aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--CONTRIBUTING.md4
-rw-r--r--README.md3
-rw-r--r--RELEASE.md11
-rw-r--r--SECURITY.md11
-rw-r--r--WORKSPACE24
-rw-r--r--configure.py205
-rw-r--r--tensorflow/BUILD74
-rw-r--r--tensorflow/api_template.__init__.py18
-rw-r--r--tensorflow/c/c_api.cc99
-rw-r--r--tensorflow/c/c_api.h10
-rw-r--r--tensorflow/c/c_api_experimental.cc27
-rw-r--r--tensorflow/c/c_api_experimental.h14
-rw-r--r--tensorflow/c/c_api_test.cc65
-rw-r--r--tensorflow/c/c_test_util.cc7
-rw-r--r--tensorflow/c/c_test_util.h3
-rw-r--r--tensorflow/c/eager/BUILD6
-rw-r--r--tensorflow/c/eager/c_api.cc120
-rw-r--r--tensorflow/c/eager/c_api.h6
-rw-r--r--tensorflow/c/eager/c_api_internal.h4
-rw-r--r--tensorflow/c/eager/c_api_test.cc107
-rw-r--r--tensorflow/c/eager/tape.h7
-rw-r--r--tensorflow/cc/BUILD2
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc71
-rw-r--r--tensorflow/cc/framework/scope.cc30
-rw-r--r--tensorflow/cc/framework/scope_internal.h3
-rw-r--r--tensorflow/cc/framework/scope_test.cc10
-rw-r--r--tensorflow/cc/gradients/array_grad.cc52
-rw-r--r--tensorflow/cc/gradients/array_grad_test.cc7
-rw-r--r--tensorflow/cc/gradients/math_grad.cc1
-rw-r--r--tensorflow/compiler/aot/codegen.cc2
-rw-r--r--tensorflow/compiler/aot/tests/BUILD4
-rw-r--r--tensorflow/compiler/jit/BUILD77
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc22
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc481
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h12
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc68
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc13
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc169
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc47
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc188
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h49
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util_test.cc69
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc19
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_device.cc70
-rw-r--r--tensorflow/compiler/jit/xla_device.h22
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc238
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h19
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h79
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc328
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.h49
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer_test.cc183
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc1
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc60
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h9
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc30
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h26
-rw-r--r--tensorflow/compiler/tests/BUILD190
-rw-r--r--tensorflow/compiler/tests/adadelta_test.py134
-rw-r--r--tensorflow/compiler/tests/adagrad_da_test.py165
-rw-r--r--tensorflow/compiler/tests/adagrad_test.py4
-rw-r--r--tensorflow/compiler/tests/adam_test.py4
-rw-r--r--tensorflow/compiler/tests/adamax_test.py139
-rw-r--r--tensorflow/compiler/tests/addsign_test.py142
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py27
-rw-r--r--tensorflow/compiler/tests/bucketize_op_test.py4
-rw-r--r--tensorflow/compiler/tests/categorical_op_test.py4
-rw-r--r--tensorflow/compiler/tests/cholesky_op_test.py4
-rw-r--r--tensorflow/compiler/tests/clustering_test.py4
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py8
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py317
-rw-r--r--tensorflow/compiler/tests/conv3d_test.py6
-rw-r--r--tensorflow/compiler/tests/depthwise_conv_op_test.py4
-rw-r--r--tensorflow/compiler/tests/dynamic_slice_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/dynamic_stitch_test.py4
-rw-r--r--tensorflow/compiler/tests/eager_test.py167
-rw-r--r--tensorflow/compiler/tests/extract_image_patches_op_test.py4
-rw-r--r--tensorflow/compiler/tests/fake_quant_ops_test.py10
-rw-r--r--tensorflow/compiler/tests/fft_test.py10
-rw-r--r--tensorflow/compiler/tests/fifo_queue_test.py201
-rw-r--r--tensorflow/compiler/tests/ftrl_test.py4
-rw-r--r--tensorflow/compiler/tests/function_test.py4
-rw-r--r--tensorflow/compiler/tests/fused_batchnorm_test.py174
-rw-r--r--tensorflow/compiler/tests/gather_nd_op_test.py4
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py51
-rw-r--r--tensorflow/compiler/tests/lrn_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py4
-rw-r--r--tensorflow/compiler/tests/matrix_triangular_solve_op_test.py4
-rw-r--r--tensorflow/compiler/tests/momentum_test.py4
-rw-r--r--tensorflow/compiler/tests/nary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/placeholder_test.py4
-rw-r--r--tensorflow/compiler/tests/pooling_ops_3d_test.py83
-rw-r--r--tensorflow/compiler/tests/pooling_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/powersign_test.py142
-rw-r--r--tensorflow/compiler/tests/proximal_adagrad_test.py172
-rw-r--r--tensorflow/compiler/tests/proximal_gradient_descent_test.py156
-rw-r--r--tensorflow/compiler/tests/qr_op_test.py112
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py95
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/reduce_window_test.py4
-rw-r--r--tensorflow/compiler/tests/reverse_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py4
-rw-r--r--tensorflow/compiler/tests/rmsprop_test.py121
-rw-r--r--tensorflow/compiler/tests/scan_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/scatter_nd_op_test.py4
-rw-r--r--tensorflow/compiler/tests/segment_reduction_ops_test.py98
-rw-r--r--tensorflow/compiler/tests/slice_ops_test.py6
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py140
-rw-r--r--tensorflow/compiler/tests/spacetobatch_op_test.py6
-rw-r--r--tensorflow/compiler/tests/sparse_to_dense_op_test.py118
-rw-r--r--tensorflow/compiler/tests/stack_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py55
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/test_utils.py63
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py270
-rw-r--r--tensorflow/compiler/tests/variable_ops_test.py221
-rw-r--r--tensorflow/compiler/tests/while_test.py4
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py4
-rw-r--r--tensorflow/compiler/tests/xla_test.py57
-rw-r--r--tensorflow/compiler/tests/xla_test_test.py44
-rw-r--r--tensorflow/compiler/tf2xla/BUILD26
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc55
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc56
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD15
-rw-r--r--tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc65
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc119
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bucketize_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc25
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cholesky_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc31
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc61
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cross_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc76
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/elu_op.cc30
-rw-r--r--tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc16
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc59
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fft_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/if_op.cc41
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_ops.cc159
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc220
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc18
-rw-r--r--tensorflow/compiler/tf2xla/kernels/l2loss_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/listdiff_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/lrn_ops.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc37
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pack_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc114
-rw-r--r--tensorflow/compiler/tf2xla/kernels/qr_op.cc47
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc138
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc212
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc34
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc23
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc6
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc26
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc105
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc5
-rw-r--r--tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc108
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc17
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc7
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc82
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc35
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc88
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc35
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc160
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc44
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/topk_op.cc84
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc579
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc32
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc206
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc188
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc17
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc (renamed from tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc)28
-rw-r--r--tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h (renamed from tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h)31
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD48
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc166
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.h7
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.cc318
-rw-r--r--tensorflow/compiler/tf2xla/lib/cholesky.h3
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.cc387
-rw-r--r--tensorflow/compiler/tf2xla/lib/qr.h40
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.cc55
-rw-r--r--tensorflow/compiler/tf2xla/lib/random.h35
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc52
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc1099
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.h22
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc106
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc265
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.h54
-rw-r--r--tensorflow/compiler/tf2xla/lib/util_test.cc24
-rw-r--r--tensorflow/compiler/tf2xla/lib/while_loop.cc27
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc39
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h16
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/ops/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc49
-rw-r--r--tensorflow/compiler/tf2xla/ops/functional_ops.cc74
-rw-r--r--tensorflow/compiler/tf2xla/ops/reduce_window_op.cc45
-rw-r--r--tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc61
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc182
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py2
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_test.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc103
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h26
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc115
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc44
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h13
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_gpu_backend.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc203
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h40
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc152
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h48
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc240
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h2
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry_test.cc119
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc25
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.h2
-rw-r--r--tensorflow/compiler/xla/BUILD50
-rw-r--r--tensorflow/compiler/xla/client/BUILD3
-rw-r--r--tensorflow/compiler/xla/client/client.cc2
-rw-r--r--tensorflow/compiler/xla/client/client.h2
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.cc6
-rw-r--r--tensorflow/compiler/xla/client/compile_only_client.h10
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.cc12
-rw-r--r--tensorflow/compiler/xla/client/executable_build_options.h8
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD92
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc50
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h2
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.cc103
-rw-r--r--tensorflow/compiler/xla/client/lib/constants.h124
-rw-r--r--tensorflow/compiler/xla/client/lib/constants_test.cc159
-rw-r--r--tensorflow/compiler/xla/client/lib/math.cc152
-rw-r--r--tensorflow/compiler/xla/client/lib/math.h51
-rw-r--r--tensorflow/compiler/xla/client/lib/math_test.cc86
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.cc79
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric.h34
-rw-r--r--tensorflow/compiler/xla/client/lib/numeric_test.cc37
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc8
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc75
-rw-r--r--tensorflow/compiler/xla/client/local_client.h6
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD7
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc1048
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h1473
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc176
-rw-r--r--tensorflow/compiler/xla/experimental/xla_sharding/BUILD18
-rw-r--r--tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py204
-rw-r--r--tensorflow/compiler/xla/layout_util.cc61
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc51
-rw-r--r--tensorflow/compiler/xla/literal.cc1967
-rw-r--r--tensorflow/compiler/xla/literal.h1152
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc15
-rw-r--r--tensorflow/compiler/xla/literal_comparison.h2
-rw-r--r--tensorflow/compiler/xla/literal_test.cc (renamed from tensorflow/compiler/xla/literal_util_test.cc)542
-rw-r--r--tensorflow/compiler/xla/literal_util.cc2119
-rw-r--r--tensorflow/compiler/xla/literal_util.h1164
-rw-r--r--tensorflow/compiler/xla/overflow_util.h50
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc2
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h2
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc5
-rw-r--r--tensorflow/compiler/xla/primitive_util.h3
-rw-r--r--tensorflow/compiler/xla/python/BUILD5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc319
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h48
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.i34
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.cc5
-rw-r--r--tensorflow/compiler/xla/python/numpy_bridge.h2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py91
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py116
-rw-r--r--tensorflow/compiler/xla/python_api/BUILD36
-rw-r--r--tensorflow/compiler/xla/python_api/types.py124
-rw-r--r--tensorflow/compiler/xla/python_api/xla_literal.py95
-rw-r--r--tensorflow/compiler/xla/python_api/xla_shape.py155
-rw-r--r--tensorflow/compiler/xla/reference_util.cc5
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc46
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD6
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_client_test.cc20
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service.h2
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc6
-rw-r--r--tensorflow/compiler/xla/rpc/xla_service.proto16
-rw-r--r--tensorflow/compiler/xla/service/BUILD298
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc255
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc431
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc220
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc208
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc111
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_support.cc7
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc128
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc1
-rw-r--r--tensorflow/compiler/xla/service/call_graph_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.cc9
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h2
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.cc78
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.h78
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc6
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.h6
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc21
-rw-r--r--tensorflow/compiler/xla/service/compiler.h30
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.h9
-rw-r--r--tensorflow/compiler/xla/service/computation_placer.cc2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier_test.cc23
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc94
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h23
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc188
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD54
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc47
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc307
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h11
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.cc50
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool.h65
-rw-r--r--tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc82
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc754
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h29
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc15
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc130
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.cc22
-rw-r--r--tensorflow/compiler/xla/service/cpu/vector_support_library.h16
-rw-r--r--tensorflow/compiler/xla/service/defuser_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h8
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h8
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc191
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/executable.cc39
-rw-r--r--tensorflow/compiler/xla/service/executable.h32
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/g3doc/hlo_parser.md (renamed from tensorflow/compiler/xla/tools/parser/README.md)0
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc7
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc52
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h16
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD84
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.cc21
-rw-r--r--tensorflow/compiler/xla/service/gpu/conditional_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc44
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc14
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h10
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/fft_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc18
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc97
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc78
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc88
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc123
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h110
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h6
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.cc45
-rw-r--r--tensorflow/compiler/xla/service/gpu/infeed_thunk.h15
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc71
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc53
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc1722
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h96
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/memset_thunk.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/memset_thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc263
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.h56
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc353
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.cc51
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_manager.h92
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc111
-rw-r--r--tensorflow/compiler/xla/service/gpu/outfeed_thunk.h52
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc27
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc5
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc86
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h43
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc114
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto5
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc81
-rw-r--r--tensorflow/compiler/xla/service/hlo_casting_utils.h104
-rw-r--r--tensorflow/compiler/xla/service/hlo_casting_utils_test.cc113
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc257
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc113
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc175
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc183
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc364
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc263
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.cc124
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h65
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc193
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc400
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h246
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc135
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc2375
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h853
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc181
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc1917
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h1153
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc)26
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_lexer.h)17
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc46
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h49
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser.cc)459
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser.h)24
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc)251
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc49
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc130
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc96
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc213
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc157
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.h91
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_metadata.cc104
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_token.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_token.h)11
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc152
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/human_readable_profile_builder.cc17
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc451
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h56
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc478
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc28
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc31
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc10
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.cc2
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc112
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h60
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc105
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD39
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc28
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc83
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc23
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h7
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc77
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h60
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc51
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h198
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc118
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h80
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc45
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h83
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc149
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc20
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h9
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ops.cc4
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc16
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc21
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc338
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h169
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc17
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h36
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer_test.cc29
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h80
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc13
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/service.cc155
-rw-r--r--tensorflow/compiler/xla/service/service.h15
-rw-r--r--tensorflow/compiler/xla/service/session.proto85
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc544
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h29
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc152
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc141
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h73
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc1
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc74
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc122
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.cc31
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h9
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier_test.cc79
-rw-r--r--tensorflow/compiler/xla/service/tuple_util_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/versioned_computation_handle.h55
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc45
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/while_util.cc9
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc7
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc13
-rw-r--r--tensorflow/compiler/xla/shape_tree.h40
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc5
-rw-r--r--tensorflow/compiler/xla/shape_util.cc497
-rw-r--r--tensorflow/compiler/xla/shape_util.h120
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc197
-rw-r--r--tensorflow/compiler/xla/tests/BUILD103
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc1476
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc32
-rw-r--r--tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc279
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc67
-rw-r--r--tensorflow/compiler/xla/tests/binop_scaling_test.cc41
-rw-r--r--tensorflow/compiler/xla/tests/bitcast_convert_test.cc52
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc324
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc79
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc64
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc44
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h66
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc22
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc48
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc257
-rw-r--r--tensorflow/compiler/xla/tests/conditional_test.cc410
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc162
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc171
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc494
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc54
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc23
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/deep_graph_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc424
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc92
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/filecheck.cc5
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/fmax_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc198
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc213
-rw-r--r--tensorflow/compiler/xla/tests/half_test.cc91
-rw-r--r--tensorflow/compiler/xla/tests/hlo_metadata_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc15
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h37
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc28
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h19
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h31
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc46
-rw-r--r--tensorflow/compiler/xla/tests/llvm_compiler_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc27
-rw-r--r--tensorflow/compiler/xla/tests/llvm_irgen_test_base.h8
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc10
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc433
-rw-r--r--tensorflow/compiler/xla/tests/log_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc257
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc86
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc275
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc67
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc168
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc71
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc40
-rw-r--r--tensorflow/compiler/xla/tests/query_inferred_shape_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc34
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc49
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc224
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc195
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc24
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc404
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc50
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc363
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc193
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc138
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc67
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc30
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h2
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc29
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc206
-rw-r--r--tensorflow/compiler/xla/tests/transfer_manager_test.cc316
-rw-r--r--tensorflow/compiler/xla/tests/transpose_test.cc57
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc333
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc130
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc85
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc214
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc714
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc105
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc1
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h2
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h2
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc6
-rw-r--r--tensorflow/compiler/xla/tools/BUILD9
-rw-r--r--tensorflow/compiler/xla/tools/convert_computation.cc4
-rw-r--r--tensorflow/compiler/xla/tools/parser/BUILD73
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc62
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc2
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc2
-rw-r--r--tensorflow/compiler/xla/util.h30
-rw-r--r--tensorflow/compiler/xla/xla.proto103
-rw-r--r--tensorflow/compiler/xla/xla_data.proto528
-rw-r--r--tensorflow/contrib/BUILD32
-rw-r--r--tensorflow/contrib/__init__.py3
-rw-r--r--tensorflow/contrib/all_reduce/BUILD10
-rw-r--r--tensorflow/contrib/all_reduce/__init__.py39
-rw-r--r--tensorflow/contrib/android/BUILD2
-rw-r--r--tensorflow/contrib/autograph/BUILD4
-rw-r--r--tensorflow/contrib/autograph/CONTRIBUTING.md49
-rw-r--r--tensorflow/contrib/autograph/LIMITATIONS.md50
-rw-r--r--tensorflow/contrib/autograph/README.md14
-rw-r--r--tensorflow/contrib/autograph/STYLE_GUIDE.md16
-rw-r--r--tensorflow/contrib/autograph/__init__.py8
-rw-r--r--tensorflow/contrib/autograph/converters/BUILD78
-rw-r--r--tensorflow/contrib/autograph/converters/asserts.py8
-rw-r--r--tensorflow/contrib/autograph/converters/asserts_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py12
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py10
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees.py53
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees_test.py30
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements.py10
-rw-r--r--tensorflow/contrib/autograph/converters/continue_statements_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py37
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py12
-rw-r--r--tensorflow/contrib/autograph/converters/decorators.py75
-rw-r--r--tensorflow/contrib/autograph/converters/decorators_test.py72
-rw-r--r--tensorflow/contrib/autograph/converters/ifexp.py12
-rw-r--r--tensorflow/contrib/autograph/converters/ifexp_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/list_comprehension.py11
-rw-r--r--tensorflow/contrib/autograph/converters/list_comprehension_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/lists.py241
-rw-r--r--tensorflow/contrib/autograph/converters/lists_test.py134
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions.py12
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/name_scopes.py8
-rw-r--r--tensorflow/contrib/autograph/converters/name_scopes_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards.py17
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/single_return.py28
-rw-r--r--tensorflow/contrib/autograph/converters/single_return_test.py4
-rw-r--r--tensorflow/contrib/autograph/converters/slices.py83
-rw-r--r--tensorflow/contrib/autograph/converters/slices_test.py59
-rw-r--r--tensorflow/contrib/autograph/core/BUILD59
-rw-r--r--tensorflow/contrib/autograph/core/annos.py39
-rw-r--r--tensorflow/contrib/autograph/core/config.py (renamed from tensorflow/contrib/autograph/impl/config.py)0
-rw-r--r--tensorflow/contrib/autograph/core/converter.py210
-rw-r--r--tensorflow/contrib/autograph/core/converter_testing.py (renamed from tensorflow/contrib/autograph/converters/converter_test_base.py)40
-rw-r--r--tensorflow/contrib/autograph/core/naming.py (renamed from tensorflow/contrib/autograph/impl/naming.py)0
-rw-r--r--tensorflow/contrib/autograph/core/naming_test.py (renamed from tensorflow/contrib/autograph/impl/naming_test.py)2
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD29
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/keras_test.py37
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb666
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb11
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb311
-rw-r--r--tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb1093
-rw-r--r--tensorflow/contrib/autograph/impl/BUILD15
-rw-r--r--tensorflow/contrib/autograph/impl/api.py38
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py19
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py201
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py78
-rw-r--r--tensorflow/contrib/autograph/lang/BUILD40
-rw-r--r--tensorflow/contrib/autograph/lang/directives.py68
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions.py59
-rw-r--r--tensorflow/contrib/autograph/lang/special_functions_test.py54
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD8
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py2
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/BUILD17
-rw-r--r--tensorflow/contrib/autograph/pyct/anno.py21
-rw-r--r--tensorflow/contrib/autograph/pyct/anno_test.py4
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg.py733
-rw-r--r--tensorflow/contrib/autograph/pyct/cfg_test.py790
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD38
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py57
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py53
-rw-r--r--tensorflow/contrib/autograph/pyct/context.py49
-rw-r--r--tensorflow/contrib/autograph/pyct/origin_info.py35
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names.py7
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/BUILD1
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/__init__.py12
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py12
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg.py27
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py29
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py17
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py87
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py74
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py15
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py61
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer_test.py17
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD2
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py23
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py17
-rw-r--r--tensorflow/contrib/batching/BUILD8
-rw-r--r--tensorflow/contrib/batching/__init__.py1
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py63
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py161
-rw-r--r--tensorflow/contrib/batching/serial_device_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py45
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo.py5
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py2
-rw-r--r--tensorflow/contrib/bigtable/BUILD213
-rw-r--r--tensorflow/contrib/bigtable/README.md10
-rw-r--r--tensorflow/contrib/bigtable/__init__.py39
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc355
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.cc45
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.h142
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc221
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc104
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc68
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h67
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc107
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc112
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc200
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc113
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc219
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc374
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h87
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc78
-rw-r--r--tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc345
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_ops.cc107
-rw-r--r--tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc27
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/__init__.py20
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py272
-rw-r--r--tensorflow/contrib/bigtable/python/ops/__init__.py20
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py741
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD15
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py75
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py72
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py24
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py125
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py69
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py191
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc54
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py4
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py12
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py135
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py62
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py271
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc14
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h7
-rw-r--r--tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc48
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h2
-rw-r--r--tensorflow/contrib/boosted_trees/ops/prediction_ops.cc70
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py1
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py610
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py413
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py67
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py11
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py25
-rw-r--r--tensorflow/contrib/checkpoint/python/split_dependency_test.py21
-rw-r--r--tensorflow/contrib/cloud/BUILD12
-rw-r--r--tensorflow/contrib/cloud/README.md18
-rw-r--r--tensorflow/contrib/cloud/__init__.py13
-rw-r--r--tensorflow/contrib/cloud/ops/gcs_config_ops.cc42
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops.py1
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py44
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py27
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py56
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt62
-rw-r--r--tensorflow/contrib/cmake/external/boringssl.cmake2
-rw-r--r--tensorflow/contrib/cmake/external/double_conversion.cmake6
-rw-r--r--tensorflow/contrib/cmake/external/grpc.cmake17
-rw-r--r--tensorflow/contrib/cmake/external/mkl.cmake68
-rw-r--r--tensorflow/contrib/cmake/external/mkldnn.cmake12
-rw-r--r--tensorflow/contrib/cmake/external/nsync.cmake2
-rw-r--r--tensorflow/contrib/cmake/external/protobuf.cmake2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt7
-rw-r--r--tensorflow/contrib/cmake/tf_c.cmake13
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake10
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake13
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake119
-rw-r--r--tensorflow/contrib/cmake/tf_shared_lib.cmake5
-rw-r--r--tensorflow/contrib/cmake/tf_stream_executor.cmake2
-rw-r--r--tensorflow/contrib/coder/python/layers/entropybottleneck.py8
-rw-r--r--tensorflow/contrib/constrained_optimization/README.md2
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py8
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py1
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py6
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py2
-rw-r--r--tensorflow/contrib/data/__init__.py17
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc518
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc26
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc27
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD382
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py268
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py118
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py319
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py62
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py24
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py128
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/ops/iterator_ops_test.py)0
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py357
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py87
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py558
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py91
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py482
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py331
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py16
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD526
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py190
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py)4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py)12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py)6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py)4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py61
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py86
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py88
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py140
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py101
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py139
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py)4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py118
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py)16
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py)4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py148
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py99
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py51
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py)4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py112
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py42
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py96
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py96
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py99
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py59
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py523
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD42
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py311
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py13
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py5
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py505
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py12
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py75
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py203
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py7
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py7
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py2
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py141
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py7
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py8
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py44
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py19
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py13
-rw-r--r--tensorflow/contrib/distribute/BUILD1
-rw-r--r--tensorflow/contrib/distribute/__init__.py2
-rw-r--r--tensorflow/contrib/distribute/python/BUILD26
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py57
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py306
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py242
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py145
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils_test.py18
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py438
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py28
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py160
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py566
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py4
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py21
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py3
-rw-r--r--tensorflow/contrib/distribute/python/shared_variable_creator_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py17
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py136
-rw-r--r--tensorflow/contrib/distribute/python/values.py333
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py55
-rw-r--r--tensorflow/contrib/distributions/BUILD71
-rw-r--r--tensorflow/contrib/distributions/__init__.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py69
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py28
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py66
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py48
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/BUILD51
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py323
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py150
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py25
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/__init__.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/chain.py25
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py24
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/exp.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py165
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/inline.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/invert.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py49
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/ordered.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/permute.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/reshape.py25
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py123
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/softplus.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/softsign.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/square.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py111
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/weibull.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/binomial.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/chi2.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py25
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py79
-rw-r--r--tensorflow/contrib/distributions/python/ops/estimator.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/geometric.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/kumaraswamy.py19
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/negative_binomial.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/onehot_categorical.py19
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py33
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py11
-rw-r--r--tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py19
-rw-r--r--tensorflow/contrib/distributions/python/ops/shape.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py97
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py25
-rw-r--r--tensorflow/contrib/eager/python/datasets.py1
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD4
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/BUILD29
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet.py274
-rw-r--r--tensorflow/contrib/eager/python/examples/densenet/densenet_test.py83
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb1184
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb689
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py98
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py202
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/BUILD115
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/README.md45
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks.py357
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/blocks_test.py304
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_input.py116
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py154
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/config.py140
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/main.py256
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/ops.py70
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/ops_test.py80
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet.py301
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py332
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/BUILD59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/config.py72
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops.py71
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/ops_test.py59
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py232
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan_test.py101
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb282
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb1018
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb443
-rw-r--r--tensorflow/contrib/eager/python/metrics.py3
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py150
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py37
-rw-r--r--tensorflow/contrib/eager/python/network_test.py42
-rw-r--r--tensorflow/contrib/eager/python/tfe.py4
-rw-r--r--tensorflow/contrib/estimator/BUILD32
-rw-r--r--tensorflow/contrib/estimator/__init__.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py8
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py46
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py78
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn.py40
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_test.py17
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py468
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping_test.py233
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py56
-rw-r--r--tensorflow/contrib/estimator/python/estimator/linear.py28
-rw-r--r--tensorflow/contrib/factorization/kernels/wals_solver_ops.cc44
-rw-r--r--tensorflow/contrib/factorization/ops/factorization_ops.cc19
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py36
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py1
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py24
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py320
-rw-r--r--tensorflow/contrib/ffmpeg/__init__.py1
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py1
-rw-r--r--tensorflow/contrib/framework/__init__.py3
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_test.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py10
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py120
-rw-r--r--tensorflow/contrib/fused_conv/BUILD2
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py11
-rw-r--r--tensorflow/contrib/gan/BUILD15
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py200
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py227
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py27
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py11
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc12
-rw-r--r--tensorflow/contrib/gdr/gdr_server_lib.cc2
-rw-r--r--tensorflow/contrib/graph_editor/transform.py5
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h30
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py20
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py3
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes.py126
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py51
-rw-r--r--tensorflow/contrib/kafka/ops/kafka_ops.cc44
-rw-r--r--tensorflow/contrib/keras/api/keras/layers/__init__.py8
-rw-r--r--tensorflow/contrib/kfac/README.md5
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py6
-rw-r--r--tensorflow/contrib/kinesis/BUILD113
-rw-r--r--tensorflow/contrib/kinesis/__init__.py (renamed from tensorflow/python/training/checkpointable/data_structures_base.py)17
-rw-r--r--tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc359
-rw-r--r--tensorflow/contrib/kinesis/ops/dataset_ops.cc42
-rw-r--r--tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py139
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py96
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py24
-rw-r--r--tensorflow/contrib/labeled_tensor/python/ops/ops.py3
-rw-r--r--tensorflow/contrib/layers/__init__.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py1
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops.py11
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py23
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py26
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py2
-rw-r--r--tensorflow/contrib/linear_optimizer/BUILD1
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py47
-rw-r--r--tensorflow/contrib/lite/BUILD12
-rw-r--r--tensorflow/contrib/lite/Makefile126
-rw-r--r--tensorflow/contrib/lite/allocation.cc6
-rw-r--r--tensorflow/contrib/lite/arena_planner.cc69
-rw-r--r--tensorflow/contrib/lite/arena_planner.h9
-rw-r--r--tensorflow/contrib/lite/arena_planner_test.cc49
-rw-r--r--tensorflow/contrib/lite/build_def.bzl16
-rwxr-xr-xtensorflow/contrib/lite/build_ios_universal_lib.sh37
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h35
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h14
-rw-r--r--tensorflow/contrib/lite/context.c3
-rw-r--r--tensorflow/contrib/lite/context.h56
-rw-r--r--tensorflow/contrib/lite/context_util.h48
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD31
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc694
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h (renamed from tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc)31
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc688
-rwxr-xr-xtensorflow/contrib/lite/download_dependencies.sh4
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD45
-rw-r--r--tensorflow/contrib/lite/examples/android/android.iml19
-rw-r--r--tensorflow/contrib/lite/examples/android/app/build.gradle60
-rw-r--r--tensorflow/contrib/lite/examples/android/app/download-models.gradle74
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml (renamed from tensorflow/contrib/lite/examples/android/AndroidManifest.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD (renamed from tensorflow/contrib/lite/examples/android/assets/BUILD)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt (renamed from tensorflow/contrib/lite/examples/android/assets/box_priors.txt)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt (renamed from tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt (renamed from tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt (renamed from tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt38
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java)13
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java)220
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java (renamed from tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml (renamed from tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png)bin1025 -> 1025 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png)bin4312 -> 4312 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png)bin196 -> 196 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png)bin665 -> 665 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png)bin2265 -> 2265 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png)bin1355 -> 1355 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png)bin6683 -> 6683 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png)bin2265 -> 2265 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png (renamed from tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png)bin12746 -> 12746 bytes
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml (renamed from tensorflow/contrib/lite/examples/android/res/drawable/border.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml (renamed from tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/attrs.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/base-strings.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/colors.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/strings.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml (renamed from tensorflow/contrib/lite/examples/android/res/values/template-styles.xml)0
-rw-r--r--tensorflow/contrib/lite/examples/android/build.gradle55
-rw-r--r--tensorflow/contrib/lite/examples/android/settings.gradle1
-rw-r--r--tensorflow/contrib/lite/examples/label_image/BUILD31
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc28
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h4
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image.cc12
-rw-r--r--tensorflow/contrib/lite/examples/label_image/label_image_test.cc16
-rw-r--r--tensorflow/contrib/lite/examples/minimal/BUILD27
-rw-r--r--tensorflow/contrib/lite/examples/minimal/minimal.cc26
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md3
-rw-r--r--tensorflow/contrib/lite/g3doc/benchmarks.md178
-rw-r--r--tensorflow/contrib/lite/g3doc/ops_versioning.md206
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md114
-rw-r--r--tensorflow/contrib/lite/graph_info.h3
-rw-r--r--tensorflow/contrib/lite/graph_info_test.cc2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc212
-rw-r--r--tensorflow/contrib/lite/interpreter.h61
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc265
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl4
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md9
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle45
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD2
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD2
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD2
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle4
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java8
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java9
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java60
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java243
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java198
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc339
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h93
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc159
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h61
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java33
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java256
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java152
-rw-r--r--tensorflow/contrib/lite/java/src/testdata/quantized.binbin0 -> 432 bytes
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD2
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java4
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD113
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc93
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc147
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc250
-rw-r--r--tensorflow/contrib/lite/kernels/add_test.cc96
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc (renamed from tensorflow/contrib/lite/kernels/arg_max.cc)69
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max_test.cc (renamed from tensorflow/contrib/lite/kernels/arg_max_test.cc)89
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/cast_test.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc66
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc591
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc81
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h8
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc43
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc36
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc57
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_test.cc110
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc83
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc81
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant_test.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc129
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected_test.cc242
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc55
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h3
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD70
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc264
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h83
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc67
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h923
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h361
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h63
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h1525
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h22
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h369
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h867
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc (renamed from tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc)60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc62
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h16
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h257
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc56
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h33
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc791
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc1769
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc121
-rw-r--r--tensorflow/contrib/lite/kernels/mul_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/neg_test.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc137
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/pow_test.cc117
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc (renamed from tensorflow/contrib/lite/kernels/mean.cc)109
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc (renamed from tensorflow/contrib/lite/kernels/mean_test.cc)178
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc34
-rw-r--r--tensorflow/contrib/lite/kernels/register.h2
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc23
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear_test.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc3
-rw-r--r--tensorflow/contrib/lite/kernels/select_test.cc37
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc93
-rw-r--r--tensorflow/contrib/lite/kernels/shape_test.cc95
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice_test.cc65
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc320
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h50
-rw-r--r--tensorflow/contrib/lite/kernels/test_util_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc194
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc256
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc109
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv_test.cc133
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc468
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc1767
-rw-r--r--tensorflow/contrib/lite/model.cc185
-rw-r--r--tensorflow/contrib/lite/model.h1
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD2
-rw-r--r--tensorflow/contrib/lite/models/smartreply/predictor_test.cc8
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc309
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h6
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc17
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h3
-rw-r--r--tensorflow/contrib/lite/profiling/BUILD7
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.cc33
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer.h3
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer_test.cc50
-rw-r--r--tensorflow/contrib/lite/python/BUILD9
-rw-r--r--tensorflow/contrib/lite/python/convert.py123
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py31
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py1
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py112
-rw-r--r--tensorflow/contrib/lite/python/interpreter_test.py79
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc341
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h27
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i43
-rw-r--r--tensorflow/contrib/lite/python/lite.py98
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py372
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py124
-rw-r--r--tensorflow/contrib/lite/schema/BUILD1
-rw-r--r--tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs91
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h1279
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.cc16
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h3
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena_test.cc41
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/testing/BUILD16
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py450
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.cc85
-rw-r--r--tensorflow/contrib/lite/testing/generate_testspec.h4
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc65
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_example_test.cc23
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_flags.h6
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.cc7
-rw-r--r--tensorflow/contrib/lite/testing/tflite_diff_util.h6
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc27
-rw-r--r--tensorflow/contrib/lite/toco/BUILD46
-rw-r--r--tensorflow/contrib/lite/toco/README.md13
-rw-r--r--tensorflow/contrib/lite/toco/args.h14
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc20
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc361
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md410
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md292
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md79
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc15
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc94
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc60
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc23
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc102
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc34
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc144
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc178
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc33
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc14
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc53
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc197
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc24
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc5
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc23
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc36
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc42
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc4
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc97
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc (renamed from tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc)6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc1436
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h204
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc10
-rw-r--r--tensorflow/contrib/lite/toco/python/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_python_api.cc13
-rw-r--r--tensorflow/contrib/lite/toco/runtime/types.h1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc96
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/tflite/import.cc30
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc284
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h11
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc69
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc16
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc19
-rw-r--r--tensorflow/contrib/lite/toco/toco.cc36
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc25
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto20
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.cc75
-rw-r--r--tensorflow/contrib/lite/toco/toco_port.h51
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.cc189
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model.h53
-rw-r--r--tensorflow/contrib/lite/toco/toco_saved_model_test.cc274
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc22
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc162
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h26
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc17
-rw-r--r--tensorflow/contrib/lite/tools/BUILD83
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD100
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md209
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc (renamed from tensorflow/contrib/lite/tools/benchmark_main.cc)4
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc (renamed from tensorflow/contrib/lite/tools/benchmark_model.cc)59
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_model.h (renamed from tensorflow/contrib/lite/tools/benchmark_model.h)27
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc57
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_params.h101
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc (renamed from tensorflow/contrib/lite/tools/benchmark_tflite_model.cc)108
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h (renamed from tensorflow/contrib/lite/tools/benchmark_tflite_model.h)19
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc (renamed from tensorflow/contrib/lite/tools/command_line_flags.cc)97
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags.h (renamed from tensorflow/contrib/lite/tools/command_line_flags.h)29
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc (renamed from tensorflow/contrib/lite/tools/command_line_flags_test.cc)56
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/README.md43
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj381
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h22
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m27
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json98
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json6
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard25
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard60
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h21
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm125
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist43
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json10
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m23
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/logging.h (renamed from tensorflow/contrib/lite/tools/logging.h)3
-rw-r--r--tensorflow/contrib/lite/tools/verifier_test.cc6
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py17
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py22
-rwxr-xr-xtensorflow/contrib/makefile/build_all_android.sh8
-rwxr-xr-xtensorflow/contrib/makefile/build_all_ios.sh8
-rwxr-xr-xtensorflow/contrib/makefile/download_dependencies.sh4
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/contrib/metrics/BUILD27
-rw-r--r--tensorflow/contrib/metrics/__init__.py1
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification.py121
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification_test.py202
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py117
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py66
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py515
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py22
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py44
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py12
-rw-r--r--tensorflow/contrib/mpi_collectives/BUILD1
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc2
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.py163
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cc80
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.cu.cc117
-rw-r--r--tensorflow/contrib/mpi_collectives/ring.h327
-rw-r--r--tensorflow/contrib/nccl/BUILD45
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py59
-rw-r--r--tensorflow/contrib/nccl/python/ops/nccl_ops.py45
-rw-r--r--tensorflow/contrib/opt/BUILD42
-rw-r--r--tensorflow/contrib/opt/__init__.py14
-rw-r--r--tensorflow/contrib/opt/python/training/addsign_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/ggt.py312
-rw-r--r--tensorflow/contrib/opt/python/training/ggt_test.py183
-rw-r--r--tensorflow/contrib/opt/python/training/powersign_test.py2
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py362
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py188
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py4
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py113
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py24
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2_test.py14
-rw-r--r--tensorflow/contrib/periodic_resample/BUILD19
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc5
-rw-r--r--tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h415
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops.cc53
-rw-r--r--tensorflow/contrib/periodic_resample/ops/array_ops_test.cc41
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py27
-rw-r--r--tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py8
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py24
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py19
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py6
-rw-r--r--tensorflow/contrib/proto/BUILD4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/BUILD81
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl89
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py4
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py42
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt94
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py17
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt161
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt16
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt20
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt29
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt32
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt62
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt21
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/test_base.py407
-rw-r--r--tensorflow/contrib/quantize/README.md2
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py72
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py84
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py39
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph.py4
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py177
-rw-r--r--tensorflow/contrib/receptive_field/README.md32
-rw-r--r--tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md629
-rw-r--r--tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py82
-rw-r--r--tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py2
-rw-r--r--tensorflow/contrib/recurrent/BUILD5
-rw-r--r--tensorflow/contrib/rnn/BUILD1
-rw-r--r--tensorflow/contrib/rnn/__init__.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py163
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py6
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py342
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py2
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/decoder.py29
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py2
-rw-r--r--tensorflow/contrib/signal/python/kernel_tests/test_util.py1
-rw-r--r--tensorflow/contrib/slim/python/slim/evaluation_test.py28
-rw-r--r--tensorflow/contrib/solvers/python/ops/linear_equations.py1
-rw-r--r--tensorflow/contrib/stat_summarizer/BUILD5
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py13
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics.py45
-rw-r--r--tensorflow/contrib/tensorboard/db/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/BUILD20
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc1030
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h61
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc801
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h133
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc48
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h3
-rw-r--r--tensorflow/contrib/tensorrt/convert/utils.h37
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc136
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_calib_op.h52
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc588
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h98
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_calib_op.cc37
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc18
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py55
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.cc2
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.h5
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc59
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h43
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h49
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h7
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc76
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py138
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py401
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i102
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py13
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head.py27
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py101
-rw-r--r--tensorflow/contrib/tpu/BUILD41
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc12
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc21
-rw-r--r--tensorflow/contrib/tpu/profiler/BUILD2
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc2
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py67
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py4
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD15
-rw-r--r--tensorflow/contrib/tpu/proto/compilation_result.proto4
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py3
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py513
-rw-r--r--tensorflow/contrib/tpu/python/tpu/topology.py5
-rw-r--r--tensorflow/contrib/tpu/python/tpu/topology_test.py46
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py124
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py45
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py55
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py78
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py244
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py54
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_test.py4
-rw-r--r--tensorflow/contrib/training/BUILD2
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py187
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py145
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py7
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py2
-rw-r--r--tensorflow/contrib/verbs/BUILD4
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service.cc6
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.cc16
-rw-r--r--tensorflow/contrib/verbs/grpc_verbs_service_impl.h16
-rw-r--r--tensorflow/contrib/verbs/rdma.cc6
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc8
-rw-r--r--tensorflow/core/BUILD141
-rw-r--r--tensorflow/core/api_def/BUILD7
-rw-r--r--tensorflow/core/api_def/api_test.cc39
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt128
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt36
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GcsConfigureBlockCache.pbtxt9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GcsConfigureCredentials.pbtxt33
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MatrixLogarithm.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt62
-rw-r--r--tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt69
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdAdd.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdNonAliasingAdd.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdSub.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Selu.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SlideDataset.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Softmax.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyAdagrad.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyCenteredRMSProp.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyFtrl.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyMomentum.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyProximalAdagrad.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyProximalGradientDescent.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseApplyRMSProp.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseMatMul.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSliceGrad.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatefulPartitionedCall.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt48
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt11
-rw-r--r--tensorflow/core/api_def/excluded_ops.cc3
-rw-r--r--tensorflow/core/api_def/java_api/api_def_Assert.pbtxt4
-rw-r--r--tensorflow/core/api_def/java_api/api_def_Const.pbtxt4
-rw-r--r--tensorflow/core/api_def/java_api/api_def_Switch.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acos.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Add.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_AsString.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BesselI0e.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BesselI1e.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cos.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Cross.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Diag.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Equal.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Exp.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FFT.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeParam.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Floor.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Greater.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt5
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Less.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Qr.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rint.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt3
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt3
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SparseSliceGrad.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatefulPartitionedCall.pbtxt1
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tan.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt10
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.h8
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc4
-rw-r--r--tensorflow/core/common_runtime/broadcaster_test.cc4
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.cc3
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h3
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.cc18
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h9
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr_test.cc11
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc68
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h26
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc72
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.cc14
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.h2
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local_test.cc2
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc19
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.h15
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc21
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc339
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc21
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD144
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc51
-rw-r--r--tensorflow/core/common_runtime/eager/context.h24
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc316
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc1
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h6
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc44
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h31
-rw-r--r--tensorflow/core/common_runtime/executor.cc27
-rw-r--r--tensorflow/core/common_runtime/executor_factory.cc85
-rw-r--r--tensorflow/core/common_runtime/executor_factory.h51
-rw-r--r--tensorflow/core/common_runtime/executor_test.cc10
-rw-r--r--tensorflow/core/common_runtime/function.cc73
-rw-r--r--tensorflow/core/common_runtime/function.h6
-rw-r--r--tensorflow/core/common_runtime/function_test.cc72
-rw-r--r--tensorflow/core/common_runtime/gpu/cuda_host_allocator.h60
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc34
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h3
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_factory.cc7
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.cc (renamed from tensorflow/core/common_runtime/gpu/process_state.cc)170
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.h (renamed from tensorflow/core/common_runtime/gpu/process_state.h)90
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc6
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h15
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc185
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc18
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h4
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc5
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.cc7
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/placer.cc55
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc40
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc (renamed from tensorflow/core/common_runtime/gpu/pool_allocator.cc)10
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.h (renamed from tensorflow/core/common_runtime/gpu/pool_allocator.h)53
-rw-r--r--tensorflow/core/common_runtime/process_state.cc129
-rw-r--r--tensorflow/core/common_runtime/process_state.h132
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc19
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc4
-rw-r--r--tensorflow/core/common_runtime/test_collective_executor_mgr.h3
-rw-r--r--tensorflow/core/debug/BUILD6
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.h2
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc2
-rw-r--r--tensorflow/core/distributed_runtime/BUILD56
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc2
-rw-r--r--tensorflow/core/distributed_runtime/cancellable_call.h65
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc75
-rw-r--r--tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc7
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.cc53
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed.h1
-rw-r--r--tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc6
-rw-r--r--tensorflow/core/distributed_runtime/eager/BUILD8
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.cc48
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl.h3
-rw-r--r--tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc61
-rw-r--r--tensorflow/core/distributed_runtime/eager/remote_execute_node.h34
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc26
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h8
-rw-r--r--tensorflow/core/distributed_runtime/local_master.h2
-rw-r--r--tensorflow/core/distributed_runtime/master.cc8
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc85
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h3
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/remote_device_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD43
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/BUILD22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h97
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc11
-rw-r--r--tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_call.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc87
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h33
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc33
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_state.h4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc28
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc4
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc10
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_util.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc8
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc16
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h18
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc168
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h85
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc171
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc10
-rw-r--r--tensorflow/core/framework/api_def.proto12
-rw-r--r--tensorflow/core/framework/collective.h1
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc56
-rw-r--r--tensorflow/core/framework/common_shape_fns.h17
-rw-r--r--tensorflow/core/framework/cost_graph.proto3
-rw-r--r--tensorflow/core/framework/dataset.h19
-rw-r--r--tensorflow/core/framework/device_base.cc33
-rw-r--r--tensorflow/core/framework/device_base.h15
-rw-r--r--tensorflow/core/framework/device_base_test.cc62
-rw-r--r--tensorflow/core/framework/function.cc4
-rw-r--r--tensorflow/core/framework/function.h6
-rw-r--r--tensorflow/core/framework/graph_to_functiondef.cc2
-rw-r--r--tensorflow/core/framework/kernel_def.proto5
-rw-r--r--tensorflow/core/framework/kernel_def_util.cc83
-rw-r--r--tensorflow/core/framework/kernel_def_util.h31
-rw-r--r--tensorflow/core/framework/kernel_def_util_test.cc133
-rw-r--r--tensorflow/core/framework/memory_types.cc11
-rw-r--r--tensorflow/core/framework/node_def_util.cc13
-rw-r--r--tensorflow/core/framework/node_def_util.h5
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc14
-rw-r--r--tensorflow/core/framework/op_kernel.cc79
-rw-r--r--tensorflow/core/framework/op_kernel.h10
-rw-r--r--tensorflow/core/framework/op_kernel_test.cc22
-rw-r--r--tensorflow/core/framework/resource_mgr.cc4
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h25
-rw-r--r--tensorflow/core/framework/resource_var.h2
-rw-r--r--tensorflow/core/framework/shape_inference.cc14
-rw-r--r--tensorflow/core/framework/shape_inference.h2
-rw-r--r--tensorflow/core/framework/stats_aggregator.h4
-rw-r--r--tensorflow/core/framework/types.h4
-rw-r--r--tensorflow/core/graph/control_flow.cc69
-rw-r--r--tensorflow/core/graph/control_flow.h13
-rw-r--r--tensorflow/core/graph/control_flow_test.cc131
-rw-r--r--tensorflow/core/graph/gradients.cc11
-rw-r--r--tensorflow/core/graph/graph.cc23
-rw-r--r--tensorflow/core/graph/graph.h20
-rw-r--r--tensorflow/core/graph/graph_constructor.cc53
-rw-r--r--tensorflow/core/graph/graph_constructor.h14
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc48
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc21
-rw-r--r--tensorflow/core/graph/subgraph.cc12
-rw-r--r--tensorflow/core/graph/tensor_id.cc5
-rw-r--r--tensorflow/core/graph/tensor_id.h31
-rw-r--r--tensorflow/core/graph/validate.cc54
-rw-r--r--tensorflow/core/graph/validate.h9
-rw-r--r--tensorflow/core/grappler/clusters/BUILD4
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc9
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine_test.cc8
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc12
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h5
-rw-r--r--tensorflow/core/grappler/costs/BUILD4
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc245
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h5
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc301
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt117
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt251
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt251
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt317
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt597
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc95
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.h10
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc119
-rw-r--r--tensorflow/core/grappler/op_types.cc17
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD9
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc1312
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h30
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc550
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc18
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h2
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD104
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc25
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h9
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc59
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h8
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc90
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.h48
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc217
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc104
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h46
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc149
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc210
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.h17
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc130
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer_test.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h8
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc15
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc132
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc4
-rw-r--r--tensorflow/core/grappler/utils/scc.cc12
-rw-r--r--tensorflow/core/kernels/BUILD97
-rw-r--r--tensorflow/core/kernels/as_string_op.cc2
-rw-r--r--tensorflow/core/kernels/batch_kernels.cc397
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc2
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc2
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD3
-rw-r--r--tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc32
-rw-r--r--tensorflow/core/kernels/bias_op.cc195
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc28
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.h35
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD8
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto17
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc153
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc23
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.h6
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc95
-rw-r--r--tensorflow/core/kernels/boosted_trees/training_ops.cc85
-rw-r--r--tensorflow/core/kernels/boosted_trees/tree_helper.h69
-rw-r--r--tensorflow/core/kernels/concat_op.cc2
-rw-r--r--tensorflow/core/kernels/constant_op.cc3
-rw-r--r--tensorflow/core/kernels/control_flow_ops.cc23
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h16
-rw-r--r--tensorflow/core/kernels/conv_2d.h2
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc5
-rw-r--r--tensorflow/core/kernels/conv_ops_fused.cc12
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc8
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc180
-rw-r--r--tensorflow/core/kernels/cwise_op_bessel.cc29
-rw-r--r--tensorflow/core/kernels/cwise_op_bessel.cu.cc27
-rw-r--r--tensorflow/core/kernels/cwise_op_equal_to_1.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc2
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_greater.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_greater_equal.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_igammas.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_less.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_less_equal.cc7
-rw-r--r--tensorflow/core/kernels/cwise_op_not_equal_to_1.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_random_grad.cc25
-rw-r--r--tensorflow/core/kernels/cwise_ops.h10
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h3
-rw-r--r--tensorflow/core/kernels/data/BUILD52
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc46
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc545
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc12
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc47
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc3
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc89
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc279
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc232
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc49
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc310
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc72
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc217
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc51
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc29
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc196
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc1
-rw-r--r--tensorflow/core/kernels/data/window_dataset.h2
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc196
-rw-r--r--tensorflow/core/kernels/deep_conv2d.cc10
-rw-r--r--tensorflow/core/kernels/dense_update_ops.cc2
-rw-r--r--tensorflow/core/kernels/deserialize_sparse_string_op.cc293
-rw-r--r--tensorflow/core/kernels/deserialize_sparse_variant_op.cc372
-rw-r--r--tensorflow/core/kernels/eigen_pooling.h13
-rw-r--r--tensorflow/core/kernels/fifo_queue.cc15
-rw-r--r--tensorflow/core/kernels/fifo_queue.h23
-rw-r--r--tensorflow/core/kernels/fifo_queue_op.cc39
-rw-r--r--tensorflow/core/kernels/function_ops.cc48
-rw-r--r--tensorflow/core/kernels/functional_ops.cc29
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h2
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc2
-rw-r--r--tensorflow/core/kernels/matmul_op.cc32
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc17
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc17
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc667
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc478
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc341
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h222
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_identity_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc16
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc20
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc7
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc15
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h6
-rw-r--r--tensorflow/core/kernels/mkl_transpose_op.cc2
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc182
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc237
-rw-r--r--tensorflow/core/kernels/pad_op.cc4
-rw-r--r--tensorflow/core/kernels/pad_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc395
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op.h8
-rw-r--r--tensorflow/core/kernels/queue_op.cc367
-rw-r--r--tensorflow/core/kernels/queue_op.h233
-rw-r--r--tensorflow/core/kernels/queue_ops.cc395
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h12
-rw-r--r--tensorflow/core/kernels/reshape_util.cc19
-rw-r--r--tensorflow/core/kernels/resize_area_op_test.cc3
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc24
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h12
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc265
-rw-r--r--tensorflow/core/kernels/sparse_slice_grad_op.cc126
-rw-r--r--tensorflow/core/kernels/string_split_op.cc130
-rw-r--r--tensorflow/core/kernels/tensor_array.cc10
-rw-r--r--tensorflow/core/kernels/tensor_array.h4
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc47
-rw-r--r--tensorflow/core/kernels/transpose_op.cc2
-rw-r--r--tensorflow/core/kernels/transpose_op.h4
-rw-r--r--tensorflow/core/kernels/unary_ops_composition.cc432
-rw-r--r--tensorflow/core/kernels/unary_ops_composition_test.cc179
-rw-r--r--tensorflow/core/kernels/variable_ops.cc3
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h12
-rw-r--r--tensorflow/core/lib/db/sqlite_test.cc15
-rw-r--r--tensorflow/core/lib/gtl/manual_constructor_test.cc3
-rw-r--r--tensorflow/core/lib/io/random_inputstream.cc10
-rw-r--r--tensorflow/core/lib/strings/numbers.cc26
-rw-r--r--tensorflow/core/lib/strings/numbers.h4
-rw-r--r--tensorflow/core/lib/strings/numbers_test.cc10
-rw-r--r--tensorflow/core/ops/batch_ops.cc20
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc47
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt2058
-rw-r--r--tensorflow/core/ops/control_flow_ops.cc13
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc44
-rw-r--r--tensorflow/core/ops/dataset_ops.cc115
-rw-r--r--tensorflow/core/ops/functional_ops.cc45
-rw-r--r--tensorflow/core/ops/image_ops.cc36
-rw-r--r--tensorflow/core/ops/math_ops.cc32
-rw-r--r--tensorflow/core/ops/nn_ops.cc6
-rw-r--r--tensorflow/core/ops/ops.pbtxt737
-rw-r--r--tensorflow/core/ops/random_ops.cc7
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc17
-rw-r--r--tensorflow/core/ops/sparse_ops.cc14
-rw-r--r--tensorflow/core/ops/sparse_ops_test.cc12
-rw-r--r--tensorflow/core/ops/state_ops.cc9
-rw-r--r--tensorflow/core/ops/string_ops.cc20
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc10
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/default/build_config.bzl16
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD12
-rw-r--r--tensorflow/core/platform/env.h2
-rw-r--r--tensorflow/core/platform/fingerprint.h2
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc21
-rw-r--r--tensorflow/core/platform/numa.h62
-rw-r--r--tensorflow/core/platform/numa_test.cc61
-rw-r--r--tensorflow/core/platform/posix/port.cc24
-rw-r--r--tensorflow/core/platform/profile_utils/cpu_utils.cc37
-rw-r--r--tensorflow/core/platform/s3/BUILD14
-rw-r--r--tensorflow/core/platform/s3/aws_crypto.cc113
-rw-r--r--tensorflow/core/platform/s3/aws_crypto.h35
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc6
-rw-r--r--tensorflow/core/platform/vmodule_benchmark_test.cc (renamed from tensorflow/compiler/xla/service/versioned_computation_handle.cc)22
-rw-r--r--tensorflow/core/platform/vmodule_test.cc117
-rw-r--r--tensorflow/core/platform/windows/port.cc5
-rw-r--r--tensorflow/core/profiler/internal/tfprof_timeline.cc16
-rw-r--r--tensorflow/core/protobuf/config.proto68
-rw-r--r--tensorflow/core/protobuf/eager_service.proto12
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto6
-rw-r--r--tensorflow/core/util/device_name_utils.cc57
-rw-r--r--tensorflow/core/util/device_name_utils.h12
-rw-r--r--tensorflow/core/util/device_name_utils_test.cc47
-rw-r--r--tensorflow/core/util/exec_on_stall.h89
-rw-r--r--tensorflow/core/util/exec_on_stall_test.cc58
-rw-r--r--tensorflow/core/util/mkl_util.h257
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h1
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h7
-rw-r--r--tensorflow/core/util/stat_summarizer.cc8
-rw-r--r--tensorflow/core/util/stat_summarizer.h2
-rw-r--r--tensorflow/core/util/stats_calculator.cc27
-rw-r--r--tensorflow/core/util/stats_calculator.h3
-rw-r--r--tensorflow/core/util/status_util.h36
-rw-r--r--tensorflow/core/util/status_util_test.cc36
-rw-r--r--tensorflow/core/util/tensor_format.cc14
-rw-r--r--tensorflow/core/util/tensor_format.h47
-rw-r--r--tensorflow/core/util/tensor_format_test.cc25
-rw-r--r--tensorflow/core/util/work_sharder.cc19
-rw-r--r--tensorflow/core/util/work_sharder.h45
-rw-r--r--tensorflow/core/util/work_sharder_test.cc17
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md50
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md32
-rw-r--r--tensorflow/docs_src/api_guides/python/contrib.distributions.md83
-rw-r--r--tensorflow/docs_src/api_guides/python/spectral_ops.md1
-rw-r--r--tensorflow/docs_src/deploy/s3.md2
-rw-r--r--tensorflow/docs_src/extend/index.md3
-rw-r--r--tensorflow/docs_src/guide/eager.md4
-rw-r--r--tensorflow/docs_src/guide/feature_columns.md6
-rw-r--r--tensorflow/docs_src/guide/graph_viz.md3
-rw-r--r--tensorflow/docs_src/guide/version_compat.md6
-rw-r--r--tensorflow/docs_src/javascript/index.md5
-rw-r--r--tensorflow/docs_src/javascript/leftnav_files1
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md35
-rw-r--r--tensorflow/docs_src/tutorials/_index.yaml13
-rw-r--r--tensorflow/docs_src/tutorials/_toc.yaml24
-rw-r--r--tensorflow/docs_src/tutorials/estimators/cnn.md (renamed from tensorflow/docs_src/tutorials/images/layers.md)0
-rw-r--r--tensorflow/docs_src/tutorials/images/deep_cnn.md14
-rw-r--r--tensorflow/docs_src/tutorials/images/image_recognition.md2
-rw-r--r--tensorflow/docs_src/tutorials/representation/linear.md10
-rw-r--r--tensorflow/docs_src/tutorials/representation/wide.md461
-rw-r--r--tensorflow/docs_src/tutorials/representation/wide_and_deep.md243
-rw-r--r--tensorflow/docs_src/tutorials/representation/word2vec.md10
-rw-r--r--tensorflow/examples/android/BUILD2
-rw-r--r--tensorflow/examples/learn/iris.py7
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD2
-rw-r--r--tensorflow/go/attrs.go245
-rw-r--r--tensorflow/go/attrs_test.go193
-rw-r--r--tensorflow/go/op/wrappers.go4808
-rw-r--r--tensorflow/go/operation.go66
-rw-r--r--tensorflow/go/operation_test.go62
-rw-r--r--tensorflow/go/tensor.go6
-rw-r--r--tensorflow/go/tensor_test.go49
-rw-r--r--tensorflow/java/BUILD5
-rw-r--r--tensorflow/java/maven/.gitignore6
-rw-r--r--tensorflow/java/maven/README.md6
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml192
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml4
-rw-r--r--tensorflow/java/maven/proto/pom.xml4
-rw-r--r--tensorflow/java/maven/run_inside_container.sh52
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml349
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc22
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc147
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h42
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java348
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java79
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Input.java48
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java153
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFString.java27
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFType.java20
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/Types.java52
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc54
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h9
-rw-r--r--tensorflow/java/src/main/native/session_jni.cc42
-rw-r--r--tensorflow/java/src/main/native/utils_jni.cc53
-rw-r--r--tensorflow/java/src/main/native/utils_jni.h33
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java103
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java38
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java34
-rw-r--r--tensorflow/python/BUILD131
-rw-r--r--tensorflow/python/__init__.py1
-rw-r--r--tensorflow/python/client/session.py177
-rw-r--r--tensorflow/python/client/session_test.py83
-rw-r--r--tensorflow/python/client/tf_session.i6
-rw-r--r--tensorflow/python/compat/BUILD22
-rw-r--r--tensorflow/python/compat/compat.py125
-rw-r--r--tensorflow/python/compat/compat_test.py70
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD14
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py294
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py32
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py37
-rw-r--r--tensorflow/python/data/kernel_tests/iterator_ops_test.py64
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py35
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py26
-rw-r--r--tensorflow/python/data/ops/BUILD2
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py1067
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py63
-rw-r--r--tensorflow/python/data/ops/readers.py10
-rw-r--r--tensorflow/python/data/util/BUILD1
-rw-r--r--tensorflow/python/data/util/convert.py38
-rw-r--r--tensorflow/python/data/util/convert_test.py73
-rw-r--r--tensorflow/python/data/util/random_seed_test.py2
-rw-r--r--tensorflow/python/debug/BUILD17
-rw-r--r--tensorflow/python/debug/cli/cli_shared.py44
-rw-r--r--tensorflow/python/debug/cli/cli_shared_test.py5
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common.py35
-rw-r--r--tensorflow/python/debug/cli/debugger_cli_common_test.py29
-rw-r--r--tensorflow/python/debug/examples/debug_keras.py89
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh13
-rw-r--r--tensorflow/python/debug/lib/debug_data.py43
-rw-r--r--tensorflow/python/debug/lib/debug_graph_reconstruction_test.py3
-rw-r--r--tensorflow/python/debug/wrappers/framework.py87
-rw-r--r--tensorflow/python/debug/wrappers/grpc_wrapper.py6
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper.py4
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py118
-rw-r--r--tensorflow/python/eager/BUILD18
-rw-r--r--tensorflow/python/eager/backprop.py19
-rw-r--r--tensorflow/python/eager/backprop_test.py69
-rw-r--r--tensorflow/python/eager/context.py10
-rw-r--r--tensorflow/python/eager/function.py453
-rw-r--r--tensorflow/python/eager/function_test.py385
-rw-r--r--tensorflow/python/eager/graph_callable.py47
-rw-r--r--tensorflow/python/eager/memory_test.py108
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc125
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py43
-rw-r--r--tensorflow/python/eager/test.py1
-rw-r--r--tensorflow/python/estimator/BUILD459
-rw-r--r--tensorflow/python/estimator/__init__.py25
-rw-r--r--tensorflow/python/estimator/api/BUILD19
-rw-r--r--tensorflow/python/estimator/canned/baseline.py24
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py244
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py719
-rw-r--r--tensorflow/python/estimator/canned/dnn.py84
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py97
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py12
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py100
-rw-r--r--tensorflow/python/estimator/canned/head.py33
-rw-r--r--tensorflow/python/estimator/canned/head_test.py117
-rw-r--r--tensorflow/python/estimator/canned/linear.py92
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py141
-rw-r--r--tensorflow/python/estimator/canned/optimizers.py2
-rw-r--r--tensorflow/python/estimator/canned/optimizers_test.py30
-rw-r--r--tensorflow/python/estimator/canned/parsing_utils.py6
-rw-r--r--tensorflow/python/estimator/estimator.py208
-rw-r--r--tensorflow/python/estimator/estimator_test.py110
-rw-r--r--tensorflow/python/estimator/export/export.py16
-rw-r--r--tensorflow/python/estimator/export/export_output.py10
-rw-r--r--tensorflow/python/estimator/exporter.py10
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io.py4
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io.py45
-rw-r--r--tensorflow/python/estimator/inputs/pandas_io_test.py70
-rw-r--r--tensorflow/python/estimator/keras.py42
-rw-r--r--tensorflow/python/estimator/keras_test.py16
-rw-r--r--tensorflow/python/estimator/model_fn.py110
-rw-r--r--tensorflow/python/estimator/model_fn_test.py21
-rw-r--r--tensorflow/python/estimator/run_config.py51
-rw-r--r--tensorflow/python/estimator/run_config_test.py112
-rw-r--r--tensorflow/python/estimator/training.py184
-rw-r--r--tensorflow/python/estimator/training_test.py322
-rw-r--r--tensorflow/python/estimator/util.py22
-rw-r--r--tensorflow/python/feature_column/BUILD68
-rw-r--r--tensorflow/python/feature_column/feature_column.py508
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py113
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py3600
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py6583
-rw-r--r--tensorflow/python/framework/common_shapes.py12
-rw-r--r--tensorflow/python/framework/error_interpolation.py92
-rw-r--r--tensorflow/python/framework/error_interpolation_test.py49
-rw-r--r--tensorflow/python/framework/function.py71
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py6
-rw-r--r--tensorflow/python/framework/importer.py6
-rw-r--r--tensorflow/python/framework/ops.py660
-rw-r--r--tensorflow/python/framework/ops_test.py9
-rw-r--r--tensorflow/python/framework/random_seed_test.py2
-rw-r--r--tensorflow/python/framework/tensor_util_test.py78
-rw-r--r--tensorflow/python/framework/test_util.py183
-rw-r--r--tensorflow/python/framework/test_util_test.py76
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py7
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py6
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py3
-rwxr-xr-xtensorflow/python/keras/BUILD10
-rw-r--r--tensorflow/python/keras/__init__.py1
-rw-r--r--tensorflow/python/keras/activations.py73
-rw-r--r--tensorflow/python/keras/applications/densenet.py4
-rw-r--r--tensorflow/python/keras/applications/inception_resnet_v2.py4
-rw-r--r--tensorflow/python/keras/applications/inception_v3.py4
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py6
-rw-r--r--tensorflow/python/keras/applications/nasnet.py4
-rw-r--r--tensorflow/python/keras/applications/resnet50.py3
-rw-r--r--tensorflow/python/keras/applications/vgg16.py3
-rw-r--r--tensorflow/python/keras/applications/vgg19.py3
-rw-r--r--tensorflow/python/keras/applications/xception.py4
-rw-r--r--tensorflow/python/keras/backend.py307
-rw-r--r--tensorflow/python/keras/backend_test.py264
-rw-r--r--tensorflow/python/keras/callbacks.py153
-rw-r--r--tensorflow/python/keras/callbacks_test.py269
-rw-r--r--tensorflow/python/keras/datasets/boston_housing.py10
-rw-r--r--tensorflow/python/keras/datasets/fashion_mnist.py8
-rw-r--r--tensorflow/python/keras/datasets/imdb.py6
-rw-r--r--tensorflow/python/keras/datasets/mnist.py20
-rw-r--r--tensorflow/python/keras/datasets/reuters.py12
-rw-r--r--tensorflow/python/keras/engine/__init__.py6
-rw-r--r--tensorflow/python/keras/engine/base_layer.py101
-rw-r--r--tensorflow/python/keras/engine/input_layer.py8
-rw-r--r--tensorflow/python/keras/engine/network.py163
-rw-r--r--tensorflow/python/keras/engine/saving.py129
-rw-r--r--tensorflow/python/keras/engine/saving_test.py114
-rw-r--r--tensorflow/python/keras/engine/sequential.py17
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py36
-rw-r--r--tensorflow/python/keras/engine/topology_test.py76
-rw-r--r--tensorflow/python/keras/engine/training.py28
-rw-r--r--tensorflow/python/keras/engine/training_arrays.py21
-rw-r--r--tensorflow/python/keras/engine/training_eager.py29
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py26
-rw-r--r--tensorflow/python/keras/engine/training_generator.py20
-rw-r--r--tensorflow/python/keras/engine/training_test.py78
-rw-r--r--tensorflow/python/keras/engine/training_utils.py30
-rw-r--r--tensorflow/python/keras/estimator/__init__.py46
-rw-r--r--tensorflow/python/keras/initializers.py51
-rw-r--r--tensorflow/python/keras/initializers_test.py26
-rw-r--r--tensorflow/python/keras/layers/__init__.py11
-rw-r--r--tensorflow/python/keras/layers/advanced_activations.py41
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py14
-rw-r--r--tensorflow/python/keras/layers/convolutional.py114
-rw-r--r--tensorflow/python/keras/layers/convolutional_recurrent.py4
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py33
-rw-r--r--tensorflow/python/keras/layers/core.py69
-rw-r--r--tensorflow/python/keras/layers/core_test.py14
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent.py2
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent_test.py131
-rw-r--r--tensorflow/python/keras/layers/embeddings.py7
-rw-r--r--tensorflow/python/keras/layers/gru_test.py8
-rw-r--r--tensorflow/python/keras/layers/local.py64
-rw-r--r--tensorflow/python/keras/layers/local_test.py107
-rw-r--r--tensorflow/python/keras/layers/lstm_test.py8
-rw-r--r--tensorflow/python/keras/layers/merge.py8
-rw-r--r--tensorflow/python/keras/layers/merge_test.py16
-rw-r--r--tensorflow/python/keras/layers/noise.py2
-rw-r--r--tensorflow/python/keras/layers/noise_test.py2
-rw-r--r--tensorflow/python/keras/layers/normalization.py43
-rw-r--r--tensorflow/python/keras/layers/pooling.py4
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py18
-rw-r--r--tensorflow/python/keras/layers/recurrent.py13
-rw-r--r--tensorflow/python/keras/layers/serialization.py4
-rw-r--r--tensorflow/python/keras/layers/simplernn_test.py8
-rw-r--r--tensorflow/python/keras/layers/wrappers.py141
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py118
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py107
-rw-r--r--tensorflow/python/keras/models_test.py16
-rw-r--r--tensorflow/python/keras/optimizers.py56
-rw-r--r--tensorflow/python/keras/optimizers_test.py6
-rw-r--r--tensorflow/python/keras/testing_utils.py74
-rw-r--r--tensorflow/python/keras/utils/data_utils.py4
-rw-r--r--tensorflow/python/keras/utils/io_utils.py5
-rw-r--r--tensorflow/python/keras/utils/io_utils_test.py24
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py41
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD24
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/as_string_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/atrous_convolution_test.py10
-rw-r--r--tensorflow/python/kernel_tests/betainc_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py144
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py44
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py212
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py536
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py2
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py31
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py106
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/dct_ops_test.py96
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/BUILD5
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py47
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py30
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py32
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_test.py18
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py10
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py104
-rw-r--r--tensorflow/python/kernel_tests/distributions/laplace_test.py13
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py24
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py55
-rw-r--r--tensorflow/python/kernel_tests/distributions/special_math_test.py6
-rw-r--r--tensorflow/python/kernel_tests/distributions/student_t_test.py56
-rw-r--r--tensorflow/python/kernel_tests/distributions/uniform_test.py49
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py83
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py1
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py218
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py20
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py136
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py71
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py35
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py110
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py68
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py22
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py73
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py31
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py31
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py96
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py28
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py6
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py44
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py16
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py112
-rw-r--r--tensorflow/python/kernel_tests/random/BUILD17
-rw-r--r--tensorflow/python/kernel_tests/random/multinomial_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/random/random_grad_test.py240
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py410
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py129
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py57
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py23
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_serialization_ops_test.py45
-rw-r--r--tensorflow/python/kernel_tests/sparse_slice_op_test.py22
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py96
-rw-r--r--tensorflow/python/kernel_tests/template_test.py38
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py100
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py160
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py2
-rw-r--r--tensorflow/python/layers/base.py43
-rw-r--r--tensorflow/python/layers/base_test.py57
-rw-r--r--tensorflow/python/layers/convolutional.py5
-rw-r--r--tensorflow/python/layers/core.py1
-rw-r--r--tensorflow/python/layers/core_test.py22
-rw-r--r--tensorflow/python/layers/normalization.py3
-rw-r--r--tensorflow/python/lib/core/bfloat16.cc11
-rw-r--r--tensorflow/python/lib/core/bfloat16_test.py14
-rw-r--r--tensorflow/python/lib/core/ndarray_tensor.cc40
-rw-r--r--tensorflow/python/lib/core/numpy.h2
-rw-r--r--tensorflow/python/lib/core/py_func.cc70
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc39
-rw-r--r--tensorflow/python/lib/core/py_util.cc2
-rw-r--r--tensorflow/python/lib/io/tf_record_test.py322
-rw-r--r--tensorflow/python/ops/array_grad.py8
-rw-r--r--tensorflow/python/ops/array_ops.py11
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py2
-rw-r--r--tensorflow/python/ops/collective_ops.py2
-rw-r--r--tensorflow/python/ops/collective_ops_test.py4
-rw-r--r--tensorflow/python/ops/cond_v2.py (renamed from tensorflow/contrib/proto/python/kernel_tests/test_case.py)27
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py479
-rw-r--r--tensorflow/python/ops/control_flow_ops.py53
-rw-r--r--tensorflow/python/ops/control_flow_ops_test.py26
-rw-r--r--tensorflow/python/ops/conv2d_benchmark.py114
-rw-r--r--tensorflow/python/ops/custom_gradient.py2
-rw-r--r--tensorflow/python/ops/data_flow_ops.py46
-rw-r--r--tensorflow/python/ops/distributions/beta.py29
-rw-r--r--tensorflow/python/ops/distributions/categorical.py23
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py34
-rw-r--r--tensorflow/python/ops/distributions/distribution.py63
-rw-r--r--tensorflow/python/ops/distributions/exponential.py3
-rw-r--r--tensorflow/python/ops/distributions/gamma.py33
-rw-r--r--tensorflow/python/ops/distributions/student_t.py21
-rw-r--r--tensorflow/python/ops/distributions/util.py45
-rw-r--r--tensorflow/python/ops/embedding_ops.py169
-rw-r--r--tensorflow/python/ops/functional_ops.py63
-rw-r--r--tensorflow/python/ops/gradients_impl.py200
-rw-r--r--tensorflow/python/ops/gradients_test.py173
-rw-r--r--tensorflow/python/ops/image_ops_impl.py318
-rw-r--r--tensorflow/python/ops/image_ops_test.py437
-rw-r--r--tensorflow/python/ops/init_ops.py38
-rw-r--r--tensorflow/python/ops/linalg/linear_operator.py8
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_diag.py5
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_low_rank_update.py31
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_lower_triangular.py8
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py41
-rw-r--r--tensorflow/python/ops/linalg_grad.py11
-rw-r--r--tensorflow/python/ops/logging_ops.py9
-rw-r--r--tensorflow/python/ops/lookup_ops.py8
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py58
-rw-r--r--tensorflow/python/ops/losses/util.py6
-rw-r--r--tensorflow/python/ops/math_grad.py54
-rw-r--r--tensorflow/python/ops/math_ops.py87
-rw-r--r--tensorflow/python/ops/math_ops_test.py29
-rw-r--r--tensorflow/python/ops/metrics_impl.py281
-rw-r--r--tensorflow/python/ops/nn_impl.py5
-rw-r--r--tensorflow/python/ops/nn_ops.py9
-rw-r--r--tensorflow/python/ops/nn_test.py12
-rw-r--r--tensorflow/python/ops/parallel_for/BUILD129
-rw-r--r--tensorflow/python/ops/parallel_for/__init__.py35
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops.py123
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops_test.py1404
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py126
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py568
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py2552
-rw-r--r--tensorflow/python/ops/random_grad.py65
-rw-r--r--tensorflow/python/ops/random_ops.py71
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py156
-rw-r--r--tensorflow/python/ops/rnn.py22
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py55
-rw-r--r--tensorflow/python/ops/script_ops.py144
-rw-r--r--tensorflow/python/ops/sparse_grad.py29
-rw-r--r--tensorflow/python/ops/special_math_ops.py54
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py123
-rw-r--r--tensorflow/python/ops/spectral_ops.py125
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/ops/state_ops.py79
-rw-r--r--tensorflow/python/ops/string_ops.py53
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py12
-rw-r--r--tensorflow/python/ops/template.py67
-rw-r--r--tensorflow/python/ops/tensor_array_grad.py1
-rw-r--r--tensorflow/python/ops/variable_scope.py382
-rw-r--r--tensorflow/python/ops/variables.py29
-rw-r--r--tensorflow/python/platform/benchmark.py10
-rw-r--r--tensorflow/python/platform/self_check.py2
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py4
-rw-r--r--tensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/saved_model/BUILD25
-rw-r--r--tensorflow/python/saved_model/builder_impl.py46
-rw-r--r--tensorflow/python/saved_model/loader_impl.py192
-rw-r--r--tensorflow/python/saved_model/loader_test.py217
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py128
-rw-r--r--tensorflow/python/training/adadelta.py17
-rw-r--r--tensorflow/python/training/adadelta_test.py116
-rw-r--r--tensorflow/python/training/adagrad.py12
-rw-r--r--tensorflow/python/training/adagrad_test.py73
-rw-r--r--tensorflow/python/training/adam.py20
-rw-r--r--tensorflow/python/training/adam_test.py18
-rw-r--r--tensorflow/python/training/checkpoint_utils.py10
-rw-r--r--tensorflow/python/training/checkpointable/BUILD25
-rw-r--r--tensorflow/python/training/checkpointable/base.py125
-rw-r--r--tensorflow/python/training/checkpointable/base_test.py32
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py297
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py71
-rw-r--r--tensorflow/python/training/checkpointable/layer_utils.py93
-rw-r--r--tensorflow/python/training/checkpointable/tracking.py72
-rw-r--r--tensorflow/python/training/checkpointable/tracking_test.py171
-rw-r--r--tensorflow/python/training/checkpointable/util.py262
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py180
-rw-r--r--tensorflow/python/training/device_util.py10
-rw-r--r--tensorflow/python/training/device_util_test.py2
-rw-r--r--tensorflow/python/training/distribute.py147
-rw-r--r--tensorflow/python/training/distribute_test.py39
-rw-r--r--tensorflow/python/training/gradient_descent.py15
-rw-r--r--tensorflow/python/training/gradient_descent_test.py26
-rw-r--r--tensorflow/python/training/learning_rate_decay.py384
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py522
-rw-r--r--tensorflow/python/training/momentum.py4
-rw-r--r--tensorflow/python/training/monitored_session.py14
-rw-r--r--tensorflow/python/training/optimizer.py32
-rw-r--r--tensorflow/python/training/optimizer_test.py14
-rw-r--r--tensorflow/python/training/rmsprop.py22
-rw-r--r--tensorflow/python/training/rmsprop_test.py54
-rw-r--r--tensorflow/python/training/saver.py148
-rw-r--r--tensorflow/python/training/saver_test.py187
-rw-r--r--tensorflow/python/util/deprecation.py38
-rw-r--r--tensorflow/python/util/deprecation_test.py6
-rw-r--r--tensorflow/python/util/lock_util.py128
-rw-r--r--tensorflow/python/util/lock_util_test.py63
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/serialization_test.py4
-rw-r--r--tensorflow/python/util/tf_export.py58
-rw-r--r--tensorflow/python/util/tf_export_test.py7
-rw-r--r--tensorflow/python/util/tf_inspect.py10
-rw-r--r--tensorflow/python/util/tf_inspect_test.py12
-rw-r--r--tensorflow/python/util/util.cc20
-rw-r--r--tensorflow/security/index.md4
-rw-r--r--tensorflow/stream_executor/BUILD9
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc15
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.cc64
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc3128
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h128
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_timer.h3
-rw-r--r--tensorflow/stream_executor/dnn.cc4
-rw-r--r--tensorflow/stream_executor/dnn.h5
-rw-r--r--tensorflow/stream_executor/event.cc11
-rw-r--r--tensorflow/stream_executor/event.h3
-rw-r--r--tensorflow/stream_executor/host/host_gpu_executor.cc9
-rw-r--r--tensorflow/stream_executor/stream.cc20
-rw-r--r--tensorflow/stream_executor/stream.h73
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc2
-rw-r--r--tensorflow/tensorflow.bzl54
-rw-r--r--tensorflow/tf_framework_version_script.lds11
-rw-r--r--tensorflow/tools/api/generator/BUILD45
-rw-r--r--tensorflow/tools/api/generator/api_gen.bzl77
-rw-r--r--tensorflow/tools/api/generator/create_python_api.py60
-rw-r--r--tensorflow/tools/api/generator/create_python_api_test.py12
-rw-r--r--tensorflow/tools/api/generator/doc_srcs.py31
-rw-r--r--tensorflow/tools/api/generator/doc_srcs_test.py33
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt16
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.debugging.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt7
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt24
-rw-r--r--tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.io.pbtxt39
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt176
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt175
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt176
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/tensorflow.manip.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/tensorflow.math.pbtxt232
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt40
-rw-r--r--tensorflow/tools/api/golden/tensorflow.quantization.pbtxt35
-rw-r--r--tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.spectral.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.strings.pbtxt36
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt2
-rw-r--r--tensorflow/tools/api/lib/python_object_to_proto_visitor.py3
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cmake2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le20
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le28
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cpu4
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh4
-rwxr-xr-xtensorflow/tools/ci_build/ci_build.sh4
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh20
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh4
-rwxr-xr-xtensorflow/tools/ci_build/copy_binary.py3
-rwxr-xr-xtensorflow/tools/ci_build/install/install_bazel_from_source.sh40
-rwxr-xr-xtensorflow/tools/ci_build/install/install_buildifier_from_source.sh30
-rwxr-xr-xtensorflow/tools/ci_build/install/install_golang_ppc64le.sh22
-rwxr-xr-xtensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh30
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh14
-rwxr-xr-xtensorflow/tools/ci_build/install/install_proto3.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh9
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh6
-rwxr-xr-xtensorflow/tools/ci_build/linux/gpu/run_mkl.sh47
-rwxr-xr-xtensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh29
-rwxr-xr-xtensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh29
-rwxr-xr-xtensorflow/tools/ci_build/linux/mkl/build-dev-container.sh53
-rwxr-xr-xtensorflow/tools/ci_build/pi/build_raspberry_pi.sh2
-rwxr-xr-xtensorflow/tools/ci_build/update_version.py12
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh133
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/common_env.sh13
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh15
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat1
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh62
-rwxr-xr-xtensorflow/tools/ci_build/windows/libtensorflow_cpu.sh2
-rw-r--r--tensorflow/tools/compatibility/ast_edits.py502
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7115
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl128
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl75
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh142
-rw-r--r--tensorflow/tools/docs/BUILD7
-rw-r--r--tensorflow/tools/docs/generate_lib.py112
-rw-r--r--tensorflow/tools/docs/generate_lib_test.py110
-rw-r--r--tensorflow/tools/docs/parser.py16
-rw-r--r--tensorflow/tools/docs/py_guide_parser.py3
-rwxr-xr-xtensorflow/tools/git/gen_git_source.py11
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc26
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc46
-rw-r--r--tensorflow/tools/lib_package/BUILD6
-rw-r--r--tensorflow/tools/pip_package/BUILD12
-rw-r--r--tensorflow/tools/pip_package/setup.py8
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc25
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc5
-rw-r--r--tensorflow/tools/quantization/quantize_graph_test.py12
-rw-r--r--tensorflow/workspace.bzl220
-rw-r--r--third_party/android/BUILD0
-rw-r--r--third_party/android/android.bzl.tpl9
-rw-r--r--third_party/android/android_configure.BUILD.tpl0
-rw-r--r--third_party/android/android_configure.bzl87
-rw-r--r--third_party/aws.BUILD3
-rw-r--r--third_party/clang_toolchain/download_clang.bzl8
-rw-r--r--third_party/codegen.BUILD16
-rw-r--r--third_party/curl.BUILD22
-rw-r--r--third_party/eigen.BUILD6
-rw-r--r--third_party/eigen3/BUILD60
-rw-r--r--third_party/eigen_fix_cuda_compilation.patch38
-rw-r--r--third_party/flatbuffers/flatbuffers.BUILD2
-rw-r--r--third_party/googleapis.BUILD45
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl20
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL.tpl1111
-rwxr-xr-xthird_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl6
-rw-r--r--third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl20
-rw-r--r--third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl192
-rw-r--r--third_party/gpus/cuda/BUILD.tpl9
-rw-r--r--third_party/gpus/cuda/BUILD.windows.tpl163
-rw-r--r--third_party/gpus/cuda_configure.bzl2163
-rw-r--r--third_party/jsoncpp.BUILD7
-rw-r--r--third_party/kafka/BUILD5
-rw-r--r--third_party/libxsmm.BUILD2
-rw-r--r--third_party/llvm/llvm.autogenerated.BUILD (renamed from third_party/llvm/llvm.BUILD)338
-rw-r--r--third_party/llvm/llvm.bzl140
-rw-r--r--third_party/mkl/LICENSE201
-rw-r--r--third_party/nanopb.BUILD23
-rw-r--r--third_party/nasm.BUILD180
-rw-r--r--third_party/repo.bzl1
-rw-r--r--third_party/sqlite.BUILD1
-rw-r--r--third_party/toolchains/BUILD22
-rw-r--r--third_party/toolchains/clang6/CROSSTOOL.tpl3
-rw-r--r--tools/bazel.rc2
3067 files changed, 182667 insertions, 60061 deletions
diff --git a/.gitignore b/.gitignore
index 828bbe9bd3..b5306b8b79 100644
--- a/.gitignore
+++ b/.gitignore
@@ -16,6 +16,7 @@ __pycache__
cmake_build/
.idea/**
/build/
+[Bb]uild/
/tensorflow/core/util/version_info.cc
/tensorflow/python/framework/fast_tensor_util.cpp
Pods
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 8669c25c45..f598999f35 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g.,
Changes to TensorFlow C++ code should conform to
[Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html).
-Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do:
+Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do:
```bash
apt-get install -y clang-tidy
@@ -107,7 +107,7 @@ diff <my_cc_file> /tmp/my_cc_file.cc
#### Python coding style
Changes to TensorFlow Python code should conform to
-[Google Python Style Guide](https://google.github.io/styleguide/pyguide.html)
+[Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
Use `pylint` to check your Python changes. To install `pylint` and
retrieve TensorFlow's custom style definition:
diff --git a/README.md b/README.md
index 4e4d139bd1..05fcb23f7e 100644
--- a/README.md
+++ b/README.md
@@ -56,6 +56,7 @@ $ python
42
>>> sess.close()
```
+Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/).
## Contribution guidelines
@@ -95,6 +96,8 @@ The TensorFlow project strives to abide by generally accepted best practices in
| --- | --- | --- |
| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA |
| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA |
+| **IBM ppc64le GPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_PPC64LE_GPU/) | TBA |
+| **Linux CPU with Intel® MKL-DNN®** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-linux-cpu/) | TBA |
## For more information
diff --git a/RELEASE.md b/RELEASE.md
index 21207a7efa..6b67072f8e 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -6,7 +6,7 @@
* Update `tf.keras` to the Keras 2.1.6 API.
* Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082).
* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees).
-* The [python interface](https://tensorflow-dot-devsite.googleplex.com/versions/r1.9/api_docs/python/tf/contrib/lite)
+* The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/contrib/lite)
for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/README.md)
has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again
included in the standard `pip` installation.
@@ -33,7 +33,6 @@
* Using `tf.keras.layers` with custom variable scopes.
* Using `tf.layers` in a subclassed `tf.keras.Model` class. See
[here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details
-
* `tf.data`:
* `Dataset.from_generator()` now accepts an `args` list, in order to create nested generators.
* `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed.
@@ -491,14 +490,6 @@ answered questions, and were part of inspiring discussions.
* [`tf.data`](http://tensorflow.org/guide/datasets) is now part of
the core TensorFlow API.
* The API is now subject to backwards compatibility guarantees.
-
-# Release 1.4.0
-
-## Major Features And Improvements
-* `tf.keras` is now part of the core TensorFlow API.
-* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of
- the core TensorFlow API.
- * The API is now subject to backwards compatibility guarantees.
* For a guide to migrating from the `tf.contrib.data` API, see the
[README](https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/data/README.md).
* Major new features include `Dataset.from_generator()` (for building an input
diff --git a/SECURITY.md b/SECURITY.md
index 0a4be37cbc..0b52fdc7ab 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -242,12 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
-----END PGP PUBLIC KEY BLOCK-----
```
-### Known vulnerabilities
-
-| Type | Versions affected | Reported by | Additional Information |
-|--------------------|:-----------------:|-----------------------|-----------------------------|
-| TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-003.md) |
-| GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-002.md) |
-| BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | [security advisory](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/security/advisory/tfsa-2018-001.md) |
-| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+### Known Vulnerabilities
+For a list of known vulnerabilities and security advisories for TensorFlow,
+[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
diff --git a/WORKSPACE b/WORKSPACE
index 4ddfb9a383..fd7570a80a 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0")
load("//tensorflow:workspace.bzl", "tf_workspace")
-# Uncomment and update the paths in these entries to build the Android demo.
-#android_sdk_repository(
-# name = "androidsdk",
-# api_level = 23,
-# # Ensure that you have the build_tools_version below installed in the
-# # SDK manager as it updates periodically.
-# build_tools_version = "26.0.1",
-# # Replace with path to Android SDK on your system
-# path = "<PATH_TO_SDK>",
-#)
-#
-#android_ndk_repository(
-# name="androidndk",
-# path="<PATH_TO_NDK>",
-# # This needs to be 14 or higher to compile TensorFlow.
-# # Please specify API level to >= 21 to build for 64-bit
-# # archtectures or the Android NDK will automatically select biggest
-# # API level that it supports without notice.
-# # Note that the NDK version is not the API level.
-# api_level=14)
+load("//third_party/android:android_configure.bzl", "android_configure")
+android_configure(name="local_config_android")
+load("@local_config_android//:android.bzl", "android_workspace")
+android_workspace()
# Please add all new TensorFlow dependencies in workspace.bzl.
tf_workspace()
diff --git a/configure.py b/configure.py
index 96caa2e2dd..8930c3a1f1 100644
--- a/configure.py
+++ b/configure.py
@@ -670,8 +670,9 @@ def create_android_ndk_rule(environ_cp):
error_msg=('The path %s or its child file "source.properties" '
'does not exist.')
)
-
- write_android_ndk_workspace_rule(android_ndk_home_path)
+ write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path)
+ write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL',
+ check_ndk_level(android_ndk_home_path))
def create_android_sdk_rule(environ_cp):
@@ -733,41 +734,12 @@ def create_android_sdk_rule(environ_cp):
error_msg=('The selected SDK does not have build-tools version %s '
'available.'))
- write_android_sdk_workspace_rule(android_sdk_home_path,
- android_build_tools_version,
- android_api_level)
-
-
-def write_android_sdk_workspace_rule(android_sdk_home_path,
- android_build_tools_version,
- android_api_level):
- print('Writing android_sdk_workspace rule.\n')
- with open(_TF_WORKSPACE, 'a') as f:
- f.write("""
-android_sdk_repository(
- name="androidsdk",
- api_level=%s,
- path="%s",
- build_tools_version="%s")\n
-""" % (android_api_level, android_sdk_home_path, android_build_tools_version))
-
-
-def write_android_ndk_workspace_rule(android_ndk_home_path):
- print('Writing android_ndk_workspace rule.')
- ndk_api_level = check_ndk_level(android_ndk_home_path)
- if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
- print('WARNING: The API level of the NDK in %s is %s, which is not '
- 'supported by Bazel (officially supported versions: %s). Please use '
- 'another version. Compiling Android targets may result in confusing '
- 'errors.\n' % (android_ndk_home_path, ndk_api_level,
- _SUPPORTED_ANDROID_NDK_VERSIONS))
- with open(_TF_WORKSPACE, 'a') as f:
- f.write("""
-android_ndk_repository(
- name="androidndk",
- path="%s",
- api_level=%s)\n
-""" % (android_ndk_home_path, ndk_api_level))
+ write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION',
+ android_build_tools_version)
+ write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL',
+ android_api_level)
+ write_action_env_to_bazelrc('ANDROID_SDK_HOME',
+ android_sdk_home_path)
def check_ndk_level(android_ndk_home_path):
@@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path):
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
if revision:
- return revision.group(1)
- return None
-
-
-def workspace_has_any_android_rule():
- """Check the WORKSPACE for existing android_*_repository rules."""
- with open(_TF_WORKSPACE, 'r') as f:
- workspace = f.read()
- has_any_rule = re.search(r'^android_[ns]dk_repository',
- workspace,
- re.MULTILINE)
- return has_any_rule
+ ndk_api_level = revision.group(1)
+ else:
+ raise Exception('Unable to parse NDK revision.')
+ if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
+ print('WARNING: The API level of the NDK in %s is %s, which is not '
+ 'supported by Bazel (officially supported versions: %s). Please use '
+ 'another version. Compiling Android targets may result in confusing '
+ 'errors.\n' % (android_ndk_home_path, ndk_api_level,
+ _SUPPORTED_ANDROID_NDK_VERSIONS))
+ return ndk_api_level
def set_gcc_host_compiler_path(environ_cp):
@@ -865,6 +835,8 @@ def set_tf_cuda_version(environ_cp):
'[Default is %s]: ') % (tf_cuda_version, default_cuda_path)
cuda_toolkit_path = get_from_env_or_user_or_default(
environ_cp, 'CUDA_TOOLKIT_PATH', ask_cuda_path, default_cuda_path)
+ if is_windows() or is_cygwin():
+ cuda_toolkit_path = cygpath(cuda_toolkit_path)
if is_windows():
cuda_rt_lib_path = 'lib/x64/cudart.lib'
@@ -973,6 +945,35 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
+def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
+ """Check compatibility between given library and cudnn/cudart libraries."""
+ ldd_bin = which('ldd') or '/usr/bin/ldd'
+ ldd_out = run_shell([ldd_bin, lib], True)
+ ldd_out = ldd_out.split(os.linesep)
+ cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
+ cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
+ cudnn = None
+ cudart = None
+ cudnn_ok = True # assume no cudnn dependency by default
+ cuda_ok = True # assume no cuda dependency by default
+ for line in ldd_out:
+ if 'libcudnn.so' in line:
+ cudnn = cudnn_pattern.search(line)
+ cudnn_ok = False
+ elif 'libcudart.so' in line:
+ cudart = cuda_pattern.search(line)
+ cuda_ok = False
+ if cudnn and len(cudnn.group(1)):
+ cudnn = convert_version_to_int(cudnn.group(1))
+ if cudart and len(cudart.group(1)):
+ cudart = convert_version_to_int(cudart.group(1))
+ if cudnn is not None:
+ cudnn_ok = (cudnn == cudnn_ver)
+ if cudart is not None:
+ cuda_ok = (cudart == cuda_ver)
+ return cudnn_ok and cuda_ok
+
+
def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
@@ -989,8 +990,8 @@ def set_tf_tensorrt_install_path(environ_cp):
raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support.
- if str(int(get_var(
- environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
+ if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
+ False))) != '1':
return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
@@ -1003,47 +1004,29 @@ def set_tf_tensorrt_install_path(environ_cp):
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
- trt_install_path = os.path.realpath(
- os.path.expanduser(trt_install_path))
+ trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
- fl.update([os.path.realpath(os.path.join(search_path, x))
- for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+ fl.update([
+ os.path.realpath(os.path.join(search_path, x))
+ for x in os.listdir(search_path)
+ if 'libnvinfer.so' in x
+ ])
return fl
possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
-
- def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
- """Check the compatibility between tensorrt and cudnn/cudart libraries."""
- ldd_bin = which('ldd') or '/usr/bin/ldd'
- ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
- cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
- cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
- cudnn = None
- cudart = None
- for line in ldd_out:
- if 'libcudnn.so' in line:
- cudnn = cudnn_pattern.search(line)
- elif 'libcudart.so' in line:
- cudart = cuda_pattern.search(line)
- if cudnn and len(cudnn.group(1)):
- cudnn = convert_version_to_int(cudnn.group(1))
- if cudart and len(cudart.group(1)):
- cudart = convert_version_to_int(cudart.group(1))
- return (cudnn == cudnn_ver) and (cudart == cuda_ver)
-
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
- if is_compatible(lib_file, cuda_ver, cudnn_ver):
+ if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
if len(matches.groups()) == 0:
continue
@@ -1059,12 +1042,13 @@ def set_tf_tensorrt_install_path(environ_cp):
# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
- search_result = re.search(
- '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
+ search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
+ ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
- if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
+ if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
+ cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break
@@ -1152,7 +1136,9 @@ def set_tf_nccl_install_path(environ_cp):
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
- if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+ nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt')
+ if os.path.exists(nccl_lib_path) and os.path.exists(
+ nccl_hdr_path) and os.path.exists(nccl_license_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
@@ -1223,7 +1209,7 @@ def set_tf_cuda_compute_capabilities(environ_cp):
# Check whether all capabilities from the input is valid
all_valid = True
# Remove all whitespace characters before splitting the string
- # that users may insert by accident, as this will result in error
+ # that users may insert by accident, as this will result in error
tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split())
for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability)
@@ -1250,28 +1236,13 @@ def set_tf_cuda_compute_capabilities(environ_cp):
def set_other_cuda_vars(environ_cp):
"""Set other CUDA related variables."""
- if is_windows():
- # The following three variables are needed for MSVC toolchain configuration
- # in Bazel
- environ_cp['CUDA_PATH'] = environ_cp.get('CUDA_TOOLKIT_PATH')
- environ_cp['CUDA_COMPUTE_CAPABILITIES'] = environ_cp.get(
- 'TF_CUDA_COMPUTE_CAPABILITIES')
- environ_cp['NO_WHOLE_ARCHIVE_OPTION'] = 1
- write_action_env_to_bazelrc('CUDA_PATH', environ_cp.get('CUDA_PATH'))
- write_action_env_to_bazelrc('CUDA_COMPUTE_CAPABILITIE',
- environ_cp.get('CUDA_COMPUTE_CAPABILITIE'))
- write_action_env_to_bazelrc('NO_WHOLE_ARCHIVE_OPTION',
- environ_cp.get('NO_WHOLE_ARCHIVE_OPTION'))
- write_to_bazelrc('build --config=win-cuda')
- write_to_bazelrc('test --config=win-cuda')
+ # If CUDA is enabled, always use GPU during build and test.
+ if environ_cp.get('TF_CUDA_CLANG') == '1':
+ write_to_bazelrc('build --config=cuda_clang')
+ write_to_bazelrc('test --config=cuda_clang')
else:
- # If CUDA is enabled, always use GPU during build and test.
- if environ_cp.get('TF_CUDA_CLANG') == '1':
- write_to_bazelrc('build --config=cuda_clang')
- write_to_bazelrc('test --config=cuda_clang')
- else:
- write_to_bazelrc('build --config=cuda')
- write_to_bazelrc('test --config=cuda')
+ write_to_bazelrc('build --config=cuda')
+ write_to_bazelrc('test --config=cuda')
def set_host_cxx_compiler(environ_cp):
@@ -1465,7 +1436,7 @@ def main():
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_S3'] = '0'
+ environ_cp['TF_NEED_AWS'] = '0'
environ_cp['TF_NEED_GCP'] = '0'
environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
@@ -1489,8 +1460,8 @@ def main():
'with_gcp_support', True, 'gcp')
set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
- 'with_s3_support', True, 's3')
+ set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
+ 'with_aws_support', True, 'aws')
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', True, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
@@ -1556,21 +1527,15 @@ def main():
set_build_strip_flag()
set_windows_build_flags()
- if workspace_has_any_android_rule():
- print('The WORKSPACE file has at least one of ["android_sdk_repository", '
- '"android_ndk_repository"] already set. Will not ask to help '
- 'configure the WORKSPACE. Please delete the existing rules to '
- 'activate the helper.\n')
- else:
- if get_var(
- environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
- False,
- ('Would you like to interactively configure ./WORKSPACE for '
- 'Android builds?'),
- 'Searching for NDK and SDK installations.',
- 'Not configuring the WORKSPACE for Android builds.'):
- create_android_ndk_rule(environ_cp)
- create_android_sdk_rule(environ_cp)
+ if get_var(
+ environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
+ False,
+ ('Would you like to interactively configure ./WORKSPACE for '
+ 'Android builds?'),
+ 'Searching for NDK and SDK installations.',
+ 'Not configuring the WORKSPACE for Android builds.'):
+ create_android_ndk_rule(environ_cp)
+ create_android_sdk_rule(environ_cp)
print('Preconfigured Bazel build configs. You can use any of the below by '
'adding "--config=<>" to your build command. See tools/bazel.rc for '
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 9b07669a5d..51eea94847 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -155,6 +155,12 @@ config_setting(
)
config_setting(
+ name = "linux_s390x",
+ values = {"cpu": "s390x"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "debug",
values = {
"compilation_mode": "dbg",
@@ -210,8 +216,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support",
+ define_values = {"with_aws_support": "true"},
visibility = ["//visibility:public"],
)
@@ -238,8 +244,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_windows_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_windows_override",
+ define_values = {"with_aws_support": "true"},
values = {"cpu": "x64_windows"},
visibility = ["//visibility:public"],
)
@@ -252,6 +258,13 @@ config_setting(
)
config_setting(
+ name = "with_cuda_support_windows_override",
+ define_values = {"using_cuda_nvcc": "true"},
+ values = {"cpu": "x64_windows"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "with_gcp_support_android_override",
define_values = {"with_gcp_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
@@ -266,8 +279,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_android_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_android_override",
+ define_values = {"with_aws_support": "true"},
values = {"crosstool_top": "//external:android/crosstool"},
visibility = ["//visibility:public"],
)
@@ -287,8 +300,8 @@ config_setting(
)
config_setting(
- name = "with_s3_support_ios_override",
- define_values = {"with_s3_support": "true"},
+ name = "with_aws_support_ios_override",
+ define_values = {"with_aws_support": "true"},
values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
visibility = ["//visibility:public"],
)
@@ -398,6 +411,7 @@ config_setting(
package_group(
name = "internal",
packages = [
+ "-//third_party/tensorflow/python/estimator",
"//learning/meta_rank/...",
"//tensorflow/...",
"//tensorflow_fold/llgtm/...",
@@ -424,6 +438,22 @@ filegroup(
data = glob(["docs_src/**/*.md"]),
)
+cc_library(
+ name = "grpc",
+ deps = select({
+ ":linux_s390x": ["@grpc//:grpc_unsecure"],
+ "//conditions:default": ["@grpc"],
+ }),
+)
+
+cc_library(
+ name = "grpc++",
+ deps = select({
+ ":linux_s390x": ["@grpc//:grpc++_unsecure"],
+ "//conditions:default": ["@grpc//:grpc++"],
+ }),
+)
+
# A shared object which includes registration mechanisms for ops and
# kernels. Does not include the implementations of any ops or kernels. Instead,
# the library which loads libtensorflow_framework.so
@@ -451,6 +481,15 @@ filegroup(
tf_cc_shared_object(
name = "libtensorflow_framework.so",
framework_so = [],
+ linkopts = select({
+ "//tensorflow:darwin": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:windows_msvc": [],
+ "//conditions:default": [
+ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file
+ "$(location //tensorflow:tf_framework_version_script.lds)",
+ ],
+ }),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
@@ -460,6 +499,7 @@ tf_cc_shared_object(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/stream_executor:stream_executor_impl",
+ "//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(),
)
@@ -539,15 +579,27 @@ exports_files(
)
gen_api_init_files(
- name = "python_api_gen",
+ name = "tensorflow_python_api_gen",
srcs = ["api_template.__init__.py"],
root_init_template = "api_template.__init__.py",
)
py_library(
name = "tensorflow_py",
- srcs = [":python_api_gen"],
+ srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflow_py_no_contrib",
+ "//tensorflow/contrib:contrib_py",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_library(
+ name = "tensorflow_py_no_contrib",
+ srcs = [":tensorflow_python_api_gen"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = ["//tensorflow/python"],
+ deps = ["//tensorflow/python:no_contrib"],
)
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 9b0d7d48af..779f65d5b1 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -20,9 +20,25 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
+
+try:
+ import os # pylint: disable=g-import-not-at-top
+ # Add `estimator` attribute to allow access to estimator APIs via
+ # "tf.estimator..."
+ from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top
+
+ # Add `estimator` to the __path__ to allow "from tensorflow.estimator..."
+ # style imports.
+ from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top
+ __path__ += [os.path.dirname(estimator_api.__file__)]
+ del estimator_api
+ del os
+except (ImportError, AttributeError):
+ print('tf.estimator package not installed.')
+
# API IMPORTS PLACEHOLDER
-from tensorflow.python.util.lazy_loader import LazyLoader
+from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index b86b277ac3..5c218d3f25 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -390,64 +391,6 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
status->status = Reset(opt->options, container_names);
}
-// This traverses the specified nodes in topological order to verify there are
-// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
-// all been visited, and counts the total number of visited nodes. If there is a
-// cycle, nodes in the cycle will never be visited, and the visited count will
-// be less than the total node count.
-Status ValidateNoCycles(const Graph& g) {
- // TODO(nolivia): check this on a subset of the graph instead of all of it.
- // A node is ready when all of its inputs have been visited.
- std::vector<const Node*> ready;
- std::vector<int> pending_count(g.num_node_ids(), 0);
-
- for (int i = 0; i < g.num_node_ids(); ++i) {
- const Node* n = g.FindNodeId(i);
- if (n == nullptr) continue;
- pending_count[i] = n->in_edges().size();
- if (n->IsMerge()) {
- // While-loop cycles are legal cycles so we manually adjust the
- // pending_count to make sure that the loop is visited.
- for (const Edge* e : n->in_edges()) {
- if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
- pending_count[i]--;
- }
- }
- }
- if (pending_count[i] == 0) {
- ready.push_back(n);
- }
- }
-
- int processed = 0;
- while (!ready.empty()) {
- const Node* node = ready.back();
- ready.pop_back();
- ++processed;
-
- for (const Edge* out : node->out_edges()) {
- const int output_id = out->dst()->id();
- pending_count[output_id]--;
- if (pending_count[output_id] == 0) {
- ready.push_back(out->dst());
- }
- }
- }
-
- if (processed < g.num_nodes()) {
- std::vector<string> nodes_in_cycle;
- for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
- ++i) {
- if (pending_count[i] != 0) {
- nodes_in_cycle.push_back(g.FindNodeId(i)->name());
- }
- }
- return errors::InvalidArgument(
- "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
- " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
- }
- return Status::OK();
-}
} // namespace
} // namespace tensorflow
@@ -631,7 +574,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
"Failed to allocate memory to serialize message of type '",
in.GetTypeName(), "' and size ", proto_size);
}
- in.SerializeToArray(buf, proto_size);
+ // SerializeToArray takes size as an int.
+ // This next 'if' is a workaround till we update to depend on a version
+ // of protocol buffers that includes
+ // https://github.com/google/protobuf/pull/4739
+ if (proto_size > std::numeric_limits<int>::max()) {
+ return InvalidArgument("Cannot serialize protocol buffer of type ",
+ in.GetTypeName(), " as the serialized size (",
+ proto_size,
+ "bytes) would be larger than the limit (",
+ std::numeric_limits<int>::max(), " bytes)");
+ }
+ if (!in.SerializeToArray(buf, proto_size)) {
+ return InvalidArgument("Unable to serialize ", in.GetTypeName(),
+ " protocol buffer, perhaps the serialized size (",
+ proto_size, " bytes) is too large?");
+ }
out->data = buf;
out->length = proto_size;
out->data_deallocator = [](void* data, size_t length) {
@@ -731,7 +689,9 @@ bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
- status->status = tensorflow::ValidateNoCycles(session->graph->graph);
+ // TODO(nolivia): check this on a subset of the graph instead of all of
+ // it.
+ status->status = graph::ValidateGraphHasNoCycle(session->graph->graph);
if (!status->status.ok()) {
session->graph->mu.unlock();
return false;
@@ -2108,7 +2068,8 @@ TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults(
TF_Graph* graph, const TF_Buffer* graph_def,
const TF_ImportGraphDefOptions* options, TF_Status* status) {
GraphDef def;
- if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
+ if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
+ graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return nullptr;
}
@@ -2138,7 +2099,8 @@ void TF_GraphImportGraphDefWithReturnOutputs(
return;
}
GraphDef def;
- if (!def.ParseFromArray(graph_def->data, graph_def->length)) {
+ if (!tensorflow::ParseProtoUnlimited(&def, graph_def->data,
+ graph_def->length)) {
status->status = InvalidArgument("Invalid GraphDef");
return;
}
@@ -2454,7 +2416,18 @@ void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, TF_Output* x, int nx,
for (int i = first_new_node_id; i < g->graph.num_node_ids(); ++i) {
Node* n = g->graph.FindNodeId(i);
if (n == nullptr) continue;
- g->name_map[n->name()] = n;
+ // We have a convoluted scheme here: Using the C++ graph construction API
+ // to add potentially many nodes to the graph without running the checks
+ // (such as uniqueness of the names of nodes) we run with other functions
+ // that add a node to the graph (like TF_FinishOperation).
+ if (!g->name_map.insert(std::make_pair(n->name(), n)).second) {
+ status->status = tensorflow::errors::Internal(
+ "BUG: The API allowed construction of a graph with duplicate node "
+ "names (",
+ n->name(),
+ "). This is a bug. Please file an issue at "
+ "https://github.com/tensorflow/tensorflow/issues.");
+ }
}
}
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index c859434745..1eb75ef11f 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -894,7 +894,8 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_ImportGraphDefOptions* opts);
// Set the prefix to be prepended to the names of nodes in `graph_def` that will
-// be imported into `graph`.
+// be imported into `graph`. `prefix` is copied and has no lifetime
+// requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix(
TF_ImportGraphDefOptions* opts, const char* prefix);
@@ -915,6 +916,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix(
// Set any imported nodes with input `src_name:src_index` to have that input
// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
// `dst` references a node already existing in the graph being imported into.
+// `src_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping(
TF_ImportGraphDefOptions* opts, const char* src_name, int src_index,
TF_Output dst);
@@ -922,7 +924,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddInputMapping(
// Set any imported nodes with control input `src_name` to have that input
// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
// `dst` references an operation already existing in the graph being imported
-// into.
+// into. `src_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsRemapControlDependency(
TF_ImportGraphDefOptions* opts, const char* src_name, TF_Operation* dst);
@@ -934,6 +936,7 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddControlDependency(
// Add an output in `graph_def` to be returned via the `return_outputs` output
// parameter of TF_GraphImportGraphDef(). If the output is remapped via an input
// mapping, the corresponding existing tensor in `graph` will be returned.
+// `oper_name` is copied and has no lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput(
TF_ImportGraphDefOptions* opts, const char* oper_name, int index);
@@ -943,7 +946,8 @@ TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs(
const TF_ImportGraphDefOptions* opts);
// Add an operation in `graph_def` to be returned via the `return_opers` output
-// parameter of TF_GraphImportGraphDef().
+// parameter of TF_GraphImportGraphDef(). `oper_name` is copied and has no
+// lifetime requirements.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation(
TF_ImportGraphDefOptions* opts, const char* oper_name);
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 95b04f9058..170046c802 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -57,6 +57,33 @@ void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
}
}
+TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
+ unsigned char gpu_memory_allow_growth) {
+ tensorflow::ConfigProto config;
+ auto* optimizer_options =
+ config.mutable_graph_options()->mutable_optimizer_options();
+ if (enable_xla_compilation) {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
+
+ // These XLA flags are needed to trigger XLA properly from C (more generally
+ // non-Python) clients. If this API is called again with `enable` set to
+ // false, it is safe to keep these flag values as is.
+ tensorflow::legacy_flags::MarkForCompilationPassFlags* flags =
+ tensorflow::legacy_flags::GetMarkForCompilationPassFlags();
+ flags->tf_xla_cpu_global_jit = true;
+ flags->tf_xla_min_cluster_size = 1;
+ } else {
+ optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
+ }
+
+ auto* gpu_options = config.mutable_gpu_options();
+ gpu_options->set_allow_growth(gpu_memory_allow_growth);
+
+ TF_Buffer* ret = TF_NewBuffer();
+ TF_CHECK_OK(MessageToBuffer(config, ret));
+ return ret;
+}
+
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
tensorflow::mutex_lock c(graph->mu);
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 20bdace40f..2d81c01e0d 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -55,11 +55,21 @@ extern "C" {
// set XLA flag values to prepare for XLA compilation. Otherwise set
// global_jit_level to OFF.
//
-// This API is syntax sugar over TF_SetConfig(), and is used by clients that
-// cannot read/write the tensorflow.ConfigProto proto.
+// This and the next API are syntax sugar over TF_SetConfig(), and is used by
+// clients that cannot read/write the tensorflow.ConfigProto proto.
+// TODO: Migrate to TF_CreateConfig() below.
TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options,
unsigned char enable);
+// Create a serialized tensorflow.ConfigProto proto, where:
+//
+// a) ConfigProto.optimizer_options.global_jit_level is set to to ON_1 if
+// `enable_xla_compilation` is non-zero, and OFF otherwise.
+// b) ConfigProto.gpu_options.allow_growth is set to `gpu_memory_allow_growth`.
+TF_CAPI_EXPORT extern TF_Buffer* TF_CreateConfig(
+ unsigned char enable_xla_compilation,
+ unsigned char gpu_memory_allow_growth);
+
// Returns the graph content in a human-readable format, with length set in
// `len`. The format is subject to change in the future.
// The returned string is heap-allocated, and caller should call free() on it.
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 577f10c5e6..bc04b53fbb 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -1160,7 +1160,7 @@ TEST(CAPI, GetOpDef) {
}
void StringVectorToArrays(const std::vector<string>& v,
- std::unique_ptr<const void* []>* ptrs,
+ std::unique_ptr<const void*[]>* ptrs,
std::unique_ptr<size_t[]>* lens) {
ptrs->reset(new const void*[v.size()]);
lens->reset(new size_t[v.size()]);
@@ -1196,7 +1196,7 @@ class CApiColocationTest : public ::testing::Test {
void SetViaStringList(TF_OperationDescription* desc,
const std::vector<string>& list) {
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
StringVectorToArrays(list, &list_ptrs, &list_lens);
TF_SetAttrStringList(desc, tensorflow::kColocationAttrName, list_ptrs.get(),
@@ -1700,6 +1700,61 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) {
TestGradientsError(false);
}
+void ScalarFloatFromTensor(const TF_Tensor* t, float* f) {
+ ASSERT_TRUE(t != nullptr);
+ ASSERT_EQ(TF_FLOAT, TF_TensorType(t));
+ ASSERT_EQ(0, TF_NumDims(t));
+ ASSERT_EQ(4, TF_TensorByteSize(t));
+ float* p = static_cast<float*>(TF_TensorData(t));
+ *f = *p;
+}
+
+TEST_F(CApiGradientsTest, MultipleCallsToAddGradients) {
+ const float X = 3.0f, Y = 7.0f;
+ TF_Operation* x = Placeholder(graph_, s_, "x", TF_FLOAT);
+ TF_Operation* y = Placeholder(graph_, s_, "y", TF_FLOAT);
+ TF_Operation* xy = Mul(x, y, graph_, s_, "xy");
+ TF_Output dxy_dx, dxy_dy;
+
+ TF_Output outputs[1] = {{xy, 0}};
+ TF_Output inputs[1] = {{x, 0}};
+ TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dx);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ inputs[0] = {y, 0};
+ TF_AddGradients(graph_, outputs, 1, inputs, 1, nullptr, s_, &dxy_dy);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_SessionOptions* opts = TF_NewSessionOptions();
+ TF_Session* sess = TF_NewSession(graph_, opts, s_);
+ TF_DeleteSessionOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Output feeds[] = {{x, 0}, {y, 0}};
+ TF_Tensor* feedValues[] = {FloatTensor(X), FloatTensor(Y)};
+ TF_Output fetches[] = {dxy_dx, dxy_dy};
+ TF_Tensor* fetchValues[] = {nullptr, nullptr};
+
+ TF_SessionRun(sess, nullptr /* run_options */, feeds, feedValues, 2, fetches,
+ fetchValues, 2, nullptr /* target_opers */, 0,
+ nullptr /* run_metadata */, s_);
+ TF_DeleteTensor(feedValues[0]);
+ TF_DeleteTensor(feedValues[1]);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_DeleteSession(sess, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ float dxy_dxValue = 0.0f, dxy_dyValue = 0.0f;
+ ScalarFloatFromTensor(fetchValues[0], &dxy_dxValue);
+ EXPECT_EQ(Y, dxy_dxValue);
+
+ ScalarFloatFromTensor(fetchValues[1], &dxy_dyValue);
+ EXPECT_EQ(X, dxy_dyValue);
+
+ TF_DeleteTensor(fetchValues[0]);
+ TF_DeleteTensor(fetchValues[1]);
+}
+
// REGISTER_OP for CApiAttributesTest test cases.
// Registers two ops, each with a single attribute called 'v'.
// The attribute in one op will have a type 'type', the other
@@ -1784,7 +1839,7 @@ TEST_F(CApiAttributesTest, String) {
TEST_F(CApiAttributesTest, StringList) {
std::vector<string> list = {"bugs", "bunny", "duck"};
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
StringVectorToArrays(list, &list_ptrs, &list_lens);
int list_total_size = 0;
@@ -1800,7 +1855,7 @@ TEST_F(CApiAttributesTest, StringList) {
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
EXPECT_TF_META("v", list.size(), TF_ATTR_STRING, list_total_size);
- std::unique_ptr<void* []> values(new void*[list.size()]);
+ std::unique_ptr<void*[]> values(new void*[list.size()]);
std::unique_ptr<size_t[]> lens(new size_t[list.size()]);
std::unique_ptr<char[]> storage(new char[list_total_size]);
TF_OperationGetAttrStringList(oper, "v", values.get(), lens.get(),
@@ -2025,7 +2080,7 @@ TEST_F(CApiAttributesTest, TensorShapeProtoList) {
tensorflow::PartialTensorShape(pts2).AsProto(&proto);
proto.SerializeToString(&bytes2);
- std::unique_ptr<const void* []> list_ptrs;
+ std::unique_ptr<const void*[]> list_ptrs;
std::unique_ptr<size_t[]> list_lens;
const std::vector<string> list = {bytes1, bytes2};
StringVectorToArrays(list, &list_ptrs, &list_lens);
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index f3b28c1708..24eb6c069b 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -216,6 +216,13 @@ TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
return MinWithDevice(l, r, graph, /*op_device=*/"", s, name);
}
+TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name) {
+ TF_Operation* op;
+ BinaryOpHelper("Mul", l, r, graph, s, name, &op, "", true);
+ return op;
+}
+
TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
const char* name) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index c16aba666e..38313d647c 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -80,6 +80,9 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "min");
+TF_Operation* Mul(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
+ TF_Status* s, const char* name = "mul");
+
// If `op_device` is non-empty, set the created op on that device.
TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
const string& op_device, TF_Status* s,
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index f265da2c2c..37be52f57d 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -54,7 +54,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
@@ -93,10 +92,10 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
],
)
@@ -122,6 +121,7 @@ tf_cuda_library(
tf_cuda_cc_test(
name = "c_api_test",
+ size = "small",
srcs = [
"c_api_debug_test.cc",
"c_api_test.cc",
@@ -139,7 +139,7 @@ tf_cuda_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 81221c4078..82ca2be2cf 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -36,9 +36,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -46,10 +46,12 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -107,7 +109,8 @@ tensorflow::Status GetAllRemoteDevices(
}
tensorflow::Status CreateRemoteContexts(
- const std::vector<string>& remote_workers,
+ const std::vector<string>& remote_workers, int64 rendezvous_id,
+ const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) {
@@ -115,12 +118,14 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::eager::CreateContextRequest request;
tensorflow::eager::CreateContextResponse response;
+ request.set_rendezvous_id(rendezvous_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) {
return tensorflow::errors::InvalidArgument(
"Unable to parse ", remote_worker, " as a device name");
}
+ *request.mutable_server_def() = server_def;
request.mutable_server_def()->set_job_name(parsed_name.job);
request.mutable_server_def()->set_task_index(parsed_name.task);
request.set_async(async);
@@ -147,46 +152,82 @@ tensorflow::Status CreateRemoteContexts(
tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts,
TFE_Context** ctx) {
+ // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
+ // server object (which currently CHECK-fails) and we miss the error, instead,
+ // we log the error, and then return to allow the user to see the error
+ // message.
+#define LOG_AND_RETURN_IF_ERROR(...) \
+ do { \
+ const ::tensorflow::Status _status = (__VA_ARGS__); \
+ if (TF_PREDICT_FALSE(!_status.ok())) { \
+ LOG(ERROR) << _status.error_message(); \
+ return _status; \
+ } \
+ } while (0);
+
string worker_name = tensorflow::strings::StrCat(
"/job:", opts->server_def.job_name(),
"/replica:0/task:", opts->server_def.task_index());
- std::unique_ptr<tensorflow::eager::EagerGrpcServer> server;
- TF_RETURN_IF_ERROR(
- tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server));
- TF_RETURN_IF_ERROR(server->Start());
+ std::unique_ptr<tensorflow::ServerInterface> server;
+ LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(opts->server_def, &server));
+
+ tensorflow::GrpcServer* grpc_server =
+ dynamic_cast<tensorflow::GrpcServer*>(server.get());
+ if (grpc_server == nullptr) {
+ LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
+ "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
+ }
+
+ LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
+
+ int64 rendezvous_id = tensorflow::random::New64();
std::vector<string> remote_workers;
- server->master_env()->worker_cache->ListWorkers(&remote_workers);
+ grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(
std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
- TF_RETURN_IF_ERROR(GetAllRemoteDevices(
- remote_workers, server->master_env()->worker_cache, &remote_device_mgr));
+ LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
+ remote_workers, grpc_server->master_env()->worker_cache,
+ &remote_device_mgr));
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
- server->channel_cache();
+ grpc_server->channel_cache();
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
// Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
- TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers,
- remote_eager_workers.get(),
- opts->async, &remote_contexts));
+ LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
+ remote_workers, rendezvous_id, opts->server_def,
+ remote_eager_workers.get(), opts->async, &remote_contexts));
tensorflow::RemoteRendezvous* r =
- server->worker_env()->rendezvous_mgr->Find(0);
+ grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
+
+ auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
+ TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
+ session_name, opts->server_def, true));
+
+ std::shared_ptr<tensorflow::WorkerSession> worker_session;
+ TF_RETURN_IF_ERROR(
+ grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
+ session_name, &worker_session));
+
+ // Initialize remote tensor communication based on worker session.
+ TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
- auto* device_mgr = server->worker_env()->device_mgr;
+ auto* device_mgr = grpc_server->worker_env()->device_mgr;
*ctx = new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, r, std::move(server),
std::move(remote_eager_workers),
std::move(remote_device_mgr), remote_contexts);
return tensorflow::Status::OK();
+#undef LOG_AND_RETURN_IF_ERROR
}
} // namespace
@@ -307,16 +348,16 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
}
int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dims();
+ int result;
+ status->status = h->handle->NumDims(&result);
+ return result;
}
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
- const tensorflow::Tensor* t = nullptr;
- status->status = h->handle->Tensor(&t);
- return t == nullptr ? 0 : t->dim_size(dim_index);
+ tensorflow::int64 result;
+ status->status = h->handle->Dim(dim_index, &result);
+ return result;
}
const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
@@ -421,8 +462,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
return ret;
}
-void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
- op->operation.MutableAttrs()->Set(attr_name, value);
+void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
+ size_t length) {
+ op->operation.MutableAttrs()->Set(
+ attr_name,
+ tensorflow::StringPiece(static_cast<const char*>(value), length));
}
void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
@@ -473,16 +517,22 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
-#define TFE_OP_SET_ATTR_LIST(fn, type) \
- void fn(TFE_Op* op, const char* attr_name, const type* values, \
- int num_values) { \
- op->operation.MutableAttrs()->Set( \
- attr_name, \
- tensorflow::gtl::ArraySlice<const type>(values, num_values)); \
+void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
+ const void* const* values, const size_t* lengths,
+ int num_values) {
+ std::vector<tensorflow::StringPiece> v(num_values);
+ for (int i = 0; i < num_values; ++i) {
+ v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
+ lengths[i]);
}
-TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
-TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
-#undef TFE_OP_SET_ATTR_LIST
+ op->operation.MutableAttrs()->Set(attr_name, v);
+}
+
+void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
+ const float* values, int num_values) {
+ op->operation.MutableAttrs()->Set(
+ attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
+}
void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
const int64_t* values, int num_values) {
@@ -655,9 +705,11 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
const tensorflow::AttrValue& default_value,
const char* attr_name, TF_Status* status) {
switch (default_value.value_case()) {
- case tensorflow::AttrValue::kS:
- TFE_OpSetAttrString(op, attr_name, default_value.s().data());
+ case tensorflow::AttrValue::kS: {
+ const string& v = default_value.s();
+ TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
break;
+ }
case tensorflow::AttrValue::kI:
TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
break;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 1862af3ce2..fdbd5374b2 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -278,7 +278,8 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType(
TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op,
const char* attr_name,
- const char* value);
+ const void* value,
+ size_t length);
TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name,
int64_t value);
TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name,
@@ -305,7 +306,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
const char* attr_name,
- const char** value,
+ const void* const* values,
+ const size_t* lengths,
int num_values);
TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op,
const char* attr_name,
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 04a6efc47c..4c5077023d 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
@@ -78,7 +78,7 @@ struct TFE_Context {
TFE_ContextDevicePlacementPolicy default_policy, bool async,
tensorflow::DeviceMgr* local_device_mgr,
tensorflow::Rendezvous* rendezvous,
- std::unique_ptr<tensorflow::GrpcServer> server,
+ std::unique_ptr<tensorflow::ServerInterface> server,
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr,
const tensorflow::gtl::FlatMap<tensorflow::string, tensorflow::uint64>&
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 27ff5f7211..3504a8b5e7 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include "tensorflow/c/eager/c_api_test_util.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -132,18 +132,20 @@ void TestRemoteExecute(bool async) {
server_def.set_task_index(1);
- std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
- ASSERT_TRUE(
- tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
- .ok());
+ std::unique_ptr<tensorflow::GrpcServer> worker_server;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server)
+ .ok());
ASSERT_TRUE(worker_server->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
status);
- TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts,
+ TFE_DEVICE_PLACEMENT_EXPLICIT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
@@ -205,6 +207,95 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
+void TestRemoteExecuteSilentCopies(bool async) {
+ tensorflow::ServerDef server_def = GetServerDef(3);
+
+ // This server def has the task index set to 0.
+ string serialized = server_def.SerializeAsString();
+
+ server_def.set_task_index(1);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server1;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server1)
+ .ok());
+ ASSERT_TRUE(worker_server1->Start().ok());
+
+ server_def.set_task_index(2);
+ std::unique_ptr<tensorflow::GrpcServer> worker_server2;
+ ASSERT_TRUE(tensorflow::GrpcServer::Create(
+ server_def, tensorflow::Env::Default(), &worker_server2)
+ .ok());
+ ASSERT_TRUE(worker_server2->Start().ok());
+
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
+ status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
+ TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
+ TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
+ const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
+ const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
+
+ auto* h1_task2 =
+ TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ // Handles are on task0 (local), and task2, but op is on task1.
+ TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task2);
+ TFE_OpSetDevice(matmul, task1_name, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(matmul, &retvals[0], &num_retvals, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ auto* retval_task0 = TFE_TensorHandleCopyToDevice(
+ retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteTensorHandle(retval_task0);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteTensorHandle(h0_task0);
+ TFE_DeleteTensorHandle(h1_task0);
+ TFE_DeleteTensorHandle(h1_task2);
+ TFE_DeleteTensorHandle(retvals[0]);
+
+ TFE_DeleteOp(matmul);
+
+ TFE_ContextAsyncWait(ctx, status);
+ TFE_DeleteContext(ctx, status);
+ EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_DeleteStatus(status);
+
+ // TODO(nareshmodi): Figure out how to correctly shut the server down.
+ worker_server1.release();
+ worker_server2.release();
+}
+
+TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
+TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
+ TestRemoteExecuteSilentCopies(true);
+}
+
TEST(CAPI, TensorHandle) {
TFE_TensorHandle* h = TestMatrixTensorHandle();
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
@@ -1083,8 +1174,8 @@ TFE_TensorHandle* CreateVariable(TFE_Context* ctx, float value,
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
TFE_OpSetAttrShape(op, "shape", {}, 0, status);
- TFE_OpSetAttrString(op, "container", "");
- TFE_OpSetAttrString(op, "shared_name", "");
+ TFE_OpSetAttrString(op, "container", "", 0);
+ TFE_OpSetAttrString(op, "shared_name", "", 0);
if (TF_GetCode(status) != TF_OK) return nullptr;
TFE_TensorHandle* var_handle = nullptr;
int num_retvals = 1;
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 734e712daa..1adb0458c3 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -520,7 +520,12 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
}
} else {
any_gradient_nonzero = true;
- auto new_gradients = vspace.AggregateGradients(grad_it->second);
+ Gradient* new_gradients = nullptr;
+ if (grad_it->second.size() == 1) {
+ new_gradients = grad_it->second.at(0);
+ } else {
+ new_gradients = vspace.AggregateGradients(grad_it->second);
+ }
if (sources_set.find(grad_it->first) == sources_set.end()) {
gradients.erase(grad_it);
} else {
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 079e063d3e..a98f0b00b2 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -530,7 +530,7 @@ cc_library_with_android_deps(
"//tensorflow/core/api_def:base_api_def",
],
deps = [
- "//tensorflow/core:framework",
+ "//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:op_gen_lib",
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index d6a4f141b6..dfdef88945 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) {
return "<Unknown AttrValue type>"; // Prevent missing return warning
}
+bool IsEmptyList(const AttrValue::ListValue& list) {
+ return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 &&
+ list.b_size() == 0 && list.type_size() == 0 &&
+ list.shape_size() == 0 && list.tensor_size() == 0;
+}
+
string ToCamelCase(const string& str) {
string result;
const char joiner = '_';
@@ -297,9 +303,9 @@ string ToCamelCase(const string& str) {
// indicate whether to treat the type as const when accepting the C++ type as an
// argument to a function.
std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
- static const std::unordered_map<StringPiece, std::pair<const char*, bool>,
- StringPieceHasher>
- attr_type_map{
+ static const auto* attr_type_map =
+ new std::unordered_map<StringPiece, std::pair<const char*, bool>,
+ StringPieceHasher>{
{"string", {"StringPiece", false}},
{"list(string)", {"gtl::ArraySlice<string>", true}},
{"int", {"int64", false}},
@@ -317,14 +323,34 @@ std::pair<const char*, bool> AttrTypeName(StringPiece attr_type) {
{"func", {"NameAttrList", true}},
};
- auto entry = attr_type_map.find(attr_type);
- if (entry == attr_type_map.end()) {
+ auto entry = attr_type_map->find(attr_type);
+ if (entry == attr_type_map->end()) {
LOG(FATAL) << "Unsupported Attr type: " << attr_type;
return {"", false};
}
return entry->second;
}
+const char* ListElementTypeName(StringPiece attr_type) {
+ static const auto* attr_list_type_map =
+ new std::unordered_map<StringPiece, const char*, StringPieceHasher>{
+ {"list(string)", "string"},
+ {"list(int)", "int"},
+ {"list(float)", "float"},
+ {"list(bool)", "bool"},
+ {"list(type)", "DataType"},
+ {"list(shape)", "PartialTensorShape"},
+ {"list(tensor)", "TensorProto"},
+ };
+
+ auto entry = attr_list_type_map->find(attr_type);
+ if (entry == attr_list_type_map->end()) {
+ LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type;
+ return "";
+ }
+ return entry->second;
+}
+
bool IsCPPKeyword(StringPiece name) {
static const std::unordered_set<StringPiece, StringPieceHasher>
// Keywords obtained from http://en.cppreference.com/w/cpp/keyword
@@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
+ string defaults_static_storage;
for (int i = 0; i < graph_op_def.attr_size(); ++i) {
const auto& attr(graph_op_def.attr(i));
@@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const {
"_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
- strings::StrAppend(
- &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
- "_ = ",
- PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
- ";\n");
+ string field_initiliazer;
+ auto& default_value = api_def_attr.default_value();
+ if (default_value.value_case() == AttrValue::kList &&
+ !IsEmptyList(default_value.list())) {
+ // Non-empty lists need static storage for their defaults. Define a
+ // function with static local variable that stores the array.
+ strings::StrAppend(&defaults_static_storage, " static ",
+ attr_type_name, " Default_", api_def_attr.rename_to(),
+ "() {\n");
+ strings::StrAppend(
+ &defaults_static_storage, " static const ",
+ ListElementTypeName(attr.type()), " kStorage[] = ",
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
+ ";\n");
+ strings::StrAppend(&defaults_static_storage, " return ",
+ attr_type_name, "(kStorage);\n }\n");
+ // Set the field_initializer to call the defined function.
+ strings::StrAppend(&field_initiliazer, "Default_",
+ api_def_attr.rename_to(), "()");
+ } else {
+ field_initiliazer =
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value());
+ }
+ strings::StrAppend(&struct_fields, " ", attr_type_name, " ",
+ api_def_attr.rename_to(), "_ = ", field_initiliazer,
+ ";\n");
}
if (struct_fields.empty()) {
@@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const {
string struct_decl = MakeComment(attrs_comment, " ");
strings::StrAppend(&struct_decl, " struct Attrs {\n");
strings::StrAppend(&struct_decl, setters, struct_fields);
+ if (!defaults_static_storage.empty()) {
+ strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage);
+ }
strings::StrAppend(&struct_decl, " };\n");
return struct_decl;
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 62a889181e..8c886f3171 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -37,6 +37,11 @@ Scope& Scope::operator=(const Scope& other) {
return *this;
}
+namespace {
+const char kScopeSeparator[] = "/";
+const char kSuffixSeparator[] = "_";
+} // namespace
+
Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
ShapeRefiner* refiner, bool disable_shape_inference)
: graph_(graph),
@@ -308,19 +313,23 @@ string Scope::Impl::GetUniqueName(const string& prefix,
return prefix;
}
auto entry = name_map_->find(prefix);
- string unique_name = prefix;
if (entry == name_map_->end()) {
name_map_->insert({prefix, 0});
- } else {
- unique_name = strings::StrCat(unique_name, "_", ++entry->second);
+ return prefix;
}
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second);
+ } while (name_map_->find(unique_name) != name_map_->end());
+ name_map_->insert({unique_name, 0});
return unique_name;
}
string Scope::Impl::GetNameForOp(const string& default_name) const {
const string unique_name =
GetUniqueName(default_name, true /* check_single_use */);
- const string sep = name_.empty() || unique_name.empty() ? "" : "/";
+ const string sep =
+ name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
return strings::StrCat(name_, sep, unique_name);
}
@@ -345,7 +354,8 @@ Scope Scope::NewSubScope(const string& child_scope_name) const {
}
const string unique_name =
impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
- const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/";
+ const string sep =
+ impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
return Scope(new Impl(*this, Impl::Tags::ScopeName(),
strings::StrCat(impl()->name_, sep, unique_name),
false /* copy_names */));
@@ -412,7 +422,7 @@ CompositeOpScopes Scope::GetCompositeOpScopes(
if (!impl()->single_use_scope()) {
Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
: impl()->op_name_);
- const string child_op_sep = impl()->name_.empty() ? "" : "_";
+ const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator;
const string child_name =
strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
return {child,
@@ -435,7 +445,13 @@ class InternalScope {
static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
for (const Node* node : graph->nodes()) {
- (*name_map)[node->name()] = 0;
+ const string& name = node->name();
+ (*name_map)[name] = 0;
+ // Add all name prefixes ('/' separated).
+ size_t idx = -1;
+ while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) {
+ (*name_map)[name.substr(0, idx)] = 0;
+ }
}
// We provide null destructors for these shared ptrs (except for name_map)
// since the caller owns them and doesn't want the scope to destroy them.
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 8efcfed20d..58adaef2e9 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -34,8 +34,7 @@ class Scope::Impl {
// name that has not been used so far in a scope will get no suffix. Later
// uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes
// can share the same NameMap. For instance, a new scope created using
- // WithControlDependencies() should would share the same NameMap with the
- // parent.
+ // WithControlDependencies() would share the same NameMap with the parent.
typedef std::unordered_map<string, int> NameMap;
Impl(const std::shared_ptr<Graph>& graph,
diff --git a/tensorflow/cc/framework/scope_test.cc b/tensorflow/cc/framework/scope_test.cc
index 9eca9d3fac..b40b345eb8 100644
--- a/tensorflow/cc/framework/scope_test.cc
+++ b/tensorflow/cc/framework/scope_test.cc
@@ -26,6 +26,16 @@ TEST(ScopeTest, BasicNames) {
EXPECT_EQ(root.GetUniqueNameForOp("mul"), "mul");
}
+TEST(ScopeTest, OpAndScopeNameCollision) {
+ Scope root = Scope::NewRootScope();
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_1");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_1"), "foo_1_1");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo"), "foo_3");
+ EXPECT_EQ(root.GetUniqueNameForOp("foo_2"), "foo_2_1");
+}
+
TEST(ScopeTest, HierarchicalNames) {
Scope root = Scope::NewRootScope();
Scope child = root.NewSubScope("child");
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index ff348fadb2..b353accddc 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -421,6 +421,58 @@ Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
+Status SliceGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // Propagate the incoming gradient along all the selected values,
+ // and zero everywhere else. Use the Pad operator for this.
+ //
+ // First create an Nx2 padding where N is the number of input
+ // dimensions. The first column is the number of prepended zeros
+ // for each dimension, and the second column is the number of
+ // appended zeros.
+ //
+ // The first column is just the begin vector.
+ // The second column is the shape of the input element-wise
+ // subtracted by begin+size
+
+ // Running example:
+ // input.shape = [3, 5, 3]
+ // begin = [1, 2, 1], size = [1, 3, 2]
+ Input input = op.input(0);
+ Input begin = op.input(1);
+ // input_rank = 3
+ auto input_rank = Rank(scope, input);
+ // slice_size = [1, 3, 2]
+ auto slice_size = Shape(scope, op.output(0));
+ // padding_shape = [3, 1]
+ auto padding_shape = Stack(scope, {input_rank, 1});
+ // before_padding = [[1]
+ // [2]
+ // [1]]
+ Input before_padding = Reshape(scope, begin, padding_shape);
+ // after_padding_sizes = shape(input) - slice_size - begin
+ // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
+ // = [1, 0, 0]
+ auto after_padding_sizes =
+ Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
+ // after_padding = [[1]
+ // [0]
+ // [0]]
+ Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
+ // paddings = [[1 1]
+ // [2 0]
+ // [1 0]]
+ auto paddings =
+ Concat(scope, {before_padding, after_padding}, Const(scope, 1));
+ grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
+ // Nothing propagated for "begin" and "size" inputs
+ grad_outputs->push_back(NoGradient());
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Slice", SliceGrad);
+
} // anonymous namespace
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc
index de3bd0fc9e..d09275b648 100644
--- a/tensorflow/cc/gradients/array_grad_test.cc
+++ b/tensorflow/cc/gradients/array_grad_test.cc
@@ -378,5 +378,12 @@ TEST_F(ArrayGradTest, StridedSliceGrad) {
RunTest(x, x_shape, y, {1, 2, 2, 2});
}
+TEST_F(ArrayGradTest, SliceGrad) {
+ TensorShape x_shape({3, 5, 3});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = Slice(scope_, x, {1, 2, 1}, {1, 3, 2});
+ RunTest(x, x_shape, y, {1, 3, 2});
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 52c177212a..35a01e0341 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual");
REGISTER_NO_GRADIENT_OP("LogicalAnd");
REGISTER_NO_GRADIENT_OP("LogicalOr");
REGISTER_NO_GRADIENT_OP("LogicalNot");
+REGISTER_NO_GRADIENT_OP("Floor");
// Conjugate helper function returns the conjugate of an Output if it
// is complex valued.
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 0025842aea..28070d60db 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -287,7 +287,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
- if (result_index < 0 || result_index > temp_sizes.size()) {
+ if (result_index < 0 || result_index >= temp_sizes.size()) {
return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,",
temp_sizes.size(), ")");
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index fd2cf2b67d..0ecc3feeb6 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -7,6 +7,10 @@ package(
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+# We disable some tfcompile tests in the open source build with the
+# "manual" tag to avoid making our OSS users build LLVM twice
+# (once for host and once for target).
+
test_suite(
name = "all_tests",
tags = ["manual"],
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 6d6c030a26..c2245b8eae 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -175,11 +176,14 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
+ "//tensorflow/core/kernels:shape_ops",
"//tensorflow/core/kernels:variable_ops",
],
)
@@ -312,9 +316,9 @@ cc_library(
":common",
":shape_inference_helpers",
":union_find",
+ ":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
- "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:parallel_check_op",
"//tensorflow/compiler/jit/ops:xla_ops",
@@ -333,6 +337,19 @@ cc_library(
)
cc_library(
+ name = "xla_cluster_util",
+ srcs = ["xla_cluster_util.cc"],
+ hdrs = ["xla_cluster_util.h"],
+ deps = [
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/kernels:bounds_check",
+ ],
+)
+
+cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
@@ -384,6 +401,32 @@ tf_cc_test(
)
tf_cc_test(
+ name = "xla_cluster_util_test",
+ size = "small",
+ srcs = [
+ "xla_cluster_util_test.cc",
+ ],
+ deps = [
+ ":common",
+ ":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cc_test(
name = "xla_launch_util_test",
size = "small",
srcs = ["xla_launch_util_test.cc"],
@@ -408,6 +451,38 @@ tf_cc_test(
],
)
+cc_library(
+ name = "xla_fusion_optimizer",
+ srcs = ["xla_fusion_optimizer.cc"],
+ hdrs = ["xla_fusion_optimizer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":common",
+ ":union_find",
+ ":xla_cluster_util",
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/core:core_cpu_base",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "xla_fusion_optimizer_test",
+ srcs = ["xla_fusion_optimizer_test.cc"],
+ deps = [
+ ":common",
+ ":xla_cluster_util",
+ ":xla_fusion_optimizer",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 731b8ebfdc..a2e6285339 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -66,8 +66,28 @@ class SinglePassSearch {
Status CompilationRequested(const FunctionLibraryRuntime& flr,
const NodeDef& node_def) {
+ const FunctionDef* function_def =
+ flr.GetFunctionLibraryDefinition()->Find(node_def.name());
+ if (function_def == nullptr) {
+ // The node def is not calling a function. Individual ops can be
+ // run directly using on-demand mode, no need to create XlaLaunch
+ // kernel for them.
+ // TODO(b/110359382): Make custom kernel creation return a bool instead of
+ // status.
+ // We don't set error messages here to avoid unnecessary string copy.
+ // Similarly below.
+ return Status(error::INVALID_ARGUMENT, "");
+ }
+
+ // If kXlaCompileAttr is set on the node_def, use its value.
+ const auto& it = node_def.attr().find(kXlaCompileAttr);
+ if (it != node_def.attr().end()) {
+ return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, "");
+ }
+
+ // kXlaCompileAttr is not set on node_def, check if it is set on
+ // FunctionDef.
bool xla_compile = false;
- // Check if op is marked _XlaCompile=true.
Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
node_def, kXlaCompileAttr, &xla_compile);
if (!status.ok() || !xla_compile) {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 6d1e3325eb..9c424b201e 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
-#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
@@ -61,9 +60,9 @@ const char* const kXlaHostTransferSequencerAttr =
namespace {
-bool AreAllParentsConst(const Node& n,
- const gtl::FlatSet<const Node*>& runtime_const_nodes) {
- if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
+bool AreAllParentsGuaranteedConst(
+ const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
return true;
@@ -94,7 +93,8 @@ void MarkGuaranteedConstants(
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](const Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
- if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
+ if (AreAllParentsGuaranteedConst(*n,
+ guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});
@@ -107,41 +107,11 @@ void MarkGuaranteedConstants(
}
}
-// A node/slot pair.
-// TODO(phawkins): is there a common definition of this?
-struct NodeSlot {
- NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot)
- : node(node), slot(slot), dtype(DT_INVALID) {}
- NodeSlot(const Node* node, int slot, DataType dtype)
- : node(node), slot(slot), dtype(dtype) {}
-
- const Node* node;
- int slot;
-
- // Optional: used to record the destination type of a source NodeSlot in case
- // the source output is a Ref type that is cast to a Tensor at the
- // destination.
- DataType dtype;
-
- bool operator==(const NodeSlot& other) const {
- return node == other.node && slot == other.slot && dtype == other.dtype;
- }
-
- // Leave dtype out of the hash since there are never two NodeSlots with the
- // same node and slot and different dtypes.
- struct Hasher {
- uint64 operator()(NodeSlot const& s) const {
- return Hash64Combine(std::hash<const Node*>()(s.node),
- std::hash<int>()(s.slot));
- }
- };
-
- struct PairHasher {
- uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
- return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
- }
- };
+struct OutputInputTensorPairHasher {
+ uint64 operator()(std::pair<OutputTensor, InputTensor> const& s) const {
+ return Hash64Combine(OutputTensor::Hash()(s.first),
+ InputTensor::Hash()(s.second));
+ }
};
// TODO(phawkins) add a canonical copy of these operator names and refactor
@@ -182,8 +152,7 @@ class Encapsulator {
// Write a copy of the input graph to 'graph_out', where the subgraphs are
// replaced with calls to the new functions.
- Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
- FunctionLibraryDefinition* library);
+ Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library);
private:
// A subgraph of the input, all marked with a common 'group_attribute'
@@ -271,7 +240,7 @@ class Encapsulator {
// Adds the function call node to graph_out.
Status AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out);
+ Graph* graph_out);
// Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out.
Status AddOutsideCompilationHostIONodes(
@@ -284,11 +253,9 @@ class Encapsulator {
// Subgraph.
void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
- // Returns the Node that inputs to the function should be wired up to.
- Node* GetCallNodeForInputs() const;
-
- // Returns the Node that outputs to the function should be wired up to.
- Node* GetCallNodeForOutputs() const;
+ // Returns the Node that the inputs and outputs of the function should be
+ // wired up to.
+ Node* GetCallNode() const;
// Returns the index of the arg that the dst of edge should connect to.
int GetArgIndexForEdge(const Edge* edge) const;
@@ -380,7 +347,7 @@ class Encapsulator {
// Map from source (producer node/slot) tensors in the original graph to
// input index (slot number in the HostCompute/RecvAtHost nodes that will
// be created) for the outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> inputs;
// Set of nodes in the original graph that are the source of control edges
// that cross from the containing compiled subgraph into the
@@ -396,8 +363,15 @@ class Encapsulator {
// node/slot) tensors in the original graph to output index (slot number
// in the SendFromHost/HostCompute nodes that will be created) for the
// outside_compilation subgraph.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
+ struct ArgNumAndType {
+ int index;
+ DataType dtype;
+
+ ArgNumAndType(int i, DataType t) : index(i), dtype(t) {}
+ };
+ std::unordered_map<OutputTensor, ArgNumAndType, OutputTensor::Hash>
+ outputs_by_src;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> outputs_by_dst;
// Set of nodes in the original graph that are the destination of control
// edges that cross from the outside_compilation subgraph into the
@@ -425,12 +399,6 @@ class Encapsulator {
OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph(
const string& outside_compilation_id);
- // Builds a ParallelCheck op that compares the output of the original
- // subgraph with the encapsulated subgraph.
- Status BuildParallelCheckOp(
- const std::unordered_map<const Node*, Node*>& node_images,
- Graph* graph_out);
-
// Builds a placeholder node used to provide the key input to a RecvAtHost
// or SendFromHost node. This placeholder node will be removed by a later
// pass.
@@ -482,26 +450,21 @@ class Encapsulator {
// Not owned.
Node* host_compute_key_placeholder_ = nullptr;
- // Function call node(s) in the output graph. Not owned.
- // If parallel_checking is enabled, 'call_node_inputs' is the function call
- // node to which inputs should be fed, and 'call_node_outputs' is the
- // parallel check op from which outputs should be read. If parallel checking
- // is disabled, both point to the function call node.
- Node* call_node_inputs_;
- Node* call_node_outputs_;
+ // Function call node in the output graph. Not owned.
+ Node* call_node_;
// Maps from source (producer node/slot) and destination
// (consumer node/slot) tensors in the input graph to _Arg numbers in
// the subgraph. The source map is one-to-one, whereas the dest map may be
// many-to-one.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src_;
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> args_by_src_;
+ std::unordered_map<InputTensor, int, InputTensor::Hash> args_by_dst_;
- // The _Arg nodes in the subgraph, in order by argument number.
+ // The arguments to the subgraph, in order.
std::vector<Node*> args_;
// Map from source tensor in the input graph to result #.
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results_;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash> results_;
// The outside_compilation clusters in this subgraph.
std::unordered_map<string, OutsideCompilationSubgraph>
@@ -541,13 +504,12 @@ class Encapsulator {
// Copies all nodes that aren't in a compiled subgraph to the output graph.
Status CopyNodesToOutputGraph(
- bool parallel_checking, Graph* graph_out,
- std::unordered_map<const Node*, Node*>* node_images);
+ Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images);
// Adds function call nodes for each compiled subgraph.
Status AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out);
+ Graph* graph_out);
// Adds _RecvAtHost and _SendFromHost nodes, where needed, for all
// outside_compilation subgraphs.
@@ -598,9 +560,9 @@ class Encapsulator {
const string& src_outside_compilation_id, const string& dst_func_id,
const string& dst_outside_compilation_id,
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added);
+ Graph* graph_out,
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added);
// Adds control dependencies between subgraph call nodes that have
// dependencies via outside_compilation edges.
@@ -609,7 +571,7 @@ class Encapsulator {
// Adds all edges to the output graph.
Status AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out);
+ Graph* graph_out);
// Constructs a minimal shape inference graph that can be used to determine
// the shape of send_node at the time that the subgraph is compiled.
@@ -729,20 +691,14 @@ void TopologicalClusterSort(
} // namespace
-Node* Encapsulator::Subgraph::GetCallNodeForInputs() const {
- return call_node_inputs_;
-}
-
-Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const {
- return call_node_outputs_;
-}
+Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; }
int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const {
- return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input()));
+ return args_by_dst_.at(InputTensor(edge->dst(), edge->dst_input()));
}
int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const {
- return results_.at(NodeSlot(edge->src(), edge->src_output()));
+ return results_.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetRecvAtHostNode(
@@ -754,7 +710,7 @@ Node* Encapsulator::Subgraph::GetRecvAtHostNode(
int Encapsulator::Subgraph::GetRecvAtHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .inputs.at(NodeSlot(edge->src(), edge->src_output()));
+ .inputs.at(OutputTensor(edge->src(), edge->src_output()));
}
Node* Encapsulator::Subgraph::GetSendFromHostNode(
@@ -766,7 +722,7 @@ Node* Encapsulator::Subgraph::GetSendFromHostNode(
int Encapsulator::Subgraph::GetSendFromHostSlot(
const string& outside_compilation_subgraph_name, const Edge* edge) const {
return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name)
- .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
+ .outputs_by_dst.at(InputTensor(edge->dst(), edge->dst_input()));
}
Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
@@ -791,10 +747,10 @@ Status Encapsulator::Subgraph::RecordArg(
std::vector<std::pair<const Node*, Node*>>* src_arg_pairs) {
Node* src_node = edge->src();
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
- std::tie(iter, inserted) =
- args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size());
+ std::tie(iter, inserted) = args_by_src_.emplace(
+ OutputTensor(src_node, src_slot), args_by_src_.size());
int arg_index = iter->second;
if (inserted) {
NodeDef arg_def;
@@ -815,7 +771,7 @@ Status Encapsulator::Subgraph::RecordArg(
Node* dst_node = edge->dst();
Node* dst_image = node_images.at(dst_node);
int dst_slot = edge->dst_input();
- args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index;
+ args_by_dst_[InputTensor(dst_node, dst_slot)] = arg_index;
graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot);
return Status::OK();
}
@@ -826,10 +782,10 @@ Status Encapsulator::Subgraph::RecordResult(
Node* src_node = edge->src();
Node* src_image = node_images.at(src_node);
int src_slot = edge->src_output();
- std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ std::unordered_map<OutputTensor, int, OutputTensor::Hash>::iterator iter;
bool inserted;
std::tie(iter, inserted) =
- results_.emplace(NodeSlot(src_node, src_slot), results_.size());
+ results_.emplace(OutputTensor(src_node, src_slot), results_.size());
int ret_index = iter->second;
if (inserted) {
NodeDef ret_def;
@@ -867,8 +823,8 @@ void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl(
outside_subgraph->control_inputs.insert(edge->src());
} else {
int input_index = outside_subgraph->inputs.size();
- outside_subgraph->inputs.emplace(NodeSlot(edge->src(), edge->src_output()),
- input_index);
+ outside_subgraph->inputs.emplace(
+ OutputTensor(edge->src(), edge->src_output()), input_index);
}
}
@@ -882,11 +838,13 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
DataType dtype = edge->dst()->input_type(edge->dst_input());
auto output_iter =
outside_subgraph->outputs_by_src
- .emplace(NodeSlot(edge->src(), edge->src_output(), dtype),
- outside_subgraph->outputs_by_src.size())
+ .emplace(OutputTensor(edge->src(), edge->src_output()),
+ OutsideCompilationSubgraph::ArgNumAndType(
+ outside_subgraph->outputs_by_src.size(), dtype))
.first;
- int output_index = output_iter->second;
- outside_subgraph->outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
+ const int output_index = output_iter->second.index;
+ outside_subgraph
+ ->outputs_by_dst[InputTensor(edge->dst(), edge->dst_input())] =
output_index;
}
}
@@ -968,7 +926,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
DataType dtype = src_node->output_type(src_slot);
@@ -976,8 +934,8 @@ Status Encapsulator::Subgraph::AddHostComputes(
input_dtypes[input_index] = dtype;
}
for (const auto& output : oc_subgraph.outputs_by_src) {
- DataType dtype = output.first.dtype;
- int output_index = output.second;
+ DataType dtype = output.second.dtype;
+ int output_index = output.second.index;
output_dtypes[output_index] = dtype;
}
@@ -1015,7 +973,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = input_src.first.slot;
+ int src_slot = input_src.first.index;
int input_index = input_src.second;
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
}
@@ -1037,7 +995,7 @@ Status Encapsulator::Subgraph::AddHostComputes(
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
Node* dst_image = node_images.at(dst_node);
- int dst_slot = output.first.slot;
+ int dst_slot = output.first.index;
int output_index = output.second;
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
@@ -1075,7 +1033,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name,
void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) {
if (sequencer_ != nullptr) {
VLOG(2) << "ConnectSequencerToCallNode";
- graph_out->AddControlEdge(sequencer_, call_node_inputs_);
+ graph_out->AddControlEdge(sequencer_, call_node_);
}
}
@@ -1090,14 +1048,19 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
call_node_def_.set_device(device_);
if (rewrite_subgraph_fn) {
+ std::vector<OutputTensor> arg_source_tensors(args_by_src_.size());
+ for (const auto& arg : args_by_src_) {
+ arg_source_tensors.at(arg.second) = arg.first;
+ }
// Initialize the input and output permutations to the identity.
std::vector<int> input_permutation(args_by_src_.size());
std::iota(input_permutation.begin(), input_permutation.end(), 0);
std::vector<int> output_permutation(results_.size());
std::iota(output_permutation.begin(), output_permutation.end(), 0);
- TF_RETURN_IF_ERROR(rewrite_subgraph_fn(
- &graph_, &input_permutation, &output_permutation, &call_node_def_));
+ TF_RETURN_IF_ERROR(
+ rewrite_subgraph_fn(arg_source_tensors, &graph_, &input_permutation,
+ &output_permutation, &call_node_def_));
// Apply the input/output permutations to the 'args_by_...' and 'results_'
// mappings, so when we build edges in BuildOutputGraph() we
@@ -1174,7 +1137,10 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo(
GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef));
host_compute->AddAttr("shape_inference_graph", inference_graph_name);
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
- TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ // TODO(sibyl-Aix6ihai): Understand why there are multiple calls to Encapsulator.
+ if (library->Find(inference_graph_name) == nullptr) {
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ }
}
return Status::OK();
}
@@ -1200,83 +1166,16 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
return Status::OK();
}
-Status Encapsulator::Subgraph::BuildParallelCheckOp(
- const std::unordered_map<const Node*, Node*>& node_images,
- Graph* graph_out) {
- // Build an index mapping output positions to node/slot pairs in the
- // original graph.
- std::vector<NodeSlot> results_by_num(results_.size());
- for (const auto& entry : results_) {
- results_by_num[entry.second] = entry.first;
- }
-
- // Build a parallel check NodeDef.
- int num_results = results_by_num.size();
- std::vector<DataType> result_dtypes(num_results);
- std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
- std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
- for (int i = 0; i < num_results; ++i) {
- const NodeSlot& node_slot = results_by_num[i];
- result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
- expected_outputs[i] =
- NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
- node_slot.slot, result_dtypes[i]);
- actual_outputs[i] =
- NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]);
- }
- // Assign the parallel check op to a CPU on the same task as the cluster it is
- // checking.
- string device, dummy;
- if (!DeviceNameUtils::SplitDeviceName(
- call_node_inputs_->assigned_device_name(), &device, &dummy)) {
- return errors::InvalidArgument("Could not parse device name");
- }
- strings::StrAppend(&device, "/cpu:0");
-
- NodeDef check_def;
- TF_RETURN_IF_ERROR(
- NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(),
- "_parallel_check")),
- "ParallelCheck")
- .Device(device)
- .Attr("T", result_dtypes)
- .Input(expected_outputs)
- .Input(actual_outputs)
- .Finalize(&check_def));
-
- Status s;
- Node* check_op = graph_out->AddNode(check_def, &s);
- if (!s.ok()) return s;
- check_op->set_assigned_device_name(device);
-
- // TODO(phawkins): it seems redundant to call AddEdge as well as
- // pass Inputs to the NodeDefBuilder, but I have been unable to find a
- // way to avoid it.
- for (int i = 0; i < num_results; ++i) {
- const NodeSlot& node_slot = results_by_num[i];
- graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
- i);
- graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i);
- }
-
- call_node_outputs_ = check_op;
- return Status::OK();
-}
-
Status Encapsulator::Subgraph::AddFunctionCallNode(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out) {
+ Graph* graph_out) {
Status s;
- call_node_inputs_ = graph_out->AddNode(call_node_def_, &s);
+ call_node_ = graph_out->AddNode(call_node_def_, &s);
if (!s.ok()) return s;
// Copy the assigned device and the key_annotation over.
- call_node_inputs_->set_assigned_device_name(device_);
- call_node_outputs_ = call_node_inputs_;
+ call_node_->set_assigned_device_name(device_);
- if (parallel_checking) {
- TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out));
- }
return Status::OK();
}
@@ -1315,7 +1214,7 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
for (const auto& input : oc_subgraph->inputs) {
const Node* src_node = input.first.node;
- int src_slot = input.first.slot;
+ int src_slot = input.first.index;
int input_index = input.second;
DataType dtype = src_node->output_type(src_slot);
@@ -1369,8 +1268,8 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
for (const auto& output : oc_subgraph->outputs_by_src) {
const Node* src_node = output.first.node;
Node* src_image = node_images.at(src_node);
- int src_slot = output.first.slot;
- int output_index = output.second;
+ int src_slot = output.first.index;
+ int output_index = output.second.index;
DataType dtype = src_node->output_type(src_slot);
dtypes[output_index] = dtype;
@@ -1609,6 +1508,9 @@ Status Encapsulator::SplitIntoSubgraphs() {
for (auto& entry : subgraphs_) {
Subgraph& subgraph = entry.second;
FixupSourceAndSinkEdges(subgraph.GetGraph());
+ // Verify that the graph has well-formed control flow structure.
+ std::vector<ControlFlowInfo> dummy;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(subgraph.GetGraph(), &dummy));
}
return s;
@@ -1627,27 +1529,17 @@ Status Encapsulator::BuildFunctionDefs(
}
Status Encapsulator::CopyNodesToOutputGraph(
- bool parallel_checking, Graph* graph_out,
- std::unordered_map<const Node*, Node*>* node_images) {
+ Graph* graph_out, std::unordered_map<const Node*, Node*>* node_images) {
for (Node* node : graph_in_->op_nodes()) {
string func_id;
string outside_compilation_id;
TF_RETURN_IF_ERROR(
GetFunctionNameAttr(node, &func_id, &outside_compilation_id));
- // Don't copy nodes that going to be encapsulated, unless parallel checking
- // is enabled.
- if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking)
- continue;
+ // Don't copy nodes that are going to be encapsulated.
+ if (IsInSubgraph(func_id, outside_compilation_id)) continue;
Node* image = graph_out->CopyNode(node);
- if (!outside_compilation_id.empty()) {
- if (parallel_checking) {
- return errors::InvalidArgument(
- "Parallel checking is not supported when outside_compilation "
- "clusters are present.");
- }
- }
(*node_images)[node] = image;
}
(*node_images)[graph_in_->source_node()] = graph_out->source_node();
@@ -1657,10 +1549,10 @@ Status Encapsulator::CopyNodesToOutputGraph(
Status Encapsulator::AddFunctionCallNodes(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out) {
+ Graph* graph_out) {
for (auto& subgraph_entry : subgraphs_) {
- TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode(
- node_images, parallel_checking, graph_out));
+ TF_RETURN_IF_ERROR(
+ subgraph_entry.second.AddFunctionCallNode(node_images, graph_out));
}
return Status::OK();
}
@@ -1694,7 +1586,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc(
} else {
// The edge is from a subgraph to a regular node in the output graph so
// use the subgraph's call node output.
- *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs();
+ *src_image = subgraphs_.at(src_func_id).GetCallNode();
}
} else {
// The source of the edge is in the output graph so use the node image in
@@ -1742,7 +1634,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst(
} else {
// The edge is to a subgraph from a regular node in the output graph so
// use the subgraph's call node input.
- *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs();
+ *dst_image = subgraphs_.at(dst_func_id).GetCallNode();
}
} else {
// The destination of the edge is in the output graph so use the node image
@@ -1778,10 +1670,9 @@ Status Encapsulator::CopyEdgeToOutputGraph(
const Edge* edge, const string& src_func_id,
const string& src_outside_compilation_id, const string& dst_func_id,
const string& dst_outside_compilation_id,
- const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out,
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>*
- edges_added) {
+ const std::unordered_map<const Node*, Node*>& node_images, Graph* graph_out,
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>* edges_added) {
Node* src_image;
TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc(
src_func_id, src_outside_compilation_id, dst_func_id,
@@ -1796,16 +1687,12 @@ Status Encapsulator::CopyEdgeToOutputGraph(
if (edge->IsControlEdge()) {
// Add the control edge, if we have not already added it, using the images
// determined above (potentially call operators or RecvAtHost/SendFromHost).
- if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
+ if (edges_added
+ ->emplace(OutputTensor(src_image, -1), InputTensor(dst_image, -1))
.second) {
graph_out->AddControlEdge(src_image, dst_image);
}
- // If parallel checking is enabled, also add a control edge to the
- // corresponding parallel check op.
- if (parallel_checking) {
- graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
- }
return Status::OK();
}
@@ -1817,18 +1704,10 @@ Status Encapsulator::CopyEdgeToOutputGraph(
FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id,
dst_func_id, dst_outside_compilation_id, edge);
- if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
- parallel_checking) {
- // If we are parallel checking, also feed the tensor as an input to the
- // corresponding parallel check subgraph.
- graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
- edge->dst_input());
- }
-
// Add the edge, if we have not already added it.
if (edges_added
- ->emplace(NodeSlot(src_image, src_output),
- NodeSlot(dst_image, dst_input))
+ ->emplace(OutputTensor(src_image, src_output),
+ InputTensor(dst_image, dst_input))
.second) {
graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
}
@@ -1839,8 +1718,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) {
for (const auto& ancestors : subgraph_ancestors_) {
const string& subgraph = ancestors.first;
for (const string& ancestor : ancestors.second) {
- graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(),
- subgraphs_[subgraph].GetCallNodeForInputs());
+ graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(),
+ subgraphs_[subgraph].GetCallNode());
}
}
return Status::OK();
@@ -1848,11 +1727,12 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) {
Status Encapsulator::AddEdgesToOutputGraph(
const std::unordered_map<const Node*, Node*>& node_images,
- bool parallel_checking, Graph* graph_out) {
+ Graph* graph_out) {
// Set of edges already added to the output graph, represented as (src, dst)
// pairs. We use the set to deduplicate edges; multiple edges in the input
// graph may map to one edge in the output graph.
- std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
+ std::unordered_set<std::pair<OutputTensor, InputTensor>,
+ OutputInputTensorPairHasher>
edges_added;
for (const Edge* edge : graph_in_->edges()) {
@@ -1870,16 +1750,6 @@ Status Encapsulator::AddEdgesToOutputGraph(
if (IsInSubgraph(src_func_id, src_outside_compilation_id) &&
IsInSubgraph(dst_func_id, dst_outside_compilation_id) &&
src_func_id == dst_func_id) {
- if (parallel_checking) {
- Node* src_image = node_images.at(edge->src());
- Node* dst_image = node_images.at(edge->dst());
- if (edge->IsControlEdge()) {
- graph_out->AddControlEdge(src_image, dst_image);
- } else {
- graph_out->AddEdge(src_image, edge->src_output(), dst_image,
- edge->dst_input());
- }
- }
continue;
}
@@ -1887,8 +1757,7 @@ Status Encapsulator::AddEdgesToOutputGraph(
// unclustered graph.
TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph(
edge, src_func_id, src_outside_compilation_id, dst_func_id,
- dst_outside_compilation_id, node_images, parallel_checking, graph_out,
- &edges_added));
+ dst_outside_compilation_id, node_images, graph_out, &edges_added));
}
for (auto& subgraph_entry : subgraphs_) {
@@ -2504,18 +2373,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
return Status::OK();
}
-Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
+Status Encapsulator::BuildOutputGraph(Graph* graph_out,
FunctionLibraryDefinition* library) {
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
- TF_RETURN_IF_ERROR(
- CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images));
- TF_RETURN_IF_ERROR(
- AddFunctionCallNodes(node_images, parallel_checking, graph_out));
+ TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images));
+ TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out));
TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out));
- TF_RETURN_IF_ERROR(
- AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
+ TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out));
TF_RETURN_IF_ERROR(
GetShapeInfoForOutsideCompilationSends(graph_out, library));
@@ -2528,8 +2394,8 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
Status EncapsulateSubgraphsInFunctions(
string group_attribute, string outside_compilation_attribute,
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
- bool parallel_checking, bool reuse_existing_functions,
- std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
+ bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
+ FunctionLibraryDefinition* library) {
Status s;
Encapsulator encapsulator(std::move(group_attribute),
@@ -2543,8 +2409,7 @@ Status EncapsulateSubgraphsInFunctions(
std::unique_ptr<Graph> out(new Graph(library));
out->set_versions(graph_in.versions());
- TF_RETURN_IF_ERROR(
- encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
+ TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library));
*graph_out = std::move(out);
return Status::OK();
@@ -2585,8 +2450,6 @@ static Status RenumberArguments(Graph* graph,
Status EncapsulateSubgraphsPass::Run(
const GraphOptimizationPassOptions& options) {
VLOG(1) << "EncapsulateSubgraphsPass::Run";
- legacy_flags::EncapsulateSubgraphsPassFlags* flags =
- legacy_flags::GetEncapsulateSubgraphsPassFlags();
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
options.flib_def);
@@ -2602,69 +2465,73 @@ Status EncapsulateSubgraphsPass::Run(
FunctionLibraryRuntime* flr =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
- auto rewrite_subgraph = [flr](std::unique_ptr<Graph>* subgraph,
- std::vector<int>* input_permutation,
- std::vector<int>* output_permutation,
- NodeDef* node) {
- // Optimize the subgraph.
- OptimizeGraph(flr, subgraph);
-
- const int num_args = input_permutation->size();
- std::vector<bool> const_args(num_args);
- TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
-
- DataTypeVector arg_types(num_args);
- TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
-
- // Compute a permutation of the arguments such that the constant arguments
- // are first.
- const int num_consts =
- std::count(const_args.begin(), const_args.end(), true);
-
- const int num_resources =
- std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
- const int num_nonconsts = num_args - num_resources - num_consts;
- if (num_nonconsts < 0) {
- return errors::Internal("num_nonconsts should be >= 0, was ",
- num_nonconsts);
- }
-
- int const_pos = 0;
- int arg_pos = num_consts;
- int resource_pos = num_consts + num_nonconsts;
- for (int i = 0; i < num_args; ++i) {
- if (const_args[i]) {
- if (arg_types[i] == DT_RESOURCE) {
- return errors::Internal(
- "Resource arguments cannot be constant (argument ", i, ")");
+ auto rewrite_subgraph =
+ [flr](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* subgraph,
+ std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation, NodeDef* node) {
+ // Optimize the subgraph.
+ OptimizeGraph(flr, subgraph);
+
+ const int num_args = input_permutation->size();
+ std::vector<bool> const_args(num_args);
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(**subgraph, &const_args));
+
+ DataTypeVector arg_types(num_args);
+ TF_RETURN_IF_ERROR(GetArgTypes(**subgraph, &arg_types));
+
+ // Compute a permutation of the arguments such that the constant
+ // arguments are first.
+ const int num_consts =
+ std::count(const_args.begin(), const_args.end(), true);
+
+ const int num_resources =
+ std::count(arg_types.begin(), arg_types.end(), DT_RESOURCE);
+ const int num_nonconsts = num_args - num_resources - num_consts;
+ if (num_nonconsts < 0) {
+ return errors::Internal("num_nonconsts should be >= 0, was ",
+ num_nonconsts);
}
- (*input_permutation)[i] = const_pos;
- ++const_pos;
- } else if (arg_types[i] == DT_RESOURCE) {
- (*input_permutation)[i] = resource_pos;
- ++resource_pos;
- } else {
- (*input_permutation)[i] = arg_pos;
- ++arg_pos;
- }
- }
- // Renumber argument nodes in the graph.
- TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
-
- // TODO(phawkins): add a forward is-constant analysis, similarly split
- // outputs into host-memory constants and device-memory non-constants.
-
- AddNodeAttr(kXlaCompiledKernelAttr, true, node);
- AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
- AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
- return Status::OK();
- };
+ int const_pos = 0;
+ int arg_pos = num_consts;
+ int resource_pos = num_consts + num_nonconsts;
+ for (int i = 0; i < num_args; ++i) {
+ if (const_args[i]) {
+ if (arg_types[i] == DT_RESOURCE) {
+ return errors::Internal(
+ "Resource arguments cannot be constant (argument ", i, ")");
+ }
+ (*input_permutation)[i] = const_pos;
+ ++const_pos;
+ } else if (arg_types[i] == DT_RESOURCE) {
+ (*input_permutation)[i] = resource_pos;
+ ++resource_pos;
+ } else {
+ (*input_permutation)[i] = arg_pos;
+ ++arg_pos;
+ }
+ }
- TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
- kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
- rewrite_subgraph, flags->tf_xla_parallel_checking,
- /*reuse_existing_functions=*/false, &graph_out, library));
+ // Renumber argument nodes in the graph.
+ TF_RETURN_IF_ERROR(
+ RenumberArguments(subgraph->get(), *input_permutation));
+
+ // TODO(phawkins): add a forward is-constant analysis, similarly split
+ // outputs into host-memory constants and device-memory non-constants.
+
+ AddNodeAttr(kXlaCompiledKernelAttr, true, node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resources, node);
+ return Status::OK();
+ };
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph,
+ rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out,
+ library),
+ "EncapsulateSubgraphsPass failed");
if (VLOG_IS_ON(1)) {
dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
index 5fee36f022..926589546f 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -28,6 +28,9 @@ limitations under the License.
namespace tensorflow {
// A rewriting function to apply to each subgraph during encapsulation.
+// 'arg_source_tensors' are the tensors corresponding to the arguments in the
+// original source graph (*not* 'graph').
+//
// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
// 'input_permutation' is a mapping from old argument numbers to new argument
// numbers, whereas 'output_permutation' is the same for outputs. Both
@@ -37,6 +40,7 @@ namespace tensorflow {
// The rewrite may also change the NodeDef's operator name, and that
// name will be used as the name of the generated function.
typedef std::function<Status(
+ const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node_def)>
RewriteSubgraphFn;
@@ -61,10 +65,6 @@ typedef std::function<Status(
// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
// function conversion.
//
-// If 'parallel_checking' is true, the unencapsulated operators are added to the
-// output graph, together with a "ParallelCheck" operator, that verifies that
-// the original and encapsulated subgraphs produce similar results.
-//
// If 'reuse_existing_functions' is set, use an existing function with the
// same name, if any.
//
@@ -76,8 +76,8 @@ typedef std::function<Status(
Status EncapsulateSubgraphsInFunctions(
string group_attribute, string outside_compilation_attribute,
const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn,
- bool parallel_checking, bool reuse_existing_functions,
- std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
+ bool reuse_existing_functions, std::unique_ptr<Graph>* graph_out,
+ FunctionLibraryDefinition* library);
// The attribute that marks function calls produced by the encapsulate
// subgraphs pass and that should in turn be compiled via XlaLaunch operators.
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index eef113a354..c0543a0079 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -511,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
std::unique_ptr<Graph> graph_out;
s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph,
/*rewrite_subgraph_fn=*/{},
- /*parallel_checking=*/false,
/*reuse_existing_functions=*/false,
&graph_out, lib_def.get());
if (!s.ok()) return s;
@@ -560,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) {
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
- Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
- "_encapsulate", "F1"));
+ Node* d = Binary(b, c,
+ b1.opts().WithName("c").WithControlInput(c).WithAttr(
+ "_encapsulate", "F1"));
Binary(a, d, b1.opts().WithName("E"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
@@ -614,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) {
Node* c =
Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
"_encapsulate", "F1"));
- Node* d =
- Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
+ Node* d = Binary(b, c,
+ b1.opts().WithName("D").WithControlInput(control).WithAttr(
"_encapsulate", "F2"));
Binary(a, d, b1.opts().WithName("E"));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
@@ -707,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
std::unique_ptr<Graph> graph;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_cluster", "_outside", graph_before_encapsulation,
- /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false,
+ /*rewrite_subgraph_fn=*/{},
/*reuse_existing_functions=*/false, &graph, &library));
std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
@@ -721,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) {
EXPECT_EQ(expected_edges, GraphEdges(*graph));
}
-TEST(EncapsulateSubgraphsTest, ParallelChecking) {
- Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
- "/job:localhost/replica:0/task:0/cpu:0");
- auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
- auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
- auto add1 = ops::Add(root.WithOpName("add1"), x1, x2);
- add1.node()->AddAttr("_cluster", "cluster1");
- auto add2 = ops::Add(root.WithOpName("add2"), add1, x2);
- add2.node()->AddAttr("_cluster", "cluster1");
- auto out = ops::Mul(root.WithOpName("mul"), x1, add2);
-
- Graph graph_before_encapsulation(OpRegistry::Global());
- TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
-
- FunctionLibraryDefinition library(OpRegistry::Global(), {});
- std::unique_ptr<Graph> graph;
- TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
- "_cluster", "_outside", graph_before_encapsulation,
- /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true,
- /*reuse_existing_functions=*/false, &graph, &library));
-
- std::vector<string> expected_nodes = {
- "add1", "add2", "cluster1", "cluster1_parallel_check/_0",
- "mul", "x1", "x2"};
- EXPECT_EQ(expected_nodes, GraphNodes(*graph));
-
- std::vector<std::pair<string, string>> expected_edges = {
- {"add1:0", "add2:0"},
- {"add2:0", "cluster1_parallel_check/_0:0"},
- {"cluster1:0", "cluster1_parallel_check/_0:1"},
- {"cluster1_parallel_check/_0:0", "mul:1"},
- {"x1:0", "add1:0"},
- {"x1:0", "cluster1:0"},
- {"x1:0", "mul:0"},
- {"x2:0", "add1:1"},
- {"x2:0", "add2:1"},
- {"x2:0", "cluster1:1"},
- };
- EXPECT_EQ(expected_edges, GraphEdges(*graph));
-}
-
const Node* FindNodeByName(const Graph& graph, const string& name) {
for (const Node* node : graph.nodes()) {
if (node->name() == name) return node;
@@ -783,10 +742,13 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
- auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f);
+ auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
+ auto const_guarantee_x2 =
+ ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
- auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2);
+ auto add1 =
+ ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_guarantee_x2);
add1.node()->AddAttr("_encapsulate", "encapsulate1");
Graph graph_before(OpRegistry::Global());
@@ -798,7 +760,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
- [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
+ [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
@@ -814,7 +777,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
}
return Status::OK();
},
- /*parallel_checking=*/false,
/*reuse_existing_functions=*/false, &graph_after, &library));
EXPECT_EQ(2, guaranteed_consts);
}
@@ -843,7 +805,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
- [&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
+ [&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
+ std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
@@ -859,7 +822,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
}
return Status::OK();
},
- /*parallel_checking=*/false,
/*reuse_existing_functions=*/false, &graph_after, &library));
// Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
// and another non-const, so overall non-const.
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 902fe27acd..338fb5a6f0 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -115,6 +115,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
+ bool use_multiple_streams = s.ok() && metadata->UseMultipleStreams();
// Get the platform_id_ for XLA_* devices.
if (platform_id_ == nullptr) {
@@ -166,14 +167,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
+ // Optimization: don't resolve constants. If we resolve constants we never
+ // emit them on the device, meaning that if they are needed by a following
+ // computation the host has to transfer them.
+ compile_options.resolve_compile_time_constants = false;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
&kernel, &executable, &compile_options));
VLOG(1) << "Executing XLA Computation...";
- XlaComputationLaunchContext launch_context(client, xla_allocator,
- allocate_xla_tensors);
+ XlaComputationLaunchContext launch_context(
+ client, xla_allocator, allocate_xla_tensors, use_multiple_streams);
launch_context.PopulateInputs(ctx, kernel, variables);
// Execute the computation.
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
index 5d211f4d73..5b6692f523 100644
--- a/tensorflow/compiler/jit/legacy_flags/BUILD
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -17,18 +17,6 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//tensorflow:internal"])
cc_library(
- name = "encapsulate_subgraphs_pass_flags",
- srcs = ["encapsulate_subgraphs_pass_flags.cc"],
- hdrs = ["encapsulate_subgraphs_pass_flags.h"],
- deps =
- [
- "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "mark_for_compilation_pass_flags",
srcs = ["mark_for_compilation_pass_flags.cc"],
hdrs = ["mark_for_compilation_pass_flags.h"],
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 8e2ee0f1d7..8c3882116d 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -41,9 +42,6 @@ limitations under the License.
namespace tensorflow {
-const char* const kXlaClusterAttr = "_XlaCluster";
-const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
-
namespace {
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
@@ -60,6 +58,14 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
return false;
}
}
+
+ // XLA does not offer guaranteed aliasing between the input and output of the
+ // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
+ // such nodes out of XLA clusters.
+ if (HasForwardedRefInput(node)) {
+ return false;
+ }
+
return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok();
}
@@ -165,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def,
return true;
}
-// Returns the DeviceType corresponding to 'device'.
-Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
- DeviceNameUtils::ParsedName parsed;
- if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
- return errors::Internal("Malformed assigned device '", device, "'");
- }
- *device_type = DeviceType(parsed.type);
- return Status::OK();
-}
-
// Tests whether `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
@@ -183,18 +179,11 @@ bool HasResourceInputOrOutput(const Node& node) {
DT_RESOURCE) != node.output_types().end();
}
-struct NodeCompare {
- bool operator()(const Node* a, const Node* b) const {
- return a->id() < b->id();
- }
-};
-using OrderedNodeSet = std::set<Node*, NodeCompare>;
-
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
-// TODO(hpucha): Consider a black list instead of a white list as
-// implemented below.
+// TODO(hpucha): Remove this code since this functionality is subsumed by
+// Grappler XlaFusionOptimizer.
bool IsXlaFusable(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
@@ -364,7 +353,7 @@ Status FindCompilationCandidates(
for (Node* node : graph.op_nodes()) {
sorted_nodes.push_back(node);
}
- std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
+ std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
for (Node* node : sorted_nodes) {
VLOG(2) << "Fuel: " << fuel;
@@ -379,9 +368,13 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
- DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
+ DeviceToDeviceType(node->assigned_device_name(), &device_type));
- if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
+ if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
+ VLOG(2) << "Compilation rejected node: not compilable " << node->name()
+ << ": " << node->type_string();
+ continue;
+ }
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(
@@ -430,46 +423,6 @@ struct Cluster {
int representative = -1;
};
-// Returns a string describing how an edge from src to dst would
-// create a cycle.
-string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
- int dst) {
- int32 max_path_size = graph.num_node_ids() + 1;
- std::vector<int32> path(max_path_size);
- int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data());
- if (path_size == 0) {
- return "";
- }
-
- auto node_name = [&cycles, &graph](int node_id) {
- if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
- return string("(null)");
- }
- auto* node = graph.FindNodeId(node_id);
- if (node == nullptr) {
- return string("(null)");
- }
- return node->name();
- };
-
- string description;
- strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
- node_name(dst), " would create a cycle.\n");
- path.resize(path_size);
- for (int32 node_id : path) {
- string ascii_art;
- if (node_id == dst) {
- ascii_art = "+-> ";
- } else if (node_id != src) {
- ascii_art = "| ";
- } else {
- ascii_art = "+-- ";
- }
- strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
- }
- return description;
-}
-
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
@@ -575,84 +528,13 @@ Status MarkForCompilationPass::RunImpl(
: Env::Default(),
is_compilable_fn, &compilation_candidates));
- GraphCycles cycles;
- for (int i = 0; i < graph->num_node_ids(); ++i) {
- // We rely on the node IDs in the cycle detection graph being consecutive
- // integers starting from 0.
- CHECK_EQ(i, cycles.NewNode());
+ if (compilation_candidates.empty()) {
+ VLOG(2) << "No compilable candidates";
+ return Status::OK();
}
- // Compute the loop structure of the graph.
- std::vector<ControlFlowInfo> control_flow_info;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
-
- // The clustering code must avoid adding cycles to the graph to prevent
- // deadlock. However, the graph may contain loops, which would trigger the
- // cycle detection code. To handle loops, we alter the structure of the cycle
- // detection graph, disconnecting each loop from the enclosing graph.
- // Specifically, we:
- // * add a new "frame" node for each loop.
- // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
- // to/from the corresponding frame node. In essence, we collapse the loop
- // into a single node for the purpose of cycle detection in the enclosing
- // graph.
- // * the body of the loop should now be disconnected from the rest of the
- // graph; we make it acyclic by breaking loop backedges (edges outgoing from
- // "NextIteration" nodes.
-
- // Map from frame name strings to node IDs in the cycle detection graph.
- std::unordered_map<string, int> frame_nodes;
-
- // Get the cycle graph node ID for frame 'frame_name', or add one if none
- // exists.
- auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
- int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
- if (frame_id < 0) {
- // The emplace succeeded; we have not allocated a frame node yet.
- frame_id = cycles.NewNode();
- }
- return frame_id;
- };
-
- for (Edge const* edge : graph->edges()) {
- if (edge->dst()->IsEnter()) {
- // Lift edges to an "Enter" node to the corresponding frame node.
- const string& frame_name =
- control_flow_info[edge->dst()->id()].frame_name;
- int dst = GetOrAddFrameNodeId(frame_name);
- if (!cycles.InsertEdge(edge->src()->id(), dst)) {
- return errors::Internal(
- "Cycle detected when adding enter->frame edge: ",
- DescribeCycle(cycles, *graph, edge->src()->id(), dst));
- }
- continue;
- }
- if (edge->src()->IsExit()) {
- // Lift edges from an "Exit" node to the corresponding frame node.
- const string& frame_name =
- control_flow_info[edge->src()->id()].frame_name;
- int src = GetOrAddFrameNodeId(frame_name);
- if (!cycles.InsertEdge(src, edge->dst()->id())) {
- return errors::Internal(
- "Cycle detected when adding frame->exit edge: ",
- DescribeCycle(cycles, *graph, src, edge->dst()->id()));
- }
- // Drop the original edge.
- continue;
- }
- if (edge->src()->IsNextIteration()) {
- // Break loop back-edges.
- continue;
- }
- if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
- // This should never happen. All cycles in the graph should contain
- // a control flow operator.
- return errors::Internal(
- "Found cycle in graph without control flow operator during XLA "
- "compilation: ",
- DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
- }
- }
+ GraphCycles cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@@ -670,6 +552,9 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
+ //
+ // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
+ // example, from the Grappler fusion pass).
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
@@ -778,7 +663,7 @@ Status MarkForCompilationPass::RunImpl(
// compilation.
DeviceType device_type("");
TF_RETURN_IF_ERROR(
- DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 703d8825d7..772c92d369 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) {
}
}
+TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output add = ops::Add(root.WithOpName("add"), neg, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name}, {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
+TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output variable = ops::Variable(root.WithOpName("variable"),
+ PartialTensorShape{}, DT_FLOAT);
+ Output read = ops::Identity(root.WithOpName("read"), variable);
+ Output neg = ops::Negate(root.WithOpName("negate"), read);
+ Output identity = ops::Negate(root.WithOpName("identity"), neg);
+ Output add = ops::Add(root.WithOpName("add"), identity, neg);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+
+ ASSERT_FALSE(clusters.empty());
+ string cluster_name = clusters.begin()->second;
+
+ std::unordered_map<string, string> expected_clusters(
+ {{"negate", cluster_name},
+ {"identity", cluster_name},
+ {"add", cluster_name}});
+ EXPECT_EQ(clusters, expected_clusters);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
new file mode 100644
index 0000000000..a5628b12a2
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -0,0 +1,188 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+const char* const kXlaClusterAttr = "_XlaCluster";
+const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
+
+namespace {
+// Returns a string describing how an edge from src to dst would
+// create a cycle.
+string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
+ int dst) {
+ int32 max_path_size = graph.num_node_ids() + 1;
+ std::vector<int32> path(max_path_size);
+ int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
+ if (path_size == 0) {
+ return "";
+ }
+
+ auto node_name = [cycles, &graph](int node_id) {
+ if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
+ return string("(null)");
+ }
+ auto* node = graph.FindNodeId(node_id);
+ if (node == nullptr) {
+ return string("(null)");
+ }
+ return node->name();
+ };
+
+ string description;
+ strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
+ node_name(dst), " would create a cycle.\n");
+ path.resize(path_size);
+ for (int32 node_id : path) {
+ string ascii_art;
+ if (node_id == dst) {
+ ascii_art = "+-> ";
+ } else if (node_id != src) {
+ ascii_art = "| ";
+ } else {
+ ascii_art = "+-- ";
+ }
+ strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
+ }
+ return description;
+}
+
+bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); }
+
+} // namespace
+
+Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
+ return errors::Internal("Malformed assigned device '", device, "'");
+ }
+ *device_type = DeviceType(parsed.type);
+ return Status::OK();
+}
+
+bool HasForwardedRefInput(const Node& node) {
+ if (AlwaysForwardsRefInput(node)) {
+ for (const Edge* incoming_edge : node.in_edges()) {
+ if (incoming_edge->IsControlEdge()) {
+ continue;
+ }
+
+ Node* incoming_node = incoming_edge->src();
+ if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) {
+ VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input "
+ << incoming_node->name() << " " << incoming_node->type_string();
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
+ for (int i = 0; i < graph->num_node_ids(); ++i) {
+ // We rely on the node IDs in the cycle detection graph being consecutive
+ // integers starting from 0.
+ CHECK_EQ(i, cycles->NewNode());
+ }
+
+ // Compute the loop structure of the graph.
+ std::vector<ControlFlowInfo> control_flow_info;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
+
+ // The clustering code must avoid adding cycles to the graph to prevent
+ // deadlock. However, the graph may contain loops, which would trigger the
+ // cycle detection code. To handle loops, we alter the structure of the cycle
+ // detection graph, disconnecting each loop from the enclosing graph.
+ // Specifically, we:
+ // * add a new "frame" node for each loop.
+ // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
+ // to/from the corresponding frame node. In essence, we collapse the loop
+ // into a single node for the purpose of cycle detection in the enclosing
+ // graph.
+ // * the body of the loop should now be disconnected from the rest of the
+ // graph; we make it acyclic by breaking loop backedges (edges outgoing from
+ // "NextIteration" nodes.
+
+ // Map from frame name strings to node IDs in the cycle detection graph.
+ std::unordered_map<string, int> frame_nodes;
+
+ // Get the cycle graph node ID for frame 'frame_name', or add one if none
+ // exists.
+ auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
+ int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
+ if (frame_id < 0) {
+ // The emplace succeeded; we have not allocated a frame node yet.
+ frame_id = cycles->NewNode();
+ }
+ return frame_id;
+ };
+
+ for (Edge const* edge : graph->edges()) {
+ if (edge->dst()->IsEnter() || edge->src()->IsExit()) {
+ const char* src_type = "pre-enter";
+ const char* dst_type = "post-exit";
+ int src = edge->src()->id();
+ int dst = edge->dst()->id();
+
+ if (edge->dst()->IsEnter()) {
+ // Lift edges to an "Enter" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->dst()->id()].frame_name;
+ dst = GetOrAddFrameNodeId(frame_name);
+ dst_type = "frame";
+ }
+
+ if (edge->src()->IsExit()) {
+ // Lift edges from an "Exit" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->src()->id()].frame_name;
+ src = GetOrAddFrameNodeId(frame_name);
+ src_type = "frame";
+ }
+
+ if (!cycles->InsertEdge(src, dst)) {
+ return errors::Internal(
+ "Cycle detected when adding ", src_type, "->", dst_type,
+ " edge: ", DescribeCycle(cycles, *graph, src, dst));
+ }
+ // Drop the original edge.
+ continue;
+ }
+ if (edge->src()->IsNextIteration()) {
+ // Break loop back-edges.
+ continue;
+ }
+ if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
+ // This should never happen. All cycles in the graph should contain
+ // a control flow operator.
+ return errors::Internal(
+ "Found cycle in graph without control flow operator during XLA "
+ "compilation: ",
+ DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
new file mode 100644
index 0000000000..bcce082aaf
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains utilities for clustering compilable graph nodes via XLA.
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+// The attribute that marks nodes to be grouped into functions by the
+// encapsulate subgraphs pass.
+extern const char* const kXlaClusterAttr;
+
+// The attribute that marks nodes in a cluster to be placed outside the xla
+// compilation by the encapsulate subgraphs pass.
+extern const char* const kXlaOutsideCompilationAttr;
+
+using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
+
+// Returns the DeviceType corresponding to 'device'.
+Status DeviceToDeviceType(const string& device, DeviceType* device_type);
+
+// Returns true if `node` has a ref tensor input that it forwards to its output.
+bool HasForwardedRefInput(const Node& node);
+
+// Creates a graph representation to enable cycle detection when clustering.
+// This representation handles loops in graph by disconnecting each loop from
+// the enclosing graph.
+Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc
new file mode 100644
index 0000000000..2cb351e1ec
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(CreateCycleDetectionGraph, ConnectivityThroughEnterExitRegion) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
+ Output enter =
+ ops::internal::Enter(root.WithOpName("enter"), a, "only_frame");
+ Output exit = ops::internal::Exit(root.WithOpName("exit"), enter);
+ Output b = ops::Add(root.WithOpName("b"), a, exit);
+
+ FixupSourceAndSinkEdges(root.graph());
+
+ GraphCycles cycles;
+ TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
+ EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
+}
+
+TEST(CreateCycleDetectionGraph, ConnectivityThroughMultipleEnterExitRegions) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output a = ops::Const(root.WithOpName("a"), Input::Initializer(0.0));
+ Output enter_0 =
+ ops::internal::Enter(root.WithOpName("enter_0"), a, "frame_0");
+ Output exit_0 = ops::internal::Exit(root.WithOpName("exit_0"), enter_0);
+ Output enter_1 =
+ ops::internal::Enter(root.WithOpName("enter_1"), a, "frame_1");
+ Output exit_1 = ops::internal::Exit(root.WithOpName("exit_1"), enter_1);
+ Output b = ops::Add(root.WithOpName("b"), a, exit_1);
+
+ FixupSourceAndSinkEdges(root.graph());
+
+ GraphCycles cycles;
+ TF_ASSERT_OK(CreateCycleDetectionGraph(root.graph(), &cycles));
+ EXPECT_FALSE(cycles.ContractEdge(a.node()->id(), b.node()->id()));
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index b1943d3e1a..d288d37bc7 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -53,7 +53,9 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
- client, client->backend().memory_allocator(), true);
+ client, client->backend().memory_allocator(),
+ /*allocate_xla_tensors=*/true,
+ /*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variables);
@@ -61,14 +63,18 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
- VLOG(2) << "Executing computation.";
+ VLOG(2) << "Executing computation: " << name();
+ for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
+ VLOG(2) << name() << ": " << *arg;
+ }
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(ctx->step_id());
- auto run_result = executable->Run(launch_context.arguments(), run_options);
+ xla::StatusOr<xla::ScopedShapedBuffer> run_result =
+ executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
@@ -159,6 +165,13 @@ Status XlaCompileOnDemandOp::Compile(
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
+ // Optimization: don't resolve constants. If we resolve constants we never
+ // emit them on the device, meaning that if they are needed by a following
+ // computation the host has to transfer them.
+ compile_options.resolve_compile_time_constants = false;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 43648402f6..7e159e3171 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -54,6 +54,7 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index ed007d603e..c55eba2f79 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -130,7 +130,7 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
+ bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
@@ -151,22 +151,24 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
- device->reset(new XlaDevice(
- options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
- padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
+ device->reset(
+ new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name),
+ platform.ValueOrDie(), transfer_as_literal,
+ use_multiple_streams, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- PaddedShapeFn padded_shape_fn)
+ PaddedShapeFn padded_shape_fn, bool use_multiple_streams)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)),
- padded_shape_fn_(std::move(padded_shape_fn)) {}
+ padded_shape_fn_(std::move(padded_shape_fn)),
+ use_multiple_streams_(use_multiple_streams) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -200,16 +202,18 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- se::Platform* platform, bool transfer_as_literal,
+ se::Platform* platform, bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn, padded_shape_fn),
+ shape_representation_fn, padded_shape_fn,
+ use_multiple_streams),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
+ use_multiple_streams_(use_multiple_streams),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
@@ -253,6 +257,30 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
+xla::StatusOr<se::Stream*> XlaDevice::GetDeviceToHostStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!device_to_host_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(device_to_host_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return device_to_host_stream_.get();
+}
+
+xla::StatusOr<se::Stream*> XlaDevice::GetHostToDeviceStream() {
+ if (!use_multiple_streams_) {
+ return GetStream();
+ }
+ if (!host_to_device_stream_) {
+ xla::Backend* backend = client()->mutable_backend();
+ TF_ASSIGN_OR_RETURN(host_to_device_stream_,
+ backend->BorrowStream(device_ordinal_));
+ }
+ return host_to_device_stream_.get();
+}
+
Status XlaDevice::CreateAndSetGpuDeviceInfo() {
if (gpu_device_info_ == nullptr) {
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
@@ -263,8 +291,9 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
- gpu_device_info_->default_context = new XlaDeviceContext(
- stream, client(), transfer_as_literal_, shape_representation_fn_);
+ gpu_device_info_->default_context =
+ new XlaDeviceContext(stream, stream, stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@@ -276,10 +305,16 @@ Status XlaDevice::FillContextMap(const Graph* graph,
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
- auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ auto ctx = new XlaDeviceContext(
+ stream, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@@ -326,8 +361,13 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- XlaTransferManager manager(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ TF_ASSIGN_OR_RETURN(se::Stream * device_to_host_stream,
+ GetDeviceToHostStream());
+ TF_ASSIGN_OR_RETURN(se::Stream * host_to_device_stream,
+ GetHostToDeviceStream());
+ XlaTransferManager manager(stream, host_to_device_stream,
+ device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 02e88ee679..fccdb14368 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -57,7 +57,7 @@ class XlaDevice : public LocalDevice {
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
- PaddedShapeFn padded_shape_fn);
+ PaddedShapeFn padded_shape_fn, bool use_multiple_streams);
// The index of the device on this host.
int device_ordinal() const;
@@ -70,12 +70,15 @@ class XlaDevice : public LocalDevice {
}
const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; }
+ bool UseMultipleStreams() const { return use_multiple_streams_; }
+
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
PaddedShapeFn padded_shape_fn_;
+ const bool use_multiple_streams_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -89,6 +92,8 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
+ // If 'use_multiple_streams' is true, we create separate streams for
+ // host-to-device and device-to-host communication.
// If padded_shape_fn is empty, a default implementation that returns
// the on-host shape is used.
static Status Create(
@@ -96,7 +101,7 @@ class XlaDevice : public LocalDevice {
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
+ bool transfer_as_literal, bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device);
@@ -106,6 +111,7 @@ class XlaDevice : public LocalDevice {
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
+ bool use_multiple_streams,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
const PaddedShapeFn& padded_shape_fn);
~XlaDevice() override;
@@ -126,6 +132,8 @@ class XlaDevice : public LocalDevice {
xla::LocalClient* client() const;
const Metadata& metadata() { return xla_metadata_; }
xla::StatusOr<se::Stream*> GetStream();
+ xla::StatusOr<se::Stream*> GetHostToDeviceStream();
+ xla::StatusOr<se::Stream*> GetDeviceToHostStream();
// If not already set, create and set GpuDeviceInfo.
// Not thread-safe
@@ -146,6 +154,16 @@ class XlaDevice : public LocalDevice {
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
xla::Backend::StreamPtr stream_;
+ // If true, only stream_ is valid and all computation and transfers use
+ // stream_. If false, computation is performed by stream_ and transfers are
+ // performed by host_to_device/device_to_host_stream.
+ bool use_multiple_streams_;
+ // If use_multiple_streams_, host to device transfers are performed using this
+ // stream.
+ xla::Backend::StreamPtr host_to_device_stream_;
+ // If use_multiple_streams_, device to host transfers are performed using this
+ // stream.
+ xla::Backend::StreamPtr device_to_host_stream_;
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 71e63b110b..04778c0090 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -48,17 +48,24 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(stream),
+ : stream_(compute_stream),
+ host_to_device_stream_(host_to_device_stream),
+ device_to_host_stream_(device_to_host_stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
+ CHECK(host_to_device_stream_ != nullptr);
+ CHECK(device_to_host_stream_ != nullptr);
+ CHECK(stream_ != nullptr);
if (!shape_representation_fn_) {
- shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
- return shape;
- };
+ shape_representation_fn_ =
+ [](const TensorShape& shape,
+ DataType dtype) -> xla::StatusOr<TensorShape> { return shape; };
}
}
@@ -70,12 +77,19 @@ Status XlaTransferManager::TransferLiteralToDevice(
xla::BorrowingLiteral literal(
static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
- const xla::ShapedBuffer& shaped_buffer =
- XlaTensor::FromTensor(device_tensor)->shaped_buffer();
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+ const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " "
<< shaped_buffer.ToString();
- return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal,
- shaped_buffer);
+ TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDevice(
+ host_to_device_stream_, literal, shaped_buffer));
+ if (UseMultipleStreams()) {
+ se::Event event(stream_->parent());
+ TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ }
+ return Status::OK();
}
Status XlaTransferManager::TransferLiteralFromDevice(
@@ -85,7 +99,7 @@ Status XlaTransferManager::TransferLiteralFromDevice(
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Literal> literal,
transfer_manager_->TransferLiteralFromDevice(
- stream_->parent(), shaped_buffer));
+ device_to_host_stream_, shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " "
<< shaped_buffer.ToString();
Tensor tensor;
@@ -103,63 +117,67 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
Tensor* device_tensor,
StatusCallback done) const {
- if (cpu_tensor->NumElements() > 0) {
- VLOG(2) << "CopyCPUTensorToDevice "
- << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << " "
- << reinterpret_cast<const void*>(
- device_tensor->tensor_data().data())
- << " " << cpu_tensor->NumElements() << " "
- << cpu_tensor->shape().DebugString() << " "
- << device_tensor->shape().DebugString();
-
- void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
- const int64 total_bytes = cpu_tensor->TotalBytes();
-
- XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
- CHECK(xla_tensor);
-
- TensorShape shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
- if (!xla_tensor->has_shaped_buffer()) {
- Status s = xla_tensor->AllocateShapedBuffer(
- device_tensor->dtype(), shape, client_,
- stream_->parent()->device_ordinal());
- if (!s.ok()) {
- done(s);
- return;
- }
- }
+ if (cpu_tensor->NumElements() == 0) {
+ VLOG(2) << "CopyCPUTensorToDevice empty tensor";
+ done(Status::OK());
+ return;
+ }
- Status status;
- if (transfer_as_literal_) {
- Tensor reshaped_cpu_tensor;
- if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
- done(errors::Internal(
- "Tensor::CopyFrom failed when copying from CPU to XLA device"));
- return;
- }
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
- } else {
- se::DeviceMemoryBase dev_dst_ptr =
- XlaTensor::DeviceMemoryFromTensor(*device_tensor);
- stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
- }
- }
- xla_tensor->set_host_tensor(*cpu_tensor);
+ VLOG(2) << "CopyCPUTensorToDevice "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
+ << " " << cpu_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
- done(status);
+ void* src_ptr = const_cast<void*>(DMAHelper::base(cpu_tensor));
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+ CHECK(xla_tensor);
+
+ xla::StatusOr<TensorShape> shape_or_status =
+ shape_representation_fn_(device_tensor->shape(), device_tensor->dtype());
+ if (!shape_or_status.ok()) {
+ done(shape_or_status.status());
return;
}
+ TensorShape shape = shape_or_status.ValueOrDie();
+ if (!xla_tensor->has_shaped_buffer()) {
+ Status s =
+ xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
+ stream_->parent()->device_ordinal());
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ }
- VLOG(2) << "CopyCPUTensorToDevice empty tensor";
- done(Status::OK());
+ Status status;
+ if (transfer_as_literal_) {
+ Tensor reshaped_cpu_tensor;
+ if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
+ done(errors::Internal(
+ "Tensor::CopyFrom failed when copying from CPU to XLA device"));
+ return;
+ }
+ status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+ } else {
+ se::DeviceMemoryBase dev_dst_ptr =
+ XlaTensor::DeviceMemoryFromTensor(*device_tensor);
+ host_to_device_stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = host_to_device_stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s",
+ host_to_device_stream_, block_status.error_message().c_str());
+ }
+ }
+ xla_tensor->set_host_tensor(*cpu_tensor);
+
+ done(status);
}
void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
@@ -167,62 +185,83 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Device* device,
Tensor* cpu_tensor,
StatusCallback done) {
- if (device_tensor->NumElements() > 0) {
- VLOG(2) << "CopyDeviceTensorToCPU "
- << reinterpret_cast<const void*>(
- device_tensor->tensor_data().data())
- << " "
- << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
- << " " << device_tensor->NumElements() << " "
- << cpu_tensor->shape().DebugString() << " "
- << device_tensor->shape().DebugString();
-
- const int64 total_bytes = cpu_tensor->TotalBytes();
- se::DeviceMemoryBase dev_src_ptr =
- XlaTensor::DeviceMemoryFromTensor(*device_tensor);
- void* dst_ptr = DMAHelper::base(cpu_tensor);
-
- Status status;
- if (transfer_as_literal_) {
- status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
- } else {
- stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
- // TODO(hpucha): Make this asynchronous.
- Status block_status = stream_->BlockHostUntilDone();
- if (!block_status.ok()) {
- status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
- block_status.error_message().c_str());
- }
- }
-
- done(status);
+ if (device_tensor->NumElements() == 0) {
+ VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
+ done(Status::OK());
return;
}
+ VLOG(2) << "CopyDeviceTensorToCPU "
+ << reinterpret_cast<const void*>(device_tensor->tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << " " << device_tensor->NumElements() << " "
+ << cpu_tensor->shape().DebugString() << " "
+ << device_tensor->shape().DebugString();
+
+ const int64 total_bytes = cpu_tensor->TotalBytes();
+ se::DeviceMemoryBase dev_src_ptr =
+ XlaTensor::DeviceMemoryFromTensor(*device_tensor);
+ void* dst_ptr = DMAHelper::base(cpu_tensor);
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
+
+ if (se::Event* event =
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ device_to_host_stream_->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(device_to_host_stream_);
+ }
+
+ Status status;
+ if (transfer_as_literal_) {
+ status = TransferLiteralFromDevice(cpu_tensor, *device_tensor);
+ } else {
+ device_to_host_stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes);
+ // TODO(hpucha): Make this asynchronous.
+ Status block_status = device_to_host_stream_->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ status = xla::InternalError(
+ "Failed to complete data transfer on stream %p: %s", stream_,
+ block_status.error_message().c_str());
+ }
+ }
- VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
- done(Status::OK());
+ done(status);
}
void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
Tensor* dst_tensor,
const StatusCallback& done) {
+ VLOG(2) << "CopyDeviceTensorToDevice "
+ << reinterpret_cast<const void*>(src_tensor.tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(dst_tensor->tensor_data().data());
// TODO(phawkins): replace this code with an asynchronous implementation.
auto body = [&]() {
if (src_tensor.NumElements() == 0) {
return Status::OK();
}
+ // TODO(jmolloy): We co-opt the device_to_host stream for device to device
+ // transfers; perhaps we should have a dedicated device to device stream? or
+ // one per device?
+ auto device_to_device_stream = device_to_host_stream_;
XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor);
XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor);
CHECK(xla_src && xla_dst)
<< "Missing destination tensor for device-to-device copy";
if (!xla_dst->has_shaped_buffer()) {
- TensorShape shape =
- shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+ TF_ASSIGN_OR_RETURN(
+ TensorShape shape,
+ shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()));
TF_RETURN_IF_ERROR(
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
}
+
+ if (se::Event* event =
+ xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ device_to_device_stream->ThenWaitFor(event);
+ xla_src->SetDefinedOn(device_to_device_stream);
+ TF_RETURN_IF_ERROR(device_to_device_stream->BlockHostUntilDone());
+ }
TF_RETURN_IF_ERROR(
xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
@@ -241,9 +280,12 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(stream, client, transfer_as_literal,
+ : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
+ client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index ee346e5653..c726495f96 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,7 +47,9 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
@@ -66,10 +68,17 @@ class XlaTransferManager {
Tensor* device_tensor) const;
Status TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor) const;
+ bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; }
- // Stream obtained from a Device, used to transfer tensors between
- // CPU and device.
+ // The main compute stream of the device, used to synchronize the transfer
+ // streams if they are set.
se::Stream* stream_;
+ // The stream to use for transferring data from host to device. Can be
+ // idential to stream_, but must not be nullptr.
+ se::Stream* host_to_device_stream_;
+ // The stream to use for transferring data from device to host. Can be
+ // idential to stream_, but must not be nullptr.
+ se::Stream* device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -85,7 +94,9 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
+ se::Stream* compute_stream, se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream, xla::LocalClient* client,
+ bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index b27c32e9bc..a605335a94 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -23,11 +23,14 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
+#include "tensorflow/core/kernels/shape_ops.h"
#include "tensorflow/core/kernels/variable_ops.h"
namespace tensorflow {
@@ -87,6 +90,46 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER( \
Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \
ReadVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("Shape") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Shape") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeOp<int64>); \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeNOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("ShapeN") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ ShapeNOp<int64>); \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int32>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ SizeOp<int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Size") \
+ .Device(DEVICE) \
+ .HostMemory("output") \
+ .TypeConstraint<int64>("out_type") \
+ .TypeConstraint("T", TYPES), \
+ SizeOp<int64>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \
+ TYPES), \
+ RankOp); \
REGISTER_KERNEL_BUILDER( \
Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \
XlaAssignVariableOp); \
@@ -95,7 +138,41 @@ class XlaAssignVariableOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
SwitchOp); \
REGISTER_KERNEL_BUILDER( \
- Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp);
+ Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
+ REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("LoopCond") \
+ .Device(DEVICE) \
+ .HostMemory("input") \
+ .HostMemory("output"), \
+ LoopCondOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
+ REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
+ .Device(DEVICE) \
+ .HostMemory("size") \
+ .HostMemory("handle"), \
+ QueueSizeOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
+ QueueIsClosedOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);
+
+// TODO(phawkins): currently we do not register the QueueEnqueueMany,
+// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
+// and write the tensors they access in order to concatenate them into a batch.
+// We would need either to call out to an XLA computation to perform the
+// concatenation, or we would need to refactor those kernels so the splitting
+// or merging is done in a separate operator that can be compiled.
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
new file mode 100644
index 0000000000..74257b09a8
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -0,0 +1,328 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+
+#include <atomic>
+#include <deque>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+namespace tensorflow {
+
+// Is 'node' an operator that consumes only the shape of its input, not the
+// data itself?
+static bool IsShapeConsumerOp(const Node& node) {
+ return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
+ node.type_string() == "Rank" || node.type_string() == "Size";
+}
+
+// Returns true if the op can be decomposed into XLA ops for which
+// there are fusable elemental implementations.
+bool IsXlaFusable(const NodeDef& node) {
+ static const std::unordered_set<std::string>* elementwise_ops =
+ new std::unordered_set<std::string>(
+ {// tf2xla/kernels/aggregate_ops.cc
+ "AddN",
+ // tf2xla/kernels/binary_ops.cc
+ "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
+ "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
+ "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
+ "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
+ "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
+ "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
+ "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
+ // tf2xla/kernels/unary_ops.cc
+ "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
+ "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
+ "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
+ "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
+ "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
+ "Square", "Tan", "Tanh", "Real", "Imag",
+ // tf2xla/kernels/bcast_ops.cc
+ "BroadcastArgs", "BroadcastGradientArgs",
+ // tf2xla/kernels/bias_ops.cc
+ "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
+ // tf2xla/kernels/cast_op.cc
+ "Cast",
+ // tf2xla/kernels/concat_op.cc
+ "Concat", "ConcatV2", "ConcatOffset",
+ // tf2xla/kernels/const_op.cc
+ "Const",
+ // tf2xla/kernels/elu_op.cc
+ "Elu", "EluGrad", "Selu", "SeluGrad",
+ // tf2xla/kernels/fill_op.cc
+ "Fill",
+ // tf2xla/kernels/identity_op.cc
+ "Identity", "IdentityN", "PreventGradient",
+ "StopGradient", /*"Snapshot",*/
+ // tf2xla/kernels/index_ops.cc
+ "ArgMax", "ArgMin",
+ // tf2xla/kernels/mirror_pad_op.cc
+ "MirrorPad",
+ // tf2xla/kernels/one_hot_op.cc
+ "OneHot",
+ // tf2xla/kernels/pack_op.cc
+ "Pack",
+ // tf2xla/kernels/pad_op.cc
+ "Pad", "PadV2",
+ // tf2xla/kernels/relu_op.cc
+ "Relu", "Relu6", "ReluGrad", "Relu6Grad",
+ // tf2xla/kernels/reshape_op.cc
+ "Reshape",
+ // tf2xla/kernels/reverse_op.cc
+ "Reverse", "ReverseV2",
+ // tf2xla/kernels/reverse_sequence_op.cc
+ "ReverseSequence",
+ // tf2xla/kernels/shape_op.cc
+ "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
+ "ZerosLike", "OnesLike",
+ // tf2xla/kernels/slice_op.cc
+ "Slice",
+ // tf2xla/kernels/split_op.cc
+ "Split", "SplitV",
+ // tf2xla/kernels/strided_slice_op.cc
+ "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
+ // tf2xla/kernels/tile_ops.cc
+ "Tile",
+ // tf2xla/kernels/transpose_op.cc
+ "Transpose", "InvertPermutation",
+ // tf2xla/kernels/unpack_op.cc
+ "Unpack"});
+
+ return elementwise_ops->count(node.op()) > 0;
+}
+
+Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
+ const grappler::GrapplerItem& item,
+ GraphDef* output) {
+ VLOG(2) << "Here at fusion optimizer";
+
+ // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
+ // Once that happens, the expected interaction between this optimizer and when
+ // the global_jit_level is set is as follows: Fusion optimizer will replace
+ // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
+ // be further compiled where possible via mark_for_compilation_pass. Note that
+ // this might lead to inefficient clustering, and it is best to use either the
+ // fusion optimizer or the global_jit flag, and not combine the two.
+
+ // Create a Graph out of GraphDef. This is required currently because the
+ // helpers around clustering, encapsulation etc work on graphs.
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ Graph graph(function_library);
+ ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
+ shape_refiner.set_require_shape_inference_fns(false);
+ shape_refiner.set_disable_constant_propagation(true);
+ ImportGraphDefOptions options;
+ // Graph optimization happens at the late stage of graph execution, when
+ // colocation constraints are already validated previously and the device
+ // placement of nodes has also completed, so there is no need to validate
+ // colocation constraints again.
+ options.validate_colocation_constraints = false;
+ options.validate_shape = false;
+ TF_RETURN_IF_ERROR(
+ ImportGraphDef(options, item.graph, &graph, &shape_refiner));
+
+ // Collect nodes that can be fused via XLA, while ignoring those that
+ // explicitly ask for XLA: (*) nodes that are marked to be compiled
+ // explicitly. (*) nodes assigned to XLA device.
+ OrderedNodeSet compilation_candidates;
+ for (Node* node : graph.op_nodes()) {
+ // If there is a _XlaCompile annotation, ignore the node if it is
+ // true. Nodes are marked with this attr via experimental_jit_scope, and
+ // will be handled by the mark_for_compilation pass.
+ bool compile = false;
+ Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
+ if (status.ok() && compile) {
+ continue;
+ }
+ // If there is already a _XlaCluster annotation, ignore the node. Nodes are
+ // marked with this attr to indicate they are already part of a cluster and
+ // hence ignored.
+ status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
+ if (status.ok()) {
+ continue;
+ }
+
+ // If there is an explicit XLA device placement, ignore the node.
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
+ if (device_type.type_string().find("XLA") != string::npos) continue;
+
+ // Assume all fusable ops are registered.
+ // TODO(hpucha): Check for registration if possible.
+ if (!IsXlaFusable(node->def())) {
+ continue;
+ }
+
+ // XLA does not offer guaranteed aliasing between the input and output of
+ // the XLA cluster so it can't implement the forward-tensor-ref semantic.
+ // Leave such nodes out of XLA clusters.
+ if (HasForwardedRefInput(*node)) {
+ continue;
+ }
+
+ compilation_candidates.insert(node);
+ }
+
+ if (compilation_candidates.empty()) {
+ VLOG(2) << "No compilable candidates";
+ *output = item.graph;
+ return Status::OK();
+ }
+
+ GraphCycles cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
+
+ // TODO(hpucha): Make clustering more robust. There are two known issues that
+ // we need to mitigate: (a) Non-resource variables can cause deadlocks
+ // when clustering changes order of execution. See b/77263461 for a specific
+ // example. (b) Queue operations can also cause deadlocks. See b/77261498 for
+ // example.
+
+ struct Cluster {
+ // Identifies the node that represents this cluster in the cycle detection
+ // graph.
+ int representative = -1;
+ };
+
+ // Each compilation candidate belongs to a cluster. The cluster's
+ // representative names the node in the 'cycles' graph that represents the
+ // cluster.
+ std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
+ std::deque<UnionFind<Cluster>*> worklist;
+ for (Node* node : compilation_candidates) {
+ Cluster& cluster = clusters[node->id()].Get();
+ cluster.representative = node->id();
+ worklist.push_back(&clusters[node->id()]);
+ }
+
+ // Repeatedly contract edges between clusters that are on the same device,
+ // provided the contraction would not create a cycle. This is a simplified
+ // version of the clustering in mark_for_compilation_pass that also deals with
+ // nodes that are explicitly tagged to be compiled/clustered.
+ while (!worklist.empty()) {
+ int from = worklist.front()->Get().representative;
+ worklist.pop_front();
+
+ Node* node_from = graph.FindNodeId(from);
+ if (node_from->IsControlFlow()) {
+ // Control flow nodes aren't compilation candidates and should never
+ // appear.
+ return errors::Internal(
+ "Found control flow node in clustering worklist: ",
+ node_from->type_string());
+ }
+ for (int to : cycles.Successors(from)) {
+ if (to >= graph.num_node_ids()) {
+ // Node is a "frame" node that is present only in the cycle detection
+ // graph. No clustering is possible.
+ continue;
+ }
+ Node* node_to = graph.FindNodeId(to);
+ if (compilation_candidates.find(node_to) ==
+ compilation_candidates.cend()) {
+ continue;
+ }
+
+ // Do not cluster across devices.
+ if (node_from->def().device() != node_to->def().device()) {
+ VLOG(2) << "Devices " << node_from->def().device() << " "
+ << node_to->def().device();
+ VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
+ << node_to->assigned_device_name();
+ continue;
+ }
+
+ // Ops that consume shapes cannot be the root of a cluster. This is an
+ // optimization.
+ if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
+ continue;
+ }
+
+ // If contracting the edge would create a cycle, bail out.
+ // However, just because we can't merge the clusters now does not mean
+ // we won't be able to merge them in the future.
+ // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
+ // 1->3. But if we first contract 1->2 then we can later contract 1->3.
+ if (!cycles.ContractEdge(from, to)) continue;
+
+ // Merge the clusters. ContractEdge uses 'from' as the number of the
+ // merged node, so make sure 'from' is the chosen representative.
+ clusters[from].Merge(&clusters[to]);
+
+ worklist.push_back(&clusters[from]);
+ break;
+ }
+ }
+
+ // Count the number of non-trivial elements in each cluster.
+ std::vector<int> effective_cluster_sizes(graph.num_node_ids());
+ for (const Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].Get().representative;
+ // Identity nodes will be removed if the node gets marked for compilation.
+ // Therefore we don't want to count them towards the effective cluster size.
+ if (n->def().op() != "Identity") {
+ effective_cluster_sizes[cluster]++;
+ }
+ }
+
+ const int min_cluster_size = 2;
+ int num_clusters = 0;
+ for (auto size : effective_cluster_sizes) {
+ if (size >= min_cluster_size) {
+ VLOG(3) << "Cluster " << num_clusters << " " << size;
+ num_clusters++;
+ }
+ }
+
+ // Names for each cluster.
+ std::unordered_map<int, string> cluster_names;
+ // Sequence number generator to ensure clusters have unique names.
+ static std::atomic<int64> cluster_sequence_num;
+
+ for (Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].Get().representative;
+
+ // Compile if this is a cluster of >= min_cluster_size compilable operators.
+ if (effective_cluster_sizes[cluster] >= min_cluster_size) {
+ string& name = cluster_names[cluster];
+
+ if (name.empty()) {
+ name = strings::StrCat("cluster_", cluster_sequence_num++);
+ }
+ n->AddAttr(kXlaClusterAttr, name);
+ VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
+ }
+ }
+
+ graph.ToGraphDef(output);
+ return Status::OK();
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h
new file mode 100644
index 0000000000..3d2309e782
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+
+// Optimizes graphs by fusing ops where possible, resulting in more efficient
+// execution.
+class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
+ public:
+ XlaFusionOptimizer() {}
+ ~XlaFusionOptimizer() override {}
+
+ Status Init(
+ const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
+ return Status::OK();
+ }
+
+ string name() const override { return "xla-fusion"; };
+
+ Status Optimize(grappler::Cluster* cluster,
+ const grappler::GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override {
+ // Nothing to do for XlaFusionOptimizer.
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
new file mode 100644
index 0000000000..5736760a87
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("UncompilableNullary").Output("o: float");
+REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
+
+class XlaFusionOptimizerTest : public grappler::GrapplerTest {
+ protected:
+ std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
+ std::unordered_map<string, string> ids;
+ for (const NodeDef& node : graph.node()) {
+ string cluster;
+ if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
+ CHECK(!cluster.empty());
+ ids[node.name()] = cluster;
+ }
+ }
+ return ids;
+ }
+};
+
+TEST_F(XlaFusionOptimizerTest, Chains) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
+ Node* d =
+ ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
+ Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
+ ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(4, clusters.size());
+ EXPECT_EQ(clusters["B"], clusters["C"]);
+ EXPECT_EQ(clusters["E"], clusters["F"]);
+ EXPECT_NE(clusters["B"], clusters["E"]);
+ EXPECT_TRUE(clusters.find("A") == clusters.cend());
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST_F(XlaFusionOptimizerTest, FusableOps) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
+ Node* b = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
+
+ Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
+ ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
+ ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(2, clusters.size());
+ EXPECT_EQ(clusters["C"], clusters["E"]);
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
+ Node* b = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
+
+ Node* c = ops::BinaryOp(
+ "Add", a, b,
+ builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
+ ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
+ Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
+ ops::UnaryOp("Cos", e,
+ builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b =
+ ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
+ ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index c0d86a28c7..851b118b0c 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -49,6 +49,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device);
if (!status.ok()) {
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 661187f4a8..4574559674 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -52,6 +52,7 @@ Status XlaInterpreterDeviceFactory::CreateDevices(
DEVICE_INTERPRETER_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
+ /*use_multiple_streams=*/false,
/*shape_representation_fn=*/{},
/*padded_shape_fn=*/{}, &device));
devices->push_back(device.release());
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index d0c7a93651..616c3ed2a2 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -115,14 +115,22 @@ using internal::ExtractSubShapedBuffer;
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors)
+ bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
- allocate_xla_tensors_(allocate_xla_tensors) {}
+ allocate_xla_tensors_(allocate_xla_tensors),
+ use_multiple_streams_(use_multiple_streams) {
+ if (use_multiple_streams_) {
+ CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
+ "be allocating XLA tensors!";
+ }
+}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
const std::map<int, OptionalTensor>& variables) {
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1);
arg_buffers_.resize(kernel->xla_input_shapes.size());
@@ -140,6 +148,16 @@ void XlaComputationLaunchContext::PopulateInputs(
t = &(ctx->input(arg_num));
}
+ if (use_multiple_streams_) {
+ CHECK(stream) << "Must have a stream available when using XLA tensors!";
+ XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
+ CHECK(xla_tensor);
+ if (se::Event* event = xla_tensor->GetDefinitionEvent(stream)) {
+ stream->ThenWaitFor(event);
+ xla_tensor->SetDefinedOn(stream);
+ }
+ }
+
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (xla::ShapeUtil::IsTuple(on_device_shape)) {
@@ -176,6 +194,21 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
+ // If the on-host-shape isn't a tuple, create a new single-element tuple
+ // buffer with a nullptr root index table. This allows the code below to treat
+ // output as a tuple unconditionally.
+ if (!xla::ShapeUtil::IsTuple(output.on_host_shape())) {
+ ShapedBuffer nontuple_buffer = output.release();
+ ShapedBuffer buffer(
+ xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_host_shape()}),
+ xla::ShapeUtil::MakeTupleShape({nontuple_buffer.on_device_shape()}),
+ output.platform(), output.device_ordinal());
+ buffer.buffers().CopySubtreeFrom(nontuple_buffer.buffers(),
+ /*source_base_index=*/{},
+ /*target_base_index=*/{0});
+ output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -230,9 +263,20 @@ void XlaComputationLaunchContext::PopulateOutputs(
Tensor* output_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
- CHECK(xla_tensor);
- xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
- ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (xla_tensor) {
+ xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
+ ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
+ } else {
+ // xla_tensor wasn't valid, which must mean this is a zero-element
+ // tensor.
+ CHECK_EQ(output_tensor->TotalBytes(), 0);
+ }
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
@@ -282,6 +326,12 @@ void XlaComputationLaunchContext::PopulateOutputs(
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
+ if (use_multiple_streams_) {
+ se::Event event(stream->parent());
+ CHECK(event.Init());
+ stream->ThenRecordEvent(&event);
+ xla_tensor->SetDefinedOn(stream, std::move(event));
+ }
*variable->tensor() = output_tensor;
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 4390701ccb..90531174ff 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -76,9 +76,15 @@ class XlaComputationLaunchContext {
// Create a new launch context. 'allocate_xla_tensors' is true if allocated
// output tensors and variables are always XlaTensors. If false they are
// assumed to be "normal" device pointers.
+ // If 'use_multiple_streams' is true, tensors may be defined and used on
+ // multiple streams and so se::Events must be defined and waited for. If
+ // 'use_multiple_streams' is true, 'allocate_xla_tensors' must also be true
+ // because we track inter-stream dependencies through events inside XlaTensor
+ // objects.
XlaComputationLaunchContext(xla::LocalClient* client,
xla::DeviceMemoryAllocator* xla_allocator,
- bool allocate_xla_tensors);
+ bool allocate_xla_tensors,
+ bool use_multiple_streams);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
@@ -99,6 +105,7 @@ class XlaComputationLaunchContext {
xla::LocalClient* client_;
xla::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
+ bool use_multiple_streams_;
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
};
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index 3c44c4ae6d..5dff187fff 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -73,6 +73,36 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
return Status::OK();
}
+se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ if (!definition_event_.has_value()) {
+ return nullptr;
+ }
+
+ // The set of defined streams is expected to be very small indeed (usually
+ // 1-2), so a simple linear scan should be fast enough.
+ if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
+ stream) != streams_defined_on_.end()) {
+ // stream is in streams_defined_on_; it doesn't need to be waited on.
+ return nullptr;
+ }
+
+ return &*definition_event_;
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+ mutex_lock lock(mu_);
+ CHECK(!definition_event_.has_value())
+ << "SetDefinedOn must only be called once!";
+ definition_event_ = std::move(event);
+ streams_defined_on_.push_back(stream);
+}
+
+void XlaTensor::SetDefinedOn(se::Stream* stream) {
+ mutex_lock lock(mu_);
+ streams_defined_on_.push_back(stream);
+}
+
// The pointer tag, OR-ed into the XlaTensor's address to distinguish it from
// device-side tensors, which are either CPU or GPU memory pointers. This works
// because we're guaranteed that CPU and GPU pointers are aligned to > 1 bits.
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index c54001a999..f7e401c731 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -85,6 +85,24 @@ class XlaTensor {
host_tensor_.reset(new Tensor(tensor));
}
+ // If the tensor's content is not yet defined on 'stream', and there exists an
+ // se::Event declaring when the tensor's content is defined, return it.
+ // Otherwise, return nullptr. If this function returns nullptr then the
+ // tensor's content can be read on 'stream' without additional
+ // synchronization.
+ se::Event* GetDefinitionEvent(se::Stream* stream);
+
+ // Assert that the tensor's content is defined on 'stream' by the time 'event'
+ // triggers.
+ void SetDefinedOn(se::Stream* stream, se::Event event);
+
+ // Assert that the tensor's content is defined on 'stream'. This version does
+ // not provide an event, and must be called *after* SetDefinedOn(Stream,
+ // Event). This call can be read as an assertion that the definition event has
+ // been waited on by 'stream', so further calls to GetDefinitionEvent(stream)
+ // do not need to also wait on the event.
+ void SetDefinedOn(se::Stream* stream);
+
// Convert from a raw pointer to an XlaTensor, removing the pointer tag.
static XlaTensor* FromOpaquePointer(void* ptr);
// Convert to a raw pointer from an XlaTensor, adding the pointer tag.
@@ -95,6 +113,14 @@ class XlaTensor {
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
+ // An optional event that is triggered when the tensor's content has been
+ // defined. If this event is nullptr, it is assumed that the tensor's content
+ // is always defined.
+ gtl::optional<se::Event> definition_event_;
+ // A list of all streams for which the tensor's content is defined for any
+ // newly enqueued command.
+ gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
+ mutex mu_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index b51c11bf6e..e8e19f055e 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -51,6 +51,38 @@ py_library(
],
)
+py_library(
+ name = "test_utils",
+ testonly = 1,
+ srcs = ["test_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "xla_test_test",
+ size = "small",
+ srcs = ["xla_test_test.py"],
+ deps = [
+ ":xla_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["adadelta_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
tf_xla_py_test(
name = "adagrad_test",
size = "small",
@@ -66,6 +98,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adagrad_da_test",
+ size = "small",
+ srcs = ["adagrad_da_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "adam_test",
size = "small",
srcs = ["adam_test.py"],
@@ -80,6 +125,48 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "adamax_test",
+ size = "small",
+ srcs = ["adamax_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "addsign_test",
+ size = "small",
+ srcs = ["addsign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "powersign_test",
+ size = "small",
+ srcs = ["powersign_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/contrib/opt:opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
name = "argminmax_test",
size = "small",
srcs = ["argminmax_test.py"],
@@ -148,7 +235,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "cholesky_op_test",
- size = "small",
+ size = "medium",
srcs = ["cholesky_op_test.py"],
tags = ["optonly"],
deps = [
@@ -238,6 +325,7 @@ tf_xla_py_test(
srcs = ["conv2d_test.py"],
shard_count = 10,
deps = [
+ ":test_utils",
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
@@ -245,6 +333,7 @@ tf_xla_py_test(
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -351,6 +440,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "fifo_queue_test",
+ size = "medium",
+ srcs = ["fifo_queue_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:extra_py_tests_deps",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
@@ -536,16 +639,68 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "proximal_adagrad_test",
+ size = "medium",
+ srcs = ["proximal_adagrad_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "proximal_gradient_descent_test",
+ size = "medium",
+ srcs = ["proximal_gradient_descent_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:training",
+ ],
+)
+
+tf_xla_py_test(
+ name = "qr_op_test",
+ size = "medium",
+ srcs = ["qr_op_test.py"],
+ disabled_backends = [
+ # Test is very slow on CPU.
+ "cpu",
+ "cpu_ondemand",
+ ],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+tf_xla_py_test(
name = "random_ops_test",
size = "small",
srcs = ["random_ops_test.py"],
- # TODO(b/31361304): enable RNG ops on GPU when parallelized.
disabled_backends = [
+ # TODO(b/110300529): RngNormal doesn't return values with the expected variance
+ "cpu",
+ "cpu_ondemand",
+ # TODO(b/31361304): enable RNG ops on GPU when parallelized.
"gpu",
],
deps = [
":xla_test",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@@ -663,6 +818,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "sparse_to_dense_op_test",
+ size = "small",
+ srcs = ["sparse_to_dense_op_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:sparse_ops",
+ ],
+)
+
+tf_xla_py_test(
name = "stack_ops_test",
size = "small",
srcs = ["stack_ops_test.py"],
@@ -741,9 +909,10 @@ tf_xla_py_test(
tf_xla_py_test(
name = "fused_batchnorm_test",
- size = "small",
+ size = "medium",
srcs = ["fused_batchnorm_test.py"],
deps = [
+ ":test_utils",
":xla_test",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
@@ -753,6 +922,7 @@ tf_xla_py_test(
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -829,6 +999,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "sort_ops_test",
+ size = "medium",
+ srcs = ["sort_ops_test.py"],
+ # Times out in fastbuild mode.
+ tags = ["optonly"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/compiler/tf2xla/python:xla",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py
new file mode 100644
index 0000000000..3e3c09c66e
--- /dev/null
+++ b/tensorflow/compiler/tests/adadelta_test.py
@@ -0,0 +1,134 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adadelta
+
+
+class AdadeltaOptimizerTest(xla_test.XLATestCase):
+
+ def testBasic(self):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lr, rho=rho, epsilon=epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, self.evaluate(var0))
+ self.assertAllClose(var1_init, self.evaluate(var1))
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ self.evaluate(adadelta_update)
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (
+ np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (
+ accum_update * rho + (update[step]**2) * (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype),
+ self.evaluate(slot[slot_idx]),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array([accum_update, accum_update], dtype=dtype),
+ self.evaluate(slot_update[slot_idx]),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype),
+ self.evaluate(var0),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype),
+ self.evaluate(var1),
+ rtol=1e-5)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py
new file mode 100644
index 0000000000..dc1625793a
--- /dev/null
+++ b/tensorflow/compiler/tests/adagrad_da_test.py
@@ -0,0 +1,165 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdagradDA optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad_da
+
+
+class AdagradDAOptimizerTest(xla_test.XLATestCase):
+
+ def testAdagradDAWithoutRegularizationBasic1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ # Let g to be gradient accumulator, gg to be gradient squared
+ # accumulator, T be the global step, lr is the learning rate, and k the
+ # initial gradient squared accumulator value.
+ # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})}
+ # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534
+ # similarly for others.
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAwithoutRegularizationBasic2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.904534, -1.603567]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.094821, -0.189358]), var1.eval())
+
+ def testAdagradDAWithL1(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.895489, -1.59555]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.085339, -0.17989]), var1.eval())
+
+ def testAdagradDAWithL1_L2(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ global_step = resource_variable_ops.ResourceVariable(
+ 0, dtype=dtypes.int64)
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = adagrad_da.AdagradDAOptimizer(
+ 3.0,
+ global_step,
+ initial_gradient_squared_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval())
+
+ # Run a step of AdagradDA
+ update.run()
+
+ self.assertAllCloseAccordingToType(
+ np.array([-0.046907, -0.093659]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([-0.004275, -0.009023]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/adagrad_test.py b/tensorflow/compiler/tests/adagrad_test.py
index 9a93b32164..d775850a80 100644
--- a/tensorflow/compiler/tests/adagrad_test.py
+++ b/tensorflow/compiler/tests/adagrad_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import adagrad
-class AdagradOptimizerTest(XLATestCase):
+class AdagradOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/adam_test.py b/tensorflow/compiler/tests/adam_test.py
index 3215dc36e5..03554d6933 100644
--- a/tensorflow/compiler/tests/adam_test.py
+++ b/tensorflow/compiler/tests/adam_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
@@ -48,7 +48,7 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(XLATestCase):
+class AdamOptimizerTest(xla_test.XLATestCase):
def testBasic(self):
for dtype in self.float_types:
diff --git a/tensorflow/compiler/tests/adamax_test.py b/tensorflow/compiler/tests/adamax_test.py
new file mode 100644
index 0000000000..c4fdbc5974
--- /dev/null
+++ b/tensorflow/compiler/tests/adamax_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AdaMax optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import adamax
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adamax_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = np.maximum(beta2 * v, np.abs(g_t))
+ param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon))
+ return param_t, m_t, v_t
+
+
+class AdaMaxOptimizerTest(xla_test.XLATestCase):
+
+ def testBasic(self):
+ for i, dtype in enumerate(self.float_types):
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adamax.AdaMaxOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ update.run()
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval(), rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, var1.eval(), rtol=1e-2)
+ self.assertEqual("var0_%d/AdaMax:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testTensorLearningRate(self):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ variable_scope.get_variable_scope().set_use_resource(True)
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adamax.AdaMaxOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of AdaMax
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/addsign_test.py b/tensorflow/compiler/tests/addsign_test.py
new file mode 100644
index 0000000000..9ec5a964cb
--- /dev/null
+++ b/tensorflow/compiler/tests/addsign_test.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for AddSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import addsign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def addsign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ alpha=1.0,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = alpha + sign_decayed * np.sign(g_t) * np.sign(m_t)
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class AddSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ alpha=1.0,
+ beta=0.9):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = addsign.AddSignOptimizer(
+ learning_rate=learning_rate,
+ alpha=alpha,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of AddSign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = addsign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = addsign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ alpha=alpha,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=1e-2)
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.01, alpha=0.1, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 1e4dd32916..9cb3d04546 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
-class BinaryOpsTest(XLATestCase):
+class BinaryOpsTest(xla_test.XLATestCase):
"""Test cases for binary operators."""
def _testBinary(self, op, a, b, expected, equality_test=None):
@@ -226,6 +226,11 @@ class BinaryOpsTest(XLATestCase):
np.array([0b1, 0b101, 0b1000], dtype=dtype),
np.array([0b0, 0b101, 0b1001], dtype=dtype),
expected=np.array([0b1, 0b101, 0b1001], dtype=dtype))
+ self._testSymmetricBinary(
+ bitwise_ops.bitwise_xor,
+ np.array([0b1, 0b111, 0b1100], dtype=dtype),
+ np.array([0b0, 0b101, 0b1001], dtype=dtype),
+ expected=np.array([0b1, 0b010, 0b0101], dtype=dtype))
lhs = np.array([0, 5, 3, 14], dtype=dtype)
rhs = np.array([5, 0, 7, 11], dtype=dtype)
@@ -1216,6 +1221,24 @@ class BinaryOpsTest(XLATestCase):
np.array([1, 0], dtype=np.int32),
expected=np.array([[1, 3], [2, 4]], dtype=dtype))
+ def testConjugateTranspose(self):
+ for dtype in self.complex_types:
+ self._testBinary(
+ array_ops.conjugate_transpose,
+ np.zeros(shape=[1, 0, 4], dtype=dtype),
+ np.array([1, 2, 0], dtype=np.int32),
+ expected=np.zeros(shape=[0, 4, 1], dtype=dtype))
+ self._testBinary(
+ array_ops.conjugate_transpose,
+ np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype),
+ np.array([0, 1], dtype=np.int32),
+ expected=np.array([[1 + 1j, 2 - 2j], [3 + 3j, 4 - 4j]], dtype=dtype))
+ self._testBinary(
+ array_ops.conjugate_transpose,
+ np.array([[1 - 1j, 2 + 2j], [3 - 3j, 4 + 4j]], dtype=dtype),
+ np.array([1, 0], dtype=np.int32),
+ expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
+
def testCross(self):
for dtype in self.float_types:
self._testBinary(
diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py
index fde9759a1c..ef4d5f6322 100644
--- a/tensorflow/compiler/tests/bucketize_op_test.py
+++ b/tensorflow/compiler/tests/bucketize_op_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -26,7 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class BucketizationOpTest(XLATestCase):
+class BucketizationOpTest(xla_test.XLATestCase):
def testInt(self):
with self.test_session() as sess:
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index 035cdea178..a4e7f75081 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -22,7 +22,7 @@ import collections
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
# TODO(srvasude): Merge this with
# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py.
-class CategoricalTest(XLATestCase):
+class CategoricalTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def output_dtypes(self):
diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py
index 1a8989d7c2..d2867278af 100644
--- a/tensorflow/compiler/tests/cholesky_op_test.py
+++ b/tensorflow/compiler/tests/cholesky_op_test.py
@@ -23,7 +23,7 @@ import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class CholeskyOpTest(XLATestCase):
+class CholeskyOpTest(xla_test.XLATestCase):
# Cholesky defined for float64, float32, complex64, complex128
# (https://www.tensorflow.org/api_docs/python/tf/cholesky)
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
index 574f82fc71..e42ebf8f9e 100644
--- a/tensorflow/compiler/tests/clustering_test.py
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -32,7 +32,7 @@ from tensorflow.python.platform import googletest
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
-class ClusteringTest(XLATestCase):
+class ClusteringTest(xla_test.XLATestCase):
def testAdd(self):
val1 = np.array([4, 3, 2, 1], dtype=np.float32)
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index f10973e19f..d9ad428147 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ConcatTest(XLATestCase):
+class ConcatTest(xla_test.XLATestCase):
def testHStack(self):
with self.test_session():
@@ -292,7 +292,7 @@ class ConcatTest(XLATestCase):
array_ops.concat([scalar, scalar, scalar], dim)
-class ConcatOffsetTest(XLATestCase):
+class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_session() as sess:
@@ -306,7 +306,7 @@ class ConcatOffsetTest(XLATestCase):
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
-class PackTest(XLATestCase):
+class PackTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_session() as sess:
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
index 62577b70ce..98d41ba7ed 100644
--- a/tensorflow/compiler/tests/conv2d_test.py
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -22,9 +22,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import test_utils
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
@@ -32,7 +34,15 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
-class Conv2DTest(XLATestCase):
+DATA_FORMATS = (
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+)
+
+
+class Conv2DTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -40,6 +50,8 @@ class Conv2DTest(XLATestCase):
strides=None,
dilations=None,
padding=None,
+ data_format_src="NHWC",
+ data_format_dst="NHWC",
expected=None):
"""Tests that tf.nn.conv2d produces the expected value.
@@ -51,8 +63,12 @@ class Conv2DTest(XLATestCase):
strides: Strides.
dilations: RHS dilations.
padding: Padding type.
+ data_format_src: Data format input is in.
+ data_format_dst: Data format verification will run and input is converted
+ to.
expected: Expected output.
"""
+
total_size_1 = np.prod(input_sizes)
total_size_2 = np.prod(filter_sizes)
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
@@ -62,6 +78,18 @@ class Conv2DTest(XLATestCase):
dilations = [1, 1]
dilations = [1] + dilations + [1]
+ # Convert between data formats.
+ expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src,
+ data_format_dst)
+ x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src,
+ data_format_dst)
+ input_sizes = test_utils.PermuteDimsBetweenDataFormats(
+ input_sizes, data_format_src, data_format_dst)
+ strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src,
+ data_format_dst)
+ dilations = test_utils.PermuteDimsBetweenDataFormats(
+ dilations, data_format_src, data_format_dst)
+
with self.test_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
@@ -71,12 +99,14 @@ class Conv2DTest(XLATestCase):
t2,
strides=strides,
padding=padding,
- data_format="NHWC",
+ data_format=data_format_dst,
dilations=dilations)
+
value = sess.run(out, {t1: x1, t2: x2})
self.assertAllClose(expected, value, 1e-3)
- def testConv2D1x1Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x1Filter(self, data_format):
expected_output = np.reshape([
30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
@@ -86,9 +116,12 @@ class Conv2DTest(XLATestCase):
filter_sizes=[1, 1, 3, 3],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Filter(self, data_format):
expected_output = np.reshape(
[2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3])
self._VerifyValues(
@@ -96,9 +129,12 @@ class Conv2DTest(XLATestCase):
filter_sizes=[2, 2, 3, 3],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Filter2x1Dilation(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Filter2x1Dilation(self, data_format):
expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]])
self._VerifyValues(
input_sizes=[1, 4, 4, 1],
@@ -106,9 +142,12 @@ class Conv2DTest(XLATestCase):
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2Filter(self, data_format):
expected_output = np.reshape([
231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0,
936.0, 1029.0
@@ -118,18 +157,24 @@ class Conv2DTest(XLATestCase):
filter_sizes=[1, 2, 3, 3],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2(self, data_format):
expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3])
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
strides=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2Same(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2Same(self, data_format):
expected_output = np.reshape(
[2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3])
self._VerifyValues(
@@ -137,47 +182,61 @@ class Conv2DTest(XLATestCase):
filter_sizes=[2, 2, 3, 3],
strides=[2, 2],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2DEmptyDilation(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DEmptyDilation(self, data_format):
self._VerifyValues(
input_sizes=[0, 2, 3, 3],
filter_sizes=[1, 1, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.zeros([0, 2, 3, 3]))
- def testConv2D2x2FilterDilation(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterDilation(self, data_format):
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[2, 2, 3, 3],
strides=[1, 1],
dilations=[1, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3]))
- def testConv2D1x2FilterDilation(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterDilation(self, data_format):
self._VerifyValues(
input_sizes=[1, 2, 3, 3],
filter_sizes=[1, 2, 3, 3],
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.array([[[[231, 252, 273], [384, 423, 462]],
[[690, 765, 840], [843, 936, 1029]]]]))
- def testConv2DKernelSizeMatchesInputSizeDilation(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DKernelSizeMatchesInputSizeDilation(self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[2, 2, 1, 2],
strides=[1, 1],
dilations=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.reshape([108, 128], [1, 1, 1, 2]))
-class Conv2DBackpropInputTest(XLATestCase):
+class Conv2DBackpropInputTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -186,6 +245,8 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=None,
dilations=None,
padding=None,
+ data_format_src="NHWC",
+ data_format_dst="NHWC",
expected=None):
"""Tests that gen_nn_ops.conv2d_backprop_input produces the expected output.
@@ -198,8 +259,12 @@ class Conv2DBackpropInputTest(XLATestCase):
strides: Strides.
dilations: Dilations.
padding: Padding type.
+ data_format_src: Data format input is in.
+ data_format_dst: Data format verification will run and input is converted
+ to.
expected: Expected output.
"""
+
total_size_1 = np.prod(filter_sizes)
total_size_2 = np.prod(out_backprop_sizes)
x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes)
@@ -209,6 +274,23 @@ class Conv2DBackpropInputTest(XLATestCase):
if dilations is not None:
dilations = [1] + dilations + [1]
+ expected = np.reshape(expected, input_sizes)
+
+ # Convert between data formats.
+ expected = test_utils.ConvertBetweenDataFormats(expected, data_format_src,
+ data_format_dst)
+ x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src,
+ data_format_dst)
+ input_sizes = test_utils.PermuteDimsBetweenDataFormats(
+ input_sizes, data_format_src, data_format_dst)
+ out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats(
+ out_backprop_sizes, data_format_src, data_format_dst)
+ strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src,
+ data_format_dst)
+ if dilations is not None:
+ dilations = test_utils.PermuteDimsBetweenDataFormats(
+ dilations, data_format_src, data_format_dst)
+
with self.test_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
@@ -220,12 +302,14 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=strides,
dilations=dilations,
padding=padding,
- data_format="NHWC")
+ data_format=data_format_dst)
+
value = sess.run(out, {t1: x1, t2: x2})
self.assertAllEqual(input_sizes, value.shape)
- self.assertAllClose(expected, np.ravel(value), 1e-3)
+ self.assertAllClose(expected, value, 1e-3)
- def testConv2D1x1Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x1Filter(self, data_format):
expected_output = [
5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127,
41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71,
@@ -237,9 +321,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 4, 4, 2],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width5(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width5(self, data_format):
expected_output = [1, 2, 0, 2, 4]
self._VerifyValues(
input_sizes=[1, 1, 5, 1],
@@ -247,9 +334,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width6(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width6(self, data_format):
expected_output = [1, 2, 0, 2, 4, 0]
self._VerifyValues(
input_sizes=[1, 1, 6, 1],
@@ -257,9 +347,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width7(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width7(self, data_format):
expected_output = [1, 2, 0, 2, 4, 0, 0]
self._VerifyValues(
input_sizes=[1, 1, 7, 1],
@@ -267,9 +360,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterC1Same(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterC1Same(self, data_format):
expected_output = [1, 4, 7, 7, 23, 33]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
@@ -277,9 +373,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 2, 3, 1],
strides=[1, 1],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Filter(self, data_format):
expected_output = [
14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604,
437, 482, 527
@@ -290,9 +389,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 3],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterSame(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterSame(self, data_format):
expected_output = [
14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217,
1505, 1487, 1883, 2279
@@ -303,9 +405,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 2, 3, 3],
strides=[1, 1],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2Filter(self, data_format):
expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
@@ -313,9 +418,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 3, 2, 1],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterSame(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterSame(self, data_format):
expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
@@ -323,9 +431,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 3, 3, 1],
strides=[1, 1],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2(self, data_format):
expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12]
self._VerifyValues(
input_sizes=[1, 3, 5, 1],
@@ -333,9 +444,12 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 2, 2, 1],
strides=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2Same(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2Same(self, data_format):
expected_output = [1, 2, 2, 3, 4, 6]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
@@ -343,9 +457,13 @@ class Conv2DBackpropInputTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[2, 2],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(
+ self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 6, 1],
filter_sizes=[2, 2, 1, 1],
@@ -353,9 +471,12 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20])
- def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self, data_format):
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
@@ -363,9 +484,12 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=[1, 1],
dilations=[1, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[1, 0, 2, 3, 0, 4])
- def testConv2DEmptyBackpropInputDilation1x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DEmptyBackpropInputDilation1x2(self, data_format):
self._VerifyValues(
input_sizes=[0, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
@@ -373,9 +497,12 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=[1, 1],
dilations=[1, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.zeros([0]))
- def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self, data_format):
# The GPU version of this test is not very stable. So adjusting the
# error threshold to 1e-4.
self._VerifyValues(
@@ -385,12 +512,16 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[
14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194,
212
])
- def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(
+ self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[2, 2, 1, 2],
@@ -398,10 +529,12 @@ class Conv2DBackpropInputTest(XLATestCase):
strides=[1, 1],
dilations=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[5, 0, 11, 0, 0, 0, 17, 0, 23])
-class Conv2DBackpropFilterTest(XLATestCase):
+class Conv2DBackpropFilterTest(xla_test.XLATestCase, parameterized.TestCase):
def _VerifyValues(self,
input_sizes=None,
@@ -410,6 +543,8 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=None,
dilations=None,
padding=None,
+ data_format_src="NHWC",
+ data_format_dst="NHWC",
expected=None):
"""Tests that gen_nn_ops.conv2d_backprop_filter produces the right output.
@@ -422,6 +557,9 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides: Stride.
dilations: Dilations.
padding: Padding type.
+ data_format_src: Data format input is in.
+ data_format_dst: Data format verification will run and input is converted
+ to.
expected: Expected output.
"""
@@ -434,6 +572,23 @@ class Conv2DBackpropFilterTest(XLATestCase):
if dilations is not None:
dilations = [1] + dilations + [1]
+ expected = np.reshape(expected, filter_sizes)
+
+ # Convert between data formats.
+ x1 = test_utils.ConvertBetweenDataFormats(x1, data_format_src,
+ data_format_dst)
+ x2 = test_utils.ConvertBetweenDataFormats(x2, data_format_src,
+ data_format_dst)
+ input_sizes = test_utils.PermuteDimsBetweenDataFormats(
+ input_sizes, data_format_src, data_format_dst)
+ out_backprop_sizes = test_utils.PermuteDimsBetweenDataFormats(
+ out_backprop_sizes, data_format_src, data_format_dst)
+ strides = test_utils.PermuteDimsBetweenDataFormats(strides, data_format_src,
+ data_format_dst)
+ if dilations is not None:
+ dilations = test_utils.PermuteDimsBetweenDataFormats(
+ dilations, data_format_src, data_format_dst)
+
with self.test_session() as sess:
t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
@@ -445,13 +600,14 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=strides,
dilations=dilations,
padding=padding,
- data_format="NHWC")
+ data_format=data_format_dst)
value = sess.run(tensor, {t1: x1, t2: x2})
self.assertAllEqual(filter_sizes, value.shape)
- self.assertAllClose(expected, np.ravel(value), 1e-3)
+ self.assertAllClose(expected, value, 1e-3)
- def testConv2D1x1Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x1Filter(self, data_format):
expected_output = [8056, 8432, 8312, 8704, 8568, 8976]
self._VerifyValues(
input_sizes=[1, 4, 4, 3],
@@ -459,9 +615,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 4, 4, 2],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2Filter(self, data_format):
expected_output = [120, 141]
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
@@ -469,9 +628,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 3, 2, 1],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterDepth1(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterDepth1(self, data_format):
expected_output = [5, 8, 14, 17]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
@@ -479,9 +641,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Filter(self, data_format):
expected_output = [
17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72,
62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87,
@@ -493,9 +658,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 3],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width5(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width5(self, data_format):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 5, 1],
@@ -503,9 +671,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width6(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width6(self, data_format):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 6, 1],
@@ -513,9 +684,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x2FilterStride3Width7(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x2FilterStride3Width7(self, data_format):
expected_output = [9, 12]
self._VerifyValues(
input_sizes=[1, 1, 7, 1],
@@ -523,9 +697,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[3, 3],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x3Filter(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x3Filter(self, data_format):
expected_output = [5, 8, 11]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
@@ -533,9 +710,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[1, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x3FilterSame(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x3FilterSame(self, data_format):
expected_output = [20, 30, 20]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
@@ -543,9 +723,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 4, 1],
strides=[1, 1],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D1x3FilterSameOutbackprop2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D1x3FilterSameOutbackprop2(self, data_format):
expected_output = [7, 10, 3]
self._VerifyValues(
input_sizes=[1, 1, 4, 1],
@@ -553,9 +736,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[2, 2],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterC1Same(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterC1Same(self, data_format):
expected_output = [91, 58, 32, 17]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
@@ -563,9 +749,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 2, 3, 1],
strides=[1, 1],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2(self, data_format):
expected_output = [92, 102, 112]
self._VerifyValues(
input_sizes=[1, 3, 5, 1],
@@ -573,9 +762,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 2, 2, 1],
strides=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2FilterStride2Same(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2FilterStride2Same(self, data_format):
expected_output = [7, 2, 16, 5]
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
@@ -583,9 +775,13 @@ class Conv2DBackpropFilterTest(XLATestCase):
out_backprop_sizes=[1, 1, 2, 1],
strides=[2, 2],
padding="SAME",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=expected_output)
- def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(
+ self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 6, 1],
filter_sizes=[2, 2, 1, 1],
@@ -593,9 +789,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=[1, 1],
dilations=[2, 1],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[55, 70, 235, 250])
- def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self, data_format):
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 1],
@@ -603,9 +802,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=[1, 1],
dilations=[1, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[1, 3, 4, 6])
- def testConv2DEmptyBackpropFilterDilation1x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DEmptyBackpropFilterDilation1x2(self, data_format):
self._VerifyValues(
input_sizes=[1, 2, 3, 1],
filter_sizes=[2, 2, 1, 0],
@@ -613,9 +815,12 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=[1, 1],
dilations=[1, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=np.zeros([0]))
- def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 4, 3],
filter_sizes=[2, 2, 3, 3],
@@ -623,13 +828,17 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=[1, 1],
dilations=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[
17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78,
99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172,
239, 306, 177, 246, 315
])
- def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self):
+ @parameterized.named_parameters(*DATA_FORMATS)
+ def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(
+ self, data_format):
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[2, 2, 1, 2],
@@ -637,6 +846,8 @@ class Conv2DBackpropFilterTest(XLATestCase):
strides=[1, 1],
dilations=[2, 2],
padding="VALID",
+ data_format_src="NHWC",
+ data_format_dst=data_format,
expected=[1, 2, 3, 6, 7, 14, 9, 18])
diff --git a/tensorflow/compiler/tests/conv3d_test.py b/tensorflow/compiler/tests/conv3d_test.py
index 3bebf46511..31ee41f04f 100644
--- a/tensorflow/compiler/tests/conv3d_test.py
+++ b/tensorflow/compiler/tests/conv3d_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -33,7 +33,7 @@ from tensorflow.python.platform import googletest
# Test cloned from
# tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
-class Conv3DBackpropFilterV2GradTest(XLATestCase):
+class Conv3DBackpropFilterV2GradTest(xla_test.XLATestCase):
def testGradient(self):
with self.test_session(), self.test_scope():
@@ -66,7 +66,7 @@ class Conv3DBackpropFilterV2GradTest(XLATestCase):
# Test cloned from tensorflow/python/kernel_tests/conv3d_transpose_test.py
-class Conv3DTransposeTest(XLATestCase):
+class Conv3DTransposeTest(xla_test.XLATestCase):
def testConv3DTransposeSingleStride(self):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py
index 03d96a2cd8..98dc73e189 100644
--- a/tensorflow/compiler/tests/depthwise_conv_op_test.py
+++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -114,7 +114,7 @@ def CheckGradConfigsToTest():
yield i, f, o, s, p
-class DepthwiseConv2DTest(XLATestCase):
+class DepthwiseConv2DTest(xla_test.XLATestCase):
# This is testing that depthwise_conv2d and depthwise_conv2d_native
# produce the same results. It also tests that NCHW and NWHC
diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
index 6a46d2ec3e..154e36b10e 100644
--- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py
+++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DynamicUpdateSliceOpsTest(XLATestCase):
+class DynamicUpdateSliceOpsTest(xla_test.XLATestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index c109c27abe..edd78153b5 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import googletest
-class DynamicStitchTest(XLATestCase):
+class DynamicStitchTest(xla_test.XLATestCase):
def _AssertDynamicStitchResultIs(self, indices, data, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 4dff5f0f40..3524666499 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -31,14 +31,16 @@ from tensorflow.python.framework import ops
from tensorflow.python.layers import convolutional
from tensorflow.python.layers import pooling
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import googletest
+from tensorflow.python.training import adam
-class EagerTest(XLATestCase):
+class EagerTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_scope():
@@ -47,6 +49,21 @@ class EagerTest(XLATestCase):
product = three * five
self.assertAllEqual(15, product)
+ def testGradientTape(self):
+ with self.test_scope():
+
+ x = constant_op.constant(1.0)
+ y = constant_op.constant(10.0)
+ with backprop.GradientTape(persistent=True) as tape:
+ tape.watch(x)
+ tape.watch(y)
+ a = x + y + x * y
+ da_dx = tape.gradient(a, x)
+ da_dy = tape.gradient(a, y)
+
+ self.assertEqual(11.0, da_dx.numpy())
+ self.assertEqual(2.0, da_dy.numpy())
+
def testExecuteListOutputLen0(self):
with self.test_scope():
empty = constant_op.constant([], dtype=dtypes.float32)
@@ -160,12 +177,120 @@ class EagerTest(XLATestCase):
for _ in range(100):
values.append(var.value())
+ # The shape, shape_n, size, and rank are tested here because their
+ # execution kernels (as opposed to compilation only tf2xla kernels)
+ # are distincts from tf2xla kernels.
+
+ def testShape(self):
+ def const(value):
+ return array_ops.shape(
+ constant_op.constant(value)).numpy()
+
+ def ones(value):
+ return array_ops.shape(
+ array_ops.ones(value)).numpy()
+
+ with self.test_scope():
+ # Shapes of directly constructed tensors
+ self.assertAllEqual([], const(3))
+ self.assertAllEqual([3], const([1.0, 2.0, 3.0]))
+ self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]]))
+ self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]]))
+
+ # Shapes of tensors created by op running on device
+ # We make this distinction because directly constructed tensors
+ # are treated differently in a few places that can influence shape:
+ # - they always have on_host_tensor
+ # - they and their shapes can be cached
+ # - they end up on device via a copy, instead of as program output
+ self.assertAllEqual([], ones([]))
+ self.assertAllEqual([3], ones([3]))
+ self.assertAllEqual([2, 2], ones([2, 2]))
+ self.assertAllEqual([2, 1, 2], ones([2, 1, 2]))
+
+ def testShapeN(self):
+ with self.test_scope():
+ # Shapes of directly constructed tensors
+ shapes = array_ops.shape_n([
+ constant_op.constant(1.0),
+ constant_op.constant([1.0, 2.0, 3.0]),
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])])
+ self.assertAllEqual(
+ [[], [3], [2, 2]],
+ [x.numpy().tolist() for x in shapes])
+
+ # Shapes of tensors created by op running on device
+ shapes = array_ops.shape_n([
+ array_ops.ones([]),
+ array_ops.ones([3]),
+ array_ops.ones([2, 2])])
+ self.assertAllEqual(
+ [[], [3], [2, 2]],
+ [x.numpy().tolist() for x in shapes])
+
+ def testSize(self):
+ with self.test_scope():
+ self.assertEqual(
+ 1, array_ops.size(constant_op.constant(1.0)).numpy())
+ self.assertEqual(
+ 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy())
+ self.assertEqual(
+ 4, array_ops.size(
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
+
+ def testRank(self):
+ with self.test_scope():
+ self.assertEqual(
+ 0, array_ops.rank(constant_op.constant(1.0)).numpy())
+ self.assertEqual(
+ 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy())
+ self.assertEqual(
+ 2, array_ops.rank(
+ constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy())
+
+ def testAdam(self):
+ with self.test_scope():
+ optimizer = adam.AdamOptimizer(0.1)
+ x = resource_variable_ops.ResourceVariable(10.0)
+ with backprop.GradientTape() as tape:
+ y = x * x
+ dy_dx = tape.gradient(y, x)
+ optimizer.apply_gradients([(dy_dx, x)])
+ self.assertAlmostEqual(9.9, x.numpy(), places=3)
+
+ def testAdamSparse(self):
+ with ops.device('/cpu:0'):
+ # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates
+ # are not implemented on TPU.
+ embedding_matrix = resource_variable_ops.ResourceVariable(
+ array_ops.ones([3, 2]))
+
+ with self.test_scope():
+ with backprop.GradientTape() as tape:
+ embedding = embedding_ops.embedding_lookup(embedding_matrix, [1])
+ y = math_ops.reduce_sum(embedding)
+ dy_dx = tape.gradient(y, embedding_matrix)
+ self.assertIsInstance(dy_dx, ops.IndexedSlices)
+ optimizer = adam.AdamOptimizer(0.1)
+ # The gradient application operations will run on CPU because optimizer
+ # updates are always collocated with the variable.
+ optimizer.apply_gradients([(dy_dx, embedding_matrix)])
+
+ # This assign_add will run on CPU because when an input to an
+ # operation is a resource, this operation is placed on the resource's
+ # device by the eager runtime.
+ embedding_matrix.assign_add(array_ops.ones([3, 2]))
+
+ self.assertAllClose([[2.0, 2.0],
+ [1.9, 1.9],
+ [2.0, 2.0]], embedding_matrix.numpy())
+
-class EagerFunctionTest(XLATestCase):
+class EagerFunctionTest(xla_test.XLATestCase):
def testBasic(self):
with self.test_scope():
- matmul = function.defun(math_ops.matmul, compiled=True)
+ matmul = function.defun(math_ops.matmul)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
sq = matmul(t, t, transpose_a=True)
self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20])
@@ -187,7 +312,7 @@ class EagerFunctionTest(XLATestCase):
def model(x):
x = conv(x)
return pool(x)
- model = function.defun(model, compiled=True)
+ model = function.defun(model)
x = array_ops.ones([1, 4, 4, 1])
y = model(x)
@@ -197,7 +322,7 @@ class EagerFunctionTest(XLATestCase):
with self.test_scope():
v = resource_variable_ops.ResourceVariable(1.0)
- @function.defun(compiled=True)
+ @function.defun
def f():
return v.read_value()
@@ -212,7 +337,7 @@ class EagerFunctionTest(XLATestCase):
v.assign_add(1.0)
return v
- f = function.defun(f, compiled=True)
+ f = function.defun(f)
var = f(v)
self.assertEqual(2.0, var.numpy())
@@ -240,7 +365,7 @@ class EagerFunctionTest(XLATestCase):
d = r2 * v2
return a, b, c, d
- foo = function.defun(foo, compiled=True)
+ foo = function.defun(foo)
c1 = [0, 0]
c2 = array_ops.ones([2], dtype=dtypes.int32)
@@ -262,7 +387,7 @@ class EagerFunctionTest(XLATestCase):
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
- @function.defun(compiled=True)
+ @function.defun
def f(x):
x = v0 * v0 * x
return x
@@ -275,8 +400,26 @@ class EagerFunctionTest(XLATestCase):
self.assertEqual(75, y.numpy())
self.assertEqual(30, dy.numpy())
+ def testSliceInDefun(self):
+ with self.test_scope():
-class ExcessivePaddingTest(XLATestCase):
+ @function.defun(compiled=True)
+ def f(x, y):
+ return x[0::2, y:, ...]
+
+ x = array_ops.ones([2, 3, 4])
+ y = array_ops.ones([], dtype=dtypes.int32)
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ tape.watch(y)
+ z = f(x, y)
+ dz = tape.gradient(z, x)
+
+ self.assertAllEqual(np.ones([1, 2, 4]), z.numpy())
+ self.assertAllEqual((2, 3, 4), dz.shape.as_list())
+
+
+class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
Tensors that would normally be excessively padded when written
@@ -307,7 +450,7 @@ class ExcessivePaddingTest(XLATestCase):
def testAsFunctionInput(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x):
return math_ops.reduce_sum(x, axis=2)
@@ -318,7 +461,7 @@ class ExcessivePaddingTest(XLATestCase):
def testAsFunctionOutput(self):
with self.test_scope():
- @function.defun(compiled=True)
+ @function.defun
def f(x):
return x * constant_op.constant(100 * [[[10.0, 2.0]]])
diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py
index 0361702e7a..5529fdbb09 100644
--- a/tensorflow/compiler/tests/extract_image_patches_op_test.py
+++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ExtractImagePatches(XLATestCase):
+class ExtractImagePatches(xla_test.XLATestCase):
"""Functional tests for ExtractImagePatches op."""
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py
index dfe9400ef0..c48ab178bf 100644
--- a/tensorflow/compiler/tests/fake_quant_ops_test.py
+++ b/tensorflow/compiler/tests/fake_quant_ops_test.py
@@ -17,14 +17,14 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import googletest
-class FakeQuantWithMinMaxArgsTest(XLATestCase):
+class FakeQuantWithMinMaxArgsTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgs operation."""
# 8 bits, wide range.
@@ -122,7 +122,7 @@ class FakeQuantWithMinMaxArgsTest(XLATestCase):
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
+class FakeQuantWithMinMaxArgsGradientTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
# 8 bits, wide range.
@@ -223,7 +223,7 @@ class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxVarsTest(XLATestCase):
+class FakeQuantWithMinMaxVarsTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxVars operation."""
# 8 bits, wide range.
@@ -328,7 +328,7 @@ class FakeQuantWithMinMaxVarsTest(XLATestCase):
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
-class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
+class FakeQuantWithMinMaxVarsGradientTest(xla_test.XLATestCase):
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
# 8 bits, wide range.
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index afb5fa4bb4..c64ea249ec 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -23,10 +23,11 @@ import itertools
import numpy as np
import scipy.signal as sps
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import spectral_ops
from tensorflow.python.platform import googletest
@@ -57,7 +58,7 @@ INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2))
INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2))
-class FFTTest(XLATestCase):
+class FFTTest(xla_test.XLATestCase):
def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
tf_method):
@@ -97,8 +98,11 @@ class FFTTest(XLATestCase):
ph = array_ops.placeholder(
dtypes.as_dtype(data.dtype), shape=data.shape)
out = signal.stft(ph, ws, hs)
+ grad = gradients_impl.gradients(out, ph,
+ grad_ys=array_ops.ones_like(out))
- value = sess.run(out, {ph: data})
+ # For gradients, we simply verify that they compile & execute.
+ value, _ = sess.run([out, grad], {ph: data})
self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
def testFFT(self):
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
new file mode 100644
index 0000000000..0f64cc87cd
--- /dev/null
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import test
+
+
+class FIFOQueueTest(xla_test.XLATestCase):
+
+ def testEnqueue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ enqueue_op.run()
+
+ def testEnqueueWithShape(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+ self.assertEqual(1, q.size().eval())
+
+ def testMultipleDequeues(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue([1]))
+ self.evaluate(q.enqueue([2]))
+ self.evaluate(q.enqueue([3]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ def testQueuesDontShare(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
+
+ def testEnqueueDictWithoutNames(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ with self.assertRaisesRegexp(ValueError, "must have names"):
+ q.enqueue({"a": 12.0})
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+
+ threads = [
+ self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
+ ]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i]], vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ for elem, result in zip(elems, results):
+ self.assertEqual([elem], result)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ x_val, y_val = sess.run(dequeued_t)
+ x, y = elems[i]
+ self.assertEqual([x], x_val)
+ self.assertEqual([y], y_val)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ self.assertEqual([0], q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual(1, size.eval())
+ dequeued_t.op.run()
+ self.assertEqual(0, size.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py
index 8e6407dffd..1da97fd512 100644
--- a/tensorflow/compiler/tests/ftrl_test.py
+++ b/tensorflow/compiler/tests/ftrl_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -30,7 +30,7 @@ from tensorflow.python.training import ftrl
from tensorflow.python.training import gradient_descent
-class FtrlOptimizerTest(XLATestCase):
+class FtrlOptimizerTest(xla_test.XLATestCase):
def initVariableAndGradient(self, dtype):
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index 8a3f4b0bdc..04fba44446 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class FunctionTest(XLATestCase):
+class FunctionTest(xla_test.XLATestCase):
def testFunction(self):
"""Executes a simple TensorFlow function."""
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index a80d69fa5f..132e42ac7a 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import test_utils
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
@@ -28,7 +30,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.platform import test
-class FusedBatchNormTest(XLATestCase):
+class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
def _reference_training(self, x, scale, offset, epsilon, data_format):
if data_format != "NHWC":
@@ -63,24 +65,36 @@ class FusedBatchNormTest(XLATestCase):
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
- def testInference(self):
+ @parameterized.named_parameters(
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+ )
+ def testInference(self, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
-
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
- data_format = "NHWC"
+ epsilon = 0.001
+ data_format_src = "NHWC"
+ y_ref, mean_ref, var_ref = self._reference_training(
+ x_val, scale_val, offset_val, epsilon, data_format_src)
+
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
- t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
+ x_val_converted = test_utils.ConvertBetweenDataFormats(
+ x_val, data_format_src, data_format)
+ y_ref_converted = test_utils.ConvertBetweenDataFormats(
+ y_ref, data_format_src, data_format)
+
+ t_val = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
- epsilon = 0.001
- y_ref, mean_ref, var_ref = self._reference_training(
- x_val, scale_val, offset_val, epsilon, data_format)
y, mean, variance = nn.fused_batch_norm(
t_val,
scale,
@@ -91,31 +105,39 @@ class FusedBatchNormTest(XLATestCase):
data_format=data_format,
is_training=False)
- y_val, _, _ = sess.run(
- [y, mean,
- variance], {t_val: x_val,
- scale: scale_val,
- offset: offset_val})
- self.assertAllClose(y_val, y_ref, atol=1e-3)
+ y_val, _, _ = sess.run([y, mean, variance], {
+ t_val: x_val_converted,
+ scale: scale_val,
+ offset: offset_val
+ })
+ self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
- def _testLearning(self, use_gradient_checker):
+ def _testLearning(self, use_gradient_checker, data_format):
channel = 3
x_shape = [2, 2, 6, channel]
scale_shape = [channel]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
-
offset_val = np.random.random_sample(scale_shape).astype(np.float32)
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)
- data_format = "NHWC"
+ epsilon = 0.001
+ data_format_src = "NHWC"
+ y_ref, mean_ref, var_ref = self._reference_training(
+ x_val, scale_val, offset_val, epsilon, data_format_src)
+
with self.test_session() as sess, self.test_scope():
# To avoid constant folding
- t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
+ x_val_converted = test_utils.ConvertBetweenDataFormats(
+ x_val, data_format_src, data_format)
+ y_ref_converted = test_utils.ConvertBetweenDataFormats(
+ y_ref, data_format_src, data_format)
+
+ t_val = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
- epsilon = 0.001
y, mean, var = nn.fused_batch_norm(
t_val,
scale,
@@ -129,33 +151,50 @@ class FusedBatchNormTest(XLATestCase):
if use_gradient_checker:
err = gradient_checker.compute_gradient_error(
t_val,
- x_shape,
+ x_val_converted.shape,
y,
- x_shape,
+ x_val_converted.shape,
extra_feed_dict={
- t_val: x_val,
+ t_val: x_val_converted,
scale: scale_val,
offset: offset_val
})
self.assertLess(err, 1e-3)
- y_val, mean_val, var_val = sess.run(
- [y, mean, var], {t_val: x_val,
- scale: scale_val,
- offset: offset_val})
- y_ref, mean_ref, var_ref = self._reference_training(
- x_val, scale_val, offset_val, epsilon, data_format)
+ y_val, mean_val, var_val = sess.run([y, mean, var], {
+ t_val: x_val_converted,
+ scale: scale_val,
+ offset: offset_val
+ })
self.assertAllClose(mean_val, mean_ref, atol=1e-3)
- self.assertAllClose(y_val, y_ref, atol=1e-3)
+ self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
self.assertAllClose(var_val, var_ref, atol=1e-3)
- def testLearning(self):
- self._testLearning(False)
+ @parameterized.named_parameters(
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+ )
+ def testLearning(self, data_format):
+ self._testLearning(False, data_format)
- def testLearningWithGradientChecker(self):
- self._testLearning(True)
+ @parameterized.named_parameters(
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+ )
+ def testLearningWithGradientChecker(self, data_format):
+ self._testLearning(True, data_format)
- def testGradientTraining(self):
+ @parameterized.named_parameters(
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+ )
+ def testGradientTraining(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
channel = 3
@@ -167,33 +206,48 @@ class FusedBatchNormTest(XLATestCase):
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)
epsilon = 0.001
+ data_format_src = "NHWC"
+ grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
+ x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
with self.test_session() as sess, self.test_scope():
- grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
- x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
+ grad_val_converted = test_utils.ConvertBetweenDataFormats(
+ grad_val, data_format_src, data_format)
+ x_val_converted = test_utils.ConvertBetweenDataFormats(
+ x_val, data_format_src, data_format)
+ grad_x_ref_converted = test_utils.ConvertBetweenDataFormats(
+ grad_x_ref, data_format_src, data_format)
+
+ grad = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="grad")
+ x = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="x")
mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
- grad, x, scale, mean, var, data_format="NHWC", is_training=True)
+ grad, x, scale, mean, var, data_format=data_format, is_training=True)
grad_x_val, grad_scale_val, grad_offset_val = sess.run(
[grad_x, grad_scale, grad_offset], {
- grad: grad_val,
- x: x_val,
+ grad: grad_val_converted,
+ x: x_val_converted,
mean: mean_val,
var: var_val,
scale: scale_val
})
- grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
- x_val, grad_val, scale_val, mean_val, var_val, epsilon, "NHWC")
-
- self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
+ self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2)
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
- def testGradientInference(self):
+ @parameterized.named_parameters(
+ ("_data_format_NHWC", "NHWC"),
+ ("_data_format_NCHW", "NCHW"),
+ ("_data_format_HWNC", "HWNC"),
+ ("_data_format_HWCN", "HWCN"),
+ )
+ def testGradientInference(self, data_format):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
channel = 3
@@ -204,33 +258,47 @@ class FusedBatchNormTest(XLATestCase):
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)
+ data_format_src = "NHWC"
with self.test_session() as sess, self.test_scope():
- grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
- x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
+ grad_val_converted = test_utils.ConvertBetweenDataFormats(
+ grad_val, data_format_src, data_format)
+ x_val_converted = test_utils.ConvertBetweenDataFormats(
+ x_val, data_format_src, data_format)
+
+ grad = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="grad")
+ x = array_ops.placeholder(
+ np.float32, shape=x_val_converted.shape, name="x")
mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
with self.test_scope():
out = gen_nn_ops.fused_batch_norm_grad(
- grad, x, scale, mean, var, data_format="NHWC", is_training=False)
+ grad,
+ x,
+ scale,
+ mean,
+ var,
+ data_format=data_format,
+ is_training=False)
grad_x, grad_scale, grad_offset, _, _ = out
ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
- grad, x, scale, mean, var, data_format="NHWC", is_training=False)
+ grad, x, scale, mean, var, data_format=data_format, is_training=False)
grad_x_val, grad_scale_val, grad_offset_val, = sess.run(
[grad_x, grad_scale, grad_offset], {
- grad: grad_val,
- x: x_val,
+ grad: grad_val_converted,
+ x: x_val_converted,
mean: mean_val,
var: var_val,
scale: scale_val
})
grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run(
[ref_x, ref_scale, ref_offset], {
- grad: grad_val,
- x: x_val,
+ grad: grad_val_converted,
+ x: x_val_converted,
mean: mean_val,
var: var_val,
scale: scale_val
diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py
index 9378b1db72..23b0aed34f 100644
--- a/tensorflow/compiler/tests/gather_nd_op_test.py
+++ b/tensorflow/compiler/tests/gather_nd_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GatherNdTest(XLATestCase):
+class GatherNdTest(xla_test.XLATestCase):
def _runGather(self, params, indices):
with self.test_session():
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 1a8c451911..e9c8ef7c91 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -136,6 +136,20 @@ class GatherTest(xla_test.XLATestCase):
self.assertAllEqual(
[[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]}))
+ def testGatherPrecision(self):
+ with self.test_session() as session, self.test_scope():
+ data = np.array([[0, 0, 0, 0], [0, 2 * (1 + np.exp2(-8)), 0, 0],
+ [0, 0, 0, 0], [0.015789, 0.0985, 0.55789, 0.3842]])
+ indices = np.array([1, 2, 3, 1])
+ dtype = dtypes.float32
+ params_np = self._buildParams(data, dtype)
+ params = array_ops.placeholder(dtype=dtype)
+ indices_tf = constant_op.constant(indices)
+ gather_t = array_ops.gather(params, indices_tf)
+ gather_val = session.run(gather_t, feed_dict={params: params_np})
+ np_val = params_np[indices]
+ self.assertAllEqual(np_val, gather_val)
+
class GatherBenchmark(test.Benchmark):
"""Microbenchmarks for the gather op."""
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 42e637734c..8b01ef96db 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -25,7 +25,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,7 +41,7 @@ def GenerateNumpyRandomRGB(shape):
return np.random.randint(0, 256, shape) / 256.
-class RGBToHSVTest(XLATestCase):
+class RGBToHSVTest(xla_test.XLATestCase):
def testBatch(self):
# Build an arbitrary RGB image
@@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase):
join1 = array_ops.stack(split1)
join2 = array_ops.stack(split2)
batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2],
- {
- batch0: inp
- })
+ {batch0: inp})
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1)
@@ -106,7 +104,7 @@ class RGBToHSVTest(XLATestCase):
self.assertAllCloseAccordingToType(hsv_tf, hsv_np)
-class AdjustContrastTest(XLATestCase):
+class AdjustContrastTest(xla_test.XLATestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
with self.test_session():
@@ -170,7 +168,7 @@ class AdjustContrastTest(XLATestCase):
self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5)
-class AdjustHueTest(XLATestCase):
+class AdjustHueTest(xla_test.XLATestCase):
def testAdjustNegativeHue(self):
x_shape = [2, 2, 3]
@@ -305,7 +303,7 @@ class AdjustHueTest(XLATestCase):
self._adjustHueTf(x_np, delta_h)
-class AdjustSaturationTest(XLATestCase):
+class AdjustSaturationTest(xla_test.XLATestCase):
def _adjust_saturation(self, image, saturation_factor):
image = ops.convert_to_tensor(image, name="image")
@@ -401,18 +399,17 @@ class AdjustSaturationTest(XLATestCase):
x = array_ops.placeholder(dtypes.float32, shape=x_shape)
with self.test_scope():
y_fused = self._adjust_saturation(x,
- scale).eval(feed_dict={
- x: x_np
- })
+ scale).eval(feed_dict={x: x_np})
self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5)
-class ResizeBilinearTest(XLATestCase):
+class ResizeBilinearTest(xla_test.XLATestCase):
def _assertForwardOpMatchesExpected(self,
image_np,
target_shape,
- expected=None):
+ expected=None,
+ large_tolerance=False):
if expected is None:
self.fail("expected must be specified")
with self.test_session() as sess, self.test_scope():
@@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase):
resized = gen_image_ops.resize_bilinear(
image, target_shape, align_corners=True)
out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]})
- self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out)
+ if large_tolerance:
+ self.assertAllClose(
+ expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1)
+ else:
+ self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out)
def _assertBackwardOpMatchesExpected(self,
grads_np,
@@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase):
[[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]],
dtype=np.float32))
+ def testAlignCorners4x4To8x8(self):
+ self._assertForwardOpMatchesExpected(
+ (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array(
+ [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8],
+ expected=3 *
+ (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
+ [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)),
+ large_tolerance=True)
+
+ def testAlignCorners8x8To16x16(self):
+ self._assertForwardOpMatchesExpected(
+ (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array(
+ [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0,
+ [16, 16],
+ expected=7 * (np.array(
+ [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],
+ dtype=np.float32) + np.array(
+ [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11],
+ [12], [13], [14], [15]],
+ dtype=np.float32)),
+ large_tolerance=True)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
index 69bd8f7230..253b45902f 100644
--- a/tensorflow/compiler/tests/lrn_ops_test.py
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -22,7 +22,7 @@ import copy
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -36,7 +36,7 @@ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
# Local response normalization tests. The forward tests are copied from
# tensorflow/python/kernel_tests/lrn_op_test.py
-class LRNTest(XLATestCase):
+class LRNTest(xla_test.XLATestCase):
def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
beta=0.5):
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 29394f9ea5..0d9f99f8a6 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -19,14 +19,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MatrixBandPartTest(XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase):
def _testMatrixBandPart(self, dtype, shape):
with self.test_session():
diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
index 5819b2bf2b..2bb8a97bda 100644
--- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
@@ -22,7 +22,7 @@ import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -35,7 +35,7 @@ def MakePlaceholder(x):
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
-class MatrixTriangularSolveOpTest(XLATestCase):
+class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
# MatrixTriangularSolve defined for float64, float32, complex64, complex128
# (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve)
diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py
index af9394e7d7..c2592c54cf 100644
--- a/tensorflow/compiler/tests/momentum_test.py
+++ b/tensorflow/compiler/tests/momentum_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import momentum as momentum_lib
-class MomentumOptimizerTest(XLATestCase):
+class MomentumOptimizerTest(xla_test.XLATestCase):
def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
var += accum * lr * momentum
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
index e4843b169b..da08225e9f 100644
--- a/tensorflow/compiler/tests/nary_ops_test.py
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -22,14 +22,14 @@ import unittest
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class NAryOpsTest(XLATestCase):
+class NAryOpsTest(xla_test.XLATestCase):
def _testNAry(self, op, args, expected, equality_fn=None):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index 6f588d8ab5..2f9122645d 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import googletest
-class NullaryOpsTest(XLATestCase):
+class NullaryOpsTest(xla_test.XLATestCase):
def _testNullary(self, op, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
index 5e6d1313bd..a75d99189b 100644
--- a/tensorflow/compiler/tests/placeholder_test.py
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
-class PlaceholderTest(XLATestCase):
+class PlaceholderTest(xla_test.XLATestCase):
def test_placeholder_with_default_default(self):
with self.test_session() as sess, self.test_scope():
diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py
index 4eed903963..17f860db61 100644
--- a/tensorflow/compiler/tests/pooling_ops_3d_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,7 +41,7 @@ def _AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding):
padding=padding)
-class Pooling3DTest(XLATestCase):
+class Pooling3DTest(xla_test.XLATestCase):
def _VerifyValues(self, pool_func, input_sizes, window, strides, padding,
expected):
@@ -187,8 +187,14 @@ class Pooling3DTest(XLATestCase):
padding="VALID",
expected=[29.5, 32.5, 50.5, 53.5, 176.5, 179.5, 197.5, 200.5])
- def _VerifyGradient(self, pool_func, pool_grad_func, input_sizes, ksize,
- strides, padding):
+ def _VerifyGradient(self,
+ pool_func,
+ pool_grad_func,
+ input_sizes,
+ ksize,
+ strides,
+ padding,
+ pool_grad_grad_func=None):
"""Verifies the output values of the pooling gradient function.
Args:
@@ -198,6 +204,7 @@ class Pooling3DTest(XLATestCase):
ksize: The kernel size dimensions
strides: The stride dimensions
padding: Padding type.
+ pool_grad_grad_func: Second-order gradient function, if available.
"""
ksize = [1] + ksize + [1]
strides = [1] + strides + [1]
@@ -218,6 +225,8 @@ class Pooling3DTest(XLATestCase):
output_gradient_vals = np.arange(
1, output_vals.size + 1, dtype=np.float32)
output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
+ output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32)
+ output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape)
# Use the Tensorflow CPU pooling gradient to compute the expected input
# gradients.
@@ -236,6 +245,22 @@ class Pooling3DTest(XLATestCase):
{inputs: x,
output_gradients: output_gradient_vals})
+ output_grad_gradients = array_ops.placeholder(
+ dtypes.float32, shape=expected_input_gradient_vals.shape)
+ if pool_grad_grad_func is not None:
+ expected_grad_gradients = pool_grad_grad_func(
+ inputs,
+ outputs,
+ output_grad_gradients,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format="NDHWC")
+ expected_grad_gradients_vals = sess.run(expected_grad_gradients, {
+ inputs: x,
+ output_grad_gradients: output_grad_grad_vals
+ })
+
# Run the gradient op on the XLA device
with self.test_scope():
outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
@@ -246,6 +271,16 @@ class Pooling3DTest(XLATestCase):
ksize=ksize,
strides=strides,
padding=padding)
+ if pool_grad_grad_func is not None:
+ actual_grad_gradients = pool_grad_grad_func(
+ inputs,
+ outputs,
+ output_grad_gradients,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format="NDHWC")
+
actual = sess.run(actual_input_gradients, {
inputs: x,
outputs: output_vals,
@@ -260,6 +295,22 @@ class Pooling3DTest(XLATestCase):
atol=1e-6)
self.assertShapeEqual(actual, inputs)
+ if pool_grad_grad_func is not None:
+ actual_grad_gradients_vals = sess.run(
+ actual_grad_gradients, {
+ inputs: x,
+ outputs: output_vals,
+ output_grad_gradients: output_grad_grad_vals
+ })
+
+ # Compare the Tensorflow and XLA results.
+ self.assertAllClose(
+ expected_grad_gradients_vals,
+ actual_grad_gradients_vals,
+ rtol=1e-4,
+ atol=1e-6)
+ self.assertShapeEqual(actual_grad_gradients_vals, outputs)
+
def testMaxPoolGradValidPadding1_1_3d(self):
self._VerifyGradient(
nn_ops.max_pool3d,
@@ -267,7 +318,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[1, 3, 3, 3, 1],
ksize=[1, 1, 1],
strides=[1, 1, 1],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradValidPadding2_1_6_3d(self):
self._VerifyGradient(
@@ -276,9 +328,13 @@ class Pooling3DTest(XLATestCase):
input_sizes=[2, 3, 3, 6, 3],
ksize=[2, 2, 2],
strides=[1, 1, 1],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradValidPadding2_1_7_3d(self):
+ # TODO(b/73062247): the bfloat16 implementation of MaxPool3DGradGrad does
+ # not have enough precision for this test case to pass if
+ # pool_grad_grad_func is passed.
self._VerifyGradient(
nn_ops.max_pool3d,
gen_nn_ops.max_pool3d_grad,
@@ -294,7 +350,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[2, 2, 2, 2, 3],
ksize=[2, 2, 2],
strides=[2, 2, 2],
- padding="VALID")
+ padding="VALID",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradSamePadding1_1_3d(self):
self._VerifyGradient(
@@ -303,7 +360,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[2, 3, 2, 4, 1],
ksize=[1, 1, 1],
strides=[1, 1, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradSamePadding2_1_3d(self):
self._VerifyGradient(
@@ -312,7 +370,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[2, 3, 2, 4, 1],
ksize=[2, 2, 2],
strides=[1, 1, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradSamePadding2_2_3d(self):
self._VerifyGradient(
@@ -321,7 +380,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[2, 5, 2, 4, 3],
ksize=[2, 2, 2],
strides=[2, 2, 2],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testMaxPoolGradSamePadding3_1_3d(self):
self._VerifyGradient(
@@ -330,7 +390,8 @@ class Pooling3DTest(XLATestCase):
input_sizes=[1, 3, 3, 7, 1],
ksize=[3, 3, 3],
strides=[1, 1, 1],
- padding="SAME")
+ padding="SAME",
+ pool_grad_grad_func=gen_nn_ops.max_pool3d_grad_grad)
def testAvgPoolGradValidPadding1_1_3d(self):
self._VerifyGradient(
diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
index fe270af3d6..9fc94752ea 100644
--- a/tensorflow/compiler/tests/pooling_ops_test.py
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -69,7 +69,7 @@ def GetTestConfigs():
return test_configs
-class PoolingTest(XLATestCase):
+class PoolingTest(xla_test.XLATestCase):
def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
data_format, expected):
@@ -288,7 +288,7 @@ class PoolingTest(XLATestCase):
expected=expected_output)
-class PoolGradTest(XLATestCase):
+class PoolGradTest(xla_test.XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
diff --git a/tensorflow/compiler/tests/powersign_test.py b/tensorflow/compiler/tests/powersign_test.py
new file mode 100644
index 0000000000..5fa7706d72
--- /dev/null
+++ b/tensorflow/compiler/tests/powersign_test.py
@@ -0,0 +1,142 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for PowerSign."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.contrib.opt.python.training import powersign
+from tensorflow.contrib.opt.python.training import sign_decay
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def py_linear_decay_fn(decay_steps):
+ def linear_decay(step):
+ step = min(step, decay_steps)
+ return float(decay_steps - step) / decay_steps
+ return linear_decay
+
+
+def powersign_update_numpy(params,
+ g_t,
+ m,
+ lr,
+ base=math.e,
+ beta=0.9,
+ py_sign_decay_fn=None,
+ t=None):
+ m_t = beta * m + (1 - beta) * g_t
+ if py_sign_decay_fn is None:
+ sign_decayed = 1.0
+ else:
+ sign_decayed = py_sign_decay_fn(t-1)
+ multiplier = base ** (sign_decayed * np.sign(g_t) * np.sign(m_t))
+ params_t = params - lr * multiplier * g_t
+ return params_t, m_t
+
+
+class PowerSignTest(xla_test.XLATestCase):
+
+ def _testDense(self,
+ learning_rate=0.1,
+ sign_decay_fn=None,
+ py_sign_decay_fn=None,
+ base=math.e,
+ beta=0.9):
+ for dtype in self.float_types:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ m0, m1 = 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ global_step = resource_variable_ops.ResourceVariable(0, trainable=False)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = powersign.PowerSignOptimizer(
+ learning_rate=learning_rate,
+ base=base,
+ beta=beta,
+ sign_decay_fn=sign_decay_fn,
+ )
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]),
+ global_step=global_step)
+ neg_update = opt.apply_gradients(zip([-grads0, -grads1], [var0, var1]),
+ global_step=global_step)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 7 steps of powersign
+ # first 4 steps with positive gradient
+ # last 3 steps with negative gradient (sign(gm) should be -1)
+ for t in range(1, 8):
+ if t < 5:
+ update.run()
+ else:
+ neg_update.run()
+
+ var0_np, m0 = powersign_update_numpy(
+ var0_np,
+ grads0_np if t < 5 else -grads0_np,
+ m0,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+ var1_np, m1 = powersign_update_numpy(
+ var1_np,
+ grads1_np if t < 5 else -grads1_np,
+ m1,
+ learning_rate,
+ base=base,
+ beta=beta,
+ py_sign_decay_fn=py_sign_decay_fn,
+ t=t,
+ )
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testDense(self):
+ decay_steps = 10
+ sign_decay_fn = sign_decay.get_linear_decay_fn(decay_steps)
+ py_sign_decay_fn = py_linear_decay_fn(decay_steps)
+ self._testDense()
+ self._testDense(learning_rate=0.1, base=10.0, beta=0.8)
+ self._testDense(
+ sign_decay_fn=sign_decay_fn, py_sign_decay_fn=py_sign_decay_fn)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/proximal_adagrad_test.py b/tensorflow/compiler/tests/proximal_adagrad_test.py
new file mode 100644
index 0000000000..cde87db63d
--- /dev/null
+++ b/tensorflow/compiler/tests/proximal_adagrad_test.py
@@ -0,0 +1,172 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Proximal Adagrad optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
+from tensorflow.python.training import proximal_adagrad
+
+
+class ProximalAdagradOptimizerTest(xla_test.XLATestCase):
+
+ def testResourceProximalAdagradwithoutRegularization(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+ opt = proximal_adagrad.ProximalAdagradOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run 3 steps Proximal Adagrad.
+ for _ in range(3):
+ update.run()
+
+ self.assertAllClose(np.array([-2.60260963, -4.29698515]), var0.eval())
+ self.assertAllClose(np.array([-0.28432083, -0.56694895]), var1.eval())
+ opt_vars = opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var1._shared_name)
+ self.assertEqual(2, len(opt_vars))
+
+ def testProximalAdagradwithoutRegularization2(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_adagrad.ProximalAdagradOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 3 steps Proximal Adagrad.
+ for _ in range(3):
+ update.run()
+ self.assertAllClose(np.array([-1.60261, -2.296985]), var0.eval())
+ self.assertAllClose(np.array([3.715679, 2.433051]), var1.eval())
+
+ def testProximalAdagradWithL1(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_adagrad.ProximalAdagradOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps Proximal Adagrad
+ for _ in range(10):
+ update.run()
+ self.assertAllClose(np.array([-6.663634, -9.190331]), var0.eval())
+ self.assertAllClose(np.array([2.959304, 1.029232]), var1.eval())
+
+ def testProximalAdagradWithL1_L2(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_adagrad.ProximalAdagradOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps Proximal Adagrad.
+ for _ in range(10):
+ update.run()
+
+ self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
+ self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
+
+ def applyOptimizer(self, opt, steps=5):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run ProximalAdagrad for a few steps
+ for _ in range(steps):
+ update.run()
+
+ return var0.eval(), var1.eval()
+
+ def testEquivAdagradwithoutRegularization(self):
+ with self.test_session(), self.test_scope():
+ val0, val1 = self.applyOptimizer(
+ proximal_adagrad.ProximalAdagradOptimizer(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session(), self.test_scope():
+ val2, val3 = self.applyOptimizer(
+ adagrad.AdagradOptimizer(
+ 3.0, initial_accumulator_value=0.1))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/proximal_gradient_descent_test.py b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
new file mode 100644
index 0000000000..11eb768711
--- /dev/null
+++ b/tensorflow/compiler/tests/proximal_gradient_descent_test.py
@@ -0,0 +1,156 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Proximal Gradient Descent optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import proximal_gradient_descent
+
+
+class ProximalGradientDescentOptimizerTest(xla_test.XLATestCase):
+
+ def testResourceProximalGradientDescentwithoutRegularization(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+ opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
+ 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([0.0, 0.0], var0.eval())
+ self.assertAllClose([0.0, 0.0], var1.eval())
+
+ # Run 3 steps Proximal Gradient Descent.
+ for _ in range(3):
+ update.run()
+
+ self.assertAllClose(np.array([-0.9, -1.8]), var0.eval())
+ self.assertAllClose(np.array([-0.09, -0.18]), var1.eval())
+
+ def testProximalGradientDescentwithoutRegularization2(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
+ 3.0, l1_regularization_strength=0.0, l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 3 steps Proximal Gradient Descent
+ for _ in range(3):
+ update.run()
+
+ self.assertAllClose(np.array([0.1, 0.2]), var0.eval())
+ self.assertAllClose(np.array([3.91, 2.82]), var1.eval())
+
+ def testProximalGradientDescentWithL1(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
+ 3.0, l1_regularization_strength=0.001, l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps proximal gradient descent.
+ for _ in range(10):
+ update.run()
+
+ self.assertAllClose(np.array([-1.988, -3.988001]), var0.eval())
+ self.assertAllClose(np.array([3.67, 2.37]), var1.eval())
+
+ def testProximalGradientDescentWithL1_L2(self):
+ with self.test_session(), self.test_scope():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([4.0, 3.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ opt = proximal_gradient_descent.ProximalGradientDescentOptimizer(
+ 3.0, l1_regularization_strength=0.001, l2_regularization_strength=2.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([4.0, 3.0], var1.eval())
+
+ # Run 10 steps Proximal Gradient Descent
+ for _ in range(10):
+ update.run()
+
+ self.assertAllClose(np.array([-0.0495, -0.0995]), var0.eval())
+ self.assertAllClose(np.array([-0.0045, -0.0095]), var1.eval())
+
+ def applyOptimizer(self, opt, steps=5):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0])
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0])
+ grads0 = constant_op.constant([0.1, 0.2])
+ grads1 = constant_op.constant([0.01, 0.02])
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run ProximalAdagrad for a few steps
+ for _ in range(steps):
+ update.run()
+
+ return var0.eval(), var1.eval()
+
+ def testEquivGradientDescentwithoutRegularization(self):
+ with self.test_session(), self.test_scope():
+ val0, val1 = self.applyOptimizer(
+ proximal_gradient_descent.ProximalGradientDescentOptimizer(
+ 3.0,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session(), self.test_scope():
+ val2, val3 = self.applyOptimizer(
+ gradient_descent.GradientDescentOptimizer(3.0))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/qr_op_test.py b/tensorflow/compiler/tests/qr_op_test.py
new file mode 100644
index 0000000000..93752a21db
--- /dev/null
+++ b/tensorflow/compiler/tests/qr_op_test.py
@@ -0,0 +1,112 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class QrOpTest(xla_test.XLATestCase, parameterized.TestCase):
+
+ def AdjustedNorm(self, x):
+ """Computes the norm of matrices in 'x', adjusted for dimension and type."""
+ norm = np.linalg.norm(x, axis=(-2, -1))
+ return norm / (max(x.shape[-2:]) * np.finfo(x.dtype).eps)
+
+ def CompareOrthogonal(self, x, y, rank):
+ # We only compare the first 'rank' orthogonal vectors since the
+ # remainder form an arbitrary orthonormal basis for the
+ # (row- or column-) null space, whose exact value depends on
+ # implementation details. Notice that since we check that the
+ # matrices of singular vectors are unitary elsewhere, we do
+ # implicitly test that the trailing vectors of x and y span the
+ # same space.
+ x = x[..., 0:rank]
+ y = y[..., 0:rank]
+ # Q is only unique up to sign (complex phase factor for complex matrices),
+ # so we normalize the sign first.
+ sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
+ phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
+ x *= phases
+ self.assertTrue(np.all(self.AdjustedNorm(x - y) < 30.0))
+
+ def CheckApproximation(self, a, q, r):
+ # Tests that a ~= q*r.
+ precision = self.AdjustedNorm(a - np.matmul(q, r))
+ self.assertTrue(np.all(precision < 5.0))
+
+ def CheckUnitary(self, x):
+ # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
+ xx = math_ops.matmul(x, x, adjoint_a=True)
+ identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
+ precision = self.AdjustedNorm(xx.eval() - identity.eval())
+ self.assertTrue(np.all(precision < 5.0))
+
+ def _test(self, dtype, shape, full_matrices):
+ np.random.seed(1)
+ x_np = np.random.uniform(
+ low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype)
+
+ with self.test_session() as sess:
+ x_tf = array_ops.placeholder(dtype)
+ with self.test_scope():
+ q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices)
+ q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
+
+ q_dims = q_tf_val.shape
+ np_q = np.ndarray(q_dims, dtype)
+ np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
+ new_first_dim = np_q_reshape.shape[0]
+
+ x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
+ for i in range(new_first_dim):
+ if full_matrices:
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="complete")
+ else:
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="reduced")
+ np_q = np.reshape(np_q_reshape, q_dims)
+ self.CompareOrthogonal(np_q, q_tf_val, min(shape[-2:]))
+ self.CheckApproximation(x_np, q_tf_val, r_tf_val)
+ self.CheckUnitary(q_tf_val)
+
+ SIZES = [1, 2, 5, 10, 32, 100, 300]
+ DTYPES = [np.float32]
+ PARAMS = itertools.product(SIZES, SIZES, DTYPES)
+
+ @parameterized.parameters(*PARAMS)
+ def testQR(self, rows, cols, dtype):
+ # TODO(b/111317468): implement full_matrices=False, test other types.
+ for full_matrices in [True]:
+ # Only tests the (3, 2) case for small numbers of rows/columns.
+ for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
+ self._test(dtype, batch_dims + (rows, cols), full_matrices)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index d6c93088d4..14c5e7a975 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -18,15 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import math
+
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import googletest
-class RandomOpsTest(XLATestCase):
+class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
@@ -47,18 +52,18 @@ class RandomOpsTest(XLATestCase):
# We use exact equality here. If the random-number generator is producing
# deterministic output, all three outputs will be bitwise identical.
self.assertTrue((not np.array_equal(y, z)) or
- (not np.array_equal(z, w)) or
- (not np.array_equal(y, w)))
+ (not np.array_equal(z, w)) or (not np.array_equal(y, w)))
def testRandomUniformIsNotConstant(self):
+
def rng(dtype):
- return random_ops.random_uniform(shape=[2], dtype=dtype,
- maxval=1000000)
+ return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000)
for dtype in self._random_types():
self._testRngIsNotConstant(rng, dtype)
def testRandomNormalIsNotConstant(self):
+
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
@@ -70,22 +75,90 @@ class RandomOpsTest(XLATestCase):
for dtype in self._random_types():
with self.test_session() as sess:
with self.test_scope():
- x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2,
- maxval=33)
+ x = random_ops.random_uniform(
+ shape=[1000], dtype=dtype, minval=-2, maxval=33)
y = sess.run(x)
self.assertTrue((y >= -2).sum() == 1000)
self.assertTrue((y < 33).sum() == 1000)
+ def testTruncatedNormalIsNotConstant(self):
+
+ def rng(dtype):
+ return random_ops.truncated_normal(shape=[2], dtype=dtype)
+
+ # TODO(b/34339814): implement inverse erf support for non-F32 types.
+ self._testRngIsNotConstant(rng, dtypes.float32)
+
def testTruncatedNormalIsInRange(self):
- count = 10000
+ count = 10000000
# TODO(b/34339814): implement inverse erf support for non-F32 types.
for dtype in [dtypes.float32]:
with self.test_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype, seed=42)
y = sess.run(x)
- self.assertTrue((y >= -2).sum() == count)
- self.assertTrue((y <= 2).sum() == count)
+
+ def normal_cdf(x):
+ return .5 * math.erfc(-x / math.sqrt(2))
+
+ def normal_pdf(x):
+ return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
+
+ def probit(x, sess=sess):
+ return sess.run(special_math.ndtri(x))
+
+ a = -2.
+ b = 2.
+ mu = 0.
+ sigma = 1.
+
+ alpha = (a - mu) / sigma
+ beta = (b - mu) / sigma
+ z = normal_cdf(beta) - normal_cdf(alpha)
+
+ self.assertTrue((y >= a).sum() == count)
+ self.assertTrue((y <= b).sum() == count)
+
+ # For more information on these calculations, see:
+ # Burkardt, John. "The Truncated Normal Distribution".
+ # Department of Scientific Computing website. Florida State University.
+ expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
+ actual_mean = np.mean(y)
+ self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+
+ expected_median = mu + probit(
+ (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
+ actual_median = np.median(y)
+ self.assertAllClose(actual_median, expected_median, atol=8e-4)
+
+ expected_variance = sigma**2 * (1 + (
+ (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
+ (normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
+ actual_variance = np.var(y)
+ self.assertAllClose(actual_variance, expected_variance, rtol=3e-4)
+
+ def testShuffle1d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = math_ops.range(1 << 16)
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = range(1 << 16)
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(set(result), set(expected))
+
+ def testShuffle2d(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = array_ops.diag(math_ops.range(20))
+ shuffle = random_ops.random_shuffle(x)
+ result = sess.run(shuffle)
+ expected = np.diag(range(20)).flatten()
+ # Compare sets to avoid randomness behavior changes but make sure still
+ # have all the values.
+ self.assertAllEqual(len(result.flatten()), len(expected))
+ self.assertAllEqual(set(result.flatten()), set(expected))
if __name__ == '__main__':
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 7420724bdb..cea2ec816f 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -22,7 +22,7 @@ import functools
import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -30,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class ReduceOpsTest(XLATestCase):
+class ReduceOpsTest(xla_test.XLATestCase):
def _testReduction(self,
tf_reduce_fn,
@@ -156,7 +156,7 @@ class ReduceOpsTest(XLATestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
-class ReduceOpPrecisionTest(XLATestCase):
+class ReduceOpPrecisionTest(xla_test.XLATestCase):
def _testReduceSum(self,
expected_result,
diff --git a/tensorflow/compiler/tests/reduce_window_test.py b/tensorflow/compiler/tests/reduce_window_test.py
index e78a63465b..c69b6837b0 100644
--- a/tensorflow/compiler/tests/reduce_window_test.py
+++ b/tensorflow/compiler/tests/reduce_window_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class ReduceWindowTest(XLATestCase):
+class ReduceWindowTest(xla_test.XLATestCase):
"""Test cases for xla.reduce_window."""
def _reduce_window(self, operand, init, reducer, **kwargs):
diff --git a/tensorflow/compiler/tests/reverse_ops_test.py b/tensorflow/compiler/tests/reverse_ops_test.py
index 18fabca28c..d01c676e7c 100644
--- a/tensorflow/compiler/tests/reverse_ops_test.py
+++ b/tensorflow/compiler/tests/reverse_ops_test.py
@@ -21,14 +21,14 @@ from __future__ import print_function
import itertools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class ReverseOpsTest(XLATestCase):
+class ReverseOpsTest(xla_test.XLATestCase):
def testReverseOneDim(self):
shape = (7, 5, 9, 11)
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 1a5d05094e..ccfa630016 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ReverseSequenceTest(XLATestCase):
+class ReverseSequenceTest(xla_test.XLATestCase):
def _testReverseSequence(self,
x,
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index ecdce4f052..ff8bbac911 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
@@ -28,33 +28,104 @@ from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop
-class RmspropTest(XLATestCase):
+class RmspropTest(xla_test.XLATestCase):
+
+ def _rmsprop_update_numpy(self,
+ var,
+ g,
+ mg,
+ rms,
+ mom,
+ lr,
+ decay=0.9,
+ momentum=0.0,
+ epsilon=1e-10,
+ centered=False):
+ rms_t = rms * decay + (1 - decay) * g * g
+ denom_t = rms_t + epsilon
+ if centered:
+ mg_t = mg * decay + (1 - decay) * g
+ denom_t -= mg_t * mg_t
+ else:
+ mg_t = mg
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
def testBasic(self):
for dtype in self.float_types:
- with self.test_session(), self.test_scope():
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- rms_opt = rmsprop.RMSPropOptimizer(3.0)
- rms_update = rms_opt.apply_gradients(
- zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
-
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
-
- # Run 3 steps of RMSProp
- for _ in range(3):
- rms_update.run()
-
- # Validate updated params
- self.assertAllCloseAccordingToType(
- np.array([2.91705132e-04, 1.00029182e+00]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([2.89990854, 3.89990854]), var1.eval())
+ for centered in [False, True]:
+ with self.test_session(), self.test_scope():
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype)
+ mg0_np = np.array([0.0, 0.0], dtype=dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype)
+ rms0_np = np.array([1.0, 1.0], dtype=dtype)
+ rms1_np = np.array([1.0, 1.0], dtype=dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ learning_rate = 3.0
+ rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
+ rms_update = rms_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = rms_opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = rms_opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = rms_opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = rms_opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = rms_opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = rms_opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of RMSProp
+ for _ in range(3):
+ rms_update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np,
+ grads0_np,
+ mg0_np,
+ rms0_np,
+ mom0_np,
+ learning_rate,
+ centered=centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np,
+ grads1_np,
+ mg1_np,
+ rms1_np,
+ mom1_np,
+ learning_rate,
+ centered=centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 3260e63b23..4292352e76 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@@ -69,7 +69,7 @@ def handle_options(func, x, axis, exclusive, reverse):
return x
-class CumsumTest(XLATestCase):
+class CumsumTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
@@ -147,7 +147,7 @@ class CumsumTest(XLATestCase):
math_ops.cumsum(input_tensor, [0]).eval()
-class CumprodTest(XLATestCase):
+class CumprodTest(xla_test.XLATestCase):
valid_dtypes = [np.float32]
diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py
index 638946e234..f606f88545 100644
--- a/tensorflow/compiler/tests/scatter_nd_op_test.py
+++ b/tensorflow/compiler/tests/scatter_nd_op_test.py
@@ -22,7 +22,7 @@ import functools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -68,7 +68,7 @@ def _NumpyUpdate(indices, updates, shape):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
-class ScatterNdTest(XLATestCase):
+class ScatterNdTest(xla_test.XLATestCase):
def _VariableRankTest(self,
np_scatter,
diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py
index 4a9c0e7471..772c20fd42 100644
--- a/tensorflow/compiler/tests/segment_reduction_ops_test.py
+++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py
@@ -21,26 +21,40 @@ from __future__ import print_function
import functools
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class SegmentReductionOpsTest(XLATestCase):
+class SegmentReductionOpsTest(xla_test.XLATestCase):
"""Test cases for segment reduction ops."""
- def UnsortedSegmentSum(self, data, indices, num_segments):
+ def _segmentReduction(self, op, data, indices, num_segments):
with self.test_session() as sess, self.test_scope():
d = array_ops.placeholder(data.dtype, shape=data.shape)
if isinstance(indices, int):
i = array_ops.placeholder(np.int32, shape=[])
else:
i = array_ops.placeholder(indices.dtype, shape=indices.shape)
- return sess.run(
- math_ops.unsorted_segment_sum(d, i, num_segments),
- {d: data,
- i: indices})
+ return sess.run(op(d, i, num_segments), {d: data, i: indices})
+
+ def _unsortedSegmentSum(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_sum, data, indices,
+ num_segments)
+
+ def _unsortedSegmentProd(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_prod, data, indices,
+ num_segments)
+
+ def _unsortedSegmentMin(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_min, data, indices,
+ num_segments)
+
+ def _unsortedSegmentMax(self, data, indices, num_segments):
+ return self._segmentReduction(math_ops.unsorted_segment_max, data, indices,
+ num_segments)
def testUnsortedSegmentSum0DIndices1DData(self):
for dtype in self.numeric_types:
@@ -49,14 +63,14 @@ class SegmentReductionOpsTest(XLATestCase):
[[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 1, 2, 3, 4, 5],
[0, 0, 0, 0, 0, 0]],
dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 2, 4))
def testUnsortedSegmentSum1DIndices1DData(self):
for dtype in self.numeric_types:
self.assertAllClose(
np.array([1, 3, 2, 9], dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
@@ -64,7 +78,7 @@ class SegmentReductionOpsTest(XLATestCase):
for dtype in self.numeric_types:
self.assertAllClose(
np.array([6, 3, 0, 6], dtype=dtype),
- self.UnsortedSegmentSum(
+ self._unsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
@@ -76,7 +90,7 @@ class SegmentReductionOpsTest(XLATestCase):
dtype=dtype)
indices = np.array([8, 1, 0, 3, 7], dtype=np.int32)
num_segments = 10
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[30, 31, 32, 33], [20, 21, 22, 23], [0, 0, 0, 0],
@@ -92,7 +106,7 @@ class SegmentReductionOpsTest(XLATestCase):
dtype=dtype)
indices = np.array([0, 1, 2, 0, 1], dtype=np.int32)
num_segments = 4
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[40, 42, 44, 46], [70, 72, 74, 76], [30, 31, 32, 33],
@@ -102,30 +116,30 @@ class SegmentReductionOpsTest(XLATestCase):
def testUnsortedSegmentSum2DIndices3DData(self):
for dtype in self.numeric_types:
data = np.array(
- [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
- [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
- [310, 311, 312]]],
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
+ 200, 201, 202
+ ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
dtype=dtype)
indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32)
num_segments = 8
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
- [[210, 211, 212], [110, 111, 112], [310, 311, 312],
- [100, 102, 104], [0, 0, 0.], [210, 212, 214], [300, 301,
- 302], [0, 0, 0]],
+ [[210, 211, 212], [110, 111, 112], [310, 311, 312], [
+ 100, 102, 104
+ ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]],
dtype=dtype), y)
def testUnsortedSegmentSum1DIndices3DData(self):
for dtype in self.numeric_types:
data = np.array(
- [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]],
- [[200, 201, 202], [210, 211, 212]], [[300, 301, 302],
- [310, 311, 312]]],
+ [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[
+ 200, 201, 202
+ ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]],
dtype=dtype)
indices = np.array([3, 0, 2, 5], dtype=np.int32)
num_segments = 6
- y = self.UnsortedSegmentSum(data, indices, num_segments)
+ y = self._unsortedSegmentSum(data, indices, num_segments)
self.assertAllClose(
np.array(
[[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]],
@@ -138,10 +152,40 @@ class SegmentReductionOpsTest(XLATestCase):
data = np.ones((4, 8, 7), dtype=dtype)
indices = np.ones((3, 2), dtype=np.int32)
num_segments = 4
- self.assertRaises(ValueError,
- functools.partial(self.UnsortedSegmentSum, data,
- indices, num_segments))
+ self.assertRaises(
+ ValueError,
+ functools.partial(self._segmentReduction,
+ math_ops.unsorted_segment_sum, data, indices,
+ num_segments))
+
+ def testUnsortedSegmentOps1DIndices1DDataNegativeIndices(self):
+ """Tests for min, max, and prod ops.
+
+ These share most of their implementation with sum, so we only test basic
+ functionality.
+ """
+ for dtype in self.numeric_types:
+ self.assertAllClose(
+ np.array([8, 3, 1, 0], dtype=dtype),
+ self._unsortedSegmentProd(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
+
+ for dtype in self.int_types | self.float_types:
+ minval = dtypes.as_dtype(dtype).min
+ maxval = dtypes.as_dtype(dtype).max
+
+ self.assertAllClose(
+ np.array([2, 3, maxval, 0], dtype=dtype),
+ self._unsortedSegmentMin(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
+ self.assertAllClose(
+ np.array([4, 3, minval, 6], dtype=dtype),
+ self._unsortedSegmentMax(
+ np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype),
+ np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4))
-if __name__ == '__main__':
+if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py
index 305ca0c6b7..6c4890565d 100644
--- a/tensorflow/compiler/tests/slice_ops_test.py
+++ b/tensorflow/compiler/tests/slice_ops_test.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
-class SliceTest(XLATestCase):
+class SliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
@@ -110,7 +110,7 @@ class SliceTest(XLATestCase):
self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result)
-class StridedSliceTest(XLATestCase):
+class StridedSliceTest(xla_test.XLATestCase):
def test1D(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
new file mode 100644
index 0000000000..9e2ef964a1
--- /dev/null
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sorting operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.compiler.tf2xla.python import xla
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+class XlaSortOpTest(xla_test.XLATestCase):
+
+ def _assertOpOutputMatchesExpected(self, op, args, expected):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(*placeholders)
+ if isinstance(output, ops.Tensor):
+ output = [output]
+
+ results = session.run(output, feeds)
+ for result, v in zip(results, expected):
+ self.assertAllClose(v, result, rtol=1e-3)
+
+ def testSort(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=dtype)
+ np.random.shuffle(x)
+ self._assertOpOutputMatchesExpected(
+ xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
+
+ def testTopK(self):
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for dtype in supported_types.intersection(self.numeric_types):
+ # Use small input size for bfloat16. Otherwise, we'll get duplicate values
+ # after conversion to bfloat16, so the possible resulting index array is
+ # no longer unique.
+ if dtype == dtypes.bfloat16.as_numpy_dtype:
+ array_size = 20
+ k_options = [0, 1, 2, 10, 20]
+ else:
+ array_size = 200 * 1000
+ k_options = [0, 1, 2, 10, 20, 100, 1000, 200 * 1000]
+ for x in [np.arange(array_size)]:
+ np.random.shuffle(x)
+ for k in k_options:
+ indices = x.argsort()[::-1][:k]
+
+ def topk(v, k=k):
+ return nn_ops.top_k(v, k=k, sorted=True)
+
+ self._assertOpOutputMatchesExpected(
+ topk, [x.astype(dtype)],
+ expected=[x[indices].astype(dtype), indices])
+
+ def testTopKZeros(self):
+ """Tests that positive and negative zeros sort correctly."""
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ # Only bfloat16 is implemented.
+ bfloat16 = dtypes.bfloat16.as_numpy_dtype
+ if bfloat16 not in self.numeric_types:
+ return
+
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtypes.bfloat16)
+ with self.test_scope():
+ topk = nn_ops.top_k(p, k=4)
+ results = sess.run(
+ topk,
+ {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)})
+ self.assertAllEqual(
+ np.array([3., 0., 0., 0.], dtype=bfloat16), results[0])
+ self.assertEqual(list([3, 0, 2, 6]), list(results[1]))
+
+ def testTopKInfinities(self):
+ """Tests that positive and negative infinity sort correctly."""
+ # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
+ if self.device in ["XLA_CPU", "XLA_GPU"]:
+ return
+
+ # Only bfloat16 is implemented.
+ bfloat16 = dtypes.bfloat16.as_numpy_dtype
+ if bfloat16 not in self.numeric_types:
+ return
+
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtypes.bfloat16)
+ with self.test_scope():
+ topk = nn_ops.top_k(p, k=6)
+ results = sess.run(topk, {
+ p: np.array(
+ [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16)
+ })
+ self.assertAllEqual(
+ np.array(
+ [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")],
+ dtype=bfloat16), results[0])
+ self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py
index f37c34156f..c685bc548f 100644
--- a/tensorflow/compiler/tests/spacetobatch_op_test.py
+++ b/tensorflow/compiler/tests/spacetobatch_op_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -68,7 +68,7 @@ def space_to_batch_direct(input_array, block_shape, paddings):
return permuted_reshaped_padded.reshape(output_shape)
-class SpaceToBatchTest(XLATestCase):
+class SpaceToBatchTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatch and BatchToSpace ops."""
def _testPad(self, inputs, paddings, block_size, outputs):
@@ -149,7 +149,7 @@ class SpaceToBatchTest(XLATestCase):
self._testOne(x_np, block_size, x_out)
-class SpaceToBatchNDTest(XLATestCase):
+class SpaceToBatchNDTest(xla_test.XLATestCase):
"""Tests input-output pairs for the SpaceToBatchND and BatchToSpaceND ops."""
def _testPad(self, inputs, block_shape, paddings, outputs):
diff --git a/tensorflow/compiler/tests/sparse_to_dense_op_test.py b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
new file mode 100644
index 0000000000..3db8101c4b
--- /dev/null
+++ b/tensorflow/compiler/tests/sparse_to_dense_op_test.py
@@ -0,0 +1,118 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.sparse_op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+def _SparseToDense(sparse_indices,
+ output_size,
+ sparse_values,
+ default_value,
+ validate_indices=True):
+ feed_sparse_indices = array_ops.placeholder(dtypes.int32)
+ feed_dict = {feed_sparse_indices: sparse_indices}
+ return sparse_ops.sparse_to_dense(
+ feed_sparse_indices,
+ output_size,
+ sparse_values,
+ default_value=default_value,
+ validate_indices=validate_indices).eval(feed_dict=feed_dict)
+
+
+class SparseToDenseTest(xla_test.XLATestCase):
+
+ def testInt(self):
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([1, 3], [5], 1, 0)
+ np_ans = np.array([0, 1, 0, 1, 0]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testFloat(self):
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([1, 3], [5], 1.0, 0.0)
+ np_ans = np.array([0, 1, 0, 1, 0]).astype(np.float32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testSetValue(self):
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([1, 3], [5], [1, 2], -1)
+ np_ans = np.array([-1, 1, -1, 2, -1]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testSetSingleValue(self):
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([1, 3], [5], 1, -1)
+ np_ans = np.array([-1, 1, -1, 1, -1]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def test2d(self):
+ # pylint: disable=bad-whitespace
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([[1, 3], [2, 0]], [3, 4], 1, -1)
+ np_ans = np.array([[-1, -1, -1, -1],
+ [-1, -1, -1, 1],
+ [ 1, -1, -1, -1]]).astype(np.int32)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testZeroDefault(self):
+ with self.test_session():
+ x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
+ self.assertAllEqual(x, [0, 0, 7, 0])
+
+ def test3d(self):
+ with self.test_session(), self.test_scope():
+ tf_ans = _SparseToDense([[1, 3, 0], [2, 0, 1]], [3, 4, 2], 1, -1)
+ np_ans = np.ones((3, 4, 2), dtype=np.int32) * -1
+ np_ans[1, 3, 0] = 1
+ np_ans[2, 0, 1] = 1
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testBadShape(self):
+ with self.test_session(), self.test_scope():
+ with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
+ _SparseToDense([1, 3], [[5], [3]], 1, -1)
+
+ def testBadValue(self):
+ with self.test_session(), self.test_scope():
+ with self.assertRaisesOpError(
+ r"sparse_values has incorrect shape \[2,1\], "
+ r"should be \[\] or \[2\]"):
+ _SparseToDense([1, 3], [5], [[5], [3]], -1)
+
+ def testBadNumValues(self):
+ with self.test_session(), self.test_scope():
+ with self.assertRaisesOpError(
+ r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
+ _SparseToDense([1, 3], [5], [1, 2, 3], -1)
+
+ def testBadDefault(self):
+ with self.test_session(), self.test_scope():
+ with self.assertRaisesOpError("default_value should be a scalar"):
+ _SparseToDense([1, 3], [5], [1, 2], [0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/stack_ops_test.py b/tensorflow/compiler/tests/stack_ops_test.py
index 94342f9567..b7dd787fef 100644
--- a/tensorflow/compiler/tests/stack_ops_test.py
+++ b/tensorflow/compiler/tests/stack_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -28,7 +28,7 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.platform import test
-class StackOpTest(XLATestCase):
+class StackOpTest(xla_test.XLATestCase):
def testStackPushPop(self):
with self.test_session(), self.test_scope():
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index b6f8390a45..d162675ef8 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -22,14 +22,15 @@ import math
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.contrib import stateless
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import test
-class StatelessRandomOpsTest(XLATestCase):
+class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
@@ -122,6 +123,56 @@ class StatelessRandomOpsTest(XLATestCase):
# so to avoid flakiness the seed is fixed.
self.assertTrue(self._anderson_darling(y) < 2.492)
+ def testTruncatedNormalIsInRange(self):
+ # TODO(b/34339814): implement inverse erf support for non-F32 types.
+ for dtype in [dtypes.float32]:
+ with self.test_session() as sess, self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ n = 10000000
+ x = stateless.stateless_truncated_normal(
+ shape=[n], seed=seed_t, dtype=dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+
+ def normal_cdf(x):
+ return .5 * math.erfc(-x / math.sqrt(2))
+
+ def normal_pdf(x):
+ return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
+
+ def probit(x, sess=sess):
+ return sess.run(special_math.ndtri(x))
+
+ a = -2.
+ b = 2.
+ mu = 0.
+ sigma = 1.
+
+ alpha = (a - mu) / sigma
+ beta = (b - mu) / sigma
+ z = normal_cdf(beta) - normal_cdf(alpha)
+
+ self.assertTrue((y >= a).sum() == n)
+ self.assertTrue((y <= b).sum() == n)
+
+ # For more information on these calculations, see:
+ # Burkardt, John. "The Truncated Normal Distribution".
+ # Department of Scientific Computing website. Florida State University.
+ expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
+ actual_mean = np.mean(y)
+ self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+
+ expected_median = mu + probit(
+ (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
+ actual_median = np.median(y)
+ self.assertAllClose(actual_median, expected_median, atol=8e-4)
+
+ expected_variance = sigma**2 * (1 + (
+ (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
+ (normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
+ actual_variance = np.var(y)
+ self.assertAllClose(actual_variance, expected_variance, rtol=1e-3)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index ef047005b6..effa5a59fe 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
@@ -28,7 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
-class TernaryOpsTest(XLATestCase):
+class TernaryOpsTest(xla_test.XLATestCase):
def _testTernary(self, op, a, b, c, expected):
with self.test_session() as session:
diff --git a/tensorflow/compiler/tests/test_utils.py b/tensorflow/compiler/tests/test_utils.py
new file mode 100644
index 0000000000..6abde18ea9
--- /dev/null
+++ b/tensorflow/compiler/tests/test_utils.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for helping test ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+def ConvertBetweenDataFormats(x, data_format_src, data_format_dst):
+ """Converts 4D tensor between data formats."""
+
+ valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
+ if data_format_src not in valid_data_formats:
+ raise ValueError("data_format_src must be of %s, got %s." %
+ (valid_data_formats, data_format_src))
+ if data_format_dst not in valid_data_formats:
+ raise ValueError("data_format_dst must be of %s, got %s." %
+ (valid_data_formats, data_format_dst))
+ if len(x.shape) != 4:
+ raise ValueError("x must be 4D, got shape %s." % x.shape)
+
+ if data_format_src == data_format_dst:
+ return x
+
+ dim_map = {d: i for i, d in enumerate(data_format_src)}
+ transpose_dims = [dim_map[d] for d in data_format_dst]
+ return np.transpose(x, transpose_dims)
+
+
+def PermuteDimsBetweenDataFormats(dims, data_format_src, data_format_dst):
+ """Get new shape for converting between data formats."""
+
+ valid_data_formats = ["NHWC", "NCHW", "HWNC", "HWCN"]
+ if data_format_src not in valid_data_formats:
+ raise ValueError("data_format_src must be of %s, got %s." %
+ (valid_data_formats, data_format_src))
+ if data_format_dst not in valid_data_formats:
+ raise ValueError("data_format_dst must be of %s, got %s." %
+ (valid_data_formats, data_format_dst))
+ if len(dims) != 4:
+ raise ValueError("dims must be of length 4, got %s." % dims)
+
+ if data_format_src == data_format_dst:
+ return dims
+
+ dim_map = {d: i for i, d in enumerate(data_format_src)}
+ permuted_dims = [dims[dim_map[d]] for d in data_format_dst]
+ return permuted_dims
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 689a4a1f4e..6a7011aea6 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -23,7 +23,7 @@ import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
@@ -44,11 +44,16 @@ def nhwc_to_format(x, data_format):
raise ValueError("Unknown format {}".format(data_format))
-class UnaryOpsTest(XLATestCase):
+class UnaryOpsTest(xla_test.XLATestCase):
"""Test cases for unary operators."""
- def _assertOpOutputMatchesExpected(self, op, inp, expected,
- equality_test=None, rtol=1e-3, atol=1e-5):
+ def _assertOpOutputMatchesExpected(self,
+ op,
+ inp,
+ expected,
+ equality_test=None,
+ rtol=1e-3,
+ atol=1e-5):
"""Verifies that 'op' produces 'expected' when fed input 'inp' .
Args:
@@ -81,10 +86,10 @@ class UnaryOpsTest(XLATestCase):
def testAllTypeOps(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
- array_ops.diag,
- np.array([1, 2, 3, 4], dtype=dtype),
- np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
- dtype=dtype))
+ array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
+ np.array(
+ [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
+ dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.diag_part,
np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
@@ -102,8 +107,7 @@ class UnaryOpsTest(XLATestCase):
expected=np.array([[-1, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
- array_ops.matrix_diag,
- np.array([[1, 2], [3, 4]], dtype=dtype),
+ array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
@@ -115,10 +119,10 @@ class UnaryOpsTest(XLATestCase):
np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
np.array(
- [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]],
- [[4, 0, 0], [0, 5, 0], [0, 0, 6]]],
- [[[7, 0, 0], [0, 8, 0], [0, 0, 9]],
- [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]],
+ [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [
+ 0, 0, 6
+ ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0],
+ [0, 0, 12]]]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag_part,
@@ -159,36 +163,30 @@ class UnaryOpsTest(XLATestCase):
continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
- math_ops.acos,
- x.astype(dtype),
- expected=np.arccos(x).astype(dtype))
+ math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
self._assertOpOutputMatchesExpected(
- math_ops.asin,
- x.astype(dtype),
- expected=np.arcsin(x).astype(dtype))
+ math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype))
x = np.arange(-3, 3).reshape(1, 3, 2)
self._assertOpOutputMatchesExpected(
- math_ops.atan,
- x.astype(dtype),
- expected=np.arctan(x).astype(dtype))
+ math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype))
self._assertOpOutputMatchesExpected(
math_ops.acosh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([0, 1.3169579, 1.76274717, 2.06343707],
- dtype=dtype))
+ expected=np.array(
+ [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.asinh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255],
- dtype=dtype))
+ expected=np.array(
+ [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.atanh,
np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype),
- expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893],
- dtype=dtype))
+ expected=np.array(
+ [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.ceil,
@@ -198,8 +196,18 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.cosh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284],
- dtype=dtype))
+ expected=np.array(
+ [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype))
+
+ # Disable float16 testing for now
+ if dtype != np.float16:
+ x = np.arange(-10, 10, 1).astype(dtype)
+ with self.test_session() as session:
+ erf_x = session.run(math_ops.erf(x))
+ erfc_x = session.run(math_ops.erfc(x))
+
+ self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x)
+ self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x)
self._assertOpOutputMatchesExpected(
math_ops.exp,
@@ -219,8 +227,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
# Tests for tf.nn ops.
@@ -261,16 +269,20 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.rint,
- np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
- [0.5, 1.5, 2.5, 3.5]], dtype=dtype),
- expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
- dtype=dtype))
+ np.array(
+ [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
+ [0.5, 1.5, 2.5, 3.5]],
+ dtype=dtype),
+ expected=np.array(
+ [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.round,
- np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
- [0.5, 1.5, 2.5, 3.5]], dtype=dtype),
- expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
- dtype=dtype))
+ np.array(
+ [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
+ [0.5, 1.5, 2.5, 3.5]],
+ dtype=dtype),
+ expected=np.array(
+ [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.rsqrt,
@@ -279,10 +291,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.sigmoid,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.7310586, 0.7310586, 0.7310586, 0.7310586],
[0.7310586, 0.880797, 0.95257413, 0.98201376]],
@@ -296,8 +305,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.sinh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172],
- dtype=dtype))
+ expected=np.array(
+ [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.sqrt,
@@ -307,15 +316,12 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.tan,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128],
- dtype=dtype))
+ expected=np.array(
+ [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.tanh,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.76159418, 0.76159418, 0.76159418, 0.76159418],
[0.76159418, 0.96402758, 0.99505478, 0.99932933]],
@@ -323,10 +329,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.log_softmax,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
@@ -360,10 +363,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.25, 0.25, 0.25, 0.25],
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
@@ -372,8 +372,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
- expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
- dtype=dtype))
+ expected=np.array(
+ [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
@@ -382,10 +382,23 @@ class UnaryOpsTest(XLATestCase):
expected=np.array(
[[True, False, True], [False, True, True]], dtype=np.bool))
+ def quantize_and_dequantize_v2(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x, -127, 127, signed_input=True, num_bits=8)
+
self._assertOpOutputMatchesExpected(
- lambda x: array_ops.quantize_and_dequantize_v2(x, -127, 127, True, 8),
+ quantize_and_dequantize_v2,
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
- expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
+ expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
+
+ def quantize_and_dequantize_v3(x):
+ return array_ops.quantize_and_dequantize_v3(
+ x, -127, 127, num_bits=8, signed_input=True, range_given=False)
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v3,
+ np.array([-1, -0.5, 0, 0.3], dtype=dtype),
+ expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
def testComplexOps(self):
for dtype in self.complex_types:
@@ -566,13 +579,13 @@ class UnaryOpsTest(XLATestCase):
for dtype in self.float_types:
self._assertOpOutputMatchesExpected(
math_ops.is_inf,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
self._assertOpOutputMatchesExpected(
math_ops.is_nan,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
def testLogicalOps(self):
@@ -589,14 +602,15 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
- np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
- dtype=np.float32),
+ np.array(
+ [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
expected=np.array([10., 26.], dtype=np.float32))
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
- types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) |
- self.complex_tf_types)
+ types = (
+ set([dtypes.bool, dtypes.int32, dtypes.float32])
+ | self.complex_tf_types)
for shape in shapes:
for src_type in types:
for dst_type in types:
@@ -638,14 +652,11 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
rank_op, dtype(7), expected=np.int32(0))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [[], []], dtype=dtype), expected=np.int32(2))
+ rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [-1, 1], dtype=dtype), expected=np.int32(1))
+ rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [[-1, 1]], dtype=dtype), expected=np.int32(2))
+ rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
self._assertOpOutputMatchesExpected(
rank_op,
np.array([[-1], [1], [4]], dtype=dtype),
@@ -710,97 +721,97 @@ class UnaryOpsTest(XLATestCase):
equality_test=self.ListsAreClose)
def testDepthToSpace(self):
+
def make_op(data_format):
+
def op(x):
- return array_ops.depth_to_space(x, block_size=2,
- data_format=data_format)
+ return array_ops.depth_to_space(
+ x, block_size=2, data_format=data_format)
+
return op
for dtype in self.numeric_types:
for data_format in ["NCHW", "NHWC"]:
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype),
- data_format),
- expected=nhwc_to_format(np.array([[[[1], [2]],
- [[3], [4]]]], dtype=dtype),
- data_format))
+ nhwc_to_format(
+ np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format),
+ expected=nhwc_to_format(
+ np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
nhwc_to_format(
- np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]],
- dtype=dtype),
+ np.array(
+ [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
- dtype=dtype),
- data_format))
+ np.array(
+ [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
+ dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
nhwc_to_format(
- np.array([[[[1, 2, 3, 4],
- [5, 6, 7, 8]],
- [[9, 10, 11, 12],
- [13, 14, 15, 16]]]], dtype=dtype),
- data_format),
+ np.array(
+ [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
+ [13, 14, 15, 16]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]], dtype=dtype),
- data_format))
+ np.array(
+ [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
+ dtype=dtype), data_format))
def testSpaceToDepth(self):
+
def make_op(data_format):
+
def op(x):
- return array_ops.space_to_depth(x, block_size=2,
- data_format=data_format)
+ return array_ops.space_to_depth(
+ x, block_size=2, data_format=data_format)
+
return op
for dtype in self.numeric_types:
for data_format in ["NCHW", "NHWC"]:
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1], [2]],
- [[3], [4]]]], dtype=dtype),
- data_format),
- expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype),
- data_format))
+ nhwc_to_format(
+ np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format),
+ expected=nhwc_to_format(
+ np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]],
- [[7, 8, 9], [10, 11, 12]]]], dtype=dtype),
- data_format),
+ nhwc_to_format(
+ np.array(
+ [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]],
- dtype=dtype),
+ np.array(
+ [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]], dtype=dtype),
- data_format),
+ nhwc_to_format(
+ np.array(
+ [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3, 4],
- [5, 6, 7, 8]],
- [[9, 10, 11, 12],
- [13, 14, 15, 16]]]], dtype=dtype),
- data_format))
+ np.array(
+ [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
+ [13, 14, 15, 16]]]],
+ dtype=dtype), data_format))
def _assertSoftplusMatchesExpected(self, features, dtype):
features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features)
self._assertOpOutputMatchesExpected(
- nn_ops.softplus, features, expected=expected,
- rtol=1e-6,
- atol=9.1e-6)
+ nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
def testSoftplus(self):
for dtype in self.float_types:
@@ -814,9 +825,10 @@ class UnaryOpsTest(XLATestCase):
one = dtype(1)
ten = dtype(10)
self._assertSoftplusMatchesExpected([
- log_eps, log_eps - one, log_eps + one, log_eps - ten,
- log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
- -log_eps - ten, -log_eps + ten], dtype)
+ log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten,
+ -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten,
+ -log_eps + ten
+ ], dtype)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index 2c09b03d5a..dd2c252d38 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -20,12 +20,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -36,7 +37,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
-class VariableOpsTest(XLATestCase):
+class VariableOpsTest(xla_test.XLATestCase):
"""Test cases for resource variable operators."""
def testOneWriteOneOutput(self):
@@ -52,9 +53,7 @@ class VariableOpsTest(XLATestCase):
with ops.control_dependencies([x]):
y = v.read_value()
self.assertAllClose(
- np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {
- p: 1
- }))
+ np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, {p: 1}))
def testSparseRead0DIndices(self):
for dtype in self.numeric_types:
@@ -103,9 +102,9 @@ class VariableOpsTest(XLATestCase):
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
- [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
- [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]],
- ).astype(dtype), sess.run(x))
+ [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]
+ ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
+ ],).astype(dtype), sess.run(x))
def testShape(self):
for dtype in self.numeric_types:
@@ -206,6 +205,206 @@ class VariableOpsTest(XLATestCase):
self.assertAllClose(update, result[1])
self.assertAllClose(update, result[2])
+ def testScatterAdd(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[2, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1], [7]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_add(
+ handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertAllEqual(sess.run(read), [[3], [7]])
+
+ def testScatterSub(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[2, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[4], [1]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_sub(
+ handle, [1], constant_op.constant([[2]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertAllEqual(sess.run(read), [[4], [-1]])
+
+ def testScatterMul(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_mul(
+ handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[5]])
+
+ def testScatterDiv(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_div(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertAllEqual(sess.run(read), [[2]])
+
+ def testScatterMin(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_min(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[3]])
+
+ def testScatterMax(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_max(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[6]])
+
+ def testScatterUpdate(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_update(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[3]])
+
+ def testScatterAddScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_add(
+ handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[3]])
+
+ def testScatterSubScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_sub(
+ handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[-1]])
+
+ def testScatterMulScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_mul(
+ handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[5]])
+
+ def testScatterDivScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_div(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[2]])
+
+ def testScatterMinScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_min(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[3]])
+
+ def testScatterMaxScalar(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ sess.run(
+ resource_variable_ops.resource_scatter_max(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(sess.run(read), [[6]])
+
+ def testScatterNdAddOps(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.float32, shape=[8])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
+ indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
+ expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
+ sess.run(gen_state_ops.resource_scatter_nd_add(handle, indices, updates))
+ read = resource_variable_ops.read_variable_op(
+ handle, dtype=dtypes.float32)
+ self.assertAllClose(expected, sess.run(read))
+
+ def testScatterNdUpdateAddOps(self):
+ with self.test_session() as sess, self.test_scope():
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.float32, shape=[8])
+ sess.run(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([1] * 8, dtype=dtypes.float32)))
+ indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
+ expected = np.array([1, 11, 1, 10, 9, 1, 1, 12])
+ sess.run(
+ gen_state_ops.resource_scatter_nd_update(handle, indices, updates))
+ read = resource_variable_ops.read_variable_op(
+ handle, dtype=dtypes.float32)
+ self.assertAllClose(expected, sess.run(read))
+
class StridedSliceAssignChecker(object):
"""Compares the results of a slice assignment using Tensorflow and numpy."""
@@ -236,12 +435,12 @@ class StridedSliceAssignChecker(object):
self.test.assertAllEqual(val, valnp)
-class SliceAssignTest(XLATestCase):
+class SliceAssignTest(xla_test.XLATestCase):
def testSliceAssign(self):
for dtype in self.numeric_types:
- checker = StridedSliceAssignChecker(self, [[1, 2, 3], [4, 5, 6]],
- dtype=dtype)
+ checker = StridedSliceAssignChecker(
+ self, [[1, 2, 3], [4, 5, 6]], dtype=dtype)
# No-op assignment
checker[:] = [[10, 20, 30], [40, 50, 60]]
# Checks trivial (1,1) shape tensor
diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py
index f79eb27435..b637cf31cf 100644
--- a/tensorflow/compiler/tests/while_test.py
+++ b/tensorflow/compiler/tests/while_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class WhileTest(XLATestCase):
+class WhileTest(xla_test.XLATestCase):
def testSingletonLoopHandrolled(self):
# Define a function for the loop body
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index f0b010fa67..06d977b93c 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -20,14 +20,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
-class XlaDeviceTest(XLATestCase):
+class XlaDeviceTest(xla_test.XLATestCase):
def testCopies(self):
"""Tests that copies onto and off XLA devices work."""
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index e924fe1e61..88827cb53b 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -49,6 +49,32 @@ flags.DEFINE_string('tf_xla_flags', None,
'Value to set the TF_XLA_FLAGS environment variable to')
+def parse_disabled_manifest(manifest_content):
+ comments_re = re.compile('#.*$')
+ disabled_tests = []
+ disabled_method_types = []
+ for l in manifest_content.splitlines():
+ stripped = comments_re.sub('', l).strip()
+ if not stripped:
+ continue
+ entry = stripped.split(' ')
+ if len(entry) == 1:
+ disabled_tests.append(entry[0])
+ elif len(entry) == 2:
+ disabled_method_types.append((entry[0], entry[1].strip().split(',')))
+ else:
+ raise ValueError('Bad entry in manifest file.')
+
+ disabled_regex = '|'.join(disabled_tests)
+ method_types_filter = dict()
+ for method, types in disabled_method_types:
+ method_types_filter[method] = set([
+ dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
+ for name in types
+ ])
+ return disabled_regex, method_types_filter
+
+
class XLATestCase(test.TestCase):
"""XLA test cases are parameterized test cases."""
@@ -85,38 +111,21 @@ class XLATestCase(test.TestCase):
# Parse the manifest file, if any, into a regex identifying tests to
# disable
- self.disabled_regex = None
- self._method_types_filter = dict()
# TODO(xpan): Make it text proto if it doesn't scale.
# Each line of the manifest file specifies an entry. The entry can be
# 1) TestNameRegex // E.g. CumprodTest.* Or
# 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
# The 1) disables the entire test. While 2) only filter some numeric types
# so that they are not used in those tests.
+ self.disabled_regex = None
+ self._method_types_filter = {}
if FLAGS.disabled_manifest is not None:
- comments_re = re.compile('#.*$')
- manifest_file = open(FLAGS.disabled_manifest, 'r')
- disabled_tests = []
- disabled_method_types = []
- for l in manifest_file.read().splitlines():
- if not l:
- continue
- entry = comments_re.sub('', l).strip().split(' ')
- if len(entry) == 1:
- disabled_tests.append(entry[0])
- elif len(entry) == 2:
- disabled_method_types.append(
- (entry[0], entry[1].strip().split(',')))
- else:
- raise ValueError('Bad entry in manifest file.')
-
- self.disabled_regex = re.compile('|'.join(disabled_tests))
- for method, types in disabled_method_types:
- self._method_types_filter[method] = set([
- dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
- for name in types])
- manifest_file.close()
+ with open(FLAGS.disabled_manifest, 'r') as manifest_file:
+ disabled_regex, self._method_types_filter = (
+ parse_disabled_manifest(manifest_file.read()))
+ if disabled_regex:
+ self.disabled_regex = re.compile(disabled_regex)
if FLAGS.tf_xla_flags is not None:
os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
diff --git a/tensorflow/compiler/tests/xla_test_test.py b/tensorflow/compiler/tests/xla_test_test.py
new file mode 100644
index 0000000000..2466445157
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_test_test.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the XLATestCase test fixture base class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.platform import test
+
+
+class XlaTestCaseTestCase(test.TestCase):
+
+ def testManifestEmptyLineDoesNotCatchAll(self):
+ manifest = """
+testCaseOne
+"""
+ disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
+ self.assertEqual(disabled_regex, "testCaseOne")
+
+ def testManifestWholeLineCommentDoesNotCatchAll(self):
+ manifest = """# I am a comment
+testCaseOne
+testCaseTwo
+"""
+ disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
+ self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index cd57452302..ff002d15b0 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -139,12 +139,14 @@ cc_library(
"xla_op_registry.cc",
"xla_resource.cc",
"xla_cpu_backend.cc",
+ "legacy_flags/backend_registration_flags.cc",
] + if_cuda_is_configured([
"xla_gpu_backend.cc",
]),
hdrs = [
"const_analysis.h",
"graph_compiler.h",
+ "legacy_flags/backend_registration_flags.h",
"xla_compilation_device.h",
"xla_compiler.h",
"xla_context.h",
@@ -162,18 +164,24 @@ cc_library(
":sharding_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
@@ -198,7 +206,7 @@ cc_library(
],
visibility = [":friends"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu_internal",
@@ -281,6 +289,7 @@ tf_cc_test(
deps = [
":tf2xla",
":tf2xla_proto",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
@@ -323,7 +332,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
@@ -360,6 +369,7 @@ tf_cc_test(
],
deps = [
":common",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -462,3 +472,13 @@ cc_library(
"//tensorflow/core:protos_all_cc",
],
)
+
+tf_cc_test(
+ name = "xla_op_registry_test",
+ srcs = ["xla_op_registry_test.cc"],
+ deps = [
+ ":xla_compiler",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 42585ad4d8..6cc95149a1 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -166,6 +166,27 @@ StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
return inserted_node;
}
+// Check that the graph has no cycle containing the given node.
+Status CheckNoCycleContains(const Node* node, const int num_nodes) {
+ std::vector<const Node*> ready;
+ ready.push_back(node);
+ std::vector<bool> visited(num_nodes);
+ while (!ready.empty()) {
+ const Node* current_node = ready.back();
+ ready.pop_back();
+ visited[current_node->id()] = true;
+ for (const Edge* out : current_node->out_edges()) {
+ if (out->dst() == node) {
+ return errors::Internal("Detect a cycle: Node \"", node->name(), "\"(",
+ node->def().op(), ") feeds into itself.");
+ } else if (!visited[out->dst()->id()]) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+ return Status::OK();
+}
+
StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
NodeDef arg_def;
NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
@@ -1407,6 +1428,10 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
TF_RETURN_IF_ERROR(
AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node));
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
+ // Check that the if_node doesn't feed into itself.
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNoCycleContains(if_node, graph_->num_node_ids()),
+ "ConvertToXlaIf failed.");
return if_node;
}
@@ -1438,7 +1463,15 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
std::vector<ControlFlowInfo> cf_info;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info));
+ std::vector<string> unreachable_nodes;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes),
+ "FunctionalizeControlFlow failed");
+ if (!unreachable_nodes.empty()) {
+ return errors::InvalidArgument(
+ "The following nodes are unreachable from the source in the graph: ",
+ tensorflow::str_util::Join(unreachable_nodes, ", "));
+ }
// Builds Frames, indexed by name.
std::unordered_map<string, Frame> frames;
@@ -1458,10 +1491,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
frame.parent = parent;
frame.name = cf.frame_name;
++parent->num_children;
- } else if (frame.parent != parent) {
- return errors::InvalidArgument("Mismatched parent frames for ",
- cf.frame->id(), ": ", parent->name, " vs ",
- frame.parent->name);
}
if (IsEnter(node)) {
@@ -1471,12 +1500,6 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
&arg.is_loop_invariant));
frame.args.push_back(arg);
} else if (IsLoopCond(node)) {
- if (frame.loop_cond) {
- return errors::InvalidArgument(
- "Loop ", cf.frame_name,
- " has more than one LoopCond node: ", node->name(), " and ",
- frame.loop_cond->name());
- }
frame.loop_cond = node;
}
frame.nodes.insert(node);
@@ -1508,6 +1531,16 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
worklist.push_back(frame->parent);
}
}
+ // There should be no cycle at this point, since while loops have been removed
+ // from graph.
+ // Check that the newly added XlaWhile nodes don't feed into themselves.
+ for (const Node* node : graph->op_nodes()) {
+ if (node->def().op() == "XlaWhile") {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ CheckNoCycleContains(node, graph->num_node_ids()),
+ "FunctionalizeLoop failed.");
+ }
+ }
// FunctionalizeControlFlow is invoked for every function, so the loops's
// bodies and conditionals that were extracted into functions will be handled
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index 14977a908a..aae2f8ee5a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/util/equal_graph_def.h"
@@ -1012,5 +1013,60 @@ TEST(FunctionalizeControlFlow, Complex) {
}
}
+TEST(FunctionalizeControlFlow, Cycle) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ // -----------------------------------------------------
+ // | |
+ // | v
+ // less -> switch_1 --> add -> merge_1 -> identity -> switch_2
+ // | ^ |
+ // | | v
+ // --------> one -------------------------> add_2 ---> merge_2
+ {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+
+ auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
+ auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
+ auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
+ auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), x, less);
+ auto two =
+ ops::Const<int32>(scope.WithOpName("cond/two")
+ .WithControlDependencies(switch_1.output_true),
+ 2);
+ auto mul = ops::Multiply(scope.WithOpName("cond/true/mul"),
+ switch_1.output_true, two);
+ auto one =
+ ops::Const<int32>(scope.WithOpName("cond/one")
+ .WithControlDependencies(switch_1.output_false),
+ 1);
+ auto add = ops::Add(scope.WithOpName("cond/false/add"),
+ switch_1.output_false, one);
+
+ auto merge_1 = ops::Merge(scope.WithOpName("cond/Merge"),
+ std::initializer_list<Input>{add, mul});
+ auto identity =
+ ops::Identity(scope.WithOpName("cond/Merge/identity"), merge_1.output);
+ auto switch_2 =
+ ops::Switch(scope.WithOpName("grad/cond/Switch"), identity, less);
+ auto add_2 = ops::Add(scope.WithOpName("cond_2/false/add"),
+ switch_2.output_false, one);
+ auto mul_2 = ops::Multiply(scope.WithOpName("cond_2/true/mul"),
+ switch_2.output_true, two);
+ auto merge_2 = ops::Merge(scope.WithOpName("cond_2/Merge"),
+ std::initializer_list<Input>{add_2, mul_2});
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ }
+ // No cycle before functionalize control flow.
+ TF_EXPECT_OK(graph::ValidateGraphHasNoCycle(*graph));
+ FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ // switch_1 and switch_2 have the same switch depth. They are replaced by a
+ // single XlaIf node during FunctionalizeControlFlow, resulting in a cycle:
+ // less -> XlaIf <--> identity.
+ Status status = FunctionalizeControlFlow(graph.get(), &library);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "Detect a cycle"))
+ << status.error_message();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 212f6f3966..4900af6df1 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/validate.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@@ -87,6 +89,8 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
}
} // namespace
Status GraphCompiler::Compile() {
+ // Check that the graph has no illegal cycles.
+ TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
// Maintain a mapping from node id to node outputs.
using NodeOutputs = std::vector<TensorValue>;
std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
@@ -227,7 +231,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
XlaContext& context = XlaContext::Get(op_context);
auto* b = context.builder();
- auto output_handle = b->Call(*result.computation, handles);
+ auto output_handle = xla::Call(b, *result.computation, handles);
// The output handle of `Call` computation is a tuple type. Unzip it so
// that it can fit into future computations.
int computation_output = 0;
@@ -236,7 +240,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
} else {
xla_op_context.SetOutput(
- i, b->GetTupleElement(output_handle, computation_output));
+ i, xla::GetTupleElement(output_handle, computation_output));
++computation_output;
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index edd2ab6301..5a335aa43c 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -58,6 +58,7 @@ tf_kernel_library(
"pack_op.cc",
"pad_op.cc",
"pooling_ops.cc",
+ "qr_op.cc",
"quantize_and_dequantize_op.cc",
"random_ops.cc",
"reduce_window_op.cc",
@@ -79,14 +80,17 @@ tf_kernel_library(
"shape_util.cc",
"slice_op.cc",
"softmax_op.cc",
+ "sort_ops.cc",
"spacetobatch_op.cc",
"spacetodepth_op.cc",
+ "sparse_to_dense_op.cc",
"split_op.cc",
"stack_ops.cc",
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tile_ops.cc",
+ "topk_op.cc",
"training_ops.cc",
"transpose_op.cc",
"unary_ops.cc",
@@ -104,12 +108,15 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:cholesky",
+ "//tensorflow/compiler/tf2xla/lib:qr",
+ "//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -117,6 +124,9 @@ tf_kernel_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
+ "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:image_ops_op_lib",
@@ -152,7 +162,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -168,7 +178,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -203,6 +213,7 @@ tf_kernel_library(
":index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client/lib:arithmetic",
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
index 1e59868621..e335328280 100644
--- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -31,7 +32,7 @@ class AddNOp : public XlaOpKernel {
xla::XlaOp sum = ctx->Input(0);
for (int i = 1; i < ctx->num_inputs(); ++i) {
- sum = ctx->builder()->Add(sum, ctx->Input(i));
+ sum = xla::Add(sum, ctx->Input(i));
}
ctx->SetOutput(0, sum);
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
index b0ba25b998..4cfe946b2e 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -28,11 +28,10 @@ class BatchMatMulOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1),
+ auto result = BatchDot(ctx->Input(0), ctx->Input(1),
/*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
/*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
- OP_REQUIRES_OK(ctx, result.status());
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, result);
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 15e1815a4c..c4af79281d 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -34,10 +35,11 @@ class FusedBatchNormOp : public XlaOpKernel {
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
+ (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
+ data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
errors::InvalidArgument(
"Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC and NCHW"));
+ "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -48,8 +50,6 @@ class FusedBatchNormOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(ctx->input_type(1), &scale_type));
- xla::XlaBuilder* builder = ctx->builder();
-
xla::XlaOp input = ctx->Input(0);
TensorShape input_shape = ctx->InputShape(0);
@@ -59,30 +59,30 @@ class FusedBatchNormOp : public XlaOpKernel {
// TODO(b/69928690): support mixed precision in the XLA batch normalization
// operators. As a workaround, cast everything to the statistics type (which
// may be more precise than the input type).
- input = builder->ConvertElementType(input, scale_type);
+ input = xla::ConvertElementType(input, scale_type);
if (is_training_) {
- xla::XlaOp output = builder->BatchNormTraining(
+ xla::XlaOp output = xla::BatchNormTraining(
input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index);
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
- ctx->SetOutput(0, builder->ConvertElementType(
- builder->GetTupleElement(output, 0), input_type));
- ctx->SetOutput(1, builder->GetTupleElement(output, 1));
- ctx->SetOutput(2, builder->GetTupleElement(output, 2));
+ ctx->SetOutput(0, xla::ConvertElementType(xla::GetTupleElement(output, 0),
+ input_type));
+ ctx->SetOutput(1, xla::GetTupleElement(output, 1));
+ ctx->SetOutput(2, xla::GetTupleElement(output, 2));
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
- ctx->SetOutput(3, builder->GetTupleElement(output, 1));
- ctx->SetOutput(4, builder->GetTupleElement(output, 2));
+ ctx->SetOutput(3, xla::GetTupleElement(output, 1));
+ ctx->SetOutput(4, xla::GetTupleElement(output, 2));
} else {
- xla::XlaOp output = builder->BatchNormInference(
+ xla::XlaOp output = xla::BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
epsilon_, feature_index);
- ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
+ ctx->SetOutput(0, xla::ConvertElementType(output, input_type));
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
@@ -111,10 +111,11 @@ class FusedBatchNormGradOp : public XlaOpKernel {
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
OP_REQUIRES(ctx,
- (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW),
+ (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW ||
+ data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN),
errors::InvalidArgument(
"Unsupported data format ", ToString(data_format_),
- "; supported formats are NHWC and NCHW"));
+ "; supported formats are NHWC, NCHW, HWNC and HWCN"));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -142,12 +143,12 @@ class FusedBatchNormGradOp : public XlaOpKernel {
xla::XlaOp offset_backprop;
if (is_training_) {
xla::XlaOp output =
- b->BatchNormGrad(activations, scale, mean, var, grad_backprop,
- epsilon_, feature_index);
+ xla::BatchNormGrad(activations, scale, mean, var, grad_backprop,
+ epsilon_, feature_index);
- x_backprop = b->GetTupleElement(output, 0);
- scale_backprop = b->GetTupleElement(output, 1);
- offset_backprop = b->GetTupleElement(output, 2);
+ x_backprop = xla::GetTupleElement(output, 0);
+ scale_backprop = xla::GetTupleElement(output, 1);
+ offset_backprop = xla::GetTupleElement(output, 2);
} else {
// Reduce over all dimensions except the feature dim.
std::vector<int64> reduction_dims(input_dims - 1);
@@ -164,35 +165,35 @@ class FusedBatchNormGradOp : public XlaOpKernel {
auto converted =
XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
// scratch1 = rsqrt(pop_var + epsilon)
auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5);
- auto scratch1 =
- b->Pow(b->Add(var, b->ConstantR0<float>(epsilon_)), neg_half);
+ auto scratch1 = xla::Pow(
+ xla::Add(var, xla::ConstantR0<float>(b, epsilon_)), neg_half);
// scratch2 = sum(y_backprop * (x - mean))
auto mul =
- b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index}));
+ xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index}));
converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type);
reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
x_backprop =
- b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index});
- scale_backprop = b->Mul(scratch1, scratch2);
+ xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index});
+ scale_backprop = xla::Mul(scratch1, scratch2);
}
ctx->SetOutput(0,
XlaHelpers::ConvertElementType(b, x_backprop, input_dtype));
ctx->SetOutput(1, scale_backprop);
ctx->SetOutput(2, offset_backprop);
- ctx->SetConstantOutput(3, Tensor(scale_dtype, {}));
- ctx->SetConstantOutput(4, Tensor(scale_dtype, {}));
+ ctx->SetConstantOutput(3, Tensor());
+ ctx->SetConstantOutput(4, Tensor());
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
index 642278ab99..26130fd9e7 100644
--- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -45,7 +46,6 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
", 2] instead of ",
xla::ShapeUtil::HumanString(crops.shape())));
- xla::XlaBuilder* b = ctx->builder();
const int64 batch_size = input_shape[0];
// Compute the product of the block_shape values.
@@ -72,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
reshaped_shape[block_rank] = batch_size / block_num_elems;
std::copy(input_shape.begin() + 1, input_shape.end(),
reshaped_shape.begin() + block_rank + 1);
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce `permuted` of shape
// [batch / prod(block_shape),
@@ -90,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
}
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
- xla::XlaOp permuted = b->Transpose(reshaped, permutation);
+ xla::XlaOp permuted = xla::Transpose(reshaped, permutation);
// 3. Reshape `permuted` to produce `reshaped_permuted` of shape
// [batch / prod(block_shape),
@@ -110,7 +110,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_permuted_shape.begin() + 1 + block_rank);
- xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape);
+ xla::XlaOp reshaped_permuted =
+ xla::Reshape(permuted, reshaped_permuted_shape);
// 4. Crop the start and end of dimensions `[1, ..., M]` of
// `reshaped_permuted` according to `crops` to produce the output of shape:
@@ -138,7 +139,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input,
" end: ", crop_end, " size ", reshaped_permuted_shape[1 + i]));
}
xla::XlaOp output =
- b->Slice(reshaped_permuted, start_indices, end_indices, strides);
+ xla::Slice(reshaped_permuted, start_indices, end_indices, strides);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
index ee2c920453..ba3b1c9dab 100644
--- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index 9d677f4266..e9b2c0b16d 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -60,8 +61,7 @@ class BiasOp : public XlaOpKernel {
"of the input tensor: ",
bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
- xla::XlaOp result =
- ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
+ xla::XlaOp result = xla::Add(ctx->Input(0), ctx->Input(1), {feature_dim});
ctx->SetOutput(0, result);
}
@@ -109,8 +109,8 @@ class BiasAddGradOp : public XlaOpKernel {
auto converted =
XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), reduce_dims);
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), reduce_dims);
ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0)));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index f04cde878e..d6d4ae8937 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -41,18 +41,19 @@ namespace {
const BCast& broadcast_helper, \
const std::vector<int64>& extend_dimensions) override { \
xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
return HLO; \
} \
}; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op)
-XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Add, xla::Add(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Sub, xla::Sub(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mul, xla::Mul(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
@@ -67,13 +68,13 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
- auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero));
- auto abs_x = b->Abs(x);
- auto abs_y = b->Abs(y);
- auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one));
- auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y));
+ auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
+ auto abs_x = xla::Abs(x);
+ auto abs_y = xla::Abs(y);
+ auto t = xla::Neg(xla::Sub(xla::Add(abs_x, abs_y), one));
+ auto result = xla::Select(different_sign, xla::Div(t, abs_y), xla::Div(x, y));
if (DataTypeIsFloating(dtype)) {
- result = b->Floor(result);
+ result = xla::Floor(result);
}
return result;
}
@@ -87,75 +88,78 @@ static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
- auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero));
- auto trunc_mod = b->Rem(x, y);
- return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y));
+ auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
+ auto trunc_mod = xla::Rem(x, y);
+ return xla::Select(same_sign, trunc_mod, xla::Rem(xla::Add(trunc_mod, y), y));
}
XLA_MAKE_BINARY(FloorMod,
FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
-XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseAnd, xla::And(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseOr, xla::Or(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(BitwiseXor, xla::Xor(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LeftShift, xla::ShiftLeft(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(RightShift,
(DataTypeIsUnsigned(ctx->input_type(0))
- ? b->ShiftRightLogical(lhs, rhs, extend_dimensions)
- : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
-
-XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(ReciprocalGrad, b->Neg(b->Mul(rhs, b->Mul(lhs, lhs))));
+ ? xla::ShiftRightLogical(lhs, rhs, extend_dimensions)
+ : xla::ShiftRightArithmetic(lhs, rhs, extend_dimensions)));
+
+XLA_MAKE_BINARY(LogicalAnd, xla::And(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LogicalOr, xla::Or(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mod, xla::Rem(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Maximum, xla::Max(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Minimum, xla::Min(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(RealDiv, xla::Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(ReciprocalGrad, xla::Neg(xla::Mul(rhs, xla::Mul(lhs, lhs))));
XLA_MAKE_BINARY(
RsqrtGrad,
- b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
- b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
- extend_dimensions));
-XLA_MAKE_BINARY(SqrtGrad,
- b->Div(b->Mul(rhs,
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
- lhs, extend_dimensions));
+ xla::Mul(xla::Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
+ xla::Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
+ extend_dimensions));
+XLA_MAKE_BINARY(
+ SqrtGrad,
+ xla::Div(xla::Mul(rhs, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
+ lhs, extend_dimensions));
static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) {
- return builder->Mul(x, x);
+ return xla::Mul(x, x);
}
XLA_MAKE_BINARY(SquaredDifference,
- Square(b, b->Sub(lhs, rhs, extend_dimensions)));
+ Square(b, xla::Sub(lhs, rhs, extend_dimensions)));
-XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(TruncateDiv, xla::Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(TruncateMod, xla::Rem(lhs, rhs, extend_dimensions));
// Comparison ops
-XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions));
-XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Equal, xla::Eq(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(NotEqual, xla::Ne(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Greater, xla::Gt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(GreaterEqual, xla::Ge(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Less, xla::Lt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LessEqual, xla::Le(lhs, rhs, extend_dimensions));
// Non-linear ops
XLA_MAKE_BINARY(SigmoidGrad,
- b->Mul(b->Mul(rhs, lhs),
- b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
+ xla::Mul(xla::Mul(rhs, lhs),
+ xla::Sub(XlaHelpers::One(b, input_type(0)), lhs)));
XLA_MAKE_BINARY(SoftplusGrad,
- b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
- XlaHelpers::One(b, input_type(1)))));
+ xla::Div(lhs, xla::Add(xla::Exp(xla::Neg(rhs)),
+ XlaHelpers::One(b, input_type(1)))));
// softsigngrad(gradients, features) = gradients / (1 + abs(features)) ** 2
XLA_MAKE_BINARY(SoftsignGrad,
- b->Div(lhs, Square(b, b->Add(XlaHelpers::One(b, input_type(0)),
- b->Abs(rhs)))));
+ xla::Div(lhs,
+ Square(b, xla::Add(XlaHelpers::One(b, input_type(0)),
+ xla::Abs(rhs)))));
-XLA_MAKE_BINARY(TanhGrad, b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(lhs, lhs))));
+XLA_MAKE_BINARY(TanhGrad,
+ xla::Mul(rhs, xla::Sub(XlaHelpers::One(b, input_type(0)),
+ xla::Mul(lhs, lhs))));
-XLA_MAKE_BINARY(Pow, b->Pow(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Pow, xla::Pow(lhs, rhs, extend_dimensions));
#undef XLA_MAKE_BINARY
@@ -168,12 +172,13 @@ class ApproximateEqualOp : public XlaOpKernel {
// Computes the max of the scalar input x and 0.
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
- auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1)));
+ auto abs = xla::Abs(xla::Sub(ctx->Input(0), ctx->Input(1)));
auto abs_shape = b->GetShape(abs);
OP_REQUIRES_OK(ctx, abs_shape.status());
auto abs_type = abs_shape.ValueOrDie().element_type();
- auto result = b->Lt(
- abs, b->ConvertElementType(b->ConstantR0<float>(tolerance_), abs_type));
+ auto result =
+ xla::Lt(abs, xla::ConvertElementType(
+ xla::ConstantR0<float>(b, tolerance_), abs_type));
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
index ca9a6b4068..efbdb76eaa 100644
--- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -36,22 +37,22 @@ class BucketizeOp : public XlaOpKernel {
const DataType dtype = context->input_type(0);
xla::XlaOp input = context->Input(0);
- xla::XlaOp boundaries = builder->ConstantR1<float>(boundaries_);
+ xla::XlaOp boundaries = xla::ConstantR1<float>(builder, boundaries_);
// TODO(phawkins): the following behavior matches the behavior of the core
// Bucketize kernel. However, comparing an int32 or int64 against float may
// lead to inaccurate bucketing due to rounding.
if (dtype == DT_DOUBLE) {
- input = builder->ConvertElementType(input, xla::F64);
- boundaries = builder->ConvertElementType(boundaries, xla::F64);
+ input = xla::ConvertElementType(input, xla::F64);
+ boundaries = xla::ConvertElementType(boundaries, xla::F64);
} else {
- input = builder->ConvertElementType(input, xla::F32);
+ input = xla::ConvertElementType(input, xla::F32);
}
- xla::XlaOp comparison = builder->ConvertElementType(
- builder->Ge(builder->Broadcast(input, {1}), boundaries,
- /*broadcast_dimensions=*/{0}),
- xla::S32);
- xla::XlaOp buckets = builder->Reduce(
- comparison, /*init_value=*/builder->ConstantR0<int32>(0),
+ xla::XlaOp comparison =
+ xla::ConvertElementType(xla::Ge(xla::Broadcast(input, {1}), boundaries,
+ /*broadcast_dimensions=*/{0}),
+ xla::S32);
+ xla::XlaOp buckets = xla::Reduce(
+ comparison, /*init_value=*/xla::ConstantR0<int32>(builder, 0),
/*computation=*/xla::CreateScalarAddComputation(xla::S32, builder),
/*dimensions_to_reduce=*/{0});
context->SetOutput(0, buckets);
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
index e9d98c7685..62eebf762b 100644
--- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -40,14 +41,14 @@ class CastOp : public XlaOpKernel {
if (src_dtype_ == dst_dtype_) {
output = input;
} else if (dst_dtype_ == DT_BOOL) {
- output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
+ output = xla::Ne(input, XlaHelpers::Zero(builder, src_dtype_));
} else if (xla::primitive_util::IsComplexType(src_type_) &&
!xla::primitive_util::IsComplexType(dst_type_)) {
// As in cast_op.h, we replicate the numpy behavior of truncating the
// imaginary part.
- output = builder->ConvertElementType(builder->Real(input), dst_type_);
+ output = xla::ConvertElementType(xla::Real(input), dst_type_);
} else {
- output = builder->ConvertElementType(input, dst_type_);
+ output = xla::ConvertElementType(input, dst_type_);
}
ctx->SetOutput(0, output);
@@ -72,7 +73,6 @@ class BitcastOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
xla::XlaOp input = ctx->Input(0);
xla::XlaOp output;
@@ -92,7 +92,7 @@ class BitcastOp : public XlaOpKernel {
xla::primitive_util::BitWidth(dst_type_),
errors::Unimplemented(
"Only bitcasts between equally sized types supported."));
- output = builder->BitcastConvertType(input, dst_type_);
+ output = xla::BitcastConvertType(input, dst_type_);
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index 835a7f5689..1784e712b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -65,24 +66,22 @@ class CategoricalOp : public XlaOpKernel {
DataTypeToPrimitiveType(input_type(0), &uniform_xla_type));
xla::Shape uniform_shape =
xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array);
- auto uniforms = builder->RngUniform(
- XlaHelpers::Zero(builder, input_type(0)),
- XlaHelpers::One(builder, input_type(0)), uniform_shape);
+ auto uniforms =
+ xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)),
+ XlaHelpers::One(builder, input_type(0)), uniform_shape);
// Use Gumbel softmax trick to generate categorical samples.
// See:
// https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
// TODO(b/68769470): Switch to using a cumulative sum approach.
- auto softmax_entries =
- builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))),
- /*broadcast_dimensions=*/{0, 2});
-
- TensorShape softmax_shape(uniform_shape_array);
- xla::XlaOp argmax;
- OP_REQUIRES_OK(
- ctx,
- XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape,
- input_type(0), output_type(0), /*axis=*/2, &argmax));
+ auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)),
+ /*broadcast_dimensions=*/{0, 2});
+
+ xla::PrimitiveType xla_output_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(output_type(0), &xla_output_type));
+ xla::XlaOp argmax =
+ XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2);
ctx->SetOutput(0, argmax);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
index fe6651793d..9fcbc86adc 100644
--- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc
@@ -24,12 +24,7 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- auto result = Cholesky(ctx->builder(), ctx->Input(0));
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, Cholesky(ctx->Input(0)));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
index a00bc912f9..4e6d33304c 100644
--- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -29,7 +30,6 @@ class ClipByValueOp : public XlaOpKernel {
const TensorShape min_shape = ctx->InputShape(1);
const TensorShape max_shape = ctx->InputShape(2);
- xla::XlaBuilder* builder = ctx->builder();
auto input = ctx->Input(0);
auto min = ctx->Input(1);
auto max = ctx->Input(2);
@@ -45,13 +45,13 @@ class ClipByValueOp : public XlaOpKernel {
if (shape != min_shape) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error());
- min = builder->Broadcast(min, shape.dim_sizes());
+ min = xla::Broadcast(min, shape.dim_sizes());
}
if (shape != max_shape) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error());
- max = builder->Broadcast(max, shape.dim_sizes());
+ max = xla::Broadcast(max, shape.dim_sizes());
}
- ctx->SetOutput(0, builder->Clamp(min, input, max));
+ ctx->SetOutput(0, xla::Clamp(min, input, max));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index 78285affa1..e3a32a5c0e 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -88,7 +89,7 @@ class ConcatBaseOp : public XlaOpKernel {
"] = ", in_shape.DebugString()));
if (in_shape.dims() == 0) {
// Inputs that come in as scalars must be reshaped to 1-vectors.
- input_data.push_back(ctx->builder()->Reshape(handle, {1}));
+ input_data.push_back(xla::Reshape(handle, {1}));
} else {
input_data.push_back(handle);
}
@@ -96,7 +97,7 @@ class ConcatBaseOp : public XlaOpKernel {
}
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index 59d06c654d..f4360d8c3f 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -53,41 +54,41 @@ class ConstOp : public XlaOpKernel {
switch (proto_.dtype()) {
case DT_BOOL:
if (proto_.bool_val_size() == 1) {
- ctx->SetOutput(0,
- b->Broadcast(b->ConstantR0<bool>(proto_.bool_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(
+ 0, xla::Broadcast(xla::ConstantR0<bool>(b, proto_.bool_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_FLOAT:
if (proto_.float_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<float>(proto_.float_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<float>(
+ b, proto_.float_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_DOUBLE:
if (proto_.double_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<double>(proto_.double_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<double>(
+ b, proto_.double_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
- ctx->SetOutput(0,
- b->Broadcast(b->ConstantR0<int32>(proto_.int_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(
+ 0, xla::Broadcast(xla::ConstantR0<int32>(b, proto_.int_val(0)),
+ shape.dim_sizes()));
return;
}
break;
case DT_INT64:
if (proto_.int64_val_size() == 1) {
- ctx->SetOutput(
- 0, b->Broadcast(b->ConstantR0<int64>(proto_.int64_val(0)),
- shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(xla::ConstantR0<int64>(
+ b, proto_.int64_val(0)),
+ shape.dim_sizes()));
return;
}
break;
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 627bad12f3..48ac4867ed 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -51,8 +53,8 @@ xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
+ return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
+ expanded_filter_shape.dim_sizes());
}
// Create a mask for depthwise convolution that will make a normal convolution
@@ -95,32 +97,27 @@ xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
// Create a M sized linspace and an M*N sized linspace that will be
// broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
- &input_feature_iota));
- xla::XlaOp expanded_feature_iota;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
- input_feature * depthwise_multiplier,
- &expanded_feature_iota));
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
// Divide the M*N sized linspace by the depthwise_multiplier to create
// [0 0 1 1 2 2] in the example in the function comment.
expanded_feature_iota =
- builder->Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
// Broadcast the N*M linspace to [H, W, ..., M, M*N].
auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota = builder->Broadcast(
- expanded_feature_iota, expanded_feature_broadcast_dims);
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
// Compare the broadcasted linspace to the input feature linspace in the
// input feature dimension to create a diagonal predicate.
- return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dims() - 2});
}
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
@@ -142,16 +139,16 @@ xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape,
implicit_broadcast_filter_shape.dims() - 1,
depthwise_multiplier * input_feature);
auto implicit_broadcast_filter =
- builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
+ xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
// Broadcast the filter to [H, W, ..., M, M*N].
auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
- auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero);
+ auto expanded_filter = xla::Add(implicit_broadcast_filter, expanded_zero);
// If the filter mask is set, choose the broadcasted filter, othwerwise,
// choose zero.
- return builder->Select(CreateExpandedFilterMask(filter_shape, builder),
- expanded_filter, expanded_zero);
+ return xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ expanded_filter, expanded_zero);
}
// Inverse of ExpandFilterForDepthwiseConvolution.
@@ -162,17 +159,17 @@ xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
xla::XlaBuilder* builder) {
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- auto masked_expanded_filter = builder->Select(
+ auto masked_expanded_filter = xla::Select(
CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
CreateExpandedZero(filter_shape, dtype, builder));
- return builder->Reshape(
+ return xla::Reshape(
// This reduce does not need inputs to be converted with
// XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
// ExpandedZero guarantees that only one element is non zero, so there
// cannot be accumulated precision error.
- builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype),
- {expanded_filter_shape.dims() - 2}),
+ xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
+ *ctx->GetOrCreateAdd(dtype),
+ {expanded_filter_shape.dims() - 2}),
filter_shape.dim_sizes());
}
@@ -289,8 +286,8 @@ class ConvOp : public XlaOpKernel {
}
xla::XlaOp conv =
- b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
- lhs_dilation, rhs_dilation, dims);
+ xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
+ lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
@@ -435,11 +432,11 @@ class ConvBackpropInputOp : public XlaOpKernel {
}
// Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims);
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = b->ConvGeneralDilated(
+ xla::XlaOp in_backprop = xla::ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
lhs_dilation, rhs_dilation, dnums);
@@ -638,8 +635,8 @@ class ConvBackpropFilterOp : public XlaOpKernel {
// This is done by specifying the window dilation factors in the
// convolution HLO below.
auto filter_backprop =
- b->ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
if (depthwise_) {
filter_backprop = ContractFilterForDepthwiseBackprop(
diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
index 7fcd4170fb..500a564f3f 100644
--- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -58,21 +59,21 @@ class CrossOp : public XlaOpKernel {
auto in1 = ctx->Input(1);
starts.back() = 0;
limits.back() = 1;
- auto u1 = b->Slice(in0, starts, limits, strides);
- auto v1 = b->Slice(in1, starts, limits, strides);
+ auto u1 = xla::Slice(in0, starts, limits, strides);
+ auto v1 = xla::Slice(in1, starts, limits, strides);
starts.back() = 1;
limits.back() = 2;
- auto u2 = b->Slice(in0, starts, limits, strides);
- auto v2 = b->Slice(in1, starts, limits, strides);
+ auto u2 = xla::Slice(in0, starts, limits, strides);
+ auto v2 = xla::Slice(in1, starts, limits, strides);
starts.back() = 2;
limits.back() = 3;
- auto u3 = b->Slice(in0, starts, limits, strides);
- auto v3 = b->Slice(in1, starts, limits, strides);
+ auto u3 = xla::Slice(in0, starts, limits, strides);
+ auto v3 = xla::Slice(in1, starts, limits, strides);
- auto s1 = b->Sub(b->Mul(u2, v3), b->Mul(u3, v2));
- auto s2 = b->Sub(b->Mul(u3, v1), b->Mul(u1, v3));
- auto s3 = b->Sub(b->Mul(u1, v2), b->Mul(u2, v1));
- auto output = b->ConcatInDim({s1, s2, s3}, in0_shape.dims() - 1);
+ auto s1 = xla::Sub(xla::Mul(u2, v3), xla::Mul(u3, v2));
+ auto s2 = xla::Sub(xla::Mul(u3, v1), xla::Mul(u1, v3));
+ auto s3 = xla::Sub(xla::Mul(u1, v2), xla::Mul(u2, v1));
+ auto output = xla::ConcatInDim(b, {s1, s2, s3}, in0_shape.dims() - 1);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index 01aa1a83e7..9ff3e02228 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -96,18 +96,16 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// First reshape the inputs, which should be a metadata-only
// operation since we are flattening the dimensions in order.
- auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape());
+ auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
+ auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
// Next broadcast the necessary input dimensions. We rely on the
// XLA optimizer to be smart about the fact that we are asking
// it to broadcast size 1 on some of these dimensions, to avoid
// adding complexity to this code.
- auto lhs_broadcast =
- builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast());
+ auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast =
- builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast());
+ auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
int rhs_size = broadcast_helper.y_bcast().size();
// Now reshape them to the correct output shape. After the
@@ -122,15 +120,15 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
lhs_reorder.push_back(i);
lhs_reorder.push_back(i + lhs_size);
}
- auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder,
- broadcast_helper.output_shape());
+ auto lhs_output =
+ xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
std::vector<int64> rhs_reorder;
for (int i = 0; i < rhs_size; ++i) {
rhs_reorder.push_back(i);
rhs_reorder.push_back(i + rhs_size);
}
- auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder,
- broadcast_helper.output_shape());
+ auto rhs_output =
+ xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
return {lhs_output, rhs_output};
}
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index 23243f6246..f314920025 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -50,7 +51,6 @@ class DepthToSpaceOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
@@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel {
") is not divisible by square of the block size (",
block_size_, ")"));
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -141,7 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2],
// block_size_,
// depth / (block_size_ * block_size_)]
- xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -151,7 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel {
// input_shape[2] * block_size_,
// depth / (block_size_ * block_size_)]
//
- xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
index 931705ba83..6dec414c53 100644
--- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -18,6 +18,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -25,10 +28,10 @@ namespace tensorflow {
namespace {
// Create a diagonal / batch diagonal matrix with 'input' on the diagonal.
-xla::StatusOr<xla::XlaOp> CreateDiagonal(
- const xla::XlaOp& input, int64 last_dim_size,
- tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx,
- xla::XlaBuilder* builder) {
+xla::XlaOp CreateDiagonal(xla::XlaOp input, int64 last_dim_size,
+ gtl::ArraySlice<int64> other_dims,
+ xla::PrimitiveType element_type) {
+ xla::XlaBuilder* builder = input.builder();
// Create two matrices that have the following forms, and compare them:
//
// [[0, 0, 0, 0] [[0, 1, 2, 3]
@@ -38,16 +41,14 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
//
// This produces a predicate matrix of the right size, with "true" on the
// diagonal.
- xla::XlaOp iota;
- TF_RETURN_IF_ERROR(
- XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota));
- xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size});
- xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0});
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, last_dim_size);
+ xla::XlaOp iota_broadcast = xla::Broadcast(iota, {last_dim_size});
+ xla::XlaOp mask = xla::Eq(iota_broadcast, iota, {0});
// If this is a batched diagonal, broadcast the mask across the other
// dimensions.
if (!other_dims.empty()) {
- mask = builder->Broadcast(mask, other_dims);
+ mask = xla::Broadcast(mask, other_dims);
}
// Broadcast the input, and then use the mask computed above to select the
@@ -64,18 +65,15 @@ xla::StatusOr<xla::XlaOp> CreateDiagonal(
std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end());
broadcast_dims.push_back(1LL);
broadcast_dims.push_back(last_dim_size);
- xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims);
+ xla::XlaOp input_broadcast = xla::Reshape(input, broadcast_dims);
broadcast_dims[broadcast_dims.size() - 2] = last_dim_size;
- xla::PrimitiveType element_type;
- TF_RETURN_IF_ERROR(
- DataTypeToPrimitiveType(ctx->input_type(0), &element_type));
auto broadcast_shape =
xla::ShapeUtil::MakeShape(element_type, broadcast_dims);
- xla::XlaOp zeros = Zeros(builder, broadcast_shape);
+ xla::XlaOp zeros = xla::Zeros(builder, broadcast_shape);
- input_broadcast = builder->Add(input_broadcast, zeros);
- return builder->Select(mask, input_broadcast, zeros);
+ input_broadcast = xla::Add(input_broadcast, zeros);
+ return xla::Select(mask, input_broadcast, zeros);
}
class DiagOp : public XlaOpKernel {
@@ -83,8 +81,6 @@ class DiagOp : public XlaOpKernel {
explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("Diag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -104,19 +100,17 @@ class DiagOp : public XlaOpKernel {
// Flattens the input to 1D.
int64 size = input_shape.num_elements();
- input = builder->Reshape(input, {size});
+ input = xla::Reshape(input, {size});
// Create an R2 with the R1 diagonal.
- auto diag_or_status =
- CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- xla::XlaOp diag = diag_or_status.ValueOrDie();
+ xla::XlaOp diag =
+ CreateDiagonal(input, size, /*other_dims=*/{}, ctx->input_xla_type(0));
// Reshapes to the final shape.
std::vector<int64> new_dims(dims.size() * 2);
std::copy(dims.begin(), dims.end(), new_dims.begin());
std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
- diag = builder->Reshape(diag, new_dims);
+ diag = xla::Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
}
@@ -170,21 +164,21 @@ class DiagPartOp : public XlaOpKernel {
// Flattens the input to 1D.
int64 size = input_shape.num_elements();
- diag = builder->Reshape(diag, {size});
+ diag = xla::Reshape(diag, {size});
// Adds padding after the last element of 'new_size'.
xla::PaddingConfig config;
auto* dim = config.add_dimensions();
dim->set_edge_padding_high(new_size);
auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = builder->Pad(diag, zero, config);
+ diag = xla::Pad(diag, zero, config);
// Reshapes so the diagonal is now in the first column.
- diag = builder->Reshape(diag, {new_size, new_size + 1});
+ diag = xla::Reshape(diag, {new_size, new_size + 1});
// Slices out the first column and reshapes to the final shape.
- diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
- diag = builder->Reshape(diag, new_dims);
+ diag = xla::Slice(diag, {0, 0}, {new_size, 1}, {1, 1});
+ diag = xla::Reshape(diag, new_dims);
ctx->SetOutput(0, diag);
}
@@ -197,8 +191,6 @@ class MatrixDiagOp : public XlaOpKernel {
explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* builder = ctx->builder();
-
OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
errors::InvalidArgument("MatrixDiag op must have at an input"));
const TensorShape input_shape = ctx->InputShape(0);
@@ -208,17 +200,15 @@ class MatrixDiagOp : public XlaOpKernel {
errors::InvalidArgument("Expected 1 <= dims, got shape ",
input_shape.DebugString()));
- xla::XlaOp diag = ctx->Input(0);
int last_dim = dims.size() - 1;
int64 last_dim_size = input_shape.dim_size(last_dim);
tensorflow::gtl::ArraySlice<int64> other_dims(dims);
other_dims.pop_back();
- auto diag_or_status =
- CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder);
- OP_REQUIRES_OK(ctx, diag_or_status.status());
- diag = diag_or_status.ValueOrDie();
+ xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims,
+ ctx->input_xla_type(0));
ctx->SetOutput(0, diag);
}
};
@@ -265,7 +255,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
// Collapses the last two dimensions.
std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
flattened_dims.back() *= dims.back();
- diag = builder->Reshape(diag, flattened_dims);
+ diag = xla::Reshape(diag, flattened_dims);
// Slices or pads the last dimension to 'target_size'.
int64 actual_size = flattened_dims.back();
@@ -276,13 +266,13 @@ class MatrixDiagPartOp : public XlaOpKernel {
auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
dim->set_edge_padding_high(target_size - actual_size);
auto zero = XlaHelpers::Zero(builder, input_type(0));
- diag = builder->Pad(diag, zero, config);
+ diag = xla::Pad(diag, zero, config);
} else if (actual_size > target_size) {
std::vector<int64> start(flattened_dims.size(), 0);
std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
std::vector<int64> strides(flattened_dims.size(), 1);
limits[flattened_dims.size() - 1] = target_size;
- diag = builder->Slice(diag, start, limits, strides);
+ diag = xla::Slice(diag, start, limits, strides);
}
// Reshape so the target values are in the first position of the last
@@ -290,18 +280,18 @@ class MatrixDiagPartOp : public XlaOpKernel {
std::vector<int64> unflattened_dims(dims.begin(), dims.end());
dims[last_dim - 1] = smaller_dim_size;
dims[last_dim] = last_dim_size + 1;
- diag = builder->Reshape(diag, dims);
+ diag = xla::Reshape(diag, dims);
// Slices out the first column and reshapes to the final shape.
std::vector<int64> start(dims.size(), 0);
std::vector<int64> limits(dims.begin(), dims.end());
std::vector<int64> strides(dims.size(), 1);
limits[last_dim] = 1;
- diag = builder->Slice(diag, start, limits, strides);
+ diag = xla::Slice(diag, start, limits, strides);
// Collapses away the last dimension.
dims.pop_back();
- diag = builder->Reshape(diag, dims);
+ diag = xla::Reshape(diag, dims);
ctx->SetOutput(0, diag);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index 0419de78b2..3b86ea34c9 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -57,8 +57,8 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::XlaOp result = ctx->builder()->DynamicUpdateSlice(
- ctx->Input(0), ctx->Input(1), ctx->Input(2));
+ xla::XlaOp result =
+ xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
ctx->SetOutput(0, result);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index dd4a169087..958231505b 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -150,8 +151,7 @@ class DynamicStitchOp : public XlaOpKernel {
if (new_shape == data_shapes[input_num]) {
input[input_num] = handle;
} else {
- input[input_num] =
- ctx->builder()->Reshape(handle, new_shape.dim_sizes());
+ input[input_num] = xla::Reshape(handle, new_shape.dim_sizes());
}
}
@@ -175,10 +175,10 @@ class DynamicStitchOp : public XlaOpKernel {
// And place it in the concat list in the place indicated by
// the index.
to_concat[index_num] =
- ctx->builder()->Slice(expression, slice_start, slice_limit, stride);
+ xla::Slice(expression, slice_start, slice_limit, stride);
}
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), to_concat, 0));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
index 493781a1e6..81f42e504e 100644
--- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -34,9 +34,9 @@ class EluOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
const auto zero = XlaHelpers::Zero(b, input_type(0));
- const auto pred = b->Gt(ctx->Input(0), zero);
- const auto expm1 = b->Expm1(ctx->Input(0));
- ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1));
+ const auto pred = xla::Gt(ctx->Input(0), zero);
+ const auto expm1 = xla::Expm1(ctx->Input(0));
+ ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), expm1));
}
};
@@ -51,9 +51,9 @@ class EluGradOp : public XlaOpKernel {
const auto one = XlaHelpers::One(b, input_type(0));
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
- const auto exp_grad = b->Mul(grad, b->Add(activation, one));
- const auto pred = b->Gt(activation, zero);
- ctx->SetOutput(0, b->Select(pred, grad, exp_grad));
+ const auto exp_grad = xla::Mul(grad, xla::Add(activation, one));
+ const auto pred = xla::Gt(activation, zero);
+ ctx->SetOutput(0, xla::Select(pred, grad, exp_grad));
}
};
@@ -71,10 +71,10 @@ class SeluOp : public XlaOpKernel {
1.0507009873554804934193349852946);
const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0),
1.7580993408473768599402175208123);
- const auto pred = b->Gt(ctx->Input(0), zero);
- const auto expm1 = b->Expm1(ctx->Input(0));
- ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)),
- b->Mul(scale_alpha, expm1)));
+ const auto pred = xla::Gt(ctx->Input(0), zero);
+ const auto expm1 = xla::Expm1(ctx->Input(0));
+ ctx->SetOutput(0, xla::Select(pred, xla::Mul(scale, ctx->Input(0)),
+ xla::Mul(scale_alpha, expm1)));
}
};
@@ -92,10 +92,10 @@ class SeluGradOp : public XlaOpKernel {
1.7580993408473768599402175208123);
const auto grad = ctx->Input(0);
const auto activation = ctx->Input(1);
- const auto lin_grad = b->Mul(grad, scale);
- const auto exp_grad = b->Mul(grad, b->Add(activation, scale_alpha));
- const auto pred = b->Gt(activation, zero);
- ctx->SetOutput(0, b->Select(pred, lin_grad, exp_grad));
+ const auto lin_grad = xla::Mul(grad, scale);
+ const auto exp_grad = xla::Mul(grad, xla::Add(activation, scale_alpha));
+ const auto pred = xla::Gt(activation, zero);
+ ctx->SetOutput(0, xla::Select(pred, lin_grad, exp_grad));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index 6df01cabbf..65d42a302f 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -110,13 +112,11 @@ class ExtractImagePatchesOp : public XlaOpKernel {
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
- xla::XlaOp iota;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
- kernel_size * depth, &iota));
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, kernel_size * depth);
- auto lhs = builder->Reshape(iota, lhs_shape);
- auto filter = builder->ConvertElementType(
- builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
+ auto lhs = xla::Reshape(iota, lhs_shape);
+ auto filter = xla::ConvertElementType(
+ xla::Eq(lhs, iota, {num_spatial_dims + 1}), type);
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(num_spatial_dims);
@@ -148,8 +148,8 @@ class ExtractImagePatchesOp : public XlaOpKernel {
}
xla::XlaOp conv =
- builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
- padding, lhs_dilation, rhs_dilation, dims);
+ xla::ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
+ lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index 8f0de0a524..2fd1a34741 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -49,20 +50,20 @@ void XlaNudge(xla::XlaBuilder* b, const DataType data_type,
const float quant_min_value, const float quant_max_value,
xla::XlaOp* nudged_min, xla::XlaOp* nudged_max,
xla::XlaOp* scale) {
- *scale = b->Div(b->Sub(max, min),
- XlaHelpers::FloatLiteral(b, data_type,
- quant_max_value - quant_min_value));
+ *scale = xla::Div(xla::Sub(max, min),
+ XlaHelpers::FloatLiteral(
+ b, data_type, quant_max_value - quant_min_value));
xla::XlaOp quant_min =
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
- xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale));
+ xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale));
xla::XlaOp quant_max =
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
xla::XlaOp nudged_zero_point =
- b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
- b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
- b->Round(zero_point_from_min)));
- *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
- *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
+ xla::Select(xla::Le(zero_point_from_min, quant_min), quant_min,
+ xla::Select(xla::Ge(zero_point_from_min, quant_max),
+ quant_max, xla::Round(zero_point_from_min)));
+ *nudged_min = xla::Mul(xla::Sub(quant_min, nudged_zero_point), *scale);
+ *nudged_max = xla::Mul(xla::Sub(quant_max, nudged_zero_point), *scale);
}
xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
@@ -71,14 +72,14 @@ xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input,
const xla::XlaOp& nudged_input_max,
const xla::XlaOp& input_scale) {
xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
- xla::XlaOp inv_scale = b->Div(one, input_scale);
+ xla::XlaOp inv_scale = xla::Div(one, input_scale);
xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f);
- xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max);
- xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min);
+ xla::XlaOp clamped = xla::Clamp(nudged_input_min, input, nudged_input_max);
+ xla::XlaOp clamped_shifted = xla::Sub(clamped, nudged_input_min);
xla::XlaOp rounded =
- b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
- return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
+ xla::Floor(xla::Add(xla::Mul(clamped_shifted, inv_scale), half));
+ return xla::Add(xla::Mul(rounded, input_scale), nudged_input_min);
}
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
@@ -163,11 +164,11 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
xla::XlaOp nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
- xla::XlaOp between_nudged_min_max =
- b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
- xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type),
- gradient_shape.dim_sizes());
- xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp between_nudged_min_max = xla::And(
+ xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
+ xla::XlaOp zeroes = xla::Broadcast(XlaHelpers::Zero(b, data_type),
+ gradient_shape.dim_sizes());
+ xla::XlaOp output = xla::Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output);
}
@@ -249,25 +250,25 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
- xla::XlaOp between_nudged_min_max =
- b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
+ xla::XlaOp between_nudged_min_max = xla::And(
+ xla::Le(nudged_input_min, input), xla::Le(input, nudged_input_max));
xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
- xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes());
- xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes);
+ xla::XlaOp zeroes = xla::Broadcast(zero, gradient_shape.dim_sizes());
+ xla::XlaOp output0 = xla::Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output0);
- xla::XlaOp below_min = b->Lt(input, nudged_input_min);
- xla::XlaOp select1 = b->Select(below_min, gradient, zeroes);
- xla::XlaOp reduce1 = b->ReduceAll(
+ xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
+ xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
+ xla::XlaOp reduce1 = xla::ReduceAll(
XlaHelpers::ConvertElementType(b, select1, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type);
ctx->SetOutput(1, output1);
- xla::XlaOp above_max = b->Gt(input, nudged_input_max);
- xla::XlaOp select2 = b->Select(above_max, gradient, zeroes);
- xla::XlaOp reduce2 = b->ReduceAll(
+ xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
+ xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
+ xla::XlaOp reduce2 = xla::ReduceAll(
XlaHelpers::ConvertElementType(b, select2, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index 933924cad1..b2b00e51e3 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,8 +63,7 @@ class GenericFftOp : public XlaOpKernel {
}
}
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length);
+ xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length);
ctx->SetOutput(0, fft);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
index e4467a0fb1..95faa1d058 100644
--- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -59,11 +60,11 @@ class FillOp : public XlaOpKernel {
xla::XlaOp data = ctx->Input(1);
if (value_shape.dims() > 0) {
CHECK_EQ(value_shape.dims(), 1);
- data = ctx->builder()->Reshape(data, {});
+ data = xla::Reshape(data, {});
}
// Emit the actual computation, which broadcasts the scalar to the
// desired shape.
- auto result = ctx->builder()->Broadcast(data, broadcast);
+ auto result = xla::Broadcast(data, broadcast);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
index d13e25bcdd..5f041be5df 100644
--- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -75,8 +76,8 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
out_shape.AppendShape(indices_shape_no_index_vectors);
out_shape.AppendShape(input_shape_post_axis);
- *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- out_shape.dim_sizes());
+ *gather_output =
+ xla::Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes());
return Status::OK();
}
@@ -142,7 +143,7 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape,
dim_numbers.add_gather_dims_to_operand_dims(i);
}
- *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds);
+ *gather_output = xla::Gather(input, indices, dim_numbers, window_bounds);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 8b9b026643..f5fcf3cacd 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -48,11 +49,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Building If: " << input_types_.size() << " inputs";
- std::vector<xla::XlaOp> inputs(input_types_.size());
std::vector<XlaCompiler::Argument> arguments(input_types_.size());
for (int i = 0; i < input_types_.size(); ++i) {
XlaCompiler::Argument& arg = arguments[i];
DataType type = ctx->input_type(i + 1);
+
if (type == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
@@ -60,7 +61,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.initialized = resource->initialized();
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind();
- OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
arg.type = resource->type();
arg.shape = resource->shape();
@@ -79,7 +79,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = input_types_[i];
arg.shape = ctx->InputShape(i + 1);
- inputs[i] = ctx->Input(i + 1);
VLOG(2) << "Arg type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString();
}
@@ -100,6 +99,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
arguments, &else_result));
+ bool has_tensor_array_gradients = false;
for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) {
for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) {
XlaResource* resource;
@@ -121,9 +121,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
+ if (!resource->tensor_array_gradients().empty())
+ has_tensor_array_gradients = true;
}
}
+ // Recompile the functions to update the argument shapes for tensor arrays.
+ if (has_tensor_array_gradients) {
+ then_result = {};
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_,
+ arguments, &then_result));
+ else_result = {};
+ OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_,
+ arguments, &else_result));
+ }
+
// Check that both branches have identical input shapes.
OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1,
errors::FailedPrecondition("Expected one input shape"));
@@ -175,13 +187,26 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
"Mismatch in resource of then and else branch for resource ", i));
}
- xla::XlaOp outputs =
- b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation,
- b->Tuple(inputs), *else_result.computation);
+ int num_inputs = then_result.input_mapping.size();
+ std::vector<xla::XlaOp> inputs(num_inputs);
+ for (int i = 0; i < num_inputs; ++i) {
+ int input_num = then_result.input_mapping[i] + 1;
+ if (ctx->input_type(input_num) == DT_RESOURCE) {
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
+ OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
+ } else {
+ inputs[i] = ctx->Input(i + 1);
+ }
+ }
+
+ xla::XlaOp outputs = xla::Conditional(
+ ctx->Input(0), xla::Tuple(b, inputs), *then_result.computation,
+ xla::Tuple(b, inputs), *else_result.computation);
// Sets non-variable outputs.
for (int i = 0; i < output_types_.size(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
- xla::XlaOp output_handle = b->GetTupleElement(outputs, i);
+ xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
if (VLOG_IS_ON(2)) {
LOG(INFO) << "Setting output " << i;
auto shape_or = b->GetShape(output_handle);
@@ -209,7 +234,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- b->GetTupleElement(outputs, pos), b));
+ xla::GetTupleElement(outputs, pos), b));
}
VLOG(2) << "If variable: pos: " << update.input_index
<< " name: " << resource->name()
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index 1568b33679..cb4caf7bcb 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -32,23 +33,26 @@ std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b,
auto red = rgb[0];
auto green = rgb[1];
auto blue = rgb[2];
- auto value = b->Max(b->Max(red, green), blue);
- auto minimum = b->Min(b->Min(red, green), blue);
- auto range = b->Sub(value, minimum);
-
- auto zeros = b->Broadcast(zero, shape.dim_sizes());
- auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros);
-
- auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
-
- auto hue = b->Select(b->Eq(green, value),
- b->Add(b->Mul(norm, b->Sub(blue, red)),
- XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
- b->Add(b->Mul(norm, b->Sub(red, green)),
- XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
- hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue);
- hue = b->Select(b->Gt(range, zero), hue, zeros);
- hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue);
+ auto value = xla::Max(xla::Max(red, green), blue);
+ auto minimum = xla::Min(xla::Min(red, green), blue);
+ auto range = xla::Sub(value, minimum);
+
+ auto zeros = xla::Broadcast(zero, shape.dim_sizes());
+ auto saturation =
+ xla::Select(xla::Gt(value, zero), xla::Div(range, value), zeros);
+
+ auto norm = xla::Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range);
+
+ auto hue =
+ xla::Select(xla::Eq(green, value),
+ xla::Add(xla::Mul(norm, xla::Sub(blue, red)),
+ XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)),
+ xla::Add(xla::Mul(norm, xla::Sub(red, green)),
+ XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0)));
+ hue = xla::Select(xla::Eq(red, value), xla::Mul(norm, xla::Sub(green, blue)),
+ hue);
+ hue = xla::Select(xla::Gt(range, zero), hue, zeros);
+ hue = xla::Select(xla::Lt(hue, zero), xla::Add(hue, one), hue);
return {hue, saturation, value};
}
@@ -66,15 +70,15 @@ std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b,
auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0);
auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0);
- auto dh = b->Mul(hue, six);
- auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one);
- auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one);
- auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one);
- auto one_minus_s = b->Sub(one, saturation);
+ auto dh = xla::Mul(hue, six);
+ auto dr = xla::Clamp(zero, xla::Sub(xla::Abs(xla::Sub(dh, three)), one), one);
+ auto dg = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, two))), one);
+ auto db = xla::Clamp(zero, xla::Sub(two, xla::Abs(xla::Sub(dh, four))), one);
+ auto one_minus_s = xla::Sub(one, saturation);
- auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value);
- auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value);
- auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value);
+ auto red = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dr)), value);
+ auto green = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, dg)), value);
+ auto blue = xla::Mul(xla::Add(one_minus_s, xla::Mul(saturation, db)), value);
return {red, green, blue};
}
@@ -97,21 +101,21 @@ class RGBToHSVOp : public XlaOpKernel {
xla::XlaBuilder* b = context->builder();
xla::XlaOp input = context->Input(0);
- xla::XlaOp red =
- b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp green =
- b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp blue =
- b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
- /*dimno=*/channel_dim);
+ xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
+ /*limit_index=*/1, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
+ /*limit_index=*/2, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
+ /*limit_index=*/3, /*stride=*/1,
+ /*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
channel_shape);
- context->SetOutput(0, b->ConcatInDim(hsv, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, hsv, channel_dim));
}
};
REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp);
@@ -134,20 +138,20 @@ class HSVToRGBOp : public XlaOpKernel {
xla::XlaBuilder* b = context->builder();
xla::XlaOp input = context->Input(0);
- xla::XlaOp hue =
- b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp saturation =
- b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp value =
- b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
- /*dimno=*/channel_dim);
+ xla::XlaOp hue = xla::SliceInDim(input, /*start_index=*/0,
+ /*limit_index=*/1, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp saturation = xla::SliceInDim(input, /*start_index=*/1,
+ /*limit_index=*/2, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp value = xla::SliceInDim(input, /*start_index=*/2,
+ /*limit_index=*/3, /*stride=*/1,
+ /*dimno=*/channel_dim);
auto rgb = HSVToRGB(context->builder(), {hue, saturation, value},
context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp);
@@ -182,18 +186,20 @@ class AdjustContrastOpV2 : public XlaOpKernel {
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
XlaHelpers::ConvertElementType(b, input, accumulation_type);
- auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *context->GetOrCreateAdd(accumulation_type),
- {height_dim, width_dim});
+ auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *context->GetOrCreateAdd(accumulation_type),
+ {height_dim, width_dim});
auto output = XlaHelpers::ConvertElementType(b, reduce, type);
- output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
+ output =
+ xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
std::vector<int64> broadcast_dims(input_shape.dims() - 2);
std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
broadcast_dims.back() = channel_dim;
- output = b->Add(b->Mul(input, factor),
- b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)),
- broadcast_dims);
+ output =
+ xla::Add(xla::Mul(input, factor),
+ xla::Mul(output, xla::Sub(XlaHelpers::One(b, type), factor)),
+ broadcast_dims);
context->SetOutput(0, output);
}
};
@@ -226,26 +232,26 @@ class AdjustSaturationOp : public XlaOpKernel {
DataType type = context->input_type(0);
- xla::XlaOp red =
- b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp green =
- b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp blue =
- b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
- /*dimno=*/channel_dim);
+ xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
+ /*limit_index=*/1, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
+ /*limit_index=*/2, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
+ /*limit_index=*/3, /*stride=*/1,
+ /*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
channel_shape);
- hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale),
- XlaHelpers::One(b, type));
+ hsv[1] = xla::Clamp(XlaHelpers::Zero(b, type), xla::Mul(hsv[1], scale),
+ XlaHelpers::One(b, type));
auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp);
@@ -276,15 +282,15 @@ class AdjustHueOp : public XlaOpKernel {
DataType type = context->input_type(0);
- xla::XlaOp red =
- b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp green =
- b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1,
- /*dimno=*/channel_dim);
- xla::XlaOp blue =
- b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1,
- /*dimno=*/channel_dim);
+ xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0,
+ /*limit_index=*/1, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1,
+ /*limit_index=*/2, /*stride=*/1,
+ /*dimno=*/channel_dim);
+ xla::XlaOp blue = xla::SliceInDim(input, /*start_index=*/2,
+ /*limit_index=*/3, /*stride=*/1,
+ /*dimno=*/channel_dim);
TensorShape channel_shape = input_shape;
channel_shape.set_dim(channel_dim, 1);
auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0),
@@ -294,12 +300,13 @@ class AdjustHueOp : public XlaOpKernel {
auto one = XlaHelpers::One(b, type);
auto& hue = hsv[0];
- hue = b->Rem(b->Add(hsv[0], delta), one);
- hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue);
+ hue = xla::Rem(xla::Add(hsv[0], delta), one);
+ hue =
+ xla::Select(xla::Lt(hue, zero), xla::Rem(xla::Add(one, hue), one), hue);
auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0));
- context->SetOutput(0, b->ConcatInDim(rgb, channel_dim));
+ context->SetOutput(0, xla::ConcatInDim(b, rgb, channel_dim));
}
};
REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index 9058cbc747..d6bf92fb3d 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -99,46 +101,71 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters(
return dims;
}
+// Form a 2D convolution kernel like:
+// 1 2 3 2 1
+// 2 4 6 4 2
+// 1/9 * 3 6 9 6 3
+// 2 4 6 4 2
+// 1 2 3 2 1
+// by multiplying two 1D kernels of the form:
+// 1/3 * [1 2 3 2 1]
+// If the 2D kernel would be very large, the 1D kernel can be applied once in
+// each dimension due to the symmetry of the kernel along all axis to reduce the
+// computational intensity.
+std::vector<float> Make1DKernel(int64 n) {
+ std::vector<float> kernel(n * 2 - 1);
+ for (int64 i = 0; i < n; ++i) {
+ float v = (i + 1.0f) / n;
+ kernel[i] = v;
+ kernel[n * 2 - 2 - i] = v;
+ }
+ return kernel;
+}
+
+// Kernels with more than 16 spatial elements are considered intense and the
+// kernel should applied to each dimension independently.
+const int64 kMax2DKernelSize = 16;
+
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
gtl::ArraySlice<int64> kernel_size,
int64 channels) {
- // Form a 2D convolution kernel like:
- // 1 2 3 2 1
- // 2 4 6 4 2
- // 1/9 * 3 6 9 6 3
- // 2 4 6 4 2
- // 1 2 3 2 1
- // by multiplying two 1D kernels of the form:
- // 1/3 * [1 2 3 2 1]
- auto make_1d_kernel = [](int64 n) {
- std::vector<float> kernel(n * 2 - 1);
- for (int64 i = 0; i < n; ++i) {
- float v = (i + 1.0f) / n;
- kernel[i] = v;
- kernel[n * 2 - 2 - i] = v;
- }
- return kernel;
- };
-
- xla::XlaOp channels_iota;
- // DT_INT32 Iota will always return status::OK().
- TF_CHECK_OK(
- XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
+ xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
- auto diag = builder->ConvertElementType(
- builder->Eq(
- builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1,
+ auto diag = xla::ConvertElementType(
+ xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
xla::PrimitiveType::F32);
- return builder->Mul(
- builder->Mul(diag,
- builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1}),
- builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])),
+ return xla::Mul(
+ xla::Mul(diag,
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ /*broadcast_dimensions=*/{1}),
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
/*broadcast_dimensions=*/{0});
}
+xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
+ gtl::ArraySlice<int64> kernel_size,
+ int64 channels, int64 dim) {
+ xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+
+ auto diag = xla::ConvertElementType(
+ xla::Eq(
+ xla::Broadcast(channels_iota,
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
+ xla::PrimitiveType::F32);
+ if (dim == 1) {
+ return xla::Mul(
+ diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ /*broadcast_dimensions=*/{1});
+ }
+ return xla::Mul(diag,
+ xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ /*broadcast_dimensions=*/{0});
+}
+
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
const xla::XlaOp& input,
const int num_spatial_dims,
@@ -165,27 +192,49 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
dimension_numbers.add_output_spatial_dimensions(1 + i);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
ResizeConvolutionDims dims =
ComputeResizeConvolutionParameters(in_size, out_size);
- xla::XlaOp kernel =
- MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
- xla::XlaOp output = builder->ConvGeneralDilated(
- input, kernel, dims.stride,
- /*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
- {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
- /*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ xla::XlaOp output;
+ // Split convolutions into independent dimensions if they wmuld be a very
+ // large kernel.
+ if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
+ xla::XlaOp kernel =
+ MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
+ output = xla::ConvGeneralDilated(
+ input, kernel, dims.stride,
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
+ {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ /*lhs_dilation=*/dims.kernel_size,
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ } else {
+ xla::XlaOp kernel0 =
+ MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
+ output = xla::ConvGeneralDilated(
+ input, kernel0, {dims.stride[0], 1},
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
+ /*lhs_dilation=*/{dims.kernel_size[0], 1},
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ xla::XlaOp kernel1 =
+ MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
+ output = xla::ConvGeneralDilated(
+ output, kernel1, {1, dims.stride[1]},
+ /*padding=*/
+ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ /*lhs_dilation=*/{1, dims.kernel_size[1]},
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ }
// Add broadcasts to handle expanding from a size == 1 dimension to a
// size > 1 dimension.
for (int i = 0; i < num_spatial_dims; ++i) {
if (in_size[i] == 1 && out_size[i] > 1) {
- output = builder->Add(output, builder->ConstantR1<float>(out_size[i], 0),
- /*broadcast_dimensions=*/{1 + i});
+ output = xla::Add(output, xla::ConstantR1<float>(builder, out_size[i], 0),
+ /*broadcast_dimensions=*/{1 + i});
}
}
return output;
@@ -214,26 +263,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
}
dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
- xla::XlaOp kernel =
- MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
+ xla::XlaOp output;
+ if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
+ xla::XlaOp kernel =
+ MakeBilinearResizeKernel(builder, dims.kernel_size, channels);
+
+ // Broadcast the input kernel where the forward op expanded from a size == 1
+ // dimension to a size > 1 dimension. This has the effect of summing the
+ // gradient contributions in that dimension.
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ if (in_size[i] == 1 && grad_size[i] > 1) {
+ kernel =
+ xla::Add(kernel, xla::ConstantR1<float>(builder, grad_size[i], 0),
+ /*broadcast_dimensions=*/{i});
+ }
+ }
- // Broadcast the input kernel where the forward op expanded from a size == 1
- // dimension to a size > 1 dimension. This has the effect of summing the
- // gradient contributions in that dimension.
- for (int i = 0; i < num_spatial_dims; ++i) {
- if (in_size[i] == 1 && grad_size[i] > 1) {
- kernel = builder->Add(kernel, builder->ConstantR1<float>(grad_size[i], 0),
- /*broadcast_dimensions=*/{i});
+ output = xla::ConvGeneralDilated(
+ grad, kernel, /*window_strides=*/dims.kernel_size,
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
+ {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ /*lhs_dilation=*/dims.stride,
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ } else {
+ xla::XlaOp kernel0 =
+ MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
+ xla::XlaOp kernel1 =
+ MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
+
+ // Broadcast the input kernel where the forward op expanded from a size == 1
+ // dimension to a size > 1 dimension. This has the effect of summing the
+ // gradient contributions in that dimension.
+ if (in_size[0] == 1 && grad_size[0] > 1) {
+ kernel0 =
+ xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[0], 0),
+ /*broadcast_dimensions=*/{0});
+ }
+ if (in_size[1] == 1 && grad_size[1] > 1) {
+ kernel1 =
+ xla::Add(kernel0, xla::ConstantR1<float>(builder, grad_size[1], 0),
+ /*broadcast_dimensions=*/{1});
}
- }
- xla::XlaOp output = builder->ConvGeneralDilated(
- grad, kernel, /*window_strides=*/dims.kernel_size,
- /*padding=*/
- {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
- {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
- /*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ output = xla::ConvGeneralDilated(
+ grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
+ /*padding=*/
+ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
+ /*lhs_dilation=*/{dims.stride[0], 1},
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+
+ output = xla::ConvGeneralDilated(
+ output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
+ /*padding=*/
+ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
+ /*lhs_dilation=*/{1, dims.stride[1]},
+ /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ }
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
// Opposite of the slice performed by the forward op.
@@ -246,7 +332,7 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
}
}
if (pad_output) {
- output = builder->Pad(output, builder->ConstantR0<float>(0.0f), padding);
+ output = xla::Pad(output, xla::ConstantR0<float>(builder, 0.0f), padding);
}
return output;
}
@@ -302,13 +388,13 @@ class ResizeBilinearOp : public XlaOpKernel {
}
}
if (slice_input) {
- input = b->Slice(input, {0, 0, 0, 0},
- {batch, slice_size[0], slice_size[1], channels},
- {1, 1, 1, 1});
+ input = xla::Slice(input, {0, 0, 0, 0},
+ {batch, slice_size[0], slice_size[1], channels},
+ {1, 1, 1, 1});
}
// Output is always type float.
- input = b->ConvertElementType(input, xla::F32);
+ input = xla::ConvertElementType(input, xla::F32);
// Special Case:
// Instead of doing a ResizeUsingDilationAndConvolution directly,
@@ -438,7 +524,7 @@ class ResizeBilinearGradOp : public XlaOpKernel {
}
}
- output = b->ConvertElementType(output, output_type_);
+ output = xla::ConvertElementType(output, output_type_);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
index 36eb4c7545..f396474858 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -60,19 +60,15 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) {
input_shape.DebugString()));
DataType index_type = output_type(0);
+ xla::PrimitiveType index_xla_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &index_xla_type));
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
-
xla::XlaOp output;
if (is_min_) {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMin(input, index_xla_type, axis);
} else {
- OP_REQUIRES_OK(ctx,
- XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0),
- index_type, axis, &output));
+ output = XlaHelpers::ArgMax(input, index_xla_type, axis);
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 2c2d88486f..22a45b2a11 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -76,14 +77,15 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
// XLA passes <out> to the function, so it is not included here.
std::vector<xla::XlaOp> args;
args.push_back(ctx->Input(0));
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(input_shape.dim_sizes())));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
if (input_shape.dims() > 1) {
// Don't bother passing the output shape and dim for the 1d case, since
// the shape is always a scalar and the dim is always 0.
- args.push_back(b.ConstantLiteral(
- *xla::Literal::CreateR1<int64>(output_shape.dim_sizes())));
- args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0<int32>(dim)));
+ args.push_back(xla::ConstantLiteral(
+ &b, *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ args.push_back(
+ xla::ConstantLiteral(&b, *xla::LiteralUtil::CreateR0<int32>(dim)));
}
xla::Shape xla_shape =
@@ -94,10 +96,12 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
break;
case 2:
- output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape);
+ output =
+ xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
break;
default:
OP_REQUIRES(ctx, false,
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index 1decf7d72d..9e64711051 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -39,12 +39,12 @@ class L2LossOp : public XlaOpKernel {
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
auto t =
XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
- auto square = b->Mul(t, t);
- auto reduce = b->Reduce(square, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), dims);
+ auto square = xla::Mul(t, t);
+ auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), dims);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype);
auto two = XlaHelpers::IntegerLiteral(b, dtype, 2);
- ctx->SetOutput(0, b->Div(deconverted, two));
+ ctx->SetOutput(0, xla::Div(deconverted, two));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
index 0388b4c830..2fb072f827 100644
--- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -90,8 +91,10 @@ class ListDiffOp : public XlaOpKernel {
idx_output.push_back(i);
}
- context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output));
- context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output));
+ context->SetOutput(0,
+ xla::ConstantR1<Tval>(context->builder(), val_output));
+ context->SetOutput(1,
+ xla::ConstantR1<Tidx>(context->builder(), idx_output));
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index 39fbf98a62..dc934543cb 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -50,8 +51,8 @@ class LRNOp : public XlaOpKernel {
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
XlaHelpers::ConvertElementType(builder, input, accumulation_type);
- auto squared = builder->Mul(converted, converted);
- auto reduce = builder->ReduceWindow(
+ auto squared = xla::Mul(converted, converted);
+ auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -59,12 +60,12 @@ class LRNOp : public XlaOpKernel {
auto sqr_sum =
XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
- auto scale = builder->Pow(
- builder->Add(builder->ConstantR0<float>(bias_),
- builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)),
- builder->ConstantR0<float>(-beta_));
+ auto scale = xla::Pow(
+ xla::Add(xla::ConstantR0<float>(builder, bias_),
+ xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum)),
+ xla::ConstantR0<float>(builder, -beta_));
- ctx->SetOutput(0, builder->Mul(input, scale));
+ ctx->SetOutput(0, xla::Mul(input, scale));
}
private:
@@ -138,8 +139,8 @@ class LRNGradOp : public XlaOpKernel {
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
XlaHelpers::ConvertElementType(builder, in_image, accumulation_type);
- auto squared = builder->Mul(converted, converted);
- auto reduce = builder->ReduceWindow(
+ auto squared = xla::Mul(converted, converted);
+ auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -148,17 +149,17 @@ class LRNGradOp : public XlaOpKernel {
XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
auto norm =
- builder->Add(builder->ConstantR0<float>(bias_),
- builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum));
+ xla::Add(xla::ConstantR0<float>(builder, bias_),
+ xla::Mul(xla::ConstantR0<float>(builder, alpha_), sqr_sum));
- auto dy = builder->Mul(
- builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_),
- builder->Div(out_image, norm)),
+ auto dy = xla::Mul(
+ xla::Mul(xla::ConstantR0<float>(builder, -2.0f * alpha_ * beta_),
+ xla::Div(out_image, norm)),
in_grads);
auto converted_dy =
XlaHelpers::ConvertElementType(builder, dy, accumulation_type);
- auto dy_reduce = builder->ReduceWindow(
+ auto dy_reduce = xla::ReduceWindow(
converted_dy, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
@@ -166,10 +167,10 @@ class LRNGradOp : public XlaOpKernel {
auto dy_reduced =
XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0));
- xla::XlaOp gradients = builder->Add(
- builder->Mul(in_image, dy_reduced),
- builder->Mul(in_grads,
- builder->Pow(norm, builder->ConstantR0<float>(-beta_))));
+ xla::XlaOp gradients = xla::Add(
+ xla::Mul(in_image, dy_reduced),
+ xla::Mul(in_grads,
+ xla::Pow(norm, xla::ConstantR0<float>(builder, -beta_))));
ctx->SetOutput(0, gradients);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index 6949b296f4..844080b8cf 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
@@ -70,15 +71,15 @@ class MatMulOp : public XlaOpKernel {
xla::XlaOp b = ctx->Input(1);
if (is_sparse_) {
if (a_type_ == DT_BFLOAT16) {
- a = ctx->builder()->ConvertElementType(a, xla::F32);
+ a = xla::ConvertElementType(a, xla::F32);
}
if (b_type_ == DT_BFLOAT16) {
- b = ctx->builder()->ConvertElementType(b, xla::F32);
+ b = xla::ConvertElementType(b, xla::F32);
}
}
- auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a;
- auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b;
- ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs));
+ auto lhs = (transpose_a_) ? xla::Transpose(a, {1, 0}) : a;
+ auto rhs = (transpose_b_) ? xla::Transpose(b, {1, 0}) : b;
+ ctx->SetOutput(0, xla::Dot(lhs, rhs));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index fbd5dc0fda..e06c87db7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -50,6 +52,7 @@ class MatrixBandPartOp : public XlaOpKernel {
xla::XlaOp num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
+ xla::PrimitiveType index_xla_type = context->input_xla_type(1);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
@@ -58,33 +61,29 @@ class MatrixBandPartOp : public XlaOpKernel {
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
+ xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m);
+ xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n);
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
-
- auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
- /*broadcast_dimensions=*/{0});
+ auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m,
+ /*broadcast_dimensions=*/{0});
// If num_lower or num_upper are negative, include all lower/upper
// diagonals.
auto zero_index = XlaHelpers::Zero(builder, index_type);
- num_lower = builder->Select(
- builder->Lt(num_lower, zero_index),
- XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
- num_upper = builder->Select(
- builder->Lt(num_upper, zero_index),
- XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
+ num_lower = xla::Select(xla::Lt(num_lower, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, m),
+ num_lower);
+ num_upper = xla::Select(xla::Lt(num_upper, zero_index),
+ XlaHelpers::IntegerLiteral(builder, index_type, n),
+ num_upper);
- auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
- builder->Le(offset, num_upper));
- indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+ auto indicator = xla::And(xla::Le(xla::Neg(num_lower), offset),
+ xla::Le(offset, num_upper));
+ indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
auto zero_input = XlaHelpers::Zero(builder, input_type);
- auto output = builder->Select(
- indicator, input,
- builder->Broadcast(zero_input, input_shape.dim_sizes()));
+ auto output = xla::Select(
+ indicator, input, xla::Broadcast(zero_input, input_shape.dim_sizes()));
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index db53f6fef8..e2ab4b83cf 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -61,14 +63,11 @@ class MatrixSetDiagOp : public XlaOpKernel {
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
- xla::XlaOp iota_m;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
- xla::XlaOp iota_n;
- OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
- auto indicator = builder->Eq(iota_m,
- builder->Broadcast(iota_n, {m}),
- /*broadcast_dimensions=*/{0});
- indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
+ xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m);
+ xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n);
+ auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}),
+ /*broadcast_dimensions=*/{0});
+ indicator = xla::Broadcast(indicator, batch_shape.dim_sizes());
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
// because we need to broadcast on the right.
@@ -77,10 +76,10 @@ class MatrixSetDiagOp : public XlaOpKernel {
if (min_dim != m) {
diag_broadcast_dims.back() = rank - 1;
}
- diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
- /*broadcast_dimensions=*/diag_broadcast_dims);
+ diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()),
+ /*broadcast_dimensions=*/diag_broadcast_dims);
- auto output = builder->Select(indicator, diag, input);
+ auto output = xla::Select(indicator, diag, input);
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
index eaed931464..f4def11d08 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
@@ -30,13 +30,9 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
auto result = TriangularSolve(
- ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true,
+ ctx->Input(0), ctx->Input(1), /*left_side=*/true,
/*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
- if (!result.ok()) {
- ctx->SetStatus(result.status());
- return;
- }
- ctx->SetOutput(0, result.ValueOrDie());
+ ctx->SetOutput(0, result);
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
index 7e9de3ef9b..529959dbd9 100644
--- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/mirror_pad_mode.h"
namespace tensorflow {
@@ -27,21 +28,21 @@ class MirrorPadOp : public XlaOpKernel {
xla::StatusOr<xla::XlaOp> DoMirrorPad(const xla::XlaOp& t,
const xla::Shape& original_shape,
- const xla::Literal& pad_literal,
+ const xla::LiteralSlice& pad_literal,
xla::XlaBuilder* b) {
xla::XlaOp accum = t;
for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0;
--dimno) {
- auto t_rev = b->Rev(accum, {dimno});
+ auto t_rev = xla::Rev(accum, {dimno});
TF_ASSIGN_OR_RETURN(int64 lhs_padding,
pad_literal.GetIntegralAsS64({dimno, 0}));
TF_ASSIGN_OR_RETURN(int64 rhs_padding,
pad_literal.GetIntegralAsS64({dimno, 1}));
int64 dim_size = original_shape.dimensions(dimno);
- auto lhs_pad = b->SliceInDim(t_rev, dim_size - 1 - lhs_padding,
- dim_size - 1, 1, dimno);
- auto rhs_pad = b->SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
- accum = b->ConcatInDim({lhs_pad, accum, rhs_pad}, dimno);
+ auto lhs_pad = xla::SliceInDim(t_rev, dim_size - 1 - lhs_padding,
+ dim_size - 1, 1, dimno);
+ auto rhs_pad = xla::SliceInDim(t_rev, 1, 1 + rhs_padding, 1, dimno);
+ accum = xla::ConcatInDim(b, {lhs_pad, accum, rhs_pad}, dimno);
}
return accum;
}
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
index aecaabb6dc..3aed47de26 100644
--- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -76,11 +77,10 @@ class PackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
// Reshape the inputs to have an extra dimension of size 1.
- reshaped_inputs[i] =
- ctx->builder()->Reshape(values[i], child_shape.dim_sizes());
+ reshaped_inputs[i] = xla::Reshape(values[i], child_shape.dim_sizes());
}
- ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis));
+ ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), reshaped_inputs, axis));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
index 7c95475e7b..89fd610bc6 100644
--- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
@@ -63,8 +64,8 @@ class PadOp : public XlaOpKernel {
int before = pad_literal.Get<int32>({i, 0});
int after = pad_literal.Get<int32>({i, 1});
OP_REQUIRES(ctx, before >= 0 && after >= 0,
- errors::InvalidArgument("Paddings must be non-negative: ",
- before, " ", after));
+ errors::InvalidArgument(
+ "Paddings must be non-negative: ", before, " ", after));
dim->set_edge_padding_low(before);
dim->set_edge_padding_high(after);
}
@@ -74,11 +75,10 @@ class PadOp : public XlaOpKernel {
if (ctx->num_inputs() == 3) {
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)),
errors::InvalidArgument("constant_values must be a scalar."));
- ctx->SetOutput(0,
- ctx->builder()->Pad(ctx->Input(0), ctx->Input(2), config));
+ ctx->SetOutput(0, xla::Pad(ctx->Input(0), ctx->Input(2), config));
} else {
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config));
+ ctx->SetOutput(0, xla::Pad(ctx->Input(0), zero, config));
}
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
index f8e7b48a0f..12d9cb9bac 100644
--- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -61,6 +63,9 @@ class PoolingOp : public XlaOpKernel {
Padding padding;
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@@ -113,8 +118,8 @@ class PoolingOp : public XlaOpKernel {
xla::XlaBuilder* const b = ctx->builder();
auto input =
XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_);
- auto reduce = ctx->builder()->ReduceWindow(
- input, InitValue(b), *Reduction(ctx), ksize, stride, padding_);
+ auto reduce = xla::ReduceWindow(input, InitValue(b), *Reduction(ctx), ksize,
+ stride, padding_);
auto pooled = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
ctx->SetOutput(0,
PostProcessOutput(ctx, pooled, input_type(0), input_shape));
@@ -127,6 +132,7 @@ class PoolingOp : public XlaOpKernel {
xla::Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
DataType reduction_type_;
+ xla::PrimitiveType xla_reduction_type_;
};
class MaxPoolOp : public PoolingOp {
@@ -136,7 +142,7 @@ class MaxPoolOp : public PoolingOp {
/*reduction_type=*/ctx->input_type(0)) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::MinValue(b, reduction_type_);
+ return xla::MinValue(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -190,7 +196,7 @@ static xla::XlaOp AvgPoolDivideByCount(
auto divisor =
XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
- return ctx->builder()->Div(output, divisor);
+ return xla::Div(output, divisor);
} else {
// For SAME padding, the padding shouldn't be included in the
// counts. We use another ReduceWindow to find the right counts.
@@ -212,18 +218,18 @@ static xla::XlaOp AvgPoolDivideByCount(
// Build a matrix of all 1s, with the same width/height as the input.
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto ones = ctx->builder()->Broadcast(
+ auto ones = xla::Broadcast(
XlaHelpers::One(ctx->builder(), accumulation_type), input_dim_sizes);
// Perform a ReduceWindow with the same window size, strides, and padding
// to count the number of contributions to each result element.
- auto reduce = ctx->builder()->ReduceWindow(
+ auto reduce = xla::ReduceWindow(
ones, XlaHelpers::Zero(ctx->builder(), accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), window_ksize, window_stride,
xla::Padding::kSame);
auto counts = XlaHelpers::ConvertElementType(ctx->builder(), reduce, dtype);
- return ctx->builder()->Div(output, counts, window_dims);
+ return xla::Div(output, counts, window_dims);
}
}
@@ -235,7 +241,7 @@ class AvgPoolOp : public PoolingOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitValue(xla::XlaBuilder* b) override {
- return XlaHelpers::Zero(b, reduction_type_);
+ return xla::Zero(b, xla_reduction_type_);
}
const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override {
@@ -347,9 +353,9 @@ class MaxPoolGradOp : public XlaOpKernel {
xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2));
auto select = CreateScalarGeComputation(element_type, ctx->builder());
auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
- xla::XlaOp gradients = ctx->builder()->SelectAndScatter(
- input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
- scatter);
+ xla::XlaOp gradients =
+ xla::SelectAndScatter(input, select, ksize_, stride_, xla_padding,
+ out_backprop, init_value, scatter);
ctx->SetOutput(0, gradients);
}
@@ -485,12 +491,12 @@ class AvgPoolGradOp : public XlaOpKernel {
}
auto zero = XlaHelpers::Zero(b, dtype);
- auto padded_gradients = b->Pad(out_backprop_div, zero, padding_config);
+ auto padded_gradients = xla::Pad(out_backprop_div, zero, padding_config);
// in_backprop = padded_gradients <conv> ones
std::vector<int64> ones(num_dims(), 1LL);
auto accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto in_backprop = b->ReduceWindow(
+ auto in_backprop = xla::ReduceWindow(
XlaHelpers::ConvertElementType(b, padded_gradients, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), ksize_,
@@ -614,58 +620,61 @@ class MaxPoolGradGradOp : public XlaOpKernel {
auto b = ctx->builder();
- auto sixteen = b->ConstantR0<uint32>(16);
+ auto sixteen = xla::ConstantR0<uint32>(b, 16);
// in (f32) -> round to bf16 -> f32 for correct bitwidth -> 16-high-bit u32
- auto in_hi = b->BitcastConvertType(
- b->ConvertElementType(b->ConvertElementType(input, xla::BF16),
- xla::F32),
+ auto in_hi = xla::BitcastConvertType(
+ xla::ConvertElementType(xla::ConvertElementType(input, xla::BF16),
+ xla::F32),
xla::U32);
- auto bp_int = b->BitcastConvertType(out_backprop, xla::U32);
- auto bp_hi = b->ShiftRightLogical(bp_int, sixteen);
- auto bp_lo = b->ShiftRightLogical(b->ShiftLeft(bp_int, sixteen), sixteen);
- auto in_hi_bp_hi = b->Add(in_hi, bp_hi); // Want an unsigned add.
- auto in_hi_bp_lo = b->Add(in_hi, bp_lo); // Want an unsigned add.
-
- auto init_value = XlaHelpers::MinValue(b, DT_FLOAT);
+ auto bp_int = xla::BitcastConvertType(out_backprop, xla::U32);
+ auto bp_hi = xla::ShiftRightLogical(bp_int, sixteen);
+ auto bp_lo =
+ xla::ShiftRightLogical(xla::ShiftLeft(bp_int, sixteen), sixteen);
+ auto in_hi_bp_hi = xla::Add(in_hi, bp_hi); // Want an unsigned add.
+ auto in_hi_bp_lo = xla::Add(in_hi, bp_lo); // Want an unsigned add.
+
+ auto init_value = xla::MinValue(b, xla::F32);
// We will reduce by taking the maximal value up to 16 bits (ignoring the lo
// 16 bits of packed-in hi/lo backprop value).
auto rb = b->CreateSubBuilder("GreaterOrEqOf_ByFirst16Bits");
{
// F32 parameters to satisfy lowering type restriction for reduce opcode.
const xla::Shape scalar = xla::ShapeUtil::MakeShape(xla::F32, {});
- auto lhs = rb->Parameter(0, scalar, "lhs");
- auto rhs = rb->Parameter(1, scalar, "rhs");
- auto sixteen = rb->ConstantR0<int32>(16);
- auto lhs_criteria = rb->ShiftLeft(
- rb->ShiftRightLogical(rb->BitcastConvertType(lhs, xla::S32), sixteen),
- sixteen);
- auto rhs_criteria = rb->ShiftLeft(
- rb->ShiftRightLogical(rb->BitcastConvertType(rhs, xla::S32), sixteen),
- sixteen);
+ auto lhs = xla::Parameter(rb.get(), 0, scalar, "lhs");
+ auto rhs = xla::Parameter(rb.get(), 1, scalar, "rhs");
+ auto sixteen = xla::ConstantR0<int32>(rb.get(), 16);
+ auto lhs_criteria =
+ xla::ShiftLeft(xla::ShiftRightLogical(
+ xla::BitcastConvertType(lhs, xla::S32), sixteen),
+ sixteen);
+ auto rhs_criteria =
+ xla::ShiftLeft(xla::ShiftRightLogical(
+ xla::BitcastConvertType(rhs, xla::S32), sixteen),
+ sixteen);
// Must use a F32 comparison, because S32 would not work for negatives.
- rb->Select(rb->Ge(rb->BitcastConvertType(lhs_criteria, xla::F32),
- rb->BitcastConvertType(rhs_criteria, xla::F32)),
- lhs, rhs);
+ xla::Select(xla::Ge(xla::BitcastConvertType(lhs_criteria, xla::F32),
+ xla::BitcastConvertType(rhs_criteria, xla::F32)),
+ lhs, rhs);
}
auto reduce = rb->BuildAndNoteError();
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
auto pooled_hi =
- b->ReduceWindow(b->BitcastConvertType(in_hi_bp_hi, xla::F32),
- init_value, reduce, ksize_, stride_, xla_padding);
+ xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_hi, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
auto pooled_lo =
- b->ReduceWindow(b->BitcastConvertType(in_hi_bp_lo, xla::F32),
- init_value, reduce, ksize_, stride_, xla_padding);
+ xla::ReduceWindow(xla::BitcastConvertType(in_hi_bp_lo, xla::F32),
+ init_value, reduce, ksize_, stride_, xla_padding);
auto grads_hi =
- b->ShiftLeft(b->BitcastConvertType(pooled_hi, xla::U32), sixteen);
- auto grads_lo = b->ShiftRightLogical(
- b->ShiftLeft(b->BitcastConvertType(pooled_lo, xla::U32), sixteen),
+ xla::ShiftLeft(xla::BitcastConvertType(pooled_hi, xla::U32), sixteen);
+ auto grads_lo = xla::ShiftRightLogical(
+ xla::ShiftLeft(xla::BitcastConvertType(pooled_lo, xla::U32), sixteen),
sixteen);
- auto grads = b->Add(grads_hi, grads_lo); // Want an unsigned add.
+ auto grads = xla::Add(grads_hi, grads_lo); // Want an unsigned add.
xla::PrimitiveType element_type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
- ctx->SetOutput(0, b->BitcastConvertType(grads, element_type));
+ ctx->SetOutput(0, xla::BitcastConvertType(grads, element_type));
}
protected:
@@ -694,5 +703,18 @@ REGISTER_XLA_OP(Name("MaxPoolGradGradV2")
.CompileTimeConstInput("strides"),
MaxPool2DGradGradOp);
+class MaxPool3DGradGradOp : public MaxPoolGradGradOp {
+ public:
+ explicit MaxPool3DGradGradOp(OpKernelConstruction* ctx)
+ : MaxPoolGradGradOp(ctx, /*num_spatial_dims=*/3) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ }
+};
+REGISTER_XLA_OP(Name("MaxPool3DGradGrad").TypeConstraint("T", DT_FLOAT),
+ MaxPool3DGradGradOp);
+
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
new file mode 100644
index 0000000000..de9068a640
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/qr.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+class QROp : public XlaOpKernel {
+ public:
+ explicit QROp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ bool full_matrices;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices));
+ OP_REQUIRES(
+ ctx, full_matrices,
+ errors::Unimplemented("full_matrices=False case of QR decomposition is "
+ "not implemented in TF/XLA"));
+ }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto result = QRDecomposition(ctx->Input(0));
+ if (!result.ok()) {
+ ctx->SetStatus(result.status());
+ return;
+ }
+ ctx->SetOutput(0, result.ValueOrDie().q);
+ ctx->SetOutput(1, result.ValueOrDie().r);
+ }
+};
+
+REGISTER_XLA_OP(Name("Qr").TypeConstraint("T", kFloatTypes), QROp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 661cd5923e..e88221e4f4 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
@@ -28,82 +31,115 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
- OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
- errors::InvalidArgument("num_bits is out of range: ", num_bits_,
- " with signed_input_ ", signed_input_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- // Comments taken from semantics description at
- // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize
- //
- // ... we find m such that
- //
- // m = max(abs(input_min), abs(input_max)) if range_given is true,
- // m = max(abs(min_elem(input)),
- // abs(max_elem(input))) otherwise.
+ xla::PrimitiveType xla_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(data_type, &xla_type));
+
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp input_min, input_max;
+
+ // The implementation follows
+ // tensorflow/core/kernels/quantize_and_dequantize_op.h closely.
+ xla::XlaOp min_range, max_range;
if (range_given_) {
- double input_min_value, input_max_value;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value));
- input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
- input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
+ min_range = ctx->Input(1);
+ max_range = ctx->Input(2);
} else {
const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
- input_min =
- b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
- input_max =
- b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
+ min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin);
+ max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax);
}
- xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max));
-
- // Next, we choose our fixed-point quantization buckets, [min_fixed,
- // max_fixed]. If signed_input is true, this is
- //
- // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1),
- // (1 << (num_bits - 1)) - 1].
- //
- // Otherwise, if signed_input is false, the fixed-point range is
- //
- // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1].
- int64 min_fixed, max_fixed;
+
+ xla::XlaOp num_bits;
+ if (num_bits_ < 0) {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 4,
+ errors::Internal("Expected 4 inputs to QuantizeAndDequantize"));
+ num_bits = ctx->Input(3);
+ } else {
+ num_bits = xla::ConstantR0<int32>(b, num_bits_);
+ }
+
+ const xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
+ const xla::XlaOp one = XlaHelpers::One(b, data_type);
+ const xla::XlaOp two = XlaHelpers::FloatLiteral(b, data_type, 2.0);
+ const xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5);
+
+ // Calculate the range for the simulated integer quantization:
+ // e.g. [-128,127] for signed = true, num_bits = 8,
+ // or [0, 255] for signed = false, num_bits = 8.
+ // We do this in floating point for hardware that does not have 64-bit
+ // integer support.
+ xla::XlaOp min_quantized, max_quantized;
if (signed_input_) {
- min_fixed = -((1LL << (num_bits_ - 1)) - 1);
- max_fixed = (1LL << (num_bits_ - 1)) - 1;
+ min_quantized =
+ -Pow(two, ConvertElementType(num_bits - xla::ConstantR0<int32>(b, 1),
+ xla_type));
+ max_quantized =
+ Pow(two, ConvertElementType(num_bits - xla::ConstantR0<int32>(b, 1),
+ xla_type)) -
+ one;
} else {
- min_fixed = 0;
- max_fixed = (1LL << num_bits_) - 1;
+ min_quantized = zero;
+ max_quantized = Pow(two, ConvertElementType(num_bits, xla_type)) - one;
}
- // From this we compute our scaling factor, s:
- //
- // s = (max_fixed - min_fixed) / (2 * m).
- xla::XlaOp s =
- b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
- b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
+ // Determine the maximum scaling factor that would scale
+ // [min_range, max_range] to not exceed [min_quantized, max_quantized],
+ // while keeping 0 unchanged.
+ xla::XlaOp scale_from_min_side =
+ Select(Gt(min_quantized * min_range, zero), min_quantized / min_range,
+ xla::MaxFiniteValue(b, xla_type));
+ xla::XlaOp scale_from_max_side =
+ Select(Gt(max_quantized * max_range, zero), max_quantized / max_range,
+ xla::MaxFiniteValue(b, xla_type));
- // Now we can quantize and dequantize the elements of our tensor. An element
- // e is transformed into e':
- //
- // e' = (e * s).round_to_nearest() / s.
- xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s);
+ // Note: Avoids changing the side of the range that determines scale.
+ xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side);
+ xla::XlaOp scale = Select(cond, scale_from_min_side, scale_from_max_side);
+ xla::XlaOp inverse_scale =
+ Select(cond, min_range / min_quantized, max_range / max_quantized);
+ min_range = Select(cond, min_range, min_quantized * inverse_scale);
+ max_range = Select(cond, max_quantized * inverse_scale, max_range);
+ if (range_given_) {
+ // Note: The clamping here is to avoid overflow in the quantized type.
+ // The semantics of the op does not guarantee to clamp to the specified
+ // min_range and max_range - because we may have changed either min_range
+ // or max_range.
+ // No need to clamp to min_range and max_range if range_given_ == false as
+ // in that case they were measured from the tensor.
+ input = Clamp(min_range, input, max_range);
+ }
+ xla::XlaOp result =
+ Floor((input - min_range) * scale + half) * inverse_scale + min_range;
ctx->SetOutput(0, result);
}
- int64 num_bits_;
+ protected:
+ int64 num_bits_ = -1;
bool signed_input_;
bool range_given_;
};
-REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp);
+class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp {
+ public:
+ explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
+ : QuantizeAndDequantizeOp(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
+ OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
+ errors::InvalidArgument("num_bits is out of range: ", num_bits_,
+ " with signed_input_ ", signed_input_));
+ }
+};
+
+REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeV2Op);
+REGISTER_XLA_OP(Name("QuantizeAndDequantizeV3"), QuantizeAndDequantizeOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 5f5bd58637..607cad798a 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -17,11 +17,17 @@ limitations under the License.
// TODO(misard,phawkins): handle random number generator seeds/states correctly.
// TODO(misard,phawkins): add tests.
+#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/random.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -42,8 +48,8 @@ class RandomUniformOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype),
- XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = xla::RngUniform(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -55,6 +61,142 @@ class RandomUniformOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"),
RandomUniformOp);
+class RandomShuffleOp : public XlaOpKernel {
+ public:
+ explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ xla::XlaOp input = ctx->Input(0);
+ TensorShape input_shape = ctx->InputShape(0);
+ const int64 n = input_shape.dim_size(0);
+ int64 num_elements = 1;
+ for (tensorflow::TensorShapeDim dimension : input_shape) {
+ num_elements *= dimension.size;
+ }
+
+ if (num_elements <= 1 || n <= 1) {
+ // No shuffling is required, so copy input directly to output
+ ctx->SetOutput(0, input);
+ return;
+ }
+
+ if (input_shape.dims() == 1) {
+ // For R1s, shuffle values by sorting instead of the obvious Fisher-Yates
+ // algorithm. Fisher-Yates is simple to implement and correct, but not
+ // easily parallelizable. For a sufficiently parallel architecture, it is
+ // faster to sort many times, than Fisher-Yates shuffle once.
+
+ // Shuffle values by assigning each value a random key and sorting the
+ // keys. Keys can collide causing detectable patterns in the shuffled
+ // output. Collisions translates into more ascending sub-sequences in the
+ // shuffled output than would be expected by chance. To avoid collisions,
+ // the number of possible key values must be sufficiently large.
+
+ // How are more than 2^32 keys created? In each loop iteration, the
+ // algorithm sorts by random keys. Conceptually, the earlier iterations
+ // are sorting on the lower-order bits of larger keys that are never
+ // actually assembled.
+
+ // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is
+ // the number of possible keys and n is the number of values. If d = n^2,
+ // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit
+ // as n goes to infinity is zero.
+
+ // This implementation ensures that the key-space is greater than or equal
+ // to the cube of the number of values. The risk of collisions can be
+ // further reduced by increasing Exponent at the expense of
+ // performance.
+
+ // For Exponent = 2, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is
+ // about 1/2.
+
+ // For Exponent = 3, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is
+ // about 1/3255.
+
+ // For Exponent = 4, the expected number of collisions per shuffle is
+ // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is
+ // about 1/132622.
+ constexpr int Exponent = 3;
+ const int rounds = static_cast<int>(
+ std::ceil(Exponent * std::log(num_elements) / std::log(kuint32max)));
+
+ const xla::Shape key_shape =
+ xla::ShapeUtil::MakeShape(xla::U32, {num_elements});
+ xla::XlaOp zero = xla::ConstantR0(builder, 0U);
+
+ // Unfortunately, xla::RngUniform gives values in the half open interval
+ // rather than the closed interval, so instead of 2^32 possible keys there
+ // are only 2^32 - 1 (kuint32max).
+ xla::XlaOp max_value = xla::ConstantR0(builder, kuint32max);
+
+ xla::XlaOp curr = input;
+ for (int i = 0; i < rounds; ++i) {
+ xla::XlaOp keys = xla::RngUniform(zero, max_value, key_shape);
+ xla::XlaOp sorted = xla::Sort(keys, curr);
+ curr = xla::GetTupleElement(sorted, 1);
+ }
+
+ ctx->SetOutput(0, curr);
+ return;
+ }
+
+ // The Fisher-Yates algorithm.
+
+ // Generate the random swaps for the indices.
+ auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n});
+ auto swaps =
+ xla::RngUniform(xla::ConstantR0<int32>(builder, 0),
+ xla::ConstantR0<int32>(builder, n), swaps_shape);
+
+ // Generate range(n) as the initial value for the indices to be swapped.
+ xla::XlaOp indices = xla::Iota(builder, xla::S32, n);
+
+ // Swap the indices at i and swaps[i].
+ auto swap_body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto swaps = loop_vars[0];
+ auto indices = loop_vars[1];
+ i = xla::Reshape(i, {1});
+ // temp = indices[i]
+ auto temp = xla::DynamicSlice(indices, i, {1});
+ // swap_index = swaps[i]
+ auto swap_index = xla::DynamicSlice(swaps, i, {1});
+ // swap_value = indices[swaps[i]]
+ auto swap_value = xla::DynamicSlice(indices, swap_index, {1});
+ // indices[i] = indices[swaps[i]]
+ indices = xla::DynamicUpdateSlice(indices, swap_value, i);
+ // indices[swaps[i]] = temp
+ indices = xla::DynamicUpdateSlice(indices, temp, swap_index);
+ return std::vector<xla::XlaOp>{swaps, indices};
+ };
+ // for i in range(n):
+ auto swap_loop_result =
+ XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices},
+ "indices_swap_loop", builder)
+ .ValueOrDie();
+ auto swapped_indices = swap_loop_result[1];
+
+ // Gather the data using the swapped indices as the shuffled order.
+ auto indices_tensor_shape = TensorShape({n});
+ DataType type = ctx->expected_output_dtype(0);
+ xla::XlaOp gather;
+ OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices,
+ indices_tensor_shape,
+ /*axis=*/0, /*indices_are_nd=*/false, type,
+ DT_INT32, builder, &gather));
+ ctx->SetOutput(0, gather);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp);
+};
+
+REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp);
+
class RandomUniformIntOp : public XlaOpKernel {
public:
explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
@@ -77,7 +219,7 @@ class RandomUniformIntOp : public XlaOpKernel {
auto minval = ctx->Input(1);
auto maxval = ctx->Input(2);
- ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape));
+ ctx->SetOutput(0, xla::RngUniform(minval, maxval, xla_shape));
}
private:
@@ -103,8 +245,8 @@ class RandomStandardNormalOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
// Normal distribution with a mean of 0 and a standard deviation of 1:
- xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype),
- XlaHelpers::One(b, dtype), xla_shape);
+ xla::XlaOp result = xla::RngNormal(XlaHelpers::Zero(b, dtype),
+ XlaHelpers::One(b, dtype), xla_shape);
ctx->SetOutput(0, result);
}
@@ -127,63 +269,21 @@ class TruncatedNormalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
xla::Shape xla_shape;
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
- xla::Shape xla_element_shape =
- xla::ShapeUtil::MakeShape(xla_shape.element_type(), {});
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
- xla::XlaOp stddev = XlaHelpers::One(b, dtype);
- xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape);
-
- auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) {
- return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0);
- };
- auto out_of_range_mask = [two_sd](xla::XlaOp candidate,
- xla::XlaBuilder* b) {
- xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b));
- xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b));
- return b->Or(too_large, too_small);
- };
- // The algorithm we're using is roughly:
- //
- // while (any(candidate < mean-2*sd || candidate > mean+2*sd)) {
- // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd
- // candidate = select(out_of_range_mask, rng_normal(), candidate)
- // }
- std::unique_ptr<xla::XlaBuilder> test_builder =
- b->CreateSubBuilder("truncated_normal_test");
- {
- auto* b = test_builder.get();
- xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
- out_of_range_mask(candidate, b);
- OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status());
- }
-
- std::unique_ptr<xla::XlaBuilder> body_builder =
- b->CreateSubBuilder("truncated_normal_body");
- {
- auto* b = body_builder.get();
- xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate");
- xla::XlaOp to_resample = out_of_range_mask(candidate, b);
- xla::XlaOp mean = XlaHelpers::Zero(b, dtype);
- xla::XlaOp stddev = XlaHelpers::One(b, dtype);
- b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate);
- }
-
- xla::StatusOr<xla::XlaComputation> test_computation = test_builder->Build();
- OP_REQUIRES_OK(ctx, test_computation.status());
- xla::StatusOr<xla::XlaComputation> body_computation = body_builder->Build();
- OP_REQUIRES_OK(ctx, body_computation.status());
- xla::XlaOp result = b->While(test_computation.ValueOrDie(),
- body_computation.ValueOrDie(), candidate);
-
- ctx->SetOutput(0, result);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
+ xla::XlaOp min_positive =
+ XlaHelpers::FloatLiteral(b, dtype, std::numeric_limits<float>::min());
+ auto uniform = xla::RngUniform(min_positive, one, xla_shape);
+ ctx->SetOutput(0, TruncatedNormal(uniform));
}
};
-REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"),
+REGISTER_XLA_OP(Name("TruncatedNormal")
+ .CompileTimeConstInput("shape")
+ .TypeConstraint("dtype", DT_FLOAT),
TruncatedNormalOp);
-} // anonymous namespace
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 08894489ac..76bd1e62aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -98,10 +99,10 @@ class ReduceWindowOp : public XlaOpKernel {
{
std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("wrapper");
- auto x = cb->Parameter(0, scalar_shape, "x");
- auto y = cb->Parameter(1, scalar_shape, "y");
- auto outputs = cb->Call(*reducer.computation, {x, y});
- cb->GetTupleElement(outputs, 0);
+ auto x = xla::Parameter(cb.get(), 0, scalar_shape, "x");
+ auto y = xla::Parameter(cb.get(), 1, scalar_shape, "y");
+ auto outputs = xla::Call(cb.get(), *reducer.computation, {x, y});
+ xla::GetTupleElement(outputs, 0);
xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(context, result.status());
wrapper = std::move(result.ValueOrDie());
@@ -112,7 +113,7 @@ class ReduceWindowOp : public XlaOpKernel {
padding[i] = {padding_low_[i], padding_high_[i]};
}
- xla::XlaOp output = builder->ReduceWindowWithGeneralPadding(
+ xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), wrapper, window_dimensions_,
window_strides_, padding);
context->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 0f42563779..be7f2bce8c 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -19,7 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -31,11 +33,11 @@ class SumOp : public XlaReductionOp {
: XlaReductionOp(ctx,
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Add(scalar_lhs, scalar_rhs);
+ xla::Add(scalar_lhs, scalar_rhs);
}
};
@@ -48,12 +50,12 @@ class ProdOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::One(builder, reduction_type_);
+ return xla::One(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Mul(scalar_lhs, scalar_rhs);
+ xla::Mul(scalar_lhs, scalar_rhs);
}
};
@@ -66,12 +68,12 @@ class MinOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MaxValue(builder, reduction_type_);
+ return xla::MaxValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Min(scalar_lhs, scalar_rhs);
+ xla::Min(scalar_lhs, scalar_rhs);
}
};
@@ -83,12 +85,12 @@ class MaxOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::MinValue(builder, reduction_type_);
+ return xla::MinValue(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Max(scalar_lhs, scalar_rhs);
+ xla::Max(scalar_lhs, scalar_rhs);
}
};
@@ -101,11 +103,11 @@ class MeanOp : public XlaReductionOp {
XlaHelpers::SumAccumulationType(ctx->input_type(0))) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return XlaHelpers::Zero(builder, reduction_type_);
+ return xla::Zero(builder, xla_reduction_type_);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Add(scalar_lhs, scalar_rhs);
+ xla::Add(scalar_lhs, scalar_rhs);
}
xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder,
@@ -113,7 +115,7 @@ class MeanOp : public XlaReductionOp {
int64 num_elements_reduced) override {
auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
num_elements_reduced);
- return builder->Div(reduce_output, divisor);
+ return reduce_output / divisor;
}
};
@@ -126,12 +128,12 @@ class AllOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return builder->ConstantR0<bool>(true);
+ return xla::ConstantR0<bool>(builder, true);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->And(scalar_lhs, scalar_rhs);
+ xla::And(scalar_lhs, scalar_rhs);
}
};
@@ -143,12 +145,12 @@ class AnyOp : public XlaReductionOp {
: XlaReductionOp(ctx, ctx->input_type(0)) {}
xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
- return builder->ConstantR0<bool>(false);
+ return xla::ConstantR0<bool>(builder, false);
}
void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs,
const xla::XlaOp& scalar_rhs) override {
- builder->Or(scalar_lhs, scalar_rhs);
+ xla::Or(scalar_lhs, scalar_rhs);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 2ecfb854a1..8333f9b288 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -64,6 +64,7 @@ class XlaReductionOp : public XlaOpKernel {
protected:
DataType reduction_type_;
+ xla::PrimitiveType xla_reduction_type_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index 4fd5bfd039..ed1d1c6610 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -19,7 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
@@ -31,6 +32,8 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx,
OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+ OP_REQUIRES_OK(
+ ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
// Unless BuildFinalizer is overridden the reduction has no
@@ -56,9 +59,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
// Evaluate the constant, reshaping to a 1-vector if it is a scalar.
xla::Literal axes_literal;
- OP_REQUIRES_OK(ctx,
- ctx->ConstantInputReshaped(
- 1, {axes_tensor_shape.num_elements()}, &axes_literal));
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputReshaped(1, {axes_tensor_shape.num_elements()},
+ &axes_literal));
VLOG(1) << "data shape: " << data_shape.DebugString();
VLOG(1) << "axes : " << axes_literal.ToString();
@@ -101,20 +104,20 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type));
- auto data = b->ConvertElementType(ctx->Input(0), type);
+ auto data = xla::ConvertElementType(ctx->Input(0), type);
// Call virtual method to get the initial value.
- auto initial = b->ConvertElementType(InitialValue(b), type);
+ auto initial = xla::ConvertElementType(InitialValue(b), type);
// Make two scalar parameters of the desired type for the lambda.
- auto rx = r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
- auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
+ auto rx = xla::Parameter(&r, 0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ auto ry = xla::Parameter(&r, 1, xla::ShapeUtil::MakeShape(type, {}), "y");
// Call virtual method to build the reduction lambda.
BuildReducer(&r, rx, ry);
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
- auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes);
+ auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
auto finalized = BuildFinalizer(b, deconverted, num_elements_reduced);
- auto result = keep_dims_ ? b->Reshape(finalized, final_shape) : finalized;
+ auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
index ba7d484d53..f4b804e546 100644
--- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -34,7 +34,7 @@ class ReluOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
- ctx->SetOutput(0, builder->Max(zero, ctx->Input(0)));
+ ctx->SetOutput(0, xla::Max(zero, ctx->Input(0)));
}
};
@@ -46,7 +46,7 @@ class Relu6Op : public XlaOpKernel {
xla::XlaBuilder* builder = ctx->builder();
auto zero = XlaHelpers::Zero(builder, input_type(0));
auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6);
- ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six));
+ ctx->SetOutput(0, xla::Clamp(zero, ctx->Input(0), six));
}
};
@@ -59,9 +59,9 @@ class ReluGradOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
- b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
- const auto pred = b->Gt(ctx->Input(1), zero);
- ctx->SetOutput(0, b->Select(pred, ctx->Input(0), zero));
+ xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto pred = xla::Gt(ctx->Input(1), zero);
+ ctx->SetOutput(0, xla::Select(pred, ctx->Input(0), zero));
}
};
@@ -74,12 +74,12 @@ class Relu6GradOp : public XlaOpKernel {
xla::XlaBuilder* b = ctx->builder();
const TensorShape shape = ctx->InputShape(0);
const auto zero =
- b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
- const auto six = b->Broadcast(
+ xla::Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes());
+ const auto six = xla::Broadcast(
XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes());
- auto out =
- b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)),
- ctx->Input(0), zero);
+ auto out = xla::Select(
+ xla::And(xla::Lt(ctx->Input(1), six), xla::Gt(ctx->Input(1), zero)),
+ ctx->Input(0), zero);
ctx->SetOutput(0, out);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
index af4d64b159..354fec9be7 100644
--- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -19,7 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -90,8 +91,7 @@ class ReshapeOp : public XlaOpKernel {
VLOG(1) << "Reshape " << input_shape.DebugString() << " "
<< shape.DebugString();
- ctx->SetOutput(0,
- ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes()));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index a711278638..5be70a4ded 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,15 +63,24 @@ class RetvalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
- TensorShape representation_shape =
- tc.is_entry_computation()
- ? tc.RepresentationShape(shape, ctx->input_type(0))
- : shape;
+ ctx->SetStatus(is_constant.status());
+ TensorShape representation_shape;
+ if (tc.is_entry_computation()) {
+ xla::StatusOr<TensorShape> shape_or_status =
+ tc.RepresentationShape(shape, ctx->input_type(0));
+ if (!shape_or_status.ok()) {
+ ctx->SetStatus(shape_or_status.status());
+ return;
+ } else {
+ representation_shape = shape_or_status.ValueOrDie();
+ }
+ } else {
+ representation_shape = shape;
+ }
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
- output =
- ctx->builder()->Reshape(input, representation_shape.dim_sizes());
+ output = xla::Reshape(input, representation_shape.dim_sizes());
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
@@ -78,8 +88,8 @@ class RetvalOp : public XlaOpKernel {
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
- output = ctx->builder()->GetTupleElement(
- ctx->builder()->Tuple({output}), 0);
+ output =
+ xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
index 2872a3c4d4..ec15b4cc7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc
@@ -19,7 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -62,7 +63,7 @@ class ReverseOp : public XlaOpKernel {
}
}
- ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), dimensions));
+ ctx->SetOutput(0, xla::Rev(ctx->Input(0), dimensions));
}
};
@@ -100,7 +101,7 @@ class ReverseV2Op : public XlaOpKernel {
x_shape.dims(), ")."));
}
- ctx->SetOutput(0, ctx->builder()->Rev(ctx->Input(0), axes));
+ ctx->SetOutput(0, xla::Rev(ctx->Input(0), axes));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 5d1c052684..c810456f94 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -85,103 +87,96 @@ class ReverseSequenceOp : public XlaOpKernel {
auto condition_builder =
builder->CreateSubBuilder("reverse_sequence_condition");
{
- auto param = condition_builder->Parameter(0, tuple_shape, "param");
- auto i = condition_builder->GetTupleElement(param, 0);
- condition_builder->Lt(
- i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
- batch_size));
+ auto param =
+ xla::Parameter(condition_builder.get(), 0, tuple_shape, "param");
+ auto i = xla::GetTupleElement(param, 0);
+ xla::Lt(i, XlaHelpers::IntegerLiteral(condition_builder.get(),
+ seq_lens_type, batch_size));
}
auto condition = condition_builder->Build();
OP_REQUIRES_OK(context, condition.status());
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
{
- auto param = body_builder->Parameter(0, tuple_shape, "param");
- auto i = body_builder->GetTupleElement(param, 0);
- auto seq_lens = body_builder->GetTupleElement(param, 1);
- auto output = body_builder->GetTupleElement(param, 2);
+ auto param = xla::Parameter(body_builder.get(), 0, tuple_shape, "param");
+ auto i = xla::GetTupleElement(param, 0);
+ auto seq_lens = xla::GetTupleElement(param, 1);
+ auto output = xla::GetTupleElement(param, 2);
// seq_len is the sequence length of the current batch element (rank 1)
- auto seq_len = body_builder->DynamicSlice(
- seq_lens, body_builder->Reshape(i, {1}), {1});
+ auto seq_len = xla::DynamicSlice(seq_lens, xla::Reshape(i, {1}), {1});
// Indices is the offset of the batch element in the input.
- auto batch_element_indices = body_builder->Broadcast(
- XlaHelpers::Zero(body_builder.get(), seq_lens_type),
- {input_shape.dims()});
- batch_element_indices = body_builder->DynamicUpdateSlice(
- batch_element_indices, body_builder->Reshape(i, {1}),
- body_builder->Reshape(
- XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
- batch_dim_),
- {1}));
+ auto batch_element_indices =
+ xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {input_shape.dims()});
+ batch_element_indices = xla::DynamicUpdateSlice(
+ batch_element_indices, xla::Reshape(i, {1}),
+ xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
+ seq_lens_type, batch_dim_),
+ {1}));
// Slice out the current batch element and pad it out in the sequence
// dimension.
TensorShape slice_shape = input_shape;
slice_shape.set_dim(batch_dim_, 1);
slice_shape.set_dim(seq_dim_, max_seq_len);
- auto slice = body_builder->DynamicSlice(output, batch_element_indices,
- slice_shape.dim_sizes());
+ auto slice = xla::DynamicSlice(output, batch_element_indices,
+ slice_shape.dim_sizes());
auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims());
padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high(
slice_shape.dim_size(seq_dim_));
- slice = body_builder->Pad(
- slice, XlaHelpers::Zero(body_builder.get(), input_type),
- padding_config);
+ slice = xla::Pad(slice, XlaHelpers::Zero(body_builder.get(), input_type),
+ padding_config);
// Now slice out the reversed sequence from its actual start.
// sequence_start_indices is the offset of the start of the reversed
// sequence in the input. The slice will go into the padding, however, we
// will mask off these elements and replace them with elements from the
// original input so their values do not matter.
- auto sequence_start_indices = body_builder->Broadcast(
- XlaHelpers::Zero(body_builder.get(), seq_lens_type),
- {slice_shape.dims()});
- sequence_start_indices = body_builder->DynamicUpdateSlice(
+ auto sequence_start_indices =
+ xla::Broadcast(XlaHelpers::Zero(body_builder.get(), seq_lens_type),
+ {slice_shape.dims()});
+ sequence_start_indices = xla::DynamicUpdateSlice(
sequence_start_indices,
- body_builder->Sub(XlaHelpers::IntegerLiteral(
- body_builder.get(), seq_lens_type, max_seq_len),
- seq_len),
- body_builder->Reshape(
- XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
- seq_dim_),
- {1}));
- slice = body_builder->DynamicSlice(slice, sequence_start_indices,
- slice_shape.dim_sizes());
+ xla::Sub(XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
+ max_seq_len),
+ seq_len),
+ xla::Reshape(XlaHelpers::IntegerLiteral(body_builder.get(),
+ seq_lens_type, seq_dim_),
+ {1}));
+ slice = xla::DynamicSlice(slice, sequence_start_indices,
+ slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
- output = body_builder->DynamicUpdateSlice(output, slice,
- batch_element_indices);
+ output = xla::DynamicUpdateSlice(output, slice, batch_element_indices);
- body_builder->Tuple(
- {body_builder->Add(
- i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
+ xla::Tuple(
+ body_builder.get(),
+ {xla::Add(i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
seq_lens, output});
}
auto body = body_builder->Build();
OP_REQUIRES_OK(context, body.status());
- auto loop_output = builder->While(
+ auto loop_output = xla::While(
condition.ValueOrDie(), body.ValueOrDie(),
- builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
- builder->Rev(input, {seq_dim_})}));
- auto output = builder->GetTupleElement(loop_output, 2);
+ xla::Tuple(builder, {XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
+ xla::Rev(input, {seq_dim_})}));
+ auto output = xla::GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
- xla::XlaOp iota;
- OP_REQUIRES_OK(
- context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
+ xla::XlaOp iota =
+ xla::Iota(builder, seq_lens_xla_shape.element_type(), max_seq_len);
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
- auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
+ auto mask = xla::Lt(iota, xla::Reshape(seq_lens, dims), {seq_dim_});
// Broadcast the mask up to the input shape.
- mask =
- builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
- input_shape.dim_sizes()));
+ mask = xla::Or(mask, xla::Broadcast(xla::ConstantR0<bool>(builder, false),
+ input_shape.dim_sizes()));
- output = builder->Select(mask, output, input);
+ output = xla::Select(mask, output, input);
context->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index 1819fb5433..27ab3e1bf5 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
@@ -100,7 +101,7 @@ class ScanOp : public XlaOpKernel {
init = XlaHelpers::One(builder, dtype);
reducer = ctx->GetOrCreateMul(dtype);
}
- auto output = builder->ReduceWindowWithGeneralPadding(
+ auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
*reducer, window_dims, window_strides, padding);
output =
@@ -110,12 +111,12 @@ class ScanOp : public XlaOpKernel {
// of all the input elements. Slice off this extra "last" element.
if (exclusive_) {
if (reverse_) {
- output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1,
- 1, axis);
+ output =
+ xla::SliceInDim(output, 1, input_shape.dim_size(axis) + 1, 1, axis);
} else {
output =
- builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
+ xla::SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
}
}
ctx->SetOutput(0, output);
diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
index f2c63b4f90..14709bb6cb 100644
--- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -103,8 +104,8 @@ class ScatterNdOp : public XlaOpKernel {
updates_shape));
xla::XlaBuilder* builder = context->builder();
- auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
- buffer_shape.dim_sizes());
+ auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
+ buffer_shape.dim_sizes());
auto indices = context->Input(0);
auto updates = context->Input(1);
auto result =
diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
index 664078ca16..e2ac7da2c2 100644
--- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc
@@ -14,20 +14,30 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
-class UnsortedSegmentSum : public XlaOpKernel {
+class UnsortedSegmentReduce : public XlaOpKernel {
public:
- explicit UnsortedSegmentSum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ DataType dtype;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_));
}
+ // The initial value to initialize elements of the output to.
+ virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
+
+ // A function to combine two scalars with the same index (e.g., sum).
+ virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
+
void Compile(XlaOpKernelContext* ctx) override {
// output = unsorted_segment_sum(data, indices, num_segments)
// Compute a tensor such that:
@@ -50,28 +60,28 @@ class UnsortedSegmentSum : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
- errors::InvalidArgument(
- "UnsortedSegmentSum requires that indices' rank be"
- " less than or equal to data's rank."));
+ errors::InvalidArgument(type_string(),
+ " requires that indices' rank be"
+ " less than or equal to data's rank."));
// Validate that indices.shape is a prefix of data.shape.
for (int d = 0; d < indices_shape.dims(); ++d) {
- OP_REQUIRES(ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
- errors::InvalidArgument(
- "UnsortedSegmentSum requires indices shape to be prefix"
- " of data_shape, but dimension ",
- d, " differs ", data_shape.dim_size(d), " vs. ",
- indices_shape.dim_size(d)));
+ OP_REQUIRES(
+ ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
+ errors::InvalidArgument(type_string(),
+ " requires indices shape to be prefix"
+ " of data_shape, but dimension ",
+ d, " differs ", data_shape.dim_size(d),
+ " vs. ", indices_shape.dim_size(d)));
}
xla::XlaBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
- auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
- buffer_shape.dim_sizes());
+ auto buffer =
+ xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
- auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) {
- return builder->Add(a, b);
- };
+ auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
+ xla::XlaBuilder* builder) { return Combine(a, b); };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
@@ -79,13 +89,73 @@ class UnsortedSegmentSum : public XlaOpKernel {
ctx->SetOutput(0, result.ValueOrDie());
}
- private:
- DataType dtype_;
+ protected:
+ xla::PrimitiveType type_;
+};
+
+class UnsortedSegmentSum : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentSum(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return xla::Zero(builder, type_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; };
};
REGISTER_XLA_OP(
Name("UnsortedSegmentSum").CompileTimeConstInput("num_segments"),
UnsortedSegmentSum);
+class UnsortedSegmentProd : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentProd(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return xla::One(builder, type_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentProd").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentProd);
+
+class UnsortedSegmentMin : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentMin(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return xla::MaxFiniteValue(builder, type_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
+ return xla::Min(a, b);
+ };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentMin").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentMin);
+
+class UnsortedSegmentMax : public UnsortedSegmentReduce {
+ public:
+ explicit UnsortedSegmentMax(OpKernelConstruction* ctx)
+ : UnsortedSegmentReduce(ctx) {}
+
+ xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
+ return xla::MinFiniteValue(builder, type_);
+ };
+ xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
+ return xla::Max(a, b);
+ };
+};
+
+REGISTER_XLA_OP(
+ Name("UnsortedSegmentMax").CompileTimeConstInput("num_segments"),
+ UnsortedSegmentMax);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
index f9f48164d6..5c010c9df2 100644
--- a/tensorflow/compiler/tf2xla/kernels/select_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -40,8 +41,6 @@ class SelectOp : public XlaOpKernel {
"'then' and 'else' must have the same size. but received: ",
then_shape.DebugString(), " vs. ", else_shape.DebugString()));
- xla::XlaBuilder* builder = ctx->builder();
-
auto cond_handle = ctx->Input(0);
auto then_handle = ctx->Input(1);
auto else_handle = ctx->Input(2);
@@ -69,14 +68,14 @@ class SelectOp : public XlaOpKernel {
const auto dim_sizes = then_shape.dim_sizes();
gtl::ArraySlice<int64> bdims = dim_sizes;
bdims.pop_front();
- cond_handle = builder->Broadcast(cond_handle, bdims);
+ cond_handle = xla::Broadcast(cond_handle, bdims);
std::vector<int64> dim_order(then_shape.dims());
dim_order[0] = then_shape.dims() - 1;
std::iota(dim_order.begin() + 1, dim_order.end(), 0);
- cond_handle = builder->Transpose(cond_handle, dim_order);
+ cond_handle = xla::Transpose(cond_handle, dim_order);
}
- ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle));
+ ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index 9ce01d0d44..6281d6c653 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -45,7 +45,7 @@ void SendOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
- ctx->builder()->Send(ctx->Input(0), channel);
+ xla::Send(ctx->Input(0), channel);
}
REGISTER_XLA_OP(Name("XlaSend"), SendOp);
@@ -76,7 +76,7 @@ void RecvOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
- ctx->SetOutput(0, ctx->builder()->Recv(shape_, channel));
+ ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel));
}
REGISTER_XLA_OP(Name("XlaRecv"), RecvOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 2c31f8d908..25a5bcbe1d 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -55,9 +55,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
// The type-specific part of the implementation of Range.
template <typename T>
-Status CreateRangeTensor(const xla::Literal& start_literal,
- const xla::Literal& limit_literal,
- const xla::Literal& delta_literal, Tensor* output) {
+Status CreateRangeTensor(const xla::LiteralSlice& start_literal,
+ const xla::LiteralSlice& limit_literal,
+ const xla::LiteralSlice& delta_literal,
+ Tensor* output) {
T start = start_literal.Get<T>({});
T limit = limit_literal.Get<T>({});
T delta = delta_literal.Get<T>({});
@@ -67,13 +68,13 @@ Status CreateRangeTensor(const xla::Literal& start_literal,
}
if (delta > 0) {
if (start > limit) {
- return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
- start, "/", limit);
+ return errors::InvalidArgument(
+ "Requires start <= limit when delta > 0: ", start, "/", limit);
}
} else {
if (start < limit) {
- return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
- start, "/", limit);
+ return errors::InvalidArgument(
+ "Requires start >= limit when delta < 0: ", start, "/", limit);
}
}
int64 size =
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 05354bca5b..5798823cd5 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -43,7 +44,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape"), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -65,7 +66,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -81,7 +82,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank"), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -100,7 +101,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size"), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
@@ -147,7 +148,7 @@ class ExpandDimsOp : public XlaOpKernel {
dim = std::min<int32>(dim, existing_dims_size);
new_shape.emplace(new_shape.begin() + dim, 1);
- ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
}
};
REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp);
@@ -189,10 +190,9 @@ class SqueezeOp : public XlaOpKernel {
if (!wrapped_squeeze_dims.empty()) {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
- errors::InvalidArgument("Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ",
- existing_dim));
+ errors::InvalidArgument(
+ "Tried to explicitly squeeze dimension ", i,
+ " but dimension was not 1: ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);
@@ -205,7 +205,7 @@ class SqueezeOp : public XlaOpKernel {
}
}
- ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape));
}
private:
@@ -222,7 +222,7 @@ class ZerosLikeOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes()));
}
};
@@ -236,7 +236,7 @@ class OnesLikeOp : public XlaOpKernel {
const TensorShape input_shape = ctx->InputShape(0);
auto one = XlaHelpers::One(ctx->builder(), input_type(0));
- ctx->SetOutput(0, ctx->builder()->Broadcast(one, input_shape.dim_sizes()));
+ ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes()));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
index be1e97bf26..1864584ade 100644
--- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -92,8 +93,7 @@ class SliceOp : public XlaOpKernel {
limits.push_back(begin[i] + size[i]);
}
std::vector<int64> strides(begin.size(), 1);
- ctx->SetOutput(
- 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides));
+ ctx->SetOutput(0, xla::Slice(ctx->Input(0), begin, limits, strides));
} else {
// `begin` is not a compile-time constant.
for (int i = 0; i < input_dims; ++i) {
@@ -106,8 +106,7 @@ class SliceOp : public XlaOpKernel {
input_shape.dim_size(i), "], but ",
"got ", size[i]));
}
- ctx->SetOutput(
- 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size));
+ ctx->SetOutput(0, xla::DynamicSlice(ctx->Input(0), ctx->Input(1), size));
}
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index bbf5ee8b12..a71fbcd901 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,9 +15,12 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -41,6 +44,7 @@ class SoftmaxOp : public XlaOpKernel {
const int kClassDim = 1;
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
xla::XlaBuilder* const b = ctx->builder();
@@ -48,24 +52,27 @@ class SoftmaxOp : public XlaOpKernel {
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
- b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension.
- auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
- auto exp_shifted = b->Exp(shifted_logits);
+ auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
+ auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
+ xla::PrimitiveType xla_accumulation_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type,
+ &xla_accumulation_type));
auto converted =
- XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type);
+ xla::ConvertElementType(exp_shifted, xla_accumulation_type);
auto reduce =
- b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
auto softmax =
log_
// softmax = shifted_logits - log(sum(exp(shifted_logits)))
- ? b->Sub(shifted_logits, b->Log(sum), {kBatchDim})
+ ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim})
// softmax = exp(shifted_logits) / sum(exp(shifted_logits))
- : b->Div(exp_shifted, sum, {kBatchDim});
+ : xla::Div(exp_shifted, sum, {kBatchDim});
ctx->SetOutput(0, softmax);
}
@@ -77,8 +84,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
- XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
- const xla::XlaOp& labels) {
+ XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type,
+ xla::XlaOp logits, xla::XlaOp labels) {
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
const int kBatchDim = 0;
@@ -87,43 +94,44 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
- b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b.
// Broadcasts along the batch dimension.
- auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+ auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
// exp(logits - max_logits)
- auto exp_shifted_logits = b->Exp(shifted_logits);
+ auto exp_shifted_logits = xla::Exp(shifted_logits);
// sum_{class} (exp(logits - max_logits))
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type);
- auto reduce = b->Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ auto reduce =
+ xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type);
// log(sum(exp(logits - max_logits)))
- auto log_sum_exp = b->Log(sum_exp);
+ auto log_sum_exp = xla::Log(sum_exp);
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
// (The subtraction broadcasts along the batch dimension.)
- auto sub = b->Sub(shifted_logits, log_sum_exp, {kBatchDim});
- auto mul = b->Mul(b->Neg(labels), sub);
+ auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
+ auto mul = xla::Mul(xla::Neg(labels), sub);
auto sum =
- b->Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
+ XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto loss = XlaHelpers::ConvertElementType(b, sum, type);
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
// (where the division broadcasts along the batch dimension)
xla::XlaOp backprop =
- b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
+ xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
return {loss, backprop};
}
@@ -146,12 +154,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
// check that "labels" is a matrix too.
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, type, logits, labels);
+ CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}
@@ -187,8 +196,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
int64 batch_size = logits_shape.dim_size(0);
int64 depth = logits_shape.dim_size(1);
- DataType logits_type = input_type(0);
- DataType indices_type = input_type(1);
+ const DataType logits_type = input_type(0);
+ const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0);
+ const DataType indices_type = input_type(1);
xla::XlaOp indices = ctx->Input(1);
@@ -206,20 +216,18 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
// Builds a vector of {batch_size} that is 0 if the index is in range, or
// NaN otherwise; then add that vector to the labels to force out-of-range
// values to NaNs.
- xla::XlaOp nan_or_zero = builder->Select(
- builder->And(
- builder->Le(XlaHelpers::Zero(builder, indices_type), indices),
- builder->Lt(indices, XlaHelpers::IntegerLiteral(
- builder, indices_type, depth))),
- builder->Broadcast(XlaHelpers::Zero(builder, logits_type),
- {batch_size}),
- builder->Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
- {batch_size}));
- labels = builder->Add(labels, nan_or_zero, {0});
+ xla::XlaOp nan_or_zero = xla::Select(
+ xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices),
+ xla::Lt(indices, XlaHelpers::IntegerLiteral(
+ builder, indices_type, depth))),
+ xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}),
+ xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
+ {batch_size}));
+ labels = xla::Add(labels, nan_or_zero, {0});
xla::XlaOp loss, backprop;
- std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
+ std::tie(loss, backprop) = CrossEntropyWithLogits(
+ ctx, logits_type, xla_logits_type, ctx->Input(0), labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
new file mode 100644
index 0000000000..faaf8964ff
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaSortOp : public XlaOpKernel {
+ public:
+ explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ context->SetOutput(0, xla::Sort(context->Input(0)));
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
index ec077924b5..8a8525efa1 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
namespace {
@@ -73,7 +74,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
"The product of the block dimensions must be positive"));
xla::XlaOp padded =
- b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
+ xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
// 2. Reshape `padded` to `reshaped_padded` of shape:
//
@@ -100,7 +101,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
reshaped_padded_shape.begin() + 1 + 2 * block_rank);
- xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape);
+ xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
// 3. Permute dimensions of `reshaped_padded` to produce
// `permuted_reshaped_padded` of shape:
@@ -120,7 +121,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
1 + block_rank * 2);
xla::XlaOp permuted_reshaped_padded =
- b->Transpose(reshaped_padded, permutation);
+ xla::Transpose(reshaped_padded, permutation);
// 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -140,7 +141,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
std::copy(remainder_shape.begin(), remainder_shape.end(),
output_shape.begin() + 1 + block_rank);
- xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 4c5886ee2a..47d282fe9e 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -50,7 +51,6 @@ class SpaceToDepthOp : public XlaOpKernel {
const gtl::InlinedVector<int64, 4> input_shape =
input_tensor_shape.dim_sizes();
- xla::XlaBuilder* b = ctx->builder();
xla::XlaOp input = ctx->Input(0);
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
@@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[1] / block_size_, block_size_,
// input_shape[2] / block_size_, block_size_,
// depth]
- xla::XlaOp reshaped = b->Reshape(input, reshaped_shape);
+ xla::XlaOp reshaped = xla::Reshape(input, reshaped_shape);
// 2. Permute dimensions of `reshaped` to produce
// `permuted_reshaped` of shape:
@@ -145,7 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_, block_size_,
// depth]
- xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order);
+ xla::XlaOp permuted_reshaped = xla::Transpose(reshaped, transpose_order);
// 3. Reshape `permuted_reshaped` to flatten `block_shape` into the
// batch dimension, producing an output tensor of shape:
@@ -155,7 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel {
// input_shape[2] / block_size_,
// block_size_ * block_size_ * depth]
//
- xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape);
+ xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
new file mode 100644
index 0000000000..e831dc30a9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc
@@ -0,0 +1,88 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+
+namespace tensorflow {
+namespace {
+
+// Operator to convert sparse representations to dense.
+class SparseToDenseOp : public XlaOpKernel {
+ public:
+ explicit SparseToDenseOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ // sparse_indices
+ const TensorShape indices_shape = context->InputShape(0);
+ OP_REQUIRES(context, indices_shape.dims() <= 2,
+ errors::InvalidArgument(
+ "sparse_indices should be a scalar, vector, or matrix, "
+ "got shape ",
+ indices_shape.DebugString()));
+ const int64 num_elems =
+ indices_shape.dims() > 0 ? indices_shape.dim_size(0) : 1;
+ const int64 num_dims =
+ indices_shape.dims() > 1 ? indices_shape.dim_size(1) : 1;
+
+ // output_shape
+ TensorShape output_shape;
+ OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
+ OP_REQUIRES(context, output_shape.dims() == num_dims,
+ errors::InvalidArgument(
+ "output_shape has incorrect number of elements: ",
+ output_shape.num_elements(), " should be: ", num_dims));
+
+ // sparse_values
+ const TensorShape sparse_values_shape = context->InputShape(2);
+ const int64 num_values = sparse_values_shape.num_elements();
+ OP_REQUIRES(
+ context,
+ sparse_values_shape.dims() == 0 ||
+ (sparse_values_shape.dims() == 1 && num_values == num_elems),
+ errors::InvalidArgument("sparse_values has incorrect shape ",
+ sparse_values_shape.DebugString(),
+ ", should be [] or [", num_elems, "]"));
+
+ // default_value
+ const TensorShape default_value_shape = context->InputShape(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(default_value_shape),
+ errors::InvalidArgument("default_value should be a scalar."));
+
+ xla::XlaOp indices = context->Input(0);
+ xla::XlaOp sparse_values = context->Input(2);
+ xla::XlaOp default_value = context->Input(3);
+
+ if (sparse_values_shape.dims() == 0 && num_elems != 1) {
+ sparse_values = Broadcast(sparse_values, {num_elems});
+ }
+ xla::XlaBuilder* builder = context->builder();
+ auto buffer = Broadcast(default_value, output_shape.dim_sizes());
+
+ auto result = XlaScatter(buffer, sparse_values, indices,
+ /*indices_are_vectors=*/num_dims > 1,
+ /*combiner=*/{}, builder);
+ context->SetOutput(0, builder->ReportErrorOrReturn(result));
+ }
+};
+
+REGISTER_XLA_OP(Name("SparseToDense").CompileTimeConstInput("output_shape"),
+ SparseToDenseOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
index 8958b2e770..242638f981 100644
--- a/tensorflow/compiler/tf2xla/kernels/split_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -19,7 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -98,7 +99,7 @@ class SplitOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
begin[split_dim] = i * slice_size;
limits[split_dim] = (i + 1) * slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
+ ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
}
}
};
@@ -134,7 +135,7 @@ class SplitVOp : public XlaOpKernel {
errors::InvalidArgument(
"Number of ways to split should be > 0, but got ", num_split));
- // check that sizes are correct
+ // Check that sizes are correct.
int total_split_size = 0;
int neg_one_dim = -1;
std::vector<int64> split_sizes_vec(num_split, -1);
@@ -148,7 +149,7 @@ class SplitVOp : public XlaOpKernel {
" number of elements as the output. Got ",
split_size_shape.dims(), "-D and ",
split_size_shape.num_elements(), " elements"));
- // get the dimension of this split
+ // Get the dimension of this split.
xla::Literal split_size_literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
@@ -199,7 +200,7 @@ class SplitVOp : public XlaOpKernel {
// Slice out the ith split from the split dimension.
limits[split_dim] = begin[split_dim] + slice_size;
- ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits, strides));
+ ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
begin[split_dim] = limits[split_dim];
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 0fb05a2be7..df91900570 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
@@ -144,24 +144,25 @@ class StackPushOp : public XlaOpKernel {
// Initializes the Stack, if the element shape was not already known.
OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape));
- xla::XlaOp ta = b->GetTupleElement(resource->value(), 0);
- xla::XlaOp index = b->GetTupleElement(resource->value(), 1);
+ xla::XlaOp ta = xla::GetTupleElement(resource->value(), 0);
+ xla::XlaOp index = xla::GetTupleElement(resource->value(), 1);
xla::XlaOp value = ctx->Input(1);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
- auto update = b->Reshape(value, slice_shape.dim_sizes());
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
- {b->DynamicUpdateSlice(ta, update, start_indices),
- b->Add(index, b->ConstantR0<int32>(1))})));
+ OP_REQUIRES_OK(ctx,
+ resource->SetValue(xla::Tuple(
+ b, {xla::DynamicUpdateSlice(ta, update, start_indices),
+ xla::Add(index, xla::ConstantR0<int32>(b, 1))})));
ctx->SetOutput(0, value);
}
@@ -197,27 +198,27 @@ class StackPopOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape));
xla::XlaOp state = resource->value();
- xla::XlaOp ta = b->GetTupleElement(state, 0);
- xla::XlaOp index = b->GetTupleElement(state, 1);
+ xla::XlaOp ta = xla::GetTupleElement(state, 0);
+ xla::XlaOp index = xla::GetTupleElement(state, 1);
- index = b->Sub(index, b->ConstantR0<int32>(1));
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
+ index = Sub(index, xla::ConstantR0<int32>(b, 1));
+ OP_REQUIRES_OK(ctx, resource->SetValue(xla::Tuple(b, {ta, index})));
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, stack_shape.dims() - 1}}));
auto slice_shape = stack_shape.dim_sizes();
slice_shape[0] = 1LL;
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
- ctx->SetOutput(0, b->Reshape(read, value_shape));
+ ctx->SetOutput(0, xla::Reshape(read, value_shape));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
index a99d4ddc7c..a6f5769e7b 100644
--- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc
@@ -15,11 +15,15 @@ limitations under the License.
#include <cmath>
+#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -32,17 +36,9 @@ namespace {
// Rotates a 32-bit integer 'v' left by 'distance' bits.
xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v,
int distance) {
- return builder->Or(
- builder->ShiftLeft(v, builder->ConstantR0<int>(distance)),
- builder->ShiftRightLogical(v, builder->ConstantR0<int>(32 - distance)));
-}
-
-// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than
-// building XOR out of other bitwise operators.
-xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x,
- const xla::XlaOp& y) {
- return builder->Or(builder->And(x, builder->Not(y)),
- builder->And(builder->Not(x), y));
+ return xla::Or(
+ xla::ShiftLeft(v, xla::ConstantR0<int>(builder, distance)),
+ xla::ShiftRightLogical(v, xla::ConstantR0<int>(builder, 32 - distance)));
}
using ThreeFry2x32State = std::array<xla::XlaOp, 2>;
@@ -58,22 +54,22 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
std::array<xla::XlaOp, 3> ks;
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
- ks[2] = builder->ConstantR0<int32>(0x1BD11BDA);
+ ks[2] = xla::ConstantR0<int32>(builder, 0x1BD11BDA);
for (int i = 0; i < 2; ++i) {
ks[i] = key[i];
x[i] = input[i];
- ks[2] = BitwiseXor(builder, ks[2], key[i]);
+ ks[2] = xla::Xor(ks[2], key[i]);
}
- x[0] = builder->Add(x[0], ks[0]);
- x[1] = builder->Add(x[1], ks[1]);
+ x[0] = xla::Add(x[0], ks[0]);
+ x[1] = xla::Add(x[1], ks[1]);
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [builder](ThreeFry2x32State v, int rotation) {
- v[0] = builder->Add(v[0], v[1]);
+ v[0] = xla::Add(v[0], v[1]);
v[1] = RotateLeftS32(builder, v[1], rotation);
- v[1] = BitwiseXor(builder, v[0], v[1]);
+ v[1] = xla::Xor(v[0], v[1]);
return v;
};
@@ -83,36 +79,36 @@ ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder,
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[1]);
- x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(1));
+ x[0] = xla::Add(x[0], ks[1]);
+ x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 1));
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
- x[0] = builder->Add(x[0], ks[2]);
- x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(2));
+ x[0] = xla::Add(x[0], ks[2]);
+ x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 2));
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[0]);
- x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0<int32>(3));
+ x[0] = xla::Add(x[0], ks[0]);
+ x[1] = xla::Add(xla::Add(x[1], ks[1]), xla::ConstantR0<int32>(builder, 3));
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
- x[0] = builder->Add(x[0], ks[1]);
- x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0<int32>(4));
+ x[0] = xla::Add(x[0], ks[1]);
+ x[1] = xla::Add(xla::Add(x[1], ks[2]), xla::ConstantR0<int32>(builder, 4));
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
- x[0] = builder->Add(x[0], ks[2]);
- x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0<int32>(5));
+ x[0] = xla::Add(x[0], ks[2]);
+ x[1] = xla::Add(xla::Add(x[1], ks[0]), xla::ConstantR0<int32>(builder, 5));
return x;
}
@@ -123,8 +119,8 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
const TensorShape& shape, double minval,
double maxval) {
// Split the seed into two 32-bit scalars to form a key.
- auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {});
- auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {});
+ auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
+ auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
ThreeFry2x32State key = {seed0, seed1};
const int64 size = shape.num_elements();
@@ -133,81 +129,36 @@ xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed,
// Fill the generator inputs with unique counter values.
ThreeFry2x32State inputs;
- TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0]));
- inputs[1] = builder->Add(inputs[0], builder->ConstantR0<int32>(half_size));
+ inputs[0] = xla::Iota(builder, xla::S32, half_size);
+ inputs[1] = xla::Add(inputs[0], xla::ConstantR0<int32>(builder, half_size));
ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key);
if (size_is_odd) {
- outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1});
+ outputs[1] = xla::Slice(outputs[1], {0}, {half_size - 1}, {1});
}
auto bits =
- builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes());
+ xla::Reshape(xla::ConcatInDim(builder, outputs, 0), shape.dim_sizes());
// Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit
// forces the random bits into the mantissa.
constexpr int kFloatBits = 32;
constexpr int kMantissaBits = 23;
- bits = builder->Or(
- builder->ShiftRightLogical(
- bits, builder->ConstantR0<int32>(kFloatBits - kMantissaBits)),
- builder->ConstantR0<int32>(bit_cast<int32>(1.0f)));
- auto floats = builder->BitcastConvertType(bits, xla::F32);
+ bits = xla::Or(
+ xla::ShiftRightLogical(
+ bits, xla::ConstantR0<int32>(builder, kFloatBits - kMantissaBits)),
+ xla::ConstantR0<int32>(builder, bit_cast<int32>(1.0f)));
+ auto floats = xla::BitcastConvertType(bits, xla::F32);
// We have a floating point number in the range [1.0, 2.0).
// Subtract 1.0f to shift to the range [0.0, 1.0)
- floats = builder->Sub(floats, builder->ConstantR0<float>(1.0f));
+ floats = xla::Sub(floats, xla::ConstantR0<float>(builder, 1.0f));
// Multiply and add to shift to the range [minval, maxval).
- floats = builder->Mul(floats, builder->ConstantR0<float>(maxval - minval));
- floats = builder->Add(floats, builder->ConstantR0<float>(minval));
+ floats = xla::Mul(floats, xla::ConstantR0<float>(builder, maxval - minval));
+ floats = xla::Add(floats, xla::ConstantR0<float>(builder, minval));
return floats;
}
-// Approximation for the inverse error function from
-// Giles, M., "Approximating the erfinv function".
-// The approximation has the form:
-// w = -log((1 - x) * (1 + x))
-// if ( w < 5 ) {
-// w = w - 2.5
-// p = sum_{i=1}^n lq[i]*w^i
-// } else {
-// w = sqrt(w) - 3
-// p = sum_{i=1}^n gq[i]*w^i
-// }
-// return p*x
-xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x,
- const TensorShape& shape) {
- constexpr int kDegree = 9;
- constexpr std::array<float, 9> w_less_than_5_constants = {
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- constexpr std::array<float, 9> w_greater_than_5_constants = {
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
-
- auto one = b->ConstantR0<float>(1.0);
- auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x))));
-
- auto lt = b->Lt(w, b->ConstantR0<float>(5.0));
- auto coefficient = [&](int i) {
- return b->Select(
- lt,
- b->Broadcast(b->ConstantR0<float>(w_less_than_5_constants[i]),
- shape.dim_sizes()),
- b->Broadcast(b->ConstantR0<float>(w_greater_than_5_constants[i]),
- shape.dim_sizes()));
- };
- w = b->Select(lt, b->Sub(w, b->ConstantR0<float>(2.5f)),
- b->Sub(b->SqrtF32(w), b->ConstantR0<float>(3.0f)));
- auto p = coefficient(0);
- for (int i = 1; i < kDegree; ++i) {
- p = b->Add(coefficient(i), b->Mul(p, w));
- }
- return b->Mul(p, x);
-}
-
} // namespace
class StatelessRandomUniformOp : public XlaOpKernel {
@@ -259,8 +210,8 @@ class StatelessRandomNormalOp : public XlaOpKernel {
RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0);
// Convert uniform distribution to normal distribution by computing
// sqrt(2) * erfinv(x)
- auto normal = builder->Mul(builder->ConstantR0<float>(std::sqrt(2.0)),
- ErfInvF32(builder, uniform, shape));
+ auto normal =
+ xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
ctx->SetOutput(0, normal);
}
@@ -275,4 +226,35 @@ REGISTER_XLA_OP(Name("StatelessRandomNormal")
.TypeConstraint("Tseed", DT_INT32),
StatelessRandomNormalOp);
+class StatelessTruncatedNormalOp : public XlaOpKernel {
+ public:
+ explicit StatelessTruncatedNormalOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+
+ TensorShape seed_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, seed_shape == TensorShape({2}),
+ errors::InvalidArgument("seed must have shape [2], not ",
+ seed_shape.DebugString()));
+ xla::XlaOp seed = ctx->Input(1);
+ xla::XlaBuilder* b = ctx->builder();
+
+ auto uniform =
+ RandomUniform(b, seed, shape, std::numeric_limits<float>::min(), 1.0);
+ ctx->SetOutput(0, TruncatedNormal(uniform));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(StatelessTruncatedNormalOp);
+};
+
+REGISTER_XLA_OP(Name("StatelessTruncatedNormal")
+ .CompileTimeConstInput("shape")
+ .TypeConstraint("dtype", DT_FLOAT)
+ .TypeConstraint("Tseed", DT_INT32),
+ StatelessTruncatedNormalOp);
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index 55254c746e..c2165ccd86 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -92,12 +93,12 @@ class StridedSliceOp : public XlaOpKernel {
xla::XlaOp slice = ctx->Input(0);
if (!dimensions_to_reverse.empty()) {
- slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
+ slice = xla::Rev(slice, dimensions_to_reverse);
}
- slice = ctx->builder()->Slice(slice, slice_begin, slice_end, slice_strides);
+ slice = xla::Slice(slice, slice_begin, slice_end, slice_strides);
- slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
+ slice = xla::Reshape(slice, final_shape.dim_sizes());
ctx->SetOutput(0, slice);
}
@@ -171,7 +172,7 @@ class StridedSliceGradOp : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(4);
// Undo any new/shrink axes.
- grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
+ grad = xla::Reshape(grad, processing_shape.dim_sizes());
// Pad the input gradients.
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
@@ -204,9 +205,9 @@ class StridedSliceGradOp : public XlaOpKernel {
}
}
if (!dimensions_to_reverse.empty()) {
- grad = ctx->builder()->Rev(grad, dimensions_to_reverse);
+ grad = xla::Rev(grad, dimensions_to_reverse);
}
- grad = ctx->builder()->Pad(grad, zero, padding_config);
+ grad = xla::Pad(grad, zero, padding_config);
ctx->SetOutput(0, grad);
}
@@ -306,17 +307,17 @@ class StridedSliceAssignOp : public XlaOpKernel {
}
if (!dimensions_to_reverse.empty()) {
- rhs = ctx->builder()->Rev(rhs, dimensions_to_reverse);
+ rhs = xla::Rev(rhs, dimensions_to_reverse);
}
- rhs = ctx->builder()->Reshape(rhs, slice_dims);
+ rhs = xla::Reshape(rhs, slice_dims);
if (lhs_shape.dims() == 0) {
// TODO(b/38323843): DynamicUpdateSlice crashes on rank 0 inputs. Fix
// and remove this workaround.
lhs = rhs;
} else {
- lhs = ctx->builder()->DynamicUpdateSlice(
- lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
+ lhs = xla::DynamicUpdateSlice(
+ lhs, rhs, xla::ConstantR1<int64>(ctx->builder(), slice_begin));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 9adee78a1f..26326f18b8 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -25,7 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_resource.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/register_types.h"
@@ -123,10 +124,9 @@ xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand,
const xla::XlaOp& update,
const gtl::ArraySlice<int64>& update_dims,
const xla::XlaOp& start_indices) {
- xla::XlaOp current =
- builder->DynamicSlice(operand, start_indices, update_dims);
- xla::XlaOp sum = builder->Add(current, update);
- return builder->DynamicUpdateSlice(operand, sum, start_indices);
+ xla::XlaOp current = xla::DynamicSlice(operand, start_indices, update_dims);
+ xla::XlaOp sum = xla::Add(current, update);
+ return xla::DynamicUpdateSlice(operand, sum, start_indices);
}
class TensorArrayOp : public XlaOpKernel {
@@ -162,7 +162,7 @@ class TensorArrayOp : public XlaOpKernel {
ta_shape.AddDim(size);
ta_shape.AppendShape(shape);
xla::XlaOp zero = XlaHelpers::Zero(b, dtype_);
- value = b->Broadcast(zero, ta_shape.dim_sizes());
+ value = xla::Broadcast(zero, ta_shape.dim_sizes());
}
XlaContext& xc = XlaContext::Get(ctx);
@@ -215,12 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
TensorShape slice_shape = elem_shape;
slice_shape.InsertDim(0, 1LL);
- auto update = b->Reshape(value, slice_shape.dim_sizes());
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
xla::XlaOp written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
@@ -259,17 +259,17 @@ class TensorArrayReadOp : public XlaOpKernel {
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, ta_shape.dims() - 1}}));
auto slice_shape = ta_shape.dim_sizes();
slice_shape[0] = 1LL;
- xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape);
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
// Remove the leading '1' dimension.
std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
- ctx->SetOutput(0, b->Reshape(read, value_shape));
+ ctx->SetOutput(0, xla::Reshape(read, value_shape));
}
private:
@@ -326,7 +326,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
for (auto i = 1; i < ta_shape.dims(); i++) {
end[i] = ta_shape.dim_size(i);
}
- ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
+ ctx->SetOutput(0, xla::Slice(ta, begin, end, strides));
return;
}
}
@@ -391,7 +391,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
if (scatter_all_elements_in_order) {
- ta = b->Add(ta, value);
+ ta = xla::Add(ta, value);
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
@@ -407,13 +407,13 @@ class TensorArrayScatterOp : public XlaOpKernel {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
- auto slice = b->Slice(value, value_starts, value_ends, value_strides);
+ auto slice = xla::Slice(value, value_starts, value_ends, value_strides);
// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
- auto index = b->Slice(indices, {i}, {i + 1}, {1});
+ auto index = xla::Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
- b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
- xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}
@@ -452,7 +452,7 @@ class TensorArrayConcatOp : public XlaOpKernel {
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
shape[0] *= ta_shape.dim_size(0);
- ctx->SetOutput(0, b->Reshape(ta, shape));
+ ctx->SetOutput(0, xla::Reshape(ta, shape));
Tensor lengths(DT_INT64, {ta_dims[0]});
auto lengths_vec = lengths.vec<int64>();
@@ -522,8 +522,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
- ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(xla::Add(
+ ta, xla::Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
index e91075196b..c9e5694262 100644
--- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -93,9 +94,9 @@ class TileOp : public XlaOpKernel {
if (one_dimension_is_broadcasted_without_multiple) {
// Create a constant Zero the size of the output shape to leverage binary
// operation broadcast semantics.
- auto broadcasted_zero = ctx->builder()->Broadcast(
+ auto broadcasted_zero = xla::Broadcast(
XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape);
- ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input));
+ ctx->SetOutput(0, xla::Add(broadcasted_zero, input));
return;
}
@@ -103,7 +104,7 @@ class TileOp : public XlaOpKernel {
// dimension. This prepends the broadcasted dimensions, so an
// input of shape [2,3,1] broadcast with multiples [5,4,3] will
// end up with shape [5,4,3,2,3,1].
- auto broadcasted = ctx->builder()->Broadcast(input, multiples_array);
+ auto broadcasted = xla::Broadcast(input, multiples_array);
// Now flatten and reshape. The broadcasted dimensions are
// paired with the original dimensions so in the above example
// we flatten [0,3,1,4,2,5] then reshape to [10,12,3].
@@ -112,8 +113,7 @@ class TileOp : public XlaOpKernel {
flattened.push_back(i);
flattened.push_back(i + output_shape.size());
}
- xla::XlaOp output =
- ctx->builder()->Reshape(broadcasted, flattened, output_shape);
+ xla::XlaOp output = xla::Reshape(broadcasted, flattened, output_shape);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
new file mode 100644
index 0000000000..1ddcb08c8e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+namespace {
+
+class TopKOp : public XlaOpKernel {
+ public:
+ explicit TopKOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
+ }
+
+ void Compile(XlaOpKernelContext* context) override {
+ int64 k;
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(1, &k));
+ OP_REQUIRES(context, k >= 0,
+ errors::InvalidArgument("Need k >= 0, got ", k));
+ const TensorShape input_shape = context->InputShape(0);
+ OP_REQUIRES(context, input_shape.dims() >= 1,
+ errors::InvalidArgument("input must be >= 1-D, got shape ",
+ input_shape.DebugString()));
+ OP_REQUIRES(
+ context, input_shape.dim_size(input_shape.dims() - 1) >= k,
+ errors::InvalidArgument("input must have at least k columns. Had ",
+ input_shape.dim_size(input_shape.dims() - 1),
+ ", needed ", k));
+
+ OP_REQUIRES(
+ context, input_shape.dims() == 1,
+ errors::Unimplemented("TopK is implemented for 1-D inputs, got shape ",
+ input_shape.DebugString()));
+
+ xla::XlaBuilder* const b = context->builder();
+ if (input_shape.dim_size(0) < k) {
+ k = input_shape.dim_size(0);
+ }
+ const xla::XlaOp input = context->Input(0);
+ xla::XlaOp iota_s32 = xla::Iota(b, xla::S32, input_shape.dim_size(0));
+ xla::XlaOp sort_result = xla::Sort(xla::Neg(input), iota_s32);
+ xla::XlaOp values =
+ xla::Neg(xla::Slice(xla::GetTupleElement(sort_result, 0),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1}));
+ xla::XlaOp indices = xla::Slice(xla::GetTupleElement(sort_result, 1),
+ /*start_indices=*/{0},
+ /*limit_indices=*/{k},
+ /*strides=*/{1});
+ context->SetOutput(0, values);
+ context->SetOutput(1, indices);
+ }
+
+ private:
+ bool sorted_;
+};
+
+REGISTER_XLA_OP(Name("TopKV2").CompileTimeConstInput("k").TypeConstraint(
+ "T", {DT_UINT32, DT_INT32, DT_FLOAT, DT_BFLOAT16}),
+ TopKOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 34caefa050..98df730249 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
@@ -31,7 +33,6 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp handle;
- xla::XlaBuilder* b = ctx->builder();
DataType type = ctx->input_type(1);
TensorShape var_shape;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
@@ -48,7 +49,7 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
- handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
+ handle = handle - ctx->Input(1) * ctx->Input(2);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -56,6 +57,64 @@ REGISTER_XLA_OP(
Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes),
ResourceApplyGradientDescent);
+xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr,
+ xla::XlaOp l1, xla::XlaOp l2,
+ xla::XlaOp grad) {
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp prox_var = var - grad * lr;
+ xla::XlaOp l1_gt_zero = xla::Sign(prox_var) *
+ xla::Max(xla::Abs(prox_var) - lr * l1, zero) /
+ (one + lr * l2);
+ xla::XlaOp l1_le_zero = prox_var / (one + lr * l2);
+ return xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+}
+
+class ResourceApplyProximalGradientDescent : public XlaOpKernel {
+ public:
+ explicit ResourceApplyProximalGradientDescent(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp var;
+ TensorShape var_shape;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+
+ TensorShape alpha_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ TensorShape l1_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ TensorShape l2_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ TensorShape delta_shape = ctx->InputShape(4);
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(delta_shape),
+ errors::InvalidArgument("var and delta do not have the same shape: ",
+ var_shape.DebugString(), " vs ",
+ delta_shape.DebugString()));
+ xla::XlaOp alpha = ctx->Input(1);
+ xla::XlaOp l1 = ctx->Input(2);
+ xla::XlaOp l2 = ctx->Input(3);
+ xla::XlaOp delta = ctx->Input(4);
+ var = ProximalGradientDescentUpdate(var, alpha, l1, l2, delta);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyProximalGradientDescent")
+ .TypeConstraint("T", kFloatTypes),
+ ResourceApplyProximalGradientDescent);
+
class ResourceApplyMomentum : public XlaOpKernel {
public:
explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -63,8 +122,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -97,14 +154,13 @@ class ResourceApplyMomentum : public XlaOpKernel {
xla::XlaOp grad = ctx->Input(3);
xla::XlaOp momentum = ctx->Input(4);
- accum = b->Add(b->Mul(accum, momentum), grad);
+ accum = accum * momentum + grad;
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
- var = b->Sub(
- var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr)));
+ var = var - (grad * lr + accum * momentum * lr);
} else {
- var = b->Sub(var, b->Mul(accum, lr));
+ var = var - accum * lr;
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
@@ -121,8 +177,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
-
DataType type = ctx->input_type(2);
TensorShape var_shape, accum_shape;
@@ -149,10 +203,8 @@ class ResourceApplyAdagrad : public XlaOpKernel {
xla::XlaOp lr = ctx->Input(2);
xla::XlaOp grad = ctx->Input(3);
- accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
- var = b->Sub(
- var, b->Mul(b->Mul(grad, lr),
- b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
+ accum = accum + xla::Square(grad);
+ var = var - grad * lr * xla::Rsqrt(accum);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
@@ -160,6 +212,139 @@ class ResourceApplyAdagrad : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes),
ResourceApplyAdagrad);
+class ResourceApplyProximalAdagrad : public XlaOpKernel {
+ public:
+ explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape;
+ xla::XlaOp var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ TensorShape lr_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ TensorShape l1_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ TensorShape l2_shape = ctx->InputShape(4);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(5);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape: ",
+ var_shape.DebugString(), " vs ", grad_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp l1 = ctx->Input(3);
+ xla::XlaOp l2 = ctx->Input(4);
+ xla::XlaOp grad = ctx->Input(5);
+ accum = accum + xla::Square(grad);
+ // Adagrad learning rate.
+ xla::XlaOp adagrad_lr = lr * xla::Rsqrt(accum);
+ var = ProximalGradientDescentUpdate(var, adagrad_lr, l1, l2, grad);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyProximalAdagrad").TypeConstraint("T", kFloatTypes),
+ ResourceApplyProximalAdagrad);
+
+class ResourceApplyAdagradDA : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdagradDA(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, squared_accum_shape;
+ xla::XlaOp var, accum, squared_accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &squared_accum_shape,
+ &squared_accum));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(squared_accum_shape),
+ errors::InvalidArgument(
+ "var and squared accum do not have the same shape",
+ var_shape.DebugString(), " ", squared_accum_shape.DebugString()));
+
+ TensorShape grad_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape l1_shape = ctx->InputShape(5);
+ TensorShape l2_shape = ctx->InputShape(6);
+ TensorShape global_step_shape = ctx->InputShape(7);
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l1_shape),
+ errors::InvalidArgument("l1 is not a scalar: ",
+ l1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(l2_shape),
+ errors::InvalidArgument("l2 is not a scalar: ",
+ l2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(global_step_shape),
+ errors::InvalidArgument("global step is not a scalar: ",
+ global_step_shape.DebugString()));
+
+ xla::XlaOp grad = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp l1 = ctx->Input(5);
+ xla::XlaOp l2 = ctx->Input(6);
+ xla::XlaBuilder* const b = ctx->builder();
+ xla::XlaOp global_step =
+ XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_);
+
+ accum = accum + grad;
+ squared_accum = squared_accum + xla::Square(grad);
+ xla::XlaOp zero = xla::ScalarLike(lr, 0.0);
+ xla::XlaOp denominator = global_step * lr * l2 + xla::Sqrt(squared_accum);
+ xla::XlaOp l1_le_zero = -lr * accum / denominator;
+ xla::XlaOp l1_gt_zero = -lr * xla::Sign(accum) *
+ xla::Max(xla::Abs(accum) - global_step * l1, zero) /
+ denominator;
+
+ var = xla::Select(xla::Gt(l1, zero), l1_gt_zero, l1_le_zero);
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, squared_accum));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdagradDA").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdagradDA);
+
class ResourceApplyAdam : public XlaOpKernel {
public:
explicit ResourceApplyAdam(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -227,17 +412,12 @@ class ResourceApplyAdam : public XlaOpKernel {
// variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon)
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
- xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
- xla::XlaOp alpha =
- b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)),
- b->Sub(one, beta1_power));
- m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1)));
- v = b->Add(v, b->Mul(b->Sub(b->Pow(grad, two), v), b->Sub(one, beta2)));
- var =
- b->Sub(var, b->Div(b->Mul(m, alpha), b->Add(b->Pow(v, half), epsilon)));
+ xla::XlaOp alpha = lr * xla::Sqrt(one - beta2_power) / (one - beta1_power);
+ m = m + (grad - m) * (one - beta1);
+ v = v + (xla::Square(grad) - v) * (one - beta2);
+ var = var - m * alpha / (xla::Sqrt(v) + epsilon);
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
@@ -250,38 +430,112 @@ class ResourceApplyAdam : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes),
ResourceApplyAdam);
-class ResourceApplyRMSProp : public XlaOpKernel {
+class ResourceApplyAdaMax : public XlaOpKernel {
public:
- explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ResourceApplyAdaMax(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaBuilder* b = ctx->builder();
+ TensorShape var_shape, m_shape, v_shape;
+ xla::XlaOp var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
- DataType type = ctx->input_type(3);
+ TensorShape beta1_power_shape = ctx->InputShape(3);
+ TensorShape lr_shape = ctx->InputShape(4);
+ TensorShape beta1_shape = ctx->InputShape(5);
+ TensorShape beta2_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape(8);
- TensorShape var_shape, ms_shape, mom_shape;
- xla::XlaOp var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power_shape),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_shape),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2_shape),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(v_shape),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var_shape.DebugString(), " ",
+ v_shape.DebugString()));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
- TensorShape lr_shape = ctx->InputShape(3);
+ xla::XlaOp beta1_power = ctx->Input(3);
+ xla::XlaOp lr = ctx->Input(4);
+ xla::XlaOp beta1 = ctx->Input(5);
+ xla::XlaOp beta2 = ctx->Input(6);
+ xla::XlaOp epsilon = ctx->Input(7);
+ xla::XlaOp grad = ctx->Input(8);
+
+ xla::XlaOp one = xla::ScalarLike(lr, 1.0);
+ m = beta1 * m + (one - beta1) * grad;
+ v = xla::Max(beta2 * v, xla::Abs(grad));
+ var = var - lr / (one - beta1_power) * (m / (v + epsilon));
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, v));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdaMax").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdaMax);
+
+class ResourceApplyRMSProp : public XlaOpKernel {
+ public:
+ explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, ms_shape, mom_shape, mg_shape;
+ xla::XlaOp var, ms, mom, mg;
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("var", dtype_, &var_shape, &var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("mg", dtype_, &mg_shape, &mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput("ms", dtype_, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput("mom", dtype_, &mom_shape, &mom));
+
+ TensorShape lr_shape = ctx->InputShape("lr");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
- TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape rho_shape = ctx->InputShape("rho");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
errors::InvalidArgument("rho is not a scalar: ",
rho_shape.DebugString()));
- TensorShape momentum_shape = ctx->InputShape(5);
+ TensorShape momentum_shape = ctx->InputShape("momentum");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- TensorShape epsilon_shape = ctx->InputShape(6);
+ TensorShape epsilon_shape = ctx->InputShape("epsilon");
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
errors::InvalidArgument("epsilon is not a scalar: ",
epsilon_shape.DebugString()));
- TensorShape grad_shape = ctx->InputShape(7);
+ TensorShape grad_shape = ctx->InputShape("grad");
// var should be the same shape as mom and ms.
OP_REQUIRES(ctx, var_shape.IsSameSize(ms_shape),
@@ -297,11 +551,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::XlaOp lr = ctx->Input(3);
- xla::XlaOp rho = ctx->Input(4);
- xla::XlaOp momentum = ctx->Input(5);
- xla::XlaOp epsilon = ctx->Input(6);
- xla::XlaOp grad = ctx->Input(7);
+ xla::XlaOp lr = ctx->Input("lr");
+ xla::XlaOp rho = ctx->Input("rho");
+ xla::XlaOp momentum = ctx->Input("momentum");
+ xla::XlaOp epsilon = ctx->Input("epsilon");
+ xla::XlaOp grad = ctx->Input("grad");
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
@@ -320,25 +574,46 @@ class ResourceApplyRMSProp : public XlaOpKernel {
// ms <- grad**2 (1 - rho) + ms * rho
//
// Which is the equation listed above.
- xla::XlaOp new_ms = b->Add(
- ms,
- b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms),
- b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho)));
- xla::XlaOp new_mom =
- b->Add(b->Mul(mom, momentum),
- b->Mul(b->Mul(grad, lr),
- b->Pow(b->Add(new_ms, epsilon),
- XlaHelpers::FloatLiteral(b, type, -0.5))));
- xla::XlaOp new_var = b->Sub(var, new_mom);
-
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom));
+ xla::XlaOp one = xla::ScalarLike(ms, 1.0);
+ xla::XlaOp new_ms = xla::Square(grad) * (one - rho) + ms * rho;
+ xla::XlaOp denominator;
+ if (centered_) {
+ mg = grad * (one - rho) + mg * rho;
+ denominator = new_ms - xla::Square(mg) + epsilon;
+ } else {
+ denominator = new_ms + epsilon;
+ }
+ xla::XlaOp new_mom = mom * momentum + grad * lr * xla::Rsqrt(denominator);
+ xla::XlaOp new_var = var - new_mom;
+
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("var", dtype_, new_var));
+ if (centered_) {
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mg", dtype_, mg));
+ }
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("ms", dtype_, new_ms));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable("mom", dtype_, new_mom));
}
+
+ protected:
+ bool centered_ = false;
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes),
ResourceApplyRMSProp);
+class ResourceApplyCenteredRMSProp : public ResourceApplyRMSProp {
+ public:
+ explicit ResourceApplyCenteredRMSProp(OpKernelConstruction* ctx)
+ : ResourceApplyRMSProp(ctx) {
+ centered_ = true;
+ }
+};
+REGISTER_XLA_OP(
+ Name("ResourceApplyCenteredRMSProp").TypeConstraint("T", kFloatTypes),
+ ResourceApplyCenteredRMSProp);
+
void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::XlaBuilder* b = ctx->builder();
@@ -424,21 +699,18 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
xla::XlaOp grad_to_use;
if (has_l2_shrinkage) {
- grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var)));
+ grad_to_use = grad + two * l2_shrinkage * var;
} else {
grad_to_use = grad;
}
- xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two));
- xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power));
- xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power));
- linear = b->Add(
- linear,
- b->Sub(grad_to_use,
- b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var)));
- xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1);
- xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2));
- var = b->Div(b->Sub(linear_clipped, linear), quadratic);
+ xla::XlaOp new_accum = accum + xla::Square(grad_to_use);
+ xla::XlaOp new_accum_lr_pow = xla::Pow(new_accum, -lr_power);
+ xla::XlaOp accum_lr_pow = xla::Pow(accum, -lr_power);
+ linear = linear + grad_to_use - (new_accum_lr_pow - accum_lr_pow) / lr * var;
+ xla::XlaOp linear_clipped = xla::Clamp(-l1, linear, l1);
+ xla::XlaOp quadratic = new_accum_lr_pow / lr + two * l2;
+ var = (linear_clipped - linear) / quadratic;
accum = new_accum;
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype, var));
@@ -478,5 +750,176 @@ class ResourceApplyFtrlV2 : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes),
ResourceApplyFtrlV2);
+class ResourceApplyAdadelta : public XlaOpKernel {
+ public:
+ explicit ResourceApplyAdadelta(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, accum_shape, accum_update_shape;
+ xla::XlaOp var, accum, accum_update;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(1, dtype_, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &accum_update_shape,
+ &accum_update));
+
+ TensorShape lr_shape = ctx->InputShape(3);
+ TensorShape rho_shape = ctx->InputShape(4);
+ TensorShape epsilon_shape = ctx->InputShape(5);
+ TensorShape grad_shape = ctx->InputShape(6);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(rho_shape),
+ errors::InvalidArgument("rho is not a scalar: ",
+ rho_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
+ errors::InvalidArgument(
+ "var and accum do not have the same shape",
+ var_shape.DebugString(), " ", accum_shape.DebugString()));
+
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+
+ xla::XlaOp lr = ctx->Input(3);
+ xla::XlaOp rho = ctx->Input(4);
+ xla::XlaOp epsilon = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp neg_half = XlaHelpers::FloatLiteral(b, dtype_, -0.5);
+ xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5);
+ xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0);
+ xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0);
+
+ accum = rho * accum + (one - rho) * xla::Pow(grad, two);
+ xla::XlaOp update = xla::Pow(accum_update + epsilon, half) *
+ xla::Pow(accum + epsilon, neg_half) * grad;
+ accum_update = rho * accum_update + (one - rho) * xla::Pow(update, two);
+ var = var - update * lr;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, accum));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, dtype_, accum_update));
+ }
+
+ private:
+ DataType dtype_;
+};
+REGISTER_XLA_OP(Name("ResourceApplyAdadelta").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAdadelta);
+
+class ResourceApplySignBase : public XlaOpKernel {
+ public:
+ explicit ResourceApplySignBase(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape var_shape, m_shape;
+ xla::XlaOp var, m;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES(ctx, var_shape.IsSameSize(m_shape),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var_shape.DebugString(), " ",
+ m_shape.DebugString()));
+ TensorShape grad_shape = ctx->InputShape(6);
+ OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
+ errors::InvalidArgument(
+ "var and grad do not have the same shape",
+ var_shape.DebugString(), " ", grad_shape.DebugString()));
+ CheckScalarParams(ctx);
+
+ xla::XlaOp lr = ctx->Input(2);
+ xla::XlaOp alpha = ctx->Input(3);
+ xla::XlaOp sign_decay = ctx->Input(4);
+ xla::XlaOp beta = ctx->Input(5);
+ xla::XlaOp grad = ctx->Input(6);
+
+ m = m * beta + grad * (xla::ScalarLike(beta, 1.0) - beta);
+ xla::XlaOp decay = xla::Sign(grad) * xla::Sign(m) * sign_decay;
+
+ xla::XlaOp grad_scale = ComputeGradientScale(alpha, decay);
+ var = var - lr * grad_scale * grad;
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, var));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, dtype_, m));
+ }
+
+ virtual void CheckScalarParams(XlaOpKernelContext* ctx) {
+ TensorShape lr_shape = ctx->InputShape(2);
+ TensorShape sign_decay_shape = ctx->InputShape(4);
+ TensorShape beta_shape = ctx->InputShape(5);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
+ errors::InvalidArgument("lr is not a scalar: ",
+ lr_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sign_decay_shape),
+ errors::InvalidArgument("sign_decay is not a scalar: ",
+ sign_decay_shape.DebugString()));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta_shape),
+ errors::InvalidArgument("beta is not a scalar: ",
+ beta_shape.DebugString()));
+ }
+
+ virtual xla::XlaOp ComputeGradientScale(xla::XlaOp alpha,
+ xla::XlaOp decay) = 0;
+
+ private:
+ DataType dtype_;
+};
+
+class ResourceApplyAddSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyAddSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape alpha_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return alpha + decay;
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyAddSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyAddSign);
+
+class ResourceApplyPowerSign : public ResourceApplySignBase {
+ public:
+ explicit ResourceApplyPowerSign(OpKernelConstruction* ctx)
+ : ResourceApplySignBase(ctx) {}
+
+ void CheckScalarParams(XlaOpKernelContext* ctx) override {
+ ResourceApplySignBase::CheckScalarParams(ctx);
+ TensorShape logbase_shape = ctx->InputShape(3);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(logbase_shape),
+ errors::InvalidArgument("logbase is not a scalar: ",
+ logbase_shape.DebugString()));
+ }
+
+ xla::XlaOp ComputeGradientScale(xla::XlaOp alpha, xla::XlaOp decay) override {
+ return xla::Exp(alpha * decay);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceApplyPowerSign").TypeConstraint("T", kFloatTypes),
+ ResourceApplyPowerSign);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
index c167642174..6c721c48fe 100644
--- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -32,7 +33,8 @@ namespace {
class TransposeOp : public XlaOpKernel {
public:
- explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit TransposeOp(OpKernelConstruction* ctx, bool conjugate = false)
+ : XlaOpKernel(ctx), conjugate_(conjugate) {}
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
@@ -78,19 +80,37 @@ class TransposeOp : public XlaOpKernel {
errors::InvalidArgument(i, " is missing from 'perm' argument."));
}
+ xla::XlaOp transposed;
// 0-D, 1-D, and identity transposes do nothing.
if (dims <= 1 || is_identity) {
- ctx->SetOutput(0, ctx->Input(0));
- return;
+ transposed = ctx->Input(0);
+ } else {
+ transposed = xla::Transpose(ctx->Input(0), transposed_order);
}
- ctx->SetOutput(0,
- ctx->builder()->Transpose(ctx->Input(0), transposed_order));
+ // Conjugate the transposed result if this is ConjugateTransposeOp.
+ if (conjugate_) {
+ ctx->SetOutput(0, xla::Conj(transposed));
+ } else {
+ ctx->SetOutput(0, transposed);
+ }
}
+
+ private:
+ const bool conjugate_;
+};
+
+class ConjugateTransposeOp : public TransposeOp {
+ public:
+ explicit ConjugateTransposeOp(OpKernelConstruction* ctx)
+ : TransposeOp(ctx, /*conjugate=*/true) {}
};
REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp);
+REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstInput("perm"),
+ ConjugateTransposeOp);
+
// InvertPermutation frequently forms part of the gradient of Transpose.
//
// inv = InvertPermutationOp(T<int32> p) takes a permutation of
@@ -127,7 +147,7 @@ class InvertPermutationOp : public XlaOpKernel {
output[d] = i;
}
- ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
+ ctx->SetOutput(0, xla::ConstantR1<int32>(ctx->builder(), output));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index 71a9fd051b..116a020437 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -16,24 +16,26 @@ limitations under the License.
// Native XLA implementations of simple unary Ops
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace tensorflow {
namespace {
-// A subclass of a TlaUnaryOp must build the lambda computation that
-// describes the scalar->scalar function to apply to each element of
-// the input.
#define XLAJIT_MAKE_UNARY(NAME, COMPUTATION) \
class NAME##Op : public XlaOpKernel { \
public: \
explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
void Compile(XlaOpKernelContext* ctx) { \
xla::XlaBuilder* b = ctx->builder(); \
+ (void)b; \
xla::XlaOp x = ctx->Input(0); \
xla::XlaOp y = COMPUTATION; \
ctx->SetOutput(0, y); \
@@ -41,122 +43,100 @@ namespace {
}; \
REGISTER_XLA_OP(Name(#NAME), NAME##Op);
-XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x));
+XLAJIT_MAKE_UNARY(ComplexAbs, xla::Abs(x));
-XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x)));
+XLAJIT_MAKE_UNARY(Angle, xla::Atan2(xla::Imag(x), xla::Real(x)));
-XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
+XLAJIT_MAKE_UNARY(Conj, xla::Conj(x));
// Return x if x>0, otherwise -x.
-XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
+XLAJIT_MAKE_UNARY(Abs, xla::Abs(x));
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
-XLAJIT_MAKE_UNARY(
- Acos,
- b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
- b->Add(XlaHelpers::One(b, input_type(0)), x))));
+XLAJIT_MAKE_UNARY(Acos,
+ xla::ScalarLike(x, 2.0) *
+ xla::Atan2(xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x),
+ xla::ScalarLike(x, 1.0) + x));
// acosh(x) = log(x + sqrt(x^2 - 1))
// = log(x + sqrt((x+1)*(x-1)))
-XLAJIT_MAKE_UNARY(
- Acosh,
- b->Log(b->Add(x,
- b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))),
- b->Sub(x, XlaHelpers::One(b, input_type(0)))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Acosh,
+ xla::Log(x + xla::Sqrt((x + xla::ScalarLike(x, 1.0)) *
+ (x - xla::ScalarLike(x, 1.0)))));
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XLAJIT_MAKE_UNARY(
- Asin,
- b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
- b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
- b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
- b->Mul(x, x)),
- XlaHelpers::FloatLiteral(b, input_type(0),
- 0.5))))));
+ Asin, xla::ScalarLike(x, 2.0) *
+ xla::Atan2(x, xla::ScalarLike(x, 1.0) +
+ xla::Sqrt(xla::ScalarLike(x, 1.0) - x * x)));
// asinh(x) = log(x + sqrt(x^2 + 1))
-XLAJIT_MAKE_UNARY(
- Asinh,
- b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
- XlaHelpers::One(b, input_type(0))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
+XLAJIT_MAKE_UNARY(Asinh,
+ xla::Log(x + xla::Sqrt(x * x + xla::ScalarLike(x, 1.0))));
-XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
+XLAJIT_MAKE_UNARY(Atan, xla::Atan2(x, xla::ScalarLike(x, 1.0)));
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
+XLAJIT_MAKE_UNARY(Atanh, xla::Log((xla::ScalarLike(x, 1.0) + x) /
+ (xla::ScalarLike(x, 1.0) - x)) *
+ xla::ScalarLike(x, 0.5));
+XLAJIT_MAKE_UNARY(Ceil, xla::Ceil(x));
+XLAJIT_MAKE_UNARY(Cos, xla::Cos(x));
+XLAJIT_MAKE_UNARY(Cosh, (xla::Exp(x) + xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
+XLAJIT_MAKE_UNARY(Sin, xla::Sin(x));
+XLAJIT_MAKE_UNARY(Exp, xla::Exp(x));
+
+XLAJIT_MAKE_UNARY(Expm1, xla::Expm1(x));
+
+XLAJIT_MAKE_UNARY(Floor, xla::Floor(x));
+XLAJIT_MAKE_UNARY(IsFinite, xla::IsFinite(x));
XLAJIT_MAKE_UNARY(
- Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
- b->Sub(XlaHelpers::One(b, input_type(0)), x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x));
-XLAJIT_MAKE_UNARY(Cos, b->Cos(x));
-XLAJIT_MAKE_UNARY(Cosh,
- b->Mul(b->Add(b->Exp(x), b->Exp(b->Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Sin, b->Sin(x));
-XLAJIT_MAKE_UNARY(Exp, b->Exp(x));
-
-XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x));
-
-XLAJIT_MAKE_UNARY(Floor, b->Floor(x));
-XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x));
-XLAJIT_MAKE_UNARY(IsInf, b->Eq(b->Abs(x),
- XlaHelpers::FloatLiteral(
- b, input_type(0),
- std::numeric_limits<double>::infinity())));
-XLAJIT_MAKE_UNARY(IsNan, b->Ne(x, x));
+ IsInf,
+ xla::Eq(xla::Abs(x),
+ xla::ScalarLike(x, std::numeric_limits<double>::infinity())));
+XLAJIT_MAKE_UNARY(IsNan, xla::Ne(x, x));
// Return 1/x
-XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x));
-XLAJIT_MAKE_UNARY(Log, b->Log(x));
+XLAJIT_MAKE_UNARY(Inv, xla::ScalarLike(x, 1.0) / x);
+XLAJIT_MAKE_UNARY(Reciprocal, xla::ScalarLike(x, 1.0) / x);
+XLAJIT_MAKE_UNARY(Log, xla::Log(x));
-XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x));
+XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x));
-XLAJIT_MAKE_UNARY(Invert, b->Not(x));
-XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x));
-XLAJIT_MAKE_UNARY(Neg, b->Neg(x));
+XLAJIT_MAKE_UNARY(Invert, xla::Not(x));
+XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x));
+XLAJIT_MAKE_UNARY(Neg, -x);
// Implements Banker's rounding: numbers that are equidistant between two
// integers are rounded towards even.
-static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0);
- auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0);
-
- auto round_val = b->Floor(x);
- auto fraction = b->Sub(x, round_val);
- auto nearest_even_int =
- b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x))));
- auto is_odd = b->Eq(nearest_even_int, one);
- return b->Select(
- b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)),
- b->Add(round_val, one), round_val);
+xla::XlaOp RoundToEven(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ auto one = xla::ScalarLike(x, 1.0);
+ auto two = xla::ScalarLike(x, 2.0);
+
+ auto round_val = xla::Floor(x);
+ auto fraction = x - round_val;
+ auto nearest_even_int = round_val - two * xla::Floor(half * x);
+ auto is_odd = xla::Eq(nearest_even_int, one);
+ return xla::Select(xla::Or(xla::Gt(fraction, half),
+ xla::And(xla::Eq(fraction, half), is_odd)),
+ round_val + one, round_val);
}
-XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x));
-XLAJIT_MAKE_UNARY(Round, Round(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Rint, RoundToEven(x));
+XLAJIT_MAKE_UNARY(Round, RoundToEven(x));
-XLAJIT_MAKE_UNARY(Rsqrt,
- b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5)));
+XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
// Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2.
-static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype,
- const xla::XlaOp& x) {
- auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5);
- return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x))));
+xla::XlaOp Sigmoid(xla::XlaOp x) {
+ auto half = xla::ScalarLike(x, 0.5);
+ return half + half * xla::Tanh(half * x);
}
-XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(b, input_type(0), x));
+XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
-XLAJIT_MAKE_UNARY(Sign, b->Sign(x));
-XLAJIT_MAKE_UNARY(Sinh,
- b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))),
- XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
+XLAJIT_MAKE_UNARY(Sinh, (xla::Exp(x) - xla::Exp(-x)) * xla::ScalarLike(x, 0.5));
// softplus(x) = log(1 + exp(x))
//
@@ -166,24 +146,48 @@ XLAJIT_MAKE_UNARY(Sinh,
//
// This is equivalent to:
// max(x, 0) + log1p(exp(-abs(x)))
-XLAJIT_MAKE_UNARY(Softplus,
- b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))),
- b->Log1p(b->Exp(b->Neg(b->Abs(x))))));
+XLAJIT_MAKE_UNARY(Softplus, xla::Max(x, xla::ScalarLike(x, 0.0)) +
+ xla::Log1p(xla::Exp(-xla::Abs(x))));
// softsign(x) = x / (abs(x) + 1)
-XLAJIT_MAKE_UNARY(Softsign,
- b->Div(x,
- b->Add(b->Abs(x), XlaHelpers::One(b, input_type(0)))));
-XLAJIT_MAKE_UNARY(Sqrt,
- b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)));
-XLAJIT_MAKE_UNARY(Square, b->Mul(x, x));
-XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x)));
-XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x));
-
-XLAJIT_MAKE_UNARY(Real, b->Real(x));
-XLAJIT_MAKE_UNARY(Imag, b->Imag(x));
+XLAJIT_MAKE_UNARY(Softsign, x / (xla::Abs(x) + xla::ScalarLike(x, 1.0)));
+XLAJIT_MAKE_UNARY(Sqrt, xla::Sqrt(x));
+XLAJIT_MAKE_UNARY(Square, x* x);
+XLAJIT_MAKE_UNARY(Tan, xla::Sin(x) / xla::Cos(x));
+XLAJIT_MAKE_UNARY(Tanh, xla::Tanh(x));
+
+XLAJIT_MAKE_UNARY(Real, xla::Real(x));
+XLAJIT_MAKE_UNARY(Imag, xla::Imag(x));
#undef XLAJIT_MAKE_UNARY
+// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial
+// is used outside of this range.
+class ErfOp : public XlaOpKernel {
+ public:
+ explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp x = ctx->Input(0);
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
+ auto y =
+ xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x));
+ ctx->SetOutput(0, y);
+ }
+};
+REGISTER_XLA_OP(Name("Erf"), ErfOp);
+
+class ErfcOp : public XlaOpKernel {
+ public:
+ explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaOp x = ctx->Input(0);
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
+ auto y =
+ xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x));
+ ctx->SetOutput(0, y);
+ }
+};
+REGISTER_XLA_OP(Name("Erfc"), ErfcOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
index f87586ba57..f951127bb9 100644
--- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -22,7 +22,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -74,10 +75,9 @@ class UnpackOp : public XlaOpKernel {
for (int i = 0; i < num; ++i) {
start_indices[axis] = i;
limit_indices[axis] = i + 1;
- auto slice = ctx->builder()->Slice(input, start_indices, limit_indices,
- strides);
+ auto slice = xla::Slice(input, start_indices, limit_indices, strides);
// Reshape to drop the 'axis' dimension.
- auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes());
+ auto result = xla::Reshape(slice, output_shape.dim_sizes());
ctx->SetOutput(i, result);
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index a163fa0a5b..bb27b5d56f 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -13,18 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
+#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/no_op.h"
namespace tensorflow {
namespace {
@@ -35,12 +33,33 @@ class VarIsInitializedOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
XlaResource* variable;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
- ctx->SetOutput(0,
- ctx->builder()->ConstantR0<bool>(variable->initialized()));
+ ctx->SetOutput(
+ 0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized()));
}
};
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
+class VariableShapeOp : public XlaOpKernel {
+ public:
+ explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType variable_dtype;
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape));
+ Tensor shape_constant(out_dtype_, TensorShape({shape.dims()}));
+ OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant));
+ ctx->SetConstantOutput(0, shape_constant);
+ }
+
+ private:
+ DataType out_dtype_;
+};
+REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp);
+
class ReadVariableOp : public XlaOpKernel {
public:
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
@@ -77,7 +96,7 @@ class AssignAddVariableOp : public XlaOpKernel {
xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
- handle = ctx->builder()->Add(handle, ctx->Input(1));
+ handle = xla::Add(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -93,7 +112,7 @@ class AssignSubVariableOp : public XlaOpKernel {
xla::XlaOp handle;
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
- handle = ctx->builder()->Sub(handle, ctx->Input(1));
+ handle = xla::Sub(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
@@ -125,29 +144,152 @@ class ResourceGatherOp : public XlaOpKernel {
ctx->SetOutput(0, gather);
}
};
-REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes),
- ResourceGatherOp);
+REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
-class VariableShapeOp : public XlaOpKernel {
+class ResourceScatterOp : public XlaOpKernel {
public:
- explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ explicit ResourceScatterOp(
+ OpKernelConstruction* context, bool indices_are_vectors,
+ std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
+ xla::XlaBuilder*)>
+ combiner)
+ : XlaOpKernel(context),
+ indices_are_vectors_(indices_are_vectors),
+ combiner_(std::move(combiner)) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaBuilder* builder = context->builder();
+
+ DataType dtype = context->input_type(2);
+ TensorShape var_shape;
+ xla::XlaOp var_value;
+ OP_REQUIRES_OK(
+ context, context->ReadVariableInput(0, dtype, &var_shape, &var_value));
+
+ const xla::XlaOp indices = context->Input(1);
+ const xla::XlaOp updates = context->Input(2);
+
+ auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_,
+ combiner_, builder);
+ OP_REQUIRES_OK(context, result.status());
+ OP_REQUIRES_OK(context,
+ context->AssignVariable(0, dtype, result.ValueOrDie()));
}
- void Compile(XlaOpKernelContext* ctx) override {
- DataType variable_dtype;
- TensorShape shape;
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape));
- Tensor shape_constant(out_dtype_, TensorShape({shape.dims()}));
- OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant));
- ctx->SetConstantOutput(0, shape_constant);
+ private:
+ const bool indices_are_vectors_;
+ const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
+ xla::XlaBuilder*)>
+ combiner_;
+};
+
+class ResourceScatterAddOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterAddOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Add(x, y);
}
+};
+REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp);
+
+class ResourceScatterSubOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterSubOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
private:
- DataType out_dtype_;
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Sub(x, y);
+ }
};
+REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp);
+
+class ResourceScatterMulOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterMulOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Mul(x, y);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp);
+
+class ResourceScatterDivOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterDivOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Div(x, y);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp);
+
+class ResourceScatterMinOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterMinOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Min(x, y);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp);
+
+class ResourceScatterMaxOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterMaxOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Max(x, y);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp);
+
+class ResourceScatterUpdateOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterUpdateOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/false,
+ /*combiner=*/{}) {}
+};
+REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp);
+
+class ResourceScatterNdUpdateOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/true,
+ /*combiner=*/{}) {}
+};
+REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp);
+
+class ResourceScatterNdAddOp : public ResourceScatterOp {
+ public:
+ explicit ResourceScatterNdAddOp(OpKernelConstruction* context)
+ : ResourceScatterOp(context, /*indices_are_vectors=*/true,
+ /*combiner=*/Combine) {}
+
+ private:
+ static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
+ xla::XlaBuilder* builder) {
+ return xla::Add(x, y);
+ }
+};
+REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp);
-REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 5467c5d994..9413a30a6c 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -246,7 +246,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
}
- xla::XlaOp init = builder->Tuple(inputs);
+ xla::XlaOp init = xla::Tuple(builder, inputs);
VLOG(1) << "Building while loop";
@@ -255,22 +255,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
{
std::unique_ptr<xla::XlaBuilder> cb =
builder->CreateSubBuilder("cond_wrapper");
- auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
- auto outputs = cb->Call(*cond.computation, {inputs});
- cb->GetTupleElement(outputs, 0);
+ auto inputs = xla::Parameter(cb.get(), 0, cond_input_shape, "inputs");
+ auto outputs = xla::Call(cb.get(), *cond.computation, {inputs});
+ xla::GetTupleElement(outputs, 0);
xla::StatusOr<xla::XlaComputation> result = cb->Build();
OP_REQUIRES_OK(ctx, result.status());
cond_wrapper = std::move(result.ValueOrDie());
}
- xla::XlaOp while_result =
- builder->While(cond_wrapper, *body.computation, init);
+ xla::XlaOp while_result = xla::While(cond_wrapper, *body.computation, init);
// Sets non-variable outputs.
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
ctx->SetOutput(body.input_mapping[i],
- builder->GetTupleElement(while_result, i));
+ xla::GetTupleElement(while_result, i));
}
}
@@ -284,7 +283,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- builder->GetTupleElement(while_result, pos), builder));
+ xla::GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
index 856475f12c..661505021f 100644
--- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc
+++ b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
+// Legacy flags for the XLA bridge's backend registration modules.
-#include <mutex>
+#include <mutex> // NOLINT
#include <vector>
-#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
+#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -28,33 +28,33 @@ namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
-static EncapsulateSubgraphsPassFlags* flags;
+static BackendRegistrationFlags* flags;
static std::vector<Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
- flags = new EncapsulateSubgraphsPassFlags;
- flags->tf_xla_parallel_checking = false;
+ flags = new BackendRegistrationFlags;
+ flags->tf_enable_prng_ops_gpu = false;
flag_list = new std::vector<Flag>({
- Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking,
- "Debug tool. Runs both JIT-compiled and interpreted graphs in "
- "parallel and verifies they produce the same outputs."),
+ Flag("tf_enable_prng_ops_gpu", &flags->tf_enable_prng_ops_gpu,
+ "Whether to enable PRNG ops: [RandomStandardNormal | RandomUniform "
+ "| RandomUniformInt | TruncatedNormal] on GPU."),
});
xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with the XLA bridge's
-// encapsulate_subgraphs_pass module.
-void AppendEncapsulateSubgraphsPassFlags(std::vector<Flag>* append_to) {
+// backend registration modules.
+void AppendBackendRegistrationFlags(std::vector<Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
-// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
+// Return a pointer to the BackendRegistrationFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
-EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() {
+BackendRegistrationFlags* GetBackendRegistrationFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
index d371bd269d..861c923dd5 100644
--- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h
+++ b/tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
-#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
-// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
+// Legacy flags for the XLA bridge's backend registration modules.
#include <vector>
@@ -27,24 +27,23 @@ namespace tensorflow {
namespace legacy_flags {
// Append to *flag_list flag definitions associated with the XLA bridge's
-// encapsulate_subgraphs_pass module.
-void AppendEncapsulateSubgraphsPassFlags(
- std::vector<tensorflow::Flag>* flag_list);
+// backend registration modules.
+void AppendBackendRegistrationFlags(std::vector<tensorflow::Flag>* append_to);
-// The values of flags associated with the XLA bridge's
-// encapsulate_subgraphs_pass module.
+// The values of flags associated with the XLA bridge's backend registration
+// module.
typedef struct {
- bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and
- // interpreted graphs in parallel and verifies
- // they produce the same outputs.
-} EncapsulateSubgraphsPassFlags;
+ // Whether to enable RandomUniform op on GPU backend.
+ // TODO (b/32333178): Remove this flag or set its default to true.
+ bool tf_enable_prng_ops_gpu;
+} BackendRegistrationFlags;
-// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
+// Return a pointer to the BackendRegistrationFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
-EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags();
+BackendRegistrationFlags* GetBackendRegistrationFlags();
} // namespace legacy_flags
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
+#endif // TENSORFLOW_COMPILER_TF2XLA_LEGACY_FLAGS_BACKEND_REGISTRATION_FLAGS_H_
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index ee7f5d510a..becc8b84fe 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -40,10 +40,48 @@ cc_library(
":triangular_solve",
":util",
":while_loop",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "random",
+ srcs = ["random.cc"],
+ hdrs = ["random.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "qr",
+ srcs = ["qr.cc"],
+ hdrs = ["qr.h"],
+ deps = [
+ ":batch_dot",
+ ":util",
+ ":while_loop",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:math",
+ "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
@@ -57,7 +95,7 @@ cc_library(
deps = [
":util",
":while_loop",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -76,11 +114,12 @@ cc_library(
deps = [
":batch_dot",
":util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/core:lib",
@@ -94,7 +133,7 @@ xla_test(
deps = [
":triangular_solve",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -116,6 +155,7 @@ cc_library(
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -134,7 +174,7 @@ xla_test(
":batch_dot",
":util",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 526694d5a0..3c4eec081b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -25,91 +26,94 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
- xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x,
- bool conjugate_y) {
- TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
- TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
-
- // Check that both tensors have the same number of dimensions. There must be
- // at least two (the batch dimensions can be empty).
- if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
- return errors::InvalidArgument(
- "Arguments to BatchedDot have different ranks: ",
- xla::ShapeUtil::HumanString(x_shape), " vs. ",
- xla::ShapeUtil::HumanString(y_shape));
- }
- const int ndims = xla::ShapeUtil::Rank(x_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to BatchedDot must have rank >= 2: ", ndims);
- }
-
- // The batch dimensions must be equal and the matrix dimensions must be
- // valid.
- std::vector<int64> batch_dimension_numbers;
- for (int i = 0; i < ndims - 2; ++i) {
- if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
+xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
+ bool transpose_y, bool conjugate_x, bool conjugate_y) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y));
+
+ // Check that both tensors have the same number of dimensions. There must be
+ // at least two (the batch dimensions can be empty).
+ if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) {
return errors::InvalidArgument(
- "Dimension ", i, " of inputs to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(x_shape), " vs ",
+ "Arguments to BatchedDot have different ranks: ",
+ xla::ShapeUtil::HumanString(x_shape), " vs. ",
xla::ShapeUtil::HumanString(y_shape));
}
- batch_dimension_numbers.push_back(i);
- }
-
- int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
- int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
- if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
- return errors::InvalidArgument(
- "Dimensions ", x_inner_dim, " and ", y_inner_dim,
- " of arguments to BatchedDot must be equal: ",
- xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
- " vs. ", xla::ShapeUtil::HumanString(y_shape),
- " transpose: ", transpose_y);
- }
-
- // Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(x_shape) ||
- xla::ShapeUtil::HasZeroElements(y_shape)) {
- std::vector<int64> dimensions(batch_dimension_numbers.size());
- for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
- dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
+ const int ndims = xla::ShapeUtil::Rank(x_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to BatchedDot must have rank >= 2: ", ndims);
+ }
+
+ // The batch dimensions must be equal and the matrix dimensions must be
+ // valid.
+ std::vector<int64> batch_dimension_numbers;
+ for (int i = 0; i < ndims - 2; ++i) {
+ if (x_shape.dimensions(i) != y_shape.dimensions(i)) {
+ return errors::InvalidArgument(
+ "Dimension ", i, " of inputs to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(x_shape), " vs ",
+ xla::ShapeUtil::HumanString(y_shape));
+ }
+ batch_dimension_numbers.push_back(i);
+ }
+
+ int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1);
+ int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2);
+ if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) {
+ return errors::InvalidArgument(
+ "Dimensions ", x_inner_dim, " and ", y_inner_dim,
+ " of arguments to BatchedDot must be equal: ",
+ xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x,
+ " vs. ", xla::ShapeUtil::HumanString(y_shape),
+ " transpose: ", transpose_y);
+ }
+
+ // Check for zero lhs/rhs dim size.
+ if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
+ xla::ShapeUtil::IsZeroElementArray(y_shape)) {
+ std::vector<int64> dimensions(batch_dimension_numbers.size());
+ for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
+ dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
+ }
+ int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
+ int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
+ dimensions.push_back(x_shape.dimensions(x_outer_dim));
+ dimensions.push_back(y_shape.dimensions(y_outer_dim));
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder,
+ xla::LiteralUtil::Zero(x_shape.element_type())),
+ dimensions);
+ }
+
+ if (x_shape.element_type() == xla::C64 && conjugate_x) {
+ x = xla::Conj(x);
+ }
+ if (y_shape.element_type() == xla::C64 && conjugate_y) {
+ y = xla::Conj(y);
+ }
+
+ // If there are no batch dimensions, use a regular Dot.
+ // TODO(b/69062148) Remove this code when Dot emitters can be passed
+ // dimensions to transpose directly (i.e. without requiring a Transpose
+ // HLO).
+ if (batch_dimension_numbers.empty()) {
+ auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
+ auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
+ return xla::Dot(lhs, rhs);
+ }
+
+ xla::DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
+ dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
+ for (auto batch_dimension_number : batch_dimension_numbers) {
+ dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
+ dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
}
- int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2);
- int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1);
- dimensions.push_back(x_shape.dimensions(x_outer_dim));
- dimensions.push_back(y_shape.dimensions(y_outer_dim));
- return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())),
- dimensions);
- }
-
- if (x_shape.element_type() == xla::C64 && conjugate_x) {
- x = builder->Conj(x);
- }
- if (y_shape.element_type() == xla::C64 && conjugate_y) {
- y = builder->Conj(y);
- }
-
- // If there are no batch dimensions, use a regular Dot.
- // TODO(b/69062148) Remove this code when Dot emitters can be passed
- // dimensions to transpose directly (i.e. without requiring a Transpose HLO).
- if (batch_dimension_numbers.empty()) {
- auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x;
- auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y;
- return builder->Dot(lhs, rhs);
- }
-
- xla::DotDimensionNumbers dot_dnums;
- dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
- dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
- for (auto batch_dimension_number : batch_dimension_numbers) {
- dot_dnums.add_lhs_batch_dimensions(batch_dimension_number);
- dot_dnums.add_rhs_batch_dimensions(batch_dimension_number);
- }
- return builder->DotGeneral(x, y, dot_dnums);
+ return xla::DotGeneral(x, y, dot_dnums);
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h
index 1acc72033b..d07a9486f1 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.h
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h
@@ -43,10 +43,9 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
-xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
- xla::XlaOp y, bool transpose_x,
- bool transpose_y, bool conjugate_x = false,
- bool conjugate_y = false);
+xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false,
+ bool transpose_y = false, bool conjugate_x = false,
+ bool conjugate_y = false);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc
index 3f1384bc86..35b137aa2c 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.cc
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc
@@ -22,7 +22,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -47,179 +49,163 @@ namespace {
// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
// l[..., j, j]
// return l
-xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
- const xla::XlaOp& a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- const int n_dims = xla::ShapeUtil::Rank(a_shape);
- const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - 2);
-
- xla::XlaOp l = Zeros(builder, a_shape);
-
- // Construct the for loop body to iterate over rows.
- auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
- xla::XlaBuilder* body_builder)
- -> xla::StatusOr<std::vector<xla::XlaOp>> {
- xla::Shape col_shape;
- xla::Shape row_shape;
- for (int64 d : major_dims) {
- row_shape.add_dimensions(d);
- col_shape.add_dimensions(d);
- }
- row_shape.add_dimensions(1);
- row_shape.add_dimensions(n);
- row_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_row = Zeros(body_builder, row_shape);
-
- col_shape.add_dimensions(n);
- col_shape.add_dimensions(1);
- col_shape.set_element_type(a_shape.element_type());
- auto mask_zeros_col = Zeros(body_builder, col_shape);
-
- std::vector<int32> mask_vector(n);
- std::iota(mask_vector.begin(), mask_vector.end(), 0);
- auto mask_range = body_builder->ConstantR1<int32>(mask_vector);
- auto mask_range_row = body_builder->Broadcast(
- body_builder->Reshape(mask_range, {0}, {1, n}), major_dims);
- auto mask_range_col = body_builder->Broadcast(
- body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims);
- auto body_a = loop_vars[0];
- auto body_l = loop_vars[1];
-
- // row = l[..., i, :i]
- // select the whole i-th row, then mask out all columns past i-1
- auto zero = body_builder->ConstantR0<int32>(0);
- TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l,
- {i, zero}, {1, n}));
- auto row = body_builder->Select(body_builder->Ge(mask_range_row, i),
- mask_zeros_row, l_i);
- // a[..., i, i]
- TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
- {i, i}, {1, 1}));
- // np.dot(row, np.swapaxes(row, -1, -2))
- xla::XlaOp diag_dot;
- TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
- /*transpose_x=*/false,
- /*transpose_y=*/true));
- // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
- // np.swapaxes(row, -1, -2)))
- auto l_ii = body_builder->Pow(
- body_builder->Sub(a_ii, diag_dot),
- FloatLiteral(body_builder, a_shape.element_type(), 0.5));
-
- // a[..., i+1:, i]
- auto ip1 = body_builder->Add(i, body_builder->ConstantR0<int32>(1));
- // select the whole i-th column, then mask out all rows above i+1
+xla::XlaOp CholeskyUnblocked(xla::XlaOp a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int n_dims = xla::ShapeUtil::Rank(a_shape);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - 2);
+
+ xla::XlaOp l = xla::ZerosLike(a);
+
+ // Construct the for loop body to iterate over rows.
+ auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
+ xla::XlaBuilder* body_builder)
+ -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ xla::Shape col_shape;
+ xla::Shape row_shape;
+ for (int64 d : major_dims) {
+ row_shape.add_dimensions(d);
+ col_shape.add_dimensions(d);
+ }
+ row_shape.add_dimensions(1);
+ row_shape.add_dimensions(n);
+ row_shape.set_element_type(a_shape.element_type());
+ auto mask_zeros_row = xla::Zeros(body_builder, row_shape);
+
+ col_shape.add_dimensions(n);
+ col_shape.add_dimensions(1);
+ col_shape.set_element_type(a_shape.element_type());
+ auto mask_zeros_col = xla::Zeros(body_builder, col_shape);
+
+ std::vector<int32> mask_vector(n);
+ std::iota(mask_vector.begin(), mask_vector.end(), 0);
+ auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
+ auto mask_range_row =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
+ auto mask_range_col =
+ xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
+ auto body_a = loop_vars[0];
+ auto body_l = loop_vars[1];
+
+ // row = l[..., i, :i]
+ // select the whole i-th row, then mask out all columns past i-1
+ auto zero = xla::ConstantR0<int32>(body_builder, 0);
+ auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n});
+ auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
+ // a[..., i, i]
+ auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ // np.dot(row, np.swapaxes(row, -1, -2))
+ auto diag_dot = BatchDot(row, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
+ // np.swapaxes(row, -1, -2)))
+ auto l_ii =
+ xla::Pow(a_ii - diag_dot,
+ FloatLiteral(body_builder, a_shape.element_type(), 0.5));
+
+ // a[..., i+1:, i]
+ // select the whole i-th column, then mask out all rows above i+1
+ auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1});
+ auto a_ip1i =
+ xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);
+
+ // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
+ // l[..., i, i]
+ // The columns in [i, n] are zeroed out in `row`, so we just have to
+ // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
+ // r.T)
+ auto dot = BatchDot(body_l, row,
+ /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ // np.dot(l[..., i+1:, :i], r.T)
+ auto dot_ip1 =
+ xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);
+
+ body_l =
+ DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i});
+ // Assign the diagonal after the rest of the column because otherwise the
+ // column assign will wrap around and overwrite the diagonal assign.
+ body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i});
+
+ return std::vector<xla::XlaOp>{body_a, body_l};
+ };
+
TF_ASSIGN_OR_RETURN(
- auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
- auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i),
- mask_zeros_col, a_0i);
-
- // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
- // l[..., i, i]
- // The columns in [i, n] are zeroed out in `row`, so we just have to
- // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
- // r.T)
- TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row,
- /*transpose_x=*/false,
- /*transpose_y=*/true));
- // np.dot(l[..., i+1:, :i], r.T)
- auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i),
- mask_zeros_col, dot);
-
- auto col_update =
- body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii);
- TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
- body_builder, body_l, col_update, {i}));
- // Assign the diagonal after the rest of the column because otherwise the
- // column assign will wrap around and overwrite the diagonal assign.
- TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
- body_builder, body_l, l_ii, {i, i}));
-
- return std::vector<xla::XlaOp>{body_a, body_l};
- };
-
- TF_ASSIGN_OR_RETURN(
- auto cholesky_while,
- XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
-
- return cholesky_while[1];
+ auto cholesky_while,
+ XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));
+
+ return cholesky_while[1];
+ });
}
} // namespace
-xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- const int ndims = xla::ShapeUtil::Rank(a_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to Cholesky must have rank >= 2: ", ndims);
- }
-
- const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
- if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
- return errors::InvalidArgument(
- "Arguments to Cholesky must be square matrices: ",
- xla::ShapeUtil::HumanString(a_shape));
- }
-
- if (block_size < 1) {
- return errors::InvalidArgument(
- "block_size argument to Cholesky must be >= 1; got ", block_size);
- }
-
- // Blocked left-looking Cholesky factorization.
- // Algorithm 1 from
- // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
- // execution." Proceedings of General Purpose GPUs. ACM, 2017.
- xla::XlaOp l = Zeros(builder, a_shape);
- for (int64 i = 0; i < n; i += block_size) {
- int64 k = std::min(block_size, n - i);
- if (i > 0) {
- // TODO(phawkins): consider implementing SYRK for the diagonal part of
- // the panel.
- // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
- TF_ASSIGN_OR_RETURN(auto lhs,
- SliceInMinorDims(builder, l, {i, 0}, {n, i}));
- TF_ASSIGN_OR_RETURN(auto rhs,
- SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
- TF_ASSIGN_OR_RETURN(auto delta,
- BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
- /*transpose_y=*/true, /*conjugate_x=*/false,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto before,
- SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
- TF_ASSIGN_OR_RETURN(
- a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta),
- {i, i}));
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must have rank >= 2: ", ndims);
+ }
+
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
+ return errors::InvalidArgument(
+ "Arguments to Cholesky must be square matrices: ",
+ xla::ShapeUtil::HumanString(a_shape));
+ }
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to Cholesky must be >= 1; got ", block_size);
}
- // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
- TF_ASSIGN_OR_RETURN(auto x,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x));
- TF_ASSIGN_OR_RETURN(l,
- UpdateSliceInMinorDims(builder, l, factorized, {i, i}));
-
- if (i + k < n) {
- // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
- TF_ASSIGN_OR_RETURN(auto panel,
- SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
- TF_ASSIGN_OR_RETURN(auto update,
- TriangularSolve(builder, factorized, panel,
- /*left_side=*/false,
- /*lower=*/true,
- /*transpose_a=*/true,
- /*conjugate_a=*/false,
- /*block_size=*/block_size));
- TF_ASSIGN_OR_RETURN(
- l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
+ // Blocked left-looking Cholesky factorization.
+ // Algorithm 1 from
+ // Haidar, Azzam, et al. "High-performance Cholesky factorization for
+ // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
+ xla::XlaOp l = xla::ZerosLike(a);
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+ if (i > 0) {
+ // TODO(phawkins): consider implementing SYRK for the diagonal part of
+ // the panel.
+ // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
+ auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
+ auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
+ auto delta = BatchDot(lhs, rhs, /*transpose_x=*/false,
+ /*transpose_y=*/true);
+ auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
+ a = UpdateSliceInMinorDims(a, before - delta, {i, i});
+ }
+
+ // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
+ auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto factorized = CholeskyUnblocked(x);
+ l = UpdateSliceInMinorDims(l, factorized, {i, i});
+
+ if (i + k < n) {
+ // l[i+k:, i:i+k] =
+ // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
+ auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k});
+ auto update = TriangularSolve(factorized, panel,
+ /*left_side=*/false,
+ /*lower=*/true,
+ /*transpose_a=*/true,
+ /*conjugate_a=*/false,
+ /*block_size=*/block_size);
+ l = UpdateSliceInMinorDims(l, update, {i + k, i});
+ }
}
- }
- return l;
+ return l;
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h
index 20fca7969e..0f6e0e9d15 100644
--- a/tensorflow/compiler/tf2xla/lib/cholesky.h
+++ b/tensorflow/compiler/tf2xla/lib/cholesky.h
@@ -30,8 +30,7 @@ namespace tensorflow {
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
-xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
- int64 block_size = 256);
+xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc
new file mode 100644
index 0000000000..9c8ac7af25
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/qr.cc
@@ -0,0 +1,387 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/qr.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
+#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Computes a Householder reflection of the form:
+// H = I - tau v v.T.
+// such that
+// H . ( x1 ) = ( x1 )
+// ( x2 ) = ( x2 )
+// ( ... ) = ( ... )
+// ( xk ) = ( beta )
+// ( ... ) ( 0 )
+// ( ... ) ( 0 )
+// Unlike the usual formulation, we allow the caller to supply 'k' rather than
+// only providing the relevant part of 'x' to maintain XLA's static shape
+// invariant. In addition, the implementation supports batching.
+// Pseudo-code, without batching:
+// alpha = x[k]
+// x_copy = np.copy(x)
+// x_copy[:k+1] = 0
+// xnorm = norm2(x_copy)
+// if xnorm == 0:
+// beta = alpha
+// tau = 0
+// v = np.zeros_like(x)
+// else:
+// beta = - np.sign(alpha) * dlapy2(alpha, xnorm)
+// tau = (beta - alpha) / beta
+// v = x / (alpha - beta)
+// v[k] = 1
+// return (v, tau, beta)
+// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
+// overflows in the norm/beta calculations. Perhaps do the same here.
+xla::Status House(xla::XlaOp x, xla::XlaOp k, gtl::ArraySlice<int64> batch_dims,
+ const int64 m, xla::XlaOp* v, xla::XlaOp* tau,
+ xla::XlaOp* beta) {
+ xla::XlaBuilder* const builder = x.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
+ const xla::PrimitiveType type = x_shape.element_type();
+
+ std::vector<int64> batch_dim_ids(batch_dims.size());
+ std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
+ const int64 minor_dim = batch_dims.size();
+
+ xla::XlaOp zero = xla::ScalarLike(x, 0.0);
+ xla::XlaOp one = xla::ScalarLike(x, 1.0);
+
+ // alpha = x[k]
+ xla::XlaOp alpha =
+ xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
+
+ // Compute x[k+1:] (padded with zeros in elements 0..k)
+ xla::XlaOp iota = xla::Iota(builder, xla::S32, m);
+ xla::XlaOp x_after_k =
+ xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type),
+ /*broadcast_dimensions=*/{minor_dim});
+
+ // sigma = np.dot(x[k+1:], x[k+1:])
+ auto sigma =
+ xla::Reduce(x_after_k * x_after_k, zero,
+ xla::CreateScalarAddComputation(type, builder), {minor_dim});
+ // mu = np.sqrt(x[k]*x[k] + sigma)
+ auto mu = xla::Sqrt(xla::Square(alpha) + sigma);
+
+ auto sigma_is_zero = xla::Eq(sigma, zero);
+
+ *beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu);
+ *tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims),
+ (*beta - alpha) / *beta);
+ auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims),
+ alpha - *beta);
+
+ auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type),
+ std::vector<int64>(batch_dims.size(), 1));
+
+ // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
+ // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
+ *v = e_k +
+ xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
+ return Status::OK();
+}
+
+// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van
+// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation
+// used as an inner routine of the blocked implementation.
+// Algorithm is adapted slightly so the shapes inside the loop are static, at
+// the cost of some redundant computation. Since this is used as an inner block
+// kernel, accumulates the Householder transformations (vs, taus) rather than
+// the matrix q.
+// Equivalent Python code, without batching:
+// def qr(a):
+// m = a.shape[0]
+// n = a.shape[1]
+// vs = np.zeros([m, n])
+// taus = np.zeros([n])
+// for j in xrange(min(m, n)):
+// v, tau, beta = house(a[:, j], j)
+// # Unusually, we apply the Householder transformation to the entirety of
+// # a, wasting FLOPs to maintain the static shape invariant that XLA
+// # requires. For columns that precede j this has no effect.
+// a[:, :] -= tau * np.dot(v[:, np.newaxis],
+// np.dot(v[np.newaxis, :], a[:, :]))
+// # Form column j explicitly rather than relying on the precision of the
+// # Householder update.
+// a[j, j] = beta
+// a[j+1:, j] = np.zeros([m - j - 1], dtype=a.dtype)
+// vs[:, j] = v
+// taus[j] = tau
+// return (q, vs, taus)
+struct QRBlockResult {
+ // The factored R value
+ xla::XlaOp r;
+
+ // Representation of the Householder matrices I - beta v v.T
+ xla::XlaOp taus; // Shape: [..., n]
+ xla::XlaOp vs; // Shape: [..., m, n]
+};
+xla::StatusOr<QRBlockResult> QRBlock(xla::XlaOp a) {
+ xla::XlaBuilder* builder = a.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int num_dims = xla::ShapeUtil::Rank(a_shape);
+ if (num_dims < 2) {
+ return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
+ num_dims);
+ }
+ xla::PrimitiveType type = a_shape.element_type();
+
+ const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+
+ const int64 num_batch_dims = num_dims - 2;
+ std::vector<int64> batch_dims(num_batch_dims);
+ for (int i = 0; i < num_batch_dims; ++i) {
+ batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
+ }
+
+ std::vector<int64> batch_dim_indices(num_batch_dims);
+ std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
+
+ auto qr_body_fn =
+ [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto a = values[0];
+ auto vs = values[1];
+ auto taus = values[2];
+
+ // v, beta = house(a[:, j], j)
+ auto x = DynamicSliceInMinorDims(a, {j}, {1});
+ xla::XlaOp v, tau, beta;
+ TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j,
+ batch_dims, m, &v, &tau, &beta));
+
+ std::vector<int64> shape = batch_dims;
+ shape.push_back(1);
+ shape.push_back(m);
+ auto v_broadcast = xla::Reshape(v, shape);
+ // a[:, :] -= tau * np.dot(v[:, np.newaxis],
+ // np.dot(v[np.newaxis, :], a[:, :]))
+ auto vva = BatchDot(v_broadcast, a);
+ vva = BatchDot(v_broadcast, vva, /*transpose_x=*/true);
+ a = a - xla::Mul(tau, vva,
+ /*broadcast_dimensions=*/batch_dim_indices);
+
+ // It is more precise to populate column 'k' explicitly, rather than
+ // computing it implicitly by applying the Householder transformation.
+ // a[k,k] = beta
+ // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
+ auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1});
+ auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type);
+ auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type),
+ std::vector<int64>(batch_dims.size(), 1));
+ auto new_x =
+ xla::Mul(x, predecessor_mask,
+ /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
+ xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
+ a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
+
+ // vs[:, j] = v
+ vs = DynamicUpdateSliceInMinorDims(
+ vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
+ // taus[j] = tau
+ taus = DynamicUpdateSliceInMinorDims(
+ taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
+ return std::vector<xla::XlaOp>{a, vs, taus};
+ };
+
+ auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
+ type, ConcatVectors(batch_dims, {m, n})));
+ auto taus = xla::Zeros(
+ builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
+
+ TF_ASSIGN_OR_RETURN(auto values,
+ XlaForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
+ {a, vs, taus}, "qr", builder));
+
+ QRBlockResult result;
+ result.r = values[0];
+ result.vs = values[1];
+ result.taus = values[2];
+ return result;
+}
+
+// Computes W and Y such that I-WY is equivalent to the sequence of Householder
+// transformations given by vs and taus.
+// Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
+// Y = np.zeros([m, n])
+// W = np.zeros([m, n])
+// Y[:, 0] = vs[:, 0]
+// W[:, 0] = -taus[0] * vs[:, 0]
+// for j in xrange(1, n):
+// v = vs[:, j]
+// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
+// W[:, j] = z
+// Y[:, j] = v
+// return W
+// There is no need to return Y since at termination of the loop it is equal to
+// vs.
+xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
+ xla::PrimitiveType type, gtl::ArraySlice<int64> batch_dims, xla::XlaOp vs,
+ xla::XlaOp taus, int64 m, int64 n) {
+ std::vector<int64> batch_dim_indices(batch_dims.size());
+ std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
+ int64 n_index = batch_dims.size() + 1;
+
+ auto body_fn =
+ [&](xla::XlaOp j, gtl::ArraySlice<xla::XlaOp> values,
+ xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
+ auto w = values[0];
+ auto y = values[1];
+ const auto vs = values[2];
+ const auto taus = values[3];
+
+ // Want j values in range [1, ... n).
+ j = j + xla::ConstantR0<int32>(builder, 1);
+ // vs has shape [..., m, 1]
+ auto v = DynamicSliceInMinorDims(vs, {j}, {1});
+ // beta has shape [..., 1]
+ auto beta = DynamicSliceInMinorDims(taus, {j}, {1});
+
+ // yv has shape [..., n, 1]
+ auto yv = BatchDot(y, v, /*transpose_x=*/true);
+ // wyv has shape [..., m, 1]
+ auto wyv = BatchDot(w, yv);
+
+ auto z = xla::Mul(
+ -beta, v + wyv,
+ /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
+
+ w = DynamicUpdateSliceInMinorDims(w, z, {j});
+ y = DynamicUpdateSliceInMinorDims(y, v, {j});
+
+ return std::vector<xla::XlaOp>{w, y, vs, taus};
+ };
+
+ xla::XlaBuilder* builder = vs.builder();
+ auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
+ type, ConcatVectors(batch_dims, {m, n})));
+ auto y = w;
+ auto v = SliceInMinorDims(vs, {0}, {1});
+ auto beta = SliceInMinorDims(taus, {0}, {1});
+ y = UpdateSliceInMinorDims(y, v, {0});
+ auto bv = xla::Mul(
+ -beta, v,
+ /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
+ w = UpdateSliceInMinorDims(w, bv, {0});
+
+ TF_ASSIGN_OR_RETURN(
+ auto values, XlaForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
+ "wy", builder));
+ return values[0];
+}
+
+} // namespace
+
+// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
+// def qr_blocked(a, block_size):
+// m = a.shape[0]
+// n = a.shape[1]
+// q = np.eye(m)
+// for i in xrange(0, min(m, n), block_size):
+// k = min(block_size, min(m, n) - s)
+// (a, vs, taus) = qr(a[i:, i:i+k])
+// y = vs
+// w = ComputeWYRepresentation(vs, taus, m-i, k)
+// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
+// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
+// return (q, a)
+// TODO(phawkins): consider using UT transformations (in the form I - V U V')
+// rather than WY transformations.
+xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
+ int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ const int num_dims = xla::ShapeUtil::Rank(a_shape);
+ if (num_dims < 2) {
+ return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
+ num_dims);
+ }
+ xla::PrimitiveType type = a_shape.element_type();
+
+ const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
+ const int64 p = std::min(m, n);
+
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to QR must be >= 1; got ", block_size);
+ }
+
+ const int64 num_batch_dims = num_dims - 2;
+ std::vector<int64> batch_dims(num_batch_dims);
+ for (int i = 0; i < num_batch_dims; ++i) {
+ batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
+ }
+
+ auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims);
+ for (int64 i = 0; i < p; i += block_size) {
+ int64 k = std::min(block_size, p - i);
+
+ auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
+ TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block));
+
+ a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
+
+ // Compute the I-WY block representation of a product of Householder
+ // matrices.
+ TF_ASSIGN_OR_RETURN(auto w,
+ ComputeWYRepresentation(type, batch_dims, qr_block.vs,
+ qr_block.taus, m - i, k));
+ auto y = qr_block.vs;
+
+ // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
+ auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
+ auto a_update = BatchDot(w, a_panel, /*transpose_x=*/true);
+ a_update = BatchDot(y, a_update);
+ a_panel = a_panel + a_update;
+ a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
+
+ // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
+ auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
+ auto q_update = BatchDot(q_panel, w);
+ q_update =
+ BatchDot(q_update, y, /*transpose_x=*/false, /*transpose_y=*/true);
+ q_panel = q_panel + q_update;
+ q = UpdateSliceInMinorDims(q, q_panel, {0, i});
+ }
+ QRDecompositionResult result;
+ result.q = q;
+ result.r = a;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h
new file mode 100644
index 0000000000..3aa6a9b075
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/qr.h
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace tensorflow {
+
+// Computes the QR decompositions of a batch of matrices. That is,
+// given a (batched) matrix a, computes an orthonormal matrix Q and an
+// upper-triangular matrix R such that a = QR.
+// `a` must be a (batched) matrix of size [..., m, n].
+// The algorithm implements a blocked QR decomposition; `block_size` is
+// the block size to use.
+// TODO(phawkins): handle the complex case.
+struct QRDecompositionResult {
+ xla::XlaOp q;
+ xla::XlaOp r;
+};
+
+xla::StatusOr<QRDecompositionResult> QRDecomposition(xla::XlaOp a,
+ int64 block_size = 128);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc
new file mode 100644
index 0000000000..8ff10fbd3f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/random.cc
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/random.h"
+
+#include <cmath>
+#include <limits>
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace tensorflow {
+
+xla::XlaOp TruncatedNormal(xla::XlaOp uniform) {
+ auto normal_cdf = [](double x) {
+ return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
+ };
+
+ const double kA = -2.0;
+ const double kB = 2.0;
+ const double kMu = 0.0;
+ const double kSigma = 1.0;
+ const double kAlpha = (kA - kMu) / kSigma;
+ const double kBeta = (kB - kMu) / kSigma;
+ const double kAlphaNormalCdf = normal_cdf(kAlpha);
+ const double kBetaNormalCdf = normal_cdf(kBeta);
+ const double kZ = kBetaNormalCdf - kAlphaNormalCdf;
+
+ xla::XlaOp one = xla::ScalarLike(uniform, 1.0);
+ xla::XlaOp two = xla::ScalarLike(uniform, 2.0);
+ xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0));
+ xla::XlaOp z = xla::ScalarLike(uniform, kZ);
+ xla::XlaOp alpha_normal_cdf = xla::ScalarLike(uniform, kAlphaNormalCdf);
+
+ auto p = alpha_normal_cdf + z * uniform;
+ // probit(p) = sqrt(2) * erfinv(2*p-1)
+ return sqrt_2 * xla::ErfInv(two * p - one);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h
new file mode 100644
index 0000000000..2c573fd85b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/random.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace tensorflow {
+
+// Builds an array filled with values sampled from a truncated normal
+// distribution such that no values are greater than two or less than negative
+// two.
+//
+// The "uniform" parameter must be an array of random numbers distributed in
+// (0,1).
+xla::XlaOp TruncatedNormal(xla::XlaOp uniform);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index d5a27abb25..6a5be1c2be 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -21,7 +21,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -97,8 +98,8 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
buffer_shape_post_axes.end());
// Construct the initial values of the loop-carried Tensors.
- auto flat_indices = builder->Reshape(indices, flat_indices_shape);
- auto flat_updates = builder->Reshape(updates, flat_updates_shape);
+ auto flat_indices = xla::Reshape(indices, flat_indices_shape);
+ auto flat_updates = xla::Reshape(updates, flat_updates_shape);
auto init = {flat_indices, flat_updates, buffer};
// Constructs the loop body. The implementation of scatter is essentially:
@@ -112,46 +113,44 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
- auto zero_index = body_builder->ConstantLiteral(
- xla::Literal::Zero(indices_shape.element_type()));
+ auto zero_index = xla::ConstantLiteral(
+ body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
// Slice the i-th index from the indices array.
xla::XlaOp index;
- auto indices_offset = body_builder->Reshape(i, {1});
+ auto indices_offset = xla::Reshape(i, {1});
if (indices_are_vectors) {
- indices_offset = body_builder->Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
+ indices_offset = xla::Pad(indices_offset, zero_index,
+ xla::MakeEdgePaddingConfig({{0, 1}}));
- index = body_builder->DynamicSlice(indices, indices_offset,
- {1, num_index_dims});
- index = body_builder->Collapse(index, {0, 1});
+ index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
+ index = xla::Collapse(index, {0, 1});
} else {
- index = body_builder->DynamicSlice(indices, indices_offset, {1});
+ index = xla::DynamicSlice(indices, indices_offset, {1});
}
// Discard updates with negative indices, since some users expect this.
- auto index_in_range =
- body_builder->ReduceAll(body_builder->Le(zero_index, index),
- body_builder->ConstantR0<bool>(true),
- xla::CreateScalarAndComputation(body_builder));
+ auto index_in_range = xla::ReduceAll(
+ xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
+ xla::CreateScalarAndComputation(body_builder));
// Make the index in bounds to prevent implementation defined behavior.
- index = body_builder->Max(index, zero_index);
- index = body_builder->Pad(
+ index = xla::Max(index, zero_index);
+ index = xla::Pad(
index, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
// Slice the i-th index from the updates array.
- auto updates_offset = body_builder->Reshape(i, {1});
- updates_offset = body_builder->Pad(
+ auto updates_offset = xla::Reshape(i, {1});
+ updates_offset = xla::Pad(
updates_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
std::vector<int64> flat_updates_slice_shape({1});
flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
- auto update = body_builder->DynamicSlice(updates, updates_offset,
- flat_updates_slice_shape);
+ auto update =
+ xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
// Unflatten the major (iteration) dimensions of the slice to their
// original shape.
@@ -159,20 +158,19 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
updates_slice_shape.insert(updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
- update = body_builder->Reshape(update, updates_slice_shape);
+ update = xla::Reshape(update, updates_slice_shape);
// Apply the update to the buffer. If there is a combiner, use it to merge
// the current values with the update.
- auto current_value =
- body_builder->DynamicSlice(buffer, index, updates_slice_shape);
+ auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
if (combiner) {
update = combiner(current_value, update, body_builder);
}
// Use the current value instead of the update if the index is out of
// bounds.
- update = body_builder->Select(index_in_range, update, current_value);
+ update = xla::Select(index_in_range, update, current_value);
// Apply the update.
- buffer = body_builder->DynamicUpdateSlice(buffer, update, index);
+ buffer = xla::DynamicUpdateSlice(buffer, update, index);
return std::vector<xla::XlaOp>{indices, updates, buffer};
};
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index b4503601f9..ce0f28db8f 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -29,619 +31,564 @@ limitations under the License.
namespace tensorflow {
-xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
- const xla::XlaOp& a, xla::XlaOp b,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a,
- int64 block_size) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have different ranks: ",
- xla::ShapeUtil::HumanString(a_shape), " vs. ",
- xla::ShapeUtil::HumanString(b_shape));
- }
- const int ndims = xla::ShapeUtil::Rank(a_shape);
- if (ndims < 2) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve must have rank >= 2: ", ndims);
- }
- // The batch dimensions must be equal.
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- int64 b_size = b_shape.dimensions(i);
- if (a_size != b_size) {
+xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ int64 block_size) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) {
return errors::InvalidArgument(
- "Batch dimensions of arguments to TriangularSolve must be equal: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
+ "Arguments to TriangularSolve have different ranks: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs. ",
xla::ShapeUtil::HumanString(b_shape));
}
- batch_dimensions.push_back(a_size);
- }
-
- if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
- xla::ShapeUtil::GetDimension(a_shape, -2)) {
- return errors::InvalidArgument(
- "The 'a' arguments to TriangularSolve must be square matrices: ",
- xla::ShapeUtil::HumanString(a_shape));
- }
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
- return errors::InvalidArgument(
- "Arguments to TriangularSolve have incompatible matrix shapes: ",
- xla::ShapeUtil::HumanString(a_shape), " vs ",
- xla::ShapeUtil::HumanString(b_shape));
- }
-
- if (block_size < 1) {
- return errors::InvalidArgument(
- "block_size argument to TriangularSolve must be >= 1; got ",
- block_size);
- }
-
- std::map<int, xla::XlaComputation> base_computations;
- auto get_base_triangular_solve =
- [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
- xla::XlaComputation& computation = base_computations[k];
- if (computation.IsNull()) {
- std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
- tensorflow::strings::StrCat("trsm_base_", k));
-
- auto a_param = sub->Parameter(
- 0,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
- "a");
-
- std::array<int64, 2> b_lastd;
- if (left_side) {
- b_lastd = {k, n};
- } else {
- b_lastd = {m, k};
- }
- auto b_param = sub->Parameter(
- 1,
- xla::ShapeUtil::MakeShape(
- b_shape.element_type(),
- PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
- "b");
-
- // We use a left-looking or right-looking subroutine on the block diagonal
- // in the lower=true cases, while falling back to a recursive call in
- // others. The left-looking and right-looking subroutines are written with
- // a While loop and so yields much faster compile times. Moreover, they
- // can give higher performance on smaller (sub)problems.
- if (left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else if (!left_side && lower) {
- TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param,
- b_param, transpose_a,
- conjugate_a)
- .status());
- } else {
- TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
- left_side, lower, transpose_a,
- conjugate_a,
- /*block_size=*/1)
- .status());
+ const int ndims = xla::ShapeUtil::Rank(a_shape);
+ if (ndims < 2) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve must have rank >= 2: ", ndims);
+ }
+ // The batch dimensions must be equal.
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ int64 b_size = b_shape.dimensions(i);
+ if (a_size != b_size) {
+ return errors::InvalidArgument(
+ "Batch dimensions of arguments to TriangularSolve must be equal: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
+ batch_dimensions.push_back(a_size);
+ }
- TF_ASSIGN_OR_RETURN(computation, sub->Build());
+ if (xla::ShapeUtil::GetDimension(a_shape, -1) !=
+ xla::ShapeUtil::GetDimension(a_shape, -2)) {
+ return errors::InvalidArgument(
+ "The 'a' arguments to TriangularSolve must be square matrices: ",
+ xla::ShapeUtil::HumanString(a_shape));
+ }
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) {
+ return errors::InvalidArgument(
+ "Arguments to TriangularSolve have incompatible matrix shapes: ",
+ xla::ShapeUtil::HumanString(a_shape), " vs ",
+ xla::ShapeUtil::HumanString(b_shape));
}
- return &computation;
- };
-
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Right-looking blocked triangular solve.
- // For an explanation of the algorithm, see the TRSM discussion in:
- // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
- // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
- // (2008): 4.
-
- // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
- // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
- // conjugate_a is True.
-
- if (!left_side && lower == transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < n; i += block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
- if (i + k < n) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
- } else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
- }
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
- b_update = builder->Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
- }
+ if (block_size < 1) {
+ return errors::InvalidArgument(
+ "block_size argument to TriangularSolve must be >= 1; got ",
+ block_size);
}
- } else if (left_side && lower != transpose_a) {
- // for i in range(0, a.shape[-1], block_size):
- for (int64 i = 0; i < m; i += block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] = triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i + k < a.shape[-1]:
- // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
- if (i + k < m) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
+ std::map<int, xla::XlaComputation> base_computations;
+ auto get_base_triangular_solve =
+ [&](int k) -> xla::StatusOr<xla::XlaComputation*> {
+ xla::XlaComputation& computation = base_computations[k];
+ if (computation.IsNull()) {
+ std::unique_ptr<xla::XlaBuilder> sub = builder->CreateSubBuilder(
+ tensorflow::strings::StrCat("trsm_base_", k));
+
+ auto a_param = xla::Parameter(
+ sub.get(), 0,
+ xla::ShapeUtil::MakeShape(b_shape.element_type(),
+ ConcatVectors(batch_dimensions, {k, k})),
+ "a");
+
+ std::array<int64, 2> b_lastd;
+ if (left_side) {
+ b_lastd = {k, n};
+ } else {
+ b_lastd = {m, k};
+ }
+ auto b_param = xla::Parameter(
+ sub.get(), 1,
+ xla::ShapeUtil::MakeShape(b_shape.element_type(),
+ ConcatVectors(batch_dimensions, b_lastd)),
+ "b");
+
+ // We use a left-looking or right-looking subroutine on the block
+ // diagonal in the lower=true cases, while falling back to a recursive
+ // call in others. The left-looking and right-looking subroutines are
+ // written with a While loop and so yields much faster compile times.
+ // Moreover, they can give higher performance on smaller (sub)problems.
+ if (left_side && lower) {
+ TriangularSolveLeftLooking(a_param, b_param, transpose_a,
+ conjugate_a);
+ } else if (!left_side && lower) {
+ TriangularSolveRightLooking(a_param, b_param, transpose_a,
+ conjugate_a);
} else {
- TF_ASSIGN_OR_RETURN(
- a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
+ TriangularSolve(a_param, b_param, left_side, lower, transpose_a,
+ conjugate_a,
+ /*block_size=*/1);
}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
- b_update = builder->Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
+ TF_ASSIGN_OR_RETURN(computation, sub->Build());
}
- }
- } else if (!left_side && lower != transpose_a) {
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, n - i);
-
- // output[..., :, i:i+k] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
- }
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
+ return &computation;
+ };
+
+ xla::XlaOp output = xla::ZerosLike(b);
+
+ // Right-looking blocked triangular solve.
+ // For an explanation of the algorithm, see the TRSM discussion in:
+ // Goto, Kazushige, and Robert Van De Geijn. "High-performance
+ // implementation of the level-3 BLAS." ACM Transactions on Mathematical
+ // Software (TOMS) 35.1 (2008): 4.
+
+ // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
+ // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
+ // conjugate_a is True.
+
+ if (!left_side && lower == transpose_a) {
+ // for i in range(0, a.shape[-1], block_size):
+ for (int64 i = 0; i < n; i += block_size) {
+ int64 k = std::min(block_size, n - i);
+
+ // output[..., :, i:i+k] = triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {0, i});
+
+ // if i + k < a.shape[-1]:
+ // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
+ if (i + k < n) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i + k, i}, {n, i + k});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, n});
+ }
+
+ auto b_update = BatchDot(update, a_slice_2,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+ auto b_slice_2 = SliceInMinorDims(b, {0, i + k}, {m, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, i + k});
}
+ }
- TF_ASSIGN_OR_RETURN(auto b_update,
- BatchDot(builder, update, a_slice_2,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {m, i}));
- b_update = builder->Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
+ } else if (left_side && lower != transpose_a) {
+ // for i in range(0, a.shape[-1], block_size):
+ for (int64 i = 0; i < m; i += block_size) {
+ int64 k = std::min(block_size, m - i);
+
+ // output[..., i:i+k, :] = triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
+ } else {
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
+
+ // if i + k < a.shape[-1]:
+ // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
+ if (i + k < m) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i + k, i}, {m, i + k});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {i, i + k}, {i + k, m});
+ }
+
+ auto b_update = BatchDot(a_slice_2, update,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto b_slice_2 = SliceInMinorDims(b, {i + k, 0}, {m, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {i + k, 0});
+ }
}
- }
- } else { // left_side && lower == transpose_a
- // for i in reversed(range(0, a.shape[-1], block_size)):
- const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
- for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
- int64 k = std::min(block_size, m - i);
-
- // output[..., i:i+k, :] triangular_solve(
- // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
- xla::XlaOp update;
- if (k > 1) {
- TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
- get_base_triangular_solve(k));
- update = builder->Call(*solve, {a_slice, b_slice});
- } else {
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- update = builder->Div(b_slice, a_slice_conj);
+ } else if (!left_side && lower != transpose_a) {
+ // for i in reversed(range(0, a.shape[-1], block_size)):
+ const int64 last_blk_ix =
+ xla::RoundUpToNearest(n, block_size) - block_size;
+ for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
+ int64 k = std::min(block_size, n - i);
+
+ // output[..., :, i:i+k] triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {0, i}, {m, i + k});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
+ } else {
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {0, i});
+
+ // if i - k >= 0:
+ // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
+ if (i - k >= 0) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k});
+ }
+
+ auto b_update = BatchDot(update, a_slice_2,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+ auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {m, i});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0});
+ }
}
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
-
- // if i - k >= 0:
- // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
- // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
- // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
- if (i - k >= 0) {
- xla::XlaOp a_slice_2;
- if (lower) {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
+ } else { // left_side && lower == transpose_a
+ // for i in reversed(range(0, a.shape[-1], block_size)):
+ const int64 last_blk_ix =
+ xla::RoundUpToNearest(m, block_size) - block_size;
+ for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
+ int64 k = std::min(block_size, m - i);
+
+ // output[..., i:i+k, :] triangular_solve(
+ // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + k, i + k});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + k, n});
+ xla::XlaOp update;
+ if (k > 1) {
+ TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve,
+ get_base_triangular_solve(k));
+ update = xla::Call(builder, *solve, {a_slice, b_slice});
} else {
- TF_ASSIGN_OR_RETURN(a_slice_2,
- SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ update = b_slice / a_slice_conj;
+ }
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
+
+ // if i - k >= 0:
+ // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
+ // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
+ // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
+ if (i - k >= 0) {
+ xla::XlaOp a_slice_2;
+ if (lower) {
+ a_slice_2 = SliceInMinorDims(a, {i, 0}, {i + k, i});
+ } else {
+ a_slice_2 = SliceInMinorDims(a, {0, i}, {i, i + k});
+ }
+
+ auto b_update = BatchDot(a_slice_2, update,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto b_slice_2 = SliceInMinorDims(b, {0, 0}, {i, n});
+ b = UpdateSliceInMinorDims(b, b_slice_2 - b_update, {0, 0});
}
-
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(auto b_slice_2,
- SliceInMinorDims(builder, b, {0, 0}, {i, n}));
- b_update = builder->Sub(b_slice_2, b_update);
- TF_ASSIGN_OR_RETURN(
- b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
}
- }
- return output;
+ return output;
+ });
}
-xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
-
- // Allocate the output and set its first or last row,
- // output = np.zeros_like(b)
- // if transpose_a:
- // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
- // else:
- // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
- xla::XlaOp output = Zeros(builder, b_shape);
- {
- auto i = transpose_a ? m - 1 : 0;
- TF_ASSIGN_OR_RETURN(auto a_slice,
- SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
- TF_ASSIGN_OR_RETURN(auto b_slice,
- SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
- TF_ASSIGN_OR_RETURN(auto a_slice_conj,
- MaybeConjugate(builder, a_slice, conjugate_a));
- auto update = builder->Div(b_slice, a_slice_conj);
- TF_ASSIGN_OR_RETURN(
- output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
- }
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (m-2, output, a, b)
- // else:
- // init = (1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
- auto init = builder->Tuple({init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i >= 0 if transpose_a else i < m
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
- {
- auto i = condb->GetTupleElement(
- condb->Parameter(0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple"),
- 0);
- if (transpose_a) {
- condb->Ge(i, condb->ConstantR0<int32>(0));
- } else {
- condb->Lt(i, condb->ConstantR0<int32>(m));
+xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
+
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ batch_dimensions.push_back(a_size);
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
- // else:
- // a_row = a[..., i:i+1, :i]
- // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
- // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
- {
- auto input_tuple = bodyb->Parameter(0, tuple_shape,
- "TriangularSolveLeftLookingWhileTuple");
- // i, output, a, b = loop_carry
- auto i = bodyb->GetTupleElement(input_tuple, 0);
- auto body_out = bodyb->GetTupleElement(input_tuple, 1);
- auto body_a = bodyb->GetTupleElement(input_tuple, 2);
- auto body_b = bodyb->GetTupleElement(input_tuple, 3);
- auto zero = bodyb->ConstantR0<int32>(0);
+ // The main computation is performed in a While loop.
- // We'd like to implement this:
- // if transpose_a:
- // a_row = T(a[..., i+1:, i:i+1])
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a_row, body_out[..., i+1:, :]))
- // else:
- // result_row = (b[..., i:i+1, :]
- // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
- // But since we can't have intermediate array sizes depend on the loop
- // counter, we instead exploit the fact that we initialized the output to
- // all zeros and use that as zero-padding (doing unnecessary FLOPs).
- xla::XlaOp a_row;
- if (transpose_a) {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {zero, i}, {m, 1}));
- } else {
- TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, zero}, {1, m}));
+ // Allocate the output and set its first or last row,
+ // output = np.zeros_like(b)
+ // if transpose_a:
+ // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
+ // else:
+ // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
+ xla::XlaOp output = xla::ZerosLike(b);
+ {
+ auto i = transpose_a ? m - 1 : 0;
+ auto a_slice = SliceInMinorDims(a, {i, i}, {i + 1, i + 1});
+ auto b_slice = SliceInMinorDims(b, {i, 0}, {i + 1, n});
+ auto a_slice_conj = MaybeConjugate(a_slice, conjugate_a);
+ auto update = b_slice / a_slice_conj;
+ output = UpdateSliceInMinorDims(output, update, {i, 0});
}
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
- /*transpose_x=*/transpose_a,
- /*transpose_y=*/false,
- /*conjugate_x=*/conjugate_a,
- /*conjugate_y=*/false));
- TF_ASSIGN_OR_RETURN(
- auto result_row_slice,
- DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
- auto result_row = bodyb->Sub(result_row_slice, b_update);
-
- // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_elt_conj,
- MaybeConjugate(bodyb.get(), a_elt, conjugate_a));
- auto div_result = bodyb->Div(result_row, a_elt_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {i, zero}));
+ // Construct the initial loop carry tuple,
// if transpose_a:
- // return (i - 1, body_out, a, b)
+ // init = (m-2, output, a, b)
// else:
- // return (i + 1, body_out, a, b)
- auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
- bodyb->Tuple({next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = builder->While(cond, body, init);
- return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
+ // init = (1, output, a, b)
+ std::vector<xla::Shape> tuple_shapes = {
+ // The loop iteration counter is a scalar, incremented each iteration.
+ xla::ShapeUtil::MakeShape(xla::S32, {}),
+ // The output has the shape of b, with one row updated each iteration.
+ b_shape,
+ // The coefficient matrix a is a loop invariant.
+ a_shape,
+ // The right-hand-side matrix b is a loop invariant.
+ b_shape};
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? m - 2 : 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
+
+ // Construct the loop condition function,
+ // def cond_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // return i >= 0 if transpose_a else i < m
+ std::unique_ptr<xla::XlaBuilder> condb =
+ builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
+ {
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple"),
+ 0);
+ if (transpose_a) {
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+ } else {
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), m));
+ }
+ }
+ TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+ // Construct the loop body function,
+ // def body_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // if transpose_a:
+ // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
+ // else:
+ // a_row = a[..., i:i+1, :i]
+ // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
+ // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
+ // if transpose_a:
+ // return (i - 1, output, a, b)
+ // else:
+ // return (i + 1, output, a, b)
+ // We have to do some extra FLOPs propagating zeros in the matrix multiply
+ // because we can't have the size of its arguments depend on the loop
+ // counter.
+ std::unique_ptr<xla::XlaBuilder> bodyb =
+ builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
+ {
+ auto input_tuple = xla::Parameter(bodyb.get(), 0, tuple_shape,
+ "TriangularSolveLeftLookingWhileTuple");
+
+ // i, output, a, b = loop_carry
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
+
+ // We'd like to implement this:
+ // if transpose_a:
+ // a_row = T(a[..., i+1:, i:i+1])
+ // result_row = (b[..., i:i+1, :]
+ // - np.matmul(a_row, body_out[..., i+1:, :]))
+ // else:
+ // result_row = (b[..., i:i+1, :]
+ // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
+ // But since we can't have intermediate array sizes depend on the loop
+ // counter, we instead exploit the fact that we initialized the output to
+ // all zeros and use that as zero-padding (doing unnecessary FLOPs).
+ xla::XlaOp a_row;
+ if (transpose_a) {
+ a_row = DynamicSliceInMinorDims(body_a, {zero, i}, {m, 1});
+ } else {
+ a_row = DynamicSliceInMinorDims(body_a, {i, zero}, {1, m});
+ }
+ auto b_update = BatchDot(a_row, body_out,
+ /*transpose_x=*/transpose_a,
+ /*transpose_y=*/false,
+ /*conjugate_x=*/conjugate_a,
+ /*conjugate_y=*/false);
+ auto result_row_slice =
+ DynamicSliceInMinorDims(body_b, {i, zero}, {1, n});
+ auto result_row = result_row_slice - b_update;
+
+ // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
+ auto a_elt = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ auto a_elt_conj = MaybeConjugate(a_elt, conjugate_a);
+ auto div_result = xla::Div(result_row, a_elt_conj);
+ body_out = DynamicUpdateSliceInMinorDims(body_out, div_result, {i, zero});
+
+ // if transpose_a:
+ // return (i - 1, body_out, a, b)
+ // else:
+ // return (i + 1, body_out, a, b)
+ auto next_i = xla::Add(
+ i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? -1 : 1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
+ }
+ TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+ // Construct the While loop and return the result,
+ // return while_loop(cond_fun, body_fun, init)[1]
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ });
}
-xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a) {
- TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
- TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
- const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
- const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
- const int64 ndims = xla::ShapeUtil::Rank(a_shape);
-
- std::vector<int64> batch_dimensions;
- for (int i = 0; i < ndims - 2; ++i) {
- int64 a_size = a_shape.dimensions(i);
- batch_dimensions.push_back(a_size);
- }
-
- // The main computation is performed in a While loop.
- xla::XlaOp output = Zeros(builder, b_shape);
-
- // Construct the initial loop carry tuple,
- // if transpose_a:
- // init = (0, output, a, b)
- // else:
- // init = (n-1, output, a, b)
- std::vector<xla::Shape> tuple_shapes = {
- // The loop iteration counter is a scalar, incremented each iteration.
- xla::ShapeUtil::MakeShape(xla::S32, {}),
- // The output has the shape of b, with one row updated each iteration.
- b_shape,
- // The coefficient matrix a is a loop invariant.
- a_shape,
- // The right-hand-side matrix b is a loop invariant.
- b_shape};
- xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
- auto init_i = builder->ConstantR0<int32>(transpose_a ? 0 : n - 1);
- auto init = builder->Tuple({init_i, output, a, b});
-
- // Construct the loop condition function,
- // def cond_fun(loop_carry):
- // i, output, a, b = loop_carry
- // return i < n if transpose_a else i >= 0
- std::unique_ptr<xla::XlaBuilder> condb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
- {
- auto i = condb->GetTupleElement(
- condb->Parameter(0, tuple_shape,
- "TriangularSolveRightLookingWhileTuple"),
- 0);
- if (transpose_a) {
- condb->Lt(i, condb->ConstantR0<int32>(n));
- } else {
- condb->Ge(i, condb->ConstantR0<int32>(0));
+xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a) {
+ xla::XlaBuilder* builder = a.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
+ TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b));
+ const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2);
+ const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1);
+ const int64 ndims = xla::ShapeUtil::Rank(a_shape);
+
+ std::vector<int64> batch_dimensions;
+ for (int i = 0; i < ndims - 2; ++i) {
+ int64 a_size = a_shape.dimensions(i);
+ batch_dimensions.push_back(a_size);
}
- }
- TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
-
- // Construct the loop body function,
- // def body_fun(loop_carry):
- // i, output, a, b = loop_carry
- // if transpose_a:
- // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2)
- // else:
- // a_row = a[..., :, i:i+1]
- // result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
- // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- // if transpose_a:
- // return (i - 1, output, a, b)
- // else:
- // return (i + 1, output, a, b)
- // We have to do some extra FLOPs propagating zeros in the matrix multiply
- // because we can't have the size of its arguments depend on the loop counter.
- std::unique_ptr<xla::XlaBuilder> bodyb =
- builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
- {
- auto input_tuple = bodyb->Parameter(
- 0, tuple_shape, "TriangularSolveRightLookingWhileTuple");
-
- // i, output, a, b = loop_carry
- auto i = bodyb->GetTupleElement(input_tuple, 0);
- auto body_out = bodyb->GetTupleElement(input_tuple, 1);
- auto body_a = bodyb->GetTupleElement(input_tuple, 2);
- auto body_b = bodyb->GetTupleElement(input_tuple, 3);
- auto zero = bodyb->ConstantR0<int32>(0);
-
- // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :,
- // i:i+1]) But since we can't have intermediate array sizes depend on the
- // loop counter, we instead exploit the fact that we initialized the output
- // to all zeros and use that as zero-padding (doing unnecessary FLOPs).
- TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a,
- /*transpose_x=*/false,
- /*transpose_y=*/transpose_a,
- /*conjugate_x=*/false,
- /*conjugate_y=*/conjugate_a));
- // result = b - np.matmul(output, a)
- auto result = bodyb->Sub(body_b, b_update);
- // result_row = result[..., :, i:i+1]
- TF_ASSIGN_OR_RETURN(
- auto result_row,
- DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1}));
-
- // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
- TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a,
- {i, i}, {1, 1}));
- TF_ASSIGN_OR_RETURN(auto a_ii_conj,
- MaybeConjugate(bodyb.get(), a_ii, conjugate_a));
- auto div_result = bodyb->Div(result_row, a_ii_conj);
- TF_ASSIGN_OR_RETURN(body_out,
- DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
- div_result, {zero, i}));
+ // The main computation is performed in a While loop.
+ xla::XlaOp output = xla::ZerosLike(b);
+
+ // Construct the initial loop carry tuple,
// if transpose_a:
- // return (i + 1, body_out, a, b)
+ // init = (0, output, a, b)
// else:
- // return (i - 1, body_out, a, b)
- auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? 1 : -1));
- bodyb->Tuple({next_i, body_out, body_a, body_b});
- }
- TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
-
- // Construct the While loop and return the result,
- // return while_loop(cond_fun, body_fun, init)[1]
- auto triangular_solve_left_looking_while = builder->While(cond, body, init);
- return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
+ // init = (n-1, output, a, b)
+ std::vector<xla::Shape> tuple_shapes = {
+ // The loop iteration counter is a scalar, incremented each iteration.
+ xla::ShapeUtil::MakeShape(xla::S32, {}),
+ // The output has the shape of b, with one row updated each iteration.
+ b_shape,
+ // The coefficient matrix a is a loop invariant.
+ a_shape,
+ // The right-hand-side matrix b is a loop invariant.
+ b_shape};
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
+ auto init_i = xla::ConstantR0<int32>(builder, transpose_a ? 0 : n - 1);
+ auto init = xla::Tuple(builder, {init_i, output, a, b});
+
+ // Construct the loop condition function,
+ // def cond_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // return i < n if transpose_a else i >= 0
+ std::unique_ptr<xla::XlaBuilder> condb =
+ builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond");
+ {
+ auto i = xla::GetTupleElement(
+ xla::Parameter(condb.get(), 0, tuple_shape,
+ "TriangularSolveRightLookingWhileTuple"),
+ 0);
+ if (transpose_a) {
+ xla::Lt(i, xla::ConstantR0<int32>(condb.get(), n));
+ } else {
+ xla::Ge(i, xla::ConstantR0<int32>(condb.get(), 0));
+ }
+ }
+ TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
+
+ // Construct the loop body function,
+ // def body_fun(loop_carry):
+ // i, output, a, b = loop_carry
+ // if transpose_a:
+ // a_row = np.swapaxes(a[..., :, i:i+1], -1, -2)
+ // else:
+ // a_row = a[..., :, i:i+1]
+ // result_row = b[..., :, i:i+1] - np.matmul(output, a_row)
+ // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1]
+ // if transpose_a:
+ // return (i - 1, output, a, b)
+ // else:
+ // return (i + 1, output, a, b)
+ // We have to do some extra FLOPs propagating zeros in the matrix multiply
+ // because we can't have the size of its arguments depend on the loop
+ // counter.
+ std::unique_ptr<xla::XlaBuilder> bodyb =
+ builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody");
+ {
+ auto input_tuple = xla::Parameter(
+ bodyb.get(), 0, tuple_shape, "TriangularSolveRightLookingWhileTuple");
+
+ // i, output, a, b = loop_carry
+ auto i = xla::GetTupleElement(input_tuple, 0);
+ auto body_out = xla::GetTupleElement(input_tuple, 1);
+ auto body_a = xla::GetTupleElement(input_tuple, 2);
+ auto body_b = xla::GetTupleElement(input_tuple, 3);
+ auto zero = xla::ConstantR0<int32>(bodyb.get(), 0);
+
+ // result = b - np.matmul(output, a)
+ // result_row = result[..., :, i:i+1]
+ auto body_b_slice = DynamicSliceInMinorDims(body_b, {zero, i}, {m, 1});
+ xla::XlaOp a_slice;
+ if (transpose_a) {
+ a_slice = DynamicSliceInMinorDims(body_a, {i, zero}, {1, n});
+ } else {
+ a_slice = DynamicSliceInMinorDims(body_a, {zero, i}, {n, 1});
+ }
+ auto b_update = body_b_slice - BatchDot(body_out, a_slice,
+ /*transpose_x=*/false,
+ /*transpose_y=*/transpose_a,
+ /*conjugate_x=*/false,
+ /*conjugate_y=*/conjugate_a);
+
+ // body_out[..., :, i:i+1] = b_update / a[..., i:i+1, i:i+1]
+ auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1});
+ auto a_ii_conj = MaybeConjugate(a_ii, conjugate_a);
+ body_out = DynamicUpdateSliceInMinorDims(body_out, b_update / a_ii_conj,
+ {zero, i});
+
+ // if transpose_a:
+ // return (i + 1, body_out, a, b)
+ // else:
+ // return (i - 1, body_out, a, b)
+ auto next_i = xla::Add(
+ i, xla::ConstantR0<int32>(bodyb.get(), transpose_a ? 1 : -1));
+ xla::Tuple(bodyb.get(), {next_i, body_out, body_a, body_b});
+ }
+ TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
+
+ // Construct the While loop and return the result,
+ // return while_loop(cond_fun, body_fun, init)[1]
+ auto triangular_solve_left_looking_while = xla::While(cond, body, init);
+ return xla::GetTupleElement(triangular_solve_left_looking_while, 1);
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
index 540c26b247..80c2bc4c9c 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h
@@ -57,23 +57,15 @@ namespace tensorflow {
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
-xla::StatusOr<xla::XlaOp> TriangularSolve(xla::XlaBuilder* builder,
- const xla::XlaOp& a, xla::XlaOp b,
- bool left_side, bool lower,
- bool transpose_a, bool conjugate_a,
- int64 block_size = 256);
+xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side,
+ bool lower, bool transpose_a, bool conjugate_a,
+ int64 block_size = 256);
-xla::StatusOr<xla::XlaOp> TriangularSolveLeftLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a);
+xla::XlaOp TriangularSolveLeftLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a);
-xla::StatusOr<xla::XlaOp> TriangularSolveRightLooking(xla::XlaBuilder* builder,
- const xla::XlaOp& a,
- const xla::XlaOp& b,
- bool transpose_a,
- bool conjugate_a);
+xla::XlaOp TriangularSolveRightLooking(xla::XlaOp a, xla::XlaOp b,
+ bool transpose_a, bool conjugate_a);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
index 87ea4763f7..f1bff6037b 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -85,11 +85,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -107,11 +106,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -129,11 +127,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
@@ -151,11 +148,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
@@ -173,11 +169,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -196,11 +191,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/true,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -219,11 +213,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -242,11 +235,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/false, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
@@ -267,11 +259,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/false, /*lower=*/true,
- /*transpose_a=*/true, /*conjugate_a=*/true,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/false, /*lower=*/true,
+ /*transpose_a=*/true, /*conjugate_a=*/true,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, complex64(0.08333333, 0.08333333),
@@ -295,11 +286,10 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
- auto result = TriangularSolve(&builder, a, b,
- /*left_side=*/true, /*lower=*/false,
- /*transpose_a=*/true, /*conjugate_a=*/false,
- /*block_size=*/2);
- TF_ASSERT_OK(result.status());
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/true, /*conjugate_a=*/false,
+ /*block_size=*/2);
xla::Array2D<complex64> expected({
{0.5, 1., 1.5},
@@ -323,10 +313,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
+ TriangularSolveLeftLooking(a, b,
+ /*transpose_a=*/false,
+ /*conjugate_a=*/false);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
@@ -345,10 +334,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
xla::XlaOp a, b;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
- auto result = TriangularSolveLeftLooking(&builder, a, b,
- /*transpose_a=*/false,
- /*conjugate_a=*/false);
- TF_ASSERT_OK(result.status());
+ TriangularSolveLeftLooking(a, b,
+ /*transpose_a=*/false,
+ /*conjugate_a=*/false);
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index d9ff7e6259..a6f5d346cb 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -28,8 +30,9 @@ limitations under the License.
namespace tensorflow {
xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) {
- return builder->Broadcast(
- builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
+ return xla::Broadcast(
+ xla::ConstantLiteral(builder,
+ xla::LiteralUtil::Zero(shape.element_type())),
xla::AsInt64Slice(shape.dimensions()));
}
@@ -37,19 +40,19 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
double value) {
switch (type) {
case xla::F16:
- return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
+ return xla::ConstantR0<xla::half>(builder, static_cast<xla::half>(value));
break;
case xla::BF16:
- return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value));
+ return xla::ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
break;
case xla::F32:
- return builder->ConstantR0<float>(static_cast<float>(value));
+ return xla::ConstantR0<float>(builder, static_cast<float>(value));
break;
case xla::F64:
- return builder->ConstantR0<double>(value);
+ return xla::ConstantR0<double>(builder, value);
break;
case xla::C64:
- return builder->ConstantR0<xla::complex64>(value);
+ return xla::ConstantR0<xla::complex64>(builder, value);
break;
default:
LOG(FATAL) << "unhandled element type " << type;
@@ -61,31 +64,31 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
xla::Literal literal;
switch (type) {
case xla::U8:
- literal = std::move(*xla::Literal::CreateR0<uint8>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint8>(value));
break;
case xla::U32:
- literal = std::move(*xla::Literal::CreateR0<uint32>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint32>(value));
break;
case xla::U64:
- literal = std::move(*xla::Literal::CreateR0<uint64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<uint64>(value));
break;
case xla::S8:
- literal = std::move(*xla::Literal::CreateR0<int8>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int8>(value));
break;
case xla::S32:
- literal = std::move(*xla::Literal::CreateR0<int32>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int32>(value));
break;
case xla::S64:
- literal = std::move(*xla::Literal::CreateR0<int64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<int64>(value));
break;
case xla::F32:
- literal = std::move(*xla::Literal::CreateR0<float>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<float>(value));
break;
case xla::F64:
- literal = std::move(*xla::Literal::CreateR0<double>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<double>(value));
break;
case xla::C64:
- literal = std::move(*xla::Literal::CreateR0<complex64>(value));
+ literal = std::move(*xla::LiteralUtil::CreateR0<complex64>(value));
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
@@ -94,11 +97,11 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal = std::move(
- *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
+ *xla::LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
break;
case xla::F16:
- literal = std::move(
- *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
+ literal = std::move(*xla::LiteralUtil::CreateR0<xla::half>(
+ static_cast<xla::half>(value)));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
@@ -107,134 +110,140 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
default:
LOG(FATAL) << "unhandled element type " << type;
}
- return builder->ConstantLiteral(literal);
+ return xla::ConstantLiteral(builder, literal);
}
-xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end) {
- TF_RET_CHECK(start.size() == end.size());
- int64 n_minor_dims = start.size();
-
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
-
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - n_minor_dims);
-
- // Prepends 0s in the major dim
- std::vector<int64> padded_start(n_dims, 0);
- std::copy(start.begin(), start.end(),
- padded_start.begin() + major_dims.size());
-
- // Prepends the shape of the major dims.
- std::vector<int64> padded_end(n_dims);
- std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
- std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
-
- std::vector<int64> strides(n_dims, 1);
- return builder->Slice(x, padded_start, padded_end, strides);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_RET_CHECK(start.size() == end.size());
+ int64 n_minor_dims = start.size();
+
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - n_minor_dims);
+
+ // Prepends 0s in the major dim
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + major_dims.size());
+
+ // Prepends the shape of the major dims.
+ std::vector<int64> padded_end(n_dims);
+ std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
+ std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
+
+ std::vector<int64> strides(n_dims, 1);
+ return xla::Slice(x, padded_start, padded_end, strides);
+ });
}
-std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
- const gtl::ArraySlice<int64>& major_dims,
- const gtl::ArraySlice<int64>& indices) {
- std::vector<int64> output(indices.size() + major_dims.size());
- std::copy(major_dims.begin(), major_dims.end(), output.begin());
- std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size());
+std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
+ gtl::ArraySlice<int64> ys) {
+ std::vector<int64> output(xs.size() + ys.size());
+ std::copy(xs.begin(), xs.end(), output.begin());
+ std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
return output;
}
-xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts,
- const gtl::ArraySlice<int64>& sizes) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- int64 n_minor_dims = starts.size();
- TF_RET_CHECK(n_minor_dims == sizes.size());
- TF_RET_CHECK(n_minor_dims <= n_dims);
- gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
- /*pos=*/0,
- /*len=*/n_dims - sizes.size());
- TF_ASSIGN_OR_RETURN(auto padded_starts,
- PrependZerosInMajorDims(builder, x, starts));
- auto padded_sizes = PrependMajorDims(builder, major_dims, sizes);
- return builder->DynamicSlice(x, padded_starts, padded_sizes);
+xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts,
+ gtl::ArraySlice<int64> sizes) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ int64 n_minor_dims = starts.size();
+ TF_RET_CHECK(n_minor_dims == sizes.size());
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape.dimensions()),
+ /*pos=*/0,
+ /*len=*/n_dims - sizes.size());
+ auto padded_starts = PrependZerosInMajorDims(x, starts);
+ auto padded_sizes = ConcatVectors(major_dims, sizes);
+ return xla::DynamicSlice(x, padded_starts, padded_sizes);
+ });
}
-xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start) {
- // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
- std::vector<int32> start_as_int32(start.begin(), start.end());
- auto start_constant = builder->ConstantR1<int32>(start_as_int32);
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
- builder->GetShape(start_constant));
- const int64 start_length =
- xla::ShapeUtil::GetDimension(start_constant_shape, -1);
- TF_RET_CHECK(start_length == n_dims);
- return builder->DynamicUpdateSlice(x, update, start_constant);
+xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
+ std::vector<int32> start_as_int32(start.begin(), start.end());
+ auto start_constant = xla::ConstantR1<int32>(builder, start_as_int32);
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape,
+ builder->GetShape(start_constant));
+ const int64 start_length =
+ xla::ShapeUtil::GetDimension(start_constant_shape, -1);
+ TF_RET_CHECK(start_length == n_dims);
+ return xla::DynamicUpdateSlice(x, update, start_constant);
+ });
}
-xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- const int64 n_minor_dims = start.size();
- TF_RET_CHECK(n_minor_dims <= n_dims);
- std::vector<int64> padded_start(n_dims, 0);
- std::copy(start.begin(), start.end(),
- padded_start.begin() + (n_dims - n_minor_dims));
- return UpdateSlice(builder, x, update, padded_start);
+xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ const int64 n_minor_dims = start.size();
+ TF_RET_CHECK(n_minor_dims <= n_dims);
+ std::vector<int64> padded_start(n_dims, 0);
+ std::copy(start.begin(), start.end(),
+ padded_start.begin() + (n_dims - n_minor_dims));
+ return UpdateSlice(x, update, padded_start);
+ });
}
-xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
- const std::vector<xla::XlaOp>& starts) {
- TF_ASSIGN_OR_RETURN(auto padded_starts,
- PrependZerosInMajorDims(builder, x, starts));
- return builder->DynamicUpdateSlice(x, update, padded_starts);
+xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<xla::XlaOp> starts) {
+ auto padded_starts = PrependZerosInMajorDims(x, starts);
+ return xla::DynamicUpdateSlice(x, update, padded_starts);
}
-xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- auto zero = builder->Reshape(builder->ConstantR0<int32>(0), {1});
- std::vector<xla::XlaOp> padded_starts(n_dims, zero);
- for (int i = 0; i < starts.size(); ++i) {
- padded_starts[n_dims - starts.size() + i] =
- builder->Reshape(starts[i], {1});
- }
- return builder->ConcatInDim(padded_starts, 0);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ auto zero = xla::Reshape(xla::ConstantR0<int32>(builder, 0), {1});
+ std::vector<xla::XlaOp> padded_starts(n_dims, zero);
+ for (int i = 0; i < starts.size(); ++i) {
+ padded_starts[n_dims - starts.size() + i] = xla::Reshape(starts[i], {1});
+ }
+ return xla::ConcatInDim(builder, padded_starts, 0);
+ });
}
-xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- const int64 n_dims = xla::ShapeUtil::Rank(shape);
- TF_RET_CHECK(n_dims >= 2);
- std::vector<int64> permutation(n_dims);
- std::iota(permutation.begin(), permutation.end(), 0);
- std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
- return builder->Transpose(x, permutation);
+xla::XlaOp TransposeInMinorDims(xla::XlaOp x) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ const int64 n_dims = xla::ShapeUtil::Rank(shape);
+ TF_RET_CHECK(n_dims >= 2);
+ std::vector<int64> permutation(n_dims);
+ std::iota(permutation.begin(), permutation.end(), 0);
+ std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
+ return xla::Transpose(x, permutation);
+ });
}
-xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
- const xla::XlaOp& x, bool conjugate) {
- TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
- auto perform_conj = shape.element_type() == xla::C64 && conjugate;
- return perform_conj ? builder->Conj(x) : x;
+xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate) {
+ xla::XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ auto perform_conj = shape.element_type() == xla::C64 && conjugate;
+ return perform_conj ? xla::Conj(x) : x;
+ });
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h
index 3c120a2548..6cb6c088e9 100644
--- a/tensorflow/compiler/tf2xla/lib/util.h
+++ b/tensorflow/compiler/tf2xla/lib/util.h
@@ -23,9 +23,6 @@ limitations under the License.
namespace tensorflow {
-// Returns a zero-filled tensor with shape `shape`.
-xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape);
-
// Returns a floating point scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
@@ -33,7 +30,7 @@ xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros
// prepended until the array is length n_dims.
-xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
gtl::ArraySlice<xla::XlaOp> starts);
// Returns a integer scalar constant of 'type' with 'value'.
@@ -41,54 +38,43 @@ xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder,
xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type,
int64 value);
-// Builds a vector of zeros of length rank(x) with the last two values being
+// Builds a vector of zeros of length rank(x) with the last values being
// those in `starts`.
-xla::StatusOr<xla::XlaOp> PrependZerosInMajorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts);
+xla::XlaOp PrependZerosInMajorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Performs a slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::XlaOp> SliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- gtl::ArraySlice<int64> start,
- gtl::ArraySlice<int64> end);
+xla::XlaOp SliceInMinorDims(xla::XlaOp x, gtl::ArraySlice<int64> start,
+ gtl::ArraySlice<int64> end);
-// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`.
-std::vector<int64> PrependMajorDims(xla::XlaBuilder* builder,
- const gtl::ArraySlice<int64>& major_dims,
- const gtl::ArraySlice<int64>& indices);
+// Returns the concatenation of `xs` and `ys`.
+std::vector<int64> ConcatVectors(gtl::ArraySlice<int64> xs,
+ gtl::ArraySlice<int64> ys);
// Performs a dynamic slice in the minor dimensions of a Tensor.
-xla::StatusOr<xla::XlaOp> DynamicSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x,
- const std::vector<xla::XlaOp>& starts, const gtl::ArraySlice<int64>& sizes);
+xla::XlaOp DynamicSliceInMinorDims(xla::XlaOp x,
+ gtl::ArraySlice<xla::XlaOp> starts,
+ gtl::ArraySlice<int64> sizes);
// Updates a slice of 'x', i.e.,
// x[start[0], ..., start[n]] = update
-xla::StatusOr<xla::XlaOp> UpdateSlice(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start);
+xla::XlaOp UpdateSlice(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start);
// Updates a slice of 'x', where 'start' contains a list of minor dimensions:
// x[..., start[0], ..., start[n]] = update
-xla::StatusOr<xla::XlaOp> UpdateSliceInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x,
- const xla::XlaOp& update,
- gtl::ArraySlice<int64> start);
+xla::XlaOp UpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<int64> start);
-xla::StatusOr<xla::XlaOp> DynamicUpdateSliceInMinorDims(
- xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update,
- const std::vector<xla::XlaOp>& starts);
+xla::XlaOp DynamicUpdateSliceInMinorDims(xla::XlaOp x, xla::XlaOp update,
+ gtl::ArraySlice<xla::XlaOp> starts);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
-xla::StatusOr<xla::XlaOp> TransposeInMinorDims(xla::XlaBuilder* builder,
- const xla::XlaOp& x);
+xla::XlaOp TransposeInMinorDims(xla::XlaOp x);
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
-xla::StatusOr<xla::XlaOp> MaybeConjugate(xla::XlaBuilder* builder,
- const xla::XlaOp& x, bool conjugate);
+xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc
index 265b39402c..442fe92c34 100644
--- a/tensorflow/compiler/tf2xla/lib/util_test.cc
+++ b/tensorflow/compiler/tf2xla/lib/util_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -70,8 +70,7 @@ XLA_TEST_F(UtilTest, Simple2dLookup) {
auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
- auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1});
- TF_ASSERT_OK(result.status());
+ DynamicSliceInMinorDims(a, {x, y}, {1, 1});
ComputeAndCompareR2<float>(&builder, {{10}},
{a_data.get(), x_data.get(), y_data.get()},
@@ -86,10 +85,8 @@ XLA_TEST_F(UtilTest, Simple3dLookup) {
CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
- TF_ASSERT_OK_AND_ASSIGN(
- auto l_index,
- DynamicSliceInMinorDims(&builder, a,
- {index, builder.ConstantR0<int32>(0)}, {1, 4}));
+ DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)},
+ {1, 4});
ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
{a_data.get(), index_data.get()});
@@ -104,8 +101,7 @@ XLA_TEST_F(UtilTest, SimpleSliceUpdate) {
auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
- auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y});
- TF_ASSERT_OK(result.status());
+ DynamicUpdateSliceInMinorDims(a, b, {x, y});
xla::Array2D<float> expected(
{{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
@@ -128,13 +124,9 @@ XLA_TEST_F(UtilTest, RowBatchDot) {
// Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull().
auto index_data = CreateR0Parameter<int>(1, 2, "index", &builder, &index);
- TF_ASSERT_OK_AND_ASSIGN(
- auto l_index,
- DynamicSliceInMinorDims(&builder, a,
- {index, builder.ConstantR0<int32>(0)}, {1, n}));
- TF_ASSERT_OK_AND_ASSIGN(
- auto dot, BatchDot(&builder, l_index, row,
- /*transpose_x=*/false, /*transpose_y=*/true));
+ auto l_index = DynamicSliceInMinorDims(
+ a, {index, xla::ConstantR0<int32>(&builder, 0)}, {1, n});
+ BatchDot(l_index, row, /*transpose_x=*/false, /*transpose_y=*/true);
ComputeAndCompareR3<float>(&builder, {{{33}}, {{292}}},
{a_data.get(), row_data.get(), index_data.get()});
diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc
index 09ce594930..574e70ddee 100644
--- a/tensorflow/compiler/tf2xla/lib/while_loop.cc
+++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -39,7 +40,7 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
xla::XlaBuilder* builder) {
std::vector<xla::XlaOp> elements(arity);
for (int i = 0; i < arity; ++i) {
- elements[i] = builder->GetTupleElement(tuple, i);
+ elements[i] = xla::GetTupleElement(tuple, i);
}
return elements;
};
@@ -48,7 +49,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
std::unique_ptr<xla::XlaBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
- auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
+ auto parameter =
+ xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter");
TF_RETURN_IF_ERROR(
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
@@ -61,7 +63,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
std::unique_ptr<xla::XlaBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
- auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
+ auto parameter =
+ xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN(
auto result,
@@ -69,11 +72,11 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaWhileLoop(
body_builder.get()));
TF_RET_CHECK(result.size() == initial_values.size());
- body_builder->Tuple(result);
+ xla::Tuple(body_builder.get(), result);
}
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
- auto outputs = builder->While(cond, body, builder->Tuple(initial_values));
+ auto outputs = xla::While(cond, body, xla::Tuple(builder, initial_values));
return unpack_tuple(outputs, arity, builder);
}
@@ -86,9 +89,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
auto while_cond_fn =
[&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* cond_builder) -> xla::StatusOr<xla::XlaOp> {
- return cond_builder->Lt(
- values[0],
- IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
+ return xla::Lt(values[0], IntegerLiteral(cond_builder, num_iterations_type,
+ num_iterations));
};
auto while_body_fn = [&](gtl::ArraySlice<xla::XlaOp> values,
xla::XlaBuilder* body_builder)
@@ -97,9 +99,10 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> updated_values;
updated_values.reserve(values.size());
- updated_values.push_back(body_builder->Add(
+ updated_values.push_back(xla::Add(
iteration,
- body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
+ xla::ConstantLiteral(body_builder,
+ xla::LiteralUtil::One(num_iterations_type))));
values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::XlaOp> body_outputs,
@@ -111,8 +114,8 @@ xla::StatusOr<std::vector<xla::XlaOp>> XlaForEachIndex(
std::vector<xla::XlaOp> values;
values.reserve(initial_values.size() + 1);
- values.push_back(
- builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
+ values.push_back(xla::ConstantLiteral(
+ builder, xla::LiteralUtil::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index 43e1c1e9fe..2fb66913ad 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -17,26 +17,39 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
-Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
- xla::Shape literal_shape;
- TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
- host_tensor.dtype(), host_tensor.shape(), &literal_shape));
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal) {
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
+ host_tensor.shape(), &xla_shape));
+ *literal = xla::BorrowingLiteral(
+ static_cast<const char*>(DMAHelper::base(&host_tensor)), xla_shape);
+ return Status::OK();
+}
- *literal = xla::Literal(literal_shape);
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal) {
+ std::vector<const char*> buf_ptrs;
+ buf_ptrs.reserve(host_tensors.size());
+ std::vector<xla::Shape> tensor_shapes(host_tensors.size());
- // memcpy over the payload ...
- // TODO(phawkins): handle string types.
- size_t total_bytes = host_tensor.TotalBytes();
- if (total_bytes > 0) {
- void* dst_ptr = literal->untyped_data();
- const void* src_ptr = DMAHelper::base(&host_tensor);
- memcpy(dst_ptr, src_ptr, total_bytes);
+ for (int i = 0; i < host_tensors.size(); i++) {
+ // Validate runtime shapes and fail if it doesn't match the contract.
+ const Tensor* tensor = &host_tensors[i];
+ buf_ptrs.emplace_back(static_cast<const char*>(DMAHelper::base(tensor)));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(),
+ &tensor_shapes[i]));
}
+
+ *literal = xla::BorrowingLiteral(
+ buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes));
+
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index 220bec1553..0610a57029 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -18,16 +18,24 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
-// Copies 'host_tensor' to an XLA Literal. Fails if host_tensor is of an
-// unsupported type.
-Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
+// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by
+// 'host_tensor'.
+Status HostTensorToBorrowingLiteral(const Tensor& host_tensor,
+ xla::BorrowingLiteral* literal);
+
+// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers
+// owned by 'host_tensors'.
+Status HostTensorsToBorrowingLiteralTuple(
+ tensorflow::gtl::ArraySlice<Tensor> host_tensors,
+ xla::BorrowingLiteral* literal);
// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
// type <target_type>.
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index f3d6787daa..a3404c2b3d 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -27,7 +28,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
{
std::vector<int64> int64_values = {1, 2, 3};
std::unique_ptr<xla::Literal> int64_values_literal =
- xla::Literal::CreateR1(gtl::ArraySlice<int64>(int64_values));
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
Tensor host_tensor;
EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
@@ -48,7 +49,7 @@ TEST(LiteralUtil, LiteralToHostTensor) {
Tensor host_tensor;
std::vector<int32> int32_values = {10, 11};
std::unique_ptr<xla::Literal> int32_values_literal =
- xla::Literal::CreateR1(gtl::ArraySlice<int32>(int32_values));
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
EXPECT_TRUE(
LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
.ok());
diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD
index bb9168fa35..ace6fd1d8e 100644
--- a/tensorflow/compiler/tf2xla/ops/BUILD
+++ b/tensorflow/compiler/tf2xla/ops/BUILD
@@ -8,12 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
cc_library(
name = "xla_ops",
- srcs = [
- "dynamic_slice_ops.cc",
- "functional_ops.cc",
- "reduce_window_op.cc",
- "sendrecv_ops.cc",
- ],
+ srcs = ["xla_ops.cc"],
deps = [
"//tensorflow/core:framework",
],
diff --git a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc
deleted file mode 100644
index d6c0edbb88..0000000000
--- a/tensorflow/compiler/tf2xla/ops/dynamic_slice_ops.cc
+++ /dev/null
@@ -1,49 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-REGISTER_OP("XlaDynamicUpdateSlice")
- .Input("input: T")
- .Input("update: T")
- .Input("indices: Tindices")
- .Output("output: T")
- .Attr("T: type")
- .Attr("Tindices: {int32, int64}")
- .SetShapeFn(shape_inference::UnchangedShape)
- .Doc(R"doc(
-Wraps the XLA DynamicUpdateSlice operator, documented at
- https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
-.
-
-XlaDynamicUpdateSlice generates a result which is the value of the `input`
-operand, with a slice update overwritten at `indices`. The shape of `update`
-determines the shape of the sub-array of the result which is updated. The shape
-of indices must be rank == 1, with dimension size equal to the rank of `input`.
-
-Handling of out-of-bounds slice indices is implementation-defined.
-
-input: A `Tensor` of type T.
-indices: A vector of indices into `input`. Must have length equal to the rank of
- `input`.
-update: A `Tensor` of type T. Same rank as `input`.
-output: A `Tensor` of type T.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc
deleted file mode 100644
index 4a669f8e6e..0000000000
--- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-// TODO(b/37549631) setting the While Op to always be stateful is too
-// conservative.
-REGISTER_OP("XlaWhile")
- .Input("input: T")
- .Output("output: T")
- .Attr("T: list(type) >= 0")
- .Attr("cond: func")
- .Attr("body: func")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-output = input; While (Cond(output)) { output = Body(output) }
-
-input: A list of input tensors whose types are T.
-output: A list of output tensors whose types are T.
-cond: A function takes 'input' and returns a tensor. If the tensor is
- a scalar of non-boolean, the scalar is converted to a boolean
- according to the following rule: if the scalar is a numerical
- value, non-zero means True and zero means False; if the scalar is
- a string, non-empty means True and empty means False. If the
- tensor is not a scalar, non-emptiness means True and False
- otherwise.
-body: A function that takes a list of tensors and returns another
- list of tensors. Both lists have the same types as specified by T.
-)doc");
-
-// TODO(b/37549631) setting the If Op to always be stateful is too
-// conservative.
-REGISTER_OP("XlaIf")
- .Input("cond: Tcond")
- .Input("inputs: Tin")
- .Output("output: Tout")
- .Attr("Tcond: type")
- .Attr("then_branch: func")
- .Attr("else_branch: func")
- .Attr("Tin: list(type) >= 0")
- .Attr("Tout: list(type) >= 0")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-output = cond ? then_branch(inputs) : else_branch(inputs).
-
-cond: A boolean scalar.
-inputs: A list of input tensors.
-output: A list of tensors returned by either then_branch(inputs) or
- else_branch(inputs). The input shapes of the then_branch and
- else_branch must match.
-then_branch: A function takes 'inputs' and returns a list of tensors,
- whose types are the same as what else_branch returns.
-else_branch: A function takes 'inputs' and returns a list of tensors.
- whose types are the same as what then_branch returns.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc b/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc
deleted file mode 100644
index d9af982adc..0000000000
--- a/tensorflow/compiler/tf2xla/ops/reduce_window_op.cc
+++ /dev/null
@@ -1,45 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("XlaReduceWindow")
- .Input("input: T")
- .Input("init_value: T")
- .Attr("T: numbertype")
- .Attr("computation: func")
- .Attr("window_dimensions: list(int)")
- .Attr("window_strides: list(int)")
- .Attr("padding_low: list(int)")
- .Attr("padding_high: list(int)")
- .Output("output: T")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Wraps the XLA ReduceWindow operator, documented at
- https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
-
-input: the input tensor
-init_value: a scalar representing the initial value for the reduction
-computation: a reducer function to apply
-window_dimensions: the shape of the window
-window_strides: the inter-window strides
-padding_low: the padding to apply at the start of each input dimensions
-padding_high: the padding to apply at the end of each input dimension.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc
deleted file mode 100644
index 7ec7b50e90..0000000000
--- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc
+++ /dev/null
@@ -1,61 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("XlaSend")
- .Input("tensor: T")
- .Attr("T: type")
- .Attr("tensor_name: string")
- .SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Sends the named tensor to another XLA computation. Wraps the XLA Send operator
-documented at
- https://www.tensorflow.org/performance/xla/operation_semantics#send .
-
-tensor: The tensor to send.
-tensor_name: A string key that identifies the channel.
-)doc");
-
-REGISTER_OP("XlaRecv")
- .Output("tensor: dtype")
- .Attr("dtype: type")
- .Attr("tensor_name: string")
- .Attr("shape: shape")
- .SetIsStateful()
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- TensorShape shape_attr;
- TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
- shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
- c->set_output(0, s);
- return Status::OK();
- })
- .Doc(R"doc(
-Receives the named tensor from another XLA computation. Wraps the XLA Recv
-operator documented at
- https://www.tensorflow.org/performance/xla/operation_semantics#recv .
-
-tensor: The tensor to receive.
-dtype: The type of the tensor.
-tensor_name: A string key that identifies the channel.
-shape: The shape of the tensor.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
new file mode 100644
index 0000000000..a59c77f5c3
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -0,0 +1,182 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("XlaDynamicUpdateSlice")
+ .Input("input: T")
+ .Input("update: T")
+ .Input("indices: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA DynamicUpdateSlice operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice
+.
+
+XlaDynamicUpdateSlice generates a result which is the value of the `input`
+operand, with a slice update overwritten at `indices`. The shape of `update`
+determines the shape of the sub-array of the result which is updated. The shape
+of indices must be rank == 1, with dimension size equal to the rank of `input`.
+
+Handling of out-of-bounds slice indices is implementation-defined.
+
+input: A `Tensor` of type T.
+indices: A vector of indices into `input`. Must have length equal to the rank of
+ `input`.
+update: A `Tensor` of type T. Same rank as `input`.
+output: A `Tensor` of type T.
+)doc");
+
+// TODO(b/37549631) setting the If Op to always be stateful is too
+// conservative.
+REGISTER_OP("XlaIf")
+ .Input("cond: Tcond")
+ .Input("inputs: Tin")
+ .Output("output: Tout")
+ .Attr("Tcond: type")
+ .Attr("then_branch: func")
+ .Attr("else_branch: func")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+output = cond ? then_branch(inputs) : else_branch(inputs).
+
+cond: A boolean scalar.
+inputs: A list of input tensors.
+output: A list of tensors returned by either then_branch(inputs) or
+ else_branch(inputs). The input shapes of the then_branch and
+ else_branch must match.
+then_branch: A function takes 'inputs' and returns a list of tensors,
+ whose types are the same as what else_branch returns.
+else_branch: A function takes 'inputs' and returns a list of tensors.
+ whose types are the same as what then_branch returns.
+)doc");
+
+REGISTER_OP("XlaRecv")
+ .Output("tensor: dtype")
+ .Attr("dtype: type")
+ .Attr("tensor_name: string")
+ .Attr("shape: shape")
+ .SetIsStateful()
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ TensorShape shape_attr;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Receives the named tensor from another XLA computation. Wraps the XLA Recv
+operator documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#recv .
+
+tensor: The tensor to receive.
+dtype: The type of the tensor.
+tensor_name: A string key that identifies the channel.
+shape: The shape of the tensor.
+)doc");
+
+REGISTER_OP("XlaReduceWindow")
+ .Input("input: T")
+ .Input("init_value: T")
+ .Attr("T: numbertype")
+ .Attr("computation: func")
+ .Attr("window_dimensions: list(int)")
+ .Attr("window_strides: list(int)")
+ .Attr("padding_low: list(int)")
+ .Attr("padding_high: list(int)")
+ .Output("output: T")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA ReduceWindow operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
+
+input: the input tensor
+init_value: a scalar representing the initial value for the reduction
+computation: a reducer function to apply
+window_dimensions: the shape of the window
+window_strides: the inter-window strides
+padding_low: the padding to apply at the start of each input dimensions
+padding_high: the padding to apply at the end of each input dimension.
+)doc");
+
+REGISTER_OP("XlaSend")
+ .Input("tensor: T")
+ .Attr("T: type")
+ .Attr("tensor_name: string")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Sends the named tensor to another XLA computation. Wraps the XLA Send operator
+documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#send .
+
+tensor: The tensor to send.
+tensor_name: A string key that identifies the channel.
+)doc");
+
+REGISTER_OP("XlaSort")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA Sort operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#sort
+.
+
+Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
+
+input: A `Tensor` of type T.
+output: A `Tensor` of type T.
+)doc");
+
+// TODO(b/37549631) setting the While Op to always be stateful is too
+// conservative.
+REGISTER_OP("XlaWhile")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: list(type) >= 0")
+ .Attr("cond: func")
+ .Attr("body: func")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+output = input; While (Cond(output)) { output = Body(output) }
+
+input: A list of input tensors whose types are T.
+output: A list of output tensors whose types are T.
+cond: A function takes 'input' and returns a tensor. If the tensor is
+ a scalar of non-boolean, the scalar is converted to a boolean
+ according to the following rule: if the scalar is a numerical
+ value, non-zero means True and zero means False; if the scalar is
+ a string, non-empty means True and empty means False. If the
+ tensor is not a scalar, non-emptiness means True and False
+ otherwise.
+body: A function that takes a list of tensors and returns another
+ list of tensors. Both lists have the same types as specified by T.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index e5ce65bec9..2fc47dffb8 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -77,4 +77,6 @@ def reduce_window(operand,
recv = gen_xla_ops.xla_recv
send = gen_xla_ops.xla_send
+sort = gen_xla_ops.xla_sort
+
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 84c133ffab..f0b30dcf4e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -73,8 +74,8 @@ TEST(ConvertGraphDefToXla, Sum) {
TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation));
// Set up arguments.
- auto x_literal = xla::Literal::CreateR0<int32>(10);
- auto y_literal = xla::Literal::CreateR0<int32>(32);
+ auto x_literal = xla::LiteralUtil::CreateR0<int32>(10);
+ auto y_literal = xla::LiteralUtil::CreateR0<int32>(32);
auto x_global_or = client->TransferToServer(*x_literal);
auto y_global_or = client->TransferToServer(*y_literal);
TF_EXPECT_OK(x_global_or.status());
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index a8bd199675..319cbc74e9 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -230,10 +231,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: {
- TensorShape shape =
- is_entry_computation
- ? options_.shape_representation_fn(arg.shape, arg.type)
- : arg.shape;
+ TensorShape shape;
+ if (is_entry_computation) {
+ TF_ASSIGN_OR_RETURN(
+ shape, options_.shape_representation_fn(arg.shape, arg.type));
+ } else {
+ shape = arg.shape;
+ }
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: {
@@ -241,8 +245,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
switch (arg.resource_kind) {
case XlaResource::kVariable: {
- TensorShape representation_shape =
- options_.shape_representation_fn(arg.shape, arg.type);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ options_.shape_representation_fn(arg.shape, arg.type));
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
@@ -338,9 +343,9 @@ Status BuildComputation(
const std::vector<int>& arg_cores,
const std::vector<XlaContext::Retval>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
- bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
- xla::XlaComputation* computation, int* num_computation_outputs,
- int* num_nonconst_outputs,
+ bool return_updated_values_for_all_resources, bool always_return_tuple,
+ xla::XlaBuilder* builder, xla::XlaComputation* computation,
+ int* num_computation_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems;
@@ -384,13 +389,14 @@ Status BuildComputation(
const XlaCompiler::Argument& arg = args[resource->arg_num()];
const int core = arg_cores[resource->arg_num()];
DCHECK_LT(resource->arg_num(), arg_cores.size());
- bool modified = resource->value() != resource->initial_value();
+ bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
// TensorArray gradients were modified if their values changed or there are
// any newly created gradients.
for (const auto& grad : resource->tensor_array_gradients()) {
- modified = modified ||
- grad.second->value() != grad.second->initial_value() ||
- arg.tensor_array_gradients.count(grad.first) == 0;
+ modified =
+ modified ||
+ !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
+ arg.tensor_array_gradients.count(grad.first) == 0;
}
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
@@ -415,7 +421,7 @@ Status BuildComputation(
// create a tuple/get-tuple-element combination so that sharding
// assignment will be placed on this value, which will cause the resource
// update to be returned from the same device that provided the resource.
- handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
+ handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
elems.push_back(handle);
}
@@ -424,7 +430,9 @@ Status BuildComputation(
*num_computation_outputs = elems.size();
// Builds the XLA computation.
- builder->Tuple(elems);
+ if (always_return_tuple || elems.size() != 1) {
+ xla::Tuple(builder, elems);
+ }
builder->ClearOpMetadata();
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
@@ -551,16 +559,16 @@ Status XlaCompiler::BuildArguments(
}
xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
tuple_sharding);
- tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
+ tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
} else {
- tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
+ tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
}
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
- arg_handles[i] = builder->GetTupleElement(tuple, i);
+ arg_handles[i] = xla::GetTupleElement(tuple, i);
}
} else {
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
@@ -568,8 +576,8 @@ Status XlaCompiler::BuildArguments(
xla::XlaScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
- arg_handles[i] =
- builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i));
+ arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i],
+ strings::StrCat("arg", i));
}
}
@@ -600,7 +608,7 @@ Status XlaCompiler::BuildArguments(
// return values of functions, and then reshape unconditionally.
if (is_entry_computation) {
arg_expression.set_handle(
- builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
+ xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
} else {
arg_expression.set_handle(arg_handles[i]);
}
@@ -652,6 +660,7 @@ Status XlaCompiler::CompileSingleOp(
.Finalize(graph.get(), &node);
TF_RETURN_IF_ERROR(status);
}
+ FixupSourceAndSinkEdges(graph.get());
return CompileGraph(options, name, std::move(graph), args, result);
}
@@ -659,20 +668,17 @@ Status XlaCompiler::CompileSingleOp(
namespace {
// Check that the ops of all non-functional nodes have been registered.
-string ValidateFunctionDef(const FunctionDef* fdef,
+Status ValidateFunctionDef(const FunctionDef* fdef,
const FunctionLibraryDefinition& flib_def) {
- std::vector<string> invalid_ops;
for (const NodeDef& node : fdef->node_def()) {
const string& op = node.op();
if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) {
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) {
- invalid_ops.push_back(op);
- }
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(op, &op_def));
}
- return tensorflow::str_util::Join(invalid_ops, ", ");
+ return Status::OK();
}
// Check that the graph doesn't have any invalid nodes (e.g. incompatible with
@@ -680,35 +686,33 @@ string ValidateFunctionDef(const FunctionDef* fdef,
Status ValidateGraph(const Graph* graph,
const FunctionLibraryDefinition& flib_def,
const DeviceType& device_type, const string& name) {
- std::vector<string> invalid_ops;
+ auto maybe_error = [&](const string& op, const Status& s) -> Status {
+ if (!s.ok()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Detected unsupported operations when trying to compile graph ", name,
+ " on ", device_type.type_string(), ": ", op, " (", s.error_message(),
+ ")"));
+ }
+ return Status::OK();
+ };
+
for (const Node* node : graph->nodes()) {
if (node->type_string() == FunctionLibraryDefinition::kGradientOp) {
continue;
}
const FunctionDef* fdef = flib_def.Find(node->def().op());
+ Status s;
if (fdef) {
- string error_msg = ValidateFunctionDef(fdef, flib_def);
- if (!error_msg.empty()) {
- invalid_ops.push_back(
- strings::StrCat(node->def().op(), ":{", error_msg, "}"));
- }
+ s = ValidateFunctionDef(fdef, flib_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
continue;
}
const OpDef* op_def;
- if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) {
- invalid_ops.push_back(node->def().op());
- continue;
- }
+ s = OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def));
- if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) {
- invalid_ops.push_back(node->def().op());
- }
- }
- if (!invalid_ops.empty()) {
- return errors::InvalidArgument(strings::StrCat(
- "Detected unsupported operations when trying to compile graph ", name,
- " on ", device_type.type_string(), ":",
- tensorflow::str_util::Join(invalid_ops, ", ")));
+ s = FindKernelDef(device_type, node->def(), nullptr, nullptr);
+ TF_RETURN_IF_ERROR(maybe_error(node->def().op(), s));
}
return Status::OK();
}
@@ -766,9 +770,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(),
- options.return_updated_values_for_all_resources, &builder,
- result->computation.get(), &num_computation_outputs,
- &num_nonconst_outputs, &result->outputs, &result->resource_updates));
+ options.return_updated_values_for_all_resources,
+ options.always_return_tuple, &builder, result->computation.get(),
+ &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
+ &result->resource_updates));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index c93850ce27..079c99797e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -52,13 +53,7 @@ class XlaContext;
// (kind kResource).
//
// Only kParameter and initialized kResource arguments become runtime parameters
-// to the generated XLA computation. The XLA computation will have run-time
-// parameters in the following order:
-// +---------------------+-----------------------------------------+
-// | kParameter values | Initial values of kResource arguments |
-// +---------------------+-----------------------------------------+
-// Within each block, the arguments are arranged by the _Arg index from which
-// they were derived.
+// to the generated XLA computation.
//
// The run-time outputs of the XLA computation are arranged in the following
// order:
@@ -77,10 +72,10 @@ class XlaContext;
// tensors with a different shape to their representation inside the XLA
// computation.
//
-// In both inputs and outputs, kResource values are placed the end. When
+// In computation outputs, updated kResource values are placed the end. When
// emitting While loop bodies, we must ensure that the loop body has
-// identical input and output signatures. By moving variable values
-// to the end of the argument list and using the
+// identical input and output signatures. By passing variable values
+// at the end of the argument list and using the
// `return_updated_values_for_all_variables` option, we can ensure that the
// input and output values of resources appear at the same positions.
//
@@ -175,6 +170,11 @@ class XlaCompiler {
// computation.
bool resolve_compile_time_constants = true;
+ // If 'always_return_tuple' is true, then the output of a computation will
+ // always be a tuple. Otherwise, a single-element output will not be wrapped
+ // in a tuple.
+ bool always_return_tuple = true;
+
// True when compiling the entry computation, false for subcomputations
// (while, call, etc.)
bool is_entry_computation = true;
@@ -234,7 +234,8 @@ class XlaCompiler {
tf2xla::HostComputeMetadata host_compute_metadata;
// Resources whose values were updated by the computation, ordered
- // by return value position. Resource updates follow the non-constant
+ // by return value position (which is the same as the order the resources
+ // were passed as arguments). Resource updates follow the non-constant
// results in the outputs of XLA computation.
std::vector<ResourceUpdate> resource_updates;
@@ -242,7 +243,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
- typedef std::function<TensorShape(const TensorShape&, DataType)>
+ typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&,
+ DataType)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. It must be set by the caller.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 5fbf4b952c..6f76816a86 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -205,9 +206,9 @@ TEST_F(XlaCompilerTest, Simple) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -221,9 +222,9 @@ TEST_F(XlaCompilerTest, Simple) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -305,7 +306,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -316,9 +317,9 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get()});
EXPECT_TRUE(
xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -340,7 +341,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -350,11 +351,12 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected0 = xla::Literal::CreateR0<int32>(7);
+ std::unique_ptr<xla::Literal> expected0 =
+ xla::LiteralUtil::CreateR0<int32>(7);
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({-7, -42});
+ xla::LiteralUtil::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
}
}
@@ -568,11 +570,11 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> input_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> input_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::Literal> input =
- xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
+ xla::LiteralUtil::MakeTuple({input_base.get(), input_grad2.get()});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*input).ConsumeValueOrDie();
@@ -582,17 +584,18 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
+ std::unique_ptr<xla::Literal> output_read =
+ xla::LiteralUtil::CreateR0<int32>(42);
std::unique_ptr<xla::Literal> output_base =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> output_grad1 =
- xla::Literal::CreateR1<int32>({0, 1});
+ xla::LiteralUtil::CreateR1<int32>({0, 1});
std::unique_ptr<xla::Literal> output_grad2 =
- xla::Literal::CreateR1<int32>({-3, 101});
- std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> output_resource = xla::LiteralUtil::MakeTuple(
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
+ xla::LiteralUtil::MakeTuple({output_read.get(), output_resource.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -795,9 +798,9 @@ TEST_F(XlaCompilerTest, Variables) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({7, 42});
+ xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({-3, 101});
+ xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -811,11 +814,11 @@ TEST_F(XlaCompilerTest, Variables) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({5, 144});
+ xla::LiteralUtil::CreateR1<int32>({5, 144});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({4, 143});
+ xla::LiteralUtil::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -883,9 +886,9 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
+ xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -899,11 +902,11 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
+ xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -952,9 +955,9 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({4, 55, 1, -3});
+ xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
@@ -968,11 +971,11 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({27, 67, 35, 402});
+ xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
@@ -1020,8 +1023,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
std::move(graph), args, &result);
ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}"))
+ EXPECT_TRUE(str_util::StrContains(status.error_message(), "InvalidOp"))
<< status.error_message();
}
@@ -1049,5 +1051,42 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
<< status.error_message();
}
+TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ NodeDef no_op;
+ no_op.set_name("NoOp");
+ no_op.set_op("NoOp");
+ Status status;
+ graph->AddNode(no_op, &status);
+ TF_ASSERT_OK(status);
+
+ std::vector<XlaCompiler::Argument> args;
+ XlaCompiler compiler(DefaultOptions());
+ // No control edge linking NoOp with source/sink.
+ {
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ XlaCompiler::CompilationResult result;
+ status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
+ std::move(graph_copy), args, &result);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "The following nodes are unreachable "
+ "from the source in the graph: NoOp"))
+ << status.error_message();
+ }
+
+ // Fix control edges for NoOp.
+ {
+ std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
+ CopyGraph(*graph, graph_copy.get());
+ EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
+ std::move(graph_copy), args, &result));
+ EXPECT_EQ(0, result.resource_updates.size());
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 098072d33c..0dea366476 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -66,8 +66,8 @@ XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn)
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
@@ -92,7 +92,7 @@ void XlaContext::AddRetval(int retval_index, DataType type,
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
- const xla::Literal& literal) {
+ const xla::LiteralSlice& literal) {
VLOG(1) << "Adding retval index " << retval_index
<< " with non-data-dependent tensor to XLA computation";
if (retvals_.size() <= retval_index) {
@@ -119,8 +119,8 @@ Status XlaContext::CreateResource(
return Status::OK();
}
-TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
- DataType type) const {
+xla::StatusOr<TensorShape> XlaContext::RepresentationShape(
+ const TensorShape& shape, DataType type) const {
return (*shape_representation_fn_)(shape, type);
}
@@ -131,9 +131,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
xla::XlaBuilder b("max<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Max(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Max(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -145,9 +147,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) {
xla::XlaBuilder b("min<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Min(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Min(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -159,9 +163,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) {
xla::XlaBuilder b("add<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Add(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Add(x, y);
return b.Build().ConsumeValueOrDie();
});
}
@@ -173,9 +179,11 @@ const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) {
xla::XlaBuilder b("mul<" + type_string + ">");
xla::PrimitiveType xla_type;
TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
- auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
- auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
- b.Mul(x, y);
+ auto x =
+ xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y =
+ xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ xla::Mul(x, y);
return b.Build().ConsumeValueOrDie();
});
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 341bf6ff1f..38d8cd653c 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -47,8 +48,8 @@ class XlaContext : public ResourceBase {
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn);
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -83,7 +84,7 @@ class XlaContext : public ResourceBase {
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
- const xla::Literal& literal);
+ const xla::LiteralSlice& literal);
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
@@ -101,8 +102,8 @@ class XlaContext : public ResourceBase {
// Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`, or of an argument or return value of a top-level computation.
- TensorShape RepresentationShape(const TensorShape& shape,
- DataType type) const;
+ xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape,
+ DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -160,7 +161,7 @@ class XlaContext : public ResourceBase {
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
- const std::function<TensorShape(const TensorShape&, DataType)>*
+ const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>*
shape_representation_fn_;
// Cache of prebuilt computations indexed by their type.
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index ead229aacc..23d04d43b3 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -31,6 +31,10 @@ bool CpuOpFilter(KernelDef* kdef) {
DT_FLOAT);
return true;
}
+ // TODO(b/26783907): The CPU backend currently does not implement sort.
+ if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
+ return false;
+ }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
index 62168b6483..dc98d4fda6 100644
--- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/legacy_flags/backend_registration_flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
@@ -22,8 +23,16 @@ namespace tensorflow {
bool GpuOpFilter(KernelDef* kdef) {
// TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to
// slow code.
- if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
- kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") {
+ legacy_flags::BackendRegistrationFlags* flags =
+ legacy_flags::GetBackendRegistrationFlags();
+ VLOG(2) << "flags->tf_enable_prng_ops_gpu: " << flags->tf_enable_prng_ops_gpu;
+ if (!flags->tf_enable_prng_ops_gpu &&
+ (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" ||
+ kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal")) {
+ return false;
+ }
+ // TODO(b/26783907): The GPU backend currently does not implement sort.
+ if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
return false;
}
if (kdef->op() == "Const") {
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index f1594193af..4d1b3b1a13 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -19,9 +19,13 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -32,103 +36,71 @@ namespace tensorflow {
namespace {
-Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- bool is_min, xla::XlaOp* argminmax) {
- xla::XlaOp init_value;
- const xla::XlaComputation* reducer;
- if (is_min) {
- init_value = XlaHelpers::MaxValue(builder, input_type);
- reducer = ctx->GetOrCreateMin(input_type);
- } else {
- init_value = XlaHelpers::MinValue(builder, input_type);
- reducer = ctx->GetOrCreateMax(input_type);
- }
-
- xla::PrimitiveType xla_output_type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type));
-
- xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer,
- /*dimensions_to_reduce=*/{axis});
- std::vector<int64> broadcast_dims(input_shape.dims() - 1);
- std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
- std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- // Compute a mask that has 1s for elements equal to the maximum.
- xla::XlaOp partial_mask = builder->ConvertElementType(
- builder->Eq(input, input_max, broadcast_dims), xla_output_type);
-
- // In order to make identity elements for a bitwise And, we:
- // Left shift the 1 to the leftmost bit, yielding 0x10...0
- // Arithmetic right shift the 1 back to the rightmost bit, yielding
- // 0xFF...F
- int32 bits_in_type =
- xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1;
- xla::XlaOp shift_amount =
- XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type);
- xla::XlaOp full_mask = builder->ShiftRightArithmetic(
- builder->ShiftLeft(partial_mask, shift_amount), shift_amount);
-
- // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
- // index.
- xla::XlaOp iota;
-
- const int64 axis_size = input_shape.dim_size(axis);
- TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota));
- xla::XlaOp product =
- builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis});
-
- // If there are multiple maximum elements, choose the one with the highest
- // index.
- xla::XlaOp output =
- builder->Reduce(product, XlaHelpers::MinValue(builder, output_type),
- *ctx->GetOrCreateMax(output_type),
- /*dimensions_to_reduce=*/{axis});
- *argminmax = output;
- return Status::OK();
+xla::XlaOp ArgMinMax(xla::XlaOp input, xla::PrimitiveType output_type, int axis,
+ bool is_min) {
+ xla::XlaBuilder* builder = input.builder();
+ return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ xla::XlaOp init_value;
+ xla::XlaComputation reducer;
+ if (is_min) {
+ init_value = xla::MaxValue(builder, input_shape.element_type());
+ reducer =
+ xla::CreateScalarMinComputation(input_shape.element_type(), builder);
+ } else {
+ init_value = xla::MinValue(builder, input_shape.element_type());
+ reducer =
+ xla::CreateScalarMaxComputation(input_shape.element_type(), builder);
+ }
+
+ xla::XlaOp input_max = xla::Reduce(input, init_value, reducer,
+ /*dimensions_to_reduce=*/{axis});
+ std::vector<int64> broadcast_dims(xla::ShapeUtil::Rank(input_shape) - 1);
+ std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
+ std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
+ // Compute a mask that has 1s for elements equal to the maximum.
+ xla::XlaOp partial_mask = xla::ConvertElementType(
+ xla::Eq(input, input_max, broadcast_dims), output_type);
+
+ // In order to make identity elements for a bitwise And, we:
+ // Left shift the 1 to the leftmost bit, yielding 0x10...0
+ // Arithmetic right shift the 1 back to the rightmost bit, yielding
+ // 0xFF...F
+ int32 bits_in_type =
+ xla::ShapeUtil::ByteSizeOfPrimitiveType(output_type) * 8 - 1;
+ xla::XlaOp shift_amount =
+ xla::ConstantR0WithType(builder, output_type, bits_in_type);
+ xla::XlaOp full_mask = xla::ShiftRightArithmetic(
+ xla::ShiftLeft(partial_mask, shift_amount), shift_amount);
+
+ // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its
+ // index.
+
+ const int64 axis_size = xla::ShapeUtil::GetDimension(input_shape, axis);
+ xla::XlaOp iota = xla::Iota(builder, output_type, axis_size);
+ xla::XlaOp product =
+ xla::And(full_mask, iota, /*broadcast_dimensions=*/{axis});
+
+ // If there are multiple maximum elements, choose the one with the highest
+ // index.
+ return xla::Reduce(product, xla::MinValue(builder, output_type),
+ xla::CreateScalarMaxComputation(output_type, builder),
+ /*dimensions_to_reduce=*/{axis});
+ });
}
} // namespace
-xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::MinValue(type));
-}
-
-xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) {
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::MaxValue(type));
-}
-
xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::Zero(type));
+ return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
}
xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
- return b->ConstantLiteral(xla::Literal::One(type));
-}
-
-xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) {
- switch (data_type) {
- case DT_HALF:
- return b->ConstantR0<Eigen::half>(
- static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
- case DT_BFLOAT16:
- return b->ConstantR0<bfloat16>(bfloat16::epsilon());
- case DT_FLOAT:
- return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
- case DT_DOUBLE:
- return b->ConstantR0<double>(std::numeric_limits<double>::epsilon());
- default:
- LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: "
- << DataTypeString(data_type);
- }
+ return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
}
xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
@@ -176,44 +148,14 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) {
return linspace;
}
-Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, xla::XlaOp* argmax) {
- return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
- axis, /*is_min=*/false, argmax);
-}
-
-Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input,
- const TensorShape& input_shape, DataType input_type,
- DataType output_type, int axis, xla::XlaOp* argmin) {
- return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type,
- axis, /*is_min=*/true, argmin);
+xla::XlaOp XlaHelpers::ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis) {
+ return ArgMinMax(input, output_type, axis, /*is_min=*/false);
}
-Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
- xla::XlaOp* iota) {
- TensorShape linspace_shape({size});
- Tensor linspace;
- switch (dtype) {
- case DT_UINT8:
- linspace = MakeLinspaceTensor<uint8>(linspace_shape, size);
- break;
- case DT_INT32:
- linspace = MakeLinspaceTensor<int32>(linspace_shape, size);
- break;
- case DT_INT64:
- linspace = MakeLinspaceTensor<int64>(linspace_shape, size);
- break;
- default:
- return errors::InvalidArgument("Invalid argument type ",
- DataTypeString(dtype));
- }
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
- *iota = builder->ConstantLiteral(linspace_literal);
- return Status::OK();
+xla::XlaOp XlaHelpers::ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis) {
+ return ArgMinMax(input, output_type, axis, /*is_min=*/true);
}
Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
@@ -245,25 +187,28 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
return errors::InvalidArgument("Invalid argument type ",
DataTypeString(index_type));
}
- xla::Literal linspace_literal;
- TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal));
+
+ xla::BorrowingLiteral linspace_literal;
+ TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal));
// Broadcast the linspace constant across the indices along the new axis,
// and test equality at each position.
std::vector<int64> broadcast_dims(indices_shape.dims());
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
- xla::XlaOp one_hot_bool = builder->Eq(
- indices, builder->ConstantLiteral(linspace_literal), broadcast_dims);
+ xla::XlaOp one_hot_bool = xla::Eq(
+ indices, xla::ConstantLiteral(builder, linspace_literal), broadcast_dims);
// Selects the user-provided off_value and on_value values.
- *one_hot = builder->Select(
- one_hot_bool, builder->Broadcast(on_value, output_shape.dim_sizes()),
- builder->Broadcast(off_value, output_shape.dim_sizes()));
+ *one_hot = xla::Select(one_hot_bool,
+ xla::Broadcast(on_value, output_shape.dim_sizes()),
+ xla::Broadcast(off_value, output_shape.dim_sizes()));
return Status::OK();
}
DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
+ // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
+ // repeated floating point additions.
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return DT_FLOAT;
}
@@ -275,7 +220,7 @@ xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder,
const DataType new_element_type) {
xla::PrimitiveType convert_to;
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
- return builder->ConvertElementType(operand, convert_to);
+ return xla::ConvertElementType(operand, convert_to);
}
} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index c3fdc5252e..d6ca4ab934 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -28,14 +28,6 @@ namespace tensorflow {
// Helper methods for building XLA computations.
class XlaHelpers {
public:
- // Returns a handle representing the minimum value of a scalar
- // element of data_type.
- static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type);
-
- // Returns a handle representing the maximum value of a scalar
- // element of data_type.
- static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the zero value of a scalar
// element of data_type.
static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type);
@@ -44,10 +36,6 @@ class XlaHelpers {
// element of data_type.
static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type);
- // Returns the machine epsilon for floating-point type `data_type`, i.e.,
- // the difference between 1.0 and the next representable value.
- static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type);
-
// Returns a handle representing the given value of an integer scalar
// element of data_type.
// Note that unlike One and Zero, does not work on boolean types.
@@ -65,25 +53,15 @@ class XlaHelpers {
gtl::ArraySlice<int64> shape,
xla::Literal* output);
- // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmax`.
- static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmax);
-
- // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and
- // `input_dtype` are the shape and dtype of `input` respectively, and
- // `output_type` is the dtype to use for `argmin`.
- static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx,
- const xla::XlaOp& input, const TensorShape& input_shape,
- DataType input_type, DataType output_type, int axis,
- xla::XlaOp* argmin);
-
- // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`.
- static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size,
- xla::XlaOp* iota);
+ // Returns the argmax of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMax(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
+
+ // Returns the argmin of `input` along `axis`. `output_type` is the type to
+ // use for the output.
+ static xla::XlaOp ArgMin(xla::XlaOp input, xla::PrimitiveType output_type,
+ int axis);
// Converts `indices` into a one-hot representation. `depth` is the size
// of the new axis to add. `axis` is the position at which to add the new
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 76c68d81af..e8eafb3819 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -19,7 +19,11 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@@ -38,8 +42,7 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
- CHECK(expression->handle().builder() != nullptr ||
- expression->resource() != nullptr);
+ CHECK(expression->handle().valid() || expression->resource() != nullptr);
VLOG(1) << "Fetched T" << expression->handle();
return expression;
}
@@ -48,7 +51,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
- CHECK_EQ(expression->handle().builder(), nullptr);
+ CHECK(!expression->handle().valid());
return const_cast<XlaExpression*>(expression);
}
@@ -63,10 +66,32 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) {
return GetComputationFromTensor(context_->input(index));
}
+const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) {
+ return GetComputationFromTensor(GetInputTensorByName(name));
+}
+
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
+TensorShape XlaOpKernelContext::InputShape(StringPiece name) {
+ return GetInputTensorByName(name).shape();
+}
+
+DataType XlaOpKernelContext::input_type(int index) const {
+ return context_->input(index).dtype();
+}
+
+xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
+ xla::PrimitiveType type;
+ Status status = DataTypeToPrimitiveType(input_type(index), &type);
+ if (!status.ok()) {
+ SetStatus(status);
+ return xla::PRIMITIVE_TYPE_INVALID;
+ }
+ return type;
+}
+
Status XlaOpKernelContext::ConstantInput(int index,
xla::Literal* constant_literal) {
return ConstantInputReshaped(
@@ -87,6 +112,25 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
const XlaExpression* expression = CastExpressionFromTensor(tensor);
+ auto copy_tensor_to_literal = [](const Tensor& tensor,
+ xla::Literal* literal) {
+ xla::Shape literal_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), &literal_shape));
+
+ *literal = xla::Literal(literal_shape);
+
+ // memcpy over the payload ...
+ // TODO(phawkins): handle string types.
+ size_t total_bytes = tensor.TotalBytes();
+ if (total_bytes > 0) {
+ void* dst_ptr = literal->untyped_data();
+ const void* src_ptr = DMAHelper::base(&tensor);
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ }
+ return Status::OK();
+ };
+
// If the tensor has a known constant value, there is no need to invoke XLA.
if (expression->has_constant_value()) {
Tensor temp(tensor.dtype());
@@ -95,19 +139,21 @@ Status XlaOpKernelContext::ConstantInputReshaped(
// with the enclosing Tensor.
return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
}
- return HostTensorToLiteral(temp, constant_literal);
+
+ return copy_tensor_to_literal(temp, constant_literal);
}
// Make sure we treat zero-element tensors as constant.
if (new_shape.num_elements() == 0) {
Tensor temp(tensor.dtype(), new_shape);
- return HostTensorToLiteral(temp, constant_literal);
+
+ return copy_tensor_to_literal(temp, constant_literal);
}
xla::XlaOp handle = expression->handle();
if (new_shape != tensor.shape()) {
// Reshape the handle to the desired shape.
- handle = builder()->Reshape(handle, new_shape.dim_sizes());
+ handle = xla::Reshape(handle, new_shape.dim_sizes());
}
// The XLA layout is specified minor to major, and TensorFlow's minor
@@ -162,7 +208,8 @@ Status XlaOpKernelContext::ConstantInputReshaped(
}
// Converts an int32 or int64 scalar literal to an int64.
-static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
+static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
+ int64* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
return errors::InvalidArgument("value is not a scalar");
}
@@ -177,7 +224,8 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) {
}
// Converts an float32 or float64 scalar literal to a float64.
-static Status LiteralToFloat64Scalar(const xla::Literal& literal, double* out) {
+static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
+ double* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 0) {
return errors::InvalidArgument("value is not a scalar");
}
@@ -204,7 +252,7 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) {
}
// Converts an int32 or int64 1D literal to an int64 vector.
-static Status LiteralToInt64Vector(const xla::Literal& literal,
+static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64>* out) {
if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
return errors::InvalidArgument("value is not 1D");
@@ -292,10 +340,11 @@ Status XlaOpKernelContext::ConstantInputList(
return Status::OK();
}
-Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
- TensorShape* shape,
- xla::XlaOp* value) {
- const Tensor& tensor = context_->input(index);
+namespace {
+
+Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, TensorShape* shape,
+ xla::XlaOp* value) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
@@ -313,18 +362,34 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
*shape = variable->shape();
}
- XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(variable->shape(), variable->type());
+ XlaContext& xla_context = XlaContext::Get(ctx);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ xla_context.RepresentationShape(variable->shape(), variable->type()));
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
- *value =
- builder()->Reshape(variable->value(), variable->shape().dim_sizes());
+ *value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
}
return Status::OK();
}
+} // namespace
+
+Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(context_->input(index), type, context_, shape,
+ value);
+}
+
+Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type,
+ TensorShape* shape,
+ xla::XlaOp* value) {
+ return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
+ shape, value);
+}
+
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
@@ -368,10 +433,11 @@ void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
const TensorShape& shape = constant.shape();
- xla::Literal literal;
- OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
- xla::XlaOp handle = builder()->ConstantLiteral(literal);
- CHECK_NE(handle.builder(), nullptr);
+ xla::BorrowingLiteral literal;
+ OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
+
+ xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
+ CHECK(handle.valid());
// Make the Tensor that will refer to the expression.
Tensor* output = nullptr;
@@ -414,17 +480,17 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
return Status::OK();
}
-Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
- xla::XlaOp handle) {
- TF_RET_CHECK(handle.builder() != nullptr);
+namespace {
- const XlaExpression* expression =
- CastExpressionFromTensor(context_->input(input_index));
+Status AssignVariableTensor(const Tensor& tensor, DataType type,
+ const OpKernelContext* ctx, xla::XlaOp handle,
+ xla::XlaBuilder* builder) {
+ const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- auto shape_or_status = builder()->GetShape(handle);
+ auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -434,15 +500,31 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
- XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(shape, type);
+ XlaContext& xla_context = XlaContext::Get(ctx);
+ TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
+ xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
- handle = builder()->Reshape(handle, representation_shape.dim_sizes());
+ handle = xla::Reshape(handle, representation_shape.dim_sizes());
}
return variable->SetValue(handle);
}
+} // namespace
+
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(context_->input(input_index), type, context_,
+ handle, builder());
+}
+
+Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type,
+ xla::XlaOp handle) {
+ TF_RET_CHECK(handle.valid());
+ return AssignVariableTensor(GetInputTensorByName(name), type, context_,
+ handle, builder());
+}
+
XlaCompiler* XlaOpKernelContext::compiler() const {
return XlaContext::Get(context_).compiler();
}
@@ -482,6 +564,12 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
return XlaContext::Get(context_).GetOrCreateMul(type);
}
+const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) {
+ const Tensor* tensor;
+ CHECK(context_->input(name, &tensor).ok());
+ return *tensor;
+}
+
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 667dc262ca..6203cffd80 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -66,16 +67,26 @@ class XlaOpKernelContext {
// Returns the number of inputs to the operator.
int num_inputs() const { return context_->num_inputs(); }
- // Returns the type of input 'index'.
- DataType input_type(int index) { return context_->input(index).dtype(); }
+ // Returns the type of input `index`.
+ DataType input_type(int index) const;
- // Returns the shape of input 'index'.
+ // Returns the type of input `index` as an xla::PrimitiveType. If the type
+ // is not representable as an XLA type, sets an error status and returns
+ // xla::PRIMITIVE_TYPE_INVALID.
+ xla::PrimitiveType input_xla_type(int index);
+
+ // Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns input 'index' as a XlaOp. Unlike
+ // Returns the shape of input `name`.
+ TensorShape InputShape(StringPiece name);
+
+ // Returns input `index` as a XlaOp. Unlike
// OpKernelContext::Input returns a symbolic value rather than a concrete
// Tensor.
const xla::XlaOp& Input(int index);
+ // Returns input `name` as a XlaOp.
+ const xla::XlaOp& Input(StringPiece name);
// Returns true if all inputs are the same shape, otherwise sets the
// status to a non-OK value and returns false.
@@ -90,13 +101,13 @@ class XlaOpKernelContext {
// Helper methods for constant inputs.
- // Evaluates input 'index' and stores it in '*constant_literal'. If the
+ // Evaluates input `index` and stores it in `*constant_literal`. If the
// expression cannot be evaluated, e.g., because it depends on unbound
// parameters, returns a non-OK status.
Status ConstantInput(int index, xla::Literal* constant_literal);
- // Evaluates input 'index', reshapes it to 'new_shape' if new_shape !=
- // InputShape(index), and stores it in '*constant_literal'. If the input
+ // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
+ // InputShape(index), and stores it in `*constant_literal`. If the input
// cannot be evaluated, e.g., because it depends on unbound parameters,
// returns a non-Ok status. If InputShape(index).num_elements() !=
// new_shape.num_elements(), returns an error status.
@@ -131,17 +142,17 @@ class XlaOpKernelContext {
return context_->expected_output_dtype(index);
}
- // Sets output 'index' to the XlaOp 'handle'.
+ // Sets output `index` to the XlaOp `handle`.
// All outputs should be set using SetOutput and SetConstantOutput, not
// via the underlying OpKernelContext.
void SetOutput(int index, const xla::XlaOp& handle);
- // Sets output 'index' to compile-time constant 'host_tensor', where
- // 'host_tensor' is a tensor in host memory. It is preferable to use
+ // Sets output `index` to compile-time constant `host_tensor`, where
+ // `host_tensor` is a tensor in host memory. It is preferable to use
// SetConstantOutput where possible.
void SetConstantOutput(int index, const Tensor& host_tensor);
- // Sets output 'index' to an invalid value.
+ // Sets output `index` to an invalid value.
// Any subsequent attempt to consume this output will cause an error.
void SetInvalidOutput(int index);
@@ -151,10 +162,10 @@ class XlaOpKernelContext {
// Variables
- // Sets '*resource' to the resource associated with input `index`.
+ // Sets `*resource` to the resource associated with input `index`.
Status GetResourceInput(int index, XlaResource** resource);
- // Sets output 'index' to be a reference to resource 'resource'.
+ // Sets output `index` to be a reference to resource `resource`.
void SetResourceOutput(int index, XlaResource* resource);
// Sets `*type` and `*shape` to the current type and shape of a variable's
@@ -163,17 +174,23 @@ class XlaOpKernelContext {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
- // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+ // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if
// its type does not match `type`.
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
xla::XlaOp* value);
+ // Reads the current value of the resouce variable referred to by input
+ // `name`.
+ Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape,
+ xla::XlaOp* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. The variable must be of `type`. Returns an error if the
// variable has been initialized with a different type or with a
// different shape.
Status AssignVariable(int input_index, DataType type, xla::XlaOp handle);
+ // Assigns the value `handle` to the variable referenced by input `name`.
+ Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);
@@ -221,6 +238,9 @@ class XlaOpKernelContext {
const xla::XlaComputation* GetOrCreateMul(const DataType type);
private:
+ // Returns the tensor of input `name`.
+ const Tensor& GetInputTensorByName(StringPiece name);
+
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 4692038b61..46785bc1f0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -71,16 +71,18 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible allow_resource_types settings.";
return false;
}
- if (!x.has_device_whitelist || !y.has_device_whitelist) {
- LOG(WARNING) << "Registrations of " << x.name
- << " do not both have device whitelists.";
+ if (!x.has_device_whitelist && !y.has_device_whitelist) {
+ LOG(WARNING) << "Duplicate registrations of " << x.name
+ << "with no device whitelists.";
return false;
}
- for (const auto& device : x.device_whitelist) {
- if (y.device_whitelist.count(device) != 0) {
- LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
- << device;
- return false;
+ if (x.has_device_whitelist && y.has_device_whitelist) {
+ for (const auto& device : x.device_whitelist) {
+ if (y.device_whitelist.count(device) != 0) {
+ LOG(WARNING) << "Multiple registrations of " << x.name << " on device "
+ << device;
+ return false;
+ }
}
}
if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) {
@@ -157,97 +159,143 @@ void XlaOpRegistry::RegisterCompilationKernels() {
registry.jit_kernels_registered_ = true;
OpRegistryInterface* op_registry = OpRegistry::Global();
- for (const auto& op : registry.ops_) {
- const string& op_name = op.first;
- const std::unique_ptr<OpRegistration>& op_registration = op.second;
- const OpDef* op_def;
- Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
- if (!lookup_status.ok()) {
- LOG(ERROR) << lookup_status.error_message();
- XLA_LOG_LINES(
- ERROR, "Ops registered: \n" +
- dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
+ // Order of op registration:
+ // The goal is to allow the co-existence of backend-specific kernels and
+ // generic kernels. To achieve this, we enforce the following order of
+ // registrations for one op:
+ // 1. Process op registration with device whitelists:
+ // this pass registers backend-specific kernels for this op.
+ // 2. Process op registration without device whitelists:
+ // this pass registers the kernels for all the other supported backends.
+ for (auto& ops : registry.ops_) {
+ const string& op_name = ops.first;
+ std::vector<std::unique_ptr<OpRegistration>>& op_registrations = ops.second;
+ // Partition the op registration so that the ones with device whitelists
+ // precede the one without device whitelist.
+ std::partition(op_registrations.begin(), op_registrations.end(),
+ [](const std::unique_ptr<OpRegistration>& op_reg) {
+ return op_reg->has_device_whitelist;
+ });
+
+ // Collect a set of backend registered by ops with device whitelists.
+ // The op registration without whitelists will register a generic kernel
+ // for all other backends not in this set.
+ std::unordered_set<string> whitelisted_backend;
+ for (auto& op_registration : op_registrations) {
+ if (op_registration->has_device_whitelist) {
+ whitelisted_backend.insert(op_registration->device_whitelist.begin(),
+ op_registration->device_whitelist.end());
+ }
}
- TF_CHECK_OK(lookup_status);
- std::unordered_set<string> type_attrs;
- for (const OpDef::AttrDef& attr_def : op_def->attr()) {
- if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
- type_attrs.insert(attr_def.name());
+ for (auto& op_registration : op_registrations) {
+ const OpDef* op_def;
+ Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def);
+ if (!lookup_status.ok()) {
+ LOG(ERROR) << lookup_status.error_message();
+ XLA_LOG_LINES(
+ ERROR,
+ "Ops registered: \n" +
+ dynamic_cast<OpRegistry*>(op_registry)->DebugString(true));
}
- }
+ TF_CHECK_OK(lookup_status);
- // Checks there are no type constraints referring to unknown attributes.
- for (const auto& constraint : op_registration->type_constraints) {
- if (type_attrs.find(constraint.first) == type_attrs.end()) {
- LOG(FATAL) << "Unknown type attribute " << constraint.first
- << " in XLA op registration for " << op_name;
+ std::unordered_set<string> type_attrs;
+ for (const OpDef::AttrDef& attr_def : op_def->attr()) {
+ if (attr_def.type() == "type" || attr_def.type() == "list(type)") {
+ type_attrs.insert(attr_def.name());
+ }
}
- }
- for (auto& backend : registry.backends_) {
- // If the operator has a device whitelist, only register on whitelisted
- // devices.
- if (op_registration->has_device_whitelist &&
- op_registration->device_whitelist.find(backend.first) ==
- op_registration->device_whitelist.end()) {
- continue;
+ // Checks there are no type constraints referring to unknown attributes.
+ for (const auto& constraint : op_registration->type_constraints) {
+ if (type_attrs.find(constraint.first) == type_attrs.end()) {
+ LOG(FATAL) << "Unknown type attribute " << constraint.first
+ << " in XLA op registration for " << op_name;
+ }
}
- std::unique_ptr<KernelDef> kdef(new KernelDef);
- kdef->set_op(op_registration->name);
- kdef->set_device_type(backend.first);
-
- // Constrain each type attribute to the intersection of:
- // a) the types supported by the backend, and
- // b) the types allowed by the OpDef, and
- // c) the type constraints.
- for (const string& type_attr : type_attrs) {
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name(type_attr);
- auto* allowed_values =
- attr_constraint->mutable_allowed_values()->mutable_list();
-
- const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
- const auto* op_def_allowed_types =
- op_def_attr.has_allowed_values()
- ? &op_def_attr.allowed_values().list().type()
- : nullptr;
- auto constraint_it = op_registration->type_constraints.find(type_attr);
- const std::set<DataType>* type_constraints =
- constraint_it != op_registration->type_constraints.end()
- ? &constraint_it->second
- : nullptr;
- for (DataType dtype : backend.second.supported_types) {
- // Filter out types that aren't allowed by the OpDef.
- if (op_def_allowed_types != nullptr &&
- std::find(op_def_allowed_types->begin(),
- op_def_allowed_types->end(),
- dtype) == op_def_allowed_types->end()) {
- continue;
+ for (auto& backend : registry.backends_) {
+ // If the operator has a device whitelist, only register on whitelisted
+ // devices.
+ if (op_registration->has_device_whitelist &&
+ op_registration->device_whitelist.find(backend.first) ==
+ op_registration->device_whitelist.end()) {
+ continue;
+ }
+
+ // If the operator does NOT has a device whitelist, skip all devices
+ // that has already been registered.
+ if (!op_registration->has_device_whitelist &&
+ whitelisted_backend.find(backend.first) !=
+ whitelisted_backend.end()) {
+ continue;
+ }
+
+ std::unique_ptr<KernelDef> kdef(new KernelDef);
+ kdef->set_op(op_registration->name);
+ kdef->set_device_type(backend.first);
+
+ // Constrain each type attribute to the intersection of:
+ // a) the types supported by the backend, and
+ // b) the types allowed by the OpDef, and
+ // c) the type constraints.
+ bool unsatisfiable_type_constraint = false;
+ for (const string& type_attr : type_attrs) {
+ KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
+ attr_constraint->set_name(type_attr);
+ auto* allowed_values =
+ attr_constraint->mutable_allowed_values()->mutable_list();
+
+ const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def);
+ const auto* op_def_allowed_types =
+ op_def_attr.has_allowed_values()
+ ? &op_def_attr.allowed_values().list().type()
+ : nullptr;
+ auto constraint_it =
+ op_registration->type_constraints.find(type_attr);
+ const std::set<DataType>* type_constraints =
+ constraint_it != op_registration->type_constraints.end()
+ ? &constraint_it->second
+ : nullptr;
+ for (DataType dtype : backend.second.supported_types) {
+ // Filter out types that aren't allowed by the OpDef.
+ if (op_def_allowed_types != nullptr &&
+ std::find(op_def_allowed_types->begin(),
+ op_def_allowed_types->end(),
+ dtype) == op_def_allowed_types->end()) {
+ continue;
+ }
+ // Filter out types based on the type constraints.
+ if (type_constraints != nullptr &&
+ type_constraints->find(dtype) == type_constraints->end()) {
+ continue;
+ }
+ // Passed all the filters, this type is allowed.
+ allowed_values->add_type(dtype);
}
- // Filter out types based on the type constraints.
- if (type_constraints != nullptr &&
- type_constraints->find(dtype) == type_constraints->end()) {
- continue;
+ if (op_registration->allow_resource_types) {
+ allowed_values->add_type(DT_RESOURCE);
+ }
+ // Don't build KernelDefs that have unsatisfiable type constraints.
+ if (allowed_values->type().empty()) {
+ unsatisfiable_type_constraint = true;
+ break;
}
- // Passed all the filters, this type is allowed.
- allowed_values->add_type(dtype);
}
- if (op_registration->allow_resource_types) {
- allowed_values->add_type(DT_RESOURCE);
+ if (unsatisfiable_type_constraint) continue;
+
+ if (backend.second.op_filter != nullptr &&
+ !backend.second.op_filter(kdef.get())) {
+ continue;
}
+ VLOG(2) << "XLA op registration: device: " << backend.first
+ << " op: " << op_name;
+ registry.kernel_registrars_.emplace_back(
+ new kernel_factory::OpKernelRegistrar(
+ new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
+ backend.second.kernel_defs.push_back(std::move(kdef));
}
- if (backend.second.op_filter != nullptr &&
- !backend.second.op_filter(kdef.get())) {
- continue;
- }
- VLOG(2) << "XLA op registration: device: " << backend.first
- << " op: " << op_name;
- registry.kernel_registrars_.emplace_back(
- new kernel_factory::OpKernelRegistrar(
- new KernelDef(*kdef), "XlaJitOp", op_registration->factory));
- backend.second.kernel_defs.push_back(std::move(kdef));
}
}
}
@@ -265,12 +313,12 @@ std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
<< "Unknown backend " << compilation_device_name;
for (const std::unique_ptr<KernelDef>& k : it->second.kernel_defs) {
auto op_iter = registry.ops_.find(k->op());
- CHECK(op_iter != registry.ops_.end());
+ CHECK(op_iter != registry.ops_.end() && !op_iter->second.empty());
// The test in IsCompatible ensures that if there are multiple matching
// registrations for this op name, they all have the same value of
// compilation_only, so only the first match needs to be tested.
if (include_compilation_only_kernels ||
- !op_iter->second->compilation_only) {
+ !op_iter->second.front()->compilation_only) {
kernels.push_back(k.get());
}
}
@@ -282,10 +330,13 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
XlaOpRegistry& registry = Instance();
mutex_lock lock(registry.mutex_);
auto it = registry.ops_.find(op);
- if (it == registry.ops_.end()) {
+ if (it == registry.ops_.end() || it->second.empty()) {
return nullptr;
}
- return &it->second->compile_time_constant_inputs;
+ // The test in IsCompatible ensures that if there are multiple matching
+ // registrations for this op name, they all have the same value of
+ // compile_time_constant_inputs, so only the first match is returned.
+ return &it->second.front()->compile_time_constant_inputs;
}
std::vector<string> XlaOpRegistry::BackendNames() {
@@ -378,16 +429,15 @@ XlaOpRegistrar::XlaOpRegistrar(
std::unique_ptr<XlaOpRegistry::OpRegistration> registration) {
XlaOpRegistry& registry = XlaOpRegistry::Instance();
mutex_lock lock(registry.mutex_);
- auto existing_ops = registry.ops_.equal_range(registration->name);
- for (auto existing = existing_ops.first; existing != existing_ops.second;
- ++existing) {
- if (!XlaOpRegistry::IsCompatible(*existing->second, *registration)) {
+ auto& existing_ops = registry.ops_[registration->name];
+ for (auto& existing : existing_ops) {
+ if (!XlaOpRegistry::IsCompatible(*existing, *registration)) {
LOG(FATAL)
<< "XLA op registration " << registration->name
<< " is incompatible with existing registration of the same name.";
}
}
- registry.ops_.emplace(registration->name, std::move(registration));
+ existing_ops.emplace_back(std::move(registration));
}
XlaBackendRegistrar::XlaBackendRegistrar(
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index e255b01dd7..2d4593ea49 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -203,7 +203,7 @@ class XlaOpRegistry {
// Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
// Registrations present under the same key must satisfy IsCompatible above,
// and this is checked during registration.
- std::unordered_multimap<string, std::unique_ptr<OpRegistration>> ops_
+ std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
GUARDED_BY(mutex_);
// Have we already registered the JIT kernels on the JIT devices?
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc
new file mode 100644
index 0000000000..7b3b15b1af
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+// This test is to verify the correctness of XLA op registration with specific
+// backend overrides.
+
+// A dummy backend-specific OpKernel for CPU.
+class DummyCPUOp : public XlaOpKernel {
+ public:
+ explicit DummyCPUOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->SetOutput(0, ctx->Input(0));
+ }
+};
+
+// A dummy generic OpKernel for all backends.
+class DummyGenericOp : public XlaOpKernel {
+ public:
+ explicit DummyGenericOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->SetOutput(0, ctx->Input(0));
+ }
+};
+
+REGISTER_OP("DummyDuplicateOp")
+ .Attr("T: {float, int32}")
+ .Input("input: int32")
+ .Output("output: int32")
+ .Doc(R"doc(
+A dummy Op.
+
+input: dummy input.
+output: dummy output.
+)doc");
+
+// Register the DummyCPUOp kernel for CPU with type INT32.
+REGISTER_XLA_OP(Name("DummyDuplicateOp")
+ .Device(DEVICE_CPU_XLA_JIT)
+ .TypeConstraint("T", DT_INT32),
+ DummyCPUOp);
+// Register the DummyGeneric kernel for all registered device (except CPU since
+// it is already registered), with type FLOAT.
+REGISTER_XLA_OP(Name("DummyDuplicateOp").TypeConstraint("T", DT_FLOAT),
+ DummyGenericOp);
+
+// Test the correctness of registered kernels. The kernel registered for CPU
+// should have type INT32 while all other kernels should have type FLOAT.
+TEST(XlaOpRegistryTest, XlaOpRegistrationWithOverride) {
+ XlaOpRegistry::RegisterCompilationKernels();
+ auto registered_kernels = GetAllRegisteredKernels().kernel();
+ for (const auto& kernels : registered_kernels) {
+ if (kernels.op() == "DummyDuplicateOp") {
+ EXPECT_EQ(kernels.constraint_size(), 1);
+ EXPECT_EQ(kernels.constraint(0).name(), "T");
+ if (kernels.device_type() == "XLA_CPU_JIT") {
+ EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
+ DT_INT32);
+ } else {
+ EXPECT_EQ(kernels.constraint(0).allowed_values().list().type(0),
+ DT_FLOAT);
+ }
+ }
+ }
+}
+
+// A dummy generic OpKernel for all backends.
+class DummyInfeasibleTypeConstraintOp : public XlaOpKernel {
+ public:
+ explicit DummyInfeasibleTypeConstraintOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+ void Compile(XlaOpKernelContext* ctx) override {
+ LOG(FATAL) << "unreachable";
+ }
+};
+
+REGISTER_OP("DummyInfeasibleTypeConstraintOp")
+ .Attr("T: {float, string}")
+ .Input("input: T")
+ .Output("output: T")
+ .Doc(R"doc(
+A dummy Op.
+
+input: dummy input.
+output: dummy output.
+)doc");
+REGISTER_XLA_OP(
+ Name("DummyInfeasibleTypeConstraintOp").TypeConstraint("T", DT_STRING),
+ DummyInfeasibleTypeConstraintOp);
+
+TEST(XlaOpRegistryTest, OpWithInfeasibleTypeConstraintIsNotRegistered) {
+ XlaOpRegistry::RegisterCompilationKernels();
+ auto registered_kernels = GetAllRegisteredKernels().kernel();
+ for (const auto& kernels : registered_kernels) {
+ // The operator should not be registered.
+ EXPECT_NE(kernels.op(), "DummyInfeasibleTypeConstraintOp");
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 540c65c597..baea814965 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
namespace tensorflow {
@@ -89,16 +90,16 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
}
switch (kind_) {
case kVariable: {
- value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
- shape_.dim_sizes());
+ value_ =
+ xla::Broadcast(XlaHelpers::Zero(builder, type_), shape_.dim_sizes());
break;
}
case kTensorArray: {
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
- value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
- ta_shape.dim_sizes());
+ value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes());
break;
}
case kStack: {
@@ -106,9 +107,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
value_ =
- builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
- ta_shape.dim_sizes()),
- builder->ConstantR0<int32>(0)});
+ xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes()),
+ xla::ConstantR0<int32>(builder, 0)});
break;
}
@@ -130,8 +131,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source,
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
- xla::XlaOp gradient_value = builder->Broadcast(
- XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
+ xla::XlaOp gradient_value =
+ xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
@@ -152,7 +153,7 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const {
for (const auto& gradient : tensor_array_gradients_) {
elems.push_back(gradient.second->value_);
}
- *pack = builder->Tuple(elems);
+ *pack = xla::Tuple(builder, elems);
}
return Status::OK();
}
@@ -168,7 +169,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
} else {
TF_RET_CHECK(kind_ == kTensorArray);
int pos = 0;
- auto v = builder->GetTupleElement(pack, pos++);
+ auto v = xla::GetTupleElement(pack, pos++);
if (!initialized()) {
initial_value_ = v;
}
@@ -178,7 +179,7 @@ Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
XlaResource* gradient;
TF_RETURN_IF_ERROR(
GetOrCreateTensorArrayGradient(source, builder, &gradient));
- auto v = builder->GetTupleElement(pack, pos++);
+ auto v = xla::GetTupleElement(pack, pos++);
if (!gradient->initialized()) {
gradient->initial_value_ = v;
}
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index 9ce36d1aa7..4de18a7788 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -75,7 +75,7 @@ class XlaResource {
const xla::XlaOp& initial_value() const { return initial_value_; }
// A variable is initialized if it has a value.
- bool initialized() const { return value_.builder() != nullptr; }
+ bool initialized() const { return value_.valid(); }
// Sets the type and shape of the resource. The type and shape of a resource
// must not change once the variable has been initialized.
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index afa8ce730b..f1c383fd9e 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -53,7 +53,6 @@ xla_proto_library(
deps = [
":xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
- "//tensorflow/compiler/xla/service:session_proto",
],
)
@@ -161,6 +160,7 @@ cc_library(
hdrs = [
"iterator_util.h",
"map_util.h",
+ "overflow_util.h",
"ptr_util.h",
"util.h",
],
@@ -236,7 +236,7 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework_internal",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
@@ -254,6 +254,7 @@ tf_cc_test(
":types",
":util",
":xla_data_proto",
+ "//tensorflow/core:lib",
"//tensorflow/core:test_main",
],
)
@@ -281,9 +282,9 @@ tf_cc_test(
)
cc_library(
- name = "literal_util",
- srcs = ["literal_util.cc"],
- hdrs = ["literal_util.h"],
+ name = "literal",
+ srcs = ["literal.cc"],
+ hdrs = ["literal.h"],
visibility = ["//visibility:public"],
deps = [
":array2d",
@@ -295,17 +296,17 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
tf_cc_test(
- name = "literal_util_test",
- srcs = ["literal_util_test.cc"],
+ name = "literal_test",
+ srcs = ["literal_test.cc"],
deps = [
":array3d",
":array4d",
+ ":literal",
":literal_util",
":shape_util",
":test",
@@ -318,6 +319,26 @@ tf_cc_test(
)
cc_library(
+ name = "literal_util",
+ srcs = ["literal_util.cc"],
+ hdrs = ["literal_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":array2d",
+ ":array3d",
+ ":array4d",
+ ":literal",
+ ":shape_util",
+ ":sparse_index_array",
+ ":status_macros",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "error_spec",
hdrs = ["error_spec.h"],
)
@@ -328,6 +349,7 @@ cc_library(
hdrs = ["literal_comparison.h"],
deps = [
":error_spec",
+ ":literal",
":literal_util",
":util",
"//tensorflow/core:lib",
@@ -459,7 +481,7 @@ cc_library(
hdrs = ["packed_literal_reader.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":statusor",
@@ -490,7 +512,7 @@ cc_library(
hdrs = ["text_literal_reader.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":statusor",
@@ -506,7 +528,7 @@ tf_cc_test(
name = "text_literal_reader_test",
srcs = ["text_literal_reader_test.cc"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":test",
":text_literal_reader",
@@ -523,7 +545,7 @@ cc_library(
hdrs = ["text_literal_writer.h"],
visibility = [":internal"],
deps = [
- ":literal_util",
+ ":literal",
":shape_util",
":status_macros",
":types",
@@ -536,6 +558,7 @@ tf_cc_test(
name = "text_literal_writer_test",
srcs = ["text_literal_writer_test.cc"],
deps = [
+ ":literal",
":literal_util",
":test",
":test_helpers",
@@ -608,6 +631,7 @@ cc_library(
":array2d",
":array3d",
":array4d",
+ ":literal_util",
":util",
":window_util",
":xla_data_proto",
@@ -628,7 +652,7 @@ tf_cc_test(
":array2d",
":array3d",
":array4d",
- ":literal_util",
+ ":literal",
":reference_util",
":test",
":util",
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index c4f0c4468f..25666cad40 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -65,7 +65,7 @@ cc_library(
deps = [
":global_data",
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:service_interface",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -110,6 +110,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:source_map_util",
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 3d596a6e65..3a157c69cd 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 68f0d0ac78..69d4d300ca 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc
index dc69d2097e..5c9abad4c3 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.cc
+++ b/tensorflow/compiler/xla/client/compile_only_client.cc
@@ -24,7 +24,8 @@ namespace xla {
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options) {
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<CompileOnlyService::AotXlaComputationInstance> service_instances;
service_instances.reserve(computations.size());
for (const AotXlaComputationInstance& instance : computations) {
@@ -36,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime(
service_instance.argument_layouts = instance.argument_layouts;
service_instance.result_layout = instance.result_layout;
}
- return compiler_service_->CompileAheadOfTime(service_instances, options);
+ return compiler_service_->CompileAheadOfTime(service_instances, options,
+ metadata);
}
int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) {
diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h
index f9a7c31270..332c965036 100644
--- a/tensorflow/compiler/xla/client/compile_only_client.h
+++ b/tensorflow/compiler/xla/client/compile_only_client.h
@@ -46,13 +46,15 @@ class CompileOnlyClient : public Client {
const Shape* result_layout;
};
- // Compiles a list of xla computations for ahead-of-time execution. This is
- // intended for use in static compilation. The |options| parameter describes
- // the target for which the compiler should emit code.
+ // Compiles a list of xla computations for ahead-of-time execution.
+ // This is intended for use in static compilation. The |options|
+ // parameter describes the target for which the compiler should emit
+ // code. |metadata|, if provided, is populated during compilation.
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options);
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata = nullptr);
// Returns the size of a pointer in bytes for a given triple.
static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index 6e3c5cb484..7dee41f6a0 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
return dump_optimized_hlo_proto_to_;
}
+ExecutableBuildOptions&
+ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to(
+ tensorflow::StringPiece dirpath) {
+ dump_unoptimized_hlo_proto_to_ = dirpath.ToString();
+ return *this;
+}
+
+const tensorflow::gtl::optional<string>&
+ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const {
+ return dump_unoptimized_hlo_proto_to_;
+}
+
ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
tensorflow::StringPiece dirpath) {
dump_per_pass_hlo_proto_to_ = dirpath.ToString();
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 393da381fb..9dc9be4423 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -65,6 +65,13 @@ class ExecutableBuildOptions {
tensorflow::StringPiece dirpath);
const tensorflow::gtl::optional<string>& dump_optimized_hlo_proto_to() const;
+ // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO
+ // protobuf to (as in DebugOptions).
+ ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to(
+ tensorflow::StringPiece dirpath);
+ const tensorflow::gtl::optional<string>& dump_unoptimized_hlo_proto_to()
+ const;
+
// If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
// to (as in DebugOptions).
ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
@@ -95,6 +102,7 @@ class ExecutableBuildOptions {
bool result_layout_set_ = false;
tensorflow::gtl::optional<string> generate_hlo_graph_;
tensorflow::gtl::optional<string> dump_optimized_hlo_proto_to_;
+ tensorflow::gtl::optional<string> dump_unoptimized_hlo_proto_to_;
tensorflow::gtl::optional<string> dump_per_pass_hlo_proto_to_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
std::vector<std::string> disabled_hlo_passes_;
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index d49d959a6c..6933e9a838 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -13,11 +13,18 @@ filegroup(
]),
)
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
+
+# Generate test_suites for all backends, named "${backend}_tests".
+generate_backend_suites()
+
cc_library(
name = "arithmetic",
srcs = ["arithmetic.cc"],
hdrs = ["arithmetic.h"],
deps = [
+ ":constants",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
@@ -29,12 +36,95 @@ cc_library(
)
cc_library(
+ name = "constants",
+ srcs = ["constants.cc"],
+ hdrs = ["constants.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "constants_test",
+ srcs = ["constants_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "math",
+ srcs = ["math.cc"],
+ hdrs = ["math.h"],
+ deps = [
+ ":constants",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "math_test",
+ srcs = ["math_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":math",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "numeric",
+ srcs = ["numeric.cc"],
+ hdrs = ["numeric.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ ],
+)
+
+xla_test(
+ name = "numeric_test",
+ srcs = ["numeric_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
+ deps = [
+ ":numeric",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
name = "testing",
srcs = ["testing.cc"],
hdrs = ["testing.h"],
deps = [
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
index a1d34796cc..978fc40f34 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.cc
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -42,8 +43,8 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
}
const Shape scalar = ShapeUtil::MakeShape(type, {});
- auto lhs = b->Parameter(0, scalar, "lhs");
- auto rhs = b->Parameter(1, scalar, "rhs");
+ auto lhs = Parameter(b.get(), 0, scalar, "lhs");
+ auto rhs = Parameter(b.get(), 1, scalar, "rhs");
generator(b.get(), lhs, rhs);
return b->BuildAndNoteError();
}
@@ -55,7 +56,7 @@ XlaComputation CreateScalarAddComputation(PrimitiveType type,
return CreateScalarComputation(
"add", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Add(lhs, rhs);
+ return Add(lhs, rhs);
});
}
@@ -64,17 +65,15 @@ XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
return CreateScalarComputation(
"mul", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Mul(lhs, rhs);
+ return Mul(lhs, rhs);
});
}
XlaComputation CreateScalarGeComputation(PrimitiveType type,
XlaBuilder* builder) {
- return CreateScalarComputation(
- "ge", type, builder,
- [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Ge(lhs, rhs);
- });
+ return CreateScalarComputation("ge", type, builder,
+ [](XlaBuilder* b, const XlaOp& lhs,
+ const XlaOp& rhs) { return Ge(lhs, rhs); });
}
XlaComputation CreateScalarMaxComputation(PrimitiveType type,
@@ -82,7 +81,7 @@ XlaComputation CreateScalarMaxComputation(PrimitiveType type,
return CreateScalarComputation(
"max", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Max(lhs, rhs);
+ return Max(lhs, rhs);
});
}
@@ -91,7 +90,7 @@ XlaComputation CreateScalarMinComputation(PrimitiveType type,
return CreateScalarComputation(
"min", type, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Min(lhs, rhs);
+ return Min(lhs, rhs);
});
}
@@ -99,26 +98,27 @@ XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
return CreateScalarComputation(
"and", PRED, builder,
[](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->And(lhs, rhs);
+ return And(lhs, rhs);
});
}
XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
- return CreateScalarComputation(
- "or", PRED, builder,
- [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
- return b->Or(lhs, rhs);
- });
+ return CreateScalarComputation("or", PRED, builder,
+ [](XlaBuilder* b, const XlaOp& lhs,
+ const XlaOp& rhs) { return Or(lhs, rhs); });
}
-StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
- auto f = builder->ConstantR0<bool>(false);
- XlaComputation logical_or = CreateScalarOrComputation(builder);
- TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
- builder->GetShape(predicates));
- std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
- std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
- return builder->Reduce(predicates, f, logical_or, all_dimensions);
+XlaOp Any(XlaOp predicates) {
+ XlaBuilder* builder = predicates.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ auto f = ConstantR0<bool>(builder, false);
+ XlaComputation logical_or = CreateScalarOrComputation(builder);
+ TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
+ builder->GetShape(predicates));
+ std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
+ std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
+ return Reduce(predicates, f, logical_or, all_dimensions);
+ });
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
index 64b6b7d633..d0b916e8c8 100644
--- a/tensorflow/compiler/xla/client/lib/arithmetic.h
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -53,7 +53,7 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Returns whether any predicate in "predicates" is set.
//
// Note: if predicates is zero-sized, Any() vacuously returns false.
-StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);
+XlaOp Any(XlaOp predicates);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc
new file mode 100644
index 0000000000..031d62e4ff
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants.cc
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, LiteralUtil::Zero(type));
+}
+
+XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
+ return Broadcast(Zero(builder, shape.element_type()),
+ AsInt64Slice(shape.dimensions()));
+}
+
+XlaOp ZerosLike(XlaOp prototype) {
+ XlaBuilder* builder = prototype.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
+ return Zeros(builder, shape);
+ });
+}
+
+XlaOp One(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, LiteralUtil::One(type));
+}
+
+XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(
+ builder,
+ static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
+ case F32:
+ return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
+ case F64:
+ return ConstantR0<double>(builder,
+ std::numeric_limits<double>::epsilon());
+ default:
+ return builder->ReportError(InvalidArgument(
+ "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str()));
+ }
+}
+
+XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, LiteralUtil::MinValue(type));
+}
+
+XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(builder,
+ Eigen::NumTraits<Eigen::half>::lowest());
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::lowest());
+ case F32:
+ return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
+ case F64:
+ return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
+ default:
+ return MinValue(builder, type);
+ }
+}
+
+XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
+ return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
+}
+
+XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(builder,
+ Eigen::NumTraits<Eigen::half>::highest());
+ case BF16:
+ return ConstantR0<bfloat16>(builder, bfloat16::highest());
+ case F32:
+ return ConstantR0<float>(builder, std::numeric_limits<float>::max());
+ case F64:
+ return ConstantR0<double>(builder, std::numeric_limits<double>::max());
+ default:
+ return MaxValue(builder, type);
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h
new file mode 100644
index 0000000000..b47f5243f0
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants.h
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
+
+#include <type_traits>
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
+// determined at C++ run-time, rather than C++ compile-time.
+// If 'value' is floating point but 'type' is not, or if 'value' is complex but
+// 'type' is not, an error will be returned. This is to catch accidental
+// truncation; in such cases, use an explicit cast.
+template <typename T>
+XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
+ if (std::is_floating_point<T>::value &&
+ !(primitive_util::IsFloatingPointType(type) ||
+ primitive_util::IsComplexType(type))) {
+ return builder->ReportError(InvalidArgument(
+ "Invalid cast from floating point type to %s in ConstantR0WithType.",
+ PrimitiveType_Name(type).c_str()));
+ }
+ if (std::is_same<T, complex64>::value &&
+ !primitive_util::IsComplexType(type)) {
+ return builder->ReportError(InvalidArgument(
+ "Invalid cast from complex type to %s in ConstantR0WithType.",
+ PrimitiveType_Name(type).c_str()));
+ }
+ switch (type) {
+ case F16:
+ return ConstantR0<half>(builder, static_cast<half>(value));
+ case BF16:
+ return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
+ case F32:
+ return ConstantR0<float>(builder, static_cast<float>(value));
+ case F64:
+ return ConstantR0<double>(builder, static_cast<double>(value));
+ case C64:
+ return ConstantR0<complex64>(builder, static_cast<complex64>(value));
+ case U8:
+ return ConstantR0<uint8>(builder, static_cast<uint8>(value));
+ case U32:
+ return ConstantR0<uint32>(builder, static_cast<uint32>(value));
+ case U64:
+ return ConstantR0<uint64>(builder, static_cast<uint64>(value));
+ case S8:
+ return ConstantR0<int8>(builder, static_cast<int8>(value));
+ case S32:
+ return ConstantR0<int32>(builder, static_cast<int32>(value));
+ case S64:
+ return ConstantR0<int64>(builder, static_cast<int64>(value));
+ default:
+ return builder->ReportError(
+ InvalidArgument("Invalid type for ConstantR0WithType (%s).",
+ PrimitiveType_Name(type).c_str()));
+ }
+}
+
+// Returns a scalar containing 'value' cast to the same run-time type as
+// 'prototype'.
+// If 'value' is floating point but 'prototype' is not, or if 'value' is complex
+// 'prototype' is not, an error will be returned.
+template <typename T>
+XlaOp ScalarLike(XlaOp prototype, T value) {
+ XlaBuilder* builder = prototype.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
+ return ConstantR0WithType(builder, shape.element_type(), value);
+ });
+}
+
+// Returns a scalar with value '0' of 'type'.
+XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
+
+// Returns a zero-filled tensor with shape `shape`.
+XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
+
+// Returns a zero-filled tensor with the same shape as `prototype`.
+XlaOp ZerosLike(XlaOp prototype);
+
+// Returns a scalar with value '1' of 'type'.
+XlaOp One(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the machine epsilon for floating-point type `type`, i.e.,
+// the difference between 1.0 and the next representable value.
+XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the minimum representable finite or infinite value for 'type'.
+// Returns '-inf' for floating-point types.
+XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the minimum representable finite value for 'type'. For a floating
+// point type, this is equal to -MaxFiniteValue().
+XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the maximum representable finite or infinite value for 'type'.
+// Returns 'inf' for floating-point types.
+XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);
+
+// Returns the maximum representable finite value for 'type'.
+XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc
new file mode 100644
index 0000000000..f1e3439862
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/constants_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+using ConstantsTest = ClientLibraryTestBase;
+
+using ::testing::HasSubstr;
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::S32, 4);
+ ComputeAndCompareR0<int32>(&builder, 4, {});
+}
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeS32DoesNotAcceptFloats) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::S32, 4.5);
+ auto statusor = builder.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("Invalid cast"));
+}
+
+XLA_TEST_F(ConstantsTest, ConstantR0WithTypeF32) {
+ XlaBuilder builder(TestName());
+ ConstantR0WithType(&builder, xla::F32, -7);
+ ComputeAndCompareR0<float>(&builder, -7, {});
+ ConstantR0WithType(&builder, xla::F32, 0.5);
+ ComputeAndCompareR0<float>(&builder, 0.5, {});
+}
+
+XLA_TEST_F(ConstantsTest, ScalarLikeS32) {
+ XlaBuilder builder(TestName());
+ ScalarLike(ConstantR0<int32>(&builder, 42), -3);
+ ComputeAndCompareR0<int32>(&builder, -3, {});
+}
+
+XLA_TEST_F(ConstantsTest, ScalarLikeF32) {
+ XlaBuilder builder(TestName());
+ ScalarLike(ConstantR0<float>(&builder, 42.75), -3.2);
+ ComputeAndCompareR0<float>(&builder, -3.2, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZeroS32) {
+ XlaBuilder builder(TestName());
+ Zero(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, 0, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZeroF32) {
+ XlaBuilder builder(TestName());
+ Zero(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, 0.0, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZerosS32) {
+ XlaBuilder builder(TestName());
+ Zeros(&builder, ShapeUtil::MakeShape(S32, {2, 2}));
+ ComputeAndCompareR2<int32>(&builder, {{0, 0}, {0, 0}}, {});
+}
+
+XLA_TEST_F(ConstantsTest, ZerosLikeF32) {
+ XlaBuilder builder(TestName());
+ ZerosLike(ConstantR1<float>(&builder, {1., 2., 3.}));
+ ComputeAndCompareR1<float>(&builder, {0., 0., 0.}, {});
+}
+
+XLA_TEST_F(ConstantsTest, OneS32) {
+ XlaBuilder builder(TestName());
+ One(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, 1, {});
+}
+
+XLA_TEST_F(ConstantsTest, OneF32) {
+ XlaBuilder builder(TestName());
+ One(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, 1., {});
+}
+
+XLA_TEST_F(ConstantsTest, EpsilonF32) {
+ XlaBuilder builder(TestName());
+ Epsilon(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::epsilon(),
+ {});
+}
+
+XLA_TEST_F(ConstantsTest, MinFiniteValueS32) {
+ XlaBuilder builder(TestName());
+ MinFiniteValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxFiniteValueS32) {
+ XlaBuilder builder(TestName());
+ MaxFiniteValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinFiniteValueF32) {
+ XlaBuilder builder(TestName());
+ MinFiniteValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxFiniteValueF32) {
+ XlaBuilder builder(TestName());
+ MaxFiniteValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinValueS32) {
+ XlaBuilder builder(TestName());
+ MinValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::min(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxValueS32) {
+ XlaBuilder builder(TestName());
+ MaxValue(&builder, S32);
+ ComputeAndCompareR0<int32>(&builder, std::numeric_limits<int32>::max(), {});
+}
+
+XLA_TEST_F(ConstantsTest, MinValueF32) {
+ XlaBuilder builder(TestName());
+ MinValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, -std::numeric_limits<float>::infinity(),
+ {});
+}
+
+XLA_TEST_F(ConstantsTest, MaxValueF32) {
+ XlaBuilder builder(TestName());
+ MaxValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::infinity(),
+ {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
new file mode 100644
index 0000000000..5587559040
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace xla {
+
+XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
+
+XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
+
+XlaOp Square(XlaOp operand) { return Pow(operand, ScalarLike(operand, 2.0)); }
+
+XlaOp Reciprocal(XlaOp operand) {
+ return Pow(operand, ScalarLike(operand, -1.0));
+}
+
+namespace {
+
+// Polynomials for computing erf/erfc. Originally from cephes.
+// Note we use float for compatibility across devices, at the cost of some
+// precision for 64 bit computations.
+//
+// Coefficients are in descending order.
+std::array<float, 9> kErfcPCoefficient = {
+ 2.46196981473530512524E-10, 5.64189564831068821977E-1,
+ 7.46321056442269912687E0, 4.86371970985681366614E1,
+ 1.96520832956077098242E2, 5.26445194995477358631E2,
+ 9.34528527171957607540E2, 1.02755188689515710272E3,
+ 5.57535335369399327526E2};
+std::array<float, 9> kErfcQCoefficient = {
+ 1.00000000000000000000E0, 1.32281951154744992508E1,
+ 8.67072140885989742329E1, 3.54937778887819891062E2,
+ 9.75708501743205489753E2, 1.82390916687909736289E3,
+ 2.24633760818710981792E3, 1.65666309194161350182E3,
+ 5.57535340817727675546E2};
+std::array<float, 6> kErfcRCoefficient = {
+ 5.64189583547755073984E-1, 1.27536670759978104416E0,
+ 5.01905042251180477414E0, 6.16021097993053585195E0,
+ 7.40974269950448939160E0, 2.97886665372100240670E0};
+std::array<float, 7> kErfcSCoefficient = {
+ 1.00000000000000000000E0, 2.26052863220117276590E0,
+ 9.39603524938001434673E0, 1.20489539808096656605E1,
+ 1.70814450747565897222E1, 9.60896809063285878198E0,
+ 3.36907645100081516050E0};
+std::array<float, 5> kErfTCoefficient = {
+ 9.60497373987051638749E0, 9.00260197203842689217E1,
+ 2.23200534594684319226E3, 7.00332514112805075473E3,
+ 5.55923013010394962768E4};
+std::array<float, 6> kErfUCoefficient = {
+ 1.00000000000000000000E0, 3.35617141647503099647E1,
+ 5.21357949780152679795E2, 4.59432382970980127987E3,
+ 2.26290000613890934246E4, 4.92673942608635921086E4};
+} // namespace
+
+// Evaluate the polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients) {
+ XlaOp poly = ScalarLike(x, 0.0);
+ for (float c : coefficients) {
+ poly = poly * x + ScalarLike(x, c);
+ }
+ return poly;
+}
+
+// Compute an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x) {
+ XlaOp abs_x = Abs(x);
+ XlaOp z = Exp(-x * x);
+
+ XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient);
+ XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient);
+ XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient);
+ XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient);
+
+ XlaOp y = Select(Lt(abs_x, ScalarLike(x, 8.0)), z * pp / pq, z * pr / ps);
+
+ return Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y);
+}
+
+// Compute a polynomial approximation of the error function.
+XlaOp Erf(XlaOp x) {
+ XlaOp z = x * x;
+ XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient);
+ XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient);
+ return x * pt / pu;
+}
+
+// Approximation for the inverse error function from
+// Giles, M., "Approximating the erfinv function".
+// The approximation has the form:
+// w = -log((1 - x) * (1 + x))
+// if ( w < 5 ) {
+// w = w - 2.5
+// p = sum_{i=1}^n lq[i]*w^i
+// } else {
+// w = sqrt(w) - 3
+// p = sum_{i=1}^n gq[i]*w^i
+// }
+// return p*x
+XlaOp ErfInv(XlaOp x) {
+ XlaBuilder* b = x.builder();
+ return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
+ constexpr int kDegree = 9;
+ constexpr std::array<float, 9> w_less_than_5_constants = {
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ constexpr std::array<float, 9> w_greater_than_5_constants = {
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+
+ auto one = ScalarLike(x, 1.0);
+ auto w = -Log((one - x) * (one + x));
+
+ auto lt = Lt(w, ScalarLike(x, 5.0));
+ auto coefficient = [&](int i) {
+ return Select(lt,
+ Broadcast(ScalarLike(x, w_less_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())),
+ Broadcast(ScalarLike(x, w_greater_than_5_constants[i]),
+ AsInt64Slice(shape.dimensions())));
+ };
+ w = Select(lt, w - ScalarLike(x, 2.5), Sqrt(w) - ScalarLike(x, 3.0));
+ auto p = coefficient(0);
+ for (int i = 1; i < kDegree; ++i) {
+ p = coefficient(i) + p * w;
+ }
+ return p * x;
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
new file mode 100644
index 0000000000..e7c8b50273
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+
+namespace xla {
+
+// Computes the square root of 'operand'.
+XlaOp Sqrt(XlaOp operand);
+
+// Computes the reciprocal of the square root of 'operand'.
+XlaOp Rsqrt(XlaOp operand);
+
+// Computes the square of 'operand'.
+XlaOp Square(XlaOp operand);
+
+// Computes the reciprocal of 'operand'.
+XlaOp Reciprocal(XlaOp operand);
+
+// Evaluates a polynomial given coefficients and `x`.
+// N.B. Coefficients should be supplied in decreasing order.
+XlaOp EvaluatePolynomial(XlaOp x,
+ tensorflow::gtl::ArraySlice<float> coefficients);
+
+// Computes an approximation of the error function complement (1 - erf(x)).
+XlaOp Erfc(XlaOp x);
+
+// Computes an approximation of the error function.
+XlaOp Erf(XlaOp x);
+
+// Computes an approximation of the inverse of the error function.
+XlaOp ErfInv(XlaOp x);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_MATH_H_
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
new file mode 100644
index 0000000000..068cd2e586
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class MathTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(MathTest, SqrtF32) {
+ XlaBuilder builder(TestName());
+ Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32);
+
+ std::unique_ptr<GlobalData> zero_data =
+ client_->TransferToServer(zero_literal).ConsumeValueOrDie();
+
+ XlaOp zero = Parameter(&builder, 0, zero_literal.shape(), "zero");
+ Sqrt(zero);
+
+ ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SquareTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Square(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, ReciprocalTenValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reciprocal(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtZeroes) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {0.0, -0.0});
+ Sqrt(x);
+
+ ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
+}
+
+XLA_TEST_F(MathTest, SqrtSixValues) {
+ XlaBuilder builder(TestName());
+ auto x = ConstantR1<float>(&builder, {16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
+ Sqrt(x);
+
+ std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/numeric.cc b/tensorflow/compiler/xla/client/lib/numeric.cc
new file mode 100644
index 0000000000..fd4e8fc390
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+
+#include <numeric>
+#include <vector>
+
+namespace xla {
+
+namespace {
+
+template <typename T>
+XlaOp MakeIota(XlaBuilder* builder, int64 size) {
+ std::vector<T> values(size);
+ for (int64 i = 0; i < size; ++i) {
+ values[i] = static_cast<T>(i);
+ }
+ return xla::ConstantR1<T>(builder, values);
+}
+
+} // namespace
+
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size) {
+ switch (type) {
+ case S8:
+ return MakeIota<int8>(builder, size);
+ case S16:
+ return MakeIota<int16>(builder, size);
+ case S32:
+ return MakeIota<int32>(builder, size);
+ case S64:
+ return MakeIota<int64>(builder, size);
+ case U8:
+ return MakeIota<uint8>(builder, size);
+ case U16:
+ return MakeIota<uint16>(builder, size);
+ case U32:
+ return MakeIota<uint32>(builder, size);
+ case U64:
+ return MakeIota<uint64>(builder, size);
+ case BF16:
+ return MakeIota<bfloat16>(builder, size);
+ case F16:
+ return MakeIota<half>(builder, size);
+ case F32:
+ return MakeIota<float>(builder, size);
+ case F64:
+ return MakeIota<double>(builder, size);
+ case C64:
+ return MakeIota<complex64>(builder, size);
+ default:
+ return builder->ReportError(
+ InvalidArgument("Unimplemented type for Iota: %s.",
+ PrimitiveType_Name(type).c_str()));
+ }
+}
+
+XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m,
+ int64 n) {
+ auto a = Iota(builder, type, m);
+ auto b = Iota(builder, type, n);
+ auto indicator = Eq(a, Broadcast(b, {m}), /*broadcast_dimensions=*/{0});
+ return ConvertElementType(indicator, type);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
new file mode 100644
index 0000000000..79707007b2
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
+
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...].
+XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
+
+// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere
+// else.
+XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_NUMERIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/numeric_test.cc b/tensorflow/compiler/xla/client/lib/numeric_test.cc
new file mode 100644
index 0000000000..bc8a73e9d7
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/numeric_test.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+using NumericTest = ClientLibraryTestBase;
+
+XLA_TEST_F(NumericTest, Iota) {
+ XlaBuilder builder(TestName());
+ Iota(&builder, S32, 10);
+
+ ComputeAndCompareR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, {});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index 3380af9f30..534c509868 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@@ -48,15 +48,15 @@ int64 DataSizeOfShape(const Shape& shape) {
// Creates a XlaOp for an op what generates fake data with the given shape.
XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) {
if (ShapeUtil::IsArray(shape)) {
- return builder->Broadcast(
- builder->ConstantLiteral(Literal::One(shape.element_type())),
+ return Broadcast(
+ ConstantLiteral(builder, LiteralUtil::One(shape.element_type())),
AsInt64Slice(shape.dimensions()));
}
std::vector<XlaOp> parts;
for (const Shape& s : shape.tuple_shapes()) {
parts.push_back(BuildFakeDataOpOnDevice(s, builder));
}
- return builder->Tuple(parts);
+ return Tuple(builder, parts);
}
std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index f9003373a6..5f9710914b 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -51,24 +51,17 @@ LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& run_options, const Backend& backend) {
- const ComputationLayout& host_computation_layout =
- executable_->module_config().host_entry_computation_layout();
- const ComputationLayout& device_computation_layout =
- executable_->module_config().device_entry_computation_layout();
+ const ComputationLayout& computation_layout =
+ executable_->module_config().entry_computation_layout();
// Check argument number, shapes, and layouts.
- if (arguments.size() != host_computation_layout.parameter_count()) {
+ if (arguments.size() != computation_layout.parameter_count()) {
return InvalidArgument(
"invalid number of arguments for computation: expected %d, got %zu",
- host_computation_layout.parameter_count(), arguments.size());
- }
- if (arguments.size() != device_computation_layout.parameter_count()) {
- return InvalidArgument(
- "invalid number of arguments for computation: expected %d, got %zu",
- device_computation_layout.parameter_count(), arguments.size());
+ computation_layout.parameter_count(), arguments.size());
}
for (int i = 0; i < arguments.size(); ++i) {
- if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
arguments[i]->on_host_shape())) {
return InvalidParameterArgument(
executable_.get(), i,
@@ -76,24 +69,10 @@ Status LocalExecutable::ValidateExecutionOptions(
"parameter "
"%d: want %s, got %s",
i,
- ShapeUtil::HumanString(
- host_computation_layout.parameter_layout(i).shape())
+ ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
.c_str(),
ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str());
}
- if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape(
- arguments[i]->on_device_shape())) {
- return InvalidParameterArgument(
- executable_.get(), i,
- "Argument does not match device shape or layout of computation "
- "parameter "
- "%d: want %s, got %s",
- i,
- ShapeUtil::HumanString(
- device_computation_layout.parameter_layout(i).shape())
- .c_str(),
- ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str());
- }
}
if (run_options.stream() != nullptr) {
@@ -185,7 +164,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
run_options, backend_->StreamBorrower(),
backend_->eigen_intra_op_thread_pool());
- if (executable_->dumping()) {
+ if (executable_->dumping_snapshot()) {
return ExecuteAndDump(&service_options, arguments);
}
return executable_->ExecuteOnStreamWrapper(
@@ -195,45 +174,44 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
- executable_->session_module()->set_execution_platform(
+ executable_->hlo_snapshot()->set_execution_platform(
backend_->platform()->Name());
- TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
+ TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot()));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer result,
executable_->ExecuteOnStream(run_options, arguments,
/*hlo_execution_profile=*/nullptr));
- TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module()));
- TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
+ TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot()));
+ TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot());
return std::move(result);
}
Status LocalExecutable::RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- SessionModule* session_module) {
- session_module->clear_arguments();
+ HloSnapshot* hlo_snapshot) {
+ hlo_snapshot->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*argument));
- *session_module->add_arguments() = literal->ToProto();
+ *hlo_snapshot->add_arguments() = literal->ToProto();
}
return Status::OK();
}
Status LocalExecutable::RecordResult(const ShapedBuffer* result,
- SessionModule* session_module) {
- session_module->clear_result();
+ HloSnapshot* hlo_snapshot) {
+ hlo_snapshot->clear_result();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
LiteralFromShapedBuffer(*result));
- *session_module->mutable_result() = literal->ToProto();
+ *hlo_snapshot->mutable_result() = literal->ToProto();
return Status::OK();
}
StatusOr<std::unique_ptr<Literal>> LocalExecutable::LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer) {
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- backend_->stream_executor(shaped_buffer.device_ordinal()));
- return backend_->transfer_manager()->TransferLiteralFromDevice(executor,
+ TF_ASSIGN_OR_RETURN(auto stream,
+ backend_->BorrowStream(shaped_buffer.device_ordinal()));
+ return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(),
shaped_buffer);
}
@@ -288,19 +266,18 @@ StatusOr<ScopedShapedBuffer> LocalClient::LiteralToShapedBuffer(
TF_ASSIGN_OR_RETURN(auto scoped_buffer,
backend().transfer_manager()->AllocateScopedShapedBuffer(
literal.shape(), allocator, device_ordinal));
- TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
- backend().stream_executor(device_ordinal));
+ TF_ASSIGN_OR_RETURN(auto stream,
+ mutable_backend()->BorrowStream(device_ordinal));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- executor, literal, scoped_buffer));
+ stream.get(), literal, scoped_buffer));
return std::move(scoped_buffer);
}
StatusOr<std::unique_ptr<Literal>> LocalClient::ShapedBufferToLiteral(
const ShapedBuffer& shaped_buffer) {
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- backend().stream_executor(shaped_buffer.device_ordinal()));
- return backend().transfer_manager()->TransferLiteralFromDevice(executor,
+ TF_ASSIGN_OR_RETURN(auto stream, mutable_backend()->BorrowStream(
+ shaped_buffer.device_ordinal()));
+ return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
shaped_buffer);
}
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 5b408cc6b2..4d9e0d7cd9 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -78,11 +79,10 @@ class LocalExecutable {
// proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- SessionModule* session_module);
+ HloSnapshot* hlo_snapshot);
// Records the result of the computation in a SessionModule proto.
- Status RecordResult(const ShapedBuffer* result,
- SessionModule* session_module);
+ Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot);
// Returns a literal containing the contents of the given ShapedBuffer.
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index 507a2dc5f0..763653c685 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -1,7 +1,5 @@
# Description:
# The new XLA client libraries.
-#
-# This is NOT YET ready to use.
licenses(["notice"]) # Apache 2.0
@@ -41,9 +39,11 @@ cc_library(
name = "xla_builder",
srcs = ["xla_builder.cc"],
hdrs = ["xla_builder.h"],
+ visibility = ["//visibility:public"],
deps = [
":xla_computation",
"//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -52,6 +52,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shape_inference",
@@ -64,7 +65,7 @@ tf_cc_test(
srcs = ["xla_builder_test.cc"],
deps = [
":xla_builder",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index ae506317c2..aac7df4383 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/compiler/xla/client/sharding_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -47,6 +48,7 @@ int64 GetUniqueId() {
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
+ case HloOpcode::kAfterAll:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
@@ -59,6 +61,36 @@ bool CanBeRoot(HloOpcode opcode) {
} // namespace
+XlaOp operator-(const XlaOp& x) { return Neg(x); }
+XlaOp operator+(const XlaOp& x, const XlaOp& y) { return Add(x, y); }
+XlaOp operator-(const XlaOp& x, const XlaOp& y) { return Sub(x, y); }
+XlaOp operator*(const XlaOp& x, const XlaOp& y) { return Mul(x, y); }
+XlaOp operator/(const XlaOp& x, const XlaOp& y) { return Div(x, y); }
+XlaOp operator%(const XlaOp& x, const XlaOp& y) { return Rem(x, y); }
+
+XlaOp operator~(const XlaOp& x) { return Not(x); }
+XlaOp operator&(const XlaOp& x, const XlaOp& y) { return And(x, y); }
+XlaOp operator|(const XlaOp& x, const XlaOp& y) { return Or(x, y); }
+XlaOp operator^(const XlaOp& x, const XlaOp& y) { return Xor(x, y); }
+XlaOp operator<<(const XlaOp& x, const XlaOp& y) { return ShiftLeft(x, y); }
+
+XlaOp operator>>(const XlaOp& x, const XlaOp& y) {
+ XlaBuilder* builder = x.builder();
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x));
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Argument to >> operator does not have an integral type (%s).",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+ if (ShapeUtil::ElementIsSigned(shape)) {
+ return ShiftRightArithmetic(x, y);
+ } else {
+ return ShiftRightLogical(x, y);
+ }
+ });
+}
+
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
@@ -81,7 +113,7 @@ XlaBuilder::XlaBuilder(const string& computation_name)
XlaBuilder::~XlaBuilder() {}
-void XlaBuilder::NoteError(const Status& error) {
+XlaOp XlaBuilder::ReportError(const Status& error) {
CHECK(!error.ok());
if (die_immediately_on_error_) {
LOG(FATAL) << "error building computation: " << error;
@@ -91,19 +123,22 @@ void XlaBuilder::NoteError(const Status& error) {
first_error_ = error;
first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
}
+ return XlaOp(this);
}
-XlaOp XlaBuilder::NoteErrorOrReturn(
- const std::function<StatusOr<XlaOp>()>& op_creator) {
+XlaOp XlaBuilder::ReportErrorOrReturn(const StatusOr<XlaOp>& op) {
if (!first_error_.ok()) {
- return {};
+ return XlaOp(this);
}
- auto op = op_creator();
if (!op.ok()) {
- NoteError(op.status());
- return {};
+ return ReportError(op.status());
}
- return op.ConsumeValueOrDie();
+ return op.ValueOrDie();
+}
+
+XlaOp XlaBuilder::ReportErrorOrReturn(
+ const std::function<StatusOr<XlaOp>()>& op_creator) {
+ return ReportErrorOrReturn(op_creator());
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
@@ -207,7 +242,7 @@ XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr);
auto build_status = Build();
if (!build_status.ok()) {
- parent_builder_->NoteError(
+ parent_builder_->ReportError(
AddStatus(build_status.status(),
tensorflow::strings::StrCat("error from: ", name_)));
return {};
@@ -315,7 +350,7 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
}
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
@@ -327,7 +362,7 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
XlaOp XlaBuilder::BinaryOp(
HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -383,7 +418,7 @@ XlaOp XlaBuilder::BinaryOp(
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -430,7 +465,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
@@ -440,7 +475,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
@@ -461,7 +496,7 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
const string& name) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!parameter_numbers_.insert(parameter_number).second) {
return InvalidArgument("parameter %lld already registered",
@@ -476,7 +511,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
XlaOp XlaBuilder::Broadcast(
const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
const Shape& shape,
@@ -498,6 +533,14 @@ XlaOp XlaBuilder::Broadcast(
});
}
+XlaOp XlaBuilder::BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return InDimBroadcast(shape, operand, broadcast_dimensions);
+ });
+}
+
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
@@ -510,7 +553,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -530,7 +573,7 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
std::vector<int64> limits(shape.dimensions().begin(),
@@ -545,7 +588,7 @@ XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -566,7 +609,7 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -584,7 +627,7 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
int64 dimension) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
@@ -603,7 +646,7 @@ XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -624,7 +667,7 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
ShapeInference::InferReshapeShape(
@@ -638,7 +681,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
@@ -648,7 +691,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
// enqueueing a trivial reshape.
@@ -690,21 +733,29 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
- *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ *instr.mutable_literal() = LiteralUtil::CreateR1U8(tag)->ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
const XlaOp& on_false) {
- return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& true_shape, GetShape(on_true));
+ TF_ASSIGN_OR_RETURN(const Shape& false_shape, GetShape(on_false));
+ TF_RET_CHECK(ShapeUtil::IsTuple(true_shape) ==
+ ShapeUtil::IsTuple(false_shape));
+ HloOpcode opcode = ShapeUtil::IsTuple(true_shape) ? HloOpcode::kTupleSelect
+ : HloOpcode::kSelect;
+ return TernaryOp(opcode, pred, on_true, on_false);
+ });
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
@@ -718,7 +769,7 @@ XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
if (!ShapeUtil::IsTuple(tuple_shape)) {
@@ -767,7 +818,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
DotDimensionNumbers dimension_numbers;
@@ -780,7 +831,7 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -859,7 +910,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -905,7 +956,7 @@ XlaOp XlaBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
@@ -992,7 +1043,7 @@ StatusOr<Window> XlaBuilder::MakeWindow(
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
const tensorflow::gtl::ArraySlice<int64> fft_length) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1009,23 +1060,144 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
}
XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
- *instr.mutable_shape() = shape;
+ const Shape infeed_instruction_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ *instr.mutable_shape() = infeed_instruction_shape;
instr.set_infeed_config(config);
- return AddInstruction(std::move(instr), HloOpcode::kInfeed);
+
+ if (ShapeUtil::IsArray(shape) && sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
+ // TODO(b/110793772): Support tiled array-shaped infeeds.
+ return InvalidArgument(
+ "Tiled sharding is not yet supported for array-shaped infeeds");
+ }
+
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ return InvalidArgument(
+ "Replicated sharding is not yet supported for infeeds");
+ }
+
+ // The sharding is set by the client according to the data tuple shape.
+ // However, the shape of the infeed instruction is a tuple containing the
+ // data and a token. For tuple sharding type, the sharding must be changed
+ // to accommodate the token.
+ XlaOp infeed;
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_TUPLE) {
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ OpSharding infeed_instruction_sharding = *sharding();
+ // Arbitrarily assign the token to device 0.
+ *infeed_instruction_sharding.add_tuple_shardings() =
+ sharding_builder::AssignDevice(0);
+ XlaScopedShardingAssignment scoped_sharding(this,
+ infeed_instruction_sharding);
+ TF_ASSIGN_OR_RETURN(infeed,
+ AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ } else {
+ TF_ASSIGN_OR_RETURN(infeed,
+ AddInstruction(std::move(instr), HloOpcode::kInfeed));
+ }
+
+ // The infeed instruction produces a tuple of the infed data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto infeed_data;
+ *infeed_data.mutable_shape() = shape;
+ infeed_data.set_tuple_index(0);
+ return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
+ {infeed});
+ });
+}
+
+XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+ const Shape infeed_instruction_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ *instr.mutable_shape() = infeed_instruction_shape;
+ instr.set_infeed_config(config);
+
+ if (ShapeUtil::IsArray(shape) && sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_OTHER) {
+ // TODO(b/110793772): Support tiled array-shaped infeeds.
+ return InvalidArgument(
+ "Tiled sharding is not yet supported for array-shaped infeeds");
+ }
+
+ if (sharding() &&
+ sharding()->type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+ return InvalidArgument(
+ "Replicated sharding is not yet supported for infeeds");
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
});
}
void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
- *instr.mutable_shape() = ShapeUtil::MakeNil();
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+
+ // Check and set outfeed shape.
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::Compatible(operand_shape, shape_with_layout)) {
+ return InvalidArgument(
+ "Outfeed shape %s must be compatible with operand shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(operand_shape).c_str());
+ }
+ *instr.mutable_outfeed_shape() = shape_with_layout;
+
+ instr.set_outfeed_config(outfeed_config);
+
+ TF_RETURN_IF_ERROR(
+ AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand})
+ .status());
+
+ // The outfeed instruction produces a token. However, existing users expect
+ // a nil shape (empty tuple). This should only be relevant if the outfeed is
+ // the root of a computation.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto tuple_instr;
+ *tuple_instr.mutable_shape() = ShapeUtil::MakeNil();
+
+ // The dummy tuple should have no sharding.
+ {
+ XlaScopedShardingAssignment scoped_sharding(this, OpSharding());
+ TF_ASSIGN_OR_RETURN(
+ XlaOp empty_tuple,
+ AddInstruction(std::move(tuple_instr), HloOpcode::kTuple, {}));
+ return empty_tuple;
+ }
+ });
+}
+
+XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
@@ -1042,14 +1214,34 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
instr.set_outfeed_config(outfeed_config);
- return AddInstruction(std::move(instr), HloOpcode::kOutfeed, {operand});
+ return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
+ {operand, token});
+ });
+}
+
+XlaOp XlaBuilder::CreateToken() {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
+ });
+}
+
+XlaOp XlaBuilder::AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (tokens.empty()) {
+ return InvalidArgument("AfterAll requires at least one operand");
+ }
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
});
}
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (tensorflow::str_util::StartsWith(call_target_name, "$")) {
return InvalidArgument(
@@ -1066,7 +1258,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
const string& channel_name,
int64 cost_estimate_ns, const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape;
instr.set_channel_name(channel_name);
@@ -1120,11 +1312,9 @@ XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions);
}
-// TODO(b/65209188): Create a dedicated lowering for Xor.
XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return Or(And(Not(lhs), rhs, broadcast_dimensions),
- And(lhs, Not(rhs), broadcast_dimensions));
+ return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Not(const XlaOp& operand) {
@@ -1223,7 +1413,7 @@ XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> permutation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1238,7 +1428,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
XlaOp XlaBuilder::Rev(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1251,13 +1441,31 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::Sort(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kSort, operand);
-}
-
-XlaOp XlaBuilder::SqrtF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(0.5),
- /*broadcast_dimensions=*/{});
+XlaOp XlaBuilder::Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ std::vector<const Shape*> operand_shape_ptrs;
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ operand_shape_ptrs.push_back(&keys_shape);
+ Shape values_shape;
+ if (values.has_value()) {
+ TF_ASSIGN_OR_RETURN(values_shape, GetShape(*values));
+ operand_shape_ptrs.push_back(&values_shape);
+ }
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, operand_shape_ptrs));
+ if (dimension == -1) {
+ TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
+ dimension = ShapeUtil::Rank(keys_shape) - 1;
+ }
+ instr.add_dimensions(dimension);
+ return values.has_value()
+ ? AddInstruction(std::move(instr), HloOpcode::kSort,
+ {keys, *values})
+ : AddInstruction(std::move(instr), HloOpcode::kSort, {keys});
+ });
}
XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
@@ -1267,7 +1475,7 @@ XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1279,7 +1487,7 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
@@ -1290,16 +1498,6 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
});
}
-XlaOp XlaBuilder::SquareF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(2.0),
- /*broadcast_dimensions=*/{});
-}
-
-XlaOp XlaBuilder::ReciprocalF32(const XlaOp& operand) {
- return BinaryOp(HloOpcode::kPower, operand, ConstantR0<float>(-1.0),
- /*broadcast_dimensions=*/{});
-}
-
XlaOp XlaBuilder::Neg(const XlaOp& operand) {
return UnaryOp(HloOpcode::kNegate, operand);
}
@@ -1313,13 +1511,12 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (!static_operands.empty()) {
return Unimplemented("static_operands is not supported in Map");
}
HloInstructionProto instr;
-
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@@ -1331,16 +1528,32 @@ XlaOp XlaBuilder::Map(tensorflow::gtl::ArraySlice<XlaOp> operands,
ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape,
dimensions));
+ const Shape& output_shape = instr.shape();
+ const int64 output_rank = ShapeUtil::Rank(output_shape);
AddCalledComputation(computation, &instr);
+ std::vector<XlaOp> new_operands(operands.begin(), operands.end());
+ for (XlaOp& new_operand : new_operands) {
+ TF_ASSIGN_OR_RETURN(Shape shape, GetShape(new_operand));
+ const int64 rank = ShapeUtil::Rank(shape);
+ if (rank != output_rank) {
+ TF_ASSIGN_OR_RETURN(new_operand,
+ InDimBroadcast(output_shape, new_operand, {}));
+ TF_ASSIGN_OR_RETURN(shape, GetShape(new_operand));
+ }
+ if (!ShapeUtil::SameDimensions(output_shape, shape)) {
+ TF_ASSIGN_OR_RETURN(new_operand,
+ AddBroadcastSequence(output_shape, new_operand));
+ }
+ }
- return AddInstruction(std::move(instr), HloOpcode::kMap, operands);
+ return AddInstruction(std::move(instr), HloOpcode::kMap, new_operands);
});
}
XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
tensorflow::gtl::ArraySlice<XlaOp> parameters,
const Shape& shape) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Check the number of parameters per RNG distribution.
@@ -1378,7 +1591,7 @@ XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, const XlaOp& init) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
// Infer shape.
@@ -1400,7 +1613,7 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
@@ -1425,7 +1638,7 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
const XlaComputation& false_computation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& predicate_shape, GetShape(predicate));
@@ -1457,13 +1670,14 @@ XlaOp XlaBuilder::Reduce(
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
+
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferReduceShape(
operand_shape, init_shape, dimensions_to_reduce,
@@ -1482,7 +1696,7 @@ XlaOp XlaBuilder::Reduce(
XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
std::vector<int64> all_dimnos(ShapeUtil::Rank(operand_shape));
std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
@@ -1495,7 +1709,7 @@ XlaOp XlaBuilder::ReduceWindow(
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1518,7 +1732,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1542,7 +1756,7 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1565,7 +1779,7 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
const XlaOp& offset, const XlaOp& mean,
const XlaOp& variance, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1590,7 +1804,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
const XlaOp& batch_mean, const XlaOp& batch_var,
const XlaOp& grad_output, float epsilon,
int64 feature_index) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1611,14 +1825,40 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
});
}
-XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+XlaOp XlaBuilder::CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
+ auto b = CreateSubBuilder("sum");
+ b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
+ b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
+ TF_ASSIGN_OR_RETURN(auto computation, b->Build());
+ return CrossReplicaSum(operand, computation, replica_group_ids,
+ /*channel_id=*/tensorflow::gtl::nullopt);
+ });
+}
+
+XlaOp XlaBuilder::CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (channel_id.has_value()) {
+ return Unimplemented("channel_id is not supported in AllReduce");
+ }
+ HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
+ for (int64 replica_group_id : replica_group_ids) {
+ instr.add_replica_group_ids(replica_group_id);
+ }
+
+ AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
{operand});
@@ -1631,7 +1871,7 @@ XlaOp XlaBuilder::SelectAndScatter(
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
return SelectAndScatterWithGeneralPadding(
operand, select, window_dimensions, window_strides,
@@ -1648,7 +1888,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const XlaOp& source, const XlaOp& init_value,
const XlaComputation& scatter) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
@@ -1676,7 +1916,7 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
const int mantissa_bits) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
@@ -1690,20 +1930,51 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
}
void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
- NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Send HLO takes two operands: a data operand and a token. Generate the
+ // token to pass into the send.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
+
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
- // Send instruction produces a tuple of {aliased operand, U32 context}.
+ HloInstructionProto send_done_instr;
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ send_done_instr.set_channel_id(handle.handle());
+ return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
+ {send});
+ });
+}
+
+XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Send instruction produces a tuple of {aliased operand, U32 context,
+ // token}.
+ HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
- *instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(
- XlaOp send,
- AddInstruction(std::move(instr), HloOpcode::kSend, {operand}));
+ *send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ send_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp send,
+ AddInstruction(std::move(send_instr), HloOpcode::kSend,
+ {operand, token}));
HloInstructionProto send_done_instr;
- *send_done_instr.mutable_shape() = ShapeUtil::MakeNil();
+ *send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
send_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
{send});
@@ -1711,18 +1982,60 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
}
XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
- return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Recv HLO takes a single token operand. Generate the token to pass into
+ // the Recv and RecvDone instructions.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto token_instr;
+ *token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
+ TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
+ HloOpcode::kAfterAll, {}));
+
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ recv_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
- *instr.mutable_shape() =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- instr.set_channel_id(handle.handle());
- TF_ASSIGN_OR_RETURN(XlaOp recv,
- AddInstruction(std::move(instr), HloOpcode::kRecv, {}));
+ HloInstructionProto recv_done_instr;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
+ recv_done_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv_done,
+ AddInstruction(std::move(recv_done_instr),
+ HloOpcode::kRecvDone, {recv}));
+
+ // The RecvDone instruction produces a tuple of the data and a token
+ // type. Return XLA op containing the data.
+ // TODO(b/80000000): Remove this when clients have been updated to handle
+ // tokens.
+ HloInstructionProto recv_data;
+ *recv_data.mutable_shape() = shape;
+ recv_data.set_tuple_index(0);
+ return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
+ {recv_done});
+ });
+}
+
+XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ // Recv instruction produces a tuple of {receive buffer, U32 context,
+ // token}.
+ HloInstructionProto recv_instr;
+ *recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
+ {shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
+ recv_instr.set_channel_id(handle.handle());
+ TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
+ HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
- *recv_done_instr.mutable_shape() = shape;
+ *recv_done_instr.mutable_shape() =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
recv_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
@@ -1966,9 +2279,526 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
return &instructions_[op.handle()];
}
-XlaOp XlaBuilder::UnimplementedOp() {
- NoteError(Unimplemented("Op not implemented"));
- return {};
+// Enqueues a "retrieve parameter value" instruction for a parameter that was
+// passed to the computation.
+XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
+ const string& name) {
+ return builder->Parameter(parameter_number, shape, name);
+}
+
+// Enqueues a constant with the value of the given literal onto the
+// computation.
+XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal) {
+ return builder->ConstantLiteral(literal);
+}
+
+XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ return operand.builder()->Broadcast(operand, broadcast_sizes);
+}
+
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return operand.builder()->BroadcastInDim(operand, shape,
+ broadcast_dimensions);
+}
+
+XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config) {
+ return operand.builder()->Pad(operand, padding_value, padding_config);
+}
+
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ return operand.builder()->Reshape(operand, dimensions, new_sizes);
+}
+
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ return operand.builder()->Reshape(operand, new_sizes);
+}
+
+XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return operand.builder()->Collapse(operand, dimensions);
+}
+
+XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides) {
+ return operand.builder()->Slice(operand, start_indices, limit_indices,
+ strides);
+}
+
+XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno) {
+ return operand.builder()->SliceInDim(operand, start_index, limit_index,
+ stride, dimno);
+}
+
+XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ return operand.builder()->DynamicSlice(operand, start_indices, slice_sizes);
+}
+
+XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices) {
+ return operand.builder()->DynamicUpdateSlice(operand, update, start_indices);
+}
+
+XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension) {
+ return builder->ConcatInDim(operands, dimension);
+}
+
+void Trace(const string& tag, const XlaOp& operand) {
+ return operand.builder()->Trace(tag, operand);
+}
+
+XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false) {
+ return pred.builder()->Select(pred, on_true, on_false);
+}
+
+XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements) {
+ return builder->Tuple(elements);
+}
+
+XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index) {
+ return tuple_data.builder()->GetTupleElement(tuple_data, index);
+}
+
+XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs) {
+ return lhs.builder()->Dot(lhs, rhs);
+}
+
+XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers);
+}
+
+XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ return lhs.builder()->Conv(lhs, rhs, window_strides, padding);
+}
+
+XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides,
+ padding);
+}
+
+XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return lhs.builder()->ConvWithGeneralDimensions(lhs, rhs, window_strides,
+ padding, dimension_numbers);
+}
+
+XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding,
+ dimension_numbers);
+}
+
+XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return lhs.builder()->ConvGeneralDilated(lhs, rhs, window_strides, padding,
+ lhs_dilation, rhs_dilation,
+ dimension_numbers);
+}
+
+XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length) {
+ return operand.builder()->Fft(operand, fft_type, fft_length);
+}
+
+XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
+ return builder->Infeed(shape, config);
+}
+
+void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return operand.builder()->Outfeed(operand, shape_with_layout, outfeed_config);
+}
+
+XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands) {
+ return builder->Call(computation, operands);
+}
+
+XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape) {
+ return builder->CustomCall(call_target_name, operands, shape);
+}
+
+XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape) {
+ return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape);
+}
+
+XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return real.builder()->Complex(real, imag, broadcast_dimensions);
+}
+
+XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
+
+XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
+
+XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ return operand.builder()->Reduce(operand, init_value, computation,
+ dimensions_to_reduce);
+}
+
+XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation) {
+ return operand.builder()->ReduceAll(operand, init_value, computation);
+}
+
+XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ return operand.builder()->ReduceWindow(operand, init_value, computation,
+ window_dimensions, window_strides,
+ padding);
+}
+
+XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ return operand.builder()->ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ padding);
+}
+
+XlaOp CrossReplicaSum(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
+ return operand.builder()->CrossReplicaSum(operand, replica_group_ids);
+}
+
+XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
+ return operand.builder()->CrossReplicaSum(operand, computation,
+ replica_group_ids, channel_id);
+}
+
+XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter) {
+ return operand.builder()->SelectAndScatter(operand, select, window_dimensions,
+ window_strides, padding, source,
+ init_value, scatter);
+}
+
+XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter) {
+ return operand.builder()->SelectAndScatterWithGeneralPadding(
+ operand, select, window_dimensions, window_strides, padding, source,
+ init_value, scatter);
+}
+
+XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
+
+XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return y.builder()->Atan2(y, x, broadcast_dimensions);
+}
+
+XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); }
+
+XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); }
+
+XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); }
+
+XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); }
+
+XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); }
+
+XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); }
+
+XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); }
+
+XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); }
+
+XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); }
+
+XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); }
+
+XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); }
+
+XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); }
+
+XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
+
+XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
+
+XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
+}
+
+XlaOp IsFinite(const XlaOp& operand) {
+ return operand.builder()->IsFinite(operand);
+}
+
+XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) {
+ return operand.builder()->ConvertElementType(operand, new_element_type);
+}
+
+XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type) {
+ return operand.builder()->BitcastConvertType(operand, new_element_type);
+}
+
+XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
+
+XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ return operand.builder()->Transpose(operand, permutation);
+}
+
+XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ return operand.builder()->Rev(operand, dimensions);
+}
+
+XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension) {
+ return keys.builder()->Sort(keys, std::move(values), dimension);
+}
+
+XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max) {
+ return min.builder()->Clamp(min, operand, max);
+}
+
+XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands) {
+ return builder->Map(operands, computation, dimensions, static_operands);
+}
+
+XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape) {
+ return mu.builder()->RngNormal(mu, sigma, shape);
+}
+
+XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape) {
+ return a.builder()->RngUniform(a, b, shape);
+}
+
+XlaOp While(const XlaComputation& condition, const XlaComputation& body,
+ const XlaOp& init) {
+ return init.builder()->While(condition, body, init);
+}
+
+XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation) {
+ return predicate.builder()->Conditional(predicate, true_operand,
+ true_computation, false_operand,
+ false_computation);
+}
+
+XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits) {
+ return operand.builder()->ReducePrecision(operand, exponent_bits,
+ mantissa_bits);
+}
+
+XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ return input.builder()->Gather(input, gather_indices, dimension_numbers,
+ window_bounds);
+}
+
+void Send(const XlaOp& operand, const ChannelHandle& handle) {
+ return operand.builder()->Send(operand, handle);
+}
+
+XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle) {
+ return builder->Recv(shape, handle);
+}
+
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle) {
+ return operand.builder()->SendWithToken(operand, token, handle);
+}
+
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle) {
+ return token.builder()->RecvWithToken(token, shape, handle);
+}
+
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config) {
+ return token.builder()->InfeedWithToken(token, shape, config);
+}
+
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config) {
+ return operand.builder()->OutfeedWithToken(operand, token, shape_with_layout,
+ outfeed_config);
+}
+
+XlaOp CreateToken(XlaBuilder* builder) { return builder->CreateToken(); }
+
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens) {
+ return builder->AfterAll(tokens);
+}
+
+XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index) {
+ return operand.builder()->BatchNormTraining(operand, scale, offset, epsilon,
+ feature_index);
+}
+
+XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index) {
+ return operand.builder()->BatchNormInference(
+ operand, scale, offset, mean, variance, epsilon, feature_index);
+}
+
+XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index) {
+ return operand.builder()->BatchNormGrad(operand, scale, batch_mean, batch_var,
+ grad_output, epsilon, feature_index);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 2b3013a91c..2be6f4a553 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -18,10 +18,12 @@ limitations under the License.
#include <map>
#include <string>
+#include <type_traits>
#include <utility>
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -46,17 +48,23 @@ class XlaBuilder;
// instruction as an operand.
class XlaOp {
public:
- XlaOp() : handle_(0), builder_(nullptr) {}
- ~XlaOp() {}
+ XlaOp() : handle_(-1), builder_(nullptr) {
+ static_assert(std::is_trivially_destructible<XlaOp>::value,
+ "XlaOp should be trivially destructible");
+ }
+ ~XlaOp() = default;
- const XlaBuilder* builder() const { return builder_; }
+ XlaBuilder* builder() const { return builder_; }
- bool operator==(const XlaOp& rhs) const {
- return handle_ == rhs.handle_ && builder_ == rhs.builder_;
- }
+ // Returns true if the XlaOp represents valid, non-erroneous value.
+ bool valid() const { return handle_ >= 0; }
+
+ // Returns true if the XlaOp was created by the XlaOp() constructor and
+ // not returned by a builder.
+ bool IsUninitialized() const { return builder_ == nullptr; }
- bool operator!=(const XlaOp& rhs) const {
- return handle_ != rhs.handle_ || builder_ != rhs.builder_;
+ bool IsIdenticalTo(const XlaOp& rhs) const {
+ return handle_ == rhs.handle_ && builder_ == rhs.builder_;
}
friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
@@ -65,6 +73,7 @@ class XlaOp {
}
private:
+ explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle, XlaBuilder* builder)
: handle_(handle), builder_(builder) {}
@@ -72,10 +81,38 @@ class XlaOp {
friend class XlaBuilder;
+ // < 0 means "invalid handle".
int64 handle_;
- XlaBuilder* builder_; // Not owned.
+
+ // Not owned. Non-null for any handle returned by XlaBuilder, even if the
+ // handle is invalid.
+ XlaBuilder* builder_;
};
+// Arithmetic operator overloads for the XlaOp type.
+XlaOp operator-(const XlaOp& x);
+XlaOp operator+(const XlaOp& x, const XlaOp& y);
+XlaOp operator-(const XlaOp& x, const XlaOp& y);
+XlaOp operator*(const XlaOp& x, const XlaOp& y);
+XlaOp operator/(const XlaOp& x, const XlaOp& y);
+XlaOp operator%(const XlaOp& x, const XlaOp& y);
+
+// Bitwise operator overloads for the XlaOp type.
+XlaOp operator~(const XlaOp& x);
+XlaOp operator&(const XlaOp& x, const XlaOp& y);
+XlaOp operator|(const XlaOp& x, const XlaOp& y);
+XlaOp operator^(const XlaOp& x, const XlaOp& y);
+XlaOp operator<<(const XlaOp& x, const XlaOp& y);
+// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
+// a right logical shift.
+XlaOp operator>>(const XlaOp& x, const XlaOp& y);
+
+// We don't overload the relational operators (==, !=, <, <=, >, >=) because the
+// semantics might be surprising since their result types are usually 'bool'.
+// Further programmers may expect == to be a structural equality.
+// We also choose not to overload any of the mutating operators (e.g., +=, -=)
+// because the semantics might be misleading — XLA computations are immutable.
+
// A convenient interface for building up computations.
//
// Thread-compatible.
@@ -122,6 +159,93 @@ class XlaBuilder {
die_immediately_on_error_ = enabled;
}
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Returns an error if the convolution dimension numbers have conflicts.
+ static Status Validate(const ConvolutionDimensionNumbers& dnum);
+
+ // Returns a new XlaBuilder whose resultant Computation is used only by this
+ // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
+ // behavior as the parent.
+ std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status. Note that all ops that have been enqueued will be moved to the
+ // computation being returned.
+ StatusOr<XlaComputation> Build();
+
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent XlaBuilder and returns an empty computation if building failed.
+ // This function is intended to be used where the returned XlaComputation is
+ // only used by the parent XlaBuilder and hence further operation on the
+ // returned XlaComputation will simply be error'ed out if an error occurred
+ // while building this computation. If the built computation is to be used by
+ // a XlaBuilder other than the parent XlaBuilder then Build() should be used
+ // instead.
+ XlaComputation BuildAndNoteError();
+
+ // Returns a subgraph that roots on the given root. If the root is not a
+ // compile-time constant (see `IsConstant`), returns an error.
+ //
+ // This will copy the needed ops/computations to the subgraph.
+ StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
+
+ // Returns the first error that was encountered while building the
+ // computation. When an error is encountered, by default we return a vacuous
+ // XlaOp and inform the user of the error that occurred while
+ // building the computation when they make a final call to Build().
+ //
+ // See also set_die_immediately_on_error().
+ Status first_error() const { return first_error_; }
+
+ // Returns the shape of the given op.
+ StatusOr<Shape> GetShape(const XlaOp& op) const;
+
+ // Returns the (inferred) result for the current computation's shape.
+ StatusOr<ProgramShape> GetProgramShape() const;
+
+ // Reports an error to the builder, by
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to
+ // Build()/GetShape()/...)
+ // * dying if die_immediately_on_error_ is true.
+ // Returns an XlaOp with an invalid handle but a valid builder. This value can
+ // be returned in place of a value in APIs that return an XlaOp.
+ XlaOp ReportError(const Status& error);
+
+ // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
+ // If the Status was an error, reports the error to builder and returns an
+ // invalid XlaOp handle.
+ XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
+
+ // A helper function that runs a function that returns a StatusOr<XlaOp> and
+ // returns an XlaOp.
+ XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on any parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`.
+ //
+ // This tests whether a computation is a compile-time constant without
+ // evaluating the computation.
+ StatusOr<bool> IsConstant(const XlaOp& operand) const;
+
+ private:
// Enqueues a "retrieve parameter value" instruction for a parameter that was
// passed to the computation.
XlaOp Parameter(int64 parameter_number, const Shape& shape,
@@ -194,6 +318,27 @@ class XlaBuilder {
XlaOp Broadcast(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+ // Performs in-dimension-style broadcast.
+ //
+ // Operand specifies the input to be broadcast. "shape" is expected output
+ // shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+ // Dimension numbers in broadcast_dimensions map to individual dimensions
+ // of the operand, and specify what dimension of the output shape they
+ // should be broadcast.
+ // e.g.
+ // Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+ // and dimension of shape is [2,2].
+ // Specifying {1} as brodcast_dimension will generate output
+ // [1 , 2]
+ // [1 , 2]
+ // On the other hand, specifying {0} as broadcast_dimension
+ // will generate output
+ // [1 , 1]
+ // [2 , 2]
+ XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
// Enqueues a pad operation onto the computation that pads the given value on
// the edges as well as between the elements of the input. padding_config
// specifies the padding amount for each dimension.
@@ -342,26 +487,6 @@ class XlaBuilder {
XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers);
- // Default dimension numbers used for a 2D convolution.
- static constexpr int64 kConvBatchDimension = 0;
- static constexpr int64 kConvFeatureDimension = 1;
- static constexpr int64 kConvFirstSpatialDimension = 2;
- static constexpr int64 kConvSecondSpatialDimension = 3;
- static constexpr int64 kConvKernelOutputDimension = 0;
- static constexpr int64 kConvKernelInputDimension = 1;
- static constexpr int64 kConvKernelFirstSpatialDimension = 2;
- static constexpr int64 kConvKernelSecondSpatialDimension = 3;
-
- // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
- // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
- // the kernel operand
- // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
- static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
- int num_spatial_dims = 2);
-
- // Returns an error if the convolution dimension numbers have conflicts.
- static Status Validate(const ConvolutionDimensionNumbers& dnum);
-
// Enqueues a convolution instruction onto the computation, which uses the
// default convolution dimension numbers.
XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
@@ -408,6 +533,8 @@ class XlaBuilder {
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(const Shape& shape, const string& config = "");
+ XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
@@ -417,6 +544,9 @@ class XlaBuilder {
// will occur.
void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
const string& outfeed_config);
+ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
// Enqueues a call instruction onto the computation.
XlaOp Call(const XlaComputation& computation,
@@ -528,9 +658,35 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- // Returns the sum of the operand value across all replicas. All replicas
- // supply one input to the sum and all replicas receive the resulting sum.
- XlaOp CrossReplicaSum(const XlaOp& operand);
+ // Returns the sum of the operand value within each subgroup of replicas. All
+ // replicas supply one input to the sum and all replicas receive the resulting
+ // sum for each subgroup.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+
+ // Enqueues an operation that do an AllReduce of the operand cross cores. Here
+ // AllReduce means doing a reduction on the input operand cross cores and then
+ // broadcasting the reduction result to those cores. The reduction function is
+ // defined by `computation`, which should be a commutative computation on
+ // scalars, e.g., add, min, or max. The way that AllReduce is applied is
+ // configured by:
+ //
+ // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+ // replicas belong to one group. Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ //
+ // - `channel_id`: for Allreduce nodes from different models, if they have the
+ // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross models.
+ //
+ // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id =
+ tensorflow::gtl::nullopt);
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
@@ -601,16 +757,6 @@ class XlaBuilder {
// Enqueues an imaginary-part instruction onto the computation.
XlaOp Imag(const XlaOp& operand);
- // Enqueues a float32 sqrt instruction onto the computation.
- // (float32 is specified as there is an implicit float32 0.5f constant
- // exponent).
- XlaOp SqrtF32(const XlaOp& operand);
-
- // Enqueues a float32 square instruction onto the computation.
- // (float32 is specified as there is an implicit float32 2.0f constant
- // exponent).
- XlaOp SquareF32(const XlaOp& operand);
-
// Enqueues a lhs^rhs computation onto the computation.
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
@@ -633,14 +779,6 @@ class XlaBuilder {
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- // Enqueues a float32 reciprocal instruction onto the computation.
- // (float32 is specified as there is an implicit float32 -1.0f constant
- // exponent).
- //
- // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
- // shape of the operand.
- XlaOp ReciprocalF32(const XlaOp& operand);
-
// Enqueues a negate instruction onto the computation.
XlaOp Neg(const XlaOp& operand);
@@ -655,7 +793,24 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> dimensions);
// Enqueues a sort (as increasing order) instruction onto the computation.
- XlaOp Sort(const XlaOp& operand);
+ // If only keys are provided:
+ // * If the keys are an rank-1 tensor (an array), the result is a sorted array
+ // of keys, in ascending order.
+ // * If the keys have higher rank, the keys are sorted along the provided
+ // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+ // value of 0 will indepenently sort every column, and a dimension value of 1
+ // will independently sort each row. If no dimension number is provided, then
+ // the last dimension is chosen by default.
+ //
+ // If both keys and values are provided:
+ // * The keys and the values must tensors with the same dimensions. The
+ // element types of the tensors may be different.
+ // * The result is a tuple that consists of a sorted tensor of keys (along the
+ // provided dimension, as above) as the first element, and a tensor with their
+ // corresponding values as the second element.
+ XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
// Enqueues a clamp instruction onto the computation.
XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
@@ -696,19 +851,23 @@ class XlaBuilder {
// Enqueues a Send node onto the computation, to send the given operand to
// a Recv instruction that shares the same channel handle.
void Send(const XlaOp& operand, const ChannelHandle& handle);
+ XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp CreateToken();
+
+ // Enqueues an AfterAll operation with no operands producing a token-shaped
+ // value.
+ XlaOp AfterAll(tensorflow::gtl::ArraySlice<XlaOp> tokens);
// Enqueues a Recv node onto the computation. The data comes from a Send
// instruction that shares the same channel handle and its shape must
// be the same as the given shape.
XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
-
- // Returns true if 'operand' is a compile-time constant. A compile-time
- // constant does not depend on any parameters, or on stateful operators such
- // as `RngNormal` or `Infeed`.
- //
- // This tests whether a computation is a compile-time constant without
- // evaluating the computation.
- StatusOr<bool> IsConstant(const XlaOp& operand) const;
+ XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
// Normalizes operand across spatial and batch dimensions for each feature.
//
@@ -748,47 +907,6 @@ class XlaBuilder {
const XlaOp& grad_output, float epsilon,
int64 feature_index);
- // Returns a new XlaBuilder whose resultant Computation is used only by this
- // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
- // behavior as the parent.
- std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
-
- // Builds the computation with the requested operations, or returns a non-ok
- // status. Note that all ops that have been enqueued will be moved to the
- // computation being returned.
- StatusOr<XlaComputation> Build();
-
- // Builds the computation with the requested operations, or notes an error in
- // the parent XlaBuilder and returns an empty computation if building failed.
- // This function is intended to be used where the returned XlaComputation is
- // only used by the parent XlaBuilder and hence further operation on the
- // returned XlaComputation will simply be error'ed out if an error occurred
- // while building this computation. If the built computation is to be used by
- // a XlaBuilder other than the parent XlaBuilder then Build() should be used
- // instead.
- XlaComputation BuildAndNoteError();
-
- // Returns a subgraph that roots on the given root. If the root is not a
- // compile-time constant (see `IsConstant`), returns an error.
- //
- // This will copy the needed ops/computations to the subgraph.
- StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op) const;
-
- // Returns the first error that was encountered while building the
- // computation. When an error is encountered, by default we return a vacuous
- // XlaOp and inform the user of the error that occurred while
- // building the computation when they make a final call to Build().
- //
- // See also set_die_immediately_on_error().
- Status first_error() const { return first_error_; }
-
- // Returns the shape of the given op.
- StatusOr<Shape> GetShape(const XlaOp& op) const;
-
- // Returns the (inferred) result for the current computation's shape.
- StatusOr<ProgramShape> GetProgramShape() const;
-
- private:
StatusOr<XlaOp> AddInstruction(
HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands = {});
@@ -796,17 +914,6 @@ class XlaBuilder {
void AddCalledComputation(const XlaComputation& computation,
HloInstructionProto* instr);
- // Notes that the error occurred by:
- // * storing it internally and capturing a backtrace if it's the first error
- // (this deferred value will be produced on the call to Build())
- // * dying if die_immediately_on_error_ is true
- void NoteError(const Status& error);
-
- XlaOp NoteErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
-
- // Helper method that creates an empty op and notes error.
- XlaOp UnimplementedOp();
-
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
// Internal helper method that does the building for an arbitrary unary op.
@@ -902,16 +1009,1032 @@ class XlaBuilder {
bool die_immediately_on_error_ = false;
XlaBuilder* parent_builder_{nullptr};
+
+ friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
+ const Shape& shape, const string& name);
+ friend XlaOp ConstantLiteral(XlaBuilder* builder,
+ const LiteralSlice& literal);
+ template <typename NativeT>
+ friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+ friend XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2(
+ XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantFromArray(XlaBuilder* builder,
+ const Array<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+ template <typename NativeT>
+ friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+ template <typename NativeT>
+ friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+ friend XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ friend XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ friend XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ friend XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
+ int64 limit_index, int64 stride, int64 dimno);
+
+ friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+ friend XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ int64 dimension);
+
+ friend void Trace(const string& tag, const XlaOp& operand);
+
+ friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
+ const XlaOp& on_false);
+ friend XlaOp Tuple(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> elements);
+ friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+ friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+ friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+ friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+ friend XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneral(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config);
+ friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+ friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+ friend XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+ friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Conj(const XlaOp& operand);
+ friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Not(const XlaOp& operand);
+ friend XlaOp ShiftLeft(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+ friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+ friend XlaOp ReduceWindow(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+ friend XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids);
+ friend XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id);
+ friend XlaOp SelectAndScatter(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+ friend XlaOp Abs(const XlaOp& operand);
+ friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp Exp(const XlaOp& operand);
+ friend XlaOp Expm1(const XlaOp& operand);
+ friend XlaOp Floor(const XlaOp& operand);
+ friend XlaOp Ceil(const XlaOp& operand);
+ friend XlaOp Round(const XlaOp& operand);
+ friend XlaOp Log(const XlaOp& operand);
+ friend XlaOp Log1p(const XlaOp& operand);
+ friend XlaOp Sign(const XlaOp& operand);
+ friend XlaOp Clz(const XlaOp& operand);
+ friend XlaOp Cos(const XlaOp& operand);
+ friend XlaOp Sin(const XlaOp& operand);
+ friend XlaOp Tanh(const XlaOp& operand);
+ friend XlaOp Real(const XlaOp& operand);
+ friend XlaOp Imag(const XlaOp& operand);
+ friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+ friend XlaOp IsFinite(const XlaOp& operand);
+ friend XlaOp ConvertElementType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp BitcastConvertType(const XlaOp& operand,
+ PrimitiveType new_element_type);
+ friend XlaOp Neg(const XlaOp& operand);
+ friend XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+ friend XlaOp Rev(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ friend XlaOp Sort(XlaOp keys, tensorflow::gtl::optional<XlaOp> values,
+ int64 dimension);
+ friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+ friend XlaOp Map(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands);
+ friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
+ const Shape& shape);
+ friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+ friend XlaOp While(const XlaComputation& condition,
+ const XlaComputation& body, const XlaOp& init);
+ friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+ friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+ friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend void Send(const XlaOp& operand, const ChannelHandle& handle);
+ friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+ friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+ friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+ friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+ friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config);
+ friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+ friend XlaOp CreateToken(XlaBuilder* builder);
+ friend XlaOp AfterAll(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> tokens);
+};
+
+// RAII-style object: sets the current sharding assignment in builder on
+// construction, and sets back to the previous assignment on destruction.
+class XlaScopedShardingAssignment {
+ public:
+ XlaScopedShardingAssignment(xla::XlaBuilder* builder,
+ tensorflow::gtl::optional<OpSharding> sharding)
+ : builder_(builder), prev_sharding_(builder->sharding()) {
+ SetSharding(sharding);
+ }
+
+ XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
+ XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
+ delete;
+
+ ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
+
+ private:
+ void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
+ if (sharding.has_value()) {
+ builder_->SetSharding(sharding.value());
+ } else {
+ builder_->ClearSharding();
+ }
+ }
+
+ xla::XlaBuilder* const builder_;
+ tensorflow::gtl::optional<OpSharding> prev_sharding_;
};
+// Free functions for building XlaOps. The intention is that these will
+// become the public API for building XlaOps rather than calling methods on
+// XlaBuilder directly.
+
+// Enqueues a "retrieve parameter value" instruction for a parameter that was
+// passed to the computation.
+XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
+ const string& name);
+
+// Enqueues a constant with the value of the given literal onto the
+// computation.
+XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
+
+// Enqueues a constant onto the computation. Methods are templated on the
+// native host type (NativeT) which corresponds to a specific XLA
+// PrimitiveType as given in the following table:
+//
+// Native Type PrimitiveType
+// -----------------------------
+// bool PRED
+// int32 S32
+// int64 S64
+// uint32 U32
+// uint64 U64
+// float F32
+// double F64
+//
+// Note: not all primitive types defined in xla_data.proto have a
+// corresponding native type yet.
+template <typename NativeT>
+XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values);
+XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
+template <typename NativeT>
+XlaOp ConstantR2(XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values);
+template <typename NativeT>
+XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values);
+template <typename NativeT>
+XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout);
+template <typename NativeT>
+XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values);
+
+// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
+// computation. The vector has size 'length' and every element has the value
+// 'value'.
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
+
+// Adds dimensions to an array by duplicating the data in the array.
+//
+// The new dimensions are inserted on the left, i.e. if
+// broadcast_sizes has values {a0, ..., aN} and the operand shape
+// has dimensions {b0, ..., bM} then the shape of the output has
+// dimensions {a0, ..., aN, b0, ..., bM}.
+//
+// The new dimensions index into copies of the operand, i.e.
+//
+// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+XlaOp Broadcast(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+// Performs in-dimension-style broadcast.
+//
+// Operand specifies the input to be broadcast. "shape" is expected output
+// shape. "broadcast_dimensions" are the dimensions to be broadcasting into.
+// Dimension numbers in broadcast_dimensions map to individual dimensions
+// of the operand, and specify what dimension of the output shape they
+// should be broadcast.
+// e.g.
+// Say operand = [1, 2], i.e., a 1D tensor with 2 elements.
+// and dimension of shape is [2,2].
+// Specifying {1} as brodcast_dimension will generate output
+// [1 , 2]
+// [1 , 2]
+// On the other hand, specifying {0} as broadcast_dimension
+// will generate output
+// [1 , 1]
+// [2 , 2]
+XlaOp BroadcastInDim(
+ const XlaOp& operand, const Shape& shape,
+ const tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+// Enqueues a pad operation onto the computation that pads the given value on
+// the edges as well as between the elements of the input. padding_config
+// specifies the padding amount for each dimension.
+XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
+ const PaddingConfig& padding_config);
+
+// Enqueues an operation onto the computation that flattens the operand based
+// on the dimension order (major/slowest-varying to minor/fastest-varying)
+// given, followed by reshaping it into the shape with the given dimension
+// sizes (also major to minor). Conceptually, this is a limited form of
+// "shape casting".
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+// Enqueues an operation onto the computation that collapses the operand, from
+// first to last dimension (C order), then reshapes it to the given dimension
+// sizes. Conceptually, this is a limited form of "shape casting".
+XlaOp Reshape(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+// Wrapper for Reshape.
+// Enqueues an operation to collapse the provided dimensions; e.g. an
+// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
+// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
+// be a consecutive, in-order subsequence of the operand dimensions.
+//
+// Note that collapsing a single dimension does nothing:
+//
+// {256} collapsing {0} => {256}
+// {1} collapsing {0} => {1}
+//
+// Collapsing multiple dimensions produces a single result dimension:
+//
+// {256, 2} collapsing {0,1} => {512}
+// {256, 2, 3} collapsing {0,1} => {512, 3}
+//
+// This could potentially cause data to be moved -- it provides a more
+// structured form of reshaping than an arbitrary Reshape operation.
+XlaOp Collapse(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+// Enqueues a slice operation onto the computation that slices the operand
+// from the start indices to the limit indices; e.g.
+//
+// x
+// [ 0 1 2 3 ]
+// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
+// [ 8 9 a b ]
+//
+// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
+// range notation.
+// The strides parameter determines the stride over the slice
+XlaOp Slice(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+// Enqueues a slice operation in a given dimension, taking all other
+// dimensions as they are; e.g. if dimno is 1 from start_index 2 to
+// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
+// for:
+//
+// array[:, 2:4:1, :]
+XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
+ int64 stride, int64 dimno);
+
+// Enqueues a slice operation onto the computation that slices the 'operand'
+// from dynamic start indices which are passed in 'start_indices'.
+// The size of the slice in each dimension is passed in 'slice_sizes',
+// which specify the end point of exclusive slice intervals in each
+// dimension [start, start + size).
+// The shape of 'start_indices' must be rank == 1, with dimension size
+// equal to the rank of the 'operand'.
+// Slice index calculations are computed modulo input dimension sizes to
+// prevent dynamic start indices from generating out-of-bound array accesses.
+XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+// Enqueues a dynamic update slice operation onto the computation, which
+// updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
+// The shape of 'update' determines the shape of the slice of 'operand'
+// which is updated.
+// The indices specified in 'start_indices' specify the offset of the slice
+// of 'operand' which is updated.
+//
+// update = {10, 11} // calculated at runtime.
+// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
+// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
+// [7 8 9] [7 8 9 ]
+//
+// The shape of 'start_indices' must be rank == 1, with dimension size
+// equal to the rank of the 'operand'.
+// Slice index calculations are computed modulo update dimension sizes to
+// prevent dynamic start indices from generating out-of-bound array accesses.
+XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
+ const XlaOp& start_indices);
+
+// Enqueues a concatenate instruction onto the computation. 'operands' must
+// have >= 1 entry.
+XlaOp ConcatInDim(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands, int64 dimension);
+
+// Enqueue a tracing operation onto the computation; the computation will emit
+// a logging message with the operand.
+void Trace(const string& tag, const XlaOp& operand);
+
+// Enqueues a conditional-move-like select operation onto the computation;
+// predicated on pred, selects between on_true and on_false.
+XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
+
+// Enqueues a tuple-creation instruction onto the computation.
+XlaOp Tuple(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> elements);
+
+// Enqueues a tuple-element-get instruction onto the computation.
+XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
+
+// Enqueues an equal-to comparison instruction onto the computation.
+XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a not-equal comparison instruction onto the computation.
+XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a greater-or-equal comparison instruction onto the computation.
+XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a greater-than comparison instruction onto the computation.
+XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a less-than comparison instruction onto the computation.
+XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a less-or-equal comparison instruction onto the computation.
+XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a dot instruction onto the computation.
+XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs);
+
+// Enqueues a general dot instruction onto the computation.
+XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, which uses the
+// default convolution dimension numbers.
+XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration in the format returned by MakePadding().
+XlaOp ConvWithGeneralPadding(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided dimension numbers configuration.
+XlaOp ConvWithGeneralDimensions(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration as well as the dimension numbers.
+XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues a convolution instruction onto the computation, with the caller
+// provided padding configuration, dilation factors and dimension numbers.
+XlaOp ConvGeneralDilated(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+// Enqueues an FFT instruction onto the computation, of the given type and
+// with the given FFT length.
+XlaOp Fft(const XlaOp& operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+
+// Enqueues an infeed instruction onto the computation, which writes data of
+// the given shape to the infeed buffer of the device.
+XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
+ const string& config = "");
+
+// Variant of Infeed which takes a token-shaped operand and produces a
+// two-element tuple containing the data value and a token-shaped value.
+// Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
+ const string& config = "");
+
+// Enqueues an outfeed instruction onto the computation. This instruction
+// generates outgoing data transfers for the given data.
+//
+// shape_with_layout communicates the laid out shape that we want to outfeed
+// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
+// will occur.
+void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+// Variant of Outfeed which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
+ const Shape& shape_with_layout,
+ const string& outfeed_config);
+
+// Enqueues a call instruction onto the computation.
+XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<XlaOp> operands);
+
+// Enqueues a custom call instruction onto the computation.
+// During code generation, a call instruction is emitted which targets a
+// symbol with the name |call_target_name|. The |operands| are passed to the
+// call instruction. |shape| is the resultant shape.
+XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const Shape& shape);
+
+// Enqueues a pseudo-op to represent host-side computation data-dependencies.
+// During code generation, host send and receive operations will be generated
+// to transfer |operands| to the host and a single result of |shape| back to
+// the device. Host send/recv operations are emitted using |channel_name|.
+// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
+// instruction scheduling.
+XlaOp HostCompute(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const string& channel_name, int64 cost_estimate_ns,
+ const Shape& shape);
+
+// The following methods enqueue element-wise binary arithmetic operations
+// onto the computation. The shapes of the operands have to match unless one
+// of the operands is a scalar, or an explicit broadcast dimension is given
+// (see g3doc for more details).
+
+// Enqueues a complex compose instruction onto the computation.
+XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a complex conjugate instruction onto the computation.
+XlaOp Conj(const XlaOp& operand);
+
+// Enqueues an add instruction onto the computation.
+XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a subtract instruction onto the computation.
+XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a multiply instruction onto the computation.
+XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a divide instruction onto the computation.
+XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a remainder instruction onto the computation.
+XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a max instruction onto the computation.
+XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues a min instruction onto the computation.
+XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Element-wise logical operators
+XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+XlaOp Not(const XlaOp& operand);
+
+XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+XlaOp ShiftRightArithmetic(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+XlaOp ShiftRightLogical(
+ const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Reduces an array among the provided dimensions, given "computation" as a
+// reduction operator.
+XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+// Convenience wrapper around the above that reduces all the dimensions in the
+// operand shape.
+XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation);
+
+// Enqueues a windowed reduce instruction onto the computation.
+XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+// As ReduceWindow(), but the padding is given in the format
+// returned by MakePadding().
+XlaOp ReduceWindowWithGeneralPadding(
+ const XlaOp& operand, const XlaOp& init_value,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+// Returns the sum of the operand value within each subgroup of replicas. All
+// replicas supply one input to the sum and all replicas receive the resulting
+// sum for each subgroup.
+XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
+
+// Enqueues an operation that do an AllReduce of the operand cross cores. Here
+// AllReduce means doing a reduction on the input operand cross cores and then
+// broadcasting the reduction result to those cores. The reduction function is
+// defined by `computation`, which should be a commutative computation on
+// scalars, e.g., add, min, or max. The way that AllReduce is applied is
+// configured by:
+//
+// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+// replicas belong to one group. Allreduce will be applied within subgroups.
+// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+//
+// - `channel_id`: for Allreduce nodes from different models, if they have the
+// same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+// applied cross models.
+//
+// TODO(b/79737069): Rename this to AllReduce when it's ready to use.
+XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<ChannelHandle>&
+ channel_id = tensorflow::gtl::nullopt);
+
+// Enqueues an operation that scatters the `source` array to the selected
+// indices of each window.
+XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding, const XlaOp& source,
+ const XlaOp& init_value, const XlaComputation& scatter);
+
+// As SelectAndScatter(), but the padding is given in the format
+// returned by MakePadding().
+XlaOp SelectAndScatterWithGeneralPadding(
+ const XlaOp& operand, const XlaComputation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const XlaOp& source, const XlaOp& init_value,
+ const XlaComputation& scatter);
+
+// Enqueues an abs instruction onto the computation.
+XlaOp Abs(const XlaOp& operand);
+
+// Enqueues a atan2 instruction onto the computation.
+XlaOp Atan2(const XlaOp& y, const XlaOp& x,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues an exp instruction onto the computation.
+XlaOp Exp(const XlaOp& operand);
+
+// Enqueues an expm1 instruction onto the computation.
+XlaOp Expm1(const XlaOp& operand);
+
+// Enqueues a floor instruction onto the computation.
+XlaOp Floor(const XlaOp& operand);
+
+// Enqueues a ceil instruction onto the computation.
+XlaOp Ceil(const XlaOp& operand);
+
+// Enqueues a round instruction onto the computation, rounding to nearest even
+// with half-way cases rounding away from zero.
+XlaOp Round(const XlaOp& operand);
+
+// Enqueues an log instruction (natural logarithm) onto the computation.
+XlaOp Log(const XlaOp& operand);
+
+// Enqueues an log1p instruction (log(x+1)) onto the computation.
+XlaOp Log1p(const XlaOp& operand);
+
+// Enqueues a sign instruction onto the computation.
+XlaOp Sign(const XlaOp& operand);
+
+// Enqueues a count leading zeros instruction onto the computation.
+XlaOp Clz(const XlaOp& operand);
+
+// Enqueues a cosine instruction onto the computation.
+XlaOp Cos(const XlaOp& operand);
+
+// Enqueues a sine instruction onto the computation.
+XlaOp Sin(const XlaOp& operand);
+
+// Enqueues a tanh instruction onto the computation.
+XlaOp Tanh(const XlaOp& operand);
+
+// Enqueues a real-part instruction onto the computation.
+XlaOp Real(const XlaOp& operand);
+
+// Enqueues an imaginary-part instruction onto the computation.
+XlaOp Imag(const XlaOp& operand);
+
+// Enqueues a lhs^rhs computation onto the computation.
+XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+// Enqueues an operator that tests if the operand's values are finite, i.e.,
+// not Inf or NaN. Defined only for floating-point types. Returns an array of
+// booleans with the same shape where entries are true iff the corresponding
+// entry was NaN.
+XlaOp IsFinite(const XlaOp& operand);
+
+// Enqueues a convert instruction onto the computation that changes the
+// element type of the operand array to primitive_type.
+XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
+
+// Enqueues a no-op instruction onto the computation that changes
+// the element type of the operand array to primitive_type. The
+// bit-widths of the source and destination element types must be
+// identical.
+XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
+
+// Enqueues a negate instruction onto the computation.
+XlaOp Neg(const XlaOp& operand);
+
+// Enqueues a transpose instruction onto the computation.
+XlaOp Transpose(const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+// Enqueues a reverse instruction onto the computation. The order of the
+// elements in the given dimensions is reversed (i.e., the element at index i
+// is moved to index dimension_size - 1 - i).
+XlaOp Rev(const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+
+// Enqueues a sort (as increasing order) instruction onto the computation.
+// If only keys are provided:
+// * If the keys are an rank-1 tensor (an array), the result is a sorted array
+// of keys, in ascending order.
+// * If the keys have higher rank, the keys are sorted along the provided
+// dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
+// value of 0 will indepenently sort every column, and a dimension value of 1
+// will independently sort each row. If no dimension number is provided, then
+// the last dimension is chosen by default.
+//
+// If both keys and values are provided:
+// * The keys and the values must tensors with the same dimensions. The
+// element types of the tensors may be different.
+// * The result is a tuple that consists of a sorted tensor of keys (along the
+// provided dimension, as above) as the first element, and a tensor with their
+// corresponding values as the second element.
+XlaOp Sort(XlaOp keys,
+ tensorflow::gtl::optional<XlaOp> values = tensorflow::gtl::nullopt,
+ int64 dimension = -1);
+
+// Enqueues a clamp instruction onto the computation.
+XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
+
+// Enqueues a map instruction onto the computation.
+XlaOp Map(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> operands,
+ const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<XlaOp> static_operands = {});
+
+// Enqueues a N(mu, sigma) random number generation instruction onto the
+// computation.
+XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
+
+// Enqueues a U(a, b) random number generation instruction onto the
+// computation. Returns values in the semi-open interval [a, b).
+XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
+
+// Enqueues a while node onto the computation.
+XlaOp While(const XlaComputation& condition, const XlaComputation& body,
+ const XlaOp& init);
+
+// Enqueues a conditional node onto the computation.
+XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
+ const XlaComputation& true_computation,
+ const XlaOp& false_operand,
+ const XlaComputation& false_computation);
+
+// Enqueues a ReducePrecision node onto the computation.
+XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
+ const int mantissa_bits);
+
+// Enqueues a Gather node onto the computation.
+XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
+ const GatherDimensionNumbers& dimension_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
+// Enqueues a Send node onto the computation, to send the given operand to
+// a Recv instruction that shares the same channel handle.
+void Send(const XlaOp& operand, const ChannelHandle& handle);
+
+// Variant of Send which takes a token-shaped operand and produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
+ const ChannelHandle& handle);
+
+// Enqueues a Recv node onto the computation. The data comes from a Send
+// instruction that shares the same channel handle and its shape must
+// be the same as the given shape.
+XlaOp Recv(XlaBuilder* builder, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Variant of Recv which takes a token-shaped operand and produces a two-element
+// tuple containing the data value and a token-shaped value. Tokens are used
+// for ordering side-effecting operations.
+// TODO(b/110532604): Replace all uses of the non-token form with this variant.
+XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
+ const ChannelHandle& handle);
+
+// Enqueues an operation (AfterAll) with no operands that produces a
+// token-shaped value. Tokens are used for ordering side-effecting operations.
+// This is a separate method from AfterAll to facility the removal of
+// operand-less AfterAll instructions.
+// TODO(b/110532604): Remove this function when all tokens are derived from a
+// single token generated or passed into the entry computation.
+XlaOp CreateToken(XlaBuilder* builder);
+
+// Enqueues an AfterAll instruction which produces a token-shaped value and
+// takes a variadic number of token-shaped operands. The number of operands must
+// be greater than zero. Used for joining tokens.
+XlaOp AfterAll(XlaBuilder* builder, tensorflow::gtl::ArraySlice<XlaOp> tokens);
+
+// Normalizes operand across spatial and batch dimensions for each feature.
+//
+// Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
+// is the normalized result and batch_mean and batch_var are the mean and
+// variance, respectively, across batch for the operand.
+XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, float epsilon,
+ int64 feature_index);
+
+// Normalizes operand across spatial and batch dimensions for each feature.
+//
+// `BatchNormInference` is equivalent to calling `BatchNormTraining` without
+// computing `mean` and `variance` for each batch inside the operation. It
+// uses the input `mean` and `variance` instead as estimated values. The
+// purpose of this op is to reduce latency in inference, hence the name
+// `BatchNormInference`.
+//
+// The output has the same shape as `operand`, and contains the normalized
+// values for each batch.
+XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& offset, const XlaOp& mean,
+ const XlaOp& variance, float epsilon,
+ int64 feature_index);
+
+// Calculates the gradients of a batch norm op.
+//
+// The inputs `batch_mean` and `batch_var` represent the mean and variance
+// across the batch.
+//
+// Returns a tuple of three elements:
+// - grad_operand: Gradient with respect to input `operand`
+// - grad_offset: Gradient with respect to input `offset`
+// - grad_scale: Gradient with respect to input `scale`
+XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
+ const XlaOp& batch_mean, const XlaOp& batch_var,
+ const XlaOp& grad_output, float epsilon,
+ int64 feature_index);
+
+// Implementation details below this point.
+
template <typename NativeT>
XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(*Literal::CreateR0<NativeT>(value));
+ return ConstantLiteral(*LiteralUtil::CreateR0<NativeT>(value));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- return ConstantLiteral(*Literal::CreateR1<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1<NativeT>(values));
}
template <typename NativeT>
@@ -923,44 +2046,44 @@ XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
}
inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(*Literal::CreateR1(values));
+ return ConstantLiteral(*LiteralUtil::CreateR1(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(*Literal::CreateR2<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateFromArray<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateFromArray<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateFromArrayWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(*Literal::CreateR2FromArray2D<NativeT>(values));
+ return ConstantLiteral(*LiteralUtil::CreateR2FromArray2D<NativeT>(values));
}
template <typename NativeT>
XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantLiteral(
- *Literal::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
}
template <typename NativeT>
@@ -979,34 +2102,96 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
return ConstantFromArray(values);
}
-// RAII-style object: sets the current sharding assignment in builder on
-// construction, and sets back to the previous assignment on destruction.
-class XlaScopedShardingAssignment {
- public:
- XlaScopedShardingAssignment(xla::XlaBuilder* builder,
- tensorflow::gtl::optional<OpSharding> sharding)
- : builder_(builder), prev_sharding_(builder->sharding()) {
- SetSharding(sharding);
- }
+// Free function template implementations.
- XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
- XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
- delete;
+template <typename NativeT>
+XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR0<NativeT>(value));
+}
- ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1<NativeT>(values));
+}
- private:
- void SetSharding(const tensorflow::gtl::optional<OpSharding>& sharding) {
- if (sharding.has_value()) {
- builder_->SetSharding(sharding.value());
- } else {
- builder_->ClearSharding();
- }
- }
+template <typename NativeT>
+XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
+ Literal literal(ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
+ literal.PopulateWithValue(value);
+ return ConstantLiteral(builder, literal);
+}
- xla::XlaBuilder* const builder_;
- tensorflow::gtl::optional<OpSharding> prev_sharding_;
-};
+inline XlaOp ConstantR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR1(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2(XlaBuilder* builder,
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return ConstantLiteral(builder, *LiteralUtil::CreateR2<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
+ const Array<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateFromArray<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
+ const Array2D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
+ const Array2D<NativeT>& values) {
+ return ConstantLiteral(builder,
+ *LiteralUtil::CreateR2FromArray2D<NativeT>(values));
+}
+
+template <typename NativeT>
+XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
+ const Array3D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantLiteral(
+ builder,
+ *LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
+}
+
+template <typename NativeT>
+XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
+ const Array3D<NativeT>& values) {
+ return ConstantFromArray(builder, values);
+}
+
+template <typename NativeT>
+XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
+ const Array4D<NativeT>& values,
+ const Layout& layout) {
+ return ConstantFromArrayWithLayout(builder, values, layout);
+}
+
+template <typename NativeT>
+XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
+ const Array4D<NativeT>& values) {
+ return ConstantFromArray(builder, values);
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
index 2df3ea3af0..3b8beb2c78 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc
@@ -53,16 +53,86 @@ class XlaBuilderTest : public ::testing::Test {
TEST_F(XlaBuilderTest, OnePlusTwo) {
XlaBuilder b(TestName());
- b.Add(b.ConstantR0<float>(1.0), b.ConstantR0<float>(2.0));
+ Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
}
+TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
+ auto test_unary_operator =
+ [&](std::function<XlaOp(XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(ConstantR0<int32>(&b, 1));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+ test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
+ test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
+ auto test_binary_operator =
+ [&](std::function<XlaOp(XlaOp, XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(ConstantR0<int32>(&b, 1), ConstantR0<int32>(&b, 2));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+
+ test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
+ op::Add(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
+ op::Subtract(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
+ op::Multiply(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
+ op::Divide(op::Constant(), op::Constant()));
+
+ test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
+ op::And(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
+ op::Or(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
+ op::Xor(op::Constant(), op::Constant()));
+ test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
+ op::ShiftLeft(op::Constant(), op::Constant()));
+ test_binary_operator(
+ [](XlaOp x, XlaOp y) { return x >> y; },
+ op::ShiftRightArithmetic(op::Constant(), op::Constant()));
+
+ auto test_unsigned_binary_operator =
+ [&](std::function<XlaOp(XlaOp, XlaOp)> op,
+ ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
+ XlaBuilder b(TestName());
+ op(ConstantR0<uint32>(&b, 1), ConstantR0<uint32>(&b, 2));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ auto root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, matches_pattern);
+ };
+ test_unsigned_binary_operator(
+ [](XlaOp x, XlaOp y) { return x >> y; },
+ op::ShiftRightLogical(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
+ XlaBuilder b(TestName());
+ ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2);
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Argument to >> operator does not have an integral type"));
+}
+
TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
- b.Add(x, b.ConstantR0<float>(1.0));
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
+ Add(x, ConstantR0<float>(&b, 1.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
@@ -72,9 +142,9 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
XlaBuilder b(TestName());
const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
- auto x = b.Parameter(0, x_shape, "x");
- auto y = b.Parameter(1, y_shape, "y");
- auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
+ auto x = Parameter(&b, 0, x_shape, "x");
+ auto y = Parameter(&b, 1, y_shape, "y");
+ auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1});
TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
@@ -86,8 +156,8 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
TEST_F(XlaBuilderTest, XPlusX) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
- b.Add(x, x);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
+ Add(x, x);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
@@ -95,9 +165,9 @@ TEST_F(XlaBuilderTest, XPlusX) {
TEST_F(XlaBuilderTest, ShapeInferenceError) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
- b.Add(x, y);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
+ auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
+ Add(x, y);
auto statusor = BuildHloModule(&b);
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(), HasSubstr("shape inference"));
@@ -105,12 +175,12 @@ TEST_F(XlaBuilderTest, ShapeInferenceError) {
TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
XlaBuilder b_call("add");
- b_call.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
+ Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x");
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
- auto y = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "y");
- b.Add(x, y);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x");
+ auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y");
+ Add(x, y);
auto statusor = BuildHloModule(&b);
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
@@ -119,16 +189,16 @@ TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
TEST_F(XlaBuilderTest, Call) {
XlaBuilder b_call("the_only_to_apply");
- auto p0 = b_call.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
- auto p1 = b_call.Parameter(1, ShapeUtil::MakeShape(F32, {}), "p1");
- b_call.Add(p0, p1);
+ auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
+ auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
+ Add(p0, p1);
TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- auto one = b.ConstantR0<float>(1);
- auto two = b.ConstantR0<float>(2);
- b.Add(b.Call(call, {x, y}), b.Call(call, {one, two}));
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto one = ConstantR0<float>(&b, 1);
+ auto two = ConstantR0<float>(&b, 2);
+ Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
@@ -137,9 +207,9 @@ TEST_F(XlaBuilderTest, Call) {
TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
- b.Add(x, y);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
+ auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
+ Add(x, y);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
// Expected:
@@ -158,9 +228,9 @@ TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
- b.Add(x, y, /*broadcast_dimensions=*/{0, 1});
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
+ auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
+ Add(x, y, /*broadcast_dimensions=*/{0, 1});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
// The binary operation has in-dim broadcast and degenerate broadcast, should
@@ -183,9 +253,10 @@ TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
XlaBuilder b1("b1");
- auto p0 = b1.Parameter(0, ShapeUtil::MakeShape(F32, {}), "p0");
+ auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0");
XlaBuilder builder("main");
- builder.Add(p0, p0);
+ auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p");
+ Add(p, p0);
auto statusor = builder.Build();
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(
@@ -196,8 +267,8 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
- b.Reshape(x, /*new_sizes=*/{6, 35});
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
+ Reshape(x, /*new_sizes=*/{6, 35});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Reshape(op::Parameter()));
@@ -205,8 +276,8 @@ TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
- b.Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
+ Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
@@ -214,25 +285,38 @@ TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
TEST_F(XlaBuilderTest, Transpose) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
- b.Transpose(x, /*permutation=*/{1, 0});
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ Transpose(x, /*permutation=*/{1, 0});
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Transpose(op::Parameter()));
}
-// TODO(b/65209188): Create a dedicated lowering for Xor.
-TEST_F(XlaBuilderTest, Xor) {
+TEST_F(XlaBuilderTest, ReportError) {
XlaBuilder b(TestName());
- auto x = b.Parameter(0, ShapeUtil::MakeShape(PRED, {}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(PRED, {}), "y");
- b.Xor(x, y);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
+ Add(b.ReportError(InvalidArgument("a test error")), x);
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
+}
+
+TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
+ XlaBuilder b(TestName());
+ StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0));
+ Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
auto root = module->entry_computation()->root_instruction();
- LOG(ERROR) << module->ToString();
- EXPECT_THAT(root,
- op::Or(op::And(op::Not(op::Parameter(0)), op::Parameter(1)),
- op::And(op::Parameter(0), op::Not(op::Parameter(1)))));
+ EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
+}
+
+TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
+ XlaBuilder b(TestName());
+ StatusOr<XlaOp> op(InvalidArgument("a test error"));
+ Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
+ auto statusor = b.Build();
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/BUILD b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD
new file mode 100644
index 0000000000..a26b20c861
--- /dev/null
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/BUILD
@@ -0,0 +1,18 @@
+# Description:
+# Python API for shardings in XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_library(
+ name = "xla_sharding",
+ srcs = ["xla_sharding.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ "//tensorflow/compiler/xla/python_api:types",
+ "//tensorflow/compiler/xla/python_api:xla_shape",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
new file mode 100644
index 0000000000..abd10b164e
--- /dev/null
+++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py
@@ -0,0 +1,204 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Experimental support for defining XLA shardings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import xla_shape
+from tensorflow.core.framework import attr_value_pb2
+
+
+class Sharding(object):
+ """A class to support adding sharding attributes to Ops.
+
+ Use the factory constructors and then call apply_to_tensor:
+ Sharding.replicate().apply_to_tensor(tensor)
+ """
+
+ def __init__(self, proto=None):
+ """Do not use this constructor; use the factory functions below."""
+ self._proto = proto
+
+ @classmethod
+ def replicate(cls):
+ """Returns a replicated sharding attribute.
+
+ This causes an op to be computed in its entirety independently on all
+ cores in the XLA device.
+ """
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
+
+ @classmethod
+ def assign_device(cls, core):
+ """Returns an AssignDevice sharding attribute.
+
+ This causes an op to be computed in its entirety only on one core in
+ the XLA device.
+ Args:
+ core: The core to assign this Op to.
+ """
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.MAXIMAL,
+ tile_assignment_dimensions=[1],
+ tile_assignment_devices=[core]))
+
+ @classmethod
+ def tile(cls, tile_shape, tile_assignment):
+ """Returns a Tiled sharding attribute.
+
+ This causes an op to be partially computed on multiple cores in the
+ XLA device.
+
+ Args:
+ tile_shape: A xla_shape.Shape describing the tile shape that each core
+ will compute.
+ The tile shape does not need to be divisible by the tile assignment.
+ tile_assignment: An np.ndarray describing the topology of the tiling and
+ which device will compute which part of the topology.
+
+ Raises:
+ TypeError: tile_assignment was not of np.array type or tile_shape was
+ not of xla_shape.Shape type.
+
+ TODO(jmolloy): This concept is nefarious and is not
+ something we really want to expose to users (especially as the
+ contract for tile_assignment is very strict).
+ """
+ if not isinstance(tile_assignment, np.ndarray):
+ raise TypeError('Tile assignment must be of type np.ndarray')
+ if not isinstance(tile_shape, xla_shape.Shape):
+ raise TypeError('Tile shape must be of type xla_shape.Shape')
+ dims = list(tile_assignment.shape)
+ flattened_devices = tile_assignment.reshape(-1, order='C')
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.OTHER,
+ tile_shape=tile_shape.message,
+ tile_assignment_dimensions=dims,
+ tile_assignment_devices=list(flattened_devices)))
+
+ @classmethod
+ def split(cls, tensor, split_dimension, num_devices):
+ """Returns a Sharding that splits a tensor across a dimension.
+
+ This creates a Tiled attribute, similar to tile(), but easier to use for the
+ common case of tiling a tensor N ways in one dimension.
+
+ Args:
+ tensor: A tf.Tensor to split.
+ split_dimension: The dimension number to split.
+ num_devices: The number of cores to split `tensor` over.
+
+ Raises:
+ ValueError: The tensor to split was smaller in the split dimension than
+ the number of devices to split over.
+ """
+ tensor.shape.assert_is_fully_defined()
+ shape = tensor.shape.as_list()
+ if shape[split_dimension] < num_devices:
+ raise ValueError('Split dimension was smaller than the required number '
+ 'of splits: shape=%r, dimension=%r, num_devices=%r',
+ shape, split_dimension, num_devices)
+
+ tile_shape = shape
+ tile_shape[split_dimension] = int(
+ math.ceil(tile_shape[split_dimension] / num_devices))
+ tile_shape_proto = xla_data_pb2.Shape(
+ element_type=xla_data_pb2.F32, dimensions=tile_shape)
+
+ tile_assignment_dims = [1] * len(shape)
+ tile_assignment_dims[split_dimension] = num_devices
+
+ return Sharding(
+ proto=xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.OTHER,
+ tile_shape=tile_shape_proto,
+ tile_assignment_dimensions=tile_assignment_dims,
+ tile_assignment_devices=range(num_devices)))
+
+ def apply_to_tensor(self, tensor):
+ """Applies this Sharding attribute to `tensor`."""
+ if len(tensor.op.outputs) > 1:
+ proto = self._get_or_create_tuple_proto(tensor.op)
+ # We can't mutate an element of old_proto.tuple_shardings, so create
+ # a new proto.
+ tuple_shardings = list(proto.tuple_shardings)
+ tuple_shardings[tensor.value_index] = self._proto
+ proto = xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings)
+ else:
+ proto = self._proto
+
+ attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
+ # TODO(jmolloy): This need to be seriously revisited before declaring this
+ # API available for public use.
+ # pylint: disable=protected-access
+ tensor.op._set_attr('_XlaSharding', attr_value)
+
+ @property
+ def proto(self):
+ """Return the sharding protobuf of type xla_data_pb2.OpSharding."""
+ return self._proto
+
+ def _get_or_create_tuple_proto(self, op):
+ try:
+ attr = op.get_attr('_XlaSharding')
+ proto = xla_data_pb2.OpSharding()
+ proto.ParseFromString(attr)
+ return proto
+ except ValueError:
+ return self._create_tuple_proto(op)
+
+ def _create_tuple_proto(self, op):
+ shardings = [
+ xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
+ for _ in op.outputs
+ ]
+ return xla_data_pb2.OpSharding(
+ type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
+
+
+# Helpers for the above factory functions that allow easy application of
+# shardings, for example:
+# tensor = xla_sharding.replicate(tensor)
+
+
+def replicate(tensor):
+ Sharding.replicate().apply_to_tensor(tensor)
+ return tensor
+
+
+def assign_device(tensor, device):
+ Sharding.assign_device(device).apply_to_tensor(tensor)
+ return tensor
+
+
+def tile(tensor, tile_shape, tile_assignment):
+ Sharding.tile(tile_shape, tile_assignment).apply_to_tensor(tensor)
+ return tensor
+
+
+def split(tensor, split_dimension, num_devices):
+ Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(tensor)
+ return tensor
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 89cafa1a7d..15eeb2ea13 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} // namespace
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
+ if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
+ // Opaque and token types have empty layouts.
+ return Layout();
+ }
+
// A Layout proto corresponds to a single array, not a tuple.
- DCHECK(!ShapeUtil::IsTuple(shape));
+ CHECK(ShapeUtil::IsArray(shape));
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
@@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
SetToDefaultLayout(&element_shape);
}
shape->clear_layout();
- } else if (ShapeUtil::IsOpaque(*shape)) {
- shape->clear_layout();
- } else {
+ } else if (ShapeUtil::IsArray(*shape)) {
shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major);
+ } else {
+ // Opaque, token types etc. have no layout.
+ shape->clear_layout();
}
}
@@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
}
return Status::OK();
- } else if (ShapeUtil::IsOpaque(shape)) {
- if (shape.has_layout()) {
- return InvalidArgument("opaque should not have a layout field");
- }
- return Status::OK();
- } else {
- // Array shape.
+ } else if (ShapeUtil::IsArray(shape)) {
if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout",
ShapeUtil::HumanString(shape).c_str());
}
return ValidateLayoutForShape(shape.layout(), shape);
+ } else {
+ // Token, opaque, etc. shape.
+ if (shape.has_layout()) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return Status::OK();
}
}
@@ -181,7 +189,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return InvalidArgument("a single Layout is not valid for tuple shapes");
}
- if (ShapeUtil::IsOpaque(shape)) {
+ if (!ShapeUtil::IsArray(shape)) {
+ if (layout.minor_to_major_size() != 0 ||
+ layout.padded_dimensions_size() != 0) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a non-trivial layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
return Status::OK();
}
@@ -234,6 +248,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
}
+ if (layout.format() == SPARSE) {
+ if (!layout.padded_dimensions().empty()) {
+ return InvalidArgument("Sparse layout has padded dimensions");
+ }
+ }
+
return Status::OK();
}
@@ -273,7 +293,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
+ if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) {
return false;
}
@@ -323,7 +343,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
// Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); });
- } else if (ShapeUtil::IsOpaque(shape)) {
+ } else if (!ShapeUtil::IsArray(shape)) {
+ // Opaque, token types etc. ignore layout.
return true;
}
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
@@ -432,12 +453,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) {
- if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
- return false;
- }
if (ShapeUtil::IsTuple(lhs)) {
- if (ShapeUtil::TupleElementCount(lhs) !=
- ShapeUtil::TupleElementCount(rhs)) {
+ if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
+ ShapeUtil::TupleElementCount(rhs)) {
return false;
}
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
@@ -446,9 +464,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
}
return true;
- } else {
+ } else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
LayoutUtil::Equal(lhs.layout(), rhs.layout());
+ } else {
+ // Layouts of non-array and non-tuple shapes is ignored.
+ return true;
}
}
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
index 4fd1d818e3..e4c825450d 100644
--- a/tensorflow/compiler/xla/layout_util_test.cc
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
"elements, but shape is rank"));
}
+TEST_F(LayoutUtilTest, CopyTokenLayout) {
+ Shape src = ShapeUtil::MakeTokenShape();
+ Shape dst = ShapeUtil::MakeTokenShape();
+
+ // Layouts are trivially the same for token types and copying layouts should
+ // be a nop.
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
+TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
+ Shape src = ShapeUtil::MakeOpaqueShape();
+ Shape dst = ShapeUtil::MakeOpaqueShape();
+
+ // Layouts are trivially the same for opaque types and copying layouts should
+ // be a nop.
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
+TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
+ Shape src = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
+ Shape dst = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
+
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
TEST_F(LayoutUtilTest, ClearLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
@@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
}
+TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
+ // Opaque and token types trivially have layouts.
+ for (Shape shape :
+ {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
+ EXPECT_TRUE(LayoutUtil::HasLayout(shape));
+ LayoutUtil::ClearLayout(&shape);
+ EXPECT_TRUE(LayoutUtil::HasLayout(shape));
+ }
+}
+
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
new file mode 100644
index 0000000000..5db124b5a2
--- /dev/null
+++ b/tensorflow/compiler/xla/literal.cc
@@ -0,0 +1,1967 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal.h"
+
+#include <algorithm>
+#include <cstring>
+#include <functional>
+#include <limits>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::strings::Printf;
+using tensorflow::strings::StrCat;
+
+namespace xla {
+
+namespace {
+
+constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
+
+// Converts between little and big endian.
+//
+// Precondition: size % 2 == 0 (elements in the array are 16 bits long)
+void ConvertEndianShort(string* bytes) {
+ CHECK_EQ(bytes->size() / 2, 0);
+ for (int64 i = 0; i < bytes->size(); i += 2) {
+ std::swap((*bytes)[i], (*bytes)[i + 1]);
+ }
+}
+
+void ConvertEndianShort(char* bytes, int64 size) {
+ CHECK_EQ(size / 2, 0);
+ for (int64 i = 0; i < size; i += 2) {
+ std::swap(bytes[i], bytes[i + 1]);
+ }
+}
+
+} // namespace
+
+LiteralBase::~LiteralBase() {}
+
+std::ostream& operator<<(std::ostream& out, const Literal& literal) {
+ out << literal.ToString();
+ return out;
+}
+
+Literal::StrideConfig::StrideConfig(
+ const Shape& source_shape, const Shape& dest_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : dimensions(dimensions),
+ base(dimensions.size(), 0),
+ step(dimensions.size(), 1) {
+ if (!dimensions.empty()) {
+ // Selects the shape with the largest minor dimension as the one upon
+ // which to run the tight stride loop.
+ if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
+ dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
+ minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
+ dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
+ } else {
+ minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
+ source_stride =
+ IndexUtil::GetDimensionStride(source_shape, minor_dimension);
+ }
+ minor_loop_size = dimensions[minor_dimension];
+ step[minor_dimension] = minor_loop_size;
+ }
+}
+
+Literal::Literal(const Shape& shape)
+ : Literal(shape, /*allocate_arrays=*/true) {}
+
+void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ SetPiece(subshape, &child_piece, allocate_arrays);
+
+ piece->emplace_back(std::move(child_piece));
+ }
+ } else if (ShapeUtil::IsArray(shape)) {
+ if (allocate_arrays) {
+ if (LayoutUtil::IsSparseArray(shape)) {
+ // For sparse arrays, the buffer must be of the size of the maximum
+ // number of sparse elements possible.
+ const int64 max_sparse_elements =
+ LayoutUtil::MaxSparseElements(shape.layout());
+ piece->set_buffer(
+ new char[max_sparse_elements *
+ ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
+ piece->set_sparse_indices(
+ new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
+ } else {
+ piece->set_buffer(new char[piece->size_bytes()]);
+ }
+ }
+ } else {
+ // If the shape is neither an array nor tuple, then it must be
+ // zero-sized. Otherwise, some memory needs to be allocated for it.
+ CHECK_EQ(piece->size_bytes(), 0);
+ }
+}
+
+Literal::Literal(const Shape& shape, bool allocate_arrays)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(LayoutUtil::HasLayout(*shape_));
+ root_piece_ = new Piece();
+ root_piece_->set_subshape(shape_.get());
+ CHECK(&root_piece_->subshape() == shape_.get());
+
+ SetPiece(*shape_, root_piece_, allocate_arrays);
+}
+
+Literal::~Literal() {
+ if (root_piece_ != nullptr) {
+ DeallocateBuffers();
+ delete root_piece_;
+ }
+}
+
+void Literal::DeallocateBuffers() {
+ root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (piece->buffer() != nullptr) {
+ delete[] piece->buffer();
+ delete piece->sparse_indices();
+ }
+ });
+}
+
+Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
+
+Literal& Literal::operator=(Literal&& other) {
+ DCHECK(&other.root_piece_->subshape() == other.shape_.get());
+ using std::swap;
+ swap(shape_, other.shape_);
+ swap(root_piece_, other.root_piece_);
+ DCHECK(&root_piece_->subshape() == shape_.get());
+
+ return *this;
+}
+
+std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
+ auto literal = MakeUnique<Literal>(shape);
+ literal->root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (ShapeUtil::IsArray(piece->subshape())) {
+ memset(piece->untyped_data(), 0, piece->size_bytes());
+ }
+ });
+ return literal;
+}
+
+const SparseIndexArray* LiteralBase::sparse_indices(
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).sparse_indices();
+}
+
+SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
+ return piece(shape_index).sparse_indices();
+}
+
+template <typename NativeT>
+Status Literal::CopySliceFromInternal(
+ const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
+ TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
+
+ auto linear_index = [](const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
+ };
+
+ if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
+ ShapeUtil::Rank(shape()) == 0) {
+ // If any of the two shapes are scalars, we can just call the StridedCopy()
+ // directly, and we know we will be copying only one value.
+ TF_RET_CHECK(copy_size.empty());
+ StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
+ src_literal.data<NativeT>(),
+ linear_index(src_literal.shape(), src_base), 0, 1);
+ } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
+ !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
+ // Perform copy if neither src nor dest has dimensions with zero element,
+ // otherwise it's a no-op.
+ TF_RET_CHECK(src_base.size() == dest_base.size());
+ TF_RET_CHECK(src_base.size() == copy_size.size());
+
+ // Scan the source from minor, stepping in copy size blocks, then within
+ // the index enumaration functor, do a strided copy advancing source index
+ // by one (walking through the minor dimension), and destination index by
+ // proper stride size at the matching dimension.
+ DimensionVector src_indexes(src_base.size(), 0);
+ DimensionVector dest_indexes(dest_base.size(), 0);
+ Literal::StrideConfig stride_config(src_literal.shape(), shape(),
+ copy_size);
+
+ auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ // Map from multi-dimensional index, to source index.
+ std::transform(indexes.begin(), indexes.end(), src_base.begin(),
+ src_indexes.begin(), std::plus<int64>());
+ // Map from multi-dimensional index, to destination index.
+ std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
+ dest_indexes.begin(), std::plus<int64>());
+
+ int64 src_index = linear_index(src_literal.shape(), src_indexes);
+ int64 dest_index = linear_index(shape(), dest_indexes);
+
+ // `this->` is needed to workaround MSVC bug: #16882
+ StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
+ src_literal.data<NativeT>(), src_index,
+ stride_config.source_stride, stride_config.minor_loop_size);
+ return true;
+ };
+
+ ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
+ stride_config.dimensions, stride_config.step,
+ copy_proc);
+ }
+ return Status::OK();
+}
+
+Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index) {
+ DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
+ const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ src_literal.shape(), src_index);
+ const int64 dest_linear_index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ char* dest_address =
+ static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
+ const char* source_address =
+ static_cast<const char*>(src_literal.untyped_data()) +
+ src_linear_index * primitive_size;
+ if (dest_address != source_address) {
+ memcpy(dest_address, source_address, primitive_size);
+ }
+ return Status::OK();
+}
+
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
+ const LiteralProto& proto) {
+ if (!proto.has_shape()) {
+ return InvalidArgument("LiteralProto has no shape");
+ }
+ if (!LayoutUtil::HasLayout(proto.shape())) {
+ return InvalidArgument("LiteralProto has no layout");
+ }
+
+ auto literal = MakeUnique<Literal>(proto.shape());
+
+ TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ const LiteralProto* proto_element = &proto;
+ for (int64 i : index) {
+ CHECK(i < proto_element->tuple_literals_size());
+ proto_element = &proto_element->tuple_literals(i);
+ }
+
+ if (ShapeUtil::IsTuple(piece->subshape())) {
+ if (proto_element->tuple_literals_size() !=
+ ShapeUtil::TupleElementCount(piece->subshape())) {
+ return InvalidArgument(
+ "Expected %lld tuple elements in LiteralProto, has %d",
+ ShapeUtil::TupleElementCount(piece->subshape()),
+ proto_element->tuple_literals_size());
+ }
+ return Status::OK();
+ }
+ if (piece->subshape().element_type() == TOKEN) {
+ return Status::OK();
+ }
+
+ CHECK(ShapeUtil::IsArray(piece->subshape()));
+ TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
+
+ return Status::OK();
+ }));
+
+ return std::move(literal);
+}
+
+std::vector<Literal> Literal::DecomposeTuple() {
+ CHECK(ShapeUtil::IsTuple(shape()));
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
+ /*allocate_arrays=*/false));
+ Literal& element = elements.back();
+ element.root_piece_->ForEachMutableSubpiece(
+ [&](const ShapeIndex& index, Piece* dest_piece) {
+ ShapeIndex src_index = {i};
+ for (int64 j : index) {
+ src_index.push_back(j);
+ }
+ Piece& src_piece = piece(src_index);
+
+ // Move the respective buffer and sparse indices over to the element
+ // Literal.
+ dest_piece->set_buffer(src_piece.buffer());
+ src_piece.set_buffer(nullptr);
+ dest_piece->set_sparse_indices(src_piece.sparse_indices());
+ src_piece.set_sparse_indices(nullptr);
+ });
+ }
+ // Set this literal to be nil-shaped.
+ *this = Literal();
+ return elements;
+}
+
+namespace {
+
+// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
+// the array slices are indicated by dest_shape and src_shape respectively.
+template <typename NativeT>
+void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ tensorflow::gtl::ArraySlice<NativeT> src,
+ const Shape& dest_shape, const Shape& src_shape) {
+ CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
+ if (ShapeUtil::IsZeroElementArray(dest_shape)) {
+ return;
+ }
+ std::vector<int64> index(ShapeUtil::Rank(dest_shape));
+ do {
+ dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
+ src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
+ } while (IndexUtil::BumpIndices(dest_shape, &index));
+}
+
+} // namespace
+
+Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
+ CHECK(subshape_ != nullptr);
+ CHECK(src.subshape_ != nullptr);
+ if (ShapeUtil::Equal(subshape(), src.subshape())) {
+ // If the layouts are equal it's faster just to memcpy.
+ memcpy(buffer(), src.buffer(), src.size_bytes());
+ } else {
+ TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
+ std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
+ switch (subshape().element_type()) {
+#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
+ case (XLA_T): \
+ CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
+ subshape(), src.subshape()); \
+ break;
+ COPY_ELEMENTS(U8, uint8);
+ COPY_ELEMENTS(U16, uint16);
+ COPY_ELEMENTS(U32, uint32);
+ COPY_ELEMENTS(U64, uint64);
+ COPY_ELEMENTS(S8, int8);
+ COPY_ELEMENTS(S16, int16);
+ COPY_ELEMENTS(S32, int32);
+ COPY_ELEMENTS(S64, int64);
+ COPY_ELEMENTS(F16, half);
+ COPY_ELEMENTS(BF16, bfloat16);
+ COPY_ELEMENTS(F32, float);
+ COPY_ELEMENTS(F64, double);
+ COPY_ELEMENTS(C64, complex64);
+ COPY_ELEMENTS(PRED, bool);
+#undef COPY_ELEMENTS
+ default:
+ return Unimplemented(
+ "Copying a Literal object with element type %s is not implemented.",
+ PrimitiveType_Name(subshape().element_type()).c_str());
+ }
+ }
+ return Status::OK();
+}
+
+Status Literal::CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index,
+ const ShapeIndex& src_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ const Shape& src_subshape =
+ ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
+ if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
+ return InvalidArgument(
+ "Destination subshape incompatible with source subshape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_subshape).c_str());
+ }
+ return root_piece_->ForEachMutableSubpieceWithStatus(
+ [&](const ShapeIndex& index, Piece* piece) {
+ if (!ShapeUtil::IsArray(piece->subshape())) {
+ return Status::OK();
+ }
+
+ // Determine if this index is in the part of this literal that we want
+ // to copy over from src_literal.
+ bool in_subtree_to_copy = true;
+ for (int i = 0; i < dest_shape_index.size(); ++i) {
+ if (index[i] != dest_shape_index[i]) {
+ in_subtree_to_copy = false;
+ break;
+ }
+ }
+ if (!in_subtree_to_copy) {
+ return Status::OK();
+ }
+ // Construct the index of the corresponding piece in the source literal.
+ ShapeIndex src_piece_index = src_shape_index;
+ for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
+ src_piece_index.push_back(index[i]);
+ }
+ TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
+ return Status::OK();
+ });
+}
+
+Status Literal::MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index) {
+ const Shape& dest_subshape =
+ ShapeUtil::GetSubshape(shape(), dest_shape_index);
+ if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
+ return InvalidArgument(
+ "Destination subshape not equal to source shape: %s vs %s",
+ ShapeUtil::HumanString(dest_subshape).c_str(),
+ ShapeUtil::HumanString(src_literal.shape()).c_str());
+ }
+
+ src_literal.root_piece_->ForEachSubpiece(
+ [&](const ShapeIndex& src_index, const Piece& src_piece) {
+ if (!ShapeUtil::IsArray(src_piece.subshape())) {
+ return;
+ }
+
+ ShapeIndex dest_index = dest_shape_index;
+ for (int64 i : src_index) {
+ dest_index.push_back(i);
+ }
+ Piece& dest_piece = piece(dest_index);
+ delete[] dest_piece.buffer();
+ dest_piece.set_buffer(src_piece.buffer());
+ delete dest_piece.sparse_indices();
+ dest_piece.set_sparse_indices(src_piece.sparse_indices());
+ });
+
+ src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
+ delete src_literal.root_piece_;
+ src_literal.root_piece_ = new LiteralBase::Piece();
+ src_literal.root_piece_->set_subshape(src_literal.shape_.get());
+
+ return Status::OK();
+}
+
+Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size) {
+ TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
+ TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
+ << ShapeUtil::HumanString(src_literal.shape());
+ TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
+
+ switch (shape().element_type()) {
+ case U8:
+ return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
+ copy_size);
+ case U16:
+ return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
+ copy_size);
+ case U32:
+ return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
+ copy_size);
+ case U64:
+ return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
+ copy_size);
+ case S8:
+ return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
+ copy_size);
+ case S16:
+ return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
+ copy_size);
+ case S32:
+ return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
+ copy_size);
+ case S64:
+ return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
+ copy_size);
+ case F16:
+ return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
+ copy_size);
+ case BF16:
+ return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
+ copy_size);
+ case F32:
+ return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
+ copy_size);
+ case F64:
+ return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
+ copy_size);
+ case C64:
+ return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
+ copy_size);
+ case PRED:
+ return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
+ copy_size);
+ default:
+ break;
+ }
+ return Unimplemented(
+ "Copying a slice from a Literal object with element type %d is not "
+ "implemented.",
+ shape().element_type());
+}
+
+void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(element_count(), values.bits());
+ CHECK_EQ(shape().element_type(), PRED);
+ for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
+ Set({i}, values.get(i));
+ }
+}
+
+std::unique_ptr<Literal> LiteralBase::Relayout(
+ const Layout& new_layout, const ShapeIndex& shape_index) const {
+ // Create new shape with 'new_layout' set at the given shape index.
+ Shape new_shape = shape();
+ Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
+ TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
+ *subshape->mutable_layout() = new_layout;
+ auto result = MakeUnique<Literal>(new_shape);
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
+}
+
+std::unique_ptr<Literal> LiteralBase::Relayout(
+ const Shape& shape_with_layout) const {
+ CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
+ << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
+ << " not compatible with literal shape "
+ << ShapeUtil::HumanString(shape());
+ std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
+ ShapeUtil::ForEachSubshape(
+ result->shape(),
+ [this, &result](const Shape& subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ TF_CHECK_OK(result->CopyFrom(*this,
+ /*dest_shape_index=*/index,
+ /*src_shape_index=*/index));
+ }
+ });
+ return result;
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Broadcast only supports arrays.");
+ }
+
+ for (int64 i = 0; i < dimensions.size(); i++) {
+ TF_RET_CHECK(shape().dimensions(i) ==
+ result_shape.dimensions(dimensions[i]));
+ }
+
+ std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+
+ // scratch_source_index is temporary storage space for the computed index into
+ // the input literal. We put it here to avoid allocating an std::vector in
+ // every iteration of ShapeUtil::ForEachIndex.
+ std::vector<int64> scratch_source_index(shape().dimensions_size());
+
+ char* dest_data = static_cast<char*>(result->untyped_data());
+ const char* source_data = static_cast<const char*>(untyped_data());
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ ShapeUtil::ForEachIndex(
+ result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ for (int64 i = 0; i < dimensions.size(); ++i) {
+ scratch_source_index[i] = output_index[dimensions[i]];
+ }
+ int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ result_shape, output_index);
+ int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape(), scratch_source_index);
+ memcpy(dest_data + primitive_size * dest_index,
+ source_data + primitive_size * source_index, primitive_size);
+ return true;
+ });
+
+ return std::move(result);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Reshape does not support tuples.");
+ }
+ std::unique_ptr<Literal> output;
+ if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
+ output =
+ Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
+ } else {
+ output = CloneToUnique();
+ }
+ // Because the layout is monotonic, we can simply reuse the same sequence of
+ // values without changing their order.
+ *output->mutable_shape_do_not_use() =
+ ShapeUtil::MakeShape(shape().element_type(), dimensions);
+
+ int64 elements_before = ShapeUtil::ElementsIn(shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ if (elements_before != elements_after) {
+ return InvalidArgument(
+ "Shapes before and after Literal::Reshape have different numbers "
+ "of elements: %s vs %s.",
+ ShapeUtil::HumanString(shape()).c_str(),
+ ShapeUtil::HumanString(output->shape()).c_str());
+ }
+ return std::move(output);
+}
+
+std::unique_ptr<Literal> LiteralBase::Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const {
+ CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
+ CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
+ << "Given permutation is not a permutation of dimension numbers";
+ // To transpose the array, we just permute the dimensions and layout, and
+ // do a straight memory copy of the raw data set.
+ // This is considerably faster than iterating over every array element using
+ // the EachCell<>() and Set<>() APIs.
+ std::vector<int64> inverse_permutation = InversePermutation(permutation);
+ Shape permuted_shape =
+ ShapeUtil::PermuteDimensions(inverse_permutation, shape());
+ // Replace the layout with one affine to this shape, such that a
+ // transpose operation can be performed by leaving the flat values
+ // representation intact.
+ // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
+ // The shape with affine layout resulting from that operation will be
+ // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
+ // most minor.
+ //
+ // Essentially, given MinMaj(Di) the position of the Di dimension within the
+ // minor to major vector, and given T(Di) the index that the original Di
+ // dimension has within the transposed array, a layout is affine if
+ // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
+ // vector of the affine layout.
+ CHECK(LayoutUtil::IsDenseArray(permuted_shape));
+ Layout* layout = permuted_shape.mutable_layout();
+ layout->clear_minor_to_major();
+ for (auto index : LayoutUtil::MinorToMajor(shape())) {
+ layout->add_minor_to_major(inverse_permutation[index]);
+ }
+ auto new_literal = MakeUnique<Literal>(permuted_shape);
+ DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
+ ShapeUtil::ByteSizeOf(shape()));
+ std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
+ return new_literal;
+}
+
+template <typename NativeT>
+std::unique_ptr<Literal> LiteralBase::SliceInternal(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices) const {
+ auto result_literal = MakeUnique<Literal>(result_shape);
+ DimensionVector new_indices(ShapeUtil::Rank(result_shape));
+ result_literal->EachCell<NativeT>(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ NativeT value = Get<NativeT>(new_indices);
+ result_literal->Set<NativeT>(indices, value);
+ });
+ return result_literal;
+}
+
+std::unique_ptr<Literal> LiteralBase::Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const {
+ CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
+
+ DimensionVector result_dimensions;
+ for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
+ CHECK_GE(start_indices[dnum], 0);
+ CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
+ << "dnum = " << dnum;
+ int64 dimension = limit_indices[dnum] - start_indices[dnum];
+ CHECK_GE(dimension, 0) << "dnum = " << dnum;
+ result_dimensions.push_back(dimension);
+ }
+ const auto result_shape =
+ ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
+ LayoutUtil::MinorToMajor(shape()));
+ switch (result_shape.element_type()) {
+ case F32:
+ return SliceInternal<float>(result_shape, start_indices);
+ case BF16:
+ return SliceInternal<bfloat16>(result_shape, start_indices);
+ case C64:
+ return SliceInternal<complex64>(result_shape, start_indices);
+ case S32:
+ return SliceInternal<int32>(result_shape, start_indices);
+ case U32:
+ return SliceInternal<uint32>(result_shape, start_indices);
+ default:
+ LOG(FATAL) << "not yet implemented: "
+ << PrimitiveType_Name(result_shape.element_type());
+ }
+}
+
+Literal LiteralBase::Clone() const {
+ Literal result(shape());
+ TF_CHECK_OK(result.CopyFrom(*this));
+ return result;
+}
+
+std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
+ auto result = MakeUnique<Literal>(shape());
+ TF_CHECK_OK(result->CopyFrom(*this));
+ return result;
+}
+
+string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsDenseArray(subshape));
+ switch (subshape.element_type()) {
+ case PRED:
+ return Get<bool>(multi_index, shape_index) ? "true" : "false";
+ case S8:
+ return StrCat(Get<int8>(multi_index, shape_index));
+ case S16:
+ return StrCat(Get<int16>(multi_index, shape_index));
+ case S32:
+ return StrCat(Get<int32>(multi_index, shape_index));
+ case S64:
+ return StrCat(Get<int64>(multi_index, shape_index));
+ case U8:
+ return StrCat(Get<uint8>(multi_index, shape_index));
+ case U16:
+ return StrCat(Get<uint16>(multi_index, shape_index));
+ case U32:
+ return StrCat(Get<uint32>(multi_index, shape_index));
+ case U64:
+ return StrCat(Get<uint64>(multi_index, shape_index));
+ case F16:
+ return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
+ case F32:
+ return StrCat(Get<float>(multi_index, shape_index));
+ case BF16:
+ return StrCat(
+ static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
+ case F64:
+ return StrCat(Get<double>(multi_index, shape_index));
+ case C64: {
+ complex64 c = Get<complex64>(multi_index, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
+ }
+ default:
+ LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
+ }
+}
+
+string LiteralBase::GetSparseElementAsString(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
+ const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ switch (subshape.element_type()) {
+ case PRED:
+ return GetSparseElement<bool>(sparse_element_number, shape_index)
+ ? "true"
+ : "false";
+ case S8:
+ return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
+ case S16:
+ return StrCat(
+ GetSparseElement<int16>(sparse_element_number, shape_index));
+ case S32:
+ return StrCat(
+ GetSparseElement<int32>(sparse_element_number, shape_index));
+ case S64:
+ return StrCat(
+ GetSparseElement<int64>(sparse_element_number, shape_index));
+ case U8:
+ return StrCat(
+ GetSparseElement<uint8>(sparse_element_number, shape_index));
+ case U16:
+ return StrCat(
+ GetSparseElement<uint16>(sparse_element_number, shape_index));
+ case U32:
+ return StrCat(
+ GetSparseElement<uint32>(sparse_element_number, shape_index));
+ case U64:
+ return StrCat(
+ GetSparseElement<uint64>(sparse_element_number, shape_index));
+ case F16:
+ return StrCat(static_cast<float>(
+ GetSparseElement<half>(sparse_element_number, shape_index)));
+ case F32:
+ return StrCat(
+ GetSparseElement<float>(sparse_element_number, shape_index));
+ case BF16:
+ return StrCat(static_cast<float>(
+ GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
+ case F64:
+ return StrCat(
+ GetSparseElement<double>(sparse_element_number, shape_index));
+ case C64: {
+ complex64 c =
+ GetSparseElement<complex64>(sparse_element_number, shape_index);
+ return StrCat("(", c.real(), ", ", c.imag(), ")");
+ }
+ default:
+ LOG(FATAL) << "Invalid element type for sparse arrays: "
+ << PrimitiveType_Name(subshape.element_type());
+ }
+}
+
+StatusOr<int64> LiteralBase::GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(shape()));
+ switch (shape().element_type()) {
+ case PRED:
+ return Get<bool>(multi_index);
+ case U8:
+ return Get<uint8>(multi_index);
+ case S32:
+ return Get<int32>(multi_index);
+ case S64:
+ return Get<int64>(multi_index);
+ case U32:
+ return Get<uint32>(multi_index);
+ case U64:
+ return Get<uint64>(multi_index);
+ default:
+ return FailedPrecondition(
+ "Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()).c_str());
+ }
+}
+
+size_t LiteralBase::Hash() const {
+ using tensorflow::Hash64;
+ using tensorflow::Hash64Combine;
+
+ size_t hash_value = ShapeUtil::Hash(shape());
+
+ ShapeUtil::ForEachSubshape(
+ shape(), [&](const Shape& subshape, const ShapeIndex& index) {
+ if (!ShapeUtil::IsArray(subshape)) {
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDense(subshape.layout()));
+ hash_value = Hash64Combine(
+ hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
+ size_bytes(index)));
+ });
+
+ return hash_value;
+}
+
+Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value) {
+ CHECK(LayoutUtil::IsDenseArray(shape()));
+ switch (shape().element_type()) {
+ case PRED:
+ Set<bool>(multi_index, value);
+ break;
+ case U8:
+ Set<uint8>(multi_index, value);
+ break;
+ case S32:
+ Set<int32>(multi_index, value);
+ break;
+ case S64:
+ Set<int64>(multi_index, value);
+ break;
+ case U32:
+ Set<uint32>(multi_index, value);
+ break;
+ case U64:
+ Set<uint64>(multi_index, value);
+ break;
+ default:
+ return FailedPrecondition(
+ "Array element type is not integral: %s",
+ PrimitiveType_Name(shape().element_type()).c_str());
+ }
+ return Status::OK();
+}
+
+tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index) const {
+ const Piece& p = piece(shape_index);
+ CHECK_GE(sparse_element_number, 0);
+ CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
+ return p.sparse_indices()->At(sparse_element_number);
+}
+
+void Literal::SortSparseElements(const ShapeIndex& shape_index) {
+ piece(shape_index).SortSparseElements();
+}
+
+void LiteralBase::Piece::SortSparseElements() {
+ switch (subshape().element_type()) {
+ case PRED:
+ SortSparseElementsInternal<bool>();
+ break;
+ case S8:
+ SortSparseElementsInternal<int8>();
+ break;
+ case U8:
+ SortSparseElementsInternal<uint8>();
+ break;
+ case S16:
+ SortSparseElementsInternal<int16>();
+ break;
+ case U16:
+ SortSparseElementsInternal<uint16>();
+ break;
+ case S32:
+ SortSparseElementsInternal<int32>();
+ break;
+ case U32:
+ SortSparseElementsInternal<uint32>();
+ break;
+ case S64:
+ SortSparseElementsInternal<int64>();
+ break;
+ case U64:
+ SortSparseElementsInternal<uint64>();
+ break;
+ case F32:
+ SortSparseElementsInternal<float>();
+ break;
+ case F64:
+ SortSparseElementsInternal<double>();
+ break;
+ case C64:
+ SortSparseElementsInternal<complex64>();
+ break;
+ case F16:
+ SortSparseElementsInternal<half>();
+ break;
+ case BF16:
+ SortSparseElementsInternal<bfloat16>();
+ break;
+ default:
+ LOG(FATAL) << "Element type not valid for sparse array: "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
+
+template <typename NativeT>
+void LiteralBase::Piece::SortSparseElementsInternal() {
+ CHECK(LayoutUtil::IsSparseArray(subshape()));
+ int64 num_elements = sparse_indices()->index_count();
+ auto values = data<NativeT>();
+ CHECK_LE(num_elements, values.size());
+ sparse_indices()->SortWithValues(
+ tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
+}
+
+namespace {
+
+void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
+ bool print_layout, std::vector<string>* pieces) {
+ const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
+ CHECK(LayoutUtil::HasLayout(literal.shape()));
+ CHECK(LayoutUtil::HasLayout(subshape));
+
+ auto shape_to_string = [print_layout](const Shape& shape) {
+ if (print_layout) {
+ return ShapeUtil::HumanStringWithLayout(shape);
+ } else {
+ return ShapeUtil::HumanString(shape);
+ }
+ };
+
+ // TODO(b/32894291): refactor this code to reduce code duplication.
+ if (ShapeUtil::IsTuple(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" (\n");
+ std::vector<string> tuple_pieces;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
+ ShapeIndex element_index = shape_index;
+ element_index.push_back(i);
+ std::vector<string> element_pieces;
+ ToStringHelper(literal, element_index, print_layout, &element_pieces);
+ tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
+ }
+ pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
+ pieces->push_back("\n)");
+ return;
+ }
+
+ if (ShapeUtil::IsToken(subshape)) {
+ pieces->push_back("token");
+ return;
+ }
+
+ if (LayoutUtil::IsSparseArray(subshape)) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back("{");
+ int64 rank = ShapeUtil::Rank(subshape);
+ int64 num_elements = literal.sparse_element_count();
+ for (int64 i = 0; i < num_elements; ++i) {
+ if (i > 0) {
+ pieces->push_back(", ");
+ }
+ if (rank == 1) {
+ pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
+ pieces->push_back(": ");
+ } else {
+ pieces->push_back("[");
+ pieces->push_back(
+ tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
+ pieces->push_back("]: ");
+ }
+ pieces->push_back(literal.GetSparseElementAsString(i));
+ }
+ pieces->push_back("}");
+ return;
+ }
+
+ CHECK(LayoutUtil::IsDenseArray(subshape));
+
+ auto element_to_string =
+ [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ PrimitiveType element_type = subshape.element_type();
+ if (element_type == PRED) {
+ // We display predicates in a densely packed form.
+ return literal.Get<bool>(indices, shape_index) ? "1" : "0";
+ }
+ return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
+ literal.GetAsString(indices, shape_index);
+ };
+
+ if (ShapeUtil::Rank(subshape) == 0) {
+ pieces->push_back(literal.GetAsString({}, shape_index));
+ } else if (ShapeUtil::Rank(subshape) == 1) {
+ pieces->push_back("{");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(element_to_string({i0}));
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 2) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(" { ");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(element_to_string({i0, i1}));
+ }
+ pieces->push_back(" ");
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 3) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(i0 > 0 ? ",\n{" : "{");
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(i1 > 0 ? ",\n { " : " { ");
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(element_to_string({i0, i1, i2}));
+ }
+ pieces->push_back(" }");
+ }
+ pieces->push_back(" }");
+ }
+ pieces->push_back("\n}");
+ } else if (ShapeUtil::Rank(subshape) == 4) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(" {");
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3}));
+ }
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
+ }
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
+ }
+ pieces->push_back("}");
+ } else if (ShapeUtil::Rank(subshape) == 5) {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {\n");
+ for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
+ pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
+ for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
+ pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
+ for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
+ pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
+ for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
+ pieces->push_back(" {");
+ for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
+ pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
+ }
+ pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
+ : "},\n");
+ }
+ pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
+ : " },\n");
+ }
+ pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
+ }
+ pieces->push_back("}");
+ } else {
+ pieces->push_back(shape_to_string(subshape));
+ pieces->push_back(" {");
+ literal.EachCellAsString(
+ [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ pieces->push_back(" ");
+ pieces->push_back(value);
+ });
+ pieces->push_back("}");
+ }
+}
+
+} // namespace
+
+int64 LiteralBase::sparse_element_count() const {
+ CHECK(LayoutUtil::IsSparseArray(shape()));
+ return sparse_indices()->index_count();
+}
+
+string LiteralBase::ToString(bool print_layout) const {
+ std::vector<string> pieces;
+ CHECK(LayoutUtil::HasLayout(this->shape()));
+ ToStringHelper(*this, {}, print_layout, &pieces);
+ return tensorflow::str_util::Join(pieces, "");
+}
+
+void LiteralBase::EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
+ return;
+ }
+ std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
+ shape(), /*linear_index=*/0);
+ do {
+ per_cell(indices, GetAsString(indices));
+ } while (IndexUtil::BumpIndices(shape(), &indices));
+}
+
+namespace {
+template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
+std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
+ const LiteralBase& src_literal, const ConverterType& converter) {
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
+ src_literal.shape(),
+ primitive_util::NativeToPrimitiveType<NativeDestT>()));
+ auto src_data = src_literal.data<NativeSrcT>();
+ auto dest_data = result_literal->template data<NativeDestT>();
+ int64 num_elements = src_literal.element_count();
+
+ for (int64 i = 0; i < num_elements; ++i) {
+ dest_data[i] = converter(src_data[i]);
+ }
+ return result_literal;
+}
+
+template <typename NativeSrcT, typename NativeDestT>
+std::unique_ptr<Literal> ConvertBetweenNativeTypes(
+ const LiteralBase& src_literal) {
+ auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
+ auto converter = [](NativeSrcT src) {
+ return tensorflow::bit_cast<NativeDestT>(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
+ src_literal, converter);
+}
+
+// This template specialization is here to make the compiler happy. bit_cast has
+// a static check that the types are the same size. This specialization should
+// never be used because the source and destination types are checked for
+// identical sizes higher up.
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
+ std::unique_ptr<Literal>>::type
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
+ LOG(FATAL) << "Invalid bitcast between types of different sizes.";
+}
+
+template <PrimitiveType primitive_src_type>
+std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
+ CHECK(ShapeUtil::IsArray(src_literal.shape()));
+ auto result_literal = MakeUnique<Literal>(
+ ShapeUtil::ChangeElementType(src_literal.shape(), C64));
+ using NativeSrcT =
+ typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
+ tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
+ src_literal.data<NativeSrcT>();
+ tensorflow::gtl::MutableArraySlice<complex64> dest_data =
+ result_literal->data<complex64>();
+ int64 num_elements = src_literal.element_count();
+ for (int64 i = 0; i < num_elements; ++i) {
+ dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
+ }
+ return result_literal;
+}
+
+template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
+std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
+ bool bitcast) {
+ CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
+ if (bitcast) {
+ return BitcastBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ } else {
+ return ConvertBetweenNativeTypes<
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal);
+ }
+}
+
+template <PrimitiveType primitive_src_type>
+StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
+ const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
+ switch (primitive_dest_type) {
+#define CONVERT_IF_TYPES_MATCH(type) \
+ case (type): \
+ return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
+ bitcast);
+ CONVERT_IF_TYPES_MATCH(PRED)
+ CONVERT_IF_TYPES_MATCH(S8)
+ CONVERT_IF_TYPES_MATCH(S32)
+ CONVERT_IF_TYPES_MATCH(S64)
+ CONVERT_IF_TYPES_MATCH(U8)
+ CONVERT_IF_TYPES_MATCH(U32)
+ CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F16)
+ CONVERT_IF_TYPES_MATCH(F32)
+ CONVERT_IF_TYPES_MATCH(F64)
+ CONVERT_IF_TYPES_MATCH(BF16)
+#undef CONVERT_IF_TYPES_MATCH
+ case C64:
+ if (!bitcast) {
+ return ConvertToC64<primitive_src_type>(src_literal);
+ }
+ break;
+ // Other types are not yet supported.
+ default:
+ break;
+ }
+ return Unimplemented(
+ "Converting from type %s to type %s is not implemented.",
+ PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str());
+}
+
+StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
+ const LiteralBase& literal, PrimitiveType primitive_dest_type,
+ bool bitcast) {
+ TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
+ if (literal.shape().element_type() == primitive_dest_type) {
+ return literal.CloneToUnique();
+ }
+ switch (literal.shape().element_type()) {
+#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
+ case (type): \
+ return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
+ bitcast);
+ CONVERT_IF_DEST_TYPE_MATCHES(PRED)
+ CONVERT_IF_DEST_TYPE_MATCHES(S8)
+ CONVERT_IF_DEST_TYPE_MATCHES(S32)
+ CONVERT_IF_DEST_TYPE_MATCHES(S64)
+ CONVERT_IF_DEST_TYPE_MATCHES(U8)
+ CONVERT_IF_DEST_TYPE_MATCHES(U32)
+ CONVERT_IF_DEST_TYPE_MATCHES(U64)
+ CONVERT_IF_DEST_TYPE_MATCHES(F16)
+ CONVERT_IF_DEST_TYPE_MATCHES(F32)
+ CONVERT_IF_DEST_TYPE_MATCHES(F64)
+ CONVERT_IF_DEST_TYPE_MATCHES(BF16)
+#undef CONVERT_IF_DEST_TYPE_MATCHES
+ // Other types are not yet supported.
+ default:
+ return Unimplemented(
+ "%s from type %s to type %s is not implemented.",
+ (bitcast ? "Bitcast converting" : "Converting"),
+ PrimitiveType_Name(literal.shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str());
+ }
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
+ PrimitiveType primitive_dest_type) const {
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
+ PrimitiveType primitive_dest_type) const {
+ if (primitive_util::BitWidth(shape().element_type()) !=
+ primitive_util::BitWidth(primitive_dest_type)) {
+ return InvalidArgument(
+ "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
+ "%d",
+ PrimitiveType_Name(shape().element_type()).c_str(),
+ PrimitiveType_Name(primitive_dest_type).c_str(),
+ primitive_util::BitWidth(shape().element_type()),
+ primitive_util::BitWidth(primitive_dest_type));
+ }
+ return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
+}
+
+StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16) const {
+ if (!ShapeUtil::IsTuple(dest_shape)) {
+ if (round_f32_to_bf16 && shape().element_type() == F32 &&
+ dest_shape.element_type() == BF16) {
+ auto converter = [](float src) {
+ return tensorflow::bfloat16::round_to_bfloat16(src);
+ };
+ return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
+ converter);
+ }
+ return Convert(dest_shape.element_type());
+ }
+ std::vector<Literal> elements;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
+ auto element = LiteralSlice(*this, {i});
+ TF_ASSIGN_OR_RETURN(
+ auto new_element,
+ element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
+ elements.push_back(std::move(*new_element));
+ }
+ auto converted = MakeUnique<Literal>();
+ *converted = Literal::MoveIntoTuple(&elements);
+ return std::move(converted);
+}
+
+/* static */ Literal Literal::MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements) {
+ std::vector<Shape> element_shapes;
+ for (const Literal& element : elements) {
+ element_shapes.push_back(element.shape());
+ }
+ Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
+ /*allocate_arrays=*/false);
+ for (int i = 0; i < elements.size(); ++i) {
+ TF_CHECK_OK(
+ literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
+ }
+ return literal;
+}
+
+template <typename NativeT>
+bool LiteralBase::Piece::EqualElementsInternal(
+ const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
+ if (multi_index->size() == ShapeUtil::Rank(subshape())) {
+ return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
+ }
+ for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
+ multi_index->push_back(i);
+ if (!EqualElementsInternal<NativeT>(other, multi_index)) {
+ return false;
+ }
+ multi_index->pop_back();
+ }
+ return true;
+}
+
+bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
+ DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
+
+ std::vector<int64> multi_index;
+ switch (subshape().element_type()) {
+ case PRED:
+ return EqualElementsInternal<bool>(other, &multi_index);
+ case U8:
+ return EqualElementsInternal<uint8>(other, &multi_index);
+ case S32:
+ return EqualElementsInternal<int32>(other, &multi_index);
+ case S64:
+ return EqualElementsInternal<int64>(other, &multi_index);
+ case U32:
+ return EqualElementsInternal<uint32>(other, &multi_index);
+ case U64:
+ return EqualElementsInternal<uint64>(other, &multi_index);
+ case F32:
+ return EqualElementsInternal<float>(other, &multi_index);
+ case F64:
+ return EqualElementsInternal<double>(other, &multi_index);
+ case F16:
+ return EqualElementsInternal<half>(other, &multi_index);
+ case BF16:
+ return EqualElementsInternal<bfloat16>(other, &multi_index);
+ case C64:
+ return EqualElementsInternal<complex64>(other, &multi_index);
+ default:
+ LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
+ << PrimitiveType_Name(subshape().element_type());
+ }
+}
+
+bool LiteralBase::operator==(const LiteralBase& other) const {
+ if (!ShapeUtil::Compatible(shape(), other.shape())) {
+ return false;
+ }
+
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ const Piece& other_piece = other.piece(index);
+ if (!piece.EqualElements(other_piece)) {
+ return false;
+ }
+ return true;
+ });
+}
+
+namespace {
+
+template <typename NativeT>
+static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
+ NativeT value) {
+ for (int64 i = 0; i < data.size(); ++i) {
+ if (data[i] != value) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+bool LiteralBase::IsAll(int8 value) const {
+ return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
+ const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case U8:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
+ }
+ return false;
+ case U32:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
+ }
+ return false;
+ case U64:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
+ }
+ return false;
+ case S8:
+ return AllElementsEqualValue<int8>(piece.data<int8>(), value);
+ case S32:
+ return AllElementsEqualValue<int32>(piece.data<int32>(), value);
+ case S64:
+ return AllElementsEqualValue<int64>(piece.data<int64>(), value);
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
+ static_cast<bfloat16>(value));
+ case PRED:
+ if (value == 0) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), false);
+ }
+ if (value == 1) {
+ return AllElementsEqualValue<bool>(piece.data<bool>(), true);
+ }
+ return false;
+ default:
+ return false;
+ }
+ return false;
+ };
+
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsAllFloat(float value) const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ auto piece_is_all = [&]() {
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
+ }
+ };
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsAllComplex(complex64 value) const {
+ switch (shape().element_type()) {
+ case C64:
+ return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
+ value);
+ default:
+ return false;
+ }
+}
+
+bool LiteralBase::IsAllFirst() const {
+ return root_piece().ForEachSubpieceWithBool(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ if (!ShapeUtil::IsArray(piece.subshape())) {
+ return true;
+ }
+
+ // Empty shapes are not all the first element since there is no first
+ // element.
+ if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
+ return false;
+ }
+ auto piece_is_all = [&]() {
+ switch (piece.subshape().element_type()) {
+ case PRED: {
+ auto data = piece.data<bool>();
+ return AllElementsEqualValue<bool>(data, data[0]);
+ }
+ // 8 bit types
+ case S8: {
+ auto data = piece.data<int8>();
+ return AllElementsEqualValue<int8>(data, data[0]);
+ }
+ case U8: {
+ auto data = piece.data<uint8>();
+ return AllElementsEqualValue<uint8>(data, data[0]);
+ }
+ // 16 bit types
+ case BF16: {
+ auto data = piece.data<bfloat16>();
+ return AllElementsEqualValue<bfloat16>(data, data[0]);
+ }
+ case F16: {
+ auto data = piece.data<half>();
+ return AllElementsEqualValue<half>(data, data[0]);
+ }
+ case S16: {
+ auto data = piece.data<int16>();
+ return AllElementsEqualValue<int16>(data, data[0]);
+ }
+ case U16: {
+ auto data = piece.data<uint16>();
+ return AllElementsEqualValue<uint16>(data, data[0]);
+ }
+ // 32 bit types
+ case F32: {
+ auto data = piece.data<float>();
+ return AllElementsEqualValue<float>(data, data[0]);
+ }
+ case U32: {
+ auto data = piece.data<uint32>();
+ return AllElementsEqualValue<uint32>(data, data[0]);
+ }
+ case S32: {
+ auto data = piece.data<int32>();
+ return AllElementsEqualValue<int32>(data, data[0]);
+ }
+ // 64 bit types
+ case C64: {
+ auto data = piece.data<complex64>();
+ return AllElementsEqualValue<complex64>(data, data[0]);
+ }
+ case F64: {
+ auto data = piece.data<double>();
+ return AllElementsEqualValue<double>(data, data[0]);
+ }
+ case S64: {
+ auto data = piece.data<int64>();
+ return AllElementsEqualValue<int64>(data, data[0]);
+ }
+ case U64: {
+ auto data = piece.data<uint64>();
+ return AllElementsEqualValue<uint64>(data, data[0]);
+ }
+ default:
+ return false;
+ }
+ };
+
+ if (!piece_is_all()) {
+ return false;
+ }
+ return true;
+ });
+}
+
+bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ switch (shape().element_type()) {
+ case U8:
+ return Get<uint8>(indices) == 0;
+ case U32:
+ return Get<uint32>(indices) == 0;
+ case U64:
+ return Get<uint64>(indices) == 0;
+ case S8:
+ return Get<int8>(indices) == 0;
+ case S32:
+ return Get<int32>(indices) == 0;
+ case S64:
+ return Get<int64>(indices) == 0;
+ case F32:
+ return Get<float>(indices) == 0.0f;
+ case F64:
+ return Get<double>(indices) == 0.0;
+ case C64:
+ return Get<complex64>(indices) == complex64(0.0f, 0.0f);
+ case F16:
+ return Get<half>(indices) == static_cast<half>(0.0f);
+ case BF16:
+ return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
+ case PRED:
+ return Get<bool>(indices) == false;
+ default:
+ LOG(FATAL) << "Input literal must be an array.";
+ }
+}
+
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+void CopyToRepeatedField(RepeatedFieldT* dest,
+ const tensorflow::gtl::ArraySlice<NativeT> src) {
+ *dest = RepeatedFieldT(src.begin(), src.end());
+}
+
+} // namespace
+
+void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
+ *proto->mutable_shape() = subshape();
+ switch (subshape().element_type()) {
+ case PRED:
+ CopyToRepeatedField(proto->mutable_preds(), data<bool>());
+ break;
+ case U8:
+ proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
+ element_count());
+ break;
+ case U32:
+ CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
+ break;
+ case U64:
+ CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
+ break;
+ case S32:
+ CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
+ break;
+ case S64:
+ CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
+ break;
+ case F16:
+ *proto->mutable_f16s() = string(
+ reinterpret_cast<const char*>(data<half>().data()), size_bytes());
+ if (!kLittleEndian) {
+ ConvertEndianShort(proto->mutable_f16s());
+ }
+ break;
+ case BF16:
+ *proto->mutable_bf16s() = string(
+ reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
+ if (!kLittleEndian) {
+ ConvertEndianShort(proto->mutable_bf16s());
+ }
+ break;
+ case F32:
+ CopyToRepeatedField(proto->mutable_f32s(), data<float>());
+ break;
+ case F64:
+ CopyToRepeatedField(proto->mutable_f64s(), data<double>());
+ break;
+ case C64:
+ for (complex64 value : data<complex64>()) {
+ proto->add_c64s(value.real());
+ proto->add_c64s(value.imag());
+ }
+ break;
+ case TUPLE:
+ case TOKEN:
+ // Nothing to do but assign the shape which is done above.
+ return;
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ }
+}
+
+const void* LiteralBase::Piece::untyped_data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
+}
+
+void* LiteralBase::Piece::untyped_data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ return buffer();
+}
+
+namespace {
+
+template <typename RepeatedFieldT, typename NativeT>
+Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
+ const RepeatedFieldT& src) {
+ if (dest.size() != src.size()) {
+ return InvalidArgument(
+ "Expected %lu elements in LiteralProto repeated field, has %d",
+ dest.size(), src.size());
+ }
+ std::copy(src.begin(), src.end(), dest.begin());
+ return Status::OK();
+}
+
+} // namespace
+
+Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
+ // These conditions should have been checked in Literal::CreateFromProto.
+ TF_RET_CHECK(proto.has_shape());
+ TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
+ TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+
+ switch (subshape().element_type()) {
+ case PRED:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
+ break;
+ case U8: {
+ auto u8_data = data<uint8>();
+ TF_RET_CHECK(proto.u8s().size() == u8_data.size());
+ std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
+ } break;
+ case S32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
+ break;
+ case S64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
+ break;
+ case U32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
+ break;
+ case U64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
+ break;
+ case F16: {
+ const string& s(proto.f16s());
+ TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
+ }
+ } break;
+
+ case BF16: {
+ const string& s(proto.bf16s());
+ TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
+ memcpy(untyped_data(), s.data(), s.size());
+ if (!kLittleEndian) {
+ ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
+ }
+ } break;
+ case F32:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
+ break;
+ case F64:
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
+ break;
+ case C64: {
+ auto complex_data = data<complex64>();
+ TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
+ for (int64 i = 0; i < complex_data.size(); ++i) {
+ complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
+ }
+ } break;
+ case TUPLE:
+ LOG(FATAL) << "Should not be called on tuple shapes: "
+ << ShapeUtil::HumanString(subshape());
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ }
+ return Status::OK();
+}
+
+LiteralProto LiteralBase::ToProto() const {
+ LiteralProto proto;
+ root_piece().ForEachSubpiece(
+ [&](const ShapeIndex& index, const Piece& piece) {
+ LiteralProto* proto_piece = &proto;
+ for (int64 i : index) {
+ while (proto_piece->tuple_literals_size() <= i) {
+ proto_piece->add_tuple_literals();
+ }
+ proto_piece = proto_piece->mutable_tuple_literals(i);
+ }
+ piece.WriteToProto(proto_piece);
+ });
+
+ if (LayoutUtil::IsSparseArray(shape())) {
+ CopyToRepeatedField(proto.mutable_sparse_indices(),
+ sparse_indices()->data());
+ }
+
+ return proto;
+}
+
+const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
+ return piece(shape_index).untyped_data();
+}
+
+void* Literal::untyped_data(const ShapeIndex& shape_index) {
+ return piece(shape_index).untyped_data();
+}
+
+int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
+ return piece(shape_index).size_bytes();
+}
+
+string LiteralBase::GetR1U8AsString() const {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(shape().element_type(), U8);
+ return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
+ ShapeUtil::ElementsIn(shape()));
+}
+
+void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
+ CHECK(ShapeUtil::IsTuple(shape));
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const Shape& subshape = shape.tuple_shapes(i);
+
+ auto child_piece = Piece();
+ child_piece.set_subshape(&subshape);
+
+ if (ShapeUtil::IsTuple(subshape)) {
+ BuildPieceSubtree(subshape, &child_piece);
+ }
+
+ piece->emplace_back(std::move(child_piece));
+ }
+}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal)
+ : LiteralBase(), root_piece_(&literal.root_piece()) {}
+
+LiteralSlice::LiteralSlice(const LiteralBase& literal,
+ const ShapeIndex& view_root)
+ : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
+
+BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsArray(*shape_));
+ CHECK(LayoutUtil::HasLayout(*shape_));
+
+ root_piece_ = Piece();
+ root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
+ root_piece_.set_subshape(shape_.get());
+}
+
+BorrowingLiteral::BorrowingLiteral(
+ tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
+ : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
+ CHECK(ShapeUtil::IsTuple(*shape_));
+ CHECK(!ShapeUtil::IsNestedTuple(*shape_));
+ CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_));
+ root_piece_ = Piece();
+ root_piece_.set_subshape(shape_.get());
+ BuildPieceSubtree(*shape_, &root_piece_);
+
+ for (int i = 0; i < src_buf_ptrs.size(); ++i) {
+ const auto& src_shape = shape_->tuple_shapes(i);
+ CHECK(ShapeUtil::IsArray(src_shape));
+ root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
new file mode 100644
index 0000000000..dd67dfa8d4
--- /dev/null
+++ b/tensorflow/compiler/xla/literal.h
@@ -0,0 +1,1152 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_H_
+
+#include <functional>
+#include <initializer_list>
+#include <iterator>
+#include <memory>
+#include <ostream>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/sparse_index_array.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Forward declare Literal and LiteralSlice class to be used by the creation
+// methods in the base class.
+class Literal;
+class LiteralSlice;
+
+// Abstract base class for literals.
+class LiteralBase {
+ public:
+ virtual ~LiteralBase() = 0;
+
+ // Literals are equal if they have compatible shapes and the same data
+ // values. Layout is not compared.
+ bool operator==(const LiteralBase& other) const;
+ bool operator!=(const LiteralBase& other) const { return !(*this == other); }
+
+ // Returns the shape of the literal.
+ const Shape& shape() const { return root_piece().subshape(); }
+
+ // Serialize to proto.
+ LiteralProto ToProto() const;
+
+ // Returns an ArraySlice of the array for this literal for the given NativeT
+ // (e.g., float). CHECKs if the subshape of the literal at the given
+ // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
+ // to native type.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to the sparse index array. Returns nullptr if the
+ // literal is not a sparse array.
+ const SparseIndexArray* sparse_indices(
+ const ShapeIndex& shape_index = {}) const;
+
+ // Returns a const pointer to (or size of) the underlying buffer holding the
+ // array at the given shape index. CHECKs if the subshape of the literal at
+ // the given ShapeIndex is not array.
+ const void* untyped_data(const ShapeIndex& shape_index = {}) const;
+ int64 size_bytes(const ShapeIndex& shape_index = {}) const;
+
+ // Returns this literal's data as a string. This literal must be a rank-1 U8
+ // array.
+ string GetR1U8AsString() const;
+
+ // Returns a string representation of the literal value.
+ // Warning: this function can take minutes for multi-million element Literals.
+ string ToString(bool print_layout = false) const;
+
+ // Gets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const;
+ // Overloads of Get for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the element value at index (0, ..., 0), however many zeroes are
+ // required for that index.
+ template <typename NativeT>
+ NativeT GetFirstElement() const;
+
+ // As Get(), but determines the correct type and converts the value
+ // into text.
+ string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index = {}) const;
+ // As GetSparseElement(), but determines the correct type and converts the
+ // value into text.
+ string GetSparseElementAsString(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+ // As Get(), but determines the correct type and converts the value into
+ // int64. This literal must be an array.
+ StatusOr<int64> GetIntegralAsS64(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const;
+
+ // Returns the multi-index of the element in a sparse literal at the given
+ // sparse element number. The sparse element number is the position with in
+ // the sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
+ int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
+
+ // Returns the value of the element in a sparse literal at the given sparse
+ // element number. The sparse element number is the position with in the
+ // sparse array's list of (index, value) pairs, and is checked against the
+ // total number of (index, value) pairs in the sparse array.
+ template <typename NativeT>
+ NativeT GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index = {}) const;
+
+ // Invokes the "per cell" callback for each element in the provided
+ // literal with the element's indices and a string representation of
+ // the element's value.
+ //
+ // This function is useful if you want a polymorphic representation
+ // of the tensor's elements (turning it to a string for something
+ // like representation in a protobuf).
+ //
+ // This literal must have a dense layout.
+ void EachCellAsString(
+ const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>& per_cell) const;
+ template <typename NativeT>
+ void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const;
+
+ // Returns whether every element in this literal is equal to value.
+ //
+ // value is an int8 because we expect this to be called with small
+ // compile-time constants (0, -1, etc.) and so that whatever value you pass
+ // can be represented exactly by floating-point types as small as 16 bits.
+ //
+ // If value doesn't fit in this literal's type, returns false. Values of 1/0
+ // are considered equal to true/false; other values are not considered equal
+ // to true. Also if this literal is not array-shaped false is returned.
+ bool IsAll(int8 value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular floating-point number.
+ //
+ // If the literal is not a floating-point value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for values that can be expressed precisely as a float,
+ // e.g. -0.5. Also if this literal is not array-shaped false is returned.
+ bool IsAllFloat(float value) const;
+
+ // Like IsAll(const Literal&, int8), except we check whether the literal is
+ // equal to a particular complex number.
+ //
+ // If the literal is not a complex value, this always returns false.
+ //
+ // This casts value to the type of literal, then compares using ==. The usual
+ // admonishments about floating-point equality checks apply. We expect you to
+ // use this to check for complex values that can be expressed precisely as
+ // float pairs e.g. (-0.5, 1.0).
+ //
+ // This literal must have a dense layout.
+ bool IsAllComplex(complex64 value) const;
+
+ // Literal consists entirely of the first element of the literal.
+ bool IsAllFirst() const;
+
+ // Returns whether this literal is zero at the specified index. This literal
+ // must be an array with a dense layout.
+ bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
+
+ // Returns the count of the elements in the array at the given shape index in
+ // this literal.
+ int64 element_count(const ShapeIndex& index = {}) const {
+ return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
+ }
+
+ // Returns the count of the elements in the sparse array at the given shape
+ // index in this literal, which will be no larger than
+ // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
+ int64 sparse_element_count() const;
+
+ // Compute a hash for this literal. This literal must not be a sparse tensor
+ // or a tuple containing a sparse tensor.
+ size_t Hash() const;
+
+ // Converts this literal to the given shape. Returns an error is the
+ // conversion is not possible.
+ //
+ // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
+ // instead of truncation; otherwise, truncation is used.
+ //
+ // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
+ // the default behavior.
+ StatusOr<std::unique_ptr<Literal>> ConvertToShape(
+ const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
+
+ // Converts this literal to another primitive type using a bitcast
+ // conversion. The to and from primitive types must have the same bit
+ // width. Returns an error if the conversion is not possible. This literal
+ // must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> BitcastConvert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Converts this literal to another primitive type. Returns an error if the
+ // conversion is not possible. This literal must be array-shaped.
+ StatusOr<std::unique_ptr<Literal>> Convert(
+ PrimitiveType primitive_dest_type) const;
+
+ // Clones the underlying buffers into a new Literal, or new
+ // std::unique_ptr<Literal>.
+ Literal Clone() const;
+ std::unique_ptr<Literal> CloneToUnique() const;
+
+ // TODO(b/67651157): The methods below which perform computation on Literals
+ // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
+ // evaluator code which operates on Literals.
+ //
+ // Creates a new value that has the equivalent value as this
+ // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
+ // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
+ // minor-to-major dimension layout and the value in the cell at any given
+ // logical index (i0, i1) will be the same.
+ //
+ // For tuple shaped literals, shape_index should be used to select the inner
+ // array that the new layout applies to.
+ //
+ // Note: this is useful when the client wants to ensure that a value placed in
+ // the XLA allocation tracker has a particular layout; for efficiency
+ // purposes or avoiding unimplemented operation/layout combinations.
+ std::unique_ptr<Literal> Relayout(const Layout& new_layout,
+ const ShapeIndex& shape_index = {}) const;
+
+ // An overload of Relayout which changes the layout of the entire shape rather
+ // than being limited to a single array within the shape.
+ std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
+
+ // Creates a new literal by reshaping this literal to have the given
+ // dimensions. The total number of elements must not change; The
+ // implementation currently only supports monotonic dim0-major layouts.
+ // This literal must be an array.
+ StatusOr<std::unique_ptr<Literal>> Reshape(
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by broadcasting this literal with `dimensions` to
+ // yield a literal of shape `result_shape`.
+ StatusOr<std::unique_ptr<Literal>> Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
+ // Creates a new literal by reordering the dimensions of this literal.
+ // The given `permutation` must be a permutation of the dimension numbers
+ // in the original literal, and it specifies the order of the new dimensions
+ // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+ // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+ // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ // This literal must be an array.
+ std::unique_ptr<Literal> Transpose(
+ tensorflow::gtl::ArraySlice<int64> permutation) const;
+
+ // Creates a sub-array from this literal by extracting the indices
+ // [start_index, limit_index) of each dimension. The result literal has the
+ // same rank and layout as for the given literal. The number of indices in
+ // start_indices and limit_indices must be the rank of the literal, and the
+ // indices follow the order of the dimensions.
+ // This literal must be an array.
+ std::unique_ptr<Literal> Slice(
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) const;
+
+ // Creates a literal with a prepended dimension with bound "times"; e.g. a
+ // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
+ // literal replicated four times.
+ // This literal must be an array.
+ template <typename NativeT>
+ std::unique_ptr<Literal> Replicate(int64 times) const;
+
+ // Creates a new Literal object with the shape specified as parameter.
+ // The content of the literal values is the default value of the primitive
+ // type of literal itself (0 for numeric types, and false for predicates).
+ //
+ // Note: It's an antipattern to use this method then immediately call
+ // Literal::Populate on the result (since that results in zero initialization,
+ // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
+ // followed by the call to Literal::Populate can be used instead.
+ static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
+
+ protected:
+ // A data structure representing a subshape at a particular ShapeIndex within
+ // the literal. For array-shaped ShapeIndexes, this data structure holds the
+ // pointer to the memory allocated for the array data.
+ class Piece {
+ public:
+ // Returns the buffer holding the array data for this piece as an array
+ // slice. This piece must be array-shaped.
+ template <typename NativeT>
+ tensorflow::gtl::ArraySlice<NativeT> data() const;
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data();
+
+ // Returns the buffer holding the array data for this piece as a void*. This
+ // piece must be array-shaped.
+ void* untyped_data();
+ const void* untyped_data() const;
+
+ // Gets or sets an element in the array at the given index. The multi_index
+ // is CHECKed against the dimension sizes of the array. This piece must be
+ // array-shaped.
+ template <typename NativeT>
+ NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
+
+ // Gets/sets the buffer holding the array data.
+ char* buffer() const { return buffer_; }
+ void set_buffer(char* buffer) { buffer_ = buffer; }
+
+ // The array of multi-indices that provide the locations of non-zero
+ // elements in a sparse array. Only used if
+ // LayoutUtil::IsSparseArray(shape()) is true.
+ SparseIndexArray* sparse_indices() const { return sparse_indices_; }
+ void set_sparse_indices(SparseIndexArray* sparse_indices) {
+ sparse_indices_ = sparse_indices;
+ }
+
+ // Gets or sets the subshape of this piece. This reference points to a
+ // subshape within the shape in the containing Literal (Literal::shape_).
+ const Shape& subshape() const { return *subshape_; }
+ void set_subshape(const Shape* subshape) { subshape_ = subshape; }
+
+ // Returns the size in bytes of the buffer holding the array data.
+ int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
+
+ // Returns the number of elements in this piece's array.
+ int64 element_count() const {
+ // If this is a sparse array, use the number of elements represented by
+ // the indices in the associated SparseIndexArray.
+ return LayoutUtil::IsSparseArray(subshape())
+ ? sparse_indices()->index_count()
+ : ShapeUtil::ElementsIn(subshape());
+ }
+
+ // Returns the child piece at 'index' of this piece.
+ Piece& child(int64 index) { return children_[index]; }
+
+ // Adds a child piece to this piece's children.
+ void emplace_back(Piece child_piece) {
+ children_.emplace_back(std::move(child_piece));
+ }
+
+ // Returns the size of children pieces of this piece.
+ int64 children_size() { return children_.size(); }
+
+ // Visitor functions that recursively traverses the piece and calls the
+ // given function at each child piece. The function has the type:
+ // void (const ShapeIndex& index, const Piece& piece)
+ template <typename Fn>
+ void ForEachSubpiece(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(
+ [&func](const ShapeIndex& index, const Piece& piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ *this, &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, const Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachSubpieceWithStatus(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelper(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Bool (const ShapeIndex& index, const Piece& piece)
+ // The first non-true return value is returned by the function.
+ template <typename Fn>
+ bool ForEachSubpieceWithBool(const Fn& func) const {
+ ShapeIndex index;
+ return ForEachHelperBool(func, *this, &index);
+ }
+ // Same as above, but the function has the type:
+ // Void (const ShapeIndex& index, Piece& piece)
+ template <typename Fn>
+ void ForEachMutableSubpiece(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ [&func](const ShapeIndex& index, Piece* piece) {
+ func(index, piece);
+ return Status::OK();
+ },
+ const_cast<xla::LiteralBase::Piece*>(this), &index)
+ .IgnoreError();
+ }
+ // Same as above, but the function has the type:
+ // Status (const ShapeIndex& index, Piece& piece)
+ // The first non-OK return value is returned by the function.
+ template <typename Fn>
+ Status ForEachMutableSubpieceWithStatus(const Fn& func) {
+ ShapeIndex index;
+ return ForEachMutableHelper(
+ func, const_cast<xla::LiteralBase::Piece*>(this), &index);
+ }
+
+ // Returns true if this piece and 'other' contain the same data. This piece
+ // and 'other' must be array-shaped and compatible.
+ bool EqualElements(const Piece& other) const;
+
+ // Writes the shape and data (if array-shaped) into the given proto.
+ void WriteToProto(LiteralProto* proto) const;
+
+ // Copy the data from 'src' into this piece's buffer. Shapes of this piece
+ // and src must be compatible.
+ Status CopyFrom(const Piece& src);
+
+ // Copies the data from the given proto into this piece. The shape of this
+ // piece must be equal (not just compatible) to the shape of the proto.
+ Status CopyFromProto(const LiteralProto& proto);
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements();
+
+ private:
+ // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
+ // The first non-OK (or non-true) value is returned by the function.
+ // The callable 'func' has the same signature as described above in
+ // ForEachSubpiece*.
+ template <typename Fn>
+ Status ForEachHelper(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+ template <typename Fn>
+ bool ForEachHelperBool(const Fn& func, const Piece& piece,
+ ShapeIndex* index) const {
+ if (!func(*index, piece)) {
+ return false;
+ }
+ for (int64 i = 0; i < piece.children_.size(); ++i) {
+ index->push_back(i);
+ if (!ForEachHelperBool(func, piece.children_[i], index)) {
+ return false;
+ }
+ index->pop_back();
+ }
+ return true;
+ }
+ template <typename Fn>
+ Status ForEachMutableHelper(const Fn& func, Piece* piece,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(*index, piece));
+ for (int64 i = 0; i < piece->children_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(
+ ForEachMutableHelper(func, &piece->children_[i], index));
+ index->pop_back();
+ }
+ return Status::OK();
+ }
+
+ // Recursive helper for EqualElements.
+ template <typename NativeT>
+ bool EqualElementsInternal(const Piece& other,
+ std::vector<int64>* multi_index) const;
+
+ // Helper for SortSparseElements that has the element type as a template
+ // parameter.
+ template <typename NativeT>
+ void SortSparseElementsInternal();
+
+ // For array-shaped pieces, this is the buffer holding the literal data.
+ char* buffer_ = nullptr;
+
+ // For sparse arrays, this is the array of indices.
+ SparseIndexArray* sparse_indices_ = nullptr;
+
+ // The shape of piece. This points into the shape of the containing Literal
+ // (Literal::shape_).
+ const Shape* subshape_ = nullptr;
+
+ // Children pieces for tuple shaped pieces.
+ std::vector<Piece> children_ = {};
+ }; // class Piece
+
+ const Piece& piece(const ShapeIndex& shape_index) const {
+ Piece* piece = &const_cast<Piece&>(root_piece());
+ for (const auto i : shape_index) {
+ DCHECK_GE(i, 0);
+ DCHECK_LT(i, piece->children_size());
+ piece = &piece->child(i);
+ }
+ return *piece;
+ }
+
+ // Returns the piece at the root of the shape.
+ virtual const Piece& root_piece() const = 0;
+
+ // LiteralSlice and Literal must access Pieces of other Literals.
+ friend class Literal;
+ friend class LiteralSlice;
+ friend class BorrowingLiteral;
+
+ private:
+ template <typename NativeT>
+ std::unique_ptr<Literal> SliceInternal(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> start_indices) const;
+};
+
+// Class representing literal values in XLA.
+//
+// The underlying buffer and shape is always owned by this class.
+class Literal : public LiteralBase {
+ public:
+ Literal() : Literal(ShapeUtil::MakeNil()) {}
+
+ // Create a literal of the given shape. The literal is allocated sufficient
+ // memory to hold the shape. Memory is uninitialized.
+ explicit Literal(const Shape& shape);
+ virtual ~Literal();
+
+ // Literals are moveable, but not copyable. To copy a literal use
+ // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
+ // of literals which can be expensive.
+ Literal(const Literal& other) = delete;
+ Literal& operator=(const Literal& other) = delete;
+ Literal(Literal&& other);
+ // 'allocate_arrays' indicates whether to allocate memory for the arrays in
+ // the shape. If false, buffer pointers inside of the Literal::Pieces are set
+ // to nullptr.
+ Literal(const Shape& shape, bool allocate_arrays);
+ Literal& operator=(Literal&& other);
+
+ // TODO(b/67651157): Remove this accessor. Literal users should not be able to
+ // mutate the shape as this can produce malformed Literals.
+ Shape* mutable_shape_do_not_use() { return shape_.get(); }
+
+ // Returns a MutableArraySlice view of the array for this literal for the
+ // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
+ // given ShapeIndex is not array. See primitive_util.h for the mapping from
+ // XLA type to native type.
+ template <typename NativeT>
+ tensorflow::gtl::MutableArraySlice<NativeT> data(
+ const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::data;
+
+ // Returns a pointer to the sparse index array. Returns nullptr if the literal
+ // is not a sparse array.
+ SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
+
+ // Returns a pointer to the underlying buffer holding the array at the given
+ // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
+ // is not array.
+ void* untyped_data(const ShapeIndex& shape_index = {});
+ // Unhide const method from parent class.
+ using LiteralBase::untyped_data;
+
+ // Populates a literal with a sparse layout with the given indices and values.
+ // Each index in the indices array is CHECKed against the dimensions in the
+ // literal's shape. If sort is true, then the indices and values will be
+ // sorted. If sort is false, then the indices and values are assumed to
+ // already be in sorted order. See CreateSparse for an example of how data
+ // are populated.
+ template <typename NativeT>
+ void PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort = true);
+
+ // Copy values from 'src_literal' rooted at 'src_shape_index' into this
+ // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
+ // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
+ // rooted at 'src_shape_index', but need not be arrays.
+ Status CopyFrom(const LiteralSlice& src_literal,
+ const ShapeIndex& dest_shape_index = {},
+ const ShapeIndex& src_shape_index = {});
+
+ // Returns a vector containing the tuple elements of this Literal as separate
+ // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
+ // elements are moved into the new Literals; no data is copied. Upon return
+ // this Literal is set to a nil shape (empty tuple)
+ std::vector<Literal> DecomposeTuple();
+
+ // Similar to CopyFrom, but with move semantincs. The subshape of this literal
+ // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
+ // (layouts and shapes must match), but need not be arrays. The memory
+ // allocated in this literal for the subshape at dest_shape_index is
+ // deallocated, and the respective buffers are replaced with those in
+ // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
+ Status MoveFrom(Literal&& src_literal,
+ const ShapeIndex& dest_shape_index = {});
+
+ // Copies the values from src_literal, starting at src_base shape indexes,
+ // to this literal, starting at dest_base, where the copy size in each
+ // dimension is specified by copy_size.
+ // The src_literal and this literal must have the same primitive type,
+ // src_base+copy_size must fit the source literal dimensions, as well as
+ // dest_base+copy_size must fit the destination literal dimensions.
+ // Note: if either src_literal or this literal contains dimensions with zero
+ // element, then copy_size must be 0 in these dimensions while the
+ // corresponding base indices being 0.
+ // This literal and 'src_literal' must be arrays.
+ Status CopySliceFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Copies one element from src_literal[src_index] to (*this)[dest_index].
+ Status CopyElementFrom(const LiteralSlice& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_index,
+ tensorflow::gtl::ArraySlice<int64> dest_index);
+
+ // Sets an element in the literal at the given index. The multi_index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value);
+ // Overloads of Set for array literals. CHECKs if the literal is not
+ // array-shaped and dense.
+ template <typename NativeT>
+ void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
+
+ // Appends the given element to the literal. If the elements are not appended
+ // in sorted order, then SortSparseElements should be called before calling
+ // other methods. This literal must have a sparse layout.
+ template <typename NativeT>
+ void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value, const ShapeIndex& shape_index = {});
+
+ // Sorts the elements in a sparse array.
+ void SortSparseElements(const ShapeIndex& shape_index = {});
+
+ // As Set(), but truncates `value` to the literal element type before storing.
+ // This literal must be an array.
+ Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value);
+
+ // Populate this literal with the given values. Examples:
+ //
+ // // Populate with floats.
+ // Array2D<float> float_values = ...
+ // literal.PopulateR2FromArray2D(values);
+ //
+ // // Populate with int32s.
+ // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
+ //
+ // The shape and element type of this literal must match given values. For
+ // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
+ // array of S32.
+ template <typename NativeT>
+ void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ void PopulateR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ void PopulateFromArray(const Array<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ void PopulateR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Populates literal values by calling the generator function for every cell
+ // in this literal object.
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ //
+ // This literal must have a dense layout.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
+
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
+
+ // Fills this literal with the given value.
+ template <typename NativeT>
+ void PopulateWithValue(NativeT value);
+
+ // This operation is the inverse of DecomposeTuple. The given elements are
+ // moved into the tuple elements of a new tuple-shaped Literal which is
+ // returned. Upon return, each of the Literals in 'elements' is set to a nil
+ // shape (empty tuple).
+ static Literal MoveIntoTuple(
+ tensorflow::gtl::MutableArraySlice<Literal> elements);
+
+ // Serialize from a proto.
+ static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
+ const LiteralProto& proto);
+
+ private:
+ // Recursively sets the subshapes and buffers of all subpieces rooted at
+ // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
+ // the shape.
+ void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
+
+ // Returns the piece at the given ShapeIndex.
+ Piece& piece(const ShapeIndex& shape_index) {
+ return const_cast<Piece&>(LiteralBase::piece(shape_index));
+ }
+
+ Piece& root_piece() const override { return *root_piece_; };
+
+ // Internal template helper for the Literal::CopySliceFrom(), matching its
+ // arguments one by one.
+ template <typename NativeT>
+ Status CopySliceFromInternal(const LiteralBase& src_literal,
+ tensorflow::gtl::ArraySlice<int64> src_base,
+ tensorflow::gtl::ArraySlice<int64> dest_base,
+ tensorflow::gtl::ArraySlice<int64> copy_size);
+
+ // Utility structure which is used to create the optimal configuration for
+ // a ShapeUtil::ForEachIndex() scan across two literals.
+ struct StrideConfig {
+ StrideConfig(const Shape& source_shape, const Shape& dest_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // The dimensions of the stride operation. Essentially every dimension
+ // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
+ // steps.
+ tensorflow::gtl::ArraySlice<int64> dimensions;
+ DimensionVector base;
+ DimensionVector step;
+ int64 minor_dimension = 0;
+ // The size of the strides for source and destination. One of the two
+ // (the one looping through its most minor dimension) will be 1, while
+ // the other will be the stride size at the dimension matching the other
+ // shape most minor dimension being scanned.
+ int64 dest_stride = 1;
+ int64 source_stride = 1;
+ // The size of the inner loop on the most minor dimension.
+ int64 minor_loop_size = 1;
+ };
+
+ // Literal class always owns the shape. The parent class borrows this shape.
+ std::unique_ptr<Shape> shape_;
+
+ Piece* root_piece_ = nullptr;
+
+ // Implementation details shared between Populate() and PopulateParallel()
+ template <typename NativeT, typename FnType>
+ Status PopulateInternal(const FnType& generator, bool parallel);
+
+ // Deallocate the buffers held by this literal.
+ void DeallocateBuffers();
+
+ friend class LiteralBase;
+};
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
+
+// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
+// literal buffers always owned by others.
+class LiteralSlice : public LiteralBase {
+ public:
+ LiteralSlice() : LiteralBase() {}
+
+ // Implicit conversion constructors.
+ LiteralSlice(const LiteralBase& literal);
+ LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
+
+ private:
+ const Piece& root_piece() const override { return *root_piece_; };
+
+ const Piece* root_piece_; // Not owned.
+};
+
+// A read-only Literal where the underlying buffers are never owned by this
+// class.
+class BorrowingLiteral : public LiteralBase {
+ public:
+ BorrowingLiteral() : LiteralBase() {}
+
+ // 'src_buf_ptr' is not owned by this class and must outlive the
+ // lifetime of this class. It points to an appropirately sized buffer with
+ // data interpretered as indicated by 'shape'.
+ // This constructor is only used for array shapes.
+ BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
+ // Similar as above, except to be used for constructing non-nested tuples.
+ BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
+ const Shape& shape);
+ // TODO(b/79707221): adding constructors for nested tuples as well.
+
+ private:
+ // Recursively builds the subtree for the given piece and sets the subshapes
+ // of the given piece with the given shape.
+ void BuildPieceSubtree(const Shape& shape, Piece* piece);
+
+ // Accessor for the root piece of this literal.
+ const Piece& root_piece() const override { return root_piece_; };
+ Piece root_piece_;
+
+ // Shape of this literal. Stored as unique_ptr so such that the (default)
+ // move construction of this class would be trivially correct: the pointer to
+ // Shape root_piece_ stores will still point to the correct address.
+ std::unique_ptr<Shape> shape_;
+};
+
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::ArraySlice<NativeT>(
+ reinterpret_cast<const NativeT*>(buffer()), element_count());
+}
+
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
+ CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ CHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
+ << "Attempting to access "
+ << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
+ << " type, but literal element type is "
+ << PrimitiveType_Name(subshape().element_type());
+ return tensorflow::gtl::MutableArraySlice<NativeT>(
+ reinterpret_cast<NativeT*>(buffer()), element_count());
+}
+
+template <typename NativeT>
+NativeT LiteralBase::Piece::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
+ return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)];
+}
+
+template <typename NativeT>
+void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ CHECK(LayoutUtil::IsDenseArray(subshape()));
+ data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
+ subshape(), multi_index)] = value;
+}
+
+template <typename NativeT>
+tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).data<NativeT>();
+}
+
+template <typename NativeT>
+tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
+ const ShapeIndex& shape_index) {
+ return piece(shape_index).data<NativeT>();
+}
+
+template <typename NativeT>
+inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index) const {
+ return piece(shape_index).Get<NativeT>(multi_index);
+}
+
+template <typename NativeT>
+inline NativeT LiteralBase::Get(
+ tensorflow::gtl::ArraySlice<int64> multi_index) const {
+ return root_piece().Get<NativeT>(multi_index);
+}
+
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ const ShapeIndex& shape_index, NativeT value) {
+ return piece(shape_index).Set<NativeT>(multi_index, value);
+}
+
+template <typename NativeT>
+inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ return root_piece().Set<NativeT>(multi_index, value);
+}
+
+template <typename NativeT>
+NativeT LiteralBase::GetFirstElement() const {
+ return data<NativeT>().at(0);
+}
+
+template <typename NativeT>
+NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
+ const ShapeIndex& shape_index) const {
+ CHECK(
+ LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
+ return data<NativeT>(shape_index)[sparse_element_number];
+}
+
+template <typename NativeT>
+void Literal::AppendSparseElement(
+ tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
+ const ShapeIndex& shape_index) {
+ Piece& p = piece(shape_index);
+ const Shape& subshape = p.subshape();
+ CHECK(LayoutUtil::IsSparseArray(subshape));
+ int64 rank = ShapeUtil::Rank(subshape);
+ CHECK_EQ(multi_index.size(), rank);
+ int64 last_element = p.sparse_indices()->index_count();
+ CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
+ p.sparse_indices()->Append(multi_index);
+ CHECK_LT(last_element, p.data<NativeT>().size());
+ p.data<NativeT>()[last_element] = value;
+}
+
+template <typename NativeT>
+void LiteralBase::EachCell(
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) const {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
+ return;
+ }
+ std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
+ do {
+ per_cell(indices, Get<NativeT>(indices));
+ } while (IndexUtil::BumpIndices(shape(), &indices));
+}
+
+template <typename NativeT>
+inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 1);
+ CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ for (int64 i = 0; i < values.size(); ++i) {
+ Set({i}, values[i]);
+ }
+}
+
+template <typename NativeT>
+void Literal::PopulateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(ShapeUtil::Rank(shape()), 2);
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+
+ const int64 dim0_size = values.size();
+ const int64 dim1_size = values.begin()->size();
+ CHECK_EQ(dim0_size, shape().dimensions(0));
+ CHECK_EQ(dim1_size, shape().dimensions(1));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto value : inner_list) {
+ Set({dim0, dim1}, value);
+ ++dim1;
+ }
+ CHECK_EQ(dim1_size, dim1);
+ ++dim0;
+ }
+}
+
+template <typename NativeT>
+void Literal::PopulateFromArray(const Array<NativeT>& values) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
+ for (int dim = 0; dim < values.num_dimensions(); ++dim) {
+ CHECK_EQ(values.dim(dim), shape().dimensions(dim));
+ }
+ values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value) { this->Set(indices, value); });
+}
+
+template <typename NativeT>
+void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
+ PopulateFromArray(values);
+}
+
+template <typename NativeT>
+void Literal::PopulateSparse(SparseIndexArray indices,
+ tensorflow::gtl::ArraySlice<NativeT> values,
+ bool sort) {
+ CHECK(LayoutUtil::IsSparseArray(shape()));
+ int rank = ShapeUtil::Rank(shape());
+ CHECK_EQ(indices.rank(), rank);
+ int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
+ CHECK_LE(indices.max_indices(), max_elements);
+ int64 num_elements = values.size();
+ CHECK_LE(num_elements, max_elements);
+ CHECK_EQ(num_elements, indices.index_count());
+ auto root_data = root_piece().data<NativeT>();
+ // Piece::data() returns an ArraySlice of size equal to the number of indices
+ // in the SparseIndexArray. So there is no need to adjust the size of the data
+ // here. It is enough to just copy the incoming values into the data buffer.
+ std::copy(values.begin(), values.end(), root_data.begin());
+ *this->root_piece().sparse_indices() = std::move(indices);
+ if (sort) {
+ auto root_data = this->root_piece().data<NativeT>();
+ this->root_piece().sparse_indices()->SortWithValues(root_data);
+ }
+ DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
+ const Shape& this_shape = shape();
+ const int64 rank = ShapeUtil::Rank(this_shape);
+ TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
+ TF_RET_CHECK(this_shape.element_type() ==
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
+ if (rank > 0) {
+ StrideConfig stride_config(this_shape, this_shape,
+ AsInt64Slice(this_shape.dimensions()));
+ int64 minor_dimension_size =
+ ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
+
+ auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ DimensionVector minor_scan_indexes(rank, 0);
+ const int64 index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
+ std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
+ for (int64 i = 0; i < minor_dimension_size; ++i) {
+ minor_scan_indexes[stride_config.minor_dimension] = i;
+ literal_data.at(index + i) = generator(minor_scan_indexes);
+ }
+ };
+ if (parallel) {
+ ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
+ stride_config.dimensions,
+ stride_config.step, init_function);
+ } else {
+ ShapeUtil::ForEachIndex(
+ this_shape, stride_config.base, stride_config.dimensions,
+ stride_config.step,
+ [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ init_function(indexes);
+ return true;
+ });
+ }
+ } else {
+ // For scalars.
+ literal_data.at(0) = generator({});
+ }
+ return Status::OK();
+}
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/false);
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateParallel(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/true);
+}
+
+template <typename NativeT>
+void Literal::PopulateWithValue(NativeT value) {
+ CHECK(ShapeUtil::IsArray(shape()));
+ CHECK_EQ(shape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>());
+ for (NativeT& element : data<NativeT>()) {
+ element = value;
+ }
+}
+
+template <typename NativeT>
+std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
+ DimensionVector bounds = {times};
+ bounds.reserve(shape().dimensions_size() + 1);
+ for (int64 bound : shape().dimensions()) {
+ bounds.push_back(bound);
+ }
+ auto literal =
+ MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
+ int64 elements = ShapeUtil::ElementsIn(literal->shape());
+ if (elements == 0) {
+ return literal;
+ }
+
+ DimensionVector output_indices(bounds.size(), 0);
+ tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
+ input_indices.remove_prefix(1);
+
+ bool done = false;
+ while (!done) {
+ const auto element = Get<NativeT>(input_indices);
+ literal->Set<NativeT>(output_indices, element);
+
+ done = true;
+ for (int n = 0; n < output_indices.size(); ++n) {
+ ++output_indices[n];
+ if (output_indices[n] < bounds[n]) {
+ done = false;
+ break;
+ }
+ output_indices[n] = 0;
+ }
+ }
+ return literal;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LITERAL_H_
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index bf9679cafe..94993cc874 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <cmath>
#include <vector>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -217,7 +218,7 @@ class NearComparator {
return Printf(
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
- Literal::MultiIndexAsString(
+ LiteralUtil::MultiIndexAsString(
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
linear_index))
.c_str(),
@@ -606,8 +607,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
} // namespace
Status EqualShapes(const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return InvalidArgument("tupleness-mismatch! want: %s got %s",
+ if (expected.element_type() != actual.element_type()) {
+ return InvalidArgument("element type mismatch, want: %s got %s",
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
@@ -626,7 +627,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return AppendStatus(result, StrCat("mismatch in tuple index", i));
}
}
- } else {
+ } else if (ShapeUtil::IsArray(expected)) {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return InvalidArgument("want rank of %s got rank of %s",
ShapeUtil::HumanString(expected).c_str(),
@@ -652,6 +653,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
}
}
}
+ // Non-array, non-tuple shapes are trivially equivalent.
return Status::OK();
}
@@ -705,6 +707,9 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
}
break;
}
+ case TOKEN:
+ // Tokens have no on-device representation and are trivially equal.
+ return Status::OK();
default:
LOG(FATAL)
<< "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
@@ -718,7 +723,7 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
return AppendStatus(result,
tensorflow::strings::Printf(
"\nat index: %s\nexpected: %s\nactual: %s",
- Literal::MultiIndexAsString(multi_index).c_str(),
+ LiteralUtil::MultiIndexAsString(multi_index).c_str(),
ToStringTruncated(expected).c_str(),
ToStringTruncated(actual).c_str()));
}
diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h
index 00a13e3619..9e5bf7c1d0 100644
--- a/tensorflow/compiler/xla/literal_comparison.h
+++ b/tensorflow/compiler/xla/literal_comparison.h
@@ -20,7 +20,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
#include "tensorflow/compiler/xla/error_spec.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/lib/core/status.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_test.cc
index f127cee0fd..e8f919950f 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include <vector>
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
@@ -76,11 +77,11 @@ class LiteralUtilTest : public ::testing::Test {
layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
literal_r4_2x2x3x3_dim0major_ =
- Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
- layout_r4_dim0major_);
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0major_);
literal_r4_2x2x3x3_dim0minor_ =
- Literal::CreateR4FromArray4DWithLayout<float>(arr4d,
- layout_r4_dim0minor_);
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0minor_);
}
Layout layout_r2_dim0major_;
@@ -94,47 +95,47 @@ class LiteralUtilTest : public ::testing::Test {
};
TEST_F(LiteralUtilTest, LiteralScalarToString) {
- auto true_lit = Literal::CreateR0<bool>(true);
+ auto true_lit = LiteralUtil::CreateR0<bool>(true);
ASSERT_EQ("true", true_lit->ToString());
- auto false_lit = Literal::CreateR0<bool>(false);
+ auto false_lit = LiteralUtil::CreateR0<bool>(false);
ASSERT_EQ("false", false_lit->ToString());
- auto u32_lit = Literal::CreateR0<uint32>(42);
+ auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
ASSERT_EQ("42", u32_lit->ToString());
- auto s32_lit = Literal::CreateR0<int32>(-999);
+ auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
ASSERT_EQ("-999", s32_lit->ToString());
- auto f32_lit = Literal::CreateR0<float>(3.14f);
+ auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
ASSERT_EQ("3.14", f32_lit->ToString());
- auto f16_lit = Literal::CreateR0<half>(static_cast<half>(0.5f));
+ auto f16_lit = LiteralUtil::CreateR0<half>(static_cast<half>(0.5f));
ASSERT_EQ("0.5", f16_lit->ToString());
- auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
+ auto c64_lit = LiteralUtil::CreateR0<complex64>({3.14f, 2.78f});
ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
- auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
+ auto bf16_lit = LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
ASSERT_EQ("0.5", bf16_lit->ToString());
// 3.14 will be truncated to 3.125 in bfloat16 format.
auto bf16_lit_truncated =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
+ LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
ASSERT_EQ("3.125", bf16_lit_truncated->ToString());
auto bf16_lit_truncated2 =
- Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
+ LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
ASSERT_EQ("9", bf16_lit_truncated2->ToString());
}
TEST_F(LiteralUtilTest, LiteralVectorToString) {
- auto pred_vec = Literal::CreateR1<bool>({true, false, true});
+ auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
ASSERT_EQ("{101}", pred_vec->ToString());
}
TEST_F(LiteralUtilTest, R2ToString) {
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}});
+ const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
const string expected = R"(s32[3,2] {
{ 1, 2 },
{ 3, 4 },
@@ -144,7 +145,8 @@ TEST_F(LiteralUtilTest, R2ToString) {
}
TEST_F(LiteralUtilTest, R3ToString) {
- const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
+ const auto literal =
+ LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
const string expected = R"(s32[3,2,1] {
{ { 1 },
{ 2 } },
@@ -157,9 +159,9 @@ TEST_F(LiteralUtilTest, R3ToString) {
}
TEST_F(LiteralUtilTest, TupleToString) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
const string expected = R"((f32[], f32[2,2]) (
1,
f32[2,2] {
@@ -182,7 +184,7 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
});
// clang-format on
- auto literal = Literal::CreateR3FromArray3D(array_3d);
+ auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2));
string result = literal->ToString();
const string expected = R"(f32[2,3,2] {
@@ -205,7 +207,7 @@ TEST_F(LiteralUtilTest, CreateSparse) {
{3, 5, 6},
};
std::vector<int64> values = {7, 8, 9, 10};
- auto literal = Literal::CreateSparse<int64>(
+ auto literal = LiteralUtil::CreateSparse<int64>(
dimensions, SparseIndexArray(indices.n1() + 3, indices), values);
Array2D<int64> expected_indices = {
@@ -224,7 +226,7 @@ TEST_F(LiteralUtilTest, CreateSparse) {
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
// clang-format off
- auto literal = Literal::CreateR4Projected<float>({
+ auto literal = LiteralUtil::CreateR4Projected<float>({
{1, 2},
{1001, 1002},
{2001, 2002},
@@ -284,7 +286,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
TEST_F(LiteralUtilTest, EachCellR2F32) {
// clang-format off
- auto literal = Literal::CreateR2<float>({
+ auto literal = LiteralUtil::CreateR2<float>({
{3.1f, 4.2f},
{9.3f, 12.4f},
});
@@ -303,26 +305,27 @@ TEST_F(LiteralUtilTest, EachCellR2F32) {
TEST_F(LiteralUtilTest, ScalarEquality) {
// Test equality with scalars.
- auto f32_42 = Literal::CreateR0<float>(42.0);
- auto f32_42_clone = Literal::CreateR0<float>(42.0);
+ auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
EXPECT_EQ(*f32_42, *f32_42);
EXPECT_EQ(*f32_42, *f32_42_clone);
- auto f32_123 = Literal::CreateR0<float>(123.0);
+ auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
EXPECT_NE(*f32_42, *f32_123);
- auto f64_42 = Literal::CreateR0<double>(42.0);
+ auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
EXPECT_NE(*f32_42, *f64_42);
}
TEST_F(LiteralUtilTest, NonScalarEquality) {
// Test equality with nonscalars.
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_clone = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_different = Literal::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
- auto vector_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
- auto scalar = Literal::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_different =
+ LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
+ auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
Literal nil(ShapeUtil::MakeNil());
EXPECT_EQ(*matrix, *matrix);
@@ -334,6 +337,22 @@ TEST_F(LiteralUtilTest, NonScalarEquality) {
EXPECT_EQ(nil, nil);
}
+TEST_F(LiteralUtilTest, TokenEquality) {
+ auto token0 = LiteralUtil::CreateToken();
+ auto token1 = LiteralUtil::CreateToken();
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+
+ EXPECT_EQ(*token0, *token1);
+ EXPECT_NE(*token0, *scalar);
+
+ EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get()}),
+ *LiteralUtil::MakeTuple({token0.get()}));
+ EXPECT_EQ(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
+ *LiteralUtil::MakeTuple({token1.get(), scalar.get()}));
+ EXPECT_NE(*LiteralUtil::MakeTuple({token0.get(), scalar.get()}),
+ *LiteralUtil::MakeTuple({scalar.get(), token1.get()}));
+}
+
TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
// Test equality with literals which have different layouts.
auto colmajor =
@@ -355,43 +374,46 @@ TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
TEST_F(LiteralUtilTest, TupleEquality) {
// Test equality with tuples.
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
- auto scalar_clone = Literal::CreateR0<float>(1.0);
- auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()});
+ auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
+ auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
EXPECT_EQ(*tuple1, *tuple2);
// Tuple with elements reversed.
- auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()});
+ auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
EXPECT_NE(*tuple1, *reversed_tuple);
// Tuple with different value.
- auto scalar_42 = Literal::CreateR0<float>(42.0);
- auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()});
+ auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto different_tuple =
+ LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
EXPECT_NE(*tuple1, *different_tuple);
}
TEST_F(LiteralUtilTest, C64Equality) {
// Test equality with tuples.
- auto vector = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
// Tuple with the same elements. One element is shared with the original
// tuple, the other is a clone of the element in the original tuple.
- auto vector_clone = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector_clone =
+ LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
EXPECT_EQ(*vector, *vector_clone);
- auto vector_reversed = Literal::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
+ auto vector_reversed =
+ LiteralUtil::CreateR1<complex64>({{3.0, 4.0}, {1.0, 2.0}});
EXPECT_NE(*vector, *vector_reversed);
}
TEST_F(LiteralUtilTest, IsAllTuple) {
- auto element1 = Literal::CreateR0<float>(0.0);
- auto element2 = Literal::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
- auto tuple = Literal::MakeTuple({element1.get(), element1.get()});
+ auto element1 = LiteralUtil::CreateR0<float>(0.0);
+ auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
+ auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
// Tuples should always return false for IsAll.
EXPECT_FALSE(tuple->IsAll(0));
@@ -400,140 +422,141 @@ TEST_F(LiteralUtilTest, IsAllTuple) {
// Verifies that CreateFromShape works for tuples.
TEST_F(LiteralUtilTest, CreateFromShapeTuple) {
- auto scalar = Literal::CreateR0<float>(0.0);
- auto matrix = Literal::CreateR2<int32>({{0, 0}, {0, 0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(0.0);
+ auto matrix = LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
auto x = Literal::CreateFromShape(tuple->shape());
EXPECT_EQ(*tuple, *x);
}
TEST_F(LiteralUtilTest, IsAll) {
- EXPECT_TRUE(Literal::CreateR0<bool>(false)->IsAll(0));
- EXPECT_TRUE(Literal::CreateR0<bool>(true)->IsAll(1));
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(1));
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAll(2));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(0));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(2));
- EXPECT_FALSE(Literal::CreateR0<bool>(true)->IsAll(-1));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(false)->IsAll(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<bool>(true)->IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(1));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(2));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(true)->IsAll(-1));
// We shouldn't reinterpret int8_min as an unsigned type and then decide that
// it is equal to 255.
auto int8_min = std::numeric_limits<int8>::min();
- EXPECT_FALSE(Literal::CreateR0<uint8>(255)->IsAll(int8_min));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(255)->IsAll(int8_min));
- EXPECT_TRUE(Literal::CreateR0<float>(42.0)->IsAll(42));
- EXPECT_FALSE(Literal::CreateR0<float>(42.0001)->IsAll(42));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(42.0)->IsAll(42));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(42.0001)->IsAll(42));
- EXPECT_TRUE(Literal::CreateR1<int>({100, 100, 100})->IsAll(100));
- EXPECT_FALSE(Literal::CreateR1<double>({100, 100, 100.001})->IsAll(100));
+ EXPECT_TRUE(LiteralUtil::CreateR1<int>({100, 100, 100})->IsAll(100));
+ EXPECT_FALSE(LiteralUtil::CreateR1<double>({100, 100, 100.001})->IsAll(100));
- EXPECT_TRUE(Literal::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}})->IsAll(8));
half h8(8.0f);
half h9(9.0f);
- EXPECT_TRUE(Literal::CreateR2<half>({{h8}, {h8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<half>({{h8}, {h8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h8}, {h9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<half>({{h9}, {h8}})->IsAll(8));
bfloat16 b8(8.0f);
bfloat16 b9(9.0f);
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
- EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
// 9.001 will be truncated to 9.0
bfloat16 b91(9.001f);
bfloat16 b90(9.00f);
- EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
complex64 c8_9 = {8, 9};
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
auto uint64_max = std::numeric_limits<uint64>::max();
- EXPECT_FALSE(Literal::CreateR2<uint64>(
+ EXPECT_FALSE(LiteralUtil::CreateR2<uint64>(
{{uint64_max, uint64_max}, {uint64_max, uint64_max}})
->IsAll(-1));
}
TEST_F(LiteralUtilTest, IsAllFloat) {
// IsAllFloat always returns false when the literal is not floating-point.
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllFloat(0));
- EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllFloat(0));
-
- EXPECT_TRUE(Literal::CreateR0<float>(0)->IsAllFloat(0));
- EXPECT_TRUE(Literal::CreateR0<float>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(Literal::CreateR0<float>(-.5)->IsAllFloat(-.49));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllFloat(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllFloat(0));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(0)->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(.5)->IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(-.5)->IsAllFloat(-.49));
EXPECT_FALSE(
- Literal::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
- EXPECT_TRUE(
- Literal::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5));
-
- EXPECT_TRUE(Literal::CreateR0<double>(0)->IsAllFloat(0));
- EXPECT_TRUE(Literal::CreateR0<double>(.5)->IsAllFloat(.5));
- EXPECT_TRUE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.5));
- EXPECT_FALSE(Literal::CreateR0<double>(-.5)->IsAllFloat(-.49));
+ LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR2<float>({{.5, .5, .5}, {.5, .5, .5}})
+ ->IsAllFloat(.5));
+
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(0)->IsAllFloat(0));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(.5)->IsAllFloat(.5));
+ EXPECT_TRUE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.5));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(-.5)->IsAllFloat(-.49));
EXPECT_FALSE(
- Literal::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
+ LiteralUtil::CreateR2<double>({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0));
}
TEST_F(LiteralUtilTest, IsAllComplex) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(Literal::CreateR0<bool>(false)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<int8>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<uint8>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<int>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<float>(0)->IsAllComplex(0));
- EXPECT_FALSE(Literal::CreateR0<double>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<bool>(false)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<uint8>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<int>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<float>(0)->IsAllComplex(0));
+ EXPECT_FALSE(LiteralUtil::CreateR0<double>(0)->IsAllComplex(0));
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})
->IsAllComplex({8.0f, 9.0f}));
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c7_9}})
+ EXPECT_FALSE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c7_9}})
->IsAllComplex({8.0f, 9.0f}));
}
TEST_F(LiteralUtilTest, IsAllFirst) {
// IsAllComplex always returns false when the literal is not complex.
- EXPECT_FALSE(Literal::CreateR1<bool>({false, true})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<bool>({false, false})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<int8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<int32>({1, 1, 2})->IsAllFirst());
- EXPECT_TRUE(Literal::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<bool>({false, true})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<bool>({false, false})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int8>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int8>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint8>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<int32>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<int32>({1, 1, 2})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR1<uint32>({5, 5, 5, 5})->IsAllFirst());
+ EXPECT_FALSE(LiteralUtil::CreateR1<uint32>({1, 1, 2})->IsAllFirst());
complex64 c8_9 = {8, 9};
complex64 c7_9 = {7, 9};
- EXPECT_TRUE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
- EXPECT_FALSE(Literal::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
+ EXPECT_TRUE(LiteralUtil::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAllFirst());
+ EXPECT_FALSE(
+ LiteralUtil::CreateR2<complex64>({{c7_9}, {c8_9}})->IsAllFirst());
}
TEST_F(LiteralUtilTest, IsZero) {
- auto scalar_zero = Literal::CreateR0<float>(0.0f);
- auto scalar_one = Literal::CreateR0<float>(1.0f);
+ auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
+ auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
EXPECT_TRUE(scalar_zero->IsZero({}));
EXPECT_FALSE(scalar_one->IsZero({}));
- auto array = Literal::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
+ auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
EXPECT_FALSE(array->IsZero({0, 1}));
EXPECT_TRUE(array->IsZero({0, 2}));
EXPECT_TRUE(array->IsZero({1, 1}));
EXPECT_FALSE(array->IsZero({1, 2}));
- auto complex_zero = Literal::CreateR0<complex64>(0.0f);
- auto complex_nonzero = Literal::CreateR0<complex64>(0.5f);
+ auto complex_zero = LiteralUtil::CreateR0<complex64>(0.0f);
+ auto complex_nonzero = LiteralUtil::CreateR0<complex64>(0.5f);
EXPECT_TRUE(complex_zero->IsZero({}));
EXPECT_FALSE(complex_nonzero->IsZero({}));
}
@@ -547,7 +570,7 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
// Make a non-integer for floating point types.
TypeParam half = TypeParam(1) / TypeParam(2);
- auto data = Literal::CreateR2<TypeParam>({{half, 2}, {3, 4}});
+ auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
@@ -561,7 +584,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
}
TEST_F(LiteralUtilTest, ReshapeR0) {
- auto original = Literal::CreateR0<float>(1.7f);
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie();
EXPECT_EQ(*original, *reshape);
}
@@ -569,13 +592,13 @@ TEST_F(LiteralUtilTest, ReshapeR0) {
TEST_F(LiteralUtilTest, ReshapeR4) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4WithLayout<float>({{
+ auto original = LiteralUtil::CreateR4WithLayout<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
// F32[1x3x4x2]
- auto expected = Literal::CreateR3WithLayout<float>({
+ auto expected = LiteralUtil::CreateR3WithLayout<float>({
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
@@ -589,13 +612,13 @@ TEST_F(LiteralUtilTest, ReshapeR4) {
TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4WithLayout<float>({{
+ auto original = LiteralUtil::CreateR4WithLayout<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0minor_);
// F32[1x3x4x2]
- auto expected = Literal::CreateR3WithLayout<float>({
+ auto expected = LiteralUtil::CreateR3WithLayout<float>({
{{10, 11}, {12, 13}, {14, 15}, {16, 17}},
{{18, 19}, {20, 21}, {22, 23}, {24, 25}},
{{26, 27}, {28, 29}, {30, 31}, {32, 33}},
@@ -607,7 +630,7 @@ TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) {
}
TEST_F(LiteralUtilTest, TransposeR0) {
- auto original = Literal::CreateR0<float>(1.7f);
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
auto reshape = original->Transpose(/*permutation=*/{});
EXPECT_EQ(*original, *reshape);
}
@@ -615,7 +638,7 @@ TEST_F(LiteralUtilTest, TransposeR0) {
TEST_F(LiteralUtilTest, TransposeR4) {
// clang-format off
// F32[1x3x2x4]
- auto original = Literal::CreateR4<float>({{
+ auto original = LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -643,7 +666,7 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
- auto mat_dim0minor = Literal::CreateR2WithLayout<int32>(
+ auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
EXPECT_EQ(mat_dim0minor->element_count(), 6);
EXPECT_THAT(mat_dim0minor->data<int32>(), ElementsAre(1, 4, 2, 5, 3, 6));
@@ -654,7 +677,7 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) {
ElementsAre(1, 2, 3, 4, 5, 6));
// Test expected memory layout of R2 created with dim0-major (row-major).
- auto mat_dim0major = Literal::CreateR2WithLayout<int32>(
+ auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
EXPECT_EQ(mat_dim0major->element_count(), 6);
EXPECT_THAT(mat_dim0major->data<int32>(), ElementsAre(1, 2, 3, 4, 5, 6));
@@ -679,8 +702,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
{10, 11, 12},
},
}); // clang-format on
- auto lit_dim0minor =
- Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0minor_);
+ auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0minor_);
EXPECT_EQ(lit_dim0minor->element_count(), 12);
std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
@@ -694,8 +717,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
testing::ElementsAreArray(expected_dim0major));
// Test expected memory layout of R3 created with dim0-major (row-major).
- auto lit_dim0major =
- Literal::CreateR3FromArray3DWithLayout<int>(arr3d, layout_r3_dim0major_);
+ auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0major_);
EXPECT_EQ(lit_dim0major->element_count(), 12);
EXPECT_THAT(lit_dim0major->data<int32>(),
testing::ElementsAreArray(expected_dim0major));
@@ -707,28 +730,28 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) {
}
TEST_F(LiteralUtilTest, SliceR0S32) {
- auto input = Literal::CreateR0<int32>(1);
+ auto input = LiteralUtil::CreateR0<int32>(1);
auto result = input->Slice({}, {});
EXPECT_EQ(*input, *result);
}
TEST_F(LiteralUtilTest, SliceR1F32) {
- auto input = Literal::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
+ auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
auto result = input->Slice({3}, {4});
- auto expected = Literal::CreateR1<float>({4.0});
+ auto expected = LiteralUtil::CreateR1<float>({4.0});
EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR2U32) {
- auto input_3x4 =
- Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
+ auto input_3x4 = LiteralUtil::CreateR2<uint32>(
+ {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto result = input_3x4->Slice({0, 2}, {2, 4});
- auto expected = Literal::CreateR2<uint32>({{3, 4}, {7, 8}});
+ auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
EXPECT_EQ(*expected, *result);
}
TEST_F(LiteralUtilTest, SliceR3U32Full) {
- auto input_2x3x2 = Literal::CreateR3<uint32>(
+ auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2});
EXPECT_EQ(*input_2x3x2, *result);
@@ -737,21 +760,21 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) {
TEST_F(LiteralUtilTest, PopulateR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {1}));
output.PopulateR1<int64>({77});
- auto expected = Literal::CreateR1<int64>({77});
+ auto expected = LiteralUtil::CreateR1<int64>({77});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1U64) {
Literal output(ShapeUtil::MakeShape(U64, {2}));
output.PopulateR1<uint64>({{77, 88}});
- auto expected = Literal::CreateR1<uint64>({{77, 88}});
+ auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateR1C64) {
Literal output(ShapeUtil::MakeShape(C64, {1}));
output.PopulateR1<complex64>({{77, 88}});
- auto expected = Literal::CreateR1<complex64>({{77, 88}});
+ auto expected = LiteralUtil::CreateR1<complex64>({{77, 88}});
EXPECT_EQ(output, *expected);
}
@@ -759,7 +782,7 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
auto expected =
- Literal::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
+ LiteralUtil::CreateR2<complex64>({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}});
EXPECT_EQ(output, *expected);
}
@@ -767,7 +790,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {}));
bfloat16 h(0.25f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR0<bfloat16>(h);
+ auto expected = LiteralUtil::CreateR0<bfloat16>(h);
EXPECT_EQ(output, *expected);
}
@@ -775,7 +798,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {3}));
bfloat16 h(0.5f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR1<bfloat16>({h, h, h});
+ auto expected = LiteralUtil::CreateR1<bfloat16>({h, h, h});
EXPECT_EQ(output, *expected);
}
@@ -783,28 +806,28 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
Literal output(ShapeUtil::MakeShape(BF16, {2, 2}));
bfloat16 h(2.0f);
output.PopulateWithValue<bfloat16>(h);
- auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
+ auto expected = LiteralUtil::CreateR2<bfloat16>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
Literal output(ShapeUtil::MakeShape(F32, {}));
output.PopulateWithValue<float>(2.5f);
- auto expected = Literal::CreateR0<float>(2.5f);
+ auto expected = LiteralUtil::CreateR0<float>(2.5f);
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
Literal output(ShapeUtil::MakeShape(S64, {3}));
output.PopulateWithValue<int64>(-7);
- auto expected = Literal::CreateR1<int64>({-7, -7, -7});
+ auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
Literal output(ShapeUtil::MakeShape(U64, {2, 2}));
output.PopulateWithValue<uint64>(42);
- auto expected = Literal::CreateR2<uint64>({{42, 42}, {42, 42}});
+ auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
EXPECT_EQ(output, *expected);
}
@@ -812,7 +835,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C64) {
Literal output(ShapeUtil::MakeShape(C64, {2, 2}));
output.PopulateWithValue<complex64>({4, 2});
auto expected =
- Literal::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
+ LiteralUtil::CreateR2<complex64>({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}});
EXPECT_EQ(output, *expected);
}
@@ -820,7 +843,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR0F16) {
Literal output(ShapeUtil::MakeShape(F16, {}));
half h(0.25f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR0<half>(h);
+ auto expected = LiteralUtil::CreateR0<half>(h);
EXPECT_EQ(output, *expected);
}
@@ -828,7 +851,7 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F16) {
Literal output(ShapeUtil::MakeShape(F16, {3}));
half h(0.5f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR1<half>({h, h, h});
+ auto expected = LiteralUtil::CreateR1<half>({h, h, h});
EXPECT_EQ(output, *expected);
}
@@ -836,15 +859,15 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2F16) {
Literal output(ShapeUtil::MakeShape(F16, {2, 2}));
half h(2.0f);
output.PopulateWithValue<half>(h);
- auto expected = Literal::CreateR2<half>({{h, h}, {h, h}});
+ auto expected = LiteralUtil::CreateR2<half>({{h, h}, {h, h}});
EXPECT_EQ(output, *expected);
}
TEST_F(LiteralUtilTest, ReplicateR2U32) {
- auto input =
- Literal::CreateR2<uint32>({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
+ auto input = LiteralUtil::CreateR2<uint32>(
+ {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
auto output = input->Replicate<uint32>(3);
- auto expected = Literal::CreateR3<uint32>(
+ auto expected = LiteralUtil::CreateR3<uint32>(
{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}});
@@ -898,12 +921,12 @@ TEST_F(LiteralUtilTest, CopySliceFrom) {
}
TEST_F(LiteralUtilTest, CopyFromScalars) {
- auto zero = Literal::CreateR0<uint32>(0);
- auto nine = Literal::CreateR0<uint32>(9);
+ auto zero = LiteralUtil::CreateR0<uint32>(0);
+ auto nine = LiteralUtil::CreateR0<uint32>(9);
TF_EXPECT_OK(zero->CopyFrom(*nine));
EXPECT_EQ(*zero, *nine);
- auto vect = Literal::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
+ auto vect = LiteralUtil::CreateR1<uint32>({3, 4, 9, 12, 5, 17, 21});
TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {}));
EXPECT_EQ(zero->Get<uint32>({}), 17);
TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {}));
@@ -912,13 +935,13 @@ TEST_F(LiteralUtilTest, CopyFromScalars) {
TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
const Shape empty_r1_shape = ShapeUtil::MakeShape(F32, {0});
- const auto const_nine = Literal::CreateR1<float>({9});
+ const auto const_nine = LiteralUtil::CreateR1<float>({9});
const auto const_empty = Literal::CreateFromShape(empty_r1_shape);
{
// Source contains dimension with zero elements.
const auto empty = Literal::CreateFromShape(empty_r1_shape);
- auto nine = Literal::CreateR1<float>({9});
+ auto nine = LiteralUtil::CreateR1<float>({9});
TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0}));
EXPECT_EQ(*nine, *const_nine);
@@ -927,7 +950,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) {
{
// Copy 0 element to destination with zero elements.
const auto empty = Literal::CreateFromShape(empty_r1_shape);
- auto nine = Literal::CreateR1<float>({9});
+ auto nine = LiteralUtil::CreateR1<float>({9});
TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0}));
EXPECT_EQ(*empty, *const_empty);
@@ -942,16 +965,16 @@ TEST_F(LiteralUtilTest, CopyFromNilShape) {
}
TEST_F(LiteralUtilTest, CopyFromArrays) {
- auto scalar_42 = Literal::CreateR0<float>(42.0);
- auto scalar_123 = Literal::CreateR0<float>(123.0);
+ auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto scalar_123 = LiteralUtil::CreateR0<float>(123.0);
EXPECT_NE(*scalar_42, *scalar_123);
TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{},
/*src_shape_index=*/{}));
EXPECT_EQ(*scalar_42, *scalar_123);
EXPECT_EQ(scalar_42->Get<float>({}), 123.0f);
- auto matrix_1234 = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto matrix_5678 = Literal::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+ auto matrix_1234 = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_5678 = LiteralUtil::CreateR2<float>({{5.0, 6.0}, {7.0, 8.0}});
EXPECT_NE(*matrix_1234, *matrix_5678);
EXPECT_EQ(matrix_1234->Get<float>({0, 0}), 1.0f);
TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{},
@@ -961,19 +984,19 @@ TEST_F(LiteralUtilTest, CopyFromArrays) {
}
TEST_F(LiteralUtilTest, CopyFromTuples) {
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
+ auto nested_tuple = LiteralUtil::MakeTuple(
{matrix.get(),
- Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get(),
- &nil_literal})
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
.get()});
// Create a tuple the same shape as the inner tuple of nested_tuple but with
// different values..
- auto tuple = Literal::MakeTuple({Literal::CreateR0<int32>(-5).get(),
- Literal::CreateR1<double>({2.0, 4.0}).get(),
- &nil_literal});
+ auto tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(-5).get(),
+ LiteralUtil::CreateR1<double>({2.0, 4.0}).get(), &nil_literal});
EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0}));
EXPECT_EQ(nested_tuple->Get<int32>({}, {1, 0}), 42);
@@ -994,8 +1017,8 @@ TEST_F(LiteralUtilTest, CopyFromTuples) {
EXPECT_EQ(nested_tuple->Get<double>({1}, {1, 1}), 4.0);
}
TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
- auto tuple = Literal::MakeTuple(
- {Literal::CreateR0<int32>(-2).get(), Literal::CreateR0<int32>(4).get()});
+ auto tuple = LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int32>(-2).get(),
+ LiteralUtil::CreateR0<int32>(4).get()});
EXPECT_EQ(tuple->Get<int32>({}, {0}), -2);
EXPECT_EQ(tuple->Get<int32>({}, {1}), 4);
@@ -1009,8 +1032,8 @@ TEST_F(LiteralUtilTest, CopyBetweenSameTuple) {
}
TEST_F(LiteralUtilTest, CopyFromDifferentShapes) {
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto vector = Literal::CreateR1<float>({5.0, 7.0});
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector = LiteralUtil::CreateR1<float>({5.0, 7.0});
Status status = matrix->CopyFrom(*vector);
ASSERT_FALSE(status.ok());
ASSERT_THAT(status.error_message(),
@@ -1035,7 +1058,7 @@ TEST_F(LiteralUtilTest, F16) {
half h1(1.0f);
half h2(2.0f);
- auto m2 = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
+ auto m2 = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l2 = m2.get();
const char* d2 = reinterpret_cast<const char*>(l2->data<half>().data());
EXPECT_EQ(d2[0], 0);
@@ -1134,12 +1157,12 @@ TEST_F(LiteralUtilTest, PopulateParallel) {
TEST_F(LiteralUtilTest, ConvertR4) {
// clang-format off
- auto original = Literal::CreateR4WithLayout<int8>({{
+ auto original = LiteralUtil::CreateR4WithLayout<int8>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
}}, layout_r4_dim0major_);
- auto expected = Literal::CreateR4WithLayout<uint32>({{
+ auto expected = LiteralUtil::CreateR4WithLayout<uint32>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -1153,42 +1176,42 @@ TEST_F(LiteralUtilTest, ConvertR4) {
TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
// clang-format off
- auto s8 = Literal::CreateR4WithLayout<int8>({{
+ auto s8 = LiteralUtil::CreateR4WithLayout<int8>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto s32 = Literal::CreateR4WithLayout<int32>({{
+ auto s32 = LiteralUtil::CreateR4WithLayout<int32>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto u32 = Literal::CreateR4WithLayout<uint32>({{
+ auto u32 = LiteralUtil::CreateR4WithLayout<uint32>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto s64 = Literal::CreateR4WithLayout<int64>({{
+ auto s64 = LiteralUtil::CreateR4WithLayout<int64>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto u64 = Literal::CreateR4WithLayout<uint64>({{
+ auto u64 = LiteralUtil::CreateR4WithLayout<uint64>({{
{{10, 0, 12, 0}, {0, 15, 0, 17}},
{{0, 19, 0, 21}, {22, 0, 24, 0}},
{{26, 0, 28, 0}, {0, 31, 0, 33}},
}}, layout_r4_dim0major_);
- auto pred = Literal::CreateR4WithLayout<bool>({{
+ auto pred = LiteralUtil::CreateR4WithLayout<bool>({{
{{true, false, true, false}, {false, true, false, true}},
{{false, true, false, true}, {true, false, true, false}},
{{true, false, true, false}, {false, true, false, true}},
}}, layout_r4_dim0major_);
- auto int32_pred = Literal::CreateR4WithLayout<int32>({{
+ auto int32_pred = LiteralUtil::CreateR4WithLayout<int32>({{
{{1, 0, 1, 0}, {0, 1, 0, 1}},
{{0, 1, 0, 1}, {1, 0, 1, 0}},
{{1, 0, 1, 0}, {0, 1, 0, 1}},
}}, layout_r4_dim0major_);
- auto f16 = Literal::CreateR4WithLayout<half>({{
+ auto f16 = LiteralUtil::CreateR4WithLayout<half>({{
{{half(10.0), half(0.0), half(12.0), half(0.0)},
{half(0.0), half(15.0), half(0.0), half(17.0)}},
{{half(0.0), half(19.0), half(0.0), half(21.0)},
@@ -1196,7 +1219,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{half(26.0), half(0.0), half(28.0), half(0.0)},
{half(0.0), half(31.0), half(0.0), half(33.0)}},
}}, layout_r4_dim0major_);
- auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
+ auto bf16 = LiteralUtil::CreateR4WithLayout<bfloat16>({{
{{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
{{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
@@ -1204,17 +1227,17 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
{{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
{bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
}}, layout_r4_dim0major_);
- auto f32 = Literal::CreateR4WithLayout<float>({{
+ auto f32 = LiteralUtil::CreateR4WithLayout<float>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
}}, layout_r4_dim0major_);
- auto f64 = Literal::CreateR4WithLayout<double>({{
+ auto f64 = LiteralUtil::CreateR4WithLayout<double>({{
{{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}},
{{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}},
{{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}},
}}, layout_r4_dim0major_);
- auto c64 = Literal::CreateR4WithLayout<complex64>({{
+ auto c64 = LiteralUtil::CreateR4WithLayout<complex64>({{
{{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
{{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
{{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}},
@@ -1286,18 +1309,18 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
}
TEST_F(LiteralUtilTest, BitcastConvert) {
- auto original =
- Literal::CreateR1<uint32>({tensorflow::bit_cast<uint32>(2.5f),
- tensorflow::bit_cast<uint32>(-42.25f),
- tensorflow::bit_cast<uint32>(100.f), 0xbeef});
- auto expected = Literal::CreateR1<float>(
+ auto original = LiteralUtil::CreateR1<uint32>(
+ {tensorflow::bit_cast<uint32>(2.5f),
+ tensorflow::bit_cast<uint32>(-42.25f),
+ tensorflow::bit_cast<uint32>(100.f), 0xbeef});
+ auto expected = LiteralUtil::CreateR1<float>(
{2.5f, -42.25f, 100.0f, tensorflow::bit_cast<float>(0xbeef)});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> converted,
original->BitcastConvert(F32));
}
TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
- auto literal = Literal::CreateR0<uint32>(1234);
+ auto literal = LiteralUtil::CreateR0<uint32>(1234);
Status status = literal->BitcastConvert(F64).status();
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(tensorflow::str_util::StrContains(status.error_message(),
@@ -1332,7 +1355,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
half h1(1.0f);
half h2(2.0f);
- auto m = Literal::CreateR2<half>({{h1, h2}, {h2, h1}});
+ auto m = LiteralUtil::CreateR2<half>({{h1, h2}, {h2, h1}});
Literal* l = m.get();
EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape()));
EXPECT_EQ(4, l->data<half>().size());
@@ -1375,10 +1398,10 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
}
TEST_F(LiteralUtilTest, LiteralSliceTest) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
Literal nil(ShapeUtil::MakeNil());
EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar);
@@ -1397,10 +1420,10 @@ TEST_F(LiteralUtilTest, LiteralSliceTest) {
}
TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
// Verify that changing the underlying data beneath the view changes the
// data of the view itself.
const auto nested_tuple_view = LiteralSlice(*nested_tuple);
@@ -1420,18 +1443,19 @@ TEST_F(LiteralUtilTest, MutatingLiteralSlice) {
}
TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) {
- auto scalar = Literal::CreateR0<float>(1.0);
- auto matrix = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
- auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()});
- auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ auto nested_tuple = LiteralUtil::MakeTuple({tuple.get(), scalar.get()});
const auto nested_tuple_view = LiteralSlice(*nested_tuple);
const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0});
const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1});
- EXPECT_EQ(matrix_view, *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ EXPECT_EQ(matrix_view,
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) {
std::vector<int64> int64_values = {1, 2, 3};
const Shape literal_shape = ShapeUtil::MakeShape(S64, {3});
@@ -1443,7 +1467,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) {
EXPECT_EQ(literal.Get<int64>({2}), 3);
}
-TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) {
+TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) {
std::vector<int64> one_two_three = {1, 2, 3};
const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3});
@@ -1472,7 +1496,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) {
TEST_F(LiteralUtilTest, LiteralMove) {
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
Literal literal(std::move(*matrix));
EXPECT_TRUE(
@@ -1485,11 +1509,11 @@ TEST_F(LiteralUtilTest, LiteralMove) {
TEST_F(LiteralUtilTest, DecomposeTuple) {
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
- {Literal::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
- Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get(),
- &nil_literal})
+ auto nested_tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}}).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get(), &nil_literal})
.get(),
&nil_literal});
@@ -1526,13 +1550,13 @@ TEST_F(LiteralUtilTest, DecomposeEmptyTuple) {
TEST_F(LiteralUtilTest, MoveIntoTuple) {
std::vector<Literal> elements;
- elements.push_back(std::move(*Literal::CreateR0<float>(1.0)));
- elements.push_back(std::move(*Literal::CreateR1<int32>({4, 8})));
- elements.push_back(std::move(
- *Literal::MakeTuple({Literal::CreateR0<int32>(42).get(),
- Literal::CreateR1<double>({23.0, 44.0}).get()})
+ elements.push_back(std::move(*LiteralUtil::CreateR0<float>(1.0)));
+ elements.push_back(std::move(*LiteralUtil::CreateR1<int32>({4, 8})));
+ elements.push_back(std::move(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR1<double>({23.0, 44.0}).get()})
- ));
+ ));
Literal literal = Literal::MoveIntoTuple(&elements);
ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape()));
@@ -1561,7 +1585,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape()));
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
literal = std::move(*matrix);
EXPECT_TRUE(
@@ -1574,7 +1598,7 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) {
TEST_F(LiteralUtilTest, LiteralSliceCopy) {
std::unique_ptr<Literal> matrix =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
const auto matrix_view = LiteralSlice(*matrix);
LiteralSlice matrix_view_copy(matrix_view);
@@ -1585,9 +1609,9 @@ TEST_F(LiteralUtilTest, LiteralSliceCopy) {
}
TEST_F(LiteralUtilTest, GetSetTuple) {
- auto tuple = Literal::MakeTuple(
- {Literal::CreateR0<float>(42.0).get(),
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
+ auto tuple = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(42.0).get(),
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get()});
EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0);
tuple->Set<float>(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0);
EXPECT_EQ(tuple->Get<float>(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0);
@@ -1628,20 +1652,20 @@ TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) {
TEST_F(LiteralUtilTest, ProtoRoundTrip) {
// Test serializing then deserializing a Literal through a proto.
- auto one_f32 = Literal::CreateR0<float>(1.0);
- auto two_f32 = Literal::CreateR0<float>(2.0);
- auto vector_int8 = Literal::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
- auto vector_c64 = Literal::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
- auto vector_bfloat16 = Literal::CreateR1<bfloat16>(
+ auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
+ auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
+ auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
+ auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
auto vector_half =
- Literal::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
+ LiteralUtil::CreateR1<half>({half{10.0}, half{20.0}, half{-30.0}});
auto matrix_pred =
- Literal::CreateR2<bool>({{true, false, true}, {false, false, true}});
- auto tuple = Literal::MakeTuple(
+ LiteralUtil::CreateR2<bool>({{true, false, true}, {false, false, true}});
+ auto tuple = LiteralUtil::MakeTuple(
{one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()});
Literal nil_literal(ShapeUtil::MakeNil());
- auto nested_tuple = Literal::MakeTuple(
+ auto nested_tuple = LiteralUtil::MakeTuple(
{tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal});
auto to_from_proto = [](const Literal& literal) -> Literal {
@@ -1774,8 +1798,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
}
TEST_F(LiteralUtilTest, SortSparseElements) {
- auto literal =
- Literal::CreateSparse<float>({10, 10, 10}, SparseIndexArray(10, 3), {});
+ auto literal = LiteralUtil::CreateSparse<float>({10, 10, 10},
+ SparseIndexArray(10, 3), {});
literal->AppendSparseElement<float>({2, 3, 4}, 2.0);
literal->AppendSparseElement<float>({3, 4, 5}, 3.0);
literal->AppendSparseElement<float>({1, 2, 3}, 1.0);
@@ -1789,21 +1813,22 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}});
ASSERT_EQ(
- Literal::CreateSparse<bool>(dimensions, indices, {true, false, true})
+ LiteralUtil::CreateSparse<bool>(dimensions, indices, {true, false, true})
->GetSparseElementAsString(1),
"false");
- ASSERT_EQ(Literal::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
+ ASSERT_EQ(LiteralUtil::CreateSparse<int64>(dimensions, indices, {1, 2, 3})
->GetSparseElementAsString(1),
tensorflow::strings::StrCat(int64{2}));
- ASSERT_EQ(Literal::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
- ->GetSparseElementAsString(1),
- tensorflow::strings::StrCat(double{2.0}));
- ASSERT_EQ(Literal::CreateSparse<half>(dimensions, indices,
- {half{1.0}, half{2.0}, half{3.0}})
+ ASSERT_EQ(
+ LiteralUtil::CreateSparse<double>(dimensions, indices, {1.0, 2.0, 3.0})
+ ->GetSparseElementAsString(1),
+ tensorflow::strings::StrCat(double{2.0}));
+ ASSERT_EQ(LiteralUtil::CreateSparse<half>(dimensions, indices,
+ {half{1.0}, half{2.0}, half{3.0}})
->GetSparseElementAsString(1),
tensorflow::strings::StrCat(static_cast<float>(half{2.0})));
ASSERT_EQ(
- Literal::CreateSparse<complex64>(
+ LiteralUtil::CreateSparse<complex64>(
dimensions, indices,
std::vector<complex64>{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})
->GetSparseElementAsString(1),
@@ -1811,33 +1836,36 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) {
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
/*dimensions=*/{0}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 1}, {2, 2}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int64>({{1, 1}, {2, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int64>({1, 2});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
/*dimensions=*/{1}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 2}, {1, 2}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int64>({{1, 2}, {1, 2}}));
}
TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
- std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(9);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(9);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> broadcasted_literal,
literal->Broadcast(
/*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
/*dimensions=*/{}));
- EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int32>({{9, 9}, {9, 9}}));
+ EXPECT_EQ(*broadcasted_literal,
+ *LiteralUtil::CreateR2<int32>({{9, 9}, {9, 9}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 7563cc1e34..548fbe8a83 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -43,25 +43,6 @@ namespace xla {
namespace {
-constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
-
-// Converts between little and big endian.
-//
-// Precondition: size % 2 == 0 (elements in the array are 16 bits long)
-void ConvertEndianShort(string* bytes) {
- CHECK_EQ(bytes->size() / 2, 0);
- for (int64 i = 0; i < bytes->size(); i += 2) {
- std::swap((*bytes)[i], (*bytes)[i + 1]);
- }
-}
-
-void ConvertEndianShort(char* bytes, int64 size) {
- CHECK_EQ(size / 2, 0);
- for (int64 i = 0; i < size; i += 2) {
- std::swap(bytes[i], bytes[i + 1]);
- }
-}
-
// Return a literal with all arrays of type FromNativeT converted to type
// ToNativeT in the given literal.
template <typename FromNativeT, typename ToNativeT>
@@ -103,498 +84,54 @@ std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
} // namespace
-LiteralBase::~LiteralBase() {}
-
-std::ostream& operator<<(std::ostream& out, const Literal& literal) {
- out << literal.ToString();
- return out;
-}
-
-Literal::StrideConfig::StrideConfig(
- const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions)
- : dimensions(dimensions),
- base(dimensions.size(), 0),
- step(dimensions.size(), 1) {
- if (!dimensions.empty()) {
- // Selects the shape with the largest minor dimension as the one upon
- // which to run the tight stride loop.
- if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >=
- dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) {
- minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0);
- dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension);
- } else {
- minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0);
- source_stride =
- IndexUtil::GetDimensionStride(source_shape, minor_dimension);
- }
- minor_loop_size = dimensions[minor_dimension];
- step[minor_dimension] = minor_loop_size;
- }
-}
-
-Literal::Literal(const Shape& shape)
- : Literal(shape, /*allocate_arrays=*/true) {}
-
-void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) {
- if (ShapeUtil::IsTuple(shape)) {
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& subshape = shape.tuple_shapes(i);
-
- auto child_piece = Piece();
- child_piece.set_subshape(&subshape);
-
- SetPiece(subshape, &child_piece, allocate_arrays);
-
- piece->emplace_back(std::move(child_piece));
- }
- } else {
- CHECK(ShapeUtil::IsArray(shape));
- if (allocate_arrays) {
- if (LayoutUtil::IsSparseArray(shape)) {
- // For sparse arrays, the buffer must be of the size of the maximum
- // number of sparse elements possible.
- const int64 max_sparse_elements =
- LayoutUtil::MaxSparseElements(shape.layout());
- piece->set_buffer(
- new char[max_sparse_elements *
- ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]);
- piece->set_sparse_indices(
- new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape)));
- } else {
- piece->set_buffer(new char[piece->size_bytes()]);
- }
- }
- }
-}
-
-Literal::Literal(const Shape& shape, bool allocate_arrays)
- : LiteralBase(), shape_(MakeUnique<Shape>(shape)) {
- CHECK(LayoutUtil::HasLayout(*shape_));
- root_piece_ = new Piece();
- root_piece_->set_subshape(shape_.get());
- CHECK(&root_piece_->subshape() == shape_.get());
-
- SetPiece(*shape_, root_piece_, allocate_arrays);
-}
-
-Literal::~Literal() {
- if (root_piece_ != nullptr) {
- DeallocateBuffers();
- delete root_piece_;
- }
-}
-
-void Literal::DeallocateBuffers() {
- root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* piece) {
- if (piece->buffer() != nullptr) {
- delete[] piece->buffer();
- delete piece->sparse_indices();
- }
- });
-}
-
-Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); }
-
-Literal& Literal::operator=(Literal&& other) {
- DCHECK(&other.root_piece_->subshape() == other.shape_.get());
- using std::swap;
- swap(shape_, other.shape_);
- swap(root_piece_, other.root_piece_);
- DCHECK(&root_piece_->subshape() == shape_.get());
-
- return *this;
-}
-
-std::unique_ptr<Literal> LiteralBase::CreateFromShape(const Shape& shape) {
- auto literal = MakeUnique<Literal>(shape);
- literal->root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* piece) {
- if (ShapeUtil::IsArray(piece->subshape())) {
- memset(piece->untyped_data(), 0, piece->size_bytes());
- }
- });
- return literal;
-}
-
-const SparseIndexArray* LiteralBase::sparse_indices(
- const ShapeIndex& shape_index) const {
- return piece(shape_index).sparse_indices();
-}
-
-SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
- return piece(shape_index).sparse_indices();
-}
-
-/* static */ std::unique_ptr<Literal> Literal::CreateFromDimensions(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromDimensions(
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
+ return Literal::CreateFromShape(
+ ShapeUtil::MakeShape(primitive_type, dimensions));
}
-/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertBF16ToF32(
const LiteralSlice& bf16_literal) {
return ConvertType<bfloat16, float>(bf16_literal);
}
-/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ConvertF32ToBF16(
const LiteralSlice& f32_literal) {
return ConvertType<float, bfloat16>(f32_literal);
}
-template <typename NativeT>
-Status Literal::CopySliceFromInternal(
- const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
- TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size());
- TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size());
-
- auto linear_index = [](const Shape& shape,
- tensorflow::gtl::ArraySlice<int64> multi_index) {
- return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index);
- };
-
- if (ShapeUtil::Rank(src_literal.shape()) == 0 ||
- ShapeUtil::Rank(shape()) == 0) {
- // If any of the two shapes are scalars, we can just call the StridedCopy()
- // directly, and we know we will be copying only one value.
- TF_RET_CHECK(copy_size.empty());
- StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
- src_literal.data<NativeT>(),
- linear_index(src_literal.shape(), src_base), 0, 1);
- } else if (!ShapeUtil::HasZeroElements(shape()) &&
- !ShapeUtil::HasZeroElements(src_literal.shape())) {
- // Perform copy if neither src nor dest has dimensions with zero element,
- // otherwise it's a no-op.
- TF_RET_CHECK(src_base.size() == dest_base.size());
- TF_RET_CHECK(src_base.size() == copy_size.size());
-
- // Scan the source from minor, stepping in copy size blocks, then within
- // the index enumaration functor, do a strided copy advancing source index
- // by one (walking through the minor dimension), and destination index by
- // proper stride size at the matching dimension.
- DimensionVector src_indexes(src_base.size(), 0);
- DimensionVector dest_indexes(dest_base.size(), 0);
- Literal::StrideConfig stride_config(src_literal.shape(), shape(),
- copy_size);
-
- auto copy_proc = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- // Map from multi-dimensional index, to source index.
- std::transform(indexes.begin(), indexes.end(), src_base.begin(),
- src_indexes.begin(), std::plus<int64>());
- // Map from multi-dimensional index, to destination index.
- std::transform(indexes.begin(), indexes.end(), dest_base.begin(),
- dest_indexes.begin(), std::plus<int64>());
-
- int64 src_index = linear_index(src_literal.shape(), src_indexes);
- int64 dest_index = linear_index(shape(), dest_indexes);
-
- // `this->` is needed to workaround MSVC bug: #16882
- StridedCopy(this->data<NativeT>(), dest_index, stride_config.dest_stride,
- src_literal.data<NativeT>(), src_index,
- stride_config.source_stride, stride_config.minor_loop_size);
- return true;
- };
-
- ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base,
- stride_config.dimensions, stride_config.step,
- copy_proc);
- }
- return Status::OK();
-}
-
-Status Literal::CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index) {
- DCHECK_EQ(shape().element_type(), src_literal.shape().element_type());
- const int64 src_linear_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- src_literal.shape(), src_index);
- const int64 dest_linear_index =
- IndexUtil::MultidimensionalIndexToLinearIndex(shape(), dest_index);
- const int64 primitive_size =
- ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
-
- char* dest_address =
- static_cast<char*>(untyped_data()) + dest_linear_index * primitive_size;
- const char* source_address =
- static_cast<const char*>(src_literal.untyped_data()) +
- src_linear_index * primitive_size;
- if (dest_address != source_address) {
- memcpy(dest_address, source_address, primitive_size);
- }
- return Status::OK();
-}
-
-std::vector<Literal> Literal::DecomposeTuple() {
- CHECK(ShapeUtil::IsTuple(shape()));
- std::vector<Literal> elements;
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}),
- /*allocate_arrays=*/false));
- Literal& element = elements.back();
- element.root_piece_->ForEachMutableSubpiece(
- [&](const ShapeIndex& index, Piece* dest_piece) {
- ShapeIndex src_index = {i};
- for (int64 j : index) {
- src_index.push_back(j);
- }
- Piece& src_piece = piece(src_index);
-
- // Move the respective buffer and sparse indices over to the element
- // Literal.
- dest_piece->set_buffer(src_piece.buffer());
- src_piece.set_buffer(nullptr);
- dest_piece->set_sparse_indices(src_piece.sparse_indices());
- src_piece.set_sparse_indices(nullptr);
- });
- }
- // Set this literal to be nil-shaped.
- *this = Literal();
- return elements;
-}
-
-/* static */ Literal Literal::MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements) {
- std::vector<Shape> element_shapes;
- for (const Literal& element : elements) {
- element_shapes.push_back(element.shape());
- }
- Literal literal(ShapeUtil::MakeTupleShape(element_shapes),
- /*allocate_arrays=*/false);
- for (int i = 0; i < elements.size(); ++i) {
- TF_CHECK_OK(
- literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
- }
- return literal;
-}
-
-namespace {
-
-// Copies the elements in 'src' to 'dest'. The shape and layout of the data in
-// the array slices are indicated by dest_shape and src_shape respectively.
-template <typename NativeT>
-void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- tensorflow::gtl::ArraySlice<NativeT> src,
- const Shape& dest_shape, const Shape& src_shape) {
- CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
- if (ShapeUtil::HasZeroElements(dest_shape)) {
- return;
- }
- std::vector<int64> index(ShapeUtil::Rank(dest_shape));
- do {
- dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] =
- src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)];
- } while (IndexUtil::BumpIndices(dest_shape, &index));
-}
-
-} // namespace
-
-Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) {
- CHECK(subshape_ != nullptr);
- CHECK(src.subshape_ != nullptr);
- if (ShapeUtil::Equal(subshape(), src.subshape())) {
- // If the layouts are equal it's faster just to memcpy.
- memcpy(buffer(), src.buffer(), src.size_bytes());
- } else {
- TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape()));
- std::vector<int64> origin(ShapeUtil::Rank(subshape()), 0);
- switch (subshape().element_type()) {
-#define COPY_ELEMENTS(XLA_T, NATIVE_T) \
- case (XLA_T): \
- CopyElementsBetween<NATIVE_T>(data<NATIVE_T>(), src.data<NATIVE_T>(), \
- subshape(), src.subshape()); \
- break;
- COPY_ELEMENTS(U8, uint8);
- COPY_ELEMENTS(U16, uint16);
- COPY_ELEMENTS(U32, uint32);
- COPY_ELEMENTS(U64, uint64);
- COPY_ELEMENTS(S8, int8);
- COPY_ELEMENTS(S16, int16);
- COPY_ELEMENTS(S32, int32);
- COPY_ELEMENTS(S64, int64);
- COPY_ELEMENTS(F16, half);
- COPY_ELEMENTS(BF16, bfloat16);
- COPY_ELEMENTS(F32, float);
- COPY_ELEMENTS(F64, double);
- COPY_ELEMENTS(C64, complex64);
- COPY_ELEMENTS(PRED, bool);
-#undef COPY_ELEMENTS
- default:
- return Unimplemented(
- "Copying a Literal object with element type %s is not implemented.",
- PrimitiveType_Name(subshape().element_type()).c_str());
- }
- }
- return Status::OK();
-}
-
-Status Literal::CopyFrom(const LiteralSlice& src_literal,
- const ShapeIndex& dest_shape_index,
- const ShapeIndex& src_shape_index) {
- const Shape& dest_subshape =
- ShapeUtil::GetSubshape(shape(), dest_shape_index);
- const Shape& src_subshape =
- ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index);
- if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) {
- return InvalidArgument(
- "Destination subshape incompatible with source subshape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_subshape).c_str());
- }
- return root_piece_->ForEachMutableSubpieceWithStatus(
- [&](const ShapeIndex& index, Piece* piece) {
- if (!ShapeUtil::IsArray(piece->subshape())) {
- return Status::OK();
- }
-
- // Determine if this index is in the part of this literal that we want
- // to copy over from src_literal.
- bool in_subtree_to_copy = true;
- for (int i = 0; i < dest_shape_index.size(); ++i) {
- if (index[i] != dest_shape_index[i]) {
- in_subtree_to_copy = false;
- break;
- }
- }
- if (!in_subtree_to_copy) {
- return Status::OK();
- }
- // Construct the index of the corresponding piece in the source literal.
- ShapeIndex src_piece_index = src_shape_index;
- for (int64 i = dest_shape_index.size(); i < index.size(); ++i) {
- src_piece_index.push_back(index[i]);
- }
- TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index)));
- return Status::OK();
- });
-}
-
-Status Literal::MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index) {
- const Shape& dest_subshape =
- ShapeUtil::GetSubshape(shape(), dest_shape_index);
- if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) {
- return InvalidArgument(
- "Destination subshape not equal to source shape: %s vs %s",
- ShapeUtil::HumanString(dest_subshape).c_str(),
- ShapeUtil::HumanString(src_literal.shape()).c_str());
- }
-
- src_literal.root_piece_->ForEachSubpiece(
- [&](const ShapeIndex& src_index, const Piece& src_piece) {
- if (!ShapeUtil::IsArray(src_piece.subshape())) {
- return;
- }
-
- ShapeIndex dest_index = dest_shape_index;
- for (int64 i : src_index) {
- dest_index.push_back(i);
- }
- Piece& dest_piece = piece(dest_index);
- delete[] dest_piece.buffer();
- dest_piece.set_buffer(src_piece.buffer());
- delete dest_piece.sparse_indices();
- dest_piece.set_sparse_indices(src_piece.sparse_indices());
- });
-
- src_literal.shape_ = MakeUnique<Shape>(ShapeUtil::MakeNil());
- delete src_literal.root_piece_;
- src_literal.root_piece_ = new LiteralBase::Piece();
- src_literal.root_piece_->set_subshape(src_literal.shape_.get());
-
- return Status::OK();
-}
-
-Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size) {
- TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape());
- TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape()))
- << ShapeUtil::HumanString(src_literal.shape());
- TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape()));
-
- switch (shape().element_type()) {
- case U8:
- return CopySliceFromInternal<uint8>(src_literal, src_base, dest_base,
- copy_size);
- case U16:
- return CopySliceFromInternal<uint16>(src_literal, src_base, dest_base,
- copy_size);
- case U32:
- return CopySliceFromInternal<uint32>(src_literal, src_base, dest_base,
- copy_size);
- case U64:
- return CopySliceFromInternal<uint64>(src_literal, src_base, dest_base,
- copy_size);
- case S8:
- return CopySliceFromInternal<int8>(src_literal, src_base, dest_base,
- copy_size);
- case S16:
- return CopySliceFromInternal<int16>(src_literal, src_base, dest_base,
- copy_size);
- case S32:
- return CopySliceFromInternal<int32>(src_literal, src_base, dest_base,
- copy_size);
- case S64:
- return CopySliceFromInternal<int64>(src_literal, src_base, dest_base,
- copy_size);
- case F16:
- return CopySliceFromInternal<half>(src_literal, src_base, dest_base,
- copy_size);
- case BF16:
- return CopySliceFromInternal<bfloat16>(src_literal, src_base, dest_base,
- copy_size);
- case F32:
- return CopySliceFromInternal<float>(src_literal, src_base, dest_base,
- copy_size);
- case F64:
- return CopySliceFromInternal<double>(src_literal, src_base, dest_base,
- copy_size);
- case C64:
- return CopySliceFromInternal<complex64>(src_literal, src_base, dest_base,
- copy_size);
- case PRED:
- return CopySliceFromInternal<bool>(src_literal, src_base, dest_base,
- copy_size);
- default:
- break;
- }
- return Unimplemented(
- "Copying a slice from a Literal object with element type %d is not "
- "implemented.",
- shape().element_type());
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateToken() {
+ return MakeUnique<Literal>(ShapeUtil::MakeTokenShape());
}
-/* static */ Literal Literal::Zero(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*Literal::CreateR0<uint8>(0));
+ return std::move(*LiteralUtil::CreateR0<uint8>(0));
case U32:
- return std::move(*Literal::CreateR0<uint32>(0));
+ return std::move(*LiteralUtil::CreateR0<uint32>(0));
case U64:
- return std::move(*Literal::CreateR0<uint64>(0));
+ return std::move(*LiteralUtil::CreateR0<uint64>(0));
case S8:
- return std::move(*Literal::CreateR0<int8>(0));
+ return std::move(*LiteralUtil::CreateR0<int8>(0));
case S32:
- return std::move(*Literal::CreateR0<int32>(0));
+ return std::move(*LiteralUtil::CreateR0<int32>(0));
case S64:
- return std::move(*Literal::CreateR0<int64>(0));
+ return std::move(*LiteralUtil::CreateR0<int64>(0));
case F16:
- return std::move(*Literal::CreateR0<half>(static_cast<half>(0.0f)));
+ return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(0.0f)));
case BF16:
return std::move(
- *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
+ *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f)));
case F32:
- return std::move(*Literal::CreateR0<float>(0));
+ return std::move(*LiteralUtil::CreateR0<float>(0));
case F64:
- return std::move(*Literal::CreateR0<double>(0));
+ return std::move(*LiteralUtil::CreateR0<double>(0));
case C64:
- return std::move(*Literal::CreateR0<complex64>(0));
+ return std::move(*LiteralUtil::CreateR0<complex64>(0));
case PRED:
- return std::move(*Literal::CreateR0<bool>(false));
+ return std::move(*LiteralUtil::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -607,33 +144,33 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::One(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
- return std::move(*Literal::CreateR0<uint8>(1));
+ return std::move(*LiteralUtil::CreateR0<uint8>(1));
case U32:
- return std::move(*Literal::CreateR0<uint32>(1));
+ return std::move(*LiteralUtil::CreateR0<uint32>(1));
case U64:
- return std::move(*Literal::CreateR0<uint64>(1));
+ return std::move(*LiteralUtil::CreateR0<uint64>(1));
case S8:
- return std::move(*Literal::CreateR0<int8>(1));
+ return std::move(*LiteralUtil::CreateR0<int8>(1));
case S32:
- return std::move(*Literal::CreateR0<int32>(1));
+ return std::move(*LiteralUtil::CreateR0<int32>(1));
case S64:
- return std::move(*Literal::CreateR0<int64>(1));
+ return std::move(*LiteralUtil::CreateR0<int64>(1));
case F16:
- return std::move(*Literal::CreateR0<half>(static_cast<half>(1.0f)));
+ return std::move(*LiteralUtil::CreateR0<half>(static_cast<half>(1.0f)));
case BF16:
return std::move(
- *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
+ *LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f)));
case F32:
- return std::move(*Literal::CreateR0<float>(1));
+ return std::move(*LiteralUtil::CreateR0<float>(1));
case F64:
- return std::move(*Literal::CreateR0<double>(1));
+ return std::move(*LiteralUtil::CreateR0<double>(1));
case C64:
- return std::move(*Literal::CreateR0<complex64>(1));
+ return std::move(*LiteralUtil::CreateR0<complex64>(1));
case PRED:
- return std::move(*Literal::CreateR0<bool>(true));
+ return std::move(*LiteralUtil::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
@@ -646,44 +183,44 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::MinValue(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
- *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
+ *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min()));
case U32:
return std::move(
- *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
+ *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min()));
case U64:
return std::move(
- *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
+ *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min()));
case S8:
return std::move(
- *Literal::CreateR0<int8>(std::numeric_limits<int8>::min()));
+ *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min()));
case S32:
return std::move(
- *Literal::CreateR0<int32>(std::numeric_limits<int32>::min()));
+ *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min()));
case S64:
return std::move(
- *Literal::CreateR0<int64>(std::numeric_limits<int64>::min()));
+ *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min()));
case F32:
- return std::move(
- *Literal::CreateR0<float>(-std::numeric_limits<float>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity()));
case F64:
- return std::move(
- *Literal::CreateR0<double>(-std::numeric_limits<double>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity()));
case C64:
LOG(FATAL) << "C64 element type has no minimum value";
case PRED:
- return std::move(*Literal::CreateR0<bool>(false));
+ return std::move(*LiteralUtil::CreateR0<bool>(false));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*Literal::CreateR0<half>(
+ return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(-std::numeric_limits<float>::infinity())));
case BF16:
- return std::move(*Literal::CreateR0<bfloat16>(
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(-std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no minimum value";
@@ -694,42 +231,42 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) {
+/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
switch (primitive_type) {
case U8:
return std::move(
- *Literal::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
+ *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max()));
case U32:
return std::move(
- *Literal::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
+ *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max()));
case U64:
return std::move(
- *Literal::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
+ *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max()));
case S8:
return std::move(
- *Literal::CreateR0<int8>(std::numeric_limits<int8>::max()));
+ *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max()));
case S32:
return std::move(
- *Literal::CreateR0<int32>(std::numeric_limits<int32>::max()));
+ *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max()));
case S64:
return std::move(
- *Literal::CreateR0<int64>(std::numeric_limits<int64>::max()));
+ *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max()));
case F32:
- return std::move(
- *Literal::CreateR0<float>(std::numeric_limits<float>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity()));
case F64:
- return std::move(
- *Literal::CreateR0<double>(std::numeric_limits<double>::infinity()));
+ return std::move(*LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity()));
case PRED:
- return std::move(*Literal::CreateR0<bool>(true));
+ return std::move(*LiteralUtil::CreateR0<bool>(true));
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case F16:
- return std::move(*Literal::CreateR0<half>(
+ return std::move(*LiteralUtil::CreateR0<half>(
static_cast<half>(std::numeric_limits<float>::infinity())));
case BF16:
- return std::move(*Literal::CreateR0<bfloat16>(
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
static_cast<bfloat16>(std::numeric_limits<float>::infinity())));
case TUPLE:
LOG(FATAL) << "tuple element type has no maximum value";
@@ -740,7 +277,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
}
}
-/* static */ std::unique_ptr<Literal> Literal::CreateR1(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
const tensorflow::core::Bitmap& values) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
@@ -748,17 +285,7 @@ Status Literal::CopySliceFrom(const LiteralSlice& src_literal,
return literal;
}
-void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(element_count(), values.bits());
- CHECK_EQ(shape().element_type(), PRED);
- for (int64 i = 0; i < static_cast<int64>(values.bits()); ++i) {
- Set({i}, values.get(i));
- }
-}
-
-/* static */ std::unique_ptr<Literal> Literal::CreateR1U8(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
tensorflow::StringPiece value) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
@@ -768,116 +295,13 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::CreateR2F32Linspace(float from,
- float to,
- int64 rows,
- int64 cols) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
+ float from, float to, int64 rows, int64 cols) {
auto value = MakeLinspaceArray2D(from, to, rows, cols);
return CreateR2FromArray2D(*value);
}
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Layout& new_layout, const ShapeIndex& shape_index) const {
- // Create new shape with 'new_layout' set at the given shape index.
- Shape new_shape = shape();
- Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index);
- TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape));
- *subshape->mutable_layout() = new_layout;
- auto result = MakeUnique<Literal>(new_shape);
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
-std::unique_ptr<Literal> LiteralBase::Relayout(
- const Shape& shape_with_layout) const {
- CHECK(ShapeUtil::Compatible(shape_with_layout, shape()))
- << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout)
- << " not compatible with literal shape "
- << ShapeUtil::HumanString(shape());
- std::unique_ptr<Literal> result = CreateFromShape(shape_with_layout);
- ShapeUtil::ForEachSubshape(
- result->shape(),
- [this, &result](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsArray(subshape)) {
- TF_CHECK_OK(result->CopyFrom(*this,
- /*dest_shape_index=*/index,
- /*src_shape_index=*/index));
- }
- });
- return result;
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
- if (!ShapeUtil::IsArray(shape())) {
- return InvalidArgument("Broadcast only supports arrays.");
- }
-
- for (int64 i = 0; i < dimensions.size(); i++) {
- TF_RET_CHECK(shape().dimensions(i) ==
- result_shape.dimensions(dimensions[i]));
- }
-
- std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
-
- // scratch_source_index is temporary storage space for the computed index into
- // the input literal. We put it here to avoid allocating an std::vector in
- // every iteration of ShapeUtil::ForEachIndex.
- std::vector<int64> scratch_source_index(shape().dimensions_size());
-
- char* dest_data = static_cast<char*>(result->untyped_data());
- const char* source_data = static_cast<const char*>(untyped_data());
- const int64 primitive_size =
- ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
-
- ShapeUtil::ForEachIndex(
- result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
- for (int64 i = 0; i < dimensions.size(); ++i) {
- scratch_source_index[i] = output_index[dimensions[i]];
- }
- int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- result_shape, output_index);
- int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
- shape(), scratch_source_index);
- memcpy(dest_data + primitive_size * dest_index,
- source_data + primitive_size * source_index, primitive_size);
- return true;
- });
-
- return std::move(result);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const {
- if (!ShapeUtil::IsArray(shape())) {
- return InvalidArgument("Reshape does not support tuples.");
- }
- std::unique_ptr<Literal> output;
- if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) {
- output =
- Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape())));
- } else {
- output = CloneToUnique();
- }
- // Because the layout is monotonic, we can simply reuse the same sequence of
- // values without changing their order.
- *output->mutable_shape_do_not_use() =
- ShapeUtil::MakeShape(shape().element_type(), dimensions);
-
- int64 elements_before = ShapeUtil::ElementsIn(shape());
- int64 elements_after = ShapeUtil::ElementsIn(output->shape());
- if (elements_before != elements_after) {
- return InvalidArgument(
- "Shapes before and after Literal::Reshape have different numbers "
- "of elements: %s vs %s.",
- ShapeUtil::HumanString(shape()).c_str(),
- ShapeUtil::HumanString(output->shape()).c_str());
- }
- return std::move(output);
-}
-
-/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+/* static */ std::unique_ptr<Literal> LiteralUtil::ReshapeSlice(
tensorflow::gtl::ArraySlice<int64> new_dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major,
const LiteralSlice& literal) {
@@ -949,587 +373,64 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
return new_literal;
}
-std::unique_ptr<Literal> LiteralBase::Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const {
- CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
- CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape())))
- << "Given permutation is not a permutation of dimension numbers";
- // To transpose the array, we just permute the dimensions and layout, and
- // do a straight memory copy of the raw data set.
- // This is considerably faster than iterating over every array element using
- // the EachCell<>() and Set<>() APIs.
- std::vector<int64> inverse_permutation = InversePermutation(permutation);
- Shape permuted_shape =
- ShapeUtil::PermuteDimensions(inverse_permutation, shape());
- // Replace the layout with one affine to this shape, such that a
- // transpose operation can be performed by leaving the flat values
- // representation intact.
- // For example, consider the shape F32[11,8]{1,0} under a {1,0} permutation.
- // The shape with affine layout resulting from that operation will be
- // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
- // most minor.
- //
- // Essentially, given MinMaj(Di) the position of the Di dimension within the
- // minor to major vector, and given T(Di) the index that the original Di
- // dimension has within the transposed array, a layout is affine if
- // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major
- // vector of the affine layout.
- CHECK(LayoutUtil::IsDenseArray(permuted_shape));
- Layout* layout = permuted_shape.mutable_layout();
- layout->clear_minor_to_major();
- for (auto index : LayoutUtil::MinorToMajor(shape())) {
- layout->add_minor_to_major(inverse_permutation[index]);
- }
- auto new_literal = MakeUnique<Literal>(permuted_shape);
- DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()),
- ShapeUtil::ByteSizeOf(shape()));
- std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes());
- return new_literal;
-}
-
-std::unique_ptr<Literal> LiteralBase::Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const {
- CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice";
-
- DimensionVector result_dimensions;
- for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) {
- CHECK_GE(start_indices[dnum], 0);
- CHECK_LE(limit_indices[dnum], shape().dimensions(dnum))
- << "dnum = " << dnum;
- int64 dimension = limit_indices[dnum] - start_indices[dnum];
- CHECK_GE(dimension, 0) << "dnum = " << dnum;
- result_dimensions.push_back(dimension);
- }
- const auto result_shape =
- ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
- LayoutUtil::MinorToMajor(shape()));
-
- auto result_literal = MakeUnique<Literal>(result_shape);
-
- DimensionVector new_indices(ShapeUtil::Rank(result_shape));
- switch (result_shape.element_type()) {
- case F32:
- result_literal->EachCell<float>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, float /*value*/) {
- for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- float value = Get<float>(new_indices);
- result_literal->Set<float>(indices, value);
- });
- return result_literal;
- case C64:
- result_literal->EachCell<complex64>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, complex64 /*value*/) {
- for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- complex64 value = Get<complex64>(new_indices);
- result_literal->Set<complex64>(indices, value);
- });
- return result_literal;
- case S32:
- result_literal->EachCell<int32>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
- for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- int32 value = Get<int32>(new_indices);
- result_literal->Set<int32>(indices, value);
- });
- return result_literal;
- case U32:
- result_literal->EachCell<uint32>(
- [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 /*value*/) {
- for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
- new_indices[i] = indices[i] + start_indices[i];
- }
- uint32 value = Get<uint32>(new_indices);
- result_literal->Set<uint32>(indices, value);
- });
- return result_literal;
- default:
- LOG(FATAL) << "not yet implemented: "
- << PrimitiveType_Name(result_shape.element_type());
- }
-}
-
-Literal LiteralBase::Clone() const {
- Literal result(shape());
- TF_CHECK_OK(result.CopyFrom(*this));
- return result;
-}
-
-std::unique_ptr<Literal> LiteralBase::CloneToUnique() const {
- auto result = MakeUnique<Literal>(shape());
- TF_CHECK_OK(result->CopyFrom(*this));
- return result;
-}
-
-string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
- const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
- CHECK(LayoutUtil::IsDenseArray(subshape));
- switch (subshape.element_type()) {
- case PRED:
- return Get<bool>(multi_index, shape_index) ? "true" : "false";
- case S8:
- return StrCat(Get<int8>(multi_index, shape_index));
- case S16:
- return StrCat(Get<int16>(multi_index, shape_index));
- case S32:
- return StrCat(Get<int32>(multi_index, shape_index));
- case S64:
- return StrCat(Get<int64>(multi_index, shape_index));
- case U8:
- return StrCat(Get<uint8>(multi_index, shape_index));
- case U16:
- return StrCat(Get<uint16>(multi_index, shape_index));
- case U32:
- return StrCat(Get<uint32>(multi_index, shape_index));
- case U64:
- return StrCat(Get<uint64>(multi_index, shape_index));
- case F16:
- return StrCat(static_cast<float>(Get<half>(multi_index, shape_index)));
- case F32:
- return StrCat(Get<float>(multi_index, shape_index));
- case BF16:
- return StrCat(
- static_cast<float>(Get<bfloat16>(multi_index, shape_index)));
- case F64:
- return StrCat(Get<double>(multi_index, shape_index));
- case C64: {
- complex64 c = Get<complex64>(multi_index, shape_index);
- return StrCat("(", c.real(), ", ", c.imag(), ")");
- }
- default:
- LOG(FATAL) << PrimitiveType_Name(subshape.element_type());
- }
-}
-
-string LiteralBase::GetSparseElementAsString(
- int64 sparse_element_number, const ShapeIndex& shape_index) const {
- const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index);
- CHECK(LayoutUtil::IsSparseArray(subshape));
- switch (subshape.element_type()) {
- case PRED:
- return GetSparseElement<bool>(sparse_element_number, shape_index)
- ? "true"
- : "false";
- case S8:
- return StrCat(GetSparseElement<int8>(sparse_element_number, shape_index));
- case S16:
- return StrCat(
- GetSparseElement<int16>(sparse_element_number, shape_index));
- case S32:
- return StrCat(
- GetSparseElement<int32>(sparse_element_number, shape_index));
- case S64:
- return StrCat(
- GetSparseElement<int64>(sparse_element_number, shape_index));
- case U8:
- return StrCat(
- GetSparseElement<uint8>(sparse_element_number, shape_index));
- case U16:
- return StrCat(
- GetSparseElement<uint16>(sparse_element_number, shape_index));
- case U32:
- return StrCat(
- GetSparseElement<uint32>(sparse_element_number, shape_index));
- case U64:
- return StrCat(
- GetSparseElement<uint64>(sparse_element_number, shape_index));
- case F16:
- return StrCat(static_cast<float>(
- GetSparseElement<half>(sparse_element_number, shape_index)));
- case F32:
- return StrCat(
- GetSparseElement<float>(sparse_element_number, shape_index));
- case BF16:
- return StrCat(static_cast<float>(
- GetSparseElement<bfloat16>(sparse_element_number, shape_index)));
- case F64:
- return StrCat(
- GetSparseElement<double>(sparse_element_number, shape_index));
- case C64: {
- complex64 c =
- GetSparseElement<complex64>(sparse_element_number, shape_index);
- return StrCat("(", c.real(), ", ", c.imag(), ")");
- }
- default:
- LOG(FATAL) << "Invalid element type for sparse arrays: "
- << PrimitiveType_Name(subshape.element_type());
- }
-}
-
-StatusOr<int64> LiteralBase::GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(LayoutUtil::IsDenseArray(shape()));
- switch (shape().element_type()) {
- case PRED:
- return Get<bool>(multi_index);
- case U8:
- return Get<uint8>(multi_index);
- case S32:
- return Get<int32>(multi_index);
- case S64:
- return Get<int64>(multi_index);
- case U32:
- return Get<uint32>(multi_index);
- case U64:
- return Get<uint64>(multi_index);
- default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
- }
-}
-
-size_t LiteralBase::Hash() const {
- using tensorflow::Hash64;
- using tensorflow::Hash64Combine;
-
- size_t hash_value = ShapeUtil::Hash(shape());
-
- ShapeUtil::ForEachSubshape(
- shape(), [&](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsTuple(subshape)) {
- return;
- }
-
- CHECK(LayoutUtil::IsDense(subshape.layout()));
- hash_value = Hash64Combine(
- hash_value, Hash64(static_cast<const char*>(untyped_data(index)),
- size_bytes(index)));
- });
-
- return hash_value;
-}
-
-Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value) {
- CHECK(LayoutUtil::IsDenseArray(shape()));
- switch (shape().element_type()) {
- case PRED:
- Set<bool>(multi_index, value);
- break;
- case U8:
- Set<uint8>(multi_index, value);
- break;
- case S32:
- Set<int32>(multi_index, value);
- break;
- case S64:
- Set<int64>(multi_index, value);
- break;
- case U32:
- Set<uint32>(multi_index, value);
- break;
- case U64:
- Set<uint64>(multi_index, value);
- break;
- default:
- return FailedPrecondition(
- "Array element type is not integral: %s",
- PrimitiveType_Name(shape().element_type()).c_str());
- }
- return Status::OK();
-}
-
-tensorflow::gtl::ArraySlice<int64> LiteralBase::GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index) const {
- const Piece& p = piece(shape_index);
- CHECK_GE(sparse_element_number, 0);
- CHECK_LT(sparse_element_number, p.sparse_indices()->index_count());
- return p.sparse_indices()->At(sparse_element_number);
-}
-
-void Literal::SortSparseElements(const ShapeIndex& shape_index) {
- piece(shape_index).SortSparseElements();
-}
-
-Literal LiteralBase::GetFirstScalarLiteral() const {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_GT(ShapeUtil::ElementsIn(shape()), 0);
- switch (shape().element_type()) {
+/* static */ Literal LiteralUtil::GetFirstScalarLiteral(
+ const LiteralSlice& literal) {
+ CHECK(ShapeUtil::IsArray(literal.shape()));
+ CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
+ switch (literal.shape().element_type()) {
case PRED:
- return std::move(*Literal::CreateR0<bool>(GetFirstElement<bool>()));
+ return std::move(
+ *LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>()));
// 8 bit types.
case S8:
- return std::move(*Literal::CreateR0<int8>(GetFirstElement<int8>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>()));
case U8:
- return std::move(*Literal::CreateR0<uint8>(GetFirstElement<uint8>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>()));
// 16 bit types.
case BF16:
- return std::move(
- *Literal::CreateR0<bfloat16>(GetFirstElement<bfloat16>()));
+ return std::move(*LiteralUtil::CreateR0<bfloat16>(
+ literal.GetFirstElement<bfloat16>()));
case F16:
- return std::move(*Literal::CreateR0<half>(GetFirstElement<half>()));
+ return std::move(
+ *LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>()));
case S16:
- return std::move(*Literal::CreateR0<int16>(GetFirstElement<int16>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>()));
case U16:
- return std::move(*Literal::CreateR0<uint16>(GetFirstElement<uint16>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>()));
// 32 bit types.
case F32:
- return std::move(*Literal::CreateR0<float>(GetFirstElement<float>()));
+ return std::move(
+ *LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>()));
case S32:
- return std::move(*Literal::CreateR0<int32>(GetFirstElement<int32>()));
+ return std::move(
+ *LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>()));
case U32:
- return std::move(*Literal::CreateR0<uint32>(GetFirstElement<uint32>()));
+ return std::move(
+ *LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>()));
// 64 bit types.
case C64:
- return std::move(
- *Literal::CreateR0<complex64>(GetFirstElement<complex64>()));
+ return std::move(*LiteralUtil::CreateR0<complex64>(
+ literal.GetFirstElement<complex64>()));
case F64:
- return std::move(*Literal::CreateR0<double>(GetFirstElement<double>()));
- case S64:
- return std::move(*Literal::CreateR0<int64>(GetFirstElement<int64>()));
- case U64:
- return std::move(*Literal::CreateR0<uint64>(GetFirstElement<uint64>()));
- default:
- LOG(FATAL) << "Unhandled primitive type " << shape().element_type();
- }
-}
-
-void LiteralBase::Piece::SortSparseElements() {
- switch (subshape().element_type()) {
- case PRED:
- SortSparseElementsInternal<bool>();
- break;
- case S8:
- SortSparseElementsInternal<int8>();
- break;
- case U8:
- SortSparseElementsInternal<uint8>();
- break;
- case S16:
- SortSparseElementsInternal<int16>();
- break;
- case U16:
- SortSparseElementsInternal<uint16>();
- break;
- case S32:
- SortSparseElementsInternal<int32>();
- break;
- case U32:
- SortSparseElementsInternal<uint32>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>()));
case S64:
- SortSparseElementsInternal<int64>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>()));
case U64:
- SortSparseElementsInternal<uint64>();
- break;
- case F32:
- SortSparseElementsInternal<float>();
- break;
- case F64:
- SortSparseElementsInternal<double>();
- break;
- case C64:
- SortSparseElementsInternal<complex64>();
- break;
- case F16:
- SortSparseElementsInternal<half>();
- break;
- case BF16:
- SortSparseElementsInternal<bfloat16>();
- break;
+ return std::move(
+ *LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>()));
default:
- LOG(FATAL) << "Element type not valid for sparse array: "
- << PrimitiveType_Name(subshape().element_type());
- }
-}
-
-template <typename NativeT>
-void LiteralBase::Piece::SortSparseElementsInternal() {
- CHECK(LayoutUtil::IsSparseArray(subshape()));
- int64 num_elements = sparse_indices()->index_count();
- auto values = data<NativeT>();
- CHECK_LE(num_elements, values.size());
- sparse_indices()->SortWithValues(
- tensorflow::gtl::MutableArraySlice<NativeT>(values.data(), num_elements));
-}
-
-namespace {
-
-void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index,
- bool print_layout, std::vector<string>* pieces) {
- const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index);
- CHECK(LayoutUtil::HasLayout(literal.shape()));
- CHECK(LayoutUtil::HasLayout(subshape));
-
- auto shape_to_string = [print_layout](const Shape& shape) {
- if (print_layout) {
- return ShapeUtil::HumanStringWithLayout(shape);
- } else {
- return ShapeUtil::HumanString(shape);
- }
- };
-
- // TODO(b/32894291): refactor this code to reduce code duplication.
- if (ShapeUtil::IsTuple(subshape)) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" (\n");
- std::vector<string> tuple_pieces;
- for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) {
- ShapeIndex element_index = shape_index;
- element_index.push_back(i);
- std::vector<string> element_pieces;
- ToStringHelper(literal, element_index, print_layout, &element_pieces);
- tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, ""));
- }
- pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n"));
- pieces->push_back("\n)");
- return;
- }
-
- if (LayoutUtil::IsSparseArray(subshape)) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back("{");
- int64 rank = ShapeUtil::Rank(subshape);
- int64 num_elements = literal.sparse_element_count();
- for (int64 i = 0; i < num_elements; ++i) {
- if (i > 0) {
- pieces->push_back(", ");
- }
- if (rank == 1) {
- pieces->push_back(StrCat(literal.GetSparseIndex(i)[0]));
- pieces->push_back(": ");
- } else {
- pieces->push_back("[");
- pieces->push_back(
- tensorflow::str_util::Join(literal.GetSparseIndex(i), ", "));
- pieces->push_back("]: ");
- }
- pieces->push_back(literal.GetSparseElementAsString(i));
- }
- pieces->push_back("}");
- return;
- }
-
- CHECK(LayoutUtil::IsDenseArray(subshape));
-
- auto element_to_string =
- [&](tensorflow::gtl::ArraySlice<int64> indices) -> string {
- PrimitiveType element_type = subshape.element_type();
- if (element_type == PRED) {
- // We display predicates in a densely packed form.
- return literal.Get<bool>(indices, shape_index) ? "1" : "0";
- }
- return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
- literal.GetAsString(indices, shape_index);
- };
-
- if (ShapeUtil::Rank(subshape) == 0) {
- pieces->push_back(literal.GetAsString({}, shape_index));
- } else if (ShapeUtil::Rank(subshape) == 1) {
- pieces->push_back("{");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(element_to_string({i0}));
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 2) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(" { ");
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(element_to_string({i0, i1}));
- }
- pieces->push_back(" ");
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n");
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 3) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(i0 > 0 ? ",\n{" : "{");
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(i1 > 0 ? ",\n { " : " { ");
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(element_to_string({i0, i1, i2}));
- }
- pieces->push_back(" }");
- }
- pieces->push_back(" }");
- }
- pieces->push_back("\n}");
- } else if (ShapeUtil::Rank(subshape) == 4) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(" {");
- for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
- pieces->push_back(element_to_string({i0, i1, i2, i3}));
- }
- pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n");
- }
- pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
- }
- pieces->push_back("}");
- } else if (ShapeUtil::Rank(subshape) == 5) {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {\n");
- for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) {
- pieces->push_back(Printf(" { /*i0=%lld*/\n", i0));
- for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) {
- pieces->push_back(Printf(" { /*i1=%lld*/\n", i1));
- for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) {
- pieces->push_back(Printf(" { /*i2=%lld*/\n", i2));
- for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) {
- pieces->push_back(" {");
- for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) {
- pieces->push_back(element_to_string({i0, i1, i2, i3, i4}));
- }
- pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n"
- : "},\n");
- }
- pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n"
- : " },\n");
- }
- pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n");
- }
- pieces->push_back("}");
- } else {
- pieces->push_back(shape_to_string(subshape));
- pieces->push_back(" {");
- literal.EachCellAsString(
- [&](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
- pieces->push_back(" ");
- pieces->push_back(value);
- });
- pieces->push_back("}");
+ LOG(FATAL) << "Unhandled primitive type "
+ << literal.shape().element_type();
}
}
-} // namespace
-
-int64 LiteralBase::sparse_element_count() const {
- CHECK(LayoutUtil::IsSparseArray(shape()));
- return sparse_indices()->index_count();
-}
-
-string LiteralBase::ToString(bool print_layout) const {
- std::vector<string> pieces;
- CHECK(LayoutUtil::HasLayout(this->shape()));
- ToStringHelper(*this, {}, print_layout, &pieces);
- return tensorflow::str_util::Join(pieces, "");
-}
-
-/* static */ std::unique_ptr<Literal> Literal::MakeTuple(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
tensorflow::gtl::ArraySlice<const Literal*> elements) {
std::vector<Shape> element_shapes;
for (const auto* element : elements) {
@@ -1542,7 +443,7 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::MakeTupleFromSlices(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleFromSlices(
tensorflow::gtl::ArraySlice<LiteralSlice> elements) {
std::vector<Shape> element_shapes;
for (const auto& element : elements) {
@@ -1555,7 +456,7 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-/* static */ std::unique_ptr<Literal> Literal::MakeTupleOwned(
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTupleOwned(
std::vector<std::unique_ptr<Literal>> elements) {
std::vector<Shape> element_shapes;
element_shapes.reserve(elements.size());
@@ -1570,819 +471,9 @@ string LiteralBase::ToString(bool print_layout) const {
return literal;
}
-void LiteralBase::EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
- return;
- }
- std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
- shape(), /*linear_index=*/0);
- do {
- per_cell(indices, GetAsString(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
-}
-
-namespace {
-template <typename NativeSrcT, typename NativeDestT, typename ConverterType>
-std::unique_ptr<Literal> ConvertBetweenNativeTypesWithConverter(
- const LiteralBase& src_literal, const ConverterType& converter) {
- CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(ShapeUtil::ChangeElementType(
- src_literal.shape(),
- primitive_util::NativeToPrimitiveType<NativeDestT>()));
- auto src_data = src_literal.data<NativeSrcT>();
- auto dest_data = result_literal->template data<NativeDestT>();
- int64 num_elements = src_literal.element_count();
-
- for (int64 i = 0; i < num_elements; ++i) {
- dest_data[i] = converter(src_data[i]);
- }
- return result_literal;
-}
-
-template <typename NativeSrcT, typename NativeDestT>
-std::unique_ptr<Literal> ConvertBetweenNativeTypes(
- const LiteralBase& src_literal) {
- auto converter = [](NativeSrcT src) { return static_cast<NativeDestT>(src); };
- return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
- src_literal, converter);
-}
-
-template <typename NativeSrcT, typename NativeDestT>
-typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
- auto converter = [](NativeSrcT src) {
- return tensorflow::bit_cast<NativeDestT>(src);
- };
- return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
- src_literal, converter);
-}
-
-// This template specialization is here to make the compiler happy. bit_cast has
-// a static check that the types are the same size. This specialization should
-// never be used because the source and destination types are checked for
-// identical sizes higher up.
-template <typename NativeSrcT, typename NativeDestT>
-typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)),
- std::unique_ptr<Literal>>::type
-BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
- LOG(FATAL) << "Invalid bitcast between types of different sizes.";
-}
-
-template <PrimitiveType primitive_src_type>
-std::unique_ptr<Literal> ConvertToC64(const LiteralBase& src_literal) {
- CHECK(ShapeUtil::IsArray(src_literal.shape()));
- auto result_literal = MakeUnique<Literal>(
- ShapeUtil::ChangeElementType(src_literal.shape(), C64));
- using NativeSrcT =
- typename primitive_util::PrimitiveTypeToNative<primitive_src_type>::type;
- tensorflow::gtl::ArraySlice<NativeSrcT> src_data =
- src_literal.data<NativeSrcT>();
- tensorflow::gtl::MutableArraySlice<complex64> dest_data =
- result_literal->data<complex64>();
- int64 num_elements = src_literal.element_count();
- for (int64 i = 0; i < num_elements; ++i) {
- dest_data[i] = complex64(static_cast<float>(src_data[i]), 0);
- }
- return result_literal;
-}
-
-template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
-std::unique_ptr<Literal> ConvertIfTypesMatch(const LiteralBase& src_literal,
- bool bitcast) {
- CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
- if (bitcast) {
- return BitcastBetweenNativeTypes<
- typename primitive_util::PrimitiveTypeToNative<
- primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
- } else {
- return ConvertBetweenNativeTypes<
- typename primitive_util::PrimitiveTypeToNative<
- primitive_src_type>::type,
- typename primitive_util::PrimitiveTypeToNative<
- primitive_dest_type>::type>(src_literal);
- }
-}
-
-template <PrimitiveType primitive_src_type>
-StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
- const LiteralBase& src_literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
- switch (primitive_dest_type) {
-#define CONVERT_IF_TYPES_MATCH(type) \
- case (type): \
- return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal, \
- bitcast);
- CONVERT_IF_TYPES_MATCH(PRED)
- CONVERT_IF_TYPES_MATCH(S8)
- CONVERT_IF_TYPES_MATCH(S32)
- CONVERT_IF_TYPES_MATCH(S64)
- CONVERT_IF_TYPES_MATCH(U8)
- CONVERT_IF_TYPES_MATCH(U32)
- CONVERT_IF_TYPES_MATCH(U64)
- CONVERT_IF_TYPES_MATCH(F16)
- CONVERT_IF_TYPES_MATCH(F32)
- CONVERT_IF_TYPES_MATCH(F64)
- CONVERT_IF_TYPES_MATCH(BF16)
-#undef CONVERT_IF_TYPES_MATCH
- case C64:
- if (!bitcast) {
- return ConvertToC64<primitive_src_type>(src_literal);
- }
- break;
- // Other types are not yet supported.
- default:
- break;
- }
- return Unimplemented(
- "Converting from type %s to type %s is not implemented.",
- PrimitiveType_Name(src_literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
-}
-
-StatusOr<std::unique_ptr<Literal>> ConvertSwitch(
- const LiteralBase& literal, PrimitiveType primitive_dest_type,
- bool bitcast) {
- TF_RET_CHECK(ShapeUtil::IsArray(literal.shape()));
- if (literal.shape().element_type() == primitive_dest_type) {
- return literal.CloneToUnique();
- }
- switch (literal.shape().element_type()) {
-#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
- case (type): \
- return ConvertIfDestTypeMatches<(type)>(literal, primitive_dest_type, \
- bitcast);
- CONVERT_IF_DEST_TYPE_MATCHES(PRED)
- CONVERT_IF_DEST_TYPE_MATCHES(S8)
- CONVERT_IF_DEST_TYPE_MATCHES(S32)
- CONVERT_IF_DEST_TYPE_MATCHES(S64)
- CONVERT_IF_DEST_TYPE_MATCHES(U8)
- CONVERT_IF_DEST_TYPE_MATCHES(U32)
- CONVERT_IF_DEST_TYPE_MATCHES(U64)
- CONVERT_IF_DEST_TYPE_MATCHES(F16)
- CONVERT_IF_DEST_TYPE_MATCHES(F32)
- CONVERT_IF_DEST_TYPE_MATCHES(F64)
- CONVERT_IF_DEST_TYPE_MATCHES(BF16)
-#undef CONVERT_IF_DEST_TYPE_MATCHES
- // Other types are not yet supported.
- default:
- return Unimplemented(
- "%s from type %s to type %s is not implemented.",
- (bitcast ? "Bitcast converting" : "Converting"),
- PrimitiveType_Name(literal.shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str());
- }
-}
-
-} // namespace
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::Convert(
- PrimitiveType primitive_dest_type) const {
- return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::BitcastConvert(
- PrimitiveType primitive_dest_type) const {
- if (primitive_util::BitWidth(shape().element_type()) !=
- primitive_util::BitWidth(primitive_dest_type)) {
- return InvalidArgument(
- "Cannot bitcast convert from %s to %s, bit widths are different: %d != "
- "%d",
- PrimitiveType_Name(shape().element_type()).c_str(),
- PrimitiveType_Name(primitive_dest_type).c_str(),
- primitive_util::BitWidth(shape().element_type()),
- primitive_util::BitWidth(primitive_dest_type));
- }
- return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
-}
-
-StatusOr<std::unique_ptr<Literal>> LiteralBase::ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16) const {
- if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
- return Convert(dest_shape.element_type());
- }
- std::vector<Literal> elements;
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) {
- auto element = LiteralSlice(*this, {i});
- TF_ASSIGN_OR_RETURN(
- auto new_element,
- element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i})));
- elements.push_back(std::move(*new_element));
- }
- auto converted = MakeUnique<Literal>();
- *converted = Literal::MoveIntoTuple(&elements);
- return std::move(converted);
-}
-
-template <typename NativeT>
-bool LiteralBase::Piece::EqualElementsInternal(
- const LiteralBase::Piece& other, std::vector<int64>* multi_index) const {
- if (multi_index->size() == ShapeUtil::Rank(subshape())) {
- return (Get<NativeT>(*multi_index) == other.Get<NativeT>(*multi_index));
- }
- for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) {
- multi_index->push_back(i);
- if (!EqualElementsInternal<NativeT>(other, multi_index)) {
- return false;
- }
- multi_index->pop_back();
- }
- return true;
-}
-
-bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const {
- DCHECK(ShapeUtil::Compatible(subshape(), other.subshape()));
-
- std::vector<int64> multi_index;
- switch (subshape().element_type()) {
- case PRED:
- return EqualElementsInternal<bool>(other, &multi_index);
- case U8:
- return EqualElementsInternal<uint8>(other, &multi_index);
- case S32:
- return EqualElementsInternal<int32>(other, &multi_index);
- case S64:
- return EqualElementsInternal<int64>(other, &multi_index);
- case U32:
- return EqualElementsInternal<uint32>(other, &multi_index);
- case U64:
- return EqualElementsInternal<uint64>(other, &multi_index);
- case F32:
- return EqualElementsInternal<float>(other, &multi_index);
- case F64:
- return EqualElementsInternal<double>(other, &multi_index);
- case F16:
- return EqualElementsInternal<half>(other, &multi_index);
- case BF16:
- return EqualElementsInternal<bfloat16>(other, &multi_index);
- case C64:
- return EqualElementsInternal<complex64>(other, &multi_index);
- default:
- LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type "
- << PrimitiveType_Name(subshape().element_type());
- }
-}
-
-bool LiteralBase::operator==(const LiteralBase& other) const {
- if (!ShapeUtil::Compatible(shape(), other.shape())) {
- return false;
- }
-
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- const Piece& other_piece = other.piece(index);
- if (!piece.EqualElements(other_piece)) {
- return false;
- }
- return true;
- });
-}
-
-namespace {
-
-template <typename NativeT>
-static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice<NativeT> data,
- NativeT value) {
- for (int64 i = 0; i < data.size(); ++i) {
- if (data[i] != value) {
- return false;
- }
- }
- return true;
-}
-
-} // namespace
-
-bool LiteralBase::IsAll(int8 value) const {
- return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index,
- const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case U8:
- if (value >= 0) {
- return AllElementsEqualValue<uint8>(piece.data<uint8>(), value);
- }
- return false;
- case U32:
- if (value >= 0) {
- return AllElementsEqualValue<uint32>(piece.data<uint32>(), value);
- }
- return false;
- case U64:
- if (value >= 0) {
- return AllElementsEqualValue<uint64>(piece.data<uint64>(), value);
- }
- return false;
- case S8:
- return AllElementsEqualValue<int8>(piece.data<int8>(), value);
- case S32:
- return AllElementsEqualValue<int32>(piece.data<int32>(), value);
- case S64:
- return AllElementsEqualValue<int64>(piece.data<int64>(), value);
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(piece.data<bfloat16>(),
- static_cast<bfloat16>(value));
- case PRED:
- if (value == 0) {
- return AllElementsEqualValue<bool>(piece.data<bool>(), false);
- }
- if (value == 1) {
- return AllElementsEqualValue<bool>(piece.data<bool>(), true);
- }
- return false;
- default:
- return false;
- }
- return false;
- };
-
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsAllFloat(float value) const {
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(
- piece.data<bfloat16>(), static_cast<bfloat16>(value));
- default:
- return false;
- }
- };
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsAllComplex(complex64 value) const {
- switch (shape().element_type()) {
- case C64:
- return AllElementsEqualValue<complex64>(root_piece().data<complex64>(),
- value);
- default:
- return false;
- }
-}
-
-bool LiteralBase::IsAllFirst() const {
- return root_piece().ForEachSubpieceWithBool(
- [&](const ShapeIndex& index, const Piece& piece) {
- if (!ShapeUtil::IsArray(piece.subshape())) {
- return true;
- }
-
- // Empty shapes are not all the first element since there is no first
- // element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
- return false;
- }
- auto piece_is_all = [&]() {
- switch (piece.subshape().element_type()) {
- case PRED: {
- auto data = piece.data<bool>();
- return AllElementsEqualValue<bool>(data, data[0]);
- }
- // 8 bit types
- case S8: {
- auto data = piece.data<int8>();
- return AllElementsEqualValue<int8>(data, data[0]);
- }
- case U8: {
- auto data = piece.data<uint8>();
- return AllElementsEqualValue<uint8>(data, data[0]);
- }
- // 16 bit types
- case BF16: {
- auto data = piece.data<bfloat16>();
- return AllElementsEqualValue<bfloat16>(data, data[0]);
- }
- case F16: {
- auto data = piece.data<half>();
- return AllElementsEqualValue<half>(data, data[0]);
- }
- case S16: {
- auto data = piece.data<int16>();
- return AllElementsEqualValue<int16>(data, data[0]);
- }
- case U16: {
- auto data = piece.data<uint16>();
- return AllElementsEqualValue<uint16>(data, data[0]);
- }
- // 32 bit types
- case F32: {
- auto data = piece.data<float>();
- return AllElementsEqualValue<float>(data, data[0]);
- }
- case U32: {
- auto data = piece.data<uint32>();
- return AllElementsEqualValue<uint32>(data, data[0]);
- }
- case S32: {
- auto data = piece.data<int32>();
- return AllElementsEqualValue<int32>(data, data[0]);
- }
- // 64 bit types
- case C64: {
- auto data = piece.data<complex64>();
- return AllElementsEqualValue<complex64>(data, data[0]);
- }
- case F64: {
- auto data = piece.data<double>();
- return AllElementsEqualValue<double>(data, data[0]);
- }
- case S64: {
- auto data = piece.data<int64>();
- return AllElementsEqualValue<int64>(data, data[0]);
- }
- case U64: {
- auto data = piece.data<uint64>();
- return AllElementsEqualValue<uint64>(data, data[0]);
- }
- default:
- return false;
- }
- };
-
- if (!piece_is_all()) {
- return false;
- }
- return true;
- });
-}
-
-bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
- CHECK(ShapeUtil::IsArray(shape()));
- switch (shape().element_type()) {
- case U8:
- return Get<uint8>(indices) == 0;
- case U32:
- return Get<uint32>(indices) == 0;
- case U64:
- return Get<uint64>(indices) == 0;
- case S8:
- return Get<int8>(indices) == 0;
- case S32:
- return Get<int32>(indices) == 0;
- case S64:
- return Get<int64>(indices) == 0;
- case F32:
- return Get<float>(indices) == 0.0f;
- case F64:
- return Get<double>(indices) == 0.0;
- case C64:
- return Get<complex64>(indices) == complex64(0.0f, 0.0f);
- case F16:
- return Get<half>(indices) == static_cast<half>(0.0f);
- case BF16:
- return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
- case PRED:
- return Get<bool>(indices) == false;
- default:
- LOG(FATAL) << "Input literal must be an array.";
- }
-}
-
-namespace {
-
-template <typename RepeatedFieldT, typename NativeT>
-void CopyToRepeatedField(RepeatedFieldT* dest,
- const tensorflow::gtl::ArraySlice<NativeT> src) {
- *dest = RepeatedFieldT(src.begin(), src.end());
-}
-
-} // namespace
-
-void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
- *proto->mutable_shape() = subshape();
- switch (subshape().element_type()) {
- case PRED:
- CopyToRepeatedField(proto->mutable_preds(), data<bool>());
- break;
- case U8:
- proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
- element_count());
- break;
- case U32:
- CopyToRepeatedField(proto->mutable_u32s(), data<uint32>());
- break;
- case U64:
- CopyToRepeatedField(proto->mutable_u64s(), data<uint64>());
- break;
- case S32:
- CopyToRepeatedField(proto->mutable_s32s(), data<int32>());
- break;
- case S64:
- CopyToRepeatedField(proto->mutable_s64s(), data<int64>());
- break;
- case F16:
- *proto->mutable_f16s() = string(
- reinterpret_cast<const char*>(data<half>().data()), size_bytes());
- if (!kLittleEndian) {
- ConvertEndianShort(proto->mutable_f16s());
- }
- break;
- case BF16:
- *proto->mutable_bf16s() = string(
- reinterpret_cast<const char*>(data<bfloat16>().data()), size_bytes());
- if (!kLittleEndian) {
- ConvertEndianShort(proto->mutable_bf16s());
- }
- break;
- case F32:
- CopyToRepeatedField(proto->mutable_f32s(), data<float>());
- break;
- case F64:
- CopyToRepeatedField(proto->mutable_f64s(), data<double>());
- break;
- case C64:
- for (complex64 value : data<complex64>()) {
- proto->add_c64s(value.real());
- proto->add_c64s(value.imag());
- }
- break;
- case TUPLE:
- // Nothing to do but assign the shape which is done above.
- return;
- default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
- }
-}
-
-const void* LiteralBase::Piece::untyped_data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- return buffer();
-}
-
-void* LiteralBase::Piece::untyped_data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- return buffer();
-}
-
-namespace {
-
-template <typename RepeatedFieldT, typename NativeT>
-Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice<NativeT> dest,
- const RepeatedFieldT& src) {
- if (dest.size() != src.size()) {
- return InvalidArgument(
- "Expected %lu elements in LiteralProto repeated field, has %d",
- dest.size(), src.size());
- }
- std::copy(src.begin(), src.end(), dest.begin());
- return Status::OK();
-}
-
-} // namespace
-
-Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
- // These conditions should have been checked in Literal::CreateFromProto.
- TF_RET_CHECK(proto.has_shape());
- TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
- TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
-
- switch (subshape().element_type()) {
- case PRED:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
- break;
- case U8: {
- auto u8_data = data<uint8>();
- TF_RET_CHECK(proto.u8s().size() == u8_data.size());
- std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin());
- } break;
- case S32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int32>(), proto.s32s()));
- break;
- case S64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<int64>(), proto.s64s()));
- break;
- case U32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint32>(), proto.u32s()));
- break;
- case U64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<uint64>(), proto.u64s()));
- break;
- case F16: {
- const string& s(proto.f16s());
- TF_RET_CHECK(data<half>().size() * sizeof(half) == s.size());
- memcpy(untyped_data(), s.data(), s.size());
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
- }
- } break;
-
- case BF16: {
- const string& s(proto.bf16s());
- TF_RET_CHECK(data<bfloat16>().size() * sizeof(bfloat16) == s.size());
- memcpy(untyped_data(), s.data(), s.size());
- if (!kLittleEndian) {
- ConvertEndianShort(reinterpret_cast<char*>(untyped_data()), s.size());
- }
- } break;
- case F32:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<float>(), proto.f32s()));
- break;
- case F64:
- TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<double>(), proto.f64s()));
- break;
- case C64: {
- auto complex_data = data<complex64>();
- TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2);
- for (int64 i = 0; i < complex_data.size(); ++i) {
- complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)};
- }
- } break;
- case TUPLE:
- LOG(FATAL) << "Should not be called on tuple shapes: "
- << ShapeUtil::HumanString(subshape());
- break;
- default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
- }
- return Status::OK();
-}
-
-LiteralProto LiteralBase::ToProto() const {
- LiteralProto proto;
- root_piece().ForEachSubpiece(
- [&](const ShapeIndex& index, const Piece& piece) {
- LiteralProto* proto_piece = &proto;
- for (int64 i : index) {
- while (proto_piece->tuple_literals_size() <= i) {
- proto_piece->add_tuple_literals();
- }
- proto_piece = proto_piece->mutable_tuple_literals(i);
- }
- piece.WriteToProto(proto_piece);
- });
-
- if (LayoutUtil::IsSparseArray(shape())) {
- CopyToRepeatedField(proto.mutable_sparse_indices(),
- sparse_indices()->data());
- }
-
- return proto;
-}
-
-/* static */
-StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
- const LiteralProto& proto) {
- if (!proto.has_shape()) {
- return InvalidArgument("LiteralProto has no shape");
- }
- if (!LayoutUtil::HasLayout(proto.shape())) {
- return InvalidArgument("LiteralProto has no layout");
- }
-
- auto literal = MakeUnique<Literal>(proto.shape());
-
- TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus(
- [&](const ShapeIndex& index, Piece* piece) {
- const LiteralProto* proto_element = &proto;
- for (int64 i : index) {
- CHECK(i < proto_element->tuple_literals_size());
- proto_element = &proto_element->tuple_literals(i);
- }
-
- if (ShapeUtil::IsTuple(piece->subshape())) {
- if (proto_element->tuple_literals_size() !=
- ShapeUtil::TupleElementCount(piece->subshape())) {
- return InvalidArgument(
- "Expected %lld tuple elements in LiteralProto, has %d",
- ShapeUtil::TupleElementCount(piece->subshape()),
- proto_element->tuple_literals_size());
- }
- return Status::OK();
- }
-
- CHECK(ShapeUtil::IsArray(piece->subshape()));
- TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element));
-
- return Status::OK();
- }));
-
- return std::move(literal);
-}
-
-/* static */ string Literal::MultiIndexAsString(
+/* static */ string LiteralUtil::MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index) {
return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
}
-const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
- return piece(shape_index).untyped_data();
-}
-
-void* Literal::untyped_data(const ShapeIndex& shape_index) {
- return piece(shape_index).untyped_data();
-}
-
-int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const {
- return piece(shape_index).size_bytes();
-}
-
-string LiteralBase::GetR1U8AsString() const {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(shape().element_type(), U8);
- return string(tensorflow::bit_cast<const char*>(data<uint8>().data()),
- ShapeUtil::ElementsIn(shape()));
-}
-
-void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) {
- CHECK(ShapeUtil::IsTuple(shape));
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& subshape = shape.tuple_shapes(i);
-
- auto child_piece = Piece();
- child_piece.set_subshape(&subshape);
-
- if (ShapeUtil::IsTuple(subshape)) {
- BuildPieceSubtree(subshape, &child_piece);
- }
-
- piece->emplace_back(std::move(child_piece));
- }
-}
-
-LiteralSlice::LiteralSlice(const LiteralBase& literal)
- : LiteralBase(), root_piece_(&literal.root_piece()) {}
-
-LiteralSlice::LiteralSlice(const LiteralBase& literal,
- const ShapeIndex& view_root)
- : LiteralBase(), root_piece_(&literal.piece(view_root)) {}
-
-BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsArray(shape_));
- CHECK_NE(src_buf_ptr, nullptr);
- CHECK(LayoutUtil::HasLayout(shape_));
-
- root_piece_ = Piece();
- root_piece_.set_buffer(const_cast<char*>(src_buf_ptr));
- root_piece_.set_subshape(&shape_);
-}
-
-BorrowingLiteral::BorrowingLiteral(
- tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs, const Shape& shape)
- : LiteralBase(), shape_(shape) {
- CHECK(ShapeUtil::IsTuple(shape_));
- CHECK(!ShapeUtil::IsNestedTuple(shape_));
- CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_));
- root_piece_ = Piece();
- root_piece_.set_subshape(&shape_);
- BuildPieceSubtree(shape_, &root_piece_);
-
- for (int i = 0; i < src_buf_ptrs.size(); ++i) {
- const auto& src_shape = shape_.tuple_shapes(i);
- CHECK(ShapeUtil::IsArray(src_shape));
- root_piece_.child(i).set_buffer(const_cast<char*>(src_buf_ptrs[i]));
- }
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 2ca9060cc7..e3737a9d00 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -51,673 +52,12 @@ limitations under the License.
namespace xla {
-// Forward declare Literal and LiteralSlice class to be used by the creation
-// methods in the base class.
-class Literal;
-class LiteralSlice;
-
-// Abstract base class for literals.
-class LiteralBase {
+class LiteralUtil {
public:
- virtual ~LiteralBase() = 0;
-
- // Literals are equal if they have compatible shapes and the same data
- // values. Layout is not compared.
- bool operator==(const LiteralBase& other) const;
- bool operator!=(const LiteralBase& other) const { return !(*this == other); }
-
- // Returns the shape of the literal.
- const Shape& shape() const { return root_piece().subshape(); }
-
- // Serialize to proto.
- LiteralProto ToProto() const;
-
- // Returns an ArraySlice of the array for this literal for the given NativeT
- // (e.g., float). CHECKs if the subshape of the literal at the given
- // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type
- // to native type.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {}) const;
-
- // Returns a const pointer to the sparse index array. Returns nullptr if the
- // literal is not a sparse array.
- const SparseIndexArray* sparse_indices(
- const ShapeIndex& shape_index = {}) const;
-
- // Returns a const pointer to (or size of) the underlying buffer holding the
- // array at the given shape index. CHECKs if the subshape of the literal at
- // the given ShapeIndex is not array.
- const void* untyped_data(const ShapeIndex& shape_index = {}) const;
- int64 size_bytes(const ShapeIndex& shape_index = {}) const;
-
- // Returns this literal's data as a string. This literal must be a rank-1 U8
- // array.
- string GetR1U8AsString() const;
-
- // Returns a string representation of the literal value.
- // Warning: this function can take minutes for multi-million element Literals.
- string ToString(bool print_layout = false) const;
-
- // Gets an element in the literal at the given index. The multi_index is
- // CHECKed against the dimension sizes.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const;
- // Overloads of Get for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // Returns the element value at index (0, ..., 0), however many zeroes are
- // required for that index.
- template <typename NativeT>
- NativeT GetFirstElement() const;
-
- // As Get(), but determines the correct type and converts the value
- // into text.
- string GetAsString(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index = {}) const;
- // As GetSparseElement(), but determines the correct type and converts the
- // value into text.
- string GetSparseElementAsString(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
- // As Get(), but determines the correct type and converts the value into
- // int64. This literal must be an array.
- StatusOr<int64> GetIntegralAsS64(
- tensorflow::gtl::ArraySlice<int64> multi_index) const;
-
- // Returns the multi-index of the element in a sparse literal at the given
- // sparse element number. The sparse element number is the position with in
- // the sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- tensorflow::gtl::ArraySlice<int64> GetSparseIndex(
- int64 sparse_element_number, const ShapeIndex& shape_index = {}) const;
-
- // Returns the value of the element in a sparse literal at the given sparse
- // element number. The sparse element number is the position with in the
- // sparse array's list of (index, value) pairs, and is checked against the
- // total number of (index, value) pairs in the sparse array.
- template <typename NativeT>
- NativeT GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index = {}) const;
-
- // Invokes the "per cell" callback for each element in the provided
- // literal with the element's indices and a string representation of
- // the element's value.
- //
- // This function is useful if you want a polymorphic representation
- // of the tensor's elements (turning it to a string for something
- // like representation in a protobuf).
- //
- // This literal must have a dense layout.
- void EachCellAsString(
- const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- const string& value)>& per_cell) const;
- template <typename NativeT>
- void EachCell(std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const;
-
- // Returns whether every element in this literal is equal to value.
- //
- // value is an int8 because we expect this to be called with small
- // compile-time constants (0, -1, etc.) and so that whatever value you pass
- // can be represented exactly by floating-point types as small as 16 bits.
- //
- // If value doesn't fit in this literal's type, returns false. Values of 1/0
- // are considered equal to true/false; other values are not considered equal
- // to true. Also if this literal is not array-shaped false is returned.
- bool IsAll(int8 value) const;
-
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular floating-point number.
- //
- // If the literal is not a floating-point value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for values that can be expressed precisely as a float,
- // e.g. -0.5. Also if this literal is not array-shaped false is returned.
- bool IsAllFloat(float value) const;
-
- // Like IsAll(const Literal&, int8), except we check whether the literal is
- // equal to a particular complex number.
- //
- // If the literal is not a complex value, this always returns false.
- //
- // This casts value to the type of literal, then compares using ==. The usual
- // admonishments about floating-point equality checks apply. We expect you to
- // use this to check for complex values that can be expressed precisely as
- // float pairs e.g. (-0.5, 1.0).
- //
- // This literal must have a dense layout.
- bool IsAllComplex(complex64 value) const;
-
- // Literal consists entirely of the first element of the literal.
- bool IsAllFirst() const;
-
- // Returns whether this literal is zero at the specified index. This literal
- // must be an array with a dense layout.
- bool IsZero(tensorflow::gtl::ArraySlice<int64> indices) const;
-
- // Returns the count of the elements in the array at the given shape index in
- // this literal.
- int64 element_count(const ShapeIndex& index = {}) const {
- return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
- }
-
- // Returns the count of the elements in the sparse array at the given shape
- // index in this literal, which will be no larger than
- // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()).
- int64 sparse_element_count() const;
-
- // Compute a hash for this literal. This literal must not be a sparse tensor
- // or a tuple containing a sparse tensor.
- size_t Hash() const;
-
- // Converts this literal to the given shape. Returns an error is the
- // conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<std::unique_ptr<Literal>> ConvertToShape(
- const Shape& dest_shape, bool round_f32_to_bf16 = false) const;
-
- // Converts this literal to another primitive type using a bitcast
- // conversion. The to and from primitive types must have the same bit
- // width. Returns an error if the conversion is not possible. This literal
- // must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> BitcastConvert(
- PrimitiveType primitive_dest_type) const;
-
- // Converts this literal to another primitive type. Returns an error if the
- // conversion is not possible. This literal must be array-shaped.
- StatusOr<std::unique_ptr<Literal>> Convert(
- PrimitiveType primitive_dest_type) const;
+ LiteralUtil() = delete;
// Returns a literal scalar representing the first element.
- Literal GetFirstScalarLiteral() const;
-
- // Clones the underlying buffers into a new Literal, or new
- // std::unique_ptr<Literal>.
- Literal Clone() const;
- std::unique_ptr<Literal> CloneToUnique() const;
-
- // TODO(b/67651157): The methods below which perform computation on Literals
- // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with
- // evaluator code which operates on Literals.
- //
- // Creates a new value that has the equivalent value as this
- // literal, but conforms to new_layout; e.g. a literal matrix that was in {0,
- // 1} minor-to-major dimension layout can be re-layed-out as {1, 0}
- // minor-to-major dimension layout and the value in the cell at any given
- // logical index (i0, i1) will be the same.
- //
- // For tuple shaped literals, shape_index should be used to select the inner
- // array that the new layout applies to.
- //
- // Note: this is useful when the client wants to ensure that a value placed in
- // the XLA allocation tracker has a particular layout; for efficiency
- // purposes or avoiding unimplemented operation/layout combinations.
- std::unique_ptr<Literal> Relayout(const Layout& new_layout,
- const ShapeIndex& shape_index = {}) const;
-
- // An overload of Relayout which changes the layout of the entire shape rather
- // than being limited to a single array within the shape.
- std::unique_ptr<Literal> Relayout(const Shape& shape_with_layout) const;
-
- // Creates a new literal by reshaping this literal to have the given
- // dimensions. The total number of elements must not change; The
- // implementation currently only supports monotonic dim0-major layouts.
- // This literal must be an array.
- StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by broadcasting this literal with `dimensions` to
- // yield a literal of shape `result_shape`.
- StatusOr<std::unique_ptr<Literal>> Broadcast(
- const Shape& result_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions) const;
-
- // Creates a new literal by reordering the dimensions of this literal.
- // The given `permutation` must be a permutation of the dimension numbers
- // in the original literal, and it specifies the order of the new dimensions
- // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
- // For example, a transpose call on a literal of shape [3 x 8 x 4] and
- // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
- // This literal must be an array.
- std::unique_ptr<Literal> Transpose(
- tensorflow::gtl::ArraySlice<int64> permutation) const;
-
- // Creates a sub-array from this literal by extracting the indices
- // [start_index, limit_index) of each dimension. The result literal has the
- // same rank and layout as for the given literal. The number of indices in
- // start_indices and limit_indices must be the rank of the literal, and the
- // indices follow the order of the dimensions.
- // This literal must be an array.
- std::unique_ptr<Literal> Slice(
- tensorflow::gtl::ArraySlice<int64> start_indices,
- tensorflow::gtl::ArraySlice<int64> limit_indices) const;
-
- // Creates a literal with a prepended dimension with bound "times"; e.g. a
- // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this
- // literal replicated four times.
- // This literal must be an array.
- template <typename NativeT>
- std::unique_ptr<Literal> Replicate(int64 times) const;
-
- // Creates a new Literal object with the shape specified as parameter.
- // The content of the literal values is the default value of the primitive
- // type of literal itself (0 for numeric types, and false for predicates).
- //
- // Note: It's an antipattern to use this method then immediately call
- // Literal::Populate on the result (since that results in zero initialization,
- // then reinitialization. Conside if a call to MakeUnique<Literal>(shape),
- // followed by the call to Literal::Populate can be used instead.
- static std::unique_ptr<Literal> CreateFromShape(const Shape& shape);
-
- protected:
- // A data structure representing a subshape at a particular ShapeIndex within
- // the literal. For array-shaped ShapeIndexes, this data structure holds the
- // pointer to the memory allocated for the array data.
- class Piece {
- public:
- // Returns the buffer holding the array data for this piece as an array
- // slice. This piece must be array-shaped.
- template <typename NativeT>
- tensorflow::gtl::ArraySlice<NativeT> data() const;
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data();
-
- // Returns the buffer holding the array data for this piece as a void*. This
- // piece must be array-shaped.
- void* untyped_data();
- const void* untyped_data() const;
-
- // Gets or sets an element in the array at the given index. The multi_index
- // is CHECKed against the dimension sizes of the array. This piece must be
- // array-shaped.
- template <typename NativeT>
- NativeT Get(tensorflow::gtl::ArraySlice<int64> index) const;
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> index, NativeT value);
-
- // Gets/sets the buffer holding the array data.
- char* buffer() const { return buffer_; }
- void set_buffer(char* buffer) { buffer_ = buffer; }
-
- // The array of multi-indices that provide the locations of non-zero
- // elements in a sparse array. Only used if
- // LayoutUtil::IsSparseArray(shape()) is true.
- SparseIndexArray* sparse_indices() const { return sparse_indices_; }
- void set_sparse_indices(SparseIndexArray* sparse_indices) {
- sparse_indices_ = sparse_indices;
- }
-
- // Gets or sets the subshape of this piece. This reference points to a
- // subshape within the shape in the containing Literal (Literal::shape_).
- const Shape& subshape() const { return *subshape_; }
- void set_subshape(const Shape* subshape) { subshape_ = subshape; }
-
- // Returns the size in bytes of the buffer holding the array data.
- int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); }
-
- // Returns the number of elements in this piece's array.
- int64 element_count() const {
- // If this is a sparse array, use the number of elements represented by
- // the indices in the associated SparseIndexArray.
- return LayoutUtil::IsSparseArray(subshape())
- ? sparse_indices()->index_count()
- : ShapeUtil::ElementsIn(subshape());
- }
-
- // Returns the child piece at 'index' of this piece.
- Piece& child(int64 index) { return children_[index]; }
-
- // Adds a child piece to this piece's children.
- void emplace_back(Piece child_piece) {
- children_.emplace_back(std::move(child_piece));
- }
-
- // Returns the size of children pieces of this piece.
- int64 children_size() { return children_.size(); }
-
- // Visitor functions that recursively traverses the piece and calls the
- // given function at each child piece. The function has the type:
- // void (const ShapeIndex& index, const Piece& piece)
- template <typename Fn>
- void ForEachSubpiece(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelper(
- [&func](const ShapeIndex& index, const Piece& piece) {
- func(index, piece);
- return Status::OK();
- },
- *this, &index)
- .IgnoreError();
- }
- // Same as above, but the function has the type:
- // Status (const ShapeIndex& index, const Piece& piece)
- // The first non-OK return value is returned by the function.
- template <typename Fn>
- Status ForEachSubpieceWithStatus(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelper(func, *this, &index);
- }
- // Same as above, but the function has the type:
- // Bool (const ShapeIndex& index, const Piece& piece)
- // The first non-true return value is returned by the function.
- template <typename Fn>
- bool ForEachSubpieceWithBool(const Fn& func) const {
- ShapeIndex index;
- return ForEachHelperBool(func, *this, &index);
- }
- // Same as above, but the function has the type:
- // Void (const ShapeIndex& index, Piece& piece)
- template <typename Fn>
- void ForEachMutableSubpiece(const Fn& func) {
- ShapeIndex index;
- return ForEachMutableHelper(
- [&func](const ShapeIndex& index, Piece* piece) {
- func(index, piece);
- return Status::OK();
- },
- const_cast<xla::LiteralBase::Piece*>(this), &index)
- .IgnoreError();
- }
- // Same as above, but the function has the type:
- // Status (const ShapeIndex& index, Piece& piece)
- // The first non-OK return value is returned by the function.
- template <typename Fn>
- Status ForEachMutableSubpieceWithStatus(const Fn& func) {
- ShapeIndex index;
- return ForEachMutableHelper(
- func, const_cast<xla::LiteralBase::Piece*>(this), &index);
- }
-
- // Returns true if this piece and 'other' contain the same data. This piece
- // and 'other' must be array-shaped and compatible.
- bool EqualElements(const Piece& other) const;
-
- // Writes the shape and data (if array-shaped) into the given proto.
- void WriteToProto(LiteralProto* proto) const;
-
- // Copy the data from 'src' into this piece's buffer. Shapes of this piece
- // and src must be compatible.
- Status CopyFrom(const Piece& src);
-
- // Copies the data from the given proto into this piece. The shape of this
- // piece must be equal (not just compatible) to the shape of the proto.
- Status CopyFromProto(const LiteralProto& proto);
-
- // Sorts the elements in a sparse array.
- void SortSparseElements();
-
- private:
- // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'.
- // The first non-OK (or non-true) value is returned by the function.
- // The callable 'func' has the same signature as described above in
- // ForEachSubpiece*.
- template <typename Fn>
- Status ForEachHelper(const Fn& func, const Piece& piece,
- ShapeIndex* index) const {
- TF_RETURN_IF_ERROR(func(*index, piece));
- for (int64 i = 0; i < piece.children_.size(); ++i) {
- index->push_back(i);
- TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index));
- index->pop_back();
- }
- return Status::OK();
- }
- template <typename Fn>
- bool ForEachHelperBool(const Fn& func, const Piece& piece,
- ShapeIndex* index) const {
- if (!func(*index, piece)) {
- return false;
- }
- for (int64 i = 0; i < piece.children_.size(); ++i) {
- index->push_back(i);
- if (!ForEachHelperBool(func, piece.children_[i], index)) {
- return false;
- }
- index->pop_back();
- }
- return true;
- }
- template <typename Fn>
- Status ForEachMutableHelper(const Fn& func, Piece* piece,
- ShapeIndex* index) {
- TF_RETURN_IF_ERROR(func(*index, piece));
- for (int64 i = 0; i < piece->children_.size(); ++i) {
- index->push_back(i);
- TF_RETURN_IF_ERROR(
- ForEachMutableHelper(func, &piece->children_[i], index));
- index->pop_back();
- }
- return Status::OK();
- }
-
- // Recursive helper for EqualElements.
- template <typename NativeT>
- bool EqualElementsInternal(const Piece& other,
- std::vector<int64>* multi_index) const;
-
- // Helper for SortSparseElements that has the element type as a template
- // parameter.
- template <typename NativeT>
- void SortSparseElementsInternal();
-
- // For array-shaped pieces, this is the buffer holding the literal data.
- char* buffer_ = nullptr;
-
- // For sparse arrays, this is the array of indices.
- SparseIndexArray* sparse_indices_ = nullptr;
-
- // The shape of piece. This points into the shape of the containing Literal
- // (Literal::shape_).
- const Shape* subshape_ = nullptr;
-
- // Children pieces for tuple shaped pieces.
- std::vector<Piece> children_ = {};
- }; // class Piece
-
- const Piece& piece(const ShapeIndex& shape_index) const {
- Piece* piece = &const_cast<Piece&>(root_piece());
- for (const auto i : shape_index) {
- DCHECK_GE(i, 0);
- DCHECK_LT(i, piece->children_size());
- piece = &piece->child(i);
- }
- return *piece;
- }
-
- // Returns the piece at the root of the shape.
- virtual const Piece& root_piece() const = 0;
-
- // LiteralSlice and Literal must access Pieces of other Literals.
- friend class Literal;
- friend class LiteralSlice;
- friend class BorrowingLiteral;
-};
-
-// Class representing literal values in XLA.
-//
-// The underlying buffer and shape is always owned by this class.
-class Literal : public LiteralBase {
- public:
- Literal() : Literal(ShapeUtil::MakeNil()) {}
-
- // Create a literal of the given shape. The literal is allocated sufficient
- // memory to hold the shape. Memory is uninitialized.
- explicit Literal(const Shape& shape);
- virtual ~Literal();
-
- // Literals are moveable, but not copyable. To copy a literal use
- // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies
- // of literals which can be expensive.
- Literal(const Literal& other) = delete;
- Literal& operator=(const Literal& other) = delete;
- Literal(Literal&& other);
- // 'allocate_arrays' indicates whether to allocate memory for the arrays in
- // the shape. If false, buffer pointers inside of the Literal::Pieces are set
- // to nullptr.
- Literal(const Shape& shape, bool allocate_arrays);
- Literal& operator=(Literal&& other);
-
- // TODO(b/67651157): Remove this accessor. Literal users should not be able to
- // mutate the shape as this can produce malformed Literals.
- Shape* mutable_shape_do_not_use() { return shape_.get(); }
-
- // Returns a MutableArraySlice view of the array for this literal for the
- // given NativeT (e.g., float). CHECKs if the subshape of the literal at the
- // given ShapeIndex is not array. See primitive_util.h for the mapping from
- // XLA type to native type.
- template <typename NativeT>
- tensorflow::gtl::MutableArraySlice<NativeT> data(
- const ShapeIndex& shape_index = {});
- // Unhide const method from parent class.
- using LiteralBase::data;
-
- // Returns a pointer to the sparse index array. Returns nullptr if the literal
- // is not a sparse array.
- SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {});
-
- // Returns a pointer to the underlying buffer holding the array at the given
- // shape index. CHECKs if the subshape of the literal at the given ShapeIndex
- // is not array.
- void* untyped_data(const ShapeIndex& shape_index = {});
- // Unhide const method from parent class.
- using LiteralBase::untyped_data;
-
- // Populates a literal with a sparse layout with the given indices and values.
- // Each index in the indices array is CHECKed against the dimensions in the
- // literal's shape. If sort is true, then the indices and values will be
- // sorted. If sort is false, then the indices and values are assumed to
- // already be in sorted order. See CreateSparse for an example of how data
- // are populated.
- template <typename NativeT>
- void PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort = true);
-
- // Copy values from 'src_literal' rooted at 'src_shape_index' into this
- // literal rooted at 'dest_shape_index'. The subshape of this literal rooted
- // at 'dest_shape_index' must be compatible with the subshape of 'src_literal'
- // rooted at 'src_shape_index', but need not be arrays.
- Status CopyFrom(const LiteralSlice& src_literal,
- const ShapeIndex& dest_shape_index = {},
- const ShapeIndex& src_shape_index = {});
-
- // Similar to CopyFrom, but with move semantincs. The subshape of this literal
- // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal'
- // (layouts and shapes must match), but need not be arrays. The memory
- // allocated in this literal for the subshape at dest_shape_index is
- // deallocated, and the respective buffers are replaced with those in
- // src_literal. Upon return, src_literal is set to a nil shape (empty tuple).
- Status MoveFrom(Literal&& src_literal,
- const ShapeIndex& dest_shape_index = {});
-
- // Copies the values from src_literal, starting at src_base shape indexes,
- // to this literal, starting at dest_base, where the copy size in each
- // dimension is specified by copy_size.
- // The src_literal and this literal must have the same primitive type,
- // src_base+copy_size must fit the source literal dimensions, as well as
- // dest_base+copy_size must fit the destination literal dimensions.
- // Note: if either src_literal or this literal contains dimensions with zero
- // element, then copy_size must be 0 in these dimensions while the
- // corresponding base indices being 0.
- // This literal and 'src_literal' must be arrays.
- Status CopySliceFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Copies one element from src_literal[src_index] to (*this)[dest_index].
- Status CopyElementFrom(const LiteralSlice& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_index,
- tensorflow::gtl::ArraySlice<int64> dest_index);
-
- // Sets an element in the literal at the given index. The multi_index is
- // CHECKed against the dimension sizes.
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value);
- // Overloads of Set for array literals. CHECKs if the literal is not
- // array-shaped and dense.
- template <typename NativeT>
- void Set(tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value);
-
- // Appends the given element to the literal. If the elements are not appended
- // in sorted order, then SortSparseElements should be called before calling
- // other methods. This literal must have a sparse layout.
- template <typename NativeT>
- void AppendSparseElement(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value, const ShapeIndex& shape_index = {});
-
- // Sorts the elements in a sparse array.
- void SortSparseElements(const ShapeIndex& shape_index = {});
-
- // As Set(), but truncates `value` to the literal element type before storing.
- // This literal must be an array.
- Status SetIntegralAsS64(tensorflow::gtl::ArraySlice<int64> multi_index,
- int64 value);
-
- // Populate this literal with the given values. Examples:
- //
- // // Populate with floats.
- // Array2D<float> float_values = ...
- // literal.PopulateR2FromArray2D(values);
- //
- // // Populate with int32s.
- // literal.PopulateR2<int32>({{1, 2}, {3, 4}});
- //
- // The shape and element type of this literal must match given values. For
- // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2
- // array of S32.
- template <typename NativeT>
- void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values);
- void PopulateR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- void PopulateR2(std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- void PopulateFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- void PopulateR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- void PopulateR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- void PopulateR4FromArray4D(const Array4D<NativeT>& values);
-
- // Populates literal values by calling the generator function for every cell
- // in this literal object.
- //
- // generator must be a callable of the type
- // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
- //
- // This literal must have a dense layout.
- template <typename NativeT, typename FnType>
- Status Populate(const FnType& generator);
-
- // A parallel version of Populate(). This can be used if the generator is
- // thread-safe and the values for the shape's different elements are
- // independent.
- template <typename NativeT, typename FnType>
- Status PopulateParallel(const FnType& generator);
-
- // Fills this literal with the given value.
- template <typename NativeT>
- void PopulateWithValue(NativeT value);
-
- // Factory methods below.
- //
-
- // Serialize from a proto.
- static StatusOr<std::unique_ptr<Literal>> CreateFromProto(
- const LiteralProto& proto);
+ static Literal GetFirstScalarLiteral(const LiteralSlice& literal);
// Creates a new literal of a given rank. To minimize ambiguity (for users
// and the compiler) these CreateR[0-2] methods should explicitly specify the
@@ -883,7 +223,7 @@ class Literal : public LiteralBase {
// As above, but intended to be invoked with move semantics; i.e.
//
// std::vector<std::unique_ptr<Literal>> elements = ...;
- // auto result = Literal::MakeTupleOwned(std::move(elements));
+ // auto result = LiteralUtil::MakeTupleOwned(std::move(elements));
//
// This would have been declared as an overload, but there is ambiguity
// in invocation between the above signature and this one.
@@ -893,7 +233,7 @@ class Literal : public LiteralBase {
// This overload lets you pass a braced list of unique_ptr<Literal>s to
// MakeTupleOwned:
//
- // Literal::MakeTupleOwned(Literal::CreateR1(...), ...).
+ // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...).
//
// Simply relying on the MakeTupleOwned(std::vector<unique_ptr<Literal>>)
// overload doesn't work because std::initializer_list's elements are always
@@ -911,18 +251,8 @@ class Literal : public LiteralBase {
return MakeTupleOwned(std::move(v));
}
- // Returns a vector containing the tuple elements of this Literal as separate
- // Literals. This Literal must be tuple-shaped and can be a nested tuple. The
- // elements are moved into the new Literals; no data is copied. Upon return
- // this Literal is set to a nil shape (empty tuple)
- std::vector<Literal> DecomposeTuple();
-
- // This operation is the inverse of DecomposeTuple. The given elements are
- // moved into the tuple elements of a new tuple-shaped Literal which is
- // returned. Upon return, each of the Literals in 'elements' is set to a nil
- // shape (empty tuple).
- static Literal MoveIntoTuple(
- tensorflow::gtl::MutableArraySlice<Literal> elements);
+ // Create a constant token literal. Token types have no value.
+ static std::unique_ptr<Literal> CreateToken();
// Creates a new Literal object with its values havings the primitive_type
// type, and with dimensions defined by the dimensions parameter.
@@ -991,192 +321,12 @@ class Literal : public LiteralBase {
// dimension 1 equal to 8.
static string MultiIndexAsString(
tensorflow::gtl::ArraySlice<int64> multi_index);
-
- private:
- // Recursively sets the subshapes and buffers of all subpieces rooted at
- // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
- // the shape.
- void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays);
-
- // Returns the piece at the given ShapeIndex.
- Piece& piece(const ShapeIndex& shape_index) {
- return const_cast<Piece&>(LiteralBase::piece(shape_index));
- }
-
- Piece& root_piece() const override { return *root_piece_; };
-
- // Internal template helper for the Literal::CopySliceFrom(), matching its
- // arguments one by one.
- template <typename NativeT>
- Status CopySliceFromInternal(const LiteralBase& src_literal,
- tensorflow::gtl::ArraySlice<int64> src_base,
- tensorflow::gtl::ArraySlice<int64> dest_base,
- tensorflow::gtl::ArraySlice<int64> copy_size);
-
- // Utility structure which is used to create the optimal configuration for
- // a ShapeUtil::ForEachIndex() scan across two literals.
- struct StrideConfig {
- StrideConfig(const Shape& source_shape, const Shape& dest_shape,
- tensorflow::gtl::ArraySlice<int64> dimensions);
-
- // The dimensions of the stride operation. Essentially every dimension
- // will be iterated from base[i] to base[i]+dimensions[i], in step[i]
- // steps.
- tensorflow::gtl::ArraySlice<int64> dimensions;
- DimensionVector base;
- DimensionVector step;
- int64 minor_dimension = 0;
- // The size of the strides for source and destination. One of the two
- // (the one looping through its most minor dimension) will be 1, while
- // the other will be the stride size at the dimension matching the other
- // shape most minor dimension being scanned.
- int64 dest_stride = 1;
- int64 source_stride = 1;
- // The size of the inner loop on the most minor dimension.
- int64 minor_loop_size = 1;
- };
-
- // Literal class always owns the shape. The parent class borrows this shape.
- std::unique_ptr<Shape> shape_;
-
- Piece* root_piece_ = nullptr;
-
- // Implementation details shared between Populate() and PopulateParallel()
- template <typename NativeT, typename FnType>
- Status PopulateInternal(const FnType& generator, bool parallel);
-
- // Deallocate the buffers held by this literal.
- void DeallocateBuffers();
-
- friend class LiteralBase;
};
-std::ostream& operator<<(std::ostream& out, const Literal& literal);
-
-// A read-only view of a Literal. A LiteralSlice contains pointers to shape and
-// literal buffers always owned by others.
-class LiteralSlice : public LiteralBase {
- public:
- LiteralSlice() : LiteralBase() {}
-
- // Implicit conversion constructors.
- LiteralSlice(const LiteralBase& literal);
- LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root);
-
- private:
- const Piece& root_piece() const override { return *root_piece_; };
- const Piece* root_piece_; // Not owned.
-};
-
-// A read-only Literal where the underlying buffers are never owned by this
-// class.
-class BorrowingLiteral : public LiteralBase {
- public:
- BorrowingLiteral() : LiteralBase() {}
-
- // 'src_buf_ptr' is not owned by this class and must outlive the
- // lifetime of this class. It points to an appropirately sized buffer with
- // data interpretered as indicated by 'shape'.
- // This constructor is only used for array shapes.
- BorrowingLiteral(const char* src_buf_ptr, const Shape& shape);
- // Similar as above, except to be used for constructing non-nested tuples.
- BorrowingLiteral(tensorflow::gtl::ArraySlice<const char*> src_buf_ptrs,
- const Shape& shape);
- // TODO(b/79707221): adding constructors for nested tuples as well.
-
- private:
- // Recursively builds the subtree for the given piece and sets the subshapes
- // of the given piece with the given shape.
- void BuildPieceSubtree(const Shape& shape, Piece* piece);
-
- // Accessor for the root piece of this literal.
- const Piece& root_piece() const override { return root_piece_; };
- Piece root_piece_;
-
- // Shape of this literal.
- const Shape shape_;
-};
-
-template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
- << "Attempting to access "
- << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
- << " type, but literal element type is "
- << PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::ArraySlice<NativeT>(
- reinterpret_cast<const NativeT*>(buffer()), element_count());
-}
-
-template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
- << "Attempting to access "
- << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
- << " type, but literal element type is "
- << PrimitiveType_Name(subshape().element_type());
- return tensorflow::gtl::MutableArraySlice<NativeT>(
- reinterpret_cast<NativeT*>(buffer()), element_count());
-}
-
-template <typename NativeT>
-NativeT LiteralBase::Piece::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- CHECK(LayoutUtil::IsDenseArray(subshape()));
- return data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
- subshape(), multi_index)];
-}
-
-template <typename NativeT>
-void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
- CHECK(LayoutUtil::IsDenseArray(subshape()));
- data<NativeT>()[IndexUtil::MultidimensionalIndexToLinearIndex(
- subshape(), multi_index)] = value;
-}
-
-template <typename NativeT>
-tensorflow::gtl::ArraySlice<NativeT> LiteralBase::data(
- const ShapeIndex& shape_index) const {
- return piece(shape_index).data<NativeT>();
-}
-
-template <typename NativeT>
-tensorflow::gtl::MutableArraySlice<NativeT> Literal::data(
- const ShapeIndex& shape_index) {
- return piece(shape_index).data<NativeT>();
-}
-
-template <typename NativeT>
-inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index) const {
- return piece(shape_index).Get<NativeT>(multi_index);
-}
-
-template <typename NativeT>
-inline NativeT LiteralBase::Get(
- tensorflow::gtl::ArraySlice<int64> multi_index) const {
- return root_piece().Get<NativeT>(multi_index);
-}
-
-template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- const ShapeIndex& shape_index, NativeT value) {
- return piece(shape_index).Set<NativeT>(multi_index, value);
-}
-
-template <typename NativeT>
-inline void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
- NativeT value) {
- return root_piece().Set<NativeT>(multi_index, value);
-}
+std::ostream& operator<<(std::ostream& out, const Literal& literal);
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR0(NativeT value) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<NativeT>(), {}));
literal->Set({}, value);
@@ -1184,7 +334,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR1(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
tensorflow::gtl::ArraySlice<NativeT> values) {
auto literal = MakeUnique<Literal>(
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
@@ -1194,7 +344,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
@@ -1207,13 +357,13 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
std::initializer_list<std::initializer_list<NativeT>> values) {
return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values,
const Layout& layout) {
@@ -1238,14 +388,14 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
values) {
return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4WithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values,
@@ -1276,7 +426,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateSparse(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateSparse(
tensorflow::gtl::ArraySlice<int64> dimensions, SparseIndexArray indices,
tensorflow::gtl::ArraySlice<NativeT> values, bool sort) {
int64 num_elements = values.size();
@@ -1291,7 +441,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
values) {
@@ -1299,7 +449,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
@@ -1309,38 +459,40 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR3Projected(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection) {
int64 dim0_size = projection;
@@ -1365,7 +517,7 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4Projected(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
std::initializer_list<std::initializer_list<NativeT>> values,
int64 projection_p, int64 projection_z) {
int64 dim0_size = projection_p;
@@ -1393,49 +545,21 @@ template <typename NativeT>
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateFromArray(values);
}
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout) {
return CreateFromArrayWithLayout(values, layout);
}
-template <typename NativeT>
-NativeT LiteralBase::GetFirstElement() const {
- return data<NativeT>().at(0);
-}
-
-template <typename NativeT>
-NativeT LiteralBase::GetSparseElement(int64 sparse_element_number,
- const ShapeIndex& shape_index) const {
- CHECK(
- LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index)));
- return data<NativeT>(shape_index)[sparse_element_number];
-}
-
-template <typename NativeT>
-void Literal::AppendSparseElement(
- tensorflow::gtl::ArraySlice<int64> multi_index, NativeT value,
- const ShapeIndex& shape_index) {
- Piece& p = piece(shape_index);
- const Shape& subshape = p.subshape();
- CHECK(LayoutUtil::IsSparseArray(subshape));
- int64 rank = ShapeUtil::Rank(subshape);
- CHECK_EQ(multi_index.size(), rank);
- int64 last_element = p.sparse_indices()->index_count();
- CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
- p.sparse_indices()->Append(multi_index);
- CHECK_LT(last_element, p.data<NativeT>().size());
- p.data<NativeT>()[last_element] = value;
-}
-
// Returns an identity matrix (rank 2) with the given row and column count.
template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::MakeIdentityR2(int64 size) {
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
Array2D<NativeT> array(size, size, 0);
for (int64 i = 0; i < size; ++i) {
array(i, i) = 1;
@@ -1444,174 +568,8 @@ template <typename NativeT>
}
template <typename NativeT>
-void LiteralBase::EachCell(
- std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value)>
- per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
- return;
- }
- std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
- do {
- per_cell(indices, Get<NativeT>(indices));
- } while (IndexUtil::BumpIndices(shape(), &indices));
-}
-
-template <typename NativeT>
-inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 1);
- CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size());
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- for (int64 i = 0; i < values.size(); ++i) {
- Set({i}, values[i]);
- }
-}
-
-template <typename NativeT>
-void Literal::PopulateR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(ShapeUtil::Rank(shape()), 2);
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
-
- const int64 dim0_size = values.size();
- const int64 dim1_size = values.begin()->size();
- CHECK_EQ(dim0_size, shape().dimensions(0));
- CHECK_EQ(dim1_size, shape().dimensions(1));
-
- int64 dim0 = 0;
- for (auto inner_list : values) {
- int64 dim1 = 0;
- for (auto value : inner_list) {
- Set({dim0, dim1}, value);
- ++dim1;
- }
- CHECK_EQ(dim1_size, dim1);
- ++dim0;
- }
-}
-
-template <typename NativeT>
-void Literal::PopulateFromArray(const Array<NativeT>& values) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions());
- for (int dim = 0; dim < values.num_dimensions(); ++dim) {
- CHECK_EQ(values.dim(dim), shape().dimensions(dim));
- }
- values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
- NativeT value) { this->Set(indices, value); });
-}
-
-template <typename NativeT>
-void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
- PopulateFromArray(values);
-}
-
-template <typename NativeT>
-void Literal::PopulateSparse(SparseIndexArray indices,
- tensorflow::gtl::ArraySlice<NativeT> values,
- bool sort) {
- CHECK(LayoutUtil::IsSparseArray(shape()));
- int rank = ShapeUtil::Rank(shape());
- CHECK_EQ(indices.rank(), rank);
- int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout());
- CHECK_LE(indices.max_indices(), max_elements);
- int64 num_elements = values.size();
- CHECK_LE(num_elements, max_elements);
- CHECK_EQ(num_elements, indices.index_count());
- auto root_data = root_piece().data<NativeT>();
- // Piece::data() returns an ArraySlice of size equal to the number of indices
- // in the SparseIndexArray. So there is no need to adjust the size of the data
- // here. It is enough to just copy the incoming values into the data buffer.
- std::copy(values.begin(), values.end(), root_data.begin());
- *this->root_piece().sparse_indices() = std::move(indices);
- if (sort) {
- auto root_data = this->root_piece().data<NativeT>();
- this->root_piece().sparse_indices()->SortWithValues(root_data);
- }
- DCHECK(this->root_piece().sparse_indices()->Validate(shape()));
-}
-
-template <typename NativeT, typename FnType>
-Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
- const Shape& this_shape = shape();
- const int64 rank = ShapeUtil::Rank(this_shape);
- TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
- TF_RET_CHECK(this_shape.element_type() ==
- primitive_util::NativeToPrimitiveType<NativeT>());
- tensorflow::gtl::MutableArraySlice<NativeT> literal_data = data<NativeT>();
- if (rank > 0) {
- StrideConfig stride_config(this_shape, this_shape,
- AsInt64Slice(this_shape.dimensions()));
- int64 minor_dimension_size =
- ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
-
- auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- DimensionVector minor_scan_indexes(rank, 0);
- const int64 index =
- IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
- std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
- for (int64 i = 0; i < minor_dimension_size; ++i) {
- minor_scan_indexes[stride_config.minor_dimension] = i;
- literal_data.at(index + i) = generator(minor_scan_indexes);
- }
- };
- if (parallel) {
- ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
- stride_config.dimensions,
- stride_config.step, init_function);
- } else {
- ShapeUtil::ForEachIndex(
- this_shape, stride_config.base, stride_config.dimensions,
- stride_config.step,
- [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
- init_function(indexes);
- return true;
- });
- }
- } else {
- // For scalars.
- literal_data.at(0) = generator({});
- }
- return Status::OK();
-}
-template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
- return PopulateInternal<NativeT>(generator, /*parallel=*/false);
-}
-
-template <typename NativeT, typename FnType>
-Status Literal::PopulateParallel(const FnType& generator) {
- return PopulateInternal<NativeT>(generator, /*parallel=*/true);
-}
-
-template <typename NativeT>
-void Literal::PopulateWithValue(NativeT value) {
- CHECK(ShapeUtil::IsArray(shape()));
- CHECK_EQ(shape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>());
- for (NativeT& element : data<NativeT>()) {
- element = value;
- }
-}
-
-template <typename NativeT>
-/* static */ std::unique_ptr<Literal> Literal::CreateFullWithDescendingLayout(
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateFullWithDescendingLayout(
tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
auto literal = MakeUnique<Literal>(ShapeUtil::MakeShapeWithDescendingLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), dimensions));
@@ -1619,44 +577,9 @@ template <typename NativeT>
return literal;
}
-template <typename NativeT>
-std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
- DimensionVector bounds = {times};
- bounds.reserve(shape().dimensions_size() + 1);
- for (int64 bound : shape().dimensions()) {
- bounds.push_back(bound);
- }
- auto literal =
- MakeUnique<Literal>(ShapeUtil::MakeShape(shape().element_type(), bounds));
- int64 elements = ShapeUtil::ElementsIn(literal->shape());
- if (elements == 0) {
- return literal;
- }
-
- DimensionVector output_indices(bounds.size(), 0);
- tensorflow::gtl::ArraySlice<int64> input_indices = output_indices;
- input_indices.remove_prefix(1);
-
- bool done = false;
- while (!done) {
- const auto element = Get<NativeT>(input_indices);
- literal->Set<NativeT>(output_indices, element);
-
- done = true;
- for (int n = 0; n < output_indices.size(); ++n) {
- ++output_indices[n];
- if (output_indices[n] < bounds[n]) {
- done = false;
- break;
- }
- output_indices[n] = 0;
- }
- }
- return literal;
-}
-
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(
const Shape& shape,
const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
@@ -1670,8 +593,9 @@ template <PrimitiveType type, typename T>
}
template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev) {
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
+ T stddev) {
using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
std::normal_distribution<NativeT> generator(mean, stddev);
return CreateRandomLiteral<type, NativeT>(
@@ -1681,8 +605,8 @@ template <PrimitiveType type, typename E, typename T>
}
template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
- const Shape& shape, T mean, T stddev) {
+/* static */ StatusOr<std::unique_ptr<Literal>>
+LiteralUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
std::minstd_rand0 engine;
return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
}
diff --git a/tensorflow/compiler/xla/overflow_util.h b/tensorflow/compiler/xla/overflow_util.h
new file mode 100644
index 0000000000..8657d3a4bf
--- /dev/null
+++ b/tensorflow/compiler/xla/overflow_util.h
@@ -0,0 +1,50 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Multiply two nonnegative int64's, returning negative for overflow
+inline int64 MultiplyWithoutOverflow(const int64 x, const int64 y) {
+ // Multiply in uint64 rather than int64 since signed overflow is undefined.
+ // Negative values will wrap around to large unsigned values in the casts
+ // (see section 4.7 [conv.integral] of the C++14 standard).
+ const uint64 ux = x;
+ const uint64 uy = y;
+ const uint64 uxy = ux * uy;
+
+ // Check if we overflow uint64, using a cheap check if both inputs are small
+ if (TF_PREDICT_FALSE((ux | uy) >> 32 != 0)) {
+ // Ensure nonnegativity. Note that negative numbers will appear "large"
+ // to the unsigned comparisons above.
+ CHECK(x >= 0 && y >= 0);
+
+ // Otherwise, detect overflow using a division
+ if (ux != 0 && uxy / ux != uy) return -1;
+ }
+
+ // Cast back to signed. Any negative value will signal an error.
+ return static_cast<int64>(uxy);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_OVERFLOW_UTIL_H_
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
index 857aae0a79..6b7fd10d63 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.cc
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
index 45a9fe0127..98dccaa9a2 100644
--- a/tensorflow/compiler/xla/packed_literal_reader.h
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 143c9a2366..b16147e3be 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
}
}
+bool IsArrayType(PrimitiveType primitive_type) {
+ return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
+ primitive_type != OPAQUE && primitive_type != TOKEN;
+}
+
} // namespace primitive_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index b26a10ade6..889e9a1cec 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type);
bool IsIntegralType(PrimitiveType type);
+// Returns true if values of the given primitive type are held in array shapes.
+bool IsArrayType(PrimitiveType primitive_type);
+
// Returns the number of bits in the representation for a given type.
int BitWidth(PrimitiveType type);
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index 83834c1ff6..fe346f9956 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -33,6 +33,7 @@ cc_library(
srcs = ["numpy_bridge.cc"],
hdrs = ["numpy_bridge.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -52,9 +53,9 @@ cc_library(
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
- "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
@@ -70,7 +71,7 @@ tf_py_wrap_cc(
deps = [
":local_computation_builder",
":numpy_bridge",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:cpu_plugin",
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index f808990cad..be55d50b23 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -14,13 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/local_computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/thread_annotations.h"
namespace xla {
-
namespace swig {
// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
@@ -97,6 +98,36 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
return &shaped_buffer_;
}
+ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); }
+
+LocalShapedBufferTuple::LocalShapedBufferTuple(
+ std::vector<LocalShapedBuffer*> elements)
+ : elements_(std::move(elements)) {
+ for (auto* element : elements_) {
+ DCHECK(element != nullptr);
+ }
+}
+
+LocalShapedBufferTuple::~LocalShapedBufferTuple() {
+ for (LocalShapedBuffer* element : elements_) {
+ if (element != nullptr) {
+ delete element;
+ }
+ }
+}
+
+StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
+ LocalShapedBuffer* element = elements_[i];
+ if (element == nullptr) {
+ return InvalidArgument("Attempted to release already-released element %d.",
+ i);
+ }
+ elements_[i] = nullptr;
+ return element;
+}
+
+int LocalShapedBufferTuple::size() const { return elements_.size(); }
+
static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
int device_ordinal,
const Literal& arg) {
@@ -145,73 +176,73 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
GetReplicaCount());
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
- pool.Schedule([this, client, replica, &arguments, &shapes_with_layout,
- &results] {
- StatusOr<int> device_ordinal_status =
- client->ReplicaNumberToDeviceOrdinal(replica);
- if (!device_ordinal_status.ok()) {
- results[replica] = device_ordinal_status.status();
- return;
- }
- const int device_ordinal = device_ordinal_status.ValueOrDie();
- VLOG(3) << "Replica " << replica
- << " mapped to device ordinal for execution: "
- << device_ordinal;
-
- // Transfer arguments in
- std::vector<ScopedShapedBuffer> scoped_buffers;
- scoped_buffers.reserve(arguments.size());
- for (int i = 0; i < arguments.size(); ++i) {
- const Literal& argument = arguments[i];
- const tensorflow::gtl::optional<Shape>& shape_with_layout =
- shapes_with_layout[i];
-
- StatusOr<ScopedShapedBuffer> pushed;
- if (shape_with_layout) {
- std::unique_ptr<Literal> relaid =
- argument.Relayout(shape_with_layout.value());
- pushed = ToBuffer(client, device_ordinal, *relaid);
- } else {
- pushed = ToBuffer(client, device_ordinal, argument);
- }
- if (!pushed.ok()) {
- results[replica] = pushed.status();
- return;
- }
-
- scoped_buffers.push_back(std::move(pushed).ValueOrDie());
- }
-
- // Execute
- std::vector<const ShapedBuffer*> argument_buffers;
- argument_buffers.reserve(scoped_buffers.size());
- for (auto& buffer : scoped_buffers) {
- argument_buffers.push_back(&buffer);
- }
-
- DeviceAssignment device_assignment =
- client->backend()
- .computation_placer()
- ->AssignDevices(GetReplicaCount(), /*computation_count=*/1)
- .ConsumeValueOrDie();
-
- ExecutableRunOptions options;
- options.set_device_ordinal(device_ordinal);
- options.set_allocator(client->backend().memory_allocator());
- options.set_intra_op_thread_pool(
- client->backend().eigen_intra_op_thread_pool_device());
- options.set_device_assignment(&device_assignment);
- StatusOr<ScopedShapedBuffer> result_buffer_status =
- executable_->Run(argument_buffers, options);
- if (!result_buffer_status.ok()) {
- results[replica] = result_buffer_status.status();
- return;
- }
-
- // Transfer result out
- results[replica] = client->ShapedBufferToLiteral(
- std::move(result_buffer_status).ValueOrDie());
- });
+ pool.Schedule(
+ [this, client, replica, &arguments, &shapes_with_layout, &results] {
+ StatusOr<int> device_ordinal_status =
+ client->ReplicaNumberToDeviceOrdinal(replica);
+ if (!device_ordinal_status.ok()) {
+ results[replica] = device_ordinal_status.status();
+ return;
+ }
+ const int device_ordinal = device_ordinal_status.ValueOrDie();
+ VLOG(3) << "Replica " << replica
+ << " mapped to device ordinal for execution: "
+ << device_ordinal;
+
+ // Transfer arguments in
+ std::vector<ScopedShapedBuffer> scoped_buffers;
+ scoped_buffers.reserve(arguments.size());
+ for (int i = 0; i < arguments.size(); ++i) {
+ const Literal& argument = arguments[i];
+ const tensorflow::gtl::optional<Shape>& shape_with_layout =
+ shapes_with_layout[i];
+
+ StatusOr<ScopedShapedBuffer> pushed;
+ if (shape_with_layout) {
+ std::unique_ptr<Literal> relaid =
+ argument.Relayout(shape_with_layout.value());
+ pushed = ToBuffer(client, device_ordinal, *relaid);
+ } else {
+ pushed = ToBuffer(client, device_ordinal, argument);
+ }
+ if (!pushed.ok()) {
+ results[replica] = pushed.status();
+ return;
+ }
+
+ scoped_buffers.push_back(std::move(pushed).ValueOrDie());
+ }
+
+ // Execute
+ std::vector<const ShapedBuffer*> argument_buffers;
+ argument_buffers.reserve(scoped_buffers.size());
+ for (auto& buffer : scoped_buffers) {
+ argument_buffers.push_back(&buffer);
+ }
+
+ DeviceAssignment device_assignment =
+ client->backend()
+ .computation_placer()
+ ->AssignDevices(GetReplicaCount(), /*computation_count=*/1)
+ .ConsumeValueOrDie();
+
+ ExecutableRunOptions options;
+ options.set_device_ordinal(device_ordinal);
+ options.set_allocator(client->backend().memory_allocator());
+ options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ options.set_device_assignment(&device_assignment);
+ StatusOr<ScopedShapedBuffer> result_buffer_status =
+ executable_->Run(argument_buffers, options);
+ if (!result_buffer_status.ok()) {
+ results[replica] = result_buffer_status.status();
+ return;
+ }
+
+ // Transfer result out
+ results[replica] = client->ShapedBufferToLiteral(
+ std::move(result_buffer_status).ValueOrDie());
+ });
}
}
@@ -312,14 +343,11 @@ StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
LocalOp LocalComputationBuilder::Parameter(int64 parameter_number,
const Shape& shape,
const string& name) {
- return builder_.Parameter(parameter_number, shape, name);
+ return xla::Parameter(&builder_, parameter_number, shape, name);
}
-std::unique_ptr<Shape> LocalComputationBuilder::GetShape(
- const LocalOp& operand) {
- auto result = MakeUnique<Shape>();
- *result = builder_.GetShape(operand.op()).ValueOrDie();
- return result;
+StatusOr<Shape> LocalComputationBuilder::GetShape(const LocalOp& operand) {
+ return builder_.GetShape(operand.op());
}
StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
@@ -328,72 +356,70 @@ StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
}
LocalOp LocalComputationBuilder::Infeed(const Shape& shape) {
- return builder_.Infeed(shape);
+ return xla::Infeed(&builder_, shape);
}
void LocalComputationBuilder::Outfeed(const LocalOp& operand,
const Shape& shape,
const string& outfeed_config) {
- builder_.Outfeed(operand.op(), shape, outfeed_config);
+ xla::Outfeed(operand.op(), shape, outfeed_config);
}
LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
- return builder_.ConstantLiteral(literal);
+ return xla::ConstantLiteral(&builder_, literal);
}
LocalOp LocalComputationBuilder::Broadcast(
const LocalOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- return builder_.Broadcast(operand.op(), broadcast_sizes);
+ return xla::Broadcast(operand.op(), broadcast_sizes);
}
LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
const LocalOp& padding_value,
const PaddingConfig& padding_config) {
- return builder_.Pad(operand.op(), padding_value.op(), padding_config);
+ return xla::Pad(operand.op(), padding_value.op(), padding_config);
}
LocalOp LocalComputationBuilder::Reshape(
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- return builder_.Reshape(operand.op(), dimensions, new_sizes);
+ return xla::Reshape(operand.op(), dimensions, new_sizes);
}
LocalOp LocalComputationBuilder::Collapse(
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- return builder_.Collapse(operand.op(), dimensions);
+ return xla::Collapse(operand.op(), dimensions);
}
LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
- return builder_.CrossReplicaSum(operand.op());
+ return xla::CrossReplicaSum(operand.op());
}
LocalOp LocalComputationBuilder::Slice(
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return builder_.Slice(operand.op(), start_indices, limit_indices, strides);
+ return xla::Slice(operand.op(), start_indices, limit_indices, strides);
}
LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
int64 start_index,
int64 limit_index, int64 stride,
int64 dimno) {
- return builder_.SliceInDim(operand.op(), start_index, limit_index, stride,
- dimno);
+ return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
}
LocalOp LocalComputationBuilder::DynamicSlice(
const LocalOp& operand, const LocalOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
+ return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
}
LocalOp LocalComputationBuilder::DynamicUpdateSlice(
const LocalOp& operand, const LocalOp& update,
const LocalOp& start_indices) {
- return builder_.DynamicUpdateSlice(operand.op(), update.op(),
- start_indices.op());
+ return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
}
LocalOp LocalComputationBuilder::ConcatInDim(
@@ -403,7 +429,7 @@ LocalOp LocalComputationBuilder::ConcatInDim(
for (const auto& op : operands) {
xla_ops.push_back(op.op());
}
- return builder_.ConcatInDim(xla_ops, dimension);
+ return xla::ConcatInDim(&builder_, xla_ops, dimension);
}
LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
@@ -413,7 +439,7 @@ LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const LocalOp& source, const LocalOp& init_value,
const LocalComputation& scatter) {
- return builder_.SelectAndScatterWithGeneralPadding(
+ return xla::SelectAndScatterWithGeneralPadding(
operand.op(), select.computation(), window_dimensions, window_strides,
padding, source.op(), init_value.op(), scatter.computation());
}
@@ -426,22 +452,22 @@ LocalOp LocalComputationBuilder::Tuple(
xla_ops.push_back(op.op());
}
- return builder_.Tuple(xla_ops);
+ return xla::Tuple(&builder_, xla_ops);
}
LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
int64 index) {
- return builder_.GetTupleElement(tuple_data.op(), index);
+ return xla::GetTupleElement(tuple_data.op(), index);
}
LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
- return builder_.Dot(lhs.op(), rhs.op());
+ return xla::Dot(lhs.op(), rhs.op());
}
LocalOp LocalComputationBuilder::DotGeneral(
const LocalOp& lhs, const LocalOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
+ return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
}
LocalOp LocalComputationBuilder::ConvGeneralDilated(
@@ -451,14 +477,13 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
- return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides,
- padding, lhs_dilation, rhs_dilation,
- dimension_numbers);
+ return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
+ lhs_dilation, rhs_dilation, dimension_numbers);
}
LocalOp LocalComputationBuilder::ConvertElementType(
const LocalOp& operand, PrimitiveType new_element_type) {
- return builder_.ConvertElementType(operand.op(), new_element_type);
+ return xla::ConvertElementType(operand.op(), new_element_type);
}
LocalOp LocalComputationBuilder::Call(
@@ -469,46 +494,39 @@ LocalOp LocalComputationBuilder::Call(
for (const auto& op : operands) {
xla_ops.push_back(op.op());
}
- return builder_.Call(local_computation.computation(), xla_ops);
+ return xla::Call(&builder_, local_computation.computation(), xla_ops);
}
LocalOp LocalComputationBuilder::Transpose(
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
- return builder_.Transpose(operand.op(), permutation);
+ return xla::Transpose(operand.op(), permutation);
}
LocalOp LocalComputationBuilder::Rev(
const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- return builder_.Rev(operand.op(), dimensions);
+ return xla::Rev(operand.op(), dimensions);
}
LocalOp LocalComputationBuilder::Map(
tensorflow::gtl::ArraySlice<LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<LocalOp> static_operands) {
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
std::vector<XlaOp> xla_ops;
xla_ops.reserve(operands.size());
for (const auto& op : operands) {
xla_ops.push_back(op.op());
}
- std::vector<XlaOp> static_xla_ops;
- static_xla_ops.reserve(static_operands.size());
- for (const auto& op : static_operands) {
- static_xla_ops.push_back(op.op());
- }
-
- return builder_.Map(xla_ops, local_computation.computation(), dimensions,
- static_xla_ops);
+ return xla::Map(&builder_, xla_ops, local_computation.computation(),
+ dimensions);
}
LocalOp LocalComputationBuilder::Reduce(
const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return builder_.Reduce(operand.op(), init_value.op(),
- local_computation.computation(), dimensions_to_reduce);
+ return xla::Reduce(operand.op(), init_value.op(),
+ local_computation.computation(), dimensions_to_reduce);
}
LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
@@ -517,7 +535,7 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
- return builder_.ReduceWindowWithGeneralPadding(
+ return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
window_dimensions, window_strides, padding);
}
@@ -525,27 +543,27 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
const LocalOp& sigma,
const Shape& shape) {
- return builder_.RngNormal(mu.op(), sigma.op(), shape);
+ return xla::RngNormal(mu.op(), sigma.op(), shape);
}
LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
const Shape& shape) {
- return builder_.RngUniform(a.op(), b.op(), shape);
+ return xla::RngUniform(a.op(), b.op(), shape);
}
LocalOp LocalComputationBuilder::While(const LocalComputation& condition,
const LocalComputation& body,
const LocalOp& init) {
- return builder_.While(condition.computation(), body.computation(), init.op());
+ return xla::While(condition.computation(), body.computation(), init.op());
}
LocalOp LocalComputationBuilder::Conditional(
const LocalOp& predicate, const LocalOp& true_operand,
const LocalComputation& true_computation, const LocalOp& false_operand,
const LocalComputation& false_computation) {
- return builder_.Conditional(
- predicate.op(), true_operand.op(), true_computation.computation(),
- false_operand.op(), false_computation.computation());
+ return xla::Conditional(predicate.op(), true_operand.op(),
+ true_computation.computation(), false_operand.op(),
+ false_computation.computation());
}
StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
@@ -561,7 +579,7 @@ StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
#define _FORWARD(method_name, return_sig, args_sig, args) \
return_sig LocalComputationBuilder::method_name args_sig { \
- return builder_.method_name args; \
+ return xla::method_name args; \
}
#define _FORWARD_UNOP(method_name) \
@@ -595,22 +613,25 @@ _FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
+_FORWARD_BINOP(Xor)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
+_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
+_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
-_FORWARD_UNOP(SqrtF32)
-_FORWARD_UNOP(SquareF32)
+_FORWARD_UNOP(Sqrt)
+_FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
-_FORWARD_UNOP(ReciprocalF32)
+_FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
@@ -631,6 +652,54 @@ void DeleteLocalComputation(LocalComputation* computation) {
delete computation;
}
-} // namespace swig
+StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
+ LocalShapedBuffer* local_shaped_buffer) {
+ if (!ShapeUtil::IsTuple(
+ local_shaped_buffer->shaped_buffer()->on_device_shape())) {
+ return InvalidArgument(
+ "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
+ "shape; shape: %s",
+ ShapeUtil::HumanString(
+ local_shaped_buffer->shaped_buffer()->on_device_shape())
+ .c_str());
+ }
+ DeviceMemoryAllocator* allocator =
+ local_shaped_buffer->shaped_buffer()->memory_allocator();
+ ShapedBuffer tuple_buffer = local_shaped_buffer->Release();
+
+ // Extract some metadata we use to construct scoped buffers.
+ const se::Platform* platform = tuple_buffer.platform();
+ int device_ordinal = tuple_buffer.device_ordinal();
+
+ ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
+ const Shape& tuple_shape = tuple_buffer.on_device_shape();
+ std::vector<LocalShapedBuffer*> results;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
+ // Create a shaped buffer for this destructured tuple element.
+ const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
+ VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
+ ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);
+
+ ShapeUtil::ForEachSubshape(
+ subshape, [&](const Shape& s, const ShapeIndex& index) {
+ ShapeIndex original(index);
+ original.push_front(i);
+ se::DeviceMemoryBase* device_memory =
+ shape_tree.mutable_element(original);
+ shaped_buffer.set_buffer(*device_memory, index);
+ *device_memory = se::DeviceMemoryBase();
+ });
+
+ VLOG(3) << "Completed tuple element: " << i;
+ results.push_back(new LocalShapedBuffer(
+ ScopedShapedBuffer(std::move(shaped_buffer), allocator)));
+ }
+ // Deallocate the root buffer.
+ se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
+ TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
+ return new LocalShapedBufferTuple(std::move(results));
+}
+
+} // namespace swig
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 9ac13b6523..690ff277e8 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
-
namespace swig {
// Initializes the number of replicas that XLA will be initialized with (when
@@ -69,10 +68,42 @@ class LocalShapedBuffer {
StatusOr<std::unique_ptr<Literal> > ToLiteral() const;
+ // Transfers ownership of the encapsulated ShapedBuffer to the caller,
+ // analogous to std::unique_ptr::release().
+ ShapedBuffer Release();
+
private:
ScopedShapedBuffer shaped_buffer_;
};
+// Result of a tuple destructuring operation on a LocalShapedBuffer -- this
+// appears to be a simpler mechanism for the time being than an alternative like
+// using SWIG to transform std::vectors into Python lists of SWIG objects
+// directly.
+class LocalShapedBufferTuple {
+ public:
+ // Note: any LocalShapedBuffer elements that are not Release()'d will be
+ // deallocated in the destructor.
+ explicit LocalShapedBufferTuple(std::vector<LocalShapedBuffer*> elements);
+
+ ~LocalShapedBufferTuple();
+
+ // Releases the ith element to the caller. Further attempts to release the ith
+ // element will return an invalid argument error.
+ StatusOr<LocalShapedBuffer*> Release(int i);
+
+ // Returns the number of elements in the destructured tuple.
+ int size() const;
+
+ private:
+ std::vector<LocalShapedBuffer*> elements_;
+};
+
+// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements
+// in LocalShapedBufferTuple form.
+StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
+ LocalShapedBuffer* local_shaped_buffer);
+
// Wraps a LocalExecutable produced by compiling a
// LocalComputation. The Execute method forwards to that of the
// underlying LocalExecutable, and additionally handles tranferring
@@ -156,7 +187,7 @@ class LocalComputationBuilder {
LocalOp Parameter(int64 parameter_number, const Shape& shape,
const string& name);
- std::unique_ptr<Shape> GetShape(const LocalOp& operand);
+ StatusOr<Shape> GetShape(const LocalOp& operand);
// Returns the shape of the current return value for the computation.
StatusOr<Shape> GetReturnValueShape();
@@ -239,8 +270,7 @@ class LocalComputationBuilder {
LocalOp Map(tensorflow::gtl::ArraySlice<LocalOp> operands,
const LocalComputation& local_computation,
- tensorflow::gtl::ArraySlice<int64> dimensions,
- tensorflow::gtl::ArraySlice<LocalOp> static_operands);
+ tensorflow::gtl::ArraySlice<int64> dimensions);
LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value,
const LocalComputation& local_computation,
@@ -302,22 +332,25 @@ class LocalComputationBuilder {
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
+ _FORWARD_BINOP(Xor)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
+ _FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
+ _FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
- _FORWARD_UNOP(SqrtF32)
- _FORWARD_UNOP(SquareF32)
+ _FORWARD_UNOP(Sqrt)
+ _FORWARD_UNOP(Square)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
- _FORWARD_UNOP(ReciprocalF32)
+ _FORWARD_UNOP(Reciprocal)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)
@@ -336,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation);
void DeleteLocalComputation(LocalComputation* computation);
} // namespace swig
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 51412ca474..afdea88cb7 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -109,7 +109,7 @@ limitations under the License.
// Must be included first
#include "tensorflow/python/lib/core/numpy.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -200,6 +200,20 @@ tensorflow::ImportNumpy();
}
}
+%typemap(out) StatusOr<xla::swig::LocalShapedBufferTuple*> {
+ if ($1.ok()) {
+ auto* value = $1.ValueOrDie();
+ {
+ auto* $1 = value;
+ $typemap(out, xla::swig::LocalShapedBufferTuple*)
+ }
+ } else {
+ PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+ SWIG_fail;
+ }
+}
+
+
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
if ($1.ok()) {
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
@@ -851,6 +865,11 @@ tensorflow::ImportNumpy();
})) {
return nullptr;
}
+ if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) {
+ build_options.set_dump_unoptimized_hlo_proto_to(std::move(s));
+ })) {
+ return nullptr;
+ }
if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) {
build_options.set_dump_per_pass_hlo_proto_to(std::move(s));
})) {
@@ -900,6 +919,9 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalShapedBuffer;
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
+%unignore xla::swig::LocalShapedBufferTuple;
+%unignore xla::swig::LocalShapedBufferTuple::Release;
+%unignore xla::swig::LocalShapedBufferTuple::size;
%unignore xla::swig::CompiledLocalComputation;
%unignore xla::swig::CompiledLocalComputation::Execute;
%unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers;
@@ -966,24 +988,28 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Min;
%unignore xla::swig::LocalComputationBuilder::And;
%unignore xla::swig::LocalComputationBuilder::Or;
+%unignore xla::swig::LocalComputationBuilder::Xor;
%unignore xla::swig::LocalComputationBuilder::Not;
%unignore xla::swig::LocalComputationBuilder::Abs;
%unignore xla::swig::LocalComputationBuilder::Exp;
+%unignore xla::swig::LocalComputationBuilder::Expm1;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log;
+%unignore xla::swig::LocalComputationBuilder::Log1p;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;
%unignore xla::swig::LocalComputationBuilder::Sin;
%unignore xla::swig::LocalComputationBuilder::Tanh;
-%unignore xla::swig::LocalComputationBuilder::SqrtF32;
-%unignore xla::swig::LocalComputationBuilder::SquareF32;
+%unignore xla::swig::LocalComputationBuilder::Sqrt;
+%unignore xla::swig::LocalComputationBuilder::Square;
%unignore xla::swig::LocalComputationBuilder::Pow;
%unignore xla::swig::LocalComputationBuilder::IsFinite;
-%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
+%unignore xla::swig::LocalComputationBuilder::Reciprocal;
%unignore xla::swig::LocalComputationBuilder::Neg;
%unignore xla::swig::LocalComputationBuilder::Sort;
+%unignore xla::swig::DestructureLocalShapedBufferTuple;
%unignore xla::swig::DeleteLocalShapedBuffer;
%unignore xla::swig::DeleteLocalComputation;
%unignore xla::swig::DeleteCompiledLocalComputation;
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 68648a3a17..71351abd59 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/python/numpy_bridge.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -374,7 +375,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element));
elements.push_back(std::move(literal));
}
- return Literal::MakeTupleOwned(std::move(elements));
+ return LiteralUtil::MakeTupleOwned(std::move(elements));
} else if (PyArray_Check(o)) {
PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(o);
int rank = PyArray_NDIM(py_array);
@@ -383,7 +384,7 @@ StatusOr<std::unique_ptr<Literal>> XlaLiteralFromPyObject(PyObject* o) {
dimensions[i] = PyArray_DIM(py_array, i);
}
int np_type = PyArray_TYPE(py_array);
- auto literal = Literal::CreateFromDimensions(
+ auto literal = LiteralUtil::CreateFromDimensions(
NumpyTypeToPrimitiveType(np_type), dimensions);
TF_RETURN_IF_ERROR(
CopyNumpyArrayToLiteral(np_type, py_array, literal.get()));
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 64f0aae0f9..a67c93a4fb 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -25,7 +25,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/python/lib/core/numpy.h"
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 50b548afa5..e2b6eaa096 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -89,18 +89,20 @@ _UNARY_OPS = [
'Not',
'Abs',
'Exp',
+ 'Expm1',
'Floor',
'Round',
'Ceil',
'Log',
+ 'Log1p',
'Sign',
'Cos',
'Sin',
'Tanh',
- 'SqrtF32',
- 'SquareF32',
+ 'Sqrt',
+ 'Square',
'IsFinite',
- 'ReciprocalF32',
+ 'Reciprocal',
'Neg',
'Sort',
]
@@ -121,6 +123,7 @@ _BINARY_OPS = [
'Min',
'And',
'Or',
+ 'Xor',
'Pow',
]
@@ -184,6 +187,14 @@ class LocalBuffer(object):
self._delete(self.c_local_shaped_buffer)
self.c_local_shaped_buffer = None
+ def destructure(self):
+ assert self.c_local_shaped_buffer is not None
+ result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer)
+ self.c_local_shaped_buffer = None
+ size = result.size()
+ destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size))
+ return destructured
+
def is_deleted(self):
return self.c_local_shaped_buffer is None
@@ -247,9 +258,12 @@ class Shape(object):
self._dimensions == other._dimensions and
self._minor_to_major == other._minor_to_major)
+ def __ne__(self, other):
+ return not self == other
+
def __repr__(self):
return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
- '_is_tuple={!r}), _minor_to_major={!r}').format(
+ '_is_tuple={!r}, _minor_to_major={!r})').format(
self._dtype, self._dimensions, self._is_tuple,
self._minor_to_major)
@@ -353,6 +367,7 @@ class CompileOptions(object):
def __init__(self):
self.generate_hlo_graph = None
self.dump_optimized_hlo_proto_to = None
+ self.dump_unoptimized_hlo_proto_to = None
self.dump_per_pass_hlo_proto_to = None
self.hlo_profile = False
@@ -446,14 +461,16 @@ class LocalComputation(object):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
+ result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
+
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
- result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape())
result_shape = result_shape.map_leaves(layout_fn)
- compile_options = compile_options or CompileOptions()
- compile_options.result_shape = result_shape
+
+ compile_options = compile_options or CompileOptions()
+ compile_options.result_shape = result_shape
return LocalComputation(
self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)
@@ -894,20 +911,19 @@ class ComputationBuilder(object):
"""
return self._client.Call(computation_to_apply.c_local_computation, operands)
- def Map(self, operands, computation_to_apply, dimensions, static_operands=()):
+ def Map(self, operands, computation_to_apply, dimensions):
"""Enqueues a map operation onto the computation.
Args:
operands: an iterable of LocalOp.
computation_to_apply: a Computation object.
dimensions: dimensions over which to apply map the function.
- static_operands: auxiliary arguments passed to the applied computation.
Returns:
A LocalOp representing the added Map op.
"""
return self._client.Map(operands, computation_to_apply.c_local_computation,
- dimensions, static_operands)
+ dimensions)
def Reduce(self, operand, init_value, computation_to_apply, dimensions):
"""Enqueues a reduction operation onto the computation.
@@ -1112,6 +1128,61 @@ class ComputationBuilder(object):
dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd))
return dimension_numbers
+ def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
+ rhs_dilation, dimension_numbers):
+ """Enqueues a ConvGeneralDilated operation onto the computation.
+
+ Args:
+ lhs: LocalOp for the rank N+2 array of inputs.
+ rhs: LocalOp for the rank N+2 array of kernel weights.
+ window_strides: length-N array-like of integer kernel strides.
+ padding: length-N array-like of pairs of integers of (low, high) padding.
+ lhs_dilation: length-N array-like of integer dilation factors.
+ rhs_dilation: length-N array-like of integer dilation factors.
+ dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a
+ triple (lhs_spec, rhs_spec, out_spec) where each element is a string of
+ length N+2 identifying by position (1) batch dimensions in lhs, rhs, and
+ the output with the character 'N', (2) feature dimensions in lhs and the
+ output with the character 'C', (3) input and output feature dimensions
+ in rhs with the characters 'I' and 'O' respectively, and (4) spatial
+ dimension correspondences between lhs, rhs, and the output using any
+ distinct characters. For example, to indicate dimension numbers
+ consistent with the Conv operation with two spatial dimensions, one
+ could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate
+ dimension numbers consistent with the TensorFlow Conv2D operation, one
+ could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of
+ convolution dimension specification, window strides are associated with
+ spatial dimension character labels according to the order in which the
+ labels appear in the rhs_spec string, so that window_strides[0] is
+ matched with the dimension corresponding to the first character
+ appearing in rhs_spec that is not 'I' or 'O'.
+
+ Returns: a LocalOp representing the ConvGenralDilated operation.
+ """
+ if not isinstance(dimension_numbers,
+ xla_data_pb2.ConvolutionDimensionNumbers):
+ lhs_spec, rhs_spec, out_spec = dimension_numbers
+ dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers()
+
+ dimension_numbers.input_batch_dimension = lhs_spec.index('N')
+ dimension_numbers.input_feature_dimension = lhs_spec.index('C')
+ dimension_numbers.output_batch_dimension = out_spec.index('N')
+ dimension_numbers.output_feature_dimension = out_spec.index('C')
+ dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O')
+ dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I')
+
+ dimension_numbers.kernel_spatial_dimensions.extend(
+ i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'})
+ dimension_numbers.input_spatial_dimensions.extend(
+ sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}),
+ key=lambda i: rhs_spec.index(lhs_spec[i])))
+ dimension_numbers.output_spatial_dimensions.extend(
+ sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}),
+ key=lambda i: rhs_spec.index(out_spec[i])))
+ return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
+ lhs_dilation, rhs_dilation,
+ dimension_numbers)
+
def _forward_methods_to_local_builder():
"""Forward remaining ComputationBuilder methods to the C API.
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index e3d393bccc..0564ddcb85 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -157,6 +157,13 @@ class ComputationsWithConstantsTest(LocalComputationTest):
c.Constant(NumpyArrayBool([True, True, False, False])))
self._ExecuteAndCompareExact(c, expected=[True, True, True, False])
+ def testBooleanXor(self):
+ c = self._NewComputation()
+ c.Xor(
+ c.Constant(NumpyArrayBool([True, False, True, False])),
+ c.Constant(NumpyArrayBool([True, True, False, False])))
+ self._ExecuteAndCompareExact(c, expected=[False, True, True, False])
+
def testSum2DF32(self):
c = self._NewComputation()
c.Add(
@@ -365,6 +372,55 @@ class LocalBufferTest(LocalComputationTest):
with self.assertRaises(ValueError):
compiled_c.ExecuteWithLocalBuffers([arg_buffer])
+ def testDestructureTupleEmpty(self):
+ t = ()
+ local_buffer = xla_client.LocalBuffer.from_pyval(t)
+ pieces = local_buffer.destructure()
+ self.assertTrue(local_buffer.is_deleted())
+ self.assertEqual(len(pieces), 0)
+
+ def testDestructureTupleOneArrayElement(self):
+ t = (np.array([1, 2, 3, 4], dtype=np.int32),)
+ local_buffer = xla_client.LocalBuffer.from_pyval(t)
+ pieces = local_buffer.destructure()
+ self.assertTrue(local_buffer.is_deleted())
+ self.assertEqual(len(pieces), 1)
+ array = pieces[0]
+ got = array.to_py()
+ want = NumpyArrayS32([1, 2, 3, 4])
+ np.testing.assert_equal(want, got)
+
+ def testDestructureTupleTwoArrayElementDifferentType(self):
+ t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
+ np.array([2, 3, 4, 5], dtype=np.int32))
+ local_buffer = xla_client.LocalBuffer.from_pyval(t)
+ pieces = local_buffer.destructure()
+ self.assertTrue(local_buffer.is_deleted())
+ self.assertEqual(len(pieces), 2)
+ array0, array1 = pieces
+ got = array0.to_py()
+ want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0])
+ np.testing.assert_equal(want, got)
+ got = array1.to_py()
+ want = NumpyArrayS32([2, 3, 4, 5])
+ np.testing.assert_equal(want, got)
+
+ def testDestructureTupleNested(self):
+ t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
+ local_buffer = xla_client.LocalBuffer.from_pyval(t)
+ pieces = local_buffer.destructure()
+ self.assertTrue(local_buffer.is_deleted())
+ self.assertEqual(len(pieces), 2)
+ tuple0, array1 = pieces
+ got = array1.to_py()
+ want = NumpyArrayS32([5])
+ np.testing.assert_equal(want, got)
+ got = tuple0.to_py()
+ self.assertEqual(type(got), tuple)
+ self.assertEqual(len(got), 2)
+ np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0])
+ np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1])
+
class SingleOpTest(LocalComputationTest):
"""Tests for single ops.
@@ -519,6 +575,46 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=result)
+ def testConvGeneralDilatedF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 1, 2, 3)
+ rhs = a(1, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
+ def testConvGeneralDilatedPermutedF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 1, 2, 3)
+ rhs = a(1, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+
+ dimension_numbers = ("NHWC", "OIHW", "CWNH")
+ c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))),
+ c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]]]])
+ self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
@@ -531,6 +627,12 @@ class SingleOpTest(LocalComputationTest):
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
+ def testExpm1(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Expm1(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.expm1(arr))
+
def testRound(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@@ -543,6 +645,12 @@ class SingleOpTest(LocalComputationTest):
c.Log(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.log(arr))
+ def testLog1p(self):
+ c = self._NewComputation()
+ arr = NumpyArrayF32([3.3, 12.1])
+ c.Log1p(c.Constant(arr))
+ self._ExecuteAndCompareClose(c, expected=np.log1p(arr))
+
def testNeg(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@@ -1067,14 +1175,6 @@ class EmbeddedComputationsTest(LocalComputationTest):
self._CreateBinaryDivF64Computation(), [0])
self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0])
- def DISABLED_testMapWithStaticOperands(self):
- c = self._NewComputation()
- factor = c.ConstantF32Scalar(3.0)
- c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))],
- self._CreateMulF32ByParamComputation(), [0],
- static_operands=[factor])
- self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0])
-
def testSelectAndScatterF32(self):
c = self._NewComputation()
c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])),
diff --git a/tensorflow/compiler/xla/python_api/BUILD b/tensorflow/compiler/xla/python_api/BUILD
new file mode 100644
index 0000000000..8999cda5ef
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/BUILD
@@ -0,0 +1,36 @@
+# Description:
+# Python API for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+py_library(
+ name = "types",
+ srcs = ["types.py"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "xla_shape",
+ srcs = ["xla_shape.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ ],
+)
+
+py_library(
+ name = "xla_literal",
+ srcs = ["xla_literal.py"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ ":xla_shape",
+ "//tensorflow/compiler/xla:xla_data_proto_py",
+ ],
+)
diff --git a/tensorflow/compiler/xla/python_api/types.py b/tensorflow/compiler/xla/python_api/types.py
new file mode 100644
index 0000000000..b60f8dce92
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/types.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""Utilities for XLA-specific Python types."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+
+# Records corresponsence between a XLA primitive type and Python/Numpy types.
+#
+# primitive_type: value of type xla_data_pb2.PrimitiveType
+# numpy_dtype: corresponsing Numpy "dtype" (like np.float32)
+# literal_field_name: name of the field in the LiteralProto message elements
+# of this type go into.
+# literal_field_type: type of the field named 'literal_field_name'.
+#
+# TODO(eliben): figure out how to avoid knowing the extra Python type and the
+# astype cast when writing into Literals.
+TypeConversionRecord = collections.namedtuple('TypeConversionRecord', [
+ 'primitive_type', 'numpy_dtype', 'literal_field_name', 'literal_field_type'
+])
+
+# Maps from XLA primitive types to TypeConversionRecord.
+MAP_XLA_TYPE_TO_RECORD = {
+ xla_data_pb2.F16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F16,
+ numpy_dtype=np.float16,
+ literal_field_name='f16s',
+ literal_field_type=float),
+ xla_data_pb2.F32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F32,
+ numpy_dtype=np.float32,
+ literal_field_name='f32s',
+ literal_field_type=float),
+ xla_data_pb2.F64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.F64,
+ numpy_dtype=np.float64,
+ literal_field_name='f64s',
+ literal_field_type=float),
+ xla_data_pb2.S8:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S8,
+ numpy_dtype=np.int8,
+ literal_field_name='s8s',
+ literal_field_type=int),
+ xla_data_pb2.S16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S16,
+ numpy_dtype=np.int16,
+ literal_field_name='s16s',
+ literal_field_type=int),
+ xla_data_pb2.S32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S32,
+ numpy_dtype=np.int32,
+ literal_field_name='s32s',
+ literal_field_type=int),
+ xla_data_pb2.S64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.S64,
+ numpy_dtype=np.int64,
+ literal_field_name='s64s',
+ literal_field_type=int),
+ xla_data_pb2.U8:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U8,
+ numpy_dtype=np.uint8,
+ literal_field_name='s8s',
+ literal_field_type=int),
+ xla_data_pb2.U16:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U16,
+ numpy_dtype=np.uint16,
+ literal_field_name='s16s',
+ literal_field_type=int),
+ xla_data_pb2.U32:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U32,
+ numpy_dtype=np.uint32,
+ literal_field_name='s32s',
+ literal_field_type=int),
+ xla_data_pb2.U64:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.U64,
+ numpy_dtype=np.uint64,
+ literal_field_name='s64s',
+ literal_field_type=int),
+ xla_data_pb2.PRED:
+ TypeConversionRecord(
+ primitive_type=xla_data_pb2.PRED,
+ numpy_dtype=np.bool,
+ literal_field_name='preds',
+ literal_field_type=bool)
+}
+
+# Maps from Numpy dtypes to TypeConversionRecord.
+# Note the conversion on the key. Numpy has a known issue wherein dtype hashing
+# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus,
+# when keying by dtype in this dict, we use the string form of dtypes.
+MAP_DTYPE_TO_RECORD = {
+ str(np.dtype(record.numpy_dtype)): record
+ for record in MAP_XLA_TYPE_TO_RECORD.values()
+}
diff --git a/tensorflow/compiler/xla/python_api/xla_literal.py b/tensorflow/compiler/xla/python_api/xla_literal.py
new file mode 100644
index 0000000000..b040098c29
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/xla_literal.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""XLA LiteralProto utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import types
+from tensorflow.compiler.xla.python_api import xla_shape
+
+
+def ConvertLiteralToNumpyArray(literal):
+ """Converts a XLA literal to a Numpy array."""
+ element_type = literal.shape.element_type
+ if element_type == xla_data_pb2.TUPLE:
+ return tuple(
+ ConvertLiteralToNumpyArray(subliteral)
+ for subliteral in literal.tuple_literals)
+
+ type_record = types.MAP_XLA_TYPE_TO_RECORD[element_type]
+ if not literal.shape.dimensions:
+ return np.array(
+ getattr(literal, type_record.literal_field_name)[0],
+ type_record.numpy_dtype)
+ else:
+ # Infer the proper Numpy order from the LiteralProto's layout. The repeated
+ # field representing the array's content in the Literal is linearized.
+ # Reading is done in two steps:
+ #
+ # 1. Read the array as 1D from the LiteralProto repeated field.
+ # 2. Reshape the array to its proper shape, using the right order depending
+ # on the LiteralProto's layout.
+ layout_order = literal.shape.layout.minor_to_major
+ numpy_shape = tuple(literal.shape.dimensions)
+ if layout_order == range(len(literal.shape.dimensions)):
+ numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='F')
+ elif layout_order == range(len(literal.shape.dimensions) - 1, -1, -1):
+ numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C')
+ else:
+ raise NotImplementedError('Unsupported layout: {0}'.format(layout_order))
+ ndarray = np.array(
+ getattr(literal, type_record.literal_field_name),
+ copy=False,
+ dtype=type_record.numpy_dtype)
+ return numpy_reshaper(ndarray)
+
+
+def _ConvertNumpyArrayToLiteral(ndarray):
+ """Converts a Numpy array to a XLA literal."""
+ type_record = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)]
+ literal = xla_data_pb2.LiteralProto()
+ literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(ndarray).message)
+
+ if ndarray.ndim == 0:
+ getattr(literal, type_record.literal_field_name).append(
+ np.asscalar(ndarray.astype(type_record.literal_field_type)))
+ else:
+ # Ndarrays with boolean dtypes need special type conversion with protobufs
+ if ndarray.dtype in {np.bool_, np.dtype('bool')}:
+ for element in np.nditer(ndarray):
+ getattr(literal, type_record.literal_field_name).append(
+ type_record.literal_field_type(element))
+ else:
+ ndarray_flat = ndarray.ravel(order='A')
+ getattr(literal, type_record.literal_field_name).extend(ndarray_flat)
+ return literal
+
+
+def ConvertNumpyArrayToLiteral(value):
+ """Converts a Numpy array or a nested tuple thereof to an XLA literal."""
+ if isinstance(value, tuple):
+ literal = xla_data_pb2.LiteralProto()
+ literal.shape.CopyFrom(xla_shape.CreateShapeFromNumpy(value).message)
+ for component in value:
+ component_literal = literal.tuple_literals.add()
+ component_literal.CopyFrom(ConvertNumpyArrayToLiteral(component))
+ return literal
+ else:
+ return _ConvertNumpyArrayToLiteral(value)
diff --git a/tensorflow/compiler/xla/python_api/xla_shape.py b/tensorflow/compiler/xla/python_api/xla_shape.py
new file mode 100644
index 0000000000..6af2895803
--- /dev/null
+++ b/tensorflow/compiler/xla/python_api/xla_shape.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ======================================
+"""XLA Shape utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.xla import xla_data_pb2
+from tensorflow.compiler.xla.python_api import types
+
+
+class Shape(object):
+ """Wraps a xla_data_pb2.Shape message with a convenient Python type.
+
+ Provides direct access to the underlying xla_data_pb2.Shape message in the
+ message attribute, along with accessor wrappers to the message's fields.
+ Avoid direct access to .message unless interacting directly with protobuf APIs
+ like CopyFrom. In other words, prefer hauling the shape around in a Shape, and
+ only access .message when strictly required by the protobuf API.
+ """
+
+ def __init__(self, element_type, dimensions, layout=None):
+ """Creates a new XLA Shape.
+
+ Args:
+ element_type: element type from xla_data_pb2.
+ dimensions: sequence of dimensions sizes (integers), or sequence
+ of Shapes in the case of a tuple, i.e. when element_type is
+ TUPLE.
+ layout: optional minor_to_major sequence for layout. If not given, the
+ default major-to-minor layout is used.
+
+ Raises:
+ ValueError: if element_type is TUPLE but dimensions are not Shape objects.
+ """
+ self.message = xla_data_pb2.Shape()
+ self.message.element_type = element_type
+ if element_type == xla_data_pb2.TUPLE:
+ if not all(isinstance(subshape, Shape) for subshape in dimensions):
+ raise ValueError(
+ 'XLA tuple requires sequence of Shape objects as dimensions')
+ self._tuple_shapes = tuple(dimensions)
+ for component_shape in self._tuple_shapes:
+ component_message = self.message.tuple_shapes.add()
+ component_message.CopyFrom(component_shape.message)
+ else:
+ self.message.dimensions.extend(dimensions)
+ if layout is None:
+ layout = list(reversed(range(len(dimensions))))
+ self.message.layout.format = xla_data_pb2.DENSE
+ self.message.layout.minor_to_major.extend(layout)
+
+ def element_type(self):
+ return self.message.element_type
+
+ def is_tuple(self):
+ return self.element_type() == xla_data_pb2.TUPLE
+
+ def dimensions(self):
+ if self.is_tuple():
+ raise ValueError('Tuple shape has no dimensions. Try tuple_shapes()?')
+ return self.message.dimensions
+
+ def tuple_shapes(self):
+ """If this is a tuple, returns its sequence of constituent Shape objects.
+
+ Returns:
+ Tuple sub-shapes.
+
+ Raises:
+ ValueError: if this is not a tuple.
+ """
+ if not self.is_tuple():
+ raise ValueError('tuple_shapes() called on a non-tuple shape')
+ return self._tuple_shapes
+
+ def layout(self):
+ return self.message.layout
+
+ @staticmethod
+ def from_pyval(pyval):
+ return CreateShapeFromNumpy(pyval)
+
+
+def _CreateShapeFromNumpy(ndarray): # pylint: disable=invalid-name
+ """Create a Shape from a given Numpy array.
+
+ Args:
+ ndarray: Numpy array.
+
+ Returns:
+ A Shape object.
+ """
+ element_type = types.MAP_DTYPE_TO_RECORD[str(ndarray.dtype)].primitive_type
+ dimensions = ndarray.shape
+
+ # Set the shape's layout based on the ordering of ndarray.
+ # Numpy arrays come in two orders: Fortran (column-major) and C (row-major).
+ if np.isfortran(ndarray):
+ # Column-major layout. This corresponds to a "dimension order is
+ # minor-to-major" layout in XLA.
+ layout = range(ndarray.ndim)
+ else:
+ # Row-major layout. This corresponds to a "dimension order is
+ # major-to-minor" layout int XLA.
+ layout = list(reversed(xrange(ndarray.ndim)))
+
+ return Shape(element_type, dimensions, layout)
+
+
+def CreateShapeFromNumpy(value): # pylint: disable=invalid-name
+ """Create a Shape from a Numpy array or a nested tuple structure thereof.
+
+ Args:
+ value: Numpy array or (possibly nested) tuple structure that bottoms out in
+ Numpy arrays.
+
+ Returns:
+ A Shape object.
+ """
+ if isinstance(value, tuple):
+ return Shape(
+ xla_data_pb2.TUPLE,
+ [CreateShapeFromNumpy(component) for component in value])
+ else:
+ return _CreateShapeFromNumpy(value)
+
+
+def CreateShapeFromDtypeAndTuple(dtype, shape_tuple): # pylint: disable=invalid-name
+ """Create a shape from a Numpy dtype and a sequence of nonnegative integers.
+
+ Args:
+ dtype: a numpy dtype, e.g. np.dtype('int32').
+ shape_tuple: a sequence of nonnegative integers.
+
+ Returns:
+ A Shape object.
+ """
+ element_type = types.MAP_DTYPE_TO_RECORD[str(dtype)].primitive_type
+ return Shape(element_type, shape_tuple)
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index c289c84cff..6397f1f479 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -510,8 +511,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
ConvolutionDimensionNumbers dnums) {
HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
std::array<int64, 2> ordered_kernel_strides;
std::array<int64, 2> ordered_input_dimensions;
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 9da9bc60a2..8091bed499 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -53,7 +53,7 @@ class ReferenceUtilTest : public ::testing::Test {
TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -65,7 +65,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
{11.f, 12.f},
});
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -73,7 +73,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
- auto actual_literal = Literal::CreateR1<float>(*result);
+ auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
ErrorSpec(0.0001));
}
@@ -81,13 +81,13 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
- auto actual_literal = Literal::CreateR1<float>(*result);
+ auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
- auto result = Literal::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
+ auto result = LiteralUtil::CreateR1<float>(ReferenceUtil::Reduce4DTo1D(
Array4D<float>(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2},
[](float a, float b) { return a + b; }));
LiteralTestUtil::ExpectR1Equal<float>({0}, *result);
@@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) {
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
ErrorSpec(0.0001));
}
@@ -106,7 +106,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
return value + row + col;
};
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
*actual_literal, ErrorSpec(0.0001));
}
@@ -117,7 +117,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
@@ -134,7 +134,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
};
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
@@ -144,7 +144,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
TEST_F(ReferenceUtilTest, SliceArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
*actual_literal, ErrorSpec(0.0001));
@@ -152,7 +152,7 @@ TEST_F(ReferenceUtilTest, SliceArray2D) {
TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
- auto actual_literal = Literal::CreateR2FromArray2D(*result);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
*actual_literal, ErrorSpec(0.0001));
@@ -164,7 +164,7 @@ TEST_F(ReferenceUtilTest, SliceArray3D) {
auto result =
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}});
- auto actual_literal = Literal::CreateR3FromArray3D(*result);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
@@ -177,7 +177,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
auto result =
ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}});
- auto actual_literal = Literal::CreateR3FromArray3D(*result);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*result);
LiteralTestUtil::ExpectR3Near<float>(
{{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
@@ -190,7 +190,7 @@ TEST_F(ReferenceUtilTest, SliceArray4D) {
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}},
{{1, 1, 1, 1}});
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
@@ -203,7 +203,7 @@ TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}},
{{1, 2, 2, 2}});
- auto actual_literal = Literal::CreateR4FromArray4D(*result);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
LiteralTestUtil::ExpectR4Near<float>(
{{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
@@ -218,7 +218,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithSamePadding) {
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kSame);
Array3D<float> expected = {{{17, 28, 39, 20}}};
- auto actual_literal = Literal::CreateR3FromArray3D(*actual);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -231,7 +231,7 @@ TEST_F(ReferenceUtilTest, ConvArray3DWithValidPadding) {
ReferenceUtil::ConvArray3D(input, weights, 1, Padding::kValid);
Array3D<float> expected = {{{17, 28, 39}}};
- auto actual_literal = Literal::CreateR3FromArray3D(*actual);
+ auto actual_literal = LiteralUtil::CreateR3FromArray3D(*actual);
LiteralTestUtil::ExpectR3NearArray3D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -266,7 +266,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
}));
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -300,7 +300,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
}));
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -356,7 +356,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
}});
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -409,7 +409,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
Array4D<float> expected({{{{2514, 2685}}}});
// clang-format on
- auto actual_literal = Literal::CreateR4FromArray4D(*actual);
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
ErrorSpec(0.0001));
@@ -422,7 +422,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
auto actual = ReferenceUtil::ApplyElementwise2D(
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
- auto actual_literal = Literal::CreateR2FromArray2D(*actual);
+ auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
*actual_literal, ErrorSpec(0.0001));
}
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 0d56a9a477..0b1cec1925 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -39,10 +39,10 @@ tf_cc_binary(
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
+ "//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
- "@grpc//:grpc++_unsecure",
],
)
@@ -54,6 +54,7 @@ tf_cc_test(
],
deps = [
":grpc_stub",
+ "//tensorflow:grpc++",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -61,7 +62,6 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "@grpc//:grpc++_unsecure",
],
)
@@ -71,9 +71,9 @@ cc_library(
hdrs = ["grpc_service.h"],
deps = [
":xla_service_proto",
+ "//tensorflow:grpc++",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
- "@grpc//:grpc++_unsecure",
],
)
diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
index 313f11a9a9..90efee50b4 100644
--- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "grpc++/create_channel.h"
-#include "grpc++/security/credentials.h"
+#include "grpcpp/create_channel.h"
+#include "grpcpp/security/credentials.h"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
@@ -85,19 +85,19 @@ TEST_F(GRPCClientTestBase, ItsAlive) {
TEST_F(GRPCClientTestBase, AxpyTenValues) {
XlaBuilder builder("axpy_10");
- auto alpha = builder.ConstantR0<float>(3.1415926535);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto y = builder.ConstantR1<float>(
- {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
- auto ax = builder.Mul(alpha, x);
- auto axpy = builder.Add(ax, y);
+ auto alpha = ConstantR0<float>(&builder, 3.1415926535);
+ auto x = ConstantR1<float>(
+ &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto y = ConstantR1<float>(
+ &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
+ auto ax = Mul(alpha, x);
+ Add(ax, y);
std::vector<float> expected = {
1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<float>(expected);
+ LiteralUtil::CreateR1<float>(expected);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h
index 5cd573167a..ca1b09b648 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service.h
+++ b/tensorflow/compiler/xla/rpc/grpc_service.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
#define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_
-#include "grpc++/server_context.h"
+#include "grpcpp/server_context.h"
#include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h"
#include "tensorflow/compiler/xla/service/service.h"
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index e29908ccec..c68c857c30 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -15,9 +15,9 @@ limitations under the License.
// Basic server binary that exposes a xla::Service through a GRPC interface
// on a configurable port.
-#include "grpc++/security/server_credentials.h"
-#include "grpc++/server.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/security/server_credentials.h"
+#include "grpcpp/server.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/init_main.h"
diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto
index 92eb19ec0f..551ae895e0 100644
--- a/tensorflow/compiler/xla/rpc/xla_service.proto
+++ b/tensorflow/compiler/xla/rpc/xla_service.proto
@@ -115,10 +115,6 @@ service XlaService {
returns (ComputeConstantResponse) {
}
- // Retrieves the inferred shape for a value within a computation.
- rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) {
- }
-
// Requests one or more device handles from the target. The returned device
// handles can be used to specify the device on which to execute computations
// or transfer data.
@@ -132,18 +128,6 @@ service XlaService {
returns (CreateChannelHandleResponse) {
}
- // Requests that the referenced computation be specialized for the provided
- // arguments for subsequent execution. This permits things such as value
- // specialization.
- rpc Specialize(SpecializeRequest) returns (SpecializeResponse) {
- }
-
- // Modifies the provided computation so that subsequent executions
- // will compute the provided ComputationDataHandle, rather than the
- // last expression enqueued on that Computation.
- rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) {
- }
-
// Invokes the provided computation with the provided global data passed as
// immutable arguments. The request contains the whole computation graph.
// Returns global data output and execution timing.
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2b14b63ea8..6e3431df52 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -22,13 +22,6 @@ load(
)
xla_proto_library(
- name = "session_proto",
- srcs = ["session.proto"],
- visibility = ["//visibility:public"],
- deps = ["//tensorflow/compiler/xla:xla_data_proto"],
-)
-
-xla_proto_library(
name = "hlo_proto",
srcs = ["hlo.proto"],
visibility = ["//visibility:public"],
@@ -39,6 +32,7 @@ tf_proto_library_py(
name = "hlo_proto", # bzl adds a _py suffix only to the OSS target.
srcs = ["hlo.proto"],
visibility = ["//visibility:public"],
+ deps = ["//tensorflow/compiler/xla:xla_data_proto_py"],
)
xla_proto_library(
@@ -142,7 +136,7 @@ cc_library(
":hlo_dce",
":hlo_pass",
":tuple_simplifier",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -233,6 +227,7 @@ cc_library(
":hlo",
":hlo_query",
":shape_inference",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -250,7 +245,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_evaluator",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -276,6 +271,7 @@ cc_library(
"dfs_hlo_visitor.cc",
"hlo_computation.cc",
"hlo_instruction.cc",
+ "hlo_instructions.cc",
"hlo_module.cc",
"hlo_opcode.cc",
"hlo_sharding.cc",
@@ -287,17 +283,19 @@ cc_library(
"hlo_computation.h",
"hlo_domain_metadata.h",
"hlo_instruction.h",
+ "hlo_instructions.h",
"hlo_module.h",
"hlo_opcode.h",
"hlo_sharding.h",
],
deps = [
+ ":hlo_casting_utils",
":hlo_module_config",
":hlo_proto",
":hlo_reachability",
":name_uniquer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:array",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_tree",
@@ -349,8 +347,8 @@ tf_cc_test(
":hlo",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -388,8 +386,8 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -399,28 +397,20 @@ tf_cc_test(
srcs = ["hlo_matchers_test.cc"],
deps = [
":hlo_matchers",
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
-cc_library(
- name = "versioned_computation_handle",
- srcs = ["versioned_computation_handle.cc"],
- hdrs = ["versioned_computation_handle.h"],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- ],
-)
-
tf_cc_test(
name = "hlo_instruction_test",
srcs = ["hlo_instruction_test.cc"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -429,7 +419,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -438,15 +427,15 @@ tf_cc_test(
srcs = ["hlo_sharding_test.cc"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -467,7 +456,7 @@ tf_cc_test(
srcs = ["call_graph_test.cc"],
deps = [
":call_graph",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -501,6 +490,7 @@ cc_library(
hdrs = ["call_inliner.h"],
deps = [
":call_graph",
+ ":hlo_dce",
":hlo_pass",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
@@ -516,7 +506,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -535,7 +525,7 @@ tf_cc_test(
deps = [
":call_graph",
":flatten_call_graph",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -594,7 +584,6 @@ cc_library(
":allocation_tracker",
":backend",
":channel_tracker",
- ":compilation_cache",
":compiler",
":computation_layout",
":device_memory_allocator",
@@ -607,10 +596,8 @@ cc_library(
":hlo_module_config",
":hlo_proto_util",
":platform_util",
- ":session_proto",
":source_map_util",
":transfer_manager",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:service_interface",
@@ -645,7 +632,6 @@ cc_library(
":platform_util",
":service",
":shaped_buffer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
@@ -765,9 +751,7 @@ cc_library(
":hlo_graph_dumper",
":hlo_proto",
":pool",
- ":session_proto",
":shaped_buffer",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -817,7 +801,7 @@ cc_library(
hdrs = ["transfer_manager.h"],
deps = [
":shaped_buffer",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -869,8 +853,6 @@ cc_library(
hdrs = ["channel_tracker.h"],
deps = [
":hlo",
- ":session_proto",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -982,16 +964,16 @@ tf_cc_test(
":hlo",
":hlo_ordering",
":hlo_scheduling",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -1027,9 +1009,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1060,7 +1042,7 @@ tf_cc_test(
":hlo_ordering",
":hlo_value",
":tuple_points_to_analysis",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1124,15 +1106,16 @@ tf_cc_test(
srcs = ["hlo_scheduling_test.cc"],
deps = [
":buffer_value",
+ ":heap_simulator",
":hlo",
":hlo_ordering",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1142,7 +1125,7 @@ cc_library(
hdrs = ["hlo_query.h"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
],
)
@@ -1165,9 +1148,22 @@ tf_cc_test(
deps = [
":hlo_matchers",
":instruction_fusion",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ ],
+)
+
+cc_library(
+ name = "multi_output_fusion",
+ srcs = ["multi_output_fusion.cc"],
+ hdrs = ["multi_output_fusion.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/core:lib",
],
)
@@ -1178,6 +1174,7 @@ cc_library(
deps = [
":hlo",
":shape_inference",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1208,6 +1205,7 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1227,6 +1225,7 @@ cc_library(
":hlo_creation_utils",
":hlo_pass",
":while_util",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
],
@@ -1240,8 +1239,9 @@ tf_cc_test(
":batchnorm_expander",
":hlo",
":hlo_matchers",
+ ":hlo_parser",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1263,6 +1263,7 @@ cc_library(
":hlo_pass",
":hlo_query",
":pattern_matcher",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1282,7 +1283,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1318,7 +1319,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1339,9 +1340,9 @@ tf_cc_test(
deps = [
":gather_expander",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:test_macros_header",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1353,7 +1354,7 @@ cc_library(
":call_inliner",
":hlo",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -1369,6 +1370,7 @@ tf_cc_test(
":conditional_simplifier",
":hlo",
":hlo_matchers",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1428,7 +1430,7 @@ tf_cc_test(
deps = [
":defuser",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
],
@@ -1456,7 +1458,7 @@ tf_cc_test(
deps = [
":hlo_matchers",
":implicit_broadcast_remover",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
],
@@ -1498,7 +1500,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":tuple_simplifier",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -1513,7 +1515,7 @@ cc_library(
hdrs = ["reshape_mover.h"],
deps = [
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -1528,7 +1530,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":reshape_mover",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1563,7 +1565,7 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":inliner",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1580,7 +1582,7 @@ cc_library(
hdrs = ["computation_placer.h"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -1612,7 +1614,7 @@ cc_library(
hdrs = ["generic_transfer_manager.h"],
deps = [
":transfer_manager",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -1652,7 +1654,6 @@ tf_cc_test(
":hlo_cost_analysis",
":local_service",
":service",
- ":versioned_computation_handle",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
@@ -1691,9 +1692,9 @@ tf_cc_test(
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -1704,7 +1705,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1719,6 +1720,7 @@ tf_cc_binary(
deps = [
":hlo",
":hlo_graph_dumper",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -1733,7 +1735,7 @@ tf_cc_test(
srcs = ["hlo_module_test.cc"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1831,7 +1833,7 @@ tf_cc_test(
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -1868,15 +1870,15 @@ tf_cc_test(
deps = [
":hlo",
":hlo_liveness_analysis",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -1929,7 +1931,7 @@ tf_cc_test(
":hlo_matchers",
":hlo_ordering",
":instruction_fusion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1961,8 +1963,10 @@ cc_library(
hdrs = ["tuple_points_to_analysis.h"],
deps = [
":hlo",
+ ":hlo_dataflow_analysis",
":logical_buffer",
":logical_buffer_analysis",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1981,6 +1985,7 @@ tf_cc_test(
":hlo_matchers",
":instruction_fusion",
":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -1994,20 +1999,6 @@ tf_cc_test(
)
cc_library(
- name = "compilation_cache",
- srcs = ["compilation_cache.cc"],
- hdrs = ["compilation_cache.h"],
- deps = [
- ":executable",
- ":hlo_module_config",
- ":versioned_computation_handle",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "layout_assignment",
srcs = [
"layout_assignment.cc",
@@ -2066,7 +2057,7 @@ tf_cc_test(
":hlo_graph_dumper",
":hlo_matchers",
":hlo_runner",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -2117,6 +2108,7 @@ cc_library(
hdrs = ["hlo_verifier.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:status_macros",
@@ -2148,6 +2140,7 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
+ ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
@@ -2155,6 +2148,7 @@ cc_library(
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
+ ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -2168,6 +2162,7 @@ tf_cc_test(
name = "hlo_rematerialization_test",
srcs = ["hlo_rematerialization_test.cc"],
deps = [
+ ":flatten_call_graph",
":hlo",
":hlo_matchers",
":hlo_ordering",
@@ -2177,6 +2172,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
],
)
@@ -2186,6 +2182,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_dce",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
@@ -2206,16 +2203,16 @@ tf_cc_test(
deps = [
":hlo",
":hlo_module_dce",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2230,16 +2227,16 @@ tf_cc_test(
":hlo",
":hlo_matchers",
":layout_assignment",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2289,7 +2286,7 @@ cc_library(
":hlo",
":hlo_domain_map",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -2305,15 +2302,15 @@ tf_cc_test(
":hlo",
":hlo_cse",
":hlo_matchers",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -2327,7 +2324,7 @@ cc_library(
":hlo_evaluator",
":hlo_pass",
":hlo_query",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
@@ -2342,7 +2339,7 @@ tf_cc_test(
":hlo_constant_folding",
":hlo_matchers",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -2380,6 +2377,20 @@ cc_library(
)
cc_library(
+ name = "hlo_domain_verifier",
+ srcs = ["hlo_domain_verifier.cc"],
+ hdrs = ["hlo_domain_verifier.h"],
+ deps = [
+ ":hlo",
+ ":hlo_domain_map",
+ ":hlo_graph_dumper",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "hlo_domain_isolator",
srcs = ["hlo_domain_isolator.cc"],
hdrs = ["hlo_domain_isolator.h"],
@@ -2398,12 +2409,11 @@ cc_library(
hdrs = ["hlo_domain_remover.h"],
deps = [
":hlo",
- ":hlo_domain_isolator",
":hlo_domain_map",
+ ":hlo_domain_verifier",
":hlo_graph_dumper",
":hlo_pass",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
@@ -2415,12 +2425,13 @@ tf_cc_test(
":hlo",
":hlo_domain_isolator",
":hlo_domain_remover",
+ ":hlo_parser",
":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2434,7 +2445,7 @@ cc_library(
":hlo_evaluator",
":hlo_pass",
":hlo_query",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
@@ -2506,10 +2517,10 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2567,10 +2578,9 @@ cc_library(
name = "hlo_tfgraph_builder",
srcs = ["hlo_tfgraph_builder.cc"],
hdrs = ["hlo_tfgraph_builder.h"],
- visibility = ["//tensorflow/compiler/xla/tools:__pkg__"],
deps = [
":hlo",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
@@ -2598,9 +2608,10 @@ cc_library(
hdrs = ["hlo_graph_dumper.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_execution_profile",
":hlo_tfgraph_builder",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:window_util",
@@ -2618,6 +2629,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_graph_dumper",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -2649,16 +2661,16 @@ tf_cc_test(
":hlo_matchers",
":shape_inference",
":transpose_folding",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -2670,7 +2682,7 @@ cc_library(
deps = [
":hlo",
":hlo_pass",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
@@ -2685,7 +2697,7 @@ tf_cc_test(
":hlo",
":shape_inference",
":zero_sized_hlo_elimination",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
@@ -2795,7 +2807,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -2831,8 +2843,8 @@ tf_cc_test(
":tuple_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2845,6 +2857,7 @@ cc_library(
":hlo",
":hlo_creation_utils",
":tuple_util",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/core:lib",
],
)
@@ -2857,8 +2870,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2884,8 +2897,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2911,8 +2924,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2965,9 +2978,76 @@ tf_cc_test(
":hlo_matchers",
":indexed_array_analysis",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "hlo_parser",
+ srcs = ["hlo_parser.cc"],
+ hdrs = ["hlo_parser.h"],
+ deps = [
+ ":hlo",
+ ":hlo_lexer",
+ ":hlo_sharding_metadata",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_parser_test",
+ size = "small",
+ srcs = ["hlo_parser_test.cc"],
+ deps = [
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main", # fixdeps: keep
+ ],
+)
+
+cc_library(
+ name = "hlo_lexer",
+ srcs = ["hlo_lexer.cc"],
+ hdrs = [
+ "hlo_lexer.h",
+ "hlo_token.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ ],
+)
+
+cc_library(
+ name = "hlo_casting_utils",
+ hdrs = ["hlo_casting_utils.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
+tf_cc_test(
+ name = "hlo_casting_utils_test",
+ srcs = ["hlo_casting_utils_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_casting_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index e1a45e453e..af7728da54 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -50,20 +51,15 @@ namespace {
namespace m = match;
-// Returns whether operand is a literal with the given value.
-bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
- return operand->opcode() == HloOpcode::kConstant &&
- operand->literal().IsAll(value);
-}
-
bool IsAll(const HloInstruction* op, int8 value) {
- if (IsLiteralWithValue(op, value)) {
- return true;
- }
- if (op->opcode() == HloOpcode::kBroadcast && IsAll(op->operand(0), value)) {
- return true;
+ switch (op->opcode()) {
+ case HloOpcode::kBroadcast:
+ return IsAll(op->operand(0), value);
+ case HloOpcode::kConstant:
+ return op->literal().IsAll(value);
+ default:
+ return false;
}
- return false;
}
// Returns whether the given transpose produces a result which is bit-wise
@@ -75,21 +71,22 @@ bool TransposeIsBitcast(const HloInstruction* transpose) {
transpose->dimensions());
}
-// Returns true if the given reshape produces a result which is bit-wise
+// Returns true if the given reshape/copy produces a result which is bit-wise
// identical to its operand and thus may be replaced with a bitcast.
//
// This function is conservative -- even if this function returns false, the
// reshape may still be a bitcast. For example, a reshape from [28x28] to [784].
-bool ReshapeIsBitcast(
- const HloInstruction* reshape,
+bool ReshapeOrCopyIsBitcast(
+ const HloInstruction* instr,
const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
- CHECK_EQ(HloOpcode::kReshape, reshape->opcode());
+ CHECK(HloOpcode::kReshape == instr->opcode() ||
+ HloOpcode::kCopy == instr->opcode());
- const HloInstruction* operand = reshape->operand(0);
+ const HloInstruction* operand = instr->operand(0);
// Can't insert bitcasts if the compiler used a memory layout which isn't
// compatible.
- return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
- valid_bitcast_callback(operand->shape(), reshape->shape());
+ return ShapeUtil::ReshapeIsBitcast(operand->shape(), instr->shape()) &&
+ valid_bitcast_callback(operand->shape(), instr->shape());
}
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@@ -159,9 +156,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleMap(HloInstruction* map) override;
- Status HandleMaximum(HloInstruction* maximum) override;
- Status HandleMinimum(HloInstruction* minimum) override;
-
// Returns whether algebraic simplification has occurred.
const bool changed() const { return changed_; }
@@ -200,8 +194,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
// Helper method to perform and add reduction in a single dimension.
HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) {
- HloInstruction* zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction* zero =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(hlo->shape().element_type()).CloneToUnique()));
HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation();
Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape());
return computation_->AddInstruction(HloInstruction::CreateReduce(
@@ -433,7 +428,15 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
}
// All copies can be eliminated (assuming layout constraints are satisified).
- ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0));
+ if (ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0))) {
+ return Status::OK();
+ }
+
+ if (is_layout_sensitive_ &&
+ ReshapeOrCopyIsBitcast(copy, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(copy);
+ }
+
return Status::OK();
}
@@ -449,7 +452,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
// Filter out and remove empty operands.
std::vector<HloInstruction*> nonempty_operands;
for (HloInstruction* operand : operands) {
- if (!ShapeUtil::HasZeroElements(operand->shape())) {
+ if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
nonempty_operands.push_back(operand);
}
}
@@ -528,11 +531,15 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
constant, BuildTupleConstant(computation_, constant->literal()));
}
+ if (constant->shape().element_type() == TOKEN) {
+ return Status::OK();
+ }
+
// If a literal is all the same element replace it with a scalar broadcast.
if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
constant->literal().IsAllFirst()) {
- std::unique_ptr<Literal> unique_scalar =
- MakeUnique<Literal>(constant->literal().GetFirstScalarLiteral());
+ std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>(
+ LiteralUtil::GetFirstScalarLiteral(constant->literal()));
HloInstruction* scalar = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(unique_scalar)));
return ReplaceWithNewInstruction(
@@ -563,6 +570,14 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
return Status::OK();
}
+namespace {
+template <typename T>
+Status InvertConstant(const HloInstruction& constant, Literal* result) {
+ return result->Populate<T>([&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return T{1.0} / constant.literal().Get<T>(indices);
+ });
+}
+} // namespace
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
Shape* shape;
@@ -624,14 +639,31 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
- HloInstruction* one =
- computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(a->shape().element_type()).CloneToUnique()));
- HloInstruction* inverse = computation_->AddInstruction(
- HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kMultiply, a, inverse));
+ Literal new_literal(b->shape());
+ switch (b->shape().element_type()) {
+ case F16:
+ TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
+ break;
+ case F32:
+ TF_RETURN_IF_ERROR(InvertConstant<float>(*b, &new_literal));
+ break;
+ case BF16:
+ TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*b, &new_literal));
+ break;
+ case F64:
+ TF_RETURN_IF_ERROR(InvertConstant<double>(*b, &new_literal));
+ break;
+ case C64:
+ TF_RETURN_IF_ERROR(InvertConstant<complex64>(*b, &new_literal));
+ break;
+ default:
+ return Status::OK();
+ }
+ auto inverse = computation_->AddInstruction(
+ HloInstruction::CreateConstant((new_literal.CloneToUnique())));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
+ return ReplaceInstruction(divide, new_divide);
}
// (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
@@ -651,18 +683,18 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
TF_ASSIGN_OR_RETURN(auto b_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, b, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a, b_times_c));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
+ return ReplaceInstruction(divide, new_divide);
}
// A / (B / C) => (A*C) / B
if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
TF_ASSIGN_OR_RETURN(auto a_times_c,
MakeBinaryHlo(HloOpcode::kMultiply, a, c));
- return ReplaceWithNewInstruction(
- divide, HloInstruction::CreateBinary(divide->shape(),
- HloOpcode::kDivide, a_times_c, b));
+ TF_ASSIGN_OR_RETURN(auto new_divide,
+ MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
+ return ReplaceInstruction(divide, new_divide);
}
return Status::OK();
@@ -1058,11 +1090,11 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
}
// Replace a zero element dot with a broadcast of the constant 0.
- if (ShapeUtil::HasZeroElements(dot->shape()) ||
- ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
+ ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
@@ -1221,9 +1253,10 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
switch (instruction->opcode()) {
case HloOpcode::kReshape:
case HloOpcode::kReverse:
- case HloOpcode::kSort:
case HloOpcode::kTranspose:
return true;
+ case HloOpcode::kSort:
+ return (!ShapeUtil::IsTuple(instruction->shape()));
default:
return false;
}
@@ -1392,7 +1425,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
}
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
- if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
@@ -1487,7 +1520,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
if (IsAll(rhs, 0)) {
auto one = HloInstruction::CreateConstant(
- Literal::One(power->shape().element_type()).CloneToUnique());
+ LiteralUtil::One(power->shape().element_type()).CloneToUnique());
std::unique_ptr<HloInstruction> ones;
if (ShapeUtil::IsScalar(power->shape())) {
ones = std::move(one);
@@ -1522,7 +1555,7 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
if (IsAll(rhs, -1)) {
auto* one = computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::One(rhs->shape().element_type()).CloneToUnique()));
+ LiteralUtil::One(rhs->shape().element_type()).CloneToUnique()));
// Explicitly broadcast scalar 1 to the output shape, to avoid implicit
// broadcast in divide HLO as we are trying to eliminate implicit
@@ -1638,7 +1671,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Reshape directly to empty constant if the shape contains zero-element
// dimension.
- if (ShapeUtil::HasZeroElements(reshape->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
auto empty_constant = HloInstruction::CreateConstant(
Literal::CreateFromShape(reshape->shape()));
@@ -1672,7 +1705,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Make this a bitcast if possible.
if (is_layout_sensitive_ &&
- ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
+ ReshapeOrCopyIsBitcast(reshape, valid_bitcast_callback_)) {
ReplaceWithBitcast(reshape);
return Status::OK();
}
@@ -1739,7 +1772,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
// If any dimension of update is 0, elide the DynamicUpdateSlice. This
// optimization becomes invalid should we later prefer to warn about out of
// bound indices.
- if (ShapeUtil::HasZeroElements(update->shape())) {
+ if (ShapeUtil::IsZeroElementArray(update->shape())) {
return ReplaceInstruction(dynamic_update_slice,
dynamic_update_slice->mutable_operand(0));
}
@@ -1751,8 +1784,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto init_value = reduce->mutable_operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- if (ShapeUtil::HasZeroElements(arg->shape()) ||
- ShapeUtil::HasZeroElements(reduce->shape())) {
+ if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
+ ShapeUtil::IsZeroElementArray(reduce->shape())) {
return ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
@@ -1774,11 +1807,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
new_reduce_dimensions, function));
}
+ // If the reduction results in the same number of elements, then the only
+ // possible side effect would be a reshape. Since the init_value is an
+ // identity of the reduction function, we can therefore replace the reduce
+ // with a simple reshape, ignoring the reduction function completely.
if (ShapeUtil::ElementsIn(reduce->shape()) ==
ShapeUtil::ElementsIn(arg->shape())) {
return ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReshape(reduce->shape(), arg));
}
+
+ // If a reduce feeds a reduce with the same computation and initial value,
+ // they can be combined into a single reduce.
+ if (arg->opcode() == HloOpcode::kReduce &&
+ init_value->Identical(*arg->operand(1)) &&
+ *function == *arg->to_apply()) {
+ // Create a new reduce with the combined reduction dimensions of both
+ // reduces.
+ std::vector<int64> arg_dims = arg->dimensions();
+ std::sort(arg_dims.begin(), arg_dims.end());
+ std::vector<int64> reduce_dims = reduce->dimensions();
+ std::sort(reduce_dims.begin(), reduce_dims.end());
+ // Transform reduce_dims to the same rank as the operand of the operand.
+ for (int64 arg_dim : arg_dims) {
+ for (int64& dim : reduce_dims) {
+ if (dim >= arg_dim) {
+ ++dim;
+ }
+ }
+ }
+ std::vector<int64> new_dimensions;
+ new_dimensions.reserve(arg->dimensions().size() +
+ reduce->dimensions().size());
+ std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
+ reduce_dims.end(), std::back_inserter(new_dimensions));
+ return ReplaceWithNewInstruction(
+ reduce,
+ HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0),
+ init_value, new_dimensions, function));
+ }
+
// A reshape that collapses multiple dimensions into a dimension being
// reduced can just reduce all of those dimensions instead of doing a
// collapsing reshape before a reduction.
@@ -1828,7 +1896,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) {
- if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateBroadcast(reduce_window->shape(),
@@ -1842,7 +1910,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateMap(reduce_window->shape(),
- {operand, reduce_window->mutable_operand(1)},
+ {reduce_window->mutable_operand(1), operand},
function));
}
@@ -2024,16 +2092,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) {
auto lhs = convolution->mutable_operand(0);
auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
convolution->shape(),
- computation_->AddInstruction(HloInstruction::CreateConvert(
- ShapeUtil::MakeShape(convolution->shape().element_type(), {}),
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(convolution->shape().element_type())
+ .CloneToUnique())),
{}));
}
const auto& window = convolution->window();
@@ -2205,68 +2272,6 @@ Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
return ReplaceWithNewInstruction(map, std::move(clone));
}
-Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
- // Match the following tree:
- // min_operand operand
- // \ /
- // max_operand min
- // \ /
- // max
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* min;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMinimum, maximum,
- /*matching_operand=*/&min,
- /*other_operand=*/&max_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, min,
- /*matching_operand=*/&min_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
-Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
- // Match the following tree:
- // max_operand operand
- // \ /
- // min_operand max
- // \ /
- // min
- // where max_operand and min_operand are scalar constants.
- {
- HloInstruction* max;
- HloInstruction* max_operand;
- HloInstruction* min_operand;
- HloInstruction* operand;
-
- if (hlo_query::MatchBinaryInstructionOperandOpcode(
- HloOpcode::kMaximum, minimum,
- /*matching_operand=*/&max,
- /*other_operand=*/&min_operand) &&
- hlo_query::MatchBinaryInstructionOperand(
- hlo_query::IsScalarConstant, max,
- /*matching_operand=*/&max_operand,
- /*other_operand=*/&operand) &&
- TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
- max_operand)) {
- return Status::OK();
- }
- }
-
- return Status::OK();
-}
-
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
XLA_VLOG_LINES(2,
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index cda157f9fa..92bbcbd740 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -60,7 +60,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
@@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
+// Test that Reduce(Reduce(A)) -> Reduce(A)
+TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
+ HloComputation::Builder builder(TestName());
+ // Create add computation.
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r4f32, "param"));
+ std::vector<int64> dims0({0});
+ Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7});
+ HloInstruction* reduce0 = builder.AddInstruction(
+ HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation));
+ std::vector<int64> dims1({1, 2});
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
+ builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero,
+ dims1, add_computation));
+ module().AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ HloInstruction* root = module().entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Reduce(param, zero));
+ EXPECT_EQ(root->dimensions(), std::vector<int64>({0, 2, 3}));
+}
+
// Test that Const + A is canonicalized to A + Const.
TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
@@ -81,7 +119,7 @@ TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
@@ -102,9 +140,9 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.14159f)));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
@@ -127,7 +165,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
HloInstruction* bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, zero, {0, 1}));
builder.AddInstruction(
@@ -162,9 +200,12 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- builder.AddInstruction(
- HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(HloInstruction::CreateMap(
+ r2f32,
+ {param0, builder.AddInstruction(
+ HloInstruction::CreateBroadcast(r2f32, zero, {}))},
+ add_computation));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -173,7 +214,7 @@ TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) {
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
root = computation->root_instruction();
- EXPECT_THAT(root, op::Add(param0, zero));
+ EXPECT_THAT(root, op::Add(param0, op::Broadcast(zero)));
}
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
@@ -182,7 +223,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0, 0, 0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0, 0, 0})));
HloInstruction* bcast =
builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1}));
builder.AddInstruction(
@@ -201,7 +242,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) {
TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14f, 3.14f, 3.14f})));
+ LiteralUtil::CreateR1<float>({3.14f, 3.14f, 3.14f})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -217,7 +258,7 @@ TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({3.14, 3.14, 4})));
+ LiteralUtil::CreateR1<float>({3.14, 3.14, 4})));
auto computation = module().AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
@@ -236,7 +277,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
@@ -257,7 +298,7 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kSubtract, param0, constant));
@@ -329,17 +370,16 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) {
// Test that (A/B)/(C/D) is simplified to (A*D)/(B*C).
TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r2f32 = ShapeUtil::MakeShape(F32, {42, 123});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r2f32, "param1"));
HloInstruction* param2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, r2f32, "param2"));
HloInstruction* param3 = builder.AddInstruction(
- HloInstruction::CreateParameter(3, r0f32, "param3"));
+ HloInstruction::CreateParameter(3, r2f32, "param3"));
HloInstruction* div0 = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, param1));
HloInstruction* div1 = builder.AddInstruction(
@@ -360,8 +400,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) {
EXPECT_THAT(
computation->root_instruction(),
op::Divide(op::Multiply(param0, param3), op::Multiply(param1, param2)));
- EXPECT_TRUE(
- ShapeUtil::Compatible(computation->root_instruction()->shape(), r2f32));
}
// Test that A/exp(B) is simplified to A*exp(-B).
@@ -421,7 +459,6 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) {
// Test that broadcasting is done on the right step when simplifying A/pow(B,C)
// to A*pow(B,-C).
TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
HloComputation::Builder builder(TestName());
HloInstruction* param0 = builder.AddInstruction(
@@ -429,7 +466,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* param2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction::CreateParameter(2, r1f32, "param2"));
HloInstruction* power = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2));
builder.AddInstruction(
@@ -446,14 +483,9 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) {
ASSERT_THAT(computation->root_instruction(),
op::Multiply(param0, op::Power(param1, op::Negate(param2))));
-
- const HloInstruction* negate =
- computation->root_instruction()->operand(1)->operand(1);
- const Shape& negate_shape = negate->shape();
- EXPECT_EQ(0, negate_shape.dimensions_size());
}
-// A / Const => A * (1 / Const)
+// A / Const => A * InvertedConst
TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
Shape r1f32 = ShapeUtil::MakeShape(F32, {3});
HloComputation::Builder builder(TestName());
@@ -461,7 +493,7 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 1.f, 2.f})));
+ LiteralUtil::CreateR1<float>({0.f, 1.f, 2.f})));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide,
param0, constant));
@@ -472,20 +504,19 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) {
ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
- op::Multiply(param0, op::Divide(op::Constant(), constant)));
+ op::Multiply(param0, op::Constant()));
}
// pow(pow(A, X), Y) => pow(A, X*Y)
TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
Shape r1f32 = ShapeUtil::MakeShape(F32, {7});
HloComputation::Builder builder(TestName());
HloInstruction* base = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* exp2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction::CreateParameter(2, r1f32, "param2"));
HloInstruction* inner_power = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1));
builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower,
@@ -502,15 +533,14 @@ TEST_F(AlgebraicSimplifierTest, PowerOfPower) {
// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex
// numbers.
TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) {
- Shape r0c64 = ShapeUtil::MakeShape(C64, {});
Shape r1c64 = ShapeUtil::MakeShape(C64, {7});
HloComputation::Builder builder(TestName());
HloInstruction* base = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1c64, "param0"));
HloInstruction* exp1 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0c64, "param1"));
+ HloInstruction::CreateParameter(1, r1c64, "param1"));
HloInstruction* exp2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0c64, "param2"));
+ HloInstruction::CreateParameter(2, r1c64, "param2"));
HloInstruction* inner_power = builder.AddInstruction(
HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1));
builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower,
@@ -529,7 +559,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
@@ -550,7 +580,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r2f32, "param0"));
HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
HloInstruction* div = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
@@ -830,7 +860,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
@@ -854,7 +884,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
@@ -882,7 +912,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
@@ -904,7 +934,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
@@ -926,7 +956,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param0"));
HloInstruction* negative_one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
param0, negative_one));
@@ -1017,7 +1047,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) {
builder.AddInstruction(HloInstruction::CreateReduceWindow(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))),
window, add_computation));
module().AddEntryComputation(builder.Build());
HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
@@ -1044,7 +1074,7 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) {
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {5, 2}), param,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
module().AddEntryComputation(builder.Build());
EXPECT_THAT(module().entry_computation()->root_instruction(),
@@ -1086,7 +1116,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@@ -1121,6 +1151,33 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
EXPECT_THAT(computation->root_instruction(), param0);
}
+TEST_F(AlgebraicSimplifierTest, CopyEqualsBitcast) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), "param"));
+ *param->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3});
+ HloInstruction* copy = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {1, 14, 14, 64}), HloOpcode::kCopy, param));
+ *copy->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({1, 2, 0, 3});
+ auto computation = module().AddEntryComputation(builder.Build());
+ EXPECT_THAT(computation->root_instruction(), op::Copy(param));
+
+ AlgebraicSimplifier simplifier1(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ ASSERT_FALSE(simplifier1.Run(&module()).ValueOrDie());
+ // Verify that the copy is not replaced.
+ EXPECT_THAT(computation->root_instruction(), op::Copy(param));
+
+ AlgebraicSimplifier simplifier2(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier2.Run(&module()).ValueOrDie());
+ // Verify that the copy is replaced.
+ EXPECT_THAT(computation->root_instruction(), op::Bitcast(param));
+}
+
// Test that unary concatenates are removed.
TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) {
Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
@@ -1151,7 +1208,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
HloInstruction* param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r1f32, "param1"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1}));
@@ -1181,7 +1238,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "param0"));
HloInstruction* empty_literal = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({})));
HloInstruction* empty_slice =
builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1}));
@@ -1351,33 +1408,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape));
}
-// Regression test for a bug in the reshape sinking transformation, where
-// moving a reshape to a scalar led to a crash.
-TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) {
- HloComputation::Builder builder(TestName());
- HloInstruction* param =
- builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {1, 1}), "param"));
- HloInstruction* reshape = builder.AddInstruction(
- HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {}), param));
- HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1., 2., 3.})));
- builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero));
- auto computation = module().AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Reshape(param), zero));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- bitcasting_callback());
-
- simplifier.Run(&module()).ValueOrDie();
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Reshape(param), zero));
-}
-
// Regression test for a bug where if we failed to sink a reshape, we'd set the
// 'changed' bit in AlgebraicSimplifier to false.
TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
@@ -1390,7 +1420,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add));
@@ -1413,7 +1443,7 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) {
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0")),
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0, 0}, {0, 0}})))));
+ LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}})))));
builder.AddInstruction(
HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add,
@@ -1696,7 +1726,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig no_padding;
for (int i = 0; i < 2; ++i) {
auto dimension = no_padding.add_dimensions();
@@ -1714,7 +1744,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1727,7 +1757,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {10, 10}), "param"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig padding;
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {2, -3};
@@ -1759,7 +1789,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1781,7 +1811,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1804,7 +1834,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1932,7 +1962,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@@ -2037,160 +2068,6 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
EXPECT_EQ("NO_CHANGE", build_and_simplify());
}
-// Test that max(min(A, x), y) is transformed to clamp(y, A, x)
-TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMinimum, param0, min_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Maximum(op::Minimum(param0, min_value), max_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
-// values.
-TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
-// broadcasted scalar values.
-TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r1f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r1f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Clamp(max_value, param0, min_value));
-}
-
-// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
-// clamp(non-constant1, A, non-constant2)
-TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateParameter(1, r0f32, "param1"));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateParameter(2, r0f32, "param2"));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Maximum(param0, max_value), min_value));
-}
-
-// Test that min(f(max(A, constant1)), constant2) is not transformed to
-// clamp(constant1, A, constant2)
-TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
- Shape r0f32 = ShapeUtil::MakeShape(F32, {});
- HloComputation::Builder builder(TestName());
- HloInstruction* param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, r0f32, "param0"));
- HloInstruction* min_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
- HloInstruction* max_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
- HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMaximum, param0, max_value));
- HloInstruction* fmax = builder.AddInstruction(
- HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
- builder.AddInstruction(HloInstruction::CreateBinary(
- r0f32, HloOpcode::kMinimum, fmax, min_value));
-
- auto module = CreateNewModule();
- auto computation = module->AddEntryComputation(builder.Build());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
- min_value));
-
- AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
- non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
-
- EXPECT_THAT(computation->root_instruction(),
- op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
- min_value));
-}
-
// Test that slice(broadcast(/*scalar value*/)) simplifies to a single
// broadcast.
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
@@ -2200,10 +2077,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, scalar_param,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
@@ -2219,10 +2094,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2234,13 +2109,11 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloComputation::Builder builder(TestName());
HloInstruction* forty_two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, forty_two,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
HloInstruction* transpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -2259,7 +2132,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2268,7 +2141,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2282,7 +2156,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding));
@@ -2313,7 +2187,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, pad, reduce_init_value, window,
@@ -2349,7 +2223,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2363,7 +2238,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
padding.mutable_dimensions(3)->set_edge_padding_high(2);
HloInstruction* pad_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
@@ -2398,7 +2273,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
const Shape reduce_window_shape =
ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
HloInstruction* reduce_init_value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0f)));
HloInstruction* reduce_window =
builder.AddInstruction(HloInstruction::CreateReduceWindow(
reduce_window_shape, convert, reduce_init_value, window,
@@ -2444,7 +2319,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
@@ -2469,9 +2344,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
HloComputation::Builder call_builder(TestName() + ".Call");
HloInstruction* zero = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({0.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({0.0f})));
HloInstruction* one = call_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0f})));
call_builder.AddInstruction(
HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get()));
@@ -2487,9 +2362,9 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) {
HloComputation::Builder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
- std::unique_ptr<Literal> value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get()});
+ std::unique_ptr<Literal> value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
builder.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2512,8 +2387,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) {
shape,
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "slice_from")),
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0}))),
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int>({0, 0, 0}))),
/*slice_sizes=*/{10, 100, 1000}));
auto computation = module().AddEntryComputation(builder.Build());
@@ -2546,8 +2421,8 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) {
builder.AddInstruction(
HloInstruction::CreateParameter(2, slice_shape, "to_update")),
slice,
- builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0, 0, 0})))));
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int>({0, 0, 0})))));
auto computation = module().AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
@@ -2562,7 +2437,7 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts) {
HloComputation::Builder builder(TestName());
Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* input_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({3, 4})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({3, 4})));
HloInstruction* inner_bcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(r2f32, input_array, {1}));
Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2});
@@ -2671,7 +2546,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
pad_shape, input,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0.0f))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))),
padding));
HloComputation* add_computation = nullptr;
@@ -2690,7 +2565,7 @@ TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) {
Window window = window_util::MakeWindow(
decorate_spatials(param.reduce_window_spatials, 1, 1));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape,
ShapeInference::InferReduceWindowShape(
pad->shape(), zero->shape(), window,
@@ -2829,7 +2704,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) {
Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k)));
Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n});
@@ -2908,7 +2783,7 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) {
Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n)));
DotDimensionNumbers dot_dnums;
@@ -2955,7 +2830,7 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) {
HloInstruction* const update = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "update"));
HloInstruction* const start_indices = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int>({0})));
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
dslice_shape, operand, update, start_indices));
const HloComputation* const computation =
@@ -3004,7 +2879,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase);
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -3012,7 +2887,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int32 start_col = (spec.lcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
@@ -3023,7 +2898,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k;
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -3071,7 +2946,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k;
Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols});
auto* lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows,
/*cols=*/lhs_cols)));
@@ -3082,7 +2957,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase);
Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols});
auto* rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows,
/*cols=*/rhs_cols)));
@@ -3090,7 +2965,7 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) {
int32 start_col = (spec.rcd == 0) ? spec.s : 0;
const auto start_indices =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<int32>({start_row, start_col})));
+ LiteralUtil::CreateR1<int32>({start_row, start_col})));
int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1;
int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k;
Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size});
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 598718c72c..c4cd60c120 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -41,6 +43,8 @@ namespace xla {
namespace {
+using tensorflow::gtl::optional;
+
// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm
// operations into smaller operations.
class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
@@ -58,8 +62,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
- bool rewrite_inference_op, bool rewrite_grad_op,
- bool use_fusion);
+ bool rewrite_inference_op, bool rewrite_grad_op);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@@ -70,21 +73,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
explicit BatchNormExpanderVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op, bool use_fusion)
+ bool rewrite_grad_op)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_fusion_(use_fusion) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
HloComputation* GetOrCreateScalarAddComputation(
PrimitiveType primitive_type) {
- HloComputation** scalar_add_computation =
- &scalar_add_computations_[primitive_type];
- if (*scalar_add_computation) {
- return *scalar_add_computation;
- }
-
HloComputation::Builder b("scalar_add_computation");
Shape shape = ShapeUtil::MakeShape(primitive_type, {});
auto scalar_lhs = b.AddInstruction(
@@ -93,71 +89,38 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
- *scalar_add_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_add_computation;
+ return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
}
- // TODO(b/80534766): Remove maps after performance issues with scalar
- // broadcasts are resolved on all backends.
- HloComputation* GetOrCreateScalarRsqrtComputation(
- PrimitiveType primitive_type) {
- HloComputation** scalar_rsqrt_computation =
- &scalar_rsqrt_computations_[primitive_type];
- if (*scalar_rsqrt_computation) {
- return *scalar_rsqrt_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<float>(-0.5f)))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kPower, scalar_lhs, scalar_rhs));
- *scalar_rsqrt_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_rsqrt_computation;
+ std::unique_ptr<HloInstruction> Rsqrt(
+ HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ add_instruction) {
+ HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast(
+ operand->shape(),
+ add_instruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(operand->shape().element_type(), {}),
+ add_instruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(-0.5f))))),
+ {}));
+ return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower,
+ operand, exponent);
}
- std::unique_ptr<HloInstruction> Rsqrt(HloInstruction* operand) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarRsqrtComputation(operand->shape().element_type()));
- }
-
- HloComputation* GetOrCreateScalarMeanComputation(PrimitiveType primitive_type,
- int64 element_count) {
- HloComputation** scalar_mean_computation =
- &scalar_mean_computations_[std::pair<PrimitiveType, int64>(
- primitive_type, element_count)];
- if (*scalar_mean_computation) {
- return *scalar_mean_computation;
- }
-
- HloComputation::Builder b("scalar_add_computation");
- Shape shape = ShapeUtil::MakeShape(primitive_type, {});
- auto scalar_lhs = b.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
- auto scalar_rhs = b.AddInstruction(HloInstruction::CreateConvert(
- shape, b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(
- 1.0f / static_cast<float>(element_count))))));
- auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kMultiply, scalar_lhs, scalar_rhs));
- *scalar_mean_computation =
- computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
- return *scalar_mean_computation;
- }
-
- std::unique_ptr<HloInstruction> Mean(int64 element_count,
- HloInstruction* operand) {
- return HloInstruction::CreateMap(
- operand->shape(), {operand},
- GetOrCreateScalarMeanComputation(operand->shape().element_type(),
- element_count));
+ std::unique_ptr<HloInstruction> Mean(
+ int64 element_count, HloInstruction* operand,
+ const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
+ add_instruction) {
+ HloInstruction* elem_count_recip =
+ add_instruction(HloInstruction::CreateBroadcast(
+ operand->shape(),
+ add_instruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(operand->shape().element_type(), {}),
+ add_instruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(1.0 / element_count))))),
+ {}));
+ return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply,
+ operand, elem_count_recip);
}
// Replaces the existing HLO instruction old_instruction, with
@@ -189,18 +152,9 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_fusion_;
// Whether rewrite has occurred.
bool changed_ = false;
-
- // Cached computations for adding two scalars.
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_add_computations_;
- tensorflow::gtl::FlatMap<PrimitiveType, HloComputation*>
- scalar_rsqrt_computations_;
- tensorflow::gtl::FlatMap<std::pair<PrimitiveType, int64>, HloComputation*>
- scalar_mean_computations_;
};
} // namespace
@@ -208,13 +162,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault {
bool BatchNormExpanderVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_inference_op,
- bool rewrite_grad_op, bool use_fusion) {
+ bool rewrite_grad_op) {
BatchNormExpanderVisitor visitor(
computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_inference_op=*/rewrite_inference_op,
- /*rewrite_grad_op=*/rewrite_grad_op,
- /*use_fusion=*/use_fusion);
+ /*rewrite_grad_op=*/rewrite_grad_op);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@@ -251,11 +204,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape();
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = add(HloInstruction::CreateBroadcast(
operand_shape,
@@ -290,28 +243,14 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
feature_shape, operand_squared, zero, dimensions_without_feature,
add_reduce_computation));
- // Fuse two parallel reduces together to improve performance.
- if (use_fusion_ && !batch_norm->has_sharding()) {
- auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
-
- auto fused = computation_->CreateFusionInstruction(
- {tuple, sum, squared_sum, operand_squared},
- HloInstruction::FusionKind::kInput);
-
- sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
-
- squared_sum =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
- }
-
// E[X].
- auto mean = add(Mean(elements_per_feature_int64, sum));
+ auto mean = add(Mean(elements_per_feature_int64, sum, add));
auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
// E[X^2].
- auto square_mean = add(Mean(elements_per_feature_int64, squared_sum));
+ auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add));
// E^2[X].
auto mean_square =
@@ -329,7 +268,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
@@ -353,16 +292,22 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining(
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
+ const HloSharding& sharding = batch_norm->sharding();
HloSharding operand_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
inst->set_sharding(operand_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
@@ -385,7 +330,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast(
operand_shape,
@@ -431,7 +376,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon);
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon));
+ auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add));
// X - E[X].
auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract,
@@ -453,14 +398,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference(
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
+ optional<int64> unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
- inst->set_sharding(batch_norm->sharding());
+ inst->set_sharding(sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- shifted_normalized->set_sharding(batch_norm->sharding());
+ shifted_normalized->set_sharding(sharding);
}
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
@@ -512,11 +463,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
const int64 feature_count = activation_shape.dimensions(feature_index);
const int64 elements_per_feature_int64 = size_in_elements / feature_count;
- auto zero_literal = Literal::CreateR0(0.0f);
+ auto zero_literal = LiteralUtil::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
- auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
+ auto epsilon_literal = LiteralUtil::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon_scalar =
add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
@@ -545,10 +496,12 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
// rsqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon_broadcasted =
add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd,
- variance_broadcasted, epsilon_activation)));
+ variance_broadcasted, epsilon_activation),
+ add));
auto rsqrt_var_add_epsilon = add(Rsqrt(
- add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature)));
+ add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature),
+ add));
// X - E[X].
auto activation_minus_mean = add_binary(
@@ -573,21 +526,6 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
feature_shape, grad_output, zero, dimensions_without_feature,
add_reduce_computation));
- if (use_fusion_ && !batch_norm->has_sharding()) {
- auto tuple = add(HloInstruction::CreateTuple(
- {sum_grad_output_times_activiation_minus_mean, grad_beta}));
-
- auto fused = computation_->CreateFusionInstruction(
- {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
- HloInstruction::FusionKind::kInput);
-
- sum_grad_output_times_activiation_minus_mean =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
-
- grad_beta =
- add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
- }
-
// Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply,
sum_grad_output_times_activiation_minus_mean,
@@ -616,11 +554,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted,
rsqrt_var_add_epsilon_broadcasted);
- scale_times_rsqrt_var_add_epsilon =
- add(Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon));
+ scale_times_rsqrt_var_add_epsilon = add(
+ Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add));
auto elements_per_feature_literal =
- Literal::CreateR0<float>(elements_per_feature_int64);
+ LiteralUtil::CreateR0<float>(elements_per_feature_int64);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = add(
@@ -640,19 +578,25 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad(
auto tuple =
HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
if (batch_norm->has_sharding()) {
+ const HloSharding& sharding = batch_norm->sharding();
int64 instruction_count_after = computation_->instruction_count();
CHECK_EQ(instruction_count_after,
instruction_count_before + added_instructions.size());
HloSharding activation_sharding =
- batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0});
+ sharding.GetAsShapeTree(batch_norm->shape()).element({0});
+ auto unique_device = batch_norm->sharding_unique_device();
+ HloSharding default_sharding =
+ unique_device.has_value()
+ ? HloSharding::AssignDevice(unique_device.value())
+ : HloSharding::Replicate();
for (HloInstruction* inst : added_instructions) {
if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
inst->set_sharding(activation_sharding);
} else {
- inst->set_sharding(HloSharding::Replicate());
+ inst->set_sharding(default_sharding);
}
}
- tuple->set_sharding(batch_norm->sharding());
+ tuple->set_sharding(sharding);
}
TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
@@ -665,8 +609,8 @@ StatusOr<bool> BatchNormExpander::Run(HloModule* module) {
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_,
- rewrite_inference_op_, rewrite_grad_op_,
- use_fusion_)) {
+ rewrite_inference_op_,
+ rewrite_grad_op_)) {
changed = true;
}
}
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 4ad987085d..7ae202c583 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface {
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
bool rewrite_inference_op = false,
- bool rewrite_grad_op = false, bool use_fusion = true)
+ bool rewrite_grad_op = false)
: rewrite_training_op_(rewrite_training_op),
rewrite_inference_op_(rewrite_inference_op),
- rewrite_grad_op_(rewrite_grad_op),
- use_fusion_(use_fusion) {}
+ rewrite_grad_op_(rewrite_grad_op) {}
~BatchNormExpander() = default;
tensorflow::StringPiece name() const override { return "batchnorm_expander"; }
@@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface {
bool rewrite_training_op_;
bool rewrite_inference_op_;
bool rewrite_grad_op_;
- bool use_fusion_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aa36e64b07..32f785a70a 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -19,12 +19,13 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -114,5 +115,33 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
}
+TEST_F(BatchNormExpanderTest, BatchNormTrainingSharding) {
+ const char* module_str = R"(
+HloModule module
+ENTRY entry {
+ %param.0 = f32[8,4] parameter(0)
+ %param.1 = f32[4] parameter(1)
+ %param.2 = f32[4] parameter(2)
+ ROOT %batch-norm-training = (f32[8,4], f32[4], f32[4])
+ batch-norm-training(f32[8,4] %param.0, f32[4] %param.1, f32[4] %param.2),
+ epsilon=0.001, feature_index=1, sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ BatchNormExpander rewriter(/*rewrite_training_op=*/true,
+ /*rewrite_inference_op=*/true,
+ /*rewrite_grad_op=*/true);
+ ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+
+ for (auto* instruction : module->entry_computation()->instructions()) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ ASSERT_TRUE(instruction->has_sharding());
+ TF_ASSERT_OK_AND_ASSIGN(int device, instruction->sharding().UniqueDevice());
+ EXPECT_EQ(device, 1);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 28e71c2054..f7b4c1405d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto builder = HloComputation::Builder(TestName());
+
+ auto module = CreateNewModule();
+ HloComputation::Builder sum_builder("add");
+ auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
+ auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
+ sum_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
+ HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
+
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
- ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}));
+ ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
+ sum, /*replica_group_ids=*/{}, /*barrier=*/""));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
@@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({gte_a, convert_gte_b}));
- auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(FoldConversions(module.get()));
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 1afaefd9df..830f26422b 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
+ auto module = CreateNewModule();
+ HloComputation::Builder sum_builder("sum");
+ auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
+ auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
+ sum_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
+ HloComputation* reduction =
+ module->AddEmbeddedComputation(sum_builder.Build());
+
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
- ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
+ ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
+ /*replica_group_ids=*/{}, /*barrier=*/""));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
- auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index ed0746980f..b21c83a07f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -85,9 +85,9 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
- for (const auto& index : root_changes_it->second) {
+ for (const auto& entry : root_changes_it->second) {
for (const HloValue* value :
- dataflow_->GetValueSet(root, index).values()) {
+ dataflow_->GetValueSet(root, entry.second).values()) {
changed_root_buffers.insert(value);
}
}
@@ -204,6 +204,12 @@ void BFloat16Propagation::DetermineWhileComputationsPrecision(
bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
const ShapeIndex& index) const {
+ // If the subshape isn't floating point then none of the users will be BF16.
+ const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
+ if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
+ return false;
+ }
+
auto& value_set = dataflow_->GetValueSet(&hlo, index);
for (const HloValue* value : value_set.values()) {
if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
@@ -257,23 +263,34 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
// If the op propagates precision and it outputs a BF16, then it's OK to
// supply BF16 also as the input. In the backward pass, the users shapes
// should have already been processed.
- PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID;
- if (use.instruction->opcode() == HloOpcode::kTuple ||
- (use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
- ShapeUtil::IsTuple(use.instruction->shape()))) {
- ShapeIndex use_output_index{use.operand_number};
- for (int64 i : use.operand_index) {
- use_output_index.push_back(i);
- }
- user_output_type =
- OutputTypeAfterChange(use.instruction, use_output_index);
- } else {
- user_output_type = OutputTypeAfterChange(use.instruction, {});
- }
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
- *use.instruction, use.operand_number) &&
- user_output_type == BF16) {
- continue;
+ *use.instruction, use.operand_number)) {
+ if (use.instruction->opcode() == HloOpcode::kTuple ||
+ (use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
+ ShapeUtil::IsTuple(use.instruction->shape()))) {
+ ShapeIndex use_output_index{use.operand_number};
+ for (int64 i : use.operand_index) {
+ use_output_index.push_back(i);
+ }
+ if (OutputTypeAfterChange(use.instruction, use_output_index) ==
+ BF16) {
+ continue;
+ }
+ } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
+ ShapeIndex use_output_index;
+ for (int64 i = 1; i < use.operand_index.size(); ++i) {
+ use_output_index.push_back(use.operand_index[i]);
+ }
+ if (OutputTypeAfterChange(use.instruction, use_output_index) ==
+ BF16) {
+ continue;
+ }
+ } else {
+ if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
+ BF16) {
+ continue;
+ }
+ }
}
return false;
}
@@ -368,6 +385,7 @@ bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
hlo->opcode() != HloOpcode::kTuple &&
hlo->opcode() != HloOpcode::kGetTupleElement &&
+ hlo->opcode() != HloOpcode::kDomain &&
hlo->shape().element_type() != BF16) {
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
@@ -559,7 +577,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
- std::list<HloComputation*> computations_topological_order =
+ const auto& computations_topological_order =
module->MakeComputationPostOrder();
tensorflow::gtl::FlatSet<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
@@ -597,7 +615,6 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
// (1) a is F32 but tuple is BF16
// (2) after adding conversion
// (3) after tuple simplifier and DCE.
- bool needs_tuple_simplifier = false;
for (auto computation : module->MakeComputationPostOrder()) {
auto insts = computation->MakeInstructionPostOrder();
for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
@@ -611,67 +628,25 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
continue;
}
ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
- // Iterate through nodes in the shape tree in pre-order and initialize
- // each non-root node with a corresponding get-tuple-element. For a leaf
- // node, if its shape does not match the fusion output, create a
- // conversion node to overwrite the node value.
- for (auto it = converted_outputs.begin(); it != converted_outputs.end();
- ++it) {
- ShapeIndex output_index = it->first;
- HloInstruction*& output = it->second;
- const Shape subshape =
- ShapeUtil::GetSubshape(hlo->shape(), output_index);
- if (output_index.empty()) {
- output = fusion_root;
- } else {
- ShapeIndex parent_index = output_index;
- parent_index.pop_back();
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- subshape, converted_outputs.element(parent_index),
- output_index.back()));
- }
- if (ShapeUtil::IsTuple(subshape)) {
- continue;
- }
- if (!ShapeUtil::Compatible(
- subshape,
- ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) {
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateConvert(subshape, output));
- }
- }
- // Iterate through nodes in the shape tree in reverse pre-order and create
- // a tuple instruction for each non-leaf node where the elements are the
- // values of its child nodes.
- for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend();
- ++it) {
- ShapeIndex output_index = it->first;
- HloInstruction*& output = it->second;
- const Shape& subshape =
- ShapeUtil::GetSubshape(hlo->shape(), output_index);
- if (!ShapeUtil::IsTuple(subshape)) {
- continue;
- }
- std::vector<HloInstruction*> elements(
- ShapeUtil::TupleElementCount(subshape));
- ShapeIndex child_index = output_index;
- for (int64 i = 0; i < elements.size(); ++i) {
- child_index.push_back(i);
- elements[i] = converted_outputs.element(child_index);
- child_index.pop_back();
- }
- output = fusion_computation->AddInstruction(
- HloInstruction::CreateTuple(elements));
- }
- fusion_computation->set_root_instruction(converted_outputs.element({}));
- needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape());
+ // Deep copy the fusion root, and convert a leaf node only if its shape
+ // does not match the fusion output.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * copy,
+ fusion_computation->DeepCopyInstructionWithCustomCopier(
+ fusion_root,
+ [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* comp) {
+ const Shape& hlo_subshape =
+ ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
+ if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
+ return leaf;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreateConvert(hlo_subshape, leaf));
+ }));
+ fusion_computation->set_root_instruction(copy);
}
}
- if (needs_tuple_simplifier) {
- TupleSimplifier tuple_simplifier;
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- }
return Status::OK();
}
@@ -740,10 +715,38 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
changes_to_bf16_.clear();
changed_ = false;
+ auto computations_topological_order = module->MakeComputationPostOrder();
+
+ // Before running the propagation pass, we insert copies (kConvert to the same
+ // type) of F32 inputs to while loops. This prevents other uses of the same
+ // input from aliasing the while loop input/output, so that there's greater
+ // chance to use BF16 inside the loop. If some of these added copies do not
+ // help, they will remain F32 after BF16 propagation and will be removed since
+ // they are no-ops.
+ for (auto computation : computations_topological_order) {
+ for (auto inst : computation->MakeInstructionPostOrder()) {
+ if (inst->opcode() != HloOpcode::kWhile) {
+ continue;
+ }
+
+ auto operand = inst->mutable_operand(0);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * copy,
+ computation->DeepCopyInstructionWithCustomCopier(
+ operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* comp) {
+ if (leaf->shape().element_type() != F32) {
+ return leaf;
+ }
+ return comp->AddInstruction(
+ HloInstruction::CreateConvert(leaf->shape(), leaf));
+ }));
+ TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
+ }
+ }
+
TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
- std::list<HloComputation*> computations_topological_order =
- module->MakeComputationPostOrder();
// The first step is a forward pass (parameters to root), where we determine
// the potential candidate instructions to use bfloat16 in the outputs that
// are not likely to cause overhead from extra explicit conversions. This is
@@ -784,39 +787,42 @@ StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
// Apply the changes in changes_to_bf16_.
for (auto& change : changes_to_bf16_) {
- auto shape = change.first->mutable_shape();
- for (const auto& index : change.second) {
- auto subshape = ShapeUtil::GetMutableSubshape(shape, index);
+ for (const auto& entry : change.second) {
+ auto subshape = entry.first;
CHECK_EQ(subshape->element_type(), F32);
subshape->set_element_type(BF16);
changed_ = true;
}
}
+ // Removes redundant HLOs added by this pass, either when inserting
+ // de-aliasing copies to while loop inputs, or later when converting output
+ // types.
+ auto clean_up = [this, module]() {
+ TF_RETURN_IF_ERROR(SkipNoopConversions(module));
+ TupleSimplifier tuple_simplifier;
+ TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
+ HloDCE dce;
+ TF_RETURN_IF_ERROR(dce.Run(module).status());
+ return Status::OK();
+ };
+
if (!changed_) {
+ TF_RETURN_IF_ERROR(clean_up());
return false;
}
TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module));
TF_RETURN_IF_ERROR(ResolveConvertedConstants(module));
- // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 ->
- // BF16), so we skip them now.
- TF_RETURN_IF_ERROR(SkipNoopConversions(module));
-
- {
- // We may have dead HLOs after ResolveInconsistentFusions,
- // ResolveConvertedConstants and SkipNoopConversions.
- HloDCE dce;
- TF_RETURN_IF_ERROR(dce.Run(module).status());
- }
+ TF_RETURN_IF_ERROR(clean_up());
return true;
}
PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
HloInstruction* hlo, const ShapeIndex& index) const {
- PrimitiveType type_on_hlo =
- ShapeUtil::GetSubshape(hlo->shape(), index).element_type();
+ Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
+ const PrimitiveType type_on_hlo = subshape->element_type();
if (type_on_hlo != F32) {
return type_on_hlo;
}
@@ -824,7 +830,7 @@ PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
if (it == changes_to_bf16_.end()) {
return type_on_hlo;
}
- return ContainsKey(it->second, index) ? BF16 : F32;
+ return ContainsKey(it->second, subshape) ? BF16 : F32;
}
PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
@@ -838,14 +844,16 @@ void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
if (target_type == BF16) {
auto& entry = changes_to_bf16_[hlo];
- entry.insert(index);
+ entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
+ index);
} else {
CHECK_EQ(target_type, F32);
auto it = changes_to_bf16_.find(hlo);
if (it == changes_to_bf16_.end()) {
return;
}
- it->second.erase(index);
+ it->second.erase(
+ ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
}
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index de0355ddfc..02b8cad089 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -194,17 +194,11 @@ class BFloat16Propagation : public HloPassInterface {
// are subject to further adjustment, then finally applied to the HLOs. This
// avoids setting changed_ to true but all changes are reverted during
// adjustment.
- struct IndexHasher {
- int64 operator()(const ShapeIndex& index) const {
- int64 hash = 0;
- for (int64 i : index) {
- hash = tensorflow::Hash64Combine(hash, std::hash<int64>()(i));
- }
- return hash;
- }
- };
+ //
+ // For each HloInstruction, changes_to_bf16_ stores the affected buffers in
+ // the output as a map from in-place pointers to subshapes to shape indices.
tensorflow::gtl::FlatMap<HloInstruction*,
- tensorflow::gtl::FlatSet<ShapeIndex, IndexHasher>>
+ tensorflow::gtl::FlatMap<Shape*, ShapeIndex>>
changes_to_bf16_;
// Whether the last processed HLO module has been changed by this pass.
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 5e1499ee6b..aeafb25ad7 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -133,9 +133,9 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
array_b.FillUnique(10.0f);
HloInstruction* a = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateFromArray(array_a)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_a)));
HloInstruction* b = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateFromArray(array_b)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateFromArray(array_b)));
HloInstruction* dot = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kDot, a, b));
@@ -150,11 +150,11 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
EXPECT_TRUE(LiteralTestUtil::Equal(
- dot->operand(0)->literal(),
- *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
+ *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_a)),
+ dot->operand(0)->literal()));
EXPECT_TRUE(LiteralTestUtil::Equal(
- dot->operand(1)->literal(),
- *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
+ *LiteralUtil::ConvertF32ToBF16(*LiteralUtil::CreateFromArray(array_b)),
+ dot->operand(1)->literal()));
}
// Tests that BF16 can be propagated through nested tuples.
@@ -240,12 +240,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), dot);
- EXPECT_TRUE(OutputsBF16(add0));
EXPECT_TRUE(OutputsBF16(add1));
EXPECT_TRUE(OutputsBF16(lhs));
- // rhs is a get-tuple-element, which does not define a buffer, but its shape
- // should also be adjusted accordingly.
- EXPECT_TRUE(OutputsBF16(rhs));
+
+ // add0 and rhs have been eliminated by simplification and DCE.
}
// Tests that a non-fusion computation's root should not be changed.
@@ -434,7 +432,7 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) {
HloInstruction* tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({param, add1}));
HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ tuple0->shape(), HloOpcode::kTupleSelect, pred, tuple0, tuple1));
HloInstruction* gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(shape, sel, 0));
HloInstruction* gte1 = builder.AddInstruction(
@@ -734,12 +732,95 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), add2);
- EXPECT_EQ(add2->operand(0), gte0);
- EXPECT_EQ(add2->operand(1), gte1);
- EXPECT_EQ(gte0->shape().element_type(), BF16);
- EXPECT_EQ(gte1->shape().element_type(), BF16);
+ EXPECT_EQ(add2->operand(0), add0);
+ EXPECT_EQ(add2->operand(1), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
EXPECT_EQ(add1->shape().element_type(), BF16);
}
+TEST_F(BFloat16PropagationTest, TupleDomain) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+
+ HloInstruction* a =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* a_trans =
+ builder.AddInstruction(HloInstruction::CreateTranspose(shape, a, {0, 1}));
+ HloInstruction* b_trans =
+ builder.AddInstruction(HloInstruction::CreateTranspose(shape, b, {0, 1}));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({a_trans, b_trans}));
+ HloInstruction* domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(tuple->shape(), tuple, nullptr, nullptr));
+ HloInstruction* a_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 0));
+ HloInstruction* b_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 1));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_gte, b_gte));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+ EXPECT_EQ(computation->root_instruction(), root);
+
+ // test BF16 propagated through domain
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
+ BF16);
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
+ BF16);
+
+ EXPECT_TRUE(OutputsBF16(a_trans));
+ EXPECT_TRUE(OutputsBF16(b_trans));
+ EXPECT_TRUE(OutputsBF16(a_gte));
+ EXPECT_TRUE(OutputsBF16(b_gte));
+ EXPECT_FALSE(OutputsBF16(a));
+ EXPECT_FALSE(OutputsBF16(b));
+}
+
+// Tests that bf16 is not propagated through a domain in case its input cannot
+// be propagated. In the case below the input of the domain is the parameter
+// tuple which cannot be propagated, so the domain instruction is not propagated
+// either.
+TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ HloInstruction* domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
+ HloInstruction* a_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 0));
+ HloInstruction* b_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 1));
+ HloInstruction* a_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
+ HloInstruction* b_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), root);
+ EXPECT_TRUE(OutputsBF16(a_trans));
+ EXPECT_TRUE(OutputsBF16(b_trans));
+ EXPECT_FALSE(OutputsBF16(a_gte));
+ EXPECT_FALSE(OutputsBF16(b_gte));
+ EXPECT_FALSE(OutputsBF16(domain));
+ EXPECT_FALSE(OutputsBF16(param));
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc
index 07b4b14b5e..23645346e6 100644
--- a/tensorflow/compiler/xla/service/bfloat16_support.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_support.cc
@@ -25,6 +25,7 @@ bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
+ case HloOpcode::kDomain:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
@@ -43,6 +44,7 @@ bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
+ case HloOpcode::kDomain:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
@@ -81,6 +83,7 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
+ case HloOpcode::kDomain:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -92,11 +95,15 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return true;
+ case HloOpcode::kBitcast:
+ return hlo.shape().element_type() ==
+ hlo.operand(0)->shape().element_type();
case HloOpcode::kDynamicSlice:
return operand_index == 0;
case HloOpcode::kDynamicUpdateSlice:
return operand_index == 0 || operand_index == 1;
case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
return operand_index == 1 || operand_index == 2;
default:
break;
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index c0b8bf9039..afe4b2e142 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType(
worklist.push_back(std::make_pair(subcomputation,
false)); // Not thread local.
break;
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
@@ -632,7 +633,7 @@ Status BufferAssignment::ComputeSummaryStats() {
if (module_sequence.size() == module_->computation_count()) {
TF_ASSIGN_OR_RETURN(
const int64 min_size,
- MinimumMemoryForSequence(module_sequence, buffer_size_));
+ HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_));
stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size;
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index bdcea92882..125ade2a11 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
@@ -32,12 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -125,7 +125,7 @@ class BufferAssignmentTest : public HloTestBase {
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
auto value = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
return builder.Build();
@@ -142,7 +142,7 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto index = builder.AddInstruction(
@@ -167,9 +167,9 @@ class BufferAssignmentTest : public HloTestBase {
const string& name) {
auto builder = HloComputation::Builder(name);
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
auto indexc = builder.AddInstruction(
@@ -290,7 +290,7 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
TEST_F(BufferAssignmentTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -304,9 +304,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) {
// no buffers assigned, and their consumer has a buffer.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
auto module = CreateNewModule();
@@ -327,7 +327,7 @@ TEST_F(BufferAssignmentTest, HasAllocationAt) {
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32vec100_, "param0"));
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
auto tuple = builder.AddInstruction(
@@ -352,7 +352,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) {
// This computation copies a constant to output.
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
auto module = CreateNewModule();
@@ -371,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -418,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) {
// share anything.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -477,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) {
// have the color 0, which allows the mul and add to share buffers.
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -547,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
//
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -601,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
module->AddEntryComputation(builder.Build());
@@ -654,13 +654,13 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ HloInstruction::CreateParameter(0, f32a100x10_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
auto exp2 = builder.AddInstruction(
HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
/*shape=*/f32vec10_,
/*operand=*/exp2,
@@ -708,9 +708,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) {
// Creates the main kernel and verifies instruction counts.
auto builder = HloComputation::Builder(TestName());
auto const3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
@@ -773,11 +773,11 @@ TEST_F(BufferAssignmentTest, ExampleConditional) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
auto const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.4f)));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
r0f32_, pred, const1, true_computation, const2, false_computation));
module->AddEntryComputation(builder.Build());
@@ -818,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
// param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32vec100_, ""));
+ HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
auto tanh = builder.AddInstruction(
@@ -1200,8 +1200,9 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
// Test that a tuple constant which is forwarded to the computation output
// is properly handled.
auto builder = HloComputation::Builder(TestName());
- builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple(
- {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()})));
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()})));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1365,8 +1366,9 @@ TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
HloInstruction::CreateParameter(1, tuple_shape, "param1"));
auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
2, ShapeUtil::MakeShape(PRED, {}), "param1"));
- auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1));
+ auto select = builder.AddInstruction(
+ HloInstruction::CreateTernary(tuple_shape, HloOpcode::kTupleSelect,
+ pred_param, tuple_param0, tuple_param1));
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
@@ -1496,11 +1498,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) {
// param1[100] --------------/--------/
auto builder = HloComputation::Builder(TestName());
auto paramscalar =
- builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p"));
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, f32vec100_, ""));
+ HloInstruction::CreateParameter(1, f32vec100_, "p1"));
auto param1 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, f32vec100_, ""));
+ HloInstruction::CreateParameter(2, f32vec100_, "p2"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
auto add = builder.AddInstruction(
@@ -1536,7 +1538,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
// be {%rev, %neg, %concat}. This occurs right at the concat itself.
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32vec100_, ""));
+ HloInstruction::CreateParameter(0, f32vec100_, "p"));
auto log = builder.AddInstruction(
HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param));
auto rev = builder.AddInstruction(
@@ -1583,7 +1585,7 @@ TEST_F(BufferAssignmentTest, PeakBuffersWhile) {
auto b = HloComputation::Builder(TestName() + ".cond");
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
condition = module->AddEmbeddedComputation(b.Build());
}
HloComputation* body;
@@ -1646,9 +1648,9 @@ class WhileBufferAssignmentTest : public HloTestBase {
builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
auto ten = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(10)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(10)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten));
return builder.Build();
@@ -1673,7 +1675,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
auto sequence =
- CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
+ ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
ByteSizeOf,
@@ -1707,7 +1709,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
HloInstruction::CreateParameter(2, data_shape_, "weights1"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
@@ -1793,7 +1795,7 @@ ENTRY %test_module {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
@@ -1850,7 +1852,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto build_cond = [&]() {
auto builder = HloComputation::Builder("cond");
auto const4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1862,7 +1864,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto build_body = [&]() {
auto builder = HloComputation::Builder("body");
auto const9 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int>(9)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(9)));
auto param =
builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x"));
builder.AddInstruction(
@@ -1874,11 +1876,15 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
auto module = CreateNewModule();
auto builder = HloComputation::Builder("entry");
- auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, ""));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ auto infeed =
+ builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, token, ""));
+ auto infeed_data = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r0s32, infeed, 0));
auto cond0 = module->AddEmbeddedComputation(build_cond());
auto body0 = module->AddEmbeddedComputation(build_body());
auto while0 = builder.AddInstruction(
- HloInstruction::CreateWhile(r0s32, cond0, body0, infeed));
+ HloInstruction::CreateWhile(r0s32, cond0, body0, infeed_data));
auto cond1 = module->AddEmbeddedComputation(build_cond());
auto body1 = module->AddEmbeddedComputation(build_body());
@@ -1886,7 +1892,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
HloInstruction::CreateWhile(r0s32, cond1, body1, while0));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero));
auto cond2 = module->AddEmbeddedComputation(build_cond());
@@ -1909,8 +1915,8 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) {
// computation, since the issue this test stresses depends on the order the
// nodes are traversed during BufferAssignment.
SequentialHloOrdering::HloModuleSequence sequence;
- sequence[module->entry_computation()] = {infeed, while0, while1, zero,
- add, while2, tuple};
+ sequence[module->entry_computation()] = {
+ token, infeed, infeed_data, while0, while1, zero, add, while2, tuple};
TF_ASSERT_OK_AND_ASSIGN(
auto assignment,
BufferAssigner::Run(
@@ -1948,7 +1954,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
@@ -1992,16 +1998,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) {
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "param"));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1));
sub_computation = module->AddEmbeddedComputation(builder.Build(add));
}
auto builder = HloComputation::Builder(TestName());
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto call1 = builder.AddInstruction(
HloInstruction::CreateCall(r0f32, {constant2}, sub_computation));
auto call2 = builder.AddInstruction(
@@ -2053,9 +2059,9 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto input0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape_, "input0"));
@@ -2103,7 +2109,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
RunCopyInsertion(module.get());
auto sequence =
- CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
+ ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo sequence for the
// root computation, so we overwrite that entry with a manually
@@ -2137,7 +2143,7 @@ TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
HloInstruction::CreateParameter(1, data_shape_, "weights0"));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto output1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index f623aef67a..4a927b5767 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -327,11 +327,12 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) {
builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(
- HloInstruction::CreateRecv(vec_, /*channel_id=*/0));
+ HloInstruction::CreateRecv(vec_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(recv_done, /*channel_id=*/1));
+ HloInstruction::CreateSend(recv_done, token, /*channel_id=*/1));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
auto module = CreateNewModule();
@@ -438,11 +439,13 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
// computation. The buffer containing {0, 1} is copied by GetTupleElement, and
// the buffers containing {3} and 3 are dead.
auto builder = HloComputation::Builder(TestName());
- auto inner_tuple0 = Literal::MakeTuple(
- {Literal::CreateR0<int64>(0).get(), Literal::CreateR0<int64>(1).get()});
- auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0<int64>(3).get()});
+ auto inner_tuple0 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()});
+ auto inner_tuple1 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
inner_tuple0->shape(), tuple_constant, 0));
@@ -490,7 +493,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@@ -502,7 +505,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element1_shape, tuple_param0, 1));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
@@ -554,7 +557,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
tuple_element0_shape, tuple_param0, 0));
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
@@ -626,7 +629,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
HloInstruction* slice = nullptr;
if (update_uses_tuple_element1) {
// Create a slice instruction as an additional user of 'gte1'.
@@ -637,7 +640,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -756,7 +759,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
@@ -765,7 +768,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index a8053d15e1..a23427f00c 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kConditional:
case HloOpcode::kWhile:
return CallContext::kSequential;
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc
index 1ea7d538cd..cc80b74843 100644
--- a/tensorflow/compiler/xla/service/call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/call_graph_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -82,7 +82,7 @@ class CallGraphTest : public HloTestBase {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
return builder.Build();
@@ -247,11 +247,11 @@ TEST_F(CallGraphTest, ComputationWithConditional) {
HloComputation::Builder builder(TestName());
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.4f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.4f)));
HloInstruction* const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.6f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.6f)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
kScalarShape, pred, const1, true_computation, const2,
diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc
index 482ccc5b67..256d05a73e 100644
--- a/tensorflow/compiler/xla/service/call_inliner.cc
+++ b/tensorflow/compiler/xla/service/call_inliner.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "tensorflow/compiler/xla/service/call_graph.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -151,6 +152,14 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
}
return Status::OK();
}));
+ if (did_mutate) {
+ // Run DCE to remove called computations which are now becoming unused.
+ // This can result then in problems if within the called computation, there
+ // were send/recv instructions, which the module group verifier will flag as
+ // error findingthe same channel ID used for multiple send/recv
+ // instructions.
+ TF_RETURN_IF_ERROR(HloDCE().Run(module).status());
+ }
return did_mutate;
}
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 738d00881d..ff968bca29 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,9 +48,9 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// the "one" value.
HloComputation::Builder inner(TestName() + ".inner");
HloInstruction* zero = inner.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(24.0f)));
HloInstruction* one = inner.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
TF_ASSERT_OK(zero->AddControlDependencyTo(one));
auto module = CreateNewModule();
HloComputation* inner_computation =
@@ -87,7 +87,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
// little trickier.
HloComputation::Builder just_false(TestName() + ".false");
just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
@@ -99,7 +99,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder outer(TestName() + ".outer");
HloInstruction* init_value = outer.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
outer.AddInstruction(
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
@@ -123,9 +123,9 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
HloComputation::Builder just_false(TestName() + ".false");
auto* true_constant = just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<bool>({true})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<bool>({true})));
auto* false_constant = just_false.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
@@ -147,15 +147,17 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
HloComputation::Builder outfeeder(TestName() + ".outfeeder");
auto value = outfeeder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto token = outfeeder.AddInstruction(HloInstruction::CreateToken());
outfeeder.AddInstruction(
- HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/""));
+ HloInstruction::CreateOutfeed(f32, value, token, /*outfeed_config=*/""));
auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build());
HloComputation::Builder outer(TestName() + ".outer");
outer.AddInstruction(HloInstruction::CreateCall(
- ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation));
+ outfeed_computation->root_instruction()->shape(), /*operands=*/{},
+ outfeed_computation));
module->AddEntryComputation(outer.Build());
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
index e415fb27e6..fac0afd672 100644
--- a/tensorflow/compiler/xla/service/channel_tracker.h
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -19,8 +19,6 @@ limitations under the License.
#include <map>
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc
deleted file mode 100644
index b16907da9e..0000000000
--- a/tensorflow/compiler/xla/service/compilation_cache.cc
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/compilation_cache.h"
-
-#include <utility>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace xla {
-
-std::shared_ptr<Executable> CompilationCache::Insert(
- std::unique_ptr<Executable> executable,
- const HloModuleConfig& module_config) {
- tensorflow::mutex_lock lock(mutex_);
-
- CacheKey key =
- BuildKey(executable->entry_computation_handle(), module_config);
- VLOG(2) << "inserting cache key: " << key;
- if (cache_.count(key) == 0) {
- cache_.emplace(key, std::move(executable));
- } else {
- // Executable already exists in the cache. This can happen if two Execute
- // calls for a new computation are received simultaneously by the
- // service. In this case, we discard the Executable given as a parameter and
- // return what is in the cache. This is necessary because the service relies
- // on the cache to keep ownership of the Executable. We only want to store
- // one Executable for a given computation version and we can't discard the
- // executable which is in the cache because it may be in use.
- executable.reset();
- }
- return cache_.at(key);
-}
-
-std::shared_ptr<Executable> CompilationCache::LookUp(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const {
- tensorflow::mutex_lock lock(mutex_);
-
- CacheKey key = BuildKey(versioned_handle, module_config);
- VLOG(2) << "looking up cache key: " << key;
- if (cache_.count(key) == 0) {
- VLOG(2) << "cache key not found: " << key;
- return nullptr;
- } else {
- std::shared_ptr<Executable> result = cache_.at(key);
- VLOG(2) << "hit executable with module config: "
- << result->module_config().compilation_cache_key();
- return result;
- }
-}
-
-CompilationCache::CacheKey CompilationCache::BuildKey(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const {
- // The computation shape is represented entirely by its ProgramShape member,
- // so just serialize the proto as part of the key.
- return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::",
- versioned_handle.version, "::",
- module_config.compilation_cache_key());
-}
-
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h
deleted file mode 100644
index 09989726ae..0000000000
--- a/tensorflow/compiler/xla/service/compilation_cache.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
-
-#include <map>
-#include <memory>
-#include <string>
-
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-
-namespace xla {
-
-// A cache which stores Executables indexed by computation handle and version.
-class CompilationCache {
- public:
- CompilationCache() {}
-
- // Insert the given Executable into the cache. Return a bare Executable
- // pointer for the caller to use. Note: the returned pointer will *not* be the
- // same as the given unique pointer if the computation already exists in the
- // cache. See comments in the .cc implementation for details of this case.
- //
- // module_config is provided by the caller, instead of being taken from the
- // executable, so that we can insert keys into the compilation cache that are
- // devoid of layout (where XLA gets to choose what layout to compile).
- //
- // A shared_ptr is returned so the caller can keep the Executable from being
- // destructed in the event that the Executable is evicted from the
- // computation cache (and the cache's shared_ptr to the Executable is
- // destructed).
- std::shared_ptr<Executable> Insert(std::unique_ptr<Executable> executable,
- const HloModuleConfig& module_config);
-
- // Lookup the Executable for the specified versioned computation in the cache.
- // Return a shared_ptr to the Executable if it exists in the cache. Return
- // nullptr otherwise.
- std::shared_ptr<Executable> LookUp(
- const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const;
-
- protected:
- mutable tensorflow::mutex mutex_;
-
- // Map from versioned handle with program layout to Executable built
- // for that computation version and program layout.
- using CacheKey = string;
-
- CacheKey BuildKey(const VersionedComputationHandle& versioned_handle,
- const HloModuleConfig& module_config) const;
- std::map<CacheKey, std::shared_ptr<Executable>> cache_ GUARDED_BY(mutex_);
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
-};
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index d8fdccf9bb..7426672a7a 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -63,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
- const AotCompilationOptions& options) {
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
std::vector<std::unique_ptr<HloModule>> hlo_modules;
for (const AotXlaComputationInstance& instance : computations) {
TF_RET_CHECK(instance.computation.has_program_shape());
@@ -100,7 +101,8 @@ CompileOnlyService::CompileAheadOfTime(
hlo_modules.push_back(std::move(hlo_module));
}
- return compiler_->CompileAheadOfTime(std::move(hlo_modules), options);
+ return compiler_->CompileAheadOfTime(std::move(hlo_modules), options,
+ metadata);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h
index e6a66c202d..1ac950bdd6 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.h
+++ b/tensorflow/compiler/xla/service/compile_only_service.h
@@ -53,6 +53,12 @@ class CompileOnlyService : public Service {
const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
const AotCompilationOptions& options);
+ StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(
+ const tensorflow::gtl::ArraySlice<AotXlaComputationInstance> computations,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata);
+
Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
GetDeviceHandlesResponse* result) override {
return Unimplemented("CompileOnlyService does not support devices.");
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
index 6f06bba679..6b3b9820f0 100644
--- a/tensorflow/compiler/xla/service/compiler.cc
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -35,6 +35,27 @@ Compiler::ComputeBackendConfigs(const HloInstruction& hlo,
return {};
}
+std::unique_ptr<tensorflow::protobuf::Message>
+Compiler::ComputeDefaultBackendConfig(const HloInstruction& hlo,
+ se::StreamExecutor* executor) const {
+ CHECK(executor != nullptr);
+ return nullptr;
+}
+
+// Define a default version where metadata is not used.
+StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+Compiler::CompileAheadOfTime(
+ std::vector<std::unique_ptr<HloModule>> modules,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata) {
+ if (metadata != nullptr) {
+ return Unimplemented(
+ "Populating AotCompilationMetadata is not implemented on this "
+ "compiler.");
+ }
+ return CompileAheadOfTime(std::move(modules), options);
+}
+
/* static */ std::map<se::Platform::Id, Compiler::CompilerFactory>*
Compiler::GetPlatformCompilerFactories() {
static auto* r = new std::map<se::Platform::Id, CompilerFactory>;
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 6c52ffd800..99abb9bae3 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -94,6 +94,19 @@ class AotCompilationOptions {
DebugOptions debug_options_;
};
+// Abstract superclass describing metadata produced during ahead-of-time
+// compilation.
+class AotCompilationMetadata {
+ public:
+ AotCompilationMetadata(const AotCompilationMetadata&) = delete;
+ AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
+
+ virtual ~AotCompilationMetadata() = default;
+
+ protected:
+ AotCompilationMetadata() = default;
+};
+
// Abstract compiler interface that is subclassed for compilation on a
// particular platform.
//
@@ -166,12 +179,29 @@ class Compiler {
ComputeBackendConfigs(const HloInstruction& hlo,
se::StreamExecutor* executor) const;
+ // Returns the backend configuration that the backend chooses by default for
+ // the given HLO. Returns no configuration if the backend does not support
+ // configurations for the given HLO.
+ //
+ // The stream executor is passed in to provide information about the hardware
+ // that the backend configurations would be targeting.
+ virtual std::unique_ptr<tensorflow::protobuf::Message>
+ ComputeDefaultBackendConfig(const HloInstruction& hlo,
+ se::StreamExecutor* executor) const;
+
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
const AotCompilationOptions& options) = 0;
+ // Similar to CompileAheadOfTime above but AotCompilationMetadata
+ // has an argument that can be populated during compilation.
+ virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
+ CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
+ const AotCompilationOptions& options,
+ std::unique_ptr<AotCompilationMetadata>* metadata);
+
/////
// The Compiler class also serves as a point to register compiler objects
// for the various platforms.
diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h
index 53c3a3f7b7..6975f387b4 100644
--- a/tensorflow/compiler/xla/service/computation_layout.h
+++ b/tensorflow/compiler/xla/service/computation_layout.h
@@ -32,12 +32,21 @@ namespace xla {
// mutable layouts.
class ComputationLayout {
public:
+ // Creates a new ComputationLayout with the given result layout.
+ explicit ComputationLayout(ShapeLayout result_layout)
+ : result_layout_(std::move(result_layout)) {}
+
// Constructs a ComputationLayout from a ProgramShape. The layouts of the
// parameters and results are set to the default layout. Layouts in the
// ProgramShape are ignored if ignore_layouts is true.
explicit ComputationLayout(const ProgramShape& program_shape,
bool ignore_layouts = true);
+ // Adds a new parameter layout to the computation layout.
+ void add_parameter_layout(ShapeLayout shape_layout) {
+ parameter_layouts_.push_back(std::move(shape_layout));
+ }
+
// Returns the layout of a particular parameter.
const ShapeLayout& parameter_layout(int64 param_no) const {
return parameter_layouts_[param_no];
diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc
index 7c1bacff92..d26486fcfe 100644
--- a/tensorflow/compiler/xla/service/computation_placer.cc
+++ b/tensorflow/compiler/xla/service/computation_placer.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc
index e9ec796121..b7be3ba605 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
index 868348547d..c43a31b167 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc
@@ -55,7 +55,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
true_computation_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "param"));
auto one = true_computation_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
true_computation_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, one));
@@ -73,7 +73,7 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}),
"param"));
auto forty_two = false_computation_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
false_computation_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, forty_two));
@@ -82,11 +82,11 @@ HloComputation* ConditionalSimplifierTest::MakeConditional(HloModule* module) {
}
auto false_instrn = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto false_param = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {}), "false_param"));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
builder.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeShape(S32, {}), false_instrn, one, true_computation,
@@ -106,7 +106,7 @@ TEST_F(ConditionalSimplifierTest, ConditionalWithControlDependency) {
HloComputation* computation = MakeConditional(&module());
auto* true_op = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
TF_ASSERT_OK(
true_op->AddControlDependencyTo(computation->root_instruction()));
@@ -119,10 +119,11 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsSend) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
+ auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
auto* send = true_computation->AddInstruction(HloInstruction::CreateSend(
true_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
- /*channel_id=*/0));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
+ token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
@@ -133,8 +134,9 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsRecv) {
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* true_computation = conditional->true_computation();
+ auto* token = true_computation->AddInstruction(HloInstruction::CreateToken());
auto* recv = true_computation->AddInstruction(HloInstruction::CreateRecv(
- ShapeUtil::MakeShape(F32, {1}), /*channel_id=*/0));
+ ShapeUtil::MakeShape(F32, {1}), token, /*channel_id=*/0));
true_computation->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
@@ -144,8 +146,9 @@ TEST_F(ConditionalSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) {
auto* conditional = computation->root_instruction();
ASSERT_EQ(conditional->opcode(), HloOpcode::kConditional);
auto* false_computation = conditional->false_computation();
- false_computation->AddInstruction(
- HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
+ auto token = false_computation->AddInstruction(HloInstruction::CreateToken());
+ false_computation->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(ConditionalSimplifier().Run(&module()).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 33d8338809..ab3d846403 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -472,6 +472,10 @@ class CopyRemover {
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
+ if (ShapeUtil::IsToken(value_a->shape())) {
+ // Token values have no representation and cannot interfere.
+ continue;
+ }
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
@@ -613,7 +617,10 @@ class CopyRemover {
VLOG(2) << copy->name() << " is not removable";
return false;
}
-
+ if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
+ VLOG(2) << copy->name() << " is not removable (shape mismatch)";
+ return false;
+ }
const CopyNodes& copy_node = copy_map_.at(copy);
ValueNode* src = copy_node.src;
ValueNode* dest = copy_node.dest;
@@ -947,28 +954,6 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
- }
- }
- }
- return Status::OK();
-}
-
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
HloInstruction* instruction = pair.first;
const ShapeTree<bool>& indices_to_copy = pair.second;
+ ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
std::vector<HloInstruction*> users = instruction->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
+ instruction, &indices_to_copy, &copies_added));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
+ // Special case copies are not eligible for later copy elision passes.
+ indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
+ if (has_copy) {
+ HloInstruction* copy = *copies_added.mutable_element(index);
+ if (copy != nullptr) {
+ copy->SetCopyElisionAllowed(false);
+ }
+ }
+ });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering, HloModule* module,
+ const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer) {
+ MaybeDumpModule("after adding copies to resolve interference", *module);
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ HloAliasAnalysis::Run(module, fusion_can_share_buffer));
+ CopyRemover copy_remover(*alias_analysis, ordering, module);
+ XLA_VLOG_LINES(3, copy_remover.ToString());
+
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy &&
+ instruction->CopyElisionAllowed()) {
+ TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ }
+ }
+ }
+ MaybeDumpModule("after removing unnecessary copies", *module);
+
+ return Status::OK();
+}
+
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Copy insertion is performed in three steps:
//
@@ -1130,16 +1150,13 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
"Call graph must be flattened before copy insertion.");
}
- // Gather Ids of existing kCopy instructions in the module. We avoid removing
- // these copies (except via DCE in TupleSimplifier) because they may have been
- // added for reasons not considered by copy insertion (eg, layout assignment).
- // Instruction id is used instead of HloInstruction* because the pointer
- // values may be recycled.
- tensorflow::gtl::FlatSet<int> existing_copies;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- existing_copies.insert(instruction->unique_id());
+ int64 num_existing_copies = 0;
+ if (VLOG_IS_ON(1)) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy) {
+ ++num_existing_copies;
+ }
}
}
}
@@ -1158,13 +1175,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(
- RemoveUnnecessaryCopies(ordering, existing_copies, module));
-
- MaybeDumpModule("after removing unnecessary copies", *module);
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
@@ -1185,7 +1197,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
}
}
}
- VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size();
+ VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 65e3d31e34..e1973db928 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -48,6 +47,15 @@ class CopyInsertion : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "copy-insertion"; }
+ // fusion_can_share_buffer: backend specific function that decides whether a
+ // fusion can share buffer with its operand.
+ //
+ // TODO(b/80315712): Find a better way to tell whether a fusion can share
+ // buffer.
+ CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr)
+ : fusion_can_share_buffer_(fusion_can_share_buffer) {}
+
// Run the pass on the given module. Returns whether the module was changed
// (copies were inserted).
StatusOr<bool> Run(HloModule* module) override;
@@ -62,8 +70,21 @@ class CopyInsertion : public HloPassInterface {
//
// TODO(b/62548313): Remove this when buffer assignment is module-scoped.
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
+
+ private:
+ // Backend specific function that decides whether a fusion can share buffer
+ // with its operand.
+ HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
};
+// Try to remove as many copies from the module as possible without introducing
+// live range interference. Only copy instructions that are eligible for
+// copy elision are considered for removal.
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering, HloModule* module,
+ const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 153f062d01..cd735256b8 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -108,7 +108,7 @@ TEST_F(CopyInsertionTest, SingleConstant) {
// be copied before entering the tuple.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
@@ -125,21 +125,27 @@ TEST_F(CopyInsertionTest, SingleConstant) {
}
TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
- // Verify that an kCopy instructions which exist in the pass before
+ // Verify that kCopy instructions which change layout and exist before
// copy-insertion remain in the graph after copy-insertion.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
- HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}})));
+ auto minor_to_major = LayoutUtil::MinorToMajor(constant->shape());
+ Layout reversed_layout =
+ LayoutUtil::MakeLayoutFromMajorToMinor(minor_to_major);
+ Shape copy_shape = constant->shape();
+ *copy_shape.mutable_layout() = reversed_layout;
+ HloInstruction* copy_1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
+ HloInstruction* copy_2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(copy_shape, HloOpcode::kCopy, constant));
HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
- HloInstruction* add_copy = builder.AddInstruction(
- HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(add->shape(), HloOpcode::kCopy, add));
module->AddEntryComputation(builder.Build());
@@ -147,12 +153,11 @@ TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 3);
+ EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))));
+ EXPECT_EQ(module->entry_computation()->root_instruction(), add);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())));
}
TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
@@ -162,9 +167,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* x = builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
@@ -192,11 +197,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
// the computation result. Verify that copies are added properly.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -204,9 +209,9 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
HloInstruction::CreateTuple({constant3, constant2}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
EXPECT_THAT(constant1->users(), UnorderedElementsAre(tuple1));
EXPECT_THAT(constant2->users(), UnorderedElementsAre(tuple1, tuple2));
@@ -250,8 +255,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
// The output of a bitcast is its operand (same buffer), so a bitcast
// constant feeding the result must have a copy added.
auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 42.0})));
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 42.0})));
HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
@@ -365,9 +371,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
// copy is added.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -375,9 +381,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
HloInstruction::CreateTuple({constant2, constant1}));
HloInstruction* pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
HloInstruction* gte =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
@@ -408,7 +414,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const Shape& loop_state_shape) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
auto induction_variable =
@@ -437,7 +443,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// Update data GTE(1).
@@ -475,7 +481,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -544,7 +550,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
// add0 = Add(in0, 1)
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
@@ -559,8 +565,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
data = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
}
- auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
@@ -593,7 +600,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, 0));
auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
gte0->shape(), HloOpcode::kAdd, gte0, inc));
@@ -603,8 +610,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// GTE(GTE(loop_state, 1), 0) -> Add
auto gte10 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
- auto update10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto update10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, gte10, update10));
@@ -628,10 +636,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
bool nested = false) {
auto builder = HloComputation::Builder(TestName() + ".While");
auto induction_var_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
if (nested) {
auto inner_init = builder.AddInstruction(
@@ -654,8 +663,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
auto builder = HloComputation::Builder(TestName() + ".While");
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
&builder);
}
@@ -672,11 +682,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v1 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto v2 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
@@ -684,9 +694,9 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
- nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2));
+ nested_tuple_shape_, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
data_init, &builder);
@@ -696,7 +706,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto one_vec = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
auto data_init =
@@ -709,11 +719,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
auto builder = HloComputation::Builder(TestName() + ".While");
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto data_init = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, one, {1}));
- auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto one_vec = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// Take a reference to 'data_init' to make it interfere with while result.
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data_init, one_vec));
@@ -745,7 +756,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
const bool nested =
ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
auto induction_var_init = builder->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto condition = module_->AddEmbeddedComputation(
BuildConditionComputation(loop_state_shape));
auto body = module_->AddEmbeddedComputation(
@@ -1247,7 +1258,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
-
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition1, body1, loop_init));
@@ -1305,7 +1315,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1313,9 +1323,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1370,7 +1380,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1378,9 +1388,9 @@ TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -1430,7 +1440,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1438,7 +1448,7 @@ TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
builder.AddInstruction(
@@ -1515,7 +1525,7 @@ TEST_F(CopyInsertionTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
cond_builder.AddInstruction(HloInstruction::CreateUnary(
cond_constant->shape(), HloOpcode::kNot, cond_constant));
HloComputation* condition =
@@ -1570,14 +1580,14 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -1595,12 +1605,51 @@ TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
EXPECT_THAT(condition->root_instruction(), op::Constant());
}
+TEST_F(CopyInsertionTest, TokensShouldNotBeCopied) {
+ string module_string = R"(
+HloModule TokensShouldNotBeCopied
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %TokensShouldNotBeCopied () -> s32[] {
+ %one = s32[] constant(1)
+ %negative_one = s32[] negate(%one)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %negative_one, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(
+ module_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should be no copies added because tokens should not be copied.
+ EXPECT_EQ(CountCopies(*module), 0);
+}
+
std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
auto builder = HloComputation::Builder("trivial_condition");
builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "loop_state"));
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kNot, constant));
return builder.Build();
@@ -1636,8 +1685,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
+ HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_SequentialWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1677,8 +1725,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) {
for (int i = 0; i < num_iters; ++i) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
+ HloModule module("BM_SequentialWhiles", config);
auto builder = HloComputation::Builder("BM_ParallelWhiles");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -1750,8 +1797,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) {
std::vector<HloInstruction*> tuple_params(num_tuple_inputs);
for (int i = 0; i < num_iters; ++i) {
auto builder = HloComputation::Builder("BM_ParallelWhiles");
- HloModule module("BM_ManyElementTuple", VersionedComputationHandle(),
- config);
+ HloModule module("BM_ManyElementTuple", config);
for (int j = 0; j < num_tuple_inputs; ++j) {
tuple_params[j] = builder.AddInstruction(
HloInstruction::CreateParameter(j, element_shape, ""));
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index bfd85f257f..c45d914e93 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -37,6 +37,7 @@ cc_library(
srcs = ["cpu_transfer_manager.cc"],
hdrs = ["cpu_transfer_manager.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -54,29 +55,6 @@ cc_library(
)
cc_library(
- name = "external_constant_pool",
- srcs = ["external_constant_pool.cc"],
- hdrs = ["external_constant_pool.h"],
- deps = [
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
- ],
-)
-
-tf_cc_test(
- name = "external_constant_pool_test",
- srcs = ["external_constant_pool_test.cc"],
- deps = [
- ":external_constant_pool",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:test",
- ],
-)
-
-cc_library(
name = "cpu_compiler",
srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"],
@@ -95,7 +73,7 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -112,7 +90,6 @@ cc_library(
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
- "//tensorflow/compiler/xla/service:gather_expander",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
@@ -151,7 +128,14 @@ cc_library(
"@llvm//:target", # fixdeps: keep
"@llvm//:x86_code_gen", # fixdeps: keep
"@llvm//:x86_disassembler", # fixdeps: keep
- ],
+ ] + select({
+ "//tensorflow:linux_ppc64le": [
+ "@llvm//:powerpc_disassembler",
+ "@llvm//:powerpc_code_gen",
+ ],
+ "//conditions:default": [
+ ],
+ }),
alwayslink = True, # Contains compiler registration
)
@@ -168,7 +152,6 @@ cc_library(
":cpu_runtime",
":custom_call_target_registry",
":disassembler",
- ":external_constant_pool",
":orc_jit_memory_mapper",
":runtime_fp16",
":runtime_conv2d",
@@ -249,7 +232,6 @@ cc_library(
":cpu_options",
":cpu_runtime",
":dot_op_emitter",
- ":external_constant_pool",
":ir_emission_utils",
":ir_function",
":parallel_loop_emitter",
@@ -266,6 +248,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
@@ -372,7 +355,7 @@ tf_cc_binary(
srcs = ["sample_harness.cc"],
deps = [
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -649,10 +632,10 @@ tf_cc_test(
deps = [
":cpu_instruction_fusion",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -706,9 +689,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -734,7 +717,7 @@ tf_cc_test(
deps = [
":cpu_layout_assignment",
":target_machine_features_fake",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -826,7 +809,7 @@ tf_cc_test(
":cpu_executable",
":parallel_task_assignment",
":target_machine_features_fake",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -898,6 +881,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
"@llvm//:core",
"@llvm//:support",
],
@@ -908,7 +892,7 @@ tf_cc_test(
srcs = ["cpu_copy_insertion_test.cc"],
deps = [
":cpu_copy_insertion",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -958,7 +942,7 @@ tf_cc_test(
":ir_emission_utils",
":target_machine_features_fake",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
index 375b017b09..547d4c696d 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -60,11 +60,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in CNHW order.
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
// The kernel dimensions are in OIHW order.
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
ConvolutionDimensionNumbers dnums;
@@ -122,11 +122,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
auto builder = HloComputation::Builder(TestName());
// The input dimensions are in NHWC order.
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
// The kernel dimensions are in HWIO order.
auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
ConvolutionDimensionNumbers dnums;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 25b18eff20..29fa29d33a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -30,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/Triple.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Mangler.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Object/ObjectFile.h"
@@ -38,7 +39,7 @@ limitations under the License.
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -66,7 +67,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
@@ -264,12 +264,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_fusion=*/false);
+ /*rewrite_grad_op=*/true);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },
/*enable_dot_strength_reduction=*/false);
+ pass.AddPass<HloDCE>();
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
@@ -297,22 +297,24 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
- pipeline.AddPass<GatherExpander>();
-
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
pipeline.AddPass<CpuLayoutAssignment>(
- module->mutable_device_entry_computation_layout(),
- &target_machine_features);
+ module->mutable_entry_computation_layout(), &target_machine_features);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
- pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
- /*is_layout_sensitive=*/true,
- [](const Shape&, const Shape&) { return true; },
- /*enable_dot_strength_reduction=*/false);
- pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ {
+ auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
+ "after layout assignement");
+ pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ [](const Shape&, const Shape&) { return true; },
+ /*enable_dot_strength_reduction=*/false);
+ pass.AddPass<HloDCE>();
+ pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ }
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// Outline ops in the entry computation into calls to subcomputations.
const int max_parallelism =
@@ -550,8 +552,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// and reduced memory usage (as compared to using DependencyHloOrdering).
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(),
- DFSMemoryScheduler));
+ ScheduleComputationsInModule(*module, BufferSizeBytesFunction(),
+ DFSMemoryScheduler));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
@@ -580,7 +582,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- &target_machine_features, jit->external_constant_pool());
+ &target_machine_features);
for (auto embedded_computation :
entry_computation->MakeEmbeddedComputationsList()) {
@@ -603,7 +605,13 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
/*is_top_level_computation=*/true,
&module_sequence.at(entry_computation)));
- string function_name = llvm_ir::AsString(entry_function->getName());
+ string function_name = [&]() {
+ llvm::SmallVector<char, 40> function_name_vector;
+ llvm::Mangler::getNameWithPrefix(
+ function_name_vector, entry_function->getName(), jit->data_layout());
+ return string(function_name_vector.begin(), function_name_vector.end());
+ }();
+
string ir_module_string;
if (embed_ir_in_executable) {
ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
@@ -730,7 +738,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(
SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction()));
+ ScheduleComputationsInModule(*module, BufferSizeBytesFunction()));
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
@@ -767,8 +775,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
std::move(instruction_to_profile_idx),
std::move(computation_to_profile_idx),
- &target_machine_features,
- /*external_constant_pool=*/nullptr);
+ &target_machine_features);
HloComputation* computation = module->entry_computation();
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
index a05a269417..4db7fa446e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -74,14 +74,14 @@ TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) {
body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
auto cond_builder = HloComputation::Builder("condition");
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -114,7 +114,7 @@ TEST_F(CpuCopyInsertionTest, TupleCall) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param"));
auto constant = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0)));
auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, sub_param, constant));
sub_builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
index d12fa6bb9a..8727c72b6e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -40,7 +40,7 @@ ENTRY DotOperation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* dot = module->entry_computation()->root_instruction();
@@ -71,7 +71,7 @@ ENTRY ConvOperation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* conv = module->entry_computation()->root_instruction();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index cf43b74c69..1093559892 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -206,8 +206,8 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
tensorflow::gtl::MutableArraySlice<OwningDeviceMemory> buffers) {
se::Stream* stream = run_options->stream();
ScopedShapedBuffer result_buffer(
- /*on_host_shape=*/host_result_shape(),
- /*on_device_shape=*/host_result_shape(), run_options->allocator(),
+ /*on_host_shape=*/result_shape(),
+ /*on_device_shape=*/result_shape(), run_options->allocator(),
stream->parent()->device_ordinal());
// Move OwningDeviceMemory values which contain the array(s) of the result
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 46fe060817..991b14f17d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
@@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -282,7 +282,7 @@ class OpcodeFusionTest : public InstructionFusionTest {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "arg0"));
HloInstruction* one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, arg0, one));
return module->AddEmbeddedComputation(builder.Build());
@@ -501,8 +501,8 @@ TEST_F(OpcodeFusionTest, UnaryMapOfExp) {
HloInstruction* exp = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param0));
- builder.AddInstruction(HloInstruction::CreateMap(
- shape, {exp}, CreateAdderToOne(module.get()), /*static_operands=*/{}));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(shape, {exp}, CreateAdderToOne(module.get())));
module->AddEntryComputation(builder.Build());
@@ -525,8 +525,8 @@ TEST_F(OpcodeFusionTest, BinaryMapOfExps) {
HloInstruction* exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kExp, param1));
- builder.AddInstruction(HloInstruction::CreateMap(
- shape, {exp0, exp1}, CreateMax(module.get()), /*static_operands=*/{}));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(shape, {exp0, exp1}, CreateMax(module.get())));
module->AddEntryComputation(builder.Build());
@@ -595,7 +595,7 @@ TEST_F(OpcodeFusionTest, MessOfFusileNodes) {
auto pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(S32, {5}), idx_choice,
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))),
padding_config));
auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
@@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
string hlo_string = tensorflow::strings::StrCat(
"HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
RunFusionAndCheckOpcodesWereFused(
module.get(),
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
index 429fc7b786..3681d12d8d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index e75fcb6bc9..3ed7876715 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace {
@@ -24,6 +25,7 @@ const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
const char* const kXlaEnableExperimentalLlvmIrGemm =
"xla_enable_experimental_llvm_ir_gemm";
+const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
} // namespace
@@ -62,6 +64,43 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) {
return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0;
}
+static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str,
+ tensorflow::StringPiece suffix) {
+ CHECK_GE(str.size(), suffix.size());
+ CHECK_EQ(str.substr(str.size() - suffix.size()), suffix);
+ return str.substr(0, str.size() - suffix.size());
+}
+
+tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+ const HloModuleConfig& config) {
+ const auto& extra_options_map =
+ config.debug_options().xla_backend_extra_options();
+ auto it = extra_options_map.find(kLlvmIrGemmTileSize);
+ if (it == extra_options_map.end()) {
+ return tensorflow::gtl::nullopt;
+ }
+
+ std::vector<string> tile_components =
+ tensorflow::str_util::Split(it->second, ':');
+ CHECK_EQ(tile_components.size(), 3);
+
+ int64 tile_size_m;
+ int64 tile_size_k;
+ int64 tile_size_n_in_vector_width;
+
+ CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m));
+ CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k));
+
+ tensorflow::StringPiece tile_size_n_in_vector_width_str =
+ RemoveSuffix(tile_components[2], "*vectwidth");
+
+ CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str,
+ &tile_size_n_in_vector_width));
+
+ return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k,
+ tile_size_n_in_vector_width);
+}
+
} // namespace options
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 106dfbbc62..429b9e16cb 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -29,6 +29,8 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config);
bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
const HloModuleConfig& config);
+tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
+ const HloModuleConfig& config);
} // namespace options
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index d97802ee45..156166bf2b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -160,9 +161,8 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor,
int32 size_32 = static_cast<int32>(size);
CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32);
- Status s =
- TransferBufferToDevice(executor, /*size=*/size,
- /*source=*/source, queued_buffer->device_memory());
+ Status s = executor->SynchronousMemcpyH2D(
+ /*host_src=*/source, /*size=*/size, queued_buffer->device_memory());
if (!s.ok()) {
queued_buffer->Done(s);
@@ -181,7 +181,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::gtl::ArraySlice<int64> dimensions(
tensorflow::bit_cast<const int64*>(literal_shape.dimensions().data()),
literal_shape.dimensions().size());
- *literal = std::move(*Literal::CreateFromDimensions(
+ *literal = std::move(*LiteralUtil::CreateFromDimensions(
literal_shape.element_type(), dimensions));
TF_ASSIGN_OR_RETURN(Shape received_shape,
TransferArrayBufferFromOutfeed(
@@ -212,7 +212,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
tensorflow::bit_cast<const int64*>(
tuple_element_shape.dimensions().data()),
tuple_element_shape.dimensions().size());
- auto empty = Literal::CreateFromDimensions(
+ auto empty = LiteralUtil::CreateFromDimensions(
tuple_element_shape.element_type(), dimensions);
int64 size = GetByteSizeRequirement(tuple_element_shape);
buffer_data.push_back({empty->untyped_data(), size});
@@ -233,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed(
for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) {
*elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i);
}
- *literal = std::move(*Literal::MakeTupleOwned(std::move(elements)));
+ *literal = std::move(*LiteralUtil::MakeTupleOwned(std::move(elements)));
TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index d77076546f..58228180ca 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -324,11 +324,11 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() {
int64 column_remainder = k() % tile_cols();
int64 column_limit = k() - column_remainder;
- ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
- [&](llvm::Value* column, bool is_first_column) {
- EmitOuterLoopBody(column, tile_cols(), is_first_column);
- });
+ ksl_.ForReturnVoid("dot.outer.tiled",
+ /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(),
+ [&](llvm::Value* column, bool is_first_column) {
+ EmitOuterLoopBody(column, tile_cols(), is_first_column);
+ });
if (column_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
@@ -341,19 +341,20 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
int64 columns, bool is_first_column) {
int64 row_limit = m() - (m() % tile_rows());
- ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
- /*step=*/tile_rows(), [&](llvm::Value* row) {
- std::vector<llvm::Value*> lhs_tile =
- lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
- llvm::Value* accumulator =
- is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
- : vsl_.GetZeroVector())
- : vsl_.LoadVector(result_, row);
- for (int i = 0; i < columns; i++) {
- accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
- }
- vsl_.StoreVector(accumulator, result_, row);
- });
+ ksl_.ForReturnVoid(
+ "dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
+ /*step=*/tile_rows(), [&](llvm::Value* row) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row);
+ llvm::Value* accumulator =
+ is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row)
+ : vsl_.GetZeroVector())
+ : vsl_.LoadVector(result_, row);
+ for (int i = 0; i < columns; i++) {
+ accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
+ }
+ vsl_.StoreVector(accumulator, result_, row);
+ });
}
void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
@@ -372,7 +373,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
// // initialized.
// }
- ksl_.For(
+ ksl_.ForReturnVoid(
"dot.inner.epilg.outer", /*start=*/current_tile_col,
/*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
/*step=*/1, /*peel_first_iteration=*/false,
@@ -382,7 +383,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
ir_builder_->CreateMul(col, ir_builder_->getInt64(m()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For(
+ ksl_.ForReturnVoid(
"dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(),
/*step=*/1, [&](llvm::Value* scalar_row) {
llvm::Value* product = vsl_.Mul(
@@ -390,7 +391,7 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
is_first_scalar_col,
ir_builder_->getInt1(is_first_tiled_column));
- ksl_.If(
+ ksl_.IfReturnVoid(
setting_result_first_time,
/*true_block_generator=*/
[&]() {
@@ -571,9 +572,10 @@ void RowMajorMatrixVectorProductEmitter::Emit() {
int64 row_remainder = m() % tile_rows();
int64 row_limit = m() - row_remainder;
- ksl_.For("dot.outer.tiled",
- /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
- [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
+ ksl_.ForReturnVoid(
+ "dot.outer.tiled",
+ /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(),
+ [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); });
if (row_remainder != 0) {
EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
@@ -585,17 +587,17 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
std::vector<VectorVariable>* vector_accumulators) {
int64 column_limit = k() - (k() % tile_cols());
- ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
- /*step=*/tile_cols(), [&](llvm::Value* col) {
- std::vector<llvm::Value*> lhs_tile =
- lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
- llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
- for (int i = 0; i < rows; i++) {
- llvm::Value* old_sum = (*vector_accumulators)[i].Get();
- (*vector_accumulators)[i].Set(
- vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
- }
- });
+ ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
+ /*step=*/tile_cols(), [&](llvm::Value* col) {
+ std::vector<llvm::Value*> lhs_tile =
+ lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col);
+ llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
+ for (int i = 0; i < rows; i++) {
+ llvm::Value* old_sum = (*vector_accumulators)[i].Get();
+ (*vector_accumulators)[i].Set(vsl_.Add(
+ old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
+ }
+ });
}
void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
@@ -612,14 +614,15 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
ir_builder_->getInt64(k()));
llvm::Value* lhs_base_pointer =
vsl_.ComputeOffsetPointer(lhs_, total_offset);
- ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
- /*step=*/1, [&](llvm::Value* scalar_col) {
- llvm::Value* product =
- vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
- vsl_.LoadScalar(rhs_, scalar_col));
- llvm::Value* old_value = (*scalar_accumulators)[r].Get();
- (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
- });
+ ksl_.ForReturnVoid(
+ "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(),
+ /*step=*/1, [&](llvm::Value* scalar_col) {
+ llvm::Value* product =
+ vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
+ vsl_.LoadScalar(rhs_, scalar_col));
+ llvm::Value* old_value = (*scalar_accumulators)[r].Get();
+ (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
+ });
}
}
@@ -665,6 +668,10 @@ class MatrixMatrixBlockPanelEmitter {
// the largest vector register we will use). This can be larger than the
// largest vector register supported by the machine -- LLVM will legalize
// these large vector widths into legally sized vectors.
+ //
+ // `max_vector_count` is the maximum number of vectors of size
+ // `max_vectorization_width` that we will attempt to process at once.
+ //
// `min_vectorization_width` is the smallest vector width the emitter will use
// -- below that it will devolve to using a scalar loop.
//
@@ -674,12 +681,13 @@ class MatrixMatrixBlockPanelEmitter {
class Config {
public:
explicit Config(PrimitiveType scalar_type, Dimensions dims,
- int64 max_vectorization_width,
+ int64 max_vectorization_width, int64 max_vector_count,
int64 min_vectorization_width, int64 tile_size_m,
int64 tile_size_k)
: scalar_type_(scalar_type),
dims_(dims),
max_vectorization_width_(max_vectorization_width),
+ max_vector_count_(max_vector_count),
min_vectorization_width_(min_vectorization_width),
tile_size_m_(tile_size_m),
tile_size_k_(tile_size_k) {}
@@ -694,6 +702,7 @@ class MatrixMatrixBlockPanelEmitter {
PrimitiveType scalar_type() const { return scalar_type_; }
Dimensions dims() const { return dims_; }
int64 max_vectorization_width() const { return max_vectorization_width_; }
+ int64 max_vector_count() const { return max_vector_count_; }
int64 min_vectorization_width() const { return min_vectorization_width_; }
int64 tile_size_m() const { return tile_size_m_; }
@@ -703,6 +712,7 @@ class MatrixMatrixBlockPanelEmitter {
PrimitiveType scalar_type_;
Dimensions dims_;
int64 max_vectorization_width_;
+ int64 max_vector_count_;
int64 min_vectorization_width_;
int64 tile_size_m_;
int64 tile_size_k_;
@@ -721,39 +731,35 @@ class MatrixMatrixBlockPanelEmitter {
ksl_(ir_builder_) {
CHECK(max_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(max_vectorization_width())));
+ CHECK_GT(max_vector_count(), 0);
CHECK(min_vectorization_width() > 0 &&
IsPowerOfTwo(static_cast<uint64>(min_vectorization_width())));
+ CHECK_GE(max_vectorization_width(), min_vectorization_width());
CHECK_GT(tile_size_k(), 0);
}
void Emit();
private:
- // This emits a loop that loops over the `n` dimension in multiples of
- // `max_vectorization_width` as much as possible and then emits a remainder
- // epilogue.
- void EmitLoopOverN();
-
- // This emits a loop that loops over the `k` dimension in multiples of
- // `tile_size_k` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start,
- llvm::Value* n_end);
-
- // This emits a loop that loops over the `m` dimension in multiples of
- // `tile_size_m` as much as possible and then emits a remainder epilogue.
- void EmitLoopOverM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ // The HandleResiduesOnX helpers split the iteration space for dimension X
+ // into a multiple of the tile size on dimension X and an epilogue. These
+ // helpers ultimately call into `EmitTiledGemm` for emitting the
+ // tiled GEMM kernel.
+
+ void HandleResiduesOnN();
+ void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start,
+ llvm::Value* n_end);
+ void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k,
+ llvm::Value* k_start, llvm::Value* k_end,
+ llvm::Value* n_start, llvm::Value* n_end);
+
+ // This emits a tiled GEMM kernel. For a detailed description see the comment
+ // on the implementation.
+ void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k,
llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end);
-
- // This emits the inner reduction loop. This inner reduction loop multiplies
- // a tile from the LHS of size [tile_size_m,tile_size_k] and a tile from the
- // RHS of size [`tile_size_k`, vls->vector_width()] to update a tile of size
- // [`tile_size_m`, vls->vector_width()] in the result.
- void EmitTiledReductionLoop(VectorSupportLibrary* vsl, int64 tile_size_k,
- llvm::Value* k_start, llvm::Value* k_end,
- llvm::Value* n_start, llvm::Value* n_end,
- int64 tile_size_m, llvm::Value* m_start,
- llvm::Value* m_end);
+ llvm::Value* n_start, llvm::Value* n_end,
+ int64 tile_size_m, llvm::Value* m_start,
+ llvm::Value* m_end);
llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); }
@@ -763,6 +769,7 @@ class MatrixMatrixBlockPanelEmitter {
int64 max_vectorization_width() const {
return config().max_vectorization_width();
}
+ int64 max_vector_count() const { return config().max_vector_count(); }
int64 min_vectorization_width() const {
return config().min_vectorization_width();
}
@@ -779,16 +786,19 @@ class MatrixMatrixBlockPanelEmitter {
KernelSupportLibrary ksl_;
};
-void MatrixMatrixBlockPanelEmitter::Emit() { EmitLoopOverN(); }
+void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); }
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() {
// We can only iterate the `n` dimension for an extent that is divisible by
// the vectorization width. So we emit an outer loop that first processes the
// largest extent in `n` that is divisible by max_vectorization_width, then
// the largest remaining extent that is divisible by max_vectorization_width /
// 2 etc.
- int64 current_vectorization_width = max_vectorization_width();
+ int64 current_vectorization_width =
+ max_vector_count() * max_vectorization_width();
+ int64 current_vector_count = max_vector_count();
+
int64 n_start = 0;
while (n_start != dims().n() &&
current_vectorization_width >= min_vectorization_width()) {
@@ -796,53 +806,67 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverN() {
if (n_start != n_end) {
VectorSupportLibrary vsl(scalar_type(), current_vectorization_width,
ir_builder_, "gebp");
- EmitLoopOverK(&vsl, GetInt64(n_start), GetInt64(n_end));
+ HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end));
n_start = n_end;
}
- current_vectorization_width /= 2;
+ if (current_vector_count == 1) {
+ current_vectorization_width /= 2;
+ } else {
+ current_vector_count--;
+ current_vectorization_width =
+ current_vector_count * max_vectorization_width();
+ }
}
if (n_start != dims().n()) {
VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp");
- ksl_.For("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
+ ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) {
llvm::Value* n_i_next =
ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1));
- EmitLoopOverK(&vsl, n_i, n_i_next);
+ HandleResiduesOnK(&vsl, n_i, n_i_next);
});
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl,
- llvm::Value* n_start,
- llvm::Value* n_end) {
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl,
+ llvm::Value* n_start,
+ llvm::Value* n_end) {
int64 k_start = 0;
int64 k_end = dims().k() - (dims().k() % tile_size_k());
if (k_end != k_start) {
- EmitLoopOverM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
- n_start, n_end);
+ HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end),
+ n_start, n_end);
k_start = k_end;
}
if (k_start != dims().k()) {
- EmitLoopOverM(vsl, dims().k() - k_start, GetInt64(k_start),
- GetInt64(dims().k()), n_start, n_end);
+ HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start),
+ GetInt64(dims().k()), n_start, n_end);
}
}
-void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
+void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) {
const int64 m_end = dims().m() - dims().m() % tile_size_m();
- EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
- tile_size_m(), GetInt64(0), GetInt64(m_end));
+ EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(),
+ GetInt64(0), GetInt64(m_end));
if (m_end != dims().m()) {
- EmitTiledReductionLoop(vsl, tile_size_k, k_start, k_end, n_start, n_end,
- dims().m() - m_end, GetInt64(m_end),
- GetInt64(dims().m()));
+ EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end,
+ dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m()));
}
}
+// The loop structure is:
+//
+// Iterate over dimension M as m:
+// Iterate over dimension N as n:
+// Iterate over dimension K as k:
+// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n])
+//
+// I.e. a just a tiled version of a "naive" GEMM.
+//
// The tiling scheme is as follows:
//
// Let the LHS be:
@@ -904,41 +928,48 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverM(
// +-------------------+-------------------+-------------------+---------
// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ...
// +-------------------+-------------------+-------------------+---------
-void MatrixMatrixBlockPanelEmitter::EmitTiledReductionLoop(
+void MatrixMatrixBlockPanelEmitter::EmitTiledGemm(
VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start,
llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end,
int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) {
- ksl_.For("dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
- MemoryTile result_memory_tile(vsl, ir_builder_, /*matrix=*/result_,
- /*matrix_size_along_minor_dim=*/dims().n(),
- /*major_dim_offset=*/m_i,
- /*tile_size_along_major_dim=*/tile_size_m);
- MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
- /*matrix_size_along_minor_dim=*/dims().k(),
- /*major_dim_offset=*/m_i,
- /*tile_size_along_major_dim=*/tile_size_m);
-
- ksl_.For("dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
- MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, dims().n(), k_i,
- tile_size_k);
- std::vector<std::vector<llvm::Value*>> lhs_tile =
- lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
- ksl_.For(
- "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
- std::vector<llvm::Value*> rhs_tile = rhs_memory_tile.LoadTile(n_i);
- std::vector<llvm::Value*> result_tile =
- result_memory_tile.LoadTile(n_i);
- for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
- for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
- result_tile[r_m_i] =
- vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
- result_tile[r_m_i]);
- }
- }
- result_memory_tile.StoreTile(result_tile, n_i);
- });
- });
- });
+ ksl_.ForReturnVoid(
+ "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) {
+ MemoryTile result_memory_tile(
+ vsl, ir_builder_, /*matrix=*/result_,
+ /*matrix_size_along_minor_dim=*/dims().n(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_,
+ /*matrix_size_along_minor_dim=*/dims().k(),
+ /*major_dim_offset=*/m_i,
+ /*tile_size_along_major_dim=*/tile_size_m);
+ ksl_.ForReturnVoid(
+ "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) {
+ TileVariable result_tile_var(vsl,
+ result_memory_tile.LoadTile(n_i));
+ ksl_.ForReturnVoid(
+ "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) {
+ MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_,
+ dims().n(), k_i, tile_size_k);
+ std::vector<std::vector<llvm::Value*>> lhs_tile =
+ lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k);
+ std::vector<llvm::Value*> rhs_tile =
+ rhs_memory_tile.LoadTile(n_i);
+ std::vector<llvm::Value*> result_tile =
+ result_tile_var.Get();
+ for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) {
+ for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) {
+ result_tile[r_m_i] =
+ vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i],
+ result_tile[r_m_i]);
+ }
+ }
+ result_tile_var.Set(result_tile);
+ });
+
+ result_memory_tile.StoreTile(result_tile_var.Get(), n_i);
+ });
+ });
}
} // namespace
@@ -1023,16 +1054,21 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
target, ir_builder_->getInt8(0), size_bytes,
target_machine_features_.minimum_alignment_for_allocation(size_bytes));
- int64 max_vector_width =
+ int64 max_target_vector_width =
target_machine_features_.vector_register_num_elements(
*ir_builder_->GetInsertBlock()->getParent(), primitive_type);
+ int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width;
+ std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) =
+ GetGemmTileSize();
+
MatrixMatrixBlockPanelEmitter::Config config(
/*scalar_type=*/primitive_type,
MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n},
- /*max_vectorization_width=*/max_vector_width,
- /*min_vectorization_width=*/std::min<int64>(4, max_vector_width),
- /*tile_size_m=*/3, /*tile_size_k=*/5);
+ /*max_vectorization_width=*/max_target_vector_width,
+ /*max_vector_count=*/tile_size_n_in_vector_width,
+ /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width),
+ /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k);
VLOG(2) << "Emitting GEBP kernel in LLVM IR with config "
<< config.GetCacheKey();
@@ -1265,8 +1301,11 @@ Status DotOpEmitter::Emit() {
// from messing up the vectorization.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction",
- /*prevent_unrolling=*/lhs_reduction_along_minor_dimension &&
- rhs_reduction_along_minor_dimension);
+ /*unroll_mode=*/
+ (lhs_reduction_along_minor_dimension &&
+ rhs_reduction_along_minor_dimension)
+ ? xla::llvm_ir::UnrollMode::kNoUnroll
+ : xla::llvm_ir::UnrollMode::kDefaultUnroll);
// The final entry in the rhs and lhs indexes is the indvar of the
// reduction loop.
@@ -1341,7 +1380,7 @@ Status DotOpEmitter::Emit() {
// the rhs and lhs indexes with the reduction dimensions removed. The terms
// from the rhs index are the lower dimensions in the index so we add them
// first.
- llvm_ir::IrArray::Index target_index;
+ llvm_ir::IrArray::Index target_index(lhs_index.GetType());
for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
if (dimension != lhs_reduction_dimension) {
target_index.push_back(lhs_index[dimension]);
@@ -1365,10 +1404,13 @@ Status DotOpEmitter::Emit() {
Status DotOpEmitter::EmitScalarDot() {
// A scalar dot is just a scalar multiply.
llvm::Value* result;
+ // Use the same index_type for all tensor accesses in the same kernel.
+ llvm::Type* index_type = ir_builder_->getInt64Ty();
+ llvm_ir::IrArray::Index element_index(index_type);
llvm::Value* lhs_value =
- lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
+ lhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_);
llvm::Value* rhs_value =
- rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
+ rhs_array_.EmitReadArrayElement(/*index=*/element_index, ir_builder_);
if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) {
#define REAL(x) ir_builder_->CreateExtractValue(x, {0})
#define IMAG(x) ir_builder_->CreateExtractValue(x, {1})
@@ -1386,7 +1428,8 @@ Status DotOpEmitter::EmitScalarDot() {
} else {
result = ir_builder_->CreateFMul(lhs_value, rhs_value);
}
- target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
+ target_array_.EmitWriteArrayElement(/*index=*/element_index, result,
+ ir_builder_);
return Status::OK();
}
@@ -1588,8 +1631,8 @@ bool PotentiallyImplementedAsEigenDot(
const Shape& lhs_shape = hlo.operand(0)->shape();
const Shape& rhs_shape = hlo.operand(1)->shape();
- if (ShapeUtil::HasZeroElements(lhs_shape) ||
- ShapeUtil::HasZeroElements(rhs_shape)) {
+ if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
+ ShapeUtil::IsZeroElementArray(rhs_shape)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index d88ccea0db..ed2a18976a 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -143,6 +143,17 @@ class DotOpEmitter {
.value_or(kDefaultTilingFactor);
}
+ std::tuple<int64, int64, int64> GetGemmTileSize() const {
+ // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz
+ //
+ // TODO(b/80093688): Tune for other architectures and centralize this
+ // information in one place.
+ const std::tuple<int64, int64, int64> kDefaultTileSize =
+ std::tuple<int64, int64, int64>(11, 9, 1);
+ return options::LlvmIrGemmTileSize(hlo_module_config_)
+ .value_or(kDefaultTileSize);
+ }
+
// Returns true if we should use an experimental implementation of GEMM
// (general matrix matrix multiplication) if possible.
bool EnableExperimentalLlvmIrGemm() const {
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
deleted file mode 100644
index c562865591..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc
+++ /dev/null
@@ -1,50 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
-
-#include <algorithm>
-#include <cstdlib>
-#include <cstring>
-
-#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/ptr_util.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
-
-namespace xla {
-namespace cpu {
-void ExternalConstantPool::Insert(string name, const LiteralSlice& literal,
- int64 alignment) {
- CHECK(!ShapeUtil::IsTuple(literal.shape()));
- CHECK(alignment > 0 && IsPowerOfTwo(static_cast<uint64>(alignment)));
- CHECK(entries_.find(name) == entries_.end());
-
- const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape());
- void* raw_pointer = tensorflow::port::AlignedMalloc(
- literal_size, std::max<size_t>(alignment, sizeof(void*)));
- CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size
- << " bytes with alignment of " << alignment;
-
- std::memcpy(raw_pointer, literal.untyped_data(), literal_size);
- entries_.emplace(std::move(name), static_cast<uint8*>(raw_pointer));
-}
-
-const uint8* ExternalConstantPool::Find(const string& name) {
- auto it = entries_.find(name);
- return it == entries_.end() ? nullptr : it->second.get();
-}
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
deleted file mode 100644
index 0677f5f0b5..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h
+++ /dev/null
@@ -1,65 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
-
-#include <memory>
-
-#include "tensorflow/compiler/xla/literal_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/platform/mem.h"
-
-namespace xla {
-namespace cpu {
-// An ExternalConstantPool maintains a set of constants kept external to
-// generated LLVM IR. These constants are accessed from the IR via globals with
-// extern linkage. This current incarnation of ExternalConstantPool only
-// supports the JIT CPU backend; the AOT backend is not supported.
-//
-// Implementation-wise, this is a simple wrapper around a map of strings to byte
-// buffers. This simply implementation works in a JIT scenario. This class
-// will have to become smarter if we decide to support external constant pools
-// on AOT compiles in the future.
-class ExternalConstantPool {
- public:
- // Inserts a buffer with the contents of `literal` into the constant pool with
- // the name `name`. It is an error to try to insert two constants with the
- // same `name` into the same constant pool. The buffer for literal is aligned
- // to `aligment` bytes, and `alignment` must be a power of 2.
- //
- // The constant pool copies out the contents of `literal` into a buffer it
- // owns -- it does not keep pointers to `literal`, or to memory owned by
- // `literal`.
- void Insert(string name, const LiteralSlice& literal, int64 alignment);
-
- // Find the constant with name `name` in this constant pool. If there isn't
- // such constant, return nullptr.
- const uint8* Find(const string& name);
-
- private:
- // We need to `AlignedFree` pointers allocated into `entries_` since we
- // allocate them with `AlignedMalloc`.
- struct FreeDeleter {
- void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); }
- };
-
- tensorflow::gtl::FlatMap<string, std::unique_ptr<uint8, FreeDeleter>>
- entries_;
-};
-} // namespace cpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc
deleted file mode 100644
index 9290a4e5df..0000000000
--- a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc
+++ /dev/null
@@ -1,82 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
-#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace xla {
-namespace cpu {
-namespace {
-class ExternalConstantPoolTest : public ::testing::Test {};
-
-template <typename T>
-T GetFromBuffer(const uint8* buffer, int64 index) {
- T result;
- std::memcpy(&result, buffer + index * sizeof(T), sizeof(T));
- return result;
-}
-
-TEST(ExternalConstantPoolTest, Basic) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}});
- constant_pool.Insert("name-0", *literal, 4);
- const uint8* constant = constant_pool.Find("name-0");
- ASSERT_NE(constant, nullptr);
-
- EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 2);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 3);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4);
-
- EXPECT_EQ(constant_pool.Find("name-1"), nullptr);
-}
-
-TEST(ExternalConstantPoolTest, RowMinorLayout) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
- const auto literal = Literal::CreateR2WithLayout(
- {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
- constant_pool.Insert("name-0", *literal, 4);
- const uint8* constant = constant_pool.Find("name-0");
- ASSERT_NE(constant, nullptr);
-
- EXPECT_EQ(GetFromBuffer<int32>(constant, 0), 1);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 1), 3);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 2), 2);
- EXPECT_EQ(GetFromBuffer<int32>(constant, 3), 4);
-}
-
-TEST(ExternalConstantPoolTest, Alignment) {
- ExternalConstantPool constant_pool;
- EXPECT_EQ(constant_pool.Find("name-0"), nullptr);
-
- for (int i = 0; i < 8; i++) {
- int64 alignment = 1 << i;
- string name = tensorflow::strings::StrCat("name-", i);
-
- const auto literal = Literal::CreateR2({{1, 2}, {3, 4}});
- constant_pool.Insert(name, *literal, alignment);
-
- const uint8* constant = constant_pool.Find(name);
- ASSERT_NE(constant, nullptr);
- EXPECT_EQ(reinterpret_cast<intptr_t>(constant) % alignment, 0);
- }
-}
-
-} // namespace
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index b560b7531c..1a8bedfe6a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution(
return false;
}
- if (ShapeUtil::HasZeroElements(input_shape) ||
- ShapeUtil::HasZeroElements(kernel_shape)) {
+ if (ShapeUtil::IsZeroElementArray(input_shape) ||
+ ShapeUtil::IsZeroElementArray(kernel_shape)) {
return false;
}
// Make sure input and kernel has the same data type.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
index abb2471e6a..530ebce854 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -35,7 +35,7 @@ ENTRY Conv {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* entry_computation = module->entry_computation();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 13bd5e73db..2ad41374d3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -48,6 +48,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
@@ -83,8 +85,7 @@ IrEmitter::IrEmitter(
llvm::Module* llvm_module,
std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
- const TargetMachineFeatures* target_machine_features,
- ExternalConstantPool* external_constant_pool)
+ const TargetMachineFeatures* target_machine_features)
: assignment_(assignment),
module_(llvm_module),
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
@@ -94,8 +95,7 @@ IrEmitter::IrEmitter(
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()),
is_top_level_computation_(false),
- target_machine_features_(*target_machine_features),
- external_constant_pool_(external_constant_pool) {
+ target_machine_features_(*target_machine_features) {
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_enable_fast_math()));
@@ -160,47 +160,25 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
return Status::OK();
}
-llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
- llvm::GlobalVariable* result;
-
- // We avoid creating large constants in the LLVM IR since LLVM is not
- // efficient for large constant arrays. We still emit "small enough" constant
- // arrays into the Ir, in the off chance the LLVM optimizer can do something
- // interesting with it.
- const int kMaxInternalConstantSizeInBytes = 128;
- if (external_constant_pool_ &&
- ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) {
- string global_name = tensorflow::strings::StrCat(
- "constant_global_", external_global_constant_counter_++);
- result = new llvm::GlobalVariable(
- /*Module=*/*module_,
- /*Type=*/IrShapeType(literal.shape()),
- /*isConstant=*/true,
- /*Linkage=*/llvm::GlobalValue::ExternalLinkage,
- /*Initializer=*/nullptr,
- /*Name=*/AsStringRef(global_name));
- result->setAlignment(MinimumAlignmentForShape(literal.shape()));
- external_constant_pool_->Insert(global_name, literal,
- MinimumAlignmentForShape(literal.shape()));
- } else {
- llvm::Constant* initializer =
- llvm_ir::ConvertLiteralToIrConstant(literal, module_);
- result = new llvm::GlobalVariable(
- /*Module=*/*module_,
- /*Type=*/initializer->getType(),
- /*isConstant=*/true,
- /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/initializer,
- /*Name=*/"");
- result->setAlignment(MinimumAlignmentForShape(literal.shape()));
- }
- return result;
+llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) {
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, module_);
+ llvm::GlobalVariable* result_global = new llvm::GlobalVariable(
+ /*Module=*/*module_,
+ /*Type=*/initializer->getType(),
+ /*isConstant=*/true,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/initializer,
+ /*Name=*/"");
+ result_global->setAlignment(MinimumAlignmentForShape(literal.shape()));
+ return llvm::ConstantExpr::getBitCast(
+ result_global, IrShapeType(literal.shape())->getPointerTo());
}
Status IrEmitter::HandleConstant(HloInstruction* constant) {
VLOG(2) << "HandleConstant: " << constant->ToString();
const Literal& literal = constant->literal();
- llvm::GlobalVariable* global_for_const;
+ llvm::Constant* global_for_const;
auto it = emitted_literals_.find(&literal);
if (it != emitted_literals_.end()) {
@@ -221,10 +199,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
return EmitMemcpy(*(copy->operand(0)), *copy);
- } else {
- // Use the elemental emitter for non-tuple shapes.
+ } else if (ShapeUtil::IsArray(copy->shape())) {
+ // Use the elemental emitter for array shapes.
return DefaultAction(copy);
}
+ return Unimplemented(
+ "unsupported operand type %s for copy instruction",
+ PrimitiveType_Name(copy->shape().element_type()).c_str());
}
// Calculate the alignment of a buffer allocated for a given primitive type.
@@ -298,45 +279,60 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
Status IrEmitter::HandleSelect(HloInstruction* select) {
auto pred = select->operand(0);
- auto on_true = select->operand(1);
- auto on_false = select->operand(2);
TF_RET_CHECK(pred->shape().element_type() == PRED);
-
- if (ShapeUtil::IsTuple(select->shape())) {
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select));
- llvm_ir::EmitTupleSelect(
- GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true),
- GetEmittedValueFor(on_false), &ir_builder_, module_);
- return Status::OK();
- }
-
return DefaultAction(select);
}
-Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
- VLOG(2) << "HandleInfeed: " << infeed->ToString();
+Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
+ auto pred = tuple_select->operand(0);
+ auto on_true = tuple_select->operand(1);
+ auto on_false = tuple_select->operand(2);
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+ TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
+ TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape()));
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple_select));
+ llvm_ir::EmitTupleSelect(GetIrArrayFor(tuple_select), GetIrArrayFor(pred),
+ GetEmittedValueFor(on_true),
+ GetEmittedValueFor(on_false), &ir_builder_, module_);
+ return Status::OK();
+}
- const Shape& shape = infeed->shape();
+Status IrEmitter::HandleInfeed(HloInstruction* instruction) {
+ HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
+ VLOG(2) << "HandleInfeed: " << infeed->ToString();
- // The infeed operation produces data (dequeued from the infeed queue) at this
- // address, which has been provided by buffer assignment.
+ // The infeed operation produces a two-element tuple containing data and a
+ // token value. HloInfeedInstruction::infeed_shape gives us the data shape.
+ const Shape& data_shape = infeed->infeed_shape();
+ DCHECK(ShapeUtil::Equal(data_shape,
+ ShapeUtil::GetTupleElementShape(infeed->shape(), 0)));
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed));
- llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed);
- if (ShapeUtil::IsTuple(shape)) {
- TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape));
+ // Write the tuple index table.
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
+ assignment_.GetUniqueSlice(infeed, {0}));
+ llvm::Value* data_address = EmitTempBufferPointer(data_slice, data_shape);
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice token_slice,
+ assignment_.GetUniqueSlice(infeed, {1}));
+ llvm::Value* token_address = EmitTempBufferPointer(
+ token_slice, ShapeUtil::GetTupleElementShape(infeed->shape(), 1));
+ llvm_ir::EmitTuple(GetIrArrayFor(infeed), {data_address, token_address},
+ &ir_builder_, module_);
+
+ if (ShapeUtil::IsTuple(data_shape)) {
+ TF_RET_CHECK(!ShapeUtil::IsNestedTuple(data_shape));
// For a tuple, we first copy each of the internal elements to
// their corresponding target locations. We then construct the
// tuple outer buffer containing pointers to the internal
// elements.
std::vector<llvm::Value*> tuple_element_addresses;
- for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
+ for (int64 i = 0; i < data_shape.tuple_shapes_size(); ++i) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
- assignment_.GetUniqueSlice(infeed, {i}));
+ assignment_.GetUniqueSlice(infeed, {0, i}));
const Shape& tuple_element_shape =
- ShapeUtil::GetTupleElementShape(shape, i);
+ ShapeUtil::GetTupleElementShape(data_shape, i);
// Only the outer tuple buffer's target address is obtained from
// GetEmittedValueFor, to handle the case when Infeed is the root
@@ -351,11 +347,11 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
tuple_element_addresses.push_back(tuple_element_address);
}
- llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_,
- module_);
+ llvm_ir::EmitTuple(llvm_ir::IrArray(data_address, data_shape),
+ tuple_element_addresses, &ir_builder_, module_);
} else {
- TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape,
- GetEmittedValueFor(infeed)));
+ TF_RETURN_IF_ERROR(
+ EmitXfeedTransfer(XfeedKind::kInfeed, data_shape, data_address));
}
return Status::OK();
@@ -480,42 +476,111 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
+ llvm::Function* mapped_ir_function =
+ FindOrDie(emitted_functions_, map->to_apply());
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : map->operands()) {
+ const llvm_ir::IrArray& array = GetIrArrayFor(operand);
+ parameter_addresses.push_back(
+ array.EmitArrayElementAddress(index, &ir_builder_));
+ }
+ return EmitElementFunctionCall(mapped_ir_function, map->shape(),
+ parameter_addresses, "map_function");
+}
+
Status IrEmitter::HandleMap(HloInstruction* map) {
- gtl::ArraySlice<HloInstruction*> operands(map->operands());
- HloComputation* function = map->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
-
- return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
- const llvm_ir::IrArray::Index& index) {
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : operands) {
- const llvm_ir::IrArray& array = GetIrArrayFor(operand);
- parameter_addresses.push_back(
- array.EmitArrayElementAddress(index, &ir_builder_));
- }
- return EmitElementFunctionCall(mapped_ir_function, map->shape(),
- parameter_addresses, "map_function");
+ return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
});
}
-Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
- auto operand = reduce_window->operand(0);
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
HloComputation* function = reduce_window->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+
+ // We fold inputs into the accumulator and initialize it to
+ // the initial value on the reduce_window.
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
+ "reduce_window_accumulator_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(operand_element_type));
+ ir_builder_.CreateStore(
+ ir_builder_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))),
+ accumulator_address);
+
+ llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), index.size());
+ llvm::Value* in_bounds_condition = nullptr;
+ for (size_t i = 0; i < index.size(); ++i) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ ir_builder_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to check if 0 <= input_index[i] < bound, as otherwise we are in
+ // the padding so that we can skip the computation. That is equivalent to
+ // input_index[i] < bound as an *unsigned* comparison, since a negative
+ // value will wrap to a large positive value.
+ llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ input_index[i],
+ ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = index_condition;
+ } else {
+ in_bounds_condition =
+ ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ }
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ llvm_ir::IrArray input_array(GetIrArrayFor(operand));
+ llvm::Value* input_value_address =
+ input_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce_window->shape(),
+ {accumulator_address, input_value_address}, "reducer_function");
+ ir_builder_.CreateStore(result, accumulator_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_address);
+}
+
+Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
- /*instruction=*/*reduce_window, /*operands=*/{operand},
+ /*instruction=*/*reduce_window,
+ /*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32}));
// TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
+ if (window_util::HasDilation(reduce_window->window())) {
return Unimplemented(
"Dilation for ReduceWindow is not implemented on CPU.");
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
@@ -530,72 +595,9 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// This is completely un-optimized and just here to have something
// that works.
return EmitTargetElementLoop(
- reduce_window, [this, reduce_window, operand, window,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // We fold inputs into the accumulator and initialize it to
- // the initial value on the reduce_window.
- PrimitiveType operand_element_type = operand->shape().element_type();
- llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accumulator_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(operand_element_type));
- ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
- reduce_window->operand(1))),
- accumulator_address);
-
- llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"),
- &ir_builder_);
- std::vector<int64> window_size;
- for (const auto& dim : window.dimensions()) {
- window_size.push_back(dim.size());
- }
- const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
- ShapeUtil::MakeShape(operand_element_type, window_size), "window");
- CHECK_EQ(window_index.size(), index.size());
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- llvm_ir::IrArray::Index input_index(index.size());
- llvm::Value* in_bounds_condition = nullptr;
- for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
- input_index[i] = ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- ir_builder_.getInt64(window.dimensions(i).padding_low()));
-
- // We need to check if 0 <= input_index[i] < bound, as
- // otherwise we are in the padding so that we can skip the
- // computation. That is equivalent to input_index[i] < bound
- // as an *unsigned* comparison, since a negative value will
- // wrap to a large positive value.
- llvm::Value* index_condition = ir_builder_.CreateICmpULT(
- input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
- operand->shape(), i)));
- if (in_bounds_condition == nullptr) {
- in_bounds_condition = index_condition;
- } else {
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, index_condition);
- }
- }
- CHECK(in_bounds_condition != nullptr);
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value_address =
- input_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce_window->shape(),
- {accumulator_address, input_value_address}, "reducer_function");
- ir_builder_.CreateStore(result, accumulator_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_address);
+ reduce_window, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduceWindow(
+ Cast<HloReduceWindowInstruction>(reduce_window), index);
});
}
@@ -686,7 +688,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(source_index.size());
+ llvm_ir::IrArray::Index operand_index(ir_builder_.getInt64Ty(),
+ source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getTrue();
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
@@ -760,7 +763,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// value and the current output value.
SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
- llvm_ir::IrArray::Index selected_index;
+ llvm_ir::IrArray::Index selected_index(source_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
@@ -823,17 +826,157 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_machine_features_);
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* lhs = convolution->operand(0);
+ const HloInstruction* rhs = convolution->operand(1);
+ const Window& window = convolution->window();
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+ int num_spatial_dims = dnums.output_spatial_dimensions_size();
+ std::vector<llvm::Value*> output_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
+ }
+ llvm::Value* output_feature = index[dnums.output_feature_dimension()];
+ llvm::Value* batch = index[dnums.output_batch_dimension()];
+
+ // We will accumulate the products into this sum to calculate the output entry
+ // at the given index.
+ PrimitiveType lhs_element_type = lhs->shape().element_type();
+ llvm::Type* lhs_llvm_type =
+ llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
+ llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ lhs_llvm_type, "convolution_sum_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(lhs_element_type));
+ llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type);
+ ir_builder_.CreateStore(constant_zero, sum_address);
+
+ llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
+ std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial[i] =
+ loops
+ .AddLoop(
+ 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)),
+ tensorflow::strings::StrCat("k", i))
+ ->GetIndVarValue();
+ }
+ llvm::Value* input_feature =
+ loops
+ .AddLoop(0, lhs->shape().dimensions(dnums.input_feature_dimension()),
+ "iz")
+ ->GetIndVarValue();
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Calculate the spatial index in the input array, taking striding, dilation
+ // and padding into account. An index in the padding will be out of the bounds
+ // of the array.
+ const auto calculate_input_index = [this](llvm::Value* output_index,
+ llvm::Value* kernel_index,
+ const WindowDimension& window_dim) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ output_index, ir_builder_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
+ kernel_index, ir_builder_.getInt64(window_dim.window_dilation()));
+ return ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
+ ir_builder_.getInt64(window_dim.padding_low()));
+ };
+ std::vector<llvm::Value*> input_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] = calculate_input_index(
+ output_spatial[i], kernel_spatial[i], window.dimensions(i));
+ }
+
+ // We need to check if 0 <= input dim < bound, as otherwise we are in the
+ // padding so that we can skip the computation. That is equivalent to input
+ // dim < bound as an *unsigned* comparison, since a negative value will wrap
+ // to a large positive value. The input dim is dilated, so we need to dilate
+ // the bound as well to match.
+
+ // Also need to check that the input coordinates are not in one of the
+ // holes created by base dilation.
+ const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) {
+ llvm::Value* remainder = ir_builder_.CreateSRem(
+ input_index, ir_builder_.getInt64(base_dilation));
+ return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
+ };
+
+ llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ llvm::ConstantInt* input_bound =
+ ir_builder_.getInt64(window_util::DilatedBound(
+ lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
+ window.dimensions(i).base_dilation()));
+ llvm::Value* dim_in_bound =
+ ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_not_in_hole =
+ not_in_hole(input_spatial[i], window.dimensions(i).base_dilation());
+ llvm::Value* dim_ok = ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition = ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
+ }
+
+ // Now we need to map the dilated base coordinates back to the actual
+ // data indices on the lhs.
+ const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) {
+ return ir_builder_.CreateSDiv(input_index,
+ ir_builder_.getInt64(base_dilation));
+ };
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] =
+ undilate(input_spatial[i], window.dimensions(i).base_dilation());
+ }
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ int num_dims = num_spatial_dims + 2;
+ llvm_ir::IrArray::Index input_index(ir_builder_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
+ }
+ input_index[dnums.input_feature_dimension()] = input_feature;
+ input_index[dnums.input_batch_dimension()] = batch;
+
+ llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
+ llvm_ir::IrArray::Index kernel_index(ir_builder_.getInt64Ty(), num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_index[dnums.kernel_spatial_dimensions(i)] =
+ window.dimensions(i).window_reversal()
+ ? ir_builder_.CreateNSWSub(
+ ir_builder_.getInt64(window.dimensions(i).size() - 1),
+ kernel_spatial[i])
+ : kernel_spatial[i];
+ }
+
+ kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
+ kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
+
+ llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
+ llvm::Value* product = ir_builder_.CreateFMul(
+ input_array.EmitReadArrayElement(input_index, &ir_builder_),
+ kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
+ llvm::Value* sum =
+ ir_builder_.CreateFAdd(ir_builder_.CreateLoad(sum_address), product);
+ ir_builder_.CreateStore(sum, sum_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(sum_address);
+}
+
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
- const auto& window = convolution->window();
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F16, F32, C64}));
- const ConvolutionDimensionNumbers& dnums =
- convolution->convolution_dimension_numbers();
-
// TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support
// different data layouts.
if (PotentiallyImplementedAsEigenConvolution(*convolution,
@@ -990,149 +1133,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
// See the description of convolution in the XLA documentation for the pseudo
// code for convolution.
return EmitTargetElementLoop(
- convolution, [this, convolution, lhs, rhs, window,
- dnums](const llvm_ir::IrArray::Index& index) {
- int num_spatial_dims = dnums.output_spatial_dimensions_size();
- std::vector<llvm::Value*> output_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- output_spatial[i] = index[dnums.output_spatial_dimensions(i)];
- }
- llvm::Value* output_feature = index[dnums.output_feature_dimension()];
- llvm::Value* batch = index[dnums.output_batch_dimension()];
-
- // We will accumulate the products into this sum to calculate
- // the output entry at the given index.
- PrimitiveType lhs_element_type = lhs->shape().element_type();
- llvm::Type* lhs_llvm_type =
- llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_);
- llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
- lhs_llvm_type, "convolution_sum_address", &ir_builder_,
- MinimumAlignmentForPrimitiveType(lhs_element_type));
- llvm::Value* constant_zero =
- llvm::Constant::getNullValue(lhs_llvm_type);
- ir_builder_.CreateStore(constant_zero, sum_address);
-
- llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_);
- std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_spatial[i] =
- loops
- .AddLoop(0,
- rhs->shape().dimensions(
- dnums.kernel_spatial_dimensions(i)),
- tensorflow::strings::StrCat("k", i))
- ->GetIndVarValue();
- }
- llvm::Value* input_feature =
- loops
- .AddLoop(
- 0, lhs->shape().dimensions(dnums.input_feature_dimension()),
- "iz")
- ->GetIndVarValue();
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Calculate the spatial index in the input array, taking striding,
- // dilation and padding into account. An index in the padding will be
- // out of the bounds of the array.
- const auto calculate_input_index =
- [this](llvm::Value* output_index, llvm::Value* kernel_index,
- const WindowDimension& window_dim) {
- llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- output_index, ir_builder_.getInt64(window_dim.stride()));
- llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
- kernel_index,
- ir_builder_.getInt64(window_dim.window_dilation()));
- return ir_builder_.CreateNSWSub(
- ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
- ir_builder_.getInt64(window_dim.padding_low()));
- };
- std::vector<llvm::Value*> input_spatial(num_spatial_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] = calculate_input_index(
- output_spatial[i], kernel_spatial[i], window.dimensions(i));
- }
-
- // We need to check if 0 <= input dim < bound, as otherwise we are in
- // the padding so that we can skip the computation. That is equivalent
- // to input dim < bound as an *unsigned* comparison, since a negative
- // value will wrap to a large positive value. The input dim is dilated,
- // so we need to dilate the bound as well to match.
-
- // Also need to check that the input coordinates are not in one of the
- // holes created by base dilation.
- const auto not_in_hole = [&](llvm::Value* input_index,
- int64 base_dilation) {
- llvm::Value* remainder = ir_builder_.CreateSRem(
- input_index, ir_builder_.getInt64(base_dilation));
- return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
- };
-
- llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
- for (int i = 0; i < num_spatial_dims; ++i) {
- llvm::ConstantInt* input_bound =
- ir_builder_.getInt64(window_util::DilatedBound(
- lhs->shape().dimensions(dnums.input_spatial_dimensions(i)),
- window.dimensions(i).base_dilation()));
- llvm::Value* dim_in_bound =
- ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
- llvm::Value* dim_not_in_hole = not_in_hole(
- input_spatial[i], window.dimensions(i).base_dilation());
- llvm::Value* dim_ok =
- ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
- in_bounds_condition =
- ir_builder_.CreateAnd(in_bounds_condition, dim_ok);
- }
-
- // Now we need to map the dilated base coordinates back to the actual
- // data indices on the lhs.
- const auto undilate = [&](llvm::Value* input_index,
- int64 base_dilation) {
- return ir_builder_.CreateSDiv(input_index,
- ir_builder_.getInt64(base_dilation));
- };
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_spatial[i] =
- undilate(input_spatial[i], window.dimensions(i).base_dilation());
- }
-
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- in_bounds_condition, "in-bounds", &ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
-
- // We are not in the padding, so carry out the computation.
- int num_dims = num_spatial_dims + 2;
- llvm_ir::IrArray::Index input_index(num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i];
- }
- input_index[dnums.input_feature_dimension()] = input_feature;
- input_index[dnums.input_batch_dimension()] = batch;
-
- llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs));
- llvm_ir::IrArray::Index kernel_index(num_dims);
- for (int i = 0; i < num_spatial_dims; ++i) {
- kernel_index[dnums.kernel_spatial_dimensions(i)] =
- window.dimensions(i).window_reversal()
- ? ir_builder_.CreateNSWSub(
- ir_builder_.getInt64(window.dimensions(i).size() - 1),
- kernel_spatial[i])
- : kernel_spatial[i];
- }
-
- kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
- kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
-
- llvm_ir::IrArray input_array(GetIrArrayFor(lhs));
- llvm::Value* product = ir_builder_.CreateFMul(
- input_array.EmitReadArrayElement(input_index, &ir_builder_),
- kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
- llvm::Value* sum = ir_builder_.CreateFAdd(
- ir_builder_.CreateLoad(sum_address), product);
- ir_builder_.CreateStore(sum, sum_address);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(sum_address);
+ convolution, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForConvolution(
+ Cast<HloConvolutionInstruction>(convolution), index);
});
}
@@ -1421,6 +1424,10 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator(
return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); };
+ case HloOpcode::kXor:
+ return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
+ llvm::Value* rhs) { return ir_builder->CreateXor(lhs, rhs); };
+
case HloOpcode::kMaximum:
return [root_is_floating_point, root_is_signed](
llvm::IRBuilder<>* ir_builder, llvm::Value* lhs,
@@ -1677,7 +1684,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
// }
llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_);
- llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size());
+ llvm_ir::IrArray::Index array_index(ir_builder_.getInt64Ty(),
+ reduce->shape().dimensions_size());
for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0;
--i) {
int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i);
@@ -1764,6 +1772,64 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
return true;
}
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index) {
+ const HloInstruction* arg = reduce->mutable_operand(0);
+ const HloInstruction* init_value = reduce->mutable_operand(1);
+ gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ HloComputation* function = reduce->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+
+ // Initialize an accumulator with init_value.
+ PrimitiveType accumulator_type = reduce->shape().element_type();
+ llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator",
+ &ir_builder_, MinimumAlignmentForPrimitiveType(accumulator_type));
+ llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
+ llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
+ ir_builder_.CreateStore(load_init_value, accumulator_addr);
+
+ // The enclosing loops go over all the target elements. Now we have to compute
+ // the actual target element. For this, we build a new loop nest to iterate
+ // over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction Value*s
+ // are placed for each dimension in dimensions, and all the rest are nullptrs.
+ llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
+ const llvm_ir::IrArray::Index reduced_dims_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Build a full index for the input argument, using reduced_dims_index as the
+ // base. In reduced_dims_index only the reduction dimensions are filled in. We
+ // fill in the rest of the dimensions with induction Value*s taken from
+ // 'index' which iterates over the target array. See the high-level
+ // description in the XLA documentation for details.
+ llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
+ llvm_ir::IrArray::Index input_index = reduced_dims_index;
+ llvm_ir::IrArray::Index::const_iterator it = index.begin();
+
+ for (size_t i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+
+ // Apply the reduction function to the loaded value.
+ llvm::Value* input_address =
+ arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce->shape(), {accumulator_addr, input_address},
+ "reduce_function");
+ ir_builder_.CreateStore(result, accumulator_addr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_addr);
+}
+
Status IrEmitter::HandleReduce(HloInstruction* reduce) {
auto arg = reduce->mutable_operand(0);
auto init_value = reduce->mutable_operand(1);
@@ -1785,61 +1851,11 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
}
}
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
- return EmitTargetElementLoop(
- reduce, [this, reduce, arg, init_value, dimensions,
- reducer_function](const llvm_ir::IrArray::Index& index) {
- // Initialize an accumulator with init_value.
- PrimitiveType accumulator_type = reduce->shape().element_type();
- llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_),
- "accumulator", &ir_builder_,
- MinimumAlignmentForPrimitiveType(accumulator_type));
- llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
- llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
- ir_builder_.CreateStore(load_init_value, accumulator_addr);
-
- // The enclosing loops go over all the target elements. Now we have to
- // compute the actual target element. For this, we build a new loop nest
- // to iterate over all the reduction dimensions in the argument.
- // AddLoopsForShapeOnDimensions will return an Index where induction
- // Value*s are placed for each dimension in dimensions, and all the rest
- // are nullptrs.
- llvm_ir::ForLoopNest loops(IrName(reduce, "inner"), &ir_builder_);
- const llvm_ir::IrArray::Index reduced_dims_index =
- loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
- "reduction_dim");
-
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Build a full index for the input argument, using reduced_dims_index
- // as the base. In reduced_dims_index only the reduction dimensions are
- // filled in. We fill in the rest of the dimensions with induction
- // Value*s taken from 'index' which iterates over the target array.
- // See the high-level description in the XLA documentation for details.
- llvm_ir::IrArray arg_array(GetIrArrayFor(arg));
- llvm_ir::IrArray::Index input_index = reduced_dims_index;
- llvm_ir::IrArray::Index::const_iterator it = index.begin();
-
- for (size_t i = 0; i < input_index.size(); ++i) {
- if (input_index[i] == nullptr) {
- input_index[i] = *it++;
- }
- }
- CHECK(index.end() == it);
-
- // Apply the reduction function to the loaded value.
- llvm::Value* input_address =
- arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce->shape(),
- {accumulator_addr, input_address}, "reduce_function");
- ir_builder_.CreateStore(result, accumulator_addr);
-
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
- return ir_builder_.CreateLoad(accumulator_addr);
- });
+ return EmitTargetElementLoop(reduce,
+ [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForReduce(
+ Cast<HloReduceInstruction>(reduce), index);
+ });
}
Status IrEmitter::HandleSend(HloInstruction* send) {
@@ -1868,7 +1884,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
- if (ShapeUtil::HasZeroElements(slice->shape())) {
+ if (ShapeUtil::IsZeroElementArray(slice->shape())) {
return Status::OK();
}
@@ -2061,7 +2077,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
// Compute the output index the operand element should be assigned to.
// output_index := edge_padding_low + operand_index * (interior_padding + 1)
const PaddingConfig& padding_config = pad->padding_config();
- llvm_ir::IrArray::Index output_index;
+ llvm_ir::IrArray::Index output_index(operand_index.GetType());
for (size_t i = 0; i < operand_index.size(); ++i) {
llvm::Value* offset = ir_builder_.CreateMul(
operand_index[i],
@@ -2523,6 +2539,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
return Status::OK();
}
+Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
+ TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0);
+ // No code to generate, but we need to emit an address for book-keeping.
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token));
+ return Status::OK();
+}
+
Status IrEmitter::FinishVisit(HloInstruction* root) {
// When this method is called, we should have already emitted an IR value for
// the root (return) op. The IR value holds the address of the buffer holding
@@ -2804,7 +2827,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = compute_function_->result_arg();
- if (!ShapeUtil::IsNil(target_shape)) {
+ if ((ShapeUtil::IsArray(target_shape) &&
+ !ShapeUtil::IsZeroElementArray(target_shape)) ||
+ (ShapeUtil::IsTuple(target_shape) &&
+ !ShapeUtil::IsEmptyTuple(target_shape))) {
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index f49cfc1dc3..419f19c24d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -30,12 +30,12 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -67,17 +67,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// index in the profiling array.
// computation_to_profile_idx: the mapping from HLO computations to their
// index in the profiling array.
- // external_constant_pool: if non-null, points to an ExternalConstantPool
- // instance into which the Ir emitter can spill
- // constants.
IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
llvm::Module* llvm_module,
std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx,
std::unordered_map<const HloComputation*, int64>
computation_to_profile_idx,
- const TargetMachineFeatures* target_machine,
- ExternalConstantPool* external_constant_pool);
+ const TargetMachineFeatures* target_machine);
~IrEmitter() override;
// Emit and return the given HLO computation as an LLVM IR
@@ -122,6 +118,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleCopy(HloInstruction* copy) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
@@ -150,6 +147,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
+ Status HandleAfterAll(HloInstruction* gen_token) override;
Status FinishVisit(HloInstruction* root) override;
Status Preprocess(HloInstruction* hlo) override;
@@ -517,6 +515,17 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Returns the number of bytes within the shape.
int64 ByteSizeOf(const Shape& shape) const;
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduceWindow(
+ HloReduceWindowInstruction* reduce_window,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForConvolution(
+ HloConvolutionInstruction* convolution,
+ const llvm_ir::IrArray::Index& index);
+ StatusOr<llvm::Value*> EmitTargetElementLoopBodyForReduce(
+ HloReduceInstruction* reduce, const llvm_ir::IrArray::Index& index);
+
enum class XfeedKind {
kInfeed,
kOutfeed,
@@ -527,7 +536,8 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address);
- llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal);
+ // Returns a ConstExpr bitcast.
+ llvm::Constant* EmitGlobalForLiteral(const Literal& literal);
const HloModuleConfig& hlo_module_config_;
@@ -535,9 +545,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const TargetMachineFeatures& target_machine_features_;
- int64 external_global_constant_counter_ = 0;
- ExternalConstantPool* external_constant_pool_;
-
struct LiteralPtrHashFunctor {
size_t operator()(const Literal* literal) const { return literal->Hash(); }
};
@@ -548,7 +555,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
}
};
- tensorflow::gtl::FlatMap<const Literal*, llvm::GlobalVariable*,
+ tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*,
LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
emitted_literals_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
index 54af40506d..59ae5acd8b 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -31,13 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter(
std::vector<llvm_ir::IrArray::Index>
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name) {
+ tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ CHECK_NE(index_type, nullptr);
+
CHECK(!ShapeUtil::IsTuple(shape_));
CHECK(!ShapeUtil::IsScalar(shape_));
llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_);
const int64 num_dims = shape_.dimensions_size();
- llvm_ir::IrArray::Index array_index(num_dims);
+ llvm_ir::IrArray::Index array_index(index_type, num_dims);
// Add loops from outer-most to inner-most dimensions.
for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) {
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
index 755715634a..25e182a26d 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name) override;
+ tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
private:
const DynamicLoopBounds* dynamic_loop_bounds_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
index fc2efbaf9a..36c9f74385 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc
@@ -110,8 +110,9 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) {
const string hlo_string = R"(
HloModule TestTaskParallel_infeed_outfeed
ENTRY InfeedOutfeed {
- infeed0 = u32[12345678,2]{1,0} infeed()
- ROOT outfeed0 = u32[12345678,2]{1,0} outfeed(infeed0)
+ infeed0 = (u32[12345678,2]{1,0}, token[]) infeed()
+ infeed0.data = u32[12345678,2]{1,0} get-tuple-element((u32[12345678,2]{1,0}, token[]) infeed0), index=0
+ ROOT outfeed0 = token[] outfeed(infeed0.data)
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
index 92da5f71c2..f8c8dd5e93 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "third_party/intel_mkl_ml/include/mkl_cblas.h"
#include "third_party/intel_mkl_ml/include/mkl_service.h"
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
index 167aa4adda..d9e8dcaed9 100644
--- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -38,20 +38,21 @@ int main(int argc, char** argv) {
// Transfer parameters.
std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<xla::GlobalData> param0_data =
client->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> param1_literal = xla::Literal::CreateR2<float>(
- {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
std::unique_ptr<xla::GlobalData> param1_data =
client->TransferToServer(*param1_literal).ConsumeValueOrDie();
// Build computation.
xla::XlaBuilder builder("");
- auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto add = builder.Add(p1, p0, {0});
+ auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Add(p1, p0, {0});
xla::StatusOr<xla::XlaComputation> computation_status = builder.Build();
xla::XlaComputation computation = computation_status.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index c4c90515ac..be772cfb7e 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -127,13 +127,6 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
}
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
- if (const uint8* from_constant_pool =
- external_constant_pool_.Find(string(name))) {
- return llvm::JITEvaluatedSymbol(
- reinterpret_cast<uint64_t>(from_constant_pool),
- llvm::JITSymbolFlags::None);
- }
-
void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
if (func_addr == nullptr) {
return nullptr;
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
index 1851a3ee0b..d74b63fcf4 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
-#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
@@ -91,10 +90,6 @@ class SimpleOrcJIT {
llvm::TargetMachine* target_machine() const { return target_machine_.get(); }
- ExternalConstantPool* external_constant_pool() {
- return &external_constant_pool_;
- }
-
// Creates an llvm::TargetMachine suitable for JITting code that will run on
// the current machine.
static std::unique_ptr<llvm::TargetMachine> InferTargetMachineForJIT(
@@ -112,7 +107,6 @@ class SimpleOrcJIT {
std::shared_ptr<llvm::orc::SymbolResolver> symbol_resolver_;
ObjLayerT object_layer_;
CompileLayerT compile_layer_;
- ExternalConstantPool external_constant_pool_;
};
} // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 67f776e7b5..b4c33e2f6c 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -40,7 +40,7 @@ tf_cc_test(
name = "cpu_fusion_test",
srcs = ["cpu_fusion_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -82,7 +82,7 @@ tf_cc_test(
name = "cpu_noalias_test",
srcs = ["cpu_noalias_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -128,7 +128,7 @@ tf_cc_test(
name = "cpu_infeed_test",
srcs = ["cpu_infeed_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
@@ -152,9 +152,9 @@ tf_cc_test(
srcs = ["cpu_literal_caching_test.cc"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -166,9 +166,9 @@ tf_cc_test(
srcs = ["cpu_outfeed_test.cc"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
index 7c8d07a10b..77b3a0301f 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h
@@ -22,7 +22,7 @@ namespace xla {
namespace cpu {
// Tests that verify IR emitted by the CPU backend is as expected.
-class CpuCodegenTest : public LLVMIRGenTestBase {};
+class CpuCodegenTest : public LlvmIrGenTestBase {};
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
index ed8f375bd6..00a7aa2ad2 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc
@@ -40,7 +40,7 @@ class CpuExternalConstantsTest : public CpuCodegenTest {
HloInstruction* constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D(backing_array)));
+ LiteralUtil::CreateR2FromArray2D(backing_array)));
HloInstruction* param =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
builder.AddInstruction(
@@ -56,7 +56,8 @@ class CpuExternalConstantsTest : public CpuCodegenTest {
TEST_F(CpuExternalConstantsTest, Basic) {
TestWithArray(/*rows=*/1024, /*cols=*/1024, R"(
-CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16
+CHECK-NOT: @constant_global_0 = external constant [1024 x [1024 x float]], align 16
+CHECK: @0 = private constant [4194304 x i8] {{.*}}, align 16
)");
}
@@ -64,8 +65,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) {
// The constant array in this test case is small enough that there is no need
// to externalize it.
TestWithArray(/*rows=*/4, /*cols=*/4, R"(
-CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8
-CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8
+CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8
+CHECK: @0 = private constant [64 x i8] {{.*}}, align 8
)");
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
index 23e7a3de4d..d98856fdbf 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -43,8 +43,8 @@ class CpuFusionTest : public HloTestBase {
TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
auto builder = HloComputation::Builder(TestName());
- auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
- auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
Shape vshape = input_literal1->shape();
auto input1 = builder.AddInstruction(
@@ -83,7 +83,7 @@ TEST_F(CpuFusionTest, FuseTwoElementwiseOps) {
TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
auto input = builder.AddInstruction(
@@ -96,8 +96,11 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil));
auto floor = builder.AddInstruction(
HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp));
- auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ vshape,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
+ {}));
builder.AddInstruction(
HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor));
@@ -114,9 +117,9 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) {
EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
EXPECT_EQ(HloOpcode::kMultiply,
fusion_instruction->fused_expression_root()->opcode());
- // There should be 7 fused instructions: 2 parameters and the fused
+ // There should be 8 fused instructions: 2 parameters and the fused
// operations.
- EXPECT_EQ(7, fusion_instruction->fused_instruction_count());
+ EXPECT_EQ(8, fusion_instruction->fused_instruction_count());
// Compile and execute the computation.
auto result = ExecuteAndTransfer(std::move(module), {});
@@ -131,7 +134,7 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
// middle.
auto module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({-1.5, -2.5, -3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({-1.5, -2.5, -3.0});
Shape vshape = input_literal->shape();
auto input = builder.AddInstruction(
@@ -163,15 +166,18 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
ShapeUtil::MakeShape(F32, {6, 1}), concatenate)),
/*init_value=*/
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{1}, add_f32));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce));
auto floor = builder.AddInstruction(
HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp));
- auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ auto two = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ cshape,
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0))),
+ {}));
builder.AddInstruction(
HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor));
@@ -188,9 +194,9 @@ TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) {
EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode());
EXPECT_EQ(HloOpcode::kMultiply,
fusion_instruction1->fused_expression_root()->opcode());
- // There should be 5 fused instructions in the root fusion instruction: 2
+ // There should be 6 fused instructions in the root fusion instruction: 2
// parameters, multiply, floor, and exp.
- EXPECT_EQ(5, fusion_instruction1->fused_instruction_count())
+ EXPECT_EQ(6, fusion_instruction1->fused_instruction_count())
<< fusion_instruction1->fused_instructions_computation()->ToString();
auto fusion_instruction2 = reduce->operand(0);
@@ -225,7 +231,7 @@ TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) {
// operand vectors. Test for this problem by counting the number of nodes in
// each fusion instruction to ensure that negate is not duplicated.
auto builder = HloComputation::Builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
Shape vshape = input_literal->shape();
auto constant = builder.AddInstruction(
@@ -286,10 +292,10 @@ TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) {
// computation. The duplication is caused by the other use of exp2 in the
// tuple.
auto builder = HloComputation::Builder(TestName());
- auto input_literal1 = Literal::CreateR1<float>({1.0, 2.0, 3.0});
- auto input_literal2 = Literal::CreateR1<float>({-2.0, -42.0, 2.0});
+ auto input_literal1 = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
+ auto input_literal2 = LiteralUtil::CreateR1<float>({-2.0, -42.0, 2.0});
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
Shape shape = constant->shape();
auto exp1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
index dd63b998e9..0d45918d09 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -47,7 +47,7 @@ class InfeedTest : public ClientLibraryTestBase {
// don't use ResetDevice since it is not implemented on CPU.
ASSERT_IS_OK(client_->TransferToInfeed(literal));
XlaBuilder builder(TestName());
- builder.Infeed(literal.shape());
+ Infeed(&builder, literal.shape());
if (ShapeUtil::IsTuple(literal.shape())) {
// TODO(b/30609564): Use ComputeAndCompareLiteral instead.
ComputeAndCompareTuple(&builder, literal, {});
@@ -58,52 +58,52 @@ class InfeedTest : public ClientLibraryTestBase {
};
TEST_F(InfeedTest, SingleInfeedR0Bool) {
- TestInfeedRoundTrip(*Literal::CreateR0<bool>(true));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR0<bool>(true));
}
TEST_F(InfeedTest, SingleInfeedR1U32) {
- TestInfeedRoundTrip(*Literal::CreateR1<uint32>({1, 2, 3}));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
TEST_F(InfeedTest, SingleInfeedR2F32) {
- TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR2F32Linspace(0.0, 1.0, 128, 64));
}
TEST_F(InfeedTest, SingleInfeedR3F32) {
TestInfeedRoundTrip(
- *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) {
const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2});
const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0});
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0minor));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0minor));
- TestInfeedRoundTrip(
- *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
- r3_dim0major));
+ TestInfeedRoundTrip(*LiteralUtil::CreateR3WithLayout(
+ {{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}},
+ r3_dim0major));
}
TEST_F(InfeedTest, SingleInfeedR4S32) {
- TestInfeedRoundTrip(*Literal::CreateR4(
+ TestInfeedRoundTrip(*LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
TEST_F(InfeedTest, SingleInfeedTuple) {
TestInfeedRoundTrip(
- *Literal::MakeTuple({Literal::CreateR1<uint32>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(false).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<uint32>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()}));
}
TEST_F(InfeedTest, SingleInfeedEmptyTuple) {
- TestInfeedRoundTrip(*Literal::MakeTuple({}));
+ TestInfeedRoundTrip(*LiteralUtil::MakeTuple({}));
}
// Tests Infeed operation used in a while loop, as in the code below. The
@@ -125,8 +125,8 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Gt(builder.ConstantR0<float>(40.0f), prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Gt(ConstantR0<float>(&builder, 40.0f), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add the reduced value of the Infeed
@@ -134,17 +134,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto infeed = builder.Infeed(infeed_shape);
- auto addend =
- builder.Reduce(infeed, builder.ConstantR0<float>(0.0f),
- CreateScalarAddComputation(F32, &builder), {0});
- builder.Add(prev, addend);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto infeed = Infeed(&builder, infeed_shape);
+ auto addend = Reduce(infeed, ConstantR0<float>(&builder, 0.0f),
+ CreateScalarAddComputation(F32, &builder), {0});
+ Add(prev, addend);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- auto init = builder.ConstantR0<float>(0.0f);
- builder.While(condition, body, init);
+ auto init = ConstantR0<float>(&builder, 0.0f);
+ While(condition, body, init);
// Build and asynchronously launch the computation.
auto computation = builder.Build().ConsumeValueOrDie();
@@ -157,13 +156,16 @@ TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) {
});
// Send 5 Infeed data of shape F32[3].
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({1, 2, 3})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({4, 5, 6})));
- ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1<float>({7, 8, 9})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({10, 11, 12})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({1, 2, 3})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({4, 5, 6})));
ASSERT_IS_OK(
- client_->TransferToInfeed(*Literal::CreateR1<float>({13, 14, 15})));
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({7, 8, 9})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({10, 11, 12})));
+ ASSERT_IS_OK(
+ client_->TransferToInfeed(*LiteralUtil::CreateR1<float>({13, 14, 15})));
delete computation_thread; // Joins the thread.
auto result_literal = client_->Transfer(*result).ConsumeValueOrDie();
@@ -207,8 +209,8 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.GetTupleElement(prev, 1);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ GetTupleElement(prev, 1);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -221,44 +223,44 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
const auto build_body = [this, &result_shape](const Shape& infeed_shape) {
XlaComputation body;
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto infeed = builder.Infeed(infeed_shape);
- auto addend = builder.Reduce(
- builder.GetTupleElement(infeed, 0), builder.ConstantR0<float>(0.0f),
- CreateScalarAddComputation(F32, &builder), {0});
- auto result = builder.Add(builder.GetTupleElement(prev, 0), addend);
- builder.Tuple({result, builder.GetTupleElement(infeed, 1)});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto infeed = Infeed(&builder, infeed_shape);
+ auto addend =
+ Reduce(GetTupleElement(infeed, 0), ConstantR0<float>(&builder, 0.0f),
+ CreateScalarAddComputation(F32, &builder), {0});
+ auto result = Add(GetTupleElement(prev, 0), addend);
+ Tuple(&builder, {result, GetTupleElement(infeed, 1)});
return builder.Build().ConsumeValueOrDie();
};
// Create the first while loop with infeed1_shape.
- auto init = builder.Tuple(
- {builder.ConstantR0<float>(0.0f), builder.ConstantR0<bool>(true)});
- auto while1 = builder.While(condition, build_body(infeed1_shape), init);
- auto result1 = builder.Tuple(
- {builder.GetTupleElement(while1, 0), builder.ConstantR0<bool>(true)});
+ auto init = Tuple(&builder, {ConstantR0<float>(&builder, 0.0f),
+ ConstantR0<bool>(&builder, true)});
+ auto while1 = While(condition, build_body(infeed1_shape), init);
+ auto result1 = Tuple(
+ &builder, {GetTupleElement(while1, 0), ConstantR0<bool>(&builder, true)});
// Create the second while loop with infeed2_shape. Note that the result from
// the first while loop is used as the initial value.
- auto while2 = builder.While(condition, build_body(infeed2_shape), result1);
- builder.GetTupleElement(while2, 0);
+ auto while2 = While(condition, build_body(infeed2_shape), result1);
+ GetTupleElement(while2, 0);
// Build the computation.
auto computation = builder.Build().ConsumeValueOrDie();
// Send the first 4 Infeed data of shape Tuple(F32[2], PRED).
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({3, 4}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({3, 4}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
// Asynchronously launch the execution on the device.
std::unique_ptr<GlobalData> result;
@@ -273,14 +275,14 @@ TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) {
// Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED).
sleep(1);
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({7, 8, 9}).get(),
- Literal::CreateR0<bool>(false).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({7, 8, 9}).get(),
+ LiteralUtil::CreateR0<bool>(false).get()})));
ASSERT_IS_OK(client_->TransferToInfeed(
- *Literal::MakeTuple({Literal::CreateR1<float>({4, 5, 6}).get(),
- Literal::CreateR0<bool>(true).get()})));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ LiteralUtil::CreateR0<bool>(true).get()})));
// Wait for the execution to be done, and transfer the result.
delete computation_thread; // Joins the thread.
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
index d6e0425c55..90b99c828e 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -38,7 +38,8 @@ while_body {
while_cond {
arg_cond = f32[2,3,2] parameter(0)
- ROOT unknown = pred[] infeed()
+ infeed = (pred[], token[]) infeed()
+ ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
@@ -49,18 +50,18 @@ ENTRY main {
{{2, 1}, {2001, 3002}, {2001, 2002}}})
const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body
- out0 = () outfeed(f32[2,3,2] const_a)
- ROOT out1 = () outfeed(f32[2,3,2] const_b)
+ out0 = token[] outfeed(f32[2,3,2] const_a)
+ ROOT out1 = token[] outfeed(f32[2,3,2] const_b)
}
)";
string filecheck_pattern = R"(
-CHECK: private constant [2 x [3 x [2 x float]]]
-CHECK-NOT: private constant [2 x [3 x [2 x float]]]
+CHECK: private constant [48 x i8]
+CHECK-NOT: private constant [48 x i8]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
@@ -78,34 +79,35 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) {
HloModule RepeatedConstants
while_body {
- arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
- ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
+ arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
+ ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
}
while_cond {
- arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0)
- ROOT unknown = pred[] infeed()
+ arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0)
+ infeed = (pred[], token[]) infeed()
+ ROOT unknown = pred[] get-tuple-element((pred[], token[]) infeed), index=0
}
ENTRY main {
param = f32[2,3,2] parameter(0)
- const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} ))
- const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body
+ const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} ))
+ const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body
- out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a)
- ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b)
+ out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a)
+ ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b)
}
)";
string filecheck_pattern = R"(
-CHECK: private constant [2 x float]
-CHECK: private constant [2 x [1 x float]]
-CHECK-NOT: private constant [2 x float]
-CHECK-NOT: private constant [2 x [1 x float]]
+CHECK: private constant [4 x i8]
+CHECK: private constant [8 x i8]
+CHECK-NOT: private constant [4 x i8]
+CHECK-NOT: private constant [8 x i8]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 3b6b0ed740..ccb61740f6 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <utility>
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
@@ -42,7 +42,7 @@ TEST_F(CpuNoAliasTest, Concat) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto param_shape = ShapeUtil::MakeShape(F32, {2, 2});
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "x"));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 879372eb13..dac416e1c7 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -32,16 +32,17 @@ ENTRY main {
{{{1, 2}, {1001, 1002}, {2001, 2002}},
{{2, 1}, {2001, 3002}, {2001, 2002}}})
- ROOT out = () outfeed(f32[2,3,2] const_a)
+ outfeed = token[] outfeed(f32[2,3,2] const_a)
+ ROOT root = () tuple()
}
)";
string filecheck_pattern = R"(
-CHECK: private constant [2 x [3 x [2 x float]]]
+CHECK: private constant [48 x i8]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
index cd1165e238..c444d15185 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc
@@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const {
void LlvmVariable::Set(llvm::Value* new_value) {
ir_builder_->CreateStore(new_value, alloca_);
}
+
+TileVariable::TileVariable(VectorSupportLibrary* vector_support,
+ std::vector<llvm::Value*> initial_value) {
+ for (llvm::Value* initial_vector_value : initial_value) {
+ storage_.emplace_back(vector_support, initial_vector_value);
+ }
+}
+
+std::vector<llvm::Value*> TileVariable::Get() const {
+ std::vector<llvm::Value*> result;
+ c_transform(storage_, std::back_inserter(result),
+ [&](VectorVariable vect_var) { return vect_var.Get(); });
+ return result;
+}
+
+void TileVariable::Set(tensorflow::gtl::ArraySlice<llvm::Value*> value) {
+ CHECK_EQ(value.size(), storage_.size());
+ for (int64 i = 0, e = value.size(); i < e; i++) {
+ storage_[i].Set(value[i]);
+ }
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index edcaec5849..49c2a4e2f4 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace cpu {
@@ -317,6 +318,21 @@ class ScalarVariable : public LlvmVariable {
Set(initial_value);
}
};
+
+// This wraps a set of alloca-backed stack variables that can, as a whole, store
+// a tile. A "tile" is a sequence of vectors that is typically used as a 2D
+// grid of scalar values (e.g. for tiled GEMMs).
+class TileVariable {
+ public:
+ TileVariable(VectorSupportLibrary* vector_support,
+ std::vector<llvm::Value*> initial_value);
+
+ std::vector<llvm::Value*> Get() const;
+ void Set(tensorflow::gtl::ArraySlice<llvm::Value*> value);
+
+ private:
+ std::vector<VectorVariable> storage_;
+};
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc
index 32b5c5d35f..e727ba49cb 100644
--- a/tensorflow/compiler/xla/service/defuser_test.cc
+++ b/tensorflow/compiler/xla/service/defuser_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/defuser.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
@@ -124,7 +124,7 @@ TEST_F(DefuserTest, NonTrivialFusionInstruction) {
auto div = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto add2 = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
@@ -162,7 +162,7 @@ TEST_F(DefuserTest, MultipleFusionInstructions) {
auto div = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3));
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto add2 = builder.AddInstruction(
HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div));
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 64678d9d74..51f16bdc94 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <type_traits>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
@@ -76,6 +76,7 @@ class DfsHloVisitorBase {
virtual Status HandleClamp(HloInstructionPtr hlo) = 0;
virtual Status HandleSelect(HloInstructionPtr hlo) = 0;
+ virtual Status HandleTupleSelect(HloInstructionPtr hlo) = 0;
virtual Status HandleMaximum(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
@@ -183,6 +184,9 @@ class DfsHloVisitorBase {
virtual Status HandleOr(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
+ virtual Status HandleXor(HloInstructionPtr hlo) {
+ return HandleElementwiseBinary(hlo);
+ }
virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo);
}
@@ -243,6 +247,8 @@ class DfsHloVisitorBase {
virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
+ virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
+
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
virtual Status FinishVisit(HloInstructionPtr root) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 240faebe62..0686ca74af 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/types.h"
@@ -79,6 +79,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleSelect(HloInstructionPtr select) override {
return DefaultAction(select);
}
+ Status HandleTupleSelect(HloInstructionPtr tuple_select) override {
+ return DefaultAction(tuple_select);
+ }
Status HandleDot(HloInstructionPtr dot) override {
return DefaultAction(dot);
}
@@ -188,6 +191,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
+ Status HandleAfterAll(HloInstructionPtr token) override {
+ return DefaultAction(token);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 9a8bab353e..bd68685153 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -456,17 +456,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
- // (x == x) && abs(x) != inf
+ // abs(x) o!= inf, this works because the comparison returns false if
+ // either operand is NaN.
auto type = operand_value->getType();
- auto equal_self =
- ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
auto abs_value = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
auto infinity = llvm::ConstantFP::getInfinity(type);
auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
- auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
return ir_builder_->CreateZExt(
- result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
return ir_builder_->CreateFNeg(operand_value);
@@ -1166,6 +1164,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
return ir_builder_->CreateOr(lhs_value, rhs_value);
+ case HloOpcode::kXor:
+ return ir_builder_->CreateXor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
@@ -1222,25 +1222,32 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const Shape& operand_shape = hlo.operand(operand_no)->shape();
// If the operand is scalar, the source index is always {}.
if (ShapeUtil::IsScalar(operand_shape)) {
- return llvm_ir::IrArray::Index();
+ return llvm_ir::IrArray::Index(target_index.GetType());
}
// If no implicit broadcast is needed for this operand, returns the target
// index as the source index.
- if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
+ //
+ // `IrArray::Index` may contain a physical linear which we can propagate to
+ // our operand only if our layouts match. "only if" is a bit strong since
+ // e.g. we can still forward the linear index if the operand shape is
+ // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases
+ // are probably not worth handling here for now.
+ if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) &&
+ LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) {
return target_index;
}
// If implicit broadcast is needed, the source dimensions that are broadcast
// have index 0.
CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
- llvm_ir::IrArray::Index source_index;
+ llvm_ir::IrArray::Index source_index(target_index.GetType());
for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
source_index.push_back(target_index[i]);
} else {
CHECK_EQ(1, operand_shape.dimensions(i));
- source_index.push_back(ir_builder_->getInt64(0));
+ source_index.push_back(target_index.GetConstantWithIndexType(0));
}
}
return source_index;
@@ -1542,9 +1549,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Emit IR to read dynamic start indices from hlo->operand(1).
const HloInstruction* input_hlo = hlo->operand(0);
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
+ // Use the same index type for all tensor accesses in the same kernel.
+ llvm::Type* index_type = index.GetType();
+ llvm_ir::IrArray::Index slice_start_index(index_type, rank);
for (int64 i = 0; i < rank; ++i) {
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+ llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(hlo->operand(1))(dim_index));
@@ -1553,18 +1565,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
- start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
- index[i]->getType());
- llvm::Value* operand_dim_size = llvm::ConstantInt::get(
- start_index_value->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* output_dim_size = llvm::ConstantInt::get(
- start_index_value->getType(), hlo->shape().dimensions(i));
+ // to officially document different behavior.
+ start_index_value =
+ ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
+ llvm::Value* operand_dim_size =
+ index_typed_const(input_hlo->shape().dimensions(i));
+ llvm::Value* output_dim_size =
+ index_typed_const(hlo->shape().dimensions(i));
start_index_value = EmitIntegralMin(
ir_builder_->CreateSub(operand_dim_size, output_dim_size),
- EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
- start_index_value, /*is_signed=*/true),
+ EmitIntegralMax(index_typed_const(0), start_index_value,
+ /*is_signed=*/true),
/*is_signed=*/true);
start_index_value->setName(
@@ -1572,7 +1584,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
slice_start_index[i] = start_index_value;
}
- llvm_ir::IrArray::Index input_index(rank);
+ llvm_ir::IrArray::Index input_index(index_type, rank);
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
@@ -1596,25 +1608,29 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
const llvm_ir::ElementGenerator& indices_generator =
operand_to_generator.at(hlo->operand(1));
+ llvm::Type* index_type = index.GetType();
// This is the index into `operand` that holds the element we want to
- // generate. This index "unsafe" as in the components in here may be
- // out of bounds.
- IrArray::Index unsafe_operand_index;
-
- // First copy in the window indices to unsafe_operand_index.
- for (int64 i = 0, e = operand_shape.dimensions_size(),
- unsafe_operand_index_dim = 0;
+ // generate.
+ IrArray::Index operand_index(index_type);
+
+ // First copy in the window indices to operand_index. Also collect a mapping
+ // from operand dimension to output window dimension. Elided window dimensions
+ // map to -1.
+ std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
+ for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
- unsafe_operand_index.push_back(ir_builder_->getInt64(0));
+ operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- unsafe_operand_index.push_back(
- index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
+ int64 output_window_dim =
+ dim_numbers.output_window_dims(operand_index_dim++);
+ operand_to_output_dim[i] = output_window_dim;
+ operand_index.push_back(index[output_window_dim]);
}
}
// This is the index of the index vector in the gather_indices tensor.
- IrArray::Index gather_index_index;
+ IrArray::Index gather_index_index(index_type);
{
std::vector<llvm::Value*> gather_index_index_components;
for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) {
@@ -1628,40 +1644,54 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
}
}
- auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
- int64 dim) {
- llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc(
- index_component, ir_builder_->getInt64Ty());
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
- ir_builder_->CreateAdd(
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
- gather_dim_component_extended);
+ auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
+ llvm::Value* gather_dim_component_extended =
+ ir_builder_->CreateSExtOrTrunc(index_component, index_type);
+ int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 output_dim = operand_to_output_dim[operand_dim];
+ // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
+ // This means we set the iteration index to 0, so for the purpose of the
+ // following calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
+ int64 largest_valid_start_index =
+ operand_shape.dimensions(operand_dim) - output_dim_size;
+ CHECK_GE(largest_valid_start_index, 0);
+
+ // Clamp the gather index so that the gather region fits in the operand.
+ // gather_dim_component_extended_inbound =
+ // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
+
+ // TODO(b/111078873): This is implementation defined behavior.
+
+ bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
+ auto gather_dim_component_extended_inbound = EmitIntegralMin(
+ index.GetConstantWithIndexType(largest_valid_start_index),
+ EmitIntegralMax(index.GetConstantWithIndexType(0),
+ gather_dim_component_extended,
+ /*is_signed=*/is_signed),
+ /*is_signed=*/is_signed);
+
+ operand_index[operand_dim] = ir_builder_->CreateAdd(
+ operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, 0);
+ add_to_operand_index(gather_dim_component, 0);
} else {
int64 index_vector_size =
indices_shape.dimensions(dim_numbers.index_vector_dim());
for (int64 i = 0; i < index_vector_size; i++) {
gather_index_index[dim_numbers.index_vector_dim()] =
- ir_builder_->getInt64(i);
+ index.GetConstantWithIndexType(i);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, i);
+ add_to_operand_index(gather_dim_component, i);
}
}
-
- IrArray::Index safe_operand_index;
- for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
- safe_operand_index.push_back(ir_builder_->CreateURem(
- unsafe_operand_index[i],
- ir_builder_->getInt64(operand_shape.dimensions(i))));
- }
-
- return operand_generator(safe_operand_index);
+ return operand_generator(operand_index);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
@@ -1673,14 +1703,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
const HloInstruction* start_hlo = hlo->operand(2);
// Calculate slice start/end indices.
const int64 rank = ShapeUtil::Rank(input_hlo->shape());
- llvm_ir::IrArray::Index slice_start_index(rank);
- llvm_ir::IrArray::Index slice_limit_index(rank);
+ llvm_ir::IrArray::Index slice_start_index(index.GetType(), rank);
+ llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
// 'input' is set to 'update'
llvm::Value* slice_intersection = ir_builder_->getTrue();
for (int64 i = 0; i < rank; ++i) {
- llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ llvm::Type* index_type = index[0]->getType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+ llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
operand_to_generator.at(start_hlo)(dim_index));
@@ -1689,19 +1723,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
- start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value,
- index[i]->getType());
- llvm::Value* input_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), input_hlo->shape().dimensions(i));
- llvm::Value* update_dim_size = llvm::ConstantInt::get(
- index[i]->getType(), update_hlo->shape().dimensions(i));
-
- start_index_value = EmitIntegralMin(
- ir_builder_->CreateSub(input_dim_size, update_dim_size),
- EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0),
- start_index_value, /*is_signed=*/true),
- /*is_signed=*/true);
+ // to officially document different behavior.
+ start_index_value =
+ ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
+ llvm::Value* input_dim_size =
+ index_typed_const(input_hlo->shape().dimensions(i));
+ llvm::Value* update_dim_size =
+ index_typed_const(update_hlo->shape().dimensions(i));
+
+ start_index_value =
+ EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size),
+ EmitIntegralMax(index_typed_const(0), start_index_value,
+ /*is_signed=*/true),
+ /*is_signed=*/true);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
@@ -1731,7 +1765,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Handle true BB (return data from 'update')
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
// Compute update index for intersection case.
- llvm_ir::IrArray::Index update_index(rank);
+ llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
}
@@ -1799,7 +1833,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
- operand_to_generator.at(hlo->operand(1))({}));
+ operand_to_generator.at(hlo->operand(1))(
+ IrArray::Index(index.GetType())));
ir_builder_->CreateStore(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
@@ -1826,10 +1861,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
- std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
- IrName(hlo, "inner"), ir_builder_->getInt64(0),
- ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1),
- ir_builder_);
+ llvm::Type* index_type = dot_result_index[0]->getType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+
+ std::unique_ptr<llvm_ir::ForLoop> inner_loop =
+ llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0),
+ index_typed_const(contracted_dim_size),
+ index_typed_const(1), ir_builder_);
SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_);
PrimitiveType primitive_type = hlo->shape().element_type();
@@ -1848,7 +1888,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
// Given an output index [a,b,c,d,e] in the result, we compute:
// sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
- IrArray::Index lhs_index, rhs_index;
+ IrArray::Index lhs_index(index_type), rhs_index(index_type);
for (int64 i = 0; i < lhs_dims - 1; i++) {
lhs_index.push_back(dot_result_index[i]);
@@ -1947,6 +1987,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kMultiply:
case HloOpcode::kNe:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kShiftLeft:
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index b43dc0c65d..addb016b04 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -33,7 +33,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text, config));
+ ParseHloString(hlo_text, config));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt));
}
};
@@ -57,8 +57,8 @@ ENTRY main {
}
)";
- std::unique_ptr<Literal> lhs = Literal::CreateR3<int32>({{{1}, {2}}});
- std::unique_ptr<Literal> rhs = Literal::CreateR3<int32>({{{3}, {4}}});
+ std::unique_ptr<Literal> lhs = LiteralUtil::CreateR3<int32>({{{1}, {2}}});
+ std::unique_ptr<Literal> rhs = LiteralUtil::CreateR3<int32>({{{3}, {4}}});
RunTest(hlo_text, {lhs.get(), rhs.get()});
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 8119478ce9..7cf2746947 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -116,6 +116,11 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
if (profile->compute_time_ns() == 0) {
profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
}
+
+ const int64 executable_size_in_bytes = SizeInBytes();
+ if (executable_size_in_bytes != 0) {
+ profile->set_executable_size_in_bytes(executable_size_in_bytes);
+ }
}
if (profile_ptr != nullptr) {
@@ -129,19 +134,7 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
return return_value;
}
-Status Executable::DumpSessionModule() {
- TF_RET_CHECK(dumping());
- const string& directory_path =
- module_config().debug_options().xla_dump_executions_to();
- VersionedComputationHandle versioned_handle = entry_computation_handle();
- // This filename does not include the version number because the computation
- // is only ever executed at one version.
- string filename = tensorflow::strings::Printf(
- "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(),
- session_module_->entry().name().c_str(), ++execution_count_);
- return Executable::DumpToDirectory(directory_path, filename,
- *session_module_);
-}
+int64 Executable::SizeInBytes() { return -1; }
Status Executable::DumpHloSnapshot() {
TF_RET_CHECK(dumping_snapshot());
@@ -158,26 +151,6 @@ Status Executable::DumpHloSnapshot() {
/* static */ Status Executable::DumpToDirectory(
const string& directory_path, string filename,
- const SessionModule& session_module) {
- tensorflow::Env* env = tensorflow::Env::Default();
- if (!env->IsDirectory(directory_path).ok()) {
- // NB! CreateDir does not work reliably with multiple XLA threads -- two
- // threads can race to observe the absence of the dump directory and
- // simultaneously try to create it, causing the "losing" thread to get a
- // "directory already exists" error.
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path));
- }
- filename = SanitizeFileName(std::move(filename));
- string file_path = tensorflow::io::JoinPath(directory_path, filename);
- string result;
- TF_RET_CHECK(
- tensorflow::SerializeToStringDeterministic(session_module, &result));
- return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path,
- result);
-}
-
-/* static */ Status Executable::DumpToDirectory(
- const string& directory_path, string filename,
const HloSnapshot& hlo_session) {
tensorflow::Env* env = tensorflow::Env::Default();
if (!env->IsDirectory(directory_path).ok()) {
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 4f0466c544..98eaeee30a 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -27,9 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/service_executable_run_options.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -90,8 +88,7 @@ class Executable {
// called explicitly for other (async, for example) variants after the stream
// has completed.
virtual Status PopulateExecutionProfile(
- HloExecutionProfile* hlo_execution_profile,
- se::StreamExecutor* executor) {
+ HloExecutionProfile* hlo_execution_profile, se::Stream* stream) {
return Status::OK();
}
@@ -132,25 +129,15 @@ class Executable {
const HloModuleConfig& module_config() const { return hlo_module_->config(); }
- // Returns the versioned computation handle of the computation computed by
- // this executable.
- const VersionedComputationHandle& entry_computation_handle() const {
- return hlo_module_->entry_computation_handle();
- }
-
// The shape (including layout) that results from this execution. This is the
// shape of the DeviceMemoryBase result value in ExecuteOnStream above.
- const Shape& host_result_shape() const {
- return hlo_module_->config().host_entry_computation_layout().result_shape();
+ const Shape& result_shape() const {
+ return hlo_module_->config().entry_computation_layout().result_shape();
}
- // TODO(b/74197823): Delete the session module dumping helpers.
- void set_session_module(std::unique_ptr<xla::SessionModule> session_module) {
- session_module_ = std::move(session_module);
- }
- bool dumping() const { return session_module_ != nullptr; }
- SessionModule* session_module() const { return session_module_.get(); }
- Status DumpSessionModule();
+ // Returns the size of the executable in bytes. Returns -1 by default if the
+ // method is not overridden to support this kind of query.
+ virtual int64 SizeInBytes();
// Dumping helpers.
void set_hlo_snapshot(std::unique_ptr<xla::HloSnapshot> hlo_snapshot) {
@@ -160,10 +147,6 @@ class Executable {
HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); }
Status DumpHloSnapshot();
- // Dump session_module to directory_path/filename.
- static Status DumpToDirectory(const string& directory_path, string filename,
- const SessionModule& session_module);
-
// Dump hlo snapshot to directory_path/filename.
static Status DumpToDirectory(const string& directory_path, string filename,
const HloSnapshot& hlo_session);
@@ -179,9 +162,6 @@ class Executable {
// around.
const std::unique_ptr<const HloModule> hlo_module_;
- // SessionModule this was compiled from. Null if not dumping executions.
- std::unique_ptr<SessionModule> session_module_;
-
// HloSnapshot this was compiled from. Null if not dumping executions.
std::unique_ptr<HloSnapshot> hlo_snapshot_;
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index d3854b40de..8f6608241e 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase {
HloInstruction* param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, kScalarShape, "param0"));
HloInstruction* zero = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero));
return builder.Build();
@@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(PRED, {}), "param0"));
HloInstruction* false_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
HloOpcode::kEq, param0, false_constant));
@@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
{
HloComputation::Builder builder(TestName() + ".entry");
HloInstruction* false_constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
builder.AddInstruction(HloInstruction::CreateWhile(
ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation,
false_constant));
@@ -232,11 +232,11 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
// computation in the true and false branch.
HloComputation::Builder builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
builder.AddInstruction(HloInstruction::CreateConditional(
kScalarShape, pred, constant1, sub_computation, constant2,
sub_computation));
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md
index f0f3dd7785..f0f3dd7785 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 2d3e4b1fcd..e3a42d0d06 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <utility>
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -113,7 +114,7 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
const Shape& index_shape = index_vector->shape();
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateFromDimensions(index_shape.element_type(), {1})));
+ LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
// We extract out individual components from the smaller index and concatenate
// them (interspersing zeros as needed) into the larger index.
@@ -300,7 +301,7 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* gather_instr) {
- CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape()));
+ CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape()));
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
@@ -369,7 +370,7 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
return inst->opcode() == HloOpcode::kGather &&
// Avoid expanding gather ops that produce zero sized tensors,
// instead punt these to ZeroSizedHloElimination.
- !ShapeUtil::HasZeroElements(inst->shape());
+ !ShapeUtil::IsZeroElementArray(inst->shape());
};
std::vector<HloInstruction*> gather_instrs;
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 1c72ca0665..020ffcd106 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gather_expander.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -36,7 +36,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
Status status = GatherExpander{}.Run(module.get()).status();
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
@@ -63,7 +63,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
ASSERT_TRUE(changed);
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 5ee67ccb4a..33730049c4 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -43,7 +43,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const {
}
Status GenericTransferManager::WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) {
TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
@@ -52,12 +52,24 @@ Status GenericTransferManager::WriteSingleTupleIndexTable(
for (const se::DeviceMemoryBase& element : elements) {
element_pointers.push_back(element.opaque());
}
- return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
- element_pointers.data(), region);
+ TF_RETURN_IF_ERROR(TransferBufferToDevice(
+ stream, GetByteSizeRequirement(shape), element_pointers.data(), region));
+ // Ensure the buffer is transferred before we destroy element_pointers.
+ return stream->BlockHostUntilDone();
+}
+
+void GenericTransferManager::TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ Status status = stream->BlockHostUntilDone();
+ if (!status.ok()) {
+ return done(status);
+ }
+ done(TransferLiteralFromDeviceInternal(stream->parent(), device_buffer));
}
StatusOr<std::unique_ptr<Literal>>
-GenericTransferManager::TransferLiteralFromDevice(
+GenericTransferManager::TransferLiteralFromDeviceInternal(
se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
VLOG(2) << "transferring literal from device ordinal "
<< executor->device_ordinal() << "; device buffer: " << device_buffer;
@@ -74,9 +86,8 @@ GenericTransferManager::TransferLiteralFromDevice(
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
- if (!ShapeUtil::IsTuple(subshape)) {
- TF_RETURN_IF_ERROR(TransferBufferFromDevice(
- executor,
+ if (ShapeUtil::IsArray(subshape)) {
+ TF_RETURN_IF_ERROR(executor->SynchronousMemcpyD2H(
/*source=*/device_buffer.buffer(index),
/*size=*/GetByteSizeRequirement(subshape),
/*destination=*/
@@ -88,8 +99,8 @@ GenericTransferManager::TransferLiteralFromDevice(
return std::move(literal);
}
-Status GenericTransferManager::TransferLiteralToDevice(
- se::StreamExecutor* executor, const LiteralSlice& literal,
+Status GenericTransferManager::TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
const ShapedBuffer& device_buffer) {
const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: "
@@ -103,9 +114,10 @@ Status GenericTransferManager::TransferLiteralToDevice(
TF_RET_CHECK(
ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
- TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+ TF_RET_CHECK(stream->parent()->device_ordinal() ==
+ device_buffer.device_ordinal());
- TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer));
+ TF_RETURN_IF_ERROR(WriteTupleIndexTables(stream, device_buffer));
return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
@@ -121,16 +133,21 @@ Status GenericTransferManager::TransferLiteralToDevice(
if (LayoutUtil::Equal(device_subshape.layout(),
subliteral.shape().layout())) {
source = subliteral.untyped_data();
+ return TransferBufferToDevice(
+ stream,
+ /*size=*/GetByteSizeRequirement(device_subshape), source,
+ &device_memory);
} else {
// Relayout data before transferring.
relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
/*shape_index=*/{});
source = relayed_out_literal->untyped_data();
+ TF_RETURN_IF_ERROR(TransferBufferToDevice(
+ stream,
+ /*size=*/GetByteSizeRequirement(device_subshape), source,
+ &device_memory));
+ return stream->BlockHostUntilDone();
}
- return TransferBufferToDevice(
- executor,
- /*size=*/GetByteSizeRequirement(device_subshape), source,
- &device_memory);
}
return Status::OK();
});
@@ -149,8 +166,7 @@ Status GenericTransferManager::TransferBufferToInfeed(
Status GenericTransferManager::TransferLiteralFromOutfeed(
se::StreamExecutor* executor, const Shape& literal_shape,
Literal* literal) {
- return Unimplemented(
- "Outfeed is not supported on this platform (b/30467474)");
+ return Unimplemented("Generic transfer from Outfeed");
}
Status GenericTransferManager::ResetDevices(
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 3da9570ef7..d216fe7d29 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -41,12 +41,13 @@ class GenericTransferManager : public TransferManager {
se::Platform::Id PlatformId() const override;
- StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override;
+ void TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) override;
- Status TransferLiteralToDevice(se::StreamExecutor* executor,
- const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) override;
+ Status TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
@@ -64,11 +65,14 @@ class GenericTransferManager : public TransferManager {
const void* source) override;
Status WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) override;
private:
+ StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDeviceInternal(
+ se::StreamExecutor* executor, const ShapedBuffer& device_buffer);
+
// The platform this transfer manager targets.
const se::Platform::Id platform_id_;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index fe597bfb45..9fca3a51c8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -150,7 +150,7 @@ cc_library(
":parallel_loop_emitter",
":partition_assignment",
":while_transformer",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -164,6 +164,8 @@ cc_library(
"//tensorflow/compiler/xla/service:name_uniquer",
"//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
+ "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
@@ -198,7 +200,7 @@ cc_library(
srcs = ["elemental_ir_emitter.cc"],
hdrs = ["elemental_ir_emitter.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -237,6 +239,20 @@ cc_library(
)
cc_library(
+ name = "hlo_execution_profiler",
+ srcs = ["hlo_execution_profiler.cc"],
+ hdrs = ["hlo_execution_profiler.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:pool",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ptr_util",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
name = "gpu_executable",
srcs = [
"conditional_thunk.cc",
@@ -250,6 +266,7 @@ cc_library(
"infeed_thunk.cc",
"kernel_thunk.cc",
"memset_thunk.cc",
+ "outfeed_thunk.cc",
"sequential_thunk.cc",
"thunk_schedule.cc",
"tuple_thunk.cc",
@@ -267,6 +284,7 @@ cc_library(
"infeed_thunk.h",
"kernel_thunk.h",
"memset_thunk.h",
+ "outfeed_thunk.h",
"sequential_thunk.h",
"thunk.h",
"thunk_schedule.h",
@@ -274,14 +292,16 @@ cc_library(
"while_thunk.h",
],
deps = [
- ":backend_configs",
":buffer_allocations",
":cudnn_convolution_runner",
+ ":hlo_execution_profiler",
":infeed_manager",
":ir_emission_utils",
+ ":outfeed_manager",
":partition_assignment",
":stream_assignment",
"//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -335,6 +355,7 @@ cc_library(
":cudnn_convolution_runner",
":gpu_executable",
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -366,7 +387,7 @@ cc_library(
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -416,9 +437,38 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
+cc_library(
+ name = "multi_output_fusion",
+ srcs = ["multi_output_fusion.cc"],
+ hdrs = ["multi_output_fusion.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:multi_output_fusion",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "multi_output_fusion_test",
+ srcs = ["multi_output_fusion_test.cc"],
+ deps = [
+ ":multi_output_fusion",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/core:lib",
],
)
@@ -460,9 +510,9 @@ tf_cc_test(
":instruction_fusion",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -472,6 +522,7 @@ cc_library(
hdrs = ["pad_insertion.h"],
deps = [
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -488,6 +539,8 @@ cc_library(
hdrs = ["gpu_transfer_manager.h"],
deps = [
":gpu_compiler",
+ ":outfeed_manager",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -522,6 +575,7 @@ cc_library(
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
+ ":multi_output_fusion",
":pad_insertion",
":partition_assignment",
":stream_assignment",
@@ -539,7 +593,6 @@ cc_library(
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
- "//tensorflow/compiler/xla/service:gather_expander",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_constant_folding",
"//tensorflow/compiler/xla/service:hlo_cse",
@@ -569,7 +622,6 @@ cc_library(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
"@llvm//:core",
- "@llvm//:support",
],
alwayslink = True, # Contains compiler registration
)
@@ -580,6 +632,7 @@ cc_library(
hdrs = ["cudnn_batchnorm_rewriter.h"],
deps = [
":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -599,6 +652,19 @@ cc_library(
)
cc_library(
+ name = "outfeed_manager",
+ srcs = ["outfeed_manager.cc"],
+ hdrs = ["outfeed_manager.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "gpu_layout_assignment",
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
@@ -672,7 +738,7 @@ cc_library(
srcs = ["while_transformer.cc"],
hdrs = ["while_transformer.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
index 77a48965e0..5780e0af40 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,8 +33,11 @@ ConditionalThunk::ConditionalThunk(
predicate_buffer_index_(predicate_buffer_index),
true_operand_buffer_index_(true_operand_buffer_index),
false_operand_buffer_index_(false_operand_buffer_index),
- true_thunk_(std::move(true_thunk_sequence), hlo),
- false_thunk_(std::move(false_thunk_sequence), hlo) {}
+ // Pass nullptr as the HloInstruction* to the true_thunk_ and false_thunk_
+ // constructors because these SequentialThunks are logically "part of"
+ // this ConditionalThunk, and shouldn't be profiled separately from it.
+ true_thunk_(std::move(true_thunk_sequence), nullptr),
+ false_thunk_(std::move(false_thunk_sequence), nullptr) {}
Status ConditionalThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -43,7 +47,9 @@ Status ConditionalThunk::Initialize(const GpuExecutable& executable,
}
Status ConditionalThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
// Copy the predicate value from device.
bool predicate;
se::DeviceMemoryBase predicate_address =
@@ -59,10 +65,15 @@ Status ConditionalThunk::ExecuteOnStream(
// Execute the true or the false computation depending on the value of the
// predicate.
if (predicate) {
- TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(
+ true_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->true_computation());
} else {
+ profiler->StartHloComputation();
TF_RETURN_IF_ERROR(
- false_thunk_.ExecuteOnStream(buffer_allocations, stream));
+ false_thunk_.ExecuteOnStream(buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->false_computation());
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
index ee03865d17..aef24342c9 100644
--- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -50,7 +51,8 @@ class ConditionalThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice predicate_buffer_index_;
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index f088112412..7833a4077e 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -55,7 +56,8 @@ ConvolutionThunk::ConvolutionThunk(
tensor_ops_enabled_(tensor_ops_enabled) {}
Status ConvolutionThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase input_data =
buffer_allocations.GetDeviceAddress(input_buffer_);
se::DeviceMemoryBase filter_data =
@@ -68,6 +70,7 @@ Status ConvolutionThunk::ExecuteOnStream(
se::dnn::AlgorithmConfig algorithm_config(
se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
filter_data, output_data, scratch, window_, dim_nums_, algorithm_config,
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index 6d845025b1..d76ca6698d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -66,7 +67,8 @@ class ConvolutionThunk : public Thunk {
// Does the convolution for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
class ScratchAllocator;
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
index ee38c0318a..92e03f94c1 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
@@ -30,9 +31,11 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk(
mem_size_(mem_size) {}
Status HostToDeviceCopyThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
return Status::OK();
}
@@ -47,11 +50,13 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk(
mem_size_(mem_size) {}
Status DeviceToDeviceCopyThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase destination_data =
buffer_allocations.GetDeviceAddress(destination_buffer_);
se::DeviceMemoryBase source_data =
buffer_allocations.GetDeviceAddress(source_buffer_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemcpy(&destination_data, source_data, mem_size_);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
index 8b128386f6..91564b520a 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -40,7 +41,8 @@ class HostToDeviceCopyThunk : public Thunk {
HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const void* source_address_;
@@ -63,7 +65,8 @@ class DeviceToDeviceCopyThunk : public Thunk {
DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice source_buffer_;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
index db6924c742..6028950652 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -66,11 +67,12 @@ Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -101,11 +103,12 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
@@ -126,12 +129,17 @@ Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) {
HloInstruction* variance_plus_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev,
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-2)))));
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ inverse_stddev->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(-2))),
+ {}))));
HloInstruction* variance =
computation_->AddInstruction(HloInstruction::CreateBinary(
variance_plus_epsilon->shape(), HloOpcode::kSubtract,
- variance_plus_epsilon, epsilon));
+ variance_plus_epsilon,
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ variance_plus_epsilon->shape(), epsilon, {}))));
// Repackage the results.
std::unique_ptr<HloInstruction> new_tuple = HloInstruction::CreateTuple({
@@ -164,23 +172,29 @@ Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) {
return Status::OK();
}
- HloInstruction* epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
+ HloInstruction* epsilon =
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0(batch_norm->epsilon())));
HloInstruction* feature_index =
computation_->AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0(batch_norm->feature_index())));
+ LiteralUtil::CreateR0(batch_norm->feature_index())));
// The cudnn libcall expects its input to be rsqrt(variance + epsilon), but
// the batchnorm HLO takes plain variance as input. Fix it up.
HloInstruction* var_plus_epsilon =
computation_->AddInstruction(HloInstruction::CreateBinary(
batch_norm->operand(3)->shape(), HloOpcode::kAdd,
- batch_norm->mutable_operand(3), epsilon));
+ batch_norm->mutable_operand(3),
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ batch_norm->operand(3)->shape(), epsilon, {}))));
HloInstruction* inverse_stddev =
computation_->AddInstruction(HloInstruction::CreateBinary(
var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon,
- computation_->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-.5)))));
+ computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ var_plus_epsilon->shape(),
+ computation_->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<float>(-.5))),
+ {}))));
std::vector<HloInstruction*> operands(batch_norm->operands().begin(),
batch_norm->operands().end());
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
index 68099fd638..7b172812c3 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@@ -99,13 +100,15 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk(
}
Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
std::tie(operand_desc, scale_offset_desc) =
MakeDescriptors(hlo_instruction()->shape(), feature_index_);
se::DeviceMemory<float> output(buffer_allocations.GetDeviceAddress(output_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -123,6 +126,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream(
/*is_training=*/false, //
/*var_to_inv_var=*/nullptr, //
/*inv_var_to_var=*/nullptr);
+
if (!stream->ok()) {
return InternalError("BatchNormalizationForward call failed.");
}
@@ -158,7 +162,8 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk(
}
Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
// The BatchNormTraining HLO outputs a tuple of three elements: output data,
@@ -175,6 +180,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(output_inv_stddev_));
se::DeviceMemory<float> null_device_ptr(nullptr);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationForward(
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(operand_)),
se::DeviceMemory<float>(buffer_allocations.GetDeviceAddress(scale_)),
@@ -240,7 +246,8 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk(
}
Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
dnn::BatchDescriptor operand_desc;
dnn::BatchDescriptor scale_offset_desc;
@@ -257,6 +264,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream(
se::DeviceMemory<float> output_grad_offset(
buffer_allocations.GetDeviceAddress(output_grad_offset_));
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenBatchNormalizationBackward(
se::DeviceMemory<float>(
buffer_allocations.GetDeviceAddress(grad_output_)),
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
index 874f85a863..d2143b3952 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -60,7 +61,8 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk {
const CudnnBatchNormForwardInferenceThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
@@ -90,7 +92,8 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk {
const CudnnBatchNormForwardTrainingThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
@@ -123,7 +126,8 @@ class CudnnBatchNormBackwardThunk : public Thunk {
delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
BufferAllocation::Slice operand_;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 3dc98c4c93..5a63e65208 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -80,8 +81,7 @@ bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
const ConvolutionDimensionNumbers& dnums,
se::StreamExecutor* stream_exec) {
// Skip this check for cudnn7 and newer.
- auto version =
- stream_exec->AsDnn()->GetVersion();
+ auto version = stream_exec->AsDnn()->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
@@ -338,8 +338,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
computation->AddInstruction(HloInstruction::CreateTuple(
{computation->AddInstruction(HloInstruction::CreateGetTupleElement(
new_call_shape.tuple_shapes(0), new_call, 0)),
- computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<uint8>({})))}));
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<uint8>({})))}));
TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple));
return true;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index e0c73aa73a..905b5ee876 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
}
// CuDNN does not accept zero-element arguments
- if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) ||
- ShapeUtil::HasZeroElements(conv->operand(1)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index e5e2a0478a..e594cec2f8 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -29,7 +29,7 @@ limitations under the License.
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -53,11 +53,17 @@ using llvm_ir::IrName;
using llvm_ir::SetToFirstInsertPoint;
using tensorflow::strings::StrAppend;
+namespace {
// Returns whether operand is a floating-point literal with the given value.
bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
- return operand->opcode() == HloOpcode::kConstant &&
- operand->literal().IsAllFloat(value);
+ if (operand->opcode() == HloOpcode::kConstant &&
+ operand->literal().IsAllFloat(value)) {
+ return true;
+ }
+ return operand->opcode() == HloOpcode::kBroadcast &&
+ IsFPLiteralWithValue(operand->operand(0), value);
}
+} // namespace
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
@@ -370,11 +376,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
"reduce_window_accum_ptr", ir_builder_);
{
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
- operand_to_generator.at(hlo->operand(1))({}));
+ operand_to_generator.at(hlo->operand(1))(
+ IrArray::Index(index.GetType())));
ir_builder_->CreateStore(init_value, accum_ptr);
}
- llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_);
+ llvm::Type* index_type = index.GetType();
+ auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
+ return index.GetConstantWithIndexType(c);
+ };
+
+ llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
@@ -385,14 +397,14 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
- IrArray::Index input_index(index.size());
+ IrArray::Index input_index(index_type, index.size());
llvm::Value* in_bounds = ir_builder_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
- index[i], ir_builder_->getInt64(window.dimensions(i).stride()));
+ index[i], index_typed_const(window.dimensions(i).stride()));
input_index[i] = ir_builder_->CreateNSWSub(
ir_builder_->CreateNSWAdd(stridden_index, window_index[i]),
- ir_builder_->getInt64(window.dimensions(i).padding_low()));
+ index_typed_const(window.dimensions(i).padding_low()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
@@ -403,7 +415,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
in_bounds,
ir_builder_->CreateICmpULT(
input_index[i],
- ir_builder_->getInt64(operand->shape().dimensions(i))));
+ index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
@@ -429,11 +441,13 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
llvm::Value* accum_ptr =
ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), module_));
+ llvm::Type* index_type = output_index.GetType();
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
- operand_to_generator.at(hlo->operand(1))({}));
+ operand_to_generator.at(hlo->operand(1))(
+ IrArray::Index(index_type)));
ir_builder()->CreateStore(init_value, accum_ptr);
- llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_);
+ llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type);
IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions(
operand->shape(), hlo->dimensions(), "reduction_dim");
if (!ShapeUtil::IsScalar(hlo->shape())) {
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
index e14ee6918b..0cdddf8bcf 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <string>
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -107,7 +108,8 @@ FftThunk::FftThunk(FftType fft_type,
output_shape_(output_shape) {}
Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(3) << "FFT type: " << FftTypeToString(fft_type_);
VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_);
VLOG(3) << "Output shape: "
@@ -116,6 +118,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (fft_plan_ == nullptr) {
const int64 fft_rank = fft_length_.size();
CHECK_LE(fft_rank, 3);
diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
index b0a22564f3..8c53be5077 100644
--- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/types.h"
@@ -72,7 +73,8 @@ class FftThunk : public Thunk {
// Does the FFT for the thunk on "stream".
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const se::fft::Type fft_type_;
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
index b36539e0cb..b3a3c5dcb4 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -27,8 +28,11 @@ ForThunk::ForThunk(const int64 loop_limit,
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
loop_limit_(loop_limit),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ // Pass nullptr as the HloInstruction* to the body_thunk_sequence_
+ // constructor because this SequentialThunk is logically "part of"
+ // this ForThunk, and shouldn't be profiled separately from it.
+ std::move(*body_thunk_sequence), nullptr)) {}
Status ForThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -37,11 +41,15 @@ Status ForThunk::Initialize(const GpuExecutable& executable,
}
Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (int64 i = 0; i < loop_limit_; ++i) {
+ profiler->StartHloComputation();
// Invoke loop body thunk sequence.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
index 41ddfe0ceb..c2d39071b2 100644
--- a/tensorflow/compiler/xla/service/gpu/for_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -39,7 +40,8 @@ class ForThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const int64 loop_limit_;
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index 2217776c7d..b22bb1d39b 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace gpu {
@@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {};
// Tuple
//
TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule MergeSharedFusionInstruction
comp.3 {
@@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 {
//
// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio.
TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule FlopsToBytesRatioThresholdExceeded
comp.2 {
@@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 {
// is merged into Fusion0 and Fusion1) would exceed the bytes transferred
// threshold.
TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule BytesTransferredThresholdExeceeded
comp.2 {
@@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 {
// Fusion2 is reduced for this test which makes the merge operation into its
// operand below the bytes transferred threshold.
TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule BytesTransferredThresholdNotExeceeded
comp.2 {
@@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 {
// Check that we're willing to merge f1_computation into f2_computation, even
// though f2 is an input fusion node.
TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule m
f1_computation {
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 79fca43d02..dbc7754e25 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -252,7 +252,8 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
alpha_(alpha) {}
Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(2) << "Executing a GemmThunk";
se::DeviceMemoryBase lhs_data =
@@ -352,6 +353,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
alpha_, stream);
};
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
bool launch_ok;
if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) {
launch_ok = launch(
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
index 7a4830d64e..939c7f85e3 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -48,7 +49,8 @@ class GemmThunk : public Thunk {
// Does the gemm operation for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
// Returns true if we'll perform autotuning if run on the given stream. If
// so, we want the GPU to be quiescent during autotuning, so as not to
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index b857219807..e1da8d940c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
-#include "tensorflow/compiler/xla/service/gather_expander.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
@@ -52,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
@@ -159,16 +159,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
pass.AddPass<CudnnBatchNormRewriter>();
}
- // TODO(kramerb): Remove use_fusion once instruction fusion can create
- // multi-output fusions from the unfused expander output.
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
- /*rewrite_grad_op=*/true,
- /*use_fusion=*/true);
-
- // Rewrite gather ops into smaller ones.
- pass.AddPass<GatherExpander>();
+ /*rewrite_grad_op=*/true);
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
// elimination has to come after that pass.
@@ -211,7 +205,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassPipeline pipeline("layout_assignment");
pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout(), stream_exec);
+ hlo_module->mutable_entry_computation_layout(), stream_exec);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -261,6 +255,9 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
+ fusion.AddPass<GpuMultiOutputFusion>();
+ fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
+ /*only_fusion_computations=*/true);
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline reduce_pipeline("reduce-precision");
@@ -555,8 +552,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
&ir_emitter_context);
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
- TF_RETURN_IF_ERROR(
- entry_computation->root_instruction()->Accept(&ir_emitter));
+ TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter));
}
if (user_pre_optimization_hook_) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index c5ccdd4a7d..fbc1303085 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -52,60 +52,20 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
HloDataflowAnalysis::Run(*module));
// Make sure all operands of a library call are in memory instead of constants
- // in IR.
- for (HloInstruction* hlo :
- module->entry_computation()->MakeInstructionPostOrder()) {
- // Inserts a copy of hlo->operand(n) if it's a constant.
- auto copy_operand_if_constant = [&](int64 n) -> Status {
- HloInstruction* operand = hlo->mutable_operand(n);
- TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
- const auto& values = dataflow->GetValueSet(operand).values();
- if (std::any_of(values.begin(), values.end(), [](const HloValue* value) {
- return value->defining_instruction()->opcode() ==
- HloOpcode::kConstant;
- })) {
- TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand));
- TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy));
- changed = true;
- }
- return Status::OK();
- };
-
- if (IsCustomCallToDnnBatchNorm(*hlo)) {
- // The epsilon and feature_index operands to a CUDNN batchnorm op don't
- // need to be materialized in memory -- in fact, they must be constants.
- // These are the last two operands of all three batchnorm ops.
- for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
- } else if (ImplementedAsLibraryCall(*hlo) ||
- hlo->opcode() == HloOpcode::kCrossReplicaSum) {
- // For all other library calls and cross-replica-sum, materialize all the
- // operands into memory. (Cross-replica-sum gets its constant args
- // materialized even if it's not implemented as a libcall to simplify the
- // implementation. It's slower, but we can constant fold away constant
- // args *anyway*, so we just need to make it work.)
- for (int64 i = 0; i < hlo->operand_count(); ++i) {
- TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
- }
- }
- }
-
- // Init values of while and conditional nodes cannot be constants. Insert
- // copies for any constants found at the operands of these nodes.
+ // in IR. Also, init values of while and conditional nodes cannot be
+ // constants. Insert copies for any constants found at the operands of these
+ // nodes.
tensorflow::gtl::FlatSet<HloInstruction*> inserted_copies;
for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kWhile &&
- instruction->opcode() != HloOpcode::kConditional) {
- continue;
- }
- for (auto operand : instruction->operands()) {
+ for (HloInstruction* hlo : computation->instructions()) {
+ // Inserts a copy of hlo->operand(n) if it's a constant.
+ auto copy_operand_if_constant = [&](int64 n) -> Status {
+ HloInstruction* operand = hlo->mutable_operand(n);
// Skip the operands that have already been replaced with a copy in a
// previous iteration (which is possible when a constant is used as an
// operand in multiple places).
if (ContainsKey(inserted_copies, operand)) {
- continue;
+ return Status::OK();
}
for (auto& pair : dataflow->GetInstructionValueSet(operand)) {
const HloValueSet& value_set = pair.second;
@@ -121,6 +81,47 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
}
+ return Status::OK();
+ };
+
+ if (IsCustomCallToDnnBatchNorm(*hlo)) {
+ // The epsilon and feature_index operands to a CUDNN batchnorm op don't
+ // need to be materialized in memory -- in fact, they must be constants.
+ // These are the last two operands of all three batchnorm ops.
+ for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
+ TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
+ }
+ } else if (ImplementedAsLibraryCall(*hlo) ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum ||
+ hlo->opcode() == HloOpcode::kWhile ||
+ hlo->opcode() == HloOpcode::kConditional) {
+ // For all other library calls, cross-replica-sum, while and conditional
+ // ops materialize all the operands into memory. (Cross-replica-sum
+ // gets its constant args materialized even if it's not implemented as a
+ // libcall to simplify the implementation. It's slower, but we can
+ // constant fold away constant args *anyway*, so we just need to make it
+ // work.)
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
+ }
+ }
+ }
+ }
+
+ if (changed) {
+ // Check the assumption that the epsilon and feature_index constants of the
+ // CUDNN batchnorm op are not shared with other ops where we would replace
+ // them with a copy. These custom op calls are generated with the
+ // CudnnBatchNormRewriter, so this would only happen if HloCSE merges them.
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* hlo : computation->instructions()) {
+ if (!IsCustomCallToDnnBatchNorm(*hlo)) {
+ continue;
+ }
+ for (int64 i = hlo->operand_count() - 2; i < hlo->operand_count();
+ ++i) {
+ CHECK_EQ(hlo->operand(i)->opcode(), HloOpcode::kConstant);
+ }
}
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 25d8f720ea..0cad2958c7 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -41,77 +41,6 @@ namespace {
using tensorflow::tracing::ScopedAnnotation;
-// A helper class for profiling HLO in the course of GPU program execution.
-// All of the profiling is guarded internally, to avoid the caller needing to
-// have lots of conditionals sprinkled around.
-class HloExecutionProfiler {
- public:
- // If profiling is enabled, start an execution timer running.
- explicit HloExecutionProfiler(
- bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
- const HloComputation* computation)
- : do_profile_(do_profile),
- profile_(profile),
- stream_(stream),
- sub_streams_(sub_streams),
- computation_(computation) {
- if (do_profile_) {
- clock_rate_ghz_ =
- stream->parent()->GetDeviceDescription().clock_rate_ghz();
- execution_timer_.reset(new se::Timer(stream->parent()));
- per_op_timer_.reset(new se::Timer(stream->parent()));
- stream->InitTimer(execution_timer_.get())
- .ThenStartTimer(execution_timer_.get());
- stream->InitTimer(per_op_timer_.get());
- }
- }
-
- // If profiling is enabled, sets the total cycle count on the profile from the
- // execution timer.
- void FinishExecution() {
- CHECK(!finished_execution_) << "Call FinishExecution only once!";
- finished_execution_ = true;
- if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(execution_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
- profile_->set_total_cycles_executed(
- *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_);
- }
- }
-
- // If profiling is enabled, starts the per-operation timer.
- void StartOperation() {
- if (do_profile_) {
- stream_->ThenStartTimer(per_op_timer_.get());
- }
- }
-
- // If profiling is enabled, stops the per-operation timer and records the time
- // that the hlo_instruction took to execute in the profile.
- void FinishOperation(const HloInstruction* hlo_instruction) {
- if (do_profile_) {
- stream_->ThenWaitFor(&sub_streams_);
- stream_->ThenStopTimer(per_op_timer_.get());
- stream_->BlockHostUntilDone().IgnoreError();
- profile_->SetCyclesTakenBy(
- hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_);
- }
- }
-
- private:
- const bool do_profile_;
- double clock_rate_ghz_;
- HloExecutionProfile* profile_;
- se::Stream* stream_;
- const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
- const HloComputation* computation_;
- std::unique_ptr<se::Timer> execution_timer_;
- std::unique_ptr<se::Timer> per_op_timer_;
- bool finished_execution_ = false;
-};
-
} // namespace
// Implementation note: HLO profiling is always enabled for GPU executables,
@@ -207,18 +136,17 @@ Status GpuExecutable::ExecuteThunks(
TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone());
}
- profiler.StartOperation();
VLOG(2) << "Executing the thunk for "
<< thunk->hlo_instruction()->ToString() << " on stream "
<< stream_no;
- TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(
+ thunk->ExecuteOnStream(buffer_allocations, stream, &profiler));
if (thunk_schedule_->Depended(thunk)) {
auto finish_event = MakeUnique<se::Event>(main_stream->parent());
finish_event->Init();
stream->ThenRecordEvent(finish_event.get());
thunk_to_finish_event[thunk] = std::move(finish_event);
}
- profiler.FinishOperation(thunk->hlo_instruction());
}
main_stream->ThenWaitFor(&sub_streams);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 178457721a..09ef62c87f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -51,7 +51,7 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// H <=> Y
// W <=> X
//
- // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW.
+ // Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
// As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
// fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
@@ -159,7 +159,13 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
Status GpuLayoutAssignment::AddBackendConstraints(
LayoutConstraints* constraints) {
- for (auto* instruction : constraints->computation()->instructions()) {
+ // Add convolution constraints in reverse postorder that the earliest
+ // convolution layout propagates first. This reduces the likelihood of fusion
+ // nodes with copies.
+ auto post_order = constraints->computation()->MakeInstructionPostOrder();
+ for (auto iterator = post_order.rbegin(); iterator != post_order.rend();
+ ++iterator) {
+ HloInstruction* instruction = *iterator;
if (IsCustomCallToDnnConvolution(*instruction)) {
TF_RETURN_IF_ERROR(
AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index e48165c142..95f78ae293 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -132,10 +132,10 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
HloInstruction::CreateParameter(4, aux_shape, "variance"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
shape,
@@ -201,10 +201,10 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
HloInstruction::CreateParameter(2, offset_scale_shape, "offset"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall(
batchnorm_shape, {operand, scale, offset, epsilon, feature_index},
@@ -278,10 +278,10 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
HloInstruction::CreateParameter(4, shape, "var"));
auto* epsilon = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
auto* feature_index =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR0<int64>(kFeatureIndex)));
+ LiteralUtil::CreateR0<int64>(kFeatureIndex)));
auto* batchnorm =
builder.AddInstruction(HloInstruction::CreateCustomCall(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 7bb8df6581..3c8018a030 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include <vector>
#include "llvm/IR/DataLayout.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -55,33 +57,28 @@ Status GpuTransferManager::TransferLiteralToInfeed(
return TransferBufferToInfeed(executor, size, literal.untyped_data());
}
- if (ShapeUtil::IsNestedTuple(shape)) {
- return Unimplemented(
- "Infeed with a nested tuple shape is not supported: %s",
- ShapeUtil::HumanString(literal.shape()).c_str());
- }
-
// For a tuple, we transfer each of its elements to the device and
// enqueue the resulting destination device addresses with the
// infeed manager.
std::vector<gpu::InfeedBuffer*> buffers;
- buffers.reserve(ShapeUtil::TupleElementCount(shape));
auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() {
for (gpu::InfeedBuffer* b : buffers) {
b->Done();
}
});
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const Shape& tuple_element_shape =
- ShapeUtil::GetTupleElementShape(shape, i);
- int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape);
- TF_ASSIGN_OR_RETURN(
- gpu::InfeedBuffer * buffer,
- TransferBufferToInfeedInternal(executor, tuple_element_size,
- literal.untyped_data({i})));
- buffers.push_back(buffer);
- }
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
+ shape, [&](const Shape& literal_subshape, const ShapeIndex& index) {
+ if (ShapeUtil::IsArray(literal_subshape)) {
+ int64 tuple_element_size = GetByteSizeRequirement(literal_subshape);
+ TF_ASSIGN_OR_RETURN(
+ gpu::InfeedBuffer * buffer,
+ TransferBufferToInfeedInternal(executor, tuple_element_size,
+ literal.untyped_data(index)));
+ buffers.push_back(buffer);
+ }
+ return Status::OK();
+ }));
cleanup.release();
return EnqueueBuffersToInfeed(executor, buffers);
@@ -144,6 +141,63 @@ StatusOr<gpu::InfeedBuffer*> GpuTransferManager::TransferBufferToInfeedInternal(
return buffer;
}
+static std::unique_ptr<Literal> ShapeTreeToLiteral(
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree) {
+ // This is a struct instead of a lambda for std::function-free recursion.
+ struct Helper {
+ static std::unique_ptr<Literal> helper(
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>>* shape_tree,
+ ShapeIndex* index) {
+ const Shape& shape = ShapeUtil::GetSubshape(shape_tree->shape(), *index);
+ if (ShapeUtil::IsArray(shape)) {
+ return (*shape_tree->mutable_element(*index))->WaitUntilAvailable();
+ }
+
+ CHECK(ShapeUtil::IsTuple(shape))
+ << ShapeUtil::HumanStringWithLayout(shape);
+ const int64 tuple_element_count = ShapeUtil::TupleElementCount(shape);
+ index->push_back(0);
+ std::vector<std::unique_ptr<Literal>> tuple_operands;
+ for (int64 i = 0; i < tuple_element_count; ++i) {
+ index->back() = i;
+ tuple_operands.push_back(helper(shape_tree, index));
+ }
+ index->pop_back();
+ return LiteralUtil::MakeTupleOwned(std::move(tuple_operands));
+ }
+ };
+ ShapeIndex index;
+ return Helper::helper(shape_tree, &index);
+}
+
+Status GpuTransferManager::TransferLiteralFromOutfeed(
+ se::StreamExecutor* /*executor*/, const Shape& literal_shape,
+ Literal* literal) {
+ ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
+ &literal_shape);
+
+ // First create a tree of literal buffers that the device can write to.
+ outfeed_buffers.ForEachMutableElement(
+ [&](const ShapeIndex& index,
+ std::unique_ptr<gpu::OutfeedBuffer>* buffer) {
+ const Shape& shape = ShapeUtil::GetSubshape(literal_shape, index);
+ // Do not transfer tuple index buffers.
+ if (ShapeUtil::IsTuple(shape)) {
+ return;
+ }
+ *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
+ });
+
+ // Give the tree of buffers to the outfeed mananger. The device will fill it
+ // while we're waiting for it below.
+ gpu::OutfeedManager* outfeed_manager = gpu::GetOrCreateOutfeedManager();
+ outfeed_manager->EnqueueOutfeedDestination(&outfeed_buffers);
+
+ // Now turn the tree of buffers back into a literal.
+ *literal = std::move(*ShapeTreeToLiteral(&outfeed_buffers));
+ return Status::OK();
+}
+
} // namespace xla
static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 09f8227f50..9dff1e5a50 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -40,6 +40,9 @@ class GpuTransferManager : public GenericTransferManager {
const LiteralSlice& literal) override;
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source) override;
+ Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+ const Shape& literal_shape,
+ Literal* literal) override;
private:
// Initiates the infeed data transfers. InfeedBuffer->Done() must be
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
new file mode 100644
index 0000000000..19420e590d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+
+#include <memory>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers,
+ se::Stream* stream) {
+ timers->push(MakeUnique<se::Timer>(stream->parent()));
+ stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get());
+}
+
+uint64 GetCyclesTaken(
+ std::stack<std::unique_ptr<se::Timer>>* timers,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ se::Stream* stream, double clock_rate_ghz) {
+ CHECK_GT(timers->size(), 0);
+ stream->ThenWaitFor(&sub_streams);
+ stream->ThenStopTimer(timers->top().get());
+ stream->BlockHostUntilDone().IgnoreError();
+ double nanoseconds = timers->top()->Nanoseconds();
+ timers->pop();
+ return static_cast<uint64>(nanoseconds * clock_rate_ghz);
+}
+} // namespace
+
+HloExecutionProfiler::HloExecutionProfiler(
+ bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const HloComputation* computation)
+ : do_profile_(do_profile),
+ profile_(profile),
+ stream_(stream),
+ sub_streams_(sub_streams),
+ computation_(computation) {
+ if (do_profile_) {
+ clock_rate_ghz_ = stream->parent()->GetDeviceDescription().clock_rate_ghz();
+ InitAndStartTimer(&timers_, stream);
+ }
+}
+
+void HloExecutionProfiler::FinishExecution() {
+ CHECK(!finished_execution_) << "Call FinishExecution only once!";
+ finished_execution_ = true;
+ if (do_profile_) {
+ profile_->set_total_cycles_executed(
+ *computation_,
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
+ }
+}
+
+void HloExecutionProfiler::StartHloComputation() {
+ if (do_profile_) {
+ InitAndStartTimer(&timers_, stream_);
+ }
+}
+
+void HloExecutionProfiler::FinishHloComputation(
+ const HloComputation* computation) {
+ if (do_profile_) {
+ profile_->set_total_cycles_executed(
+ *computation,
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
+ }
+}
+
+void HloExecutionProfiler::StartHloInstruction() {
+ if (do_profile_) {
+ InitAndStartTimer(&timers_, stream_);
+ }
+}
+
+void HloExecutionProfiler::FinishHloInstruction(
+ const HloInstruction* hlo_instruction) {
+ if (do_profile_) {
+ hlo_instructions_.erase(hlo_instruction);
+ profile_->SetCyclesTakenBy(
+ hlo_instruction,
+ GetCyclesTaken(&timers_, sub_streams_, stream_, clock_rate_ghz_));
+ }
+}
+
+std::unique_ptr<ScopedInstructionProfiler>
+HloExecutionProfiler::MakeScopedInstructionProfiler(
+ const HloInstruction* hlo_instruction) {
+ if (do_profile_ && hlo_instruction != nullptr) {
+ // Make sure that we are not already measuring the time for the same
+ // 'hlo_instruction'.
+ CHECK(hlo_instructions_.insert(hlo_instruction).second)
+ << hlo_instruction->name();
+ }
+ return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
new file mode 100644
index 0000000000..6654850bef
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
+
+#include <memory>
+#include <stack>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/pool.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+class ScopedInstructionProfiler;
+
+// A helper class for profiling HLO in the course of GPU program execution.
+// All of the profiling is guarded internally, to avoid the caller needing to
+// have lots of conditionals sprinkled around.
+class HloExecutionProfiler {
+ public:
+ // If profiling is enabled, start an execution timer running.
+ explicit HloExecutionProfiler(
+ bool do_profile, HloExecutionProfile* profile, se::Stream* stream,
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams,
+ const HloComputation* computation);
+
+ // If profiling is enabled, sets the total cycle count on the profile from the
+ // execution timer.
+ void FinishExecution();
+
+ // If profiling is enabled, starts a timer for a (sub)computation.
+ void StartHloComputation();
+
+ // If profiling is enabled stops the timer for a (sub)computation and records
+ // the time that the computation took to execute in the profile.
+ void FinishHloComputation(const HloComputation* computation);
+
+ // If profiling is enabled, starts a per-operation timer.
+ void StartHloInstruction();
+
+ // If profiling is enabled, stops the per-operation timer and records the time
+ // that the hlo_instruction took to execute in the profile.
+ void FinishHloInstruction(const HloInstruction* hlo_instruction);
+
+ // Returns a ScopedInstructionProfiler and triggers a call to
+ // StartHloInstruction(). Once the returned ScopedInstructionProfiler goes
+ // out of scope, it triggers a call to FinishHloInstruction().
+ std::unique_ptr<ScopedInstructionProfiler> MakeScopedInstructionProfiler(
+ const HloInstruction* hlo_instruction);
+
+ private:
+ const bool do_profile_;
+ double clock_rate_ghz_;
+ HloExecutionProfile* profile_;
+ se::Stream* stream_;
+ const std::vector<Pool<se::Stream>::SmartPtr>& sub_streams_;
+ const HloComputation* computation_;
+ std::stack<std::unique_ptr<se::Timer>> timers_;
+ // Contains the HLO instructions for which we are currently measuring the
+ // time.
+ std::unordered_set<const HloInstruction*> hlo_instructions_;
+ bool finished_execution_ = false;
+};
+
+// This class can be used within the ExecuteOnStream() implementations of
+// Thunks. It ensures that we always have a pair of matching
+// StartHloInstruction() and FinishHloInstruction() calls to the profiler.
+class ScopedInstructionProfiler {
+ public:
+ ScopedInstructionProfiler(HloExecutionProfiler* profiler,
+ const HloInstruction* hlo_instruction)
+ : profiler_(profiler), hlo_instruction_(hlo_instruction) {
+ if (hlo_instruction != nullptr) {
+ profiler->StartHloInstruction();
+ }
+ }
+ ~ScopedInstructionProfiler() {
+ if (hlo_instruction_ != nullptr) {
+ profiler_->FinishHloInstruction(hlo_instruction_);
+ }
+ }
+
+ private:
+ HloExecutionProfiler* profiler_;
+ const HloInstruction* hlo_instruction_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_EXECUTION_PROFILER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index f766f96882..19de37b0fb 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -100,7 +100,7 @@ GpuHloOrdering::GpuHloOrdering(
if (last_instruction_per_stream[stream_no] != nullptr) {
immediate_preds.push_back(last_instruction_per_stream[stream_no]);
}
- predecessor_map->SetReachabilityToUnion(immediate_preds, hlo);
+ predecessor_map->FastSetReachabilityToUnion(immediate_preds, hlo);
last_instruction_per_stream[stream_no] = hlo;
} else {
// Only parameters and constants don't have an assigned stream, since they
@@ -199,7 +199,7 @@ StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
// concurrency by optimizing for minimal memory usage.
TF_ASSIGN_OR_RETURN(
schedule->thunk_launch_order_,
- CreateMemoryMinimizingSequence(
+ ScheduleOneComputation(
*entry_computation, [pointer_size](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size);
}));
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
index e230d538cc..45f0a1c645 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>("test_module", config);
}
HloVec RemoveHlo(const HloVec& input,
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
index 061210352c..d420863b85 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -137,7 +137,7 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
- const ShapeIndex& shape_index,
+ ShapeIndexView shape_index,
llvm::Value* ir_value) {
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
@@ -158,7 +158,7 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
llvm::Value* ir_value,
- const ShapeIndex& shape_index) {
+ ShapeIndexView shape_index) {
VLOG(2) << "Binding " << hlo.ToString();
const Shape& hlo_shape = hlo.shape();
@@ -202,7 +202,7 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
<< " of " << hlo.ToString();
llvm_ir::IrArray ir_array(base_ptr,
ShapeUtil::GetSubshape(hlo.shape(), shape_index));
- alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
+ alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
// The GPU backend emits one kernel per top-level HLO, and LLVM views
// execution of one kernel as the "whole program" executed on the GPU.
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
index 3d34311b43..a86e6e78c6 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -51,7 +51,7 @@ class HloToIrBindings {
// Rebinds the given HLO to the LLVM IR value that represent its address.
void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value,
- const ShapeIndex& shape_index = {});
+ ShapeIndexView shape_index = {});
// Unbinds all IR values that's defined in an LLVM function, e.g., function
// arguments and stack variables. Global variables will be kept in bindings_.
@@ -71,7 +71,7 @@ class HloToIrBindings {
// A helper method that returns the base pointer of the IrArray containing the
// output of "inst".at the given ShapeIndex.
llvm::Value* GetBasePointer(const HloInstruction& hlo,
- const ShapeIndex& shape_index = {}) const {
+ ShapeIndexView shape_index = {}) const {
auto it = base_ptrs_.find(&hlo);
CHECK(it != base_ptrs_.end()) << hlo.ToString();
return it->second.element(shape_index);
@@ -97,7 +97,7 @@ class HloToIrBindings {
// Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape.
llvm::Value* GetTypedIrValue(const HloInstruction& hlo,
- const ShapeIndex& shape_index,
+ ShapeIndexView shape_index,
llvm::Value* ir_value);
const BufferAssignment* buffer_assignment_;
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index ea34d5b30c..62915febb1 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -22,29 +23,31 @@ namespace xla {
namespace gpu {
InfeedThunk::InfeedThunk(
- tensorflow::gtl::ArraySlice<BufferAllocation::Slice> tuple_element_buffers,
- const BufferAllocation::Slice& destination_buffer,
+ const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction)
- : Thunk(Kind::kInfeed, hlo_instruction),
- tuple_element_buffers_(tuple_element_buffers.begin(),
- tuple_element_buffers.end()),
- destination_buffer_(destination_buffer) {}
+ : Thunk(Kind::kInfeed, hlo_instruction), infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
VLOG(2) << "Infeeding to GPU ";
- se::DeviceMemoryBase destination_address =
- buffer_allocations.GetDeviceAddress(destination_buffer_);
-
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
+ // First copy the infeed data which is element 0 of the infeed instruction's
+ // two-tuple output (the other element is a token).
+ se::DeviceMemoryBase data_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({0}));
InfeedManager* infeed_manager = GetOrCreateInfeedManager();
std::vector<InfeedBuffer*> infeed_buffers;
- if (ShapeUtil::IsTuple(hlo_instruction()->shape())) {
- CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape()));
+ const Shape& data_shape =
+ ShapeUtil::GetTupleElementShape(hlo_instruction()->shape(), 0);
+ if (ShapeUtil::IsTuple(data_shape)) {
+ CHECK(!ShapeUtil::IsNestedTuple(data_shape));
// Transfer the tuple elements first.
std::vector<void*> tuple_element_addresses;
- for (BufferAllocation::Slice tuple_element_buffer :
- tuple_element_buffers_) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(data_shape); ++i) {
+ const BufferAllocation::Slice& tuple_element_buffer =
+ infeed_slices_.element({0, i});
se::DeviceMemoryBase tuple_element_address =
buffer_allocations.GetDeviceAddress(tuple_element_buffer);
@@ -56,15 +59,23 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
}
// Transfer the tuple outer buffer.
auto host_size = tuple_element_addresses.size() * sizeof(void*);
- stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(),
+ stream->ThenMemcpy(&data_address, tuple_element_addresses.data(),
host_size);
} else {
InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer();
infeed_buffers.push_back(buffer);
- stream->ThenMemcpy(&destination_address, *(buffer->device_memory()),
+ stream->ThenMemcpy(&data_address, *(buffer->device_memory()),
buffer->length());
}
+ // Construct top-level tuple of infeed containing the data and the token. Use
+ // a nullptr for the token, it should never be dereferenced.
+ std::vector<void*> infeed_addresses = {data_address.opaque(), nullptr};
+ se::DeviceMemoryBase top_level_address =
+ buffer_allocations.GetDeviceAddress(infeed_slices_.element({}));
+ stream->ThenMemcpy(&top_level_address, infeed_addresses.data(),
+ 2 * sizeof(void*));
+
Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
return InternalError("Failed to complete data transfer on stream %p: %s",
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
index 93713cb12d..59487e245b 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -32,23 +33,19 @@ namespace gpu {
class InfeedThunk : public Thunk {
public:
// Constructs a InfeedThunk that copies data from the on-device
- // infeed queue to the device buffer
- // `destination_buffer`. `mem_size` is the size of the data in
- // bytes.
- InfeedThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Slice>
- tuple_element_buffers,
- const BufferAllocation::Slice& destination_buffer,
+ // infeed queue into the buffers in the given shape tree.
+ InfeedThunk(const ShapeTree<BufferAllocation::Slice>& infeed_slices,
const HloInstruction* hlo_instruction);
InfeedThunk(const InfeedThunk&) = delete;
InfeedThunk& operator=(const InfeedThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
- const std::vector<BufferAllocation::Slice> tuple_element_buffers_;
- const BufferAllocation::Slice destination_buffer_;
+ const ShapeTree<BufferAllocation::Slice> infeed_slices_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 36a1b82a26..64ed3d748f 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -40,6 +40,7 @@ bool IsFusile(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kDynamicSlice ||
hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
hlo.opcode() == HloOpcode::kFusion ||
+ hlo.opcode() == HloOpcode::kGather ||
hlo.opcode() == HloOpcode::kPad ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReduceWindow ||
@@ -77,15 +78,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (consumer->operand_count() == 2 &&
- (producer->opcode() == HloOpcode::kDot ||
- (producer->opcode() == HloOpcode::kFusion &&
- producer->fused_expression_root()->opcode() == HloOpcode::kDot))) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
HloInstruction* op1 = nullptr;
HloInstruction* op2 = nullptr;
- if (consumer->opcode() == HloOpcode::kFusion &&
+ if (consumer->operand_count() == 1 &&
+ consumer->opcode() == HloOpcode::kFusion &&
consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
Match(consumer->fused_expression_root(),
match::Op()
@@ -103,10 +103,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
op2->opcode() != HloOpcode::kBroadcast) {
return false;
}
- if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) {
return true;
}
- } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ } else if (consumer->operand_count() == 2 &&
+ consumer->opcode() == HloOpcode::kMultiply) {
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
// Fuse if 'alpha' is a broadcast of a scalar constant.
if (alpha->opcode() == HloOpcode::kBroadcast &&
alpha->dimensions().empty() &&
@@ -173,6 +175,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
return false;
}
+ // Fuse scalar constants into loop fusion nodes, this reduces the number of
+ // parameters and makes matching scalar broadcasts easier.
+ if (ShapeUtil::IsEffectiveScalar(producer->shape()) &&
+ consumer->opcode() == HloOpcode::kFusion &&
+ producer->opcode() == HloOpcode::kConstant) {
+ return true;
+ }
+
return IsFusile(*producer) && IsFusile(*consumer) &&
InstructionFusion::ShouldFuse(consumer, operand_index);
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index ec60f3a167..98ba162cd9 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -33,7 +33,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndOperandElementReusingConsumerNotFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* broadcast2 =
@@ -53,7 +53,7 @@ TEST_F(InstructionFusionTest,
NonCostlyProducerAndOperandElementReusingConsumerFused) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
HloInstruction* broadcast2 =
@@ -73,7 +73,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* reshape2 = builder.AddInstruction(
@@ -92,7 +92,7 @@ TEST_F(InstructionFusionTest,
CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
HloComputation::Builder builder(TestName());
HloInstruction* const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
HloInstruction* transpose2 = builder.AddInstruction(
@@ -143,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
// Tests that broadcasts fused into a fusion with a reduce root.
TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
add {
@@ -168,11 +168,11 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
- op::Reduce(op::Broadcast(op::Parameter()), op::Parameter()));
+ op::Reduce(op::Broadcast(op::Constant()), op::Constant()));
}
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
@@ -194,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) {
}
TEST_F(InstructionFusionTest, AddIntoBitcast) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
@@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) {
}
TEST_F(InstructionFusionTest, DontFuseGTE) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY DontFuseGTE {
p0 = (f32[10], f32[10]) parameter(0)
@@ -232,7 +232,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) {
}
TEST_F(InstructionFusionTest, DotOutputFusion) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
alpha = f32[] constant(3)
@@ -255,13 +255,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
EXPECT_THAT(
root->fused_expression_root(),
op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
- op::Broadcast(op::Parameter())));
+ op::Broadcast(op::Constant())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
// duplicated and fused into both reduces.
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
Add {
lhs = f32[] parameter(0)
@@ -292,7 +292,7 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
// is *not* duplicated and fused into both reduces, because we say that integer
// division is not cheap.
TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
Add {
lhs = s32[] parameter(0)
@@ -317,7 +317,7 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
}
TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY NoOutputFusion {
alpha = f32[] constant(3)
@@ -339,7 +339,7 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
EXPECT_THAT(root->fused_expression_root(),
op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
- op::Broadcast(op::Parameter())));
+ op::Broadcast(op::Constant())));
}
// Counts the HLO ops with a given op code in the specified module.
@@ -371,7 +371,7 @@ static StatusOr<const HloInstruction*> FindHloInstruction(
TEST_F(InstructionFusionTest, MultiOutputFusion) {
// sub --> add --> tuple
// \---------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -403,7 +403,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) {
TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) {
// tanh --> add --> tuple
// \---------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -424,7 +424,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) {
TEST_F(InstructionFusionTest, MultiOutputFusion2) {
// sub --> add1 --\--------\
// \----------> add2 --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -457,7 +457,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion2) {
TEST_F(InstructionFusionTest, MultiOutputFusion3) {
// sub --> add1 ----\--------\
// \ --> add2 --> add3 --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -492,7 +492,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion3) {
TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) {
// sub --> mul ---\
// \--> call --> add --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
c = f32[] constant(42)
@@ -527,7 +527,7 @@ TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) {
TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) {
// sub[2,3] --> add[4,3] --> tuple([2,3], [4,3])
// \-------------------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[2,3]{1,0} parameter(0)
@@ -548,7 +548,7 @@ TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) {
}
TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
add_computation {
@@ -581,5 +581,30 @@ TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) {
<< module->ToString();
}
+TEST_F(InstructionFusionTest, FuseScalarConstant) {
+ auto module = ParseHloString(R"(
+ HloModule test_module
+
+ ENTRY FuseScalarConstant {
+ p0 = f32[] parameter(0)
+ c0 = f32[] constant(1)
+ add1 = f32[] add(p0, c0)
+ b0 = f32[2]{0} broadcast(add1), dimensions={}
+ c1 = f32[2]{0} constant({1, 2})
+ ROOT add2 = f32[2]{0} add(b0, c1)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())),
+ op::Parameter()));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 67890bfed1..388aa35d7d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape) &&
- !ShapeUtil::HasZeroElements(lhs_shape) &&
- !ShapeUtil::HasZeroElements(rhs_shape);
+ !ShapeUtil::IsZeroElementArray(lhs_shape) &&
+ !ShapeUtil::IsZeroElementArray(rhs_shape);
}
bool DotImplementedAsGemm(const HloInstruction& dot) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 1e0db2821a..fe83d017f4 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -191,6 +191,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
HloOpcode root_opcode = computation.root_instruction()->opcode();
PrimitiveType element_type =
computation.root_instruction()->shape().element_type();
+ bool is_atomic_integral = element_type == S32 || element_type == U32 ||
+ element_type == S64 || element_type == U64;
llvm::Value* source = ir_builder_.CreateLoad(source_address, "source");
if (root_opcode == HloOpcode::kAdd) {
// NVPTX supports atomicAdd on F32 and integer types.
@@ -201,7 +203,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
{output_address->getType()}, &ir_builder_);
return true;
}
- if (primitive_util::IsIntegralType(element_type)) {
+ if (is_atomic_integral) {
// integral + integral
ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address,
source,
@@ -210,9 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
}
}
- // NVPTX supports atomicMax and atomicMin on only integer types.
- if (root_opcode == HloOpcode::kMaximum &&
- primitive_util::IsIntegralType(element_type)) {
+ // NVPTX supports atomicMax and atomicMin only on integer types.
+ if (root_opcode == HloOpcode::kMaximum && is_atomic_integral) {
// max(integral, integral)
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
@@ -222,8 +223,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
return true;
}
- if (root_opcode == HloOpcode::kMinimum &&
- primitive_util::IsIntegralType(element_type)) {
+ if (root_opcode == HloOpcode::kMinimum && is_atomic_integral) {
// min(integral, integral)
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
@@ -421,24 +421,27 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation(
Status IrEmitter::HandleSelect(HloInstruction* select) {
auto pred = select->operand(0);
- auto on_true = select->operand(1);
- auto on_false = select->operand(2);
TF_RET_CHECK(pred->shape().element_type() == PRED);
-
- if (ShapeUtil::IsTuple(select->shape())) {
- llvm_ir::EmitTupleSelect(GetIrArray(*select, *select),
- GetIrArray(*pred, *select),
- GetBasePointer(*on_true),
- GetBasePointer(*on_false), &ir_builder_, module_);
- return Status::OK();
- }
-
// We must not call the subclass `DefaultAction` method, lest its
// `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
// assume no handler has already been called.
return IrEmitter::DefaultAction(select);
}
+Status IrEmitter::HandleTupleSelect(HloInstruction* tuple_select) {
+ auto pred = tuple_select->operand(0);
+ auto on_true = tuple_select->operand(1);
+ auto on_false = tuple_select->operand(2);
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+ TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()));
+ TF_RET_CHECK(ShapeUtil::IsTuple(tuple_select->shape()));
+ llvm_ir::EmitTupleSelect(GetIrArray(*tuple_select, *tuple_select),
+ GetIrArray(*pred, *tuple_select),
+ GetBasePointer(*on_true), GetBasePointer(*on_false),
+ &ir_builder_, module_);
+ return Status::OK();
+}
+
namespace {
llvm::Value* Real(llvm::Value* x, llvm::IRBuilder<>* ir_builder) {
return ir_builder->CreateExtractValue(x, {0});
@@ -475,12 +478,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
+ // TODO(b/110211620): Convert to use i32 index_type when it is possible.
+ llvm::Type* index_type = ir_builder_.getInt64Ty();
+ llvm_ir::IrArray::Index element_index(index_type);
if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
// If the operands are scalar, don't emit any loops.
llvm::Value* lhs_value =
- lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
+ lhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_);
llvm::Value* rhs_value =
- rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
+ rhs_array.EmitReadArrayElement(/*index=*/element_index, &ir_builder_);
llvm::Value* result;
if (ShapeUtil::ElementIsComplex(lhs_shape)) {
auto value = MultiplyComplex(lhs_value, rhs_value, &ir_builder_);
@@ -490,7 +496,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
} else {
result = ir_builder_.CreateFMul(lhs_value, rhs_value);
}
- target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
+ target_array.EmitWriteArrayElement(/*index=*/element_index, result,
+ &ir_builder_);
return Status::OK();
}
@@ -581,7 +588,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// address. The index into the target address is the concatenation of the rhs
// and lhs indexes with the reduction dimensions removed. The terms from the
// rhs index are the lower dimensions in the index so we add them first.
- llvm_ir::IrArray::Index target_index;
+ llvm_ir::IrArray::Index target_index(index_type);
for (size_t dimension = 0; dimension < lhs_index.size(); ++dimension) {
if (dimension != lhs_reduction_dimension) {
target_index.push_back(lhs_index[dimension]);
@@ -607,7 +614,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
- if (ShapeUtil::HasZeroElements(convolution->shape())) {
+ if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
@@ -617,7 +624,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
}
Status IrEmitter::HandleFft(HloInstruction* fft) {
- if (ShapeUtil::HasZeroElements(fft->shape())) {
+ if (ShapeUtil::IsZeroElementArray(fft->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index e55dfc6dae..d2dd335f10 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -88,6 +88,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
index bb47a42805..c9574c87a3 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -120,9 +120,10 @@ Status IrEmitterNested::EmitTargetElementLoop(
// For MOF we give the loop emitter an array for every output it should
// generate.
if (hlo.IsMultiOutputFusion()) {
+ const int64 num_elems = ShapeUtil::TupleElementCount(hlo.shape());
std::vector<llvm_ir::IrArray> target_arrays;
- for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e;
- ++i) {
+ target_arrays.reserve(num_elems);
+ for (int64 i = 0; i != num_elems; ++i) {
target_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
TF_RETURN_IF_ERROR(
@@ -130,6 +131,7 @@ Status IrEmitterNested::EmitTargetElementLoop(
.EmitLoop());
std::vector<llvm::Value*> tuple_operand_ptrs;
+ tuple_operand_ptrs.reserve(num_elems);
for (const llvm_ir::IrArray& array : target_arrays) {
tuple_operand_ptrs.push_back(array.GetBasePointer());
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 0f5c003341..673ba530df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
@@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@@ -59,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
@@ -78,6 +80,7 @@ namespace gpu {
namespace {
+using llvm_ir::IrArray;
using llvm_ir::IrName;
using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::InlinedVector;
@@ -282,6 +285,69 @@ int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
// Cannot unroll.
return 1;
}
+
+// Returns the llvm type for the indices used in the kernel that contains the
+// hlo instruction. Such indices include the index for the parallel loop and
+// the indices for the tensors accessed by the kernel. The return type is i32
+// iff the following conditions are met:
+// . The launch_size of the kernel is within the range of i32.
+// . The sizes of all the tensors accessed within the kernel are within the
+// range of i32.
+// Otherwise, the return type is i64.
+llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
+ llvm::IRBuilder<>* ir_builder) {
+ // Find the unnested hlo instructon for which the kernel is generated for.
+ const HloInstruction* unnested_hlo = hlo;
+ const HloComputation* computation = hlo->parent();
+ if (computation->IsFusionComputation()) {
+ unnested_hlo = computation->FusionInstruction();
+ }
+
+ auto shape_in_range = [&](const Shape& s) {
+ bool in_range = true;
+ ShapeUtil::ForEachSubshape(
+ s, [&](const Shape& sub_shape, const ShapeIndex& /*index*/) {
+ if (ShapeUtil::IsArray(sub_shape) &&
+ !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
+ in_range = false;
+ }
+ });
+
+ return in_range;
+ };
+
+ llvm::Type* i64_ty = ir_builder->getInt64Ty();
+ // Check launch dimension
+ if (!IsInt32(launch_size)) {
+ return i64_ty;
+ }
+
+ // Check the size of result tensors
+ if (!shape_in_range(unnested_hlo->shape())) {
+ return i64_ty;
+ }
+
+ auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
+ return shape_in_range(operand->shape());
+ };
+
+ // Check the size of input tensors
+ if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
+ return i64_ty;
+ }
+
+ // Check the size of the internal result tensors
+ if (unnested_hlo->opcode() == HloOpcode::kFusion) {
+ if (!c_all_of(
+ unnested_hlo->fused_instructions_computation()->instructions(),
+ hlo_shape_in_range)) {
+ return i64_ty;
+ }
+ }
+
+ return ir_builder->getInt32Ty();
+}
+
} // namespace
Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
@@ -291,7 +357,8 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
unroll_factor = ComputeMaxUnrollFactor(hlo);
}
- thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor));
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ hlo, /*implements_whole_instruction=*/true, unroll_factor));
return IrEmitter::DefaultAction(hlo);
}
@@ -305,7 +372,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(dot));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
return IrEmitter::HandleDot(dot);
}
@@ -315,7 +383,8 @@ Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
}
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
- thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
return IrEmitter::HandleConvolution(convolution);
}
@@ -501,24 +570,32 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
case HloOpcode::kReduce: {
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
- ArraySlice<HloInstruction*> reduces =
+ ArraySlice<HloInstruction*> output_instructions =
root->opcode() == HloOpcode::kTuple
? root->operands()
: ArraySlice<HloInstruction*>(&root, 1);
// For multi-output fusion emit an initializer for each tuple element.
// Otherwise it's sufficient to just initialize the single output.
- for (int i = 0, e = reduces.size(); i != e; ++i) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Thunk> initializer_thunk,
- BuildInitializerThunk(
- fusion, reduces[i] == root ? ShapeIndex() : ShapeIndex({i})));
- thunks.push_back(std::move(initializer_thunk));
+ HloInstruction* first_reduce = nullptr;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ if (output_instructions[i]->opcode() == HloOpcode::kReduce) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Thunk> initializer_thunk,
+ BuildInitializerThunk(fusion, output_instructions[i] == root
+ ? ShapeIndex()
+ : ShapeIndex({i})));
+ thunks.push_back(std::move(initializer_thunk));
+ first_reduce =
+ first_reduce == nullptr ? output_instructions[i] : first_reduce;
+ }
}
- thunks.push_back(BuildKernelThunk(fusion));
+ CHECK(first_reduce != nullptr);
+ thunks.push_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), fusion));
- std::vector<llvm_ir::IrArray> parameter_arrays;
+ std::vector<IrArray> parameter_arrays;
for (HloInstruction* operand : fusion->operands()) {
parameter_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -533,29 +610,49 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// fusion is a special case of that.
InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
InlinedVector<llvm_ir::ElementGenerator, 1> init_value_gens;
+ std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens;
InlinedVector<HloComputation*, 1> reducers;
- for (const HloInstruction* reduce : reduces) {
- CHECK_EQ(HloOpcode::kReduce, reduce->opcode());
- // TODO(kramerb): CHECK that layouts are equal. Currently this
- // breaks multioutputfusion_test. The test has pre-fused
- // instructions, but layout_assignment will not assign any layouts
- // for instructions inside of a fused computation. It just removes
- // the layouts instead.
- CHECK(ShapeUtil::Compatible(reduces[0]->shape(), reduce->shape()));
- CHECK(ShapeUtil::Compatible(reduces[0]->operand(0)->shape(),
- reduce->operand(0)->shape()));
- CHECK(ShapeUtil::Compatible(reduces[0]->operand(1)->shape(),
- reduce->operand(1)->shape()));
- CHECK(reduces[0]->dimensions() == reduce->dimensions());
- input_gens.push_back(fused_emitter.GetGenerator(reduce->operand(0)));
- init_value_gens.push_back(
- fused_emitter.GetGenerator(reduce->operand(1)));
- reducers.push_back(reduce->to_apply());
+ InlinedVector<ShapeIndex, 1> reduce_output_shapes;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ const HloInstruction* inst = output_instructions[i];
+ ShapeIndex output_shape_index;
+ if (root->opcode() == HloOpcode::kTuple) {
+ output_shape_index = {i};
+ }
+ if (inst->opcode() == HloOpcode::kReduce) {
+ CHECK(IsReductionToVector(*inst))
+ << "Only reductions to vector are supported";
+ // Shapes, layouts and dimensions must be the same for all reduces
+ // inside of this fusion.
+ CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
+ CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
+ inst->operand(0)->shape()));
+ CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
+ inst->operand(1)->shape()));
+ CHECK(first_reduce->dimensions() == inst->dimensions());
+ input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
+ init_value_gens.push_back(
+ fused_emitter.GetGenerator(inst->operand(1)));
+ reducers.push_back(inst->to_apply());
+ reduce_output_shapes.push_back(std::move(output_shape_index));
+ } else {
+ // For extra outputs we can relax shape equality to allow different
+ // types (with the same number of elements). Layouts still have to
+ // match.
+ CHECK(ShapeUtil::CompatibleIgnoringElementType(
+ first_reduce->operand(0)->shape(), inst->shape()));
+ CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
+ inst->shape().layout()));
+ extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
+ std::move(output_shape_index));
+ }
}
- const Shape& input_shape = reduces[0]->operand(0)->shape();
- return EmitReductionToVector(reduces[0], input_shape, input_gens,
- init_value_gens, reduces[0]->dimensions(),
- reducers);
+ const Shape& input_shape = first_reduce->operand(0)->shape();
+ return EmitReductionToVector(first_reduce, input_shape, input_gens,
+ init_value_gens,
+ first_reduce->dimensions(), reducers,
+ reduce_output_shapes, extra_output_gens);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@@ -569,8 +666,9 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// touching the un-updated elements.
// Set up kernel thunk and fused ir emitter.
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
- std::vector<llvm_ir::IrArray> operand_arrays;
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(fusion, /*implements_whole_instruction=*/true));
+ std::vector<IrArray> operand_arrays;
for (HloInstruction* operand : fusion->operands()) {
operand_arrays.push_back(GetIrArray(*operand, *fusion));
}
@@ -583,7 +681,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
// Array to write into. Because this is an in-place operation, this is the
// same as operand 0's array.
- llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion);
+ IrArray output_array = GetIrArray(*fusion, *fusion);
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
update_shape, ir_emitter_context_->device_description());
@@ -596,314 +694,25 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
fusion, operand_arrays, output_array, &elemental_emitter,
launch_dimensions, &ir_builder_);
}
+
if (ImplementedAsGemm(*fusion)) {
thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
return Status::OK();
}
- CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop);
- int unroll_factor = ComputeMaxUnrollFactor(fusion);
+ CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop);
- thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor));
- return IrEmitter::HandleFusion(fusion);
-}
-
-namespace {
-
-// Returns the indices of the first elements of all consecutive subarrays of the
-// given array. For example:
-// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
-std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
- std::vector<size_t> is = {0};
- for (size_t i = 1; i < xs.size(); ++i) {
- if (1 != xs[i] - xs[i - 1]) {
- is.push_back(i);
- }
- }
- return is;
-}
-
-// Merges the sequences of dimensions of the given shape which start at the
-// given indices `segs`.
-Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
- const Shape& shape) {
- std::vector<int64> dimensions;
- for (size_t i = 1; i <= segs.size(); ++i) {
- dimensions.push_back(std::accumulate(
- shape.dimensions().begin() + segs[i - 1],
- shape.dimensions().begin() +
- (segs.size() == i ? shape.dimensions().size() : segs[i]),
- 1, std::multiplies<int64>()));
- }
- return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
- dimensions);
-}
-
-// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
-// if so, the normalized and rank-reduced shapes. The shapes must have the same
-// dimensions, so this considers layout only.
-//
-// This function recognizes higher-rank transposes which are elementwise
-// equivalent to a 0-2-1 transpose.
-std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
- CHECK(ShapeUtil::Compatible(a, b));
- std::vector<int64> perm(a.dimensions().size());
- {
- auto layout_a_orig = LayoutUtil::MinorToMajor(a);
- std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
- auto layout_b_orig = LayoutUtil::MinorToMajor(b);
- std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
- for (size_t i = 0; i < perm.size(); ++i) {
- perm[i] = PositionInContainer(layout_b, layout_a[i]);
- }
+ if (CheckAndEmitHloWithTile021(fusion)) {
+ return Status::OK();
}
- auto segs = ConsecutiveSegments(perm);
- Shape norm_a =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
- Shape norm_b =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b);
- if (3 == segs.size() && 0 == perm[0]) {
- Shape reduced_a = MergeDimensions(segs, norm_a);
- Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout(
- b.element_type(),
- Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
- return std::make_tuple(true, reduced_a, reduced_b);
- }
- return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
-}
-
-// Returns whether the given shapes are potentially of a 0-2-1 transpose.
-// As 0-2-1 is a self-inverse permutation, which shape is input or output is
-// arbitrary.
-bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
- return 3 == b.dimensions().size() &&
- ShapeUtil::Compatible(
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a),
- ShapeUtil::PermuteDimensions(
- {0, 2, 1},
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- b)));
-}
-
-// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
-// major to minor. The x- and y- dimensions are tiled in square tiles of edge
-// length `tile_size`. Each thread block of `tile_size` x `num_rows` threads
-// transposes one tile: each thread copies a row from the input to a shared
-// memory tile, then copies a column from the shared memory tile to the output.
-//
-// `tile_size` should usually be same as warp size.
-//
-// Returns (number of tiles = number of thread blocks needed).
-//
-// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
-// to launch fewer blocks so each transposes many tiles, and
-// in any case, the number of blocks we can launch is limited.
-//
-// This is the same algorithm in CUDA:
-// https://github.com/tensorflow/tensorflow/blob/d2693c8a70567cc78b2e8a9ac8020d321620ca83/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L189
-int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
- const int64 tile_size, const int64 num_rows,
- llvm::IRBuilder<>* builder) {
- // Adds `addend` to the given `dim` of `index`.
- auto offset_dim = [builder](llvm_ir::IrArray::Index index,
- llvm::Value* addend, int64 dim) {
- index[dim] = builder->CreateAdd(index[dim], addend);
- return index;
- };
-
- CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
-
- Shape input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input.GetShape());
- Shape output_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- output.GetShape());
- input = input.CastToShape(input_shape, builder);
- output = output.CastToShape(output_shape, builder);
-
- llvm::Type* tile_type = llvm::ArrayType::get(
- llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
- // One extra here to avoid share memory bank conflict
- tile_size + 1);
- auto* tile = new llvm::GlobalVariable(
- *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
- /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
- llvm::UndefValue::get(tile_type), "tile", nullptr,
- llvm::GlobalValue::NotThreadLocal,
- /*AddressSpace=*/3 /* GPU shared memory */);
-
- // let x = threadIdx.x
- llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
- llvm_ir::AddRangeMetadata(0, num_rows * tile_size,
- static_cast<llvm::Instruction*>(x));
- x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
- "thread.id.x");
-
- // computing logical thread ids
- // logical_x = x % tile_size
- auto logical_x = builder->CreateURem(x, builder->getInt64(tile_size));
-
- // logical_y = x / tile_size
- auto logical_y = builder->CreateUDiv(x, builder->getInt64(tile_size));
-
- // `emit_cp` emits equivalent to following pseudocode:
- // if (tile_size == tile_width && tile_size == tile_height) {
- // unroll for (i in range(0, tile_size, num_rows)) {
- // emit_cp_element(index + {0, i, 0}, y + logical_y);
- // }
- // } else if (x < tile_width) {
- // tile_height_upperbound = ceil(tile_height / num_rows) * num_rows;
- // for (i in range(0, tile_height_upperbound, num_rows)) {
- // y_loc = i + logical_y;
- // if (y_loc < tile_height)
- // emit_cp_element(index + {0, i, 0}, y_loc);
- // }
- // }
- //
- // We use this to emit both the copy from input to tile and the copy from tile
- // to output.
- //
- // `index` is the origin of the row or column in the input or output array.
- //
- // `emit_cp_element(index, y)` emits code to copy a single element between the
- // tile and the input or output array, where `y` is the `y`-position in the
- // tile, whether which is row or column is a function of whether we're copying
- // from input or to output, and `index` is the index into the input or output
- // array.
- auto emit_cp_tile = [builder, tile_size, &offset_dim, num_rows, logical_x,
- logical_y](
- std::function<void(const llvm_ir::IrArray::Index&,
- llvm::Value*)>
- emit_cp_element,
- llvm::Value* tile_width, llvm::Value* tile_height,
- const llvm_ir::IrArray::Index& index,
- const string& loop_name) {
- llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
- builder->CreateAnd(
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
- builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
- "not_last_row", builder);
- builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
- for (int64 i = 0; i < tile_size; i += num_rows) {
- auto source_idx = offset_dim(index, builder->getInt64(i), /*dim=*/1);
- auto y_loc = builder->CreateAdd(builder->getInt64(i), logical_y);
- emit_cp_element(source_idx, y_loc);
- }
- builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
- llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(logical_x, tile_width), "x_in_tile", builder);
- builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
-
- // tile_height_upper_bound = ceil(tile_height / num_rows) * num_rows
- auto tile_height_upper_bound = builder->CreateMul(
- builder->CreateUDiv(
- builder->CreateAdd(tile_height, builder->getInt64(num_rows - 1)),
- builder->getInt64(num_rows)),
- builder->getInt64(num_rows));
-
- auto loop = llvm_ir::ForLoop::EmitForLoop(
- loop_name, builder->getInt64(0), tile_height_upper_bound,
- builder->getInt64(num_rows), builder);
- llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
- builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
-
- auto y_loc = builder->CreateAdd(loop->GetIndVarValue(), logical_y);
- auto if_y_in_tile = llvm_ir::EmitIfThenElse(
- builder->CreateICmpULT(y_loc, tile_height), "y_in_tile", builder);
- builder->SetInsertPoint(if_y_in_tile.true_block->getTerminator());
-
- emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
- y_loc);
- builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
- };
- auto input_dims_in_tiles = input_shape.dimensions();
- // Unpermuted dimensions are untiled.
- for (int i = 1; i < 3; ++i) {
- input_dims_in_tiles[i] =
- CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
- }
- int64 num_tiles =
- std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
- std::multiplies<int64>());
- const llvm_ir::IrArray::Index input_tile_index(
- /*linear=*/builder->CreateIntCast(
- llvm_ir::AddRangeMetadata(
- 0, num_tiles,
- static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
- builder))),
- builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
- ShapeUtil::MakeShapeWithDescendingLayout(
- PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
- builder);
- const llvm_ir::IrArray::Index input_tile_origin = ({
- llvm_ir::IrArray::Index index = input_tile_index;
- for (int i = 1; i < 3; ++i) {
- index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
- "tile_origin." + std::to_string(i));
- }
- index;
- });
- const llvm_ir::IrArray::Index input_index =
- offset_dim(offset_dim(input_tile_origin, logical_x, /*dim=*/2), logical_y,
- /*dim=*/1);
- std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
- // Only last row or column may not have full size.
- for (int i = 1; i < 3; ++i) {
- tile_dims[i] = builder->CreateSelect(
- builder->CreateICmpEQ(input_tile_index[i],
- builder->getInt64(input_dims_in_tiles[i] - 1)),
- builder->getInt64(input_shape.dimensions(i) -
- (input_dims_in_tiles[i] - 1) * tile_size),
- builder->getInt64(tile_size), "tile_size");
- }
-
- // Load data from input memory to shared memory tile.
- emit_cp_tile(
- // tile[y, x] = input_array[index]
- [builder, tile, &input, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- builder->CreateStore(
- input.EmitReadArrayElement(index, builder, "input_element"),
- builder->CreateGEP(tile, {builder->getInt64(0), y, logical_x}));
- },
- tile_dims[2], tile_dims[1], input_index, "input");
+ int unroll_factor = ComputeMaxUnrollFactor(fusion);
- // Wait for all threads to reach this point, lest we copy a value from tile to
- // output before the other thread copies it from input to tile.
- // This is `__syncthreads` in CUDA.
- llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
-
- const llvm_ir::IrArray::Index output_tile_index(
- Permute({0, 2, 1}, input_tile_index.multidim()));
- const llvm_ir::IrArray::Index output_tile_origin(
- Permute({0, 2, 1}, input_tile_origin.multidim()));
- const llvm_ir::IrArray::Index output_index =
- offset_dim(offset_dim(output_tile_origin, logical_x, /*dim=*/2),
- logical_y, /*dim=*/1);
-
- // Store data from shared memory tile to output memory.
- emit_cp_tile(
- // output_array[index] = tile[x, y]
- [builder, tile, &output, logical_x](const llvm_ir::IrArray::Index& index,
- llvm::Value* y) {
- output.EmitWriteArrayElement(
- index,
- builder->CreateLoad(
- builder->CreateGEP(tile, {builder->getInt64(0), logical_x, y}),
- "output_element"),
- builder);
- },
- tile_dims[1], tile_dims[2], output_index, "output");
-
- return num_tiles;
+ thunk_sequence_->emplace_back(BuildKernelThunk(
+ fusion, /*implements_whole_instruction=*/true, unroll_factor));
+ return IrEmitter::HandleFusion(fusion);
}
-} // namespace
-
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(),
*copy)) {
@@ -915,36 +724,40 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy));
return Status::OK();
}
- bool is_transpose_021;
- Shape reduced_input_shape, reduced_output_shape;
- std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
- IsTranspose021(copy->operand(0)->shape(), copy->shape());
- if (is_transpose_021 &&
- reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
- reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
- thunk_sequence_->emplace_back(BuildKernelThunk(copy));
- VLOG(3) << "Emitting tiled 0-2-1 transposition";
- constexpr int64 tile_size = 32;
- constexpr int64 num_rows = 8;
- int64 num_tiles = EmitTranspose021Tiled(
- GetIrArray(*copy->operand(0), *copy)
- .CastToShape(reduced_input_shape, &ir_builder_),
- GetIrArray(*copy, *copy)
- .CastToShape(reduced_output_shape, &ir_builder_),
- tile_size, num_rows, &ir_builder_);
- UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size),
- LastThunk(), ir_emitter_context_->llvm_module());
+ if (CheckAndEmitHloWithTile021(copy)) {
return Status::OK();
}
return IrEmitter::HandleCopy(copy);
}
+Status IrEmitterUnnested::EmitExtraOutputsForReduce(
+ const HloInstruction* reduce, const IrArray::Index& index,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
+ for (int i = 0; i != extra_output_gens.size(); ++i) {
+ const HloInstruction* output = reduce->parent()->FusionInstruction();
+ llvm::Value* extra_output_address =
+ GetIrArray(*output, *output, extra_output_gens[i].second)
+ .EmitArrayElementAddress(index, &ir_builder_,
+ "extra_output_element_address");
+ TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
+ extra_output_gens[i].first(index));
+ ir_builder_.CreateStore(extra_output_ir_value, extra_output_address);
+ }
+ return Status::OK();
+}
+
Status IrEmitterUnnested::EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// Number of elements processed by a single thread.
constexpr int64 kTileSize = 16;
int64 num_elems = ShapeUtil::ElementsIn(input_shape);
@@ -956,6 +769,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
int64 num_tiles =
RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize);
+ Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
+ reduce->shape().element_type(), {num_tiles}, {0});
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ tiled_input_shape, ir_emitter_context_->device_description());
+
+ llvm::Type* index_ty = GetIndexTypeForKernel(
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
+
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
// Check whether every thread will process a full tile's worth of elements
// without reading outside the bounds of the input. If this is true, we can
// skip some bounds checks in the final algorithm.
@@ -994,8 +819,7 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// // and threads_per_block is a multiple of warpSize.
// reduce_kernel<<<num_blocks, threads_per_block>>>();
//
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
llvm::Type* element_ir_type =
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
@@ -1005,40 +829,42 @@ Status IrEmitterUnnested::EmitReductionToScalar(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index({})));
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
llvm::Value* x_in_tiles = tile_index[0];
+ x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
// Emit an inner for-loop that reduces the elements in the tile.
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- ir_builder_.getInt64(0),
- ir_builder_.getInt64(kTileSize),
- ir_builder_.getInt64(1), &ir_builder_);
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* x = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// x-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)),
+ ir_builder_.CreateICmpULT(x, index_typed_constant(num_elems)),
"x_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
- llvm_ir::IrArray::Index input_index(
+
+ IrArray::Index input_index(
/*linear=*/x, input_shape, &ir_builder_);
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
for (int i = 0; i != num_reduces; ++i) {
@@ -1050,18 +876,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
{partial_reduction_result_addresses[i], input_address},
partial_reduction_result_addresses[i]));
}
- return Status::OK();
+ return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens);
};
// x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
// immediately beyond the tile.
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
- ir_builder_.getInt64(kTileSize),
- ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)));
// The tile is entirely in bound if all_threads_in_bounds or
// x_end <= num_elems.
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)),
+ ir_builder_.CreateICmpULE(x_end, index_typed_constant(num_elems)),
ir_builder_.getInt1(all_threads_in_bounds));
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
@@ -1112,25 +938,21 @@ Status IrEmitterUnnested::EmitReductionToScalar(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm::Value* lane_id = ir_builder_.CreateURem(
- x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id");
+ x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
+ IrArray::Index(
/*linear=*/ir_builder_.getInt64(0),
ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
+ reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
@@ -1140,10 +962,6 @@ Status IrEmitterUnnested::EmitReductionToScalar(
};
// Emit a parallel loop that iterates through all input tiles, one per thread.
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {num_tiles}, {0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
UpdateLaunchDimensions(
launch_dimensions,
@@ -1151,14 +969,18 @@ Status IrEmitterUnnested::EmitReductionToScalar(
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
launch_dimensions, &ir_builder_)
- .EmitLoop(IrName(reduce));
+ .EmitLoop(IrName(reduce), index_ty);
}
Status IrEmitterUnnested::EmitColumnReduction(
int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// Divide the input matrix into tiles of size Kx1. For example, when the
// input matrix is 4x4 and K=2, the tiled matrix looks like
//
@@ -1178,6 +1000,17 @@ Status IrEmitterUnnested::EmitColumnReduction(
// If the height is not a multiple of the tile size, we pad the bottom of the
// input matrix.
const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
+ Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
+ reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ tiled_input_shape, ir_emitter_context_->device_description());
+
+ // TODO(b/110211620): Convert to use i32 index_type when it is possible.
+ llvm::Type* index_ty = ir_builder_.getInt64Ty();
+
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
// linear_index < height_in_tiles * width;
@@ -1202,8 +1035,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
// }
// AtomicReducer(&output[x], partial_result);
// }
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
// Emit the loop body that reduces one tile.
llvm::Type* element_ir_type =
@@ -1214,7 +1046,7 @@ Status IrEmitterUnnested::EmitColumnReduction(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index({})));
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
@@ -1225,24 +1057,28 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x = tile_index[1];
+ y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
+ x = ir_builder_.CreateZExtOrTrunc(x, index_ty);
+
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- ir_builder_.getInt64(0),
- ir_builder_.getInt64(kTileSize),
- ir_builder_.getInt64(1), &ir_builder_);
+ index_typed_constant(0),
+ index_typed_constant(kTileSize),
+ index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* y = ir_builder_.CreateNSWAdd(
- ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)),
tile_element_loop->GetIndVarValue());
+
// Unless we know the tile is entirely in bounds, we have to emit a
// y-in-bounds check before reading from the input.
if (!tile_in_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)),
+ ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
"y_in_bounds", &ir_builder_);
// Emit code that reads the input element and accumulates it to
@@ -1266,9 +1102,9 @@ Status IrEmitterUnnested::EmitColumnReduction(
const Shape input_matrix_shape =
ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
{height, width});
- const llvm_ir::IrArray::Index input_matrix_index(
- {y, x}, input_matrix_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
+ const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
+ &ir_builder_);
+ const IrArray::Index input_index =
input_matrix_index
.SourceIndexOfReshape(input_matrix_shape,
normalized_input_shape, &ir_builder_)
@@ -1284,17 +1120,18 @@ Status IrEmitterUnnested::EmitColumnReduction(
{partial_reduction_result_addresses[i], input_address},
partial_reduction_result_addresses[i]));
}
- return Status::OK();
+ return EmitExtraOutputsForReduce(reduce, input_index,
+ extra_output_gens);
}
};
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
// immediately beyond the tile.
llvm::Value* y_end = ir_builder_.CreateNSWAdd(
- ir_builder_.getInt64(kTileSize),
- ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)));
+ index_typed_constant(kTileSize),
+ ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)));
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)),
+ ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
ir_builder_.getInt1(height % kTileSize == 0));
// The tile is entirely in bound if "height" is a multiple of kTileSize or
// y_end <= height.
@@ -1315,18 +1152,13 @@ Status IrEmitterUnnested::EmitColumnReduction(
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- x,
- ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
- &ir_builder_),
+ IrArray::Index(x,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
@@ -1335,10 +1167,6 @@ Status IrEmitterUnnested::EmitColumnReduction(
};
// Emit a parallel loop that iterate through all input tiles.
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
UpdateLaunchDimensions(
launch_dimensions,
@@ -1346,7 +1174,31 @@ Status IrEmitterUnnested::EmitColumnReduction(
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
launch_dimensions, &ir_builder_)
- .EmitLoop(IrName(reduce));
+ .EmitLoop(IrName(reduce), index_ty);
+}
+
+static std::pair<int64, int64> ComputeTilingSchemeForReduction(
+ int64 depth, int64 width, int64 kWarpSize) {
+ constexpr int64 kTargetNumElementsPerThread = 64;
+ int64 x_tile_size = kTargetNumElementsPerThread;
+ int64 z_tile_size = 1;
+
+ // Only tile along the x dimension with tile size kTargetNumElementsPerThread
+ // if doing so doesn't require a slow version of loop with bound check on each
+ // dimension. A more sophisticated heuristics is to enable tile along the
+ // x dimension with tile size kTargetNumElementsPerThread when either width is
+ // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big
+ // enough so that only a small fraction of the threads execute the slow
+ // version of loop with bound check.
+ if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) {
+ x_tile_size = 8;
+ z_tile_size = 8;
+ while (depth % z_tile_size != 0) {
+ z_tile_size -= 1;
+ }
+ }
+
+ return std::pair<int64, int64>(x_tile_size, z_tile_size);
}
Status IrEmitterUnnested::EmitRowReduction(
@@ -1354,9 +1206,13 @@ Status IrEmitterUnnested::EmitRowReduction(
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// A naive algorithm is:
- // 1. Divide the input tensor into tiles of size 1x1xK.
+ // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
// 2. Partially reduces each tile to a scalar using one thread.
// 3. Accumulates that scalar to the output vector using atomic operations.
//
@@ -1367,15 +1223,15 @@ Status IrEmitterUnnested::EmitRowReduction(
// int y = linear_index / width_in_tiles % height;
// int z = linear_index / (height * width_in_tiles);
// float partial_result = 0;
- // for (element_id_in_tile : range(kTileSize)) {
- // int x = x_in_tiles * kTileSize + element_id_in_tile;
+ // for (element_id_in_tile : range(x_tile_size)) {
+ // int x = x_in_tiles * x_tile_size + element_id_in_tile;
// if (x < width)
- // partial_result = reducer(partial_result, input[z][y][z]);
+ // partial_result = reducer(partial_result, input[z][y][x]);
// }
// AtomicReducer(&output[y], partial_result);
// }
//
- // Three optimizations are performed.
+ // Four optimizations are performed.
//
// 1. To coalesce global memory accesses, dilate the tile with a factor of 32
// (i.e. the warp size). For example, suppose the width is 8x32=256. Instead
@@ -1402,29 +1258,46 @@ Status IrEmitterUnnested::EmitRowReduction(
// element_id_in_tile, which makes the code more friendly to optimizations
// such as LICM.
//
+ // 4. When the width is too small and x_tile_size is less than the target
+ // number of elements per thread and use a small factor of depth as
+ // z_tile_size to increase the number of elements calculated by each
+ // partial sum. This can reduce the needed number of dynamic shfl_down and
+ // atomic operations.
+ //
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
// linear_index < depth * height * width_in_tiles;
// linear_index += blockDim.x * gridDim.x) {
// int x_in_tiles = linear_index % width_in_tiles;
// int y = linear_index / width_in_tiles % height;
- // int z = linear_index / (height * width_in_tiles);
+ // int z_in_tiles = linear_index / (height * width_in_tiles);
// int warp_id = x_in_tiles / warpSize;
// int lane_id = x_in_tiles % warpSize;
// float partial_result = 0;
// int x = warp_id * kTileSize * warpSize + lane_id;
- // if (width % (kTileSize * warpSize) == 0 ||
- // x + (kTileSize - 1) * warpSize < width) {
- // // The entire tile is in bounds.
- // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
- // ++element_id_in_tile, x += warpSize) {
- // partial_result = Reducer(partial_result, input[z][y][x]);
+ // if (width % (x_tile_size * warpSize) == 0 ||
+ // x + (x_tile_size - 1) * warpSize < width) {
+ // // The entire x_tile is in bounds.
+ // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
+ // ++element_id_in_z_tile) {
+ // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
+ // int tx = x;
+ // for (int element_id_in_x_tile = 0;
+ // element_id_in_x_tile < x_tile_size;
+ // ++element_id_in_x_tile, tx += warpSize) {
+ // partial_result = Reducer(partial_result, input[z][y][tx]);
+ // }
// }
// } else {
// // The tile is partially in bounds.
- // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
- // ++element_id_in_tile, x += warpSize) {
- // if (x < width)
- // partial_result = Reducer(partial_result, input[z][y][x]);
+ // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
+ // ++element_id_in_z_tile) {
+ // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
+ // int tx = x;
+ // for (int element_id_in_x_tile = 0; element_id_in_x_tile <
+ // x_tile_size; ++element_id_in_tile, tx += warpSize) {
+ // if (tx < width)
+ // partial_result = Reducer(partial_result, input[z][y][tx]);
+ // }
// }
// }
// for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
@@ -1435,17 +1308,30 @@ Status IrEmitterUnnested::EmitRowReduction(
// AtomicReducer(&output[y], partial_result);
// }
//
- // Choose 8 as the tile size, which matches Eigen's RowReduceKernel.
- constexpr int64 kTileSize = 8;
+
+ int64 x_tile_size;
+ int64 z_tile_size;
+ std::tie(x_tile_size, z_tile_size) =
+ ComputeTilingSchemeForReduction(depth, width, kWarpSize);
+
// Round the width in tiles up to the nearest multiple of kWarpSize, so that
// the use of shfl_down is valid.
const int64 width_in_tiles =
- RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize);
+ RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize);
+ Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
+ reduce->shape().element_type(),
+ {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0});
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ tiled_input_shape, ir_emitter_context_->device_description());
+ llvm::Type* index_ty = GetIndexTypeForKernel(
+ reduce, launch_dimensions.launch_bound(), &ir_builder_);
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ auto loop_body_emitter = [=](const IrArray::Index& tile_index) {
const int num_reduces = reducers.size();
- // Emit the loop body that reduces one tile.
llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
input_shape.element_type(), ir_emitter_context_->llvm_module());
std::vector<llvm::Value*> partial_reduction_result_addresses;
@@ -1454,122 +1340,149 @@ Status IrEmitterUnnested::EmitRowReduction(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](llvm_ir::IrArray::Index({})));
+ init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
- // Emit an inner for-loop that partially reduces the elements in the given
- // tile.
- llvm::Value* z = tile_index[0];
+ llvm::Value* z_tile = tile_index[0];
llvm::Value* y = tile_index[1];
llvm::Value* x_tile = tile_index[2];
+
+ x_tile = ir_builder_.CreateZExtOrTrunc(x_tile, index_ty);
+
llvm::Value* warp_id = ir_builder_.CreateUDiv(
- x_tile, ir_builder_.getInt64(kWarpSize), "warp_id");
+ x_tile, index_typed_constant(kWarpSize), "warp_id");
llvm::Value* lane_id = ir_builder_.CreateURem(
- x_tile, ir_builder_.getInt64(kWarpSize), "lane_id");
+ x_tile, index_typed_constant(kWarpSize), "lane_id");
- // The x-location of the last element in this tile.
- // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize);
+ // The x-location of the last element in this z-x-tile.
+ // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
llvm::Value* last_x = ir_builder_.CreateNSWAdd(
- lane_id,
- ir_builder_.CreateNSWMul(
- ir_builder_.getInt64(kWarpSize),
- ir_builder_.CreateNSWAdd(
- ir_builder_.getInt64(kTileSize - 1),
- ir_builder_.CreateNSWMul(warp_id,
- ir_builder_.getInt64(kTileSize)))));
-
- auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
- std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
- ir_builder_.getInt64(0),
- ir_builder_.getInt64(kTileSize),
- ir_builder_.getInt64(1), &ir_builder_);
-
- // Emit the body of the partial reduction loop.
- llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &ir_builder_);
- // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize);
- llvm::Value* x = ir_builder_.CreateNSWAdd(
- lane_id,
- ir_builder_.CreateNSWMul(
- ir_builder_.getInt64(kWarpSize),
- ir_builder_.CreateNSWAdd(
- tile_element_loop->GetIndVarValue(),
- ir_builder_.CreateNSWMul(warp_id,
- ir_builder_.getInt64(kTileSize)))));
-
- // Unless we know the tile is entirely in bounds, we have to emit a
- // x-in-bounds check before reading from the input.
- if (!tile_in_bounds) {
- llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)),
- "x_in_bounds", &ir_builder_);
-
- // Points ir_builder_ to the then-block.
- llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
- &ir_builder_);
- }
-
- // Emit code that reads the input element and accumulates it to the
- // partial reduction result.
- llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
- {
- // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We
- // need to convert that to an index to input_shape (the shape of the
- // operand of "reduce"). This conversion is composed of a transposition
- // from input_shape to normalized_input_shape and a reshape from
- // normalized_input_shape to input_3d_tensor_shape.
- const Shape normalized_input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input_shape);
- auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
- const std::vector<int64> transpose_dimension_mapping(
- input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
- const Shape input_3d_tensor_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
- {depth, height, width});
- const llvm_ir::IrArray::Index input_3d_tensor_index(
- {z, y, x}, input_3d_tensor_shape, &ir_builder_);
- const llvm_ir::IrArray::Index input_index =
- input_3d_tensor_index
- .SourceIndexOfReshape(input_3d_tensor_shape,
- normalized_input_shape, &ir_builder_)
- .SourceIndexOfTranspose(normalized_input_shape, input_shape,
- transpose_dimension_mapping,
- &ir_builder_);
- for (int i = 0; i != num_reduces; ++i) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
- input_gens[i](input_index));
- ir_builder_.CreateStore(input_ir_value, input_address);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], input_address},
- partial_reduction_result_addresses[i]));
- }
+ lane_id, ir_builder_.CreateNSWMul(
+ index_typed_constant(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ index_typed_constant(x_tile_size - 1),
+ ir_builder_.CreateNSWMul(
+ warp_id, index_typed_constant(x_tile_size)))));
+
+ KernelSupportLibrary ksl(
+ &ir_builder_,
+ /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll,
+ /*prevent_vectorization=*/false);
+
+ // Emit a for-loop that partially reduces the elements in the given
+ // z-x-tile.
+ auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
+ int64 x_tile_loop_bound) -> Status {
+ auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
+ llvm::Value* z = ir_builder_.CreateNSWAdd(
+ z_indvar, ir_builder_.CreateNSWMul(
+ index_typed_constant(z_tile_size), z_tile));
+ TF_RETURN_IF_ERROR(ksl.For(
+ "x_tile",
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(x_tile_loop_bound),
+ /*step=*/1, [&](llvm::Value* x_indvar) -> Status {
+ // x = lane_id +
+ // warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
+ llvm::Value* x = ir_builder_.CreateNSWAdd(
+ lane_id,
+ ir_builder_.CreateNSWMul(
+ index_typed_constant(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ x_indvar, ir_builder_.CreateNSWMul(
+ warp_id, llvm::ConstantInt::get(
+ index_ty, x_tile_size)))));
+
+ // Unless we know the x-tile is entirely in bounds, we have to
+ // emit a x-in-bounds check before reading from the input.
+ if (!x_tile_in_bounds) {
+ llvm_ir::LlvmIfData if_x_in_bounds_data =
+ llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT(
+ x, index_typed_constant(width)),
+ "x_in_bounds", &ir_builder_);
+ // Points ir_builder_ to the then-block.
+ llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
+ &ir_builder_);
+ }
+
+ // Emit code that reads the input element and accumulates it
+ // to the partial reduction result.
+ llvm::Value* input_address =
+ ir_builder_.CreateAlloca(element_ir_type);
+ {
+ // {z,y,x} is an index to input_3d_tensor_shape
+ // [depth,height,width]. We need to convert that to an index
+ // to input_shape (the shape of the operand of "reduce").
+ // This conversion is composed of a transposition from
+ // input_shape to normalized_input_shape and a reshape from
+ // normalized_input_shape to input_3d_tensor_shape.
+ const Shape normalized_input_shape = ShapeUtil::
+ MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ input_shape);
+ auto input_shape_min2maj =
+ LayoutUtil::MinorToMajor(input_shape);
+ const std::vector<int64> transpose_dimension_mapping(
+ input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
+ const Shape input_3d_tensor_shape =
+ ShapeUtil::MakeShapeWithDescendingLayout(
+ input_shape.element_type(), {depth, height, width});
+ const IrArray::Index input_3d_tensor_index(
+ {z, y, x}, input_3d_tensor_shape, &ir_builder_);
+ const IrArray::Index input_index =
+ input_3d_tensor_index
+ .SourceIndexOfReshape(input_3d_tensor_shape,
+ normalized_input_shape,
+ &ir_builder_)
+ .SourceIndexOfTranspose(
+ normalized_input_shape, input_shape,
+ transpose_dimension_mapping, &ir_builder_);
+
+ for (int i = 0; i != num_reduces; ++i) {
+ TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
+ input_gens[i](input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], input_address},
+ partial_reduction_result_addresses[i]));
+ }
+ return EmitExtraOutputsForReduce(reduce, input_index,
+ extra_output_gens);
+ }
+ }));
return Status::OK();
- }
+ };
+
+ return ksl.For("z_tile",
+ /*start=*/index_typed_constant(0),
+ /*end=*/index_typed_constant(z_tile_size),
+ /*step=*/1, emit_z_tile_element_loop);
};
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
- ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0),
- ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width)));
- llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
- &ir_builder_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
-
- // After the if-then-else statement on tile_in_bounds, emit calls to
- // shfl_down that accumulate the partial reduction results of all threads
- // from the warp.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
- &ir_builder_);
+ ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0),
+ ir_builder_.CreateICmpULT(last_x, index_typed_constant(width)));
+
+ TF_RETURN_IF_ERROR(
+ ksl.If(tile_in_bounds,
+ /*true_block_generator=*/
+ [&]() -> Status {
+ return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true,
+ x_tile_size);
+ },
+ /*false_block_generator=*/
+ [&]() -> Status {
+ return emit_z_x_tile_element_loop(
+ /*x_tile_in_bounds=*/false,
+ CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize));
+ }));
+
+ // After accumulating the elements of the z_x_tile, emit calls to
+ // shfl_down that accumulate the partial reduction results of all
+ // threads in a warp.
int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
// bitcast cannot be applied to aggregate types (even packed ones), so we
// instead bitcast addresses of load/store to intN* of the same bit-width.
@@ -1605,36 +1518,36 @@ Status IrEmitterUnnested::EmitRowReduction(
// lane 0 (which holds the partially accumulated result for its warp) to the
// output element.
llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)),
+ ir_builder_.CreateICmpEQ(lane_id, index_typed_constant(0)),
"lane_id_is_zero", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
&ir_builder_);
for (int i = 0; i != num_reduces; ++i) {
- ShapeIndex output_shape_index;
- if (output->IsMultiOutputFusion()) {
- output_shape_index = {i};
- }
llvm::Value* output_address =
- GetIrArray(*output, *output, output_shape_index)
+ GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
- llvm_ir::IrArray::Index(
- y,
- ShapeUtil::GetSubshape(output->shape(),
- output_shape_index),
- &ir_builder_),
+ IrArray::Index(y,
+ ShapeUtil::GetSubshape(
+ output->shape(), reduce_output_shapes[i]),
+ &ir_builder_),
&ir_builder_, "output_element_address");
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, partial_reduction_result_addresses[i]));
+ // We don't need to emit atomic operations if there is only one tile of
+ // results. 'depth' is the z dimension, 'width' is the x dimension.
+ if (z_tile_size >= depth && x_tile_size >= width) {
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducers[i],
+ {output_address, partial_reduction_result_addresses[i]},
+ output_address));
+ } else {
+ TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address,
+ partial_reduction_result_addresses[i]));
+ }
}
return Status::OK();
};
// Emit a parallel loop that iterates through every input tiles.
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {depth, height, width_in_tiles},
- {2, 1, 0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
UpdateLaunchDimensions(
launch_dimensions,
@@ -1642,7 +1555,7 @@ Status IrEmitterUnnested::EmitRowReduction(
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
launch_dimensions, &ir_builder_)
- .EmitLoop(IrName(reduce));
+ .EmitLoop(IrName(reduce), index_ty);
}
// Figures out whether `reduce` is a row or column reduction, and which
@@ -1656,7 +1569,11 @@ Status IrEmitterUnnested::EmitReductionToVector(
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers) {
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens) {
// This emission requires "reduce" to have an input layout. It is either set
// by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
// a fused kReduce).
@@ -1692,7 +1609,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
// dimension of the input is to keep.
if (input_dims_to_keep.empty()) {
return EmitReductionToScalar(reduce, input_shape, input_gens,
- init_value_gens, reducers);
+ init_value_gens, reducers,
+ reduce_output_shapes, extra_output_gens);
} else if (input_dims_to_keep.front() ==
LayoutUtil::Minor(input_shape.layout(), 0)) {
// Column reduction. Treat the result of "input" as a matrix whose width
@@ -1710,7 +1628,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
}
return EmitColumnReduction(height, width, reduce, input_shape, input_gens,
- init_value_gens, reducers);
+ init_value_gens, reducers, reduce_output_shapes,
+ extra_output_gens);
} else {
// Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
// 3D tensor. The size of dimension 1 (the height) is the size of the
@@ -1736,7 +1655,8 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
const int64 height = ShapeUtil::ElementsIn(reduce->shape());
return EmitRowReduction(depth, height, width, reduce, input_shape,
- input_gens, init_value_gens, reducers);
+ input_gens, init_value_gens, reducers,
+ reduce_output_shapes, extra_output_gens);
}
}
@@ -1748,30 +1668,30 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
// HandleReduce specializes reduction from a multi-dimensional array to a 1D
// array. The specialized version requires an initializer thunk that
// initializes the output array to the initial value of the reduce.
- if (IsReductionToVector(*reduce) &&
- // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits
- 32 <= primitive_util::BitWidth(reduce->shape().element_type())) {
+ if (IsReductionToVector(*reduce)) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
BuildInitializerThunk(reduce));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(reduce));
+ thunks.push_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), reduce));
return EmitReductionToVector(
- reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) {
+ reduce, input->shape(), {[&](const IrArray::Index& index) {
return GetIrArray(*input, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
- {[&](const llvm_ir::IrArray::Index& index) {
+ {[&](const IrArray::Index& index) {
return GetIrArray(*init_value, *reduce)
.EmitReadArrayElement(index, &ir_builder_);
}},
- dimensions_to_reduce, {reducer});
+ dimensions_to_reduce, {reducer}, {{}}, {});
}
- thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(reduce, /*implements_whole_instruction=*/true));
return IrEmitter::HandleReduce(reduce);
}
@@ -1800,7 +1720,8 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
return Status::OK();
}
- thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
return IrEmitter::HandleTuple(tuple);
}
@@ -1825,7 +1746,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
BuildInitializerThunk(select_and_scatter));
std::vector<std::unique_ptr<Thunk>> thunks;
thunks.push_back(std::move(initializer_thunk));
- thunks.push_back(BuildKernelThunk(select_and_scatter));
+ thunks.push_back(BuildKernelThunk(select_and_scatter,
+ /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
@@ -1835,6 +1757,14 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
"Dilation for SelectAndScatter not implemented on GPU.");
}
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ source->shape(), ir_emitter_context_->device_description());
+ llvm::Type* index_type = GetIndexTypeForKernel(
+ select_and_scatter, launch_dimensions.launch_bound(), &ir_builder_);
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_type, c);
+ };
+
// kSelectAndScatter is implemented as two kernel launches: the first launch
// initializes the output array to the given initial value,
// and the second accumulates the "source" matrix to the
@@ -1854,8 +1784,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// selected_index = I
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
- auto loop_body_emitter =
- [=](const llvm_ir::IrArray::Index& source_index) -> Status {
+ auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
// Allocate space to keep the currently selected value, its index, and a
// boolean flag if the value is initialized. The initialized_flag is set
// false.
@@ -1865,8 +1794,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
"selected_value_address", &ir_builder_);
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
- "selected_index_address", &ir_builder_);
+ index_type, index_typed_constant(rank), "selected_index_address",
+ &ir_builder_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
ir_builder_.CreateStore(ir_builder_.getInt1(false),
@@ -1874,13 +1803,13 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"),
- &ir_builder_);
+ &ir_builder_, index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
CHECK_GT(dim.size(), 0);
}
- const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ const IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
&ir_builder_);
@@ -1888,17 +1817,17 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
- llvm_ir::IrArray::Index operand_index(source_index.size());
+ IrArray::Index operand_index(index_type, source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
- source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ source_index[i], index_typed_constant(window.dimensions(i).stride()));
operand_index[i] = ir_builder_.CreateNSWSub(
ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
- ir_builder_.getInt64(window.dimensions(i).padding_low()));
+ index_typed_constant(window.dimensions(i).padding_low()));
llvm::Value* index_condition = ir_builder_.CreateICmpULT(
operand_index[i],
- ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
in_bounds_condition =
ir_builder_.CreateAnd(in_bounds_condition, index_condition);
}
@@ -1916,8 +1845,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
- const auto save_operand_index = [&](
- const llvm_ir::IrArray::Index& operand_index) {
+ const auto save_operand_index = [&](const IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot =
ir_builder_.CreateInBoundsGEP(selected_index_address,
@@ -1925,7 +1853,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
- llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
+ IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
ir_builder_.CreateStore(operand_data, selected_value_address);
@@ -1970,7 +1898,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
// value and the current output value.
llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
- llvm_ir::IrArray::Index selected_index;
+ IrArray::Index selected_index(operand_index.GetType());
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
@@ -1988,8 +1916,6 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
source_value_address);
};
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- source->shape(), ir_emitter_context_->device_description());
UpdateLaunchDimensions(
launch_dimensions,
// IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
@@ -2000,7 +1926,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter(
ir_emitter_context_->llvm_module());
return ParallelLoopEmitter(loop_body_emitter, source->shape(),
launch_dimensions, &ir_builder_)
- .EmitLoop(IrName(select_and_scatter));
+ .EmitLoop(IrName(select_and_scatter), index_type);
}
Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
@@ -2027,15 +1953,23 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
}
Status IrEmitterUnnested::HandleRng(HloInstruction* random) {
- thunk_sequence_->push_back(BuildKernelThunk(random));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(random, /*implements_whole_instruction=*/true));
return IrEmitter::HandleRng(random);
}
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
- thunk_sequence_->push_back(BuildKernelThunk(select));
+ thunk_sequence_->push_back(
+ BuildKernelThunk(select, /*implements_whole_instruction=*/true));
return IrEmitter::HandleSelect(select);
}
+Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
+ thunk_sequence_->push_back(
+ BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
+ return IrEmitter::HandleTupleSelect(tuple_select);
+}
+
Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
if (hlo_module_config_.replica_count() != 1) {
// TODO(b/33011107): Support nontrivial cross replica sum on GPU.
@@ -2071,22 +2005,31 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) {
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*crs->operand(i)),
/*destination_buffer=*/tuple_element_buffers.back(),
- /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), crs));
+ /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
}
// Output a tuple of the buffers above.
thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers,
- GetAllocationSlice(*crs), crs));
+ GetAllocationSlice(*crs), nullptr));
thunk_sequence_->push_back(
MakeUnique<SequentialThunk>(std::move(thunks), crs));
return Status::OK();
}
+Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) {
+ return Status::OK();
+}
+
Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
thunk_sequence_->emplace_back(BuildInfeedThunk(infeed));
return Status::OK();
}
+Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
+ thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed));
+ return Status::OK();
+}
+
// Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO).
//
@@ -2205,13 +2148,9 @@ GetHloBufferSlices(const HloInstruction* hlo,
return slices;
}
-Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
- // TODO(b/72710576): Gather is not implemented on GPUs
- return Unimplemented("Gather is not implemented on GPUs.");
-}
-
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
- const HloInstruction* inst, int unroll_factor) {
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
@@ -2303,7 +2242,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
}
return MakeUnique<KernelThunk>(buffers, llvm_ir::AsString(kernel->getName()),
- inst, unroll_factor);
+ implements_whole_instruction ? inst : nullptr,
+ unroll_factor);
}
std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk(
@@ -2335,17 +2275,31 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
const HloInstruction* inst) {
CHECK_EQ(HloOpcode::kInfeed, inst->opcode());
- std::vector<BufferAllocation::Slice> tuple_element_buffers;
- for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) {
- BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment()
- .GetUniqueSlice(inst, {i})
- .ConsumeValueOrDie();
- tuple_element_buffers.push_back(buffer);
- }
+ ShapeTree<BufferAllocation::Slice> slices(inst->shape());
+ slices.ForEachMutableElement(
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ *slice = ir_emitter_context_->buffer_assignment()
+ .GetUniqueSlice(inst, index)
+ .ConsumeValueOrDie();
+ });
+ return MakeUnique<InfeedThunk>(slices, inst);
+}
- return MakeUnique<InfeedThunk>(
- tuple_element_buffers,
- /*destination_buffer=*/GetAllocationSlice(*inst), inst);
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
+ const HloInstruction* inst) {
+ CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
+
+ ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
+ slices.ForEachMutableElement(
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ auto status_or_slice =
+ ir_emitter_context_->buffer_assignment().GetUniqueSlice(
+ inst->operand(0), index);
+ if (status_or_slice.ok()) {
+ *slice = status_or_slice.ConsumeValueOrDie();
+ }
+ });
+ return MakeUnique<OutfeedThunk>(std::move(slices), inst);
}
namespace {
@@ -2390,7 +2344,9 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (alpha->opcode() == HloOpcode::kBroadcast) {
alpha = alpha->operand(0);
}
- alpha = inst->operand(alpha->parameter_number());
+ if (alpha->opcode() == HloOpcode::kParameter) {
+ alpha = inst->operand(alpha->parameter_number());
+ }
// TODO(b/74185543): Remove the following if block once we support fusion
// with a non-constant as well. Then we will just always use the constant
// on the device.
@@ -2436,15 +2392,18 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
const HloInstruction* hlo, const ShapeIndex& index) {
bool fused = HloOpcode::kFusion == hlo->opcode();
const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
- const HloInstruction* init_value = [&] {
+ const HloInstruction* init_value_operand = [&] {
switch (inst->opcode()) {
case HloOpcode::kSelectAndScatter:
return inst->operand(2);
case HloOpcode::kReduce:
return inst->operand(1);
case HloOpcode::kTuple:
- CHECK(hlo->IsMultiOutputFusion() &&
- inst->operand(index.back())->opcode() == HloOpcode::kReduce);
+ CHECK(hlo->IsMultiOutputFusion())
+ << ": " << hlo->ToString() << " is not a multi-output fusion.";
+ CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce)
+ << ": Found '" << inst->operand(index.back())->opcode() << "' in "
+ << inst->ToString() << " but expected 'reduce'.";
// For multi-output fusion look through the tuple.
return inst->operand(index.back())->operand(1);
default:
@@ -2453,10 +2412,16 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
}
}();
+ const HloInstruction* init_value = init_value_operand;
if (fused && init_value->opcode() == HloOpcode::kParameter) {
init_value = hlo->operand(init_value->parameter_number());
}
+ // Initializer thunks don't implement a whole instruction, and we want to
+ // profile the whole instruction instead of the individual thunks it consists
+ // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
+ // generate below.
+ //
// In the common case, the initializer is a constant. In this case, emit a
// device-memset call if we can. Currently StreamExecutor only supports
// zeroing and 32-bit memsets.
@@ -2470,24 +2435,26 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
ArraySlice<uint8> literal_bytes(
reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
- return {MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), hlo)};
+ return {
+ MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
// repeating the literal 4 or 2 times, so long as the destination buffer is
// an even multiple of 32 bits long.
+ const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
if ((num_bytes == 1 || num_bytes == 2) &&
- ShapeUtil::ByteSizeOf(hlo->shape()) % 4 == 0) {
+ ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
uint16 pattern16;
if (num_bytes == 1) {
uint8 b = literal_bytes.front();
pattern16 = uint16{b} | (uint16{b} << 8);
} else {
- pattern16 = literal_bytes.front();
+ memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
}
uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
return {MakeUnique<Memset32BitValueThunk>(
- pattern32, GetAllocationSlice(*hlo, index), hlo)};
+ pattern32, GetAllocationSlice(*hlo, index), nullptr)};
}
// If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
@@ -2498,19 +2465,31 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
uint32 word;
memcpy(&word, literal_bytes.data(), sizeof(word));
return {MakeUnique<Memset32BitValueThunk>(
- word, GetAllocationSlice(*hlo, index), hlo)};
+ word, GetAllocationSlice(*hlo, index), nullptr)};
}
}
// Otherwise fall back to our slow initializer code.
- std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(hlo);
- TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
- *hlo,
- [=](const llvm_ir::IrArray::Index& index) {
- return GetIrArray(*init_value, *hlo)
- .EmitReadArrayElement(index, &ir_builder_);
- },
- kernel_thunk.get()));
+ std::unique_ptr<KernelThunk> kernel_thunk =
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
+ LaunchDimensions launch_dimensions =
+ CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
+ ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
+ ir_emitter_context_->llvm_module());
+ // If the init_value was fused into this reduce we have to generate it first.
+ if (fused && init_value_operand->opcode() != HloOpcode::kParameter) {
+ CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode());
+ TF_RETURN_IF_ERROR(HandleConstant(const_cast<HloInstruction*>(init_value)));
+ }
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(
+ [=](const IrArray::Index& index) {
+ return GetIrArray(*init_value, *hlo)
+ .EmitReadArrayElement(index, &ir_builder_);
+ },
+ GetIrArray(*hlo, *hlo, index), launch_dimensions,
+ &ir_builder_)
+ .EmitLoop(IrName(hlo)));
// Clean up state left behind by emitting the loop above. (This is normally
// done in IrEmitterUnnested::Postprocess().)
@@ -2694,18 +2673,22 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
if (!hlo.IsMultiOutputFusion()) {
return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
launch_dimensions, &ir_builder_, unroll_factor)
- .EmitLoop(IrName(&hlo));
+ .EmitLoop(IrName(&hlo),
+ GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(),
+ &ir_builder_));
}
- // For multiple outputs fusion, we need to emit each operand and the root.
- std::vector<llvm_ir::IrArray> output_arrays;
+ // For multioutput fusion, we need to emit each operand and the root.
+ std::vector<IrArray> output_arrays;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
}
- TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays,
- launch_dimensions, &ir_builder_,
- unroll_factor)
- .EmitLoop(IrName(&hlo)));
+ TF_RETURN_IF_ERROR(
+ ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
+ &ir_builder_, unroll_factor)
+ .EmitLoop(IrName(&hlo),
+ GetIndexTypeForKernel(
+ &hlo, launch_dimensions.launch_bound(), &ir_builder_)));
std::vector<llvm::Value*> tuple_operand_ptrs;
for (int64 i = 0; i < output_arrays.size(); ++i) {
@@ -2725,5 +2708,482 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
+int IrEmitterUnnested::ConstructIrArrayForOutputs(
+ const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_arrays->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
+ }
+ } else {
+ output_arrays->push_back(GetIrArray(hlo, hlo));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_arrays->reserve(num_params);
+ for (const HloInstruction* param : hlo.operands()) {
+ param_arrays->push_back(GetIrArray(*param, hlo));
+ }
+ return num_params;
+}
+
+int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<IrArray>* output_in_reduced_shape_arrays) {
+ int64 num_outputs = 1;
+ if (hlo.IsMultiOutputFusion()) {
+ num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
+ output_in_reduced_shape_arrays->reserve(num_outputs);
+ output_reduced_shapes->reserve(num_outputs);
+ for (int64 i = 0; i < num_outputs; ++i) {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ ShapeUtil::GetSubshape(hlo.shape(), {i}).element_type(),
+ reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[i].CastToShape(
+ (*output_reduced_shapes)[i], &ir_builder_));
+ }
+ } else {
+ output_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ hlo.shape().element_type(), reduced_output_dims));
+ output_in_reduced_shape_arrays->push_back(output_arrays[0].CastToShape(
+ (*output_reduced_shapes)[0], &ir_builder_));
+ }
+ return num_outputs;
+}
+
+int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<IrArray>* param_in_reduced_shape_arrays) {
+ int64 num_params = hlo.operands().size();
+ param_in_reduced_shape_arrays->reserve(num_params);
+ param_reduced_shapes->reserve(num_params);
+ for (int64 id = 0; id < num_params; ++id) {
+ if (param_buffers[id] == nullptr) {
+ param_reduced_shapes->push_back(Shape());
+ param_in_reduced_shape_arrays->push_back(IrArray());
+ continue;
+ }
+ const HloInstruction* param = hlo.operand(id);
+ param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
+ param->shape().element_type(),
+ Permute({0, 2, 1}, reduced_output_dims)));
+ param_in_reduced_shape_arrays->push_back(param_arrays[id].CastToShape(
+ (*param_reduced_shapes)[id], &ir_builder_));
+ }
+ return num_params;
+}
+
+namespace {
+
+// Reads thread_idx.x and converts it to a (y,x) coordinate, assuming that the
+// thread lives within a square tile of size tile_size (so thread blocks are of
+// size tile_size * tile_size).
+std::tuple<llvm::Value*, llvm::Value*> CalculateYXCoordinateWithinTile(
+ llvm::IRBuilder<>* builder, llvm::Value* tile_size,
+ int64 threads_per_tile) {
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, threads_per_tile,
+ llvm::cast<llvm::Instruction>(thread_id));
+ thread_id = builder->CreateIntCast(thread_id, tile_size->getType(),
+ /*isSigned=*/true, "thread.id.x");
+ auto x = builder->CreateURem(thread_id, tile_size);
+ auto y = builder->CreateUDiv(thread_id, tile_size);
+ return std::make_tuple(y, x);
+}
+
+// Reads block_idx.x, casts it to type index_ty, and adds the assumption that
+// it's in the range [0, num_blocks].
+llvm::Value* GetBlockIdx(llvm::IRBuilder<>* builder, llvm::Type* index_ty,
+ int64 num_blocks) {
+ llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, num_blocks,
+ llvm::cast<llvm::Instruction>(block_id));
+ return builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true,
+ "block.id.x");
+}
+
+// Emits code to process up to (tile_size/num_rows) elements in a tile, given
+// `emit_elem_function` is the function to emit code to process one element, `y`
+// and `x` are the coordinates for the first element to process, and `index` is
+// the index for the origin of the tile. Emits bounds check to ensure that each
+// processed element is within the boundary defined by `tile_width` and
+// `tile_height`.
+void EmitTiledElementalCodeWithBoundsCheck(
+ int64 tile_size, int64 num_rows, const IrArray::Index& index,
+ const string& loop_name, KernelSupportLibrary* ksl,
+ llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ llvm::Type* index_ty = tile_width->getType();
+ // Emits a constant value with index type.
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = builder->CreateAdd(index[dim], addend);
+ return index;
+ };
+
+ auto emit_full_tile = [&] {
+ for (int64 i = 0; i < tile_size; i += num_rows) {
+ auto source_idx = offset_dim(index, index_typed_constant(i), /*dim=*/1);
+ auto y_loc = builder->CreateAdd(index_typed_constant(i), y);
+ emit_elem_function(source_idx, y_loc);
+ }
+ };
+
+ auto emit_last_row = [&] {
+ ksl->IfReturnVoid("x_in_tile", builder->CreateICmpULT(x, tile_width), [&] {
+ // tile_height_upper_bound =
+ // ceil(tile_height / num_rows) * num_rows
+ auto tile_height_upper_bound = builder->CreateMul(
+ builder->CreateUDiv(
+ builder->CreateAdd(tile_height,
+ index_typed_constant(num_rows - 1)),
+ index_typed_constant(num_rows)),
+ index_typed_constant(num_rows));
+ ksl->ForReturnVoid(
+ loop_name, /*start=*/index_typed_constant(0),
+ /*end=*/tile_height_upper_bound,
+ /*step=*/index_typed_constant(num_rows), [&](llvm::Value* y_indvar) {
+ auto y_loc = builder->CreateAdd(y_indvar, y);
+ ksl->IfReturnVoid(
+ "y_in_tile", builder->CreateICmpULT(y_loc, tile_height), [&] {
+ emit_elem_function(offset_dim(index, y_indvar, /*dim=*/1),
+ y_loc);
+ });
+ });
+ });
+ };
+ ksl->IfReturnVoid(
+ "full_tile",
+ builder->CreateAnd(
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_width),
+ builder->CreateICmpEQ(index_typed_constant(tile_size), tile_height)),
+ emit_full_tile, emit_last_row);
+}
+} // namespace
+
+// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
+// algorithm to improve the memory access patterns for the input parameters
+// which have a shape that is a 0-2-1 transpose of the output tensors.
+//
+// For the purpose of tiling, the output tensors have a logical shape of three
+// components 0-2-1 while the relevant input parameters have a logical shape of
+// three components 0-1-2 in the order major to minor. The x- and y- dimensions
+// of the tensors are tiled in square tiles of edge length `kTileSize`. Each
+// thread block of `kTileSize` x `kNumRows` threads transposes one tile: each
+// thread copies kTileSize/kNumRows elements from the input to a shared memory
+// tile, then the otherwise "regular hlo kernel" reads from the shared memory
+// instead of the original input.
+//
+// This is similar to the following CUDA algorithm in TensorFlow:
+// https://goo.gl/MStRV6.
+//
+// `kTileSize` should usually be same as warp size. We currently choose 32 for
+// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
+// to launch fewer blocks so each transposes many tiles.
+LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
+ HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids) {
+ // Parameters for the tiling algorithm.
+ constexpr int64 kTileSize = 32;
+ constexpr int64 kNumRows = 4;
+ constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
+
+ // Construct IrArrays for the inputs and outputs.
+ std::vector<IrArray> output_arrays;
+ int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
+ std::vector<IrArray> param_arrays;
+ int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+
+ // Allocate shared memory buffers to store the tiled inputs.
+ std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
+ for (int64 id : tiled_param_ids) {
+ const HloInstruction* param = hlo->operand(id);
+ // Add 1 to the minor dimension to reduce shared memory bank conflicts.
+ llvm::Type* tile_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(llvm_ir::PrimitiveTypeToIrType(
+ param->shape().element_type(), module_),
+ kTileSize + 1),
+ kTileSize);
+ const int kNVPTXSharedMemoryAddrSpace = 3;
+ auto* tile_base_ptr = new llvm::GlobalVariable(
+ *ir_builder_.GetInsertBlock()->getParent()->getParent(), tile_type,
+ /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
+ llvm::UndefValue::get(tile_type),
+ llvm_ir::AsStringRef(IrName(hlo, StrCat("tile", id))), nullptr,
+ llvm::GlobalValue::NotThreadLocal, kNVPTXSharedMemoryAddrSpace);
+ param_shmem_buffers[id] = tile_base_ptr;
+ VLOG(3) << "Added shmem buffer for parameter " << id << ": "
+ << llvm_ir::DumpToString(*tile_base_ptr);
+ }
+
+ // The 0-2-1 shape of the tiling scheme is the reduced shape of the HLO result
+ // for the purpose of tiling. Calculate the logical output dimensions in the
+ // tile from the reduced output dimensions.
+ std::vector<int64> output_dims_in_tiles = std::vector<int64>(
+ reduced_output_dims.begin(), reduced_output_dims.end());
+ CHECK_EQ(output_dims_in_tiles.size(), 3);
+ for (int i = 1; i < 3; ++i) {
+ output_dims_in_tiles[i] =
+ CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize);
+ }
+ const int64 num_tiles =
+ c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>());
+ LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile);
+
+ llvm::Type* index_ty = GetIndexTypeForKernel(
+ hlo, launch_dimensions.launch_bound(), &ir_builder_);
+ auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
+ return llvm::ConstantInt::get(index_ty, c);
+ };
+
+ // Cast each output IrArray to its corresponding reduced shape and keep the
+ // reduced shape live during IR emission.
+ std::vector<IrArray> output_in_reduced_shape_arrays;
+ std::vector<Shape> output_reduced_shapes;
+ CHECK_EQ(ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ *hlo, output_arrays, reduced_output_dims, &output_reduced_shapes,
+ &output_in_reduced_shape_arrays),
+ num_outputs);
+
+ // For each tiled parameter, cast its input IrArray to the corresponding
+ // reduced shape and keep the reduced shape live during IR emission.
+ std::vector<IrArray> param_in_reduced_shape_arrays;
+ std::vector<Shape> param_reduced_shapes;
+ CHECK_EQ(ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ *hlo, param_arrays, param_shmem_buffers, reduced_output_dims,
+ &param_reduced_shapes, &param_in_reduced_shape_arrays),
+ num_params);
+
+ // Calculate the starting element coordinate within a tile for the current
+ // thread, (y, x) from thread_id.
+ llvm::Value* x;
+ llvm::Value* y;
+ std::tie(y, x) = CalculateYXCoordinateWithinTile(
+ &ir_builder_, index_typed_constant(kTileSize), kThreadsPerTile);
+
+ // Calculate the index for the current output tile from block_id.
+ const IrArray::Index output_tile_index(
+ GetBlockIdx(&ir_builder_, index_ty, num_tiles),
+ ShapeUtil::MakeShapeWithDescendingLayout(PRED /*arbitrary*/,
+ output_dims_in_tiles),
+ &ir_builder_);
+
+ // Output tile origin is the index for the first element of the current output
+ // tile.
+ const IrArray::Index output_tile_origin = [&] {
+ IrArray::Index index = output_tile_index;
+ for (int i = 1; i < 3; ++i) {
+ index[i] = ir_builder_.CreateMul(output_tile_index[i],
+ index_typed_constant(kTileSize),
+ "tile_origin." + std::to_string(i));
+ }
+ return index;
+ }();
+
+ // Calculate the input tile origin from the output tile origin.
+ const IrArray::Index input_tile_origin(
+ Permute({0, 2, 1}, output_tile_origin.multidim()));
+
+ // Calculate the current output tile bounds in each of the logical dimensions.
+ std::vector<llvm::Value*> output_tile_bounds(3);
+ for (int i = 1; i < 3; ++i) {
+ // Only last row or column may not have full size.
+ output_tile_bounds[i] = ir_builder_.CreateSelect(
+ ir_builder_.CreateICmpEQ(
+ output_tile_index[i],
+ index_typed_constant(output_dims_in_tiles[i] - 1)),
+ index_typed_constant(reduced_output_dims[i] -
+ (output_dims_in_tiles[i] - 1) * kTileSize),
+ index_typed_constant(kTileSize), "kTileSize");
+ }
+
+ KernelSupportLibrary ksl(&ir_builder_, llvm_ir::UnrollMode::kDefaultUnroll);
+
+ // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
+ auto emit_tiled_elemental_code_with_bounds_check =
+ [&](const IrArray::Index& index, const string& loop_name,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const std::function<void(const IrArray::Index&, llvm::Value*)>&
+ emit_elem_function) {
+ EmitTiledElementalCodeWithBoundsCheck(
+ kTileSize, kNumRows, index, loop_name, &ksl, &ir_builder_, y, x,
+ tile_width, tile_height, emit_elem_function);
+ };
+
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) {
+ index[dim] = ir_builder_.CreateAdd(index[dim], addend);
+ return index;
+ };
+ const IrArray::Index input_index =
+ offset_dim(offset_dim(input_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Copy input parameter values to shared memory buffers:
+ // tile[y, x] = input[index]
+ emit_tiled_elemental_code_with_bounds_check(
+ input_index, "input", output_tile_bounds[1], output_tile_bounds[2],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id];
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ // TODO(jlebar): Add AA metadata to this store. Tile buffers are
+ // global variables, so LLVM can't infer much about it.
+ ir_builder_.CreateStore(
+ input_in_logical_shape.EmitReadArrayElement(index, &ir_builder_,
+ "input_element"),
+ ir_builder_.CreateGEP(shmem_buffer,
+ {index_typed_constant(0), y_loc, x}));
+ }
+ });
+
+ // Wait for all threads to reach this point, lest we copy a value from tile to
+ // output before the other thread copies it from input to tile.
+ // This is `__syncthreads` in CUDA.
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {},
+ &ir_builder_);
+
+ llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
+
+ const IrArray::Index output_index =
+ offset_dim(offset_dim(output_tile_origin, x, /*dim=*/2), y, /*dim=*/1);
+
+ // Write to output[index] by emitting code like normal, except that values for
+ // the tiled parameters are read from the shmem buffers.
+ if (hlo->opcode() == HloOpcode::kCopy) {
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ // TODO(jlebar): Add AA metadata to this load.
+ llvm::Instruction* load_from_shmem_buffer = ir_builder_.CreateLoad(
+ ir_builder_.CreateGEP(param_shmem_buffers[0],
+ {ir_builder_.getInt64(0), x, y_loc}),
+ "output_element");
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, load_from_shmem_buffer, &ir_builder_);
+ });
+ } else {
+ CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
+ emit_tiled_elemental_code_with_bounds_check(
+ output_index, "output", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc) {
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_,
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(param_arrays, &elem_emitter);
+ tiled_param_info.set_y(y_loc);
+ fused_emitter.SetTiledParameterInfo(&tiled_param_info);
+ TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
+ IrArray::Index untiled_index = llvm_ir::GetUnreducedOutputIndex(
+ index, output_reduced_shapes[0], output_arrays[0].GetShape(),
+ &ir_builder_);
+ const llvm_ir::ElementGenerator& output_generator =
+ fused_emitter.GetRootGenerator();
+ llvm::Value* output_value =
+ output_generator(untiled_index).ValueOrDie();
+ if (hlo->IsMultiOutputFusion()) {
+ CHECK(output_value->getType()->isStructTy());
+ CHECK_EQ(output_value->getType()->getStructNumElements(),
+ output_in_reduced_shape_arrays.size());
+ for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) {
+ output_in_reduced_shape_arrays[i].EmitWriteArrayElement(
+ index, ir_builder_.CreateExtractValue(output_value, i),
+ &ir_builder_);
+ }
+ } else {
+ output_in_reduced_shape_arrays[0].EmitWriteArrayElement(
+ index, output_value, &ir_builder_);
+ }
+ });
+ }
+
+ // For multioutput fusion, emit a tuple with all the individual outputs.
+ if (hlo->IsMultiOutputFusion()) {
+ std::vector<llvm::Value*> tuple_operand_ptrs;
+ for (int64 i = 0; i < output_arrays.size(); ++i) {
+ tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
+ }
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &ir_builder_,
+ module_);
+ }
+
+ return launch_dimensions;
+}
+
+bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
+ HloOpcode opcode = hlo->opcode();
+ CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
+ CHECK(opcode != HloOpcode::kFusion ||
+ hlo->fusion_kind() == HloInstruction::FusionKind::kLoop)
+ << "Only loop fusions are supported.";
+
+ const Shape& output_shape = hlo->IsMultiOutputFusion()
+ ? ShapeUtil::GetSubshape(hlo->shape(), {0})
+ : hlo->shape();
+
+ // If the output_shape is reduced to 021 shape, find all the parameters of the
+ // hlo that are in the corresponding 012 shape.
+ std::vector<int64> params_012;
+ optional<std::vector<int64>> reduced_dims_021;
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ HloInstruction* operand = hlo->mutable_operand(operand_idx);
+ auto find_transpose_result =
+ llvm_ir::FindTranspose021(operand->shape(), output_shape);
+ if (!find_transpose_result.has_value()) {
+ continue;
+ }
+ const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
+ if (!reduced_dims_021.has_value()) {
+ reduced_dims_021 = curr_reduced_dims_021;
+ }
+ if (!ContainersEqual(*reduced_dims_021, curr_reduced_dims_021)) {
+ // There is more than one possible transpose. Instead of picking one
+ // transpose, we simply give up here.
+ return false;
+ }
+ params_012.push_back(operand_idx);
+ }
+
+ if (!reduced_dims_021.has_value()) {
+ return false;
+ }
+
+ if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
+ (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
+ return false;
+ }
+
+ VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
+ thunk_sequence_->emplace_back(
+ BuildKernelThunk(hlo, /*implements_whole_instruction=*/true));
+ const LaunchDimensions launch_dimensions =
+ EmitHlo021Tile(hlo, *reduced_dims_021, params_012);
+ UpdateLaunchDimensions(launch_dimensions, LastThunk(),
+ ir_emitter_context_->llvm_module());
+
+ return true;
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index b41eaa303b..a1cc38401c 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
namespace xla {
namespace gpu {
@@ -67,16 +68,18 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleDot(HloInstruction* dot) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
- Status HandleGather(HloInstruction* gather) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleSelectAndScatter(HloInstruction* instruction) override;
Status HandleTuple(HloInstruction* tuple) override;
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleInfeed(HloInstruction* xla_infeed) override;
+ Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleAfterAll(HloInstruction* gen_token) override;
Status EmitTargetElementLoop(
const HloInstruction& hlo,
@@ -100,6 +103,13 @@ class IrEmitterUnnested : public IrEmitter {
const HloInstruction& inst,
tensorflow::gtl::ArraySlice<const BufferAllocation*> args);
+ // Helper for writing extra outputs from inside a reduce kernel.
+ Status EmitExtraOutputsForReduce(
+ const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
+
// EmitColumnReduction and EmitRowReduction emit code for column and row
// reduction of a matrix and/or 3D tensor. Row and column reduction have
// different memory access pattern, so for performance their implementations
@@ -115,7 +125,11 @@ class IrEmitterUnnested : public IrEmitter {
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Emits code that reduces a 3D tensor of shape [depth x height x width] to a
// vector of shape [height]. Other parameters have the same meaning as those
@@ -127,14 +141,22 @@ class IrEmitterUnnested : public IrEmitter {
const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Emits code that reduces a tensor of arbitrary rank to a scalar.
Status EmitReductionToScalar(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
// Figures out whether `reduce` is a row or column reduction, and which
// dimensions to reduce, and calls either `EmitRowReduction` or
@@ -147,20 +169,72 @@ class IrEmitterUnnested : public IrEmitter {
// Multiple reduces can be emitted in the same loop, assuming they have the
// same input and output shapes, and the same reduce dimensions.
//
+ // extra_output_gens can contain extra generators for intermediate outputs.
+ // These must have the same shape as the reduce input as they are computed
+ // when the reduce inputs are being read.
+ //
// Prerequisite: `IsReductionToVector(*reduce)`
Status EmitReductionToVector(
HloInstruction* reduce, const Shape& input_shape,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> input_gens,
tensorflow::gtl::ArraySlice<llvm_ir::ElementGenerator> init_value_gens,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
- tensorflow::gtl::ArraySlice<HloComputation*> reducers);
+ tensorflow::gtl::ArraySlice<HloComputation*> reducers,
+ tensorflow::gtl::ArraySlice<ShapeIndex> reduce_output_shapes,
+ tensorflow::gtl::ArraySlice<
+ std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens);
+
+ // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
+ // for the hlo instruction.
+ bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
+ // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
+ // returns the launch dimensions for the kernel. This is a helper to support
+ // the implementation of CheckAndEmitHloWithTile021.
+ LaunchDimensions EmitHlo021Tile(
+ HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ tensorflow::gtl::ArraySlice<int64> tiled_param_ids);
+ // Generates the IrArray for each output of hlo and returns the number of
+ // outputs.
+ int ConstructIrArrayForOutputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* output_arrays);
+ // Generates the IrArray for each input of hlo and returns the number of
+ // inputs.
+ int ConstructIrArrayForInputs(const HloInstruction& hlo,
+ std::vector<llvm_ir::IrArray>* param_arrays);
+ // For each output of the `hlo` instruction, constructs the reduced shape for
+ // the output with the given `reduced_output_dims` and cast the original
+ // output IrArray element in `output_arrays` to the reduced shape. Returns
+ // the number of outputs.
+ int ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& output_arrays,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* output_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* output_in_reduced_shape_arrays);
+ // For each input of the `hlo` instruction, checks its value in
+ // `param_buffers` to find out whether the input has a reduced shape. If the
+ // input has a reduced shape, constructs the reduced shape for the input and
+ // casts the original input IrArray in `param_arrays` to the reduced shape.
+ // Return the total number of inputs.
+ int ConstructInputReducedShapeAndCastInputIrArrayToShape(
+ const HloInstruction& hlo,
+ const std::vector<llvm_ir::IrArray>& param_arrays,
+ const std::vector<llvm::Value*>& param_buffers,
+ tensorflow::gtl::ArraySlice<int64> reduced_output_dims,
+ std::vector<Shape>* param_reduced_shapes,
+ std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
// Thunk object. The kernel implementation will be unrolled if unroll_factor
- // is greater than one.
- std::unique_ptr<KernelThunk> BuildKernelThunk(const HloInstruction* inst,
- int unroll_factor = 1);
+ // is greater than one. 'implements_whole_instruction' specifies whether this
+ // KernelThunk implements the whole 'inst' HloInstruction. In some cases
+ // 'inst' will be implemented by a sequence of Thunks.
+ std::unique_ptr<KernelThunk> BuildKernelThunk(
+ const HloInstruction* inst, bool implements_whole_instruction,
+ int unroll_factor = 1);
// Returns a FftThunk that calls cuFFT to implement `inst`.
std::unique_ptr<Thunk> BuildFftThunk(const HloInstruction* inst);
@@ -181,10 +255,14 @@ class IrEmitterUnnested : public IrEmitter {
std::unique_ptr<Thunk> BuildDeviceToDeviceCopyThunk(
const HloInstruction* inst);
- // Returns an InfeedThunk that performs device-to-device memcpy to implement
+ // Returns an InfeedThunk that performs a host-to-device memcpy to implement
// `inst`.
std::unique_ptr<Thunk> BuildInfeedThunk(const HloInstruction* inst);
+ // Returns an OutfeedThunk that performs a device-to-host memcpy to implement
+ // `inst`.
+ std::unique_ptr<Thunk> BuildOutfeedThunk(const HloInstruction* inst);
+
// Returns a WhileThunk that invokes thunk sequences for 'condition' and
// 'body' sub-computations of while instruction 'hlo'.
std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
index f56c1ce69f..e76823ad10 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -75,7 +76,8 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
}
Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
// Load the kernel.
se::StreamExecutor* executor = stream->parent();
LaunchDimensions launch_dimensions;
@@ -100,6 +102,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " ("
<< buf.size() << "B)";
}
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (!stream->parent()->Launch(
stream, se::ThreadDim(launch_dimensions.threads_per_block()),
se::BlockDim(launch_dimensions.block_count()), *kernel,
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
index 7def27e189..d751de50ad 100644
--- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -62,7 +63,8 @@ class KernelThunk : public Thunk {
// Executes the kernel for the thunk on "stream", which must be non-null.
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
// Buffers passed to the kernel as arguments.
diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
index d4100a898b..9fd6cf7157 100644
--- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc
@@ -14,21 +14,27 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace gpu {
Status MemzeroThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemZero(&dest_data, dest_data.size());
return Status::OK();
}
Status Memset32BitValueThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
stream->ThenMemset32(&dest_data, value_, dest_data.size());
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h
index 51c332d287..d1fec0bd76 100644
--- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MEMSET_THUNK_H_
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status.h"
@@ -36,7 +37,8 @@ class MemzeroThunk : public Thunk {
: Thunk(Kind::kMemzero, hlo), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice dest_;
@@ -52,7 +54,8 @@ class Memset32BitValueThunk : public Thunk {
: Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {}
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
uint32 value_;
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
new file mode 100644
index 0000000000..ea661b3c2c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -0,0 +1,263 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+
+#include <stdint.h>
+#include <algorithm>
+#include <iterator>
+#include <list>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {}
+
+bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ auto get_element_instr =
+ [&](const HloInstruction* instr) -> const HloInstruction* {
+ const HloInstruction* element_instr = instr;
+ if (instr->opcode() == HloOpcode::kFusion) {
+ auto fused_expression_root = instr->fused_expression_root();
+ if (instr->IsMultiOutputFusion()) {
+ // If possible, we want to pick a reduce operand of the fusion root,
+ // because it has the most constraints.
+ for (const auto* inst : fused_expression_root->operands()) {
+ if (inst->opcode() == HloOpcode::kReduce) {
+ return inst;
+ }
+ }
+ return fused_expression_root->operands()[0];
+ } else {
+ element_instr = fused_expression_root;
+ }
+ }
+ return element_instr;
+ };
+
+ auto get_element_shape = [&](const HloInstruction* element_instr) {
+ // Special handling of kReduce instructions -- the fusion
+ // applies to the first operand.
+ if (element_instr->opcode() == HloOpcode::kReduce) {
+ return element_instr->operand(0)->shape();
+ }
+ return element_instr->shape();
+ };
+
+ // The shapes in all tuple operands should agree, unless it is a reduce.
+ // In that case, the operand of the reduce needs to have the same shape
+ // as the other tuple operands, but also we need to compare the output
+ // shapes of the reduces.
+ // TODO(tjoerg): Allow differences in fp precision.
+ auto* element_instr_1 = get_element_instr(instr1);
+ auto* element_instr_2 = get_element_instr(instr2);
+ if (element_instr_1->opcode() == HloOpcode::kReduce &&
+ element_instr_2->opcode() == HloOpcode::kReduce &&
+ !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) {
+ return false;
+ }
+ // The elementwise output shapes must be the same (including layout).
+ return ShapeUtil::Equal(get_element_shape(element_instr_1),
+ get_element_shape(element_instr_2));
+}
+
+namespace {
+bool IsInputFusibleReduction(HloInstruction* instr) {
+ if (instr->IsMultiOutputFusion()) {
+ for (const HloInstruction* operand :
+ instr->fused_expression_root()->operands()) {
+ if (operand->opcode() == HloOpcode::kReduce) {
+ CHECK(instr->fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Reduce multi-output fusion " << instr->ToString()
+ << " must be an input fusion.";
+ return true;
+ }
+ }
+ return false;
+ } else if (instr->opcode() == HloOpcode::kFusion) {
+ // The loop emitter can handle to-vector reduce fusions. Such reduce
+ // fusions have the fusion kind kLoop rather than kInput. We do not fuse
+ // to-vector reduce fusions, because the resulting fusions may no longer be
+ // supported by loop emitter.
+ return IsReductionToVector(*instr->fused_expression_root());
+ } else {
+ return IsReductionToVector(*instr);
+ }
+}
+} // namespace
+
+bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
+ // We can fuse reduces and loop fusions.
+ return IsInputFusibleReduction(instr) ||
+ (instr->opcode() == HloOpcode::kFusion &&
+ instr->fusion_kind() == HloInstruction::FusionKind::kLoop);
+}
+
+int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ for (auto instr : instr1->operands()) {
+ if (!IsProfitableOperand(instr)) {
+ continue;
+ }
+ in_list.insert(instr);
+ }
+ int64 profit = 0;
+ for (auto instr : instr2->operands()) {
+ if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) {
+ continue;
+ }
+ profit += ShapeUtil::ByteSizeOf(instr->shape());
+ }
+ VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name()
+ << ", the profit is =" << profit;
+ return profit;
+}
+
+bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) {
+ return false;
+ }
+ // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
+ // merge into bigger loop fusions and input (reduce) fusions become fusions
+ // with multiple reduce outputs. We could fuse reduce and loop fusions
+ // together too (the result being an input fusion) if we find cases where this
+ // improves things.
+ CHECK(instr1->opcode() == HloOpcode::kFusion);
+ if (instr2->opcode() == HloOpcode::kFusion) {
+ return instr1->fusion_kind() == instr2->fusion_kind();
+ }
+ return instr1->fusion_kind() != HloInstruction::FusionKind::kLoop;
+}
+
+bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
+ bool changed = false;
+ RecomputeReachability();
+
+ tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ // Keep a list of the instructions to fuse after making all the fusion
+ // decisions. We first aggressively add instructions to potential_fusion_list,
+ // then filter out instructions that will be no longer fusable because of
+ // reachability change. This avoids recalculating reachability on a large set
+ // of instructions.
+ std::vector<std::pair<HloInstruction*, HloInstruction*>>
+ potential_fusion_list;
+ std::vector<std::pair<HloInstruction*, HloInstruction*>> fusion_list;
+ std::vector<HloInstruction*> instrs_to_update_reachability;
+
+ // For each reduce or reduce multi-output fusion, try to fuse it with loop
+ // fusions operands.
+ for (HloInstruction* consumer : computation()->MakeInstructionPostOrder()) {
+ if (consumer->user_count() == 0) {
+ continue;
+ }
+ if (!IsInputFusibleReduction(consumer)) {
+ continue;
+ }
+
+ auto consumer_operands = consumer->operands();
+ for (size_t i = 0; i < consumer_operands.size(); ++i) {
+ HloInstruction* producer = consumer_operands[i];
+ if (!producer->IsFusable()) {
+ continue;
+ }
+ const bool is_loop_fusion =
+ producer->opcode() == HloOpcode::kFusion &&
+ producer->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ if (!is_loop_fusion) {
+ continue;
+ }
+ if (!ShapesCompatibleForFusion(producer, consumer)) {
+ continue;
+ }
+ // If we have already decided to fuse this producer, skip it.
+ if (ContainsKey(to_fuse, producer)) {
+ continue;
+ }
+ // Do not fuse a producer if the other operands of the fusion are
+ // reachable from the producer, this would create a cycle.
+ if (c_any_of(consumer_operands, [&](HloInstruction* operand) {
+ return producer != operand &&
+ reachability()->IsReachable(producer, operand);
+ })) {
+ break;
+ }
+ to_fuse.insert(producer);
+ potential_fusion_list.emplace_back(producer, consumer);
+ instrs_to_update_reachability.push_back(producer);
+ instrs_to_update_reachability.push_back(consumer);
+ break;
+ }
+ }
+
+ // Filter out pairs that will be no longer fusable because of reachability
+ // change.
+ for (auto& fusion_pair : potential_fusion_list) {
+ HloInstruction* producer = fusion_pair.first;
+ HloInstruction* consumer = fusion_pair.second;
+ if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) {
+ return producer != operand &&
+ reachability()->IsReachable(producer, operand);
+ })) {
+ UpdateReachability(producer, consumer, instrs_to_update_reachability);
+ fusion_list.push_back(fusion_pair);
+ }
+ }
+
+ for (auto fusions_to_create : fusion_list) {
+ HloInstruction* producer = fusions_to_create.first;
+ HloInstruction* consumer = fusions_to_create.second;
+ if (consumer->opcode() != HloOpcode::kFusion) {
+ // Fusing with a reduce (fusion) always results in an input fusion.
+ HloInstruction* input_fusion =
+ computation()->AddInstruction(HloInstruction::CreateFusion(
+ consumer->shape(), HloInstruction::FusionKind::kInput, consumer));
+ VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
+ << consumer->name() << " into " << input_fusion->name();
+ TF_CHECK_OK(computation()->ReplaceInstruction(consumer, input_fusion));
+ if (producer->opcode() == HloOpcode::kFusion) {
+ input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ input_fusion->FuseInstructionIntoMultiOutput(producer);
+ }
+ } else {
+ VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
+ << consumer->name();
+
+ if (producer->opcode() == HloOpcode::kFusion) {
+ consumer->MergeFusionInstructionIntoMultiOutput(producer);
+ } else {
+ consumer->FuseInstructionIntoMultiOutput(producer);
+ }
+ }
+ changed = true;
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
new file mode 100644
index 0000000000..67ca5d49ee
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
+
+#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+
+namespace xla {
+namespace gpu {
+
+// Multi-output fusion of sibling and producer-consumer instructions for the
+// Jellyfish backend.
+class GpuMultiOutputFusion : public MultiOutputFusion {
+ public:
+ GpuMultiOutputFusion();
+
+ protected:
+ // Test if instr1 and instr2 have the compatible shapes that can be legally
+ // fused.
+ bool ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) override;
+
+ // We currently only consider reduce and reduce fusion nodes as candidates.
+ bool IsFusible(HloInstruction* instr) override;
+
+ // This function estimates the amount of memory reads saved by merging
+ // instr1 and instr2 into one multi-output fusion instruction. For a fusion
+ // instruction, all the operands need to be loaded from memory. If we merge
+ // instr1 and instr2, common operands will not be loaded twice. The profit is
+ // estimated as the size of the common operands b/w instr1 and instr2.
+ int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override;
+
+ // Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
+ bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) override;
+
+ // Fuse loop fusions into reduce fusions.
+ bool DoProducerConsumerMultiOutputFusion() override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
new file mode 100644
index 0000000000..979ea79243
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc
@@ -0,0 +1,353 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
+
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace op = xla::testing::opcode_matchers;
+
+namespace xla {
+namespace gpu {
+
+using InstructionFusionTest = HloTestBase;
+
+const char kModulePrefix[] = R"(
+ HloModule test_module
+
+ scalar_add_computation {
+ scalar_lhs.0 = f32[] parameter(0)
+ scalar_rhs.0 = f32[] parameter(1)
+ ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0)
+ }
+ scalar_mul_computation {
+ scalar_lhs.1 = f32[] parameter(0)
+ scalar_rhs.1 = f32[] parameter(1)
+ ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1)
+ })";
+
+TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
+ // Fusion with reduce instruction root and a sibling reduce instruction
+ // sharing the same input param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] constant(1)
+ fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation
+ reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce()));
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[6400]{0} parameter(1)
+ mul = f32[6400]{0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[6400]{0} parameter(1)
+ r1 = f32[64,100]{0,1} reshape(p1.2)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[6400]{0} parameter(1)
+ fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[10,10]{1,0} parameter(1)
+ mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[10,10]{1,0} parameter(1)
+ const.2 = f32[10]{0} parameter(0)
+ ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1.3 = f32[10,10]{1,0} parameter(1)
+ fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1
+ p2 = f32[] parameter(2)
+ fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
+ // Two sibling fusions with reduce instruction roots sharing the same input
+ // param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ const.2 = f32[] parameter(0)
+ ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[128,512,28,28]{3,2,1,0} parameter(1)
+ fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce()));
+}
+
+TEST_F(InstructionFusionTest,
+ MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) {
+ // Multi-output fusion with two reduce instructions root and a sibling reduce
+ // instruction sharing the same input param.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) {
+ const.1 = f32[] constant(1)
+ p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1)
+ reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2)
+ }
+
+ ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) {
+ p0 = f32[128,512,28,28]{3,2,1,0} parameter(0)
+ const = f32[] constant(1)
+ fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation
+ get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0
+ get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1
+ reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation
+ ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Reduce()));
+}
+
+TEST_F(InstructionFusionTest,
+ MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) {
+ // Verify that if we already have a multi-output fusion that we prefer to pick
+ // a reduce op from its operands for checking shape compatibility.
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p1.1 = f32[10,10]{1,0} parameter(1)
+ mul = f32[10,10]{1,0} multiply(p1.1, p1.1)
+ const.1 = f32[] parameter(0)
+ reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1)
+ }
+
+ fused_computation_2 {
+ p1.2 = f32[10,10]{1,0} parameter(1)
+ const.2 = f32[10] parameter(0)
+ ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation
+ }
+
+ ENTRY entry {
+ p0 = f32[] parameter(0)
+ p1 = f32[10,10]{1,0} parameter(1)
+ p2 = f32[10]{0} parameter(2)
+ fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1
+ get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0
+ get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1
+ fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2
+ ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, MultiOutputFusionTwoLoops) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_computation_1 {
+ p0.1 = f32[6400]{0} parameter(0)
+ ROOT mul = f32[6400]{0} multiply(p0.1, p0.1)
+ }
+
+ fused_computation_2 {
+ p0.2 = f32[6400]{0} parameter(0)
+ const.2 = f32[] constant(1)
+ ROOT div = f32[6400]{0} divide(p0.2, const.2)
+ }
+
+ ENTRY entry {
+ p0 = f32[6400]{0} parameter(0)
+ fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1
+ fusion.2 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_2
+ ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, fusion.2)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* fusion =
+ module->entry_computation()->root_instruction()->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Multiply(), op::Divide()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduce) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_add {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT add = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ add = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_add
+ reduce = f32[2,2]{1,0} reduce(add, c0), dimensions={2}, to_apply=scalar_add_computation
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(reduce, add)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Add()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_select {
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ c0 = f32[] constant(0)
+ broadcast = f32[2,2,2]{2,1,0} broadcast(f32[] c0), dimensions={}
+ greater-than = pred[2,2,2]{2,1,0} greater-than(f32[2,2,2]{2,1,0} p1.1, f32[2,2,2]{2,1,0} broadcast)
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT select = f32[2,2,2]{2,1,0} select(pred[2,2,2]{2,1,0} greater-than, f32[2,2,2]{2,1,0} p0.1, f32[2,2,2]{2,1,0} broadcast)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ c1 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0.2, c1), dimensions={2}, to_apply=scalar_add_computation
+ mul = f32[2,2,2]{2,1,0} multiply(p0.2, p0.2)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=scalar_add_computation
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ select = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_select
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(select), kind=kInput, calls=fused_reduce
+ gte0 = f32[2,2]{1,0} get-tuple-element(fusion), index=0
+ gte1 = f32[2,2]{1,0} get-tuple-element(fusion), index=1
+ ROOT root = (f32[2,2]{1,0}, f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(gte1, gte1, select)
+ })"))
+ .ValueOrDie();
+ ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+ SCOPED_TRACE(module->ToString());
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
+ op::GetTupleElement()));
+ const HloInstruction* fusion = root->operand(0)->operand(0);
+ ASSERT_TRUE(fusion->IsMultiOutputFusion());
+ EXPECT_THAT(fusion->fused_expression_root(),
+ op::Tuple(op::Reduce(), op::Reduce(), op::Select()));
+}
+
+TEST_F(InstructionFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) {
+ auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
+ fused_element_wise {
+ p0.1 = f32[2,2,2]{2,1,0} parameter(0)
+ p1.1 = f32[2,2,2]{2,1,0} parameter(1)
+ ROOT root = f32[2,2,2]{2,1,0} add(p0.1, p1.1)
+ }
+
+ fused_reduce {
+ p0.2 = f32[2,2,2]{2,1,0} parameter(0)
+ mul = f32[2,2,2]{2,1,0} multiply(f32[2,2,2]{2,1,0} p0.2, f32[2,2,2]{2,1,0} p0.2)
+ c1 = f32[] constant(0)
+ ROOT reduce = f32[2,2]{1,0} reduce(f32[2,2,2]{2,1,0} mul, f32[] c1), dimensions={1}, to_apply=scalar_add_computation
+ }
+
+ ENTRY reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ p1 = f32[2,2,2]{2,1,0} parameter(1)
+ element_wise = f32[2,2,2]{2,1,0} fusion(p0, p1), kind=kLoop, calls=fused_element_wise
+ fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(element_wise), kind=kLoop, calls=fused_reduce
+ ROOT root = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}) tuple(fusion, element_wise)
+ })"))
+ .ValueOrDie();
+ ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
new file mode 100644
index 0000000000..47744548b9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+void OutfeedManager::EnqueueOutfeedDestination(
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>* buffers) {
+ tensorflow::mutex_lock l(mu_);
+ enqueued_buffers_.push_back(buffers);
+ cv_.notify_one();
+}
+
+ShapeTree<std::unique_ptr<OutfeedBuffer>>*
+OutfeedManager::BlockingGetNextOutfeedDestination() {
+ tensorflow::mutex_lock l(mu_);
+ while (enqueued_buffers_.empty()) {
+ cv_.wait(l);
+ }
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>* current_buffer =
+ enqueued_buffers_.front();
+ enqueued_buffers_.pop_front();
+ return current_buffer;
+}
+
+OutfeedManager* GetOrCreateOutfeedManager() {
+ static auto* manager = new OutfeedManager;
+ return manager;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
new file mode 100644
index 0000000000..f580c24e17
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/notification.h"
+
+namespace xla {
+namespace gpu {
+
+// TODO(b/30467474) Once GPU outfeed implementation settles, consider
+// folding back the cpu and gpu outfeed implementations into a generic
+// one if possible.
+
+// Defines a buffer holding the destination for an outfeed in host memory and a
+// notification when that triggers when the transfer is done.
+class OutfeedBuffer {
+ public:
+ OutfeedBuffer(int64 length) : length_(length) {}
+
+ // Waits for the device transfer to be finished.
+ std::unique_ptr<Literal> WaitUntilAvailable() {
+ done_.WaitForNotification();
+ return std::move(destination_);
+ }
+
+ int64 length() const { return length_; }
+ void set_destination(std::unique_ptr<Literal> destination) {
+ destination_ = std::move(destination);
+ }
+ Literal* destination() { return destination_.get(); }
+
+ // Callback to signal that this buffer is consumed.
+ void Done() { done_.Notify(); }
+
+ private:
+ std::unique_ptr<Literal> destination_;
+ const int64 length_;
+ tensorflow::Notification done_;
+};
+
+// Manages a thread-safe queue of buffers. The buffers are supposed to be
+// produced by the transfer manager and consumed by the device.
+class OutfeedManager {
+ public:
+ // Adds a tree of buffers to the queue. The individual buffers correspond to
+ // the elements of a tuple and may be nullptr if the buffer is a tuple index
+ // buffer.
+ void EnqueueOutfeedDestination(
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>* buffers);
+
+ // Blocks until the queue is non-empty, then returns the buffer at the head of
+ // the queue.
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>*
+ BlockingGetNextOutfeedDestination();
+
+ private:
+ tensorflow::mutex mu_;
+
+ // Condition variable that is signaled every time a buffer is enqueued.
+ tensorflow::condition_variable cv_;
+
+ // The queue of trees of buffers. OutfeedBuffer* queue contents are not owned.
+ std::deque<ShapeTree<std::unique_ptr<OutfeedBuffer>>*> enqueued_buffers_;
+};
+
+// Singleton creator-or-accessor: Returns the GPU outfeed manager.
+OutfeedManager* GetOrCreateOutfeedManager();
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
new file mode 100644
index 0000000000..4c0f1421e9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+OutfeedThunk::OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
+ const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kOutfeed, hlo_instruction),
+ outfeed_slices_(std::move(outfeed_slices)) {}
+
+Status OutfeedThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString();
+
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
+ OutfeedManager* outfeed_manager = GetOrCreateOutfeedManager();
+ ShapeTree<std::unique_ptr<OutfeedBuffer>>* outfeed_buffers =
+ outfeed_manager->BlockingGetNextOutfeedDestination();
+
+ // Nothing to be done for empty tuples.
+ if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) {
+ return Status::OK();
+ }
+ CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(),
+ outfeed_buffers->shape()));
+
+ TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus(
+ [&](const ShapeIndex& index, std::unique_ptr<OutfeedBuffer>* buffer) {
+ if (!*buffer) { // Tuple pointers.
+ return Status::OK();
+ }
+ // Allocate storage for the literal data.
+ const Shape& shape =
+ ShapeUtil::GetSubshape(outfeed_buffers->shape(), index);
+ (*buffer)->set_destination(Literal::CreateFromShape(shape));
+
+ BufferAllocation::Slice slice = outfeed_slices_.element(index);
+ se::DeviceMemoryBase data_address;
+ if (slice.allocation()) {
+ // If we have a static allocation, read it from there. This avoids
+ // synchronizing the host and device just to read a pointer.
+ data_address = buffer_allocations.GetDeviceAddress(slice);
+ } else {
+ // Otherwise we have to read the tuple pointer first.
+ CHECK(!index.empty());
+ // Copy the parent buffer to the host.
+ BufferAllocation::Slice tuple_slice =
+ outfeed_slices_.element(ShapeIndexView(index).ConsumeFront());
+ if (!tuple_slice.allocation()) {
+ return Unimplemented(
+ "Nested dynamic tuples are not supported on GPU");
+ }
+ se::DeviceMemoryBase tuple_address =
+ buffer_allocations.GetDeviceAddress(tuple_slice);
+ CHECK(tuple_slice.size() % sizeof(void*) == 0)
+ << "Tuple size must be a multiple of pointer size";
+ std::vector<void*> tuple_element_buffer_addresses(tuple_slice.size() /
+ sizeof(void*));
+ stream->ThenMemcpy(tuple_element_buffer_addresses.data(),
+ tuple_address, tuple_slice.size());
+ TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
+ // The data address is specified by the element of the tuple pointer
+ // buffer.
+ data_address =
+ se::DeviceMemoryBase(tuple_element_buffer_addresses[index.back()],
+ (*buffer)->length());
+ }
+
+ // TODO(b/111309141): Run this on a separate stream so it doesn't block
+ // the GPU from doing work during the transfer. This could be handled by
+ // making StreamAssignment do something intelligent with outfeed thunks.
+ stream
+ ->ThenMemcpy((*buffer)->destination()->untyped_data(), data_address,
+ (*buffer)->length())
+ .ThenDoHostCallback([buffer]() { (*buffer)->Done(); });
+ return Status::OK();
+ }));
+
+ Status block_status = stream->BlockHostUntilDone();
+ if (!block_status.ok()) {
+ return InternalError("Failed to complete data transfer on stream %p: %s",
+ stream, block_status.error_message().c_str());
+ }
+
+ VLOG(2) << "Outfeeding from GPU complete";
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h
new file mode 100644
index 0000000000..8ed89f05f0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A thunk that outfeeds data. Data must be already resident on the host. This
+// thunk performs a host to device copy from the buffer allocated for the
+// outfeed op to the host location.
+class OutfeedThunk : public Thunk {
+ public:
+ // Constructs a OutfeedThunk that copies data to the host-side
+ // outfeed queue from the buffers in the given shape tree.
+ OutfeedThunk(ShapeTree<BufferAllocation::Slice> outfeed_slices,
+ const HloInstruction* hlo_instruction);
+
+ OutfeedThunk(const OutfeedThunk&) = delete;
+ OutfeedThunk& operator=(const OutfeedThunk&) = delete;
+
+ Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
+
+ private:
+ const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index c8f0d4185c..b22040eee1 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
@@ -68,7 +69,7 @@ HloInstruction* MaybePaddedAndSlicedInput(
PrimitiveType element_type = input->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
input = MakePadHlo(input, padding, padding_config).ValueOrDie();
}
@@ -125,7 +126,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window,
PrimitiveType element_type = kernel->shape().element_type();
HloInstruction* padding =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
return MakePadHlo(kernel, padding, padding_config).ValueOrDie();
}
} // namespace
@@ -234,9 +235,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// Create a new backward convolution replacing the old one.
HloComputation* computation = backward_conv->parent();
HloInstruction* output = backward_conv->mutable_operand(1);
- HloInstruction* padding =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(input->shape().element_type()))));
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(
+ LiteralUtil::Zero(input->shape().element_type()))));
HloInstruction* padded_input =
MakePadHlo(input, padding, input_padding_config).ValueOrDie();
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index d8c07dc311..cd833ec7bd 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -58,7 +58,7 @@ ParallelLoopEmitter::ParallelLoopEmitter(
std::vector<llvm_ir::IrArray::Index>
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name) {
+ tensorflow::StringPiece loop_name, llvm::Type* index_type) {
// Emit the following code in LLVM IR:
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
// if (linear_index < num_elements) {
@@ -71,14 +71,13 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
//
// %nctaid.x is currently specified as 2147483647.
VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_;
+ CHECK_NE(index_type, nullptr);
std::vector<llvm_ir::IrArray::Index> array_indices;
-
llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
static_cast<llvm::Instruction*>(block_id));
- block_id =
- ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id");
+ block_id = ir_builder_->CreateZExtOrTrunc(block_id, index_type, "block_id");
// Per the PTX documentation:
// "It is guaranteed that [...] 0 <= %tid.x < %ntid.x"
@@ -88,13 +87,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(),
static_cast<llvm::Instruction*>(thread_id));
- thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(),
- "thread_id");
+ thread_id =
+ ir_builder_->CreateZExtOrTrunc(thread_id, index_type, "thread_id");
llvm::Value* linear_index_base = ir_builder_->CreateAdd(
ir_builder_->CreateMul(
block_id,
- ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "",
+ llvm::ConstantInt::get(index_type,
+ launch_dimensions_.threads_per_block()),
+ "",
/*HasNUW=*/true, /*HasNSW=*/true),
thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true);
@@ -110,21 +111,23 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
llvm::Intrinsic::assume,
{ir_builder_->CreateICmpULT(
linear_index_base,
- ir_builder_->getInt64(launch_dimensions_.threads_per_block() *
- launch_dimensions_.block_count()),
+ llvm::ConstantInt::get(index_type,
+ launch_dimensions_.threads_per_block() *
+ launch_dimensions_.block_count()),
"linear_index_in_range")},
{}, ir_builder_);
if (unroll_factor_ > 1) {
linear_index_base = ir_builder_->CreateMul(
- linear_index_base, ir_builder_->getInt64(unroll_factor_),
+ linear_index_base, llvm::ConstantInt::get(index_type, unroll_factor_),
"linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
}
array_indices.emplace_back(linear_index_base, shape_, ir_builder_);
for (int i = 1; i < unroll_factor_; ++i) {
llvm::Value* linear_index = ir_builder_->CreateAdd(
- linear_index_base, ir_builder_->getInt64(i), "linear_index",
+ linear_index_base, llvm::ConstantInt::get(index_type, i),
+ "linear_index",
/*HasNUW=*/true, /*HasNSW=*/true);
array_indices.emplace_back(linear_index, shape_, ir_builder_);
}
@@ -132,7 +135,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
auto if_in_bounds = llvm_ir::EmitIfThenElse(
ir_builder_->CreateICmpULT(
linear_index_base,
- ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
+ llvm::ConstantInt::get(index_type, ShapeUtil::ElementsIn(shape_))),
llvm_ir::IrName(loop_name, "in_bounds"), ir_builder_, false);
// Set exit_bb_ to the exit block of the if structure.
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
index 25318b3bed..302e1bf1bc 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
~ParallelLoopEmitter() override = default;
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name) override;
+ tensorflow::StringPiece loop_name, llvm::Type* index_type) override;
private:
// The thread and block dimension to parallelize the loop on.
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
index c125474edb..02471129e0 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
@@ -47,6 +47,7 @@ class LaunchDimensions {
int64 block_count() const { return block_count_; }
int64 threads_per_block() const { return threads_per_block_; }
+ int64 launch_bound() const { return block_count() * threads_per_block(); }
private:
int64 block_count_;
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
index 88cb10883e..84285be70a 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -33,9 +34,12 @@ Status SequentialThunk::Initialize(const GpuExecutable& executable,
}
Status SequentialThunk::ExecuteOnStream(
- const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ const BufferAllocations& buffer_allocations, se::Stream* stream,
+ HloExecutionProfiler* profiler) {
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
for (const auto& thunk : thunks_) {
- TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ TF_RETURN_IF_ERROR(
+ thunk->ExecuteOnStream(buffer_allocations, stream, profiler));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
index 135f79e413..3c4de1d1a6 100644
--- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -41,7 +42,8 @@ class SequentialThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
// The list of sub-thunks.
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
index 696fa7e019..6f4bb0580e 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase {
auto debug_options = GetDebugOptionsForTest();
debug_options.set_xla_gpu_disable_multi_streaming(false);
config.set_debug_options(debug_options);
- return MakeUnique<HloModule>("test_module", VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>("test_module", config);
}
// Pre-canned shapes.
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
index 931c0bffab..99a1a0eae9 100644
--- a/tensorflow/compiler/xla/service/gpu/thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -53,6 +54,7 @@ class Thunk {
kKernel,
kMemset32BitValue,
kMemzero,
+ kOutfeed,
kSequential,
kTuple,
kWhile,
@@ -94,11 +96,12 @@ class Thunk {
// Execute the kernel for the thunk on the given stream. This method must be
// called after Initialize and can be called multiple times over Thunk's
- // lifetime. Stream argument must be non-null.
+ // lifetime. 'stream' and 'profiler' must be non-null.
//
// Precondition: Initialize(stream->parent()) has been called.
virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) = 0;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) = 0;
private:
Kind kind_;
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
index 97cb04c38f..a10e40451c 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -15,13 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace gpu {
Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
std::vector<void*> tuple_element_buffer_addresses;
for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) {
tuple_element_buffer_addresses.push_back(
@@ -31,6 +33,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
buffer_allocations.GetDeviceAddress(dest_buffer_));
auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
if (!stream
->ThenMemcpy(&dest_buffer_address,
tuple_element_buffer_addresses.data(), host_size)
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
index 951f809b51..2d5735d6c4 100644
--- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -46,7 +47,8 @@ class TupleThunk : public Thunk {
TupleThunk& operator=(const TupleThunk&) = delete;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const std::vector<BufferAllocation::Slice> tuple_element_buffers_;
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
index 30b9640c4c..1315a4183a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -29,10 +30,14 @@ WhileThunk::WhileThunk(
const HloInstruction* hlo)
: Thunk(Kind::kWhile, hlo),
condition_result_buffer_index_(condition_result_buffer_index),
+ // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_
+ // and body_thunk_sequence_ constructors because these SequentialThunks
+ // are logically "part of" this WhileThunk, and shouldn't be profiled
+ // separately from it.
condition_thunk_sequence_(MakeUnique<SequentialThunk>(
- std::move(*condition_thunk_sequence), hlo)),
- body_thunk_sequence_(
- MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+ std::move(*condition_thunk_sequence), nullptr)),
+ body_thunk_sequence_(MakeUnique<SequentialThunk>(
+ std::move(*body_thunk_sequence), nullptr)) {}
Status WhileThunk::Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) {
@@ -43,14 +48,18 @@ Status WhileThunk::Initialize(const GpuExecutable& executable,
}
Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) {
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) {
se::DeviceMemoryBase condition_result_data =
buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
+ auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
while (true) {
// Invoke thunk sequence for while 'condition' computation.
- TF_RETURN_IF_ERROR(
- condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ profiler->StartHloComputation();
+ TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(
+ buffer_allocations, stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_condition());
// Copy the result of condition computation and break the loop if 'false'.
bool condition_result;
@@ -66,9 +75,14 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
break;
}
- // Invoke thunk sequence for while 'body' computation.
- TF_RETURN_IF_ERROR(
- body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ // We measure the time of one execution of the while body computation. The
+ // while body may be executed more than once, the last measurement "wins".
+ profiler->StartHloComputation();
+ // Invoke thunk sequence for while 'body' computation, and pass on
+ // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'.
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(buffer_allocations,
+ stream, profiler));
+ profiler->FinishHloComputation(hlo_instruction()->while_body());
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
index 22176685a9..9270f95ee6 100644
--- a/tensorflow/compiler/xla/service/gpu/while_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,7 +49,8 @@ class WhileThunk : public Thunk {
Status Initialize(const GpuExecutable& executable,
se::StreamExecutor* executor) override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
- se::Stream* stream) override;
+ se::Stream* stream,
+ HloExecutionProfiler* profiler) override;
private:
const BufferAllocation::Slice condition_result_buffer_index_;
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index ad55728c45..c5321df6c4 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase {
return InvalidArgument("Unexpected tuple index instruction : %s",
inst->name().c_str());
} else if (tag == "loop_increment") {
- // Parse the constant which represents the loop induction variable
- // increment value.
+ // ParseHloString the constant which represents the loop induction
+ // variable increment value.
TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
} else if (tag == "param0" &&
inst != computation_->parameter_instruction(0)) {
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 2f290f61bd..dbc8442ed2 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -42,7 +42,7 @@ class WhileTransformerTest : public HloTestBase {
const int64 tuple_index, const int64 limit) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(limit)));
auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
0, GetLoopStateShape(tuple_index), "loop_state"));
auto induction_variable =
@@ -65,8 +65,8 @@ class WhileTransformerTest : public HloTestBase {
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
induction_variable_shape_, loop_state, ind_var_tuple_index));
- auto inc = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(increment)));
+ auto inc = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(increment)));
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
// Update data GTE(data_tuple_index).
@@ -89,10 +89,12 @@ class WhileTransformerTest : public HloTestBase {
const int64 ind_var_tuple_index,
const int64 ind_var_init) {
auto builder = HloComputation::Builder(TestName() + ".While");
- auto induction_var_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(ind_var_init)));
- auto data_init = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto induction_var_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(ind_var_init)));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
auto loop_state_init =
ind_var_tuple_index == 0
? builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index acf6611486..aa89567ee8 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -47,7 +48,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "x_value"));
auto half = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.5)));
builder.AddInstruction(HloInstruction::CreateBinary(
half->shape(), HloOpcode::kAdd, x_value, half));
return module->AddEmbeddedComputation(builder.Build());
@@ -122,7 +123,7 @@ std::unique_ptr<HloModule> MakeBigGraph() {
auto rng = builder.AddInstruction(
HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m}));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_computation = ScalarSumComputation(module.get());
builder.AddInstruction(
HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation));
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 06a5e0351b..4005fc0d11 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -27,6 +27,46 @@ using tensorflow::gtl::FlatMap;
using tensorflow::gtl::FlatSet;
/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function) {
+ if (module_sequence.empty()) {
+ return 0;
+ }
+
+ const HloModule* module = module_sequence.begin()->first->parent();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+
+ // The absolute minimum memory required for a given sequence of instructions
+ // is determined by the sequence of Alloc and Free calls on a simulated heap,
+ // ignoring fragmentation. We run the heap simulation on the whole module,
+ // rather than summing each computation, since it gives us a better lower
+ // bound, by minimizing the liveness of sub-computations.
+ TF_ASSIGN_OR_RETURN(
+ HeapSimulator::Result result,
+ HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
+ module_sequence, *points_to_analysis, size_function));
+ return result.heap_size;
+}
+
+/*static*/
+StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
+ TF_ASSIGN_OR_RETURN(
+ HeapSimulator::Result result,
+ HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
+ sequence, points_to_analysis, size_function,
+ HeapSimulator::Options(), memory_by_computation));
+ return result.heap_size;
+}
+
+/*static*/
StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloModule& module,
const SequentialHloOrdering::HloModuleSequence& module_sequence,
@@ -46,9 +86,11 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
std::unique_ptr<HeapAlgorithm> algorithm, const HloComputation& computation,
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
- const BufferValue::SizeFunction& size_fn, const Options& options) {
+ const BufferValue::SizeFunction& size_fn, const Options& options,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
- /*module_sequence=*/nullptr);
+ /*module_sequence=*/nullptr, memory_by_computation);
TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence,
points_to_analysis));
return heap.Finish();
@@ -188,6 +230,9 @@ Status HeapSimulator::RunComputation(
//
// INVARIANT: Either Alloc or ShareBuffer will be called for each buffer
// that we should assign.
+
+ // Make sure each buffer get reused at most once.
+ FlatSet<const BufferValue*> reused_buffers;
for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
@@ -200,6 +245,9 @@ Status HeapSimulator::RunComputation(
bool shared = false;
if (options_.may_reuse_operand_buffers) {
for (const BufferValue* operand_buffer : operand_buffers_to_free) {
+ if (reused_buffers.count(operand_buffer) != 0) {
+ continue;
+ }
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
buffer->instruction()->opcode() != HloOpcode::kCopy &&
points_to_analysis.CanShareOperandBufferWithUser(
@@ -209,6 +257,7 @@ Status HeapSimulator::RunComputation(
<< operand_buffer->ToString();
ShareBuffer(buffer, operand_buffer, instruction);
shared = true;
+ reused_buffers.insert(operand_buffer);
break;
}
}
@@ -219,6 +268,12 @@ Status HeapSimulator::RunComputation(
Alloc(buffer, instruction);
}
}
+ // Account for the memory used by subcomputations when estimating the
+ // current heap size.
+ if (memory_by_computation_ != nullptr) {
+ algorithm_->AccountForSubcomputationMemory(instruction,
+ *memory_by_computation_);
+ }
// If the whole module is sequential, we can save memory by running the
// heap-simulation for sub-computations inline. E.g. the buffers for the
@@ -286,12 +341,15 @@ Status HeapSimulator::RunComputation(
HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence)
+ const SequentialHloOrdering::HloModuleSequence* module_sequence,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation)
: no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
size_fn_(size_fn),
options_(options),
- module_sequence_(module_sequence) {
+ module_sequence_(module_sequence),
+ memory_by_computation_(memory_by_computation) {
debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr);
}
@@ -460,6 +518,26 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) {
}
}
+void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {
+ // We only count the memory usage of the largest subcomputation, instead of
+ // adding them all, because subcomputations won't execute in parallel.
+ int64 max_subcomputation_bytes = 0;
+ for (const auto* c : instruction->called_computations()) {
+ auto it = memory_by_computation.find(c);
+ if (it != memory_by_computation.end()) {
+ int64 subcomputation_bytes = it->second;
+ if (subcomputation_bytes > max_subcomputation_bytes) {
+ max_subcomputation_bytes = subcomputation_bytes;
+ }
+ }
+ }
+ max_heap_size_ =
+ std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes);
+}
+
void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) {
current_heap_size_ -= size;
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 8b2b43a37a..811a6042df 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -85,6 +85,23 @@ class HeapSimulator {
const BufferValueFlatSet* buffers_to_assign;
};
+ // Returns the minimum memory required to compute an HLO module where all
+ // computations have been scheduled (represented by the given
+ // module_sequence), assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForModule(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function);
+
+ // Returns the minimum memory required to compute the given computation,
+ // assuming no fragmentation.
+ static StatusOr<int64> MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
+
// Run the heap simulation with the given algorithm, assuming the given
// module_sequence, which must contain a topologically-consistent total
// ordering of all instructions within each computation. The result is invalid
@@ -111,7 +128,9 @@ class HeapSimulator {
const std::vector<const HloInstruction*>& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
- const Options& options = Options());
+ const Options& options = Options(),
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
private:
// If 'module_sequence' is non-null, it is used to find kCall and kWhile
@@ -120,7 +139,9 @@ class HeapSimulator {
HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const SequentialHloOrdering::HloModuleSequence* module_sequence);
+ const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation = nullptr);
~HeapSimulator();
Status RunComputation(
@@ -144,7 +165,13 @@ class HeapSimulator {
const std::unique_ptr<HeapAlgorithm> algorithm_;
const BufferValue::SizeFunction size_fn_;
const Options options_;
+ // module_sequence_ is set by buffer assignment, and memory_by_computation_ is
+ // set by hlo scheduling. Then, in RunComputation, we check both in order to
+ // handle subcomputations. It would be good to unify the handling of
+ // subcomputations, but it's not clear how.
const SequentialHloOrdering::HloModuleSequence* module_sequence_;
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ memory_by_computation_;
// In addition to Alloc and Free, the heap simulator exposes a concept of
// buffer sharing. When ShareBuffer is called, instead of allocating new
@@ -189,6 +216,11 @@ class HeapAlgorithm {
// Alloc allocates a buffer of 'size' bytes.
virtual void Alloc(const BufferValue* buffer, int64 size) = 0;
+ virtual void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) {}
+
// Free de-allocates a previously allocated buffer.
virtual void Free(const BufferValue* buffer, int64 size) = 0;
@@ -207,7 +239,14 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
~NoFragmentationStatsHeap() override = default;
void Alloc(const BufferValue* buffer, int64 size) override;
+
+ void AccountForSubcomputationMemory(
+ const HloInstruction* instruction,
+ const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ memory_by_computation) override;
+
void Free(const BufferValue* buffer, int64 size) override;
+
Result Finish() override;
private:
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 6271652412..b41dc66fe9 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -34,6 +34,65 @@ limitations under the License.
namespace xla {
namespace {
+class MinimumMemoryForSequenceTest : public HloTestBase {};
+
+TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
+
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
+ HloInstruction* cond_iter = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
+ HloInstruction* cond_data = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
+ // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
+ HloInstruction* cond_lt = cond_builder.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kLt, cond_iter, cond_data));
+ HloComputation* cond_computation =
+ module->AddEmbeddedComputation(cond_builder.Build());
+
+ auto body_builder = HloComputation::Builder("WhileBody");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
+ HloComputation* body_computation =
+ module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ // Entry params: 8 bytes (4 bytes per param), TOTAL=8
+ HloInstruction* iter = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
+ HloInstruction* data = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
+ // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
+ // While: 8 bytes (4 bytes per element), TOTAL=32
+ // Both cond and body use a max of 24 bytes, TOTAL=56
+ HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+ tuple_shape, cond_computation, body_computation, tuple));
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ };
+
+ SequentialHloOrdering::HloModuleSequence module_sequence;
+ module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
+ cond_lt};
+ module_sequence[body_computation] = {body_param};
+ module_sequence[entry_computation] = {iter, data, tuple, while_op};
+ EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn)
+ .ValueOrDie());
+}
+
const char kAlloc[] = "Alloc";
const char kFree[] = "Free";
const char kFinish[] = "Finish";
@@ -139,6 +198,11 @@ class HeapSimulatorTracker {
.ConsumeValueOrDie();
}
+ int64 OffsetAt(const HloInstruction* instruction, const ShapeIndex& index) {
+ const BufferValue* buffer = BufferAt(instruction, index);
+ return result_.chunk_map.at(buffer).offset;
+ }
+
// Ensures the expected sequence of Alloc/Free/Finish calls was performed.
void ExpectCallSequence(const CallSequence& expected) const {
EXPECT_EQ(expected, actual_calls_);
@@ -150,10 +214,9 @@ class HeapSimulatorTracker {
const ShapeIndex& index_a,
const HloInstruction* instruction_b,
const ShapeIndex& index_b) {
- const BufferValue* a = BufferAt(instruction_a, index_a);
- const BufferValue* b = BufferAt(instruction_b, index_b);
- EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset)
- << *a << ", " << *b;
+ int64 offset_a = OffsetAt(instruction_a, index_a);
+ int64 offset_b = OffsetAt(instruction_b, index_b);
+ EXPECT_EQ(offset_a, offset_b);
}
private:
@@ -176,7 +239,7 @@ class HeapSimulatorTest : public HloTestBase {
TEST_F(HeapSimulatorTest, ScalarConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
// Constants aren't assigned. See b/32248867
HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0});
@@ -252,6 +315,43 @@ TEST_F(HeapSimulatorTest, MultiplyAdd) {
tracker.ExpectSharedBuffers(add, {}, mul, {});
}
+TEST_F(HeapSimulatorTest, BufferReusedOnce) {
+ HeapSimulatorTracker tracker(TestName());
+ auto builder = HloComputation::Builder(TestName());
+
+ HloComputation::Builder fusion_builder("fusion");
+ {
+ HloComputation::Builder& builder = fusion_builder;
+ auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32vec4_, "A"));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec4_, HloOpcode::kExp, a_param));
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({exp, neg}));
+ }
+ auto fusion_computation =
+ tracker.module()->AddEmbeddedComputation(fusion_builder.Build());
+ auto a_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec4_, "paramA"));
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec4_, HloOpcode::kNegate, a_param));
+ auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
+ ShapeUtil::MakeTupleShape({f32vec4_, f32vec4_}),
+ HloInstruction::FusionKind::kLoop, {neg}, fusion_computation));
+ tracker.module()->AddEntryComputation(builder.Build());
+
+ tracker.RunWholeModule({a_param, neg, fusion});
+
+ auto neg_buffer = tracker.OffsetAt(neg, {});
+ int64 output_buffer_0 = tracker.OffsetAt(fusion, {0});
+ int64 output_buffer_1 = tracker.OffsetAt(fusion, {1});
+ // Only one buffer should be shared.
+ EXPECT_TRUE((neg_buffer == output_buffer_0) ^
+ (neg_buffer == output_buffer_1));
+}
+
TEST_F(HeapSimulatorTest, MultiplyDot) {
auto builder = HloComputation::Builder(TestName());
auto paramA = builder.AddInstruction(
@@ -574,7 +674,7 @@ class HeapAlgorithmTestBase : public ::testing::Test {
const BufferValue* DummyBufferValue() {
const BufferValue::Id id = buffers_.size();
auto const0 = builder_.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{}));
return buffers_.back().get();
}
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 1f7c1cffd3..d241791060 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -150,6 +150,11 @@ message HloInstructionProto {
// Backend configuration for the instruction. Has backend-specific meaning.
string backend_config = 43;
+
+ // Cross Replica Sum fields.
+ repeated int64 replica_group_ids = 44;
+ int64 all_reduce_id = 45;
+ string cross_replica_sum_barrier = 46;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index a88283ed9a..e8a4b034b4 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -452,15 +452,16 @@ string HloAliasAnalysis::ToString() const {
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
- HloModule* module) {
+ HloModule* module, const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer) {
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
- TF_ASSIGN_OR_RETURN(
- alias_analysis->dataflow_analysis_,
- HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
- /*bitcast_defines_value=*/false));
+ TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
+ /*bitcast_defines_value=*/false,
+ fusion_can_share_buffer));
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
@@ -493,6 +494,16 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
bool HloAliasAnalysis::HasLiveRangeInterference(
const HloOrdering& ordering) const {
for (const HloBuffer& buffer : buffers()) {
+ CHECK(!buffer.values().empty());
+ if (ShapeUtil::IsToken(buffer.values().front()->shape())) {
+ // Tokens have no on-device representation and cannot interfere.
+ for (const HloValue* value : buffer.values()) {
+ // If one of the values is a token, all values must be a token.
+ DCHECK(ShapeUtil::IsToken(value->shape()));
+ }
+ continue;
+ }
+
// Check that the values in the buffer are totally ordered with respect to
// 'ordering'. Begin by sorting the values with respect to 'ordering' with a
// tie-break using value ID. The tie-break is necessary because we need a
@@ -517,7 +528,6 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
// a buffer and A interferes with C, then necessarily A also interferes
// with B. So to check interference you only need to check interference
// between A and B, and between B and C.
- CHECK(!values.empty());
for (int i = 1; i < values.size(); ++i) {
if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) {
VLOG(1) << values[i - 1]->ToShortString() << " and "
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index 67dfd4301b..afb0c20f0c 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -39,7 +39,10 @@ class HloAliasAnalysis {
public:
// The callgraph of the given HloModule must be flattened
// (xla::FlattenCallGraph) prior to running the analysis.
- static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(HloModule* module);
+ static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(
+ HloModule* module,
+ const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr);
string ToString() const;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index 8f18d50f6e..403d4df6b5 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -116,9 +116,9 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) {
// Test the analysis on a single binary operation (Add).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, constant1, constant2));
module_->AddEntryComputation(builder.Build());
@@ -228,9 +228,9 @@ TEST_F(HloAliasAnalysisTest, SingleCall) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
module_->AddEntryComputation(builder.Build());
@@ -267,9 +267,9 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -346,15 +346,15 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -439,15 +439,15 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while0 = builder.AddInstruction(
@@ -498,7 +498,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
return cond_builder.Build();
};
// Build separate condition computations so the call graph is flat. The
@@ -543,9 +543,9 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto entry_while = builder.AddInstruction(
@@ -608,17 +608,17 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2, constant3}));
auto xla_while = builder.AddInstruction(
@@ -654,19 +654,18 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
}
TEST_F(HloAliasAnalysisTest, TupleSelect) {
- // Test a kSelect of a tuple value. Non-top-level element flow through the
- // instruction.
+ // Test a kTupleSelect. Non-top-level element flow through the instruction.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -677,13 +676,13 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
const Shape tuple_shape = tuple1->shape();
auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, select12, select34));
+ tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
module_->AddEntryComputation(builder.Build());
@@ -718,7 +717,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) {
}
TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
- // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO:
+ // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
//
// body((F32[], F32[]) %tuple_param):
// %negate = Negate(%tuple_param{0})
@@ -754,22 +753,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, select));
@@ -806,7 +805,7 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
// Bitcasting a value should not produce a new buffer.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
@@ -825,7 +824,7 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) {
// interference.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast}));
@@ -844,13 +843,13 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) {
// the other use of the init.
auto builder = HloComputation::Builder(TestName());
auto init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, init->shape(), "param"));
auto cond_root = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h
new file mode 100644
index 0000000000..7f73bba036
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Casting utilitiy functions for HLO instructions.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
+
+#include <type_traits>
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+class HloInstruction;
+
+template <class T>
+using EnableIfDerivedFromHlo =
+ typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;
+
+// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like
+// RTTI if it turns out to be a performance issue.
+// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
+// nullptr or runtime information does not match.
+//
+// Similar to LLVM's cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* Cast(const HloInstruction* instruction) {
+ CHECK(instruction != nullptr);
+ const T* casted = dynamic_cast<const T*>(instruction);
+ CHECK(casted != nullptr);
+ return casted;
+}
+
+// Non-const overload of Cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* Cast(HloInstruction* instruction) {
+ return const_cast<T*>(
+ Cast<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Works just like the Cast, except that it allows for a null pointer as an
+// argument which it then propagates.
+//
+// Similar to LLVM's cast_or_null.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* CastOrNull(const HloInstruction* instruction) {
+ return instruction != nullptr ? Cast<T>(instruction) : nullptr;
+}
+
+// Non-const overload of CastOrNull.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* CastOrNull(HloInstruction* instruction) {
+ return const_cast<T*>(
+ CastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Casts an HloInstruction pointer to one of its subclasses, dies if argument is
+// nullptr, returns nullptr if runtime information does not match.
+//
+// Similar to LLVM's dyn_cast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* DynCast(const HloInstruction* instruction) {
+ CHECK(instruction != nullptr);
+ return dynamic_cast<const T*>(instruction);
+}
+
+// Non-const overload of DynCast.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* DynCast(HloInstruction* instruction) {
+ return const_cast<T*>(
+ DynCast<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+// Works just like the DynCast, except that it allows for a null pointer as an
+// argument which it then propagates.
+//
+// Similar to LLVM's dyn_cast_or_null.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+const T* DynCastOrNull(const HloInstruction* instruction) {
+ return instruction != nullptr ? DynCast<T>(instruction) : nullptr;
+}
+
+// Non-const overload of DynCastOrNull.
+template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
+T* DynCastOrNull(HloInstruction* instruction) {
+ return const_cast<T*>(
+ DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc
new file mode 100644
index 0000000000..a336427540
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class DummyInstruction : public HloInstruction {
+ public:
+ DummyInstruction()
+ : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {}
+};
+
+class AnotherDummyInstruction : public HloInstruction {
+ public:
+ AnotherDummyInstruction()
+ : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {}
+};
+
+TEST(HloCastingUtilsTest, CastSucceeds) {
+ DummyInstruction instruction;
+ DummyInstruction* casted =
+ Cast<DummyInstruction>(static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, &instruction);
+}
+
+TEST(HloCastingUtilsTest, CastDiesForWrongType) {
+ AnotherDummyInstruction instruction;
+ ASSERT_DEATH(
+ Cast<DummyInstruction>(static_cast<HloInstruction*>(&instruction)), "");
+}
+
+TEST(HloCastingUtilsTest, CastDiesForNullptr) {
+ HloInstruction* null = nullptr;
+ ASSERT_DEATH(Cast<DummyInstruction>(null), "");
+}
+
+TEST(HloCastingUtilsTest, CastOrNullSucceeds) {
+ DummyInstruction instruction;
+ DummyInstruction* casted =
+ Cast<DummyInstruction>(static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, &instruction);
+}
+
+TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) {
+ AnotherDummyInstruction instruction;
+ ASSERT_DEATH(
+ Cast<DummyInstruction>(static_cast<HloInstruction*>(&instruction)), "");
+}
+
+TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) {
+ HloInstruction* null = nullptr;
+ DummyInstruction* casted = CastOrNull<DummyInstruction>(null);
+ ASSERT_EQ(casted, nullptr);
+}
+
+TEST(HloCastingUtilsTest, DynCastSucceeds) {
+ DummyInstruction instruction;
+ DummyInstruction* casted =
+ DynCast<DummyInstruction>(static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, &instruction);
+}
+
+TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) {
+ AnotherDummyInstruction instruction;
+ DummyInstruction* casted =
+ DynCast<DummyInstruction>(static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, nullptr);
+}
+
+TEST(HloCastingUtilsTest, DynCastDiesForNullptr) {
+ HloInstruction* null = nullptr;
+ ASSERT_DEATH(DynCast<DummyInstruction>(null), "");
+}
+
+TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) {
+ DummyInstruction instruction;
+ DummyInstruction* casted = DynCastOrNull<DummyInstruction>(
+ static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, &instruction);
+}
+
+TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) {
+ AnotherDummyInstruction instruction;
+ DummyInstruction* casted = DynCastOrNull<DummyInstruction>(
+ static_cast<HloInstruction*>(&instruction));
+ ASSERT_EQ(casted, nullptr);
+}
+
+TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) {
+ HloInstruction* null = nullptr;
+ DummyInstruction* casted = DynCastOrNull<DummyInstruction>(null);
+ ASSERT_EQ(casted, nullptr);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index b61eabbbf5..166a83fade 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -64,7 +64,7 @@ HloComputation::HloComputation(
const string& name, int parameter_count,
std::vector<std::unique_ptr<HloInstruction>>* instructions,
HloInstruction* root_instruction, HloInstruction* fusion_instruction)
- : name_(name),
+ : name_(NameUniquer::GetSanitizedName(name)),
unique_id_(-1),
root_instruction_(root_instruction),
fusion_instruction_(fusion_instruction) {
@@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
+namespace {
+
+// Returns the new name for a fusion parameter when we change its number.
+//
+// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
+// renumbering the parameters, so replace the final number in the name with
+// the updated value.
+string RenameFusionParameter(const string& original_name, int64 new_param_no) {
+ const string param_underscore = ".param_";
+ size_t index = original_name.rfind(param_underscore);
+ if (index == string::npos) {
+ return original_name;
+ }
+ string after_param = original_name.substr(index + param_underscore.size());
+ int64 numeric_suffix;
+ if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ return StrCat(original_name.substr(0, index + param_underscore.size()),
+ new_param_no);
+ }
+ return original_name;
+}
+
+} // namespace
+
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name = param_instruction->name();
- // Fusion parameters are named foo.param_1, bar.param_2, etc. We are
- // renumbering the parameters, so replace the final number in the name with
- // the updated value.
- const string param_underscore = ".param_";
- size_t index = param_name.rfind(param_underscore);
- if (index == string::npos) {
- string after_param = name().substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
- param_name =
- StrCat(param_name.substr(0, index), param_underscore, param_no);
- }
- }
-
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
@@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) {
return Status::OK();
}
+Status HloComputation::RemoveUnusedParameters() {
+ CHECK(IsFusionComputation());
+ int64 removed = 0;
+ for (int64 i = 0; i < param_instructions_.size(); ++i) {
+ HloInstruction* param_instruction = param_instructions_[i];
+ if (param_instruction->user_count() == 0 &&
+ param_instruction != root_instruction()) {
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ ++removed;
+ continue;
+ }
+
+ if (removed > 0) {
+ const int64 param_no = i - removed;
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_name));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ }
+ }
+ param_instructions_.resize(param_instructions_.size() - removed);
+ return Status::OK();
+}
+
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
@@ -234,7 +273,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
auto inst_it = instruction_iterators_.at(instruction);
(*inst_it)->set_parent(nullptr);
- instruction->DetachFromOperands();
instructions_.erase(inst_it);
return Status::OK();
}
@@ -246,9 +284,8 @@ void HloComputation::set_root_instruction(
if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
- << new_root_instruction->shape().ShortDebugString()
- << " is incompatible with "
- << root_instruction_->shape().ShortDebugString();
+ << new_root_instruction->shape() << " is incompatible with "
+ << root_instruction_->shape();
}
bool root_found = false;
for (auto& instruction : instructions_) {
@@ -264,46 +301,11 @@ void HloComputation::set_root_instruction(
namespace {
-// Helper class which computes the post order of an expression rooted at a
-// particular instruction.
-class InstructionPostOrderer : public DfsHloVisitorWithDefault {
- public:
- // added_instructions is the set of instructions which have already been
- // accounted for in the post order in previous invocations of
- // GetOrder. Without this mechanism, instructions which are predecessors of
- // multiple root instructions of the computation can be added to the post
- // order more than once.
- static std::list<HloInstruction*> GetOrder(
- HloInstruction* root,
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
- InstructionPostOrderer orderer(added_instructions);
- TF_CHECK_OK(root->Accept(&orderer));
- return std::move(orderer.post_order_);
- }
-
- private:
- explicit InstructionPostOrderer(
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
- : added_instructions_(added_instructions) {}
- ~InstructionPostOrderer() override {}
-
- Status DefaultAction(HloInstruction* hlo_instruction) override {
- if (added_instructions_->count(hlo_instruction) == 0) {
- post_order_.push_back(hlo_instruction);
- added_instructions_->insert(hlo_instruction);
- }
- return Status::OK();
- }
-
- std::list<HloInstruction*> post_order_;
- tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
-};
-
// Helper which builds a post order of the HLO call graph.
void ComputeComputationPostOrder(
HloComputation* computation,
tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::list<HloComputation*>* post_order) {
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -315,12 +317,53 @@ void ComputeComputationPostOrder(
}
}
+enum State { kVisiting, kVisited };
+
+void ComputeInstructionPostOrder(
+ std::vector<HloInstruction*>* post_order, HloInstruction* root,
+ tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
+ std::vector<HloInstruction*> dfs_stack;
+ dfs_stack.push_back(root);
+ while (!dfs_stack.empty()) {
+ const auto current = dfs_stack.back();
+ auto it = visited->find(current);
+ if (it != visited->end()) {
+ if (it->second == kVisited) {
+ // Already visited.
+ dfs_stack.pop_back();
+ continue;
+ }
+ // Visit this node.
+ CHECK_EQ(kVisiting, it->second);
+ dfs_stack.pop_back();
+ post_order->push_back(current);
+ it->second = kVisited;
+ continue;
+ }
+
+ visited->insert({current, kVisiting});
+
+ // Add the operands to the stack in reverse order so the first operand is
+ // processed first. This will produce a more natural ordering and a nicer
+ // result for thigns like HLO stringification.
+ const auto& operands = current->operands();
+ for (int64 i = operands.size() - 1; i >= 0; --i) {
+ dfs_stack.emplace_back(operands[i]);
+ }
+
+ for (HloInstruction* op : current->control_predecessors()) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+}
+
} // namespace
-std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
- std::list<HloInstruction*> post_order;
- std::list<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
+std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::vector<HloInstruction*> post_order;
+ post_order.reserve(instruction_count());
+ std::vector<HloInstruction*> trace_instructions;
+ tensorflow::gtl::FlatMap<HloInstruction*, State> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -328,21 +371,20 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- post_order.splice(post_order.end(),
- InstructionPostOrderer::GetOrder(instruction.get(),
- &added_instructions));
+ ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
}
}
- post_order.splice(post_order.end(), trace_instructions);
+ post_order.insert(post_order.end(), trace_instructions.begin(),
+ trace_instructions.end());
CHECK_EQ(instructions_.size(), post_order.size())
<< "number of instructions does not match post order size";
return post_order;
}
-std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
tensorflow::gtl::FlatSet<HloComputation*> visited;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
// 'this'. 'this' is immediately removed from the post order after
@@ -486,23 +528,11 @@ HloInstruction* HloComputation::CreateFusionInstruction(
}
StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
- if (ShapeUtil::IsArray(instruction->shape())) {
- if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
- // Use kCopy to copy array elements
- HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- if (copies_added != nullptr) {
- *copies_added->mutable_element(*index) = copy;
- }
- return copy;
- } else {
- // Array elements which are not to be copied are passed through
- // transparently.
- return instruction;
- }
- } else if (ShapeUtil::IsTuple(instruction->shape())) {
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
+ if (ShapeUtil::IsTuple(instruction->shape())) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
i++) {
@@ -512,17 +542,22 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
instruction, i));
index->push_back(i);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element,
- DeepCopyHelper(gte, indices_to_copy, copies_added, index));
+ TF_ASSIGN_OR_RETURN(HloInstruction * element,
+ DeepCopyHelper(gte, index, copy_leaf));
elements.push_back(element);
index->pop_back();
}
return AddInstruction(HloInstruction::CreateTuple(elements));
- } else {
- return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
}
+ if (ShapeUtil::IsToken(instruction->shape())) {
+ // Tokens have no on-device representation and cannot be copied. Pass
+ // through transparently.
+ return instruction;
+ }
+
+ // Array shape.
+ TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape()));
+ return copy_leaf(instruction, *index, this);
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
@@ -544,7 +579,36 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
}
ShapeIndex index;
- return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index);
+ auto copy_leaf = [indices_to_copy, copies_added](
+ HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation) {
+ if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
+ HloInstruction* copy = computation->AddInstruction(
+ HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
+ if (copies_added != nullptr) {
+ *copies_added->mutable_element(leaf_index) = copy;
+ }
+ return copy;
+ }
+ // Elements which are not to be copied are passed through
+ // transparently.
+ return leaf;
+ };
+ return DeepCopyHelper(instruction, &index, copy_leaf);
+}
+
+StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
+ if (instruction->parent() != this) {
+ return FailedPrecondition(
+ "Can't deep copy instruction %s: instruction is not in computation %s",
+ instruction->name().c_str(), name().c_str());
+ }
+ ShapeIndex index;
+ return DeepCopyHelper(instruction, &index, copy_leaf);
}
ProgramShape HloComputation::ComputeProgramShape() const {
@@ -609,7 +673,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
- const std::list<HloInstruction*> all = MakeInstructionPostOrder();
+ const auto& all = MakeInstructionPostOrder();
auto result = MakeUnique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;
@@ -617,7 +681,7 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
- result->SetReachabilityToUnion(inputs, hlo);
+ result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
}
@@ -827,15 +891,6 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
}
}
context->MapComputation(this, result.get());
- // We cloned the elements of 'replacements', so they're all going to be
- // destroyed. HloInstructions need to be detached from their operands before
- // they're destroyed, otherwise they stick around in the operands' users lists
- // and cause use-after-frees.
- for (auto& kv : replacements) {
- if (std::unique_ptr<HloInstruction>& new_instr = kv.second) {
- new_instr->DetachFromOperands();
- }
- }
return result;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0da4a305f3..abc1da4da3 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+#include <functional>
#include <list>
#include <memory>
#include <string>
@@ -113,6 +114,11 @@ class HloComputation {
// instruction.
Status RemoveParameter(int64 param_no);
+ // Remove unused parameters from the computation.
+ // Note this is only applicatable to the computation for the fusion
+ // instruction.
+ Status RemoveUnusedParameters();
+
// Add new parameter instruction to the computation.
// This should be a new parameter. Instruction will be appended to parameters
// and inserted to the instruction list.
@@ -199,7 +205,7 @@ class HloComputation {
// Compute and return a post-order of the instructions in the computation. In
// this order, definitions of values always appear before their uses.
- std::list<HloInstruction*> MakeInstructionPostOrder() const;
+ std::vector<HloInstruction*> MakeInstructionPostOrder() const;
// Computes and returns the reachability between HLO instructions in the
// computation. The returned HloReachabilityMap is constructed such that
@@ -221,7 +227,7 @@ class HloComputation {
// transitively. The embedded computations are sorted such that if computation
// A calls computation B (eg, via a map instruction) then A will appear after
// B in the list.
- std::list<HloComputation*> MakeEmbeddedComputationsList() const;
+ std::vector<HloComputation*> MakeEmbeddedComputationsList() const;
// Creates a fusion instruction containing the given instructions.
// `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
@@ -249,6 +255,14 @@ class HloComputation {
const ShapeTree<bool>* indices_to_copy = nullptr,
ShapeTree<HloInstruction*>* copies_added = nullptr);
+ // As above, but uses a custom function to copy the leaf nodes, which could
+ // create alternative HLOs other than kCopy, or even pass-throughs.
+ StatusOr<HloInstruction*> DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
+
// Computes and returns the ProgramShape of this computation (shape of
// parameters and result with layout).
ProgramShape ComputeProgramShape() const;
@@ -373,8 +387,10 @@ class HloComputation {
// Internal helper for recursive copying of an instruction. Creates and
// returns a deep copy of the given instruction.
StatusOr<HloInstruction*> DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index);
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf);
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 25469a54c4..e4c5470331 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <set>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -118,7 +118,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) {
// Test GetInstructionPostOrder for a computation with one instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
@@ -129,7 +129,7 @@ TEST_F(HloComputationTest, PostOrderSimple) {
// instructions.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto negate2 = builder.AddInstruction(
@@ -144,7 +144,7 @@ TEST_F(HloComputationTest, PostOrderTrace) {
// Test GetInstructionPostOrder for a computation with a trace instruction.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto trace =
@@ -163,13 +163,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(),
@@ -181,11 +181,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
// which are not connected.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -205,11 +205,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
// computation has multiple roots (dead code).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
// Add three disconnected add expressions.
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
constant1, constant2));
@@ -256,7 +256,7 @@ TEST_F(HloComputationTest, DeepCopyArray) {
// Test that DeepCopyInstruction properly copies an array.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
@@ -268,9 +268,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) {
// Test that DeepCopyInstruction properly copies a tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -289,7 +289,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) {
// copy are specified.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto computation = builder.Build();
{
@@ -314,9 +314,9 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
// specified by the given indices.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto computation = builder.Build();
@@ -371,11 +371,43 @@ TEST_F(HloComputationTest, DeepCopyTupleAtIndices) {
}
}
+TEST_F(HloComputationTest, DeepCopyToken) {
+ // Test that DeepCopyInstruction properly handles tokens which should not be
+ // copied.
+ auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ auto copy = computation->DeepCopyInstruction(token).ValueOrDie();
+
+ // No copy should be added.
+ EXPECT_THAT(copy, op::AfterAll());
+}
+
+TEST_F(HloComputationTest, DeepCopyTokenTuple) {
+ // Test that DeepCopyInstruction properly handles tokens which should not be
+ // copied.
+ auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({token, constant}));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ auto copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
+
+ // Only the array (second tuple element) should be copied. The token is passed
+ // through transparently.
+ EXPECT_THAT(copy, op::Tuple(op::GetTupleElement(tuple),
+ op::Copy(op::GetTupleElement(tuple))));
+}
+
TEST_F(HloComputationTest, CycleDetection) {
// Test whether the visitor can detect cycles in the graph.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto add = builder.AddInstruction(
@@ -385,6 +417,9 @@ TEST_F(HloComputationTest, CycleDetection) {
// Add a control dependency to create a cycle.
ASSERT_IS_OK(add->AddControlDependencyTo(negate));
+ auto instructions = computation->MakeInstructionPostOrder();
+ EXPECT_EQ(3, instructions.size());
+
const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
auto visit_status = computation->Accept(visitor);
ASSERT_FALSE(visit_status.ok());
@@ -398,7 +433,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
// twice.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto dead_negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -421,9 +456,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
TEST_F(HloComputationTest, CloneWithControlDependency) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
@@ -467,9 +502,9 @@ TEST_F(HloComputationTest, Reachability) {
// There is a control dependency from 'add' to 'exp'.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto negate = builder.AddInstruction(
@@ -572,13 +607,14 @@ TEST_F(HloComputationTest, Stringification) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationIndent) {
@@ -604,13 +640,14 @@ TEST_F(HloComputationTest, StringificationIndent) {
auto options =
HloPrintOptions().set_print_metadata(false).set_indent_amount(2);
- EXPECT_EQ(computation->ToString(options),
- R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation =
+ R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
- })");
+ })";
+ EXPECT_EQ(computation->ToString(options), expected_computation);
}
TEST_F(HloComputationTest, StringificationCanonical) {
@@ -635,21 +672,23 @@ TEST_F(HloComputationTest, StringificationCanonical) {
auto* computation = module->AddEntryComputation(builder.Build());
auto options = HloPrintOptions().set_print_metadata(false);
- EXPECT_EQ(computation->ToString(options),
- R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
+ const string expected_computation1 =
+ R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] {
%x = f32[5,10]{1,0} parameter(0)
%y = f32[20,10]{1,0} parameter(1)
%transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0}
ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation1);
options = HloPrintOptions().Canonical();
- EXPECT_EQ(computation->ToString(options), R"(TransposeDot {
+ const string expected_computation2 = R"(TransposeDot {
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(computation->ToString(options), expected_computation2);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
index 35ecd4428d..7229031c0c 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
@@ -51,14 +51,18 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
- // Skip Constant, Parameter, Reduce operation.
+ // Skip Constant, Parameter, Reduce, and AfterAll operation.
// TODO(b/35975797): Enable Reduce operation once arbitrary computation
// are supported by the evaluator.
// TODO(b/64407269): Enable Tuple once the timeout issue is resolved.
+ // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one
+ // operand in which case constant folding will be impossible and this
+ // special case is not necessary.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
- instruction->opcode() == HloOpcode::kReduce) {
+ instruction->opcode() == HloOpcode::kReduce ||
+ instruction->opcode() == HloOpcode::kAfterAll) {
continue;
}
// Skip instructions with non-constant operands.
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 5d05ccfc0b..64a42c1efc 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
@@ -62,7 +62,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
HloComputation::Builder builder(TestName());
HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
@@ -82,8 +82,8 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
HloComputation::Builder builder(TestName());
- HloInstruction* input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0f, 19.0f})));
+ HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
builder.AddInstruction(
HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
@@ -120,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
for (auto csize : test_config.concat_sizes) {
dimensions[test_config.concat_dimension] = csize;
concat_size += csize;
- auto literal = Literal::CreateFromDimensions(F32, dimensions);
+ auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
HloInstruction* insn = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
operands.push_back(insn);
@@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 94c9c7eabc..c49cf7f5db 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -164,7 +164,11 @@ Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleSelect(const HloInstruction*) {
+Status HloCostAnalysis::HandleSelect(const HloInstruction* hlo) {
+ return HandleElementwiseOp(hlo);
+}
+
+Status HloCostAnalysis::HandleTupleSelect(const HloInstruction*) {
return Status::OK();
}
@@ -172,15 +176,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleSlice(const HloInstruction*) {
+Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) {
+ current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2;
return Status::OK();
}
-Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) {
+Status HloCostAnalysis::HandleDynamicSlice(
+ const HloInstruction* dynamic_slice) {
+ current_properties_[kBytesAccessedKey] =
+ shape_size_(dynamic_slice->shape()) * 2;
return Status::OK();
}
-Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) {
+Status HloCostAnalysis::HandleDynamicUpdateSlice(
+ const HloInstruction* dynamic_update_slice) {
+ current_properties_[kBytesAccessedKey] =
+ shape_size_(dynamic_update_slice->operand(1)->shape()) * 2;
return Status::OK();
}
@@ -386,6 +397,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index d17678d20f..0181138a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -54,7 +54,8 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleConstant(const HloInstruction* constant) override;
Status HandleGetTupleElement(
const HloInstruction* get_tuple_element) override;
- Status HandleSelect(const HloInstruction* select) override;
+ Status HandleSelect(const HloInstruction* hlo) override;
+ Status HandleTupleSelect(const HloInstruction* hlo) override;
Status HandleCompare(const HloInstruction* compare) override;
Status HandleClamp(const HloInstruction* clamp) override;
Status HandleReducePrecision(const HloInstruction* hlo) override;
@@ -97,6 +98,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
+ Status HandleAfterAll(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
index 16fdda8a8b..9fd0363f57 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -59,9 +59,9 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a unary user function: x => exp(x + 0.5)
{
XlaBuilder builder("add_and_exp");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto half = builder.ConstantR0<float>(0.5);
- builder.Exp(builder.Add(x, half));
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto half = ConstantR0<float>(&builder, 0.5);
+ Exp(Add(x, half));
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
add_and_exp_ = computation_status.ConsumeValueOrDie();
@@ -70,9 +70,9 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary user function: (x, y) => x + y
{
XlaBuilder builder("add");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Add(x, y);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
add_ = computation_status.ConsumeValueOrDie();
@@ -81,9 +81,9 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
{
XlaBuilder builder("sigmoid");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto one = builder.ConstantR0<float>(1.0);
- builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = ConstantR0<float>(&builder, 1.0);
+ Div(one, Add(one, Exp(Neg(x))));
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
sigmoid_ = computation_status.ConsumeValueOrDie();
@@ -92,9 +92,9 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary max function: (x, y) => max (x, y)
{
XlaBuilder builder("max");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Max(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Max(x, y);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
max_ = computation_status.ConsumeValueOrDie();
@@ -103,9 +103,9 @@ class HloCostAnalysisTest : public ::testing::Test {
// Create a computation for a binary GT function: (x, y) => x > y
{
XlaBuilder builder("gt");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Gt(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Gt(x, y);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
gt_ = computation_status.ConsumeValueOrDie();
@@ -137,9 +137,9 @@ class HloCostAnalysisTest : public ::testing::Test {
TEST_F(HloCostAnalysisTest, MatrixMultiply) {
XlaBuilder builder("matrix_multiply");
- auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
- auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
+ auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
+ Dot(lhs, rhs);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -159,8 +159,8 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) {
TEST_F(HloCostAnalysisTest, Map) {
XlaBuilder builder("map");
- auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
- auto result = builder.Map({input}, add_and_exp_, {0});
+ auto input = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10}), "in");
+ Map(&builder, {input}, add_and_exp_, {0});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -176,17 +176,17 @@ TEST_F(HloCostAnalysisTest, Map) {
TEST_F(HloCostAnalysisTest, Convolution) {
XlaBuilder builder("convolution");
- auto input = builder.Parameter(
- 0,
+ auto input = Parameter(
+ &builder, 0,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
/*x_dim=*/20}),
"input");
- auto kernel = builder.Parameter(
- 1,
+ auto kernel = Parameter(
+ &builder, 1,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
/*x_dim=*/3}),
"kernel");
- auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
+ Conv(input, kernel, {1, 1}, Padding::kValid);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -206,9 +206,8 @@ TEST_F(HloCostAnalysisTest, Convolution) {
TEST_F(HloCostAnalysisTest, Reduce) {
XlaBuilder builder("reduce");
auto input =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
- auto result =
- builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ Reduce(input, ConstantR0<float>(&builder, 0.0f), add_, {1});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -224,9 +223,9 @@ TEST_F(HloCostAnalysisTest, Reduce) {
TEST_F(HloCostAnalysisTest, ReduceWindow) {
XlaBuilder builder("reduce_window");
auto input =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
- auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
- {4, 5}, {4, 5}, Padding::kValid);
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ ReduceWindow(input, ConstantR0<float>(&builder, 0), add_, {4, 5}, {4, 5},
+ Padding::kValid);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -241,12 +240,11 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) {
TEST_F(HloCostAnalysisTest, SelectAndScatter) {
XlaBuilder builder("select_and_scatter");
auto operand =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
auto source =
- builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
- auto result =
- builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
- source, builder.ConstantR0<float>(0), add_);
+ Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
+ SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid, source,
+ ConstantR0<float>(&builder, 0), add_);
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -261,7 +259,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) {
TEST_F(HloCostAnalysisTest, Broadcast) {
XlaBuilder b("broadcast");
- b.Broadcast(b.ConstantR0<float>(42), {10, 7});
+ Broadcast(ConstantR0<float>(&b, 42), {10, 7});
auto hlo_module = BuildHloGraph(&b);
HloCostAnalysis analysis(ShapeSize);
ASSERT_IS_OK(
@@ -273,13 +271,12 @@ TEST_F(HloCostAnalysisTest, Broadcast) {
TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
XlaBuilder builder("fully_connected_forward");
auto input =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
auto weight =
- builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
- auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
+ Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
+ auto bias = Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {20}), "bias");
// sigmoid(input * weight + bias)
- auto result = builder.Map(
- {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
+ Map(&builder, {Add(Dot(input, weight), bias, {1})}, sigmoid_, {0, 1});
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -297,11 +294,11 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
HloCostAnalysis conv_analysis(ShapeSize);
{
XlaBuilder builder("conv_looking_matmul");
- auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
- "input");
- auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
- "weights");
- builder.Conv(lhs, rhs, {1, 1}, Padding::kSame);
+ auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "input");
+ auto rhs = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "weights");
+ Conv(lhs, rhs, {1, 1}, Padding::kSame);
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
&conv_analysis));
@@ -311,10 +308,10 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
{
XlaBuilder builder("matmul");
auto lhs =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
auto rhs =
- builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
- builder.Dot(lhs, rhs);
+ Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
+ Dot(lhs, rhs);
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
&matmul_analysis));
@@ -341,13 +338,13 @@ TEST_F(FusionCostAnalysis, LoopFusion) {
// tuple = Tuple({sub, sub, mul, C1})
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)));
auto c3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2F32Linspace(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace(
/*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, c1, c2));
@@ -394,9 +391,9 @@ TEST_F(FusionCostAnalysis, NoLayout) {
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(2, 3, 4, 5))));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3})));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(shape_without_layout, c2, {1}));
@@ -419,9 +416,9 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
HloCostAnalysis analysis(ShapeSize);
{
XlaBuilder builder("matmul");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y");
- auto tuple = builder.Tuple({x, y});
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {42}), "y");
+ Tuple(&builder, {x, y});
auto hlo_module = BuildHloGraph(&builder);
ASSERT_IS_OK(
@@ -435,21 +432,21 @@ TEST_F(HloCostAnalysisTest, TupleCost) {
TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
XlaBuilder builder("BaseDilatedConvolution");
- auto input = builder.Parameter(
- 0,
+ auto input = Parameter(
+ &builder, 0,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
/*x_dim=*/20}),
"input");
- auto kernel = builder.Parameter(
- 1,
+ auto kernel = Parameter(
+ &builder, 1,
ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
/*x_dim=*/3}),
"kernel");
- auto result = builder.ConvGeneralDilated(
- input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}},
- /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
- XlaBuilder::CreateDefaultConvDimensionNumbers(2));
+ ConvGeneralDilated(input, kernel, /*window_strides=*/{1, 1},
+ /*padding=*/{{1, 1}, {1, 1}},
+ /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11},
+ XlaBuilder::CreateDefaultConvDimensionNumbers(2));
// Run HLO cost analysis.
auto hlo_module = BuildHloGraph(&builder);
@@ -460,5 +457,51 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) {
EXPECT_EQ(analysis.flop_count(), 1472);
}
+TEST_F(HloCostAnalysisTest, Slice) {
+ // Test the analysis on a slice.
+ XlaBuilder builder("slice");
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
+ Slice(x, {0}, {1}, {1});
+ auto hlo_module = BuildHloGraph(&builder);
+
+ // Run HLO cost analysis.
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(), 8);
+}
+
+TEST_F(HloCostAnalysisTest, DynamicSlice) {
+ // Test the analysis on a slice.
+ XlaBuilder builder("dynamic-slice");
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
+ DynamicSlice(x, ConstantR1<int32>(&builder, {1}), {1});
+ auto hlo_module = BuildHloGraph(&builder);
+
+ // Run HLO cost analysis.
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(), 8);
+}
+
+TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) {
+ // Test the analysis on a slice.
+ XlaBuilder builder("dynamic-update-slice");
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "x");
+ DynamicUpdateSlice(x, ConstantR1<float>(&builder, {1.0}),
+ ConstantR1<int32>(&builder, {1}));
+ auto hlo_module = BuildHloGraph(&builder);
+
+ // Run HLO cost analysis.
+ HloCostAnalysis analysis(ShapeSize);
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ EXPECT_EQ(analysis.bytes_accessed(), 8);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 0fb65c845a..90d2be118d 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
@@ -261,9 +262,9 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
padding_config_dim.set_edge_padding_high(zeros_to_append);
*padding_config.add_dimensions() = padding_config_dim;
- HloInstruction* zero =
- computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(operand->shape().element_type()))));
+ HloInstruction* zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(
+ LiteralUtil::Zero(operand->shape().element_type()))));
return MakePadHlo(operand, zero, padding_config);
}
@@ -272,7 +273,7 @@ StatusOr<HloInstruction*> BroadcastZeros(
ArraySlice<int64> broadcast_dimensions) {
HloInstruction* zero =
computation->AddInstruction(HloInstruction::CreateConstant(
- MakeUnique<Literal>(Literal::Zero(element_type))));
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
index 7e7c4f95fe..60d3e71757 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc
@@ -60,8 +60,8 @@ TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({3, 4}));
+ *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({3, 4}));
}
TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
@@ -82,10 +82,10 @@ TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
*module,
- {Literal::CreateR3<int32>(
+ {LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})}));
CHECK_EQ(*result_literal,
- *Literal::CreateR2<int32>(
+ *LiteralUtil::CreateR2<int32>(
{{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
}
@@ -103,10 +103,11 @@ TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9, 10}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(
+ *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9, 10}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
@@ -123,10 +124,11 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
HloEvaluator evaluator;
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
- evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({9, 10})}));
- CHECK_EQ(*result_literal, *Literal::CreateR3<int32>({{{9, 10}}}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result_literal,
+ evaluator.Evaluate<std::unique_ptr<Literal>>(
+ *module, {LiteralUtil::CreateR1<int32>({9, 10})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR3<int32>({{{9, 10}}}));
}
TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
@@ -145,8 +147,8 @@ TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<int32>(9)}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{9}}));
+ *module, {LiteralUtil::CreateR0<int32>(9)}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{9}}));
}
TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
@@ -166,9 +168,9 @@ TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
+ *module, {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
CHECK_EQ(*result_literal,
- *Literal::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
+ *LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
}
TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
@@ -188,8 +190,8 @@ TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR1<int32>({3, 4})}));
- CHECK_EQ(*result_literal, *Literal::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
+ *module, {LiteralUtil::CreateR1<int32>({3, 4})}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
@@ -209,8 +211,8 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<int32>(0)}));
- CHECK_EQ(*result_literal, *Literal::CreateR2<int32>({{0, 0}, {0, 0}}));
+ *module, {LiteralUtil::CreateR0<int32>(0)}));
+ CHECK_EQ(*result_literal, *LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
}
TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
@@ -230,9 +232,9 @@ TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
HloEvaluator evaluator;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
evaluator.Evaluate<std::unique_ptr<Literal>>(
- *module, {Literal::CreateR0<float>(0.0f)}));
+ *module, {LiteralUtil::CreateR0<float>(0.0f)}));
CHECK_EQ(*result_literal,
- *Literal::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index dab946a099..06484f4012 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -135,15 +135,14 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
// instruction for each class.
tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
decltype(cse_equal)>
- representatives(/*N=*/1024, &CseHash, cse_equal);
-
+ representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
+ cse_equal);
for (auto instruction : computation->MakeInstructionPostOrder()) {
// If the instruction has zero operands (constants, parameters, etc.) skip
// over it.
if (instruction->operand_count() == 0) {
continue;
}
-
// Skip instructions which have side effects.
if (instruction->HasSideEffect()) {
continue;
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e8c5ca347b..76b9c66651 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -32,10 +32,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
@@ -53,9 +53,9 @@ TEST_F(HloCseTest, CombineTwoConstants) {
// Test that two identical constants are commoned.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
EXPECT_EQ(42.0f, constant->literal().Get<float>({}));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR0<float>(84.0);
+ auto expected = LiteralUtil::CreateR0<float>(84.0);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -81,10 +81,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
// the pass is not layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
EXPECT_THAT(add, op::Add(first_operand, first_operand));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -113,10 +113,10 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
// if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
EXPECT_THAT(add, op::Add(constant1, constant2));
auto result = ExecuteAndTransfer(std::move(module), {});
- auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
@@ -144,20 +144,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
auto builder = HloComputation::Builder(TestName());
std::vector<HloInstruction*> constants;
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint64>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int64>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<double>(42.0))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0))));
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
// Duplicate the float constant to verify something happens.
constants.push_back(builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f))));
const Shape shape_r0 = ShapeUtil::MakeShape(F32, {});
for (int64 i = 0; i < constants.size(); ++i) {
@@ -188,13 +188,13 @@ TEST_F(HloCseTest, NonscalarConstants) {
// Test that identical nonscalar constants are merged.
auto builder = HloComputation::Builder(TestName());
auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
// Create a constant which has the same shape but a different value.
auto uncommon_constant =
builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
// Tie the constants together with a tuple. This makes it easier to refer to
// the constant instructions via their use.
@@ -223,7 +223,7 @@ TEST_F(HloCseTest, IdenticalInstructions) {
// Test that three identical instructions are commoned.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -253,7 +253,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
// commoned if the pass is layout sensitive.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
@@ -284,7 +284,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
// the pass is layout insensitive.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kExp, constant));
@@ -362,7 +362,7 @@ TEST_F(HloCseTest, IdenticalExpressions) {
// The *1 instructions should be merged with the *2 instructions.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kNegate, constant));
@@ -400,9 +400,9 @@ TEST_F(HloCseTest, DoNotCombineRng) {
// Test that two RNG ops are not commoned.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
{constant1, constant2}));
@@ -442,9 +442,9 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
auto builder = HloComputation::Builder(TestName() + "_rng_fun");
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto rng = builder.AddInstruction(HloInstruction::CreateRng(
scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
auto param = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -459,7 +459,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
{
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({5.0f})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
auto rng1 = builder.AddInstruction(
HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
auto rng2 = builder.AddInstruction(
@@ -486,7 +486,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
}
TEST_F(HloCseTest, CompareComputations) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule m
add_computation {
@@ -521,9 +521,9 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
// in this case) are not collapsed.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<uint32>(42)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -536,5 +536,40 @@ TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) {
EXPECT_EQ(2, computation->instruction_count());
}
+TEST_F(HloCseTest, Domain) {
+ auto module = ParseHloString(R"(
+HloModule module
+ENTRY %entry {
+ %param = f32[] parameter(0), sharding={maximal device=0}
+ %domain.0 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.1 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+ %domain.2 = f32[] domain(%param),
+ domain={kind="sharding", entry={maximal device=0}, exit={maximal device=2}}
+ %negate.0 = f32[] negate(%domain.0)
+ %negate.1 = f32[] negate(%domain.1)
+ %negate.2 = f32[] negate(%domain.2)
+ %domain.3 = f32[] domain(%negate.0),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.4 = f32[] domain(%negate.1),
+ domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}}
+ %domain.5 = f32[] domain(%negate.2),
+ domain={kind="sharding", entry={maximal device=2}, exit={maximal device=0}}
+ %add = f32[] add(%domain.3, %domain.4)
+ ROOT %sub = f32[] subtract(%add, %domain.5)
+})")
+ .ValueOrDie();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+ LOG(INFO) << "AAAAA " << module->ToString();
+ const HloInstruction* sub = module->entry_computation()->root_instruction();
+ const HloInstruction* add = sub->operand(0);
+ EXPECT_EQ(add->operand(0), add->operand(1));
+ EXPECT_NE(add->operand(0), sub->operand(1));
+ EXPECT_NE(add->operand(1), sub->operand(1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index cc130a4900..de1a32d8bd 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -34,16 +34,86 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
+
+// We have this pattern in dynamaic update slice fusion, which should be
+// supported:
+//
+// Parameters: p0, p1
+// Fusion
+// ds = DynamicSlice(p0, p1)
+// ROOT DynamicUpdateslice(p0, ds, p1)
+//
+// In this case, we should be able to reuse p0 and output, although p0 has
+// multiple uses.
+bool MultiDynamicSliceUseShareSameIndices(
+ tensorflow::gtl::ArraySlice<HloUse> uses) {
+ if (uses.empty()) {
+ return false;
+ }
+ const HloInstruction* indices = nullptr;
+ for (HloUse use : uses) {
+ auto user = use.instruction;
+ if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
+ if (indices == nullptr) {
+ indices = user->operand(2);
+ } else if (indices != user->operand(2)) {
+ return false;
+ }
+ if (use.operand_number != 0) {
+ return false;
+ }
+ } else if (user->opcode() == HloOpcode::kDynamicSlice) {
+ if (indices == nullptr) {
+ indices = user->operand(1);
+ } else if (indices != user->operand(1)) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
- bool bitcast_defines_value)
+HloDataflowAnalysis::HloDataflowAnalysis(
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer)
: module_(module),
ssa_form_(ssa_form),
bitcast_defines_value_(bitcast_defines_value),
- call_graph_(CallGraph::Build(&module)) {}
+ call_graph_(CallGraph::Build(&module)),
+ fusion_can_share_buffer_(fusion_can_share_buffer) {}
+
+bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
+ const HloInstruction* inst) {
+ tensorflow::gtl::FlatSet<const HloInstruction*> visited;
+ tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack;
+ stack.push_back(inst);
+ while (!stack.empty()) {
+ const HloInstruction* current = stack.back();
+ stack.pop_back();
+ visited.insert(current);
+ for (const HloInstruction* user : current->users()) {
+ // Found a user that is non-elementwise on current instruction.
+ for (const int64 use_index : user->OperandIndices(current)) {
+ if (!user->IsElementwiseOnOperand(use_index) &&
+ user->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ }
+ if (!visited.count(user)) {
+ stack.push_back(user);
+ }
+ }
+ }
+ return true;
+}
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
@@ -328,18 +398,17 @@ bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
bool changed = false;
- // RecvDone forwards the operand value at {0} to the output.
+ // RecvDone forwards the operand value at {0} to element {0} of its output.
for (auto& pair : GetInstructionValueSet(recv_done)) {
ShapeIndex& index = pair.first;
HloValueSet& value_set = pair.second;
- ShapeIndex operand_index = {0};
- for (int64 i : index) {
- operand_index.push_back(i);
+ if (index.empty() || index[0] != 0) {
+ continue;
}
const HloValueSet& operand_value_set =
- GetValueSet(recv_done->operand(0), operand_index);
+ GetValueSet(recv_done->operand(0), index);
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
@@ -396,6 +465,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
return changed;
}
+bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
+ // Domain instructions just forward their operand. Given that domains can have
+ // a tuple operand, we iterate through its indexes, like for copies.
+ // Unlike copies though we also propagate the top-level value.
+ CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(domain)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
bool changed = false;
@@ -490,17 +577,17 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
}
}
-bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) {
- CHECK_EQ(select->opcode(), HloOpcode::kSelect);
- // A phi value is not defined at a kSelect instruction because kSelect does
- // not create a new value. Rather it forwards a value from its operands. This
- // contrasts with kWhile instruction (which does define a phi value) which has
- // in-place update semantics.
+bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
+ CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
+ // A phi value is not defined at a kTupleSelect instruction because
+ // kTupleSelect does not create a new value. Rather it forwards a value from
+ // its operands. This contrasts with kWhile instruction (which does define a
+ // phi value) which has in-place update semantics.
bool changed = false;
for (auto& pair : GetInstructionValueSet(select)) {
const ShapeIndex& index = pair.first;
if (index.empty()) {
- // kSelect copies (not forwards) the top-level value.
+ // kTupleSelect copies (not forwards) the top-level value.
continue;
}
HloValueSet& value_set = pair.second;
@@ -556,12 +643,14 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateBitcastValueSet(instruction);
case HloOpcode::kSlice:
return UpdateSliceValueSet(instruction);
+ case HloOpcode::kDomain:
+ return UpdateDomainValueSet(instruction);
case HloOpcode::kCopy:
return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
return UpdateGetTupleElementValueSet(instruction);
- case HloOpcode::kSelect:
- return UpdateSelectValueSet(instruction);
+ case HloOpcode::kTupleSelect:
+ return UpdateTupleSelectValueSet(instruction);
case HloOpcode::kTuple:
return UpdateTupleValueSet(instruction);
case HloOpcode::kParameter:
@@ -734,6 +823,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kGetTupleElement:
+ case HloOpcode::kDomain:
// These instructions define no values. The values in their output
// flow from their operands or from cross computation dataflow.
break;
@@ -759,21 +849,25 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
break;
case HloOpcode::kCopy:
- case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
case HloOpcode::kTuple:
// These instructions only define their top-level values. Any other
// values flow from their operands.
define_top_level_only();
break;
case HloOpcode::kRecvDone:
- // RecvDone aliases its input tuple element {0}, therefore does not
- // define any values.
+ // RecvDone produces a two-element tuple. Element zero aliases its
+ // input tuple element {0}; element one is a token.
+ define_value_at(/*index=*/{});
+ define_value_at(/*index=*/{1});
break;
case HloOpcode::kSend:
- // Send produces a tuple of {aliased operand, U32 context}, therefore
- // only defines the top-level tuple and the tuple element at {1}.
+ // Send produces a tuple of {aliased operand, U32 context, token},
+ // therefore only defines the top-level tuple and the tuple elements
+ // at {1} and {2}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
+ define_value_at(/*index=*/{2});
break;
default:
define_all_values();
@@ -787,12 +881,13 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
- const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer) {
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
- auto dataflow_analysis = WrapUnique(
- new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
+ auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis(
+ module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
dataflow_analysis->Propagate();
@@ -915,6 +1010,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
+
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
@@ -927,20 +1023,27 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
const HloValue& value = GetValueDefinedAt(fusion_param, operand_index);
if (value.uses().size() != 1) {
+ if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
+ return true;
+ }
return false;
}
const HloUse& use = value.uses()[0];
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == 0;
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
+ user->fusion_kind() == HloInstruction::FusionKind::kInput) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ } else {
+ return AreTransitiveUsesElementwiseOrTuple(fusion_param);
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -965,8 +1068,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// index 'other_add_operand_index').
return use.instruction == user->fused_expression_root() &&
use.operand_number == other_add_operand_index;
+ } else if (fusion_can_share_buffer_ != nullptr &&
+ fusion_can_share_buffer_(user, operand)) {
+ return true;
}
}
+
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
@@ -998,8 +1105,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}) != uses.end();
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 9868746b61..f4abc7a7c7 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -42,6 +42,20 @@ namespace xla {
// Analysis which identifies all HLO values and their uses in an HLO module.
class HloDataflowAnalysis {
public:
+ // Different backends can have very different ways to do fusion, so we give
+ // backends the flexibility to decide whether an fusion instruction can share
+ // buffer with it's operands. If this is not specified, a default strategy
+ // will be used; if this is specified, it will be applied *in addition* to the
+ // default strategy.
+ //
+ // The first parameter of the function should be the fusion instruction, the
+ // second parameter should be an operand of the fusion instruction.
+ //
+ // TODO(b/80315712): Find a better way to tell whether a fusion can share
+ // buffer.
+ using FusionCanShareBufferFunction = std::function<bool(
+ const HloInstruction* fusion, const HloInstruction* operand)>;
+
// Run dataflow analysis on the given module. Parameters:
//
// ssa_form : If true then new values are defined at the merge points of
@@ -61,7 +75,10 @@ class HloDataflowAnalysis {
// value of its operand.
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
const HloModule& module, bool ssa_form = false,
- bool bitcast_defines_value = false);
+ bool bitcast_defines_value = false,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
+
+ static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
// Returns true if 'instruction' defines an HLO value at the given shape index
// of its output.
@@ -136,8 +153,10 @@ class HloDataflowAnalysis {
const ShapeIndex& user_index) const;
protected:
- HloDataflowAnalysis(const HloModule& module, bool ssa_form,
- bool bitcast_defines_value = false);
+ HloDataflowAnalysis(
+ const HloModule& module, bool ssa_form,
+ bool bitcast_defines_value = false,
+ const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
// Returns a new HloValue defined at the given instruction and shape index.
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
@@ -166,10 +185,11 @@ class HloDataflowAnalysis {
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
+ bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
- bool UpdateSelectValueSet(HloInstruction* select);
+ bool UpdateTupleSelectValueSet(HloInstruction* select);
bool UpdateSendValueSet(HloInstruction* send);
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
@@ -221,6 +241,10 @@ class HloDataflowAnalysis {
// The Id to use for the next HloValue.
HloValue::Id next_value_id_ = 0;
+
+ // Backend specific function that decides whether a fusion can share buffer
+ // with its operand.
+ FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 5798326dcb..37bc2d2c9d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -101,9 +101,9 @@ TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
// Test the dataflow for a simple binary operation (Add).
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, constant1, constant2));
module_->AddEntryComputation(builder.Build());
@@ -198,9 +198,9 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) {
// Verify the dataflow through a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto nested_tuple = builder.AddInstruction(
@@ -259,9 +259,9 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
module_->AddEntryComputation(builder.Build());
@@ -308,9 +308,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -362,9 +362,9 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call1 = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, called_computation));
auto call2 = builder.AddInstruction(HloInstruction::CreateCall(
@@ -426,9 +426,9 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto call = builder.AddInstruction(HloInstruction::CreateCall(
scalar_shape_, {constant1, constant2}, outer_computation));
module_->AddEntryComputation(builder.Build());
@@ -493,15 +493,15 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -594,15 +594,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while0 = builder.AddInstruction(
@@ -653,7 +653,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
@@ -691,9 +691,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto entry_while = builder.AddInstruction(
@@ -780,15 +780,15 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto xla_while = builder.AddInstruction(
@@ -840,11 +840,11 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) {
// Test a kSelect of an array value.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
@@ -860,19 +860,18 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) {
}
TEST_P(HloDataflowAnalysisTest, TupleSelect) {
- // Test a kSelect of a tuple value. Non-top-level element flow through the
- // instruction.
+ // Test a kTupleSelect. Non-top-level element flow through the instruction.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
@@ -883,20 +882,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({constant4}));
const Shape tuple_shape = tuple1->shape();
auto select11 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple1));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple1));
auto select12 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto select34 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple3, tuple4));
+ tuple_shape, HloOpcode::kTupleSelect, pred, tuple3, tuple4));
auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, select12, select34));
+ tuple_shape, HloOpcode::kTupleSelect, pred, select12, select34));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- // Top-level value is always defined by a kSelect.
+ // Top-level value is always defined by a kTupleSelect.
EXPECT_TRUE(analysis.ValueIsDefinedAt(select11));
EXPECT_TRUE(analysis.ValueIsDefinedAt(select12));
EXPECT_TRUE(analysis.ValueIsDefinedAt(select34));
@@ -937,20 +936,20 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) {
}
TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
- // Test kSelect of a nested tuple.
+ // Test kTupleSelect of a nested tuple.
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto constant4 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4.0)));
auto constant5 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(5.0)));
auto inner_tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant2, constant3}));
auto tuple1 = builder.AddInstruction(
@@ -960,7 +959,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant4, inner_tuple2}));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
module_->AddEntryComputation(builder.Build());
@@ -983,7 +982,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) {
}
TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
- // Test a tuple-shaped kSelect feeding a kWhile instruction. HLO:
+ // Test a tuple-shaped kTupleSelect feeding a kWhile instruction. HLO:
//
// body((F32[], F32[]) %tuple_param):
// %add = Add(%tuple_param{0}, %tuple_param{1})
@@ -1026,24 +1025,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple1 =
builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
auto tuple2 =
builder.AddInstruction(HloInstruction::CreateTuple({constant2}));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
auto gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, select, 0));
auto tuple =
@@ -1089,7 +1088,7 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
// Test the bitcast_defines_value flag to the dataflow analysis.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kBitcast, constant));
@@ -1158,44 +1157,50 @@ TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(param, /*channel_id=*/0));
+ HloInstruction::CreateSend(param, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 4);
+ EXPECT_EQ(analysis.values().size(), 6);
EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
}
TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
- // Test that a RecvDone forwards its operand tuple element at {0} to the
- // output.
+ // Test that a RecvDone forwards its operand tuple element at {0} to element
+ // {0} of the output.
auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(
- HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
+ HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
module_->AddEntryComputation(builder.Build());
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- EXPECT_EQ(analysis.values().size(), 3);
+ EXPECT_EQ(analysis.values().size(), 7);
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
- EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
- EXPECT_THAT(HloValuesAt(recv_done),
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
+ EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
+ EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
EXPECT_TRUE(
analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
@@ -1304,13 +1309,13 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
auto constant = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto exp = body_builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, exp, body_param));
auto dead_constant = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, dead_constant));
HloComputation* body = module_->AddEmbeddedComputation(
@@ -1320,7 +1325,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
@@ -1571,11 +1576,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
scalar_shape_, pred, constant1, true_computation, constant2,
false_computation));
@@ -1662,11 +1667,11 @@ TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
auto builder = HloComputation::Builder(TestName());
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(56.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(12.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
auto tuple_operand = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
@@ -1792,15 +1797,15 @@ TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
// Build entry computation.
auto builder = HloComputation::Builder(TestName());
auto pred1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto pred2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.2f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.3f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
auto tuple_operand = builder.AddInstruction(
HloInstruction::CreateTuple({pred2, constant1, constant2}));
auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
@@ -1880,9 +1885,14 @@ class HloDataflowAnalysisTestBase : public HloTestBase {
computation_ = module_->AddEntryComputation(std::move(computation));
}
- void RunAnalysis() {
+ void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
+ fusion_can_share_buffer = nullptr) {
CHECK_NOTNULL(module_.get());
- dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
+ dataflow_analysis_ =
+ HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
+ /*bitcast_defines_value=*/false,
+ fusion_can_share_buffer)
+ .ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
@@ -1933,9 +1943,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -1974,6 +1984,114 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ NonElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "param0"));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
+
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {reverse, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ MultiOutputFusionCanAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+
+ auto copy0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
+ auto copy1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {0}));
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {1}));
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {0}));
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {1}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ ElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {exp, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+ Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
+
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "param0"));
+ auto index = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 0})));
+ auto ds = builder.AddInstruction(
+ HloInstruction::CreateDynamicSlice(slice_shape, param, index, {1, 2, 2}));
+
+ auto dus = builder.AddInstruction(
+ HloInstruction::CreateDynamicUpdateSlice(data_shape, param, ds, index));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {dus, ds, index}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
auto builder = HloComputation::Builder(TestName());
@@ -2026,9 +2144,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -2048,6 +2166,45 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ FusedDynamicUpdateSliceWithConvertCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ auto convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape_bf16, gte1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape_bf16, convert1, update, starts));
+
+ auto convert2 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {convert2, dynamic_update_slice, starts, update, convert1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());
@@ -2080,9 +2237,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
@@ -2091,7 +2248,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -2113,7 +2270,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -2121,7 +2278,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
@@ -2136,6 +2293,33 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape, HloOpcode::kMultiply, operand, operand));
+ auto two = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, two, mul}, HloInstruction::FusionKind::kInput);
+ RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
+ const HloInstruction*) {
+ return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop;
+ });
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
@@ -2186,7 +2370,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index fcd723af14..7d35e251ca 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -41,20 +41,13 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
XLA_VLOG_LINES(2, module->ToString());
for (auto* computation : module->MakeComputationPostOrder()) {
- std::unordered_set<HloInstruction*> live_instructions;
- TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
- [&live_instructions](HloInstruction* instruction) {
- live_instructions.insert(instruction);
- return Status::OK();
- }));
-
// Remove any dead roots and their dead transitive operands. Collect them
// into a separate list first to avoid problems with iterating through the
// computation's instruction while simultaneously removing instructions.
std::vector<HloInstruction*> dead_roots;
for (auto* instruction : computation->instructions()) {
- if (instruction->user_count() == 0 &&
- live_instructions.count(instruction) == 0 &&
+ if (instruction != computation->root_instruction() &&
+ instruction->user_count() == 0 &&
computation->IsRemovable(instruction) &&
!instruction->HasSideEffect()) {
dead_roots.push_back(instruction);
@@ -85,8 +78,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
// Remove dead computations.
- std::list<HloComputation*> computations = module->MakeComputationPostOrder();
- for (auto* computation : computations) {
+ for (auto* computation : module->MakeComputationPostOrder()) {
if (live_computations.count(computation) == 0) {
TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation));
changed = true;
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index 5a56607a66..26e3736e01 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -53,9 +53,9 @@ TEST_F(HloDceTest, NoDeadCode) {
// Verify that no dead code is removed from a computation with no dead code.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
builder.AddInstruction(HloInstruction::CreateBinary(
constant1->shape(), HloOpcode::kAdd, constant1, constant2));
@@ -74,20 +74,21 @@ TEST_F(HloDceTest, InstructionsWithSideEffect) {
// Verify that side-effect instructions (Send in this test) are not removed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
builder.AddInstruction(
- HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
builder.AddInstruction(HloInstruction::CreateTuple({}));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
- EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(4, computation->instruction_count());
HloDCE dce;
EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
- EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(4, computation->instruction_count());
}
TEST_F(HloDceTest, DeadParameters) {
@@ -126,9 +127,9 @@ TEST_F(HloDceTest, ControlDependencies) {
// Verify that instructions with control dependencies are not removed.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
// Create two dead instructions: a negate and an add.
auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -223,7 +224,7 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
auto param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "cond_param"));
auto constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, constant));
}
@@ -234,9 +235,9 @@ TEST_F(HloDceTest, CalledComputationWithSideEffect) {
{
auto param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
-
- auto infeed =
- body_builder.AddInstruction(HloInstruction::CreateInfeed(shape, ""));
+ auto token = body_builder.AddInstruction(HloInstruction::CreateToken());
+ auto infeed = body_builder.AddInstruction(
+ HloInstruction::CreateInfeed(shape, token, ""));
body_builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, infeed));
}
@@ -278,8 +279,10 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) {
{
auto param = nested_callee_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param"));
+ auto token =
+ nested_callee_builder.AddInstruction(HloInstruction::CreateToken());
nested_callee_builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape, param, ""));
+ HloInstruction::CreateOutfeed(shape, param, token, ""));
}
auto nested_called_computation =
module->AddEmbeddedComputation(nested_callee_builder.Build());
@@ -342,12 +345,12 @@ TEST_F(HloDceTest, RemoveDeadSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
// Add another instruction as the root of the computation.
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(module->MakeComputationPostOrder().size(), 2);
@@ -383,7 +386,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
// Add another instruction as the root of the computation that also uses
@@ -393,7 +396,7 @@ TEST_F(HloDceTest, KeepUsedSubcomputation) {
builder.AddInstruction(HloInstruction::CreateParameter(
/*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")),
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))),
/*dimensions_to_reduce=*/{0}, reduce_subcomp));
module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index e0c5718509..eded3e78ee 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -26,10 +26,10 @@ limitations under the License.
namespace xla {
// Domain isolation is the task of placing kDomain instructions between HLO
-// instructions having different shrading. A kDomain instruction is essentially
+// instructions having different sharding. A kDomain instruction is essentially
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
-// kDomain instruciton will be placed.
+// kDomain instruction will be placed.
class HloDomainIsolator : public HloPassInterface {
public:
// Creates a new kDomain instruction for the edge between the use instruction
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index ebd5adb5d5..9e096320db 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -41,11 +41,15 @@ namespace xla {
bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
HloInstruction* instruction2) const {
- int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1);
- int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1);
+ int64 domain_id1 = GetDomainId(instruction1);
+ int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
+int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+ return FindOrDefault(instruction_to_domain_, instruction, -1);
+}
+
Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
// We only check operands, so we are sure to not process the empty domain from
@@ -58,6 +62,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ auto domain = MakeUnique<DomainMetadata::Domain>();
+ domain->enter_domains.insert(instruction);
+ TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index e62ef763fb..1ca7159725 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -65,6 +65,10 @@ class HloDomainMap {
// currently processing.
bool IsDomainInstruction(HloInstruction* instruction) const;
+ // Retrieves the domain identifier of the instruction, or -1 in case
+ // instruction is not found within any domain.
+ int64 GetDomainId(HloInstruction* instruction) const;
+
private:
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
index 1d06040b0e..e2e820002b 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -43,46 +43,8 @@ class HloDomainRemover::RunContext {
Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain(
const DomainMetadata::Domain& domain) {
- // Verify that the whole kDomain frontier bounding the instruction reach set,
- // has matching metadata.
- // A kDomain instruction has two sides of metadata, a user facing and an
- // operand facing.
- // A reachable instruction set can make contact with a kDomain instruction on
- // a user facing side (the kDomain is operand of the instruction), or on a
- // operand facing side (the kDomain is user of the instruction).
- // And depending on the contact side, the proper metadata object
- // (user_side_metadata() vs. operand_side_metadata()) needs to be used for
- // consistency checks.
- const DomainMetadata* ref_metadata = nullptr;
- VLOG(4) << "Reach set:";
- for (HloInstruction* instruction : domain.instructions) {
- VLOG(4) << " " << instruction->name();
- }
- VLOG(4) << " Domains:";
- for (HloInstruction* instruction : domain.enter_domains) {
- const DomainMetadata& meta = instruction->user_side_metadata();
- VLOG(4) << " User side: " << instruction->name();
- VLOG(4) << " " << meta.ToString();
- if (ref_metadata == nullptr) {
- ref_metadata = &meta;
- } else {
- TF_RET_CHECK(meta.Matches(*ref_metadata))
- << "Metadata mismatch at instruction " << instruction->name() << " : "
- << meta.ToString() << " vs " << ref_metadata->ToString();
- }
- }
- for (HloInstruction* instruction : domain.exit_domains) {
- const DomainMetadata& meta = instruction->operand_side_metadata();
- VLOG(4) << " Operand side: " << instruction->name();
- VLOG(4) << " " << meta.ToString();
- if (ref_metadata == nullptr) {
- ref_metadata = &meta;
- } else {
- TF_RET_CHECK(meta.Matches(*ref_metadata))
- << "Metadata mismatch at instruction " << instruction->name() << " : "
- << meta.ToString() << " vs " << ref_metadata->ToString();
- }
- }
+ TF_ASSIGN_OR_RETURN(const DomainMetadata* ref_metadata,
+ HloDomainVerifier::VerifyDomain(domain));
if (ref_metadata != nullptr) {
VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString();
TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain));
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index f29aac29c0..00b2c860a7 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -17,16 +17,17 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
-class HloDomainTest : public HloTestBase {
+class HloDomainTest : public HloVerifiedTestBase {
protected:
bool FindUserViaDomainPath(HloInstruction* instruction,
HloInstruction* operand) const {
@@ -64,11 +65,11 @@ class HloDomainTest : public HloTestBase {
return false;
}
- StatusOr<std::unique_ptr<HloModule>> ParseModule(
- tensorflow::StringPiece hlo_string) {
+ StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return tools::Parse(hlo_string, config);
+ ParseAndVerifyModule(hlo_string, config);
+ return &module();
}
};
@@ -143,32 +144,31 @@ ENTRY entry {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator isolator(CreateShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
- EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover remover(ShardingMetadata::KindName(),
NormalizeShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
- EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
}
TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) {
@@ -186,12 +186,11 @@ ENTRY entry {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator isolator(CreateShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(!isolator_changed);
}
@@ -202,37 +201,38 @@ HloModule Module
ENTRY entry {
p0 = (f32[4]) parameter(0)
a = f32[4] get-tuple-element(p0), index=0
- b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0}
- c = () send-done(b), channel_id=1, sharding={maximal device=0}
- d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0}
- e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0}
- f = f32[4] add(a, e)
- g = f32[4] subtract(a, e)
+ token = token[] after-all()
+ b = (f32[4], u32[], token[]) send(a, token), channel_id=1, sharding={maximal device=0}
+ c = token[] send-done(b), channel_id=1, sharding={maximal device=0}
+ d = (f32[4], u32[], token[]) recv(token), channel_id=2, sharding={maximal device=0}
+ e = (f32[4], token[]) recv-done(d), channel_id=2, sharding={maximal device=0}
+ e_element = f32[4] get-tuple-element(e), index=0, sharding={maximal device=0}
+ f = f32[4] add(a, e_element)
+ g = f32[4] subtract(a, e_element)
ROOT h = (f32[4], f32[4]) tuple(f, g)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator isolator(CreateShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
- EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+ EXPECT_TRUE(HasDomainEdge(module, "b", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "f", "e_element"));
+ EXPECT_FALSE(HasDomainEdge(module, "a", "p0"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover remover(ShardingMetadata::KindName(),
NormalizeShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
- EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e"));
+ EXPECT_FALSE(HasDomainEdge(module, "b", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "f", "e_element"));
}
TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
@@ -240,20 +240,21 @@ TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) {
HloModule Module
ENTRY entry {
- a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1}
- b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1}
- c = f32[4] add(b, b), sharding={maximal device=-1}
- d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1}
- ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1}
+ token = token[] after-all(), sharding={maximal device=-1}
+ a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=-1}
+ b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=-1}
+ b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=-1}
+ c = f32[4] add(b_element, b_element), sharding={maximal device=-1}
+ d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=-1}
+ ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=-1}
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator isolator(CreateShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_FALSE(isolator_changed);
}
@@ -262,24 +263,25 @@ TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) {
HloModule Module
ENTRY entry {
- a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0}
- b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0}
- c = f32[4] add(b, b)
- d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0}
- ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0}
+ token = token[] after-all(), sharding={maximal device=0}
+ a = (f32[4], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=0}
+ b = (f32[4], token[]) recv-done(a), channel_id=1, sharding={maximal device=0}
+ b_element = f32[4] get-tuple-element(b), index=0, sharding={maximal device=0}
+ c = f32[4] add(b_element, b_element)
+ d = (f32[4], u32[], token[]) send(c, token), channel_id=2, sharding={maximal device=0}
+ ROOT e = token[] send-done(d), channel_id=2, sharding={maximal device=0}
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainRemover remover(ShardingMetadata::KindName(),
NormalizeShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_FALSE(remover_changed);
- HloInstruction* add = FindInstruction(module.get(), "c");
+ HloInstruction* add = FindInstruction(module, "c");
ASSERT_NE(add, nullptr);
auto device = add->sharding_unique_device();
EXPECT_TRUE(device.has_value());
@@ -302,42 +304,41 @@ ENTRY entry {
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator sharding_isolator(CreateShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
- sharding_isolator.Run(module.get()));
+ sharding_isolator.Run(module));
EXPECT_TRUE(sharding_isolator_changed);
HloDomainIsolator opname_isolator(OpNameDomainCreator);
TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
- opname_isolator.Run(module.get()));
+ opname_isolator.Run(module));
EXPECT_TRUE(opname_isolator_changed);
- EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "c", "b"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "a"));
+ EXPECT_TRUE(HasDomainEdge(module, "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "e", "d"));
HloDomainRemover sharding_remover(ShardingMetadata::KindName(),
NormalizeShardingDomain);
TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed,
- sharding_remover.Run(module.get()));
+ sharding_remover.Run(module));
EXPECT_TRUE(sharding_remover_changed);
HloDomainRemover opname_remover(OpNameMetadata::KindName(),
OpNameDomainNormalizer);
TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed,
- opname_remover.Run(module.get()));
+ opname_remover.Run(module));
EXPECT_TRUE(opname_remover_changed);
- EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "c", "b"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "a"));
+ EXPECT_FALSE(HasDomainEdge(module, "d", "c"));
}
TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
@@ -345,33 +346,35 @@ TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) {
HloModule Module
ENTRY entry {
- infeed = (f32[4], f32[4]) infeed(),
- sharding={{maximal device=1}, {maximal device=0}}
- gte0 = f32[4] get-tuple-element(infeed), index=0
- gte1 = f32[4] get-tuple-element(infeed), index=1
+ token = token[] after-all()
+ infeed = ((f32[4], f32[4]), token[]) infeed(token),
+ sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}}
+ infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0
+ gte0 = f32[4] get-tuple-element(infeed.data), index=0
+ gte1 = f32[4] get-tuple-element(infeed.data), index=1
copy0 = f32[4] copy(gte0)
copy1 = f32[4] copy(gte1)
ROOT add = f32[4] add(copy0, copy1)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseModule(hlo_string));
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();
HloDomainIsolator isolator(CreateShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);
- EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed"));
- EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0"));
- EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1"));
+ EXPECT_TRUE(HasDomainEdge(module, "infeed.data", "infeed"));
+ EXPECT_FALSE(HasDomainEdge(module, "copy0", "gte0"));
+ EXPECT_FALSE(HasDomainEdge(module, "copy1", "gte1"));
// Inject unassigned tuple/gte within the infeed domain, to simulate the
// HLO passes adding unexpected instructions.
//
// infeed
+ // |
+ // infeed.data (tuple element 0 of infeed)
// / \
// GTE0 GTE1
// / \
@@ -380,31 +383,36 @@ ENTRY entry {
// \ /
// TUPLE
// |
- // DOMAIN
- HloInstruction* infeed = FindInstruction(module.get(), "infeed");
+ HloInstruction* infeed = FindInstruction(module, "infeed");
ASSERT_NE(infeed, nullptr);
- auto infeed_users = infeed->users();
- HloInstruction* new_gte0 =
+ HloInstruction* infeed_data =
infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0));
+
+ auto infeed_data_users = infeed_data->users();
+ HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed_data->shape(), 0), infeed_data,
+ 0));
HloInstruction* new_copy0 =
- infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
new_gte0->shape(), HloOpcode::kCopy, new_gte0));
- HloInstruction* new_gte1 =
- infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1));
+ HloInstruction* new_gte1 = infeed_data->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(infeed_data->shape(), 1), infeed_data,
+ 1));
HloInstruction* new_copy1 =
- infeed->parent()->AddInstruction(HloInstruction::CreateUnary(
+ infeed_data->parent()->AddInstruction(HloInstruction::CreateUnary(
new_gte1->shape(), HloOpcode::kCopy, new_gte1));
- HloInstruction* new_tuple = infeed->parent()->AddInstruction(
+ HloInstruction* new_tuple = infeed_data->parent()->AddInstruction(
HloInstruction::CreateTuple({new_copy0, new_copy1}));
- for (HloInstruction* user : infeed_users) {
- TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple));
+ for (HloInstruction* user : infeed_data_users) {
+ TF_EXPECT_OK(infeed_data->ReplaceUseWith(user, new_tuple));
}
HloDomainRemover remover(ShardingMetadata::KindName(),
NormalizeShardingDomain);
- TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
EXPECT_TRUE(remover_changed);
struct Assignment {
@@ -418,7 +426,7 @@ ENTRY entry {
};
for (auto& assignment : assignments) {
auto device = assignment.instruction->sharding_unique_device();
- EXPECT_TRUE(device.has_value());
+ ASSERT_TRUE(device.has_value());
EXPECT_EQ(*device, assignment.device);
}
EXPECT_TRUE(new_tuple->has_sharding());
@@ -428,5 +436,64 @@ ENTRY entry {
HloSharding::AssignDevice(0)}));
}
+TEST_F(HloDomainTest, EmptyRootDomain) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+ENTRY entry {
+ %param = f32[1] parameter(0), sharding={maximal device=0}
+ %tuple = (f32[1]) tuple(%param),
+ sharding={maximal device=1}
+ ROOT %gte = f32[1] get-tuple-element(%tuple), index=0,
+ sharding={maximal device=1}
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
+
+ HloDomainIsolator isolator(CreateShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
+ EXPECT_TRUE(isolator_changed);
+
+ EXPECT_TRUE(HasDomainEdge(module, "tuple", "param"));
+ EXPECT_FALSE(HasDomainEdge(module, "gte", "tuple"));
+
+ // Remove %tuple and %gte (tuple simplification)
+ HloInstruction* gte = FindInstruction(module, "gte");
+ HloInstruction* tuple = FindInstruction(module, "tuple");
+ module->entry_computation()->set_root_instruction(tuple->mutable_operand(0));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(gte));
+ TF_EXPECT_OK(module->entry_computation()->RemoveInstruction(tuple));
+
+ HloDomainRemover remover(ShardingMetadata::KindName(),
+ NormalizeShardingDomain);
+ TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module));
+ EXPECT_TRUE(remover_changed);
+
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_TRUE(root->has_sharding());
+ EXPECT_EQ(root->sharding(), HloSharding::AssignDevice(1));
+}
+
+// Tests that text dumps of domain instructions can be parsed back, in the
+// specific case of null shardings.
+TEST_F(HloDomainTest, DumpParseNullSharding) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {});
+ auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr);
+ auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr);
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
+ HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain(
+ shape, param, std::move(sharding_md_0), std::move(sharding_md_1)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, domain, domain));
+
+ auto module = CreateNewModule();
+ module->AddEntryComputation(builder.Build());
+
+ auto hlo_string = module->ToString();
+ ASSERT_TRUE(ParseModule(hlo_string).status().ok());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
new file mode 100644
index 0000000000..751fc677e2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_domain_verifier.h"
+
+#include <set>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+class HloDomainVerifier::RunContext {
+ public:
+ RunContext(HloModule* module, HloDomainVerifier* verifier)
+ : module_(module), verifier_(verifier) {}
+
+ Status Run();
+
+ private:
+ // If the verifier caller passed an empty vector for kinds, we collect all the
+ // avalable domain types.
+ Status PopulateDomainKinds();
+
+ HloModule* module_;
+ HloDomainVerifier* verifier_;
+};
+
+Status HloDomainVerifier::RunContext::PopulateDomainKinds() {
+ if (verifier_->kinds_.empty()) {
+ // The caller specified no domain kinds, collect all the ones available.
+ std::set<string> kinds;
+ for (HloComputation* computation : module_->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kDomain) {
+ TF_RET_CHECK(instruction->user_side_metadata().Kind() ==
+ instruction->operand_side_metadata().Kind())
+ << instruction->ToString();
+ kinds.insert(instruction->user_side_metadata().Kind().ToString());
+ }
+ }
+ }
+ verifier_->kinds_.insert(verifier_->kinds_.end(), kinds.begin(),
+ kinds.end());
+ }
+ return Status::OK();
+}
+
+Status HloDomainVerifier::RunContext::Run() {
+ VLOG(4) << "Running HLO Domain Verifier";
+ TF_RETURN_IF_ERROR(PopulateDomainKinds());
+ for (HloComputation* computation : module_->computations()) {
+ for (auto& kind : verifier_->kinds_) {
+ // First create the domain instruciton sets. A domain instruction set is
+ // the set of instructions whose edges never cross a kDomain instruction.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDomainMap> domain_map,
+ HloDomainMap::Create(computation, kind));
+ // Verify every domain populated within the map.
+ for (auto& domain : domain_map->GetDomains()) {
+ TF_RETURN_IF_ERROR(VerifyDomain(*domain).status());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<bool> HloDomainVerifier::Run(HloModule* module) {
+ RunContext run_context(module, this);
+ TF_RETURN_IF_ERROR(run_context.Run());
+ return false;
+}
+
+StatusOr<const DomainMetadata*> HloDomainVerifier::VerifyDomain(
+ const DomainMetadata::Domain& domain) {
+ const DomainMetadata* ref_metadata = nullptr;
+ VLOG(4) << "Reach set:";
+ for (HloInstruction* instruction : domain.instructions) {
+ VLOG(4) << " " << instruction->name();
+ }
+ VLOG(4) << " Domains:";
+ for (HloInstruction* instruction : domain.enter_domains) {
+ const DomainMetadata& meta = instruction->user_side_metadata();
+ VLOG(4) << " User side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ for (HloInstruction* instruction : domain.exit_domains) {
+ const DomainMetadata& meta = instruction->operand_side_metadata();
+ VLOG(4) << " Operand side: " << instruction->name();
+ VLOG(4) << " " << meta.ToString();
+ if (ref_metadata == nullptr) {
+ ref_metadata = &meta;
+ } else {
+ TF_RET_CHECK(meta.Matches(*ref_metadata))
+ << "Metadata mismatch at instruction " << instruction->name() << " : "
+ << meta.ToString() << " vs " << ref_metadata->ToString();
+ }
+ }
+ return ref_metadata;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
new file mode 100644
index 0000000000..8e53cf97f8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_domain_map.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// Verifies that the domain instructions are consistent, and the each domain is
+// surrounded by the same metadata.
+class HloDomainVerifier : public HloPassInterface {
+ public:
+ HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
+
+ tensorflow::StringPiece name() const override { return "domain_verifier"; }
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ // Verify that the whole kDomain frontier bounding the instruction reach set,
+ // has matching metadata.
+ // A kDomain instruction has two sides of metadata, a user facing and an
+ // operand facing.
+ // A reachable instruction set can make contact with a kDomain instruction on
+ // a user facing side (the kDomain is operand of the instruction), or on a
+ // operand facing side (the kDomain is user of the instruction).
+ // And depending on the contact side, the proper metadata object
+ // (user_side_metadata() vs. operand_side_metadata()) needs to be used for
+ // consistency checks.
+ // Returns the DomainMetadata pointer which surrounds the domain, and
+ // represents the common metadata within such domain. If the returned
+ // DomainMetadata pointer is nullptr, the input domain had no kDomain
+ // boundary.
+ static StatusOr<const DomainMetadata*> VerifyDomain(
+ const DomainMetadata::Domain& domain);
+
+ private:
+ class RunContext;
+
+ std::vector<string> kinds_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_VERIFIER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index abec29df43..c804f4364f 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
@@ -141,6 +141,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
// These are ops with embedded computations where it suffices to convert
// the embedded computations instead of converting the ops themselves.
if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
+ opcode == HloOpcode::kCrossReplicaSum ||
opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
opcode == HloOpcode::kSelectAndScatter ||
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
index 5c5a059e0f..c170e36c73 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
@@ -57,8 +57,10 @@ TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) {
const string& hlo_string = R"(
HloModule InfeedOutfeed
ENTRY RoundTrip16MiBR1.v2 {
- ROOT infeed = bf16[4]{0} infeed()
- outfeed = () outfeed(infeed)
+ token = token[] after-all()
+ infeed = (bf16[4]{0}, token[]) infeed(token)
+ ROOT infeed.data = bf16[4]{0} get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
}
)";
auto module = CreateModuleFromHloString(hlo_string);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 1e78d775c8..dfdfeb49a2 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -135,7 +136,6 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>(
} // namespace
-
HloEvaluator::HloEvaluator(int64 max_loop_iterations)
: max_loop_iterations_(max_loop_iterations) {
typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this);
@@ -300,12 +300,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateWithSubstitutions(
instruction->CloneWithNewOperands(instruction->shape(), operands);
auto result = Evaluate(cloned_instruction.get());
- // Clean up our cloned instructions before returning.
- cloned_instruction->DetachFromOperands();
- for (auto& operand : owned_operands) {
- operand->DetachFromOperands();
- }
-
return result;
}
@@ -321,7 +315,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseBinaryOp(
rhs_instr.get());
auto result = Evaluate(cloned_instruction.get());
- cloned_instruction->DetachFromOperands();
return result;
}
@@ -334,10 +327,27 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateElementwiseUnaryOp(
HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get());
auto result = Evaluate(cloned_instruction.get());
- cloned_instruction->DetachFromOperands();
return result;
}
+StatusOr<std::unique_ptr<Literal>> HloEvaluator::EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs) {
+ std::unique_ptr<HloInstruction> lhs_instr =
+ HloInstruction::CreateConstant(lhs.CloneToUnique());
+ std::unique_ptr<HloInstruction> rhs_instr =
+ HloInstruction::CreateConstant(rhs.CloneToUnique());
+
+ TF_ASSIGN_OR_RETURN(
+ Shape dot_shape,
+ ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
+
+ std::unique_ptr<HloInstruction> cloned_instruction =
+ HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(),
+ dim_numbers);
+ return Evaluate(cloned_instruction.get());
+}
+
Status HloEvaluator::HandleParameter(HloInstruction* parameter) {
CHECK_LT(parameter->parameter_number(), arg_literals_.size());
const Literal* input_literal = arg_literals_[parameter->parameter_number()];
@@ -372,7 +382,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
- CHECK(!ShapeUtil::IsTuple(reference_shape));
+ CHECK(ShapeUtil::IsArray(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
@@ -383,14 +393,14 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
- CHECK(!ShapeUtil::IsTuple(operand_shape));
+ CHECK(ShapeUtil::IsArray(operand_shape));
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
ShapeUtil::GetDimension(operand_shape, concat_dim);
}
- auto result_literal = Literal::CreateFromDimensions(
+ auto result_literal = LiteralUtil::CreateFromDimensions(
reference_shape.element_type(), concat_dimensions);
DimensionVector source_indices(rank, 0);
DimensionVector dest_indices(concat_dimensions.size(), 0);
@@ -541,7 +551,7 @@ Status HloEvaluator::HandleTuple(HloInstruction* tuple) {
operand_literals.push_back(&GetEvaluatedLiteralFor(operand));
}
- evaluated_[tuple] = Literal::MakeTuple(operand_literals);
+ evaluated_[tuple] = LiteralUtil::MakeTuple(operand_literals);
return Status::OK();
}
@@ -765,6 +775,12 @@ class OutputWindowIndexToInputIndex {
return ArraySlice<int64>(input_index_);
}
+ // Returns for a given 'input_dim' the corresponding output dimension index,
+ // or -1 if 'input_dim' is an elided window dimension.
+ int64 input_dim_value_to_output_index(int64 input_dim) {
+ return input_dim_value_to_output_index_[input_dim];
+ }
+
private:
// Propagates window dimensions from the output index to input_index_ by
// mutating input_index_ in place.
@@ -782,7 +798,7 @@ class OutputWindowIndexToInputIndex {
// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the output index. See
- // PropagateOutputIndexToInputIndex.
+ // PropagateOutputIndexWindowDimsToInputIndex.
std::vector<int64> input_dim_value_to_output_index_;
// The result computed by this functor. operator() returns an ArraySlice into
@@ -835,6 +851,8 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
// corresponding index in the input shape.
std::vector<int64> input_index(operand.shape().dimensions_size());
std::vector<int64> output_index(gather->shape().dimensions_size());
+ std::vector<int64> input_gather_index_clamped(
+ operand.shape().dimensions_size());
OutputGatherIndexToInputIndex output_gather_index_to_input_index(
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
@@ -856,14 +874,26 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) {
output_index[i] = output_gather_index[i] + output_window_index[i];
DCHECK_LT(output_index[i], shape.dimensions(i));
}
+ for (int i = 0, e = input_gather_index.size(); i < e; i++) {
+ int64 output_dim =
+ output_window_index_to_input_index.input_dim_value_to_output_index(i);
+ // If 'output_dim' is -1, it means 'i' is an elided window dim. This means
+ // we set the iteration index to 0, so for the purpose of the following
+ // calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : shape.dimensions(output_dim);
+ // Clamp the gather index so that the gather region fits in the operand.
+ // input_gather_index_clamped[i] = clamp(input_gather_index[i], 0,
+ // operand_shape.dimensions(i) -
+ // output_dim_size);
+ input_gather_index_clamped[i] =
+ std::min(operand_shape.dimensions(i) - output_dim_size,
+ std::max(0LL, input_gather_index[i]));
+ }
for (int i = 0, e = input_index.size(); i < e; i++) {
- // TODO(b/74360564): We should implement whatever out of bounds behavior
- // we decide for dynamic-slice here as well.
- input_index[i] = (input_gather_index[i] + input_window_index[i]) %
- operand_shape.dimensions(i);
- if (input_index[i] < 0) {
- input_index[i] += operand_shape.dimensions(i);
- }
+ input_index[i] = input_gather_index_clamped[i] + input_window_index[i];
+ DCHECK_GE(input_index[i], 0);
+ DCHECK_LT(input_index[i], operand_shape.dimensions(i));
}
TF_RETURN_IF_ERROR(
result->CopyElementFrom(operand, input_index, output_index));
@@ -910,6 +940,11 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+Status HloEvaluator::HandleAfterAll(HloInstruction* token) {
+ evaluated_[token] = LiteralUtil::CreateToken();
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
@@ -1027,8 +1062,6 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
const auto& on_false = GetEvaluatedLiteralFor(select->operand(2));
// If predicate is of scalar type, no element-wise selection would be needed.
- // This would also handle output array of tuple types as the DefaultAction
- // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples.
if (ShapeUtil::IsScalar(pred.shape())) {
if (pred.Get<bool>({})) {
evaluated_[select] = on_true.CloneToUnique();
@@ -1041,6 +1074,19 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) {
return DefaultAction(select);
}
+Status HloEvaluator::HandleTupleSelect(HloInstruction* tuple_select) {
+ const auto& pred = GetEvaluatedLiteralFor(tuple_select->operand(0));
+ const auto& on_true = GetEvaluatedLiteralFor(tuple_select->operand(1));
+ const auto& on_false = GetEvaluatedLiteralFor(tuple_select->operand(2));
+
+ if (pred.Get<bool>({})) {
+ evaluated_[tuple_select] = on_true.CloneToUnique();
+ } else {
+ evaluated_[tuple_select] = on_false.CloneToUnique();
+ }
+ return Status::OK();
+}
+
Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
HloComputation* cond_comp = while_hlo->while_condition();
HloComputation* body_comp = while_hlo->while_body();
@@ -1071,6 +1117,107 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) {
return Status::OK();
}
+// Key-value sort is a special snowflake: it's templated on two different
+// element types, one for the keys, and one for the values. Jump through some
+// hoops to make this work.
+namespace {
+template <typename KeyType, typename ValueType>
+std::unique_ptr<Literal> EvaluateSortInternal(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
+ CHECK_EQ(sort->operand_count(), 2);
+ // We need to sort and array of keys and an array of values, where the
+ // sorted order of the values is determined by the keys. The simplest(?)
+ // way to do this is to go to an array-of-pairs representation, sort the
+ // array using the keys, and then go back to pair-of-arrays.
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
+ const auto& keys_data = keys_literal.data<KeyType>();
+ const auto& values_data = values_literal.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ CHECK_EQ(keys_data.size(), values_data.size());
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ auto result_keys_literal = MakeUnique<Literal>(sort->operand(0)->shape());
+ result_keys_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<KeyType>(result_keys));
+ auto result_values_literal = MakeUnique<Literal>(sort->operand(1)->shape());
+ result_values_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ValueType>(result_values));
+ auto result_tuple = LiteralUtil::MakeTuple(
+ {result_keys_literal.get(), result_values_literal.get()});
+ VLOG(3) << "HandleSort result_tuple: " << result_tuple->ToString();
+ return result_tuple;
+}
+
+template <typename KeyType>
+StatusOr<std::unique_ptr<Literal>> EvaluateSortCurried(
+ HloInstruction* sort, const Literal& keys_literal,
+ const Literal& values_literal) {
+ switch (sort->operand(1)->shape().element_type()) {
+ case F32:
+ return EvaluateSortInternal<KeyType, float>(sort, keys_literal,
+ values_literal);
+ case U32:
+ return EvaluateSortInternal<KeyType, uint32>(sort, keys_literal,
+ values_literal);
+ case S32:
+ return EvaluateSortInternal<KeyType, int32>(sort, keys_literal,
+ values_literal);
+ case BF16:
+ return EvaluateSortInternal<KeyType, bfloat16>(sort, keys_literal,
+ values_literal);
+ default:
+ return InvalidArgument("Unsupported type for Sort");
+ }
+}
+
+StatusOr<std::unique_ptr<Literal>> EvaluateSort(HloInstruction* sort,
+ const Literal& keys_literal,
+ const Literal& values_literal) {
+ switch (sort->operand(0)->shape().element_type()) {
+ case F32:
+ return EvaluateSortCurried<float>(sort, keys_literal, values_literal);
+ case U32:
+ return EvaluateSortCurried<uint32>(sort, keys_literal, values_literal);
+ case S32:
+ return EvaluateSortCurried<int32>(sort, keys_literal, values_literal);
+ case BF16:
+ return EvaluateSortCurried<bfloat16>(sort, keys_literal, values_literal);
+ default:
+ return InvalidArgument("Unsupported type for Sort");
+ }
+}
+} // namespace
+
+Status HloEvaluator::HandleSort(HloInstruction* sort) {
+ if (!ShapeUtil::IsTuple(sort->shape())) {
+ return DefaultAction(sort);
+ } else {
+ auto result = EvaluateSort(sort, GetEvaluatedLiteralFor(sort->operand(0)),
+ GetEvaluatedLiteralFor(sort->operand(1)));
+ if (result.ok()) {
+ evaluated_[sort] = std::move(result.ValueOrDie());
+ return Status::OK();
+ } else {
+ return result.status();
+ }
+ }
+}
+
Status HloEvaluator::Preprocess(HloInstruction* hlo) {
VLOG(2) << "About to visit HLO: " << hlo->ToString();
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index b53d5644de..a4c37ef328 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -115,6 +116,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
StatusOr<std::unique_ptr<Literal>> EvaluateElementwiseUnaryOp(
HloOpcode opcode, const Literal& operand);
+ StatusOr<std::unique_ptr<Literal>> EvaluateDotOp(
+ const DotDimensionNumbers& dim_numbers, const Literal& lhs,
+ const Literal& rhs);
+
protected:
// Make HloEvaluatorTypedVisitor a friend because it is logically part of this
// class.
@@ -172,8 +177,14 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
+
Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleAfterAll(HloInstruction* token) override;
+
+ Status HandleSort(HloInstruction* sort) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 84b4ead2dd..5f575b24a1 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
@@ -112,9 +112,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
// Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
// with 3 operands.
TEST_P(HloEvaluatorTest, DoesClamp) {
- auto low = Literal::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
- auto value = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
- auto high = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
+ auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
+ auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
+ auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
Shape shape = low->shape();
HloComputation::Builder b(TestName());
@@ -127,15 +127,15 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
- auto low = Literal::CreateR0<float>(0.f);
- auto value = Literal::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
- auto high = Literal::CreateR0<float>(1.f);
+ auto low = LiteralUtil::CreateR0<float>(0.f);
+ auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
+ auto high = LiteralUtil::CreateR0<float>(1.f);
Shape shape = value->shape();
HloComputation::Builder b(TestName());
@@ -148,7 +148,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -156,9 +156,9 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
// with 3 operands.
TEST_P(HloEvaluatorTest, DoesSelect) {
- auto pred = Literal::CreateR2<bool>({{true, false}, {false, true}});
- auto on_true = Literal::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
- auto on_false = Literal::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
+ auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
+ auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
+ auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
Shape shape = on_true->shape();
HloComputation::Builder b(TestName());
@@ -173,7 +173,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
std::unique_ptr<Literal> result = Evaluate({});
- auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -181,37 +181,46 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise addition with 2 operands.
TEST_P(HloEvaluatorTest, DoesAdd) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{3, 4}, {-96, 8}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-96, 8}});
TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise and with 2 operands.
TEST_P(HloEvaluatorTest, DoesAnd) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{0, 0}, {4, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {4, 4}});
TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise or with 2 operands.
TEST_P(HloEvaluatorTest, DoesOr) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{3, 4}, {-100, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-100, 4}});
TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
std::move(rhs));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
+// element-wise or with 2 operands.
+TEST_P(HloEvaluatorTest, DoesXor) {
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{3, 4}, {-104, 0}});
+ TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
+ std::move(rhs));
+}
+// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise multiply with 2 operands.
TEST_P(HloEvaluatorTest, DoesMultiply) {
- auto lhs = Literal::CreateR2<int32>({{-1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int32>(
+ auto lhs = LiteralUtil::CreateR2<int32>({{-1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 4}, {4, 4}});
- auto expected = Literal::CreateR2<int32>(
+ auto expected = LiteralUtil::CreateR2<int32>(
{{std::numeric_limits<int32>::min(), 0}, {-400, 16}});
TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
std::move(rhs));
@@ -219,17 +228,17 @@ TEST_P(HloEvaluatorTest, DoesMultiply) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise divide with 2 operands.
TEST_P(HloEvaluatorTest, DoesDivideInt64) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto expected = Literal::CreateR2<int64>({{0, 0}, {-25, 1}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{0, 0}, {-25, 1}});
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
TEST_P(HloEvaluatorTest, DoesDivideDouble) {
- auto lhs = Literal::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
- auto rhs = Literal::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
+ auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
+ auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
auto expected =
- Literal::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
+ LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
std::move(rhs));
}
@@ -237,54 +246,54 @@ TEST_P(HloEvaluatorTest, DoesDivideDouble) {
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise abs op with 1 operand.
TEST_P(HloEvaluatorTest, DoesAbsR2) {
- auto operand = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
- auto expected = Literal::CreateR2<int64>({{1, 20}, {100, 4}});
+ auto operand = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
+ auto expected = LiteralUtil::CreateR2<int64>({{1, 20}, {100, 4}});
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesAbsR0) {
- auto operand = Literal::CreateR0<float>(-1.0f);
- auto expected = Literal::CreateR0<float>(1.0f);
+ auto operand = LiteralUtil::CreateR0<float>(-1.0f);
+ auto expected = LiteralUtil::CreateR0<float>(1.0f);
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) {
- auto operand = Literal::CreateR1<float>({});
- auto expected = Literal::CreateR1<float>({});
+ auto operand = LiteralUtil::CreateR1<float>({});
+ auto expected = LiteralUtil::CreateR1<float>({});
TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesNegateR2) {
- auto operand = Literal::CreateR2<int32>(
+ auto operand = LiteralUtil::CreateR2<int32>(
{{0, std::numeric_limits<int32>::min()}, {-1, 4}});
- auto expected =
- Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()}, {1, -4}});
+ auto expected = LiteralUtil::CreateR2<int32>(
+ {{0, std::numeric_limits<int>::min()}, {1, -4}});
TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
}
TEST_P(HloEvaluatorTest, DoesCosR2) {
- auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
- auto expected = Literal::CreateR2<float>({{1, -1}, {-1, 1}});
+ auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
+ auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesSinR2) {
- auto operand = Literal::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
- auto expected = Literal::CreateR2<float>({{0, 0}, {0, 0}});
+ auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
+ auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
}
TEST_P(HloEvaluatorTest, DoesNotR2) {
auto operand =
- Literal::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
- {-1, std::numeric_limits<int>::max()}});
+ LiteralUtil::CreateR2<int32>({{0, std::numeric_limits<int>::min()},
+ {-1, std::numeric_limits<int>::max()}});
auto expected =
- Literal::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
- {0, std::numeric_limits<int>::min()}});
+ LiteralUtil::CreateR2<int32>({{-1, std::numeric_limits<int>::max()},
+ {0, std::numeric_limits<int>::min()}});
TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
}
// Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
// constant operands.
TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
- auto lhs = Literal::CreateR2<int64>({{1, 0}, {-100, 4}});
- auto rhs = Literal::CreateR2<int64>({{2, 4}, {4, 4}});
- auto rhs2 = Literal::CreateR2<int64>({{1, -20}, {-100, 4}});
+ auto lhs = LiteralUtil::CreateR2<int64>({{1, 0}, {-100, 4}});
+ auto rhs = LiteralUtil::CreateR2<int64>({{2, 4}, {4, 4}});
+ auto rhs2 = LiteralUtil::CreateR2<int64>({{1, -20}, {-100, 4}});
std::vector<const Literal*> args = {lhs.get(), rhs.get(), rhs2.get()};
Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
@@ -305,7 +314,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
std::unique_ptr<Literal> result = Evaluate(args);
- auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
+ auto expected = LiteralUtil::CreateR2<int64>({{4, -16}, {-196, 12}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -315,7 +324,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- Literal::CreateRandomLiteral<F32>(
+ LiteralUtil::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction =
@@ -340,8 +349,8 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
// Verifies Broadcast operation is correctly evaluated.
TEST_P(HloEvaluatorTest, DoesBroadcast) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
- auto output_literal = Literal::CreateR3<int32>(
+ auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
+ auto output_literal = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
HloInstruction* literal_instruction = b.AddInstruction(
HloInstruction::CreateConstant(std::move(input_literal)));
@@ -356,8 +365,8 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR0<int32>(111);
- auto output_literal = Literal::CreateR2<int32>(
+ auto input_literal = LiteralUtil::CreateR0<int32>(111);
+ auto output_literal = LiteralUtil::CreateR2<int32>(
{{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
HloInstruction* literal_instruction = b.AddInstruction(
@@ -377,9 +386,9 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
HloComputation::Builder b(TestName());
HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int64>({{-1, -2}, {100, 200}})));
+ LiteralUtil::CreateR2<int64>({{-1, -2}, {100, 200}})));
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int64>({{-2, -3}, {-100, -200}})));
+ LiteralUtil::CreateR2<int64>({{-2, -3}, {-100, -200}})));
std::vector<HloInstruction*> operands = {operand1, operand2};
@@ -390,8 +399,8 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected =
- Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
+ auto expected = LiteralUtil::CreateR2<int64>(
+ {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -399,9 +408,9 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
HloComputation::Builder b(TestName());
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({100, 200})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({100, 200})));
HloInstruction* operand2 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({})));
std::vector<HloInstruction*> operands = {operand1, operand2};
@@ -412,16 +421,16 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<int64>({100, 200});
+ auto expected = LiteralUtil::CreateR1<int64>({100, 200});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
+ auto input_literal = LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}});
auto expected =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
@@ -438,9 +447,9 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
HloComputation::Builder b(TestName());
- auto input_literal = Literal::CreateR2WithLayout<int32>(
+ auto input_literal = LiteralUtil::CreateR2WithLayout<int32>(
{{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
- auto expected = Literal::CreateR2WithLayout<float>(
+ auto expected = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(),
expected->shape()));
@@ -469,13 +478,13 @@ PaddingConfig CreatePaddingConfig(
}
TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
- auto operand = Literal::CreateR2<int32>({{}, {}});
+ auto operand = LiteralUtil::CreateR2<int32>({{}, {}});
HloComputation::Builder b(TestName());
auto operand_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
constexpr int32 kPadValue = 10;
- auto pad_value = Literal::CreateR0<int32>(kPadValue);
+ auto pad_value = LiteralUtil::CreateR0<int32>(kPadValue);
auto padding_value_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
@@ -487,7 +496,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<int32>(
+ auto expected = LiteralUtil::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -497,11 +506,11 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
HloComputation::Builder b(TestName());
Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
constexpr float kPadValue = 1.5;
- auto pad_value = Literal::CreateR0<float>(kPadValue);
+ auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
HloInstruction* pad_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
@@ -523,7 +532,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
(*expected_array)(7, 0, 0, 0) = 5.0f;
(*expected_array)(7, 2, 0, 0) = 6.0f;
- auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -540,12 +549,12 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
// }
auto input_array = MakeUnique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
- auto input = Literal::CreateR2FromArray2D<float>(*input_array);
+ auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
auto pad_value_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.718f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
auto r2_padding_on_dim0_dim1 =
CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
@@ -565,7 +574,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 2) = 2.718f;
(*expected_array)(0, 3) = 2.718f;
(*expected_array)(0, 4) = 2.718f;
- auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250)));
}
@@ -581,12 +590,12 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
// }
auto input_array = MakeUnique<Array2D<float>>(4, 3);
input_array->FillUnique(1.0f);
- auto input = Literal::CreateR2FromArray2D<float>(*input_array);
+ auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
HloInstruction* input_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
auto pad_value_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.718f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
PaddingConfig padding_config = MakeNoPaddingConfig(2);
@@ -604,7 +613,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
std::unique_ptr<Literal> result = Evaluate();
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
- auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -621,13 +630,13 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// }
auto lhs_array = MakeUnique<Array2D<float>>(4, 1);
lhs_array->FillUnique(1.0f);
- auto lhs_literal = Literal::CreateR2FromArray2D<float>(*lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
// rhs:
// f32[2] { 1, 2 },
- auto rhs_literal = Literal::CreateR2<float>({{1, 2}});
+ auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -649,7 +658,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
{4.f, 8.f},
});
// clang-format on
- auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -660,7 +669,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// lhs:
// f32[3]
// { 1, 2, 3 },
- auto lhs_literal = Literal::CreateR1<float>({1, 2, 3});
+ auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -672,7 +681,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
// }
auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
- auto rhs_literal = Literal::CreateR2FromArray2D<float>(*rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -686,7 +695,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<float>({22.f, 28.f});
+ auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -703,7 +712,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// }
auto lhs_array = MakeUnique<Array2D<float>>(4, 3);
lhs_array->FillUnique(1.0f);
- auto lhs_literal = Literal::CreateR2FromArray2D<float>(*lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -715,7 +724,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
// }
auto rhs_array = MakeUnique<Array2D<float>>(3, 2);
rhs_array->FillUnique(1.0f);
- auto rhs_literal = Literal::CreateR2FromArray2D<float>(*rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -735,7 +744,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
{94.f, 124.f},
{130.f, 172.f},
});
- auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -744,12 +753,12 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
HloComputation::Builder b(TestName());
Array3D<float> lhs_array = {{{1, 2, 3}}};
- auto lhs_literal = Literal::CreateR3FromArray3D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
Array3D<float> rhs_array = {{{3.f, 4.f}}};
- auto rhs_literal = Literal::CreateR3FromArray3D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -783,7 +792,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
std::unique_ptr<Literal> result = Evaluate();
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
- auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -800,7 +809,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -811,7 +820,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -845,7 +854,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
{149, 160, 171, 80},
}));
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -875,11 +884,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
}});
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
@@ -924,7 +933,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
Array4D<float> expected_array({{{{2514, 2685}}}});
Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -955,11 +964,11 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
}});
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(weight);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1001,7 +1010,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
Array4D<float> expected_array({{{{2514, 2685}}}});
Array4D<float> expected_array_bf16({{{{2512, 2672}}}});
// clang-format on
- auto expected = Literal::CreateR4FromArray4D<float>(
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
@@ -1019,7 +1028,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1030,7 +1039,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1065,7 +1074,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
{91, 112, 98, 120, 105, 128, 112},
{65, 84, 70, 90, 75, 96, 80},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1082,7 +1091,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1093,7 +1102,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{7, 8},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1129,7 +1138,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
{104, 91, 112, 98, 120, 105, 128, 112},
{78, 65, 84, 70, 90, 75, 96, 80},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1147,7 +1156,7 @@ TEST_P(HloEvaluatorTest,
{13, 14, 15, 16},
}));
// clang-format on
- auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs_array);
+ auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
HloInstruction* lhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
@@ -1158,7 +1167,7 @@ TEST_P(HloEvaluatorTest,
{8, 9, 10},
}));
// clang-format on
- auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs_array);
+ auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
HloInstruction* rhs_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
@@ -1201,7 +1210,7 @@ TEST_P(HloEvaluatorTest,
{0, 0, 0},
{91, 98, 105},
}));
- auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1216,9 +1225,9 @@ TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
std::vector<float> v(kNumElements, 1.0f);
HloInstruction* arg_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
HloInstruction* init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1248,14 +1257,14 @@ void BM_ReducePrecisely(int num_iters) {
HloComputation::Builder b("BM_ReducePrecisely");
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config);
+ HloModule module("BM_ReducePrecisely", config);
constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
std::vector<float> v(kNumElements, 1.0f);
HloInstruction* arg_instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(v)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1290,13 +1299,13 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1317,7 +1326,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR1<float>({6, 18});
+ auto expected = LiteralUtil::CreateR1<float>({6, 18});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1332,13 +1341,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder max_computation("max");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1369,7 +1378,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{6, 7}});
+ auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1383,13 +1392,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
// }
auto arg_array = MakeUnique<Array2D<float>>(2, 3);
arg_array->FillUnique(1.0f);
- auto arg_literal = Literal::CreateR2FromArray2D<float>(*arg_array);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1426,7 +1435,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
+ auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
@@ -1436,13 +1445,13 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
// arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
std::vector<int64> input_dims(6, 4);
std::unique_ptr<Literal> arg_literal =
- Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
HloInstruction* arg_instruction =
b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
auto init_value = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
HloComputation::Builder add_computation("add");
Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -1489,7 +1498,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal =
- Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
}
@@ -1504,7 +1513,8 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(3, 5);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
@@ -1518,7 +1528,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{3},
{19},
});
@@ -1536,13 +1546,14 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0, 1})));
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
@@ -1551,7 +1562,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
@@ -1571,13 +1582,14 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
// }
auto operand_array = MakeUnique<Array2D<float>>(2, 4);
operand_array->FillUnique(1.0f);
- auto operand_literal = Literal::CreateR2FromArray2D<float>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2, 1})));
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand,
@@ -1586,7 +1598,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<float>({
+ auto expected = LiteralUtil::CreateR2<float>({
{2, 3, 4},
{6, 7, 8},
});
@@ -1604,16 +1616,17 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
// }
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
- auto operand_literal = Literal::CreateR2FromArray2D<double>(*operand_array);
+ auto operand_literal =
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
auto start_indices = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto update = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
+ LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
@@ -1622,7 +1635,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<double>({
+ auto expected = LiteralUtil::CreateR2<double>({
{1, -2, -3},
{5, -6, -7},
});
@@ -1640,12 +1653,13 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
// }
auto operand_array = MakeUnique<Array2D<double>>(2, 3);
operand_array->FillUnique(1.0);
- auto operand_literal2 = Literal::CreateR2FromArray2D<double>(*operand_array);
+ auto operand_literal2 =
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
HloInstruction* operand2 = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal2)));
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto tuple =
b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
@@ -1657,7 +1671,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
std::unique_ptr<Literal> result = Evaluate();
- auto expected = Literal::CreateR2<double>({
+ auto expected = LiteralUtil::CreateR2<double>({
{1, 2, 3},
{5, 6, 7},
});
@@ -1677,9 +1691,9 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
operand_array->FillUnique(1.0);
HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D<double>(*operand_array)));
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
HloInstruction* operand1 = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int64>({0, 1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64>({0, 1})));
auto tuple1 =
b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
@@ -1697,8 +1711,8 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
std::unique_ptr<Literal> result = Evaluate();
auto result_inner_literal =
- Literal::CreateR2FromArray2D<double>(*operand_array);
- auto expected = Literal::MakeTuple({
+ LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
+ auto expected = LiteralUtil::MakeTuple({
result_inner_literal.get(),
result_inner_literal.get(),
});
@@ -1726,7 +1740,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
{{23.0f}, {24.0f}}},
});
// clang-format on
- auto operand_literal = Literal::CreateR4FromArray4D<float>(input);
+ auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
HloInstruction* operand = b.AddInstruction(
HloInstruction::CreateConstant(std::move(operand_literal)));
@@ -1737,7 +1751,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
std::unique_ptr<Literal> result = Evaluate();
// clang-format off
- auto expected = Literal::CreateR4FromArray4D<float>({
+ auto expected = LiteralUtil::CreateR4FromArray4D<float>({
{{{23.0f}, {24.0f}},
{{21.0f}, {22.0f}},
{{19.0f}, {20.0f}}},
@@ -1773,11 +1787,11 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
// Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
HloEvaluator evaluator;
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
- {square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{param0, LiteralUtil::CreateR1<float>({1, 2, 3, 4}).get()},
+ {square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1790,18 +1804,18 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kMultiply, param0, param0));
- HloInstruction* constant = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
+ HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
HloInstruction* add = b.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
// Evaluate add with square = {10, 20, 30, 40}.
HloEvaluator evaluator;
auto result = evaluator.EvaluateWithSubstitutions(
- add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
+ add, {{square, LiteralUtil::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
+ *LiteralUtil::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1821,11 +1835,12 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1845,10 +1860,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1869,11 +1885,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<int32>(
+ *LiteralUtil::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1895,13 +1911,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-1, 1}, {-4, 4}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1923,13 +1939,13 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{-2, 2}, {-1, 1}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1950,10 +1966,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{5}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1974,11 +1991,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{8}}, {{5}}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -1998,10 +2015,11 @@ ENTRY main {
}
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{}, {}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -2022,11 +2040,11 @@ ENTRY main {
)";
ParseAndVerifyModule(hlo_text);
- std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR2<int32>({{0, 1}, {2, 1}}),
*Evaluate({operand.get(), gather_indices.get()})));
}
@@ -2034,14 +2052,14 @@ ENTRY main {
// element-wise comparison with 2 bfloat16 operands.
TEST_P(HloEvaluatorTest, DoesCompareBF16) {
// lhs >= rhs
- auto lhs = Literal::CreateR2<bfloat16>(
+ auto lhs = LiteralUtil::CreateR2<bfloat16>(
{{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
{bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
- auto rhs = Literal::CreateR2<bfloat16>(
+ auto rhs = LiteralUtil::CreateR2<bfloat16>(
{{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
{bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
auto expected =
- Literal::CreateR2<bool>({{false, true, true}, {false, true, true}});
+ LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs),
std::move(rhs));
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b1b58642ec..2ae5f8bf36 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/core/lib/core/casts.h"
@@ -34,6 +35,37 @@ using is_complex_t = std::is_same<T, complex64>;
template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
+// It's UB to use std::sort with std::less<float>, because of NaNs. Define
+// "safe" less functions which are actually strict weak orders.
+template <
+ typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ return a < b;
+}
+
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_floating_point<NativeT>::value ||
+ std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (std::isnan(b)) {
+ return !std::isnan(a);
+ } else {
+ return a < b;
+ }
+}
+
+template <typename NativeT, typename std::enable_if<std::is_same<
+ NativeT, Eigen::half>::value>::type* = nullptr>
+bool SafeLess(const NativeT& a, const NativeT& b) {
+ if (Eigen::half_impl::isnan(b)) {
+ return !Eigen::half_impl::isnan(a);
+ } else {
+ return a < b;
+ }
+}
+
// Templated DfsHloVisitor for use by HloEvaluator.
//
// Typically ReturnT here indicates the resulting literal type of each evaluated
@@ -610,12 +642,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleAnd(HloInstruction* and_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[and_],
- ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el && rhs_el;
- }));
- return Status::OK();
+ return InvalidArgument("Unsupported type for And");
}
template <
@@ -644,12 +671,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleOr(HloInstruction* or_) {
- TF_ASSIGN_OR_RETURN(
- parent_->evaluated_[or_],
- ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
- return lhs_el || rhs_el;
- }));
- return Status::OK();
+ return InvalidArgument("Unsupported type for Or");
}
template <
@@ -664,6 +686,35 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
+ Status HandleXor(HloInstruction* xor_) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[xor_],
+ ElementWiseBinaryOp(xor_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) {
+ return lhs_el ^ rhs_el;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleXor(HloInstruction* xor_) {
+ return InvalidArgument("Unsupported type for Xor");
+ }
+
+ template <
+ typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
+ Status HandleXor(HloInstruction* xor_) {
+ return InvalidArgument("Unsupported type for Xor");
+ }
+
+ Status HandleXor(HloInstruction* xor_) override {
+ return HandleXor<ElementwiseT>(xor_);
+ }
+
+ template <typename NativeT,
typename std::enable_if<
std::is_integral<NativeT>::value &&
!std::is_same<NativeT, bool>::value>::type* = nullptr>
@@ -778,7 +829,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override {
CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
- CHECK(!ShapeUtil::IsTuple(select->shape()));
+ CHECK(ShapeUtil::IsArray(select->shape()));
std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
[](bool pred, ReturnT on_true, ReturnT on_false) {
if (pred) {
@@ -1006,83 +1057,47 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
CHECK_EQ(dnums.lhs_batch_dimensions_size(),
dnums.rhs_batch_dimensions_size());
- std::vector<int64> lhs_non_contracting_dims;
+ DimensionVector lhs_index(lhs_rank);
+ DimensionVector rhs_index(rhs_rank);
+
+ // result_index_locations[i] contains one or two pointers to the locations
+ // in lhs_index or rhs_index where the i'th result index should go.
+ tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank>
+ result_index_locations;
+ result_index_locations.reserve(lhs_rank + rhs_rank - 2);
+
+ // The first components in the output shape are the LHS and RHS batch
+ // dimensions:
+ for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); i++) {
+ result_index_locations.push_back(
+ {&lhs_index[dnums.lhs_batch_dimensions(i)],
+ &rhs_index[dnums.rhs_batch_dimensions(i)]});
+ }
+
+ // Then we have the LHS and RHS non-contracting dimensions, if any:
for (int64 i = 0; i < lhs_rank; i++) {
- if (i != lhs_contracting_dimension) {
- lhs_non_contracting_dims.push_back(i);
+ if (i != lhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.lhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&lhs_index[i], nullptr});
}
}
-
- std::vector<int64> rhs_non_batch_non_contracting_dims;
- tensorflow::gtl::FlatSet<int64> batch_dims_set(
- dnums.rhs_batch_dimensions().begin(),
- dnums.rhs_batch_dimensions().end());
for (int64 i = 0; i < rhs_rank; i++) {
- if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) {
- rhs_non_batch_non_contracting_dims.push_back(i);
+ if (i != rhs_contracting_dimension &&
+ !ArrayContains(AsInt64Slice(dnums.rhs_batch_dimensions()), i)) {
+ result_index_locations.push_back({&rhs_index[i], nullptr});
}
}
- const int64 batch_dim_size = dnums.lhs_batch_dimensions_size();
- const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size();
-
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
auto result = MakeUnique<Literal>(dot->shape());
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
[&](tensorflow::gtl::ArraySlice<int64> result_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
- // Find the corresponding non-contracting indices for lhs and rhs.
- //
- // For `result_index`, its batch dimension, if exists, will be at the
- // same dimension as the batch dimension of lhs and rhs. More
- // specifically:
- // - For lhs, the non-contracting dimensions, including the batch
- // dimension have the same index as the `result_index`.
- // - For rhs, the batch dimension is set seperately from other
- // non-contracting dimensions, since these other non-contracting
- // dimensions in rhs follow the non-contracting dimensions of lhs in
- // the resulting index.
- //
- // As an example, for a resulting index:
- // result_index [result_batch, result_x, result_y]
- // the effecting lhs and rhs indices are:
- // lhs [result_batch, lhs_non_contracting_dim, contracting_dim
- // rhs [result_batch, contracting_dim, rhs_non_contracting_dim]
- // `result_x` is only affected by the lhs_non_contracting_dim and
- // likewise `result_y` only depends on rhs_non_contracting_dim.
- //
- // so we can look up the lhs and rhs indices by:
- //
- // lhs:
- // batch index is the same as `result_batch`.
- // non-contracting dimension is the same as
- // result_index[lhs_non_contracting_dim]
- // rhs:
- // batch index: the same as `result_batch`.
- // non-contracting dimension index: *not* the same as
- // result_index[rhs_non_contractng_dim], since the
- // non-contracting dimensions of lhs are included in the
- // result_index first. Instead, the non_contracting_dim of rhs must
- // be calculated as following:
- // lhs_non_contracting_dimensions_size +
- // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1
- //
- // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is
- // the index offset to the result_index that only depends on
- // the non_batch and non-contracting dimensions of rhs. -1 at the
- // end translates size to index.
- for (auto i : lhs_non_contracting_dims) {
- lhs_index[i] = result_index[i];
- }
- for (auto i : dnums.rhs_batch_dimensions()) {
- rhs_index[i] = result_index[i];
- }
- for (auto i : rhs_non_batch_non_contracting_dims) {
- const int64 rhs_non_batch_non_contracting_dim =
- lhs_non_contracting_size + (i - batch_dim_size) - 1;
- rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim];
+ for (int64 i = 0; i < result_index.size(); i++) {
+ *result_index_locations[i].first = result_index[i];
+ if (result_index_locations[i].second) {
+ *result_index_locations[i].second = result_index[i];
+ }
}
// Accumulates resulting product along the contracted dimension.
@@ -1103,7 +1118,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
Status HandlePad(HloInstruction* pad) override {
- CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
+ CHECK(ShapeUtil::IsArray(pad->operand(0)->shape()));
// Padding value must be scalar.
CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
@@ -1116,7 +1131,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
/*padding_config=*/pad->padding_config()));
CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(pad->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
// Create new HLO of padded shape with padding value.
@@ -1182,7 +1197,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
dynamic_slice->dynamic_slice_sizes()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
@@ -1237,7 +1252,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
operand->shape(), update->shape(), start_indices->shape()));
TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(result_shape)
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
TF_RET_CHECK(
primitive_util::IsIntegralType(start_indices->shape().element_type()));
@@ -1302,7 +1317,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(operand);
auto curr_val = arg_literal.Get<NativeT>(multi_index);
- auto curr_val_literal = Literal::CreateR0<NativeT>(curr_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<NativeT>(curr_val);
arg_literals.push_back(std::move(curr_val_literal));
}
@@ -1378,6 +1393,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ template <typename NativeT,
+ typename std::enable_if<
+ !is_complex_t<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleSort(HloInstruction* sort) {
+ auto keys = sort->operand(0);
+ TF_RET_CHECK(ShapeUtil::Rank(keys->shape()) == 1)
+ << "Sort is only supported for R1 shapes";
+ TF_RET_CHECK(sort->operand_count() == 1)
+ << "Typed visitor does not support key-value sort";
+
+ const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
+ VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
+ const auto& keys_data = keys_literal.data<ReturnT>();
+
+ std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const ReturnT& a, const ReturnT& b) {
+ return SafeLess<ReturnT>(a, b);
+ });
+ auto result_literal = MakeUnique<Literal>(sort->shape());
+ result_literal->PopulateR1(
+ tensorflow::gtl::ArraySlice<ReturnT>(result_data));
+ VLOG(3) << "HandleSort result_literal: " << result_literal->ToString();
+ parent_->evaluated_[sort] = std::move(result_literal);
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<is_complex_t<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleSort(HloInstruction* sort) {
+ return InvalidArgument("Unsupported type for Sort");
+ }
+
+ Status HandleSort(HloInstruction* sort) override {
+ return HandleSort<ReturnT>(sort);
+ }
+
Status HandleReduce(HloInstruction* reduce) override {
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
@@ -1393,7 +1448,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
/*to_apply=*/function->ComputeProgramShape()));
TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape))
<< "return shape is set to: " << ShapeUtil::HumanString(reduce->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanString(inferred_return_shape);
const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg);
@@ -1450,8 +1505,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
auto curr_val = arg_literal.Get<ReturnT>(input_index);
// Evaluate computation with specified literal operands.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
- auto result_val_literal = Literal::CreateR0<ReturnT>(result_val);
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(curr_val);
+ auto result_val_literal =
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
@@ -1529,10 +1585,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid
// dynamic memory allocations.
- auto curr_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto selected_val_literal = Literal::CreateR0<ReturnT>(ReturnT());
- auto source_literal_scatter = Literal::CreateR0<ReturnT>(ReturnT());
- auto scattered_literal = Literal::CreateR0<ReturnT>(ReturnT());
+ auto curr_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto selected_val_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto source_literal_scatter = LiteralUtil::CreateR0<ReturnT>(ReturnT());
+ auto scattered_literal = LiteralUtil::CreateR0<ReturnT>(ReturnT());
do {
// For each element in `source`, we place a window in `operand`. For each
// window placement, we iterate inside the window twice:
@@ -1613,7 +1669,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape))
<< "return shape is set to: "
<< ShapeUtil::HumanStringWithLayout(reduce_window->shape())
- << "but is inferred to be: "
+ << " but is inferred to be: "
<< ShapeUtil::HumanStringWithLayout(inferred_return_shape);
const Literal& operand_literal =
@@ -1653,9 +1709,9 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Evaluate computation with specified literal operands.
const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
+ LiteralUtil::CreateR0<ReturnT>(curr_val);
const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
+ LiteralUtil::CreateR0<ReturnT>(result_val);
std::unique_ptr<Literal> computed_result =
embedded_evaluator
.Evaluate<const Literal*>(
@@ -1700,7 +1756,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = Literal::CreateFromDimensions(
+ auto result = LiteralUtil::CreateFromDimensions(
shape.element_type(), AsInt64Slice(shape.dimensions()));
TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
@@ -1962,7 +2018,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// TODO(b/74360564): This is implementation defined behavior, but is
// currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
+ // to officially document different behavior.
for (int64 i = 0; i < start.size(); ++i) {
start[i] = std::min<int64>(
std::max(int64{0}, start[i]),
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index 4900c813fd..eba80c0f19 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -29,7 +29,7 @@ using ::testing::ContainsRegex;
class HloExecutionProfileTest : public HloTestBase {};
TEST_F(HloExecutionProfileTest, Basic) {
- auto hlo_module = tools::Parse(R"(
+ auto hlo_module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
lhs = f32[30,30]{1,0} parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 05adb45713..57cf34d7de 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -27,7 +27,9 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -590,15 +592,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
const HloInstruction* parent_instr) {
VLOG(2) << "Dumping subcomputation " << subcomp->name();
- const char* computation_fmt = R"(subgraph %s {
-%s
-label = <%s>;
-labelloc = t;
-tooltip = " ";
-%s
-} // %s
+ // Add an edge from the subcomputation to its parent node. If subcomp
+ // belongs to a fusion node, it's drawn in place of the fusion instruction,
+ // so there's no need to link those.
+ if (parent_instr->opcode() != HloOpcode::kFusion) {
+ const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
+ VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
+ << " as " << next_edge_id_;
+ edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
+ const char* edge_fmt =
+ R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
+ edges_.push_back(Printf(
+ edge_fmt, InstructionId(from), InstructionId(parent_instr),
+ SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
+ }
-)";
+ // Have we already dumped this subcomputation? If so, generating the edge
+ // linking it and parent_instr is all we want to do in this function.
+ if (cluster_ids_.find(subcomp) != cluster_ids_.end()) {
+ return "";
+ }
cluster_ids_[subcomp] = next_cluster_id_++;
@@ -645,25 +658,16 @@ tooltip = " ";
string comp_body = DumpComputation(subcomp);
- // Add an edge from the subcomputation to its parent node. If subcomp
- // belongs to a fusion node, it's drawn in place of the fusion instruction,
- // so there's no need to link those.
- if (parent_instr->opcode() != HloOpcode::kFusion) {
- const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction());
- VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name()
- << " as " << next_edge_id_;
- edge_ids_.insert({{from, parent_instr}, next_edge_id_++});
- const char* edge_fmt =
- R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)";
- edges_.push_back(Printf(
- edge_fmt, InstructionId(from), InstructionId(parent_instr),
- SubcomputationId(subcomp), subcomp->name(), parent_instr->name()));
- }
-
- string computation =
- Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
+ const char* computation_fmt = R"(subgraph %s {
+%s
+label = <%s>;
+labelloc = t;
+tooltip = " ";
+%s
+} // %s
- return computation;
+)";
+ return Printf(computation_fmt, id, style, subcomp_label, comp_body, id);
}
string HloDotDumper::DumpComputation(const HloComputation* comp) {
@@ -721,11 +725,25 @@ string HloDotDumper::DumpRootTag() {
to_id, node_body, node_shape, NodeColorAttributes(color));
}
+static const HloConstantInstruction* TryGetFusionParameterConstant(
+ const HloInstruction* instr) {
+ if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
+ return nullptr;
+ }
+ const HloInstruction* fusion = instr->parent()->FusionInstruction();
+ const HloInstruction* operand = fusion->operand(instr->parameter_number());
+ return DynCast<HloConstantInstruction>(operand);
+}
+
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
// If a node:
//
- // - is a tuple-shaped parameter,
- // - is not a parameter to a fusion node,
+ // - is a parameter of a fusion node which is bound to a constant,
+ //
+ // or
+ //
+ // - is a tuple-shaped parameter, and
+ // - is not a parameter to a fusion node, and
// - has at least kMinUsersToOmit users shown, and
// - all of the shown users are get-tuple-elements,
//
@@ -733,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
//
// This helps us handle the common case where a while loop body has one big
// tuple-shaped parameter.
+ if (TryGetFusionParameterConstant(instr) != nullptr) {
+ return true;
+ }
const int kMinUsersToOmit = 3;
return instr->opcode() == HloOpcode::kParameter &&
ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() &&
@@ -804,26 +825,26 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
- auto stringify_constant = [](const HloInstruction* constant) {
+ auto stringify_constant = [](const HloConstantInstruction* constant) {
const auto& shape = constant->shape();
// If the shape has a dimension of size zero, print it as e.g.
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
- if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
optional<int64> elem_count;
- if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
+ if (ShapeUtil::IsArray(shape)) {
elem_count = 1;
for (int64 dim : shape.dimensions()) {
*elem_count *= dim;
}
}
- if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
+ if (elem_count.has_value() && *elem_count <= 8) {
return Printf("%s (%s)", constant->literal().ToString(),
ShapeUtil::HumanString(constant->shape()));
}
@@ -839,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
ShapeUtil::HumanString(constant->shape()));
};
- // Special case: If instr is a parameter to a fusion node, check whether the
- // corresponding operand to the fusion node is a constant.
- if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) {
- const HloInstruction* fusion = instr->parent()->FusionInstruction();
- const HloInstruction* operand = fusion->operand(instr->parameter_number());
- if (operand->opcode() != HloOpcode::kConstant) {
- return "";
- }
- return StrCat("<b>constant</b> ", stringify_constant(operand));
- }
-
std::vector<string> lines;
for (int64 i = 0; i < instr->operand_count(); ++i) {
const HloInstruction* operand = instr->operand(i);
+ const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
optional<string> operand_str;
- if (operand->opcode() == HloOpcode::kConstant) {
- operand_str = stringify_constant(operand);
+ if (constant_operand != nullptr) {
+ operand_str = stringify_constant(constant_operand);
} else if (ShouldMergeIntoUsers(operand)) {
- // Special case: If the operand is a parameter, use its parameter number
- // rather than its name, because that's generally how people think of the
- // node.
+ // Special case: If the operand is a parameter to a fusion node and it
+ // always has a constant value, display it like a regular constant.
+ //
+ // For other parameters, use the parameter number rather than the proper
+ // name, because that's generally how people think of the node.
if (operand->opcode() == HloOpcode::kParameter) {
- operand_str = Printf("Parameter %lld", operand->parameter_number());
+ if (const HloConstantInstruction* constant =
+ TryGetFusionParameterConstant(operand)) {
+ operand_str = stringify_constant(constant);
+ } else {
+ operand_str = Printf("Parameter %lld", operand->parameter_number());
+ }
} else {
operand_str = operand->name();
}
@@ -895,11 +913,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
const auto kParameterColor = kOrange;
// Special case: If this instruction has a parameter merged into it, paint it
- // the same color as a parameter.
+ // the same color as a parameter. Unless the merged-in parameter is a
+ // parameter to a fusion node that is bound to a constant -- these aren't
+ // "real" parameters from the user's perspective.
if (std::any_of(instr->operands().begin(), instr->operands().end(),
[&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter &&
- ShouldMergeIntoUsers(operand);
+ ShouldMergeIntoUsers(operand) &&
+ TryGetFusionParameterConstant(operand) == nullptr;
})) {
return kParameterColor;
}
@@ -939,11 +960,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kPower:
case HloOpcode::kReal:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSelect:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
@@ -962,6 +985,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kBitcast:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
@@ -973,13 +997,12 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
}
return kGreen;
case HloOpcode::kConcatenate:
- case HloOpcode::kCopy:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
- case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
case HloOpcode::kTranspose:
// De-emphasize scalar-shaped data movement ops and all data movement ops
// inside fusion nodes, both of which are essentially free.
@@ -995,6 +1018,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kWhite;
}
return kGreen;
+ case HloOpcode::kCopy:
+ // Emphasize copy nodes, which are either physical transposes (and thus
+ // significant), or copies of read-only buffers (and thus dead weight).
+ return kGreen;
case HloOpcode::kConvolution:
case HloOpcode::kDot:
case HloOpcode::kFft:
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
index 8e52d926d8..1d7a062c55 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -120,8 +121,8 @@ TEST(HloGraphDumperTest, NestedFusion) {
TEST(HloGraphDumperTest, Constant) {
HloComputation::Builder b("b");
auto instruction = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
- instruction->set_name("i_am_a_constant_root_instruction");
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-42)));
+ instruction->SetAndSanitizeName("i_am_a_constant_root_instruction");
HloModuleConfig config;
HloModule m(TestName(), config);
HloComputation* root_computation = m.AddEntryComputation(b.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 4095b3d337..830ebfb125 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -16,25 +16,25 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include <algorithm>
-#include <deque>
#include <ostream>
#include <set>
#include <unordered_set>
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -60,107 +60,366 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
- auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
- for (const int64 operand_id : proto.operand_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
- << "No instruction with id " << operand_id;
- instruction->AppendOperand(instruction_map.at(operand_id));
- }
- for (const int64 predecessor_id : proto.control_predecessor_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
- << "No instruction with id " << predecessor_id;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
- ->AddControlDependencyTo(instruction.get()));
- }
-
- // In the proto, fused computations are held exclusively within the
- // HloInstructionProto and do not appear as an HloComputationProto within the
- // HloModuleProto.
- if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RET_CHECK(!proto.fusion_kind().empty());
- TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
- StringToFusionKind(proto.fusion_kind()));
-
- // Find the fused computation and set its fusion instruction.
- TF_RET_CHECK(proto.called_computation_ids_size() == 1)
- << "Expect 1 called computation for fusion instruction, but sees "
- << proto.called_computation_ids_size();
- const int64 fusion_id = proto.called_computation_ids(0);
- auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
- TF_RET_CHECK(fused_computation != nullptr)
- << "No fusion computation with id " << fusion_id;
- fused_computation->SetFusionInstruction(instruction.get());
- instruction->called_computations_.push_back(fused_computation);
- } else {
- for (const int64 computation_id : proto.called_computation_ids()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_id))
- << "No computation with id " << computation_id;
- instruction->called_computations_.push_back(
- computation_map.at(computation_id));
+ std::unique_ptr<HloInstruction> instruction;
+ const auto operands = [&instruction_map, &proto](int index) {
+ return instruction_map.at(proto.operand_ids(index));
+ };
+ const auto all_operands = [&instruction_map, &proto]() {
+ std::vector<HloInstruction*> result(proto.operand_ids_size());
+ std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
+ result.begin(), [&instruction_map](int64 operand_id) {
+ return instruction_map.at(operand_id);
+ });
+ return result;
+ };
+ const auto computations = [&computation_map, &proto](int index) {
+ return computation_map.at(proto.called_computation_ids(index));
+ };
+ switch (opcode) {
+ // Ops migrated to subclasses.
+ case HloOpcode::kBatchNormTraining:
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "BatchNormTraining instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
+ instruction = CreateBatchNormTraining(
+ proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
+ proto.feature_index());
+ break;
+ case HloOpcode::kBatchNormInference:
+ TF_RET_CHECK(proto.operand_ids_size() == 5)
+ << "BatchNormInference instruction should have 5 operands but sees "
+ << proto.operand_ids_size();
+ instruction = CreateBatchNormInference(
+ proto.shape(), operands(0), operands(1), operands(2), operands(3),
+ operands(4), proto.epsilon(), proto.feature_index());
+ break;
+ case HloOpcode::kBatchNormGrad:
+ TF_RET_CHECK(proto.operand_ids_size() == 5)
+ << "BatchNormGrad instruction should have 5 operands but sees "
+ << proto.operand_ids_size();
+ instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
+ operands(2), operands(3), operands(4),
+ proto.epsilon(), proto.feature_index());
+ break;
+ case HloOpcode::kFft: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Fft instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ std::vector<int64> fft_length(proto.fft_length().begin(),
+ proto.fft_length().end());
+ instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
+ tensorflow::gtl::ArraySlice<int64>(fft_length));
+ break;
+ }
+ case HloOpcode::kSend:
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Send instruction should have 2 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateSend(operands(0), operands(1), proto.channel_id());
+ break;
+ case HloOpcode::kSendDone:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "SendDone instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateSendDone(operands(0));
+ break;
+ case HloOpcode::kRecv:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Recv instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
+ proto.channel_id());
+ break;
+ case HloOpcode::kRecvDone:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "RecvDone instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateRecvDone(operands(0));
+ break;
+ case HloOpcode::kReverse:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Reverse instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateReverse(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kConcatenate:
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Concatenate instruction should have 1 dimension but sees "
+ << proto.dimensions_size();
+ instruction =
+ CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
+ break;
+ case HloOpcode::kReduce:
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Reduce instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Reduce instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ instruction = CreateReduce(proto.shape(), operands(0), operands(1),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()),
+ computations(0));
+ break;
+ case HloOpcode::kSort: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1 ||
+ proto.operand_ids_size() == 2)
+ << "Sort instruction should have 1 or 2 operands but has "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.dimensions().size() == 1)
+ << "Sort instruction should have 1 dimension";
+ HloInstruction* keys = operands(0);
+ HloInstruction* values =
+ proto.operand_ids_size() == 2 ? operands(1) : nullptr;
+ instruction =
+ CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ break;
+ }
+ case HloOpcode::kTranspose:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Transpose instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction =
+ CreateTranspose(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kBroadcast:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Broadcast instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction =
+ CreateBroadcast(proto.shape(), operands(0),
+ std::vector<int64>(proto.dimensions().begin(),
+ proto.dimensions().end()));
+ break;
+ case HloOpcode::kMap:
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Map instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ instruction = CreateMap(proto.shape(), all_operands(), computations(0));
+ break;
+ case HloOpcode::kSlice: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Slice instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ std::vector<int64> slice_starts, slice_limits, slice_strides;
+ for (const HloInstructionProto::SliceDimensions& slice_dimensions :
+ proto.slice_dimensions()) {
+ slice_starts.push_back(slice_dimensions.start());
+ slice_limits.push_back(slice_dimensions.limit());
+ slice_strides.push_back(slice_dimensions.stride());
+ }
+ instruction = CreateSlice(proto.shape(), operands(0), slice_starts,
+ slice_limits, slice_strides);
+ break;
+ }
+ case HloOpcode::kConstant: {
+ // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
+ if (proto.has_literal()) {
+ TF_ASSIGN_OR_RETURN(auto literal,
+ Literal::CreateFromProto(proto.literal()));
+ instruction = CreateConstant(std::move(literal));
+ } else {
+ instruction = MakeUnique<HloConstantInstruction>(proto.shape());
+ }
+ break;
+ }
+ case HloOpcode::kTrace: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Trace instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_literal());
+ TF_ASSIGN_OR_RETURN(auto literal,
+ Literal::CreateFromProto(proto.literal()));
+ instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
+ break;
+ }
+ case HloOpcode::kFusion: {
+ // In the proto, fused computations are held exclusively within the
+ // HloInstructionProto and do not appear as an HloComputationProto within
+ // the HloModuleProto.
+ TF_RET_CHECK(!proto.fusion_kind().empty());
+ TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
+ StringToFusionKind(proto.fusion_kind()));
+
+ // Find the fused computation and set its fusion instruction.
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Expect 1 called computation for fusion instruction but sees "
+ << proto.called_computation_ids_size();
+ const int64 fusion_id = proto.called_computation_ids(0);
+ auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
+ TF_RET_CHECK(fused_computation != nullptr)
+ << "No fusion computation with id " << fusion_id;
+ instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
+ fused_computation);
+ break;
+ }
+ case HloOpcode::kRng:
+ instruction =
+ CreateRng(proto.shape(), proto.distribution(), all_operands());
+ break;
+ case HloOpcode::kParameter:
+ instruction = CreateParameter(proto.parameter_number(), proto.shape(),
+ proto.name());
+ break;
+ case HloOpcode::kGetTupleElement:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "GetTupleElement instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
+ instruction = CreateGetTupleElement(proto.shape(), operands(0),
+ proto.tuple_index());
+ break;
+ case HloOpcode::kReducePrecision:
+ instruction =
+ CreateReducePrecision(proto.shape(), operands(0),
+ proto.exponent_bits(), proto.mantissa_bits());
+ break;
+ case HloOpcode::kInfeed: {
+ const Shape& data_shape =
+ ShapeUtil::GetTupleElementShape(proto.shape(), 0);
+ if (proto.operand_ids_size() == 0) {
+ // TODO(b/80000000): Remove this when all uses of infeed are
+ // converted to take tokens.
+ instruction = CreateInfeed(data_shape, proto.infeed_config());
+ } else {
+ CHECK_EQ(proto.operand_ids_size(), 1);
+ instruction =
+ CreateInfeed(data_shape, operands(0), proto.infeed_config());
+ }
+ } break;
+ case HloOpcode::kOutfeed:
+ if (proto.operand_ids_size() == 1) {
+ // TODO(b/80000000): Remove this when all uses of outfeed are
+ // converted to take tokens.
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ proto.outfeed_config());
+ } else {
+ CHECK_EQ(proto.operand_ids_size(), 2);
+ instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
+ operands(1), proto.outfeed_config());
+ }
+ break;
+ case HloOpcode::kCrossReplicaSum: {
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "CrossReplicaSum should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ tensorflow::gtl::optional<int64> all_reduce_id;
+ if (proto.all_reduce_id() > 0) {
+ all_reduce_id = proto.all_reduce_id();
+ }
+ instruction = CreateCrossReplicaSum(
+ proto.shape(), all_operands(), computations(0),
+ /*replica_group_ids=*/
+ std::vector<int64>(proto.replica_group_ids().begin(),
+ proto.replica_group_ids().end()),
+ /*barrier=*/proto.cross_replica_sum_barrier(),
+ /*all_reduce_id=*/all_reduce_id);
+ break;
+ }
+ case HloOpcode::kConvolution:
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Convolution instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_window());
+ TF_RET_CHECK(proto.has_convolution_dimension_numbers());
+ instruction =
+ CreateConvolve(proto.shape(), operands(0), operands(1),
+ proto.window(), proto.convolution_dimension_numbers());
+ break;
+ case HloOpcode::kReduceWindow:
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "ReduceWindow instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "ReduceWindow should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1),
+ proto.window(), computations(0));
+ break;
+ case HloOpcode::kSelectAndScatter:
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "SelectAndScatter instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 2)
+ << "SelectAndScatter should have 2 called computations but sees "
+ << proto.called_computation_ids_size();
+ instruction = CreateSelectAndScatter(
+ proto.shape(), operands(0), computations(0), proto.window(),
+ operands(1), operands(2), computations(1));
+ break;
+ case HloOpcode::kCustomCall:
+ instruction = CreateCustomCall(proto.shape(), all_operands(),
+ proto.custom_call_target());
+ if (proto.has_window()) {
+ static_cast<HloCustomCallInstruction*>(instruction.get())
+ ->set_window(proto.window());
+ }
+ if (proto.has_convolution_dimension_numbers()) {
+ static_cast<HloCustomCallInstruction*>(instruction.get())
+ ->set_convolution_dimension_numbers(
+ proto.convolution_dimension_numbers());
+ }
+ break;
+ case HloOpcode::kHostCompute:
+ instruction =
+ CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(),
+ proto.cost_estimate_ns());
+ break;
+ case HloOpcode::kPad:
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Pad instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_padding_config());
+ instruction = CreatePad(proto.shape(), operands(0), operands(1),
+ proto.padding_config());
+ break;
+ case HloOpcode::kDynamicSlice: {
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "DynamicSlice instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
+ c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
+ instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
+ slice_sizes);
+ break;
+ }
+ default: {
+ instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
+ for (const int64 operand_id : proto.operand_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
+ << "No instruction with id " << operand_id;
+ instruction->AppendOperand(instruction_map.at(operand_id));
+ }
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
+ ->AddControlDependencyTo(instruction.get()));
+ }
+ if (instruction->opcode() != HloOpcode::kFusion) {
+ for (const int64 computation_id : proto.called_computation_ids()) {
+ TF_RET_CHECK(ContainsKey(computation_map, computation_id))
+ << "No computation with id " << computation_id;
+ instruction->called_computations_.push_back(
+ computation_map.at(computation_id));
+ }
+ }
+ break;
}
- }
-
- if (instruction->opcode() == HloOpcode::kTrace) {
- TF_RET_CHECK(instruction->operands().size() == 1)
- << "Trace instruction should have 1 operand but sees "
- << instruction->operands().size();
- instruction->mutable_operand(0)->set_tracing(instruction.get());
}
TF_RET_CHECK(!proto.name().empty());
- instruction->name_ = proto.name();
-
+ instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
- if (proto.has_literal()) {
- TF_ASSIGN_OR_RETURN(instruction->literal_,
- Literal::CreateFromProto(proto.literal()));
- }
- instruction->parameter_number_ = proto.parameter_number();
- instruction->tuple_index_ = proto.tuple_index();
- for (int64 dimension : proto.dimensions()) {
- instruction->dimensions_.push_back(dimension);
- }
- if (proto.has_window()) {
- instruction->window_ = MakeUnique<Window>(proto.window());
- }
- if (proto.has_convolution_dimension_numbers()) {
- instruction->convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(
- proto.convolution_dimension_numbers());
- }
if (proto.has_dot_dimension_numbers()) {
instruction->dot_dimension_numbers_ =
MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
}
- for (const HloInstructionProto::SliceDimensions& slice_dimensions :
- proto.slice_dimensions()) {
- instruction->slice_starts_.push_back(slice_dimensions.start());
- instruction->slice_limits_.push_back(slice_dimensions.limit());
- instruction->slice_strides_.push_back(slice_dimensions.stride());
- }
- instruction->exponent_bits_ = proto.exponent_bits();
- instruction->mantissa_bits_ = proto.mantissa_bits();
- for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) {
- instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size);
- }
- if (proto.has_padding_config()) {
- instruction->padding_config_ =
- MakeUnique<PaddingConfig>(proto.padding_config());
- }
- instruction->outfeed_config_ = proto.outfeed_config();
- instruction->distribution_ = proto.distribution();
- instruction->epsilon_ = proto.epsilon();
- instruction->feature_index_ = proto.feature_index();
- instruction->channel_id_ = proto.channel_id();
- instruction->infeed_config_ = proto.infeed_config();
- instruction->custom_call_target_ = proto.custom_call_target();
- instruction->outfeed_shape_ = proto.outfeed_shape();
- instruction->fft_type_ = proto.fft_type();
- for (int64 fft_len : proto.fft_length()) {
- instruction->fft_length_.push_back(fft_len);
- }
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -175,61 +434,34 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
for (int64 bound : proto.gather_window_bounds()) {
instruction->gather_window_bounds_.push_back(bound);
}
-
- instruction->channel_name_ = proto.channel_name();
- instruction->cost_estimate_ns_ = proto.cost_estimate_ns();
-
return std::move(instruction);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
int64 parameter_number, const Shape& shape, const string& name) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
- instruction->parameter_number_ = parameter_number;
- instruction->name_ = name;
- return instruction;
+ return MakeUnique<HloParameterInstruction>(parameter_number, shape, name);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
const string& tag, HloInstruction* operand) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
- instruction->operands_.push_back(operand);
- instruction->literal_ = Literal::CreateR1U8(tag);
- operand->set_tracing(instruction.get());
- return instruction;
+ return MakeUnique<HloTraceInstruction>(tag, operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
std::unique_ptr<Literal> literal) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
- instruction->literal_ = std::move(literal);
- return instruction;
+ return MakeUnique<HloConstantInstruction>(std::move(literal));
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateGetTupleElement(const Shape& shape,
HloInstruction* operand, int64 index) {
- CHECK(ShapeUtil::IsTuple(operand->shape()));
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
- instruction->tuple_index_ = index;
- instruction->AppendOperand(operand);
- return instruction;
+ return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape));
- instruction->distribution_ = distribution;
- instruction->shape_ = shape;
- for (HloInstruction* param : parameters) {
- instruction->AppendOperand(param);
- }
- return instruction;
+ return MakeUnique<HloRngInstruction>(shape, distribution, parameters);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
@@ -271,7 +503,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
break;
default:
@@ -306,6 +537,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
@@ -323,8 +555,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
// Only certain opcodes are supported with CreateTernary: opcodes of ternary
// instructions with no auxiliary fields.
switch (opcode) {
- case (HloOpcode::kClamp):
- case (HloOpcode::kSelect):
+ case HloOpcode::kClamp:
+ case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
break;
default:
LOG(FATAL) << "Invalid ternary instruction opcode "
@@ -342,45 +575,22 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
- CHECK(static_operands.empty()) << "static_operands not yet supported";
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->called_computations_.push_back(map_computation);
- return instruction;
+ HloComputation* map_computation) {
+ return MakeUnique<HloMapInstruction>(shape, operands, map_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape));
- if (window_util::HasBaseDilation(window)) {
- instruction->name_ = instruction->name() + "-base-dilated";
- }
- if (window_util::HasWindowDilation(window)) {
- instruction->name_ = instruction->name() + "-window-dilated";
- }
- instruction->AppendOperand(lhs);
- instruction->AppendOperand(rhs);
- instruction->window_ = MakeUnique<Window>(window);
- instruction->convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(dimension_numbers);
- return instruction;
+ return MakeUnique<HloConvolutionInstruction>(shape, lhs, rhs, window,
+ dimension_numbers);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape));
- instruction->AppendOperand(operand);
- instruction->fft_type_ = fft_type;
- instruction->fft_length_.assign(fft_length.begin(), fft_length.end());
- return instruction;
+ return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
@@ -413,96 +623,95 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
const int exponent_bits,
const int mantissa_bits) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
- instruction->AppendOperand(operand);
- instruction->exponent_bits_ = exponent_bits;
- instruction->mantissa_bits_ = mantissa_bits;
- return instruction;
+ return MakeUnique<HloReducePrecisionInstruction>(
+ shape, operand, exponent_bits, mantissa_bits);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id) {
+ return MakeUnique<HloAllReduceInstruction>(
+ shape, operands, reduce_computation, replica_group_ids, barrier,
+ all_reduce_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
- const Shape& shape, const string& config) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
- instruction->set_infeed_config(config);
- return instruction;
+ const Shape& infeed_shape, HloInstruction* token_operand,
+ const string& config) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
+ const Shape& infeed_shape, const string& config) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape, config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
- const Shape& shape, HloInstruction* operand,
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) {
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
+ token_operand, outfeed_config);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
+ const Shape& outfeed_shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
- CHECK(ShapeUtil::Compatible(operand->shape(), shape))
- << "Outfeed shape " << shape << " must be compatible with operand shape "
- << operand->shape();
- instruction->AppendOperand(operand);
- instruction->outfeed_config_ = std::string(outfeed_config);
- instruction->outfeed_shape_ = shape;
- return instruction;
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand,
+ outfeed_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
- HloInstruction* operand, int64 channel_id) {
- // Send instruction produces a tuple of {aliased operand, U32 context}.
- Shape output_shape = ShapeUtil::MakeTupleShape(
- {operand->shape(), ShapeUtil::MakeShape(U32, {})});
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = channel_id;
- return instruction;
+ HloInstruction* operand, HloInstruction* token, int64 channel_id) {
+ return MakeUnique<HloSendInstruction>(operand, token, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kSend)
+ auto send_operand = DynCast<HloSendInstruction>(operand);
+ CHECK(send_operand != nullptr)
<< "SendDone must take the context operand from Send";
- auto instruction = WrapUnique(
- new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
+ return MakeUnique<HloSendDoneInstruction>(send_operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
- const Shape& shape, int64 channel_id) {
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
- Shape output_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
- instruction->channel_id_ = channel_id;
- return instruction;
+ const Shape& shape, HloInstruction* token, int64 channel_id) {
+ return MakeUnique<HloRecvInstruction>(shape, token, channel_id);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kRecv)
+ auto recv_operand = DynCast<HloRecvInstruction>(operand);
+ CHECK(recv_operand != nullptr)
<< "RecvDone must take the context operand from Recv";
- Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
+ return MakeUnique<HloRecvDoneInstruction>(recv_operand);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
+ return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ CHECK(!operands.empty());
+ auto instruction = WrapUnique(
+ new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
+ return WrapUnique(
+ new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
@@ -536,30 +745,15 @@ HloInstruction::CreateCrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
- instruction->AppendOperand(operand);
- instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
- instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
- instruction->slice_strides_.assign(strides.begin(), strides.end());
- // For backward compatibility with old serialized computations: if there are
- // no strides, assume all strides are 1.
- // TODO(b/63317920): remove this code.
- if (instruction->slice_strides_.empty()) {
- instruction->slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
- }
- return instruction;
+ return MakeUnique<HloSliceInstruction>(shape, operand, start_indices,
+ limit_indices, strides);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(start_indices);
- instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(),
- slice_sizes.end());
- return instruction;
+ return MakeUnique<HloDynamicSliceInstruction>(shape, operand, start_indices,
+ slice_sizes);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -578,13 +772,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
int64 dimension) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->dimensions_.push_back(dimension);
- return instruction;
+ return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
@@ -607,25 +795,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
- instruction->AppendOperand(arg);
- instruction->AppendOperand(init_value);
- instruction->dimensions_.assign(dimensions_to_reduce.begin(),
- dimensions_to_reduce.end());
- instruction->called_computations_.push_back(reduce_computation);
- return instruction;
+ return MakeUnique<HloReduceInstruction>(
+ shape, arg, init_value, dimensions_to_reduce, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(init_value);
- instruction->called_computations_.push_back(reduce_computation);
- instruction->window_ = MakeUnique<Window>(window);
- return instruction;
+ return MakeUnique<HloReduceWindowInstruction>(shape, operand, init_value,
+ window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -634,14 +812,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* scale,
HloInstruction* offset, float epsilon,
int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(offset);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormTrainingInstruction>(
+ shape, operand, scale, offset, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -649,16 +821,8 @@ HloInstruction::CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(offset);
- instruction->AppendOperand(mean);
- instruction->AppendOperand(variance);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormInferenceInstruction>(
+ shape, operand, scale, offset, mean, variance, epsilon, feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -667,16 +831,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
HloInstruction* variance,
HloInstruction* grad_output, float epsilon,
int64 feature_index) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(scale);
- instruction->AppendOperand(mean);
- instruction->AppendOperand(variance);
- instruction->AppendOperand(grad_output);
- instruction->epsilon_ = epsilon;
- instruction->feature_index_ = feature_index;
- return instruction;
+ return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean,
+ variance, grad_output, epsilon,
+ feature_index);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -684,27 +841,15 @@ HloInstruction::CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(source);
- instruction->AppendOperand(init_value);
- // Select comes before scatter in the vector.
- instruction->called_computations_.push_back(select);
- instruction->called_computations_.push_back(scatter);
- instruction->window_ = MakeUnique<Window>(window);
- return instruction;
+ return MakeUnique<HloSelectAndScatterInstruction>(
+ shape, operand, select, window, source, init_value, scatter);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(broadcast_dimensions.begin(),
- broadcast_dimensions.end());
- return instruction;
+ return MakeUnique<HloBroadcastInstruction>(shape, operand,
+ broadcast_dimensions);
}
/* static */ std::unique_ptr<HloInstruction>
@@ -762,11 +907,8 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
const PaddingConfig& padding_config) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape));
- instruction->AppendOperand(operand);
- instruction->AppendOperand(padding_value);
- instruction->padding_config_ = MakeUnique<PaddingConfig>(padding_config);
- return instruction;
+ return MakeUnique<HloPadInstruction>(shape, operand, padding_value,
+ padding_config);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
@@ -783,53 +925,34 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
- CHECK_EQ(shape.dimensions().size(), dimensions.size());
- CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
- CHECK(std::equal(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end(),
- Permute(dimensions, shape.dimensions()).begin()))
- << "shape: " << ShapeUtil::HumanString(shape)
- << ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << Join(dimensions, ", ") << "}";
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
- instruction->AppendOperand(operand);
- instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
- return instruction;
+ return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
+ const Shape& shape, int64 dimension, HloInstruction* keys,
+ HloInstruction* values) {
+ return MakeUnique<HloSortInstruction>(shape, dimension, keys, values);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
- instruction->fusion_kind_ = fusion_kind;
- instruction->name_ = "fusion";
- instruction->set_parent(fused_root->parent());
- instruction->set_metadata(fused_root->metadata());
- instruction->CloneAndFuseInternal(fused_root);
- return instruction;
+ return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* fusion_computation) {
- auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->fusion_kind_ = fusion_kind;
- instruction->name_ = "fusion";
- instruction->called_computations_.push_back(fusion_computation);
- fusion_computation->SetFusionInstruction(instruction.get());
- return instruction;
+ return MakeUnique<HloFusionInstruction>(shape, fusion_kind, operands,
+ fusion_computation);
}
-void HloInstruction::set_device_sharding(int64 device) {
- HloSharding device_sharding = HloSharding::AssignDevice(device);
+void HloInstruction::set_single_sharding(const HloSharding& sharding) {
+ CHECK(!sharding.IsTuple()) << sharding;
if (ShapeUtil::IsTuple(shape())) {
- set_sharding(HloSharding::Tuple(device_sharding.GetAsShapeTree(shape())));
+ set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
} else {
- set_sharding(device_sharding);
+ set_sharding(sharding);
}
}
@@ -843,289 +966,6 @@ void HloInstruction::SetupDerivedInstruction(
derived_instruction->set_metadata(metadata_);
}
-HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
- CHECK_EQ(opcode(), HloOpcode::kFusion);
- CHECK_EQ(operand_count(),
- fused_instructions_computation()->parameter_instructions().size());
- const int64 param_no = operand_count();
- // Name the parameter after the instruction it represents in the outer
- // (non-fusion) computation.
- string param_name = StrCat(new_operand->name(), ".param_", param_no);
- HloInstruction* fused_parameter =
- fused_instructions_computation()->AddParameter(
- HloInstruction::CreateParameter(param_no, new_operand->shape(),
- param_name));
- AppendOperand(new_operand);
- return fused_parameter;
-}
-
-void HloInstruction::MergeFusionInstruction(
- HloInstruction* instruction_to_merge) {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
- CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
- operands().end());
- // Clone the instruction from which to merge fused instructions.
- std::unique_ptr<HloInstruction> clone = instruction_to_merge->Clone();
- // Replace uses of fused parameters with the corresponding operand of the
- // fusion. Add all non-parameter fused instructions to 'unfused_instructions'
- // to be merged into 'this'. This is done in reverse post order.
- std::vector<HloInstruction*> unfused_instructions;
- auto fused_instructions =
- clone->fused_instructions_computation()->MakeInstructionPostOrder();
- for (auto fused_it = fused_instructions.rbegin();
- fused_it != fused_instructions.rend(); ++fused_it) {
- auto fused_instruction = *fused_it;
- if (fused_instruction->opcode() == HloOpcode::kParameter) {
- TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith(
- clone->mutable_operand(fused_instruction->parameter_number())));
- } else {
- unfused_instructions.push_back(fused_instruction);
- }
- }
- CHECK(unfused_instructions.front() == clone->fused_expression_root());
- // Replace instruction_to_merge use of 'this' with unfused_root.
- TF_CHECK_OK(
- instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
- // Fuse 'unfused_instructions' into 'this'.
- for (auto& instruction : unfused_instructions) {
- FuseInstruction(instruction);
- instruction->DetachFromOperands();
- }
- CHECK_EQ(0, clone->user_count());
- clone->DetachFromOperands();
- TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
- clone->fused_instructions_computation()));
-}
-
-void HloInstruction::MergeFusionInstructionIntoMultiOutput(
- HloInstruction* instruction_to_merge) {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion);
- // Add all non-parameter fused instructions to 'unfused_instructions' to be
- // merged into 'this'. `old_to_new' maps the instructions in the fused node
- // to the disaseembled fusion instructions.
- // Note that we add the unfused instructions to this->parent_ computation.
- // This is necessary because the unique_id needs for an instruction and
- // it's only added when inserting to the computation.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
- std::vector<HloInstruction*> unfused_instructions;
- auto computation_to_merge =
- instruction_to_merge->fused_instructions_computation();
- auto post_order = computation_to_merge->MakeInstructionPostOrder();
- for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
- auto fused_instruction = *rit;
- if (fused_instruction->opcode() == HloOpcode::kParameter) {
- InsertOrDie(&old_to_new, fused_instruction,
- instruction_to_merge->mutable_operand(
- fused_instruction->parameter_number()));
- continue;
- }
-
- // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
- // which clones again. This can be improved.
- auto cloned_instruction =
- parent_->AddInstruction(fused_instruction->Clone());
- unfused_instructions.push_back(cloned_instruction);
- InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
- }
- for (auto unfused_instruction : unfused_instructions) {
- for (int64 index = 0; index < unfused_instruction->operand_count();
- index++) {
- auto new_operand =
- FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
- TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
- }
- }
-
- HloInstruction* unfused_root = unfused_instructions.front();
- TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
-
- TF_CHECK_OK(
- instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
- if (GetModule()) {
- TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
- }
-
- // Fuse the root instruction and generate multiple outputs.
- FuseInstructionIntoMultiOutput(unfused_root);
- TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
- // The rest instructions are of normal fusing.
- for (int64 i = 1; i < unfused_instructions.size(); i++) {
- auto instruction = unfused_instructions[i];
- FuseInstruction(instruction);
- TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
- }
-}
-
-HloInstruction* HloInstruction::FuseInstructionInternal(
- HloInstruction* instruction_to_fuse, bool add_output) {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
-
- // When add_output is false, this fusion instruction must be a user of
- // instruction_to_fuse.
- if (!add_output) {
- CHECK(IsUserOf(instruction_to_fuse));
- }
- HloInstruction* fused_instruction =
- CloneAndFuseInternal(instruction_to_fuse, add_output);
- return fused_instruction;
-}
-
-HloInstruction* HloInstruction::CloneAndFuseInternal(
- HloInstruction* instruction_to_fuse, bool add_output) {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
- VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
- HloInstruction* clone = nullptr;
- if (called_computations_.empty()) {
- // New fusion instruction. It should not be a multioutput instruction.
- CHECK(!add_output);
- auto builder = HloComputation::Builder("fused_computation", this);
- builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
- called_computations_.push_back(
- CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
- clone = fused_expression_root();
- } else {
- clone = fused_instructions_computation()->AddInstruction(
- instruction_to_fuse->Clone(/*suffix=*/""));
- // When add_output is false, instruction_to_fuse is necessarily an operand
- // of the fusion instruction. After fusion this will no longer be the case.
- // Remove the operand from the operand list and remove its corresponding
- // fused parameter instruction. Renumber parameters as necessary to make
- // parameter numbers consistent with their index in the
- // fused_parameter_ vector.
- bool in_operand_list = std::find(operands_.begin(), operands_.end(),
- instruction_to_fuse) != operands_.end();
- CHECK(add_output || in_operand_list);
- const std::vector<HloInstruction*>& fused_parameters =
- fused_instructions_computation()->parameter_instructions();
- for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
- if (instruction_to_fuse == operands_[operand_num]) {
- // replace the fused parameter instruction's uses with the clone.
- HloInstruction* fused_parameter = fused_parameters[operand_num];
- TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
-
- // Remove the corresponding fused parameter and operand from their
- // respective vectors.
- TF_CHECK_OK(
- fused_instructions_computation()->RemoveParameter(operand_num));
- operands_.erase(operands_.begin() + operand_num);
- break;
- }
- }
- // We've cloned instruction_to_fuse into this fusion instruction, so this
- // fusion instruction is no longer a use of instruction_to_fuse.
- if (in_operand_list) {
- instruction_to_fuse->RemoveUser(this);
- // When the instruction_to_fuse does not have other users, we don't need
- // to generate a multioutput fusion instruction.
- if (instruction_to_fuse->user_count() == 0) {
- add_output = false;
- }
- }
- }
-
- // Reread the parameters in the computation.
- const std::vector<HloInstruction*>& fused_parameters =
- fused_instructions_computation()->parameter_instructions();
-
- // Add each operand of the clone as an operand of the fusion instruction. A
- // complication is that some clone operands may already be operands of the
- // fusion instruction.
- for (int64 operand_num = 0; operand_num < clone->operand_count();
- ++operand_num) {
- HloInstruction* operand = clone->mutable_operand(operand_num);
-
- // See if this operand is already an operand of the fusion node.
- CHECK_EQ(operands_.size(), fused_parameters.size());
- HloInstruction* fused_param = nullptr;
- for (int64 i = 0; i < operands_.size(); ++i) {
- if (operands_[i] == operand) {
- fused_param = fused_parameters[i];
- break;
- }
- }
-
- if (fused_param == nullptr) {
- // Clone's operand was not already an operand of the fusion
- // instruction. Add it as an operand and add a corresponding fused
- // parameter instruction.
- fused_param = AddFusionOperand(operand);
- }
- TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
- }
-
- if (add_output) {
- CHECK_GT(instruction_to_fuse->user_count(), 0);
- // If this is already a multioutput fusion instruction, expand the root
- // tuple by 1.
- HloInstruction* fused_root = fused_expression_root();
- HloInstruction::InstructionVector tuple_elements;
- bool newly_created_tuple_instr = false;
- if (fused_root->opcode() == HloOpcode::kTuple) {
- tuple_elements = fused_root->operands();
- } else {
- tuple_elements.push_back(fused_root);
- newly_created_tuple_instr = true;
- }
- if (clone->opcode() == HloOpcode::kTuple) {
- for (auto inst : clone->operands()) {
- tuple_elements.push_back(inst);
- }
- } else {
- tuple_elements.push_back(clone);
- }
- HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
- HloInstruction::CreateTuple(tuple_elements));
- fused_instructions_computation()->set_root_instruction(new_root);
- shape_ = new_root->shape();
- if (fused_root->opcode() == HloOpcode::kTuple) {
- TF_CHECK_OK(
- fused_instructions_computation()->RemoveInstruction(fused_root));
- }
-
- // If this is a newly created multioutput instruction, we need to update
- // the use of the original fusion instruction.
- if (newly_created_tuple_instr) {
- HloInstruction* new_instr = parent_->AddInstruction(
- HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
- TF_CHECK_OK(ReplaceAllUsesWith(new_instr));
- }
- int64 index = tuple_elements.size();
- if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
- index -= instruction_to_fuse->operand_count();
- std::vector<HloInstruction*> to_be_removed;
- for (auto old_gte : instruction_to_fuse->users()) {
- CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
- int64 old_tuple_index = old_gte->tuple_index();
- HloInstruction* new_gte =
- parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
- old_gte->shape(), this, index + old_tuple_index));
- TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
- to_be_removed.push_back(old_gte);
- }
- for (auto old_gte : to_be_removed) {
- TF_CHECK_OK(parent_->RemoveInstruction(old_gte));
- }
- TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone));
- } else {
- HloInstruction* new_gte =
- parent_->AddInstruction(HloInstruction::CreateGetTupleElement(
- clone->shape(), this, index - 1));
- TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
- }
- }
-
- VLOG(2) << "New clone:\n" << clone->ToString();
- return clone;
-}
-
-RandomDistribution HloInstruction::random_distribution() const {
- CHECK_EQ(opcode_, HloOpcode::kRng);
- return distribution_;
-}
-
bool HloInstruction::HasSideEffectNoRecurse() const {
switch (opcode_) {
case HloOpcode::kSend:
@@ -1171,26 +1011,15 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->custom_call_target_ = std::string(custom_call_target);
- return instruction;
+ return MakeUnique<HloCustomCallInstruction>(shape, operands,
+ custom_call_target);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
- std::unique_ptr<HloInstruction> instruction =
- WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape));
- for (auto operand : operands) {
- instruction->AppendOperand(operand);
- }
- instruction->channel_name_ = std::string(channel_name);
- instruction->cost_estimate_ns_ = cost_estimate_ns;
- return instruction;
+ return MakeUnique<HloHostComputeInstruction>(shape, operands, channel_name,
+ cost_estimate_ns);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
@@ -1263,6 +1092,43 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
// in the face of code changes than copying fields explicitly. This also
// properly sets the user fields of the operands.
switch (opcode_) {
+ // Ops migrated to subclasses.
+ // TODO(b/80131774): Remove this switch when migration is complete.
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kFft:
+ case HloOpcode::kSend:
+ case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReverse:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReduce:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kMap:
+ case HloOpcode::kSlice:
+ case HloOpcode::kConstant:
+ case HloOpcode::kTrace:
+ case HloOpcode::kFusion:
+ case HloOpcode::kRng:
+ case HloOpcode::kParameter:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kHostCompute:
+ case HloOpcode::kPad:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kSort:
+ clone = CloneWithNewOperandsImpl(shape, new_operands, context);
+ break;
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -1283,7 +1149,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh:
CHECK_EQ(new_operands.size(), 1);
clone = CreateUnary(shape, opcode_, new_operands[0]);
@@ -1307,6 +1172,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
@@ -1316,36 +1182,15 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
CHECK_EQ(new_operands.size(), 3);
clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
new_operands[2]);
break;
// Other supported ops.
- case HloOpcode::kBroadcast:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateBroadcast(shape, new_operands[0], dimensions_);
- break;
case HloOpcode::kCall:
clone = CreateCall(shape, new_operands, to_apply());
break;
- case HloOpcode::kCustomCall:
- clone = CreateCustomCall(shape, new_operands, custom_call_target_);
- if (window_ != nullptr) {
- clone->window_ = MakeUnique<Window>(*window_);
- }
- if (convolution_dimension_numbers_ != nullptr) {
- clone->convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(
- *convolution_dimension_numbers_);
- }
- break;
- case HloOpcode::kHostCompute:
- clone = CreateHostCompute(shape, new_operands, channel_name_,
- cost_estimate_ns_);
- break;
- case HloOpcode::kConcatenate:
- clone = CreateConcatenate(shape, new_operands, dimensions(0));
- break;
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
clone = CreateConvert(shape, new_operands[0]);
@@ -1354,85 +1199,20 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateBitcastConvert(shape, new_operands[0]);
break;
- case HloOpcode::kReducePrecision:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_,
- mantissa_bits_);
- break;
- case HloOpcode::kConvolution:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
- *convolution_dimension_numbers_);
- break;
case HloOpcode::kDot:
CHECK_EQ(new_operands.size(), 2);
clone = CreateDot(shape, new_operands[0], new_operands[1],
*dot_dimension_numbers_);
break;
- case HloOpcode::kFft:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
- break;
- case HloOpcode::kCrossReplicaSum:
- clone = CreateCrossReplicaSum(shape, new_operands);
- break;
- case HloOpcode::kGetTupleElement:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateGetTupleElement(shape, new_operands[0], tuple_index());
- break;
- case HloOpcode::kMap:
- clone = CreateMap(shape, new_operands, to_apply());
- break;
- case HloOpcode::kPad:
- CHECK_EQ(new_operands.size(), 2);
- clone =
- CreatePad(shape, new_operands[0], new_operands[1], *padding_config_);
- break;
- case HloOpcode::kReduce:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_,
- to_apply());
- break;
- case HloOpcode::kReduceWindow:
- CHECK_EQ(new_operands.size(), 2);
- clone = CreateReduceWindow(shape, new_operands[0], new_operands[1],
- *window_, to_apply());
- break;
- case HloOpcode::kSelectAndScatter:
- CHECK_EQ(new_operands.size(), 3);
- clone =
- CreateSelectAndScatter(shape, new_operands[0], select(), *window_,
- new_operands[1], new_operands[2], scatter());
- break;
- case HloOpcode::kReverse:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateReverse(shape, new_operands[0], dimensions_);
- break;
- case HloOpcode::kRng:
- clone = CreateRng(shape, distribution_, new_operands);
- break;
case HloOpcode::kReshape:
CHECK_EQ(new_operands.size(), 1);
clone = CreateReshape(shape, new_operands[0]);
break;
- case HloOpcode::kSlice:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_,
- slice_strides_);
- break;
- case HloOpcode::kDynamicSlice:
- clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1],
- dynamic_slice_sizes_);
- break;
case HloOpcode::kDynamicUpdateSlice:
CHECK_EQ(new_operands.size(), 3);
clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
new_operands[2]);
break;
- case HloOpcode::kTranspose:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateTranspose(shape, new_operands[0], dimensions_);
- break;
case HloOpcode::kTuple:
clone = CreateTuple(new_operands);
*clone->mutable_shape() = shape;
@@ -1442,78 +1222,12 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone =
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
break;
- case HloOpcode::kConstant:
- clone = CreateConstant(literal_->CloneToUnique());
- break;
- case HloOpcode::kFusion: {
- HloModule* module = context != nullptr ? context->module() : GetModule();
- HloComputation* new_fused_computation = nullptr;
- if (context != nullptr) {
- new_fused_computation =
- context->FindComputation(fused_instructions_computation());
- }
- if (new_fused_computation == nullptr) {
- new_fused_computation = module->AddEmbeddedComputation(
- fused_instructions_computation()->Clone("clone", context));
- }
- clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(),
- /*operands=*/new_operands,
- /*fusion_computation=*/new_fused_computation);
- break;
- }
- case HloOpcode::kParameter:
- clone = CreateParameter(parameter_number_, shape, name_);
- break;
- case HloOpcode::kBatchNormTraining:
- CHECK_EQ(new_operands.size(), 3);
- clone =
- CreateBatchNormTraining(shape, new_operands[0], new_operands[1],
- new_operands[2], epsilon(), feature_index());
- break;
- case HloOpcode::kBatchNormInference:
- CHECK_EQ(new_operands.size(), 5);
- clone = CreateBatchNormInference(
- shape, new_operands[0], new_operands[1], new_operands[2],
- new_operands[3], new_operands[4], epsilon(), feature_index());
- break;
- case HloOpcode::kInfeed:
- CHECK_EQ(new_operands.size(), 0);
- clone = CreateInfeed(shape, infeed_config());
- break;
- case HloOpcode::kOutfeed:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config());
- break;
- case HloOpcode::kBatchNormGrad:
- CHECK_EQ(new_operands.size(), 5);
- clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1],
- new_operands[2], new_operands[3],
- new_operands[4], epsilon(), feature_index());
- break;
case HloOpcode::kConditional:
CHECK_EQ(new_operands.size(), 3);
clone = CreateConditional(shape, new_operands[0], new_operands[1],
true_computation(), new_operands[2],
false_computation());
break;
- case HloOpcode::kSend:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSend(new_operands[0], channel_id());
- break;
- case HloOpcode::kSendDone:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateSendDone(new_operands[0]);
- break;
- case HloOpcode::kRecv:
- CHECK_EQ(new_operands.size(), 0);
- // The shape is a tuple, but CreateRecv() wants the raw data shape.
- clone =
- CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id());
- break;
- case HloOpcode::kRecvDone:
- CHECK_EQ(new_operands.size(), 1);
- clone = CreateRecvDone(new_operands[0]);
- break;
case HloOpcode::kGather:
CHECK_EQ(new_operands.size(), 2);
clone = CreateGather(shape, new_operands[0], new_operands[1],
@@ -1525,8 +1239,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
break;
- case HloOpcode::kTrace:
- LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
+ case HloOpcode::kAfterAll:
+ if (new_operands.empty()) {
+ clone = CreateToken();
+ } else {
+ clone = CreateAfterAll(new_operands);
+ }
+ break;
}
SetupDerivedInstruction(clone.get());
clone->set_parent(parent_);
@@ -1542,7 +1261,29 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
return clone;
}
-HloInstruction::~HloInstruction() {}
+HloInstruction::~HloInstruction() {
+ // Detach from operands. An instruction may be repeated as an operand. To
+ // avoid calling RemoveUser twice on the same operand, check before remove.
+ for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
+ HloInstruction* operand = operands_[operand_num];
+ if (operand == nullptr) {
+ continue;
+ }
+ if (operand->user_set_.find(this) != operand->user_set_.end()) {
+ operand->RemoveUser(this);
+ }
+ operands_[operand_num] = nullptr;
+ }
+
+ // Update users. Set `nullptr` to the correpsonding operand slot for users.
+ for (auto& user : this->users()) {
+ for (int i = 0; i < user->operand_count(); ++i) {
+ if (user->operands_[i] == this) {
+ user->operands_[i] = nullptr;
+ }
+ }
+ }
+}
std::unique_ptr<HloInstruction> HloInstruction::Clone(
const string& suffix, HloCloneContext* context) const {
@@ -1607,40 +1348,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
return hlo;
}
-const Literal& HloInstruction::literal() const {
- CHECK_EQ(HloOpcode::kConstant, opcode_);
- return *literal_;
-}
-
-bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
-
-bool HloInstruction::CanHaveDimensionsField() const {
- return (opcode() == HloOpcode::kReverse ||
- opcode() == HloOpcode::kConcatenate ||
- opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
- opcode() == HloOpcode::kTranspose);
-}
-
-const std::vector<int64>& HloInstruction::dimensions() const {
- CHECK(CanHaveDimensionsField());
- return dimensions_;
-}
-
-int64 HloInstruction::dimensions(int64 index) const {
- return dimensions()[index];
-}
-
-int64 HloInstruction::concatenate_dimension() const {
- CHECK(opcode() == HloOpcode::kConcatenate);
- CHECK_EQ(1, dimensions_.size());
- return dimensions(0);
-}
-
-int64 HloInstruction::tuple_index() const {
- CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
- return tuple_index_;
-}
-
const HloInstruction* HloInstruction::operand(int64 i) const {
return operands_[i];
}
@@ -1722,6 +1429,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
operand->AddUser(this);
}
+void HloInstruction::RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ if (ascending_indices.empty()) {
+ return;
+ }
+ int next_index = 0;
+ int removed_count = 0;
+ for (int to_remove : ascending_indices) {
+ while (next_index < to_remove) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_LT(to_remove, operands_.size());
+ ++removed_count;
+ ++next_index;
+ }
+ while (next_index < operands_.size()) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_EQ(removed_count, ascending_indices.size());
+ operands_.resize(operands_.size() - removed_count);
+}
+
void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_set_, user)) {
user_set_.insert(user);
@@ -1729,10 +1460,6 @@ void HloInstruction::AddUser(HloInstruction* user) {
}
}
-bool HloInstruction::IsConstant() const {
- return opcode_ == HloOpcode::kConstant;
-}
-
bool HloInstruction::HasConstantOperand() const {
for (const HloInstruction* operand : operands_) {
if (operand->IsConstant()) {
@@ -1762,9 +1489,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
- case HloOpcode::kCrossReplicaSum:
case HloOpcode::kDivide:
- case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kEq:
case HloOpcode::kExp:
@@ -1780,6 +1505,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kAnd:
case HloOpcode::kNot:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kLt:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -1800,50 +1526,14 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
return true;
- // Broadcast, Concatenate, and Transpose need the same dimensions field.
- case HloOpcode::kBroadcast:
- case HloOpcode::kConcatenate:
- case HloOpcode::kTranspose:
- return dimensions() == other.dimensions();
-
- case HloOpcode::kFusion:
- return fusion_kind() == other.fusion_kind() &&
- eq_computations(fused_instructions_computation(),
- other.fused_instructions_computation());
-
// These opcodes have complex or special behavior so just return false.
- case HloOpcode::kDomain:
- case HloOpcode::kRng:
- case HloOpcode::kTrace:
case HloOpcode::kWhile:
+ case HloOpcode::kAfterAll:
return false;
- case HloOpcode::kParameter:
- return parameter_number() == other.parameter_number();
-
- case HloOpcode::kBatchNormTraining:
- case HloOpcode::kBatchNormInference:
- case HloOpcode::kBatchNormGrad:
- return feature_index() == other.feature_index() &&
- epsilon() == other.epsilon();
-
- // A constant is defined by the value in the literal.
- case HloOpcode::kConstant:
- return literal() == other.literal();
-
- // A reduce-precision operation is determined by the bit sizes.
- case HloOpcode::kReducePrecision:
- return exponent_bits() == other.exponent_bits() &&
- mantissa_bits() == other.mantissa_bits();
-
- // Convolution has a window and dimensions.
- case HloOpcode::kConvolution:
- return protobuf_util::ProtobufEquals(window(), other.window()) &&
- protobuf_util::ProtobufEquals(
- convolution_dimension_numbers(),
- other.convolution_dimension_numbers());
// Check dot dimension numbers.
case HloOpcode::kDot:
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
@@ -1854,83 +1544,57 @@ bool HloInstruction::IdenticalSlowPath(
other.gather_dimension_numbers()) &&
gather_window_bounds() == other.gather_window_bounds();
- // FFT has various types & lengths.
- case HloOpcode::kFft:
- return fft_type() == other.fft_type() &&
- fft_length() == other.fft_length();
-
- // Reduction results are determined by the reduction dimension and the
- // reduction computation.
- case HloOpcode::kReduce:
- return dimensions() == other.dimensions() &&
- eq_computations(to_apply(), other.to_apply());
- case HloOpcode::kReduceWindow:
- return eq_computations(to_apply(), other.to_apply()) &&
- protobuf_util::ProtobufEquals(window(), other.window());
-
- // SelectAndScatter is determined by both select and scatter
- // computation as well as the window configuration.
- case HloOpcode::kSelectAndScatter:
- return eq_computations(select(), other.select()) &&
- eq_computations(scatter(), other.scatter()) &&
- protobuf_util::ProtobufEquals(window(), other.window());
-
-
// Remaining instructions with special values.
- case HloOpcode::kGetTupleElement:
- return tuple_index() == other.tuple_index();
- case HloOpcode::kPad:
- return protobuf_util::ProtobufEquals(padding_config(),
- other.padding_config());
- case HloOpcode::kSlice:
- return slice_starts_ == other.slice_starts_ &&
- slice_limits_ == other.slice_limits_ &&
- slice_strides_ == other.slice_strides_;
case HloOpcode::kCall:
- case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
- case HloOpcode::kCustomCall:
- if ((window_ == nullptr) != (other.window_ == nullptr) ||
- (window_ != nullptr &&
- !protobuf_util::ProtobufEquals(window(), other.window()))) {
- return false;
- }
- if ((convolution_dimension_numbers_ == nullptr) !=
- (other.convolution_dimension_numbers_ == nullptr) ||
- (convolution_dimension_numbers_ != nullptr &&
- !protobuf_util::ProtobufEquals(
- convolution_dimension_numbers(),
- other.convolution_dimension_numbers()))) {
- return false;
- }
- return custom_call_target_ == other.custom_call_target_;
- case HloOpcode::kReverse:
- return dimensions() == other.dimensions();
case HloOpcode::kConditional:
return eq_computations(true_computation(), other.true_computation()) &&
eq_computations(false_computation(), other.false_computation());
- // These opcodes are not yet supported.
- case HloOpcode::kInfeed:
- case HloOpcode::kOutfeed:
- case HloOpcode::kSort:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
+ case HloOpcode::kDomain:
+ return operand_side_metadata().Matches(other.operand_side_metadata()) &&
+ user_side_metadata().Matches(other.user_side_metadata());
+
+ // Ops migrated to subclasses should never come to this line.
+ // TODO(b/80131774): Remove this switch when migration is complete.
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
+ case HloOpcode::kBatchNormGrad:
+ case HloOpcode::kFft:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
+ case HloOpcode::kRecvDone:
+ case HloOpcode::kReverse:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReduce:
+ case HloOpcode::kSort:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kMap:
+ case HloOpcode::kSlice:
+ case HloOpcode::kConstant:
+ case HloOpcode::kTrace:
+ case HloOpcode::kFusion:
+ case HloOpcode::kRng:
+ case HloOpcode::kParameter:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kOutfeed:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
case HloOpcode::kHostCompute:
- return false;
+ case HloOpcode::kPad:
+ case HloOpcode::kDynamicSlice:
+ LOG(FATAL) << "Base class impl called for opcode with subclass: "
+ << opcode();
}
}
-bool HloInstruction::IsRank2Transpose() const {
- return (opcode_ == HloOpcode::kTranspose) &&
- dimensions_ == std::vector<int64>({1, 0}) &&
- shape_.dimensions_size() == 2 &&
- std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
- operands_[0]->shape_.dimensions().rbegin());
-}
-
void HloInstruction::RemoveUser(HloInstruction* user) {
auto set_it = user_set_.find(user);
CHECK(set_it != user_set_.end());
@@ -1960,6 +1624,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
return Status::OK();
}
@@ -1968,10 +1636,14 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
TF_RET_CHECK(operand_num >= 0);
TF_RET_CHECK(operand_num < operand_count());
HloInstruction* old_operand = mutable_operand(operand_num);
+ if (old_operand == new_operand) {
+ return Status::OK();
+ }
+
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
- << old_operand->shape().ShortDebugString() << " is not compatible with "
- << new_operand->shape().ShortDebugString();
+ << old_operand->shape() << " is not compatible with "
+ << new_operand->shape();
operands_[operand_num] = new_operand;
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
@@ -1998,6 +1670,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
}
}
users_.clear();
@@ -2012,28 +1688,13 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
return Status::OK();
}
-void HloInstruction::DetachFromOperands() {
- VLOG(3) << "DetachFromOperands:\n " << ToString();
- CHECK_EQ(0, user_count());
- // An instruction may be repeated as an operand. To avoid calling RemoveUser
- // twice on the same operand, keep a set of already detached operands.
- std::set<HloInstruction*> detached_operands;
- for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
- HloInstruction* operand = operands_[operand_num];
- if (!ContainsKey(detached_operands, operand)) {
- operand->RemoveUser(this);
- detached_operands.insert(operand);
- }
- operands_[operand_num] = nullptr;
- }
-}
-
HloComputation* HloInstruction::to_apply() const {
switch (opcode_) {
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -2051,6 +1712,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@@ -2060,16 +1722,6 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
}
}
-const string& HloInstruction::custom_call_target() const {
- CHECK_EQ(opcode_, HloOpcode::kCustomCall);
- return custom_call_target_;
-}
-
-const string& HloInstruction::outfeed_config() const {
- CHECK_EQ(opcode_, HloOpcode::kOutfeed);
- return outfeed_config_;
-}
-
HloComputation* HloInstruction::while_condition() const {
CHECK_EQ(HloOpcode::kWhile, opcode_);
return called_computations_[kConditionComputationIndex];
@@ -2096,32 +1748,6 @@ void HloInstruction::set_while_body(HloComputation* computation) {
called_computations_[kBodyComputationIndex] = computation;
}
-HloComputation* HloInstruction::select() const {
- CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return called_computations_[kSelectComputationIndex];
-}
-
-HloComputation* HloInstruction::scatter() const {
- CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- return called_computations_[kScatterComputationIndex];
-}
-
-void HloInstruction::set_select(HloComputation* computation) {
- // Don't allow changing the computation for fused instructions so we don't
- // have to recompute called_instructions for the entire fusion instruction.
- CHECK(!IsFused());
- CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- called_computations_[kSelectComputationIndex] = computation;
-}
-
-void HloInstruction::set_scatter(HloComputation* computation) {
- // Don't allow changing the computation for fused instructions so we don't
- // have to recompute called_instructions for the entire fusion instruction.
- CHECK(!IsFused());
- CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
- called_computations_[kScatterComputationIndex] = computation;
-}
-
HloComputation* HloInstruction::true_computation() const {
CHECK_EQ(HloOpcode::kConditional, opcode_);
return called_computations_[kTrueComputationIndex];
@@ -2169,6 +1795,74 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
return ToStringWithCanonicalNameMap(options, &new_map);
}
+bool HloInstruction::IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const {
+ switch (opcode_) {
+ // Unary elementwise operations.
+ case HloOpcode::kAbs:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClz:
+ case HloOpcode::kConvert:
+ case HloOpcode::kBitcastConvert:
+ case HloOpcode::kCopy:
+ case HloOpcode::kCos:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kFloor:
+ case HloOpcode::kImag:
+ case HloOpcode::kIsFinite:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kNot:
+ case HloOpcode::kNegate:
+ case HloOpcode::kReal:
+ case HloOpcode::kReducePrecision:
+ case HloOpcode::kSign:
+ case HloOpcode::kSin:
+ case HloOpcode::kTanh:
+ CHECK_EQ(1, operand_count());
+ return true;
+
+ // Binary elementwise operations, the same as in IsElementwiseBinary().
+ case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kComplex:
+ case HloOpcode::kDivide:
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kAnd:
+ case HloOpcode::kOr:
+ case HloOpcode::kXor:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ CHECK_EQ(2, operand_count());
+ return true;
+
+ // Ternary elementwise operations.
+ case HloOpcode::kSelect:
+ case HloOpcode::kClamp:
+ return true;
+
+ case HloOpcode::kDynamicUpdateSlice:
+ return operand_idx.has_value() && operand_idx.value() == 0;
+
+ default:
+ return false;
+ }
+}
+
string HloInstruction::ToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
@@ -2219,112 +1913,45 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const {
string operands;
- if (opcode() == HloOpcode::kConstant) {
- // For constants, show the actual value in place of an empty operand list.
- //
- // In HloInstruction, sometimes a constant literal is not constructed due
- // to its size. Skip the printing in this case.
- if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) &&
- ShapeUtil::ElementsIn(shape()) <= 10) ||
- options.print_large_constants())) {
- // Literal::ToString emits multidimensional arrays over multiple
- // lines. Compact this into one line by stripping out white space.
- string tmp = literal().ToString();
- std::replace(tmp.begin(), tmp.end(), '\n', ' ');
- std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
- bool first = true;
- // Concatenate elements in "v" with spaces separating them, but ignoring
- // empty entries.
- for (const auto& s : v) {
- if (s.empty()) {
- continue;
- }
- StrAppend(&operands, (first ? "" : " "), s);
- first = false;
- }
- } else {
- // Do not show large constants or tuples.
- operands = "{...}";
+ tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
+ const int64 kMaxOperandsToShowIfCompact = 4;
+ if (options.compact_operands() &&
+ slice.size() > kMaxOperandsToShowIfCompact) {
+ slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
+ }
+ operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
+ // If operand is already been deleted, put `null` to the string output.
+ if (operand == nullptr) {
+ StrAppend(out, "null ");
+ return;
}
- } else if (opcode() == HloOpcode::kParameter) {
- StrAppend(&operands, parameter_number_);
- } else {
- tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
- const int64 kMaxOperandsToShowIfCompact = 4;
- if (options.compact_operands() &&
- slice.size() > kMaxOperandsToShowIfCompact) {
- slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
+ std::vector<string> str;
+ if (options.print_operand_shape()) {
+ str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
}
- operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) {
- std::vector<string> str;
- if (options.print_operand_shape()) {
- str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
- }
- // In a top-level HloInstruction::ToString() call, the operand name is not
- // part of the canonical string.
- if (options.canonicalize_instruction_names() &&
- options.is_in_nested_computation()) {
- str.push_back(PrintName(
- canonical_name_map->LookupOrInsert(operand->name()), options));
- } else if (!options.compact_operands()) {
- str.push_back(PrintName(operand->name(), options));
- }
- StrAppend(out, Join(str, " "));
- });
- const int64 remaining = operands_.size() - slice.size();
- if (slice.size() != operands_.size()) {
- StrAppend(&operands, ", ...(+", remaining, ")");
+ // In a top-level HloInstruction::ToString() call, the operand name is not
+ // part of the canonical string.
+ if (options.canonicalize_instruction_names() &&
+ options.is_in_nested_computation()) {
+ str.push_back(PrintName(
+ canonical_name_map->LookupOrInsert(operand->name()), options));
+ } else if (!options.compact_operands()) {
+ str.push_back(PrintName(operand->name(), options));
}
+ StrAppend(out, Join(str, " "));
+ });
+ const int64 remaining = operands_.size() - slice.size();
+ if (slice.size() != operands_.size()) {
+ StrAppend(&operands, ", ...(+", remaining, ")");
}
return operands;
}
std::vector<string> HloInstruction::ExtraAttributesToString(
const HloPrintOptions& options) const {
- std::vector<string> extra;
- if (opcode() == HloOpcode::kFusion) {
- extra.push_back(StrCat("kind=", xla::ToString(fusion_kind())));
- }
- if (CanHaveDimensionsField()) {
- extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
- }
- if (window_ != nullptr && window_->dimensions_size() != 0) {
- extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
- }
- if (padding_config_ != nullptr) {
- extra.push_back(
- StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
- }
- if (opcode() == HloOpcode::kSlice) {
- std::vector<string> bounds;
- bounds.reserve(slice_starts_.size());
- const bool omit_stride =
- std::all_of(slice_strides_.begin(), slice_strides_.end(),
- [](int64 stride) { return stride == 1; });
- for (int i = 0; i < slice_starts_.size(); ++i) {
- string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
- bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i],
- stride_str, "]"));
- }
- extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
- }
- if (opcode() == HloOpcode::kDynamicSlice) {
- extra.push_back(
- StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
- }
- if (opcode() == HloOpcode::kBatchNormTraining ||
- opcode() == HloOpcode::kBatchNormInference ||
- opcode() == HloOpcode::kBatchNormGrad) {
- extra.push_back(StrCat("epsilon=", epsilon()));
- extra.push_back(StrCat("feature_index=", feature_index()));
- }
+ std::vector<string> extra = ExtraAttributesToStringImpl(options);
- if (convolution_dimension_numbers_ != nullptr) {
- extra.push_back(StrCat(
- "dim_labels=",
- ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
- }
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
@@ -2333,10 +1960,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
extra.push_back(
StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
}
- if (opcode() == HloOpcode::kFft) {
- extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
- extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
- }
if (options.print_subcomputation_mode() ==
HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
@@ -2356,7 +1979,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
PrintName(false_computation()->name(), options)));
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
- opcode() == HloOpcode::kReduce) {
+ opcode() == HloOpcode::kReduce ||
+ opcode() == HloOpcode::kCrossReplicaSum) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
@@ -2391,6 +2015,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
@@ -2406,14 +2031,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
break;
}
}
- if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
- opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
- extra.push_back(StrCat("channel_id=", channel_id_));
- }
- if (opcode() == HloOpcode::kGetTupleElement) {
- extra.push_back(StrCat("index=", tuple_index()));
- }
if (has_sharding()) {
extra.push_back(StrCat("sharding=", sharding().ToString()));
}
@@ -2426,34 +2044,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
}),
"}"));
}
- if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) {
- extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\""));
- }
- if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) {
- extra.push_back(
- StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\""));
- }
- if (opcode() == HloOpcode::kRng) {
- extra.push_back(
- StrCat("distribution=", RandomDistributionToString(distribution_)));
- }
- if (opcode() == HloOpcode::kReducePrecision) {
- extra.push_back(StrCat("exponent_bits=", exponent_bits_));
- extra.push_back(StrCat("mantissa_bits=", mantissa_bits_));
- }
- if (operand_side_metadata_ != nullptr) {
- extra.push_back(
- StrCat("operand_side=", operand_side_metadata_->ToString()));
- }
- if (user_side_metadata_ != nullptr) {
- extra.push_back(StrCat("user_side=", user_side_metadata_->ToString()));
- }
- // By contract, we print the custom call target even if
- // options.print_subcomputation_mode() == kOff, because the call target is not
- // an HloComputation.
- if (opcode() == HloOpcode::kCustomCall) {
- extra.push_back(
- StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) {
+ extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(),
+ "\", entry=", user_side_metadata_->ToString(),
+ ", exit=", operand_side_metadata_->ToString(), "}"));
}
return extra;
@@ -2486,31 +2080,12 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_metadata() = metadata_;
proto.set_backend_config(backend_config_);
- if (literal_ != nullptr) {
- *proto.mutable_literal() = literal_->ToProto();
- }
- proto.set_parameter_number(parameter_number_);
- if (opcode() == HloOpcode::kFusion) {
- proto.set_fusion_kind(xla::ToString(fusion_kind()));
- proto.add_called_computation_ids(
- fused_instructions_computation()->unique_id());
- } else {
+ if (opcode() != HloOpcode::kFusion) {
for (const HloComputation* computation : called_computations_) {
proto.add_called_computation_ids(computation->unique_id());
}
}
- proto.set_tuple_index(tuple_index_);
- for (int64 dimension : dimensions_) {
- proto.add_dimensions(dimension);
- }
- if (window_ != nullptr) {
- *proto.mutable_window() = *window_;
- }
- if (convolution_dimension_numbers_ != nullptr) {
- *proto.mutable_convolution_dimension_numbers() =
- *convolution_dimension_numbers_;
- }
if (dot_dimension_numbers_ != nullptr) {
*proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
}
@@ -2522,42 +2097,11 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.add_gather_window_bounds(bound);
}
}
- for (int i = 0; i < slice_starts_.size(); ++i) {
- auto* slice_dimension = proto.add_slice_dimensions();
- slice_dimension->set_start(slice_starts_[i]);
- slice_dimension->set_limit(slice_limits_[i]);
- slice_dimension->set_stride(slice_strides_[i]);
- }
- proto.set_exponent_bits(exponent_bits_);
- proto.set_mantissa_bits(mantissa_bits_);
- for (int64 slice_size : dynamic_slice_sizes_) {
- proto.add_dynamic_slice_sizes(slice_size);
- }
- if (padding_config_ != nullptr) {
- *proto.mutable_padding_config() = *padding_config_;
- }
- proto.set_outfeed_config(outfeed_config_);
- if (opcode() == HloOpcode::kRng) {
- proto.set_distribution(distribution_);
- }
- proto.set_epsilon(epsilon_);
- proto.set_feature_index(feature_index_);
- proto.set_channel_id(channel_id_);
- proto.set_infeed_config(infeed_config_);
- proto.set_custom_call_target(custom_call_target_);
- *proto.mutable_outfeed_shape() = outfeed_shape_;
- proto.set_fft_type(fft_type_);
- for (int64 fft_len : fft_length_) {
- proto.add_fft_length(fft_len);
- }
if (has_sharding()) {
*proto.mutable_sharding() = sharding().ToProto();
}
- proto.set_channel_name(channel_name_);
- proto.set_cost_estimate_ns(cost_estimate_ns_);
-
return proto;
}
@@ -2567,35 +2111,6 @@ string HloInstruction::ToCategory() const {
return "data formatting";
}
- if (opcode() == HloOpcode::kConvolution) {
- string category = "convolution";
- if (window_util::HasBaseDilation(window())) {
- category += " base-dilated";
- }
- if (window_util::HasWindowDilation(window())) {
- category += " window-dilated";
- }
- return category;
- }
-
- // Give transpose-dot and backwards-conv fusions the categories "dot" and
- // "convolution" so they match the categories of proper kDot and kConvolution
- // ops. These fusion categories are really just a way of expressing a
- // particular kind of dot or conv, so they should have the same category as a
- // vanilla dot/conv.
- if (opcode() == HloOpcode::kFusion) {
- switch (fusion_kind()) {
- case FusionKind::kLoop:
- return "loop fusion";
- case FusionKind::kInput:
- return "input fusion";
- case FusionKind::kOutput:
- return "output fusion";
- case FusionKind::kCustom:
- return "custom fusion";
- }
- }
-
if (IsElementwise()) {
return "non-fusion elementwise";
}
@@ -2609,12 +2124,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
trace_instruction_ = trace_instruction;
}
-string HloInstruction::TracingTag() const {
- CHECK_EQ(HloOpcode::kTrace, opcode());
- CHECK(literal_ != nullptr);
- return literal_->GetR1U8AsString();
-}
-
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
bool HloInstruction::IsFusable() const {
@@ -2633,51 +2142,6 @@ bool HloInstruction::IsFusable() const {
}
}
-HloComputation* HloInstruction::fused_instructions_computation() const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- CHECK(!called_computations_.empty());
- auto* fused_instructions_computation = called_computations_.front();
- CHECK(fused_instructions_computation->IsFusionComputation())
- << "Computation " << fused_instructions_computation->name()
- << " is not a fusion kind";
- return fused_instructions_computation;
-}
-
-HloInstruction* HloInstruction::fused_expression_root() const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_computation()->root_instruction();
-}
-
-HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_computation()->parameter_instruction(
- parameter_number);
-}
-
-const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_computation()->parameter_instructions();
-}
-
-const tensorflow::gtl::iterator_range<UnwrappingIterator<
- std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
-HloInstruction::fused_instructions() const {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- const HloComputation* subcomp = fused_instructions_computation();
- return subcomp->instructions();
-}
-
-const tensorflow::gtl::iterator_range<
- UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
-HloInstruction::fused_instructions() {
- CHECK_EQ(opcode_, HloOpcode::kFusion);
- return fused_instructions_computation()->instructions();
-}
-
-int64 HloInstruction::fused_instruction_count() const {
- return fused_instructions_computation()->instruction_count();
-}
-
HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
: unique_id_(-1),
opcode_(opcode),
@@ -2732,6 +2196,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleAnd(this);
case HloOpcode::kOr:
return visitor->HandleOr(this);
+ case HloOpcode::kXor:
+ return visitor->HandleXor(this);
case HloOpcode::kShiftLeft:
return visitor->HandleShiftLeft(this);
case HloOpcode::kShiftRightArithmetic:
@@ -2756,6 +2222,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleRemainder(this);
case HloOpcode::kSelect:
return visitor->HandleSelect(this);
+ case HloOpcode::kTupleSelect:
+ return visitor->HandleTupleSelect(this);
case HloOpcode::kConvolution:
return visitor->HandleConvolution(this);
case HloOpcode::kFft:
@@ -2856,6 +2324,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGather(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
+ case HloOpcode::kAfterAll:
+ return visitor->HandleAfterAll(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -3096,12 +2566,6 @@ Status HloInstruction::AcceptOrdered(
return visitor->FinishVisit(this);
}
-const Shape& HloInstruction::outfeed_shape() const {
- DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
- return outfeed_shape_;
-}
-
const Shape& HloInstruction::shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
return shape_;
@@ -3123,87 +2587,7 @@ bool HloInstruction::IsElementwiseBinary() const {
}
bool HloInstruction::IsElementwise() const {
- switch (opcode_) {
- // Nullary elementwise operations.
- case HloOpcode::kConstant:
- return true;
-
- // Unary elementwise operations.
- case HloOpcode::kAbs:
- case HloOpcode::kRoundNearestAfz:
- case HloOpcode::kCeil:
- case HloOpcode::kClz:
- case HloOpcode::kConvert:
- case HloOpcode::kBitcastConvert:
- case HloOpcode::kCopy:
- case HloOpcode::kCos:
- case HloOpcode::kExp:
- case HloOpcode::kExpm1:
- case HloOpcode::kFloor:
- case HloOpcode::kImag:
- case HloOpcode::kIsFinite:
- case HloOpcode::kLog:
- case HloOpcode::kLog1p:
- case HloOpcode::kNot:
- case HloOpcode::kNegate:
- case HloOpcode::kReal:
- case HloOpcode::kReducePrecision:
- case HloOpcode::kSign:
- case HloOpcode::kSin:
- case HloOpcode::kTanh:
- CHECK_EQ(1, operand_count());
- return true;
-
- // Binary elementwise operations, the same as in IsElementwiseBinary().
- case HloOpcode::kAdd:
- case HloOpcode::kAtan2:
- case HloOpcode::kComplex:
- case HloOpcode::kDivide:
- case HloOpcode::kEq:
- case HloOpcode::kGe:
- case HloOpcode::kGt:
- case HloOpcode::kLe:
- case HloOpcode::kLt:
- case HloOpcode::kMaximum:
- case HloOpcode::kMinimum:
- case HloOpcode::kMultiply:
- case HloOpcode::kNe:
- case HloOpcode::kPower:
- case HloOpcode::kRemainder:
- case HloOpcode::kSubtract:
- case HloOpcode::kAnd:
- case HloOpcode::kOr:
- case HloOpcode::kShiftLeft:
- case HloOpcode::kShiftRightArithmetic:
- case HloOpcode::kShiftRightLogical:
- CHECK_EQ(2, operand_count());
- return true;
-
- // Ternary elementwise operations.
- case HloOpcode::kSelect:
- return !ShapeUtil::IsTuple(shape_);
- case HloOpcode::kClamp:
- return true;
-
- // Other operations.
- case HloOpcode::kRng:
- case HloOpcode::kMap:
- return true;
- case HloOpcode::kFusion:
- if (fusion_kind() != FusionKind::kLoop) {
- return false;
- }
- for (auto* fused : fused_instructions()) {
- if (fused->opcode() != HloOpcode::kParameter &&
- !fused->IsElementwise()) {
- return false;
- }
- }
- return true;
-
- default:
- return false;
- }
+ return IsElementwiseImpl(tensorflow::gtl::nullopt);
}
bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
@@ -3211,54 +2595,8 @@ bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const {
return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape());
}
-namespace {
-bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
- const HloInstruction* operand) {
- std::vector<int64> operand_indices = instruction->OperandIndices(operand);
- return std::all_of(
- operand_indices.begin(), operand_indices.end(),
- [instruction](int64 operand_index) {
- return instruction->IsElementwiseOnOperand(operand_index);
- });
-}
-} // namespace
-
bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
- // For all instructions other than kFusion, being elementwise on one of the
- // operands is equivalent to being elementwise on all the operands.
- if (opcode() != HloOpcode::kFusion) {
- return IsElementwise();
- }
-
- CHECK_EQ(HloOpcode::kFusion, opcode());
- if (fusion_kind() != FusionKind::kLoop) {
- return false;
- }
-
- // A loop-fusion is elementwise on an operand if all operations (computed
- // using BFS) between the operand and the fused root are elementwise.
- std::deque<HloInstruction*> worklist;
- std::unordered_set<const HloInstruction*> visited;
- worklist.push_back(fused_parameter(operand_idx));
- visited.insert(fused_parameter(operand_idx));
- while (!worklist.empty()) {
- HloInstruction* operand = worklist.front();
- worklist.pop_front();
- for (HloInstruction* user : operand->users()) {
- CHECK_GE(user->unique_id(), 0);
- if (ContainsKey(visited, user)) {
- continue;
- }
- if (user->IsElementwise() ||
- IsInstructionElementwiseOnOperand(user, operand)) {
- worklist.push_back(user);
- visited.insert(user);
- } else {
- return false;
- }
- }
- }
- return true;
+ return IsElementwiseImpl(operand_idx);
}
// A helper class for memoized, recursive computation of HloOpcode::kFusion
@@ -3280,8 +2618,10 @@ class HloInstruction::FusionReusesParamElements {
static UseKind ComputeInternal(
int64 i, const HloInstruction& hlo,
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
- if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) {
- return UseKind::kUse;
+ if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
+ if (hlo_param->parameter_number() == i) {
+ return UseKind::kUse;
+ }
}
auto p = cache->emplace(&hlo, UseKind{});
@@ -3590,21 +2930,264 @@ void HloInstruction::set_outer_dimension_partitions(
outer_dimension_partitions_ = outer_dimension_partitions;
}
+// TODO(b/80131774): Remove these temporary methods after transition.
+int64 HloInstruction::feature_index() const {
+ return Cast<HloBatchNormInstruction>(this)->feature_index();
+}
+
+float HloInstruction::epsilon() const {
+ return Cast<HloBatchNormInstruction>(this)->epsilon();
+}
+
+FftType HloInstruction::fft_type() const {
+ return Cast<HloFftInstruction>(this)->fft_type();
+}
+
+const std::vector<int64>& HloInstruction::fft_length() const {
+ return Cast<HloFftInstruction>(this)->fft_length();
+}
+
+int64 HloInstruction::channel_id() const {
+ return Cast<HloSendRecvInstruction>(this)->channel_id();
+}
+
+int64 HloInstruction::concatenate_dimension() const {
+ return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
+}
+
+bool HloInstruction::IsRank2Transpose() const {
+ auto transpose = DynCast<HloTransposeInstruction>(this);
+ return transpose != nullptr && transpose->IsRank2Transpose();
+}
+
+int64 HloInstruction::slice_starts(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_starts() const {
+ return Cast<HloSliceInstruction>(this)->slice_starts();
+}
+
+int64 HloInstruction::slice_limits(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_limits() const {
+ return Cast<HloSliceInstruction>(this)->slice_limits();
+}
+
+int64 HloInstruction::slice_strides(int64 dimension) const {
+ return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
+}
+
+const std::vector<int64>& HloInstruction::slice_strides() const {
+ return Cast<HloSliceInstruction>(this)->slice_strides();
+}
+
+bool HloInstruction::IsInPlaceSlice() const {
+ return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
+}
+
+const Literal& HloInstruction::literal() const {
+ return Cast<HloConstantInstruction>(this)->literal();
+}
+
+bool HloInstruction::IsConstant() const {
+ return DynCast<HloConstantInstruction>(this) != nullptr;
+}
+
void HloInstruction::RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index) {
- CHECK_EQ(opcode(), HloOpcode::kConstant);
- Shape* mutable_array_subshape =
- ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
- CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
+ Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
+}
+
+string HloInstruction::TracingTag() const {
+ return Cast<HloTraceInstruction>(this)->TracingTag();
+}
+
+HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
+ return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
+}
+
+// Delegates to HloFusionInstruction::MergeFusionInstruction.
+void HloInstruction::MergeFusionInstruction(
+ HloInstruction* instruction_to_merge) {
+ return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
+ Cast<HloFusionInstruction>(instruction_to_merge));
+}
- // Normally array_subshape will always have a layout, but this invariant is
- // temporarily broken in LayoutAssignment::AssignLayouts.
+// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
+void HloInstruction::MergeFusionInstructionIntoMultiOutput(
+ HloInstruction* instruction_to_merge) {
+ return Cast<HloFusionInstruction>(this)
+ ->MergeFusionInstructionIntoMultiOutput(
+ Cast<HloFusionInstruction>(instruction_to_merge));
+}
+
+HloInstruction* HloInstruction::FuseInstruction(
+ HloInstruction* instruction_to_fuse) {
+ return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
+}
+
+HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
+ HloInstruction* instruction_to_fuse) {
+ return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
+ instruction_to_fuse);
+}
+
+HloComputation* HloInstruction::fused_instructions_computation() const {
+ return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
+}
+
+HloInstruction* HloInstruction::fused_expression_root() const {
+ return Cast<HloFusionInstruction>(this)->fused_expression_root();
+}
+
+const tensorflow::gtl::iterator_range<UnwrappingIterator<
+ std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
+HloInstruction::fused_instructions() const {
+ return Cast<HloFusionInstruction>(this)->fused_instructions();
+}
+
+const tensorflow::gtl::iterator_range<
+ UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
+HloInstruction::fused_instructions() {
+ return Cast<HloFusionInstruction>(this)->fused_instructions();
+}
+
+int64 HloInstruction::fused_instruction_count() const {
+ return Cast<HloFusionInstruction>(this)->fused_instruction_count();
+}
+
+HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
+ return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
+}
+
+const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
+ return Cast<HloFusionInstruction>(this)->fused_parameters();
+}
+
+const bool HloInstruction::IsMultiOutputFusion() const {
+ const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
+ return fusion != nullptr && fusion->IsMultiOutputFusion();
+}
+
+HloInstruction::FusionKind HloInstruction::fusion_kind() const {
+ return Cast<HloFusionInstruction>(this)->fusion_kind();
+}
+
+void HloInstruction::set_fusion_kind(FusionKind kind) {
+ return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
+}
- if (!mutable_array_subshape->has_layout() ||
- !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
- literal_ = literal_->Relayout(new_layout, shape_index);
- *mutable_array_subshape->mutable_layout() = new_layout;
+RandomDistribution HloInstruction::random_distribution() const {
+ return Cast<HloRngInstruction>(this)->random_distribution();
+}
+
+int64 HloInstruction::parameter_number() const {
+ return Cast<HloParameterInstruction>(this)->parameter_number();
+}
+
+int64 HloInstruction::tuple_index() const {
+ return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
+}
+
+int32 HloInstruction::exponent_bits() const {
+ return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
+}
+
+int32 HloInstruction::mantissa_bits() const {
+ return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
+}
+
+string HloInstruction::infeed_config() const {
+ return Cast<HloInfeedInstruction>(this)->infeed_config();
+}
+
+void HloInstruction::set_infeed_config(const string& config) {
+ return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
+}
+
+const Shape& HloInstruction::outfeed_shape() const {
+ return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
+}
+
+const string& HloInstruction::outfeed_config() const {
+ return Cast<HloOutfeedInstruction>(this)->outfeed_config();
+}
+
+const std::vector<int64>& HloInstruction::replica_group_ids() const {
+ return Cast<HloAllReduceInstruction>(this)->replica_group_ids();
+}
+
+string HloInstruction::cross_replica_sum_barrier() const {
+ return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier();
+}
+
+void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) {
+ return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier(
+ barrier);
+}
+
+tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const {
+ return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
+}
+
+const ConvolutionDimensionNumbers&
+HloInstruction::convolution_dimension_numbers() const {
+ if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
+ return convolution->convolution_dimension_numbers();
+ }
+ if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
+ return custom_call->convolution_dimension_numbers();
+ }
+ LOG(FATAL) << "Unimplemented method.";
+}
+
+void HloInstruction::set_convolution_dimension_numbers(
+ const ConvolutionDimensionNumbers& dnums) {
+ if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
+ convolution->set_convolution_dimension_numbers(dnums);
+ } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
+ custom_call->set_convolution_dimension_numbers(dnums);
+ } else {
+ LOG(FATAL) << "Unimplemented method.";
}
}
+HloComputation* HloInstruction::select() const {
+ return Cast<HloSelectAndScatterInstruction>(this)->select();
+}
+
+HloComputation* HloInstruction::scatter() const {
+ return Cast<HloSelectAndScatterInstruction>(this)->scatter();
+}
+
+void HloInstruction::set_select(HloComputation* computation) {
+ return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
+}
+
+void HloInstruction::set_scatter(HloComputation* computation) {
+ return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
+}
+
+const string& HloInstruction::custom_call_target() const {
+ return Cast<HloCustomCallInstruction>(this)->custom_call_target();
+}
+
+const string& HloInstruction::channel_name() const {
+ return Cast<HloHostComputeInstruction>(this)->channel_name();
+}
+
+const PaddingConfig& HloInstruction::padding_config() const {
+ return Cast<HloPadInstruction>(this)->padding_config();
+}
+
+int64 HloInstruction::slice_sizes(int64 dimension) const {
+ return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
+}
+
+const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
+ return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d47af6c018..b392d65636 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -33,7 +33,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/iterator_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -322,7 +322,7 @@ class HloInstruction {
kCustom,
};
- ~HloInstruction();
+ virtual ~HloInstruction();
// Creates an instruction from the given proto. Arguments:
//
@@ -389,11 +389,10 @@ class HloInstruction {
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
- // at a given index) with the same `static_operands`.
+ // at a given index)
static std::unique_ptr<HloInstruction> CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloComputation* map_computation,
- tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
+ HloComputation* map_computation);
// Creates a convolution op, where rhs is the convolutional filter
// and window describes how the filter is applied to lhs.
@@ -426,10 +425,27 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, const int exponent_bits,
const int mantissa_bits);
- // Creates a cross replica sum op.
+ // Creates a cross replica reduction op.
+ //
+ // `reduction_computation`: the reduction function.
+ //
+ // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+ // replicas belong to one group. Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ //
+ // `all_reduce_id`: for Allreduce nodes from different modules, if they have
+ // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
+ // not be applied cross modules.
+ //
+ // TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id =
+ tensorflow::gtl::nullopt);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
@@ -442,19 +458,36 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand);
// Creates an infeed instruction, which reads data of the given shape from the
- // Infeed interface of the device.
- static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
+ // Infeed interface of the device. infeed_shape is the shape of the data
+ // received from the infeed *not* the shape of the infeed instruction which
+ // is a tuple containing the infeed_shape and the TOKEN.
+ static std::unique_ptr<HloInstruction> CreateInfeed(
+ const Shape& infeed_shape, HloInstruction* token_operand,
+ const string& config);
+ // Overload which does not require a token.
+ // TODO(b/80000000): Remove this overload when all uses of infeed are
+ // converted to take tokens.
+ static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& infeed_shape,
const string& config);
- // Creates an outfeed instruction, which outputs data.
+ // Creates an outfeed instruction, which outputs data. outfeed_shape is the
+ // shape of the data being outfed *not* the shape of the outfeed instruction
+ // which is a TOKEN.
static std::unique_ptr<HloInstruction> CreateOutfeed(
- const Shape& shape, HloInstruction* operand,
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config);
+ // Overload which does not require a token.
+ // TODO(b/80000000): Remove this overload when all uses of outfeed are
+ // converted to take tokens.
+ static std::unique_ptr<HloInstruction> CreateOutfeed(
+ const Shape& outfeed_shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
// another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
+ HloInstruction* token,
int64 channel_id);
// Blocks until data transfer for the Send instruction (operand) is complete.
@@ -466,6 +499,7 @@ class HloInstruction {
// which allocates resources to receive data of the given shape from a unique
// send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
+ HloInstruction* token,
int64 channel_id);
// Blocks until data transfer for the Recv instruction (operand) is complete
@@ -579,6 +613,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a sort op, with a keys operand, and an optional values operand.
+ static std::unique_ptr<HloInstruction> CreateSort(
+ const Shape& shape, int64 dimension, HloInstruction* keys,
+ HloInstruction* values = nullptr);
+
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
// example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
@@ -648,6 +687,19 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a Afterall instruction used for joining or creating new values of
+ // token type which thread through side-effecting operations. Operands must
+ // all be tokens, and there must be at least one operand.
+ static std::unique_ptr<HloInstruction> CreateAfterAll(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
+ // Creates an AfterAll instruction which creates a token type out of thin air
+ // (no operands). This is a separate method from CreateAfterAll to facility
+ // the removal of operand-less AfterAll instructions.
+ // TODO(b/110532604): Remove this capability of creating a token from nothing
+ // when we plumb a primordial token from the entry computation.
+ static std::unique_ptr<HloInstruction> CreateToken();
+
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
tensorflow::gtl::ArraySlice<int64> output_window_dims,
@@ -786,15 +838,18 @@ class HloInstruction {
// Returns whether the instruction has a constant operand.
bool HasConstantOperand() const;
- // Returns whether this instruction does a rank-2 transposition.
- bool IsRank2Transpose() const;
-
// Replaces the use of this instruction in "user" with "new_producer". Note
// that there might be multiple uses of this instruction in "user"; all will
// be replaced.
+ //
+ // If user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
// Replaces the specified operand with new_operand.
+ //
+ // This function does NOT remove duplicated operands even if this instruction
+ // is a fusion, so that the existing operand numbers do not change.
Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
// Replaces all uses of this instruction with the new producer. If
@@ -803,14 +858,10 @@ class HloInstruction {
//
// If this instruction is the root of its computation, sets the computation's
// root to new_producer.
- Status ReplaceAllUsesWith(HloInstruction* new_producer);
-
- // Detaches an instruction from its operands. That is, remove the instruction
- // from each operand's user set. This should only be called prior to
- // deallocating the instruction.
//
- // TODO(b/78305363): Make this automatic when deleting an instruction.
- void DetachFromOperands();
+ // If a user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
+ Status ReplaceAllUsesWith(HloInstruction* new_producer);
// Performs a postorder DFS visit using this node as the root. If
// call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
@@ -857,38 +908,6 @@ class HloInstruction {
template <typename HloInstructionPtr>
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
- // Returns the literal associated with this instruction.
- //
- // Note: only constant and parameter opcodes have an associated literal.
- const Literal& literal() const;
-
- // Returns whether there is literal associated with this instruction.
- bool HasLiteral() const;
-
- // Returns the parameter number associated with this instruction.
- //
- // Note: only parameter opcodes have an associated parameter number.
- int64 parameter_number() const {
- CHECK_EQ(HloOpcode::kParameter, opcode_);
- return parameter_number_;
- }
-
- // Returns the dimension sizes or numbers associated with this instruction.
- //
- // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
- // and reverse.
- const std::vector<int64>& dimensions() const;
- int64 dimensions(int64 index) const;
-
- // Accessor for the dimension in which a concatenate HLO should occur.
- // Precondition: opcode() == HloOpcode::kConcatenate
- int64 concatenate_dimension() const;
-
- // Returns the tuple index associated with this instruction.
- //
- // Precondition: opcode() == HloOpcode::kGetTupleElement
- int64 tuple_index() const;
-
// Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
// (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
@@ -916,18 +935,6 @@ class HloInstruction {
HloComputation* to_apply() const;
void set_to_apply(HloComputation* to_apply);
- // Returns the custom_call_target for CustomCall.
- // Precondition: opcode() == HloOpcode::kCustomCall
- const string& custom_call_target() const;
-
- // Returns the config for the Outfeed instruction.
- // Precondition: opcode() == HloOpcode::kOutfeed
- const string& outfeed_config() const;
-
- // Returns the shape for the Outfeed instruction.
- // Precondition: opcode() == HloOpcode::kOutfeed
- const Shape& outfeed_shape() const;
-
// Gets/sets the while_condition or while_body HloComputation for While. The
// setters should only be called by HloModule or HloComputation methods.
//
@@ -937,15 +944,6 @@ class HloInstruction {
void set_while_condition(HloComputation* while_condition);
void set_while_body(HloComputation* while_body);
- // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
- // setters should only be called by HloModule or HloComputation methods.
- //
- // Precondition: opcode() == HloOpcode::kSelectAndScatter.
- HloComputation* select() const;
- HloComputation* scatter() const;
- void set_select(HloComputation* select);
- void set_scatter(HloComputation* scatter);
-
// Gets/sets the true and false HloComputation for Conditional. The setters
// should only be called by HloModule or HloComputation methods.
//
@@ -983,11 +981,11 @@ class HloInstruction {
string ToShortString() const;
// Returns a serialized representation of this instruction.
- HloInstructionProto ToProto() const;
+ virtual HloInstructionProto ToProto() const;
// Returns a category for the HLO. This could be something like "convolution"
// or "elementwise".
- string ToCategory() const;
+ virtual string ToCategory() const;
// Returns a logging instruction, if the output of this instruction is logged.
//
@@ -995,111 +993,14 @@ class HloInstruction {
HloInstruction* tracing() const;
void set_tracing(HloInstruction* trace_instruction);
- // Returns the channel id associated with the instruction. The id is
- // shared between each Send/Recv pair and is globally unique to identify each
- // channel.
- //
- // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
- int64 channel_id() const { return channel_id_; }
-
- // Returns the channel name associated with the instruction. The name is
- // used to identify host Send/Recv operations.
- //
- // Precondition: opcode() == HloOpcode::kHostCompute
- string channel_name() const { return channel_name_; }
-
- // Returns feature_index field associated with the instruction. The index
- // represents the index of the feature dimension.
- //
- // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
- // or kBatchNormGrad.
- int64 feature_index() const { return feature_index_; }
-
- // Returns a epsilon value associated with the instruction. The is a small
- // number added to the variance to avoid divide-by-zero error.
- //
- // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
- // or kBatchNormGrad.
- float epsilon() const { return epsilon_; }
-
- // Returns the infeed configuration string. The infeed configuration includes
- // any metadata needed for the backend compiler (e.g., infeed buffer address)
- // and is target-dependent.
- string infeed_config() const { return infeed_config_; }
- void set_infeed_config(const string& config) { infeed_config_ = config; }
-
- // Returns a tag to be used in tracing.
- //
- // Precondition: opcode() == HloOpcode::kTrace
- string TracingTag() const;
-
- // Returns whether the instruction is a constant.
- bool IsConstant() const;
-
// Returns true if this instruction is fused, ie contained within a fusion
// instruction.
bool IsFused() const;
- // Returns the computation for this fused instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloComputation* fused_instructions_computation() const;
-
// Returns true if this instruction can be legally fused into a fusion
// instruction.
bool IsFusable() const;
- // Returns the root instruction of the fused expression contained within this
- // fusion instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloInstruction* fused_expression_root() const;
-
- // Returns the list of fused instructions inside this fusion instruction. The
- // returned type is a range of HloInstruction*s.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- const tensorflow::gtl::iterator_range<UnwrappingIterator<
- std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
- fused_instructions() const;
-
- const tensorflow::gtl::iterator_range<
- UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
- fused_instructions();
-
- // Gets the number of instructions inside this fusion instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- int64 fused_instruction_count() const;
-
- // Returns the fused parameter instruction in this fusion instruction
- // corresponding to the given parameter number.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloInstruction* fused_parameter(int64 parameter_number) const;
-
- // Returns the vector of fused parameters inside this fusion instruction.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- const std::vector<HloInstruction*>& fused_parameters() const;
-
- // Returns true if this instruction is a fusion instruction that generates
- // multiple outputs.
- const bool IsMultiOutputFusion() const {
- return opcode() == HloOpcode::kFusion &&
- fused_expression_root()->opcode() == HloOpcode::kTuple;
- }
-
- FusionKind fusion_kind() const {
- CHECK_EQ(HloOpcode::kFusion, opcode_);
- return fusion_kind_;
- }
-
- void set_fusion_kind(FusionKind kind) {
- CHECK_EQ(HloOpcode::kFusion, opcode_);
- fusion_kind_ = kind;
- }
-
// Returns the sharding applied to this operator.
// REQUIRES: has_sharding() is true.
const HloSharding& sharding() const {
@@ -1124,8 +1025,11 @@ class HloInstruction {
void set_sharding(const HloSharding& sharding) {
sharding_ = MakeUnique<HloSharding>(sharding);
}
+ void set_single_sharding(const HloSharding& sharding);
// Sets a sharding that assigns the current instruction to device.
- void set_device_sharding(int64 device);
+ void set_device_sharding(int64 device) {
+ set_single_sharding(HloSharding::AssignDevice(device));
+ }
// Remove any sharding from this operator.
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
@@ -1155,167 +1059,17 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
- // Adds a new operand the fusion instruction.
- HloInstruction* AddFusionOperand(HloInstruction* new_operand);
-
- // Merges the fused instructions from 'instruction_to_merge' into the
- // fused instruction set of 'this', updating operands as necessary.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- // Predondition: 'instruction_to_merge' must be an operand of 'this'.
- void MergeFusionInstruction(HloInstruction* instruction_to_merge);
-
- // Merges the fused instructions from instruction_to_merge into the fused
- // instruction set of 'this' and generates multioutput fusion instructions.
- // All the users of instruction_to_merge will be redirected to 'this'
- // instruction. instruction_to_merge will be removed from its parent
- // computation.
- //
- // Precondition: opcode() == HloOpcode::kFusion
- void MergeFusionInstructionIntoMultiOutput(
- HloInstruction* instruction_to_merge);
-
- // Fuses the given instruction in this fusion instruction. instruction_to_fuse
- // is cloned and the clone is placed in the fusion
- // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
- // than moved to cleanly handle the case where the instruction has a use
- // outside the fusion instruction. Moving such an instruction into a fusion
- // instruction would violate the single-result invariant of HLO instructions
- // and significantly complicate code generation.
- //
- // Precondition: this->opcode() == HloOpcode::kFusion
- HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
- return FuseInstructionInternal(instruction_to_fuse);
- }
-
- // Fuses the given instruction in this fusion instruction and generate
- // multioutput fusion instruction. A clone of the instruction_to_fuse will
- // be part of the output of fusion instructions. The users of
- // instruction_to_fuse will be redirected to this fusion instructions.
- // instruction_to_fuse will be removed from its parent computation.
- //
- // Precondition: this->opcode() == HloOpcode::kFusion
- HloInstruction* FuseInstructionIntoMultiOutput(
- HloInstruction* instruction_to_fuse) {
- return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
- }
-
- // Returns the start index in the given dimension for a slice node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_starts(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_starts_[dimension];
+ // TODO(b/80249101): Remove these methods once HLO scheduling and copy
+ // insertion are integrated, and we don't need to run a separate pass
+ // of copy elision anymore.
+ bool CopyElisionAllowed() const {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ return copy_elision_allowed_;
}
- const std::vector<int64>& slice_starts() const { return slice_starts_; }
- // Returns the (exclusive) limit index in the given dimension for a slice
- // node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_limits(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_limits_[dimension];
- }
- const std::vector<int64>& slice_limits() const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_limits_;
- }
-
- // Returns the stride in the given dimension for a slice node.
- //
- // Precondition: opcode() == HloOpcode::kSlice
- int64 slice_strides(int64 dimension) const {
- CHECK_EQ(HloOpcode::kSlice, opcode_);
- return slice_strides_[dimension];
- }
- const std::vector<int64>& slice_strides() const { return slice_strides_; }
-
- // Returns the flag that describes whether a slice must be lowered into an
- // offset into the original operand.
- bool IsInPlaceSlice() const { return is_in_place_slice_; }
-
- // Sets and returns the flag that describes whether a slice must be lowered
- // into an offset into the original operand.
- bool SetIsInPlaceSlice(bool value) {
- is_in_place_slice_ = value;
- return value;
- }
-
- // Returns the size of the slice in the given dimension for a dynamic
- // slice node.
- //
- // Precondition: opcode() == HloOpcode::kDynamicSlice
- int64 slice_sizes(int64 dimension) const {
- CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
- return dynamic_slice_sizes_[dimension];
- }
- const std::vector<int64>& dynamic_slice_sizes() const {
- CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
- return dynamic_slice_sizes_;
- }
-
- // Returns the number of exponent bits for a reduce-precision node.
- //
- // Precondition: opcode() == HloOpcode::kReducePrecision
- int32 exponent_bits() const {
- CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
- return exponent_bits_;
- }
-
- // Returns the number of mantissa bits for a reduce-precision node.
- //
- // Precondition: opcode() == HloOpcode::kReducePrecision
- int32 mantissa_bits() const {
- CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
- return mantissa_bits_;
- }
-
- // Returns data on the window in a windowed operation such as
- // convolution.
- const Window& window() const {
- CHECK(window_ != nullptr);
- return *window_;
- }
-
- // Sets the window data in a windowed operation such as convolution.
- void set_window(const Window& window) {
- window_ = MakeUnique<Window>(window);
- }
-
- // Returns the padding configuration for a pad node.
- //
- // Precondition: opcode() == HloOpcode::kPad
- const PaddingConfig& padding_config() const {
- CHECK(padding_config_ != nullptr);
- return *padding_config_;
- }
-
- // Returns data on the dimension numbers used for a convolution operation,
- // which may be a kConvolution instruction or a kCustomCall that implements a
- // convolution.
- const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
- CHECK(convolution_dimension_numbers_ != nullptr);
- return *convolution_dimension_numbers_;
- }
-
- // Sets the convolution dimension numbers on this instruction. In general you
- // shouldn't need to call this; instead, specify the convolution dimension
- // numbers when you create the instruction.
- void set_convolution_dimension_numbers(
- const ConvolutionDimensionNumbers& dnums) {
- convolution_dimension_numbers_ =
- MakeUnique<ConvolutionDimensionNumbers>(dnums);
- }
-
- FftType fft_type() const {
- CHECK_EQ(HloOpcode::kFft, opcode_);
- return fft_type_;
- }
-
- const std::vector<int64>& fft_length() const {
- CHECK_EQ(HloOpcode::kFft, opcode_);
- return fft_length_;
+ void SetCopyElisionAllowed(bool value) {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ copy_elision_allowed_ = value;
}
// Returns data on the dimension numbers used for a dot operation.
@@ -1340,11 +1094,6 @@ class HloInstruction {
// Returns the dump string of the gather dimension numbers.
string GatherDimensionNumbersToString() const;
- // Returns the random distribution for this rng node.
- //
- // Precondition: opcode() == HloOpcode::kRng
- RandomDistribution random_distribution() const;
-
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1355,7 +1104,8 @@ class HloInstruction {
// Clones the HLO instruction as above but with new shape and operands.
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
HloCloneContext* context = nullptr) const;
// Returns the computations this instruction directly calls (if any).
@@ -1426,9 +1176,14 @@ class HloInstruction {
std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
- // Gets/sets the string identifier for this instruction.
+ // Gets the string identifier for this instruction.
const string& name() const { return name_; }
- void set_name(tensorflow::StringPiece name) { name_ = std::string(name); }
+
+ // Sets the string identifier for this instruction. Name will be sanitized to
+ // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
+ void SetAndSanitizeName(const string& name) {
+ name_ = NameUniquer::GetSanitizedName(name);
+ }
// Use the given NameUniquer to select a unique name for the instruction based
// on the instruction's existing name.
@@ -1509,13 +1264,273 @@ class HloInstruction {
void set_outer_dimension_partitions(
const std::vector<int64>& outer_dimension_partitions);
- // Change the layout for an Constant Hlo instruction to match new_layout. For
- // tuple shaped constants shape_index is the path to the internal array
- // subshape whose layout needs to be changed.
+ // Old methods kept for smooth subclassing transition BEGIN.
+ // TODO(b/80131774): Remove this code.
+
+ // Delegates to HloBatchNormInstruction::feature_index.
+ int64 feature_index() const;
+
+ // Delegates to HloBatchNormInstruction::epsilon.
+ float epsilon() const;
+
+ // Delegates to HloFftInstruction::fft_type.
+ FftType fft_type() const;
+
+ // Delegates to HloFftInstruction::fft_length.
+ const std::vector<int64>& fft_length() const;
+
+ // Delegates to HloSendRecvInstruction::channel_id.
+ int64 channel_id() const;
+
+ // Returns the dimension sizes or numbers associated with this instruction.
+ virtual const std::vector<int64>& dimensions() const {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+ virtual int64 dimensions(int64 index) const {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Delegates to HloConcatenateInstruction::concatenate_dimension.
+ int64 concatenate_dimension() const;
+
+ // Returns whether this instruction does a rank-2 transposition.
+ bool IsRank2Transpose() const;
+
+ // Delegates to HloSliceInstruction::slice_start.
+ int64 slice_starts(int64 dimension) const;
+ const std::vector<int64>& slice_starts() const;
+
+ // Delegates to HloSliceInstruction::slice_limits.
+ int64 slice_limits(int64 dimension) const;
+ const std::vector<int64>& slice_limits() const;
+
+ // Delegates to HloSliceInstruction::slice_strides.
+ int64 slice_strides(int64 dimension) const;
+ const std::vector<int64>& slice_strides() const;
+
+ // Delegates to HloSliceInstruction::IsInPlaceSlice.
+ bool IsInPlaceSlice() const;
+
+ // Returns the literal associated with this instruction.
+ const Literal& literal() const;
+
+ // Returns whether the instruction is a constant.
+ bool IsConstant() const;
+
+ // Delegate to HloConstantInstruction::RelayoutConstant.
void RelayoutConstant(const Layout& new_layout,
const ShapeIndex& shape_index = {});
+ // Delegates to HloTraceInstruction::TracingTag.
+ string TracingTag() const;
+
+ // Delegates to HloFusionInstruction::AddFusionOperand.
+ HloInstruction* AddFusionOperand(HloInstruction* new_operand);
+
+ // Delegates to HloFusionInstruction::MergeFusionInstruction.
+ void MergeFusionInstruction(HloInstruction* instruction_to_merge);
+
+ // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
+ void MergeFusionInstructionIntoMultiOutput(
+ HloInstruction* instruction_to_merge);
+
+ // Delegates to HloFusionInstruction::FuseInstruction.
+ HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
+
+ // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
+ HloInstruction* FuseInstructionIntoMultiOutput(
+ HloInstruction* instruction_to_fuse);
+
+ // Delegates to HloFusionInstruction::fused_instruction.
+ HloComputation* fused_instructions_computation() const;
+
+ // Delegates to HloFusionInstruction::fused_expression_root.
+ HloInstruction* fused_expression_root() const;
+
+ // Delegates to HloFusionInstruction::fused_instructions.
+ const tensorflow::gtl::iterator_range<UnwrappingIterator<
+ std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
+ fused_instructions() const;
+
+ const tensorflow::gtl::iterator_range<
+ UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
+ fused_instructions();
+
+ // Delegates to HloFusionInstruction::fused_instruction_count.
+ int64 fused_instruction_count() const;
+
+ // Delegates to HloFusionInstruction::fused_parameter.
+ HloInstruction* fused_parameter(int64 parameter_number) const;
+
+ // Delegates to HloFusionInstruction::fused_parameters.
+ const std::vector<HloInstruction*>& fused_parameters() const;
+
+ // Returns true if this instruction is a fusion instruction that generates
+ // multiple outputs.
+ const bool IsMultiOutputFusion() const;
+
+ // Delegates to HloFusionInstruction::fusion_kind.
+ FusionKind fusion_kind() const;
+
+ // Delegates to HloFusionInstruction::set_fusion_kind.
+ void set_fusion_kind(FusionKind kind);
+
+ // Delegates to HloRngInstruction::random_distribution.
+ RandomDistribution random_distribution() const;
+
+ // Delegates to HloParameterInstruction::parameter_number.
+ int64 parameter_number() const;
+
+ // Delegates to HloGetTupleElementInstruction::tuple_index.
+ int64 tuple_index() const;
+
+ // Delegates to HloReducePrecisionInstruction::exponent_bits.
+ int32 exponent_bits() const;
+
+ // Delegates to HloReducePrecisionInstruction::mantissa_bits.
+ int32 mantissa_bits() const;
+
+ // Delegates to HloInfeedInstruction::infeed_config.
+ string infeed_config() const;
+
+ // Delegates to HloInfeedInstruction::set_infeed_config.
+ void set_infeed_config(const string& config);
+
+ // Returns the config for the Outfeed instruction.
+ const string& outfeed_config() const;
+
+ // Returns the shape for the Outfeed instruction.
+ const Shape& outfeed_shape() const;
+
+ // Delegates to HloAllReduceInstruction::replica_group_ids.
+ const std::vector<int64>& replica_group_ids() const;
+
+ // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier.
+ string cross_replica_sum_barrier() const;
+ void set_cross_replica_sum_barrier(const string& barrier);
+
+ // Delegates to HloAllReduceInstruction::all_reduce_id.
+ tensorflow::gtl::optional<int64> all_reduce_id() const;
+
+ // Returns data on the window in a windowed operation such as
+ // convolution.
+ virtual const Window& window() const {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Sets the window data in a windowed operation such as convolution.
+ virtual void set_window(const Window& window) {
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Returns data on the dimension numbers used for a convolution operation,
+ // which may be a kConvolution instruction or a kCustomCall that implements a
+ // convolution.
+ const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
+
+ // Sets the convolution dimension numbers on this instruction. In general you
+ // shouldn't need to call this; instead, specify the convolution dimension
+ // numbers when you create the instruction.
+ void set_convolution_dimension_numbers(
+ const ConvolutionDimensionNumbers& dnums);
+
+ // Delegates to HloSelectAndScatterInstruction::select.
+ HloComputation* select() const;
+
+ // Delegates to HloSelectAndScatterInstruction::scatter.
+ HloComputation* scatter() const;
+
+ // Delegates to HloSelectAndScatterInstruction::set_select.
+ void set_select(HloComputation* computation);
+
+ // Delegates to HloSelectAndScatterInstruction::set_scatter.
+ void set_scatter(HloComputation* computation);
+
+ // Delegates to HloCustomCallInstruction::custom_call_target.
+ const string& custom_call_target() const;
+
+ // Delegates to HloHostComputeInstruction::channel_name.
+ const string& channel_name() const;
+
+ // Delegates to HloPadInstruction::padding_config.
+ const PaddingConfig& padding_config() const;
+
+ // Delegates to HloDynamicSliceInstruction::slice_sizes.
+ int64 slice_sizes(int64 dimension) const;
+
+ // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
+ const std::vector<int64>& dynamic_slice_sizes() const;
+ // Old methods kept for smooth subclassing transition END.
+
+ protected:
+ enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
+ // Helper class for computing OperandElementUse for kFusion.
+ class FusionReusesParamElements;
+
+ // Internal constructor for a given opcode/shape, other fields must be filled
+ // by factory methods.
+ HloInstruction(HloOpcode opcode, const Shape& shape);
+
+ // Appends operand to the list of operands and adds this instruction as a user
+ // of the operand.
+ void AppendOperand(HloInstruction* operand);
+
+ void RemoveOperandAt(int index) {
+ operands_.erase(operands_.begin() + index);
+ }
+
+ // Removes a list of operands with the given indices in ascending order.
+ void RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices);
+
+ void AppendComputation(HloComputation* computation) {
+ called_computations_.push_back(computation);
+ }
+
+ void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
+
+ void set_called_computation(int index, HloComputation* computation) {
+ called_computations_[index] = computation;
+ }
+ // Indices of computations in called_computations_ for instructions which call
+ // multiple computations.
+ enum {
+ // kWhile computations.
+ kBodyComputationIndex = 0,
+ kConditionComputationIndex = 1,
+
+ // kSelectAndScatter computations.
+ kSelectComputationIndex = 0,
+ kScatterComputationIndex = 1,
+
+ // kConditional computations.
+ kTrueComputationIndex = 0,
+ kFalseComputationIndex = 1,
+ };
+
private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ // TODO(b/80131774): This should be pure virtual.
+ LOG(FATAL) << "Unimplemented method.";
+ }
+
+ // Implementation for non-common logic of ExtraAttributesToString.
+ virtual std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {};
+ }
+
+ // Implementation for IsElementwise if operand_idx is nullopt and for
+ // IsElementwiseOnOperand if otherwise.
+ //
+ // NOTE: For all instructions other than kFusion, being elementwise on one of
+ // the operands is equivalent to being elementwise on all the operands.
+ virtual bool IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const;
// Prints an instruction to a string.
//
// The canonical string representation needs to name operands and instruction
@@ -1526,7 +1541,7 @@ class HloInstruction {
CanonicalNameMap* canonical_name_map) const;
// Prints an operand to a string.
- string OperandsToStringWithCanonicalNameMap(
+ virtual string OperandsToStringWithCanonicalNameMap(
const HloPrintOptions& options,
CanonicalNameMap* canonical_name_map) const;
@@ -1534,13 +1549,8 @@ class HloInstruction {
// OperandsToStringWithCanonicalNameMap() functions.
friend class HloComputation;
- enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
-
- // Helper class for computing OperandElementUse for kFusion.
- class FusionReusesParamElements;
-
// See comments on Identical().
- bool IdenticalSlowPath(
+ virtual bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const;
@@ -1550,52 +1560,12 @@ class HloInstruction {
const Shape& shape, HloOpcode opcode,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
- // Appends operand to the list of operands and adds this instruction as a user
- // of the operand.
- void AppendOperand(HloInstruction* operand);
-
// Adds a user for this instruction.
void AddUser(HloInstruction* user);
// Removes a user for this instruction.
void RemoveUser(HloInstruction* user);
- // Internal constructor for a given opcode/shape, other fields must be filled
- // by factory methods.
- HloInstruction(HloOpcode opcode, const Shape& shape);
-
- // Fuses the given instruction into this fusion instruction. When add_output
- // is false (which is the default), instruction_to_fuse is cloned and the
- // clone is placed in the fusion instruction. instruction_to_fuse is
- // unchanged.
- //
- // When add_output is true, a clone of the instruction_to_fuse will be part
- // of the output of fusion instructions. The users of instruction_to_fuse
- // will be redirected to this fusion instructions. instruction_to_fuse will
- // be removed from its parent computation.
- //
- // Precondition: this->opcode() == HloOpcode::kFusion
- HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
- bool add_output = false);
-
- // Clones the given instruction_to_fuse and insert the clone into this fusion
- // instruction. If add_output is true, a clone of instruction_to_fuse will
- // be in the output of the this fusion instruction (part of the tuple of the
- // fusion root).
- //
- // Precondition: opcode() == HloOpcode::kFusion
- HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
- bool add_output = false);
-
- // Clones a fusion instruction with a new shape and operands.
- std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- HloCloneContext* context = nullptr) const;
-
- // Returns true if this instruction can legally have the dimensions field
- // set. Used for checking precondition of dimensions field accessors.
- bool CanHaveDimensionsField() const;
-
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
@@ -1627,62 +1597,17 @@ class HloInstruction {
// The computation in which this instruction is contained.
HloComputation* parent_ = nullptr;
- // Shape of outfeed request.
- Shape outfeed_shape_;
-
// Result shape of this instruction.
Shape shape_;
- // Literal, only present for kConstant.
- std::unique_ptr<Literal> literal_;
-
- // Constant index, only present for kGetTupleElement.
- int64 tuple_index_ = -1;
-
- // Dimensions present for some operations that require reshaping or
- // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
- std::vector<int64> dimensions_;
-
- // Describes the window in a windowed operation such as convolution.
- std::unique_ptr<Window> window_;
-
- // Describes the dimension numbers used for a convolution.
- std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
-
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
- // Describes FFT type for an FFT instruction.
- FftType fft_type_ = FftType::FFT;
-
- // Indicates the FFT length for an FFT instruction.
- std::vector<int64> fft_length_;
-
- // Describes the [begin, end) index range for a slice.
- std::vector<int64> slice_starts_;
- std::vector<int64> slice_limits_;
- std::vector<int64> slice_strides_;
-
- // Describes whether the slice can be lowered to an offset into the operand.
- bool is_in_place_slice_ = false;
-
- // The bit sizes for a reduce-precision operation.
- int32 exponent_bits_ = 0;
- int32 mantissa_bits_ = 0;
-
- // Describes the [start, start + size) range size for a dynamic slice
- // ('start' is specified dynamically in the second operand of the operation).
- std::vector<int64> dynamic_slice_sizes_;
-
- // The padding configuration that describes the edge padding and interior
- // padding of this pad instruction. Only set for pad instructions.
- std::unique_ptr<PaddingConfig> padding_config_;
-
- // The type of the fusion. Used by kFusion only.
- FusionKind fusion_kind_;
+ // Used to tag kCopy instructions that are eligible for copy elision.
+ bool copy_elision_allowed_ = true;
// The sharding, if one exists.
std::unique_ptr<HloSharding> sharding_;
@@ -1691,65 +1616,15 @@ class HloInstruction {
std::unique_ptr<DomainMetadata> operand_side_metadata_;
std::unique_ptr<DomainMetadata> user_side_metadata_;
- // For parameter instructions this field holds the parameter number.
- int64 parameter_number_ = 0;
-
- // Name of a global symbol to call, only present for kCustomCall.
- string custom_call_target_;
-
- // Name to use for host send/recv channels, only present for kHostCompute.
- string channel_name_;
-
- // Estimate of the duration of a host computation in nanoseconds.
- int64 cost_estimate_ns_ = 0;
-
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
- // Indices of computations in called_computations_ for instructions which call
- // multiple computations.
- enum {
- // kWhile computations.
- kBodyComputationIndex = 0,
- kConditionComputationIndex = 1,
-
- // kSelectAndScatter computations.
- kSelectComputationIndex = 0,
- kScatterComputationIndex = 1,
-
- // kConditional computations.
- kTrueComputationIndex = 0,
- kFalseComputationIndex = 1,
- };
-
- // Outfeed configuration information, only present for kOutfeed.
- string outfeed_config_;
-
// A trace instruction that consumes this instruction.
//
// Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
// an operand.
HloInstruction* trace_instruction_ = nullptr;
- // The distribution requested for random number generation.
- // Only present for kRng.
- RandomDistribution distribution_;
-
- // A small float number added to the variance to avoid divide-by-zero error.
- // Only present for kBatchNormTraining.
- float epsilon_ = 0.0f;
-
- // An integer value representing the index of the feature dimension.
- // Only present for kBatchNormTraining.
- int64 feature_index_ = -1;
-
- // Represents a unique identifier for each Send/Recv instruction pair.
- // Only present for kSend or kRecv.
- int64 channel_id_ = -1;
-
- // The string representation of the infeed configuration.
- string infeed_config_;
-
// The backend-specific configuration for how a backend should compile this
// HLO. See the documentation on backend_config().
string backend_config_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index a1a8814384..87c048930f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -20,15 +20,15 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -249,7 +249,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto addleft = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
auto addright = builder.AddInstruction(
@@ -294,7 +294,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, r0f32_, "param1"));
auto c0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto neg1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
auto addleft = builder.AddInstruction(
@@ -334,7 +334,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
auto param = embedded_builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32, "x"));
auto value = embedded_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
@@ -342,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) {
// Builds a parameter and feeds it to the map.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10, ""));
+ HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto map = builder.AddInstruction(
HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
module->AddEntryComputation(builder.Build());
@@ -381,11 +381,11 @@ TEST_F(HloInstructionTest, TrivialReduce) {
// Builds a parameter and an initial value and feeds them to the reduce.
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, f32a100x10, ""));
+ HloInstruction::CreateParameter(0, f32a100x10, "p"));
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto reduce = builder.AddInstruction(
HloInstruction::CreateReduce(f32v100, param0, const0,
/*dimensions_to_reduce=*/{1}, add_f32));
@@ -626,7 +626,7 @@ TEST_F(HloInstructionTest, SingletonFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single unary operation.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto module = CreateNewModule();
@@ -642,9 +642,9 @@ TEST_F(HloInstructionTest, BinaryFusionOp) {
HloComputation::Builder builder(TestName());
// Create a fusion instruction containing a single binary operation.
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant2));
auto module = CreateNewModule();
@@ -661,7 +661,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
@@ -682,7 +682,7 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
HloComputation::Builder builder(TestName());
// Create a chain of fused unary ops.
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto exp1 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
auto exp2 = builder.AddInstruction(
@@ -710,16 +710,17 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto outfeed10 = builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape10, constant, ""));
+ HloInstruction::CreateOutfeed(shape10, constant, token, ""));
auto outfeed01 = builder.AddInstruction(
- HloInstruction::CreateOutfeed(shape01, constant, ""));
+ HloInstruction::CreateOutfeed(shape01, constant, token, ""));
auto clone01 = builder.AddInstruction(outfeed01->Clone());
auto clone10 = builder.AddInstruction(outfeed10->Clone());
@@ -731,7 +732,7 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
HloComputation::Builder builder(TestName());
auto* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
{1, 2},
{3, 4},
})));
@@ -762,13 +763,13 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
HloComputation::Builder builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
- auto map_1_x = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {constant}, computation_x, /*static_operands=*/{}));
- auto map_2_x = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{}));
- auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap(
- scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{}));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
+ auto map_1_x = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
+ auto map_2_x = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
+ auto map_3_y = builder.AddInstruction(
+ HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
auto* computation = module->AddEntryComputation(builder.Build());
auto* fusion = computation->CreateFusionInstruction(
@@ -797,11 +798,11 @@ TEST_F(HloInstructionTest, ComplexFusionOp) {
// Notable complexities are repeated operands in the same instruction,
// different shapes, use of value in different expressions.
auto c1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
auto c2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.1f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
auto c3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(9.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
@@ -872,11 +873,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) {
// Create a set of random constant operands to use below. Make them matrices
// so dimensions are interesting.
auto operand1 = HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
auto operand2 = HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
- auto vector_operand =
- HloInstruction::CreateConstant(Literal::CreateR1<float>({42.0, 123.0}));
+ LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
+ auto vector_operand = HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0, 123.0}));
Shape shape = operand1->shape();
// Convenient short names for the operands.
@@ -923,6 +924,40 @@ TEST_F(HloInstructionTest, IdenticalInstructions) {
*HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
}
+TEST_F(HloInstructionTest, IdenticalCallInstructions) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+subcomp1 (x: f32[]) -> f32[] {
+ x = f32[] parameter(0)
+ ROOT n = f32[] sine(x)
+}
+
+subcomp2 (x: f32[]) -> f32[] {
+ x = f32[] parameter(0)
+ ROOT n = f32[] cosine(x)
+}
+
+ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
+ p = f32[] parameter(0)
+ t1 = f32[] call(p), to_apply=subcomp1
+ t2 = f32[] call(p), to_apply=subcomp1
+ t3 = f32[] call(p), to_apply=subcomp2
+ ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
+ }
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_string));
+
+ auto* root = module->entry_computation()->root_instruction();
+ auto* t1 = root->operand(0);
+ auto* t2 = root->operand(1);
+ auto* t3 = root->operand(2);
+
+ EXPECT_TRUE(StructuralEqual(*t1, *t2));
+ EXPECT_FALSE(StructuralEqual(*t1, *t3));
+}
+
TEST_F(HloInstructionTest, FunctionVisitor) {
// Verify the function visitor HloInstruction::Accept visits all instructions
// from a root properly given the following graph:
@@ -980,6 +1015,23 @@ TEST_F(HloInstructionTest, FullyElementwise) {
}
}
+TEST_F(HloInstructionTest, MapIsElementwise) {
+ auto module = CreateNewModule();
+ const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
+ HloComputation::Builder builder(TestName());
+ HloComputation::Builder map_builder("id");
+ map_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {x}, map_computation));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(map->IsElementwise());
+}
+
TEST_F(HloInstructionTest, PartiallyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
@@ -1119,6 +1171,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
}
+TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
+ // Fused expression:
+ //
+ // x y
+ // | |
+ // | transpose
+ // \ /
+ // dot
+ const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
+
+ EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
+ EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
+}
+
TEST_F(HloInstructionTest, FusionEquality) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
@@ -1148,9 +1234,9 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
// Build a nested fusion computation.
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto b_t = builder.AddInstruction(
HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
DotDimensionNumbers dot_dnums;
@@ -1159,7 +1245,7 @@ TEST_F(HloInstructionTest, NestedFusionEquality) {
auto dot = builder.AddInstruction(
HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -1256,7 +1342,7 @@ TEST_F(HloInstructionTest, Stringification) {
"condition=%TransposeDot, body=%TransposeDot");
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
@@ -1369,15 +1455,15 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
HloInstruction* fusion = computation->CreateFusionInstruction(
{dot, reshape}, HloInstruction::FusionKind::kLoop);
- EXPECT_EQ(
- fusion->ToString(options),
+ const string expected_fusion =
R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
-})");
+})";
+ EXPECT_EQ(fusion->ToString(options), expected_fusion);
}
TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
@@ -1409,8 +1495,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
HloInstruction::CreateWhile(sout, computation, computation, x));
auto options = HloPrintOptions().Canonical();
- EXPECT_EQ(loop->ToString(options),
- R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+ const string expected_loop =
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
tmp_1 = f32[20,10]{1,0} parameter(1)
@@ -1432,7 +1518,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) {
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
-})");
+})";
+ EXPECT_EQ(loop->ToString(options), expected_loop);
}
TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
@@ -1464,13 +1551,12 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
HloInstruction::CreateWhile(sout, computation, computation, x));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloInstruction* conditional =
builder.AddInstruction(HloInstruction::CreateConditional(
sout, pred, x, computation, x, computation));
auto options = HloPrintOptions().Canonical();
- EXPECT_EQ(
- conditional->ToString(options),
+ const string expected_conditional =
R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
{
tmp_0 = f32[5,10]{1,0} parameter(0)
@@ -1493,7 +1579,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) {
tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
-})");
+})";
+ EXPECT_EQ(conditional->ToString(options), expected_conditional);
}
TEST_F(HloInstructionTest, CheckDeepClone) {
@@ -1533,7 +1620,7 @@ ENTRY entry (param: s32[]) -> s32[] {
// Check that deep clones really deep clones every instruction and
// computations, without leaving dangling pointers to the old module.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
std::unique_ptr<HloModule> clone = module->Clone();
for (HloComputation* computation : clone->computations()) {
EXPECT_EQ(computation->parent(), clone.get());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
new file mode 100644
index 0000000000..7ea42caa7b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -0,0 +1,1917 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+
+#include <deque>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace xla {
+namespace {
+
+using ::tensorflow::str_util::CEscape;
+using ::tensorflow::str_util::Join;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
+ const HloInstruction* operand) {
+ std::vector<int64> operand_indices = instruction->OperandIndices(operand);
+ return std::all_of(
+ operand_indices.begin(), operand_indices.end(),
+ [instruction](int64 operand_index) {
+ return instruction->IsElementwiseOnOperand(operand_index);
+ });
+}
+} // namespace
+
+HloBatchNormInstruction::HloBatchNormInstruction(
+ HloOpcode opcode, const Shape& shape, HloInstruction* operand,
+ HloInstruction* scale, float epsilon, int64 feature_index)
+ : HloInstruction(opcode, shape),
+ epsilon_(epsilon),
+ feature_index_(feature_index) {
+ AppendOperand(operand);
+ AppendOperand(scale);
+}
+
+bool HloBatchNormInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloBatchNormInstruction&>(other);
+ return feature_index() == casted_other.feature_index() &&
+ epsilon() == casted_other.epsilon();
+}
+
+HloInstructionProto HloBatchNormInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_epsilon(epsilon_);
+ proto.set_feature_index(feature_index_);
+ return proto;
+}
+
+std::vector<string> HloBatchNormInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("epsilon=", epsilon()),
+ StrCat("feature_index=", feature_index())};
+}
+
+HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand,
+ scale, epsilon, feature_index) {
+ AppendOperand(offset);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloBatchNormTrainingInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], epsilon(),
+ feature_index());
+}
+
+HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand,
+ scale, epsilon, feature_index) {
+ AppendOperand(offset);
+ AppendOperand(mean);
+ AppendOperand(variance);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 5);
+ return MakeUnique<HloBatchNormInferenceInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
+ new_operands[4], epsilon(), feature_index());
+}
+
+HloBatchNormGradInstruction::HloBatchNormGradInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output,
+ float epsilon, int64 feature_index)
+ : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale,
+ epsilon, feature_index) {
+ AppendOperand(mean);
+ AppendOperand(variance);
+ AppendOperand(grad_output);
+}
+
+std::unique_ptr<HloInstruction>
+HloBatchNormGradInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 5);
+ return MakeUnique<HloBatchNormGradInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3],
+ new_operands[4], epsilon(), feature_index());
+}
+
+HloFftInstruction::HloFftInstruction(
+ const Shape& shape, HloInstruction* operand, FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length)
+ : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) {
+ fft_length_.assign(fft_length.begin(), fft_length.end());
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloFftInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_fft_type(fft_type_);
+ for (int64 fft_len : fft_length_) {
+ proto.add_fft_length(fft_len);
+ }
+ return proto;
+}
+
+std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("fft_type=", FftType_Name(fft_type())),
+ StrCat("fft_length={", Join(fft_length(), ","), "}")};
+}
+
+bool HloFftInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloFftInstruction&>(other);
+ return fft_type() == casted_other.fft_type() &&
+ fft_length() == casted_other.fft_length();
+}
+
+std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_,
+ fft_length_);
+}
+
+HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
+ const Shape& shape,
+ int64 channel_id)
+ : HloInstruction(opcode, shape), channel_id_(channel_id) {}
+
+HloInstructionProto HloSendRecvInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_channel_id(channel_id_);
+ return proto;
+}
+
+std::vector<string> HloSendRecvInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("channel_id=", channel_id_)};
+}
+
+bool HloSendRecvInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+// Send instruction produces a tuple of {aliased operand, U32 context}.
+HloSendInstruction::HloSendInstruction(HloInstruction* operand,
+ HloInstruction* token, int64 channel_id)
+ : HloSendRecvInstruction(
+ HloOpcode::kSend,
+ ShapeUtil::MakeTupleShape({CHECK_NOTNULL(operand)->shape(),
+ ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
+ channel_id) {
+ AppendOperand(operand);
+ AppendOperand(token);
+}
+
+std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1],
+ channel_id());
+}
+
+HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand)
+ : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeTokenShape(),
+ CHECK_NOTNULL(operand)->channel_id()) {
+ AppendOperand(operand);
+}
+
+std::unique_ptr<HloInstruction>
+HloSendDoneInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloSendDoneInstruction>(
+ Cast<HloSendInstruction>(new_operands[0]));
+}
+
+// Recv instruction produces a tuple of {receive buffer, U32 context}.
+HloRecvInstruction::HloRecvInstruction(const Shape& shape,
+ HloInstruction* token, int64 channel_id)
+ : HloSendRecvInstruction(
+ HloOpcode::kRecv,
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}),
+ channel_id) {
+ AppendOperand(token);
+}
+
+std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloRecvInstruction>(
+ ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id());
+}
+
+HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand)
+ : HloSendRecvInstruction(
+ HloOpcode::kRecvDone,
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::GetTupleElementShape(operand->shape(), 0),
+ ShapeUtil::MakeTokenShape()}),
+ CHECK_NOTNULL(operand)->channel_id()) {
+ AppendOperand(operand);
+}
+
+std::unique_ptr<HloInstruction>
+HloRecvDoneInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloRecvDoneInstruction>(
+ Cast<HloRecvInstruction>(new_operands[0]));
+}
+
+HloAllReduceInstruction::HloAllReduceInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id)
+ : HloInstruction(HloOpcode::kCrossReplicaSum, shape),
+ replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()),
+ cross_replica_sum_barrier_(barrier.begin(), barrier.end()),
+ all_reduce_id_(all_reduce_id) {
+ // TODO(b/79737069): Remove the CHECK when supported.
+ CHECK(!all_reduce_id_);
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloAllReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 i : replica_group_ids_) {
+ proto.add_replica_group_ids(i);
+ }
+ // Proto3 is so sad.
+ if (all_reduce_id_) {
+ proto.set_all_reduce_id(*all_reduce_id_);
+ }
+ proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_);
+ return proto;
+}
+
+std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& /*options*/) const {
+ std::vector<string> result = {
+ StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")};
+ if (!cross_replica_sum_barrier().empty()) {
+ result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+ }
+ if (all_reduce_id_) {
+ result.push_back(StrCat("all_reduce_id=", *all_reduce_id_));
+ }
+ return result;
+}
+
+bool HloAllReduceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other);
+ return replica_group_ids() == casted_other.replica_group_ids() &&
+ eq_computations(to_apply(), casted_other.to_apply()) &&
+ cross_replica_sum_barrier() ==
+ casted_other.cross_replica_sum_barrier() &&
+ all_reduce_id() == casted_other.all_reduce_id();
+}
+
+std::unique_ptr<HloInstruction>
+HloAllReduceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* /*context*/) const {
+ return MakeUnique<HloAllReduceInstruction>(
+ shape, new_operands, to_apply(), replica_group_ids(),
+ cross_replica_sum_barrier(), all_reduce_id());
+}
+
+HloReverseInstruction::HloReverseInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : HloInstruction(HloOpcode::kReverse, shape),
+ dimensions_(dimensions.begin(), dimensions.end()) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloReverseInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloReverseInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloReverseInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloReverseInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloConcatenateInstruction::HloConcatenateInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension)
+ : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloInstructionProto HloConcatenateInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloConcatenateInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloConcatenateInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloConcatenateInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloConcatenateInstruction>(shape, new_operands,
+ dimensions(0));
+}
+
+HloReduceInstruction::HloReduceInstruction(
+ const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation)
+ : HloInstruction(HloOpcode::kReduce, shape),
+ dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) {
+ AppendOperand(arg);
+ AppendOperand(init_value);
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloReduceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloReduceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloReduceInstruction&>(other);
+ // Reduction results are determined by the reduction dimension and the
+ // reduction computation.
+ return dimensions() == casted_other.dimensions() &&
+ eq_computations(to_apply(), casted_other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloReduceInstruction>(
+ shape, new_operands[0], new_operands[1], dimensions(), to_apply());
+}
+
+HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values)
+ : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
+ AppendOperand(keys);
+ if (values) {
+ AppendOperand(values);
+ }
+}
+
+HloInstructionProto HloSortInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloSortInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloSortInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ HloInstruction* keys = new_operands[0];
+ HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
+ return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values);
+}
+
+HloTransposeInstruction::HloTransposeInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions)
+ : HloInstruction(HloOpcode::kTranspose, shape),
+ dimensions_(dimensions.begin(), dimensions.end()) {
+ CHECK_EQ(shape.dimensions().size(), dimensions.size());
+ CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
+ CHECK(std::equal(operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(dimensions, shape.dimensions()).begin()))
+ << "shape: " << ShapeUtil::HumanString(shape)
+ << ", operand->shape(): " << ShapeUtil::HumanString(shape)
+ << ", dimensions: {" << Join(dimensions, ", ") << "}";
+ AppendOperand(operand);
+}
+
+bool HloTransposeInstruction::IsRank2Transpose() const {
+ return dimensions() == std::vector<int64>({1, 0}) &&
+ shape().dimensions_size() == 2 &&
+ std::equal(shape().dimensions().begin(), shape().dimensions().end(),
+ operand(0)->shape().dimensions().rbegin());
+}
+
+HloInstructionProto HloTransposeInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloTransposeInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloTransposeInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloTransposeInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloTransposeInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloBroadcastInstruction::HloBroadcastInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimension)
+ : HloInstruction(HloOpcode::kBroadcast, shape),
+ dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloBroadcastInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloBroadcastInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloBroadcastInstruction&>(other);
+ return dimensions() == casted_other.dimensions();
+}
+
+std::unique_ptr<HloInstruction>
+HloBroadcastInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0],
+ dimensions());
+}
+
+HloMapInstruction::HloMapInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation)
+ : HloInstruction(HloOpcode::kMap, shape) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ AppendComputation(map_computation);
+ // TODO(b/65689298) Remove code below once Map is generalized to accept
+ // arbitrary map dimensions.
+ dimensions_.resize(ShapeUtil::Rank(shape));
+ std::iota(dimensions_.begin(), dimensions_.end(), 0);
+}
+
+HloInstructionProto HloMapInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 dimension : dimensions_) {
+ proto.add_dimensions(dimension);
+ }
+ return proto;
+}
+
+bool HloMapInstruction::IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const {
+ if (!dimensions().empty()) {
+ // Check that the map is executed in elementwise compatible dimensions.
+ if (dimensions().size() != shape().dimensions_size()) {
+ return false;
+ }
+ for (int i = 0; i < dimensions().size(); ++i) {
+ if (dimensions()[i] != i) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("dimensions={", Join(dimensions(), ","), "}")};
+}
+
+bool HloMapInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return eq_computations(to_apply(), other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply());
+}
+
+HloSliceInstruction::HloSliceInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides)
+ : HloInstruction(HloOpcode::kSlice, shape),
+ slice_starts_(start_indices.begin(), start_indices.end()),
+ slice_limits_(limit_indices.begin(), limit_indices.end()),
+ slice_strides_(strides.begin(), strides.end()) {
+ AppendOperand(operand);
+ // For backward compatibility with old serialized computations: if there are
+ // no strides, assume all strides are 1.
+ // TODO(b/63317920): remove this code.
+ if (slice_strides_.empty()) {
+ slice_strides_ = std::vector<int64>(start_indices.size(), 1LL);
+ }
+}
+
+HloInstructionProto HloSliceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int i = 0; i < slice_starts_.size(); ++i) {
+ auto* slice_dimension = proto.add_slice_dimensions();
+ slice_dimension->set_start(slice_starts_[i]);
+ slice_dimension->set_limit(slice_limits_[i]);
+ slice_dimension->set_stride(slice_strides_[i]);
+ }
+ return proto;
+}
+
+std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> bounds;
+ bounds.reserve(slice_starts_.size());
+ const bool omit_stride =
+ std::all_of(slice_strides_.begin(), slice_strides_.end(),
+ [](int64 stride) { return stride == 1; });
+ for (int i = 0; i < slice_starts_.size(); ++i) {
+ string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
+ bounds.push_back(
+ StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]"));
+ }
+ return {StrCat("slice={", Join(bounds, ", "), "}")};
+}
+
+bool HloSliceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
+ return slice_starts_ == other_slice.slice_starts_ &&
+ slice_limits_ == other_slice.slice_limits_ &&
+ slice_strides_ == other_slice.slice_strides_;
+}
+
+std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
+ slice_limits_, slice_strides_);
+}
+
+HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
+ : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
+ literal_(std::move(literal)) {}
+
+HloConstantInstruction::HloConstantInstruction(const Shape& shape)
+ : HloInstruction(HloOpcode::kConstant, shape) {}
+
+HloInstructionProto HloConstantInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ if (literal_ != nullptr) {
+ *proto.mutable_literal() = literal_->ToProto();
+ }
+ return proto;
+}
+
+bool HloConstantInstruction::IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const {
+ return true;
+}
+
+void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
+ const ShapeIndex& shape_index) {
+ Shape* mutable_array_subshape =
+ ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
+ CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
+
+ // Normally array_subshape will always have a layout, but this invariant is
+ // temporarily broken in LayoutAssignment::AssignLayouts.
+
+ if (!mutable_array_subshape->has_layout() ||
+ !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
+ literal_ = literal_->Relayout(new_layout, shape_index);
+ *mutable_array_subshape->mutable_layout() = new_layout;
+ }
+}
+
+bool HloConstantInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
+ return literal() == other_slice.literal();
+}
+
+std::unique_ptr<HloInstruction>
+HloConstantInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
+}
+
+string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const {
+ string operands;
+ // For constants, show the actual value in place of an empty operand list.
+ if (literal_ != nullptr &&
+ ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
+ options.print_large_constants())) {
+ // Literal::ToString emits multidimensional arrays over multiple
+ // lines. Compact this into one line by stripping out white space.
+ string tmp = literal().ToString();
+ std::replace(tmp.begin(), tmp.end(), '\n', ' ');
+ std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
+ bool first = true;
+ // Concatenate elements in "v" with spaces separating them, but ignoring
+ // empty entries.
+ for (const auto& s : v) {
+ if (s.empty()) {
+ continue;
+ }
+ StrAppend(&operands, (first ? "" : " "), s);
+ first = false;
+ }
+ } else {
+ // Do not show large constants or tuples.
+ operands = "{...}";
+ }
+ return operands;
+}
+
+HloTraceInstruction::HloTraceInstruction(const string& tag,
+ HloInstruction* operand)
+ : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
+ literal_(LiteralUtil::CreateR1U8(tag)) {
+ AppendOperand(operand);
+ operand->set_tracing(this);
+}
+
+HloInstructionProto HloTraceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_literal() = literal_->ToProto();
+ return proto;
+}
+
+bool HloTraceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
+}
+
+HloFusionInstruction::HloFusionInstruction(const Shape& shape,
+ FusionKind fusion_kind,
+ HloInstruction* fused_root)
+ : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
+ CHECK(fused_root != nullptr);
+ SetAndSanitizeName("fusion");
+ set_parent(fused_root->parent());
+ set_metadata(fused_root->metadata());
+ CloneAndFuseInternal(fused_root);
+}
+
+HloFusionInstruction::HloFusionInstruction(
+ const Shape& shape, FusionKind fusion_kind,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* fusion_computation)
+ : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+ SetAndSanitizeName("fusion");
+ AppendComputation(fusion_computation);
+ fusion_computation->SetFusionInstruction(this);
+}
+
+string HloFusionInstruction::ToCategory() const {
+ switch (fusion_kind()) {
+ case FusionKind::kLoop:
+ return "loop fusion";
+ case FusionKind::kInput:
+ return "input fusion";
+ case FusionKind::kOutput:
+ return "output fusion";
+ case FusionKind::kCustom:
+ return "custom fusion";
+ }
+}
+
+HloInstructionProto HloFusionInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_fusion_kind(xla::ToString(fusion_kind()));
+ proto.add_called_computation_ids(
+ fused_instructions_computation()->unique_id());
+ return proto;
+}
+
+bool HloFusionInstruction::IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const {
+ if (!operand_idx.has_value()) {
+ for (auto* fused : fused_instructions()) {
+ if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) {
+ return false;
+ }
+ }
+ return true;
+ }
+ // A loop-fusion is elementwise on an operand if all operations (computed
+ // using BFS) between the operand and the fused root are elementwise.
+ std::deque<HloInstruction*> worklist;
+ std::unordered_set<const HloInstruction*> visited;
+ worklist.push_back(fused_parameter(operand_idx.value()));
+ visited.insert(fused_parameter(operand_idx.value()));
+ while (!worklist.empty()) {
+ HloInstruction* operand = worklist.front();
+ worklist.pop_front();
+ for (HloInstruction* user : operand->users()) {
+ CHECK_GE(user->unique_id(), 0);
+ if (ContainsKey(visited, user)) {
+ continue;
+ }
+ if (user->IsElementwise() ||
+ IsInstructionElementwiseOnOperand(user, operand)) {
+ worklist.push_back(user);
+ visited.insert(user);
+ } else {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+HloInstruction* HloFusionInstruction::AddFusionOperand(
+ HloInstruction* new_operand) {
+ CHECK_EQ(operand_count(),
+ fused_instructions_computation()->parameter_instructions().size());
+ const int64 param_no = operand_count();
+ // Name the parameter after the instruction it represents in the outer
+ // (non-fusion) computation.
+ string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ HloInstruction* fused_parameter =
+ fused_instructions_computation()->AddParameter(
+ HloInstruction::CreateParameter(param_no, new_operand->shape(),
+ param_name));
+ AppendOperand(new_operand);
+ return fused_parameter;
+}
+
+void HloFusionInstruction::MergeFusionInstruction(
+ HloFusionInstruction* instruction_to_merge) {
+ CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
+ operands().end());
+ // Clone the instruction from which to merge fused instructions.
+ std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
+ HloFusionInstruction* cloned_fusion =
+ static_cast<HloFusionInstruction*>(cloned.get());
+ // Replace uses of fused parameters with the corresponding operand of the
+ // fusion. Add all non-parameter fused instructions to
+ // 'unfused_instructions' to be merged into 'this'. This is done in reverse
+ // post order.
+ std::vector<HloInstruction*> unfused_instructions;
+ auto fused_instructions = cloned_fusion->fused_instructions_computation()
+ ->MakeInstructionPostOrder();
+ for (auto fused_it = fused_instructions.rbegin();
+ fused_it != fused_instructions.rend(); ++fused_it) {
+ auto fused_instruction = *fused_it;
+ if (fused_instruction->opcode() == HloOpcode::kParameter) {
+ TF_CHECK_OK(
+ fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand(
+ fused_instruction->parameter_number())));
+ } else {
+ unfused_instructions.push_back(fused_instruction);
+ }
+ }
+ CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
+ // Replace instruction_to_merge use of 'this' with unfused_root.
+ TF_CHECK_OK(
+ instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
+ // Fuse 'unfused_instructions' into 'this'.
+ for (auto& instruction : unfused_instructions) {
+ FuseInstruction(instruction);
+ }
+ CHECK_EQ(0, cloned_fusion->user_count());
+ TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
+ cloned_fusion->fused_instructions_computation()));
+}
+
+void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
+ HloFusionInstruction* instruction_to_merge) {
+ // Add all non-parameter fused instructions to 'unfused_instructions' to be
+ // merged into 'this'. `old_to_new' maps the instructions in the fused node
+ // to the disaseembled fusion instructions.
+ // Note that we add the unfused instructions to this->parent_ computation.
+ // This is necessary because the unique_id needs for an instruction and
+ // it's only added when inserting to the computation.
+ tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
+ std::vector<HloInstruction*> unfused_instructions;
+ auto computation_to_merge =
+ instruction_to_merge->fused_instructions_computation();
+ auto post_order = computation_to_merge->MakeInstructionPostOrder();
+ for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) {
+ auto fused_instruction = *rit;
+ if (fused_instruction->opcode() == HloOpcode::kParameter) {
+ InsertOrDie(&old_to_new, fused_instruction,
+ instruction_to_merge->mutable_operand(
+ fused_instruction->parameter_number()));
+ continue;
+ }
+
+ // Here we clone the insertion and call FuseInstructionIntoMultiOutput()
+ // which clones again. This can be improved.
+ auto cloned_instruction =
+ parent()->AddInstruction(fused_instruction->Clone());
+ unfused_instructions.push_back(cloned_instruction);
+ InsertOrDie(&old_to_new, fused_instruction, cloned_instruction);
+ }
+ for (auto unfused_instruction : unfused_instructions) {
+ for (int64 index = 0; index < unfused_instruction->operand_count();
+ index++) {
+ auto new_operand =
+ FindOrDie(old_to_new, unfused_instruction->mutable_operand(index));
+ TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand));
+ }
+ }
+
+ HloInstruction* unfused_root = unfused_instructions.front();
+ TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
+
+ TF_CHECK_OK(
+ instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge));
+ if (GetModule()) {
+ TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge));
+ }
+
+ // Fuse the root instruction and generate multiple outputs.
+ FuseInstructionIntoMultiOutput(unfused_root);
+ TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
+ // The rest instructions are of normal fusing.
+ for (int64 i = 1; i < unfused_instructions.size(); i++) {
+ auto instruction = unfused_instructions[i];
+ FuseInstruction(instruction);
+ TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
+ }
+}
+
+HloComputation* HloFusionInstruction::fused_instructions_computation() const {
+ CHECK(!called_computations().empty());
+ auto* fused_instructions_computation = called_computations().front();
+ CHECK(fused_instructions_computation->IsFusionComputation())
+ << "Computation " << fused_instructions_computation->name()
+ << " is not a fusion kind";
+ return fused_instructions_computation;
+}
+
+HloInstruction* HloFusionInstruction::fused_expression_root() const {
+ return fused_instructions_computation()->root_instruction();
+}
+
+HloInstruction* HloFusionInstruction::fused_parameter(
+ int64 parameter_number) const {
+ return fused_instructions_computation()->parameter_instruction(
+ parameter_number);
+}
+
+const std::vector<HloInstruction*>& HloFusionInstruction::fused_parameters()
+ const {
+ return fused_instructions_computation()->parameter_instructions();
+}
+
+const tensorflow::gtl::iterator_range<UnwrappingIterator<
+ std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
+HloFusionInstruction::fused_instructions() const {
+ const HloComputation* subcomp = fused_instructions_computation();
+ return subcomp->instructions();
+}
+
+const tensorflow::gtl::iterator_range<
+ UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
+HloFusionInstruction::fused_instructions() {
+ return fused_instructions_computation()->instructions();
+}
+
+int64 HloFusionInstruction::fused_instruction_count() const {
+ return fused_instructions_computation()->instruction_count();
+}
+
+HloInstruction* HloFusionInstruction::FuseInstructionInternal(
+ HloInstruction* instruction_to_fuse, bool add_output) {
+ // When add_output is false, this fusion instruction must be a user of
+ // instruction_to_fuse.
+ if (!add_output) {
+ CHECK(IsUserOf(instruction_to_fuse));
+ }
+ HloInstruction* fused_instruction =
+ CloneAndFuseInternal(instruction_to_fuse, add_output);
+ return fused_instruction;
+}
+
+HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
+ HloInstruction* instruction_to_fuse, bool add_output) {
+ CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString();
+ VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString();
+ HloInstruction* clone = nullptr;
+ if (called_computations().empty()) {
+ // New fusion instruction. It should not be a multioutput instruction.
+ CHECK(!add_output);
+ auto builder = HloComputation::Builder("fused_computation", this);
+ builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/""));
+ AppendComputation(
+ CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build()));
+ clone = fused_expression_root();
+ } else {
+ // When add_output is false, instruction_to_fuse is necessarily an operand
+ // of the fusion instruction. After fusion this will no longer be the
+ // case. Remove the operand from the operand list and remove its
+ // corresponding fused parameter instruction. Renumber parameters as
+ // necessary to make parameter numbers consistent with their index in the
+ // fused_parameter_ vector.
+ bool in_operand_list = std::find(operands().begin(), operands().end(),
+ instruction_to_fuse) != operands().end();
+ CHECK(add_output || in_operand_list);
+ if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
+ // We assume all uses of a kTuple operation are GTE ops, not another
+ // fusion node. In this case, we don't need to clone
+ // 'instruction_to_fuse'.
+ CHECK(!in_operand_list);
+ clone = instruction_to_fuse;
+ } else {
+ clone = fused_instructions_computation()->AddInstruction(
+ instruction_to_fuse->Clone(/*suffix=*/""));
+ }
+ const std::vector<HloInstruction*>& fused_parameters =
+ fused_instructions_computation()->parameter_instructions();
+ for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
+ if (instruction_to_fuse == operand(operand_num)) {
+ // replace the fused parameter instruction's uses with the clone.
+ HloInstruction* fused_parameter = fused_parameters[operand_num];
+ TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone));
+
+ // Remove the corresponding fused parameter and operand from their
+ // respective vectors.
+ TF_CHECK_OK(
+ fused_instructions_computation()->RemoveParameter(operand_num));
+ RemoveOperandAt(operand_num);
+ break;
+ }
+ }
+ // We've cloned instruction_to_fuse into this fusion instruction, so this
+ // fusion instruction is no longer a use of instruction_to_fuse.
+ if (in_operand_list) {
+ DetachFrom(instruction_to_fuse);
+ // When the instruction_to_fuse does not have other users, we don't need
+ // to generate a multioutput fusion instruction.
+ if (instruction_to_fuse->user_count() == 0) {
+ add_output = false;
+ }
+ }
+ }
+
+ // Reread the parameters in the computation.
+ const std::vector<HloInstruction*>& fused_parameters =
+ fused_instructions_computation()->parameter_instructions();
+
+ // Add each operand of the clone as an operand of the fusion instruction. A
+ // complication is that some clone operands may already be operands of the
+ // fusion instruction.
+ for (int64 operand_num = 0; operand_num < clone->operand_count();
+ ++operand_num) {
+ HloInstruction* operand = clone->mutable_operand(operand_num);
+
+ // See if this operand is already an operand of the fusion node.
+ CHECK_EQ(operands().size(), fused_parameters.size());
+ HloInstruction* fused_param = nullptr;
+ for (int64 i = 0; i < operands().size(); ++i) {
+ if (this->operand(i) == operand) {
+ fused_param = fused_parameters[i];
+ break;
+ }
+ }
+
+ if (fused_param == nullptr) {
+ // Clone's operand was not already an operand of the fusion
+ // instruction. Add it as an operand and add a corresponding fused
+ // parameter instruction.
+ fused_param = AddFusionOperand(operand);
+ }
+ TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param));
+ }
+
+ if (add_output) {
+ CHECK_GT(instruction_to_fuse->user_count(), 0);
+ // If this is already a multioutput fusion instruction, expand the root
+ // tuple by 1.
+ HloInstruction* fused_root = fused_expression_root();
+ HloInstruction::InstructionVector tuple_elements;
+ bool newly_created_tuple_instr = false;
+ if (fused_root->opcode() == HloOpcode::kTuple) {
+ tuple_elements = fused_root->operands();
+ } else {
+ tuple_elements.push_back(fused_root);
+ newly_created_tuple_instr = true;
+ }
+ if (clone->opcode() == HloOpcode::kTuple) {
+ for (auto inst : clone->operands()) {
+ tuple_elements.push_back(inst);
+ }
+ } else {
+ tuple_elements.push_back(clone);
+ }
+ HloInstruction* new_root = fused_instructions_computation()->AddInstruction(
+ HloInstruction::CreateTuple(tuple_elements));
+ fused_instructions_computation()->set_root_instruction(new_root);
+ *mutable_shape() = new_root->shape();
+ if (fused_root->opcode() == HloOpcode::kTuple) {
+ TF_CHECK_OK(
+ fused_instructions_computation()->RemoveInstruction(fused_root));
+ }
+
+ // If this is a newly created multioutput instruction, we need to update
+ // the use of the original fusion instruction.
+ if (newly_created_tuple_instr) {
+ HloInstruction* new_instr = parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0));
+ TF_CHECK_OK(ReplaceAllUsesWith(new_instr));
+ }
+ int64 index = tuple_elements.size();
+ if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
+ CHECK_EQ(clone, instruction_to_fuse);
+ index -= clone->operand_count();
+ std::vector<HloInstruction*> to_be_removed;
+ for (auto old_gte : clone->users()) {
+ CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement);
+ int64 old_tuple_index = old_gte->tuple_index();
+ HloInstruction* new_gte =
+ parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ old_gte->shape(), this, index + old_tuple_index));
+ TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte));
+ to_be_removed.push_back(old_gte);
+ }
+ for (auto old_gte : to_be_removed) {
+ TF_CHECK_OK(parent()->RemoveInstruction(old_gte));
+ }
+ } else {
+ HloInstruction* new_gte =
+ parent()->AddInstruction(HloInstruction::CreateGetTupleElement(
+ clone->shape(), this, index - 1));
+ TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte));
+ }
+ }
+
+ if (clone != instruction_to_fuse) {
+ VLOG(2) << "New clone:\n" << clone->ToString();
+ }
+ return clone;
+}
+
+std::vector<string> HloFusionInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("kind=", xla::ToString(fusion_kind()))};
+}
+
+bool HloFusionInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return fusion_kind() == other.fusion_kind() &&
+ eq_computations(fused_instructions_computation(),
+ other.fused_instructions_computation());
+}
+
+std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ HloModule* module = context != nullptr ? context->module() : GetModule();
+ HloComputation* new_fused_computation = nullptr;
+ if (context != nullptr) {
+ new_fused_computation =
+ context->FindComputation(fused_instructions_computation());
+ }
+ if (new_fused_computation == nullptr) {
+ new_fused_computation = module->AddEmbeddedComputation(
+ fused_instructions_computation()->Clone("clone", context));
+ }
+ return MakeUnique<HloFusionInstruction>(shape, fusion_kind(), new_operands,
+ new_fused_computation);
+}
+
+Status HloFusionInstruction::DeduplicateFusionOperands() {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ std::vector<int> operands_to_remove;
+ for (int i = 0; i < operand_count(); ++i) {
+ auto emplace_result = operand_indices.emplace(operand(i), i);
+ if (!emplace_result.second) {
+ TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
+ fused_parameter(emplace_result.first->second)));
+ operands_to_remove.push_back(i);
+ }
+ }
+ if (operands_to_remove.empty()) {
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ fused_instructions_computation()->RemoveUnusedParameters());
+ RemoveOperandsAtAscendingIndices(operands_to_remove);
+ return Status::OK();
+}
+
+HloRngInstruction::HloRngInstruction(
+ const Shape& shape, RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
+ : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) {
+ for (HloInstruction* param : parameters) {
+ AppendOperand(param);
+ }
+}
+
+HloInstructionProto HloRngInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_distribution(distribution_);
+ return proto;
+}
+
+std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("distribution=", RandomDistributionToString(distribution_))};
+}
+
+bool HloRngInstruction::IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const {
+ return true;
+}
+
+bool HloRngInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloRngInstruction>(shape, distribution_, new_operands);
+}
+
+HloParameterInstruction::HloParameterInstruction(int64 parameter_number,
+ const Shape& shape,
+ const string& name)
+ : HloInstruction(HloOpcode::kParameter, shape),
+ parameter_number_(parameter_number) {
+ SetAndSanitizeName(name);
+}
+
+HloInstructionProto HloParameterInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_parameter_number(parameter_number_);
+ return proto;
+}
+
+string HloParameterInstruction::OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const {
+ return StrCat(parameter_number_);
+}
+
+bool HloParameterInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloParameterInstruction&>(other);
+ return parameter_number() == casted_other.parameter_number();
+}
+
+std::unique_ptr<HloInstruction>
+HloParameterInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name());
+}
+
+HloGetTupleElementInstruction::HloGetTupleElementInstruction(
+ const Shape& shape, HloInstruction* operand, int64 index)
+ : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
+ CHECK(ShapeUtil::IsTuple(operand->shape()));
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloGetTupleElementInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_tuple_index(tuple_index_);
+ return proto;
+}
+
+std::vector<string> HloGetTupleElementInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("index=", tuple_index())};
+}
+
+bool HloGetTupleElementInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloGetTupleElementInstruction&>(other);
+ return tuple_index() == casted_other.tuple_index();
+}
+
+std::unique_ptr<HloInstruction>
+HloGetTupleElementInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0],
+ tuple_index());
+}
+
+HloReducePrecisionInstruction::HloReducePrecisionInstruction(
+ const Shape& shape, HloInstruction* operand, const int exponent_bits,
+ const int mantissa_bits)
+ : HloInstruction(HloOpcode::kReducePrecision, shape),
+ exponent_bits_(exponent_bits),
+ mantissa_bits_(mantissa_bits) {
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloReducePrecisionInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_exponent_bits(exponent_bits_);
+ proto.set_mantissa_bits(mantissa_bits_);
+ return proto;
+}
+
+std::vector<string> HloReducePrecisionInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("exponent_bits=", exponent_bits_),
+ StrCat("mantissa_bits=", mantissa_bits_)};
+}
+
+bool HloReducePrecisionInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloReducePrecisionInstruction&>(other);
+ // A reduce-precision operation is determined by the bit sizes.
+ return exponent_bits() == casted_other.exponent_bits() &&
+ mantissa_bits() == casted_other.mantissa_bits();
+}
+
+std::unique_ptr<HloInstruction>
+HloReducePrecisionInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloReducePrecisionInstruction>(
+ shape, new_operands[0], exponent_bits(), mantissa_bits());
+}
+
+HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
+ HloInstruction* token_operand,
+ const string& config)
+ : HloInstruction(HloOpcode::kInfeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed_shape, ShapeUtil::MakeTokenShape()})),
+ infeed_config_(config) {
+ AppendOperand(token_operand);
+}
+
+HloInfeedInstruction::HloInfeedInstruction(const Shape& infeed_shape,
+ const string& config)
+ : HloInstruction(HloOpcode::kInfeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed_shape, ShapeUtil::MakeTokenShape()})),
+ infeed_config_(config) {}
+
+HloInstructionProto HloInfeedInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_infeed_config(infeed_config_);
+ return proto;
+}
+
+std::vector<string> HloInfeedInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (infeed_config_.empty()) {
+ return {};
+ }
+ return {StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")};
+}
+
+bool HloInfeedInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ if (new_operands.empty()) {
+ return MakeUnique<HloInfeedInstruction>(infeed_shape(), infeed_config());
+ } else {
+ CHECK_EQ(new_operands.size(), 1);
+ return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0],
+ infeed_config());
+ }
+}
+
+HloOutfeedInstruction::HloOutfeedInstruction(
+ const Shape& outfeed_shape, HloInstruction* operand,
+ HloInstruction* token_operand, tensorflow::StringPiece outfeed_config)
+ : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
+ outfeed_shape_(outfeed_shape),
+ outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
+ << "Outfeed shape " << outfeed_shape
+ << " must be compatible with operand shape " << operand->shape();
+ AppendOperand(operand);
+ AppendOperand(token_operand);
+}
+
+HloOutfeedInstruction::HloOutfeedInstruction(
+ const Shape& outfeed_shape, HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config)
+ : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
+ outfeed_shape_(outfeed_shape),
+ outfeed_config_(outfeed_config.begin(), outfeed_config.end()) {
+ CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
+ << "Outfeed shape " << outfeed_shape
+ << " must be compatible with operand shape " << operand->shape();
+ AppendOperand(operand);
+}
+
+HloInstructionProto HloOutfeedInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_outfeed_config(outfeed_config());
+ *proto.mutable_outfeed_shape() = outfeed_shape();
+ return proto;
+}
+
+std::vector<string> HloOutfeedInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ if (outfeed_config_.empty()) {
+ return {};
+ }
+ return {StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")};
+}
+
+bool HloOutfeedInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ if (new_operands.size() == 1) {
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
+ outfeed_config());
+ } else {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0],
+ new_operands[1], outfeed_config());
+ }
+}
+
+HloConvolutionInstruction::HloConvolutionInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window, const ConvolutionDimensionNumbers& dimension_numbers)
+ : HloInstruction(HloOpcode::kConvolution, shape),
+ window_(window),
+ convolution_dimension_numbers_(dimension_numbers) {
+ if (window_util::HasBaseDilation(window)) {
+ SetAndSanitizeName(StrCat(name(), "-base-dilated"));
+ }
+ if (window_util::HasWindowDilation(window)) {
+ SetAndSanitizeName(StrCat(name(), "-window-dilated"));
+ }
+ AppendOperand(lhs);
+ AppendOperand(rhs);
+}
+
+string HloConvolutionInstruction::ToCategory() const {
+ string category = "convolution";
+ if (window_util::HasBaseDilation(window())) {
+ category += " base-dilated";
+ }
+ if (window_util::HasWindowDilation(window())) {
+ category += " window-dilated";
+ }
+ return category;
+}
+
+HloInstructionProto HloConvolutionInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_window() = window_;
+ *proto.mutable_convolution_dimension_numbers() =
+ convolution_dimension_numbers_;
+ return proto;
+}
+
+std::vector<string> HloConvolutionInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra;
+ if (window_.dimensions_size() != 0) {
+ extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
+ }
+ extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString(
+ convolution_dimension_numbers_)));
+ return extra;
+}
+
+bool HloConvolutionInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloConvolutionInstruction&>(other);
+ return protobuf_util::ProtobufEquals(window(), casted_other.window()) &&
+ protobuf_util::ProtobufEquals(
+ convolution_dimension_numbers(),
+ casted_other.convolution_dimension_numbers());
+}
+
+std::unique_ptr<HloInstruction>
+HloConvolutionInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloConvolutionInstruction>(shape, new_operands[0],
+ new_operands[1], window(),
+ convolution_dimension_numbers_);
+}
+
+HloReduceWindowInstruction::HloReduceWindowInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
+ const Window& window, HloComputation* reduce_computation)
+ : HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
+ AppendOperand(operand);
+ AppendOperand(init_value);
+ AppendComputation(reduce_computation);
+}
+
+HloInstructionProto HloReduceWindowInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_window() = window_;
+ return proto;
+}
+
+std::vector<string> HloReduceWindowInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra;
+ if (window_.dimensions_size() != 0) {
+ extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
+ }
+ return extra;
+}
+
+bool HloReduceWindowInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloReduceWindowInstruction&>(other);
+ return eq_computations(to_apply(), casted_other.to_apply()) &&
+ protobuf_util::ProtobufEquals(window(), casted_other.window());
+}
+
+std::unique_ptr<HloInstruction>
+HloReduceWindowInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloReduceWindowInstruction>(
+ shape, new_operands[0], new_operands[1], window(), to_apply());
+}
+
+HloSelectAndScatterInstruction::HloSelectAndScatterInstruction(
+ const Shape& shape, HloInstruction* operand, HloComputation* select,
+ const Window& window, HloInstruction* source, HloInstruction* init_value,
+ HloComputation* scatter)
+ : HloInstruction(HloOpcode::kSelectAndScatter, shape), window_(window) {
+ AppendOperand(operand);
+ AppendOperand(source);
+ AppendOperand(init_value);
+ // Select comes before scatter in the vector.
+ AppendComputation(select);
+ AppendComputation(scatter);
+}
+
+HloInstructionProto HloSelectAndScatterInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_window() = window_;
+ return proto;
+}
+
+std::vector<string> HloSelectAndScatterInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra;
+ if (window_.dimensions_size() != 0) {
+ extra.push_back(StrCat("window={", window_util::ToString(window()), "}"));
+ }
+ return extra;
+}
+
+bool HloSelectAndScatterInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloSelectAndScatterInstruction&>(other);
+ return eq_computations(select(), casted_other.select()) &&
+ eq_computations(scatter(), casted_other.scatter()) &&
+ protobuf_util::ProtobufEquals(window(), casted_other.window());
+}
+
+std::unique_ptr<HloInstruction>
+HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloSelectAndScatterInstruction>(
+ shape, new_operands[0], select(), window(), new_operands[1],
+ new_operands[2], scatter());
+}
+
+HloCustomCallInstruction::HloCustomCallInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target)
+ : HloInstruction(HloOpcode::kCustomCall, shape),
+ custom_call_target_(custom_call_target.begin(),
+ custom_call_target.end()) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloInstructionProto HloCustomCallInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ if (window_ != nullptr) {
+ *proto.mutable_window() = *window_;
+ }
+ if (convolution_dimension_numbers_ != nullptr) {
+ *proto.mutable_convolution_dimension_numbers() =
+ *convolution_dimension_numbers_;
+ }
+ proto.set_custom_call_target(custom_call_target_);
+ return proto;
+}
+
+std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ std::vector<string> extra;
+ if (window_ != nullptr && window_->dimensions_size() != 0) {
+ extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
+ }
+ if (convolution_dimension_numbers_ != nullptr) {
+ extra.push_back(StrCat(
+ "dim_labels=",
+ ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_)));
+ }
+ // By contract, we print the custom call target even if
+ // options.print_subcomputation_mode() == kOff, because the call target is not
+ // an HloComputation.
+ extra.push_back(
+ StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ return extra;
+}
+
+bool HloCustomCallInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other =
+ static_cast<const HloCustomCallInstruction&>(other);
+ if ((window_ == nullptr) != (casted_other.window_ == nullptr) ||
+ (window_ != nullptr &&
+ !protobuf_util::ProtobufEquals(*window_, *casted_other.window_))) {
+ return false;
+ }
+ if ((convolution_dimension_numbers_ == nullptr) !=
+ (casted_other.convolution_dimension_numbers_ == nullptr) ||
+ (convolution_dimension_numbers_ != nullptr &&
+ !protobuf_util::ProtobufEquals(
+ convolution_dimension_numbers(),
+ casted_other.convolution_dimension_numbers()))) {
+ return false;
+ }
+ return custom_call_target_ == casted_other.custom_call_target_;
+}
+
+std::unique_ptr<HloInstruction>
+HloCustomCallInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ auto cloned = MakeUnique<HloCustomCallInstruction>(shape, new_operands,
+ custom_call_target());
+ if (window_ != nullptr) {
+ cloned->set_window(*window_);
+ }
+ if (convolution_dimension_numbers_ != nullptr) {
+ cloned->set_convolution_dimension_numbers(*convolution_dimension_numbers_);
+ }
+ return std::move(cloned);
+}
+
+HloHostComputeInstruction::HloHostComputeInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece channel_name, const int64 cost_estimate_ns)
+ : HloInstruction(HloOpcode::kHostCompute, shape),
+ channel_name_(channel_name.begin(), channel_name.end()),
+ cost_estimate_ns_(cost_estimate_ns) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloInstructionProto HloHostComputeInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ proto.set_channel_name(channel_name_);
+ proto.set_cost_estimate_ns(cost_estimate_ns_);
+ return proto;
+}
+
+bool HloHostComputeInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ // Not yet supported.
+ return false;
+}
+
+std::unique_ptr<HloInstruction>
+HloHostComputeInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ return MakeUnique<HloHostComputeInstruction>(
+ shape, new_operands, channel_name_, cost_estimate_ns_);
+}
+
+HloPadInstruction::HloPadInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* padding_value,
+ const PaddingConfig& padding_config)
+ : HloInstruction(HloOpcode::kPad, shape), padding_config_(padding_config) {
+ AppendOperand(operand);
+ AppendOperand(padding_value);
+}
+
+HloInstructionProto HloPadInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_padding_config() = padding_config_;
+ return proto;
+}
+
+std::vector<string> HloPadInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {StrCat("padding=", xla::PaddingConfigToString(padding_config_))};
+}
+
+bool HloPadInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloPadInstruction&>(other);
+ return protobuf_util::ProtobufEquals(padding_config(),
+ casted_other.padding_config());
+}
+
+std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloPadInstruction>(shape, new_operands[0], new_operands[1],
+ padding_config_);
+}
+
+HloDynamicSliceInstruction::HloDynamicSliceInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes)
+ : HloInstruction(HloOpcode::kDynamicSlice, shape),
+ dynamic_slice_sizes_(slice_sizes.begin(), slice_sizes.end()) {
+ AppendOperand(operand);
+ AppendOperand(start_indices);
+}
+
+HloInstructionProto HloDynamicSliceInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ for (int64 slice_size : dynamic_slice_sizes_) {
+ proto.add_dynamic_slice_sizes(slice_size);
+ }
+ return proto;
+}
+
+std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {
+ StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")};
+}
+
+bool HloDynamicSliceInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ return true;
+}
+
+std::unique_ptr<HloInstruction>
+HloDynamicSliceInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 2);
+ return MakeUnique<HloDynamicSliceInstruction>(
+ shape, new_operands[0], new_operands[1], dynamic_slice_sizes_);
+}
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
new file mode 100644
index 0000000000..e922d94234
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -0,0 +1,1153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// All HloInstruction subclasses are put in this file.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+class HloBatchNormInstruction : public HloInstruction {
+ public:
+ // Returns feature_index field associated with the instruction. The index
+ // represents the index of the feature dimension.
+ int64 feature_index() const { return feature_index_; }
+
+ // Returns a epsilon value associated with the instruction. The is a small
+ // number added to the variance to avoid divide-by-zero error.
+ float epsilon() const { return epsilon_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ protected:
+ explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* scale, float epsilon,
+ int64 feature_index);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // A small float number added to the variance to avoid divide-by-zero error.
+ float epsilon_ = 0.0f;
+
+ // An integer value representing the index of the feature dimension.
+ int64 feature_index_ = -1;
+};
+
+class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormTrainingInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* scale,
+ HloInstruction* offset,
+ float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormInferenceInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
+ float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloBatchNormGradInstruction : public HloBatchNormInstruction {
+ public:
+ explicit HloBatchNormGradInstruction(
+ const Shape& shape, HloInstruction* operand, HloInstruction* scale,
+ HloInstruction* mean, HloInstruction* variance,
+ HloInstruction* grad_output, float epsilon, int64 feature_index);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloFftInstruction : public HloInstruction {
+ public:
+ explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
+ FftType fft_type,
+ tensorflow::gtl::ArraySlice<int64> fft_length);
+ FftType fft_type() const { return fft_type_; }
+
+ const std::vector<int64>& fft_length() const { return fft_length_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Describes FFT type for an FFT instruction.
+ FftType fft_type_ = FftType::FFT;
+
+ // Indicates the FFT length for an FFT instruction.
+ std::vector<int64> fft_length_;
+};
+
+class HloSendRecvInstruction : public HloInstruction {
+ public:
+ // Returns the channel id associated with the instruction. The id is
+ // shared between each Send/Recv pair and is globally unique to identify each
+ // channel.
+ int64 channel_id() const { return channel_id_; }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ protected:
+ explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
+ int64 channel_id);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Represents a unique identifier for each Send/Recv instruction pair.
+ int64 channel_id_;
+};
+
+class HloSendInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
+ int64 channel_id);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloSendDoneInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloSendDoneInstruction(HloSendInstruction* operand);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloRecvInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
+ int64 channel_id);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloRecvDoneInstruction : public HloSendRecvInstruction {
+ public:
+ explicit HloRecvDoneInstruction(HloRecvInstruction* operand);
+
+ private:
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+};
+
+class HloAllReduceInstruction : public HloInstruction {
+ public:
+ explicit HloAllReduceInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
+ const tensorflow::gtl::optional<int64>& all_reduce_id =
+ tensorflow::gtl::nullopt);
+
+ // Returns the group ids of each replica for CrossReplicaSum op.
+ const std::vector<int64>& replica_group_ids() const {
+ return replica_group_ids_;
+ }
+
+ // Returns the barrier config used for the CrossReplicaSum implementation of
+ // each backend.
+ string cross_replica_sum_barrier() const {
+ return cross_replica_sum_barrier_;
+ }
+ void set_cross_replica_sum_barrier(string barrier) {
+ cross_replica_sum_barrier_ = barrier;
+ }
+
+ tensorflow::gtl::optional<int64> all_reduce_id() const {
+ return all_reduce_id_;
+ }
+
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The group id of each replica for CrossReplicaSum.
+ std::vector<int64> replica_group_ids_;
+
+ // The string representation of the barrier config used for CrossReplicaSum.
+ string cross_replica_sum_barrier_;
+
+ // For Allreduce nodes from different modules, if they have the same
+ // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross modules.
+ tensorflow::gtl::optional<int64> all_reduce_id_;
+};
+
+class HloReverseInstruction : public HloInstruction {
+ public:
+ explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloConcatenateInstruction : public HloInstruction {
+ public:
+ explicit HloConcatenateInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Accessor for the dimension in which a concatenate HLO should occur.
+ int64 concatenate_dimension() const { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloReduceInstruction : public HloInstruction {
+ public:
+ explicit HloReduceInstruction(
+ const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloSortInstruction : public HloInstruction {
+ public:
+ explicit HloSortInstruction(const Shape& shape, int64 dimension,
+ HloInstruction* keys,
+ HloInstruction* values = nullptr);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns the sort dimension for this instruction
+ int64 sort_dimension() { return dimensions(0); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloTransposeInstruction : public HloInstruction {
+ public:
+ explicit HloTransposeInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns whether this instruction does a rank-2 transposition.
+ bool IsRank2Transpose() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloBroadcastInstruction : public HloInstruction {
+ public:
+ explicit HloBroadcastInstruction(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimension);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloMapInstruction : public HloInstruction {
+ public:
+ explicit HloMapInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation);
+ // Returns the dimension sizes or numbers associated with this instruction.
+ const std::vector<int64>& dimensions() const override { return dimensions_; }
+ int64 dimensions(int64 index) const override { return dimensions()[index]; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ bool IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::vector<int64> dimensions_;
+};
+
+class HloSliceInstruction : public HloInstruction {
+ public:
+ explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices,
+ tensorflow::gtl::ArraySlice<int64> strides);
+
+ HloInstructionProto ToProto() const override;
+
+ // Returns the start index in the given dimension for a slice node.
+ int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
+ const std::vector<int64>& slice_starts() const { return slice_starts_; }
+
+ // Returns the (exclusive) limit index in the given dimension for a slice
+ // node.
+ int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
+ const std::vector<int64>& slice_limits() const { return slice_limits_; }
+
+ // Returns the stride in the given dimension for a slice node.
+ int64 slice_strides(int64 dimension) const {
+ return slice_strides_[dimension];
+ }
+ const std::vector<int64>& slice_strides() const { return slice_strides_; }
+
+ // Returns the flag that describes whether a slice must be lowered into an
+ // offset into the original operand.
+ bool IsInPlaceSlice() const { return is_in_place_slice_; }
+
+ // Sets and returns the flag that describes whether a slice must be lowered
+ // into an offset into the original operand.
+ bool SetIsInPlaceSlice(bool value) {
+ is_in_place_slice_ = value;
+ return value;
+ }
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Describes the [begin, end) index range for a slice.
+ std::vector<int64> slice_starts_;
+ std::vector<int64> slice_limits_;
+ std::vector<int64> slice_strides_;
+
+ // Describes whether the slice can be lowered to an offset into the operand.
+ bool is_in_place_slice_ = false;
+};
+
+class HloConstantInstruction : public HloInstruction {
+ public:
+ explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
+ // Used when the literal is too large and dropped.
+ explicit HloConstantInstruction(const Shape& shape);
+ // Returns the literal associated with this instruction.
+ const Literal& literal() const { return *literal_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Change the layout for an Constant Hlo instruction to match new_layout. For
+ // tuple shaped constants shape_index is the path to the internal array
+ // subshape whose layout needs to be changed.
+ void RelayoutConstant(const Layout& new_layout,
+ const ShapeIndex& shape_index = {});
+
+ private:
+ bool IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ string OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ // TODO(b/36360764): Remove unique_ptr wrapping.
+ std::unique_ptr<Literal> literal_;
+};
+
+class HloTraceInstruction : public HloInstruction {
+ public:
+ explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
+ // Returns a tag to be used in tracing.
+ string TracingTag() const { return literal_->GetR1U8AsString(); }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ // TODO(b/36360764): Remove unique_ptr wrapping.
+ std::unique_ptr<Literal> literal_;
+};
+
+class HloFusionInstruction : public HloInstruction {
+ public:
+ explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
+ HloInstruction* fused_root);
+
+ explicit HloFusionInstruction(
+ const Shape& shape, FusionKind fusion_kind,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* fusion_computation);
+
+ string ToCategory() const override;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Adds a new operand the fusion instruction.
+ HloInstruction* AddFusionOperand(HloInstruction* new_operand);
+
+ // Merges the fused instructions from 'instruction_to_merge' into the
+ // fused instruction set of 'this', updating operands as necessary.
+ //
+ // Predondition: 'instruction_to_merge' must be an operand of 'this'.
+ void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
+
+ // Merges the fused instructions from instruction_to_merge into the fused
+ // instruction set of 'this' and generates multioutput fusion instructions.
+ // All the users of instruction_to_merge will be redirected to 'this'
+ // instruction. instruction_to_merge will be removed from its parent
+ // computation.
+ void MergeFusionInstructionIntoMultiOutput(
+ HloFusionInstruction* instruction_to_merge);
+
+ // Fuses the given instruction in this fusion instruction. instruction_to_fuse
+ // is cloned and the clone is placed in the fusion
+ // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
+ // than moved to cleanly handle the case where the instruction has a use
+ // outside the fusion instruction. Moving such an instruction into a fusion
+ // instruction would violate the single-result invariant of HLO instructions
+ // and significantly complicate code generation.
+ HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
+ return FuseInstructionInternal(instruction_to_fuse);
+ }
+
+ // Fuses the given instruction in this fusion instruction and generate
+ // multioutput fusion instruction. A clone of the instruction_to_fuse will
+ // be part of the output of fusion instructions. The users of
+ // instruction_to_fuse will be redirected to this fusion instructions.
+ // instruction_to_fuse will be removed from its parent computation.
+ HloInstruction* FuseInstructionIntoMultiOutput(
+ HloInstruction* instruction_to_fuse) {
+ return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
+ }
+
+ // Returns the computation for this fused instruction.
+ HloComputation* fused_instructions_computation() const;
+
+ // Returns the root instruction of the fused expression contained within this
+ // fusion instruction.
+ HloInstruction* fused_expression_root() const;
+
+ // Returns the list of fused instructions inside this fusion instruction. The
+ // returned type is a range of HloInstruction*s.
+ const tensorflow::gtl::iterator_range<UnwrappingIterator<
+ std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
+ fused_instructions() const;
+
+ const tensorflow::gtl::iterator_range<
+ UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
+ fused_instructions();
+
+ // Gets the number of instructions inside this fusion instruction.
+ int64 fused_instruction_count() const;
+
+ // Returns the fused parameter instruction in this fusion instruction
+ // corresponding to the given parameter number.
+ HloInstruction* fused_parameter(int64 parameter_number) const;
+
+ // Returns the vector of fused parameters inside this fusion instruction.
+ const std::vector<HloInstruction*>& fused_parameters() const;
+
+ // Returns true if this instruction is a fusion instruction that generates
+ // multiple outputs.
+ const bool IsMultiOutputFusion() const {
+ return fused_expression_root()->opcode() == HloOpcode::kTuple;
+ }
+
+ FusionKind fusion_kind() const { return fusion_kind_; }
+
+ void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
+
+ // If multiple operands are the same instruction, keeps only one of them.
+ Status DeduplicateFusionOperands();
+
+ private:
+ // Fuses the given instruction into this fusion instruction. When add_output
+ // is false (which is the default), instruction_to_fuse is cloned and the
+ // clone is placed in the fusion instruction. instruction_to_fuse is
+ // unchanged.
+ //
+ // When add_output is true, a clone of the instruction_to_fuse will be part
+ // of the output of fusion instructions. The users of instruction_to_fuse
+ // will be redirected to this fusion instructions. instruction_to_fuse will
+ // be removed from its parent computation.
+ HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
+ bool add_output = false);
+ // Clones the given instruction_to_fuse and insert the clone into this fusion
+ // instruction. If add_output is true, a clone of instruction_to_fuse will
+ // be in the output of the this fusion instruction (part of the tuple of the
+ // fusion root).
+ HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
+ bool add_output = false);
+
+ bool IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The type of the fusion. Used by kFusion only.
+ FusionKind fusion_kind_;
+};
+
+class HloRngInstruction : public HloInstruction {
+ public:
+ explicit HloRngInstruction(
+ const Shape& shape, RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+ // Returns the random distribution for this rng node.
+ RandomDistribution random_distribution() const { return distribution_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ bool IsElementwiseImpl(
+ const tensorflow::gtl::optional<int64>& operand_idx) const override;
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The distribution requested for random number generation.
+ RandomDistribution distribution_;
+};
+
+class HloParameterInstruction : public HloInstruction {
+ public:
+ explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
+ const string& name);
+ int64 parameter_number() const { return parameter_number_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ string OperandsToStringWithCanonicalNameMap(
+ const HloPrintOptions& options,
+ CanonicalNameMap* canonical_name_map) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ int64 parameter_number_ = 0;
+};
+
+class HloGetTupleElementInstruction : public HloInstruction {
+ public:
+ explicit HloGetTupleElementInstruction(const Shape& shape,
+ HloInstruction* operand, int64 index);
+ // Returns the tuple index associated with this instruction.
+ int64 tuple_index() const { return tuple_index_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ int64 tuple_index_ = -1;
+};
+
+class HloReducePrecisionInstruction : public HloInstruction {
+ public:
+ explicit HloReducePrecisionInstruction(const Shape& shape,
+ HloInstruction* operand,
+ const int exponent_bits,
+ const int mantissa_bits);
+ // Returns the number of exponent bits for a reduce-precision node.
+ int32 exponent_bits() const { return exponent_bits_; }
+ // Returns the number of mantissa bits for a reduce-precision node.
+ int32 mantissa_bits() const { return mantissa_bits_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The bit sizes for a reduce-precision operation.
+ int32 exponent_bits_ = 0;
+ int32 mantissa_bits_ = 0;
+};
+
+class HloInfeedInstruction : public HloInstruction {
+ public:
+ explicit HloInfeedInstruction(const Shape& infeed_shape,
+ HloInstruction* token_operand,
+ const string& config);
+ // TODO(b/80000000): Remove this constructor when all uses of infeed are
+ // converted to take tokens.
+ explicit HloInfeedInstruction(const Shape& infeed_shape,
+ const string& config);
+ // Returns the infeed configuration string. The infeed configuration includes
+ // any metadata needed for the backend compiler (e.g., infeed buffer address)
+ // and is target-dependent.
+ string infeed_config() const { return infeed_config_; }
+ void set_infeed_config(const string& config) { infeed_config_ = config; }
+ // Returns the shape of the data received by the infeed. This is not the same
+ // as the shape of the infeed instruction which produces a tuple containing
+ // the infeed data shape and a TOKEN.
+ const Shape& infeed_shape() const {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
+ return ShapeUtil::GetSubshape(shape(), {0});
+ }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The string representation of the infeed configuration.
+ string infeed_config_;
+};
+
+class HloOutfeedInstruction : public HloInstruction {
+ public:
+ explicit HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ HloInstruction* token_operand,
+ tensorflow::StringPiece outfeed_config);
+ // TODO(b/80000000): Remove this constructor when all uses of outfeed are
+ // converted to take tokens.
+ explicit HloOutfeedInstruction(const Shape& outfeed_shape,
+ HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config);
+
+ // Returns the shape for the Outfeed instruction.
+ const Shape& outfeed_shape() const {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
+ return outfeed_shape_;
+ }
+ // Returns the config for the Outfeed instruction.
+ const string& outfeed_config() const { return outfeed_config_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Shape of outfeed request.
+ Shape outfeed_shape_;
+ // Outfeed configuration information, only present for kOutfeed.
+ string outfeed_config_;
+};
+
+class HloConvolutionInstruction : public HloInstruction {
+ public:
+ explicit HloConvolutionInstruction(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+ const Window& window() const override { return window_; }
+ void set_window(const Window& window) override { window_ = window; }
+ const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
+ return convolution_dimension_numbers_;
+ }
+ void set_convolution_dimension_numbers(
+ const ConvolutionDimensionNumbers& dnums) {
+ convolution_dimension_numbers_ = dnums;
+ }
+ string ToCategory() const override;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ Window window_;
+ // Describes the dimension numbers used for a convolution.
+ ConvolutionDimensionNumbers convolution_dimension_numbers_;
+};
+
+class HloReduceWindowInstruction : public HloInstruction {
+ public:
+ explicit HloReduceWindowInstruction(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* init_value,
+ const Window& window,
+ HloComputation* reduce_computation);
+ const Window& window() const override { return window_; }
+ void set_window(const Window& window) override { window_ = window; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ Window window_;
+};
+
+class HloSelectAndScatterInstruction : public HloInstruction {
+ public:
+ explicit HloSelectAndScatterInstruction(
+ const Shape& shape, HloInstruction* operand, HloComputation* select,
+ const Window& window, HloInstruction* source, HloInstruction* init_value,
+ HloComputation* scatter);
+ const Window& window() const override { return window_; }
+ void set_window(const Window& window) override { window_ = window; }
+ // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
+ // setters should only be called by HloModule or HloComputation methods.
+ HloComputation* select() const {
+ return called_computations()[kSelectComputationIndex];
+ }
+
+ HloComputation* scatter() const {
+ return called_computations()[kScatterComputationIndex];
+ }
+
+ void set_select(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
+ set_called_computation(kSelectComputationIndex, computation);
+ }
+
+ void set_scatter(HloComputation* computation) {
+ // Don't allow changing the computation for fused instructions so we don't
+ // have to recompute called_instructions for the entire fusion instruction.
+ CHECK(!IsFused());
+ set_called_computation(kScatterComputationIndex, computation);
+ }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ Window window_;
+};
+
+class HloCustomCallInstruction : public HloInstruction {
+ public:
+ explicit HloCustomCallInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target);
+ const Window& window() const override {
+ CHECK(window_ != nullptr);
+ return *window_;
+ }
+
+ void set_window(const Window& window) override {
+ window_ = MakeUnique<Window>(window);
+ }
+
+ const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
+ CHECK(convolution_dimension_numbers_ != nullptr);
+ return *convolution_dimension_numbers_;
+ }
+
+ void set_convolution_dimension_numbers(
+ const ConvolutionDimensionNumbers& dnums) {
+ convolution_dimension_numbers_ =
+ MakeUnique<ConvolutionDimensionNumbers>(dnums);
+ }
+ const string& custom_call_target() const { return custom_call_target_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ // Name of a global symbol to call, only present for kCustomCall.
+ string custom_call_target_;
+ // Describes the window in a windowed operation such as convolution.
+ std::unique_ptr<Window> window_;
+ // Describes the dimension numbers used for a convolution.
+ std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+};
+
+class HloHostComputeInstruction : public HloInstruction {
+ public:
+ explicit HloHostComputeInstruction(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
+ // Returns the channel name associated with the instruction. The name is
+ // used to identify host Send/Recv operations.
+ const string& channel_name() const { return channel_name_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+ // Name to use for host send/recv channels.
+ string channel_name_;
+ // Estimate of the duration of a host computation in nanoseconds.
+ int64 cost_estimate_ns_ = 0;
+};
+
+class HloPadInstruction : public HloInstruction {
+ public:
+ explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
+ HloInstruction* padding_value,
+ const PaddingConfig& padding_config);
+ // Returns the padding configuration for a pad node.
+ const PaddingConfig& padding_config() const { return padding_config_; }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // The padding configuration that describes the edge padding and interior
+ // padding of this pad instruction.
+ PaddingConfig padding_config_;
+};
+
+class HloDynamicSliceInstruction : public HloInstruction {
+ public:
+ explicit HloDynamicSliceInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+ // Old methods kept for smooth subclassing transition END.
+ // Returns the size of the slice in the given dimension for a dynamic
+ // slice node.
+ int64 slice_sizes(int64 dimension) const {
+ return dynamic_slice_sizes_[dimension];
+ }
+ const std::vector<int64>& dynamic_slice_sizes() const {
+ return dynamic_slice_sizes_;
+ }
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ // Describes the [start, start + size) range size for a dynamic slice
+ // ('start' is specified dynamically in the second operand of the operation).
+ std::vector<int64> dynamic_slice_sizes_;
+};
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 350db12653..f0d9fdbc8f 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
+#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include <unordered_map>
@@ -26,9 +26,8 @@ limitations under the License.
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-namespace tools {
-using tensorflow::StringPiece;
+using ::tensorflow::StringPiece;
namespace {
@@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
-StringPiece HloLexer::StringPieceFromPointers(const char* begin,
- const char* end) const {
+tensorflow::StringPiece HloLexer::StringPieceFromPointers(
+ const char* begin, const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
- return StringPiece(begin, end - begin);
+ return tensorflow::StringPiece(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
@@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kAttributeName;
}
- StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
+ tensorflow::StringPiece identifier =
+ StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
#define KEYWORD(STR) \
@@ -332,23 +332,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no_cache_.last_query = ptr;
line_no_cache_.line_no_of_query = line_no;
size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
- if (line_offset == StringPiece::npos) {
+ if (line_offset == tensorflow::StringPiece::npos) {
line_offset = 0;
}
return {line_no, ptr - start - line_offset};
}
-StringPiece HloLexer::GetLine(LocTy loc) const {
+tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
if (!CanDereference(loc)) {
return "LINE OUT OF RANGE";
}
size_t line_start =
StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
- const char* start = line_start == StringPiece::npos
+ const char* start = line_start == tensorflow::StringPiece::npos
? buf_.begin()
: buf_.begin() + line_start + 1;
size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
- const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end;
+ const char* end =
+ line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end;
return StringPieceFromPointers(start, end);
}
@@ -370,7 +371,7 @@ TokKind HloLexer::LexString() {
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
if (RE2::Consume(&consumable, *escaping_pattern)) {
current_ptr_ = consumable.begin();
- StringPiece raw =
+ tensorflow::StringPiece raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
@@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) {
}
}
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index 27880b9b8a..ceb674f25e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
#include <string>
-#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
+#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -27,9 +27,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace tools {
// Lexer for the HloModule::ToString() format text.
+//
+// This class is meant to be used by hlo_parser.cc. You shouldn't need to use
+// it directly.
class HloLexer {
public:
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
@@ -57,7 +59,7 @@ class HloLexer {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
- int64 GetInt64Val() const {
+ tensorflow::int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
}
@@ -114,7 +116,7 @@ class HloLexer {
TokKind current_kind_;
string str_val_;
Shape shape_val_;
- int64 int64_val_;
+ tensorflow::int64 int64_val_;
double decimal_val_;
struct LineNoCacheTy {
@@ -125,7 +127,6 @@ class HloLexer {
mutable LineNoCacheTy line_no_cache_{nullptr, 0};
};
-} // namespace tools
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 8e2e2c7627..01b625c29c 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -15,15 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -59,7 +59,7 @@ class HloLivenessAnalysisTest : public HloTestBase {
// Test that add instruction at entry root is live at all output shape indices.
TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -75,7 +75,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
// Test that a dead add instruction is marked as dead by analysis.
TEST_F(HloLivenessAnalysisTest, DeadAdd) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -94,7 +94,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) {
// Test that all output shape indices of entry root tuple (and defining
// instruction in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -113,7 +113,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
// Tests that all outputs of nested tuple and entry root (and defining
// instruction values appearing in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(1)
@@ -140,7 +140,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
// Tests that GTE at entry root of Tuple instruction only propgates liveness
// to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -162,7 +162,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
// Tests that GTE at entry root of nested Tuple instruction only propgates
// liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -199,7 +199,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
// Tests that GTE of GTE (at entry root) of nested Tuple instruction only
// propgates liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -240,7 +240,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
// Test that live/dead while tuple elements are marked live/dead correctly.
TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -291,7 +291,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
// Tests that a tuple element live in while.cond computation, propagates
// liveness to while.body.root/while.result/while.operand (where it is unused).
TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -345,7 +345,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
// Tests that a use of while.result{0} propagates liveness to
// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[], s32[]) parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index dfefad3634..b57c940238 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -187,6 +187,7 @@ HLO_MATCHER(Exp);
HLO_MATCHER(Floor);
HLO_MATCHER(Fusion);
HLO_MATCHER(Ge);
+HLO_MATCHER(AfterAll);
HLO_MATCHER(Gt);
HLO_MATCHER(Infeed);
HLO_MATCHER(IsFinite);
@@ -195,6 +196,7 @@ HLO_MATCHER(Log);
HLO_MATCHER(And);
HLO_MATCHER(Not);
HLO_MATCHER(Or);
+HLO_MATCHER(Xor);
HLO_MATCHER(Lt);
HLO_MATCHER(Map);
HLO_MATCHER(Maximum);
@@ -329,7 +331,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
tensorflow::StringPiece sharding) {
return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
- xla::tools::ParseSharding(sharding).ValueOrDie()));
+ ParseSharding(sharding).ValueOrDie()));
}
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 1d10e3c4fe..7de59acc1e 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace op = xla::testing::opcode_matchers;
@@ -74,8 +76,10 @@ TEST(HloMatchersTest, Test) {
}
TEST(HloMatchersTest, CustomCallMatcher) {
- auto c1 = HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3}));
- auto c2 = HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3}));
+ auto c1 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1, 2, 3}));
+ auto c2 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3}));
auto call = HloInstruction::CreateCustomCall(
ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target");
@@ -194,7 +198,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index e63424c2df..55ff073d3f 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -32,15 +32,6 @@ limitations under the License.
namespace xla {
-HloModule::HloModule(const string& name,
- const VersionedComputationHandle& entry_computation_handle,
- const HloModuleConfig& config)
- : name_(NameUniquer::GetSanitizedName(name)),
- config_(config),
- has_entry_computation_handle_(true),
- entry_computation_handle_(entry_computation_handle),
- unique_id_(next_unique_module_id_++) {}
-
HloModule::HloModule(const string& name, const HloModuleConfig& config)
: name_(NameUniquer::GetSanitizedName(name)),
config_(config),
@@ -67,7 +58,7 @@ HloComputation* HloModule::AddComputationInternal(
// If the module configuration has no entry layout computation set, create a
// default one based on the program shape.
- if (!config_.has_host_entry_computation_layout()) {
+ if (!config_.has_entry_computation_layout()) {
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
@@ -234,21 +225,17 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
- const HloModuleProto& proto, const HloModuleConfig& module_config,
- const VersionedComputationHandle& entry_computation_handle) {
+ const HloModuleProto& proto, const HloModuleConfig& module_config) {
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
<< "No program shape found in the proto";
const auto& expected_program_shape = proto.program_shape();
- TF_RET_CHECK(
- expected_program_shape.parameters_size() ==
- module_config.device_entry_computation_layout().parameter_count());
+ TF_RET_CHECK(expected_program_shape.parameters_size() ==
+ module_config.entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
- module_config.device_entry_computation_layout()
- .parameter_layout(i)
- .shape();
+ module_config.entry_computation_layout().parameter_layout(i).shape();
TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
@@ -258,7 +245,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
}
const Shape& result_shape =
- module_config.device_entry_computation_layout().result_layout().shape();
+ module_config.entry_computation_layout().result_layout().shape();
TF_RET_CHECK(
ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
@@ -287,8 +274,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(entry != nullptr);
- auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
- module_config);
+ auto module = MakeUnique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
@@ -338,7 +324,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
ComputationLayout* entry_layout =
- module_config.mutable_host_entry_computation_layout();
+ module_config.mutable_entry_computation_layout();
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
TF_RETURN_IF_ERROR(
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
@@ -346,9 +332,6 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
}
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
program_shape.result()));
- *module_config.mutable_device_entry_computation_layout() =
- module_config.host_entry_computation_layout();
-
return module_config;
}
@@ -401,7 +384,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
// as a parameter in the new function.
arguments.push_back(old_operand);
*operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
- parameter_count, old_operand->shape(), ""));
+ parameter_count, old_operand->shape(), "p"));
++parameter_count;
}
TF_CHECK_OK(
@@ -462,7 +445,7 @@ int64 HloModule::instruction_count() const {
return n;
}
-std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
+std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
// First determine all root computations by building a set of nonroot
// computations (computations which are called by an instruction in the
// module).
@@ -480,7 +463,7 @@ std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
// order. This prevents duplication as an embedded computation may be called
// from two different root computations.
std::set<HloComputation*> added_computations;
- std::list<HloComputation*> post_order;
+ std::vector<HloComputation*> post_order;
for (auto& computation : computations_) {
if (nonroot_computations.count(computation.get()) == 0) {
for (HloComputation* embedded_computation :
@@ -525,8 +508,6 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_);
- module->entry_computation_handle_ = entry_computation_handle_;
- module->has_entry_computation_handle_ = has_entry_computation_handle_;
HloCloneContext context(module.get(), suffix);
auto cloned_computation = entry_computation_->Clone(suffix, &context);
@@ -556,10 +537,11 @@ uint64 HloModule::RandomNew64() const {
HloComputation* HloModule::GetComputationWithName(
tensorflow::StringPiece name) {
- auto it = c_find_if(computations(), [&](HloComputation* computation) {
+ auto computations_in_module = computations();
+ auto it = c_find_if(computations_in_module, [&](HloComputation* computation) {
return computation->name() == name;
});
- return it == computations().end() ? nullptr : *it;
+ return it == computations_in_module.end() ? nullptr : *it;
}
/* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index c93c74d34a..d2e726a0db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -57,10 +56,6 @@ namespace xla {
// attached to.
class HloModule {
public:
- HloModule(const string& name,
- const VersionedComputationHandle& entry_computation_handle,
- const HloModuleConfig& config);
-
// Constructor without a versioned computation handle. This constructor should
// only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation
@@ -110,24 +105,19 @@ class HloModule {
return entry_computation_;
}
- ComputationLayout* mutable_host_entry_computation_layout() {
- return config_.mutable_host_entry_computation_layout();
- }
-
- const ComputationLayout& host_entry_computation_layout() const {
- return config_.host_entry_computation_layout();
+ // Creates the ComputationLayout which describes the current status of the HLO
+ // module entry computation.
+ ComputationLayout compute_computation_layout() const {
+ return ComputationLayout(entry_computation()->ComputeProgramShape(),
+ /*ignore_layouts=*/false);
}
- ComputationLayout* mutable_device_entry_computation_layout() {
- return config_.mutable_device_entry_computation_layout();
+ ComputationLayout* mutable_entry_computation_layout() {
+ return config_.mutable_entry_computation_layout();
}
- const ComputationLayout& device_entry_computation_layout() const {
- return config_.device_entry_computation_layout();
- }
-
- const VersionedComputationHandle& entry_computation_handle() const {
- return entry_computation_handle_;
+ const ComputationLayout& entry_computation_layout() const {
+ return config_.entry_computation_layout();
}
// Gets the computations in this module.
@@ -163,7 +153,7 @@ class HloModule {
// Compute and return a post order of all computations in the module. The sort
// is defined like so: if computation A has an instruction which calls
// computation B, then A will appear after B in the sort.
- std::list<HloComputation*> MakeComputationPostOrder() const;
+ std::vector<HloComputation*> MakeComputationPostOrder() const;
// Gets the computations in this module which aren't for fusion nodes.
//
@@ -188,9 +178,7 @@ class HloModule {
// Convert an HloModule to or from a proto.
HloModuleProto ToProto() const;
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
- const HloModuleProto& proto, const HloModuleConfig& module_config,
- const VersionedComputationHandle& entry_computation_handle =
- VersionedComputationHandle());
+ const HloModuleProto& proto, const HloModuleConfig& module_config);
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.
@@ -264,10 +252,6 @@ class HloModule {
mutable std::mt19937_64 rng_{42};
mutable tensorflow::mutex rng_mutex_;
- // Versioned handle of the entry computation of the module.
- bool has_entry_computation_handle_ = false;
- VersionedComputationHandle entry_computation_handle_;
-
// Unique name generator for computation and instruction names, which are
// unique per module.
NameUniquer computation_name_uniquer_{/*separator=*/"."};
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index dae5578a31..07a8c798db 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -28,16 +28,14 @@ namespace xla {
using tensorflow::strings::StrAppend;
-HloModuleConfig::HloModuleConfig() {}
-
-HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
- : host_entry_computation_layout_(program_shape),
- device_entry_computation_layout_(program_shape) {}
+HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
+ bool ignore_layouts)
+ : entry_computation_layout_(
+ ComputationLayout(program_shape, ignore_layouts)) {}
void HloModuleConfig::SetDefaultComputationLayout(
const ProgramShape& program_shape) {
- host_entry_computation_layout_ = ComputationLayout(program_shape);
- device_entry_computation_layout_ = ComputationLayout(program_shape);
+ entry_computation_layout_ = ComputationLayout(program_shape);
}
string HloModuleConfig::compilation_cache_key() const {
@@ -46,18 +44,11 @@ string HloModuleConfig::compilation_cache_key() const {
StrAppend(&key, "::(");
std::vector<string> params;
for (const ShapeLayout& param_layout :
- host_entry_computation_layout_->parameter_layouts()) {
+ entry_computation_layout_->parameter_layouts()) {
params.push_back(param_layout.shape().DebugString());
}
StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ",
- host_entry_computation_layout_->result_shape().SerializeAsString());
- for (const ShapeLayout& param_layout :
- device_entry_computation_layout_->parameter_layouts()) {
- params.push_back(param_layout.shape().DebugString());
- }
- StrAppend(
- &key, tensorflow::str_util::Join(params, ", "), ") => ",
- device_entry_computation_layout_->result_shape().SerializeAsString());
+ entry_computation_layout_->result_shape().SerializeAsString());
if (seed() != 0) {
// TODO(b/32083678): force recompilation to reset global state.
static std::atomic<int> counter{0};
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index cdb0b29a23..074e9c9070 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -37,48 +37,34 @@ class HloModuleConfig {
// ComputationLayout. The default ctor creates it without -- in this case
// accessing entry_computation_layout will CHECK-fail. The ctor accepting a
// ProgramShape creates a computation layout using this shape.
- HloModuleConfig();
- explicit HloModuleConfig(const ProgramShape& program_shape);
+ // The layouts in the ProgramShape will be reset to default unless
+ // ignore_layouts is set to false.
+ HloModuleConfig() = default;
- // Checks if this config has an entry computation layout already.
- bool has_host_entry_computation_layout() const {
- return host_entry_computation_layout_.has_value();
- }
+ explicit HloModuleConfig(const ProgramShape& program_shape,
+ bool ignore_layouts = true);
- bool has_device_entry_computation_layout() const {
- return device_entry_computation_layout_.has_value();
+ // Checks if this config has an entry computation layout already.
+ bool has_entry_computation_layout() const {
+ return entry_computation_layout_.has_value();
}
// Sets the entry computation layout for this config. If the entry computation
// layout already exists, it is silently replaced.
void SetDefaultComputationLayout(const ProgramShape& program_shape);
- // Returns a constant reference to the on-host layout of the entry
- // computation. Assumes the layout was set.
- const ComputationLayout& host_entry_computation_layout() const {
- CHECK(host_entry_computation_layout_.has_value());
- return *host_entry_computation_layout_;
- }
-
- // Returns a mutable pointer to the layout of the on-host entry computation.
+ // Returns a constant reference to the layout of the entry computation.
// Assumes the layout was set.
- ComputationLayout* mutable_host_entry_computation_layout() {
- CHECK(host_entry_computation_layout_.has_value());
- return &(*host_entry_computation_layout_);
- }
-
- // Returns a constant reference to the on-device layout of the entry
- // computation. Assumes the layout was set.
- const ComputationLayout& device_entry_computation_layout() const {
- CHECK(device_entry_computation_layout_.has_value());
- return *device_entry_computation_layout_;
+ const ComputationLayout& entry_computation_layout() const {
+ CHECK(entry_computation_layout_.has_value());
+ return *entry_computation_layout_;
}
- // Returns a mutable pointer to the layout of the on-device entry computation.
+ // Returns a mutable pointer to the layout of the entry computation.
// Assumes the layout was set.
- ComputationLayout* mutable_device_entry_computation_layout() {
- CHECK(device_entry_computation_layout_.has_value());
- return &(*device_entry_computation_layout_);
+ ComputationLayout* mutable_entry_computation_layout() {
+ CHECK(entry_computation_layout_.has_value());
+ return &(*entry_computation_layout_);
}
// Returns whether to enable HLO-level profiling.
@@ -127,8 +113,7 @@ class HloModuleConfig {
private:
// If you add new members, be sure to update compilation_cache_key.
- tensorflow::gtl::optional<ComputationLayout> host_entry_computation_layout_;
- tensorflow::gtl::optional<ComputationLayout> device_entry_computation_layout_;
+ tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_;
// Whether this is a 'host module'.
bool is_host_module_ = false;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 53b7d0ed39..363862e490 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
@@ -73,7 +73,7 @@ class HloModuleDceTest : public HloTestBase {
// Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -110,7 +110,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
// Tests a while loop with one unused output (which is used in the while loop
// body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], f32[]) parameter(0)
@@ -150,7 +150,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
// Tests that a while loop with one dead tuple element at {1} has its while
// loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -193,7 +193,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
// dead in while.body{1} and at while.result{1}) propgates liveness of this
// tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[]) parameter(0)
@@ -235,7 +235,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
// two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -303,7 +303,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
// while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index f6fa45a6b7..6bcd7b042d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -113,6 +113,9 @@ Status HloModuleGroupMetadata::Build() {
}
}
TF_RETURN_IF_ERROR(VerifyCompanionSets());
+ if (VLOG_IS_ON(4)) {
+ DumpCollectedStats();
+ }
return Status::OK();
}
@@ -124,9 +127,14 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const {
for (HloInstruction* instruction : *companions) {
// Go through all the communicating instructions (send, recv) of the given
// companion, and record their device.
+ auto it = tracked_instructions_comms_.find(instruction);
+ if (it == tracked_instructions_comms_.end()) {
+ // Companions can be added even if they have no communicating
+ // instructions, if they are parent of companions.
+ continue;
+ }
std::unordered_set<int64> comm_devices;
- for (HloInstruction* comm_instruction :
- tracked_instructions_comms_.at(instruction)) {
+ for (HloInstruction* comm_instruction : it->second) {
auto device = GetInstructionDevice(*comm_instruction);
TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
<< " does not have a device";
@@ -315,6 +323,7 @@ Status HloModuleGroupMetadata::RecordInstructions() {
TF_RETURN_IF_ERROR(computation->Accept(visitor));
}
}
+ VLOG(2) << "Created " << channels_.size() << " channels";
return Status::OK();
}
@@ -373,7 +382,8 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
// Check if the shapes match for each channel.
for (const Channel& channel : channels_) {
const Shape& send_shape = channel.send->operand(0)->shape();
- const Shape& recv_shape = channel.recv_done->shape();
+ const Shape& recv_shape =
+ ShapeUtil::GetTupleElementShape(channel.recv_done->shape(), 0);
if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
return FailedPrecondition("send/recv shapes do not match");
}
@@ -445,4 +455,36 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
return FailedPrecondition("channel is used in disallowed computation");
}
+void HloModuleGroupMetadata::DumpCollectedStats() const {
+ std::map<std::pair<int64, int64>, int64> communication_histogram;
+ for (auto& channel : channels_) {
+ auto from_device = GetInstructionDevice(*channel.send);
+ auto to_device = GetInstructionDevice(*channel.recv);
+ LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device
+ << " to_device=" << *to_device << " send=" << channel.send->name()
+ << " send_done=" << channel.send_done->name()
+ << " recv=" << channel.recv->name()
+ << " recv_done=" << channel.recv_done->name();
+ communication_histogram[std::pair<int64, int64>(*from_device,
+ *to_device)] += 1;
+ }
+ for (auto& fromto_count : communication_histogram) {
+ LOG(INFO) << "From " << fromto_count.first.first << " to "
+ << fromto_count.first.second << ": " << fromto_count.second;
+ }
+ for (auto& companion_set : companion_sets_) {
+ LOG(INFO) << "Companion set:";
+ for (HloInstruction* instruction : *companion_set) {
+ LOG(INFO) << " " << instruction->name();
+ }
+ }
+ for (auto& instruction_comm : tracked_instructions_comms_) {
+ LOG(INFO) << "Communicating instruction " << instruction_comm.first->name();
+ for (HloInstruction* instruction : instruction_comm.second) {
+ auto device = GetInstructionDevice(*instruction);
+ LOG(INFO) << " " << instruction->name() << " on device " << *device;
+ }
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index f68d4028dc..ffde3a332d 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -230,6 +230,9 @@ class HloModuleGroupMetadata {
return it != tracked_instructions_.end() ? &it->second : nullptr;
}
+ // Dump all the collected module group statistics to the logs.
+ void DumpCollectedStats() const;
+
// List of all companion instructions sets in the module.
std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
companion_sets_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 5a0d1e264e..df1d562048 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -277,7 +277,7 @@ Status HloModuleGroupUtil::VerifyComputations(
StatusOr<std::unique_ptr<HloReachabilityMap>>
HloModuleGroupUtil::ComputeReachability(
tensorflow::gtl::ArraySlice<HloComputation*> computations) {
- std::list<HloInstruction*> post_order;
+ std::vector<HloInstruction*> post_order;
auto visit_function =
[&](HloInstruction* instruction,
const std::vector<HloInstruction*>& instruction_group) {
@@ -292,7 +292,7 @@ HloModuleGroupUtil::ComputeReachability(
}
auto reachability = MakeUnique<HloReachabilityMap>(post_order);
for (HloInstruction* hlo : post_order) {
- reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
+ reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo);
}
return std::move(reachability);
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 7f28a804bf..236f450086 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase {
std::unique_ptr<HloComputation> CreateConstantComputation() {
auto builder = HloComputation::Builder("Constant");
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
return builder.Build();
}
@@ -122,7 +122,7 @@ TEST_F(HloModuleTest, CloneHasFusion) {
{
auto b = HloComputation::Builder("Entry");
auto input = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
b.AddInstruction(
HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
/*operands=*/{input}, fused_computation));
@@ -173,7 +173,7 @@ TEST_F(HloModuleTest, LargeConstantToString) {
auto builder = HloComputation::Builder("Constant");
std::vector<float> values(16, 42.0);
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(values)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
module->AddEntryComputation(builder.Build());
EXPECT_EQ(
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 1fe06ee0c0..39e12c4815 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -81,6 +81,7 @@ namespace xla {
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
+ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kHostCompute, "host-compute") \
@@ -93,6 +94,7 @@ namespace xla {
V(kAnd, "and") \
V(kNot, "not") \
V(kOr, "or") \
+ V(kXor, "xor") \
V(kLt, "less-than", kHloOpcodeIsComparison) \
V(kMap, "map", kHloOpcodeIsVariadic) \
V(kMaximum, "maximum") \
@@ -131,6 +133,7 @@ namespace xla {
V(kTrace, "trace") \
V(kTranspose, "transpose") \
V(kTuple, "tuple", kHloOpcodeIsVariadic) \
+ V(kTupleSelect, "tuple-select") \
V(kWhile, "while")
enum class HloOpcode {
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
index cd2ce5c69f..6f3f83f63a 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTuple:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index dcd4725fe7..6c1e015f77 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition(
<< " and def is in FALSE computation";
return true;
}
+ if (value.defining_instruction() == use.instruction) {
+ VLOG(4) << " use is conditional " << use << " and def is "
+ << value.ToShortString();
+ return true;
+ }
}
VLOG(4) << " use is not before value";
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index ee526d8dd7..985f3fa64d 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
// interference is reduced relative to DependencyHloOrdering.
class SequentialHloOrdering : public HloOrdering {
public:
+ // TODO(dimvar): HloModuleSequence is not a good name because it sounds like
+ // a sequence of modules, instead of a map of schedules for all computations
+ // in a module. We should change it at some point.
+ //
// A sequence of instructions for each computation in the module.
using HloModuleSequence =
tensorflow::gtl::FlatMap<const HloComputation*,
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 37a7fbad97..126d3a2d9c 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -57,7 +57,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) {
auto builder_c = HloComputation::Builder("C");
HloInstruction* c = builder_c.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
HloComputation* computation_c =
module->AddEmbeddedComputation(builder_c.Build());
@@ -145,7 +145,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
module->AddEntryComputation(builder.Build());
@@ -208,7 +208,7 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -310,7 +310,7 @@ ENTRY while.v11 {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
DependencyHloOrdering ordering(module.get());
ordering.ToString(); // Shouldn't crash.
}
@@ -347,7 +347,7 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ef10ca4bff..f162d52d3c 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -13,10 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -24,18 +27,17 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
-namespace tools {
namespace {
-using tensorflow::StringPiece;
-using tensorflow::gtl::optional;
-using tensorflow::str_util::Join;
-using tensorflow::str_util::Split;
-using tensorflow::str_util::SplitAndParseAsInts;
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using ::tensorflow::StringPiece;
+using ::tensorflow::gtl::optional;
+using ::tensorflow::str_util::Join;
+using ::tensorflow::str_util::Split;
+using ::tensorflow::str_util::SplitAndParseAsInts;
+using ::tensorflow::strings::Printf;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
const double kF16max = 65504;
@@ -83,11 +85,15 @@ class HloParser {
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
- bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal);
- bool SetValueInLiteral(double value, int64 linear_index, Literal* literal);
- bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal);
+ bool SetValueInLiteral(tensorflow::int64 value,
+ tensorflow::int64 linear_index, Literal* literal);
+ bool SetValueInLiteral(double value, tensorflow::int64 linear_index,
+ Literal* literal);
+ bool SetValueInLiteral(bool value, tensorflow::int64 linear_index,
+ Literal* literal);
template <typename LiteralNativeT, typename ParsedElemT>
- bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
+ bool SetValueInLiteralHelper(ParsedElemT value,
+ tensorflow::int64 linear_index,
Literal* literal);
bool ParseOperands(std::vector<HloInstruction*>* operands);
@@ -99,9 +105,15 @@ class HloParser {
// Describes the start, limit, and stride on every dimension of the operand
// being sliced.
struct SliceRanges {
- std::vector<int64> starts;
- std::vector<int64> limits;
- std::vector<int64> strides;
+ std::vector<tensorflow::int64> starts;
+ std::vector<tensorflow::int64> limits;
+ std::vector<tensorflow::int64> strides;
+ };
+
+ // The data parsed for the kDomain instruction.
+ struct DomainData {
+ std::unique_ptr<DomainMetadata> entry_metadata;
+ std::unique_ptr<DomainMetadata> exit_metadata;
};
// Types of attributes.
@@ -122,6 +134,7 @@ class HloParser {
kMetadata,
kFusionKind,
kDistribution,
+ kDomain,
};
struct AttrConfig {
@@ -178,14 +191,18 @@ class HloParser {
bool ParseSharding(OpSharding* sharding);
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
+ // Parses the metadata behind a kDOmain instruction.
+ bool ParseDomain(DomainData* domain);
+
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
- bool ParseDxD(const string& name, std::vector<int64>* result);
+ bool ParseDxD(const string& name, std::vector<tensorflow::int64>* result);
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
- bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
+ bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
bool ParseInt64List(const TokKind start, const TokKind end,
- const TokKind delim, std::vector<int64>* result);
+ const TokKind delim,
+ std::vector<tensorflow::int64>* result);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
@@ -197,7 +214,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
- bool ParseInt64(int64* result);
+ bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
@@ -311,22 +328,15 @@ bool HloParser::ParseComputations() {
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
- ->mutable_parameter_layout(p)
- ->CopyLayoutFromShape(param_shape));
- TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->CopyLayoutFromShape(result_shape));
- TF_CHECK_OK(module_->mutable_device_entry_computation_layout()
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
-
return true;
}
@@ -455,7 +465,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
- int64 parameter_number;
+ tensorflow::int64 parameter_number;
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number) ||
@@ -488,7 +498,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kClz:
case HloOpcode::kCopy:
case HloOpcode::kCos:
- case HloOpcode::kDomain:
case HloOpcode::kExp:
case HloOpcode::kExpm1:
case HloOpcode::kImag:
@@ -501,7 +510,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kReal:
case HloOpcode::kSign:
case HloOpcode::kSin:
- case HloOpcode::kSort:
case HloOpcode::kTanh: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -530,6 +538,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
@@ -543,7 +552,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
// Ternary ops.
case HloOpcode::kClamp:
- case HloOpcode::kSelect: {
+ case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect: {
if (!ParseOperands(&operands, /*expected_size=*/3) ||
!ParseAttributes(attrs)) {
return false;
@@ -572,11 +582,31 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
+ optional<HloComputation*> to_apply;
+ optional<std::vector<int64>> replica_group_ids;
+ optional<string> barrier;
+ optional<int64> all_reduce_id;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &to_apply};
+ attrs["replica_group_ids"] = {
+ /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
+ attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
+ attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64,
+ &all_reduce_id};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateCrossReplicaSum(shape, operands));
+ if (replica_group_ids) {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, *replica_group_ids,
+ barrier ? *barrier : "", all_reduce_id));
+ } else {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, {}, barrier ? *barrier : "",
+ all_reduce_id));
+ }
break;
}
case HloOpcode::kReshape: {
@@ -588,6 +618,44 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateReshape(shape, operands[0]));
break;
}
+ case HloOpcode::kAfterAll: {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ if (operands.empty()) {
+ instruction = builder->AddInstruction(HloInstruction::CreateToken());
+ } else {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateAfterAll(operands));
+ }
+ break;
+ }
+ case HloOpcode::kSort: {
+ auto loc = lexer_.GetLoc();
+
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+ &dimensions};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
+ dimensions->size() != 1) {
+ return false;
+ }
+ switch (operands.size()) {
+ case 1:
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0), /*keys=*/operands[0]));
+ break;
+ case 2:
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0),
+ /*keys=*/operands[0], /*values=*/operands[1]));
+ break;
+ default:
+ return Error(loc, StrCat("expects either 1 or 2 operands, but has ",
+ operands.size(), " operands"));
+ }
+ break;
+ }
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
@@ -611,18 +679,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kRecv: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/0) ||
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
+ instruction = builder->AddInstruction(HloInstruction::CreateRecv(
+ shape.tuple_shapes(0), operands[0], *channel_id));
break;
}
case HloOpcode::kRecvDone: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -636,18 +704,18 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kSend: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], *channel_id));
+ HloInstruction::CreateSend(operands[0], operands[1], *channel_id));
break;
}
case HloOpcode::kSendDone: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -661,7 +729,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGetTupleElement: {
- optional<int64> index;
+ optional<tensorflow::int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -719,7 +787,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kFft: {
optional<FftType> fft_type;
- optional<std::vector<int64>> fft_length;
+ optional<std::vector<tensorflow::int64>> fft_length;
attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
&fft_length};
@@ -732,7 +800,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kBroadcast: {
- optional<std::vector<int64>> broadcast_dimensions;
+ optional<std::vector<tensorflow::int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&broadcast_dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -744,7 +812,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConcatenate: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
@@ -759,6 +827,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
+ &dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
@@ -770,7 +841,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
- optional<std::vector<int64>> dimensions_to_reduce;
+ optional<std::vector<tensorflow::int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
@@ -783,7 +854,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReverse: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -827,7 +898,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kDynamicSlice: {
- optional<std::vector<int64>> dynamic_slice_sizes;
+ optional<std::vector<tensorflow::int64>> dynamic_slice_sizes;
attrs["dynamic_slice_sizes"] = {
/*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
@@ -851,7 +922,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kTranspose: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -865,7 +936,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormTraining: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
@@ -881,7 +952,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormInference: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
@@ -898,7 +969,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormGrad: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
@@ -938,23 +1009,53 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kInfeed: {
optional<string> config;
attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands, /*expected_size=*/0) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateInfeed(shape, config ? *config : ""));
+ // We need to know the infeed data shape to construct the infeed
+ // instruction. This is the zero-th element of the tuple-shaped output of
+ // the infeed instruction. ShapeUtil::GetTupleElementShape will check fail
+ // if the shape is not a non-empty tuple, so add guard so an error message
+ // can be emitted instead of a check fail
+ if (!ShapeUtil::IsTuple(shape) && !ShapeUtil::IsEmptyTuple(shape)) {
+ return Error(lexer_.GetLoc(),
+ "infeed must have a non-empty tuple shape");
+ }
+
+ if (operands.empty()) {
+ // TODO(b/80000000): Remove this when all uses of infeed are
+ // converted to take tokens.
+ instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::GetTupleElementShape(shape, 0), config ? *config : ""));
+ } else if (operands.size() == 1) {
+ instruction = builder->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::GetTupleElementShape(shape, 0), operands[0],
+ config ? *config : ""));
+ } else {
+ return Error(lexer_.GetLoc(),
+ "infeed must have exactly zero or one operands");
+ }
break;
}
case HloOpcode::kOutfeed: {
optional<string> config;
attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
- operands[0]->shape(), operands[0], config ? *config : ""));
+ if (operands.size() == 1) {
+ // TODO(b/80000000): Remove this when all uses of outfeed are
+ // converted to take tokens.
+ instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
+ operands[0]->shape(), operands[0], config ? *config : ""));
+ } else if (operands.size() == 2) {
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateOutfeed(operands[0]->shape(), operands[0],
+ operands[1], config ? *config : ""));
+ } else {
+ return Error(lexer_.GetLoc(),
+ "outfeed must have exactly one or two operands");
+ }
break;
}
case HloOpcode::kRng: {
@@ -969,8 +1070,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReducePrecision: {
- optional<int64> exponent_bits;
- optional<int64> mantissa_bits;
+ optional<tensorflow::int64> exponent_bits;
+ optional<tensorflow::int64> mantissa_bits;
attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
&exponent_bits};
attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
@@ -1015,7 +1116,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kHostCompute: {
optional<string> channel_name;
- optional<int64> cost_estimate_ns;
+ optional<tensorflow::int64> cost_estimate_ns;
attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
&channel_name};
attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
@@ -1028,16 +1129,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kDot: {
- optional<std::vector<int64>> lhs_contracting_dims;
+ optional<std::vector<tensorflow::int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
- optional<std::vector<int64>> rhs_contracting_dims;
+ optional<std::vector<tensorflow::int64>> rhs_contracting_dims;
attrs["rhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
- optional<std::vector<int64>> lhs_batch_dims;
+ optional<std::vector<tensorflow::int64>> lhs_batch_dims;
attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&lhs_batch_dims};
- optional<std::vector<int64>> rhs_batch_dims;
+ optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
@@ -1069,20 +1170,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGather: {
- optional<std::vector<int64>> output_window_dims;
+ optional<std::vector<tensorflow::int64>> output_window_dims;
attrs["output_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<int64>> elided_window_dims;
+ optional<std::vector<tensorflow::int64>> elided_window_dims;
attrs["elided_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<int64>> gather_dims_to_operand_dims;
+ optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&gather_dims_to_operand_dims};
- optional<int64> index_vector_dim;
+ optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<int64>> window_bounds;
+ optional<std::vector<tensorflow::int64>> window_bounds;
attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
&window_bounds};
@@ -1102,12 +1203,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
dim_numbers, *window_bounds));
break;
}
+ case HloOpcode::kDomain: {
+ DomainData domain;
+ attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
+ if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateDomain(
+ shape, operands[0], std::move(domain.exit_metadata),
+ std::move(domain.entry_metadata)));
+ break;
+ }
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
- instruction->set_name(name);
+ instruction->SetAndSanitizeName(name);
+ if (instruction->name() != name) {
+ return Error(name_loc,
+ StrCat("illegal instruction name: ", name,
+ "; suggest renaming to: ", instruction->name()));
+ }
// Add shared attributes like metadata to the instruction, if they were seen.
if (sharding) {
@@ -1178,8 +1296,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
LocTy loc = lexer_.GetLoc();
bool maximal = false;
bool replicated = false;
- std::vector<int64> devices;
- std::vector<int64> tile_assignment_dimensions;
+ std::vector<tensorflow::int64> devices;
+ std::vector<tensorflow::int64> tile_assignment_dimensions;
Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
@@ -1206,7 +1324,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
}
do {
- int64 dim;
+ tensorflow::int64 dim;
if (!ParseInt64(&dim)) {
return false;
}
@@ -1218,7 +1336,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return false;
}
do {
- int64 device;
+ tensorflow::int64 device;
if (!ParseInt64(&device)) {
return false;
}
@@ -1277,10 +1395,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
}
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
*sharding->mutable_tile_shape() = tile_shape;
- for (int64 dim : tile_assignment_dimensions) {
+ for (tensorflow::int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
- for (int64 device : devices) {
+ for (tensorflow::int64 device : devices) {
sharding->add_tile_assignment_devices(device);
}
}
@@ -1289,6 +1407,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return true;
}
+// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ','
+// 'exit=' exit_sharding '}'
+bool HloParser::ParseDomain(DomainData* domain) {
+ std::unordered_map<string, AttrConfig> attrs;
+ optional<string> kind;
+ optional<OpSharding> entry_sharding;
+ optional<OpSharding> exit_sharding;
+ attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind};
+ attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding};
+ attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding};
+ if (!ParseSubAttributes(attrs)) {
+ return false;
+ }
+ if (*kind == ShardingMetadata::KindName()) {
+ auto entry_sharding_ptr = MakeUnique<HloSharding>(
+ HloSharding::FromProto(*entry_sharding).ValueOrDie());
+ auto exit_sharding_ptr = MakeUnique<HloSharding>(
+ HloSharding::FromProto(*exit_sharding).ValueOrDie());
+ domain->entry_metadata =
+ MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr));
+ domain->exit_metadata =
+ MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr));
+ } else {
+ return TokenError(StrCat("unsupported domain kind: ", *kind));
+ }
+ return true;
+}
+
// '{' name+ '}'
bool HloParser::ParseInstructionNames(
std::vector<HloInstruction*>* instructions) {
@@ -1315,40 +1461,50 @@ bool HloParser::ParseInstructionNames(
"expects '}' at the end of instruction name list");
}
-bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
+bool HloParser::SetValueInLiteral(tensorflow::int64 value,
+ tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case S8:
- return SetValueInLiteralHelper<int8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int8>(value, linear_index,
+ literal);
case S16:
- return SetValueInLiteralHelper<int16>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int16>(value, linear_index,
+ literal);
case S32:
- return SetValueInLiteralHelper<int32>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int32>(value, linear_index,
+ literal);
case S64:
- return SetValueInLiteralHelper<int64>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int64>(value, linear_index,
+ literal);
case U8:
- return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint8>(value, linear_index,
+ literal);
case U16:
- return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint16>(value, linear_index,
+ literal);
case U32:
- return SetValueInLiteralHelper<uint32>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint32>(value, linear_index,
+ literal);
case U64:
- return SetValueInLiteralHelper<uint64>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
+ literal);
default:
LOG(FATAL) << "unknown integral primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
-bool HloParser::SetValueInLiteral(double value, int64 linear_index,
+bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case F16:
- return SetValueInLiteralHelper<half>(value, linear_index, literal);
+ return SetValueInLiteralHelper<Eigen::half>(value, linear_index, literal);
case BF16:
- return SetValueInLiteralHelper<bfloat16>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::bfloat16>(value, linear_index,
+ literal);
case F32:
return SetValueInLiteralHelper<float>(value, linear_index, literal);
case F64:
@@ -1359,7 +1515,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index,
}
}
-bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
+bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
@@ -1372,7 +1528,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
}
template <typename LiteralNativeT, typename ParsedElemT>
-bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
+bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
+ tensorflow::int64 linear_index,
Literal* literal) {
// Check that linear_index is in range.
if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
@@ -1462,7 +1619,7 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
}
}
}
- *literal = Literal::MakeTupleOwned(std::move(elements));
+ *literal = LiteralUtil::MakeTupleOwned(std::move(elements));
return ParseToken(TokKind::kRparen,
StrCat("expects ')' at the end of the tuple with ",
ShapeUtil::TupleElementCount(shape), "elements"));
@@ -1484,16 +1641,16 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
- const int64 rank = ShapeUtil::Rank(shape);
+ const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
}
// Create a literal with the given shape in default layout.
- *literal = Literal::CreateFromDimensions(shape.element_type(),
- AsInt64Slice(shape.dimensions()));
- int64 nest_level = 0;
- int64 linear_index = 0;
+ *literal = LiteralUtil::CreateFromDimensions(
+ shape.element_type(), AsInt64Slice(shape.dimensions()));
+ tensorflow::int64 nest_level = 0;
+ tensorflow::int64 linear_index = 0;
// elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
// the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
// when we are parsing the 2nd '{' (right before '1'), we are seeing a
@@ -1501,14 +1658,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// the first '}' (right after '3'), it means the sub-array ends, and the
// sub-array is supposed to contain exactly 3 elements, so check if
// elems_seen_per_dim[1] is 3.
- std::vector<int64> elems_seen_per_dim(rank);
+ std::vector<tensorflow::int64> elems_seen_per_dim(rank);
auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
- std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
- elems_seen_per_dim.begin() + dim);
+ std::vector<tensorflow::int64> elems_seen_until_dim(
+ elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim);
return StrCat("[",
Join(elems_seen_until_dim, ",",
- [](string* out, const int64& num_elems) {
- tensorflow::strings::StrAppend(out, num_elems - 1);
+ [](string* out, const tensorflow::int64& num_elems) {
+ StrAppend(out, num_elems - 1);
}),
"]");
};
@@ -1584,7 +1741,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
LocTy loc = lexer_.GetLoc();
- int64 value;
+ tensorflow::int64 value;
if (!ParseInt64(&value)) {
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
@@ -1624,29 +1781,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
switch (shape.element_type()) {
case PRED:
- return ParseSparseLiteralHelper<uint8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape);
case S8:
- return ParseSparseLiteralHelper<int8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int8>(literal, shape);
case S16:
- return ParseSparseLiteralHelper<int16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int16>(literal, shape);
case S32:
- return ParseSparseLiteralHelper<int32>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int32>(literal, shape);
case S64:
- return ParseSparseLiteralHelper<int64>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int64>(literal, shape);
case U8:
- return ParseSparseLiteralHelper<uint8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape);
case U16:
- return ParseSparseLiteralHelper<uint16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint16>(literal, shape);
case U32:
- return ParseSparseLiteralHelper<uint32>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint32>(literal, shape);
case U64:
- return ParseSparseLiteralHelper<uint64>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint64>(literal, shape);
case F16:
- return ParseSparseLiteralHelper<half>(literal, shape);
+ return ParseSparseLiteralHelper<Eigen::half>(literal, shape);
case F32:
return ParseSparseLiteralHelper<float>(literal, shape);
case BF16:
- return ParseSparseLiteralHelper<bfloat16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::bfloat16>(literal, shape);
case F64:
return ParseSparseLiteralHelper<double>(literal, shape);
default:
@@ -1659,9 +1816,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
template <typename LiteralNativeT>
bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
const Shape& shape) {
- std::vector<int64> index;
+ std::vector<tensorflow::int64> index;
- int64 rank = ShapeUtil::Rank(shape);
+ tensorflow::int64 rank = ShapeUtil::Rank(shape);
*literal = MakeUnique<Literal>(shape);
@@ -1679,7 +1836,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
LocTy index_loc = lexer_.GetLoc();
index.clear();
if (lexer_.GetKind() == TokKind::kInt) {
- int64 single_index = lexer_.GetInt64Val();
+ tensorflow::int64 single_index = lexer_.GetInt64Val();
lexer_.Lex();
if (rank != 1) {
return Error(
@@ -1712,7 +1869,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true);
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
- int64 value_s64;
+ tensorflow::int64 value_s64;
if (!ParseInt64(&value_s64)) {
return Error(value_loc,
StrCat("expects integer for primitive type: ",
@@ -1885,23 +2042,24 @@ bool HloParser::ParseAttributeHelper(
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
case AttrTy::kInt64: {
- int64 result;
+ tensorflow::int64 result;
if (!ParseInt64(&result)) {
return false;
}
- static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
+ static_cast<optional<tensorflow::int64>*>(attr_out_ptr)
+ ->emplace(result);
return true;
}
case AttrTy::kInt32: {
- int64 result;
+ tensorflow::int64 result;
if (!ParseInt64(&result)) {
return false;
}
- if (result != static_cast<int32>(result)) {
+ if (result != static_cast<tensorflow::int32>(result)) {
return Error(attr_loc, "value out of range for int32");
}
- static_cast<optional<int32>*>(attr_out_ptr)
- ->emplace(static_cast<int32>(result));
+ static_cast<optional<tensorflow::int32>*>(attr_out_ptr)
+ ->emplace(static_cast<tensorflow::int32>(result));
return true;
}
case AttrTy::kFloat: {
@@ -1977,12 +2135,12 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kBracedInt64List: {
- std::vector<int64> result;
+ std::vector<tensorflow::int64> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
- static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
+ static_cast<optional<std::vector<tensorflow::int64>>*>(attr_out_ptr)
->emplace(result);
return true;
}
@@ -2027,6 +2185,9 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kDomain: {
+ return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
+ }
}
}();
if (!success) {
@@ -2157,7 +2318,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
<< str;
}
- const int64 rank = lhs_rhs_out[0].length();
+ const tensorflow::int64 rank = lhs_rhs_out[0].length();
if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
@@ -2271,7 +2432,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
return false;
}
- std::vector<std::vector<int64>> ranges;
+ std::vector<std::vector<tensorflow::int64>> ranges;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
@@ -2305,7 +2466,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
// ::= int64_val (delim int64_val)*
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
- std::vector<int64>* result) {
+ std::vector<tensorflow::int64>* result) {
if (!ParseToken(start, StrCat("expects an int64 list starting with ",
TokKindToString(start)))) {
return false;
@@ -2314,7 +2475,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
// empty
} else {
do {
- int64 i;
+ tensorflow::int64 i;
if (!ParseInt64(&i)) {
return false;
}
@@ -2431,7 +2592,8 @@ bool HloParser::ParseString(string* result) {
return true;
}
-bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
+bool HloParser::ParseDxD(const string& name,
+ std::vector<tensorflow::int64>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
return Error(loc,
@@ -2439,7 +2601,7 @@ bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
- int64 number;
+ tensorflow::int64 number;
if (!ParseInt64(&number)) {
return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
}
@@ -2459,7 +2621,8 @@ bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
return TokenError("expects token type kInt or kDxD");
}
-bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
+bool HloParser::ParseWindowPad(
+ std::vector<std::vector<tensorflow::int64>>* pad) {
LocTy loc = lexer_.GetLoc();
if (!pad->empty()) {
return Error(loc, "sub-attribute 'pad=' already exists");
@@ -2470,7 +2633,7 @@ bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
string str = lexer_.GetStrVal();
std::vector<string> padding_str = Split(str, 'x');
for (int i = 0; i < padding_str.size(); i++) {
- std::vector<int64> low_high;
+ std::vector<tensorflow::int64> low_high;
if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
@@ -2494,7 +2657,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
string str = lexer_.GetStrVal();
std::vector<string> padding_str = Split(str, 'x');
for (const auto& padding_dim_str : padding_str) {
- std::vector<int64> padding_dim;
+ std::vector<tensorflow::int64> padding_dim;
if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
@@ -2516,7 +2679,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) {
optional<string> op_type;
optional<string> op_name;
optional<string> source_file;
- optional<int32> source_line;
+ optional<tensorflow::int32> source_line;
attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
@@ -2603,7 +2766,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
-bool HloParser::ParseInt64(int64* result) {
+bool HloParser::ParseInt64(tensorflow::int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
@@ -2726,8 +2889,8 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
} // namespace
-StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
- const HloModuleConfig& config) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str, const HloModuleConfig& config) {
HloParser parser(str, config);
if (!parser.Run()) {
return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
@@ -2735,9 +2898,10 @@ StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
return parser.ConsumeHloModule();
}
-StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str) {
HloModuleConfig config;
- return Parse(str, config);
+ return ParseHloString(str, config);
}
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
@@ -2759,5 +2923,4 @@ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
return parser.ParseConvolutionDimensionNumbersOnly();
}
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 902c45cebc..3f3a51215e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -13,28 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
-namespace tools {
+
+// For details about the syntax accepted by this parser, see
+// g3doc/hlo_parser.md.
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with the given config.
-StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
- const HloModuleConfig& config);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str, const HloModuleConfig& config);
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with default config.
-StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
@@ -47,7 +50,10 @@ StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
tensorflow::StringPiece str);
-} // namespace tools
+// ParseHloString sharding from str. str is supposed to contain the body of the
+// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 3c5957b96a..f06c705c42 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <string>
#include "tensorflow/compiler/xla/window_util.h"
@@ -23,10 +23,10 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace xla {
-namespace tools {
+
namespace {
-using tensorflow::StringPiece;
+using ::tensorflow::StringPiece;
struct TestData {
string test_name;
@@ -236,6 +236,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3
)"
},
+{
+"DomainParsing",
+R"(HloModule DomainParsing_module
+
+ENTRY %DomainParsing (v1: f32[]) -> f32[] {
+ %v1 = f32[] parameter(0)
+ ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
+}
+
+)"
+},
// int32 result = 0;
// while (result < 5) { result = result + 1; }
{
@@ -266,12 +277,13 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
"SendRecv",
R"(HloModule TwoSendRecvBothWayRecvFist_module
-ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
- ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
+ %token = token[] after-all()
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15, sharding={maximal device=1}
+ ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
%constant = f32[] constant(2.1), sharding={maximal device=0}
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
}
)"
@@ -754,7 +766,7 @@ add_F32.v3 {
ENTRY MapBinaryAdder.v3 {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
- ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3
+ ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
}
)"
@@ -784,10 +796,14 @@ ENTRY ReduceR3ToR2.v3 {
R"(HloModule outfeed_module
ENTRY InfeedToOutfeed {
- infeed = (u32[3]{0}, pred[]) infeed()
- outfeed = () outfeed(infeed)
- ROOT infeed.1 = (u32[3]{0}, pred[]) infeed()
- outfeed.1 = () outfeed(infeed.1)
+ token = token[] after-all()
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
}
)"
@@ -817,6 +833,56 @@ ENTRY ReducePrecision {
)"
},
+// Sort (Key)
+{
+"SortKey",
+R"(HloModule sort
+
+ENTRY Sort {
+ x = f32[1024]{0} parameter(0)
+ ROOT sorted = f32[1024]{0} sort(x), dimensions={0}
+}
+
+)"
+},
+// Sort (Key, Value)
+{
+"SortKeyValue",
+R"(HloModule sort
+
+ENTRY Sort {
+ keys = f32[1024]{0} parameter(0)
+ values = s32[1024]{0} parameter(1)
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
+}
+
+)"
+},
+// R2 Sort (Key)
+{
+"SortKeyR2",
+R"(HloModule sort
+
+ENTRY Sort {
+ x = f32[1024,16]{0,1} parameter(0)
+ ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}
+}
+
+)"
+},
+// R2 Sort (Key, Value)
+{
+"SortKeyValueR2",
+R"(HloModule sort
+
+ENTRY Sort {
+ keys = f32[1024,16]{0,1} parameter(0)
+ values = s32[1024,16]{0,1} parameter(1)
+ ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}
+}
+
+)"
+},
// Conditional
{
"Conditional",
@@ -889,6 +955,42 @@ ENTRY Gather {
)"
},
+// cross-replica-sum
+{
+"CrossReplicaSum",
+R"(HloModule CRS
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CRS {
+ input = f32[8]{0} parameter(0)
+ ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add
+}
+
+)"
+},
+// cross-replica-sum with subgroups
+{
+"CrossReplicaSumWithSubgroups",
+R"(HloModule CRS_Subgroups
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CrossReplicaSumWithSubgroups {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add
+}
+
+)"
+}
});
// clang-format on
}
@@ -901,12 +1003,12 @@ class HloParserTest : public ::testing::Test,
<< "'" << s << "' does not contain '" << expected << "'";
}
- // Expects "ToString(Parse(string)) == string", that is, parses the string,
- // asserts that it succeeded, stringifies the parsed module, and checks that
- // the it equals the original string.
+ // Expects "ToString(ParseHloString(string)) == string", that is, parses the
+ // string, asserts that it succeeded, stringifies the parsed module, and
+ // checks that the it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ(original, result.ValueOrDie()->ToString(
HloPrintOptions().set_print_large_constants(true)));
@@ -917,7 +1019,7 @@ class HloParserShortTest : public HloParserTest {
protected:
void ExpectEqualShort() {
const string& original = GetParam().module_string;
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ(original,
result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
@@ -938,13 +1040,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
TEST_F(HloParserTest, Empty) {
const string original = "";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -958,7 +1060,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -970,7 +1072,7 @@ ENTRY %blabla (x: g32[]) -> g32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -983,7 +1085,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -994,7 +1096,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -1009,7 +1111,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
@@ -1020,7 +1122,7 @@ TEST_F(HloParserTest, ConfigurationField) {
ENTRY %configuration_test() -> s32[] {
%constant = s32[] constant(42), backend_config="foo bar"
})";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ("foo bar", result.ValueOrDie()
->entry_computation()
@@ -1036,7 +1138,7 @@ ENTRY %some_2 () -> f32[2] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 1, but sees larger");
@@ -1050,7 +1152,7 @@ ENTRY %some_2x3 () -> f32[2,3] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 2, but sees 1");
@@ -1064,7 +1166,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects 3 elements in the [0]th element");
@@ -1079,7 +1181,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type F16");
@@ -1093,7 +1195,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
@@ -1111,7 +1213,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, InvalidDimLabels) {
@@ -1127,32 +1229,34 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
)";
+ ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat(
+ prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
+
ExpectHasSubstr(
- Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ ParseHloString(tensorflow::strings::StrCat(
+ prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
- "expects dim labels pattern");
-
- ExpectHasSubstr(Parse(tensorflow::strings::StrCat(
- prefix, ",dim_labels=010_1100->010", suffix))
- .status()
- .error_message(),
- "must have the same rank");
+ "must have the same rank");
}
TEST_F(HloParserTest, UnexpectedAttribute) {
const string original = R"(HloModule unexpected_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, calls=%recv
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"unexpected attribute \"calls\"");
}
@@ -1160,15 +1264,16 @@ TEST_F(HloParserTest, MissingAttribute) {
const string original = R"(HloModule missing_attr_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(-2.1)
- %send = (f32[], u32[]) send(f32[] %constant)
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token)
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"attribute channel_id is expected but not seen");
}
@@ -1176,15 +1281,16 @@ TEST_F(HloParserTest, PredecessorUndefined) {
const string original = R"(HloModule pre_not_found_module
ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
- %recv = (f32[], u32[]) recv(), channel_id=15
- %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
+ %token = token[] after-all()
+ %recv = (f32[], u32[], token[]) recv(token[] %token), channel_id=15
+ %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
ROOT %constant = f32[] constant(2.1)
- %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
- %send-done = () send-done((f32[], u32[]) %send), channel_id=16
+ %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token), channel_id=16, control-predecessors={%done}
+ %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"'done' is not defined");
}
@@ -1197,7 +1303,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
@@ -1211,7 +1317,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"expects padding_low and padding_high separated by '_'");
}
@@ -1223,7 +1329,7 @@ ENTRY %test_comma.v4 () -> f32[] {
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
@@ -1233,7 +1339,7 @@ ENTRY %CustomCall () -> f32[1] {
%constant = f32[1]{0} constant({12345})
ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"Shape of computation CustomCall, f32[1], is not compatible "
"with that of its root instruction foo, f32[1,2,3]");
}
@@ -1252,9 +1358,9 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
- auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
+ auto program_layout = module.ValueOrDie()->entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
auto param_layout = program_layout.parameter_layout(0).layout();
auto result_layout = program_layout.result_layout().layout();
@@ -1275,7 +1381,7 @@ c1 {
c2 {
const2 = f32[1]{0} constant({67890})
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
}
@@ -1286,7 +1392,7 @@ ENTRY consts {
first = f32[1]{0} constant({12345})
last = f32[1]{0} constant({67890})
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(
module.ValueOrDie()->entry_computation()->root_instruction()->name(),
@@ -1301,7 +1407,7 @@ ENTRY c1 {
ENTRY c2 {
const2 = f32[1]{0} constant({67890})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"expects only one ENTRY");
}
@@ -1311,7 +1417,7 @@ ENTRY consts {
ROOT const1 = f32[1]{0} constant({12345})
ROOT const2 = f32[1]{0} constant({12345})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"one computation should have only one ROOT");
}
@@ -1323,7 +1429,7 @@ comp {
comp {
const2 = f32[1]{0} constant({67890})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
R"(was parsing 2:1: error: computation previously defined here
comp {
^)");
@@ -1346,7 +1452,7 @@ ENTRY entry {
ROOT call1 = s32[] call(param), to_apply=tcallb
})";
ExpectHasSubstr(
- Parse(original).status().error_message(),
+ ParseHloString(original).status().error_message(),
"was parsing 8:39: error: instruction does not exist: aparam");
}
@@ -1370,6 +1476,15 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
}
+TEST_F(HloParserTest, NontupleInfeed) {
+ const string original = R"(HloModule nontuple_infeed:
+ENTRY nontuple_infeed {
+ token = token[] after-all()
+ ROOT infeed = pred[] infeed(token)
+})";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "infeed must have a non-empty tuple shape");
+}
+
} // namespace
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index d45038f1f4..2a07b6fcbc 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_query.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -61,7 +61,7 @@ bool AllOperandsAreConstants(const HloInstruction& instruction) {
}
HloInstruction* GetMatchingOperand(
- std::function<bool(const HloInstruction*)> matcher,
+ const std::function<bool(const HloInstruction*)>& matcher,
HloInstruction* instruction) {
for (HloInstruction* op : instruction->operands()) {
if (matcher(op)) {
@@ -72,7 +72,7 @@ HloInstruction* GetMatchingOperand(
}
bool MatchBinaryInstructionOperand(
- std::function<bool(const HloInstruction*)> matcher,
+ const std::function<bool(const HloInstruction*)>& matcher,
HloInstruction* instruction, HloInstruction** matching_operand,
HloInstruction** other_operand) {
CHECK_EQ(instruction->operand_count(), 2);
diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h
index c79347bbf9..c0826a6aee 100644
--- a/tensorflow/compiler/xla/service/hlo_query.h
+++ b/tensorflow/compiler/xla/service/hlo_query.h
@@ -45,7 +45,7 @@ bool IsScalarConstant(const HloInstruction* instruction);
// multiple matching operands, then the first matching operand is returned. If
// there are no matching operands then nullptr is returned.
HloInstruction* GetMatchingOperand(
- std::function<bool(const HloInstruction*)> matcher,
+ const std::function<bool(const HloInstruction*)>& matcher,
HloInstruction* instruction);
// Returns whether a binary instruction has a matching operand. Sets
@@ -53,7 +53,7 @@ HloInstruction* GetMatchingOperand(
// other_operand. Note: in the case where both operands match, the first operand
// of the instruction is returned.
bool MatchBinaryInstructionOperand(
- std::function<bool(const HloInstruction*)> matcher,
+ const std::function<bool(const HloInstruction*)>& matcher,
HloInstruction* instruction, HloInstruction** matching_operand,
HloInstruction** other_operand);
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc
index 4738e46f8a..01b088a957 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace xla {
HloReachabilityMap::HloReachabilityMap(
- const std::list<HloInstruction*>& instructions)
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions)
: size_(instructions.size()) {
bit_vectors_.reserve(size_);
for (const HloInstruction* hlo : instructions) {
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index 69bb2b3cee..48215d32a8 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -41,7 +41,8 @@ class HloReachabilityMap {
public:
// Sets up a graph with no edges and where the nodes correspond to the given
// instructions.
- explicit HloReachabilityMap(const std::list<HloInstruction*>& instructions);
+ explicit HloReachabilityMap(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
diff --git a/tensorflow/compiler/xla/service/hlo_reachability_test.cc b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
index 657a9ee83d..585c95972b 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_reachability_test.cc
@@ -39,15 +39,15 @@ TEST_F(HloReachabilityTest, Reachability) {
*/
auto builder = HloComputation::Builder(TestName());
auto a = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto b = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto c = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto d = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
auto e = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
builder.Build();
HloReachabilityMap reachability({a, b, c, d, e});
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 39b85de0f1..59a8800a7d 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -71,6 +72,20 @@ bool IsRematerializable(const HloInstruction* instruction) {
}
}
+// Checks whether an instruction can be rematerialized, by looking up the
+// cache before, and eventually calling the IsRematerializable() API.
+bool CanBeRematerialized(
+ const HloInstruction* instruction,
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ auto it = remat_able->find(instruction);
+ if (it != remat_able->end()) {
+ return it->second;
+ }
+ bool rematerializable = IsRematerializable(instruction);
+ (*remat_able)[instruction] = rematerializable;
+ return rematerializable;
+}
+
// Type holding a unique identifier for each Buffer object.
using BufferId = int64;
using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>;
@@ -843,9 +858,10 @@ int64 RematerializationCost(const HloInstruction* instruction,
// candidate which reduce memory use at the program point of the current
// instruction as indicated by memory_tracker. nullptr is returned if no
// candidate can be found.
-Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
- const InstructionList& instruction_list,
- int64 memory_limit_bytes) {
+Item* PickRematerializationCandidate(
+ const MemoryUsageTracker& memory_tracker,
+ const InstructionList& instruction_list, int64 memory_limit_bytes,
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
@@ -869,8 +885,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker,
<< " is excluded from rematerialization";
continue;
}
-
- if (!IsRematerializable(candidate)) {
+ if (!CanBeRematerialized(candidate, remat_able)) {
VLOG(5) << "candidate " << candidate->name()
<< " not viable: is not rematerializable";
continue;
@@ -974,6 +989,9 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// blacklist.
tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
+ // The map from instructions to their rematerializable status.
+ tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able;
+
// The peak memory of the computation at any point in the instruction
// sequence.
int64 peak_memory = memory_tracker.memory_usage();
@@ -1011,7 +1029,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
Item* best_item = PickRematerializationCandidate(
- memory_tracker, instruction_list, memory_limit_bytes);
+ memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
if (best_item == nullptr) {
VLOG(3) << "Unable to find rematerialization candidate at program "
@@ -1184,7 +1202,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes) {
+ int64 memory_limit_bytes, RematerializationSizes* sizes,
+ bool run_copy_elision) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@@ -1213,12 +1232,21 @@ StatusOr<bool> HloRematerialization::Run(
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
// Create initial sequence of HLO instructions.
- TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence(
+ TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule(
*module,
[this](const BufferValue& buffer) {
return size_function_(buffer.shape());
},
scheduler_algorithm_));
+ if (run_copy_elision) {
+ // We run a separate pass of copy elision here because the sequential
+ // ordering from the HLO schedule allows for more copies to be eliminated.
+ // TODO(b/80249101): Instead of a separate copy elision pass, use the
+ // ordering from the HLO schedule directly for copy insertion.
+ SequentialHloOrdering ordering(module, *sequence);
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module));
+ }
+
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@@ -1321,9 +1349,10 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes) {
+ RematerializationSizes* sizes, bool run_copy_elision) {
HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
+ return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ run_copy_elision);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ee2dd0571..59b4cf5dcc 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -57,6 +57,12 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
+ // run_copy_elision: Enable copy elision. This pass is used to eliminate
+ // copies that were inserted before HLO scheduling.
+ //
+ // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
+ // insertion is integrated with HLO scheduling.
+ //
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
// false is returned.
@@ -68,7 +74,7 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes = nullptr);
+ RematerializationSizes* sizes, bool run_copy_elision = true);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -83,7 +89,8 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit, RematerializationSizes* sizes);
+ int64 memory_limit, RematerializationSizes* sizes,
+ bool run_copy_elision);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index 83de54f3fa..cd131147e6 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase {
// Creates and returns a computation which can benefit from
// rematerialization. The computation looks like:
//
- // F32[] %param = {...}
+ // F32[1] %param = {...}
+ // F32[] %reshape = reshape(F32[], param)
// F32[1024] %bcast = broadcast(%param)
// F32[1024] %negate = negate(%bcast)
// F32[2048] %concat_1 = concat({%negate, %negate})
@@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase {
const string& suffix = "") {
auto builder = HloComputation::Builder(TestName() + suffix);
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ HloInstruction::CreateParameter(0, vec1_shape_, "param"));
+ auto reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(scalar_shape_, param));
auto bcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
+ HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast));
auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate(
@@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase {
const string& suffix = "") {
auto builder = HloComputation::Builder(TestName() + suffix);
auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
+ HloInstruction::CreateParameter(0, vec1_shape_, "param"));
+ auto reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(scalar_shape_, param));
auto bcast = builder.AddInstruction(
- HloInstruction::CreateBroadcast(vec1024_shape_, param, {}));
+ HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {}));
auto slice_1 = builder.AddInstruction(
HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0},
/*limit_indices=*/{1},
@@ -126,7 +132,7 @@ class HloRematerializationTest : public HloTestBase {
builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
return builder.Build();
}
@@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
+ StatusOr<bool> RunHloRematerialization(
+ int64 memory_limit_bytes, HloModule* module,
+ SequentialHloOrdering::HloModuleSequence* sequence) {
+ TF_EXPECT_OK(verifier().Run(module).status());
+ return HloRematerialization::RematerializeAndSchedule(
+ ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
+ sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
+ }
+
// Various shapes used in the canned computations.
const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {});
const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1});
@@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
SequentialHloOrdering::HloModuleSequence sequence;
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/14 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/14 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// Root should not have changed.
@@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) {
HloComputation* computation =
module->AddEntryComputation(MakeRematerializableComputation());
- EXPECT_EQ(computation->instruction_count(), 7);
+ EXPECT_EQ(computation->instruction_count(), 8);
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/20 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/20 * 1024,
+ module.get(), &sequence));
// No instructions should have been materialized.
EXPECT_FALSE(changed);
- EXPECT_EQ(computation->instruction_count(), 7);
+ EXPECT_EQ(computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
@@ -215,7 +226,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
// The body computation uses 16KB and the entry computation uses 2KB at the
// while so the peak memory use of the module is 18KB. Set the memory limit a
// bit lower (17KB) to force rematerialization of the entry computation.
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/17 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/17 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// Only the entry computation should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 8);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
}
// Test rematerialization of a computation which calls another computation via a
@@ -254,7 +263,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/body_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(body_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(body_computation->instruction_count(), 8);
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/15 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/15 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
- // Both computations should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(body_computation->instruction_count(), 8);
+ // Both computations should have rematerialized instructions added.
+ EXPECT_EQ(entry_computation->instruction_count(), 9);
+ EXPECT_EQ(body_computation->instruction_count(), 9);
}
// Test rematerialization of a doubly nested computation. All computations
@@ -289,7 +296,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, vec1_shape_, "param"));
cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
HloComputation* while_cond =
module->AddEmbeddedComputation(cond_builder.Build());
@@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) {
module->AddEntryComputation(MakeRematerializableWhileComputation(
while_cond, /*while_body=*/middle_computation));
- EXPECT_EQ(entry_computation->instruction_count(), 6);
- EXPECT_EQ(middle_computation->instruction_count(), 6);
- EXPECT_EQ(inner_computation->instruction_count(), 7);
+ EXPECT_EQ(entry_computation->instruction_count(), 7);
+ EXPECT_EQ(middle_computation->instruction_count(), 7);
+ EXPECT_EQ(inner_computation->instruction_count(), 8);
// If all computations are maximally rematerialized then peak memory usage is
// ~12K so pick something slightly larger.
SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/13 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/13 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
- // All computations should have a rematerialized instruction added.
- EXPECT_EQ(entry_computation->instruction_count(), 7);
- EXPECT_EQ(middle_computation->instruction_count(), 7);
- EXPECT_EQ(inner_computation->instruction_count(), 8);
+ // All computations should have rematerialized instructions added.
+ EXPECT_EQ(entry_computation->instruction_count(), 9);
+ EXPECT_EQ(middle_computation->instruction_count(), 9);
+ EXPECT_EQ(inner_computation->instruction_count(), 9);
}
TEST_F(HloRematerializationTest, RngNotRematerialized) {
@@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) {
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
TF_ASSERT_OK_AND_ASSIGN(
- bool changed, HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
+ bool changed, RunHloRematerialization(
/*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_),
- module.get(), DefaultMemoryScheduler, &sequence));
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// The rng should not have been rematerialized.
EXPECT_EQ(count_rngs(entry_computation), 1);
@@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) {
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024,
+ module.get(), &sequence));
EXPECT_TRUE(changed);
// The broadcast should have been rematerialized 3 times.
@@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
// Pick a memory limit some where between 24KB (initial peak memory including
// parameter and output) and 20KB (peak memory possible with
// rematerialization).
- TF_ASSERT_OK_AND_ASSIGN(bool changed,
- HloRematerialization::RematerializeAndSchedule(
- ByteSizeOf,
- /*memory_limit_bytes=*/22 * 1024, module.get(),
- DefaultMemoryScheduler, &sequence));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization(
+ /*memory_limit_bytes=*/22 * 1024,
+ module.get(), &sequence));
// Rematerialization should only occur if the rematerializable instruction has
// no indirect uses.
if (indirectly_used) {
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 31e13da0c0..4f0569f405 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -36,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
- return tools::Parse(hlo_string, config);
+ return ParseHloString(hlo_string, config);
}
namespace {
@@ -80,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
- return tools::Parse(hlo_string, config);
+ return ParseHloString(hlo_string, config);
}
HloRunner::HloRunner(se::Platform* platform) {
@@ -98,8 +98,10 @@ StatusOr<ScopedShapedBuffer> HloRunner::TransferLiteralToDevice(
backend().transfer_manager()->AllocateScopedShapedBuffer(
literal.shape(), backend().memory_allocator(),
backend().default_device_ordinal()));
+ TF_ASSIGN_OR_RETURN(
+ auto stream, backend().BorrowStream(backend().default_stream_executor()));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- backend().default_stream_executor(), literal, buffer));
+ stream.get(), literal, buffer));
return std::move(buffer);
}
@@ -127,8 +129,10 @@ StatusOr<std::vector<ScopedShapedBuffer>> HloRunner::TransferLiteralsToDevice(
StatusOr<std::unique_ptr<Literal>> HloRunner::TransferLiteralFromDevice(
const ShapedBuffer& buffer) {
- return backend().transfer_manager()->TransferLiteralFromDevice(
- backend().default_stream_executor(), buffer);
+ TF_ASSIGN_OR_RETURN(
+ auto stream, backend().BorrowStream(backend().default_stream_executor()));
+ return backend().transfer_manager()->TransferLiteralFromDevice(stream.get(),
+ buffer);
}
StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
@@ -237,7 +241,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
backend().transfer_manager()->AllocateScopedShapedBuffer(
argument->shape(), backend().memory_allocator(), device));
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
- executor, *argument, argument_buffer));
+ streams.back().get(), *argument, argument_buffer));
argument_buffers.push_back(std::move(argument_buffer));
argument_buffer_ptrs[index++] = &argument_buffers.back();
}
@@ -307,7 +311,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
for (int64 i = 0; i < options.num_replicas; ++i) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
backend().transfer_manager()->TransferLiteralFromDevice(
- streams[i]->parent(), results[i]));
+ streams[i].get(), results[i]));
exec_results.push_back(std::move(literal));
}
return std::move(exec_results);
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 68b2cde83a..c6d3909af6 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -36,29 +36,6 @@ using ::tensorflow::strings::HumanReadableNumBytes;
namespace xla {
-StatusOr<int64> MinimumMemoryForSequence(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
- return 0;
- }
-
- const HloModule* module = module_sequence.begin()->first->parent();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
- TuplePointsToAnalysis::Run(module));
-
- // The absolute minimum memory required for a given sequence of instructions
- // is determined by the sequence of Alloc and Free calls on a simulated heap,
- // ignoring fragmentation. We run the heap simulation on the whole module,
- // rather than summing each computation, since it gives us a better lower
- // bound, by minimizing the liveness of sub-computations.
- TF_ASSIGN_OR_RETURN(
- HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
- return result.heap_size;
-}
-
namespace {
// Class implementing a list scheduler of HLO instructions which produces a
@@ -398,7 +375,7 @@ int64 SumLogicalBufferSizes(
return size;
}
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+StatusOr<std::vector<const HloInstruction*>> ScheduleComputationHelper(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
@@ -416,30 +393,15 @@ StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
} // namespace
-StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
- TF_ASSIGN_OR_RETURN(
- HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
- return result.heap_size;
-}
-
StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
memory_by_computation) {
- // This ordering is based on DFS post-order, with a heuristic to decide which
- // operand to visit first. The heuristic is based on 'extra_users', which is
- // simply users-1 for each instruction. By subtracting 1, we're saying that
- // instructions with no users or a single user don't count; instructions with
- // lots of fan-out will be visited earlier.
+ // These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
+ int64 total_hlos = computation.parent()->NumUniqueInstructionIds();
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
@@ -448,6 +410,11 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
total_sizes[hlo] = 0;
continue;
}
+ // This ordering is based on DFS post-order, with a heuristic to decide
+ // which operand to visit first. The heuristic is based on 'extra_users',
+ // which is simply users-1 for each instruction. By subtracting 1, we're
+ // saying that instructions with no users or a single user don't count;
+ // instructions with lots of fan-out will be visited earlier.
extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
int64 logical_buffer_size = SumLogicalBufferSizes(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
@@ -463,10 +430,13 @@ StatusOr<std::vector<const HloInstruction*>> DFSMemoryScheduler(
// lead to it. But computation is a DAG, so we are double-counting nodes,
// which can lead to overflows for large programs.
// cumulative_total_size caps the size to prevent overflows.
+ // Same for total_hlos: it prevents overflows on very large and branchy
+ // models, where the number of paths is exponential to the number of nodes.
// NOTE(dimvar): this is quite ugly and should be changed. It's unclear
// why we care about transitive sizes; when scheduling a node, its input
// and output buffers should be all that matters, not its "history".
total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size);
+ extra_users[hlo] = std::min(extra_users[hlo], total_hlos);
}
CHECK_EQ(extra_users.size(), computation.instruction_count());
CHECK_EQ(total_sizes.size(), computation.instruction_count());
@@ -533,29 +503,29 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
std::vector<const HloInstruction*> list_sequence,
ListMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 list_memory,
- MinimumMemoryForComputation(computation, list_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 list_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, list_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory);
TF_ASSIGN_OR_RETURN(std::vector<const HloInstruction*> dfs_sequence,
DFSMemoryScheduler(computation, points_to_analysis,
size_function, memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 dfs_memory,
- MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
- size_function));
+ TF_ASSIGN_OR_RETURN(const int64 dfs_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, dfs_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory);
TF_ASSIGN_OR_RETURN(
std::vector<const HloInstruction*> post_order_sequence,
PostOrderMemoryScheduler(computation, points_to_analysis, size_function,
memory_by_computation));
- TF_ASSIGN_OR_RETURN(
- const int64 post_order_memory,
- MinimumMemoryForComputation(computation, post_order_sequence,
- points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(const int64 post_order_memory,
+ HeapSimulator::MinimumMemoryForComputation(
+ computation, post_order_sequence, points_to_analysis,
+ size_function, &memory_by_computation));
VLOG(2) << "Min-memory post order sequence: "
<< HumanReadableNumBytes(post_order_memory);
@@ -576,10 +546,9 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
}
}
-StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(const HloModule& module,
- const LogicalBuffer::SizeFunction& size_function,
- const MemorySchedulerAlgorithm& algorithm) {
+StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+ const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm) {
SequentialHloOrdering::HloModuleSequence sequence;
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
@@ -587,12 +556,13 @@ CreateMemoryMinimizingSequence(const HloModule& module,
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(auto one_computation_sequence,
- CreateMemoryMinimizingSequence(
+ ScheduleComputationHelper(
*computation, *points_to_analysis, size_function,
algorithm, memory_by_computation));
memory_by_computation[computation] =
- MinimumMemoryForComputation(*computation, one_computation_sequence,
- *points_to_analysis, size_function)
+ HeapSimulator::MinimumMemoryForComputation(
+ *computation, one_computation_sequence, *points_to_analysis,
+ size_function, &memory_by_computation)
.ValueOrDie();
sequence[computation] = std::move(one_computation_sequence);
}
@@ -600,15 +570,15 @@ CreateMemoryMinimizingSequence(const HloModule& module,
return sequence;
}
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function) {
CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
- return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
- size_function, nullptr, empty_map);
+ return ScheduleComputationHelper(computation, *points_to_analysis,
+ size_function, nullptr, empty_map);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
index 49b927eefd..2b33ccc8bf 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.h
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -28,20 +28,6 @@ limitations under the License.
namespace xla {
-// Returns the minimum memory required to compute the given module sequence,
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForSequence(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function);
-
-// Returns the minimum memory required to compute the given computation,
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function);
-
// A memory scheduler computes an execution sequence for the HLO instructions in
// 'computation' that minimizes peak memory, given a points-to analysis result
// that describes buffer aliasing, together with a target-specific size function
@@ -89,14 +75,13 @@ StatusOr<std::vector<const HloInstruction*>> DefaultMemoryScheduler(
// Returns an HloModuleSequence which seeks to minimize the memory required for
// the computation. size_function is the function returning the number of bytes
// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(const HloModule& module,
- const LogicalBuffer::SizeFunction& size_function,
- const MemorySchedulerAlgorithm& algorithm = {});
+StatusOr<SequentialHloOrdering::HloModuleSequence> ScheduleComputationsInModule(
+ const HloModule& module, const LogicalBuffer::SizeFunction& size_function,
+ const MemorySchedulerAlgorithm& algorithm = {});
-// Overload of above that computes the sequence for a single computation.
+// Computes the schedule for a single computation.
// Currently only used by the GPU backend.
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation(
const HloComputation& computation,
const LogicalBuffer::SizeFunction& size_function);
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 0bc930f9ea..cf9ceed5b2 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -18,78 +18,20 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-class MinimumMemoryForSequenceTest : public HloTestBase {};
-
-TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
- auto module = CreateNewModule();
- const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
- const Shape tuple_shape =
- ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
-
- auto cond_builder = HloComputation::Builder("WhileCond");
- // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
- HloInstruction* cond_param = cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
- HloInstruction* cond_iter = cond_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
- HloInstruction* cond_data = cond_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
- // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
- HloInstruction* cond_lt = cond_builder.AddInstruction(
- HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
- HloOpcode::kLt, cond_iter, cond_data));
- HloComputation* cond_computation =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto body_builder = HloComputation::Builder("WhileBody");
- // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
- HloInstruction* body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
- HloComputation* body_computation =
- module->AddEmbeddedComputation(body_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- // Entry params: 8 bytes (4 bytes per param), TOTAL=8
- HloInstruction* iter = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
- HloInstruction* data = builder.AddInstruction(
- HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
- // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
- HloInstruction* tuple =
- builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
- // While: 8 bytes (4 bytes per element), TOTAL=32
- // Both cond and body use a max of 24 bytes, TOTAL=56
- HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
- tuple_shape, cond_computation, body_computation, tuple));
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
-
- auto size_fn = [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- };
-
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56,
- MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
-}
-
class HloSchedulingTest : public HloTestBase {};
TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
@@ -124,7 +66,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) {
+ ScheduleComputationsInModule(*module, [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
}));
// Verify that all instructions are in the sequence.
@@ -158,14 +100,14 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
};
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler));
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
sequence.at(module->entry_computation()).size());
@@ -203,7 +145,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
// ROOT %subtract = f32[4]{0} subtract(
// f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1)
// }
- // %SubcomputationsNotAccounted () -> f32[2,4] {
+ // %ListAccountsForSubcomputations () -> f32[2,4] {
// %constant.3 = f32[2,4]{1,0} constant(
// f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } })
// %transpose = f32[2,4]{1,0} transpose(
@@ -226,8 +168,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
auto cond_builder = HloComputation::Builder("WhileCond");
HloInstruction* cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "cond_param"));
- HloInstruction* zero_vector = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{0, 0, 0, 0}})));
+ HloInstruction* zero_vector =
+ cond_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
cond_builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
@@ -237,16 +180,18 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
auto body_builder = HloComputation::Builder("WhileBody");
HloInstruction* body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, r1f32, "body_param"));
- HloInstruction* one_vector = body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* one_vector =
+ body_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
body_builder.AddInstruction(HloInstruction::CreateBinary(
r1f32, HloOpcode::kSubtract, body_param, one_vector));
auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
// transpose(matrix) + bcast(while)
auto builder = HloComputation::Builder(TestName());
- HloInstruction* while_init = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>({{1, 1, 1, 1}})));
+ HloInstruction* while_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
// Creates 16 bytes, ignoring subcomputations
HloInstruction* while_loop =
builder.AddInstruction(HloInstruction::CreateWhile(
@@ -257,7 +202,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
HloInstruction::CreateBroadcast(r2f32, while_loop, {0}));
HloInstruction* matrix = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>(
{{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}})));
// Creates 32 bytes
HloInstruction* transpose = builder.AddInstruction(
@@ -269,16 +214,16 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
module->AddEntryComputation(builder.Build());
- TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(
- *module,
- [](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape());
- },
- ListMemoryScheduler));
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
// Verify that all instructions are in the sequence.
- EXPECT_EQ(module->entry_computation()->instruction_count(),
- sequence.at(module->entry_computation()).size());
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
SequentialHloOrdering ordering(module.get(), sequence);
// This schedule is an example of List's greedy heuristics being suboptimal.
// The while_loop is more expensive than transpose, so it would have been
@@ -287,6 +232,24 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast));
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations. The max mem doesn't change
+ // because the while body isn't live during the peak.
+ EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
}
TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
@@ -297,7 +260,7 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
// Wrap lit in abs because constants are considered free by
// IgnoreInstruction, and it skews the accounting.
auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 1, 1, 1, 1, 1})));
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
auto abs_const = builder.AddInstruction(
HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
@@ -318,12 +281,12 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(
SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(*module,
- [&TUPLE_SIZE](const BufferValue& buffer) {
- return ShapeUtil::ByteSizeOf(
- buffer.shape(), TUPLE_SIZE);
- },
- ListMemoryScheduler));
+ ScheduleComputationsInModule(*module,
+ [&TUPLE_SIZE](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(
+ buffer.shape(), TUPLE_SIZE);
+ },
+ ListMemoryScheduler));
// Verify that all instructions are in the sequence.
EXPECT_EQ(module->entry_computation()->instruction_count(),
@@ -340,11 +303,11 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
HloComputation::Builder builder(TestName());
auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 1, 1, 1, 1})));
+ LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1, 2, 3, 4, 5})));
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({0, 2, 4, 6, 8})));
+ LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
@@ -368,7 +331,7 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
{tuple, mul, add}, HloInstruction::FusionKind::kLoop);
TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence,
- CreateMemoryMinimizingSequence(
+ ScheduleComputationsInModule(
*module,
[](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
@@ -384,5 +347,73 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
}
+TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
+ auto module = CreateNewModule();
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4});
+
+ // param != 0
+ // Needs 17 bytes
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "cond_param"));
+ HloInstruction* zero_vector =
+ cond_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0, 0, 0, 0}})));
+ cond_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector));
+ auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
+
+ // param - 1
+ // Needs 16 bytes
+ auto body_builder = HloComputation::Builder("WhileBody");
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "body_param"));
+ HloInstruction* one_vector =
+ body_builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+ body_builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, body_param, one_vector));
+ auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* while_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
+ // Creates 16 bytes, ignoring subcomputations
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ r1f32, cond_computation, body_computation, while_init));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ TF_ASSERT_OK_AND_ASSIGN(
+ SequentialHloOrdering::HloModuleSequence sequence,
+ ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler));
+ // Verify that all instructions are in the sequence.
+ auto entry_computation = module->entry_computation();
+ EXPECT_EQ(entry_computation->instruction_count(),
+ sequence.at(entry_computation).size());
+
+ tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ memory_by_computation[cond_computation] = 17;
+ memory_by_computation[body_computation] = 16;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
+ TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
+
+ // HeapSimulator doesn't account for subcomputations
+ EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn)
+ .ValueOrDie());
+ // HeapSimulator accounts for subcomputations
+ EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation(
+ *entry_computation, sequence.at(entry_computation),
+ *points_to_analysis, size_fn, &memory_by_computation)
+ .ValueOrDie());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 58224ef870..393944c20f 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -39,6 +39,55 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
return HloSharding(tile_shape, assignment);
}
+HloSharding HloSharding::Tuple(const ShapeTree<HloSharding>& sub_shardings) {
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(sub_shardings.leaf_count());
+ for (const auto& index_to_sharding : sub_shardings.leaves()) {
+ flattened_list.push_back(index_to_sharding.second);
+ }
+ if (flattened_list.empty()) {
+ // Empty tuple sharding ends up having no leaves, but we want to allow
+ // empty tuple HLO instruction results to have sharding, so we fetch the
+ // root ({}) sharding value from the ShapeTree.
+ // A ShapeTree created with ShapeTree<HloSharding>(shape, init) will have
+ // init as value at its root.
+ flattened_list.push_back(sub_shardings.element(ShapeIndex({})));
+ }
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::Tuple(
+ const Shape& tuple_shape,
+ tensorflow::gtl::ArraySlice<HloSharding> shardings) {
+ CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ for (auto& sharding : shardings) {
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ }
+ std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
+ CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape))
+ << "Flat list has " << flattened_list.size() << ", required "
+ << RequiredLeaves(tuple_shape);
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding) {
+ CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape);
+ CHECK(!sharding.IsTuple()) << sharding.ToString();
+ int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape);
+ std::vector<HloSharding> flattened_list;
+ flattened_list.reserve(leaf_count);
+ for (int64 i = 0; i < leaf_count; ++i) {
+ flattened_list.push_back(sharding);
+ }
+ return HloSharding(flattened_list);
+}
+
+HloSharding HloSharding::Single(const Shape& shape,
+ const HloSharding& sharding) {
+ return ShapeUtil::IsTuple(shape) ? SingleTuple(shape, sharding) : sharding;
+}
+
string HloSharding::ToString() const {
if (IsTuple()) {
std::vector<string> parts;
@@ -72,6 +121,29 @@ bool HloSharding::UsesDevice(int64 device) const {
std::find(devices.begin(), devices.end(), device) != devices.end();
}
+std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
+ int64 element_count = 1;
+ std::map<int64, int64> device_map;
+ if (IsTuple()) {
+ for (auto& tuple_element_sharding : tuple_elements()) {
+ auto unique_device = tuple_element_sharding.UniqueDevice();
+ if (unique_device.ok()) {
+ device_map[unique_device.ValueOrDie()] += 1;
+ }
+ }
+ element_count = tuple_elements().size();
+ } else {
+ auto unique_device = UniqueDevice();
+ if (unique_device.ok()) {
+ device_map[unique_device.ValueOrDie()] += 1;
+ }
+ }
+ if (count != nullptr) {
+ *count = element_count;
+ }
+ return device_map;
+}
+
std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
CHECK(!ShapeUtil::IsTuple(tile_shape_));
CHECK(!maximal_);
@@ -123,24 +195,49 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
return index;
}
+int64 HloSharding::RequiredLeaves(const Shape& shape) {
+ // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are
+ // concerned, but they do have a single tuple_elements_ entry since we want
+ // to allow empty tuple results to have sharding.
+ return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape);
+}
+
+Status HloSharding::CheckLeafCount(const Shape& shape) const {
+ int64 shape_leaves = RequiredLeaves(shape);
+ TF_RET_CHECK(shape_leaves == tuple_elements_.size())
+ << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves
+ << " leaf nodes while this sharding has " << tuple_elements_.size();
+ return Status::OK();
+}
+
StatusOr<ShapeTree<HloSharding>> HloSharding::AsShapeTree(
const Shape& shape) const {
if (IsTuple()) {
ShapeTree<HloSharding> result(shape, HloSharding::Replicate());
- int64 num_leaves = result.leaf_count();
- TF_RET_CHECK(num_leaves == tuple_elements_.size())
- << "Shape " << ShapeUtil::HumanString(shape) << " has " << num_leaves
- << " leaf nodes while this sharding has " << tuple_elements_.size();
+ TF_RETURN_IF_ERROR(CheckLeafCount(shape));
auto it = tuple_elements_.begin();
for (auto& index_to_sharding : result.leaves()) {
index_to_sharding.second = *it++;
}
+ if (ShapeUtil::IsEmptyTuple(shape)) {
+ // Empty tuples have no leaves, but we want to assign them a sharding
+ // anyway, so we use the root element sharding.
+ *result.mutable_element(ShapeIndex({})) = *it;
+ }
return std::move(result);
} else {
return ShapeTree<HloSharding>(shape, *this);
}
}
+StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const {
+ if (IsTuple()) {
+ TF_RETURN_IF_ERROR(CheckLeafCount(shape));
+ return *this;
+ }
+ return Tuple(ShapeTree<HloSharding>(shape, *this));
+}
+
StatusOr<int64> HloSharding::UniqueDevice() const {
if (IsTuple()) {
if (tuple_elements_.empty()) {
@@ -182,28 +279,12 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
return tensorflow::errors::InvalidArgument(
StrCat("Sharding is tuple-shaped but validation shape is not."));
}
- // The easiest way to get the number of elements in a nested tuple is just to
- // create a shape tree. We could call GetAsShapeTree, but that will try and
- // apply our tuple_shardings_ to the shape tree, and that might cause a crash
- // at this point as we haven't validated them.
- ShapeTree<bool> bool_shape_tree(shape, false);
- int64 num_leaves =
- std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end());
- if (num_leaves != tuple_elements_.size()) {
- return tensorflow::errors::InvalidArgument(
- StrCat("Validation tuple shape has ", num_leaves,
- " leaf elements, but this sharding contains ",
- tuple_elements_.size(), " elements."));
- }
+ TF_RETURN_IF_ERROR(CheckLeafCount(shape));
// Now we've validated the number of tuple elements, it's safe to request a
// shape tree.
ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
for (const auto& index_to_sharding : shape_tree.leaves()) {
- if (index_to_sharding.first.empty()) {
- // An empty tuple has a ShapeTree with a single leaf at the empty index.
- continue;
- }
Status status = index_to_sharding.second.ValidateNonTuple(
ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
if (!status.ok()) {
@@ -389,6 +470,40 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape,
: sub_shape_tree.element(ShapeIndex({}));
}
+tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding()
+ const {
+ if (!IsTuple()) {
+ return *this;
+ }
+ for (int64 i = 1; i < tuple_elements_.size(); ++i) {
+ if (tuple_elements_[0] != tuple_elements_[i]) {
+ return tensorflow::gtl::optional<HloSharding>();
+ }
+ }
+ return tuple_elements_.front();
+}
+
+size_t HloSharding::Hash() const {
+ if (!tuple_) {
+ size_t h = 0;
+ for (const auto& element : tuple_elements_) {
+ h = tensorflow::Hash64Combine(h, element.Hash());
+ }
+ return h;
+ }
+ if (replicated_) {
+ return 0;
+ }
+ size_t h = 0;
+ for (uint32 v : tile_assignment_) {
+ h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
+ }
+ for (uint32 v : tile_shape_.dimensions()) {
+ h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
+ }
+ return h;
+}
+
std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) {
out << sharding.ToString();
return out;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index f4a0fb626f..6f672b0f28 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -19,10 +19,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_
+#include <map>
#include <string>
+#include <vector>
#include "tensorflow/compiler/xla/array.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -70,26 +72,22 @@ class HloSharding {
// Creates a new sharding for a tuple type. The given ShapeTree must have
// elements for every leaf shape contained in the tuple.
- static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
- std::vector<HloSharding> flattened_list;
- flattened_list.reserve(
- std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
- for (const auto& index_to_sharding : sub_shardings.leaves()) {
- flattened_list.push_back(index_to_sharding.second);
- }
- return HloSharding(flattened_list);
- }
+ static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings);
- // Creates a new sharding for a tuple type. The requested tuple shape must not
- // be nested. For nested tuples, use the ShapeTree overload.
+ // Creates a new sharding for a tuple type. The number of elements in
+ // shardings must match the number of leaf nodes in tuple_shape. For
+ // empty tuples, the shardings array must have one element.
static HloSharding Tuple(const Shape& tuple_shape,
- tensorflow::gtl::ArraySlice<HloSharding> shardings) {
- CHECK(ShapeUtil::IsTuple(tuple_shape));
- CHECK(!ShapeUtil::IsNestedTuple(tuple_shape));
- std::vector<HloSharding> flattened_list(shardings.begin(), shardings.end());
- CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape));
- return HloSharding(flattened_list);
- }
+ tensorflow::gtl::ArraySlice<HloSharding> shardings);
+
+ // Creates a new sharding for a tuple type, with a single input sharding
+ // repeated on each leaf.
+ static HloSharding SingleTuple(const Shape& tuple_shape,
+ const HloSharding& sharding);
+
+ // If shape is an array, returns sharding, otherwise returns the tuple shaped
+ // sharding with all the leaf nodes having the same input sharding.
+ static HloSharding Single(const Shape& shape, const HloSharding& sharding);
// Create a new sharding from a protobuf OpSharding.
static StatusOr<HloSharding> FromProto(const OpSharding& proto);
@@ -131,6 +129,14 @@ class HloSharding {
// Returns true if the sharding defines an operation on the given device.
bool UsesDevice(int64 device) const;
+ // Retrieves an histogram of the devices used by the sharding. The returned
+ // map has the device number as key, and the occurrence count as value.
+ // If a sharding does not have a device, it will not be incuded in the
+ // histogram. The count argument, if not nullptr, will receive the total
+ // number of elements this sharding is made of (one for array, N leaves for
+ // tuples).
+ std::map<int64, int64> UsedDevices(int64* count) const;
+
// Returns the tile that should be executed on the given device.
// REQUIRES: !IsTuple()
std::vector<int64> TileIndexForDevice(int64 device) const;
@@ -172,6 +178,18 @@ class HloSharding {
// REQUIRES: IsTuple()
HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const;
+ // If the current sharding is a tuple sharding, return itself as result.
+ // Otherwise returns a tuple sharding for the input shape, with all the leaves
+ // having this object sharding.
+ StatusOr<HloSharding> GetTupleSharding(const Shape& shape) const;
+
+ // Extracts the sharding that is common within the current sharding.
+ // If the current sharding is not a tuple sharding, the current sharding will
+ // be returned. If it is a tuple, and all the tuple elements are common, the
+ // common element will be returned. Otherwise the optional will contain no
+ // value.
+ tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const;
+
bool operator==(const HloSharding& other) const {
return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
ShapeUtil::Compatible(tile_shape_, other.tile_shape_) &&
@@ -180,26 +198,7 @@ class HloSharding {
}
bool operator!=(const HloSharding& other) const { return !(*this == other); }
- size_t Hash() const {
- if (!tuple_) {
- size_t h = 0;
- for (const auto& element : tuple_elements_) {
- h = tensorflow::Hash64Combine(h, element.Hash());
- }
- return h;
- }
- if (replicated_) {
- return 0;
- }
- size_t h = 0;
- for (uint32 v : tile_assignment_) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- for (uint32 v : tile_shape_.dimensions()) {
- h = tensorflow::Hash64Combine(h, std::hash<uint32>{}(v));
- }
- return h;
- }
+ size_t Hash() const;
struct Hasher {
size_t operator()(const HloSharding& sharding) const {
@@ -241,6 +240,12 @@ class HloSharding {
tuple_(false),
tile_shape_(),
tile_assignment_({0}) {}
+ // device_id values:
+ // -2: magic number to mean unassigned device, used by spatial partitioning
+ // -1: the id of the host
+ // 0 or positive: the id of a device
+ // NOTE(dimvar): -1 is needed for outside compilation. It can be removed once
+ // we have fully switched to the side-effect tokens.
explicit HloSharding(int64 device_id)
: replicated_(false),
maximal_(true),
@@ -260,11 +265,19 @@ class HloSharding {
tile_assignment_({0}),
tuple_elements_(tuple_shardings) {}
+ // Checks that the number of elements in tuple_elements_ is consistent with
+ // the tuple shape passes as argument.
+ Status CheckLeafCount(const Shape& shape) const;
+
// Internal helper to validate a tuple sharding.
Status ValidateTuple(const Shape& shape, int64 num_devices) const;
+
// Internal helper to validate a non-tuple (leaf) sharding.
Status ValidateNonTuple(const Shape& shape, int64 num_devices) const;
+ // Returns the number of tuple_elements_ entries to fit the shape.
+ static int64 RequiredLeaves(const Shape& shape);
+
bool replicated_;
bool maximal_;
bool tuple_;
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 82cff2a4b7..4f91d619ef 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -31,32 +31,22 @@ struct PassThrough {
HloInstruction* operand = nullptr;
};
-void SetDeviceSharding(HloInstruction* instruction, int64 device) {
- VLOG(4) << " " << instruction->name() << " to device " << device;
- instruction->set_device_sharding(device);
-}
-
-tensorflow::gtl::optional<int64> ShardingUniqueDevice(
- const HloSharding& sharding) {
- if (sharding.IsTileMaximal()) {
- auto device = sharding.UniqueDevice();
- if (device.ok()) {
- return device.ValueOrDie();
- }
- }
- return tensorflow::gtl::optional<int64>();
+void SetSingleSharding(HloInstruction* instruction,
+ const HloSharding& sharding) {
+ VLOG(4) << " " << instruction->name() << " to " << sharding;
+ instruction->set_single_sharding(sharding);
}
bool ShardingMatches(const HloSharding& sharding1,
const HloSharding& sharding2) {
- auto device1 = ShardingUniqueDevice(sharding1);
- if (device1) {
- auto device2 = ShardingUniqueDevice(sharding2);
- if (device2) {
- return *device1 == *device2;
+ auto single_sharding1 = sharding1.ExtractSingleSharding();
+ if (single_sharding1) {
+ auto single_sharding2 = sharding2.ExtractSingleSharding();
+ if (single_sharding2) {
+ return *single_sharding1 == single_sharding2;
}
}
- // Anything which is not tile maximal with unique device, gets a full sharding
+ // Anything which is not unique across all elements, gets a full sharding
// compare.
return sharding1 == sharding2;
}
@@ -98,6 +88,12 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
VLOG(2) << " " << instruction->ToString();
}
}
+ if (instruction == instruction->parent()->root_instruction()) {
+ pass_through.emplace_back(nullptr, instruction);
+ VLOG(2) << "Found passthrough domain link:";
+ VLOG(2) << " <root>";
+ VLOG(2) << " " << instruction->ToString();
+ }
}
return pass_through;
}
@@ -111,29 +107,33 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
tuple, 0));
gte->set_sharding(sharding);
- TF_RETURN_IF_ERROR(
- pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ if (pass_through.user != nullptr) {
+ TF_RETURN_IF_ERROR(
+ pass_through.operand->ReplaceUseWith(pass_through.user, gte));
+ } else {
+ pass_through.operand->parent()->set_root_instruction(gte);
+ }
}
return Status::OK();
}
std::unique_ptr<HloSharding> CloneShardingForDomain(
const HloSharding& sharding) {
- auto device = ShardingUniqueDevice(sharding);
- if (!device) {
+ auto single_sharding = sharding.ExtractSingleSharding();
+ if (!single_sharding) {
return MakeUnique<HloSharding>(sharding);
}
- return MakeUnique<HloSharding>(HloSharding::AssignDevice(*device));
+ return MakeUnique<HloSharding>(*single_sharding);
}
-Status ApplyDomainDeviceSharding(const DomainMetadata::Domain& domain,
- int64 device) {
- VLOG(4) << "Applying device " << device << " sharding";
+Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
+ const HloSharding& sharding) {
+ VLOG(4) << "Applying " << sharding << " sharding";
for (HloInstruction* instruction : domain.instructions) {
// We only change instructions without sharding, since otherwise we might
// mess up with eventual HLO passes which has knowledge of it.
if (!instruction->has_sharding()) {
- SetDeviceSharding(instruction, device);
+ SetSingleSharding(instruction, sharding);
} else {
VLOG(4) << " " << instruction->name() << " already has sharding "
<< instruction->sharding();
@@ -186,12 +186,15 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
const HloSharding* tuple_sharding =
GetOperandSharding(tuple, domain, sharding);
if (tuple_sharding != nullptr) {
- TF_RET_CHECK(tuple_sharding->IsTuple()) << tuple->ToString();
- HloSharding sub_sharding = tuple_sharding->GetSubSharding(
- tuple->shape(), {instruction->tuple_index()});
- VLOG(4) << " " << instruction->name() << " to sharding "
- << sub_sharding;
- instruction->set_sharding(sub_sharding);
+ if (tuple_sharding->IsTuple()) {
+ HloSharding sub_sharding = tuple_sharding->GetSubSharding(
+ tuple->shape(), {instruction->tuple_index()});
+ VLOG(4) << " " << instruction->name() << " to sharding "
+ << sub_sharding;
+ instruction->set_sharding(sub_sharding);
+ } else {
+ SetSingleSharding(instruction, *tuple_sharding);
+ }
++assigned;
}
} else if (instruction->opcode() == HloOpcode::kTuple) {
@@ -242,12 +245,29 @@ StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
const HloSharding& sharding) {
- auto device = ShardingUniqueDevice(sharding);
- if (device) {
- // Shortcut the simple case. We have a unique device sharding, so we call
- // the ApplyDomainDeviceSharding() API which will apply array or tuple
- // shaped device sharding to the domain instructions.
- return ApplyDomainDeviceSharding(domain, *device);
+ // Here is the place to call external sharding normalizers, which are
+ // implemented in other modules (ie, spatial partitioning).
+ // The signature of the external normalizer function should be something
+ // like:
+ //
+ // StatusOr<bool> Normalizer(const DomainMetadata::Domain&,
+ // const HloSharding& sharding);
+ //
+ // The function should return true if it has processed the domain
+ // normalization, false if domain was not one recognized by it, or an error.
+ // We will call the functions in order below, and fall back to local code if
+ // none of the external normalizers acted on the domain.
+ // External normalizers should not handle the cases that are already handled
+ // locally.
+
+ // None of the external normalizers handled the domain sharding, try to see
+ // whether this is a single sharding first.
+ auto single_sharding = sharding.ExtractSingleSharding();
+ if (single_sharding) {
+ // Shortcut the simple case. We have a unique sharding, so we call
+ // the ApplyDomainSingleSharding() API which will apply array or tuple
+ // shaped sharding to the domain instructions.
+ return ApplyDomainSingleSharding(domain, *single_sharding);
}
VLOG(1) << "Assigning non-trivial sharding " << sharding;
for (;;) {
@@ -367,7 +387,7 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const {
}
string ShardingMetadata::ToString() const {
- return sharding_ != nullptr ? sharding_->ToString() : "None";
+ return sharding_ != nullptr ? sharding_->ToString() : "{}";
}
Status ShardingMetadata::NormalizeInstructions(
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 94d1a3226b..7baa927d0e 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -18,12 +18,12 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -311,18 +311,20 @@ TEST_F(HloShardingTest, OstreamTest) {
EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
}
-TEST_F(HloShardingTest, Parse) {
+TEST_F(HloShardingTest, ParseHloString) {
auto check = [](const HloSharding& sharding) {
TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
- tools::ParseSharding(sharding.ToString()));
+ ParseSharding(sharding.ToString()));
EXPECT_EQ(sharding, parsed_sharding);
};
check(HloSharding::Replicate());
check(HloSharding::AssignDevice(2));
check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}),
Array4D<int64>({{{{0}, {1}}}})));
- // Empty tuple.
- check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), {}));
+ // Empty tuple. One sharding is required for empty tuples, as we need to be
+ // able to assign sharding to them, even though they have no leaves.
+ check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}),
+ {HloSharding::Replicate()}));
{
// Non-nested tuple.
auto tuple_shape =
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
index 7b601f9a95..45c684d667 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
@@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
auto x = builder.AddInstruction(
HloInstruction::CreateCall(r0s32_, {constant}, callee1));
auto y = builder.AddInstruction(
@@ -112,9 +112,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(5)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(3)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
auto x = builder.AddInstruction(
HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
auto y = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index 3dc733940f..48f676db85 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/attr_value.pb.h"
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index be156d765d..1e2b31a1f2 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -90,7 +90,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
TEST_F(HloTfGraphBuilderTest, CheckScalarValue) {
auto builder = HloComputation::Builder("Const");
HloInstruction *instruction = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0(123)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(123)));
OpMetadata metadata;
metadata.set_op_name("x");
metadata.set_op_type("y");
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h
index 7928bee5c2..533429608b 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/service/hlo_token.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
#include <string>
@@ -22,9 +22,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace tools {
// Defines different kinds of tokens in a hlo module string.
+//
+// You shouldn't need to use this directly unless you're using HloLexer
+// directly, and you probably don't need to do that. Use hlo_parser instead.
enum class TokKind {
// Markers
kEof,
@@ -72,7 +74,6 @@ enum class TokKind {
string TokKindToString(TokKind kind);
-} // namespace tools
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 7b27dbfec3..4e3c9df3a0 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -125,7 +125,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
// transparently.
CHECK_EQ(operand_number, 0);
return index.empty();
- case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
// Select does not use any nested elements of its selected-from operands
// (operand 1 and 2)
CHECK_GE(operand_number, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9cfd8a9bf7..f896773729 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include <set>
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -39,6 +41,10 @@ Status ShapeVerifier::HandleSelect(HloInstruction* select) {
return CheckTernaryShape(select);
}
+Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
+ return CheckTernaryShape(tuple_select);
+}
+
Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : concatenate->operands()) {
@@ -106,22 +112,57 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
reduce_precision->mantissa_bits()));
}
-Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); }
+namespace {
+
+Status CheckIsTokenOperand(const HloInstruction* instruction,
+ int64 operand_no) {
+ const HloInstruction* token = instruction->operand(operand_no);
+ if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
+ return InternalError(
+ "Expected operand %lld to be token-shaped, actual shape is"
+ "%s:\n%s",
+ operand_no, ShapeUtil::HumanString(token->shape()).c_str(),
+ instruction->ToString().c_str());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
+ HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
+ // Infeed has an optional single token operand.
+ // TODO(b/80000000): Update when token is not optional.
+ if (infeed->operand_count() == 1) {
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
+ }
+
+ // The output of infeed is a tuple containing the data value and a token.
+ return CheckShape(infeed,
+ ShapeUtil::MakeTupleShape(
+ {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
+}
+
+Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
+ HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
+ // Outfeed has an optional token operand (operand 1).
+ // TODO(b/80000000): Update when token is not optional.
+ if (outfeed->operand_count() == 2) {
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
+ }
-Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
// Outfeed has a separate shape field for the value which is outfed to the
- // host. The shape of the instruction itself is always nil because the outfeed
- // produces no HLO value in the graph.
+ // host. The shape of the instruction itself is always a token.
if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
outfeed->operand(0)->shape())) {
return InternalError(
- "Expected outfeed to have shape compatible with operand's shape %s, "
+ "Expected outfeed shape to be compatible with operand's shape %s, "
"actual shape is %s:\n%s",
ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(),
outfeed->ToString().c_str());
}
- return CheckShape(outfeed, ShapeUtil::MakeNil());
+ return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
@@ -137,7 +178,16 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
- return CheckUnaryShape(sort);
+ if (sort->operand_count() == 2 &&
+ !ShapeUtil::SameDimensions(sort->operand(0)->shape(),
+ sort->operand(1)->shape())) {
+ return InternalError(
+ "Expected sort to have to have the same dimensions for the keys and "
+ "the values. Keys shape is: %s\n, Values shape is: %s",
+ ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(),
+ ShapeUtil::HumanString(sort->operand(1)->shape()).c_str());
+ }
+ return CheckVariadicShape(sort);
}
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
@@ -299,9 +349,11 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
const HloInstruction* send_done = send->users().front();
TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
- return CheckShape(
- send, ShapeUtil::MakeTupleShape(
- {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(send, 1));
+ return CheckShape(send,
+ ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
+ ShapeUtil::MakeShape(U32, {}),
+ ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
@@ -309,7 +361,8 @@ Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
const HloInstruction* send = send_done->operand(0);
TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
- return CheckShape(send_done, ShapeUtil::MakeNil());
+
+ return CheckShape(send_done, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
@@ -317,9 +370,11 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
const HloInstruction* recv_done = recv->users().front();
TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
- return CheckShape(recv,
- ShapeUtil::MakeTupleShape(
- {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(recv, 0));
+ return CheckShape(
+ recv, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::GetTupleElementShape(recv_done->shape(), 0),
+ ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
@@ -327,7 +382,9 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
const HloInstruction* recv = recv_done->operand(0);
TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
- return CheckShape(recv_done, recv->shape().tuple_shapes(0));
+ return CheckShape(recv_done,
+ ShapeUtil::MakeTupleShape({recv->shape().tuple_shapes(0),
+ ShapeUtil::MakeTokenShape()}));
}
Status ShapeVerifier::HandleBatchNormTraining(
@@ -386,6 +443,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kRecvDone:
case HloOpcode::kReducePrecision:
case HloOpcode::kSelect:
+ case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTuple:
@@ -426,6 +484,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
+Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : token->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
+ return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes));
+}
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
@@ -440,16 +506,10 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
switch (instruction->opcode()) {
- case HloOpcode::kSelect:
- if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
- // Select only defines the top-level buffer, which in this case is the
- // tuple, so we cannot allow mixed precision.
- compatible =
- ShapeUtil::Compatible(instruction->shape(), inferred_shape);
- } else {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
- instruction->shape(), inferred_shape);
- }
+ case HloOpcode::kTupleSelect:
+ // TupleSelect only defines the top-level buffer, which in this case is
+ // the tuple, so we cannot allow mixed precision.
+ compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
break;
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
@@ -777,8 +837,7 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
- if (!ShapeUtil::IsScalar(operand_shape) &&
- !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
+ if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
return FailedPrecondition(
"Implicit broadcast is not allowed in HLO."
"Found non-compatible shapes for instruction %s.\n"
@@ -791,6 +850,39 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal.
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ // Tokens cannot be passed as entry parameters.
+ // TODO(b/80000000): Remove this constraint.
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()).c_str());
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
@@ -832,7 +924,9 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
} else if (instruction->opcode() == HloOpcode::kWhile) {
TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
- } else if (instruction->IsElementwise()) {
+ } else if (instruction->opcode() !=
+ HloOpcode::kRng /* Rng operands are always scalar. */
+ && instruction->IsElementwise()) {
TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction));
}
@@ -851,6 +945,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
}
+ TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 1392a78097..12c047850e 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -35,6 +35,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleClamp(HloInstruction* clamp) override;
Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConvert(HloInstruction* convert) override;
Status HandleBitcastConvert(HloInstruction* convert) override;
@@ -81,6 +82,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
+ Status HandleAfterAll(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
index dc3bfce0c4..d7458c338e 100644
--- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
+++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc
@@ -169,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const {
StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_)));
}
}
+
+ if (total_bytes > 0) {
+ MetricTableReport table;
+ table.SetMetricName("MiB read+written");
+ table.SetEntryName("ops");
+ table.SetShowCategoryTable();
+ for (const auto& op : op_infos_) {
+ MetricTableReport::Entry entry;
+ entry.text = op.name;
+ entry.short_text = op.short_name;
+ entry.category_text = op.category;
+ entry.metric = static_cast<double>(op.bytes_accessed) / (1 << 20);
+ table.AddEntry(std::move(entry));
+ }
+ StrAppend(&s,
+ table.MakeReport(static_cast<double>(total_bytes) / (1 << 20)));
+ }
return s;
}
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
index 8c7b38dd1b..f85d31d522 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 8b3fa6c157..8b2df32567 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -28,6 +29,7 @@ namespace {
using Analysis = IndexedArrayAnalysis;
using UnknownArray = Analysis::UnknownArray;
using ConstantArray = Analysis::ConstantArray;
+using ReshapedArray = Analysis::ReshapedArray;
using ScalarIndexedArray = Analysis::ScalarIndexedArray;
using tensorflow::gtl::ArraySlice;
using tensorflow::str_util::Join;
@@ -52,6 +54,13 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
"(constant ", ShapeUtil::HumanString(root->shape()), ")");
}
+ case Array::kReshaped: {
+ ReshapedArray* reshaped_array = root->as<ReshapedArray>();
+ return tensorflow::strings::StrCat(
+ "(reshape ", ToString(reshaped_array->operand(), print_constants),
+ " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")");
+ }
+
case Array::kScalarIndexedConstant:
case Array::kScalarIndexed: {
auto* indexed_array = root->as<ScalarIndexedArray>();
@@ -152,6 +161,12 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
computed_array,
ComputeArrayForReshape(instr->shape(),
FindOrDie(cache_, instr->operand(0))));
+ } else if (instr->opcode() == HloOpcode::kDot) {
+ TF_ASSIGN_OR_RETURN(
+ computed_array,
+ ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(),
+ FindOrDie(cache_, instr->operand(0)),
+ FindOrDie(cache_, instr->operand(1))));
} else {
computed_array = nullptr;
}
@@ -239,15 +254,40 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
Array* indices) {
if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) {
+ VLOG(3) << "ComputeArrayForGather: indices are not scalar";
return nullptr;
}
CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1);
- if (!c_binary_search(dim_numbers.elided_window_dims(),
- dim_numbers.gather_dims_to_operand_dims(0))) {
+
+ // We can also handle dim_numbers.elided_window_dims_size() == 0 here, should
+ // it become relevant.
+
+ if (dim_numbers.elided_window_dims_size() != 1 ||
+ dim_numbers.elided_window_dims(0) !=
+ dim_numbers.gather_dims_to_operand_dims(0)) {
+ VLOG(3) << "ComputeArrayForGather: gather operations must elide "
+ "gather_dims_to_operand_dims[0] and "
+ "gather_dims_to_operand_dims[0] only";
return nullptr;
}
+ // ScalarIndexedArray cannot represent gathers that "slice" along some
+ // dimensions -- for instance it cannot represent a gather that picks 5 [2,3]
+ // arrays from an array of size [7,4,6]. We check that condition down below:
+
+ for (int64 i = 0, e = source->shape().dimensions_size(); i < e; i++) {
+ if (i != dim_numbers.elided_window_dims(0) &&
+ source->shape().dimensions(i) != window_bounds[i]) {
+ VLOG(3) << "ComputeArrayForGather: window_bounds[" << i
+ << "] != source->shape().dimensions(" << i << ") -- "
+ << source->shape().dimensions(i) << " vs. " << window_bounds[i]
+ << " with dim_numbers.elided_window_dims(0) = "
+ << dim_numbers.elided_window_dims(0);
+ return nullptr;
+ }
+ }
+
int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0);
std::vector<int64> output_dims;
for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
@@ -257,8 +297,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
}
if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) {
- auto it = c_find(indexed->output_dims(), source_dim);
- if (it != indexed->output_dims().end()) {
+ if (c_linear_search(indexed->output_dims(), source_dim)) {
return FoldGatherOfGather(indexed, indices, source_dim, output_dims,
shape);
}
@@ -336,7 +375,11 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
// result_subarray_size does not include the elements in the current
// `result_dim` dimension (we multiply in result_shape[result_dim] at the
// end of loop body) so candidate_operand_dim can never be zero.
- CHECK_NE(candidate_operand_dim, 0);
+ CHECK_NE(candidate_operand_dim, 0)
+ << "result_dim = " << result_dim
+ << ", result_subarray_size = " << result_subarray_size
+ << ", result_shape = [" << Join(result_shape, ",") << "]"
+ << ", operand_shape = [" << Join(operand_shape, ",") << "]";
if (candidate_operand_dim != -1 &&
result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) {
@@ -357,7 +400,7 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs(
});
VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to ["
<< Join(result_shape, ",") << "] passthrough indices are ["
- << Join(result_strings, ",") << "]";
+ << Join(result_strings, ",") << "] (legend: `result`->`operand`)";
}
DCHECK(c_is_sorted(
@@ -398,6 +441,10 @@ int64 MapPassthroughOperandDimToResultDim(
int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
ArraySlice<int64> result_shape,
int64 source_passthrough_dim) {
+ VLOG(3) << "FindSourcePositionForPassthroughResultDim(["
+ << Join(operand_shape, ",") << "], [" << Join(result_shape, ",")
+ << "], " << source_passthrough_dim << ")";
+
int64 indexed_source_subarray_size =
std::accumulate(operand_shape.begin() + source_passthrough_dim + 1,
operand_shape.end(), 1, std::multiplies<int64>());
@@ -405,15 +452,191 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
return FindSuffixWithProduct(result_shape, indexed_source_subarray_size);
}
+Shape StripDegenerateDimensions(const Shape& shape) {
+ DimensionVector new_dims;
+ c_copy_if(shape.dimensions(), std::back_inserter(new_dims),
+ [](int64 dim) { return dim != 1; });
+ return ShapeUtil::MakeShape(shape.element_type(), new_dims);
+}
}; // namespace
-StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
- const Shape& shape, Array* operand) {
- auto* scalar_indexed = dynamic_cast<ScalarIndexedConstantArray*>(operand);
- if (!scalar_indexed) {
+StatusOr<ScalarIndexedArray*>
+IndexedArrayAnalysis::ReshapeToRemoveDegenerateDims(
+ ScalarIndexedArray* operand) {
+ const Shape& shape = operand->shape();
+ if (!ShapeUtil::HasDegenerateDimensions(shape)) {
+ return operand;
+ }
+
+ // We only need to reshape out the degenerate dims from the indices and the
+ // source (except the source dim).
+
+ const Shape& source_shape = operand->source()->shape();
+ DimensionVector new_source_shape_dims;
+ for (int64 i = 0, e = source_shape.dimensions_size(); i < e; i++) {
+ if (i == operand->source_dim() || source_shape.dimensions(i) != 1) {
+ new_source_shape_dims.push_back(source_shape.dimensions(i));
+ }
+ }
+
+ Shape new_source_shape =
+ ShapeUtil::MakeShape(shape.element_type(), new_source_shape_dims);
+ Shape new_indices_shape =
+ StripDegenerateDimensions(operand->indices()->shape());
+
+ TF_ASSIGN_OR_RETURN(
+ Array* const new_source,
+ ComputeArrayForReshape(new_source_shape, operand->source()));
+ TF_ASSIGN_OR_RETURN(
+ Array* const new_indices,
+ ComputeArrayForReshape(new_indices_shape, operand->indices()));
+
+ // Build the new output dims while keeping track of the degenerate dims that
+ // will no longer be present.
+ DimensionVector new_output_dims;
+ int64 degenerate_dims_seen = 0;
+ for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
+ if (shape.dimensions(i) == 1) {
+ degenerate_dims_seen++;
+ } else if (ArrayContains(operand->output_dims(), i)) {
+ new_output_dims.push_back(i - degenerate_dims_seen);
+ }
+ }
+
+ // Similarly, build the new source dim while keeping track of the degenerate
+ // dims that will no longer be present.
+ int64 degenerate_dims_before_source_dim =
+ std::count(source_shape.dimensions().begin(),
+ source_shape.dimensions().begin() + operand->source_dim(), 1);
+ int64 new_source_dim =
+ operand->source_dim() - degenerate_dims_before_source_dim;
+
+ return ConstructScalarIndexedArray(
+ new_source, new_indices, new_source_dim,
+ InlinedVectorToVector(new_output_dims),
+ StripDegenerateDimensions(operand->shape()));
+}
+
+StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims(
+ ScalarIndexedArray* operand,
+ tensorflow::gtl::ArraySlice<int64> degenerate_dims) {
+ if (degenerate_dims.empty()) {
+ return operand;
+ }
+
+ CHECK(!ShapeUtil::HasDegenerateDimensions(operand->shape()));
+
+ DimensionVector new_output_dims = [&]() {
+ // To make things easy we use a "scratch" buffer of bools where the i'th
+ // element is true iff the i'th component of the result index is an output
+ // index.
+
+ gtl::InlinedVector<bool, 6> output_dims_bitvector(
+ operand->shape().dimensions_size());
+ for (int64 output_dim : operand->output_dims()) {
+ output_dims_bitvector[output_dim] = true;
+ }
+
+ for (int64 degenerate_dim : degenerate_dims) {
+ InsertAt(&output_dims_bitvector, degenerate_dim, false);
+ }
+
+ DimensionVector result;
+ result.reserve(operand->output_dims().size());
+ for (int64 i = 0, e = output_dims_bitvector.size(); i < e; i++) {
+ if (output_dims_bitvector[i]) {
+ result.push_back(i);
+ }
+ }
+
+ return result;
+ }();
+
+ DimensionVector new_result_shape_dims;
+ c_copy(operand->shape().dimensions(),
+ std::back_inserter(new_result_shape_dims));
+ for (int64 degenerate_dim : degenerate_dims) {
+ InsertAt(&new_result_shape_dims, degenerate_dim, 1);
+ }
+
+ DimensionVector new_source_shape_dims = new_result_shape_dims;
+ for (int64 output_dim : new_output_dims) {
+ EraseAt(&new_source_shape_dims, output_dim);
+ }
+
+ int64 new_source_dim = [&]() {
+ for (int i = 0, e = new_source_shape_dims.size(); i < e; i++) {
+ int64 non_degenerate_dims_seen = 0;
+ if (non_degenerate_dims_seen == operand->source_dim()) {
+ return i;
+ }
+ if (new_source_shape_dims[new_source_dim] != 1) {
+ non_degenerate_dims_seen++;
+ }
+ }
+ LOG(FATAL) << "Did not find source dim in " << ToString(operand);
+ }();
+
+ int64 source_dim_size =
+ operand->source()->shape().dimensions(operand->source_dim());
+ InsertAt(&new_source_shape_dims, /*index=*/new_source_dim,
+ /*value=*/source_dim_size);
+
+ Shape new_source_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
+ new_source_shape_dims);
+ Shape new_result_shape = ShapeUtil::MakeShape(operand->shape().element_type(),
+ new_result_shape_dims);
+
+ TF_ASSIGN_OR_RETURN(
+ Array* const new_source,
+ ComputeArrayForReshape(new_source_shape, operand->source()));
+ return ConstructScalarIndexedArray(
+ new_source, operand->indices(), new_source_dim,
+ InlinedVectorToVector(new_output_dims), new_result_shape);
+}
+
+StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldReshapeOfGather(
+ const Shape& shape, ScalarIndexedConstantArray* operand) {
+ VLOG(3) << "FoldReshapeOfGather(" << ToString(operand) << ")";
+
+ // To make things easier on ourselves, instead of directly trying to fold the
+ // reshape of `operand` to `shape`, we call
+ // `FoldReshapeOfGatherNoDegenerateDims` on shapes without degenerate dims and
+ // handle the degenerate dimensions here by inserting reshapes.
+
+ TF_ASSIGN_OR_RETURN(ScalarIndexedArray* const operand_without_degenerate_dims,
+ ReshapeToRemoveDegenerateDims(operand));
+
+ Shape output_shape_without_degenerate_dims = StripDegenerateDimensions(shape);
+ TF_ASSIGN_OR_RETURN(
+ ScalarIndexedArray* const folded_reshape_without_degenerate_dims,
+ FoldReshapeOfGatherNoDegenerateDims(
+ output_shape_without_degenerate_dims,
+ operand_without_degenerate_dims->as<ScalarIndexedConstantArray>()));
+
+ if (folded_reshape_without_degenerate_dims == nullptr) {
return nullptr;
}
+ DimensionVector degenerate_result_dims;
+ for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) {
+ if (shape.dimensions(i) == 1) {
+ degenerate_result_dims.push_back(i);
+ }
+ }
+
+ return ReshapeToAddDegenerateDims(folded_reshape_without_degenerate_dims,
+ degenerate_result_dims);
+}
+
+StatusOr<ScalarIndexedArray*>
+IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims(
+ const Shape& shape, ScalarIndexedConstantArray* scalar_indexed) {
+ VLOG(3) << "FoldReshapeOfGatherNoDegenerateDims(" << ToString(scalar_indexed)
+ << ")";
+ CHECK(!ShapeUtil::HasDegenerateDimensions(shape));
+ CHECK(!ShapeUtil::HasDegenerateDimensions(scalar_indexed->shape()));
+
// Try to fold Reshape(ScalarIndexed(Const, Indices))
// => ScalarIndexed(Const', Indices)
//
@@ -464,7 +687,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
std::vector<ReshapePassthroughDimPair> reshape_passthrough_dims =
ComputeReshapePassthroughDimPairs(
- /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()),
+ /*operand_shape=*/AsInt64Slice(scalar_indexed->shape().dimensions()),
/*result_shape=*/AsInt64Slice(shape.dimensions()));
auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) {
@@ -474,6 +697,8 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
if (!c_all_of(scalar_indexed->output_dims(),
is_reshape_passthrough_operand_dim)) {
+ VLOG(3) << "Not all output dims are passthrough dims "
+ << ToString(scalar_indexed);
return nullptr;
}
@@ -527,6 +752,11 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
// (a.k.a. isn't pass-through) than the [3,5,2] array.
if (source_dim_for_new_scalar_indexed_node == -1) {
+ VLOG(3) << "Could not compute the source dim for the new scalar indexed "
+ "node: scalar_indexed_source_shape = ["
+ << Join(scalar_indexed_source_shape.dimensions(), ",")
+ << "] and new_scalar_indexed_source_shape = ["
+ << Join(new_scalar_indexed_source_shape, ",") << "]";
return nullptr;
}
@@ -534,6 +764,10 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
&new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node,
scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim()));
+ CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1l,
+ std::multiplies<int64>()),
+ ShapeUtil::ElementsIn(scalar_indexed_source_shape));
+
CHECK(IsReshapePassthroughOperandDim(
ComputeReshapePassthroughDimPairs(
/*operand_shape=*/AsInt64Slice(
@@ -564,6 +798,31 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
output_dims_for_new_scalar_indexed_node, shape);
}
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
+ const Shape& shape, Array* operand) {
+ if (ShapeUtil::Compatible(operand->shape(), shape)) {
+ return operand;
+ }
+
+ if (auto* scalar_indexed =
+ dynamic_cast<ScalarIndexedConstantArray*>(operand)) {
+ TF_ASSIGN_OR_RETURN(Analysis::Array * reshape_folded_into_gather,
+ FoldReshapeOfGather(shape, scalar_indexed));
+ if (reshape_folded_into_gather) {
+ return reshape_folded_into_gather;
+ }
+ }
+
+ if (auto* constant_array = dynamic_cast<ConstantArray*>(operand)) {
+ TF_ASSIGN_OR_RETURN(Literal* const new_literal,
+ TakeOwnership(constant_array->literal()->Reshape(
+ AsInt64Slice(shape.dimensions()))));
+ return Construct<ConstantArray>(new_literal);
+ }
+
+ return Construct<ReshapedArray>(operand, shape);
+}
+
StatusOr<Analysis::Array*>
IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
Array* lhs,
@@ -703,11 +962,177 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode,
return Construct<ScalarIndexedConstantArray>(
new_source, scalar_indexed_const->indices(),
scalar_indexed_const->source_dim(),
- std::vector<int64>(scalar_indexed_const->output_dims().begin(),
- scalar_indexed_const->output_dims().end()),
+ ArraySliceToVector(scalar_indexed_const->output_dims()),
scalar_indexed_const->shape());
}
+namespace {
+
+// Returns the non-contracting non-batch dimension (as per `contracting_dims`
+// and `batch_dims`) if there is exactly one, otherwise returns nullopt.
+gtl::optional<int64> GetOnlyNonContractingNonBatchDim(
+ int64 rank, ArraySlice<int64> contracting_dims,
+ ArraySlice<int64> batch_dims) {
+ gtl::optional<int64> result;
+ for (int64 dim = 0; dim < rank; dim++) {
+ if (!ArrayContains(contracting_dims, dim) &&
+ !ArrayContains(batch_dims, dim)) {
+ if (result.has_value()) {
+ return gtl::nullopt;
+ }
+ result = dim;
+ }
+ }
+ return result;
+}
+
+// Returns true if `indexed_array`, which is either the LHS or the RHS of a Dot
+// HLO, can be folded into the dot operation. For now these conditions are both
+// necessary and sufficient.
+//
+// `tag` describes the caller. Used only for logging.
+//
+// `contracting_dims` and `batch_dims` are the contracting and batch dimensions
+// of whatever operand `indexed_array` is to the dot (LHS or RHS).
+bool CanFoldDotIntoIndexedArray(
+ tensorflow::StringPiece tag,
+ Analysis::ScalarIndexedConstantArray* indexed_array,
+ ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) {
+ gtl::optional<int64> non_contracting_non_batch_dim =
+ GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()),
+ contracting_dims, batch_dims);
+ if (!non_contracting_non_batch_dim.has_value()) {
+ VLOG(3) << tag << ": multiple or no non-contracting non-batch dimensions";
+ return false;
+ }
+
+ if (indexed_array->output_dims().size() != 1 ||
+ indexed_array->output_dims()[0] != *non_contracting_non_batch_dim) {
+ VLOG(3) << tag << ": output dims != the lhs non-contracting non-batch dim";
+ return false;
+ }
+
+ int64 indexed_array_rank = ShapeUtil::Rank(indexed_array->shape());
+ if (indexed_array->source_dim() < (indexed_array_rank - 2)) {
+ // This restriction can be lifted by inserting reshape nodes.
+ VLOG(3) << tag
+ << ": source dim is not in the low two dims, won't be able to form "
+ "a matmul";
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ScalarIndexedConstantArray* lhs, ConstantArray* rhs) {
+ VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " "
+ << ToString(rhs);
+ if (!CanFoldDotIntoIndexedArray(
+ "ComputeArrayForDotWithIndexedLhs", lhs, /*contracting_dims=*/
+ AsInt64Slice(dim_numbers.lhs_contracting_dimensions()),
+ /*batch_dims=*/AsInt64Slice(dim_numbers.lhs_batch_dimensions()))) {
+ return nullptr;
+ }
+
+ int64 lhs_rank = ShapeUtil::Rank(lhs->shape());
+ DotDimensionNumbers new_dim_numbers = dim_numbers;
+ new_dim_numbers.set_lhs_contracting_dimensions(
+ 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1));
+
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, lhs->literal(), *rhs->literal())));
+
+ // The new source dimension is wherever the non-batch non-contracting LHS
+ // dimension "went".
+ int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
+ dim_numbers.rhs_batch_dimensions_size();
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, lhs->indices(), new_source_dim,
+ ArraySliceToVector(lhs->output_dims()), shape);
+}
+
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ConstantArray* lhs, ScalarIndexedConstantArray* rhs) {
+ VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " "
+ << ToString(rhs);
+ if (!CanFoldDotIntoIndexedArray(
+ "ComputeArrayForDotWithIndexedRhs", rhs, /*contracting_dims=*/
+ AsInt64Slice(dim_numbers.rhs_contracting_dimensions()),
+ /*batch_dims=*/AsInt64Slice(dim_numbers.rhs_batch_dimensions()))) {
+ return nullptr;
+ }
+
+ int64 rhs_rank = ShapeUtil::Rank(rhs->shape());
+
+ DotDimensionNumbers new_dim_numbers = dim_numbers;
+ new_dim_numbers.set_rhs_contracting_dimensions(
+ 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1));
+
+ TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+ TakeOwnership(HloEvaluator{}.EvaluateDotOp(
+ new_dim_numbers, *lhs->literal(), rhs->literal())));
+
+ // The new source dimension is wherever the non-batch non-contracting RHS
+ // dimension "went".
+ int64 new_source_dim = dim_numbers.lhs_batch_dimensions_size() +
+ dim_numbers.rhs_batch_dimensions_size() + 1;
+
+ ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
+ return Construct<ScalarIndexedConstantArray>(
+ new_source, rhs->indices(), new_source_dim,
+ ArraySliceToVector(rhs->output_dims()), shape);
+}
+
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs,
+ Array* rhs) {
+ // Intuitively, if
+ //
+ // - The LHS of a dot product is a gathered sequence of rows from a constant
+ // array (i.e. LHS[I,J] = Const[Indices[I],J]) and the RHS is a constant
+ //
+ // OR
+ //
+ // - If the RHS of a dot product is a gathered sequence of columns from a
+ // constant array (i.e. RHS[I,J] = Const[I, Indices[J]]) and the LHS is a
+ // constant
+ //
+ // then the result of the dot product itself is a gather from a constant
+ // array. E.g. Dot(LHS, ConstRhs) where LHS[I,J] = Const[Indices[I],J] can be
+ // rewritten as Result where Result[I,J] = Dot(Const, ConstRhs)[Indices[I],
+ // J].
+ //
+ // We do a general version of this rewrite here.
+ VLOG(3) << "ComputeArrayForDot(" << ToString(lhs) << " " << ToString(rhs);
+ if (auto* lhs_indexed_array =
+ dynamic_cast<ScalarIndexedConstantArray*>(lhs)) {
+ if (auto* rhs_constant = dynamic_cast<ConstantArray*>(rhs)) {
+ return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers,
+ lhs_indexed_array, rhs_constant);
+ }
+ }
+
+ if (auto* rhs_indexed_array =
+ dynamic_cast<ScalarIndexedConstantArray*>(rhs)) {
+ if (auto* lhs_constant = dynamic_cast<ConstantArray*>(lhs)) {
+ return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant,
+ rhs_indexed_array);
+ }
+ }
+
+ return nullptr;
+}
+
tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const {
return "indexed-array-analysis-printer-pass";
}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index ce92fd2919..e923dc39f7 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -39,7 +39,13 @@ class IndexedArrayAnalysis {
// Array instances are immutable once created.
class Array {
public:
- enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed };
+ enum Kind {
+ kUnknown,
+ kConstant,
+ kReshaped,
+ kScalarIndexedConstant,
+ kScalarIndexed
+ };
virtual Kind kind() const = 0;
virtual const Shape& shape() const = 0;
@@ -96,6 +102,27 @@ class IndexedArrayAnalysis {
friend class IndexedArrayAnalysis;
};
+ // Represents an Array that is a reshape of another Array.
+ class ReshapedArray : public Array {
+ public:
+ Kind kind() const override { return kReshaped; }
+
+ // The array to reshape.
+ Array* operand() const { return operand_; }
+
+ // The output shape.
+ const Shape& shape() const override { return shape_; }
+
+ private:
+ explicit ReshapedArray(Array* operand, Shape shape)
+ : operand_(operand), shape_(shape) {}
+
+ Array* operand_;
+ const Shape shape_;
+
+ friend class IndexedArrayAnalysis;
+ };
+
// ---------------------------------------------------------------------------
// Indexed Array Overview
// ---------------------------------------------------------------------------
@@ -241,6 +268,18 @@ class IndexedArrayAnalysis {
tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
Array* indices);
+ StatusOr<Array*> ComputeArrayForDotWithIndexedLhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ScalarIndexedConstantArray* lhs, ConstantArray* rhs);
+
+ StatusOr<Array*> ComputeArrayForDotWithIndexedRhs(
+ const Shape& shape, const DotDimensionNumbers& dim_numbers,
+ ConstantArray* lhs, ScalarIndexedConstantArray* rhs);
+
+ StatusOr<Array*> ComputeArrayForDot(const Shape& shape,
+ const DotDimensionNumbers& dim_numbers,
+ Array* lhs, Array* rhs);
+
// This tries to fold a ScalarIndexedArray which has another
// ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
// ScalarIndexedArray as indices. If `source` happened to be a
@@ -266,6 +305,21 @@ class IndexedArrayAnalysis {
ScalarIndexedArray* source, Array* indices, int64 source_dim,
tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
+ // Reshapes a scalar-indexed node to remove the degenerate dimensions in its
+ // output. The result is always a scalar-indexed node.
+ StatusOr<ScalarIndexedArray*> ReshapeToRemoveDegenerateDims(
+ ScalarIndexedArray* operand);
+
+ // Reshapes a scalar-indexed node such that the result has the degenerate
+ // dimensions `degenerate_dims`. The result is always a scalar-indexed node.
+ StatusOr<ScalarIndexedArray*> ReshapeToAddDegenerateDims(
+ ScalarIndexedArray* operand,
+ tensorflow::gtl::ArraySlice<int64> degenerate_dims);
+
+ StatusOr<ScalarIndexedArray*> FoldReshapeOfGather(
+ const Shape& shape, ScalarIndexedConstantArray* operand);
+ StatusOr<ScalarIndexedArray*> FoldReshapeOfGatherNoDegenerateDims(
+ const Shape& shape, ScalarIndexedConstantArray* scalar_indexed);
StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(HloOpcode opcode,
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 373556ebeb..5f4b42799b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <ctype.h>
+
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@@ -34,6 +36,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
}
private:
+ // Replaces seqences of whitespace with a single space. This makes the
+ // strings being matched against "whitespace insensitive" which lets us indent
+ // them for readability.
+ string CanonicalizeWhitespace(const string& text) {
+ string result;
+
+ for (char c : text) {
+ if (!isspace(c)) {
+ result.push_back(c);
+ } else if (!result.empty() && result.back() != ' ') {
+ result.push_back(' ');
+ }
+ }
+
+ while (!result.empty() && result.back() == ' ') {
+ result.pop_back();
+ }
+
+ return result;
+ }
+
void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
const string& root_expression,
bool print_constants) {
@@ -44,10 +67,10 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
IndexedArrayAnalysis::Array* const array_result,
indexed_tensor_analysis.GetArrayFor(
module().entry_computation()->root_instruction()));
- string string_result =
- indexed_tensor_analysis.ToString(array_result, print_constants);
+ string string_result = CanonicalizeWhitespace(
+ indexed_tensor_analysis.ToString(array_result, print_constants));
LOG(INFO) << string_result;
- ASSERT_EQ(string_result, root_expression);
+ ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression));
}
};
@@ -91,6 +114,82 @@ ENTRY main {
hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])");
}
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
+ indices = s32[5,2] parameter(0)
+ ROOT gather = s32[5] gather(operand, indices),
+ output_window_dims={},
+ elided_window_dims={0,1},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3,1] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0,2},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3,1] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,2,3] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={2,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,2] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,2}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
string hlo_text = R"(
HloModule SimpleGather
@@ -273,7 +372,157 @@ ENTRY main {
"(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])");
}
-TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) {
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,6] constant(s32[2,6]{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,6] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,6])
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
+
+ i.0 = s64[1,3]{1,0} parameter(0)
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
+ elided_window_dims={0}, gather_dims_to_operand_dims={0},
+ index_vector_dim=2, window_bounds={1,3}
+
+ i.1 = s64[1] parameter(1)
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
+ elided_window_dims={1}, gather_dims_to_operand_dims={1},
+ index_vector_dim=1, window_bounds={1,1,3}
+
+ ROOT reshape = s32[1,3]{1,0} reshape(g.1)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,3])
+ (reshape
+ (scalar-indexed %i.0 %i.1 1->[1])
+ to s64[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,6] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[1,1,1,6])
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[1,2,6] constant(s32[1,2,6]{{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,1,6] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={1,1,6}
+ ROOT reshape = s32[1,1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,1,6] s32[2,1,1,1,6] {
+ { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } },
+ { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } })
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text,
+ expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,6] constant(s32[2,6]{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}})
+ indices = s32[1,5] parameter(0)
+ gather = s32[1,5,6] gather(operand, indices),
+ output_window_dims={2},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,5,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,6] s32[2,1,1,6] {
+ { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } },
+ { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } })
+ (reshape %indices to s32[5])
+ 0->[2])
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text,
+ expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) {
string hlo_text = R"(
HloModule ReshapeOfGather
@@ -290,10 +539,19 @@ ENTRY main {
}
)";
- AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,4])
+ %indices
+ 0->[0,2])
+ to s32[5,2,2,2,3])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
}
-TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) {
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) {
string hlo_text = R"(
HloModule ReshapeOfGather
@@ -313,7 +571,48 @@ ENTRY main {
}
)";
- AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,5,2])
+ %indices
+ 1->[2])
+ to s32[6,7])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,4,1] constant(s32[3,4,1]{
+ {{1},{2},{3},{4}},
+ {{1},{2},{3},{4}},
+ {{1},{2},{3},{4}}})
+ indices = s32[5,6] parameter(0)
+ gather = s32[5,4,6,1] gather(operand, indices),
+ output_window_dims={1,3},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,4,1}
+ ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,4,1])
+ %indices
+ 0->[0,2])
+ to s32[5,2,2,2,3,1])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
}
TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {
@@ -500,5 +799,170 @@ ENTRY main {
AssertArrayForRootExpressionIs(hlo_text, "%add");
}
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_0) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_lhs = s32[5,4] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,4}
+ ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[3,3] s32[3,3] {
+ { 70, 80, 90 },
+ { 158, 184, 210 },
+ { 246, 288, 330 } })
+ %indices 0->[0]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_1) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[3,3] constant(s32[3,3]{{1,2,3},{4,5,6},{7,8,9}})
+ indices = s32[5] parameter(0)
+ dot_lhs = s32[3,5] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[5,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={0}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,3] s32[4,3] {
+ { 84, 99, 114 },
+ { 96, 114, 132 },
+ { 108, 129, 150 },
+ { 120, 144, 168 } })
+ %indices 0->[1]))");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_2) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_rhs = s32[3,5] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,4] s32[4,4] {
+ { 38, 44, 50, 56 },
+ { 83, 98, 113, 128 },
+ { 128, 152, 176, 200 },
+ { 173, 206, 239, 272 } })
+ %indices 1->[1])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpBasic_3) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ dot_lhs_constant = s32[4,3] constant(s32[4,3]{{1,2,3},{4,5,6},{7,8,9},{10,11,12}})
+ indices = s32[5] parameter(0)
+ dot_rhs = s32[5,3] gather(gather_operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3}
+ ROOT dot = s32[4,5] dot(dot_lhs_constant, dot_rhs), lhs_contracting_dims={1}, rhs_contracting_dims={1}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[4,4] s32[4,4] {
+ { 14, 32, 50, 68 },
+ { 32, 77, 122, 167 },
+ { 50, 122, 194, 266 },
+ { 68, 167, 266, 365 } })
+ %indices 1->[0])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpWithBatch) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[2,3,2] constant(s32[2,3,2]{{{1,2},{3,4},{5,6}},{{7,8},{9,10},{11,12}}})
+ dot_lhs_constant = s32[2,2,3] constant(s32[2,2,3]{{{1,2,3},{4,5,6}},{{7,8,9},{10,11,12}}})
+ indices = s32[4] parameter(0)
+ dot_rhs = s32[2,3,4] gather(gather_operand, indices),
+ output_window_dims={0,1},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={2},
+ index_vector_dim=1,
+ window_bounds={2,3,1}
+ ROOT dot = s32[2,2,4] dot(dot_lhs_constant, dot_rhs),
+ lhs_contracting_dims={2}, rhs_contracting_dims={1},
+ lhs_batch_dims={0}, rhs_batch_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, R"(
+(scalar-indexed-const
+ (constant s32[2,2,2] s32[2,2,2] {
+ { { 22, 28 },
+ { 49, 64 } },
+ { { 220, 244 },
+ { 301, 334 } } })
+ %indices 3->[2])
+)");
+}
+
+TEST_F(IndexedArrayAnalysisTest, DotOpNegative) {
+ string hlo_text = R"(
+HloModule DotOp
+
+ENTRY main {
+ gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{5,6,7,8},{9,10,11,12}})
+ dot_rhs_constant = s32[2,3] constant(s32[2,3]{{1,2,3},{4,5,6}})
+ indices = s32[2] parameter(0)
+ dot_lhs = s32[3,2] gather(gather_operand, indices),
+ output_window_dims={0},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={3,1}
+ ROOT dot = s32[3,3] dot(dot_lhs, dot_rhs_constant), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+}
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text, "%dot");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index d2af261008..32937b33b3 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) {
auto max_f32 = max_builder.Build();
auto builder = HloComputation::Builder("MapMaxFunction");
- auto lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
- auto rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1})));
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
@@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
+ auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
@@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) {
HloInstruction::CreateParameter(0, r0f32, "x"));
(void)param1;
const2_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
auto const2_f32 = const2_builder.Build();
auto builder = HloComputation::Builder("MapConstFunction");
auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
@@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
@@ -123,10 +123,10 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
auto max_f32 = max_builder.Build();
auto builder = HloComputation::Builder("MapSubFunction");
- auto lhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1, 2, 3, 4})));
- auto rhs = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({4, 3, 2, 1})));
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
builder.AddInstruction(
HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
@@ -142,7 +142,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
+ auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 429c850343..da91262130 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -83,6 +83,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kNegate:
case HloOpcode::kNot:
case HloOpcode::kOr:
+ case HloOpcode::kXor:
case HloOpcode::kOutfeed:
case HloOpcode::kPad:
case HloOpcode::kReal:
@@ -96,8 +97,10 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
+ case HloOpcode::kAfterAll:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
+ case HloOpcode::kTupleSelect:
return false;
// Cheap instructions for reals, but expensive for complex.
@@ -236,6 +239,30 @@ InstructionFusion::ComputeGloballyUnfusable(
if (EffectivelyAtMostUnary(producer)) {
continue;
}
+
+ // If the total size of the inputs is less than or equal to the total size
+ // of the outputs for the producer then duplicating it won't increase the
+ // memory traffic. In that case, we do not forbid fusion of the operation
+ // here.
+ auto total_size = [](const Shape& shape) {
+ int64 size = 0;
+ ShapeUtil::ForEachSubshape(
+ shape,
+ [&size](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ size += ShapeUtil::ElementsIn(subshape);
+ }
+ });
+ return size;
+ };
+ int64 operands_size = 0;
+ for (const HloInstruction* op : producer->operands()) {
+ operands_size += total_size(op->shape());
+ }
+ if (operands_size <= total_size(producer->shape())) {
+ continue;
+ }
+
// Otherwise we will forbid fusing the op unless we can fuse it into
// all of its consumers on all paths.
//
@@ -280,10 +307,8 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// map from HloInstruction* to the instruction's index in the vector. An
// instruction is "removed" from the vector by setting it's element to
// nullptr.
- std::list<HloInstruction*> post_order_list =
+ std::vector<HloInstruction*> post_order =
computation_->MakeInstructionPostOrder();
- std::vector<HloInstruction*> post_order(post_order_list.begin(),
- post_order_list.end());
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
for (size_t i = 0; i < post_order.size(); ++i) {
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index df109df787..9e7a15f033 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
@@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion {
};
TEST_F(InstructionFusionTest, FuseInstructions) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) {
}
TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
fused_computation {
p1 = f32[4,3] parameter(0)
@@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
}
TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@@ -167,7 +167,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
@@ -195,7 +196,7 @@ static int Count(const HloModule& module, HloOpcode op) {
}
TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -220,7 +221,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
//
// p0 -> add -------------------------> sub
// \-> abs1 -> rng -> abs2 -/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -251,14 +252,15 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// p0 -> add -------------------------> sub
// \-> abs1 -> log -> abs2 -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
add = f32[4,3]{1,0} add(p0, p0)
abs1 = f32[4,3]{1,0} abs(add)
log = f32[4,3]{1,0} log(abs1)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
abs2 = f32[4,3]{1,0} abs(log)
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
})")
@@ -282,13 +284,14 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// \ \-> add2 -/
// \-> log -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
add1 = f32[4,3]{1,0} add(p0, p0)
log = f32[4,3]{1,0} log(p0)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
add2 = f32[4,3]{1,0} add(log, add1)
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
})")
@@ -314,14 +317,15 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// \------> sub1
// log -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
add1 = f32[4,3]{1,0} add(p0, p0)
add2 = f32[4,3]{1,0} add(add1, add1)
log = f32[4,3]{1,0} log(add2)
- send = f32[4,3]{1,0} send(log), channel_id=0
+ token = token[] after-all()
+ send = f32[4,3]{1,0} send(log, token), channel_id=0
sub1 = f32[4,3]{1,0} subtract(log, add2)
sub2 = f32[4,3]{1,0} subtract(add2, add1)
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
@@ -352,7 +356,8 @@ TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
- builder.AddInstruction(HloInstruction::CreateSend(unary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
@@ -375,7 +380,8 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
- builder.AddInstruction(HloInstruction::CreateSend(binary1, 0));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
@@ -390,7 +396,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
TEST_F(InstructionFusionTest,
WideningConvertsAreAlwaysDuplicableIntoConsumers) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY Test {
p0 = f16[100] parameter(0)
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 524d3234eb..8652599dc6 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -74,7 +74,7 @@ cc_library(
hdrs = ["executable.h"],
deps = [
":executor",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index c166653068..9f8f4bda87 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -44,7 +44,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<LayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout());
+ hlo_module->mutable_entry_computation_layout());
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 029e71058a..8d40c08d55 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -75,9 +75,9 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
// consumes.
std::vector<std::unique_ptr<Literal>> arg_literals;
for (int64 p = 0; p < computation->num_parameters(); ++p) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> arg_literal,
- transfer_manager->TransferLiteralFromDevice(executor, *arguments[p]));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> arg_literal,
+ transfer_manager->TransferLiteralFromDevice(
+ run_options->stream(), *arguments[p]));
arg_literals.push_back(std::move(arg_literal));
}
@@ -96,7 +96,7 @@ StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
result_literal->shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
- executor, *result_literal, result));
+ run_options->stream(), *result_literal, result));
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc
index 97e9fa2c8e..4fb67bd0b7 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executor.cc
@@ -53,6 +53,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, void *host_dst,
AsExecutorStream(stream)->EnqueueTask([this, host_dst, dev_src, size]() {
port::Status ok = SynchronousMemcpy(host_dst, dev_src, size);
});
+ AsExecutorStream(stream)->BlockUntilDone();
return true;
}
@@ -61,6 +62,7 @@ bool XlaInterpreterExecutor::Memcpy(Stream *stream, DeviceMemoryBase *dev_dst,
AsExecutorStream(stream)->EnqueueTask([this, dev_dst, host_src, size]() {
port::Status ok = SynchronousMemcpy(dev_dst, host_src, size);
});
+ AsExecutorStream(stream)->BlockUntilDone();
return true;
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 7067b6f86a..fedc83c8f8 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -175,41 +175,32 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
TF_RETURN_IF_ERROR(
LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
- const BufferLayoutConstraint* curr_constraint =
- GetBufferLayoutConstraint(buffer);
- if (curr_constraint != nullptr) {
- if (LayoutUtil::Equal(curr_constraint->layout(), layout)) {
+ auto iter = buffer_constraints_.find(&buffer);
+ if (iter != buffer_constraints_.end()) {
+ const BufferLayoutConstraint& curr_constraint = iter->second;
+ if (LayoutUtil::Equal(curr_constraint.layout(), layout)) {
// New constraint matches existing constraint. Nothing to do.
return Status::OK();
}
- if (curr_constraint->mandatory()) {
+ if (curr_constraint.mandatory()) {
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
buffer.ToString().c_str(),
- LayoutUtil::HumanString(curr_constraint->layout()).c_str(),
+ LayoutUtil::HumanString(curr_constraint.layout()).c_str(),
LayoutUtil::HumanString(layout).c_str());
}
- }
-
- auto iter = buffer_constraints_.find(&buffer);
- bool overwrite = iter != buffer_constraints_.end();
- if (!overwrite) {
+ iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
+ } else {
+ TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1)
+ << buffer.ToString();
iter = buffer_constraints_
.insert(std::make_pair(
&buffer,
BufferLayoutConstraint(layout, buffer, mandatory, dfs)))
.first;
- } else {
- iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs);
}
added_constraints_.push_back(&iter->second);
-
- // Remove buffer from the set of unconstrained buffers.
- TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) ==
- static_cast<int>(!overwrite));
- unconstrained_buffer_ids_.erase(buffer.id());
-
return Status::OK();
}
@@ -716,7 +707,8 @@ Status CheckParameterLayout(HloInstruction* parameter,
const ComputationLayout& computation_layout) {
const ShapeLayout& parameter_layout =
computation_layout.parameter_layout(parameter->parameter_number());
- if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
+ if (parameter_layout.LayoutIsSet() &&
+ !parameter_layout.MatchesLayoutInShape(parameter->shape())) {
return InternalError(
"parameter instruction %s does not match layout of computation "
"shape: %s",
@@ -936,14 +928,15 @@ LayoutAssignment::LayoutAssignment(
ComputationLayout* entry_computation_layout,
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
+ saved_entry_computation_layout_(*entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
+ if (channel_layout_constraints_ != nullptr) {
+ // Save a copy of the input ChannelLayoutConstraints so that we can reset it
+ // if we have to undo previous operations (ClearPreviousPassSideEffects()).
+ channel_constraints_ = *channel_layout_constraints_;
+ }
VLOG(1) << "Entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
- // Layouts of all parameter instructions must be set.
- for (const ShapeLayout& parameter_layout :
- entry_computation_layout_->parameter_layouts()) {
- CHECK(parameter_layout.LayoutIsSet());
- }
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1572,6 +1565,13 @@ Status LayoutAssignment::RunOnComputation(
// Propagates layouts from mandatory and backend constraints.
TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+ // Prior to applying default layouts, we take note of all HLO instructions
+ // which lack a layout constraint.
+ for (LogicalBuffer::Id buffer_id : constraints.unconstrained_buffer_ids()) {
+ unconstrained_layout_instructions_.insert(
+ points_to_analysis.GetBuffer(buffer_id).instruction());
+ }
+
// While any unconstrained buffers remain, pick an arbitrary buffer, give it a
// layout and propagate the change.
while (!constraints.unconstrained_buffer_ids().empty()) {
@@ -1614,13 +1614,58 @@ Status LayoutAssignment::RunOnComputation(
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
+ if (channel_constraints != nullptr) {
+ TF_RETURN_IF_ERROR(
+ ConstrainChannelLayouts(computation, channel_constraints));
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::ConstrainChannelLayouts(
+ HloComputation* computation,
+ ChannelLayoutConstraints* channel_constraints) {
+ // We go through the kRecvDone before. These must either impose their layout,
+ // of find a matching one already existing (ConstrainChannel() returns
+ // nullptr).
for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kRecvDone) {
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(),
+ ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
+ TF_RET_CHECK(layout == nullptr)
+ << instruction->ToString()
+ << " cannot constrain layout as it was set to "
+ << LayoutUtil::HumanString(*layout);
+ }
+ }
+ // After that we go through the kSend. These are likely going to have a kCopy
+ // as operand (otherwise we add it), so in case the constrained layout does
+ // not match, we can change the kCopy layout (and the kSend one as well).
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
- channel_constraints->ConstrainChannel(
- instruction->channel_id(), instruction->operand(0)->shape().layout());
- } else if (instruction->opcode() == HloOpcode::kRecvDone) {
- channel_constraints->ConstrainChannel(instruction->channel_id(),
- instruction->shape().layout());
+ HloInstruction* operand = instruction->mutable_operand(0);
+ const Layout* layout = channel_constraints->ConstrainChannel(
+ instruction->channel_id(), operand->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the kSend wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ Shape* send_shape =
+ ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
+ *send_shape = shape;
+ }
}
}
return Status::OK();
@@ -1672,13 +1717,14 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
// when seen from an outer instruction, which has across-computation
// constraints to impose.
// For example, the kWhile instruction needs to enforce the same layouts for
- // the parameters and root of the bosy, as well as the condition parameters.
+ // the parameters and root of the body, as well as the condition parameters.
// Similarly, the kConditional instruction needs to enforce the same layouts
// for the root of the true and false computations.
// So in the first pass, while allowing the layouts to flow to parameters and
// root, we also fix up the eventually inconsistent ComputationLayout, which
// will be then made mandatory by the second pass.
for (int64 i = 0; i < 2; ++i) {
+ VLOG(5) << "Running " << (i == 0 ? "un" : "") << "constrained pass";
TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
TF_ASSIGN_OR_RETURN(auto points_to_analysis,
TuplePointsToAnalysis::Run(module));
@@ -1716,10 +1762,12 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
Status LayoutAssignment::Init() {
computation_layouts_.clear();
+ *entry_computation_layout_ = saved_entry_computation_layout_;
return Status::OK();
}
Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
+ VLOG(5) << "Clearing previous side effects";
// Clear all the copies which have been added, and all the related
// instructions (like GTE and tuples).
int64 removed_copies = 0;
@@ -1737,12 +1785,14 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
}
}
added_copies_.clear();
+ unconstrained_layout_instructions_.clear();
if (removed_copies > 0) {
TupleSimplifier tuple_simplifier;
HloDCE dce;
TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
TF_RETURN_IF_ERROR(dce.Run(module).status());
}
+ ResetChannelConstraints();
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index c287cca0c5..b75ecb311a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -249,25 +249,30 @@ class ChannelLayoutConstraints {
// Given `shape`, apply the layout for `channel_id`. `channel_id` must already
// be constrained.
Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const {
- CHECK(IsChannelConstrained(channel_id));
- *shape.mutable_layout() = constraints_.at(channel_id);
+ auto it = constraints_.find(channel_id);
+ CHECK(it != constraints_.end()) << "Channel " << channel_id;
+ *shape.mutable_layout() = it->second;
return shape;
}
// Returns the layout constraint for `channel_id`, which must already be
// constrained.
- Layout LayoutForChannel(int64 channel_id) const {
- CHECK(IsChannelConstrained(channel_id));
- return constraints_.at(channel_id);
+ const Layout& LayoutForChannel(int64 channel_id) const {
+ auto it = constraints_.find(channel_id);
+ CHECK(it != constraints_.end()) << "Channel " << channel_id;
+ return it->second;
}
// Adds a new layout constraint for `channel_id`. If a constraint for
- // `channel_id` already exists, this operation requires that the new layout is
- // the same as the previously constrained layout.
- void ConstrainChannel(int64 channel_id, const Layout& layout) {
- CHECK(!IsChannelConstrained(channel_id) ||
- LayoutUtil::Equal(layout, constraints_[channel_id]));
- constraints_[channel_id] = layout;
+ // `channel_id` has been added, this API returns nullptr, otherwise returns
+ // the layout which has already been set for the channel.
+ const Layout* ConstrainChannel(int64 channel_id, const Layout& layout) {
+ auto it = constraints_.emplace(std::make_pair(channel_id, layout));
+ if (it.second) {
+ return nullptr;
+ }
+ return LayoutUtil::Equal(layout, it.first->second) ? nullptr
+ : &it.first->second;
}
private:
@@ -427,8 +432,13 @@ class LayoutAssignment : public HloPassInterface {
Status PropagateComputationLayouts(HloComputation* computation,
ComputationLayout* computation_layout);
+ // The pointer to the ComputationLayout passed as constructor parameter.
ComputationLayout* entry_computation_layout_;
+ // A copy of entry_computation_layout_ used to reset it to the initial values
+ // during the multiple passes done by the layout assignment operation.
+ ComputationLayout saved_entry_computation_layout_;
+
protected:
// Sets up the copy instruction according to the characteristic (sharding,
// metadata, ...) of the reference instruction. The index argument is used
@@ -464,6 +474,20 @@ class LayoutAssignment : public HloPassInterface {
// itself).
Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
+ // Apply the channel layout constraints by populating the channel_constraints
+ // data structure passed in at constructor time. Eventually adds copies in
+ // case two ends of a channel ended up with a different leyout.
+ Status ConstrainChannelLayouts(HloComputation* computation,
+ ChannelLayoutConstraints* channel_constraints);
+
+ // Resets the input ChannelLayoutConstraints to the original copy received
+ // from the constructor input.
+ void ResetChannelConstraints() {
+ if (channel_layout_constraints_ != nullptr) {
+ *channel_layout_constraints_ = channel_constraints_;
+ }
+ }
+
// Map containing the layouts of all computations assigned so
// far. Computations are handled in a topological sort where computations are
// handled before their caller instructions so the layouts of caller
@@ -474,7 +498,19 @@ class LayoutAssignment : public HloPassInterface {
// here.
tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
- ChannelLayoutConstraints* channel_layout_constraints_;
+ // The pointer to the channel layout constraints passed in with the
+ // constructor. If not nullptr, this is an input/output argument.
+ ChannelLayoutConstraints* channel_layout_constraints_ = nullptr;
+
+ // A copy of the input layout constraints used to reset the above pointer in
+ // case we have to undo operations due to the multiple passes over the
+ // computations/instructions.
+ ChannelLayoutConstraints channel_constraints_;
+
+ // The set of HLO instructions which lacked any layout constraint, thus
+ // receiving propagated default layouts.
+ tensorflow::gtl::FlatSet<const HloInstruction*>
+ unconstrained_layout_instructions_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7508013199..a16fa75e30 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -29,13 +29,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -52,10 +52,18 @@ using ::testing::ElementsAre;
class LayoutAssignmentTest : public HloTestBase {
protected:
void AssignLayouts(HloModule* module,
- ComputationLayout* entry_computation_layout) {
- LayoutAssignment layout_assignment(entry_computation_layout);
+ ComputationLayout* entry_computation_layout,
+ ChannelLayoutConstraints* channel_constraints = nullptr) {
+ LayoutAssignment layout_assignment(
+ entry_computation_layout, /*channel_constraints=*/channel_constraints);
EXPECT_IS_OK(layout_assignment.Run(module).status());
}
+
+ std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) {
+ auto minor_to_major =
+ FindInstruction(module, name)->shape().layout().minor_to_major();
+ return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
+ }
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
@@ -133,9 +141,9 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
for (auto& minor_to_major : minor_to_majors) {
auto builder = HloComputation::Builder(TestName());
- auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal1 = LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
- auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+ auto constant_literal2 = LiteralUtil::CreateR2WithLayout<float>(
{{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
Shape ashape = constant_literal1->shape();
@@ -184,10 +192,10 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
// match their source).
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -221,10 +229,10 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
// Verify layouts of a select with tuple operands is assigned properly.
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
auto tuple0 = builder.AddInstruction(
HloInstruction::CreateTuple({constant0, constant1}));
@@ -232,7 +240,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) {
HloInstruction::CreateTuple({constant0, constant1}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
@@ -266,7 +274,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
// tuple and assigning the layouts of the copied arrays as needed.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
auto inner_tuple =
builder.AddInstruction(HloInstruction::CreateTuple({constant}));
auto nested_tuple = builder.AddInstruction(
@@ -576,7 +584,7 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, constant, {}));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -651,7 +659,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = tools::Parse(module_str).ValueOrDie();
+ auto module = ParseHloString(module_str).ValueOrDie();
module =
backend()
@@ -691,7 +699,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = tools::Parse(module_str).ValueOrDie();
+ auto module = ParseHloString(module_str).ValueOrDie();
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
@@ -707,17 +715,10 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
LayoutUtil::MakeLayout({2, 1, 0}));
AssignLayouts(module.get(), &computation_layout);
- auto layout_of = [&](tensorflow::StringPiece name) {
- return FindInstruction(module.get(), name)
- ->shape()
- .layout()
- .minor_to_major();
- };
-
- EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2));
- EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0));
- EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1));
- EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "gte0"), ElementsAre(0, 1, 2));
+ EXPECT_THAT(LayoutOf(module.get(), "gte1a"), ElementsAre(1, 2, 0));
+ EXPECT_THAT(LayoutOf(module.get(), "gte1b"), ElementsAre(2, 0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "fresult"), ElementsAre(2, 1, 0));
EXPECT_THAT(FindInstruction(module.get(), "gte1")
->shape()
.tuple_shapes(0)
@@ -769,9 +770,12 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
false_builder.AddInstruction(
HloInstruction::CreateParameter(0, tshape, "param"));
// Using infeed as layout assignment does not mess up with it.
- auto infeed =
- false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, ""));
- false_builder.AddInstruction(HloInstruction::CreateTuple({infeed}));
+ auto token = false_builder.AddInstruction(HloInstruction::CreateToken());
+ auto infeed = false_builder.AddInstruction(
+ HloInstruction::CreateInfeed(xshape, token, ""));
+ auto infeed_data = false_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(xshape, infeed, 0));
+ false_builder.AddInstruction(HloInstruction::CreateTuple({infeed_data}));
}
HloComputation* false_computation =
module->AddEmbeddedComputation(false_builder.Build());
@@ -798,7 +802,7 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
builder.AddInstruction(HloInstruction::CreateUnary(
constant0->shape(), HloOpcode::kBitcast, constant0));
@@ -816,5 +820,46 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
"Unexpected bitcast operation seen during layout assignment"));
}
+TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
+ // Pin non matching layouts to parameter and root.
+ const char* module_str = R"(
+ HloModule test_module
+
+ ENTRY entry_computation {
+ param = (f32[2,2]) parameter(0)
+ gte = f32[2,2] get-tuple-element(param), index=0
+ token = token[] after-all()
+ recv = (f32[2,2], u32[], token[]) recv(token), channel_id=1, sharding={maximal device=1}
+ recv-done = (f32[2,2], token[]) recv-done(recv), channel_id=1,
+ sharding={maximal device=1}
+ ROOT root = f32[2,2] get-tuple-element(recv-done), index=0
+ send = (f32[2,2], u32[], token[]) send(gte, token), channel_id=1,
+ sharding={maximal device=0}
+ send-done = token[] send-done(send), channel_id=1, sharding={maximal device=0}
+ }
+ )";
+
+ auto module = ParseHloString(module_str).ValueOrDie();
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ Shape param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+ TF_ASSERT_OK(
+ computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
+ param_shape));
+ computation_layout.mutable_result_layout()->ResetLayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ ChannelLayoutConstraints channel_constraints;
+ AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+
+ EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "root"), ElementsAre(1, 0));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::GetSubshape(
+ FindInstruction(module.get(), "send")->shape(), {0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index f1e7fc2953..6f1e04a1c6 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -21,6 +21,11 @@ filegroup(
]),
)
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
cc_library(
name = "alias_analysis",
srcs = ["alias_analysis.cc"],
@@ -37,12 +42,25 @@ cc_library(
],
)
+tf_cc_test(
+ name = "alias_analysis_test",
+ srcs = ["alias_analysis_test.cc"],
+ deps = [
+ ":alias_analysis",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "llvm_util",
srcs = ["llvm_util.cc"],
hdrs = ["llvm_util.h"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@@ -107,11 +125,30 @@ cc_library(
)
cc_library(
+ name = "kernel_tiling",
+ srcs = ["kernel_tiling.cc"],
+ hdrs = ["kernel_tiling.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
name = "fused_ir_emitter",
srcs = ["fused_ir_emitter.cc"],
hdrs = ["fused_ir_emitter.h"],
deps = [
":ir_array",
+ ":kernel_tiling",
":llvm_util",
":loop_emitter",
":tuple_ops",
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index 21bca1d6be..93a8c130e1 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -32,15 +32,17 @@ static const BufferAllocation* kParameterAllocation = new BufferAllocation(
LogicalBuffer::Color(0));
void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
- llvm_ir::IrArray* array) {
+ llvm_ir::IrArray* array,
+ const ShapeIndex& index) {
BufferAllocation::Slice buffer_slice;
- if (hlo.opcode() == HloOpcode::kParameter) {
- // Parameters may alias with each other but may not alias with our temporary
- // buffers.
+ if (hlo.opcode() == HloOpcode::kParameter &&
+ hlo.parent() == hlo.parent()->parent()->entry_computation()) {
+ // Entry computation parameters may alias with each other but may not alias
+ // with our temporary buffers.
buffer_slice = BufferAllocation::Slice(kParameterAllocation, 0, 0);
} else {
const std::set<BufferAllocation::Slice> slices =
- assignment_.GetAllSlices(&hlo, /*index=*/{});
+ assignment_.GetAllSlices(&hlo, index);
if (slices.empty() || slices.size() > 1) {
// Skip HLOs which don't have a buffer assigned or for which the
// buffer can't be determined statically. We cannot determine their
@@ -137,16 +139,18 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
// 2. Operands of users of the given hlo.
// 3. Operands of the given hlo.
//
- // This set can be increased as we need. For now only consider top-level
- // buffers (index = {}) not buffers nested within the instruction's
- // operands/output which are not typically touched.
+ // This set can be increased as we need.
std::vector<const LogicalBuffer*> worklist;
auto add_buffers_to_worklist =
[&worklist, &assignment](const HloInstruction* instruction) {
- for (const LogicalBuffer* buffer :
- assignment.GetSourceBuffers(instruction, /*index=*/{})) {
- worklist.push_back(buffer);
- }
+ ShapeUtil::ForEachSubshape(
+ instruction->shape(),
+ [&](const Shape& /*shape*/, const ShapeIndex& index) {
+ for (const LogicalBuffer* buffer :
+ assignment.GetSourceBuffers(instruction, index)) {
+ worklist.push_back(buffer);
+ }
+ });
};
for (HloInstruction* user : hlo.users()) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index 5244ac61e5..fe9eab93aa 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -38,7 +38,8 @@ class AliasAnalysis {
// Augments IrArray with aliasing information.
void AddAliasingInformationToIrArray(const HloInstruction& hlo,
- llvm_ir::IrArray* array);
+ llvm_ir::IrArray* array,
+ const ShapeIndex& index = {});
private:
// Returns a unique alias domain for this emitter.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
new file mode 100644
index 0000000000..2552ff4a6a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class AliasAnalysisTest : public CpuCodegenTest {};
+
+void FakeCustomCallTarget(float* out, float** in) {}
+
+REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
+
+TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) {
+ const char* hlo_string = R"(
+HloModule while
+
+body {
+ const.0.125 = f32[] constant(0.125)
+ body.state = f32[] parameter(0)
+ ROOT add.2.2 = f32[] add(const.0.125, body.state)
+}
+
+condition {
+ const.100 = f32[] constant(100)
+ condition.state = f32[] parameter(0)
+ addend = f32[] custom-call(condition.state), custom_call_target="FakeCustomCallTarget"
+ add = f32[] add(addend, condition.state)
+ ROOT greater-than = pred[] greater-than(const.100, add)
+}
+
+ENTRY while3 {
+ const.0 = f32[] constant(0)
+ ROOT while = f32[] while(const.0), condition=condition, body=body
+}
+)";
+
+ CompileAndVerifyIr(hlo_string, R"(
+; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
+; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
+; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:.*]]
+;
+; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
+; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
+; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
+; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
+;
+; CHECK-LABEL: @while3(
+
+![[alias_scope_md_for_store]] = !{![[buffer_idx_0:.*]]}
+![[buffer_idx_0]] = !{!"buffer: {index:0, offset:0, size:4}", ![[aa_md_root:.*]]}
+![[aa_md_root]] = !{!"XLA global AA domain"}
+![[buffer_idx_1:.*]] = !{!"buffer: {index:1, offset:0, size:4}", !3}
+![[buffer_idx_1_offset_16:.*]] = !{!"buffer: {index:1, offset:16, size:1}", !3}
+![[noalias_md_for_load]] = !{![[buffer_idx_1_offset_16]], ![[buffer_idx_1]]}
+}
+)");
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
index f172b1d87c..b12ce97e28 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
*ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
/*Name=*/"");
+ llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
+ global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
generators_[constant] = [=](const IrArray::Index& index) {
- return IrArray(global, constant->shape())
+ return IrArray(shape_constant, constant->shape())
.EmitReadArrayElement(index, ir_builder_);
};
@@ -117,7 +119,24 @@ Status FusedIrEmitter::HandleGetTupleElement(
}
Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
- generators_[parameter] = [=](const IrArray::Index& index) {
+ generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* {
+ if (tiled_parameter_info_) {
+ if (llvm::Value* param_tile_buffer =
+ tiled_parameter_info_->GetBufferForParameter(
+ parameter->parameter_number())) {
+ // TODO(jlebar): Add AA metadata to this load. Tile buffers are global
+ // variables, so LLVM's points-to analysis doesn't help us much. And we
+ // want the AA info to be present before address spaces are inferred
+ // (which is pretty late in the pipeline), so even if we had
+ // address-space-based AA in LLVM, it wouldn't help us much here.
+ return ir_builder_->CreateLoad(
+ ir_builder_->CreateGEP(
+ param_tile_buffer,
+ {index.GetConstantWithIndexType(0), tiled_parameter_info_->x(),
+ tiled_parameter_info_->y()}),
+ "tiled_buffer");
+ }
+ }
return parameter_arrays_[parameter->parameter_number()]
.EmitReadArrayElement(index, ir_builder_);
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
index b3b6026ef1..a6ceec7b23 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -56,6 +57,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
ElementalIrEmitter* elemental_emitter)
: parameter_arrays_(parameter_arrays),
+ tiled_parameter_info_(nullptr),
elemental_emitter_(elemental_emitter),
ir_builder_(elemental_emitter->ir_builder()),
module_(elemental_emitter->module()) {}
@@ -86,9 +88,14 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
return it->second;
}
+ void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) {
+ tiled_parameter_info_ = info;
+ }
+
private:
// Arrays of parameters of fusion instruction
tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+ const llvm_ir::TiledParameterInfo* tiled_parameter_info_;
ElementalIrEmitter* elemental_emitter_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
index 7323abeb20..dcf9838d80 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -29,9 +29,9 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
-static void Delinearize(std::vector<llvm::Value*>* multidim,
- llvm::Value* linear, const Shape& shape,
- llvm::IRBuilder<>* ir_builder) {
+void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
+ llvm::Value* linear, const Shape& shape,
+ llvm::IRBuilder<>* ir_builder) const {
int64 divisor = 1;
const Layout& layout = shape.layout();
for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
@@ -48,10 +48,11 @@ static void Delinearize(std::vector<llvm::Value*>* multidim,
// useful because cuda-memcheck can't help us much in XLA: Most of our
// memory lives in one big allocation, so cuda-memcheck can't detect
// out-of-bounds accesses.
- auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor));
+ auto* quot =
+ ir_builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
if (i < layout.minor_to_major_size() - 1) {
(*multidim)[dimension] = ir_builder->CreateURem(
- quot, ir_builder->getInt64(size_of_current_dimension));
+ quot, GetConstantWithIndexType(size_of_current_dimension));
} else {
(*multidim)[dimension] = quot;
}
@@ -65,6 +66,8 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
linear_(linear),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ CHECK_NE(linear, nullptr);
+ index_type_ = linear->getType();
CHECK(LayoutUtil::HasLayout(shape))
<< "Shape " << ShapeUtil::HumanStringWithLayout(shape)
<< " should have a layout.";
@@ -77,6 +80,13 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
linear_(linear),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ if (size()) {
+ index_type_ = multidim_[0]->getType();
+ } else {
+ CHECK_NE(linear_, nullptr);
+ index_type_ = linear_->getType();
+ }
+ CHECK_NE(index_type_, nullptr);
CHECK_EQ(shape.dimensions_size(), multidim.size());
CHECK(LayoutUtil::HasLayout(shape))
<< "Shape " << ShapeUtil::HumanStringWithLayout(shape)
@@ -88,6 +98,9 @@ IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
: multidim_(multidim.begin(), multidim.end()),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ CHECK_GT(multidim_.size(), 0);
+ index_type_ = multidim[0]->getType();
+ CHECK_NE(index_type_, nullptr);
CHECK_EQ(shape.dimensions_size(), multidim.size());
CHECK(LayoutUtil::HasLayout(shape));
}
@@ -130,15 +143,15 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
std::vector<llvm::Value*> source_multidim_index(
- ShapeUtil::Rank(input_shape),
- llvm::UndefValue::get(builder->getInt64Ty()));
+ ShapeUtil::Rank(input_shape), llvm::UndefValue::get(index_type_));
// We compute the source indices in each common factor from only the target
// indices in the same common factor.
for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
llvm::Value* logical_linear_index =
Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
multidim_, common_factors[k].second,
- common_factors[k + 1].second - common_factors[k].second))
+ common_factors[k + 1].second - common_factors[k].second),
+ index_type_)
.Linearize(
tensorflow::gtl::ArraySlice<int64>(
AsInt64Slice(output_shape.dimensions()),
@@ -150,9 +163,10 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
// linear index by each dimension size.
for (int64 i = common_factors[k + 1].first - 1;
i >= common_factors[k].first; --i) {
- llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i));
+ llvm::Value* divisor =
+ GetConstantWithIndexType(input_shape.dimensions(i));
if (input_shape.dimensions(i) == 1) {
- source_multidim_index[i] = builder->getInt64(0);
+ source_multidim_index[i] = GetConstantWithIndexType(0);
} else if (i == common_factors[k].first) {
source_multidim_index[i] = logical_linear_index;
} else {
@@ -168,14 +182,14 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
return Index(source_multidim_index, linear(), input_shape);
}
- return Index(source_multidim_index);
+ return Index(source_multidim_index, index_type_);
}
IrArray::Index IrArray::Index::SourceIndexOfSlice(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
tensorflow::gtl::ArraySlice<int64> strides,
llvm::IRBuilder<>* builder) const {
- Index source_index(multidim_.size());
+ Index source_index(index_type_, multidim_.size());
for (int i = 0; i < multidim_.size(); ++i) {
int64 stride = strides[i];
auto type = multidim_[i]->getType();
@@ -224,11 +238,12 @@ IrArray::Index IrArray::Index::SourceIndexOfBitcast(
// the physical index of the element in the buffer. This is like Linearize,
// but takes the layout into account.
int64 scale = 1;
- llvm::Value* linear_index = builder->getInt64(0);
+ llvm::Value* linear_index = GetConstantWithIndexType(0);
for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
linear_index = builder->CreateAdd(
linear_index,
- builder->CreateMul(multidim_[dimension], builder->getInt64(scale), "",
+ builder->CreateMul(multidim_[dimension],
+ GetConstantWithIndexType(scale), "",
/*HasNUW=*/true, /*HasNSW=*/true),
"", /*HasNUW=*/true, /*HasNSW=*/true);
scale *= shape.dimensions(dimension);
@@ -252,7 +267,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
}
if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
!LayoutUtil::HasLayout(shape)) {
- return Index(source_index);
+ return Index(source_index, index_type_);
}
// High-level idea: we can reuse the linear index if the broadcasted
// dimensions are contiguous, and this part of the operation is a bitcast.
@@ -274,7 +289,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
bool contiguous_broadcast_dimensions =
max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
if (!contiguous_broadcast_dimensions) {
- return Index(source_index);
+ return Index(source_index, index_type_);
}
// Check if the mapped dimensions are a bitcast.
std::vector<int64> operand_logical_to_physical =
@@ -282,7 +297,7 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
for (int64 i = 0; i < rank; ++i) {
if (operand_logical_to_physical[i] !=
logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
- return Index(source_index);
+ return Index(source_index, index_type_);
}
}
llvm::Value* linear = linear_;
@@ -291,7 +306,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
}
if (divisor > 1) {
- linear = builder->CreateUDiv(linear, builder->getInt64(divisor));
+ linear = builder->CreateUDiv(
+ linear,
+ IrArray::Index(linear->getType()).GetConstantWithIndexType(divisor));
}
if (min_broadcasted_dimension > 0) {
int64 mod = 1;
@@ -299,7 +316,9 @@ IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
++i) {
mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
}
- linear = builder->CreateURem(linear, builder->getInt64(mod));
+ linear = builder->CreateURem(
+ linear,
+ IrArray::Index(linear->getType()).GetConstantWithIndexType(mod));
}
return Index(source_index, linear, operand_shape);
}
@@ -309,12 +328,13 @@ llvm::Value* IrArray::Index::Linearize(
llvm::IRBuilder<>* builder) const {
// Each dimension is multiplied by the product of the sizes of all
// earlier dimensions and added to the accumulator logical_linear_index.
- llvm::Value* logical_linear_index = builder->getInt64(0);
+ llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
int64 multiplier = 1;
for (ssize_t i = size() - 1; i >= 0; --i) {
llvm::Value* addend =
- builder->CreateMul((*this)[i], builder->getInt64(multiplier), "",
+ builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
/*HasNUW=*/true, /*HasNSW=*/true);
+ addend = builder->CreateZExtOrTrunc(addend, index_type_);
logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
/*HasNUW=*/true, /*HasNSW=*/true);
multiplier *= dimensions[i];
@@ -349,7 +369,8 @@ llvm::Value* IrArray::EmitArrayElementAddress(
// index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
// produce better code in some cases.
auto dim = shape_->dimensions(i);
- actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
+ actual_index.push_back(
+ dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
}
// "base_ptr_" has the type of "<ir_type_for_its_shape>*"
@@ -357,7 +378,9 @@ llvm::Value* IrArray::EmitArrayElementAddress(
// should be computed by
//
// getelementptr base_ptr_, 0, most major index, ..., most minor index
- std::vector<llvm::Value*> gep_indices(1, ir_builder->getInt64(0));
+ CHECK_GT(index.size(), 0);
+ std::vector<llvm::Value*> gep_indices(
+ 1, llvm::ConstantInt::get(index[0]->getType(), 0));
for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) {
int64 dimension = LayoutUtil::Major(shape_->layout(), i);
gep_indices.push_back(actual_index[dimension]);
@@ -399,9 +422,11 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
llvm::IRBuilder<>* ir_builder) const {
llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
- return IrArray(
+ IrArray new_irarray(
ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
new_shape);
+ new_irarray.metadata_ = metadata_;
+ return new_irarray;
}
/* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
@@ -410,7 +435,9 @@ IrArray IrArray::CastToShape(const Shape& new_shape,
llvm::IRBuilder<>* ir_builder) {
Index new_index = index;
new_index[which_dimension] = ir_builder->CreateAdd(
- index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true,
+ index[which_dimension],
+ llvm::ConstantInt::get(index[which_dimension]->getType(), addend), "",
+ /*HasNUW=*/true,
/*HasNSW=*/true);
return new_index;
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
index 4c3195c29c..0777c49923 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -53,18 +53,38 @@ class IrArray {
// multidimensional index, which LLVM DCE can delete.
class Index {
public:
- // Constructs an empty zero-dimensional index.
- Index() {}
-
// Constructs an index of rank "size". Each dimension of the index is
// initialized to "value".
- explicit Index(size_t size, llvm::Value* value = nullptr)
- : multidim_(size, value) {}
+ explicit Index(size_t size, llvm::Value* value)
+ : multidim_(size, value), index_type_(value->getType()) {
+ CHECK_NE(index_type_, nullptr);
+ }
+
+ // Constructs an index of rank "size". Each dimension of the index is
+ // initialized to nullptr.
+ explicit Index(llvm::Type* index_ty, size_t size = 0)
+ : multidim_(size, nullptr), index_type_(index_ty) {
+ CHECK(index_ty->isIntegerTy());
+ }
// Constructs an index from multi-dimensional index "multidim". The linear
// index is set to nullptr.
- explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim)
- : multidim_(multidim.begin(), multidim.end()) {}
+ explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ llvm::Type* index_ty = nullptr)
+ : multidim_(multidim.begin(), multidim.end()) {
+ if (size() == 0) {
+ index_type_ = index_ty;
+ } else {
+ index_type_ = (*this)[0]->getType();
+ if (index_ty != nullptr) {
+ CHECK_EQ(index_type_, index_ty);
+ }
+ }
+ CHECK_NE(index_type_, nullptr);
+ CHECK(c_all_of(multidim, [&](llvm::Value* v) {
+ return index_type_ == v->getType();
+ }));
+ }
// Constructs an index from linear index "linear" and computes the
// multi-dimensional index from "linear" and "shape". "ir_builder" is the IR
@@ -94,19 +114,19 @@ class IrArray {
size_t size() const { return multidim().size(); }
llvm::Value* operator[](size_t i) const { return multidim()[i]; }
- llvm::Value*& operator[](size_t i) { return multidim()[i]; }
+ llvm::Value*& operator[](size_t i) { return mutable_multidim()[i]; }
- void push_back(llvm::Value* value) { multidim().push_back(value); }
+ void push_back(llvm::Value* value) { mutable_multidim().push_back(value); }
void InsertAt(int64 index, llvm::Value* value) {
CHECK_LE(index, size());
- multidim().insert(multidim().begin() + index, value);
+ mutable_multidim().insert(mutable_multidim().begin() + index, value);
}
using iterator = std::vector<llvm::Value*>::iterator;
using const_iterator = std::vector<llvm::Value*>::const_iterator;
- iterator begin() { return multidim().begin(); }
- iterator end() { return multidim().end(); }
+ iterator begin() { return mutable_multidim().begin(); }
+ iterator end() { return mutable_multidim().end(); }
const_iterator begin() const { return multidim().begin(); }
const_iterator end() const { return multidim().end(); }
@@ -154,13 +174,25 @@ class IrArray {
llvm::Value* Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,
llvm::IRBuilder<>* builder) const;
+ llvm::Type* GetType() const { return index_type_; }
+
+ llvm::Constant* GetConstantWithIndexType(int64 c) const {
+ // The LLVM function makes sure that the value can be represented by the
+ // specified type, see ConstantInt::ConstantInt(IntegerType *Ty, const
+ // APInt &V).
+ return llvm::ConstantInt::get(index_type_, c);
+ }
+
private:
// Changing the multi-dimensional index invalidates the linear index.
- std::vector<llvm::Value*>& multidim() {
+ std::vector<llvm::Value*>& mutable_multidim() {
linear_ = nullptr;
return multidim_;
}
+ void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
+ const Shape& shape, llvm::IRBuilder<>* ir_builder) const;
+
std::vector<llvm::Value*> multidim_;
// These values are purely for efficiency; `multidim_` is enough to find the
@@ -177,6 +209,8 @@ class IrArray {
llvm::Value* linear_ = nullptr;
Layout layout_;
std::vector<int64> dims_;
+
+ llvm::Type* index_type_;
};
// Default constructor. Constructs an IrArray in a null status.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index 23d2d4e87d..98d0ceb3e2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -15,53 +15,58 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
-#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
namespace xla {
-void KernelSupportLibrary::For(
+Status KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
- const std::function<void(llvm::Value*, bool)>& for_body_generator) {
- If(ir_builder_->CreateICmpSLT(start, end), [&]() {
- for_body_generator(start, /*is_first_iteration=*/true);
- For(name, ir_builder_->CreateAdd(start, step), end, step,
- [&](llvm::Value* iv) { for_body_generator(iv, false); });
+ const std::function<Status(llvm::Value*, bool)>& for_body_generator) {
+ return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status {
+ TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true));
+ return For(name, ir_builder_->CreateAdd(start, step), end, step,
+ [&](llvm::Value* iv) { return for_body_generator(iv, false); });
});
}
-void KernelSupportLibrary::For(
+Status KernelSupportLibrary::For(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step, bool peel_first_iteration,
- const std::function<void(llvm::Value*, llvm::Value*)>& for_body_generator) {
+ const std::function<Status(llvm::Value*, llvm::Value*)>&
+ for_body_generator) {
if (peel_first_iteration) {
- For(name, start, end, step, true,
- [&](llvm::Value* indvar, bool is_first_iteration) {
- for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration));
- });
+ return For(name, start, end, step, true,
+ [&](llvm::Value* indvar, bool is_first_iteration) -> Status {
+ return for_body_generator(
+ indvar, ir_builder_->getInt1(is_first_iteration));
+ });
} else {
std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
name, start, end, step, ir_builder_,
- /*prevent_unrolling=*/prevent_unrolling_,
+ /*unroll_mode=*/unroll_mode_,
/*prevent_vectorization=*/prevent_vectorization_);
ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
- for_body_generator(loop->GetIndVarValue(),
- /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
- loop->GetIndVarValue(), start));
+ TF_RETURN_IF_ERROR(
+ for_body_generator(loop->GetIndVarValue(),
+ /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
+ loop->GetIndVarValue(), start)));
llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_);
+ return Status::OK();
}
}
-void KernelSupportLibrary::If(
- llvm::Value* condition, const std::function<void()>& true_block_generator,
- const std::function<void()>& false_block_generator) {
+Status KernelSupportLibrary::If(
+ tensorflow::StringPiece name, llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator) {
llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
+ llvm_ir::EmitIfThenElse(condition, name, ir_builder_);
ir_builder_->SetInsertPoint(&if_data.true_block->back());
- true_block_generator();
+ TF_RETURN_IF_ERROR(true_block_generator());
ir_builder_->SetInsertPoint(&if_data.false_block->back());
- false_block_generator();
+ TF_RETURN_IF_ERROR(false_block_generator());
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
+ return Status::OK();
}
void KernelSupportLibrary::EmitAndCallOutlinedKernel(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 64b935bbf1..9d770cc4c3 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -30,13 +31,14 @@ namespace xla {
class KernelSupportLibrary {
public:
// `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
- // If `prevent_unrolling` is true then unrolling is explicitly disabled on
- // every loop generated by this instance of KernelSupportLibrary.
- explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling = true,
- bool prevent_vectorization = true)
+ // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop
+ // generated by this instance of KernelSupportLibrary.
+ explicit KernelSupportLibrary(
+ llvm::IRBuilder<>* ir_builder,
+ llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll,
+ bool prevent_vectorization = true)
: ir_builder_(ir_builder),
- prevent_unrolling_(prevent_unrolling),
+ unroll_mode_(unroll_mode),
prevent_vectorization_(prevent_vectorization) {}
// Generates the following control flow structure:
@@ -46,19 +48,41 @@ class KernelSupportLibrary {
// for (i64 i = `start` + `step`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
// }
- void For(
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<Status(llvm::Value* ind_var,
+ bool is_first_iteration)>& for_body_generator);
+
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
- for_body_generator);
+ for_body_generator) {
+ CHECK_EQ(Status::OK(),
+ For(name, start, end, step,
+ [&](llvm::Value* ind_var, bool is_first_iteration) -> Status {
+ for_body_generator(ind_var, is_first_iteration);
+ return Status::OK();
+ }));
+ }
+
+ Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<Status(llvm::Value* ind_var,
+ bool is_first_iteration)>&
+ for_body_generator) {
+ return For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ }
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
for_body_generator) {
- For(name, /*start=*/ir_builder_->getInt64(start),
- /*end=*/ir_builder_->getInt64(end),
- /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure if `peel_first_iteration` is
@@ -75,46 +99,102 @@ class KernelSupportLibrary {
// for (i64 i = `start`; i s< `end`; i += `step`)
// `for_body_generator(/*ind_var=*/,i,
// /*is_first_iteration=*/,(i != `start`))`;
- void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- llvm::Value* step, bool peel_first_iteration,
- const std::function<void(llvm::Value* ind_var,
- llvm::Value* is_first_iteration)>&
- for_body_generator);
-
- void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
- int64 step, bool peel_first_iteration,
- const std::function<void(llvm::Value* ind_var,
- llvm::Value* is_first_iteration)>&
- for_body_generator) {
- For(name, /*start=*/start, /*end=*/end,
- /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
- for_body_generator);
- }
-
- void For(
+ Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step, bool peel_first_iteration,
+ const std::function<Status(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator);
+
+ void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ llvm::Value* end, llvm::Value* step,
+ bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ TF_CHECK_OK(For(
+ name, start, end, step, peel_first_iteration,
+ [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status {
+ for_body_generator(ind_var, is_first_iteration);
+ return Status::OK();
+ }));
+ }
+
+ Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step, bool peel_first_iteration,
+ const std::function<Status(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ return For(name, /*start=*/start, /*end=*/end,
+ /*step=*/llvm::ConstantInt::get(start->getType(), step),
+ peel_first_iteration, for_body_generator);
+ }
+
+ void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start,
+ llvm::Value* end, int64 step, bool peel_first_iteration,
+ const std::function<void(llvm::Value* ind_var,
+ llvm::Value* is_first_iteration)>&
+ for_body_generator) {
+ ForReturnVoid(name, /*start=*/start, /*end=*/end,
+ /*step=*/llvm::ConstantInt::get(start->getType(), step),
+ peel_first_iteration, for_body_generator);
+ }
+
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ llvm::Value* step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, start, end, step,
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) -> Status {
+ return for_body_generator(indvar);
+ });
+ }
+
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
llvm::Value* step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, start, end, step,
- /*peel_first_iteration=*/false,
- [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ ForReturnVoid(name, start, end, step,
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) {
+ return for_body_generator(indvar);
+ });
+ }
+
+ Status For(
+ tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+ int64 step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
+ /*peel_first_iteration=*/false,
+ [&](llvm::Value* indvar, llvm::Value*) -> Status {
+ return for_body_generator(indvar);
+ });
}
- void For(
+ void ForReturnVoid(
tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, start, end, ir_builder_->getInt64(step),
- /*peel_first_iteration=*/false,
- [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+ ForReturnVoid(name, start, end,
+ llvm::ConstantInt::get(start->getType(), step),
+ for_body_generator);
}
- void For(
+ Status For(
+ tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+ const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
+ return For(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ }
+
+ void ForReturnVoid(
tensorflow::StringPiece name, int64 start, int64 end, int64 step,
const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
- For(name, /*start=*/ir_builder_->getInt64(start),
- /*end=*/ir_builder_->getInt64(end),
- /*step=*/ir_builder_->getInt64(step), for_body_generator);
+ ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start),
+ /*end=*/ir_builder_->getInt64(end),
+ /*step=*/ir_builder_->getInt64(step), for_body_generator);
}
// Generates the following control flow structure:
@@ -123,9 +203,39 @@ class KernelSupportLibrary {
// `true_block_generator()`;
// else
// `false_block_generator()`;
- void If(llvm::Value* condition,
- const std::function<void()>& true_block_generator,
- const std::function<void()>& false_block_generator = []() {});
+ Status If(tensorflow::StringPiece name, llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator =
+ []() -> Status { return Status::OK(); });
+
+ Status If(llvm::Value* condition,
+ const std::function<Status()>& true_block_generator,
+ const std::function<Status()>& false_block_generator =
+ []() -> Status { return Status::OK(); }) {
+ return If("", condition, true_block_generator, false_block_generator);
+ }
+
+ void IfReturnVoid(llvm::Value* condition,
+ const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator = []() {
+ }) {
+ IfReturnVoid("", condition, true_block_generator, false_block_generator);
+ }
+
+ void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition,
+ const std::function<void()>& true_block_generator,
+ const std::function<void()>& false_block_generator = []() {
+ }) {
+ TF_CHECK_OK(If(name, condition,
+ [&]() {
+ true_block_generator();
+ return Status::OK();
+ },
+ [&]() {
+ false_block_generator();
+ return Status::OK();
+ }));
+ }
using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
@@ -183,7 +293,7 @@ class KernelSupportLibrary {
private:
llvm::IRBuilder<>* ir_builder_;
- bool prevent_unrolling_;
+ llvm_ir::UnrollMode unroll_mode_;
bool prevent_vectorization_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
new file mode 100644
index 0000000000..533b75cdae
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -0,0 +1,118 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace llvm_ir {
+
+namespace {
+// Returns the indices of the first elements of all consecutive subarrays of the
+// given array. For example:
+// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
+std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+ std::vector<size_t> is = {0};
+ for (size_t i = 1; i < xs.size(); ++i) {
+ if (1 != xs[i] - xs[i - 1]) {
+ is.push_back(i);
+ }
+ }
+ return is;
+}
+
+// Merges the sequences of dimensions of the given shape which start at the
+// given indices `segs`.
+Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
+ const Shape& shape) {
+ std::vector<int64> dimensions;
+ for (size_t i = 1; i <= segs.size(); ++i) {
+ dimensions.push_back(std::accumulate(
+ shape.dimensions().begin() + segs[i - 1],
+ shape.dimensions().begin() +
+ (segs.size() == i ? shape.dimensions().size() : segs[i]),
+ 1, std::multiplies<int64>()));
+ }
+ return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
+ dimensions);
+}
+} // namespace
+
+tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(
+ const Shape& a, const Shape& b) {
+ if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
+ return tensorflow::gtl::nullopt;
+ }
+
+ std::vector<int64> perm(a.dimensions().size());
+ {
+ auto layout_a_orig = LayoutUtil::MinorToMajor(a);
+ std::vector<int64> layout_a(layout_a_orig.rbegin(), layout_a_orig.rend());
+ auto layout_b_orig = LayoutUtil::MinorToMajor(b);
+ std::vector<int64> layout_b(layout_b_orig.rbegin(), layout_b_orig.rend());
+ for (size_t i = 0; i < perm.size(); ++i) {
+ perm[i] = PositionInContainer(layout_b, layout_a[i]);
+ }
+ }
+ auto segs = ConsecutiveSegments(perm);
+ if ((3 == segs.size() && 0 == perm[0]) || 2 == segs.size()) {
+ Shape norm_a =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
+ Shape reduced_a = MergeDimensions(segs, norm_a);
+ auto reduced_a_dims = reduced_a.dimensions();
+ std::vector<int64> dims_021;
+ if (2 == segs.size()) {
+ // The logical component-0 is of size one.
+ dims_021 = {1, reduced_a_dims[1], reduced_a_dims[0]};
+ } else {
+ dims_021 = {reduced_a_dims[0], reduced_a_dims[2], reduced_a_dims[1]};
+ }
+
+ return dims_021;
+ }
+
+ return tensorflow::gtl::nullopt;
+}
+
+IrArray::Index GetUnreducedOutputIndex(
+ const IrArray::Index& reduced_output_index,
+ const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
+ llvm::IRBuilder<>* ir_builder) {
+ auto bounds = reduced_output_shape.dimensions();
+ auto minor_to_major = reduced_output_shape.layout().minor_to_major();
+ llvm::Value* linear_index = reduced_output_index.GetConstantWithIndexType(0);
+ int64 multiplier = 1;
+ for (int i = 0; i < reduced_output_index.size(); ++i) {
+ int64 dim = minor_to_major[i];
+ llvm::Value* addend = ir_builder->CreateMul(
+ reduced_output_index[dim],
+ reduced_output_index.GetConstantWithIndexType(multiplier),
+ "linearizing",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ linear_index = ir_builder->CreateAdd(linear_index, addend, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ multiplier *= bounds[dim];
+ }
+
+ return IrArray::Index(linear_index, unreduced_output_shape, ir_builder);
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
new file mode 100644
index 0000000000..6f1268fffb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
+
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// About 0-2-1 transpose:
+//
+// If a shape can be viewed as three logical components 0-1-2 in the order of
+// major to minor, a 0-2-1-transpose changes the order of such logical
+// components to 0-2-1. We call the shape being transposed the input shape and
+// the transposed shape the output shape. The logical view of the input and
+// output shapes for the transpose are called the 0-1-2 shape or reduced input
+// shape and the 0-2-1 shape or the reduced output shape respectively. The
+// original input and output shapes are called the unreduced input and output
+// shapes.
+
+// If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the
+// reduced shape of `b` or the 0-2-1 shape.
+tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
+ const Shape& b);
+
+// Return the unreduced output index corresponding to the given reduced output
+// index.
+IrArray::Index GetUnreducedOutputIndex(
+ const IrArray::Index& reduced_output_index,
+ const Shape& reduced_output_shape, const Shape& unreduced_output_shape,
+ llvm::IRBuilder<>* ir_builder);
+
+// A class to represent information for tiled parameters to support IR emission
+// for 021 transpose.
+class TiledParameterInfo {
+ public:
+ TiledParameterInfo(tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers,
+ llvm::Value* y, llvm::Value* x)
+ : param_buffers_(param_buffers), y_(y), x_(x) {}
+
+ llvm::Value* x() const { return x_; }
+ llvm::Value* y() const { return y_; }
+
+ void set_x(llvm::Value* x) { x_ = x; }
+ void set_y(llvm::Value* y) { y_ = y; }
+
+ llvm::Value* GetBufferForParameter(int64 index) const {
+ return param_buffers_[index];
+ }
+
+ private:
+ // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
+ // if the parameter is not tiled.
+ tensorflow::gtl::ArraySlice<llvm::Value*> param_buffers_;
+ // The y coordinate within a tile.
+ llvm::Value* y_;
+ // The x coordinate within a tile.
+ llvm::Value* x_;
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_TILING_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 497b48ff22..c9ae7d3afd 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -34,7 +34,7 @@ namespace llvm_ir {
ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index,
- llvm::Value* step, bool prevent_unrolling,
+ llvm::Value* step, UnrollMode unroll_mode,
bool prevent_vectorization)
: prefix_(std::string(prefix)),
suffix_(std::string(suffix)),
@@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
end_index_(end_index),
step_(step),
insert_before_bb_(nullptr),
- prevent_unrolling_(prevent_unrolling),
+ unroll_mode_(unroll_mode),
prevent_vectorization_(prevent_vectorization) {}
/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling, bool prevent_vectorization) {
+ UnrollMode unroll_mode, bool prevent_vectorization) {
std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
- end_index, step, prevent_unrolling,
+ end_index, step, unroll_mode,
prevent_vectorization));
loop->Emit(ir_builder);
return loop;
@@ -97,7 +97,7 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
ir_builder->SetInsertPoint(&func->getEntryBlock(),
func->getEntryBlock().getFirstInsertionPt());
llvm::Value* indvar_address =
- ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr,
+ ir_builder->CreateAlloca(start_index_->getType(), nullptr,
AsStringRef(GetQualifiedName("invar_address")));
// Preheader basic block.
@@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
llvm::IRBuilder<>* ir_builder) {
const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
+ const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
llvm::LLVMContext* ctx = &start_index_->getContext();
std::vector<llvm::Metadata*> result;
- if (prevent_unrolling_) {
+ if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
result.push_back(llvm::MDNode::get(
*ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
}
@@ -162,6 +163,10 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
}
+ if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
+ result.push_back(llvm::MDNode::get(
+ *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
+ }
return result;
}
@@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
- return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
- prevent_unrolling, prevent_vectorization);
+ return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
+ unroll_mode, prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
llvm::Value* start_index,
llvm::Value* end_index,
llvm::Value* stride,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
if (inner_loop_body_bb_ != nullptr) {
// Create this loop inside the previous one.
ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
}
std::unique_ptr<ForLoop> loop(new ForLoop(
- /*prefix=*/name_, suffix, start_index, end_index, stride,
- prevent_unrolling, prevent_vectorization));
+ /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
+ prevent_vectorization));
loop->Emit(ir_builder_);
if (outer_loop_preheader_bb_ == nullptr) {
@@ -215,23 +220,23 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
- return AddLoop(suffix, ir_builder_->getInt64(start_index),
- ir_builder_->getInt64(end_index), prevent_unrolling,
+ return AddLoop(suffix, GetConstantWithIndexType(start_index),
+ GetConstantWithIndexType(end_index), unroll_mode,
prevent_vectorization);
}
std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
int64 end_index, int64 stride,
tensorflow::StringPiece suffix,
- bool prevent_unrolling,
+ UnrollMode unroll_mode,
bool prevent_vectorization) {
CHECK_LE(start_index, end_index);
- return AddLoop(suffix, ir_builder_->getInt64(start_index),
- ir_builder_->getInt64(end_index),
- ir_builder_->getInt64(stride), prevent_unrolling,
+ return AddLoop(suffix, GetConstantWithIndexType(start_index),
+ GetConstantWithIndexType(end_index),
+ GetConstantWithIndexType(stride), unroll_mode,
prevent_vectorization);
}
@@ -245,7 +250,7 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::StringPiece suffix) {
- llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr);
+ llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size());
for (int64 dimension : dimensions) {
std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
/*start_index=*/0,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index d915f95db1..0dd5b9d3b2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -34,6 +34,12 @@ limitations under the License.
namespace xla {
namespace llvm_ir {
+enum class UnrollMode {
+ kDefaultUnroll,
+ kFullyUnroll,
+ kNoUnroll,
+};
+
// A class for constructing a for-loop in LLVM IR.
class ForLoop {
public:
@@ -69,12 +75,13 @@ class ForLoop {
// LLVM IR. If non-empty, it is prepended to the name of the induction
// variable value and each basic block created for the loop.
//
- // If `prevent_unrolling` is true then emit metadata that directs LLVM to not
- // unroll the generated loop.
+ // `unroll_mode` specifies the desired LLVM unrolling behavior for generated
+ // loop.
static std::unique_ptr<ForLoop> EmitForLoop(
tensorflow::StringPiece prefix, llvm::Value* start_index,
llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
- bool prevent_unrolling = false, bool prevent_vectorization = false);
+ UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// The names of the blocks follow LLVM's conventions. Control flow amongst the
// blocks for the example C code looks like:
@@ -128,7 +135,7 @@ class ForLoop {
ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
- bool prevent_unrolling, bool prevent_vectorization);
+ UnrollMode unroll_mode, bool prevent_vectorization);
// Emit the loop at the insert point of the builder.
void Emit(llvm::IRBuilder<>* ir_builder);
@@ -161,7 +168,7 @@ class ForLoop {
llvm::BasicBlock* body_bb_;
llvm::BasicBlock* exit_bb_;
llvm::Value* indvar_;
- bool prevent_unrolling_;
+ UnrollMode unroll_mode_;
bool prevent_vectorization_;
TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
@@ -170,46 +177,52 @@ class ForLoop {
// A simple class for constructing nested for-loops.
class ForLoopNest {
public:
- explicit ForLoopNest(llvm::IRBuilder<>* ir_builder)
- : ForLoopNest(/*name=*/"", ir_builder) {}
+ explicit ForLoopNest(llvm::IRBuilder<>* ir_builder,
+ llvm::Type* index_ty = nullptr)
+ : ForLoopNest(/*name=*/"", ir_builder) {
+ SetIndexType(index_ty);
+ }
- ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder)
+ ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder,
+ llvm::Type* index_ty = nullptr)
: name_(std::string(name)),
outer_loop_preheader_bb_(nullptr),
outer_loop_exit_bb_(nullptr),
inner_loop_body_bb_(nullptr),
- ir_builder_(ir_builder) {}
+ ir_builder_(ir_builder) {
+ SetIndexType(index_ty);
+ }
// Adds a loop to the nest. If no loop has been added yet then emit a loop at
// the current insert point of the given builder. If one or more loops have
- // been added then emit loop inside the body of the last added loop. If
- // prevent_unrolling is true, then metadata is emitting directing LLVM to not
- // unroll this loop.
- std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index, llvm::Value* stride,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ // been added then emit loop inside the body of the last added loop.
+ // unroll_mode is used to emit metadata that controls LLVM unrolling.
+ std::unique_ptr<ForLoop> AddLoop(
+ tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* stride,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
- std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
- llvm::Value* start_index,
- llvm::Value* end_index,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// A convenient wrapper of the other flavor of AddLoop. The given start and
// end index are constant.
- std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
- int64 stride, tensorflow::StringPiece suffix,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ int64 start_index, int64 end_index, int64 stride,
+ tensorflow::StringPiece suffix,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Like the above, except that it defaults to a stride of one.
- std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
- tensorflow::StringPiece suffix,
- bool prevent_unrolling = false,
- bool prevent_vectorization = false);
+ std::unique_ptr<ForLoop> AddLoop(
+ int64 start_index, int64 end_index, tensorflow::StringPiece suffix,
+ UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll,
+ bool prevent_vectorization = false);
// Add loops to iterate through the indices within the specified
// shape. The returned index collects the induction variables of the
@@ -245,6 +258,14 @@ class ForLoopNest {
llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; }
private:
+ void SetIndexType(llvm::Type* index_ty) {
+ index_type_ = index_ty == nullptr ? ir_builder_->getInt64Ty() : index_ty;
+ }
+
+ llvm::Constant* GetConstantWithIndexType(int64 c) const {
+ return llvm::ConstantInt::get(index_type_, c);
+ }
+
// Human-friendly name of the loop nest.
string name_;
@@ -259,6 +280,8 @@ class ForLoopNest {
llvm::IRBuilder<>* ir_builder_;
+ llvm::Type* index_type_;
+
TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest);
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index ec04239b4f..6c55361b44 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -87,18 +88,10 @@ llvm::Value* EmitCallToIntrinsic(
tensorflow::gtl::ArraySlice<llvm::Value*> operands,
tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
llvm::IRBuilder<>* ir_builder) {
- std::vector<llvm::Type*> types;
- for (auto type : overloaded_types) {
- types.push_back(type);
- }
llvm::Module* module = ModuleFromIRBuilder(ir_builder);
- llvm::Function* intrinsic =
- llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
- std::vector<llvm::Value*> operands_vec;
- for (auto operand : operands) {
- operands_vec.push_back(operand);
- }
- return ir_builder->CreateCall(intrinsic, operands_vec);
+ llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(
+ module, intrinsic_id, AsArrayRef(overloaded_types));
+ return ir_builder->CreateCall(intrinsic, AsArrayRef(operands));
}
llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
@@ -201,6 +194,10 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
// An Opaque is like a void*, use i8*.
case OPAQUE:
return llvm::Type::getInt8PtrTy(module->getContext());
+ case TOKEN:
+ // Tokens do not have a physical representation, but the compiler needs
+ // some placeholder type, so use int8*.
+ return llvm::Type::getInt8PtrTy(module->getContext());
default:
LOG(FATAL) << "unsupported type " << element_type;
}
@@ -253,130 +250,14 @@ StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
return shape;
}
-namespace {
-
-// Recursively construct a multidimensional LLVM constant which represents the
-// given literal. The minor-to-major dimension ordering in the constant matches
-// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0
-// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a
-// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type
-// [4 x [3 x [2 x float]] will be returned.
-//
-// multi_index is a multidimensional index into the array. dimension_index is an
-// index into the minor_to_major field in the literal shape. This determines
-// which dimension is iterated over in this level of the recursion. Dimensions
-// are iterated from most major down to most minor (highest dimension_index
-// value down to zero).
-llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
- std::vector<int64>* multi_index,
- llvm::Module* module) {
- const Shape& shape = literal.shape();
- llvm::Type* ir_element_type =
- llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module);
- if (dimension_index == -1) {
- // Base case of the recursion. Index into the data field of the protobuf
- // with the multi index.
- llvm::Constant* value;
- switch (shape.element_type()) {
- case PRED:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<bool>(*multi_index));
- break;
- case U8:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint8>(*multi_index));
- break;
- case S32:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<int32>(*multi_index));
- break;
- case U32:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint32>(*multi_index));
- break;
- case S64:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<int64>(*multi_index));
- break;
- case U64:
- value = llvm::ConstantInt::get(ir_element_type,
- literal.Get<uint64>(*multi_index));
- break;
- case F32:
- value = llvm::ConstantFP::get(ir_element_type,
- literal.Get<float>(*multi_index));
- break;
- case BF16:
- value = llvm::ConstantInt::get(
- ir_element_type,
- tensorflow::bit_cast<uint16>(literal.Get<bfloat16>(*multi_index)));
- break;
- case F16:
- value = llvm::ConstantFP::get(
- ir_element_type,
- static_cast<float>(literal.Get<half>(*multi_index)));
- break;
- case F64:
- value = llvm::ConstantFP::get(ir_element_type,
- literal.Get<double>(*multi_index));
- break;
- case C64: {
- complex64 x = literal.Get<complex64>(*multi_index);
- value = llvm::ConstantStruct::get(
- static_cast<llvm::StructType*>(ir_element_type),
- llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
- x.real()),
- llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module),
- x.imag()));
- break;
- }
- default:
- LOG(FATAL) << "unsupported type " << shape.element_type();
- }
- return value;
- }
-
- // The dimension index starts at the one less than the rank of the array and
- // decrements with each recursive call. We want to iterate through the
- // dimensions in major-to-minor order as we recurse so just index into
- // minor_to_major to get the dimension number for this level of the recursion.
- int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index);
-
- // Recursively call LiteralToConstant to construct subarrays for the
- // more-minor dimensions. Gather the subarrays into a vector for bundling into
- // a new (higher-dimensional) ConstantArray.
- std::vector<llvm::Constant*> elements;
- for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
- (*multi_index)[dimension] = i;
- elements.push_back(
- LiteralToConstant(literal, dimension_index - 1, multi_index, module));
- }
-
- llvm::Type* element_type;
- if (elements.empty()) {
- element_type = ir_element_type;
- for (int i = 0; i < dimension_index; ++i) {
- int64 index = LayoutUtil::Minor(shape.layout(), i);
- element_type =
- llvm::ArrayType::get(element_type, shape.dimensions(index));
- }
- } else {
- element_type = elements[0]->getType();
- }
- llvm::ArrayType* aggregate_type =
- llvm::ArrayType::get(element_type, shape.dimensions(dimension));
- return llvm::ConstantArray::get(aggregate_type, elements);
-}
-
-} // namespace
-
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
llvm::Module* module) {
- std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
- llvm::Constant* value = LiteralToConstant(
- literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
- &multi_index, module);
- return value;
+ const char* data = static_cast<const char*>(literal.untyped_data());
+ CHECK_EQ(module->getDataLayout().isLittleEndian(),
+ tensorflow::port::kLittleEndian);
+ return llvm::ConstantDataArray::getString(
+ module->getContext(), llvm::StringRef(data, literal.size_bytes()),
+ /*AddNull=*/false);
}
llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 4a10ec466d..9c51861eac 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -27,7 +27,7 @@ limitations under the License.
#include "llvm/IR/Module.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/raw_ostream.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 0728ccfff7..e8b0605b9d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -83,16 +83,19 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
// Sanity check: In multi-output fusion, all shapes produced must have the
// same dimensions.
for (const IrArray& array : target_arrays) {
- CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape()));
+ CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape()))
+ << ": '" << shape_.ShortDebugString() << "' does not match '"
+ << array.GetShape().ShortDebugString() << "'";
}
}
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name) {
+ tensorflow::StringPiece loop_name, llvm::Type* index_type) {
+ CHECK_NE(index_type, nullptr);
if (ShapeUtil::IsScalar(shape_)) {
// No loop needed, so set exit_bb_ to nullptr.
exit_bb_ = nullptr;
- return {IrArray::Index()};
+ return {IrArray::Index(index_type)};
}
// Create loop nest with one for-loop for each dimension of the target shape.
@@ -100,7 +103,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
// class so emit loops in order from most-major dimension down to most-minor
// dimension (of the target shape).
ForLoopNest loop_nest(loop_name, ir_builder_);
- IrArray::Index array_index(shape_.dimensions_size());
+ IrArray::Index array_index(index_type, shape_.dimensions_size());
for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
int64 dimension = LayoutUtil::Major(shape_.layout(), i);
std::unique_ptr<ForLoop> loop = loop_nest.AddLoop(
@@ -123,9 +126,14 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
return {array_index};
}
-Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) {
+Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name,
+ llvm::Type* index_type) {
+ if (index_type == nullptr) {
+ index_type = ir_builder_->getInt64Ty();
+ }
+
for (const IrArray::Index& array_index :
- EmitIndexAndSetExitBasicBlock(loop_name)) {
+ EmitIndexAndSetExitBasicBlock(loop_name, index_type)) {
TF_RETURN_IF_ERROR(body_emitter_(array_index));
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
index b70d28ecd3..6be1c2fba2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -65,13 +65,16 @@ class LoopEmitter {
// specifies the element, will return multiple indices if the loop is
// unrolled.
std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock() {
- return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"");
+ return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"",
+ ir_builder_->getInt64Ty());
}
+
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
- tensorflow::StringPiece loop_name);
+ tensorflow::StringPiece loop_name, llvm::Type* index_type);
// Emits a complete loop nest for every element in the given shape.
- Status EmitLoop(tensorflow::StringPiece loop_name = "");
+ Status EmitLoop(tensorflow::StringPiece loop_name = "",
+ llvm::Type* index_type = nullptr);
protected:
// An IR emitter that generates the loop body.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
index dacc54742c..3b298f4746 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
@@ -45,7 +45,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
// Read start indices from start_indices_generator.
const int64 rank = ShapeUtil::Rank(output_shape);
- IrArray::Index start_index(rank);
+ IrArray::Index start_index(ir_builder->getInt64Ty(), rank);
for (int64 i = 0; i < rank; ++i) {
IrArray::Index dim_index({ir_builder->getInt64(i)});
TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index));
@@ -79,7 +79,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
//
// output_index[dim] = start_index[dim] + update_index[dim]
//
- IrArray::Index output_index(rank);
+ IrArray::Index output_index(start_index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast(
start_index[i], update_index[i]->getType());
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 375c4a6780..53efc30c36 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -108,6 +107,11 @@ ExecutionOptions CreateExecutionOptions(
->set_xla_dump_optimized_hlo_proto_to(
build_options.dump_optimized_hlo_proto_to().value());
}
+ if (build_options.dump_unoptimized_hlo_proto_to().has_value()) {
+ execution_options.mutable_debug_options()
+ ->set_xla_dump_unoptimized_hlo_proto_to(
+ build_options.dump_unoptimized_hlo_proto_to().value());
+ }
if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
execution_options.mutable_debug_options()
->set_xla_dump_per_pass_hlo_proto_to(
@@ -150,7 +154,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
for (int i = 0; i < argument_layouts.size(); ++i) {
const Shape& argument_shape = *argument_layouts[i];
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
tensorflow::gtl::optional<const OpMetadata*> metadata =
ParameterMetadata(computation, /*parameter_number=*/i);
@@ -174,8 +179,8 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
}
}
if (build_options.result_layout() != nullptr) {
- TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(
- *build_options.result_layout(), program_shape.result()));
+ TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
+ program_shape.result()));
}
ExecutionOptions execution_options =
@@ -185,6 +190,9 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, argument_layouts, &execution_options));
+ VLOG(3) << "Computation Layout: "
+ << module_config->entry_computation_layout().ToString();
+
TF_ASSIGN_OR_RETURN(
se::StreamExecutor * executor,
execute_backend_->stream_executor(build_options.device_ordinal()));
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index f410921b4b..d631fb5ee4 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -131,18 +131,23 @@ Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
return Status::OK();
}
-Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
- // RecvDone doesn't create a new buffer but rather aliases its input (Recv)
- // tuple element at {0} to its output.
+Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
+ // RecvDone produces a two-element tuple containing the data value (which
+ // aliases part of its operand) and a token. Only the tuple index table and
+ // the token are defined by the RecvDone.
+ NewLogicalBuffer(recv_done, /*index=*/{});
+ NewLogicalBuffer(recv_done, /*index=*/{1});
return Status::OK();
}
Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
- // Send creates new buffers for the top-level tuple and the context (tuple
- // element at {1}). Tuple element at {0} is an alias of the Send operand, so
- // we don't need to create a new Logical Buffer for that.
+ // Send creates new buffers for the top-level tuple, the context (tuple
+ // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
+ // is an alias of the Send operand, so we don't need to create a new Logical
+ // Buffer for that.
NewLogicalBuffer(send, /*index=*/{});
NewLogicalBuffer(send, /*index=*/{1});
+ NewLogicalBuffer(send, /*index=*/{2});
return Status::OK();
}
@@ -152,10 +157,10 @@ Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select) {
+Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
// Select allocates a new buffer and then shallow copies the on_true or
// on_false buffer into this new buffer.
- NewLogicalBuffer(select, /*index=*/{});
+ NewLogicalBuffer(tuple_select, /*index=*/{});
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index b5ef396787..81f524d84a 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -63,7 +63,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
- Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
// A map from the buffer ID to the logical buffer
std::vector<std::unique_ptr<LogicalBuffer>> logical_buffers_;
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
new file mode 100644
index 0000000000..4166ef5baf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -0,0 +1,338 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
+ bool changed = false;
+
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ computation_ = computation;
+ RecomputeReachability();
+ candidates_.clear();
+ candidates_index_.clear();
+ all_fusion_candidates_.clear();
+
+ int64 index = 0;
+ for (auto it : computation_->MakeInstructionPostOrder()) {
+ candidates_.emplace_back(it);
+ InsertOrDie(&candidates_index_, it, index++);
+ }
+
+ // Create the initial candidate list for each Node.
+ for (auto& node : candidates_) {
+ HloInstruction* instruction = node.hlo;
+ int64 instruction_id = get_candidate_id(instruction);
+ FusionCandidate& instr_node = candidates_[instruction_id];
+ if (!IsFusible(instruction)) {
+ continue;
+ }
+ all_fusion_candidates_.push_back(instruction);
+
+ std::vector<HloInstruction*> candidates;
+ tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
+ VLOG(10) << "Looking at instruction: " << instruction->name();
+ for (auto operand : instruction->operands()) {
+ // Filter out the non-interesting instructions -- they
+ // will not generate the savings.
+ if (!IsProfitableOperand(operand)) {
+ VLOG(10) << "Operand not profitable: " << operand->name();
+ continue;
+ }
+ VLOG(10) << "Operand profitable: " << operand->name();
+ for (auto user : operand->users()) {
+ VLOG(10) << "User: " << user->name();
+ if (user == instruction || !IsFusible(user)) {
+ VLOG(10) << "User is not fusible, or is the instruction itself: "
+ << user->name();
+ continue;
+ }
+ int64 user_id = get_candidate_id(user);
+ if (is_connected(instruction, user)) {
+ VLOG(10) << "User is connected: " << user->name();
+ continue;
+ }
+ if (instruction_id < user_id &&
+ user->opcode() == HloOpcode::kFusion) {
+ VLOG(10) << "User ID for user: " << user->name() << " is "
+ << user_id << " which is higher than " << instruction_id;
+ continue;
+ }
+ if (!LegalToFuse(instruction, user)) {
+ VLOG(10) << "User not legal to fuse: " << user->name();
+ continue;
+ }
+ if (candidates_set.insert(user).second) {
+ VLOG(10) << "User added to candidate list: " << user->name();
+ candidates.push_back(user);
+ }
+ }
+ }
+
+ // Iterate over candidates rather than candidates_set to avoid
+ // nondeterminism.
+ for (auto candidate : candidates) {
+ int64 profit = GetProfit(instruction, candidate);
+ if (profit > 0) {
+ FusionCandidate& candidate_node =
+ candidates_[get_candidate_id(candidate)];
+ instr_node.fusibles.emplace_back(candidate, profit);
+ candidate_node.fusibles.emplace_back(instruction, profit);
+ worklist_.emplace(instruction, candidate, profit);
+ }
+ }
+ }
+ if (Perform()) {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ HloInstruction* remaining = instr1;
+ HloInstruction* fused = instr2;
+ // Make sure that if only one of the instructions is a fusion, or if only one
+ // of the instructions is a multi-output fusion, it's what will be fused into.
+ if (fused->opcode() == HloOpcode::kFusion) {
+ std::swap(remaining, fused);
+ }
+ if (fused->IsMultiOutputFusion()) {
+ std::swap(remaining, fused);
+ }
+
+ if (fused->opcode() == HloOpcode::kFusion) {
+ remaining->MergeFusionInstructionIntoMultiOutput(fused);
+ } else {
+ remaining->FuseInstructionIntoMultiOutput(fused);
+ }
+ return remaining;
+}
+
+bool MultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
+ // kConstant instruction will not have memory reads, so it won't be a profit
+ // source. Skip them.
+ if (instr->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsEffectiveScalar(instr->shape())) {
+ return false;
+ }
+ // We don't target to fuse producer/consumer instructions -- this should
+ // be taken care of by the instruction_fusion pass. If instr has only
+ // one user, it will not have sibling instructions. We won't consider it.
+ if (instr->user_count() < 2) {
+ return false;
+ }
+ return true;
+}
+
+void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
+ HloInstruction* fusion = instr1;
+ HloInstruction* fused = instr2;
+ if (is_fused(instr1)) {
+ fusion = instr2;
+ fused = instr1;
+ }
+
+ // Insert the newly created instruction (if any), to candidates_.
+ for (auto use : fusion->users()) {
+ if (candidates_index_.find(use) == candidates_index_.end()) {
+ int64 index = candidates_.size();
+ candidates_.emplace_back(use);
+ InsertOrDie(&candidates_index_, use, index++);
+ }
+ }
+ FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)];
+ FusionCandidate& fused_node = candidates_[get_candidate_id(fused)];
+
+ // Update the reachability graph.
+ UpdateReachability(fusion, fused, all_fusion_candidates_,
+ [this](HloInstruction* instr) { return is_fused(instr); });
+
+ // Update the fusible list for fusion. Variable new_fusibles keeps
+ // track of the new or changed entries.
+ std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
+ tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ auto it = fusion_node.fusibles.begin();
+ while (it != fusion_node.fusibles.end()) {
+ HloInstruction* instr = it->first;
+ if (is_fused(instr) || is_connected(fusion, instr)) {
+ it = fusion_node.fusibles.erase(it);
+ continue;
+ }
+ in_list.insert(instr);
+ int64 profit = GetProfit(instr, fusion);
+ if (profit > it->second) {
+ it->second = profit;
+ new_fusibles.emplace_back(instr, profit);
+ }
+ ++it;
+ }
+
+ // Fused_node has been fused into fusion_node. Take the fusion candidates
+ // (fusibles) from fused_nodes and add them to the fusion_node's. Filter
+ // out those fusibles that no longer valid (or already in the list).
+ for (const auto& it : fused_node.fusibles) {
+ HloInstruction* instr = it.first;
+ if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
+ continue;
+ }
+ if (in_list.count(instr) > 0) {
+ continue;
+ }
+ int64 profit = GetProfit(instr, fusion);
+ fusion_node.fusibles.emplace_back(instr, profit);
+ new_fusibles.emplace_back(instr, profit);
+ }
+ fused_node.fusibles.clear();
+
+ // Update the worklist_.
+ for (auto it : new_fusibles) {
+ worklist_.emplace(fusion, it.first, it.second);
+ }
+}
+
+bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1,
+ HloInstruction* instr2) {
+ if (instr1 == instr2) {
+ return false;
+ }
+ if (instr1->opcode() != HloOpcode::kFusion) {
+ return false;
+ }
+
+ // Fusing nodes with 0 user makes no sense and the rest of the implementation
+ // doesn't support it either.
+ if (instr1->user_count() == 0 || instr2->user_count() == 0) {
+ return false;
+ }
+
+ // Check if the users of multioutput fusion is not a get-tuple-element.
+ // If this is the case, we bail out because the transformation assumes
+ // the users are get-tuple-element.
+ auto multioutput_user_is_not_gte = [](HloInstruction* instr) {
+ if (!instr->IsMultiOutputFusion()) {
+ return false;
+ }
+ for (auto user : instr->users()) {
+ if (user->opcode() != HloOpcode::kGetTupleElement) {
+ return true;
+ }
+ }
+ return false;
+ };
+ if (multioutput_user_is_not_gte(instr1) ||
+ multioutput_user_is_not_gte(instr2)) {
+ return false;
+ }
+
+ if (is_connected(instr1, instr2)) {
+ return false;
+ }
+ if (!ShapesCompatibleForFusion(instr1, instr2)) {
+ return false;
+ }
+
+ return true;
+}
+
+void MultiOutputFusion::RecomputeReachability() {
+ reachability_ = computation_->ComputeReachability();
+}
+
+void MultiOutputFusion::UpdateReachability(
+ HloInstruction* instr1, HloInstruction* instr2,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ const std::function<bool(HloInstruction*)>& skip) {
+ for (auto instr : instrs_to_update) {
+ if (skip != nullptr && skip(instr)) {
+ continue;
+ }
+ if (reachability_->IsReachable(instr2, instr) &&
+ reachability_->IsReachable(instr1, instr)) {
+ // If a candidate was already reachable by both, no update needed.
+ continue;
+ }
+ if (reachability_->IsReachable(instr2, instr)) {
+ reachability_->FastSetReachabilityToUnion({instr, instr1}, instr);
+ }
+ if (reachability_->IsReachable(instr1, instr)) {
+ reachability_->FastSetReachabilityToUnion({instr, instr2}, instr);
+ }
+ }
+}
+
+bool MultiOutputFusion::Perform() {
+ int changed = false;
+ // Pick the top candidate from queue and try to merge.
+ while (!worklist_.empty()) {
+ if (fuel_ <= 0) {
+ VLOG(2) << "No fusing: run out of fuel.";
+ break;
+ }
+ ToBeFused candidate = worklist_.top();
+ worklist_.pop();
+
+ HloInstruction* instr1 = candidate.instr1;
+ HloInstruction* instr2 = candidate.instr2;
+
+ if (is_fused(instr1) || is_fused(instr2)) {
+ continue;
+ }
+
+ VLOG(1) << "Considering candidate profit_score=" << candidate.score
+ << "\n\t\tinstr1 = " << instr1->ToString()
+ << "\n\t\tinstr2 = " << instr2->ToString();
+
+ if (LegalToFuse(instr1, instr2)) {
+ VLOG(1) << "Fuse!";
+ VLOG(2) << "Before multi_output_fusion:";
+ VLOG(2) << "instr1: " << instr1->ToString();
+ VLOG(2) << "\n"
+ << instr1->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ VLOG(2) << "instr2: " << instr2->ToString();
+ if (instr2->opcode() == HloOpcode::kFusion) {
+ VLOG(2) << "\n"
+ << instr2->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ }
+ HloInstruction* ret = Fuse(instr1, instr2);
+ set_is_fused(ret == instr1 ? instr2 : instr1);
+ Update(instr1, instr2);
+ changed = true;
+ VLOG(2) << "After fusion, \t this: " << ret->name() << "\n"
+ << ret->fused_instructions_computation()->ToString(
+ HloPrintOptions().set_indent_amount(1));
+ auto users = ret->users();
+ --fuel_;
+ }
+ }
+ if (DoProducerConsumerMultiOutputFusion()) {
+ changed = true;
+ }
+ return changed;
+}
+
+bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; }
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
new file mode 100644
index 0000000000..0019cd7254
--- /dev/null
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -0,0 +1,169 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
+
+#include <queue>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// This class implements the fusing of sibling fusion instructions that sharing
+// common operands.
+// It constructs the following associated data structures.
+// (1) candidates_: stores the instruction and the set of instructions it can
+// fuse to.
+// (2) candidates_index_: maps instruction to id.
+// (3) reachability_: reachability map in this computation.
+// (4) all_fusion_candidates_: the vector of candidate instructions.
+// (5) worklist_: a priority queue that contains pairs of instructions to be
+// fused and their fusion profit scores.
+//
+// Function Perform() applies the optimization. It picks up the most profitable
+// pair in the worklist_, check if it's legal to fuse and fuse the pair.
+// After fusion, it updates the associated structure such as reachability_,
+// candidates_ and worklist_.
+// Note that the reachability map is updated based on the original computation.
+// This works because the reachability is monotonically increasing with
+// instruction fusion.
+class MultiOutputFusion : public HloPassInterface {
+ public:
+ MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
+
+ tensorflow::StringPiece name() const override {
+ return "multi_output_fusion";
+ }
+
+ // Run multi-output fusion on the given module. Returns whether the module
+ // was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // Main entry for the optimization. Returns true if the optimization happens.
+ bool Perform();
+
+ // Test if instr1 and instr2 have the compatible shapes that can be legally
+ // fused.
+ virtual bool ShapesCompatibleForFusion(HloInstruction* instr1,
+ HloInstruction* instr2) = 0;
+
+ // Whether the instruction is a candidate for fusion.
+ virtual bool IsFusible(HloInstruction* instr) = 0;
+
+ // This function estimates the savings by merging instr1 and instr2 into one
+ // multi-output fusion instruction.
+ virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0;
+
+ // Whether fusing the instruction can reduce memory reads.
+ virtual bool IsProfitableOperand(HloInstruction* instr);
+
+ // Test if it's legal to fuse instr1 and instr2 into one fusion instruction.
+ virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction.
+ // The other instruction is removed from its parent computation.
+ virtual HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Recompute reachability for the current computation.
+ void RecomputeReachability();
+
+ // Returns the reachability map for the current computation.
+ HloReachabilityMap* reachability() const { return reachability_.get(); }
+
+ // Returns the computation for the pass.
+ HloComputation* computation() const { return computation_; }
+
+ // Update the reachability map after fusing instr1 and instr2.
+ void UpdateReachability(
+ HloInstruction* instr1, HloInstruction* instr2,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instrs_to_update,
+ const std::function<bool(HloInstruction*)>& skip = nullptr);
+
+ // Hook for multi-output fusion along producer-consumer edges.
+ // Returns whether any instructions were fused.
+ //
+ // TODO(b/80420762): Perform producer-consumer multi-output fusion in
+ // InstructionFusion instead.
+ virtual bool DoProducerConsumerMultiOutputFusion();
+
+ private:
+ // Update the internal data structures after instr1 and instr2 are fused into
+ // one fusion instruction.
+ void Update(HloInstruction* instr1, HloInstruction* instr2);
+
+ // Optimization fuel is a compiler debugging technique that makes an
+ // optimization pass stop what it is doing after having made N changes to the
+ // program, where N is the fuel. By varying N, this can be used to find the
+ // first single change that makes a test fail.
+ int64 fuel_;
+
+ // Computation for the pass.
+ HloComputation* computation_;
+
+ // An internal data structure for each instruction in current computation.
+ // When an instruction is removed, member 'hlo' is set to nullptr.
+ struct FusionCandidate {
+ HloInstruction* hlo;
+ std::list<std::pair<HloInstruction*, int64>> fusibles;
+ explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {}
+ };
+ std::vector<FusionCandidate> candidates_;
+
+ // A map that maps an instruction to the index_.
+ tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_;
+
+ // The reachability map of current computation.
+ std::unique_ptr<HloReachabilityMap> reachability_;
+
+ // This stores all the candidate instructions in current computation.
+ std::vector<HloInstruction*> all_fusion_candidates_;
+
+ // The pair of candidates to be fused and the profit score.
+ struct ToBeFused {
+ HloInstruction* instr1;
+ HloInstruction* instr2;
+ int64 score;
+ ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score)
+ : instr1(instr1), instr2(instr2), score(score) {}
+ bool operator<(const ToBeFused& rhs) const { return score < rhs.score; }
+ };
+ std::priority_queue<ToBeFused> worklist_;
+
+ int64 get_candidate_id(HloInstruction* instr) {
+ return FindOrDie(candidates_index_, instr);
+ }
+
+ bool is_fused(HloInstruction* instr) {
+ return candidates_[get_candidate_id(instr)].hlo == nullptr;
+ }
+
+ void set_is_fused(HloInstruction* instr) {
+ candidates_[get_candidate_id(instr)].hlo = nullptr;
+ }
+
+ bool is_connected(HloInstruction* instr1, HloInstruction* instr2) {
+ return reachability_->IsConnected(instr1, instr2);
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index 3a6a7c25f4..f6e7578a89 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -67,22 +67,17 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
has_numeric_suffix = true;
// Remove numeric suffix from root.
root = root.substr(0, separator_index);
- // Update count to at least the numeric suffix value to avoid future
- // colisions with this name.
- generated_names_[root] = std::max(generated_names_[root], numeric_suffix);
}
}
- int64* count = &(generated_names_[root]);
- if (*count == 0) {
- *count = 1;
+
+ SequentialIdGenerator& id_generator = generated_names_[root];
+ numeric_suffix = id_generator.RegisterId(numeric_suffix);
+ if (numeric_suffix == 0) {
return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0)
: root;
- } else {
- tensorflow::strings::StrAppend(&root, separator_, *count);
- // Increment lookup under old 'root' name.
- (*count)++;
- return root;
}
+ tensorflow::strings::StrAppend(&root, separator_, numeric_suffix);
+ return root;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 4139c2700b..4423d61069 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -17,10 +17,11 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_
#include <string>
-#include <unordered_map>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -44,13 +45,40 @@ class NameUniquer {
static string GetSanitizedName(const string& name);
private:
+ // Used to track and generate new identifiers for the same instruction name
+ // root.
+ class SequentialIdGenerator {
+ public:
+ SequentialIdGenerator() = default;
+
+ // Tries to register id as used identifier. If id is not already used, the
+ // id itself will be returned. Otherwise a new one will be generated, and
+ // returned.
+ int64 RegisterId(int64 id) {
+ if (used_.insert(id).second) {
+ return id;
+ }
+ while (!used_.insert(next_).second) {
+ ++next_;
+ }
+ return next_++;
+ }
+
+ private:
+ // The next identifier to be tried.
+ int64 next_ = 0;
+
+ // Set of all the identifiers which has been used.
+ tensorflow::gtl::FlatSet<int64> used_;
+ };
+
// The string to use to separate the prefix of the name from the uniquing
// integer value.
string separator_;
- // Map from name prefix to the number of names generated using that prefix
- // so far.
- std::unordered_map<string, int64> generated_names_;
+ // Map from name prefix to the generator data structure which tracks used
+ // identifiers and generates new ones.
+ tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_;
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
};
diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc
index 2ec255558c..3e2592c6ac 100644
--- a/tensorflow/compiler/xla/service/name_uniquer_test.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc
@@ -54,12 +54,13 @@ TEST_F(NameUniquerTest, NumericSuffixes) {
EXPECT_EQ("foo", uniquer.GetUniqueName("foo"));
EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54"));
- EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo"));
EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1"));
- EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1"));
- EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000"));
- EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000"));
- EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1"));
+ EXPECT_EQ("foo.55.0", uniquer.GetUniqueName("foo.55.1"));
+ EXPECT_EQ("bar.1000", uniquer.GetUniqueName("bar.1000"));
+ EXPECT_EQ("bar.2000", uniquer.GetUniqueName("bar.2000"));
+ EXPECT_EQ("bar.-2000", uniquer.GetUniqueName("bar.-2000"));
+ EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.1"));
}
TEST_F(NameUniquerTest, PrefixHasSuffix) {
@@ -77,12 +78,12 @@ TEST_F(NameUniquerTest, Sanitize) {
EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54"));
EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54"));
EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1"));
- EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo_2", uniquer.GetUniqueName("foo"));
// Invalid characters will be replaced with '_'.
- EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000"));
- EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000"));
- EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1"));
+ EXPECT_EQ("bar_1000", uniquer.GetUniqueName("bar<1000"));
+ EXPECT_EQ("bar_2000", uniquer.GetUniqueName("bar<2000"));
+ EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar_1"));
// Separator is only recognized in the middle of the prefix.
EXPECT_EQ("_10", uniquer.GetUniqueName(
@@ -93,5 +94,15 @@ TEST_F(NameUniquerTest, Sanitize) {
EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_"));
}
+TEST_F(NameUniquerTest, KeepNamesInRandomOrder) {
+ NameUniquer uniquer(".");
+
+ EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11"));
+ EXPECT_EQ("foo.10", uniquer.GetUniqueName("foo.10"));
+ EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo.1"));
+ EXPECT_EQ("foo.12", uniquer.GetUniqueName("foo.12"));
+ EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo.3"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index d3bc47e61e..ac6ea4c72f 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -86,8 +86,8 @@ namespace xla {
// are provided below.
//
// Example nullary instruction:
-// Recv() == Op().WithOpcode(HloOpcode::kRecv)
-// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv)
+// Param() == Op().WithOpcode(HloOpcode::kParam)
+// Param(&a) == Op(&a).WithOpcode(HloOpcode::kParam)
//
// Example unary instruction:
// Abs() == Op().WithOpcode(HloOpcode::kAbs)
@@ -204,7 +204,7 @@ class LayoutPattern {
// Modifies the pattern to match only if the layout equals the given proto.
// The layout must outlive the returned pattern.
constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo(
- const Layout* layout) const {
+ const ::xla::Layout* layout) const {
return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>(
LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_);
}
@@ -726,6 +726,32 @@ class HloInstructionPatternFusionKindImpl {
::xla::HloInstruction::FusionKind kind_;
};
+// An HloInstructionPattern implementation that matches only if the instruction
+// is a kGetTupleElement with a particular tuple index.
+template <typename Previous>
+class HloInstructionPatternTupleIndexImpl {
+ public:
+ explicit constexpr HloInstructionPatternTupleIndexImpl(
+ const Previous& previous, int64 tuple_index)
+ : previous_(previous), tuple_index_(tuple_index) {}
+
+ bool Match(const ::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) &&
+ inst->opcode() == HloOpcode::kGetTupleElement &&
+ inst->tuple_index() == tuple_index_;
+ }
+
+ bool Match(::xla::HloInstruction* inst) const {
+ return previous_.Match(inst) &&
+ inst->opcode() == HloOpcode::kGetTupleElement &&
+ inst->tuple_index() == tuple_index_;
+ }
+
+ private:
+ Previous previous_;
+ int64 tuple_index_;
+};
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
@@ -841,6 +867,17 @@ class HloInstructionPattern {
HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
}
+ // Modifies the pattern to match only if the instruction is a
+ // get-tuple-element with the given tuple index.
+ constexpr HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternTupleIndexImpl<Impl>>
+ WithTupleIndex(int64 tuple_index) const {
+ return HloInstructionPattern<HloInstructionType,
+ HloInstructionPatternTupleIndexImpl<Impl>>(
+ HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index),
+ matched_inst_);
+ }
+
private:
Impl impl_;
HloInstructionType** matched_inst_;
@@ -880,9 +917,7 @@ Op(::xla::HloInstruction** matched_inst) {
return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \
}
XLA_NULLOP_PATTERN(Constant)
-XLA_NULLOP_PATTERN(Infeed)
XLA_NULLOP_PATTERN(Parameter)
-XLA_NULLOP_PATTERN(Recv)
#undef XLA_NULLOP_PATTERN
// Helpers for unary instructions.
@@ -919,18 +954,21 @@ XLA_UNOP_PATTERN(Cos)
XLA_UNOP_PATTERN(Exp)
XLA_UNOP_PATTERN(Fft)
XLA_UNOP_PATTERN(Floor)
+XLA_UNOP_PATTERN(GetTupleElement)
XLA_UNOP_PATTERN(Imag)
+XLA_UNOP_PATTERN(Infeed)
XLA_UNOP_PATTERN(IsFinite)
XLA_UNOP_PATTERN(Log)
XLA_UNOP_PATTERN(Not)
XLA_UNOP_PATTERN(Negate)
-XLA_UNOP_PATTERN(Outfeed)
XLA_UNOP_PATTERN(Real)
+XLA_UNOP_PATTERN(Recv)
+XLA_UNOP_PATTERN(RecvDone)
XLA_UNOP_PATTERN(Reduce)
XLA_UNOP_PATTERN(ReducePrecision)
XLA_UNOP_PATTERN(Reshape)
XLA_UNOP_PATTERN(Reverse)
-XLA_UNOP_PATTERN(Send)
+XLA_UNOP_PATTERN(SendDone)
XLA_UNOP_PATTERN(Sign)
XLA_UNOP_PATTERN(Sin)
XLA_UNOP_PATTERN(Sort)
@@ -981,8 +1019,10 @@ XLA_BINOP_PATTERN(Maximum)
XLA_BINOP_PATTERN(Minimum)
XLA_BINOP_PATTERN(Multiply)
XLA_BINOP_PATTERN(Ne)
+XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
+XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
XLA_BINOP_PATTERN(And)
XLA_BINOP_PATTERN(Or)
@@ -1040,6 +1080,32 @@ inline auto NonConstant(HloInstructionType** matched_inst)
return Op(matched_inst).IsNonConstant();
}
+// Add overloads for GetTupleElement which take a int64 specifying which tuple
+// element is selected.
+template <typename Arg>
+inline auto GetTupleElement(Arg&& arg, int64 tuple_index)
+ -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index)) {
+ return Op()
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index);
+}
+
+template <typename HloInstructionType, typename Arg>
+inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
+ int64 tuple_index)
+ -> decltype(Op(matched_inst)
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index)) {
+ return Op(matched_inst)
+ .WithOpcode(HloOpcode::kGetTupleElement)
+ .WithOperand(0, std::forward<Arg>(arg))
+ .WithTupleIndex(tuple_index);
+}
+
} // namespace match
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index 204e8c9920..a530581c34 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) {
ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
const HloInstruction* matched_inst;
HloInstruction* matched_operand;
@@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) {
p0 = f32[] parameter(0)
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
})";
- TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(
@@ -193,5 +193,23 @@ TEST(PatternMatcherTest, FusionKind) {
HloInstruction::FusionKind::kLoop)));
}
+TEST(PatternMatcherTest, GetTupleElement) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+
+ ENTRY while.v11 {
+ p0 = (f32[], f32[], f32[]) parameter(0)
+ ROOT gte = f32[] get-tuple-element(p0), index=1
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(0)));
+ EXPECT_TRUE(Match(root, match::Op().WithTupleIndex(1)));
+ EXPECT_FALSE(Match(root, match::Op().WithTupleIndex(2)));
+ EXPECT_FALSE(Match(root, match::GetTupleElement(match::Op(), 0)));
+ EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
index 0f26a025bf..ca86c5d13e 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -155,20 +155,15 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand,
case HloOpcode::kConstant: {
if (first_reshape_operand->opcode() == HloOpcode::kReshape) {
VLOG(5) << "Adding reshape to kConstant operand";
- HloInstruction* reshape = computation->AddInstruction(
+ return computation->AddInstruction(
HloInstruction::CreateReshape(new_shape, operand));
- operand->SetupDerivedInstruction(reshape);
- return reshape;
} else {
CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose);
VLOG(5) << "Adding transpose to kConstant operand";
std::vector<int64> inverse_permutation =
InversePermutation(first_reshape_operand->dimensions());
- HloInstruction* transpose =
- computation->AddInstruction(HloInstruction::CreateTranspose(
- new_shape, operand, inverse_permutation));
- operand->SetupDerivedInstruction(transpose);
- return transpose;
+ return computation->AddInstruction(HloInstruction::CreateTranspose(
+ new_shape, operand, inverse_permutation));
}
}
case HloOpcode::kRng: {
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
index 13e2d3258e..ad3b662c20 100644
--- a/tensorflow/compiler/xla/service/reshape_mover_test.cc
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -175,8 +175,9 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) {
TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {2, 3});
- auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<bool>({{true, true, false}, {false, false, true}})));
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, true, false}, {false, false, true}})));
auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1"));
@@ -255,12 +256,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) {
HloComputation::Builder builder(TestName());
auto root_shape = ShapeUtil::MakeShape(F32, {3, 2});
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1));
@@ -309,7 +310,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) {
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0"));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<float>({{1, 2, 3}, {4, 5, 6}})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
builder.AddInstruction(HloInstruction::CreateBinary(
@@ -348,7 +349,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) {
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 3}), "param0"));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({9, 8, 7})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({9, 8, 7})));
auto reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
auto reshape1 =
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 82be6bcf4f..da3b622bfa 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
@@ -62,55 +61,28 @@ namespace xla {
namespace {
-// Records the arguments used to invoke a computation in a SessionModule
-// proto.
-Status RecordArguments(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::StreamExecutor* executor, TransferManager* transfer_manager,
- SessionModule* module) {
- module->clear_arguments();
- for (const ShapedBuffer* argument : arguments) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, *argument));
- *module->add_arguments() = literal->ToProto();
- }
- return Status::OK();
-}
-
-// Records the result of a computation in a SessionModule proto.
-Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
- TransferManager* transfer_manager, SessionModule* module) {
- module->clear_result();
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, result));
- *module->mutable_result() = literal->ToProto();
- return Status::OK();
-}
-
// Records the arguments used to invoke a computation in an HloSnapshot proto.
Status RecordArguments(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- se::StreamExecutor* executor, TransferManager* transfer_manager,
+ se::Stream* stream, TransferManager* transfer_manager,
HloSnapshot* module) {
module->clear_arguments();
for (const ShapedBuffer* argument : arguments) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, *argument));
+ transfer_manager->TransferLiteralFromDevice(stream, *argument));
*module->add_arguments() = literal->ToProto();
}
return Status::OK();
}
// Records the result of a computation in a HloSnapshot proto.
-Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor,
+Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
TransferManager* transfer_manager, HloSnapshot* module) {
module->clear_result();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> literal,
- transfer_manager->TransferLiteralFromDevice(executor, result));
+ transfer_manager->TransferLiteralFromDevice(stream, result));
*module->mutable_result() = literal->ToProto();
return Status::OK();
}
@@ -219,21 +191,17 @@ Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
return Status::OK();
}
-Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout,
- const Shape& result_shape) const {
- if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) {
+Status Service::ValidateResultShape(const Shape& client_shape,
+ const Shape& result_shape) const {
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
+ if (!ShapeUtil::Compatible(client_shape, result_shape)) {
return InvalidArgument(
"Shape used to set computation result layout %s is not compatible "
"with result shape %s",
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanStringWithLayout(client_shape).c_str(),
ShapeUtil::HumanString(result_shape).c_str());
}
- if (!LayoutUtil::HasLayout(shape_with_layout)) {
- return InvalidArgument(
- "Shape used to set computation result layout %s does not have layout",
- ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
- }
- return ShapeUtil::ValidateShape(shape_with_layout);
+ return Status::OK();
}
StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
@@ -276,10 +244,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,
const ExecutionOptions* execution_options) {
auto config = MakeUnique<HloModuleConfig>(program_shape);
- ComputationLayout* host_computation_layout =
- config->mutable_host_entry_computation_layout();
- ComputationLayout* device_computation_layout =
- config->mutable_device_entry_computation_layout();
+ ComputationLayout* computation_layout =
+ config->mutable_entry_computation_layout();
if (program_shape.parameters_size() != argument_shapes.size()) {
return InvalidArgument("computation takes %d parameters, but %zu given",
program_shape.parameters_size(),
@@ -296,32 +262,22 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
ShapeUtil::HumanString(*argument_shapes[i]).c_str());
}
- TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i)
- ->CopyLayoutFromShape(*argument_shapes[i]));
- TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i)
- ->CopyLayoutFromShape(*argument_shapes[i]));
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ *argument_shapes[i]));
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
const auto& shape_with_output_layout =
execution_options->shape_with_output_layout();
- TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout,
- program_shape.result()));
TF_RETURN_IF_ERROR(
- host_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
- shape_with_output_layout));
+ ValidateResultShape(shape_with_output_layout, program_shape.result()));
TF_RETURN_IF_ERROR(
- device_computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
// If the result layout is not set, then choose the default.
- // TODO(b/29118294): Allow the compiler to choose a better layout in this
- // case.
- // TODO(b/78356948): We are forcing the default layout here. We should fix
- // clients which expect a default layout, to be explicit about it, by
- // passing the proper ExecutionOptions with shape_with_output_layout set.
- host_computation_layout->mutable_result_layout()->SetToDefaultLayout();
- device_computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
config->set_replica_count(options_.number_of_replicas());
@@ -376,8 +332,8 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
module_protos[i]->entry_computation_name().c_str());
TF_RETURN_IF_ERROR(
Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot));
- hlo_snapshots.push_back(std::move(hlo_snapshot));
}
+ hlo_snapshots.push_back(std::move(hlo_snapshot));
}
VLOG(1) << "Computations:";
@@ -409,22 +365,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
return std::move(executables);
}
-Status Service::ValidateEntryComputationLayout(HloModule* module) {
- const ComputationLayout& on_device =
- module->device_entry_computation_layout();
- for (int64 i = 0; i < on_device.parameter_count(); ++i) {
- TF_RET_CHECK(ShapeUtil::Equal(
- on_device.parameter_shape(i),
- execute_backend_->transfer_manager()->HostShapeToDeviceShape(
- module->host_entry_computation_layout().parameter_shape(i))));
- }
- TF_RET_CHECK(ShapeUtil::Equal(
- module->device_entry_computation_layout().result_shape(),
- execute_backend_->transfer_manager()->HostShapeToDeviceShape(
- module->host_entry_computation_layout().result_shape())));
- return Status::OK();
-}
-
StatusOr<std::vector<GlobalDataHandle>>
Service::ExecuteParallelAndRegisterResult(
tensorflow::gtl::ArraySlice<Executable*> executables,
@@ -526,7 +466,7 @@ Service::ExecuteParallelAndRegisterResult(
HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
TF_RETURN_IF_ERROR(
- executable->PopulateExecutionProfile(&hlo_profile, stream->parent()));
+ executable->PopulateExecutionProfile(&hlo_profile, stream));
XLA_LOG_LINES(
tensorflow::INFO,
hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription()));
@@ -720,7 +660,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
request.execution_options()));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
- << module_config->host_entry_computation_layout().ToString();
+ << module_config->entry_computation_layout().ToString();
// Adds to the vectors to build and execute the computations after the loop.
all_arguments.push_back(replicated_arguments);
@@ -749,6 +689,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
executable_ptrs.push_back(executable.get());
}
+ for (int i = 0; i < executable_ptrs.size(); i++) {
+ if (executable_ptrs[i]->dumping_snapshot()) {
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(
+ all_executors[i][0]->device_ordinal()));
+ TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
+ execute_backend_->transfer_manager(),
+ executable_ptrs[i]->hlo_snapshot()));
+ }
+ }
+
// Execute the generated executables in parallel and return the device
// handles for each computation's output.
ExecutionProfile profile;
@@ -764,6 +715,20 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
*result->add_responses() = response;
}
+ for (int i = 0; i < executable_ptrs.size(); i++) {
+ if (executable_ptrs[i]->dumping_snapshot()) {
+ TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
+ allocation_tracker_.ResolveForReplica(outputs[i], 0));
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(all_executors[i][0]));
+ TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
+ execute_backend_->transfer_manager(),
+ executable_ptrs[i]->hlo_snapshot()));
+ // Dump out the ith snapshot.
+ TF_RETURN_IF_ERROR(executable_ptrs[i]->DumpHloSnapshot());
+ }
+ }
+
VLOG(1) << "successfully completed 'execute-graph-parallel' request";
return Status::OK();
}
@@ -856,13 +821,15 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
device_allocator));
- // Check that on-host and on-device shapes are consistent.
- TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get()));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend->compiler()->RunBackend(
std::move(module), executor, device_allocator));
+ if (!execution_directory_path.empty()) {
+ executable->set_hlo_snapshot(std::move(hlo_snapshot));
+ }
+
return std::move(executable);
}
@@ -900,12 +867,14 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
execute_backend_->default_stream_executor(),
/*device_allocator=*/nullptr));
+ TF_ASSIGN_OR_RETURN(auto stream,
+ execute_backend_->BorrowStream(
+ execute_backend_->default_stream_executor()));
if (executable->dumping_snapshot()) {
executable->hlo_snapshot()->set_execution_platform(
execute_backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(
- replicated_arguments.front(),
- execute_backend_->default_stream_executor(),
+ replicated_arguments.front(), stream.get(),
execute_backend_->transfer_manager(), executable->hlo_snapshot()));
}
@@ -919,9 +888,9 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg,
TF_ASSIGN_OR_RETURN(
const ShapedBuffer* result_buffer,
allocation_tracker_.ResolveForReplica(result->output(), 0));
- TF_RETURN_IF_ERROR(RecordResult(
- *result_buffer, execute_backend_->default_stream_executor(),
- execute_backend_->transfer_manager(), executable->hlo_snapshot()));
+ TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
+ execute_backend_->transfer_manager(),
+ executable->hlo_snapshot()));
TF_RETURN_IF_ERROR(executable->DumpHloSnapshot());
}
@@ -959,14 +928,13 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
return_shape = &shaped_buffer->on_host_shape();
}
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- execute_backend_->stream_executor(shaped_buffer->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
+ shaped_buffer->device_ordinal()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Literal> result_literal,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
- executor, *shaped_buffer));
+ stream.get(), *shaped_buffer));
if (LayoutUtil::LayoutsInShapesEqual(*return_shape,
result_literal->shape())) {
@@ -1016,9 +984,10 @@ Status Service::TransferToServer(const TransferToServerRequest* arg,
execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
shape, execute_backend_->memory_allocator(),
executor->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
- executor, *literal, shaped_buffer));
+ stream.get(), *literal, shaped_buffer));
replicated_buffers.emplace_back(std::move(shaped_buffer));
}
TF_ASSIGN_OR_RETURN(*result->mutable_data(),
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 422bb95657..47d196fb2a 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -26,15 +26,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/allocation_tracker.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/channel_tracker.h"
-#include "tensorflow/compiler/xla/service/compilation_cache.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/execution_tracker.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/service_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -196,9 +193,6 @@ class Service : public ServiceInterface {
const ExecutionOptions& execution_options,
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments);
- // Assert that host- and device-shapes are in a consistent state.
- Status ValidateEntryComputationLayout(HloModule* module);
-
protected:
friend class LocalExecutable;
@@ -269,11 +263,11 @@ class Service : public ServiceInterface {
// will be the result of this computation.
Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result);
- // Convenience function which checks whether the given shape_with_layout
+ // Convenience function which checks whether the given client_shape
// (presumably passed by the client to set the result layout) is valid for the
// given computation result shape.
- Status ValidateResultShapeWithLayout(const Shape& shape_with_layout,
- const Shape& result_shape) const;
+ Status ValidateResultShape(const Shape& client_shape,
+ const Shape& result_shape) const;
// Returns the stream executors assigned to the replicas represented by the
// given device handle. Each device_handle is a virtual replicated device that
@@ -298,9 +292,6 @@ class Service : public ServiceInterface {
// Tracks asynchronously launched executions via the API.
ExecutionTracker execution_tracker_;
- // Cache containing previously built Executables.
- CompilationCache compilation_cache_;
-
// Backend to compile and execute computations on.
std::unique_ptr<Backend> execute_backend_;
diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto
deleted file mode 100644
index bb8d1cd2a1..0000000000
--- a/tensorflow/compiler/xla/service/session.proto
+++ /dev/null
@@ -1,85 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// This proto file defines messages which store the state of XLA
-// computations within the XLA service. A computation is stored as a record
-// of the operation requests used to build it.
-syntax = "proto3";
-
-import "tensorflow/compiler/xla/xla_data.proto";
-
-package xla;
-
-// Describes a single operation request.
-message OperationRequest {
- ComputationDataHandle output_handle = 1;
- Shape output_shape = 2;
-
- // For operations which call embedded computations such as "Map", these are
- // the version(s) that the embedded computation should be called at. A version
- // value of a computation is the ComputationDataHandle of the root of the
- // computation at the point in time.
- //
- // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single
- // embedded computation so this field will have a single value for those
- // operations.
- //
- // "While" operation takes two; index 0 is the "condition" version and index 1
- // is the "body" version.
- repeated int64 embedded_computation_versions = 3;
-
- // The actual request, which in itself is a tagged union of all possible
- // operation request types.
- OpRequest request = 4;
-}
-
-// Describes a sequence of operation requests which define an XLA
-// computation.
-message SessionComputation {
- string name = 1;
-
- // The ComputationHandle used to refer to this computation in the XLA
- // service.
- ComputationHandle computation_handle = 2;
-
- // Map from ComputationDataHandle value to operation request. The highest
- // ComputationDataHandle value corresponds to the root of the computation.
- map<int64, OperationRequest> requests = 3;
-}
-
-// Describes a group of SessionComputations with an "entry point" computation
-// that may refer to the other non-entry (AKA embedded) computations.
-//
-// This message is used to serialize a computation that has been built via the
-// XLA service API, along with its dependencies, for purposes such as
-// analysis/replay/file-storage.
-message SessionModule {
- // The entry computation, which was requested for serialization. This may have
- // referred to embedded computations, which are reflected below.
- SessionComputation entry = 1;
-
- // Embedded computations that are transitively referred to by the entry
- // computation.
- repeated SessionComputation embedded_computations = 2;
-
- // The arguments passed to the computation.
- repeated LiteralProto arguments = 3;
-
- // The result of the computation.
- LiteralProto result = 4;
-
- // The name of the platform used to run the computation.
- string execution_platform = 5;
-}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index d624f548b1..70edf7883f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -44,147 +44,18 @@ namespace xla {
namespace {
-// Return the UnaryOperation proto enum value associated with the given HLO
-// opcode.
-UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAbs:
- return UNOP_ABS;
- case HloOpcode::kCeil:
- return UNOP_CEIL;
- case HloOpcode::kClz:
- return UNOP_CLZ;
- case HloOpcode::kCos:
- return UNOP_COS;
- case HloOpcode::kExp:
- return UNOP_EXP;
- case HloOpcode::kExpm1:
- return UNOP_EXPM1;
- case HloOpcode::kFloor:
- return UNOP_FLOOR;
- case HloOpcode::kImag:
- return UNOP_IMAG;
- case HloOpcode::kIsFinite:
- return UNOP_IS_FINITE;
- case HloOpcode::kLog:
- return UNOP_LOG;
- case HloOpcode::kLog1p:
- return UNOP_LOG1P;
- case HloOpcode::kNot:
- return UNOP_NOT;
- case HloOpcode::kNegate:
- return UNOP_NEGATE;
- case HloOpcode::kReal:
- return UNOP_REAL;
- case HloOpcode::kRoundNearestAfz:
- return UNOP_ROUND_NEAREST_AFZ;
- case HloOpcode::kSign:
- return UNOP_SIGN;
- case HloOpcode::kSin:
- return UNOP_SIN;
- case HloOpcode::kSort:
- return UNOP_SORT;
- case HloOpcode::kTanh:
- return UNOP_TANH;
- default:
- LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
- << opcode;
- }
-}
-
-// Return the BinaryOperation proto enum value associated with the given HLO
-// opcode.
-BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAtan2:
- return BINOP_ATAN2;
- case HloOpcode::kComplex:
- return BINOP_COMPLEX;
- case HloOpcode::kMultiply:
- return BINOP_MUL;
- case HloOpcode::kAdd:
- return BINOP_ADD;
- case HloOpcode::kSubtract:
- return BINOP_SUB;
- case HloOpcode::kDivide:
- return BINOP_DIV;
- case HloOpcode::kEq:
- return BINOP_EQ;
- case HloOpcode::kGe:
- return BINOP_GE;
- case HloOpcode::kGt:
- return BINOP_GT;
- case HloOpcode::kLe:
- return BINOP_LE;
- case HloOpcode::kLt:
- return BINOP_LT;
- case HloOpcode::kNe:
- return BINOP_NE;
- case HloOpcode::kMaximum:
- return BINOP_MAX;
- case HloOpcode::kMinimum:
- return BINOP_MIN;
- case HloOpcode::kPower:
- return BINOP_POW;
- case HloOpcode::kRemainder:
- return BINOP_REM;
- case HloOpcode::kOr:
- return BINOP_OR;
- case HloOpcode::kAnd:
- return BINOP_AND;
- case HloOpcode::kShiftLeft:
- return BINOP_SHIFT_LEFT;
- case HloOpcode::kShiftRightArithmetic:
- return BINOP_SHIFT_RIGHT_ARITHMETIC;
- case HloOpcode::kShiftRightLogical:
- return BINOP_SHIFT_RIGHT_LOGICAL;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the TernaryOperation proto enum value associated with the given HLO
-// opcode.
-TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kClamp:
- return TRIOP_CLAMP;
- case HloOpcode::kSelect:
- return TRIOP_SELECT;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the VariadicOperation proto enum value associated with the given HLO
-// opcode.
-VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kTuple:
- return VAROP_TUPLE;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectNotTupleOrOpaque(const Shape& shape,
- tensorflow::StringPiece op_type) {
- if (ShapeUtil::IsTuple(shape)) {
- return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
- } else if (ShapeUtil::IsOpaque(shape)) {
- return InvalidArgument("Expected non-opaque argument for %s, but got %s.",
+Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument("Expected array argument for %s, but got %s.",
std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
- } else {
- return Status::OK();
}
+ return Status::OK();
}
Status VerifyReducerShape(const ProgramShape& reducer_shape,
@@ -198,11 +69,11 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape,
}
const Shape& accumulator_shape = reducer_shape.result();
- if (ShapeUtil::Rank(accumulator_shape) != 0) {
+ if (!ShapeUtil::IsArray(accumulator_shape) ||
+ ShapeUtil::Rank(accumulator_shape) != 0) {
return InvalidArgument(
- "Reduction function must have rank 0 (rank %lld reduction function "
- "given).",
- ShapeUtil::Rank(accumulator_shape));
+ "Reduction function must produce a scalar but has shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
}
// Check that the accumulator can be passed in as the first argument.
@@ -321,84 +192,79 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape);
-}
+ TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
-/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
- UnaryOperation operation, const Shape& arg) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
-
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
- switch (operation) {
- case UNOP_FLOOR:
- case UNOP_CEIL:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ switch (opcode) {
+ case HloOpcode::kFloor:
+ case HloOpcode::kCeil:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating for floor/ceil "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_COS:
- case UNOP_SIN:
- case UNOP_EXP:
- case UNOP_EXPM1:
- case UNOP_LOG:
- case UNOP_LOG1P:
- case UNOP_TANH:
- if (!ShapeUtil::ElementIsFloating(arg) &&
- !ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kCos:
+ case HloOpcode::kSin:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kTanh:
+ if (!ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
"sin/cos/exp/log/tanh operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_REAL:
- case UNOP_IMAG:
- if (!ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kReal:
+ case HloOpcode::kImag:
+ if (!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be complex for real/imag "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, F32);
- case UNOP_ABS:
- if (ShapeUtil::ElementIsComplex(arg)) {
+ return ShapeUtil::ChangeElementType(shape, F32);
+ case HloOpcode::kAbs:
+ if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
- arg, primitive_util::ComplexComponentType(arg.element_type()));
+ shape, primitive_util::ComplexComponentType(shape.element_type()));
}
- return arg;
- case UNOP_CLZ:
- case UNOP_NEGATE:
- case UNOP_ROUND_NEAREST_AFZ:
- case UNOP_SIGN:
- case UNOP_SORT:
- return arg;
-
- case UNOP_NOT:
- if (arg.element_type() != PRED &&
- !primitive_util::IsIntegralType(arg.element_type())) {
+ return shape;
+ case HloOpcode::kClz:
+ case HloOpcode::kNegate:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSign:
+ return shape;
+
+ case HloOpcode::kNot:
+ if (shape.element_type() != PRED &&
+ !primitive_util::IsIntegralType(shape.element_type())) {
return InvalidArgument(
"Expected pred or an integral element type in argument to Not "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
+ return shape;
- case UNOP_IS_FINITE:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ case HloOpcode::kIsFinite:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating point for IsFinite "
+ "Expected element type in shape to be floating "
+ "point for IsFinite "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, PRED);
+ return ShapeUtil::ChangeElementType(shape, PRED);
default:
return InvalidArgument(
"Unknown operation for unary shape inference: \"%s\".",
- UnaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -415,8 +281,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const Shape* arg_shape = nullptr;
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
element_type = arg_shape->element_type();
@@ -463,6 +328,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ for (const Shape* arg_shape : arg_shapes) {
+ if (arg_shape->element_type() != TOKEN) {
+ return InvalidArgument(
+ "Operands of token instructions must be TOKEN types.");
+ }
+ }
+ return ShapeUtil::MakeTokenShape();
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
@@ -473,12 +349,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
return InvalidArgument(
- "Convert does not allow tuples, so cannot convert from %s to %s.",
+ "Convert does not allow non-arrays, so cannot convert from %s to %s.",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
@@ -495,7 +372,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
@@ -542,7 +420,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
const Shape& operand_shape, const Shape& padding_value_shape,
const PaddingConfig& padding_config) {
- if (ShapeUtil::IsTuple(operand_shape)) {
+ if (!ShapeUtil::IsArray(operand_shape)) {
return InvalidArgument(
"Pad operation does not support tuple-shape operands.");
}
@@ -681,8 +559,8 @@ Status ValidateDotDimensionNumbers(
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
string message = tensorflow::strings::Printf(
@@ -768,8 +646,9 @@ Status ValidateDotDimensionNumbers(
}
/* static */ StatusOr<Shape>
-ShapeInference::InferDegenerateDimensionBroadcastShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
+ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
+ const Shape& lhs,
+ const Shape& rhs) {
TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
// The shapes have to be compatible. That is, if some dimension d has a
@@ -787,7 +666,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
+ HloOpcodeString(operation).c_str(),
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -797,8 +676,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
- BinaryOperation operation, const Shape& smaller_shape,
- const Shape& larger_shape,
+ const Shape& smaller_shape, const Shape& larger_shape,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
@@ -899,18 +777,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Binary op %s with different element types: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(),
+ HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -943,10 +818,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
// After InDim broadcasting, perform degenerate dimensions broadcasting.
- TF_ASSIGN_OR_RETURN(
- Shape indim_broadcast_shape,
- InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
- broadcast_dimensions));
+ TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
+ InferInDimBroadcastShape(smaller_shape, larger_shape,
+ broadcast_dimensions));
return InferDegenerateDimensionBroadcastShape(
operation, indim_broadcast_shape, larger_shape);
@@ -955,51 +829,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
- rhs->shape(), /*broadcast_dimensions=*/{});
+ return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
+ /*broadcast_dimensions=*/{});
}
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs,
- broadcast_dimensions);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
VLOG(2) << tensorflow::strings::Printf(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
+ HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str(),
Join(broadcast_dimensions, ", ").c_str());
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- BinaryOperation_Name(operation))));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- BinaryOperation_Name(operation))));
- switch (operation) {
- case BINOP_MAX:
- case BINOP_MIN:
- case BINOP_SUB:
- case BINOP_ADD:
- case BINOP_ATAN2:
- case BINOP_POW:
- case BINOP_DIV:
- case BINOP_REM:
- case BINOP_MUL:
- case BINOP_SHIFT_LEFT:
- case BINOP_SHIFT_RIGHT_ARITHMETIC:
- case BINOP_SHIFT_RIGHT_LOGICAL:
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ TF_RETURN_IF_ERROR(
+ ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+ HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+ HloOpcodeString(opcode))));
+ switch (opcode) {
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kPower:
+ case HloOpcode::kDivide:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_COMPLEX: {
+ case HloOpcode::kComplex: {
if (!ShapeUtil::ElementIsFloating(lhs)) {
return InvalidArgument(
"Expected element type in shape to be floating for complex compose "
@@ -1007,7 +874,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(lhs.element_type()).c_str());
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
@@ -1015,8 +882,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return Unimplemented("Complex component type is not implemented.");
}
}
- case BINOP_AND:
- case BINOP_OR:
+ case HloOpcode::kAnd:
+ case HloOpcode::kOr:
+ case HloOpcode::kXor:
if (lhs.element_type() != PRED &&
!primitive_util::IsIntegralType(lhs.element_type())) {
return InvalidArgument(
@@ -1024,24 +892,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"got %s.",
PrimitiveType_Name(lhs.element_type()).c_str());
}
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_EQ:
- case BINOP_GE:
- case BINOP_GT:
- case BINOP_LE:
- case BINOP_LT:
- case BINOP_NE: {
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kNe: {
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
return ShapeUtil::ChangeElementType(shape, PRED);
}
default:
return Unimplemented(
"Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
- BinaryOperation_Name(operation).c_str(),
- lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
+ HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(),
+ rhs.ShortDebugString().c_str());
}
}
@@ -1053,23 +921,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
- return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
- TernaryOperation operation, const Shape& lhs, const Shape& rhs,
- const Shape& ehs) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
- switch (operation) {
- case TRIOP_CLAMP:
+ switch (opcode) {
+ case HloOpcode::kClamp:
return InferClampShape(lhs, rhs, ehs);
- case TRIOP_SELECT:
+ case HloOpcode::kSelect:
return InferSelectShape(lhs, rhs, ehs);
+ case HloOpcode::kTupleSelect:
+ return InferTupleSelectShape(lhs, rhs, ehs);
default:
return InvalidArgument("Unknown operation %s.",
- TernaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -1077,6 +941,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const HloInstruction*> operands) {
std::vector<const Shape*> operand_shapes;
+ operand_shapes.reserve(operands.size());
for (const HloInstruction* operand : operands) {
operand_shapes.push_back(&operand->shape());
}
@@ -1086,27 +951,30 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
- return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
- operand_shapes);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- VariadicOperation operation,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
- switch (operation) {
- case VAROP_TUPLE: {
+ switch (opcode) {
+ case HloOpcode::kTuple: {
Shape result = ShapeUtil::MakeTupleShape({});
+ result.mutable_tuple_shapes()->Reserve(operand_shapes.size());
for (const Shape* shape : operand_shapes) {
ShapeUtil::AppendShapeToTuple(*shape, &result);
}
return result;
}
+ case HloOpcode::kSort: {
+ if (operand_shapes.size() == 1) {
+ return *operand_shapes[0];
+ } else if (operand_shapes.size() == 2) {
+ return ShapeUtil::MakeTupleShape(
+ {*operand_shapes[0], *operand_shapes[1]});
+ }
+ return InvalidArgument("Unexpected number of operands for sort");
+ }
default:
return InvalidArgument("Unknown operation %s.",
- VariadicOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -1121,15 +989,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// All arguments must have the same shape.
const Shape* arg_shape = arg_shapes[0];
for (size_t i = 1; i < arg_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+ TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
- if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
- !ShapeUtil::IsTuple(*arg_shape) &&
- ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+ if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
*arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
@@ -1212,11 +1077,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
const Shape& operand_shape, const Shape& scale_shape,
const Shape& offset_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm training"));
+ ExpectArray(operand_shape, "operand of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm training"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1318,11 +1183,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
const Shape& offset_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm inference"));
+ ExpectArray(operand_shape, "operand of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1465,16 +1330,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,
const Shape& output_grad_shape, int64 feature_index) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
+ ExpectArray(scale_shape, "scale input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- output_grad_shape, "output_grad input of batch norm grad"));
+ ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
@@ -1623,8 +1485,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dnums) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -1859,7 +1721,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
+ ExpectArray(*operand_shape, "operand of cross replica sum"));
}
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
@@ -1901,8 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
operand_shape.element_type()));
return InferWindowOutputShape(operand_shape, window,
@@ -1915,7 +1776,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
const Window& window, const Shape& source_shape,
const Shape& init_value_shape, const ProgramShape& scatter_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+ ExpectArray(operand_shape, "operand of select-and-scatter"));
// Check if the select function has a proper shape of (T,T) -> PRED.
if (select_shape.parameters_size() != 2) {
@@ -1980,7 +1841,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
Join(starts, ",").c_str(), Join(limits, ",").c_str(),
Join(strides, ",").c_str());
};
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
@@ -2039,10 +1900,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
- "start indices of dynamic slice"));
+ ExpectArray(start_indices_shape, "start indices of dynamic slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
@@ -2100,11 +1960,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
const Shape& operand_shape, const Shape& update_shape,
const Shape& start_indices_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ ExpectArray(operand_shape, "operand of dynamic update slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- start_indices_shape, "start indices of dynamic update slice"));
+ ExpectArray(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
+ "start indices of dynamic update slice"));
VLOG(2) << tensorflow::strings::Printf(
"updating slice of shape %s at dynamic start_indices %s with update "
@@ -2172,8 +2032,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
}
@@ -2303,7 +2162,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
return InvalidArgument("Broadcast with negative dimension size %lld.",
@@ -2322,7 +2181,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
ShapeUtil::MakeShape(operand.element_type(), new_sizes);
@@ -2354,7 +2213,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::iota(indices.begin(), indices.end(), 0);
@@ -2375,9 +2234,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
const Shape& min, const Shape& operand, const Shape& max) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
+ TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
+ TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
@@ -2410,15 +2269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
- bool compatible;
- if (ShapeUtil::IsTuple(on_true)) {
- // Select only defines the top-level buffer, so if it's a tuple, the two
- // input must match exactly.
- compatible = ShapeUtil::Compatible(on_true, on_false);
- } else {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
- }
- if (!compatible) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
return InvalidArgument(
"Operands to select must be the same shape; got %s and %s.",
ShapeUtil::HumanString(on_true).c_str(),
@@ -2430,7 +2281,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::HumanString(pred).c_str());
}
if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
- ShapeUtil::Rank(pred) == 0) {
+ ShapeUtil::IsScalar(pred)) {
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
@@ -2444,6 +2295,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
}
+/* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
+ const Shape& pred, const Shape& on_true, const Shape& on_false) {
+ // Select only defines the top-level buffer, so if it's a tuple, the two
+ // input must match exactly.
+ if (!ShapeUtil::Compatible(on_true, on_false)) {
+ return InvalidArgument(
+ "Operands to tuple-select must be the same shape; got %s and %s.",
+ ShapeUtil::HumanString(on_true).c_str(),
+ ShapeUtil::HumanString(on_false).c_str());
+ }
+ if (pred.element_type() != PRED) {
+ return InvalidArgument(
+ "TupleSelect's pred operand must have PRED element type; got %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ if (!ShapeUtil::IsScalar(pred)) {
+ return InvalidArgument(
+ "TupleSelect operation with non-scalar predicate: %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ return on_true;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
const ProgramShape& to_apply) {
@@ -2576,9 +2450,9 @@ static Status ValidateGatherDimensionNumbers(
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 9da2c99b41..1a5684e3c3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -46,8 +46,6 @@ class ShapeInference {
public:
// Infers the shape produced by applying the given unary operation to the
// given input shape.
- static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
- const Shape& arg);
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
const Shape& shape);
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
@@ -56,9 +54,6 @@ class ShapeInference {
// Infers the shape produced by applying the given binary operation to the
// given input shapes.
static StatusOr<Shape> InferBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
- static StatusOr<Shape> InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
@@ -67,9 +62,6 @@ class ShapeInference {
// Infers the shape produced by applying the given ternary operation to the
// given input shapes.
- static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
- const Shape& lhs, const Shape& rhs,
- const Shape& ehs);
static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
const Shape& rhs,
const Shape& ehs);
@@ -81,9 +73,6 @@ class ShapeInference {
// Infers the shape produced by applying the given variadic operation to the
// given input operand shapes.
static StatusOr<Shape> InferVariadicOpShape(
- VariadicOperation operation,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
- static StatusOr<Shape> InferVariadicOpShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes);
static StatusOr<Shape> InferVariadicOpShape(
@@ -227,6 +216,13 @@ class ShapeInference {
static StatusOr<Shape> InferConcatOpShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ // Infers the shape produced by a kAfterAll. Trivially this shape is always a
+ // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
+ // and checking operand shapes. This method verifies that the operand shapes
+ // are all TOKENs.
+ static StatusOr<Shape> InferAfterAllShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
// the shape is identical except for the element type.
@@ -279,7 +275,7 @@ class ShapeInference {
// the LHS and a single element in the RHS to produce a single output element,
// even in the presence of broadcasting of one of the operands over the other.
static StatusOr<Shape> InferElementwiseBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
// Helper for inferring the shape of Clamp ops.
@@ -290,12 +286,16 @@ class ShapeInference {
static StatusOr<Shape> InferSelectShape(const Shape& pred,
const Shape& on_true,
const Shape& on_false);
+ // Helper for inferring the shape of TupleSelect ops.
+ static StatusOr<Shape> InferTupleSelectShape(const Shape& pred,
+ const Shape& on_true,
+ const Shape& on_false);
// Helper for inferring shapes of binary operations which use degenerate
// dimension broadcasting (a dimension of size 1 in one operand is broadcast
// up to match the size of the dimension in the other operand).
static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs);
+ HloOpcode operation, const Shape& lhs, const Shape& rhs);
// Helper for inferring shapes of binary operations using "InDim"
// broadcasting. This is the broadcasting used in the *InDim binary operations
@@ -303,8 +303,7 @@ class ShapeInference {
// lower-rank shape than larger_shape. Returns the shape that the
// smaller_shape is broadcast to.
static StatusOr<Shape> InferInDimBroadcastShape(
- BinaryOperation operation, const Shape& smaller_shape,
- const Shape& larger_shape,
+ const Shape& smaller_shape, const Shape& larger_shape,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 0e61994a78..bafe14d6f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
- auto inferred_status = ShapeInference::InferUnaryOpShape(
- UnaryOperation::UNOP_NEGATE, matrix_shape);
+ auto inferred_status =
+ ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
}
@@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple);
+ HloOpcode::kSelect, pred_, tuple, tuple);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
@@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, SelectBadShapes) {
auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
+ HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Operands to select must be the same shape"));
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
HasSubstr("pred operand must have PRED"));
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
- matrix_64_48_, matrix_64_48_);
+ HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
+ matrix_64_48_);
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("with non-scalar predicate with dimensionality"));
// Tuples have a TUPLE element type and cannot be the pred of a select.
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}),
+ HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
ShapeUtil::MakeTupleShape({f32_, f32_}),
ShapeUtil::MakeTupleShape({f32_, f32_}));
ASSERT_FALSE(inferred_status_error4.ok());
@@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) {
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_,
- matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampAllScalar) {
- auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_);
+ auto inferred_status =
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_);
+ HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_);
+ HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
auto inferred_status = ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_);
+ HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
}
TEST_F(ShapeInferenceTest, ClampBadShapes) {
// Type mismatch
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_)
- .ok());
- ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_)
- .ok());
- // Dimension mismatch
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_64_, vector_32_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_64_, vector_32_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
.ok());
ASSERT_FALSE(
- ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP,
- vector_32_, vector_32_, vector_64_)
+ ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
.ok());
- // Dimension mismatch, where one operand is a scalar
+ // Dimension mismatch
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_)
+ HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
.ok());
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
- TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_)
+ HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
+ .ok());
+ // Dimension mismatch, where one operand is a scalar
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, vector_32_, f32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
+ vector_64_, f32_, vector_32_)
+ .ok());
+ ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
+ vector_64_, vector_32_)
.ok());
}
TEST_F(ShapeInferenceTest, Complex) {
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
const tensorflow::gtl::ArraySlice<int64>& bcast) {
- return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX,
- lhs, rhs, bcast);
+ return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
+ bcast);
};
// Inputs must be FP.
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
@@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) {
}
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
- StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
- VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
+ StatusOr<Shape> result =
+ ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
ASSERT_IS_OK(result.status());
ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
ShapeUtil::MakeTupleShape({s32_, f32_})));
@@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) {
TEST_F(ShapeInferenceTest, InferPowShape) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
- auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {});
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ HloOpcode::kPower, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
}
@@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) {
TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {});
+ ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {});
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
inferred_status.ValueOrDie()));
@@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
- auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {1});
+ auto inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec8, {0});
+ auto inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
ASSERT_FALSE(inferred_status_mismatch.ok());
- inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {0});
+ inferred_status_match =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
- inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, mat, vec16, {1});
+ inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
ASSERT_FALSE(inferred_status_mismatch.ok());
}
@@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2});
+ HloOpcode::kAdd, cube, matrix8_4, {1, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2});
+ HloOpcode::kAdd, cube, matrix16_4, {0, 2});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1});
+ HloOpcode::kAdd, cube, matrix16_8, {0, 1});
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
}
@@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
// "magical" broadcast rejected
- auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {});
+ auto inferred_status_error1 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
ASSERT_FALSE(inferred_status_error1.ok());
ASSERT_THAT(inferred_status_error1.status().error_message(),
HasSubstr("Automatic"));
// broadcast_dimension out of bounds for tensor's rank
- auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {3});
+ auto inferred_status_error2 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
ASSERT_FALSE(inferred_status_error2.ok());
ASSERT_THAT(inferred_status_error2.status().error_message(),
ContainsRegex("Broadcast dimension number .* too large"));
// broadcast_dimension doesn't match corresponding dimension
- auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, vec8, {0});
+ auto inferred_status_error3 =
+ ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
HasSubstr("Broadcast dimension 0 mismatch"));
// broadcast_dimensions list too long
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
+ HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(inferred_status_error4.status().error_message(),
HasSubstr("broadcast_dimensions has to match"));
// there's a dimension above the rank of the tensor
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
+ HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
ASSERT_FALSE(inferred_status_error5.ok());
ASSERT_THAT(inferred_status_error5.status().error_message(),
ContainsRegex("dimension number .* too large"));
// broadcasting dimensions don't match in this order
auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
+ HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
ASSERT_FALSE(inferred_status_error6.ok());
ASSERT_THAT(inferred_status_error6.status().error_message(),
HasSubstr("dimension 0 mismatch"));
@@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
// in a proper (strictly increasing) order, even if the lower-rank array
// matches the higher-rank array in many different ways.
auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
ASSERT_FALSE(inferred_status_error7.ok());
ASSERT_THAT(inferred_status_error7.status().error_message(),
HasSubstr("dimensions order is wrong"));
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
+ HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
ASSERT_FALSE(inferred_status_error8.ok());
ASSERT_THAT(inferred_status_error8.status().error_message(),
HasSubstr("dimensions order is wrong"));
@@ -1315,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(
inferred_status_error4.status().error_message(),
- HasSubstr("Expected non-tuple argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
@@ -1391,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- HasSubstr("Expected non-tuple argument"));
+ HasSubstr("Expected array argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1690,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for input"))
+ HasSubstr("Expected array argument for input"))
<< statusor.status();
}
@@ -1704,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for gather indices"))
+ HasSubstr("Expected array argument for gather indices"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index c4d01562c4..7232c658b3 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -22,8 +22,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/notification.h"
+
+using ::tensorflow::strings::StrCat;
namespace xla {
/* static */ tensorflow::mutex
@@ -36,8 +40,75 @@ TransferManager::GetPlatformTransferManagers() {
return r;
}
+StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ StatusOr<std::unique_ptr<Literal>> ret;
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ substream->ThenWaitFor(stream);
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+
+ tensorflow::Notification n;
+ TransferLiteralFromDevice(substream, device_buffer,
+ [&](StatusOr<std::unique_ptr<Literal>> arg) {
+ ret = std::move(arg);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
+Status TransferManager::TransferLiteralToDevice(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ substream->ThenWaitFor(stream);
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDeviceAsync(substream, literal, device_buffer));
+ return substream->BlockHostUntilDone();
+}
+
+StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ StatusOr<std::unique_ptr<Literal>> ret;
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+
+ tensorflow::Notification n;
+ TransferArrayFromDevice(substream, shape, source,
+ [&](StatusOr<std::unique_ptr<Literal>> arg) {
+ ret = std::move(arg);
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
Status TransferManager::TransferArrayToDevice(
- se::StreamExecutor* executor, const LiteralSlice& literal,
+ se::Stream* stream, const LiteralSlice& literal,
+ const se::DeviceMemoryBase& dest) {
+ // Implement the synchronous version by waiting on the asynchronous version.
+ // Use a substream so that if we are called from a HostCallback we don't
+ // deadlock.
+ se::Stream* substream = stream->GetOrCreateSubStream();
+ auto cleanup = tensorflow::gtl::MakeCleanup(
+ [&]() { stream->ReturnSubStream(substream); });
+ TF_RETURN_IF_ERROR(TransferArrayToDeviceAsync(substream, literal, dest));
+ return substream->BlockHostUntilDone();
+}
+
+Status TransferManager::TransferArrayToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest) {
const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape))
@@ -51,28 +122,32 @@ Status TransferManager::TransferArrayToDevice(
dest.size(), GetByteSizeRequirement(on_device_shape));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
- executor->platform(), executor->device_ordinal());
+ stream->parent()->platform(),
+ stream->parent()->device_ordinal());
shaped_buffer.set_buffer(dest, /*index=*/{});
- return TransferLiteralToDevice(executor, literal, shaped_buffer);
+ return TransferLiteralToDevice(stream, literal, shaped_buffer);
}
-StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
- se::StreamExecutor* executor, const Shape& shape,
- const se::DeviceMemoryBase& source) {
- TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape))
- << "Shape " << ShapeUtil::HumanString(shape)
- << " has a differently shaped representation on-device: "
- << ShapeUtil::HumanString(HostShapeToDeviceShape(shape));
+void TransferManager::TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape, const se::DeviceMemoryBase& source,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) {
+ if (!ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) {
+ auto error = StrCat("Shape ", ShapeUtil::HumanString(shape),
+ " has a differently shaped representation on-device: ",
+ ShapeUtil::HumanString(HostShapeToDeviceShape(shape)));
+ return done(FailedPrecondition("%s", error.c_str()));
+ }
if (source.size() < GetByteSizeRequirement(shape)) {
- return FailedPrecondition(
- "Allocation on device not large enough for array: "
- "%lld < %lld",
- source.size(), GetByteSizeRequirement(shape));
+ return done(
+ FailedPrecondition("Allocation on device not large enough for array: "
+ "%lld < %lld",
+ source.size(), GetByteSizeRequirement(shape)));
}
ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
- executor->platform(), executor->device_ordinal());
+ stream->parent()->platform(),
+ stream->parent()->device_ordinal());
shaped_buffer.set_buffer(source, /*index=*/{});
- return TransferLiteralFromDevice(executor, shaped_buffer);
+ return TransferLiteralFromDevice(stream, shaped_buffer, std::move(done));
}
/* static */ void TransferManager::RegisterTransferManager(
@@ -108,10 +183,14 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
}
Status TransferManager::WriteTupleIndexTables(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
- VLOG(2) << "Writing tuple index tables for " << device_buffer;
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ TF_RETURN_IF_ERROR(WriteTupleIndexTablesAsync(stream, device_buffer));
+ return stream->BlockHostUntilDone();
+}
- TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
+Status TransferManager::WriteTupleIndexTablesAsync(
+ se::Stream* stream, const ShapedBuffer& device_buffer) {
+ VLOG(2) << "Writing tuple index tables for " << device_buffer;
return ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_device_shape(),
@@ -129,7 +208,7 @@ Status TransferManager::WriteTupleIndexTables(
elements.push_back(device_buffer.buffer(element_index));
element_index.pop_back();
}
- return WriteSingleTupleIndexTable(executor, elements, device_subshape,
+ return WriteSingleTupleIndexTable(stream, elements, device_subshape,
&device_memory);
}
@@ -138,26 +217,20 @@ Status TransferManager::WriteTupleIndexTables(
}
Status TransferManager::TransferBufferFromDevice(
- se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
- int64 size, void* destination) {
+ se::Stream* stream, const se::DeviceMemoryBase& source, int64 size,
+ void* destination) {
if (source.size() < size) {
return FailedPrecondition(
"Source allocation on device not large enough for data tranfer: "
"%lld < %lld",
source.size(), size);
}
- auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination);
- if (!copy_status.ok()) {
- return AddStatus(
- Status(static_cast<tensorflow::error::Code>(copy_status.code()),
- copy_status.error_message()),
- "failed transfer from device to buffer");
- }
+ stream->ThenMemcpy(destination, source, size);
return Status::OK();
}
Status TransferManager::TransferBufferToDevice(
- se::StreamExecutor* executor, int64 size, const void* source,
+ se::Stream* stream, int64 size, const void* source,
se::DeviceMemoryBase* destination) {
if (destination->size() < size) {
return FailedPrecondition(
@@ -165,13 +238,7 @@ Status TransferManager::TransferBufferToDevice(
"%lld < %lld",
destination->size(), size);
}
- auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination);
- if (!copy_status.ok()) {
- return AddStatus(
- Status(static_cast<tensorflow::error::Code>(copy_status.code()),
- copy_status.error_message()),
- "failed transfer of buffer to device");
- }
+ stream->ThenMemcpy(destination, source, size);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 43a8092b06..249bdcc1f5 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <set>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -52,30 +52,65 @@ class TransferManager {
return host_shape;
}
- // Returns a literal containing the data held in the given ShapedBuffer.
- // using the provided executor. The optional literal_shape will be the shape
- // for the literal. The shape of the ShapedBuffer and
- // DeviceShape(literal_shape) must be compatible, but need not have the same
- // layout.
+ // Returns a literal containing the data held in the given ShapedBuffer
+ // using the provided executor. This operation is performed synchronously
+ // without waiting for any other operation on a stream to complete.
+ //
+ // This function should be avoided in favor of the asynchronous version below.
virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
- se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0;
+ se::Stream* stream, const ShapedBuffer& device_buffer);
+
+ // Begins transferring a literal containing the data held in the given
+ // ShapedBuffer using the provided executor.
+ //
+ // This operation is performed asynchronously on the given stream. It returns
+ // once the transfer is enqueued. 'done' is invoked with the result when
+ // complete.
+ //
+ // device_buffer is copied by reference and must live at least until done() is
+ // invoked.
+ virtual void TransferLiteralFromDevice(
+ se::Stream* stream, const ShapedBuffer& device_buffer,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done) = 0;
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
- // but need not have the same layout
- virtual Status TransferLiteralToDevice(se::StreamExecutor* executor,
+ // but need not have the same layout.
+ //
+ // This operation is performed synchronously without waiting for any other
+ // operation on a stream to complete. This function should be avoided in favor
+ // of the asynchronous version below.
+ virtual Status TransferLiteralToDevice(se::Stream* stream,
const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) = 0;
+ const ShapedBuffer& device_buffer);
+
+ // Transfers the given literal into the previously allocated device memory
+ // represented by the given ShapedBuffer using the given executor. The shape
+ // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
+ // but need not have the same layout.
+ //
+ // This operation is performed asynchronously on the given stream. It returns
+ // once the transfer is enqueued.
+ virtual Status TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) = 0;
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to
// transfer an array at a known address.
- Status TransferArrayToDevice(se::StreamExecutor* executor,
- const LiteralSlice& literal,
+ Status TransferArrayToDevice(se::Stream* stream, const LiteralSlice& literal,
const se::DeviceMemoryBase& dest);
+ void TransferArrayFromDevice(
+ se::Stream* stream, const Shape& shape,
+ const se::DeviceMemoryBase& source,
+ std::function<void(StatusOr<std::unique_ptr<Literal>>)> done);
+
+ Status TransferArrayToDeviceAsync(se::Stream* stream,
+ const LiteralSlice& literal,
+ const se::DeviceMemoryBase& dest);
StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
- se::StreamExecutor* executor, const Shape& shape,
+ se::Stream* stream, const Shape& shape,
const se::DeviceMemoryBase& source);
// Transfers the given literal into the Infeed interface of the device,
@@ -96,8 +131,10 @@ class TransferManager {
// Given an allocated ShapedBuffer, constructs the tuple index table(s) in
// each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
// ShapedBuffer is array-shaped this method does nothing.
- Status WriteTupleIndexTables(se::StreamExecutor* executor,
+ Status WriteTupleIndexTables(se::Stream* stream,
const ShapedBuffer& device_buffer);
+ Status WriteTupleIndexTablesAsync(se::Stream* stream,
+ const ShapedBuffer& device_buffer);
// Determines the byte size requirement for the given shape on the underlying
// architecture. This will be used to allocate an appropriately sized memory
@@ -144,7 +181,7 @@ class TransferManager {
// 'destination' buffer.
//
// size is the size to transfer to destination in bytes.
- virtual Status TransferBufferFromDevice(se::StreamExecutor* executor,
+ virtual Status TransferBufferFromDevice(se::Stream* stream,
const se::DeviceMemoryBase& source,
int64 size, void* destination);
@@ -152,15 +189,15 @@ class TransferManager {
// destination of the device.
//
// size is the size to transfer from source in bytes.
- virtual Status TransferBufferToDevice(se::StreamExecutor* executor,
- int64 size, const void* source,
+ virtual Status TransferBufferToDevice(se::Stream* stream, int64 size,
+ const void* source,
se::DeviceMemoryBase* destination);
// Writes the given device-memory pointers in 'elements' to the given region
// to construct a tuple index table in the platform-specific tuple
// representation.
virtual Status WriteSingleTupleIndexTable(
- se::StreamExecutor* executor,
+ se::Stream* stream,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
const Shape& shape, se::DeviceMemoryBase* region) = 0;
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index ba16dc640e..49e1f87319 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -178,7 +178,6 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
- convolution.SetupDerivedInstruction(new_conv.get());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index f73f1227aa..7051a4cf51 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -20,19 +20,19 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
@@ -69,7 +69,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
@@ -91,7 +91,7 @@ ENTRY entry_computation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
@@ -119,7 +119,7 @@ ENTRY entry_computation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
@@ -147,7 +147,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
@@ -160,11 +160,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
auto builder = HloComputation::Builder("entry");
// (1.0 + 2.0) * (2.0 - 3.0)
HloInstruction* const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
HloInstruction* const2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
HloInstruction* const3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
const1->shape(), HloOpcode::kAdd, const1, const2));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) {
HloComputation* entry_computation =
module->AddEntryComputation(builder.Build(mul));
HloInstruction* call = module->OutlineExpressionFromComputation(
- {add, sub, mul}, "", entry_computation);
+ {add, sub, mul}, "entry", entry_computation);
EXPECT_EQ(call, entry_computation->root_instruction());
HloComputation* callee_computation = call->to_apply();
// The arguments to the call should be const1, const2, and const3.
@@ -205,7 +205,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
const HloComputation* callee = module->GetComputationWithName("callee");
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index bb634e6573..990dfc410c 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -121,7 +122,6 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index,
}
namespace {
-
// Gather fusion instructions from 'instruction' into 'fusion_instructions'.
void GatherFusionInstructions(
HloInstruction* instruction,
@@ -292,22 +292,29 @@ Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) {
}
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
- // RecvDone aliases its input (Recv) tuple element {0} to its output.
+ // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
+ // output. The other indices ({} and {1}) define their own buffers.
PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
+ points_to_set.AddPointedToBuffer(
+ logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{}),
+ /*index=*/{});
+ points_to_set.AddPointedToBuffer(
+ logical_buffer_analysis_->GetBuffer(recv_done, /*index=*/{1}),
+ /*index=*/{1});
+
const PointsToSet& operand_points_to_set =
GetPointsToSet(recv_done->operand(0));
- // Recursively copy the points to set of the operand tuple {0}.
+ // Recursively copy the points to set of the operand tuple {0} to the output
+ // element {0}.
points_to_set.ForEachMutableElement(
[this, &points_to_set, &operand_points_to_set](
const ShapeIndex& index, PointsToSet::BufferList* buffers) {
- ShapeIndex src_index({0});
- for (auto element : index) {
- src_index.push_back(element);
+ if (index.empty() || index[0] != 0) {
+ return;
}
- *buffers = operand_points_to_set.element(src_index);
- for (auto& tuple_source :
- operand_points_to_set.tuple_sources(src_index)) {
+ *buffers = operand_points_to_set.element(index);
+ for (auto& tuple_source : operand_points_to_set.tuple_sources(index)) {
points_to_set.add_tuple_source(index, tuple_source);
}
});
@@ -315,7 +322,7 @@ Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
}
Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
- // Send creates a tuple of {aliased operand, U32 context}.
+ // Send creates a tuple of {aliased operand, U32 context, token}.
PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
// Creates the points to set for the tuple and its element at {1}.
@@ -328,6 +335,10 @@ Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
context_buffer->push_back(
&logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
+ auto token_buffer = points_to_set.mutable_element(ShapeIndex({2}));
+ token_buffer->push_back(
+ &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({2})));
+
// Recursively copy the points to set of the operand to output tuple {0}.
const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
operand_points_to_set.ForEachElement(
@@ -388,7 +399,7 @@ Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) {
+Status TuplePointsToAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
// Select allocates a new buffer and then shallow copies the on_true or
// on_false buffer into this new buffer. Which side is chosen cannot be
// determined statically so conservatively set the points-to set to the union
@@ -396,9 +407,9 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) {
//
// First create a copy of the on_true points-to set (and tuple sources), then
// add in elements of the on_false points-to set (tuple sources).
- auto on_true = select->operand(1);
- auto on_false = select->operand(2);
- PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true);
+ auto on_true = tuple_select->operand(1);
+ auto on_false = tuple_select->operand(2);
+ PointsToSet& points_to_set = CreateCopiedPointsToSet(tuple_select, on_true);
const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set;
points_to_set.ForEachMutableElement(
[&](const ShapeIndex& index, PointsToSet::BufferList* buffers) {
@@ -416,7 +427,7 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) {
// respective element in the points-to set should contain only itself.
points_to_set.mutable_element({})->clear();
points_to_set.AddPointedToBuffer(
- logical_buffer_analysis_->GetBuffer(select, /*index=*/{}),
+ logical_buffer_analysis_->GetBuffer(tuple_select, /*index=*/{}),
/*index=*/{});
return Status::OK();
}
@@ -723,15 +734,22 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return false;
}
if (user->opcode() == HloOpcode::kFusion) {
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
+ user->fusion_kind() == HloInstruction::FusionKind::kInput) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0);
+ } else {
+ HloInstruction* fusion_param =
+ user->fused_parameter(user->operand_index(operand));
+ return HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
+ fusion_param);
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -789,8 +807,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
return param_uses.size() == 1 && param_uses[0].first == callee_root &&
callee_root->IsElementwiseOnOperand(param_uses[0].second);
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ //
+ // Multi-output fusion will fail the check here as tuples are not considered
+ // an elementwise operation.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index c0d8241480..686bb05328 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -253,7 +253,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
Status HandleCopy(HloInstruction* copy) override;
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
- Status HandleSelect(HloInstruction* select) override;
+ Status HandleTupleSelect(HloInstruction* tuple_select) override;
string ToString() const;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index f558316b05..0ac8df4271 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase {
TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
@@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
// tuple.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto constant3 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({inner_tuple, constant3}));
@@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
// Create a tuple which contains duplicate elements.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant, constant, constant}));
@@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
// the same.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto copy = builder.AddInstruction(
@@ -317,9 +317,10 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
// Send forwards its operand to the output tuple at {0}.
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send = builder.AddInstruction(
- HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ HloInstruction::CreateSend(constant, token, /*channel_id=*/0));
auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
BuildModuleAndRunAnalysis(builder.Build());
@@ -342,8 +343,9 @@ TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
// RecvDone forwards its operand tuple element at {0} to the output.
auto builder = HloComputation::Builder(TestName());
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
- ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
+ ShapeUtil::MakeShape(F32, {1, 2, 3}), token, /*channel_id=*/0));
auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
BuildModuleAndRunAnalysis(builder.Build());
@@ -355,7 +357,7 @@ TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
ExpectHasTopLevelBuffers(
points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
- ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
+ ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {0}}});
}
TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
@@ -363,18 +365,18 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
// set containing the union of both sides.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant2, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -401,9 +403,9 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, tuple_shape, "param1"));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, param0, param1));
+ tuple_shape, HloOpcode::kTupleSelect, pred, param0, param1));
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
@@ -441,18 +443,18 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
// Select from two identical tuples. The result should not be ambiguous.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple2 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -472,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
// the right values.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple1 = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto inner_tuple2 = builder.AddInstruction(
@@ -486,9 +488,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
auto select = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ tuple1->shape(), HloOpcode::kTupleSelect, pred, tuple1, tuple2));
BuildModuleAndRunAnalysis(builder.Build());
@@ -519,9 +521,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
// have the operand of the bitcast in its points-to set.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
constant2->shape(), HloOpcode::kBitcast, constant2));
auto tuple =
@@ -555,9 +557,10 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
// Construct a tuple constant and kCopy it. Verify the points-to set of the
// copy correctly correctly points into the nested elements of the constant.
auto builder = HloComputation::Builder(TestName());
- auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
- Literal::CreateR1<float>({2.0, 42}).get()})));
+ auto tuple_constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
@@ -577,9 +580,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
// times. Verify buffer alias sets.
auto builder = HloComputation::Builder(TestName());
auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto inner_tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
auto tuple = builder.AddInstruction(
@@ -618,7 +621,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
auto tuple_element1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1));
auto ones = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f})));
// Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones)
auto update = builder.AddInstruction(HloInstruction::CreateBinary(
update_shape, HloOpcode::kAdd, tuple_element1, ones));
@@ -866,9 +869,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -960,9 +963,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// Create a DynamicUpdateSlice instruction of tuple element 1.
auto starts = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
@@ -1014,9 +1017,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto a = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
@@ -1025,7 +1028,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto add_operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1047,7 +1050,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
auto one = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto operand = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape, one, {1}));
@@ -1055,7 +1058,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
auto two = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
@@ -1120,7 +1123,7 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
auto sub_param = sub_builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "sub_param"));
auto one = sub_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto ones = sub_builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, one, {1}));
auto add = sub_builder.AddInstruction(
@@ -1148,5 +1151,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
call, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) {
+ Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32});
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16});
+
+ auto builder = HloComputation::Builder(TestName() + "_fusion");
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, full_shape, "full"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, broadcast_shape, "small"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(full_shape, param1, {0}));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ full_shape, HloOpcode::kAdd, param0, broadcast));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {add, broadcast}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc
index d668855084..77bdcc9de0 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc
@@ -30,10 +30,17 @@ limitations under the License.
namespace xla {
+TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) :
+ exclude_entry_computation_(exclude_entry_computation) {}
+
StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Initially add all GTE and Tuple instructions to the worklist.
std::queue<HloInstruction*> worklist;
for (auto* computation : module->computations()) {
+ if (exclude_entry_computation_ &&
+ computation == module->entry_computation()) {
+ continue;
+ }
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kGetTupleElement) {
@@ -69,7 +76,6 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Tuple
//
HloInstruction* top_tuple = nullptr;
- HloInstruction* first_gte = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
@@ -79,17 +85,10 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
can_simplify = false;
break;
}
- if (first_gte == nullptr) {
- first_gte = operand;
- } else if (!first_gte->has_compatible_sharding(operand)) {
- can_simplify = false;
- break;
- }
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
- instruction->shape()) ||
- !instruction->has_compatible_sharding(top_tuple)) {
+ instruction->shape())) {
can_simplify = false;
break;
}
@@ -118,14 +117,12 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
HloInstruction* element_source =
instruction->mutable_operand(0)->mutable_operand(
instruction->tuple_index());
- if (instruction->has_compatible_sharding(element_source)) {
- changed = true;
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
- for (HloInstruction* user : element_source->users()) {
- if (user->opcode() == HloOpcode::kTuple ||
- user->opcode() == HloOpcode::kGetTupleElement) {
- worklist.push(user);
- }
+ changed = true;
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
+ for (HloInstruction* user : element_source->users()) {
+ if (user->opcode() == HloOpcode::kTuple ||
+ user->opcode() == HloOpcode::kGetTupleElement) {
+ worklist.push(user);
}
}
}
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index e5e9b10b5b..7509501883 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -27,13 +27,20 @@ namespace xla {
// the module.
class TupleSimplifier : public HloPassInterface {
public:
- TupleSimplifier() {}
+ TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
+ explicit TupleSimplifier(bool exclude_entry_computation);
~TupleSimplifier() override {}
tensorflow::StringPiece name() const override { return "tuple-simplifier"; }
// Run tuple simplification on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ // When set, this pipeline stage will perform optimization of all computations
+ // apart from the module's entry computation. This is used by Graphcore's
+ // backend.
+ bool exclude_entry_computation_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
index ca9ae91281..39b693872d 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase {
TF_ASSERT_OK(changed_status.status());
EXPECT_EQ(change_expected, changed_status.ValueOrDie());
}
+ void Run(HloModule* module, bool change_expected, bool exclude_entry) {
+ TupleSimplifier simplifier(exclude_entry);
+ auto changed_status = simplifier.Run(module);
+ TF_ASSERT_OK(changed_status.status());
+ EXPECT_EQ(change_expected, changed_status.ValueOrDie());
+ }
const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
@@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) {
EXPECT_THAT(computation->root_instruction(), tuple);
}
+TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) {
+ // Verify that the root computation can be excluded
+ auto module = CreateNewModule();
+
+ HloInstruction* p0;
+ HloInstruction* p1;
+ HloComputation* c0;
+ HloComputation* c1;
+ HloComputation* entry;
+
+ {
+ HloComputation::Builder builder(TestName() + "_1");
+ p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
+
+ c0 = module->AddEmbeddedComputation(builder.Build());
+ }
+ {
+ HloComputation::Builder builder(TestName() + "_2");
+ p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2}));
+
+ c1 = module->AddEmbeddedComputation(builder.Build());
+ }
+ {
+ HloComputation::Builder builder(TestName() + "_Entry");
+ HloInstruction* tuple_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape_, "param"));
+ HloInstruction* call0 = builder.AddInstruction(
+ HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0));
+ HloInstruction* call1 = builder.AddInstruction(
+ HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1));
+ HloInstruction* tuple0 =
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
+ HloInstruction* gte2 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0));
+ HloInstruction* gte3 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3}));
+
+ entry = module->AddEntryComputation(builder.Build());
+ }
+
+ Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true);
+
+ EXPECT_THAT(c0->root_instruction(), p0);
+ EXPECT_THAT(c1->root_instruction(), p1);
+ EXPECT_THAT(entry->instruction_count(), 9);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc
index 754fd8ef16..d33d5bb8f3 100644
--- a/tensorflow/compiler/xla/service/tuple_util_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_util_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -37,7 +37,7 @@ ENTRY entry {
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h
deleted file mode 100644
index 5732a56caf..0000000000
--- a/tensorflow/compiler/xla/service/versioned_computation_handle.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
-
-#include <ostream>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-
-namespace xla {
-
-// A data structure encapsulating a ComputationHandle and version value of that
-// computation. This object is used to unambiguously refer to a particular
-// computation in the service.
-struct VersionedComputationHandle {
- // A version value unambiguously specifying the state of the computation at a
- // particular point in time as it is being built. This value is the
- // ComputationDataHandle of the current root instruction.
- using Version = int64;
-
- ComputationHandle handle;
- Version version;
-
- string ToString() const;
- bool operator==(const VersionedComputationHandle& other) const {
- return (handle.handle() == other.handle.handle()) &&
- (version == other.version);
- }
- bool operator<(const VersionedComputationHandle& other) const {
- return ((handle.handle() < other.handle.handle()) ||
- ((handle.handle() == other.handle.handle()) &&
- (version < other.version)));
- }
-};
-
-std::ostream& operator<<(std::ostream& out,
- const VersionedComputationHandle& versioned_handle);
-
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
index 0d2288d8ea..393e758038 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -55,7 +55,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -95,7 +95,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -136,7 +136,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -184,7 +184,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index e1ec12192f..32e69c335b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -53,7 +53,7 @@ HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation(
builder.AddInstruction(
HloInstruction::CreateParameter(0, param_shape, "param"));
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
return module->AddEmbeddedComputation(builder.Build());
}
@@ -125,7 +125,7 @@ TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) {
builder.AddInstruction(HloInstruction::CreateUnary(
scalar_s32, HloOpcode::kNegate, mul_result));
HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(4)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(4)));
HloInstruction* sub_result =
builder.AddInstruction(HloInstruction::CreateBinary(
scalar_s32, HloOpcode::kSubtract, negate_result, constant));
@@ -248,7 +248,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest,
TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -258,25 +260,32 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_s32, gte_0, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_s32, gte_0, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateToken());
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
-
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopInvariantCodeMotion{}.Run(&module()));
- EXPECT_FALSE(simplified_loop);
+ ASSERT_FALSE(simplified_loop);
EXPECT_THAT(while_inst->while_body()->instructions(),
Contains(op::Outfeed()));
@@ -287,7 +296,9 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
// bitcast either.
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
auto scalar_f32 = ShapeUtil::MakeShape(F32, {});
- Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ Shape while_shape =
+ ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, token_shape});
HloComputation* while_body = [&]() {
HloComputation::Builder builder(TestName() + ".while_body");
@@ -297,21 +308,29 @@ TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) {
HloInstruction::CreateGetTupleElement(scalar_s32, param, 0));
HloInstruction* gte_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ HloInstruction* in_token = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(token_shape, param, 2));
HloInstruction* bitcast_inst = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0));
+ HloInstruction* out_token = builder.AddInstruction(
+ HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, in_token, ""));
builder.AddInstruction(
- HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, ""));
- builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1}));
+ HloInstruction::CreateTuple({gte_0, gte_1, out_token}));
return module().AddEmbeddedComputation(builder.Build());
}();
HloComputation::Builder builder(TestName());
+ auto* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_s32, "param"));
+ auto* token = builder.AddInstruction(HloInstruction::CreateToken());
auto* init_value = builder.AddInstruction(
- HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ HloInstruction::CreateTuple({scalar_param, scalar_param, token}));
auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile(
while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
while_body, init_value));
+ builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, while_inst, 0));
module().AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 619e87caa5..2e1571943e 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -157,7 +157,7 @@ TEST_F(WhileLoopSimplifierTest,
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* true_op = while_op->while_body()->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
TF_ASSERT_OK(true_op->AddControlDependencyTo(
while_op->while_body()->root_instruction()));
ASSERT_TRUE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -175,9 +175,11 @@ TEST_F(WhileLoopSimplifierTest, LoopWithSendNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
while_body->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))),
+ token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateSendDone(send));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -190,8 +192,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithRecvNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
+ auto* token = while_body->AddInstruction(HloInstruction::CreateToken());
auto* recv = while_body->AddInstruction(
- HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
+ HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), token,
/*channel_id=*/0));
while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
@@ -208,8 +211,9 @@ TEST_F(WhileLoopSimplifierTest, LoopWithInfeedNotSimplified) {
auto* while_op = computation->root_instruction();
ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
- while_body->AddInstruction(
- HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config"));
+ auto token = while_body->AddInstruction(HloInstruction::CreateToken());
+ while_body->AddInstruction(HloInstruction::CreateInfeed(
+ ShapeUtil::MakeShape(F32, {1}), token, "config"));
EXPECT_FALSE(WhileLoopSimplifier().Run(the_module).ValueOrDie());
}
diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc
index 473eab2ea8..1ef17b9d7d 100644
--- a/tensorflow/compiler/xla/service/while_util.cc
+++ b/tensorflow/compiler/xla/service/while_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
@@ -38,7 +39,7 @@ static StatusOr<HloComputation*> WidenWhileCondition(
// the root instruction later. We later change the root instruction to
// something more appropriate.
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
return narrow_condition->parent()->AddEmbeddedComputation(builder.Build());
}();
@@ -154,7 +155,7 @@ MakeCountedLoopConditionComputation(const Shape& loop_state_shape,
{&loop_state_shape}, scalar_pred, "while_cond"));
HloInstruction* trip_count_constant = cond_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(trip_count)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(trip_count)));
HloInstruction* param = cond_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
@@ -175,7 +176,7 @@ static StatusOr<std::unique_ptr<HloComputation>> MakeCountedLoopBodyComputation(
CreateComputationWithSignature(
{&loop_state_shape}, loop_state_shape, "while_body"));
HloInstruction* one = body_computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
HloInstruction* param = body_computation->parameter_instruction(0);
TF_ASSIGN_OR_RETURN(HloInstruction * indvar,
MakeGetTupleElementHlo(param, 0));
@@ -203,7 +204,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues(
std::vector<HloInstruction*> init_values_with_indvar;
init_values_with_indvar.reserve(init_values.size() + 1);
HloInstruction* zero = computation->AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
init_values_with_indvar.push_back(zero);
c_copy(init_values, std::back_inserter(init_values_with_indvar));
return computation->AddInstruction(
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index bcc545c61d..2ccb919acf 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -50,7 +50,7 @@ ENTRY entry {
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
@@ -151,7 +151,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* while_body = module->GetComputationWithName("body");
@@ -179,7 +179,9 @@ body {
cond {
param.c = (s32[], s32[]) parameter(0)
- ROOT condition = pred[] infeed()
+ token = token[] after-all()
+ infeed = (pred[], token[]) infeed(token)
+ ROOT condition = pred[] get-tuple-element(infeed), index=0
}
ENTRY main {
@@ -190,7 +192,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* main = module->GetComputationWithName("main");
HloInstruction* while_instr = main->root_instruction();
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
index aa40b5cb26..83d696fe09 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -32,11 +32,12 @@ StatusOr<bool> ZeroSizedHloElimination::Run(HloModule* module) {
for (HloComputation* comp : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (instruction->HasSideEffect() ||
- ShapeUtil::IsTuple(instruction->shape())) {
+ !ShapeUtil::IsArray(instruction->shape()) ||
+ instruction->opcode() == HloOpcode::kConstant) {
continue;
}
if (comp->IsRemovable(instruction) &&
- ShapeUtil::HasZeroElements(instruction->shape())) {
+ ShapeUtil::IsZeroElementArray(instruction->shape())) {
TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(
Literal::CreateFromShape(instruction->shape()))));
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
index f5331280ee..b9ef18892d 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -67,7 +67,16 @@ TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) {
}
TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) {
- builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0));
+ auto token = builder_.AddInstruction(HloInstruction::CreateToken());
+ builder_.AddInstruction(
+ HloInstruction::CreateSend(zero_sized_param_, token, 0));
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination());
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateConstant) {
+ builder_.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1({})));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination());
EXPECT_FALSE(changed);
}
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 5b14953ebb..4aacc87b78 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -47,6 +47,9 @@ struct ShapeTreeNode {
// Children of this node, as indices into the container's nodes_ array.
std::vector<size_t> children;
+ // Tells whether this is a leaf node.
+ bool is_leaf = true;
+
explicit ShapeTreeNode(ShapeIndex index)
: ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNode(ShapeIndex index, T data)
@@ -102,8 +105,8 @@ class ShapeTree {
// Returns the data element associated with the array in the shape at the
// given index (see ShapeUtil::GetSubshape for how indexes are defined).
- const T& element(const ShapeIndex& index) const;
- T* mutable_element(const ShapeIndex& index);
+ const T& element(ShapeIndexView index) const;
+ T* mutable_element(ShapeIndexView index);
// Return the shape represented with this ShapeTree.
const Shape& shape() const { return *shape_; }
@@ -122,9 +125,7 @@ class ShapeTree {
// Returns true if the node at the given index is a leaf node (an array
// shape).
- bool IsLeaf(const ShapeIndex& index) const {
- return Lookup(index)->children.empty();
- }
+ bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
ShapeTree(const ShapeTree&) = default;
ShapeTree& operator=(const ShapeTree&) = default;
@@ -210,12 +211,12 @@ class ShapeTree {
// Returns an iterator pointing to the given ShapeIndex.
// REQUIRES: index must exist in the ShapeTree.
- iterator find(const ShapeIndex& index) {
+ iterator find(ShapeIndexView index) {
Node* element = Lookup(index);
return iterator(&nodes_, typename std::vector<Node>::iterator(element),
/*iterate_leaves_only=*/false);
}
- const_iterator find(const ShapeIndex& index) const {
+ const_iterator find(ShapeIndexView index) const {
Node* element = Lookup(index);
return iterator(&nodes_,
typename std::vector<Node>::const_iterator(element),
@@ -284,8 +285,8 @@ class ShapeTree {
static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
// Return the tree node at the given index.
- Node* Lookup(const ShapeIndex& index);
- const Node* Lookup(const ShapeIndex& index) const;
+ Node* Lookup(ShapeIndexView index);
+ const Node* Lookup(ShapeIndexView index) const;
// The nodes in this shape tree.
std::vector<Node> nodes_;
@@ -311,16 +312,14 @@ class ShapeTreeIterator
: nodes_(nodes),
node_(std::move(node)),
iterate_leaves_only_(iterate_leaves_only) {
- while (iterate_leaves_only && node_ != nodes_->end() &&
- !node_->children.empty()) {
+ while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
}
ShapeTreeIterator& operator++() {
++node_;
- while (iterate_leaves_only_ && node_ != nodes_->end() &&
- !node_->children.empty()) {
+ while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
++node_;
}
return *this;
@@ -333,8 +332,7 @@ class ShapeTreeIterator
ShapeTreeIterator& operator--() {
--node_;
- while (iterate_leaves_only_ && node_ > nodes_->begin() &&
- !node_->children.empty()) {
+ while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
--node_;
}
return *this;
@@ -358,7 +356,7 @@ class ShapeTreeIterator
ContainerType* nodes_;
IteratorType node_;
// True if we should not include interior nodes in our walk.
- bool iterate_leaves_only_;
+ const bool iterate_leaves_only_;
};
template <typename T>
@@ -379,6 +377,7 @@ void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
node->children.reserve(size);
+ node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
for (int i = 0; i < size; ++i) {
@@ -395,6 +394,7 @@ void ShapeTree<T>::InitChildren(const Shape& shape, Node* node) {
if (ShapeUtil::IsTuple(shape)) {
const int64 size = ShapeUtil::TupleElementCount(shape);
node->children.reserve(size);
+ node->is_leaf = false;
ShapeIndex shape_index = node->data.first;
shape_index.push_back(0);
for (int i = 0; i < size; ++i) {
@@ -463,17 +463,17 @@ ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
}
template <typename T>
-const T& ShapeTree<T>::element(const ShapeIndex& index) const {
+const T& ShapeTree<T>::element(ShapeIndexView index) const {
return Lookup(index)->data.second;
}
template <typename T>
-T* ShapeTree<T>::mutable_element(const ShapeIndex& index) {
+T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
return &Lookup(index)->data.second;
}
template <typename T>
-internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
+internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
Node* node = &nodes_[0];
for (const int64 i : index) {
CHECK_GE(i, 0);
@@ -485,7 +485,7 @@ internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(const ShapeIndex& index) {
template <typename T>
const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
- const ShapeIndex& index) const {
+ ShapeIndexView index) const {
return const_cast<ShapeTree*>(this)->Lookup(index);
}
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
index dc5facf158..51de82e957 100644
--- a/tensorflow/compiler/xla/shape_tree_test.cc
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -116,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) {
TestInitValueConstructor(nested_tuple_shape_, 10);
}
+TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) {
+ ShapeTree<int> shape_tree{ShapeUtil::MakeTupleShape({})};
+ EXPECT_EQ(0, shape_tree.leaf_count());
+}
+
TEST_F(ShapeTreeTest, ArrayShape) {
ShapeTree<int> shape_tree{array_shape_};
*shape_tree.mutable_element({}) = 42;
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index e8a28d76e9..f4668c0f55 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
@@ -42,30 +43,17 @@ limitations under the License.
namespace xla {
-string ShapeIndex::ToString() const {
- return tensorflow::strings::StrCat(
- "{", tensorflow::str_util::Join(indices_, ","), "}");
-}
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+string ShapeIndex::ToString() const { return ShapeIndexView(*this).ToString(); }
string ShapeIndexView::ToString() const {
- return tensorflow::strings::StrCat(
- "{",
- tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
- ","),
- "}");
+ return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
- if (size() != other.size()) {
- return false;
- }
- for (auto it = begin(), other_it = other.begin(); it != end();
- ++it, ++other_it) {
- if (*it != *other_it) {
- return false;
- }
- }
- return true;
+ return indices_ == other.indices_;
}
bool ShapeIndexView::operator!=(const ShapeIndexView& other) const {
@@ -84,18 +72,34 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
namespace {
+// Returns whether the given primitive type corresponds to an array shape.
+bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
+ return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
+ primitive_type != OPAQUE && primitive_type != TOKEN;
+}
+
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
-bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
- if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
- return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
+bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts,
+ bool ignore_fp_precision) {
+ if ((ignore_fp_precision &&
+ !ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) ||
+ (!ignore_fp_precision && !ShapeUtil::SameElementType(lhs, rhs))) {
+ VLOG(3) << "CompareShapes: lhs element type != rhs element type";
+ return false;
+ }
+
+ if (ShapeUtil::IsTuple(lhs)) {
+ return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) {
- return CompareShapes(l, r, compare_layouts);
+ return CompareShapes(l, r, compare_layouts,
+ ignore_fp_precision);
});
- } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
- return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
+ } else if (!ShapeUtil::IsArray(lhs)) {
+ // Non-tuple, non-array tupes such as opaque and token types are trivially
+ // the same.
+ return true;
}
if (compare_layouts) {
@@ -125,10 +129,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
- if (!ShapeUtil::SameElementType(lhs, rhs)) {
- VLOG(3) << "CompareShapes: lhs element type != rhs element type";
- return false;
- }
return true;
}
@@ -161,7 +161,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
} // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
- bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
+ bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
+ /*ignore_fp_precision=*/false);
if (!equal && VLOG_IS_ON(3)) {
VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
<< ", rhs = " << rhs.ShortDebugString();
@@ -170,9 +171,21 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return equal;
}
+/* static */ bool ShapeUtil::EqualIgnoringFpPrecision(const Shape& lhs,
+ const Shape& rhs) {
+ bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true,
+ /*ignore_fp_precision=*/true);
+ if (!equal && VLOG_IS_ON(3)) {
+ VLOG(3) << "ShapeUtil::EqualIgnoringFpPrecision differ: lhs = "
+ << lhs.ShortDebugString() << ", rhs = " << rhs.ShortDebugString();
+ }
+
+ return equal;
+}
+
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
- CHECK(!ShapeUtil::IsTuple(shape))
- << "Tuples do not have a rank, shape: " << shape;
+ CHECK(ShapeUtil::IsArray(shape))
+ << "Non-arrays do not have a rank, shape: " << shape;
return shape.dimensions_size();
}
@@ -199,8 +212,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShape(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
- DCHECK_NE(TUPLE, element_type);
- DCHECK_NE(OPAQUE, element_type);
+ CHECK(IsArrayPrimitiveType(element_type));
Shape result;
PopulateShape(element_type, dimensions, &result);
return result;
@@ -223,8 +235,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) {
- DCHECK_NE(TUPLE, element_type);
- DCHECK_NE(OPAQUE, element_type);
+ CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
@@ -257,6 +268,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
tensorflow::gtl::ArraySlice<Shape> shapes) {
Shape result;
result.set_element_type(TUPLE);
+ result.mutable_tuple_shapes()->Reserve(shapes.size());
for (const auto& shape : shapes) {
AppendShapeToTuple(shape, &result);
}
@@ -271,6 +283,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return result;
}
+/* static */ Shape ShapeUtil::MakeTokenShape() {
+ Shape result;
+ result.set_element_type(TOKEN);
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
+ return result;
+}
+
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
Shape* tuple_shape) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
@@ -294,7 +313,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
- if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
+ if (!IsArray(shape)) {
return false;
}
return primitive_util::BitWidth(shape.element_type()) == bits;
@@ -320,6 +339,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
case C64:
case TUPLE:
case OPAQUE:
+ case TOKEN:
return false;
default:
@@ -335,6 +355,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return primitive_util::IsFloatingPointType(shape.element_type());
}
+/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
+ return IsArrayPrimitiveType(shape.element_type());
+}
+
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
shape.tuple_shapes().end(), IsTuple);
@@ -345,7 +369,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::IsNil(const Shape& shape) {
- return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape);
+ return IsEmptyTuple(shape);
}
/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
@@ -361,6 +385,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.tuple_shapes(index);
}
+/* static */ int64 ShapeUtil::SubshapeCount(const Shape& shape) {
+ int64 n = 0;
+ ForEachSubshape(shape, [&](const Shape& literal_subshape,
+ const ShapeIndex& index) { ++n; });
+ return n;
+}
+
/* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
int64 limit) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
@@ -388,37 +419,31 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
+ CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
CHECK_EQ(shape.dimensions_size(), Rank(shape));
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
std::multiplies<int64>());
}
-/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
- return ElementsIn(shape) == 0;
+/* static */ int64 ShapeUtil::ElementsInRecursive(const Shape& shape) {
+ CHECK(IsArray(shape) || IsTuple(shape));
+ if (IsArray(shape)) {
+ return ElementsIn(shape);
+ }
+ int64 count = 0;
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ count += ElementsInRecursive(element_shape);
+ }
+ return count;
}
-/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
- return shape.element_type() == F32 && Rank(shape) == 0;
+/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
+ return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
-/* static */ string ShapeUtil::HumanString(const Shape& shape) {
- if (IsTuple(shape)) {
- string text = "(";
- const char* prefix = "";
- for (const Shape& elem_shape : shape.tuple_shapes()) {
- tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
- prefix = ", ";
- }
- text += ")";
- return text;
- } else {
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Lowercase(
- PrimitiveType_Name(shape.element_type())),
- "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
- }
+/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
+ return shape.element_type() == F32 && Rank(shape) == 0;
}
namespace {
@@ -470,48 +495,56 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
} // namespace
-/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
+/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
- tensorflow::strings::StrAppend(&text, prefix,
- HumanStringWithLayout(elem_shape));
+ StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
- } else {
- string result = tensorflow::strings::StrCat(
- LowercasePrimitiveTypeName(shape.element_type()), "[");
- for (int i = 0; i < shape.dimensions().size(); i++) {
- tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
- shape.dimensions(i));
+ }
+ return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
+ tensorflow::str_util::Join(shape.dimensions(), ","), "]");
+}
+
+/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
+ if (IsTuple(shape)) {
+ string text = "(";
+ const char* prefix = "";
+ for (const Shape& elem_shape : shape.tuple_shapes()) {
+ StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
+ prefix = ", ";
}
- result += "]";
- if (!IsScalar(shape) && !IsOpaque(shape)) {
- if (LayoutUtil::HasLayout(shape)) {
- tensorflow::strings::StrAppend(&result,
- LayoutUtil::HumanString(shape.layout()));
- }
+ text += ")";
+ return text;
+ }
+ string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
+ for (int i = 0; i < shape.dimensions().size(); i++) {
+ StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
+ }
+ result += "]";
+ if (!IsScalar(shape) && IsArray(shape)) {
+ if (LayoutUtil::HasLayout(shape)) {
+ StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
}
- return result;
}
+ return result;
}
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
std::vector<string> parameters;
for (auto& shape : program_shape.parameters()) {
const int i = parameters.size();
- parameters.push_back(
- tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
- ? program_shape.parameter_names(i)
- : "(unknown)",
- ": ", HumanString(shape)));
+ parameters.push_back(StrCat(i < program_shape.parameter_names_size()
+ ? program_shape.parameter_names(i)
+ : "(unknown)",
+ ": ", HumanString(shape)));
}
- return tensorflow::strings::StrCat(
- "(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
- HumanString(program_shape.result()));
+ return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
+ HumanString(program_shape.result()));
}
namespace {
@@ -545,12 +578,11 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so
// we convert in to the RE2-consumable type and then consume the corresponding
// amount from our StringPiece type.
+ static LazyRE2 shape_pattern = {
+ "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?"};
tensorflow::RegexpStringPiece s_consumable(s->data(), s->size());
- if (RE2::Consume(
- &s_consumable,
- "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?",
- &element_type_string, &dimensions_string, &format_string,
- &layout_string)) {
+ if (RE2::Consume(&s_consumable, *shape_pattern, &element_type_string,
+ &dimensions_string, &format_string, &layout_string)) {
size_t consumed = s->size() - s_consumable.size();
s->remove_prefix(consumed);
auto string_to_int64 = [&s](const string& input) -> StatusOr<int64> {
@@ -581,14 +613,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the primitive element type.
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
StringToPrimitiveType(element_type_string));
- if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
- primitive_type == OPAQUE) {
+ if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str());
}
Shape result;
- if (format_string.empty() && layout_string.empty()) {
+ if (primitive_type == OPAQUE) {
+ result = ShapeUtil::MakeOpaqueShape();
+ } else if (primitive_type == TOKEN) {
+ result = ShapeUtil::MakeTokenShape();
+ } else if (format_string.empty() && layout_string.empty()) {
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else if (format_string == "sparse") {
@@ -633,43 +668,37 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
- return rhs.element_type() == TUPLE &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
- }
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
+ return CompareShapes(lhs, rhs, /*compare_layouts=*/false,
+ /*ignore_fp_precision=*/false);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
+ if (IsArray(lhs)) {
+ return IsArray(rhs) && SameDimensions(lhs, rhs);
+ } else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType);
+ } else {
+ // Opaque, token, etc types are vacuously compatible.
+ return true;
}
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
+ if (IsArray(lhs)) {
+ return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) &&
+ CompatibleIgnoringElementType(lhs, rhs);
+ } else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision);
+ } else {
+ // Opaque, token, etc types are vacuously compatible.
+ return true;
}
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
- return CompatibleIgnoringElementType(lhs, rhs);
- }
- return false;
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
@@ -691,10 +720,6 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
switch (primitive_type) {
case PRED:
return sizeof(int8);
- case TUPLE:
- LOG(FATAL) << "tuples have no definitive size";
- case OPAQUE:
- LOG(FATAL) << "opaque have no definitive size";
case S8:
return sizeof(int8);
case S16:
@@ -721,6 +746,13 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(double);
case C64:
return sizeof(complex64);
+ case TOKEN:
+ // Tokens require no space.
+ return 0;
+ case TUPLE:
+ case OPAQUE:
+ LOG(FATAL) << PrimitiveType_Name(primitive_type)
+ << " primitive type has no definitive size";
default:
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
}
@@ -729,28 +761,32 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) {
return ByteSizeOfTupleIndexTable(shape, pointer_size);
+ } else if (IsArray(shape)) {
+ int64 byte_size = ByteSizeOfElements(shape);
+ if (LayoutUtil::IsSparseArray(shape)) {
+ byte_size += ByteSizeOfSparseIndices(shape);
+ }
+ return byte_size;
+ } else if (shape.element_type() == TOKEN) {
+ return 0;
}
- int64 byte_size = ByteSizeOfElements(shape);
- if (LayoutUtil::IsSparseArray(shape)) {
- byte_size += ByteSizeOfSparseIndices(shape);
- }
- return byte_size;
+ LOG(FATAL) << PrimitiveType_Name(shape.element_type())
+ << " primitive type has no definitive size";
}
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK_EQ(TUPLE, shape.element_type());
+ CHECK_EQ(TUPLE, shape.element_type());
CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size();
}
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK(ShapeUtil::IsArray(shape));
+ CHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count;
if (LayoutUtil::IsSparseArray(shape)) {
@@ -775,13 +811,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK(LayoutUtil::IsSparseArray(shape));
+ CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ return InvalidArgument("shape has invalid element type: %s",
+ shape.ShortDebugString().c_str());
+ }
if (shape.element_type() == TUPLE) {
if (shape.dimensions_size() != 0) {
return InvalidArgument("tuples must not have dimensions specified");
@@ -797,10 +837,24 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape.tuple_shapes_size() > 0) {
return InvalidArgument("non-tuple shape has tuple_shapes field");
}
- if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
- return InvalidArgument("shape has invalid element type: %s",
- shape.ShortDebugString().c_str());
+
+ // Tokens and opaques can should not have layout or dimensions.
+ if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
+ if (shape.dimensions_size() != 0) {
+ return InvalidArgument(
+ "shape has %s element type, but has dimensions field: %s",
+ LowercasePrimitiveTypeName(shape.element_type()).c_str(),
+ shape.ShortDebugString().c_str());
+ }
+ if (shape.has_layout()) {
+ return InvalidArgument(
+ "shape has %s element type, but has layout field: %s",
+ LowercasePrimitiveTypeName(shape.element_type()).c_str(),
+ shape.ShortDebugString().c_str());
+ }
+ return Status::OK();
}
+
if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument(
"shape's rank is mismatched with dimension count; rank=%lld "
@@ -817,6 +871,60 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
}
+ TF_RETURN_IF_ERROR(ValidateShapeSize(shape));
+ return Status::OK();
+}
+
+/* static */ Status ShapeUtil::ValidateShapeSize(const Shape& shape) {
+ VLOG(3) << "Validating shape size: " << ShapeUtil::HumanString(shape);
+
+ if (!IsArray(shape)) {
+ return Status::OK();
+ }
+
+ int64 shape_size = [&shape]() {
+ int64 shape_size;
+ if (LayoutUtil::IsSparseArray(shape)) {
+ shape_size = LayoutUtil::MaxSparseElements(shape.layout());
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ shape_size = MultiplyWithoutOverflow(shape_size, ShapeUtil::Rank(shape));
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ shape_size = MultiplyWithoutOverflow(shape_size, sizeof(int64));
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ }
+
+ shape_size = 1;
+
+ // This is intentionally unconditional: even if the shape is sparse, we want
+ // to verify the densified version has a reasonable size.
+ if (shape.dimensions().empty()) {
+ return shape_size;
+ }
+
+ for (int64 dim : shape.dimensions()) {
+ shape_size = MultiplyWithoutOverflow(shape_size, dim);
+ if (shape_size < 0) {
+ return shape_size;
+ }
+ }
+ shape_size = MultiplyWithoutOverflow(
+ shape_size, ByteSizeOfPrimitiveType(shape.element_type()));
+
+ return shape_size;
+ }();
+
+ if (shape_size < 0) {
+ return InvalidArgument("Shape %s size may overflow int64.",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+
+ VLOG(3) << "Shape size is valid: " << shape_size;
return Status::OK();
}
@@ -865,6 +973,21 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return *return_shape;
}
+/* static */ StatusOr<const Shape*> ShapeUtil::TryGetSubshape(
+ const Shape& shape, ShapeIndexView index) {
+ const Shape* return_shape = &shape;
+ for (auto i : index) {
+ if (!IsTuple(*return_shape) || i < 0 ||
+ i >= return_shape->tuple_shapes_size()) {
+ return InvalidArgument(
+ "Shape index %s not a valid subshape index for tuple with shape %s",
+ index.ToString().c_str(), shape.DebugString().c_str());
+ }
+ return_shape = &return_shape->tuple_shapes(i);
+ }
+ return return_shape;
+}
+
/* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
ShapeIndexView index) {
Shape* return_shape = shape;
@@ -901,64 +1024,9 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
return leaves;
}
-/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
- std::vector<int64> dimension_sizes;
- std::vector<int64> degenerate_dimensions;
- for (int64 i = 0; i < shape.dimensions_size(); ++i) {
- if (shape.dimensions(i) == 1) {
- degenerate_dimensions.push_back(i);
- } else {
- dimension_sizes.push_back(shape.dimensions(i));
- }
- }
-
- // Construct minor_to_major of stripped shape. The order of the non-degenerate
- // dimensions should be preserved from the original shape. First, create
- // vector of the non-degenerate dimensions from the original minor_to_major
- // array.
- std::vector<int64> minor_to_major;
- for (int64 i : shape.layout().minor_to_major()) {
- if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
- i) == degenerate_dimensions.end()) {
- minor_to_major.push_back(i);
- }
- }
-
- // The dimensions in minor_to_major need to be renumbered to account for the
- // degenerate dimensions which have removed. Decrement each dimension number
- // once for each degenerate dimension which has a smaller number.
- for (int i = 0; i < minor_to_major.size(); ++i) {
- int adjustment = 0;
- for (int64 dim : degenerate_dimensions) {
- if (minor_to_major[i] > dim) {
- adjustment++;
- }
- }
- minor_to_major[i] -= adjustment;
- }
-
- {
- std::vector<int64> dims(minor_to_major.size());
- std::iota(dims.begin(), dims.end(), 0);
- DCHECK(minor_to_major.size() == dims.size() &&
- std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
- dims.begin()));
- }
- Shape stripped_shape;
- if (LayoutUtil::IsDenseArray(shape)) {
- stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes,
- minor_to_major);
- } else if (LayoutUtil::IsSparseArray(shape)) {
- stripped_shape =
- MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes,
- shape.layout().max_sparse_elements());
- } else {
- stripped_shape = MakeShape(shape.element_type(), dimension_sizes);
- }
-
- VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
- VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
- return stripped_shape;
+/* static */ bool ShapeUtil::HasDegenerateDimensions(const Shape& shape) {
+ CHECK(ShapeUtil::IsArray(shape));
+ return ArrayContains<int64>(AsInt64Slice(shape.dimensions()), 1);
}
namespace {
@@ -1044,12 +1112,41 @@ Status ForEachMutableSubshapeHelper(
for (auto dim : Permute(permutation, shape.dimensions())) {
new_shape.add_dimensions(dim);
}
+
+ // If `shape` has a layout, by contract we choose a new layout such that the
+ // transpose defined by this permutation is a bitcast.
+ //
+ // Some formalism helps to understand the correct way to do this. We're going
+ // to do algebra in the group of permutations of the dimensions of `shape`.
+ //
+ // Since the order of `shape`'s dimensions is not permuted relative to itself,
+ // `shape`'s list of dimensions is isomorphic to the identity I.
+ //
+ // Let `shape`'s layout be L. A layout is a permutation which maps a
+ // minor-to-major physical layout to the order of a shape's logical dims.
+ // Therefore inverse of a layout maps from logical to physical dims, and so
+ // the physical layout of I is simply L'.I = L', where L' is the inverse of L.
+ //
+ // Let the argument `permutation` be P. This is a permutation over `shape`'s
+ // dimensions, so our return value will be a shape with dims P.I = P. Our
+ // goal is to construct a layout permutation L* that we can apply to P such
+ // that that the physical dimension ordering of the returned shape is the same
+ // as that of the original shape, namely L'.
+ //
+ // Our returned shape has dims P and layout L*, so its in-memory layout is
+ // L*'.P. Setting this equal to L' and solving for L*, we get:
+ //
+ // L*'.P = L' =>
+ // L*' = L'P' =>
+ // L* = P.L
+ //
if (shape.has_layout()) {
CHECK(LayoutUtil::IsDenseArray(shape));
Layout* new_layout = new_shape.mutable_layout();
new_layout->set_format(DENSE);
new_layout->clear_minor_to_major();
- for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
+ for (auto index : ComposePermutations(
+ permutation, AsInt64Slice(shape.layout().minor_to_major()))) {
new_layout->add_minor_to_major(index);
}
if (shape.layout().padded_dimensions_size() > 0) {
@@ -1059,6 +1156,13 @@ Status ForEachMutableSubshapeHelper(
new_layout->add_padded_dimensions(dim);
}
}
+ // The permutation accepted by TransposeIsBitcast is the inverse of the
+ // permutation here.
+ CHECK(TransposeIsBitcast(shape, new_shape, InversePermutation(permutation)))
+ << "shape=" << HumanStringWithLayout(shape)
+ << ", new_shape=" << HumanStringWithLayout(new_shape)
+ << ", permutation={" << tensorflow::str_util::Join(permutation, ",")
+ << "}";
}
return new_shape;
}
@@ -1066,6 +1170,9 @@ Status ForEachMutableSubshapeHelper(
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post) {
+ CHECK(IsArray(shape_pre));
+ CHECK(IsArray(shape_post));
+
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
std::vector<int64> deleted_indices;
@@ -1123,6 +1230,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
/* static */ std::vector<std::pair<int64, int64>>
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
const Shape& output_shape) {
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+
// Unmodified dimensions are merely common factors of rank 1.
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
@@ -1176,8 +1286,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
const Shape& output_shape) {
- CHECK(LayoutUtil::HasLayout(input_shape) &&
- LayoutUtil::HasLayout(output_shape));
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+ CHECK(LayoutUtil::HasLayout(input_shape));
+ CHECK(LayoutUtil::HasLayout(output_shape));
if (!SameElementType(input_shape, output_shape)) {
return false;
@@ -1339,6 +1451,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
const Shape& input_shape, const Shape& output_shape) {
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+
int64 input_rank = Rank(input_shape);
int64 output_rank = Rank(output_shape);
@@ -1473,6 +1588,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
Shape shape) {
+ CHECK(IsArray(shape));
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout();
@@ -1494,6 +1610,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) {
+ CHECK(IsArray(shape));
std::vector<int64> dims_to_delete;
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
if (!p(i)) {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 9df31d5d21..d576be724e 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -62,6 +62,8 @@ class ShapeIndex {
public:
ShapeIndex() = default;
ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
+ template <typename InputIt>
+ ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {}
bool empty() const { return indices_.empty(); }
size_t size() const { return indices_.size(); }
@@ -108,30 +110,33 @@ class ShapeIndex {
class ShapeIndexView {
public:
ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0)
- : ShapeIndexView(shape_index.data() + offset,
- shape_index.data() + shape_index.size()) {
+ : indices_(shape_index.data() + offset, shape_index.size()) {
CHECK_LE(offset, shape_index.size());
}
- ShapeIndexView(std::initializer_list<int64> indices)
- : ShapeIndexView(indices.begin(), indices.end()) {}
+ ShapeIndexView(std::initializer_list<int64> indices) : indices_(indices) {}
ShapeIndexView(const ShapeIndexView& other) = default;
using iterator = const int64*;
- iterator begin() const { return begin_; }
- iterator end() const { return end_; }
- int64 size() const { return std::distance(begin_, end_); }
- bool empty() const { return begin_ == end_; }
+ iterator begin() const { return indices_.begin(); }
+ iterator end() const { return indices_.end(); }
+ int64 size() const { return indices_.size(); }
+ bool empty() const { return indices_.empty(); }
int64 front() const {
CHECK(!empty());
- return *begin_;
+ return indices_.front();
}
ShapeIndexView ConsumeFront() const {
- CHECK(!empty());
- auto new_begin = begin_;
- ++new_begin;
- return ShapeIndexView(new_begin, end_);
+ ShapeIndexView result = *this;
+ result.indices_.pop_front();
+ return result;
+ }
+ ShapeIndexView ConsumeBack() const {
+ ShapeIndexView result = *this;
+ result.indices_.pop_back();
+ return result;
}
+ ShapeIndex ToShapeIndex() const { return ShapeIndex(begin(), end()); }
bool operator==(const ShapeIndexView& other) const;
bool operator!=(const ShapeIndexView& other) const;
@@ -139,10 +144,7 @@ class ShapeIndexView {
string ToString() const;
private:
- ShapeIndexView(iterator begin, iterator end) : begin_(begin), end_(end) {}
-
- iterator begin_;
- iterator end_;
+ tensorflow::gtl::ArraySlice<int64> indices_;
};
std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index);
@@ -169,24 +171,25 @@ class ShapeUtil {
// may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape.
- // Precondition: !IsTuple(shape)
+ // Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape);
- // Returns true if 'shape' has zero elements.
- static bool HasZeroElements(const Shape& shape);
+ // As ElementsIn(), but recurses through tuples.
+ static int64 ElementsInRecursive(const Shape& shape);
+
+ // Returns true if 'shape' is an array with zero elements.
+ static bool IsZeroElementArray(const Shape& shape);
// Returns the number of bytes required for an allocation of shape. The
// |pointer_size| parameter is used for calculating the size of tuple
// shapes. This includes only the size of the top-level buffer. For example, a
// tuple is stored as an array of pointers to other buffers. In this case,
// this method only returns the size of the pointer array.
- // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
- // !ShapeUtil::IsOpaque(shape)
static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
// Returns the number of bytes used to store the primitive_type.
//
- // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
+ // Precondition: ShapeUtil::IsArray(shape)
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
// Returns the number of bytes required to store the tuple member pointers for
@@ -245,7 +248,7 @@ class ShapeUtil {
}
// Returns the higher-precision element type if a and b are both floating
- // point types; otherwise, checks that they have the same element type
+ // point types; otherwise, checks that that they have the same element type
// and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) {
@@ -276,6 +279,9 @@ class ShapeUtil {
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);
+ // As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
+ static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
+
// Returns the rank (number of dimensions) of the given shape.
// Precondition: !IsTuple(shape)
static int64 Rank(const Shape& shape);
@@ -293,10 +299,10 @@ class ShapeUtil {
// Scalar-specific
static bool IsScalar(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
+ return IsArray(shape) && Rank(shape) == 0;
}
static bool IsEffectiveScalar(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
+ return IsArray(shape) && TrueRank(shape) == 0;
}
static bool IsScalarF32(const Shape& shape);
@@ -325,13 +331,17 @@ class ShapeUtil {
// into a custom operation.
static Shape MakeOpaqueShape();
+ // Creates a token shape. Values of this shape are used for ordering
+ // side-effecting operations.
+ static Shape MakeTokenShape();
+
// Appends a shape to the given tuple.
static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
// Appends a major dimension to the shape with the given bound.
static void AppendMajorDimension(int bound, Shape* shape);
- // Returns an empty tuple shape. Can be used to indicate side-effects.
+ // Returns an empty tuple shape. Can be used as a sentinel Shape value.
static Shape MakeNil() { return MakeTupleShape({}); }
// Checks whether the shape is initialized.
@@ -424,11 +434,15 @@ class ShapeUtil {
return shape.element_type() == OPAQUE;
}
+ // Returns whether the shape is an token value used for ordering
+ // side-effecting operations.
+ static bool IsToken(const Shape& shape) {
+ return shape.element_type() == TOKEN;
+ }
+
// Returns whether the shape is an array. Note that scalars are considered
// arrays.
- static bool IsArray(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape);
- }
+ static bool IsArray(const Shape& shape);
// Returns whether the shape is a tuple with at least one element which is
// also a tuple.
@@ -437,7 +451,7 @@ class ShapeUtil {
// Returns true if shape is an empty tuple.
static bool IsEmptyTuple(const Shape& shape);
- // Returns true if shape is an empty tuple, or is an array with no elements.
+ // Returns true if shape is the nil shape (an empty tuple).
static bool IsNil(const Shape& shape);
// Returns the number of elements in the given tuple shape.
@@ -448,6 +462,9 @@ class ShapeUtil {
// Precondition: IsTuple(shape) && TupleElementCount(shape) > index
static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
+ // Returns the number of elements, recursively, in the given shape.
+ static int64 SubshapeCount(const Shape& shape);
+
// Slices tuple elements in the range [start, limit) and returns a new tuple
// shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
@@ -467,8 +484,11 @@ class ShapeUtil {
static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
// GetSubshape and GetMutableSubshape return a particular nested Shape within
- // the given Shape argument.
+ // the given Shape argument. The non-Try variants check fail if index is
+ // invalid.
static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);
+ static StatusOr<const Shape*> TryGetSubshape(const Shape& shape,
+ ShapeIndexView index);
static Shape* GetMutableSubshape(Shape* shape, ShapeIndexView index);
// Returns whether the given index in the given shape is a leaf element of the
@@ -504,28 +524,18 @@ class ShapeUtil {
static Status ForEachMutableSubshapeWithStatus(
Shape* shape, const MutatingStatusVisitorFunction& func);
- // Removes all degenerate dimensions (size one) from the given shape. The
- // stripped minor_to_major preserves the relative ordering of non-degenerate
- // dimensions. The stripped shape has the property that the underlying
- // representation (bits in memory) for the stripped shape is the same as the
- // original shape modulo padding. Examples:
- //
- // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2}
- // stripped shape: F32 [2], minor_to_major = {0}
- //
- // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1}
- // stripped shape: F32 [6, 5], minor_to_major = {1, 0}
- //
- // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
- // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
- //
- // input shape: F32 [1, 1], minor_to_major = {0, 1}
- // stripped shape: F32 [], minor_to_major = {}
- // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
- static Shape StripDegenerateDimensions(const Shape& shape);
+ // Returns true if `shape` (which must be an array) with degenerate dimensions
+ // (dimensions with bound 1).
+ static bool HasDegenerateDimensions(const Shape& shape);
// Permutes the dimensions by the given permutation, so
- // return_value.dimensions[permutation[i]] = argument.dimensions[i]
+ // return_value.dimensions[permutation[i]] = argument.dimensions[i].
+ //
+ // Postcondition: For any valid permutation,
+ //
+ // !HasLayout(shape) ||
+ // TransposeIsBitcast(shape, PermuteDimensions(permutation, shape),
+ // InversePermutation(permutation)).
static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
const Shape& shape);
@@ -697,6 +707,10 @@ class ShapeUtil {
static size_t Hash(const Shape& shape);
private:
+ // Validates the shape size is sane. This makes sure it's safe to do
+ // calculations in int64 without overflowing.
+ static Status ValidateShapeSize(const Shape& shape);
+
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
@@ -708,7 +722,7 @@ class ShapeUtil {
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function,
bool parallel = false) {
- if (ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Status::OK();
}
CHECK_EQ(Rank(shape), base.size());
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index f7675e97da..6cdb46d674 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
+#include <numeric>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@@ -22,6 +23,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
@@ -93,12 +96,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
}
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
- string shape_string = "(f32[1],(f32[2]), f32[3])";
+ string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
- ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
+ ShapeUtil::MakeOpaqueShape(),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
@@ -136,6 +141,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
+TEST(ShapeUtilTest, ParseOpaqueType) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString("opaque[]"));
+ Shape expected = ShapeUtil::MakeOpaqueShape();
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, ParseTokenType) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
+ Shape expected = ShapeUtil::MakeTokenShape();
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
TEST(ShapeUtilTest, ParseInvalidShapeString) {
string shape_strings[] = {
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
@@ -153,6 +175,41 @@ TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2));
}
+TEST(ShapeUtilTest, TokenCompatibility) {
+ EXPECT_TRUE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_FALSE(ShapeUtil::Compatible(ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()})));
+}
+
+TEST(ShapeUtilTest, TokensEqualShapes) {
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_FALSE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeTokenShape()));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})})));
+ EXPECT_FALSE(ShapeUtil::Equal(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {1, 0})})));
+}
+
TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
auto layout_1 = shape_1.mutable_layout();
@@ -188,6 +245,24 @@ TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
}
+TEST(ShapeUtilTest, EqualIgnoringFpPrecision) {
+ EXPECT_TRUE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {4, 3}, {0, 1})));
+}
+
+TEST(ShapeUtilTest, UnequalIgnoringFpPrecision) {
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {0, 1})));
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 4}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(F16, {3, 4}, {1, 0})));
+ EXPECT_FALSE(ShapeUtil::EqualIgnoringFpPrecision(
+ ShapeUtil::MakeShapeWithLayout(F32, {4, 3}, {0, 1}),
+ ShapeUtil::MakeShapeWithLayout(PRED, {4, 3}, {0, 1})));
+}
+
TEST(ShapeUtilTest, CompatibleTuples) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
@@ -295,6 +370,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
+
+ EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN));
+ EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape()));
}
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
@@ -307,6 +385,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
}
+TEST(ShapeUtilTest, NilShape) {
+ EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3})));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})})));
+}
+
TEST(ShapeUtilTest, NestedTuple) {
EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
EXPECT_FALSE(ShapeUtil::IsNestedTuple(
@@ -337,25 +425,30 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
-TEST(ShapeUtilTest, HasZeroElements) {
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
- EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
+TEST(ShapeUtilTest, IsZeroElementArray) {
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17})));
+
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})})));
}
TEST(ShapeUtilTest, SameDimensions) {
@@ -449,19 +542,21 @@ TEST(ShapeUtilTest, IsLeafIndex) {
TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
+ Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
- Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
+ Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
+ EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
ShapeUtil::HumanString(tuple));
- EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(nested_tuple));
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
@@ -470,8 +565,10 @@ TEST(ShapeUtilTest, HumanString) {
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
ShapeUtil::HumanStringWithLayout(tuple));
- EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})",
- ShapeUtil::HumanStringWithLayout(nested_tuple));
+ EXPECT_EQ(
+ "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
+ "token[])",
+ ShapeUtil::HumanStringWithLayout(nested_tuple));
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
@@ -481,8 +578,9 @@ TEST(ShapeUtilTest, HumanString) {
"(unknown): u32[1,2], "
"(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
- "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0");
@@ -497,8 +595,10 @@ TEST(ShapeUtilTest, HumanString) {
"matrix: u32[1,2], "
"matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
- "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
+ "token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
}
@@ -713,14 +813,37 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
}
-TEST(ShapeUtilTest, StripDegenerateDimensions) {
- EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions(
- ShapeUtil::MakeShape(F32, {3, 1, 2})),
- ShapeUtil::MakeShape(F32, {3, 2})));
- EXPECT_TRUE(ShapeUtil::Equal(
- ShapeUtil::StripDegenerateDimensions(
- ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)),
- ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10)));
+TEST(ShapeUtilTest, HasDegenerateDimensions) {
+ EXPECT_TRUE(
+ ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 2})));
+ EXPECT_TRUE(
+ ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 1, 1})));
+ EXPECT_FALSE(
+ ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 3, 5})));
+ EXPECT_FALSE(
+ ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5})));
+}
+
+TEST(ShapeUtilTest, PermuteDimensionsLayout) {
+ std::vector<int64> layout(3);
+ std::iota(layout.begin(), layout.end(), 0);
+ do {
+ Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout);
+ SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s)));
+
+ std::vector<int64> permutation(3);
+ std::iota(permutation.begin(), permutation.end(), 0);
+ do {
+ SCOPED_TRACE(tensorflow::strings::StrCat(
+ "permutation=", tensorflow::str_util::Join(permutation, ",")));
+
+ // TransposeIsBitcast takes the inverse of the permutation that
+ // PermuteDimensions takes.
+ EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
+ s, ShapeUtil::PermuteDimensions(permutation, s),
+ InversePermutation(permutation)));
+ } while (std::next_permutation(permutation.begin(), permutation.end()));
+ } while (std::next_permutation(layout.begin(), layout.end()));
}
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index a62d49e9c7..6a75aa6794 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -65,6 +65,7 @@ cc_library(
srcs = ["test_utils.cc"],
hdrs = ["test_utils.h"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
@@ -88,6 +89,7 @@ cc_library(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:error_spec",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test",
@@ -117,11 +119,11 @@ cc_library(
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
"//tensorflow/compiler/xla/service:platform_util",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@@ -138,8 +140,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -179,6 +181,7 @@ cc_library(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -209,6 +212,7 @@ cc_library(
deps = [
":codegen_test_base",
":filecheck",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:test",
@@ -302,7 +306,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -345,7 +349,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -406,7 +410,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -435,7 +439,7 @@ xla_test(
tags = ["optonly"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -531,6 +535,7 @@ xla_test(
srcs = ["scalar_computations_test.cc"],
shard_count = 32,
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -573,7 +578,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -599,7 +604,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -645,7 +650,7 @@ xla_test(
tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -697,8 +702,9 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -763,6 +769,7 @@ xla_test(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
@@ -779,7 +786,7 @@ xla_test(
CONVOLUTION_TEST_DEPS = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -826,7 +833,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:local_client",
@@ -873,7 +880,7 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -885,6 +892,7 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
@@ -905,7 +913,7 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -938,7 +946,7 @@ xla_test(
],
deps = [
":test_utils",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1029,6 +1037,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1077,6 +1086,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1147,7 +1157,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1174,7 +1184,7 @@ xla_test(
deps = [
":client_library_test_base",
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1195,9 +1205,25 @@ xla_test(
],
deps = [
":client_library_test_base",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "token_hlo_test",
+ srcs = ["token_hlo_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ ],
+ deps = [
+ ":client_library_test_base",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -1210,6 +1236,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -1228,10 +1255,12 @@ xla_test(
name = "custom_call_test",
srcs = ["custom_call_test.cc"],
deps = [
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -1272,6 +1301,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -1349,7 +1379,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1372,7 +1402,7 @@ xla_test(
name = "prng_test",
srcs = ["prng_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
@@ -1397,6 +1427,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1511,7 +1542,7 @@ xla_test(
name = "cross_replica_sum_test",
srcs = ["cross_replica_sum_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
@@ -1520,11 +1551,11 @@ xla_test(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@@ -1555,7 +1586,7 @@ xla_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1595,7 +1626,7 @@ xla_test(
name = "compute_constant_test",
srcs = ["compute_constant_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -1670,7 +1701,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1695,7 +1726,7 @@ xla_test(
"enable_for_xla_interpreter",
],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1712,6 +1743,7 @@ tf_cc_test(
srcs = ["llvm_compiler_test.cc"],
tags = ["requires-gpu-sm35"],
deps = [
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:cpu_plugin",
@@ -1732,7 +1764,7 @@ xla_test(
name = "round_trip_packed_literal_test",
srcs = ["round_trip_packed_literal_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:packed_literal_reader",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
@@ -1755,7 +1787,7 @@ xla_test(
],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1764,6 +1796,7 @@ xla_test(
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
@@ -1782,7 +1815,7 @@ xla_test(
srcs = ["multioutput_fusion_test.cc"],
deps = [
"//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1822,7 +1855,7 @@ xla_test(
name = "local_client_allocation_test",
srcs = ["local_client_allocation_test.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
@@ -1845,7 +1878,7 @@ xla_test(
shard_count = 30,
tags = ["optonly"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
@@ -1891,7 +1924,7 @@ xla_test(
srcs = ["round_trip_transfer_test.cc"],
deps = [
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1912,7 +1945,7 @@ xla_test(
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1960,7 +1993,7 @@ xla_test(
":literal_test_util",
":local_client_test_base",
":xla_internal_test_main",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -1970,6 +2003,7 @@ xla_test(
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
],
)
@@ -2021,6 +2055,7 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 36a7064969..3ae96fa1bc 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -51,16 +51,16 @@ class ArrayElementwiseOpTestParamCount
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- builder.Neg(a);
+ auto a = ConstantR1<float>(&builder, {});
+ Neg(a);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
- builder.Neg(a);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ Neg(a);
ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
error_spec_);
@@ -68,10 +68,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
- std::numeric_limits<int32>::min(),
- std::numeric_limits<int32>::max()});
- builder.Neg(a);
+ auto a = ConstantR1<int32>(&builder,
+ {-1, 0, 1, 324, std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max()});
+ Neg(a);
// -min == min for int32 due to an overflow. In C++ it is undefined behavior
// to do this calculation. For XLA we have not specified that, so it
@@ -84,17 +84,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>({});
- builder.Neg(a);
+ auto a = ConstantR1<complex64>(&builder, {});
+ Neg(a);
ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>(
- {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
- builder.Neg(a);
+ auto a = ConstantR1<complex64>(
+ &builder, {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}});
+ Neg(a);
ComputeAndCompareR1<complex64>(
&builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}},
@@ -103,16 +103,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) {
XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int64>({
- -1,
- 1,
- 0,
- 0x12345678,
- static_cast<int64>(0xffffffff12345678l),
- static_cast<int64>(0x8000000000000000LL),
- static_cast<int64>(0x8000000000000001LL),
- });
- builder.Neg(a);
+ auto a =
+ ConstantR1<int64>(&builder, {
+ -1,
+ 1,
+ 0,
+ 0x12345678,
+ static_cast<int64>(0xffffffff12345678l),
+ static_cast<int64>(0x8000000000000000LL),
+ static_cast<int64>(0x8000000000000001LL),
+ });
+ Neg(a);
LOG(INFO) << -static_cast<int64>(0x7FFFFFFFFFFFFFFFLL);
ComputeAndCompareR1<int64>(&builder,
@@ -130,8 +131,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) {
XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- builder.IsFinite(a);
+ auto a = ConstantR1<float>(&builder, {});
+ IsFinite(a);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
@@ -141,21 +142,21 @@ static const float kNonCanonicalNaN = tensorflow::bit_cast<float>(0x7FD01234);
XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteScalarF32) {
XlaBuilder builder(TestName());
- builder.IsFinite(builder.ConstantR0<float>(NAN));
+ IsFinite(ConstantR0<float>(&builder, NAN));
ComputeAndCompareR0<bool>(&builder, false, {});
EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
- builder.IsFinite(builder.ConstantR0<float>(kNonCanonicalNaN));
+ IsFinite(ConstantR0<float>(&builder, kNonCanonicalNaN));
ComputeAndCompareR0<bool>(&builder, false, {});
const float inf = std::numeric_limits<float>::infinity();
- builder.IsFinite(builder.ConstantR0<float>(inf));
+ IsFinite(ConstantR0<float>(&builder, inf));
ComputeAndCompareR0<bool>(&builder, false, {});
- builder.IsFinite(builder.ConstantR0<float>(-inf));
+ IsFinite(ConstantR0<float>(&builder, -inf));
ComputeAndCompareR0<bool>(&builder, false, {});
- builder.IsFinite(builder.ConstantR0<float>(0.0f));
+ IsFinite(ConstantR0<float>(&builder, 0.0f));
ComputeAndCompareR0<bool>(&builder, true, {});
}
@@ -163,9 +164,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
XlaBuilder builder(TestName());
const float inf = std::numeric_limits<float>::infinity();
EXPECT_TRUE(std::isnan(kNonCanonicalNaN));
- auto a = builder.ConstantR1<float>(
- {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
- builder.IsFinite(a);
+ auto a = ConstantR1<float>(&builder,
+ {{NAN, 7.0f, kNonCanonicalNaN, -1.0f, inf, -inf}});
+ IsFinite(a);
ComputeAndCompareR1<bool>(&builder, {false, true, false, true, false, false},
{});
@@ -173,9 +174,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteR1F32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
- auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
- builder.Add(a, b);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ Add(a, b);
ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
error_spec_);
@@ -183,20 +184,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.Add(a, b);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ Add(a, b);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>(
- {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
- auto b = builder.ConstantR1<complex64>(
- {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
- builder.Add(a, b);
+ auto a = ConstantR1<complex64>(
+ &builder, {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}});
+ auto b = ConstantR1<complex64>(
+ &builder, {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}});
+ Add(a, b);
ComputeAndCompareR1<complex64>(
&builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {},
@@ -205,9 +206,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) {
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>({});
- auto b = builder.ConstantR1<complex64>({});
- builder.Add(a, b);
+ auto a = ConstantR1<complex64>(&builder, {});
+ auto b = ConstantR1<complex64>(&builder, {});
+ Add(a, b);
ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
}
@@ -224,8 +225,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0x8000000000000000LL,
0x8000000000000000LL,
1};
- std::unique_ptr<Literal> lhs_literal = Literal::CreateR1<uint64>({lhs});
- auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param");
+ std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<uint64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
@@ -238,12 +239,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
0,
1,
0x8000000000000000LL};
- std::unique_ptr<Literal> rhs_literal = Literal::CreateR1<uint64>({rhs});
- auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param");
+ std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<uint64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
- b.Add(lhs_param, rhs_param);
+ Add(lhs_param, rhs_param);
std::vector<uint64> expected(lhs.size());
for (int64 i = 0; i < lhs.size(); ++i) {
@@ -264,8 +265,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
1,
0,
-1};
- std::unique_ptr<Literal> lhs_literal = Literal::CreateR1<int64>({lhs});
- auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param");
+ std::unique_ptr<Literal> lhs_literal = LiteralUtil::CreateR1<int64>({lhs});
+ auto lhs_param = Parameter(&b, 0, lhs_literal->shape(), "lhs_param");
std::unique_ptr<GlobalData> lhs_data =
client_->TransferToServer(*lhs_literal).ConsumeValueOrDie();
@@ -277,12 +278,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
0x7FFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL,
0x7FFFFFFFFFFFFFFFLL};
- std::unique_ptr<Literal> rhs_literal = Literal::CreateR1<int64>({rhs});
- auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param");
+ std::unique_ptr<Literal> rhs_literal = LiteralUtil::CreateR1<int64>({rhs});
+ auto rhs_param = Parameter(&b, 1, rhs_literal->shape(), "rhs_param");
std::unique_ptr<GlobalData> rhs_data =
client_->TransferToServer(*rhs_literal).ConsumeValueOrDie();
- auto sub = b.Sub(lhs_param, rhs_param);
+ Sub(lhs_param, rhs_param);
std::vector<int64> expected(lhs.size());
for (int64 i = 0; i < lhs.size(); ++i) {
@@ -302,26 +303,26 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
b_values.push_back(2 * i / static_cast<float>(count + 2));
}
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({a_values});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a_constant = builder.ConstantR1<float>(a_values);
- auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
+ auto a_constant = ConstantR1<float>(&builder, a_values);
+ auto a_param = Parameter(&builder, 0, a_literal->shape(), "a_param");
- std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({b_values});
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param");
- auto b_param = builder.ConstantR1<float>(b_values);
+ auto b_constant = Parameter(&builder, 1, a_literal->shape(), "b_param");
+ auto b_param = ConstantR1<float>(&builder, b_values);
- auto sum1 = builder.Add(a_constant, b_constant);
- auto sum2 = builder.Add(a_constant, b_param);
- auto sum3 = builder.Add(a_param, b_constant);
- auto sum4 = builder.Add(a_param, b_param);
+ auto sum1 = Add(a_constant, b_constant);
+ auto sum2 = Add(a_constant, b_param);
+ auto sum3 = Add(a_param, b_constant);
+ auto sum4 = Add(a_param, b_param);
- auto sum = builder.Add(sum1, sum2);
- sum = builder.Add(sum, sum3);
- sum = builder.Add(sum, sum4);
+ auto sum = Add(sum1, sum2);
+ sum = Add(sum, sum3);
+ sum = Add(sum, sum4);
std::vector<float> expected;
for (int64 i = 0; i < count; ++i) {
@@ -334,9 +335,9 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
- auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
- builder.Sub(a, b);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = ConstantR1<float>(&builder, {100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ Sub(a, b);
ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
{}, error_spec_);
@@ -344,38 +345,38 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.Sub(a, b);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ Sub(a, b);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
- auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
- builder.Sub(a, b);
+ auto a = ConstantR1<int32>(&builder, {-1, 0, 2, 1000000000});
+ auto b = ConstantR1<int32>(&builder, {-1, 2, 1, -1});
+ Sub(a, b);
ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- auto b = builder.ConstantR1<int32>({});
- builder.Sub(a, b);
+ auto a = ConstantR1<int32>(&builder, {});
+ auto b = ConstantR1<int32>(&builder, {});
+ Sub(a, b);
ComputeAndCompareR1<int32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>(
- {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
- auto b = builder.ConstantR1<complex64>(
- {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
- builder.Sub(a, b);
+ auto a = ConstantR1<complex64>(&builder,
+ {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}});
+ auto b = ConstantR1<complex64>(
+ &builder, {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}});
+ Sub(a, b);
ComputeAndCompareR1<complex64>(
&builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {},
@@ -384,18 +385,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) {
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>({});
- auto b = builder.ConstantR1<complex64>({});
- builder.Sub(a, b);
+ auto a = ConstantR1<complex64>(&builder, {});
+ auto b = ConstantR1<complex64>(&builder, {});
+ Sub(a, b);
ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
- auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
- builder.Div(a, b);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = ConstantR1<float>(&builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
+ Div(a, b);
ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
error_spec_);
@@ -403,9 +404,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.Div(a, b);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ Div(a, b);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -442,7 +443,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
auto divisor_data =
CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
- builder.Div(dividend, divisor);
+ Div(dividend, divisor);
ComputeAndCompareR1<int32>(&builder, quotients,
{dividend_data.get(), divisor_data.get()});
@@ -454,7 +455,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
XlaOp dividend;
auto dividend_data =
CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- builder.Div(dividend, builder.ConstantR1<int32>(divisors));
+ Div(dividend, ConstantR1<int32>(&builder, divisors));
ComputeAndCompareR1<int32>(&builder, quotients, {dividend_data.get()});
}
@@ -467,7 +468,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
auto divisor_data =
CreateR1Parameter<int32>(divisors, 1, "divisor", &builder, &divisor);
- builder.Rem(dividend, divisor);
+ Rem(dividend, divisor);
ComputeAndCompareR1<int32>(&builder, remainders,
{dividend_data.get(), divisor_data.get()});
@@ -479,7 +480,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivS32s) {
XlaOp dividend;
auto dividend_data =
CreateR1Parameter<int32>(dividends, 0, "dividend", &builder, &dividend);
- builder.Rem(dividend, builder.ConstantR1<int32>(divisors));
+ Rem(dividend, ConstantR1<int32>(&builder, divisors));
ComputeAndCompareR1<int32>(&builder, remainders, {dividend_data.get()});
}
@@ -513,7 +514,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
&builder, &dividend);
auto divisor_data =
CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
- builder.Div(dividend, divisor);
+ Div(dividend, divisor);
ComputeAndCompareR1<uint32>(&builder, quotients,
{dividend_data.get(), divisor_data.get()});
@@ -524,7 +525,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
XlaOp dividend;
auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
&builder, &dividend);
- builder.Div(dividend, builder.ConstantR1<uint32>(divisors));
+ Div(dividend, ConstantR1<uint32>(&builder, divisors));
ComputeAndCompareR1<uint32>(&builder, quotients, {dividend_data.get()});
}
@@ -537,7 +538,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
&builder, &dividend);
auto divisor_data =
CreateR1Parameter<uint32>(divisors, 1, "divisor", &builder, &divisor);
- builder.Rem(dividend, divisor);
+ Rem(dividend, divisor);
ComputeAndCompareR1<uint32>(&builder, remainders,
{dividend_data.get(), divisor_data.get()});
@@ -548,7 +549,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
XlaOp dividend;
auto dividend_data = CreateR1Parameter<uint32>(dividends, 0, "dividend",
&builder, &dividend);
- builder.Rem(dividend, builder.ConstantR1<uint32>(divisors));
+ Rem(dividend, ConstantR1<uint32>(&builder, divisors));
ComputeAndCompareR1<uint32>(&builder, remainders, {dividend_data.get()});
}
@@ -556,11 +557,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>(
- {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
- auto b = builder.ConstantR1<complex64>(
- {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
- builder.Div(a, b);
+ auto a = ConstantR1<complex64>(
+ &builder, {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}});
+ auto b = ConstantR1<complex64>(&builder,
+ {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}});
+ Div(a, b);
ComputeAndCompareR1<complex64>(
&builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_);
@@ -568,20 +569,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) {
XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>({});
- auto b = builder.ConstantR1<complex64>({});
- builder.Div(a, b);
+ auto a = ConstantR1<complex64>(&builder, {});
+ auto b = ConstantR1<complex64>(&builder, {});
+ Div(a, b);
ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>(
- {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
- auto b = builder.ConstantR1<float>(
- {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
- builder.Rem(a, b);
+ auto a = ConstantR1<float>(
+ &builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
+ auto b = ConstantR1<float>(
+ &builder, {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
+ Rem(a, b);
ComputeAndCompareR1<float>(
&builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
@@ -590,20 +591,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.Rem(a, b);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ Rem(a, b);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<double>(
- {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
- auto b = builder.ConstantR1<double>(
- {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
- builder.Rem(a, b);
+ auto a = ConstantR1<double>(
+ &builder, {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
+ auto b = ConstantR1<double>(
+ &builder, {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
+ Rem(a, b);
ComputeAndCompareR1<double>(
&builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
@@ -612,9 +613,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
- auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- builder.Mul(a, b);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ Mul(a, b);
ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
{}, error_spec_);
@@ -622,9 +623,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.Mul(a, b);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ Mul(a, b);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -648,18 +649,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>(a_data);
- auto b = builder.ConstantR1<int32>(b_data);
- builder.Mul(a, b);
+ auto a = ConstantR1<int32>(&builder, a_data);
+ auto b = ConstantR1<int32>(&builder, b_data);
+ Mul(a, b);
ComputeAndCompareR1<int32>(&builder, expected, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- auto b = builder.ConstantR1<int32>({});
- builder.Mul(a, b);
+ auto a = ConstantR1<int32>(&builder, {});
+ auto b = ConstantR1<int32>(&builder, {});
+ Mul(a, b);
ComputeAndCompareR1<int32>(&builder, {}, {});
}
@@ -679,20 +680,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>(a_data);
- auto b = builder.ConstantR1<uint32>(b_data);
- builder.Mul(a, b);
+ auto a = ConstantR1<uint32>(&builder, a_data);
+ auto b = ConstantR1<uint32>(&builder, b_data);
+ Mul(a, b);
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>(
- {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
- auto b = builder.ConstantR1<complex64>(
- {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
- builder.Mul(a, b);
+ auto a = ConstantR1<complex64>(
+ &builder, {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}});
+ auto b = ConstantR1<complex64>(&builder,
+ {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}});
+ Mul(a, b);
ComputeAndCompareR1<complex64>(
&builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {},
@@ -701,27 +702,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) {
XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<complex64>({});
- auto b = builder.ConstantR1<complex64>({});
- builder.Mul(a, b);
+ auto a = ConstantR1<complex64>(&builder, {});
+ auto b = ConstantR1<complex64>(&builder, {});
+ Mul(a, b);
ComputeAndCompareR1<complex64>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({false, false, true, true});
- auto b = builder.ConstantR1<bool>({false, true, false, true});
- builder.And(a, b);
+ auto a = ConstantR1<bool>(&builder, {false, false, true, true});
+ auto b = ConstantR1<bool>(&builder, {false, true, false, true});
+ And(a, b);
ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
- auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
- builder.And(a, b);
+ auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
+ auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
+ And(a, b);
Array2D<bool> expected_array({{false, false}, {false, true}});
ComputeAndCompareR2<bool>(&builder, expected_array, {});
@@ -729,27 +730,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) {
XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({});
- auto b = builder.ConstantR1<bool>({});
- builder.And(a, b);
+ auto a = ConstantR1<bool>(&builder, {});
+ auto b = ConstantR1<bool>(&builder, {});
+ And(a, b);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({0, -1, -8});
- auto b = builder.ConstantR1<int32>({5, -7, 12});
- builder.And(a, b);
+ auto a = ConstantR1<int32>(&builder, {0, -1, -8});
+ auto b = ConstantR1<int32>(&builder, {5, -7, 12});
+ And(a, b);
ComputeAndCompareR1<int32>(&builder, {0, -7, 8}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<int32>({{0, -5}, {-1, 5}});
- auto b = builder.ConstantR2<int32>({{1, -6}, {4, 5}});
- builder.And(a, b);
+ auto a = ConstantR2<int32>(&builder, {{0, -5}, {-1, 5}});
+ auto b = ConstantR2<int32>(&builder, {{1, -6}, {4, 5}});
+ And(a, b);
Array2D<int32> expected_array({{0, -6}, {4, 5}});
ComputeAndCompareR2<int32>(&builder, expected_array, {});
@@ -757,27 +758,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- auto b = builder.ConstantR1<int32>({});
- builder.And(a, b);
+ auto a = ConstantR1<int32>(&builder, {});
+ auto b = ConstantR1<int32>(&builder, {});
+ And(a, b);
ComputeAndCompareR1<int32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({0, 1, 8});
- auto b = builder.ConstantR1<int32>({5, 7, 12});
- builder.And(a, b);
+ auto a = ConstantR1<int32>(&builder, {0, 1, 8});
+ auto b = ConstantR1<int32>(&builder, {5, 7, 12});
+ And(a, b);
ComputeAndCompareR1<int32>(&builder, {0, 1, 8}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<uint32>({{0, 1}, {3, 8}});
- auto b = builder.ConstantR2<uint32>({{1, 0}, {7, 6}});
- builder.And(a, b);
+ auto a = ConstantR2<uint32>(&builder, {{0, 1}, {3, 8}});
+ auto b = ConstantR2<uint32>(&builder, {{1, 0}, {7, 6}});
+ And(a, b);
Array2D<uint32> expected_array({{0, 0}, {3, 0}});
ComputeAndCompareR2<uint32>(&builder, expected_array, {});
@@ -785,27 +786,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>({});
- auto b = builder.ConstantR1<uint32>({});
- builder.And(a, b);
+ auto a = ConstantR1<uint32>(&builder, {});
+ auto b = ConstantR1<uint32>(&builder, {});
+ And(a, b);
ComputeAndCompareR1<uint32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({false, false, true, true});
- auto b = builder.ConstantR1<bool>({false, true, false, true});
- builder.Or(a, b);
+ auto a = ConstantR1<bool>(&builder, {false, false, true, true});
+ auto b = ConstantR1<bool>(&builder, {false, true, false, true});
+ Or(a, b);
ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<bool>({{false, false}, {true, true}});
- auto b = builder.ConstantR2<bool>({{false, true}, {false, true}});
- builder.Or(a, b);
+ auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
+ auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
+ Or(a, b);
Array2D<bool> expected_array({{false, true}, {true, true}});
ComputeAndCompareR2<bool>(&builder, expected_array, {});
@@ -813,27 +814,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) {
XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({});
- auto b = builder.ConstantR1<bool>({});
- builder.Or(a, b);
+ auto a = ConstantR1<bool>(&builder, {});
+ auto b = ConstantR1<bool>(&builder, {});
+ Or(a, b);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({0, -1, 8});
- auto b = builder.ConstantR1<int32>({5, -7, 4});
- builder.Or(a, b);
+ auto a = ConstantR1<int32>(&builder, {0, -1, 8});
+ auto b = ConstantR1<int32>(&builder, {5, -7, 4});
+ Or(a, b);
ComputeAndCompareR1<int32>(&builder, {5, -1, 12}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<int32>({{0, -1}, {8, 8}});
- auto b = builder.ConstantR2<int32>({{5, -7}, {4, 1}});
- builder.Or(a, b);
+ auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
+ auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
+ Or(a, b);
Array2D<int32> expected_array({{5, -1}, {12, 9}});
ComputeAndCompareR2<int32>(&builder, expected_array, {});
@@ -841,27 +842,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- auto b = builder.ConstantR1<int32>({});
- builder.Or(a, b);
+ auto a = ConstantR1<int32>(&builder, {});
+ auto b = ConstantR1<int32>(&builder, {});
+ Or(a, b);
ComputeAndCompareR1<int32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>({0, 1, 8});
- auto b = builder.ConstantR1<uint32>({5, 7, 4});
- builder.Or(a, b);
+ auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
+ auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
+ Or(a, b);
ComputeAndCompareR1<uint32>(&builder, {5, 7, 12}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<uint32>({{0, 1}, {8, 8}});
- auto b = builder.ConstantR2<uint32>({{5, 7}, {4, 1}});
- builder.Or(a, b);
+ auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
+ auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
+ Or(a, b);
Array2D<uint32> expected_array({{5, 7}, {12, 9}});
ComputeAndCompareR2<uint32>(&builder, expected_array, {});
@@ -869,25 +870,108 @@ XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>({});
- auto b = builder.ConstantR1<uint32>({});
- builder.Or(a, b);
+ auto a = ConstantR1<uint32>(&builder, {});
+ auto b = ConstantR1<uint32>(&builder, {});
+ Or(a, b);
ComputeAndCompareR1<uint32>(&builder, {}, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, XorPredR1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<bool>(&builder, {false, false, true, true});
+ auto b = ConstantR1<bool>(&builder, {false, true, false, true});
+ Xor(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, false}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorPredR2) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR2<bool>(&builder, {{false, false}, {true, true}});
+ auto b = ConstantR2<bool>(&builder, {{false, true}, {false, true}});
+ Xor(a, b);
+
+ Array2D<bool> expected_array({{false, true}, {true, false}});
+ ComputeAndCompareR2<bool>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementPredR1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<bool>(&builder, {});
+ auto b = ConstantR1<bool>(&builder, {});
+ Xor(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorS32R1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {0, -1, 8});
+ auto b = ConstantR1<int32>(&builder, {5, -7, 4});
+ Xor(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {5, 6, 12}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorS32R2) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR2<int32>(&builder, {{0, -1}, {8, 8}});
+ auto b = ConstantR2<int32>(&builder, {{5, -7}, {4, 1}});
+ Xor(a, b);
+
+ Array2D<int32> expected_array({{5, 6}, {12, 9}});
+ ComputeAndCompareR2<int32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementS32R1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<int32>(&builder, {});
+ auto b = ConstantR1<int32>(&builder, {});
+ Xor(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorU32R1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {0, 1, 8});
+ auto b = ConstantR1<uint32>(&builder, {5, 7, 4});
+ Xor(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {5, 6, 12}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorU32R2) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR2<uint32>(&builder, {{0, 1}, {8, 8}});
+ auto b = ConstantR2<uint32>(&builder, {{5, 7}, {4, 1}});
+ Xor(a, b);
+
+ Array2D<uint32> expected_array({{5, 6}, {12, 9}});
+ ComputeAndCompareR2<uint32>(&builder, expected_array, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, XorZeroElementU32R1) {
+ XlaBuilder builder(TestName());
+ auto a = ConstantR1<uint32>(&builder, {});
+ auto b = ConstantR1<uint32>(&builder, {});
+ Xor(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {}, {});
+}
XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({false, true, true, false});
- builder.Not(a);
+ auto a = ConstantR1<bool>(&builder, {false, true, true, false});
+ Not(a);
ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<bool>({{false, true}, {true, false}});
- builder.Not(a);
+ auto a = ConstantR2<bool>(&builder, {{false, true}, {true, false}});
+ Not(a);
Array2D<bool> expected_array({{true, false}, {false, true}});
ComputeAndCompareR2<bool>(&builder, expected_array, {});
@@ -895,24 +979,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) {
XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({});
- builder.Not(a);
+ auto a = ConstantR1<bool>(&builder, {});
+ Not(a);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({-1, 0, 1});
- builder.Not(a);
+ auto a = ConstantR1<int32>(&builder, {-1, 0, 1});
+ Not(a);
ComputeAndCompareR1<int32>(&builder, {0, -1, -2}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<int32>({{-1, 0}, {1, 8}});
- builder.Not(a);
+ auto a = ConstantR2<int32>(&builder, {{-1, 0}, {1, 8}});
+ Not(a);
Array2D<int32> expected_array({{0, -1}, {-2, -9}});
ComputeAndCompareR2<int32>(&builder, expected_array, {});
@@ -920,24 +1004,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- builder.Not(a);
+ auto a = ConstantR1<int32>(&builder, {});
+ Not(a);
ComputeAndCompareR1<int32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>({0, 4294967295});
- builder.Not(a);
+ auto a = ConstantR1<uint32>(&builder, {0, 4294967295});
+ Not(a);
ComputeAndCompareR1<uint32>(&builder, {4294967295, 0}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<uint32>({{0, 4294967295}, {1, 4294967294}});
- builder.Not(a);
+ auto a = ConstantR2<uint32>(&builder, {{0, 4294967295}, {1, 4294967294}});
+ Not(a);
Array2D<uint32> expected_array({{4294967295, 0}, {4294967294, 1}});
ComputeAndCompareR2<uint32>(&builder, expected_array, {});
@@ -945,19 +1029,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) {
XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>({});
- builder.Not(a);
+ auto a = ConstantR1<uint32>(&builder, {});
+ Not(a);
ComputeAndCompareR1<uint32>(&builder, {}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({static_cast<int32>(0x12345678),
- static_cast<int32>(0xF0001000), 1, 3, 77,
- 1, -3, 77});
- auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15, 32, 100, -1});
- builder.ShiftLeft(a, b);
+ auto a = ConstantR1<int32>(
+ &builder, {static_cast<int32>(0x12345678), static_cast<int32>(0xF0001000),
+ 1, 3, 77, 1, -3, 77});
+ auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 15, 32, 100, -1});
+ ShiftLeft(a, b);
ComputeAndCompareR1<int32>(&builder,
{static_cast<int32>(0x23456780), 0x00100000, 0x4,
@@ -967,11 +1051,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
- static_cast<int32>(0x10001000), 1, 3, 77,
- 1, -3, 77});
- auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2, 32, 100, -1});
- builder.ShiftRightArithmetic(a, b);
+ auto a = ConstantR1<int32>(
+ &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
+ 1, 3, 77, 1, -3, 77});
+ auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 2, 32, 100, -1});
+ ShiftRightArithmetic(a, b);
ComputeAndCompareR1<int32>(
&builder,
@@ -982,11 +1066,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
- static_cast<int32>(0x10001000), 1, 3, 77,
- 1, -3, 77});
- auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5, 32, 100, -1});
- builder.ShiftRightLogical(a, b);
+ auto a = ConstantR1<int32>(
+ &builder, {static_cast<int32>(0x92345678), static_cast<int32>(0x10001000),
+ 1, 3, 77, 1, -3, 77});
+ auto b = ConstantR1<int32>(&builder, {4, 8, 2, 7, 5, 32, 100, -1});
+ ShiftRightLogical(a, b);
ComputeAndCompareR1<int32>(&builder,
{0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
@@ -994,10 +1078,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>(
- {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
- auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15, 32, 100, ~0u});
- builder.ShiftLeft(a, b);
+ auto a = ConstantR1<uint32>(&builder,
+ {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77});
+ auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 15, 32, 100, ~0u});
+ ShiftLeft(a, b);
ComputeAndCompareR1<uint32>(
&builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {});
@@ -1005,10 +1089,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>(
- {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
- auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2, 32, 100, ~0u});
- builder.ShiftRightArithmetic(a, b);
+ auto a = ConstantR1<uint32>(&builder,
+ {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
+ auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 2, 32, 100, ~0u});
+ ShiftRightArithmetic(a, b);
ComputeAndCompareR1<uint32>(
&builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {});
@@ -1016,10 +1100,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>(
- {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
- auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5, 32, 100, ~0u});
- builder.ShiftRightLogical(a, b);
+ auto a = ConstantR1<uint32>(&builder,
+ {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77});
+ auto b = ConstantR1<uint32>(&builder, {4, 8, 2, 7, 5, 32, 100, ~0u});
+ ShiftRightLogical(a, b);
ComputeAndCompareR1<uint32>(&builder,
{0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {});
@@ -1028,18 +1112,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 2.25f, 10.0f, NAN});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({});
- auto rhs = builder.ConstantR1<float>({});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {});
+ auto rhs = ConstantR1<float>(&builder, {});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
@@ -1047,9 +1131,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
- builder.Ge(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ Ge(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
}
@@ -1057,9 +1141,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
- builder.Gt(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ Gt(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
}
@@ -1067,9 +1151,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
- builder.Le(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 5.0f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ Le(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
}
@@ -1077,9 +1161,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
- builder.Lt(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ Lt(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
}
@@ -1088,9 +1172,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Eq(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, false, false, false, true, false, false, false, true},
@@ -1099,9 +1184,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({});
- auto rhs = builder.ConstantR1<int32>({});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<int32>(&builder, {});
+ auto rhs = ConstantR1<int32>(&builder, {});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
@@ -1109,26 +1194,26 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
- {1.0f, 25.5f},
- {2.25f, -3.0f},
- {NAN, 0.0f},
- {1.0f, 6.0f}});
- auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
- {1.0f, 5.0f},
- {2.25f, -3.0f},
- {10.0f, 0.0f},
- {1.0f, NAN}});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
+ {1.0f, 25.5f},
+ {2.25f, -3.0f},
+ {NAN, 0.0f},
+ {1.0f, 6.0f}});
+ auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
+ {1.0f, 5.0f},
+ {2.25f, -3.0f},
+ {10.0f, 0.0f},
+ {1.0f, NAN}});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<complex64>({});
- auto rhs = builder.ConstantR1<complex64>({});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<complex64>(&builder, {});
+ auto rhs = ConstantR1<complex64>(&builder, {});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {}, {});
}
@@ -1138,17 +1223,17 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<complex64>({{-2.5f, 10.0f},
- {1.0f, 25.5f},
- {2.25f, -3.0f},
- {NAN, 0.0f},
- {1.0f, 6.0f}});
- auto rhs = builder.ConstantR1<complex64>({{0.0f, 10.0f},
- {1.0f, 5.0f},
- {2.25f, -3.0f},
- {10.0f, 0.0f},
- {1.0f, NAN}});
- builder.Ne(lhs, rhs);
+ auto lhs = ConstantR1<complex64>(&builder, {{-2.5f, 10.0f},
+ {1.0f, 25.5f},
+ {2.25f, -3.0f},
+ {NAN, 0.0f},
+ {1.0f, 6.0f}});
+ auto rhs = ConstantR1<complex64>(&builder, {{0.0f, 10.0f},
+ {1.0f, 5.0f},
+ {2.25f, -3.0f},
+ {10.0f, 0.0f},
+ {1.0f, NAN}});
+ Ne(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {true, true, false, true, true}, {});
}
@@ -1158,9 +1243,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({10.0f, 25.5f, 1.0f, 10.0f, NAN});
- builder.Ne(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {10.0f, 25.5f, 1.0f, 10.0f, NAN});
+ Ne(lhs, rhs);
ComputeAndCompareR1<bool>(&builder, {true, false, true, true, true}, {});
}
@@ -1169,9 +1254,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Ne(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Ne(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, true, true, true, false, true, true, true, false}, {});
@@ -1181,9 +1267,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Ge(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Ge(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, false, false, true, true, false, true, true, true}, {});
@@ -1193,9 +1280,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Gt(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Gt(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, false, false, true, false, false, true, true, false},
@@ -1206,9 +1294,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Le(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Le(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, true, true, false, true, true, false, false, true}, {});
@@ -1218,9 +1307,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
- auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
- builder.Lt(lhs, rhs);
+ auto lhs =
+ ConstantR1<int32>(&builder, {min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = ConstantR1<int32>(&builder, {min, 0, max, -1, 0, 1, min, 0, max});
+ Lt(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, true, true, false, false, true, false, false, false},
@@ -1230,9 +1320,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Eq(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Eq(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, false, false, false, true, false, false, false, true},
@@ -1242,9 +1332,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Ne(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Ne(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, true, true, true, false, true, true, true, false}, {});
@@ -1253,9 +1343,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Ge(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Ge(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, false, false, true, true, false, true, true, true}, {});
@@ -1264,9 +1354,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Gt(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Gt(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, false, false, true, false, false, true, true, false},
@@ -1276,9 +1366,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Le(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Le(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {true, true, true, false, true, true, false, false, true}, {});
@@ -1287,9 +1377,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
- auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
- builder.Lt(lhs, rhs);
+ auto lhs = ConstantR1<uint32>(&builder, {0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = ConstantR1<uint32>(&builder, {0, 1, max, 4, 5, 6, 0, 1, max});
+ Lt(lhs, rhs);
ComputeAndCompareR1<bool>(
&builder, {false, true, true, false, false, true, false, false, false},
@@ -1300,10 +1390,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
auto lhs =
- builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
+ ConstantR1<float>(&builder, {4.0f, 2.0f, 2.0f, NAN, 6.0f, -2.0f, -2.0f});
auto rhs =
- builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
- builder.Pow(lhs, rhs);
+ ConstantR1<float>(&builder, {2.0f, -2.0f, 3.0f, 10.0f, NAN, 3.0f, 4.0f});
+ Pow(lhs, rhs);
ComputeAndCompareR1<float>(
&builder, {16.0f, 0.25f, 8.0f, NAN, NAN, -8.0f, 16.0f}, {}, error_spec_);
@@ -1312,9 +1402,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
SetFastMathDisabled(true);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({-2.0f, -0.6f, -0.6f, 0.0f});
- auto rhs = builder.ConstantR1<float>({0.5f, 0.6f, -0.6f, -0.6f});
- builder.Pow(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {-2.0f, -0.6f, -0.6f, 0.0f});
+ auto rhs = ConstantR1<float>(&builder, {0.5f, 0.6f, -0.6f, -0.6f});
+ Pow(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {NAN, NAN, NAN, INFINITY}, {},
error_spec_);
@@ -1322,9 +1412,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowNonIntegerF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({});
- auto rhs = builder.ConstantR1<float>({});
- builder.Pow(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {});
+ auto rhs = ConstantR1<float>(&builder, {});
+ Pow(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -1336,14 +1426,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> param_literal = Literal::CreateR1<float>(values);
+ std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
std::unique_ptr<GlobalData> param_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
- auto sum = b.ConstantR0<float>(0.0f);
- auto param = b.Parameter(0, param_literal->shape(), "param");
+ auto sum = ConstantR0<float>(&b, 0.0f);
+ auto param = Parameter(&b, 0, param_literal->shape(), "param");
for (float exponent : exponents) {
- sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent)));
+ sum = Add(sum, Pow(param, ConstantR0<float>(&b, exponent)));
}
std::vector<float> expected;
@@ -1364,15 +1454,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- b.Pow(b.Exp(param0), param1);
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ Pow(Exp(param0), param1);
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1389,15 +1479,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- b.Log(b.Pow(param0, param1));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ Log(Pow(param0, param1));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1414,15 +1504,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- b.Mul(b.Exp(param0), b.Exp(param1));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ Mul(Exp(param0), Exp(param1));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1439,15 +1529,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- b.Div(param0, b.Exp(param1));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ Div(param0, Exp(param1));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1465,21 +1555,21 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- auto param2 = b.Parameter(2, literal2->shape(), "param2");
- b.Div(b.Div(param0, param1), param2);
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ Div(Div(param0, param1), param2);
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1497,22 +1587,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- auto param2 = b.Parameter(2, literal2->shape(), "param2");
- b.Div(param0, b.Div(param1, param2));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ Div(param0, Div(param1, param2));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1530,22 +1620,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 9.5f, -11.0f, -0.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- auto param2 = b.Parameter(2, literal2->shape(), "param2");
- b.Div(param0, b.Pow(param1, param2));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ Div(param0, Pow(param1, param2));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1564,27 +1654,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
std::vector<float> values2 = {0.1f, 1.1f, 6.9f, 12.5f, -15.0f, -0.5f};
std::vector<float> values3 = {2.1f, 3.1f, 9.9f, -4.5f, -11.0f, -21.5f};
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>(values0);
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>(values0);
std::unique_ptr<GlobalData> data0 =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>(values1);
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>(values1);
std::unique_ptr<GlobalData> data1 =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal2 = Literal::CreateR1<float>(values2);
+ std::unique_ptr<Literal> literal2 = LiteralUtil::CreateR1<float>(values2);
std::unique_ptr<GlobalData> data2 =
client_->TransferToServer(*literal2).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal3 = Literal::CreateR1<float>(values3);
+ std::unique_ptr<Literal> literal3 = LiteralUtil::CreateR1<float>(values3);
std::unique_ptr<GlobalData> data3 =
client_->TransferToServer(*literal3).ConsumeValueOrDie();
- auto param0 = b.Parameter(0, literal0->shape(), "param0");
- auto param1 = b.Parameter(1, literal1->shape(), "param1");
- auto param2 = b.Parameter(2, literal2->shape(), "param2");
- auto param3 = b.Parameter(3, literal3->shape(), "param2");
- b.Div(b.Div(param0, param1), b.Div(param2, param3));
+ auto param0 = Parameter(&b, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&b, 1, literal1->shape(), "param1");
+ auto param2 = Parameter(&b, 2, literal2->shape(), "param2");
+ auto param3 = Parameter(&b, 3, literal3->shape(), "param2");
+ Div(Div(param0, param1), Div(param2, param3));
std::vector<float> expected(values0.size());
for (int64 i = 0; i < values0.size(); ++i) {
@@ -1604,8 +1694,8 @@ TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
for (int i = 0; i < count; ++i) {
values.push_back(i / static_cast<float>(count));
}
- auto x = builder.ConstantR1<float>(values);
- builder.Pow(x, builder.ConstantR0<float>(2.0f));
+ auto x = ConstantR1<float>(&builder, values);
+ Pow(x, ConstantR0<float>(&builder, 2.0f));
std::vector<float> expected;
expected.reserve(values.size());
@@ -1630,8 +1720,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
Array4D<float> expected(2, 2, 2, 2, expected_vector);
- auto x = builder.ConstantR4FromArray4D<float>(values);
- builder.Pow(x, builder.ConstantR0<float>(2.0f));
+ auto x = ConstantR4FromArray4D<float>(&builder, values);
+ Pow(x, ConstantR0<float>(&builder, 2.0f));
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
@@ -1641,8 +1731,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
Array4D<float> values(2, 2, 0, 2);
Array4D<float> expected(2, 2, 0, 2);
- auto x = builder.ConstantR4FromArray4D<float>(values);
- builder.Pow(x, builder.ConstantR0<float>(2.0f));
+ auto x = ConstantR4FromArray4D<float>(&builder, values);
+ Pow(x, ConstantR0<float>(&builder, 2.0f));
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
@@ -1650,9 +1740,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
XlaBuilder builder(TestName());
SetFastMathDisabled(true);
- auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
- builder.Min(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
+ Min(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {1.0f, -5.0f, 1.0f, NAN, NAN}, {},
error_spec_);
@@ -1660,18 +1750,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({});
- auto rhs = builder.ConstantR1<float>({});
- builder.Min(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {});
+ auto rhs = ConstantR1<float>(&builder, {});
+ Min(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
XlaBuilder builder(TestName());
SetFastMathDisabled(true);
- auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
- auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
- builder.Min(lhs, rhs);
+ auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
+ Min(lhs, rhs);
ComputeAndCompareR1<double>(&builder, {1.0, -5.0, 1.0, NAN, NAN}, {},
error_spec_);
@@ -1680,9 +1770,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
XlaBuilder builder(TestName());
SetFastMathDisabled(true);
- auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
- auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
- builder.Max(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = ConstantR1<float>(&builder, {2.0f, -5.0f, 1.0f, 10.0f, NAN});
+ Max(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {2.0f, 1.0f, 2.25f, NAN, NAN}, {},
error_spec_);
@@ -1690,18 +1780,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<float>({});
- auto rhs = builder.ConstantR1<float>({});
- builder.Max(lhs, rhs);
+ auto lhs = ConstantR1<float>(&builder, {});
+ auto rhs = ConstantR1<float>(&builder, {});
+ Max(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
XlaBuilder builder(TestName());
SetFastMathDisabled(true);
- auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
- auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
- builder.Max(lhs, rhs);
+ auto lhs = ConstantR1<double>(&builder, {1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = ConstantR1<double>(&builder, {2.0, -5.0, 1.0, 10.0, NAN});
+ Max(lhs, rhs);
ComputeAndCompareR1<double>(&builder, {2.0, 1.0, 2.25, NAN, NAN}, {},
error_spec_);
@@ -1711,11 +1801,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>(
- {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
- auto y = builder.ConstantR1<int32>(
- {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
- builder.Max(x, y);
+ auto x = ConstantR1<int32>(
+ &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = ConstantR1<int32>(
+ &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ Max(x, y);
std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0,
1, 1, 10, max, max, max};
@@ -1726,11 +1816,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
const int32 min = std::numeric_limits<int32>::min();
const int32 max = std::numeric_limits<int32>::max();
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>(
- {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
- auto y = builder.ConstantR1<int32>(
- {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
- builder.Min(x, y);
+ auto x = ConstantR1<int32>(
+ &builder, {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = ConstantR1<int32>(
+ &builder, {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ Min(x, y);
std::vector<int32> expected = {min, min, min, -10, -1, -1, 0,
0, 0, 1, 0, max, min};
@@ -1740,9 +1830,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
- auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
- builder.Max(x, y);
+ auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
+ auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
+ Max(x, y);
std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
ComputeAndCompareR1<uint32>(&builder, expected, {});
@@ -1751,9 +1841,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
const uint32 max = std::numeric_limits<uint32>::max();
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
- auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
- builder.Min(x, y);
+ auto x = ConstantR1<uint32>(&builder, {0, 0, 1, 1, 1, max, max, max});
+ auto y = ConstantR1<uint32>(&builder, {0, 1, 0, 1, 10, 0, 234234, max});
+ Min(x, y);
std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
ComputeAndCompareR1<uint32>(&builder, expected, {});
@@ -1761,11 +1851,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
- auto y = builder.ConstantR1<float>(
- {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
- builder.Max(x, y);
+ auto x = ConstantR1<float>(
+ &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = ConstantR1<float>(
+ &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ Max(x, y);
std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 9.0};
@@ -1774,9 +1864,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
XlaBuilder builder(TestName());
- auto u = builder.ConstantR1<float>({3.5});
- auto v = builder.ConstantR1<float>({});
- builder.Max(u, v);
+ auto u = ConstantR1<float>(&builder, {3.5});
+ auto v = ConstantR1<float>(&builder, {});
+ Max(u, v);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -1784,9 +1874,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
for (int broadcast_dim : {0, 1}) {
XlaBuilder builder(TestName());
- auto u = builder.ConstantR1<float>({3.5});
- auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
- builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
+ auto u = ConstantR1<float>(&builder, {3.5});
+ auto v = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
+ Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
}
@@ -1794,10 +1884,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- auto m =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- builder.Max(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ auto m = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ Max(v, m, /*broadcast_dimensions=*/{1});
Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -1805,9 +1895,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({});
- auto m = builder.ConstantR2<float>({{}, {}});
- builder.Max(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<float>(&builder, {});
+ auto m = ConstantR2<float>(&builder, {{}, {}});
+ Max(v, m, /*broadcast_dimensions=*/{1});
Array2D<float> expected({{}, {}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -1815,10 +1905,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
XlaBuilder builder(TestName());
- auto scalar = builder.ConstantR0<int32>(2);
+ auto scalar = ConstantR0<int32>(&builder, 2);
Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
- auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
- builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+ auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
+ Max(array, scalar, /*broadcast_dimensions=*/{});
Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
ComputeAndCompareR3<int32>(&builder, expected, {});
@@ -1826,10 +1916,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
XlaBuilder builder(TestName());
- auto scalar = builder.ConstantR0<int32>(2);
+ auto scalar = ConstantR0<int32>(&builder, 2);
Array3D<int32> a_3d(2, 0, 3);
- auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
- builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+ auto array = ConstantR3FromArray3D<int32>(&builder, a_3d);
+ Max(array, scalar, /*broadcast_dimensions=*/{});
Array3D<int32> expected(2, 0, 3);
ComputeAndCompareR3<int32>(&builder, expected, {});
@@ -1837,10 +1927,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
XlaBuilder builder(TestName());
- auto m =
- builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
- auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
- builder.Min(m, v, /*broadcast_dimensions=*/{0});
+ auto m = ConstantR2<float>(&builder,
+ {{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
+ auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
+ Min(m, v, /*broadcast_dimensions=*/{0});
Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -1848,9 +1938,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantR2<float>({{}, {}});
- auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
- builder.Min(m, v, /*broadcast_dimensions=*/{0});
+ auto m = ConstantR2<float>(&builder, {{}, {}});
+ auto v = ConstantR1<float>(&builder, {-10.2f, 16.4f});
+ Min(m, v, /*broadcast_dimensions=*/{0});
Array2D<float> expected({{}, {}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -1859,11 +1949,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
XlaBuilder builder(TestName());
auto array2d =
- builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
- auto array4d = builder.ConstantR4FromArray4D<float>(
- {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
- {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
- builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+ ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ auto array4d = ConstantR4FromArray4D<float>(
+ &builder, {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
+ {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
+ Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
Array4D<float> expected(
{{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
@@ -1874,10 +1964,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
XlaBuilder builder(TestName());
auto array2d =
- builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ ConstantR2<float>(&builder, {{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
Array4D<float> arg(2, 2, 0, 3);
- auto array4d = builder.ConstantR4FromArray4D<float>(arg);
- builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+ auto array4d = ConstantR4FromArray4D<float>(&builder, arg);
+ Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
Array4D<float> expected(2, 2, 0, 3);
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -1885,9 +1975,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
- auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
- builder.Min(x, y);
+ auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ Min(x, y);
std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -1895,9 +1985,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
- auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
- builder.Max(x, y);
+ auto x = ConstantR1<int32>(&builder, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = ConstantR1<int32>(&builder, {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ Max(x, y);
std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -1905,19 +1995,20 @@ XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1});
- auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10});
- builder.Rem(a, b);
+ auto a = ConstantR1<int32>(&builder, {-3, 26, 2, -1, 1});
+ auto b = ConstantR1<int32>(&builder, {10, 5, 1, 10, -10});
+ Rem(a, b);
ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
XlaBuilder builder(TestName());
- auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
- auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
- auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
- builder.Clamp(minimum, argument, maximum);
+ auto minimum = ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto argument =
+ ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
+ auto maximum = ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ Clamp(minimum, argument, maximum);
ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
error_spec_);
@@ -1925,10 +2016,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
XlaBuilder builder(TestName());
- auto minimum = builder.ConstantR0<float>(0.0f);
- auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
- auto maximum = builder.ConstantR0<float>(5.0f);
- builder.Clamp(minimum, argument, maximum);
+ auto minimum = ConstantR0<float>(&builder, 0.0f);
+ auto argument = ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto maximum = ConstantR0<float>(&builder, 5.0f);
+ Clamp(minimum, argument, maximum);
ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
error_spec_);
@@ -1936,16 +2027,19 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
XlaBuilder builder(TestName());
- auto min_scalar = builder.ConstantR0<float>(0.0f);
- auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
- auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
- auto max_scalar = builder.ConstantR0<float>(3.0f);
- auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ auto min_scalar = ConstantR0<float>(&builder, 0.0f);
+ auto min_vector =
+ ConstantR1<float>(&builder, {1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto arg_vector =
+ ConstantR1<float>(&builder, {2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto max_scalar = ConstantR0<float>(&builder, 3.0f);
+ auto max_vector =
+ ConstantR1<float>(&builder, {3.0f, 0.5f, 25.5f, 5.0f, 123.0});
// Perform clamp with broadcasted scalar and vector.
- builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
- builder.Clamp(min_scalar, arg_vector, max_vector)),
- builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
- builder.Clamp(min_scalar, arg_vector, max_scalar)));
+ Add(Add(Clamp(min_vector, arg_vector, max_scalar),
+ Clamp(min_scalar, arg_vector, max_vector)),
+ Add(Clamp(min_vector, arg_vector, max_vector),
+ Clamp(min_scalar, arg_vector, max_scalar)));
ComputeAndCompareR1<float>(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {},
error_spec_);
@@ -1953,52 +2047,52 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) {
XlaBuilder builder(TestName());
- auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0, -5});
- auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4, 10});
- auto max_vector = builder.ConstantR1<int32>({3, 0, 25, 5, 123, -1});
- builder.Clamp(min_vector, arg_vector, max_vector);
+ auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0, -5});
+ auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4, 10});
+ auto max_vector = ConstantR1<int32>(&builder, {3, 0, 25, 5, 123, -1});
+ Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<int32>(&builder, {2, 0, 1, 2, 4, -1}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) {
XlaBuilder builder(TestName());
- auto min_scalar = builder.ConstantR0<int32>(0);
- auto min_vector = builder.ConstantR1<int32>({1, -6, 1, 2, 0});
- auto arg_vector = builder.ConstantR1<int32>({2, 10, -5, 1, 4});
- auto max_scalar = builder.ConstantR0<int32>(3);
- auto max_vector = builder.ConstantR1<int32>({3, 1, 25, 5, 123});
+ auto min_scalar = ConstantR0<int32>(&builder, 0);
+ auto min_vector = ConstantR1<int32>(&builder, {1, -6, 1, 2, 0});
+ auto arg_vector = ConstantR1<int32>(&builder, {2, 10, -5, 1, 4});
+ auto max_scalar = ConstantR0<int32>(&builder, 3);
+ auto max_vector = ConstantR1<int32>(&builder, {3, 1, 25, 5, 123});
// Perform clamp with broadcasted scalar and vector.
- builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
- builder.Clamp(min_scalar, arg_vector, max_vector)),
- builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
- builder.Clamp(min_scalar, arg_vector, max_scalar)));
+ Add(Add(Clamp(min_vector, arg_vector, max_scalar),
+ Clamp(min_scalar, arg_vector, max_vector)),
+ Add(Clamp(min_vector, arg_vector, max_vector),
+ Clamp(min_scalar, arg_vector, max_scalar)));
ComputeAndCompareR1<int32>(&builder, {8, 8, 2, 6, 14}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) {
XlaBuilder builder(TestName());
- auto min_vector = builder.ConstantR1<uint32>({1, 2, 1, 2, 0, ~0u - 4});
- auto arg_vector = builder.ConstantR1<uint32>({2, 10, 5, 1, 4, 10});
- auto max_vector = builder.ConstantR1<uint32>({3, 5, 25, 5, 123, ~0u});
- builder.Clamp(min_vector, arg_vector, max_vector);
+ auto min_vector = ConstantR1<uint32>(&builder, {1, 2, 1, 2, 0, ~0u - 4});
+ auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 5, 1, 4, 10});
+ auto max_vector = ConstantR1<uint32>(&builder, {3, 5, 25, 5, 123, ~0u});
+ Clamp(min_vector, arg_vector, max_vector);
ComputeAndCompareR1<uint32>(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) {
XlaBuilder builder(TestName());
- auto min_scalar = builder.ConstantR0<uint32>(0);
- auto min_vector = builder.ConstantR1<uint32>({1, 0, 1, 2, 0});
- auto arg_vector = builder.ConstantR1<uint32>({2, 10, 0, 1, 4});
- auto max_scalar = builder.ConstantR0<uint32>(3);
- auto max_vector = builder.ConstantR1<uint32>({3, 1, 25, 5, 123});
+ auto min_scalar = ConstantR0<uint32>(&builder, 0);
+ auto min_vector = ConstantR1<uint32>(&builder, {1, 0, 1, 2, 0});
+ auto arg_vector = ConstantR1<uint32>(&builder, {2, 10, 0, 1, 4});
+ auto max_scalar = ConstantR0<uint32>(&builder, 3);
+ auto max_vector = ConstantR1<uint32>(&builder, {3, 1, 25, 5, 123});
// Perform clamp with broadcasted scalar and vector.
- builder.Add(builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
- builder.Clamp(min_scalar, arg_vector, max_vector)),
- builder.Add(builder.Clamp(min_vector, arg_vector, max_vector),
- builder.Clamp(min_scalar, arg_vector, max_scalar)));
+ Add(Add(Clamp(min_vector, arg_vector, max_scalar),
+ Clamp(min_scalar, arg_vector, max_vector)),
+ Add(Clamp(min_vector, arg_vector, max_vector),
+ Clamp(min_scalar, arg_vector, max_scalar)));
ComputeAndCompareR1<uint32>(&builder, {8, 8, 2, 6, 14}, {});
}
@@ -2007,18 +2101,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
+ LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Add(p0, p1);
+ auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Add(p0, p1);
ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
{param0_data.get(), param1_data.get()},
@@ -2029,18 +2123,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Add(p0, p1);
+ auto p0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Add(p0, p1);
Array3D<float> expected(0, 7, 0);
ComputeAndCompareR3<float>(
@@ -2051,13 +2145,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
- auto p = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Add(a, p);
+ auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
+ auto p = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Add(a, p);
ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
{param0_data.get()}, error_spec_);
@@ -2065,8 +2159,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
- builder.Cos(a);
+ auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
+ Cos(a);
ComputeAndCompareR1<float>(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {},
error_spec_);
@@ -2074,8 +2168,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({3.14159f, 0.0f, 1.570796f, -0.78539f});
- builder.Sin(a);
+ auto a = ConstantR1<float>(&builder, {3.14159f, 0.0f, 1.570796f, -0.78539f});
+ Sin(a);
ComputeAndCompareR1<float>(&builder, {0.0f, 0.0f, 1.0f, -0.707107f}, {},
error_spec_);
@@ -2083,9 +2177,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
- auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
- builder.Atan2(a, b);
+ auto a = ConstantR1<float>(&builder, {0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
+ auto b = ConstantR1<float>(&builder, {6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
+ Atan2(a, b);
ComputeAndCompareR1<float>(
&builder,
@@ -2095,8 +2189,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
- builder.Tanh(a);
+ auto a = ConstantR1<float>(&builder, {-2.5f, 3.14f, 2.25f});
+ Tanh(a);
ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
error_spec_);
@@ -2107,7 +2201,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
// the input tensor is large enough to exercise the vectorized tanh
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>(
+ auto input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16,
-0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32,
-1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85,
@@ -2118,8 +2212,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) {
TF_ASSERT_OK_AND_ASSIGN(auto input_data,
client_->TransferToServer(*input_literal));
- auto input = builder.Parameter(0, input_literal->shape(), "input");
- builder.Tanh(input);
+ auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ Tanh(input);
ComputeAndCompareR1<float>(
&builder,
@@ -2149,7 +2243,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
// Just to help make sense of the scales here -- exp(89) saturates float32 and
// exp(-10) is smaller than our error spec.
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
{1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31,
-1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5,
-193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4,
@@ -2164,8 +2258,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
- auto input = builder.Parameter(0, input_literal->shape(), "input");
- builder.Exp(input);
+ auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ Exp(input);
std::vector<float> expected_result;
int64 input_size = input_literal->shape().dimensions(0);
@@ -2183,7 +2277,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// implementation on XLA CPU.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
@@ -2202,8 +2296,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
- auto input = builder.Parameter(0, input_literal->shape(), "input");
- builder.Log(input);
+ auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ Log(input);
std::vector<float> expected_result;
int64 input_size = input_literal->shape().dimensions(0);
@@ -2218,9 +2312,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint32>(
- {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
- builder.Clz(a);
+ auto a = ConstantR1<uint32>(
+ &builder, {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678});
+ Clz(a);
ComputeAndCompareR1<uint32>(&builder, {32, 31, 27, 15, 9, 3, 0}, {});
}
@@ -2228,8 +2322,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) {
XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) {
XlaBuilder builder(TestName());
auto a =
- builder.ConstantR1<int64>({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
- builder.Clz(a);
+ ConstantR1<int64>(&builder, {0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1});
+ Clz(a);
ComputeAndCompareR1<int64>(&builder, {64, 63, 32, 1, 0}, {});
}
@@ -2241,12 +2335,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
// c---------------------/
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
- auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
- auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+ auto a = ConstantR1<float>(&builder, {1.1f, 2.2f, 3.3f, 4.4f});
+ auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
- auto add = builder.Add(a, b);
- builder.Add(add, c);
+ auto add = Add(a, b);
+ Add(add, c);
ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
error_spec_);
@@ -2259,12 +2353,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
// a---------------------/
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
- auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
- auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+ auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
- auto add = builder.Add(b, c);
- builder.Add(a, add);
+ auto add = Add(b, c);
+ Add(a, add);
ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
error_spec_);
@@ -2276,12 +2370,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
// b ----- (neg) ----/
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
- auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
- auto neg_a = builder.Neg(a);
- auto neg_b = builder.Neg(b);
- builder.Add(neg_a, neg_b);
+ auto neg_a = Neg(a);
+ auto neg_b = Neg(b);
+ Add(neg_a, neg_b);
ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
error_spec_);
@@ -2297,14 +2391,14 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
// d -----/
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
- auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
- auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
- auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f});
+ auto a = ConstantR1<float>(&builder, {91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = ConstantR1<float>(&builder, {2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = ConstantR1<float>(&builder, {-3.3f, -15.5f, -7.7f, -29.9f});
+ auto d = ConstantR1<float>(&builder, {-19.0f, 10.0f, -40.0f, 20.2f});
- auto add_ab = builder.Add(a, b);
- auto add_cd = builder.Add(c, d);
- builder.Add(add_ab, add_cd);
+ auto add_ab = Add(a, b);
+ auto add_cd = Add(c, d);
+ Add(add_ab, add_cd);
ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
error_spec_);
@@ -2312,11 +2406,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto b =
- builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
- builder.Add(a, b);
+ auto a = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto b = ConstantR2<float>(&builder,
+ {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
+ Add(a, b);
Array2D<float> expected_array(
{{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
@@ -2326,10 +2420,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
// Add a scalar + matrix.
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto scalar = builder.ConstantR0<float>(3.0f);
- builder.Add(scalar, a);
+ auto a = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = ConstantR0<float>(&builder, 3.0f);
+ Add(scalar, a);
Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
@@ -2338,10 +2432,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
XLA_TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
// Add a matrix + scalar.
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto scalar = builder.ConstantR0<float>(3.0f);
- builder.Add(a, scalar);
+ auto a = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = ConstantR0<float>(&builder, 3.0f);
+ Add(a, scalar);
Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
@@ -2351,13 +2445,13 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
// Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
// only dim 0 of the matrix.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f});
+ auto v = ConstantR1<float>(&builder, {20.0f, 40.0f, 60.0f});
// clang-format off
- auto m = builder.ConstantR2<float>({
+ auto m = ConstantR2<float>(&builder, {
{-2.5f, 3.14f, 1.0f},
{2.25f, -10.0f, 3.33f}});
// clang-format on
- builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ Add(v, m, /*broadcast_dimensions=*/{1});
Array2D<float> expected_array(
{{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
@@ -2366,27 +2460,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
// Test broadcasting in Eq comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({42, 73});
- auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
+ auto v = ConstantR1<int32>(&builder, {42, 73});
+ auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
// This test exercises both possible broadcast dimensions for a vector/matrix
// comparison.
- auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
- auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
- auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
+ auto cmp_dim_0 = Eq(v, m, /*broadcast_dimensions=*/{1});
+ auto cmp_dim_1 = Eq(v, m, /*broadcast_dimensions=*/{0});
+ Tuple(&builder, {cmp_dim_0, cmp_dim_1});
- auto expected = Literal::MakeTuple(
- {Literal::CreateR2<bool>({{true, true}, {true, false}}).get(),
- Literal::CreateR2<bool>({{true, false}, {false, false}}).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
// Test broadcasting in Ne comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({42, 73});
- auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
- builder.Ne(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<int32>(&builder, {42, 73});
+ auto m = ConstantR2<int32>(&builder, {{42, 73}, {42, 52}});
+ Ne(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,2] {
{ 00 },
@@ -2398,9 +2492,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
// Test broadcasting in Ge comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
- auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
- builder.Ge(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
+ auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
+ Ge(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 1100 },
@@ -2412,9 +2506,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
// Test broadcasting in Gt comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
- auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
- builder.Gt(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
+ auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
+ Gt(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 0100 },
@@ -2426,9 +2520,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
// Test broadcasting in Le comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
- auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
- builder.Le(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
+ auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
+ Le(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 1011 },
@@ -2440,9 +2534,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
// Test broadcasting in Lt comparison.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
- auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
- builder.Lt(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<int32>(&builder, {1, 2, 3, 4});
+ auto m = ConstantR2<int32>(&builder, {{1, 0, 5, 6}, {42, 52, 10, 4}});
+ Lt(v, m, /*broadcast_dimensions=*/{1});
const string expected = R"(pred[2,4] {
{ 0011 },
@@ -2455,9 +2549,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
// Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
// arguments is reversed.
XlaBuilder builder(TestName());
- auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
- auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f});
- builder.Mul(m, v, /*broadcast_dimensions=*/{1});
+ auto m =
+ ConstantR2<float>(&builder, {{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
+ auto v = ConstantR1<float>(&builder, {2.0f, 4.0f, 6.0f});
+ Mul(m, v, /*broadcast_dimensions=*/{1});
Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
}
@@ -2468,10 +2563,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
// m's shape in XLA notation is {3, 2}
// md's shape in XLA notation is {3, 1}
// The result has shape {3, 2}, where md is broadcast over m
- auto m =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}});
- builder.Add(m, md);
+ auto m = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = ConstantR2<float>(&builder, {{10.0f, 20.0f, 30.0f}});
+ Add(m, md);
Array2D<float> expected_array(
{{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
@@ -2483,10 +2578,10 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
// m's shape in XLA notation is {3, 2}
// md's shape in XLA notation is {1, 2}
// The result has shape {3, 2}, where md is broadcast over m
- auto m =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}});
- builder.Add(m, md);
+ auto m = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = ConstantR2<float>(&builder, {{10.0f}, {20.0f}});
+ Add(m, md);
Array2D<float> expected_array(
{{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
@@ -2501,9 +2596,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
// a's shape in XLA notation is {1, 4}
// b's shape in XLA notation is {3, 1}
// The result has shape {3, 4}.
- auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}});
- auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
- builder.Add(a, b);
+ auto a = ConstantR2<float>(&builder, {{0.0f}, {10.0f}, {20.0f}, {30.0f}});
+ auto b = ConstantR2<float>(&builder, {{1.0f, 2.0f, 3.0f}});
+ Add(a, b);
Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
{11.0f, 12.0f, 13.0f},
{21.0f, 22.0f, 23.0f},
@@ -2515,9 +2610,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
// Add together a (2,2) array and a (2) array, using dimension 0 for
// broadcasting (though there are two ways to broadcast these shapes).
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({20.0f, 40.0f});
- auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
- builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
+ auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
+ Add(v, m, /*broadcast_dimensions=*/{1});
Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
}
@@ -2526,9 +2621,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
// Add together a (2,2) array and a (2) array, using dimension 1 for
// broadcasting (though there are two ways to broadcast these shapes).
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({20.0f, 40.0f});
- auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
- builder.Add(v, m, /*broadcast_dimensions=*/{0});
+ auto v = ConstantR1<float>(&builder, {20.0f, 40.0f});
+ auto m = ConstantR2<float>(&builder, {{10.0f, 50.0f}, {77.0f, 88.0f}});
+ Add(v, m, /*broadcast_dimensions=*/{0});
Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
}
@@ -2538,12 +2633,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
XlaBuilder builder(TestName());
Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
{{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
- auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
{{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
- auto b = builder.ConstantR3FromArray3D<float>(b_3d);
- builder.Add(a, b);
+ auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
+ Add(a, b);
Array3D<float> expected_3d(
{{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
@@ -2565,9 +2660,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
{11.0f, 12.0f}},
});
// clang-format on
- auto a = builder.ConstantR3FromArray3D<float>(a_3d);
- auto v = builder.ConstantR1<float>({10.0f, 20.0f});
- builder.Add(a, v, /*broadcast_dimensions=*/{2});
+ auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
+ auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
+ Add(a, v, /*broadcast_dimensions=*/{2});
Array3D<float> expected_3d(
{{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
@@ -2589,9 +2684,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
{11.0f, 12.0f}},
});
// clang-format on
- auto a = builder.ConstantR3FromArray3D<float>(a_3d);
- auto v = builder.ConstantR1<float>({10.0f, 20.0f});
- builder.Add(a, v, /*broadcast_dimensions=*/{0});
+ auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
+ auto v = ConstantR1<float>(&builder, {10.0f, 20.0f});
+ Add(a, v, /*broadcast_dimensions=*/{0});
// clang-format off
Array3D<float> expected_3d({
@@ -2619,12 +2714,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
{9.0f, 10.0f},
{11.0f, 12.0f}},
});
- auto a = builder.ConstantR3FromArray3D<float>(a_3d);
- auto m = builder.ConstantR2<float>({
+ auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
+ auto m = ConstantR2<float>(&builder, {
{10.0f, 20.0f, 30.0f},
{40.0f, 50.0f, 60.0f},
});
- builder.Add(a, m, /*broadcast_dimensions=*/{0, 1});
+ Add(a, m, /*broadcast_dimensions=*/{0, 1});
Array3D<float> expected_3d({
{{11.0f, 12.0f},
@@ -2644,12 +2739,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
XlaBuilder builder(TestName());
Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
{{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
- auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto a = ConstantR3FromArray3D<float>(&builder, a_3d);
Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
- auto b = builder.ConstantR3FromArray3D<float>(b_3d);
+ auto b = ConstantR3FromArray3D<float>(&builder, b_3d);
- builder.Gt(a, b);
+ Gt(a, b);
Array3D<int> expected_3d(
{{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
@@ -2684,9 +2779,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
}
}
- auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
- auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d);
- builder.Add(a, b);
+ auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
+ auto b = ConstantR4FromArray4D<float>(&builder, *operand_b_4d);
+ Add(a, b);
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
}
@@ -2712,9 +2807,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
}
}
- auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
- auto b = builder.ConstantR1<float>(operand_b_1d);
- builder.Add(a, b, {1});
+ auto a = ConstantR4FromArray4D<float>(&builder, *operand_a_4d);
+ auto b = ConstantR1<float>(&builder, operand_b_1d);
+ Add(a, b, {1});
ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
}
@@ -2730,11 +2825,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
std::iota(r1.begin(), r1.end(), 1.0);
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR4FromArray4DWithLayout(
- r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
- auto a = builder.ConstantLiteral(*a_literal);
- auto b = builder.ConstantR1<float>(r1);
- builder.Add(a, b, {1});
+ std::unique_ptr<Literal> a_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ r4, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ auto a = ConstantLiteral(&builder, *a_literal);
+ auto b = ConstantR1<float>(&builder, r1);
+ Add(a, b, {1});
for (int i0 = 0; i0 < d0; ++i0) {
for (int i1 = 0; i1 < d1; ++i1) {
@@ -2752,22 +2848,22 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) {
XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
XlaBuilder builder(TestName());
auto shape = ShapeUtil::MakeOpaqueShape();
- auto x = builder.Parameter(0, shape, "x");
- builder.Add(x, x);
+ auto x = Parameter(&builder, 0, shape, "x");
+ Add(x, x);
auto computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
::testing::ContainsRegex(
- "Expected non-opaque argument for lhs of binary operation"));
+ "Expected array argument for lhs of binary operation"));
}
XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto b =
- builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
- builder.Add(a, b, /*broadcast_dimensions=*/{0, 1});
+ auto a = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto b = ConstantR2<float>(&builder,
+ {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
+ Add(a, b, /*broadcast_dimensions=*/{0, 1});
Array2D<float> expected_array(
{{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
@@ -2776,11 +2872,11 @@ XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
- auto b =
- builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
- builder.Add(a, b, /*broadcast_dimensions=*/{1, 0});
+ auto a = ConstantR2<float>(&builder,
+ {{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto b = ConstantR2<float>(&builder,
+ {{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
+ Add(a, b, /*broadcast_dimensions=*/{1, 0});
auto computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
@@ -2792,15 +2888,15 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
// broadcast.
XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
XlaBuilder builder(TestName());
- auto x_literal = Literal::CreateR1<float>({1, 2, 3});
- auto y_literal = Literal::CreateR1<float>({4, 5});
+ auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
- auto x = builder.Parameter(0, x_literal->shape(), "x");
- auto y = builder.Parameter(1, y_literal->shape(), "y");
- auto slice = builder.Slice(x, {1}, {2}, {1});
- builder.Sub(slice, y);
+ auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto y = Parameter(&builder, 1, y_literal->shape(), "y");
+ auto slice = Slice(x, {1}, {2}, {1});
+ Sub(slice, y);
ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
error_spec_);
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
index fcd9ff55e3..8d15b7841b 100644
--- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -29,10 +29,10 @@ class AxpySimpleTest : public ClientLibraryTestBase {};
TEST_F(AxpySimpleTest, AxTenValues) {
XlaBuilder builder("ax_10");
- auto alpha = builder.ConstantR0<float>(3.1415926535);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- builder.Mul(alpha, x);
+ auto alpha = ConstantR0<float>(&builder, 3.1415926535);
+ auto x = ConstantR1<float>(
+ &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ Mul(alpha, x);
std::vector<float> expected = {
-3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796,
@@ -42,11 +42,11 @@ TEST_F(AxpySimpleTest, AxTenValues) {
XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
XlaBuilder builder("axpy_10");
- auto alpha = builder.ConstantR0<float>(3.1415926535);
- auto x = builder.ConstantR1<float>({});
- auto y = builder.ConstantR1<float>({});
- auto ax = builder.Mul(alpha, x);
- builder.Add(ax, y);
+ auto alpha = ConstantR0<float>(&builder, 3.1415926535);
+ auto x = ConstantR1<float>(&builder, {});
+ auto y = ConstantR1<float>(&builder, {});
+ auto ax = Mul(alpha, x);
+ Add(ax, y);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -54,13 +54,13 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
TEST_F(AxpySimpleTest, AxpyTenValues) {
XlaBuilder builder("axpy_10");
- auto alpha = builder.ConstantR0<float>(3.1415926535);
- auto x = builder.ConstantR1<float>(
- {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
- auto y = builder.ConstantR1<float>(
- {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
- auto ax = builder.Mul(alpha, x);
- builder.Add(ax, y);
+ auto alpha = ConstantR0<float>(&builder, 3.1415926535);
+ auto x = ConstantR1<float>(
+ &builder, {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto y = ConstantR1<float>(
+ &builder, {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
+ auto ax = Mul(alpha, x);
+ Add(ax, y);
TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape());
diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
index 22c3394e6f..8c227df7f0 100644
--- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
+++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
@@ -35,10 +35,10 @@ class BadRngShapeValidationTest : public ClientLibraryTestBase {};
TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR0<float>(0.0);
- auto one = builder.ConstantR0<float>(1.0);
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto one = ConstantR0<float>(&builder, 1.0);
Shape default_constructed;
- builder.RngUniform(zero, one, default_constructed);
+ RngUniform(zero, one, default_constructed);
StatusOr<XlaComputation> computation = builder.Build();
EXPECT_FALSE(computation.ok());
@@ -49,13 +49,13 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR0<float>(0.0);
- auto one = builder.ConstantR0<float>(1.0);
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto one = ConstantR0<float>(&builder, 1.0);
Shape sans_layout;
sans_layout.set_element_type(F32);
sans_layout.add_dimensions(1);
- builder.RngUniform(zero, one, sans_layout);
+ RngUniform(zero, one, sans_layout);
StatusOr<XlaComputation> computation = builder.Build();
ASSERT_TRUE(computation.ok());
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index f3dac75a44..6a024798f9 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -20,10 +20,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -62,7 +63,7 @@ class BatchNormalizationTest
{5.0f, 4.4f}, // p2
});
input_array_.FillWithPZ(pz);
- input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_));
+ input_literal_ = std::move(*LiteralUtil::CreateR4FromArray4D(input_array_));
CHECK_EQ(kSamples, input_array_.planes());
CHECK_EQ(kZ, input_array_.depth());
CHECK_EQ(kY, input_array_.height());
@@ -101,9 +102,9 @@ INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
XlaBuilder builder("subtract_in_z_one_sample");
- auto x = builder.ConstantLiteral(input_literal_);
- auto y = builder.ConstantR1<float>({3.14, 4.25});
- builder.Sub(x, y, /*broadcast_dimensions=*/{1});
+ auto x = ConstantLiteral(&builder, input_literal_);
+ auto y = ConstantR1<float>(&builder, {3.14, 4.25});
+ Sub(x, y, /*broadcast_dimensions=*/{1});
Array4D<float> expected(kSamples, kZ, kY, kX);
Array2D<float> pz({
@@ -117,8 +118,8 @@ XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
XlaBuilder builder("square_tesseract_elementwise");
- auto x = builder.ConstantLiteral(input_literal_);
- builder.SquareF32(x);
+ auto x = ConstantLiteral(&builder, input_literal_);
+ Square(x);
using tensorflow::MathUtil;
@@ -134,11 +135,10 @@ XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
XLA_TEST_P(BatchNormalizationTest, SumToZ) {
XlaBuilder builder("sum_to_z");
- auto input_activations = builder.ConstantLiteral(input_literal_);
+ auto input_activations = ConstantLiteral(&builder, input_literal_);
XlaComputation add = CreateScalarAddComputation(F32, &builder);
// Reduce all but the Z dimension.
- builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
- {0, 2, 3});
+ Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
std::vector<float> expected = {6, 12.6};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -146,13 +146,13 @@ XLA_TEST_P(BatchNormalizationTest, SumToZ) {
XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
XlaBuilder builder("square_and_reduce");
- auto input_activations = builder.ConstantLiteral(input_literal_);
- auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
- auto activation_deviations = builder.Sub(input_activations, set_means,
- /*broadcast_dimensions=*/{1});
+ auto input_activations = ConstantLiteral(&builder, input_literal_);
+ auto set_means = ConstantR1<float>(&builder, {2.f, 4.2f});
+ auto activation_deviations = Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto dev_squares = builder.SquareF32(activation_deviations);
- builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
+ auto dev_squares = Square(activation_deviations);
+ Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
std::vector<float> expected = {18, 0.06};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -160,8 +160,8 @@ XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XlaBuilder builder("variance_to_stddev");
- auto variance = builder.ConstantR1<float>({6.f, .02f});
- builder.SqrtF32(variance);
+ auto variance = ConstantR1<float>(&builder, {6.f, .02f});
+ Sqrt(variance);
std::vector<float> expected = {2.44948974f, 0.14142136f};
ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
@@ -172,50 +172,50 @@ XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
XlaBuilder builder("batch_normalize_per_spec");
auto input_activations =
- CheckShape(&builder, builder.ConstantLiteral(input_literal_),
+ CheckShape(&builder, ConstantLiteral(&builder, input_literal_),
ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
- auto gamma = builder.ConstantR1<float>({1.0, 1.0});
- auto beta = builder.ConstantR1<float>({0.0, 0.0});
+ auto gamma = ConstantR1<float>(&builder, {1.0, 1.0});
+ auto beta = ConstantR1<float>(&builder, {0.0, 0.0});
XlaComputation add = CreateScalarAddComputation(F32, &builder);
// Reduce all dimensions except dimension 1.
Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
auto sum = CheckShape(
&builder,
- builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
- /*dimensions_to_reduce=*/{0, 2, 3}),
+ Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
TwoElementVectorF32);
auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
- auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(input_shape) /
- ShapeUtil::ElementsIn(sum_shape));
- auto set_means = builder.Div(sum, count);
+ auto count =
+ ConstantR0<float>(&builder, ShapeUtil::ElementsIn(input_shape) /
+ ShapeUtil::ElementsIn(sum_shape));
+ auto set_means = Div(sum, count);
const float kEpsilon = 1e-9f;
- auto epsilon = builder.ConstantR0<float>(kEpsilon);
- auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
- auto activation_deviations = builder.Sub(input_activations, set_means,
- /*broadcast_dimensions=*/{1});
- auto dev_squares = builder.SquareF32(activation_deviations);
- auto sum_of_squares = CheckShape(
- &builder,
- builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
- /*dimensions_to_reduce=*/{0, 2, 3}),
- TwoElementVectorF32);
- auto variance = builder.Div(sum_of_squares, count);
- auto standard_deviation = builder.SqrtF32(variance);
+ auto epsilon = ConstantR0<float>(&builder, kEpsilon);
+ auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
+ auto activation_deviations = Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
+ auto dev_squares = Square(activation_deviations);
+ auto sum_of_squares =
+ CheckShape(&builder,
+ Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
+ TwoElementVectorF32);
+ auto variance = Div(sum_of_squares, count);
+ auto standard_deviation = Sqrt(variance);
auto standard_deviation_above_epsilon =
- CheckShape(&builder, builder.Gt(standard_deviation, epsilon),
+ CheckShape(&builder, Gt(standard_deviation, epsilon),
ShapeUtil::MakeShape(PRED, {2}));
- auto gt_eps = builder.Select(standard_deviation_above_epsilon,
- standard_deviation, epsilon2);
- auto normalization_factors = builder.ReciprocalF32(gt_eps);
+ auto gt_eps =
+ Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
+ auto normalization_factors = Reciprocal(gt_eps);
auto normalized_input_activations =
- builder.Mul(activation_deviations, normalization_factors,
- /*broadcast_dimensions=*/{1});
- /* auto output_activations = */ builder.Add(
- builder.Mul(normalized_input_activations, gamma,
- /*broadcast_dimensions=*/{1}),
- beta, /*broadcast_dimensions=*/{1});
+ Mul(activation_deviations, normalization_factors,
+ /*broadcast_dimensions=*/{1});
+ /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma,
+ /*broadcast_dimensions=*/{1}),
+ beta, /*broadcast_dimensions=*/{1});
Array4D<float> expected(kSamples, kZ, kY, kX);
Array2D<float> pz({
@@ -232,46 +232,47 @@ XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
const int kFeatureIndex = 3;
XlaBuilder builder(TestName());
- auto operand = builder.ConstantR4FromArray4D<float>(
- {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
+ auto operand = ConstantR4FromArray4D<float>(
+ &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
- auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
+ auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
- auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
+ auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
- builder.BatchNormTraining(operand, scale, offset,
- /*epsilon=*/0.001, kFeatureIndex);
+ BatchNormTraining(operand, scale, offset,
+ /*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
- {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
+ {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
-XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) {
+XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());
- auto operand = builder.ConstantR4FromArray4D<float>(
+ auto operand = ConstantR4FromArray4D<float>(
+ &builder,
{{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
- auto scale = builder.ConstantR1<float>({2.0f, 3.0f});
+ auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
- auto offset = builder.ConstantR1<float>({1.0f, 2.0f});
+ auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
- builder.BatchNormTraining(operand, scale, offset,
- /*epsilon=*/0.001, kFeatureIndex);
+ BatchNormTraining(operand, scale, offset,
+ /*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
- {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
+ {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}})
.get(),
- Literal::CreateR1<float>({4, 5}).get(),
- Literal::CreateR1<float>({5, 5}).get()});
+ LiteralUtil::CreateR1<float>({4, 5}).get(),
+ LiteralUtil::CreateR1<float>({5, 5}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -294,14 +295,14 @@ XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
/*parameter_number=*/2, "offset", &builder, &h2);
- builder.BatchNormTraining(h0, h1, h2,
- /*epsilon=*/1, kFeatureIndex);
+ BatchNormTraining(h0, h1, h2,
+ /*epsilon=*/1, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f))
.get(),
- Literal::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -327,14 +328,15 @@ XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
/*parameter_number=*/2, "offset", &builder, &h2);
// var = 125, mean = 15, epsilon = -100
- builder.BatchNormTraining(h0, h1, h2,
- /*epsilon=*/-100, kFeatureIndex);
+ BatchNormTraining(h0, h1, h2,
+ /*epsilon=*/-100, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR3FromArray3D<float>({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR3FromArray3D<float>(
+ {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}})
.get(),
- Literal::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
- Literal::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)).get(),
+ LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f)).get()});
ComputeAndCompareTuple(&builder, *expected,
{operand.get(), scale.get(), offset.get()},
@@ -346,26 +348,27 @@ XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
XlaBuilder builder(TestName());
auto operand =
- builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 2, 2, 1, 0.0f));
+ ConstantR4FromArray4D<float>(&builder, Array4D<float>(2, 2, 2, 1, 0.0f));
- auto scale = builder.ConstantR1<float>({1.0f, 1.0f});
+ auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
- auto mean = builder.ConstantR1<float>({0.0f, 0.0f});
+ auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
- auto var = builder.ConstantR1<float>({1.0f, 1.0f});
+ auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
- auto grad_output = builder.ConstantR4FromArray4D<float>(
+ auto grad_output = ConstantR4FromArray4D<float>(
+ &builder,
{{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
- builder.BatchNormGrad(operand, scale, mean, var, grad_output,
- /*epsilon=*/0.0, kFeatureIndex);
+ BatchNormGrad(operand, scale, mean, var, grad_output,
+ /*epsilon=*/0.0, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
- {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
+ {{{1.f}, {1.f}}, {{3.f}, {3.f}}}})
.get(),
- Literal::CreateR1<float>({0, 0}).get(),
- Literal::CreateR1<float>({16, 20}).get()});
+ LiteralUtil::CreateR1<float>({0, 0}).get(),
+ LiteralUtil::CreateR1<float>({16, 20}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1));
}
@@ -511,22 +514,23 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
+ auto expected_normalized =
+ LiteralUtil::CreateR4FromArray4D<float>(normalized);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- builder.Parameter(0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal->shape(), "input");
auto scale_activations =
- builder.Parameter(1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal->shape(), "offset");
auto offset_activations =
- builder.Parameter(2, offset_literal->shape(), "scale");
+ Parameter(&builder, 2, offset_literal->shape(), "scale");
- auto expected = Literal::MakeTuple({expected_normalized.get(),
- Literal::CreateR1<float>(mean).get(),
- Literal::CreateR1<float>(var).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_normalized.get(), LiteralUtil::CreateR1<float>(mean).get(),
+ LiteralUtil::CreateR1<float>(var).get()});
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -535,8 +539,8 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
std::unique_ptr<GlobalData> offset_data =
client_->TransferToServer(*offset_literal).ConsumeValueOrDie();
- builder.BatchNormTraining(input_activations, scale_activations,
- offset_activations, epsilon, feature_index);
+ BatchNormTraining(input_activations, scale_activations, offset_activations,
+ epsilon, feature_index);
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
@@ -611,21 +615,21 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
- auto offset_literal = Literal::CreateR1<float>(offset);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
+ auto offset_literal = LiteralUtil::CreateR1<float>(offset);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
auto input_activations =
- builder.Parameter(0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal->shape(), "input");
auto scale_activations =
- builder.Parameter(1, scale_literal->shape(), "offset");
+ Parameter(&builder, 1, scale_literal->shape(), "offset");
auto offset_activations =
- builder.Parameter(2, offset_literal->shape(), "scale");
- auto mean_activations = builder.Parameter(3, mean_literal->shape(), "mean");
+ Parameter(&builder, 2, offset_literal->shape(), "scale");
+ auto mean_activations = Parameter(&builder, 3, mean_literal->shape(), "mean");
auto variance_activations =
- builder.Parameter(4, var_literal->shape(), "variance");
+ Parameter(&builder, 4, var_literal->shape(), "variance");
Array4D<float> expected = normalized;
@@ -640,9 +644,9 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
std::unique_ptr<GlobalData> variance_data =
client_->TransferToServer(*var_literal).ConsumeValueOrDie();
- builder.BatchNormInference(input_activations, scale_activations,
- offset_activations, mean_activations,
- variance_activations, epsilon, feature_index);
+ BatchNormInference(input_activations, scale_activations, offset_activations,
+ mean_activations, variance_activations, epsilon,
+ feature_index);
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
@@ -798,21 +802,23 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
});
auto expected_grad_activation =
- Literal::CreateR4FromArray4D<float>(grad_activation);
+ LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
- auto input_literal = Literal::CreateR4FromArray4D<float>(input_array);
- auto scale_literal = Literal::CreateR1<float>(scale);
- auto mean_literal = Literal::CreateR1<float>(mean);
- auto var_literal = Literal::CreateR1<float>(var);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ auto scale_literal = LiteralUtil::CreateR1<float>(scale);
+ auto mean_literal = LiteralUtil::CreateR1<float>(mean);
+ auto var_literal = LiteralUtil::CreateR1<float>(var);
auto grad_output_literal =
- Literal::CreateR4FromArray4D<float>(grad_output_array);
-
- auto input_parameter = builder.Parameter(0, input_literal->shape(), "input");
- auto scale_parameter = builder.Parameter(1, scale_literal->shape(), "scale");
- auto mean_parameter = builder.Parameter(2, mean_literal->shape(), "mean");
- auto var_parameter = builder.Parameter(3, var_literal->shape(), "variance");
+ LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
+
+ auto input_parameter =
+ Parameter(&builder, 0, input_literal->shape(), "input");
+ auto scale_parameter =
+ Parameter(&builder, 1, scale_literal->shape(), "scale");
+ auto mean_parameter = Parameter(&builder, 2, mean_literal->shape(), "mean");
+ auto var_parameter = Parameter(&builder, 3, var_literal->shape(), "variance");
auto grad_output_parameter =
- builder.Parameter(4, grad_output_literal->shape(), "grad_output");
+ Parameter(&builder, 4, grad_output_literal->shape(), "grad_output");
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -825,14 +831,13 @@ XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
std::unique_ptr<GlobalData> grad_output_data =
client_->TransferToServer(*grad_output_literal).ConsumeValueOrDie();
- builder.BatchNormGrad(input_parameter, scale_parameter, mean_parameter,
- var_parameter, grad_output_parameter, epsilon,
- feature_index);
+ BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
+ grad_output_parameter, epsilon, feature_index);
auto expected =
- Literal::MakeTuple({expected_grad_activation.get(),
- Literal::CreateR1<float>(grad_scale).get(),
- Literal::CreateR1<float>(grad_offset).get()});
+ LiteralUtil::MakeTuple({expected_grad_activation.get(),
+ LiteralUtil::CreateR1<float>(grad_scale).get(),
+ LiteralUtil::CreateR1<float>(grad_offset).get()});
// Run all HLO passes during this test. In particular, ClientLibraryTestBase
// disables constant folding, but we want it enabled for our zero-sized tensor
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index ca337e7884..747c82b502 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -51,9 +51,9 @@ class Bfloat16Test : public ClientLibraryTestBase {
XLA_TEST_F(Bfloat16Test, ScalarOperation) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
- auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
- builder.Add(x, y);
+ auto x = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(2.0f));
+ auto y = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(1.0f));
+ Add(x, y);
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(3.0f), {},
error_spec_);
@@ -61,8 +61,8 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) {
XLA_TEST_F(Bfloat16Test, LogOperation) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f));
- builder.Log(x);
+ auto x = ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(4.0f));
+ Log(x);
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {},
error_spec_);
@@ -70,7 +70,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) {
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
XlaBuilder builder(TestName());
- builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
+ Neg(ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(2.1f)));
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
error_spec_);
@@ -80,33 +80,33 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());
- auto operand = builder.ConstantR4FromArray4D<bfloat16>(
+ auto operand = ConstantR4FromArray4D<bfloat16>(
+ &builder,
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
- auto scale = builder.ConstantR1<bfloat16>(
- {static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
+ auto scale = ConstantR1<bfloat16>(
+ &builder, {static_cast<bfloat16>(2.0f), static_cast<bfloat16>(3.0f)});
- auto offset = builder.ConstantR1<bfloat16>(
- {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
+ auto offset = ConstantR1<bfloat16>(
+ &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(2.0f)});
- auto tuple = builder.BatchNormTraining(operand, scale, offset,
- /*epsilon=*/0.001, kFeatureIndex);
+ BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<bfloat16>(
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-1.6875f)},
{static_cast<bfloat16>(-2.04f)}},
{{static_cast<bfloat16>(0.105f)}, {static_cast<bfloat16>(0.66f)}}},
{{{static_cast<bfloat16>(1.89f)}, {static_cast<bfloat16>(3.35f)}},
{{static_cast<bfloat16>(3.7f)}, {static_cast<bfloat16>(6.04f)}}}})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(4), static_cast<bfloat16>(5)})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(5), static_cast<bfloat16>(5)})
.get()});
@@ -117,38 +117,39 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
const int kFeatureIndex = 2;
XlaBuilder builder(TestName());
- auto operand = builder.ConstantR4FromArray4D<bfloat16>(
- Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
+ auto operand = ConstantR4FromArray4D<bfloat16>(
+ &builder, Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
- auto scale = builder.ConstantR1<bfloat16>(
- {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
+ auto scale = ConstantR1<bfloat16>(
+ &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
- auto mean = builder.ConstantR1<bfloat16>(
- {static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
+ auto mean = ConstantR1<bfloat16>(
+ &builder, {static_cast<bfloat16>(0.0f), static_cast<bfloat16>(0.0f)});
- auto var = builder.ConstantR1<bfloat16>(
- {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
+ auto var = ConstantR1<bfloat16>(
+ &builder, {static_cast<bfloat16>(1.0f), static_cast<bfloat16>(1.0f)});
- auto grad_output = builder.ConstantR4FromArray4D<bfloat16>(
+ auto grad_output = ConstantR4FromArray4D<bfloat16>(
+ &builder,
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(4.f)}}},
{{{static_cast<bfloat16>(5.f)}, {static_cast<bfloat16>(6.f)}},
{{static_cast<bfloat16>(7.f)}, {static_cast<bfloat16>(8.f)}}}});
- builder.BatchNormGrad(operand, scale, mean, var, grad_output,
- /*epsilon=*/0.0, kFeatureIndex);
+ BatchNormGrad(operand, scale, mean, var, grad_output,
+ /*epsilon=*/0.0, kFeatureIndex);
- auto expected = Literal::MakeTuple(
- {Literal::CreateR4<bfloat16>(
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR4<bfloat16>(
{{{{static_cast<bfloat16>(-3.f)}, {static_cast<bfloat16>(-3.f)}},
{{static_cast<bfloat16>(-1.f)}, {static_cast<bfloat16>(-1.f)}}},
{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(1.f)}},
{{static_cast<bfloat16>(3.f)}, {static_cast<bfloat16>(3.f)}}}})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(0), static_cast<bfloat16>(0)})
.get(),
- Literal::CreateR1<bfloat16>(
+ LiteralUtil::CreateR1<bfloat16>(
{static_cast<bfloat16>(16), static_cast<bfloat16>(20)})
.get()});
diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
index 48203b1d40..20cb989751 100644
--- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc
+++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
@@ -33,9 +33,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) {
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
- auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<float>(&builder, *alhs);
+ auto rhs = ConstantR2FromArray2D<float>(&builder, *arhs);
+ Add(lhs, rhs);
auto aexpected = ReferenceUtil::MapWithIndexArray2D(
*alhs, [&](float lhs_value, int64 row, int64 col) {
@@ -49,9 +49,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) {
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
- auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<float>(&builder, *alhs);
+ auto rhs = ConstantR2FromArray2D<float>(&builder, *arhs);
+ Add(lhs, rhs);
auto aexpected = ReferenceUtil::MapWithIndexArray2D(
*alhs, [&](float lhs_value, int64 row, int64 col) {
@@ -65,9 +65,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) {
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
- auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<float>(&builder, *alhs);
+ auto rhs = ConstantR2FromArray2D<float>(&builder, *arhs);
+ Add(lhs, rhs);
auto aexpected = ReferenceUtil::MapWithIndexArray2D(
*alhs, [&](float lhs_value, int64 row, int64 col) {
@@ -81,9 +81,9 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1);
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
- auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<float>(&builder, *alhs);
+ auto rhs = ConstantR2FromArray2D<float>(&builder, *arhs);
+ Add(lhs, rhs);
auto aexpected = ReferenceUtil::MapWithIndexArray2D(
*alhs, [&](float lhs_value, int64 row, int64 col) {
@@ -94,11 +94,12 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
TEST_F(BinopScalingTest, R0PlusR2F32) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR0<float>(42.0);
- auto rhs = builder.ConstantR2<float>({
- {1.0, 2.0}, {3.0, 4.0},
- });
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR0<float>(&builder, 42.0);
+ auto rhs = ConstantR2<float>(&builder, {
+ {1.0, 2.0},
+ {3.0, 4.0},
+ });
+ Add(lhs, rhs);
Array2D<float> expected(2, 2);
expected(0, 0) = 42.0 + 1.0;
@@ -129,9 +130,9 @@ TEST_F(BinopScalingTest, R4PlusR0S32) {
});
// clang-format on
- auto lhs = builder.ConstantR4FromArray4D(lhs_array);
- auto rhs = builder.ConstantR0<int>(42);
- builder.Add(lhs, rhs);
+ auto lhs = ConstantR4FromArray4D(&builder, lhs_array);
+ auto rhs = ConstantR0<int>(&builder, 42);
+ Add(lhs, rhs);
ComputeAndCompareR4<int>(&builder, expected, {});
}
diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
index bff60f25ec..d531e8fa82 100644
--- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
+++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc
@@ -43,8 +43,8 @@ class BitcastConvertTest : public ClientLibraryTestBase {
TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({42, 64});
- builder.BitcastConvertType(a, S32);
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ BitcastConvertType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -52,8 +52,8 @@ TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) {
TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0f, 64.0f});
- builder.BitcastConvertType(a, F32);
+ auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
+ BitcastConvertType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -62,10 +62,10 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) {
TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) {
XlaBuilder builder(TestName());
auto a =
- builder.ConstantR1<int32>({0, static_cast<int32>(0x80000000), 0x3F800000,
- static_cast<int32>(0xBF800000), 0x3F000000,
- static_cast<int32>(0xBF000000)});
- builder.BitcastConvertType(a, F32);
+ ConstantR1<int32>(&builder, {0, static_cast<int32>(0x80000000),
+ 0x3F800000, static_cast<int32>(0xBF800000),
+ 0x3F000000, static_cast<int32>(0xBF000000)});
+ BitcastConvertType(a, F32);
std::vector<float> expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -73,8 +73,8 @@ TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) {
XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- builder.BitcastConvertType(a, F32);
+ auto a = ConstantR1<int32>(&builder, {});
+ BitcastConvertType(a, F32);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -82,8 +82,8 @@ XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) {
TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.6, 64.4});
- builder.BitcastConvertType(a, S32);
+ auto a = ConstantR1<float>(&builder, {42.6, 64.4});
+ BitcastConvertType(a, S32);
std::vector<int32> expected = {0x422a6666, 0x4280cccd};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -91,9 +91,9 @@ TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) {
TEST_F(BitcastConvertTest, ConvertS32Extremes) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>(
- {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
- builder.BitcastConvertType(a, F32);
+ auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max()});
+ BitcastConvertType(a, F32);
std::vector<float> expected = {-0.0f, NAN};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0, 0));
@@ -102,10 +102,10 @@ TEST_F(BitcastConvertTest, ConvertS32Extremes) {
TEST_F(BitcastConvertTest, ConvertMapToS32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
- auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
- b->BitcastConvertType(param, S32);
- auto a = builder.ConstantR1<float>({42.0f, 64.0f});
- builder.Map({a}, b->BuildAndNoteError(), {0});
+ auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
+ BitcastConvertType(param, S32);
+ auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
+ Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<int32> expected = {0x42280000, 0x42800000};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -114,10 +114,10 @@ TEST_F(BitcastConvertTest, ConvertMapToS32) {
TEST_F(BitcastConvertTest, ConvertMapToF32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
- auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
- b->BitcastConvertType(param, F32);
- auto a = builder.ConstantR1<int32>({0x42280000, 0x42800000});
- builder.Map({a}, b->BuildAndNoteError(), {0});
+ auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
+ BitcastConvertType(param, F32);
+ auto a = ConstantR1<int32>(&builder, {0x42280000, 0x42800000});
+ Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -130,9 +130,9 @@ TEST_F(BitcastConvertTest, ConvertMapToF32) {
// the new convert should have the same element type as the old convert.
TEST_F(BitcastConvertTest, ConvertReshape) {
XlaBuilder builder(TestName());
- auto input = builder.ConstantR1<int32>({0x42280000});
- auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
- builder.BitcastConvertType(reshape, F32);
+ auto input = ConstantR1<int32>(&builder, {0x42280000});
+ auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
+ BitcastConvertType(reshape, F32);
ComputeAndCompareR0<float>(&builder, 42.0f, {});
}
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 34c86e007b..50dd574624 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -37,17 +38,17 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
XlaBuilder* builder) {
switch (op) {
case HloOpcode::kMinimum: {
- return builder->Min(lhs, rhs);
+ return Min(lhs, rhs);
}
case HloOpcode::kMaximum: {
- return builder->Max(lhs, rhs);
+ return Max(lhs, rhs);
}
case HloOpcode::kMultiply: {
- return builder->Mul(lhs, rhs);
+ return Mul(lhs, rhs);
}
default: {
// Default to Add
- return builder->Add(lhs, rhs);
+ return Add(lhs, rhs);
}
}
}
@@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array3D<float>* r3_array, float start, float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
client_->TransferToServer(*r3_data).ConsumeValueOrDie();
@@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array2D<float>* r2_array, float start, float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
client_->TransferToServer(*r2_data).ConsumeValueOrDie();
@@ -104,13 +105,13 @@ using ::testing::HasSubstr;
XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR0<float>(1.5), {});
+ Broadcast(ConstantR0<float>(&b, 1.5), {});
ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
+ Broadcast(ConstantR0<float>(&b, 2.25), {2, 3});
Array2D<float> expected(2, 3, 2.25);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
@@ -122,7 +123,7 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
/*builder=*/&b, /*data_handle=*/&src);
- b.Broadcast(src, {2, 3});
+ Broadcast(src, {2, 3});
Array2D<float> expected(2, 3, 2.25);
ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
ErrorSpec(0.0001));
@@ -130,21 +131,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
+ Broadcast(ConstantR0<float>(&b, 2.25), {2, 0});
Array2D<float> expected(2, 0);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
+ Broadcast(ConstantR0<float>(&b, 2.25), {0, 2});
Array2D<float> expected(0, 2);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
+ Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {2});
Array2D<float> expected(2, 3);
expected(0, 0) = 1;
@@ -156,6 +157,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {1});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {0});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 1;
+ expected(1, 0) = 2;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 1.0;
+ expected(1, 0, 1) = 2.0;
+ expected(0, 1, 0) = 5.0;
+ expected(1, 1, 0) = 6.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 5.0;
+ expected(1, 0, 1) = 6.0;
+ expected(0, 1, 0) = 1.0;
+ expected(1, 1, 0) = 2.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {3, 2}), {1});
+
+ Array2D<float> expected(3, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+ expected(2, 0) = 1;
+ expected(2, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
// Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
XlaBuilder b(TestName());
@@ -172,7 +253,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
XlaOp x, y;
auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
- b.And(x, y, /*broadcast_dimensions=*/{1, 2});
+ And(x, y, /*broadcast_dimensions=*/{1, 2});
Array3D<bool> expected(2, 2, 1);
expected(0, 0, 0) = false;
@@ -185,7 +266,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR1<float>({}), {2});
+ Broadcast(ConstantR1<float>(&b, {}), {2});
Array2D<float> expected(2, 0);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
@@ -193,7 +274,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
XlaBuilder b(TestName());
- b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
+ Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {0});
Array2D<float> expected(0, 3);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
@@ -209,14 +290,14 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
// dimensions.
XlaBuilder b(TestName());
- b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
- b.ConstantLiteral(*Literal::CreateR3<float>(
- {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
- /*broadcast_dimensions=*/{1, 2});
+ Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
auto expected =
- Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
- {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
+ LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
+ {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -260,9 +341,10 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
&r3_implicit_array, 1.0, 0.2, 56789);
- auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
- auto r3_parameter = builder.Parameter(1, r3_shape, "input");
- XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
+ auto r3_implicit_parameter =
+ Parameter(&builder, 0, r3_implicit_shape, "input");
+ auto r3_parameter = Parameter(&builder, 1, r3_shape, "input");
+ BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
spec.output_bounds[2]);
@@ -284,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
}
}
- auto expected = Literal::CreateR3FromArray3D(expected_array);
+ auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r3_implicit_global_data.get(), r3_global_data.get()},
@@ -306,10 +388,10 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
- b.Add(r3h, r1h);
+ Add(r3h, r1h);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
@@ -317,79 +399,81 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
- auto r1 =
- b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}}));
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1);
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -509,14 +593,14 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
&r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
auto r2_implicit_parameter1 =
- builder.Parameter(0, r2_implicit_shape1, "input0");
- auto r2_parameter = builder.Parameter(1, r2_shape, "input1");
+ Parameter(&builder, 0, r2_implicit_shape1, "input0");
+ auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1");
auto r2_implicit_parameter2 =
- builder.Parameter(2, r2_implicit_shape2, "input2");
+ Parameter(&builder, 2, r2_implicit_shape2, "input2");
XlaOp op1 =
BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
- XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
+ BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
@@ -530,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
*v = ApplyOpToFloats(spec.op2, tmp, v3);
});
- auto expected = Literal::CreateR2FromArray2D(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
@@ -544,80 +628,82 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}}));
- auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
- b.Add(r2, r1);
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}}));
- auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
- b.Add(r2, r1);
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantR1<float>({10, 20});
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r3, r1, {0});
+ auto r1 = ConstantR1<float>(&b, {10, 20});
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r3, r1, {0});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantR1<float>({10, 20});
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r1, r3, {1});
+ auto r1 = ConstantR1<float>(&b, {10, 20});
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r1, r3, {1});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
- auto r1 = b.ConstantR1<float>({10, 20});
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
- b.Add(r1, r3, {2});
+ auto r1 = ConstantR1<float>(&b, {10, 20});
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ Add(r1, r3, {2});
- auto expected =
- Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
XlaBuilder b(TestName());
- auto r1_0 = b.ConstantR1<float>({1000, 2000});
- auto r1_1 = b.ConstantR1<float>({100, 200});
- auto r1_2 = b.ConstantR1<float>({10, 20});
- auto r3 = b.ConstantLiteral(
- *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
+ auto r1_1 = ConstantR1<float>(&b, {100, 200});
+ auto r1_2 = ConstantR1<float>(&b, {10, 20});
+ auto r3 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
- r3 = b.Add(r1_0, r3, {0});
- r3 = b.Add(r3, r1_1, {1});
- r3 = b.Add(r1_2, r3, {2});
+ r3 = Add(r1_0, r3, {0});
+ r3 = Add(r3, r1_1, {1});
+ r3 = Add(r1_2, r3, {2});
}
- r3 = b.Mul(r3, b.ConstantR0<float>(-2));
+ r3 = Mul(r3, ConstantR0<float>(&b, -2));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
@@ -626,19 +712,19 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
XlaBuilder b(TestName());
- auto r1_0 = b.ConstantR1<float>({1000, 2000});
- auto r1_1 = b.ConstantR1<float>({100, 200});
- auto r1_2 = b.ConstantR1<float>({10, 20});
- auto r0 = b.ConstantR0<float>(3);
- auto r3 = b.Broadcast(r0, {2, 2, 2});
+ auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
+ auto r1_1 = ConstantR1<float>(&b, {100, 200});
+ auto r1_2 = ConstantR1<float>(&b, {10, 20});
+ auto r0 = ConstantR0<float>(&b, 3);
+ auto r3 = Broadcast(r0, {2, 2, 2});
for (int i = 0; i < 3; ++i) {
- r3 = b.Add(r1_0, r3, {0});
- r3 = b.Add(r3, r1_1, {1});
- r3 = b.Add(r1_2, r3, {2});
+ r3 = Add(r1_0, r3, {0});
+ r3 = Add(r3, r1_1, {1});
+ r3 = Add(r1_2, r3, {2});
}
- r3 = b.Mul(r3, b.ConstantR0<float>(-1));
+ r3 = Mul(r3, ConstantR0<float>(&b, -1));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
@@ -650,10 +736,10 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
// results in a shape incompatible with the lhs [2, 3, 1].
XlaBuilder b(TestName());
- b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
- b.ConstantLiteral(*Literal::CreateR3<float>(
- {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
- /*broadcast_dimensions=*/{1, 2});
+ Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
auto result_status = Execute(&b, {});
EXPECT_FALSE(result_status.ok());
@@ -665,26 +751,26 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
// Test invalid broadcasting with [1, 2] and [2, 3] inputs.
XlaBuilder b(TestName());
- b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
- b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+ Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
+ ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
auto result_status = Execute(&b, {});
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(result_status.status().error_message(),
- HasSubstr("op BINOP_ADD with incompatible shapes"));
+ HasSubstr("op add with incompatible shapes"));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
// Test invalid broadcasting with [1, 2] and [2, 3] inputs.
XlaBuilder b(TestName());
- b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
- b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+ Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
+ ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
auto result_status = Execute(&b, {});
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(result_status.status().error_message(),
- HasSubstr("op BINOP_ADD with incompatible shapes"));
+ HasSubstr("op add with incompatible shapes"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
index 51b9f0d3e3..c7b94b5bba 100644
--- a/tensorflow/compiler/xla/tests/broadcast_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -37,7 +37,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
// Test degenerate case of broadcasting a scalar into a scalar.
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {}), input, {}));
@@ -46,14 +46,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
- error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR0<float>(42.0),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
@@ -63,14 +63,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
// Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
// to enable testing of the results.
@@ -86,18 +86,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
LiteralSlice(*result, {0}), error_spec_));
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
LiteralSlice(*result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
@@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
@@ -116,7 +116,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
// the dimensions, ie transpose.
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
@@ -125,15 +125,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(
- LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
- *result, error_spec_));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
@@ -143,15 +143,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
*result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>({1.0, 2.0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
// Broadcast vector in dimension 1.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -166,8 +166,9 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -176,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
int64 r1_size = input_data.size();
std::iota(input_data.begin(), input_data.end(), 0.0f);
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(input_data)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
// Broadcast vector in dimension 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -196,8 +197,9 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -207,7 +209,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
std::vector<float> r1_array(64, 42.0);
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<float>(r1_array)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
// Broadcast vector in dimension 1.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -218,14 +220,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
+ EXPECT_TRUE(LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D(r4_array),
*result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
auto builder = HloComputation::Builder(TestName());
auto input = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
@@ -238,15 +240,16 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
auto builder = HloComputation::Builder(TestName());
Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D<float>(to_broadcast)));
+ LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
// Broadcast vector in dimensions 2 and 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -260,8 +263,9 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -280,7 +284,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
}
}
auto input = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR3FromArray3D<float>(input_vals)));
+ LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
// Broadcast vector in dimensions 2 and 3.
builder.AddInstruction(HloInstruction::CreateBroadcast(
@@ -291,8 +295,9 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*LiteralUtil::CreateR4FromArray4D<float>(expected),
+ *result, error_spec_));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index 5fd33b50c9..2086e38b91 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -34,7 +35,7 @@ class CallOpTest : public ClientLibraryTestBase {
protected:
XlaComputation CreateR0F32IdentityComputation() {
XlaBuilder builder("Identity");
- builder.Parameter(0, r0f32_, "x");
+ Parameter(&builder, 0, r0f32_, "x");
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -42,9 +43,9 @@ class CallOpTest : public ClientLibraryTestBase {
XlaComputation CreateR1S0F32AdditionComputation() {
XlaBuilder builder("Addition");
- auto x = builder.Parameter(0, r1s0f32_, "x");
- auto y = builder.Parameter(1, r1s0f32_, "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, r1s0f32_, "x");
+ auto y = Parameter(&builder, 1, r1s0f32_, "y");
+ Add(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -52,9 +53,9 @@ class CallOpTest : public ClientLibraryTestBase {
XlaComputation CreateR1S2F32AdditionComputation() {
XlaBuilder builder("Addition");
- auto x = builder.Parameter(0, r1s2f32_, "x");
- auto y = builder.Parameter(1, r1s2f32_, "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, r1s2f32_, "x");
+ auto y = Parameter(&builder, 1, r1s2f32_, "y");
+ Add(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -62,7 +63,7 @@ class CallOpTest : public ClientLibraryTestBase {
XlaComputation CreateR0F32TupleComputation() {
XlaBuilder builder("Tuple");
- builder.Tuple({builder.Parameter(0, r0f32_, "x")});
+ Tuple(&builder, {Parameter(&builder, 0, r0f32_, "x")});
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -76,8 +77,9 @@ class CallOpTest : public ClientLibraryTestBase {
XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32IdentityComputation();
- auto constant = builder.ConstantLiteral(*Literal::CreateR0<float>(42.0));
- builder.Call(callee, {constant});
+ auto constant =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR0<float>(42.0));
+ Call(&builder, callee, {constant});
ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
}
@@ -85,9 +87,9 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) {
XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S0F32AdditionComputation();
- auto x = builder.ConstantLiteral(*Literal::CreateR1<float>({}));
- auto y = builder.ConstantLiteral(*Literal::CreateR1<float>({}));
- builder.Call(callee, {x, y});
+ auto x = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ auto y = ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({}));
+ Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
}
@@ -95,9 +97,11 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) {
XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR1S2F32AdditionComputation();
- auto x = builder.ConstantLiteral(*Literal::CreateR1<float>({1.0f, 2.0f}));
- auto y = builder.ConstantLiteral(*Literal::CreateR1<float>({2.0f, 3.0f}));
- builder.Call(callee, {x, y});
+ auto x =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ auto y =
+ ConstantLiteral(&builder, *LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ Call(&builder, callee, {x, y});
ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
}
@@ -105,40 +109,40 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) {
XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) {
XlaBuilder builder("inner");
{
- auto x = builder.Parameter(0, r0f32_, "x");
- builder.Add(x, builder.ConstantR0<float>(1.0));
+ auto x = Parameter(&builder, 0, r0f32_, "x");
+ Add(x, ConstantR0<float>(&builder, 1.0));
}
TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build());
XlaBuilder builder2("outer");
{
- auto x = builder2.Parameter(0, r0f32_, "x");
- x = builder2.Call(inner, {x});
- x = builder2.Call(inner, {x});
- x = builder2.Call(inner, {x});
+ auto x = Parameter(&builder2, 0, r0f32_, "x");
+ x = Call(&builder2, inner, {x});
+ x = Call(&builder2, inner, {x});
+ x = Call(&builder2, inner, {x});
}
TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build());
XlaBuilder builder3("outermost");
{
- auto x = builder3.Parameter(0, r0f32_, "x");
- x = builder3.Call(outer, {x});
- x = builder3.Call(outer, {x});
- x = builder3.Call(outer, {x});
+ auto x = Parameter(&builder3, 0, r0f32_, "x");
+ x = Call(&builder3, outer, {x});
+ x = Call(&builder3, outer, {x});
+ x = Call(&builder3, outer, {x});
}
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> start,
- client_->TransferToServer(*Literal::CreateR0<float>(1.0f)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(1.0f)));
ComputeAndCompareR0<float>(&builder3, 10.0f, {start.get()}, ErrorSpec(0.0f));
}
XLA_TEST_F(CallOpTest, CallR0F32Tuple) {
XlaBuilder builder(TestName());
XlaComputation callee = CreateR0F32TupleComputation();
- auto elem = Literal::CreateR0<float>(42.0);
- auto tuple = Literal::MakeTuple({elem.get()});
- builder.Call(callee, {builder.ConstantLiteral(*elem)});
+ auto elem = LiteralUtil::CreateR0<float>(42.0);
+ auto tuple = LiteralUtil::MakeTuple({elem.get()});
+ Call(&builder, callee, {ConstantLiteral(&builder, *elem)});
ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
}
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
index 660ff0cad5..0bc8facfe2 100644
--- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -36,11 +36,11 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {};
TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XlaBuilder builder("add_two_params");
- auto param_literal = Literal::CreateR1<float>({1.1f, 2.2f});
+ auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
- auto p0 = builder.Parameter(0, param_literal->shape(), "param0");
- auto p1 = builder.Parameter(1, param_literal->shape(), "param1");
- auto add = builder.Add(p0, p1);
+ auto p0 = Parameter(&builder, 0, param_literal->shape(), "param0");
+ auto p1 = Parameter(&builder, 1, param_literal->shape(), "param1");
+ Add(p0, p1);
auto param0_data =
client_->TransferToServer(*param_literal).ConsumeValueOrDie();
@@ -77,20 +77,20 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
XlaBuilder builder("add_two_params");
- auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
- auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1");
- auto add = builder.Mul(p0, p1);
+ auto p0 = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto p1 = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {4}), "param1");
+ Mul(p0, p1);
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
auto computation = computation_status.ConsumeValueOrDie();
- auto f32_literal = Literal::CreateR0<float>(1.1f);
+ auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
- auto f32_4_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
auto f32_4_data =
client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
- auto u8_4_literal = Literal::CreateR1U8("hola");
+ auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
// Match
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index bf8ed4d9fb..ef784da457 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
@@ -156,7 +157,7 @@ string ClientLibraryTestBase::ExecuteToString(
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -294,7 +295,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -346,7 +347,7 @@ Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = Literal::ConvertF32ToBF16(expected);
+ converted_expected = LiteralUtil::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
@@ -388,7 +389,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8(
auto actual = actual_status.ConsumeValueOrDie();
// Turn the expected value into a literal.
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1U8(expected);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
VLOG(1) << "expected: " << expected_literal->ToString();
VLOG(1) << "actual: " << actual->ToString();
@@ -486,11 +487,11 @@ ClientLibraryTestBase::ComputeValueAndReference(
XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
XlaBuilder builder("relu");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
- auto z_value = builder.Parameter(0, shape, "z_value");
+ auto z_value = Parameter(&builder, 0, shape, "z_value");
auto zero = use_bfloat16_
- ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
- : builder.ConstantR0<float>(0.0f);
- builder.Max(z_value, zero);
+ ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
+ : ConstantR0<float>(&builder, 0.0f);
+ Max(z_value, zero);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -499,9 +500,9 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
XlaComputation ClientLibraryTestBase::CreateScalarMax() {
XlaBuilder builder("max");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
- auto x = builder.Parameter(0, shape, "x");
- auto y = builder.Parameter(1, shape, "y");
- builder.Max(x, y);
+ auto x = Parameter(&builder, 0, shape, "x");
+ auto y = Parameter(&builder, 1, shape, "y");
+ Max(x, y);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -510,13 +511,13 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() {
XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
XlaBuilder builder("relu_sensitivity");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
- auto activation = builder.Parameter(0, shape, "activation");
- auto backprop = builder.Parameter(1, shape, "backprop");
+ auto activation = Parameter(&builder, 0, shape, "activation");
+ auto backprop = Parameter(&builder, 1, shape, "backprop");
auto zero = use_bfloat16_
- ? builder.ConstantR0<bfloat16>(static_cast<bfloat16>(0.0f))
- : builder.ConstantR0<float>(0.0f);
- auto activation_gtz = builder.Gt(activation, zero);
- builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
+ ? ConstantR0<bfloat16>(&builder, static_cast<bfloat16>(0.0f))
+ : ConstantR0<float>(&builder, 0.0f);
+ auto activation_gtz = Gt(activation, zero);
+ Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
@@ -559,8 +560,9 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
- return builder->ConstantLiteral(
- use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
+ return ConstantLiteral(builder, use_bfloat16_
+ ? *LiteralUtil::ConvertF32ToBF16(literal)
+ : literal);
}
std::unique_ptr<GlobalData>
@@ -581,14 +583,14 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {
- converted_literal = Literal::ConvertF32ToBF16(literal);
+ converted_literal = LiteralUtil::ConvertF32ToBF16(literal);
param_literal = converted_literal.get();
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*param_literal, device_handle)
.ConsumeValueOrDie();
*data_handle =
- builder->Parameter(parameter_number, param_literal->shape(), name);
+ Parameter(builder, parameter_number, param_literal->shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 0499fec589..fcc9347db5 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -284,7 +285,7 @@ class ClientLibraryTestBase : public ::testing::Test {
template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
- return AddParam(*Literal::CreateFromArray(argument), builder);
+ return AddParam(*LiteralUtil::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
@@ -299,13 +300,14 @@ class ClientLibraryTestBase : public ::testing::Test {
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
XlaBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
+ return CreateConstantFromLiteral(*LiteralUtil::CreateFromArray(array),
+ builder);
}
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
+ return CreateConstantFromLiteral(*LiteralUtil::CreateR0<NativeT>(value),
builder);
}
@@ -373,6 +375,13 @@ class ClientLibraryTestBase : public ::testing::Test {
// The float type used in this test, BF16 or F32 according to use_bfloat16.
PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; }
+ // Executes the computation and calculates the expected reference value using
+ // the reference client. Returns two literals in the order of (expected,
+ // actual).
+ StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+ ComputeValueAndReference(XlaBuilder* builder,
+ tensorflow::gtl::ArraySlice<Literal> arguments);
+
Client* client_;
Client* ref_client_; // To compute reference result.
ExecutionOptions execution_options_;
@@ -390,13 +399,6 @@ class ClientLibraryTestBase : public ::testing::Test {
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
- // Executes the computation and calculates the expected reference value using
- // the reference client. Returns two literals in the order of (expected,
- // actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<Literal> arguments);
-
// Whether to run tests with all float-type input/output converted to
// bfloat16.
bool use_bfloat16_ = false;
@@ -410,7 +412,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR0<NativeT>(expected);
+ LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -426,7 +428,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR0<NativeT>(expected);
+ LiteralUtil::CreateR0<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -436,7 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<NativeT>(expected);
+ LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -452,7 +454,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<NativeT>(expected);
+ LiteralUtil::CreateR1<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -462,7 +464,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
XlaBuilder* builder, const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2FromArray2D<NativeT>(expected);
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -478,7 +480,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2FromArray2D<NativeT>(expected);
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -488,7 +490,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
XlaBuilder* builder, const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR3FromArray3D<NativeT>(expected);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -504,7 +506,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR3FromArray3D<NativeT>(expected);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -514,7 +516,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
XlaBuilder* builder, const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR4FromArray4D<NativeT>(expected);
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments);
}
@@ -530,7 +532,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4(
std::is_same<NativeT, complex64>::value,
"Float or complex type required when specifying an ErrorSpec");
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR4FromArray4D<NativeT>(expected);
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
arguments, error);
}
@@ -539,13 +541,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR0(value);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
return data;
}
@@ -553,13 +555,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR1(values);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
return data;
}
@@ -567,13 +569,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
return data;
}
@@ -581,13 +583,13 @@ template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
- std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = Literal::ConvertF32ToBF16(*literal);
+ literal = LiteralUtil::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
- *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ *data_handle = Parameter(builder, parameter_number, literal->shape(), name);
return data;
}
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 08671cf624..6ce2f844a3 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -43,8 +43,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
for (const std::vector<int64>& execute_layout : layouts) {
for (const std::vector<int64>& transfer_layout : layouts) {
- b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}}));
+ Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
+ ConstantR2<int32>(&b, {{10, 20}, {30, 40}}));
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
ExecutionOptions execution_options = execution_options_;
@@ -56,7 +56,7 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
client_->Execute(computation, {}, &execution_options));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2WithLayout<int32>(
+ LiteralUtil::CreateR2WithLayout<int32>(
{{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
TF_ASSERT_OK_AND_ASSIGN(
@@ -72,8 +72,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
XlaBuilder b(TestName());
- b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}})});
+ Tuple(&b, {ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
+ ConstantR2<int32>(&b, {{10, 20}, {30, 40}})});
TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build());
@@ -112,13 +112,13 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<GlobalData> const_arg,
- client_->TransferToServer(*Literal::CreateR2<int32>({{5, 6}, {7, 8}})));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> const_arg,
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2<int32>({{5, 6}, {7, 8}})));
XlaBuilder b(TestName() + ".add");
- b.Add(b.Parameter(0, shape, "param_0"),
- b.ConstantR2<int32>({{1, 2}, {3, 4}}));
+ Add(Parameter(&b, 0, shape, "param_0"),
+ ConstantR2<int32>(&b, {{1, 2}, {3, 4}}));
TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build());
// We can't really test parallel execution on CPU since all of the cores in a
@@ -136,7 +136,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
TF_ASSERT_OK_AND_ASSIGN(auto results,
client_->ExecuteParallel(computation_instances));
- auto expected_result = Literal::CreateR2<int32>({{6, 8}, {10, 12}});
+ auto expected_result = LiteralUtil::CreateR2<int32>({{6, 8}, {10, 12}});
TF_ASSERT_OK_AND_ASSIGN(
auto result_literal,
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 50a0069648..ff38246286 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -50,7 +50,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
&execution_profile)
.ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR0<float>(expected_result), *result, error_spec_));
+ *LiteralUtil::CreateR0<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -67,7 +67,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie();
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>(expected_result), *result, error_spec_));
+ *LiteralUtil::CreateR2<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
@@ -77,7 +77,7 @@ class CompilationCacheTest : public ClientLibraryTestBase {
// TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XlaBuilder builder(TestName());
- builder.Neg(builder.ConstantR0<float>(42.0));
+ Neg(ConstantR0<float>(&builder, 42.0));
XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
@@ -89,17 +89,17 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) {
XLA_TEST_F(CompilationCacheTest,
DISABLED_ComputationCalledWithDifferentParameters) {
std::unique_ptr<GlobalData> data_42 =
- client_->TransferToServer(*Literal::CreateR0<float>(42.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_123 =
- client_->TransferToServer(*Literal::CreateR0<float>(123.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> data_456 =
- client_->TransferToServer(*Literal::CreateR0<float>(456.0f))
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
- builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Neg(Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param"));
XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
@@ -115,16 +115,16 @@ XLA_TEST_F(CompilationCacheTest,
// TODO(b/74197823): Disabled because there is no cache in the new design.
XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) {
XlaBuilder builder_neg(TestName() + "_neg");
- builder_neg.Neg(builder_neg.ConstantR0<float>(42.0));
+ Neg(ConstantR0<float>(&builder_neg, 42.0));
XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie();
XlaBuilder builder_exp(TestName() + "_exp");
- builder_exp.Exp(builder_exp.ConstantR0<float>(1.0));
+ Exp(ConstantR0<float>(&builder_exp, 1.0));
XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie();
XlaBuilder builder_add(TestName() + "_add");
- builder_add.Add(builder_add.ConstantR0<float>(2.0),
- builder_add.ConstantR0<float>(3.0));
+ Add(ConstantR0<float>(&builder_add, 2.0),
+ ConstantR0<float>(&builder_add, 3.0));
XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie();
ExecuteComputationR0F32(computation_neg, {}, -42.0,
@@ -143,18 +143,18 @@ XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) {
// layouts. Use these arrays as parameters to a simple computation. If the
// layout of the array changes then computation should be recompiled (cache
// miss).
- auto rowmaj_array = Literal::CreateR2WithLayout(
+ auto rowmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
auto rowmaj_handle =
client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
- auto colmaj_array = Literal::CreateR2WithLayout(
+ auto colmaj_array = LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
auto colmaj_handle =
client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecuteComputationR2F32(computation, {colmaj_handle.get()},
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index ba22530f1c..64bf8b3b38 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -99,7 +99,7 @@ TEST_F(ComputeConstantTest, ScalarInt32Literal) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.ConstantR0<int32>(42);
+ auto computation = ConstantR0<int32>(&b, 42);
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<int32>(client, computation, &b);
@@ -113,7 +113,7 @@ TEST_F(ComputeConstantTest, ScalarFloatAdd) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
+ Add(ConstantR0<float>(&b, 42.5f), ConstantR0<float>(&b, 1.5f));
EXPECT_TRUE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -127,8 +127,8 @@ TEST_F(ComputeConstantTest, ScalarRng) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
- ShapeUtil::MakeShape(F32, {}));
+ RngUniform(ConstantR0<float>(&b, 1.1f), ConstantR0<float>(&b, 2.1f),
+ ShapeUtil::MakeShape(F32, {}));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -141,7 +141,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
+ auto computation = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param");
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -156,8 +156,8 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR0<float>(1.0f),
- b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Add(ConstantR0<float>(&b, 1.0f),
+ Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "param"));
EXPECT_FALSE(IsConstant(computation, &b));
auto value = ComputeConstantScalar<float>(client, computation, &b);
@@ -174,18 +174,18 @@ TEST_F(ComputeConstantTest, UnrelatedParam) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto param_a = Parameter(&b, 10, ShapeUtil::MakeShape(F32, {}), "param0");
auto constant_4 =
- b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
- auto not_constant_a = b.Add(constant_4, param_a);
+ Add(ConstantR0<float>(&b, 2.5f), ConstantR0<float>(&b, 1.5f));
+ auto not_constant_a = Add(constant_4, param_a);
- auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
+ auto param_b = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "param1");
auto constant_9 =
- b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
- auto not_constant_b = b.Add(param_b, constant_9);
+ Mul(ConstantR0<float>(&b, 2.0f), ConstantR0<float>(&b, 4.5f));
+ auto not_constant_b = Add(param_b, constant_9);
- auto constant_13 = b.Add(constant_4, constant_9);
- b.Add(not_constant_b, b.Add(constant_13, not_constant_a));
+ auto constant_13 = Add(constant_4, constant_9);
+ Add(not_constant_b, Add(constant_13, not_constant_a));
EXPECT_TRUE(IsConstant(constant_13, &b));
@@ -201,13 +201,13 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
XlaBuilder b(TestName());
auto computation =
- b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
+ Add(ConstantR1<int32>(&b, {1, 2}), ConstantR1<int32>(&b, {3, 4}));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR1<int32>({4, 6});
+ LiteralUtil::CreateR1<int32>({4, 6});
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -216,12 +216,12 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
for (ClientType client_type : client_types) {
Client* client = ClientOrDie(platform_, client_type);
XlaBuilder b(TestName());
- auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
+ auto computation = Div(ConstantR0<int32>(&b, 15), ConstantR0<int32>(&b, 3));
EXPECT_TRUE(IsConstant(computation, &b));
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
- std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
@@ -237,13 +237,13 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
TF_ASSERT_OK_AND_ASSIGN(
auto computed, ComputeConstantLiteral(
client,
- b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
- b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ Add(ConstantR2<int32>(&b, {{1, 2}, {3, 4}}),
+ ConstantR2<int32>(&b, {{10, 20}, {30, 40}})),
&b, &layout_proto));
std::unique_ptr<Literal> expected_literal =
- Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
- LayoutUtil::MakeLayout(layout));
+ LiteralUtil::CreateR2WithLayout<int32>(
+ {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout));
ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
expected_literal->shape(), computed->shape()));
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index a4c8a83eb1..9f288634c0 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -39,7 +39,7 @@ using ::testing::HasSubstr;
// Concatenate expects at least one argument.
XLA_TEST_F(ConcatTest, Concat_Nothing) {
XlaBuilder builder(TestName());
- builder.ConcatInDim({}, 0);
+ ConcatInDim(&builder, {}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
@@ -49,8 +49,8 @@ XLA_TEST_F(ConcatTest, Concat_Nothing) {
// Concatenate with one argument works.
XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0, 64.0});
- builder.ConcatInDim({a}, 0);
+ auto a = ConstantR1<float>(&builder, {42.0, 64.0});
+ ConcatInDim(&builder, {a}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -58,8 +58,8 @@ XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- builder.ConcatInDim({a}, 0);
+ auto a = ConstantR1<float>(&builder, {});
+ ConcatInDim(&builder, {a}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -69,9 +69,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_Nothing) {
// to concatenate on.
XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR0<float>(42.0);
- auto b = builder.ConstantR0<float>(64.0);
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR0<float>(&builder, 42.0);
+ auto b = ConstantR0<float>(&builder, 64.0);
+ ConcatInDim(&builder, {a, b}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
@@ -80,9 +80,9 @@ XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({});
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {});
+ ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -90,9 +90,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({});
- auto b = builder.ConstantR1<float>({256.0});
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR1<float>(&builder, {});
+ auto b = ConstantR1<float>(&builder, {256.0});
+ ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -100,9 +100,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0, 64.0});
- auto b = builder.ConstantR1<float>({});
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR1<float>(&builder, {42.0, 64.0});
+ auto b = ConstantR1<float>(&builder, {});
+ ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {42, 64};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -110,9 +110,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0, 64.0});
- auto b = builder.ConstantR1<float>({256.0});
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR1<float>(&builder, {42.0, 64.0});
+ auto b = ConstantR1<float>(&builder, {256.0});
+ ConcatInDim(&builder, {a, b}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -130,9 +130,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>(lhs);
- auto b = builder.ConstantR1<float>(rhs);
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR1<float>(&builder, lhs);
+ auto b = ConstantR1<float>(&builder, rhs);
+ ConcatInDim(&builder, {a, b}, 0);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
@@ -140,9 +140,9 @@ XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
for (int dim : {0, 1}) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
- auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
- builder.ConcatInDim({a, b}, dim);
+ auto a = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
+ auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 0));
+ ConcatInDim(&builder, {a, b}, dim);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
ErrorSpec(0.0001));
@@ -153,9 +153,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
- auto a = builder.ConstantR2FromArray2D(*a_array);
- auto b = builder.ConstantR2FromArray2D(*b_array);
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR2FromArray2D(&builder, *a_array);
+ auto b = ConstantR2FromArray2D(&builder, *b_array);
+ ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected({
{0},
@@ -168,9 +168,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(1, 1);
auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
- auto a = builder.ConstantR2FromArray2D(*a_array);
- auto b = builder.ConstantR2FromArray2D(*b_array);
- builder.ConcatInDim({a, b}, 1);
+ auto a = ConstantR2FromArray2D(&builder, *a_array);
+ auto b = ConstantR2FromArray2D(&builder, *b_array);
+ ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected({
{0, 64},
@@ -181,9 +181,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
XlaBuilder builder(TestName());
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
- auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
- auto b = builder.ConstantR2FromArray2D(*b_array);
- builder.ConcatInDim({a, b}, 1);
+ auto a = ConstantR2FromArray2D(&builder, Array2D<float>(2, 0));
+ auto b = ConstantR2FromArray2D(&builder, *b_array);
+ ConcatInDim(&builder, {a, b}, 1);
ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
}
@@ -192,9 +192,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(2, 3);
auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
- auto a = builder.ConstantR2FromArray2D(*a_array);
- auto b = builder.ConstantR2FromArray2D(*b_array);
- builder.ConcatInDim({a, b}, 1);
+ auto a = ConstantR2FromArray2D(&builder, *a_array);
+ auto b = ConstantR2FromArray2D(&builder, *b_array);
+ ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected({
{0, 1, 2, 64, 65, 66, 67, 68},
@@ -206,9 +206,9 @@ XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
- auto a = builder.ConstantR2FromArray2D(*a_array);
- auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR2FromArray2D(&builder, *a_array);
+ auto b = ConstantR2FromArray2D(&builder, Array2D<float>(0, 2));
+ ConcatInDim(&builder, {a, b}, 0);
ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
}
@@ -217,9 +217,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
XlaBuilder builder(TestName());
auto a_array = CreatePatternedMatrix(3, 2);
auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
- auto a = builder.ConstantR2FromArray2D(*a_array);
- auto b = builder.ConstantR2FromArray2D(*b_array);
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR2FromArray2D(&builder, *a_array);
+ auto b = ConstantR2FromArray2D(&builder, *b_array);
+ ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected({
{0, 1},
@@ -236,9 +236,9 @@ XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
- auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
- builder.ConcatInDim({a, b}, 2);
+ auto a = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 2));
+ auto b = ConstantR3FromArray3D(&builder, Array3D<float>(3, 0, 1));
+ ConcatInDim(&builder, {a, b}, 2);
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
ErrorSpec(0.0001));
}
@@ -257,9 +257,9 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
{{7}},
{{8}},
});
- auto a = builder.ConstantR3FromArray3D(a_array);
- auto b = builder.ConstantR3FromArray3D(b_array);
- builder.ConcatInDim({a, b}, 2);
+ auto a = ConstantR3FromArray3D(&builder, a_array);
+ auto b = ConstantR3FromArray3D(&builder, b_array);
+ ConcatInDim(&builder, {a, b}, 2);
Array3D<float> expected({
{{0, 1, 6}},
@@ -271,10 +271,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0});
- auto b = builder.ConstantR1<float>({64.0});
- auto c = builder.ConstantR1<float>({256.0});
- builder.ConcatInDim({a, b, c}, 0);
+ auto a = ConstantR1<float>(&builder, {42.0});
+ auto b = ConstantR1<float>(&builder, {64.0});
+ auto c = ConstantR1<float>(&builder, {256.0});
+ ConcatInDim(&builder, {a, b, c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -300,10 +300,10 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
{{7}},
{{11}},
});
- auto a = builder.ConstantR3FromArray3D(a_array);
- auto b = builder.ConstantR3FromArray3D(b_array);
- auto c = builder.ConstantR3FromArray3D(c_array);
- builder.ConcatInDim({a, b, c}, 2);
+ auto a = ConstantR3FromArray3D(&builder, a_array);
+ auto b = ConstantR3FromArray3D(&builder, b_array);
+ auto c = ConstantR3FromArray3D(&builder, c_array);
+ ConcatInDim(&builder, {a, b, c}, 2);
Array3D<float> expected({
{{0, 1, 2, 3}},
@@ -315,11 +315,11 @@ XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0});
- auto b = builder.ConstantR1<float>({64.0});
- auto c = builder.ConstantR1<float>({256.0});
+ auto a = ConstantR1<float>(&builder, {42.0});
+ auto b = ConstantR1<float>(&builder, {64.0});
+ auto c = ConstantR1<float>(&builder, {256.0});
// concatenated = (a concat b) concat c
- builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
+ ConcatInDim(&builder, {ConcatInDim(&builder, {a, b}, 0), c}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -327,11 +327,11 @@ XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0});
- auto b = builder.ConstantR1<float>({64.0});
- auto c = builder.ConstantR1<float>({256.0});
+ auto a = ConstantR1<float>(&builder, {42.0});
+ auto b = ConstantR1<float>(&builder, {64.0});
+ auto c = ConstantR1<float>(&builder, {256.0});
// concatenated = a concat (b concat c)
- builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
+ ConcatInDim(&builder, {a, ConcatInDim(&builder, {b, c}, 0)}, 0);
std::vector<float> expected = {42, 64, 256};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -346,9 +346,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2FromArray2D<float>(lhs);
- auto b = builder.ConstantR2FromArray2D<float>(rhs);
- builder.ConcatInDim({a, b}, 0);
+ auto a = ConstantR2FromArray2D<float>(&builder, lhs);
+ auto b = ConstantR2FromArray2D<float>(&builder, rhs);
+ ConcatInDim(&builder, {a, b}, 0);
Array2D<float> expected(2, 1024);
for (int i = 0; i < 1024; ++i) {
@@ -367,9 +367,9 @@ XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2FromArray2D<float>(lhs);
- auto b = builder.ConstantR2FromArray2D<float>(rhs);
- builder.ConcatInDim({a, b}, 1);
+ auto a = ConstantR2FromArray2D<float>(&builder, lhs);
+ auto b = ConstantR2FromArray2D<float>(&builder, rhs);
+ ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected(1, 2048);
for (int i = 0; i < 1024; ++i) {
@@ -392,9 +392,9 @@ XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
}
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2FromArray2D<float>(lhs);
- auto b = builder.ConstantR2FromArray2D<float>(rhs);
- builder.ConcatInDim({a, b}, 1);
+ auto a = ConstantR2FromArray2D<float>(&builder, lhs);
+ auto b = ConstantR2FromArray2D<float>(&builder, rhs);
+ ConcatInDim(&builder, {a, b}, 1);
Array2D<float> expected(64, 66);
for (int i0 = 0; i0 < 64; ++i0) {
@@ -410,22 +410,37 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
XlaBuilder builder(TestName());
auto opaque_shape = ShapeUtil::MakeOpaqueShape();
auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
- auto x = builder.Parameter(0, r1f32, "x");
- auto y = builder.Parameter(1, opaque_shape, "y");
- builder.ConcatInDim({x, y}, 0);
+ auto x = Parameter(&builder, 0, r1f32, "x");
+ auto y = Parameter(&builder, 1, opaque_shape, "y");
+ ConcatInDim(&builder, {x, y}, 0);
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
- HasSubstr("Expected non-opaque argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
+}
+
+// Show that we can't concatenate with tokens.
+XLA_TEST_F(ConcatTest, CannotConcatTokens) {
+ XlaBuilder builder(TestName());
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
+ auto x = Parameter(&builder, 0, r1f32, "x");
+ auto y = Parameter(&builder, 1, token_shape, "y");
+ ConcatInDim(&builder, {x, y}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_THAT(
+ computation_status.status().ToString(),
+ HasSubstr("Expected array argument for operand of concatenation"));
}
XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
XlaBuilder builder(TestName());
- auto p0 = builder.ConstantR1<bool>({true});
- auto p1 = builder.ConstantR1<bool>({false});
- auto p2 = builder.ConstantR1<bool>({true});
- builder.ConcatInDim({p0, p1, p2}, 0);
+ auto p0 = ConstantR1<bool>(&builder, {true});
+ auto p1 = ConstantR1<bool>(&builder, {false});
+ auto p2 = ConstantR1<bool>(&builder, {true});
+ ConcatInDim(&builder, {p0, p1, p2}, 0);
bool expected[] = {true, false, true};
ComputeAndCompareR1<bool>(&builder, expected, {});
@@ -433,11 +448,11 @@ XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
XlaBuilder builder(TestName());
- auto a0 = builder.ConstantR1<int32>({1});
- auto a1 = builder.ConstantR1<int32>({2, 3});
- auto a2 = builder.ConstantR1<int32>({4, 5, 6});
- auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
- builder.ConcatInDim({a0, a1, a2, a3}, 0);
+ auto a0 = ConstantR1<int32>(&builder, {1});
+ auto a1 = ConstantR1<int32>(&builder, {2, 3});
+ auto a2 = ConstantR1<int32>(&builder, {4, 5, 6});
+ auto a3 = ConstantR1<int32>(&builder, {7, 8, 9, 10});
+ ConcatInDim(&builder, {a0, a1, a2, a3}, 0);
std::vector<int32> expected(10);
std::iota(expected.begin(), expected.end(), 1);
@@ -472,7 +487,7 @@ XLA_TEST_F(ConcatTest, ConcatR3WeirdDims) {
auto p1 = CreateR3Parameter<float>(arr1, /*parameter_number=*/1, "p1",
&builder, &h1);
- builder.ConcatInDim({h0, h1}, 2);
+ ConcatInDim(&builder, {h0, h1}, 2);
ComputeAndCompareR3<float>(&builder, expected, {p0.get(), p1.get()});
}
@@ -499,9 +514,9 @@ TEST_P(ConcatR2BinaryTest, DoIt) {
rhs.FillUnique(1000);
XlaBuilder builder(TestName());
- auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
- auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
- builder.ConcatInDim({a0, a1}, spec.concat_dimension);
+ auto a0 = ConstantR2FromArray2D<int32>(&builder, lhs);
+ auto a1 = ConstantR2FromArray2D<int32>(&builder, rhs);
+ ConcatInDim(&builder, {a0, a1}, spec.concat_dimension);
std::unique_ptr<Array2D<int32>> expected =
ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
@@ -519,19 +534,19 @@ TEST_P(ConcatR2BinaryTest, DoIt) {
// concat
XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
- auto x_literal = Literal::CreateR0<float>(2.f);
- auto y_literal = Literal::CreateR0<float>(3.f);
+ auto x_literal = LiteralUtil::CreateR0<float>(2.f);
+ auto y_literal = LiteralUtil::CreateR0<float>(3.f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, f32_scalar, "x");
- auto y = builder.Parameter(1, f32_scalar, "y");
- auto mul = builder.Mul(x, y);
- auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f}));
- auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f}));
- auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f}));
- builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0);
+ auto x = Parameter(&builder, 0, f32_scalar, "x");
+ auto y = Parameter(&builder, 1, f32_scalar, "y");
+ auto mul = Mul(x, y);
+ auto add1 = Add(mul, ConstantR1<float>(&builder, {1.f, 2.f}));
+ auto add2 = Add(mul, ConstantR1<float>(&builder, {3.f, 4.f}));
+ auto add3 = Add(mul, ConstantR1<float>(&builder, {5.f, 6.f}));
+ ConcatInDim(&builder, {add1, add2, add3}, /*dimension=*/0);
ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
{x_data.get(), y_data.get()}, ErrorSpec(1e-4));
@@ -541,21 +556,21 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
// produces the correct result in rank 1.
XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
- auto x_literal = Literal::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
- auto y_literal = Literal::CreateR0<float>(1.5f);
- auto z_literal = Literal::CreateR0<float>(5.5f);
+ auto x_literal = LiteralUtil::CreateR1<float>({2.0f, 3.0f, 5.0f, 6.0f});
+ auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
+ auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, x_literal->shape(), "x");
- auto y = builder.Parameter(1, f32_scalar, "y");
- auto z = builder.Parameter(2, f32_scalar, "z");
- auto bcast = builder.Broadcast(y, {5});
- auto bcast2 = builder.Broadcast(z, {3});
- auto concat = builder.ConcatInDim({bcast, x}, /*dimension=*/0);
- builder.ConcatInDim({concat, bcast2}, /*dimension=*/0);
+ auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto y = Parameter(&builder, 1, f32_scalar, "y");
+ auto z = Parameter(&builder, 2, f32_scalar, "z");
+ auto bcast = Broadcast(y, {5});
+ auto bcast2 = Broadcast(z, {3});
+ auto concat = ConcatInDim(&builder, {bcast, x}, /*dimension=*/0);
+ ConcatInDim(&builder, {concat, bcast2}, /*dimension=*/0);
ComputeAndCompareR1<float>(
&builder,
@@ -569,21 +584,21 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) {
XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) {
auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
Array3D<float> x3d(3, 5, 7, 3.14f);
- auto x_literal = Literal::CreateR3FromArray3D<float>(x3d);
- auto y_literal = Literal::CreateR0<float>(1.5f);
- auto z_literal = Literal::CreateR0<float>(5.5f);
+ auto x_literal = LiteralUtil::CreateR3FromArray3D<float>(x3d);
+ auto y_literal = LiteralUtil::CreateR0<float>(1.5f);
+ auto z_literal = LiteralUtil::CreateR0<float>(5.5f);
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, x_literal->shape(), "x");
- auto y = builder.Parameter(1, f32_scalar, "y");
- auto z = builder.Parameter(2, f32_scalar, "y");
- auto y_bcast = builder.Broadcast(y, {1, 5, 7});
- auto z_bcast = builder.Broadcast(z, {4, 1, 7});
- auto concat = builder.ConcatInDim({y_bcast, x}, /*dimension=*/0);
- builder.ConcatInDim({concat, z_bcast}, /*dimension=*/1);
+ auto x = Parameter(&builder, 0, x_literal->shape(), "x");
+ auto y = Parameter(&builder, 1, f32_scalar, "y");
+ auto z = Parameter(&builder, 2, f32_scalar, "y");
+ auto y_bcast = Broadcast(y, {1, 5, 7});
+ auto z_bcast = Broadcast(z, {4, 1, 7});
+ auto concat = ConcatInDim(&builder, {y_bcast, x}, /*dimension=*/0);
+ ConcatInDim(&builder, {concat, z_bcast}, /*dimension=*/1);
Array3D<float> y_bcast3d(1, 5, 7, 1.5f);
Array3D<float> z_bcast3d(4, 1, 7, 5.5f);
auto concat0 = ReferenceUtil::Concat3D(y_bcast3d, x3d, 0);
diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc
index 7ff6706935..35f1400fb2 100644
--- a/tensorflow/compiler/xla/tests/conditional_test.cc
+++ b/tensorflow/compiler/xla/tests/conditional_test.cc
@@ -26,8 +26,8 @@ class ConditionalOpTest : public ClientLibraryTestBase {
protected:
XlaComputation CreateR0ConstantComputation(float value) {
XlaBuilder builder("Constant");
- builder.Parameter(0, empty_tuple_, "tuple");
- builder.ConstantR0<float>(value);
+ Parameter(&builder, 0, empty_tuple_, "tuple");
+ ConstantR0<float>(&builder, value);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -35,7 +35,7 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateR0IdentityComputation() {
XlaBuilder builder("Identity");
- builder.Parameter(0, r0f32_, "x");
+ Parameter(&builder, 0, r0f32_, "x");
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -43,8 +43,8 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateCeilComputation(const Shape& shape) {
XlaBuilder builder("Ceil");
- auto param = builder.Parameter(0, shape, "param");
- builder.Ceil(param);
+ auto param = Parameter(&builder, 0, shape, "param");
+ Ceil(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -60,8 +60,8 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateFloorComputation(const Shape& shape) {
XlaBuilder builder("Floor");
- auto param = builder.Parameter(0, shape, "param");
- builder.Floor(param);
+ auto param = Parameter(&builder, 0, shape, "param");
+ Floor(param);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -78,12 +78,12 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateTupleCeilComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
- auto tuple = builder.Parameter(0, tuple_shape, "tuple");
- auto x = builder.GetTupleElement(tuple, 0);
- auto y = builder.GetTupleElement(tuple, 1);
- auto x_ceil = builder.Ceil(x);
- auto y_ceil = builder.Ceil(y);
- builder.Tuple({x_ceil, y_ceil});
+ auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
+ auto x = GetTupleElement(tuple, 0);
+ auto y = GetTupleElement(tuple, 1);
+ auto x_ceil = Ceil(x);
+ auto y_ceil = Ceil(y);
+ Tuple(&builder, {x_ceil, y_ceil});
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -100,12 +100,12 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateTupleFloorComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
- auto tuple = builder.Parameter(0, tuple_shape, "tuple");
- auto x = builder.GetTupleElement(tuple, 0);
- auto y = builder.GetTupleElement(tuple, 1);
- auto x_floor = builder.Floor(x);
- auto y_floor = builder.Floor(y);
- builder.Tuple({x_floor, y_floor});
+ auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
+ auto x = GetTupleElement(tuple, 0);
+ auto y = GetTupleElement(tuple, 1);
+ auto x_floor = Floor(x);
+ auto y_floor = Floor(y);
+ Tuple(&builder, {x_floor, y_floor});
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -122,10 +122,10 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateTupleAddComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
- auto tuple = builder.Parameter(0, tuple_shape, "tuple");
- auto x = builder.GetTupleElement(tuple, 0);
- auto y = builder.GetTupleElement(tuple, 1);
- builder.Add(x, y);
+ auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
+ auto x = GetTupleElement(tuple, 0);
+ auto y = GetTupleElement(tuple, 1);
+ Add(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -142,10 +142,10 @@ class ConditionalOpTest : public ClientLibraryTestBase {
XlaComputation CreateTupleSubComputation(const string& computation_name,
const Shape& tuple_shape) {
XlaBuilder builder(computation_name);
- auto tuple = builder.Parameter(0, tuple_shape, "tuple");
- auto x = builder.GetTupleElement(tuple, 0);
- auto y = builder.GetTupleElement(tuple, 1);
- builder.Sub(x, y);
+ auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
+ auto x = GetTupleElement(tuple, 0);
+ auto y = GetTupleElement(tuple, 1);
+ Sub(x, y);
auto build_status = builder.Build();
EXPECT_IS_OK(build_status.status());
return build_status.ConsumeValueOrDie();
@@ -172,12 +172,11 @@ class ConditionalOpTest : public ClientLibraryTestBase {
// Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest, Parameters0) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operands = builder.Tuple({});
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operands = Tuple(&builder, {});
auto true_computation = CreateR0ConstantComputation(56.0f);
auto false_computation = CreateR0ConstantComputation(12.0f);
- builder.Conditional(pred, operands, true_computation, operands,
- false_computation);
+ Conditional(pred, operands, true_computation, operands, false_computation);
ComputeAndCompareR0<float>(&builder, 56.0f, {}, error_spec_);
}
@@ -185,11 +184,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters0) {
// Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest, Parameters1) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.0f);
- auto operand2 = builder.ConstantR0<float>(12.0f);
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.0f);
+ auto operand2 = ConstantR0<float>(&builder, 12.0f);
auto identity = CreateR0IdentityComputation();
- builder.Conditional(pred, operand1, identity, operand2, identity);
+ Conditional(pred, operand1, identity, operand2, identity);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -198,11 +197,11 @@ XLA_TEST_F(ConditionalOpTest, Parameters1) {
// that take in different arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.4f);
- auto operand2 = builder.ConstantR0<float>(12.6f);
- builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
- CreateR0FloorComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.4f);
+ auto operand2 = ConstantR0<float>(&builder, 12.6f);
+ Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -211,10 +210,10 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
// that take in the same arguments.
XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand = builder.ConstantR0<float>(12.6f);
- builder.Conditional(pred, operand, CreateR0CeilComputation(), operand,
- CreateR0FloorComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand = ConstantR0<float>(&builder, 12.6f);
+ Conditional(pred, operand, CreateR0CeilComputation(), operand,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -223,11 +222,11 @@ XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
// take in different arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.4f);
- auto operand2 = builder.ConstantR0<float>(12.6f);
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.4f);
+ auto operand2 = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
- builder.Conditional(pred, operand1, floor, operand2, floor);
+ Conditional(pred, operand1, floor, operand2, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -236,10 +235,10 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
// take in the same arguments.
XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand = builder.ConstantR0<float>(12.6f);
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand = ConstantR0<float>(&builder, 12.6f);
auto floor = CreateR0FloorComputation();
- builder.Conditional(pred, operand, floor, operand, floor);
+ Conditional(pred, operand, floor, operand, floor);
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -248,11 +247,11 @@ XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
// and false cases.
XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.4f);
- auto operand2 = builder.ConstantR0<float>(12.6f);
- builder.Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
- CreateR0FloorComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.4f);
+ auto operand2 = ConstantR0<float>(&builder, 12.6f);
+ Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -261,19 +260,19 @@ XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
XlaBuilder inner_builder(TestName() + ".inner_conditional");
- auto pred_cond = inner_builder.Parameter(0, r0bool, "param0");
- auto true_operand = inner_builder.Parameter(1, r0f32_, "param1");
- auto false_operand = inner_builder.Parameter(2, r0f32_, "param2");
- inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
- false_operand, CreateR0FloorComputation());
+ auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
+ auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
+ auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
+ Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
+ CreateR0FloorComputation());
auto inner_builder_result = inner_builder.Build();
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.4f);
- auto operand2 = builder.ConstantR0<float>(12.6f);
- builder.Call(inner_builder_result.ConsumeValueOrDie(),
- {pred, operand1, operand2});
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.4f);
+ auto operand2 = ConstantR0<float>(&builder, 12.6f);
+ Call(&builder, inner_builder_result.ConsumeValueOrDie(),
+ {pred, operand1, operand2});
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -282,12 +281,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
// true.
XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operand1 = builder.ConstantR0<float>(56.0f);
- auto operand2 = builder.ConstantR0<float>(12.0f);
- auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
- CreateR0TupleSubComputation());
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operand1 = ConstantR0<float>(&builder, 56.0f);
+ auto operand2 = ConstantR0<float>(&builder, 12.0f);
+ auto operands = Tuple(&builder, {operand1, operand2});
+ Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 68.0f, {}, error_spec_);
}
@@ -296,12 +295,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
// false.
XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(56.0f);
- auto operand2 = builder.ConstantR0<float>(12.0f);
- auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
- CreateR0TupleSubComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 56.0f);
+ auto operand2 = ConstantR0<float>(&builder, 12.0f);
+ auto operands = Tuple(&builder, {operand1, operand2});
+ Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
ComputeAndCompareR0<float>(&builder, 44.0f, {}, error_spec_);
}
@@ -310,12 +309,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
// predicate is true.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
- auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
- auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
- CreateR1TupleSubComputation());
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
+ auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
+ auto operands = Tuple(&builder, {operand1, operand2});
+ Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {}, error_spec_);
}
@@ -324,12 +323,12 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
// predicate is false.
XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR1<float>({24.0f, 56.0f});
- auto operand2 = builder.ConstantR1<float>({10.0f, 11.0f});
- auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
- CreateR1TupleSubComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
+ auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
+ auto operands = Tuple(&builder, {operand1, operand2});
+ Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR1TupleSubComputation());
ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {}, error_spec_);
}
@@ -337,32 +336,34 @@ XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
// Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operands = builder.Tuple(
- {builder.ConstantR0<float>(12.2f), builder.ConstantR0<float>(25.6f)});
- builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
- CreateR0TupleFloorComputation());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
+ ConstantR0<float>(&builder, 25.6f)});
+ Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
+ CreateR0TupleFloorComputation());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(12.0f).get(),
- Literal::CreateR0<float>(25.0f).get()}),
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12.0f).get(),
+ LiteralUtil::CreateR0<float>(25.0f).get()}),
{}, error_spec_);
}
// Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operands = builder.Tuple({builder.ConstantR1<float>({12.2f, 15.8f}),
- builder.ConstantR1<float>({25.6f, 29.2f})});
- builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
- CreateR1TupleFloorComputation());
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operands =
+ Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
+ ConstantR1<float>(&builder, {25.6f, 29.2f})});
+ Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
+ CreateR1TupleFloorComputation());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR1<float>({13.0f, 16.0f}).get(),
- Literal::CreateR1<float>({26.0f, 30.0f}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({13.0f, 16.0f}).get(),
+ LiteralUtil::CreateR1<float>({26.0f, 30.0f}).get()}),
{}, error_spec_);
}
@@ -371,37 +372,38 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
XlaBuilder true_builder(TestName() + ".true");
{
- true_builder.Parameter(0, empty_tuple_, "tuple");
- auto true_pred = true_builder.ConstantR0<bool>(true);
- auto true_scalar = true_builder.ConstantR0<float>(12.2f);
- auto true_array = true_builder.ConstantR1<float>({12.8f, 14.6f});
- true_builder.Tuple({true_pred, true_scalar, true_array});
+ Parameter(&true_builder, 0, empty_tuple_, "tuple");
+ auto true_pred = ConstantR0<bool>(&true_builder, true);
+ auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
+ auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
+ Tuple(&true_builder, {true_pred, true_scalar, true_array});
}
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
XlaBuilder false_builder(TestName() + ".false");
{
- false_builder.Parameter(0, empty_tuple_, "tuple");
- auto false_pred = false_builder.ConstantR0<bool>(false);
- auto false_scalar = false_builder.ConstantR0<float>(25.6f);
- auto false_array = false_builder.ConstantR1<float>({26.4f, 32.6f});
- false_builder.Tuple({false_pred, false_scalar, false_array});
+ Parameter(&false_builder, 0, empty_tuple_, "tuple");
+ auto false_pred = ConstantR0<bool>(&false_builder, false);
+ auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
+ auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
+ Tuple(&false_builder, {false_pred, false_scalar, false_array});
}
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operands = builder.Tuple({});
- builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
- operands, false_builder_result.ConsumeValueOrDie());
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operands = Tuple(&builder, {});
+ Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
+ false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<bool>(true).get(),
- Literal::CreateR0<float>(12.2f).get(),
- Literal::CreateR1<float>({12.8f, 14.6f}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<bool>(true).get(),
+ LiteralUtil::CreateR0<float>(12.2f).get(),
+ LiteralUtil::CreateR1<float>({12.8f, 14.6f}).get()}),
{}, error_spec_);
}
@@ -409,45 +411,48 @@ XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
XlaBuilder true_builder(TestName() + ".true");
{
- true_builder.Parameter(0, empty_tuple_, "tuple");
- auto true_constant1 = true_builder.ConstantR0<float>(12.2f);
- auto true_constant2 = true_builder.ConstantR1<float>({12.8f, 14.6f});
- auto true_constant3 = true_builder.ConstantR1<float>({25.4f, 29.8f});
- auto true_constant4 = true_builder.ConstantR0<float>(35.6f);
- true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}),
- true_builder.Tuple({true_constant3, true_constant4})});
+ Parameter(&true_builder, 0, empty_tuple_, "tuple");
+ auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
+ auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
+ auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
+ auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
+ Tuple(&true_builder,
+ {Tuple(&true_builder, {true_constant1, true_constant2}),
+ Tuple(&true_builder, {true_constant3, true_constant4})});
}
auto true_builder_result = true_builder.Build();
EXPECT_IS_OK(true_builder_result.status());
XlaBuilder false_builder(TestName() + ".false");
{
- false_builder.Parameter(0, empty_tuple_, "tuple");
- auto false_constant1 = false_builder.ConstantR0<float>(46.6f);
- auto false_constant2 = false_builder.ConstantR1<float>({54.4f, 58.4f});
- auto false_constant3 = false_builder.ConstantR1<float>({62.1f, 67.4f});
- auto false_constant4 = false_builder.ConstantR0<float>(9.3f);
- false_builder.Tuple(
- {false_builder.Tuple({false_constant1, false_constant2}),
- false_builder.Tuple({false_constant3, false_constant4})});
+ Parameter(&false_builder, 0, empty_tuple_, "tuple");
+ auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
+ auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
+ auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
+ auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
+ Tuple(&false_builder,
+ {Tuple(&false_builder, {false_constant1, false_constant2}),
+ Tuple(&false_builder, {false_constant3, false_constant4})});
}
auto false_builder_result = false_builder.Build();
EXPECT_IS_OK(false_builder_result.status());
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto operands = builder.Tuple({});
- builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(),
- operands, false_builder_result.ConsumeValueOrDie());
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto operands = Tuple(&builder, {});
+ Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), operands,
+ false_builder_result.ConsumeValueOrDie());
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple(
- {Literal::MakeTuple({Literal::CreateR0<float>(46.6f).get(),
- Literal::CreateR1<float>({54.4f, 58.4f}).get()})
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(46.6f).get(),
+ LiteralUtil::CreateR1<float>({54.4f, 58.4f}).get()})
.get(),
- Literal::MakeTuple({Literal::CreateR1<float>({62.1f, 67.4f}).get(),
- Literal::CreateR0<float>(9.3f).get()})
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({62.1f, 67.4f}).get(),
+ LiteralUtil::CreateR0<float>(9.3f).get()})
.get()}),
{}, error_spec_);
}
@@ -464,8 +469,8 @@ XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
auto operand2_param =
CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
- builder.Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
- CreateR0FloorComputation());
+ Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
+ CreateR0FloorComputation());
ComputeAndCompareR0<float>(
&builder, 57.0f,
@@ -484,8 +489,8 @@ XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
&builder, &operand1);
auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
&builder, &operand2);
- builder.Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
- CreateR1FloorComputation());
+ Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
+ CreateR1FloorComputation());
ComputeAndCompareR1<float>(
&builder, {10.0f, 11.0f},
@@ -499,27 +504,25 @@ XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
- auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
- auto pred_cond = inner_builder.GetTupleElement(param0, 0);
- auto true_operand = inner_builder.GetTupleElement(param0, 1);
- auto false_operand = inner_builder.GetTupleElement(param0, 2);
- inner_builder.Conditional(pred_cond, true_operand,
- CreateR0CeilComputation(), false_operand,
- CreateR0FloorComputation());
+ auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
+ auto pred_cond = GetTupleElement(param0, 0);
+ auto true_operand = GetTupleElement(param0, 1);
+ auto false_operand = GetTupleElement(param0, 2);
+ Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
+ false_operand, CreateR0FloorComputation());
}
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
- auto pred1 = builder.ConstantR0<bool>(true);
- auto pred2 = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(1.1f);
- auto operand2 = builder.ConstantR0<float>(12.2f);
- auto operand3 = builder.ConstantR0<float>(43.3f);
- auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
- builder.Conditional(pred1, tuple_operand,
- inner_builder_result.ConsumeValueOrDie(), operand3,
- CreateR0IdentityComputation());
+ auto pred1 = ConstantR0<bool>(&builder, true);
+ auto pred2 = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 1.1f);
+ auto operand2 = ConstantR0<float>(&builder, 12.2f);
+ auto operand3 = ConstantR0<float>(&builder, 43.3f);
+ auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
+ Conditional(pred1, tuple_operand, inner_builder_result.ConsumeValueOrDie(),
+ operand3, CreateR0IdentityComputation());
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -529,23 +532,22 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
{
Shape r0bool = ShapeUtil::MakeShape(PRED, {});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
- auto param0 = inner_builder.Parameter(0, tuple_shape, "param0");
- auto pred_cond = inner_builder.GetTupleElement(param0, 0);
- auto true_operand = inner_builder.GetTupleElement(param0, 1);
- auto false_operand = inner_builder.GetTupleElement(param0, 2);
- inner_builder.Conditional(pred_cond, true_operand,
- CreateR0CeilComputation(), false_operand,
- CreateR0FloorComputation());
+ auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
+ auto pred_cond = GetTupleElement(param0, 0);
+ auto true_operand = GetTupleElement(param0, 1);
+ auto false_operand = GetTupleElement(param0, 2);
+ Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
+ false_operand, CreateR0FloorComputation());
}
auto inner_builder_result = inner_builder.Build();
EXPECT_IS_OK(inner_builder_result.status());
XlaBuilder builder(TestName());
- auto pred2 = builder.ConstantR0<bool>(false);
- auto operand1 = builder.ConstantR0<float>(1.1f);
- auto operand2 = builder.ConstantR0<float>(12.2f);
- auto tuple_operand = builder.Tuple({pred2, operand1, operand2});
- builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
+ auto pred2 = ConstantR0<bool>(&builder, false);
+ auto operand1 = ConstantR0<float>(&builder, 1.1f);
+ auto operand2 = ConstantR0<float>(&builder, 12.2f);
+ auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
+ Call(&builder, inner_builder_result.ConsumeValueOrDie(), {tuple_operand});
ComputeAndCompareR0<float>(&builder, 12.0f, {}, error_spec_);
}
@@ -553,12 +555,12 @@ XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
// Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto operand1 = builder.ConstantR0<float>(56.0f);
- auto operand2 = builder.ConstantR0<float>(12.0f);
- auto operands = builder.Tuple({operand1, operand2});
- builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
- CreateR0TupleSubComputation());
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto operand1 = ConstantR0<float>(&builder, 56.0f);
+ auto operand2 = ConstantR0<float>(&builder, 12.0f);
+ auto operands = Tuple(&builder, {operand1, operand2});
+ Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
+ CreateR0TupleSubComputation());
auto result = builder.Build();
EXPECT_FALSE(result.ok());
@@ -572,45 +574,45 @@ XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
XlaComputation swapper;
{
XlaBuilder builder(TestName() + ".swapper");
- auto param0 = builder.Parameter(0, tuple_shape, "sp0");
- auto x = builder.GetTupleElement(param0, 0);
- auto y = builder.GetTupleElement(param0, 1);
- builder.Tuple({y, x});
+ auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
+ auto x = GetTupleElement(param0, 0);
+ auto y = GetTupleElement(param0, 1);
+ Tuple(&builder, {y, x});
swapper = builder.Build().ConsumeValueOrDie();
}
XlaComputation forwarder;
{
XlaBuilder builder(TestName() + ".forwarder");
- auto param0 = builder.Parameter(0, tuple_shape, "fp0");
- auto x = builder.GetTupleElement(param0, 0);
- auto y = builder.GetTupleElement(param0, 1);
- builder.Tuple({x, y});
+ auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
+ auto x = GetTupleElement(param0, 0);
+ auto y = GetTupleElement(param0, 1);
+ Tuple(&builder, {x, y});
forwarder = builder.Build().ConsumeValueOrDie();
}
XlaComputation main;
{
XlaBuilder builder(TestName() + ".main");
- auto param0 = builder.Parameter(0, tuple_shape, "mp0");
- auto x = builder.GetTupleElement(param0, 0);
- auto y = builder.GetTupleElement(param0, 1);
- auto lt_pred = builder.Lt(x, y);
- auto res = builder.Conditional(lt_pred, param0, forwarder, param0, swapper);
- auto ge_pred = builder.Ge(x, y);
- builder.Conditional(ge_pred, res, swapper, res, forwarder);
+ auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
+ auto x = GetTupleElement(param0, 0);
+ auto y = GetTupleElement(param0, 1);
+ auto lt_pred = Lt(x, y);
+ auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
+ auto ge_pred = Ge(x, y);
+ Conditional(ge_pred, res, swapper, res, forwarder);
main = builder.Build().ConsumeValueOrDie();
}
auto test_swap = [&](float a, float b) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR0<float>(a);
- auto y = builder.ConstantR0<float>(b);
- auto tuple_operand = builder.Tuple({x, y});
- builder.Call(main, {tuple_operand});
+ auto x = ConstantR0<float>(&builder, a);
+ auto y = ConstantR0<float>(&builder, b);
+ auto tuple_operand = Tuple(&builder, {x, y});
+ Call(&builder, main, {tuple_operand});
ComputeAndCompareTuple(
&builder,
- *Literal::MakeTuple({Literal::CreateR0<float>(a).get(),
- Literal::CreateR0<float>(b).get()}),
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(a).get(),
+ LiteralUtil::CreateR0<float>(b).get()}),
{}, error_spec_);
};
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
index 916ffadbc7..71d72a9828 100644
--- a/tensorflow/compiler/xla/tests/constants_test.cc
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -39,7 +40,7 @@ class ConstantsTest : public ClientLibraryTestBase {
TEST_F(ConstantsTest, ZeroCellF32) {
XlaBuilder builder(TestName());
- builder.ConstantR1<float>({});
+ ConstantR1<float>(&builder, {});
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -48,7 +49,7 @@ TEST_F(ConstantsTest, OneCellF32) {
std::vector<float> constant = {2.0};
XlaBuilder builder(TestName());
- builder.ConstantR1<float>(constant);
+ ConstantR1<float>(&builder, constant);
ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
}
@@ -57,7 +58,7 @@ TEST_F(ConstantsTest, OneCellS32) {
std::vector<int32> constant = {2};
XlaBuilder builder(TestName());
- builder.ConstantR1<int32>(constant);
+ ConstantR1<int32>(&builder, constant);
ComputeAndCompareR1<int32>(&builder, constant, {});
}
@@ -66,7 +67,7 @@ TEST_F(ConstantsTest, OneCellU32) {
std::vector<uint32> constant = {2};
XlaBuilder builder(TestName());
- builder.ConstantR1<uint32>(constant);
+ ConstantR1<uint32>(&builder, constant);
ComputeAndCompareR1<uint32>(&builder, constant, {});
}
@@ -75,7 +76,7 @@ TEST_F(ConstantsTest, EightCells) {
std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
XlaBuilder builder(TestName());
- builder.ConstantR1<float>(constant);
+ ConstantR1<float>(&builder, constant);
ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
}
@@ -85,14 +86,14 @@ TEST_F(ConstantsTest, SixteenCells) {
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
XlaBuilder builder(TestName());
- builder.ConstantR1<float>(constant);
+ ConstantR1<float>(&builder, constant);
ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
}
TEST_F(ConstantsTest, Empty_0x2) {
XlaBuilder builder(TestName());
- builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 2));
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
}
@@ -102,15 +103,15 @@ TEST_F(ConstantsTest, Small_2x2) {
MakeLinspaceArray2D(100.0, 200.0, 2, 2);
XlaBuilder builder(TestName());
- builder.ConstantR2FromArray2D<float>(*constant);
+ ConstantR2FromArray2D<float>(&builder, *constant);
ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
}
TEST_F(ConstantsTest, Empty_3x0x2) {
XlaBuilder builder(TestName());
- auto constant = builder.ConstantLiteral(
- *Literal::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2)));
+ ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(
+ Array3D<float>(3, 0, 2)));
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
}
@@ -125,8 +126,7 @@ TEST_F(ConstantsTest, Small_2x2x2) {
{{5.f, 6.f}, // y0
{7.f, 8.f}}, // y1
});
- auto constant =
- builder.ConstantLiteral(*Literal::CreateR3FromArray3D<float>(array3d));
+ ConstantLiteral(&builder, *LiteralUtil::CreateR3FromArray3D<float>(array3d));
ComputeAndCompareR3<float>(&builder, array3d, {});
}
@@ -141,17 +141,17 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
});
input_array.FillWithPZ(pz);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4D(input_array);
+ LiteralUtil::CreateR4FromArray4D(input_array);
{
XlaBuilder builder(TestName());
- builder.ConstantLiteral(*input_literal);
+ ConstantLiteral(&builder, *input_literal);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
{
XlaBuilder builder(TestName());
- builder.ConstantR4FromArray4D<float>(input_array);
+ ConstantR4FromArray4D<float>(&builder, input_array);
ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
}
}
@@ -159,17 +159,26 @@ TEST_F(ConstantsTest, Small_3x2x1x1) {
// TODO(b/29263943): Support tuple constants.
TEST_F(ConstantsTest, DISABLED_TupleConstant) {
XlaBuilder builder(TestName());
- builder.ConstantLiteral(
- *Literal::MakeTuple({Literal::CreateR2<float>({{1.0}, {2.0}}).get(),
- Literal::CreateR1<float>({2.0, 42}).get()}));
+ ConstantLiteral(&builder,
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
std::unique_ptr<Literal> result =
ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie();
- LiteralTestUtil::ExpectR2Near<float>(
- {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_);
- LiteralTestUtil::ExpectR1Near<float>(
- {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_);
+ LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
+ LiteralSlice(*result, {0}), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, LiteralSlice(*result, {1}),
+ error_spec_);
+}
+
+TEST_F(ConstantsTest, Token) {
+ XlaBuilder builder(TestName());
+ ConstantLiteral(&builder, *LiteralUtil::CreateToken());
+ // TODO(b/80000000): tokens cannot be returned from computations.
+ Tuple(&builder, {});
+ TF_ASSERT_OK(Execute(&builder, {}).status());
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 722d882471..dca57fd1c7 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -45,8 +45,8 @@ class ConvertTest : public ClientLibraryTestBase {
TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({42, 64});
- builder.ConvertElementType(a, S32);
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ ConvertElementType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -54,8 +54,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.0f, 64.0f});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
+ ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -63,8 +63,8 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({42, 64});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ ConvertElementType(a, F32);
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -72,8 +72,8 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({true, false, true});
- builder.ConvertElementType(a, S32);
+ auto a = ConstantR1<bool>(&builder, {true, false, true});
+ ConvertElementType(a, S32);
std::vector<int32> expected = {1, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -81,8 +81,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) {
TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({true, false, true});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<bool>(&builder, {true, false, true});
+ ConvertElementType(a, F32);
std::vector<float> expected = {1., 0., 1.};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -90,8 +90,8 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) {
XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>({});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<int32>(&builder, {});
+ ConvertElementType(a, F32);
std::vector<float> expected = {};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -99,8 +99,8 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({42.6, 64.4});
- builder.ConvertElementType(a, S32);
+ auto a = ConstantR1<float>(&builder, {42.6, 64.4});
+ ConvertElementType(a, S32);
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -145,12 +145,12 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
static_cast<int64>(0x8000008000000000LL),
static_cast<int64>(0x8000010000000000LL),
};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int64>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int64>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, F32);
+ ConvertElementType(arg_param, F32);
std::vector<float> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -164,12 +164,12 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) {
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff,
0x80000000, 0x80000001, 0x80000002, 0x80000003,
0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, F32);
+ ConvertElementType(arg_param, F32);
std::vector<float> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -182,12 +182,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XlaBuilder builder(TestName());
std::vector<float> arg{0.0f, 1.0f, 16777216.0f,
16777218.0f, 2147483647.0f, 4294967040.0f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, U32);
+ ConvertElementType(arg_param, U32);
std::vector<uint32> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -199,12 +199,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<uint32> arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<uint32>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<uint32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, S64);
+ ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -216,12 +216,12 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) {
XlaBuilder builder(TestName());
std::vector<int32> arg{0, 1, 0x1000, -1, -0x1000};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<int32>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<int32>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, S64);
+ ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -253,12 +253,12 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
9223370937343148032.f,
-9223371487098961920.f,
-9223370937343148032.f};
- std::unique_ptr<Literal> arg_literal = Literal::CreateR1<float>({arg});
- auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param");
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR1<float>({arg});
+ auto arg_param = Parameter(&builder, 0, arg_literal->shape(), "arg_param");
std::unique_ptr<GlobalData> arg_data =
client_->TransferToServer(*arg_literal).ConsumeValueOrDie();
- builder.ConvertElementType(arg_param, S64);
+ ConvertElementType(arg_param, S64);
std::vector<int64> expected(arg.size());
for (int64 i = 0; i < arg.size(); ++i) {
@@ -269,8 +269,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) {
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint8_t>({32, 64});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<uint8_t>(&builder, {32, 64});
+ ConvertElementType(a, F32);
std::vector<float> expected = {32.0, 64.0};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -278,8 +278,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint8_t>({32, 64});
- builder.ConvertElementType(a, S32);
+ auto a = ConstantR1<uint8_t>(&builder, {32, 64});
+ ConvertElementType(a, S32);
std::vector<int32_t> expected = {32, 64};
ComputeAndCompareR1<int32_t>(&builder, expected, {});
@@ -287,8 +287,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<uint8_t>({32, 64});
- builder.ConvertElementType(a, U32);
+ auto a = ConstantR1<uint8_t>(&builder, {32, 64});
+ ConvertElementType(a, U32);
std::vector<uint32_t> expected = {32, 64};
ComputeAndCompareR1<uint32_t>(&builder, expected, {});
@@ -296,8 +296,8 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<float>({32.0f, 64.0f});
- builder.ConvertElementType(a, F64);
+ auto a = ConstantR1<float>(&builder, {32.0f, 64.0f});
+ ConvertElementType(a, F64);
std::vector<double> expected = {32.0, 64.0};
ComputeAndCompareR1<double>(&builder, expected, {});
@@ -305,8 +305,8 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<double>({32.0, 64.0});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<double>(&builder, {32.0, 64.0});
+ ConvertElementType(a, F32);
std::vector<float> expected = {32.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -314,9 +314,9 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
TEST_F(ConvertTest, ConvertS32Extremes) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<int32>(
- {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
- builder.ConvertElementType(a, F32);
+ auto a = ConstantR1<int32>(&builder, {std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max()});
+ ConvertElementType(a, F32);
std::vector<float> expected = {
static_cast<float>(std::numeric_limits<int32>::min()),
@@ -327,10 +327,10 @@ TEST_F(ConvertTest, ConvertS32Extremes) {
TEST_F(ConvertTest, ConvertMapToS32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
- auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
- b->ConvertElementType(param, S32);
- auto a = builder.ConstantR1<float>({42.0f, 64.0f});
- builder.Map({a}, b->BuildAndNoteError(), {0});
+ auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "in");
+ ConvertElementType(param, S32);
+ auto a = ConstantR1<float>(&builder, {42.0f, 64.0f});
+ Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<int32> expected = {42, 64};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -339,10 +339,10 @@ TEST_F(ConvertTest, ConvertMapToS32) {
TEST_F(ConvertTest, ConvertMapToF32) {
XlaBuilder builder(TestName());
auto b = builder.CreateSubBuilder("convert");
- auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
- b->ConvertElementType(param, F32);
- auto a = builder.ConstantR1<int32>({42, 64});
- builder.Map({a}, b->BuildAndNoteError(), {0});
+ auto param = Parameter(b.get(), 0, ShapeUtil::MakeShape(S32, {}), "in");
+ ConvertElementType(param, F32);
+ auto a = ConstantR1<int32>(&builder, {42, 64});
+ Map(&builder, {a}, b->BuildAndNoteError(), {0});
std::vector<float> expected = {42.0f, 64.0f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -355,9 +355,9 @@ TEST_F(ConvertTest, ConvertMapToF32) {
// the new convert should have the same element type as the old convert.
TEST_F(ConvertTest, ConvertReshape) {
XlaBuilder builder(TestName());
- auto input = builder.ConstantR1<int32>({42});
- auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
- builder.ConvertElementType(reshape, F32);
+ auto input = ConstantR1<int32>(&builder, {42});
+ auto reshape = Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
+ ConvertElementType(reshape, F32);
ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
}
@@ -391,13 +391,13 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<half>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<half>(input)));
XlaBuilder builder(TestName());
- builder.ConvertElementType(
- builder.Parameter(
- 0, ShapeUtil::MakeShape(F16, {static_cast<int64>(input.size())}),
- "param"),
+ ConvertElementType(
+ Parameter(&builder, 0,
+ ShapeUtil::MakeShape(F16, {static_cast<int64>(input.size())}),
+ "param"),
F32);
ComputeAndCompareR1<float>(&builder, expected_output, {dot_lhs_handle.get()});
@@ -411,13 +411,13 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> dot_lhs_handle,
- client_->TransferToServer(*Literal::CreateR1<float>(input)));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>(input)));
XlaBuilder builder(TestName());
- builder.ConvertElementType(
- builder.Parameter(
- 0, ShapeUtil::MakeShape(F32, {static_cast<int64>(input.size())}),
- "param"),
+ ConvertElementType(
+ Parameter(&builder, 0,
+ ShapeUtil::MakeShape(F32, {static_cast<int64>(input.size())}),
+ "param"),
F16);
ComputeAndCompareR1<half>(&builder, expected_output, {dot_lhs_handle.get()});
@@ -426,28 +426,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) {
XLA_TEST_F(ConvertTest, ConvertC64ToC64) {
XlaBuilder builder(TestName());
std::vector<complex64> x = {{42.0f, 64.0f}};
- builder.ConvertElementType(builder.ConstantR1<complex64>(x), C64);
+ ConvertElementType(ConstantR1<complex64>(&builder, x), C64);
ComputeAndCompareR1<complex64>(&builder, x, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(ConvertTest, ConvertS64S64) {
XlaBuilder builder(TestName());
std::vector<int64> x = {{-42, 64}};
- builder.ConvertElementType(builder.ConstantR1<int64>(x), S64);
+ ConvertElementType(ConstantR1<int64>(&builder, x), S64);
ComputeAndCompareR1<int64>(&builder, x, {});
}
XLA_TEST_F(ConvertTest, ConvertU64U64) {
XlaBuilder builder(TestName());
std::vector<uint64> x = {{42, 64}};
- builder.ConvertElementType(builder.ConstantR1<uint64>(x), U64);
+ ConvertElementType(ConstantR1<uint64>(&builder, x), U64);
ComputeAndCompareR1<uint64>(&builder, x, {});
}
XLA_TEST_F(ConvertTest, ConvertU64S64) {
XlaBuilder builder(TestName());
std::vector<uint64> unsigned_x = {{42, UINT64_MAX}};
- builder.ConvertElementType(builder.ConstantR1<uint64>(unsigned_x), S64);
+ ConvertElementType(ConstantR1<uint64>(&builder, unsigned_x), S64);
std::vector<int64> signed_x = {{42, -1}};
ComputeAndCompareR1<int64>(&builder, signed_x, {});
}
@@ -455,11 +455,31 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) {
XLA_TEST_F(ConvertTest, ConvertS64U64) {
XlaBuilder builder(TestName());
std::vector<int64> signed_x = {{42, -1, INT64_MIN}};
- builder.ConvertElementType(builder.ConstantR1<int64>(signed_x), U64);
+ ConvertElementType(ConstantR1<int64>(&builder, signed_x), U64);
std::vector<uint64> unsigned_x = {
{42, UINT64_MAX, tensorflow::MathUtil::IPow<uint64>(2, 63)}};
ComputeAndCompareR1<uint64>(&builder, unsigned_x, {});
}
+XLA_TEST_F(ConvertTest, ConvertBF16F32) {
+ XlaBuilder builder(TestName());
+
+ std::vector<bfloat16> all_bfloats(1 << 16);
+ for (int i = 0; i < all_bfloats.size(); ++i) {
+ all_bfloats[i].value = i;
+ }
+
+ std::vector<uint32> expected(all_bfloats.size());
+ for (int i = 0; i < expected.size(); ++i) {
+ expected[i] = (1U << 16) * i;
+ }
+
+ // Exhaustively test all bf16 to f32 conversions.
+ xla::XlaOp all_bfloats_bf16 = ConstantR1<bfloat16>(&builder, all_bfloats);
+ xla::XlaOp all_bfloats_f32 = ConvertElementType(all_bfloats_bf16, F32);
+ BitcastConvertType(all_bfloats_f32, U32);
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index b5a42e3059..944366410b 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -93,14 +93,15 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
weight_array->FillWithMultiples(0.2);
auto weight_data =
- client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array))
+ client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(*input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, *input_array);
auto weight =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight");
- auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid);
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight");
+ auto conv1 = Conv(input, weight, {1, 1}, Padding::kValid);
ConvolutionDimensionNumbers dim_nums =
XlaBuilder::CreateDefaultConvDimensionNumbers();
@@ -117,8 +118,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
dim_nums.set_kernel_input_feature_dimension(
dim_nums.kernel_output_feature_dimension());
dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim);
- builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid,
- dim_nums);
+ ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums);
auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array,
{1, 1}, Padding::kValid);
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 947959beb1..a8b8f74ca9 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase {
#if XLA_TEST_BACKEND_GPU
// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
// convolution. So relax the absolute error threshold.
- ErrorSpec error_spec_ = ErrorSpec(1e-2);
+ ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4);
#else
- ErrorSpec error_spec_ = ErrorSpec(1e-4);
+ ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4);
#endif
};
@@ -89,9 +89,9 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
ASSERT_EQ(2, arhs->height());
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR4FromArray4D<T>(*alhs);
- auto rhs = builder.ConstantR4FromArray4D<T>(*arhs);
- builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
+ auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
+ Conv(lhs, rhs, {1, 1}, Padding::kValid);
ComputeAndCompare(&builder, {}, error_spec_);
}
@@ -109,9 +109,9 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<T> input_data(1, 1, 1, 2);
input_data.FillWithYX(Array2D<T>({
@@ -123,8 +123,8 @@ class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -140,9 +140,9 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({
@@ -157,8 +157,8 @@ class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest {
{7.0f, 8.0f},
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -174,9 +174,9 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({
@@ -192,8 +192,8 @@ class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -210,9 +210,9 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 1, 3, 3});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<T> input_data(1, 1, 4, 4);
input_data.FillWithYX(Array2D<T>({{1.0f, 2.0f, 3.0f, 4.0f},
@@ -224,8 +224,8 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
{{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}}));
// clang-format on
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
};
@@ -238,9 +238,9 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1}, Padding::kValid);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1}, Padding::kValid);
}
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
@@ -249,10 +249,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
Array3D<float> expected({{{510, 610, 710, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -268,10 +268,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
+ ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
@@ -284,10 +284,10 @@ class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -304,10 +304,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
+ ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
@@ -319,10 +319,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -335,10 +335,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
+ ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
@@ -350,10 +350,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<float>(&builder, expected,
@@ -369,10 +369,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Convolution dimensions are bf0_oi0->bo0.
- builder.ConvGeneralDilated(
+ ConvGeneralDilated(
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
@@ -386,10 +386,10 @@ class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
auto input_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(input))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
.ConsumeValueOrDie();
auto filter_literal =
- client_->TransferToServer(*Literal::CreateR3FromArray3D(filter))
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR3<T>(&builder, expected,
@@ -408,8 +408,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
{
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Tensorflow dimension numbers for 3D convolution.
ConvolutionDimensionNumbers dnums;
@@ -429,21 +429,20 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
dnums.set_kernel_input_feature_dimension(3);
dnums.set_kernel_output_feature_dimension(4);
- builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid,
- dnums);
+ ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid, dnums);
}
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
iota(input_elems.begin(), input_elems.end(), 1.0f);
- auto input_r1 = Literal::CreateR1<float>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota(filter_elems.begin(), filter_elems.end(), 1.0f);
- auto filter_r1 = Literal::CreateR1<float>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<float>(
+ auto expected_r1 = LiteralUtil::CreateR1<float>(
{19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie();
@@ -475,8 +474,8 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
{
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Tensorflow dimension numbers for 2D convolution.
ConvolutionDimensionNumbers dnums;
@@ -493,21 +492,20 @@ class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest {
dnums.set_kernel_input_feature_dimension(2);
dnums.set_kernel_output_feature_dimension(3);
- builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid,
- dnums);
+ ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums);
}
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape));
iota_int_init_value(input_elems, 1);
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape));
iota_int_init_value(filter_elems, 1);
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
- auto expected_r1 = Literal::CreateR1<T>(
+ auto expected_r1 = LiteralUtil::CreateR1<T>(
{static_cast<T>(92115), static_cast<T>(93150), static_cast<T>(94185)});
auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie();
@@ -541,8 +539,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29});
Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
ConvolutionDimensionNumbers dnums;
dnums.set_input_feature_dimension(0);
@@ -551,7 +549,7 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
dnums.set_kernel_output_feature_dimension(1);
dnums.set_output_batch_dimension(0);
dnums.set_output_feature_dimension(1);
- builder.ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
+ ConvWithGeneralDimensions(input, filter, {}, Padding::kValid, dnums);
Array2D<float> param0(4, 29);
param0.FillUnique();
@@ -563,8 +561,8 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization,
expected_result.Fill(0);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(param0)),
- std::move(*Literal::CreateFromArray(param1))},
+ {std::move(*LiteralUtil::CreateFromArray(param0)),
+ std::move(*LiteralUtil::CreateFromArray(param1))},
error_spec_);
}
@@ -599,8 +597,8 @@ class Convolve1D1WindowTestBase
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
{
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
// Tensorflow dimension numbers for 1D convolution.
ConvolutionDimensionNumbers dnums;
@@ -614,24 +612,23 @@ class Convolve1D1WindowTestBase
dnums.set_kernel_input_feature_dimension(1);
dnums.set_kernel_output_feature_dimension(2);
- builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid,
- dnums);
+ ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
}
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
static_cast<T>(1.0f));
- auto input_r1 = Literal::CreateR1<T>(input_elems);
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
static_cast<T>(1.0f));
- auto filter_r1 = Literal::CreateR1<T>(filter_elems);
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
std::vector<T> expect_elems(batch * output_feature * num_windows,
static_cast<T>(window_size * input_feature));
- auto expected_r1 = Literal::CreateR1<T>(expect_elems);
+ auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
auto expected_r3 =
expected_r1->Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
@@ -726,9 +723,9 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<bfloat16> input_data(1, 1, 1, 2);
input_data.FillWithYX(Array2D<bfloat16>({
@@ -740,8 +737,8 @@ XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
}));
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))},
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))},
error_spec_);
}
@@ -754,9 +751,9 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
XlaBuilder builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> input_data(1, 1, 1, 2);
input_data.FillIota(0);
@@ -764,8 +761,8 @@ XLA_TEST_F(ConvolutionTest, NoCudnnAlgorithmPicker) {
filter_data.FillIota(10);
ComputeAndCompare(&builder,
- {std::move(*Literal::CreateFromArray(input_data)),
- std::move(*Literal::CreateFromArray(filter_data))});
+ {std::move(*LiteralUtil::CreateFromArray(input_data)),
+ std::move(*LiteralUtil::CreateFromArray(filter_data))});
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index fea850dc13..8792e7781b 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -55,12 +55,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) {
XlaBuilder builder(TestName());
const Array4D<float> input_array(1, 1, 1, 1, {2});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {3});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
const Array4D<float> expected(1, 1, 1, 1, {6});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -70,12 +70,12 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
XlaBuilder builder(TestName());
const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {2});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
const Array4D<float> expected(5, 1, 1, 1, {2, 4, 6, 8, 10});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -86,12 +86,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) {
Array4D<float> input_array(2, 1, 3, 4);
input_array.FillWithMultiples(1);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {2.3});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(2, 1, 3, 4);
expected.FillWithMultiples(2.3);
@@ -102,12 +102,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 2, 1, 1, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 3, 1, 1, {12, 34, 56});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -117,12 +117,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 2, {1, 2});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 1, {12});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -132,12 +132,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 2, {12, 23});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -147,12 +147,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 2, 1, {12, 34});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -162,12 +162,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 2, 1, {10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 2, {13, 24});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -177,12 +177,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 2, 2, {1000, 100, 10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 1, {1234});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -194,13 +194,13 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
Array4D<float> input_array(
2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(
2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(
2, 2, 2, 2,
@@ -213,12 +213,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {10});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 2}, Padding::kValid);
+ Conv(input, filter, {1, 2}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 2, {10, 30});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -228,12 +228,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {10});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 2}, Padding::kValid);
+ Conv(input, filter, {1, 2}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 3, {10, 30, 50});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -243,12 +243,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 2}, Padding::kValid);
+ Conv(input, filter, {1, 2}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 1, {123});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -258,12 +258,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 2}, Padding::kValid);
+ Conv(input, filter, {1, 2}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 2, {123, 345});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -273,12 +273,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 1, {10});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {2, 2}, Padding::kValid);
+ Conv(input, filter, {2, 2}, Padding::kValid);
Array4D<float> expected(1, 1, 2, 2, {10, 30, 70, 90});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -288,12 +288,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 1, {1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 3, {10, 20, 30});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> expected(1, 1, 1, 1, {20});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -303,12 +303,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> expected(1, 1, 1, 3, {123, 1230, 12300});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -318,15 +318,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 3, 3,
{10000, 0, 1000, // row 0
0, 100, 0, // row 1
10, 0, 1}); // row 2
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> expected(1, 1, 2, 2, {104, 230, 2300, 10400});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -336,12 +336,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 2, 1, 1, {10, 1});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> expected(1, 1, 1, 2, {13, 24});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -351,12 +351,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 2, 2, {7, 13, 17, 23});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 2, 2, {216, 276, 396, 456});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -366,12 +366,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
const Array4D<float> filter_array(1, 1, 1, 2, {7, 13});
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 1, 1, 2, {33, 53});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -383,15 +383,15 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
std::vector<float> input_data(64);
std::iota(input_data.begin(), input_data.end(), 0.0);
Array4D<float> input_array(1, 1, 8, 8, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(128);
std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0);
std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0);
const Array4D<float> filter_array(2, 1, 8, 8, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 2, 1, 1, {2016, 4032});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -403,14 +403,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
std::vector<float> input_data(16 * 1 * 1 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(16, 1, 1, 1, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * 1 * 1);
std::iota(filter_data.begin(), filter_data.end(), 1.0);
const Array4D<float> filter_array(1, 1, 1, 1, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data = {1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16};
@@ -432,14 +432,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
}
}
}
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * ky * kx);
std::iota(filter_data.begin(), filter_data.end(), 1.0);
const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data(bs);
for (int i = 0; i < bs; ++i) {
@@ -463,14 +463,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
}
}
}
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * ky * kx);
std::iota(filter_data.begin(), filter_data.end(), 1.0);
const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data = {
23,
@@ -492,14 +492,14 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
}
}
}
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * 8 * 8);
std::iota(filter_data.begin(), filter_data.end(), 1.0);
const Array4D<float> filter_array(1, 1, 8, 8, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data = {
19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224,
@@ -515,7 +515,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
std::vector<float> input_data(2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
Array4D<float> input_array(1, 2, 8, 8, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(2 * 2 * 8 * 8);
std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
@@ -527,9 +527,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
4.0);
const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(1, 2, 1, 1, {14240, 30496});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -541,7 +541,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
std::vector<float> input_data(2 * 2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
Array4D<float> input_array(2, 2, 8, 8, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(2 * 2 * 8 * 8);
std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
@@ -553,9 +553,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
4.0);
const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(2, 2, 1, 1, {14240, 30496, 38816, 87840});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
@@ -567,7 +567,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
std::vector<float> input_data(32 * 2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
Array4D<float> input_array(32, 2, 8, 8, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(2 * 2 * 8 * 8);
std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
@@ -579,9 +579,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
4.0);
const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::vector<float> expected_data = {
14240, 30496, 38816, 87840, 63392, 145184, 87968,
@@ -613,9 +613,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
}
}
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> expected(16, 16, 1, 1);
for (int i0 = 0; i0 < 16; ++i0) {
@@ -635,9 +635,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) {
Array4D<float> input_array(1, 1, 4, 6, input_data);
Array4D<float> filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneralDilated(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2},
XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -654,9 +654,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneralDilated(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{},
XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -677,9 +677,9 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
200, 20, 2, //
300, 30, 3, //
400, 40, 4});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneralDilated(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1},
/*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2},
/*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -699,9 +699,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneral(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-1, -1}},
XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -718,9 +718,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneral(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-1, 2}},
XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -737,9 +737,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneral(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {2, -1}},
XlaBuilder::CreateDefaultConvDimensionNumbers());
@@ -756,9 +756,9 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneralDilated(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {3, 2}},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
@@ -781,9 +781,9 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
Array4D<float> input_array(1, 1, 1, 5, input_data);
Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.ConvGeneralDilated(
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-3, -2}},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
@@ -821,9 +821,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
input_array, filter_array, {1, 1}, Padding::kValid);
@@ -854,9 +854,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
input_array, filter_array, {1, 1}, Padding::kValid);
@@ -887,9 +887,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
input_array, filter_array, {1, 1}, Padding::kValid);
@@ -920,9 +920,9 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
input_array, filter_array, {1, 1}, Padding::kValid);
@@ -954,9 +954,9 @@ XLA_TEST_F(ConvolutionVariantsTest,
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
XlaBuilder builder(TestName());
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
- builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
+ Conv(input, filter, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
input_array, filter_array, {1, 1}, Padding::kValid);
@@ -970,12 +970,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(1, 2, 3, 1, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 2 * 1 * 1);
std::iota(filter_data.begin(), filter_data.end(), 1.0);
Array4D<float> filter_array(1, 2, 1, 1, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
ConvolutionDimensionNumbers dnums;
// NHWC input format.
@@ -995,7 +995,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
dnums.set_kernel_output_feature_dimension(3);
// Tests padding sizes that don't correspond either to SAME or VALID padding.
- builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+ ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
std::vector<float> expected_data = {
0, 0, 0, 0, 0, 0, 0, //
@@ -1014,12 +1014,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(1, 2, 3, 1, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * 1 * 1);
std::iota(filter_data.begin(), filter_data.end(), 2.0);
Array4D<float> filter_array(1, 1, 1, 1, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
ConvolutionDimensionNumbers dnums;
// NHWC input format.
@@ -1039,7 +1039,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
dnums.set_kernel_output_feature_dimension(3);
// Tests padding sizes that don't correspond either to SAME or VALID padding.
- builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+ ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
std::vector<float> expected_data = {
0, 0, 0, 0, 0, 0, 0, 0, //
@@ -1058,12 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(1, 2, 3, 1, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * 1 * 1);
std::iota(filter_data.begin(), filter_data.end(), 2.0);
Array4D<float> filter_array(1, 1, 1, 1, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
ConvolutionDimensionNumbers dnums;
// NHWC input format.
@@ -1083,7 +1083,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
dnums.set_kernel_output_feature_dimension(3);
// Tests zero padding sizes. This can use matmul for computation.
- builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+ ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
std::vector<float> expected_data = {
2, 4, 6, //
@@ -1099,12 +1099,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
std::vector<float> input_data(1 * 2 * 3 * 2);
std::iota(input_data.begin(), input_data.end(), 1.0);
Array4D<float> input_array(1, 2, 3, 2, input_data);
- auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto input = ConstantR4FromArray4D<float>(&builder, input_array);
std::vector<float> filter_data(1 * 1 * 2 * 3);
std::iota(filter_data.begin(), filter_data.end(), 2.0);
Array4D<float> filter_array(1, 1, 2, 3, filter_data);
- auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ auto filter = ConstantR4FromArray4D<float>(&builder, filter_array);
ConvolutionDimensionNumbers dnums;
// NHWC input format.
@@ -1124,7 +1124,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
dnums.set_kernel_output_feature_dimension(3);
// Tests zero padding sizes. This can use matmul for computation.
- builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+ ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
std::vector<float> expected_data = {
12, 15, 18, //
@@ -1148,14 +1148,14 @@ XLA_TEST_F(ConvolutionVariantsTest,
BackwardInputLowPaddingLessThanHighPadding) {
XlaBuilder builder(TestName());
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
- auto weights = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 2, /*values=*/{5, 6}));
- auto mirrored_weights = builder.Rev(weights, {2, 3});
- builder.ConvWithGeneralPadding(gradients, mirrored_weights,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {1, 0}});
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 2, /*values=*/{5, 6}));
+ auto mirrored_weights = Rev(weights, {2, 3});
+ ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 0}});
ComputeAndCompareR4<float>(&builder, {{{{5, 16, 27}}}}, {}, error_spec_);
}
@@ -1167,16 +1167,16 @@ XLA_TEST_F(ConvolutionVariantsTest,
BackwardInputLowPaddingGreaterThanHighPadding) {
XlaBuilder builder(TestName());
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
- auto weights = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
- auto mirrored_weights = builder.Rev(weights, {2, 3});
- builder.ConvGeneralDilated(gradients, mirrored_weights,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {0, 3}},
- /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
- XlaBuilder::CreateDefaultConvDimensionNumbers());
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = Rev(weights, {2, 3});
+ ConvGeneralDilated(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 3}},
+ /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
ComputeAndCompareR4<float>(&builder, {{{{100, 0}}}}, {}, error_spec_);
}
@@ -1187,14 +1187,14 @@ XLA_TEST_F(ConvolutionVariantsTest,
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
XlaBuilder builder(TestName());
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
- auto weights = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
- auto mirrored_weights = builder.Rev(weights, {2, 3});
- builder.ConvWithGeneralPadding(gradients, mirrored_weights,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {1, 1}});
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = Rev(weights, {2, 3});
+ ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 1}});
ComputeAndCompareR4<float>(&builder, {{{{10}}}}, {}, error_spec_);
}
@@ -1208,14 +1208,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
XlaBuilder builder(TestName());
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
- auto weights = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 2, /*values=*/{1, 10}));
- auto mirrored_weights = builder.Rev(weights, {2, 3});
- builder.ConvWithGeneralPadding(gradients, mirrored_weights,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {0, 2}});
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 2, /*values=*/{1, 10}));
+ auto mirrored_weights = Rev(weights, {2, 3});
+ ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 2}});
ComputeAndCompareR4<float>(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_);
}
@@ -1229,17 +1229,17 @@ XLA_TEST_F(ConvolutionVariantsTest,
// weight gradients: 24,130,240
//
// This pattern will be fused to backward convolution with padding=(1,2).
- auto activations = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
- auto forward_conv = builder.ConvGeneralDilated(
- activations, gradients,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {1, 2}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- XlaBuilder::CreateDefaultConvDimensionNumbers());
- builder.Transpose(forward_conv, {0, 1, 2, 3});
+ auto activations = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv =
+ ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 2}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
+ Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_);
}
@@ -1255,17 +1255,17 @@ XLA_TEST_F(ConvolutionVariantsTest,
// This pattern will be fused to backward convolution with padding=(2,1).
// Note: both (2,1) and (2,0) are valid padding for the backward convolution
// because the stride is 2.
- auto activations = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
- auto forward_conv = builder.ConvGeneralDilated(
- activations, gradients,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {2, 0}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- XlaBuilder::CreateDefaultConvDimensionNumbers());
- builder.Transpose(forward_conv, {0, 1, 2, 3});
+ auto activations = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv =
+ ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 0}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
+ Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_);
}
@@ -1282,17 +1282,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
// because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN
// supports even padding only -- using (2,1) would need extra effort of
// canonicalization.
- auto activations = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
- auto gradients = builder.ConstantR4FromArray4D<float>(
- Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
- auto forward_conv = builder.ConvGeneralDilated(
- activations, gradients,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {2, 1}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- XlaBuilder::CreateDefaultConvDimensionNumbers());
- builder.Transpose(forward_conv, {0, 1, 2, 3});
+ auto activations = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = ConstantR4FromArray4D<float>(
+ &builder, Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv =
+ ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
+ Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_);
}
@@ -1300,14 +1300,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
XlaBuilder builder(TestName());
- auto gradients = builder.ConstantR3FromArray3D<float>(
- Array3D<float>(1, 1, 1, /*value=*/1));
+ auto gradients = ConstantR3FromArray3D<float>(
+ &builder, Array3D<float>(1, 1, 1, /*value=*/1));
auto weights =
- builder.ConstantR3FromArray3D<float>(Array3D<float>({{{1, 10, 100}}}));
- auto mirrored_weights = builder.Rev(weights, {2});
- builder.ConvWithGeneralPadding(gradients, mirrored_weights,
- /*window_strides=*/{1},
- /*padding=*/{{1, 1}});
+ ConstantR3FromArray3D<float>(&builder, Array3D<float>({{{1, 10, 100}}}));
+ auto mirrored_weights = Rev(weights, {2});
+ ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1},
+ /*padding=*/{{1, 1}});
ComputeAndCompareR3<float>(&builder, {{{10}}}, {}, error_spec_);
}
@@ -1315,17 +1315,17 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
XlaBuilder builder(TestName());
auto activations =
- builder.ConstantR3FromArray3D<float>(Array3D<float>({{{1, 2, 3, 4}}}));
+ ConstantR3FromArray3D<float>(&builder, Array3D<float>({{{1, 2, 3, 4}}}));
auto gradients =
- builder.ConstantR3FromArray3D<float>(Array3D<float>({{{100, 10, 1}}}));
+ ConstantR3FromArray3D<float>(&builder, Array3D<float>({{{100, 10, 1}}}));
auto forward_conv =
- builder.ConvGeneralDilated(activations, gradients,
- /*window_strides=*/{1},
- /*padding=*/{{2, 1}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{2},
- XlaBuilder::CreateDefaultConvDimensionNumbers(
- /*num_spatial_dims=*/1));
- builder.Transpose(forward_conv, {0, 1, 2});
+ ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1},
+ /*padding=*/{{2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers(
+ /*num_spatial_dims=*/1));
+ Transpose(forward_conv, {0, 1, 2});
ComputeAndCompareR3<float>(&builder, {{{13, 24, 130}}}, {}, error_spec_);
}
@@ -1333,52 +1333,52 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
XlaBuilder builder(TestName());
- auto gradients_flat = Literal::CreateR1<float>({1});
+ auto gradients_flat = LiteralUtil::CreateR1<float>({1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto gradients = builder.ConstantLiteral(*gradients_literal);
+ auto gradients = ConstantLiteral(&builder, *gradients_literal);
- auto weights_flat = Literal::CreateR1<float>({1, 10, 100});
+ auto weights_flat = LiteralUtil::CreateR1<float>({1, 10, 100});
auto weights_literal =
weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto weights = builder.ConstantLiteral(*weights_literal);
+ auto weights = ConstantLiteral(&builder, *weights_literal);
- auto expected_flat = Literal::CreateR1<float>({10});
+ auto expected_flat = LiteralUtil::CreateR1<float>({10});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie();
- auto mirrored_weights = builder.Rev(weights, {2, 3, 4});
- builder.ConvWithGeneralPadding(gradients, mirrored_weights,
- /*window_strides=*/{1, 1, 1},
- /*padding=*/{{0, 0}, {0, 0}, {1, 1}});
+ auto mirrored_weights = Rev(weights, {2, 3, 4});
+ ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1, 1},
+ /*padding=*/{{0, 0}, {0, 0}, {1, 1}});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
XlaBuilder builder(TestName());
- auto activations_flat = Literal::CreateR1<float>({1, 2, 3, 4});
+ auto activations_flat = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie();
- auto activations = builder.ConstantLiteral(*activations_literal);
+ auto activations = ConstantLiteral(&builder, *activations_literal);
- auto gradients_flat = Literal::CreateR1<float>({100, 10, 1});
+ auto gradients_flat = LiteralUtil::CreateR1<float>({100, 10, 1});
auto gradients_literal =
gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto gradients = builder.ConstantLiteral(*gradients_literal);
+ auto gradients = ConstantLiteral(&builder, *gradients_literal);
- auto expected_flat = Literal::CreateR1<float>({13, 24, 130});
+ auto expected_flat = LiteralUtil::CreateR1<float>({13, 24, 130});
auto expected_literal =
expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie();
- auto forward_conv = builder.ConvGeneralDilated(
- activations, gradients,
- /*window_strides=*/{1, 1, 1},
- /*padding=*/{{0, 0}, {0, 0}, {2, 1}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2},
- XlaBuilder::CreateDefaultConvDimensionNumbers(
- /*num_spatial_dims=*/3));
- builder.Transpose(forward_conv, {0, 1, 2, 3, 4});
+ auto forward_conv =
+ ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1, 1, 1},
+ /*padding=*/{{0, 0}, {0, 0}, {2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers(
+ /*num_spatial_dims=*/3));
+ Transpose(forward_conv, {0, 1, 2, 3, 4});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
index 2b3390ca98..1dc6ff0f4f 100644
--- a/tensorflow/compiler/xla/tests/copy_test.cc
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -58,37 +58,38 @@ class CopyOpTest : public HloTestBase {
};
XLA_TEST_F(CopyOpTest, CopyR0Bool) {
- TestCopyOp(*Literal::CreateR0<bool>(true));
+ TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
}
XLA_TEST_F(CopyOpTest, CopyR1S0U32) {
- TestCopyOp(*Literal::CreateR1<uint32>({}));
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
}
XLA_TEST_F(CopyOpTest, CopyR1S3U32) {
- TestCopyOp(*Literal::CreateR1<uint32>({1, 2, 3}));
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
}
XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
- TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
- {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+ TestCopyOp(
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
- TestCopyOp(*Literal::CreateR4(
+ TestCopyOp(*LiteralUtil::CreateR4(
{{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
{{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
}
XLA_TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
- TestCopyOp(*Literal::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+ TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
}
XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
auto builder = HloComputation::Builder(TestName());
// Copy literal to device to use as parameter.
- auto literal = Literal::CreateR0<float>(42.0);
+ auto literal = LiteralUtil::CreateR0<float>(42.0);
Shape shape = literal->shape();
auto param0 = builder.AddInstruction(
@@ -109,7 +110,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) {
XLA_TEST_F(CopyOpTest, CopyConstantR2Twice) {
auto builder = HloComputation::Builder(TestName());
- auto literal = Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -131,7 +132,7 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
HloComputation::Builder builder(TestName());
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
// Reverse the minor-to-major order of the literal.
Layout* literal_layout =
literal->mutable_shape_do_not_use()->mutable_layout();
@@ -168,7 +169,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(a);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -202,7 +203,7 @@ void CopyOpTest::TestCopyConstantLayoutR4(
HloComputation::Builder builder(TestName());
- std::unique_ptr<Literal> literal = Literal::CreateR4FromArray4D(a);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
@@ -248,7 +249,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
auto empty = Literal::CreateFromShape(in_shape);
XlaBuilder builder(TestName());
- auto param0 = builder.Parameter(0, in_shape, "input");
+ Parameter(&builder, 0, in_shape, "input");
auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie();
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index b159887765..d12a4e7fcd 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -32,28 +32,44 @@ class TrivialCrossReplicaSumTest : public HloTestBase {};
XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p = f32[3] parameter(0)
- ROOT crs = f32[3] cross-replica-sum(p)
+ ROOT crs = f32[3] cross-replica-sum(p), to_apply=add
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal = Literal::CreateR1<float>({1, 2, 3});
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
}
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] parameter(1)
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal0 = Literal::CreateR1<float>({1, 2, 3});
- auto literal1 = Literal::CreateR1<float>({10, 20});
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
EXPECT_EQ(
- *Literal::MakeTuple({literal0.get(), literal1.get()}),
+ *LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get(), literal1.get()}));
}
@@ -63,15 +79,23 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] constant({10, 20})
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
- auto literal0 = Literal::CreateR1<float>({1, 2, 3});
- auto literal1 = Literal::CreateR1<float>({10, 20});
- EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}),
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto literal0 = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ EXPECT_EQ(*LiteralUtil::MakeTuple({literal0.get(), literal1.get()}),
*ExecuteAndTransfer(std::move(module), {literal0.get()}));
}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index b43d5c9ff5..90f3d1b874 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
@@ -73,7 +74,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
@@ -94,7 +95,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
array(1, 1) = 4.0f;
auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
builder.AddInstruction(
HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
@@ -110,7 +111,7 @@ XLA_TEST_F(CustomCallTest,
auto b = HloComputation::Builder(TestName());
auto input = b.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
@@ -135,8 +136,8 @@ class CustomCallClientAPITest : public ClientLibraryTestBase {};
// are reserved for internal use.
XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) {
XlaBuilder builder(TestName());
- builder.CustomCall("$illegal", /*operands=*/{},
- ShapeUtil::MakeShape(F32, {1}));
+ CustomCall(&builder, "$illegal", /*operands=*/{},
+ ShapeUtil::MakeShape(F32, {1}));
StatusOr<std::unique_ptr<GlobalData>> result =
Execute(&builder, /*arguments=*/{});
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
index bfe688e20d..d4b3aac85b 100644
--- a/tensorflow/compiler/xla/tests/deallocation_test.cc
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -48,7 +48,7 @@ class DeallocationTest : public ClientLibraryTestBase {
TEST_F(DeallocationTest, DeallocateScalar) {
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0);
+ ConstantR0<float>(&builder, 42.0);
auto global_data = ExecuteAndCheckTransfer(&builder, {});
// A result can be transferred an arbitrary number of times. Add an extra
@@ -66,7 +66,7 @@ TEST_F(DeallocationTest, DeallocateScalar) {
TEST_F(DeallocationTest, DeallocateVector) {
XlaBuilder builder(TestName());
- builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
ASSERT_IS_OK(client_->Unregister(*global_data));
@@ -79,7 +79,7 @@ TEST_F(DeallocationTest, DeallocateVector) {
TEST_F(DeallocationTest, DeallocateEmptyVector) {
XlaBuilder builder(TestName());
- builder.ConstantR1<float>({});
+ ConstantR1<float>(&builder, {});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
ASSERT_IS_OK(client_->Unregister(*global_data));
@@ -92,8 +92,8 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) {
XLA_TEST_F(DeallocationTest, DeallocateTuple) {
XlaBuilder builder(TestName());
- builder.Tuple({builder.ConstantR0<float>(42.0),
- builder.ConstantR1<float>({1.0, 2.0, 3.0})});
+ Tuple(&builder, {ConstantR0<float>(&builder, 42.0),
+ ConstantR1<float>(&builder, {1.0, 2.0, 3.0})});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
ASSERT_IS_OK(client_->Unregister(*global_data));
@@ -106,9 +106,10 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) {
XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) {
XlaBuilder builder(TestName());
- auto element = builder.ConstantR0<float>(42.0);
- auto inner_tuple = builder.Tuple({builder.ConstantR0<float>(42.0), element});
- builder.Tuple({element, inner_tuple, element});
+ auto element = ConstantR0<float>(&builder, 42.0);
+ auto inner_tuple =
+ Tuple(&builder, {ConstantR0<float>(&builder, 42.0), element});
+ Tuple(&builder, {element, inner_tuple, element});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
ASSERT_IS_OK(client_->Unregister(*global_data));
@@ -122,9 +123,9 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) {
XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) {
XlaBuilder builder(TestName());
auto inner_tuple =
- builder.Tuple({builder.ConstantR0<float>(42.0),
- builder.ConstantR1<float>({1.0, 2.0, 3.0})});
- builder.Tuple({inner_tuple, builder.ConstantR1<float>({0.123, 0.456})});
+ Tuple(&builder, {ConstantR0<float>(&builder, 42.0),
+ ConstantR1<float>(&builder, {1.0, 2.0, 3.0})});
+ Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {0.123, 0.456})});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
ASSERT_IS_OK(client_->Unregister(*global_data));
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
index 12789fe665..a6a233e71a 100644
--- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -54,9 +54,9 @@ class DeconstructTupleTest : public ClientLibraryTestBase {
TEST_F(DeconstructTupleTest, DeconstructTuple) {
XlaBuilder builder(TestName());
- auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
- builder.Tuple({const1, const2});
+ auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
+ auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
+ Tuple(&builder, {const1, const2});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status = client_->DeconstructTuple(*global_data);
@@ -73,9 +73,9 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) {
TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
XlaBuilder builder(TestName());
- auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
- builder.Tuple({const1, const2});
+ auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
+ auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
+ Tuple(&builder, {const1, const2});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status1 = client_->DeconstructTuple(*global_data);
@@ -103,9 +103,9 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
XlaBuilder builder(TestName());
- auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
- builder.Tuple({const1, const2, const2, const1});
+ auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
+ auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
+ Tuple(&builder, {const1, const2, const2, const1});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status = client_->DeconstructTuple(*global_data);
@@ -129,9 +129,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
XlaBuilder builder(TestName());
- auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
- builder.Tuple({const1, const2, const1});
+ auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
+ auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
+ Tuple(&builder, {const1, const2, const1});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status = client_->DeconstructTuple(*global_data);
@@ -159,7 +159,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XlaBuilder builder(TestName());
- builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status = client_->DeconstructTuple(*global_data);
@@ -171,11 +171,11 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({3.14f, -100.25f});
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
- builder.Tuple({p});
+ auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
+ Tuple(&builder, {p});
auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
auto result_status = client_->DeconstructTuple(*global_data);
@@ -186,9 +186,9 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) {
XlaBuilder builder(TestName());
- auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
- builder.Tuple({builder.Tuple({const1, const2}), const1});
+ auto const1 = ConstantR1<float>(&builder, {1.0, 2.0, 3.0, 4.0});
+ auto const2 = ConstantR1<float>(&builder, {2.0, 4.0, 6.0, 8.0});
+ Tuple(&builder, {Tuple(&builder, {const1, const2}), const1});
auto global_data = ExecuteAndCheckTransfer(&builder, {});
auto result_status = client_->DeconstructTuple(*global_data);
diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc
index 085a5105ac..810947ab01 100644
--- a/tensorflow/compiler/xla/tests/deep_graph_test.cc
+++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc
@@ -30,7 +30,7 @@ TEST_F(ClientLibraryTestBase, DeepGraph) {
auto y_data = CreateR0Parameter<int32>(1, 1, "y", &b, &y);
XlaOp z = x;
for (int i = 0; i < kDepth; ++i) {
- z = b.Add(z, y);
+ z = Add(z, y);
}
ComputeAndCompareR0<int32>(&b, /*expected=*/kDepth + 3,
{x_data.get(), y_data.get()});
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0fd846cef8..d86fd7cc2d 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
"arg0", &builder, &param);
- auto lhs = builder.GetTupleElement(param, 0);
- auto rhs = builder.GetTupleElement(param, 1);
- builder.Dot(lhs, rhs);
+ auto lhs = GetTupleElement(param, 0);
+ auto rhs = GetTupleElement(param, 1);
+ Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *Literal::CreateR2<float>({{19, 22}, {43, 50}}),
+ *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -87,9 +88,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR1<T>({});
- auto rhs = builder.ConstantR1<T>({});
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR1<T>(&builder, {});
+ auto rhs = ConstantR1<T>(&builder, {});
+ Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(0.0), {},
this->error_spec_);
@@ -102,9 +103,9 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64);
XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR2FromArray2D<T>({{3.0f, 4.0f}});
- auto rhs = builder.ConstantFromArray<T>({3.0f, 4.0f});
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, {{3.0f, 4.0f}});
+ auto rhs = ConstantFromArray<T>(&builder, {3.0f, 4.0f});
+ Dot(lhs, rhs);
this->template ComputeAndCompareR1<T>(&builder, {static_cast<T>(25.0f)}, {},
this->error_spec_);
@@ -113,9 +114,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR1<T>({static_cast<T>(2.0f)});
- auto rhs = builder.ConstantR1<T>({static_cast<T>(3.0f)});
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR1<T>(&builder, {static_cast<T>(2.0f)});
+ auto rhs = ConstantR1<T>(&builder, {static_cast<T>(3.0f)});
+ Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(6.0f), {},
this->error_spec_);
@@ -124,9 +125,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantFromArray<T>({1.0f, 2.5f, 42.0f});
- auto rhs = builder.ConstantFromArray<T>({11.0f, -1.0f, 0.5f});
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantFromArray<T>(&builder, {1.0f, 2.5f, 42.0f});
+ auto rhs = ConstantFromArray<T>(&builder, {11.0f, -1.0f, 0.5f});
+ Dot(lhs, rhs);
this->template ComputeAndCompareR0<T>(&builder, static_cast<T>(29.5f), {},
this->error_spec_);
@@ -139,9 +140,9 @@ std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
- auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
+ auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
+ Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 0), {},
this->error_spec_);
@@ -150,10 +151,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
- auto rhs = builder.ConstantR2FromArray2D<T>(
- {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
+ auto rhs = ConstantR2FromArray2D<T>(
+ &builder, {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}});
+ Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(0, 3), {},
this->error_spec_);
@@ -162,10 +163,10 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR2FromArray2D<T>(
- {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
- auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(
+ &builder, {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}});
+ auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
+ Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(&builder, Array2D<T>(3, 0), {},
this->error_spec_);
@@ -174,9 +175,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) {
XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto lhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(2, 0));
- auto rhs = builder.ConstantR2FromArray2D<T>(Array2D<T>(0, 2));
- auto result = builder.Dot(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(2, 0));
+ auto rhs = ConstantR2FromArray2D<T>(&builder, Array2D<T>(0, 2));
+ Dot(lhs, rhs);
this->template ComputeAndCompareR2<T>(
&builder, Array2D<T>(2, 2, static_cast<T>(0.0f)), {}, this->error_spec_);
@@ -186,19 +187,19 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
auto param0 =
- builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
+ Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 4}), "arg0");
auto param1 =
- builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
- auto exp0 = builder.Exp(param0);
- auto result = builder.Dot(exp0, param1);
+ Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({4, 1}), "arg1");
+ auto exp0 = Exp(param0);
+ Dot(exp0, param1);
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -217,23 +218,22 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
- auto result = builder.Dot(
- builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
- builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
+ Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
+ Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
Array2D<T> expected({{15.0f, -2.0f}, {-25.0f, 34.0f}});
ComputeAndCompareR2<T>(&builder, expected,
@@ -287,9 +287,10 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
+ std::unique_ptr<Literal> dot_lhs_lit =
+ LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
+ param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
@@ -298,7 +299,7 @@ void ParametricDotTest::TestImpl() {
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
std::unique_ptr<Literal> dot_rhs_lit =
- Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
+ LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
@@ -308,7 +309,7 @@ void ParametricDotTest::TestImpl() {
if (param.has_addend) {
addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
- addend_lit = Literal::CreateR2FromArray2DWithLayout(
+ addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
@@ -316,26 +317,26 @@ void ParametricDotTest::TestImpl() {
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<NativeT>();
- auto result = builder.Dot(
- builder.Parameter(0,
- ShapeUtil::MakeShapeWithLayout(
- prim_type, {param.m, param.k},
- MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
- "dot_lhs"),
- builder.Parameter(1,
- ShapeUtil::MakeShapeWithLayout(
- prim_type, {param.k, param.n},
- MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
- "dot_rhs"));
+ auto result =
+ Dot(Parameter(&builder, 0,
+ ShapeUtil::MakeShapeWithLayout(
+ prim_type, {param.m, param.k},
+ MinorToMajorForIsRowMajor(param.dot_lhs_row_major)),
+ "dot_lhs"),
+ Parameter(&builder, 1,
+ ShapeUtil::MakeShapeWithLayout(
+ prim_type, {param.k, param.n},
+ MinorToMajorForIsRowMajor(param.dot_rhs_row_major)),
+ "dot_rhs"));
if (param.has_addend) {
- result = builder.Add(
- result, builder.Parameter(
- 2,
- ShapeUtil::MakeShapeWithLayout(
- prim_type, {param.m, param.n},
- MinorToMajorForIsRowMajor(param.addend_row_major)),
- "addend"));
+ result =
+ Add(result,
+ Parameter(&builder, 2,
+ ShapeUtil::MakeShapeWithLayout(
+ prim_type, {param.m, param.n},
+ MinorToMajorForIsRowMajor(param.addend_row_major)),
+ "addend"));
}
std::unique_ptr<Array2D<NativeT>> expected;
@@ -477,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -492,9 +493,8 @@ class NonsquareMatrixDot : public DotOperationTest {
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
- auto result = builder.Dot(
- builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
- builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
+ Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
+ Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
@@ -512,21 +512,20 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
- auto result = builder.Dot(
- builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
- builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
+ Dot(Parameter(&builder, 0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
+ Parameter(&builder, 1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
Array2D<complex64> expected({{30.0, -2.0}});
@@ -538,11 +537,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto matrix1 = builder.ConstantR2FromArray2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}});
- auto matrix2 = builder.ConstantR2FromArray2D<T>({{5.0f, 6.0f}, {7.0f, 8.0f}});
- auto matrix12 = builder.Dot(matrix1, matrix2);
- auto matrix21 = builder.Dot(matrix2, matrix1);
- builder.Add(matrix12, matrix21);
+ auto matrix1 =
+ ConstantR2FromArray2D<T>(&builder, {{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto matrix2 =
+ ConstantR2FromArray2D<T>(&builder, {{5.0f, 6.0f}, {7.0f, 8.0f}});
+ auto matrix12 = Dot(matrix1, matrix2);
+ auto matrix21 = Dot(matrix2, matrix1);
+ Add(matrix12, matrix21);
Array2D<T> expected({{42.0f, 56.0f}, {74.0f, 96.0f}});
this->template ComputeAndCompareR2<T>(&builder, expected, {},
@@ -559,32 +560,32 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64);
XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
- auto x =
- builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), "x");
- auto y =
- builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), "y");
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "y");
- auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
- auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
+ auto x_flat = Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
+ auto y_flat = Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
// Slice batches into individual matrices and multiply them.
std::vector<XlaOp> out_slices;
for (int i = 0; i < 4; ++i) {
// Slice off individual matrices and reshape to 2D tensors.
- auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
- x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
- auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
- y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
+ auto x_slice = Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
+ x_slice = Reshape(x_slice, {0, 1, 2}, {2, 2});
+ auto y_slice = Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1});
+ y_slice = Reshape(y_slice, {0, 1, 2}, {2, 2});
- auto out = builder.Dot(x_slice, y_slice);
- out = builder.Reshape(out, {0, 1}, {1, 2, 2});
+ auto out = Dot(x_slice, y_slice);
+ out = Reshape(out, {0, 1}, {1, 2, 2});
out_slices.push_back(out);
}
- auto out_flat = builder.ConcatInDim(out_slices, 0);
- builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
+ auto out_flat = ConcatInDim(&builder, out_slices, 0);
+ Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -592,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -616,9 +617,9 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
XlaBuilder builder(this->TestName());
auto x =
- builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
+ Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
auto y =
- builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
+ Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
@@ -626,17 +627,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
- auto out = builder.DotGeneral(x, y, dnums);
+ DotGeneral(x, y, dnums);
auto x_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -665,32 +666,36 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
}
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *lhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *lhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *rhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *rhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
XlaBuilder builder(this->TestName());
auto prim_type = primitive_util::NativeToPrimitiveType<T>();
- auto lhs_arg = builder.Parameter(
- 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
+ auto lhs_arg = Parameter(
+ &builder, 0,
+ ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
"lhs");
- auto rhs_arg = builder.Parameter(
- 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
+ auto rhs_arg = Parameter(
+ &builder, 1,
+ ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
"rhs");
if (transpose_lhs) {
- lhs_arg = builder.Transpose(lhs_arg, {1, 0});
+ lhs_arg = Transpose(lhs_arg, {1, 0});
}
if (transpose_rhs) {
- rhs_arg = builder.Transpose(rhs_arg, {1, 0});
+ rhs_arg = Transpose(rhs_arg, {1, 0});
}
- auto result = builder.Dot(lhs_arg, rhs_arg);
+ Dot(lhs_arg, rhs_arg);
Array2D<T> expected({{26.0f, 0.0f}, {-12.0f, 10.0f}});
VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
@@ -713,15 +718,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
{6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}}));
XlaBuilder builder(this->TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}),
- "rhs_arg_0");
- auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}),
- "rhs_arg_1");
- auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}),
- "rhs_arg_2");
- auto result = builder.Dot(
- lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_arg_0 = Parameter(
+ &builder, 0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0");
+ auto rhs_arg_1 = Parameter(
+ &builder, 1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs_arg_1");
+ auto rhs_arg_2 = Parameter(
+ &builder, 2, ShapeUtil::MakeShape(prim_type, {1, 2}), "rhs_arg_2");
+ Dot(lhs_constant,
+ ConcatInDim(&builder, {rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0));
std::unique_ptr<Array2D<T>> arg_0_value_array(
new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
@@ -732,15 +737,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -761,15 +766,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
{2.0f, 1.0f}}));
XlaBuilder builder(this->TestName());
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({2, 2}),
- "lhs_arg_0");
- auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShapeWithType<T>({2, 3}),
- "lhs_arg_1");
- auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShapeWithType<T>({2, 1}),
- "lhs_arg_2");
- auto result = builder.Dot(
- builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto lhs_arg_0 = Parameter(
+ &builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "lhs_arg_0");
+ auto lhs_arg_1 = Parameter(
+ &builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 3}), "lhs_arg_1");
+ auto lhs_arg_2 = Parameter(
+ &builder, 2, ShapeUtil::MakeShapeWithType<T>({2, 1}), "lhs_arg_2");
+ Dot(ConcatInDim(&builder, {lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1),
+ rhs_constant);
std::unique_ptr<Array2D<T>> arg_0_value_array(
new Array2D<T>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
@@ -781,15 +786,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
@@ -811,16 +816,15 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) {
// Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({1, 0});
- auto dynamic_slice =
- builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {1, 0});
+ auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{96.0, 105.0, 114.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -839,25 +843,23 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
// Dot result to slice from: {{114, 105, 96}, {96, 105, 114}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({0, 1});
- auto dynamic_slice =
- builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {0, 1});
+ auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{105.0}, {105.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+
+ DotOfGatherOptimizationWithConstRHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -870,25 +872,21 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({0, 1});
- auto dynamic_slice =
- builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {0, 1});
+ auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{105.0, 105.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -901,25 +899,21 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({1, 0});
- auto dynamic_slice =
- builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {1, 0});
+ auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{96.0}, {105.0}, {114.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -937,25 +931,21 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({0, 1});
- auto dynamic_slice =
- builder.DynamicSlice(lhs_constant, start_constant, {6, 1});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {0, 1});
+ auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {6, 1});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{126.0, 129.0, 132.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -973,25 +963,21 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{132, 129, 126}, {126, 129, 132}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({0, 1});
- auto dynamic_slice =
- builder.DynamicSlice(rhs_constant, start_constant, {6, 1});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {0, 1});
+ auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {6, 1});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(0);
dot_dnums.add_rhs_contracting_dimensions(0);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{129.0}, {129.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1001,25 +987,21 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({1, 0});
- auto dynamic_slice =
- builder.DynamicSlice(lhs_constant, start_constant, {1, 6});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {1, 0});
+ auto dynamic_slice = DynamicSlice(lhs_constant, start_constant, {1, 6});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
+ DotGeneral(dynamic_slice, rhs_constant, dot_dnums);
Array2D<float> expected({{56.0, 168.0, 91.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(
- DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1029,19 +1011,41 @@ XLA_TEST_F(DotOperationTest,
// Dot result to slice from: {{91, 168, 56}, {56, 168, 91}}
XlaBuilder builder(TestName());
- auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array);
- auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array);
- auto start_constant = builder.ConstantR1<int32>({1, 0});
- auto dynamic_slice =
- builder.DynamicSlice(rhs_constant, start_constant, {1, 6});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, *constant_lhs_array);
+ auto rhs_constant = ConstantR2FromArray2D(&builder, *constant_rhs_array);
+ auto start_constant = ConstantR1<int32>(&builder, {1, 0});
+ auto dynamic_slice = DynamicSlice(rhs_constant, start_constant, {1, 6});
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(1);
- auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
+ DotGeneral(lhs_constant, dynamic_slice, dot_dnums);
Array2D<float> expected({{168.0}, {168.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
+
+XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
+ XlaBuilder builder(TestName());
+
+ Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
+
+ Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
+ auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
+
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ DotGeneral(lhs_constant, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({
+ {26.f, 30.f},
+ {38.f, 44.f},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 49f3a10d22..b063b6bdef 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -124,11 +124,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
// vector<bool> is special so that it cannot be an ArraySlice<bool>, which
// is what the code below wants. So instead we do this.
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -138,8 +138,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- builder.DynamicSlice(input, starts, slice_sizes);
+ auto input = ConstantLiteral(&builder, input_values);
+ DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -150,11 +150,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -164,8 +164,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- builder.DynamicSlice(input, starts, slice_sizes);
+ auto input = ConstantLiteral(&builder, input_values);
+ DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -176,11 +176,11 @@ class DynamicSliceTest : public ClientLibraryTestBase {
const std::vector<int64>& slice_sizes,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -190,8 +190,8 @@ class DynamicSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- builder.DynamicSlice(input, starts, slice_sizes);
+ auto input = ConstantLiteral(&builder, input_values);
+ DynamicSlice(input, starts, slice_sizes);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -349,15 +349,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
void RunR0(int input_value_int, int update_value_int,
const std::vector<IndexT> slice_starts, int expected_value_int) {
Literal input_value =
- std::move(*Literal::CreateR0(input_value_int)
+ std::move(*LiteralUtil::CreateR0(input_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_value =
- std::move(*Literal::CreateR0(update_value_int)
+ std::move(*LiteralUtil::CreateR0(update_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_value =
- std::move(*Literal::CreateR0(expected_value_int)
+ std::move(*LiteralUtil::CreateR0(expected_value_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -367,9 +367,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_value);
- auto update = builder.ConstantLiteral(update_value);
- builder.DynamicUpdateSlice(input, update, starts);
+ auto input = ConstantLiteral(&builder, input_value);
+ auto update = ConstantLiteral(&builder, update_value);
+ DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_value, {start_data.get()});
}
@@ -380,15 +380,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
tensorflow::gtl::ArraySlice<int> expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR1(input_values_int)
+ std::move(*LiteralUtil::CreateR1(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR1(update_values_int)
+ std::move(*LiteralUtil::CreateR1(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR1(expected_values_int)
+ std::move(*LiteralUtil::CreateR1(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -398,9 +398,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- auto update = builder.ConstantLiteral(update_values);
- builder.DynamicUpdateSlice(input, update, starts);
+ auto input = ConstantLiteral(&builder, input_values);
+ auto update = ConstantLiteral(&builder, update_values);
+ DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -411,15 +411,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array2D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR2FromArray2D(input_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR2FromArray2D(update_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR2FromArray2D(expected_values_int)
+ std::move(*LiteralUtil::CreateR2FromArray2D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -429,9 +429,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- auto update = builder.ConstantLiteral(update_values);
- builder.DynamicUpdateSlice(input, update, starts);
+ auto input = ConstantLiteral(&builder, input_values);
+ auto update = ConstantLiteral(&builder, update_values);
+ DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -442,15 +442,15 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
const std::vector<IndexT> slice_starts,
const Array3D<int>& expected_values_int) {
Literal input_values =
- std::move(*Literal::CreateR3FromArray3D(input_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(input_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal update_values =
- std::move(*Literal::CreateR3FromArray3D(update_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(update_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
Literal expected_values =
- std::move(*Literal::CreateR3FromArray3D(expected_values_int)
+ std::move(*LiteralUtil::CreateR3FromArray3D(expected_values_int)
->Convert(primitive_util::NativeToPrimitiveType<DataT>())
.ValueOrDie());
@@ -460,9 +460,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
slice_starts, 0, "slice_starts", &builder, &starts);
// Build dynamic slice computation.
- auto input = builder.ConstantLiteral(input_values);
- auto update = builder.ConstantLiteral(update_values);
- builder.DynamicUpdateSlice(input, update, starts);
+ auto input = ConstantLiteral(&builder, input_values);
+ auto update = ConstantLiteral(&builder, update_values);
+ DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()});
}
@@ -508,8 +508,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
XlaOp update;
std::unique_ptr<GlobalData> update_data = CreateR3Parameter<T>(
update_values, 1, "update_values", &builder, &update);
- auto starts = builder.ConstantR1<int32>({index, 0, 0});
- builder.DynamicUpdateSlice(input, update, starts);
+ auto starts = ConstantR1<int32>(&builder, {index, 0, 0});
+ DynamicUpdateSlice(input, update, starts);
// Run computation and compare against expected values.
ComputeAndCompareR3<T>(&builder, expected_values,
@@ -520,7 +520,7 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <typename NativeT>
void DumpArray(const string& name, const Array3D<NativeT> values) {
std::unique_ptr<Literal> literal =
- Literal::CreateR3FromArray3D<NativeT>(values);
+ LiteralUtil::CreateR3FromArray3D<NativeT>(values);
LOG(INFO) << name << ":" << literal->ToString();
}
};
@@ -695,17 +695,17 @@ void BM_DynamicSlice(int num_iters) {
XlaBuilder builder("DynamicSlice");
// Create input as a constant: shape [1, 2, 3, 4]
- auto input_literal = Literal::CreateR4(
+ auto input_literal = LiteralUtil::CreateR4(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
- auto input = builder.ConstantLiteral(*input_literal);
+ auto input = ConstantLiteral(&builder, *input_literal);
// Create dynamic slice start indices as a parameter: shape [4]
auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
auto start_indices =
- builder.Parameter(0, start_indices_shape, "start_indices");
+ Parameter(&builder, 0, start_indices_shape, "start_indices");
// Add DynamicSlice op to the computatation.
- builder.DynamicSlice(input, start_indices, {1, 1, 1, 1});
+ DynamicSlice(input, start_indices, {1, 1, 1, 1});
auto computation = builder.Build().ConsumeValueOrDie();
// Initialize and transfer parameter buffer.
@@ -715,9 +715,11 @@ void BM_DynamicSlice(int num_iters) {
start_indices_shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
- auto start_indices_literal = Literal::CreateR1<int32>({0, 1, 2, 3});
+ auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3});
+ auto stream =
+ client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- executors[device_ordinal], *start_indices_literal, buffer));
+ stream.get(), *start_indices_literal, buffer));
std::unique_ptr<LocalExecutable> executable =
client
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index a6ba6db5d3..ebba13c5b3 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -31,10 +31,10 @@ XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> input,
client_->TransferToServer(
- *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
+ *LiteralUtil::CreateR2F32Linspace(1e0, 1e5, 256, 256)));
XlaBuilder b(TestName() + ".add");
- b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1"));
+ Dot(Parameter(&b, 0, shape, "param_0"), Parameter(&b, 1, shape, "param_1"));
TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build());
ExecutionProfile execution_profile;
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 0a37e4d423..86bfaea4ef 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -39,7 +39,7 @@ class ExhaustiveF32ElementwiseOpTest
XlaBuilder builder(TestName());
std::unique_ptr<Literal> input_literal =
- Literal::CreateFromDimensions(F32, {input_size});
+ LiteralUtil::CreateFromDimensions(F32, {input_size});
for (int64 i = begin; i < end; i++) {
if (i >= known_incorrect_range.first &&
i < known_incorrect_range.second) {
@@ -54,7 +54,7 @@ class ExhaustiveF32ElementwiseOpTest
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
- auto input = builder.Parameter(0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal->shape(), "input");
enqueue_op(&builder, input);
std::vector<float> expected_result;
@@ -79,8 +79,8 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) {
#endif
ExhaustivelyTestF32Op(
- [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); },
- std::log, known_incorrect_range);
+ [](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log,
+ known_incorrect_range);
}
XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) {
@@ -95,14 +95,14 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) {
#endif
ExhaustivelyTestF32Op(
- [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); },
- std::exp, known_incorrect_range);
+ [](XlaBuilder* builder, const XlaOp& input) { Exp(input); }, std::exp,
+ known_incorrect_range);
}
XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) {
ExhaustivelyTestF32Op(
- [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); },
- std::tanh, /*known_incorrect_range=*/{0, 0});
+ [](XlaBuilder* builder, const XlaOp& input) { Tanh(input); }, std::tanh,
+ /*known_incorrect_range=*/{0, 0});
}
std::vector<std::pair<int64, int64>> CreateExhaustiveParameters() {
diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc
index 93d1c921c4..dcb469087e 100644
--- a/tensorflow/compiler/xla/tests/filecheck.cc
+++ b/tensorflow/compiler/xla/tests/filecheck.cc
@@ -76,6 +76,11 @@ StatusOr<bool> RunFileCheck(const string& input, const string& pattern) {
XLA_LOG_LINES(tensorflow::WARNING, input);
LOG(WARNING) << "FileCheck pattern was:";
XLA_LOG_LINES(tensorflow::WARNING, pattern);
+ } else if (!standard_error.empty()) {
+ LOG(INFO) << "FileCheck stderr:";
+ XLA_LOG_LINES(tensorflow::INFO, standard_error);
+ LOG(INFO) << "FileCheck input was:";
+ XLA_LOG_LINES(tensorflow::INFO, input);
}
return succeeded;
}
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
index 71eb914a8e..30dc639f11 100644
--- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -42,12 +42,12 @@ class FloorCeilTest : public ClientLibraryTestBase {
LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ")
<< "}";
XlaBuilder builder(TestName());
- auto c = builder.ConstantR1<float>(input);
+ auto c = ConstantR1<float>(&builder, input);
if (f == kCeil) {
- builder.Ceil(c);
+ Ceil(c);
} else {
ASSERT_EQ(kFloor, f);
- builder.Floor(c);
+ Floor(c);
}
ComputeAndCompareR1<float>(&builder, expected, /*arguments=*/{});
}
@@ -55,12 +55,12 @@ class FloorCeilTest : public ClientLibraryTestBase {
void TestR0F32(float input, float expected, Function f) {
LOG(INFO) << "input: " << expected;
XlaBuilder builder(TestName());
- auto c = builder.ConstantR0<float>(input);
+ auto c = ConstantR0<float>(&builder, input);
if (f == kCeil) {
- builder.Ceil(c);
+ Ceil(c);
} else {
ASSERT_EQ(kFloor, f);
- builder.Floor(c);
+ Floor(c);
}
ComputeAndCompareR0<float>(&builder, expected, /*arguments=*/{});
}
diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc
index 73f029b59b..0254ae1baa 100644
--- a/tensorflow/compiler/xla/tests/fmax_test.cc
+++ b/tensorflow/compiler/xla/tests/fmax_test.cc
@@ -28,11 +28,11 @@ class FmaxSimpleTest : public ClientLibraryTestBase {};
TEST_F(FmaxSimpleTest, FmaxTenValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
- auto y = builder.ConstantR1<float>(
- {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
- builder.Max(x, y);
+ auto x = ConstantR1<float>(
+ &builder, {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = ConstantR1<float>(
+ &builder, {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ Max(x, y);
std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 9.0};
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index e6f79b5ac5..dc64477935 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -26,13 +26,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -89,7 +90,7 @@ class FusionTest : public HloTestBase {
HloInstruction* hlos[4];
for (int i = 0; i < Arity; ++i) {
hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2FromArray2D(operand_data[i])));
+ LiteralUtil::CreateR2FromArray2D(operand_data[i])));
}
auto answer_shape =
ShapeUtil::MakeShape(prim_type, {test_width, test_height});
@@ -115,7 +116,7 @@ class FusionTest : public HloTestBase {
ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
HloInstruction::FusionKind::kLoop);
- auto expected = Literal::CreateR2FromArray2D(answer_data);
+ auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
@@ -186,27 +187,28 @@ XLA_TEST_F(FusionTest, Test) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.62, 2.72, 3.14}})));
+ LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
- auto const10 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<bool>({{true, false, true}, {false, true, false}})));
+ LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+ auto const10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, false, true}, {false, true, false}})));
auto select11 = builder.AddInstruction(
HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
HloOpcode::kSelect, const10, add8, const9));
@@ -222,7 +224,7 @@ XLA_TEST_F(FusionTest, Test) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.5}, {2.72}}),
+ *LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -233,11 +235,11 @@ XLA_TEST_F(FusionTest, Parameter) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{1.0, 2.0, 3.0}})));
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+ LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
// add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
@@ -248,7 +250,7 @@ XLA_TEST_F(FusionTest, Parameter) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -269,7 +271,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
auto hlo_module = CreateNewModule();
auto two = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
auto x =
builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {}));
auto y = builder.AddInstruction(
@@ -292,9 +294,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR1<float>({1.0, 2.0, 3.0})));
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+ LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
// add2 = broadcast(const_vector) + const_array
@@ -308,7 +310,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Near(
- *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
*ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
@@ -316,14 +318,14 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto single_element_array = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR2<int32>({{5}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {}), single_element_array));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(5),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -331,14 +333,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -346,14 +348,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -361,14 +363,14 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR3<int32>({{{7}}})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -376,14 +378,14 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR3<int32>({{{7}}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -391,14 +393,14 @@ XLA_TEST_F(FusionTest, Reshape__) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(7)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(7),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -406,14 +408,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -421,14 +423,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -436,14 +438,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -451,7 +453,7 @@ XLA_TEST_F(FusionTest, Reverse) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
hlo_module->AddEntryComputation(builder.Build())
@@ -459,7 +461,7 @@ XLA_TEST_F(FusionTest, Reverse) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -467,7 +469,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
ShapeUtil::MakeShape(S32, {3}), const0, {0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -477,7 +479,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-3, -2, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -485,7 +487,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
ShapeUtil::MakeShape(S32, {2}), const0, {}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -495,15 +497,15 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -1}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(S32, {2}), const0, {0}, {4}, {2}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -513,17 +515,17 @@ XLA_TEST_F(FusionTest, SliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-1, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1})));
auto dynamic_slice2 =
builder.AddInstruction(HloInstruction::CreateDynamicSlice(
ShapeUtil::MakeShape(S32, {2}), const0, const1, {2}));
@@ -535,15 +537,15 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({-2, -3}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 3, 4})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4})));
auto reshape1 = builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {2, 2}), const0));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -552,17 +554,16 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
-// TODO(b/64070202): Investigate failure.
-XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
+XLA_TEST_F(FusionTest, TransposeNegate) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{1, 2}, {3, 4}})));
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}})));
auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {2, 2}), const0, {1, 0}));
auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
@@ -571,9 +572,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {})));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -591,10 +592,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -603,7 +604,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -611,10 +612,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
auto hlo_module = CreateNewModule();
auto builder = HloComputation::Builder(TestName());
- auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({1, 2, 4, 8})));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
@@ -625,7 +626,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR0<int32>(-15),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -633,9 +634,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
auto builder = HloComputation::Builder(TestName());
auto hlo_module = CreateNewModule();
auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
- Literal::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+ LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<int32>(1)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
Window window;
ASSERT_TRUE(
tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
@@ -675,7 +676,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
HloInstruction::FusionKind::kLoop);
EXPECT_TRUE(LiteralTestUtil::Equal(
- *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -687,9 +688,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
auto builder = HloComputation::Builder(TestName());
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({0})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({0})));
auto const1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0));
auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
@@ -711,7 +712,7 @@ XLA_TEST_F(FusionTest, SharedConstant) {
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
EXPECT_TRUE(
- LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+ LiteralTestUtil::Equal(*LiteralUtil::CreateR1<int32>({8}),
*ExecuteAndTransfer(std::move(hlo_module), {})));
}
@@ -765,6 +766,39 @@ XLA_TEST_F(FusionTest, Clamp2D) {
TestElementwise2D<float, 3>(HloOpcode::kClamp);
}
+// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast.
+XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) {
+ const string hlo_text = R"(
+HloModule Cluster
+
+fusion_c {
+ fusion.arg = f32[2,2]{1,0} parameter(0)
+ bitcast.0 = f32[2,2,1]{2,1,0} bitcast(fusion.arg)
+ tanh.0 = f32[2,2,1]{0,2,1} tanh(bitcast.0)
+ ROOT bitcast.2 = f32[2,2,1]{1,2,0} bitcast(tanh.0)
+}
+
+ENTRY main {
+ arg = f32[2,2]{1,0} parameter(0)
+ ROOT fusion = f32[2,2,1]{1,2,0} fusion(arg), kind=kLoop, calls=fusion_c
+}
+)";
+
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<float>({{0., 0.}, {1., 0.}});
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ test_runner_.Execute(std::move(module), {operand.get()},
+ /*run_hlo_passes=*/false));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}),
+ *result));
+}
+
void BM_ParallelFusion(int num_iters) {
// Simple element-wise computation to benchmark parallel task partitioning.
tensorflow::testing::StopTiming();
@@ -793,31 +827,31 @@ void BM_ParallelFusion(int num_iters) {
// Create computation.
XlaBuilder builder("ParallelFusion");
Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1});
- auto param0 = builder.Parameter(0, shape0, "param0");
+ auto param0 = Parameter(&builder, 0, shape0, "param0");
Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1});
- auto param1 = builder.Parameter(1, shape1, "param1");
+ auto param1 = Parameter(&builder, 1, shape1, "param1");
Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1});
- auto param2 = builder.Parameter(2, shape2, "param2");
+ auto param2 = Parameter(&builder, 2, shape2, "param2");
- auto x = builder.Mul(param0, param1);
- auto y = builder.Add(x, param2);
+ auto x = Mul(param0, param1);
+ Add(x, param2);
auto computation = builder.Build().ConsumeValueOrDie();
// Transfer literals to device.
auto param0_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
ScopedShapedBuffer buffer0 =
client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
ScopedShapedBuffer buffer1 =
client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
- Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
+ LiteralUtil::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
ScopedShapedBuffer buffer2 =
client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 4854c649c1..c5ca64fa3f 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -13,16 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
-
-// NB! TODO(b/74360564): These tests do not test out of bounds behavior since
-// that hasn't been specced yet.
namespace xla {
namespace {
@@ -41,7 +39,7 @@ class GatherOperationTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text, config));
+ ParseHloString(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
@@ -62,8 +60,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -83,8 +82,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -104,9 +104,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -126,9 +126,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -148,9 +148,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -170,11 +170,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -194,11 +194,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -218,8 +218,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -239,9 +240,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -260,18 +261,15 @@ ENTRY main {
window_bounds={1, 0}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OutOfBoundsIndex) {
// Out of bounds indices must not crash, and the indices in range should
// produce the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for OOB accesses,
- // we should get rid of the mask and check that backends produce the same
- // value for OOB indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -285,29 +283,45 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, OutOfBoundsUnsignedIndex) {
+ // Out of bounds indices must not crash, and the indices in range should
+ // produce the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ gather = s32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = s32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, NegativeIndex) {
// Negative indices must not crash, and the indices in range should produce
// the same values across all backends.
- //
- // TODO(b/74360564): Once we have a well defined semantics for negative
- // accesses, we should get rid of the mask and check that backends produce the
- // same value for negative indices too.
const string hlo_text = R"(
HloModule BatchDynamicSlice
@@ -321,20 +335,40 @@ ENTRY main {
gather_dims_to_operand_dims={0,1},
index_vector_dim=1,
window_bounds={1,1}
- gather_reshaped = s32[6]{0} reshape(gather)
- in_bounds_mask = s32[6]{0} parameter(2)
- ROOT result = s32[6]{0} multiply(gather_reshaped, in_bounds_mask)
+ ROOT result = s32[6]{0} reshape(gather)
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR2<int32>(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
{{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
- std::unique_ptr<Literal> in_bounds_mask =
- Literal::CreateR1<int32>({0, 1, 1, 0, 0, 1});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
+}
+
+XLA_TEST_F(GatherOperationTest, NegativeIndexIntoUnsignedOperand) {
+ // Negative indices must not crash, and the indices in range should produce
+ // the same values across all backends.
- RunTest(hlo_text,
- {operand.get(), gather_indices.get(), in_bounds_mask.get()});
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+ENTRY main {
+ operand = u32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ gather = u32[6,1,1]{2,1,0} gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+ ROOT result = u32[6]{0} reshape(gather)
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<uint32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR2<int32>(
+ {{2, -1}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ RunTest(hlo_text, operand.get(), gather_indices.get());
}
XLA_TEST_F(GatherOperationTest, OneScalarIndex) {
@@ -352,9 +386,9 @@ ENTRY main {
window_bounds={1,3,2}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR3<int32>(
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
{{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -373,8 +407,8 @@ ENTRY main {
window_bounds={1}
}
)";
- std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({1, 2, 3, 4});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR0<int32>(1);
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR0<int32>(1);
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -394,8 +428,8 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices = LiteralUtil::CreateR1<int32>({});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -418,8 +452,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -442,9 +477,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 2}, {2, 1}});
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -467,9 +502,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
+ LiteralUtil::CreateR3<int32>({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -492,11 +527,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -520,11 +555,11 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
- {{-4, 4}, {-5, 5}, {-6, 6}}, //
- {{-7, 7}, {-8, 8}, {-9, 9}}});
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{0, 0}, {1, 0}});
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -547,8 +582,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
- std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> gather_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -571,9 +607,9 @@ ENTRY main {
}
)";
std::unique_ptr<Literal> operand =
- Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
- Literal::CreateR2<int32>({{2, 1}, {1, 1}});
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
RunTest(hlo_text, operand.get(), gather_indices.get());
}
@@ -598,22 +634,23 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
Shape operand_shape = ShapeUtil::MakeShape(S32, {3, 3});
Shape indices_shape = ShapeUtil::MakeShape(S32, {2});
- auto operand = builder.Parameter(0, operand_shape, "operand");
- auto indices = builder.Parameter(1, indices_shape, "indices");
+ auto operand = Parameter(&builder, 0, operand_shape, "operand");
+ auto indices = Parameter(&builder, 1, indices_shape, "indices");
GatherDimensionNumbers dim_numbers;
dim_numbers.add_output_window_dims(1);
dim_numbers.add_elided_window_dims(0);
dim_numbers.add_gather_dims_to_operand_dims(0);
dim_numbers.set_index_vector_dim(1);
- builder.Gather(operand, indices, dim_numbers, {1, 3});
+ Gather(operand, indices, dim_numbers, {1, 3});
std::vector<int32> expected = {};
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> operand_arg,
- client_->TransferToServer(*Literal::CreateR2<int32>(
- {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<GlobalData> operand_arg,
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> indices_arg,
- client_->TransferToServer(*Literal::CreateR1<int32>({0, 2})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<int32>({0, 2})));
TF_ASSERT_OK_AND_ASSIGN(std::vector<xla::DeviceHandle> devices,
client_->GetDeviceHandles(1));
xla::ExecutionOptions execution_options = CreateDefaultExecutionOptions();
@@ -629,8 +666,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
client_->ExecuteParallel(computation_instances));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
client_->Transfer(*(result_data[0])));
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2, 3}, {7, 8, 9}},
+ *result_literal);
}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc
index 76bf47845c..73a47eda72 100644
--- a/tensorflow/compiler/xla/tests/half_test.cc
+++ b/tensorflow/compiler/xla/tests/half_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -37,8 +37,7 @@ class HalfTestBase : public ClientLibraryTestBase {
static const int kNumElements = 4;
};
-using UnaryBuildFuncTy =
- std::function<void(xla::XlaBuilder*, const xla::XlaOp& src)>;
+using UnaryBuildFuncTy = std::function<void(const xla::XlaOp& src)>;
struct UnaryOpTestParam {
std::function<half(half)> compute_func;
@@ -62,7 +61,7 @@ XLA_TEST_P(UnaryOpTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<half>(&builder, expected, {x_data.get()}, error_spec_);
}
@@ -79,18 +78,17 @@ half round_imp(half value) {
INSTANTIATE_TEST_CASE_P(
half, UnaryOpTest,
::testing::Values(
- UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs},
- UnaryOpTestParam{[](half x) { return round_imp(x); },
- &XlaBuilder::Round},
- UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil},
- UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos},
- UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp},
- UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor},
- UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log},
- UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg},
- UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign},
- UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin},
- UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh}
+ UnaryOpTestParam{[](half x) { return abs(x); }, &Abs},
+ UnaryOpTestParam{[](half x) { return round_imp(x); }, &Round},
+ UnaryOpTestParam{[](half x) { return ceil(x); }, &Ceil},
+ UnaryOpTestParam{[](half x) { return cos(x); }, &Cos},
+ UnaryOpTestParam{[](half x) { return exp(x); }, &Exp},
+ UnaryOpTestParam{[](half x) { return floor(x); }, &Floor},
+ UnaryOpTestParam{[](half x) { return log(x); }, &Log},
+ UnaryOpTestParam{[](half x) { return -x; }, &Neg},
+ UnaryOpTestParam{[](half x) { return sign_imp(x); }, &Sign},
+ UnaryOpTestParam{[](half x) { return sin(x); }, &Sin},
+ UnaryOpTestParam{[](half x) { return tanh(x); }, &Tanh}
));
@@ -118,19 +116,18 @@ XLA_TEST_P(UnaryPredTest, Ops) {
}
UnaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd);
+ build_func(x_opnd);
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get()});
}
INSTANTIATE_TEST_CASE_P(half, UnaryPredTest,
::testing::Values(UnaryPredTestParam{
- [](half x) { return isfinite(x); },
- &XlaBuilder::IsFinite}));
+ [](half x) { return isfinite(x); }, &IsFinite}));
-using BinaryBuildFuncTy = std::function<void(
- xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y,
- tensorflow::gtl::ArraySlice<int64>)>;
+using BinaryBuildFuncTy =
+ std::function<void(const xla::XlaOp& x, const xla::XlaOp& y,
+ tensorflow::gtl::ArraySlice<int64>)>;
struct BinaryOpTestParam {
std::function<half(half, half)> compute_func;
@@ -159,7 +156,7 @@ XLA_TEST_P(BinaryOpTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<half>(&builder, expected, {x_data.get(), y_data.get()},
error_spec_);
@@ -173,22 +170,15 @@ half atan2_imp(half x, half y) {
INSTANTIATE_TEST_CASE_P(
half, BinaryOpTest,
::testing::Values(
- BinaryOpTestParam{[](half x, half y) { return x + y; },
- &XlaBuilder::Add},
+ BinaryOpTestParam{[](half x, half y) { return x + y; }, &Add},
BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); },
- &XlaBuilder::Atan2},
- BinaryOpTestParam{[](half x, half y) { return x / y; },
- &XlaBuilder::Div},
- BinaryOpTestParam{[](half x, half y) { return max(x, y); },
- &XlaBuilder::Max},
- BinaryOpTestParam{[](half x, half y) { return min(x, y); },
- &XlaBuilder::Min},
- BinaryOpTestParam{[](half x, half y) { return x * y; },
- &XlaBuilder::Mul},
- BinaryOpTestParam{[](half x, half y) { return pow(x, y); },
- &XlaBuilder::Pow},
- BinaryOpTestParam{[](half x, half y) { return x - y; },
- &XlaBuilder::Sub}
+ &Atan2},
+ BinaryOpTestParam{[](half x, half y) { return x / y; }, &Div},
+ BinaryOpTestParam{[](half x, half y) { return max(x, y); }, &Max},
+ BinaryOpTestParam{[](half x, half y) { return min(x, y); }, &Min},
+ BinaryOpTestParam{[](half x, half y) { return x * y; }, &Mul},
+ BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, &Pow},
+ BinaryOpTestParam{[](half x, half y) { return x - y; }, &Sub}
));
@@ -221,27 +211,22 @@ XLA_TEST_P(BinaryPredTest, Ops) {
}
BinaryBuildFuncTy build_func = GetParam().build_func;
- build_func(&builder, x_opnd, y_opnd, {});
+ build_func(x_opnd, y_opnd, {});
ComputeAndCompareR1<bool>(&builder, expected, {x_data.get(), y_data.get()});
}
INSTANTIATE_TEST_CASE_P(
half, BinaryPredTest,
- ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; },
- &XlaBuilder::Eq},
- BinaryPredTestParam{[](half x, half y) { return x != y; },
- &XlaBuilder::Ne},
- BinaryPredTestParam{[](half x, half y) { return x >= y; },
- &XlaBuilder::Ge},
- BinaryPredTestParam{[](half x, half y) { return x > y; },
- &XlaBuilder::Gt},
- BinaryPredTestParam{[](half x, half y) { return x <= y; },
- &XlaBuilder::Le},
- BinaryPredTestParam{[](half x, half y) { return x < y; },
- &XlaBuilder::Lt}
-
- ));
+ ::testing::Values(
+ BinaryPredTestParam{[](half x, half y) { return x == y; }, &Eq},
+ BinaryPredTestParam{[](half x, half y) { return x != y; }, &Ne},
+ BinaryPredTestParam{[](half x, half y) { return x >= y; }, &Ge},
+ BinaryPredTestParam{[](half x, half y) { return x > y; }, &Gt},
+ BinaryPredTestParam{[](half x, half y) { return x <= y; }, &Le},
+ BinaryPredTestParam{[](half x, half y) { return x < y; }, &Lt}
+
+ ));
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
index cf971dd61b..4d82442f7e 100644
--- a/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
+++ b/tensorflow/compiler/xla/tests/hlo_metadata_test.cc
@@ -30,9 +30,9 @@ class HloMetadataTest : public LocalClientTestBase {
}
void BuildAddComputation(XlaBuilder* builder) {
- auto x = builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder->Add(x, y);
+ auto x = Parameter(builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Add(x, y);
}
OpMetadata metadata_;
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 36e19e6507..b662e83716 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -23,11 +23,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -94,8 +94,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform,
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
- return MakeUnique<HloModule>(name, VersionedComputationHandle(),
- GetModuleConfigForTest());
+ return MakeUnique<HloModule>(name, GetModuleConfigForTest());
}
/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() {
@@ -277,9 +276,10 @@ StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal(
HloComputation* HloTestBase::FindComputation(HloModule* module,
tensorflow::StringPiece name) {
- auto it = c_find_if(module->computations(),
+ auto computations = module->computations();
+ auto it = c_find_if(computations,
[&](HloComputation* c) { return c->name() == name; });
- if (it == module->computations().end()) {
+ if (it == computations.end()) {
return nullptr;
}
return *it;
@@ -288,9 +288,10 @@ HloComputation* HloTestBase::FindComputation(HloModule* module,
HloInstruction* HloTestBase::FindInstruction(HloModule* module,
tensorflow::StringPiece name) {
for (const HloComputation* c : module->computations()) {
- auto it = c_find_if(c->instructions(),
+ auto instructions = c->instructions();
+ auto it = c_find_if(instructions,
[&](HloInstruction* i) { return i->name() == name; });
- if (it != c->instructions().end()) {
+ if (it != instructions.end()) {
return *it;
}
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index eb3a2ea76a..9009d67cea 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -66,6 +66,15 @@ namespace xla {
//
// For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test {
+ public:
+ // Creates a new HLO module for a test. The module created will have
+ // TestName() for its name; it will also automatically populate its debug
+ // options from command-line flags. If you want a fresh HloModule object and
+ // then add HloComputations to it, it's recommended to use this method in your
+ // tests.
+ static std::unique_ptr<HloModule> CreateNewModule(
+ const string& name = TestName());
+
protected:
// This uses the interpreter backend as the reference backend and
// automatically finds another supported backend as the test backend. If the
@@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test {
~HloTestBase() override {}
- // Creates a new HLO module for a test. The module created will have
- // TestName() for its name; it will also automatically populate its debug
- // options from command-line flags. If you want a fresh HloModule object and
- // then add HloComputations to it, it's recommended to use this method in your
- // tests.
- static std::unique_ptr<HloModule> CreateNewModule(
- const string& name = TestName());
-
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
@@ -184,13 +185,9 @@ class HloTestBase : public ::testing::Test {
// 'layout'.
void ForceParameterLayout(HloModule* module, int64 param_no,
const Layout& layout) {
- ASSERT_LT(
- param_no,
- module->mutable_host_entry_computation_layout()->parameter_count());
- module->mutable_host_entry_computation_layout()
- ->mutable_parameter_layout(param_no)
- ->ResetLayout(layout);
- module->mutable_device_entry_computation_layout()
+ ASSERT_LT(param_no,
+ module->mutable_entry_computation_layout()->parameter_count());
+ module->mutable_entry_computation_layout()
->mutable_parameter_layout(param_no)
->ResetLayout(layout);
}
@@ -198,10 +195,7 @@ class HloTestBase : public ::testing::Test {
// Convenience method to force the layout of the computation result in a
// module. The result layout of 'module' is set to 'layout'.
void ForceResultLayout(HloModule* module, const Layout& layout) {
- module->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->ResetLayout(layout);
- module->mutable_device_entry_computation_layout()
+ module->mutable_entry_computation_layout()
->mutable_result_layout()
->ResetLayout(layout);
}
@@ -209,10 +203,7 @@ class HloTestBase : public ::testing::Test {
// Convenience method to clear the layout of the computation result in
// 'module'.
void ForceClearResultLayout(HloModule* module) {
- module->mutable_host_entry_computation_layout()
- ->mutable_result_layout()
- ->Clear();
- module->mutable_device_entry_computation_layout()
+ module->mutable_entry_computation_layout()
->mutable_result_layout()
->Clear();
}
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index da4cf4ae0c..ad1f5b9eed 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() {
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- VerifyModule();
+ VerifyModule(module_.get());
+ }
+ for (int i = 0; i < modules_.size(); ++i) {
+ VerifyModule(modules_.at(i).get());
}
HloTestBase::TearDown();
}
-void HloVerifiedTestBase::VerifyModule() {
- HloVerifier verifier;
- xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+void HloVerifiedTestBase::VerifyModule(HloModule* module) {
+ HloVerifier verifier(/*allow_mixed_precision=*/true);
+ xla::StatusOr<bool> mutated = verifier.Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() {
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = CreateNewModule();
+ module_ = HloTestBase::CreateNewModule();
}
return *module_;
}
-void HloVerifiedTestBase::ParseAndVerifyModule(
- tensorflow::StringPiece hlo_text) {
+HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
+ modules_.emplace_back(HloTestBase::CreateNewModule());
+ return modules_.back().get();
+}
+
+void HloVerifiedTestBase::ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+ const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text));
- VerifyModule();
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
+ VerifyModule(module_.get());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index e5bb14a883..5b28c01c36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -44,7 +44,8 @@ class HloVerifiedTestBase : public HloTestBase {
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
HloModule& module();
- void ParseAndVerifyModule(tensorflow::StringPiece hlo_text);
+ void ParseAndVerifyModule(tensorflow::StringPiece hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
// Sets the shape-size function used during hlo verification. If this isn't
// called, a default ShapeVerifier is used instead.
@@ -52,11 +53,23 @@ class HloVerifiedTestBase : public HloTestBase {
shape_verifier_ = std::move(shape_verifier);
}
+ // Creates a new module for a test, and stores it in modules_ so it can be
+ // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
+ // creation of unverified modules.
+ HloModule* CreateNewModule(const string& name = TestName());
+
+ // It is confusing to store modules created by module() and CreateNewModule()
+ // in different fields, but it allows us to migrate tests to
+ // HloVerifiedTestBase more easily, so it's a win because we can verify more
+ // modules. See b/80488902.
private:
- std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
+ // Lazily populated. Access via module().
+ std::unique_ptr<HloModule> module_;
+ // Populated by calls to CreateNewModule.
+ std::vector<std::unique_ptr<HloModule>> modules_;
std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
- void VerifyModule();
+ static void VerifyModule(HloModule* module);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index d1b8a6cf0b..31a099c15f 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/error_spec.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -154,20 +155,20 @@ class LiteralTestUtil {
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
@@ -175,46 +176,46 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual) {
- EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
+ EXPECT_TRUE(Equal(*LiteralUtil::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -222,7 +223,7 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
@@ -231,28 +232,28 @@ template <typename NativeT>
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
const LiteralSlice& actual, const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
+ EXPECT_TRUE(Near(*LiteralUtil::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
index bbac7285ae..f297b2b847 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -31,8 +31,9 @@ namespace xla {
namespace {
TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple({
- Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
});
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
}
@@ -42,11 +43,13 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
// un-fail an assertion failure. The CHECK-failure is death, so we can make a
// death assertion.
auto unequal_things_are_equal = [] {
- std::unique_ptr<Literal> lhs = Literal::MakeTuple({
- Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
+ std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
});
- std::unique_ptr<Literal> rhs = Literal::MakeTuple({
- Literal::CreateR0<int32>(64).get(), Literal::CreateR0<int32>(42).get(),
+ std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(64).get(),
+ LiteralUtil::CreateR0<int32>(42).get(),
});
CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
};
@@ -55,8 +58,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
auto dummy_lambda = [] {
- auto two = Literal::CreateR0<float>(2);
- auto four = Literal::CreateR0<float>(4);
+ auto two = LiteralUtil::CreateR0<float>(2);
+ auto four = LiteralUtil::CreateR0<float>(4);
ErrorSpec error(0.001);
CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
};
@@ -98,8 +101,8 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
}
TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
- auto expected = Literal::CreateR1<int32>({1, 2, 3});
- auto actual = Literal::CreateR1<int32>({4, 5, 6});
+ auto expected = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ auto actual = LiteralUtil::CreateR1<int32>({4, 5, 6});
::testing::AssertionResult result =
LiteralTestUtil::Equal(*expected, *actual);
EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
@@ -107,25 +110,26 @@ TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
}
TEST(LiteralTestUtilTest, NearComparatorR1) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- auto b =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtilTest, NearComparatorR1Nan) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
- auto b =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
+ auto b = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8});
EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
}
TEST(LiteralTestUtil, NearComparatorDifferentLengths) {
- auto a =
- Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
- auto b = Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
+ auto a = LiteralUtil::CreateR1<float>(
+ {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
+ auto b =
+ LiteralUtil::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7});
EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001}));
EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001}));
}
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 2f46ee0be2..13df83ffff 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
@@ -64,7 +65,7 @@ class LLVMCompilerTest : public ::testing::Test {
// Create HLO module, and run the compiler.
auto builder = HloComputation::Builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
auto hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
@@ -86,7 +87,7 @@ class LLVMCompilerTest : public ::testing::Test {
void TestMultiModuleCompilation(LLVMCompiler *compiler) {
HloComputation::Builder builder(TestName());
builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
std::unique_ptr<HloModule> hlo_module = CreateNewModule();
hlo_module->AddEntryComputation(builder.Build());
@@ -124,8 +125,7 @@ class LLVMCompilerTest : public ::testing::Test {
static std::unique_ptr<HloModule> CreateNewModule() {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
- config);
+ return MakeUnique<HloModule>(TestName(), config);
}
};
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
index 2c45f19c09..6fc1115097 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include <utility>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -25,28 +26,28 @@ limitations under the License.
namespace xla {
-void LLVMIRGenTestBase::SetIrHook(bool match_optimized_ir) {
+void LlvmIrGenTestBase::SetIrHook(bool match_optimized_ir) {
auto llvm_compiler = GetLLVMCompiler();
using std::placeholders::_1;
// Add the IR inspection hook to the LLVM compiler.
if (match_optimized_ir) {
llvm_compiler->SetPostOptimizationHook(
- std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ std::bind(&LlvmIrGenTestBase::IrHook, this, _1));
} else {
llvm_compiler->SetPreOptimizationHook(
- std::bind(&LLVMIRGenTestBase::IrHook, this, _1));
+ std::bind(&LlvmIrGenTestBase::IrHook, this, _1));
}
}
-void LLVMIRGenTestBase::ResetIrHook() {
+void LlvmIrGenTestBase::ResetIrHook() {
auto llvm_compiler = GetLLVMCompiler();
llvm_compiler->RemovePreOptimizationHook();
llvm_compiler->RemovePostOptimizationHook();
}
-void LLVMIRGenTestBase::CompileAndVerifyIr(
+void LlvmIrGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
@@ -58,7 +59,17 @@ void LLVMIRGenTestBase::CompileAndVerifyIr(
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
-void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
+void LlvmIrGenTestBase::CompileAndVerifyIr(const string& hlo_text,
+ const string& expected_llvm_ir,
+ bool match_optimized_ir) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_optimized_ir);
+}
+
+void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const AotCompilationOptions& options,
const string& pattern, bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
@@ -71,11 +82,11 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr(
EXPECT_TRUE(filecheck_result.ValueOrDie());
}
-LLVMCompiler* LLVMIRGenTestBase::GetLLVMCompiler() {
+LLVMCompiler* LlvmIrGenTestBase::GetLLVMCompiler() {
return static_cast<LLVMCompiler*>(backend().compiler());
}
-Status LLVMIRGenTestBase::IrHook(const llvm::Module& module) {
+Status LlvmIrGenTestBase::IrHook(const llvm::Module& module) {
ir_ = llvm_ir::DumpModuleToString(module);
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
index 74cbb5f5df..018f9546af 100644
--- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
+++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// Tests that verify IR emitted by the CPU/GPU backend is as expected.
-class LLVMIRGenTestBase : public CodegenTestBase {
+class LlvmIrGenTestBase : public CodegenTestBase {
protected:
// Compiles the given HLO module to LLVM IR and verifies the IR matches the
// given pattern. `pattern` is in the FileCheck pattern matching syntax
@@ -38,6 +38,12 @@ class LLVMIRGenTestBase : public CodegenTestBase {
void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
const string& pattern, bool match_optimized_ir);
+ // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create
+ // an HLO module.
+ void CompileAndVerifyIr(const string& hlo_text,
+ const string& expected_llvm_ir,
+ bool match_optimized_ir = false);
+
// Compiles the given HLO module to LLVM IR and verifies the IR matches the
// given pattern. `pattern` is in the FileCheck pattern matching syntax
// (http://llvm.org/docs/CommandGuide/FileCheck.html).
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index f21f83992f..0df50150ae 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -38,14 +38,14 @@ class LocalClientAllocationTest : public LocalClientTestBase {
XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
- auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- builder.Add(x, y);
+ auto x = ConstantR1<float>(&builder, {0.0f, 1.0f, 2.0f});
+ auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ Add(x, y);
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
int64 allocation_count_before = allocator_->allocation_count();
@@ -74,9 +74,9 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
// Run a computation on every device on the system. Verify that allocation
// occurs on the proper device.
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
- auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- builder.Add(x, y);
+ auto x = ConstantR1<float>(&builder, {0.0f, 1.0f, 2.0f});
+ auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ Add(x, y);
auto computation = builder.Build().ConsumeValueOrDie();
TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform());
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index a366afe826..70612e7c49 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -37,8 +37,8 @@ using xla::string;
xla::XlaComputation Doubler() {
xla::XlaBuilder builder("doubler");
auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
- auto x = builder.Parameter(0, r0f32, "x");
- builder.Mul(x, builder.ConstantR0<float>(2.0));
+ auto x = xla::Parameter(&builder, 0, r0f32, "x");
+ xla::Mul(x, xla::ConstantR0<float>(&builder, 2.0));
return std::move(builder.Build().ValueOrDie());
}
@@ -51,10 +51,10 @@ int main(int argc, char** argv) {
xla::XlaBuilder builder("aot_test_helper");
auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
- auto opaque_param = builder.Parameter(0, opaque_shape, "x");
+ auto opaque_param = Parameter(&builder, 0, opaque_shape, "x");
auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
- auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32);
- builder.Call(Doubler(), {sum});
+ auto sum = CustomCall(&builder, "SumStructElements", {opaque_param}, r0f32);
+ Call(&builder, Doubler(), {sum});
if (argc != 2) {
LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 96858c00d6..7c003fb81f 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -54,7 +54,7 @@ class LocalClientExecuteTest : public LocalClientTestBase {
XLA_TEST_F(LocalClientExecuteTest, Constant) {
XlaBuilder builder(TestName());
- auto y = builder.ConstantR0<float>(123.0f);
+ ConstantR0<float>(&builder, 123.0f);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
@@ -64,11 +64,11 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.ConstantR0<float>(123.0f);
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = ConstantR0<float>(&builder, 123.0f);
+ Add(x, y);
- auto x_value = LiteralToShapedBuffer(*Literal::CreateR0<float>(42.0f));
+ auto x_value = LiteralToShapedBuffer(*LiteralUtil::CreateR0<float>(42.0f));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
LiteralTestUtil::ExpectR0Near<float>(165.f, *ShapedBufferToLiteral(result),
@@ -77,11 +77,11 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x");
- auto y = builder.ConstantR1<float>({});
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
+ auto y = ConstantR1<float>(&builder, {});
+ Add(x, y);
- auto x_array = LiteralToShapedBuffer(*Literal::CreateR1<float>({}));
+ auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>({}, *ShapedBufferToLiteral(result),
@@ -90,12 +90,12 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
- auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
LiteralTestUtil::ExpectR1Near<float>(
@@ -104,12 +104,12 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
- auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ExecutionProfile profile;
ScopedShapedBuffer result = ExecuteLocallyOrDie(
builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
@@ -122,19 +122,19 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ Add(x, y);
auto computation = builder.Build().ConsumeValueOrDie();
// Create x as a col-major array.
- auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ auto x_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
{{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({0, 1})));
// Create y as a row-major array.
- auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+ auto y_array = LiteralToShapedBuffer(*LiteralUtil::CreateR2WithLayout(
{{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(),
LayoutUtil::MakeLayout({1, 0})));
@@ -155,15 +155,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ Add(x, y);
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
// Run with col-major result layout.
ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
@@ -192,15 +192,15 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
- builder.Tuple({x, y, x});
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ Tuple(&builder, {x, y, x});
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -209,27 +209,26 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {2}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
- auto inner_tuple = builder.Tuple({x, y, x});
- builder.Tuple({inner_tuple, x});
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ auto inner_tuple = Tuple(&builder, {x, y, x});
+ Tuple(&builder, {inner_tuple, x});
auto computation = builder.Build().ConsumeValueOrDie();
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
auto y_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
+ *LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
ScopedShapedBuffer result =
ExecuteLocallyOrDie(computation, {&x_array, &y_array});
@@ -238,28 +237,25 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{10.0f, 20.0f}, {30.0f, 40.0f}},
- LiteralSlice(*result_literal, {0, 1}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}},
- LiteralSlice(*result_literal, {0, 2}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0, 0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
+ LiteralSlice(*result_literal, {0, 1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0, 2}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
// Verify setting the result layout of a computation with a tuple output.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
- builder.Tuple({x, y});
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
+ Tuple(&builder, {x, y});
auto array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
+ *LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
ExecutableBuildOptions options = DefaultExecutableBuildOptions();
Shape shape_with_layout = ShapeUtil::MakeTupleShape(
@@ -273,10 +269,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
options, DefaultExecutableRunOptions());
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
@@ -291,23 +287,23 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
// Computation adds the respective array and vector elements from each tuple
// argument and returns the results as a tuple.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, tuple_shape0, "x");
- auto y = builder.Parameter(1, tuple_shape1, "y");
- auto x_0 = builder.GetTupleElement(x, 0);
- auto x_1 = builder.GetTupleElement(x, 1);
- auto y_0 = builder.GetTupleElement(y, 0);
- auto y_1 = builder.GetTupleElement(y, 1);
- auto array_sum = builder.Add(x_0, y_1);
- auto vector_diff = builder.Sub(x_1, y_0);
- builder.Tuple({array_sum, vector_diff});
+ auto x = Parameter(&builder, 0, tuple_shape0, "x");
+ auto y = Parameter(&builder, 1, tuple_shape1, "y");
+ auto x_0 = GetTupleElement(x, 0);
+ auto x_1 = GetTupleElement(x, 1);
+ auto y_0 = GetTupleElement(y, 0);
+ auto y_1 = GetTupleElement(y, 1);
+ auto array_sum = Add(x_0, y_1);
+ auto vector_diff = Sub(x_1, y_0);
+ Tuple(&builder, {array_sum, vector_diff});
auto computation = builder.Build().ConsumeValueOrDie();
- auto x_literal = Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR1<float>({42.0, 75.0, 123.0}).get()});
- auto y_literal = Literal::MakeTuple(
- {Literal::CreateR1<float>({2.0, 4.0, 6.0}).get(),
- Literal::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
+ auto x_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()});
+ auto y_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}).get(),
+ LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}}).get()});
auto x_buffer = LiteralToShapedBuffer(*x_literal);
auto y_buffer = LiteralToShapedBuffer(*y_literal);
@@ -319,11 +315,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{56.0f, 46.0f}, {36.0f, 26.0f}},
- LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
@@ -338,32 +333,32 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
// Computation negates the array element and sums the two vector elements in
// the nested tuple. The resulting array and vector are returned as a tuple.
XlaBuilder builder(TestName());
- auto param = builder.Parameter(0, nested_tuple_shape, "param");
- auto inner_tuple = builder.GetTupleElement(param, 0);
- auto inner_array = builder.GetTupleElement(inner_tuple, 0);
- auto inner_vector = builder.GetTupleElement(inner_tuple, 1);
- auto outer_vector = builder.GetTupleElement(param, 1);
-
- auto negate_array = builder.Neg(inner_array);
- auto vector_sum = builder.Add(inner_vector, outer_vector);
- builder.Tuple({negate_array, vector_sum});
+ auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
+ auto inner_tuple = GetTupleElement(param, 0);
+ auto inner_array = GetTupleElement(inner_tuple, 0);
+ auto inner_vector = GetTupleElement(inner_tuple, 1);
+ auto outer_vector = GetTupleElement(param, 1);
+
+ auto negate_array = Neg(inner_array);
+ auto vector_sum = Add(inner_vector, outer_vector);
+ Tuple(&builder, {negate_array, vector_sum});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = Literal::MakeTuple(
- {Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR1<float>({42.0, 75.0, 123.0}).get()})
+ auto arg_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0}).get()})
.get(),
- Literal::CreateR1<float>({222.0, -2.0, 10.0}).get()});
+ LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0}).get()});
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
std::unique_ptr<Literal> result_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
+ LiteralSlice(*result_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
+ LiteralSlice(*result_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
@@ -376,31 +371,30 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
ShapeUtil::MakeTupleShape({array_shape, array_shape});
XlaBuilder builder(TestName());
- auto param = builder.Parameter(0, tuple_shape, "param");
- auto element_0 = builder.GetTupleElement(param, 0);
- auto element_1 = builder.GetTupleElement(param, 1);
- builder.Tuple({builder.Neg(element_0), builder.Add(element_1, element_1)});
+ auto param = Parameter(&builder, 0, tuple_shape, "param");
+ auto element_0 = GetTupleElement(param, 0);
+ auto element_1 = GetTupleElement(param, 1);
+ Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
auto computation = builder.Build().ConsumeValueOrDie();
- auto arg_literal = Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
- Literal::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
+ auto arg_literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}).get(),
+ LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}}).get()});
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
std::unique_ptr<Literal> result_0_literal = ShapedBufferToLiteral(result_0);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{-1.0, -2.0}, {-3.0, -4.0}},
- LiteralSlice(*result_0_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
+ LiteralSlice(*result_0_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
+ LiteralSlice(*result_0_literal, {1}));
ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
std::unique_ptr<Literal> result_1_literal = ShapedBufferToLiteral(result_1);
- LiteralTestUtil::ExpectR2Equal<float>(
- {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0}));
- LiteralTestUtil::ExpectR2Equal<float>(
- {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1}));
+ LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
+ LiteralSlice(*result_1_literal, {0}));
+ LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
+ LiteralSlice(*result_1_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
@@ -420,26 +414,25 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
XlaBuilder builder(TestName());
- auto param = builder.Parameter(0, tuple_shape, "param");
+ auto param = Parameter(&builder, 0, tuple_shape, "param");
// Add each element's tuple index value to every element.
std::vector<XlaOp> result_elements;
for (int i = 0; i < kElementCount; ++i) {
- auto element = builder.GetTupleElement(param, i);
- result_elements.push_back(
- builder.Add(element, builder.ConstantR0<float>(i)));
+ auto element = GetTupleElement(param, i);
+ result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
}
- builder.Tuple(result_elements);
+ Tuple(&builder, result_elements);
auto computation = builder.Build().ConsumeValueOrDie();
// Feed in a tuple where each two-element vector element is {tuple_index,
// -tuple_index}.
std::vector<std::unique_ptr<Literal>> arg_elements;
for (int i = 0; i < kElementCount; ++i) {
- arg_elements.push_back(Literal::CreateR1<float>({1.0f * i, -1.0f * i}));
+ arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
}
std::unique_ptr<Literal> arg_literal =
- Literal::MakeTupleOwned(std::move(arg_elements));
+ LiteralUtil::MakeTupleOwned(std::move(arg_elements));
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -447,8 +440,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
for (int i = 0; i < kElementCount; ++i) {
LiteralTestUtil::ExpectR1Near<float>(
- {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}),
- error_spec_);
+ {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_);
}
}
@@ -465,22 +457,22 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
XlaBuilder builder(TestName());
- auto param = builder.Parameter(0, tuple_shape, "param");
+ auto param = Parameter(&builder, 0, tuple_shape, "param");
// The computation increments each leaf value by an amount equal to the leaf's
// ordinal position in a traversal of the tuple.
std::vector<XlaOp> result_elements;
for (int i = 0; i < kFanout; ++i) {
- auto outer_element = builder.GetTupleElement(param, i);
+ auto outer_element = GetTupleElement(param, i);
std::vector<XlaOp> inner_result_elements;
for (int j = 0; j < kFanout; ++j) {
- auto inner_element = builder.GetTupleElement(outer_element, j);
- inner_result_elements.push_back(builder.Add(
- inner_element, builder.ConstantR0<float>(i * kFanout + j)));
+ auto inner_element = GetTupleElement(outer_element, j);
+ inner_result_elements.push_back(
+ Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
}
- result_elements.push_back(builder.Tuple(inner_result_elements));
+ result_elements.push_back(Tuple(&builder, inner_result_elements));
}
- builder.Tuple(result_elements);
+ Tuple(&builder, result_elements);
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
@@ -488,12 +480,13 @@ XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
for (int i = 0; i < kFanout; ++i) {
std::vector<std::unique_ptr<Literal>> inner_tuple_elements;
for (int j = 0; j < kFanout; ++j) {
- inner_tuple_elements.push_back(Literal::CreateR0<float>(i + j));
+ inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
}
outer_tuple_elements.push_back(
- Literal::MakeTupleOwned(std::move(inner_tuple_elements)));
+ LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
}
- auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements));
+ auto arg_literal =
+ LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
@@ -520,23 +513,23 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
}
XlaBuilder builder(TestName());
- auto element = builder.Parameter(0, shape, "param");
+ auto element = Parameter(&builder, 0, shape, "param");
for (int i = 0; i < kTupleDepth; ++i) {
- element = builder.GetTupleElement(element, 0);
+ element = GetTupleElement(element, 0);
}
- auto output = builder.Add(element, builder.ConstantR0<float>(42.0));
+ auto output = Add(element, ConstantR0<float>(&builder, 42.0));
for (int i = 0; i < kTupleDepth; ++i) {
- output = builder.Tuple({output});
+ output = Tuple(&builder, {output});
}
auto computation = builder.Build().ConsumeValueOrDie();
// Construct the argument to pass to the computation.
- std::unique_ptr<Literal> arg_literal = Literal::CreateR0<float>(123.0);
+ std::unique_ptr<Literal> arg_literal = LiteralUtil::CreateR0<float>(123.0);
for (int i = 0; i < kTupleDepth; ++i) {
std::vector<std::unique_ptr<Literal>> arg_vector;
arg_vector.push_back(std::move(arg_literal));
- arg_literal = Literal::MakeTupleOwned(std::move(arg_vector));
+ arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
}
auto arg_buffer = LiteralToShapedBuffer(*arg_literal);
@@ -547,19 +540,19 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
for (int i = 0; i < kTupleDepth; ++i) {
index.push_back(0);
}
- LiteralTestUtil::ExpectR0Equal<float>(
- 165.0, LiteralSlice(*result_literal, index));
+ LiteralTestUtil::ExpectR0Equal<float>(165.0,
+ LiteralSlice(*result_literal, index));
}
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
// Test passing in an invalid number of arguments.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
+ Add(x, y);
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({1.0f, 2.0f, 3.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -571,11 +564,11 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
// Test passing in an argument with the wrong shape.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
- builder.Neg(x);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
+ Neg(x);
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status =
ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
@@ -588,11 +581,11 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
// Test passing in an invalid result layout parameter.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
- builder.Neg(x);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
+ Neg(x);
auto x_array = LiteralToShapedBuffer(
- *Literal::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
+ *LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {&x_array},
DefaultExecutableBuildOptions().set_result_layout(
@@ -611,7 +604,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
// Try to run a trivial computation on every device on the system. If a
// specific device is not supported, check that the right error is returned.
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0f);
+ ConstantR0<float>(&builder, 42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
for (int d = 0; d < local_client_->device_count(); ++d) {
if (!local_client_->device_ordinal_supported(d)) {
@@ -638,7 +631,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
// Try running computations on devices with device ordinal values which do not
// exist.
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0f);
+ ConstantR0<float>(&builder, 42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
auto execute_status =
@@ -655,7 +648,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// Run a computation on a specific stream on each device on the system.
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0f);
+ ConstantR0<float>(&builder, 42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
for (int d = 0; d < local_client_->device_count(); ++d) {
@@ -691,7 +684,7 @@ XLA_TEST_F(LocalClientExecuteTest,
wrong_stream.Init();
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0f);
+ ConstantR0<float>(&builder, 42.0f);
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions().set_stream(&wrong_stream));
@@ -708,7 +701,7 @@ XLA_TEST_F(LocalClientExecuteTest,
TestAllocator allocator(wrong_platform);
XlaBuilder builder(TestName());
- auto y = builder.ConstantR0<float>(123.0f);
+ ConstantR0<float>(&builder, 123.0f);
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
@@ -721,7 +714,7 @@ XLA_TEST_F(LocalClientExecuteTest,
XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
// Try to run a computation on a stream that has not been initialized.
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(42.0f);
+ ConstantR0<float>(&builder, 42.0f);
LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
se::StreamExecutor* executor =
@@ -744,26 +737,26 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto tuple12 = builder.Tuple(
- {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
- auto tuple21 = builder.Tuple(
- {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
+ ConstantR1<float>(&builder, vec2)});
+ auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
+ ConstantR1<float>(&builder, vec1)});
+ Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
ScopedShapedBuffer result =
ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
std::unique_ptr<Literal> tuple_literal = ShapedBufferToLiteral(result);
- LiteralTestUtil::ExpectR1Equal<float>(
- {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0}));
- LiteralTestUtil::ExpectR1Equal<float>(
- {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1}));
+ LiteralTestUtil::ExpectR1Equal<float>({2.0f, 4.0f, 6.0f},
+ LiteralSlice(*tuple_literal, {0}));
+ LiteralTestUtil::ExpectR1Equal<float>({1.0f, 2.0f, 3.0f},
+ LiteralSlice(*tuple_literal, {1}));
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
- auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
+ auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
+ Add(x, y);
Shape argument_layout =
ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
@@ -775,7 +768,7 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
executable_status.ConsumeValueOrDie();
auto x_array =
- LiteralToShapedBuffer(*Literal::CreateR1<float>({0.0f, 1.0f, 2.0f}));
+ LiteralToShapedBuffer(*LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
ScopedShapedBuffer result =
executable->Run({&x_array}, DefaultExecutableRunOptions())
.ConsumeValueOrDie();
@@ -799,29 +792,29 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
};
// Array shapes.
- test_to_device_and_back(*Literal::CreateR0<float>(42.0));
- test_to_device_and_back(*Literal::CreateR0<bool>(true));
- test_to_device_and_back(*Literal::CreateR1<float>({1.0, 42.0, 744.4}));
+ test_to_device_and_back(*LiteralUtil::CreateR0<float>(42.0));
+ test_to_device_and_back(*LiteralUtil::CreateR0<bool>(true));
+ test_to_device_and_back(*LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
test_to_device_and_back(
- *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*Literal::CreateR2<int32>({{2, 1}, {4444, 56}}));
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(*LiteralUtil::CreateR2<int32>({{2, 1}, {4444, 56}}));
// Null shape (empty tuple).
- test_to_device_and_back(*Literal::MakeTuple({}));
+ test_to_device_and_back(*LiteralUtil::MakeTuple({}));
// Non-nested tuples.
test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR0<float>(12223.0).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(12223.0).get()}));
test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
- Literal::CreateR0<float>(123456.0).get()}));
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<float>(123456.0).get()}));
// Nested tuple.
- test_to_device_and_back(*Literal::MakeTuple(
- {Literal::MakeTuple({Literal::CreateR1<float>({1.0, -42.0}).get(),
- Literal::CreateR0<float>(123456.0).get()})
+ test_to_device_and_back(*LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<float>(123456.0).get()})
.get(),
- Literal::CreateR0<bool>(false).get()}));
+ LiteralUtil::CreateR0<bool>(false).get()}));
}
XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
@@ -839,13 +832,38 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
};
test_to_device_and_back(
- *Literal::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
- test_to_device_and_back(*Literal::CreateR2<int64>({{2, 1}, {4444, 56}}));
+ *LiteralUtil::CreateR2<double>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
+ test_to_device_and_back(*LiteralUtil::CreateR2<int64>({{2, 1}, {4444, 56}}));
test_to_device_and_back(
- *Literal::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
- test_to_device_and_back(
- *Literal::MakeTuple({Literal::CreateR1<double>({1.0, -42.0}).get(),
- Literal::CreateR0<int64>(123456789000LL).get()}));
+ *LiteralUtil::CreateR2<uint64>({{20000000000ULL, 1}, {4444, 56}}));
+ test_to_device_and_back(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<double>({1.0, -42.0}).get(),
+ LiteralUtil::CreateR0<int64>(123456789000LL).get()}));
+}
+
+XLA_TEST_F(LocalClientExecuteTest, InfeedTest) {
+ XlaBuilder builder(TestName());
+ const Shape shape = ShapeUtil::MakeShape(F32, {3});
+ auto in = Infeed(&builder, shape);
+ auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
+ Add(in, constant);
+
+ std::unique_ptr<Literal> result;
+ std::unique_ptr<tensorflow::Thread> thread(
+ tensorflow::Env::Default()->StartThread(
+ tensorflow::ThreadOptions(), "execute_thread", [&] {
+ result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
+ builder.Build().ValueOrDie(), /*arguments=*/{}));
+ }));
+
+ ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
+ *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
+ local_client_->default_device_ordinal()));
+
+ // Join the thread.
+ thread.reset();
+
+ LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, *result);
}
// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel.
@@ -853,10 +871,10 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3});
- auto in = builder.Infeed(shape);
- auto constant = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f});
- auto sum = builder.Add(in, constant);
- builder.Outfeed(sum, shape, /*outfeed_config=*/"");
+ auto in = Infeed(&builder, shape);
+ auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
+ auto sum = Add(in, constant);
+ Outfeed(sum, shape, /*outfeed_config=*/"");
std::unique_ptr<tensorflow::Thread> thread(
tensorflow::Env::Default()->StartThread(
@@ -864,7 +882,7 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
[&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
- *Literal::CreateR1<float>({-5.0, 123.0, 42.0}),
+ *LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
@@ -891,17 +909,19 @@ void BM_LocalClientOverhead(int num_iters) {
// Use a tiny add operation as the computation.
XlaBuilder builder("Add");
auto shape = ShapeUtil::MakeShape(F32, {2, 3});
- auto x = builder.Parameter(0, shape, "x");
- builder.Add(x, x);
+ auto x = Parameter(&builder, 0, shape, "x");
+ Add(x, x);
auto computation = builder.Build().ConsumeValueOrDie();
auto buffer =
transfer_manager
->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
.ConsumeValueOrDie();
- auto literal = Literal::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
- ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
- executors[device_ordinal], *literal, buffer));
+ auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
+ auto stream =
+ client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(stream.get(), *literal,
+ buffer));
const int kWarmups = 2;
@@ -911,11 +931,8 @@ void BM_LocalClientOverhead(int num_iters) {
std::unique_ptr<LocalExecutable> executable =
executable_status.ConsumeValueOrDie();
- se::Stream stream(executors[client->default_device_ordinal()]);
- stream.Init();
-
ExecutableRunOptions run_options;
- run_options.set_allocator(&allocator).set_stream(&stream);
+ run_options.set_allocator(&allocator).set_stream(stream.get());
for (int i = 0; i < kWarmups; ++i) {
auto result = executable->Run({&buffer}, run_options);
diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc
index c0c02e584c..cdf70ee418 100644
--- a/tensorflow/compiler/xla/tests/log_test.cc
+++ b/tensorflow/compiler/xla/tests/log_test.cc
@@ -30,8 +30,8 @@ class LogTest : public ClientLibraryTestBase {};
XLA_TEST_F(LogTest, LogZeroValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR3FromArray3D<float>(Array3D<float>(3, 0, 0));
- builder.Log(x);
+ auto x = ConstantR3FromArray3D<float>(&builder, Array3D<float>(3, 0, 0));
+ Log(x);
ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 0), {},
ErrorSpec(0.0001));
@@ -42,8 +42,8 @@ TEST_F(LogTest, LogTenValues) {
5.0, 6.0, -7.0, -8.0, 9.0};
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(input);
- builder.Log(x);
+ auto x = ConstantR1<float>(&builder, input);
+ Log(x);
std::vector<float> expected;
expected.reserve(input.size());
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 7df45bebeb..7ddc636931 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -52,9 +52,9 @@ class MapTest : public ClientLibraryTestBase {
// 1.0f ---------/
XlaComputation CreateAdderToOne() {
XlaBuilder mapped_builder(TestName());
- auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto one = mapped_builder.ConstantR0<float>(1.0);
- mapped_builder.Add(x, one);
+ auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = ConstantR0<float>(&mapped_builder, 1.0);
+ Add(x, one);
auto computation_status = mapped_builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -62,9 +62,9 @@ class MapTest : public ClientLibraryTestBase {
XlaComputation CreateMax() {
XlaBuilder b(TestName());
- auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- b.Max(lhs, rhs);
+ auto lhs = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto rhs = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Max(lhs, rhs);
auto computation_status = b.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -75,8 +75,8 @@ class MapTest : public ClientLibraryTestBase {
template <class T>
XlaComputation CreateScalarOne() {
XlaBuilder mapped_builder("scalar_one");
- (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- mapped_builder.ConstantR0<T>(1);
+ (void)Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ ConstantR0<T>(&mapped_builder, 1);
auto computation_status = mapped_builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -89,9 +89,9 @@ class MapTest : public ClientLibraryTestBase {
// 2.0f ---------/
XlaComputation CreateMulByTwo() {
XlaBuilder mapped_builder(TestName());
- auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto two = mapped_builder.ConstantR0<float>(2.0);
- mapped_builder.Mul(x, two);
+ auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto two = ConstantR0<float>(&mapped_builder, 2.0);
+ Mul(x, two);
auto computation_status = mapped_builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -107,10 +107,10 @@ class MapTest : public ClientLibraryTestBase {
// 1.0f ---------/
XlaComputation CreateAdderToOneTimesItself() {
XlaBuilder mapped_builder(TestName());
- auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto one = mapped_builder.ConstantR0<float>(1.0);
- auto adder_to_one = mapped_builder.Add(x, one);
- mapped_builder.Mul(x, adder_to_one);
+ auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = ConstantR0<float>(&mapped_builder, 1.0);
+ auto adder_to_one = Add(x, one);
+ Mul(x, adder_to_one);
auto computation_status = mapped_builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -125,10 +125,10 @@ class MapTest : public ClientLibraryTestBase {
XlaComputation CreateMapPlusN(const XlaComputation& embedded_computation,
float n) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto map = builder.Map({x}, embedded_computation, {});
- auto constant_n = builder.ConstantR0<float>(n);
- builder.Add(map, constant_n);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto map = Map(&builder, {x}, embedded_computation, {});
+ auto constant_n = ConstantR0<float>(&builder, n);
+ Add(map, constant_n);
auto computation_status = builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -138,9 +138,9 @@ class MapTest : public ClientLibraryTestBase {
// defined by (x, y) -> x > y.
XlaComputation CreateGt() {
XlaBuilder b("Gt");
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- b.Gt(x, y);
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Gt(x, y);
auto computation_status = b.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -155,11 +155,11 @@ class MapTest : public ClientLibraryTestBase {
// z {R0F32} ---------------/
XlaComputation CreateTernaryAdder() {
XlaBuilder mapped_builder("TernaryAdder");
- auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z");
- auto xy = mapped_builder.Add(x, y);
- mapped_builder.Add(xy, z);
+ auto x = Parameter(&mapped_builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&mapped_builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto z = Parameter(&mapped_builder, 2, ShapeUtil::MakeShape(F32, {}), "z");
+ auto xy = Add(x, y);
+ Add(xy, z);
auto computation_status = mapped_builder.Build();
TF_CHECK_OK(computation_status.status());
return computation_status.ConsumeValueOrDie();
@@ -169,12 +169,12 @@ class MapTest : public ClientLibraryTestBase {
TEST_F(MapTest, MapEachElemPlusOneR0) {
// Applies lambda (x) (+ x 1)) to an input scalar.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(42.0);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateAdderToOne(), {});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateAdderToOne(), {});
ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
ErrorSpec(0.01f));
@@ -183,12 +183,12 @@ TEST_F(MapTest, MapEachElemPlusOneR0) {
XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateAdderToOne(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -198,12 +198,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateAdderToOne(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateAdderToOne(), {0});
ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
{param0_data.get()}, ErrorSpec(0.01f));
@@ -212,12 +212,12 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) {
TEST_F(MapTest, MapEachF32ElementToS32Constant) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateScalarOne<int32>(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateScalarOne<int32>(), {0});
ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}
@@ -225,12 +225,12 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) {
TEST_F(MapTest, MapEachF32ElementToU32Constant) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateScalarOne<uint32>(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateScalarOne<uint32>(), {0});
ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
}
@@ -239,12 +239,12 @@ TEST_F(MapTest, MapEachElemLongerChainR1) {
// Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
+ LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateAdderToOneTimesItself(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateAdderToOneTimesItself(), {0});
ComputeAndCompareR1<float>(
&builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
@@ -255,13 +255,13 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
// Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
- builder.Map({map1}, CreateMulByTwo(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
+ Map(&builder, {map1}, CreateMulByTwo(), {0});
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -272,13 +272,13 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
// maps (lambda (x) (* x 2)) on the result.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- auto map1 = builder.Map({param}, CreateAdderToOne(), {0});
- builder.Map({map1}, CreateMulByTwo(), {0});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto map1 = Map(&builder, {param}, CreateAdderToOne(), {0});
+ Map(&builder, {map1}, CreateMulByTwo(), {0});
ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
{param0_data.get()}, ErrorSpec(0.01f));
@@ -287,13 +287,13 @@ TEST_F(MapTest, MapMultipleMapsR1S4) {
TEST_F(MapTest, MapEachElemPlusOneR2) {
// Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
{{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param}, CreateAdderToOne(), {0, 1});
+ auto param = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param}, CreateAdderToOne(), {0, 1});
Array2D<float> expected_array(
{{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
@@ -319,10 +319,10 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
auto embed3 = CreateMapPlusN(embed1, 4.0);
XlaBuilder embed4_builder("embed4");
- auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
- auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {});
- auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {});
- embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
+ auto embed4_param = Parameter(&embed4_builder, 0, scalar_shape, "x");
+ auto embed4_map_lhs = Map(&embed4_builder, {embed4_param}, embed2, {});
+ auto embed4_map_rhs = Map(&embed4_builder, {embed4_param}, embed3, {});
+ Add(embed4_map_lhs, embed4_map_rhs);
auto embed4_status = embed4_builder.Build();
ASSERT_IS_OK(embed4_status.status());
auto embed4 = embed4_status.ConsumeValueOrDie();
@@ -330,11 +330,11 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
auto embed5 = CreateMapPlusN(embed2, 6.0);
XlaBuilder builder(TestName());
- auto constant_42 = builder.ConstantR0<float>(42.0);
- auto constant_7 = builder.ConstantR0<float>(7.0);
- auto map_42 = builder.Map({constant_42}, embed5, {});
- auto map_7 = builder.Map({constant_7}, embed4, {});
- builder.Add(map_42, map_7);
+ auto constant_42 = ConstantR0<float>(&builder, 42.0);
+ auto constant_7 = ConstantR0<float>(&builder, 7.0);
+ auto map_42 = Map(&builder, {constant_42}, embed5, {});
+ auto map_7 = Map(&builder, {constant_7}, embed4, {});
+ Add(map_42, map_7);
ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
}
@@ -343,17 +343,18 @@ TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder), {0});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, CreateScalarAddComputation(F32, &builder),
+ {0});
ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
{param0_data.get(), param1_data.get()},
@@ -364,20 +365,20 @@ TEST_F(MapTest, MapBinaryAdder) {
// for Map that used to fail in shape inference (b/28989438).
XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2WithLayout(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR2WithLayout(
{{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder),
- {0, 1});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
+ {0, 1});
Array2D<int32> expected(2, 2);
expected(0, 0) = 11;
@@ -391,19 +392,19 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) {
XLA_TEST_F(MapTest, AddR3_3x0x2) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder),
- {0, 1, 2});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, CreateScalarAddComputation(S32, &builder),
+ {0, 1, 2});
ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
{param0_data.get(), param1_data.get()});
@@ -413,22 +414,22 @@ TEST_F(MapTest, MapTernaryAdder) {
// Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param2_literal =
- Literal::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
+ LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
std::unique_ptr<GlobalData> param2_data =
client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
- builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ auto param2 = Parameter(&builder, 2, param2_literal->shape(), "param2");
+ Map(&builder, {param0, param1, param2}, CreateTernaryAdder(), {0});
ComputeAndCompareR1<float>(
&builder, {-2.7f, -92.3f, -895.7f, -400.0f},
@@ -440,7 +441,8 @@ TEST_F(MapTest, MapGt) {
// Maps (x,y) -> x > y onto two R1F32 vectors.
XlaBuilder b(TestName());
auto gt = CreateGt();
- b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt, {0});
+ Map(&b, {ConstantR1<float>(&b, {1, 20}), ConstantR1<float>(&b, {10, 2})}, gt,
+ {0});
ComputeAndCompareR1<bool>(&b, {false, true}, {});
}
@@ -449,15 +451,15 @@ TEST_F(MapTest, NestedBinaryMap) {
{
// max_with_square(x) = do max(x, x^2) via a map.
XlaBuilder b("max_with_square");
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- b.Map({x, b.Mul(x, x)}, CreateMax(), {});
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ Map(&b, {x, Mul(x, x)}, CreateMax(), {});
auto computation_status = b.Build();
ASSERT_IS_OK(computation_status.status());
max_with_square = computation_status.ConsumeValueOrDie();
}
XlaBuilder b(TestName());
- auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
- b.Map({input}, max_with_square, {0});
+ auto input = ConstantR1<float>(&b, {0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
+ Map(&b, {input}, max_with_square, {0});
ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
}
@@ -468,30 +470,29 @@ TEST_F(MapTest, MapOperantionWithBuildError) {
XlaBuilder builder(TestName());
auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
- auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y");
- sub_builder->Add(x, y);
+ auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(U16, {}), "y");
+ Add(x, y);
auto error_add = sub_builder->BuildAndNoteError();
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<Literal> param1_literal =
- Literal::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, error_add, {0});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, error_add, {0});
StatusOr<XlaComputation> computation_status = builder.Build();
ASSERT_TRUE(!computation_status.ok());
- EXPECT_THAT(
- computation_status.status().ToString(),
- ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with "
- "different element types: f32[] and u16[]"));
+ EXPECT_THAT(computation_status.status().ToString(),
+ ::testing::HasSubstr("error from: ErrorAdd: Binary op add with "
+ "different element types: f32[] and u16[]"));
}
// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
@@ -507,21 +508,21 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) {
XlaBuilder builder(TestName());
auto sub_builder = builder.CreateSubBuilder("power");
- auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- sub_builder->Pow(x, y);
+ auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Pow(x, y);
auto power = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, power, {});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, power, {});
ComputeAndCompareR0<float>(&builder, 32.0f,
{param0_data.get(), param1_data.get()},
@@ -534,21 +535,21 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) {
XlaBuilder builder(TestName());
auto sub_builder = builder.CreateSubBuilder("power");
- auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- sub_builder->Sub(y, x); // note that this is y - x, not x - y
+ auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(sub_builder.get(), 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Sub(y, x); // note that this is y - x, not x - y
auto sub_opposite = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(2.0f);
- std::unique_ptr<Literal> param1_literal = Literal::CreateR0<float>(5.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
- builder.Map({param0, param1}, sub_opposite, {});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, param1_literal->shape(), "param1");
+ Map(&builder, {param0, param1}, sub_opposite, {});
ComputeAndCompareR0<float>(
&builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f));
@@ -560,16 +561,16 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
XlaBuilder builder(TestName());
auto sub_builder = builder.CreateSubBuilder("power");
- auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- sub_builder->Mul(x, x);
+ auto x = Parameter(sub_builder.get(), 0, ShapeUtil::MakeShape(F32, {}), "x");
+ Mul(x, x);
auto square = sub_builder->BuildAndNoteError();
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(10.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(10.0f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
- builder.Map({param0}, square, {});
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
+ Map(&builder, {param0}, square, {});
ComputeAndCompareR0<float>(&builder, 100.0f, {param0_data.get()},
ErrorSpec(0.01f));
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index 27fd36e06a..069b8a881f 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -56,15 +56,15 @@ TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32);
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
using T = TypeParam;
XlaBuilder builder("exp_2x2");
- auto data = builder.ConstantR2FromArray2D<T>({
- {1.0f, 0.0f}, // row 0
- {-1.0f, 0.5f}, // row 1
- });
- builder.Exp(data);
+ auto data = ConstantR2FromArray2D<T>(&builder, {
+ {1.0f, 0.0f}, // row 0
+ {-1.0f, 0.5f}, // row 1
+ });
+ Exp(data);
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
- {0.36788f, 1.64872f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{2.71828f, 1.00000f}, // row 0
+ {0.36788f, 1.64872f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
}
@@ -76,43 +76,43 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
// add_half(x) = x + 0.5
XlaBuilder builder("add_half");
auto x_value =
- builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({}), "x_value");
- auto half = builder.ConstantR0<T>(static_cast<T>(0.5));
- builder.Add(x_value, half);
+ Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({}), "x_value");
+ auto half = ConstantR0<T>(&builder, static_cast<T>(0.5));
+ Add(x_value, half);
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
add_half = computation_status.ConsumeValueOrDie();
}
XlaBuilder builder("map_2x2");
- auto data = builder.ConstantR2FromArray2D<T>({
- {1.0f, 0.0f}, // row 0
- {-1.0f, 0.5f}, // row 1
- });
- auto map = builder.Map({data}, add_half, {0, 1});
+ auto data = ConstantR2FromArray2D<T>(&builder, {
+ {1.0f, 0.0f}, // row 0
+ {-1.0f, 0.5f}, // row 1
+ });
+ Map(&builder, {data}, add_half, {0, 1});
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
- {-0.5f, 1.0f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{1.5f, 0.5f}, // row 0
+ {-0.5f, 1.0f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
}
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
using T = TypeParam;
XlaBuilder builder("max_2x2");
- auto lhs = builder.ConstantR2FromArray2D<T>({
- {7.0f, 2.0f}, // row 0
- {3.0f, -4.0f}, // row 1
- });
- auto rhs = builder.ConstantR2FromArray2D<T>({
- {5.0f, 6.0f}, // row 0
- {1.0f, -8.0f}, // row 1
- });
- auto max = builder.Max(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, {
+ {7.0f, 2.0f}, // row 0
+ {3.0f, -4.0f}, // row 1
+ });
+ auto rhs = ConstantR2FromArray2D<T>(&builder, {
+ {5.0f, 6.0f}, // row 0
+ {1.0f, -8.0f}, // row 1
+ });
+ Max(lhs, rhs);
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
- {3.0f, -4.0f}}); // row 1
+ LiteralUtil::CreateR2FromArray2D<T>({{7.0f, 6.0f}, // row 0
+ {3.0f, -4.0f}}); // row 1
this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
}
@@ -137,9 +137,9 @@ class TestLinspaceMaxParametric
XlaBuilder builder(
tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
- auto lhs = builder.ConstantR2FromArray2D<T>(*alhs);
- auto rhs = builder.ConstantR2FromArray2D<T>(*arhs);
- auto max = builder.Max(lhs, rhs);
+ auto lhs = ConstantR2FromArray2D<T>(&builder, *alhs);
+ auto rhs = ConstantR2FromArray2D<T>(&builder, *arhs);
+ Max(lhs, rhs);
Array2D<T> expected(rows, cols);
for (int row = 0; row < rows; ++row) {
@@ -200,31 +200,33 @@ class MatOpsDotAddTest
TF_ASSERT_OK_AND_ASSIGN(
auto lhs_handle,
- client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ lhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
TF_ASSERT_OK_AND_ASSIGN(
auto rhs_handle,
- client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
+ client_->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
XlaBuilder builder(TestName());
- auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs");
+ auto lhs_arg = Parameter(&builder, 0, lhs_shape, "lhs");
auto lhs_mat_arg = lhs_arg;
if (transpose) {
- lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0});
+ lhs_mat_arg = Transpose(lhs_mat_arg, {1, 0});
}
- auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs");
- auto result = builder.Dot(lhs_mat_arg, rhs_arg);
+ auto rhs_arg = Parameter(&builder, 1, rhs_shape, "rhs");
+ auto result = Dot(lhs_mat_arg, rhs_arg);
Array2D<T> expected;
if (add_lhs) {
- result = builder.Add(result, lhs_arg);
+ result = Add(result, lhs_arg);
if (transpose) {
expected = Array2D<T>({{47.0f, 52.0f}, {71.0f, 78.0f}});
} else {
expected = Array2D<T>({{35.0f, 39.0f}, {81.0f, 89.0f}});
}
} else {
- result = builder.Add(result, rhs_arg);
+ result = Add(result, rhs_arg);
if (transpose) {
expected = Array2D<T>({{56.0f, 61.0f}, {80.0f, 87.0f}});
} else {
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
index 0791a71aac..e576f000ef 100644
--- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -33,9 +33,10 @@ class SliceTest : public ClientLibraryTestBase {};
XLA_TEST_F(SliceTest, Slice2D) {
XlaBuilder builder("slice_2d");
- auto original = builder.ConstantR2<float>(
+ auto original = ConstantR2<float>(
+ &builder,
{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
- builder.Slice(original, {2, 1}, {4, 3}, {1, 1});
+ Slice(original, {2, 1}, {4, 3}, {1, 1});
Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -45,8 +46,8 @@ XLA_TEST_F(SliceTest, Slice3D) {
XlaBuilder builder("slice_3d");
Array3D<float> array_3d(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
- auto original = builder.ConstantR3FromArray3D<float>(array_3d);
- builder.Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1});
+ auto original = ConstantR3FromArray3D<float>(&builder, array_3d);
+ Slice(original, {0, 0, 1}, {2, 1, 2}, {1, 1, 1});
Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 3cbb2452fb..eb06b115da 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -60,7 +60,7 @@ class MultiOutputFusionTest : public HloTestBase {
const Shape elem_shape2 = ShapeUtil::MakeShape(F32, {size, size});
auto const0 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(8.0f)));
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(8.0f)));
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, elem_shape0, "0"));
@@ -105,8 +105,9 @@ class MultiOutputFusionTest : public HloTestBase {
Literal expect(ShapeUtil::MakeShape(F32, {size, size}));
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
- auto actual = ExecuteAndTransfer(
- std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
+ auto actual =
+ ExecuteAndTransfer(std::move(hlo_module),
+ {LiteralUtil::CreateR0<float>(-9.0f).get(), &arg1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -165,7 +166,8 @@ class MultiOutputFusionTest : public HloTestBase {
Literal input1(ShapeUtil::MakeShape(F64, {size}));
input1.PopulateWithValue(1.);
- Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
+ Literal expect =
+ std::move(*LiteralUtil::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
@@ -198,16 +200,16 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(
- Literal::MakeTupleOwned(Literal::CreateR0<int32>(42)),
- Literal::CreateR0<float>(1.0)),
- Literal::MakeTupleOwned(Literal::CreateR0<float>(3.0),
- Literal::CreateR0<int32>(4)));
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
+ auto param = LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)),
+ LiteralUtil::CreateR0<float>(1.0)),
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<float>(3.0),
+ LiteralUtil::CreateR0<int32>(4)));
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *result, *Literal::MakeTupleOwned(Literal::CreateR0<int32>(42))));
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR0<int32>(42)), *result));
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
@@ -232,11 +234,10 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *result, *Literal::CreateR1<float>({0.0, 4.0, 9.0, 1.0})));
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, -1.0});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0, 1.0}, *result);
}
XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
@@ -266,11 +267,10 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR1<float>({1.0, 2.0, 3.0});
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
- EXPECT_TRUE(LiteralTestUtil::Equal(
- *result, *Literal::CreateR1<float>({0.0, 4.0, 9.0})));
+ auto param = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ LiteralTestUtil::ExpectR1Equal<float>({0.0, 4.0, 9.0}, *result);
}
const char* const kScalarOps = R"(
@@ -310,13 +310,15 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *result,
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{3, 7}, {11, 15}}),
- Literal::CreateR2<float>({{5, 16}, {36, 64}}))));
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
+ *result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -340,13 +342,15 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *result, *Literal::MakeTupleOwned(
- Literal::CreateR2<float>({{6, 8}, {10, 12}}),
- Literal::CreateR2<float>({{25, 36}, {49, 64}}))));
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
+ *result));
}
XLA_TEST_F(MultiOutputFusionTest,
@@ -357,9 +361,9 @@ XLA_TEST_F(MultiOutputFusionTest,
c0 = f32[] constant(0)
r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
- c1 = f32[] constant(5)
+ c1 = f32[] constant(1.17549e-38)
r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max
- r3 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Add
+ r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add
ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3)
}
@@ -371,13 +375,196 @@ XLA_TEST_F(MultiOutputFusionTest,
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
- TF_ASSERT_OK_AND_ASSIGN(auto result,
- Execute(std::move(module), {param.get()}));
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR1<float>({36, 64}),
+ LiteralUtil::CreateR1<float>({66, 138})),
+ *result));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0})
+ tuple(p0, r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}),
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}})),
+ *result));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0})
+ tuple(r1, mul, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{6, 8}, {10, 12}}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR2<float>({{25, 36}, {49, 64}})),
+ *result));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ c0 = f32[] constant(0)
+ r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(p0, p0)
+ c1 = f32[] constant(5)
+ b1 = f32[2,2,2]{2,1,0} broadcast(c1), dimensions={}
+ mul2 = f32[2,2,2]{2,1,0} multiply(p0, b1)
+ ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0})
+ tuple(r1, mul, mul2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param =
+ LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR1<float>({14, 22}),
+ LiteralUtil::CreateR3<float>(
+ {{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}),
+ LiteralUtil::CreateR3<float>(
+ {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})),
+ *result));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce {
+ p0 = f32[2,2,2]{2,1,0} parameter(0)
+ init1 = f32[] parameter(1)
+ init2 = f32[] parameter(2)
+ r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add
+ r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2)
+ }
+
+ ENTRY reduce {
+ p = f32[2,2,2]{2,1,0} parameter(0)
+ i = f32[] parameter(1)
+ j = f32[] parameter(2)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput,
+ calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param =
+ LiteralUtil::CreateR3<float>({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}});
+ auto init1 = LiteralUtil::CreateR0<float>(5);
+ auto init2 = LiteralUtil::CreateR0<float>(6);
+ std::unique_ptr<Literal> result = ExecuteNoHloPasses(
+ std::move(module), {param.get(), init1.get(), init2.get()});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{167, 172}, {176, 180}}),
+ LiteralUtil::CreateR2<float>({{6, 6}, {6, 8}})),
+ *result));
+}
+
+XLA_TEST_F(MultiOutputFusionTest,
+ DISABLED_ON_CPU(MultiOutputReduceFusionDifferentElementTypes)) {
+ const string testcase = tensorflow::strings::StrCat(kScalarOps, R"(
+ fused_reduce (p0: f16[2,2,2]) -> (f32[2,2], f32[2,2], f16[2,2,2]) {
+ p0 = f16[2,2,2]{2,1,0} parameter(0)
+ convert = f32[2,2,2]{2,1,0} convert(p0)
+ c0 = f32[] constant(0)
+ r1 = f32[2,2]{1,0} reduce(convert, c0), dimensions={2}, to_apply=Add
+ mul = f32[2,2,2]{2,1,0} multiply(convert, convert)
+ c1 = f32[] constant(5)
+ r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max
+ ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0})
+ tuple(r1, r2, p0)
+ }
+
+ ENTRY reduce {
+ p = f16[2,2,2]{2,1,0} parameter(0)
+ ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}, f16[2,2,2]{2,1,0}) fusion(p),
+ kind=kInput, calls=fused_reduce
+ })");
+ auto module =
+ HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
+ .ValueOrDie();
+ auto param = LiteralUtil::CreateR3<Eigen::half>(
+ {{{Eigen::half(1), Eigen::half(2)}, {Eigen::half(3), Eigen::half(4)}},
+ {{Eigen::half(5), Eigen::half(6)}, {Eigen::half(7), Eigen::half(8)}}});
+ std::unique_ptr<Literal> result =
+ ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *result, *Literal::MakeTupleOwned(Literal::CreateR1<float>({14, 22}),
- Literal::CreateR1<float>({36, 64}),
- Literal::CreateR1<float>({391, 463}))));
+ *LiteralUtil::MakeTupleOwned(
+ LiteralUtil::CreateR2<float>({{3, 7}, {11, 15}}),
+ LiteralUtil::CreateR2<float>({{5, 16}, {36, 64}}),
+ LiteralUtil::CreateR3<Eigen::half>(
+ {{{Eigen::half(1), Eigen::half(2)},
+ {Eigen::half(3), Eigen::half(4)}},
+ {{Eigen::half(5), Eigen::half(6)},
+ {Eigen::half(7), Eigen::half(8)}}})),
+ *result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
index ce295b832d..e428fa9b5e 100644
--- a/tensorflow/compiler/xla/tests/pad_test.cc
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -93,8 +93,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(0);
- b.Pad(AddParam(*Literal::CreateR1<float>({}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, {}, {}, DefaultErrorSpec());
}
@@ -108,8 +108,8 @@ XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) {
dimension->set_edge_padding_high(4);
dimension->set_interior_padding(7);
- b.Pad(AddParam(*Literal::CreateR1<float>({}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
DefaultErrorSpec());
}
@@ -123,16 +123,17 @@ XLA_TEST_P(PadTestFloat, Pad1DS3Array) {
dimension->set_edge_padding_high(0);
dimension->set_interior_padding(1);
- b.Pad(AddParam(*Literal::CreateR1<float>({1, 2, 3}), &b),
- AddParam(*Literal::CreateR0<float>(0.1), &b), padding_config);
+ Pad(AddParam(*LiteralUtil::CreateR1<float>({1, 2, 3}), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(0.1), &b), padding_config);
std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
ComputeAndCompareR1<float>(&b, expected, {}, DefaultErrorSpec());
}
XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) {
XlaBuilder b(TestName());
- b.Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
- AddParam(*Literal::CreateR0<float>(1.5), &b), r4_padding_on_dim0_dim1_);
+ Pad(AddParam(Array4D<float>(2, 0, 3, 2), &b),
+ AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ r4_padding_on_dim0_dim1_);
ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
DefaultErrorSpec());
}
@@ -147,8 +148,8 @@ TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) {
});
input->FillWithYX(input_xy);
- b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0<float>(1.5), &b),
- r4_padding_on_dim0_dim1_);
+ Pad(AddParam(*input, &b), AddParam(*LiteralUtil::CreateR0<float>(1.5), &b),
+ r4_padding_on_dim0_dim1_);
auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
expected->Fill(1.5);
@@ -166,8 +167,9 @@ TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) {
const float pad_value = 1.5f;
Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
- b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0<float>(pad_value), &b),
- r4_padding_on_dim0_dim1_);
+ Pad(AddParam(input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b),
+ r4_padding_on_dim0_dim1_);
auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
expected->Fill(pad_value);
@@ -205,11 +207,11 @@ TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) {
const float pad_value = -5.123f;
Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
input = input->Relayout(layout);
- b.Pad(AddParam(*input, &b),
- AddParam(*Literal::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(*input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 1, 5, 8);
expected_array.Fill(pad_value);
@@ -251,11 +253,11 @@ XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
input_array(0, 0, 0, 0) = 1.0f;
input_array(0, 24, 6, 6) = 2.0f;
input_array(0, 17, 2, 5) = 3.0f;
- auto input = Literal::CreateR4FromArray4D<float>(input_array);
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
input = input->Relayout(layout);
- b.Pad(AddParam(*input, &b),
- AddParam(*Literal::CreateR0<float>(pad_value), &b), padding_config);
+ Pad(AddParam(*input, &b),
+ AddParam(*LiteralUtil::CreateR0<float>(pad_value), &b), padding_config);
Array4D<float> expected_array(1, 25, 17, 11);
expected_array.Fill(pad_value);
@@ -275,8 +277,8 @@ XLA_TEST_F(PadTest, Pad4DU8Array) {
});
input->FillWithYX(input_xy);
- b.Pad(AddParam(*input, &b), b.ConstantR0<uint8>(35),
- r4_padding_on_dim0_dim1_);
+ Pad(AddParam(*input, &b), ConstantR0<uint8>(&b, 35),
+ r4_padding_on_dim0_dim1_);
auto expected = MakeUnique<Array4D<uint8>>(2, 3, 3, 2);
expected->Fill(35);
@@ -294,16 +296,16 @@ XLA_TEST_F(PadTest, Pad4DPredArray) {
// Since bool is currently not well supported, use Broadcast operation to
// create the operand for Pad.
- auto input = b.Broadcast(b.ConstantR0<bool>(true), {1, 1, 3, 2});
+ auto input = Broadcast(ConstantR0<bool>(&b, true), {1, 1, 3, 2});
auto padded =
- b.Pad(input, b.ConstantR0<bool>(false), r4_padding_on_dim0_dim1_);
+ Pad(input, ConstantR0<bool>(&b, false), r4_padding_on_dim0_dim1_);
// For the same reason, use Select to convert boolean values to int32.
auto zeros = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
auto ones = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
zeros->Fill(0);
ones->Fill(1);
- b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b));
+ Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b));
auto expected = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
expected->Fill(0);
@@ -329,7 +331,7 @@ XLA_TEST_P(PadTestFloat, Large2DPad) {
padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
100 * dim);
}
- b.Pad(input, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b), padding_config);
auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -351,7 +353,8 @@ XLA_TEST_P(PadTestFloat, AllTypes2DPad) {
padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
padding_config.mutable_dimensions(1)->set_interior_padding(2);
- b.Pad(input, AddParam(*Literal::CreateR0<float>(3.14f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(3.14f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
ComputeAndCompareR2<float>(&b, *expected, {}, DefaultErrorSpec());
@@ -376,7 +379,8 @@ XLA_TEST_P(PadTestFloat, High2DPad) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -403,7 +407,8 @@ XLA_TEST_P(PadTestFloat, NegativePadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding);
}
- b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -430,7 +435,8 @@ XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) {
padding_config.mutable_dimensions(dim)->set_interior_padding(
interior_padding[dim]);
}
- b.Pad(input, AddParam(*Literal::CreateR0<float>(2.718f), &b), padding_config);
+ Pad(input, AddParam(*LiteralUtil::CreateR0<float>(2.718f), &b),
+ padding_config);
auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
@@ -446,12 +452,13 @@ XLA_TEST_P(PadTestFloat, ReducePad) {
XlaComputation add = CreateScalarAddComputation(FloatType(), &b);
auto reduce =
- b.Reduce(input, AddParam(*Literal::CreateR0<float>(0.0), &b), add, {0});
+ Reduce(input, AddParam(*LiteralUtil::CreateR0<float>(0.0), &b), add, {0});
PaddingConfig padding_config = MakeNoPaddingConfig(3);
padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
- b.Pad(reduce, AddParam(*Literal::CreateR0<float>(0.0f), &b), padding_config);
+ Pad(reduce, AddParam(*LiteralUtil::CreateR0<float>(0.0f), &b),
+ padding_config);
Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
{{2.0, 2.0}, {2.0, 2.0}},
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index 838f1b4e2f..8ba1d11b33 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -42,11 +42,12 @@ class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "param0");
ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
ErrorSpec(0.0001f));
@@ -54,11 +55,11 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "param0");
ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -67,11 +68,11 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({3.14f, -100.25f});
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2}), "param0");
ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
ErrorSpec(0.01f));
@@ -80,12 +81,13 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XlaBuilder builder(TestName());
string str("hello world");
- std::unique_ptr<Literal> param0_literal = Literal::CreateR1U8(str);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(
- 0, ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}), "param0");
+ Parameter(&builder, 0,
+ ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}),
+ "param0");
ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
}
@@ -93,11 +95,11 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
{param0_data.get()}, ErrorSpec(0.01f));
@@ -105,12 +107,12 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
Array2D<float> expected_array(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
@@ -121,28 +123,28 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
XLA_TEST_F(ParamsTest, TwoParameters) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+ auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
// Use both parameters
//
// {1, 2} + {10, 20} = {11, 22}
- auto sum = builder.Add(param0, param1);
- sum = builder.Add(param0, param1);
+ auto sum = Add(param0, param1);
+ sum = Add(param0, param1);
// Use only the second parameter again, to show that it can be used
// twice and to make the computation asymmetric in the two
// parameters to test that the parameters are not swapped.
//
// {11, 22} * {10, 20} = {110, 440}
- auto prod = builder.Mul(sum, param1);
+ Mul(sum, param1);
ComputeAndCompareR1<float>(&builder, {110, 440},
{param0_data.get(), param1_data.get()},
@@ -152,12 +154,12 @@ XLA_TEST_F(ParamsTest, TwoParameters) {
XLA_TEST_F(ParamsTest, MissingParameter) {
// Test that an error is returned when a computation with an incomplete set of
// parameters (parameter numbers not contiguous from 0) is executed.
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(3.14159f);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
XlaBuilder builder(TestName());
- auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
+ Parameter(&builder, 2, ShapeUtil::MakeShape(F32, {}), "param2");
auto computation_status = builder.Build();
ASSERT_NE(computation_status.status(), Status::OK());
@@ -166,15 +168,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
XLA_TEST_F(ParamsTest, UnusedParameter) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+ Parameter(&builder, 0, literal0->shape(), "param0");
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20});
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+ Parameter(&builder, 1, literal1->shape(), "param1");
ComputeAndCompareR1<float>(&builder, {10, 20},
{param0_data.get(), param1_data.get()},
@@ -186,22 +188,23 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// unused expression.
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*literal0).ConsumeValueOrDie();
- std::unique_ptr<Literal> literal1 = Literal::CreateR1<float>({10, 20, 30});
+ std::unique_ptr<Literal> literal1 =
+ LiteralUtil::CreateR1<float>({10, 20, 30});
std::unique_ptr<GlobalData> param1_data =
client_->TransferToServer(*literal1).ConsumeValueOrDie();
- auto param0 = builder.Parameter(0, literal0->shape(), "param0");
- auto param1 = builder.Parameter(1, literal1->shape(), "param1");
- auto param2 = builder.Parameter(2, literal1->shape(), "param2");
+ auto param0 = Parameter(&builder, 0, literal0->shape(), "param0");
+ auto param1 = Parameter(&builder, 1, literal1->shape(), "param1");
+ auto param2 = Parameter(&builder, 2, literal1->shape(), "param2");
// This add is unused.
- builder.Add(param1, param2);
+ Add(param1, param2);
- builder.Neg(param0);
+ Neg(param0);
ComputeAndCompareR1<float>(
&builder, {-1, -2},
@@ -215,7 +218,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> init_value = {{0, 1}};
init_value.resize(size);
- XlaOp sum_handle = builder.ConstantR1<float>(init_value);
+ XlaOp sum_handle = ConstantR1<float>(&builder, init_value);
std::vector<float> sum = {{0, 1}};
sum.resize(size);
@@ -230,11 +233,11 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::vector<float> sum_value = {{entry0, entry1}};
sum_value.resize(size);
- std::unique_ptr<Literal> literal = Literal::CreateR1<float>(sum_value);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
param_data_owner.push_back(
client_->TransferToServer(*literal).ConsumeValueOrDie());
- XlaOp param = builder.Parameter(i, literal->shape(), "param");
- sum_handle = builder.Add(sum_handle, param);
+ XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ sum_handle = Add(sum_handle, param);
}
std::vector<GlobalData*> param_data;
@@ -260,16 +263,16 @@ XLA_TEST_F(ParamsTest,
XlaBuilder builder(TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
- XlaOp sum_handle = builder.ConstantR0<float>(0.0f);
+ XlaOp sum_handle = ConstantR0<float>(&builder, 0.0f);
float target = 0.0;
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(i);
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = builder.Parameter(i, literal->shape(), "param");
- sum_handle = builder.Add(sum_handle, param);
+ XlaOp param = Parameter(&builder, i, literal->shape(), "param");
+ sum_handle = Add(sum_handle, param);
}
std::vector<GlobalData*> param_data;
@@ -291,26 +294,26 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
XlaBuilder builder(TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
- XlaOp sum_handle = builder.ConstantR1<int32>({0, 0});
+ XlaOp sum_handle = ConstantR1<int32>(&builder, {0, 0});
int32 target = 0;
constexpr int kParamCount = 3000;
std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = Parameter(&builder, i, literal->shape(), "param");
params.push_back(param);
- sum_handle = builder.Add(sum_handle, param);
+ sum_handle = Add(sum_handle, param);
}
std::vector<XlaOp> outputs;
for (int i = 0; i < kParamCount; ++i) {
- outputs.push_back(builder.Add(params[i], sum_handle));
+ outputs.push_back(Add(params[i], sum_handle));
}
- builder.Tuple(outputs);
+ Tuple(&builder, outputs);
std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
@@ -321,10 +324,10 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({target + i, target + i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({target + i, target + i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
// Test large number of parameters flowing into a while-loop.
@@ -353,25 +356,25 @@ XLA_TEST_F(ParamsTest,
std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- XlaOp param = builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = Parameter(&builder, i, literal->shape(), "param");
params.push_back(param);
parameter_shapes.push_back(literal->shape());
}
// Add bool parameter for the loop condition. Use a parameter HLO instead of a
// constant because DCE may eliminate the while-body otherwise.
- std::unique_ptr<Literal> bool_literal = Literal::CreateR0<bool>(false);
+ std::unique_ptr<Literal> bool_literal = LiteralUtil::CreateR0<bool>(false);
param_data_owner.push_back(
std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
XlaOp bool_param =
- builder.Parameter(kParamCount, bool_literal->shape(), "bool_param");
+ Parameter(&builder, kParamCount, bool_literal->shape(), "bool_param");
params.push_back(bool_param);
parameter_shapes.push_back(bool_literal->shape());
- auto init = builder.Tuple(params);
+ auto init = Tuple(&builder, params);
// Create a computation for the condition: while(bool_param).
Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes);
@@ -379,8 +382,8 @@ XLA_TEST_F(ParamsTest,
{
XlaBuilder builder("condition");
auto condition_parameter =
- builder.Parameter(0, while_shape, "condition_parameter");
- builder.GetTupleElement(condition_parameter, kParamCount);
+ Parameter(&builder, 0, while_shape, "condition_parameter");
+ GetTupleElement(condition_parameter, kParamCount);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -389,27 +392,27 @@ XLA_TEST_F(ParamsTest,
XlaComputation body;
{
XlaBuilder builder("body");
- auto body_parameter = builder.Parameter(0, while_shape, "body_parameter");
+ auto body_parameter = Parameter(&builder, 0, while_shape, "body_parameter");
std::vector<XlaOp> updates;
for (int i = 0; i < kParamCount; ++i) {
- auto add = builder.Add(builder.GetTupleElement(body_parameter, i),
- builder.ConstantR1<int32>({1, 1}));
+ auto add = Add(GetTupleElement(body_parameter, i),
+ ConstantR1<int32>(&builder, {1, 1}));
updates.push_back(add);
}
// Add bool parameter.
- updates.push_back(builder.GetTupleElement(body_parameter, kParamCount));
+ updates.push_back(GetTupleElement(body_parameter, kParamCount));
- builder.Tuple(updates);
+ Tuple(&builder, updates);
body = builder.Build().ConsumeValueOrDie();
}
- auto loop = builder.While(condition, body, init);
+ auto loop = While(condition, body, init);
std::vector<XlaOp> outputs;
for (int i = 0; i < kParamCount; ++i) {
- outputs.push_back(builder.GetTupleElement(loop, i));
+ outputs.push_back(GetTupleElement(loop, i));
}
- builder.Tuple(outputs);
+ Tuple(&builder, outputs);
std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
@@ -420,10 +423,10 @@ XLA_TEST_F(ParamsTest,
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
- elements.push_back(Literal::CreateR1<int32>({i, i}));
+ elements.push_back(LiteralUtil::CreateR1<int32>({i, i}));
ptrs.push_back(elements.back().get());
}
- ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
+ ComputeAndCompareTuple(&builder, *LiteralUtil::MakeTuple(ptrs), param_data);
}
#endif
@@ -433,16 +436,16 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
- auto input = builder.Parameter(0, tuple_shape, "input");
- auto lhs = builder.GetTupleElement(input, 0);
- auto rhs = builder.GetTupleElement(input, 1);
- builder.Add(lhs, rhs);
+ auto input = Parameter(&builder, 0, tuple_shape, "input");
+ auto lhs = GetTupleElement(input, 0);
+ auto rhs = GetTupleElement(input, 1);
+ Add(lhs, rhs);
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*Literal::MakeTuple({
- Literal::CreateR1<float>({1, 2, 3}).get(),
- Literal::CreateR1<float>({4, 5, 6}).get(),
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
}))
.ConsumeValueOrDie();
@@ -454,10 +457,10 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
// when (transferred to the server and) passed through a parameter.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
XlaBuilder builder(TestName());
- builder.Parameter(0, literal->shape(), "input");
+ Parameter(&builder, 0, literal->shape(), "input");
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -466,10 +469,10 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
// As above, but for {1, 0} layout.
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
XlaBuilder builder(TestName());
- builder.Parameter(0, literal->shape(), "input");
+ Parameter(&builder, 0, literal->shape(), "input");
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -477,8 +480,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
}
XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
- std::unique_ptr<Literal> literal = Literal::CreateR2<float>({
- {1, 3}, {2, 4},
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 3},
+ {2, 4},
});
const Shape original = literal->shape();
{
@@ -494,9 +498,9 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
}
// Use the original shape in building the computation.
XlaBuilder builder(TestName());
- auto input = builder.Parameter(0, original, "input");
+ auto input = Parameter(&builder, 0, original, "input");
// Use the slice operator to get an off-diagonal element.
- builder.Slice(input, {0, 1}, {1, 2}, {1, 1});
+ Slice(input, {0, 1}, {1, 2}, {1, 1});
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 77159efb26..5c351b2d11 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -29,64 +29,63 @@ namespace {
class PredTest : public ClientLibraryTestBase {
protected:
- void TestCompare(
- bool lhs, bool rhs, bool expected,
- XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ void TestCompare(bool lhs, bool rhs, bool expected,
+ std::function<XlaOp(const xla::XlaOp&, const xla::XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
- XlaOp lhs_op = builder.ConstantR0<bool>(lhs);
- XlaOp rhs_op = builder.ConstantR0<bool>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ XlaOp lhs_op = ConstantR0<bool>(&builder, lhs);
+ XlaOp rhs_op = ConstantR0<bool>(&builder, rhs);
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
};
TEST_F(PredTest, ConstantR0PredTrue) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR0<bool>(true);
+ ConstantR0<bool>(&builder, true);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, ConstantR0PredFalse) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR0<bool>(false);
+ ConstantR0<bool>(&builder, false);
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, ConstantR0PredCompareEq) {
- TestCompare(true, false, false, &XlaBuilder::Eq);
+ TestCompare(true, false, false, &Eq);
}
TEST_F(PredTest, ConstantR0PredCompareNe) {
- TestCompare(true, false, true, &XlaBuilder::Ne);
+ TestCompare(true, false, true, &Ne);
}
TEST_F(PredTest, ConstantR0PredCompareLe) {
- TestCompare(true, false, false, &XlaBuilder::Le);
+ TestCompare(true, false, false, &Le);
}
TEST_F(PredTest, ConstantR0PredCompareLt) {
- TestCompare(true, false, false, &XlaBuilder::Lt);
+ TestCompare(true, false, false, &Lt);
}
TEST_F(PredTest, ConstantR0PredCompareGe) {
- TestCompare(true, false, true, &XlaBuilder::Ge);
+ TestCompare(true, false, true, &Ge);
}
TEST_F(PredTest, ConstantR0PredCompareGt) {
- TestCompare(true, false, true, &XlaBuilder::Gt);
+ TestCompare(true, false, true, &Gt);
}
TEST_F(PredTest, ConstantR1Pred) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({true, false, false, true});
+ ConstantR1<bool>(&builder, {true, false, false, true});
ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
}
TEST_F(PredTest, ConstantR2Pred) {
XlaBuilder builder(TestName());
- auto a =
- builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
+ ConstantR2<bool>(&builder, {{false, true, true}, {true, false, false}});
const string expected = R"(pred[2,3] {
{ 011 },
{ 100 }
@@ -96,44 +95,44 @@ TEST_F(PredTest, ConstantR2Pred) {
TEST_F(PredTest, AnyR1True) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({true, false});
- TF_ASSERT_OK(Any(a, &builder).status());
+ auto a = ConstantR1<bool>(&builder, {true, false});
+ Any(a);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, AnyR1False) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({false, false});
- TF_ASSERT_OK(Any(a, &builder).status());
+ auto a = ConstantR1<bool>(&builder, {false, false});
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, AnyR1VacuouslyFalse) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR1<bool>({});
- TF_ASSERT_OK(Any(a, &builder).status());
+ auto a = ConstantR1<bool>(&builder, {});
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, AnyR2True) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<bool>({
- {false, false, false},
- {false, false, false},
- {false, false, true},
- });
- TF_ASSERT_OK(Any(a, &builder).status());
+ auto a = ConstantR2<bool>(&builder, {
+ {false, false, false},
+ {false, false, false},
+ {false, false, true},
+ });
+ Any(a);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, AnyR2False) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<bool>({
- {false, false, false},
- {false, false, false},
- {false, false, false},
- });
- TF_ASSERT_OK(Any(a, &builder).status());
+ auto a = ConstantR2<bool>(&builder, {
+ {false, false, false},
+ {false, false, false},
+ {false, false, false},
+ });
+ Any(a);
ComputeAndCompareR0<bool>(&builder, false, {});
}
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 1a2de6937c..5ebf8344d2 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -53,8 +53,8 @@ template <typename T>
std::unique_ptr<Literal> PrngTest::UniformTest(
T a, T b, tensorflow::gtl::ArraySlice<int64> dims, int64 seed) {
XlaBuilder builder(TestName());
- builder.RngUniform(
- builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
+ RngUniform(
+ ConstantR0<T>(&builder, a), ConstantR0<T>(&builder, b),
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
SetSeed(seed);
@@ -141,9 +141,9 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
int32 sample_size = range_size * expected_count;
XlaBuilder builder(TestName());
- builder.RngUniform(builder.ConstantR0<int32>(0),
- builder.ConstantR0<int32>(range_size),
- ShapeUtil::MakeShape(S32, {sample_size}));
+ RngUniform(ConstantR0<int32>(&builder, 0),
+ ConstantR0<int32>(&builder, range_size),
+ ShapeUtil::MakeShape(S32, {sample_size}));
SetSeed(seed);
auto actual =
@@ -184,21 +184,22 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
// Build a x -> (x + U[0,1)) computation.
auto build_sum_rng = [this](XlaBuilder& builder) {
auto b = builder.CreateSubBuilder("sum_with_rng");
- auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input");
- b->Add(x, b->RngUniform(b->ConstantR0<float>(0), b->ConstantR0<float>(1),
- ShapeUtil::MakeShape(F32, {})));
+ auto x = Parameter(b.get(), 0, ShapeUtil::MakeShape(F32, {}), "input");
+ Add(x,
+ RngUniform(ConstantR0<float>(b.get(), 0), ConstantR0<float>(b.get(), 1),
+ ShapeUtil::MakeShape(F32, {})));
return b->BuildAndNoteError();
};
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
+ LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
client_->TransferToServer(*param0_literal));
- auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param0 = Parameter(&builder, 0, param0_literal->shape(), "param0");
auto fn = build_sum_rng(builder);
- builder.Map({param0}, fn, {0});
+ Map(&builder, {param0}, fn, {0});
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
@@ -226,9 +227,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
// Build a U[0,1) computation.
auto build_computation = [this]() {
XlaBuilder builder(TestName());
- builder.RngUniform(builder.ConstantR0<float>(0),
- builder.ConstantR0<float>(1),
- ShapeUtil::MakeShape(F32, {10}));
+ RngUniform(ConstantR0<float>(&builder, 0), ConstantR0<float>(&builder, 1),
+ ShapeUtil::MakeShape(F32, {10}));
return builder.Build();
};
@@ -282,8 +282,8 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
XLA_TEST_F(PrngTest, TenValuesN01) {
XlaBuilder builder(TestName());
- builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
- ShapeUtil::MakeShape(F32, {10}));
+ RngNormal(ConstantR0<float>(&builder, 0), ConstantR0<float>(&builder, 1),
+ ShapeUtil::MakeShape(F32, {10}));
SetSeed(42);
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
@@ -294,9 +294,9 @@ XLA_TEST_F(PrngTest, RngUniformCrash) {
XlaBuilder builder(TestName());
// This used to crash XLA during LLVM IR generation for CPUs.
- auto rng_uniform = builder.RngUniform(builder.ConstantR0<int32>(0),
- builder.ConstantR0<int32>(1000 * 1000),
- ShapeUtil::MakeShape(S32, {}));
+ RngUniform(ConstantR0<int32>(&builder, 0),
+ ConstantR0<int32>(&builder, 1000 * 1000),
+ ShapeUtil::MakeShape(S32, {}));
SetSeed(0);
ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
}
diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
index f95e756483..526a38e8d1 100644
--- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
+++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
@@ -31,8 +31,8 @@ class QueryInferredShapeTest : public ClientLibraryTestBase {};
TEST_F(QueryInferredShapeTest, OnePlusOneShape) {
XlaBuilder builder("one_plus_one");
- auto one = builder.ConstantR0<float>(1.0);
- auto result = builder.Add(one, one);
+ auto one = ConstantR0<float>(&builder, 1.0);
+ auto result = Add(one, one);
StatusOr<Shape> shape_status = builder.GetShape(result);
ASSERT_IS_OK(shape_status.status());
auto shape = shape_status.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index c0a2c0ca4c..a080dd1732 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <array>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -73,7 +73,7 @@ ENTRY reduce.1 {
}
)";
- return tools::Parse(hlo_string);
+ return ParseHloString(hlo_string);
}
// TODO(b/72454718): XLA:GPU does not support executing code compiled without
@@ -95,21 +95,21 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) {
*reduce_input_shape->mutable_layout() =
LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major);
- std::unique_ptr<Literal> reduce_input =
- Literal::CreateR4<float>({{ /*i0=0*/
- {/*i1=0*/
- {-0.246092796, -0.179497838, -0.161181688},
- {-0.151643038, -0.240213156, -0.198156}},
- {/*i1=1*/
- {-0.14222312, -0.162200093, -0.193907976},
- {-0.239411, -0.198166847, -0.172471642}}},
- { /*i0=1*/
- {/*i1=0*/
- {-0.22965157, -0.218723893, -0.129257083},
- {-0.188762426, -0.16123569, -0.181166649}},
- {/*i1=1*/
- {-0.241772294, -0.245131493, -0.160247207},
- {-0.179881215, -0.23383224, -0.121976733}}}});
+ std::unique_ptr<Literal> reduce_input = LiteralUtil::CreateR4<float>(
+ {{ /*i0=0*/
+ {/*i1=0*/
+ {-0.246092796, -0.179497838, -0.161181688},
+ {-0.151643038, -0.240213156, -0.198156}},
+ {/*i1=1*/
+ {-0.14222312, -0.162200093, -0.193907976},
+ {-0.239411, -0.198166847, -0.172471642}}},
+ { /*i0=1*/
+ {/*i1=0*/
+ {-0.22965157, -0.218723893, -0.129257083},
+ {-0.188762426, -0.16123569, -0.181166649}},
+ {/*i1=1*/
+ {-0.241772294, -0.245131493, -0.160247207},
+ {-0.179881215, -0.23383224, -0.121976733}}}});
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5)));
}
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index b311785449..04c7f31646 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -230,12 +230,13 @@ XLA_TEST_P(ReducePrecisionAccuracyTest, ReducePrecisionF32) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({input_values});
+ std::unique_ptr<Literal> a_literal =
+ LiteralUtil::CreateR1<float>({input_values});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
- builder.ReducePrecision(a, exponent_bits, mantissa_bits);
+ ReducePrecision(a, exponent_bits, mantissa_bits);
ComputeAndCompareR1<float>(&builder, expected_values, {a_data.get()});
}
@@ -253,18 +254,18 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
// Abs doesn't affect resolution.
- auto abs = builder.Abs(a);
+ auto abs = Abs(a);
// Near 1.0, Log(x) approximates x - 1; this lets us confirm that the
// reduce-precision operation showed up in the correct place in the
// graph.
- builder.Log(abs);
+ Log(abs);
// Insert precision-reduction after the Abs(x) operation, rounding that
// result to exactly 1.0f.
@@ -282,14 +283,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
// These two operations should be fused by any reasonable backend.
- auto abs = builder.Abs(a);
- builder.Neg(abs);
+ auto abs = Abs(a);
+ Neg(abs);
// Add a pass after operation fusion, suffixing kAbs operations. This
// should not see into the fusion nodes and thus should not affect the
@@ -308,14 +309,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
// These two operations should be fused by any reasonable backend.
- auto abs = builder.Abs(a);
- builder.Neg(abs);
+ auto abs = Abs(a);
+ Neg(abs);
// Add a pass after operation fusion, suffixing kFusion operations.
auto reduce_precision_pass = execution_options_.mutable_debug_options()
@@ -332,14 +333,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
// These two operations should be fused by any reasonable backend.
- auto abs = builder.Abs(a);
- builder.Neg(abs);
+ auto abs = Abs(a);
+ Neg(abs);
// Add a pass suffixing fusion nodes containing kCos operations. This
// should have no effect.
@@ -357,14 +358,14 @@ XLA_TEST_F(ReducePrecisionInsertionTest,
DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR1<float>({1.00001});
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({1.00001});
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
- auto a = builder.Parameter(0, a_literal->shape(), "a");
+ auto a = Parameter(&builder, 0, a_literal->shape(), "a");
// These two operations should be fused by any reasonable backend.
- auto abs = builder.Abs(a);
- builder.Neg(abs);
+ auto abs = Abs(a);
+ Neg(abs);
// Add a pass suffixing fusion nodes containing kAbs operations. This
// should see the kAbs operation within the above fusion node.
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index d671d40456..1407fca72f 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -67,12 +67,12 @@ class ReduceTest : public ClientLibraryTestBase {
ReduceTest() {
// Implementation note: laid out z >> y >> x by default.
// clang-format off
- literal_2d_ = Literal::CreateR2<float>({
+ literal_2d_ = LiteralUtil::CreateR2<float>({
// x0 x1 x2
{ 1.f, 2.f, 3.f}, // y0
{ 4.f, 5.f, 6.f}, // y1
});
- literal_3d_ = Literal::CreateR3Projected<float>({
+ literal_3d_ = LiteralUtil::CreateR3Projected<float>({
// x0 x1 x2
{ 1.f, 2.f, 3.f}, // y0
{ 4.f, 5.f, 6.f}, // y1
@@ -89,9 +89,9 @@ class ReduceTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
std::vector<float> input_data(element_count);
for (int64 i = 0; i < element_count; ++i) {
@@ -101,7 +101,7 @@ class ReduceTest : public ClientLibraryTestBase {
}
}
std::unique_ptr<Literal> input_literal =
- Literal::CreateR1(AsSlice(input_data));
+ LiteralUtil::CreateR1(AsSlice(input_data));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -118,22 +118,22 @@ class ReduceTest : public ClientLibraryTestBase {
const int element_count = input_data.size();
XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
- auto input_par = builder.Parameter(0, input_shape, "input");
+ auto input_par = Parameter(&builder, 0, input_shape, "input");
auto pred_values =
- builder.Eq(input_par, builder.ConstantR1<int>(element_count, 1));
+ Eq(input_par, ConstantR1<int>(&builder, element_count, 1));
XlaOp init_value;
XlaComputation reduce;
if (and_reduce) {
- init_value = builder.ConstantR0<bool>(true);
+ init_value = ConstantR0<bool>(&builder, true);
reduce = CreateScalarAndComputation(&builder);
} else {
- init_value = builder.ConstantR0<bool>(false);
+ init_value = ConstantR0<bool>(&builder, false);
reduce = CreateScalarOrComputation(&builder);
}
- builder.Reduce(pred_values, init_value, reduce,
- /*dimensions_to_reduce=*/{0});
+ Reduce(pred_values, init_value, reduce,
+ /*dimensions_to_reduce=*/{0});
- std::unique_ptr<Literal> input_literal = Literal::CreateR1(input_data);
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR1(input_data);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -156,26 +156,26 @@ class ReduceTest : public ClientLibraryTestBase {
int64 major = 0) {
XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto input_pred = builder.Eq(input, builder.ConstantR0<uint8>(1));
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto input_pred = Eq(input, ConstantR0<uint8>(&builder, 1));
XlaOp init_value;
XlaComputation reduce_op;
if (and_reduce) {
- init_value = builder.ConstantR0<bool>(true);
+ init_value = ConstantR0<bool>(&builder, true);
reduce_op = CreateScalarAndComputation(&builder);
} else {
- init_value = builder.ConstantR0<bool>(false);
+ init_value = ConstantR0<bool>(&builder, false);
reduce_op = CreateScalarOrComputation(&builder);
}
- builder.Reduce(input_pred, init_value, reduce_op,
- /*dimensions_to_reduce=*/{0});
+ Reduce(input_pred, init_value, reduce_op,
+ /*dimensions_to_reduce=*/{0});
Array2D<uint8> input_data(rows, cols);
input_data.FillRandom(0, 1);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -202,14 +202,14 @@ class ReduceTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -230,14 +230,14 @@ class ReduceTest : public ClientLibraryTestBase {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -287,15 +287,15 @@ class ReduceTest : public ClientLibraryTestBase {
XlaComputation reduction_function = reduction_function_generator(&builder);
const Shape input_shape = ShapeUtil::MakeShape(
xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<NativeT>(initial_value);
- builder.Reduce(input, zero, reduction_function,
- /*dimensions_to_reduce=*/{0});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<NativeT>(&builder, initial_value);
+ Reduce(input, zero, reduction_function,
+ /*dimensions_to_reduce=*/{0});
Array2D<NativeT> input_data(rows, cols);
input_data.FillUnique(initial_value);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout({minor, major}));
std::unique_ptr<GlobalData> input_global_data =
@@ -442,15 +442,15 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- auto log_ = builder.Log(input);
- builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto log_ = Log(input);
+ Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -473,16 +473,16 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- auto log_ = builder.Log(input);
- auto transpose = builder.Transpose(log_, {1, 0});
- builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto log_ = Log(input);
+ auto transpose = Transpose(log_, {1, 0});
+ Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2D(input_data);
+ LiteralUtil::CreateR2FromArray2D(input_data);
input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -505,10 +505,10 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50});
- XlaOp input = builder.Parameter(0, input_shape, "input");
- XlaOp zero = builder.ConstantR0<float>(0.0);
- XlaOp transpose = builder.Transpose(input, /*permutation=*/{1, 0, 2});
- builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
+ XlaOp input = Parameter(&builder, 0, input_shape, "input");
+ XlaOp zero = ConstantR0<float>(&builder, 0.0);
+ XlaOp transpose = Transpose(input, /*permutation=*/{1, 0, 2});
+ Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> input_data,
MakeFakeLiteral(input_shape));
@@ -522,16 +522,16 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
XlaBuilder builder(TestName());
XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto zero = builder.ConstantR0<float>(0.0);
- auto log_ = builder.Tanh(input);
- auto reshape = builder.Reshape(log_, {rows, cols});
- builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto log_ = Tanh(input);
+ auto reshape = Reshape(log_, {rows, cols});
+ Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR3FromArray3D(input_data);
+ LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
@@ -568,9 +568,9 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) {
XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
XlaBuilder builder(TestName());
auto add = CreateScalarAddComputation(F32, &builder);
- auto scalar = builder.ConstantR0<float>(42.0);
- auto broadcasted = builder.Broadcast(scalar, {500, 500});
- builder.Reduce(broadcasted, builder.ConstantR0<float>(0.0f), add, {0, 1});
+ auto scalar = ConstantR0<float>(&builder, 42.0);
+ auto broadcasted = Broadcast(scalar, {500, 500});
+ Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
float expected = 42.0f * static_cast<float>(500 * 500);
ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -580,9 +580,9 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
XlaBuilder builder(TestName());
auto max = CreateScalarMaxComputation(F32, &builder);
- auto scalar = builder.ConstantR0<float>(42.0);
- auto broadcasted = builder.Broadcast(scalar, {500, 500});
- builder.Reduce(broadcasted, builder.ConstantR0<float>(0.0f), max, {0, 1});
+ auto scalar = ConstantR0<float>(&builder, 42.0);
+ auto broadcasted = Broadcast(scalar, {500, 500});
+ Reduce(broadcasted, ConstantR0<float>(&builder, 0.0f), max, {0, 1});
float expected = 42.0f;
ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -594,9 +594,9 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
auto max = CreateScalarMaxComputation(F32, &builder);
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input);
- builder.Reduce(builder.ConstantLiteral(*input_literal),
- builder.ConstantR0<float>(FLT_MIN), max, {0, 1});
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ Reduce(ConstantLiteral(&builder, *input_literal),
+ ConstantR0<float>(&builder, FLT_MIN), max, {0, 1});
auto input_max = FLT_MIN;
input.Each(
[&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
@@ -609,9 +609,9 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
auto min = CreateScalarMinComputation(F32, &builder);
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input);
- builder.Reduce(builder.ConstantLiteral(*input_literal),
- builder.ConstantR0<float>(FLT_MAX), min, {0, 1});
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ Reduce(ConstantLiteral(&builder, *input_literal),
+ ConstantR0<float>(&builder, FLT_MAX), min, {0, 1});
auto input_min = FLT_MAX;
input.Each(
@@ -623,12 +623,11 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto min = CreateScalarMinComputation(U32, &builder);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
auto initial_value =
- builder.ConstantR0<uint32>(std::numeric_limits<uint32>::max());
+ ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::max());
- builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, min,
- {0, 1});
+ Reduce(ConstantLiteral(&builder, *input_literal), initial_value, min, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
@@ -636,21 +635,20 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto max = CreateScalarMaxComputation(U32, &builder);
- auto input_literal = Literal::CreateR2FromArray2D(input);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
auto initial_value =
- builder.ConstantR0<uint32>(std::numeric_limits<uint32>::min());
+ ConstantR0<uint32>(&builder, std::numeric_limits<uint32>::min());
- builder.Reduce(builder.ConstantLiteral(*input_literal), initial_value, max,
- {0, 1});
+ Reduce(ConstantLiteral(&builder, *input_literal), initial_value, max, {0, 1});
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_2d_);
+ auto m = ConstantLiteral(&builder, *literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
std::vector<float> expected = {6.f, 15.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -659,9 +657,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_2d_);
+ auto m = ConstantLiteral(&builder, *literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
}
@@ -669,9 +667,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XlaBuilder builder("reduce_among_y");
- auto m = builder.ConstantLiteral(*literal_2d_);
+ auto m = ConstantLiteral(&builder, *literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
std::vector<float> expected = {5.f, 7.f, 9.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -679,9 +677,9 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1, 2});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1, 2});
std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -689,9 +687,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1});
std::vector<float> expected = {20.f, 28.f, 36.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -699,9 +697,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1, 2});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0, 1, 2});
float expected = 21.0f * 4.0;
ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
@@ -709,9 +707,9 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {0});
// clang-format off
Array2D<float> expected({
@@ -724,9 +722,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {1});
// clang-format off
Array2D<float> expected({
@@ -741,9 +739,9 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
XlaBuilder builder(TestName());
- auto m = builder.ConstantLiteral(*literal_3d_);
+ auto m = ConstantLiteral(&builder, *literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
- builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {2});
+ Reduce(m, ConstantR0<float>(&builder, 0.0f), add, {2});
// clang-format off
Array2D<float> expected({
@@ -820,17 +818,17 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
// input_array.FillRandom(3.14f, 0.05);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR3FromArray3D(input_array);
+ auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
input_literal =
input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout));
std::unique_ptr<GlobalData> input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
auto input_activations =
- builder.Parameter(0, input_literal->shape(), "input");
+ Parameter(&builder, 0, input_literal->shape(), "input");
XlaComputation add = CreateScalarAddComputation(F32, &builder);
- auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
- add, GetParam().reduce_dims);
+ Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
+ GetParam().reduce_dims);
auto expected =
ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
@@ -871,14 +869,15 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
XlaBuilder builder(TestName());
XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
- auto a = builder.ConstantR0<float>(2.0f);
- auto a2 = builder.Abs(a);
+ auto a = ConstantR0<float>(&builder, 2.0f);
+ auto a2 = Abs(a);
- std::unique_ptr<Literal> b_literal = Literal::CreateR1<float>({1.0f, 4.0f});
+ std::unique_ptr<Literal> b_literal =
+ LiteralUtil::CreateR1<float>({1.0f, 4.0f});
std::unique_ptr<GlobalData> b_data =
client_->TransferToServer(*b_literal).ConsumeValueOrDie();
- auto b = builder.Parameter(0, b_literal->shape(), "b");
- auto max = builder.Reduce(b, a2, max_f32, {0});
+ auto b = Parameter(&builder, 0, b_literal->shape(), "b");
+ Reduce(b, a2, max_f32, {0});
ComputeAndCompareR0<float>(&builder, 4.0f, {b_data.get()});
}
@@ -900,13 +899,13 @@ class ReduceInitializerTest : public ReduceTest {
XlaComputation max_fn = CreateScalarMaxComputation(
primitive_util::NativeToPrimitiveType<T>(), &builder);
- auto init = builder.ConstantR0<T>(initializer);
+ auto init = ConstantR0<T>(&builder, initializer);
std::vector<T> input_arr(num_elems, std::numeric_limits<T>::lowest());
- auto input_literal = Literal::CreateR1<T>(input_arr);
+ auto input_literal = LiteralUtil::CreateR1<T>(input_arr);
auto input_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- builder.Reduce(builder.Parameter(0, input_literal->shape(), "input"), init,
- max_fn, {0});
+ Reduce(Parameter(&builder, 0, input_literal->shape(), "input"), init,
+ max_fn, {0});
ComputeAndCompareR0<T>(&builder, initializer, {input_data.get()});
}
@@ -939,23 +938,24 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
XLA_TEST_F(ReduceTest, ReduceIdentity) {
XlaBuilder builder(TestName());
Shape single_float = ShapeUtil::MakeShape(F32, {});
- builder.Parameter(0, single_float, "lhs-unused");
- builder.Parameter(1, single_float, "rhs-used");
+ Parameter(&builder, 0, single_float, "lhs-unused");
+ Parameter(&builder, 1, single_float, "rhs-used");
auto computation_status = builder.Build();
TF_ASSERT_OK(computation_status.status());
Shape operand_shape = ShapeUtil::MakeShape(F32, {1});
- builder.Reduce(builder.Parameter(0, operand_shape, "operand"),
- builder.Parameter(1, single_float, "init"),
- computation_status.ValueOrDie(), {0});
+ Reduce(Parameter(&builder, 0, operand_shape, "operand"),
+ Parameter(&builder, 1, single_float, "init"),
+ computation_status.ValueOrDie(), {0});
float operand[] = {42.0f};
float init = 58.5f;
float expected = 42.0f;
- std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(operand);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR1<float>(operand);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
- std::unique_ptr<Literal> input_literal2 = Literal::CreateR0<float>(init);
+ std::unique_ptr<Literal> input_literal2 = LiteralUtil::CreateR0<float>(init);
std::unique_ptr<GlobalData> input_global_data2 =
client_->TransferToServer(*input_literal2).ConsumeValueOrDie();
ComputeAndCompareR0<float>(
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 266760e820..c2681f70f7 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -70,31 +70,33 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_);
- builder_.ReduceWindow(input, init,
- CreateScalarAddComputation(FloatType(), &builder_),
- window_dimensions, window_strides, padding);
+ auto init = CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f),
+ &builder_);
+ ReduceWindow(input, init,
+ CreateScalarAddComputation(FloatType(), &builder_),
+ window_dimensions, window_strides, padding);
}
void ReduceWindowMax(const XlaOp& input,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_);
- builder_.ReduceWindow(input, init,
- CreateScalarMaxComputation(FloatType(), &builder_),
- window_dimensions, window_strides, padding);
+ auto init =
+ CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
+ ReduceWindow(input, init,
+ CreateScalarMaxComputation(FloatType(), &builder_),
+ window_dimensions, window_strides, padding);
}
void ReduceWindowMin(const XlaOp& input,
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
- auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_);
- builder_.ReduceWindow(input, init,
- CreateScalarMinComputation(FloatType(), &builder_),
- window_dimensions, window_strides, padding);
+ auto init =
+ CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
+ ReduceWindow(input, init,
+ CreateScalarMinComputation(FloatType(), &builder_),
+ window_dimensions, window_strides, padding);
}
XlaBuilder builder_;
@@ -102,14 +104,14 @@ class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({1, 1, 1, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
const auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0), &builder_);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0), &builder_);
TF_ASSERT_OK(builder_.first_error());
- builder_.ReduceWindow(input, init_value,
- CreateScalarAddComputation(FloatType(), &builder_),
- /*window_dimensions=*/{1, 2},
- /*window_strides=*/{1}, Padding::kValid);
+ ReduceWindow(input, init_value,
+ CreateScalarAddComputation(FloatType(), &builder_),
+ /*window_dimensions=*/{1, 2},
+ /*window_strides=*/{1}, Padding::kValid);
ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
<< builder_.first_error();
ASSERT_THAT(builder_.first_error().error_message(),
@@ -119,33 +121,32 @@ TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
// Regression test for b/68964348.
TEST_P(ReduceWindowTest, R0ReduceWindow) {
const auto input =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(42.0), &builder_);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(42.0), &builder_);
const auto init =
- CreateConstantFromLiteral(*Literal::CreateR0<float>(1.0), &builder_);
- builder_.ReduceWindow(input, init,
- CreateScalarAddComputation(FloatType(), &builder_),
- /*window_dimensions=*/{},
- /*window_strides=*/{}, Padding::kSame);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR0<float>(43.0), {},
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(1.0), &builder_);
+ ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
+ /*window_dimensions=*/{},
+ /*window_strides=*/{}, Padding::kSame);
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR0<float>(43.0), {},
ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride2) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, {3}, {2}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({100, 1}), {},
- ErrorSpec(0.00001));
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({100, 1}),
+ {}, ErrorSpec(0.00001));
}
TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
+ *LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
Padding::kSame);
ComputeAndCompareLiteral(&builder_,
- *Literal::CreateR1<float>({1000, 100, 10, 1, 1}), {},
- ErrorSpec(0.00001));
+ *LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
+ {}, ErrorSpec(0.00001));
}
XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
@@ -157,7 +158,7 @@ XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -172,7 +173,7 @@ TEST_P(ReduceWindowTest, NonSquareSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -186,7 +187,7 @@ TEST_P(ReduceWindowTest, MiddleDimsSmall) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
{1, 2, 2, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -203,7 +204,7 @@ TEST_P(ReduceWindowTest, Along2ndMinorDim) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res), {},
DefaultErrorSpec());
}
@@ -225,8 +226,8 @@ TEST_P(ReduceWindowTest, AmongMajor2Dims) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
@@ -248,8 +249,8 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
// Tests the super windowing logic w.r.t handling prime number of windows in a
@@ -273,8 +274,8 @@ TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
@@ -290,8 +291,8 @@ TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
// Tests a reduction function that is not a simple add/min/max/etc.
@@ -306,15 +307,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
Padding padding = Padding::kValid;
const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
auto b = builder_.CreateSubBuilder("unusual");
- auto lhs = b->Parameter(0, scalar, "lhs");
- auto rhs = b->Parameter(1, scalar, "rhs");
- b->Min(b->Add(lhs, rhs),
- CreateConstantFromLiteral(*Literal::CreateR0<float>(8.0f), b.get()));
+ auto lhs = Parameter(b.get(), 0, scalar, "lhs");
+ auto rhs = Parameter(b.get(), 1, scalar, "rhs");
+ Min(Add(lhs, rhs),
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(8.0f), b.get()));
XlaComputation reduce_fn = b->BuildAndNoteError();
- builder_.ReduceWindow(
+ ReduceWindow(
input,
- CreateConstantFromLiteral(*Literal::CreateR0<float>(0.0f), &builder_),
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0<float>(0.0f), &builder_),
reduce_fn,
/*window_dimensions=*/{1, 1, 2, 1},
/*window_strides=*/{1, 1, 1, 1}, padding);
@@ -328,15 +329,15 @@ XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
/*window=*/{1, 1, 2, 1},
/*stride=*/{1, 1, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*expected),
+ {}, DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R4UnitWindow) {
Array4D<float> input_array(13, 12, 8, 15);
input_array.FillRandom(2.f, 2.f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -348,7 +349,7 @@ TEST_P(ReduceWindowTest, R4UnitWindow) {
auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
{1, 4, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -377,7 +378,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
auto shape = ShapeUtil::MakeShape(F32, input_dims);
std::unique_ptr<Literal> arg_literal =
- Literal::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
const auto input = CreateConstantFromLiteral(*arg_literal, &builder_);
@@ -386,7 +387,7 @@ XLA_TEST_P(ReduceWindowTest, R6Add) {
std::vector<int64> output_dims = {8, 8, 6, 6, 8, 8};
std::unique_ptr<Literal> expected =
- Literal::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
+ LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec());
}
@@ -395,7 +396,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
Array4D<float> input_array(2, 1, 27, 119);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -409,7 +410,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -417,7 +418,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
Array4D<float> input_array(3, 2, 4, 64);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -431,7 +432,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -439,7 +440,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
Array4D<float> input_array(1, 3, 12, 200);
input_array.FillRandom(2.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp input;
auto input_data = CreateParameterAndTransferLiteral(
@@ -453,7 +454,7 @@ XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
auto res = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res),
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*res),
{input_data.get()}, DefaultErrorSpec());
}
@@ -474,18 +475,18 @@ TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
auto result = ReferenceUtil::ReduceWindow4DAdd(
input_array, 0.0f, {win_len, win_len, 1, 1},
{win_stride, win_stride, 1, 1}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {},
- DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateFromArray(*result),
+ {}, DefaultErrorSpec());
}
XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
std::vector<float> input_vector(128 * 9, 1);
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
ComputeAndCompareLiteral(
&builder_,
- *Literal::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
+ *LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
DefaultErrorSpec());
}
@@ -500,9 +501,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -517,9 +518,9 @@ XLA_TEST_P(ReduceWindowTest, Add128In128) {
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
const auto input = CreateConstantFromLiteral(
- *Literal::CreateR1<float>(input_vector), &builder_);
+ *LiteralUtil::CreateR1<float>(input_vector), &builder_);
ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateR1<float>({1088}), {},
+ ComputeAndCompareLiteral(&builder_, *LiteralUtil::CreateR1<float>({1088}), {},
DefaultErrorSpec());
}
@@ -536,14 +537,15 @@ TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(
input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_,
+ *LiteralUtil::CreateFromArray<float>(*res), {},
+ DefaultErrorSpec());
}
TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
Array2D<float> input_array(6, 4, 1.0f);
- XlaOp input = builder_.Broadcast(
- CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4});
+ XlaOp input = Broadcast(
+ CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
Padding padding = Padding::kSame;
ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
@@ -551,8 +553,9 @@ TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3},
padding);
- ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray<float>(*res),
- {}, DefaultErrorSpec());
+ ComputeAndCompareLiteral(&builder_,
+ *LiteralUtil::CreateFromArray<float>(*res), {},
+ DefaultErrorSpec());
}
INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
@@ -610,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
param.base_bounds[2], param.base_bounds[3]);
input.FillIota(1);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
@@ -622,12 +625,12 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
}
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
CHECK(param.reducer == kAdd || param.reducer == kMax);
auto computation = param.reducer == kAdd
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
- b.ReduceWindowWithGeneralPadding(
+ ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
/*computation=*/computation,
@@ -648,7 +651,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*stride=*/param.strides,
/*padding=*/padding);
std::unique_ptr<Literal> expected_literal =
- Literal::CreateFromArray(*expected);
+ LiteralUtil::CreateFromArray(*expected);
const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
input_literal->shape().element_type(),
AsInt64Slice(expected_literal->shape().dimensions()), param.layout);
@@ -960,25 +963,25 @@ TEST_P(R3ReduceWindowTest, Add) {
Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
param.base_bounds[2], 1.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR3FromArray3DWithLayout(
+ LiteralUtil::CreateR3FromArray3DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
- b.ReduceWindow(/*operand=*/parameter,
- /*init_value=*/init_value,
- /*computation=*/CreateScalarAddComputation(FloatType(), &b),
- /*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/param.padding);
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ ReduceWindow(/*operand=*/parameter,
+ /*init_value=*/init_value,
+ /*computation=*/CreateScalarAddComputation(FloatType(), &b),
+ /*window_dimensions=*/param.window_bounds,
+ /*window_strides=*/param.strides, /*padding=*/param.padding);
auto expected = ReferenceUtil::ReduceWindow3DAdd(
/*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/param.padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
@@ -1094,7 +1097,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
const float kInitValue = 0.0f;
Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR2FromArray2DWithLayout(
+ LiteralUtil::CreateR2FromArray2DWithLayout(
input, LayoutUtil::MakeLayout(param.layout));
XlaOp parameter;
@@ -1108,8 +1111,8 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
- b.ReduceWindowWithGeneralPadding(
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
/*computation=*/computation,
@@ -1124,7 +1127,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*window=*/param.window_bounds,
/*stride=*/param.strides, /*padding=*/padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
};
@@ -1293,7 +1296,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
std::vector<float> input_vector(param.base_bounds[0]);
std::iota(std::begin(input_vector), std::end(input_vector), 0);
std::unique_ptr<Literal> input_literal =
- Literal::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
+ LiteralUtil::CreateR1(tensorflow::gtl::ArraySlice<float>(input_vector));
XlaOp parameter;
auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0",
&b, &parameter);
@@ -1305,8 +1308,8 @@ TEST_P(R1ReduceWindowTest, DoIt) {
? CreateScalarAddComputation(FloatType(), &b)
: CreateScalarMaxComputation(FloatType(), &b);
auto init_value =
- CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b);
- b.ReduceWindowWithGeneralPadding(
+ CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b);
+ ReduceWindowWithGeneralPadding(
/*operand=*/parameter,
/*init_value=*/init_value,
/*computation=*/computation,
@@ -1324,7 +1327,7 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*stride=*/param.strides,
/*padding=*/padding);
- ComputeAndCompareLiteral(&b, *Literal::CreateR1<float>(*expected),
+ ComputeAndCompareLiteral(&b, *LiteralUtil::CreateR1<float>(*expected),
{input_arg.get()}, DefaultErrorSpec());
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index 36d763b0f7..d544968648 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -39,8 +39,8 @@ class ReplayTest : public ClientLibraryTestBase {};
TEST_F(ReplayTest, TwoPlusTwoReplay) {
// Make 2+2 computation.
XlaBuilder builder(TestName());
- auto two = builder.ConstantR0<int32>(2);
- builder.Add(two, two);
+ auto two = ConstantR0<int32>(&builder, 2);
+ Add(two, two);
XlaComputation computation = builder.Build().ConsumeValueOrDie();
// Serialize it out.
@@ -70,9 +70,9 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Make computation.
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(S32, {}), "y");
+ Add(x, y);
XlaComputation computation = builder.Build().ConsumeValueOrDie();
// Serialize it out.
@@ -91,10 +91,10 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Run it.
std::unique_ptr<GlobalData> x_data =
- client_->TransferToServer(*Literal::CreateR0<int32>(2))
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> y_data =
- client_->TransferToServer(*Literal::CreateR0<int32>(3))
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
.ConsumeValueOrDie();
std::unique_ptr<Literal> literal =
client_
@@ -111,13 +111,13 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
// As above, but with map(+2) over some constant array.
XlaBuilder plus_two_builder("plus two");
auto input =
- plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
- plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
+ Parameter(&plus_two_builder, 0, ShapeUtil::MakeShape(S32, {}), "input");
+ Add(input, ConstantR0<int32>(&plus_two_builder, 2));
XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
XlaBuilder mapper_builder(TestName());
- auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
- mapper_builder.Map({original}, plus_two, {0});
+ auto original = ConstantR1<int32>(&mapper_builder, {1, 2, 3});
+ Map(&mapper_builder, {original}, plus_two, {0});
XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index da1b588ec4..7c0389cfa3 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -44,11 +44,11 @@ using ReshapeMotionTest = ClientLibraryTestBase;
TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) {
XlaBuilder builder(TestName());
- auto a = builder.ConstantR2<int32>({{2, 3, 5}, {7, 11, 13}});
- auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}});
- auto c = builder.Reshape(a, {6});
- auto d = builder.Reshape(b, {6});
- auto e = builder.Mul(c, d);
+ auto a = ConstantR2<int32>(&builder, {{2, 3, 5}, {7, 11, 13}});
+ auto b = ConstantR2<int32>(&builder, {{17, 19}, {23, 29}, {31, 37}});
+ auto c = Reshape(a, {6});
+ auto d = Reshape(b, {6});
+ Mul(c, d);
ComputeAndCompareR1<int32>(&builder, {34, 57, 115, 203, 341, 481}, {});
}
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index a4580cd71d..46d91711a5 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -55,39 +55,39 @@ XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) {
XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0});
- auto expected_literal = Literal::CreateR1<float>({1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -97,15 +97,15 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
XlaBuilder builder(TestName());
Array2D<float> input_array(1, 1);
input_array.Fill(1.0f);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter",
&builder, &parameter);
- auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
- /*new_sizes=*/{});
+ auto reshape = Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{});
auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<float>(1.0f);
+ auto expected_literal = LiteralUtil::CreateR0<float>(1.0f);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -113,63 +113,54 @@ XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) {
XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(1.0f);
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(1.0f);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
- auto a = builder.Neg(parameter);
- builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
+ auto a = Neg(parameter);
+ Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1});
- auto expected_literal = Literal::CreateR1<float>({-1.0f});
+ auto expected_literal = LiteralUtil::CreateR1<float>({-1.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) {
+XLA_TEST_P(ReshapeTest, Trivial0x3) {
XlaBuilder builder(TestName());
Array2D<float> input_array(0, 3);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-05-15
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) {
+XLA_TEST_P(ReshapeTest, Trivial0x3WithParameter) {
XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
- Literal::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(0, 3));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
+XLA_TEST_P(ReshapeTest, Trivial3x0) {
XlaBuilder builder(TestName());
Array2D<float> input_array(3, 0);
- auto input_literal = Literal::CreateR2FromArray2D(input_array);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input_array);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ auto expected_literal = LiteralUtil::CreateR1<float>({});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -177,12 +168,12 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) {
// Collapses a 2-dimensional row vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial1x3) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto input_literal = LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -190,30 +181,26 @@ XLA_TEST_P(ReshapeTest, Trivial1x3) {
// Collapses a 2-dimensional column vector to 1 dimension.
XLA_TEST_P(ReshapeTest, Trivial3x1) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
+ auto input_literal = LiteralUtil::CreateR2<float>({{1.0f}, {2.0f}, {3.0f}});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
- auto expected_literal = Literal::CreateR1<float>({1.0f, 2.0f, 3.0f});
+ Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1});
+ auto expected_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Splits an empty vector into an empty matrix.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) {
+XLA_TEST_P(ReshapeTest, R1ToR2_0_To_2x0) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateR1<float>({});
+ auto input_literal = LiteralUtil::CreateR1<float>({});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0},
- /*new_sizes=*/{2, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0},
+ /*new_sizes=*/{2, 0});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -222,32 +209,28 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) {
XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) {
XlaBuilder builder(TestName());
auto input_literal =
- Literal::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0},
- /*new_sizes=*/{2, 3});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0},
+ /*new_sizes=*/{2, 3});
auto expected_literal =
- Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Transposes a 2x0 array to a 0x2 array.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) {
+XLA_TEST_P(ReshapeTest, Reshape0x2To2x0) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 2));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 2));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
- /*new_sizes=*/{2, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 0});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -256,15 +239,15 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) {
XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
XlaBuilder builder(TestName());
auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
- auto input_literal = Literal::CreateFromArray(*simple);
+ auto input_literal = LiteralUtil::CreateFromArray(*simple);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
- /*new_sizes=*/{3, 1});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{3, 1});
auto expected = ReferenceUtil::TransposeArray2D(*simple);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -273,32 +256,28 @@ XLA_TEST_P(ReshapeTest, ReshapeRowToCol) {
XLA_TEST_P(ReshapeTest, TransposeAsReshape) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
- /*new_sizes=*/{3, 4});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 4});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Transposes a 0x4 array with XlaBuilder::Transpose.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) {
+XLA_TEST_P(ReshapeTest, Transpose0x4) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 4));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 4));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Transpose(parameter, {1, 0});
- auto expected_literal = Literal::CreateR2<float>({{}, {}, {}, {}});
+ Transpose(parameter, {1, 0});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{}, {}, {}, {}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -307,49 +286,43 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) {
XLA_TEST_P(ReshapeTest, Transpose4x3) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Transpose(parameter, {1, 0});
+ Transpose(parameter, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*a4x3);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
// Reshapes an empty 2-dimensional array with dimensions that are not just a
// rearrangement of the originals (split), but no reordering (no shuffle).
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(6, 0));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(6, 0));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
- /*new_sizes=*/{2, 3, 0, 0});
- auto expected_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 0, 0));
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 3, 0, 0});
+ auto expected_literal =
+ LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 0, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeR4ToR2ZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array4D<float>(2, 3, 4, 0));
+ auto input_literal = LiteralUtil::CreateFromArray(Array4D<float>(2, 3, 4, 0));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
- /*new_sizes=*/{24, 0});
- auto expected_literal = Literal::CreateFromArray(Array2D<float>(24, 0));
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{24, 0});
+ auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(24, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -359,32 +332,28 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) {
XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
- /*new_sizes=*/{2, 6});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 6});
auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
- auto expected_literal = Literal::CreateFromArray(*expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(*expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
-// TODO(b/29185393): Make this work with the GPU backend. The GPU backend
-// does not handle zero-sized shapes correctly. Failed last on 2017-11-30
-// with an incorrect result rank.
-//
-XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
+XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(Array2D<float>(0, 6));
+ auto input_literal = LiteralUtil::CreateFromArray(Array2D<float>(0, 6));
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
- /*new_sizes=*/{3, 0});
- auto expected_literal = Literal::CreateFromArray(Array2D<float>(3, 0));
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 0});
+ auto expected_literal = LiteralUtil::CreateFromArray(Array2D<float>(3, 0));
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -394,15 +363,15 @@ XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) {
XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) {
XlaBuilder builder(TestName());
auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
- auto input_literal = Literal::CreateFromArray(*a4x3);
+ auto input_literal = LiteralUtil::CreateFromArray(*a4x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
- /*new_sizes=*/{2, 6});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{2, 6});
Array2D<float> expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
{8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
- auto expected_literal = Literal::CreateFromArray(expected);
+ auto expected_literal = LiteralUtil::CreateFromArray(expected);
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -420,13 +389,13 @@ static Array3D<float> ArrayForDocR3Tests() {
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
- /*new_sizes=*/{24});
- auto expected_literal = Literal::CreateR1<float>(
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{24});
+ auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -435,33 +404,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) {
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
- /*new_sizes=*/{8, 3});
- auto expected_literal = Literal::CreateR2<float>({{10, 11, 12},
- {15, 16, 17},
- {20, 21, 22},
- {25, 26, 27},
- {30, 31, 32},
- {35, 36, 37},
- {40, 41, 42},
- {45, 46, 47}});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{8, 3});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{10, 11, 12},
+ {15, 16, 17},
+ {20, 21, 22},
+ {25, 26, 27},
+ {30, 31, 32},
+ {35, 36, 37},
+ {40, 41, 42},
+ {45, 46, 47}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
- /*new_sizes=*/{24});
- auto expected_literal = Literal::CreateR1<float>(
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{24});
+ auto expected_literal = LiteralUtil::CreateR1<float>(
{10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -470,33 +439,33 @@ XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) {
XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
- /*new_sizes=*/{8, 3});
- auto expected_literal = Literal::CreateR2<float>({{10, 20, 30},
- {40, 11, 21},
- {31, 41, 12},
- {22, 32, 42},
- {15, 25, 35},
- {45, 16, 26},
- {36, 46, 17},
- {27, 37, 47}});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{8, 3});
+ auto expected_literal = LiteralUtil::CreateR2<float>({{10, 20, 30},
+ {40, 11, 21},
+ {31, 41, 12},
+ {22, 32, 42},
+ {15, 25, 35},
+ {45, 16, 26},
+ {36, 46, 17},
+ {27, 37, 47}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
XlaBuilder builder(TestName());
- auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests());
+ auto input_literal = LiteralUtil::CreateFromArray(ArrayForDocR3Tests());
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
- /*new_sizes=*/{2, 6, 2});
- auto expected_literal = Literal::CreateR3<float>(
+ Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{2, 6, 2});
+ auto expected_literal = LiteralUtil::CreateR3<float>(
{{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
{{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
@@ -523,12 +492,12 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) {
Array4D<float> t2x2x2x3(2, 2, 2, 3);
auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
t2x2x2x3.FillWithYX(*filler2x3);
- auto input_literal = Literal::CreateFromArray(t2x2x2x3);
+ auto input_literal = LiteralUtil::CreateFromArray(t2x2x2x3);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
- auto expected_literal = Literal::CreateR2<float>(
+ Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3});
+ auto expected_literal = LiteralUtil::CreateR2<float>(
{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
6.0f}});
@@ -548,15 +517,15 @@ XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) {
t(1, 0, 0, 1) = 5;
t(1, 0, 1, 0) = 6;
t(1, 0, 1, 1) = 7;
- auto input_literal = Literal::CreateFromArray(t);
+ auto input_literal = LiteralUtil::CreateFromArray(t);
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
- /*new_sizes=*/{2, 4});
+ Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{2, 4});
auto expected_literal =
- Literal::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
+ LiteralUtil::CreateR2<float>({{0, 1, 2, 3}, {4, 5, 6, 7}});
ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -575,9 +544,9 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, input_literal, "input",
&b, &parameter);
- b.Reshape(parameter, dimensions, {});
+ Reshape(parameter, dimensions, {});
- auto expected_literal = Literal::CreateR0<float>(83.0f);
+ auto expected_literal = LiteralUtil::CreateR0<float>(83.0f);
ComputeAndCompareLiteral(&b, *expected_literal, {input.get()},
zero_error_spec_);
}
@@ -585,11 +554,11 @@ XLA_TEST_P(ReshapeTest, ToScalar) {
XLA_TEST_P(ReshapeTest, BadDimensions) {
XlaBuilder b(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
- b.Reshape(parameter, {}, {});
+ Reshape(parameter, {}, {});
EXPECT_THAT(
ExecuteToString(&b, {}),
::testing::HasSubstr("not a permutation of the operand dimensions"));
@@ -597,11 +566,11 @@ XLA_TEST_P(ReshapeTest, BadDimensions) {
XLA_TEST_P(ReshapeTest, BadNewSizes) {
XlaBuilder b(TestName());
- auto input_literal = Literal::CreateR1<float>({1.0f, 2.0f});
+ auto input_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b,
&parameter);
- b.Reshape(parameter, {1}, {});
+ Reshape(parameter, {1}, {});
EXPECT_THAT(ExecuteToString(&b, {}),
::testing::HasSubstr("mismatched element counts"));
}
@@ -609,7 +578,8 @@ XLA_TEST_P(ReshapeTest, BadNewSizes) {
XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
XlaBuilder builder(TestName());
// clang-format off
- auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D<float>{
+ auto input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
+ Array4D<float>{
{
{
{0, 1},
@@ -637,7 +607,7 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
+ Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
Array2D<float> expected_array({
{0, 1, 2, 3, 100, 101, 102, 103},
@@ -654,16 +624,16 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
.ConsumeValueOrDie();
std::unique_ptr<Literal> expected =
- Literal::CreateR2FromArray2D<float>(expected_array);
+ LiteralUtil::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = Literal::ConvertF32ToBF16(*expected);
+ expected = LiteralUtil::ConvertF32ToBF16(*expected);
}
EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
@@ -671,10 +641,10 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
+ Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
// clang-format off
- auto expected_literal = Literal::CreateR4<float>({
+ auto expected_literal = LiteralUtil::CreateR4<float>({
{{{0, 1, 2, 3}},
{{4, 5, 6, 7}}},
{{{100, 101, 102, 103}},
@@ -690,7 +660,7 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> input_literal = Literal::CreateR2<float>({
+ std::unique_ptr<Literal> input_literal = LiteralUtil::CreateR2<float>({
{0, 1, 2, 3, 4, 5, 6, 7},
{100, 101, 102, 103, 104, 105, 106, 107},
{200, 201, 202, 203, 204, 205, 206, 207},
@@ -698,10 +668,10 @@ XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input",
&builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
+ Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
// clang-format off
- auto expected_literal = Literal::CreateR4<float>({
+ auto expected_literal = LiteralUtil::CreateR4<float>({
{{{0, 100, 200, 1}},
{{101, 201, 2, 102}}},
{{{202, 3, 103, 203}},
@@ -723,15 +693,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
+ Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
+ LiteralUtil::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -745,15 +715,15 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
+ Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
+ LiteralUtil::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -768,20 +738,20 @@ XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
- /*new_sizes=*/{5, 60});
+ Reshape(parameter, /*dimensions=*/{0, 2, 1, 3},
+ /*new_sizes=*/{5, 60});
Array2D<float> expected_array(5, 60);
input.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* cell) {
expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
*cell;
});
- auto expected = Literal::CreateR2FromArray2D(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
@@ -795,13 +765,13 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
- /*new_sizes=*/{7, 2, 3, 5});
+ Reshape(parameter, /*dimensions=*/{3, 0, 1, 2},
+ /*new_sizes=*/{7, 2, 3, 5});
XlaComputation computation = builder.Build().ConsumeValueOrDie();
ExecutionOptions execution_options = execution_options_;
@@ -817,7 +787,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = Literal::ConvertF32ToBF16(*input_literal);
+ auto expected = LiteralUtil::ConvertF32ToBF16(*input_literal);
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else {
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
@@ -826,21 +796,21 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) {
XlaBuilder builder(TestName());
- auto literal_1x2x3x4 = Literal::CreateR4<float>(
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4<float>(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
&builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
- /*new_sizes=*/{1, 2, 3, 4});
+ Reshape(parameter, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{1, 2, 3, 4});
ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()});
}
XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
- auto literal_1x2x3x4 = Literal::CreateR4<float>(
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4<float>(
{{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
@@ -848,11 +818,11 @@ XLA_TEST_P(ReshapeTest, R4ToR4Reshape) {
XlaOp parameter;
auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input",
&builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
- /*new_sizes=*/{2, 4, 3, 1});
+ Reshape(parameter, /*dimensions=*/{1, 3, 2, 0},
+ /*new_sizes=*/{2, 4, 3, 1});
// clang-format off
- auto expected_2x4x3x1 = Literal::CreateR4<float>(
+ auto expected_2x4x3x1 = LiteralUtil::CreateR4<float>(
{{{{1}, {5}, {9}},
{{2}, {6}, {10}},
{{3}, {7}, {11}},
@@ -876,17 +846,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
- /*new_sizes=*/new_bounds);
+ Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
+ /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -905,17 +875,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
- /*new_sizes=*/new_bounds);
+ Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
+ /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -934,17 +904,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
- /*new_sizes=*/new_bounds);
+ Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
+ /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -964,17 +934,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
XlaBuilder builder(TestName());
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
- /*new_sizes=*/new_bounds);
+ Reshape(parameter, /*dimensions=*/{0, 1, 3, 2},
+ /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
@@ -993,17 +963,17 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
[&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
float* cell) { *cell = distribution(rng); });
std::unique_ptr<Literal> input_literal =
- Literal::CreateR4FromArray4DWithLayout(
+ LiteralUtil::CreateR4FromArray4DWithLayout(
input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
XlaOp parameter;
auto input_data = CreateParameterAndTransferLiteral(
0, *input_literal, "input", &builder, &parameter);
- builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
- /*new_sizes=*/new_bounds);
+ Reshape(parameter, /*dimensions=*/{1, 0, 2, 3},
+ /*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
+ LiteralUtil::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
->Relayout(input_literal->shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index e7bd142dc9..23f0d26d93 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -82,12 +82,12 @@ TEST_P(FloatReverseTest, Reverses) {
std::vector<float> input_vector(
ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims)));
std::iota(input_vector.begin(), input_vector.end(), 0.0);
- auto r1_literal = Literal::CreateR1<float>(input_vector);
+ auto r1_literal = LiteralUtil::CreateR1<float>(input_vector);
auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie();
XlaBuilder builder(TestName());
auto a = AddParam(*input_literal, &builder);
- builder.Rev(a, spec.reversal);
+ Rev(a, spec.reversal);
std::unique_ptr<Literal> expected = input_literal->CloneToUnique();
std::vector<int64> output_indices(spec.input_dims.size());
@@ -127,7 +127,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
}});
// clang-format on
- b.Rev(b.ConstantR4FromArray4D<uint8>(input), {0, 3});
+ Rev(ConstantR4FromArray4D<uint8>(&b, input), {0, 3});
// clang-format off
Array4D<uint8> expected({{
@@ -163,7 +163,7 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
});
// clang-format on
- b.Rev(b.ConstantR4FromArray4D<float>(input), {0, 1});
+ Rev(ConstantR4FromArray4D<float>(&b, input), {0, 1});
// clang-format off
Array4D<float> expected({
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
index 7cfca781ac..a620fe1908 100644
--- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/packed_literal_reader.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
index f334a8c131..a8193c2eac 100644
--- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -46,61 +46,62 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
};
TEST_F(RoundTripTransferTest, R0S32) {
- RoundTripTest(*Literal::CreateR0<int32>(42));
+ RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
}
TEST_F(RoundTripTransferTest, R0F32) {
- RoundTripTest(*Literal::CreateR0<float>(42.0));
+ RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
}
TEST_F(RoundTripTransferTest, R1F32_Len0) {
- RoundTripTest(*Literal::CreateR1<float>({}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({}));
}
TEST_F(RoundTripTransferTest, R1F32_Len2) {
- RoundTripTest(*Literal::CreateR1<float>({42.0, 64.0}));
+ RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
}
TEST_F(RoundTripTransferTest, R1F32_Len256) {
std::vector<float> values(256);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1024) {
std::vector<float> values(1024);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len1025) {
std::vector<float> values(1025);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R1F32_Len4096) {
std::vector<float> values(4096);
std::iota(values.begin(), values.end(), 1.0);
- RoundTripTest(*Literal::CreateR1<float>(values));
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
}
TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
- RoundTripTest(*Literal::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+ RoundTripTest(
+ *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
}
TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
- RoundTripTest(*Literal::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+ RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
}
TEST_F(RoundTripTransferTest, R3F32) {
RoundTripTest(
- *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
- {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
}
TEST_F(RoundTripTransferTest, R4F32) {
- RoundTripTest(*Literal::CreateR4<float>({{
+ RoundTripTest(*LiteralUtil::CreateR4<float>({{
{{10, 11, 12, 13}, {14, 15, 16, 17}},
{{18, 19, 20, 21}, {22, 23, 24, 25}},
{{26, 27, 28, 29}, {30, 31, 32, 33}},
@@ -108,33 +109,36 @@ TEST_F(RoundTripTransferTest, R4F32) {
}
TEST_F(RoundTripTransferTest, EmptyTuple) {
- RoundTripTest(*Literal::MakeTuple({}));
+ RoundTripTest(*LiteralUtil::MakeTuple({}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({1, 2}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR1<float>({}).get(),
- Literal::CreateR1<float>({3, 4}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
}
TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
- RoundTripTest(*Literal::MakeTuple({Literal::CreateR0<float>(1.0).get(),
- Literal::CreateR1<int>({2, 3}).get()}));
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
+ LiteralUtil::CreateR1<int>({2, 3}).get()}));
}
// Below two tests are added to identify the cost of large data transfers.
TEST_F(RoundTripTransferTest, R2F32_Large) {
- RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+ RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
}
TEST_F(RoundTripTransferTest, R4F32_Large) {
Array4D<float> array4d(2, 2, 256, 256);
array4d.FillWithMultiples(1.0f);
- RoundTripTest(*Literal::CreateR4FromArray4D<float>(array4d));
+ RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
index 308d3fc78a..3b603c0d31 100644
--- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -44,74 +45,75 @@ class ScalarComputationsTest : public ClientLibraryTestBase {
protected:
// A template for building and running a binary comparison test.
template <typename NativeT>
- void TestCompare(
- NativeT lhs, NativeT rhs, bool expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
- XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs);
- XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
+ XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
template <typename NativeT>
void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
- XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&,
- tensorflow::gtl::ArraySlice<int64>)) {
+ std::function<XlaOp(const XlaOp&, const XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)>
+ op) {
XlaBuilder builder(TestName());
- XlaOp lhs_op = builder.ConstantR0<NativeT>(lhs);
- XlaOp rhs_op = builder.ConstantR0<NativeT>(rhs);
- XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
+ XlaOp lhs_op = ConstantR0<NativeT>(&builder, lhs);
+ XlaOp rhs_op = ConstantR0<NativeT>(&builder, rhs);
+ op(lhs_op, rhs_op, {});
ComputeAndCompareR0<NativeT>(&builder, expected, {});
}
};
XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) {
XlaBuilder builder(TestName());
- builder.ConstantR0<float>(2.1f);
+ ConstantR0<float>(&builder, 2.1f);
ComputeAndCompareR0<float>(&builder, 2.1f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) {
XlaBuilder builder(TestName());
- builder.Neg(builder.ConstantR0<float>(2.1f));
+ Neg(ConstantR0<float>(&builder, 2.1f));
ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) {
XlaBuilder builder(TestName());
- builder.Neg(builder.ConstantR0<int32>(2));
+ Neg(ConstantR0<int32>(&builder, 2));
ComputeAndCompareR0<int32>(&builder, -2, {});
}
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
XlaBuilder builder(TestName());
- builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+ Add(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
XlaBuilder builder(TestName());
- builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+ Add(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
ComputeAndCompareR0<int32>(&builder, 7, {});
}
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
XlaBuilder builder(TestName());
- builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
+ Add(ConstantR0<uint32>(&builder, 35), ConstantR0<uint32>(&builder, 57));
ComputeAndCompareR0<uint32>(&builder, 92, {});
}
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
XlaBuilder builder(TestName());
- builder.Add(builder.ConstantR0<uint8>(35), builder.ConstantR0<uint8>(57));
+ Add(ConstantR0<uint8>(&builder, 35), ConstantR0<uint8>(&builder, 57));
ComputeAndCompareR0<uint8>(&builder, 92, {});
}
@@ -120,7 +122,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
XlaBuilder builder(TestName());
const uint64 a = static_cast<uint64>(1) << 63;
const uint64 b = a + 1;
- builder.Add(builder.ConstantR0<uint64>(a), builder.ConstantR0<uint64>(b));
+ Add(ConstantR0<uint64>(&builder, a), ConstantR0<uint64>(&builder, b));
ComputeAndCompareR0<uint64>(&builder, a + b, {});
}
@@ -129,40 +131,39 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
XlaBuilder builder(TestName());
const int64 a = static_cast<int64>(1) << 62;
const int64 b = a - 1;
- builder.Add(builder.ConstantR0<int64>(a), builder.ConstantR0<int64>(b));
+ Add(ConstantR0<int64>(&builder, a), ConstantR0<int64>(&builder, b));
ComputeAndCompareR0<int64>(&builder, a + b, {});
}
XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
XlaBuilder builder(TestName());
- builder.Add(builder.ConstantR0<double>(0.25),
- builder.ConstantR0<double>(3.5));
+ Add(ConstantR0<double>(&builder, 0.25), ConstantR0<double>(&builder, 3.5));
ComputeAndCompareR0<double>(&builder, 3.75, {});
}
XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
XlaBuilder builder(TestName());
- builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+ Sub(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f));
ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
XlaBuilder builder(TestName());
- builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+ Sub(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5));
ComputeAndCompareR0<int32>(&builder, -3, {});
}
XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
XlaBuilder builder(TestName());
- auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
- builder.ConvertElementType(a, F32);
+ auto a = Parameter(&builder, 0, ShapeUtil::MakeShape(S64, {}), "a");
+ ConvertElementType(a, F32);
int64 value = 3LL << 35;
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
@@ -171,9 +172,8 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
XlaBuilder builder(TestName());
- builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
- builder.ConstantR0<float>(5.5f)),
- builder.ConstantR0<float>(0.5f));
+ Mul(Mul(ConstantR0<float>(&builder, 2.1f), ConstantR0<float>(&builder, 5.5f)),
+ ConstantR0<float>(&builder, 0.5f));
ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
}
@@ -190,7 +190,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
for (int32 x : data) {
for (int32 y : data) {
XlaBuilder builder(TestName());
- builder.Mul(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+ Mul(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
// Signed integer overflow is undefined behavior in C++. Convert the input
// integers to unsigned, perform the multiplication unsigned, and convert
@@ -209,7 +209,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
for (uint32 x : data) {
for (uint32 y : data) {
XlaBuilder builder(TestName());
- builder.Mul(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+ Mul(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
uint32 expected = x * y;
ComputeAndCompareR0<uint32>(&builder, expected, {});
@@ -219,18 +219,17 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
XlaBuilder builder(TestName());
- builder.Mul(
- builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
- builder.ConstantR0<int32>(1));
+ Mul(Mul(ConstantR0<int32>(&builder, 2), ConstantR0<int32>(&builder, 5)),
+ ConstantR0<int32>(&builder, 1));
ComputeAndCompareR0<int32>(&builder, 10, {});
}
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XlaBuilder builder(TestName());
- std::unique_ptr<Literal> a_literal = Literal::CreateR0<float>(2.1f);
- std::unique_ptr<Literal> b_literal = Literal::CreateR0<float>(5.5f);
- std::unique_ptr<Literal> c_literal = Literal::CreateR0<float>(0.5f);
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
@@ -239,10 +238,10 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
std::unique_ptr<GlobalData> c_data =
client_->TransferToServer(*c_literal).ConsumeValueOrDie();
- XlaOp a = builder.Parameter(0, a_literal->shape(), "a");
- XlaOp b = builder.Parameter(1, b_literal->shape(), "b");
- XlaOp c = builder.Parameter(2, c_literal->shape(), "c");
- builder.Mul(builder.Mul(a, b), c);
+ XlaOp a = Parameter(&builder, 0, a_literal->shape(), "a");
+ XlaOp b = Parameter(&builder, 1, b_literal->shape(), "b");
+ XlaOp c = Parameter(&builder, 2, c_literal->shape(), "c");
+ Mul(Mul(a, b), c);
ComputeAndCompareR0<float>(&builder, 5.775f,
{a_data.get(), b_data.get(), c_data.get()},
@@ -251,14 +250,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
XlaBuilder builder(TestName());
- builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
+ Div(ConstantR0<float>(&builder, 5.0f), ConstantR0<float>(&builder, 2.5f));
ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
XlaBuilder builder(TestName());
- builder.Rem(builder.ConstantR0<float>(2.5f), builder.ConstantR0<float>(5.0f));
+ Rem(ConstantR0<float>(&builder, 2.5f), ConstantR0<float>(&builder, 5.0f));
ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
}
@@ -281,8 +280,8 @@ class DivS32Test : public ClientLibraryTestBase,
XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
DivS32Params p = GetParam();
XlaBuilder builder(TestName());
- builder.Div(builder.ConstantR0<int32>(p.dividend),
- builder.ConstantR0<int32>(p.divisor));
+ Div(ConstantR0<int32>(&builder, p.dividend),
+ ConstantR0<int32>(&builder, p.divisor));
ComputeAndCompareR0<int32>(&builder, p.quotient, {});
}
@@ -290,8 +289,8 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) {
XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) {
DivS32Params p = GetParam();
XlaBuilder builder(TestName());
- builder.Rem(builder.ConstantR0<int32>(p.dividend),
- builder.ConstantR0<int32>(p.divisor));
+ Rem(ConstantR0<int32>(&builder, p.dividend),
+ ConstantR0<int32>(&builder, p.divisor));
ComputeAndCompareR0<int32>(&builder, p.remainder, {});
}
@@ -305,7 +304,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) {
CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, &dividend);
auto divisord =
CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
- builder.Div(dividend, divisor);
+ Div(dividend, divisor);
ComputeAndCompareR0<int32>(&builder, p.quotient,
{dividendd.get(), divisord.get()});
@@ -320,7 +319,7 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) {
CreateR0Parameter<int32>(p.dividend, 0, "dividend", &builder, &dividend);
auto divisord =
CreateR0Parameter<int32>(p.divisor, 1, "divisor", &builder, &divisor);
- builder.Rem(dividend, divisor);
+ Rem(dividend, divisor);
ComputeAndCompareR0<int32>(&builder, p.remainder,
{dividendd.get(), divisord.get()});
@@ -367,18 +366,18 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
XlaBuilder builder(TestName());
XlaOp dividend =
- builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
XlaOp divisor =
- builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor");
- builder.Div(dividend, divisor);
+ Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
+ Div(dividend, divisor);
TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build());
}
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -389,7 +388,8 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend / divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -408,18 +408,18 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
XlaBuilder builder(TestName());
XlaOp dividend =
- builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend");
+ Parameter(&builder, 0, ShapeUtil::MakeShape(U32, {}), "dividend");
XlaOp divisor =
- builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor");
- builder.Rem(dividend, divisor);
+ Parameter(&builder, 1, ShapeUtil::MakeShape(U32, {}), "divisor");
+ Rem(dividend, divisor);
TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build());
}
for (uint32 divisor : vals) {
if (divisor != 0) {
for (uint32 dividend : vals) {
- auto dividend_literal = Literal::CreateR0<uint32>(dividend);
- auto divisor_literal = Literal::CreateR0<uint32>(divisor);
+ auto dividend_literal = LiteralUtil::CreateR0<uint32>(dividend);
+ auto divisor_literal = LiteralUtil::CreateR0<uint32>(divisor);
TF_ASSERT_OK_AND_ASSIGN(auto dividend_data,
client_->TransferToServer(*dividend_literal));
TF_ASSERT_OK_AND_ASSIGN(auto divisor_data,
@@ -430,7 +430,8 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
{dividend_data.get(), divisor_data.get()},
&execution_options_)
.ConsumeValueOrDie();
- auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
+ auto expected_literal =
+ LiteralUtil::CreateR0<uint32>(dividend % divisor);
EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
@@ -439,10 +440,10 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) {
XlaBuilder builder(TestName());
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
- builder.Rem(x, builder.ConstantR0<int32>(80000));
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(S32, {}), "x");
+ Rem(x, ConstantR0<int32>(&builder, 80000));
- std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(87919);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(*literal));
ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
}
@@ -451,15 +452,15 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
XlaBuilder builder(TestName());
// This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
// as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
- builder.Div(builder.ConstantR0<uint32>(0xFFFFFFFE),
- builder.ConstantR0<uint32>(2));
+ Div(ConstantR0<uint32>(&builder, 0xFFFFFFFE),
+ ConstantR0<uint32>(&builder, 2));
ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
}
XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) {
XlaBuilder builder(TestName());
- builder.Rem(builder.ConstantR0<uint32>(11), builder.ConstantR0<uint32>(3));
+ Rem(ConstantR0<uint32>(&builder, 11), ConstantR0<uint32>(&builder, 3));
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
@@ -468,7 +469,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) {
for (bool x : {false, true}) {
for (bool y : {false, true}) {
XlaBuilder builder(TestName());
- builder.And(builder.ConstantR0<bool>(x), builder.ConstantR0<bool>(y));
+ And(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
ComputeAndCompareR0<bool>(&builder, x && y, {});
}
@@ -479,7 +480,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) {
for (int32 x : {0, 8}) {
for (int32 y : {1, -16}) {
XlaBuilder builder(TestName());
- builder.And(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+ And(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
ComputeAndCompareR0<int32>(&builder, x & y, {});
}
@@ -490,7 +491,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) {
for (uint32 x : {0, 8}) {
for (uint32 y : {1, 16}) {
XlaBuilder builder(TestName());
- builder.And(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+ And(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
ComputeAndCompareR0<uint32>(&builder, x & y, {});
}
@@ -501,7 +502,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) {
for (bool x : {false, true}) {
for (bool y : {false, true}) {
XlaBuilder builder(TestName());
- builder.Or(builder.ConstantR0<bool>(x), builder.ConstantR0<bool>(y));
+ Or(ConstantR0<bool>(&builder, x), ConstantR0<bool>(&builder, y));
ComputeAndCompareR0<bool>(&builder, x || y, {});
}
@@ -512,7 +513,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) {
for (int32 x : {0, 8}) {
for (int32 y : {1, -16}) {
XlaBuilder builder(TestName());
- builder.Or(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+ Or(ConstantR0<int32>(&builder, x), ConstantR0<int32>(&builder, y));
ComputeAndCompareR0<int32>(&builder, x | y, {});
}
@@ -523,7 +524,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) {
for (uint32 x : {0, 8}) {
for (uint32 y : {1, 16}) {
XlaBuilder builder(TestName());
- builder.Or(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+ Or(ConstantR0<uint32>(&builder, x), ConstantR0<uint32>(&builder, y));
ComputeAndCompareR0<uint32>(&builder, x | y, {});
}
@@ -533,7 +534,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) {
XLA_TEST_F(ScalarComputationsTest, NotBool) {
for (bool x : {false, true}) {
XlaBuilder builder(TestName());
- builder.Not(builder.ConstantR0<bool>(x));
+ Not(ConstantR0<bool>(&builder, x));
ComputeAndCompareR0<bool>(&builder, !x, {});
}
@@ -542,7 +543,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) {
XLA_TEST_F(ScalarComputationsTest, NotS32) {
for (int32 x : {-1, 0, 1}) {
XlaBuilder builder(TestName());
- builder.Not(builder.ConstantR0<int32>(x));
+ Not(ConstantR0<int32>(&builder, x));
ComputeAndCompareR0<int32>(&builder, ~x, {});
}
@@ -551,7 +552,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) {
XLA_TEST_F(ScalarComputationsTest, NotU32) {
for (uint32 x : {0, 1, 2}) {
XlaBuilder builder(TestName());
- builder.Not(builder.ConstantR0<uint32>(x));
+ Not(ConstantR0<uint32>(&builder, x));
ComputeAndCompareR0<uint32>(&builder, ~x, {});
}
@@ -559,18 +560,18 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) {
XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) {
XlaBuilder builder(TestName());
- builder.Select(builder.ConstantR0<bool>(true), // The predicate.
- builder.ConstantR0<float>(123.0f), // The value on true.
- builder.ConstantR0<float>(42.0f)); // The value on false.
+ Select(ConstantR0<bool>(&builder, true), // The predicate.
+ ConstantR0<float>(&builder, 123.0f), // The value on true.
+ ConstantR0<float>(&builder, 42.0f)); // The value on false.
ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
XlaBuilder builder(TestName());
- builder.Select(builder.ConstantR0<bool>(false), // The predicate.
- builder.ConstantR0<float>(123.0f), // The value on true.
- builder.ConstantR0<float>(42.0f)); // The value on false.
+ Select(ConstantR0<bool>(&builder, false), // The predicate.
+ ConstantR0<float>(&builder, 123.0f), // The value on true.
+ ConstantR0<float>(&builder, 42.0f)); // The value on false.
ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
}
@@ -579,313 +580,311 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) {
// templatized comparison tests.
XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) {
XlaBuilder builder(TestName());
- builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
+ Gt(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 1.0f));
ComputeAndCompareR0<bool>(&builder, true, {});
}
// S32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<int32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
- TestCompare<int32>(3, 3, true, &XlaBuilder::Eq);
+ TestCompare<int32>(3, 3, true, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<int32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeS32) {
- TestCompare<int32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<int32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtS32) {
- TestCompare<int32>(1, 5, false, &XlaBuilder::Gt);
+ TestCompare<int32>(1, 5, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeS32) {
- TestCompare<int32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<int32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtS32) {
- TestCompare<int32>(9, 7, false, &XlaBuilder::Lt);
+ TestCompare<int32>(9, 7, false, &Lt);
TestCompare<int32>(std::numeric_limits<int32>::min(),
- std::numeric_limits<int32>::max(), true, &XlaBuilder::Lt);
+ std::numeric_limits<int32>::max(), true, &Lt);
}
// U32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Eq);
+ TestCompare<uint32>(2, 1, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeU32) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ne);
+ TestCompare<uint32>(2, 1, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
- TestCompare<uint32>(2, 1, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(2, 1, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
- TestCompare<uint32>(3, 3, true, &XlaBuilder::Ge);
+ TestCompare<uint32>(3, 3, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtU32) {
- TestCompare<uint32>(1, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 5, false, &XlaBuilder::Gt);
- TestCompare<uint32>(5, 1, true, &XlaBuilder::Gt);
+ TestCompare<uint32>(1, 5, false, &Gt);
+ TestCompare<uint32>(5, 5, false, &Gt);
+ TestCompare<uint32>(5, 1, true, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeU32) {
- TestCompare<uint32>(2, 1, false, &XlaBuilder::Le);
+ TestCompare<uint32>(2, 1, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtU32) {
- TestCompare<uint32>(9, 7, false, &XlaBuilder::Lt);
- TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
- &XlaBuilder::Lt);
+ TestCompare<uint32>(9, 7, false, &Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true, &Lt);
}
// F32 comparisons.
XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) {
- TestCompare<float>(2.0, 1.3, false, &XlaBuilder::Eq);
+ TestCompare<float>(2.0, 1.3, false, &Eq);
}
XLA_TEST_F(ScalarComputationsTest, CompareNeF32) {
- TestCompare<float>(2.0, 1.3, true, &XlaBuilder::Ne);
+ TestCompare<float>(2.0, 1.3, true, &Ne);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
- TestCompare<float>(2.0, 1.9, true, &XlaBuilder::Ge);
+ TestCompare<float>(2.0, 1.9, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
- TestCompare<float>(3.5, 3.5, true, &XlaBuilder::Ge);
+ TestCompare<float>(3.5, 3.5, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGtF32) {
- TestCompare<float>(1.0, 5.2, false, &XlaBuilder::Gt);
+ TestCompare<float>(1.0, 5.2, false, &Gt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLeF32) {
- TestCompare<float>(2.0, 1.2, false, &XlaBuilder::Le);
+ TestCompare<float>(2.0, 1.2, false, &Le);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32) {
- TestCompare<float>(9.0, 7.2, false, &XlaBuilder::Lt);
+ TestCompare<float>(9.0, 7.2, false, &Lt);
}
// F32 comparisons with exceptional values. The test names encode the
// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, true, &XlaBuilder::Lt);
+ TestCompare<float>(-INFINITY, -0.0, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, false, &XlaBuilder::Lt);
+ TestCompare<float>(-0.0, 0.0, false, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, true, &XlaBuilder::Lt);
+ TestCompare<float>(0.0, INFINITY, true, &Lt);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
- TestCompare<float>(-INFINITY, -0.0, false, &XlaBuilder::Ge);
+ TestCompare<float>(-INFINITY, -0.0, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
// Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
- TestCompare<float>(-0.0, 0.0, true, &XlaBuilder::Ge);
+ TestCompare<float>(-0.0, 0.0, true, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
- TestCompare<float>(0.0, INFINITY, false, &XlaBuilder::Ge);
+ TestCompare<float>(0.0, INFINITY, false, &Ge);
}
XLA_TEST_F(ScalarComputationsTest, ExpScalar) {
XlaBuilder builder(TestName());
- builder.Exp(builder.ConstantR0<float>(2.0f));
+ Exp(ConstantR0<float>(&builder, 2.0f));
ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, LogScalar) {
XlaBuilder builder("log");
- builder.Log(builder.ConstantR0<float>(2.0f));
+ Log(ConstantR0<float>(&builder, 2.0f));
ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, TanhScalar) {
XlaBuilder builder(TestName());
- builder.Tanh(builder.ConstantR0<float>(2.0f));
+ Tanh(ConstantR0<float>(&builder, 2.0f));
ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
XlaBuilder builder(TestName());
- builder.Tanh(builder.ConstantR0<double>(2.0));
+ Tanh(ConstantR0<double>(&builder, 2.0));
ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, PowScalar) {
XlaBuilder builder(TestName());
- builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
+ Pow(ConstantR0<float>(&builder, 2.0f), ConstantR0<float>(&builder, 3.0f));
ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound.
- builder.ConstantR0<int32>(5), // The operand to be clamped.
- builder.ConstantR0<int32>(3)); // The upper bound.
+ Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
+ ConstantR0<int32>(&builder, 5), // The operand to be clamped.
+ ConstantR0<int32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<int32>(&builder, 3, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound.
- builder.ConstantR0<int32>(2), // The operand to be clamped.
- builder.ConstantR0<int32>(3)); // The upper bound.
+ Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
+ ConstantR0<int32>(&builder, 2), // The operand to be clamped.
+ ConstantR0<int32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<int32>(&builder, 2, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<int32>(-1), // The lower bound.
- builder.ConstantR0<int32>(-5), // The operand to be clamped.
- builder.ConstantR0<int32>(3)); // The upper bound.
+ Clamp(ConstantR0<int32>(&builder, -1), // The lower bound.
+ ConstantR0<int32>(&builder, -5), // The operand to be clamped.
+ ConstantR0<int32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<int32>(&builder, -1, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound.
- builder.ConstantR0<uint32>(5), // The operand to be clamped.
- builder.ConstantR0<uint32>(3)); // The upper bound.
+ Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
+ ConstantR0<uint32>(&builder, 5), // The operand to be clamped.
+ ConstantR0<uint32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<uint32>(&builder, 3, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound.
- builder.ConstantR0<uint32>(2), // The operand to be clamped.
- builder.ConstantR0<uint32>(3)); // The upper bound.
+ Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
+ ConstantR0<uint32>(&builder, 2), // The operand to be clamped.
+ ConstantR0<uint32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<uint32>(&builder, 2, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<uint32>(1), // The lower bound.
- builder.ConstantR0<uint32>(0), // The operand to be clamped.
- builder.ConstantR0<uint32>(3)); // The upper bound.
+ Clamp(ConstantR0<uint32>(&builder, 1), // The lower bound.
+ ConstantR0<uint32>(&builder, 0), // The operand to be clamped.
+ ConstantR0<uint32>(&builder, 3)); // The upper bound.
ComputeAndCompareR0<uint32>(&builder, 1, {});
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
- builder.ConstantR0<float>(5.0f), // The operand to be clamped.
- builder.ConstantR0<float>(3.0f)); // The upper bound.
+ Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
+ ConstantR0<float>(&builder, 5.0f), // The operand to be clamped.
+ ConstantR0<float>(&builder, 3.0f)); // The upper bound.
ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
- builder.ConstantR0<float>(2.5f), // The operand to be clamped.
- builder.ConstantR0<float>(3.0f)); // The upper bound.
+ Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
+ ConstantR0<float>(&builder, 2.5f), // The operand to be clamped.
+ ConstantR0<float>(&builder, 3.0f)); // The upper bound.
ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) {
XlaBuilder builder(TestName());
- builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
- builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
- builder.ConstantR0<float>(3.0f)); // The upper bound.
+ Clamp(ConstantR0<float>(&builder, 2.0f), // The lower bound.
+ ConstantR0<float>(&builder, -5.0f), // The operand to be clamped.
+ ConstantR0<float>(&builder, 3.0f)); // The upper bound.
ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
}
XLA_TEST_F(ScalarComputationsTest, MinS32Above) {
- TestMinMax<int32>(10, 3, 3, &XlaBuilder::Min);
+ TestMinMax<int32>(10, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinS32Below) {
- TestMinMax<int32>(-100, 3, -100, &XlaBuilder::Min);
+ TestMinMax<int32>(-100, 3, -100, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Above) {
- TestMinMax<int32>(10, 3, 10, &XlaBuilder::Max);
+ TestMinMax<int32>(10, 3, 10, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxS32Below) {
- TestMinMax<int32>(-100, 3, 3, &XlaBuilder::Max);
+ TestMinMax<int32>(-100, 3, 3, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, 3, &XlaBuilder::Min);
+ TestMinMax<uint32>(large, 3, 3, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinU32Below) {
- TestMinMax<uint32>(0, 5, 0, &XlaBuilder::Min);
+ TestMinMax<uint32>(0, 5, 0, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Above) {
const uint32 large = std::numeric_limits<int32>::max();
- TestMinMax<uint32>(large, 3, large, &XlaBuilder::Max);
+ TestMinMax<uint32>(large, 3, large, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxU32Below) {
- TestMinMax<uint32>(0, 5, 5, &XlaBuilder::Max);
+ TestMinMax<uint32>(0, 5, 5, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 3.1f, &XlaBuilder::Min);
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min);
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Min);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Min);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Min);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Min);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Above) {
- TestMinMax<float>(10.1f, 3.1f, 10.1f, &XlaBuilder::Max);
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxF32Below) {
- TestMinMax<float>(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max);
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &Max);
}
XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) {
SetFastMathDisabled(true);
- TestMinMax<float>(NAN, 3.1f, NAN, &XlaBuilder::Max);
- TestMinMax<float>(-3.1f, NAN, NAN, &XlaBuilder::Max);
+ TestMinMax<float>(NAN, 3.1f, NAN, &Max);
+ TestMinMax<float>(-3.1f, NAN, NAN, &Max);
}
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
// Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
XlaBuilder b(TestName());
- b.Div(
- b.Sub(b.Mul(b.ConstantR0<float>(1),
- b.Mul(b.Sub(b.ConstantR0<float>(3), b.ConstantR0<float>(1)),
- b.Add(b.ConstantR0<float>(7), b.ConstantR0<float>(0)))),
- b.ConstantR0<float>(4)),
- b.ConstantR0<float>(20));
+ Div(Sub(Mul(ConstantR0<float>(&b, 1),
+ Mul(Sub(ConstantR0<float>(&b, 3), ConstantR0<float>(&b, 1)),
+ Add(ConstantR0<float>(&b, 7), ConstantR0<float>(&b, 0)))),
+ ConstantR0<float>(&b, 4)),
+ ConstantR0<float>(&b, 20));
ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
}
@@ -893,30 +892,18 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
// Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
XlaBuilder b(TestName());
- b.Sub(b.Mul(b.ConstantR0<int32>(1),
- b.Mul(b.Sub(b.ConstantR0<int32>(3), b.ConstantR0<int32>(1)),
- b.Add(b.ConstantR0<int32>(7), b.ConstantR0<int32>(0)))),
- b.ConstantR0<int32>(4));
+ Sub(Mul(ConstantR0<int32>(&b, 1),
+ Mul(Sub(ConstantR0<int32>(&b, 3), ConstantR0<int32>(&b, 1)),
+ Add(ConstantR0<int32>(&b, 7), ConstantR0<int32>(&b, 0)))),
+ ConstantR0<int32>(&b, 4));
ComputeAndCompareR0<int32>(&b, 10, {});
}
-XLA_TEST_F(ScalarComputationsTest, SqrtF320) {
- XlaBuilder builder(TestName());
- Literal zero_literal = Literal::Zero(PrimitiveType::F32);
-
- std::unique_ptr<GlobalData> zero_data =
- client_->TransferToServer(zero_literal).ConsumeValueOrDie();
-
- XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero");
- builder.SqrtF32(zero);
-
- ComputeAndCompareR0<float>(&builder, 0.0f, {zero_data.get()}, error_spec_);
-}
XLA_TEST_F(ScalarComputationsTest, RoundScalar) {
XlaBuilder builder(TestName());
- builder.Round(builder.ConstantR0<float>(1.4f));
+ Round(ConstantR0<float>(&builder, 1.4f));
ComputeAndCompareR0<float>(&builder, 1.0f, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index 7015e5a6a3..b1f1e69d3c 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -73,16 +73,16 @@ XLA_TEST_P(SelectAndScatterTest, ParamTest) {
auto operand_shape = GetParam().operand_shape;
Array<float> o(operand_shape);
o.FillRandom(1.5f);
- auto operand = builder_.ConstantFromArray(o);
+ auto operand = ConstantFromArray(&builder_, o);
auto source_shape = GetParam().source_shape;
Array<float> s(source_shape);
s.FillRandom(12.0f);
- auto source = builder_.ConstantFromArray(s);
+ auto source = ConstantFromArray(&builder_, s);
- builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
- GetParam().window_strides, GetParam().padding_type,
- source, builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
+ GetParam().window_strides, GetParam().padding_type, source,
+ ConstantR0<float>(&builder_, 0.0f), add_f32_);
ComputeAndCompare(&builder_, {}, ErrorSpec(1e-5));
}
@@ -197,110 +197,110 @@ INSTANTIATE_TEST_CASE_P(
// Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
- const auto operand = builder_.ConstantR1<float>({});
- const auto source = builder_.ConstantR1<float>({});
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
- /*window_strides=*/{3}, Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ const auto operand = ConstantR1<float>(&builder_, {});
+ const auto source = ConstantR1<float>(&builder_, {});
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ ConstantR0<float>(&builder_, 0.0f), add_f32_);
ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
}
// Test for F32 1D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest, R1F32) {
const auto operand =
- builder_.ConstantR1<float>({1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
- const auto source = builder_.ConstantR1<float>({34.f, 42.f});
+ ConstantR1<float>(&builder_, {1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
+ const auto source = ConstantR1<float>(&builder_, {34.f, 42.f});
const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
- /*window_strides=*/{3}, Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ ConstantR0<float>(&builder_, 0.0f), add_f32_);
ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
}
// Test for S32 1D array, when windows do not overlap and the init value is 1.
XLA_TEST_F(SelectAndScatterTest, R1S32) {
- const auto operand = builder_.ConstantR1<int32>({-1, 0, 6, 4, -4, 10});
- const auto source = builder_.ConstantR1<int32>({-10, 20});
+ const auto operand = ConstantR1<int32>(&builder_, {-1, 0, 6, 4, -4, 10});
+ const auto source = ConstantR1<int32>(&builder_, {-10, 20});
const std::vector<int32> expected = {1, 1, -9, 1, 1, 21};
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
- /*window_strides=*/{3}, Padding::kValid, source,
- builder_.ConstantR0<int32>(1), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ ConstantR0<int32>(&builder_, 1), add_s32_);
ComputeAndCompareR1<int32>(&builder_, expected, {});
}
// Test for S32 1D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
- const auto operand = builder_.ConstantR1<int32>({1, 9, 3, 7, 5, 6});
- const auto source = builder_.ConstantR1<int32>({34, 42, 53, 19});
+ const auto operand = ConstantR1<int32>(&builder_, {1, 9, 3, 7, 5, 6});
+ const auto source = ConstantR1<int32>(&builder_, {34, 42, 53, 19});
const std::vector<int32> expected = {0, 76, 0, 72, 0, 0};
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
- /*window_strides=*/{1}, Padding::kValid, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR1<int32>(&builder_, expected, {});
}
// Test for S32 2D array, when windows do not overlap.
XLA_TEST_F(SelectAndScatterTest, R2S32) {
const auto operand =
- builder_.ConstantR2<int32>({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
- const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
+ const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
- /*window_strides=*/{2, 3}, Padding::kValid, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR2<int32>(&builder_, expected, {});
}
// Test for tie breaking rule in ge_f32_. When a tie is present, the operand
// that has the lower lexicographical order (smaller index) should be chosen.
XLA_TEST_F(SelectAndScatterTest, R2F32Tie) {
- const auto operand = builder_.ConstantR2<float>(
- {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
- const auto source = builder_.ConstantR2<float>(
- {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
+ const auto operand = ConstantR2<float>(
+ &builder_, {{0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}, {0.f, 0.f, 0.f}});
+ const auto source = ConstantR2<float>(
+ &builder_, {{1.0f, 2.0f, 3.0f}, {4.f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
Array2D<float> expected(
{{12.f, 9.f, 0.f}, {15.f, 9.f, 0.f}, {0.f, 0.f, 0.f}});
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
- /*window_strides=*/{1, 1}, Padding::kSame, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3, 3},
+ /*window_strides=*/{1, 1}, Padding::kSame, source,
+ ConstantR0<float>(&builder_, 0.0f), add_f32_);
ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
}
// Similar to SelectAndScatterTest.R2S32 but the input is transposed.
XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
- const auto operand = builder_.ConstantR2<int32>(
- {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
+ const auto operand = ConstantR2<int32>(
+ &builder_, {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
const auto reshape =
- builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
- const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
+ const auto source = ConstantR2<int32>(&builder_, {{2, 6}});
Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
- builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
- /*window_strides=*/{2, 3}, Padding::kValid, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR2<int32>(&builder_, expected, {});
}
// Test for S32 2D array, when windows overlap with each other.
XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
const auto operand =
- builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
- const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
- /*window_strides=*/{1, 1}, Padding::kValid, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR2<int32>(&builder_, expected, {});
}
// Test for S32 2D array, when the padding is Padding::kSAME.
XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
const auto operand =
- builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
- const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = ConstantR2<int32>(&builder_, {{2, 6, 4}});
Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
- /*window_strides=*/{2, 2}, Padding::kSame, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{2, 2}, Padding::kSame, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR2<int32>(&builder_, expected, {});
}
@@ -308,25 +308,26 @@ XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
// with each other.
XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
const auto operand =
- builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ ConstantR2<int32>(&builder_, {{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
const auto source =
- builder_.ConstantR2<int32>({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
+ ConstantR2<int32>(&builder_, {{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
- builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
- /*window_strides=*/{1, 1}, Padding::kSame, source,
- builder_.ConstantR0<int32>(0), add_s32_);
+ SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kSame, source,
+ ConstantR0<int32>(&builder_, 0), add_s32_);
ComputeAndCompareR2<int32>(&builder_, expected, {});
}
XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
- const auto operand = builder_.ConstantR2<float>(
- {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
- const auto source = builder_.ConstantR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ const auto operand = ConstantR2<float>(
+ &builder_, {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
+ const auto source =
+ ConstantR2<float>(&builder_, {{1.0f, 2.0f}, {3.0f, 4.0f}});
Array2D<float> expected(
{{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
- /*window_strides=*/{1, 1}, Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ ConstantR0<float>(&builder_, 0.0f), add_f32_);
ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
}
@@ -342,16 +343,16 @@ TEST_F(SelectAndScatterTest, R4F32Valid) {
{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
Array4D<float> o(4, 6, 15, 220);
o.FillWithPZ(pzo);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = ConstantR4FromArray4D(&builder_, o);
Array4D<float> e(4, 6, 15, 220);
e.FillWithPZ(pze);
Array4D<float> s(2, 2, 15, 220);
s.FillWithPZ(pzs);
- auto source = builder_.ConstantR4FromArray4D(s);
+ auto source = ConstantR4FromArray4D(&builder_, s);
s.FillWithPZ(pzs);
- builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
- Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
+ add_f32_);
ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
}
@@ -367,16 +368,16 @@ TEST_F(SelectAndScatterTest, R4F32Overlap) {
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
Array4D<float> o(4, 5, 17, 128);
o.FillWithPZ(pzo);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = ConstantR4FromArray4D(&builder_, o);
Array4D<float> e(4, 5, 17, 128);
e.FillWithPZ(pze);
Array4D<float> s(2, 2, 17, 128);
s.FillWithPZ(pzs);
- auto source = builder_.ConstantR4FromArray4D(s);
+ auto source = ConstantR4FromArray4D(&builder_, s);
s.FillWithPZ(pzs);
- builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
- Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
+ add_f32_);
ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
}
@@ -392,16 +393,16 @@ TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
{0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
Array4D<float> o(4, 5, 1, 1);
o.FillWithPZ(pzo);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = ConstantR4FromArray4D(&builder_, o);
Array4D<float> e(4, 5, 1, 1);
e.FillWithPZ(pze);
Array4D<float> s(2, 2, 1, 1);
s.FillWithPZ(pzs);
- auto source = builder_.ConstantR4FromArray4D(s);
+ auto source = ConstantR4FromArray4D(&builder_, s);
s.FillWithPZ(pzs);
- builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
- Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
+ add_f32_);
ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
}
@@ -414,39 +415,39 @@ TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
Array4D<float> o(4, 6, 4, 4);
o.FillWithPZ(pzo);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = ConstantR4FromArray4D(&builder_, o);
Array4D<float> s(2, 2, 4, 4);
s.FillWithPZ(pzs);
- auto source = builder_.ConstantR4FromArray4D(s);
+ auto source = ConstantR4FromArray4D(&builder_, s);
s.FillWithPZ(pzs);
- builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
- Padding::kValid, source,
- builder_.ConstantR0<float>(0.0f), add_f32_);
+ SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source, ConstantR0<float>(&builder_, 0.0f),
+ add_f32_);
auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
{2, 3, 1, 1}, false);
ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
}
XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
- const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
- const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
+ const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
- /*window_strides=*/{1}, Padding::kValid, source,
- builder_.ConstantR0<float>(0), max_f32_);
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ ConstantR0<float>(&builder_, 0), max_f32_);
ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
}
XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
- const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
- const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const auto operand = ConstantR1<float>(&builder_, {1, 2, 3, 100, 3, 2, 1});
+ const auto source = ConstantR1<float>(&builder_, {34, 42, 53, 19});
const float max_float = std::numeric_limits<float>::max();
const std::vector<float> expected = {max_float, max_float, max_float, 19,
max_float, max_float, max_float};
- builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
- /*window_strides=*/{1}, Padding::kValid, source,
- builder_.ConstantR0<float>(max_float), min_f32_);
+ SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ ConstantR0<float>(&builder_, max_float), min_f32_);
ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
}
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
index 72707f2244..59409ab26e 100644
--- a/tensorflow/compiler/xla/tests/select_test.cc
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -35,50 +35,52 @@ class SelectTest : public ClientLibraryTestBase {
TEST_F(SelectTest, SelectScalarF32True) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto on_true = builder.ConstantR0<float>(123.0f);
- auto on_false = builder.ConstantR0<float>(42.0f);
- auto result = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto on_true = ConstantR0<float>(&builder, 123.0f);
+ auto on_false = ConstantR0<float>(&builder, 42.0f);
+ Select(pred, on_true, on_false);
ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
}
TEST_F(SelectTest, SelectScalarS32True) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto on_true = builder.ConstantR0<int32>(-42);
- auto on_false = builder.ConstantR0<int32>(42);
- auto result = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto on_true = ConstantR0<int32>(&builder, -42);
+ auto on_false = ConstantR0<int32>(&builder, 42);
+ Select(pred, on_true, on_false);
ComputeAndCompareR0<int32>(&builder, -42, {});
}
TEST_F(SelectTest, SelectScalarF32False) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto on_true = builder.ConstantR0<float>(123.0f);
- auto on_false = builder.ConstantR0<float>(42.0f);
- auto result = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto on_true = ConstantR0<float>(&builder, 123.0f);
+ auto on_false = ConstantR0<float>(&builder, 42.0f);
+ Select(pred, on_true, on_false);
ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
}
XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR1<bool>({});
- auto on_true = builder.ConstantR1<float>({});
- auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR1<bool>(&builder, {});
+ auto on_true = ConstantR1<float>(&builder, {});
+ auto on_false = ConstantR1<float>(&builder, {});
+ Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
- auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
- auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR1<bool>(&builder, {false, true, false, true, false});
+ auto on_true =
+ ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false =
+ ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
error_spec_);
@@ -88,12 +90,12 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
// Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
// is not a constant, but rather the result of comparing two other vectors.
XlaBuilder builder(TestName());
- auto v1 = builder.ConstantR1<int32>({});
- auto v2 = builder.ConstantR1<int32>({});
- auto cmp = builder.Eq(v1, v2);
- auto on_true = builder.ConstantR1<float>({});
- auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(cmp, on_true, on_false);
+ auto v1 = ConstantR1<int32>(&builder, {});
+ auto v2 = ConstantR1<int32>(&builder, {});
+ auto cmp = Eq(v1, v2);
+ auto on_true = ConstantR1<float>(&builder, {});
+ auto on_false = ConstantR1<float>(&builder, {});
+ Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -102,12 +104,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
// Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
// not a constant, but rather the result of comparing two other vectors.
XlaBuilder builder(TestName());
- auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
- auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
- auto cmp = builder.Eq(v1, v2);
- auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
- auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ auto v1 = ConstantR1<int32>(&builder, {1, 2, 3, 4, 5});
+ auto v2 = ConstantR1<int32>(&builder, {9, 2, 9, 4, 9});
+ auto cmp = Eq(v1, v2);
+ auto on_true =
+ ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false =
+ ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
error_spec_);
@@ -116,12 +120,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
// Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
XlaBuilder builder(TestName());
- auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
- auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
- auto cmp = builder.Gt(v1, v2);
- auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
- auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ auto v1 = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ auto v2 = ConstantR1<float>(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
+ auto cmp = Gt(v1, v2);
+ auto on_true =
+ ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false =
+ ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
error_spec_);
@@ -140,8 +146,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
{21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto cmp = builder.Gt(v1, v2);
- auto select = builder.Select(cmp, v1, v2);
+ auto cmp = Gt(v1, v2);
+ Select(cmp, v1, v2);
ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -181,8 +187,8 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto cmp = builder.Gt(v1, v2);
- auto select = builder.Select(cmp, v1, v2);
+ auto cmp = Gt(v1, v2);
+ Select(cmp, v1, v2);
ComputeAndCompareR1<float>(&builder, expected_vec,
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -192,14 +198,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
// "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
// select between two R1F32s.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
- auto s = builder.ConstantR0<int32>(0);
- auto cmp = builder.Gt(v, s);
+ auto v = ConstantR1<int32>(&builder, {1, -1, 2, -2});
+ auto s = ConstantR0<int32>(&builder, 0);
+ auto cmp = Gt(v, s);
- auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
auto on_false =
- builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
+ Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
error_spec_);
@@ -209,14 +215,14 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
// "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
// select between two R1F32s.
XlaBuilder builder(TestName());
- auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
- auto s = builder.ConstantR0<float>(2.5f);
- auto cmp = builder.Gt(v, s);
+ auto v = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f});
+ auto s = ConstantR0<float>(&builder, 2.5f);
+ auto cmp = Gt(v, s);
- auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
auto on_false =
- builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
- auto select = builder.Select(cmp, on_true, on_false);
+ ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
+ Select(cmp, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
error_spec_);
@@ -225,10 +231,10 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
for (bool which : {false, true}) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(which);
- auto on_true = builder.ConstantR1<float>({});
- auto on_false = builder.ConstantR1<float>({});
- auto select = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, which);
+ auto on_true = ConstantR1<float>(&builder, {});
+ auto on_false = ConstantR1<float>(&builder, {});
+ Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
}
@@ -236,20 +242,20 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(true);
- auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
- auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, true);
+ auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
+ auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
+ Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
}
TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
XlaBuilder builder(TestName());
- auto pred = builder.ConstantR0<bool>(false);
- auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
- auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
- auto select = builder.Select(pred, on_true, on_false);
+ auto pred = ConstantR0<bool>(&builder, false);
+ auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
+ auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
+ Select(pred, on_true, on_false);
ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 5653bf11a7..48138e7b07 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -42,8 +42,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
values.FillIota(0);
XlaBuilder builder(TestName());
- auto original = builder.ConstantR3FromArray3D<float>(values);
- builder.Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
+ auto original = ConstantR3FromArray3D<float>(&builder, values);
+ Slice(original, {0, 0, 0}, {3, 3, 1}, {1, 1, 1});
Array3D<float> expected{
{{0.0}, {3.0}, {6.0}}, {{9.0}, {12.0}, {15.0}}, {{18.0}, {21.0}, {24.0}}};
@@ -55,8 +55,8 @@ TEST_F(SliceTest, Slice3x3x3_To_3x1x3_F32) {
values.FillIota(0);
XlaBuilder builder(TestName());
- auto original = builder.ConstantR3FromArray3D<float>(values);
- builder.Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
+ auto original = ConstantR3FromArray3D<float>(&builder, values);
+ Slice(original, {0, 0, 0}, {3, 1, 3}, {1, 1, 1});
Array3D<float> expected{
{{0.0, 1.0, 2.0}}, {{9.0, 10.0, 11.0}}, {{18.0, 19.0, 20.0}}};
@@ -68,8 +68,8 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
values.FillIota(0);
XlaBuilder builder(TestName());
- auto original = builder.ConstantR3FromArray3D<float>(values);
- builder.Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
+ auto original = ConstantR3FromArray3D<float>(&builder, values);
+ Slice(original, {0, 0, 0}, {1, 3, 3}, {1, 1, 1});
Array3D<float> expected{
{{{0.0, 1.0, 2.0}, {3.0, 4.0, 5.0}, {6.0, 7.0, 8.0}}}};
@@ -78,24 +78,24 @@ TEST_F(SliceTest, Slice3x3x3_To_1x3x3_F32) {
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 0));
+ Slice(original, {0, 0}, {0, 0}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
}
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
- builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 20));
+ Slice(original, {0, 15}, {0, 20}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
}
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
- builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, Array2D<float>(3, 0));
+ Slice(original, {1, 0}, {3, 0}, {1, 1});
ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
}
@@ -109,8 +109,8 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
}
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, values);
+ Slice(original, {128, 128}, {256, 256}, {1, 1});
Array2D<float> expected(128, 128);
for (int row = 0; row < 128; ++row) {
@@ -127,8 +127,8 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
std::iota(values.data(), values.data() + 4096, 0.0);
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, values);
+ Slice(original, {0, 3072}, {1, 4096}, {1, 1});
Array2D<float> expected(1, 1024);
std::iota(expected.data(), expected.data() + 1024, 3072.0);
@@ -148,8 +148,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
}
}
XlaBuilder builder(TestName());
- auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
+ auto original = ConstantR2FromArray2D<float>(&builder, values);
+ Slice(original, {0, 0}, {16, 2}, {1, 1});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
}
@@ -160,8 +160,8 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
auto expected = ReferenceUtil::Slice4D(
values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
XlaBuilder builder(TestName());
- auto original = builder.ConstantR4FromArray4D(values);
- builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
+ auto original = ConstantR4FromArray4D(&builder, values);
+ Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
@@ -170,11 +170,11 @@ XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
values.FillRandom(3.14f);
auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
/*strides=*/{{1, 1, 2, 1}});
- auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
+ auto expected_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
*expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
XlaBuilder builder(TestName());
- auto original = builder.ConstantR4FromArray4D(values);
- builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
+ auto original = ConstantR4FromArray4D(&builder, values);
+ Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
&expected_literal->shape());
}
@@ -197,12 +197,12 @@ class SliceR1Test : public ClientLibraryTestBase,
// vector<bool>.
tensorflow::gtl::InlinedVector<NativeT, 1> input(spec.input_dim0);
std::iota(input.begin(), input.end(), NativeT());
- auto literal = Literal::CreateR1<NativeT>(input);
+ auto literal = LiteralUtil::CreateR1<NativeT>(input);
XlaBuilder builder(TestName());
- auto original = builder.Parameter(0, literal->shape(), "p0");
- builder.Slice(original, {spec.slice_start}, {spec.slice_limit},
- {spec.slice_stride});
+ auto original = Parameter(&builder, 0, literal->shape(), "p0");
+ Slice(original, {spec.slice_start}, {spec.slice_limit},
+ {spec.slice_stride});
// Ditto.
tensorflow::gtl::InlinedVector<NativeT, 1> expected;
@@ -368,12 +368,12 @@ XLA_TEST_P(SliceR2Test, DoIt) {
const R2Spec& spec = GetParam();
Array2D<int32> input(spec.input_dim0, spec.input_dim1);
input.FillUnique();
- auto literal = Literal::CreateR2FromArray2DWithLayout(
+ auto literal = LiteralUtil::CreateR2FromArray2DWithLayout(
input, LayoutUtil::MakeLayout(spec.layout));
XlaBuilder builder(TestName());
- auto a = builder.Parameter(0, literal->shape(), "p0");
- builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
+ auto a = Parameter(&builder, 0, literal->shape(), "p0");
+ Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
client_->TransferToServer(*literal));
@@ -463,13 +463,12 @@ class SliceR4Test : public ClientLibraryTestBase,
auto expected = ReferenceUtil::Slice4D(
values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
XlaBuilder builder(TestName());
- auto literal = Literal::CreateR4FromArray4DWithLayout(
+ auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(
values, LayoutUtil::MakeLayout(spec.input_layout));
- auto parameter = builder.Parameter(0, literal->shape(), "p0");
+ auto parameter = Parameter(&builder, 0, literal->shape(), "p0");
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
client_->TransferToServer(*literal));
- builder.Slice(parameter, spec.slice_starts, spec.slice_limits,
- spec.slice_strides);
+ Slice(parameter, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
}
};
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index dd7c541733..2647937013 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
@@ -110,7 +111,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
MakeFakeLiteralInternal(element_shape, engine));
elements.push_back(std::move(element));
}
- return Literal::MakeTupleOwned(std::move(elements));
+ return LiteralUtil::MakeTupleOwned(std::move(elements));
}
if (engine == nullptr) {
return Literal::CreateFromShape(shape);
@@ -161,6 +162,9 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteralInternal(
}));
break;
}
+ // Token requires no data.
+ case TOKEN:
+ break;
default:
return Unimplemented("Unsupported type for fake literal generation: %s",
ShapeUtil::HumanString(shape).c_str());
@@ -217,7 +221,7 @@ std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
start_indices[i] = generator(*engine);
}
}
- return Literal::CreateR1<int32>(start_indices);
+ return LiteralUtil::CreateR1<int32>(start_indices);
}
// Use dataflow analysis on each parameter to see if there are uses that would
@@ -270,14 +274,22 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr &&
- !ShapeUtil::Equal(needs_index->shape(), use->shape())) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ if (needs_index != nullptr) {
+ auto needs_index_shape = needs_index->shape();
+ auto use_shape = use->shape();
+ if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
+ needs_index_shape = needs_index->operand(0)->shape();
+ }
+ if (use->opcode() == HloOpcode::kDynamicSlice) {
+ use_shape = use->operand(0)->shape();
+ }
+ if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
+ return Unimplemented(
+ "Conflicting operand generation slice index constraints\n");
+ }
}
needs_index = use;
break;
-
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
needs_constant = use;
@@ -307,9 +319,9 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
} else if (needs_constant != nullptr) {
switch (constant_type) {
case ConstantType::kZero:
- return Literal::Zero(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
case ConstantType::kOne:
- return Literal::One(param.shape().element_type()).CloneToUnique();
+ return LiteralUtil::One(param.shape().element_type()).CloneToUnique();
case ConstantType::kUnknown:
// We want the identity element for the computation, but we don't really
// know what it is - so any value we generate will be just as wrong.
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index a8689f6498..e59f215a9a 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <random>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 59afd28a80..8f424ae81f 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -31,16 +32,16 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
XlaBuilder builder(TestName());
// Make the reduction lambda.
Shape single_float = ShapeUtil::MakeShape(F32, {});
- builder.Parameter(0, single_float, "unused");
- builder.Parameter(1, single_float, "used");
+ Parameter(&builder, 0, single_float, "unused");
+ Parameter(&builder, 1, single_float, "used");
auto computation_status = builder.Build();
TF_ASSERT_OK(computation_status.status());
// Make the reduction.
Shape pair_float = ShapeUtil::MakeShape(F32, {2});
- builder.Reduce(builder.Parameter(0, pair_float, "operand"),
- builder.Parameter(1, single_float, "init"),
- computation_status.ValueOrDie(), {0});
+ Reduce(Parameter(&builder, 0, pair_float, "operand"),
+ Parameter(&builder, 1, single_float, "init"),
+ computation_status.ValueOrDie(), {0});
computation_status = builder.Build();
TF_ASSERT_OK(computation_status.status());
@@ -53,5 +54,23 @@ XLA_TEST_F(TestUtilsTest, UnusedParam) {
TF_ASSERT_OK(MakeFakeArguments(&module).status());
}
+XLA_TEST_F(TestUtilsTest, Token) {
+ auto module = ParseHloString(
+ R"(HloModule outfeed_module
+
+ ENTRY InfeedToOutfeed {
+ token = token[] parameter(0)
+ infeed = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
+ outfeed = token[] outfeed(infeed.data, token)
+ ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token)
+ infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
+ infeed.1.token = token[] get-tuple-element(infeed.1), index=1
+ outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
new file mode 100644
index 0000000000..2bdbd08309
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -0,0 +1,206 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <array>
+
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class TokenHloTest : public HloTestBase {};
+
+XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateToken());
+
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+}
+
+XLA_TEST_F(TokenHloTest, TokenTree) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto token2 = builder.AddInstruction(HloInstruction::CreateToken());
+ builder.AddInstruction(
+ HloInstruction::CreateAfterAll({token0, token0, token1, token2}));
+
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *LiteralUtil::CreateToken()));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 1 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}),
+ "param"));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(HloInstruction::CreateAfterAll({param}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr(
+ "Operands of token instructions must be TOKEN types"));
+}
+
+XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
+ // Thread a token around a while loop. Token is created and consumed by a
+ // AfterAll instruction in the while body.
+ string module_string = R"(
+HloModule TokenInWhileLoop
+
+%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
+ %param.1 = (s32[], token[]) parameter(0)
+ %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
+ %constant.1 = s32[] constant(1)
+ %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
+ %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
+ %after-all = token[] after-all(token[] %get-tuple-element.2)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
+}
+
+%Cond (param: (s32[], token[])) -> pred[] {
+ %param = (s32[], token[]) parameter(0)
+ %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
+ %constant = s32[] constant(42)
+ ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant)
+}
+
+ENTRY %TokenInWhileLoop () -> s32[] {
+ %zero = s32[] constant(0)
+ %init_token = token[] after-all()
+ %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
+ %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
+}
+)";
+
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ // Module DCE pass removes the generate token instructions.
+ debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(module_string, debug_options));
+
+ EXPECT_TRUE(RunAndCompare(std::move(module), error_spec_));
+}
+
+XLA_TEST_F(TokenHloTest, TokenInConditional) {
+ string module_string = R"(
+HloModule TokenInConditional
+
+%True (param.1: token[]) -> (s32[], token[]) {
+ %param.1 = token[] parameter(0)
+ %forty_two = s32[] constant(42)
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %forty_two, token[] %param.1)
+}
+
+%False (param.2: s32[]) -> (s32[], token[]) {
+ %param.2 = s32[] parameter(0)
+ %new_token = token[] after-all()
+ ROOT %tuple = (s32[], token[]) tuple(s32[] %param.2, token[] %new_token)
+}
+
+ENTRY %TokenInConditional (param.3: pred[]) -> s32[] {
+ %param.3 = pred[] parameter(0)
+ %init_token = token[] after-all()
+ %seven = s32[] constant(7)
+ %cond = (s32[], token[]) conditional(pred[] %param.3, token[] %init_token, s32[] %seven), true_computation=True, false_computation=False
+ ROOT %root = s32[] get-tuple-element((s32[], token[]) %cond), index=0
+}
+)";
+
+ DebugOptions debug_options = GetDebugOptionsForTest();
+ // Module DCE pass removes the generate token instructions.
+ debug_options.add_xla_disable_hlo_passes("hlo-module-dce");
+
+ {
+ // True case.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(module_string, debug_options));
+ auto arg = LiteralUtil::CreateR0<bool>(true);
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {arg.get()}));
+ EXPECT_EQ(42, result->Get<int32>({}));
+ }
+
+ {
+ // False case.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(module_string, debug_options));
+ auto arg = LiteralUtil::CreateR0<bool>(false);
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
+ Execute(std::move(module), {arg.get()}));
+ EXPECT_EQ(7, result->Get<int32>({}));
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
index 0063e7ad41..0f86b7f20f 100644
--- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc
+++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -41,7 +42,12 @@ class TransferManagerTest : public LocalClientTestBase {
TransferManagerTest()
: shape_size_fn_([this](const Shape& shape) {
return transfer_manager_->GetByteSizeRequirement(shape);
- }) {}
+ }) {
+ stream_ptr_ = local_client_->mutable_backend()
+ ->BorrowStream(stream_executor_)
+ .ValueOrDie();
+ stream_ = stream_ptr_.get();
+ }
~TransferManagerTest() override = default;
@@ -53,37 +59,41 @@ class TransferManagerTest : public LocalClientTestBase {
.ValueOrDie();
}
+ protected:
+ Backend::StreamPtr stream_ptr_;
+ se::Stream* stream_;
+
private:
std::function<int64(const Shape&)> shape_size_fn_;
};
XLA_TEST_F(TransferManagerTest, TransferR0U32) {
- std::unique_ptr<Literal> literal = Literal::CreateR0<uint32>(42);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR0Equal<uint32>(42, *result);
}
XLA_TEST_F(TransferManagerTest, TransferR1F32) {
std::unique_ptr<Literal> literal =
- Literal::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
+ LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>({1.25f, 2.5f, -17.0f, -20.125f},
*result);
@@ -92,48 +102,48 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) {
XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) {
std::vector<float> test_vector(1024 * 1024);
std::iota(test_vector.begin(), test_vector.end(), 0);
- std::unique_ptr<Literal> literal = Literal::CreateR1<float>(test_vector);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(test_vector);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR1Equal<float>(test_vector, *result);
}
XLA_TEST_F(TransferManagerTest, TransferR1U8) {
const char* test_string = "0123456789abcdef";
- std::unique_ptr<Literal> literal = Literal::CreateR1U8(test_string);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(test_string);
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_EQ(result->GetR1U8AsString(), test_string);
}
XLA_TEST_F(TransferManagerTest, TransferR2F32) {
std::unique_ptr<Literal> literal =
- Literal::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
const Shape& shape = literal->shape();
auto device_buffer = AllocateDeviceBuffer(shape);
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
LiteralTestUtil::ExpectR2Equal<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result);
@@ -141,7 +151,7 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) {
XLA_TEST_F(TransferManagerTest,
TransferR2F32AndChangeLayoutTransferringToDevice) {
- std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2WithLayout<float>(
{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1}));
const Shape ondevice_shape =
ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0});
@@ -149,11 +159,11 @@ XLA_TEST_F(TransferManagerTest,
// Round trip literal through device. Set the on-device layout to something
// different than the literal layout.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_FALSE(
LayoutUtil::Equal(result->shape().layout(), literal->shape().layout()));
@@ -162,89 +172,237 @@ XLA_TEST_F(TransferManagerTest,
}
XLA_TEST_F(TransferManagerTest, TransferTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR0<float>(123.0f).get(),
- Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple({});
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR0<float>(123.0f).get(),
- Literal::MakeTuple(
- {Literal::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
- Literal::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
.get(),
- Literal::CreateR1<float>({-10.0f, 123.0f}).get()});
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
- std::unique_ptr<Literal> literal = Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
- std::unique_ptr<Literal> literal = Literal::MakeTuple(
- {Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<complex64>(
{complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)})
.get(),
- Literal::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
- Literal::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
+ LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6}).get(),
+ LiteralUtil::CreateR0<complex64>(complex64(0.3f, -0.4f)).get()});
auto device_buffer = AllocateDeviceBuffer(literal->shape());
// Round trip literal through device.
- ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(
- stream_executor_, *literal, device_buffer));
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result,
- transfer_manager_->TransferLiteralFromDevice(
- stream_executor_, device_buffer));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
+XLA_TEST_F(TransferManagerTest, TransferTokenFromDevice) {
+ // "Copy" a token from the device. The token has no physical representation so
+ // no copying is actually performed, but it shouldn't fail.
+ // TODO(b/110532604): Add transferring the token to device when this is
+ // supported.
+ auto device_buffer = AllocateDeviceBuffer(ShapeUtil::MakeTokenShape());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*LiteralUtil::CreateToken(), *result));
+}
+
+XLA_TEST_F(TransferManagerTest, MultiStreamRoundTripSoak) {
+ const int64 kIterationCount = 5000;
+ std::unique_ptr<Literal> literal1 = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(123.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -10.0f, 3333333.3f}).get()})
+ .get(),
+ LiteralUtil::CreateR1<float>({-10.0f, 123.0f}).get()});
+ std::unique_ptr<Literal> literal2 = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(456.0f).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{5.0f, 7.0f}, {9.0f, 4.0f}}).get(),
+ LiteralUtil::CreateR1<float>({44.0f, -11.0f, 3333333.3f}).get()})
+ .get(),
+ LiteralUtil::CreateR1<float>({-98.0f, 153.0f}).get()});
+
+ auto device_buffer1 = AllocateDeviceBuffer(literal1->shape());
+ auto device_buffer2 = AllocateDeviceBuffer(literal2->shape());
+
+ auto stream1 = stream_;
+ auto stream2 = stream_->GetOrCreateSubStream();
+
+ std::unique_ptr<Literal> result1, result2;
+
+ // Round trip literals through device in multiple streams asynchronously.
+ for (int i = 0; i < kIterationCount; ++i) {
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream1, *literal1,
+ device_buffer1));
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice(stream2, *literal2,
+ device_buffer2));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> this_result1,
+ transfer_manager_->TransferLiteralFromDevice(stream1, device_buffer1));
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> this_result2,
+ transfer_manager_->TransferLiteralFromDevice(stream2, device_buffer2));
+ result1 = std::move(this_result1);
+ result2 = std::move(this_result2);
+ }
+
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal1, *result1));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal2, *result2));
+}
+
+class TransferDeviceToHostBenchmark : public TransferManagerTest {
+ public:
+ using TransferManagerTest::TransferManagerTest;
+ ~TransferDeviceToHostBenchmark() override {}
+
+ void Run(int iters, int num_tuple_elements, int array_size) {
+ tensorflow::testing::StopTiming();
+ SetUp();
+
+ std::vector<std::unique_ptr<Literal>> tuple_elements;
+ for (int i = 0; i < num_tuple_elements; ++i) {
+ tuple_elements.push_back(
+ LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ }
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> result,
+ transfer_manager_->TransferLiteralFromDevice(stream_, device_buffer));
+ }
+ tensorflow::testing::StopTiming();
+ TearDown();
+ }
+
+ void TestBody() override {}
+};
+
+class TransferHostToDeviceBenchmark : public TransferManagerTest {
+ public:
+ using TransferManagerTest::TransferManagerTest;
+ ~TransferHostToDeviceBenchmark() override {}
+
+ void Run(int iters, int num_tuple_elements, int array_size) {
+ tensorflow::testing::StopTiming();
+ SetUp();
+
+ std::vector<std::unique_ptr<Literal>> tuple_elements;
+ for (int i = 0; i < num_tuple_elements; ++i) {
+ tuple_elements.push_back(
+ LiteralUtil::CreateR2F32Linspace(0.0f, 1.0f, array_size, array_size));
+ }
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::MakeTupleOwned(std::move(tuple_elements));
+ auto device_buffer = AllocateDeviceBuffer(literal->shape());
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(stream_, *literal,
+ device_buffer));
+ }
+ tensorflow::testing::StopTiming();
+ TearDown();
+ }
+
+ void TestBody() override {}
+};
+
+void BM_TransferDeviceToHost(int iters, int num_tuple_elements,
+ int array_size) {
+ TransferDeviceToHostBenchmark bm;
+ bm.Run(iters, num_tuple_elements, array_size);
+}
+
+void BM_TransferHostToDevice(int iters, int num_tuple_elements,
+ int array_size) {
+ TransferHostToDeviceBenchmark bm;
+ bm.Run(iters, num_tuple_elements, array_size);
+}
+
+BENCHMARK(BM_TransferHostToDevice)
+ ->ArgPair(1, 256)
+ ->ArgPair(1, 257)
+ ->ArgPair(100, 256)
+ ->ArgPair(100, 257);
+
+BENCHMARK(BM_TransferDeviceToHost)
+ ->ArgPair(1, 256)
+ ->ArgPair(1, 257)
+ ->ArgPair(100, 256)
+ ->ArgPair(100, 257);
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ tensorflow::testing::RunBenchmarks();
+ return RUN_ALL_TESTS();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc
index fe1e3da7ec..6ebb4324f8 100644
--- a/tensorflow/compiler/xla/tests/transpose_test.cc
+++ b/tensorflow/compiler/xla/tests/transpose_test.cc
@@ -38,34 +38,35 @@ class TransposeTest : public ClientLibraryTestBase {
XLA_TEST_F(TransposeTest, Transpose0x0) {
XlaBuilder builder("Transpose");
- auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- auto result = builder.Transpose(lhs, {1, 0});
+ auto lhs = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 0));
+ Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
}
XLA_TEST_F(TransposeTest, Transpose0x42) {
XlaBuilder builder("Transpose");
- auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42));
- auto result = builder.Transpose(lhs, {1, 0});
+ auto lhs = ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 42));
+ Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_);
}
XLA_TEST_F(TransposeTest, Transpose7x0) {
XlaBuilder builder("Transpose");
- auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0));
- auto result = builder.Transpose(lhs, {1, 0});
+ auto lhs = ConstantR2FromArray2D<float>(&builder, Array2D<float>(7, 0));
+ Transpose(lhs, {1, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_);
}
TEST_F(TransposeTest, Transpose2x2) {
XlaBuilder builder("Transpose");
- auto lhs = builder.ConstantR2<float>({
- {1.0, 2.0}, {3.0, 4.0},
- });
- auto result = builder.Transpose(lhs, {1, 0});
+ auto lhs = ConstantR2<float>(&builder, {
+ {1.0, 2.0},
+ {3.0, 4.0},
+ });
+ Transpose(lhs, {1, 0});
Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}});
@@ -74,16 +75,18 @@ TEST_F(TransposeTest, Transpose2x2) {
XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
XlaBuilder builder("Transpose");
- auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3));
- auto result = builder.Transpose(operand, {1, 2, 0});
+ auto operand =
+ ConstantR3FromArray3D<int32>(&builder, Array3D<int32>(0, 2, 3));
+ Transpose(operand, {1, 2, 0});
ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {});
}
TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
XlaBuilder builder("Transpose");
- auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {1, 2, 0});
+ auto operand =
+ ConstantR3FromArray3D<int32>(&builder, {{{1, 2, 3}, {4, 5, 6}}});
+ Transpose(operand, {1, 2, 0});
Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}});
@@ -92,8 +95,9 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
XlaBuilder builder("Transpose");
- auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {2, 1, 0});
+ auto operand =
+ ConstantR3FromArray3D<int32>(&builder, {{{1, 2, 3}, {4, 5, 6}}});
+ Transpose(operand, {2, 1, 0});
Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}});
@@ -102,8 +106,9 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
TEST_F(TransposeTest, Transpose1x2x3_1x2x3) {
XlaBuilder builder("Transpose");
- auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
- auto result = builder.Transpose(operand, {0, 1, 2});
+ auto operand =
+ ConstantR3FromArray3D<int32>(&builder, {{{1, 2, 3}, {4, 5, 6}}});
+ Transpose(operand, {0, 1, 2});
Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}});
@@ -116,9 +121,9 @@ TEST_F(TransposeTest, MultiTranspose3x2) {
for (int transposes = 0; transposes <= 10; ++transposes) {
XlaBuilder builder("Transpose");
- auto computed = builder.ConstantR2FromArray2D<float>(input);
+ auto computed = ConstantR2FromArray2D<float>(&builder, input);
for (int i = 0; i < transposes; ++i) {
- computed = builder.Transpose(computed, {1, 0});
+ computed = Transpose(computed, {1, 0});
}
const Array2D<float>& expected = transposes % 2 == 0 ? input : transposed;
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
@@ -130,8 +135,8 @@ TEST_F(TransposeTest, Small_1x1) {
auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1);
XlaBuilder builder("transpose_1x1");
- auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
- builder.Transpose(operand, {1, 0});
+ auto operand = ConstantR2FromArray2D<float>(&builder, *aoperand);
+ Transpose(operand, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
@@ -142,8 +147,8 @@ TEST_F(TransposeTest, Small_2x2) {
auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2);
XlaBuilder builder("transpose_2x2");
- auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
- builder.Transpose(operand, {1, 0});
+ auto operand = ConstantR2FromArray2D<float>(&builder, *aoperand);
+ Transpose(operand, {1, 0});
auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
@@ -162,8 +167,8 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) {
}
XlaBuilder builder(TestName());
- auto operand = builder.ConstantR3FromArray3D(aoperand);
- builder.Transpose(operand, {0, 2, 1});
+ auto operand = ConstantR3FromArray3D(&builder, aoperand);
+ Transpose(operand, {0, 2, 1});
ComputeAndCompareR3<int32>(&builder, expected, {});
}
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 41189231b9..bf86c5dfb6 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -49,12 +49,12 @@ XLA_TEST_F(TupleTest, TupleConstant) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get(),
- Literal::CreateR2<float>(constant_matrix).get()});
+ auto value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
- builder.ConstantLiteral(*value);
+ ConstantLiteral(&builder, *value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
@@ -64,11 +64,11 @@ XLA_TEST_F(TupleTest, TupleScalarConstant) {
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
- auto value =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
- Literal::CreateR0<float>(constant_scalar2).get()});
+ auto value = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar1).get(),
+ LiteralUtil::CreateR0<float>(constant_scalar2).get()});
- builder.ConstantLiteral(*value);
+ ConstantLiteral(&builder, *value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
@@ -82,14 +82,14 @@ XLA_TEST_F(TupleTest, TupleCreate) {
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- builder.Tuple({builder.ConstantR0<float>(constant_scalar),
- builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
-
- auto expected =
- Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
- Literal::CreateR1<float>(constant_vector).get(),
- Literal::CreateR2<float>(constant_matrix).get()});
+ Tuple(&builder, {ConstantR0<float>(&builder, constant_scalar),
+ ConstantR1<float>(&builder, constant_vector),
+ ConstantR2<float>(&builder, constant_matrix)});
+
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -97,19 +97,20 @@ XLA_TEST_F(TupleTest, TupleCreate) {
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
XlaBuilder builder(TestName());
- builder.Tuple(
- {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
+ Tuple(&builder,
+ {ConstantR0<float>(&builder, 7.0), ConstantR1<float>(&builder, {})});
- auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
- Literal::CreateR1<float>({}).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
+ LiteralUtil::CreateR1<float>({}).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
XlaBuilder builder(TestName());
- builder.Tuple({});
- auto expected = Literal::MakeTuple({});
+ Tuple(&builder, {});
+ auto expected = LiteralUtil::MakeTuple({});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -121,9 +122,10 @@ XLA_TEST_F(TupleTest, GetTupleElement) {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
- auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
- builder.GetTupleElement(tuple_data, 1);
+ auto tuple_data =
+ Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
+ ConstantR2<float>(&builder, constant_matrix)});
+ GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
error_spec_);
}
@@ -131,17 +133,18 @@ XLA_TEST_F(TupleTest, GetTupleElement) {
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
XlaBuilder builder(TestName());
- auto tuple_data = builder.Tuple(
- {builder.ConstantR1<float>({}),
- builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
- builder.GetTupleElement(tuple_data, 1);
+ auto tuple_data =
+ Tuple(&builder,
+ {ConstantR1<float>(&builder, {}),
+ ConstantR2FromArray2D<float>(&builder, Array2D<float>(0, 101))});
+ GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
XlaBuilder builder(TestName());
- auto value = builder.ConstantR1<float>({4.5f});
- builder.GetTupleElement(value, 1);
+ auto value = ConstantR1<float>(&builder, {4.5f});
+ GetTupleElement(value, 1);
auto result_status = builder.Build();
EXPECT_FALSE(result_status.ok());
EXPECT_THAT(
@@ -158,14 +161,15 @@ XLA_TEST_F(TupleTest, AddTupleElements) {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
- auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
- auto vector_element = builder.GetTupleElement(tuple_data, 0);
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ auto tuple_data =
+ Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
+ ConstantR2<float>(&builder, constant_matrix)});
+ auto vector_element = GetTupleElement(tuple_data, 0);
+ auto matrix_element = GetTupleElement(tuple_data, 1);
auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
- builder.Add(matrix_element, vector_element,
- /*broadcast_dimensions=*/{1});
+ Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{2.f, 4.f, 6.f}, // row 0
@@ -185,13 +189,14 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
- auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
- builder.Tuple({builder.GetTupleElement(tuple_data, 1),
- builder.GetTupleElement(tuple_data, 0)});
- auto expected =
- Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
- Literal::CreateR1<float>(constant_vector).get()});
+ auto tuple_data =
+ Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
+ ConstantR2<float>(&builder, constant_matrix)});
+ Tuple(&builder,
+ {GetTupleElement(tuple_data, 1), GetTupleElement(tuple_data, 0)});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>(constant_matrix).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -206,14 +211,14 @@ XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
std::unique_ptr<GlobalData> v2_data =
CreateR0Parameter<float>(1.0f, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&b, /*data_handle=*/&v2);
- auto v1_gt = b.Gt(v1, v2); // false
- auto v2_gt = b.Gt(v2, v1); // true
- auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
- auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
- b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
+ auto v1_gt = Gt(v1, v2); // false
+ auto v2_gt = Gt(v2, v1); // true
+ auto v1_v2 = Tuple(&b, {v1_gt, v2_gt}); // {false, true}
+ auto v2_v1 = Tuple(&b, {v2_gt, v1_gt}); // {true, false}
+ Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
- Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
- Literal::CreateR0<bool>(!direction).get()});
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<bool>(direction).get(),
+ LiteralUtil::CreateR0<bool>(!direction).get()});
ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()},
error_spec_);
@@ -243,22 +248,23 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
{1.f, 2.f, 3.f}, // row 0
{4.f, 5.f, 6.f}, // row 1
};
- auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
- auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0),
- builder.GetTupleElement(tuple_data, 1)});
- auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
- builder.GetTupleElement(tuple_data, 0)});
- auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0);
- auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1);
- auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1);
- auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0);
-
- auto addvectors = builder.Add(vector_from_01, vector_from_10);
- auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
-
- builder.Add(addmatrices, addvectors,
- /*broadcast_dimensions=*/{1});
+ auto tuple_data =
+ Tuple(&builder, {ConstantR1<float>(&builder, constant_vector),
+ ConstantR2<float>(&builder, constant_matrix)});
+ auto new_tuple01 = Tuple(&builder, {GetTupleElement(tuple_data, 0),
+ GetTupleElement(tuple_data, 1)});
+ auto new_tuple10 = Tuple(&builder, {GetTupleElement(tuple_data, 1),
+ GetTupleElement(tuple_data, 0)});
+ auto vector_from_01 = GetTupleElement(new_tuple01, 0);
+ auto vector_from_10 = GetTupleElement(new_tuple10, 1);
+ auto matrix_from_01 = GetTupleElement(new_tuple01, 1);
+ auto matrix_from_10 = GetTupleElement(new_tuple10, 0);
+
+ auto addvectors = Add(vector_from_01, vector_from_10);
+ auto addmatrices = Add(matrix_from_01, matrix_from_10);
+
+ Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{4.f, 8.f, 12.f}, // row 0
@@ -273,14 +279,15 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto tuple12 = builder.Tuple(
- {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
- auto tuple21 = builder.Tuple(
- {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
-
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
- Literal::CreateR1<float>(vec1).get()});
+ auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
+ ConstantR1<float>(&builder, vec2)});
+ auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
+ ConstantR1<float>(&builder, vec1)});
+
+ Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -292,22 +299,22 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
// Need to put a select in there to prevent HLO-level optimizations from
// optimizing out the tuples.
XlaBuilder b("sort_square");
- auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto x2 = b.Mul(x, x);
- auto x_smaller_tuple = b.Tuple({x, x2});
- auto x2_smaller_tuple = b.Tuple({x2, x});
- auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
- auto smaller = b.GetTupleElement(sorted, 0);
- auto greater = b.GetTupleElement(sorted, 1);
- b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller));
+ auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto x2 = Mul(x, x);
+ auto x_smaller_tuple = Tuple(&b, {x, x2});
+ auto x2_smaller_tuple = Tuple(&b, {x2, x});
+ auto sorted = Select(Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
+ auto smaller = GetTupleElement(sorted, 0);
+ auto greater = GetTupleElement(sorted, 1);
+ Add(greater, Mul(ConstantR0<float>(&b, 100.0f), smaller));
auto computation_status = b.Build();
ASSERT_IS_OK(computation_status.status());
tuple_computation = computation_status.ConsumeValueOrDie();
}
XlaBuilder b(TestName());
- auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
- b.Map({input}, tuple_computation, {0});
+ auto input = ConstantR1<float>(&b, {-1.0f, 1.0f, 2.1f});
+ Map(&b, {input}, tuple_computation, {0});
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}
@@ -317,14 +324,15 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto tuple12 = builder.Tuple(
- {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
- auto tuple21 = builder.Tuple(
- {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
-
- builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
- Literal::CreateR1<float>(vec2).get()});
+ auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
+ ConstantR1<float>(&builder, vec2)});
+ auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
+ ConstantR1<float>(&builder, vec1)});
+
+ Select(ConstantR0<bool>(&builder, true), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
+ LiteralUtil::CreateR1<float>(vec2).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -335,14 +343,13 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto tuple12 = builder.Tuple(
- {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
- auto tuple21 = builder.Tuple(
- {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+ auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
+ ConstantR1<float>(&builder, vec2)});
+ auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
+ ConstantR1<float>(&builder, vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
- builder.GetTupleElement(select, 0);
+ auto select = Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
+ GetTupleElement(select, 0);
ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}
@@ -371,19 +378,16 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto pred_tuple = builder.Tuple(
- {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)});
- auto tuple12 = builder.Tuple(
- {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
- auto tuple21 = builder.Tuple(
- {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+ auto pred_tuple = Tuple(&builder, {ConstantR0<bool>(&builder, true),
+ ConstantR0<bool>(&builder, false)});
+ auto tuple12 = Tuple(&builder, {ConstantR1<float>(&builder, vec1),
+ ConstantR1<float>(&builder, vec2)});
+ auto tuple21 = Tuple(&builder, {ConstantR1<float>(&builder, vec2),
+ ConstantR1<float>(&builder, vec1)});
- auto select1 =
- builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
- auto select2 =
- builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
- builder.Add(builder.GetTupleElement(select2, 0),
- builder.GetTupleElement(select2, 1));
+ auto select1 = Select(GetTupleElement(pred_tuple, 0), tuple12, tuple21);
+ auto select2 = Select(GetTupleElement(pred_tuple, 1), tuple21, select1);
+ Add(GetTupleElement(select2, 0), GetTupleElement(select2, 1));
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
@@ -395,31 +399,32 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
- auto c1 = builder.ConstantR1<float>(vec1);
- auto c2 = builder.ConstantR1<float>(vec2);
- auto tuple12 = builder.Tuple({c1, c2});
- auto tuple21 = builder.Tuple({c2, c1});
+ auto c1 = ConstantR1<float>(&builder, vec1);
+ auto c2 = ConstantR1<float>(&builder, vec2);
+ auto tuple12 = Tuple(&builder, {c1, c2});
+ auto tuple21 = Tuple(&builder, {c2, c1});
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ Select(ConstantR0<bool>(&builder, false), tuple12, tuple21);
- auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
- Literal::CreateR1<float>(vec1).get()});
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
XlaBuilder builder(TestName());
- auto inner_tuple = builder.Tuple(
- {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
- builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+ auto inner_tuple = Tuple(&builder, {ConstantR1<float>(&builder, {1.0, 2.0}),
+ ConstantR0<float>(&builder, 42.0)});
+ Tuple(&builder, {inner_tuple, ConstantR1<float>(&builder, {22.0, 44.0})});
- auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
- auto expected_s = Literal::CreateR0<float>(42.0);
+ auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
+ auto expected_s = LiteralUtil::CreateR0<float>(42.0);
auto expected_inner_tuple =
- Literal::MakeTuple({expected_v1.get(), expected_s.get()});
- auto expected_v2 = Literal::CreateR1<float>({22.0, 44.0});
+ LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
auto expected =
- Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+ LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
@@ -432,21 +437,21 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
Shape outer_tuple_shape =
ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
- auto input = builder.Parameter(0, outer_tuple_shape, "input");
- auto gte0 = builder.GetTupleElement(input, 0);
- auto gte1 = builder.GetTupleElement(gte0, 1);
- builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0}));
+ auto input = Parameter(&builder, 0, outer_tuple_shape, "input");
+ auto gte0 = GetTupleElement(input, 0);
+ auto gte1 = GetTupleElement(gte0, 1);
+ Add(gte1, ConstantR1<float>(&builder, {10.0, 11.0, 12.0}));
std::unique_ptr<GlobalData> data =
client_
- ->TransferToServer(*Literal::MakeTuple({
- Literal::MakeTuple(
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::MakeTuple(
{
- Literal::CreateR1<float>({1.0, 2.0, 3.0}).get(),
- Literal::CreateR1<float>({4.0, 5.0, 6.0}).get(),
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
})
.get(),
- Literal::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
}))
.ConsumeValueOrDie();
@@ -463,25 +468,26 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2});
Shape arg0_shape = ShapeUtil::MakeTupleShape(
{c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})});
- auto input0 = builder.Parameter(0, arg0_shape, "input0");
- auto t0 = builder.GetTupleElement(input0, 0);
- auto t1 = builder.GetTupleElement(input0, 1);
- auto t10 = builder.GetTupleElement(t1, 0);
- auto t11 = builder.GetTupleElement(t1, 1);
- auto sum = builder.Add(builder.Add(t10, t11, {1}), t0);
- auto input1 = builder.Parameter(1, c64r1, "input1");
- auto prod = builder.Mul(input1, sum, {1});
- builder.Tuple({builder.Tuple({prod, sum}),
- builder.ConstantR0<complex64>({123, 456})});
+ auto input0 = Parameter(&builder, 0, arg0_shape, "input0");
+ auto t0 = GetTupleElement(input0, 0);
+ auto t1 = GetTupleElement(input0, 1);
+ auto t10 = GetTupleElement(t1, 0);
+ auto t11 = GetTupleElement(t1, 1);
+ auto sum = Add(Add(t10, t11, {1}), t0);
+ auto input1 = Parameter(&builder, 1, c64r1, "input1");
+ auto prod = Mul(input1, sum, {1});
+ Tuple(&builder, {Tuple(&builder, {prod, sum}),
+ ConstantR0<complex64>(&builder, {123, 456})});
}
std::unique_ptr<GlobalData> arg0 =
client_
- ->TransferToServer(*Literal::MakeTuple(
- {Literal::CreateR0<complex64>({1, 2}).get(),
- Literal::MakeTuple(
- {Literal::CreateR1<complex64>({{10, 20}, {30, 40}}).get(),
- Literal::CreateR2<complex64>(
+ ->TransferToServer(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<complex64>({1, 2}).get(),
+ LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR1<complex64>({{10, 20}, {30, 40}})
+ .get(),
+ LiteralUtil::CreateR2<complex64>(
{{{100, 200}, {300, 400}},
{{1000, 2000}, {3000, 4000}},
{{10000, 20000}, {30000, 40000}}})
@@ -490,11 +496,13 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
.ConsumeValueOrDie();
std::unique_ptr<GlobalData> arg1 =
client_
- ->TransferToServer(*Literal::CreateR1<complex64>({{1, 2}, {1, -2}}))
+ ->TransferToServer(
+ *LiteralUtil::CreateR1<complex64>({{1, 2}, {1, -2}}))
.ConsumeValueOrDie();
- auto sum = Literal::CreateR2<complex64>({{{111, 222}, {331, 442}},
- {{1011, 2022}, {3031, 4042}},
- {{10011, 20022}, {30031, 40042}}});
+ auto sum =
+ LiteralUtil::CreateR2<complex64>({{{111, 222}, {331, 442}},
+ {{1011, 2022}, {3031, 4042}},
+ {{10011, 20022}, {30031, 40042}}});
auto prod = MakeUnique<Literal>(sum->shape());
ASSERT_TRUE(prod->Populate<complex64>(
[&sum](tensorflow::gtl::ArraySlice<int64> indexes) {
@@ -504,9 +512,9 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
: complex64(1, -2));
})
.ok());
- auto expected =
- Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(),
- Literal::CreateR0<complex64>({123, 456}).get()});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::MakeTuple({prod.get(), sum.get()}).get(),
+ LiteralUtil::CreateR0<complex64>({123, 456}).get()});
ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()},
error_spec_);
}
@@ -529,11 +537,12 @@ XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
auto module =
HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest())
.ValueOrDie();
- auto param = Literal::MakeTupleOwned(Literal::CreateR1<float>({1, 2, 3}));
+ auto param =
+ LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1<float>({1, 2, 3}));
auto result = ExecuteNoHloPasses(std::move(module), {param.get()});
EXPECT_TRUE(LiteralTestUtil::Equal(
- *result,
- *Literal::MakeTupleOwned(Literal::CreateR2<float>({{1, 2, 3}}))));
+ *LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR2<float>({{1, 2, 3}})),
+ *result));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index c3abe22797..a90a6fb0a5 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -38,8 +38,8 @@ class UnaryOpTest : public ClientLibraryTestBase {
template <typename T>
void AbsSize0TestHelper() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<T>({});
- auto abs = builder.Abs(arg);
+ auto arg = ConstantR1<T>(&builder, {});
+ Abs(arg);
if (primitive_util::NativeToPrimitiveType<T>() == C64) {
ComputeAndCompareR1<float>(&builder, {}, {});
@@ -51,8 +51,8 @@ class UnaryOpTest : public ClientLibraryTestBase {
template <typename T>
void AbsTestHelper() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()});
- auto abs = builder.Abs(arg);
+ auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123, inf<T>(), -inf<T>()});
+ Abs(arg);
ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
}
@@ -60,9 +60,9 @@ class UnaryOpTest : public ClientLibraryTestBase {
template <typename T>
void SignTestHelper() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<T>(
- {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
- auto sign = builder.Sign(arg);
+ auto arg = ConstantR1<T>(
+ &builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
+ Sign(arg);
ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
}
@@ -70,10 +70,10 @@ class UnaryOpTest : public ClientLibraryTestBase {
template <typename T>
void SignAbsTestHelper() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<T>({-2, 25, 0, -123});
- auto sign = builder.Sign(arg);
- auto abs = builder.Abs(arg);
- builder.Sub(builder.Mul(sign, abs), arg);
+ auto arg = ConstantR1<T>(&builder, {-2, 25, 0, -123});
+ auto sign = Sign(arg);
+ auto abs = Abs(arg);
+ Sub(Mul(sign, abs), arg);
ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
}
@@ -92,27 +92,28 @@ int64 UnaryOpTest::inf<int64>() {
template <>
void UnaryOpTest::AbsTestHelper<complex64>() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<complex64>({{-2, 0},
- {0, 25},
- {0, 0},
- {-0.3f, 0.4f},
- {0, inf<float>()},
- {-inf<float>(), 0}});
- auto abs = builder.Abs(arg);
+ auto arg = ConstantR1<complex64>(&builder, {{-2, 0},
+ {0, 25},
+ {0, 0},
+ {-0.3f, 0.4f},
+ {0, inf<float>()},
+ {-inf<float>(), 0}});
+ Abs(arg);
std::unique_ptr<Literal> expected =
- Literal::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
+ LiteralUtil::CreateR1<float>({2, 25, 0, 0.5, inf<float>(), inf<float>()});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
template <>
void UnaryOpTest::SignTestHelper<complex64>() {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<complex64>(
+ auto arg = ConstantR1<complex64>(
+ &builder,
{{-2, 0}, {0, 25}, {0, 0}, {static_cast<float>(-0.0), 0}, {-1, 1}});
- auto sign = builder.Sign(arg);
+ Sign(arg);
- std::unique_ptr<Literal> expected = Literal::CreateR1<complex64>(
+ std::unique_ptr<Literal> expected = LiteralUtil::CreateR1<complex64>(
{{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -121,13 +122,13 @@ template <>
void UnaryOpTest::SignAbsTestHelper<complex64>() {
XlaBuilder builder(TestName());
auto arg =
- builder.ConstantR1<complex64>({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
- auto sign = builder.Sign(arg);
- auto abs = builder.Abs(arg);
- builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg);
+ ConstantR1<complex64>(&builder, {{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}});
+ auto sign = Sign(arg);
+ auto abs = Abs(arg);
+ Sub(Mul(sign, ConvertElementType(abs, C64)), arg);
std::unique_ptr<Literal> expected =
- Literal::CreateR1<complex64>({0, 0, 0, 0});
+ LiteralUtil::CreateR1<complex64>({0, 0, 0, 0});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -145,37 +146,34 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) {
XLA_TEST_F(UnaryOpTest, AbsTestR0) {
XlaBuilder builder(TestName());
- auto argi = builder.ConstantR0<int>(-5);
- auto absi = builder.Abs(argi);
- auto argf = builder.ConstantR0<float>(-3.0f);
- auto absf = builder.Abs(argf);
- auto argf0 = builder.ConstantR0<float>(-0.0f);
- auto absf0 = builder.Abs(argf0);
- auto argc = builder.ConstantR0<complex64>({-0.3f, 0.4f});
- auto absc = builder.Abs(argc);
- builder.Add(builder.Add(absc, absf0),
- builder.Add(absf, builder.ConvertElementType(absi, F32)));
+ auto argi = ConstantR0<int>(&builder, -5);
+ auto absi = Abs(argi);
+ auto argf = ConstantR0<float>(&builder, -3.0f);
+ auto absf = Abs(argf);
+ auto argf0 = ConstantR0<float>(&builder, -0.0f);
+ auto absf0 = Abs(argf0);
+ auto argc = ConstantR0<complex64>(&builder, {-0.3f, 0.4f});
+ auto absc = Abs(argc);
+ Add(Add(absc, absf0), Add(absf, ConvertElementType(absi, F32)));
ComputeAndCompareR0<float>(&builder, 8.5f, {});
}
XLA_TEST_F(UnaryOpTest, SignTestR0) {
XlaBuilder builder(TestName());
- auto argi = builder.ConstantR0<int>(-5);
- auto sgni = builder.Sign(argi); // -1
- auto argf = builder.ConstantR0<float>(-4.0f);
- auto sgnf = builder.Sign(argf); // -1
- auto argf0 = builder.ConstantR0<float>(-0.0f);
- auto sgnf0 = builder.Sign(argf0); // 0
- auto argc = builder.ConstantR0<complex64>({-.3, .4});
- auto sgnc = builder.Sign(argc); // (-.6, .8)
- builder.Add(sgnc, builder.ConvertElementType(
- builder.Add(builder.Add(sgnf0, sgnf),
- builder.ConvertElementType(sgni, F32)),
- C64));
+ auto argi = ConstantR0<int>(&builder, -5);
+ auto sgni = Sign(argi); // -1
+ auto argf = ConstantR0<float>(&builder, -4.0f);
+ auto sgnf = Sign(argf); // -1
+ auto argf0 = ConstantR0<float>(&builder, -0.0f);
+ auto sgnf0 = Sign(argf0); // 0
+ auto argc = ConstantR0<complex64>(&builder, {-.3, .4});
+ auto sgnc = Sign(argc); // (-.6, .8)
+ Add(sgnc, ConvertElementType(
+ Add(Add(sgnf0, sgnf), ConvertElementType(sgni, F32)), C64));
std::unique_ptr<Literal> expected =
- Literal::CreateR0<complex64>({-2.6f, 0.8f});
+ LiteralUtil::CreateR0<complex64>({-2.6f, 0.8f});
ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f));
}
@@ -194,9 +192,9 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) {
XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<unsigned int>(
- {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- auto abs = builder.Abs(arg);
+ auto arg = ConstantR1<unsigned int>(
+ &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ Abs(arg);
ComputeAndCompareR1<unsigned int>(
&builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
@@ -204,37 +202,37 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR1<unsigned int>(
- {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
- auto sign = builder.Sign(arg);
+ auto arg = ConstantR1<unsigned int>(
+ &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ Sign(arg);
ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
}
XLA_TEST_F(UnaryOpTest, SignAbsTestR2) {
XlaBuilder builder(TestName());
- auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
- auto sign = builder.Sign(arg);
- auto abs = builder.Abs(arg);
- builder.Sub(builder.Mul(sign, abs), arg);
+ auto arg = ConstantR2<float>(&builder, {{1.0, -2.0}, {-3.0, 4.0}});
+ auto sign = Sign(arg);
+ auto abs = Abs(arg);
+ Sub(Mul(sign, abs), arg);
ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
}
XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({0, 1});
- auto rhs = builder.ConstantR1<int32>({1, 1});
- builder.ConvertElementType(builder.Eq(lhs, rhs), S32);
+ auto lhs = ConstantR1<int32>(&builder, {0, 1});
+ auto rhs = ConstantR1<int32>(&builder, {1, 1});
+ ConvertElementType(Eq(lhs, rhs), S32);
ComputeAndCompareR1<int32>(&builder, {0, 1}, {});
}
XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) {
XlaBuilder builder(TestName());
- auto lhs = builder.ConstantR1<int32>({0, 1});
- auto rhs = builder.ConstantR1<int32>({1, 1});
- builder.ConvertElementType(builder.Eq(lhs, rhs), F32);
+ auto lhs = ConstantR1<int32>(&builder, {0, 1});
+ auto rhs = ConstantR1<int32>(&builder, {1, 1});
+ ConvertElementType(Eq(lhs, rhs), F32);
ComputeAndCompareR1<float>(&builder, {0.0, 1.0}, {});
}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
index 82d301983f..ea3aba6df1 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -46,7 +46,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase {
{{1.0, 2.0, 3.0}, // } plane 2 in dim 0
{4.0, 5.0, 6.0}}});
// clang-format on
- return builder_.ConstantR3FromArray3D<float>(x3d);
+ return ConstantR3FromArray3D<float>(&builder_, x3d);
}
XlaBuilder builder_;
@@ -56,11 +56,10 @@ class VecOpsReduceTest : public ClientLibraryTestBase {
TEST_F(VecOpsReduceTest, AddReduceR1F32) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
- auto x = builder_.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ auto x = ConstantR1<float>(
+ &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, -4.2f, {}, errspec_);
}
@@ -71,10 +70,9 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) {
std::vector<float> input(3000);
std::iota(input.begin(), input.end(), 100.0f);
- auto x = builder_.ConstantR1<float>(input);
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ auto x = ConstantR1<float>(&builder_, input);
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
float expected = std::accumulate(input.begin(), input.end(), 0.0f);
ComputeAndCompareR0<float>(&builder_, expected, {}, errspec_);
@@ -83,11 +81,10 @@ TEST_F(VecOpsReduceTest, AddReduceBigR1F32) {
TEST_F(VecOpsReduceTest, MaxReduceR1F32) {
auto max_reducer = CreateScalarMax();
- auto x = builder_.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto max_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer,
- /*dimensions_to_reduce=*/{0});
+ auto x = ConstantR1<float>(
+ &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, 2.6f, {}, errspec_);
}
@@ -95,11 +92,10 @@ TEST_F(VecOpsReduceTest, MaxReduceR1F32) {
TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) {
auto max_reducer = CreateScalarMax();
- auto x = builder_.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto max_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer,
- /*dimensions_to_reduce=*/{0});
+ auto x = ConstantR1<float>(
+ &builder_, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Reduce(x, ConstantR0<float>(&builder_, 4.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR0<float>(&builder_, 4.0f, {}, errspec_);
}
@@ -108,15 +104,14 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
// clang-format off
- auto x = builder_.ConstantR2<float>({
+ auto x = ConstantR2<float>(&builder_, {
{1.0, 2.0, 3.0}, // | dim 0
{4.0, 5.0, 6.0}}); // |
// ------ dim 1 ----------
// clang-format on
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
ComputeAndCompareR1<float>(&builder_, {6.0, 15.0}, {}, errspec_);
}
@@ -125,13 +120,12 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
// clang-format off
- auto x = builder_.ConstantR2<float>({
+ auto x = ConstantR2<float>(&builder_, {
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0}});
// clang-format on
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
ComputeAndCompareR1<float>(&builder_, {5.0, 7.0, 9.0}, {}, errspec_);
}
@@ -139,9 +133,8 @@ TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{2});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{2});
Array2D<float> expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}});
@@ -151,9 +144,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
Array2D<float> expected_array(
{{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}});
@@ -164,9 +156,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
Array2D<float> expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}});
@@ -176,9 +167,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{1, 2});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1, 2});
ComputeAndCompareR1<float>(&builder_, {21.0, 21.0, 21.0}, {}, errspec_);
}
@@ -186,9 +176,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 2});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 2});
ComputeAndCompareR1<float>(&builder_, {18.0, 45.0}, {}, errspec_);
}
@@ -196,9 +185,8 @@ XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 1});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1});
ComputeAndCompareR1<float>(&builder_, {15.0, 21.0, 27.0}, {}, errspec_);
}
@@ -206,9 +194,8 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) {
auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
auto x = BuildSampleConstantCube();
- auto add_reduce =
- builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
- /*dimensions_to_reduce=*/{0, 1, 2});
+ Reduce(x, ConstantR0<float>(&builder_, 0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1, 2});
ComputeAndCompareR0<float>(&builder_, 63.0, {}, errspec_);
}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index 5cce7a2bf8..79bae22dac 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -50,9 +50,9 @@ class VecOpsSimpleTest : public ClientLibraryTestBase {
XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto exp = builder.Exp(x);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Exp(x);
std::vector<float> expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02,
8.1662, 9.9742, 6.7379e-03, 4.0657e-01,
@@ -69,8 +69,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
for (int i = 0; i < count; ++i) {
exponents.push_back(i / static_cast<float>(count));
}
- auto x = builder.ConstantR1<float>(exponents);
- auto exp = builder.Exp(x);
+ auto x = ConstantR1<float>(&builder, exponents);
+ Exp(x);
std::vector<float> expected;
expected.reserve(exponents.size());
@@ -98,8 +98,8 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
Array4D<float> expected(2, 2, 2, 2, expected_vector);
- auto x = builder.ConstantR4FromArray4D<float>(exponents);
- auto exp = builder.Exp(x);
+ auto x = ConstantR4FromArray4D<float>(&builder, exponents);
+ Exp(x);
ComputeAndCompareR4<float>(&builder, expected, {},
ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
@@ -107,9 +107,9 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- builder.Neg(x);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Neg(x);
std::vector<float> expected = {-2.1, 2.6, -2.6, 4.0, -2.1,
-2.3, 5.0, 0.9, 2.4, -1.6};
@@ -118,8 +118,8 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
- builder.Neg(x);
+ auto x = ConstantR1<int32>(&builder, {2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
+ Neg(x);
std::vector<int> expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -127,59 +127,19 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<uint32>(
- {0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
- builder.Neg(x);
+ auto x = ConstantR1<uint32>(
+ &builder, {0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
+ Neg(x);
std::vector<uint32> expected = {0, static_cast<uint32>(-1),
static_cast<uint32>(-42), 1, 12};
ComputeAndCompareR1<uint32>(&builder, expected, {});
}
-XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
- XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- builder.SquareF32(x);
-
- std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
- 5.29, 25., 0.81, 5.76, 2.56};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
- XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- builder.ReciprocalF32(x);
-
- std::vector<float> expected = {
- 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
- 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
- XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>({0.0, -0.0});
- auto exp = builder.SqrtF32(x);
-
- ComputeAndCompareR1<float>(&builder, {0, 0}, {}, error_spec_);
-}
-
-XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
- XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
- auto exp = builder.SqrtF32(x);
-
- std::vector<float> expected = {4, 1, 32, 0.4, 0.4472, 111.1080};
- ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
-}
-
XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
XlaBuilder builder(TestName());
- auto x =
- builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345});
- auto exp = builder.Pow(x, builder.ConstantR0<float>(-.5f));
+ auto x = ConstantR1<float>(&builder,
+ {16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345});
+ Pow(x, ConstantR0<float>(&builder, -.5f));
std::vector<float> expected = {.25, 1, .03125, 2.5,
2.23607, .009000, .900025};
@@ -191,11 +151,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
XlaBuilder builder(TestName());
auto add = CreateScalarAddComputation(F32, &builder);
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto y = builder.ConstantR1<float>(
- {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto max = builder.Map({x, y}, add, {0});
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = ConstantR1<float>(
+ &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ Map(&builder, {x, y}, add, {0});
std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9,
0.1, -6.8, 4., -1., 2.2};
@@ -204,11 +164,11 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto y = builder.ConstantR1<float>(
- {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto max = builder.Max(x, y);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = ConstantR1<float>(
+ &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ Max(x, y);
std::vector<float> expected = {2.1, -0.6, 2.6, 0.2, 3.8,
2.3, -1.8, 4.9, 1.4, 1.6};
@@ -227,7 +187,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
{21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto max = builder.Max(v1, v2);
+ Max(v1, v2);
ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -267,7 +227,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
/*builder=*/&builder, /*data_handle=*/&v2);
- auto max = builder.Max(v1, v2);
+ Max(v1, v2);
ComputeAndCompareR1<float>(&builder, expected_vec,
{param0_data.get(), param1_data.get()},
error_spec_);
@@ -275,10 +235,10 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto y = builder.ConstantR0<float>(0);
- auto max = builder.Max(x, y);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = ConstantR0<float>(&builder, 0);
+ Max(x, y);
std::vector<float> expected = {2.1, 0.0, 2.6, 0.0, 2.1,
2.3, 0.0, 0.0, 0.0, 1.6};
@@ -287,11 +247,11 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto y = builder.ConstantR1<float>(
- {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
- auto min = builder.Min(x, y);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = ConstantR1<float>(
+ &builder, {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ Min(x, y);
std::vector<float> expected = {-0.4, -2.6, -3.0, -4.0, 2.1,
-2.2, -5.0, -0.9, -2.4, 0.6};
@@ -300,11 +260,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR0<float>(0);
- auto one = builder.ConstantR0<float>(1);
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Min(builder.Max(x, zero), one);
+ auto zero = ConstantR0<float>(&builder, 0);
+ auto one = ConstantR0<float>(&builder, 1);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ Min(Max(x, zero), one);
std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
0.9, 0.0, 0.1, 0.0, 0.6};
@@ -313,11 +273,11 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR0<float>(0);
- auto one = builder.ConstantR0<float>(1);
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Clamp(zero, x, one);
+ auto zero = ConstantR0<float>(&builder, 0);
+ auto one = ConstantR0<float>(&builder, 1);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ Clamp(zero, x, one);
std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
0.9, 0.0, 0.1, 0.0, 0.6};
@@ -326,10 +286,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
- auto one = builder.ConstantR1<float>({1.0f, 1.0f});
- auto x = builder.ConstantR1<float>({2.1, -2.6});
- auto clamp = builder.Clamp(zero, x, one);
+ auto zero = ConstantR1<float>(&builder, {0.0f, 0.0f});
+ auto one = ConstantR1<float>(&builder, {1.0f, 1.0f});
+ auto x = ConstantR1<float>(&builder, {2.1, -2.6});
+ Clamp(zero, x, one);
std::vector<float> expected = {1.0, 0.0};
ComputeAndCompareR1<float>(&builder, expected, {});
@@ -337,11 +297,11 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
XlaBuilder builder(TestName());
- auto one = builder.ConstantR0<float>(1);
- auto two = builder.ConstantR0<float>(2);
- auto x = builder.ConstantR1<float>(
- {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
- auto clamp = builder.Clamp(one, x, two);
+ auto one = ConstantR0<float>(&builder, 1);
+ auto two = ConstantR0<float>(&builder, 2);
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ Clamp(one, x, two);
std::vector<float> expected = {2.0, 1.0, 2.0, 1.0, 2.0,
1.0, 1.0, 1.0, 1.0, 1.0};
@@ -350,10 +310,10 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) {
XlaBuilder builder(TestName());
- auto zero = builder.ConstantR0<int64>(0);
- auto one = builder.ConstantR0<int64>(10);
- auto x = builder.ConstantR1<int64>({-3, 3, 9, 13});
- auto clamp = builder.Clamp(zero, x, one);
+ auto zero = ConstantR0<int64>(&builder, 0);
+ auto one = ConstantR0<int64>(&builder, 10);
+ auto x = ConstantR1<int64>(&builder, {-3, 3, 9, 13});
+ Clamp(zero, x, one);
std::vector<int64> expected = {0, 3, 9, 10};
ComputeAndCompareR1<int64>(&builder, expected, {});
@@ -365,9 +325,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
// add_half(x) = x + 0.5
XlaBuilder builder("add_half");
auto x_value =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
- auto half = builder.ConstantR0<float>(0.5);
- builder.Add(x_value, half);
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto half = ConstantR0<float>(&builder, 0.5);
+ Add(x_value, half);
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
add_half = computation_status.ConsumeValueOrDie();
@@ -378,9 +338,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
// clamp(y) = clamp<0,5>(y)
XlaBuilder builder("clamp");
auto y_value =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value");
- auto zero = builder.ConstantR0<float>(0.0);
- auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0<float>(5));
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "y_value");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ Clamp(zero, y_value, ConstantR0<float>(&builder, 5));
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
clamp = computation_status.ConsumeValueOrDie();
@@ -391,13 +351,13 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
// mult_relu_add(z) = clamp(add_half(2 * max(z, 0)))
XlaBuilder builder("mult_relu_add");
auto z_value =
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
- auto zero = builder.ConstantR0<float>(0.0);
- auto two = builder.ConstantR0<float>(2.0);
- auto max = builder.Max(z_value, zero);
- auto mult = builder.Mul(two, max);
- auto inner = builder.Map({mult}, add_half, {});
- builder.Map({inner}, clamp, {});
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "z_value");
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto two = ConstantR0<float>(&builder, 2.0);
+ auto max = Max(z_value, zero);
+ auto mult = Mul(two, max);
+ auto inner = Map(&builder, {mult}, add_half, {});
+ Map(&builder, {inner}, clamp, {});
auto computation_status = builder.Build();
ASSERT_IS_OK(computation_status.status());
mult_relu_add = computation_status.ConsumeValueOrDie();
@@ -405,9 +365,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
XlaBuilder builder("map10");
{
- auto x = builder.ConstantR1<float>(
- {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
- auto activations = builder.Map({x}, mult_relu_add, {0});
+ auto x = ConstantR1<float>(
+ &builder, {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ Map(&builder, {x}, mult_relu_add, {0});
}
std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7,
@@ -417,9 +377,9 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<int32>({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4});
- auto y = builder.ConstantR0<int32>(3);
- builder.Rem(x, y);
+ auto x = ConstantR1<int32>(&builder, {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4});
+ auto y = ConstantR0<int32>(&builder, 3);
+ Rem(x, y);
std::vector<int32> expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1};
ComputeAndCompareR1<int32>(&builder, expected, {});
@@ -427,9 +387,9 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<bool>({false, true});
- auto y = builder.ConstantR1<bool>({true, false});
- builder.Eq(x, y);
+ auto x = ConstantR1<bool>(&builder, {false, true});
+ auto y = ConstantR1<bool>(&builder, {true, false});
+ Eq(x, y);
std::array<bool, 2> expected = {{false, false}};
ComputeAndCompareR1<bool>(&builder, expected, {});
@@ -437,9 +397,9 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) {
XlaBuilder builder(TestName());
- auto x = builder.ConstantR1<bool>({false, true});
- auto y = builder.ConstantR1<bool>({true, false});
- builder.Ne(x, y);
+ auto x = ConstantR1<bool>(&builder, {false, true});
+ auto y = ConstantR1<bool>(&builder, {true, false});
+ Ne(x, y);
std::array<bool, 2> expected = {{true, true}};
ComputeAndCompareR1<bool>(&builder, expected, {});
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index c463f3eac5..29befef92e 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -55,8 +55,8 @@ TEST_F(WhileTest, WhileWithScalarS32Result) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Gt(builder.ConstantR0<int32>(5), prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Gt(ConstantR0<int32>(&builder, 5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -64,16 +64,16 @@ TEST_F(WhileTest, WhileWithScalarS32Result) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR0<int32>(1);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR0<int32>(&builder, 1);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.ConstantR0<int32>(0);
- builder.While(condition, body, init);
+ auto init = ConstantR0<int32>(&builder, 0);
+ While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -91,8 +91,8 @@ TEST_F(WhileTest, WhileWithScalarS64Result) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Gt(builder.ConstantR0<int64>(5), prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Gt(ConstantR0<int64>(&builder, 5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -100,16 +100,16 @@ TEST_F(WhileTest, WhileWithScalarS64Result) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR0<int64>(1);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR0<int64>(&builder, 1);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.ConstantR0<int64>(0);
- builder.While(condition, body, init);
+ auto init = ConstantR0<int64>(&builder, 0);
+ While(condition, body, init);
ComputeAndCompareR0<int64>(&builder, 5, {});
}
@@ -122,8 +122,8 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Gt(builder.ConstantR0<int32>(5), prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Gt(ConstantR0<int32>(&builder, 5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -131,18 +131,18 @@ TEST_F(WhileTest, WhileWithScalarResultNonConstInit) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR0<int32>(1);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR0<int32>(&builder, 1);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
- builder.ConstantR0<int32>(0),
- CreateScalarAddComputation(S32, &builder), {0});
- builder.While(condition, body, init);
+ auto init =
+ Reduce(ConstantR1<int32>(&builder, 2, 1), ConstantR0<int32>(&builder, 0),
+ CreateScalarAddComputation(S32, &builder), {0});
+ While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -154,8 +154,8 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Ne(builder.ConstantR0<bool>(true), prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Ne(ConstantR0<bool>(&builder, true), prev);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -163,16 +163,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Or(prev, builder.ConstantR0<bool>(true));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Or(prev, ConstantR0<bool>(&builder, true));
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.Ne(builder.ConstantR0<bool>(false),
- builder.ConstantR0<bool>(true));
- builder.While(condition, body, init);
+ auto init =
+ Ne(ConstantR0<bool>(&builder, false), ConstantR0<bool>(&builder, true));
+ While(condition, body, init);
ComputeAndCompareR0<bool>(&builder, true, {});
}
@@ -184,17 +184,16 @@ TEST_F(WhileTest, WhileWithPredicateResult) {
// while (result.sum() < 15.5f) {
// result = result + vector<float>(0);
// }
-// TODO(b/29185393): does not terminate on CPU.
-TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
+TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithEmptyVectorResult)) {
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
// Create a computation for the reduction.
XlaComputation add;
{
XlaBuilder builder("add");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Add(x, y);
add = builder.Build().ConsumeValueOrDie();
}
@@ -203,10 +202,10 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
- /*dimensions_to_reduce=*/{0});
- builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ Gt(ConstantR0<float>(&builder, 15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -215,16 +214,16 @@ TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR1<float>({});
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR1<float>(&builder, {});
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.ConstantR1<float>({});
- auto result = builder.While(condition, body, init);
+ auto init = ConstantR1<float>(&builder, {});
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -246,9 +245,9 @@ TEST_F(WhileTest, WhileWithVectorResult) {
XlaComputation add;
{
XlaBuilder builder("add");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Add(x, y);
add = builder.Build().ConsumeValueOrDie();
}
@@ -257,10 +256,10 @@ TEST_F(WhileTest, WhileWithVectorResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
- /*dimensions_to_reduce=*/{0});
- builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ Gt(ConstantR0<float>(&builder, 15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -269,16 +268,16 @@ TEST_F(WhileTest, WhileWithVectorResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR1<float>(8, 0.125f);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR1<float>(&builder, 8, 0.125f);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.ConstantR1<float>(8, 0.f);
- auto result = builder.While(condition, body, init);
+ auto init = ConstantR1<float>(&builder, 8, 0.f);
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -306,9 +305,9 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
XlaComputation add;
{
XlaBuilder builder("add");
- auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
- auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
- builder.Add(x, y);
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {}), "y");
+ Add(x, y);
add = builder.Build().ConsumeValueOrDie();
}
@@ -317,10 +316,10 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
- /*dimensions_to_reduce=*/{0});
- builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto sum = Reduce(prev, ConstantR0<float>(&builder, 0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ Gt(ConstantR0<float>(&builder, 15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -329,27 +328,27 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR1<float>(8, 0.125f);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR1<float>(&builder, 8, 0.125f);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.ConstantR1<float>(8, 0.f);
- auto result = builder.While(condition, body, init);
+ auto init = ConstantR1<float>(&builder, 8, 0.f);
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- builder.Tuple({result});
+ Tuple(&builder, {result});
// Individual elements with increase by 1/8 each time through the loop, so
// the sum will increase by 1.0. It will first be >15.5 when the elements
// have all reached 2.0.
auto expected_data =
- Literal::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
- auto expected = Literal::MakeTuple({expected_data.get()});
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f});
+ auto expected = LiteralUtil::MakeTuple({expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -366,9 +365,9 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(N), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, N), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -377,32 +376,34 @@ TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto w1 = builder.GetTupleElement(prev, 1);
- auto w2 = builder.GetTupleElement(prev, 2);
- auto w3 = builder.GetTupleElement(prev, 3);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto w1 = GetTupleElement(prev, 1);
+ auto w2 = GetTupleElement(prev, 2);
+ auto w3 = GetTupleElement(prev, 3);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
- builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
- auto result = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 3, 1.f),
+ ConstantR1<float>(&builder, 3, 2.f),
+ ConstantR1<float>(&builder, 3, 3.f)});
+ auto result = While(condition, body, init);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(N);
- auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
- auto expected_w2 = Literal::CreateR1<float>({2.0f, 2.0f, 2.0f});
- auto expected_w3 = Literal::CreateR1<float>({3.0f, 3.0f, 3.0f});
- auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(),
- expected_w3.get(), expected_w1.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(N);
+ auto expected_w1 = LiteralUtil::CreateR1<float>({1.0f, 1.0f, 1.0f});
+ auto expected_w2 = LiteralUtil::CreateR1<float>({2.0f, 2.0f, 2.0f});
+ auto expected_w3 = LiteralUtil::CreateR1<float>({3.0f, 3.0f, 3.0f});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_w2.get(),
+ expected_w3.get(), expected_w1.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -419,9 +420,9 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(N), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, N), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -430,26 +431,27 @@ TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto w1 = builder.GetTupleElement(prev, 1);
- auto w2 = builder.GetTupleElement(prev, 2);
- auto w3 = builder.GetTupleElement(prev, 3);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto w1 = GetTupleElement(prev, 1);
+ auto w2 = GetTupleElement(prev, 2);
+ auto w3 = GetTupleElement(prev, 3);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
- builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
- auto xla_while = builder.While(condition, body, init);
-
- auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1),
- builder.GetTupleElement(xla_while, 2));
- auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 3, 1.f),
+ ConstantR1<float>(&builder, 3, 2.f),
+ ConstantR1<float>(&builder, 3, 3.f)});
+ auto xla_while = While(condition, body, init);
+
+ auto add12 =
+ Add(GetTupleElement(xla_while, 1), GetTupleElement(xla_while, 2));
+ auto result = Add(add12, GetTupleElement(xla_while, 3));
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -474,9 +476,9 @@ TEST_F(WhileTest, WhileWithTupleResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, 5), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -486,30 +488,30 @@ TEST_F(WhileTest, WhileWithTupleResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto weights = builder.GetTupleElement(prev, 1);
- auto input = builder.ConstantR1<float>(10, 1.f);
- auto new_weights = builder.Add(weights, input);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto weights = GetTupleElement(prev, 1);
+ auto input = ConstantR1<float>(&builder, 10, 1.f);
+ auto new_weights = Add(weights, input);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
- auto result = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 10, 0.f)});
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -524,9 +526,9 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, 5), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -535,29 +537,28 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto pred = builder.GetTupleElement(prev, 1);
- auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto pred = GetTupleElement(prev, 1);
+ auto new_pred = Or(pred, ConstantR0<bool>(&builder, true));
+ Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)), new_pred});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple({builder.ConstantR0<int32>(0),
- builder.Ne(builder.ConstantR0<bool>(false),
- builder.ConstantR0<bool>(true))});
- auto result = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ Ne(ConstantR0<bool>(&builder, false),
+ ConstantR0<bool>(&builder, true))});
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_predicate = Literal::CreateR0<bool>(true);
- auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_predicate.get()});
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_predicate = LiteralUtil::CreateR0<bool>(true);
+ auto expected = LiteralUtil::MakeTuple(
+ {expected_counter.get(), expected_predicate.get()});
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0));
}
@@ -571,9 +572,9 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, 5), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -583,26 +584,26 @@ TEST_F(WhileTest, WhileWithTupleConstantScalarResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
- builder.ConstantR0<int32>(7)});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Tuple(&builder, {Add(iteration, ConstantR0<int32>(&builder, 1)),
+ ConstantR0<int32>(&builder, 7)});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
- auto result = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR0<int32>(&builder, 7)});
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR0<int32>(7);
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR0<int32>(7);
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -632,9 +633,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
const int c1 = 5;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c1));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
@@ -642,9 +643,9 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
const int c2 = 7;
{
XlaBuilder builder("condition2");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c2));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c2));
TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
}
@@ -654,43 +655,43 @@ TEST_F(WhileTest, TwoWhileWithTupleResult) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto weights = builder.GetTupleElement(prev, 1);
- auto input = builder.ConstantR1<float>(10, 1.f);
- auto new_weights = builder.Add(weights, input);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto weights = GetTupleElement(prev, 1);
+ auto input = ConstantR1<float>(&builder, 10, 1.f);
+ auto new_weights = Add(weights, input);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
XlaComputation body2;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto weights = builder.GetTupleElement(prev, 1);
- auto input = builder.ConstantR1<float>(10, 1.f);
- auto new_weights = builder.Add(weights, input);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto weights = GetTupleElement(prev, 1);
+ auto input = ConstantR1<float>(&builder, 10, 1.f);
+ auto new_weights = Add(weights, input);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
- auto while1 = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 10, 0.f)});
+ auto while1 = While(condition, body, init);
- auto while2 = builder.While(condition2, body2, while1);
+ auto while2 = While(condition2, body2, while1);
- auto while_result1 = builder.GetTupleElement(while1, 1);
- auto while_result2 = builder.GetTupleElement(while2, 1);
+ auto while_result1 = GetTupleElement(while1, 1);
+ auto while_result2 = GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
builder.GetShape(while_result2).ConsumeValueOrDie());
- auto result = builder.Add(while_result1, while_result2);
+ auto result = Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -711,9 +712,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
const int c1 = 5;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c1));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
@@ -721,9 +722,9 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
const int c2 = 7;
{
XlaBuilder builder("condition2");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c2));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c2));
TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
}
@@ -733,30 +734,30 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto weights = builder.GetTupleElement(prev, 1);
- auto input = builder.ConstantR1<float>(10, 1.f);
- auto new_weights = builder.Add(weights, input);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto weights = GetTupleElement(prev, 1);
+ auto input = ConstantR1<float>(&builder, 10, 1.f);
+ auto new_weights = Add(weights, input);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
- auto while1 = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 10, 0.f)});
+ auto while1 = While(condition, body, init);
- auto while2 = builder.While(condition2, body, while1);
+ auto while2 = While(condition2, body, while1);
- auto while_result1 = builder.GetTupleElement(while1, 1);
- auto while_result2 = builder.GetTupleElement(while2, 1);
+ auto while_result1 = GetTupleElement(while1, 1);
+ auto while_result2 = GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
builder.GetShape(while_result2).ConsumeValueOrDie());
- auto result = builder.Add(while_result1, while_result2);
+ auto result = Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -778,9 +779,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
const int c1 = 5;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c1));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
@@ -788,9 +789,9 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
const int c2 = 7;
{
XlaBuilder builder("condition2");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(c2));
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, c2));
TF_ASSERT_OK_AND_ASSIGN(condition2, builder.Build());
}
@@ -800,29 +801,29 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- auto weights = builder.GetTupleElement(prev, 1);
- auto input = builder.ConstantR1<float>(10, 1.f);
- auto new_weights = builder.Add(weights, input);
- builder.Tuple(
- {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ auto weights = GetTupleElement(prev, 1);
+ auto input = ConstantR1<float>(&builder, 10, 1.f);
+ auto new_weights = Add(weights, input);
+ Tuple(&builder,
+ {Add(iteration, ConstantR0<int32>(&builder, 1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
- auto while1 = builder.While(condition, body, init);
- auto while2 = builder.While(condition2, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 10, 0.f)});
+ auto while1 = While(condition, body, init);
+ auto while2 = While(condition2, body, init);
- auto while_result1 = builder.GetTupleElement(while1, 1);
- auto while_result2 = builder.GetTupleElement(while2, 1);
+ auto while_result1 = GetTupleElement(while1, 1);
+ auto while_result2 = GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
builder.GetShape(while_result2).ConsumeValueOrDie());
- auto result = builder.Add(while_result1, while_result2);
+ auto result = Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
@@ -844,9 +845,9 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Gt(ConstantR0<int32>(&builder, 5), iteration);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -856,38 +857,37 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
// TupleElement 0
- auto iteration = builder.GetTupleElement(prev, 0);
- auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
+ auto iteration = GetTupleElement(prev, 0);
+ auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
// TupleElement 1
- auto input = builder.GetTupleElement(prev, 1);
+ auto input = GetTupleElement(prev, 1);
// Update.
- auto update = builder.ConvertElementType(builder.Broadcast(out0, {2}), F32);
+ auto update = ConvertElementType(Broadcast(out0, {2}), F32);
// Starts = iteration * 2;
- auto starts = builder.Reshape(
- builder.Mul(iteration, builder.ConstantR0<int32>(2)), {1});
+ auto starts = Reshape(Mul(iteration, ConstantR0<int32>(&builder, 2)), {1});
// UpdateSlice.
- auto out1 = builder.DynamicUpdateSlice(input, update, starts);
+ auto out1 = DynamicUpdateSlice(input, update, starts);
- builder.Tuple({out0, out1});
+ Tuple(&builder, {out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder("while");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
- auto result = builder.While(condition, body, init);
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ ConstantR1<float>(&builder, 10, 0.f)});
+ auto result = While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
builder.GetShape(result).ConsumeValueOrDie());
- auto expected_counter = Literal::CreateR0<int32>(5);
- auto expected_data = Literal::CreateR1<float>(
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
{1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f});
auto expected =
- Literal::MakeTuple({expected_counter.get(), expected_data.get()});
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
@@ -913,10 +913,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
// Create a computation for the condition: repeat for count iterations.
auto build_condition = [this, v6s32](int count) {
XlaBuilder builder(TestName());
- auto prev = builder.Reshape(
- builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
- {});
- builder.Gt(builder.ConstantR0<int32>(count), prev);
+ auto prev = Reshape(
+ Slice(Parameter(&builder, 0, v6s32, "prev"), {0}, {1}, {1}), {0}, {});
+ Gt(ConstantR0<int32>(&builder, count), prev);
return builder.Build().ConsumeValueOrDie();
};
@@ -924,22 +923,22 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, v6s32, "prev");
- auto inc = builder.ConcatInDim(
- {builder.ConstantR1<int32>({1}),
- builder.RngUniform(builder.ConstantR0<int32>(0),
- builder.ConstantR0<int32>(100),
- ShapeUtil::MakeShape(S32, {5}))},
- 0);
- builder.Add(inc, prev);
+ auto prev = Parameter(&builder, 0, v6s32, "prev");
+ auto inc = ConcatInDim(&builder,
+ {ConstantR1<int32>(&builder, {1}),
+ RngUniform(ConstantR0<int32>(&builder, 0),
+ ConstantR0<int32>(&builder, 100),
+ ShapeUtil::MakeShape(S32, {5}))},
+ 0);
+ Add(inc, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto while_loop = [this, &body, build_condition](int count) {
XlaBuilder builder(TestName());
- auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
- builder.While(build_condition(count), body, init);
+ auto init = ConstantR1<int32>(&builder, {0, 0, 0, 0, 0, 0});
+ While(build_condition(count), body, init);
return builder.Build();
};
@@ -958,33 +957,30 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) {
auto element_shape = ShapeUtil::MakeShape(F32, {2});
XlaBuilder outer("outer");
- auto p = outer.Parameter(0, element_shape, "param");
- auto t = outer.Tuple({p, outer.ConstantR1<float>({1, 1})});
+ auto p = Parameter(&outer, 0, element_shape, "param");
+ auto t = Tuple(&outer, {p, ConstantR1<float>(&outer, {1, 1})});
TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t));
XlaBuilder cond("cond");
- auto cond_t = cond.Parameter(0, tuple_shape, "t");
- TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
- cond.ConstantR1<float>({42, 42})),
- &cond)
- .status());
+ auto cond_t = Parameter(&cond, 0, tuple_shape, "t");
+ Any(Eq(GetTupleElement(cond_t, 0), ConstantR1<float>(&cond, {42, 42})));
XlaBuilder body("body");
- auto body_t = body.Parameter(0, tuple_shape, "t");
- auto e = body.GetTupleElement(body_t, 1);
- body.Tuple({e, e});
+ auto body_t = Parameter(&body, 0, tuple_shape, "t");
+ auto e = GetTupleElement(body_t, 1);
+ Tuple(&body, {e, e});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
- outer.While(cond_computation, body_computation, t);
+ While(cond_computation, body_computation, t);
- auto expected_element = Literal::CreateR1<float>({1, 1});
+ auto expected_element = LiteralUtil::CreateR1<float>({1, 1});
auto expected =
- Literal::MakeTuple({expected_element.get(), expected_element.get()});
+ LiteralUtil::MakeTuple({expected_element.get(), expected_element.get()});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -993,24 +989,23 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) {
auto element_shape = ShapeUtil::MakeShape(F32, {2});
XlaBuilder outer("outer");
- auto p = outer.Parameter(0, element_shape, "param");
+ auto p = Parameter(&outer, 0, element_shape, "param");
XlaBuilder cond("cond");
- auto cond_t = cond.Parameter(0, element_shape, "t");
- TF_ASSERT_OK(
- Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status());
+ auto cond_t = Parameter(&cond, 0, element_shape, "t");
+ Any(Eq(cond_t, ConstantR1<float>(&cond, {42, 42})));
XlaBuilder body("body");
- auto body_t = body.Parameter(0, element_shape, "t");
- auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2});
+ Parameter(&body, 0, element_shape, "t");
+ Broadcast(ConstantR0<float>(&body, 1.0), {2});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
- outer.While(cond_computation, body_computation, p);
+ While(cond_computation, body_computation, p);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
+ client_->TransferToServer(*LiteralUtil::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1019,25 +1014,24 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
auto element_shape = ShapeUtil::MakeShape(F32, {});
XlaBuilder outer("outer");
- auto p = outer.Parameter(0, element_shape, "param");
+ auto p = Parameter(&outer, 0, element_shape, "param");
XlaBuilder cond("cond");
- auto cond_t = cond.Parameter(0, element_shape, "t");
- cond.Eq(cond_t, cond.ConstantR0<float>(42));
+ auto cond_t = Parameter(&cond, 0, element_shape, "t");
+ Eq(cond_t, ConstantR0<float>(&cond, 42));
XlaBuilder body("body");
- auto body_t = body.Parameter(0, element_shape, "t");
- auto tuple =
- body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))});
- auto e = body.GetTupleElement(tuple, 1);
+ auto body_t = Parameter(&body, 0, element_shape, "t");
+ auto tuple = Tuple(&body, {body_t, Add(body_t, ConstantR0<float>(&body, 1))});
+ GetTupleElement(tuple, 1);
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
- outer.While(cond_computation, body_computation, p);
+ While(cond_computation, body_computation, p);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<float>(42)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1056,33 +1050,31 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) {
XlaBuilder outer("outer");
auto p =
- outer.Tuple({outer.ConstantR0<int32>(0),
- outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")});
+ Tuple(&outer, {ConstantR0<int32>(&outer, 0),
+ Parameter(&outer, 0, ShapeUtil::MakeShape(S32, {}), "t")});
XlaBuilder cond("cond");
- auto params = cond.Parameter(0, result_shape, "prev");
- auto cond_t = cond.Add(cond.GetTupleElement(params, 1),
- cond.GetTupleElement(params, 0));
- cond.Lt(cond_t, cond.ConstantR0<int32>(30));
+ auto params = Parameter(&cond, 0, result_shape, "prev");
+ auto cond_t = Add(GetTupleElement(params, 1), GetTupleElement(params, 0));
+ Lt(cond_t, ConstantR0<int32>(&cond, 30));
XlaBuilder body("body");
- auto body_t = body.Parameter(0, result_shape, "t");
+ auto body_t = Parameter(&body, 0, result_shape, "t");
- auto tuple = body.Tuple(
- {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0<int32>(1)),
- body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0<int32>(1))});
+ Tuple(&body, {Add(GetTupleElement(body_t, 0), ConstantR0<int32>(&body, 1)),
+ Add(GetTupleElement(body_t, 1), ConstantR0<int32>(&body, 1))});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
- outer.While(cond_computation, body_computation, p);
+ While(cond_computation, body_computation, p);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
- client_->TransferToServer(*Literal::CreateR0<int32>(1)));
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(1)));
- auto add1 = Literal::CreateR0<int32>(15);
- auto add2 = Literal::CreateR0<int32>(16);
- auto expected = Literal::MakeTuple({add1.get(), add2.get()});
+ auto add1 = LiteralUtil::CreateR0<int32>(15);
+ auto add2 = LiteralUtil::CreateR0<int32>(16);
+ auto expected = LiteralUtil::MakeTuple({add1.get(), add2.get()});
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
@@ -1105,9 +1097,9 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
XlaComputation inner_condition;
{
XlaBuilder builder("inner_condition");
- auto params = builder.Parameter(0, inner_result_shape, "prev");
- auto i = builder.GetTupleElement(params, 0);
- builder.Lt(i, builder.ConstantR0<int32>(7));
+ auto params = Parameter(&builder, 0, inner_result_shape, "prev");
+ auto i = GetTupleElement(params, 0);
+ Lt(i, ConstantR0<int32>(&builder, 7));
inner_condition = builder.Build().ConsumeValueOrDie();
}
@@ -1116,8 +1108,8 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
XlaComputation outer_condition;
{
XlaBuilder builder("outer_condition");
- auto prev = builder.Parameter(0, outer_result_shape, "prev");
- builder.Lt(prev, builder.ConstantR0<int32>(30));
+ auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
+ Lt(prev, ConstantR0<int32>(&builder, 30));
outer_condition = builder.Build().ConsumeValueOrDie();
}
@@ -1126,12 +1118,12 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
XlaComputation inner_body;
{
XlaBuilder builder("inner_body");
- auto params = builder.Parameter(0, inner_result_shape, "prev");
- auto i = builder.GetTupleElement(params, 0);
- auto result = builder.GetTupleElement(params, 1);
- i = builder.Add(builder.ConstantR0<int32>(1), i);
- result = builder.Add(builder.ConstantR0<int32>(2), result);
- builder.Tuple({i, result});
+ auto params = Parameter(&builder, 0, inner_result_shape, "prev");
+ auto i = GetTupleElement(params, 0);
+ auto result = GetTupleElement(params, 1);
+ i = Add(ConstantR0<int32>(&builder, 1), i);
+ result = Add(ConstantR0<int32>(&builder, 2), result);
+ Tuple(&builder, {i, result});
inner_body = builder.Build().ConsumeValueOrDie();
}
@@ -1139,17 +1131,17 @@ XLA_TEST_F(WhileTest, NestedWhileWithScalarResult) {
XlaComputation outer_body;
{
XlaBuilder builder("outer_body");
- auto prev = builder.Parameter(0, outer_result_shape, "prev");
- auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
- auto result = builder.While(inner_condition, inner_body, init);
- builder.GetTupleElement(result, 1);
+ auto prev = Parameter(&builder, 0, outer_result_shape, "prev");
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), prev});
+ auto result = While(inner_condition, inner_body, init);
+ GetTupleElement(result, 1);
outer_body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.ConstantR0<int32>(0);
- builder.While(outer_condition, outer_body, init);
+ auto init = ConstantR0<int32>(&builder, 0);
+ While(outer_condition, outer_body, init);
ComputeAndCompareR0<int32>(&builder, 42, {});
}
@@ -1167,8 +1159,8 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
XlaComputation condition_callee;
{
XlaBuilder builder("condition_callee");
- auto prev = builder.Parameter(0, result_shape, "prev");
- builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ Tuple(&builder, {Gt(ConstantR0<int32>(&builder, 5), prev)});
condition_callee = builder.Build().ConsumeValueOrDie();
}
@@ -1176,9 +1168,9 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto result = builder.Call(condition_callee, {prev});
- builder.GetTupleElement(result, 0);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto result = Call(&builder, condition_callee, {prev});
+ GetTupleElement(result, 0);
condition = builder.Build().ConsumeValueOrDie();
}
@@ -1186,16 +1178,16 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithCallInsideCondition)) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, result_shape, "prev");
- auto input = builder.ConstantR0<int32>(1);
- builder.Add(input, prev);
+ auto prev = Parameter(&builder, 0, result_shape, "prev");
+ auto input = ConstantR0<int32>(&builder, 1);
+ Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
XlaBuilder builder(TestName());
- auto init = builder.ConstantR0<int32>(0);
- builder.While(condition, body, init);
+ auto init = ConstantR0<int32>(&builder, 0);
+ While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
@@ -1210,34 +1202,34 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto state = builder.Parameter(0, while_shape, "state");
- builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
+ auto state = Parameter(&builder, 0, while_shape, "state");
+ Gt(ConstantR0<int32>(&builder, 5), GetTupleElement(state, 0));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
XlaComputation body;
{
XlaBuilder builder("body");
- auto state = builder.Parameter(0, while_shape, "state");
- auto indvar = builder.GetTupleElement(state, 0);
- auto input_0 = builder.GetTupleElement(state, 1);
- auto input_1 = builder.GetTupleElement(state, 2);
- auto output = builder.Tanh(builder.Dot(input_0, input_1));
- auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
- builder.Tuple({indvar_next, input_0, input_1, output});
+ auto state = Parameter(&builder, 0, while_shape, "state");
+ auto indvar = GetTupleElement(state, 0);
+ auto input_0 = GetTupleElement(state, 1);
+ auto input_1 = GetTupleElement(state, 2);
+ auto output = Tanh(Dot(input_0, input_1));
+ auto indvar_next = Add(indvar, ConstantR0<int32>(&builder, 1));
+ Tuple(&builder, {indvar_next, input_0, input_1, output});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
XlaBuilder builder(TestName());
- auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
- auto init = builder.Tuple(
- {builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
- auto while_instruction = builder.While(condition, body, init);
- builder.GetTupleElement(while_instruction, 3);
+ auto matrix_input = Parameter(&builder, 0, matrix_shape, "matrix");
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), matrix_input,
+ matrix_input, matrix_input});
+ auto while_instruction = While(condition, body, init);
+ GetTupleElement(while_instruction, 3);
- TF_ASSERT_OK_AND_ASSIGN(auto param_value,
- client_->TransferToServer(*Literal::CreateR2<float>(
- {{1.0, 2.0}, {-1.0, -2.0}})));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto param_value, client_->TransferToServer(*LiteralUtil::CreateR2<float>(
+ {{1.0, 2.0}, {-1.0, -2.0}})));
ComputeAndCompareR2<float>(
&builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}},
@@ -1264,9 +1256,9 @@ void BM_WhileLoop(int num_iters) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto prev = builder.Parameter(0, loop_state_shape, "prev");
- auto iteration = builder.GetTupleElement(prev, 0);
- builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
+ auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
+ auto iteration = GetTupleElement(prev, 0);
+ Lt(iteration, ConstantR0<int32>(&builder, loop_limit));
condition = builder.Build().ConsumeValueOrDie();
}
@@ -1274,29 +1266,29 @@ void BM_WhileLoop(int num_iters) {
XlaComputation body;
{
XlaBuilder builder("body");
- auto prev = builder.Parameter(0, loop_state_shape, "prev");
+ auto prev = Parameter(&builder, 0, loop_state_shape, "prev");
// TupleElement 0
- auto iteration = builder.GetTupleElement(prev, 0);
- auto out0 = builder.Add(iteration, builder.ConstantR0<int32>(1));
+ auto iteration = GetTupleElement(prev, 0);
+ auto out0 = Add(iteration, ConstantR0<int32>(&builder, 1));
// TupleElement 1
- auto input = builder.GetTupleElement(prev, 1);
+ auto input = GetTupleElement(prev, 1);
// Update.
- auto one = builder.ConstantR0<float>(1.0);
- auto update = builder.Broadcast(one, {1, 1024, 1024});
+ auto one = ConstantR0<float>(&builder, 1.0);
+ auto update = Broadcast(one, {1, 1024, 1024});
// Starts = iteration * 2;
- auto starts = builder.ConstantR1<int32>({0, 0, 0});
+ auto starts = ConstantR1<int32>(&builder, {0, 0, 0});
// UpdateSlice.
- auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- builder.Tuple({out0, out1});
+ auto out1 = DynamicUpdateSlice(input, update, starts);
+ Tuple(&builder, {out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While instruction.
XlaBuilder builder("while");
- auto zero = builder.ConstantR0<float>(0.0);
- auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
- auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});
- builder.While(condition, body, init);
+ auto zero = ConstantR0<float>(&builder, 0.0);
+ auto input = Broadcast(zero, {seq_len, 1024, 1024});
+ auto init = Tuple(&builder, {ConstantR0<int32>(&builder, 0), input});
+ While(condition, body, init);
auto computation = builder.Build().ConsumeValueOrDie();
std::unique_ptr<LocalExecutable> executable =
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 3c9a01653c..4d4dd62a3f 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -79,7 +79,9 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
- gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results) {
+ gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
+ tensorflow::gtl::ArraySlice<tensorflow::StringPiece> opcodes_to_ignore =
+ {}) {
string separator = "[^:]*:: +";
string match_percentage = "\\d+\\.\\d\\d%";
string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)";
@@ -113,7 +115,9 @@ Status ParseOneProfileOutputLine(
", Regexp: ", regexp_pattern);
}
- InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ if (!c_linear_search(opcodes_to_ignore, parsed_line.opcode)) {
+ InsertOrDie(parsed_results, parsed_line.opcode, parsed_line);
+ }
return Status::OK();
}
@@ -128,20 +132,23 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
se::StreamExecutor* executor = backend->default_stream_executor();
DeviceMemoryAllocator* allocator = backend->memory_allocator();
auto* transfer_manager = backend->transfer_manager();
+ TF_ASSERT_OK_AND_ASSIGN(
+ Backend::StreamPtr stream_ptr,
+ backend->BorrowStream(backend->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer lhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
lhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
+ stream_ptr.get(), *Literal::CreateFromShape(lhs_arg_shape), lhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
ScopedShapedBuffer rhs_arg,
transfer_manager->AllocateScopedShapedBuffer(
rhs_arg_shape, allocator, backend->default_device_ordinal()));
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
- executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ stream_ptr.get(), *Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
@@ -153,9 +160,6 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
&executable->hlo_profile_printer_data(),
&executable->hlo_profile_index_map());
- TF_ASSERT_OK_AND_ASSIGN(
- Backend::StreamPtr stream_ptr,
- backend->BorrowStream(backend->default_device_ordinal()));
ExecutableRunOptions exec_run_options;
exec_run_options.set_stream(stream_ptr.get());
exec_run_options.set_allocator(backend->memory_allocator());
@@ -187,9 +191,9 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
ClientLibrary::GetOrCreateLocalClient(platform));
XlaBuilder builder(TestName());
- auto result = builder.Tanh(builder.Add(
- builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
- builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
+ Tanh(Add(
+ Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"),
+ Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs")));
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
@@ -239,9 +243,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
EXPECT_TRUE(HasTrops(tanh_profile));
}
-// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo
-// instructions "interior" to while nodes.
-XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
+XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
const int64 size = 256;
Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
Shape while_result_shape =
@@ -255,30 +257,30 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
XlaComputation condition;
{
XlaBuilder builder("condition");
- auto state = builder.Parameter(0, while_result_shape, "state");
- auto iteration = builder.GetTupleElement(state, 0);
- builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ auto state = Parameter(&builder, 0, while_result_shape, "state");
+ auto iteration = GetTupleElement(state, 0);
+ Gt(ConstantR0<int32>(&builder, 5), iteration);
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
XlaComputation body;
{
XlaBuilder builder("body");
- auto state = builder.Parameter(0, while_result_shape, "state");
- auto matrix = builder.GetTupleElement(state, 1);
- auto next_iteration = builder.Add(builder.GetTupleElement(state, 0),
- builder.ConstantR0<int32>(1));
- builder.Tuple({next_iteration, builder.Add(matrix, matrix)});
+ auto state = Parameter(&builder, 0, while_result_shape, "state");
+ auto matrix = GetTupleElement(state, 1);
+ auto next_iteration =
+ Add(GetTupleElement(state, 0), ConstantR0<int32>(&builder, 1));
+ Tuple(&builder, {next_iteration, Mul(matrix, matrix)});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
XlaBuilder builder(TestName());
auto initial_while_state =
- builder.Tuple({builder.ConstantR0<int32>(0),
- builder.Parameter(0, matrix_shape, "initial_value")});
- auto while_result = builder.While(condition, body, initial_while_state);
- builder.Add(builder.GetTupleElement(while_result, 1),
- builder.Parameter(1, matrix_shape, "other_value"));
+ Tuple(&builder, {ConstantR0<int32>(&builder, 0),
+ Parameter(&builder, 0, matrix_shape, "initial_value")});
+ auto while_result = While(condition, body, initial_while_state);
+ Add(GetTupleElement(while_result, 1),
+ Parameter(&builder, 1, matrix_shape, "other_value"));
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
@@ -290,36 +292,50 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
tensorflow::str_util::Split(profile_output, '\n');
auto while_body_profile_start =
- std::find_if(profile_output_lines.begin(), profile_output_lines.end(),
+ c_find_if(profile_output_lines, [](tensorflow::StringPiece s) {
+ return tensorflow::str_util::StartsWith(s,
+ "Execution profile for body");
+ });
+
+ ASSERT_NE(while_body_profile_start, profile_output_lines.cend());
+
+ auto while_body_profile_end =
+ std::find_if(while_body_profile_start, profile_output_lines.end(),
[](tensorflow::StringPiece s) {
return tensorflow::str_util::StartsWith(
- s, "Execution profile for body");
+ s, "********** microseconds report **********");
});
- ASSERT_NE(while_body_profile_start, profile_output_lines.end());
+ // We emit a blank line before the "********** microseconds report **********"
+ // line.
+ while_body_profile_end--;
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ ASSERT_NE(while_body_profile_end, profile_output_lines.end());
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1),
- /*expect_hlo=*/false, &parsed_profile_lines));
+ gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
- TF_ASSERT_OK(
- ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2),
- /*expect_hlo=*/true, &parsed_profile_lines));
+ for (auto while_body_profile_i = while_body_profile_start + 1;
+ while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
+ // There are multiple "get-tuple-element" instructions in the while body so
+ // we ignore them -- we don't want parsed_profile_lines to be a multi-map.
+ TF_ASSERT_OK(ParseOneProfileOutputLine(
+ *while_body_profile_i,
+ /*expect_hlo=*/while_body_profile_i != (while_body_profile_start + 1),
+ &parsed_profile_lines, {"get-tuple-element"}));
+ }
TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile,
MaybeFind(parsed_profile_lines, "[total]"));
- TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile,
- MaybeFind(parsed_profile_lines, "add"));
+ TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine multiply_profile,
+ MaybeFind(parsed_profile_lines, "multiply"));
EXPECT_GT(total_while_body_profile.cycles, 0);
EXPECT_EQ(total_while_body_profile.opcode, "[total]");
EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%");
- EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles);
- EXPECT_NE(dot_profile.cycles_percentage, "0.00%");
- EXPECT_NE(dot_profile.cycles_percentage, "100.00%");
+ EXPECT_GT(total_while_body_profile.cycles, multiply_profile.cycles);
+ EXPECT_NE(multiply_profile.cycles_percentage, "0.00%");
+ EXPECT_NE(multiply_profile.cycles_percentage, "100.00%");
}
} // namespace
} // namespace xla
@@ -336,8 +352,11 @@ static std::pair<int, char**> AddXlaHloProfileFlag(int argc, char** argv) {
new_argv[argc] = strdup("--xla_hlo_profile");
// Fusion can change the Hlo instructions that show up in the final Hlo
- // executable, so block it here.
- new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion");
+ // executable, so block it here. Also block the WhileLoopInvariantCodeMotion
+ // pass, otherwise a while loop is transformed and we could not match the
+ // original name in the ProfileWhileComputation test.
+ new_argv[argc + 1] = strdup(
+ "--xla_disable_hlo_passes=fusion,while-loop-invariant-code-motion");
return {argc + 2, new_argv};
}
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index a9f2915b45..a075195618 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -49,6 +49,7 @@ GTEST_API_ int main(int argc, char** argv) {
}
// Unfortunately Google's internal benchmark infrastructure has a
// different API than Tensorflow's.
+ testing::InitGoogleTest(&argc, argv);
#if defined(PLATFORM_GOOGLE)
base::SetFlag(&FLAGS_benchmarks, pattern);
RunSpecifiedBenchmarks();
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
index 56702feab9..897123d760 100644
--- a/tensorflow/compiler/xla/text_literal_reader.cc
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
index e45e5291c9..708e8c80d8 100644
--- a/tensorflow/compiler/xla/text_literal_reader.h
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
index 23070b6638..92f9b4f9f0 100644
--- a/tensorflow/compiler/xla/text_literal_reader_test.cc
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 373c0d2d8d..24e0784741 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
index 0a1235b5e0..159ac1b7e1 100644
--- a/tensorflow/compiler/xla/text_literal_writer.h
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
index 70cf2fb1b8..4ea02faffc 100644
--- a/tensorflow/compiler/xla/text_literal_writer_test.cc
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -30,8 +31,9 @@ namespace xla {
namespace {
TEST(TextLiteralWriterTest, WritesFloatLiteral) {
- auto literal = Literal::CreateR2<float>({
- {3.14, 2.17}, {1.23, 4.56},
+ auto literal = LiteralUtil::CreateR2<float>({
+ {3.14, 2.17},
+ {1.23, 4.56},
});
string path =
tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index d73bcdaf82..55501827f2 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -74,7 +74,7 @@ cc_library(
srcs = ["replay_computation.cc"],
deps = [
"//tensorflow/compiler/xla:execution_options_util",
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -85,6 +85,7 @@ cc_library(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:testing",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service/gpu:infeed_manager",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -122,7 +123,7 @@ tf_cc_binary(
name = "show_literal",
srcs = ["show_literal.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
@@ -135,7 +136,7 @@ tf_cc_binary(
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
],
)
@@ -144,7 +145,7 @@ tf_cc_binary(
name = "show_text_literal",
srcs = ["show_text_literal.cc"],
deps = [
- "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:text_literal_reader",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc
index fe03a6e7bd..14d01b5bfb 100644
--- a/tensorflow/compiler/xla/tools/convert_computation.cc
+++ b/tensorflow/compiler/xla/tools/convert_computation.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <unistd.h>
#include <string>
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/env.h"
@@ -33,7 +33,7 @@ namespace xla {
namespace tools {
void RealMain(const string& mode, const string& path) {
- SessionModule module;
+ HloSnapshot module;
tensorflow::Env* env = tensorflow::Env::Default();
if (mode == "txt2bin") {
TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module));
diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD
deleted file mode 100644
index 76f35afd53..0000000000
--- a/tensorflow/compiler/xla/tools/parser/BUILD
+++ /dev/null
@@ -1,73 +0,0 @@
-# Build file for the Hlo parser.
-
-licenses(["notice"]) # Apache 2.0
-
-package(
- default_visibility = [":friends"],
-)
-
-package_group(
- name = "friends",
- includes = [
- "//tensorflow/compiler/xla:friends",
- ],
-)
-
-# Filegroup used to collect source files for dependency checking.
-filegroup(
- name = "c_srcs",
- data = glob([
- "**/*.cc",
- "**/*.h",
- ]),
-)
-
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "hlo_lexer",
- srcs = ["hlo_lexer.cc"],
- hdrs = [
- "hlo_lexer.h",
- "hlo_token.h",
- ],
- deps = [
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- ],
-)
-
-cc_library(
- name = "hlo_parser",
- srcs = ["hlo_parser.cc"],
- hdrs = ["hlo_parser.h"],
- deps = [
- ":hlo_lexer",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-tf_cc_test(
- name = "hlo_parser_test",
- size = "small",
- srcs = ["hlo_parser_test.cc"],
- deps = [
- ":hlo_parser",
- "//tensorflow/compiler/xla:window_util",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index be094b7890..854e797ec2 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -24,6 +24,9 @@ limitations under the License.
// passing --use_fake_data on the command line. If the real data is available
// in the proto and --use_fake_data is false, the real data is used.
//
+// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a
+// textual HLO string.
+//
// The output format is:
//
// file_path: computation_name :: type:literal_str
@@ -40,9 +43,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/testing.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -170,6 +174,11 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
client->Compile(computation, argument_layouts, ExecutableBuildOptions())
.ValueOrDie();
+ // Do not attmept to run the executable, if num_runs is less than 1.
+ if (opts.num_runs < 1) {
+ return Cancelled("Cancelled after compilation since --num_runs < 1.");
+ }
+
// Run the computation num_runs times, and return the result from the last
// execution.
StreamExecutorMemoryAllocator allocator(
@@ -187,33 +196,50 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< static_cast<double>(profile.compute_time_ns()) / 1e9 << "s";
}
- // Check that --num_runs > 0, otherwise *result below will fail with an
- // unhelpful error (because the loop didn't run any iterations).
- CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0";
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> result_literal,
client->ShapedBufferToLiteral(*result));
return std::move(*result_literal);
}
+StatusOr<HloSnapshot> ParseInputFile(const string& filename,
+ const Options& opts) {
+ tensorflow::Env* env = tensorflow::Env::Default();
+ HloSnapshot snapshot;
+ if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) {
+ return snapshot;
+ }
+ CHECK(opts.use_fake_data)
+ << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto "
+ "and textual HLO don't carry real data.";
+ fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n",
+ filename.c_str());
+
+ if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) {
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str());
+ string contents;
+ TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents));
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(contents);
+ if (module.ok()) {
+ *snapshot.mutable_hlo()->mutable_hlo_module() =
+ module.ValueOrDie()->ToProto();
+ return snapshot;
+ }
+ fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n",
+ filename.c_str());
+ return InvalidArgument("Could not parse %s.", filename.c_str());
+}
+
int RealMain(tensorflow::gtl::ArraySlice<char*> args, const Options& opts) {
LocalClient* client = ClientLibrary::LocalClientOrDie();
- tensorflow::Env* env = tensorflow::Env::Default();
int exit_status = EXIT_SUCCESS;
for (char* arg : args) {
- HloSnapshot snapshot;
- auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot);
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", arg);
- status = tensorflow::ReadBinaryProto(env, arg, snapshot.mutable_hlo());
- if (!status.ok()) {
- fprintf(stderr, "%s: is not HloSnapshot or HloProto: %s.\n", arg,
- status.ToString().c_str());
- continue;
- }
- CHECK(opts.use_fake_data)
- << "HloProto input must be handled with --use_fake_data";
+ StatusOr<HloSnapshot> maybe_snapshot = ParseInputFile(arg, opts);
+ if (!maybe_snapshot.ok()) {
+ continue;
}
-
+ HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie();
StatusOr<Literal> result_status = ReplayComputation(snapshot, client, opts);
if (!result_status.ok()) {
fprintf(stderr, "%s: error: %s\n", arg,
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
index fe8e72ba32..51909190a3 100644
--- a/tensorflow/compiler/xla/tools/show_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include <stdio.h>
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
index 8525873e91..48c8374811 100644
--- a/tensorflow/compiler/xla/tools/show_text_literal.cc
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <memory>
#include <string>
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/text_literal_reader.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index b4f45cc972..5ae099a462 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -499,17 +500,17 @@ bool c_is_sorted(const C& c, Compare&& comp) {
}
template <typename C>
-auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) {
+auto c_adjacent_find(C& c) -> decltype(std::begin(c)) {
return std::adjacent_find(std::begin(c), std::end(c));
}
template <typename C, typename Pred>
-auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) {
+auto c_find_if(C& c, Pred&& pred) -> decltype(std::begin(c)) {
return std::find_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
}
template <typename C, typename Value>
-auto c_find(const C& c, Value&& value) -> decltype(std::begin(c)) {
+auto c_find(C& c, Value&& value) -> decltype(std::begin(c)) {
return std::find(std::begin(c), std::end(c), std::forward<Value>(value));
}
@@ -533,12 +534,24 @@ c_count_if(const C& c, Pred&& pred) {
return std::count_if(std::begin(c), std::end(c), std::forward<Pred>(pred));
}
+// Determines whether `value` is present in `c`.
+template <typename C, typename T>
+bool c_linear_search(const C& c, T&& value) {
+ auto last = std::end(c);
+ return std::find(std::begin(c), last, std::forward<T>(value)) != last;
+}
+
template <typename C, typename Value>
int64 FindIndex(const C& c, Value&& value) {
auto it = c_find(c, std::forward<Value>(value));
return std::distance(c.begin(), it);
}
+template <typename T>
+bool ArrayContains(tensorflow::gtl::ArraySlice<T> c, const T& value) {
+ return c_find(c, value) != c.end();
+}
+
template <typename C, typename Value>
void InsertAt(C* c, int64 index, Value&& value) {
c->insert(c->begin() + index, std::forward<Value>(value));
@@ -549,6 +562,17 @@ void EraseAt(C* c, int64 index) {
c->erase(c->begin() + index);
}
+template <typename T>
+std::vector<T> ArraySliceToVector(tensorflow::gtl::ArraySlice<T> slice) {
+ return std::vector<T>(slice.begin(), slice.end());
+}
+
+template <typename T, int N>
+std::vector<T> InlinedVectorToVector(
+ const tensorflow::gtl::InlinedVector<T, N>& inlined_vector) {
+ return std::vector<T>(inlined_vector.begin(), inlined_vector.end());
+}
+
// Returns true if `x` fits in 32-bits.
template <typename T>
bool IsInt32(T x) {
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index f619b8dc24..6f07e4606b 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -17,7 +17,6 @@ syntax = "proto3";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
-import "tensorflow/compiler/xla/service/session.proto";
package xla;
@@ -226,22 +225,6 @@ message ExecutionOptions {
repeated DeviceHandle device_handles = 5;
}
-message SnapshotComputationRequest {
- ComputationHandle computation = 1;
-}
-
-message SnapshotComputationResponse {
- SessionModule module = 1;
-}
-
-message LoadComputationSnapshotRequest {
- SessionModule module = 1;
-}
-
-message LoadComputationSnapshotResponse {
- ComputationHandle computation = 1;
-}
-
message GetDeviceHandlesRequest {
int64 device_count = 1;
}
@@ -300,11 +283,6 @@ message ResetDeviceRequest {
message ResetDeviceResponse {
}
-message ComputationStatsRequest {
- ComputationHandle computation = 1;
- DebugOptions debug_options = 2;
-}
-
message ComputationGraphStatsRequest {
HloModuleProto computation = 1;
DebugOptions debug_options = 2;
@@ -314,14 +292,6 @@ message ComputationStatsResponse {
ComputationStats stats = 1;
}
-message ComputationRequest {
- string name = 1;
-}
-
-message ComputationResponse {
- ComputationHandle computation = 1;
-}
-
message CreateChannelHandleRequest {
}
@@ -336,24 +306,6 @@ message UnregisterRequest {
message UnregisterResponse {
}
-message SetReturnValueRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
-}
-
-message SetReturnValueResponse {
-}
-
-message ExecuteRequest {
- reserved 3, 4;
-
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-
- // Options that affect how XLA compiles and runs code to service this request.
- ExecutionOptions execution_options = 5;
-}
-
message ExecuteGraphRequest {
HloModuleProto computation = 1;
repeated GlobalDataHandle arguments = 2;
@@ -362,10 +314,6 @@ message ExecuteGraphRequest {
ExecutionOptions execution_options = 3;
}
-message ExecuteParallelRequest {
- repeated ExecuteRequest requests = 1;
-}
-
message ExecuteGraphParallelRequest {
repeated ExecuteGraphRequest requests = 1;
}
@@ -379,21 +327,6 @@ message ExecuteParallelResponse {
repeated ExecuteResponse responses = 1;
}
-message ExecuteAsyncRequest {
- reserved 3, 4;
-
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-
- // Options that affect how XLA compiles and runs code to service this request.
- ExecutionOptions execution_options = 6;
-}
-
-message ExecuteAsyncResponse {
- // A handle to the execution launched asynchronously.
- ExecutionHandle execution = 1;
-}
-
message WaitForExecutionRequest {
ExecutionHandle execution = 1;
}
@@ -403,31 +336,13 @@ message WaitForExecutionResponse {
ExecutionProfile profile = 2;
}
-message IsConstantRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
- int64 num_parameters = 3;
-}
-
-message IsConstantResponse {
- bool is_constant = 1;
-}
-
-message ComputeConstantRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
- Layout output_layout = 3;
- repeated LiteralProto parameters = 4;
-}
-
message ComputeConstantGraphRequest {
HloModuleProto computation = 1;
Layout output_layout = 2;
}
message ComputeConstantResponse {
- // A LiteralProto is returned directly for this request, instead of a
- // ComputationDataHandle.
+ // A LiteralProto is returned directly for this request.
LiteralProto literal = 1;
}
@@ -469,14 +384,6 @@ message LoadDataResponse {
int64 nanoseconds = 5;
}
-message SpecializeRequest {
- ComputationHandle computation = 1;
- repeated GlobalDataHandle arguments = 2;
-}
-
-message SpecializeResponse {
-}
-
message GetShapeRequest {
GlobalDataHandle data = 1;
}
@@ -485,14 +392,6 @@ message GetShapeResponse {
Shape shape = 1;
}
-message GetComputationShapeRequest {
- ComputationHandle computation = 1;
-}
-
-message GetComputationShapeResponse {
- ProgramShape program_shape = 1;
-}
-
message UnpackRequest {
GlobalDataHandle data = 1;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index b895ac045c..c7472173a7 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -66,11 +66,16 @@ enum PrimitiveType {
// in the dimensions field.
TUPLE = 13;
- // An opaque type used for passing context specific data to a custom
- // operation.
+ // An opaque type used for passing context-specific data to a custom
+ // operation. Shapes of this primitive type will have empty dimensions and
+ // tuple_shapes fields.
OPAQUE = 14;
- // Next = 17
+ // A token type threaded between side-effecting operations. Shapes of this
+ // primitive type will have empty dimensions and tuple_shapes fields.
+ TOKEN = 17;
+
+ // Next = 18
}
// Describes the value held inside padding elements.
@@ -269,12 +274,9 @@ message ExecutionProfile {
// for the input data transfer since the memory is initialized with the proper
// values before the execution.
int64 compute_and_transfer_time_ns = 5;
-}
-// Handle given to a user that represents a computation that the user builds up
-// before execution.
-message ComputationHandle {
- int64 handle = 1;
+ // The size of the binary code in the executable.
+ int64 executable_size_in_bytes = 6;
}
// Handle given to a user that represents an execution that the user launched
@@ -290,13 +292,6 @@ message GlobalDataHandle {
int64 handle = 1;
}
-// Handle given to a user that represents a data result in a computation.
-// This is used to pass to subsequent computations that depends upon the data as
-// an operand.
-message ComputationDataHandle {
- int64 handle = 1;
-}
-
// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
@@ -436,44 +431,6 @@ message GatherDimensionNumbers {
int64 index_vector_dim = 4;
}
-// Operation requests that are all collected as a tagged union with a oneof
-// field in OpRequest.
-
-message ConstantRequest {
- LiteralProto literal = 2;
-}
-
-message GetTupleElementRequest {
- ComputationDataHandle operand = 2;
- int64 index = 3;
-}
-
-message SliceRequest {
- ComputationDataHandle operand = 2;
- repeated int64 start_indices = 3;
- repeated int64 limit_indices = 4;
- repeated int64 strides = 5;
-}
-
-message DynamicSliceRequest {
- // Operand from which to slice at dynamic 'start_indices'.
- ComputationDataHandle operand = 2;
- // Dynamically computed 'start_indices' for slice operation.
- ComputationDataHandle start_indices = 3;
- // Slice sizes for each dimension (note that indices calculations are computed
- // modulo dimension sizes to avoid out-of-bound array accesses).
- repeated int64 slice_sizes = 4;
-}
-
-message DynamicUpdateSliceRequest {
- // Operand on which slice 'update' is to be applied.
- ComputationDataHandle operand = 2;
- // The slice update to apply to 'operand'.
- ComputationDataHandle update = 3;
- // Dynamically computed start indices for the update slice operation.
- ComputationDataHandle start_indices = 4;
-}
-
message ConvolutionDimensionNumbers {
// The number of the dimension that represents batch in the input.
int64 input_batch_dimension = 7;
@@ -511,13 +468,6 @@ message ConvolutionDimensionNumbers {
// Next = 13
};
-message ConvolveRequest {
- ComputationDataHandle lhs = 2;
- ComputationDataHandle rhs = 3; // This is the filter/kernel.
- Window window = 4; // Describes the filter/kernel.
- ConvolutionDimensionNumbers dimension_numbers = 5;
-}
-
enum FftType {
FFT = 0; // Forward FFT; complex in, complex out.
IFFT = 1; // Inverse FFT; complex in, complex out.
@@ -526,56 +476,6 @@ enum FftType {
// fft_length real out
}
-message FftRequest {
- FftType fft_type = 1;
- repeated int64 fft_length = 2; // Multivalent for higher-order FFT.
- ComputationDataHandle operand = 3;
-}
-
-message InfeedRequest {
- // The shape of the data returned by reading the device's infeed buffer.
- Shape shape = 2;
-
- // Additional infeed configuration for the backend.
- bytes config = 3;
-}
-
-message OutfeedRequest {
- // The shape of the data returned by reading the device's outfeed buffer.
- Shape shape = 1;
-
- // Operand to the Outfeed. Supports tuple.
- ComputationDataHandle operand = 2;
-
- // Backend-specific information for how to perform the outfeed.
- bytes outfeed_config = 3;
-}
-
-message CallRequest {
- ComputationHandle to_apply = 2;
- repeated ComputationDataHandle operands = 3;
-}
-
-message CustomCallRequest {
- string call_target_name = 2;
- repeated ComputationDataHandle operands = 3;
- Shape shape = 4;
-}
-
-message HostComputeRequest {
- // Operand to the HostCompute. Supports tuple.
- repeated ComputationDataHandle operands = 1;
-
- // Name used to identify HostSend/Recv channels.
- string channel_name = 2;
-
- // Cost estimate in nanoseconds.
- int64 cost_estimate_ns = 3;
-
- // The shape of any data returned by host.
- Shape shape = 4;
-}
-
message DotDimensionNumbers {
// The dimension numbers that represent the 'lhs' contracting dimensions.
repeated int64 lhs_contracting_dimensions = 1;
@@ -587,297 +487,6 @@ message DotDimensionNumbers {
repeated int64 rhs_batch_dimensions = 4;
};
-message DotRequest {
- ComputationDataHandle lhs = 2;
- ComputationDataHandle rhs = 3;
- DotDimensionNumbers dimension_numbers = 4;
-}
-
-message MapRequest {
- repeated ComputationDataHandle operands = 2;
- ComputationHandle to_apply = 3;
- repeated ComputationDataHandle static_operands = 4;
- // The dimensions over which to map.
- // Example mapping a Dot operation along the batch dimension 0:
- // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
- // Map({operand0, operand1}, Dot, {0})
- repeated int64 dimensions = 5;
-}
-
-message ReduceRequest {
- // Operand to the reduction.
- ComputationDataHandle operand = 2;
-
- // Initial value for the reduction. This must be consistent with the result
- // shape of to_apply.
- ComputationDataHandle init_value = 3;
-
- // The dimensions to reduce over.
- repeated int64 dimensions = 4;
-
- // The computation to apply in the reduction.
- ComputationHandle to_apply = 5;
-}
-
-message ReduceWindowRequest {
- ComputationDataHandle operand = 2;
- ComputationDataHandle init_value = 3;
- Window window = 4;
- ComputationHandle to_apply = 5;
-}
-
-message BatchNormTrainingRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle offset = 3;
- float epsilon = 4;
- int64 feature_index = 5;
-}
-
-message BatchNormInferenceRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle offset = 3;
- ComputationDataHandle mean = 4;
- ComputationDataHandle variance = 5;
- float epsilon = 6;
- int64 feature_index = 7;
-}
-
-message BatchNormGradRequest {
- ComputationDataHandle operand = 1;
- ComputationDataHandle scale = 2;
- ComputationDataHandle mean = 3;
- ComputationDataHandle variance = 4;
- ComputationDataHandle grad_output = 5;
- float epsilon = 6;
- int64 feature_index = 7;
-}
-
-message CrossReplicaSumRequest {
- ComputationDataHandle operand = 2;
-}
-
-message SelectAndScatterRequest {
- // Operand array on which the windows slide.
- ComputationDataHandle operand = 2;
-
- // Source array for the data to scatter.
- ComputationDataHandle source = 3;
-
- // Initial scalar value for each element in the output.
- ComputationDataHandle init_value = 4;
-
- // Window configuration.
- Window window = 5;
-
- // Binary function used to select an element from each window.
- ComputationHandle select = 6;
-
- // Binary function used to combine each scattered value from source with the
- // current output value at the selected location.
- ComputationHandle scatter = 7;
-}
-
-message ReverseRequest {
- ComputationDataHandle operand = 2;
- repeated int64 dimensions = 3;
-}
-
-message BroadcastRequest {
- ComputationDataHandle operand = 2;
- repeated int64 broadcast_sizes = 3;
-}
-
-message PadRequest {
- ComputationDataHandle operand = 2;
- ComputationDataHandle padding_value = 3;
- PaddingConfig padding_config = 4;
-}
-
-message ReshapeRequest {
- ComputationDataHandle operand = 2;
-
- // The dimension order for collapse (from fastest-changing to slowest).
- repeated int64 dimensions = 3;
-
- // The new dimension sizes (from dimension 0 to n-1).
- repeated int64 new_sizes = 4;
-}
-
-message TransposeRequest {
- ComputationDataHandle operand = 2;
-
- // The permutation of the operand's dimensions (in the range 0 to n-1).
- repeated int64 dimensions = 3;
-}
-
-message ParameterRequest {
- Shape shape = 2;
- int64 parameter = 3;
- string name = 4;
-}
-
-message GetLocalShapeRequest {
- ComputationHandle computation = 1;
- ComputationDataHandle operand = 2;
-}
-
-message GetLocalShapeResponse {
- Shape shape = 1;
-}
-
-message TraceRequest {
- string tag = 2;
- ComputationDataHandle operand = 3;
-}
-
-message ConvertRequest {
- ComputationDataHandle operand = 2;
- PrimitiveType new_element_type = 3;
-}
-
-message ConcatenateRequest {
- repeated ComputationDataHandle operands = 2;
- // The dimension in which we concatenate; e.g. if you had dimension arrays of
- // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
- // Attempting to concatenate those in dimension 1 would produce an error, as
- // 4 != 5 (and there is no ragged array support).
- int64 dimension = 3;
-}
-
-message ConditionalRequest {
- ComputationDataHandle predicate = 2;
- ComputationDataHandle true_operand = 3;
- ComputationHandle true_computation = 4;
- ComputationDataHandle false_operand = 5;
- ComputationHandle false_computation = 6;
-}
-
-message WhileRequest {
- ComputationHandle condition = 2;
- ComputationHandle body = 3;
- ComputationDataHandle init = 4;
-}
-
-enum UnaryOperation {
- UNOP_INVALID = 0;
-
- // Elementwise, logical negation on booleans and bitwise negation on ints.
- UNOP_NOT = 1;
-
- // Elementwise, computes e^x.
- UNOP_EXP = 2;
-
- // Elementwise, computes -x.
- UNOP_NEGATE = 3;
-
- // Puts the elements in the operand into sorted order.
- UNOP_SORT = 4;
-
- // Elementwise, computes tanh(x).
- UNOP_TANH = 5;
-
- // Elementwise, computes the natural logarithm of x.
- UNOP_LOG = 6;
-
- // Elementwise, computes the floor of x.
- UNOP_FLOOR = 7;
-
- // Elementwise, computes the ceil of x.
- UNOP_CEIL = 8;
-
- // Elementwise, computes the abs of x.
- UNOP_ABS = 9;
-
- // Elementwise, computes the sign of x.
- UNOP_SIGN = 10;
-
- // Elementwise, tests if values are finite (not NaN or inf)
- UNOP_IS_FINITE = 11;
-
- // Elementwise, computes the cosine of x.
- UNOP_COS = 12;
-
- // Elementwise, computes the sine of x.
- UNOP_SIN = 13;
-
- // Elementwise, rounds x to nearest integral value, rounding half-way cases
- // away from zero.
- UNOP_ROUND_NEAREST_AFZ = 14;
-
- // Elementwise, extract real component of complex x.
- UNOP_REAL = 15;
-
- // Elementwise, extract real component of complex x.
- UNOP_IMAG = 16;
-
- // Elementwise, computes clz(x).
- UNOP_CLZ = 17;
-
- // Elementwise, computes exp(x)-1.
- UNOP_EXPM1 = 18;
-
- // Elementwise, computes log(x+1).
- UNOP_LOG1P = 19;
-}
-
-message UnaryOpRequest {
- UnaryOperation unop = 2;
- ComputationDataHandle operand = 3;
-}
-
-enum BinaryOperation {
- BINOP_INVALID = 0;
-
- // Arithmetic operations.
- BINOP_ADD = 1;
- BINOP_DIV = 2;
- BINOP_MUL = 3;
- BINOP_SUB = 4;
-
- // Comparison operators.
- BINOP_EQ = 5;
- BINOP_GE = 6;
- BINOP_GT = 7;
- BINOP_LE = 8;
- BINOP_LT = 9;
- BINOP_NE = 10;
-
- // Element-wise maximum.
- BINOP_MAX = 14;
-
- // Element-wise minimum.
- BINOP_MIN = 15;
-
- // Raises the left-hand-side to the right-hand-side power.
- BINOP_POW = 16;
-
- // Remainder operation.
- BINOP_REM = 17;
-
- // Element-wise, logical operators on booleans and bitwise operators on ints.
- BINOP_AND = 18;
- BINOP_OR = 19;
-
- BINOP_SHIFT_LEFT = 20;
- BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
- BINOP_SHIFT_RIGHT_LOGICAL = 22;
-
- // Complex from real, imag.
- BINOP_COMPLEX = 23;
-
- // Computes the 4-quadrant arctangent of the y, x input arguments.
- BINOP_ATAN2 = 24;
-}
-
-message BinaryOpRequest {
- BinaryOperation binop = 2;
- ComputationDataHandle lhs = 3;
- ComputationDataHandle rhs = 4;
- repeated int64 broadcast_dimensions = 5;
-}
-
enum RandomDistribution {
RNG_INVALID = 0;
@@ -892,67 +501,6 @@ enum RandomDistribution {
// Next: 4
}
-message RngRequest {
- RandomDistribution distribution = 2;
- repeated ComputationDataHandle parameter = 3;
- Shape shape = 4;
-}
-
-enum TernaryOperation {
- TRIOP_INVALID = 0;
-
- // Given a predicate and two operands, selects operand0 if the predicate is
- // true and operand1 if the predicate is false.
- TRIOP_SELECT = 1;
-
- // Given a min, max and an operand returns the operand if between min and max,
- // else returns min if operand is less than min or max if operand is greater
- // than max.
- TRIOP_CLAMP = 3;
-}
-
-message TernaryOpRequest {
- TernaryOperation triop = 2;
- ComputationDataHandle lhs = 3;
- ComputationDataHandle rhs = 4;
- ComputationDataHandle ehs = 5;
-}
-
-enum VariadicOperation {
- VAROP_INVALID = 0;
-
- // Creates a tuple from its operands.
- VAROP_TUPLE = 1;
-}
-
-message VariadicOpRequest {
- VariadicOperation varop = 2;
- repeated ComputationDataHandle operands = 3;
-}
-
-message ReducePrecisionRequest {
- ComputationDataHandle operand = 1;
- int32 exponent_bits = 2;
- int32 mantissa_bits = 3;
-}
-
-message SendRequest {
- ComputationDataHandle operand = 1;
- ChannelHandle channel_handle = 2;
-}
-
-message RecvRequest {
- Shape shape = 1;
- ChannelHandle channel_handle = 2;
-}
-
-message GatherRequest {
- ComputationDataHandle input = 1;
- ComputationDataHandle gather_indices = 2;
- GatherDimensionNumbers dimension_numbers = 3;
- repeated int64 window_bounds = 4;
-}
-
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,
@@ -983,59 +531,3 @@ message OpSharding {
// to.
repeated OpSharding tuple_shardings = 5;
}
-
-message OpRequest {
- ComputationHandle computation = 1;
- OpMetadata metadata = 33;
- OpSharding sharding = 40;
-
- oneof op {
- BinaryOpRequest binary_op_request = 2;
- BroadcastRequest broadcast_request = 3;
- CallRequest call_request = 4;
- ConcatenateRequest concatenate_request = 5;
- ConstantRequest constant_request = 6;
- ConvertRequest convert_request = 7;
- ConvolveRequest convolve_request = 8;
- CrossReplicaSumRequest cross_replica_sum_request = 9;
- CustomCallRequest custom_call_request = 10;
- DotRequest dot_request = 43;
- DynamicSliceRequest dynamic_slice_request = 11;
- DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
- GetTupleElementRequest get_tuple_element_request = 13;
- InfeedRequest infeed_request = 14;
- MapRequest map_request = 15;
- PadRequest pad_request = 16;
- ParameterRequest parameter_request = 17;
- ReducePrecisionRequest reduce_precision_request = 36;
- ReduceRequest reduce_request = 18;
- ReduceWindowRequest reduce_window_request = 19;
- ReshapeRequest reshape_request = 20;
- ReverseRequest reverse_request = 21;
- RngRequest rng_request = 22;
- SelectAndScatterRequest select_and_scatter_request = 23;
- SliceRequest slice_request = 24;
- TernaryOpRequest ternary_op_request = 25;
- TraceRequest trace_request = 26;
- TransposeRequest transpose_request = 34;
- UnaryOpRequest unary_op_request = 27;
- VariadicOpRequest variadic_op_request = 28;
- WhileRequest while_request = 29;
- SendRequest send_request = 30;
- RecvRequest recv_request = 31;
- OutfeedRequest outfeed_request = 32;
- BatchNormTrainingRequest batch_norm_training_request = 35;
- BatchNormGradRequest batch_norm_grad_request = 37;
- BatchNormInferenceRequest batch_norm_inference_request = 38;
- FftRequest fft_request = 41;
- ConvertRequest bitcast_convert_request = 42;
- ConditionalRequest conditional_request = 44;
- HostComputeRequest host_compute_request = 45;
- GatherRequest gather_request = 46;
- // Next: 47
- }
-}
-
-message OpResponse {
- ComputationDataHandle output = 1;
-}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 0f9c80404a..c039624daa 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -9,6 +9,7 @@ load("//third_party/mpi:mpi.bzl", "if_mpi")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda")
py_library(
name = "contrib_py",
@@ -26,25 +27,24 @@ py_library(
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/boosted_trees:init_py",
"//tensorflow/contrib/checkpoint/python:checkpoint",
- "//tensorflow/contrib/cloud:cloud_py",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/coder:coder_py",
"//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/contrib/autograph",
"//tensorflow/contrib/constrained_optimization",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
"//tensorflow/contrib/data",
- "//tensorflow/contrib/distribute:distribute",
"//tensorflow/contrib/deprecated:deprecated_py",
+ "//tensorflow/contrib/distribute:distribute",
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/feature_column:feature_column_py",
"//tensorflow/contrib/framework:framework_py",
- "//tensorflow/contrib/fused_conv:fused_conv_py",
"//tensorflow/contrib/gan",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/contrib/grid_rnn:grid_rnn_py",
@@ -83,7 +83,6 @@ py_library(
"//tensorflow/contrib/proto",
"//tensorflow/contrib/quantization:quantization_py",
"//tensorflow/contrib/quantize:quantize_graph",
- "//tensorflow/contrib/autograph",
"//tensorflow/contrib/receptive_field:receptive_field_py",
"//tensorflow/contrib/recurrent:recurrent_py",
"//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py",
@@ -114,6 +113,7 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:util",
+ "//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([
"//tensorflow/contrib/tensorrt:init_py",
]) + select({
@@ -122,7 +122,17 @@ py_library(
"//tensorflow/contrib/kafka",
],
"//conditions:default": [],
- }) + if_not_windows([
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis",
+ ],
+ "//conditions:default": [],
+ }) + if_not_windows_cuda([
+ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
+ ]) + if_not_windows([
+ "//tensorflow/contrib/bigtable", # depends on bigtable
+ "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/lite/python:lite", # unix dependency, need to fix code
]),
@@ -153,6 +163,12 @@ cc_library(
"//tensorflow/contrib/kafka:dataset_kernels",
],
"//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis:dataset_kernels",
+ ],
+ "//conditions:default": [],
}),
)
@@ -182,5 +198,11 @@ cc_library(
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
],
"//conditions:default": [],
+ }) + select({
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support": [
+ "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ ],
+ "//conditions:default": [],
}),
)
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9aad772f0a..ded05da718 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -25,7 +25,8 @@ import os
from tensorflow.contrib import batching
from tensorflow.contrib import bayesflow
from tensorflow.contrib import checkpoint
-from tensorflow.contrib import cloud
+if os.name != "nt":
+ from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
from tensorflow.contrib import coder
from tensorflow.contrib import compiler
diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD
index 62d1b1cf07..881808a98b 100644
--- a/tensorflow/contrib/all_reduce/BUILD
+++ b/tensorflow/contrib/all_reduce/BUILD
@@ -12,6 +12,16 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_py_test")
py_library(
+ name = "all_reduce_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":all_reduce",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
name = "all_reduce",
srcs = [
"python/all_reduce.py",
diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py
new file mode 100644
index 0000000000..f9824f4cfb
--- /dev/null
+++ b/tensorflow/contrib/all_reduce/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""All-reduce implementations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.contrib.all_reduce.python.all_reduce import *
+
+from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long,wildcard-import
+
+_allowed_symbols = [
+ 'build_ring_all_reduce',
+ 'build_recursive_hd_all_reduce',
+ 'build_shuffle_all_reduce',
+ 'build_nccl_all_reduce',
+ 'build_nccl_then_ring',
+ 'build_nccl_then_recursive_hd',
+ 'build_nccl_then_shuffle',
+ 'build_shuffle_then_ring',
+ 'build_shuffle_then_shuffle'
+]
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD
index c10179ba8b..f0b1c92cf7 100644
--- a/tensorflow/contrib/android/BUILD
+++ b/tensorflow/contrib/android/BUILD
@@ -1,6 +1,8 @@
# Description:
# JNI-based Java inference interface for TensorFlow.
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD
index 30dd846893..ad700ac4a0 100644
--- a/tensorflow/contrib/autograph/BUILD
+++ b/tensorflow/contrib/autograph/BUILD
@@ -23,9 +23,9 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/autograph/impl",
+ "//tensorflow/contrib/autograph/lang",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/utils",
- "@gast_archive//:gast",
- "@six_archive//:six",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md
index a4aec8c74a..06fb7b03d5 100644
--- a/tensorflow/contrib/autograph/CONTRIBUTING.md
+++ b/tensorflow/contrib/autograph/CONTRIBUTING.md
@@ -1,4 +1,4 @@
-# How to Contribute
+# How to contribute
We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
@@ -46,3 +46,50 @@ bazel test --config=opt --copt=-O3 --copt=-march=native \
```
from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md)
+
+## Developer info
+
+### Module structure
+
+The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat):
+
+```dot
+digraph d_modules {
+ autograph [style=filled];
+ converters;
+ core;
+ impl;
+ lang;
+ operators;
+
+ autograph -> impl
+ autograph -> lang
+
+ impl -> converters
+ impl -> core
+ impl -> operators
+
+ lang -> operators
+
+ converters -> core
+ converters -> lang
+}
+```
+
+`autograph` is the sole user-visible module.
+
+A short description of the modules:
+
+ * `autograph`: the main module imported by the user and by the generated code; only contains declarations
+ * `impl`: high level code and the implementation of the api frontend
+ * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py`
+ * `lang`: special user-visible functions that serve as extensions to the Python language
+ * `converters`: collection of source code transformation modules specialized for particular AutoGraph features
+ * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow
+
+There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph:
+
+ * `pyct`: a general purpose Python source code transformation library
+ * `utils`: the kitchen sync; deprecated
+
+Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library.
diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/contrib/autograph/LIMITATIONS.md
new file mode 100644
index 0000000000..d8b1cb7616
--- /dev/null
+++ b/tensorflow/contrib/autograph/LIMITATIONS.md
@@ -0,0 +1,50 @@
+# Capabilities and Limitations
+
+TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`.
+
+Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support.
+
+# Python Language Support Status
+
+Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved.
+
+ Construct | Supported now? | Plan to support? | Notes
+ :--------- | :--------------: | :----------------: | :-----
+If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error.
+For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations.
+While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations.
+Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests.
+Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested.
+Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`.
+Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so.
+Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists.
+Static function calls | Yes | | Non-recursive function calls
+Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion.
+Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant.
+Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html).
+List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation.
+Function variables | Yes | | e.g. `f_new = f_orig; f_new()`
+Lambda functions | No | Yes | Planned feature.
+Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods.
+Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported.
+Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case.
+Dynamic code / exec | No | |
+Reflection | No | |
+Try / Except | No | No | No current sane TF equivalent.
+Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code.
+Functions with side effects | Some | | Side effects are allowed, under certain circumstances.
+Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples.
+List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority.
+Custom context managers | No | Yes | Currently low priority. Left unconverted currently.
+Generators | No | Maybe | Could be achievable using queues; very low priority.
+Assertions | Yes | | As `tf.Assert`
+Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in.
+Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority.
+Async | No | No |
+
+## Extra capabilities
+
+ - We liberally add name scopes to generated functions
+ - Operations get decent default names everywhere (planned)
+ - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially.
+
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 47b1d4a99a..679ab48e5c 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -4,7 +4,7 @@ IMPORTANT: AutoGraph is alpha software, and under active development. Expect rou
AutoGraph is a Python to TensorFlow compiler.
-With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops.
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
For example, this Python function:
@@ -120,3 +120,15 @@ You can use the functional API to inspect the generated code as well:
print(ag.to_code(f))
# Output: <Python and TensorFlow code>
```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md
index 866e5f583a..7e6b0cc27d 100644
--- a/tensorflow/contrib/autograph/STYLE_GUIDE.md
+++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md
@@ -20,7 +20,17 @@ Naming conventions:
Below are AutoGraph-specific conventions. In the event of conflict,
it supercedes all previous conventions.
-1. __Citations in Docstrings.__ Write a `#### References` subsection at the
+1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/]
+ notation to describe the type for args, return values and attributes.
+
+ Example:
+
+ ```
+ Args:
+ foo: Dict[str, List[int]], a dictionary of sorts
+ ```
+
+2. __Citations in Docstrings.__ Write a `#### References` subsection at the
bottom of any docstring with citations. Use ICLR’s bibliography style to
write references; for example, order entries by the first author's last
name. Add a link to the paper if the publication is open source (ideally,
@@ -60,12 +70,12 @@ it supercedes all previous conventions.
https://arxiv.org/abs/1803.04386
```
-2. Avoid LaTeX in docstrings.
+3. Avoid LaTeX in docstrings.
* It is not rendered in many (if not most) editors and can be hard to read
for both LaTeX experts and non-experts.
-3. Write docstring and comment math using ASCII friendly notation; python using
+4. Write docstring and comment math using ASCII friendly notation; python using
operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`,
`sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx:
x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`.
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index c86f7e4ede..361cf2d77c 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -30,6 +30,9 @@ from tensorflow.contrib.autograph.impl.api import do_not_convert
from tensorflow.contrib.autograph.impl.api import RunMode
from tensorflow.contrib.autograph.impl.api import to_code
from tensorflow.contrib.autograph.impl.api import to_graph
+from tensorflow.contrib.autograph.lang.directives import set_element_type
+from tensorflow.contrib.autograph.lang.directives import set_loop_options
+from tensorflow.contrib.autograph.lang.special_functions import stack
from tensorflow.contrib.autograph.pyct.transformer import AutographParseError
from tensorflow.python.util.all_util import remove_undocumented
@@ -41,8 +44,11 @@ _allowed_symbols = [
'do_not_convert',
'to_code',
'to_graph',
- # Special functions and overloaded operators
+ # Overloaded operators
'operators',
+ # Python language "extensions"
+ 'set_element_type',
+ 'set_loop_options',
'stack',
# Exceptions
'AutographParseError',
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD
index 8f9bffa55e..b2e2e27673 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/contrib/autograph/converters/BUILD
@@ -31,29 +31,17 @@ py_library(
"name_scopes.py",
"side_effect_guards.py",
"single_return.py",
+ "slices.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "@gast_archive//:gast",
- ],
-)
-
-py_library(
- name = "test_lib",
- srcs = [
- "converter_test_base.py",
- ],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
- deps = [
- ":converters",
- "//tensorflow/contrib/autograph/operators",
+ "//tensorflow/contrib/autograph/core",
+ "//tensorflow/contrib/autograph/lang",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python:util",
"@gast_archive//:gast",
- "@six_archive//:six",
],
)
@@ -63,7 +51,8 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_windows"],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -73,7 +62,8 @@ py_test(
srcs = ["break_statements_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -84,7 +74,8 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_windows"],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -96,7 +87,8 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_windows"],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/contrib/autograph/impl",
"//tensorflow/python:client_testlib",
],
@@ -107,7 +99,8 @@ py_test(
srcs = ["continue_statements_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -117,7 +110,8 @@ py_test(
srcs = ["control_flow_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -126,8 +120,13 @@ py_test(
name = "decorators_test",
srcs = ["decorators_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -136,7 +135,8 @@ py_test(
name = "name_scopes_test",
srcs = ["name_scopes_test.py"],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
],
@@ -147,7 +147,8 @@ py_test(
srcs = ["list_comprehension_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -157,7 +158,8 @@ py_test(
srcs = ["lists_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -167,7 +169,8 @@ py_test(
srcs = ["logical_expressions_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -182,7 +185,8 @@ py_test(
"notap",
],
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
],
)
@@ -192,7 +196,8 @@ py_test(
srcs = ["single_return_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
],
@@ -203,7 +208,20 @@ py_test(
srcs = ["ifexp_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":test_lib",
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
],
diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py
index 3b0db677ce..e664a403a5 100644
--- a/tensorflow/contrib/autograph/converters/asserts.py
+++ b/tensorflow/contrib/autograph/converters/asserts.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-class AssertsTransformer(transformer.Base):
+class AssertsTransformer(converter.Base):
"""Transforms Print nodes to Call so they can be handled as functions."""
def visit_Assert(self, node):
@@ -45,5 +45,5 @@ class AssertsTransformer(transformer.Base):
raise NotImplementedError('can only convert string messages for now.')
-def transform(node, context):
- return AssertsTransformer(context).visit(node)
+def transform(node, ctx):
+ return AssertsTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/contrib/autograph/converters/asserts_test.py
index cc913febe8..2cd0e626bc 100644
--- a/tensorflow/contrib/autograph/converters/asserts_test.py
+++ b/tensorflow/contrib/autograph/converters/asserts_test.py
@@ -21,11 +21,11 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.platform import test
-class AssertsTest(converter_test_base.TestCase):
+class AssertsTest(converter_testing.TestCase):
def test_transform(self):
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 775d92c1d9..a990e359a2 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -29,7 +29,7 @@ BREAK_USED = 'break_used'
CONTROL_VAR_NAME = 'control_var_name'
-class BreakStatementTransformer(transformer.Base):
+class BreakStatementTransformer(converter.Base):
"""Canonicalizes break statements into additional conditionals."""
def visit_Break(self, node):
@@ -67,7 +67,7 @@ class BreakStatementTransformer(transformer.Base):
def visit_While(self, node):
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.context.namer.new_symbol('break_', scope.referenced)
+ break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
node.test = self.visit(node.test)
node.body, break_used = self._track_body(node.body, break_var)
@@ -97,7 +97,7 @@ class BreakStatementTransformer(transformer.Base):
def visit_For(self, node):
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- break_var = self.context.namer.new_symbol('break_', scope.referenced)
+ break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
@@ -137,5 +137,5 @@ class BreakStatementTransformer(transformer.Base):
return node
-def transform(node, context):
- return BreakStatementTransformer(context).visit(node)
+def transform(node, ctx):
+ return BreakStatementTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py
index 1af59e9b52..dcff1c54c2 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/break_statements_test.py
@@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.platform import test
-class BreakCanonicalizationTest(converter_test_base.TestCase):
+class BreakCanonicalizationTest(converter_testing.TestCase):
def test_basic_while(self):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index 46e39da16a..b26c52294c 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-class BuiltinFunctionTransformer(transformer.Base):
+class BuiltinFunctionTransformer(converter.Base):
"""Handles builtin functions.
This transformer only covers functions that are translated into a
@@ -48,7 +48,7 @@ class BuiltinFunctionTransformer(transformer.Base):
# TODO(mdan): This won't work if the function was hidden.
# TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
if (isinstance(node.func, gast.Name) and
- node.func.id in ('len', 'range', 'xrange')):
+ node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
return self._convert_builtin(node)
# Print needs to be handled separately because it can be read as statement.
if isinstance(node.func, gast.Name) and node.func.id == 'print':
@@ -68,5 +68,5 @@ class BuiltinFunctionTransformer(transformer.Base):
return self.visit(function_call)
-def transform(node, context):
- return BuiltinFunctionTransformer(context).visit(node)
+def transform(node, ctx):
+ return BuiltinFunctionTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index 30272409df..e9000e518c 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -23,13 +23,13 @@ import sys
import six
from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class BuiltinFunctionsTest(converter_test_base.TestCase):
+class BuiltinFunctionsTest(converter_testing.TestCase):
def test_len(self):
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py
index b6ecdcb780..a36b3d77a9 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/contrib/autograph/converters/call_trees.py
@@ -26,12 +26,12 @@ from collections import namedtuple
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import inspect_utils
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -45,6 +45,9 @@ KNOWN_NUMPY_FUNCTIONS = {
}
+# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer.
+
+
class FunctionNamer(object):
"""Describes the interface for CallTreeTransformer's namer."""
@@ -76,20 +79,18 @@ class FunctionNamer(object):
raise NotImplementedError()
-class CallTreeTransformer(transformer.Base):
- """Transforms the call tree by renaming transformed symbols."""
+# TODO(mdan): Rename to CallsTransformer.
- def __init__(self, context, uncompiled_modules, nocompile_decorators):
- super(CallTreeTransformer, self).__init__(context)
- self.uncompiled_modules = uncompiled_modules
- self.nocompile_decorators = nocompile_decorators
+
+class CallTreeTransformer(converter.Base):
+ """Transforms the call tree by renaming transformed symbols."""
def _resolve_name(self, node):
"""Used to resolve decorator info."""
if isinstance(node, gast.Call):
return self._resolve_name(node.func)
if isinstance(node, gast.Name):
- return self.context.namespace.get(node.id)
+ return self.ctx.namespace.get(node.id)
if isinstance(node, gast.Attribute):
parent = self._resolve_name(node.value)
if parent is not None:
@@ -119,12 +120,12 @@ class CallTreeTransformer(transformer.Base):
"""Determines whether an entity should be compiled in the context."""
# TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
module_name = fqn[0]
- for mod in self.uncompiled_modules:
+ for mod in self.ctx.program.uncompiled_modules:
if module_name.startswith(mod[0] + '.'):
return False
for i in range(1, len(fqn)):
- if fqn[:i] in self.uncompiled_modules:
+ if fqn[:i] in self.ctx.program.uncompiled_modules:
return False
# Check for local decorations
@@ -140,7 +141,7 @@ class CallTreeTransformer(transformer.Base):
if hasattr(target_entity, '__pyct_is_compile_decorator'):
return False
- if target_entity in self.nocompile_decorators:
+ if target_entity in self.ctx.program.autograph_decorators:
return False
# Inspect the target function decorators. If any include a @convert
@@ -159,7 +160,7 @@ class CallTreeTransformer(transformer.Base):
for dec in target_node.decorator_list:
decorator_fn = self._resolve_name(dec)
if (decorator_fn is not None and
- decorator_fn in self.nocompile_decorators):
+ decorator_fn in self.ctx.program.autograph_decorators):
return False
return True
@@ -174,7 +175,7 @@ class CallTreeTransformer(transformer.Base):
return node
if anno.hasanno(node, 'is_constructor'):
- new_name = self.context.namer.compiled_class_name(
+ new_name = self.ctx.namer.compiled_class_name(
target_fqn, live_entity=target_entity)
do_rename = True
else:
@@ -183,7 +184,7 @@ class CallTreeTransformer(transformer.Base):
else:
# Fallback - not reliable.
owner_type = inspect_utils.getmethodclass(target_entity)
- new_name, do_rename = self.context.namer.compiled_function_name(
+ new_name, do_rename = self.ctx.namer.compiled_function_name(
target_fqn, live_entity=target_entity, owner_type=owner_type)
if do_rename:
@@ -264,15 +265,16 @@ class CallTreeTransformer(transformer.Base):
return node
def visit_Call(self, node):
- # If the function is wrapped by one of the marker decorators,
+ # If the function call is wrapped by one of the marker decorators,
# consider it graph ready.
if anno.hasanno(node.func, 'live_val'):
target_entity = anno.getanno(node.func, 'live_val')
- if target_entity in self.nocompile_decorators:
+ if target_entity in self.ctx.program.autograph_decorators:
if len(node.args) < 1:
raise ValueError(
'Found call to decorator function "%s", but it had no arguments. '
- 'A decorator needs at least an argument.')
+ 'A decorator needs at least one positional argument.' %
+ target_entity)
anno.setanno(node.args[0], 'graph_ready', True)
self.generic_visit(node)
@@ -309,27 +311,20 @@ class CallTreeTransformer(transformer.Base):
# ensure that they return the correct value.
return node
- if self.context.recursive:
+ if self.ctx.program.recursive:
node = self._insert_dynamic_conversion(node)
return node
-def transform(node, context, uncompiled_modules, nocompile_decorators):
+def transform(node, ctx):
"""Transform function call to the compiled counterparts.
Args:
- node: AST to transform.
- context: An EntityContext object.
- uncompiled_modules: set of string tuples, each tuple represents the fully
- qualified name of a package containing functions that will not be
- compiled.
- nocompile_decorators: A tuple containing decorators to be stripped from
- functions during conversion.
+ node: AST
+ ctx: EntityContext
Returns:
A tuple (node, new_names):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
- t = CallTreeTransformer(context, uncompiled_modules, nocompile_decorators)
- node = t.visit(node)
- return node
+ return CallTreeTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py
index 303dd54a4e..27d8281b85 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/contrib/autograph/converters/call_trees_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -29,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class CallTreesTest(converter_test_base.TestCase):
+class CallTreesTest(converter_testing.TestCase):
def test_basic(self):
@@ -43,7 +43,7 @@ class CallTreesTest(converter_test_base.TestCase):
return test_fn_1(a) + 1
node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
- node = call_trees.transform(node, self.ctx, (), ())
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node) as result:
# Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1
@@ -60,7 +60,7 @@ class CallTreesTest(converter_test_base.TestCase):
return f() + 3
node = self.parse_and_analyze(test_fn_2, {})
- node = call_trees.transform(node, self.ctx, (), ())
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node) as result:
# 10 = 7 (from the mock) + 3 (from test_fn_2)
@@ -78,9 +78,9 @@ class CallTreesTest(converter_test_base.TestCase):
node = self.parse_and_analyze(
TestClass.test_fn_2, {'TestClass': TestClass},
- namer=converter_test_base.FakeNoRenameNamer(),
+ namer=converter_testing.FakeNoRenameNamer(),
arg_types={'self': (TestClass.__name__, TestClass)})
- node = call_trees.transform(node, self.ctx, (), ())
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node) as result:
tc = TestClass()
@@ -92,7 +92,7 @@ class CallTreesTest(converter_test_base.TestCase):
setattr(a, 'foo', 'bar')
node = self.parse_and_analyze(test_fn, {'setattr': setattr})
- node = call_trees.transform(node, self.ctx, (), ())
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node) as result:
with self.test_session() as sess:
@@ -115,7 +115,7 @@ class CallTreesTest(converter_test_base.TestCase):
return np.random.binomial(2, 0.5)
node = self.parse_and_analyze(test_fn, {'np': np})
- node = call_trees.transform(node, self.ctx, (), ())
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node, dtypes.int64) as result:
result.np = np
@@ -130,13 +130,13 @@ class CallTreesTest(converter_test_base.TestCase):
a = math_ops.add(a, constant_op.constant(1))
return a
- node = self.parse_and_analyze(test_fn, {
- 'math_ops': math_ops,
- 'constant_op': constant_op
- })
- node = call_trees.transform(node, self.ctx,
- set(((math_ops.__name__,),
- (constant_op.__name__,))), ())
+ node = self.parse_and_analyze(
+ test_fn, {
+ 'math_ops': math_ops,
+ 'constant_op': constant_op
+ },
+ arg_types=set(((math_ops.__name__,), (constant_op.__name__,))))
+ node = call_trees.transform(node, self.ctx)
with self.compiled(node) as result:
result.math_ops = math_ops
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py
index 0417817a77..958bde0a58 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -31,7 +31,7 @@ GUARD_CREATED = 'guard_created'
CREATE_GUARD_NEXT = 'create_guard_next'
-class ContinueCanonicalizationTransformer(transformer.Base):
+class ContinueCanonicalizationTransformer(converter.Base):
"""Canonicalizes continue statements into additional conditionals."""
def visit_Continue(self, node):
@@ -85,7 +85,7 @@ class ContinueCanonicalizationTransformer(transformer.Base):
def _visit_loop_body(self, node, nodes):
self.enter_local_scope()
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
- continue_var = self.context.namer.new_symbol('continue_', scope.referenced)
+ continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced)
self.set_local(CONTROL_VAR_NAME, continue_var)
nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
@@ -135,5 +135,5 @@ class ContinueCanonicalizationTransformer(transformer.Base):
return node
-def transform(node, namer):
- return ContinueCanonicalizationTransformer(namer).visit(node)
+def transform(node, ctx):
+ return ContinueCanonicalizationTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/contrib/autograph/converters/continue_statements_test.py
index bcbb316d74..2ce1837972 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/contrib/autograph/converters/continue_statements_test.py
@@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.platform import test
-class ContinueCanonicalizationTest(converter_test_base.TestCase):
+class ContinueCanonicalizationTest(converter_testing.TestCase):
def test_basic_continue(self):
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index d7ddbe8a04..f4a8710627 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import cfg
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -45,9 +45,8 @@ class SymbolNamer(object):
raise NotImplementedError()
-class ControlFlowTransformer(transformer.Base):
+class ControlFlowTransformer(converter.Base):
"""Transforms control flow structures like loops an conditionals."""
-
def _create_cond_branch(self, body_name, aliased_orig_names,
aliased_new_names, body, returns):
if aliased_orig_names:
@@ -141,10 +140,10 @@ class ControlFlowTransformer(transformer.Base):
aliased_orelse_orig_names = tuple(orelse_scope.modified -
orelse_scope.created)
aliased_body_new_names = tuple(
- self.context.namer.new_symbol(s.ssf(), body_scope.referenced)
+ self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
for s in aliased_body_orig_names)
aliased_orelse_new_names = tuple(
- self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced)
+ self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
for s in aliased_orelse_orig_names)
alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
@@ -165,9 +164,8 @@ class ControlFlowTransformer(transformer.Base):
else:
results = gast.Tuple([s.ast() for s in modified], None)
- body_name = self.context.namer.new_symbol('if_true', body_scope.referenced)
- orelse_name = self.context.namer.new_symbol('if_false',
- orelse_scope.referenced)
+ body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
+ orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
if modified:
def build_returns(aliased_names, alias_map, scope):
@@ -235,7 +233,7 @@ class ControlFlowTransformer(transformer.Base):
raise ValueError('cannot convert while loop: no outputs')
state_ssf = [
- self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
]
ssf_map = {
name: ssf
@@ -267,11 +265,9 @@ class ControlFlowTransformer(transformer.Base):
state=state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- test_name=self.context.namer.new_symbol('loop_test',
- body_scope.referenced),
+ test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced),
test=test,
- body_name=self.context.namer.new_symbol('loop_body',
- body_scope.referenced),
+ body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced),
body=node_body,
extra_deps=tuple(s.ast() for s in cond_closure),
)
@@ -288,7 +284,7 @@ class ControlFlowTransformer(transformer.Base):
state = list(body_closure)
state_ssf = [
- self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
]
ssf_map = {
name: ssf
@@ -326,17 +322,16 @@ class ControlFlowTransformer(transformer.Base):
state_ast_tuple=state_ast_tuple,
iter_=node.iter,
iterate=node.target,
- extra_test_name=self.context.namer.new_symbol('extra_test',
- all_referenced),
+ extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
extra_test_expr=extra_test,
- body_name=self.context.namer.new_symbol('loop_body', all_referenced),
+ body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
body=node_body)
return node
-def transform(node, context):
- cfg.run_analyses(node, cfg.Liveness(context))
- cfg.run_analyses(node, cfg.Defined(context))
- node = ControlFlowTransformer(context).visit(node)
+def transform(node, ctx):
+ cfg.run_analyses(node, cfg.Liveness(ctx.info))
+ cfg.run_analyses(node, cfg.Defined(ctx.info))
+ node = ControlFlowTransformer(ctx).visit(node)
return node
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 1a863590f9..735eb92a0d 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.converters import converter_test_base
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -27,7 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
-class ControlFlowTest(converter_test_base.TestCase):
+class ControlFlowTest(converter_testing.TestCase):
def test_simple_while(self):
@@ -42,7 +42,7 @@ class ControlFlowTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- with self.compiled(node, control_flow_ops.while_loop) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
self.assertEqual((10, 5, 5),
sess.run(result.test_fn(constant_op.constant(5))))
@@ -57,7 +57,7 @@ class ControlFlowTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- with self.compiled(node, control_flow_ops.while_loop) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
@@ -75,7 +75,7 @@ class ControlFlowTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- with self.compiled(node, control_flow_ops.cond) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
self.assertEqual((-1, 0),
sess.run(result.test_fn(constant_op.constant(1))))
@@ -92,7 +92,7 @@ class ControlFlowTest(converter_test_base.TestCase):
node = self.parse_and_analyze(test_fn, {})
node = control_flow.transform(node, self.ctx)
- with self.compiled(node, control_flow_ops.cond) as result:
+ with self.compiled(node) as result:
with self.test_session() as sess:
self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/contrib/autograph/converters/decorators.py
index 92445f3174..3471bd11d6 100644
--- a/tensorflow/contrib/autograph/converters/decorators.py
+++ b/tensorflow/contrib/autograph/converters/decorators.py
@@ -24,19 +24,14 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.util import tf_inspect
-class DecoratorsTransformer(gast.NodeTransformer):
+class DecoratorsTransformer(converter.Base):
"""Converts or removes decorators."""
- def __init__(self, remove_decorators):
- self.remove_decorators = remove_decorators
- self.additional_dependencies = set()
-
- # pylint:disable=invalid-name
-
def visit_FunctionDef(self, node):
self.generic_visit(node)
kept_decorators = []
@@ -58,31 +53,53 @@ class DecoratorsTransformer(gast.NodeTransformer):
# This is currently verified by tests.
continue
- if not anno.hasanno(dec_func, 'live_val'):
- raise ValueError(
- 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func))
-
+ original_dec = anno.getanno(dec_func, anno.Basic.QN)
dec_value = anno.getanno(dec_func, 'live_val')
- if dec_value not in self.remove_decorators:
- kept_decorators.append((dec, dec_value))
- for _, dec_value in kept_decorators:
- if dec_value.__module__ == '__main__':
+ if dec_value in self.ctx.program.autograph_decorators:
+ # AutoGraph decorators do not need to be preserved.
+ continue
+
+ # When using foo.bar.baz, we only really need to grab foo and import
+ # that.
+ dec_support_node = dec_func
+ while isinstance(dec_support_node, gast.Attribute):
+ dec_support_node = dec_support_node.value
+
+ if not anno.hasanno(dec_support_node, 'live_val'):
raise ValueError(
- 'decorator "%s" was not allowed because it is declared '
- 'in the module "%s". To fix this, declare it in a separate '
- 'module that we can import it from.' % (dec_value,
- dec_value.__module__))
+ 'could not resolve symbol "%s" when looking up decorator "%s"' %
+ (anno.getanno(dec_support_node, anno.Basic.QN), original_dec))
+
+ dec_support = anno.getanno(dec_support_node, 'live_val')
+ # The tuple contains:
+ # * the AST that represents the decorator
+ # * the entity supporting the decorator (i.e., what we need to import)
+ # * the name of the module that needs to be imported for this decorator
+ # to properly resolve.
+ # Examples:
+ # for foo.bar, the tuple is (<ast>, <module foo>, 'foo')
+ # for baz, the tuple is (<ast>, <module baz.__module__>, 'baz')
+ kept_decorators.append((dec, dec_support,
+ anno.getanno(dec_support_node, anno.Basic.QN)))
+
+ for _, dec_support, name in kept_decorators:
+ if tf_inspect.ismodule(dec_support):
+ self.ctx.program.additional_imports.add(
+ 'import %s as %s' % (dec_support.__name__, name))
else:
- self.additional_dependencies.add(dec_value)
-
- node.decorator_list = [dec for dec, _ in kept_decorators]
+ if dec_support.__module__ == '__main__':
+ raise ValueError(
+ 'decorator "%s" was not allowed because it is declared '
+ 'in the module "%s". To fix this, declare it in a separate '
+ 'module that we can import it from.' % (dec_support,
+ dec_support.__module__))
+ self.ctx.program.additional_imports.add(
+ 'from %s import %s' % (dec_support.__module__, name))
+
+ node.decorator_list = [dec for dec, _, _ in kept_decorators]
return node
- # pylint:enable=invalid-name
-
-def transform(node, remove_decorators):
- transformer = DecoratorsTransformer(remove_decorators)
- node = transformer.visit(node)
- return node, transformer.additional_dependencies
+def transform(node, ctx):
+ return DecoratorsTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/contrib/autograph/converters/decorators_test.py
index 9c01f68912..d41c7fde24 100644
--- a/tensorflow/contrib/autograph/converters/decorators_test.py
+++ b/tensorflow/contrib/autograph/converters/decorators_test.py
@@ -20,9 +20,10 @@ from __future__ import print_function
from functools import wraps
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import decorators
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.platform import test
@@ -39,28 +40,35 @@ def simple_decorator(f):
return lambda a: f(a) + 1
-def self_removing_decorator(removing_wrapper):
+def self_transform_decorator(transform):
+
def decorator(f):
@wraps(f)
def wrapper(*args):
# This removing wrapper is defined in the test below. This setup is so
- # intricate just to simulate how we use the transformer in practice.
- transformed_f = removing_wrapper(f, (self_removing_decorator,))
+ # intricate in order to simulate how we use the transformer in practice.
+ transformed_f = transform(f, (self_transform_decorator,))
return transformed_f(*args) + 1
return wrapper
return decorator
-class DecoratorsTest(converter_test_base.TestCase):
+class DecoratorsTest(converter_testing.TestCase):
- def _remover_wrapper(self, f, remove_decorators):
+ def _transform(self, f, autograph_decorators):
namespace = {
- 'self_removing_decorator': self_removing_decorator,
- 'simple_decorator': simple_decorator
+ 'self_transform_decorator': self_transform_decorator,
+ 'simple_decorator': simple_decorator,
+ 'converter_testing': converter_testing,
}
- node = self.parse_and_analyze(f, namespace)
- node, _ = decorators.transform(node, remove_decorators=remove_decorators)
- result, _ = compiler.ast_to_object(node)
+ node = self.parse_and_analyze(
+ f,
+ namespace,
+ recursive=False,
+ autograph_decorators=autograph_decorators)
+ node = decorators.transform(node, self.ctx)
+ import_line = '\n'.join(self.ctx.program.additional_imports)
+ result, _ = compiler.ast_to_object(node, source_prefix=import_line)
return getattr(result, f.__name__)
def test_noop(self):
@@ -69,15 +77,14 @@ class DecoratorsTest(converter_test_base.TestCase):
return a
node = self.parse_and_analyze(test_fn, {})
- node, deps = decorators.transform(node, remove_decorators=())
+ node = decorators.transform(node, self.ctx)
result, _ = compiler.ast_to_object(node)
- self.assertFalse(deps)
self.assertEqual(1, result.test_fn(1))
def test_function(self):
- @self_removing_decorator(self._remover_wrapper)
+ @self_transform_decorator(self._transform)
def test_fn(a):
return a
@@ -88,7 +95,7 @@ class DecoratorsTest(converter_test_base.TestCase):
class TestClass(object):
- @self_removing_decorator(self._remover_wrapper)
+ @self_transform_decorator(self._transform)
def test_fn(self, a):
return a
@@ -101,38 +108,39 @@ class DecoratorsTest(converter_test_base.TestCase):
# Note that reversing the order of this two doesn't work.
@classmethod
- @self_removing_decorator(self._remover_wrapper)
+ @self_transform_decorator(self._transform)
def test_fn(cls, a):
return a
# 2 = 1 (a) + 1 (decorator applied exactly once)
self.assertEqual(2, TestClass.test_fn(1))
- def test_nested_decorators(self):
+ def test_nested_decorators_local(self):
- @self_removing_decorator(self._remover_wrapper)
+ @self_transform_decorator(self._transform)
def test_fn(a):
@simple_decorator
def inner_fn(b):
return b + 11
return inner_fn(a)
- with self.assertRaises(ValueError):
+ # Expected to fail because simple_decorator cannot be imported.
+ with self.assertRaises(transformer.AutographParseError):
test_fn(1)
- # TODO(mdan): Uncomment this test once converter_test_base is updated.
- # (can't do it now because it has unrelated pending changes)
- # def test_nested_decorators(self):
- #
- # @self_removing_decorator(self._remover_wrapper)
- # def test_fn(a):
- # @imported_decorator
- # def inner_fn(b):
- # return b + 11
- # return inner_fn(a)
- #
- # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
- # self.assertEqual(14, test_fn(1))
+ def test_nested_decorators_imported(self):
+
+ @self_transform_decorator(self._transform)
+ def test_fn(a):
+
+ @converter_testing.imported_decorator
+ def inner_fn(b):
+ return b + 11
+
+ return inner_fn(a)
+
+ # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
+ self.assertEqual(14, test_fn(1))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py
index 616d222762..e996138498 100644
--- a/tensorflow/contrib/autograph/converters/ifexp.py
+++ b/tensorflow/contrib/autograph/converters/ifexp.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-class IfExp(transformer.Base):
+class IfExp(converter.Base):
"""Canonicalizes all IfExp nodes into plain conditionals."""
def visit_IfExp(self, node):
@@ -34,16 +34,16 @@ class IfExp(transformer.Base):
return desugared_ifexp
-def transform(node, context):
+def transform(node, ctx):
"""Desugar IfExp nodes into plain conditionals.
Args:
- node: an AST node to transform
- context: a context object
+ node: ast.AST, the node to transform
+ ctx: converter.EntityContext
Returns:
new_node: an AST with no IfExp nodes, only conditionals.
"""
- node = IfExp(context).visit(node)
+ node = IfExp(ctx).visit(node)
return node
diff --git a/tensorflow/contrib/autograph/converters/ifexp_test.py b/tensorflow/contrib/autograph/converters/ifexp_test.py
index ac6849dcb4..cdd5a2f591 100644
--- a/tensorflow/contrib/autograph/converters/ifexp_test.py
+++ b/tensorflow/contrib/autograph/converters/ifexp_test.py
@@ -19,12 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import ifexp
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.platform import test
-class IfExpTest(converter_test_base.TestCase):
+class IfExpTest(converter_testing.TestCase):
def compiled_fn(self, test_fn, *args):
node = self.parse_and_analyze(test_fn, {})
diff --git a/tensorflow/contrib/autograph/converters/list_comprehension.py b/tensorflow/contrib/autograph/converters/list_comprehension.py
index d7f2920151..c4a13ee822 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehension.py
+++ b/tensorflow/contrib/autograph/converters/list_comprehension.py
@@ -31,17 +31,14 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-class ListCompCanonicalizationTransformer(transformer.Base):
+class ListCompCanonicalizationTransformer(converter.Base):
"""NodeTransformer to canonicalize list comprehensions."""
- def __init__(self, context):
- super(ListCompCanonicalizationTransformer, self).__init__(context)
-
def make_update_list_node(self, list_, elt):
return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0]
@@ -76,5 +73,5 @@ class ListCompCanonicalizationTransformer(transformer.Base):
return make_list + loop_body
-def transform(node, context):
- return ListCompCanonicalizationTransformer(context).visit(node)
+def transform(node, ctx):
+ return ListCompCanonicalizationTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/list_comprehension_test.py b/tensorflow/contrib/autograph/converters/list_comprehension_test.py
index 4758671f5e..2bbee93412 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehension_test.py
+++ b/tensorflow/contrib/autograph/converters/list_comprehension_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import list_comprehension
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.platform import test
-class ListCompTest(converter_test_base.TestCase):
+class ListCompTest(converter_testing.TestCase):
def test_basic(self):
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py
index b49521b2c3..d77a044798 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/contrib/autograph/converters/lists.py
@@ -32,85 +32,196 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.python.framework import dtypes
+from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-class ListTransformer(transformer.Base):
+# Tags for local state.
+POP_USES = 'pop_uses'
+
+
+class ListTransformer(converter.Base):
"""Converts lists and related operations to their TF counterpart."""
- def _empty_list(self, node):
- if not anno.hasanno(node, 'element_type'):
- raise NotImplementedError(
- 'type inference for empty lists is not yet supported; '
- 'use set_element_type(<list>, <dtype>) to continue')
- dtype = anno.getanno(node, 'element_type')
- if not isinstance(dtype, dtypes.DType):
- # TODO(mdan): Allow non-TF dtypes?
- # That would be consistent with the dynamic dispatch pattern, but
- # we must make sure that doesn't become confusing.
- raise NotImplementedError('element type "%s" not yet supported' % dtype)
-
- dtype_name = dtype.name
- # TODO(mdan): Does it ever make sense not to use tensor lists?
+ def visit_List(self, node):
+ node = self.generic_visit(node)
template = """
- tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True)
+ ag__.new_list(elements)
"""
- return templates.replace_as_expression(template, dtype_name=dtype_name)
+ return templates.replace_as_expression(template, elements=node)
- def _pre_populated_list(self, node):
- raise NotImplementedError('pre-populated lists')
+ def _replace_append_call(self, node):
+ assert len(node.args) == 1
+ assert isinstance(node.func, gast.Attribute)
+ template = """
+ target = ag__.list_append(target, element)
+ """
+ return templates.replace(
+ template,
+ target=node.func.value,
+ element=node.args[0])
+
+ def _replace_pop_call(self, node):
+ # Expressions that use pop() are converted to a statement + expression.
+ #
+ # For example:
+ #
+ # print(target.pop())
+ #
+ # ... is converted to:
+ #
+ # target, target_pop = ag__.list_pop(target)
+ # print(target_pop)
+ #
+ # Here, we just generate the variable name and swap it in,
+ # and _generate_pop_operation will handle the rest.
+ #
+ # Multiple uses of pop() are allowed:
+ #
+ # print(tartget.pop(), target.pop())
+ # print(tartget.pop().pop())
+ #
+ assert isinstance(node.func, gast.Attribute)
+ scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+ target_node = node.func.value
+
+ # Attempt to use a related name if can get one. Otherwise use something
+ # generic.
+ if anno.hasanno(target_node, anno.Basic.QN):
+ target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
+ else:
+ target_name = 'list'
+ pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)
+
+ pop_uses = self.get_local(POP_USES, [])
+ pop_uses.append((node, pop_var_name))
+ self.set_local(POP_USES, pop_uses)
+
+ return templates.replace_as_expression('var_name', var_name=pop_var_name)
+
+ def _replace_stack_call(self, node):
+ assert len(node.args) == 1
+ dtype = anno.getanno(
+ node.args[0],
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ template = """
+ ag__.list_stack(
+ target,
+ opts=ag__.ListStackOpts(
+ element_dtype=dtype,
+ original_call=orig_call))
+ """
+ return templates.replace_as_expression(
+ template,
+ dtype=dtype,
+ target=node.args[0],
+ orig_call=node.func)
- def visit_Expr(self, node):
+ def visit_Call(self, node):
node = self.generic_visit(node)
- if isinstance(node.value, gast.Call):
- call_node = node.value
-
- if not anno.hasanno(call_node.func, anno.Basic.QN):
- return node
- qn = anno.getanno(call_node.func, anno.Basic.QN)
-
- if qn.qn[-1] == 'append' and (len(call_node.args) == 1):
- template = """
- target = ag__.utils.dynamic_list_append(target, element)
- """
- node = templates.replace(
- template,
- target=qn.parent.ast(),
- element=call_node.args[0])
+
+ # TODO(mdan): This is insufficient if target is a function argument.
+ # In the case of function arguments, we need to add the list to the
+ # function's return value, because it is being modified.
+ # TODO(mdan): Checking just the name is brittle, can it be improved?
+ if isinstance(node.func, gast.Attribute):
+ func_name = node.func.attr
+ if func_name == 'append' and (len(node.args) == 1):
+ node = self._replace_append_call(node)
+ elif func_name == 'pop' and (len(node.args) <= 1):
+ node = self._replace_pop_call(node)
+ elif func_name == 'stack' and (len(node.args) == 1):
+ node = self._replace_stack_call(node)
+
return node
- def _replace_list_constructors(self, targets, values):
- for target in targets:
- if (isinstance(target, (gast.Tuple, gast.List)) and
- isinstance(values, (gast.Tuple, gast.List))):
- n_targets = len(target.elts)
- for i in range(n_targets):
- target_el, value_el = target.elts[i], values.elts[i]
- values.elts[i] = self._replace_list_constructors(
- (target_el,), value_el)
- return values
- if isinstance(values, gast.List):
- if values.elts:
- return self._pre_populated_list(values)
- else:
- return self._empty_list(values)
- return values
-
- def visit_Assign(self, node):
- node = self.generic_visit(node)
+ def _generate_pop_operation(self, original_call_node, pop_var_name):
+ assert isinstance(original_call_node.func, gast.Attribute)
+
+ if original_call_node.args:
+ pop_element = original_call_node.args[0]
+ else:
+ pop_element = parser.parse_expression('None')
+ # The call will be something like "target.pop()", and the dtype is hooked to
+ # target, hence the func.value.
+ dtype = anno.getanno(
+ original_call_node.func.value,
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+ shape = anno.getanno(
+ original_call_node.func.value,
+ 'element_shape',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ target, pop_var_name = ag__.list_pop(
+ target, element,
+ opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
+ """
+ return templates.replace(
+ template,
+ target=original_call_node.func.value,
+ pop_var_name=pop_var_name,
+ element=pop_element,
+ dtype=dtype,
+ shape=shape)
+
+ def _postprocess_statement(self, node):
+ """Inserts any separate pop() calls that node may use."""
+ pop_uses = self.get_local(POP_USES, None)
+ if pop_uses:
+ replacements = []
+ for original_call_node, pop_var_name in pop_uses:
+ replacements.extend(
+ self._generate_pop_operation(original_call_node, pop_var_name))
+ replacements.append(node)
+ node = replacements
+ self.exit_local_scope()
+ return node, None
+
+ # TODO(mdan): Should we have a generic visit_block instead?
+ # Right now it feels that a visit_block would add too much magic that's
+ # hard to follow.
+
+ def _visit_and_process_block(self, block):
+ return self.visit_block(
+ block,
+ before_visit=self.enter_local_scope,
+ after_visit=self._postprocess_statement)
+
+ def visit_FunctionDef(self, node):
+ node.args = self.generic_visit(node.args)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ node.body = self._visit_and_process_block(node.body)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.visit(node.target)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
- # Only convert lists when they are assigned to a variable, e.g.:
- # l = []
- # TODO(mdan): A similar pattern exists in type_info.py
- # We should add a generic "unpack_assignment" function to the base
- # transformer, that has the same effect as applying some logic to the SSA
- # form.
- node.value = self._replace_list_constructors(node.targets, node.value)
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._visit_and_process_block(node.body)
return node
-def transform(node, context):
- return ListTransformer(context).visit(node)
+def transform(node, ctx):
+ return ListTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py
index 74c6dc64f1..ea04097b28 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/contrib/autograph/converters/lists_test.py
@@ -19,77 +19,129 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import lists
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
-class ListTest(converter_test_base.TestCase):
+class ListTest(converter_testing.TestCase):
- def test_empty_annotated_list(self):
+ def test_empty_list(self):
def test_fn():
- l = []
- utils.set_element_type(l, dtypes.int32)
- l.append(1)
- return l
+ return []
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
- # TODO(mdan): Attach these additional modules automatically.
- result.utils = utils
- result.dtypes = dtypes
+ with self.compiled(node) as result:
+ tl = result.test_fn()
+ # Empty tensor lists cannot be evaluated or stacked.
+ self.assertTrue(isinstance(tl, ops.Tensor))
+ self.assertEqual(tl.dtype, dtypes.variant)
+
+ def test_initialized_list(self):
+
+ def test_fn():
+ return [1, 2, 3]
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
with self.test_session() as sess:
- self.assertAllEqual([1], sess.run(result.test_fn().stack()))
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
- def test_empty_annotated_lists_unpacked(self):
+ def test_list_append(self):
def test_fn():
- l, m = [], []
- utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
+ l = [1]
+ l.append(2)
+ l.append(3)
+ return l
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ node = self.parse_and_analyze(test_fn, {})
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node) as result:
+ with self.test_session() as sess:
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
+
+ def test_list_pop(self):
+
+ def test_fn():
+ l = [1, 2, 3]
+ utils.set_element_type(l, dtypes.int32, ())
+ s = l.pop()
+ return s, l
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ ts, tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2])
+ self.assertAllEqual(sess.run(ts), 3)
+
+ def test_double_list_pop(self):
- def test_empty_annotated_lists_list_unpacked(self):
+ def test_fn(l):
+ s = l.pop().pop()
+ return s
+
+ node = self.parse_and_analyze(test_fn, {})
+ node = lists.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ test_input = [1, 2, [1, 2, 3]]
+ # TODO(mdan): Pass a list of lists of tensor when we fully support that.
+ # For now, we just pass a regular Python list of lists just to verify that
+ # the two pop calls are sequenced properly.
+ self.assertAllEqual(result.test_fn(test_input), 3)
+
+ def test_list_stack(self):
+
+ tf = None # Will be replaced with a mock.
def test_fn():
- [l, m] = [], []
+ l = [1, 2, 3]
utils.set_element_type(l, dtypes.int32)
- utils.set_element_type(m, dtypes.int32)
- l.append(1)
- m.append(2)
- return l, m
-
- node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils})
+ return tf.stack(l)
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
node = lists.transform(node, self.ctx)
- with self.compiled(node, tensor_array_ops.TensorArray,
- dtypes.int32) as result:
+ with self.compiled(node, array_ops.stack, dtypes.int32) as result:
result.utils = utils
result.dtypes = dtypes
with self.test_session() as sess:
- res_l, res_m = result.test_fn()
- self.assertEqual([1], sess.run(res_l.stack()))
- self.assertEqual([2], sess.run(res_m.stack()))
+ self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py
index 3a795a315a..16eb1f0e3f 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions.py
@@ -23,10 +23,10 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
@@ -39,11 +39,11 @@ from tensorflow.contrib.autograph.pyct import transformer
SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
-class LogicalExpressionTransformer(transformer.Base):
+class LogicalExpressionTransformer(converter.Base):
"""Converts logical expressions to corresponding TF calls."""
- def __init__(self, context):
- super(LogicalExpressionTransformer, self).__init__(context)
+ def __init__(self, ctx):
+ super(LogicalExpressionTransformer, self).__init__(ctx)
# TODO(mdan): Look into replacing with bitwise operators instead.
# TODO(mdan): Skip replacing if the function is trivial.
self.op_mapping = {
@@ -128,5 +128,5 @@ class LogicalExpressionTransformer(transformer.Base):
return right
-def transform(node, context):
- return LogicalExpressionTransformer(context).visit(node)
+def transform(node, ctx):
+ return LogicalExpressionTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
index 2814060c4d..48186024a9 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import logical_expressions
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class GradientsFunctionTest(converter_test_base.TestCase):
+class GradientsFunctionTest(converter_testing.TestCase):
def test_equals(self):
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/contrib/autograph/converters/name_scopes.py
index dfee529aba..dd6c6bf960 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes.py
+++ b/tensorflow/contrib/autograph/converters/name_scopes.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-class FunctionNameScopeTransformer(transformer.Base):
+class FunctionNameScopeTransformer(converter.Base):
"""Wrap a function body with a `name_scope` of the function name."""
def _name_for_current_scope(self):
@@ -70,5 +70,5 @@ class FunctionNameScopeTransformer(transformer.Base):
return node
-def transform(node, context):
- return FunctionNameScopeTransformer(context).visit(node)
+def transform(node, ctx):
+ return FunctionNameScopeTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/contrib/autograph/converters/name_scopes_test.py
index 17692cbd88..444d0bcd46 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/contrib/autograph/converters/name_scopes_test.py
@@ -18,14 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import name_scopes
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class FunctionNameScopeTransformer(converter_test_base.TestCase):
+class FunctionNameScopeTransformer(converter_testing.TestCase):
def test_basic(self):
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py
index 3bcb2d3c42..b808604f0a 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py
@@ -36,11 +36,11 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import qual_names
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -59,14 +59,9 @@ class SymbolNamer(object):
raise NotImplementedError()
-class SideEffectGuardTransformer(transformer.Base):
+class SideEffectGuardTransformer(converter.Base):
"""Adds control dependencies to functions with side effects."""
- def __init__(self, context):
- super(SideEffectGuardTransformer, self).__init__(context)
-
- # pylint:disable=invalid-name
-
def _visit_and_reindent(self, nodes):
new_nodes = []
current_dest = new_nodes
@@ -149,7 +144,7 @@ class SideEffectGuardTransformer(transformer.Base):
s for s in guarded_args if s not in args_scope.parent.modified)
aliased_new_names = tuple(
qual_names.QN(
- self.context.namer.new_symbol(
+ self.ctx.namer.new_symbol(
s.ssf(), args_scope.parent.referenced)) for s in need_alias)
alias_map = dict(zip(need_alias, aliased_new_names))
if len(guarded_args) == 1:
@@ -183,8 +178,6 @@ class SideEffectGuardTransformer(transformer.Base):
(node.body, alias_map))
return node
- # pylint:enable=invalid-name
-
-def transform(node, context):
- return SideEffectGuardTransformer(context).visit(node)
+def transform(node, ctx):
+ return SideEffectGuardTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
index ce0ce33243..a7ad8efed4 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import side_effect_guards
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@@ -29,7 +29,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class SideEffectGuardsTest(converter_test_base.TestCase):
+class SideEffectGuardsTest(converter_testing.TestCase):
def test_side_effect_on_return_only_variable(self):
diff --git a/tensorflow/contrib/autograph/converters/single_return.py b/tensorflow/contrib/autograph/converters/single_return.py
index bcc9ca9dfe..a351cd81b8 100644
--- a/tensorflow/contrib/autograph/converters/single_return.py
+++ b/tensorflow/contrib/autograph/converters/single_return.py
@@ -20,21 +20,21 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Move this logic into transformer_base.
-class BodyVisitor(transformer.Base):
+class BodyVisitor(converter.Base):
"""Walks breadth- or depth-first the list-of-nodes bodies of AST nodes."""
- def __init__(self, context, depth_first=False):
+ def __init__(self, ctx, depth_first=False):
+ super(BodyVisitor, self).__init__(ctx)
self.depth_first = depth_first
self.changes_made = False
- super(BodyVisitor, self).__init__(context)
def visit_nodelist(self, nodelist):
for node in nodelist:
@@ -144,13 +144,13 @@ def contains_return(node):
return False
-class LiftReturn(transformer.Base):
+class LiftReturn(converter.Base):
"""Move return statements out of If and With blocks."""
- def __init__(self, context):
+ def __init__(self, ctx):
+ super(LiftReturn, self).__init__(ctx)
self.changes_made = False
self.common_return_name = None
- super(LiftReturn, self).__init__(context)
def visit_If(self, node):
# Depth-first traversal of if statements
@@ -195,8 +195,8 @@ class LiftReturn(transformer.Base):
last_return_name = self.common_return_name
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
referenced_names = body_scope.referenced
- self.common_return_name = self.context.namer.new_symbol(
- 'return_', referenced_names)
+ self.common_return_name = self.ctx.namer.new_symbol('return_',
+ referenced_names)
node = self.generic_visit(node)
self.common_return_name = last_return_name
return node
@@ -265,7 +265,7 @@ class DetectReturnInFunctionDef(gast.NodeVisitor):
'Each function definition should contain at least one return.')
-def transform(node, context):
+def transform(node, ctx):
"""Ensure a function has only a single return.
This transforms an AST node with multiple returns successively into containing
@@ -280,8 +280,8 @@ def transform(node, context):
this is an error.
Args:
- node: an AST node to transform
- context: a context object
+ node: ast.AST
+ ctx: converter.EntityContext
Returns:
new_node: an AST with a single return value
@@ -301,10 +301,10 @@ def transform(node, context):
while True:
# Try to lift all returns out of if statements and with blocks
- lr = LiftReturn(context)
+ lr = LiftReturn(ctx)
node = lr.visit(node)
changes_made = lr.changes_made
- fe = FoldElse(context)
+ fe = FoldElse(ctx)
node = fe.visit(node)
changes_made = changes_made or fe.changes_made
diff --git a/tensorflow/contrib/autograph/converters/single_return_test.py b/tensorflow/contrib/autograph/converters/single_return_test.py
index d483005a09..1f0de4310e 100644
--- a/tensorflow/contrib/autograph/converters/single_return_test.py
+++ b/tensorflow/contrib/autograph/converters/single_return_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import converter_test_base
from tensorflow.contrib.autograph.converters import single_return
+from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework.ops import name_scope
from tensorflow.python.platform import test
-class SingleReturnTest(converter_test_base.TestCase):
+class SingleReturnTest(converter_testing.TestCase):
def compiled_fn(self, test_fn, *args):
node = self.parse_and_analyze(test_fn, {})
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py
new file mode 100644
index 0000000000..3f5fc57125
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/slices.py
@@ -0,0 +1,83 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for slice operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import templates
+
+
+class SliceTransformer(converter.Base):
+ """Converts slicing operations to their TF counterpart.
+
+ Currently, relying on the default slice operator that Tensor uses is
+ insufficient, because TensorArray and tensor lists use dedicated index read
+ and write functions.
+ """
+
+ def _process_single_assignment(self, target, value):
+ if not isinstance(target, gast.Subscript):
+ return None
+
+ template = """
+ target = ag__.set_item(target, key, item)
+ """
+ return templates.replace(
+ template, target=target.value, key=target.slice, item=value)
+
+ def visit_Assign(self, node):
+ node = self.generic_visit(node)
+ # TODO(mdan): Support unpackings and multiple assignments.
+ if len(node.targets) != 1:
+ raise NotImplementedError('multiple assignment')
+ replacement = self._process_single_assignment(node.targets[0], node.value)
+ if replacement is not None:
+ return replacement
+ return node
+
+ def visit_Subscript(self, node):
+ node = self.generic_visit(node)
+ if not isinstance(node.slice, gast.Index):
+ # TODO(mdan): It might make more sense to wave them through.
+ raise NotImplementedError('non-index slice')
+
+ if not isinstance(node.ctx, gast.Load):
+ # Index writes are handled at a higher level, one at which the rvalue is
+ # also available.
+ return node
+
+ dtype = anno.getanno(
+ node.value,
+ 'element_type',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ ag__.get_item(
+ target,
+ key,
+ opts=ag__.GetItemOpts(element_dtype=dtype))
+ """
+ return templates.replace_as_expression(
+ template, target=node.value, key=node.slice, dtype=dtype)
+
+
+def transform(node, ctx):
+ return SliceTransformer(ctx).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py
new file mode 100644
index 0000000000..df9a4c8bab
--- /dev/null
+++ b/tensorflow/contrib/autograph/converters/slices_test.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.converters import slices
+from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SliceTest(converter_testing.TestCase):
+
+ def test_index_access(self):
+
+ def test_fn(l):
+ utils.set_element_type(l, dtypes.int32)
+ return l[1]
+
+ node = self.parse_and_analyze(
+ test_fn,
+ {
+ 'utils': utils,
+ 'dtypes': dtypes
+ },
+ include_type_analysis=True,
+ )
+ node = slices.transform(node, self.ctx)
+
+ with self.compiled(node, dtypes.int32) as result:
+ result.utils = utils
+ result.dtypes = dtypes
+ with self.test_session() as sess:
+ tl = list_ops.tensor_list_from_tensor(
+ [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
+ y = result.test_fn(tl)
+ self.assertEqual(2, sess.run(y))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/contrib/autograph/core/BUILD
new file mode 100644
index 0000000000..833f9dced8
--- /dev/null
+++ b/tensorflow/contrib/autograph/core/BUILD
@@ -0,0 +1,59 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "core",
+ srcs = [
+ "config.py",
+ "converter.py",
+ "naming.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/contrib/autograph/pyct/static_analysis",
+ "//tensorflow/contrib/autograph/utils",
+ ],
+)
+
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_testing.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":core",
+ "//tensorflow/contrib/autograph/operators",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/contrib/autograph/pyct/static_analysis",
+ "//tensorflow/contrib/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/core/annos.py b/tensorflow/contrib/autograph/core/annos.py
new file mode 100644
index 0000000000..b8937ce36a
--- /dev/null
+++ b/tensorflow/contrib/autograph/core/annos.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Annotations specific to AutoGraph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from enum import Enum
+
+
+class NoValue(Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class NodeAnno(NoValue):
+ """Additional annotations used by AutoGraph converters.
+
+ These are in addition to the basic annotations declared in pyct/anno.py and
+ pyct/static_analysis/annos.py.
+ """
+
+ # The directives collection - see directives.py
+ DIRECTIVES = (
+ 'Dict depicting static directive calls. See the directives converter.')
diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/core/config.py
index 878bb7e12f..878bb7e12f 100644
--- a/tensorflow/contrib/autograph/impl/config.py
+++ b/tensorflow/contrib/autograph/core/config.py
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py
new file mode 100644
index 0000000000..54e6aa0f3b
--- /dev/null
+++ b/tensorflow/contrib/autograph/core/converter.py
@@ -0,0 +1,210 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter construction support.
+
+This module contains a base class for all converters, as well as supporting
+structures. These structures are referred to as contexts.
+
+The class hierarchy is as follows:
+
+ <your converter>
+ [extends] converter.Base
+ [extends] transformer.Base
+ [extends] gast.nodeTransformer
+ [uses] transfomer.SourceInfo
+ [uses] converter.EntityContext
+ [uses] converter.ProgramContext
+ [uses] transfomer.SourceInfo
+
+converter.Base is a specialization of transformer.Base for AutoGraph. It's a
+very lightweight subclass that adds a `ctx` attribute holding the corresponding
+EntityContext object (see below). Note that converters are not reusable, and
+`visit` will raise an error if called more than once.
+
+converter.EntityContext contains mutable state associated with an entity that
+the converter processes.
+
+converter.ProgramContext contains mutable state across related entities. For
+example, when converting several functions that call one another, the
+ProgramContext should be shared across these entities.
+
+Below is the overal flow at conversion:
+
+ program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
+ while <program_ctx has more entities to convert>:
+ entity, source_info = <get next entity from program_ctx>
+ entity_ctx = EntityContext(program_ctx, source_info)
+ for <each ConverterClass>:
+ converter = ConverterClass(entity_ctx)
+
+ # May update entity_ctx and program_ctx
+ entity = converter.visit(entity)
+
+ <add entity's dependencies to program_ctx>
+
+Note that pyct contains a small number of transformers used for static analysis.
+These implement transformer.Base, rather than converter.Base, to avoid a
+dependency on AutoGraph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.contrib.autograph.core import config
+from tensorflow.contrib.autograph.core import naming
+from tensorflow.contrib.autograph.pyct import transformer
+
+# TODO(mdan): These contexts can be refactored into first class objects.
+# For example, we could define Program and Entity abstractions that hold on
+# to the actual entity and have conversion methods.
+
+
+class ProgramContext(object):
+ """ProgramContext keeps track of converting function hierarchies.
+
+ This object is mutable, and is updated during conversion. Not thread safe.
+
+ Attributes:
+ recursive: bool, whether to recursively convert any functions that the
+ decorator function may call.
+ autograph_decorators: Tuple[Callable, ...], decorator functions that belong
+ to AutoGraph. These require special treatment.
+ dependency_cache: Dict[Any, ast.AST], the original entities mapped to their
+ converted AST
+ additional_imports: Set[Any], additional entities which for any reason
+ cannot be attached after loading and need to be explicitly imported
+ in the generated code
+ name_map: Dict[str, str], map of original entity name to the name of
+ their converted counterparts
+ autograph_module: Module, a reference to the autograph module. This
+ needs to be specified by the caller to avoid circular dependencies.
+ uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the
+ fully qualified name of a package containing functions that will not be
+ compiled.
+ required_imports: str, containing an import statement on each line. These
+ are all the imports necessary for the compiled code to run, in addition
+ to the closures of each entity, which are attached dynamically.
+ """
+
+ def __init__(
+ self,
+ recursive,
+ autograph_decorators,
+ partial_types,
+ autograph_module,
+ uncompiled_modules,
+ ):
+ self.recursive = recursive
+ self.autograph_decorators = autograph_decorators
+ self.partial_types = partial_types if partial_types else ()
+ self.autograph_module = autograph_module
+ self.uncompiled_modules = uncompiled_modules
+
+ # Required to output dependencies in discovery order, which should match
+ # the reverse dependency order.
+ self.dependency_cache = collections.OrderedDict()
+ self.additional_imports = set()
+ self.name_map = {}
+
+ @property
+ def required_imports(self):
+ """Returns a block containing all imports required by the converted code."""
+ # TODO(mdan): Check that these don't clobber one another.
+ return '\n'.join(config.COMPILED_IMPORT_STATEMENTS +
+ tuple(self.additional_imports))
+
+ def new_namer(self, namespace):
+ return naming.Namer(namespace, self.recursive, self.name_map,
+ self.partial_types)
+
+ def update_name_map(self, namer):
+ """Updates renamed_calls based on the recent activity from the namer.
+
+ Whenever we convert a new entity, any references to other entities are being
+ renamed to match their soon-to-be-converted counterparts. The namer keeps
+ track of these renames. When conversion is complete, we copy those renames
+ so that when those referenced entities are being converted, their new name
+ matches.
+
+ Args:
+ namer: naming.Namer
+
+ Raises:
+ ValueError: when an entity was renamed twice and to different names.
+ """
+ # TODO(mdan): Have call_trees do this directly.
+ # This is done so indirectly, via the namer, for historic reasons. But
+ # now we can have the converter that does the rename record the new name
+ # as well and skip this step altogether.
+ for o, name in namer.renamed_calls.items():
+ if o in self.name_map:
+ if self.name_map[o] != name:
+ raise ValueError(
+ 'Calls to %s were converted using multiple names (%s). This is '
+ 'possible when an entity with one of these names already '
+ 'existed. To fix, avoid using any of these names.' %
+ (o, (name, self.name_map[o])))
+ else:
+ self.name_map[o] = name
+
+ def add_to_cache(self, original_entity, converted_ast):
+ self.dependency_cache[original_entity] = converted_ast
+
+
+class EntityContext(object):
+ """Tracks the conversion of a single entity.
+
+ This object is mutable, and is updated during conversion. Not thread safe.
+
+ Attributes:
+ namer: Namer
+ info: transformer.EntityInfo
+ program: ProgramContext
+ """
+
+ def __init__(self, namer, entity_info, program_ctx):
+ self.namer = namer
+ self.info = entity_info
+ self.program = program_ctx
+
+
+class Base(transformer.Base):
+ """All converters should inherit from this class.
+
+ Attributes:
+ ctx: EntityContext
+ """
+
+ def __init__(self, ctx):
+ super(Base, self).__init__(ctx.info)
+ self.ctx = ctx # Keeping this short because it's used frequently.
+
+ self._used = False
+ self._ast_depth = 0
+
+ def visit(self, node):
+ if not self._ast_depth:
+ if self._used:
+ raise ValueError('converter objects cannot be reused')
+ self._used = True
+
+ self._ast_depth += 1
+ try:
+ return super(Base, self).visit(node)
+ finally:
+ self._ast_depth -= 1
diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/core/converter_testing.py
index 41c2e71702..0e46aacc12 100644
--- a/tensorflow/contrib/autograph/converters/converter_test_base.py
+++ b/tensorflow/contrib/autograph/core/converter_testing.py
@@ -23,17 +23,24 @@ import imp
from tensorflow.contrib.autograph import operators
from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.core import config
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import pretty_printer
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import activity
from tensorflow.contrib.autograph.pyct.static_analysis import live_values
from tensorflow.contrib.autograph.pyct.static_analysis import type_info
from tensorflow.python.platform import test
+def imported_decorator(f):
+ return lambda a: f(a) + 1
+
+
+# TODO(mdan): We might be able to use the real namer here.
class FakeNamer(object):
"""A fake namer that uses a global counter to generate unique names."""
@@ -114,23 +121,32 @@ class TestCase(test.TestCase):
arg_types=None,
include_type_analysis=True,
owner_type=None,
- recursive=True):
+ recursive=True,
+ autograph_decorators=()):
node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=namer or FakeNamer(),
+
+ if namer is None:
+ namer = FakeNamer()
+ program_ctx = converter.ProgramContext(
+ recursive=recursive,
+ autograph_decorators=autograph_decorators,
+ partial_types=None,
+ autograph_module=None,
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ entity_info = transformer.EntityInfo(
source_code=source,
- source_file=None,
+ source_file='<fragment>',
namespace=namespace,
arg_values=None,
arg_types=arg_types,
- owner_type=owner_type,
- recursive=recursive,
- type_annotation_func=utils.set_element_type)
+ owner_type=owner_type)
+ ctx = converter.EntityContext(namer, entity_info, program_ctx)
+
node = qual_names.resolve(node)
- node = activity.resolve(node, ctx)
- node = live_values.resolve(node, ctx, {})
+ node = activity.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
if include_type_analysis:
- node = type_info.resolve(node, ctx)
- node = live_values.resolve(node, ctx, {})
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
self.ctx = ctx
return node
diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/core/naming.py
index b1d3f76be7..b1d3f76be7 100644
--- a/tensorflow/contrib/autograph/impl/naming.py
+++ b/tensorflow/contrib/autograph/core/naming.py
diff --git a/tensorflow/contrib/autograph/impl/naming_test.py b/tensorflow/contrib/autograph/core/naming_test.py
index 73fc089465..d2bebd0478 100644
--- a/tensorflow/contrib/autograph/impl/naming_test.py
+++ b/tensorflow/contrib/autograph/core/naming_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.impl import naming
+from tensorflow.contrib.autograph.core import naming
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
new file mode 100644
index 0000000000..1368ce244c
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -0,0 +1,29 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_test(
+ name = "keras_test",
+ srcs = [
+ "keras_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
new file mode 100644
index 0000000000..a2fc7c550e
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
@@ -0,0 +1,37 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras integration tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+class MinimalKeras(tf.keras.Model):
+
+ def call(self, x):
+ return x * 3
+
+
+class KerasTest(tf.test.TestCase):
+
+ def test_basic(self):
+ MinimalKeras()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb
new file mode 100644
index 0000000000..fff673921a
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/notebooks/autograph_vs_eager_mnist_benchmark.ipynb
@@ -0,0 +1,666 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Pa2qpEmoVOGe"
+ },
+ "outputs": [],
+ "source": [
+ "from __future__ import absolute_import\n",
+ "from __future__ import division\n",
+ "from __future__ import print_function\n",
+ "\n",
+ "import os\n",
+ "import time\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import six\n",
+ "\n",
+ "from tensorflow.contrib import autograph\n",
+ "from tensorflow.contrib.eager.python import tfe\n",
+ "from tensorflow.python.eager import context\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "YfnHJbBOBKae"
+ },
+ "outputs": [],
+ "source": [
+ "import gzip\n",
+ "import shutil\n",
+ "\n",
+ "from six.moves import urllib\n",
+ "\n",
+ "\n",
+ "def download(directory, filename):\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " zipped_filepath = filepath + '.gz'\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [784])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8)\n",
+ " label = tf.reshape(label, [])\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def mnist_train(directory):\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte')\n",
+ "\n",
+ "def mnist_test(directory):\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')\n",
+ "\n",
+ "def setup_mnist_data(is_training, hp, batch_size):\n",
+ " if is_training:\n",
+ " ds = mnist_train('/tmp/autograph_mnist_data')\n",
+ " ds = ds.cache()\n",
+ " ds = ds.shuffle(batch_size * 10)\n",
+ " else:\n",
+ " ds = mnist_test('/tmp/autograph_mnist_data')\n",
+ " ds = ds.cache()\n",
+ " ds = ds.repeat()\n",
+ " ds = ds.batch(batch_size)\n",
+ " return ds\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "x_MU13boiok2"
+ },
+ "outputs": [],
+ "source": [
+ "def mlp_model(input_shape):\n",
+ " model = tf.keras.Sequential((\n",
+ " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
+ " tf.keras.layers.Dense(100, activation='relu'),\n",
+ " tf.keras.layers.Dense(10, activation='softmax')))\n",
+ " model.build()\n",
+ " return model\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "kfZk9EFZ5TeQ"
+ },
+ "outputs": [],
+ "source": [
+ "# Test-only parameters. Test checks successful completion not correctness. \n",
+ "burn_ins = 1\n",
+ "trials = 1\n",
+ "max_steps = 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "gWXV8WHn43iZ"
+ },
+ "outputs": [],
+ "source": [
+ "#@test {\"skip\": true} \n",
+ "burn_ins = 3\n",
+ "trials = 10\n",
+ "max_steps = 500"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "DXt4GoTxtvn2"
+ },
+ "source": [
+ "# Autograph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "W51sfbONiz_5"
+ },
+ "outputs": [],
+ "source": [
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "CsAD0ajbi9iZ"
+ },
+ "outputs": [],
+ "source": [
+ "def fit(m, x, y, opt):\n",
+ " l, accuracy = predict(m, x, y)\n",
+ " opt.minimize(l)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "RVw57HdTjPzi"
+ },
+ "outputs": [],
+ "source": [
+ "def get_next_batch(ds):\n",
+ " itr = ds.make_one_shot_iterator()\n",
+ " image, label = itr.get_next()\n",
+ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
+ " y = tf.one_hot(tf.squeeze(label), 10)\n",
+ " return x, y\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "UUI0566FjZPx"
+ },
+ "outputs": [],
+ "source": [
+ "def train(train_ds, test_ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " train_losses = []\n",
+ " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n",
+ " test_losses = []\n",
+ " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n",
+ " train_accuracies = []\n",
+ " train_accuracies = autograph.utils.set_element_type(train_accuracies,\n",
+ " tf.float32)\n",
+ " test_accuracies = []\n",
+ " test_accuracies = autograph.utils.set_element_type(test_accuracies,\n",
+ " tf.float32)\n",
+ " i = tf.constant(0)\n",
+ " while i \u003c hp.max_steps:\n",
+ " train_x, train_y = get_next_batch(train_ds)\n",
+ " test_x, test_y = get_next_batch(test_ds)\n",
+ " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ "\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 789
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 11529,
+ "status": "ok",
+ "timestamp": 1531163743912,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 240
+ },
+ "id": "K1m8TwOKjdNd",
+ "outputId": "59db8f19-23a5-413a-e9d0-fb756b0e4757"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Duration: 0.592790126801\n",
+ "Duration: 0.594069957733\n",
+ "Duration: 0.591835975647\n",
+ "Duration: 0.592386007309\n",
+ "Duration: 0.595040082932\n",
+ "Duration: 0.594245910645\n",
+ "Duration: 0.624264001846\n",
+ "Duration: 0.6021900177\n",
+ "Duration: 0.592960119247\n",
+ "Duration: 0.599496841431\n",
+ "Mean duration: 0.597927904129 +/- 0.0093268291102\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAEcCAYAAAAydkhNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd8FGX+wPHPbMum90IKvQSQ3jtSbYCAqHee9TxPT0VF\njztRT+9UzvMOsdzPUxTO3gURsYsgTRBFmvROQkJ63T7z+2OS3Wx2EwIkBC7f9+vFi+zO7Mwzz84+\n33nKPKNomqYhhBBCBGFo7gQIIYQ4d0mQEEIIUScJEkIIIeokQUIIIUSdJEgIIYSokwQJIYQQdZIg\nIYQQok4SJISow6ZNm7j44oubOxknlZWVRWZmJqqqNndSxP8gCRLilI0ZM4YePXpQXFzs9/6UKVPI\nzMwkOzsbgD//+c9kZmaybds27zpHjhwhMzPT+/raa6/lgw8+8L5+4YUXGDt2LH379mX06NHMmjUL\ngMsuu4y+ffvSt29funXrRs+ePenTpw99+/ZlwYIFAWn897//zezZs8/oOPv3789nn312Sp958cUX\nmT9/Phs3bmTUqFFntP9qtfMoGEVRGmVfQtRmau4EiPNTeno6y5cv55prrgFgz549OBwOv8JKURRi\nYmJ4+umnWbhwod/7wSxZsoRly5bx6quvkp6eTkFBAStWrADgk08+8a537bXXcvnllzN9+vQzOgZN\n0xq9cF21ahX33XcfLpdLCm7xP0FqEuK0TJkyhSVLlnhfL1myhKlTpwasN3XqVHbv3s2mTZtOus3t\n27czfPhw0tPTAYiPj2fGjBlB161vNpnVq1fzwgsv8Omnn9KnTx8uv/xyQA8u8+fP51e/+hW9e/fm\n2LFjLF68mEsuuYS+ffsyfvx43n33Xe92atcGxowZw6JFi5g8eTIDBgxg1qxZOJ1O7/LS0lIOHz5M\nt27duOWWWzhx4oS3tpOXl4emaSxYsIDx48czePBg7rnnHkpLSwFwOp388Y9/ZNCgQQwYMIAZM2ZQ\nWFjI/Pnz+fHHH3n00Ufp27cvjz322Enz8cSJE9x2220MGjSIiRMn8v7773uXbd26lenTp9OvXz+G\nDx/OP/7xj3r3D1BeXs4DDzzA8OHDGTVqFE8//bQ3/48cOcK1115L//79GTJkiLfmJ/53SE1CnJZe\nvXqxdOlSDhw4QNu2bfn888956623mD9/vt96VquVW2+9laeeeoq33nrrpNt8/PHHSUpKYtCgQXTr\n1g2D4dSvY0aMGMGtt97KkSNHePLJJ/2WLVu2jJdeeol27dqhqirx8fEsWLCA9PR0Nm3axM0330zP\nnj3p2rUrEFjr+fzzz1m0aBEWi4Wrr76aJUuWcNVVVwGwZs0aBg8ejNVq5aWXXmL27NmsXLnS+9lX\nXnmFFStW8OabbxIbG8tjjz3GX//6V+bNm8eSJUsoLy9n9erVmM1mdu7cSUhICPfccw8//fQTU6ZM\n4YorrmjQ8c+aNYsuXbrw7LPPsn//fm688UYyMjIYPHgwc+fO5frrr2fy5MnYbDb27t0LUOf+AWbP\nnk1SUhLffPMNFRUV3HrrraSmpnLllVfyzDPPMHz4cF5//XWcTifbt28/5e9LnNukJiFO25QpU/jo\no49Yu3Yt7du3JykpKeh6V155JcePH2f16tX1bm/y5Mk89NBDrF27lmuvvZahQ4cG7W84E1OnTqVD\nhw4YDAZMJhOjRo3y1lz69+/PsGHD6q31XHfddSQkJBAVFcWFF17Izp07vctWrlxZbz/Ee++9x913\n301SUhJms5nbb7+dL774AlVVMZlMFBcXc/DgQRRFoVu3boSHh5/y8R0/fpzNmzdz3333YTabyczM\nZMaMGSxduhQAk8nEkSNHKCoqIjQ0lJ49e3rfD7b/goICVq9ezZw5cwgJCSEuLo7rr7+e5cuXez+X\nlZVFbm4uFouFvn37nnKaxblNahLitE2ePJnf/OY3HDt2jClTptS5nsVi4Q9/+APPPPMM8+bNq3eb\nl112GZdddhkej4evv/6ae++9l+7duzNs2LBGSXNKSorf61WrVvH8889z6NAhVFXFbrfTpUuXOj8f\nHx/v/Ts0NJS8vDxAb/5at24d999/f52fzc7O5o477vDWjjRNw2QykZ+fz5QpU8jJyWHWrFmUlZUx\nadIkZs2ahdFoPKXjy8vLIzo6mtDQUO97qamp7NixA4C5c+fyzDPPcPHFF5ORkcHtt9/O6NGjA/Y/\nefJk7rnnHrKysnC73QwfPtybZk3TaNWqFaDXMp5++mmuuOIKYmJiuOGGG864r0icWyRIiNOWmppK\nWloa3333HXPnzq133WnTpvHyyy/z1VdfNWjbRqORiRMnsmDBAvbu3dtoQaJm85HT6eSuu+7in//8\nJ2PHjsVgMHD77bfX299Rl23btpGenk5sbGzAfqq1atWKuXPn0qdPn6DbuP3227n99tvJzs7md7/7\nHe3bt2f69Omn1AGelJRESUkJlZWVhIWFAXrtorqW17p1a2+g/uKLL5g5cyYbN27EarUG7L9du3aM\nHDmSkJAQNmzYEDQd8fHxPProowD8+OOP3HjjjQwcOJCMjIwGp1mc26S5SZyRuXPn8uqrr2K1Wutd\nz2g0cscdd/DSSy/Vuc6SJUtYtWoVFRUVaJrGqlWr2L9/v7dJ5FQkJCSQlZVVb4HvcrlwuVzExsZi\nMBhYtWoVa9euPeV9gd7UNHLkSO/r+Ph4iouLKS8v97531VVX8dRTT3mHCBcWFvLNN98AsGHDBvbs\n2YOqqoSFhWEymby1iISEBI4ePVrv/quPMyUlhT59+vDUU0/hdDrZtWsXH3zwAZMnTwbg448/9nZI\nR0ZGoigKBoOhzv0nJiYybNgw5s6dS3l5OZqmcfToUX744QdA76PJzc0FICoqCoPBcFr9SOLc1aQ1\niTlz5rBy5Uri4+NZtmyZ9/3XX3+dN998E7PZzKhRo7jvvvuaMhmikdW8oqx9xVjfVe9ll13GggUL\nKCsrC7p+REQEL7zwAgcOHMDj8ZCamsojjzwS0M7dkCvriy66iI8//phBgwaRnp7O4sWLAz4XHh7O\nAw88wF133YXL5eLCCy9k7NixdW6zvv2uWrWKv/3tb97X7du359JLL2Xs2LFomsby5cu5/vrrAbjp\nppvIy8sjPj6eiy++mLFjx5Kfn8/DDz9Mbm4u4eHhXHLJJd6C/brrruNPf/oT77zzDpMnT+aBBx6o\nN23z5s3j4YcfZsSIEURHR3PXXXcxZMgQQB/59cQTT2C320lLS2P+/PlYLJZ69/+Pf/yDf/3rX1x6\n6aVUVlaSkZHBzTffDOg1qOoAkpCQwAMPPEBaWlq93404vyhN+WS6TZs2ER4ezuzZs71BYsOGDbz4\n4ossWLAAk8lEYWEhcXFxTZUEIZpcQUEBl19++Uk75oU4HzVpvbB///5ERUX5vff222/zu9/9DpNJ\nr8RIgBDnu7Kysno7rIU4n531xsNDhw6xadMmrrzySq699lq/KRuEOB+1bduWSy65pLmTIUSTOOuj\nmzweD6Wlpbz33nts3bqVu+++29t5J4QQ4txy1msSKSkpTJgwAYCePXtiMBgoKio66eeasOtECCFE\nHZq8JlG7cB83bhzr169nwIABHDx4ELfb7R1bXh9FUcjLKzvpei1BYmKk5EUVyQsfyQsfyQufxMTI\nM/p8kwaJe++9lw0bNlBcXMzo0aO58847mT59Ovfffz+TJk3CbDZ7JxgTQghx7mnSIbCNTa4MdHKV\n5CN54SN54SN54XOmNQm5NVIIIUSdJEgIIYSokwQJIYQQdZIgIYQQok4SJM4zTo+LVza/T7GjpFG3\nuyn3Z37I2dyo22xqdredV356jzJn+clXPgUbjv/ITye2Nuo2m1qlq5JXfnqPSldlo253bfYGtubt\naNRtNrUyZzmvbH4fu9veqNv97tg6dhTsbtRtng8kSJxndhXu4dM9K1ifffJnRp+K/+54i1d+ebtR\nt9nUtufv5NO937Ix56dG3e5rO99l4fY3GnWbTe3nvO18uvdbfjyxpdG2qWoqb+36kBe3vdpo2zwb\nfszdwqd7VrDlFIJbeXk5S5Z8UOdyl8fFu3s+4vktCwOWzZ59NxUVDb9QWbRoAe+8c/6cXxIkzjO2\nqqujEmdpo23To3q8f2eX57CnaD+bT2zD5razLf+XRttPY7N5qvLC0Xh5YXc7vH8fLctmT9F+tuRt\np9JVyY6CXY22n8bmPS8aMS8qatRKjpQeY0/Rfrbl/0K5s4KdhXsabT+N7XTyoqyslCVL3g+6TFVV\nSp2+4bQHS46wt2g/2/N3UuIo47d/voPw8IgzS/Q5TJ5Md55xePRCrDELgzKX7yro8Y1PBSy/reeN\nXJDQtdH211iqC/STBUxN0xr8dLfSGtt64oenA5bf0/c2Osa0O4VUnh32Jjgvam7rH5ueDVh+/4C7\nSY9MbbT9NRbvb+QULqReeOHfZGdncdNN19C//yCGDBnGf//7EvHxCezbt4e//vtJDr69FVeJg9+5\nryVhSAbx/fRj3/nUOhYufJ1Qzcp9982kR4/ebN++hcTEZJ54Yh4Wi6XO/e7du5t//esJHA4HaWlp\n3H//w0RERPD++++wdOliTCYTbdu245FHHmfz5h959tl5Veeywv/930t+j6ltKhIkzpCmaVS6bYSb\nw87K/uwN+AFUuioJNYX6P6rT40QDQoyBJ+zJCpYiR3GD0na28sLudLN8/WGUVjag/vS//90vfLbu\nOM/eNZKIUDMADo8TBbAEzYv6b8CqeUVZH1VTsbvthDVSXmiaRoXd7T2G2nwFY93pq3BVBnw3drcD\ng2LAYgzc7skK2XJXxcmSDVTnhYMwc8MKtPdW7OOHXScatG4wFW4LDvcoVmwzs/HrdQAMyEziyjEd\nfevUyovbbruTAwf38+LLr2I2mNi8+Ud27vyF119/j+jYBP67bgWtL++KMdSM6vKw98VNRHdLxBRq\nBgUq3TZCjVaOHTvKX//6d/70pwe4/e57eP29pfz2NzP88sKtur2vH3vsEWbN+hO9evVm4cIX+e9/\nF3DnnbN4881X+eCDZZhMJm9T1jvvvMG99/6ZCy7oid1urzf4NCZpbjpDa7K/Z/bqR065KeLjNQf5\neV/+Ke/P4XECUFpHYXao9Ah/XP0Inx9a4ff+X7//J/ev+Rvf/nSM1Vuy/ZadLEgYGnCafL8jh/9b\ns5TZqx9hX/HBk65/Jj5YuZ/l6w+zeV8OUHfBvadoHyvdr2BMPszOw75JJB9c+zgPrw8+HczJCsaa\nP/D6fHrwK/64+hGOlB1r0Pons2brcWY+s7rOc8ZRVasqDfJdbj9YwIsrVzB79SOsy97ot+yPqx9m\nbpDaI5w8YFYHppNZvO8T/rj6YXIqcimtcFJS3rDPna7qSSTUOiaT2JSzmdmrH+HH3J+976mayvGK\nHP616d/e97p1605KSgqb9+ax+VAWeeuPsvv5jex96UdcpQ6cBbaqHYLNpTdxpbRKJTw2FVXTyCqP\n5ONv/fuI3t71IV8c/pZKVyUVFeVUVJTTq1dvAC666FJ+/lkfPNKxYyceeeQBvvzyMwwG/TG2PXr0\n4tlnn+KDD96hrKz0rD0mVmoSp2HBxzvYe6yEf/5hqLcw/jF3C93jMxv0+bJKJx+t0QvSRX8ec0r7\ndtRoYlE1lez8SkLMRr79KYvIcDPuhJ0AfHLwCy5u53sUZ/VoqDdW/ozmDGNEr1Q278nj25+z6DOo\n/hExFQ0YMbNg2S9Y+3+PYoCteTvqbJL5YuMRsvMruOHizAY3AdWWX2KvSpcNrHUHuR9z9R+oOW0/\n+SU27/uV7uoaSBnRIf5TFgQrZGtqSF4AfHZIn/5+V+FeWkemN+gz9fl84xEA3vxyN+1aRREd7n8V\nWV3DLA4S5J56dwuWjpsxxsFnB74lM6IncVFW3KobVVPJsxVQ6bIFXOmXniRgNjQvvj26BoADJYd5\n+TW9M7m+8/7KMR39rvpP1YKtr7Ilfwfx1jj+NvTPActXZ30PwMpj6+iXrBfQDrd+8XWsPNt7IWa1\nWvn+lxxe/mQnDsMhyg8W0emW/hhMBvYt+gnV7evLq3BV6udihYc5C77n4RsGgKKA6vFr7lx3XH82\neHFVAK5rVqR//vMZfv75J9asWcUrr7zMG2+8z29+cwNDh45g/fo1/P73N/L008/TunWb086nhpKa\nxClwq26e2byAH078SEGpHYfT472aCjGGeNf75sh3vLj1Vb8TQFU13B6VnIpc/rX5GZRQ31Xatz8d\n4+/vr+TR7+dxsOQIa7cdx+HST8D/W7KNj1Yf8K5bXRiomkqFq5K/vPYdf1n9L778ZQuLt6zh88O+\nGsTS/Z+xL6uEQ8d9P3Zr7+8wZeyiuNzBc4u3sf1AIUeK6q/RbMr9mbkb5/sNNV26/zNe/+U9lu7/\njEXb3gJAMah6Xph8efHujs94euN/ASgstfPuin2s3nqcrPwKjpZl8/iGpzhRmee3v6Nl2Ty2YR7Z\n5Tne93KLKtm63z+dmuL25ond7aDAVsjcjfM5UnqM9cc3sSZ7g54uk4tNRWsB/5rAnLWPsmz/537b\nDFbI1rT++A/8bd1T/HI01/veB3s+5u3di/lw7zLe3Ok/QsZa47xYduALXv/lvaDbPVhymMc3PEWB\nzVfj+WTdIf65dAWPbZiH26TnfUGpg4cXbQz4fHXBVu6swKN6yK04weMbniK7PAdj4hGMcXp6C50F\nzPnoTUBvgqz2x9UP89lB/+e6nKyG+d2xdTy56Tm/zv63d33Ih3uX8fbuxby3Z6nf+jWbOt/asZS3\ndy8Out09RfuYu3F+wP73Fu3n8Q1PUWiv/9EC3tq2sxRVVckqP151nuWz4sh37CvRL9AOlBxi5VH9\nvDBZTagO/Td376qHvKPElq87DIDqKccYasZgMmDPq6DymH/aVmWt48Utr+Fw6efXL4cKMcYfxxBV\nwKJt7/DRvk/91jcaDISHRxAVFcX8T57jw73L+OKLT+ndW3+ee25uDn369GPUlRPILTpBXkk+WVnH\naNeuPddccz2p7TKY99WzJ63tNQapSZyCnIoT7Cnah6U92PLTKat0en8gZoOvTXfxvk8Avc020hKB\nzeHmLws3UlLhoP3wneQ78jC3dePcOQiA17/cQ0i39RgqS3hr2yfsX5OJ060yuFsyP+7O48fdeVw6\npC1mk4Gf9h2HcH0/eeXFmNP2Ywgvw9JpM4rZ6ZfeLw9/y9KNIaCohA7wvW9udYgdBwu9r/fmnIB6\nmjezKo4DekHWM7G7d9t+DON8f6tGlq8/xMSBrfkuV1+v0mnn281ZvjxadYATSV9S5D7Bkn2f8vue\n13uXvbztNfLthXx5+Ftu6P4rAF77fDe7jhTx9J3Dvc0ImsFX4Jc6S1m6/3Oyyo/z6s53yanwFeIA\nOZaf0bRf+RWMAJ8fXsGkDhfhcquYjArlzvrb2bPK9bx4atkqXv7DlWiaxrfH1vitM639ZO/fRsXo\n21dV7eKarlfgdmuYTAYMVVeYL29/g2JHCV8c/oZfZ16h59F3B7D2/Rqlwo0lej/kdtKPtcL/GMBX\nw9TQKHOV89buD8muyOHdPUuwtPNv/jOm70ZVNW9hWu2Tg18wvvWFAJiMBspO0udwtFxvtnzmkzUc\nO2TmsVsGegNztRmdfHlRZvPtb22uXjhf3XlqQI3y+S2LcKluvj26hovbTCTErOfh05tfBGD98U1c\n2m48ALuPFHGi2MaInr4O9OoLKZfqZua/vyWmzyaKXAUs/Gkxx5z7/Pb1/t6lDE8dginUQljraHb/\n3wYiO8VT1qmQeHcYefl6HkR2TKR4j8bu5zcSkhBGeEaNxzIr+nmhOt1g0M/N3ccKMISXgtnDT/l6\nE1JbbaD3I9UjCufMeYTbH7oV1a1iNbbnuXlP4Ha7+dvfHqKiooKs8mzihqSyrWwnP727nq9WriUq\nwoojvpLWrbuy5siPXNppNJt2nWDBsh08evMgkmMbt09QgsQpcKo1flRmOw98thBTsn5SHMwug05Q\nXKO99VD+CY45NnAiK4SC0qorPbu+XDHbMbfZQZG9n/66qmbhcIIhooifStYRltMdU9pe3Fkd2X20\niHatoqh02akudnLKCqGqoNRUI8Eab4xJhzHGBnYCHsotxZSxG09eGnllpZjiT378Jc5SPKqHd/d8\nFLDM0vlH79/rtmdxfKf/8v15eazKWYkhIhq1PI6f9+UTYq3EEAZF9iLe3r2YSe0mEm4OI9+uB7Cs\nvEp2F+7jYMkR9pcWYmxVRG5RX28h6cHlS5ujzFerMwSPeB/t/oYDFYFDNw8cL+KJFW9zUYcRVIQ2\nrAlFMTsoLKvk2R9eC1h2z/JnMMbofwdrt88qLuDR5R8S7kjjst59GNsv3VvDyass4I1fPmRC+jhA\nQzFVfb9oGKLzUKwVaPZw5q/8kLEZo+nZQf/i7DX2U+Io9QaNEEMIwXy460t2lwQOb571/CqsbfZx\n54gplDkaNvZ/b24uqjOR//z8SsCy/+54y/v3m9/8ArT2W17hquTrI6vol9yLjMg0QC/cAbZnH+az\nH/7Db/tNZc+RUm+7h8vjYmveDoocJbz31TGchnKslssYkJkUkBc2tQJPeQWGEDiUXYkpITD9d727\nEEvCCdpc0d3v/ZFdb+azQ1/iyu6AMcRD+2t7BT3+rvcM1f8IM9P+sutRy9zsMn1B0jD/Y/3Pj29h\nSoKUC9sxrKsejDt27ESnW/rrad00jle2f8rvYy7i+edfBuD2FbMBOFh0lMjR7chQ+4ECof2/AuBY\nYREbsn7m5bVb0KLMvPnTl8wae3nQdJ4uCRKnoGYbrDltH6YkX6dkYUUFJ4pt/PmF9YRWXTA89+UK\nzK33YHEkAnowsLtcYAKD1YbBepR3dn0EhlQUo95UU+YuIaTbIQ6ocOLQLsxpZaglCRzMLtVHthh9\n7aB5FUUo1a89vivWmixtdwZ9/5eyzZhbHcQYdxzNHt6g4/9h3xGiLJGsrXW1CGCM8jUBFJTrV18H\nskuh6nlSH21ZiydxN0nx6VT8kkxZpQsUPcAeLc/maHk2HtXD5R18z4o+WnKCZ39eAICSEYrZamNP\n3lGKyvRCwK25vO2lr3+zBWua/v3UbO6q6evsL4O+P3/VYsytDvFVXgGt4qKCrlObYnHw+sZvyVMO\nBCwzxviaz9b+coztG7fSKsF3dffNgQ2YUw9gK6zgza/CGd6jlXfZnuL9wH427shHsfj6MjymckK6\n6PtS7WHsUyvZvszC87dfQojZSIXTd3dxsaOU/LJyUPxruDWtzAn+yGBH/C94Ig7z+MpFGMwulAZc\nlCoWB8bEoxy1B+ZFzZv7vOcqqve99cd/4KsjKylxljI67lKS43z9IsfdBzElw8vfL8dTmIK1p/7+\nCVs+Xx1ZqW+pVSgWq421u7oxIDOJVT9nkVdS5i3ZFIsDxVhV4/TUUdyl7CWwbgafHfsUU0o2isUB\nJleQNQIpZgemlEN+v4dqNcuL6lpczYsIU9IR8i17WHUskl9lTvf77M8F+gwAxqSuqGUx3vcLHYW8\ntns9lnagOqzsx065czwRlob9phvC+MgjjzzSaFtrYpWVwb7Ks+dAyWG25usdb4Zw/zZJm1bGVxty\nUEwuTAl6NVwJsaGYXHg0D1HGeLSoXNyR2WDw/UgqXXYcDg1jjN7e7tZcKFVVVqemH6/mslBuc7P8\nl42Ya5xoRwrz8eBGsThQ7WEYQho+DYGt3IwSVqIX1KoBxXLyvM0vUNlS9JNfM08wSkglmtuMS7Hh\njDwKQKm7BMXkIizESNu4ZE549BpO9bGCHkC//fkITqte81EMqrdgqb6iLiyAwjIbhpgTmBJ9zVfl\n7grsaiWqwYmzPAynoeFj5F12M4awcjTViN3lRDOePC80l4Vi6x4weOpdr9RVQk6ek30ncjDF630s\n+bZCPIoTNAXNGcIvBbsoVI56gyaAR1XB6PYWNqqmevNAqSqwNEcoTs3OcechNuf7RuqU2MvJrSxA\nMahY3LGUeAoanBeax4jBagO3BUxO7z7rz4sQTMlH/L7LYBSLHTx6Xlf3kezNO45mcIJqYOmqY3y+\ncxPG6Nrp1VDMToyR+lBsg6J4h996a1r2cBIS4J0f1uKO9J0XitGNIVyvpau2CO/fDaPvV3OEYbBW\nevvc6mM1hKElHORkYzKK7MWEmUM5VHrUOzKyurywGq2YFCM7Cnazq2hvwGcNFjuGCH0gSnGpSw9i\nVOWFAqnhKWSX5fHV7p8IDzWSEZ9yCsccSB46dAq+PrKKJfuWn/X9qg4ritlR74/Q6orHbm54YeAp\niccYXYBqCwODekoBpjmp9lCUEDuKUndeeMpivAVKQ1Sv7ymLwWCtQDE37KrxbNI0Agoe1R6GElJZ\nb4HkKY0NelVbF9UWjiG0Ak9REoaofG8N95yiKX4BFcDgCkc119+Hcqp5obnMKGYX7rw0jAlZJy34\nm4OmKgHlQrgSQ4Wmn/9RxlhevmLuGe1DRjedgoYO+atLvJKBY08f7FtH0J0JeAqT/ZarDmvQzxlC\n7CgGDdVWdxXSHO7f9q1WBD6NynW8LVHGOAAUa9WxeMz1Bgj7z6Ow/zwKrUZzln3bUFxHO3tfR+YN\nDPbRevVI6EYXzzguib0ex56+eIoS/ZZrzuBNRgarDUXRSAlLqnPbhhD/7ynCEBOwjiu7Pao9tGqb\nVeurxnoDhG3zaD0vavwm7VuH48rq4H2tHAvebl0fd0GKngc7RhJybBCeEv8OIs0ZErSAMlj1AFHf\neWGw+heccdbA58m7jnX05rdiqToXFK3eAGH/eRT2LSP939s6HFe2b+iz89Cp36Xvzk/Fsacv9q3D\n6eIZj6fUP72aMyQgQADeAFH9nQYTHu3/G9GcgX1XriNd0NxmvfCtOhcUs7PeAGH/eSSGPaP93vvL\n4D8yMm2I97Vy/DTy4kS6Ny8ce/qilvs3haoOa9ALx+oA4cpqT6ZnwinvtzYJEqegooF3mNZU86Tt\nndaRSHcGwzp15IL4bqiV/gW5pyBwigPV7msUbu0aVOd+as+EqlZEB6wzoFM6mQn6j9gQot8rUF1t\nrYvmDNWX+CEEAAAgAElEQVT/ufUflKcwGc0WhafYV6ifONywdvwQzXe87aPbMHP8BC7t050h6T1R\nbf5z37jz0wI+H2uJ8/49o/OUOvdTu+ksxRqYr5rL4s2j6lFhgc0ctbisaM5QcJur0tgKzR7BbWP1\nTkjNY6TyRFx9W/AyeXzHO65rT4a36Y2zIozi7FjC8Q9qwfIiMdQXSK7ObHheVHcO16S5Qogx6kG3\nunmvZr9KMJozFM0Rhqbqpaf7RDqaPQK1TC/UFbcVQ3lyfZvwqvkbUUvjUIuT0OwRzBw/HrPH/zcS\nLC9q/kZcR+oujO2a/8WDWhl43mquENTKSL/C9+R5EUZFse+i5sKM4SSHJdI+uq2+H1s4tvyGnRf+\neRFPa2sHnrxxAjFqBqrD/2LAk19/eeHOaUduzplXf5o0SMyZM4ehQ4cyadKkgGULFy4kMzOT4uKG\nNws0txK7HiSc+3riPHBBwHJ3rm80g0kxMrHNGK7v6RtpkBgew5O3DuWGSzIZkJmEBf9eQa0yArVC\nP3E1txlrcRfu7nsLIYVdcB7O5OpBQ7ilx3WMib7K70q+pgszhhObeyFakE66CIsVax2dusE4dumj\nLtISw/l1l2l0CumD65g+DFOzReI82B37lhGgBu7LneO7ycegmXBldaBnuO/KKtri+4EO79kKzeWf\nLrU8mhBVL8Q1p4XBCcO4p9/vmdhmDFd1nkpmXCda20Zh3zGE0OIuQdPvym6HfcdgIq1Bri5VI3ER\nDe/cc+z0jSF2HuhBO1MvrulxGXfP6EnvVp0w5/TAsW243pZfy7jWo3wvPGZcWR3I0Hp73+qYlEy/\nLr6g2ybBv1allsZ5a5kxIdFc1HYsd/a+hQltLuSazCsY1aE3N3b7FTd1vhV3bkbQ9F/Udix/GjAz\n6LQsqEZaxTb8OciOX3wXK7+94BouTB+BO1uvTd13yQRMORdwRer1PH3b2IDPTmzju4lOc1kYlz6W\nXtG+mmiUJZKkmFCuv0j/Tnu39Q8KanECmks/BtVhxZXVAeeuAbiy2+Hc34P/u2kGkzMux75tGO78\nVgRzabvx3D/gbvp2DGyr1zzGOgeBBGPfMbjqLwXnvl70iBjo/b77JffCmNMdx85BaM7AVoKaNVDV\nYcV1tBOeE74y5OaL+vDAdf2Ij7bStXVsQO26Z2J3NLf+21PtYSQ5ejHAPEnPi329MCsh5BbZOFNN\nGiSmTZvGwoWBU+vm5OSwbt06UlPPvcnB6lNaNSTQU5jCr/sF/gA8RUmYNP1kGNyqP5M7XESneF9h\nGW2Jwlw1Nj40xMTYnv53lY7r0554RV/ffbwdtw28gi4pqTw44RpmjppCu1bR9Eq8gFsvGs2Q1L4B\n+7cYLUzpcAndEzuCGniiR4eFBi8k6qCWJtA5PZpHfzuIEe17cfewX9Exwfej9eRloDmCF7Ttw3wF\n99jWI7ip7xTGd/cNMYwO8QWJLq1juWxArbvV3RZaW/Uf0YiUkVzbcwrxobFM7nARI9P1YHPjsDG0\nj23NNf0D7941ahbcxzqhVcQQExaYxoGdU+nRNjHg/eAU1DLflbtakkTfiFGM7NqJnh0SUBQFa2lH\nvZYRZCBy/2RfQBiRPJJ0T18m9KiZF5H0aB/PA9f2487pPRjQwX/opOa2oBbrV/qXthvPpPYTiQ+N\nZUqHixmaqhew/VP6MDyzG+68wLu7oy1RXNZuAq0j0/1u7vPyGIkNb9jYejNW1HJfE1C/lJ5c0XlS\n1bFDZps4nvn1dYy+oAOh5sB9DWrVz/t3iqMPUztPJMnqazq8uF9nnrh1CKN66+dZp2T/gtyshGK2\n6e+5j3bGndUJzRmK+1gXopztCbEYubDtIDRbJJ4TgQEzMTSei9uOIz0ylVCTr+D22N3kbzwGqpGo\nBk6al2hNJC1Mz++iA2tw5SUwIX0CMSH6xY1BMeDJbcvR1Yvo0yrwvHDXSJ/7aGfcxzv41ajbJSRi\nrJp646qxnRjS2f/u6mtG98BToo/p7RcxkocvvoabxvfjotYTGNiqNw9e15/fXdatQcdSnyYNEv37\n9ycqKrBKN3fuXGbPnt2Uu24Sdre9KnIb6N0xAeeBC3Adb+tdrrlCvEMOnarenhlr9TUd1CwYAdKi\n/augQ7q05t5xk8iwdqBP0gW0T9XXj4uy0qO9fzt1sKvjtlGtMRtMXD6iHWN6tw1YnhQdGXQiN4BB\nyb4r5R4J3fDs0a8WQ0P8awmpCXqB2yY5kumj2gNgtRgZETfB7yqxf4cMbz+GBzeDu6eQFOYbpB5l\n8b9ybZ/oX2DfNLEXv+47ngviM5nQZQDBJMWE8vQ9o2mbFNjOnmpNp/r0DgtSWA3skkqYOXgfUBK+\nK7wL4rsywBTYnFM7Xyod+iibET1bcUWnyd6bvQCiatSaYqPM/OWGAXRKTA1Y3iEtmj6dEompdZ7E\nhUXgzm1N97jMeqd+iQg1c8flvQPe7xjTznvDWkiQIHHLZb2wBAsewAU19ucpTmRm35v58zV9uWRw\nG349rpN32dSR7bk6yFQal3e4xG9Yc3SN731YD72wbxXh++7jQ/2bSaNqTZvisBm5od/FhDvTCPf4\n1xTiIvVjsFTdfKcFuVDqUDMvatSqPTYXBRuzQDXSvU3w/q4orSpgadA9PpPf9fwNf/vtQB67eRCu\nnO/p3ymW1sn+zaa3XX4BUWEWRvdJ49J245na4TLvsh6tfenXtKo01xiOXvO8iQg1k5nu33wXbg7D\nfbwdnqIkojTftqaN7MDvJnUnIymCzDaBv41Tddbvk1ixYgWtWrWiS5fgTQTnMpvHjuYxcdnQNsRG\nhnD7qEvILark43J9UrBhXdpy1LoHm60Mp0cPEgbFF4drF4yt4xOgxs2wVpOVWGsMfx76+5OmRfUE\nxvekqnbq0BATaXHRUKsp1WwwB5351J2bQaeu7diQq88rc2vPG3js500cKC7lRLF/dTU1Xj+JI8LM\nXDSoNf0zk/zu8PyialqQoZltWZZnxoXHe5ezuUaAql0QJkf6t8P3aJNCpCWC23rdVHcmVAl2TB0T\n0zhgVOjRPh6rMbBJM8RkCRowx2SMQHWGcCJ3PwC39bqRNVuP8x07SY4LI7ewsupY/PPfVhUkwqwm\nLswYDsDyg/oNT5E1xqxXXzyEmnxBPrrWeRFaK3g98KuhoBmJiTh5U2HbpFjwv6mYpDBfIRwsSMRH\nhpNlD8yLi9qOxaN62F6wC0014NzTj/bTWkMMdM7w/74mDW0bND3j24wG4KP9+rQU1hpX7waT3u6f\nFOEryOLC/M8Lk1KroHeb6ZPRgT4Zd1FYaueTdYcoLHOwdX8B7VNr9cMFCRLJfnnhO2+Of70fZ5GN\nrG9e5aeSNjAklBNrjlC8IxfNozFw6BC6TEjj+2NZZL2/i0r1GBvUL7n++pspLMynoqyIH5bPZ+/a\nGJ555j/e7V7QLp7UhHAsJiOXtBvPV199zu4X9PuMLhiXAfGgqRrH132GpaKcUKuJskojiUMyWLZk\nid904Rf/fprfsViMFq4Y2Jf3V0YzaNiZzw9Wl7MaJOx2Oy+88AKLFi3yvncejcDFoerjvKunCejd\nSb8yzv2lH9sLdnLjhT35pdDK81sW+rVDD08dxPaCXQFBIikyBqshDLtaidUYQmxIYGdzXRKjwqFW\n2RcT4vvhBmtW0DQ1aHNT/05paJr/SJZRvVI5kF1KXJR/gZWaqBd4kaFmjAZDwBQAvRN7cLj0KKHm\nEG7ocSUvbXuNETVGeQxI7sPBksN+hSRArDWWMFMolW4bEebwU5pu3BLkhrGksASeu2sgRqPiDX41\nGRVj0OASbg7HYFGgxqweQy9IoazSyYCuScz+z3oAbE7/+wc6pkWz+2gxrZN833G3uC4U2IswKAau\n73Y1r/7yDgNTfM2EvRIvIK8y3y94AqSEJWMxWnB6nMSERBMdFtrgyRCDHVPNGlywPqkQoyXoeRFu\nDvNODKmgMLBr3SPKTqZDdDtcVQHy6i7TeGf3Ynom6E1ukeEWPMUJ+n00VjOL933C5hPbAPzOS01V\nCOm1iofW1Zi7KgrUCI3YeA87Qow8tE7Pp5ThGk63G1utAVo1A2bN30ir8R2wn6jgv68uZEPuJj78\n5kMchZV0/v0ANE3jxNLjxOxPpPRoAdaYMP47T7+TvLKygrCwcN59922ee+7FoC0n1fLz83nhhX8z\n8s5LiYyIZNfrW+macgE/VmynbXQoz73yKgZF4R9rn8YaFsabz/pPF+4xaygoaGjeAQgXDWrNhX3T\nsFqarig/q0HiyJEjZGVlMWXKFDRNIzc3l+nTp/P+++8TH3/yeSESExveudbYVE3FqTrQPGHExYb5\npeXeUTd7/05K6s/ozP5+n52ZeEOd231txrzTSs8V4zP5oNaDtNITEr3piqoMbI4Kj7R4O4Nr6tI6\nkXCr7weTmBjJ1LGdiYsNo0fHBOKjfdsaGGml87pDDOuTHvT7mDPmD96/xycOYXy3IX7L/zj6ljqP\n6ZXpwaesPplWyYFV6tZJyaSn6UEzrDwwiERFhxBvCPxBJ8fGUOH01Z6qj/G6SfpAhTk3DODtL3cz\nYUg7IsJ8BetDNw/mx125XNgvw1ugPzL+bu/ySxNHcWmPGh3YwANjbq/jiCJ544pn6lhWv7TkwFE0\nHVulk5igH4fxROBFWWxsGDFB+pZaxcXjKKq6i91i4qFfDQlYp6H+fpGveXla4nim9fY1x4VHWnHu\n0X8z6VfHEFZhwWioDopGEsPi0DQoKncQGm6qsaxqDYOC2WQIeM9kMmOrNWo9M60NidF6XriPBd40\nmZYcR3hZCGX7CynfX8ie/2xE0yDaGEFFUQmhyRGc+Oogr722gFGjRtG/v55ugwHi48OJiQn8TZjN\nRmJjw8jOPsDQoUN4YsajAHxQ8QH79+/nrdue54rPruClBc8yatQo/jnlQRRF4Xfv7+Lvf3+YcePG\nMW7cOMLCwnj3qucbkNuNq8mDRM2aQufOnVm7dq339ZgxY1iyZAnR0Q27gm7Om+mqH4mIx4Tb4W7W\ntCQmRlJQEDivTqIh2Zsui0sv2N0FKWSkWjjuOEKoO5KcysB5nBw2D1FVtZBu8V282+jeOgbVGXis\nf/61fjXc3Dc3gp4X+fmBeRGlxnrTF1o1jHJoqwHsKThKvjMHg91KaVngyA9bhYd4q3612S+pV8Ax\ndkyJ5KHr+mOrcGCr8B9336NNbNC0nC2JiZEUFQQek8UZ7j2OKEUPqGMyRrA9fycnbPm4yg0Ulgam\n21WpkWLR+076J/Vusu+7ZhlRUWbjorQJXJR2ZuP7ExMjOZ5bxF0r5/i9b7BZyat6/kiCUf+eL2k7\njq926M2kFSUuCstKQdNIG92RqD567emuPr+nwFbIG7ve56a/3Ul8bgT/+Mc/GThwMDfccDOqqlFQ\nUI7LFdjE5XJ5KCqqpKSkEpvN6c3HsjI7lZVOHA6FhQvfZMOG9fz3v6+yZMnH3H//X3j88Xne6cKf\ne+7fvPHG+6f1DIkzvbhu0iBx7733smHDBoqLixk9ejR33nkn06f75iRRFOW8aW6yVwUJzW0ixNLw\nIXJnw+UdLqFnQjeSw33NAe2i22DfPhStMoJx3buSlqGQHpnK4bKjAZ83Goy0i27DnwfcRXI9N6md\nD67oNJnu8V38bhrLjOvEnwbMJD0iFafHRYG9kMSweOwnAm8iNBmMdI3rzJ/6z6RVxJlNZ9AcajZL\n/brLdDrHdiTC7Ksl9Erozp8GzCQjIo1L2o2jyF5CdEik9/w2G8zeZiGTwUSPuM7M7n8naRHBh5M2\ndppNxsYbS2My+Iq367peRYeYdn79UANS+tAqIpmMiDT6RffkD69sw2qyYnc7iOwYT/7KI4R3j8do\nMWIrrqBLdHt+1/4auqR0JrRXKKGhVj77TJ+BISwsnIqKCqKi6r7g7dbtAp599ilKS0sID4/g66+/\n4IorrqakpBiz2cyoUReSmprG3//+V8A3XXiPHr34+usvsNkqm+VZ2k0aJObNq78p5Ztvgk8ydi6q\nflANHrN39MS5IsoS6RcgqoW4Y7HjIToslIxIvRmiui8gMTQeDci3FXgL1GA3Wp1vokOi/Nqdq1U/\n+MdqCvEWeNWFZ5vIDPJtBVS4K70d6q2jmq4j8GyJsUaTGObfjKsoijcvQk2hhEbo50N1f1mH6Lbs\nLtqHhkakJRJFUWgTFfzei8aUnhiuT/rYROKsMSSE+jfFGRSDNy9S4lPo3asv119/NfFdWhE5OI5E\nezRbXtoEwHPxOfztkb/jPGHj1odvwlDVnHXfffcDMHny5dx330wSEhL9Oq7BFwTj4xP4/e9v5847\n9YEpQ4YMZ/jwkezbt5e5c/+KpqkoisKtt97pN104aFx11TXNEiBAZoFtsOrmJs1j8nZcnyvqmvX0\n4RsHsHV/gd8wuD6JPbi6y1R6JlyAqnnYUbCLXgndg37+fBT0PoA6DEkdgEtz0z+pNw6Pg91F++kc\ne/pPRDvXBBvJVJcxGSMwGUwMbtWPMmc5B0uOkBF59u5j+ttv655NoDHU9Rup6S9/0fsKXKqb1Vnr\nGTZqEEW/LeJYWTb9U/oAkJqaxsCBgwM+O336VUyfflXQ7T777Avev8eNm8i4cRP9lnfs2IlFi94I\n+Fz1dOHNTYJEA9mqaxJuE9ZzrLnJbAj+NSbHhjG+v/8oIUVR/EYbDU8LPOHPZ6dyR7lBMTA6fRgA\nEYQzNLRhUyecL04lYBoNRu/Q3VBTaNDa2PnsVPLCbDAxJmMEACnhyaSEN2x6kf9VMndTA9WsSZxr\nzU1K0McNtUwmRa57qp1KTeJ/neTF6ZMg0UB27+gmMyHmcyvbGjqGviWQvPAxGyVgVpMgcfrOrdLu\nHFZUdVOR5gpp0htXTsWENvrso9Wdby3ZqKpmo8TQIM+nbGEGt9LH7tcc1dRS9U3qiclgqnM6GnFy\n8tChBlqw9VW25O/A9tOFvHj3BMym5mtySkyM9OaFqql+U3+0NJIXPpIXPpIXPmd6n0TLzblTdLQ0\nB8VjRvFYGnUs95lqySd/bZIXPpIXPpIXZ0ZyrwFKKuwU2AtxV4bTNiVa2r2FEC2GBIkGyC4tQDFo\naI5QxvQ9/284E0KIhpIg0QA2pz4RWKg5hCEXnH9TNQghxOmSINEAdrc+XUByTDgGaWoSQrQgEiQa\nwO7yTXgmhBAtiQSJBnC4q56sFuThNkII8b9MgkQDVNck5A5WIURLI0GiAZxVfRIWCRJCiBZGgkQD\nODzVQUKam4QQLYsEiQZweqQmIYRomSRINIDT7QYgxGQ5yZpCCPG/RYJEA1TXJELM0twkhGhZmjRI\nzJkzh6FDhzJp0iTve08++SQXX3wxU6ZM4c4776S8vLwpk3DaDpcepdJVCYDLo9ckrNInIYRoYZo0\nSEybNo2FCxf6vTd8+HCWL1/O0qVLadOmDS+++GJTJuG0fPLjLzy56Tn+svopnl+yDZeqBwmL1CSE\nEC1MkwaJ/v37ExUV5ffe0KFDMRj03fbu3ZucnJymTMJpWbJ2NwA2Stm0O88bJKzSJyGEaGGatU/i\ngw8+YOTIkc2ZhAZxVt1xHWaRmoQQomVptjGd//nPfzCbzX79FSdzpk9Yaij/ViUNZ1WfREpi7FlL\nw8mcK+k4F0he+Ehe+EheNI5mCRJLlixh1apVvPbaa6f0ubP1+FLFqPpemFzeaTls5U7ylOZ7hGq1\nmo9mbOkkL3wkL3wkL3zONFg2eZCo/Qjt7777jpdffpk33ngDi+XcbOM3GHxBQrHYcHpcmACzzAIr\nhGhhmrTUu/fee9mwYQPFxcWMHj2aO++8kxdffBGXy8VNN90EQK9evXjkkUeaMhmnTFVUb2eNEmKH\nqqAhU4ULIVqaJi315s2bF/De9OnTm3KXjULV3N4gYQipBKU6SEjHtRCiZZE7rmtRNQ0Vj/e1ElKJ\nUlWTMBuMzZUsIYRoFtJ+UovHo3qblwAUayWKUR/dJM1NQoiWRkq9Wlxuzdu8BGCMLvD+LUFCCNHS\nSHNTLW6P6m1eqs2gSHYJIVoWKfVqcXtUMOh9EjVH75qVc3O4rhBCNCUJErW4PKqvucntG800NfU3\nzZQiIYRoPhIkanG7fc1NmttXe+iakdBcSRJCiGYjQaIWt8fXca3VqElYTSHNlSQhhGg2EiRqcdUc\nAlsjSFgM0ichhGh5JEjU4qnZcV2juckiT6UTQrRAEiRqcXlUlCDNTTL8VQjREknJV4vbrXmbm6YP\n69rMqRFCiOYlQaIWd40+iQhzWDOnRgghmpcEiVqq75NQMMiIJiFEiydBohb9PgkPRsVIiFGChBCi\nZZMgUYvbo4LRjVmxYJbnRwghWjgJErW4PBqK0U2IIQRFae7UCCFE85IgUYvL7QGjG4tBmpqEEEKC\nRC02pxPFoGE1WkkJTwagT2KPZk6VEEI0jyZ9is6cOXNYuXIl8fHxLFu2DICSkhLuuecesrKySE9P\n5+mnnyYyMrIpk3FKyl2VYIBQs5UoSyT/HPFXGeUkhGixmrQmMW3aNBYuXOj33oIFCxgyZAhffPEF\ngwYN4sUXX2zKJJyyCqcNgDBzqPd/udtaCNFSNWnp179/f6Kiovze++abb5g6dSoAU6dO5euvv27K\nJJyySpcdgAhLaDOnRAghmt9Zv0QuLCwkIUF/NkNiYiJFRUVnOwn1srklSAghRLUm7ZNobImJTd93\n4dIcACTHxZ6V/Z2uczltZ5vkhY/khY/kReM460EiPj6e/Px8EhISyMvLIy4ursGfzcsra8KU6aqb\nm9x25azs73QkJkaes2k72yQvfCQvfCQvfM40WDZ5c5OmaX6vx4wZw+LFiwFYsmQJY8eObeoknBKn\nqtckQk3WZk6JEEI0vyYNEvfeey9XX301Bw8eZPTo0Xz44YfccsstrFu3jokTJ7J+/XpuueWWpkzC\nKatubgo1SpAQQogmbW6aN29e0PdfeeWVptztaXO5VTSDC4AwmSZcCCHkjuuabA43mKqChElGNwkh\nhASJGmxON0p1kDBLkBBCCAkSNdgcbjC6QFOwyrMkhBBCgkRNNrtekzArISgyT7gQQkiQqKnS4UEx\nurEoMrJJCCFAgoSfSrsLTC55bKkQQlSRIFFDudOOYlAJNUqntRBCgAQJP2WOCgDCTHKPhBBCgAQJ\nP2VOPUhEWCRICCEESJDwU+4sByA6RGaPFEIIkCDhp8Kj1yRiQ6ObOSVCCHFukCBRg60qSMSFRp1k\nTSGEaBkkSNTg0CoBiJUgIYQQgAQJPy7FBkifhBBCVJMgUYPHoD+VLtIc0cwpEUKIc0ODgsSnn35K\nebk+8ueZZ57ht7/9Ldu3b2/ShDUH1WgDjxmz0dzcSRFCiHNCg4LEf/7zHyIiIti6dStr1qzh8ssv\n57HHHmvqtJ1VTo8LzVKOySn9EUIIUa1BQcJk0h9gt3btWmbMmMGkSZNwOBxNmrCzLacyFxQwu2Oa\nOylCCHHOaFCQUBSFjz/+mOXLlzNkyBAAXC5XkybsbMsqzwEgxCNBQgghqjUoSDz44IN8/vnnzJgx\ng4yMDA4dOsSgQYPOaMevvPIKl112GZMmTeLee+/F6XSe0fbOVKGtCIAQTUY2CSFEtQYFib59+/L8\n889z/fXXA9C2bVseeuih095pbm4ur7/+OosXL2bZsmV4PB4+/fTT095eY3B53ACYDKZmTYcQQpxL\nGhQknnjiCcrKynC73fz617+md+/eLF269Ix2rKoqNpsNt9uN3W4nKSnpjLZ3plweDwBmo7FZ0yGE\nEOeSBgWJdevWERkZyZo1a0hOTuaLL75g0aJFp73T5ORkbrzxRkaPHs3IkSOJjIxk6NChp729xiBB\nQgghAp1S28oPP/zA+PHjSU5OPqNnQJeWlvLNN9/w7bffEhkZycyZM1m2bBmTJk2q93OJiU3XX2C0\n6McTHmpt0v00lvMhjWeL5IWP5IWP5EXjaFCQiI+P58EHH2Tt2rXccsstuN1uPFVX3qdj3bp1ZGRk\nEBOjjyQaP348mzdvPmmQyMsrO+19nkx5hT6kV3VrTbqfxpCYGHnOp/FskbzwkbzwkbzwOdNg2aDm\npnnz5tGxY0fmz59PdHQ0OTk53Hjjjae909TUVLZs2YLD4UDTNL7//ns6dOhw2ttrDG5V77i2mKTj\nWgghqjWoRIyLi+M3v/kNBw8eZN++fbRt25Zp06ad9k579uzJxIkTufzyyzGZTHTr1o0rr7zytLfX\nGNyq9EkIIURtDQoS27ZtY+bMmVgsFjRNw+1289xzz9G9e/fT3vEdd9zBHXfccdqfb2zVQcJilJqE\nEEJUa1CJ+PjjjzN37lzv3dbff/89jz76KO+8806TJu5s8kiQEEKIAA3qk7DZbN4AATB48GBsNluT\nJao5eFQVkD4JIYSoqUFBIjQ0lO+//977euPGjYSGhjZZopqDt7lJgoQQQng1qEScM2cOd911FxaL\nBdAn93v22WebNGFnm0eT5iYhhKitQSViz549+fLLLzl48CCaptGuXTsmTJjAypUrmzh5Z4+3ucks\nQUIIIao1uEQ0m8107tzZ+1rTtCZJUHOprkmESHOTEEJ4nfYzrs9kWo5zkaqpaBqESE1CCCG86i0R\n9+3bV+cyt9vd6IlpTh7NA5qC2XTacVMIIf7n1BskbrnlljqXhYSENHpimpOqqaAZJEgIIUQN9QaJ\nFStWnK10NDsVVa9JGCVICCFENSkRq+g1CWluEkKImqRErKJR3dwkE/wJIUQ1CRJV9NFNUpMQQoia\npESsokmfhBBCBJASsUp1kDCZ/rfu/xBCiDMhQaKKhgoYMBokS4QQopqUiFU0RUOR7BBCCD9SKlbR\nUFE0aWoSQoiaJEhUU1SpSQghRC3NViqWlZUxc+ZMLr74Yi699FK2bNnSXEmpomFQJEgIIURNzTbl\n6eOPP86oUaN49tlncbvd2O325kqKfre1gtQkhBCilmYpFcvLy9m0aRPTp08HwGQyERER0RxJAcCj\n6eCf+HkAABL8SURBVA8cMkiQEEIIP81SKh47dozY2Fjuv/9+pk6dykMPPdSsNQmn2wUgzU1CCFGL\nojXDI+a2b9/OVVddxTvvvEOPHj14/PHHiYyMZObMmWc7KQDc/I9PKW27jAhXOot+80CzpEEIIc5F\nzdInkZKSQkpKCj169ABg4sSJvPzyyyf9XF5eWZOkJ7ewgtC24HY13T4aU2Ji5HmRzrNB8sJH8sJH\n8sInMTHyjD7fLO0rCQkJtGrVioMHDwLw/fff06FDh+ZIik7R+yQczv+t53YLIcSZarbRTQ8++CD3\n3XcfbrebjIwM/v73vzdXUlAUPTi43BIkhBCipmYLEpmZmXz44YfNtXt/VUEiKvR/65GsQghxpmQ4\nD2C16v/3bJ/YvAkRQohzjAQJwGPQh9/GhUU3c0qEEOLc0uKDhKZp3iARZWm+G/qEEOJc1OKDhKpp\nYHICEGk5s6FiQgjxv6bFBwm3R0MxOwCIkiAhhBB+WnyQ8HjUGkFCmpuEEKKmFh8kXB4NxSzNTUII\nEUyLDxJ6TcKJQTMRYrQ0d3KEEOKc0uKDhNujgtGNEXNzJ0UIIc45LT5IuDwaisGDQYKEEEIEaPFB\nwuNRweDB2HwzlAghxDmrxQcJt0cDgweTIjUJIYSorcUHCYfbhWLQMCpSkxBCiNpafJCwu/R7JEzS\nJyGEEAEkSLj1eyRMBqlJCCFEbS0+SDiqgoRZkXskhBCithYfJOwevbnJbJDmJiGEqK3FBwmnxwVI\nkBBCiGBafJDwNjdJkBBCiADNGiRUVWXq1KnceuutzZYGp6oHCZm3SQghAjVrkHjttdfo0KFDcybB\n29xkMUpNQgghamu2IJGTk8OqVauYMWNGcyUB8NUkLMaQZk2HEEKci5otSMydO5fZs2ejKEpzJQEA\npyo1CSGEqEuz3EG2cuVKEhIS6Nq1Kxs2bGjw5xITG/+hQIrRA25IiI5qku03lfMprU1N8sJH8sJH\n8qJxNEuQ+Omnn1ixYgWrVq3C4XBQUVHB7NmzefLJJ+v9XF5eWaOnpcJhB8DtUJtk+00hMTHyvElr\nU5O88JG88JG88DnTYNksQWLWrFnMmjULgI0bN7Jo0aKTBoim4lLdYACrWUY3CSFEbS3+PgmP6gbA\napIgIYQQtTX7rHYDBw5k4MCBzbZ/t+YBIFRqEkIIEaDF1yTc1TUJCRJCCBGgxQcJD3qQCLVIkBBC\niNpafJBQpblJCCHqJEECPUiEyM10QggRoMUHCY+mNzfJk+mEECJQiw8SGh5QDc0+PYgQQpyLWnyQ\nUBUPaC0+G4QQIqgWXzpqeFA0Y3MnQwghzkkSJBQVBQkSQggRTIsPEigqBgkSQggRVIsOEm6PCgZp\nbhJCiLq06CDhcqtgUDEoEiSEECKYFh0knC4PikHFKM1NQggRVIsOEjaX/nxroyI30gkhRDAtOkjY\nXfrzrY3S3CSEEEG16CBRWqk/utRskHmbhBAimPMmSLg8LrLLcwBQNRW724GqqWe0zRKbDYAQkwQJ\nIYQI5rxpjL9l8UNUqCVYK9NwhRTgMdoxYuKC+K5M7XQJiWHxp7zNUptek5AgIYQQwTVLTSInJ4fr\nrruOSy65hEmTJvHaa6+d9DMVagkA9rAsPEY7amUELruZLQXbeHLj8xTaik45HcW2cgDCLaGn/Fkh\nhGgJmqUmYTQauf/+++natSsVFRVMmzaNYcOG0aFDhzo/MzTqUoa37cH3ew6SEBnJ/7d390FR1f8e\nwN+7KynyoCIrGJKDOPhTygdMsOCiFwkMQXYn0IlxakbNMgt5SMKdUeeq6Uw4zNRtHDMrs7g5eUt/\nU/izudH4dMW1SLQGLdExWIpdEZAnZV32c//gsoayiLl4kH2//trztPs9n+Hw3u+ec76nuXEIrDc7\n8H3NEbQF/YatJ7cjL+qVe+pRNLY3AQD8ho28730iIhqMFAkJrVYLrVYLAPDy8kJoaCgsFkuvIZH1\nbDKuXGnGeH+/bvP/vSEIm//nv9Dm/xt2nP4Mhqdeg0bdebWSpa0OgGDMcK1jfXOrBaOGjcIjGg80\nW5uAR4DRwxkSREQ9UfzEtclkwvnz5zF16tS/tb121HBk/dsi2K+NRm17DTaWFuJi42WUmc9gk3Eb\n/uNkAX6trwQAXL5WhY3GbfjvC/8EALTYOn9uCvBmSBAR9UTRE9etra3IzMyEwWCAl5fX336fkLG+\nmD92IQ5W/wt1o2tR+NP2bsvfLd+JkCHTUWW+BowG/vePU8j4RxpuSCsAQOvFkCAi6oliIWGz2ZCZ\nmYnU1FTEx8f3aRut1sfpsuUpkZhW8RgKSoogI6vQcU0LsQ4D7GoMefQSLtnPAD4e6Hr+nHXITdyw\nd4ZE6LggDBsy9H536YHqrRbuhrW4hbW4hbVwDZWIiBIfnJeXh1GjRmHt2rV93ubKlea7rtN24yaq\nLS34s74NlaZrmBg0Akfr/wWz6rfuK3YMATQ2qMUD/znvrXttvqK0Wp8+1cIdsBa3sBa3sBa33G9Y\nKtKTKCsrw9dff42wsDDodDqoVCpkZ2cjNjb2vt97+DAPTHpsFCY9NgpzpwcBAOzVk/Dlhc6QiBoT\nBaPFCGhsAABP9d//mYuIaLBTJCRmzpyJc+fOPbDPG+sV4HidEBLTGRL/z8+T5yOIiJx5aO64vh9h\nI0Mxd1w0IgMjoPX077Ys0MfPyVZEROQWIaFRa5AeluqYHj1sFK7e6LxDe8RQntwiInJG8fsklGCI\nzHG8fkTziIItISIa2NwyJP56uatCF3cRET0U3DIkACBxfBwAIHz0JIVbQkQ0cLnFOYmeJE9IwJxx\n0TwnQUTUC7ftSahVagYEEdFduG1IEBHR3TEkiIjIKYYEERE5xZAgIiKnGBJEROQUQ4KIiJxiSBAR\nkVMMCSIicoohQURETjEkiIjIKYYEERE5xZAgIiKnFAuJo0ePYv78+UhMTMTOnTuVagYREfVCkZCw\n2+3YtGkTPvzwQ3zzzTcoLi7GxYsXlWgKERH1QpGQOHv2LMaPH4+goCB4eHhgwYIFKCkpUaIpRETU\nC0VCwmw2Y+zYsY7pgIAAWCwWJZpCRES9UCQk+FxpIqKHgyKPLw0MDMQff/zhmDabzRgzZsxdt9Nq\n+SS5LqzFLazFLazFLayFayjSk3jiiSdQVVWFmpoaWK1WFBcXY968eUo0hYiIeqFIT0Kj0WDdunVY\nunQpRARpaWkIDQ1VoilERNQLlfAEAREROcE7romIyCmGBBEROcWQICIipwZ8SLjjGE8GgwFPP/00\nUlJSHPOuXbuGpUuXIjExEcuWLUNzc7Nj2ebNm5GQkIDU1FScO3dOiSb3i9raWrzwwgtISkpCSkoK\n9uzZA8A9a2G1WpGeng6dToeUlBS89957AACTyYRFixYhMTEROTk5sNlsjvWzs7ORkJCAxYsXd7vk\nfLCw2+3Q6/V45ZVXALhvLeLi4rBw4ULodDqkpaUBcPExIgNYR0eHxMfHi8lkEqvVKgsXLpTKykql\nm9XvfvjhB6moqJDk5GTHvLffflt27twpIiLvv/++FBQUiIjI4cOH5aWXXhIRkfLycklPT3/wDe4n\nFotFKioqRESkpaVFEhISpLKy0i1rISLS1tYmIiI2m03S09OlvLxcVq9eLQcPHhQRkfXr18vnn38u\nIiJFRUWyYcMGEREpLi6WrKwsRdrcnz7++GPJzc2Vl19+WUTEbWsRFxcnjY2N3ea58hgZ0D0Jdx3j\n6cknn4Svr2+3eSUlJdDr9QAAvV7vqENJSQl0Oh0AYNq0aWhubkZdXd2DbXA/0Wq1mDx5MgDAy8sL\noaGhMJvNblkLAPD09ATQ+c3YZrNBpVLBaDQiMTERQGctvvvuOwDd/14SExNRWlqqTKP7SW1tLY4c\nOYL09HTHvJMnT7plLUQEdru92zxXHiMDOiQ4xtMt9fX18Pf3B9D5z7O+vh4AYLFYEBgY6FgvICAA\nZrNZkTb2J5PJhPPnz2PatGm4evWqW9bCbrdDp9MhOjoa0dHRCA4Ohq+vL9TqzsM4MDDQsb9/rYVG\no4Gvry8aGxsVa7urbdmyBXl5eVCpVACAhoYGjBgxwi1roVKpsGzZMjz33HPYt28fALj0GFHkZrq+\nEt7CcVc91ajrwBksWltbkZmZCYPBAC8vL6f7N9hroVarceDAAbS0tGDVqlU9Dq/ftb+310JEBk0t\nDh8+DH9/f0yePBlGoxFA5/7dvs/uUAsA2Lt3ryMIli5dipCQEJceIwM6JP7uGE+D0ejRo1FXVwd/\nf39cuXIFfn5+ADq/CdTW1jrWq62tHVQ1stlsyMzMRGpqKuLj4wG4by26eHt7Y9asWThz5gyamppg\nt9uhVqu77W9XLQICAtDR0YGWlhaMGDFC4Za7xk8//YTvv/8eR44cQXt7O1pbW7FlyxY0Nze7XS2A\nzp4CAPj5+SE+Ph5nz5516TEyoH9ucucxnm5P/Li4OHz11VcAgP379zvqMG/ePBw4cAAAUF5eDl9f\nX0c3czAwGAyYOHEiXnzxRcc8d6xFfX294wqVGzduoLS0FBMnTkRUVBQOHToEoHst4uLisH//fgDA\noUOHMHv2bGUa3g9ycnJw+PBhlJSUoLCwEFFRUdi2bZtb1uL69etobW0FALS1teH48eMICwtz6TEy\n4IflOHr0KN566y3HGE8rVqxQukn9Ljc3F0ajEY2NjfD398frr7+O+Ph4rF69Gn/++SceffRRvPPO\nO46T2xs3bsSxY8fg6emJrVu3Ijw8XOE9cI2ysjIsWbIEYWFhUKlUUKlUyM7OxtSpU5GVleVWtfj1\n11+Rn58Pu90Ou92OpKQkrFy5EtXV1cjJyUFTUxMmT56MgoICeHh4wGq1Ys2aNTh37hxGjhyJwsJC\njBs3TundcLlTp07ho48+wo4dO9yyFtXV1XjttdegUqnQ0dGBlJQUrFixAo2NjS47RgZ8SBARkXIG\n9M9NRESkLIYEERE5xZAgIiKnGBJEROQUQ4KIiJxiSBARkVMMCXroLFq0CHq9HgsWLEB4eDj0ej30\nej0MBsM9v9fy5cv7NHT02rVrUV5e/neae08qKirw7bff9vvnEPUV75Ogh1ZNTQ3S0tJ6HdWza5iG\nh8W+fftQWlqKwsJCpZtCBGCAj91EdK9KS0tRUFCA6dOno6KiAqtWrUJ9fT2KioocD6HJz89HZGQk\nAGDOnDnYvXs3QkJCkJGRgRkzZuD06dOwWCxITk5GVlYWACAjIwOvvvoqYmJisGbNGnh7e+PixYsw\nm82IiIjA1q1bAXSOhZOXl4eGhgYEBwejo6MDcXFxWLx4cbd21tXVITc3Fw0NDQCAmJgYLF++HNu3\nb0dbWxv0ej2ioqKQn5+P06dPo7CwENevXwcAZGZmIjY2FlVVVcjIyEBycjLKyspgtVqxYcMGRERE\nPJBak5u4n4ddECnJZDLJ7Nmzu807ceKETJkyRX7++WfHvL8+kKWyslLmzp3rmI6NjZVLly6JiMjz\nzz8vubm5IiLS1NQkkZGRYjKZHMuOHTsmIiJvvPGGLFmyRG7evCnt7e0yf/58MRqNIiKycuVK+eCD\nD0REpLq6WmbMmCF79+69o+27du2S9evXO6abmppEROSLL76QnJycbm3X6XRy9epVERGpra2V2NhY\naWlpkd9//10mTZokxcXFjn2fO3eu2Gy2vheR6C7Yk6BBZ8KECXj88ccd05cvX8a7774Li8UCjUYD\ni8WCxsZGjBw58o5tn332WQCAj48PQkJCUFVVhaCgoDvWe+aZZzBkSOfhM2XKFFRVVSEyMhJGoxGb\nN28GAIwbN87RY7nd9OnT8dlnn2Hbtm2YNWsWYmJielyvrKwMJpMJy5Ytcwz6qNFoUF1djeHDh8PT\n0xNJSUkAgKeeegoajQaXL19GaGhoX8tF1CuGBA06Xl5e3aazs7OxYcMGzJkzB3a7HVOnTkV7e3uP\n2w4dOtTxWq1Wo6Oj457W6+tzCmbOnIn9+/fjxIkT+PLLL7Fr1y58+umnd6wnIggPD8fu3bvvWFZV\nVXXHPLvdPqielUDKe3jO6BH1QPpw3UVLS4tj1M+9e/c6/cfvCpGRkY4hmmtqanDq1Kke1zOZTPD2\n9kZSUhLy8/Pxyy+/AOh8VsRfH1ofERGByspK/Pjjj455Z8+edby+fv06Dh48CKDz8Z0AMH78eNfu\nFLk19iToodaXb80GgwErVqzA2LFjERUVBR8fnx63v/29nC3rbb1169bhzTffRHFxMSZMmICIiIhu\nn9eltLQUe/bsgUajgYhg06ZNAIDo6Gh88skn0Ol0mD17NvLz87F9+3YUFBSgubkZN2/eRHBwMHbs\n2AEA8Pf3x4ULF5Ceng6r1YrCwkJoNJq71oSor3gJLJELtbe3w8PDA2q1GmazGenp6SgqKkJwcLDL\nP6vr6qbjx4+7/L2JurAnQeRCly5dwtq1ayEisNvtyM7O7peAIHpQ2JMgIiKneOKaiIicYkgQEZFT\nDAkiInKKIUFERE4xJIiIyCmGBBEROfV/smX5vm0Z6kkAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f970d490590\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test_accuracy 0.1\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXl4FFW6/79V1Vv2BEhIAG/AuCAIsgjoCEFgFDSsio7I\n6Dg4l/GODgpu4wxcnWHEHYXBDQVllJ/LRUAhDKCgYd+XsCVhS0IWOnvSSXqtqt8f1V1d1V2d7iwN\nIbyf5+Ghu6q66tRJ1fmedznnMKIoiiAIgiCIILCXuwAEQRDElQEJBkEQBBESJBgEQRBESJBgEARB\nECFBgkEQBEGEBAkGQRAEERIkGARBEERIkGAQVx0HDhzAPffcc7mL0eEZOHAgioqKLncxiDaEBIOQ\nGT16NPr164eamhrV9kmTJqF3794oKSkBAPzlL39B7969cezYMfmYwsJC9O7dW/7+yCOPYNWqVfL3\njz76CGPGjMGgQYNw5513Ys6cOQCA8ePHY9CgQRg0aBD69OmD/v37Y+DAgRg0aBCWLl3qV8YlS5bg\nhRdeaNV93nrrrfjPf/7TrN98/PHHePfdd7Fv3z6MHDmyVdf34FtHHY3Dhw+jR48el7sYRBuiu9wF\nINoXPXr0QGZmJqZPnw4AyMvLg91uB8Mw8jEMwyA+Ph7vvfceli1bptquxZo1a7Bu3TqsWLECPXr0\nQGVlJbZu3QoAWL9+vXzcI488gsmTJ+P+++9v1T2IohiwLC0lKysLzz33HJxOZ5ufu73C8zw4jrvc\nxSDaEWRhEComTZqENWvWyN/XrFmDKVOm+B03ZcoU5Obm4sCBA0HPefz4cQwfPlzubXbu3BkPPPCA\n5rFNzVSzfft2fPTRR9iwYQMGDhyIyZMnA5CE5t1338W0adMwYMAAFBUVYfXq1bj33nsxaNAg3HXX\nXfjmm2/k8/haCaNHj8by5csxceJEDBkyBHPmzIHD4ZD319XVoaCgAH369MHMmTNRVlYmW0Hl5eUQ\nRRFLly7FXXfdhdtuuw2zZ89GXV0dAMDhcOD555/HsGHDMGTIEDzwwAOoqqrCu+++i4MHD2L+/PkY\nNGgQ/vnPf2re89NPP43hw4djyJAheOSRR3DmzBl5n91ux+uvv47Ro0djyJAhmD59ulzuAwcO4KGH\nHsKQIUMwatQorF27Vq4rpVWzZs0aPPzww/L33r17Y+XKlRg7dizGjh0LAHj11Vdx5513YvDgwbj/\n/vtVf3NBEPDRRx/hrrvukvebzWb5XBcuXJDr4Y033sCoUaMwfPhwvPLKK3JZq6ur8cQTT2DIkCEY\nNmwYfvvb3wZ8BojLCwkGoeKWW25BQ0MDzp07B0EQsHHjRkycONGvITeZTHjiiSewcOHCkM65du1a\nLFu2DMePH4cgCC0q24gRI/DEE0/g3nvvxeHDh+VGEADWrVuHf/7znzh06BBSUlLQuXNnLF26FIcO\nHcJrr72G1157DadOnZKP97USNm7ciOXLl2PLli3IyclRieaOHTtw2223wWQy4ZNPPkFSUhIOHz6M\nQ4cOITExEStWrMDWrVuxcuVKbN++HbGxsfj73/8OQGqQ6+vrsX37duzbtw9///vfYTQaMXv2bAwe\nPBjz5s3DoUOHMHfuXM17HjlyJH788Ufs2rULffr0wXPPPSfve/3113Hy5El888032LdvH55//nkw\nDIPS0lLMnDkTjz76KPbs2YO1a9eq3IW++NbF1q1bsWrVKmzYsAEA0L9/f/zwww/Yv38/JkyYgGee\neUZu7JcvX44NGzbg008/xcGDB7FgwQKYTCa/87711lsoKCjADz/8gM2bN8NsNuP9998HAHz22WdI\nTk7G3r17sWvXLsyePTtgWYnLCwkG4cekSZOwdu1a7Ny5E9deey2SkpI0j3vwwQdRWlqK7du3N3m+\niRMnYt68edi5cyceeeQR/OpXv9KMT7SGKVOmIC0tDSzLQqfTYeTIkbJFc+utt+KOO+5o0hp69NFH\n0aVLF8TGxmLUqFEqcfnll1+ajFt8++23eOaZZ5CUlAS9Xo8nn3wSmzZtgiAI0Ol0qKmpwfnz58Ew\nDPr06YOoqKiQ7+u+++5DRESEfN6cnBzU19dDFEWsXr0ac+fORWJiIhiGwYABA6DX67Fu3Trccccd\nuPfee8FxHOLi4poUDF/++Mc/IiYmBgaDAQAwYcIExMbGgmVZPPbYY3A4HDh//jwAYNWqVZg9ezZS\nU1MBADfeeCPi4uIAqK3FVatW4aWXXkJMTAwiIyMxc+ZM2R2p0+lQXl6OoqIicByHwYMHh1xW4tJC\nMQzCj4kTJ+K3v/0tioqKMGnSpIDHGQwG/OlPf8KiRYvwzjvvNHnO8ePHY/z48eB5Hj/99BOeffZZ\n9O3bF3fccUeblDk5OVn1PSsrCx988AHy8/MhCAJsNhtuvPHGgL/v3Lmz/DkiIgLl5eUApEZv165d\neOmllwL+tqSkBE899RRYlpV/o9PpUFFRgUmTJuHixYuYM2cOLBYLJkyYgDlz5oQUGxAEAQsXLsSm\nTZtQXV0NhmHAMAyqq6vhcDjgcDhwzTXX+P2utLRUc3uo+Nbl8uXLsWrVKrlOGhoaUF1dDQC4ePFi\n0GtVVVXBarWqYlOCIMiC8vjjj2PJkiWYMWMGGIbBAw88gJkzZ7a4/ET4IAuD8KNbt27o3r07tm3b\nhrvvvrvJY++77z5YLBb8+OOPIZ2b4ziMHTsWN954I06fPt0WxQWgdn84HA48/fTT+MMf/oDdu3dj\n//79SE9PbzI+Eohjx46hR48eSEhI8LuOh5SUFHzyySfYt28f9u3bh/379+PIkSNISkqCTqfDk08+\niczMTHz99df45ZdfZFdasOD5unXr8PPPP2PFihU4cOAAtm7dKt9DQkICjEYjCgsLNcujtR0AIiMj\nYbPZ5O8eEVCiLNeBAwfw6aefYvHixdi/fz/279+P6OhouRzJyckBr+UhISEBERERWL9+vVxHBw4c\nwMGDBwEAUVFRePHFF/HTTz/ho48+wueff449e/Y0eU7i8kCCQWiyYMECrFixQvZHB4LjODz11FP4\n5JNPAh6zZs0aZGVloaGhAaIoIisrC2fPnkX//v2bXa4uXbqguLi4ycbf6XTC6XQiISEBLMsiKysL\nO3fubPa1AMkdlZ6eLn/v3LkzampqUF9fL2/7zW9+g4ULF8ppx1VVVdiyZQsAYO/evcjLy4MgCIiM\njIROp5Otiy5dushBYS0aGhpgMBgQGxuLxsZGvPPOO3JjzjAM7rvvPrz++usoKyuDIAg4cuQInE4n\nJkyYgN27d2Pjxo3geR41NTXIyckBIAWiN2/eDJvNhoKCAnz33XdN3n9DQwN0Oh3i4+PhcDiwZMkS\nNDQ0yPsfeOABLFq0CAUFBQCA3Nxc1NbWqs7hsRoWLFiAqqoqAIDZbMaOHTvkOvaITmRkJDiOo+ys\ndkrYBWPbtm0YN24cxo4dG9BvvWHDBmRkZGDChAmqoB5xaVH2LK+55hr07dtXc58v48ePR1JSkl/q\nrYfo6Gh89NFHcjbPO++8g1deeQWDBg0KeP1AjBs3DqIoYtiwYbjvvvs0fxcVFYW//e1vePrppzF0\n6FBs2LABY8aMCXjOpq6blZWlil9ce+21yMjIwJgxYzB06FCUl5fjd7/7HcaMGYMZM2Zg8ODBeOih\nh5CdnQ0AqKiowKxZszB48GCMHz8ew4YNw8SJEwFIcZONGzdi2LBhePXVV/2uPXnyZKSkpCA9PR3j\nx4/HwIEDVftffPFF3HDDDZg6dSqGDRuGd955B6IoIiUlBUuXLsXy5csxdOhQTJkyRRaMxx57DHq9\nHnfccQdeeuklTJgwocm6GDFiBEaMGIGxY8dizJgxiIiIULmsfv/73+Oee+6R733u3LmyBaM813PP\nPYfU1FQ8+OCDuPXWWzFjxgzk5+cDAPLz8/HYY49h4MCBmDZtGqZPn44hQ4YE/JsQlw8mnCvuCYKA\nsWPH4vPPP0dSUhKmTp2KhQsXIi0tTT6moKAAs2fPxr///W9ER0ejqqoKnTp1CleRCCJkKisrMXny\n5KBBfYK4WgirhZGdnY3U1FR0794der0eGRkZsqnu4dtvv8XDDz+M6OhoACCxINoNFoulyWA3QVxt\nhDVLymw2IyUlRf7etWtX1XQSAGSzdNq0aRBFEU8++SRGjBgRzmIRREj07NkTPXv2vNzFIIh2Q1gF\nIxRvF8/zKCwsxMqVK1FSUoLp06cjMzNTtjgIgiCI9kFYXVLJycly5gggWRy+g8C6du2KMWPGgGVZ\n9OjRA7169ZKtjkCEMexCEARBBCCsFka/fv1QWFiI4uJiJCYmIjMz028qiV//+tfIzMzE5MmTUVVV\nhYKCgqADgRiGQXm5JZxFv2JITIyhunBDdeGF6sIL1YWXxMSYVv0+rILBcRzmzZuHGTNmQBRFTJ06\nFWlpaVi8eDH69euHUaNGYcSIEdi5cycyMjLAcRxeeOEFeWoBgiAIov0Q1rTacEI9BgnqPXmhuvBC\ndeGF6sJLay0MGulNEARBhAQJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBEGEBAkGQRAEERIkGARB\nEERIkGAQBEEQIUGCQRAEQYQECQZBEAQREiQYBEEQREiQYBAEQRAhQYJBEARBhAQJBoA6hwX/PvkN\nNpz/8XIXhSAIot0S1vUwrhROVuZi78WDAIDR16TDpDNe5hIRBEG0P8jCAMCLvPxZEIXLWBKCIIj2\nCwkGAEGxhhQJBkEQhDYkGFCLhAASDIIgCC1IMOAjGGRhEARBaEKCAUBUiAQvkGAQBEFocVUKRo29\nFg7eKX/nFYIhkkuKIAhCk6tOMGwuO/6281W8tv9deZuoCHrz5JIiCILQ5OoTDN4GAChrrJC3KQPd\nIgkGQRCEJledYDAat6y0KsjCIAiC0KZDC0aj04ptRbtU8QqtGIWoypIS/fYTBEEQHXxqkK9yv8Oh\nsmzUOiyYcO1YANpZUOqBe7zffoIgCKKDWxhFlhIAQLkiXqFlYQhkYRAEQQSlQwuGRwhYxnubWjEK\nGrhHEAQRnLALxrZt2zBu3DiMHTsWS5cu9du/Zs0a3H777ZgyZQqmTJmCVatWtdm1BUjWglIwtARB\nmSVFgkEQBKFNWGMYgiBg/vz5+Pzzz5GUlISpU6dizJgxSEtLUx2XkZGBuXPntvn1PeMrGDDeMmla\nGDT5IEEQRDDCamFkZ2cjNTUV3bt3h16vR0ZGBrZs2eJ3nBimuIHXJeUVDK2pzMklRRAEEZywCobZ\nbEZKSor8vWvXrigrK/M7bvPmzZg0aRKefvppXLx4sc2u73E1MQFcUk7B5beNpywpgiAITcIqGKFY\nDqNHj8bWrVvx/fff4/bbb8eLL77Y5tcPFMNw8A73cd5tDXbvmA2CIAjCS1hjGMnJySgpKZG/m81m\nJCUlqY6Ji4uTPz/44IN4++23Qzp3YmJM8IPcnqjICIN8fAVM8u7YeCO6RMVAf46Tt8XEGEM7dzvi\nSitvOKG68EJ14YXqom0Iq2D069cPhYWFKC4uRmJiIjIzM7Fw4ULVMeXl5UhMTAQAbNmyBdddd11I\n5y4vtwQ9hhck95Ld5pKPr6qul/dfLK+G2KhHo9Uhb6uubQjp3O2FxMSYK6q84YTqwgvVhReqCy+t\nFc6wCgbHcZg3bx5mzJgBURQxdepUpKWlYfHixejXrx9GjRqFL774Alu3boVOp0NcXBxee+21Nrt+\nsHEYDncMQzmYT6D1MAiCIDQJ+9Qg6enpSE9PV22bNWuW/HnOnDmYM2dOWK7tSZdlGO20WqcgWRZK\nEXEJFPQmCILQokOP9BY9A/cQKOgtBbhpxT2CIIjgdFjBqKi1yu4lRjUOQ2lhSIIhqBZQIguDIAhC\niw4rGC98uFsWBzbASG+tcRgusjAIgiA06bCCAQAKw0JGaxyGQC4pgiCIoHRowfDgmYQQ8LUmtEZ6\nk2AQBEFocXUIRgBB8GREKbdRWi1BEIQ2V51gKFfUcwpO/FSYhdM1Z+VtFPQmCILQpkMu0eobhxAD\nTF/uElxYf35zk78lCIIgJDqkhWF3qK0E5QJJvEYMQwnFMAiCILTpkIJh8xGMBpt3riilteHUEAyK\nYRAEQWhzVQjGmeIa+TOvimFoWRgUwyAIgtDiqhCMQEuwarukwrP6H0EQxJVOBxUMtRAEWoJV2yVF\nFgZBEIQWHVQw1I0+H3Achv/qehT0JgiC0KaDCoYLUIzuFgNYGC5Rw8IgwSAIgtCkQwqG3cEDjFIk\ntGMYTp7SagmCIEKlQwqGwyUAjLZIBLMwRBIMgiAITTrkSG8XLwCst+G3O1348PtjiE7LwxnLGXm7\ndlotCQZBEIQWHVIweF5UuaTAiDhwthCmmL2q47TSaimGQRAEoU2HdEm5BAEMq8iU0lgXQzqOBIMg\nCCJUOqZg8KLKJRVl4jSP0xyHQQP3CIIgNOmYguESAIWFkdQpAso0WwBgwcpreitRTlRIEARBeOmY\ngiGIYBQWBsOIqqwpAICokxdQUkIuKYIgCG06pmDwagtDSzAYgdMc6U2CQRAEoU0HFQwX2Ngq+ftF\n60UweofqGJHnNGMYNA6DIAhCmw4pGMW6Q9CnnJe/23g7jDftUx0j8tqBcAEU9CYIgtCiQwqGhSsO\negzv1B6CQhYGQRCENh1SMBhRH/QY0aUtGJQlRRAEoU2HFAwI2u4mJSIfQDDIwiAIgtAk7IKxbds2\njBs3DmPHjsXSpUsDHrdx40b07t0bJ06caPU1GSGEGU98BCMlojtEgYFIMQyCIAhNwioYgiBg/vz5\nWLZsGdavX4/MzEycPXvW77iGhgZ8+eWXGDBgQBtdOLhgKC2MCFs3PNzzdwAYimEQBEEEIKyCkZ2d\njdTUVHTv3h16vR4ZGRnYsmWL33GLFi3Cf//3f0OvDx57CIkA7iYPDBjVMRwM0LEcIDKUJUUQBBGA\nsAqG2WxGSkqK/L1r164oKytTHXPq1ClcvHgRI0eObLPrBkqZ9aBnjBAVcQ6O4cCyDCAyECnoTRAE\noUlYpzcXg0zkJ4oiFixYgDfeeCPk33hITIwJfN4gv43QG1GvEBW9jkOXztEAGDBM0+duj1xp5Q0n\nVBdeqC68UF20DWEVjOTkZJSUlMjfzWYzkpKS5O8NDQ04c+YMHnnkEYiiiIqKCvzpT3/Chx9+iL59\n+zZ57vJyS8B9vMYcUUoMPhaGKDCorWkERAa8wDd57vZGYmLMFVXecEJ14YXqwgvVhZfWCmdYBaNf\nv34oLCxEcXExEhMTkZmZiYULF8r7o6OjsXv3bvn7I488gpdeegl9+vRp1XWDTVEeoTepUm85hiWX\nFEEQRBDCKhgcx2HevHmYMWMGRFHE1KlTkZaWhsWLF6Nfv34YNWqU6niGYUJ2STVFsLEUPWJScFpp\nhIgsWIaBKFJaLUEQRCDCvkRreno60tPTVdtmzZqleey///3vNrmmCAEMgFE9huPnoh2qfcmRSXjg\nhonY+mOm4geMZGGALAyCIIhAdMiR3h6X1K+6DfXbd0+vX8PA6QFVJhUDTnZJkYVBEAShRYcTDFEU\nZSuBZfxvz7NNFfT2WBgkGARBEAHpcILh4r0NPsswfvs5t2BE6k3ejQIjHSsyCJ6USxAEcXXSAQVD\nkFfXa8rCeO0Pw+Vtouh2SVEMgyAIIiAdTjB4wbscK6NxeywjuaJiIg3ejeSSIgiCCEqHEwwXL8Dj\nVmrKJaVEEL1Bb9+1vwmCIAiJDicYgsLCYBkWg5NuUe3XEozrr4kiC4MgCCIIHU4weEEEoxCMGTdP\nx8ged8j7PS4pADByklsqMtojEgxAMQyCIAhNOpxgCILXQmDcLikd6xUJZSA8UhcJAGhwNkrHgwEY\nWnWPIAhCi6CCYTabL0U52gxl0Jt1356eUax9oRCMLhGdAACe2Ug8QfIXtr8SdAJDgiCIq42ggnH/\n/ffjz3/+s2qSwPaMSjDcFgYXwMJ45KbfYHDSLZiUNs69RTre6rLB4qy/NAUmCIK4QggqGFu3bsWY\nMWPw3nvv4d5778XKlStRX99+G1PJJaUeh6FjvRaGUjA6RyRgxs3TkWCKB6BOw9Uaw0EQBHE1E7RV\nNBgMmDx5Mr755hv885//xCeffIL09HTMnz8flZWVl6KMzUI1DkNDMJTWhi8svGm4bTFrLkEQREci\npG50cXEx3nnnHTz77LO4/fbb8emnn6Jz5854/PHHw12+ZqNKq3ULgC5ADMMfr2BQ4JsgCEJN0OnN\nn3jiCeTl5eGhhx7C6tWrkZCQAAAYNGgQNmzYEPYCNhdeUA7ck8Qh2hAl72/K1cSChSfUzZNgEARB\nqAgqGJMmTcLdd98NjvN35axfvz4shWoNXguDkdNqY/TR8v6mBINR7BNEypIiCIJQEtQlFRcXh8bG\nRvl7XV1du86Y4kVp4B6jcC/FGryCwTGBYxgMuaQIgiACElQw3nzzTURHexvc6OhovPnmm2EtVGvg\nec/Eg97GP8bgXfi8qRiG0voglxRBEISaoIIhiqLs2gEAlmXB8+3XXeNxSSlTZCN03rUvmnRJKUSG\nJ5cUQRCEiqCCERUVhaNHj8rfjx49isjIyLAWqjXw7nEYysZfJXhNuKRYKGMYZGEQBEEoCRr0fv75\n5/Hkk0/iuuuuAwCcOXMGS5YsCXvBWoogui0MjanNAe0pzz2og94kGARBEEqCCsbAgQORmZmJI0eO\nQBRFDBw4EHFxcZeibC3CM3CP9TGeJl47Dudq85sMeivFhOaSIgiCUBNUMAApU2rkyJHhLkubIGi4\npABgbM/RQX+rtEoo6E0QBKEmaAwjJycHv/nNb3DLLbfgpptukv+1VyQLAwFdUk2iWG3vqy25KC5v\nv3NmEQRBXGqCCsYrr7yCZ555BqmpqcjKysLMmTMxe/bsS1G2FqGVJRUqjEIwiivqsXxDTlsWjSAI\n4oomaKvqcDhw++23QxRFJCUlYfbs2di+ffulKFuL8GRJNRXcDoTqJ4wIp4vcUgRBEB6CCgbLSofE\nxcUhJycH1dXVKC4uDnvBWopnidaWWBhKlxQY0T0vFUEQBAGEEPTOyMhAdXU1Zs6ciWnTpkEQBMya\nNetSlK1FCK2yMETVZ8+ocYIgCCKIYAiCgNtvvx0JCQlIT0/Hvn37YLfbVVOFBGPbtm1YsGABRFHE\n/fffj5kzZ6r2f/3111i5ciU4jkNUVBT+8Y9/IC0trWV3A/dstYzYsgWQfFxSZGEQBEF4abJVZVkW\nf/vb3+Tver2+WWIhCALmz5+PZcuWYf369cjMzMTZs2dVx0yYMAHr1q3D2rVr8fjjj+O1115r5i34\nXrMVQW8oXVICXAJZGARBEB6CtqppaWkoKipq0cmzs7ORmpqK7t27Q6/XIyMjA1u2bFEdExXlXaui\nsbFRjpm0FHngXossDJ8YBrmkCIIgZILGMKqqqjBx4kQMHjxYNYfUokWLgp7cbDYjJSVF/t61a1cc\nO3bM77iVK1fi888/h8vlwooVK0ItuyaC6Fk8qXXjMMCI4LucxvGKLri5S/sdd0IQBHGpCCnonZGR\n0aKTh7ou9vTp0zF9+nRkZmbigw8+wOuvvx70N4mJMZrbjUY94BCh1+kCHhMInY4BXNJnhnMBKXn4\nMDsH3/7mw2ad51LT3PvsyFBdeKG68EJ10TYEFYwpU6a0+OTJyckoKSmRv5vNZiQlJQU8/t5778XL\nL78c0rnLyy2a2y31dgAiRCHwMYHgeUWQW+cIeq32QGJiTLsu36WE6sIL1YUXqgsvrRXOoIIxa9Ys\nzWk2QnFJ9evXD4WFhSguLkZiYiIyMzOxcOFC1TEFBQVITU0FAPz888/o2bNniEXXRhBEgG1ZWq3I\niJ7lwMHonK0qB0EQREcjqGCMGjVK/my327Fp06aQ0145jsO8efMwY8YMiKKIqVOnIi0tDYsXL0a/\nfv0watQofPnll9i9ezf0ej1iY2PxxhtvtPxuALhEAQzT9EJJAcurmMmWBIMgCEJNs11S9913H/7n\nf/4n5Aukp6cjPT1dtU058E+ZttsWCO6xEy0RjP6mdGyuKwTD8SqXFEEQBBFCWq0vDMO0OM32UuCS\nBaP5LqkYXRwceYMBAAwJBkEQhIpmxTBEUURubi5uv/32sBespbTGwuBYBhCleyWXFEEQhJpmxTA4\njsOMGTMwYMCAsBaqNXim8+BaKBiiRzD0JBgEQRBKwppWezngW2Fh6DhWtjAIgiAINUFb1WnTpqG2\ntlb+XlNTg+nTp4e1UK3B5V5atSVTjBgNHCD6/47W9yYIgghBMBobGxEXFyd/j4+PR319+126VGiF\nS8pk4DQtDJdIgkEQBBG0VRUEAY2NjfL3hoYG8Hz7bUB5d+POsVyQI/0xGXTagiG4Wl0ugiCIK52g\nMYzx48djxowZmDZtGgDgq6++wsSJE8NesJbiEqV0WBNnaPZvA1kYToEC4ARBEEEF449//COSkpKw\ndetWiKKIhx56CJMnT74UZWsRLlFq3I06Y7N/azJwUK+i5D4nWRgEQRDBBQOQMqWulGwpwT3dbEst\nDJFcUgRBEJoEjWH8+c9/Rk1Njfy9uroaTz/9dFgL1RpccFsYXPMtjEBZUk4SDIIgiOCCceHCBcTH\nx8vfExISUFhYGNZCtQaXKDXuhhZYGByrPQ7DwVMMgyAIIqhg8DyvyopyOp1wONrvPEs8pLIZWyAY\nADQFw+4kwSAIgggawxg+fDhmz56NRx99FACwYsUKv9ln2xOeGEZLLAwAmoJhdbZfgSQIgrhUBBWM\nOXPm4OOPP5aXTR01ahSGDRsW9oK1FF6OYbRUMPyNLhIMgiCIEFxSer0eTz31FN5//33cdddd+OGH\nH/DXv/71UpStRQiMZGG0JOgtwUAU1NXicJFLiiAIokkLw+VyYevWrfjuu+9w5MgRuFwuLFu2rF3P\nVutxSbXYwgAAgQNY7/reNhIMgiCIwBbGa6+9hjvvvBNff/01xo8fj6ysLMTFxbVrsQCUFkbLBOOv\nvx2MCL36txabDau3nUNtA7mmCIK4egloYXz11VcYOHAgZs6cidtuuw0A5IWU2jMi4wKDlge9r+sR\nh+hCI2yoDujAAAAgAElEQVS2BnnbgbwSVJ5lcMFswdMP3NJGJSUIgriyCCgYO3bswLp16/Dmm2+i\ntrYWkydPbteTDnoQWUkwWh7DUE6NzgAQYbHbAACVdfZWl48gCOJKJaBLKjY2FtOnT8fq1avx/vvv\no7a2FjabDdOnT8fXX399KcvYLES29TEMxl0tBkYSHY+b6wowsAiCIMJGSItG9O7dG3PnzsX27dsx\nffp0bNmyJdzlahGCKIJheUBkWjS9uQfWrQwGxgQA0jmhNS0hQRDE1UNIkw960Ov1uPfee3HvvfeG\nqzytgudFgBHAiC0XC8C7vKue1QMCAM49lxQpBkEQVzHNX5auHePiBYAVwKJ1gsG4lUHHSHoqWxjk\nkyII4iqmQwkGL7gtjFYKhsclpWP07g28e3urTksQBHFF0+EEg2EFsK10SXmC3hzLQRQBcLy8hyAI\n4mqlYwmGxyXFtI2FAYiAwIFhKUuKIAiiQwmGy+2SanUMg/FUiwjwOtnCIMEgCOJqJuyCsW3bNowb\nNw5jx47F0qVL/fZ//vnnyMjIwKRJk/D73/8epaWlLb6WZGHwbRb0FhkRosAp0mpJMQiCuHoJq2AI\ngoD58+dj2bJlWL9+PTIzM3H27FnVMX369MHq1avx/fff4+6778abb77Z4us5XTwYBuCaly3sh8ol\nxXNyWi1ZGARBXM2EVTCys7ORmpqK7t27Q6/XIyMjw2/Q39ChQ2E0SiOqBwwYALPZ3OLr2d1LqbY2\nhqF0SYkC586SEsm+IAjiqiasgmE2m5GSkiJ/79q1K8rKygIev2rVqlat5md3L3TEtTbo7XFJuWMY\nDAMpXZdMDIIgrmJa57sJgiiKIR/7/fff48SJE/jiiy9COj4xMcZvm6myAgBg1Bk194fKr3oNQk71\naQxM6Y8LxYekjRwPg0HXqvOGi/ZYpssF1YUXqgsvVBdtQ1gFIzk5GSUlJfJ3s9mMpKQkv+N27dqF\npUuX4ssvv4Rerw/p3OXlFr9tFVW1AABRYDT3h8rAuIH429Du4BzR+F48LG1kBDidrladNxwkJsa0\nuzJdLqguvFBdeKG68NJa4QyrS6pfv34oLCxEcXExHA4HMjMzMWbMGNUxJ0+exMsvv4wPP/wQCQkJ\nrbqeJ4aha/U4DBbdopOh43SAe7lWhlxSBEFc5YTVwuA4DvPmzcOMGTMgiiKmTp2KtLQ0LF68GP36\n9cOoUaPw1ltvwWq14umnn4YoiujWrRs++OCDFl3PyUvZTJ45oFqLjmMA0T3V+fWHwdti2+S8BEEQ\nVyJhFQwASE9P9wtkz5o1S/782Weftdm1HG4Lg2Pb5rY4jpUFg42yoNywF0DLg/IEQRBXMh1qpLfD\nbWHo20owWAai4K0igXG2yXkJgiCuRDqUYDgFdwyjDQXDY2FIUAyDIIirlw4lGK42jmFwHCMHvQFp\napAfC37Bq3sXwim42uQaBEEQVwodSjAcotslxbWNYLAMA4hKq4JBTtVplDRcRJ2d0vQIgri66FCC\n4XIHvfVsaGM5gsEwjHoxJhGwOOsBAA7B0SbXIAiCuFLoUIJhd0mCYdS1jWAAUM98KzKoc0iWhZ23\nt9k12gKn4EJ+XWGzRtd3VBy8EwV1Fy53MdoFdt6Bwrqiy12MdoHNZcMFS/HlLsYVTYcSjEaH1OuP\ndk9m2BawiioSIaLe0QAAcPDty8JYeWoV3jqwBMcqTl7uolx2lp9YiTcP/At51WeDH9zB+fDocrxx\nYDE1lADePfQRXt+/CGWN5Ze7KFcsHUowrC6pEY+JMLXZOZUz3wqMQ5qQEFLPrT1xwCxNYZJPPWtZ\nNIuokcTpmnMAgNKGls8C3VEoqpemKaqwVl3mkly5dCjBsDsll1SbCobCJcVzXjdUexMMz7QlHkEj\nQDVBaELPRcvpUIJhc1sYkQZDm51TKRgi6xWJ9hbDoNUA/SHx9EKxLSVUFy2lQwmGZ2qQthq4BwRe\nW6PdWRju/6lhIIimoXek5XQYwRBFUZ58sK3SaoHAq/fll1Uj70JNm12n1dBMun5Qw+CFrC0vVBct\np8MIht3JQwAPoO1Gekvn0haMPaeK8frKQ6i3to/5pbyrkNPLQPhDT4UXXhQudxGuWDqMYNRbnQAr\nPQhtNdIbCOySYlhJnBrajWA0HfQ+W1KLV784gGqLN/bCCzyWHPkUe0sP4oKlGG8fWNKhMkiaI55O\nwYVFh5fiUFk2ztcW4O0D76PGXhvG0l1axGY0knbegXcPfYhjFSeRV30W7xx8X04nvxQIogi7kw/b\n+V3NmNan0WnFOwc/wKmqPJyszMXCgx/A6rKGrWztnbBPb36psDl4MIz0UrRpDIMNsBgTJz3Q4Xyw\nm4XHJRWgjVy8KhuWRic27C7A9LtvAACUNJhxqioPp6rykBLVFaUNZvxw9j+YcfP0S1ToMNOMbvXZ\nmvPIqz6DvOoziDXEoM5hwab8n/GbGyeHr3yXkOY0kscqTuJMzXmcqTkPlmEhiAKyincho9ddYSyh\nl7e/Ooycwhosff5O6Li279M2Zx64/ebDOFebjyVHPpW37bt4GCN7/KrNy3Ul0GEsDLuDly2MtnRJ\ncYHO5bYwbI52IhhuAvWqBcF/u15DDHmxfd1Pa2iO60HZyfDUYUeqi+Y0klodLkG4dHWRUyjFBsP1\nbrmE0L0CWi7pjvRcNJcrVjB4gVdlKtkcPMBIf8i2Wg8D8HdJiU4pZZdxWxiOdmJhsB6XlChqZnBJ\n8R0RDCOZ2YBk+su/Z6RHwdPIWl02CFe4r9cpODVH5Dt4h3xvnrpQ/p09nwW5Lqwdti7s7roQRVGu\nC2XSiP9zYb1kyQROV3jq3Mlr14XNZYcoihBFUXY76Tn/BBrls3O1JVZcsYLxjz1vYU7WXPm7zeEC\nWAEMmMBupBZQWmFTfRftEdIHzuW+bvsQDE/Y++eiHZiTNRfnawvkPQ3ORqD/f6C/NhvFTDae3/4y\nTlXlwSV6e52cu2EQRAE2lw3Pbftf/Ethhl+JbCrYitlZc1FcXypvq7XXYXbWXHyduwbrz23G89tf\nxrnafFWvkVXUhcVRj+e2vYxPjn1xycvflqw7twmzs+aqpsWosFZhTtZcrD2zAavPrMfz219GcX2p\n3PkAvB0RQRRQbavBc9texoqTX1+SMjtd4Xm3vjuzHrOz5qLa5s1yLG0w49lt87Dh/I/4Kvc7PLft\nZZQ1VgSwtgRcbCjD89tfxjd5a8NSxvbKFSsYFTYpOOtRe5uDB8MK6skC24CGRvVDK7oMEGwRYE31\nAMIbnGsOvlm1+y4ekj97gre6LqUoZrIBAEfKj4MXlI2kt1dd655gMa/6TDiLfMk4XHZM/lzWWAEA\n2FmyF//J/wkAcKIy16cuPL1qHuVW6fjsihOXqrhh5WRlnvzZM1XGlgvbsPXCdgBAXvVZVUdCfi4g\nyPNR7XdPQxNuHGGyMDycrc33fq45DwDYkP8TdpbsAwAU1F3QTBYQIOCc+7fbi3eHtYztjStWMDxY\n3Wa05JIS2jR+AQBjh/RUfTewRojWGDB6J6B3tBvB8A1dWF1ey0hpNovuxAAWjCoQyroVhxcFzTHj\nVrsLLl798giiiCWrj+GnA+17/iqboi60rE8GjE8j6XXvNbXKYqHZglc+24eKWv+smUZb+8ie80X5\nXJg4/0k6GTBwKcRTaXkyl3isTyCXlCiK+GDtcWzcW9iq8yvrItYQ47efYdR14UFyz12d456uSMHI\nyfemflocVjh4J/KteQArgGvD+AUATBl+neq7iTNBaJQeLjbCIgXbFdTYa+XJ7wRRwKGybFUansVR\nj1NVedCiuL4UedVncaTsWLN8o5W1NjTa1eVQvgzKxtAFyXfr+zJ4PgsiD9+XodHmxJPvbsPSH9S9\nbEujE4fyynEoT3v2zypbNU5U5gKQYk4HzUdhc3nTemvtloBWTKGlCKerzyK7vPU9e1VdaAR/pbrw\nbmfg9ds3NeXK+2uOodBcj9Xbzqm27zlxEU+9tx0HcsrkbRXWSvnv7hRcOGg+qvKj19hrcbpafR4P\nBXUXcLr6LI5XnGrqNkPCynufRa26YBnfjoR2XRRcbPkCYubGcvnv7uAdOGg+KgflRVEEY7CCja7W\nFIzztQXINufhUOkJfPtz6yxgm+od8RcG346E5/4FgW8z8SxtMOOM27qxuew4aD6qsnYrrJU4X6st\njGdqziOv+ixyqk63SVlC4YoUjL9+uFP+XN3YgG/z1uKQYyNYU2ObWxg6nx5plF4hGJEWPwvj9f2L\n8FH25yipv4jdJfux7PiX+PLU/8n7Fx78AEuOfIpCi/8aBQv2vYtFhz/GJ8e/aNbU3BfK6jUsDG/D\n4BkBr4QFq3oZPI1XlcXmZ4afKZZcWgdy1cJQ4x7TESiO8/c9b+GDo8tQbavBz0U7sPzESvzf6e/l\n/Qv2LcSiw0s1p5t+Y/9ivHf4Y3x8bEWr13NQ1YVWIwlWJZ4WmycpQPBz9YmiiHU7z+NMUS3sTqme\nDDr1M/LzYcl1s+Wgt9wv734DS458ikanFRvP/4TlJ1Zi3blN3v27Xsd7hz9Crc9KjqIo4s0D/8J7\nhz/Gh9mftXpqbmUjqVUXDMOqGk9PxpggqJ+Jv3++v8Vl+Meet7Do8FLwAo81ZzZg+YmV2FzwMwDA\nxQswDciCsc9eNDjUlptLcOHtg+9j6cllMN54END5z+dmd/BY+WOearxRXmE11mzzf58aVe+Iv0XI\nMqymq9IpuNrMvvjn3nfw7qEPAQBf567B8hMrsU3h5np59xt4++ASv6SLRmcj3j30IRYd/hj/OvLJ\nJZuq6IoUDGXPo87WgNOKxrUtB+0B0kMyb9hz8vdoYyREqyQYTKRFbixdvIB6qxMWh7QiX3Fthezz\n9UwxDQBlbp94sCVeq+yhTzui0/k/vqH1qr0vg+eBq6hthM2pfnnyA/QmaxuklzKQW85z3TqHBWer\npSC8Mhhf75QGgzU41Q2D78tR66jTPH+oWF02rNych882nNJMqfTtVTfYbe5y8H6B1+LyBqzZfh4L\nvjwoZ8gZ9OrXyKCXBETLB2912WTfuXKRJ08jbePVSRYOn/L61lVzaQzyXPi6Kj3PhUt0+fWqeaF1\nMQYbb8cZ97vhSUxQdj58BUP5TAPeTEUPdQ0OZO7Jx5aDRVj0f0fl7c8u2oZ1u8/7XT+QFe5h+9FS\nVV14EiPsgqPNJ/t08k7kVkuWwoYjx5F1RD01v83n3n3rwsE7UFzRgL98vBtF5fVtWjYlV6RgKLHY\nG1UpkW2ZUushOSpJ/hylj8DYW24ECx3YCIvcaHy49jhmLdouH/fxumMoqZL+cKFYPXaXuofg8R2H\ngsMp+LlUg70MDMOAV7wMckPFiKi1qh/G/FJJMDrFqn3eNfVSmW0OHrwgYM/Ji35xDkBaCfHwGck9\no5V14tuI+84EHGg+r1BpdNmw5VARtmeXajaSosioc+vdDREvCigqV4ulskHzCKVRry6fQSf97Rwa\nWT523i5fS6su/BsGdaNpaXTg2LlKv99poZUKHNzCYFS9aqf7byNZoL6uytDHdmhhddnkZ9Mz3kHp\n4m1wNt1IesZCAcCR0xV45l87sHmfJMLFFT4j0xn/uvB07gDtusg+VwGHhnVudzlk67qtsPI2ud7r\n6l1YsTEXh097rUnfe7f5vCN23oGVm3NRVm3FF5ty27RsSq5IwdD18FaIxdaA+kbvw6CVN92WROhM\neHDUDegW1RVMRD1+OVqEilorjpaehq6HNzbBsALqGqU/slag1bdhqKhXP+AsGDgFFz47ugqf/LRb\nsyH2YLW7/F4Ia5CGgQULp6KR9BzPRtXhZIXXJ7rmTCYuOiWrwGRQN3A19W4Lw8Fj874LWPrDScx8\n6xdsOH4Qmec2y8eVVNXJ5dOaasX3ZWj0szh4OHgnvsldi4sNZWgKrdiPxe6tW6tTYyyCg1fHMNwD\nQE9V5aHI6vUfr8r7AWdqvdas51K+o5E9FobTKeBkZS425W/1Xt9lky27UOrC9zn5dMMxvLvqID49\n8g0qrE0Lh1bA1mPVSfv9rS1RFDU7GEfKjyO/zlsX+tSTOFUh1cWBnDKV++1gbjl+3O+fCHGs4iR+\nKsySv1tdVjhd0rUEQRIjq8N77UZno+r3vnUBlkdlfR1e+XEZ1uyR4oYeq473HajKagmGtzOgaYWz\nAup8Ok8AsN98CHkV3vv7Kuc7Vd2UVjbgi025ftZpWXUj/r0pF1a7C4fKsvHLBa9r/bON2bB6LHtR\nqot/rT4i728M8o44eAecsELf8zhELnxLL1yRU4Pou3nNy3qnDQLPyHdivASCAQBdIrqgqKEYjN6O\n1VnnYOyzV30gy0trjBu8vSdlY2b1cT1U1qt7snbeiUMlOThQuQ+uslSkHb0Gowf10CyTJ0NMtY23\nSWZuYR2+2HUSSPH/DR+h3UPcXu5t4H4qzAK6AigY5zdI0WNhNNpdqpl7M8u+UR13zlwll88TE1L2\nfn0byYp6tUm9YvNJpPevwbayXTDpjJiUdo9muStrrfIU90osTgukIA+Depv/y9RodyEmUtuttqdq\nm/z556Id7k/jVMf4irnn71xRa8P7R5ep9lldVtnC4FgOlbU2REcxiv0+DYNfQ2EHl1CPw1XH0MPc\nBeN6jtEsNwAsyzwORKu3VVqrpMCyjxvOQ1b2BQy6qZPm+Tac/1H+rOtaiBVnPsOGLVNRaJb+XqMH\ndQfDMHh/jZTGPGZwD7Cs994+yv5cdT6ry4Z6uwNggNOFdUB/oMZqUe1X3buPtcWwPD7ZsQXlhlw4\nXAyAnprllg72FwzlvGmaU6cwAsw12nNo5dgOyp93lOzFjpK9eHXofPyw4zx+OSKlK3dPjFK9s59t\nyEHuhRpwLINd3Jeq82UXmGG8yQWGAURB6oAwesXAZJ9793Vd2nk7GkwF0MUXwVbnfdkdTh42B4/Y\nqLZZI+iKtDCUNDqsqh68oY1jGL6Y9JJbJlLvXtWP5REVoSFSnEsSDHhdD0pXi+/LUNWgbiR3nSzC\nsi173OdyykG8eqsTe0+a1eJjd4Jh/XvWFxvLcPR0pWYjeaGsXrMH2hS+PnmlWX70bODe7skLZsBd\nPo7hsOt4KaoavHEJ37ow16on/attbMSmbClbqqC8StOKKDRb8Ng/NuOrrf7muAAejEnqrTbY/evi\nXGk1LBo9yVBxOH3E2uHJOPMvp2RhSI2Tpd6F5z/chR8Pn1HtBwBzdSP2nTLLaeMeGJYHEyk1qr4N\nqBKeF7A/76Lf9kaXVR6Xo2V55ptrYXWEHkD1iAXgH8uqD5JabHVZIbpnZ6ixSGWptgUWDD8Lg+NR\naZcsTkbX9LUYDcGoddTJFpfm1CmsgBP5FU2eV8mcJTtlsdDC8zyczPef4JPhXN53WHQ3y3rt9sLh\n5HHkXKny57DzDjj0UqdNUCz09q/Vx/DMv3bgfGmdKmuvpVz5guGywil6K6gt18LQwshJSm3SuRWb\n42EyaOT2cy44eG9PEpCCvx58X4Zaq1owTpdUgYmwyOdqtEsP9NIfTuDjH05g9wmpMcgvrcPq7drp\nhcX1pdLvNF4WvV47ttEUdQ0OrPwxT0468AS9g/7O1ii/sJW1Dny6/hS+yjou71fWhdXuwsqtJ9Un\n4FzgoqX6OVFYhiNn/F/i00VSI7gtW3sdb09dNmg0hgVldfjpYMtz+n1dD02N/le6pKrqpEYu6+R5\nxX4raurtePXfB/HR9yewJ9fnfjgXWLdgWJ2BRa6i1ia71nzxBJg1G0lGQKWlZTPT+sY0ausd+H8/\n5uGERgMJSHUhwlvGaosdtQrBsPE2rN1+Tvbl1/uJpwt2TmokGc7/XmwOlzdBJkBdlNRL75FTa34p\nRoCIlo+z8nWLJcRInc3Sykb/g5Xld78rSgtD+Y5k7i7AtuPq5zWvuAINkOq5rM6CI6eld+TEeWnb\nu98exQdrj6O1hN0ltW3bNixYsACiKOL+++/HzJkzVfsPHDiABQsWIDc3F++++y7uvvvuZp3/jOMw\nlIO723KmWi0i3BaG0T3oSZ9yDj/VHQcbpT5Ol3Ieojt4esFSjKd/fglDEofJ+60uG+oaHWAZBrwg\nYt3eMzBerzgB620YuE5l2C0sR3XW7TgbfxCmgQJOFMUhB1uxP6cM+l7aZf3i1LcwGDqBYRP99rkE\nV7MsDOMtv8B+dCS2HCxCgXgQxdwRONleAG4AIMLQex+E+ni4im70+62u+xk5o6VKLILp1hKcr0sD\nIjx14W0IDuSUQWDULy/D8oBJskh0XUqxrPAd3G8aj435W8AyLP46dDZ2NayDvpcDjF5bxIzXHwFf\nlwCro7f/TlaA3eVEqF0NY7/tsB8bIZWnRy726jeh9KfBEC6m4flpt6C002bohAS4iq/3++33ZzfI\nAcs6w3mYBhfAXuP9A1pdNhw7Vymvs5JbXA50UZaVBxshieeu0n3YZz6EKddlIPPcZhg4A/42dA6W\nHf8SUboEGK7L0Sz/h9mfoW/n3ugWlaxZF1WW0DOxjH13wX5Cmrn1q1NrcarhMLiuN4I398T6PeeQ\nzfyAnTu6YX6nh/1++395P4CH2wrvWohX9v8DaRE3y/uPnr+Ig6WRAIDlfxmN/Xm+4snDZagFA0CX\nXAAu8QKcF26EvsdpiC499uddj12N30N3DQuuk1mz/IsOf4wBXfrByET67WMYIaDQaGHovReOHOkd\n1/c6hu/rNsFUPAUjut8GO+9AbsT34JK6gS/7L//f9vI25vpu56HrWgi+3OvO2pNbhBuib0ZCjBFl\nNVY/gVy39wz0PS3u35/D0vx34Np+A0yD8yA6TKg/fgdMhtZ3psNqYQiCgPnz52PZsmVYv349MjMz\ncfasOh+6W7dueP311zFhwoRWX++G+DQMS7m11efRwp43CC7zNegaIWVMeSwNrpMZbJR/2qlvyp9L\n5LGr2Ju73ui04pnFO/Dy8n3IL63zewAYnROMydvTY1gRObb9gN4ORu9Efv15HCw7CrZTKXRd1Oap\nEoe+SvOhd/Au2TUi8sEfA9Zok3tBhY35ACOC6+R2eXAucLHV0Hc7j+t7xPn91rcuGFaAI8abXqvs\nPVXX2wHORzD0dvA6b12IjID/nP8J9c4G1DksyK0+g4uu89AlFoOLD+xC4GKrNd0tyoYhpLqIaPAG\n8eMqAUZEfmMeThfV4qKlBi5jFfTdz/pllQH+2S0MJ4CP8/YWyyx1qKhRWFw+bifWaAVj8J7DJbiw\n4fyPsqsppyoPOdWncbB8H9jowOnIJypzAvjtRdS6kzVEPnh2GhtVB88goOOVORAhgkuQGucD5/Kl\n/d1y8NyHO/x+6+uHd8GJHIuiF+zzHJRb1O8ZG1EPRqdIVuAE6LufAaNzgTVZsWLndhTWF0Kfkg/W\nGFgEj1Qcw84TGpYpKwKMdG+iEDyNlout9n6OL4MIUR54WmQpgUNXC0PPU2prwlN2nc/7z/HgOnvd\nWycKzXj2/Z1wuni4XALg44JjI+tUbmmG4+WOGhvRACbCgi5xEUHvIRhhFYzs7Gykpqaie/fu0Ov1\nyMjIwJYtW1THdOvWDTfccEObjJx8etAf0bezfw+3LRBqkuAs6CtnwHgEozkweu8fuapBevirLXZY\nGp1+DxEbVecXl1D+vhL5Aa/j55Zj/R9QycJwC4ZDepBEgYVgjfI7Vr6+u4zy/6ZGgOFVYhdhCvhz\nFSKnMLcVDcfFqkY/8WSj/RcyanB5zfpD5qN+++XruNQWp+8YEwAAI4DxNAzuuhDsJgj2Jm7GU0Z3\n3TKR9QAEfLjOm9kSGer7qXj59+YUyVONdIo1wi74pBhH+4/PaVBkEy3d/qPffg++dRHIDVNRJ51P\ndEj3LzTEQHQ10Tt1p7d6FhWTLGNRTk8GoNlIaqGMRSifA0EQYfcRGK26UL4jus6BO1IRnM/fVsuS\nUFoYLkn8+drOclBaGxFRJp2cgexx/ymTPIb01U4q8Lu84l489feXj/fgTEmtxjvSdF2wkRYkxof4\ncjZBWAXDbDYjJcUbse/atSvKylofeLmceATD0ALBUFLeUA02tgKMwYq9F05IvVYFWo2kEq5T4Hrk\nGE41sEjrYSqutOBAgWTteRoGhhUAIXCvkjFYwRisgLs3xDAAl1AGNtrbs7JFNX9UdqW1CvNXb8Ci\n7/di/4WT4CLUPt5gdXG0iYkBRd4nFVjQqDfOG0iW60LnbLoujFZAb5N7hgwrgI2rQJldEfSMD9xg\nBTyvqRF7L5wAY7QisUcDGGPz6qKp50J0qZ/Z0gZ/N01kBMC4XV6eqfzBuZq0Njx14WnUGJ0LbEwV\n2Cjvc+exOpoDY2oAG1sB6G34n6XfoYFXW0zB6yLwNRle3XhKk4mqiYpSbHd5ljVwAU3UBWu0Ys70\n3rJ1VOuoQ2FdkSrtNqVn8wfWsRH1YGMrUG2rhYUtAWPwFc+m64KNq4CDa/l0Lh7C6vC/lHPF63h/\nH2RbMvvBW3DsXCW6dZauY9SYuK051KMKxt5SQOocAF2IvVHRqYco6Jo0sR28EzpGB6coPbRKU9mD\nrvNFeLZ6GkmgaTeEsY80i6eyh2W4Tt27v2DcieZSaCkC4otwEYBBI8QQCMFuAsMKquBgMCysf+aQ\nrou3kRed0t+V4XiITQiGqa80fYPSVWG88ZDqmMq4fSGXywMXUwOu9wEAQCEAXXTTx3sQbJFgdA4/\n14b6IHX/UDlbqwdnfL43JOhpJHUur3hoYOrn/zc33qSeOsTQq/lzgnFxVeDimr9ksGiNhj7Kpu1y\nc+OwsYBCM7Tcd474s3KPWnAYwEUC0DkhChwYaGdlGW/ZhrePbVNNKfPGgcWqY368+J9Qb0WGi69o\n0t0aiO7RKSiuvwhd54vIRyZ8U8KbS1gFIzk5GSUl3pfRbDYjKSmpiV+0DMFuwi2GSUhM9J9xsq0Y\nnRiD0cN6yt+TXP6+eiXOwhsh2KJgvEFqRETROwW57sKtiO3sRKUuR+WPBgDHuZthuDZwNoNgjQFn\nvtSov7cAACAASURBVAlOUxn013gH2Ik8B74mCbrOpRDAQ+ARsv2oFIymetUeGFaA0BArWRz6ptMZ\nAcBR0Bs6PhrstVJDqKyLJ2/9A1Zs2QNLVK5PaiQDR/5NMPQ86X9CNzpnHKyFqTAm1AIp3nRa0aUH\n6juDib8ITic0Z6VWVcNo0hkRTIoYVgRviQcXaVG7YLTOLQLOwt4Q7ZHyc+EeHgJRBBynByE6wQpH\npzxV3EfkWbiKr4f+vwKP4BUaYjGwy63o05fFN8d/kLcnmOJQYdaBi6tE5zgjqp2h9269FoYTsIXW\nIePrOoGNqZZdfAHPLbBwXrgRoj3C7x0xsHpYcm4GG1EPXfezqmwv0aWD62JP6HtImYEMGL9VJjl7\nAl4YOxmL1/2M+jjvuyTYIgBeDzaqDnanCF+vVJPlVVgYTYmnqhwNSeCjgntUInQmTOs/CZ0i4vH2\nzo+l67nrIkofiT8OmY4SixnfHFunulfRYYSrops8Lk0UWL/MuL7J1+PhpEnIyjmF265XT6TaEsLq\nkurXrx8KCwtRXFwMh8OBzMxMjBkTeKBRSy0SobYLOhvjUV5uuWT/rPWKF9qp9u/ytZ1xW9KvINR4\nxVGwJLgPZlB/sTMcRb3AOvy7jwZLapM+UrExBjd06olZI+4DX9VV3j6g8wCIivhDc6rSxHobA18X\nTsByOI3ISNUeQKc6t7UbeHNPRNq7y9s8daFnDDh+gEV5bg+IPg0SyzB444HfeK+ncT8pkcmIZ1Ng\nvdALfI03lchV3h2dTPEAgGYnhijcNjGmEM0+pxGdGwcGPYyvSgZv7ql6LroaesjXFWqSMOPWCUiO\n7uLzQz1cZdd4v2uIumiNwU0JaUhPGo4Ywftc3J48RLaaWK5575fcSDJo0tpS/cZuguvCDUGP4ytT\nwJtTVXXRKyYVABBtiMZN8b0xrtdoxBli1ed3mFR1EaX3FzK9Mw7dddfgGnEghAbv7/mya7wrZmqk\nmjd5Xwr3XCidKgBwNJjgLAreSA9NHoTB8YPRy5gmb/O8I9G6aKSZrseIxOGINqjji4ItEryiLvSi\nvwJ24rqgp+la/G5ABm6Man18N6yCwXEc5s2bhxkzZmD8+PHIyMhAWloaFi9ejJ9/lmanPHbsGEaO\nHImNGzfi5ZdfbnG2VJe41gd0moMy6M3XqV/waKMRv79H7VsRLFKgSyeaIIoMymtsMDD+DRLLMPID\nyYoaq301xiA6Qo8eSdEqYYmLiAr5QfYlWq8QrhDP0aNTPPp3D5DPq8CzJnqUydtyJ+ok8eAEEyxW\nqQ+v1WvrFBPhnRvMpbFfn4iYCPd2RV2kxMXhum5SfWtNW90UUTrvS9k1LjSLVXTp0SMmJehxqUnx\nuHvINZh4R09527VxUh167j8hxoR4k8Z1FX8XTlC4EN3TSAiNMeifJj2HKZ28jWSEzoTru0uNT3PW\nslaWCUCTfnsVvB6CNQQ/msZz1ruL1LjGGmLw7EMDcV96GjpHxvodp/xtjMF7LUaQnhUjL93vrwf3\nQJTB6zoWeZ13UFwz0mUBAO6gP8MKIYun4OL8OkJaaK7q524vYo3eZyFa75uQwqjKYlLk9nv+dt2j\ngz+XzSHsA/fS09OxadMmbN68WR6DMWvWLIwaNQqAZIVkZWXh8OHD2LNnD9atW9fsa/CVKUiMb33K\nWHNQCUZliqoH7OKlqRfmPupN8R3V93p0jUhCDOMVlwifwRsjuqbD4RLkjBQjEwlnaU8AwJCkWyEK\nDIT6BMRGGRAdofc+/ACiDRFBsjcAoV7bjZZolHp5fHViyIJxXUoXJEcGdy96Jg5U+nRv6poK0R4B\nV0M0KmvdKZxOdUxo4rWSrzVSF+neb8CAToOk+yi/BqLAontUD0SapJdNVNTFqFt6IjFW+h0fQDDi\nWe2yj+s7AADgqkhBbERoz5TI69A3uWfQ43p1jcNDY67H5BHXytuu75wK0WGE0Cg1jF3iTH695ju6\njgDAyFlOejECrkqpIeDLu0PkOSQZU+SBYUadV5xNughc00X6uwcadxPPddHcLtRK22+JH4S+qf5j\nebQQeZ2qVx/wOI1nNS2uJ6L1UeihaOSUsULe5kL5zxYfwYjBzZ1vAgBwlu4QXTpEuweu9E5NwH8l\nSfd+7suj4K2C97qMdl0IARr4G+Ikq8lZ2rN54lmfEPQwrcHGQl0niC49ronpJm9jfSYkdZlTAV7x\nt2Yj0TO6p3TpmkQwvEF7rE0ruKJHeg/pOgjWQ6MhWDpfBgvD+yALNYmwHR4tN8icTlKPa7t5X5zr\nuyXghSFPYWrqA/K2CL23QbIeGo1f9xgNp0uQXUscy8B14Ub0a5yGaTfeB9uRURBtUeiRGC1NeKd4\n6WKMkSoB0cKeeytsx37lF9hOik6A9eAYOM4MDCn3HpB6rnpOj7dG/N1vn+PczZK/GN6pv5Vz2XSK\nikavugxYTvVFbqE7k8Z9LybOhNeH/y9+/V8jAQBdIjoDkHrnM/o/gLdGvILOllthO3InkqO7yIKh\nrIsInSnoiP+HUh/FX4fOhu8MrCP7XIdOhRMx7YapoSc28Dpc0zkefx86129XVNlQGFnpPFqLeyVG\nxeKGhkkYl5KBJc+MQIRRJzcMXSI647Xh8/DwgLsAeEVVzxjhPNcP1oNj4Mzvg/8d+iJenj5cPqey\nxxqpM8lWmmYaLYCpPR7Bi0Nm+W0fe8tNGGB7GI8PeACdogKnW6vgdYDLiGHib/12/bHf7+TPCdH+\nYhxnjMXLt72AqTdM8tvXPbI7os6MgL40H+/NGiH3tmP0Ufjvfo/gzeEvgynuB9vRkRjRN1X+nefe\nr/3tLWC5qKAWxq3Mfegv+ns5pgwdAOvBMXBduDF095xLB9ERAduRkX777DlD5M+cxrLSoiMCMQV3\na86bdnPnm/Dszc9BqO6qeuc5wYTZt85Er4oH4Mzvi8TSe2DStS45x5crcvJBDxzLyq6K+Ji2rZhg\nqNNqGcQao5GSHI/C+lokd/YvC8fqYNKZ0LdnEgApkBtl1EOeecBlQKRRauQ8QWi7aAXAIMoQCaNe\nJ99rjyS3Ga54WCL1EcGtA14H0RorNa6KoGqXmEhvT0UI7ZHwTMIYqfd/8QVrtFyW7okm9L61BzJu\n74m/uudnjDQacH23KJw6Xw8R0pQJid1icUGQLBGlmyHW/bnR1QiO5RDJRqJrQhRKKqyIidAj0ugu\nr+i99widSTVOQ4teyQmIjtCDY1iVFWIy6DH/Manx/f5saFMpiLwe8TFGxEb6u82eGT8ciw+fhF2w\nq6aTl6+nM+GZ+wMPNlUuHSo6jUBEAwTWLv3t3YMMk+PiVb9RTvFvUghGIPdccnwskqP9e9YPjvb6\n30NNI/dYxz27dMFen+nFukWnyEHqO25OwVqfBQQjdCbN5wkA9DoO0ZX7UV5Wimee+j2c3VmwPU3I\n/PJbXEg9hTNn8vC/ry3Fq/94Cf/O/RKfOhx44IFp0PWU7v3Uwl1IHdUfMXE8sv+1B7GpXWAprIQ+\n1oieD/cH656S/vF7+sPBO/H40s9gzsqHyIvgInWIfNMhWQwuB4o374Wt6iLAMEi+sxfi+iSi7nQl\nLv50DqIoQhepR9pjA1F+6AAYZz5uHT4eJQByl+xFr9/eAkBE1c8rwZ8R0XihDqNe+hXefvt15Oae\nxPnKAsT1TUJc3Gjc0ee/cDo3F4sX///2zjwgynJ7/J+ZYdgZkE1kEVEUccEdUMktrpgrXEWvZup1\nrVxyqUS+Zd+ytG96vdXtdjXNTLMsb9mvm7bp1dJETZOstMUV0QABkX0GmOf3xzADA4MMCiLwfP5i\n3vV5D+/7nOc85zzn/I2iomIydVn4P9SVr/7+EYOejDTJ5vzmU/iO6YitnwIbpQ2d/Dw5cyGP0IDW\nFiR5ZzRthaFQsmZeJEXaUlMd5rtF1YV7M0Z05nBBeSoGC4kAjVla1TZKEh/qQ+r1fHKdfuZcpSzQ\nduU5qYwjyTJRitpGiZ+n+eiuTXlo76Du/hy5blgx7WBT+5SUh8aBrNxibO2gcr48D03FR9on2IfT\nxbWXfDQqDIuUVkyX6SljSrS5E9TTxQkP/4pOblAPX0q8s7iSWt3sdrUzWGk6fUW8Ukd/N36+mI2P\nh2MNFoZDrXVRnMsTRtY0ZQV1WJxZamO6XlVcbJ1NI35LU0KWLCFjidiq77TQGd4LPbdeBFe5/oqj\njQM25fewVB8DwENjX+vC2arWVklKCGXZ1ac7nNVO4KjiP99eorjYfGT94i9nKSoeBMAXP1W00TgC\n/0yfxoPR5lM4xvdBqVDwyCMLuXTpAlu27OC15M18d/IYWZczmLd6AT4+hrb8428v4eLiglarZc6c\naQxYEGO4kALG39eZayWlnMguos3kQHzHBnPpg5+4eSaDVmEVz2KjVOEU6EbHuQZFnnXyGh/9+11a\nufTlt+N7UTnbEzLfkAKkrLiU0gIdqZ/8QvCsPti62VNWntbF+E56t3LkWnkbjGRl/EH7sb3wHx2C\nq4cb8+bNx8XFhUf3PcH5raeYMdONmHB/pj44gVWr/o+QkM68cHgdV4vT6TiwG/v37SXhwVno8q+z\n/AclDq2daeVm6D9G929HgJczYcEeNf9Db5MmrjBUtG7VsOsvakKpUBLh04fDxw2LYezUSpNSqGz6\nz+gymYOp3xLSqiKvULCfK8F+ruSXuHH2xq/8/p3hZVUplSyOD+M/R+2xdSlhVPtoggZ2MCmSByLa\noivRm+ovuDk7QnmNFQcbe/Q3PdHnu1Ka1g6hszdLuT6gTTgTB0VSWqbn8W8/N20vve6Lb1jFiN5W\neetO0sfRG7VKTadWFaPPv4T8mTNZv9LFI4Sfs85yXOtoUl6VE9xNDhnP0T9O0MmzLcK94uvx93Im\nOHAoF3IvE99xrNn9YtoN43LuFTPTfHi/AAb39MXBzgbHcme6qDIl1ce7J99eO84D7e7HwcaeV069\nYdpvn1sRjVKZKN8Is9+1TUn5OvmgVKiY8ZcHTJ37hI5juXDzEu1d23E+5yLOaieLU0ITO8WSfP0n\nPByqz3GP6zCC9MIMpnSeYNo2d0wXvjljh975JJqcXmTXsA4AzKtO2tvYE+UXwanrpxnXYSR//yDZ\nUN60nJI/2pkWo1ZmqH+U2W/7WmRhTLBpo7JBWa7wHNQOlOnLsFGqKNWXoUCBQqFACPNAWDuVPXpR\natH5G99pHFt+3sGDnSdAXoXC+0tIHJfPXsC+c6hJWQB88MG7HDpkqLmRkZFBQWaeYb2FgB5BrelU\npOFD97dZ/KcFFJYU8sKh59HdMPjRHihPFa9UKNHdLOba++cozdfiqHTgYuB5fPsNJjnzdwL6RWH8\n8FT2NuT+molzOzdsy1dSt/MKRKlQkVpyBYUSfNwdKLnQAVHyPX1ahYO6mJutU3D0MwyGSvQl7N//\nBZ988jE5RTmU5WhxFHmkXrmMp6cXISGGAJq/9pzK9rPvMzU+nmUPL2D+/MVs2vQO8eMmkuacR2z5\nN6JUKujVyTqfU11pkgrDaNZWHY3ebaZ1mcTBTwy1I1QqpWkkV3nBUD+fXvTzsRxy6ax24vG+85nz\n34OmbWEdPMujXQZUOz5+qHmIXuVRtIONPQvjelGk7c7mMwZbf1HPubyabOgoHww1dD5qmwqZ6S52\npex6ABonW8I6eHD6fJbBF3KL2Zy2Gn+md/mL2bb7/CK5zy/S9Pfxvf81WRiVZRHlF0GUX3mnrDQ4\neDNvFhPg7YSrnSNP9l1Y7X4aWxce77vAbJtSqcChfCpKbSxeJMwVhqPagYR+j5m2ze72EJt/2g6A\nT1F4tfvM6DK52v/pllYU0N41kMmdx5ttGxoQxdCAKNPfQCULo0IWg/0HMNi/+v8YoJW9WzWfQmRX\nHyK7+gAD2fbFr4Ah99Hi+LBq59tUeS9cbJ3L/TWgv3kV3cUu2AYZpkVn9x5f7fy53afRw6ub2Tb7\nKrJQt/3VbF3IsID7GN+x9gjHxMOruKnLY0CbcPb/P0Mk0HMzBtSY58jb0dP0f0zLq1g57+ngwZTO\n49mZXFFX4tSpk3z//QneeGMrtra2LFw4j8rGmL2NA2pbG3w0renUyjBoCPUM4ec0w/cyun2M6dir\ne37De2BbEiclUHw5j7fe2kRg6/LpwSrTtlVDvrt6hjKmfQzjNvwPQgjUNipKr3ZEX6xioO9AvFzt\nSXLYh1KhRC/0ZKZfZ+/OXbz55nacnJxZvfpZdDpttev6OvuwvFwW/fpFcOjQQQ4c2MfmzdtxcWm4\nNWiVaZpO7/LBaWMrjMooFFQaSVqfNlylvP1nqNox9OroxYBubVgzL5LnZlbvFKuiLzS8ZPa2Khb8\nuTuLJoTRJeDWkU9WO4ItWBhVeWp6X5ZP6YV3fViJVaakqu2uNPVkmsaqhKXww9oUhrWysLmN9+JW\nGG0zJ3sbUyhtZdQK8/ei+gUqeqLw0Orz3LcnC+um7yrLomuQQWG0stL/6OjoSGFhzaOZgoJ8XFxc\nsLW15fLlS/z8808mqx/AVmmMqKvcE1tem6LXlmGjscPPuQ2fffYpAKP6B9ItrA/ZP1WsWi8rKsEp\nQEP+pRx0OYbsC/oiw/95YJ/OaG9epXt7d4pvplJSdANHu4piakZZFBYW4uDggKOjE9nZWRw9egSA\nwMB2ZGVl8ssvZ03H6cvrqI8ePY6XX15HaGjXu6YsoIlaGEoUlCEslrhsLOzUKtztDPPybna3XgVe\nlRHhbU3RRHXB3MKo6CSN03QXbt46Umh8eA/Ss3SmKa6ewZ5cyr21s9jGQrnZqswaFcpnqRe4QSau\ndjW/zBpHWzRt66cSWGULw1JkiLFzF0JhSA5XhdaO1U14S4qnMtY6gr0cPEjJSzVz5t8Jrs6G+3q3\nstw+o6WrVqqrTfMsndSDXT/kkImluH4D7vbVp8nqS3l6O3qRVXwDZ1tHpsX3oEwvrB40aTSudO/e\ng+nT/0JExAD69x9otj8iYgAff/whM2ZMoW3bQLp1627qIxQKpclPU9lfY5yCrfr/bz20HZd3/khi\n0jK6dOlGWtofONjZ8NcZs0l8IYFfXzsGSgU+Q4NwDfXCf2xnLr33I0JAgfc1xr0+ikUz47n221Ge\nXj6Pm8Vu2Dp5mSxjhUKBj6M3KXmpBAa1I7tjCA89NAlfXz/CwnoAYGNjw7PPruHvf38JrVaLvb09\nL7/8Ovb29oSEdMbJyYlRo+48y3ddaJIKw7D0lHrJcHunPDWtLz9fzCLA2xlvj6GUCT2D/PvX6RqV\no1HqgrFjUCqU2FpwngZp2jKyXTTdvbqYbV/Ycw452ptEtqk+l19bJ1lioQRqVQZ2b0Ovzg/y5eUD\npmmZhiKgtaET9nXXcB3DXLsly7OLRwhlf3SgJNMHx7AKWT3aYyZFJUUW667X1kmKGpzIVZnYKZZW\n9m6mUOE7JaZfW7S6Mob29rO43+jDsNT+bkEehAaO49OLDkT49DHbN7f7dMpEmUX51fZeqKy09h8K\nnciBK4cZHjgUpVJhVsLVGlauXGX2u1evimdQq9WsW2eet2nvxa84e/EiUStGotG4otG48vbbO037\nl89NYM/FrxhYxX+1dMISbOJVhHl1Ndvu6a7BN3I0duX5voxoOnqg6WhwMj/Y2RA6b+jg/wnAzBcN\nU9e+bQzrKt5+eyfZxTf4JjWJoQH38UBitMXn7dw5lI0b36q2PTPzOkII+vWLtHBWw9E0FcY9RHtf\njWm9hb2NHbHBI+/avdXloycHG8tRLgqFglHtqxek6uxevbCPkdo6Sa0VCgMM4bZ3QxZd27mzYmpv\nrul/4YNz39XYsSkVSsZ3Gsl7V36nf9cKJ2lXj5qzHdbWSWr11iU9dLZ1Ii54lFXHWoOdraqaP6sy\nxiipmtqvUqosxvf3qNI5Vqa290Jn5SpyVzvNXf1GjBaWQ03huiq1xfb09q7uGwJDgMbU+7uy6+oJ\ni/vBfPrTyEsP96ekSu13d/tWtyWLzz/fw6ZN/2LRoqV1PvdOaZIKw5i6+25mw70XsVEZRsrVcvvf\nAbVdS1dmXVnWu0lHfzdupBmmFm7VsUX39WdYHz+rp0BqVZ6l1mfJvZsYpypra39dqH0gca/Kov6/\nkVB/L2PMgUV0FmThWY+ZKEaMGMWIEfU3AKkL947XuA40/kTUvUFDdAxqlWW/h/HDC27V3uL+xsbY\nvqrRPJVRKBR1CjKoKZTU6LsIcg20uL+xsbmLCsMoi7Yu/hb3NzYNIosarBXjtHAb5/pNx3Ev0SQt\nDGNioqppjVsatU093CkrIx6nVJRRUFJAB9cgLuam0P6e7SQN03OO9dgxVPZrPNd/BcVlxRSUFNDe\ntR2Xcq/QwbVdvd2rPmmITrKyU/v5AYkUlhZRUFJIkGsgKbmpdHBrV2/3qk8qZFF/30hla+W5/ivQ\nlmkpKCkgUNOWq/nX7tmBRH3QJBWG0cJo6QrjVs7N+sDN3s0sXDLYrfbstI2FNRbGneBu72bmJ7q3\nZVH/70XlZ29l70YrKlbq36vKAhreCq+68LI5KwtoslNS5S9vy9YXDTJ6qoy1kS/3AkZZODaQLO6F\niDxrMc3bN5AsmhINoTBaMk2nR7BAS7cw3O1b4ah2IMDFcnjl7RLqbsj9dC+tc6kNTwd37G3s6l0W\nHVzbWQxZvpfxcvTEzsaOAGff2g+uA/7Ovrio62ctye2Qn5/P7t3/rtM53o5e2Kls8Xfx5YMP3kOr\nrZ+gDU8HDzzt3evlWk0JhWiCoUbTPlxMcanW6nQEzRl3D0eys2692K6uCCHQC73FtQn3MlIWFTSU\nLBozJc8ff1xj+fIlbNv2fp3OM8oiPn4sb765HY2mbgtrLWFM5FhXWZSVlaFSNd67dKdlrJuoD0M6\nvY00REemUCialHVhRMqigoaShaIRYxQ3bHiNa9euMnPmg/TtG8Gjjy7i3Xe3c+DAV5SUlDJo0BBm\nzpxLcXExK1cmcP16Bnq9noULF3DpUiqZmddZuPBh3NzceOWVf5lde+vWzXz77SF0Oi3duoXxxBOJ\nAFy9msratavJyclBpVKxatWL+Pr68d672/nyy89QKpVERg5k3rz5LFw4jwULlhAS0pmbN3OYPXsa\nu3Z9wmeffcqRI4fR6bQUF2t58cW/kZCwjPz8PEpLS5kz52Giosoz9n72KTt37kCpVNChQ0eWLl3O\n9OmT2bnzI1QqFYWFBeW/dzeK4mmSCqOS11sikTQCH537lFMZP9brNXt5d+fPwaNr3F85vTnAd98d\nJTU1hU2btiGEYPnypfzwQzI5Odl4enrx0ksvA+DgoKBvX8H777/HP/6xEY2mekXA8eMnMWPGbABW\nrVrJkSOHGTAgimeffYpp0/5KVNRgSkpK0Ov1HD16hMOHv2HTpm3Y2tqSl5dXQ4srlOvPP//Itm3v\n4+zsjF6vZ82adTg6OnLzZg7z5hmuf+HCed55Zyv/+tcWNBoNeXl5ODo60rt3H5KSDhMVNZh9+75k\nyJD7G81KaZIKQ1oYEonk+PFjfPfdcWbOfBAhBEVFxaSmphAW1pN//vMVNmx4jf79o4iOvo+iojwM\nI0zLfcbJk8d5993taLXF5OXl0b59B3r27E1m5nXT6F+tNviyTpw4zqhRY7C1NUQQWpP8r1+/CJyd\nDf4fvV7Pxo2vkZx8CqVSQWbmdW7cyObUqRMMGXK/SaEZrzt69DjefXc7UVGD2bv3PyxfXr2y492i\nSSoMiUTSuPw5ePQtrYG7gRCChx6awdixcdX2vfnmOyQlfcvGja/x228/Eh//UI3X0el0rF//Elu2\nvIOnpxdbtryBTqejJuVicPtWn5pTqVSm/GKG8ytwqFQf/quvPicnJ4e33tqBUqkkPn4sWq2uxswV\n3bv3IC3t/0hO/h69Xk9QUOMtnm2SUVJyRkoiaXlUTW8eERHJnj2fUFRkSCtuGKnfIDMzEzs7O4YP\nH8HkyVM5c+ZM+flOFBQUVLuuTqdDoTBkwy0sLOTgwf2m4729W3Po0EEASkpK0GqLCQ833FerNRRe\nys3NBaBNGz9++cVwrwMH9tX4HPn5+bRq5Y5SqeT770+Qlmao89GnTzgHDuwjN/em2XUBYmJG8r//\n+z+MGjXW4jXvFk3Swujg3o7T6WfxtJCGWSKRNE+qpjd/9NFFXLp0iYcf/itgUChPP72K1NQr/POf\nr6BUKrCxUfPCC4YMt2PHxvL444vw9PQyc3o7OzszZkwc06ZNok0bX0JDK5IwPvXUs6xdu5rNmzei\nVqtZtepFIiL6c+7cb8yaNQ1bWzWRkQOZO/dRJk9+kKefXsEXX3xGnz79anyO4cNHsHz5UubMmUZw\ncAiBgYZFoEFB7Zk2bSYLFsxFpVLRsWMIiYnPlJ/zAJs3byA6unoy0btJkwyrzdXm89WZIwzw7Wex\nrGNLwsvLhevXa3K6tSykLCqQsqigOcjiwIF9fPvtIZ566tk7uk6LDKvV2DnXueaERCKRNEVefnkt\nR48msW7dK43dlKapMCQSiaSlsHjxE43dBBNN0uktkUgkkruPVBgSiUQisQqpMCQSiURiFQ2uML75\n5htGjBhBTEwMb7zxRrX9Op2OJUuWMHz4cCZNmsS1a9caukkSiUQiuQ0aVGHo9XpWrVrFm2++yaef\nfsqePXs4f/682TH//ve/cXV15csvv2T69OmsXbu2IZskkUgkktukQRXG6dOnCQwMxM/PD7VazahR\no9i/f7/ZMfv37ycuzrC0PyYmhqSkpIZskkQikUhukwZVGOnp6bRp08b0u3Xr1mRkZJgdk5GRgY+P\noWi6SqVCo9GQk5PTkM2SSCQSyW3QoArDmkXkVY8RQjSpcpgSiUTSUmjQhXs+Pj5mTuz09HS8vb2r\nHZOWlkbr1q0pKysjPz8fV9faK2Ld6RL35oSURQVSFhVIWVQgZVE/NKiF0b17d1JSUrh69So6nY49\ne/Zw//33mx0zdOhQdu/eDcDnn39OZGRkQzZJIpFIJLdJgycf/Oabb3jhhRcQQjBhwgTmzp3Ls8Qs\nGwAACZ5JREFUq6++Svfu3Rk6dCg6nY4nnniCs2fP4ubmxvr16/H392/IJkkkEonkNmiS2WolEolE\ncveRK70lEolEYhVSYUgkEonEKqTCkEgkEolVNDmFUVtuquZGYmIiAwYMYMyYMaZtN2/eZObMmcTE\nxDBr1izy8iqqiT3//PMMHz6ccePGcfbs2cZocoOQlpbGtGnTGDlyJGPGjGHbtm1Ay5SFTqcjPj6e\n2NhYxowZw2uvvQZAamoqEydOJCYmhqVLl1JaWmo6vrnna9Pr9cTFxfHwww8DLVcWw4YNY+zYscTG\nxjJhwgSgnr8R0YQoKysT0dHRIjU1Veh0OjF27Fhx7ty5xm5Wg/Ldd9+JM2fOiNGjR5u2vfTSS+KN\nN94QQgixceNGsXbtWiGEEAcPHhRz5swRQgiRnJws4uPj736DG4iMjAxx5swZIYQQ+fn5Yvjw4eLc\nuXMtUhZCCFFYWCiEEKK0tFTEx8eL5ORk8dhjj4m9e/cKIYRYuXKleO+994QQQuzYsUM888wzQggh\n9uzZIxYvXtwobW5I3nrrLbFs2TIxb948IYRosbIYNmyYyMnJMdtWn99Ik7IwrMlN1dzo27cvGo3G\nbFvl/FtxcXEmGezfv5/Y2FgAevToQV5eHpmZmXe3wQ2El5cXoaGhADg5OdGhQwfS09NbpCwAHBwc\nAMOIubS0FIVCwbFjx4iJiQEMsti3bx/Q/PO1paWl8fXXXxMfH2/advTo0RYpCyEEer3ebFt9fiNN\nSmFYk5uqJZCdnY2npydg6Eizs7MB87xcYJBPenp6o7SxIUlNTeWXX36hR48eZGVltUhZ6PV6YmNj\nGThwIAMHDiQgIACNRoNSafikfXx8TM/b3PO1rV69mieffNKUUujGjRu4urq2SFkoFApmzZrF+PHj\n2bVrF0C9fiNNqqa3kEtGbokl+TS3vFwFBQUsWrSIxMREnJycany+5i4LpVLJxx9/TH5+PvPnz69W\nNgAqnreqLEQzytd28OBBPD09CQ0N5dixY4Dh+ao+c0uQBcDOnTtNSmHmzJkEBQXV6zfSpBSGNbmp\nWgIeHh5kZmbi6enJ9evXcXd3BwwjhLS0NNNxaWlpzUo+paWlLFq0iHHjxhEdHQ20XFkYcXZ2pl+/\nfvzwww/k5uai1+tRKpVmz2uURV3ztTUFvv/+e/773//y9ddfo9VqKSgoYPXq1eTl5bU4WYDBggBw\nd3cnOjqa06dP1+s30qSmpKzJTdUcqToSGDZsGB999BEAu3fvNsng/vvv5+OPPwYgOTkZjUZjMkWb\nA4mJiQQHBzN9+nTTtpYoi+zsbFOkS3FxMUlJSQQHBxMREcHnn38OmMti2LBhzTZf29KlSzl48CD7\n9+9n/fr1REREsG7duhYpi6KiIgoKCgAoLCzk8OHDdOrUqV6/kSaXGsRSbqrmzLJlyzh27Bg5OTl4\nenqycOFCoqOjeeyxx/jjjz/w9fXllVdeMTnGn3vuOQ4dOoSDgwNr1qyha9eujfwE9cPJkyeZOnUq\nnTp1QqFQoFAoWLJkCWFhYSxevLhFyeLXX38lISEBvV6PXq9n5MiRPPLII1y5coWlS5eSm5tLaGgo\na9euRa1Wt5h8bcePH2fLli1s2LChRcriypUrLFiwAIVCQVlZGWPGjGHu3Lnk5OTU2zfS5BSGRCKR\nSBqHJjUlJZFIJJLGQyoMiUQikViFVBgSiUQisQqpMCQSiURiFVJhSCQSicQqpMKQSCQSiVVIhSFp\n0kycOJG4uDhGjRpF165diYuLIy4ujsTExDpfa/bs2Valu16xYgXJycm309w6cebMGb744osGv49E\nYi1yHYakWXD16lUmTJhwy+yjxlQRTYVdu3aRlJTE+vXrG7spEgnQxHJJSSR1ISkpibVr19KzZ0/O\nnDnD/Pnzyc7OZseOHaaCOgkJCYSHhwMwePBgtm7dSlBQEFOmTKFXr16cOnWKjIwMRo8ezeLFiwGY\nMmUKjz76KFFRUTzxxBM4Oztz/vx50tPT6d27N2vWrAEMuXmefPJJbty4QUBAAGVlZQwbNoxJkyaZ\ntTMzM5Nly5Zx48YNAKKiopg9ezavv/46hYWFxMXFERERQUJCAqdOnWL9+vUUFRUBsGjRIgYNGkRK\nSgpTpkxh9OjRnDx5Ep1OxzPPPEPv3r3viqwlLYQ7KdYhkdwrpKamisjISLNtR44cEV26dBE//vij\naVvl4jLnzp0TQ4YMMf0eNGiQuHDhghBCiMmTJ4tly5YJIYTIzc0V4eHhIjU11bTv0KFDQgghHn/8\ncTF16lRRUlIitFqtGDFihDh27JgQQohHHnlEbNq0SQghxJUrV0SvXr3Ezp07q7V98+bNYuXKlabf\nubm5QgghPvjgA7F06VKztsfGxoqsrCwhhBBpaWli0KBBIj8/X1y+fFmEhISIPXv2mJ59yJAhorS0\n1HohSiS1IC0MSbOmffv2dOvWzfT70qVLvPrqq2RkZKBSqcjIyCAnJwc3N7dq5z7wwAMAuLi4EBQU\nREpKCn5+ftWO+9Of/oSNjeFT6tKlCykpKYSHh3Ps2DGef/55APz9/U2WTFV69uzJO++8w7p16+jX\nrx9RUVEWjzt58iSpqanMmjXLlJBSpVJx5coVHB0dcXBwYOTIkQD0798flUrFpUuX6NChg7Xikkhu\niVQYkmaNk5OT2e8lS5bwzDPPMHjwYPR6PWFhYWi1Wovn2tnZmf5WKpWUlZXV6Thr6yz06dOH3bt3\nc+TIET788EM2b97M9u3bqx0nhKBr165s3bq12r6UlJRq2/R6fbOq9SBpfJqOB1AiqQVhRfxGfn6+\nKTvpzp07a1QC9UF4eLgprfTVq1c5fvy4xeNSU1NxdnZm5MiRJCQk8NNPPwGGWhfGNOYAvXv35ty5\nc5w4ccK07fTp06a/i4qK2Lt3L2AoUQoQGBhYvw8ladFIC0PSbLBmNJ2YmMjcuXNp06YNERERuLi4\nWDy/6rVq2ner455++mmWL1/Onj17aN++Pb179za7n5GkpCS2bduGSqVCCMGqVasAGDhwIG+//Tax\nsbFERkaSkJDA66+/ztq1a8nLy6OkpISAgAA2bNgAgKenJ7///jvx8fHodDrWr1+PSqWqVSYSibXI\nsFqJpIHQarWo1WqUSiXp6enEx8ezY8cOAgIC6v1exiipw4cP1/u1JRIj0sKQSBqICxcusGLFCoQQ\n6PV6lixZ0iDKQiK5W0gLQyKRSCRWIZ3eEolEIrEKqTAkEolEYhVSYUgkEonEKqTCkEgkEolVSIUh\nkUgkEquQCkMikUgkVvH/AcQ/YGad+SX7AAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f971b401110\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "with tf.Graph().as_default():\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=max_steps,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 500)\n",
+ " test_ds = setup_mnist_data(False, hp, 100)\n",
+ " tf_train = autograph.to_graph(train)\n",
+ " (train_losses_, test_losses_, train_accuracies_,\n",
+ " test_accuracies_) = tf_train(train_ds, test_ds, hp)\n",
+ "\n",
+ " with tf.Session() as sess:\n",
+ " durations = []\n",
+ " for t in range(burn_ins + trials):\n",
+ " sess.run(tf.global_variables_initializer())\n",
+ " start = time.time()\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = sess.run([train_losses_, \n",
+ " test_losses_, \n",
+ " train_accuracies_,\n",
+ " test_accuracies_])\n",
+ " if t \u003c burn_ins:\n",
+ " continue\n",
+ " duration = time.time() - start\n",
+ " durations.append(duration)\n",
+ " print('Duration:', duration)\n",
+ "\n",
+ " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " print('test_accuracy', test_accuracies[-1])\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "A06kdgtZtlce"
+ },
+ "source": [
+ "# Eager"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "hBKOKGrWty4e"
+ },
+ "outputs": [],
+ "source": [
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(tf.cast(y, tf.float32), y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "HCgTZ0MTt6vt"
+ },
+ "outputs": [],
+ "source": [
+ "def train(ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " train_losses = []\n",
+ " test_losses = []\n",
+ " train_accuracies = []\n",
+ " test_accuracies = []\n",
+ " i = 0\n",
+ " train_test_itr = tfe.Iterator(ds)\n",
+ " for (train_x, train_y), (test_x, test_y) in train_test_itr:\n",
+ " train_x = tf.to_float(tf.reshape(train_x, (-1, 28 * 28)))\n",
+ " train_y = tf.one_hot(tf.squeeze(train_y), 10)\n",
+ " test_x = tf.to_float(tf.reshape(test_x, (-1, 28 * 28)))\n",
+ " test_y = tf.one_hot(tf.squeeze(test_y), 10)\n",
+ " if i \u003e hp.max_steps:\n",
+ " break\n",
+ " with tf.GradientTape() as tape:\n",
+ " step_train_loss, step_train_accuracy = predict(m, train_x, train_y)\n",
+ " grad = tape.gradient(step_train_loss, m.variables)\n",
+ " opt.apply_gradients(zip(grad, m.variables))\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ "\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " return train_losses, test_losses, train_accuracies, test_accuracies\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ },
+ "height": 789
+ },
+ "colab_type": "code",
+ "executionInfo": {
+ "elapsed": 56025,
+ "status": "ok",
+ "timestamp": 1531163800231,
+ "user": {
+ "displayName": "",
+ "photoUrl": "",
+ "userId": ""
+ },
+ "user_tz": 240
+ },
+ "id": "plv_yrn_t8Dy",
+ "outputId": "68be955d-61dd-43e4-b540-3794e3c8f990"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Duration: 4.2232978344\n",
+ "Duration: 4.2386469841\n",
+ "Duration: 4.24286484718\n",
+ "Duration: 4.24036884308\n",
+ "Duration: 4.25758385658\n",
+ "Duration: 4.23242998123\n",
+ "Duration: 4.4213449955\n",
+ "Duration: 4.29613113403\n",
+ "Duration: 4.28209114075\n",
+ "Duration: 4.24192905426\n",
+ "Mean duration: 4.26766886711 +/- 0.055508619589\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXdgFGX6x78zW7KbTSE9JIA0pQkIooCgqBx2qiK/O0XU\n8zyFAw/w8MSCFcuJCHqKoFiwIHIgIgooaGjSCU1aaCEJ6W1btszM74/ZmZ2tWchuQjbP55/s7szO\nvDPZeb/vU97nZQRBEEAQBEEQ9cA2dQMIgiCI5gEJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBEGE\nBAkGQRAEERIkGARBEERIkGAQRIjs3r0bt99+e1M3o14KCwvRtWtX8Dzf1E0hogwSDKLB3HzzzejZ\nsyeqq6s9Ph85ciS6du2KoqIiAMC///1vdO3aFQcPHpT3yc/PR9euXeX348ePx/Lly+X3CxYswNCh\nQ9G3b1/ceOONmDZtGgDgrrvuQt++fdG3b190794dvXr1Qp8+fdC3b18sXLjQp43vvfceZsyY0aDr\n7NevH3766acL+s6HH36IuXPnYufOnRgyZEiDzi/hfY/8wTBMWM5FEErUTd0AIjpo06YN1qxZg/vu\nuw8AcPz4cdhsNo+Oi2EYtGrVCu+88w4+/vhjj8/9sXLlSqxevRqfffYZ2rRpg4qKCmzcuBEA8MMP\nP8j7jR8/HqNGjcLdd9/doGsQBCHsHW1OTg6efPJJOBwO6sSJZg9ZGERYGDlyJFauXCm/X7lyJUaP\nHu2z3+jRo3Hs2DHs3r273mMeOnQIgwcPRps2bQAAKSkpGDt2rN99g1W42bx5MxYsWIAff/wRffr0\nwahRowCIQjN37lz8+c9/xlVXXYWCggKsWLECd9xxB/r27Ythw4bhm2++kY/jbSXcfPPNWLx4MUaM\nGIFrrrkG06ZNg91ul7fX1tbi7Nmz6N69Ox599FGUlpbKVlBZWRkEQcDChQsxbNgwDBgwAFOnTkVt\nbS0AwG6341//+hf69++Pa665BmPHjkVlZSXmzp2LPXv24OWXX0bfvn3xyiuv1HsfS0tL8fjjj6N/\n//649dZb8e2338rbDhw4gLvvvhtXX301Bg8ejDfeeCPo+QHAZDLhmWeeweDBgzFkyBC888478v3P\nz8/H+PHj0a9fPwwcOFC2CInogCwMIiz07t0bq1atwqlTp9C+fXusXbsWX331FebOneuxn06nw2OP\nPYa3334bX331Vb3HfPXVV5Geno7+/fuje/fuYNkLH+Ncf/31eOyxx5Cfn48333zTY9vq1auxaNEi\ndOjQATzPIyUlBQsXLkSbNm2we/duPPLII+jVqxe6desGwNcaWrt2LRYvXgytVov/+7//w8qVKzFu\n3DgAwJYtWzBgwADodDosWrQIM2bMwG+//SZ/99NPP8XGjRvx5ZdfIikpCa+88gpefPFFzJkzBytX\nroTJZMLmzZuh0Whw5MgRxMTEYOrUqdi7dy9GjhyJe+65J6TrnzZtGrp06YL58+fj5MmTeOihh9C2\nbVsMGDAAs2fPxoQJEzBixAhYrVacOHECAAKeHwBmzJiB9PR0bNiwAWazGY899hiysrJw7733Yt68\neRg8eDCWLFkCu92OQ4cOXfD/i7h0IQuDCBsjR47Ed999h61bt6Jjx45IT0/3u9+9996L8+fPY/Pm\nzUGPN2LECDz33HPYunUrxo8fj+uuu85vfKIhjB49Gp06dQLLslCr1RgyZIhs0fTr1w+DBg0Kag09\n8MADSE1NRUJCAm666SYcOXJE3vbbb78FjVssW7YM//znP5Geng6NRoNJkyZh3bp14HkearUa1dXV\nOH36NBiGQffu3WEwGC74+s6fP499+/bhySefhEajQdeuXTF27FisWrUKAKBWq5Gfn4+qqiro9Xr0\n6tVL/tzf+SsqKrB582bMnDkTMTExSE5OxoQJE7BmzRr5e4WFhSgpKYFWq0Xfvn0vuM3EpQtZGETY\nGDFiBO6//34UFBRg5MiRAffTarWYOHEi5s2bhzlz5gQ95l133YW77roLHMfhl19+wfTp09GjRw8M\nGjQoLG3OzMz0eJ+Tk4P3338fZ86cAc/zqKurQ5cuXQJ+PyUlRX6t1+tRVlYGQHSRbdu2DU8//XTA\n7xYVFeEf//iHbDUJggC1Wo3y8nKMHDkSxcXFmDZtGoxGI4YPH45p06ZBpVJd0PWVlZUhMTERer1e\n/iwrKwuHDx8GAMyePRvz5s3D7bffjrZt22LSpEm48cYbfc4/YsQITJ06FYWFhXA6nRg8eLDcZkEQ\n0Lp1awCi9fHOO+/gnnvuQatWrfDggw82OLZEXDqQYBBhIysrC9nZ2di0aRNmz54ddN8xY8bgo48+\nws8//xzSsVUqFW699VYsXLgQJ06cCJtgKF1MdrsdTzzxBP7zn/9g6NChYFkWkyZNChofCcTBgwfR\npk0bJCUl+ZxHonXr1pg9ezb69Onj9xiTJk3CpEmTUFRUhL/97W/o2LEj7r777gsKnqenp6OmpgYW\niwWxsbEARKtDsv7atWsni/a6deswZcoU7Ny5Ezqdzuf8HTp0wA033ICYmBjs2LHDbztSUlLw8ssv\nAwD27NmDhx56CNdeey3atm0bcpuJSxdySRFhZfbs2fjss8+g0+mC7qdSqfCPf/wDixYtCrjPypUr\nkZOTA7PZDEEQkJOTg5MnT8pukwshNTUVhYWFQTt/h8MBh8OBpKQksCyLnJwcbN269YLPBYjuqBtu\nuEF+n5KSgurqaphMJvmzcePG4e2335bTjisrK7FhwwYAwI4dO3D8+HHwPI/Y2Fio1WrZukhNTcW5\nc+eCnl+6zszMTPTp0wdvv/027HY7jh49iuXLl2PEiBEAgO+//14OZsfHx4NhGLAsG/D8aWlpGDRo\nEGbPng2TyQRBEHDu3Dns2rULgBjTKSkpAQAkJCSAZdmLijsRlyYRtTCKi4sxY8YMlJeXQ6VSYezY\nsXjggQc89tm5cycmTpwoj0CGDRuGiRMnRrJZRJhRjjS9R5LBRsN33XUXFi5cCKPR6Hf/uLg4LFiw\nAKdOnQLHccjKysILL7zg4xcPZcR922234fvvv0f//v3Rpk0brFixwud7BoMBzzzzDJ544gk4HA7c\ndNNNGDp0aMBjBjtvTk4OXnrpJfl9x44dceedd2Lo0KEQBAFr1qzBhAkTAAAPP/wwysrKkJKSgttv\nvx1Dhw5FeXk5Zs2ahZKSEhgMBtxxxx1yJ//AAw/gqaeewtKlSzFixAg888wzQds2Z84czJo1C9df\nfz0SExPxxBNPYODAgQDEDLLXX38ddXV1yM7Oxty5c6HVaoOe/4033sBbb72FO++8ExaLBW3btsUj\njzwCQLSsJDFJTU3FM888g+zs7KD/G6L5wERyxb2ysjKUl5ejW7duMJvNGDNmDN5//3106tRJ3mfn\nzp1YvHgxFixYEKlmEESjUlFRgVGjRtUb1CeI5kZEbcW0tDQ5HdFgMKBTp04oLS2N5CkJoskxGo1B\ng90E0VxptKB3QUEBjh496tf/nJubi1GjRiE9PR0zZsxA586dG6tZBBF22rdvj/bt2zd1Mwgi7ETU\nJSVhNpsxfvx4TJw4EX/60598trEsC71ej5ycHMyePRvr1q2LdJMIgiCICyTi6QtOpxNTpkzByJEj\nfcQCEF1VUo74kCFD4HA4fIrYedMIGkcQBEF4EXGX1MyZM9G5c2c5I8Sb8vJypKamAhDr2gBAq1at\ngh6TYRiUlRmD7tNSSEuLp3vhgu6FG7oXbuheuElLi2/Q9yMqGHv27MHq1atxxRVXYNSoUWAYBlOn\nTkVRUREYhsG4ceOwbt06fP3111Cr1dDpdD61hwiCIIhLg0aJYUQCGjGI0OjJDd0LN3Qv3NC9cNNQ\nC4OmYBIEQRAhQYJBEARBhAQJBkEQBBESJBgEQRBESJBgEARBECFBgkEQBKHAZDJh5crlF/XdGTP+\nCbPZVP+OLhYvXoilS7+4qHM1BSQYBEEQCozGWqxc+a3fbTzPB/3um2++A4MhLhLNuiSgFfcIgiAU\nLFjwHoqKCvHww/ehX7/+GDhwED75ZBFSUlKRl3ccS5Ysw9NPP4myslLY7TaMHftnDB8+CgAwduwI\nfPzxElgsFjz55BT07HkVDh3aj7S0DLz++hxotdqA5z1x4hjeeut12Gw2ZGdn4+mnZyEuLg7ffrsU\nq1atgFqtRvv2HfDCC69i3749mD9/jmvdEwb//e8ij2V4IwUJBkEQlyzLNuZh19GGLYmgUjHgOPf8\n5Gu6puPemwNXxH788ck4c+YUFi/+EgCwb98eHDnyB5YsWSavAT9z5izEx8fDZrPhb397AEOG3IyE\nhAQA7oWrCgrO4cUXX8NTTz2D559/Gr/9thG33HJbwPO+8soLmDbtKfTufRU+/vhDfPLJQkyePA1f\nfvkZli9fDbVaLbu7li79AtOn/xtXXtkLdXV1QYUonDRLl9T3m042dRMIgmhBdO/eQxYLAFi27Cs8\n+OBf8Pe/P4TS0lIUFOS7triFqXXrLHTqJApTly5dUVxcFPD4ZrMJZrMJvXtfBQC47bY7kZu7DwDQ\nufPleOGFZ7B+/U9gWXGZ3p49e2P+/LexfPlSGI21jbYMbrO0MBatOoTOrQcirVXkTTCCIJqOe2/u\nHNQaCIVwlAZRrlG/b98e7N27GwsXfgqtVovJk/8Ou93u8x3lqJ9lVX73URKoStN//jMPubl7sWVL\nDj799CN88cW3uP/+B3Hdddfj99+34O9/fwjvvPM+2rW77CKvLnSapYUBAAVloWciEARBhEpsbCws\nFkvA7WazCfHx8dBqtTh79gwOHz7kd78LKdNnMMQhISEBBw7kAgDWrfsRV10lrl1fUlKMPn2uxuOP\nT4HZbILVakFhYQE6duyE++6bgC5duiE//0zoF9gAmqWFAQAFZWb0uTytqZtBEESUkZCQiJ49e2PC\nhP9D//7XYeDAQR7b+/e/Dt999z88+OBf0K7dZbjyyp6Kre4YhhiQDp2ZM1/AW2+9BpvNhqysbMyc\nOQtOpxMvvfQczGYzAAHjxt0HgyEOixZ9gL17d0OlUqF9+44YMGBQvccPB82yWu3w6atwbbd0PDby\nyqZuSpNDlTjd0L1wQ/fCDd0LNy2yWq1Wo0JJpbWpm0EQBNGiaJYuqYTsMpTXNk4aGUEQBCHSLC0M\nc/pO8J23wGbnmropBEEQLYZmKRgSFbV1Td0EgiCIFgMJBkEQBBESzVIwkrViOm1JTW0Tt4QgCKLl\n0CwFIzsuGwBQYa5p4pYQBBFtNKS8OQAsW/Y1bDab322TJ/8dx44dvehjNzXNUjCS9YkAgCorWRgE\nQYSXYOXNQ+Hbb7+GzRad7vJmmVabFp8EAKi1U3kQgiDCi3d584kTp+Crr5bg119/hsPhxA033IiH\nH34UdXV1eP75f6OsrBQ8z2PChEdQWVmO8vIyTJ78GFq1aoV58z4IeJ6ff16LL774FAAwYMAgPP74\nZPA8j9dffxnHjh0BwODOO0fg3nv/7LfEeVPQLAWjdWIyAMDoJJcUQUQzK/J+wL7Sgw06hoplwPHu\nghZ90ntiTOe7Au7vXd58167tKCjIx6JFn0MQBDz11DTs35+L6upKpKam4c033wEAWCxmxMYa8M03\nX+Pddz90lTv3T3l5ORYseA+ffPIl4uLiMXXqJGzZkoO0tAyUlZXis8+WAoBcztxfifOmoFm6pDol\ntwUAmJmKJm4JQRDRzs6dO7Br1048/PB9ePjh+5CffxYFBfno2LEzdu/eiQUL3sP+/bmIjTW4viFA\nWebcH0ePHkbfvv2QkJAIlmUxbNhtyM3dh6ysbJw/X4R33nkLO3b8Lh/TX4nzpqBZWhiZ8ekAp4Zd\nU9nUTSEIIoKM6XxXUGsgFBpaS0oQBIwf/yBGjBjts+3jj7/A779vxYcfvodrrx2ABx98JORj+ivj\nFx8fj08//Ro7dvyOFSuWYePGn/H008/7LXHeWGtgKGmWFgbLsNA6kiHEmGB1RGdwiSCIpsG7vHn/\n/gOwZs33sFrF+nXl5WWoqqpCeXk5YmJicMstt+HPf74fx48fc33f4KouG5ju3a/E/v37UFtbA47j\n8Msv63DVVX1RU1MNnucwZMhNeOSRx3HihHhMfyXOm4JmaWEAgA7xsKMUJaZKtE/KaurmEAQRJXiX\nN584cQrOnDmDxx57CIAoKM899zIKCs7hv/+dB5ZloFZr8OSTTwMARowYhSefnILU1DSfoLdU8jwl\nJRV///skTJ78dwDAwIGDMXjwDcjLO4HZs1+EIPBgGAaPPTY5YInzpqBZljcHgH9+tQBFqv24r8ME\nXNehR1M3p8mg0s1u6F64oXvhhu6FmxZZ3hwAWunECy+urWrilhAEQbQMmq1gpBrEyXtlJkqtJQiC\naAyarWBkJoiT96rqyNQkCIJoDJqtYGS3EgXDaCfBIAiCaAyarWBkuMqDWLmmSS8jCIJoaTRbwTBo\nYwEBcIDW9iYIgmgMmq1gsAwLhteCY2x+Z0wSBEEQ4SWiglFcXIwHHngAd9xxB4YPH47PP//c736v\nvPIKbrnlFowcORJHjhwJ+fhq6AC1A1Ybre1NEAQRaSI601ulUuHpp59Gt27dYDabMWbMGAwaNAid\nOnWS98nJyUF+fj7Wr1+P/fv3Y9asWVi2bFlIx9cyMbCrjKgx2xCra7aT1gmCIJoFEbUw0tLS0K1b\nNwCAwWBAp06dUFpa6rHPhg0bMGrUKABA7969YTQaUV5eHtLxdaweDCug3EiZUgRBEJGm0WIYBQUF\nOHr0KHr16uXxeWlpKTIzM+X3GRkZKCkpCemYsepYAEC5iVbeIwiCiDSN4scxm82YMmUKZs6cCYPB\n4LHNX8BaKtAVjLS0eCTHJeBcNWCFrcE1UpozLfnavaF74YbuhRu6F+Eh4oLhdDoxZcoUjBw5En/6\n0598tmdkZKC4uFh+X1xcjPT09HqPW1ZmRCyrAwAUlFe02OJiVFjNDd0LN3Qv3NC9cHPJFx+cOXMm\nOnfujAkTJvjdPnToUHz33XcAgNzcXCQkJCA1NTWkYycbxCUQq620tjdBEESkiaiFsWfPHqxevRpX\nXHEFRo0aBYZhMHXqVBQVFYFhGIwbNw5DhgxBTk4Ohg0bBr1ej9deey3k46fFiYJRS+VBCIIgIk5E\nBePqq68OaV7F888/f1HHz4gTLRETRxVrCYIgIk2znekNAGl6UTDsrJFmexMEQUSYZi0YerUOKl4H\nQWuGxeZs6uYQBEFENc1aMABAJySAibHCaLU1dVMIgiCimuYvGEwcGEZAhYUm7xEEQUSSZi8YWlYL\nALDY6pq4JQRBENFN8xcMlQYAYCLBIAiCiCjNXjBiVC4Lw06CQRAEEUmav2CoYwAAFoe9iVtCEAQR\n3TR7wdCpRQvD6iALgyAIIpI0e8HQa0QLo85JabUEQRCRpNkLRqwsGOSSIgiCiCTNXjAkC8NGgkEQ\nBBFRmr1gGGLENTFsPAkGQRBEJIkawbBzJBgEQRCRpNkLRrxWFAwH72jilhAEQUQ3zV4w9FoxhkEW\nBkEQRGRp9oKhkyfu2cDTmhgEQRARo9kLhlR8UGCcqKihyXsEQRCRotkLhlRLCiyP8xXmpm0MQRBE\nFNPsBUPFqsBCBUblREmltambQxAEEbU0e8EAgBiVDlA7YKVlWgmCICJGVAiGXqUDo3LQut4EQRAR\nJCoEI1YTC6idsNhoLgZBEESkiArBMGhiwTACzDaKYRAEQUSKqBCMeK0eAGB2UlotQRBEpIgKwTBo\nDQAAq9PSxC0hCIKIXqJCMGLVooVhJQuDIAgiYkSFYBg0sQAAG08xDIIgiEgRFYIhWRh2gZZpJQiC\niBRRIRgJMfEAAE5lAcfzTdwagiCI6CQqBCPLkAkAYPUmWG1cE7eGIAgiOokKwUjQxkMlxICJNVJ5\nEIIgiAgRFYLBMAwMSAITY0GNlVJrCYIgIkFUCAYAxLFJYBig1FTR1E0hCIKISqJGMOLVCQCACmt1\nE7eEIAgiOomoYMycORPXXXcdhg8f7nf7zp070a9fP4wePRqjR4/G+++/f9HnStSKglFlI8EgCIKI\nBOpIHnzMmDEYP348ZsyYEXCffv36YcGCBQ0+V6uYVoAZqLHXNvhYBEEQhC8RtTD69euHhISESJ5C\nJlnfCgBgcpJgEARBRIImj2Hk5uZi1KhRePTRR5GXl3fRx0mLTQIAmDljuJpGEARBKIioS6o+evTo\ngV9//RV6vR45OTmYNGkS1q1bd1HHitfpIDjVsKmonhRBEEQkaFLBMBgM8ushQ4bgxRdfRHV1NVq1\nalXvd9PS4j3eMxo1hF0aOFQ2n23RTku73mDQvXBD98IN3YvwEHHBEAQh4Lby8nKkpqYCAA4cOAAA\nIYkFAJSVebqerDYnwGnghNlnWzSTlhbfoq43GHQv3NC9cEP3wk1DhTOigjF9+nTs2LED1dXVuPHG\nGzF58mQ4HA4wDINx48Zh3bp1+Prrr6FWq6HT6TB37tyLPpdOq4Lg1EBgONg5B7QqTRivhCAIgoio\nYMyZMyfo9vvuuw/33XdfWM7FMAzUgg4CAIvTAq0qMSzHJQiCIESaPEsqnGjZGACA2UH1pAiCIMJN\nVAmGjhVX3qutMzVxSwiCIKKPqBIMaeW99w4sgpOnMucEQRDhJLoEQxsjv66xUVYEQRBEOIkqwcjW\nXSa/5gRaeY8gCCKcRJVgZMalwlnSFgAJBkEQRLiJKsGI02sBgQEAcDwJBkEQRDiJKsHQaVUQBPGS\nyMIgCIIIL1ElGDEaldvCIMEgCIIIK1ElGFoNC0gWBrmkCIIgwkpIgvHjjz/CZBInw82bNw9//etf\ncejQoYg27GIQLQzxkpxkYRAEQYSVkATjgw8+QFxcHA4cOIAtW7Zg1KhReOWVVyLdtgtGq1EBPAW9\nCYIgIkFIgqFWizUKt27dirFjx2L48OGw2WwRbdjFEKNhKehNEAQRIUISDIZh8P3332PNmjUYOHAg\nAMDhcES0YReD1iPozTdxawiCIKKLkATj2Wefxdq1azF27Fi0bdsWZ86cQf/+/SPdtgtGrWLBgoLe\nBEEQkSCk9TD69u2L999/X37fvn17PPfccxFrVENQs+IlUdCbIAgivIRkYbz++uswGo1wOp34y1/+\ngquuugqrVq2KdNsuCjWrAgDwZGEQBEGElZAEY9u2bYiPj8eWLVuQkZGBdevWYfHixZFu20WhUYmC\nQRYGQRBEeLmgiXu7du3CsGHDkJGRAYZhItWmBqFWiS4pypIiCIIILyEJRkpKCp599ln8+OOPGDRo\nEJxOJzju0uyQta4YBgW9CYIgwktIgjFnzhx07twZc+fORWJiIoqLi/HQQw9Fum0XhUYtuqTsHK24\nRxAEEU5CEozk5GTcf//9MBgMyMvLQ2ZmJsaMGRPptl0UGpdLyu4kwSAIgggnIaXVHjx4EFOmTIFW\nq4UgCHA6nXj33XfRo0ePSLfvgolRawAANuelN7GQIAiiOROSYLz66quYPXu2PMt7+/btePnll7F0\n6dKINu5i0GlEwSALgyAIIryE5JKyWq2yWADAgAEDYLVaI9aohhCjcbmkKIZBEAQRVkISDL1ej+3b\nt8vvd+7cCb1eH7FGNQS9RgsAsJGFQRAEEVZCcknNnDkTTzzxBLRasTN2OByYP39+RBt2scRqtYAd\ncJBgEARBhJWQBKNXr15Yv349Tp8+DUEQ0KFDB9xyyy347bffIty8C0cfowZMgIMnwSAIgggnIQkG\nAGg0GlxxxRXye0EQItKghhKrjQEAOCiGQRAEEVYuek3vS7U0iCHG5Ta7RGeiEwRBNFeCWhh5eXkB\ntzkv0RhBrEswnFQahCAIIqwEFYxHH3004LaYmJiwNyYcGHTiPAwnxTAIgiDCSlDB2LhxY2O1I2wY\nXEJG1WoJgiDCy0XHMC5VpBgGVaslCIIIL1EnGFqNCgLPws6aYeeonhRBEES4iDrBAABVdTvwagt2\nFu9p6qYQBEFEDREVjJkzZ+K6667D8OHDA+7zyiuv4JZbbsHIkSNx5MiRsJzXYGsDAKi1G8NyPIIg\nCCLCgjFmzBh8/PHHAbfn5OQgPz8f69evx0svvYRZs2aF5bx6jQ4AYHPaw3I8giAIIsKC0a9fPyQk\nJATcvmHDBowaNQoA0Lt3bxiNRpSXlzf4vAaXYJjtdQ0+FkEQBCHSpDGM0tJSZGZmyu8zMjJQUlLS\n4OPG6UgwCIIgwk2TCoa/elThKDkSFyOWXrc6bA0+FkEQBCEScvHBSJCRkYHi4mL5fXFxMdLT00P6\nblpafMBtrVNbAeWAE46g+wXCzjnwxf4VGNbperRNzLrg7zc2F3ON0QrdCzd0L9zQvQgPEReMYFVt\nhw4dii+//BJ33HEHcnNzkZCQgNTU1JCOW1YWOANKzbMQBMBkswbdLxCbCrZh7YnfsOXMLrxxfXgC\n8ZEiLS3+oq4xGqF74YbuhRu6F24aKpwRFYzp06djx44dqK6uxo033ojJkyfD4XCAYRiMGzcOQ4YM\nQU5ODoYNGwa9Xo/XXnstLOeN02sAXgU7d3FZUmaHuPysyWEOS3sIgiCigYgKxpw5c+rd5/nnnw/7\neRMMWoBTw666OMHgqQ4VQRCED1E50zsxLgYCr4JDcJcGKbdWYMf50GZ+8wIPAGCZqLw9BEEQF0WT\nBr0jRaJBC3AqcIJV/mzW728AANrGZyMrLjPQVwEAnCQYuDQXiSIIgmgKonIIHaNRgRXUEBgnzhkL\nPSrXStZDMHiQhUEQBOFNVFoYAKBmtHAywOu75uGWy26SPw9lnQy3S0oVsfYRBEE0N6J2CK1hNfLr\nXcX75NehLN0qCYaKLAyCIAiZqO0RdSq9/Fq5XKuDr3+NDCmGEY5Z5wRBENFC1ApGK1WK/FoZtwhl\nrW+BLAyCIAgforZHTNG6S4ywrPsyQxEMjmIYBEEQPkStYGTo3amzVoc7vdYRimDwlCVFEAThTdT2\niK1iDXCWtwYAOBWZUaEIhtMV5yCXFEEQhJuo7RENOg0c57r4fB6KS0oKjDMkGARBEDJR2yMaXAUI\nvXGGkCXCp1jIAAAgAElEQVRld4kKS1lSBEEQMtErGDo1IPheXiguKYerym3gwuwEQRAtj+gVDL0G\n4C9OMOwuK4T3M8nvaOUJvLN3AaxOq882giCIaCZ6BUOnBsAAgqdb6UJiGP7KiLybuwgnqk9hx/m9\nYWknQRBEcyFqBUPFstDHqHzcUiEJBud07Ru4jIhADiuCIFoYUSsYgJgp5R34DqU0iCQqoRQqJAiC\naClEvWAIvLdLKvTig1yQUuhkYRAE0dKIbsHQqyF4Bb5DKz7Iefz1i0CCQRBEyyK6BUOnkWMYrOtS\nL6SWlL8sKQmSC4IgWhrRLRiK1NoYVgcgxFpSLsvCKXAQAlgS5JIiCKKlEd2CoZy8x4kLKtVnYQiC\n4FEOPZQlXRsbk92M8+aSpm4GQRAtjKgWDH2MGoJLMBw2FipGVa+F4S0Qmwp/xz9/m4kamzFi7bxQ\nntn2Kl7ZMScka4kgCCJcRLVg1Nk5MBobAMBm0ULNqOsNentnRi0/8T0cvBMHyw97fB7IVdUYSFaS\ng6s/gE8QBBEuolowEmI1YLRiCQ/epocKGtQ564J+hw+QGcXg0itESPNECIJoTKJaMG7qmw2GFS0B\nwa6H4FTDytXBZuewcPVhFJabfb4TbO6Fkksh6B1KxhdBEES4iGrBUCmWZtVy8bDZWFiddVi78yy2\nHy7Bf77yrQcVapC7KV1SEqFMQiQIgggXUS0YADD96okY1u5GXNHqcjhsLHiBh91Vvtxo8YwBcDyH\n1afWhXTcxnIH8QKPUkuZX4FyCmRhEATReES9YHRMbI9Rne9AWis9BE4NAMgXDoBtVeLjVPr9/C5s\nLdoR0nG5Rhrdrz/7K17c/h/8fn63zzZySQG1diNKzKVN3QyCaBFEvWBIxMaoAac4F+MUvxsxV+zz\n2cdo941pSHjHLEKNdTSUPSX7AQCHKo74bCPBAJ7e8jJe2vFWUzeDIFoELUYwDDqNbGHIsKFbCd7x\ngkshQ4kEw82lEFMiiGinxQhGrE4NeAkGE+NtUQTudLznbzS6YPiLYVDQW+ZSEHCCiHZalGAIrvIg\nEqw+sAvKG+/OubE6a4YJPP+Dgt5uSDwJIvK0HMGIUfsYEIzOgnOlJuUn8ist6ykuzqa2MPzQWKVB\nmoO7J9CES4IgwkeLEQxlqXMZlQOzFu9UfODuGPVqnceuDq/RPMc3blFCf112fTGMrYU78N/cj33m\nlhypOI5lx1eFJARFpmL849ensOP8ngtpbqPTWEkIBNGSaTGCEatTgyvPgiP/Ctj+6A8AYNSi1eCv\n89ep9R7vvTtnrpHdQf5mltfnhsktP4Q/Ko/B7LB4fP7e/o+QU7AVxZb601G3nRcF9ZvjKy+gtY0P\nJQAQROSJuGBs2rQJt912G2699VYsXLjQZ/vKlSsxcOBAjB49GqNHj8by5csj0o5YnRoAC2dxR/B1\nBgAAoxI7mYpam+/+3hYG5y0YTT+ira+TtDnFCYoNcZ9J1gnLqOrZs2m5FP4fBBHtqOvf5eLheR4v\nv/wyPv30U6Snp+Oee+7B0KFD0alTJ4/97rzzTjz77LORbApiNIoOz5UtpTaYwPTKwYdbq9Cv9ZUo\nV1nlXXReguEdYG6siXvBqC/obeNEIQxkiYRyDbzLbaViLm1j9FKIKRFEtBNRwThw4AAuu+wyZGdn\nAxCFYcOGDT6C0RhBVYZhMLhnazh5HtsPl0DgVECMGSyAYuzAD9U74CjqCE2WuL9PDMPHJdVIWVJy\nIF68R8p4RL0WhkswuAD72UNY31wKJrOXumBcAgJOENFORHuBkpIStG7dWn6fkZGB0lJfv/n69esx\ncuRIPPHEEyguLo5Yex6+sxsevqOb+MZ7Eh8ANq5Kfu0tGD4xjAvooOqcdTAFmEXu5J04VXMmYNFD\n76Ra7gIEo04SjADHDmU9DU52SV3igkEWBkFEnIhaGKFYDjfffDPuuusuaDQaLF26FE899RQ+++yz\ner+XlhbfoLapoAUPz9iFKsEtGMnxCR7bzJwJKSkG+b2d57D5UDEcTh7jhnUJeq57v5kBAFg27gOf\nbR/vWYp1eTl4/JrxuKnjdT7b1WrRlabRqpGWFo86h3s9D61O3BboXkgWRHxiDNKSxX2U/xN9vLre\n+6g96Tq/StXgex5JEhJFgb+U29jY0L1wQ/ciPERUMDIzM1FUVCS/LykpQXp6usc+iYmJ8ut7770X\nb70VWl2gsrKLXzJ1wfQhmL//CE7XBj6GYPcM8p6qysdHO5a53xdV4cgffwAAbr4qK6Tz+mvzznNi\nrah9547gyviePtudnDjCt9ucKCszwuJwx1lWHlmLUd1uhana11LgBR42pyiI5ZW1iOfEc9tclXoB\noLyqFmWa4Pex1iJaRoLANOieR5qyylp0TmnY76K5wws8BEGAihXFvSXfCyV0L9w0VDgj6mfo2bMn\n8vPzUVhYCLvdjjVr1mDo0KEe+5SVlcmvN2zYgM6dO0eySQAArUaFWI0+6D7eLikA+CU/x/2GcY/U\nrTYnVm05jYIyk8936kNy9YSa5ePtevnm0Gq/+9kVwqAMeludVr/7BMLqWqHwUgx6K914NHEPeGPX\nfEzNiWzyCNGyiaiFoVKp8Nxzz+Hhhx+GIAi455570KlTJ8yfPx89e/bETTfdhCVLlmDjxo1Qq9VI\nTEzEa6+9FskmyfgTBCU6dUzwAzDuzmrN72fx4/az+P1wMV7/+8CAXymrsiAtKdbjM6kjrr/DEwXK\nWzACxSGk+IX3d6yKJWrrW99cuf+lOM9BKRj1TaS0OutgcViRok+KdLOajAJTUf07EUQDiKhgAMAN\nN9yAG264weOzKVOmyK+nTZuGadOmRboZPsRpDEG3/7b3PBDMCFFYGL/uKwQAlFZZfXZTjuif+nAb\npo/rix4dkuXPWFkwgge9pbN575cS678DVLqeLE4rquqqkaRr5SEY9hCC3lL7bSFYI42NRwJAPSnG\nL/7+JowOE+bdOBtq1v2zLzGXotRajp6p3SPWzsamOZRyIZonl56foZFI0rUKuj2voB6fp8LCsNrc\nnVVRdSU2F/4OO+dAta0GT26apfiOgEOnKzwOU59geOM9klZ2fkqk+AUAfHzoCzy7bTZMDvNFWxj+\nBEMQBNQpjtfYKDPV6nPpGR2iu9A7PfqlHW9hwYFPfWbD+8PqtOLzP75BhbXyIlrbeDjJPUdEiBYr\nGCm65OA7CAycFZngqlMx54aX8UC3cR6bGdZ/B/W/wxux9NhKzN37Ac6bSzw3MoJP9VlVkBhGjckG\nh9Pzc2/XVSBXkY3znb2+pXAHamw18vtQ0mqlOIeDd/iI2rLjqzB90/Mot1b4+2rEUbraQk1zDnS/\nQonnfH9yHXYU78HHh74MrYFNxKXoPiSig4i7pC5Vkr0sDN6mAxujHC0zcJy8CgCgZbXINHhmd0Hl\n+VBe2y0duXnlyCs7DyQA+cYC6FRecRKWh3e1cqnkhj8LY+p7WxHT3Qw2zl1LyltYvEuWSNT5EYzV\np9Z6vA9l4p5ytGrn7B4z4DcVbgMAnKw+g1R9Sr3HCjceMYwQR9UNEQyTy0qxOOu3RpoSEgwiUrRY\nCyNZ5+n7F2zeAQu3H7isxoo4lWc6GqPiPPbpnJ2I7NQ42AWl6HjXU/f1LdfvkmI8DiUJRvuEdgAC\nlzgPJeZQn2BwPOfRrkDHbKpJc6FaGMoONFCZFBtf//2S7r3qEq+rVZ9gVFirkF9b0EitIZqaalsN\nqhWehYbQYgXDO+gt2Dyzl27sm4Ur2ohzRJ7+cDvmfn3U9yAKK6NVXAwMejWgltJQVT6dE8PwqLN7\nfsa6TI5QO13JJRWj0gIAnAHcSsoYRiCCuaQsDisWH/Z0vfhzcwFNt3iRMp4T7P4p4xPK4LiHGDpD\nEAzXdarYS08wlIHu+mJTXx/7H+bt+zDSTWpSjlXm4VD5kaZuxiXBM1tfxTNbXw3LsVqsYDAMg8d6\nPSi/97YwDDo1+nV1u6GKyn3dEIzKgZf+ei1u6puN3p1TEafXgNG4C/6dr/QKnDMCrHWeoz+3heFp\nfQRKE5U6Rq1LMCQLw8bZsf7Mr6hzCUWgzl1JMAtjw7lNyC075PFZiaUMB8oOB2xTY6M8b7BAr1Iw\njledxObC3wF4phjbQ7IwXIJxCVoYSvGrb2GtKlsN6jibj1X2R8UxbMjfFJH2NTbzcxfigwOfNHUz\noo4WKxgAPFIp777OM62ydWoskhM8YxBclVccQ+1Em7Q4jL+lCzRqFgadCtCIHQ/DAHvzvIPePCw2\nz4dZFSCGYbJ4duZyDMMlJLKF4Xrolx9fhVWnfsL3p34C4D+GIZERmwYAcChcTA7OgV/yc2C0i356\nf4Kz4MCn+PDgZzhZfcbj86bymXtM3Ati5Zgd7jpey45/h6XHVsLssHgISSguPOmeMD4Vvpoeh4fb\nLfj/w+K6bm9h+e/+j7Ei7wd50EEQ3rRowVCSluBZO0rNAu3S46BWsejdKQWPDu8OR35X8DYdeKvo\nzrqqS6LHdzQ6p0dQ21jnNS+DEXDwVAXyCmtQWi1uqzWLwuAtGEYvwTBZXYs9ebmkpIf+ZM0ZAJBT\nPoN1gFL8xq7oMH49twUr89bg08Nfi00N0imWeC28FErAOBJ4xDCCpNWa/KTMztj8AnaV7JPfhyIY\nUgFJK+c736bYXIKcgm1NNgdC6WoLZmEIgiALRqC5K6GkGEcLds7hMVcq2gj377HFZkn5Y1BWf2wt\n2gEAiFHHILWVHv+degM0alFX+1xxO5Zu6AxD+3P4teRnDL4q1eP7lapTgKKfL60xAcpYOcNDEIDZ\nS/YgwaDFv/7cB+fKjFAleqbLWm1O5Jd6urOksiOSsGjlGIYTJocZJRaxxIqaVcPisOCcsTDgdcZp\n4qBlNXLHAQCVtmoAwDmT7/c0rMbDL27xesCCWTNKzhmLkBGbKre9oXi4pIJYGJYAHeCPp3+WX4ci\netJcDmU9L4k3d78LG2dHRmwauiZfXu+xwo0zRAvDwTtk951yP2XHYnKYomZGfH3p1s9ufRVmpwX/\nvfnNRmpR46J8RsIhHi3ewojXxgEAErTx+EvXu/HMtdNwW/uh6JHSFQBksQDERZgm3NYV2UliSm4d\n79lRnnEcgsCpwBld2707IUWWlIk9j+NFJfJndo7DkvXHsOdYGWYu3I6PfvAM2Dk5Hk6Od1sYrGRh\nOPDt8VXyflW2Gry+az6OVeUFvOZYjR7JuiRU1rmr80rBd3/ZWglazwyxyjpRXCQrxBrC5L2ztefw\n+q53sPDg5363VxltKKu+sJGeMs7zw+l1KLf4n1Bn5epvX7CYjyAIKLGUyddpddb5PHyShVJsrn/Z\n20igFMxgguGRAKDYz6xIFfZnkYWTs7XnsOb0zxGzxpTHrS8T0NyAFOkySwUWHVyCKtfzcCmiHOiF\nI9bY4i2MGf0m40TVKXRKbA8AyIrLRFZcZtDv6F2FCyVTluM5LD22AkauGoI1EXCIdagY1hWgFmJh\nZyxol2HA2VPA4GvisYdZi2+LDgCMmJ1VXmvGmX2F+HVvYMvgx9/PIqOD+ANgoQEgPvSVio6y0lol\nj4SVZBkyUWQW1xqJVeuRrE9CsaUU5dZKrMxbA6tD7PAkwVC6pBK0caioc5+jwjVRT8Wq4OSdIc32\nliygI5XH/W6f/t+tAIDF/7653mNJeE9iXH74R9zdfqTPfqH45INZGHtKcvHJH1/L7zmBg513yG5B\nQExe4AUelbYqf4eIOMpFsoK5pJTW4cZzm+HknfhL13tQY6uVP1fGfCLBm7vfBQB0T+6CDontwn58\np8e9qH+uESA+wxea/fbl0W9xovoUVAyLh6+874K+21jYudB+F6HS4i2MZF0S+re+2mcGdjCk9b6l\nEefRqjxsO78LACDYdRAE1211CUasVtz/nps64L1/3oD4BHEExGhtYFwlRjiNCWCD/UMFfLflNL76\n5RgAYOfhcgDixD2T3YxkXRK6Jl3uVywA4OqM3vLrU/kWJKpFK2jx4S+RW3YQx6rF43I8L8dLJOK9\nLAyp85dmqQeyMEwOM/aUiOXbA5Uw8b7GMkvgWeMlljIsObJMdjF5xy10AVxddSFZGIEF41DFMZ/P\njHZPl2FyjHg/KwOMNo9UHsdHB5eIAl9XhUkbZ2BL4fZ62xUqTiE0C0Ppnssp2IatRTvh4J0egmEK\nQTA2FfyOY5WBrdhQECDAyTuxMX8TamzhKz+uFIlgqeOemWWe+31z7Du8vef9oOepC1I2JxR4gceB\nssMRTRpRXlc4ztPiBeNi0KtFC0N6+DSKkYmejQMEl/i4BEOvliwOAbE6NZLjFSm8CjeVpp17rsdl\nGfEY3Ku1z34Wm/jjLCqzQsWo4OCdMNqNiNfGITkmcLmTGJW7+u7BE0YcPyUe52ztOY/9HByHqe9u\n8VjqL8HltpMos1bAZDfL0xKVMYw6uxN1dvGHOW/vh1h8+EscrTwBdZBUVMmFoG57DC9sfyOgFfLR\nwSXYfn431p39FYCviR2r9V8tMpQ5KcqHnuM5HCz/Qy7O6G8sccbrvsVqREuxIkCZlPdyP8K+soPI\nNxZgvys1+etjK+ptV6g4L8LCkKhz1qFaaWEEWB0SEDvg38/vxjfHV2J+7sKQyssEghd4bC3aif/l\n/YBFAVyVF4PSDRXMwlC6IVfk/YBFB5fI7w+UH8bJmjNBxUD6/VucVvzvxOoLDp7nFGzDhwc/w/IT\n/pcoCAeOEO9FqJBgXARS3ENKQVVWfb217+W4tovo0hJngwN6jdhZSx1cWqJSMNyjHFXKeQzskYGJ\no67ErIeuwd1DFGufS8Ii/RVYsGBRW2eEU+CQoI2Dw+w5+VBJnMbd6QucBtZa/6NxhhXAXnbAY/Tl\nbWEAwKmaM3JnoXRJTXx7E6bM2wxe4GUXmNlhCVpcUSreqE4XO+HD5X4mScI98pUsGu+AZqBg+rmK\nwD7m7DhRlJWdx8Zzm7HgwKdYmbcGgKd7Ttr/lCsrTULqpPy5v5QWmCAAWpUmYHsuFs+02sAdg9lP\nwP7fW17C/rKD8vtgFsaPZ37BF0fcC4ntU3zvQrFxdnkGcr4xfDPPleVygsUwlMkLW4t2IrfsIIrN\npR7t8rYk/XGq5gw2ntss/15C5axrtv3hCv+/93CgFHRySTUR0ixx95wF9yhEzaoQqxM7LrVa7CSl\ntTUkF4pe5+6ADLHukTdvTMLfhveQJwwmGrRITxbdWR2yXJ22LBgMnE5GDvSWlfPYsivwjztGUMxs\nd2qgsicG3FedVugRRFXOio9jxJpRO4r3yHNDvF1STk7A3L0L5PecwHn8WJXBdgAwSi4wTrwXoUyi\nA3wD9NJs7RqTDUvWHYOlzoFqkw1nSgPHFeQUY8WDddbVeW0q3IbcskMe7spOiR2gYlTIry2A1WmV\nvyfNafE3Ij1ZfVp+beftcsJCOAk1SypQHaxDik4rmGB4lxTJU1zbhWLjbBdcrTkUPF1SF2Ztvbzj\nLfx6bov8vjaIYAhepX9MfiyzIlMx9pX6F1XJMyH9vywOK7459l2DqiGXmEux/syv8v20k0uq6VGz\nasSq9XK8QBkwbROXJfv2e14udsrSyFcaESs7z1idCq1ixP3aZvq6VGK00r+IR0KsBnBZLe0zWoHn\nGFmEzhU5INTF+XxfQnC4JyEKPAsVFxd0EakCxahco1LLbazKF91eylngUozA5nCLjHIE/sXGQ/hh\n+0n5/XPbPBfJkiYpCrxLMAK4Obw9QzUWT6GycXYsP/49nt3+MjYVbcXKLSfF+SyqwA9Ksi4JDBiY\nHWbwAo8jFZ7usEUHPwerOHOsRo9WMYk4XZuPJzfNwqeuYLjUZn9ip3Rf2TmHd4WxsBCyS8qPhSEh\n/U4tF1Cy/nTNWZ/PVp9ah8m//tuvi0b5v7VxdlkwvDvfhhCqGybQvZAqAQBA7QXEVjR+LMdXd76N\njw75z6RSueJ60v9u7dkN2FS4DYsPfxXyOb15ffd8rDr1E/5wxd3IJXWJEK+N97EwhrW7EV2TL5dn\nb0sTo6RsGs5P/jsncFCzamhYNTR+PBXSSIETeKQk6sGoxXON6H8FILj/fYJDi9v6XBGwvTaz4uBO\nDVgwaBvfJuD+J8vd2Vqck8Hw9L+g7vBAOM939D02Zwcv8O7Z6V7B+zrOiqLKwMvXykF2WTA8O90q\now3Lfs0DL/UprpjH4TOe8QK7045DFUfAs3ZoLzuKU/wu1JrtYLwE4/FeD8mv4zVxSNYlodRajg35\nm/De/o+wr/SAx/5KC0PLajwqHR921SuSihfaOLtPuqgyTnS2tAq/7PXtZC+UIlMxlhxZJrvAQg16\nB0sjzTZkQqvSXpAv/ry5xCdLbu2ZDeAFHqf8iIkyA8vO2SNSZsXOhSgYAa5TmQAQzMLwRs0ETuyo\nqPO1ciWxlP5fVod4H6v87Bsq0rMjXZuHS6oB8SYJEoyLJF5rgNlhAcdz8j/p8iQx5iA9BNLnsoXh\n6vyVI0CO56BiVNCyWr8/bqnGlCAI6N05Ra5VlRrXCjq1wrXhiEHrFIOniHDuhzG/xIwXBjyFbNMQ\nCPZYgAF6p/YIeH2M3v1gf/5THhb+7zQEcyJ8x/kitVYL3l1xANDYoO/3i+ex1A45xVjim40n8NbS\nfdifV+4jGJVmz07tvysPYu2OfNSaxftZXiM+WBVG8aEQHKIYHj5bCqvdAYFTga+LRbHqEEqM1T4W\nhkHjjvXEavRIj02F0W7yqZ2luAL5lValRasYt2BwAg+O5+SHkRd4n7pWysmQq7efxKlid4fgLS5O\njkdFTf0j/PdyP8L287uxqUAsMa8UiR/P/AKjzb9AW4NYGMm6JMSq9bA4rDhaeQKTNs6o1+UkQEB5\nABeKt+sRgFc5FltY1op38k6PtWeUz1GwVSUDueeU1s6FCEawisf+3ExS3Ez6vbBs4LVxLhZPC4Nc\nUk1GvDYeAgSYHBb5Hy9ZEtJD4BYMsUOT6h15+JsFJ1QMC41K43cugKCwMG7vfxlSUljX+eOg07qt\nBsGuQ2ZyLFSu+RlcTQrq9ruXxt19tBSp+mTE2kSrwmx14I+9Bqh5/5lFUsBePLjnz8Rxzncm88b9\nZ5FfYoIqsVz+7NrMvuILlcMjuA8A63adxR9nqjBv+QG5DIrkkjpTKprvuSfK8d6KgzhVVOvxXXOd\nuH+VSez8JKsnv7QKRrsJgtUAriwbYAScNp72sTCUgqFX6eTaWkWm837vhTK4rlVpEKd1f1/sMCs8\nOhnl/5EXeE/fNst5LL7lHfP46ufj+NcH23D6vOc1eyN1ZFKGmrdVsTV/N+ycHW/smo/NivRdY5AM\nqDitAVomBhanFd+d/BEAsO7sxoD7S6nSyjk6Skot5T6f1Sg6YBtnv6B09kB8/sc3eGXHHLnGWUNd\nUkqCxjC8xD5QRQEAKLX63gvJOpS8CNJAM5yCYQ/RVRkqJBgXSbxGypQyyg+9JBisK5gljTiklFZ/\nJRmszjqoWRW0XuU3JKQfDy9w0KhZpCQzYBkWsWo9DFq3hSHYY5CRrIeaEQVDcMQAzhjYT/cA8gai\ntNqKFZtO4dAp8eGuNtmx+0gljLuvB28LHMsQD8bg0eHd8ehwsUCj83xHOMtbe+xSXF3jOq9bxDJi\nxeA9o3YA3isUKiyOo/mukSgv/hwZlsO+E2WY/78D2Hu8zKc5VpuYumu0ivdXcLrOqXaAUfEQnFpw\ntWLZlmO2XWDUTgic+6fOcu77tvC7E+CtogAEyqhR/l9ES9DzwSu2eLZRKRhWZx0ECLK7gmF5D/H0\nniT3W26R2O780GYPS0LlLRh6tQ7Hq04i31iApcdWYHdJLowWO44WBp6JrmV0OF/mgNXhDuaXmMvw\n780v+V0/o7Xr/+s9ek5yWWD+Zr1vK9opv7b5qZh7MewpFef6nDWKrj9HiC6pGntgUZYyIWvtRtTY\narH65Np6Kxr4EwzJ7VTmRzyVrjzR0yAlAISv+rPTI+hNLqkmQ6q1U2otlwVDK1sYomBIIyzvGIZ3\nh6NiVNCoNKixG7Hx3GaPbbwsGOJfo92EOI0BLMNCrxCMB4f1RnysFipB7DylUTVX1hYPDL4O+hgV\n1vx+1qeMOsACXPBJdXcO6IABPTKRKqcDM4hTeWZZ5Ve4On2lMNh0EHgWsQbBQyDE/dzvD5wtAiC4\nM8BYDut2es5zUFJSZcX+vAq54xVc7We04ohRx+qRrE4DBAZmeI1+BeDpD/a43zq1yDsWPGvpWKH7\nYeccLK5OFydBSpaJdzFGKY7x+dqjWL5dLHAYKy3AxXIe114TYATr5HhY6hx44ZOd+GW3773wHpl7\n19Kqc9o8XEKfHP4Kp4trRGsvAJxdDTjVAAOYXPG5irpKGB0mfHlEnDOitKQyDRkAgJ/zc7DkyDL5\nNyp1kiavSaS8wONg+R/y82F12MI7ac3121aOqpceW4kKi/+YQHVd4EWFLotvCzWrRq3NhC+OfIu1\nZzdi+Ynvsfrk2oBVAcyKmAgvCHj3fweghvjb8mepKEvWmJ2WiFgYlFZ7idAmLgsAUGAsCuiSktCy\n7iypM7X5+PbEKo/tKpeFAQD/O7Haw+8qPaBSR2+0m+XRj+QSiFFpcX3PtgDcJUOULqWMZD06tPas\nxqtEUAhG2/g2MDCey9cmxYsWkkHv3q+Dqg+cFa3hrBDnnJQbxc5BGatYsek04NRAq3P6rIF+x3Vt\nkJKgAxNjgb7vr8joc0zuSBmVE8fPBR9hf/j9YberidNA4FkwWvEBTI1LRFK8HoLdbTk5Czu7rlUD\nTtkUpxrxbBKuSOoc8FzKTBmjmccVSZ3wxuBZGNnxDgDAqpM/eexv42z442wVck4cwU6H+L82Vrvi\nSSznIarKAKdUYBIA6uwczhYbkV9iwle/nMD5Cv+uJOlzKcGiZ2o3AKJlU+iaByNxrqpCtPYCYLOq\nIHDi78c7OJ5fVo3P1h7FiQL3/0XKyqu1G7H9/G6UWsphcVhlq8nG2cHxPH7dWwCrzYlauxGcwKF9\ngp5oOOIAACAASURBVPhb3ZtXHHQdk/rILzHi87W+cxi8rYrdhQd89gEQdBW6VrpEJGjjUWs3oszl\nTtp+fjfWnt0ou+u8UVoYNSY79p0oh80h/l/8xVKUc3bqnHXyIMDJcXByDRMNwY94kmA0IZJgnDMV\nyqmUsmB41aSJUQS939nru9KZaGG4R7nVNvdDyStcUg7OgTquTnaHSSM8pR/8xvbXAABu63a1/Jle\n67kYlA+u2IGW1eDf10xB70zPGEVGsmhZGHRud9O9N3ZFRu110NSJrh9G5cSgnplIT3Ffh7MqHQKn\nhsBy0Lg+7p7cBQDQ/8pUvPH4QDx0t3gfazVnkJHsmhGvtfmMhNtlxEGjEn+ukoBJGWOJOgPAq8Cw\n4kPSp0MWurZLguBw3ffaJDgrxPPI7isXAqfG+UoLLk8ILBjSGicAUFhqxQ/bzmDOV39g8Xdn/O5+\ntrQam/cXQRXvtm4cVnd9MaV4StkzgiDg+Y/d7poft5/F5+vcJUmKK7zcHa6B/t7jZeAFQR5JsjZR\n7MtrjSg0esZkfij6Fow6SMqthREtDH+onMjJLYLD6bYwtu2rggbuCgKv75qHf22eJcdV6pw2/LQ9\nH0vWH8dHPx3Egv3igkaJajE12+qou6BUT++YwSuf75FdeEq8j2nQxsLisOD1XfOw8ODn2FYklvGp\nCiIYBnUsErTxqLJV+0zGlFxw3u2p42zILTuE/NoCVyKHIKfB+0u3Vrqk6jibey4KI3gMHi4Gd4IN\nzcO4JIjTGtAqJhGFxiJ5wphkSXinCmoVLil/D4hoYbgfVGU9ImVareRzTYgR3RvXZw8AALR2CQcA\n3NbxBjzVbwqGX+Eu4hejVWFI7yz8c6y7nhQA9OjgKiXisjCkH1m7BK90W9dzEatztzG9lR4v/bU/\n+nfJdl2EExNu64qh/cTYRrZpiChEvAoO3o6+XcRzSQFnB28HyzAex9QoPEOsoQa9O6Wg7xVpmP/E\n9XjhoWsRFyvu0L2Dq/S2a7TMcFqPjLAEXRwG9cyUG945qxX6XS7eIz2rR5u0OGRX34K72t6FyzPT\nUVJpwffrA1s0jEIwNueWYMWmUzhbbISxxn/n+vnPf+Do2SpZsADI1k5KksYjhiH9rw+f9g0cl1S5\nXRw1Fs8OR9lXlVVbsf246LbatV+0hn7cfgJnazw7U9bg6xaRKisDgMkICLz/a2Jj6qBuc8xzXXpe\nBdbhntTp/du2cXbkl4jnPG07iHMmsT0qh8s9p+LgVEysyyuowRfrj+G3fYU+9czeWroP73zraSl4\nj8Klxcm800ctDisOVRzFOWMh9pcdwpdHv8X6M7/6WBgjO90uv47V6OUqzd712Q5VHMU3x1bKMTQA\ncvHSRQc/xxu756OwugxgBDByNWpfwVC6pGzO8LrnpGOFGs8JFRKMBtAmLgs1diPK6yqhYdWyZeHj\nkvLKkvJGxag8ROb387vk0YskGIIgyJN/pKBia0MGXrp5Oib2flj+LsMwaJfQBizDYtq9vXHnwMuQ\naNCCYRj06pSCuf8YhFZxYkd2W3+xUqjUmUkxlnZe8zPaxouioFb5/lx6tBU74lv6t4ZaxcqB41v6\ndUC7jDhkJsXDzjnkY0uCIVlFyrIbyh83a6jBwCsz8Y8xPRGn97IKwGPGn/ugXZY4ur3vph5gBPf9\ni9fGISMpFhkpomWk02rwyJ09wTIsumW3xkt/vRYzx/wJt19+Azq71m231/qWP5HvqZc7rX/3DDw3\noR9iNQGSBVgOtRbPQL+0BHDHbIOHu7DCWonCcjPeXrY/4PkBYEXOKZRWiVZGrcXunpMCYMfhEhTX\n1Hich9GbwTMBgvjn28uv7Sf6yq/PFtqgjg2c6aPJOu3hchR4Vs7K80ed04bjLheWco7Clj01oguR\n5WB1uNs4+8td2Li3EJ+vO4YPvjsEQRDwyY9HsGTdMfxxpgoHT1WgpDJw+77fKqYAe5ezt9itOFzu\nWUBylWtlSo+qzGp3XM6gjgXnDNw9bir83cPF0ye9F2IYd8ZhbuU+j1iVOFdJQG5eOTieh51zeIhI\nHWfzOJ7VduGuui0H3BalFOBWpjF7u04vBhKMBtA2XnRzVNZVedQx8rEwWLdLyh8qRgWHYvWz3SW5\n2OvK/JBiF5zAyyZ0kmLiWNe0znJ5C2+u7JiCu4d08giQJsbFYOb4qzH13t7o0T4ZN/fNxg0dPS2P\nrDh3BtT8G19DnNY9ipw0+kpMH3eV/D5BL25jY0SzXfLVJsXG4oWHrkVKnAECBNn8lgRD2i/QDGVG\nZ8aVHVL8XpeTd6LrZUnQ6DioGRWu6pSJtqnuhz1VJ1ozIy+/BQBwU9vroVVpMfmqRzCm83CPY/W5\nXAxcQ1BhSpfpSFSL91LgGXBVab4n59T489DL0aF1AjKT3fdFa0uFI190t0mxlaREdycpxYlsvN1j\nguahgkJ89pPSDy8AfmY9m6wOvP7lXgDAb/sK3fswYhVjyT0nCQYbJ7q6HIWd4Ch01yRL4LPAlSkG\nBAoXXU0N0DnG/b/1hxQnAgDwKhhNQWZoMwJqXbPxVayiq3HEiGVgWA4HTisyzBQd7JGzVfjw+8PY\nfOA8ck7lQttlF8A68dOOfJQHWTfF5uDkOSh3d74LALD7RAF2FfiuRQ8AYy6/S3596A/Fb9GmwoFz\ngRMvAICHYrKknYWzzu2eO2M+6XE9ds6OX3YXYP7yA/jqtwN4dturHseqc9o8LACr7cKsDUEQsPhH\n9xo6UhJErSN8VYABEowGIcUxAPeoH3BniUhIMYxApcfVrMrHjD5SeQIAwMMdw3BbGIHrQIVCaqIe\nPTuKnfH9t3TB2GuuFdvhEjoNq8bjvR7C9Ksn+cRjru6S7nZlQbRGtCotDpQdhiAIsq9WsqqkYL4U\nRI2VBUOaGe32Dzt5p3xt6rQiLM37xm/7pYfB7LDAoIkFwzAea1Ok6sX29UnviblDXkGPFLEjvyKp\ns89Kch2zEtCvSxr+b+jl6JKdgSSdmBwg2HXgje7rFATAfrInBJsBCQat6z66LYwJfYaDd3XWOh3Q\nKk6L63q5V2S8vpM4SdLBOaBz1RLjbTowGjvyCsWBABNjgf7adVBnncTrfx/gc93VJjtKq61Yt/Oc\nPDKWrB9GbYeKUePqTqIYMCrX78acgPuvHiofQy3o5fRlEQa8yfV74tS4odOVuFn1NzhL2vqcH3DF\nlyR4Vly22BK4JI00adIuKOam1BmgZrSAygmrXXE8ZSYd60Su8ycwhhrEdNkDVWIFVClF2LS/CDMW\nuEt3eMByWPDdIew4kQ8A6JjYAQBwpCwPjNbXJZSh6oCuSe543bY97s71fKkDzoLgKydyikHemUIr\nHA7FhD+hwmP+j5134Gi+6HbcVrJNHvnzZvH3dryoHFa7u42WOv/WYbG5FL8VbAXnyqKTz+e1pPNJ\n1+RQk90ENatGliH4Gj+hQoLRAC5LcD9UbeLd4uHdyUrvt5/f7fc48Zo4H/+lVE5CDnpDkJdRVVoY\n4UCr0uLpa/6J5wf8S/7sytRu6Jh4WQjf1aBnSjeU11Uip3CbbDlIQiFZXmaHBQwYxLpKw0uCoSyN\n7uAdiNfGyYK7p3S/h59Z6iTdxdosMLgKI2YrrCKdokZWfcvBsgyDiaN74pZrxP9lgk48HqOxe8RF\nBEsCuIps/N/N7uB4crwOzhLRrdcpuS2GDxBH8tdfnYSX/tpf7gCnXz0RE/7UGxpWDTvngJQNLdh1\ngNoOyVro3Uf8q2mTh/SkWEwfdxWm3NPLo70vf7oLVpsTLCvei/atXZMI1Q7EaWLx+IjeHgMWwRaL\nHm2y5ffdsjIxqKf7t9o+Mx4JRTfCunsYAAadshPx/+2deVwV57nHfzNzVg5nAQ77JqsiKosKLkQR\nCbihUEEbkza9as1iNKJZDPfT2BtTc29MbZO0ualNW5PWW1vbmn760U+allSjDcFoJGpQEzSKGAHZ\nZD/bvPePOTPMcEBRIQq833/kzHZmXs+8z/u8z/P+nonR/nBcSgSxCyNm4lAjtKNnEah0bcKB2LyQ\nZ30Yhro02M5NBukVNBen37odwv9znCMLcGqg5TTCPlk8R+69qAIvgfO5Bm1Cec9+WYbX1cYOT80B\nlsdn5xtBVN0gThWOfCp00Jyx7/gU59LDSyW0n2A0e674/se14Nv84LjiKYUjIvcwyk8rZWoYlgdr\nUK7zOPWVkH7tsvd4daLBOHSyGqdk3laHree9IITg9FeNuN5uw39/8ir2fvFX7PzgQzzx08NobhOO\n612t8kRVvbBWyd4OL9YAa5fyd3S7UINxB8g77iCvniyk3lNS/ckf6DihY5sTniHN/U+0jke0eQyu\ndtTBxbsU6zDEvHq5NzNYhBlD4Kfvv57GjciLnge9So/3L37gIYciehodjg6oWZW0XXxeeYaXg3dC\nzaolowIA+y/8Q1rcJV+kxhMeXc5ueLmrH4qyLHdKsv9EAO6OThYAJg4tCu6LQk5aT4U4J8/DcSkB\ntmM5MGq8kZuYBL1Kj+MNn0KrgSx7Tuh4NawGNt4OTkVAeAZwaIRaGyoHOJZBcoxyCiw0hMMvq18G\nF9ijydTR7cS89AhpiifIXweDTgVG7YC3xuD2tnqmRhKCQ+Bn6mlPs16PlfenwnI9Gd2fT0eo1YA1\niydImXI+Ri1iQk1ISwiAUSecp2I0eHC2ctoSALzcWQrTxgfC4owCf90ffKcyfXvt0gSEWg1wQmiL\nU+e63OfqBM0xWZxHrhIgBtcZlhfaCnAnORCwxkb856/+7TFxJ8ZXGLUdxKHBB5/UKfbbKtNhvzRO\n+nypxoGNPz2G7opZsJ1JUxwrZdO5biBFL4/nuFRwNgiG2KoR+gLWolyf4xINjCxxQBrocE5FfZc/\nHDqHr6624vi5a1i/cz/+96v/wVN7/iBNW50i74PRt+HVsv/D2x8cx4ef9coWY3gcOXkVzV1taGkB\nPjndf0bYrUANxh2SEiBYbrm30dtA9Bdj2Dj5MZSkFcNP7yv9EERxOwKC6/ZWRfC7svEcfLQWqZO8\nV/D38kOUKQLX7W3SQjRRuVM0EDaXHQQ9noetjykpAgI1q1ZU5/vo6lH84tTbwn638XQSJzocnSAg\nkocxzicWLMNibrjnSPhWmBqUgimBybBfHA9Xqy/CtNGwXxoH+4UJCPBR1htJTwgEwGBFttAJ6VQ6\nTAlMRrujA1931ErZc+J0mU6lQ7ezG3odAxWrgr9RmAoK8ldhTLARKq5noOHiXSit/lBow8gz7pgE\nD4Bgycwo6Tgn78SP1qSB4ZySDL3ObTC81QY8tUxIsxYHGeK6hycz85EcEoOiObEIdD9XsJ/wL8ey\neHTJBPga3OnbZnOfv+HiolQ8/70p8DXpoFa5f/O8crBkNnGIDjH1dK7u/RaDl5AGLXbyYBA/vu9p\nGDUjPA+r6wQXcBnahE+gDvvS80CGF9pIZQfH6wGeAyE9XgPf5Q1X3RjpsxhXInYvmLz0eHZFirTP\n18uIZx5I8UjD7o8N30rF+sw8lKQVY1PaI9CwGqj8lOtgGLcop+gpWfV+SAoRPFaGcyoMEKPtxNa3\nj+Hn+07BbrootIOswBqjckI38d+4pjqDcvu7ioC38F08/lZeBbA8DGoDHsicOKDnuBnUYNwhDycs\nx6bJjyNeNsKVT0k9NXktNJwGP83c5nFugN4qjTASfYVOJ94nBmatMEprsV33kH0eihrIg0GgQRgd\n17QJQntioF/DypMBWMmA/O3Ce/jHpYNSpyiiYlWSYRCpbDwHF++SYhdO3ikV3BFXW3upvbC78DUU\nxC68o+dgGRb/kbgCrvoIwKHDipgHhU7GqUVyrFVxbEyoGT8vnoU5KT1TPqKnea2rUbagU+u+Rz06\nHZ1w8k7o1RpMiRXiDQ8tiMJTy1MUWTPvnj+gWPWvHV8Ofdr7CJh6AloNJ02e2HkHCCecJxoMsY1N\nssJXa5NXIdFvHLLC7wMABPh4Yd3SSTAZNPDWq7FtzTT853d61u4I1xE6S51KA2+1AepeZXYtBh3G\nBAm/VdFgsHaD4hiby4as1DBpPQJxqcCxjJRhpvfiwTEcfHUWNNtasLZgAnpj1rsNmU89NGMqhe/x\nEbyHmRN65ubjI41YlhMBhgGCzT4AGCllnCVqyVtI9he+g3Tr8eO1M1GYGYONy5IwNsJHkrbZtCwV\n8eEWxaJWAODbTXDWek7VhvlZMDFaeJ9NGiMWx8zzOCY8SI+XH50uxXXWJa9GpNVtiDnl4lbOKFud\n7u4COK5v3S1GY4M1uhaKZAmGl1brR/pZMXPczaeXBwI1GHeImlMj2p2DLSKfkopyxwF6v2yA0rAs\niV2ADSmPYEZImlR7QszRH+sTi0luZdlYS/9zqncTUTdKlFUWn1deXW7FuEJp9AugzxWzalYlZZPN\nDElHSsAkt8hjh7Sa2cm7cNadFDBWtkKbY7lBEbMDgJhQoSO0WnSYkxqKB7LjoNV4SnHrtSrFd4oB\n94auRpmWmNCBe6n0sPMOdDm7oWbVUgfvQDe0Gk6hVdRbIkakjalX1qJ2OaQ4j2ggwtyDEPkUY7Ah\nEI8nrZRUAnoT5OsFL51yND0teCqMam+M840DwzAeK4XlnuDksYLhnheRg3mRWZg/Rgi021x2RAYZ\nERUidPrfnjMO6wsnSW2i0bmgZlUwaUxotbchJd4q1H2RxTbkXqgIq+1GTFq1lBoOACZvFcbHCt/j\nbxA8Ko4I3yMmU0QGGvG9xBVYFLgMq+6bDR+jFgumRSIiUGi7ZMcyeF9YiCBfL7Asg4mRymCxr7cR\n8yZ7LvIUi6SJzAnP8DhmyewIWC16RITo3OfoEC4aDFYZz2Hdiz6FMs3C70vFMSB83122JvQr7Fg3\nQ/qs1fasHwo0WqBTaW8azxsINxYRotwW/cUsfpC+CTaXHS8fex2AMptKzaqkeXipWJHbYIij3s+u\nnUZqwOAErwYbeQxHw6qlTlT+I51gTYCaVSEveh4qG8/ifK8ypwCgYtWSweAYVqon3mpvV3gYVS0X\noGI4D2M9WDz17RR02Zww6NT4Ts7YAZ/nrxeyz75ur8Wl60I2k9o9DSfGZq7bWxFkCJQMhqi51OXq\nCVwa1F6KHPpoc6RUX6Ld0aGQyhc1y/y9BA/oO+OXY1xQNCJ1Y275ueVMD56C6cFTpM8+WsELKIpf\nAgaMwoPJmBiMyEAjwgK8wTKxKHOvphZrS+j1AGzA3ORIcCyH02c10rN4qw0waY3gW3l0ODqx+Tsp\n2PX5WVx2O1ydzi6EegfjSi814a9RiRBrj0fjghPN7sSQaP9ATP3WePyr9Qucb70AG+lC8bIkRAYa\noWZVmJ84BX2xZpEyVrN4egxe6ZEeQ0yAFeGWQMCtWB9likSzrUURNxJ5dNL3sPPUO4izRONccxXO\nd55BCuLgbQDQIigla9zTUHERBrQ4HHAQb7R1OMB4CVO7S2ZG4WKZHo0QastzLAO5/23RmkEIQYej\nE1rZLXAckZIIArwFoxRqUAqG3g7UYHyDiFIeWeH39VtDAAAs7ikp0cNgGAYaTo2pQSn9nnO3iTZH\nSh2KXaHu2jNqFUeV88ZkwaI19Wkw1CwnqXWyDCtTDW2VgoI2lx1NthYEewcNSX1sANCqOWjVt17c\nx1fvCwaMpKAK9AgFymNPk6zjpfUtf/zir4qYBeAp4xBhDEOEMQwHa/6N5u4WmZClQ5LOFo2VmlVh\nSUIOrl0b3Bz8J5JXobajHskBnvPhDMNIo3QAUjXHvV/+FVMCk9HtErwq0auWDyTaHR0wu41Pq70N\nH1w9jMv2nhgFT3iYtSbJYCRZE/FZw+fSuSJO3olmt6Cgj9aMlAh/uOqm4fznFxBhDJNSyW8F+X3O\nH5ONtKAUWPV+aOpuRrfThrzoXOn5ezPROh6vZm5DafWHONdchX9dPoLJAUm40n4VWk4DjuXAEhYc\nw4FwDrAuHmqiAt+tAWtswv88NhV+Zh3iIr3ReFWs1kjAdxjBaLvAqJyIMIbBwTtwpukLhVy70ZtD\nu7t2jtifPJmy5pafvzfUYAwBNxNUWxqXd8P9ooch1hlgh8HMIcdyyI6cjb1fKIUV+3ODk/0n4lhd\nBc40KUuiBhoCEGWKxNnmLxHo5S9Ne8hLXIoSKfJU2nsFNatCiHeQx2gYgJTCCQi1QkRj6uSd+MMX\n7yqOtbnsYMBIMSwtp5WmPeTV29rs7ah3y6sHeCljLINNkCFQGvTcjLG+cbBozWixXcf2Y6+jobtJ\n4VH3HpGbNEKndt3Wio9rPdPP5ZlzsZYoWHQWHKr5N2plhZPsLofkYfi4g/STA5Lga/GGH26gpXYD\nQgxBWBh1P8b6xCHGMkbanhM5Z0Dn916T9crxnwPomc5jGAZmrQkttuvgCQ8dp0VSaDhOtzXBxrYB\nMEqGQPwtJIWNwXV7C6o7LiPMGIKGLiGlV66N5WfRYoyPBcebAbO7P+mrhOytMuQ90Ycffoh58+Yh\nNzcXO3fu9Nhvt9tRXFyMnJwcLF++HF9/7SkmNty4U41/i9YMHafD541CVkR/dRruNWaHzsCCMdko\nilsibZN7GHJ0Ki2eSF7tsT3ZfwJWTXgQK8YuRUboNMnD6KsmcvAAO69vmm+P/Vaf2+WdXoDeCj+9\nLzZP3dDvdWaGpkt/q1gOZnen2tDZk/Pfam/DsboKsAw7JOnWt4tepcP65O8DABrcAx957EXbayBh\n1vZ4GH1fr6fttCot/NwGQV6LxME7PBa3MgyD9LAUKZHkVmEYBgui7lcYi1ul3eGpNCz3IEXD2uHo\nhIpVYVyQkHFZ11mP67ZWnGxQrlL3NRjh6yU8X5h3iCRGKn9HCFzSlJTlNp+9L4bUw+B5Hlu3bsWu\nXbsQEBCAwsJCzJ07FzExPRlFf/rTn2A2m/H+++/jwIED2L59O37yk58M5W0NOao+Aty3AsuwCDT4\nS4v36js9iwjdizAMg4XROYptfB8yF3JK0orh5J041XAGHY5OWN3TKmJnKc6Tn7h2yuPcCX7jPLbd\nC0SbI/Fa5kt45vAPFenWetmUlDg1E24MQV50Lg5f+VgKXmtYNUxaE+aEZUDHafHP6kNI9Bsnqab+\nq+aIx3fGW2I8FozebQINAZgbPgttjnYcrf1UsU9uMNSsWurQ93/1jz6vZVB7warzRUN3E8waEwxq\nYVBWJyvSVN1Wg5r2r8Ey7G0biKFgbsQs1HXW41TDGagYzmMGQjRuLuJCp7NLSlr49ef/1+f1vFR6\n+GgtONd8HtHmSKl/+MMX+6Rjvmy5AAAesaY7ZUgNxsmTJxEZGYnQUCHtcOHChSgtLVUYjNLSUqxf\nvx4AkJubixdeeGEob+kbId4nBvMis5B8BwFqcdQAAMvi8wfjtu4KvUuQ9kacVpJ3rHLE/P86mdFc\nHp+PWWEz+jz+XoFjObx83w8VUxJsP/XQ542Zi/sjMrH+4HMAgB/P3goGDBiGwZKY+cgKvw9mrUmK\ne8lH4WaNCQ7eccfpxEPFt+IWgSc8jtZ+ijhZhp98hP1f05+FQe2FsT6xONdc5XENHadDRkg67gud\nhsrGL5DoN06q4d3bePKExxhThMdU0N3EpDHi0Un/geu2Vmg5DT68UqaQ6rDIpH4SfOMRY4mCXqVH\nl7NvzayxPnGItUQhK/w+cCwneeF9VQQ0qL0GdSAxpAajrq4OwcE988yBgYE4dUo5Uqyvr0dQkNB4\nHMfBZDKhpaUFFsu9417fKizDIq+PPOxboTBuMbScBkvjFkvu+nAkziIsMhvonG9veqeBzh+Tfc8b\nC5H+XlT59Ir8WLPGBC2nURoZ2WjZr9fiudzIrD7z/e81WIbFj2dthUrWHvLqdOLzPZG8GicbKlHb\nUY+/XXhP2r9ywoOSqsKMEGEhYn/TkfMis5DZR0rrvYD4nDd6Fx4cVwiGYfB40kqcvPY5ciIz8fTh\nHwIQ4ikTrAmI8xEMr5i+31+qNADMCp0+SHcvMKQGo3eBkYEcQwgZtFz64Yy/lx9WTnjwbt/GHWPV\n++GnmdskYcPbYWlcHv785d+wZuJ3pfUow5G04Mm40lGL2f28xFtnPHfD3z7DMJgckITj9Z9hScx8\nzAm7NzvGvui9TiHRbyzeu1gqZRkBgmFJ9p8A+ANTA1PgIi64iKtP48AwDMb5xOFss5BNtSw+H+lB\nqQodseGCWH9mTliG9P8fbY6UtNz+O+N5qFiVlHnWmyiTclGeWWPCjJCpmBqYgkDD7QX7+4MhA+nV\nb5OKigq8/vrr+NWvfgUAUtB7zZqe9K7Vq1dj3bp1SEpKgsvlQkZGBsrK+lGjpFAoFMpdY0gn+iZO\nnIjq6mpcuXIFdrsd+/fvx9y5cxXHzJkzB/v2CcGa9957D9Omeco6UygUCuXuM6QeBiCk1f7oRz8C\nIQSFhYVYs2YNXnvtNUycOBFz5syB3W7H008/jTNnzsBisWDHjh0ICwu7+YUpFAqF8o0y5AaDQqFQ\nKCODeyf3jEKhUCj3NNRgUCgUCmVAUINBoVAolAEx7AzGzbSpRholJSWYMWMG8vJ6BAuvX7+OlStX\nIjc3F6tWrUJbW8/K3xdffBE5OTlYsmQJzpw5czdueUiora3Fd7/7XSxYsAB5eXl45513AIzOtrDb\n7SgqKkJ+fj7y8vLws5/9DABQU1ODZcuWITc3Fxs3boTT6ZSOH2l6bb3heR4FBQV49NFHAYzetsjK\nysLixYuRn5+PwsJCAIP8jpBhhMvlItnZ2aSmpobY7XayePFiUlVVdbdva0j55JNPSGVlJVm0aJG0\n7eWXXyY7d+4khBDyi1/8gmzfvp0QQsjBgwfJ97//fUIIIRUVFaSoqOibv+Ehor6+nlRWVhJCCGlv\nbyc5OTmkqqpqVLYFIYR0dnYSQghxOp2kqKiIVFRUkCeffJIcOHCAEELI888/T37/+98TQgjZvXs3\n2bJlCyGEkP3795MNGzbclXseSn7zm9+QTZs2kUceeYQQQkZtW2RlZZGWlhbFtsF8R4aVhyHXT7us\nDgAACDZJREFUplKr1ZI21UhmypQpMJmUQmqlpaUoKCgAABQUFEhtUFpaivx8QXcqKSkJbW1taGho\n+GZveIjw9/dHQkICAMBgMCAmJgZ1dXWjsi0AQK8X5EXsdjucTicYhkF5eTlyc4WV0wUFBfjnP/8J\nQPl7yc3NHXELY2tra3Ho0CEUFRVJ2z7++ONR2RaEEPC8ssTxYL4jw8pg9KVNVV9ff4MzRiZNTU2w\nWoXaB/7+/mhqEkTp5LpcgNA+dXV1fV5jOFNTU4OzZ88iKSkJjY2No7IteJ5Hfn4+Zs6ciZkzZyI8\nPBwmkwksK7zSQUFB0vP2p9c2Uti2bRueeeYZSVajubkZZrN5VLYFwzBYtWoVli5dir179wLAoL4j\nw6qAEqFLRm5IX+0z0nS5Ojo6sH79epSUlMBgMPT7fCO9LViWxbvvvov29nasXbsW58+f9zhGfN7e\nbUFGkF7bwYMHYbVakZCQgPLycgDC8/V+5tHQFgCwZ88eySisXLkSUVFRg/qODCuDERQUpAhS1dXV\nISBgcMW1hgN+fn5oaGiA1WrFtWvX4OvrC0AYIdTW1krH1dbWjqj2cTqdWL9+PZYsWYLs7GwAo7ct\nRLy9vTF16lR89tlnaG1tBc/zYFlW8bxiWwQGBsLlcqG9vR1ms/kmVx4efPrpp/jggw9w6NAh2Gw2\ndHR0YNu2bWhraxt1bQEIHgQA+Pr6Ijs7GydPnhzUd2RYTUkNRJtqJNJ7JJCVlYW//OUvAIB9+/ZJ\nbTB37ly8+65Q6rOiogImk0lyRUcCJSUliI2NxcMPPyxtG41t0dTUJGW6dHd3o6ysDLGxsUhPT8d7\n7wmy4PK2yMrKGrF6bRs3bsTBgwdRWlqKHTt2ID09Ha+88sqobIuuri50dAjV/To7O3HkyBHEx8cP\n6jsy7KRB+tKmGsls2rQJ5eXlaGlpgdVqxbp165CdnY0nn3wSV69eRUhICF599VUpMP7CCy/g8OHD\n0Ov1eOmll5CYOHzlwOUcP34cDz30EOLj48EwQnGh4uJiTJo0CRs2bBhVbXHu3Dls3rwZPM+D53ks\nWLAAjz32GC5fvoyNGzeitbUVCQkJ2L59O9Rq9ajRazt69Ch+/etf48033xyVbXH58mU88cQTYBgG\nLpcLeXl5WLNmDVpaWgbtHRl2BoNCoVAod4dhNSVFoVAolLsHNRgUCoVCGRDUYFAoFAplQFCDQaFQ\nKJQBQQ0GhUKhUAYENRgUCoVCGRDUYFCGNcuWLUNBQQEWLlyIxMREFBQUoKCgACUlJbd8rdWrVw9I\n7vq5555DRUXF7dzuLVFZWYm///3vQ/49FMpAoeswKCOCK1euoLCw8Ibqo6JUxHBh7969KCsrw44d\nO+72rVAoAIaZlhSFciuUlZVh+/btSE5ORmVlJdauXYumpibs3r1bKqizefNmpKWlAQBmz56NXbt2\nISoqCitWrEBKSgpOnDiB+vp6LFq0CBs2bAAArFixAo8//jgyMjLw9NNPw9vbG+fPn0ddXR1SU1Px\n0ksvARC0eZ555hk0NzcjPDwcLpcLWVlZWL58ueI+GxoasGnTJjQ3NwMAMjIysHr1arzxxhvo7OxE\nQUEB0tPTsXnzZpw4cQI7duxAV1cXAGD9+vWYNWsWqqursWLFCixatAjHjx+H3W7Hli1bkJqa+o20\nNWWUcCfFOiiUe4Wamhoybdo0xbaPPvqIjB8/npw6dUraJi8uU1VVRTIzM6XPs2bNIhcuXCCEEPLA\nAw+QTZs2EUIIaW1tJWlpaaSmpkbad/jwYUIIIU899RR56KGHiMPhIDabjcybN4+Ul5cTQgh57LHH\nyC9/+UtCCCGXL18mKSkpZM+ePR73/tZbb5Hnn39e+tza2koIIeSPf/wj2bhxo+Le8/PzSWNjIyGE\nkNraWjJr1izS3t5OLl26RMaOHUv2798vPXtmZiZxOp0Db0QK5SZQD4MyoomOjsaECROkzxcvXsRr\nr72G+vp6cByH+vp6tLS0wGKxeJw7f/58AIDRaERUVBSqq6sRGhrqcdz9998PlUp4lcaPH4/q6mqk\npaWhvLwcL774IgAgLCxM8mR6k5ycjN/97nd45ZVXMHXqVGRkZPR53PHjx1FTU4NVq1ZJgpQcx+Hy\n5cvw8vKCXq/HggULAADTp08Hx3G4ePEiYmJiBtpcFMoNoQaDMqIxGAyKz8XFxdiyZQtmz54Nnucx\nadIk2Gy2Ps/VarXS3yzLwuVy3dJxA62zMHnyZOzbtw8fffQR/vznP+Ott97Cb3/7W4/jCCFITEzE\nrl27PPZVV1d7bON5fkTVeqDcfYZPBJBCuQlkAPkb7e3tkjrpnj17+jUCg0FaWpokK33lyhUcPXq0\nz+Nqamrg7e2NBQsWYPPmzTh9+jQAodaFKGMOAKmpqaiqqsKxY8ekbSdPnpT+7urqwoEDBwAIJUoB\nIDIycnAfijKqoR4GZcQwkNF0SUkJ1qxZg+DgYKSnp8NoNPZ5fu9r9bfvRsf94Ac/wLPPPov9+/cj\nOjoaqampiu8TKSsrwzvvvAOO40AIwdatWwEAM2fOxNtvv438/HxMmzYNmzdvxhtvvIHt27ejra0N\nDocD4eHhePPNNwEAVqsVX375JYqKimC327Fjxw5wHHfTNqFQBgpNq6VQhgibzQa1Wg2WZVFXV4ei\noiLs3r0b4eHhg/5dYpbUkSNHBv3aFIoI9TAolCHiwoULeO6550AIAc/zKC4uHhJjQaF8U1APg0Kh\nUCgDgga9KRQKhTIgqMGgUCgUyoCgBoNCoVAoA4IaDAqFQqEMCGowKBQKhTIgqMGgUCgUyoD4f001\n1ZxdsABYAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f96f1241810\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "test_accuracy tf.Tensor(0.99, shape=(), dtype=float32)\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYwAAAEcCAYAAADUX4MJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXdgFGXex79TtiabZJNsGoGEBEihhRJAQBQQQaqABQue\n4h3HqYeK7ThFz0NRz8N2nHIoqKe+dyIKJ4KA0gQpofcaSO/JJluSrTPvH7PTtiSBJIIwn792Z6c8\n88zs83t+9SFYlmWhoKCgoKDQAuSVboCCgoKCwq8DRWAoKCgoKLQKRWAoKCgoKLQKRWAoKCgoKLQK\nRWAoKCgoKLQKRWAoKCgoKLQKRWAoKCgoKLQKRWAoXHfs378ft91225VuxjVPv379UFJScqWbodCO\nKAJDQWDUqFHo3bs36uvrZdunTJmCzMxMlJWVAQD+9Kc/ITMzE8eOHRP2KSoqQmZmpvB95syZWLVq\nlfB96dKlGD16NPr374+bb74Z8+bNAwBMnDgR/fv3R//+/ZGdnY0+ffqgX79+6N+/P5YtWxbQxiVL\nluDZZ59t030OHDgQ33///SUd869//Qtvv/028vLycNNNN7Xp+jz+fXStcejQISQnJ1/pZii0I/SV\nboDC1UVycjLWrVuH++67DwBw9uxZOJ1OEAQh7EMQBKKiovDOO+9g+fLlsu3BWL16NdauXYtPP/0U\nycnJqK2txZYtWwAA3333nbDfzJkzcfvtt2P69OltugeWZUO25XLZvn07nn76abjd7nY/99WK1+sF\nRVFXuhkKVxGKhqEgY8qUKVi9erXwffXq1Zg6dWrAflOnTsWZM2ewf//+Fs95/PhxDB8+XJhtxsTE\n4M477wy6b3OVanbs2IGlS5di/fr16NevH26//XYAnKB5++23cc899yAnJwclJSX45ptvMH78ePTv\n3x9jxozBl19+KZzHX0sYNWoUVqxYgcmTJyM3Nxfz5s2Dy+USfrdYLCgsLER2djZmz56NqqoqQQuq\nrq4Gy7JYtmwZxowZgyFDhuDJJ5+ExWIBALhcLjzzzDMYPHgwcnNzceedd6Kurg5vv/02Dhw4gIUL\nF6J///545ZVXgt7z448/juHDhyM3NxczZ87E+fPnhd+cTidef/11jBo1Crm5ubjvvvuEdu/fvx8z\nZsxAbm4uRo4ciTVr1gh9JdVqVq9ejXvvvVf4npmZiS+++AJjx47F2LFjAQCvvvoqbr75ZgwYMADT\np0+XPXOGYbB06VKMGTNG+L2yslI4V3FxsdAPb7zxBkaOHInhw4fjL3/5i9BWs9mMOXPmIDc3F4MH\nD8b9998f8h1QuLIoAkNBRt++fWG323HhwgUwDIMNGzZg8uTJAQO5VqvFnDlz8NZbb7XqnGvWrMHy\n5ctx/PhxMAxzWW278cYbMWfOHIwfPx6HDh0SBkEAWLt2LV555RUcPHgQiYmJiImJwbJly3Dw4EG8\n9tpreO2113Dq1Clhf38tYcOGDVixYgU2b96M06dPy4Tmzp07MWTIEGi1Wnz44YeIi4vDoUOHcPDg\nQZhMJnz66afYsmULvvjiC+zYsQMRERF4+eWXAXADss1mw44dO5CXl4eXX34ZGo0GTz75JAYMGIAF\nCxbg4MGDeOGFF4Le80033YQffvgBu3btQnZ2Np5++mnht9dffx0nT57El19+iby8PDzzzDMgCALl\n5eWYPXs2HnjgAezZswdr1qyRmQv98e+LLVu2YNWqVVi/fj0AoE+fPvj222+xb98+TJo0CU888YQw\n2K9YsQLr16/HRx99hAMHDmDRokXQarUB533zzTdRWFiIb7/9Fps2bUJlZSX++c9/AgA+/vhjJCQk\nYO/evdi1axeefPLJkG1VuLIoAkMhgClTpmDNmjX4+eefkZaWhri4uKD73XXXXSgvL8eOHTuaPd/k\nyZOxYMEC/Pzzz5g5cyaGDh0a1D/RFqZOnYr09HSQJAmapnHTTTcJGs3AgQMxbNiwZrWhBx54ALGx\nsYiIiMDIkSNlwmXbtm3N+i1WrlyJJ554AnFxcVCpVHj00UexceNGMAwDmqZRX1+PixcvgiAIZGdn\nIywsrNX3NW3aNOh0OuG8p0+fhs1mA8uy+Oabb/DCCy/AZDKBIAjk5ORApVJh7dq1GDZsGMaPHw+K\nohAZGdmswPDn97//PQwGA9RqNQBg0qRJiIiIAEmSePDBB+FyuXDx4kUAwKpVq/Dkk08iJSUFAJCR\nkYHIyEgAcm1x1apVmD9/PgwGA/R6PWbPni2YI2maRnV1NUpKSkBRFAYMGNDqtir8sig+DIUAJk+e\njPvvvx8lJSWYMmVKyP3UajUeeeQRvPvuu1i8eHGz55w4cSImTpwIr9eLH3/8EU899RR69uyJYcOG\ntUubExISZN+3b9+O999/HwUFBWAYBg6HAxkZGSGPj4mJET7rdDpUV1cD4Aa9Xbt2Yf78+SGPLSsr\nw2OPPQaSJIVjaJpGTU0NpkyZgoqKCsybNw9WqxWTJk3CvHnzWuUbYBgGb731FjZu3Aiz2QyCIEAQ\nBMxmM1wuF1wuFzp37hxwXHl5edDtrcW/L1esWIFVq1YJfWK322E2mwEAFRUVLV6rrq4OTU1NMt8U\nwzCCQHn44YexZMkSzJo1CwRB4M4778Ts2bMvu/0KHYeiYSgEkJSUhE6dOuGnn37Crbfe2uy+06ZN\ng9VqxQ8//NCqc1MUhbFjxyIjIwPnzp1rj+YCkJs/XC4XHn/8cfz2t7/F7t27sW/fPowYMaJZ/0go\njh07huTkZBiNxoDr8CQmJuLDDz9EXl4e8vLysG/fPhw+fBhxcXGgaRqPPvoo1q1bh//+97/Ytm2b\nYEpryXm+du1abN26FZ9++in279+PLVu2CPdgNBqh0WhQVFQUtD3BtgOAXq+Hw+EQvvNCQIq0Xfv3\n78dHH32E9957D/v27cO+ffsQHh4utCMhISHktXiMRiN0Oh2+++47oY/279+PAwcOAADCwsLw3HPP\n4ccff8TSpUvxySefYM+ePc2eU+HKoAgMhaAsWrQIn376qWCPDgVFUXjsscfw4Ycfhtxn9erV2L59\nO+x2O1iWxfbt25Gfn48+ffpccrtiY2NRWlra7ODvdrvhdrthNBpBkiS2b9+On3/++ZKvBXDmqBEj\nRgjfY2JiUF9fD5vNJmy7++678dZbbwlhx3V1ddi8eTMAYO/evTh79iwYhoFerwdN04J2ERsbKziF\ng2G326FWqxEREYHGxkYsXrxYGMwJgsC0adPw+uuvo6qqCgzD4PDhw3C73Zg0aRJ2796NDRs2wOv1\nor6+HqdPnwbAOaI3bdoEh8OBwsJCfP31183ev91uB03TiIqKgsvlwpIlS2C324Xf77zzTrz77rso\nLCwEAJw5cwYNDQ2yc/Baw6JFi1BXVwcAqKysxM6dO4U+5oWOXq8HRVFKdNZVSocKjD//+c8YOnQo\nJk2aFHKfV155BbfeeiumTJkisxsr/PJIZ5adO3dGz549g/7mz8SJExEXFxcQessTHh6OpUuXCtE8\nixcvxl/+8hf0798/5PVDMW7cOLAsi8GDB2PatGlBjwsLC8Pzzz+Pxx9/HIMGDcL69esxevTokOds\n7rrbt2+X+S/S0tIwYcIEjB49GoMGDUJ1dTV+85vfYPTo0Zg1axYGDBiAGTNm4OjRowCAmpoazJ07\nFwMGDMDEiRMxePBgTJ48GQDnN9mwYQMGDx6MV199NeDat99+OxITEzFixAhMnDgR/fr1k/3+3HPP\noUePHrjjjjswePBgLF68GCzLIjExEcuWLcOKFSswaNAgTJ06VRAYDz74IFQqFYYNG4b58+cH/Df9\n++LGG2/EjTfeiLFjx2L06NHQ6XQyk9VDDz2E2267Tbj3F154QdBgpOd6+umnkZKSgrvuugsDBw7E\nrFmzUFBQAAAoKCjAgw8+iH79+uGee+7Bfffdh9zc3JDPROHKQXTkinv79+9HWFgYnn32Waxduzbg\n9+3bt+OLL77AsmXLcOTIEbz66qtYuXJlRzVHQeGSqK2txe23396iU19B4XqhQzWMgQMHIiIiIuTv\nmzdvFmLp+/btC6vVipqamo5skoJCq7Farc06uxUUrjeuaJRUVVWVTL2Nj49HZWUlYmNjr2CrFBQ4\nUlNTkZqaeqWboaBw1XBFnd7BrGHXS9kFBQUFhV8bV1TDiI+PR0VFhfC9oqIiZJKYlI6oFdSefHpo\nFdad3QwtrcG/p79zWee468s/CJ+jdVFYOvm19mpeh1FurcLj618CAKy8+wMAwHu7V2Bn0T7EhcVg\nycTg5S8W/7wMe0sOAQA+mvI3RGgNwm9bLvyMpfs+DzjmnxNfgSksJmD71cxft76D41Vn0D2mK169\nhSugWGGtwlxfn/nzybS3oFfphO8sy+LulY8I33vG9cBLI6/erOjT1efx4hZ5fg7/Xizc9g6OVZ5B\nt+hULBrzHABgT/FBvLXrQ2hoDT6+/e+4d9UfAQDzhv4OQzrLAyQ2nNuGFQe/RDAitRH4cMob7X07\nbeJ8bQH+/CPXpi/vel82fj2z8VUU1otVfVfe/YHw/3/hprnok5CFnYX78N6eFQCA/9y5BBQpRpF9\neWwtvj65Puh1aZLGx1MXQ0Or2+U+OlxgNOdTHz16NL744guMHz8ehw8fRkRERKvMUQRBoLra2p7N\nbFdcDi8AwMN4L6udDU6L7HtaRGrI85hMhqumL87WivH4fJtcLq4v3J7gffG//O8FYQEAP58/jAh1\nOPIqDuGejGm4UFUa9FrPblyEV4Y+DzWlEra1R194GS/+e+YbDIjPQWZ090s+nmVZfHZqJeocZmhp\nLYYkDEBOXG8AgMvtey/cDKqrrWhwWvDS7tAD2+5zR0CRFI5UH8eMjGmosFfJfj9RdRZPr38FRk0U\nHup5L1Tt3Bdt5Xx5YGlzvk1a6Ll96grww8ndyDH1wt4CLrLM6XHi6e/FqLG3dn2Il4Y8ixJbGc6Y\nz+PuHrejrC60r7PBYUFpRS3UlBpWlw3fFHyLMUmjkBSeEPKYULi8bnx+aiVGdxmBlIjLS4Yss1Xg\nrYMfCN/nb3wD47uOQVZ0DwCAjhAnBUZNlOy5LdnzKSLVEahsFJ/9/gsnUeeoR4GlCNO7TUJhbVnI\na3sYD/acP4rsmAysPLsGjw6beVn3wNOhJqmnnnoKM2bMwMWLF3HzzTfj66+/xn//+1+hEBxfvmHM\nmDF48cUX8dJLwWdavzYogutWL+O9rONP18kT2rSUps1t+iWoaaoN2Eb6XjEWgRMHs6Memwq3Bpzj\n3UPLsLt8H87W56O2iYvbD6P1sv3s7kYUWppPGLscTtSexq7yffjH4dB5Jc3R4LJgb8UBnKu/gGM1\nJ/Hh8c9C7ru1eCfcjDvk7yW2Mrx/ZAV+LstDsbUUp+vOBuxTZC3FkZoTOGM+H+QMV5bqIO8Dj0Ed\nLnz+8Ni/Acjf+zJ7hWz/U3Vnsfz459hZugd1DjOsruDCMCEsHgBQ43tvVp9fh7ySw/jPmW8u6x4O\nVh3Bgaoj+Nv+f1zW8QDwwdGP0eRpEr5faCjEksMfCd/rnGbhc7TWCDfjEb43uhtRaC2Gw+sUtlXa\nq/Hxif/D1uKdsLhsqG2qAwECNCmf/9/ShQsHr2qsgcPjxPaSXZd9DzwdKjAWL16MnTt34vjx49i2\nbRumT5+OGTNm4O677xb2efHFF/HDDz/g22+/lcX9/5ohfeoiCxb1Ti6JqcRaJggQs6MeDU75C+/y\nulBm4/4kp83cH+d3vbjZgJdl4GE8KLaGnklcaRiWwa7yfQHbSV54sty9syyLQksxGJbB2gsbA/b3\nSIRsibUMlY3VoEkaTw98LGDfat+g0J6YnQ0t79QMlhADGQCwLF90kROeUnOTFL7PaiT3d67+Ak76\nBMYfc34XcMy6i5uEgabOYYa5qW33cbkwLIMiSwk8jAe7g7wP/D35TyCqG2tR66hDplGu1fGDoFSA\n1DTVweLiEief6Pd7YfsNibkYkjDAtw8nrEps/H+mddkDdRYHzFZxcFZToinnWM1JVNirUGGvwvGa\nUzhWcxIuryvYaQBwE8Yia0nISUFtkxkMy6CuyYyUiM5Qkyq4GTfOmy8AAAbG5+Dtm1+FilTJjiuQ\nTJRqHbWobqpFrC4a83OfELaPSx2N7lFpALixxeqyoT1Qakl1ABQh2hef//lVPNL3Ybx/ZDluSMzF\n/Vl34oVdiwAA/xz1N2G/5ce/wPHaU3h24B9xvv4iwlVh6GzoBIAbbL8v2IwNBZvx/KB5l6VadzTH\na06h2MqZj6QaEW+q5U2TJ2pP44OjHyNKEykIUym8YAGANfmcXTZBHyczPWkpDRxeJ2qbmcFeLv7m\nwEuluT9mo2+Wyc8WvSGq9kZrolDrMMs0ttXn1wHgZtAmnei7yTB2wxnzeRRZS/Fz6V7c3HkY3j20\nDGEaHZ7tP7dN93IplFTb8PW2fPTKteObC2sQ7e6GelXg831p9+v4x8jX4fKKg2iMNhr5DVwxwz6m\nnrjQUACXb5A1aiJR3VSLY9UnhP2rm2phddlAExQiNWLYvppSIUYXDQAotVRj87bDqDJxpiuPn7Zf\nU98EmiYRFS7X3p9+n5uFr/jTKADA7hOioFp69JOA+xmfegsmpAUvn/Ppyf/iQNWRoL8BwIu7X8Or\nw56Hh/UiRmtEbVMdah11WHKE0z6KS904qKqGs5EGqRX7a2fZXuFzibUcNrcdnQ2doKXFe1GTavyY\nVw5oASfjgtXdPuZJpTRIB0CT8rIGxVbOlhtsxsVzvJbLcr/QUAizox6JYfGCY4thGRyrOQkAqG66\nOvNUpIObRjIrc/pmYIxvdl3la38wYQFwNld/0iJTZbMs3uzQnMnjcuEFRqSac7x7GQb/XH0Mu49X\nNHeYgKUZgcELE6uLqzZ74Dw3+52UNhYP97offWI5DZsiKURro4Ka+LKM3WW+Cj2tw9RuEwBwfepl\nvKhpqkVhfQnqHOaA49uTH/cX46utnCns3a+O4Eh+LfZc4LSgGoZ7529LvSXgOLu7USYwaJIW3odY\nXQxujRDXw6B8GkaDRHOrbaqD1WVFuDoc63aJfhI1qUa4iqsEfKqkCscuVgmze/++fHbpbsxbIi8X\n4/aIAtza6EJxlQ2H85t/7rVN9SF/a05Y8Jgd3H2HqcKgIlWwuxuF30rKXVjyzTGA4cYBLREecPze\n4uMAuH7TkKLAyC9qxLHz3LmdXmez7+WloGgYbeRg1VEcrjqG32TPEAZ4f0e/nhZNDwv3/F34vDZ/\nA1SUGuNSRwnbVp37FgD3AvCaitnRgFJbOYDmB6S28lPJLhRbS3FfVvDFjZpD2q4GlxX/Pvkl7s+6\nEw4PN5vmBYaOEmtTqSk1Mo3dcbRGnD1KNQyeblFdZfbZWF00SqylMpPNt/kbcOrAGUxOvQ1ZMT0u\nuf08dc564R7mb34LpohwHL+QjANnqtE9ORKxUeKzLKm24b+bz+GeW3qgUyw3UAWzrf/f6VWYkTEN\nNjdXg8nubsT8D3fDbKgHHQ/0NfVCYlg8DldxS94SBIlYXUxQv0QPYzpUkr6w2LzYcsgFJHOai9Ut\nPofTdecwNGlQ0Ps8U2TGpn3F+O3EbOg0wYeB7y5sxL7yY4jWRYAFgyaPAwzLQEfr0MXQCet/5Aaw\niUNTUW/jJgaNbANAAKSGKw+SY+qF7wt+lJ3X4rLCzXD7EwyNysYqwTz57bZSnDsHqLOiQBnqg/rv\nCswVsLhsiNfHYceuKuh81dBVlApNjZxK64EThEo0LTV6mvDmyjzcOSIT8UbRH3ax3IIIvRoxkVqY\nbU7Qnc4BXhrnSnrjbHE9QDbvh9xbuR+RWgOmpN+GAksRNhRsxm+y72l1octKO1f8sbLGBcbrN3+n\n5JOnRisF0k9mXLRdAEFx/4ldR8VCkkXlTQDDTdzqbHYcbWgfc7aiYbSRTYVbcaDqCCokUQz+g56W\nFgdJ6X4bCrdg7YUNcHsDbZycwOAeD6+uA8EHpPbiy7NrsKt8nzC4Xwr87FnnE457Kw6g2FoKh5cb\nOPg+kTr0RiYPl5magOAaRh9TT9ASM5+aVCMxLB6ltjLBhry1ZCeKGkqxvfTyigzy1DvEGaOFqEC+\n9TzUXblZ3PnSBtTbnCiv5Qb+XccqcLLAjAUf7QXjGyD4fsiNF8NAfy7LQ4W9SuhXFiyqnZXCgPDN\n1iJsPViCKenj0cXQCQ9l34P+cYGFGRl7BGpKw1FRLXGA1rpQWcP1gcPjkJnE/IMnpCz+8jAOnavB\njqPlQX93ed34vmAzapxVOFt/HufqL6DEVoYyewXyGy5ia8lOYd+iSiu8DHf/UgcuADQ1EtDV9ZJt\nyztbJGgYXpf8+Z8r4N4X98VeiNMmYKZk8kKAAOuhcc56Gm7GDT0ihdk3AOw8VIn3vuQ09WqrBYRK\n7l84XVGK91cfR53ER7Hw0/145oNd8DIMzBYHVJ3yoepyBkv/dxyb9hWDoFoOXNlUuBV/+7+DePvA\nv3Cs5hS2l+zCgdLgdfFU9Wny+63mBvIT5y0wN8jffY1Fvi/rDRTsfPs276rDFz+Iz7vW7BH2P3i+\nAjtPFrR4H61BERhtwOayo8TniJaqvP720pbIbygI2BatjRKcn1KkA0JNUy12Fe2XRWC0B8EG7Zaw\n+GykRk2kuM1lFTQMN+NBdWOtYCKY3fs3mJw+DoTfK9jolt/LY31/Cx2tlcWdqygaGdHd4WG9OF/P\nCVM+qv2c+cJlR6fxbfaHiqoBaBfqbS7MW/Iznv9wLxiGhcPF9xOLjScOg2VZ4fgp6eOQFpkqnKPc\nXik/Z0QNCJ/AOHDSjM82nQXp0YM8PwKEMxLDOw1BnF4eYu48MRSfb8jHwk8PCNusNq8wMDS6HbL2\nHyo/jcIKuU/m+MVavL3yCDxeboCvqGvEiYI6vPb5ASz6/AAaHW7YHW68vqZ15epBMDhUykdvMSDU\n8uf3zdZiNBWnwmsxCtvWHzgnmCrh8csP8H1nHeGYYnoAbnsYWA93f16nFt4GsU+OHyUBVsxnqKl3\ng/VwAsjqsoOK5sxJjJObxNCmUphtTahsaAARJjeJ/ufHc9h58ajYDC8n3DUarp+cpwI1NakScbqo\nHh7fu731cAE+3x184mI52w2sS9SaTpRx1YpZLw2WEf8LjsMjMH2oXNDCK/4HXOf7wms2Cd+rKuX/\nI9ZLi8KU9Mq0rbagCIw2cMZ8Toj2kJpH/DWMlgawiw2FAdvi9LEy5zn/WWr6ef/Ix3hn93JsKNhy\n6Y1vBvdlCAyrywY1qUKYSlT36xz1cHjEtRe+OP2VIDDUPp8EyZfr9g35/iY3qSOPhyZpZBi7AeCE\nLcuywozV4XWi1B581hwMl9uLd746ggNnquD2umXhi1JUyeewcqtoIqq3OYWZKt35DL6r+g92leeh\nys6FOIarw2W+nHJflE+vmCzuvg31osnB98d+e+VhnCo04/3VnGmqpyEnRKvFQdLjJgCfwLA4GoX+\nI1gKDOXEe9+JA1d5rR1vfXkExy6Ik5tth0qx+L+Hca6kAedLGlBQYcWWAyUoaQxddl2EharLKexo\n+hpUdDkItQMEKY6iLEPgTIEVTU4PwrSiJkGonGh0OcEyZOCsmSXRN51z6m8+UIxV2/KF+2UdeoS7\nk4RdGUuMrC/AUABDgWUIUFE1oBO4/xVjjQIA0PFFIDudwr/2rYIme7dMuG07fwQHveuE7zf04drg\nZrn3ShPENEawcl8ly3LDqdnuAGkI5T8i4a0XB/oGt28/6QAPYGhGV4zom4S//eEGDO7MCY5ErZgH\nwro1gEU8D+uUR9yxXhpd431CmvSgU6f2GeoVH0YbkKr8cg1DPuC2NAC7/MLuJqeNQ2pEF5lpKEoT\nGRB/bvY5NWsvMby0zuJAk9OD+Gg9ln17AgMz4zAoK17S3tC5AaGwumwwqA0yX0NNUy0cXid0tA5N\nniY0eRxCX/D78WY3kiDhZb0B0RxScx6PiqCFaJgGpwVuxi0L07S57LL9nS4viqqs6J4cFXCus8X1\nOJpfi6P5tfjDnWkBv/P4z9BqLQ7UWZzQqCgQJs7xes58AYXWYrD2CLhdrExgFPpsyEZtFDcrptyI\nCKdg81LgB72Saq7dZqsTRZVWrF9LgYwcANajBtwhMnUZCnFGPSxeCnZXk/B+UA4jPLoaWJkGeBkG\nBEHgpRV5Ie+Px9bkxs5j5SBi7MF3YAmA8PU1yYCK5e6LNJi5dkrQUlo4fPemogkIUwfahXp7I8BQ\niNLr4K/TZaYYcSS/FicKuPdbN4h7HxmrEXfn3gzSmAG3U4UP8+SJjNyASwBeFUCK5ijGFgXEcpMI\nOqEIrEsDggBuH52AtIiu6GQKw6INX0F6x+kpWuw+CpiiVagD0D0pBv4GPoIQ37pbB3XCT17f5Idk\nQGgb4Q9xhsuLcBdmgfWooEq6CELLXVWn0iI+NhJl7lqwXhI39U0GSRCIjdRhZu8pGGTuiazoHnhs\nK6ddPjV9ENSMAbVkHzRaaHgjo/HFD2KeTlJUBIb2TMLXZgIE7YGZKQ/QWC8HRcNoA6fN56H1OXGr\nmzFJ8Xb8UHhZr0ybGBDPzSwJyexJRakQrgpDfkOBEDHFz3qks/L8sgb87f8Owtoot9+6PQxOFNSB\nZVk8/+FeLFieh1XHf8AR51Ys/d8J2b5v7l+CmqZabC3eiY+OfSY48BiWxbJvT2DHEbkDbfme71Hv\nbADrVqOkUvyjbCneAZvbjnh9LGhGhxJbGTYUcAsL8b4LXnDwUVBVjfIoMA2pxrELtWh0SKNqVIjw\nJX5ZXVaYG+V/Tj589Wh+LQ6fr8Fnm87gtc8P4uDZwNXl+EEapAcf5zeTrOfngDxX0oCSahv0nYtA\n0Nxv52vKQBAsvA0xOFVYL3umpyqLxPtmKNAqBjodEKbWYdygLrJzO1xefLnlPAACTIMJrD0SrEuH\nOGNg3gbLUOjVNRrw0mh0O1Bj40wtTWYuyotVNaKm3oFNecWCGQqkB+qMfSCjKqHJ3g1N9m6A4vq3\nuMoGc/hR0DEVINjA4UHmyyUYkL5XlI4vgiFFrpWEq8W1yylKslaK2gm7ywmCodApNjDyJ6NLoGAH\nOIGRlhiNiARjAAAgAElEQVSJgQn9kJucHbQv/jxzAGjItQHGHin7Tqg54b+h9kswhgosPbUU9uij\nsn2i4zxIufEw6lScmJh2Y2AghXSScnP/RMBnUgozNoEgGTBN4v3H601gHb57ZSl4azlNiX93/jAp\nB0nRvhBhhkZEuCh8KZJCdkyGrJxISkws0pOMGJTQHzf36INeadGytv12Qm9kpxpBsiqQYRa4GBcy\njZcfDMKjaBiXSaO7CXUOM7JjMpBffxE2yaDtZeWDC2/HDwXDMKBICl4vJ2j4wZAgCFAEBS/rBU1Q\n6GFMx4GqIzhafQK9Y7MFDUQ6K3/3q6OwNbmxYW8RGJZFo8ODh8ZnYdW2fPywvxgP3pYJp69ExY7a\nzaDjAHexfK3remcDPjr2GYp9SU9uxg01pUZdgwN7TlZiz8lK3NiXe+Evlluwr/IQKANQcSEKlMEM\nSv7uIpyKhMtVC1KiLPCCYnzXMSi3V2JK+m34+4F/BvRNQVkTlnzFLf6j85mRfz5aCcvFElAEBYvL\nhq+2nwb04P6wJINGdxOKq2x45yt5WOPhczXo38Mk21ZYyfUdGVkDgg6tWRF+AoMzlbBwxB4TttU5\n6kCouBntP1cfQ+oQ8bmwak4wqUgVWC8FimLg8Dpg0IRhZGYnbMiTZ62fKgw0afxhSi98vP4Uiqok\nZjuGwg09E7DnqApNHgd2nSkEogDGZgRwEaS2EQs/3Y9GJ9f+4X0ScaTuEDyRtaAixUkOGV4PpsGE\nw/kVUKVyiWNhap0Q2QUAzrP9oe4mlnEB6QVJEuCnRy6dPAQ1QiMOmCBYIXeOCrMAlBfGsHDQtDgI\nTus6Fd179hKimAgAzz8wEIdrSWw8fRCM1YiYSO4loqlAYXZzn87o1ikSUfnhqPVYfecgsHTuVHx9\nmsDO0j0BCYPBcisA4GTtaVQ5xfsxGcICdyLEc4XrKcGf4lJx/co0GkDquP5TU2poVBSanL4EVpdc\nc9artMLEjPXSiAoLrlHOyJiGgoYiIbhEaF+UDkOy40HrR8GlL0fn6BiQBIkovR5mJzd5TDYkBj3n\npaBoGJdJjYN7KUy6GE71lti+Q2kYj/R9OOi5vCwjMz9Js0t5kw1FUrgnk1thjtco3D7BJHWEe3U1\nAOVG3qlKbMwrxo6j5SivteOH/dzs78CZIGs4awKd5sU2UYvgHZR86CShbsLn+3/A01+sxMJ/54HU\nNIFx6OCtTBUcd6zEQWfwJskcdgDw2ffn8b+dF2FQh+OJ/nNkdXqkM7MlXwVGm1TWObExrxiMS40G\npxVHL3CmCcb3JzxfWY2SahtIQ60wcwaAXQUn8PK/d+F/h/bjpf9swtnqYpwvqYdeQ6Nr99AZuwBA\n6KwgdFb86T4x+onQy40phIq7Vve4JBAEUFYbGAJ9rsgGlqHgVdlgc9uhpTUwRekw766+sv3UKhK3\n5op9Mu+uvkhJMODpe/phwW8GijsyFExROhjDwgDKAw/BvWs9Yrr62miBQ10F0liJxFQr7hndHdOH\nB840CQ2npZU7xbwGqbAYn3oL0BAvP4bywItAcyuvGetVOrzwwEB0T46EXis+f0JnA6FyIUytFfx9\nKRGdMbrrDegSb4BGTeGlB3Px1mPDkJYUgYmZN8KdnwMVTTdbdHRoTy7RNT5STOZ7ov8cJMVGYEbG\nVDwuyQpviQq7/H+iJpsv3udm3IAvYoklfNFwjWIRTQ2lxh+n90Fmlyj8/ZGhePauQbJwey2lgdM3\nhoSrtVCrgi9Re2OnIZiZfVdAP5AEgdmTe2LWkHGY0+chIWBGahb1FzKXgyIwLhPeyR2ri4GW1sq0\nCH8fRpPP8SsNDZXiZtwhI5NI3zE0QUNLaUGTtCAgeGd6k8cBt9eNqsZqoNtuaHruQq1FbM/zH4qZ\noaU1gYMYoWls1jH//PJdyDtViTordx+qlJPYbfkBTYn7QcWWglA7wTp9zm7ejMGSoF2cKWDrT06w\njPzezxZZsX6P6OyXRoSxMnt9kAHCJ5S8Tq64HENyAzU/ayuqqcOpmnPQZO2DOp3TMgi9BZqsPFQl\nr8Em80rUxP+Id4/9A7UWJ/p2i4FHVxV4HR80qwFBstD2/hlpSRHolRaNznHhmDou0HRCgMC8qUPx\ntzlDkRyvD/j91MUGwUkNQEg0M0nMTS88MBB/fXgwpgzvKmwz6Lk+Cdep0DVRHBBZhoRBr4LJYABB\nMoiOYaGjtHjzkVugp/WgDPXQZOVB0/0Q6uN+RoOnFlpt4N9+UA6n1VKRokmwn6m35MYIvP77IaBI\n8XmQ+sCIMpqkkeoT/lGaSKQlRWD+/QNkznAeNaUS3nv//0ZKggGRvixstYrCX2cNwt/m3CDbZ/bk\nbAzKEqtb82ZNad0xaR5H18guCKP1AXXJgiEt9gdAFqUXDIfXGRCCKzWFqSk1uiZG4Nl7+yM6QovM\nFCNidaIqrqXFSWdStNyE1hakk8/2qEmnCIzLhHdyx2qjoaU1Mj8FP2uK13PmD16YSF+6EZ1uwH2Z\ndwCAEBabFJaAv94gX+GNn0jQJAWCIGBQhcPisoJhGVk01v78Ylh9zl5S24RQtXPqfILEaBBfHlLT\nhLe/OhR0fwCwO51Y+r8TKKvhzk/oxJknGc7lLagZ32zKN5jr1DSeGfQHOI4PBevSyQZJAABLwu1h\nsOVgCfJLG3D8oiQT1z/U0h+GxLBeCWDdajDwAj6HNC8wqiwWnK7jIpqoKG4AJOhQGgSLzO5aVDXV\nIDs6A0/0myMrvQEAsWGSPzDBYN5dOXh51iCwFPfMe0WJORNRmkioSBoxkVoYwoJYfBlKJjzvzZwO\nAIiJ4NquVpFIS4pAXJQOOg0tPCfp8/I/H0EQQnCAxWOGQcMN/iZ9YPn3Wkd9UBOph+ImEmRELQiW\nwpP9/4D7/RI4Y6N0MvlNRnGz8Elp44RtYbQeD/e6Hw/1vBeT08XtwXJ7VKRKeIcpsnnreHJcuCBA\neIZkJ2DOFDH0lNdspIUNpUETNEnjqQGPYG6/2SGvw9+LVLtqDXyFgM7hSfhN9gzM7vUg+sSLpt5g\nUVaxkvdMS2vh9D0XLRUY6HG5SDWMYAEkl4oiMC4TqYaho7RwMx5htsSbpLr5in/xAkHqBM2K7iGU\nz270aSBJ4QmI0Ynx6oCYNU6RFEqrbdCSeljdNhw6L58B/XvzMdQ1ijM+QheoScy81WeKIBhQ3UWt\nQ9XlDC7EfBX6Zn3Zrkfya+Efa0+buPpREwdk4ZHbewl2XRVNITk6Gl189bBovwJqg3pw9tTPN53F\nq58dwFtfiv4Gf23En37d45FkCuNCCwFouh/mjvMJDIZwwUKKobXanK3QxAQvI9I9RY+vyrnaPT1j\nMtHdmCYk4fFInbcurwuFlmK8lvcOCiycmW9wsigwpLNGKkgeDRhKMM+pSBpRvrwVmiLx6u8G4405\nQ2W7L3x4EF56MBcRIWzaN/l8SfzskQULg4oT3v6CDwDeP7IcR6qPB2yvcdQBKgdIvQ3xqmR0i+oK\nLa0Vssr5d1caiBEex/lZhiQOEO4jTKWHURuFgfE5gvYkPV4adq2mVMJ26cDWVuQCQz5Qx4fFoVN4\naFv+gLi+IX9rDkFgGJIxKKE/+sZlY+508VyaICYtqcBQS/4f0j5qKxpFw7g64J3cERoDNL6Xklcp\necGh80l0fjtNUvhjzu/QP64PsmMyBDOMwydQgtlJGZ+mQBM0FizPQ0kZZ756f+0B2X5u1oll6w4L\n3wltI/qkiy/kotlDcHM/bvAmdFbYabmDMpjJQPyNExiFFVYx1r5JvlZ7H1M2BmSYkBjLvex8WfPo\nCK5vKFY+gzRFBnEi+khLMGJk8nCMTBgtbJtwQ4rweVBWHBKM+oAY/pRon3mCdsvMJYTaCZguIhgT\nRkcJs9x+vnUr/GfDpETQO70u/OPwRyixlQlVhWO0opCPUIt263sypiMrugf6SbK2WS8lxNtThLz9\niTFhiPQTDHqtCikJBoRiYAbnV5AODHzQRLTWGPQYvt1SKu1V0Edxs+qsBLGvn+z/B/SOzcJNyUMD\njnEwTYhQGxCliRRm0KHs5A9k340exm6YkTFN2Kajdbg/6y5kRffAjIypIe+xJWb1vBc5pl5CUU6D\n5BkEGyQJgsCtKSODnitSI+/rgfE5mJJ+GwBgbs5spEWmIi0yBX/M+R2eGvCosB+fQBoqdDU2iPDO\nMfVC5/AkDO80BARB4MGe9yI7JkOoDdYe6CTmN52iYVw5mnxCQEdphZeSV/X5AYhXLfnkNYqgkBnd\nHQ/3uh80SYvJeE7O4Rh0VSyJhgFAmFUTanmoLkEygCTKJyNNh8nDuoKMqgQVUwozSsTw2yDhkoHX\nFWeS94xJ40IojRWCOcpIibbjGzvdgISweBAEgbQk7g/HC0OSt3n7JTnRFPc9XKfC/Pv7y+oZdUuM\nxh09JmNaFlcFlCCA6Tely47vEm+QtREApt+UBS2lBaltBEF5oWODD5hSLG4uDHV81zFC5VP/SBq3\npIS10+uUZdbTBCUTEtI/ZYzOiMdyfitoWQA4k5RPw6DIdvj7+bpAWpCQHzCpED4zf/qaeoEFi/65\nnKAMV4uDTEpEZ8zp81DIwUa89+ZLtncKT8Tj/WbLZvcGdThiddF4LOe3goZyOQyIz8Hvej8gvHMR\nEg3Df40IninptwlFLKWoSJXMn3JH98mCcMmI7oanBjyCpwY8iszo7kiLTMFknwmLD3XPjA4euhps\nMa6UiM7406AncI9PiCaExeHRvg/LNKS2EtGMtnU5KALjMmBYFnU2K1QkDYqkBNugU9AwuLwKq110\nSgOBVWzF4oKcH0BH63DgTBVe/mSfkHfAD16sLymIL33gLzCmjOgiC/3M7haGiCgvND0OQZ1+DEsO\nf4Q3+UVgWiioBshfrthoNdSpp6DpfhiaFC45qHuMOAuVvuB8e/k/r94nCFS0/FUb2b8TeqfF4Jl7\n+qF7chT+8lCu8JtKkgX+tzk3YPGjw2THEgAXXukn+LQqDQzqMCHOPkmdgpao8/W9NGLFX8OQJlZW\nNsqjZwxqg6ySribIn1IqUITkMoQOgmgNyeGcKSpGy5nApNopb0LrFtU18MAg9PVVyT1v4TSP5gaW\n3Ph+su/8s/f6+kztZ3r0R9rPsn5pR1o74PJtkZrZCIKAXmIS0rcQWcQLaqvbBj2tQye/pQf4SaNs\n0vALIu2LYH6US6XD8zB++uknLFq0CCzLYvr06Zg9W+5wKisrw5///GfU1dUhKioKb775JuLjAyX/\n1cSP+0tQZbFCreW6T+d7ELxg8LAeUASFH/LKoE4VTVLSGd+u4+X4Pu8ikCqag1irEZuPl6CwworC\nCiuyUqMFH4aHlwW8fd8vL8Bk1GDUoDjsKOOcvRaXFS5G7tw0O+u5UhwSgTGmy83IieuFN/cvke2r\nV+kER76bcYMy+ZKytJypZ2DXFOT5cp2ksxh+sOVLfky7KR3WRjf0XapwSJKQHqFX40lJKKlJUgVW\nWpBQWh2Wh5//d4oxQOrJoQgK6ZFdhSTKPp07o1Nxb/zkDlz1LtPYHafN54QS4M0JDGnme4lNXnbE\noA6Xze51QRyW/KAOcFFNvJ+nJUdvczzZfw6qm+oEE4ha1gbufcyM7o75uU9AS2tAgMCLu18Peq7U\nSC5xsMHF2eGbs3XPyJwGFiz2VnAmUX7QZ4Xn3rwQlPazQdV+M2kprRVEYT5tiCAILLxhvtB2Pa2D\nxWWFhlK3GB0lnSwYg9R/e3HIM/CynhbP01FI+yJYbbpLpUM1DIZhsHDhQixfvhzfffcd1q1bh/z8\nfNk+b7zxBqZOnYpvv/0Wjz76KBYvXhzibFcPRZVWEJQXXrevTr1Pw3j74AdwMx54GS9nw/dFDPE+\nDelL8/X2CyitFjOUWYbEN9+bcbqIm/Gabb6y4L7h8WKZzXcObpDhtQn+pf/3qS+xs2y3cD6ryxa0\nxIe622GAFAdEkiCRGtElYD+pw1ZaD4onQmLrldqMeQHHx4lHhqkx944+0Gha/6q1NEvlX/zBWfLZ\nnNPrlKn+iYY4GLVRQvE6KXwJdEFgSEwp/jMx6aBQ5icwItThsnLjwSJRpNFKM0ZmIj6aO39bNAwt\nrUVng1hXSaphSEMpkw1JiNXFCKVUghGtNcp8D81F06hIWgjmAAI1DLKZPAlA/h/oMA2jlYJI77Pv\n0yQNozZK8F/w70Jr8hakzz7Y/URqDCF9Sb8E7WneAjpYYBw9ehQpKSno1KkTVCoVJkyYgM2bN8v2\nyc/Px5AhQwAAgwcPDvj9aoQiCYASywdLVfgfj52Gm/Fwhcj8on1on5OTYVgu81Zig2cdepmdf2Ne\nMSrqGsH4ykY3WDkBMaA7p311T+FeBOlAJ7W9W11W2SI1PISxXKZhhKobJZ2NVNgDcxSkL2JQDcPv\n1Wrt+gAAoAoRMTM3ZzZyTL3RO5YrC6Gm5YJlYFIf9I7NQnZMBrJjMpAemYKMzlEy53jPmExMShsr\nzKJ5k5Q0MmV27wdk5324l7igD78uSW58P6REdMaQxFxZXwWbnUsHkv7dEpAcz5mMyHacdUo1jFCm\nh9vTxwfdriJpmKQ5AS2YLqSDJP8eiBpG64eU9h7MeCiSwm2po3F/ZvPruuTG90NqRBdM6ipfMY+f\n4Blb4VeRTiY66n7agqGdhXKHmqQqKyuRmCg6ueLj43Hs2DHZPpmZmdi0aRNmzpyJTZs2obGxEQ0N\nDYiMbL/klUuhsrEalfYq9DGJ64t7GS/2Vx5GuDoMMVojKMpXh943EDklTtGvd5xBeGYTAEJWrhgA\nVqw7DQpqJETr4XR5uX1YX66Fnz2+uMqGPy/bA20uCwJiJcy4yDDADnTtpENBCaBX64FG0dbDVUoN\nw4WGQlQ1BWZ1A8DtNyfj+1IuoopfT0JF0rIiidI/fqldHlEFyGdx4ZLPjJ8PQ9h+CQIjlIaREd0N\nGdHdhO/SGfp9mXdCRamgpbV4VJJRn5KgRVykATUOTksalzoKaZGp2F/J3b/ZGejDSDYkYU6fB4Wy\nEfF6E+7Pugufn1opmLumpN/GFRL0I9jsXNoXakoFb4hktbYgHbjUIQTumJSbsbFwK5o8TQhXhcly\nDWJ1MSjyLbHbUjQNHWRW7b0MgRGh6RgNAwAmpo1tcZ+smB5BF9sq9i1Z0NmQ3OI51NTVLTDaW4vr\nUIHRmlnls88+i4ULF2L16tUYOHAg4uPjQVEt/5FMpvbriItlDVjwr134y+9uwF/3vAkA+GjKm3j1\no0PomhSJrr3N+PepL4X9E2unAQkAvDRMJgN6etOxmq98Tbnh8no4E4Ff2Ofhs3UB0UJC9U82uCrP\na/jJJgOevnsk8puOA2UAVNwf1D8qJVytR6zeCKvLhnUXNwU9Z2Q0CXBjA4Z0zYHJZICaVsPtEgXG\n6G7D8PmRbwAAlY3ytRwM6jAkxEchOSIRJZZypCclQqviBpnBKX1wpPo4hnfNlT2jwal9sK/yINdF\nJNXs84uOCm/V842yiKG5Rl+YbrDjjPow1Dg44RkfEwWT0YB4t9xM0DnehCideGysVwwbNpkMiHOI\nExiKINEtuVPQwTEhxhi0DSZ9NKob65CcEIvhzoE4Xnsao7oNbbf3ONYptjc+Jkpotz86lYYTGBq9\nIDBMJgO6xCTiYBXnlEoyRcMUEbpdsW7xt86mOJhMBtzSbTjWnNqIYWn9Wrwnoy4S5qYGdE1MANke\nkWKt4FL6+Zb04fghfwduzRwGU2zzx0nfk0RjbLuOS+1BNMNpzjG64O/lpdKhAiMhIQFlZWJNosrK\nSsTFxcn2iYuLwz/+wUXvNDY2YtOmTQgPb1lSV1e338pzH6w6ggabC/9ceRjwBTOcLazCyYt1OHmx\nDr0ZefnwixW10CYAjJdCdbUViVQyRnW+EVuKd3DVJwkGHg9Au4xgGULMcfDTIoZkx+MwSwBgYdBp\nkZMdj9QEg69SqRwtTUNPE2iycyakeht3/2F+AiNMFY4Hs+7DS7tfR4MzeB9VmH2z5LTb0FWTjupq\nK2i/V2FI9GAQWTQ+O7Uy4DxhqnBUV1vxZM4f0ORxwFrvhhVcu/oY+uL5QQlICIuTPaNMfRaeHzQP\nakoFHa1r9vlZrc5WPd8muyjgGm3c9YMdR7HiLNBmcaPaY4XTLndsN1oYuG3isVZJaZXqaiscdtGM\np6f1qK0JngnssHmDtuFPA59Ak8eBhjoHeoX3xoLBTyFeH9du77HDJravyeoB4oP3Bf+c9ZQerw9/\nEQC3n54V/3N2iwfVId4dALBbRTOmysU9y9EJI9Ensg/i1aYW7+n53Hlwel2orb20bOrLxWQyXFI/\nT0geh2GmoTCyMS0eJ+0L0qVq13GpvXh12PPQUGpUV1vbLDQ6VLz37t0bRUVFKC0thcvlwrp16zB6\n9GjZPmazWdBE/vWvf2H69Okd2aQAVp37FsWxawCwcHvFQaTW0gQirB7agRtxplae9KXpxS1KI3Wm\npho4xzHd6TwI2g2vh0CSMdJXNZRH1CLuvaU7Fxnk29TFFInfT+6JsYO6IJjfUKfmrsWbA/jIK38N\nI0IVjhitUWai8IefWXaWhPr5JywRBCH7HRAzhw2+DF4trQ0wyxAEgaTwhIDZN789VhfTYiZr0Azp\nFvZrLgpFaibizV3+Ic6qFiKWpH6BULkG/tfy3873FUEQQt5Ke6FqhQ8DENdmMagNMKjDBTNKrFZS\npqKF0hRSk5RRw90TSZBCKZyW0NG6NuVddDQqShW0rEowpObTjnLit5UoTWS7FB4EOlhgUBSFBQsW\nYNasWZg4cSImTJiA9PR0vPfee9i6dSsAIC8vD+PGjcO4ceNQV1eHOXPmdGSTAthavBNeqgmg3Sis\nEGcHtRY7VMnnQJAsqCi5L4CvYc94SWw7VAqHywPGJzxI38IpnppOiI7QyipW8lnPE25IwS0DO0Or\nocAHiUoHsHf+OBzvPX6j/JqEPHafTxL0H7wM6nAQBCFzRPeKycSwJHGJSb54oXSQeajnvUJsP480\n8oYAgUlpY9E9Kg03JcvzItqb1trBpWGpzfkDpEEJ/D13MSRjeNJgDEro36rMWqmturnY/PYov3A5\nSNsXyocBiD4rg0qeac9nIhMgWizTIRWuVypc9GohKTwBA+L6IsfUC10jW877+bXT4XkYI0aMwIgR\nI2Tb5s6dK3weO3Ysxo5t2UHV4RByE0V1gz1gm6c2EXSMGFZJ6uz498YzWLPjAqyohdZXB411q+Gt\nTEF0sgasVRxA/v7IMFSZGxEbyQ04Xi8raBjSMhF8ZdL3Hr8Rz+3eAACIjfCtSyxoGJwTV6+WD168\nw9KgNqDWFzI6NGkw+pp6IlITifUXf8AZM2fykg4ycXoTHup5DxbuFcOapb+rKBUGxOcIizt1JC3F\n8vNIhWxzA5c0N4IXgiRB4p7M0Nqs/9xfrmGE1pDao8Db5SAV7s0N+E6GExj+CYZGbSQogoKaUrWo\n+VxKAMO1Dk3SmNXrvivdjF8MJdPbB+GX/Xy+3MyV25DANsp9K14z54+xNLoBiXmKcHEDSpIpDGlx\n8toycUa9UC4jK9WIYBoGT7hOkhRk8C0cwwsMn4bhb97hB05pxAavNvsPJP61q/gQPF7TkIdqtl9x\nuFBkGrkcigR9XAt7ckgTIZvTMKTFA1syPfGYfAlx6ZFctrQ0ciiYhsGbWH6JfgpGa58VHzLqXwyP\nJEikRCQjTteyWSnKV0KlV0zm5TRV4VeMsuIej59wKKpqgLqrXIjwdZwAgLFFwlstht2xXvEP27dL\nF4wZOBApCQYYk3pg2fFdQS+ZnhQJgltMLmTNGx6+fEGAhqGSz2j5QTRYPR3/gcTfzxGm0uOVoX8W\nqoxKZ9XBqm22N4/0nYUGl6XViU5SgdFc1jQ/6ANotd8gShOJhUPnC3ZpaRhxMB/Gi0OegcPjaJds\n2stB+iyb81/xBCth8oc+s1p1LaM2CouGLQgIuFC49rkuNYyLDUV4ctvzOFEtqdrpX1+JZBCml89a\npYvcM049ZIYLSQhtosGErokRIAkCRq28qmsoWhQYvA+D9PNh+Jmk+MFemrDDzz79naHSWSmPURsl\n2PlJghTs4cEGmPaGIqlLyoqVmqGac5R3jQzMZG8N0Vqj8FykgiaYhqGh1ELxwiuB1G/RnFDk7ydY\nNrRepWvWoS8lUmNo8Z1VuPa4Lp/4tpKdcDFufHTsc2EbQXrlNUpJBrSKBXw5eRRB4Y+398Oyk1wu\ngX8mM0CgK3JhiG3EIEmBttYm86iIVgoM3358VrdUCPSOzcaozjcGXJefcfo7Q5tzjvJoKQ1cXleL\nS1ReCaRmvOYGL5qkcX/WXSFXNbxULiVr/Zeitaa2ZwY8hr0VBzDwF/BFKVx7XJcCIymMq0HkgmQt\naz8NIz5aI8t81tIaGHSi+Sc90YiTkrJYqQkGPD1qVMC1WiswWjtb899PRYnfH8i6S9AOpCF+oU1S\nLV9TFeLYqwFpoEBLpbxvSBzY7O+tgVu73YFGT+Aa6Fea1prCkg1JSDYktbyjgkIQrkuTVNAoD5LB\ntBGirXvCsM6yOkvcetrioKShxcFq4cOD8Nx9/YNeix+sW4o7D+b0ljXPZ/7yH+SlA7lUY5CaHEST\nVKCjsyV438nVKDDkGkbHh3dOSBsDgFs/4mqlNf4LBYXL5brSMHYfr0DnuPCgBfe0WmDs4GR8v923\ngWACNAxaMqPVSArfRUdooVGFHrAWj1jYYjJaixoGQQTdL0IbqEkAwZ3e0t8Xj1jY/PWEywa/7tWA\nLHGvHesyhWJk8nDkxve7KmsGAcDfR7z8i/SDwvXL1TcKdBAWuwsffsetijVheqDAYDsfxpGabOG7\nw+uUVX/VUhrZoCmtlKpRN/8nbc1KV6H+6FpKA4fXKeQS+O8XFaKAm8zp7Zt1EhKFsrWrb/ECg8HV\nZ7eXmaR+AQ2DIIirVlgArSvHraDQFq4bgeFwS0t6B3d+fnzi/4TP58wXZL/paJ3MHCQ1SbW0BkBr\nCPOJaOkAACAASURBVKWBPDXgUewqy0NuAudIlwktUgWtSouHsu9Bo9+aFTpaC5qg4GG9wjHJ4YmY\nlDYWWSGWkQwG79y/Gh29MpOUMrNWUOhwrnmB4fYwUNEkHE5RSJTUNLR4HL9GL0+YSi8brLWq9rUV\nEyEERlJ4Au7oMVn4Lh0keS1iYEK/gOO42bABVpdV8FUQBIFxqaMD9m2+Xb6lYf1WobsakIXVXoUm\nMwWFa41r2um9+3gFfv/3bThxsQ4Ol6hhnCura+YoDtbPBKOndbLBWku3rxO4tQllJEEGXew+GCkR\nnZEYZKH7S4G/1tVokqJbmemtoKDQPlzT07LvdhcAALYfKcONfRK50FnCK5T8YL0UtxBSK9CpdDKn\nt1atgpCk0Q6QAdWLQsOvatfSalqzet4bIPguFT5K6qrUMCRC4kplWCsoXE9c0wJDit3phDZnGwja\nDdbLDTSsR9VqgRFG62UmEJqkMWt8ulCBtq2EMkk1R0sO2PZwBMfoolFiK7uiWcyhkN5fe5YKV1BQ\nCM51IzAsDjsImouOEoSE34p4BAjZjJwmaSE7WK/SyWaxNEnjhj6JaC8uRcPg8V+voiOYkTEVJl0M\nbk0Z2eHXulQUrUJB4ZfluvnHNbnk5iOWIQG/Nbf91zKWRkX51w9qb5v55WgYWdHd27UNwYhQGzC1\n24QWFz1SUFC49rluBQYYEp4KMbO7d2x2gMCQRkX5F2Vr77j/SzGpdI9KQ7gqTFj0RkFBQeGXoMNN\nUj/99BMWLVoElmUxffp0zJ49W/Z7eXk5nnvuOVitVjAMg3nz5uGmm25ql2uzhAf8ehNNbr9kPYZC\ndlQvPDLiTqF0xuv73gVgFnaROrlV5KXXYboULsUk9Xi/3wuObwUFBYVfig7VMBiGwcKFC7F8+XJ8\n9913WLduHfLz82X7fPDBBxg/fjxWr16Nt956Cy+//HK7XLvR3YiGtG+hSjsKsCwcnkCTVFyUDhpa\nDYIgQBBEwBKl0sJ+ejr4uhNtha//dCkmH4IgrvulMRUUFH55OlTDOHr0KFJSUtCpE+ecnTBhAjZv\n3oz09HRhH4IgYLNxa0xbLBbEx7ctb4CnxMYtpUrHlsNTx8LrdsvuNkKnxR3D02XH3J1xO0z6GKy9\nsBEAV8jt2YF/xMWGIsToomX7tlexu+cGzsWR6uPIjslol/Ndb/y218wW63QpKCi0Dx0qMCorK5GY\nKEYSxcfH49ixY7J9HnvsMcyaNQufffYZHA4HPv7443a5tsVlFT5X06cR5omS/R4drg8oGKim1BiX\nOhqbCrfC6XWBJmikRHRGSkTngPO3VzG+hLA4JIQFlkVXaB394npf6SYoKFw3dKjAaE39oXXr1mH6\n9Ol48MEHcfjwYTzzzDNYt25di8eZTM0nrTmq7cLnuoj9sJcMBiSH6DTakOfQqrRwel3QazUh94mN\njoApuvk2/FK01BfXE0pfiCh9IaL0RfvQoQIjISEBZWVlwvfKykrExcXJ9lm1ahWWL18OAMjJyYHT\n6URdXR2io+UmIH+qq63N/l5YUyb7bnXZIXVbE14y5DlI1ldwz0OE3MdS70C1t/k2/BKYTIYW++J6\nQekLEaUvRJS+EGmr4OxQ42/v3r1RVFSE0tJSuFwurFu3DqNHy4vfJSUlYdeuXQCA/Px8uFyuFoVF\na6h1mGXfNTp5RrfUoe0PnxDWnNlJSRpTUFC43uhQDYOiKCxYsACzZs0Cy7K44447kJ6ejvfeew+9\ne/fGyJEj8dxzz+GFF17AJ598ApIk8cYbb7TLtZ1eeVRUdrcwnBCtVOgcHjpLmmqFwGhrjSYFBQWF\nXxsdnocxYsQIjBgxQrZt7ty5wuf09HT85z//affrerxyjYJWuwGJwMhsZk0IoRx4u7dKQUFB4dfL\nNWtXaWiULyjkJpwAAJMuBrG6GKQGiXziEUp6B0mOG5k8XDiPgoKCwvXENVt80O50Qerltrs59WJq\nt4noa+rZ7LG8wPAGERh39JiM6d0nKdVRFRQUrjuuSQ3D4fLAw8hNUjafwGhN/kRzGgaglNJWUFC4\nPrkmBUZJlR0g5E5pm5vLJm9NlVmqBYGhoKCgcD1yTQqM4iorCEI+2PNRU63RMAhFYCgoKCgEcE0K\nDLPNGaBh8LSmBhTVjA9DQUFB4XrlmhQYDTYXQLAwaePw0pBnZb+1hw9DQUFB4Xrk2hQYdhdAMFBT\ndLOLIoWiiyEZAJBsSGphTwUFBYXrh2syrLbB7gKMLGiSChQYRMu3PL7rGMSHxaGfSamEqqCgoMBz\nTQoMi90FgmBAkRRokoaaUsN1CU5vNaXCDYkDO7qZCgoKCr8qrjmTFMOysNidACE6r/W0uB63Slmp\nTkFBQeGyuOYEhr3JLUQ38cuoSgVGey18pKCgoHC9cc0JDN7hDUBY91q6XnZ7rcWtoKCgcL1xjQoM\nLgfDX8NQkbQgRBQUFBQULo1rTmBYbFKBwd2emuKqEOppfcjjFBQUFBSa55oTGMFMUg4vV+pcr9KF\nPE5BQUFBoXk63AP8008/YdGiRWBZFtOnT8fs2bNlv7/22mvYu3cvCIJAY2MjzGYz8vLyLvt69TYn\nCD+TVKO7CYDc+a2goKCgcGl0qMBgGAYLFy7EJ598gri4ONxxxx0YPXo00tPThX3mz58vfP78889x\n6tSpNl3TItEw+BIf3aPSkN9QgF6xWW06t4KCgsL1TIcKjKNHjyIlJQWdOnHrZ0+YMAGbN2+WCQwp\n3333HR5//PE2XVPu9OYExviuY5Ae1RVZzSzLqqCgoKDQPB3qw6isrERiYqLwPT4+HlVVVUH3LSsr\nQ2lpKYYMGdKma9bbHdBm7QMg+jAokkJ2TIay8JGCgoJCG+hQDYNlg5cYD8a6deswduzYVg/qJpMh\n6Hartw5Qcet3h+t1Ife7lrge7rG1KH0hovSFiNIX7UOHCoyEhASUlZUJ3ysrKxEXFxd03/Xr1+Ol\nl15q9bmrq60B29weBvZGBny5QZfDG3S/awmTyXDN32NrUfpCROkLEaUvRNoqODvUJNW7d28UFRWh\ntLQULpcL69atw+jRowP2u3DhAiwWC3Jyctp0PWujS/ad92EoKCgoKLSdDtUwKIrCggULMGvWLLAs\nizvuuAPp6el477330Lt3b4wcORIAp11MmDChzdeTOrwBgFSyuhUUFBTajQ7PwxgxYgRGjBgh2zZ3\n7lzZ98cee6xdrtVgkwsMpW6UgoKCQvtxTdls6u1OIQcDUExSCgoKCu3JNTWiXii1yE1SisBQUFBQ\naDeumRGVYVkcvVALvVa8JTfjuYItUlBQULi2uGYExsot52Gxu9A1SQwbczPuK9giBQUFhWuLa0Zg\n7G/4Cepuh3BzPzGz3O1VBIaCgoJCe3HNrFfaFHUGFAC1WswUVzQMBQUFhfbjmtAwjteIFW6bPA7h\ns0sRGAoKCgrtxq9eYJRbq/HB0Y+F702eJuFzj6jgVXEVFP6/vTsPbKpKHz7+TdK0LC2b3QCZikVB\nsAqoLMKUdYChBVoBFas4U6SAQNlEFgXGqQNYmAr8FBVBQUBRXwGFMOpYQUAqKIIwLDrgQGmRlq3Q\njaTJPe8fLSmhQFJoUtM+n79yb05Ozn2g98k5595zhRDl5/UJI+PsRYftgpIeRmTjjrQLbVsZTRJC\niCrJacLIysryRDtumk6vOWxv+PVzAG73byTLmQshRAVymjAGDhzI2LFjSUtL80R7yu16V0LJOlJC\nCFGxnCaMr7/+mh49erBgwQL69u3L6tWrycvL80TbXHLJZr7mflkWRAghKpbTs6qvry8xMTF8+OGH\nvPzyy7z99ttERkaSlJTE2bNnPdHGGzJbr93DkIQhhBAVy6WzamZmJv/85z+ZNGkSHTt2ZOnSpdx2\n220MGzbM3e1zymIrfgaGKvJ12C8r1QohRMVyeuPeyJEj+eWXX3j88cdZu3Yt9evXB6Bt27Zs2rTJ\n7Q10xlwyh6Hl18VQ77R9vyw8KIQQFctpwhgwYAC9evXCYCj7i33jxo1Ov2Dr1q3Mnj0bpRQDBw4k\nISGhTJlNmzbx+uuvo9frad68OfPnz3ex+Vf2MIwO+w0y6S2EEBXKacKoW7cuBQUFBAQUL+p38eJF\nDhw4QMeOHZ1WrmkaSUlJLF++nODgYAYNGkSPHj0IDy+9oe748eMsXbqUDz/8EH9/f86dO1euA7CU\n9DCU1c9hvwxJCSFExXI6bpOcnIy/v79929/fn+TkZJcq37dvH2FhYTRu3Bij0UhUVBSpqakOZT76\n6COeeOIJ+3c0aNCgPO2nSCt5jvdVcxgyJCWEEBXL6VlVKeVwA5xer8dms7lUeVZWFg0blq4eGxIS\nQnZ2tkOZY8eO8b///Y8hQ4bw+OOPs23bNlfbDpSuF9Whxe0O+6WHIYQQFcvpkFTt2rX56aefuP/+\n+wH46aefqFWrlkuVK6WclrHZbKSnp7N69WpOnjxJXFwcJpPJoVdzI5dv3PMzOA5JSQ9DCCEqltOE\nMXnyZEaPHk2zZs0AOHLkCK+99ppLlYeGhnLy5En7dlZWFsHBwQ5lQkJCaNOmDXq9nttvv52mTZty\n7Ngx7r333hvWHRRU8qAkHw3McGeDO6gd/Ef+fbS4hxLYIICgBgE3qKHqsMdCSCyuILEoJbGoGE4T\nRps2bTCZTOzduxelFG3atKFu3bouVR4REUF6ejqZmZkEBQVhMplISUlxKNOzZ09MJhMxMTGcO3eO\n48eP06RJE6d1nz6dC0ChuXixwaJLiu53dbUnjIsXLnHalutSO71ZUFCAPRbVncSilMSilMSi1K0m\nTpceoFS3bl26dOlS7soNBgMzZswgPj4epRSDBg0iPDycRYsWERERQbdu3fjjH//It99+S1RUFAaD\ngeeff97lhARgVcXP7a7h44tRX3o4cqe3EEJULKcJ4/Dhw8yaNYvDhw9jsVjs+w8dOnSDT5WKjIwk\nMjLSYV9iYqLD9tSpU5k6dapL9V3NqornMHx9jPjoSg9HL5PeQghRoZz+DP/b3/7G+PHjCQsL45tv\nviEhIYEJEyZ4om0usSorStPhZ/DBx6GHIQlDCCEqktOEYbFY6NixI0opgoODmTBhQrkvfXUnqyoC\nzYDBoHe4/NeglyEpIYSoSE7PqvqSE2/dunU5fPgw58+fJzMz0+0Nc5WN4oThY3A8FOlhCCFExXI6\nhxEVFcX58+dJSEhgyJAhaJpWZg6iMtlUEcrmg4/B8el6ch+GEEJUrBsmDE3T6NixI/Xr1ycyMpJd\nu3ZhNptdvqnOE2xYQfPFUKaHIQlDCCEq0g3Pqnq9nhdeeMG+bTQaf1fJQlMams6KshnK9DBkSEoI\nISqW05/h4eHhZGRkeKIt5VakFd+DgWbAoJchKSGEcCencxjnzp2jf//+PPDAAw5rSC1cuNCtDXOF\nueR53sVzGI4JQhKGEEJULJcmvaOiojzRlnIzW0tuJLziKqm764XzS85Rh0tshRBC3DqnCSM2NtYT\n7bgpFq00YVwekkpsk4CmtEpslRBCVE1OE0ZiYuI1f63/voakSnsYOp1OJryFEMINnCaMbt262V+b\nzWa++OILh0esViaz7XIPwweDQYaghBDCnco9JPXII48watQotzWoPC4nDJ1mQC9zFkII4VblvpRI\np9P9bi6ztZQkDKPeWMktEUKIqq9ccxhKKX7++Wc6duzo9oa54pK1eA4joIZrj4wVQghx88o1h2Ew\nGIiPj6d169ZubZSrLhQUAFDPxWeMCyGEuHluv6x269atzJ49G6UUAwcOJCEhweH9devWkZycTGho\nKABxcXEMGjTIpbovFBQCUN+/5i21UQghhHNO5zCGDBnChQsX7Ns5OTnExcW5VLmmaSQlJbFs2TI2\nbtyIyWTi6NGjZcpFRUWxbt061q1b53KyALhYWDwk1cBfehhCCOFuThNGQUGBwzO269WrR15enkuV\n79u3j7CwMBo3bozRaCQqKorU1NQy5ZRS5WhyqcKi4oRRt1aNm/q8EEII1zlNGJqmUVAyVwCQn5+P\nzWZzqfKsrCwaNmxo3w4JCSE7O7tMuS+//JIBAwYwbtw4Tp065VLdAFatuB1+Bqcja0IIIW6R0zNt\ndHQ08fHxDBkyBIAPPviA/v37u1S5Kz2H7t27Ex0djdFoZM2aNUyZMoUVK1a4VL+1ZLVaPx8/l8oL\nIYS4eU4TxogRIwgODubrr79GKcXjjz9OTEyMS5WHhoZy8uRJ+3ZWVhbBwcEOZa4c7nr00UeZP3++\nS3UHBQWAQYENgm+rU7xdTVXnY7+axKKUxKKUxKJiuDSWExsbe1NXS0VERJCenk5mZiZBQUGYTCZS\nUlIcypw+fZqgoCAAUlNTadasmUt1nz6di7nIAnowFxRx+nRuudtXFQQFBVTbY7+axKKUxKKUxKLU\nrSZOp3MYY8eOJScnx759/vx5xo0b51LlBoOBGTNmEB8fT3R0NFFRUYSHh7No0SI2b94MwMqVK4mO\njiYmJoZVq1YxZ84clxtvU8VzGDWMcqe3EEK4m9MexokTJ6hXr559u379+qSnp7v8BZGRkURGRjrs\nS0xMtL+eOHEiEydOdLm+K11OGH4+vjf1eSGEEK5z2sOw2WwOV0UVFRVhsVjc2ihXaap40lt6GEII\n4X5OexidO3dmwoQJDB06FIAVK1aU6TFUFhsyJCWEEJ7iNGFMnDiRt956i7lz5wLFa0u1b9/e7Q1z\nhYYNpekx+sgDk4QQwt2cDkkZjUbGjBnD66+/zp/+9Cc+++wzpk+f7om2OaVhA01vfzyrEEII97lh\nD8NqtfL111/zySefsHfvXqxWK8uWLfvdrFar0EDpr/kIWSGEEBXruj2MOXPm0LVrV9asWUN0dDTf\nfPMNdevW/d0kCwCFDZ0q9zOghBBC3ITr9jA++OAD2rRpQ0JCAh06dAD43f2SVzoNNJm/EEIIT7hu\nwti+fTsbNmwgOTmZCxcuEBMT4/Kig56idDZ0yBVSQgjhCdcdz6lTpw5xcXGsXbuW119/nQsXLnDp\n0iXi4uJYs2aNJ9t4fToNPdLDEEIIT3BpAqBFixa8+OKLbNu2jbi4uGs+06JS6DR0ShKGEEJ4Qrke\nJGE0Gunbty99+/Z1V3tcpikNdEp6GEII4SFee4lRka0IQBKGEEJ4iNcmjEtFJQlDJwlDCCE8wWsT\nRmFR8QKIBulhCCGER3htwrhklR6GEEJ4ktcmDHPJkJSPrlzz9kIIIW6S2xPG1q1b6dOnD71792bJ\nkiXXLff555/TokULDhw44FK9l6wlQ1J66WEIIYQnuDVhaJpGUlISy5YtY+PGjZhMJo4ePVqmXH5+\nPqtWrSrXOlXSwxBCCM9ya8LYt28fYWFhNG7cGKPRSFRU1DVv+lu4cCHDhw/HWI4HIZlLLqv1kR6G\nEEJ4hFsTRlZWFg0bNrRvh4SEkJ2d7VDm0KFDnDp1ii5dupSrbnPJpLdBLz0MIYTwBLeebZVSTt+f\nPXs2r7zyisufuczoV5zravv5ERQUcPONrAKq+/FfSWJRSmJRSmJRMdyaMEJDQzl58qR9Oysri+Dg\nYPt2fn4+R44c4amnnkIpxZkzZ3j22Wd54403aNWq1Q3rPncxHwBl03H6dK57DsALBAUFVOvjv5LE\nopTEopTEotStJk63JoyIiAjS09PJzMwkKCgIk8lESkqK/X1/f3/S0tLs20899RTTpk2jZcuWTusu\nKhmSMsqQlBBCeIRbz7YGg4EZM2YQHx+PUopBgwYRHh7OokWLiIiIoFu3bg7ldTqdy0NSFs0KgNEg\nCUMIITzB7WfbyMhIIiMjHfYlJiZes+x7773ncr0W2+UehjxASQghPMFr7/S2lvQwfKWHIYQQHuG1\nCaPIVpIwfKSHIYQQnuC9CUN6GEII4VFemzAuD0n5SQ9DCCE8wusThgxJCSGEZ3hvwlDFCaOGJAwh\nhPAIr00YNs0GgJ+PbyW3RAghqgevTRiXexgyhyGEEJ7htQnDpop7GDXKsSS6EEKIm1cFEoYMSQkh\nhCd4bcLQkB6GEEJ4kvcmjJI5jJoy6S2EEB7hvQkDDaXA6CN3egshhCd4ccKwgfLa5gshhNfx2jOu\n0mnoNENlN0MIIaoN700Y0sMQQgiPcvsZd+vWrfTp04fevXuzZMmSMu+vWbOGfv36ERMTQ1xcHEeP\nHnWpXqWThCGEEJ7k1jOupmkkJSWxbNkyNm7ciMlkKpMQ+vXrx4YNG1i/fj3Dhg1jzpw5LtWtdBo6\nJUNSQgjhKW5NGPv27SMsLIzGjRtjNBqJiooiNTXVoUzt2rXtrwsKCtDrXWySTkPnvSNqQgjhddx6\nTWpWVhYNGza0b4eEhLB///4y5VavXs3y5cuxWq2sWLHCtcp1NulhCCGEB7k1YSilXCoXFxdHXFwc\nJpOJxYsXM3fuXOcf0mnodQaCggJusZXeT2JQSmJRSmJRSmJRMdyaMEJDQzl58qR9Oysri+Dg4OuW\n79u3L7NmzXJar02zgQ50Ss/p07kV0lZvFRQUUO1jcJnEopTEopTEotStJk63TgJERESQnp5OZmYm\nFosFk8lEjx49HMocP37c/nrz5s3ccccdTustshUByByGEEJ4kFt7GAaDgRkzZhAfH49SikGDBhEe\nHs6iRYuIiIigW7durFq1irS0NIxGI3Xq1OGVV15xWq+lJGHokTkMIYTwFLcvxBQZGUlkZKTDvsTE\nRPvrF154odx1mq2SMIQQwtO8ckznUpEFkIQhhBCe5JUJw1xUvLS5JAwhhPAcr0wYl6zFPQyDThKG\nEEJ4ilcmDPschiQMIYTwGO9MGCVzGAYZkhJCCI/xzoRhLZ7DMOjkaXtCCOEpXpkwLPY5DEkYQgjh\nKd6ZMGyXexgyJCWEEJ7ilQnDXHKnt49eehhCCOEpXpkwikqukpIehhBCeI5XJozLQ1LSwxBCCM/x\n0oRRPOktCUOI6iMvL4916/7fTX32+efHk5+fV8Etqn68MmEUXe5hyJCUENVGbu5F1q37+JrvaZp2\nw88mJy+gdm1/dzTrlrn6oLnfA6/8iX55eXOjwVjJLRFCeMqbb77GyZOZxMfH8eCD7enYsRPvvvs2\nt90WyJEjv7By5UdMm/Ycp09nY7GYGTx4CP36xQAweHB/li1bSUFBAc89l0hERGv+85+fCAoKYe7c\nf+Lr6+vwXd9+u40VK5ZhtVqpW7cuM2e+TP369SksLOTVV5P5+edD6HR6/vrX4XTp0o3vvtvBkiWL\n0TSNevXqsWDBYt55Zwm1atXi8cefBGDo0MdITl4IKJ57LpE2bR7kwIH9zJkzn5Url/Pzzwcxm810\n7dqD+PgEAA4dOsCiRf+ksPASvr6+LFiwmMmTxzFhwvM0a3YXAKNGDWPy5GnceWczt/8beHfC0EvC\nEKIyfPT1Eb4/nF2hdT7UIphHu1//pDdq1FiOHfuVd95ZDcCePbs5dOggK1d+RGhoKADTp88iICAA\ns9nM8OFD6dKle8lT5nT2ejIyTvDSS3OYMuUFZs6cxpYtX9OrVx+H77r//jYsWbIcgI0b1/P+++8x\nevQ4li9fSkBAACtWrAGKh8lycnJITv4HixcvIzQ0lNzcaz/dT6crbcOJE+m88MLfmDRpCgAjRowm\nICAATdMYN24Uv/56hD/84Q5mzZpOUtIrNG/egoKCAvz8/OjXL4ZNmz4jMXESJ06kY7UWeSRZgJcm\njCKteEjKKHMYQlRrLVu2sicLgI8+ep9t274BIDs7m4yMdMLDGwOlwz4NGzYiPLz4BNu8eQtOnTrJ\n1bKzTzFz5gLOnj2D1WqlYcNGAPzwwy7+/vc59nL+/v58++022rRpa29HQMC1H4N65dBTSEgo99zT\nyr6dmvoFn322HpvNxrlzZ/nf//4HQGBgEM2btwCgVq1aAHTr1oPly5cxevR4TKbP+POf+7kYrVvn\n9jPu1q1bmT17NkopBg4cSEJCgsP7y5cv5+OPP8bHx4cGDRowe/ZsGjZseMM6i0omvWVISojK8Wj3\nZjfsDXhKjRo17K/37NnNjz/+wJIly/H19WXs2BFYLJYyn7ly+EmvN1yzzKuvzmPIkKd4+OHO7Nmz\nm3fffRu49nzD9eYgDAYDmlb63pXfU7NmTfvr3347yZo1q1m2bCW1a/sze/ZLWCxmrje14edXg4ce\nas+2bVvYvPkrli5dee2CbuDWSW9N00hKSmLZsmVs3LgRk8nE0aNHHcq0bNmStWvX8umnn9KrVy+S\nk5Od1nu5h+ErQ1JCVBu1atWioKDguu/n5+cREBCAr68vx48f48CB/1yznCuTzPn5+QQGBgLwr39t\ntO9v164Dn3zyoX07NzeXe++9j71793Dq1G8AXLx4ESjuyfzyy2EAfv75ML/9VtqTubIN+fn51KxZ\nk1q1anPu3Fm++24HAGFhd3D27BkOHz4EQEFBgX1yPzp6AAsWzOeee1pdt0fjDm7tYezbt4+wsDAa\nN24MQFRUFKmpqYSHh9vLtGvXzv66devWbNiwwWm9Vq14DsNXehhCVBt16tQlIuJ+nn76cdq3f5iO\nHTs5vN++/cOsX/8Jf/nLE/zhD2Hce2/EFe+Wzh9cOZdwPfHxw3nxxSkEB4fQsuW99mTw9NPDSEl5\nhaFDH8NgMPDXvyYQGdmV559/genTn0MpRf36DUhJeY0uXbrz+ecm4uPjaNGiJU2ahF2zDc2a3cVd\ndzXnqaceo1Gjxtx33/0A+Pj48NJLc3j11WTMZjM1atRgwYLF1KhRg+bNW1C7dm2iojw3HAWgU268\npuuLL75g+/btJCUlAfDpp5+yf/9+XnzxxWuWT0pKIigoiJEjR96w3omfzifj0lGGhIyhc6s/VHi7\nvUlQUACnT197kq26kViUkliUqoqxOHPmNImJI3n//U/K9bniCwBunlt7GOXJRZ9++ikHDhxg5Urn\n43GXexiB9evccgCqAolBKYlFKYlFqaoUi/Xr17Nw4UKmTZvm8eNya8IIDQ3l5MnScbusrCyCMd1e\n6gAAERdJREFUg4PLlNuxYwdLlixh1apVGI3Oh5msyopSUJhXVOV+OZRXVfz1dLMkFqUkFqWqWiw6\ndepBp049AMp9XLeaYNw66R0REUF6ejqZmZlYLBZMJhM9evRwKHPw4EFmzZrFG2+8Qf369V2q16pZ\nQTPg6yt3egshhKe4tYdhMBiYMWMG8fHxKKUYNGgQ4eHhLFq0iIiICLp168a8efMoLCxk3LhxKKVo\n1KgRixcvvmG9xQlDTw1JGEII4TFuvw8jMjKSyMhIh32JiYn21++++26567SqIpRmoIZREoYQQniK\nVy4+aFNWUHr8pIchhBAe45UJQ8MGmoEavrI0iBDVxa0sbw7w0UcfYDabK7BF1Y+XJgyZwxCiurnR\n8uau+PjjDzCbL1Vgi8rPZrNV6vffKq/8ia50GigDPgavzHdCiJtw9fLmzz6byPvvr2Tz5n9TVGQl\nMrIr8fEJXLp0iZkzp3L6dDaapjF27BiOHcvgzJnTjB07knr16rFw4RsOdS9fvpRvv92GxWLm3nvv\nY/Lk6QBkZmYwb95scnJyMBgMJCXNpVGjxqxevYIvv/wXer2eDh06MWLEaMaOHcGYMRNo3rwFFy7k\n8MwzQ/n448/41782smPHdiwWM5cumZk7959MnTqJvLxcrFYrw4ePpHPnLkDxMiRr1qxGr9cRHn4X\nEydO4emnh7BmzVoMBgMFBfkl2+swGDz/g9krEwaAXknvQojKsvbIRvZk76/QOtsER/BIs+jrvn/1\n8ubff/8dGRnpvP32eyilmDJlIj/9tJecnHMEBgaRnLwAgJo1dTz4oOLDDz/g//7vLerUqVOm7oED\nH+Mvf3kGgKSkmezYsZ2HH+7MSy+9yNChf6Vz5y4UFRWhaRrffbeD7du38vbb7+Hr63vd5cyvXI7k\nwIH9vPfeh/j7+6NpGnPmzKdWrVpcuJDDiBHF9f/661FWrVrOG2+8Q506dcjNzaVWrVq0bfsAaWnb\n6dy5C1999SVdu/aolGQBXpwwDPK0PSGqtV27dvL997uIj49DKUVh4SUyMtK5777WvP76Qt588zU6\nduxMz55/pLAwl+Ilzq+9+sTu3bt4//2VmM2XyM3N5c47w2ndui1nzpy2//q/fFPxDz/sIiqqn33V\nW1cW/3voofb4+xc/8U/TNN566zX27t2DXq/jzJnTnD9/jj17fqBr1x72hHa53ujoAbz//ko6d+7C\npk0bmDLl2ksreYLXJgy9ThYeFKKyPNIs+oa9AU9QSvHUU3+hf//YMu8tW7aKtLRveeut1/jll/0M\nHvzUdeuxWCykpCTzzjurCAwM4p13lpQsRX7t5FK85FHZBQwNBgNKafY6r3Tlcub//vfn5OTk8O67\nq9Hr9Qwe3B+z2XLdpZQiIu7n1KlX2Lv3RzRNo2nTO697LO7mtZMAvqqm80JCiCrj6uXN27fvgMn0\nGYWFhQAlv9TPc+bMGfz8/OjVqw9DhjzJwYMHSz5fm/z8/DL1WiwWdLri1XALCgrYsiXVXj44OIRt\n27YAUFRUhNl8iXbtir/38gR66XLmjTl8uPi7Nm/+6rrHkZeXR/36DdDr9fz44w/2lXAfeKAdmzd/\nxcWLFxzqBejduy9/+9sLREX1L3/gKpBX9jDMhx/iD/XvqOxmCCE86OrlzZ99NpFjx44xcuRfgeKE\nMmNGEhkZJ3j99YXo9Tp8fIz84x/Fq2X37x/Dc88lEhgY5DDp7e/vT79+sQwd+hgNGzZyeBLeiy++\nxLx5s1m69C2MRiNJSXNp374jR478wrBhQ/H1NdKhQycSEp5lyJA4ZsyYxhdf/IsHHnjousfRq1cf\npkyZyPDhQ2nWrDlhYU0BaNr0ToYOjWfMmAQMBgN33dWc6dNnlXzmzyxd+iY9e/aq8LiWh1uXN3eX\nfpM+pc1dgYwdeF9lN6XSVbWF1W6FxKKUxKJUVYjF5s1f8e2323jxxZduqZ7f9fLm7iR3eQshqoMF\nC+bx3XdpzJ+/sLKb4r0JI6Cmr/NCQgjh5caPn1zZTbDz2knv+gF+ld0EIYSoVrw2YdTzlx6GEEJ4\nktcmDOlhCCGEZ7k9YWzdupU+ffrQu3dvlixZUub9H374gUceeYRWrVrx5ZdfulxvPX9JGEII4Ulu\nTRiappGUlMSyZcvYuHEjJpOJo0ePOpRp1KgRc+fOpV+/fuWqu570MIQQwqPcepXUvn37CAsLo3Hj\nxgBERUWRmppKeHi4vUyjRo0A0OnK3mp/PQ1vq42fPG1PCCE8yq09jKysLBo2bGjfDgkJITs7+5br\nfXVCl1uuQwghRPm4NWG46yby2jVl4UEhhPA0tw5JhYaGcvLkSft2VlYWwcHBFVL3rd7iXpVILEpJ\nLEpJLEpJLCqGW3sYERERpKenk5mZicViwWQy0aNHj+uW98JlrYQQotpw++KDW7du5R//+AdKKQYN\nGkRCQgKLFi0iIiKCbt26sX//fsaMGcPFixfx8/MjKCiIDRs2uLNJQgghboJXrlYrhBDC87z2Tm8h\nhBCeJQlDCCGESyRhCCGEcInXJQxna1NVNdOnT+fhhx92WDrlwoULxMfH07t3b4YNG0ZubunTxF5+\n+WV69erFgAEDOHToUGU02S1OnTrF0KFD6du3L/369eO9994DqmcsLBYLgwcPJiYmhn79+vHaa68B\nkJGRwaOPPkrv3r2ZOHEiVqvVXn7ChAn06tWLxx57zOFS96pC0zRiY2MZOXIkUH1j0b17d/r3709M\nTAyDBg0CKvhvRHkRm82mevbsqTIyMpTFYlH9+/dXR44cqexmudX333+vDh48qKKjo+37kpOT1ZIl\nS5RSSr311ltq3rx5SimltmzZooYPH66UUmrv3r1q8ODBnm+wm2RnZ6uDBw8qpZTKy8tTvXr1UkeO\nHKmWsVBKqYKCAqWUUlarVQ0ePFjt3btXjRs3Tm3atEkppdTMmTPVBx98oJRSavXq1WrWrFlKKaVM\nJpMaP358pbTZnd599101adIkNWLECKWUqrax6N69u8rJyXHYV5F/I17Vw7hybSqj0Whfm6oqe/DB\nB6lTp47DvtTUVGJjYwGIjY21xyA1NZWYmBgA7r//fnJzczlz5oxnG+wmQUFB3HPPPQDUrl2b8PBw\nsrKyqmUsAGrWrAkU/2K2Wq3odDp27txJ7969geJYfPXVV4Dj/5fevXuTlpZWOY12k1OnTvHNN98w\nePBg+77vvvuuWsZCKYWmaQ77KvJvxKsShrvWpvI2586dIzAwECg+kZ47dw6A7OxsQkND7eVCQkLI\nysqqlDa6U0ZGBocPH+b+++/n7Nmz1TIWmqYRExNDp06d6NSpE02aNKFOnTro9cV/0qGhofbjvTIW\nBoOBOnXqkJOTU2ltr2izZ8/m+eefty9gev78eerWrVstY6HT6Rg2bBgDBw7k448/BqjQvxGveqa3\nkltGbuha8SnPKsDeID8/n8TERKZPn07t2rWve3xVPRZ6vZ7169eTl5fH6NGjyzw2AEqP9+pYKKWq\nTCy2bNlCYGAg99xzDzt37gSKj+/qY64OsQBYs2aNPSnEx8fTtGnTCv0b8aqE4c61qbzJbbfdxpkz\nZwgMDOT06dM0aNAAKP6FcOrUKXu5U6dOVan4WK1WEhMTGTBgAD179gSqbywu8/f356GHHuKnn37i\n4sWLaJqGXq93ON7LsQgJCcFms5GXl0fdunUrueUV48cff+Trr7/mm2++wWw2k5+fz+zZs8nNza12\nsYDiHgRAgwYN6NmzJ/v27avQvxGvGpIq79pUVcXVvwS6d+/O2rVrAVi3bp09Bj169GD9+vUA7N27\nlzp16ti7olXB9OnTadasGU8//bR9X3WMxblz5+xXuly6dIm0tDSaNWtG+/bt+fzzzwHHWHTv3p11\n69YB8Pnnn9OhQ4fKabgbTJw4kS1btpCamkpKSgrt27dn/vz51TIWhYWF5OfnA1BQUMD27du5++67\nK/RvxOuWBrnW2lRV2aRJk9i5cyc5OTkEBgYyduxYevbsybhx4/jtt99o1KgRCxcutE+M//3vf2fb\ntm3UrFmTOXPm0KpVq0o+goqxe/dunnzySe6++250Oh06nY4JEyZw3333MX78+GoVi59//pmpU6ei\naRqaptG3b19GjRrFiRMnmDhxIhcvXuSee+5h3rx5GI1GLBYLkydP5tChQ9SrV4+UlBRuv/32yj6M\nCrdr1y7eeecd3nzzzWoZixMnTjBmzBh0Oh02m41+/fqRkJBATk5Ohf2NeF3CEEIIUTm8akhKCCFE\n5ZGEIYQQwiWSMIQQQrhEEoYQQgiXSMIQQgjhEkkYQgghXCIJQ3i1Rx99lNjYWKKiomjVqhWxsbHE\nxsYyffr0ctf1zDPPuLTc9bRp09i7d+/NNLdcDh48yBdffOH27xHCVXIfhqgSMjMzGTRo0A1XH728\nVIS3+Pjjj0lLSyMlJaWymyIE4GVrSQlRHmlpacybN4/WrVtz8OBBRo8ezblz51i9erX9gTpTp06l\nXbt2AHTp0oXly5fTtGlTnnjiCdq0acOePXvIzs4mOjqa8ePHA/DEE0/w7LPP0rlzZyZPnoy/vz9H\njx4lKyuLtm3bMmfOHKB4bZ7nn3+e8+fP06RJE2w2G927d+exxx5zaOeZM2eYNGkS58+fB6Bz5848\n88wzLF68mIKCAmJjY2nfvj1Tp05lz549pKSkUFhYCEBiYiKRkZGkp6fzxBNPEB0dze7du7FYLMya\nNYu2bdt6JNaimriVh3UI8XuRkZGhOnTo4LBvx44dqmXLlmr//v32fVc+XObIkSOqa9eu9u3IyEj1\n66+/KqWUGjJkiJo0aZJSSqmLFy+qdu3aqYyMDPt727ZtU0op9dxzz6knn3xSFRUVKbPZrPr06aN2\n7typlFJq1KhR6u2331ZKKXXixAnVpk0btWbNmjJtX7p0qZo5c6Z9++LFi0oppT766CM1ceJEh7bH\nxMSos2fPKqWUOnXqlIqMjFR5eXnq+PHjqnnz5spkMtmPvWvXrspqtboeRCGckB6GqNLuvPNO7r33\nXvv2sWPHWLRoEdnZ2RgMBrKzs8nJyaFevXplPvvnP/8ZgICAAJo2bUp6ejqNGzcuU+5Pf/oTPj7F\nf0otW7YkPT2ddu3asXPnTl5++WUAbr/9dntP5mqtW7dm1apVzJ8/n4ceeojOnTtfs9zu3bvJyMhg\n2LBh9gUpDQYDJ06coFatWtSsWZO+ffsC0LFjRwwGA8eOHSM8PNzVcAlxQ5IwRJVWu3Zth+0JEyYw\na9YsunTpgqZp3HfffZjN5mt+1s/Pz/5ar9djs9nKVc7V5yw88MADrFu3jh07dvDJJ5+wdOlSVq5c\nWaacUopWrVqxfPnyMu+lp6eX2adpWpV61oOofN4zAyiEE8qF6zfy8vLsq5OuWbPmukmgIrRr186+\nrHRmZia7du26ZrmMjAz8/f3p27cvU6dO5T//+Q9Q/KyLy8uYA7Rt25YjR47www8/2Pft27fP/rqw\nsJBNmzYBxY8oBQgLC6vYgxLVmvQwRJXhyq/p6dOnk5CQQMOGDWnfvj0BAQHX/PzVdV3vvRuVmzFj\nBlOmTMFkMnHnnXfStm1bh++7LC0tjffeew+DwYBSiqSkJAA6derEihUriImJoUOHDkydOpXFixcz\nb948cnNzKSoqokmTJrz55psABAYG8t///pfBgwdjsVhISUnBYDA4jYkQrpLLaoVwE7PZjNFoRK/X\nk5WVxeDBg1m9ejVNmjSp8O+6fJXU9u3bK7xuIS6THoYQbvLrr78ybdo0lFJomsaECRPckiyE8BTp\nYQghhHCJTHoLIYRwiSQMIYQQLpGEIYQQwiWSMIQQQrhEEoYQQgiXSMIQQgjhkv8PZHg4l1eLyCQA\nAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "\u003cmatplotlib.figure.Figure at 0x7f96f7389490\u003e"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "with context.eager_mode():\n",
+ " durations = []\n",
+ " for t in range(burn_ins + trials):\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=max_steps,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 500)\n",
+ " test_ds = setup_mnist_data(False, hp, 100)\n",
+ " ds = tf.data.Dataset.zip((train_ds, test_ds))\n",
+ " start = time.time()\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = train(ds, hp)\n",
+ " if t \u003c burn_ins:\n",
+ " continue\n",
+ " train_losses[-1].numpy()\n",
+ " test_losses[-1].numpy()\n",
+ " train_accuracies[-1].numpy()\n",
+ " test_accuracies[-1].numpy()\n",
+ " duration = time.time() - start\n",
+ " durations.append(duration)\n",
+ " print('Duration:', duration)\n",
+ "\n",
+ "\n",
+ " print('Mean duration:', np.mean(durations), '+/-', np.std(durations))\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " print('test_accuracy', test_accuracies[-1])\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "default_view": {},
+ "last_runtime": {
+ "build_target": "",
+ "kind": "local"
+ },
+ "name": "Autograph vs. Eager MNIST benchmark",
+ "provenance": [
+ {
+ "file_id": "1tAQW5tHUgAc8M4-iwwJm6Xs6dV9nEqtD",
+ "timestamp": 1530297010607
+ },
+ {
+ "file_id": "18dCjshrmHiPTIe1CNsL8tnpdGkuXgpM9",
+ "timestamp": 1530289467317
+ },
+ {
+ "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
+ "timestamp": 1522272821237
+ },
+ {
+ "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
+ "timestamp": 1522238054357
+ },
+ {
+ "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
+ "timestamp": 1521743157199
+ },
+ {
+ "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
+ "timestamp": 1520522344607
+ }
+ ],
+ "version": "0.3.2",
+ "views": {}
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
index d62390494b..0702273fac 100644
--- a/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
+++ b/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb
@@ -570,7 +570,7 @@
" autograph.utils.set_element_type(numbers, tf.int32)\n",
" for i in range(n):\n",
" numbers.append(i)\n",
- " return numbers.stack() # Stack the list so that it can be used as a Tensor\n",
+ " return autograph.stack(numbers) # Stack the list so that it can be used as a Tensor\n",
"\n",
"\n",
"tf_f = autograph.to_graph(f)\n",
@@ -648,7 +648,7 @@
" if not is_prime:\n",
" continue\n",
" primes.append(i)\n",
- " all_primes = primes.stack()\n",
+ " all_primes = autograph.stack(primes)\n",
"\n",
" print('The prime numbers less than', n, 'are:')\n",
" print(all_primes)\n",
@@ -953,8 +953,9 @@
" train_accuracies.append(step_train_accuracy)\n",
" test_accuracies.append(step_test_accuracy)\n",
" i += 1\n",
- " return (train_losses.stack(), test_losses.stack(), train_accuracies.stack(),\n",
- " test_accuracies.stack())"
+ " return (autograph.stack(train_losses), autograph.stack(test_losses),\n",
+ " autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))"
],
"execution_count": 0,
"outputs": []
@@ -1236,7 +1237,7 @@
" cell_output, (state, output) = cell.call(ch, (state, output))\n",
" hidden_outputs.append(cell_output)\n",
" i += 1\n",
- " hidden_outputs = hidden_outputs.stack()\n",
+ " hidden_outputs = autograph.stack(hidden_outputs)\n",
" if training:\n",
" hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
" return hidden_outputs\n",
diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
index 324b23c24b..44532cb078 100644
--- a/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
+++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb
@@ -190,7 +190,6 @@
" self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n",
" self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n",
"\n",
- "\n",
" def _rnn_layer(self, chars, cell, batch_size, training):\n",
" \"\"\"A single RNN layer.\n",
"\n",
@@ -203,13 +202,12 @@
" Returns:\n",
" A Tensor of shape (max_sequence_length, batch_size, output_size).\n",
" \"\"\"\n",
- " hidden_outputs = []\n",
- " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n",
+ " hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n",
" state, output = cell.zero_state(batch_size, tf.float32)\n",
" for ch in chars:\n",
" cell_output, (state, output) = cell.call(ch, (state, output))\n",
" hidden_outputs.append(cell_output)\n",
- " hidden_outputs = hidden_outputs.stack()\n",
+ " hidden_outputs = autograph.stack(hidden_outputs)\n",
" if training:\n",
" hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
" return hidden_outputs\n",
@@ -223,7 +221,7 @@
"\n",
"\n",
" def call(self, inputs, training=False):\n",
- " \"\"\"The RNN model code. Uses Eager and \n",
+ " \"\"\"The RNN model code. Uses Eager.\n",
"\n",
" The model consists of two RNN layers (made by lower_cell and upper_cell),\n",
" followed by a fully connected layer with ReLU activation.\n",
@@ -243,7 +241,8 @@
" seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n",
"\n",
" # Grab just the end-of-sequence from each output.\n",
- " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n",
+ " indices = (length - 1, range(batch_size))\n",
+ " indices = tf.stack(indices, 1)\n",
" sequence_ends = tf.gather_nd(seq, indices)\n",
" return self.relu_layer(sequence_ends)\n",
"\n",
@@ -381,7 +380,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 107,
"metadata": {
"colab": {
"autoexec": {
@@ -392,9 +391,9 @@
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 10604,
+ "elapsed": 5454,
"status": "ok",
- "timestamp": 1524095272039,
+ "timestamp": 1529952160455,
"user": {
"displayName": "",
"photoUrl": "",
@@ -403,7 +402,7 @@
"user_tz": 240
},
"id": "2pg1AfbxBJQq",
- "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660",
+ "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb",
"slideshow": {
"slide_type": "-"
}
@@ -413,7 +412,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Eval loss at step 100: 0.0674834\n"
+ "Eval loss at step 100: 0.0705221\n"
]
}
],
@@ -423,8 +422,8 @@
" 'learning_rate': 0.01,\n",
"}\n",
"\n",
- "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv\"\n",
- "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv\"\n",
+ "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n",
+ "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n",
"data_dir = \"tmp/rnn/data\"\n",
"\n",
"regressor = tf.estimator.Estimator(\n",
@@ -457,7 +456,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 108,
"metadata": {
"colab": {
"autoexec": {
@@ -468,9 +467,9 @@
},
"colab_type": "code",
"executionInfo": {
- "elapsed": 7990,
+ "elapsed": 3432,
"status": "ok",
- "timestamp": 1524095280105,
+ "timestamp": 1529952163923,
"user": {
"displayName": "",
"photoUrl": "",
@@ -479,7 +478,7 @@
"user_tz": 240
},
"id": "dxHex2tUN_10",
- "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101",
+ "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2",
"slideshow": {
"slide_type": "slide"
}
@@ -491,12 +490,12 @@
"\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -507,12 +506,12 @@
"\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -520,15 +519,15 @@
{
"data": {
"text/html": [
- "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e"
+ "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -536,16 +535,16 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n",
- "//# sourceURL=js_71b9087b6d"
+ "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n",
+ "//# sourceURL=js_dc5d7f2784"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -553,16 +552,16 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_e390445f33"
+ "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_be7950150b"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e"
]
},
"metadata": {
"tags": [
- "outputarea_id1"
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -570,17 +569,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_241dd76d85"
+ "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_d0c3bd4eaa"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -588,17 +587,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_60c64e3d50"
+ "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_f10f6eba86"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -606,17 +605,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_14ea437cbd"
+ "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ff29697179"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -624,17 +623,17 @@
{
"data": {
"application/javascript": [
- "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_09294c2226"
+ "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_ff85295dc7"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -642,17 +641,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_e5e8266997"
+ "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ed7aabfedb"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -660,17 +659,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_07a097f0ee"
+ "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_c86f8feaf4"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -678,17 +677,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_790d669ca8"
+ "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_4d0fde6662"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -696,17 +695,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_d30df771f0"
+ "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_3f66d52720"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -714,32 +713,32 @@
{
"data": {
"application/javascript": [
- "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_8a43a2da4b"
+ "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_375f5ae6d7"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
},
{
"data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n",
"text/plain": [
- "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e"
+ "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -748,17 +747,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_893ad561f4"
+ "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_34b0509660"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -766,17 +765,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
- "//# sourceURL=js_2d99e0ac17"
+ "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
+ "//# sourceURL=js_518a0f26fe"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -784,17 +783,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n",
- "//# sourceURL=js_5c19462e32"
+ "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
+ "//# sourceURL=js_17eb3ff612"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -802,17 +801,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_b9c8b7567b"
+ "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_99da807c8e"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -820,17 +819,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n",
- "//# sourceURL=js_fd05186348"
+ "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
+ "//# sourceURL=js_dee01cb4b6"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -838,16 +837,16 @@
{
"data": {
"text/html": [
- "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
+ "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
],
"text/plain": [
- "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e"
+ "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -856,17 +855,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
- "//# sourceURL=js_efef96e882"
+ "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
+ "//# sourceURL=js_8c378be329"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -875,17 +874,17 @@
{
"data": {
"application/javascript": [
- "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
- "//# sourceURL=js_6eca889864"
+ "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
+ "//# sourceURL=js_f0b946600c"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -894,17 +893,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n",
- "//# sourceURL=js_f02070cc60"
+ "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n",
+ "//# sourceURL=js_9e21b1373a"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -913,17 +912,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n",
- "//# sourceURL=js_ed9faba660"
+ "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n",
+ "//# sourceURL=js_a7764968c6"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -932,17 +931,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n",
- "//# sourceURL=js_f3458d7074"
+ "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
+ "//# sourceURL=js_74279d3ff0"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -951,17 +950,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
- "//# sourceURL=js_3ffd97bd6f"
+ "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
+ "//# sourceURL=js_82b6c34cdb"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1",
+ "id3_content_0",
+ "outputarea_id3",
"user_output"
]
},
@@ -970,17 +969,17 @@
{
"data": {
"application/javascript": [
- "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n",
- "//# sourceURL=js_7f73e8bcca"
+ "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
+ "//# sourceURL=js_ff6144734a"
],
"text/plain": [
- "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e"
+ "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
]
},
"metadata": {
"tags": [
- "id1_content_0",
- "outputarea_id1"
+ "id3_content_0",
+ "outputarea_id3"
]
},
"output_type": "display_data"
@@ -1043,28 +1042,6 @@
"kind": "local"
},
"name": "RNN Colorbot using Keras and Estimators",
- "provenance": [
- {
- "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl",
- "timestamp": 1523579810961
- },
- {
- "file_id": "1DcfimonWU11tmyivKBGVrbpAl3BIOaRG",
- "timestamp": 1523016192637
- },
- {
- "file_id": "1wCZUh73zTNs1jzzYjqoxMIdaBWCdKJ2K",
- "timestamp": 1522238054357
- },
- {
- "file_id": "1_HpC-RrmIv4lNaqeoslUeWaX8zH5IXaJ",
- "timestamp": 1521743157199
- },
- {
- "file_id": "1mjO2fQ2F9hxpAzw2mnrrUkcgfb7xSGW-",
- "timestamp": 1520522344607
- }
- ],
"version": "0.3.2",
"views": {}
},
diff --git a/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb
new file mode 100644
index 0000000000..e8f16b431d
--- /dev/null
+++ b/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb
@@ -0,0 +1,1093 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "qWUV0FYjDSKj"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "from tensorflow.contrib import autograph\n",
+ "\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "kGXS3UWBBNoc"
+ },
+ "source": [
+ "# 1. AutoGraph writes graph code for you\n",
+ "\n",
+ "[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n",
+ "\n",
+ "Here's a quick example of how it works:\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "aA3gOodCBkOw"
+ },
+ "outputs": [],
+ "source": [
+ "# Autograph can convert functions like this...\n",
+ "def g(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " else:\n",
+ " x = 0.0\n",
+ " return x\n",
+ "\n",
+ "# ...into graph-building functions like this:\n",
+ "def tf_g(x):\n",
+ " with tf.name_scope('g'):\n",
+ " \n",
+ " def if_true():\n",
+ " with tf.name_scope('if_true'):\n",
+ " x_1, = x,\n",
+ " x_1 = x_1 * x_1\n",
+ " return x_1,\n",
+ "\n",
+ " def if_false():\n",
+ " with tf.name_scope('if_false'):\n",
+ " x_1, = x,\n",
+ " x_1 = 0.0\n",
+ " return x_1,\n",
+ "\n",
+ " x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n",
+ " return x\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "I1RtBvoKBxq5"
+ },
+ "outputs": [],
+ "source": [
+ "# You can run your plain-Python code in graph mode,\n",
+ "# and get the same results out, but with all the benfits of graphs:\n",
+ "print('Original value: %2.2f' % g(9.0))\n",
+ "\n",
+ "# Generate a graph-version of g and call it:\n",
+ "tf_g = autograph.to_graph(g)\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " # The result works like a regular op: takes tensors in, returns tensors.\n",
+ " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
+ " g_ops = tf_g(tf.constant(9.0))\n",
+ " with tf.Session() as sess:\n",
+ " print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n",
+ " \n",
+ " \n",
+ "# You can view, debug and tweak the generated code:\n",
+ "print(autograph.to_code(g))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "m-jWmsCmByyw"
+ },
+ "source": [
+ "#### Automatically converting complex control flow\n",
+ "\n",
+ "AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n",
+ "AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent. \n",
+ " \n",
+ "We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "toxKBOXbB1ro"
+ },
+ "outputs": [],
+ "source": [
+ "# Continue in a loop\n",
+ "def f(l):\n",
+ " s = 0\n",
+ " for c in l:\n",
+ " if c % 2 \u003e 0:\n",
+ " continue\n",
+ " s += c\n",
+ " return s\n",
+ "\n",
+ "print('Original value: %d' % f([10,12,15,20]))\n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n",
+ " \n",
+ "print(autograph.to_code(f))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "FUJJ-WTdCGeq"
+ },
+ "source": [
+ "Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well! \n",
+ " \n",
+ "Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "IAOgh62zCPZ4"
+ },
+ "outputs": [],
+ "source": [
+ "def f(x):\n",
+ " assert x != 0, 'Do not pass zero!'\n",
+ " return x * x\n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " try:\n",
+ " print(tf_f(tf.constant(0)).eval())\n",
+ " except tf.errors.InvalidArgumentError as e:\n",
+ " print('Got error message:\\n%s' % e.message)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "KRu8iIPBCQr5"
+ },
+ "source": [
+ "You can also use plain Python `print` functions in in-graph"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ySTsuxnqCTQi"
+ },
+ "outputs": [],
+ "source": [
+ "def f(n):\n",
+ " if n \u003e= 0:\n",
+ " while n \u003c 5:\n",
+ " n += 1\n",
+ " print(n)\n",
+ " return n\n",
+ " \n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default():\n",
+ " with tf.Session():\n",
+ " tf_f(tf.constant(0)).eval()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "NqF0GT-VCVFh"
+ },
+ "source": [
+ "Appending to lists in loops also works (we create a `TensorArray` for you behind the scenes)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ABX070KwCczR"
+ },
+ "outputs": [],
+ "source": [
+ "def f(n):\n",
+ " z = []\n",
+ " # We ask you to tell us the element dtype of the list\n",
+ " z = autograph.utils.set_element_type(z, tf.int32)\n",
+ " for i in range(n):\n",
+ " z.append(i)\n",
+ " # when you're done with the list, stack it\n",
+ " # (this is just like np.stack)\n",
+ " return autograph.stack(z) \n",
+ "\n",
+ "tf_f = autograph.to_graph(f)\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session():\n",
+ " print(tf_f(tf.constant(3)).eval())\n",
+ "\n",
+ "print('\\n\\n'+autograph.to_code(f))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "iu5IF7n2Df7C"
+ },
+ "outputs": [],
+ "source": [
+ "def fizzbuzz(num):\n",
+ " if num % 3 == 0 and num % 5 == 0:\n",
+ " print('FizzBuzz')\n",
+ " elif num % 3 == 0:\n",
+ " print('Fizz')\n",
+ " elif num % 5 == 0:\n",
+ " print('Buzz')\n",
+ " else:\n",
+ " print(num)\n",
+ " return num"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "EExAjWuwDPpR"
+ },
+ "outputs": [],
+ "source": [
+ "tf_g = autograph.to_graph(fizzbuzz)\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " # The result works like a regular op: takes tensors in, returns tensors.\n",
+ " # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
+ " g_ops = tf_g(tf.constant(15))\n",
+ " with tf.Session() as sess:\n",
+ " sess.run(g_ops) \n",
+ " \n",
+ "# You can view, debug and tweak the generated code:\n",
+ "print('\\n')\n",
+ "print(autograph.to_code(fizzbuzz))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "SzpKGzVpBkph"
+ },
+ "source": [
+ "# De-graphify Exercises\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "8k23dxcSmmXq"
+ },
+ "source": [
+ "#### Easy print statements"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "dE1Vsmp-mlpK"
+ },
+ "outputs": [],
+ "source": [
+ "# See what happens when you turn AutoGraph off.\n",
+ "# Do you see the type or the value of x when you print it?\n",
+ "\n",
+ "# @autograph.convert()\n",
+ "def square_log(x):\n",
+ " x = x * x\n",
+ " print('Squared value of x =', x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_log(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "_R-Q7BbxmkBF"
+ },
+ "source": [
+ "#### Now some exercises. Convert the TensorFlow code into AutoGraph'd Python code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "SwA11tO-yCvg"
+ },
+ "outputs": [],
+ "source": [
+ "def square_if_positive(x):\n",
+ " x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n",
+ " return x\n",
+ "\n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "GPmx4CNhyPI_"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_if_positive(x):\n",
+ " ... # \u003c\u003c\u003c fill it in!\n",
+ " \n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qqsjik-QyA9R"
+ },
+ "source": [
+ "#### Uncollapse to see answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "DaSmaWUEvMRv"
+ },
+ "outputs": [],
+ "source": [
+ "# Simple cond\n",
+ "@autograph.convert()\n",
+ "def square_if_positive(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_if_positive(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qj7am2I_xvTJ"
+ },
+ "source": [
+ "#### Nested If statement"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "4yyNOf-Twr6s"
+ },
+ "outputs": [],
+ "source": [
+ "def nearest_odd_square(x):\n",
+ "\n",
+ " def if_positive():\n",
+ " x1 = x * x\n",
+ " x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n",
+ " return x1,\n",
+ "\n",
+ " x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "hqmh5b2VyU9w"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def nearest_odd_square(x):\n",
+ " ... # \u003c\u003c\u003c fill it in!\n",
+ " \n",
+ "with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "b9AXIkNLxp6J"
+ },
+ "source": [
+ "#### Uncollapse to reveal answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "8RlCVEpNxD91"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def nearest_odd_square(x):\n",
+ " if x \u003e 0:\n",
+ " x = x * x\n",
+ " if x % 2 == 0:\n",
+ " x = x + 1\n",
+ " return x\n",
+ "\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(nearest_odd_square(tf.constant(4))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "jXAxjeBr1qWK"
+ },
+ "source": [
+ "#### Convert a while loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "kWkv7anlxoee"
+ },
+ "outputs": [],
+ "source": [
+ "# Convert a while loop\n",
+ "def square_until_stop(x, y):\n",
+ " x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n",
+ " return x\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "zVUsc1eA1u2K"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_until_stop(x, y):\n",
+ " ... # fill it in!\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L2psuzPI02S9"
+ },
+ "source": [
+ "#### Uncollapse for the answer\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "ucmZyQVL03bF"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def square_until_stop(x, y):\n",
+ " while x \u003c y:\n",
+ " x = x * x\n",
+ " return x\n",
+ " \n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "FXB0Zbwl13PY"
+ },
+ "source": [
+ "#### Nested loop and conditional"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "clGymxdf15Ig"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " current_sum = 0.0\n",
+ " idx = 0\n",
+ " \n",
+ " for i in range(len(x)):\n",
+ " idx = i\n",
+ " if current_sum \u003e= threshold:\n",
+ " break\n",
+ " current_sum += x[i]\n",
+ " return idx\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "i7PF-uId9lp5"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " ...\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "weKFXAb615Vp"
+ },
+ "source": [
+ "#### Uncollapse to see answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "1sjaFcL717Ig"
+ },
+ "outputs": [],
+ "source": [
+ "@autograph.convert()\n",
+ "def argwhere_cumsum(x, threshold):\n",
+ " current_sum = 0.0\n",
+ " idx = 0\n",
+ " for i in range(len(x)):\n",
+ " idx = i\n",
+ " if current_sum \u003e= threshold:\n",
+ " break\n",
+ " current_sum += x[i]\n",
+ " return idx\n",
+ "\n",
+ "N = 10\n",
+ "with tf.Graph().as_default(): \n",
+ " with tf.Session() as sess:\n",
+ " idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
+ " print(sess.run(idx))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "4LfnJjm0Bm0B"
+ },
+ "source": [
+ "# 3. Training MNIST in-graph\n",
+ "\n",
+ "Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well! \n",
+ "\n",
+ "Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Em5dzSUOtLRP"
+ },
+ "source": [
+ "#### Download data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "xqoxumv0ssQW"
+ },
+ "outputs": [],
+ "source": [
+ "import gzip\n",
+ "import os\n",
+ "import shutil\n",
+ "\n",
+ "from six.moves import urllib\n",
+ "\n",
+ "\n",
+ "def download(directory, filename):\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " zipped_filepath = filepath + '.gz'\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [784])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8)\n",
+ " label = tf.reshape(label, [])\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def mnist_train(directory):\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte')\n",
+ "\n",
+ "def mnist_test(directory):\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "znmy4l8ntMvW"
+ },
+ "source": [
+ "#### Define the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "Pe-erWQdBoC5"
+ },
+ "outputs": [],
+ "source": [
+ "def mlp_model(input_shape):\n",
+ " model = tf.keras.Sequential((\n",
+ " tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
+ " tf.keras.layers.Dense(100, activation='relu'),\n",
+ " tf.keras.layers.Dense(10, activation='softmax')))\n",
+ " model.build()\n",
+ " return model\n",
+ "\n",
+ "\n",
+ "def predict(m, x, y):\n",
+ " y_p = m(x)\n",
+ " losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
+ " l = tf.reduce_mean(losses)\n",
+ " accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
+ " accuracy = tf.reduce_mean(accuracies)\n",
+ " return l, accuracy\n",
+ "\n",
+ "\n",
+ "def fit(m, x, y, opt):\n",
+ " l, accuracy = predict(m, x, y)\n",
+ " opt.minimize(l)\n",
+ " return l, accuracy\n",
+ "\n",
+ "\n",
+ "def setup_mnist_data(is_training, hp, batch_size):\n",
+ " if is_training:\n",
+ " ds = mnist_train('/tmp/autograph_mnist_data')\n",
+ " ds = ds.shuffle(batch_size * 10)\n",
+ " else:\n",
+ " ds = mnist_test('/tmp/autograph_mnist_data')\n",
+ " ds = ds.repeat()\n",
+ " ds = ds.batch(batch_size)\n",
+ " return ds\n",
+ "\n",
+ "\n",
+ "def get_next_batch(ds):\n",
+ " itr = ds.make_one_shot_iterator()\n",
+ " image, label = itr.get_next()\n",
+ " x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
+ " y = tf.one_hot(tf.squeeze(label), 10)\n",
+ " return x, y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "oeYV6mKnJGMr"
+ },
+ "source": [
+ "#### Define the training loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "3xtg_MMhJETd"
+ },
+ "outputs": [],
+ "source": [
+ "def train(train_ds, test_ds, hp):\n",
+ " m = mlp_model((28 * 28,))\n",
+ " opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
+ " \n",
+ " # We'd like to save our losses to a list. In order for AutoGraph\n",
+ " # to convert these lists into their graph equivalent,\n",
+ " # we need to specify the element type of the lists.\n",
+ " train_losses = []\n",
+ " train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n",
+ " test_losses = []\n",
+ " test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n",
+ " train_accuracies = []\n",
+ " train_accuracies = autograph.utils.set_element_type(train_accuracies, tf.float32)\n",
+ " test_accuracies = []\n",
+ " test_accuracies = autograph.utils.set_element_type(test_accuracies, tf.float32)\n",
+ " \n",
+ " # This entire training loop will be run in-graph.\n",
+ " i = tf.constant(0)\n",
+ " while i \u003c hp.max_steps:\n",
+ " train_x, train_y = get_next_batch(train_ds)\n",
+ " test_x, test_y = get_next_batch(test_ds)\n",
+ " # add get next\n",
+ " step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
+ " step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
+ " if i % (hp.max_steps // 10) == 0:\n",
+ " print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n",
+ " step_test_loss, 'train accuracy:', step_train_accuracy,\n",
+ " 'test accuracy:', step_test_accuracy)\n",
+ " train_losses.append(step_train_loss)\n",
+ " test_losses.append(step_test_loss)\n",
+ " train_accuracies.append(step_train_accuracy)\n",
+ " test_accuracies.append(step_test_accuracy)\n",
+ " i += 1\n",
+ " \n",
+ " # We've recorded our loss values and accuracies \n",
+ " # to a list in a graph with AutoGraph's help.\n",
+ " # In order to return the values as a Tensor, \n",
+ " # we need to stack them before returning them.\n",
+ " return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),\n",
+ " autograph.stack(test_accuracies))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ },
+ "colab_type": "code",
+ "id": "HYh6MSZyJOag"
+ },
+ "outputs": [],
+ "source": [
+ "with tf.Graph().as_default():\n",
+ " hp = tf.contrib.training.HParams(\n",
+ " learning_rate=0.05,\n",
+ " max_steps=500,\n",
+ " )\n",
+ " train_ds = setup_mnist_data(True, hp, 50)\n",
+ " test_ds = setup_mnist_data(False, hp, 1000)\n",
+ " tf_train = autograph.to_graph(train)\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = tf_train(train_ds, test_ds, hp)\n",
+ "\n",
+ " with tf.Session() as sess:\n",
+ " sess.run(tf.global_variables_initializer())\n",
+ " (train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n",
+ " test_accuracies])\n",
+ " plt.title('MNIST train/test losses')\n",
+ " plt.plot(train_losses, label='train loss')\n",
+ " plt.plot(test_losses, label='test loss')\n",
+ " plt.legend()\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Loss')\n",
+ " plt.show()\n",
+ " plt.title('MNIST train/test accuracies')\n",
+ " plt.plot(train_accuracies, label='train accuracy')\n",
+ " plt.plot(test_accuracies, label='test accuracy')\n",
+ " plt.legend(loc='lower right')\n",
+ " plt.xlabel('Training step')\n",
+ " plt.ylabel('Accuracy')\n",
+ " plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [
+ "qqsjik-QyA9R",
+ "b9AXIkNLxp6J",
+ "L2psuzPI02S9",
+ "weKFXAb615Vp",
+ "Em5dzSUOtLRP"
+ ],
+ "default_view": {},
+ "name": "AutoGraph Workshop.ipynb",
+ "provenance": [
+ {
+ "file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP",
+ "timestamp": 1530563781803
+ }
+ ],
+ "version": "0.3.2",
+ "views": {}
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD
index 54424e2647..a5438592c3 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/contrib/autograph/impl/BUILD
@@ -18,18 +18,19 @@ py_library(
name = "impl",
srcs = [
"api.py",
- "config.py",
"conversion.py",
- "naming.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib/autograph/converters",
+ "//tensorflow/contrib/autograph/core",
"//tensorflow/contrib/autograph/operators",
"//tensorflow/contrib/autograph/pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis",
"//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
"@gast_archive//:gast",
"@six_archive//:six",
],
@@ -59,13 +60,3 @@ py_test(
"@gast_archive//:gast",
],
)
-
-py_test(
- name = "naming_test",
- srcs = ["naming_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":impl",
- "//tensorflow/python:client_testlib",
- ],
-)
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 24f87b2c14..c7401c7df1 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -27,14 +27,15 @@ import gast
import six
# pylint:enable=g-bad-import-order
-from tensorflow.contrib.autograph.impl import config
+from tensorflow.contrib.autograph.core import config
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import conversion
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.utils import builtins
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
# TODO(mdan): Properly document the type hints.
@@ -70,6 +71,8 @@ def convert(recursive=False, verbose=False, arg_types=None):
def wrapper(*args, **kwargs):
return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)
+ wrapper = tf_decorator.make_decorator(f, wrapper)
+
# Sometimes the decorator is just desugared, making it impossible to detect.
# This attribute makes detection easier.
setattr(wrapper, '__pyct_is_compile_decorator', True)
@@ -230,20 +233,20 @@ def to_graph(e,
A function with a signature identical to `o`, but which when executed it
creates TF a graph that has the same functionality as the original entity.
"""
- conversion_map = conversion.ConversionMap(
+ program_ctx = converter.ProgramContext(
recursive=recursive,
- nocompile_decorators=(convert, do_not_convert, converted_call),
+ autograph_decorators=(convert, do_not_convert, converted_call),
partial_types=partial_types,
- api_module=tf_inspect.getmodule(to_graph))
- _, name, namespace = conversion.entity_to_graph(e, conversion_map, arg_values,
+ autograph_module=tf_inspect.getmodule(to_graph),
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
arg_types)
module = gast.Module([])
- for import_line in config.COMPILED_IMPORT_STATEMENTS:
- module.body.extend(parser.parse_str(import_line).body)
- for dep in reversed(conversion_map.dependency_cache.values()):
+ for dep in reversed(program_ctx.dependency_cache.values()):
module.body.append(dep)
- compiled_node, compiled_src = compiler.ast_to_object(module)
+ compiled_node, compiled_src = compiler.ast_to_object(
+ module, source_prefix=program_ctx.required_imports)
# The compiled code should see everything the entry entity saw.
# TODO(mdan): This might not work well if the call tree spans modules?
@@ -280,17 +283,16 @@ def to_code(e,
Returns:
String.
"""
- conversion_map = conversion.ConversionMap(
+ program_ctx = converter.ProgramContext(
recursive=recursive,
- nocompile_decorators=(convert, do_not_convert, converted_call),
+ autograph_decorators=(convert, do_not_convert, converted_call),
partial_types=partial_types,
- api_module=tf_inspect.getmodule(to_graph))
- conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)
+ autograph_module=tf_inspect.getmodule(to_graph),
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
- imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
code = '\n'.join(
compiler.ast_to_source(dep, indentation)
- for dep in reversed(tuple(
- six.itervalues(conversion_map.dependency_cache))))
+ for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
- return imports + '\n\n' + code
+ return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index a7737b7f44..9943093332 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -21,12 +21,13 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.impl import config
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
@@ -154,6 +155,22 @@ class ApiTest(test.TestCase):
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
+ def test_decorator_preserves_argspec(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ called_member_converted = api.convert()(called_member)
+
+ tc = TestClass()
+ self.assertListEqual(
+ list(tf_inspect.getfullargspec(tc.called_member)),
+ list(tf_inspect.getfullargspec(tc.called_member_converted)))
+
def test_convert_call_site_decorator(self):
class TestClass(object):
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 55a30dc127..776d19f672 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""High level conversion support."""
+"""Core conversion logic, serves as main point of access."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
import imp
import gast
@@ -38,77 +37,23 @@ from tensorflow.contrib.autograph.converters import logical_expressions
from tensorflow.contrib.autograph.converters import name_scopes
from tensorflow.contrib.autograph.converters import side_effect_guards
from tensorflow.contrib.autograph.converters import single_return
-from tensorflow.contrib.autograph.impl import config
-from tensorflow.contrib.autograph.impl import naming
+from tensorflow.contrib.autograph.converters import slices
+from tensorflow.contrib.autograph.core import config
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import inspect_utils
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import activity
from tensorflow.contrib.autograph.pyct.static_analysis import live_values
from tensorflow.contrib.autograph.pyct.static_analysis import type_info
-from tensorflow.contrib.autograph.utils import type_hints
from tensorflow.python.util import tf_inspect
# TODO(mdan): Might we not need any renaming at all?
-class ConversionMap(object):
- """ConversionMap keeps track of converting function hierarchies.
-
- This object is mutable, and is updated as functions are converted.
-
- Attributes:
- recursive: Whether to recursively convert any functions that the decorator
- function may call.
- nocompile_decorators: tuple of decorator functions that toggle compilation
- off.
- dependency_cache: dict[object]: ast; maps original entities to their
- converted AST
- additional_imports: set(object); additional entities which for any reason
- cannot be attached after loading and need to be explicitly imported
- in the generated code
- name_map: dict[string]: string; maps original entities to the name of
- their converted counterparts
- api_module: A reference to the api module. The reference needs to be passed
- to avoid circular dependencies.
- """
-
- # TODO(mdan): Rename to ConversionContext, and pull in additional flags.
-
- def __init__(self, recursive, nocompile_decorators, partial_types,
- api_module):
- self.recursive = recursive
- self.nocompile_decorators = nocompile_decorators
- self.partial_types = partial_types if partial_types else ()
- # Required to output dependencies in discovery order, which should match
- # the reverse dependency order.
- self.dependency_cache = collections.OrderedDict()
- self.additional_imports = set()
- self.name_map = {}
- self.api_module = api_module
-
- def new_namer(self, namespace):
- return naming.Namer(namespace, self.recursive, self.name_map,
- self.partial_types)
-
- def update_name_map(self, namer):
- for o, name in namer.renamed_calls.items():
- if o in self.name_map:
- if self.name_map[o] != name:
- raise ValueError(
- 'Calls to %s were converted using multiple names (%s). This is '
- 'possible when an entity with one of these names already '
- 'existed. To fix, avoid using any of these names.')
- else:
- self.name_map[o] = name
-
- def add_to_cache(self, original_entity, converted_ast):
- self.dependency_cache[original_entity] = converted_ast
-
-
def is_whitelisted_for_graph(o):
"""Check whether an entity is whitelisted for use in graph mode.
@@ -127,7 +72,7 @@ def is_whitelisted_for_graph(o):
return False
-def entity_to_graph(o, conversion_map, arg_values, arg_types):
+def entity_to_graph(o, program_ctx, arg_values, arg_types):
"""Compile a Python entity into equivalent TensorFlow.
The function will also recursively compile all the entities that `o`
@@ -138,7 +83,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types):
Args:
o: A Python entity.
- conversion_map: A ConversionMap object.
+ program_ctx: A ProgramContext object.
arg_values: A dict containing value hints for symbols like function
parameters.
arg_types: A dict containing type hints for symbols like function
@@ -156,7 +101,7 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types):
ValueError: if the entity type is not supported.
"""
if tf_inspect.isclass(o):
- node, name, ns = class_to_graph(o, conversion_map)
+ node, name, ns = class_to_graph(o, program_ctx)
elif tf_inspect.isfunction(o):
# TODO(mdan): This is not a reliable mechanism.
# The most reliable way is to check the source code, the AST will contain
@@ -166,36 +111,35 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types):
'lambda functions are not yet supported; declare the function'
' using def instead: %s' % o)
else:
- node, name, ns = function_to_graph(o, conversion_map, arg_values,
- arg_types)
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
elif tf_inspect.ismethod(o):
- node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types)
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
else:
raise ValueError(
'Entity "%s" has unsupported type "%s". Only functions and classes are '
'supported for now.' % (o, type(o)))
- conversion_map.add_to_cache(o, node)
- if conversion_map.recursive:
+ program_ctx.add_to_cache(o, node)
+ if program_ctx.recursive:
while True:
candidate = None
- for obj in conversion_map.name_map.keys():
- if obj not in conversion_map.dependency_cache:
+ for obj in program_ctx.name_map.keys():
+ if obj not in program_ctx.dependency_cache:
candidate = obj
break
if candidate is None:
break
if (hasattr(candidate, 'im_class') and
- getattr(candidate, 'im_class') not in conversion_map.partial_types):
+ getattr(candidate, 'im_class') not in program_ctx.partial_types):
# Class members are converted with their objects, unless they're
# only converted partially.
continue
- entity_to_graph(candidate, conversion_map, {}, {})
+ entity_to_graph(candidate, program_ctx, {}, {})
return node, name, ns
-def class_to_graph(c, conversion_map):
+def class_to_graph(c, program_ctx):
"""Specialization of `entity_to_graph` for classes."""
converted_members = {}
method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
@@ -210,7 +154,7 @@ def class_to_graph(c, conversion_map):
continue
node, _, namespace = function_to_graph(
m,
- conversion_map=conversion_map,
+ program_ctx=program_ctx,
arg_values={},
arg_types={'self': (c.__name__, c)},
owner_type=c)
@@ -219,14 +163,14 @@ def class_to_graph(c, conversion_map):
else:
class_namespace.update(namespace)
converted_members[m] = node
- namer = conversion_map.new_namer(class_namespace)
+ namer = program_ctx.new_namer(class_namespace)
class_name = namer.compiled_class_name(c.__name__, c)
# TODO(mdan): This needs to be explained more thoroughly.
# Process any base classes: if the sueprclass if of a whitelisted type, an
# absolute import line is generated. Otherwise, it is marked for conversion
# (as a side effect of the call to namer.compiled_class_name() followed by
- # conversion_map.update_name_map(namer)).
+ # program_ctx.update_name_map(namer)).
output_nodes = []
renames = {}
bases = []
@@ -246,7 +190,7 @@ def class_to_graph(c, conversion_map):
alias = namer.compiled_class_name(base.__name__, base)
bases.append(alias)
renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
- conversion_map.update_name_map(namer)
+ program_ctx.update_name_map(namer)
# Generate the definition of the converted class.
output_nodes.append(
@@ -278,14 +222,14 @@ def _add_reserved_symbol(namespace, name, entity):
ag_internal = None
-def _add_self_references(namespace, api_module):
+def _add_self_references(namespace, autograph_module):
"""Adds namespace references to the module that exposes the api itself."""
global ag_internal
if ag_internal is None:
# Craft a module that exposes parts of the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
- ag_internal.converted_call = api_module.converted_call
+ ag_internal.converted_call = autograph_module.converted_call
ag_internal.utils = utils
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
@@ -295,27 +239,24 @@ def _add_self_references(namespace, api_module):
_add_reserved_symbol(namespace, 'ag__', ag_internal)
-def function_to_graph(f, conversion_map, arg_values, arg_types,
- owner_type=None):
+def function_to_graph(f, program_ctx, arg_values, arg_types, owner_type=None):
"""Specialization of `entity_to_graph` for callable functions."""
node, source = parser.parse_entity(f)
node = node.body[0]
namespace = inspect_utils.getnamespace(f)
- _add_self_references(namespace, conversion_map.api_module)
- namer = conversion_map.new_namer(namespace)
+ _add_self_references(namespace, program_ctx.autograph_module)
+ namer = program_ctx.new_namer(namespace)
- ctx = context.EntityContext(
- namer=namer,
+ entity_info = transformer.EntityInfo(
source_code=source,
source_file='<fragment>',
namespace=namespace,
arg_values=arg_values,
arg_types=arg_types,
- owner_type=owner_type,
- recursive=conversion_map.recursive,
- type_annotation_func=type_hints.set_element_type)
- node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)
+ owner_type=owner_type)
+ context = converter.EntityContext(namer, entity_info, program_ctx)
+ node = node_to_graph(node, context)
# TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
@@ -325,29 +266,28 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
raise NotImplementedError('Strange corner case. Send us offending code!')
node.name = new_name
- conversion_map.update_name_map(namer)
+ program_ctx.update_name_map(namer)
# TODO(mdan): Use this at compilation.
- conversion_map.additional_imports.update(deps)
return node, new_name, namespace
-def _static_analysis_pass(node, ctx):
+def _apply_transformer(node, context, converter_module):
+ # TODO(mdan): Clear static analysis here.
node = qual_names.resolve(node)
- node = activity.resolve(node, ctx, None)
- node = live_values.resolve(node, ctx, config.PYTHON_LITERALS)
- node = type_info.resolve(node, ctx)
+ node = activity.resolve(node, context.info, None)
+ node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
+ node = type_info.resolve(node, context.info)
+ node = converter_module.transform(node, context)
return node
-def node_to_graph(node, ctx, nocompile_decorators):
+def node_to_graph(node, context):
"""Convert Python code to equivalent TF graph mode code.
Args:
- node: A Python AST node representing the code to convert.
- ctx: An EntityContext object.
- nocompile_decorators: A tuple containing decorators to be stripped from
- functions during conversion.
+ node: AST, the code to convert.
+ context: converter.EntityContext
Returns:
A tuple (node, deps):
@@ -357,53 +297,26 @@ def node_to_graph(node, ctx, nocompile_decorators):
"""
# TODO(mdan): Verify arguments for correctness.
- # TODO(mdan): Factor out common elements.
- # These include:
- # * code move between blocks
- # * visiting blocks in transformers
-
- # Certain steps, especially canonicalization, insert new symbols into the
- # tree, which must be accounted. Although less efficient, it is most robust
- # to re-run the analysis.
-
- node = _static_analysis_pass(node, ctx)
-
- # TODO(mdan): Clean this up.
- # Some intermediate analyses are not required, and some comments got orphaned.
-
+ node = _apply_transformer(node, context, ifexp)
# Past this point, line numbers are no longer accurate so we ignore the
# source.
# TODO(mdan): Is it feasible to reconstruct intermediate source code?
- ctx.source_code = None
- node = ifexp.transform(node, ctx)
- node, deps = decorators.transform(node, nocompile_decorators)
- node = break_statements.transform(node, ctx)
- node = _static_analysis_pass(node, ctx)
-
- node = asserts.transform(node, ctx)
-
+ context.info.source_code = None
+ node = _apply_transformer(node, context, decorators)
+ node = _apply_transformer(node, context, break_statements)
+ node = _apply_transformer(node, context, asserts)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
- node = continue_statements.transform(node, ctx)
- ctx.namespace['len'] = len
-
- node = _static_analysis_pass(node, ctx)
- node = single_return.transform(node, ctx)
-
- node = _static_analysis_pass(node, ctx)
- node = lists.transform(node, ctx)
- node = builtin_functions.transform(node, ctx)
-
- node = _static_analysis_pass(node, ctx)
- node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
- nocompile_decorators)
- node = control_flow.transform(node, ctx)
-
- # control_flow may create new symbols and change scopes.
- node = _static_analysis_pass(node, ctx)
- node = logical_expressions.transform(node, ctx)
- node = side_effect_guards.transform(node, ctx)
- node = name_scopes.transform(node, ctx)
-
- return node, deps
+ node = _apply_transformer(node, context, continue_statements)
+ context.info.namespace['len'] = len
+ node = _apply_transformer(node, context, single_return)
+ node = _apply_transformer(node, context, lists)
+ node = _apply_transformer(node, context, slices)
+ node = _apply_transformer(node, context, builtin_functions)
+ node = _apply_transformer(node, context, call_trees)
+ node = _apply_transformer(node, context, control_flow)
+ node = _apply_transformer(node, context, logical_expressions)
+ node = _apply_transformer(node, context, side_effect_guards)
+ node = _apply_transformer(node, context, name_scopes)
+ return node
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index bc61498b54..f5279298af 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph import utils
+from tensorflow.contrib.autograph.core import config
+from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import api
from tensorflow.contrib.autograph.impl import conversion
from tensorflow.python.framework import constant_op
@@ -30,8 +32,13 @@ from tensorflow.python.platform import test
class ConversionTest(test.TestCase):
- def _simple_conversion_map(self):
- return conversion.ConversionMap(True, (), (), api)
+ def _simple_program_ctx(self):
+ return converter.ProgramContext(
+ recursive=True,
+ autograph_decorators=(),
+ partial_types=(),
+ autograph_module=api,
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
def test_is_whitelisted_for_graph(self):
@@ -44,16 +51,16 @@ class ConversionTest(test.TestCase):
def test_entity_to_graph_unsupported_types(self):
with self.assertRaises(ValueError):
- conversion_map = self._simple_conversion_map()
- conversion.entity_to_graph('dummy', conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph('dummy', program_ctx, None, None)
def test_entity_to_graph_callable(self):
b = 2
def f(a):
return a + b
- conversion_map = self._simple_conversion_map()
- ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ ast, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
self.assertEqual('tf__f', name)
self.assertTrue(ns['b'] is b)
@@ -66,18 +73,17 @@ class ConversionTest(test.TestCase):
def f(a):
return g(a)
- conversion_map = self._simple_conversion_map()
- conversion.entity_to_graph(f, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(f, program_ctx, None, None)
- self.assertTrue(f in conversion_map.dependency_cache)
- self.assertTrue(g in conversion_map.dependency_cache)
- self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
+ self.assertTrue(f in program_ctx.dependency_cache)
+ self.assertTrue(g in program_ctx.dependency_cache)
+ self.assertEqual('tf__f', program_ctx.dependency_cache[f].name)
# need the extra .body[0] in order to step past the with tf.name_scope('f')
# that is added automatically
self.assertEqual(
- 'tf__g',
- conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
- self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
+ 'tf__g', program_ctx.dependency_cache[f].body[0].body[0].value.func.id)
+ self.assertEqual('tf__g', program_ctx.dependency_cache[g].name)
def test_entity_to_graph_class_hierarchy(self):
@@ -104,16 +110,15 @@ class ConversionTest(test.TestCase):
def baz(self):
return self.y
- conversion_map = self._simple_conversion_map()
- conversion.entity_to_graph(TestSubclass, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(TestSubclass, program_ctx, None, None)
- self.assertTrue(TestBase in conversion_map.dependency_cache)
- self.assertTrue(TestSubclass in conversion_map.dependency_cache)
+ self.assertTrue(TestBase in program_ctx.dependency_cache)
+ self.assertTrue(TestSubclass in program_ctx.dependency_cache)
self.assertEqual('TfTestBase',
- conversion_map.dependency_cache[TestBase].body[-1].name)
- self.assertEqual(
- 'TfTestSubclass',
- conversion_map.dependency_cache[TestSubclass].body[-1].name)
+ program_ctx.dependency_cache[TestBase].body[-1].name)
+ self.assertEqual('TfTestSubclass',
+ program_ctx.dependency_cache[TestSubclass].body[-1].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self):
@@ -126,24 +131,23 @@ class ConversionTest(test.TestCase):
def call(self, x):
return 3 * x
- conversion_map = self._simple_conversion_map()
- conversion.entity_to_graph(TestSubclass, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(TestSubclass, program_ctx, None, None)
- self.assertTrue(TestSubclass in conversion_map.dependency_cache)
- self.assertFalse(training.Model in conversion_map.dependency_cache)
+ self.assertTrue(TestSubclass in program_ctx.dependency_cache)
+ self.assertFalse(training.Model in program_ctx.dependency_cache)
self.assertEqual(
'Model',
- conversion_map.dependency_cache[TestSubclass].body[0].names[0].name)
- self.assertEqual(
- 'TfTestSubclass',
- conversion_map.dependency_cache[TestSubclass].body[-1].name)
+ program_ctx.dependency_cache[TestSubclass].body[0].names[0].name)
+ self.assertEqual('TfTestSubclass',
+ program_ctx.dependency_cache[TestSubclass].body[-1].name)
def test_entity_to_graph_lambda(self):
f = lambda a: a
with self.assertRaises(NotImplementedError):
- conversion_map = self._simple_conversion_map()
- conversion.entity_to_graph(f, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(f, program_ctx, None, None)
def test_ag_module_cached(self):
def callee():
@@ -152,11 +156,11 @@ class ConversionTest(test.TestCase):
def caller(a):
return a()
- conversion_map = self._simple_conversion_map()
- _, _, callee_ns = conversion.entity_to_graph(
- callee, conversion_map, None, None)
- _, _, caller_ns = conversion.entity_to_graph(
- caller, conversion_map, None, None)
+ program_ctx = self._simple_program_ctx()
+ _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None,
+ None)
+ _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None,
+ None)
self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/contrib/autograph/lang/BUILD
new file mode 100644
index 0000000000..77a2184e22
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/BUILD
@@ -0,0 +1,40 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "lang",
+ srcs = [
+ "directives.py",
+ "special_functions.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/autograph/operators",
+ ],
+)
+
+py_test(
+ name = "special_functions_test",
+ srcs = ["special_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":lang",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/contrib/autograph/lang/directives.py
new file mode 100644
index 0000000000..aabe5d9939
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/directives.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Directives are special no-op functions that serve as compilation markers.
+
+They provide static information like type hints, compilation and TensorFlow
+overrides.
+
+These serve as annotations in the compiled code, allowing the user some control
+over the compilation process. They have no functional role at runtime.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+UNSPECIFIED = object()
+
+
+def set_element_type(entity, dtype, shape=UNSPECIFIED):
+ """Indicates that the entity is expected hold items of specified type/shape.
+
+ The staged TensorFlow ops will reflect and assert this data type. Ignored
+ otherwise.
+
+ Args:
+ entity: The entity to annotate.
+ dtype: TensorFlow dtype value to assert for entity.
+ shape: Optional shape to assert for entity.
+ """
+ del entity
+ del dtype
+ del shape
+
+
+def set_loop_options(
+ parallel_iterations=UNSPECIFIED,
+ back_prop=UNSPECIFIED,
+ swap_memory=UNSPECIFIED,
+ maximum_iterations=UNSPECIFIED):
+ """Specifies additional arguments to be passed to the enclosing while_loop.
+
+ The parameters apply to and only to the immediately enclosing loop. It only
+ has effect if the loop is staged as a TF while_loop; otherwise the parameters
+ have no effect.
+
+ Args:
+ parallel_iterations: See tf.while_loop.
+ back_prop: See tf.while_loop.
+ swap_memory: See tf.while_loop.
+ maximum_iterations: See tf.while_loop.
+ """
+ del parallel_iterations
+ del back_prop
+ del swap_memory
+ del maximum_iterations
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/contrib/autograph/lang/special_functions.py
new file mode 100644
index 0000000000..11135295a7
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/special_functions.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special functions that only make sense for AutoGraph.
+
+These functions are meant to ensure feature parity between Python and AutoGraph,
+so that the exact same code works in both modes. In general, AutoGraph will
+replace these calls.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import data_structures
+
+
+def stack(list_or_tensor, element_dtype=None, strict=True):
+ """Stacks the input, if it admits the notion of stacking.
+
+ For example, a list of tensors can be stacked into a larger tensor. This
+ function is similar to tf.stack, but it accepts non-lists and lists of
+ non-tensors as arguments. In the latter case, the function does nothing.
+
+ Args:
+ list_or_tensor: Any
+ element_dtype: tf.DType, optional dtypedtype for the elements in the list.
+ Required if the input is stackable, and the list is untyped.
+ strict: bool, if True an error is raised if the input is not stackable.
+ Otherwise the function is a no-op.
+
+ Returns:
+ Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
+ if strict=False, the result will be list_or_tensor.
+
+ Raises:
+ ValueError: if strict=True and the input is not stackable.
+ """
+ if strict:
+ def raise_error(x):
+ raise ValueError('%s must be stackable when strict=True' % x)
+ original_call = raise_error
+ else:
+ original_call = lambda x: x
+ return data_structures.list_stack(
+ list_or_tensor,
+ data_structures.ListStackOpts(
+ element_dtype=element_dtype, original_call=original_call))
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/contrib/autograph/lang/special_functions_test.py
new file mode 100644
index 0000000000..a49cb64075
--- /dev/null
+++ b/tensorflow/contrib/autograph/lang/special_functions_test.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for special_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SpecialFunctionsTest(test.TestCase):
+
+ def test_basic(self):
+ self.assertEqual(special_functions.stack(1, strict=False), 1)
+ self.assertListEqual(
+ special_functions.stack([1, 2, 3], strict=False), [1, 2, 3])
+ # TODO(mdan): This should probably forward to tf.stack.
+ self.assertTrue(
+ isinstance(
+ special_functions.stack(
+ [constant_op.constant(1),
+ constant_op.constant(2)], strict=False), list))
+
+ with self.assertRaises(ValueError):
+ special_functions.stack([1, 2, 3])
+
+ t = constant_op.constant([1.0, 2.0])
+ l = list_ops.tensor_list_from_tensor(
+ t, element_shape=constant_op.constant([], dtype=dtypes.int32))
+ self.assertTrue(
+ tensor_util.is_tensor(
+ special_functions.stack(l, element_dtype=dtypes.float32)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 0c6ab65505..332d5dab19 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -28,7 +28,15 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
"//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index c900fd6af2..392cb60bcc 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""This module implements operators that we overload.
+"""This module implements operators that AutoGraph overloads.
Note that "operator" is used loosely here, and includes control structures like
conditionals and loops, implemented in functional form, using for example
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 671c9ccc13..988df70157 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -51,7 +51,7 @@ def for_stmt(iter_, extra_test, body, init_state):
Args:
iter_: The entity being iterated over.
extra_test: Callable with the state as arguments, and boolean return type.
- An additionnal loop condition.
+ An additional loop condition.
body: Callable with the iterate and the state as arguments, and
state as return type. The actual loop body.
init_state: Tuple containing the initial state.
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD
index 796ab445c7..f77a6ab392 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/contrib/autograph/pyct/BUILD
@@ -22,9 +22,10 @@ py_library(
"__init__.py",
"anno.py",
"ast_util.py",
+ "cfg.py",
"compiler.py",
- "context.py",
"inspect_utils.py",
+ "origin_info.py",
"parser.py",
"pretty_printer.py",
"qual_names.py",
@@ -38,6 +39,8 @@ py_library(
"@gast_archive//:gast",
"@six_archive//:six",
"@termcolor_archive//:termcolor",
+ # TODO(mdan): Remove this dependency.
+ "//tensorflow/python:util",
],
)
@@ -63,6 +66,17 @@ py_test(
)
py_test(
+ name = "cfg_test",
+ srcs = ["cfg_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
name = "compiler_test",
srcs = ["compiler_test.py"],
srcs_version = "PY2AND3",
@@ -130,6 +144,7 @@ py_test(
name = "transformer_test",
srcs = ["transformer_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py
index cc4a7edf02..92f1370e05 100644
--- a/tensorflow/contrib/autograph/pyct/anno.py
+++ b/tensorflow/contrib/autograph/pyct/anno.py
@@ -44,10 +44,19 @@ class Basic(NoValue):
'be indented below it. The annotation contains a tuple '
'(new_body, name_map), where `new_body` is the new indented block and '
'`name_map` allows renaming symbols.')
+ ORIGIN = ('Contains OriginInfo objects specific to the annotated node. See '
+ 'origin_information.py for definition.')
-def getanno(node, key, field_name='___pyct_anno'):
- return getattr(node, field_name)[key]
+FAIL = object()
+
+
+def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
+ if (default is FAIL or (hasattr(node, field_name) and
+ (key in getattr(node, field_name)))):
+ return getattr(node, field_name)[key]
+ else:
+ return default
def hasanno(node, key, field_name='___pyct_anno'):
@@ -73,5 +82,9 @@ def delanno(node, key, field_name='___pyct_anno'):
def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
- if hasanno(from_node, key, field_name):
- setanno(to_node, key, getanno(from_node, key, field_name), field_name)
+ if hasanno(from_node, key, field_name=field_name):
+ setanno(
+ to_node,
+ key,
+ getanno(from_node, key, field_name=field_name),
+ field_name=field_name)
diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py
index 1d4d9d119e..f2c0c8cf05 100644
--- a/tensorflow/contrib/autograph/pyct/anno_test.py
+++ b/tensorflow/contrib/autograph/pyct/anno_test.py
@@ -38,12 +38,14 @@ class AnnoTest(test.TestCase):
anno.setanno(node, 'foo', 3)
self.assertTrue(anno.hasanno(node, 'foo'))
- self.assertEqual(3, anno.getanno(node, 'foo'))
+ self.assertEqual(anno.getanno(node, 'foo'), 3)
+ self.assertEqual(anno.getanno(node, 'bar', default=7), 7)
anno.delanno(node, 'foo')
self.assertFalse(anno.hasanno(node, 'foo'))
with self.assertRaises(AttributeError):
anno.getanno(node, 'foo')
+ self.assertIsNone(anno.getanno(node, 'foo', default=None))
def test_copyanno(self):
node_1 = ast.Name()
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/contrib/autograph/pyct/cfg.py
new file mode 100644
index 0000000000..666328781f
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/cfg.py
@@ -0,0 +1,733 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow graph (CFG) structure for Python AST representation.
+
+The CFG is a digraph with edges representing valid control flow. Each
+node is associated with exactly one AST node, but not all AST nodes may have
+a corresponding CFG counterpart.
+
+Once built, the CFG itself is immutable, but the values it holds need not be;
+they are usually annotated with information extracted by walking the graph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from enum import Enum
+
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
+
+from tensorflow.contrib.autograph.pyct import compiler
+
+
+class Node(object):
+ """A node in the CFG.
+
+ Although new instances of this class are mutable, the objects that a user
+ finds in the CFG are typically not.
+
+ The nodes represent edges in the CFG graph, and maintain pointers to allow
+ efficient walking in both forward and reverse order. The following property
+ holds for all nodes: "child in node.next" iff "node in child.prev".
+
+ Attributes:
+ next: FrozenSet[Node, ...], the nodes that follow this node, in control
+ flow order
+ prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse
+ control flow order
+ ast_node: ast.AST, the AST node corresponding to this CFG node
+ """
+
+ def __init__(self, next_, prev, ast_node):
+ self.next = next_
+ self.prev = prev
+ self.ast_node = ast_node
+
+ def freeze(self):
+ self.next = frozenset(self.next)
+ self.prev = frozenset(self.prev)
+
+ def __repr__(self):
+ return compiler.ast_to_source(self.ast_node).strip()
+
+
+class Graph(
+ collections.namedtuple('Graph', ['entry', 'exit', 'error', 'index'])):
+ """A Control Flow Graph.
+
+ The CFG maintains an index to allow looking up a CFG node by the AST node to
+ which it is associated. The index can also be enumerated in top-down, depth
+ first order.
+
+ Walking the graph in forward or reverse order is supported by double
+ parent-child links.
+
+ Note: the error nodes are not wired to their corresponding finally guards,
+ because these are shared, and wiring them would create a reverse path from
+ normal control flow into the error nodes, which we want to avoid.
+
+ Attributes:
+ entry: Node, the entry node
+ exit: FrozenSet[Node, ...], the exit nodes
+ error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised
+ error (errors propagated from function calls are not accounted)
+ index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG
+ node
+ """
+
+ def __repr__(self):
+ result = 'digraph CFG {\n'
+ for node in self.index.values():
+ result += ' %s [label="%s"];\n' % (id(node), node)
+ for node in self.index.values():
+ if node.next:
+ result += ' %s -> {%s};\n' % (id(node), ', '.join(
+ repr(id(n)) for n in node.next))
+ result += '}'
+ return result
+
+
+class _WalkMode(Enum):
+ FORWARD = 1
+ REVERSE = 2
+
+
+class GraphVisitor(object):
+ """Base class for a CFG visitors.
+
+ This implementation is not thread safe.
+
+ The visitor has some facilities to simplify dataflow analyses. In particular,
+ it allows revisiting the nodes at the decision of the subclass. This can be
+ used to visit the graph until the state reaches a fixed point.
+
+ For more details on dataflow analysis, see
+ https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf
+
+ Note: the literature generally suggests visiting successor nodes only when the
+ state of the current node changed, regardless of whether that successor has
+ ever been visited. This implementation visits every successor at least once.
+
+ Attributes:
+ graph: Graph
+ in_: Dict[Node, Any], stores node-keyed state during a visit
+ out: Dict[Node, Any], stores node-keyed state during a visit
+ """
+
+ def reset(self):
+ self.in_ = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+ self.out = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+
+ def init_state(self, node):
+ """State initialization function. Optional to overload.
+
+ An in/out state slot will be created for each node in the graph. Subclasses
+ may overload this to control what that is initialized to.
+
+ Args:
+ node: Node
+ """
+ del node
+ return None
+
+ def visit_node(self, node):
+ """Visitor function.
+
+ Args:
+ node: Node
+ Returns:
+ bool, whether the node should be revisited; subclasses can visit every
+ reachable node exactly once by always returning False
+ """
+ raise NotImplementedError('Subclasses must implement this.')
+
+ def _visit_internal(self, mode):
+ """Visits the CFG, depth-first."""
+ assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
+ if mode == _WalkMode.FORWARD:
+ open_ = [self.graph.entry]
+ elif mode == _WalkMode.REVERSE:
+ open_ = list(self.graph.exit)
+ closed = set()
+ self.reset()
+
+ while open_:
+ node = open_.pop(0)
+ closed.add(node)
+
+ should_revisit = self.visit_node(node)
+
+ if mode == _WalkMode.FORWARD:
+ children = node.next
+ elif mode == _WalkMode.REVERSE:
+ children = node.prev
+
+ for next_ in children:
+ if should_revisit or next_ not in closed:
+ open_.append(next_)
+
+ def visit_forward(self, graph):
+ self.graph = graph
+ self._visit_internal(_WalkMode.FORWARD)
+
+ def visit_reverse(self, graph):
+ self.graph = graph
+ self._visit_internal(_WalkMode.REVERSE)
+
+
+class GraphBuilder(object):
+ """Builder that constructs a CFG from a given AST.
+
+ This GraphBuilder facilitates constructing the DAG that forms the CFG when
+ nodes
+ are supplied in lexical order (i.e., top-down, depth first). Under these
+ conditions, it supports building patterns found in typical structured
+ programs.
+
+ This builder ignores the flow generated by exceptions, which are assumed to
+ always be catastrophic and present purely for diagnostic purposes (e.g. to
+ print debug information). Statements like raise and try/catch sections are
+ allowed and will generate control flow edges, but ordinaty statements are
+ assumed not to raise exceptions.
+
+ Finally sections are also correctly interleaved between break/continue/return
+ nodes and their subsequent statements.
+
+ Important concepts:
+ * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly
+ * leaf set - since the graph is constructed gradually, a leaf set maintains
+ the CFG nodes that will precede the node that the builder expects to
+ receive next; when an ordinary node is added, it is connected to the
+ existing leaves and it in turn becomes the new leaf
+ * jump nodes - nodes that should generate edges other than what
+ ordinary nodes would; these correspond to break, continue and return
+ statements
+ * sections - logical delimiters for subgraphs that require special
+ edges; there are various types of nodes, each admitting various
+ types of jump nodes; sections are identified by their corresponding AST
+ node
+ """
+
+ # TODO(mdan): Perhaps detail this in a markdown doc.
+ # TODO(mdan): Add exception support.
+
+ def __init__(self, parent_ast_node):
+ self.reset()
+ self.parent = parent_ast_node
+
+ def reset(self):
+ """Resets the state of this factory."""
+ self.head = None
+ self.errors = set()
+ self.node_index = collections.OrderedDict()
+
+ # TODO(mdan): Too many primitives. Use classes.
+ self.leaves = set()
+
+ self.finally_sections = {}
+ self.finally_section_subgraphs = {} # Values are [begin_node, exit_nodes]
+ # Whether the guard section can be reached from the statement that precedes
+ # it.
+ self.finally_section_has_direct_flow = {}
+ # Finally sections that await their first node.
+ self.pending_finally_sections = set()
+
+ # Exit jumps keyed by the section they affect.
+ self.exits = {}
+
+ # The entry of loop sections, keyed by the section.
+ self.section_entry = {}
+ # Continue jumps keyed by the section they affect.
+ self.continues = {}
+
+ # The entry of conditional sections, keyed by the section.
+ self.cond_entry = {}
+ # Lists of leaf nodes corresponding to each branch in the section.
+ self.cond_leaves = {}
+
+ def _connect_nodes(self, first, second):
+ """Connects nodes to signify that control flows from first to second.
+
+ Args:
+ first: Union[Set[Node, ...], Node]
+ second: Node
+ """
+ if isinstance(first, Node):
+ first.next.add(second)
+ second.prev.add(first)
+ else:
+ for node in first:
+ self._connect_nodes(node, second)
+
+ def _add_new_node(self, ast_node):
+ """Grows the graph by adding a CFG node following the current leaves."""
+ if ast_node is self.node_index:
+ raise ValueError('%s added twice' % ast_node)
+ node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ self.node_index[ast_node] = node
+
+ if self.head is None:
+ self.head = node
+
+ for leaf in self.leaves:
+ self._connect_nodes(leaf, node)
+
+ # If any finally section awaits its first node, populate it.
+ for section_id in self.pending_finally_sections:
+ self.finally_section_subgraphs[section_id][0] = node
+ self.pending_finally_sections = set()
+
+ return node
+
+ def add_ordinary_node(self, ast_node):
+ """Grows the graph by adding an ordinary CFG node.
+
+ Ordinary nodes are followed by the next node, in lexical order, that is,
+ they become the new leaf set.
+
+ Args:
+ ast_node: ast.AST
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set((node,))
+ return node
+
+ def _add_jump_node(self, ast_node, guards):
+ """Grows the graph by adding a jump node.
+
+ Jump nodes are added to the current leaf set, and the leaf set becomes
+ empty. If the jump node is the last in a cond section, then it may be added
+ back to the leaf set by a separate mechanism.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections active for this node
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set()
+ # The guards themselves may not yet be complete, and will be wired later.
+ self.finally_sections[node] = guards
+ return node
+
+ def _connect_jump_to_finally_sections(self, node):
+ """Connects a jump node to the finally sections protecting it."""
+ cursor = set((node,))
+ for guard_section_id in self.finally_sections[node]:
+ guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id]
+ self._connect_nodes(cursor, guard_begin)
+ cursor = guard_ends
+ del self.finally_sections[node]
+ # TODO(mdan): Should garbage-collect finally_section_subgraphs.
+ return cursor
+
+ def add_exit_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding an exit node.
+
+ This node becomes an exit for the current section.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.exits[section_id].add(node)
+
+ def add_continue_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding a reentry node.
+
+ This node causes control flow to go back to the loop section's entry.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.continues[section_id].add(node)
+
+ def add_error_node(self, ast_node, guards):
+ """Grows the graph by adding an error node.
+
+ This node becomes an exit for the entire graph.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.errors.add(node)
+ self.leaves = set()
+
+ def enter_section(self, section_id):
+ """Enters a regular section.
+
+ Regular sections admit exit jumps, which end the section.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_exit_node
+ """
+ assert section_id not in self.exits
+ self.exits[section_id] = set()
+
+ def exit_section(self, section_id):
+ """Exits a regular section."""
+
+ # Exits are jump nodes, which may be protected.
+ for exit_ in self.exits[section_id]:
+ self.leaves |= self._connect_jump_to_finally_sections(exit_)
+
+ del self.exits[section_id]
+
+ def enter_loop_section(self, section_id, entry_node):
+ """Enters a loop section.
+
+ Loop sections define an entry node. The end of the section always flows back
+ to the entry node. These admit continue jump nodes which also flow to the
+ entry node.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_continue_node
+ entry_node: ast.AST, the entry node into the loop (e.g. the test node
+ for while loops)
+ """
+ assert section_id not in self.section_entry
+ assert section_id not in self.continues
+ self.continues[section_id] = set()
+ node = self.add_ordinary_node(entry_node)
+ self.section_entry[section_id] = node
+
+ def exit_loop_section(self, section_id):
+ """Exits a loop section."""
+ self._connect_nodes(self.leaves, self.section_entry[section_id])
+
+ # continues are jump nodes, which may be protected.
+ for reentry in self.continues[section_id]:
+ guard_ends = self._connect_jump_to_finally_sections(reentry)
+ self._connect_nodes(guard_ends, self.section_entry[section_id])
+
+ # Loop nodes always loop back.
+ self.leaves = set((self.section_entry[section_id],))
+
+ del self.continues[section_id]
+ del self.section_entry[section_id]
+
+ def enter_cond_section(self, section_id):
+ """Enters a conditional section.
+
+ Conditional sections define an entry node, and one or more branches.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ section_id arg passed to new_cond_branch
+ """
+
+ assert section_id not in self.cond_entry
+ assert section_id not in self.cond_leaves
+ self.cond_leaves[section_id] = []
+
+ def new_cond_branch(self, section_id):
+ """Begins a new branch in a cond section."""
+ assert section_id in self.cond_leaves
+
+ if section_id in self.cond_entry:
+ # Subsequent splits move back to the split point, and memorize the
+ # current leaves.
+ self.cond_leaves[section_id].append(self.leaves)
+ self.leaves = self.cond_entry[section_id]
+ else:
+ # If this is the first time we split a section, just remember the split
+ # point.
+ self.cond_entry[section_id] = self.leaves
+
+ def exit_cond_section(self, section_id):
+ """Exits a conditional section."""
+ for split in self.cond_leaves[section_id]:
+ self.leaves |= split
+ del self.cond_entry[section_id]
+ del self.cond_leaves[section_id]
+
+ def enter_finally_section(self, section_id):
+ """Enters a finally section."""
+ # TODO(mdan): This, not the caller, should track the active sections.
+ self.finally_section_subgraphs[section_id] = [None, None]
+ if self.leaves:
+ self.finally_section_has_direct_flow[section_id] = True
+ else:
+ self.finally_section_has_direct_flow[section_id] = False
+ self.pending_finally_sections.add(section_id)
+
+ def exit_finally_section(self, section_id):
+ """Exits a finally section."""
+ assert section_id not in self.pending_finally_sections, 'Empty finally?'
+ self.finally_section_subgraphs[section_id][1] = self.leaves
+ # If the guard can only be reached by a jump, then it will not flow
+ # into the statement that follows it.
+ if not self.finally_section_has_direct_flow[section_id]:
+ self.leaves = set()
+ del self.finally_section_has_direct_flow[section_id]
+
+ def build(self):
+ """Returns the CFG accumulated so far and resets the builder.
+
+ Returns:
+ Graph
+ """
+ # Freeze the nodes.
+ for node in self.node_index.values():
+ node.freeze()
+
+ result = Graph(
+ entry=self.head,
+ exit=self.leaves,
+ error=self.errors,
+ index=self.node_index)
+
+ # Reset the state.
+ self.reset()
+
+ return result
+
+
+class AstToCfg(gast.NodeVisitor):
+ """Converts an AST to CFGs.
+
+ A separate CFG will be constructed for each function.
+ """
+
+ # TODO(mdan): Figure out how to deal with closures.
+
+ def __init__(self):
+ super(AstToCfg, self).__init__()
+
+ self.builder_stack = []
+ self.builder = None
+ self.cfgs = {}
+
+ self.lexical_scopes = []
+
+ def _enter_lexical_scope(self, node):
+ self.lexical_scopes.append(node)
+
+ def _exit_lexical_scope(self, node):
+ leaving_node = self.lexical_scopes.pop()
+ assert node == leaving_node
+
+ def _get_enclosing_scopes(self, include, stop_at):
+ included = []
+ for node in reversed(self.lexical_scopes):
+ if isinstance(node, include):
+ included.append(node)
+ if isinstance(node, stop_at):
+ return node, included
+ return None, included
+
+ def _process_basic_statement(self, node):
+ self.generic_visit(node)
+ self.builder.add_ordinary_node(node)
+
+ def _process_exit_statement(self, node, *exits_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(exits_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError(
+ '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type))
+ self.builder.add_exit_node(node, try_node, guards)
+
+ def _process_continue_statement(self, node, *loops_to_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(loops_to_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any of %s' %
+ (node, loops_to_nodes_of_type))
+ self.builder.add_continue_node(node, try_node, guards)
+
+ def visit_FunctionDef(self, node):
+ self.builder_stack.append(self.builder)
+ self.builder = GraphBuilder(node)
+
+ self._enter_lexical_scope(node)
+ self.builder.enter_section(node)
+
+ self._process_basic_statement(node.args)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+ self._exit_lexical_scope(node)
+
+ self.cfgs[node] = self.builder.build()
+ self.builder = self.builder_stack.pop()
+
+ def visit_Lambda(self, node):
+ # TODO(mdan): Treat like FunctionDef? That would be a separate CFG.
+ raise NotImplementedError()
+
+ def visit_Return(self, node):
+ self._process_exit_statement(node, gast.FunctionDef)
+
+ def visit_Expr(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Assign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AnnAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AugAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Print(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Raise(self, node):
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=(gast.FunctionDef,),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any FunctionDef' % node)
+ self.builder.add_error_node(node, try_node, guards)
+
+ def visit_Assert(self, node):
+ # Ignoring the effect of exceptions.
+ self._process_basic_statement(node)
+
+ def visit_Delete(self, node):
+ self._process_basic_statement(node)
+
+ def visit_If(self, node):
+ # No need to track ifs as lexical scopes, for now.
+ # Lexical scopes are generally tracked in order to be able to resolve the
+ # targets of jump statements like break/continue/etc. Since there is no
+ # statement that can interrupt a conditional, we don't need to track their
+ # lexical scope. That may change in the future.
+
+ self.builder.enter_cond_section(node)
+ self._process_basic_statement(node.test)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_cond_section(node)
+
+ def visit_While(self, node):
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ self.builder.enter_loop_section(node, node.test)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # the statements inside it don't affect the loop itself. For example, a
+ # break in the loop's orelse will not affect the loop itself.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+
+ def visit_For(self, node):
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ # TODO(mdan): Strictly speaking, this should be node.target + node.iter.
+ # A blind dataflow analysis would have to process both node.target and
+ # node.iter to properly process read and write access.
+ self.builder.enter_loop_section(node, node.iter)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # they don't count as loop bodies. For example, a break in the loop's
+ # orelse will affect the parent loop, not the current one.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+
+ def visit_Break(self, node):
+ self._process_exit_statement(node, gast.While, gast.For)
+
+ def visit_Continue(self, node):
+ self._process_continue_statement(node, gast.While, gast.For)
+
+ def visit_Try(self, node):
+ self._enter_lexical_scope(node)
+
+ for stmt in node.body:
+ self.visit(stmt)
+ # Unlike loops, the orelse is a simple continuation of the body.
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ if node.handlers:
+ # TODO(mdan): Should we still support bare try/except? Might be confusing.
+ raise NotImplementedError('exceptions are not yet supported')
+
+ self._exit_lexical_scope(node)
+
+ self.builder.enter_finally_section(node)
+ for stmt in node.finalbody:
+ self.visit(stmt)
+ self.builder.exit_finally_section(node)
+
+ def visit_With(self, node):
+ # TODO(mdan): Mark the context manager's exit call as exit guard.
+ self._process_basic_statement(node.items)
+ for stmt in node.body:
+ self.visit(stmt)
+
+
+def build(node):
+ builder = AstToCfg()
+ builder.visit(node)
+ return builder.cfgs
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/contrib/autograph/pyct/cfg_test.py
new file mode 100644
index 0000000000..00afadd521
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/cfg_test.py
@@ -0,0 +1,790 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cfg module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import cfg
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class CountingVisitor(cfg.GraphVisitor):
+
+ def __init__(self):
+ self.counts = {}
+
+ def visit_node(self, node):
+ self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1
+ return False # visit only once
+
+
+class GraphVisitorTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs, node
+
+ def test_basic_coverage_forward(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor()
+ visitor.visit_forward(graph)
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ # The return node should be unreachable in forward direction.
+ self.assertTrue(fn_node.body[0].body[2] not in visitor.counts)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+ def test_basic_coverage_reverse(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor()
+ visitor.visit_reverse(graph)
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+
+class AstToCfgTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs
+
+ def _repr_set(self, node_set):
+ return set(repr(n) for n in node_set)
+
+ def _as_set(self, elements):
+ if elements is None:
+ return frozenset()
+ elif isinstance(elements, str):
+ return frozenset((elements,))
+ else:
+ return frozenset(elements)
+
+ def assertGraphMatches(self, graph, edges):
+ """Tests whether the CFG contains the specified edges."""
+ for prev, node_repr, next_ in edges:
+ matched = False
+ for cfg_node in graph.index.values():
+ if repr(cfg_node) == node_repr:
+ if (self._as_set(prev) == set(map(repr, cfg_node.prev)) and
+ self._as_set(next_) == set(map(repr, cfg_node.next))):
+ matched = True
+ break
+ if not matched:
+ self.fail(
+ 'match failed for node "%s" in graph:\n%s' % (node_repr, graph))
+
+ def test_straightline(self):
+
+ def test_fn(a):
+ a += 1
+ a = 2
+ a = 3
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'a += 1'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', 'return'),
+ ('a = 3', 'return', None),
+ ),
+ )
+
+ def test_straightline_no_return(self):
+
+ def test_fn(a, b):
+ a = b + 1
+ a += max(a)
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a, b', 'a = b + 1'),
+ ('a = b + 1', 'a += max(a)', None),
+ ),
+ )
+
+ def test_unreachable_code(self):
+
+ def test_fn(a):
+ return
+ a += 1 # pylint:disable=unreachable
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'return'),
+ ('a', 'return', None),
+ (None, 'a += 1', None),
+ ),
+ )
+
+ def test_branch_straightline(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+ else:
+ a += -1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('(a > 0)', 'a = 1', None),
+ ('(a > 0)', 'a += -1', None),
+ ),
+ )
+
+ def test_branch_nested(self):
+
+ def test_fn(a):
+ if a > 0:
+ if a > 1:
+ a = 1
+ else:
+ a = 2
+ else:
+ if a > 2:
+ a = 3
+ else:
+ a = 4
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', ('(a > 1)', '(a > 2)')),
+ ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', None),
+ ('(a > 1)', 'a = 2', None),
+ ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')),
+ ('(a > 2)', 'a = 3', None),
+ ('(a > 2)', 'a = 4', None),
+ ),
+ )
+
+ def test_branch_straightline_semi(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', 'a = 1'),
+ ('(a > 0)', 'a = 1', None),
+ ),
+ )
+
+ def test_branch_return(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+ else:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', ('return', 'a = 1')),
+ ('(a > 0)', 'a = 1', 'a = 2'),
+ ('(a > 0)', 'return', None),
+ ('a = 1', 'a = 2', None),
+ ),
+ )
+
+ def test_branch_return_minimal(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'return'),
+ ('(a > 0)', 'return', None),
+ ),
+ )
+
+ def test_while_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', None),
+ ),
+ )
+
+ def test_while_else_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', '(a > 0)'),
+ ('a = 0', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+
+ def test_while_else_return(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 3:
+ continue
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')),
+ ('(a > 1)', '(a > 3)', ('continue', 'a = 1')),
+ ('(a > 3)', 'continue', '(a > 1)'),
+ ('(a > 3)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_while_nested_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 2:
+ break
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
+ ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'a = 1', '(a > 1)'),
+ (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', None),
+ ),
+ )
+
+ def test_for_else_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', 'range(0, a)'),
+ ('(a > 1)', 'a = 0', 'a = 1'),
+ ('a = 0', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+
+ def test_for_else_return(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')),
+ ('range(1, a)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 3:
+ continue
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)',
+ ('(a > 3)', 'a = 2')),
+ ('range(1, a)', '(a > 3)', ('continue', 'b += 1')),
+ ('(a > 3)', 'continue', 'range(1, a)'),
+ ('(a > 3)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_for_nested_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 2:
+ break
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')),
+ ('range(1, a)', '(a > 2)', ('break', 'b += 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'b += 1', 'range(1, a)'),
+ (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_complex(self):
+
+ def test_fn(a):
+ b = 0
+ while a > 0:
+ for b in range(0, a):
+ if a > 2:
+ break
+ if a > 3:
+ if a > 4:
+ continue
+ else:
+ max(a)
+ break
+ b += 1
+ else: # for b in range(0, a):
+ return a
+ a = 2
+ for a in range(1, a):
+ return b
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')),
+ (
+ ('(a > 0)', 'continue', 'b += 1'),
+ 'range(0, a)',
+ ('(a > 2)', 'return a'),
+ ),
+ ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')),
+ ('(a > 3)', '(a > 4)', ('continue', 'max(a)')),
+ ('(a > 4)', 'max(a)', 'break'),
+ ('max(a)', 'break', 'a = 2'),
+ ('(a > 4)', 'continue', 'range(0, a)'),
+ ('(a > 3)', 'b += 1', 'range(0, a)'),
+ ('range(0, a)', 'return a', None),
+ ('break', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')),
+ ('range(1, a)', 'return b', None),
+ ('range(1, a)', 'a = 3', None),
+ ),
+ )
+
+ def test_finally_straightline(self):
+
+ def test_fn(a):
+ try:
+ a += 1
+ finally:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'a += 1', 'a = 2'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_return_finally(self):
+
+ def test_fn(a):
+ try:
+ return a
+ finally:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'return a', 'a = 1'),
+ ('return a', 'a = 1', None),
+ (None, 'a = 2', None),
+ ),
+ )
+
+ def test_break_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ break
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'break'),
+ ('(a > 0)', 'break', 'a = 1'),
+ ('break', 'a = 1', None),
+ ),
+ )
+
+ def test_continue_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ continue
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', 'continue'),
+ ('(a > 0)', 'continue', 'a = 1'),
+ ('continue', 'a = 1', '(a > 0)'),
+ ),
+ )
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
new file mode 100644
index 0000000000..ca1441cf6f
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -0,0 +1,38 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "common_transformers",
+ srcs = [
+ "anf.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/autograph/pyct",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "anf_test",
+ srcs = ["anf_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":common_transformers",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
new file mode 100644
index 0000000000..cc039986c2
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Conversion to A-normal form."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import transformer
+
+
+class DummyGensym(object):
+ """A dumb gensym that suffixes a stem by sequential numbers from 1000."""
+
+ def __init__(self, entity_info):
+ del entity_info
+ # A proper implementation needs to account for:
+ # * entity_info.namespace
+ # * all the symbols defined in the AST
+ # * the symbols generated so far
+ self._idx = 0
+
+ def new_name(self, stem):
+ self._idx += 1
+ return stem + '_' + str(1000 + self._idx)
+
+
+class AnfTransformer(transformer.Base):
+ """Performs the actual conversion."""
+
+ # TODO(mdan): Link to a reference.
+ # TODO(mdan): Implement.
+
+ def __init__(self, entity_info):
+ """Creates a transformer.
+
+ Args:
+ entity_info: transformer.EntityInfo
+ """
+ super(AnfTransformer, self).__init__(entity_info)
+ self._gensym = DummyGensym(entity_info)
+
+
+def transform(node, entity_info):
+ return AnfTransformer(entity_info).visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
new file mode 100644
index 0000000000..81983a5ecb
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for anf module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.contrib.autograph.pyct.common_transformers import anf
+from tensorflow.python.platform import test
+
+
+class AnfTransformerTest(test.TestCase):
+
+ def _simple_source_info(self):
+ return transformer.EntityInfo(
+ source_code=None,
+ source_file=None,
+ namespace=None,
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+
+ def test_basic(self):
+
+ def test_function():
+ a = 0
+ return a
+
+ node, _ = parser.parse_entity(test_function)
+ node = anf.transform(node, self._simple_source_info())
+ result, _ = compiler.ast_to_object(node)
+
+ self.assertEqual(test_function(), result.test_function())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/pyct/context.py b/tensorflow/contrib/autograph/pyct/context.py
deleted file mode 100644
index b34015cfd2..0000000000
--- a/tensorflow/contrib/autograph/pyct/context.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Conversion context containers."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-class EntityContext(object):
- """Contains information about an entity, like source code.
-
- In general, objects of this class should be considered immutable.
-
- Attributes:
- namer: Namer that matches the contract of all converters.
- source_code: The entity's source code.
- source_file: The entity's source file.
- namespace: Dict[str->*], containing symbols visible to the entity
- (excluding parameters).
- arg_values: Dict[str->*], containing parameter values, if known.
- arg_types: Dict[str->*], containing parameter types, if known.
- owner_type: The surrounding class type of the function, if present.
- """
-
- # TODO(mdan): Remove the default and update tests.
- def __init__(self, namer, source_code, source_file, namespace, arg_values,
- arg_types, owner_type, recursive, type_annotation_func=None):
- self.namer = namer
- self.source_code = source_code
- self.source_file = source_file
- self.namespace = namespace
- self.arg_values = {} if arg_values is None else arg_values
- self.arg_types = {} if arg_types is None else arg_types
- self.owner_type = owner_type
- self.recursive = recursive
- self.type_annotation_func = type_annotation_func
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/contrib/autograph/pyct/origin_info.py
new file mode 100644
index 0000000000..2b05836e46
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/origin_info.py
@@ -0,0 +1,35 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Container for origin source code information before AutoGraph compilation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+
+class OriginInfo(
+ namedtuple('OriginInfo', ('file_path', 'function_name', 'line_number',
+ 'column_offset', 'source_code_line'))):
+ """Container for information about the source code before conversion.
+
+ Instances of this class contain information about the source code that
+ transformed code originated from. Examples include:
+ * line number
+ * file name
+ * original user code
+ """
+
+ pass
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py
index 583cf7ecd7..da07013cf4 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names.py
@@ -205,6 +205,7 @@ class QnResolver(gast.NodeTransformer):
return node
def visit_Subscript(self, node):
+ # TODO(mdan): This may no longer apply if we overload getitem.
node = self.generic_visit(node)
s = node.slice
if not isinstance(s, gast.Index):
@@ -216,7 +217,11 @@ class QnResolver(gast.NodeTransformer):
elif isinstance(s.value, gast.Str):
subscript = QN(StringLiteral(s.value.s))
else:
- subscript = anno.getanno(node.slice.value, anno.Basic.QN)
+ # The index may be an expression, case in which a name doesn't make sense.
+ if anno.hasanno(node.slice.value, anno.Basic.QN):
+ subscript = anno.getanno(node.slice.value, anno.Basic.QN)
+ else:
+ return node
if anno.hasanno(node.value, anno.Basic.QN):
anno.setanno(node, anno.Basic.QN,
QN(anno.getanno(node.value, anno.Basic.QN),
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
index 8064a967cd..bcf2dacec2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
@@ -27,6 +27,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/contrib/autograph/utils",
"@gast_archive//:gast",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
index c325e19f28..9a82de735d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
@@ -18,10 +18,14 @@ This module contains utilities to help annotate AST nodes with as much runtime
information as can be possibly extracted without actually executing the code,
under that assumption that the context in which the code will run is known.
-Note: It's a fair bet that this analysis cannot be reused across contexts
-without re-running it. In most cases, the context usually means referenced
-modules, which should be static enough to allow reuse, but that is not being
-reliably verified.
+Overall, the different analyses have the functions listed below:
+
+ * activity: inventories symbols read, written to, params, etc. at different
+ levels
+ * liveness, reaching_definitions: dataflow analyses based on the program's CFG
+ and using the symbol information gathered by activity analysis
+ * live_values, type_info: type and value inference based on dataflow
+ analysis and context information
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
index fdbd349af9..bc22be0a27 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
@@ -21,9 +21,9 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.qual_names import QN
from tensorflow.contrib.autograph.pyct.static_analysis import activity
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -112,18 +112,16 @@ class ActivityAnalyzerTest(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=None,
+ entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
namespace={},
arg_values=None,
arg_types=None,
- owner_type=None,
- recursive=True)
+ owner_type=None)
node = qual_names.resolve(node)
- node = activity.resolve(node, ctx)
- return node, ctx
+ node = activity.resolve(node, entity_info)
+ return node, entity_info
def test_local_markers(self):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
index ad97fdfa8e..4acc4ed66a 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py
@@ -276,9 +276,9 @@ class Forward(object):
taken).
"""
- def __init__(self, label, context, transfer_fn=operator.or_):
+ def __init__(self, label, source_info, transfer_fn=operator.or_):
self.transfer_fn = transfer_fn
- self.context = context
+ self.source_info = source_info
self.out_label = label + '_out'
self.in_label = label + '_in'
self.gen_label = label + '_gen'
@@ -286,7 +286,7 @@ class Forward(object):
# TODO(alexbw): see if we can simplify by visiting breadth-first
def visit(self, node):
- """Depth-first walking the CFG, applying dataflow information propagtion."""
+ """Depth-first walking the CFG, applying dataflow info propagation."""
# node.value is None only for the exit CfgNode.
if not node.value:
return
@@ -399,18 +399,18 @@ class Liveness(Backward):
later in the program.
"""
- def __init__(self, context):
- super(Liveness, self).__init__('live', context)
+ def __init__(self, source_info):
+ super(Liveness, self).__init__('live', source_info)
def get_gen_kill(self, node, _):
# A variable's parents are live if it is live
# e.g. x is live if x.y is live. This means gen needs to return
# all parents of a variable (if it's an Attribute or Subscript).
# This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x)
- gen = activity.get_read(node.value, self.context)
+ gen = activity.get_read(node.value, self.source_info)
gen = functools.reduce(lambda left, right: left | right.support_set, gen,
gen)
- kill = activity.get_updated(node.value, self.context)
+ kill = activity.get_updated(node.value, self.source_info)
return gen, kill
@@ -420,11 +420,11 @@ class ReachingDefinitions(Forward):
Each statement is annotated with a set of (variable, definition) pairs.
"""
- def __init__(self, context):
- super(ReachingDefinitions, self).__init__('definitions', context)
+ def __init__(self, source_info):
+ super(ReachingDefinitions, self).__init__('definitions', source_info)
def get_gen_kill(self, node, incoming):
- definitions = activity.get_updated(node.value, self.context)
+ definitions = activity.get_updated(node.value, self.source_info)
gen = frozenset((id_, node.value) for id_ in definitions)
kill = frozenset(def_ for def_ in incoming if def_[0] in definitions)
return gen, kill
@@ -437,9 +437,10 @@ class Defined(Forward):
be defined at that point.
"""
- def __init__(self, context):
- super(Defined, self).__init__('defined', context, transfer_fn=operator.and_)
+ def __init__(self, source_info):
+ super(Defined, self).__init__(
+ 'defined', source_info, transfer_fn=operator.and_)
def get_gen_kill(self, node, _):
- gen = activity.get_updated(node.value, self.context)
+ gen = activity.get_updated(node.value, self.source_info)
return gen, frozenset()
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
index fc07fa3447..428ebbedca 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py
@@ -23,29 +23,26 @@ import functools
import gast
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import cfg
from tensorflow.python.platform import test
class CFGTest(test.TestCase):
- def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
- arg_types = arg_types or {}
+ def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=None,
+ entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
- namespace=namespace,
+ namespace={},
arg_values=None,
- arg_types=arg_types,
- owner_type=None,
- recursive=True)
+ arg_types=None,
+ owner_type=None)
node = qual_names.resolve(node)
- return node, ctx
+ return node, entity_info
def _check_anno_matches(self, node, anno_name, var_names):
if isinstance(var_names, str):
@@ -73,7 +70,7 @@ class CFGTest(test.TestCase):
x = x
return x
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
body = node.body[0].body
# Only the argument reaches the expression
@@ -106,7 +103,7 @@ class CFGTest(test.TestCase):
y = 2 # pylint: disable=unused-variable
return x
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.Defined(ctx))
body = node.body[0].body
# only x is for sure defined at the end
@@ -116,7 +113,7 @@ class CFGTest(test.TestCase):
self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y'))
def _get_live_annotated_fnbody(self, f):
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.Liveness(ctx))
body = node.body[0].body
return body
@@ -226,7 +223,7 @@ class CFGTest(test.TestCase):
return g(x)
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.Defined(ctx))
body = node.body[0].body
@@ -253,7 +250,7 @@ class CFGTest(test.TestCase):
return g() # y is not defined here
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.Defined(ctx))
body = node.body[0].body
self.assertEqual(
@@ -282,7 +279,7 @@ class CFGTest(test.TestCase):
return x, y
for f in (for_orelse, while_orelse):
- node, ctx = self._parse_and_analyze(f, {})
+ node, ctx = self._parse_and_analyze(f)
cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
body = node.body[0].body
return_node = body[-1]
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 53ae154590..9ccb98f79a 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -39,7 +39,7 @@ class LiveValueResolver(transformer.Base):
def visit_ClassDef(self, node):
self.generic_visit(node)
- anno.setanno(node, 'live_val', self.context.namespace[node.name])
+ anno.setanno(node, 'live_val', self.entity_info.namespace[node.name])
return node
def visit_Name(self, node):
@@ -55,8 +55,8 @@ class LiveValueResolver(transformer.Base):
if not symbol_is_local and not symbol_is_param:
if node.id in self.literals:
anno.setanno(node, 'live_val', self.literals[node.id])
- elif node.id in self.context.namespace:
- obj = self.context.namespace[node.id]
+ elif node.id in self.entity_info.namespace:
+ obj = self.entity_info.namespace[node.id]
anno.setanno(node, 'live_val', obj)
if hasattr(obj, '__name__'):
anno.setanno(node, 'fqn', (obj.__name__,))
@@ -80,8 +80,8 @@ class LiveValueResolver(transformer.Base):
# TODO(mdan): Use type annotations as fallback.
if not symbol_is_modified:
- if node.id in self.context.arg_values:
- obj = self.context.arg_values[node.id]
+ if node.id in self.entity_info.arg_values:
+ obj = self.entity_info.arg_values[node.id]
anno.setanno(node, 'live_val', obj)
anno.setanno(node, 'fqn', (obj.__class__.__name__,))
return node
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
index 69e428bde1..38af792777 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
@@ -21,9 +21,9 @@ from __future__ import print_function
import six
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import activity
from tensorflow.contrib.autograph.pyct.static_analysis import live_values
from tensorflow.contrib.autograph.pyct.static_analysis import type_info
@@ -39,22 +39,19 @@ class LiveValuesResolverTest(test.TestCase):
literals=None,
arg_types=None):
literals = literals or {}
- arg_types = arg_types or {}
node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=None,
+ entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
namespace=namespace,
arg_values=None,
arg_types=arg_types,
- owner_type=None,
- recursive=True)
+ owner_type=None)
node = qual_names.resolve(node)
- node = activity.resolve(node, ctx)
- node = live_values.resolve(node, ctx, literals)
- node = type_info.resolve(node, ctx)
- node = live_values.resolve(node, ctx, literals)
+ node = activity.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, literals)
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, literals)
return node
def test_literals(self):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index d6555dc7e0..a229c288a8 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -17,8 +17,8 @@
This analyzer uses known live values to further infer object types. This
may include for instance constructed objects and object member functions.
-In addition, the analyzer will also process annotations for TF (staged) type
-annotations.
+In addition, the analyzer also handles user annotations made in the code (for
+example, the autograph.set_element_type function).
Requires annotations generated by LiveValuesResolver.
"""
@@ -43,7 +43,9 @@ from __future__ import print_function
import gast
+from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -51,6 +53,7 @@ from tensorflow.python.util import tf_inspect
# TODO(mdan): Remove the duplication between this and activity.py.
# In particular, the symbol definitions we track here could as well be tracked
# there because they follow the same rules for visibility.
+# TODO(mdan): Use a CFG based Defined analysis instead.
class Scope(object):
"""Tracks symbol value references.
@@ -134,37 +137,40 @@ class TypeInfoResolver(transformer.Base):
node.orelse = self._visit_block(node.orelse)
return node
- def _process_function_arg(self, arg_name):
- str_name = str(arg_name)
- type_holder = arg_name.ast()
- self.scope.setval(arg_name, type_holder)
- if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types:
+ def _process_function_arg(self, arg_node):
+ qn = anno.getanno(arg_node, anno.Basic.QN)
+ arg_name = str(qn)
+ self.scope.setval(qn, arg_node)
+ if (len(self.enclosing_entities) == 1 and
+ arg_name in self.entity_info.arg_types):
# Forge a node to hold the type information, so that method calls on
# it can resolve the type.
- type_string, type_obj = self.context.arg_types[str_name]
- anno.setanno(type_holder, 'type', type_obj)
- anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.')))
+ type_string, type_obj = self.entity_info.arg_types[arg_name]
+ anno.setanno(arg_node, 'type', type_obj)
+ anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
def visit_arg(self, node):
- self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN))
+ self._process_function_arg(node.arg)
return node
def visit_Name(self, node):
self.generic_visit(node)
- qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Param):
- self._process_function_arg(qn)
- elif isinstance(node.ctx, gast.Load) and self.scope.hasval(qn):
- # E.g. if we had
- # a = b
- # then for future references to `a` we should have definition = `b`
- definition = self.scope.getval(qn)
- if anno.hasanno(definition, 'type'):
- anno.setanno(node, 'type', anno.getanno(definition, 'type'))
- anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn'))
- if anno.hasanno(definition, 'element_type'):
- anno.setanno(node, 'element_type',
- anno.getanno(definition, 'element_type'))
+ self._process_function_arg(node)
+ elif isinstance(node.ctx, gast.Load):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if self.scope.hasval(qn):
+ # E.g. if we had
+ # a = b
+ # then for future references to `a` we should have definition = `b`
+ definition = self.scope.getval(qn)
+ anno.copyanno(definition, node, 'type')
+ anno.copyanno(definition, node, 'type_fqn')
+ anno.setanno(node, 'definition', definition)
+
+ # TODO(mdan): Remove this when the directives module is in.
+ anno.copyanno(definition, node, 'element_type')
+ anno.copyanno(definition, node, 'element_shape')
return node
def _process_variable_assignment(self, target, value):
@@ -204,30 +210,27 @@ class TypeInfoResolver(transformer.Base):
node.targets, node.value, self._process_variable_assignment)
return node
+ # TODO(mdan): Remove as soon as the new directives module is ready.
def visit_Call(self, node):
if anno.hasanno(node.func, 'live_val'):
# Symbols targeted by the "set_type" marker function are assigned the data
# type that it specified.
- if (anno.getanno(node.func, 'live_val') is
- self.context.type_annotation_func):
+ if anno.getanno(node.func, 'live_val') is utils.set_element_type:
- if len(node.args) != 2:
- raise ValueError('"%s" must have exactly two parameters'
+ if len(node.args) < 2 or len(node.args) > 3:
+ raise ValueError('"%s" must have either two or three parameters'
% self.context.type_annotation_func)
- target_arg, type_arg = node.args
- if not anno.hasanno(target_arg, anno.Basic.QN):
- raise ValueError('the first argument of "%s" must by a symbol'
- % self.context.type_annotation_func)
- if isinstance(type_arg, gast.Str):
- element_type = type_arg.s
- elif isinstance(type_arg, gast.Num):
- element_type = type_arg.n
+ if len(node.args) == 2:
+ target_arg, type_arg = node.args
+ shape_arg = parser.parse_expression('None')
else:
- if not anno.hasanno(type_arg, 'live_val'):
- raise ValueError(
- 'the second argument of "%s" must be statically resolvable' %
- self.context.type_annotation_func)
- element_type = anno.getanno(type_arg, 'live_val')
+ target_arg, type_arg, shape_arg = node.args
+ if not anno.hasanno(target_arg, anno.Basic.QN):
+ raise ValueError('the first argument of "%s" must by a symbol' %
+ utils.set_element_type)
+ # TODO(mdan): This is vulnerable to symbol renaming.
+ element_type = type_arg
+ element_shape = shape_arg
target_symbol = anno.getanno(target_arg, anno.Basic.QN)
# Find the definition of this symbol and annotate it with the given
@@ -235,7 +238,9 @@ class TypeInfoResolver(transformer.Base):
# to receive the same type annotation.
definition = self.scope.getval(target_symbol)
anno.setanno(node, 'element_type', element_type)
+ anno.setanno(node, 'element_shape', element_shape)
anno.setanno(definition, 'element_type', element_type)
+ anno.setanno(definition, 'element_shape', element_shape)
# TODO(mdan): Should we update references between definition and here?
return self.generic_visit(node)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
index 95cbf5ca79..32b1148ab2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis import activity
from tensorflow.contrib.autograph.pyct.static_analysis import live_values
from tensorflow.contrib.autograph.pyct.static_analysis import type_info
@@ -62,21 +61,18 @@ class TypeInfoResolverTest(test.TestCase):
namespace,
arg_types=None):
node, source = parser.parse_entity(test_fn)
- ctx = context.EntityContext(
- namer=None,
+ entity_info = transformer.EntityInfo(
source_code=source,
source_file=None,
namespace=namespace,
arg_values=None,
arg_types=arg_types,
- owner_type=None,
- recursive=True,
- type_annotation_func=utils.set_element_type)
+ owner_type=None)
node = qual_names.resolve(node)
- node = activity.resolve(node, ctx)
- node = live_values.resolve(node, ctx, {})
- node = type_info.resolve(node, ctx)
- node = live_values.resolve(node, ctx, {})
+ node = activity.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
return node
def test_constructor_detection(self):
@@ -147,7 +143,7 @@ class TypeInfoResolverTest(test.TestCase):
opt.minimize(0)
node = self._parse_and_analyze(
- test_fn, {'training': training},
+ test_fn, {},
arg_types={
'opt': (training.GradientDescentOptimizer.__name__,
training.GradientDescentOptimizer)
@@ -180,35 +176,6 @@ class TypeInfoResolverTest(test.TestCase):
method_call = node.body[0].body[1].value.func
self.assertFalse(anno.hasanno(method_call, 'live_val'))
- def test_type_annotation(self):
-
- class Foo(object):
- pass
-
- def test_fn():
- f = []
- f = utils.set_element_type(f, Foo)
- return f
-
- node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
- f_def = node.body[0].body[0].value
- self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
- f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
-
- def test_type_annotation_args(self):
-
- class Foo(object):
- pass
-
- def test_fn(f):
- utils.set_element_type(f, Foo)
- return f
-
- node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
- f_ref = node.body[0].body[1].value
- self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
-
def test_nested_unpacking(self):
class Foo(object):
@@ -223,32 +190,13 @@ class TypeInfoResolverTest(test.TestCase):
node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
a, b, c = node.body[0].body[1].value.elts
- self.assertEquals(Foo, anno.getanno(a, 'type'))
- self.assertEquals(Bar, anno.getanno(b, 'type'))
- self.assertEquals(Foo, anno.getanno(c, 'type'))
+ self.assertEquals(anno.getanno(a, 'type'), Foo)
+ self.assertEquals(anno.getanno(b, 'type'), Bar)
+ self.assertEquals(anno.getanno(c, 'type'), Foo)
self.assertFalse(anno.hasanno(a, 'live_val'))
self.assertFalse(anno.hasanno(b, 'live_val'))
self.assertFalse(anno.hasanno(c, 'live_val'))
- def test_inner_scope(self):
-
- def test_fn():
- a = []
- utils.set_element_type(a, 1)
- for _ in a:
- b = []
- utils.set_element_type(b, 2)
- return a, b
-
- node = self._parse_and_analyze(test_fn, {'utils': utils})
- a, b = node.body[0].body[2].body[2].value.elts
- self.assertEquals(1, anno.getanno(a, 'element_type'))
- self.assertEquals(2, anno.getanno(b, 'element_type'))
- self.assertFalse(anno.hasanno(a, 'type'))
- self.assertFalse(anno.hasanno(b, 'type'))
- self.assertFalse(anno.hasanno(a, 'live_val'))
- self.assertFalse(anno.hasanno(b, 'live_val'))
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index baf7923fff..9c479ebc2f 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements):
raise ValueError(
'single expression expected; for more general templates use replace')
node = replacement[0]
- if not isinstance(node, gast.Expr):
- raise ValueError(
- 'the template is expected to generate an expression node; instead '
- 'found %s' % node)
- return node.value
+ node = qual_names.resolve(node)
+
+ if isinstance(node, gast.Expr):
+ return node.value
+ elif isinstance(node, gast.Name):
+ return node
+
+ raise ValueError(
+ 'the template is expected to generate an expression or a name node;'
+ ' instead found %s' % node)
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index 60bca8b38d..7655811830 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -32,15 +32,40 @@ class AutographParseError(SyntaxError):
pass
-def try_ast_to_source(node):
- try:
- return compiler.ast_to_source(node)
- except AssertionError:
- return '<could not convert AST to source>'
+# TODO(mdan): Use namedtuple.
+class EntityInfo(object):
+ """Contains information about a Python entity. Immutable.
+
+ Examples of entities include functions and classes.
+
+ Attributes:
+ source_code: The entity's source code.
+ source_file: The entity's source file.
+ namespace: Dict[str, ], containing symbols visible to the entity
+ (excluding parameters).
+ arg_values: dict[str->*], containing parameter values, if known.
+ arg_types: dict[str->*], containing parameter types, if known.
+ owner_type: The surrounding class type of the function, if present.
+ """
+
+ # TODO(mdan): Remove the default and update tests.
+ def __init__(self, source_code, source_file, namespace, arg_values, arg_types,
+ owner_type):
+ self.source_code = source_code
+ self.source_file = source_file
+ self.namespace = namespace
+ self.arg_values = {} if arg_values is None else arg_values
+ self.arg_types = {} if arg_types is None else arg_types
+ self.owner_type = owner_type
class Base(gast.NodeTransformer):
- """Base class for specialized transformers.
+ """Base class for general-purpose code transformers transformers.
+
+ This is an extension of ast.NodeTransformer that provides a few additional
+ functions, like state tracking within the scope of arbitrary node, helpers
+ for processing code blocks, debugging, mapping of transformed code to
+ original code, and others.
Scope-local state tracking: to keep state across nodes, at the level of
(possibly nested) scopes, use enter/exit_local_scope and set/get_local.
@@ -48,15 +73,17 @@ class Base(gast.NodeTransformer):
when they are not properly paired.
"""
- def __init__(self, context):
+ # TODO(mdan): Document all extra features.
+
+ def __init__(self, entity_info):
"""Initialize the transformer. Subclasses should call this.
Args:
- context: An EntityContext.
+ entity_info: An EntityInfo object.
"""
self._lineno = 0
self._col_offset = 0
- self.context = context
+ self.entity_info = entity_info
self._enclosing_entities = []
# A stack that allows keeping mutable, scope-local state where scopes may be
@@ -191,7 +218,7 @@ class Base(gast.NodeTransformer):
# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
def apply_to_single_assignments(self, targets, values, apply_fn):
- """Applies a fuction to each individual assignment.
+ """Applies a function to each individual assignment.
This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
It tries to break down the unpacking if possible. In effect, it has the same
@@ -219,7 +246,7 @@ class Base(gast.NodeTransformer):
targets field of an ast.Assign node.
values: an AST node.
apply_fn: a function of a single argument, which will be called with the
- respective nodes of each single assignment. The signaure is
+ respective nodes of each single assignment. The signature is
apply_fn(target, value), no return value.
"""
if not isinstance(targets, (list, tuple)):
@@ -237,9 +264,15 @@ class Base(gast.NodeTransformer):
# TODO(mdan): Look into allowing to rewrite the AST here.
apply_fn(target, values)
+ def _get_source(self, node):
+ try:
+ return compiler.ast_to_source(node)
+ except AssertionError:
+ return '<could not convert AST to source>'
+
def visit(self, node):
- source_code = self.context.source_code
- source_file = self.context.source_file
+ source_code = self.entity_info.source_code
+ source_file = self.entity_info.source_file
did_enter_function = False
local_scope_size_at_entry = len(self._local_scope_state)
@@ -275,7 +308,7 @@ class Base(gast.NodeTransformer):
except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
- e.__class__.__name__, str(e), try_ast_to_source(node),
+ e.__class__.__name__, str(e), self._get_source(node),
pretty_printer.fmt(node, color=False))
if source_code:
line = source_code.splitlines()[self._lineno - 1]
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py
index f110e79605..baf04653ae 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/contrib/autograph/pyct/transformer_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import context
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.platform import test
@@ -29,16 +28,14 @@ from tensorflow.python.platform import test
class TransformerTest(test.TestCase):
- def _context_for_testing(self):
- return context.EntityContext(
- namer=None,
+ def _simple_source_info(self):
+ return transformer.EntityInfo(
source_code=None,
source_file=None,
namespace=None,
arg_values=None,
arg_types=None,
- owner_type=None,
- recursive=False)
+ owner_type=None)
def test_entity_scope_tracking(self):
@@ -55,7 +52,7 @@ class TransformerTest(test.TestCase):
anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
return self.generic_visit(node)
- tr = TestTransformer(self._context_for_testing())
+ tr = TestTransformer(self._simple_source_info())
def test_function():
a = 0
@@ -118,7 +115,7 @@ class TransformerTest(test.TestCase):
def visit_For(self, node):
return self._annotate_result(node)
- tr = TestTransformer(self._context_for_testing())
+ tr = TestTransformer(self._simple_source_info())
def test_function(a):
"""Docstring."""
@@ -157,7 +154,7 @@ class TransformerTest(test.TestCase):
self.exit_local_scope()
return node
- tr = TestTransformer(self._context_for_testing())
+ tr = TestTransformer(self._simple_source_info())
def no_exit(a):
if a > 0:
@@ -196,7 +193,7 @@ class TransformerTest(test.TestCase):
z = y
return z
- tr = TestTransformer(self._context_for_testing())
+ tr = TestTransformer(self._simple_source_info())
node, _ = parser.parse_entity(test_function)
node = tr.visit(node)
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index d3a1b94688..d82c17bf2a 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -33,6 +33,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:list_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python/data/ops:dataset_ops",
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index 211e8eaee9..998087e056 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.contrib.autograph.utils import type_check
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
@@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs):
return dynamic_range(*args, **kwargs)
if f is range:
return dynamic_range(*args, **kwargs)
- raise ValueError('%s is not supported' % f)
+ if f is int:
+ return dynamic_int(*args, **kwargs)
+ if f is float:
+ return dynamic_float(*args, **kwargs)
+
+ raise NotImplementedError(
+ 'The "%s" builtin is not yet supported.' % f.__name__)
def dynamic_len(list_or_tensor):
@@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor):
return len(list_or_tensor)
+def dynamic_int(num_or_tensor, **kwargs):
+ """Implementation of int() using dynamic dispatch."""
+ if tensor_util.is_tensor(num_or_tensor):
+ return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
+ return int(num_or_tensor)
+
+
+def dynamic_float(num_or_tensor, **kwargs):
+ """Implementation of float() using dynamic dispatch."""
+ if tensor_util.is_tensor(num_or_tensor):
+ return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
+ return float(num_or_tensor)
+
+
def dynamic_range(start_or_stop, stop=None, step=None):
"""Implementation of range using dynamic dispatch."""
if type_check.is_tensor(start_or_stop, stop, step):
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
index 163e698407..0c2312178a 100644
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ b/tensorflow/contrib/autograph/utils/builtins_test.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib.autograph.utils import builtins
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
@@ -77,7 +78,7 @@ class BuiltinsTest(test.TestCase):
return x
# Functions that just have the names of builtins are rejected.
- with self.assertRaises(ValueError):
+ with self.assertRaises(NotImplementedError):
self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
if six.PY2:
self.assertListEqual(
@@ -87,6 +88,20 @@ class BuiltinsTest(test.TestCase):
self.assertListEqual(
list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
+ def test_casts(self):
+ i = constant_op.constant(2, dtype=dtypes.int32)
+ f = constant_op.constant(1.0, dtype=dtypes.float32)
+
+ self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
+ self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
+ self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
+ self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
+
+ self.assertEqual(builtins.dynamic_builtin(int, True), 1)
+ self.assertEqual(builtins.dynamic_builtin(int, False), 0)
+ self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
+ self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
+
def test_dynamic_print_tf(self):
try:
out_capturer = six.StringIO()
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index b6dae3cc1f..b27a19b16c 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -50,6 +50,14 @@ cc_library(
)
cc_library(
+ name = "serial_device_batch_scheduler",
+ hdrs = ["serial_device_batch_scheduler.h"],
+ deps = [
+ "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler",
+ ],
+)
+
+cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],
deps = [
diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py
index 44fa5f42a7..1e503a097a 100644
--- a/tensorflow/contrib/batching/__init__.py
+++ b/tensorflow/contrib/batching/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Ops and modules related to batch.
+@@batch_function_v1
@@batch_function
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index 921d6917a4..55faad983f 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
@@ -57,8 +58,6 @@ def batch_function(num_batch_threads,
max_batch_size,
batch_timeout_micros,
allowed_batch_sizes=None,
- grad_timeout_micros=60 * 1000 * 1000,
- unbatch_timeout_micros=60 * 1000 * 1000,
max_enqueued_batches=10):
"""Batches the computation done by the decorated function.
@@ -93,6 +92,66 @@ def batch_function(num_batch_threads,
does nothing. Otherwise, supplies a list of batch sizes, causing the op
to pad batches up to one of those sizes. The entries must increase
monotonically, and the final entry must equal max_batch_size.
+ max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
+
+ Returns:
+ The decorated function will return the unbatched computation output Tensors.
+ """
+
+ def decorator(fn): # pylint: disable=missing-docstring
+
+ def decorated(*args): # pylint: disable=missing-docstring
+ types = [arg.dtype for arg in args]
+
+ @function.Defun(*types)
+ def computation(*computation_args):
+ return fn(*computation_args)
+
+ with ops.name_scope("batch") as name:
+ for a in args:
+ if not isinstance(a, ops.Tensor):
+ raise ValueError("All arguments to functions decorated with "
+ "`batch_function` are supposed to be Tensors; "
+ "found %s" % repr(a))
+ return gen_batch_ops.batch_function(
+ num_batch_threads=num_batch_threads,
+ max_batch_size=max_batch_size,
+ batch_timeout_micros=batch_timeout_micros,
+ allowed_batch_sizes=allowed_batch_sizes,
+ max_enqueued_batches=max_enqueued_batches,
+ shared_name=name,
+ f=computation,
+ in_tensors=list(args),
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ return decorated
+
+ return decorator
+
+
+def batch_function_v1(num_batch_threads,
+ max_batch_size,
+ batch_timeout_micros,
+ allowed_batch_sizes=None,
+ grad_timeout_micros=60 * 1000 * 1000,
+ unbatch_timeout_micros=60 * 1000 * 1000,
+ max_enqueued_batches=10):
+ """Batches the computation done by the decorated function.
+
+ This is the older version of batch_function(). Please use the former instead
+ of this.
+
+ Args:
+ num_batch_threads: Number of scheduling threads for processing batches
+ of work. Determines the number of batches processed in parallel.
+ max_batch_size: Batch sizes will never be bigger than this.
+ batch_timeout_micros: Maximum number of microseconds to wait before
+ outputting an incomplete batch.
+ allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
+ does nothing. Otherwise, supplies a list of batch sizes, causing the op
+ to pad batches up to one of those sizes. The entries must increase
+ monotonically, and the final entry must equal max_batch_size.
grad_timeout_micros: The timeout to use for the gradient. See the
documentation of the unbatch op for more details. Defaults to 60s.
unbatch_timeout_micros: The timeout to use for unbatching. See the
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index e22f978dde..7846814546 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -23,7 +23,10 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework.errors import InvalidArgumentError
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -185,12 +188,62 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBasicUnbatchV1Decorated(self):
+ """Tests that the batch_function_v1 decorator works."""
+ with self.test_session() as sess:
+ @batch_ops.batch_function_v1(1, 10, 100000)
+ def computation(in_t):
+ return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.test_session() as sess:
+ # TODO(apassos): Removing this line causes test flakiness! Ideally should
+ # be investigated.
+ default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
+
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchDecoratedWithCapturedInput(self):
+ """Tests that the batch_function decorator works."""
+ with self.test_session() as sess:
+ captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
+ captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+
+ @batch_ops.batch_function(1, 10, 100000)
+ def computation(in_t):
+ return in_t + captured_inp0 - captured_inp1
+
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []
@@ -205,6 +258,114 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBatchFunctionOp(self):
+ """Tests that the batch_function op works."""
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.int32)
+ def computation(in_t):
+ return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = gen_batch_ops.batch_function(
+ [inp],
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000,
+ Tout=[dtypes.int32],
+ f=computation,
+ captured_tensors=computation.captured_inputs)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchFunctionOpWithCapturedInput(self):
+ """Tests that batch_function op works with captured input."""
+ with self.test_session() as sess:
+ captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
+ captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+
+ @function.Defun(dtypes.int32)
+ def computation(inp):
+ return inp + captured_inp0 - captured_inp1
+
+ result = gen_batch_ops.batch_function(
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ allowed_batch_sizes=[3, 10],
+ batching_queue="",
+ f=computation,
+ in_tensors=[inp],
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchFunctionOpWithInputError(self):
+ """Tests that batch_function op works with error in the inputs."""
+ with self.test_session() as sess:
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def computation(in0, in1):
+ return in0 + in1
+
+ result = gen_batch_ops.batch_function(
+ [inp], # computation actually expects 2 inputs.
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ batching_queue="",
+ f=computation,
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ with self.assertRaisesRegexp(InvalidArgumentError,
+ ".*2 arguments.*but 1.*"):
+ sess.run([result], feed_dict={inp: [2]})
+
+ def testBasicUnbatchDecoratedWithReshape(self):
+ """Tests that the batch_function decorator works."""
+ with self.test_session() as sess:
+
+ @batch_ops.batch_function(1, 10, 100000)
+ def computation(in_t):
+ return array_ops.reshape(in_t, [-1]) + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [[2]]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h
new file mode 100644
index 0000000000..bf6b708361
--- /dev/null
+++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h
@@ -0,0 +1,21 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
+
+#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h"
+
+#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index d9e23646d8..9e6a146f67 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
-from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
@@ -256,50 +255,6 @@ class ExpectationTest(test.TestCase):
gradq_approx_kl_normal_normal_,
rtol=0.01, atol=0.)
- def test_docstring_example_gamma(self):
- with self.test_session() as sess:
- num_draws = int(1e5)
- concentration_p = constant_op.constant(1.)
- concentration_q = constant_op.constant(2.)
- p = gamma_lib.Gamma(concentration=concentration_p, rate=1.)
- q = gamma_lib.Gamma(concentration=concentration_q, rate=3.)
- approx_kl_gamma_gamma = monte_carlo_lib.expectation(
- f=lambda x: p.log_prob(x) - q.log_prob(x),
- samples=p.sample(num_draws, seed=42),
- log_prob=p.log_prob,
- use_reparametrization=(p.reparameterization_type
- == distribution_lib.FULLY_REPARAMETERIZED))
- exact_kl_gamma_gamma = kullback_leibler.kl_divergence(p, q)
- [exact_kl_gamma_gamma_, approx_kl_gamma_gamma_] = sess.run([
- exact_kl_gamma_gamma, approx_kl_gamma_gamma])
- self.assertEqual(
- False,
- p.reparameterization_type == distribution_lib.FULLY_REPARAMETERIZED)
- self.assertAllClose(exact_kl_gamma_gamma_, approx_kl_gamma_gamma_,
- rtol=0.01, atol=0.)
-
- # Compare gradients. (Not present in `docstring`.)
- gradp = lambda fp: gradients_impl.gradients(fp, concentration_p)[0]
- gradq = lambda fq: gradients_impl.gradients(fq, concentration_q)[0]
- [
- gradp_exact_kl_gamma_gamma_,
- gradq_exact_kl_gamma_gamma_,
- gradp_approx_kl_gamma_gamma_,
- gradq_approx_kl_gamma_gamma_,
- ] = sess.run([
- gradp(exact_kl_gamma_gamma),
- gradq(exact_kl_gamma_gamma),
- gradp(approx_kl_gamma_gamma),
- gradq(approx_kl_gamma_gamma),
- ])
- # Notice that variance (i.e., `rtol`) is higher when using score-trick.
- self.assertAllClose(gradp_exact_kl_gamma_gamma_,
- gradp_approx_kl_gamma_gamma_,
- rtol=0.05, atol=0.)
- self.assertAllClose(gradq_exact_kl_gamma_gamma_,
- gradq_approx_kl_gamma_gamma_,
- rtol=0.03, atol=0.)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
index 5770bcdd70..68fa415eea 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py
@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Monte Carlo integration and helpers.
-
-See the @{$python/contrib.bayesflow.monte_carlo} guide.
-"""
+"""Monte Carlo integration and helpers."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 032b859d46..68ead2f760 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -192,7 +192,7 @@ def _logspace_mean(log_values):
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
- """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
+ r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
This function computes the Monte-Carlo approximation of an expectation, i.e.,
diff --git a/tensorflow/contrib/bigtable/BUILD b/tensorflow/contrib/bigtable/BUILD
new file mode 100644
index 0000000000..71538e0770
--- /dev/null
+++ b/tensorflow/contrib/bigtable/BUILD
@@ -0,0 +1,213 @@
+# Cloud Bigtable client for TensorFlow
+
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_copts",
+ "tf_custom_op_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_cc_test",
+ "tf_py_test",
+)
+
+tf_custom_op_py_library(
+ name = "bigtable",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ dso = [
+ ":python/ops/_bigtable.so",
+ ],
+ kernels = [
+ ":bigtable_kernels",
+ ":bigtable_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":bigtable_ops",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data",
+ ],
+)
+
+KERNEL_FILES = [
+ "kernels/bigtable_kernels.cc",
+ "kernels/bigtable_lookup_dataset_op.cc",
+ "kernels/bigtable_prefix_key_dataset_op.cc",
+ "kernels/bigtable_range_key_dataset_op.cc",
+ "kernels/bigtable_sample_keys_dataset_op.cc",
+ "kernels/bigtable_sample_key_pairs_dataset_op.cc",
+ "kernels/bigtable_scan_dataset_op.cc",
+]
+
+tf_custom_op_library(
+ name = "python/ops/_bigtable.so",
+ srcs = KERNEL_FILES + [
+ "ops/bigtable_ops.cc",
+ ],
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_range_helpers",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "bigtable_ops",
+ deps = [":bigtable_ops_op_lib"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
+ "bigtable_ops",
+ "bigtable_test_ops",
+ ],
+)
+
+tf_kernel_library(
+ name = "bigtable_kernels",
+ srcs = KERNEL_FILES,
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_range_helpers",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+# A library for use in the bigtable kernels.
+cc_library(
+ name = "bigtable_lib_cc",
+ srcs = ["kernels/bigtable_lib.cc"],
+ hdrs = ["kernels/bigtable_lib.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+cc_library(
+ name = "bigtable_range_helpers",
+ srcs = ["kernels/bigtable_range_helpers.cc"],
+ hdrs = ["kernels/bigtable_range_helpers.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ ],
+)
+
+cc_library(
+ name = "bigtable_test_client",
+ srcs = ["kernels/test_kernels/bigtable_test_client.cc"],
+ hdrs = ["kernels/test_kernels/bigtable_test_client.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "@com_github_googleapis_googleapis//:bigtable_protos",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+tf_cc_test(
+ name = "bigtable_test_client_test",
+ srcs = ["kernels/test_kernels/bigtable_test_client_test.cc"],
+ tags = ["manual"],
+ deps = [
+ ":bigtable_test_client",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/bigtable:bigtable_client",
+ ],
+)
+
+tf_cc_test(
+ name = "bigtable_range_helpers_test",
+ size = "small",
+ srcs = ["kernels/bigtable_range_helpers_test.cc"],
+ deps = [
+ ":bigtable_range_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "bigtable_test_ops",
+ deps = [":bigtable_test_ops_op_lib"],
+)
+
+tf_custom_op_library(
+ name = "python/kernel_tests/_bigtable_test.so",
+ srcs = [
+ "kernels/test_kernels/bigtable_test_client_op.cc",
+ "ops/bigtable_test_ops.cc",
+ ],
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_test_client",
+ "@com_googlesource_code_re2//:re2",
+ ],
+)
+
+# Don't use tf_kernel_library because it prevents access to strings/stringprintf.h
+cc_library(
+ name = "bigtable_test_kernels",
+ srcs = [
+ "kernels/test_kernels/bigtable_test_client_op.cc",
+ ],
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [
+ ":bigtable_lib_cc",
+ ":bigtable_test_client",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@com_googlesource_code_re2//:re2",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "bigtable_test_py",
+ dso = [
+ ":python/kernel_tests/_bigtable_test.so",
+ ],
+ kernels = [
+ ":bigtable_test_kernels",
+ ":bigtable_test_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":bigtable_test_ops",
+ ],
+)
+
+tf_py_test(
+ name = "bigtable_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bigtable_ops_test.py"],
+ additional_deps = [
+ ":bigtable",
+ ":bigtable_test_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+ tags = ["manual"],
+)
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
new file mode 100644
index 0000000000..ef3c60069e
--- /dev/null
+++ b/tensorflow/contrib/bigtable/README.md
@@ -0,0 +1,10 @@
+# Bigtable #
+
+[Google Cloud Bigtable](https://cloud.google.com/bigtable/) is a high
+performance storage system that can store and serve training data. This contrib
+package contains an experimental integration with TensorFlow.
+
+> **Status: Highly experimental.** The current implementation is very much in
+> flux. Please use at your own risk! :-)
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
diff --git a/tensorflow/contrib/bigtable/__init__.py b/tensorflow/contrib/bigtable/__init__.py
new file mode 100644
index 0000000000..7df054637c
--- /dev/null
+++ b/tensorflow/contrib/bigtable/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Cloud Bigtable Client for TensorFlow.
+
+This contrib package allows TensorFlow to interface directly with Cloud Bigtable
+for high-speed data loading.
+
+@@BigtableClient
+@@BigTable
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable
+from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'BigTable',
+ 'BigtableClient',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
new file mode 100644
index 0000000000..70923e6287
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -0,0 +1,355 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+namespace {
+
+class BigtableClientOp : public OpKernel {
+ public:
+ explicit BigtableClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("project_id", &project_id_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("instance_id", &instance_id_));
+ OP_REQUIRES(ctx, !project_id_.empty(),
+ errors::InvalidArgument("project_id must be non-empty"));
+ OP_REQUIRES(ctx, !instance_id_.empty(),
+ errors::InvalidArgument("instance_id must be non-empty"));
+
+ OP_REQUIRES_OK(
+ ctx, ctx->GetAttr("connection_pool_size", &connection_pool_size_));
+ // If left unset by the client code, set it to a default of 100. Note: the
+ // cloud-cpp default of 4 concurrent connections is far too low for high
+ // performance streaming.
+ if (connection_pool_size_ == -1) {
+ connection_pool_size_ = 100;
+ }
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_receive_message_size",
+ &max_receive_message_size_));
+ // If left unset by the client code, set it to a default of 100. Note: the
+ // cloud-cpp default of 4 concurrent connections is far too low for high
+ // performance streaming.
+ if (max_receive_message_size_ == -1) {
+ max_receive_message_size_ = 1 << 24; // 16 MBytes
+ }
+ OP_REQUIRES(ctx, max_receive_message_size_ > 0,
+ errors::InvalidArgument("connection_pool_size must be > 0"));
+ }
+
+ ~BigtableClientOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableClientResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ BigtableClientResource* resource;
+ OP_REQUIRES_OK(
+ ctx,
+ mgr->LookupOrCreate<BigtableClientResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](
+ BigtableClientResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ auto client_options =
+ google::cloud::bigtable::ClientOptions()
+ .set_connection_pool_size(connection_pool_size_)
+ .set_data_endpoint("batch-bigtable.googleapis.com");
+ auto channel_args = client_options.channel_arguments();
+ channel_args.SetMaxReceiveMessageSize(
+ max_receive_message_size_);
+ channel_args.SetUserAgentPrefix("tensorflow");
+ client_options.set_channel_arguments(channel_args);
+ std::shared_ptr<google::cloud::bigtable::DataClient> client =
+ google::cloud::bigtable::CreateDefaultDataClient(
+ project_id_, instance_id_, std::move(client_options));
+ *ret = new BigtableClientResource(project_id_, instance_id_,
+ std::move(client));
+ return Status::OK();
+ }));
+ core::ScopedUnref resource_cleanup(resource);
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableClientResource>()));
+ }
+
+ private:
+ string project_id_;
+ string instance_id_;
+ int64 connection_pool_size_;
+ int32 max_receive_message_size_;
+
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableClient").Device(DEVICE_CPU),
+ BigtableClientOp);
+
+class BigtableTableOp : public OpKernel {
+ public:
+ explicit BigtableTableOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name", &table_));
+ OP_REQUIRES(ctx, !table_.empty(),
+ errors::InvalidArgument("table_name must be non-empty"));
+ }
+
+ ~BigtableTableOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableTableResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+
+ BigtableClientResource* client_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &client_resource));
+ core::ScopedUnref unref_client(client_resource);
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(
+ ctx, mgr->LookupOrCreate<BigtableTableResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, client_resource](BigtableTableResource** ret) {
+ *ret = new BigtableTableResource(client_resource, table_);
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableTableResource>()));
+ }
+
+ private:
+ string table_; // Note: this is const after construction.
+
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
+ BigtableTableOp);
+
+class ToBigtableOp : public AsyncOpKernel {
+ public:
+ explicit ToBigtableOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("to_bigtable_op_", SanitizeThreadSuffix(name())),
+ /* num_threads = */ 1, /* low_latency_hint = */ false)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ thread_pool_->Schedule([this, ctx, done]() {
+ const Tensor* column_families_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("column_families", &column_families_tensor), done);
+ OP_REQUIRES_ASYNC(
+ ctx, column_families_tensor->dims() == 1,
+ errors::InvalidArgument("`column_families` must be a vector."), done);
+
+ const Tensor* columns_tensor;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("columns", &columns_tensor), done);
+ OP_REQUIRES_ASYNC(ctx, columns_tensor->dims() == 1,
+ errors::InvalidArgument("`columns` must be a vector."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx,
+ columns_tensor->NumElements() ==
+ column_families_tensor->NumElements(),
+ errors::InvalidArgument("len(column_families) != len(columns)"),
+ done);
+
+ std::vector<string> column_families;
+ column_families.reserve(column_families_tensor->NumElements());
+ std::vector<string> columns;
+ columns.reserve(column_families_tensor->NumElements());
+ for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) {
+ column_families.push_back(column_families_tensor->flat<string>()(i));
+ columns.push_back(columns_tensor->flat<string>()(i));
+ }
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(1), &dataset), done);
+
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "ToBigtableOpIterator", &iterator),
+ done);
+
+ int64 timestamp_int;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ParseScalarArgument<int64>(ctx, "timestamp", &timestamp_int),
+ done);
+ OP_REQUIRES_ASYNC(ctx, timestamp_int >= -1,
+ errors::InvalidArgument("timestamp must be >= -1"),
+ done);
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &resource), done);
+ core::ScopedUnref resource_cleanup(resource);
+
+ std::vector<Tensor> components;
+ components.reserve(dataset->output_dtypes().size());
+ bool end_of_sequence = false;
+ do {
+ ::google::cloud::bigtable::BulkMutation mutation;
+ // TODO(saeta): Make # of mutations configurable.
+ for (uint64 i = 0; i < 100 && !end_of_sequence; ++i) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
+ done);
+ if (!end_of_sequence) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CreateMutation(std::move(components), column_families, columns,
+ timestamp_int, &mutation),
+ done);
+ }
+ components.clear();
+ }
+ grpc::Status mutation_status;
+ std::vector<::google::cloud::bigtable::FailedMutation> failures =
+ resource->table().BulkApply(std::move(mutation), mutation_status);
+ if (!mutation_status.ok()) {
+ LOG(ERROR) << "Failure applying mutation: "
+ << mutation_status.error_code() << " - "
+ << mutation_status.error_message() << " ("
+ << mutation_status.error_details() << ").";
+ }
+ if (!failures.empty()) {
+ for (const auto& failure : failures) {
+ LOG(ERROR) << "Failure applying mutation on row ("
+ << failure.original_index()
+ << "): " << failure.mutation().row_key()
+ << " - error: " << failure.status().error_message()
+ << " (Details: " << failure.status().error_details()
+ << ").";
+ }
+ }
+ OP_REQUIRES_ASYNC(
+ ctx, failures.empty() && mutation_status.ok(),
+ errors::Unknown("Failure while writing to BigTable: ",
+ mutation_status.error_code(), " - ",
+ mutation_status.error_message(), " (",
+ mutation_status.error_details(),
+ "), # of mutation failures: ", failures.size(),
+ ". See the log for the specific error details."),
+ done);
+ } while (!end_of_sequence);
+ done();
+ });
+ }
+
+ private:
+ static string SanitizeThreadSuffix(string suffix) {
+ string clean;
+ for (int i = 0; i < suffix.size(); ++i) {
+ const char ch = suffix[i];
+ if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
+ (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
+ clean += ch;
+ } else {
+ clean += '_';
+ }
+ }
+ return clean;
+ }
+
+ Status CreateMutation(
+ std::vector<Tensor> tensors, const std::vector<string>& column_families,
+ const std::vector<string>& columns, int64 timestamp_int,
+ ::google::cloud::bigtable::BulkMutation* bulk_mutation) {
+ if (tensors.size() != column_families.size() + 1) {
+ return errors::InvalidArgument(
+ "Iterator produced a set of Tensors shorter than expected");
+ }
+ ::google::cloud::bigtable::SingleRowMutation mutation(
+ std::move(tensors[0].scalar<string>()()));
+ std::chrono::milliseconds timestamp(timestamp_int);
+ for (size_t i = 1; i < tensors.size(); ++i) {
+ if (!TensorShapeUtils::IsScalar(tensors[i].shape())) {
+ return errors::Internal("Output tensor ", i, " was not a scalar");
+ }
+ if (timestamp_int == -1) {
+ mutation.emplace_back(::google::cloud::bigtable::SetCell(
+ column_families[i - 1], columns[i - 1],
+ std::move(tensors[i].scalar<string>()())));
+ } else {
+ mutation.emplace_back(::google::cloud::bigtable::SetCell(
+ column_families[i - 1], columns[i - 1], timestamp,
+ std::move(tensors[i].scalar<string>()())));
+ }
+ }
+ bulk_mutation->emplace_back(std::move(mutation));
+ return Status::OK();
+ }
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
+ ToBigtableOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
new file mode 100644
index 0000000000..2514575f30
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+
+namespace tensorflow {
+
+Status GrpcStatusToTfStatus(const ::grpc::Status& status) {
+ if (status.ok()) {
+ return Status::OK();
+ }
+ auto grpc_code = status.error_code();
+ if (status.error_code() == ::grpc::StatusCode::ABORTED ||
+ status.error_code() == ::grpc::StatusCode::UNAVAILABLE ||
+ status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) {
+ grpc_code = ::grpc::StatusCode::INTERNAL;
+ }
+ return Status(
+ static_cast<::tensorflow::error::Code>(status.error_code()),
+ strings::StrCat("Error reading from BigTable: ", status.error_message(),
+ " (Details: ", status.error_details(), ")"));
+}
+
+string RegexFromStringSet(const std::vector<string>& strs) {
+ CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty.";
+ std::unordered_set<string> uniq(strs.begin(), strs.end());
+ if (uniq.size() == 1) {
+ return *uniq.begin();
+ }
+ return str_util::Join(uniq, "|");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
new file mode 100644
index 0000000000..12d8256dea
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
+
+// Note: we use bigtable/client/internal/table.h as this is the no-exception API
+
+#include "google/cloud/bigtable/data_client.h"
+#include "google/cloud/bigtable/internal/table.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+Status GrpcStatusToTfStatus(const ::grpc::Status& status);
+
+string RegexFromStringSet(const std::vector<string>& strs);
+
+class BigtableClientResource : public ResourceBase {
+ public:
+ BigtableClientResource(
+ string project_id, string instance_id,
+ std::shared_ptr<google::cloud::bigtable::DataClient> client)
+ : project_id_(std::move(project_id)),
+ instance_id_(std::move(instance_id)),
+ client_(std::move(client)) {}
+
+ std::shared_ptr<google::cloud::bigtable::DataClient> get_client() {
+ return client_;
+ }
+
+ string DebugString() override {
+ return strings::StrCat("BigtableClientResource(project_id: ", project_id_,
+ ", instance_id: ", instance_id_, ")");
+ }
+
+ private:
+ const string project_id_;
+ const string instance_id_;
+ std::shared_ptr<google::cloud::bigtable::DataClient> client_;
+};
+
+class BigtableTableResource : public ResourceBase {
+ public:
+ BigtableTableResource(BigtableClientResource* client, string table_name)
+ : client_(client),
+ table_name_(std::move(table_name)),
+ table_(client->get_client(), table_name_) {
+ client_->Ref();
+ }
+
+ ~BigtableTableResource() override { client_->Unref(); }
+
+ ::google::cloud::bigtable::noex::Table& table() { return table_; }
+
+ string DebugString() override {
+ return strings::StrCat(
+ "BigtableTableResource(client: ", client_->DebugString(),
+ ", table: ", table_name_, ")");
+ }
+
+ private:
+ BigtableClientResource* client_; // Ownes one ref.
+ const string table_name_;
+ ::google::cloud::bigtable::noex::Table table_;
+};
+
+// BigtableReaderDatasetIterator is an abstract class for iterators from
+// datasets that are "readers" (source datasets, not transformation datasets)
+// that read from Bigtable.
+template <typename Dataset>
+class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
+ public:
+ explicit BigtableReaderDatasetIterator(
+ const typename DatasetIterator<Dataset>::Params& params)
+ : DatasetIterator<Dataset>(params), iterator_(nullptr, false) {}
+
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureIteratorInitialized());
+ if (iterator_ == reader_->end()) {
+ grpc::Status status = reader_->Finish();
+ if (status.ok()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ return GrpcStatusToTfStatus(status);
+ }
+ *end_of_sequence = false;
+ google::cloud::bigtable::Row& row = *iterator_;
+ Status s = ParseRow(ctx, row, out_tensors);
+ // Ensure we always advance.
+ ++iterator_;
+ return s;
+ }
+
+ protected:
+ virtual ::google::cloud::bigtable::RowRange MakeRowRange() = 0;
+ virtual ::google::cloud::bigtable::Filter MakeFilter() = 0;
+ virtual Status ParseRow(IteratorContext* ctx,
+ const ::google::cloud::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) = 0;
+
+ private:
+ Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (reader_) {
+ return Status::OK();
+ }
+
+ auto rows = MakeRowRange();
+ auto filter = MakeFilter();
+
+ // Note: the this in `this->dataset()` below is necessary due to namespace
+ // name conflicts.
+ reader_.reset(new ::google::cloud::bigtable::RowReader(
+ this->dataset()->table()->table().ReadRows(rows, filter)));
+ iterator_ = reader_->begin();
+ return Status::OK();
+ }
+
+ mutex mu_;
+ std::unique_ptr<::google::cloud::bigtable::RowReader> reader_ GUARDED_BY(mu_);
+ ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
new file mode 100644
index 0000000000..9e49fa35db
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -0,0 +1,221 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ using UnaryDatasetOpKernel::UnaryDatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ BigtableTableResource* table;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table));
+
+ std::vector<string> column_families;
+ std::vector<string> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ OP_REQUIRES(
+ ctx, column_families.size() == columns.size(),
+ errors::InvalidArgument("len(columns) != len(column_families)"));
+
+ const uint64 num_outputs = columns.size() + 1;
+ std::vector<PartialTensorShape> output_shapes;
+ output_shapes.reserve(num_outputs);
+ DataTypeVector output_types;
+ output_types.reserve(num_outputs);
+ for (uint64 i = 0; i < num_outputs; ++i) {
+ output_shapes.push_back({});
+ output_types.push_back(DT_STRING);
+ }
+
+ *output =
+ new Dataset(ctx, input, table, std::move(column_families),
+ std::move(columns), output_types, std::move(output_shapes));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ BigtableTableResource* table,
+ std::vector<string> column_families,
+ std::vector<string> columns,
+ const DataTypeVector& output_types,
+ std::vector<PartialTensorShape> output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ table_(table),
+ column_families_(std::move(column_families)),
+ columns_(std::move(columns)),
+ output_types_(output_types),
+ output_shapes_(std::move(output_shapes)),
+ filter_(MakeFilter(column_families_, columns_)) {
+ table_->Ref();
+ input_->Ref();
+ }
+
+ ~Dataset() override {
+ table_->Unref();
+ input_->Unref();
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableLookupDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "BigtableLookupDatasetOp::Dataset";
+ }
+
+ private:
+ static ::google::cloud::bigtable::Filter MakeFilter(
+ const std::vector<string>& column_families,
+ const std::vector<string>& columns) {
+ string column_family_regex = RegexFromStringSet(column_families);
+ string column_regex = RegexFromStringSet(columns);
+
+ return ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1),
+ ::google::cloud::bigtable::Filter::FamilyRegex(column_family_regex),
+ ::google::cloud::bigtable::Filter::ColumnRegex(column_regex));
+ }
+
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_); // Sequence requests.
+ std::vector<Tensor> input_tensors;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ return Status::OK();
+ }
+ if (input_tensors.size() != 1) {
+ return errors::InvalidArgument(
+ "Upstream iterator (", dataset()->input_->DebugString(),
+ ") did not produce a single `tf.string` `tf.Tensor`. It "
+ "produced ",
+ input_tensors.size(), " tensors.");
+ }
+ if (input_tensors[0].NumElements() == 0) {
+ return errors::InvalidArgument("Upstream iterator (",
+ dataset()->input_->DebugString(),
+ ") return an empty set of keys.");
+ }
+ if (input_tensors[0].NumElements() == 1) {
+ // Single key lookup.
+ ::grpc::Status status;
+ auto pair = dataset()->table_->table().ReadRow(
+ input_tensors[0].scalar<string>()(), dataset()->filter_, status);
+ if (!status.ok()) {
+ return GrpcStatusToTfStatus(status);
+ }
+ if (!pair.first) {
+ return errors::DataLoss("Row key '",
+ input_tensors[0].scalar<string>()(),
+ "' not found.");
+ }
+ TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors));
+ } else {
+ // Batched get.
+ return errors::Unimplemented(
+ "BigtableLookupDataset doesn't yet support batched retrieval.");
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status ParseRow(IteratorContext* ctx,
+ const ::google::cloud::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) {
+ out_tensors->reserve(dataset()->columns_.size() + 1);
+ Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
+ row_key_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(row_key_tensor));
+
+ if (row.cells().size() > 2 * dataset()->columns_.size()) {
+ LOG(WARNING) << "An excessive number of columns ("
+ << row.cells().size()
+ << ") were retrieved when reading row: "
+ << row.row_key();
+ }
+
+ for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
+ Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
+ bool found_column = false;
+ for (auto cell_itr = row.cells().begin();
+ !found_column && cell_itr != row.cells().end(); ++cell_itr) {
+ if (cell_itr->family_name() == dataset()->column_families_[i] &&
+ string(cell_itr->column_qualifier()) ==
+ dataset()->columns_[i]) {
+ col_tensor.scalar<string>()() = string(cell_itr->value());
+ found_column = true;
+ }
+ }
+ if (!found_column) {
+ return errors::DataLoss("Column ", dataset()->column_families_[i],
+ ":", dataset()->columns_[i],
+ " not found in row: ", row.row_key());
+ }
+ out_tensors->emplace_back(std::move(col_tensor));
+ }
+ return Status::OK();
+ }
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ BigtableTableResource* table_;
+ const std::vector<string> column_families_;
+ const std::vector<string> columns_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const ::google::cloud::bigtable::Filter filter_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
+ BigtableLookupDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
new file mode 100644
index 0000000000..e960719614
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ *output = new Dataset(ctx, resource, std::move(prefix));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix)
+ : GraphDatasetBase(ctx), table_(table), prefix_(std::move(prefix)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtablePrefixKeyDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtablePrefixKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::google::cloud::bigtable::RowRange MakeRowRange() override {
+ return ::google::cloud::bigtable::RowRange::Prefix(dataset()->prefix_);
+ }
+ ::google::cloud::bigtable::Filter MakeFilter() override {
+ return ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::CellsRowLimit(1),
+ ::google::cloud::bigtable::Filter::StripValueTransformer());
+ }
+ Status ParseRow(IteratorContext* ctx,
+ const ::google::cloud::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
+ output_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(output_tensor));
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* const table_;
+ const string prefix_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
+ BigtablePrefixKeyDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
new file mode 100644
index 0000000000..51965f6214
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.cc
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+string MakePrefixEndKey(const string& prefix) {
+ string end = prefix;
+ while (true) {
+ if (end.empty()) {
+ return end;
+ }
+ ++end[end.size() - 1];
+ if (end[end.size() - 1] == 0) {
+ // Handle wraparound case.
+ end = end.substr(0, end.size() - 1);
+ } else {
+ return end;
+ }
+ }
+}
+
+} // namespace
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromPrefix(string prefix) {
+ string end = MakePrefixEndKey(prefix);
+ VLOG(1) << "Creating MultiModeKeyRange from Prefix: " << prefix
+ << ", with end key: " << end;
+ return MultiModeKeyRange(std::move(prefix), std::move(end));
+}
+
+/* static */ MultiModeKeyRange MultiModeKeyRange::FromRange(string begin,
+ string end) {
+ return MultiModeKeyRange(std::move(begin), std::move(end));
+}
+
+const string& MultiModeKeyRange::begin_key() const { return begin_; }
+
+const string& MultiModeKeyRange::end_key() const { return end_; }
+
+bool MultiModeKeyRange::contains_key(StringPiece key) const {
+ if (StringPiece(begin_) > key) {
+ return false;
+ }
+ if (StringPiece(end_) <= key && !end_.empty()) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
new file mode 100644
index 0000000000..44c628e366
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Represents a continuous range of keys defined by either a prefix or a range.
+//
+// Ranges are represented as "half-open", where the beginning key is included
+// in the range, and the end_key is the first excluded key after the range.
+//
+// The range of keys can be specified either by a key prefix, or by an explicit
+// begin key and end key. All methods on this class are valid no matter which
+// way the range was specified.
+//
+// Example:
+// MultiModeKeyRange range = MultiModeKeyRange::FromPrefix("myPrefix");
+// if (range.contains_key("myPrefixedKey")) {
+// LOG(INFO) << "range from " << range.begin_key() << " to "
+// << range.end_key() << "contains \"myPrefixedKey\"";
+// }
+// if (!range.contains_key("randomKey")) {
+// LOG(INFO) << "range does not contain \"randomKey\"";
+// }
+// range = MultiModeKeyRange::FromRange("a_start_key", "z_end_key");
+class MultiModeKeyRange {
+ public:
+ static MultiModeKeyRange FromPrefix(string prefix);
+ static MultiModeKeyRange FromRange(string begin, string end);
+
+ // The first valid key in the range.
+ const string& begin_key() const;
+ // The first invalid key after the valid range.
+ const string& end_key() const;
+ // Returns true if the provided key is a part of the range, false otherwise.
+ bool contains_key(StringPiece key) const;
+
+ private:
+ MultiModeKeyRange(string begin, string end)
+ : begin_(std::move(begin)), end_(std::move(end)) {}
+
+ const string begin_;
+ const string end_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_RANGE_HELPERS_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
new file mode 100644
index 0000000000..1bfc547271
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_helpers_test.cc
@@ -0,0 +1,107 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(MultiModeKeyRangeTest, SimplePrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("prefix");
+ EXPECT_EQ("prefix", r.begin_key());
+ EXPECT_EQ("prefiy", r.end_key());
+ EXPECT_TRUE(r.contains_key("prefixed_key"));
+ EXPECT_FALSE(r.contains_key("not-prefixed-key"));
+ EXPECT_FALSE(r.contains_key("prefi"));
+ EXPECT_FALSE(r.contains_key("prefiy"));
+ EXPECT_FALSE(r.contains_key("early"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, Range) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("a", "b");
+ EXPECT_EQ("a", r.begin_key());
+ EXPECT_EQ("b", r.end_key());
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("ab"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key("bc"));
+ EXPECT_FALSE(r.contains_key("A"));
+ EXPECT_FALSE(r.contains_key("B"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, InvertedRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("b", "a");
+ EXPECT_FALSE(r.contains_key("a"));
+ EXPECT_FALSE(r.contains_key("b"));
+ EXPECT_FALSE(r.contains_key(""));
+}
+
+TEST(MultiModeKeyRangeTest, EmptyPrefix) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix("");
+ EXPECT_EQ("", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key(""));
+ EXPECT_TRUE(r.contains_key("a"));
+ EXPECT_TRUE(r.contains_key("z"));
+ EXPECT_TRUE(r.contains_key("A"));
+ EXPECT_TRUE(r.contains_key("ZZZZZZ"));
+}
+
+TEST(MultiModeKeyRangeTest, HalfRange) {
+ MultiModeKeyRange r = MultiModeKeyRange::FromRange("start", "");
+ EXPECT_EQ("start", r.begin_key());
+ EXPECT_EQ("", r.end_key());
+ EXPECT_TRUE(r.contains_key("start"));
+ EXPECT_TRUE(r.contains_key("starting"));
+ EXPECT_TRUE(r.contains_key("z-end"));
+ EXPECT_FALSE(r.contains_key(""));
+ EXPECT_FALSE(r.contains_key("early"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixWrapAround) {
+ string prefix = "abc\xff";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abd", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\xff\x07"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x15"));
+ EXPECT_TRUE(r.contains_key("abc\xff\x61"));
+ EXPECT_TRUE(r.contains_key("abc\xff\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abd"));
+}
+
+TEST(MultiModeKeyRangeTest, PrefixSignedWrapAround) {
+ string prefix = "abc\x7f";
+ MultiModeKeyRange r = MultiModeKeyRange::FromPrefix(prefix);
+ EXPECT_EQ(prefix, r.begin_key());
+ EXPECT_EQ("abc\x80", r.end_key());
+
+ EXPECT_TRUE(r.contains_key("abc\x7f\x07"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x15"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\x61"));
+ EXPECT_TRUE(r.contains_key("abc\x7f\xff"));
+ EXPECT_FALSE(r.contains_key("abc\0"));
+ EXPECT_FALSE(r.contains_key("abc\x01"));
+ EXPECT_FALSE(r.contains_key("abd"));
+ EXPECT_FALSE(r.contains_key("ab\x80"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
new file mode 100644
index 0000000000..96d3565d9b
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ *output =
+ new Dataset(ctx, resource, std::move(start_key), std::move(end_key));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string start_key, string end_key)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ start_key_(std::move(start_key)),
+ end_key_(std::move(end_key)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableRangeKeyDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableRangeKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::google::cloud::bigtable::RowRange MakeRowRange() override {
+ return ::google::cloud::bigtable::RowRange::Range(dataset()->start_key_,
+ dataset()->end_key_);
+ }
+ ::google::cloud::bigtable::Filter MakeFilter() override {
+ return ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::CellsRowLimit(1),
+ ::google::cloud::bigtable::Filter::StripValueTransformer());
+ }
+ Status ParseRow(IteratorContext* ctx,
+ const ::google::cloud::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ Tensor output_tensor(ctx->allocator({}), DT_STRING, {});
+ output_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(output_tensor));
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* const table_;
+ const string start_key_;
+ const string end_key_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
+ BigtableRangeKeyDatasetOp);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
new file mode 100644
index 0000000000..a1a63a975a
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/contrib/bigtable/kernels/bigtable_range_helpers.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
+ errors::InvalidArgument(
+ "Only one of prefix and start_key can be provided"));
+ if (!prefix.empty()) {
+ OP_REQUIRES(ctx, end_key.empty(),
+ errors::InvalidArgument(
+ "If prefix is specified, end_key must be empty."));
+ }
+
+ *output = new Dataset(ctx, resource, std::move(prefix),
+ std::move(start_key), std::move(end_key));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix, string start_key, string end_key)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ key_range_(MakeMultiModeKeyRange(
+ std::move(prefix), std::move(start_key), std::move(end_key))) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableSampleKeyPairsDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableSampleKeyPairsDatasetOp::Dataset";
+ }
+
+ private:
+ static MultiModeKeyRange MakeMultiModeKeyRange(string prefix,
+ string start_key,
+ string end_key) {
+ if (!start_key.empty()) {
+ return MultiModeKeyRange::FromRange(std::move(start_key),
+ std::move(end_key));
+ }
+ return MultiModeKeyRange::FromPrefix(std::move(prefix));
+ }
+
+ BigtableTableResource& table() const { return *table_; }
+
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ // Computes split points (`keys_`) to use when scanning the table.
+ //
+ // Initialize first retrieves the sample keys from the table (`row_keys`),
+ // as these often form good split points within the table. We then iterate
+ // over them, and copy them to `keys_` if they fall within the requested
+ // range to scan (`dataset()->key_range_`). Because the requested range
+ // might start between elements of the sampled keys list, care is taken to
+ // ensure we don't accidentally miss any subsets of the requested range by
+ // including `begin_key()` and `end_key()` as appropriate.
+ Status Initialize(IteratorContext* ctx) override {
+ grpc::Status status;
+ std::vector<google::cloud::bigtable::RowKeySample> row_keys =
+ dataset()->table().table().SampleRows(status);
+ if (!status.ok()) {
+ return GrpcStatusToTfStatus(status);
+ }
+
+ for (size_t i = 0; i < row_keys.size(); ++i) {
+ string row_key(row_keys[i].row_key);
+ if (dataset()->key_range_.contains_key(row_key)) {
+ // First key: check to see if we need to add the begin_key.
+ if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+ keys_.push_back(std::move(row_key));
+ } else if (!keys_.empty()) {
+ // If !keys_.empty(), then we have found at least one element of
+ // `row_keys` that is within our requested range
+ // (`dataset()->key_range_`). Because `row_keys` is sorted, if we
+ // have found an element that's not within our key range, then we
+ // are after our requested range (ranges are contiguous) and can end
+ // iteration early.
+ break;
+ }
+ }
+
+ // Handle the case where we skip over the selected range entirely.
+ if (keys_.empty()) {
+ keys_.push_back(dataset()->key_range_.begin_key());
+ }
+
+ // Last key: check to see if we need to add the end_key.
+ if (keys_.back() != dataset()->key_range_.end_key()) {
+ keys_.push_back(dataset()->key_range_.end_key());
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ > keys_.size() - 2) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ *end_of_sequence = false;
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_];
+
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() = keys_[index_ + 1];
+ ++index_;
+
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t index_ GUARDED_BY(mu_) = 0;
+ // Note: we store the keys_ on the iterator instead of the dataset
+ // because we want to re-sample the row keys in case there have been
+ // tablet rebalancing operations since the dataset was created.
+ //
+ // Note: keys_ is readonly after Initialize, and thus does not need a
+ // guarding lock.
+ std::vector<string> keys_;
+ };
+
+ BigtableTableResource* const table_;
+ const MultiModeKeyRange key_range_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BigtableSampleKeyPairsDataset").Device(DEVICE_CPU),
+ BigtableSampleKeyPairsDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
new file mode 100644
index 0000000000..a5a47cfe2d
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -0,0 +1,113 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ *output = new Dataset(ctx, resource);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table)
+ : GraphDatasetBase(ctx), table_(table) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableSampleKeysDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override {
+ return "BigtableRangeKeyDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ ::grpc::Status status;
+ row_keys_ = dataset()->table()->table().SampleRows(status);
+ if (!status.ok()) {
+ row_keys_.clear();
+ return GrpcStatusToTfStatus(status);
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ < row_keys_.size()) {
+ out_tensors->emplace_back(ctx->allocator({}), DT_STRING,
+ TensorShape({}));
+ out_tensors->back().scalar<string>()() =
+ string(row_keys_[index_].row_key);
+ *end_of_sequence = false;
+ index_++;
+ } else {
+ *end_of_sequence = true;
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t index_ = 0;
+ std::vector<::google::cloud::bigtable::RowKeySample> row_keys_;
+ };
+
+ BigtableTableResource* const table_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
+ BigtableSampleKeysDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
new file mode 100644
index 0000000000..13cb868167
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -0,0 +1,219 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BigtableScanDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "prefix", &prefix));
+ string start_key;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "start_key", &start_key));
+ string end_key;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "end_key", &end_key));
+
+ OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()),
+ errors::InvalidArgument(
+ "Either prefix or start_key must be specified"));
+ OP_REQUIRES(ctx, prefix.empty() || start_key.empty(),
+ errors::InvalidArgument(
+ "Only one of prefix and start_key can be provided"));
+ if (!prefix.empty()) {
+ OP_REQUIRES(ctx, end_key.empty(),
+ errors::InvalidArgument(
+ "If prefix is specified, end_key must be empty."));
+ }
+
+ std::vector<string> column_families;
+ std::vector<string> columns;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "column_families",
+ &column_families));
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "columns", &columns));
+ OP_REQUIRES(
+ ctx, column_families.size() == columns.size(),
+ errors::InvalidArgument("len(columns) != len(column_families)"));
+ OP_REQUIRES(ctx, !column_families.empty(),
+ errors::InvalidArgument("`column_families` is empty"));
+
+ float probability = 0;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<float>(ctx, "probability", &probability));
+ OP_REQUIRES(
+ ctx, probability > 0 && probability <= 1,
+ errors::InvalidArgument(
+ "Probability outside the range of (0, 1]. Got: ", probability));
+
+ BigtableTableResource* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+
+ const uint64 num_outputs = columns.size() + 1;
+ std::vector<PartialTensorShape> output_shapes;
+ output_shapes.reserve(num_outputs);
+ DataTypeVector output_types;
+ output_types.reserve(num_outputs);
+ for (uint64 i = 0; i < num_outputs; ++i) {
+ output_shapes.push_back({});
+ output_types.push_back(DT_STRING);
+ }
+
+ *output = new Dataset(ctx, resource, std::move(prefix),
+ std::move(start_key), std::move(end_key),
+ std::move(column_families), std::move(columns),
+ probability, output_types, std::move(output_shapes));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table,
+ string prefix, string start_key, string end_key,
+ std::vector<string> column_families,
+ std::vector<string> columns, float probability,
+ const DataTypeVector& output_types,
+ std::vector<PartialTensorShape> output_shapes)
+ : GraphDatasetBase(ctx),
+ table_(table),
+ prefix_(std::move(prefix)),
+ start_key_(std::move(start_key)),
+ end_key_(std::move(end_key)),
+ column_families_(std::move(column_families)),
+ columns_(std::move(columns)),
+ column_family_regex_(RegexFromStringSet(column_families_)),
+ column_regex_(RegexFromStringSet(columns_)),
+ probability_(probability),
+ output_types_(output_types),
+ output_shapes_(std::move(output_shapes)) {
+ table_->Ref();
+ }
+
+ ~Dataset() override { table_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::BigtableScanDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "BigtableScanDatasetOp::Dataset";
+ }
+
+ BigtableTableResource* table() const { return table_; }
+
+ private:
+ class Iterator : public BigtableReaderDatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : BigtableReaderDatasetIterator<Dataset>(params) {}
+
+ ::google::cloud::bigtable::RowRange MakeRowRange() override {
+ if (!dataset()->prefix_.empty()) {
+ DCHECK(dataset()->start_key_.empty());
+ return ::google::cloud::bigtable::RowRange::Prefix(
+ dataset()->prefix_);
+ } else {
+ DCHECK(!dataset()->start_key_.empty())
+ << "Both prefix and start_key were empty!";
+ return ::google::cloud::bigtable::RowRange::Range(
+ dataset()->start_key_, dataset()->end_key_);
+ }
+ }
+ ::google::cloud::bigtable::Filter MakeFilter() override {
+ // TODO(saeta): Investigate optimal ordering here.
+ return ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1),
+ ::google::cloud::bigtable::Filter::FamilyRegex(
+ dataset()->column_family_regex_),
+ ::google::cloud::bigtable::Filter::ColumnRegex(
+ dataset()->column_regex_),
+ dataset()->probability_ != 1.0
+ ? ::google::cloud::bigtable::Filter::RowSample(
+ dataset()->probability_)
+ : ::google::cloud::bigtable::Filter::PassAllFilter());
+ }
+ Status ParseRow(IteratorContext* ctx,
+ const ::google::cloud::bigtable::Row& row,
+ std::vector<Tensor>* out_tensors) override {
+ out_tensors->reserve(dataset()->columns_.size() + 1);
+ Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {});
+ row_key_tensor.scalar<string>()() = string(row.row_key());
+ out_tensors->emplace_back(std::move(row_key_tensor));
+
+ if (row.cells().size() > 2 * dataset()->columns_.size()) {
+ LOG(WARNING) << "An excessive number of columns ("
+ << row.cells().size()
+ << ") were retrieved when reading row: "
+ << row.row_key();
+ }
+
+ for (uint64 i = 0; i < dataset()->columns_.size(); ++i) {
+ Tensor col_tensor(ctx->allocator({}), DT_STRING, {});
+ bool found_column = false;
+ for (auto cell_itr = row.cells().begin();
+ !found_column && cell_itr != row.cells().end(); ++cell_itr) {
+ if (cell_itr->family_name() == dataset()->column_families_[i] &&
+ string(cell_itr->column_qualifier()) ==
+ dataset()->columns_[i]) {
+ col_tensor.scalar<string>()() = string(cell_itr->value());
+ found_column = true;
+ }
+ }
+ if (!found_column) {
+ return errors::InvalidArgument(
+ "Column ", dataset()->column_families_[i], ":",
+ dataset()->columns_[i], " not found in row: ", row.row_key());
+ }
+ out_tensors->emplace_back(std::move(col_tensor));
+ }
+ return Status::OK();
+ }
+ };
+
+ BigtableTableResource* table_;
+ const string prefix_;
+ const string start_key_;
+ const string end_key_;
+ const std::vector<string> column_families_;
+ const std::vector<string> columns_;
+ const string column_family_regex_;
+ const string column_regex_;
+ const float probability_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
+ BigtableScanDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
new file mode 100644
index 0000000000..f083ce6f44
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
@@ -0,0 +1,374 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+
+#include "google/bigtable/v2/data.pb.h"
+#include "google/protobuf/wrappers.pb.h"
+#include "re2/re2.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/util/ptr_util.h"
+// #include "util/task/codes.pb.h"
+
+namespace tensorflow {
+namespace {
+
+void UpdateRow(const ::google::bigtable::v2::Mutation& mut,
+ std::map<string, string>* row) {
+ if (mut.has_set_cell()) {
+ CHECK(mut.set_cell().timestamp_micros() >= -1)
+ << "Timestamp_micros: " << mut.set_cell().timestamp_micros();
+ auto col =
+ strings::Printf("%s:%s", mut.set_cell().family_name().c_str(),
+ string(mut.set_cell().column_qualifier()).c_str());
+ (*row)[col] = string(mut.set_cell().value());
+ } else if (mut.has_delete_from_column()) {
+ auto col = strings::Printf(
+ "%s:%s", mut.delete_from_column().family_name().c_str(),
+ string(mut.delete_from_column().column_qualifier()).c_str());
+ row->erase(col);
+ } else if (mut.has_delete_from_family()) {
+ auto itr = row->lower_bound(mut.delete_from_family().family_name());
+ auto prefix =
+ strings::Printf("%s:", mut.delete_from_family().family_name().c_str());
+ while (itr != row->end() && itr->first.substr(0, prefix.size()) == prefix) {
+ row->erase(itr);
+ }
+ } else if (mut.has_delete_from_row()) {
+ row->clear();
+ } else {
+ LOG(ERROR) << "Unknown mutation: " << mut.ShortDebugString();
+ }
+}
+
+} // namespace
+
+class SampleRowKeysResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::SampleRowKeysResponse> {
+ public:
+ explicit SampleRowKeysResponse(BigtableTestClient* client)
+ : client_(client) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ mutex_lock l2(client_->mu_);
+ if (num_messages_sent_ * 2 < client_->table_.rows.size()) {
+ *sz = 10000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+ return false;
+ }
+
+ bool Read(google::bigtable::v2::SampleRowKeysResponse* resp) override {
+ // Send every other key from the table.
+ mutex_lock l(mu_);
+ mutex_lock l2(client_->mu_);
+ *resp = google::bigtable::v2::SampleRowKeysResponse();
+ auto itr = client_->table_.rows.begin();
+ for (uint64 i = 0; i < 2 * num_messages_sent_; ++i) {
+ ++itr;
+ if (itr == client_->table_.rows.end()) {
+ return false;
+ }
+ }
+ resp->set_row_key(itr->first);
+ resp->set_offset_bytes(100 * num_messages_sent_);
+ num_messages_sent_++;
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ mutex mu_;
+ int64 num_messages_sent_ GUARDED_BY(mu_) = 0;
+ BigtableTestClient* client_; // Not owned.
+};
+
+class ReadRowsResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::ReadRowsResponse> {
+ public:
+ ReadRowsResponse(BigtableTestClient* client,
+ google::bigtable::v2::ReadRowsRequest const& request)
+ : client_(client), request_(request) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ *sz = 10000000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+
+ bool Read(google::bigtable::v2::ReadRowsResponse* resp) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ sent_first_message_ = true;
+ RowFilter filter = MakeRowFilter();
+
+ mutex_lock l2(client_->mu_);
+ *resp = google::bigtable::v2::ReadRowsResponse();
+ // Send all contents in first response.
+ for (auto itr = client_->table_.rows.begin();
+ itr != client_->table_.rows.end(); ++itr) {
+ if (filter.AllowRow(itr->first)) {
+ ::google::bigtable::v2::ReadRowsResponse_CellChunk* chunk = nullptr;
+ bool sent_first = false;
+ for (auto col_itr = itr->second.columns.begin();
+ col_itr != itr->second.columns.end(); ++col_itr) {
+ if (filter.AllowColumn(col_itr->first)) {
+ chunk = resp->add_chunks();
+ if (!sent_first) {
+ sent_first = true;
+ chunk->set_row_key(itr->first);
+ }
+ auto colon_idx = col_itr->first.find(":");
+ CHECK(colon_idx != string::npos)
+ << "No ':' found in: " << col_itr->first;
+ chunk->mutable_family_name()->set_value(
+ string(col_itr->first, 0, colon_idx));
+ chunk->mutable_qualifier()->set_value(
+ string(col_itr->first, ++colon_idx));
+ if (!filter.strip_values) {
+ chunk->set_value(col_itr->second);
+ }
+ if (filter.only_one_column) {
+ break;
+ }
+ }
+ }
+ if (sent_first) {
+ // We are sending this row, so set the commit flag on the last chunk.
+ chunk->set_commit_row(true);
+ }
+ }
+ }
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ struct RowFilter {
+ std::set<string> row_set;
+ std::vector<std::pair<string, string>> row_ranges;
+ double row_sample = 0.0; // Note: currently ignored.
+ std::unique_ptr<RE2> col_filter;
+ bool strip_values = false;
+ bool only_one_column = false;
+
+ bool AllowRow(const string& row) {
+ if (row_set.find(row) != row_set.end()) {
+ return true;
+ }
+ for (const auto& range : row_ranges) {
+ if (range.first <= row && range.second > row) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool AllowColumn(const string& col) {
+ if (col_filter) {
+ return RE2::FullMatch(col, *col_filter);
+ } else {
+ return true;
+ }
+ }
+ };
+
+ RowFilter MakeRowFilter() {
+ RowFilter filter;
+ for (auto i = request_.rows().row_keys().begin();
+ i != request_.rows().row_keys().end(); ++i) {
+ filter.row_set.insert(string(*i));
+ }
+ for (auto i = request_.rows().row_ranges().begin();
+ i != request_.rows().row_ranges().end(); ++i) {
+ if (i->start_key_case() !=
+ google::bigtable::v2::RowRange::kStartKeyClosed ||
+ i->end_key_case() != google::bigtable::v2::RowRange::kEndKeyOpen) {
+ LOG(WARNING) << "Skipping row range that cannot be processed: "
+ << i->ShortDebugString();
+ continue;
+ }
+ filter.row_ranges.emplace_back(std::make_pair(
+ string(i->start_key_closed()), string(i->end_key_open())));
+ }
+ if (request_.filter().has_chain()) {
+ string family_filter;
+ string qualifier_filter;
+ for (auto i = request_.filter().chain().filters().begin();
+ i != request_.filter().chain().filters().end(); ++i) {
+ switch (i->filter_case()) {
+ case google::bigtable::v2::RowFilter::kFamilyNameRegexFilter:
+ family_filter = i->family_name_regex_filter();
+ break;
+ case google::bigtable::v2::RowFilter::kColumnQualifierRegexFilter:
+ qualifier_filter = i->column_qualifier_regex_filter();
+ break;
+ case google::bigtable::v2::RowFilter::kCellsPerColumnLimitFilter:
+ if (i->cells_per_column_limit_filter() != 1) {
+ LOG(ERROR) << "Unexpected cells_per_column_limit_filter: "
+ << i->cells_per_column_limit_filter();
+ }
+ break;
+ case google::bigtable::v2::RowFilter::kStripValueTransformer:
+ filter.strip_values = i->strip_value_transformer();
+ break;
+ case google::bigtable::v2::RowFilter::kRowSampleFilter:
+ LOG(INFO) << "Ignoring row sample directive.";
+ break;
+ case google::bigtable::v2::RowFilter::kPassAllFilter:
+ break;
+ case google::bigtable::v2::RowFilter::kCellsPerRowLimitFilter:
+ filter.only_one_column = true;
+ break;
+ default:
+ LOG(WARNING) << "Ignoring unknown filter type: "
+ << i->ShortDebugString();
+ }
+ }
+ if (family_filter.empty() || qualifier_filter.empty()) {
+ LOG(WARNING) << "Missing regex!";
+ } else {
+ string regex = strings::Printf("%s:%s", family_filter.c_str(),
+ qualifier_filter.c_str());
+ filter.col_filter.reset(new RE2(regex));
+ }
+ } else {
+ LOG(WARNING) << "Read request did not have a filter chain specified: "
+ << request_.filter().DebugString();
+ }
+ return filter;
+ }
+
+ mutex mu_;
+ bool sent_first_message_ GUARDED_BY(mu_) = false;
+ BigtableTestClient* client_; // Not owned.
+ const google::bigtable::v2::ReadRowsRequest request_;
+};
+
+class MutateRowsResponse : public grpc::ClientReaderInterface<
+ google::bigtable::v2::MutateRowsResponse> {
+ public:
+ explicit MutateRowsResponse(size_t num_successes)
+ : num_successes_(num_successes) {}
+
+ bool NextMessageSize(uint32_t* sz) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ *sz = 10000000; // A sufficiently high enough value to not worry about.
+ return true;
+ }
+
+ bool Read(google::bigtable::v2::MutateRowsResponse* resp) override {
+ mutex_lock l(mu_);
+ if (sent_first_message_) {
+ return false;
+ }
+ sent_first_message_ = true;
+ *resp = google::bigtable::v2::MutateRowsResponse();
+ for (size_t i = 0; i < num_successes_; ++i) {
+ auto entry = resp->add_entries();
+ entry->set_index(i);
+ }
+ return true;
+ }
+
+ grpc::Status Finish() override { return grpc::Status::OK; }
+
+ void WaitForInitialMetadata() override {} // Do nothing.
+
+ private:
+ const size_t num_successes_;
+
+ mutex mu_;
+ bool sent_first_message_ = false;
+};
+
+grpc::Status BigtableTestClient::MutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ google::bigtable::v2::MutateRowResponse* response) {
+ mutex_lock l(mu_);
+ auto* row = &table_.rows[string(request.row_key())];
+ for (int i = 0; i < request.mutations_size(); ++i) {
+ UpdateRow(request.mutations(i), &row->columns);
+ }
+ *response = google::bigtable::v2::MutateRowResponse();
+ return grpc::Status::OK;
+}
+grpc::Status BigtableTestClient::CheckAndMutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::CheckAndMutateRowRequest const& request,
+ google::bigtable::v2::CheckAndMutateRowResponse* response) {
+ return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
+ "CheckAndMutateRow not implemented.");
+}
+grpc::Status BigtableTestClient::ReadModifyWriteRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadModifyWriteRowRequest const& request,
+ google::bigtable::v2::ReadModifyWriteRowResponse* response) {
+ return grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
+ "ReadModifyWriteRow not implemented.");
+}
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
+BigtableTestClient::ReadRows(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadRowsRequest const& request) {
+ return MakeUnique<ReadRowsResponse>(this, request);
+}
+
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
+BigtableTestClient::SampleRowKeys(
+ grpc::ClientContext* context,
+ google::bigtable::v2::SampleRowKeysRequest const& request) {
+ return MakeUnique<SampleRowKeysResponse>(this);
+}
+std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
+BigtableTestClient::MutateRows(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowsRequest const& request) {
+ mutex_lock l(mu_);
+ for (auto i = request.entries().begin(); i != request.entries().end(); ++i) {
+ auto* row = &table_.rows[string(i->row_key())];
+ for (auto mut = i->mutations().begin(); mut != i->mutations().end();
+ ++mut) {
+ UpdateRow(*mut, &row->columns);
+ }
+ }
+ return MakeUnique<MutateRowsResponse>(request.entries_size());
+}
+
+std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() {
+ LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely "
+ "cause a crash!";
+ return nullptr;
+}
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
new file mode 100644
index 0000000000..dac2b16a21
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
+#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
+
+#include "google/cloud/bigtable/data_client.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class BigtableTestClient : public ::google::cloud::bigtable::DataClient {
+ public:
+ std::string const& project_id() const override { return project_id_; }
+ std::string const& instance_id() const override { return instance_id_; }
+ void reset() override {
+ mutex_lock l(mu_);
+ table_ = Table();
+ }
+
+ grpc::Status MutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ google::bigtable::v2::MutateRowResponse* response) override;
+
+ grpc::Status CheckAndMutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::CheckAndMutateRowRequest const& request,
+ google::bigtable::v2::CheckAndMutateRowResponse* response) override;
+
+ grpc::Status ReadModifyWriteRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::ReadModifyWriteRowRequest const& request,
+ google::bigtable::v2::ReadModifyWriteRowResponse* response) override;
+
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::ReadRowsResponse>>
+ ReadRows(grpc::ClientContext* context,
+ google::bigtable::v2::ReadRowsRequest const& request) override;
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::SampleRowKeysResponse>>
+ SampleRowKeys(
+ grpc::ClientContext* context,
+ google::bigtable::v2::SampleRowKeysRequest const& request) override;
+
+ std::unique_ptr<
+ grpc::ClientReaderInterface<google::bigtable::v2::MutateRowsResponse>>
+ MutateRows(grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowsRequest const& request) override;
+
+ std::shared_ptr<grpc::Channel> Channel() override;
+
+ private:
+ friend class SampleRowKeysResponse;
+ friend class ReadRowsResponse;
+ friend class MutateRowsResponse;
+
+ struct Row {
+ string row_key;
+ std::map<string, string> columns;
+ };
+ struct Table {
+ std::map<string, Row> rows;
+ };
+
+ mutex mu_;
+ const std::string project_id_ = "testproject";
+ const std::string instance_id_ = "testinstance";
+ Table table_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_TEST_KERNELS_BIGTABLE_TEST_CLIENT_H_
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc
new file mode 100644
index 0000000000..fa3e587b90
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_op.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace tensorflow {
+
+namespace {
+
+class BigtableTestClientOp : public OpKernel {
+ public:
+ explicit BigtableTestClientOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ ~BigtableTestClientOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<BigtableClientResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ BigtableClientResource* resource;
+ OP_REQUIRES_OK(
+ ctx,
+ mgr->LookupOrCreate<BigtableClientResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](BigtableClientResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::shared_ptr<google::cloud::bigtable::DataClient> client(
+ new BigtableTestClient());
+ // Note: must make explicit copies to sequence
+ // them before the move of client.
+ string project_id = client->project_id();
+ string instance_id = client->instance_id();
+ *ret = new BigtableClientResource(std::move(project_id),
+ std::move(instance_id),
+ std::move(client));
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<BigtableClientResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BigtableTestClient").Device(DEVICE_CPU),
+ BigtableTestClientOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
new file mode 100644
index 0000000000..32611e2590
--- /dev/null
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client_test.cc
@@ -0,0 +1,345 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h"
+#include "google/cloud/bigtable/internal/table.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+void WriteCell(const string& row, const string& family, const string& column,
+ const string& value,
+ ::google::cloud::bigtable::noex::Table* table) {
+ ::google::cloud::bigtable::SingleRowMutation mut(row);
+ mut.emplace_back(::google::cloud::bigtable::SetCell(family, column, value));
+ table->Apply(std::move(mut));
+}
+
+TEST(BigtableTestClientTest, EmptyRowRead) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ ::google::cloud::bigtable::RowSet rowset;
+ rowset.Append("r1");
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ EXPECT_EQ(rows.begin(), rows.end()) << "Some rows were returned in response!";
+ EXPECT_TRUE(rows.Finish().ok()) << "Error reading rows.";
+}
+
+TEST(BigtableTestClientTest, SingleRowWriteAndRead) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+
+ ::google::cloud::bigtable::RowSet rowset("r1");
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+ EXPECT_NE(itr, rows.end()) << "No rows were returned in response!";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end());
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndSingleRowRead) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ ::google::cloud::bigtable::RowSet rowset("r1");
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndRead) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ ::google::cloud::bigtable::RowSet rowset("r1", "r2", "r3");
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1));
+ auto rows = table.ReadRows(std::move(rowset), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, MultiRowWriteAndPrefixRead) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1));
+ auto rows =
+ table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, ColumnFiltering) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ // Extra cells
+ WriteCell("r1", "f2", "c1", "v1", &table);
+ WriteCell("r2", "f2", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c2", "v3", &table);
+
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1),
+ ::google::cloud::bigtable::Filter::FamilyRegex("f1"),
+ ::google::cloud::bigtable::Filter::ColumnRegex("c1"));
+ auto rows =
+ table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v1");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v2");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "v3");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, RowKeys) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+
+ // Extra cells
+ WriteCell("r1", "f2", "c1", "v1", &table);
+ WriteCell("r2", "f2", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c2", "v3", &table);
+
+ auto filter = ::google::cloud::bigtable::Filter::Chain(
+ ::google::cloud::bigtable::Filter::Latest(1),
+ ::google::cloud::bigtable::Filter::CellsRowLimit(1),
+ ::google::cloud::bigtable::Filter::StripValueTransformer());
+ auto rows =
+ table.ReadRows(::google::cloud::bigtable::RowRange::Prefix("r"), filter);
+ auto itr = rows.begin();
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r1");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r2");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+
+ EXPECT_NE(itr, rows.end()) << "Missing rows";
+ EXPECT_EQ(itr->row_key(), "r3");
+ EXPECT_EQ(itr->cells().size(), 1);
+ EXPECT_EQ(itr->cells()[0].family_name(), "f1");
+ EXPECT_EQ(itr->cells()[0].column_qualifier(), "c1");
+ EXPECT_EQ(itr->cells()[0].value(), "");
+
+ ++itr;
+ EXPECT_EQ(itr, rows.end()) << "Extra rows in the response.";
+ EXPECT_TRUE(rows.Finish().ok());
+}
+
+TEST(BigtableTestClientTest, SampleKeys) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+ WriteCell("r4", "f1", "c1", "v4", &table);
+ WriteCell("r5", "f1", "c1", "v5", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(3, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+ EXPECT_EQ(0, resp[0].offset_bytes);
+ EXPECT_EQ("r3", string(resp[1].row_key));
+ EXPECT_EQ(100, resp[1].offset_bytes);
+ EXPECT_EQ("r5", string(resp[2].row_key));
+ EXPECT_EQ(200, resp[2].offset_bytes);
+}
+
+TEST(BigtableTestClientTest, SampleKeysShort) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(1, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+}
+
+TEST(BigtableTestClientTest, SampleKeysEvenNumber) {
+ std::shared_ptr<::google::cloud::bigtable::DataClient> client_ptr =
+ std::make_shared<BigtableTestClient>();
+ ::google::cloud::bigtable::noex::Table table(client_ptr, "test_table");
+
+ WriteCell("r1", "f1", "c1", "v1", &table);
+ WriteCell("r2", "f1", "c1", "v2", &table);
+ WriteCell("r3", "f1", "c1", "v3", &table);
+ WriteCell("r4", "f1", "c1", "v4", &table);
+
+ grpc::Status status;
+ auto resp = table.SampleRows(status);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(2, resp.size());
+ EXPECT_EQ("r1", string(resp[0].row_key));
+ EXPECT_EQ("r3", string(resp[1].row_key));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
new file mode 100644
index 0000000000..416b719e30
--- /dev/null
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -0,0 +1,107 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+// TODO(saeta): Add support for setting ClientOptions values.
+REGISTER_OP("BigtableClient")
+ .Attr("project_id: string")
+ .Attr("instance_id: string")
+ .Attr("connection_pool_size: int")
+ .Attr("max_receive_message_size: int = -1")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("client: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// TODO(saeta): Add support for Application Profiles.
+// See https://cloud.google.com/bigtable/docs/app-profiles for more info.
+REGISTER_OP("BigtableTable")
+ .Input("client: resource")
+ .Attr("table_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("table: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("DatasetToBigtable")
+ .Input("table: resource")
+ .Input("input_dataset: variant")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Input("timestamp: int64")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+REGISTER_OP("BigtableLookupDataset")
+ .Input("keys_dataset: variant")
+ .Input("table: resource")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Output("handle: variant")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtablePrefixKeyDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtableRangeKeyDataset")
+ .Input("table: resource")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtableSampleKeysDataset")
+ .Input("table: resource")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("BigtableSampleKeyPairsDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// TODO(saeta): Support continuing despite bad data (e.g. empty string, or
+// skip incomplete row.)
+REGISTER_OP("BigtableScanDataset")
+ .Input("table: resource")
+ .Input("prefix: string")
+ .Input("start_key: string")
+ .Input("end_key: string")
+ .Input("column_families: string")
+ .Input("columns: string")
+ .Input("probability: float")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc
new file mode 100644
index 0000000000..f7d02458f6
--- /dev/null
+++ b/tensorflow/contrib/bigtable/ops/bigtable_test_ops.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("BigtableTestClient")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Output("client: resource")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py
new file mode 100644
index 0000000000..292d8f4e51
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This module contains tests for the bigtable integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
new file mode 100644
index 0000000000..2f20064619
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -0,0 +1,272 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Bigtable Ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import bigtable
+from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.bigtable.ops import gen_bigtable_test_ops
+from tensorflow.contrib.bigtable.python.ops import bigtable_api
+from tensorflow.contrib.util import loader
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_bigtable_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_bigtable_test.so"))
+
+
+def _ListOfTuplesOfStringsToBytes(values):
+ return [(compat.as_bytes(i[0]), compat.as_bytes(i[1])) for i in values]
+
+
+class BigtableOpsTest(test.TestCase):
+ COMMON_ROW_KEYS = ["r1", "r2", "r3"]
+ COMMON_VALUES = ["v1", "v2", "v3"]
+
+ def setUp(self):
+ self._client = gen_bigtable_test_ops.bigtable_test_client()
+ table = gen_bigtable_ops.bigtable_table(self._client, "testtable")
+ self._table = bigtable.BigTable("testtable", None, table)
+
+ def _makeSimpleDataset(self):
+ output_rows = dataset_ops.Dataset.from_tensor_slices(self.COMMON_ROW_KEYS)
+ output_values = dataset_ops.Dataset.from_tensor_slices(self.COMMON_VALUES)
+ return dataset_ops.Dataset.zip((output_rows, output_values))
+
+ def _writeCommonValues(self, sess):
+ output_ds = self._makeSimpleDataset()
+ write_op = self._table.write(output_ds, ["cf1"], ["c1"])
+ sess.run(write_op)
+
+ def runReadKeyTest(self, read_ds):
+ itr = read_ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected = list(self.COMMON_ROW_KEYS)
+ expected.reverse()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i in range(3):
+ output = sess.run(n)
+ want = expected.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output),
+ "Unequal at step %d: want: %s, got: %s" % (i, want, output))
+
+ def testReadPrefixKeys(self):
+ self.runReadKeyTest(self._table.keys_by_prefix_dataset("r"))
+
+ def testReadRangeKeys(self):
+ self.runReadKeyTest(self._table.keys_by_range_dataset("r1", "r4"))
+
+ def runScanTest(self, read_ds):
+ itr = read_ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_keys = list(self.COMMON_ROW_KEYS)
+ expected_keys.reverse()
+ expected_values = list(self.COMMON_VALUES)
+ expected_values.reverse()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i in range(3):
+ output = sess.run(n)
+ want = expected_keys.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output[0]),
+ "Unequal keys at step %d: want: %s, got: %s" % (i, want, output[0]))
+ want = expected_values.pop()
+ self.assertEqual(
+ compat.as_bytes(want), compat.as_bytes(output[1]),
+ "Unequal values at step: %d: want: %s, got: %s" % (i, want,
+ output[1]))
+
+ def testScanPrefixStringCol(self):
+ self.runScanTest(self._table.scan_prefix("r", cf1="c1"))
+
+ def testScanPrefixListCol(self):
+ self.runScanTest(self._table.scan_prefix("r", cf1=["c1"]))
+
+ def testScanPrefixTupleCol(self):
+ self.runScanTest(self._table.scan_prefix("r", columns=("cf1", "c1")))
+
+ def testScanRangeStringCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", cf1="c1"))
+
+ def testScanRangeListCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", cf1=["c1"]))
+
+ def testScanRangeTupleCol(self):
+ self.runScanTest(self._table.scan_range("r1", "r4", columns=("cf1", "c1")))
+
+ def testLookup(self):
+ ds = self._table.keys_by_prefix_dataset("r")
+ ds = ds.apply(self._table.lookup_columns(cf1="c1"))
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_keys = list(self.COMMON_ROW_KEYS)
+ expected_values = list(self.COMMON_VALUES)
+ expected_tuples = zip(expected_keys, expected_values)
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i, elem in enumerate(expected_tuples):
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(elem[0]), compat.as_bytes(output[0]),
+ "Unequal keys at step %d: want: %s, got: %s" %
+ (i, compat.as_bytes(elem[0]), compat.as_bytes(output[0])))
+ self.assertEqual(
+ compat.as_bytes(elem[1]), compat.as_bytes(output[1]),
+ "Unequal values at step %d: want: %s, got: %s" %
+ (i, compat.as_bytes(elem[1]), compat.as_bytes(output[1])))
+
+ def testSampleKeys(self):
+ ds = self._table.sample_keys()
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ expected_key = self.COMMON_ROW_KEYS[0]
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(self.COMMON_ROW_KEYS[0]), compat.as_bytes(output),
+ "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
+ self.COMMON_ROW_KEYS[0]), compat.as_bytes(output)))
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(self.COMMON_ROW_KEYS[2]), compat.as_bytes(output),
+ "Unequal keys: want: %s, got: %s" % (compat.as_bytes(
+ self.COMMON_ROW_KEYS[2]), compat.as_bytes(output)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+ def runSampleKeyPairsTest(self, ds, expected_key_pairs):
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ for i, elems in enumerate(expected_key_pairs):
+ output = sess.run(n)
+ self.assertEqual(
+ compat.as_bytes(elems[0]), compat.as_bytes(output[0]),
+ "Unequal key pair (first element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[0]), compat.as_bytes(output[0])))
+ self.assertEqual(
+ compat.as_bytes(elems[1]), compat.as_bytes(output[1]),
+ "Unequal key pair (second element) at step %d; want: %s, got %s" %
+ (i, compat.as_bytes(elems[1]), compat.as_bytes(output[1])))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+ def testSampleKeyPairsSimplePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="")
+ expected_key_pairs = [("r", "r1"), ("r1", "r3"), ("r3", "s")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSimpleRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r1", end="r3")
+ expected_key_pairs = [("r1", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangePrefix(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r2", start="", end="")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsSkipRangeRange(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r3")
+ expected_key_pairs = [("r2", "r3")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsOffsetRanges(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="r2", end="r4")
+ expected_key_pairs = [("r2", "r3"), ("r3", "r4")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairEverything(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="", start="", end="")
+ expected_key_pairs = [("", "r1"), ("r1", "r3"), ("r3", "")]
+ self.runSampleKeyPairsTest(ds, expected_key_pairs)
+
+ def testSampleKeyPairsPrefixAndStartKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="r1", end="")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testSampleKeyPairsPrefixAndEndKey(self):
+ ds = bigtable_api._BigtableSampleKeyPairsDataset(
+ self._table, prefix="r", start="", end="r3")
+ itr = ds.make_initializable_iterator()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(itr.initializer)
+
+ def testParallelScanPrefix(self):
+ ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
+ def testParallelScanRange(self):
+ ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.test_session() as sess:
+ self._writeCommonValues(sess)
+ sess.run(itr.initializer)
+ expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
+ actual_values = []
+ for _ in range(len(expected_values)):
+ output = sess.run(n)
+ actual_values.append(output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+ self.assertItemsEqual(
+ _ListOfTuplesOfStringsToBytes(expected_values),
+ _ListOfTuplesOfStringsToBytes(actual_values))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/bigtable/python/ops/__init__.py b/tensorflow/contrib/bigtable/python/ops/__init__.py
new file mode 100644
index 0000000000..36d75b0d70
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/ops/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""This module contains the Python API for the Cloud Bigtable integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
new file mode 100644
index 0000000000..9f73b7223c
--- /dev/null
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -0,0 +1,741 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Python API for TensorFlow's Bigtable integration.
+
+TensorFlow has support for reading from and writing to Cloud Bigtable. To use
+the Bigtable TensorFlow integration, first create a BigtableClient (which
+configures your connection to Cloud Bigtable), and then open a Table. The Table
+object then allows you to create numerous @{tf.data.Dataset}s to read data, or
+write a @{tf.data.Dataset} object to the underlying Bigtable Table.
+
+For background on Google Cloud Bigtable, see: https://cloud.google.com/bigtable.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six import iteritems
+from six import string_types
+
+from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.util import loader
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import resource_loader
+
+_bigtable_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_bigtable.so"))
+
+
+class BigtableClient(object):
+ """BigtableClient is the entrypoint for interacting with Cloud Bigtable in TF.
+
+ BigtableClient encapsulates a connection to Cloud Bigtable, and exposes the
+ `table` method to open a Bigtable Table.
+ """
+
+ def __init__(self,
+ project_id,
+ instance_id,
+ connection_pool_size=None,
+ max_receive_message_size=None):
+ """Creates a BigtableClient that can be used to open connections to tables.
+
+ Args:
+ project_id: A string representing the GCP project id to connect to.
+ instance_id: A string representing the Bigtable instance to connect to.
+ connection_pool_size: (Optional.) A number representing the number of
+ concurrent connections to the Cloud Bigtable service to make.
+ max_receive_message_size: (Optional.) The maximum bytes received in a
+ single gRPC response.
+
+ Raises:
+ ValueError: if the arguments are invalid (e.g. wrong type, or out of
+ expected ranges (e.g. negative).)
+ """
+ if not isinstance(project_id, str):
+ raise ValueError("`project_id` must be a string")
+ self._project_id = project_id
+
+ if not isinstance(instance_id, str):
+ raise ValueError("`instance_id` must be a string")
+ self._instance_id = instance_id
+
+ if connection_pool_size is None:
+ connection_pool_size = -1
+ elif connection_pool_size < 1:
+ raise ValueError("`connection_pool_size` must be positive")
+
+ if max_receive_message_size is None:
+ max_receive_message_size = -1
+ elif max_receive_message_size < 1:
+ raise ValueError("`max_receive_message_size` must be positive")
+
+ self._connection_pool_size = connection_pool_size
+
+ self._resource = gen_bigtable_ops.bigtable_client(
+ project_id, instance_id, connection_pool_size, max_receive_message_size)
+
+ def table(self, name, snapshot=None):
+ """Opens a table and returns a `BigTable` object.
+
+ Args:
+ name: A `tf.string` `tf.Tensor` name of the table to open.
+ snapshot: Either a `tf.string` `tf.Tensor` snapshot id, or `True` to
+ request the creation of a snapshot. (Note: currently unimplemented.)
+
+ Returns:
+ A `BigTable` python object representing the operations available on the
+ table.
+ """
+ # TODO(saeta): Implement snapshot functionality.
+ table = gen_bigtable_ops.bigtable_table(self._resource, name)
+ return BigTable(name, snapshot, table)
+
+
+class BigTable(object):
+ """BigTable is the entrypoint for reading and writing data in Cloud Bigtable.
+
+ This BigTable class is the python representation of the Cloud Bigtable table
+ within TensorFlow. Methods on this class allow data to be read from and
+ written to the Cloud Bigtable service in flexible and high performance
+ manners.
+ """
+
+ # TODO(saeta): Investigate implementing tf.contrib.lookup.LookupInterface.
+ # TODO(saeta): Consider variant tensors instead of resources (while supporting
+ # connection pooling).
+
+ def __init__(self, name, snapshot, resource):
+ self._name = name
+ self._snapshot = snapshot
+ self._resource = resource
+
+ def lookup_columns(self, *args, **kwargs):
+ """Retrieves the values of columns for a dataset of keys.
+
+ Example usage:
+ ```
+ table = bigtable_client.table("my_table")
+ key_dataset = table.get_keys_prefix("imagenet")
+ images = key_dataset.apply(table.lookup_columns(("cf1", "image"),
+ ("cf2", "label"),
+ ("cf2", "boundingbox")))
+ training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
+ ```
+
+ Alternatively, you can use keyword arguments to specify the columns to
+ capture. Example (same as above, rewritten):
+ ```
+ table = bigtable_client.table("my_table")
+ key_dataset = table.get_keys_prefix("imagenet")
+ images = key_dataset.apply(table.lookup_columns(
+ cf1="image", cf2=("label", "boundingbox")))
+ training_data = images.map(parse_and_crop, num_parallel_calls=64).batch(128)
+ ```
+
+ Note: certain kwargs keys are reserved, and thus some column families cannot
+ be identified using the kwargs syntax. Instead, please use the args syntax.
+ This list includes:
+ - 'name'
+ This list can change at any time.
+
+ Args:
+ *args: A list of tuples containing (column family, column name) pairs.
+ **kwargs: Column families and
+
+ Returns:
+ A function that can be passed to `tf.data.Dataset.apply` to retrieve the
+ values of columns for the rows.
+ """
+ table = self # Capture self
+ normalized = args
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ normalized = list(normalized)
+ for key, value in iteritems(kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, str):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+
+ def _apply_fn(dataset):
+ # TODO(saeta): Verify dataset's types are correct!
+ return _BigtableLookupDataset(dataset, table, normalized)
+
+ return _apply_fn
+
+ def keys_by_range_dataset(self, start, end):
+ """Retrieves all row keys between start and end.
+
+ Note: it does NOT retrieve the values of columns.
+
+ Args:
+ start: The start row key. The row keys for rows after start (inclusive)
+ will be retrieved.
+ end: (Optional.) The end row key. Rows up to (but not including) end will
+ be retrieved. If end is None, all subsequent row keys will be retrieved.
+
+ Returns:
+ A @{tf.data.Dataset} containing `tf.string` Tensors corresponding to all
+ of the row keys between `start` and `end`.
+ """
+ # TODO(saeta): Make inclusive / exclusive configurable?
+ if end is None:
+ end = ""
+ return _BigtableRangeKeyDataset(self, start, end)
+
+ def keys_by_prefix_dataset(self, prefix):
+ """Retrieves the row keys matching a given prefix.
+
+ Args:
+ prefix: All row keys that begin with `prefix` in the table will be
+ retrieved.
+
+ Returns:
+ A @{tf.data.Dataset}. containing `tf.string` Tensors corresponding to all
+ of the row keys matching that prefix.
+ """
+ return _BigtablePrefixKeyDataset(self, prefix)
+
+ def sample_keys(self):
+ """Retrieves a sampling of row keys from the Bigtable table.
+
+ This dataset is most often used in conjunction with
+ @{tf.contrib.data.parallel_interleave} to construct a set of ranges for
+ scanning in parallel.
+
+ Returns:
+ A @{tf.data.Dataset} returning string row keys.
+ """
+ return _BigtableSampleKeysDataset(self)
+
+ def scan_prefix(self, prefix, probability=None, columns=None, **kwargs):
+ """Retrieves row (including values) from the Bigtable service.
+
+ Rows with row-key prefixed by `prefix` will be retrieved.
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.scan_prefix("row_prefix", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ prefix: The prefix all row keys must match to be retrieved for prefix-
+ based scans.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ return _BigtableScanDataset(self, prefix, "", "", normalized, probability)
+
+ def scan_range(self, start, end, probability=None, columns=None, **kwargs):
+ """Retrieves rows (including values) from the Bigtable service.
+
+ Rows with row-keys between `start` and `end` will be retrieved.
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.scan_range("row_start", "row_end", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.scan_range("row_start", "row_end", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ start: The start of the range when scanning by range.
+ end: (Optional.) The end of the range when scanning by range.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ return _BigtableScanDataset(self, "", start, end, normalized, probability)
+
+ def parallel_scan_prefix(self,
+ prefix,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves row (including values) from the Bigtable service at high speed.
+
+ Rows with row-key prefixed by `prefix` will be retrieved. This method is
+ similar to `scan_prefix`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_prefix("row_prefix", columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_prefix("row_prefix", cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ prefix: The prefix all row keys must match to be retrieved for prefix-
+ based scans.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, prefix, "", "")
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
+
+ def parallel_scan_range(self,
+ start,
+ end,
+ num_parallel_scans=None,
+ probability=None,
+ columns=None,
+ **kwargs):
+ """Retrieves rows (including values) from the Bigtable service.
+
+ Rows with row-keys between `start` and `end` will be retrieved. This method
+ is similar to `scan_range`, but by constrast performs multiple sub-scans in
+ parallel in order to achieve higher performance.
+
+ Note: The dataset produced by this method is not deterministic!
+
+ Specifying the columns to retrieve for each row is done by either using
+ kwargs or in the columns parameter. To retrieve values of the columns "c1",
+ and "c2" from the column family "cfa", and the value of the column "c3"
+ from column family "cfb", the following datasets (`ds1`, and `ds2`) are
+ equivalent:
+
+ ```
+ table = # ...
+ ds1 = table.parallel_scan_range("row_start",
+ "row_end",
+ columns=[("cfa", "c1"),
+ ("cfa", "c2"),
+ ("cfb", "c3")])
+ ds2 = table.parallel_scan_range("row_start", "row_end",
+ cfa=["c1", "c2"], cfb="c3")
+ ```
+
+ Note: only the latest value of a cell will be retrieved.
+
+ Args:
+ start: The start of the range when scanning by range.
+ end: (Optional.) The end of the range when scanning by range.
+ num_parallel_scans: (Optional.) The number of concurrent scans against the
+ Cloud Bigtable instance.
+ probability: (Optional.) A float between 0 (exclusive) and 1 (inclusive).
+ A non-1 value indicates to probabilistically sample rows with the
+ provided probability.
+ columns: The columns to read. Note: most commonly, they are expressed as
+ kwargs. Use the columns value if you are using column families that are
+ reserved. The value of columns and kwargs are merged. Columns is a list
+ of tuples of strings ("column_family", "column_qualifier").
+ **kwargs: The column families and columns to read. Keys are treated as
+ column_families, and values can be either lists of strings, or strings
+ that are treated as the column qualifier (column name).
+
+ Returns:
+ A @{tf.data.Dataset} returning the row keys and the cell contents.
+
+ Raises:
+ ValueError: If the configured probability is unexpected.
+ """
+ probability = _normalize_probability(probability)
+ normalized = _normalize_columns(columns, kwargs)
+ ds = _BigtableSampleKeyPairsDataset(self, "", start, end)
+ return self._make_parallel_scan_dataset(ds, num_parallel_scans, probability,
+ normalized)
+
+ def write(self, dataset, column_families, columns, timestamp=None):
+ """Writes a dataset to the table.
+
+ Args:
+ dataset: A @{tf.data.Dataset} to be written to this table. It must produce
+ a list of number-of-columns+1 elements, all of which must be strings.
+ The first value will be used as the row key, and subsequent values will
+ be used as cell values for the corresponding columns from the
+ corresponding column_families and columns entries.
+ column_families: A @{tf.Tensor} of `tf.string`s corresponding to the
+ column names to store the dataset's elements into.
+ columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
+ to store the dataset's elements into.
+ timestamp: (Optional.) An int64 timestamp to write all the values at.
+ Leave as None to use server-provided timestamps.
+
+ Returns:
+ A @{tf.Operation} that can be run to perform the write.
+
+ Raises:
+ ValueError: If there are unexpected or incompatible types, or if the
+ number of columns and column_families does not match the output of
+ `dataset`.
+ """
+ if timestamp is None:
+ timestamp = -1 # Bigtable server provided timestamp.
+ for tensor_type in nest.flatten(dataset.output_types):
+ if tensor_type != dtypes.string:
+ raise ValueError("Not all elements of the dataset were `tf.string`")
+ for shape in nest.flatten(dataset.output_shapes):
+ if not shape.is_compatible_with(tensor_shape.scalar()):
+ raise ValueError("Not all elements of the dataset were scalars")
+ if len(column_families) != len(columns):
+ raise ValueError("len(column_families) != len(columns)")
+ if len(nest.flatten(dataset.output_types)) != len(columns) + 1:
+ raise ValueError("A column name must be specified for every component of "
+ "the dataset elements. (e.g.: len(columns) != "
+ "len(dataset.output_types))")
+ return gen_bigtable_ops.dataset_to_bigtable(
+ self._resource,
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ column_families,
+ columns,
+ timestamp)
+
+ def _make_parallel_scan_dataset(self, ds, num_parallel_scans,
+ normalized_probability, normalized_columns):
+ """Builds a parallel dataset from a given range.
+
+ Args:
+ ds: A `_BigtableSampleKeyPairsDataset` returning ranges of keys to use.
+ num_parallel_scans: The number of concurrent parallel scans to use.
+ normalized_probability: A number between 0 and 1 for the keep probability.
+ normalized_columns: The column families and column qualifiers to retrieve.
+
+ Returns:
+ A @{tf.data.Dataset} representing the result of the parallel scan.
+ """
+ if num_parallel_scans is None:
+ num_parallel_scans = 50
+
+ ds = ds.shuffle(buffer_size=10000) # TODO(saeta): Make configurable.
+
+ def _interleave_fn(start, end):
+ return _BigtableScanDataset(
+ self,
+ prefix="",
+ start=start,
+ end=end,
+ normalized=normalized_columns,
+ probability=normalized_probability)
+
+ # Note prefetch_input_elements must be set in order to avoid rpc timeouts.
+ ds = ds.apply(
+ interleave_ops.parallel_interleave(
+ _interleave_fn,
+ cycle_length=num_parallel_scans,
+ sloppy=True,
+ prefetch_input_elements=1))
+ return ds
+
+
+def _normalize_probability(probability):
+ if probability is None:
+ probability = 1.0
+ if isinstance(probability, float) and (probability <= 0.0 or
+ probability > 1.0):
+ raise ValueError("probability must be in the range (0, 1].")
+ return probability
+
+
+def _normalize_columns(columns, provided_kwargs):
+ """Converts arguments (columns, and kwargs dict) to C++ representation.
+
+ Args:
+ columns: a datastructure containing the column families and qualifier to
+ retrieve. Valid types include (1) None, (2) list of tuples, (3) a tuple of
+ strings.
+ provided_kwargs: a dictionary containing the column families and qualifiers
+ to retrieve
+
+ Returns:
+ A list of pairs of column family+qualifier to retrieve.
+
+ Raises:
+ ValueError: If there are no cells to retrieve or the columns are in an
+ incorrect format.
+ """
+ normalized = columns
+ if normalized is None:
+ normalized = []
+ if isinstance(normalized, tuple):
+ if len(normalized) == 2:
+ normalized = [normalized]
+ else:
+ raise ValueError("columns was a tuple of inappropriate length")
+ for key, value in iteritems(provided_kwargs):
+ if key == "name":
+ continue
+ if isinstance(value, string_types):
+ normalized.append((key, value))
+ continue
+ for col in value:
+ normalized.append((key, col))
+ if not normalized:
+ raise ValueError("At least one column + column family must be specified.")
+ return normalized
+
+
+class _BigtableKeyDataset(dataset_ops.Dataset):
+ """_BigtableKeyDataset is an abstract class representing the keys of a table.
+ """
+
+ def __init__(self, table):
+ """Constructs a _BigtableKeyDataset.
+
+ Args:
+ table: a Bigtable class.
+ """
+ super(_BigtableKeyDataset, self).__init__()
+ self._table = table
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.TensorShape([])
+
+ @property
+ def output_types(self):
+ return dtypes.string
+
+
+class _BigtablePrefixKeyDataset(_BigtableKeyDataset):
+ """_BigtablePrefixKeyDataset represents looking up keys by prefix.
+ """
+
+ def __init__(self, table, prefix):
+ super(_BigtablePrefixKeyDataset, self).__init__(table)
+ self._prefix = prefix
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_prefix_key_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ prefix=self._prefix)
+
+
+class _BigtableRangeKeyDataset(_BigtableKeyDataset):
+ """_BigtableRangeKeyDataset represents looking up keys by range.
+ """
+
+ def __init__(self, table, start, end):
+ super(_BigtableRangeKeyDataset, self).__init__(table)
+ self._start = start
+ self._end = end
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_range_key_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ start_key=self._start,
+ end_key=self._end)
+
+
+class _BigtableSampleKeysDataset(_BigtableKeyDataset):
+ """_BigtableSampleKeysDataset represents a sampling of row keys.
+ """
+
+ # TODO(saeta): Expose the data size offsets into the keys.
+
+ def __init__(self, table):
+ super(_BigtableSampleKeysDataset, self).__init__(table)
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_sample_keys_dataset(
+ table=self._table._resource) # pylint: disable=protected-access
+
+
+class _BigtableLookupDataset(dataset_ops.Dataset):
+ """_BigtableLookupDataset represents a dataset that retrieves values for keys.
+ """
+
+ def __init__(self, dataset, table, normalized):
+ self._num_outputs = len(normalized) + 1 # 1 for row key
+ self._dataset = dataset
+ self._table = table
+ self._normalized = normalized
+ self._column_families = [i[0] for i in normalized]
+ self._columns = [i[1] for i in normalized]
+
+ @property
+ def output_classes(self):
+ return tuple([ops.Tensor] * self._num_outputs)
+
+ @property
+ def output_shapes(self):
+ return tuple([tensor_shape.TensorShape([])] * self._num_outputs)
+
+ @property
+ def output_types(self):
+ return tuple([dtypes.string] * self._num_outputs)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_bigtable_ops.bigtable_lookup_dataset(
+ keys_dataset=self._dataset._as_variant_tensor(),
+ table=self._table._resource,
+ column_families=self._column_families,
+ columns=self._columns)
+
+
+class _BigtableScanDataset(dataset_ops.Dataset):
+ """_BigtableScanDataset represents a dataset that retrieves keys and values.
+ """
+
+ def __init__(self, table, prefix, start, end, normalized, probability):
+ self._table = table
+ self._prefix = prefix
+ self._start = start
+ self._end = end
+ self._column_families = [i[0] for i in normalized]
+ self._columns = [i[1] for i in normalized]
+ self._probability = probability
+ self._num_outputs = len(normalized) + 1 # 1 for row key
+
+ @property
+ def output_classes(self):
+ return tuple([ops.Tensor] * self._num_outputs)
+
+ @property
+ def output_shapes(self):
+ return tuple([tensor_shape.TensorShape([])] * self._num_outputs)
+
+ @property
+ def output_types(self):
+ return tuple([dtypes.string] * self._num_outputs)
+
+ def _as_variant_tensor(self):
+ return gen_bigtable_ops.bigtable_scan_dataset(
+ table=self._table._resource, # pylint: disable=protected-access
+ prefix=self._prefix,
+ start_key=self._start,
+ end_key=self._end,
+ column_families=self._column_families,
+ columns=self._columns,
+ probability=self._probability)
+
+
+class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+ """_BigtableKeyRangeDataset returns key pairs from the Bigtable.
+ """
+
+ def __init__(self, table, prefix, start, end):
+ self._table = table
+ self._prefix = prefix
+ self._start = start
+ self._end = end
+
+ @property
+ def output_classes(self):
+ return (ops.Tensor, ops.Tensor)
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return (dtypes.string, dtypes.string)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_bigtable_ops.bigtable_sample_key_pairs_dataset(
+ table=self._table._resource,
+ prefix=self._prefix,
+ start_key=self._start,
+ end_key=self._end)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index 8cff1a3bb1..ef0e80cd09 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -15,8 +15,9 @@ py_library(
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
- "custom_export_strategy",
+ ":custom_export_strategy",
":custom_loss_head",
+ ":distillation_loss",
":estimator",
":model",
":trainer_hooks",
@@ -144,6 +145,7 @@ py_library(
srcs = ["dnn_tree_combined_estimator.py"],
srcs_version = "PY2AND3",
deps = [
+ ":distillation_loss",
":estimator_utils",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
@@ -156,6 +158,17 @@ py_library(
],
)
+py_library(
+ name = "distillation_loss",
+ srcs = ["distillation_loss.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/learn",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
py_test(
name = "dnn_tree_combined_estimator_test",
size = "medium",
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py
new file mode 100644
index 0000000000..9aacc55343
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/distillation_loss.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utill functions for distillation loss.
+
+The distillation loss_fn will be called with the following:
+
+Args:
+ dnn_logits: Tensor of logits from the dnn, treated as the "target". This will
+ be the output of a call to tf.stop_gradient().
+ tree_logits: Tensor of logits from the tree, treated as the "predictions".
+ example_weights: Tensor of example weights, or a single scalar.
+
+Returns:
+ A scalar indicating the reduced loss for that batch of examples.
+
+Note: we calls the loss_fn defined in contrib head, which is computing two
+losses, first one for training and second one for reporting. We only take the
+first one here.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+
+
+def _logits_to_label_for_tree(logits, n_classes):
+ if n_classes == 2:
+ return math_ops.sigmoid(logits)
+ else:
+ return nn.softmax(logits)
+
+
+def create_dnn_to_tree_squared_loss_fn(n_classes):
+ """Returns a squared loss function for dnn to tree distillation."""
+
+ def _dnn_to_tree_squared_loss(dnn_logits, tree_logits, example_weights):
+ return head_lib._mean_squared_loss( # pylint: disable=protected-access
+ labels=_logits_to_label_for_tree(dnn_logits, n_classes),
+ logits=_logits_to_label_for_tree(tree_logits, n_classes),
+ weights=example_weights)[0]
+
+ return _dnn_to_tree_squared_loss
+
+
+def create_dnn_to_tree_cross_entropy_loss_fn(n_classes):
+ """Returns a cross entropy loss function for dnn to tree distillation."""
+
+ def _dnn_to_tree_cross_entropy_loss(dnn_logits, tree_logits, example_weights):
+ if n_classes == 2:
+ return head_lib._log_loss_with_two_classes( # pylint: disable=protected-access
+ labels=_logits_to_label_for_tree(dnn_logits, n_classes),
+ logits=tree_logits,
+ weights=example_weights)[0]
+ else:
+ return head_lib._softmax_cross_entropy_loss( # pylint: disable=protected-access
+ labels=_logits_to_label_for_tree(dnn_logits, n_classes),
+ logits=tree_logits,
+ weights=example_weights)[0]
+
+ return _dnn_to_tree_cross_entropy_loss
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index 758754feac..7eb429b636 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -24,7 +24,9 @@ from __future__ import division
from __future__ import print_function
import six
+
from tensorflow.contrib import layers
+from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -35,11 +37,13 @@ from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
@@ -77,6 +81,7 @@ def _dnn_tree_combined_model_fn(features,
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
use_core_versions=False):
"""DNN and GBDT combined model_fn.
@@ -117,6 +122,13 @@ def _dnn_tree_combined_model_fn(features,
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
@@ -132,6 +144,12 @@ def _dnn_tree_combined_model_fn(features,
if not dnn_feature_columns:
raise ValueError("dnn_feature_columns must be specified")
+ if dnn_to_tree_distillation_param:
+ if not predict_with_tree_only:
+ logging.warning("update predict_with_tree_only to True since distillation"
+ "is specified.")
+ predict_with_tree_only = True
+
# Build DNN Logits.
dnn_parent_scope = "dnn"
dnn_partitioner = dnn_input_layer_partitioner or (
@@ -225,6 +243,25 @@ def _dnn_tree_combined_model_fn(features,
def _tree_train_op_fn(loss):
"""Returns the op to optimize the loss."""
+ if dnn_to_tree_distillation_param:
+ loss_weight, loss_fn = dnn_to_tree_distillation_param
+ weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access
+ features, head.weight_column_name)
+ dnn_logits_fixed = array_ops.stop_gradient(dnn_logits)
+
+ if loss_fn is None:
+ # we create the loss_fn similar to the head loss_fn for
+ # multi_class_head used previously as the default one.
+ n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension
+ loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn(
+ n_classes)
+
+ dnn_to_tree_distillation_loss = loss_weight * loss_fn(
+ dnn_logits_fixed, tree_logits, weight_tensor)
+ summary.scalar("dnn_to_tree_distillation_loss",
+ dnn_to_tree_distillation_loss)
+ loss += dnn_to_tree_distillation_loss
+
update_op = gbdt_model.train(loss, predictions_dict, labels)
with ops.control_dependencies(
[update_op]), (ops.colocate_with(global_step)):
@@ -232,7 +269,13 @@ def _dnn_tree_combined_model_fn(features,
return update_op
if predict_with_tree_only:
- tree_train_logits = tree_logits
+ if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.INFER:
+ tree_train_logits = tree_logits
+ else:
+ tree_train_logits = control_flow_ops.cond(
+ global_step > dnn_steps_to_train,
+ lambda: tree_logits,
+ lambda: dnn_logits)
else:
tree_train_logits = dnn_logits + tree_logits
@@ -325,6 +368,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
use_core_versions=False):
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
@@ -372,6 +416,13 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
"""
@@ -403,6 +454,7 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
+ dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions)
super(DNNBoostedTreeCombinedClassifier, self).__init__(
@@ -436,6 +488,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
use_core_versions=False):
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
@@ -483,6 +536,13 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
"""
@@ -519,6 +579,7 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
+ dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
@@ -553,6 +614,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
predict_with_tree_only=False,
tree_feature_columns=None,
tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
use_core_versions=False):
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
@@ -595,6 +657,13 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
set to True, these features are in addition to dnn_feature_columns.
tree_center_bias: Whether a separate tree should be created for
first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
"""
@@ -620,6 +689,7 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
predict_with_tree_only=predict_with_tree_only,
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
+ dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
use_core_versions=use_core_versions)
super(DNNBoostedTreeCombinedEstimator, self).__init__(
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index f495edc62f..9b7acfa664 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -131,6 +131,30 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
classifier.fit(input_fn=_train_input_fn, steps=15)
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ def testFitAndEvaluateWithDistillation(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.DNNBoostedTreeCombinedClassifier(
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[feature_column.real_valued_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ n_classes=2,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=False,
+ tree_feature_columns=[feature_column.real_valued_column("x")],
+ dnn_to_tree_distillation_param=(1, None))
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 89d0d611d2..59a78515c6 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
# supports second order derivative.
def loss_fn(labels, logits, weights=None):
result = losses.per_example_maxent_loss(
- labels=labels, logits=logits, weights=weights,
+ labels=labels,
+ logits=logits,
+ weights=weights,
num_classes=n_classes)
return math_ops.reduce_mean(result[0])
else:
@@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
},
model_dir=model_dir,
config=config,
@@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': False,
},
model_dir=model_dir,
config=config,
@@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- use_core_libs=False):
+ use_core_libs=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
the bias.
use_core_libs: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. For example,
+ result_dict = classifier.predict(...)
+ for example_prediction_result in result_dict:
+ # access leaf index list by example_prediction_result["leaf_index"]
+ # which contains one leaf index per tree
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -233,6 +264,92 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
'use_core_libs': use_core_libs,
+ 'output_leaf_index': False,
+ },
+ model_dir=model_dir,
+ config=config,
+ feature_engineering_fn=feature_engineering_fn)
+
+
+class GradientBoostedDecisionTreeRanker(estimator.Estimator):
+ """A ranking estimator using gradient boosted decision trees."""
+
+ def __init__(
+ self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ ):
+ """Initializes a GradientBoostedDecisionTreeRanker instance.
+
+ This is an estimator that can be trained off the pairwise data and can be
+ used for inference on non-paired data. This is essentially LambdaMart.
+ Args:
+ learner_config: A config for the learner.
+ examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ head: `Head` instance.
+ ranking_model_pair_keys: Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ num_trees: An int, number of trees to build.
+ feature_columns: A list of feature columns.
+ weight_column_name: Name of the column for weights, or None if not
+ weighted.
+ model_dir: Directory for model exports, etc.
+ config: `RunConfig` object to configure the runtime settings.
+ label_keys: Optional list of strings with size `[n_classes]` defining the
+ label vocabulary. Only supported for `n_classes` > 2.
+ feature_engineering_fn: Feature engineering function. Takes features and
+ labels which are the output of `input_fn` and returns features and
+ labels which will be fed into the model.
+ logits_modifier_function: A modifier function for the logits.
+ center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
+ output_leaf_index: whether to output leaf indices along with predictions
+ during inference. The leaf node indexes are available in predictions
+ dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is
+ [batch_size, num_trees].
+ For example,
+ result_iter = classifier.predict(...)
+ for result_dict in result_iter:
+ # access leaf index list by result_dict["leaf_index"]
+ # which contains one leaf index per tree
+
+ Raises:
+ ValueError: If learner_config is not valid.
+ """
+ super(GradientBoostedDecisionTreeRanker, self).__init__(
+ model_fn=model.ranking_model_builder,
+ params={
+ 'head': head,
+ 'n_classes': 2,
+ 'feature_columns': feature_columns,
+ 'learner_config': learner_config,
+ 'num_trees': num_trees,
+ 'weight_column_name': weight_column_name,
+ 'examples_per_layer': examples_per_layer,
+ 'center_bias': center_bias,
+ 'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': use_core_libs,
+ 'output_leaf_index': output_leaf_index,
+ 'ranking_model_pair_keys': ranking_model_pair_keys,
},
model_dir=model_dir,
config=config,
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 0d58317bd5..2c2dcb039d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -37,12 +37,31 @@ def _train_input_fn():
return features, label
+def _ranking_train_input_fn():
+ features = {
+ "a.f1": constant_op.constant([[3.], [0.3], [1.]]),
+ "a.f2": constant_op.constant([[0.1], [3.], [1.]]),
+ "b.f1": constant_op.constant([[13.], [0.4], [5.]]),
+ "b.f2": constant_op.constant([[1.], [3.], [0.01]]),
+ }
+ label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32)
+ return features, label
+
+
def _eval_input_fn():
features = {"x": constant_op.constant([[1.], [2.], [2.]])}
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
return features, label
+def _infer_ranking_train_input_fn():
+ features = {
+ "f1": constant_op.constant([[3.], [2], [1.]]),
+ "f2": constant_op.constant([[0.1], [3.], [1.]])
+ }
+ return features, None
+
+
class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -68,6 +87,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
classifier.export(self._export_dir_base)
+ def testThatLeafIndexIsInPredictions(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=True)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ result_iter = classifier.predict(input_fn=_eval_input_fn)
+ for prediction_dict in result_iter:
+ self.assertTrue("leaf_index" in prediction_dict)
+ self.assertTrue("logits" in prediction_dict)
+
def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
@@ -133,6 +174,34 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
regressor.evaluate(input_fn=_eval_input_fn, steps=1)
regressor.export(self._export_dir_base)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ model = estimator.GradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ use_core_libs=True,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ model.fit(input_fn=_ranking_train_input_fn, steps=1000)
+ model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ model.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 15ab6d8145..0e8a56e6e9 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import copy
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -28,7 +29,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
-
def model_builder(features, labels, mode, params, config):
"""Multi-machine batch gradient descent tree model.
@@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config):
num_trees = params["num_trees"]
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
+ output_leaf_index = params["output_leaf_index"]
+
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config):
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
features=training_features,
- use_core_columns=use_core_libs)
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config):
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
if num_trees:
if center_bias:
num_trees += 1
@@ -135,3 +141,184 @@ def model_builder(features, labels, mode, params, config):
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
finalized_trees))
return model_fn_ops
+
+
+def ranking_model_builder(features, labels, mode, params, config):
+ """Multi-machine batch gradient descent tree model for ranking.
+
+ Args:
+ features: `Tensor` or `dict` of `Tensor` objects.
+ labels: Labels used to train on.
+ mode: Mode we are in. (TRAIN/EVAL/INFER)
+ params: A dict of hyperparameters.
+ The following hyperparameters are expected:
+ * head: A `Head` instance.
+ * learner_config: A config for the learner.
+ * feature_columns: An iterable containing all the feature columns used by
+ the model.
+ * examples_per_layer: Number of examples to accumulate before growing a
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
+ * weight_column_name: The name of weight column.
+ * center_bias: Whether a separate tree should be created for first fitting
+ the bias.
+ * ranking_model_pair_keys (Optional): Keys to distinguish between features
+ for left and right part of the training pairs for ranking. For example,
+ for an Example with features "a.f1" and "b.f1", the keys would be
+ ("a", "b").
+ config: `RunConfig` of the estimator.
+
+ Returns:
+ A `ModelFnOps` object.
+ Raises:
+ ValueError: if inputs are not valid.
+ """
+ head = params["head"]
+ learner_config = params["learner_config"]
+ examples_per_layer = params["examples_per_layer"]
+ feature_columns = params["feature_columns"]
+ weight_column_name = params["weight_column_name"]
+ num_trees = params["num_trees"]
+ use_core_libs = params["use_core_libs"]
+ logits_modifier_function = params["logits_modifier_function"]
+ output_leaf_index = params["output_leaf_index"]
+ ranking_model_pair_keys = params["ranking_model_pair_keys"]
+
+ if features is None:
+ raise ValueError("At least one feature must be specified.")
+
+ if config is None:
+ raise ValueError("Missing estimator RunConfig.")
+
+ center_bias = params["center_bias"]
+
+ if isinstance(features, ops.Tensor):
+ features = {features.name: features}
+
+ # Make a shallow copy of features to ensure downstream usage
+ # is unaffected by modifications in the model function.
+ training_features = copy.copy(features)
+ training_features.pop(weight_column_name, None)
+ global_step = training_util.get_global_step()
+ with ops.device(global_step.device):
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config="", # Initialize an empty ensemble.
+ name="ensemble_model")
+
+ # Extract the features.
+ if mode == learn.ModeKeys.TRAIN or mode == learn.ModeKeys.EVAL:
+ # For ranking pairwise training, we extract two sets of features.
+ if len(ranking_model_pair_keys) != 2:
+ raise ValueError("You must provide keys for ranking.")
+ left_pair_key = ranking_model_pair_keys[0]
+ right_pair_key = ranking_model_pair_keys[1]
+ if left_pair_key is None or right_pair_key is None:
+ raise ValueError("Both pair keys should be provided for ranking.")
+
+ features_1 = {}
+ features_2 = {}
+ for name in training_features:
+ feature = training_features[name]
+ new_name = name[2:]
+ if name.startswith(left_pair_key + "."):
+ features_1[new_name] = feature
+ else:
+ assert name.startswith(right_pair_key + ".")
+ features_2[new_name] = feature
+
+ main_features = features_1
+ supplementary_features = features_2
+ else:
+ # For non-ranking or inference ranking, we have only 1 set of features.
+ main_features = training_features
+
+ # Create GBDT model.
+ gbdt_model_main = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=main_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ with ops.name_scope("gbdt", "gbdt_optimizer"):
+ # Logits for inference.
+ if mode == learn.ModeKeys.INFER:
+ predictions_dict = gbdt_model_main.predict(mode)
+ logits = predictions_dict[gbdt_batch.PREDICTIONS]
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+ else:
+ gbdt_model_supplementary = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=config.is_chief,
+ num_ps_replicas=config.num_ps_replicas,
+ ensemble_handle=ensemble_handle,
+ center_bias=center_bias,
+ examples_per_layer=examples_per_layer,
+ learner_config=learner_config,
+ feature_columns=feature_columns,
+ logits_dimension=head.logits_dimension,
+ features=supplementary_features,
+ use_core_columns=use_core_libs,
+ output_leaf_index=output_leaf_index)
+
+ # Logits for train and eval.
+ if not supplementary_features:
+ raise ValueError("Features for ranking must be specified.")
+
+ predictions_dict_1 = gbdt_model_main.predict(mode)
+ predictions_1 = predictions_dict_1[gbdt_batch.PREDICTIONS]
+
+ predictions_dict_2 = gbdt_model_supplementary.predict(mode)
+ predictions_2 = predictions_dict_2[gbdt_batch.PREDICTIONS]
+
+ logits = predictions_1 - predictions_2
+ if logits_modifier_function:
+ logits = logits_modifier_function(logits, features, mode)
+
+ predictions_dict = predictions_dict_1
+ predictions_dict[gbdt_batch.PREDICTIONS] = logits
+
+ def _train_op_fn(loss):
+ """Returns the op to optimize the loss."""
+ update_op = gbdt_model_main.train(loss, predictions_dict, labels)
+ with ops.control_dependencies(
+ [update_op]), (ops.colocate_with(global_step)):
+ update_op = state_ops.assign_add(global_step, 1).op
+ return update_op
+
+ create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+ if num_trees:
+ if center_bias:
+ num_trees += 1
+ finalized_trees, attempted_trees = (
+ gbdt_model_main.get_number_of_trees_tensor())
+ model_fn_ops.training_hooks.append(
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees))
+ return model_fn_ops
diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
index b3fe38614e..9493c1a139 100644
--- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
@@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout";
const char* kApplyAveragingAttributeName = "apply_averaging";
const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights";
const char* kPredictionsTensorName = "predictions";
+const char* kLeafIndexTensorName = "leaf_index";
void CalculateTreesToInclude(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
@@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel {
core::ScopedUnref unref_me(ensemble_resource);
if (use_locking_) {
tf_shared_lock l(*ensemble_resource->get_mutex());
- DoCompute(context, ensemble_resource);
+ DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/false);
} else {
- DoCompute(context, ensemble_resource);
+ DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/false);
}
}
- private:
- void DoCompute(OpKernelContext* context,
- DecisionTreeEnsembleResource* ensemble_resource) {
+ protected:
+ // return_output_leaf_index is a boolean variable indicating whether to output
+ // leaf index in prediction. Though this class invokes only with this param
+ // value as false, the subclass GradientTreesPredictionVerboseOp will invoke
+ // with the true value.
+ virtual void DoCompute(OpKernelContext* context,
+ DecisionTreeEnsembleResource* ensemble_resource,
+ const bool return_output_leaf_index) {
// Read dense float features list;
OpInputList dense_float_features_list;
OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures(
@@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel {
&output_predictions_t));
auto output_predictions = output_predictions_t->matrix<float>();
+ // Allocate output leaf index matrix.
+ Tensor* output_leaf_index_t = nullptr;
+ if (return_output_leaf_index) {
+ OP_REQUIRES_OK(context, context->allocate_output(
+ kLeafIndexTensorName,
+ {batch_size, ensemble_resource->num_trees()},
+ &output_leaf_index_t));
+ }
// Run predictor.
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
@@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel {
i, weight * (num_ensembles - i + start_averaging) / num_ensembles);
}
MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features,
- worker_threads, output_predictions);
+ worker_threads, output_predictions,
+ output_leaf_index_t);
} else {
MultipleAdditiveTrees::Predict(
ensemble_resource->decision_tree_ensemble(), trees_to_include,
- batch_features, worker_threads, output_predictions);
+ batch_features, worker_threads, output_predictions,
+ output_leaf_index_t);
}
// Output dropped trees and original weights.
@@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel {
{2, static_cast<int64>(dropped_trees.size())},
&output_dropout_info_t));
auto output_dropout_info = output_dropout_info_t->matrix<float>();
-
for (int32 i = 0; i < dropped_trees.size(); ++i) {
output_dropout_info(0, i) = dropped_trees[i];
output_dropout_info(1, i) = original_weights[i];
@@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU),
GradientTreesPredictionOp);
+// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp
+// and have an additional output of tensor of rank 2 containing leaf ids for
+// each tree where an instance ended up with.
+class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp {
+ public:
+ explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context)
+ : GradientTreesPredictionOp(context) {}
+
+ protected:
+ void DoCompute(OpKernelContext* context,
+ DecisionTreeEnsembleResource* ensemble_resource,
+ bool return_output_leaf_index) override {
+ GradientTreesPredictionOp::DoCompute(context, ensemble_resource,
+ /*return_output_leaf_index=*/true);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU),
+ GradientTreesPredictionVerboseOp);
+
class GradientTreesPartitionExamplesOp : public OpKernel {
public:
explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
index 56ff00b390..1b7f59ea42 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
@@ -37,6 +37,7 @@ class BaseSplitHandler(object):
gradient_shape,
hessian_shape,
multiclass_strategy,
+ loss_uses_sum_reduction=False,
name=None):
"""Constructor for BaseSplitHandler.
@@ -51,6 +52,8 @@ class BaseSplitHandler(object):
gradient_shape: A TensorShape, containing shape of gradients.
hessian_shape: A TensorShape, containing shape of hessians.
multiclass_strategy: Strategy describing how to treat multiclass problems.
+ loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
+ SUM or MEAN reduction was used for the loss.
name: An optional handler name.
"""
self._l1_regularization = l1_regularization
@@ -62,6 +65,7 @@ class BaseSplitHandler(object):
self._multiclass_strategy = multiclass_strategy
self._hessian_shape = hessian_shape
self._gradient_shape = gradient_shape
+ self._loss_uses_sum_reduction = loss_uses_sum_reduction
def scheduled_reads(self):
"""Returns the list of `ScheduledOp`s required for update_stats."""
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index 9f78ab2024..bf686237ff 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -44,6 +45,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
multiclass_strategy,
init_stamp_token=0,
+ loss_uses_sum_reduction=False,
name=None):
"""Initialize the internal state for this split handler.
@@ -62,6 +64,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy: Strategy describing how to treat multiclass problems.
init_stamp_token: A tensor containing an scalar for initial stamp of the
stamped objects.
+ loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
+ SUM or MEAN reduction was used for the loss.
name: An optional handler name.
"""
super(EqualitySplitHandler, self).__init__(
@@ -73,6 +77,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
multiclass_strategy=multiclass_strategy,
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
name=name)
self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
init_stamp_token,
@@ -173,6 +178,11 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
# pair.
num_minibatches, partition_ids, feature_ids, gradients, hessians = (
self._stats_accumulator.flush(stamp_token, next_stamp_token))
+ # For sum_reduction, we don't need to divide by number of minibatches.
+
+ num_minibatches = control_flow_ops.cond(
+ ops.convert_to_tensor(self._loss_uses_sum_reduction),
+ lambda: math_ops.to_int64(1), lambda: num_minibatches)
partition_ids, gains, split_infos = (
split_handler_ops.build_categorical_equality_splits(
num_minibatches=num_minibatches,
@@ -187,7 +197,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
tree_complexity_regularization=self._tree_complexity_regularization,
min_node_weight=self._min_node_weight,
bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy,))
+ multiclass_strategy=self._multiclass_strategy))
# There are no warm-up rounds needed in the equality column handler. So we
# always return ready.
are_splits_ready = constant_op.constant(True)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index 0b65eba2a7..ef253e7cec 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -90,7 +90,17 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
empty_hessians,
example_weights,
is_active=array_ops.constant([True, True]))
- with ops.control_dependencies([update_1]):
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+
+ with ops.control_dependencies([update_1, update_2]):
are_splits_ready, partitions, gains, splits = (
split_handler.make_splits(0, 1, class_id))
are_splits_ready, partitions, gains, splits = (sess.run(
@@ -159,6 +169,129 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
+ def testGenerateFeatureSplitCandidatesSumReduction(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 0 | 1,2 |
+ # i1 | (-0.5, 0.07) | 0 | |
+ # i2 | (1.2, 0.2) | 0 | 2 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [0, 0, 0, 1]
+ indices = [[0, 0], [0, 1], [2, 0], [3, 0]]
+ values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0,
+ loss_uses_sum_reduction=True)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1, update_2]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([0, 1], partitions)
+
+ # Check the split on partition 0.
+ # -(0.4 + 2.4 - 0.1) / (0.24 + 0.4 + 1)
+ expected_left_weight = -1.6463414634146338
+
+ # (0.4 + 2.4 - 0.1) ** 2 / (0.24 + 0.4 + 1)
+ expected_left_gain = 4.445121951219511
+
+ # -(-1 + 0.1) / (0.14 + 1)
+ expected_right_weight = 0.789473684211
+
+ # (-1 + 0.1) ** 2 / (0.14 + 1)
+ expected_right_gain = 0.710526315789
+
+ # (0.4 + -1 + 2.4 - 0.1) ** 2 / (0.24 + 0.14 + 0.4 + 1)
+ expected_bias_gain = 1.6235955056179772
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(2, split_node.feature_id)
+
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
+ # Check the split on partition 1.
+ # (-8 + 0.1) / (0.26 + 1)
+ expected_left_weight = -6.26984126984
+ # (-8 + 0.1) ** 2 / (0.26 + 1)
+ expected_left_gain = 49.5317460317
+ expected_right_weight = 0
+ expected_right_gain = 0
+ # (-8 + 0.1) ** 2 / (0.26 + 1)
+ expected_bias_gain = 49.5317460317
+
+ # Verify candidate for partition 1, there's only one active feature here
+ # so zero gain is expected.
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[1])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+ self.assertAllClose(0.0, gains[1], 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(1, split_node.feature_id)
+
def testGenerateFeatureSplitCandidatesMulticlass(self):
with self.test_session() as sess:
# Batch size is 4, 2 gradients per each instance.
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 409a2d8f46..df0bec1fe3 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -99,6 +99,7 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
multiclass_strategy,
init_stamp_token=0,
+ loss_uses_sum_reduction=False,
name=None):
"""Initialize the internal state for this split handler.
@@ -117,6 +118,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy: Strategy describing how to treat multiclass problems.
init_stamp_token: A tensor containing an scalar for initial stamp of the
stamped objects.
+ loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
+ SUM or MEAN reduction was used for the loss.
name: An optional handler name.
"""
super(InequalitySplitHandler, self).__init__(
@@ -128,7 +131,8 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
feature_column_group_id=feature_column_group_id,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=multiclass_strategy)
+ multiclass_strategy=multiclass_strategy,
+ loss_uses_sum_reduction=loss_uses_sum_reduction)
self._stats_accumulator = stats_accumulator_ops.StatsAccumulator(
init_stamp_token,
gradient_shape,
@@ -160,6 +164,7 @@ class DenseSplitHandler(InequalitySplitHandler):
hessian_shape,
multiclass_strategy,
init_stamp_token=0,
+ loss_uses_sum_reduction=False,
name=None):
"""Initialize the internal state for this split handler.
@@ -179,6 +184,8 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy: Strategy describing how to treat multiclass problems.
init_stamp_token: A tensor containing an scalar for initial stamp of the
stamped objects.
+ loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
+ SUM or MEAN reduction was used for the loss.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@@ -193,7 +200,8 @@ class DenseSplitHandler(InequalitySplitHandler):
name=name,
gradient_shape=gradient_shape,
hessian_shape=hessian_shape,
- multiclass_strategy=multiclass_strategy)
+ multiclass_strategy=multiclass_strategy,
+ loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
@@ -255,15 +263,15 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
- self._min_node_weight))
+ self._min_node_weight, self._loss_uses_sum_reduction))
return are_splits_ready, partition_ids, gains, split_infos
-def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
- stamp_token, next_stamp_token, multiclass_strategy,
- class_id, feature_column_id, l1_regularization,
- l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional):
+def _make_dense_split(
+ quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@@ -291,7 +299,10 @@ def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
stats_accumulator_handle, stamp_token, next_stamp_token))
-
+ # For sum_reduction, we don't need to divide by number of minibatches.
+ num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction,
+ lambda: math_ops.to_int64(1),
+ lambda: num_minibatches)
# Put quantile and stats accumulator flushing in the dependency path.
with ops.control_dependencies([flush_quantiles, partition_ids]):
are_splits_ready = array_ops.identity(are_splits_ready)
@@ -329,6 +340,7 @@ class SparseSplitHandler(InequalitySplitHandler):
hessian_shape,
multiclass_strategy,
init_stamp_token=0,
+ loss_uses_sum_reduction=False,
name=None):
"""Initialize the internal state for this split handler.
@@ -348,6 +360,8 @@ class SparseSplitHandler(InequalitySplitHandler):
multiclass_strategy: Strategy describing how to treat multiclass problems.
init_stamp_token: A tensor containing an scalar for initial stamp of the
stamped objects.
+ loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
+ SUM or MEAN reduction was used for the loss.
name: An optional handler name.
"""
super(SparseSplitHandler, self).__init__(
@@ -362,6 +376,7 @@ class SparseSplitHandler(InequalitySplitHandler):
hessian_shape=hessian_shape,
multiclass_strategy=multiclass_strategy,
init_stamp_token=init_stamp_token,
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
name=name)
self._sparse_float_column = sparse_float_column
@@ -424,15 +439,15 @@ class SparseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
- self._min_node_weight))
+ self._min_node_weight, self._loss_uses_sum_reduction))
return are_splits_ready, partition_ids, gains, split_infos
-def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
- stamp_token, next_stamp_token, multiclass_strategy,
- class_id, feature_column_id, l1_regularization,
- l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional):
+def _make_sparse_split(
+ quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
"""Function that builds splits for a sparse feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@@ -460,7 +475,9 @@ def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
stats_accumulator_handle, stamp_token, next_stamp_token))
-
+ num_minibatches = control_flow_ops.cond(loss_uses_sum_reduction,
+ lambda: math_ops.to_int64(1),
+ lambda: num_minibatches)
# Put quantile and stats accumulator flushing in the dependency path.
with ops.control_dependencies([flush_quantiles, partition_ids]):
are_splits_ready = array_ops.identity(are_splits_ready)
@@ -498,17 +515,18 @@ def _specialize_make_split(func, is_multi_dimentional):
dtypes.float32,
dtypes.float32,
dtypes.float32,
+ dtypes.bool,
noinline=True)
def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
- min_node_weight):
+ min_node_weight, loss_uses_sum_reduction):
"""Function that builds splits for a sparse feature column."""
- return func(
- quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
- next_stamp_token, multiclass_strategy, class_id, feature_column_id,
- l1_regularization, l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional)
+ return func(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy, class_id,
+ feature_column_id, l1_regularization, l2_regularization,
+ tree_complexity_regularization, min_node_weight,
+ is_multi_dimentional, loss_uses_sum_reduction)
return f
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 2f2c230211..d59732cf92 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -182,6 +182,144 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
+ def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Dense Quantile |
+ # i0 | (0.2, 0.12) | 0 | 1 |
+ # i1 | (-0.5, 0.07) | 0 | 1 |
+ # i2 | (1.2, 0.2) | 0 | 0 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ class_id = -1
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ split_handler = ordinal_split_handler.DenseSplitHandler(
+ l1_regularization=0.2,
+ l2_regularization=2.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
+ epsilon=0.001,
+ num_quantiles=10,
+ feature_column_group_id=0,
+ dense_float_column=dense_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ loss_uses_sum_reduction=True)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
+ with ops.control_dependencies([are_splits_ready]):
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_3 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2, update_3]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+ are_splits_ready, are_splits_ready2, partitions, gains, splits = (
+ sess.run([
+ are_splits_ready, are_splits_ready2, partitions, gains, splits
+ ]))
+
+ # During the first iteration, inequality split handlers are not going to
+ # have any splits. Make sure that we return not_ready in that case.
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+
+ self.assertAllEqual([0, 1], partitions)
+
+ # Check the split on partition 0.
+ # -(2.4 - 0.2) / (0.4 + 2)
+ expected_left_weight = -0.91666
+
+ # expected_left_weight * -(2.4 - 0.2)
+ expected_left_gain = 2.016666666666666
+
+ # -(-1 + 0.4 + 0.2) / (0.38 + 2)
+ expected_right_weight = 0.1680672
+
+ # expected_right_weight * -(-1 + 0.4 + 0.2)
+ expected_right_gain = 0.0672268907563025
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain = 0.9208633093525178
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.dense_float_binary_split
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertAllClose(0.3, split_node.threshold, 0.00001)
+
+ # Check the split on partition 1.
+ # (-8 + 0.2) / (0.26 + 2)
+ expected_left_weight = -3.4513274336283186
+ expected_right_weight = 0
+
+ # Verify candidate for partition 1, there's only one active bucket here
+ # so zero gain is expected.
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[1])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.dense_float_binary_split
+ self.assertAllClose(0.0, gains[1], 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertAllClose(0.52, split_node.threshold, 0.00001)
+
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
with self.test_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
@@ -798,6 +936,139 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
+ def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
+ with self.test_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Sparse Quantile |
+ # i0 | (0.2, 0.12) | 0 | 1 |
+ # i1 | (-0.5, 0.07) | 0 | N/A |
+ # i2 | (1.2, 0.2) | 0 | 0 |
+ # i3 | (4.0, 0.13) | 1 | 1 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ example_partitions = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ indices = array_ops.constant([[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64)
+ values = array_ops.constant([0.52, 0.3, 0.52])
+ sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = ordinal_split_handler.SparseSplitHandler(
+ l1_regularization=0.0,
+ l2_regularization=4.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
+ epsilon=0.01,
+ num_quantiles=2,
+ feature_column_group_id=0,
+ sparse_float_column=sparse_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ loss_uses_sum_reduction=True)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ example_partitions,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+ with ops.control_dependencies([are_splits_ready]):
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ example_partitions,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_3 = split_handler.update_stats_sync(
+ 1,
+ example_partitions,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2, update_3]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+ are_splits_ready, are_splits_ready2, partitions, gains, splits = (
+ sess.run([
+ are_splits_ready, are_splits_ready2, partitions, gains, splits
+ ]))
+
+ # During the first iteration, inequality split handlers are not going to
+ # have any splits. Make sure that we return not_ready in that case.
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+
+ self.assertAllEqual([0, 1], partitions)
+ # Check the split on partition 0.
+ # -(0.4 + 2.4) / (0.24 + 0.4 + 4)
+ expected_left_weight = -0.603448275862069
+ # (0.4 + 2.4) ** 2 / (0.24 + 0.4 + 4)
+ expected_left_gain = 1.689655172413793
+ # 1 / (0.14 + 4)
+ expected_right_weight = 0.24154589371980678
+ # 1 ** 2 / (0.14 + 4)
+ expected_right_gain = 0.24154589371980678
+ # (0.4 + 2.4 - 1) ** 2 / (0.24 + 0.4 + 0.14 + 4)
+ expected_bias_gain = 0.6778242677824265
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.sparse_float_binary_split_default_right
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0])
+
+ self.assertAllClose([expected_left_weight], left_child.value)
+
+ self.assertAllClose([expected_right_weight], right_child.value)
+
+ self.assertEqual(0, split_node.split.feature_column)
+
+ self.assertAllClose(0.52, split_node.split.threshold)
+
+ # Check the split on partition 1.
+ expected_left_weight = -1.8779342723004695
+ expected_right_weight = 0
+
+ # Verify candidate for partition 1, there's only one active bucket here
+ # so zero gain is expected.
+ split_info.ParseFromString(splits[1])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.sparse_float_binary_split_default_left
+
+ self.assertAllClose(0.0, gains[1])
+
+ self.assertAllClose([expected_left_weight], left_child.value)
+
+ self.assertAllClose([expected_right_weight], right_child.value)
+
+ self.assertEqual(0, split_node.split.feature_column)
+
+ self.assertAllClose(0.52, split_node.split.threshold)
+
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
with self.test_session() as sess:
# Batch is 4, 2 classes
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
index 43b00d4c6d..c9223afeab 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc
@@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict(
const std::vector<int32>& trees_to_include,
const boosted_trees::utils::BatchFeatures& features,
tensorflow::thread::ThreadPool* const worker_threads,
- tensorflow::TTypes<float>::Matrix output_predictions) {
+ tensorflow::TTypes<float>::Matrix output_predictions,
+ Tensor* const output_leaf_index) {
// Zero out predictions as the model is additive.
output_predictions.setZero();
@@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict(
// Lambda for doing a block of work.
auto update_predictions = [&config, &features, &trees_to_include,
- &output_predictions](int64 start, int64 end) {
+ &output_predictions,
+ &output_leaf_index](int64 start, int64 end) {
auto examples_iterable = features.examples_iterable(start, end);
+ Tensor dummy_tensor(DT_INT32, TensorShape({1, 1}));
+ tensorflow::TTypes<int>::Matrix output_leaf_index_mat =
+ output_leaf_index != nullptr ? output_leaf_index->matrix<int>()
+ : dummy_tensor.matrix<int>();
for (const auto& example : examples_iterable) {
for (const int32 tree_idx : trees_to_include) {
const boosted_trees::trees::DecisionTreeConfig& tree =
@@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict(
const float tree_weight = config.tree_weights(tree_idx);
const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example);
QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString();
+ // Checks if output leaf tree index is required.
+ if (output_leaf_index != nullptr) {
+ output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx;
+ }
const auto& leaf_node = tree.nodes(leaf_idx);
QCHECK(leaf_node.has_leaf())
<< "Invalid leaf node: " << leaf_node.DebugString();
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
index cc3dc226cd..940531c4ba 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h
@@ -33,12 +33,17 @@ class MultipleAdditiveTrees {
public:
// Predict runs tree ensemble on the given batch and updates
// output predictions accordingly, for the given list of trees.
+ // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not
+ // nullptr, this method fills output_leaf_indices with a per-tree leaf id
+ // where each of the instances from 'features' ended up in. Its shape is num
+ // examples X num of trees.
static void Predict(
const boosted_trees::trees::DecisionTreeEnsembleConfig& config,
const std::vector<int32>& trees_to_include,
const boosted_trees::utils::BatchFeatures& features,
tensorflow::thread::ThreadPool* const worker_threads,
- tensorflow::TTypes<float>::Matrix output_predictions);
+ tensorflow::TTypes<float>::Matrix output_predictions,
+ Tensor* const output_leaf_index);
};
} // namespace models
diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
index 4ca18bedb1..462a9ac86f 100644
--- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
+++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc
@@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) {
tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test",
kNumThreadsSingleThreaded);
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix,
+ /*output_leaf_index=*/nullptr);
EXPECT_EQ(0, output_matrix(0, 0));
EXPECT_EQ(0, output_matrix(1, 0));
}
@@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ /*output_leaf_index=*/nullptr);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
}
+ // Normal case with leaf node.
+ {
+ // Initialize output leaf index tensor, since leaf index is positive in this
+ // case, initialize with the value of -1. Since there are 2 examples and
+ // there are 2 trees, initialize leaf output index by 2 * 2.
+ Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2}));
+ MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
+ batch_features_, &threads, output_matrix,
+ &output_leaf_index_tensor);
+ EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2).
+ EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1).
+ EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
+ 0, 0)); // 1st leaf for the first example
+ EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix<int>()(
+ 1, 0)); // 1st leaf for the second example
+ EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix<int>()(
+ 0, 1)); // 2nd leaf for the first example
+ EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix<int>()(
+ 1, 1)); // 2nd leaf for the second example
+ }
// Weighted case
{
DecisionTreeEnsembleConfig weighted = tree_ensemble_config;
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
- output_matrix);
+ output_matrix, nullptr);
// -0.4 (bias) + 0.2 (leaf 2).
EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0));
// -0.4 (bias) + 0.9 (leaf 1).
@@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) {
// Drop first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2).
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1).
}
// Drop second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias).
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias).
}
// Drop all trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0));
}
@@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1)
@@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
weighted.set_tree_weights(0, 6.0);
weighted.set_tree_weights(1, 3.2);
MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads,
- output_matrix);
+ output_matrix, nullptr);
// bias
EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0));
// bias + leaf 2
@@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Dropout first tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2)
EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2)
@@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Dropout second tree.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias)
EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias)
@@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) {
// Drop both trees.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_,
- &threads, output_matrix);
+ &threads, output_matrix, nullptr);
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0));
EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1));
EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0));
@@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) {
// Normal case.
{
MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1},
- batch_features_, &threads, output_matrix);
+ batch_features_, &threads, output_matrix,
+ nullptr);
EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2)
EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2)
EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2)
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index a7e7bfc13c..69bb8fd4ad 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -51,7 +51,7 @@ class WeightedQuantilesSummary {
SummaryEntry() {
memset(this, 0, sizeof(*this));
- value = 0;
+ value = ValueType();
weight = 0;
min_rank = 0;
max_rank = 0;
diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
index d66f645f62..6491d58794 100644
--- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
@@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
return Status::OK();
}
+static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) {
+ string learner_config_str;
+ c->GetAttr("learner_config", &learner_config_str).IgnoreError();
+ LearnerConfig learner_config;
+ ParseProtoUnlimited(&learner_config, learner_config_str);
+
+ bool reduce_dim;
+ c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
+ // Sets the shape of the output as a matrix.
+ c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
+ reduce_dim ? learner_config.num_classes() - 1
+ : learner_config.num_classes())});
+ c->set_output(1, {c->UnknownShape()});
+ c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim)});
+ return Status::OK();
+}
+
REGISTER_OP("GradientTreesPrediction")
.Attr("learner_config: string")
.Attr("num_dense_float_features: int >= 0")
@@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
and original weights of those trees during prediction.
)doc");
+REGISTER_OP("GradientTreesPredictionVerbose")
+ .Attr("learner_config: string")
+ .Attr("num_dense_float_features: int >= 0")
+ .Attr("num_sparse_float_features: int >= 0")
+ .Attr("num_sparse_int_features: int >= 0")
+ .Attr("use_locking: bool = false")
+ .Attr("apply_dropout: bool")
+ .Attr("apply_averaging: bool")
+ .Attr("center_bias: bool")
+ .Attr("reduce_dim: bool")
+ .Input("tree_ensemble_handle: resource")
+ .Input("seed: int64")
+ .Input("dense_float_features: num_dense_float_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_float_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
+ .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_values: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
+ .Output("predictions: float")
+ .Output("drop_out_tree_indices_weights: float")
+ .Output("leaf_index: int32")
+ .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn)
+ .Doc(R"doc(
+Runs multiple additive regression forests predictors on input instances
+and computes the final prediction for each class, and outputs a matrix of
+leaf ids per each tree in an ensemble.
+
+learner_config: Config for the learner of type LearnerConfig proto. Prediction
+ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
+num_dense_float_features: Number of dense float features.
+num_sparse_float_features: Number of sparse float features.
+num_sparse_int_features: Number of sparse int features.
+use_locking: Whether to use locking.
+seed: random seed to be used for dropout.
+reduce_dim: whether to reduce the dimension (legacy impl) or not.
+apply_dropout: whether to apply dropout during prediction.
+apply_averaging: whether averaging of tree ensembles should take place. If set
+to true, will be based on AveragingConfig from learner_config.
+tree_ensemble_handle: The handle to the tree ensemble.
+dense_float_features: Rank 2 Tensors containing dense float feature values.
+sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
+sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
+sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
+sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
+sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
+sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
+predictions: Rank 2 Tensor containing predictions per example per class.
+drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
+leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up.
+)doc");
+
REGISTER_OP("GradientTreesPartitionExamples")
.Attr("num_dense_float_features: int >= 0")
.Attr("num_sparse_float_features: int >= 0")
diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
index 58f0d36b0f..7f6e55ae58 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py
@@ -21,4 +21,5 @@ from __future__ import print_function
from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader
from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples
from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction
+from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 5dd2e0c7f2..1ee7f2395e 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -46,6 +46,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import device_setter
@@ -58,8 +59,16 @@ NUM_LAYERS_ATTEMPTED = "num_layers"
NUM_TREES_ATTEMPTED = "num_trees"
NUM_USED_HANDLERS = "num_used_handlers"
USED_HANDLERS_MASK = "used_handlers_mask"
+LEAF_INDEX = "leaf_index"
_FEATURE_NAME_TEMPLATE = "%s_%d"
+# Keys in Training state.
+GBDTTrainingState = collections.namedtuple("GBDTTrainingState", [
+ "num_layer_examples", "num_layer_steps", "num_layers", "active_tree",
+ "active_layer", "continue_centering", "bias_stats_accumulator",
+ "steps_accumulator", "handlers"
+])
+
def _get_column_by_index(tensor, indices):
"""Returns columns from a 2-D tensor by index."""
@@ -71,18 +80,24 @@ def _get_column_by_index(tensor, indices):
return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1])
-def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats,
- used_handlers):
+def _make_predictions_dict(stamp,
+ logits,
+ partition_ids,
+ ensemble_stats,
+ used_handlers,
+ leaf_index=None):
"""Returns predictions for the given logits and n_classes.
Args:
stamp: The ensemble stamp.
- logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1].
- that contains predictions when no dropout was applied.
+ logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that
+ contains predictions when no dropout was applied.
partition_ids: A rank 1 `Tensor` with shape [batch_size].
ensemble_stats: A TreeEnsembleStatsOp result tuple.
used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a
- boolean mask..
+ boolean mask.
+ leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that
+ contains leaf id for each example prediction.
Returns:
A dict of predictions.
@@ -95,6 +110,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats,
result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees
result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers
result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask
+ if leaf_index is not None:
+ result[LEAF_INDEX] = leaf_index
return result
@@ -267,8 +284,10 @@ class GradientBoostedDecisionTreeModel(object):
learner_config,
features,
logits_dimension,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS,
feature_columns=None,
- use_core_columns=False):
+ use_core_columns=False,
+ output_leaf_index=False):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -276,13 +295,18 @@ class GradientBoostedDecisionTreeModel(object):
num_ps_replicas: Number of parameter server replicas, can be 0.
ensemble_handle: A handle to the ensemble variable.
center_bias: Whether to center the bias before growing trees.
- examples_per_layer: Number of examples to accumulate before growing
- a tree layer. It can also be a function that computes the number of
- examples based on the depth of the layer that's being built.
+ examples_per_layer: Number of examples to accumulate before growing a tree
+ layer. It can also be a function that computes the number of examples
+ based on the depth of the layer that's being built.
learner_config: A learner config.
features: `dict` of `Tensor` objects.
logits_dimension: An int, the dimension of logits.
+ loss_reduction: Either `SUM_OVER_NONZERO_WEIGHTS` (mean) or `SUM`.
feature_columns: A list of feature columns.
+ use_core_columns: A boolean specifying whether core feature columns are
+ used.
+ output_leaf_index: A boolean variable indicating whether to output leaf
+ index into predictions dictionary.
Raises:
ValueError: if inputs are not valid.
@@ -303,6 +327,13 @@ class GradientBoostedDecisionTreeModel(object):
self._center_bias = center_bias
self._examples_per_layer = examples_per_layer
+ # Check loss reduction value.
+ if (loss_reduction != losses.Reduction.SUM and
+ loss_reduction != losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
+ raise ValueError(
+ "Invalid loss reduction is provided: %s." % loss_reduction)
+ self._loss_reduction = loss_reduction
+
# Fill in the defaults.
if (learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.MULTI_CLASS_STRATEGY_UNSPECIFIED):
@@ -313,6 +344,19 @@ class GradientBoostedDecisionTreeModel(object):
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
+ if logits_dimension == 1 or learner_config.multi_class_strategy == (
+ learner_pb2.LearnerConfig.TREE_PER_CLASS):
+ self._gradient_shape = tensor_shape.scalar()
+ self._hessian_shape = tensor_shape.scalar()
+ else:
+ self._gradient_shape = tensor_shape.TensorShape([logits_dimension])
+ if (learner_config.multi_class_strategy ==
+ learner_pb2.LearnerConfig.FULL_HESSIAN):
+ self._hessian_shape = tensor_shape.TensorShape(
+ ([logits_dimension, logits_dimension]))
+ else:
+ # Diagonal hessian strategy.
+ self._hessian_shape = tensor_shape.TensorShape(([logits_dimension]))
if (learner_config.growing_mode ==
learner_pb2.LearnerConfig.GROWING_MODE_UNSPECIFIED):
learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
@@ -347,6 +391,7 @@ class GradientBoostedDecisionTreeModel(object):
sparse_int_values, sparse_int_shapes) = extract_features(
features, self._feature_columns, use_core_columns)
logging.info("Active Feature Columns: " + str(fc_names))
+ logging.info("Learner config: " + str(learner_config))
self._fc_names = fc_names
self._dense_floats = dense_floats
self._sparse_float_indices = sparse_float_indices
@@ -359,6 +404,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.multi_class_strategy ==
learner_pb2.LearnerConfig.TREE_PER_CLASS and
learner_config.num_classes == 2)
+ self._output_leaf_index = output_leaf_index
def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode):
"""Runs prediction and returns a dictionary of the prediction results.
@@ -388,22 +434,44 @@ class GradientBoostedDecisionTreeModel(object):
# Make sure ensemble stats run. This will check that the ensemble has
# the right stamp.
with ops.control_dependencies(ensemble_stats):
- predictions, _ = prediction_ops.gradient_trees_prediction(
- ensemble_handle,
- seed,
- self._dense_floats,
- self._sparse_float_indices,
- self._sparse_float_values,
- self._sparse_float_shapes,
- self._sparse_int_indices,
- self._sparse_int_values,
- self._sparse_int_shapes,
- learner_config=self._learner_config_serialized,
- apply_dropout=apply_dropout,
- apply_averaging=mode != learn.ModeKeys.TRAIN,
- use_locking=True,
- center_bias=self._center_bias,
- reduce_dim=self._reduce_dim)
+ leaf_index = None
+ # Only used in infer (predict), not used in train and eval.
+ if self._output_leaf_index and mode == learn.ModeKeys.INFER:
+ predictions, _, leaf_index = (
+ prediction_ops).gradient_trees_prediction_verbose(
+ ensemble_handle,
+ seed,
+ self._dense_floats,
+ self._sparse_float_indices,
+ self._sparse_float_values,
+ self._sparse_float_shapes,
+ self._sparse_int_indices,
+ self._sparse_int_values,
+ self._sparse_int_shapes,
+ learner_config=self._learner_config_serialized,
+ apply_dropout=apply_dropout,
+ apply_averaging=mode != learn.ModeKeys.TRAIN,
+ use_locking=True,
+ center_bias=self._center_bias,
+ reduce_dim=self._reduce_dim)
+ else:
+ leaf_index = None
+ predictions, _ = prediction_ops.gradient_trees_prediction(
+ ensemble_handle,
+ seed,
+ self._dense_floats,
+ self._sparse_float_indices,
+ self._sparse_float_values,
+ self._sparse_float_shapes,
+ self._sparse_int_indices,
+ self._sparse_int_values,
+ self._sparse_int_shapes,
+ learner_config=self._learner_config_serialized,
+ apply_dropout=apply_dropout,
+ apply_averaging=mode != learn.ModeKeys.TRAIN,
+ use_locking=True,
+ center_bias=self._center_bias,
+ reduce_dim=self._reduce_dim)
partition_ids = prediction_ops.gradient_trees_partition_examples(
ensemble_handle,
self._dense_floats,
@@ -416,7 +484,7 @@ class GradientBoostedDecisionTreeModel(object):
use_locking=True)
return _make_predictions_dict(ensemble_stamp, predictions, partition_ids,
- ensemble_stats, used_handlers)
+ ensemble_stats, used_handlers, leaf_index)
def predict(self, mode):
"""Returns predictions given the features and mode.
@@ -487,17 +555,30 @@ class GradientBoostedDecisionTreeModel(object):
return self._predict_and_return_dict(self._ensemble_handle,
ensemble_stamp, mode)
- def train(self, loss, predictions_dict, labels):
- """Grows a new tree and adds it to the ensemble.
+ def _get_class_id(self, predictions_dict):
+ # Handle different multiclass strategies.
+ if (self._learner_config.multi_class_strategy ==
+ learner_pb2.LearnerConfig.TREE_PER_CLASS and
+ self._logits_dimension != 1):
+ # Choose the class for which the tree is built (one vs rest).
+ return math_ops.to_int32(
+ predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension)
+ return constant_op.constant(-1, dtype=dtypes.int32)
+
+ def update_stats(self, loss, predictions_dict):
+ """Update the accumulators with stats from this batch.
Args:
loss: A scalar tensor representing average loss of examples.
predictions_dict: Dictionary of Rank 2 `Tensor` representing information
about predictions per example.
- labels: Rank 2 `Tensor` representing labels per example.
Returns:
- An op that adds a new tree to the ensemble.
+ Three values:
+ - An op that adds a new tree to the ensemble, and
+ - An op that increments the stamp but removes all the trees and resets
+ the handlers. This can be used to reset the state of the ensemble.
+ - A dict containing the training state.
Raises:
ValueError: if inputs are not valid.
@@ -521,13 +602,10 @@ class GradientBoostedDecisionTreeModel(object):
aggregation_method=None)[0]
strategy = self._learner_config.multi_class_strategy
- class_id = constant_op.constant(-1, dtype=dtypes.int32)
+ class_id = self._get_class_id(predictions_dict)
# Handle different multiclass strategies.
if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS:
# We build one vs rest trees.
- gradient_shape = tensor_shape.scalar()
- hessian_shape = tensor_shape.scalar()
-
if self._logits_dimension == 1:
# We have only 1 score, gradients is of shape [batch, 1].
hessians = gradients_impl.gradients(
@@ -544,11 +622,6 @@ class GradientBoostedDecisionTreeModel(object):
hessian_list = self._diagonal_hessian(gradients, predictions)
# Assemble hessian list into a tensor.
hessians = array_ops.stack(hessian_list, axis=1)
-
- # Choose the class for which the tree is built (one vs rest).
- class_id = math_ops.to_int32(
- predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension)
-
# Use class id tensor to get the column with that index from gradients
# and hessians.
squeezed_gradients = array_ops.squeeze(
@@ -557,15 +630,10 @@ class GradientBoostedDecisionTreeModel(object):
_get_column_by_index(hessians, class_id))
else:
# Other multiclass strategies.
- gradient_shape = tensor_shape.TensorShape([self._logits_dimension])
-
if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN:
- hessian_shape = tensor_shape.TensorShape(
- ([self._logits_dimension, self._logits_dimension]))
hessian_list = self._full_hessian(gradients, predictions)
else:
# Diagonal hessian strategy.
- hessian_shape = tensor_shape.TensorShape(([self._logits_dimension]))
hessian_list = self._diagonal_hessian(gradients, predictions)
squeezed_gradients = gradients
@@ -573,7 +641,7 @@ class GradientBoostedDecisionTreeModel(object):
squeezed_hessians = hessians
# Get the weights for each example for quantiles calculation,
- weights = self._get_weights(hessian_shape, squeezed_hessians)
+ weights = self._get_weights(self._hessian_shape, squeezed_hessians)
# Create all handlers ensuring resources are evenly allocated across PS.
fc_name_idx = 0
@@ -587,6 +655,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config.regularization.tree_complexity, dtypes.float32)
min_node_weight = constant_op.constant(
self._learner_config.constraints.min_node_weight, dtypes.float32)
+ loss_uses_sum_reduction = self._loss_reduction == losses.Reduction.SUM
+ loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
epsilon = 0.01
num_quantiles = 100
strategy_tensor = constant_op.constant(strategy)
@@ -600,15 +670,18 @@ class GradientBoostedDecisionTreeModel(object):
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- feature_column_group_id=dense_float_column_idx,
+ feature_column_group_id=constant_op.constant(
+ dense_float_column_idx),
epsilon=epsilon,
num_quantiles=num_quantiles,
dense_float_column=self._dense_floats[dense_float_column_idx],
name=fc_name,
- gradient_shape=gradient_shape,
- hessian_shape=hessian_shape,
+ gradient_shape=self._gradient_shape,
+ hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
- init_stamp_token=init_stamp_token))
+ init_stamp_token=init_stamp_token,
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
+ ))
fc_name_idx += 1
# Create handlers for sparse float columns.
@@ -620,7 +693,8 @@ class GradientBoostedDecisionTreeModel(object):
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- feature_column_group_id=sparse_float_column_idx,
+ feature_column_group_id=constant_op.constant(
+ sparse_float_column_idx),
epsilon=epsilon,
num_quantiles=num_quantiles,
sparse_float_column=sparse_tensor.SparseTensor(
@@ -628,10 +702,11 @@ class GradientBoostedDecisionTreeModel(object):
self._sparse_float_values[sparse_float_column_idx],
self._sparse_float_shapes[sparse_float_column_idx]),
name=fc_name,
- gradient_shape=gradient_shape,
- hessian_shape=hessian_shape,
+ gradient_shape=self._gradient_shape,
+ hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
- init_stamp_token=init_stamp_token))
+ init_stamp_token=init_stamp_token,
+ loss_uses_sum_reduction=loss_uses_sum_reduction))
fc_name_idx += 1
# Create handlers for sparse int columns.
@@ -643,32 +718,20 @@ class GradientBoostedDecisionTreeModel(object):
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- feature_column_group_id=sparse_int_column_idx,
+ feature_column_group_id=constant_op.constant(
+ sparse_int_column_idx),
sparse_int_column=sparse_tensor.SparseTensor(
self._sparse_int_indices[sparse_int_column_idx],
self._sparse_int_values[sparse_int_column_idx],
self._sparse_int_shapes[sparse_int_column_idx]),
name=fc_name,
- gradient_shape=gradient_shape,
- hessian_shape=hessian_shape,
+ gradient_shape=self._gradient_shape,
+ hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
- init_stamp_token=init_stamp_token))
+ init_stamp_token=init_stamp_token,
+ loss_uses_sum_reduction=loss_uses_sum_reduction))
fc_name_idx += 1
- # Create steps accumulator.
- steps_accumulator = stats_accumulator_ops.StatsAccumulator(
- stamp_token=0,
- gradient_shape=tensor_shape.scalar(),
- hessian_shape=tensor_shape.scalar(),
- name="StepsAccumulator")
-
- # Create bias stats accumulator.
- bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator(
- stamp_token=0,
- gradient_shape=gradient_shape,
- hessian_shape=hessian_shape,
- name="BiasAccumulator")
-
# Create ensemble stats variables.
num_layer_examples = variables.Variable(
initial_value=array_ops.zeros([], dtypes.int64),
@@ -690,7 +753,23 @@ class GradientBoostedDecisionTreeModel(object):
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
-
+ # Variable that becomes false once bias centering is done.
+ continue_centering = variables.Variable(
+ initial_value=self._center_bias,
+ name="continue_centering",
+ trainable=False)
+ # Create bias stats accumulator.
+ bias_stats_accumulator = stats_accumulator_ops.StatsAccumulator(
+ stamp_token=0,
+ gradient_shape=self._gradient_shape,
+ hessian_shape=self._hessian_shape,
+ name="BiasAccumulator")
+ # Create steps accumulator.
+ steps_accumulator = stats_accumulator_ops.StatsAccumulator(
+ stamp_token=0,
+ gradient_shape=tensor_shape.scalar(),
+ hessian_shape=tensor_shape.scalar(),
+ name="StepsAccumulator")
# Create ensemble stats summaries.
summary.scalar("layer_stats/num_examples", num_layer_examples)
summary.scalar("layer_stats/num_steps", num_layer_steps)
@@ -699,16 +778,13 @@ class GradientBoostedDecisionTreeModel(object):
# Update bias stats.
stats_update_ops = []
- continue_centering = variables.Variable(
- initial_value=self._center_bias,
- name="continue_centering",
- trainable=False)
+
stats_update_ops.append(
control_flow_ops.cond(
continue_centering,
- self._make_update_bias_stats_fn(ensemble_stamp, predictions,
- gradients, bias_stats_accumulator),
- control_flow_ops.no_op))
+ self._make_update_bias_stats_fn(
+ ensemble_stamp, predictions, gradients,
+ bias_stats_accumulator), control_flow_ops.no_op))
# Update handler stats.
handler_reads = collections.OrderedDict()
@@ -765,8 +841,8 @@ class GradientBoostedDecisionTreeModel(object):
lambda: active_handlers))
# Prepare empty gradients and hessians when handlers are not ready.
- empty_hess_shape = [1] + hessian_shape.as_list()
- empty_grad_shape = [1] + gradient_shape.as_list()
+ empty_hess_shape = [1] + self._hessian_shape.as_list()
+ empty_grad_shape = [1] + self._gradient_shape.as_list()
empty_gradients = constant_op.constant(
[], dtype=dtypes.float32, shape=empty_grad_shape)
@@ -788,34 +864,86 @@ class GradientBoostedDecisionTreeModel(object):
per_handler_updates, ensemble_stamp, worker_device)
for update in update_results.values():
stats_update_ops += update
+
+ training_state = GBDTTrainingState(
+ num_layer_examples=num_layer_examples,
+ num_layer_steps=num_layer_steps,
+ num_layers=num_layers,
+ active_tree=active_tree,
+ active_layer=active_layer,
+ continue_centering=continue_centering,
+ bias_stats_accumulator=bias_stats_accumulator,
+ steps_accumulator=steps_accumulator,
+ handlers=handlers)
+
+ reset_op = control_flow_ops.no_op()
+ if self._is_chief:
+ # Advance the ensemble stamp to throw away staggered workers.
+ stamp_token, _ = model_ops.tree_ensemble_serialize(self._ensemble_handle)
+ next_stamp_token = stamp_token + 1
+
+ reset_ops = []
+ for handler in handlers:
+ reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0))
+ if self._center_bias:
+ reset_ops.append(
+ bias_stats_accumulator.flush(stamp_token, next_stamp_token))
+ reset_ops.append(steps_accumulator.flush(stamp_token, next_stamp_token))
+ reset_ops.append(self._finalized_trees.assign(0).op)
+ reset_ops.append(self._attempted_trees.assign(0).op)
+ reset_ops.append(
+ model_ops.tree_ensemble_deserialize(
+ self._ensemble_handle,
+ stamp_token=next_stamp_token,
+ tree_ensemble_config="",
+ name="reset_gbdt"))
+
+ reset_op = control_flow_ops.group([reset_ops])
+
+ return stats_update_ops, reset_op, training_state
+
+ def increment_step_counter_and_maybe_update_ensemble(self, predictions_dict,
+ training_state):
+ """Increments number of visited examples and grows the ensemble.
+
+ If the number of visited examples reaches the target examples_per_layer,
+ ensemble is updated.
+
+ Args:
+ predictions_dict: Dictionary of Rank 2 `Tensor` representing information
+ about predictions per example.
+ training_state: `dict` returned by update_stats.
+
+ Returns:
+ An op that updates the counters and potientially grows the ensemble.
+ """
+ batch_size = math_ops.cast(
+ array_ops.shape(predictions_dict[PREDICTIONS])[0], dtypes.float32)
+ ensemble_stamp = predictions_dict[ENSEMBLE_STAMP]
# Accumulate a step after updating stats.
- batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32)
- with ops.control_dependencies(stats_update_ops):
- add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]],
- [batch_size], [1.0])
- # Determine learning rate.
- learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof(
- "tuner")
- if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout":
- tuner = getattr(self._learner_config.learning_rate_tuner,
- learning_rate_tuner)
- learning_rate = tuner.learning_rate
- else:
- # TODO(nponomareva, soroush) do the line search.
- raise ValueError("Line search learning rate is not yet supported.")
+ steps_accumulator = training_state.steps_accumulator
+ num_layer_examples = training_state.num_layer_examples
+ num_layer_steps = training_state.num_layer_steps
+ active_layer = training_state.active_layer
+ add_step_op = steps_accumulator.add(
+ ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0])
# After adding the step, decide if further processing is needed.
ensemble_update_ops = [add_step_op]
+ class_id = self._get_class_id(predictions_dict)
+
with ops.control_dependencies([add_step_op]):
if self._is_chief:
dropout_seed = predictions_dict[NUM_TREES_ATTEMPTED]
# Get accumulated steps and examples for the current layer.
- _, _, _, _, acc_examples, acc_steps = steps_accumulator.serialize()
+ _, _, _, _, acc_examples, acc_steps = (
+ steps_accumulator.serialize())
acc_examples = math_ops.cast(acc_examples[0], dtypes.int64)
acc_steps = math_ops.cast(acc_steps[0], dtypes.int64)
- ensemble_update_ops.append(num_layer_examples.assign(acc_examples))
+ ensemble_update_ops.append(
+ num_layer_examples.assign(acc_examples))
ensemble_update_ops.append(num_layer_steps.assign(acc_steps))
# Determine whether we need to update tree ensemble.
examples_per_layer = self._examples_per_layer
@@ -824,18 +952,172 @@ class GradientBoostedDecisionTreeModel(object):
ensemble_update_ops.append(
control_flow_ops.cond(
acc_examples >= examples_per_layer,
- self._make_update_ensemble_fn(
- ensemble_stamp, steps_accumulator, bias_stats_accumulator,
- continue_centering, learning_rate, handlers, num_layers,
- active_tree, active_layer, dropout_seed, class_id),
+ self.make_update_ensemble_fn(ensemble_stamp, training_state,
+ dropout_seed, class_id),
control_flow_ops.no_op))
- # Calculate the loss to be reported.
# Note, the loss is calculated from the prediction considering dropouts, so
# that the value might look staggering over steps when the dropout ratio is
# high. eval_loss might be referred instead in the aspect of convergence.
return control_flow_ops.group(*ensemble_update_ops)
+ def make_update_ensemble_fn(self, ensemble_stamp, training_state,
+ dropout_seed, class_id):
+ """A method to create the function which updates the tree ensemble."""
+ # Determine learning rate.
+ learning_rate_tuner = self._learner_config.learning_rate_tuner.WhichOneof(
+ "tuner")
+ if learning_rate_tuner == "fixed" or learning_rate_tuner == "dropout":
+ tuner = getattr(self._learner_config.learning_rate_tuner,
+ learning_rate_tuner)
+ learning_rate = tuner.learning_rate
+ else:
+ # TODO(nponomareva, soroush) do the line search.
+ raise ValueError("Line search learning rate is not yet supported.")
+
+ def _update_ensemble():
+ """A method to update the tree ensemble."""
+ # Get next stamp token.
+ next_ensemble_stamp = ensemble_stamp + 1
+ # Finalize bias stats.
+ _, _, _, bias_grads, bias_hess = (
+ training_state.bias_stats_accumulator.flush(ensemble_stamp,
+ next_ensemble_stamp))
+
+ # Finalize handler splits.
+ are_splits_ready_list = []
+ partition_ids_list = []
+ gains_list = []
+ split_info_list = []
+
+ for handler in training_state.handlers:
+ (are_splits_ready,
+ partition_ids, gains, split_info) = handler.make_splits(
+ ensemble_stamp, next_ensemble_stamp, class_id)
+ are_splits_ready_list.append(are_splits_ready)
+ partition_ids_list.append(partition_ids)
+ gains_list.append(gains)
+ split_info_list.append(split_info)
+ # Stack all the inputs to one tensor per type.
+ # This is a workaround for the slowness of graph building in tf.cond.
+ # See (b/36554864).
+ split_sizes = array_ops.reshape(
+ array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
+ partition_ids = array_ops.concat(partition_ids_list, axis=0)
+ gains = array_ops.concat(gains_list, axis=0)
+ split_infos = array_ops.concat(split_info_list, axis=0)
+
+ # Determine if all splits are ready.
+ are_all_splits_ready = math_ops.reduce_all(
+ array_ops.stack(
+ are_splits_ready_list, axis=0, name="stack_handler_readiness"))
+
+ # Define bias centering update operation.
+ def _center_bias_fn():
+ # Center tree ensemble bias.
+ delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess,
+ array_ops.zeros_like(bias_grads))
+ center_bias = training_ops.center_tree_ensemble_bias(
+ tree_ensemble_handle=self._ensemble_handle,
+ stamp_token=ensemble_stamp,
+ next_stamp_token=next_ensemble_stamp,
+ delta_updates=delta_updates,
+ learner_config=self._learner_config_serialized)
+ return training_state.continue_centering.assign(center_bias)
+
+ # Define ensemble growing operations.
+ def _grow_ensemble_ready_fn():
+ # Grow the ensemble given the current candidates.
+ sizes = array_ops.unstack(split_sizes)
+ partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
+ gains_list = list(array_ops.split(gains, sizes, axis=0))
+ split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
+ return training_ops.grow_tree_ensemble(
+ tree_ensemble_handle=self._ensemble_handle,
+ stamp_token=ensemble_stamp,
+ next_stamp_token=next_ensemble_stamp,
+ learning_rate=learning_rate,
+ partition_ids=partition_ids_list,
+ gains=gains_list,
+ splits=split_info_list,
+ learner_config=self._learner_config_serialized,
+ dropout_seed=dropout_seed,
+ center_bias=self._center_bias)
+
+ def _grow_ensemble_not_ready_fn():
+ # Don't grow the ensemble, just update the stamp.
+ return training_ops.grow_tree_ensemble(
+ tree_ensemble_handle=self._ensemble_handle,
+ stamp_token=ensemble_stamp,
+ next_stamp_token=next_ensemble_stamp,
+ learning_rate=0,
+ partition_ids=[],
+ gains=[],
+ splits=[],
+ learner_config=self._learner_config_serialized,
+ dropout_seed=dropout_seed,
+ center_bias=self._center_bias)
+
+ def _grow_ensemble_fn():
+ # Conditionally grow an ensemble depending on whether the splits
+ # from all the handlers are ready.
+ return control_flow_ops.cond(are_all_splits_ready,
+ _grow_ensemble_ready_fn,
+ _grow_ensemble_not_ready_fn)
+
+ # Update ensemble.
+ update_ops = [are_all_splits_ready]
+ if self._center_bias:
+ update_model = control_flow_ops.cond(training_state.continue_centering,
+ _center_bias_fn, _grow_ensemble_fn)
+ else:
+ update_model = _grow_ensemble_fn()
+ update_ops.append(update_model)
+
+ # Update ensemble stats.
+ with ops.control_dependencies([update_model]):
+ stats = training_ops.tree_ensemble_stats(
+ self._ensemble_handle, stamp_token=next_ensemble_stamp)
+ update_ops.append(self._finalized_trees.assign(stats.num_trees))
+ update_ops.append(self._attempted_trees.assign(stats.attempted_trees))
+ update_ops.append(training_state.num_layers.assign(stats.num_layers))
+ update_ops.append(training_state.active_tree.assign(stats.active_tree))
+ update_ops.append(
+ training_state.active_layer.assign(stats.active_layer))
+
+ # Flush step stats.
+ update_ops.extend(
+ training_state.steps_accumulator.flush(ensemble_stamp,
+ next_ensemble_stamp))
+ return control_flow_ops.group(*update_ops, name="update_ensemble")
+
+ return _update_ensemble
+
+ def get_number_of_trees_tensor(self):
+ return self._finalized_trees, self._attempted_trees
+
+ def train(self, loss, predictions_dict, labels):
+ """Updates the accumalator stats and grows the ensemble.
+
+ Args:
+ loss: A scalar tensor representing average loss of examples.
+ predictions_dict: Dictionary of Rank 2 `Tensor` representing information
+ about predictions per example.
+ labels: Rank 2 `Tensor` representing labels per example. Has no effect
+ on the training and is only kept for backward compatibility.
+
+ Returns:
+ An op that adds a new tree to the ensemble.
+
+ Raises:
+ ValueError: if inputs are not valid.
+ """
+ del labels # unused; kept for backward compatibility.
+ update_op, _, training_state = self.update_stats(loss, predictions_dict)
+ with ops.control_dependencies(update_op):
+ return self.increment_step_counter_and_maybe_update_ensemble(
+ predictions_dict, training_state)
+
def _get_weights(self, hessian_shape, hessians):
"""Derives weights to be used based on hessians and multiclass strategy."""
if hessian_shape == tensor_shape.scalar():
@@ -951,127 +1233,3 @@ class GradientBoostedDecisionTreeModel(object):
return control_flow_ops.group(*[add_stats_op], name="update_bias_stats")
return _update_bias_stats
-
- def _make_update_ensemble_fn(self, ensemble_stamp, steps_accumulator,
- bias_stats_accumulator, continue_centering,
- learning_rate, handlers, num_layers, active_tree,
- active_layer, dropout_seed, class_id):
- """A method to create the function which updates the tree ensemble."""
-
- def _update_ensemble():
- """A method to update the tree ensemble."""
- # Get next stamp token.
- next_ensemble_stamp = ensemble_stamp + 1
- # Finalize bias stats.
- _, _, _, bias_grads, bias_hess = bias_stats_accumulator.flush(
- ensemble_stamp, next_ensemble_stamp)
-
- # Finalize handler splits.
- are_splits_ready_list = []
- partition_ids_list = []
- gains_list = []
- split_info_list = []
-
- for handler in handlers:
- (are_splits_ready,
- partition_ids, gains, split_info) = handler.make_splits(
- ensemble_stamp, next_ensemble_stamp, class_id)
- are_splits_ready_list.append(are_splits_ready)
- partition_ids_list.append(partition_ids)
- gains_list.append(gains)
- split_info_list.append(split_info)
- # Stack all the inputs to one tensor per type.
- # This is a workaround for the slowness of graph building in tf.cond.
- # See (b/36554864).
- split_sizes = array_ops.reshape(
- array_ops.shape_n(partition_ids_list), [len(partition_ids_list)])
- partition_ids = array_ops.concat(partition_ids_list, axis=0)
- gains = array_ops.concat(gains_list, axis=0)
- split_infos = array_ops.concat(split_info_list, axis=0)
-
- # Determine if all splits are ready.
- are_all_splits_ready = math_ops.reduce_all(
- array_ops.stack(
- are_splits_ready_list, axis=0, name="stack_handler_readiness"))
-
- # Define bias centering update operation.
- def _center_bias_fn():
- # Center tree ensemble bias.
- delta_updates = array_ops.where(bias_hess > 0, -bias_grads / bias_hess,
- array_ops.zeros_like(bias_grads))
- center_bias = training_ops.center_tree_ensemble_bias(
- tree_ensemble_handle=self._ensemble_handle,
- stamp_token=ensemble_stamp,
- next_stamp_token=next_ensemble_stamp,
- delta_updates=delta_updates,
- learner_config=self._learner_config_serialized)
- return continue_centering.assign(center_bias)
-
- # Define ensemble growing operations.
- def _grow_ensemble_ready_fn():
- # Grow the ensemble given the current candidates.
- sizes = array_ops.unstack(split_sizes)
- partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
- gains_list = list(array_ops.split(gains, sizes, axis=0))
- split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
- return training_ops.grow_tree_ensemble(
- tree_ensemble_handle=self._ensemble_handle,
- stamp_token=ensemble_stamp,
- next_stamp_token=next_ensemble_stamp,
- learning_rate=learning_rate,
- partition_ids=partition_ids_list,
- gains=gains_list,
- splits=split_info_list,
- learner_config=self._learner_config_serialized,
- dropout_seed=dropout_seed,
- center_bias=self._center_bias)
-
- def _grow_ensemble_not_ready_fn():
- # Don't grow the ensemble, just update the stamp.
- return training_ops.grow_tree_ensemble(
- tree_ensemble_handle=self._ensemble_handle,
- stamp_token=ensemble_stamp,
- next_stamp_token=next_ensemble_stamp,
- learning_rate=0,
- partition_ids=[],
- gains=[],
- splits=[],
- learner_config=self._learner_config_serialized,
- dropout_seed=dropout_seed,
- center_bias=self._center_bias)
-
- def _grow_ensemble_fn():
- # Conditionally grow an ensemble depending on whether the splits
- # from all the handlers are ready.
- return control_flow_ops.cond(are_all_splits_ready,
- _grow_ensemble_ready_fn,
- _grow_ensemble_not_ready_fn)
-
- # Update ensemble.
- update_ops = [are_all_splits_ready]
- if self._center_bias:
- update_model = control_flow_ops.cond(continue_centering,
- _center_bias_fn, _grow_ensemble_fn)
- else:
- update_model = _grow_ensemble_fn()
- update_ops.append(update_model)
-
- # Update ensemble stats.
- with ops.control_dependencies([update_model]):
- stats = training_ops.tree_ensemble_stats(
- self._ensemble_handle, stamp_token=next_ensemble_stamp)
- update_ops.append(self._finalized_trees.assign(stats.num_trees))
- update_ops.append(self._attempted_trees.assign(stats.attempted_trees))
- update_ops.append(num_layers.assign(stats.num_layers))
- update_ops.append(active_tree.assign(stats.active_tree))
- update_ops.append(active_layer.assign(stats.active_layer))
-
- # Flush step stats.
- update_ops.extend(
- steps_accumulator.flush(ensemble_stamp, next_ensemble_stamp))
- return control_flow_ops.group(*update_ops, name="update_ensemble")
-
- return _update_ensemble
-
- def get_number_of_trees_tensor(self):
- return self._finalized_trees, self._attempted_trees
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 289fb195db..f7867d882d 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -19,19 +19,17 @@ from __future__ import division
from __future__ import print_function
from google.protobuf import text_format
-
from tensorflow.contrib import layers
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.boosted_trees.python.utils import losses
-
-from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.learn.python.learn.estimators import model_fn
-
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -782,6 +780,118 @@ class GbdtTest(test_util.TensorFlowTestCase):
[[0.25], [0.25], [0.25], [0.25]])
self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
+ def testPredictFnWithLeafIndexAdvancedLeft(self):
+ """Tests the predict function with output leaf ids."""
+ with self.test_session() as sess:
+ # Create ensemble with one bias node.
+ ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ dense_float_binary_split {
+ threshold: 1.0
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 0
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.15
+ }
+ }
+ }
+ }
+ trees {
+ nodes {
+ dense_float_binary_split {
+ threshold: 0.99
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 00
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.23
+ }
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }""", ensemble_config)
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=3,
+ tree_ensemble_config=ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.constraints.min_node_weight = 0
+ features = {}
+ features["dense_float"] = array_ops.constant(
+ [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32)
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=False,
+ num_ps_replicas=0,
+ center_bias=True,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features,
+ output_leaf_index=True)
+
+ # Create predict op.
+ mode = model_fn.ModeKeys.INFER
+ predictions_dict = sess.run(gbdt_model.predict(mode))
+ self.assertEquals(predictions_dict["ensemble_stamp"], 3)
+ # here are how the numbers in expected results are calculated,
+ # 0.5 = 0.25 + 0.25
+ # 0.48 = 0.25 + 0.23
+ # 0.38 = 0.15 + 0.23
+ # 0.38 = 0.15 + 0.23
+ self.assertAllClose(predictions_dict["predictions"],
+ [[0.5], [0.48], [0.38], [0.38]])
+ self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0])
+ self.assertAllClose(predictions_dict["leaf_index"],
+ [[1, 1], [1, 2], [2, 2], [2, 2]])
+
def testTrainFnMulticlassFullHessian(self):
"""Tests the GBDT train for multiclass full hessian."""
with self.test_session() as sess:
@@ -1451,6 +1561,301 @@ class GbdtTest(test_util.TensorFlowTestCase):
self.assertEquals(output.growing_metadata.num_layers_attempted, 2)
+ def testResetModelBeforeAndAfterSplit(self):
+ """Tests whether resetting works."""
+ with self.test_session():
+ # First build a small tree and train it to verify training works.
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ features = {}
+ features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions = array_ops.constant(
+ [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
+ partition_ids = array_ops.zeros([4], dtypes.int32)
+ ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle)
+
+ predictions_dict = {
+ "predictions": predictions,
+ "predictions_no_dropout": predictions,
+ "partition_ids": partition_ids,
+ "ensemble_stamp": ensemble_stamp,
+ "num_trees": 12,
+ "max_tree_depth": 4,
+ }
+
+ labels = array_ops.ones([4, 1], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+ loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions))
+
+ # Create train op.
+ update_op, reset_op, training_state = gbdt_model.update_stats(
+ loss, predictions_dict)
+ with ops.control_dependencies(update_op):
+ train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble(
+ predictions_dict, training_state)
+
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ original_stamp = ensemble_stamp.eval()
+ expected_tree = """
+ nodes {
+ dense_float_binary_split {
+ threshold: 1.0
+ left_id: 1
+ right_id: 2
+ }
+ node_metadata {
+ gain: 0
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.0
+ }
+ }
+ }"""
+
+ def _train_once_and_check(expect_split):
+ stamp = ensemble_stamp.eval()
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(stamp_token.eval(), stamp + 1)
+ if expect_split:
+ # State of the ensemble after a split occurs.
+ self.assertEquals(len(output.trees), 1)
+ self.assertProtoEquals(expected_tree, output.trees[0])
+ else:
+ # State of the ensemble after a single accumulation but before any
+ # splitting occurs
+ self.assertEquals(len(output.trees), 0)
+ self.assertProtoEquals("""
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }""", output)
+
+ def _run_reset():
+ stamp_before_reset = ensemble_stamp.eval()
+ reset_op.run()
+ stamp_after_reset = ensemble_stamp.eval()
+ self.assertNotEquals(stamp_after_reset, stamp_before_reset)
+
+ _, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertProtoEquals("", output)
+
+ return stamp_after_reset
+
+ # Exit after one train_op, so no new layer are created but the handlers
+ # contain enough information to split on the next call to train.
+ _train_once_and_check(expect_split=False)
+ self.assertEquals(ensemble_stamp.eval(), original_stamp + 1)
+
+ # Reset the handlers so it still requires two training calls to split.
+ stamp_after_reset = _run_reset()
+
+ _train_once_and_check(expect_split=False)
+ _train_once_and_check(expect_split=True)
+ self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2)
+
+ # This time, test that the reset_op works right after splitting.
+ stamp_after_reset = _run_reset()
+
+ # Test that after resetting, the tree can be trained as normal.
+ _train_once_and_check(expect_split=False)
+ _train_once_and_check(expect_split=True)
+ self.assertEquals(ensemble_stamp.eval(), stamp_after_reset + 2)
+
+ def testResetModelNonChief(self):
+ """Tests the reset function on a non-chief worker."""
+ with self.test_session():
+ # Create ensemble with one bias node.
+ ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: false
+ }""", ensemble_config)
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ features = {}
+ features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=False,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions = array_ops.constant(
+ [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
+ partition_ids = array_ops.zeros([4], dtypes.int32)
+ ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle)
+
+ predictions_dict = {
+ "predictions": predictions,
+ "predictions_no_dropout": predictions,
+ "partition_ids": partition_ids,
+ "ensemble_stamp": ensemble_stamp
+ }
+
+ labels = array_ops.ones([4, 1], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+ loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions))
+
+ # Create reset op.
+ _, reset_op, _ = gbdt_model.update_stats(
+ loss, predictions_dict)
+
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Reset op doesn't do anything because this is a non-chief worker.
+ reset_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertEquals(len(output.tree_weights), 1)
+ self.assertEquals(stamp_token.eval(), 0)
+
+ def testResetModelWithCenterBias(self):
+ """Tests the reset function running on chief with bias centering."""
+ with self.test_session():
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.learning_rate_tuner.fixed.learning_rate = 0.1
+ learner_config.num_classes = 2
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 1
+ learner_config.constraints.min_node_weight = 0
+ features = {}
+ features["dense_float"] = array_ops.ones([4, 1], dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=True,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions = array_ops.constant(
+ [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
+ partition_ids = array_ops.zeros([4], dtypes.int32)
+ ensemble_stamp = model_ops.tree_ensemble_stamp_token(ensemble_handle)
+
+ predictions_dict = {
+ "predictions": predictions,
+ "predictions_no_dropout": predictions,
+ "partition_ids": partition_ids,
+ "ensemble_stamp": ensemble_stamp,
+ "num_trees": 12,
+ }
+
+ labels = array_ops.ones([4, 1], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+ loss = math_ops.reduce_mean(_squared_loss(labels, weights, predictions))
+
+ # Create train op.
+ update_op, reset_op, training_state = gbdt_model.update_stats(
+ loss, predictions_dict)
+ with ops.control_dependencies(update_op):
+ train_op = gbdt_model.increment_step_counter_and_maybe_update_ensemble(
+ predictions_dict, training_state)
+
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect bias to be centered.
+ def train_and_check():
+ train_op.run()
+ _, serialized = model_ops.tree_ensemble_serialize(ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ expected_tree = """
+ nodes {
+ leaf {
+ vector {
+ value: 0.25
+ }
+ }
+ }"""
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllEqual(output.tree_weights, [1.0])
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
+ train_and_check()
+ self.assertEquals(ensemble_stamp.eval(), 1)
+
+ reset_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 2)
+
+ train_and_check()
+ self.assertEquals(ensemble_stamp.eval(), 3)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index ab7ac2aba6..b5ebaf1999 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -23,6 +23,12 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+
+
+def per_example_squared_hinge_loss(labels, weights, predictions):
+ loss = losses.hinge_loss(labels=labels, logits=predictions, weights=weights)
+ return math_ops.square(loss), control_flow_ops.no_op()
def per_example_logistic_loss(labels, weights, predictions):
@@ -126,7 +132,7 @@ def per_example_squared_loss(labels, weights, predictions):
def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
- """Exponential loss given labels, example weights and predictions.
+ """Trimmed exponential loss given labels, example weights and predictions.
Note that this is only for binary classification.
If logistic loss tries to make sure that the classifier is certain of its
@@ -211,3 +217,62 @@ def per_example_exp_loss(labels, weights, predictions, name=None, eps=0.1):
unweighted_loss = exp_with_logits(
name=name, eps=eps, labels=labels, logits=predictions)
return unweighted_loss * weights, control_flow_ops.no_op()
+
+
+def per_example_full_exp_loss(labels, weights, predictions, name=None):
+ """Full exponential loss given labels, example weights and predictions.
+
+ Note that this is only for binary classification.
+ The loss returns is exp(-targets*logits), where targets are converted to -1
+ and 1.
+
+ Args:
+ labels: Rank 2 (N, D) tensor of per-example labels.
+ weights: Rank 2 (N, 1) tensor of per-example weights.
+ predictions: Rank 2 (N, D) tensor of per-example predictions.
+ name: A name for the operation (optional).
+
+ Returns:
+ loss: A Rank 2 (N, 1) tensor of per-example exp loss
+ update_op: An update operation to update the loss's internal state.
+ """
+
+ def full_exp_with_logits(name, labels=None, logits=None):
+ """Computes exponential loss given `logits`.
+
+ Args:
+ name: A name for the operation (optional).
+ labels: A `Tensor` of the same type and shape as `logits`.
+ logits: A `Tensor` of type `float32` or `float64`.
+
+ Returns:
+ A `Tensor` of the same shape as `logits` with the componentwise
+ exponential losses.
+
+ Raises:
+ ValueError: If `logits` and `labels` do not have the same shape.
+ """
+ with ops.name_scope(name, "exp_loss", [logits, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ labels = ops.convert_to_tensor(labels, name="labels")
+ try:
+ labels.get_shape().merge_with(logits.get_shape())
+ except ValueError:
+ raise ValueError("logits and labels must have the same shape (%s vs %s)"
+ % (logits.get_shape(), labels.get_shape()))
+
+ # Default threshold of 0 to switch between classes
+ zeros = array_ops.zeros_like(logits, dtype=logits.dtype)
+ ones = array_ops.ones_like(logits, dtype=logits.dtype)
+ neg_ones = -array_ops.ones_like(logits, dtype=logits.dtype)
+
+ # Convert labels to 1 and -1
+ cond_labels = (labels > zeros)
+ labels_converted = array_ops.where(cond_labels, ones, neg_ones)
+
+ return math_ops.exp(-1.0 * logits * labels_converted)
+
+ labels = math_ops.to_float(labels)
+ unweighted_loss = full_exp_with_logits(
+ name=name, labels=labels, logits=predictions)
+ return unweighted_loss * weights, control_flow_ops.no_op()
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 8ae493ba99..2fbaa31d5e 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -16,10 +16,13 @@
Visualization and inspection:
@@dot_graph_from_checkpoint
+@@list_objects
@@object_metadata
Managing dependencies:
+@@capture_dependencies
@@Checkpointable
+@@CheckpointableBase
@@CheckpointableObjectGraph
@@NoDependency
@@split_dependency
@@ -38,13 +41,15 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
-from tensorflow.python.training.checkpointable.base import Checkpointable
-from tensorflow.python.training.checkpointable.base import NoDependency
+from tensorflow.python.training.checkpointable.base import CheckpointableBase
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
+from tensorflow.python.training.checkpointable.data_structures import NoDependency
+from tensorflow.python.training.checkpointable.tracking import Checkpointable
+from tensorflow.python.training.checkpointable.util import capture_dependencies
+from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(module_name=__name__)
-
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 3717d7f583..ac85c7be80 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -26,13 +26,14 @@ from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
class UniqueNameTrackerTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNames(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@@ -48,11 +49,11 @@ class UniqueNameTrackerTests(test.TestCase):
slots.track(y, "y")
self.evaluate((x1.initializer, x2.initializer, x3.initializer,
y.initializer))
- save_root = checkpointable_utils.Checkpoint(slots=slots)
+ save_root = util.Checkpoint(slots=slots)
save_path = save_root.save(checkpoint_prefix)
- restore_slots = checkpointable.Checkpointable()
- restore_root = checkpointable_utils.Checkpoint(
+ restore_slots = tracking.Checkpointable()
+ restore_root = util.Checkpoint(
slots=restore_slots)
status = restore_root.restore(save_path)
restore_slots.x = resource_variable_ops.ResourceVariable(0.)
@@ -65,9 +66,9 @@ class UniqueNameTrackerTests(test.TestCase):
self.assertEqual(4., self.evaluate(restore_slots.x_1_1))
self.assertEqual(5., self.evaluate(restore_slots.y))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testExample(self):
- class SlotManager(checkpointable.Checkpointable):
+ class SlotManager(tracking.Checkpointable):
def __init__(self):
self.slotdeps = containers.UniqueNameTracker()
@@ -79,15 +80,15 @@ class UniqueNameTrackerTests(test.TestCase):
resource_variable_ops.ResourceVariable(4.), "y"))
slots.append(slotdeps.track(
resource_variable_ops.ResourceVariable(5.), "x"))
- self.slots = slots
+ self.slots = data_structures.NoDependency(slots)
manager = SlotManager()
self.evaluate([v.initializer for v in manager.slots])
- checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
+ checkpoint = util.Checkpoint(slot_manager=manager)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_path = checkpoint.save(checkpoint_prefix)
- metadata = checkpointable_utils.object_metadata(save_path)
+ metadata = util.object_metadata(save_path)
dependency_names = []
for node in metadata.nodes:
for child in node.children:
@@ -97,7 +98,7 @@ class UniqueNameTrackerTests(test.TestCase):
dependency_names,
["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayers(self):
tracker = containers.UniqueNameTracker()
tracker.track(layers.Dense(3), "dense")
diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
index 69dc0b9be2..00a805af25 100644
--- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py
+++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
@@ -23,8 +23,9 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
def _split_variable_closure(variable):
@@ -43,7 +44,7 @@ def _combine_variable_closure(variable):
return _consume_restore_buffer_fn
-class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+class SaveTensorSlicesAsDeps(base.CheckpointableBase):
def __init__(self):
self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
@@ -58,14 +59,14 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
self._track_checkpointable(dep, name=name)
-class HasRegularDeps(checkpointable.Checkpointable):
+class HasRegularDeps(tracking.Checkpointable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
-class OnlyOneDep(checkpointable.Checkpointable):
+class OnlyOneDep(tracking.Checkpointable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
@@ -73,9 +74,9 @@ class OnlyOneDep(checkpointable.Checkpointable):
class SplitTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveRestoreSplitDep(self):
- save_checkpoint = checkpointable_utils.Checkpoint(
+ save_checkpoint = util.Checkpoint(
dep=SaveTensorSlicesAsDeps())
self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
checkpoint_directory = self.get_temp_dir()
@@ -83,7 +84,7 @@ class SplitTests(test.TestCase):
save_path = save_checkpoint.save(checkpoint_prefix)
regular_deps = HasRegularDeps()
- regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+ regular_restore_checkpoint = util.Checkpoint(
dep=regular_deps)
regular_restore_checkpoint.restore(
save_path).assert_consumed().run_restore_ops()
@@ -91,7 +92,7 @@ class SplitTests(test.TestCase):
self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
one_dep = OnlyOneDep()
- one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+ one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep)
status = one_dep_restore_checkpoint.restore(save_path)
with self.assertRaises(AssertionError):
# Missing the second dependency.
@@ -99,7 +100,7 @@ class SplitTests(test.TestCase):
status.run_restore_ops()
self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
- restore_checkpoint = checkpointable_utils.Checkpoint()
+ restore_checkpoint = util.Checkpoint()
status = restore_checkpoint.restore(save_path)
restore_checkpoint.dep = SaveTensorSlicesAsDeps()
status.assert_consumed().run_restore_ops()
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index 42ba368531..523a9efcf0 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -50,6 +50,7 @@ py_library(
deps = [
":gen_bigquery_reader_ops",
":gen_gcs_config_ops",
+ "//tensorflow/contrib/bigtable",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
@@ -74,3 +75,14 @@ tf_py_test(
],
tags = ["manual"],
)
+
+tf_py_test(
+ name = "gcs_config_ops_test",
+ size = "small",
+ srcs = ["python/ops/gcs_config_ops_test.py"],
+ additional_deps = [
+ ":cloud_py",
+ "//tensorflow/python:client_testlib",
+ ],
+ tags = ["manual"],
+)
diff --git a/tensorflow/contrib/cloud/README.md b/tensorflow/contrib/cloud/README.md
new file mode 100644
index 0000000000..134ce057f4
--- /dev/null
+++ b/tensorflow/contrib/cloud/README.md
@@ -0,0 +1,18 @@
+# Cloud #
+
+## BigTable ##
+
+[Google Cloud BigTable](https://cloud.google.com/bigtable/) is a high
+performance storage system that can store and serve training data. This contrib
+package contains an experimental integration with TensorFlow.
+
+> **Status: Highly experimental.** The current implementation is very much in
+> flux. Please use at your own risk! :-)
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
+
+## Cloud Storage (GCS) ##
+
+The Google Cloud Storage ops allow the user to configure the GCS File System.
+
+<!-- TODO(saeta): Document usage / methods / etc. -->
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
index ef7aa7624c..af81106a68 100644
--- a/tensorflow/contrib/cloud/__init__.py
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -18,15 +18,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=line-too-long,wildcard-import
+import os
+
+# pylint: disable=line-too-long,wildcard-import,g-import-not-at-top
from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import *
from tensorflow.contrib.cloud.python.ops.gcs_config_ops import *
-# pylint: enable=line-too-long,wildcard-import
+
+if os.name != 'nt':
+ from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigTable
+ from tensorflow.contrib.bigtable.python.ops.bigtable_api import BigtableClient
+
+del os
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'BigQueryReader',
+ 'BigTable',
+ 'BigtableClient',
'BlockCacheParams',
'configure_colab_session',
'configure_gcs',
diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
index 5e31a15498..9cf85f5f18 100644
--- a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
+++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
@@ -21,12 +21,50 @@ namespace tensorflow {
REGISTER_OP("GcsConfigureCredentials")
.Input("json: string")
- .SetShapeFn(shape_inference::NoOutputs);
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Configures the credentials used by the GCS client of the local TF runtime.
+
+The json input can be of the format:
+
+1. Refresh Token:
+{
+ "client_id": "<redacted>",
+ "client_secret": "<redacted>",
+ "refresh_token: "<redacted>",
+ "type": "authorized_user",
+}
+
+2. Service Account:
+{
+ "type": "service_account",
+ "project_id": "<redacted>",
+ "private_key_id": "<redacted>",
+ "private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
+ "client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
+ "client_id": "<REDACTED>",
+ # Some additional fields elided
+}
+
+Note the credentials established through this method are shared across all
+sessions run on this runtime.
+
+Note be sure to feed the inputs to this op to ensure the credentials are not
+stored in a constant op within the graph that might accidentally be checkpointed
+or in other ways be persisted or exfiltrated.
+)doc");
REGISTER_OP("GcsConfigureBlockCache")
.Input("max_cache_size: uint64")
.Input("block_size: uint64")
.Input("max_staleness: uint64")
- .SetShapeFn(shape_inference::NoOutputs);
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Re-configures the GCS block cache with the new configuration values.
+
+If the values are the same as already configured values, this op is a no-op. If
+they are different, the current contents of the block cache is dropped, and a
+new block cache is created fresh.
+)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
index 4f7300fd1f..95e7e744d3 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
@@ -124,6 +124,7 @@ class ConfigureGcsHook(training.SessionRunHook):
self._credentials_placeholder)
else:
self._credentials_op = None
+
if self._block_cache:
self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache(
max_cache_size=self._block_cache.max_bytes,
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
new file mode 100644
index 0000000000..9b6c056d6c
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -0,0 +1,44 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the gcs_config_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.cloud.python.ops import gcs_config_ops
+from tensorflow.python.platform import test
+
+
+class GcsConfigOpsTest(test.TestCase):
+
+ def testSetBlockCache(self):
+ cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
+ with self.test_session() as sess:
+ gcs_config_ops.configure_gcs(sess, block_cache=cfg)
+
+ def testConfigureGcsHook(self):
+ creds = {'client_id': 'fake_client',
+ 'refresh_token': 'fake_token',
+ 'client_secret': 'fake_secret',
+ 'type': 'authorized_user'}
+ hook = gcs_config_ops.ConfigureGcsHook(credentials=creds)
+ hook.begin()
+ with self.test_session() as sess:
+ sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None
+ hook.after_create_session(sess, None)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 935ad5ff37..8f521ffee4 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -36,6 +36,7 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
@@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver):
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
- def _gkeMaster():
- return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ def _gkeEndpoints():
+ return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _envVarFallback():
@@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver):
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
- tpu = self._gkeMaster()
+ tpu = self._gkeEndpoints()
else:
tpu = self._envVarFallback()
@@ -170,10 +171,11 @@ class TPUClusterResolver(ClusterResolver):
if service is None and should_resolve:
if not _GOOGLE_API_CLIENT_INSTALLED:
- raise ImportError('googleapiclient must be installed before using the '
- 'TPU cluster resolver. Execute: `pip install '
- '--upgrade google-api-python-client` to install with '
- 'pip.')
+ raise ImportError('googleapiclient and oauth2client must be installed '
+ 'before using the TPU cluster resolver. Execute: '
+ '`pip install --upgrade google-api-python-client` '
+ 'and `pip install --upgrade oauth2client` to '
+ 'install with pip.')
final_discovery_url = self._discoveryUrl() or discovery_url
if final_discovery_url:
@@ -213,7 +215,7 @@ class TPUClusterResolver(ClusterResolver):
ValueError: If none of the TPUs specified exists.
"""
if not self._shouldResolve():
- return self._tpu
+ return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
job_tasks = self.cluster_spec().job_tasks(self._job_name)
if not job_tasks:
@@ -258,6 +260,7 @@ class TPUClusterResolver(ClusterResolver):
if 'state' in response and response['state'] != 'READY':
raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
(self._tpu, response['state']))
+
if 'health' in response and response['health'] != 'HEALTHY':
raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
response['health']))
@@ -278,8 +281,12 @@ class TPUClusterResolver(ClusterResolver):
# Case 3.
return None
# Case 2.
- cluster_spec = {self._job_name: [self._tpu[len(
- compat.as_bytes('grpc://')):]]}
+ cluster_spec = {
+ self._job_name: [
+ x[len(compat.as_bytes('grpc://')):]
+ for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))
+ ]
+ }
if self._coordinator_address:
# {1, 2}.a
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 7e002cc72f..ad4f643263 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -157,7 +157,7 @@ class TPUClusterResolverTest(test.TestCase):
job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
-
+
@mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
mock_request_compute_metadata)
def testUnhealthyCloudTpu(self):
@@ -402,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase):
compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
- def testGkeEnvironment(self):
+ def testGkeEnvironmentForDonut(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
- self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
+
+ self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(TPUClusterResolver._inGke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(TPUClusterResolver._gkeMaster()))
+ compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
+
+ tpu_cluster_resolver = TPUClusterResolver()
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ compat.as_bytes(tpu_cluster_resolver.master()))
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.120.27.5:8470' }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
+
+ def testGkeEnvironmentForPod(self):
+ os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
+ 'grpc://10.120.27.6:8470,'
+ 'grpc://10.120.27.7:8470,'
+ 'grpc://10.120.27.8:8470')
+
+ self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
+ self.assertTrue(TPUClusterResolver._inGke())
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470,'
+ 'grpc://10.120.27.6:8470,'
+ 'grpc://10.120.27.7:8470,'
+ 'grpc://10.120.27.8:8470'),
+ compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
+
+ tpu_cluster_resolver = TPUClusterResolver()
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ compat.as_bytes(tpu_cluster_resolver.master()))
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.120.27.5:8470' }
+ tasks { key: 1 value: '10.120.27.6:8470' }
+ tasks { key: 2 value: '10.120.27.7:8470' }
+ tasks { key: 3 value: '10.120.27.8:8470' }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testDiscoveryUrl(self):
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 0708d6b7b9..a0a5b0e00c 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW)
# Options
option(tensorflow_VERBOSE "Enable for verbose output" OFF)
+
+if(WIN32)
+# BoringSSL is disabled for windows as it currently doesn't build with
+# MSBuild. (Ninja is required.)
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF)
+else()
+# BoringSSL is enabled for gRPC.
+option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON)
+endif()
+
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
@@ -290,17 +299,20 @@ include_directories(
${double_conversion_INCLUDE_DIR}
)
-if(tensorflow_ENABLE_SSL_SUPPORT)
- include(boringssl)
- list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES})
- list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl)
- include_directories(${boringssl_INCLUDE_DIR})
-endif()
if(tensorflow_ENABLE_GRPC_SUPPORT)
+ if(tensorflow_ENABLE_SSL_SUPPORT)
+ include(boringssl)
+ include_directories(${boringssl_INCLUDE_DIR})
+ endif()
include(grpc)
+ include_directories(${GRPC_INCLUDE_DIRS})
+ # Place boringssl after grpc as grpc depends on boringssl.
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${grpc_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES grpc)
- include_directories(${GRPC_INCLUDE_DIRS})
+ if(tensorflow_ENABLE_SSL_SUPPORT)
+ list(APPEND tensorflow_EXTERNAL_LIBRARIES ${boringssl_STATIC_LIBRARIES})
+ list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl)
+ endif()
endif()
if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
include(jemalloc)
@@ -327,40 +339,14 @@ endif()
# MKL Support
if (tensorflow_ENABLE_MKL_SUPPORT)
add_definitions(-DINTEL_MKL -DEIGEN_USE_VML)
- if (WIN32)
- find_path(MKL_HOME_PLATFORM mkl
- PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../
- $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../
- PATH_SUFFIXES windows)
- set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
- set(MKL_LINK_DIRS
- ${MKL_HOME_PLATFORM}/mkl/lib/intel64
- ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt
- ${MKL_HOME_PLATFORM}/compiler/lib/intel64
- ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib)
- set(MKL_REDIST_DLL_DIRS
- ${MKL_HOME_PLATFORM}/redist/intel64/mkl
- ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt
- ${MKL_HOME_PLATFORM}/redist/intel64/compiler)
- list(APPEND tensorflow_EXTERNAL_LIBRARIES
- mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64)
- endif()
- if (UNIX)
- # Fix me: complete the path on linux
- find_path(MKL_HOME_PLATFORM mkl
- HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../
- $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../
- PATH_SUFFIXES linux)
- set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include)
- set(MKL_LINK_DIRS) # incompleted
- set(MKL_REDIST_SO_DIRS) # incompleted
- endif()
- include_directories(${MKL_INCLUDE_DIRS})
- link_directories(${MKL_LINK_DIRS})
+ include(mkl)
+ list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkl_STATIC_LIBRARIES})
+ list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkl_copy_shared_to_destination)
+ include_directories(${mkl_INCLUDE_DIRS})
if (tensorflow_ENABLE_MKLDNN_SUPPORT)
include(mkldnn)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
- list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn)
+ list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
else (tensorflow_ENABLE_MKLDNN_SUPPORT)
add_definitions(-DINTEL_MKL_ML)
diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake
index 3c4bb01e24..fbb14b2515 100644
--- a/tensorflow/contrib/cmake/external/boringssl.cmake
+++ b/tensorflow/contrib/cmake/external/boringssl.cmake
@@ -17,7 +17,7 @@ include (ExternalProject)
set(boringssl_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src/boringssl/include)
#set(boringssl_EXTRA_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/boringssl/src)
set(boringssl_URL https://boringssl.googlesource.com/boringssl)
-set(boringssl_TAG ee7aa02)
+set(boringssl_TAG 7f8c553d7f4db0a6ce727f2986d41bf8fe8ec4bf)
set(boringssl_BUILD ${CMAKE_BINARY_DIR}/boringssl/src/boringssl-build)
#set(boringssl_LIBRARIES ${boringssl_BUILD}/obj/so/libboringssl.so)
set(boringssl_STATIC_LIBRARIES
diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake
index 527ccdc8d8..5c5adaf579 100644
--- a/tensorflow/contrib/cmake/external/double_conversion.cmake
+++ b/tensorflow/contrib/cmake/external/double_conversion.cmake
@@ -16,15 +16,15 @@ include (ExternalProject)
set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion)
set(double_conversion_URL https://github.com/google/double-conversion.git)
-set(double_conversion_TAG 5664746)
+set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8)
set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR})
set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so)
set(double_conversion_INCLUDES ${double_conversion_BUILD})
if(WIN32)
- set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib)
+ set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib)
else()
- set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a)
+ set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a)
endif()
set(double_conversion_HEADERS
diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake
index 693dc7cd67..b1e64aa55c 100644
--- a/tensorflow/contrib/cmake/external/grpc.cmake
+++ b/tensorflow/contrib/cmake/external/grpc.cmake
@@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc)
set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f)
if(WIN32)
+ # We use unsecure gRPC because boringssl does not build on windows
+ set(grpc_TARGET grpc++_unsecure)
+ set(grpc_DEPENDS protobuf zlib)
+ set(grpc_SSL_PROVIDER NONE)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
set(grpc_STATIC_LIBRARIES
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib
@@ -32,9 +36,12 @@ if(WIN32)
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib)
endif()
else()
+ set(grpc_TARGET grpc++)
+ set(grpc_DEPENDS boringssl protobuf zlib)
+ set(grpc_SSL_PROVIDER module)
set(grpc_STATIC_LIBRARIES
- ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
- ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
+ ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a
+ ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
@@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0)
ExternalProject_Add(grpc
PREFIX grpc
- DEPENDS protobuf zlib
+ DEPENDS ${grpc_DEPENDS}
GIT_REPOSITORY ${GRPC_URL}
GIT_TAG ${GRPC_TAG}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
BUILD_IN_SOURCE 1
BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES}
- BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure
+ BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET}
COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin
INSTALL_COMMAND ""
CMAKE_CACHE_ARGS
@@ -59,7 +66,7 @@ ExternalProject_Add(grpc
-DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS}
-DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES}
-DZLIB_ROOT:STRING=${ZLIB_INSTALL}
- -DgRPC_SSL_PROVIDER:STRING=NONE
+ -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER}
)
# grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h.
diff --git a/tensorflow/contrib/cmake/external/mkl.cmake b/tensorflow/contrib/cmake/external/mkl.cmake
new file mode 100644
index 0000000000..a172e3a41a
--- /dev/null
+++ b/tensorflow/contrib/cmake/external/mkl.cmake
@@ -0,0 +1,68 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+include (ExternalProject)
+
+# NOTE: Different from mkldnn.cmake, this file is meant to download mkl libraries
+set(mkl_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include)
+set(mkl_BIN_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkl/bin)
+set(mkl_WIN mklml_win_2018.0.3.20180406.zip) # match for v0.14
+set(mkl_MAC mklml_mac_2018.0.3.20180406.tgz)
+set(mkl_LNX mklml_lnx_2018.0.3.20180406.tgz)
+set(mkl_TAG v0.14)
+set(mkl_URL https://github.com/intel/mkl-dnn/releases)
+
+if (WIN32)
+ set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_WIN})
+ list(APPEND mkl_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.lib)
+ list(APPEND mkl_STATIC_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.lib)
+ list(APPEND mkl_SHARED_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/mklml.dll)
+ list(APPEND mkl_SHARED_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5md.dll)
+elseif (UNIX)
+ set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_LNX})
+ list(APPEND mkl_SHARED_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libiomp5.so)
+ list(APPEND mkl_SHARED_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_gnu.so)
+ list(APPEND mkl_SHARED_LIBRARIES
+ ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/lib/libmklml_intel.so)
+elseif (APPLE)
+ set(mkl_DOWNLOAD_URL ${mkl_URL}/download/${mkl_TAG}/${mkl_MAC})
+ #TODO need more information
+endif ()
+
+ExternalProject_Add(mkl
+ PREFIX mkl
+ URL ${mkl_DOWNLOAD_URL}
+ DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND ""
+ INSTALL_COMMAND "")
+
+# put mkl dynamic libraries in one bin directory
+add_custom_target(mkl_create_destination_dir
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${mkl_BIN_DIRS}
+ DEPENDS mkl)
+
+add_custom_target(mkl_copy_shared_to_destination DEPENDS mkl_create_destination_dir)
+
+foreach(dll_file ${mkl_SHARED_LIBRARIES})
+ add_custom_command(TARGET mkl_copy_shared_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dll_file} ${mkl_BIN_DIRS})
+endforeach()
diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake
index a639fdee36..8123ee1f39 100644
--- a/tensorflow/contrib/cmake/external/mkldnn.cmake
+++ b/tensorflow/contrib/cmake/external/mkldnn.cmake
@@ -22,8 +22,11 @@ set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib)
+ set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.dll)
+ set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release)
else()
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib)
+ set(mkldnn_SHARED_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.dll)
endif()
else()
set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a)
@@ -31,6 +34,7 @@ endif()
ExternalProject_Add(mkldnn
PREFIX mkldnn
+ DEPENDS mkl
GIT_REPOSITORY ${mkldnn_URL}
GIT_TAG ${mkldnn_TAG}
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
@@ -40,5 +44,11 @@ ExternalProject_Add(mkldnn
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
- -DMKLINC:STRING=${MKL_INCLUDE_DIRS}
+ -DMKLINC:STRING=${mkl_INCLUDE_DIRS}
)
+
+# since mkldnn depends on mkl, copy the mkldnn.dll together with mklml.dll to mkl_bin_dirs
+add_custom_target(mkldnn_copy_shared_to_destination DEPENDS mkldnn)
+
+add_custom_command(TARGET mkldnn_copy_shared_to_destination PRE_BUILD
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different ${mkldnn_SHARED_LIBRARIES} ${mkl_BIN_DIRS})
diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake
index b9d1dd88d4..eba3bcfc79 100644
--- a/tensorflow/contrib/cmake/external/nsync.cmake
+++ b/tensorflow/contrib/cmake/external/nsync.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public)
set(nsync_URL https://github.com/google/nsync)
-set(nsync_TAG 0559ce013feac8db639ee1bf776aca0325d28777)
+set(nsync_TAG 1.20.0)
set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync)
set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install)
diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake
index ab464bc99a..f56fb35a0f 100644
--- a/tensorflow/contrib/cmake/external/protobuf.cmake
+++ b/tensorflow/contrib/cmake/external/protobuf.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
set(PROTOBUF_URL https://github.com/google/protobuf.git)
-set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9)
+set(PROTOBUF_TAG v3.6.0)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fece56c412..40041d9c88 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -14,6 +14,7 @@ tensorflow/examples/tutorials
tensorflow/examples/tutorials/mnist
tensorflow/python
tensorflow/python/client
+tensorflow/python/compat
tensorflow/python/data
tensorflow/python/data/ops
tensorflow/python/data/util
@@ -35,6 +36,7 @@ tensorflow/python/keras
tensorflow/python/keras/applications
tensorflow/python/keras/datasets
tensorflow/python/keras/engine
+tensorflow/python/keras/estimator
tensorflow/python/keras/layers
tensorflow/python/keras/preprocessing
tensorflow/python/keras/utils
@@ -85,6 +87,8 @@ tensorflow/contrib/batching/python/ops
tensorflow/contrib/bayesflow
tensorflow/contrib/bayesflow/python
tensorflow/contrib/bayesflow/python/ops
+# tensorflow/contrib/bigtable/python
+# tensorflow/contrib/bigtable/python/ops
tensorflow/contrib/boosted_trees
tensorflow/contrib/boosted_trees/estimator_batch
tensorflow/contrib/boosted_trees/kernels
@@ -129,6 +133,7 @@ tensorflow/contrib/data
tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
+tensorflow/contrib/data/python/kernel_tests/serialization
tensorflow/contrib/data/python/ops
tensorflow/contrib/decision_trees
tensorflow/contrib/decision_trees/proto
@@ -236,6 +241,8 @@ tensorflow/contrib/keras/api/keras/wrappers/scikit_learn
tensorflow/contrib/kernel_methods
tensorflow/contrib/kernel_methods/python
tensorflow/contrib/kernel_methods/python/mappers
+tensorflow/contrib/kinesis/python
+tensorflow/contrib/kinesis/python/ops
tensorflow/contrib/kfac
tensorflow/contrib/kfac/examples
tensorflow/contrib/kfac/python
diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index 2e0a2fcef4..7a30eb94f5 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -36,16 +36,3 @@ add_dependencies(
tf_cc_while_loop
tf_core_lib
tf_protos_cc)
-
-if(tensorflow_BUILD_PYTHON_BINDINGS)
- add_library(tf_c_python_api OBJECT
- "${tensorflow_source_dir}/tensorflow/c/python_api.cc"
- "${tensorflow_source_dir}/tensorflow/c/python_api.h"
- )
- add_dependencies(
- tf_c_python_api
- tf_c
- tf_core_lib
- tf_core_framework
- tf_protos_cc)
-endif()
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index dac84ccb0d..067c299a71 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -125,6 +125,7 @@ endfunction()
file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
+ "${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto"
)
@@ -233,15 +234,6 @@ if(WIN32)
list(APPEND tf_core_lib_srcs ${tf_core_platform_windows_srcs})
endif(WIN32)
-if(tensorflow_ENABLE_SSL_SUPPORT)
- # Cloud libraries require boringssl.
- file(GLOB tf_core_platform_cloud_srcs
- "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.h"
- "${tensorflow_source_dir}/tensorflow/core/platform/cloud/*.cc"
- )
- list(APPEND tf_core_lib_srcs ${tf_core_platform_cloud_srcs})
-endif()
-
if (tensorflow_ENABLE_HDFS_SUPPORT)
list(APPEND tf_core_platform_hdfs_srcs
"${tensorflow_source_dir}/tensorflow/core/platform/hadoop/hadoop_file_system.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 2d76bf530a..844f62649d 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -134,14 +134,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs})
endif(tensorflow_BUILD_CONTRIB_KERNELS)
-if(NOT tensorflow_ENABLE_SSL_SUPPORT)
- # Cloud libraries require boringssl.
- file(GLOB tf_core_kernels_cloud_srcs
- "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h"
- "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc"
- )
+# Cloud libraries require curl and boringssl.
+# Curl is not supported yet anyway so we remove for now.
+file(GLOB tf_core_kernels_cloud_srcs
+ "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.h"
+ "${tensorflow_source_dir}/tensorflow/contrib/cloud/kernels/*.cc"
+)
list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_cloud_srcs})
-endif()
file(GLOB_RECURSE tf_core_kernels_exclude_srcs
"${tensorflow_source_dir}/tensorflow/core/kernels/*test*.h"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 1959ad028a..8a9172b43c 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -456,6 +456,18 @@ add_custom_command(
COMMENT "Running SWIG to generate Python wrappers"
VERBATIM )
+add_library(tf_c_python_api OBJECT
+ "${tensorflow_source_dir}/tensorflow/c/python_api.cc"
+ "${tensorflow_source_dir}/tensorflow/c/python_api.h"
+)
+add_dependencies(
+ tf_c_python_api
+ tf_c
+ tf_core_lib
+ tf_core_framework
+ tf_protos_cc
+ tf_python_protos_cc)
+
set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h"
"${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc"
@@ -743,29 +755,113 @@ set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt")
file(WRITE "${api_init_list_file}" "${api_init_files}")
# Run create_python_api.py to generate __init__.py files.
+
+### TODO
+# In order to download and compile MKL/MKL-DNN automatically in cmake script, mkl-built libraries should be added to system path
+# to be loaded by python executor. However `add_custom_command` has an issue with `COMMAND ${CMAKE_COMMAND} -E env PATH=`, where
+# arguments of multiple paths (such as D:/;D:/mkl) will be parsed in to seperate string without semicolon and that command fail to
+# recongnize paths. As CUDA isn't built with MKL, the MKL built directory is the only path to this command to work around that issue.
+# To not override the CUDA and system path in other circumstances, `if-else` branch used here to handle this problem,
+# and should be removed if the path issue can be resolved.
+###
+
+if (tensorflow_ENABLE_MKL_SUPPORT)
+ # add mkl dist dlls to system path for python
+ # TODO: In current cmake version, PY_RUNTIME_ENV behaves strange with multiple paths,
+ # so we have to specify only one path in it to work around the issue. We need this if/else
+ # to protect overwriting CUDA environments
+ set(PY_RUNTIME_ENV ${mkl_BIN_DIRS})
+ add_custom_command(
+ OUTPUT ${api_init_files}
+ DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
+
+ # tensorflow/__init__.py depends on files generated in this step. So, remove it while
+ # this step is running since the files aren't there yet.
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+
+ # Run create_python_api.py to generate API init files.
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python PATH=${PY_RUNTIME_ENV} ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
+ "--package=tensorflow.python"
+ "--apiname=tensorflow"
+ "${api_init_list_file}"
+
+ COMMENT "Generating __init__.py files for Python API."
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
+ VERBATIM
+ )
+else (tensorflow_ENABLE_MKL_SUPPORT)
+ add_custom_command(
+ OUTPUT ${api_init_files}
+ DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
+
+ # tensorflow/__init__.py depends on files generated in this step. So, remove it while
+ # this step is running since the files aren't there yet.
+ COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
+
+ # Run create_python_api.py to generate API init files.
+ COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
+ "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
+ "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
+ "--package=tensorflow.python"
+ "--apiname=tensorflow"
+ "${api_init_list_file}"
+
+ COMMENT "Generating __init__.py files for Python API."
+ WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
+ )
+endif (tensorflow_ENABLE_MKL_SUPPORT)
+
+add_custom_target(tf_python_api SOURCES ${api_init_files})
+add_dependencies(tf_python_api tf_python_ops)
+
+# TODO(mikecase): This can be removed once tf.estimator is moved
+# out of TensorFlow.
+########################################################
+# Generate API __init__.py files for tf.estimator.
+########################################################
+
+# Parse tensorflow/tools/api/generator/BUILD to get list of generated files.
+FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text)
+STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text})
+string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text})
+string(REPLACE "," ";" api_init_files_list ${api_init_files_text})
+
+set(api_init_files "")
+foreach(api_init_file ${api_init_files_list})
+ string(STRIP "${api_init_file}" api_init_file)
+ if(api_init_file)
+ string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes
+ list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}")
+ endif()
+endforeach(api_init_file)
+set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt")
+file(WRITE "${estimator_api_init_list_file}" "${api_init_files}")
+
+# Run create_python_api.py to generate __init__.py files.
add_custom_command(
OUTPUT ${api_init_files}
DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops
- # tensorflow/__init__.py depends on files generated in this step. So, remove it while
- # this step is running since the files aren't there yet.
- COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py
-
# Run create_python_api.py to generate API init files.
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE}
"${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py"
- "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py"
- "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow"
- "${api_init_list_file}"
+ "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api"
+ "--package=tensorflow.python.estimator"
+ "--apiname=estimator"
+ "--output_package=tensorflow.python.estimator.api"
+ "${estimator_api_init_list_file}"
COMMENT "Generating __init__.py files for Python API."
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python"
)
-add_custom_target(tf_python_api SOURCES ${api_init_files})
-add_dependencies(tf_python_api tf_python_ops)
-
-
+add_custom_target(estimator_python_api SOURCES ${api_init_files})
+add_dependencies(estimator_python_api tf_python_ops)
############################################################
# Build a PIP package containing the TensorFlow runtime.
############################################################
@@ -776,6 +872,7 @@ add_dependencies(tf_python_build_pip_package
tf_python_touchup_modules
tf_python_ops
tf_python_api
+ estimator_python_api
tf_extension_ops)
# Fix-up Python files that were not included by the add_python_module() macros.
diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake
index 38f40452b5..fdf522f1fd 100644
--- a/tensorflow/contrib/cmake/tf_shared_lib.cmake
+++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake
@@ -145,3 +145,8 @@ install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/
# unsupported Eigen directory
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/
DESTINATION include/unsupported/Eigen)
+# mkl
+if (tensorflow_ENABLE_MKL_SUPPORT)
+ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkl/src/mkl/include/
+ DESTINATION include/mkl)
+endif (tensorflow_ENABLE_MKL_SUPPORT)
diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake
index 2f70e59d54..6d634cb170 100644
--- a/tensorflow/contrib/cmake/tf_stream_executor.cmake
+++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake
@@ -64,8 +64,6 @@ file(GLOB tf_stream_executor_srcs
if (tensorflow_ENABLE_GPU)
file(GLOB tf_stream_executor_gpu_srcs
"${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc"
- "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h"
- "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc"
)
if (NOT tensorflow_BUILD_CC_TESTS)
file(GLOB tf_stream_executor_gpu_tests
diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py
index 0fbe3081af..0c997bd4fd 100644
--- a/tensorflow/contrib/coder/python/layers/entropybottleneck.py
+++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py
@@ -28,7 +28,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras import engine
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import init_ops
@@ -40,7 +40,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.summary import summary
-class EntropyBottleneck(engine.Layer):
+class EntropyBottleneck(base_layer.Layer):
"""Entropy bottleneck layer.
This layer can be used to model the entropy (the amount of information
@@ -262,7 +262,7 @@ class EntropyBottleneck(engine.Layer):
self._range_coder_precision = int(range_coder_precision)
self._data_format = data_format
self._channel_axis(2) # trigger ValueError early
- self.input_spec = engine.InputSpec(min_ndim=2)
+ self.input_spec = base_layer.InputSpec(min_ndim=2)
@property
def init_scale(self):
@@ -357,7 +357,7 @@ class EntropyBottleneck(engine.Layer):
channels = input_shape[channel_axis].value
if channels is None:
raise ValueError("The channel dimension of the inputs must be defined.")
- self.input_spec = engine.InputSpec(
+ self.input_spec = base_layer.InputSpec(
ndim=input_shape.ndims, axes={channel_axis: channels})
filters = (1,) + self.filters + (1,)
scale = self.init_scale ** (1 / (len(self.filters) + 1))
diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md
index c65a150464..cb1dd7d836 100644
--- a/tensorflow/contrib/constrained_optimization/README.md
+++ b/tensorflow/contrib/constrained_optimization/README.md
@@ -46,7 +46,7 @@ document.
Imagine that we want to constrain the recall of a binary classifier to be at
least 90%. Since the recall is proportional to the number of true positive
classifications, which itself is a sum of indicator functions, this constraint
-is non-differentible, and therefore cannot be used in a problem that will be
+is non-differentiable, and therefore cannot be used in a problem that will be
optimized using a (stochastic) gradient-based algorithm.
For this and similar problems, TFCO supports so-called *proxy constraints*,
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
index 04014ab4ae..3791dae8d7 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -169,8 +169,8 @@ def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
del old_inactive # Needed by the condition, but not the body.
iteration += 1
scale = (1.0 - standard_ops.reduce_sum(
- matrix, axis=0, keep_dims=True)) / standard_ops.maximum(
- 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True))
+ matrix, axis=0, keepdims=True)) / standard_ops.maximum(
+ 1.0, standard_ops.reduce_sum(inactive, axis=0, keepdims=True))
matrix += scale * inactive
new_inactive = standard_ops.to_float(matrix > 0)
matrix *= new_inactive
@@ -206,10 +206,10 @@ def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
# For numerical reasons, make sure that the largest matrix element is zero
# before exponentiating.
- log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True)
+ log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keepdims=True)
log_matrix -= standard_ops.log(
standard_ops.reduce_sum(
- standard_ops.exp(log_matrix), axis=0, keep_dims=True))
+ standard_ops.exp(log_matrix), axis=0, keepdims=True))
return log_matrix
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index 102bc460fd..a0dd3881a8 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''):
new_control_inputs, input_types, new_original_op,
op_def)
#Use Graph's hidden methods to add the op
- to_graph._add_op(new_op) # pylint: disable=protected-access
to_graph._record_op_seen_by_control_dependencies(new_op)
for device_function in reversed(to_graph._device_function_stack):
new_op._set_device(device_function(new_op))
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 8285ea0492..252ea1560d 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -768,7 +768,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLSTMCheckpointableSingleLayer(self):
num_units = 2
direction = CUDNN_RNN_UNIDIRECTION
@@ -781,7 +781,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGRUCheckpointableSingleLayer(self):
num_units = 2
direction = CUDNN_RNN_UNIDIRECTION
@@ -826,7 +826,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCudnnCompatibleLSTMCheckpointablMultiLayer(self):
num_units = 2
num_layers = 3
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 8822a7523f..748d7cd011 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import saver
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import tracking as checkpointable_lib
CUDNN_RNN_UNIDIRECTION = "unidirectional"
CUDNN_RNN_BIDIRECTION = "bidirectional"
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 3510e7b1ad..675330716b 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -25,19 +25,27 @@ See @{$guide/datasets$Importing Data} for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@RandomDataset
+@@Reducer
@@SqlDataset
+@@TFRecordWriter
@@assert_element_shape
@@batch_and_drop_remainder
@@bucket_by_sequence_length
@@choose_from_datasets
+@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
+
+@@get_single_element
+@@group_by_reducer
@@group_by_window
@@ignore_errors
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
+
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
@@ -50,8 +58,7 @@ See @{$guide/datasets$Importing Data} for an overview.
@@sliding_window_batch
@@sloppy_interleave
@@unbatch
-
-@@get_single_element
+@@unique
"""
from __future__ import absolute_import
@@ -71,14 +78,18 @@ from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset
from tensorflow.contrib.data.python.ops.error_ops import ignore_errors
from tensorflow.contrib.data.python.ops.get_single_element import get_single_element
from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length
+from tensorflow.contrib.data.python.ops.grouping import group_by_reducer
from tensorflow.contrib.data.python.ops.grouping import group_by_window
+from tensorflow.contrib.data.python.ops.grouping import Reducer
from tensorflow.contrib.data.python.ops.interleave_ops import choose_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
+from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
@@ -88,6 +99,8 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.unique import unique
+from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 97cc0bc6c9..4657807785 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/lib/io/buffered_inputstream.h"
#include "tensorflow/core/lib/io/random_inputstream.h"
namespace tensorflow {
@@ -103,12 +102,11 @@ class CSVDatasetOp : public DatasetOpKernel {
OP_REQUIRES(
ctx, select_cols.empty() || select_cols.front() >= 0,
errors::InvalidArgument("select_cols should be non-negative indices"));
- bool select_all_cols = select_cols.empty();
- *output = new Dataset(
- ctx, std::move(filenames), header, buffer_size, output_types_,
- output_shapes_, std::move(record_defaults), std::move(select_cols),
- select_all_cols, use_quote_delim, delim[0], std::move(na_value));
+ *output = new Dataset(ctx, std::move(filenames), header, buffer_size,
+ output_types_, output_shapes_,
+ std::move(record_defaults), std::move(select_cols),
+ use_quote_delim, delim[0], std::move(na_value));
}
private:
@@ -118,8 +116,7 @@ class CSVDatasetOp : public DatasetOpKernel {
int64 buffer_size, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
- bool select_all_cols, bool use_quote_delim, char delim,
- string na_value)
+ bool use_quote_delim, char delim, string na_value)
: GraphDatasetBase(ctx),
filenames_(std::move(filenames)),
header_(header),
@@ -128,7 +125,6 @@ class CSVDatasetOp : public DatasetOpKernel {
output_shapes_(output_shapes),
record_defaults_(std::move(record_defaults)),
select_cols_(std::move(select_cols)),
- select_all_cols_(select_all_cols),
use_quote_delim_(use_quote_delim),
delim_(delim),
na_value_(std::move(na_value)) {}
@@ -166,11 +162,24 @@ class CSVDatasetOp : public DatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ bool select_all = dataset()->select_cols_.empty();
do {
// We are currently processing a file, so try to read the next record
- if (buffered_input_stream_) {
- Status s = ReadRecord(ctx, out_tensors);
- if (s.ok() || !errors::IsOutOfRange(s)) {
+ if (input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors, select_all,
+ dataset()->select_cols_);
+ if (s.ok()) {
+ // Validate output
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument(
+ "Expect ", dataset()->out_type_.size(), " fields but have ",
+ out_tensors->size(), " in record");
+ }
+
+ *end_of_sequence = false;
+ return s;
+ }
+ if (!errors::IsOutOfRange(s)) {
// Not at the end of file, return OK or non-EOF errors to caller.
*end_of_sequence = false;
return s;
@@ -203,145 +212,317 @@ class CSVDatasetOp : public DatasetOpKernel {
}
private:
- // Reads a record by parsing the input buffer, and converting extracted
+ // Reads an entire CSV row from the input stream, either from the
+ // existing buffer or by filling the buffer as needed. Converts extracted
// fields to output tensors as we go.
- Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors)
+ //
+ // When this function is called, pos_ should be the index of the first
+ // character of the record in buffer_, or past the end of the buffer.
+ // Note: ctx and out_tensors are only used in this function
+ // when fields are included in the record.
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool select_all, const std::vector<int64>& selected)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // Extracts fields from line(s) from the buffered input stream.
- out_tensors->reserve(dataset()->record_defaults_.size());
-
- string input;
- TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input));
-
- size_t current_idx = 0;
- size_t num_fields_parsed = 0;
- size_t selector_idx = 0; // Keep track of index into select_cols
-
- while (current_idx < input.size()) {
- // In each iteration, parse one field
- if (input[current_idx] == '\n' || input[current_idx] == '\r') {
- // This should never happen, because buffered input reader splits
- // input on newlines.
- return errors::InvalidArgument("Parsing error.");
- }
+ if (pos_ >= buffer_.size()) {
+ // At the end of the file, this will return errors::OutOfRange
+ TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
+ pos_ = 0;
+ }
+
+ // The first character may be \n if this is the continuation of a
+ // \r\n linebreak between this and the previous record. If so, skip it.
+
+ bool end_of_record = false; // Keep track of when we find \n, \r or EOF
+ size_t num_parsed = 0;
+ size_t num_selected_parsed = 0;
+
+ Status result;
- bool quoted = false;
+ while (!end_of_record) { // Read till we reach \n, \r or EOF
bool include =
- (dataset()->select_all_cols_ ||
- dataset()->select_cols_[selector_idx] == num_fields_parsed);
+ select_all || (num_selected_parsed < selected.size() &&
+ selected[num_selected_parsed] == num_parsed);
- if (dataset()->use_quote_delim_ && input[current_idx] == '"') {
- quoted = true;
- current_idx++;
+ // Don't fail fast, so that the next call to GetNext may still return
+ // a valid record
+ result.Update(
+ ParseOneField(ctx, out_tensors, &end_of_record, include));
+
+ num_parsed++;
+ if (include) num_selected_parsed++;
+ }
+
+ return result;
+ }
+
+ // Parses one field from position pos_ in the buffer. Fields are
+ // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
+ // the next field.
+ Status ParseOneField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // If we get here, this means the previous field's end coincided
+ // with the end of the buffer. We can fill the buffer without abandon.
+ Status s = FillBuffer(&buffer_);
+
+ if (errors::IsOutOfRange(s)) {
+ // Reached EOF, and last field is empty
+ *end_of_record = true;
+ if (include) {
+ return FieldToOutput(ctx, StringPiece(), out_tensors);
+ } else {
+ return Status::OK();
+ }
+ } else if (!s.ok()) {
+ return s; // Surface other errors back to caller
}
- // Parse the body of the field
- string field;
- if (!quoted) {
- while (current_idx < input.size() &&
- input[current_idx] != dataset()->delim_) {
- if ((dataset()->use_quote_delim_ && input[current_idx] == '"') ||
- input[current_idx] == '\n' || input[current_idx] == '\r') {
- return errors::InvalidArgument(
- "Unquoted fields cannot have quotes/CRLFs inside");
+ pos_ = 0;
+ }
+
+ if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
+ return ParseQuotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ // For keeping track of relevant parts of a field from a previous buffer
+ struct Piece {
+ size_t start;
+ size_t len;
+ string buffer;
+
+ Piece(string buffer, size_t start, size_t len)
+ : start(start), len(len), buffer(std::move(buffer)) {}
+ };
+
+ // Given that pos_ exceeds the buffer, saves the relevant part of the
+ // current buffer (if necessary), fills the buffer, and resets indices to
+ // 0.
+ Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
+ size_t* start, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ string temp_buffer;
+
+ buffer_.swap(temp_buffer);
+ if (include && pos_ > *start) {
+ earlier_pieces->push_back(
+ Piece(std::move(temp_buffer), *start, pos_ - *start));
+ }
+ pos_ = 0;
+ *start = 0;
+ return FillBuffer(&buffer_);
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseQuotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ pos_++; // Starting quotation mark
+
+ Status parse_result;
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument(
+ "Reached end of file without closing quoted field in "
+ "record");
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+ if (ch == '"') {
+ // When we encounter a quote, we look ahead to the next character to
+ // decide what to do
+ pos_++;
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ // This was the last field. We are done
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(), out_tensors, earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s;
}
- if (include) field += input[current_idx];
- current_idx++;
- } // Exit condition: end of input, or current index at delim
+ }
+
+ char next = buffer_[pos_];
+ pos_++;
+ if (next == dataset()->delim_) {
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ return parse_result;
+
+ } else if (next == '\n' || next == '\r') {
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ if (next == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ } else if (next != '"') {
+ // Take note of the error, but keep going to end of field.
+ include = false; // So we don't get funky errors when trying to
+ // unescape the quotes.
+ parse_result.Update(errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another quote"));
+ }
- // Go to next field or the end
- current_idx++;
} else {
- // Quoted field needs to be ended with '"' and delim or end
- while (true) {
- if (current_idx >= input.size() - 1 || input.empty()) {
- if (current_idx == input.size() - 1 &&
- input[current_idx] == '"') {
- // We're at the end of the input, and the quote terminates the
- // record. Go to end.
- current_idx++;
- break;
- }
- // If there's no terminating quote, it means our buffered record
- // line reader split a record up. This can happen if there is a
- // newline encased in quotes. The next line is also part of the
- // record, so we read it and reset the index.
- if (include && current_idx == input.size() - 1) {
- // TODO(rachelim): Instead of building up a string, keep track
- // of terminal indices (or starting char* and length)
- // Also look into using /lib/strings/Scanner
- field += input[current_idx];
- }
- if (include) {
- field += '\n';
- }
- current_idx = 0;
- Status s = buffered_input_stream_->ReadLine(&input);
- if (!s.ok()) {
- return errors::InvalidArgument(
- "Quoted field has to end with quote followed by delim, "
- "CRLF, or EOF");
- }
- } else if (input[current_idx] == '"' &&
- input[current_idx + 1] == dataset()->delim_) {
- // End of field, go to next field or end
- current_idx += 2;
- break;
- } else if (input[current_idx] == '"') {
- // Current char is a quote. Since we're not at end of field,
- // the next character must also be a quote.
- if (input[current_idx + 1] != '"') {
- return errors::InvalidArgument(
- "Quote inside a string has to be escaped by another "
- "quote");
- }
- if (include) field += '"';
- current_idx += 2;
- } else {
- if (include) field += input[current_idx];
- current_idx++;
- }
+ pos_++;
+ }
+ }
+ }
+
+ // Converts quoted field to an output tensor, removing the starting
+ // and ending quotes from it and unescaping double quotations if
+ // necessary.
+ Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ if (field.find('\"', 1) == field.size() - 1) {
+ // `field` contains no escaped quotation marks.
+ // Exclude framing quotation marks
+ field.remove_prefix(1);
+ field.remove_suffix(1);
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+ }
+ string field_complete;
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ field_complete.reserve(str_len);
+
+ // This bool flips every time we see a quote, so that we skip the second
+ // quote of every pair of adjacent quotes in the field. We need to track
+ // this across iterations of the for loop because adjacent double quotes
+ // may be in different buffers. Initialize to true because we also skip
+ // the opening quotation mark of the quoted field.
+ bool skip_next_quote = true;
+ for (const Piece& p : earlier_pieces) {
+ AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
+ &field_complete, &skip_next_quote);
+ }
+ AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
+ StringPiece result = StringPiece(field_complete);
+ result.remove_suffix(1); // Skip final quote
+
+ return FieldToOutput(ctx, result, out_tensors);
+ }
+
+ void AppendUnescapedPiece(StringPiece piece, string* field_complete,
+ bool* skip_next_quote) {
+ size_t from = 0;
+ size_t found = piece.find('\"', from);
+ while (found != string::npos) {
+ if (!*skip_next_quote) {
+ // This is the first quote in a pair of adjacent double quotes
+ field_complete->append(piece.data() + from, found + 1 - from);
+ }
+ *skip_next_quote = !*skip_next_quote;
+ from = found + 1;
+ found = piece.find('\"', from);
+ }
+ // Include the chunk after the last quotation mark in the string
+ if (from < piece.size()) {
+ field_complete->append(piece.data() + from, piece.size() - from);
+ }
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseUnquotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ Status parse_result;
+
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ // Handle errors
+ if (errors::IsOutOfRange(s)) {
+ // Whatever we have is the last field of the last record
+ *end_of_record = true;
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
}
}
- num_fields_parsed++;
+ char ch = buffer_[pos_];
- if (include) {
- // Add the tensor to the result
- TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field),
- selector_idx, out_tensors));
- selector_idx++;
- // Terminate early if we have all the fields we want
- if (selector_idx == dataset()->select_cols_.size())
- return Status::OK();
+ if (ch == dataset()->delim_) {
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ pos_++;
+ return parse_result;
}
- } // Exit condition: current_idx has reached the end of record
-
- // Check if the last field is empty, and include it if necessary
- bool include =
- (dataset()->select_all_cols_ ||
- dataset()->select_cols_[selector_idx] == num_fields_parsed);
- if (include && !input.empty() &&
- input[input.size() - 1] == dataset()->delim_) {
- TF_RETURN_IF_ERROR(
- FieldToOutput(ctx, string(), selector_idx, out_tensors));
+ if (ch == '\n' || ch == '\r') {
+ // need special case to skip over first \n of record if the line
+ // breaks are \r\n
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ *end_of_record = true;
+ pos_++;
+ if (ch == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ }
+ if (dataset()->use_quote_delim_ && ch == '"') {
+ // Take note of the error, but keep going to end of field.
+ parse_result.Update(errors::InvalidArgument(
+ "Unquoted fields cannot have quotes inside"));
+ }
+ // Otherwise, go to next character
+ pos_++;
}
+ }
- // Check that number of fields matches
- if (out_tensors->size() != dataset()->out_type_.size()) {
- return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
- " fields but have ",
- out_tensors->size(), " in record");
+ Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ result->clear();
+ Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result);
+
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ // Ignore OutOfRange error when ReadNBytes read < N bytes.
+ return Status::OK();
}
- return Status::OK();
+ return s;
}
- // Given a string field, and its index in the output,
- // converts it to a Tensor of the right type and adds it to the
- // out_tensors vector.
- Status FieldToOutput(IteratorContext* ctx, string field,
- size_t output_idx,
+ // Given a field, converts it to the right output tensor type
+ Status FieldToOutput(IteratorContext* ctx, StringPiece field,
std::vector<Tensor>* out_tensors) {
+ size_t output_idx = out_tensors->size();
if (output_idx >= dataset()->out_type_.size()) {
// We can get here if we're selecting all columns, but the number of
// fields exceeds the number of defaults provided
@@ -397,7 +578,7 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->record_defaults_[output_idx].flat<float>()(0);
} else {
float value;
- if (!strings::safe_strtof(field.c_str(), &value)) {
+ if (!strings::safe_strtof(field, &value)) {
return errors::InvalidArgument(
"Field ", output_idx,
" in record is not a valid float: ", field);
@@ -412,7 +593,7 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->record_defaults_[output_idx].flat<double>()(0);
} else {
double value;
- if (!strings::safe_strtod(field.c_str(), &value)) {
+ if (!strings::safe_strtod(field, &value)) {
return errors::InvalidArgument(
"Field ", output_idx,
" in record is not a valid double: ", field);
@@ -426,7 +607,7 @@ class CSVDatasetOp : public DatasetOpKernel {
component.scalar<string>()() =
dataset()->record_defaults_[output_idx].flat<string>()(0);
} else {
- component.scalar<string>()() = std::move(field);
+ component.scalar<string>()() = field.ToString();
}
break;
}
@@ -439,6 +620,50 @@ class CSVDatasetOp : public DatasetOpKernel {
return Status::OK();
}
+ // Records can be delimited by "\r\n" line breaks. When we encounter a
+ // '\r', we have to check the next character to see if it is part of the
+ // linebreak, and ignore it if so.
+ void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ Status s = FillBuffer(&buffer_);
+ pos_ = 0;
+ // If we failed to fill buffer, it doesn't matter because we're done
+ // with the record
+ if (!s.ok()) return;
+ }
+ if (buffer_[pos_] == '\n') {
+ pos_++;
+ }
+ }
+
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ string field_complete;
+ field_complete.reserve(str_len);
+
+ for (const Piece& p : earlier_pieces) {
+ field_complete.append(p.buffer, p.start, p.len);
+ }
+
+ field_complete.append(field.data(), field.size());
+ return FieldToOutput(ctx, field_complete, out_tensors);
+ }
+
// Sets up reader streams to read from the file at `current_file_index_`.
Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (current_file_index_ >= dataset()->filenames_.size()) {
@@ -452,16 +677,18 @@ class CSVDatasetOp : public DatasetOpKernel {
dataset()->filenames_[current_file_index_], &file_));
input_stream_.reset(
new io::RandomAccessInputStream(file_.get(), false));
- // TODO(rachelim): Maintain our own buffer so we don't read every record
- // twice
- buffered_input_stream_.reset(new io::BufferedInputStream(
- input_stream_.get(), dataset()->buffer_size_, false));
+ buffer_.clear();
+ pos_ = 0;
if (dataset()->header_) {
- // Ignore header line
- string str;
- Status s = buffered_input_stream_->ReadLine(&str);
- if (errors::IsOutOfRange(s)) {
- return errors::InvalidArgument("Can't read header of empty file");
+ // Read one line, but don't include it. Pass nullptrs as dummy
+ // pointers to objects that shouldn't be invoked anyway
+ // We need to process this as a record here instead of just finding
+ // the first newline because it might contain quoted fields with
+ // newlines in the header as well
+ std::vector<int64> empty;
+ Status s = ReadRecord(nullptr, nullptr, false, empty);
+ if (!s.ok()) {
+ return errors::InvalidArgument("Can't read header of file");
}
}
return Status::OK();
@@ -470,15 +697,15 @@ class CSVDatasetOp : public DatasetOpKernel {
// Resets all reader streams.
void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
input_stream_.reset();
- buffered_input_stream_.reset();
file_.reset();
}
mutex mu_;
+ string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
+ size_t pos_ GUARDED_BY(
+ mu_); // Index into the buffer must be maintained between iters
std::unique_ptr<io::RandomAccessInputStream> input_stream_
GUARDED_BY(mu_);
- std::unique_ptr<io::BufferedInputStream> buffered_input_stream_
- GUARDED_BY(mu_);
size_t current_file_index_ GUARDED_BY(mu_) = 0;
std::unique_ptr<RandomAccessFile> file_
GUARDED_BY(mu_); // must outlive input_stream_
@@ -491,7 +718,6 @@ class CSVDatasetOp : public DatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const std::vector<Tensor> record_defaults_;
const std::vector<int64> select_cols_;
- const bool select_all_cols_;
const bool use_quote_delim_;
const char delim_;
const string na_value_;
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index a2bfce0362..b3d464d716 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -40,7 +40,8 @@ class FunctionBufferingResource : public ResourceBase {
const NameAttrList& func, int64 buffer_size,
const string& source_device,
const string& target_device,
- const std::vector<Tensor>& func_args)
+ const std::vector<Tensor>& func_args,
+ const DataTypeVector& output_types)
: lib_(lib),
pflr_(std::move(pflr)),
func_(func),
@@ -48,6 +49,7 @@ class FunctionBufferingResource : public ResourceBase {
source_device_(source_device),
target_device_(target_device),
func_args_(func_args),
+ output_types_(output_types),
handle_(kInvalidHandle),
is_buffering_(false),
end_of_sequence_(false),
@@ -176,6 +178,13 @@ class FunctionBufferingResource : public ResourceBase {
AllocatorAttributes arg_alloc_attr;
arg_alloc_attr.set_on_host(true);
opts.args_alloc_attrs.push_back(arg_alloc_attr);
+ for (const auto& dtype : output_types_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
if (opts.source_device != target_device_) {
opts.remote_execution = true;
}
@@ -233,6 +242,7 @@ class FunctionBufferingResource : public ResourceBase {
const string source_device_;
const string target_device_;
const std::vector<Tensor> func_args_;
+ const DataTypeVector output_types_;
FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
@@ -250,6 +260,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
}
~FunctionBufferResourceHandleOp() override {
@@ -269,18 +280,20 @@ class FunctionBufferResourceHandleOp : public OpKernel {
std::vector<Tensor> func_args;
func_args.push_back(*string_arg);
+ const string& source_device = ctx->device()->name();
+
// Obtain and canonicalize target_device.
const Tensor* target_arg;
OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
- const string& target_device =
- DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar<string>()());
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr,
errors::Internal("No function library is provided."));
- const string& source_device = ctx->device()->name();
-
mutex_lock l(mu_);
if (!initialized_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
@@ -297,7 +310,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
this](FunctionBufferingResource** ptr) {
*ptr = new FunctionBufferingResource(
clone_lib, std::move(pflr), func_, buffer_size_,
- source_device, target_device, func_args);
+ source_device, target_device, func_args, output_types_);
return Status::OK();
}));
core::ScopedUnref s(buffer);
@@ -319,6 +332,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
int64 buffer_size_;
string container_;
string name_;
+ DataTypeVector output_types_;
};
REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index 3dfc3741c2..141706f393 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace {
@@ -24,19 +25,32 @@ namespace {
class ThreadPoolResource : public ResourceBase {
public:
ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
- const string& name, int num_threads, bool low_latency_hint)
- : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) {
- }
+ const string& name, int num_threads, bool low_latency_hint,
+ int max_intra_op_parallelism)
+ : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {}
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn) {
- thread_pool_.Schedule(std::move(fn));
+ if (max_intra_op_parallelism_ < 0) {
+ thread_pool_.Schedule(std::move(fn));
+ } else {
+ thread_pool_.Schedule(std::bind(
+ [this](std::function<void()> bound_fn) {
+ // TODO(mrry): Consider moving this thread-local configuration to
+ // the threads themselves.
+ ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
+ bound_fn();
+ },
+ std::move(fn)));
+ }
}
string DebugString() override { return "ThreadPoolResource"; }
private:
thread::ThreadPool thread_pool_;
+ const int max_intra_op_parallelism_;
};
// Creates a handle to a ThreadPool resource. Note that we don't use
@@ -48,6 +62,8 @@ class ThreadPoolHandleOp : public OpKernel {
explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
+ &max_intra_op_parallelism_));
OP_REQUIRES(
ctx, num_threads_ > 0,
errors::InvalidArgument("`num_threads` must be greater than zero."));
@@ -78,7 +94,7 @@ class ThreadPoolHandleOp : public OpKernel {
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
*ret = new ThreadPoolResource(
ctx->env(), {}, display_name_,
- num_threads_,
+ num_threads_, max_intra_op_parallelism_,
false /* low_latency_hint */);
return Status::OK();
}));
@@ -95,6 +111,7 @@ class ThreadPoolHandleOp : public OpKernel {
bool initialized_ GUARDED_BY(mu_) = false;
string display_name_;
int num_threads_;
+ int max_intra_op_parallelism_;
};
class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index f271d269ab..8413fcaf87 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -104,6 +104,7 @@ REGISTER_OP("FunctionBufferingResource")
.Attr("container: string")
.Attr("f: func")
.Attr("buffer_size: int")
+ .Attr("output_types: list(type)")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Creates a resource that fills up a buffer by making function calls.
@@ -117,6 +118,7 @@ container: If non-empty, this resource is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this resource will be shared under the given name
across multiple sessions.
+output_types: The type list for the return values.
)doc");
REGISTER_OP("FunctionBufferingResourceGetNext")
@@ -158,6 +160,7 @@ REGISTER_OP("ThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
@@ -166,6 +169,8 @@ Creates a custom thread pool with the given number of threads.
handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
num_threads: The number of threads in the thread pool.
+max_intra_op_parallelism: The maximum degree of parallelism to use within
+ operations that execute on this threadpool.
display_name: A human-readable name for the threads that may be visible in
some visualizations.
)doc");
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index c483a43769..9a454efc4c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test")
py_test(
name = "batch_dataset_op_test",
@@ -16,20 +16,23 @@ py_test(
"no_pip",
],
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -39,7 +42,6 @@ py_test(
srcs = ["bucketing_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -48,24 +50,33 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
- name = "concatenate_dataset_op_test",
+ name = "csv_dataset_op_test",
size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
+ srcs = ["csv_dataset_op_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:error_ops",
+ "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:readers",
"//third_party/py/numpy",
],
)
@@ -80,103 +91,44 @@ py_test(
"nomac", # b/62040583
],
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
],
)
-py_library(
- name = "dataset_serialization_test",
- srcs = [
- "dataset_serialization_test_base.py",
- ],
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:iterator_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "csv_dataset_op_test",
- size = "small",
- srcs = ["csv_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
- name = "filter_dataset_op_test",
+ name = "get_single_element_test",
size = "small",
- srcs = ["filter_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "optonly",
- ],
+ srcs = ["get_single_element_test.py"],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-tf_py_test(
- name = "flat_map_dataset_op_test",
- size = "medium",
- srcs = ["flat_map_dataset_op_test.py"],
- additional_deps = [
- ":dataset_serialization_test",
- "//third_party/py/numpy",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/contrib/data/python/ops:get_single_element",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
],
- grpc_enabled = True,
- tags = ["no_pip"],
)
py_test(
@@ -191,10 +143,8 @@ py_test(
"notap",
],
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:array_ops",
- "//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@@ -202,43 +152,28 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
+ "@six_archive//:six",
],
)
py_test(
- name = "directed_interleave_dataset_test",
- size = "medium",
- srcs = ["directed_interleave_dataset_test.py"],
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-tf_py_test(
- name = "get_single_element_test",
- size = "small",
- srcs = ["get_single_element_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/contrib/data/python/ops:get_single_element",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python:array_ops",
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:model_fn",
],
)
@@ -253,91 +188,114 @@ py_test(
"optonly",
],
deps = [
- ":dataset_serialization_test",
+ "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/contrib/data/python/ops:error_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
"//tensorflow/python:io_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
"//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
- name = "prefetch_dataset_op_test",
+ name = "optimize_dataset_op_test",
size = "small",
- srcs = ["prefetch_dataset_op_test.py"],
+ srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/python:platform",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
],
)
+cuda_py_test(
+ name = "prefetching_ops_test",
+ size = "small",
+ srcs = ["prefetching_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
py_test(
name = "range_dataset_op_test",
size = "small",
srcs = ["range_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:counter",
"//tensorflow/contrib/data/python/ops:enumerate_ops",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
],
)
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
+ "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
py_test(
name = "reader_dataset_ops_test",
size = "medium",
srcs = ["reader_dataset_ops_test.py"],
- shard_count = 4,
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
+ ":reader_dataset_ops_test_base",
"//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
"//third_party/py/numpy",
],
)
@@ -364,6 +322,7 @@ py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
],
)
@@ -374,13 +333,14 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:scan_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -388,55 +348,55 @@ py_test(
)
py_test(
- name = "sequence_dataset_op_test",
+ name = "shuffle_dataset_op_test",
size = "medium",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/python:array_ops",
+ "//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
- name = "serialization_integration_test",
+ name = "slide_dataset_op_test",
size = "small",
- srcs = ["serialization_integration_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ srcs = ["slide_dataset_op_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/contrib/data/python/ops:sliding",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
)
-py_test(
- name = "shuffle_dataset_op_test",
- size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
+py_library(
+ name = "sql_dataset_op_test_base",
+ srcs = ["sql_dataset_op_test_base.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ visibility = [
+ "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
+ "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
+ ],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
+ "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//third_party/py/numpy",
+ "@org_sqlite//:python",
],
)
@@ -445,14 +405,12 @@ py_test(
size = "small",
srcs = ["sql_dataset_op_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
+ ":sql_dataset_op_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "@org_sqlite//:python",
],
)
@@ -463,11 +421,15 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
+ ":reader_dataset_ops_test_base",
"//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
)
@@ -481,8 +443,12 @@ py_test(
"//tensorflow/contrib/data/python/ops:threadpool",
"//tensorflow/contrib/data/python/ops:unique",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -493,87 +459,49 @@ py_test(
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
- ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/contrib/stateless",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
py_test(
- name = "zip_dataset_op_test",
- size = "small",
- srcs = ["zip_dataset_op_test.py"],
+ name = "window_dataset_op_test",
+ size = "medium",
+ srcs = ["window_dataset_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_test",
- size = "small",
- srcs = ["prefetching_ops_test.py"],
- additional_deps = [
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
+ tags = [
+ "no_pip",
],
-)
-
-tf_py_test(
- name = "slide_dataset_op_test",
- size = "small",
- srcs = ["slide_dataset_op_test.py"],
- additional_deps = [
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/contrib/data/python/ops:sliding",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
-tf_py_test(
+py_test(
name = "writer_ops_test",
size = "small",
srcs = ["writer_ops_test.py"],
- additional_deps = [
+ deps = [
"//tensorflow/contrib/data/python/ops:writers",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
"//tensorflow/python:lib",
- "//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index b5fbc45ad3..42adfd17f0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import math
import time
+from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
@@ -40,7 +40,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase):
+class BatchDatasetTest(test.TestCase, parameterized.TestCase):
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
@@ -293,7 +293,7 @@ class BatchDatasetTest(test.TestCase):
ph2: np.arange(8).astype(np.int32)
})
with self.assertRaises(errors.InvalidArgumentError):
- print(sess.run(next_element))
+ sess.run(next_element)
# No 0th dimension (i.e. scalar value) for one component.
sess.run(
@@ -303,7 +303,7 @@ class BatchDatasetTest(test.TestCase):
ph2: 7
})
with self.assertRaises(errors.InvalidArgumentError):
- print(sess.run(next_element))
+ sess.run(next_element)
def testBatchAndDropRemainder(self):
components = (np.arange(7),
@@ -427,9 +427,13 @@ class BatchDatasetTest(test.TestCase):
self.assertEqual([None], dataset.output_shapes[1][0].as_list())
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
- def _testMapAndBatchDatasetHelper(self,
- num_parallel_calls=None,
- num_parallel_batches=None):
+ @parameterized.named_parameters(
+ ("default", None, None),
+ ("sequential_calls", 1, None),
+ ("parallel_calls", 2, None),
+ ("parallel_batches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset ->
# RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
@@ -500,19 +504,11 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
- def testMapAndBatch(self):
- return self._testMapAndBatchDatasetHelper()
-
- def testMapAndBatchWithParallelBatches(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_batches=10)
-
- def testMapAndBatchWithSequentialCalls(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_calls=1)
-
- def testMapAndBatchWithParallelCalls(self):
- return self._testMapAndBatchDatasetHelper(num_parallel_calls=2)
-
- def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False):
+ @parameterized.named_parameters(
+ ("even", False),
+ ("uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(
@@ -532,12 +528,6 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- def testMapAndBatchPartialBatch(self):
- return self._testMapAndBatchPartialBatchHelper()
-
- def testMapAndBatchPartialBatchDropRemainder(self):
- return self._testMapAndBatchPartialBatchHelper(drop_remainder=True)
-
def testMapAndBatchYieldsPartialBatch(self):
iterator = (dataset_ops.Dataset.range(10)
.apply(batching.map_and_batch(
@@ -614,7 +604,7 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testMapAndBatchDatasetFails(self):
+ def testMapAndBatchFails(self):
"""Test a dataset that maps a TF function across its input elements."""
dataset = dataset_ops.Dataset.from_tensors(
array_ops.check_numerics(
@@ -628,7 +618,7 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
- def testMapAndBatchDatasetShapeMismatch(self):
+ def testMapAndBatchShapeMismatch(self):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
@@ -651,173 +641,79 @@ class BatchDatasetTest(test.TestCase):
"number of elements does not match"):
sess.run(get_next)
+ def testMapAndBatchImplicitDispose(self):
+ # Tests whether a map and batch dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
+ # MapAndBatchDataset(f=square_3, batch_size=100).
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
-class BatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
- return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
+ 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
+ dataset = dataset.prefetch(5)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len // batch_size
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
+ with self.test_session() as sess:
+ for _ in range(3):
+ sess.run(get_next)
- def _build_dataset_dense_to_sparse(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
+ @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ def testMapAndBatchOutOfRangeError(self, threshold):
- def testDenseToSparseBatchDatasetCore(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
-
- num_outputs = len(components) // 4
- self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
- lambda: self._build_dataset_dense_to_sparse(diff_comp),
- num_outputs)
-
- def _sparse(self, i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
+ def raising_py_fn(i):
+ if i >= threshold:
+ raise StopIteration()
+ else:
+ return i
- def _build_dataset_sparse(self, batch_size=5):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
-
- def testSparseCore(self):
- self.run_core_tests(self._build_dataset_sparse,
- lambda: self._build_dataset_sparse(2), 2)
-
- def _build_dataset_nested_sparse(self):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
-
- def testNestedSparseCore(self):
- self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+ iterator = (
+ dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10)).make_one_shot_iterator())
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ for i in range(threshold // 10):
+ self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
+ if threshold % 10 != 0:
+ self.assertAllEqual(
+ [threshold // 10 * 10 + j for j in range(threshold % 10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
-class UnbatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ @parameterized.parameters(
+ (False, dtypes.bool),
+ (-42, dtypes.int8),
+ (-42, dtypes.int16),
+ (-42, dtypes.int32),
+ (-42, dtypes.int64),
+ (42, dtypes.uint8),
+ (42, dtypes.uint16),
+ (42.0, dtypes.float16),
+ (42.0, dtypes.float32),
+ (42.0, dtypes.float64),
+ (b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
-
- return dataset_ops.Dataset.from_tensor_slices(components).batch(
- batch_size).apply(batching.unbatch())
-
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
-
-
-class MapAndBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testNumParallelBatches(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_batches = 2
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_batches=num_parallel_batches,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
- def testNumParallelCalls(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_calls = 7
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
-
-class PaddedBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testPaddedBatch(self):
-
- def build_dataset(seq_lens):
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- lambda x: array_ops.fill([x], x)).padded_batch(
- 4, padded_shapes=[-1])
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
-
- def testPaddedBatchNonDefaultPadding(self):
-
- def build_dataset(seq_lens):
-
- def fill_tuple(x):
- filled = array_ops.fill([x], x)
- return (filled, string_ops.as_string(filled))
-
- padded_shape = [-1]
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- fill_tuple).padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>"))
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
+ with self.test_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
class RestructuredDatasetTest(test.TestCase):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index bd3e034211..2022c1f2bd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -21,7 +21,6 @@ import random
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -68,7 +67,7 @@ class GroupByReducerTest(test.TestCase):
reducer = grouping.Reducer(
init_func=lambda _: (0.0, 0.0),
reduce_func=reduce_fn,
- finalize_func=lambda x: x[0])
+ finalize_func=lambda x, _: x)
for i in range(1, 11):
dataset = dataset_ops.Dataset.range(2 * i).apply(
grouping.group_by_reducer(
@@ -121,7 +120,7 @@ class GroupByReducerTest(test.TestCase):
reducer = grouping.Reducer(
init_func=lambda x: ([0], 1),
reduce_func=reduce_fn,
- finalize_func=lambda x: x)
+ finalize_func=lambda x, y: (x, y))
for i in range(1, 11):
dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
@@ -176,37 +175,27 @@ class GroupByReducerTest(test.TestCase):
dataset.apply(
grouping.group_by_reducer(lambda _: "wrong", reducer))
+ def testTuple(self):
+ def init_fn(_):
+ return np.array([], dtype=np.int64), np.int64(0)
-class GroupByReducerSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ def reduce_fn(state, value):
+ s1, s2 = state
+ v1, v2 = value
+ return array_ops.concat([s1, [v1]], 0), s2 + v2
- def _build_dataset(self, components):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
+ def finalize_fn(s1, s2):
+ return s1, s2
- return dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_reducer(lambda x: x % 5, reducer))
-
- def testCoreGroupByReducer(self):
- components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 5,
- verify_exhausted=True)
+ reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
+ grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual(x, np.asarray([x for x in range(10)]))
+ self.assertEqual(y, 45)
class GroupByWindowTest(test.TestCase):
@@ -353,34 +342,6 @@ class GroupByWindowTest(test.TestCase):
self.assertEqual(len(components), sum(counts))
-class GroupByWindowSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
-
- def testCoreGroupByWindow(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 12,
- verify_exhausted=False)
-
-
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
@@ -655,7 +616,44 @@ class BucketBySequenceLength(test.TestCase):
batch_sizes = batch_sizes[:-1]
self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(boundaries), sorted(lengths_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.test_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
def testTupleElements(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 8c138c7081..df115175f5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -25,6 +25,7 @@ import time
import numpy as np
+from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers as core_readers
@@ -32,7 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@@ -61,12 +62,12 @@ class CsvDatasetOpTest(test.TestCase):
op2 = sess.run(next2)
self.assertAllEqual(op1, op2)
- def setup_files(self, inputs):
+ def setup_files(self, inputs, linebreak='\n'):
filenames = []
for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i)
- with open(fn, 'w') as f:
- f.write('\n'.join(ip))
+ fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
+ with open(fn, 'wb') as f:
+ f.write(linebreak.join(ip).encode('utf-8'))
filenames.append(fn)
return filenames
@@ -75,7 +76,7 @@ class CsvDatasetOpTest(test.TestCase):
filenames = self.setup_files(inputs)
dataset_expected = core_readers.TextLineDataset(filenames)
dataset_expected = dataset_expected.map(
- lambda l: gen_parsing_ops.decode_csv(l, **kwargs))
+ lambda l: parsing_ops.decode_csv(l, **kwargs))
dataset_actual = readers.CsvDataset(filenames, **kwargs)
return (dataset_actual, dataset_expected)
@@ -86,38 +87,47 @@ class CsvDatasetOpTest(test.TestCase):
inputs, **kwargs)
self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+ def _verify_output_or_err(self,
+ sess,
+ dataset,
+ expected_output=None,
+ expected_err_re=None):
+ nxt = dataset.make_one_shot_iterator().get_next()
+ if expected_err_re is None:
+ # Verify that output is expected, without errors
+ expected_output = [[
+ v.encode('utf-8') if isinstance(v, str) else v for v in op
+ ] for op in expected_output]
+ for value in expected_output:
+ op = sess.run(nxt)
+ self.assertAllEqual(op, value)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(nxt)
+ else:
+ # Verify that OpError is produced as expected
+ with self.assertRaisesOpError(expected_err_re):
+ while True:
+ try:
+ sess.run(nxt)
+ except errors.OutOfRangeError:
+ break
+
def _test_dataset(self,
inputs,
expected_output=None,
expected_err_re=None,
+ linebreak='\n',
**kwargs):
"""Checks that elements produced by CsvDataset match expected output."""
# Convert str type because py3 tf strings are bytestrings
- filenames = self.setup_files(inputs)
+ filenames = self.setup_files(inputs, linebreak)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
dataset = readers.CsvDataset(filenames, **kwargs)
- nxt = dataset.make_one_shot_iterator().get_next()
- if expected_err_re is None:
- # Verify that output is expected, without errors
- expected_output = [[
- v.encode('utf-8') if isinstance(v, str) else v for v in op
- ] for op in expected_output]
- for value in expected_output:
- op = sess.run(nxt)
- self.assertAllEqual(op, value)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
- else:
- # Verify that OpError is produced as expected
- with self.assertRaisesOpError(expected_err_re):
- while True:
- try:
- sess.run(nxt)
- except errors.OutOfRangeError:
- break
-
- def testCsvDataset_floatRequired(self):
+ self._verify_output_or_err(sess, dataset, expected_output,
+ expected_err_re)
+
+ def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4
inputs = [['1,2,3,4']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
@@ -137,10 +147,55 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
- def testCsvDataset_withQuoted(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
+ def testCsvDataset_withEmptyFields(self):
+ record_defaults = [[0]] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Unquoted fields cannot have quotes inside',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['"a"b","c","d"']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Quote inside a string has to be escaped by another quote',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
+ filenames = self.setup_files(inputs)
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
+ filenames = self.setup_files(inputs)
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
def testCsvDataset_mixedTypes(self):
record_defaults = [
@@ -164,11 +219,6 @@ class CsvDatasetOpTest(test.TestCase):
self._test_by_comparison(
inputs, record_defaults=record_defaults, field_delim=':')
- def testCsvDataset_withEmptyValues(self):
- record_defaults = [[0]] * 4
- inputs = [['1,,3,4', ',6,7,8']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
def testCsvDataset_withNaValue(self):
record_defaults = [[0]] * 4
inputs = [['1,NA,3,4', 'NA,6,7,8']]
@@ -176,8 +226,8 @@ class CsvDatasetOpTest(test.TestCase):
inputs, record_defaults=record_defaults, na_value='NA')
def testCsvDataset_withSelectCols(self):
- record_defaults = [[0]] * 2
- inputs = [['1,2,3,4', '5,6,7,8']]
+ record_defaults = [['']] * 2
+ inputs = [['1,2,3,4', '"5","6","7","8"']]
self._test_by_comparison(
inputs, record_defaults=record_defaults, select_cols=[1, 2])
@@ -190,27 +240,17 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults=record_defaults,
select_cols=[3, 4])
+ def testCsvDataset_withOneCol(self):
+ record_defaults = [['NA']]
+ inputs = [['0', '', '2']]
+ self._test_dataset(
+ inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
+
def testCsvDataset_withMultipleFiles(self):
record_defaults = [[0]] * 4
inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
self._test_by_comparison(inputs, record_defaults=record_defaults)
- def testCsvDataset_withNewLine(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_withMultipleNewLines(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
def testCsvDataset_withLeadingAndTrailingSpaces(self):
record_defaults = [[0.0]] * 4
inputs = [['0, 1, 2, 3']]
@@ -266,9 +306,10 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_errorWithHeaderEmptyFile(self):
record_defaults = [[0]] * 2
inputs = [[]]
+ expected_err_re = "Can't read header of file"
self._test_dataset(
inputs,
- expected_err_re="Can't read header of empty file",
+ expected_err_re=expected_err_re,
record_defaults=record_defaults,
header=True,
)
@@ -284,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['', '1,2']] # First record is empty
self._test_dataset(
inputs,
- expected_err_re='Expect 2 fields but have 0 in record',
+ expected_err_re='Expect 2 fields but have 1 in record',
record_defaults=record_defaults)
def testCsvDataset_withChainedOps(self):
@@ -301,7 +342,7 @@ class CsvDatasetOpTest(test.TestCase):
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
- record_defaults = [dtypes.float32, dtypes.float32]
+ record_defaults = [dtypes.float32, [0.0]]
inputs = [['1.0,2.0', '3.0,4.0']]
self._test_dataset(
inputs,
@@ -326,6 +367,162 @@ class CsvDatasetOpTest(test.TestCase):
self.assertEqual(result, sorted(result))
+## The following tests exercise parsing logic for quoted fields
+
+ def testCsvDataset_withQuoted(self):
+ record_defaults = [['']] * 4
+ inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withOneColAndQuotes(self):
+ record_defaults = [['']]
+ inputs = [['"0"', '"1"', '"2"']]
+ self._test_dataset(
+ inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLine(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLineInUnselectedCol(self):
+ record_defaults = [['']]
+ inputs = [['1,"2\n3",4', '5,6,7']]
+ self._test_dataset(
+ inputs,
+ expected_output=[['1'], ['5']],
+ record_defaults=record_defaults,
+ select_cols=[0])
+
+ def testCsvDataset_withMultipleNewLines(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithTerminateMidRecord(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,"a']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Reached end of file without closing quoted field in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withEscapedQuotes(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+
+## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
+## and different types of line breaks
+
+ def testCsvDataset_withInvalidBufferSize(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,d']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='buffer_size should be positive',
+ record_defaults=record_defaults,
+ buffer_size=0)
+
+ def testCsvDataset_withBufferSize(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs, expected, record_defaults=record_defaults, buffer_size=i + 1)
+
+ def testCsvDataset_withCR(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak='\r',
+ record_defaults=record_defaults,
+ buffer_size=i + 1)
+
+ def testCsvDataset_withCRLF(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ record_defaults=record_defaults,
+ buffer_size=i + 1)
+
+ def testCsvDataset_withBufferSizeAndQuoted(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak='\n',
+ record_defaults=record_defaults,
+ buffer_size=i + 1)
+ self._test_dataset(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRAndQuoted(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak='\r',
+ record_defaults=record_defaults,
+ buffer_size=i + 1)
+ self._test_dataset(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRLFAndQuoted(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ for i in range(20):
+ # Test a range of buffer sizes that should all work
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ record_defaults=record_defaults,
+ buffer_size=i + 1)
+ self._test_dataset(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.
@@ -343,7 +540,7 @@ class CsvDatasetBenchmark(test.Benchmark):
self._filenames = []
for n in self._num_cols:
fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'w') as f:
+ with open(fn, 'wb') as f:
# Just write 100 rows and use `repeat`... Assumes the cost
# of creating an iterator is not significant
row = ','.join([str_val for _ in range(n)])
@@ -384,7 +581,7 @@ class CsvDatasetBenchmark(test.Benchmark):
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [[0.0]] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
self._tearDown()
@@ -394,7 +591,7 @@ class CsvDatasetBenchmark(test.Benchmark):
num_cols = self._num_cols[i]
kwargs = {'record_defaults': [['']] * num_cols}
dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
self._tearDown()
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index a842502cc6..a2ab3de52e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -17,14 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -70,63 +66,5 @@ class DatasetConstructorTest(test.TestCase):
# pylint: enable=protected-access
-class DatasetConstructorSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_tensor_dataset(self, variable_array):
- components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
-
- return dataset_ops.Dataset.from_tensors(components)
-
- def testFromTensorsCore(self):
- # Equal length components
- arr = np.array(1)
- num_outputs = 1
- diff_arr = np.array(2)
- self.run_core_tests(lambda: self._build_tensor_dataset(arr),
- lambda: self._build_tensor_dataset(diff_arr),
- num_outputs)
-
- def _build_tensor_slices_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components)
-
- def testFromTensorSlicesCore(self):
- # Equal length components
- components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0]))
-
- diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[5], [6], [7], [8]]), 22),
- np.array([1.0, 2.0, 3.0, 4.0]))
-
- dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
-
- self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
- lambda: self._build_tensor_slices_dataset(diff_comp), 4)
- self.run_core_tests(
- lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
-
- def _build_sparse_tensor_slice_dataset(self, slices):
- indices = np.array(
- [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
- dtype=np.int64)
- values = np.array([val for s in slices for val in s], dtype=np.float64)
- dense_shape = np.array(
- [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
- sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
- return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
-
- def testFromSparseTensorSlicesCore(self):
- slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
- diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
-
- self.run_core_tests(
- lambda: self._build_sparse_tensor_slice_dataset(slices),
- lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
- 9,
- sparse_tensors=True)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 34b6a080c0..9b1857de1a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
@@ -34,8 +33,8 @@ class DirectedInterleaveDatasetTest(test.TestCase):
input_datasets = [
dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
]
- dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset,
- input_datasets)
+ dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
@@ -144,24 +143,5 @@ class DirectedInterleaveDatasetTest(test.TestCase):
], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
-class SampleFromDatasetsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, probs, num_samples):
- dataset = interleave_ops.sample_from_datasets(
- [
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(len(probs))
- ],
- probs,
- seed=1813)
- return dataset.take(num_samples)
-
- def testSerializationCore(self):
- self.run_core_tests(
- lambda: self._build_dataset([0.5, 0.5], 100),
- lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index bee561e3e2..44c3325a3d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -22,10 +22,8 @@ import math
import threading
import time
-import numpy as np
from six.moves import zip_longest
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
@@ -38,132 +36,6 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
- repeat_count = 2
- return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
- repeat_count).interleave(
- lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
-
- def testSerializationCore(self):
- input_values = np.array([4, 5, 6], dtype=np.int64)
- num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
- num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # pylint: enable=g-long-lambda
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
- _interleave_fn, cycle_length=1)
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
-class ParallelInterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self.input_values = np.array([4, 5, 6], dtype=np.int64)
- self.num_repeats = 2
- self.num_outputs = np.sum(self.input_values) * 2
-
- def _build_ds(self, cycle_length, block_length, sloppy=False):
- return (dataset_ops.Dataset.from_tensor_slices(
- self.input_values).repeat(self.num_repeats).apply(
- interleave_ops.parallel_interleave(
- lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
- cycle_length, block_length, sloppy)))
-
- def testSerializationCore(self):
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
- self.run_core_tests(
- lambda: self._build_ds(cycle_length, block_length),
- lambda: self._build_ds(cycle_length * 2, block_length * 1),
- self.num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
-
- def testSerializationWithSloppy(self):
- break_points = self.gen_break_points(self.num_outputs, 10)
- expected_outputs = np.repeat(
- np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
- self.num_repeats).tolist()
-
- def run_test(cycle_length, block_length):
- actual = self.gen_outputs(
- lambda: self._build_ds(cycle_length, block_length, True),
- break_points, self.num_outputs)
- self.assertSequenceEqual(sorted(actual), expected_outputs)
-
- # cycle_length > 1, block_length > 1
- run_test(2, 3)
- # cycle_length = 1
- run_test(1, 3)
- # block_length = 1
- run_test(2, 1)
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).apply(
- interleave_ops.parallel_interleave(_interleave_fn, 1))
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
class ParallelInterleaveDatasetTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 30a993b1f7..30a993b1f7 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index 8d40429279..b7025f3802 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -17,27 +17,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import hashlib
+import itertools
import os
+import time
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.util import compat
+_NUMPY_RANDOM_SEED = 42
+
class MapDatasetTest(test.TestCase):
@@ -143,229 +144,125 @@ class MapDatasetTest(test.TestCase):
sess.run(get_next)
-class MapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 14
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (
- dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(self._num_epochs))
-
- def testSaveRestoreCore(self):
- self.run_core_tests(
- self._build_ds,
- lambda: self._build_ds(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(_map_fn)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1)))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testSparseCore(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1]))
-
- def _build_ds(num_outputs):
- return dataset_ops.Dataset.range(num_outputs).map(_sparse)
-
- num_outputs = 10
- self.run_core_tests(lambda: _build_ds(num_outputs),
- lambda: _build_ds(int(num_outputs / 2)), num_outputs)
-
-
-class ParallelMapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 1
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
-
- def _build_ds_with_prefetch(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
-
- def testSaveRestoreCore(self):
- for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
- self.run_core_tests(
- ds_fn,
- lambda: ds_fn(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(
- _map_fn, num_parallel_calls=2).prefetch(2)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1),
- num_parallel_calls=2).prefetch(2))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
-
-class IgnoreErrorsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors())
-
- def testIgnoreErrorsCore(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
- diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
- num_outputs = 4
- self.run_core_tests(lambda: self._build_ds(components),
- lambda: self._build_ds(diff_components), num_outputs)
-
+class MapDatasetBenchmark(test.Benchmark):
+
+ # The purpose of this benchmark is to compare the performance of chaining vs
+ # fusing of the map and batch transformations across various configurations.
+ #
+ # NOTE: It is recommended to build the benchmark with
+ # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ # and execute it on a machine with at least 32 CPU cores.
+ def benchmarkMapAndBatch(self):
+
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest(),
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+
+ print("%s:" % label)
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
+ element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
+
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
+
+ chained_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+ for _ in range(5):
+ sess.run(chained_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(chained_get_next.op)
+ end = time.time()
+ chained_deltas.append(end - start)
+
+ fused_dataset = dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
+ for _ in range(5):
+ sess.run(fused_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(fused_get_next.op)
+ end = time.time()
+ fused_deltas.append(end - start)
+
+ print(
+ "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print("")
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..3bb9723bbc
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test.TestCase):
+
+ def testDefaultOptimizations(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize())
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testEmptyOptimizations(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize([]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimization(self):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch(
+ 10).apply(optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(
+ any([node.op == "MapAndBatchDatasetV2" for node in graph.node]))
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testFunctionLibraryDefinitionModification(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply(
+ optimization.optimize(["_test_only_function_rename"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(errors.NotFoundError,
+ "Function .* is not defined."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index b08132cd72..82543b1039 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -21,6 +21,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -68,6 +69,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
with ops.device(device1):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_remote_fn,
+ output_types=[dtypes.float32],
target_device=target,
string_arg=ds_iterator_handle,
buffer_size=3,
@@ -85,8 +87,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
return (prefetch_op, reset_op, destroy_op)
def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
@@ -125,8 +126,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
"/job:localhost/replica:0/task:0/gpu:0")
def testReinitialization(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -166,8 +166,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
device0 = "/job:localhost/replica:0/task:0/cpu:0"
device1 = "/job:localhost/replica:0/task:0/cpu:1"
@@ -201,6 +200,49 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.test_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
class PrefetchToDeviceTest(test.TestCase):
@@ -227,14 +269,43 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.test_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
device_dataset = host_dataset.apply(
@@ -258,8 +329,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element["a"].dtype)
self.assertEqual([], next_element["a"].shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual({"a": i}, sess.run(next_element))
@@ -292,8 +362,7 @@ class PrefetchToDeviceTest(test.TestCase):
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
for i in range(10):
actual = sess.run(next_element)
@@ -343,8 +412,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- worker_config = config_pb2.ConfigProto()
- worker_config.device_count["CPU"] = 2
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config) as sess:
sess.run(iterator.initializer)
for i in range(5):
@@ -377,5 +445,467 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
+class CopyToDeviceTest(test.TestCase):
+
+ def testCopyToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceInt32(self):
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int32, next_element.dtype)
+ self.assertEqual((4,), next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDevice(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDeviceWithPrefetch(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32AndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStringsAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDevicePingPongCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with compat.forward_compatibility_horizon(2018, 8, 4):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
+ back_to_cpu_dataset = device_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = back_to_cpu_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInitAndPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInitAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 80e1cb0041..592642da0c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -17,21 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import counter
from tensorflow.contrib.data.python.ops import enumerate_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -81,88 +73,5 @@ class RangeDatasetTest(test.TestCase):
self.assertEqual(-2, sess.run(negative_get_next))
-class RangeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _iterator_checkpoint_prefix_local(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_prefix_local(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_prefix_local()),
- dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def testSaveRestore(self):
-
- def _build_graph(start, stop):
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Saving and restoring in same session.
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def _build_range_dataset(self, start, stop):
- return dataset_ops.Dataset.range(start, stop)
-
- def testRangeCore(self):
- start = 2
- stop = 10
- stop_1 = 8
- self.run_core_tests(lambda: self._build_range_dataset(start, stop),
- lambda: self._build_range_dataset(start, stop_1),
- stop - start)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index e0237198b7..9df403ef50 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -17,426 +17,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import gzip
import os
-import zlib
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
-from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-class TextLineDatasetTestBase(test.TestCase):
-
- def _lineText(self, f, l):
- return compat.as_bytes("%d: %d" % (f, l))
-
- def _createFiles(self,
- num_files,
- num_lines,
- crlf=False,
- compression_type=None):
- filenames = []
- for i in range(num_files):
- fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
- filenames.append(fn)
- contents = []
- for j in range(num_lines):
- contents.append(self._lineText(i, j))
- # Always include a newline after the record unless it is
- # at the end of the file, in which case we include it
- if j + 1 != num_lines or i == 0:
- contents.append(b"\r\n" if crlf else b"\n")
- contents = b"".join(contents)
-
- if not compression_type:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
-
- return filenames
-
-
-class TextLineDatasetSerializationTest(
- TextLineDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, test_filenames, compression_type=None):
- return core_readers.TextLineDataset(
- test_filenames, compression_type=compression_type, buffer_size=10)
-
- def testTextLineCore(self):
- compression_types = [None, "GZIP", "ZLIB"]
- num_files = 5
- lines_per_file = 5
- num_outputs = num_files * lines_per_file
- for compression_type in compression_types:
- test_filenames = self._createFiles(
- num_files,
- lines_per_file,
- crlf=True,
- compression_type=compression_type)
- # pylint: disable=cell-var-from-loop
- self.run_core_tests(
- lambda: self._build_iterator_graph(test_filenames, compression_type),
- lambda: self._build_iterator_graph(test_filenames), num_outputs)
- # pylint: enable=cell-var-from-loop
-
-
-class FixedLengthRecordReaderTestBase(test.TestCase):
-
- def setUp(self):
- super(FixedLengthRecordReaderTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self._header_bytes = 5
- self._record_bytes = 3
- self._footer_bytes = 2
-
- def _record(self, f, r):
- return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with open(fn, "wb") as f:
- f.write(b"H" * self._header_bytes)
- for j in range(self._num_records):
- f.write(self._record(i, j))
- f.write(b"F" * self._footer_bytes)
- return filenames
-
-
-class FixedLengthRecordDatasetSerializationTest(
- FixedLengthRecordReaderTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, num_epochs, compression_type=None):
- filenames = self._createFiles()
- return core_readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes,
- self._footer_bytes).repeat(num_epochs)
-
- def testFixedLengthRecordCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
-
-class TFRecordDatasetTestBase(test.TestCase):
-
- def setUp(self):
- super(TFRecordDatasetTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- self.test_filenames = self._createFiles()
-
- self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
- self.num_epochs = array_ops.placeholder_with_default(
- constant_op.constant(1, dtypes.int64), shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
- self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = core_readers.TFRecordDataset(
- self.filenames, self.compression_type).repeat(self.num_epochs)
- batch_dataset = repeat_dataset.batch(self.batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- self.init_op = iterator.make_initializer(repeat_dataset)
- self.init_batch_op = iterator.make_initializer(batch_dataset)
- self.get_next = iterator.get_next()
-
- def _record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
-
-
-class TFRecordDatasetSerializationTest(
- TFRecordDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self,
- num_epochs,
- batch_size=1,
- compression_type=None,
- buffer_size=None):
- filenames = self._createFiles()
- if compression_type is "ZLIB":
- zlib_files = []
- for i, fn in enumerate(filenames):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
- filenames = zlib_files
-
- elif compression_type is "GZIP":
- gzip_files = []
- for i, fn in enumerate(self.test_filenames):
- with open(fn, "rb") as f:
- gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(gzfn, "wb") as gzf:
- gzf.write(f.read())
- gzip_files.append(gzfn)
- filenames = gzip_files
-
- return core_readers.TFRecordDataset(
- filenames, compression_type,
- buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
-
- def testTFRecordWithoutBufferCore(self):
- num_epochs = 5
- batch_size = num_epochs
- num_outputs = num_epochs * self._num_files * self._num_records // batch_size
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, batch_size,
- buffer_size=0),
- lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
- num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
- num_outputs * batch_size)
- # pylint: enable=g-long-lambda
-
- def testTFRecordWithBufferCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
- def testTFRecordWithCompressionCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
-
-
-def _interleave(iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
-
-class ReadBatchFeaturesTest(test.TestCase):
-
- def setUp(self):
- super(ReadBatchFeaturesTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self.test_filenames = self._createFiles()
-
- def _read_batch_features(self,
- filenames,
- num_epochs,
- batch_size,
- reader_num_threads=1,
- parser_num_threads=1,
- shuffle=False,
- shuffle_seed=None,
- drop_final_batch=False):
- self.filenames = filenames
- self.num_epochs = num_epochs
- self.batch_size = batch_size
-
- return readers.make_batched_features_dataset(
- file_pattern=self.filenames,
- batch_size=self.batch_size,
- features={
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
- },
- reader=core_readers.TFRecordDataset,
- num_epochs=self.num_epochs,
- shuffle=shuffle,
- shuffle_seed=shuffle_seed,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads,
- drop_final_batch=drop_final_batch).make_one_shot_iterator(
- ).get_next()
-
- def _record(self, f, r):
- example = example_pb2.Example(
- features=feature_pb2.Features(
- feature={
- "file":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[f])),
- "record":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[r])),
- "keywords":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
- }))
- return example.SerializeToString()
-
- def _get_keywords(self, f, r):
- num_keywords = 1 + (f + r) % 2
- keywords = []
- for index in range(num_keywords):
- keywords.append(compat.as_bytes("keyword%d" % index))
- return keywords
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
-
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
- return sess.run([
- file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
- ])
-
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=1):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return _interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for record in next_records:
- f = record[0]
- r = record[1]
- file_batch.append(f)
- record_batch.append(r)
- keywords = self._get_keywords(f, r)
- keywords_batch_values.extend(keywords)
- keywords_batch_indices.extend(
- [[batch_index, i] for i in range(len(keywords))])
- batch_index += 1
- keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
- if len(file_batch) == batch_size:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
- ]
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- if file_batch:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
- ]
-
- def _verify_records(self,
- sess,
- batch_size,
- file_index=None,
- num_epochs=1,
- interleave_cycle_length=1):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
- for i in range(len(expected_batch)):
- self.assertAllEqual(expected_batch[i], actual_batch[i])
+class ReadBatchFeaturesTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
def testRead(self):
for batch_size in [1, 2]:
@@ -444,33 +42,33 @@ class ReadBatchFeaturesTest(test.TestCase):
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from file 0.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from file 1.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
# Basic test: read from both files.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
- batch_size=batch_size)
- self._verify_records(sess, batch_size, num_epochs=num_epochs)
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
with self.assertRaises(errors.OutOfRangeError):
self._next_actual_batch(sess)
@@ -504,18 +102,18 @@ class ReadBatchFeaturesTest(test.TestCase):
# Test that shuffling with same seed produces the same result.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- outputs1 = self._read_batch_features(
+ outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
- outputs2 = self._read_batch_features(
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
+ shuffle_seed=5).make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
batch2 = self._run_actual_batch(outputs2, sess)
@@ -525,18 +123,18 @@ class ReadBatchFeaturesTest(test.TestCase):
# Test that shuffling with different seeds produces a different order.
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- outputs1 = self._read_batch_features(
+ outputs1 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=5)
- outputs2 = self._read_batch_features(
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
shuffle=True,
- shuffle_seed=15)
+ shuffle_seed=15).make_one_shot_iterator().get_next()
all_equal = True
for _ in range(total_records // batch_size):
batch1 = self._run_actual_batch(outputs1, sess)
@@ -552,13 +150,14 @@ class ReadBatchFeaturesTest(test.TestCase):
for parser_num_threads in [2, 4]:
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads)
- self._verify_records(
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
sess,
batch_size,
num_epochs=num_epochs,
@@ -571,11 +170,11 @@ class ReadBatchFeaturesTest(test.TestCase):
for num_epochs in [1, 10]:
with ops.Graph().as_default():
# Basic test: read from file 0.
- self.outputs = self._read_batch_features(
+ self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
num_epochs=num_epochs,
batch_size=batch_size,
- drop_final_batch=True)
+ drop_final_batch=True).make_one_shot_iterator().get_next()
for _, tensor in self.outputs.items():
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
@@ -1069,7 +668,30 @@ class MakeCsvDatasetTest(test.TestCase):
self.assertFalse(all_equal)
-class MakeTFRecordDatasetTest(TFRecordDatasetTestBase):
+class MakeTFRecordDatasetTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase):
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
def _next_expected_batch(self,
file_indices,
@@ -1085,8 +707,8 @@ class MakeTFRecordDatasetTest(TFRecordDatasetTestBase):
yield j, i
def _next_record_interleaved(file_indices, cycle_length):
- return _interleave([_next_record([i]) for i in file_indices],
- cycle_length)
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
record_batch = []
batch_index = 0
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
new file mode 100644
index 0000000000..e63bc4c720
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -0,0 +1,331 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing reader datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class FixedLengthRecordDatasetTestBase(test.TestCase):
+ """Base class for setting up and testing FixedLengthRecordDataset."""
+
+ def setUp(self):
+ super(FixedLengthRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self._header_bytes = 5
+ self._record_bytes = 3
+ self._footer_bytes = 2
+
+ def _record(self, f, r):
+ return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+ with open(fn, "wb") as f:
+ f.write(b"H" * self._header_bytes)
+ for j in range(self._num_records):
+ f.write(self._record(i, j))
+ f.write(b"F" * self._footer_bytes)
+ return filenames
+
+
+class ReadBatchFeaturesTestBase(test.TestCase):
+ """Base class for setting up and testing `make_batched_feature_dataset`."""
+
+ def setUp(self):
+ super(ReadBatchFeaturesTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self.test_filenames = self._createFiles()
+
+ def make_batch_feature(self,
+ filenames,
+ num_epochs,
+ batch_size,
+ reader_num_threads=1,
+ parser_num_threads=1,
+ shuffle=False,
+ shuffle_seed=None,
+ drop_final_batch=False):
+ self.filenames = filenames
+ self.num_epochs = num_epochs
+ self.batch_size = batch_size
+
+ return readers.make_batched_features_dataset(
+ file_pattern=self.filenames,
+ batch_size=self.batch_size,
+ features={
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ reader=core_readers.TFRecordDataset,
+ num_epochs=self.num_epochs,
+ shuffle=shuffle,
+ shuffle_seed=shuffle_seed,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch)
+
+ def _record(self, f, r):
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "file":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[f])),
+ "record":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[r])),
+ "keywords":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=self._get_keywords(f, r)))
+ }))
+ return example.SerializeToString()
+
+ def _get_keywords(self, f, r):
+ num_keywords = 1 + (f + r) % 2
+ keywords = []
+ for index in range(num_keywords):
+ keywords.append(compat.as_bytes("keyword%d" % index))
+ return keywords
+
+ def _sum_keywords(self, num_files):
+ sum_keywords = 0
+ for i in range(num_files):
+ for j in range(self._num_records):
+ sum_keywords += 1 + (i + j) % 2
+ return sum_keywords
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
+
+ def _run_actual_batch(self, outputs, sess):
+ file_op = outputs["file"]
+ keywords_indices_op = outputs["keywords"].indices
+ keywords_values_op = outputs["keywords"].values
+ keywords_dense_shape_op = outputs["keywords"].dense_shape
+ record_op = outputs["record"]
+ return sess.run([
+ file_op, keywords_indices_op, keywords_values_op,
+ keywords_dense_shape_op, record_op
+ ])
+
+ def _next_actual_batch(self, sess):
+ return self._run_actual_batch(self.outputs, sess)
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=1):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for record in next_records:
+ f = record[0]
+ r = record[1]
+ file_batch.append(f)
+ record_batch.append(r)
+ keywords = self._get_keywords(f, r)
+ keywords_batch_values.extend(keywords)
+ keywords_batch_indices.extend(
+ [[batch_index, i] for i in range(len(keywords))])
+ batch_index += 1
+ keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
+ if len(file_batch) == batch_size:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [batch_size, keywords_batch_max_len], record_batch
+ ]
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ if file_batch:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [len(file_batch), keywords_batch_max_len], record_batch
+ ]
+
+ def verify_records(self,
+ sess,
+ batch_size,
+ file_index=None,
+ num_epochs=1,
+ interleave_cycle_length=1):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length):
+ actual_batch = self._next_actual_batch(sess)
+ for i in range(len(expected_batch)):
+ self.assertAllEqual(expected_batch[i], actual_batch[i])
+
+
+class TextLineDatasetTestBase(test.TestCase):
+ """Base class for setting up and testing TextLineDataset."""
+
+ def _lineText(self, f, l):
+ return compat.as_bytes("%d: %d" % (f, l))
+
+ def _createFiles(self,
+ num_files,
+ num_lines,
+ crlf=False,
+ compression_type=None):
+ filenames = []
+ for i in range(num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ contents = []
+ for j in range(num_lines):
+ contents.append(self._lineText(i, j))
+ # Always include a newline after the record unless it is
+ # at the end of the file, in which case we include it
+ if j + 1 != num_lines or i == 0:
+ contents.append(b"\r\n" if crlf else b"\n")
+ contents = b"".join(contents)
+
+ if not compression_type:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+
+ return filenames
+
+
+class TFRecordDatasetTestBase(test.TestCase):
+ """Base class for setting up and testing TFRecordDataset."""
+
+ def setUp(self):
+ super(TFRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ self.test_filenames = self._createFiles()
+
+ self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
+ self.num_epochs = array_ops.placeholder_with_default(
+ constant_op.constant(1, dtypes.int64), shape=[])
+ self.compression_type = array_ops.placeholder_with_default("", shape=[])
+ self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = core_readers.TFRecordDataset(
+ self.filenames, self.compression_type).repeat(self.num_epochs)
+ batch_dataset = repeat_dataset.batch(self.batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ self.init_op = iterator.make_initializer(repeat_dataset)
+ self.init_batch_op = iterator.make_initializer(batch_dataset)
+ self.get_next = iterator.get_next()
+
+ def _record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index bdc003a8a5..c5cfddb72b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
import time
+
from absl.testing import parameterized
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.data.python.ops import resampling
from tensorflow.python.data.ops import dataset_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index eb2ceff893..42cada0b97 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -21,7 +21,6 @@ import itertools
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
@@ -64,7 +63,7 @@ class ScanDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFibonacci(self):
iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
@@ -168,18 +167,5 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-class ScanDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_elements):
- return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
- scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
-
- def testScanCore(self):
- num_output = 5
- self.run_core_tests(lambda: self._build_dataset(num_output),
- lambda: self._build_dataset(2), num_output)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
new file mode 100644
index 0000000000..686788522a
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -0,0 +1,526 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "dataset_serialization_test_base",
+ srcs = [
+ "dataset_serialization_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "cache_dataset_serialization_test",
+ size = "small",
+ srcs = ["cache_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "concatenate_dataset_serialization_test",
+ size = "small",
+ srcs = ["concatenate_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_serialization_test",
+ size = "medium",
+ srcs = ["dataset_constructor_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_serialization_test",
+ size = "medium",
+ srcs = ["filter_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fixed_length_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["fixed_length_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "flat_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["flat_map_dataset_serialization_test.py"],
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "group_by_reducer_serialization_test",
+ size = "medium",
+ srcs = ["group_by_reducer_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:grouping",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "group_by_window_serialization_test",
+ size = "medium",
+ srcs = ["group_by_window_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:grouping",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "ignore_errors_serialization_test",
+ size = "small",
+ srcs = ["ignore_errors_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:error_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_and_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_and_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_serialization_test",
+ size = "small",
+ srcs = ["optimize_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "padded_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["padded_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "prefetch_dataset_serialization_test",
+ size = "small",
+ srcs = ["prefetch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "range_dataset_serialization_test",
+ size = "small",
+ srcs = ["range_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sample_from_datasets_serialization_test",
+ size = "medium",
+ srcs = ["sample_from_datasets_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_serialization_test",
+ size = "small",
+ srcs = ["scan_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:scan_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sequence_dataset_serialization_test",
+ size = "medium",
+ srcs = ["sequence_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "serialization_integration_test",
+ size = "small",
+ srcs = ["serialization_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_and_repeat_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:shuffle_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_serialization_test",
+ size = "small",
+ srcs = ["sql_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_serialization_test",
+ size = "medium",
+ srcs = ["stats_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "textline_dataset_serialization_test",
+ size = "medium",
+ srcs = ["textline_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "tf_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["tf_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "unbatch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["unbatch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_serialization_test",
+ size = "small",
+ srcs = ["unique_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/ops:unique",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "zip_dataset_serialization_test",
+ size = "small",
+ srcs = ["zip_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..af87d8b608
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the BatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class BatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len // batch_size
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+ def _build_dataset_dense_to_sparse(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+
+ def testDenseToSparseBatchDatasetCore(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
+
+ num_outputs = len(components) // 4
+ self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
+ lambda: self._build_dataset_dense_to_sparse(diff_comp),
+ num_outputs)
+
+ def _sparse(self, i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ def _build_dataset_sparse(self, batch_size=5):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
+
+ def testSparseCore(self):
+ self.run_core_tests(self._build_dataset_sparse,
+ lambda: self._build_dataset_sparse(2), 2)
+
+ def _build_dataset_nested_sparse(self):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
+
+ def testNestedSparseCore(self):
+ self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
new file mode 100644
index 0000000000..a0a1100893
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -0,0 +1,190 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CacheDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class CacheDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self.range_size = 10
+ self.num_repeats = 3
+ self.num_outputs = self.range_size * self.num_repeats
+ self.cache_file_prefix = 'test'
+
+ def ds_fn(self):
+ return dataset_ops.Dataset.range(self.range_size).cache(
+ os.path.join(self.get_temp_dir(),
+ self.cache_file_prefix)).repeat(self.num_repeats)
+
+ def expected_outputs(self):
+ return list(range(self.range_size)) * self.num_repeats
+
+ def testCheckpointBeforeOneEpoch(self):
+ # Generate 5 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ def testCheckpointBeforeOneEpochThenRunFewSteps(self):
+ # Generate 8 entries from iterator but save checkpoint after producing
+ # 5.
+ outputs = self.gen_outputs(
+ self.ds_fn, [5],
+ 8,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, range(8))
+
+ # Restoring from checkpoint and running GetNext should return a
+ # `AlreadExistsError` now because the lockfile already exists.
+ with self.assertRaises(errors.AlreadyExistsError):
+ self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+
+ def testCheckpointAfterOneEpoch(self):
+ # Generate 15 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ def testCheckpointAfterOneEpochThenRunFewSteps(self):
+ # Generate 18 entries from iterator but save checkpoint after producing
+ # 15.
+ outputs = self.gen_outputs(
+ self.ds_fn, [15],
+ 18,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
+
+ outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ def testCheckpointBeforeOneEpochButRunCompleteEpoch(self):
+ # Generate 13 entries from iterator but save checkpoint after producing
+ # 5.
+ outputs = self.gen_outputs(
+ self.ds_fn, [5],
+ 13,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
+
+ # Since we ran for more than one epoch, the cache was completely written.
+ # The ckpt was saved when the iterator was in cache-write mode. Test that
+ # the iterator falls back to read mode after restoring if the cache has
+ # been completely written.
+
+ outputs = list(range(5)) + self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ def testCheckpointUnusedWriterIterator(self):
+ # Checkpoint before get_next is called even once.
+ outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, [])
+
+ outputs = self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ def testCheckpointUnusedMidwayWriterIterator(self):
+ # Produce 5 elements and checkpoint.
+ outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint, then produce no elements and checkpoint.
+ outputs.extend(
+ self.gen_outputs(
+ self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce rest of the elements.
+ outputs.extend(
+ self.gen_outputs(
+ self.ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ def testUnusedCheckpointError(self):
+ # Produce 5 elements and save ckpt.
+ outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Since the complete cache has not been written, a new iterator which does
+ # not restore the checkpoint will throw an error since there is a partial
+ # cache shard.
+ with self.assertRaises(errors.AlreadyExistsError):
+ outputs = self.gen_outputs(
+ self.ds_fn, [], self.num_outputs, verify_exhausted=False)
+
+ def testIgnoreCheckpointIfCacheWritten(self):
+ # Produce 15 elements and save ckpt. This will write the complete cache.
+ outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Build the iterator again but do not restore from ckpt. Since the cache
+ # has already been written we should be able to use it.
+ outputs = self.gen_outputs(
+ self.ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
index 17f2980157..96f13d75a3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the ConcatenateDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
new file mode 100644
index 0000000000..2139b5c33d
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the dataset constructors serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class FromTensorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_dataset(self, variable_array):
+ components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
+
+ return dataset_ops.Dataset.from_tensors(components)
+
+ def testFromTensorsCore(self):
+ # Equal length components
+ arr = np.array(1)
+ num_outputs = 1
+ diff_arr = np.array(2)
+ self.run_core_tests(lambda: self._build_tensor_dataset(arr),
+ lambda: self._build_tensor_dataset(diff_arr),
+ num_outputs)
+
+
+class FromTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_slices_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components)
+
+ def testFromTensorSlicesCore(self):
+ # Equal length components
+ components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0]))
+
+ diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[5], [6], [7], [8]]), 22),
+ np.array([1.0, 2.0, 3.0, 4.0]))
+
+ dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
+
+ self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
+ lambda: self._build_tensor_slices_dataset(diff_comp), 4)
+ self.run_core_tests(
+ lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
+
+
+class FromSparseTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_sparse_tensor_slice_dataset(self, slices):
+ indices = np.array(
+ [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
+ dtype=np.int64)
+ values = np.array([val for s in slices for val in s], dtype=np.float64)
+ dense_shape = np.array(
+ [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
+ sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
+ return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
+
+ def testFromSparseTensorSlicesCore(self):
+ slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
+ diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
+
+ self.run_core_tests(
+ lambda: self._build_sparse_tensor_slice_dataset(slices),
+ lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
+ 9,
+ sparse_tensors=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 78ecce8f7d..393f08850b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase):
ckpt_saved=False,
init_before_restore=False,
sparse_tensors=False,
- verify_exhausted=True):
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
"""Generates elements from input dataset while stopping at break points.
Produces `num_outputs` outputs and saves the state of the iterator in the
@@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase):
sparse_tensors: Whether dataset is built from SparseTensor(s).
verify_exhausted: Whether to verify that the iterator has been exhausted
after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
Returns:
A list of `num_outputs` items.
@@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase):
if i == len(break_points) and verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next_op)
- self._save(sess, saver)
- ckpt_saved = True
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
return outputs
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
index b572d6ed77..7c170078a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the FilterDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
@@ -35,7 +35,7 @@ class FilterDatasetSerializationTest(
def testFilterCore(self):
div = 3
- num_outputs = np.sum([x % 3 is not 2 for x in range(100)])
+ num_outputs = np.sum([x % 3 != 2 for x in range(100)])
self.run_core_tests(lambda: self._build_filter_range_graph(div),
lambda: self._build_filter_range_graph(div * 2),
num_outputs)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..34392d88d4
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FixedLengthRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class FixedLengthRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, num_epochs, compression_type=None):
+ filenames = self._createFiles()
+ return core_readers.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes,
+ self._footer_bytes).repeat(num_epochs)
+
+ def testFixedLengthRecordCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
index f3feecef32..16051ffd3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the FlatMapDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
new file mode 100644
index 0000000000..571e0899bb
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByReducer serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ return dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+ def testCoreGroupByReducer(self):
+ components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 5,
+ verify_exhausted=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
new file mode 100644
index 0000000000..f86af4084e
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByWindow serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByWindowSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
+
+ def testCoreGroupByWindow(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 12,
+ verify_exhausted=False)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
new file mode 100644
index 0000000000..65ae9923b8
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the IgnoreErrors input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class IgnoreErrorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors())
+
+ def testIgnoreErrorsCore(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+ diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
+ num_outputs = 4
+ self.run_core_tests(lambda: self._build_ds(components),
+ lambda: self._build_ds(diff_components), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..ac3892fe81
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the InterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class InterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ repeat_count = 2
+ return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ repeat_count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length)
+
+ def testSerializationCore(self):
+ input_values = np.array([4, 5, 6], dtype=np.int64)
+ num_outputs = np.sum(input_values) * 2
+ # cycle_length > 1, block_length > 1
+ cycle_length = 2
+ block_length = 3
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length, block_length),
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length * 2, block_length * 1),
+ num_outputs)
+ # cycle_length = 1
+ cycle_length = 1
+ block_length = 3
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length, block_length),
+ None, num_outputs)
+ # block_length = 1
+ cycle_length = 2
+ block_length = 1
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length, block_length),
+ None, num_outputs)
+ # pylint: enable=g-long-lambda
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
+ _interleave_fn, cycle_length=1)
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..c9cd211328
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,88 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
new file mode 100644
index 0000000000..ab783e5cce
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class MapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 14
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(self._num_epochs))
+
+ def testSaveRestoreCore(self):
+ self.run_core_tests(
+ self._build_ds,
+ lambda: self._build_ds(multiplier=15.0),
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(_map_fn)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1)))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testSparseCore(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def _build_ds(num_outputs):
+ return dataset_ops.Dataset.range(num_outputs).map(_sparse)
+
+ num_outputs = 10
+ self.run_core_tests(lambda: _build_ds(num_outputs),
+ lambda: _build_ds(int(num_outputs / 2)), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
new file mode 100644
index 0000000000..d5c03495e3
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the OptimizeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+
+ def build_dataset(num_elements, batch_size):
+ return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
+ batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
+
+ self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..9ac42a461a
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the PaddedBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class PaddedBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testPaddedBatch(self):
+
+ def build_dataset(seq_lens):
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=[-1])
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+ def testPaddedBatchNonDefaultPadding(self):
+
+ def build_dataset(seq_lens):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ padded_shape = [-1]
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ fill_tuple).padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>"))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..1f8a584df9
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelInterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class ParallelInterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self.input_values = np.array([4, 5, 6], dtype=np.int64)
+ self.num_repeats = 2
+ self.num_outputs = np.sum(self.input_values) * 2
+
+ def _build_ds(self, cycle_length, block_length, sloppy=False):
+ return (dataset_ops.Dataset.from_tensor_slices(
+ self.input_values).repeat(self.num_repeats).apply(
+ interleave_ops.parallel_interleave(
+ lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
+ cycle_length, block_length, sloppy)))
+
+ def testSerializationCore(self):
+ # cycle_length > 1, block_length > 1
+ cycle_length = 2
+ block_length = 3
+ self.run_core_tests(
+ lambda: self._build_ds(cycle_length, block_length),
+ lambda: self._build_ds(cycle_length * 2, block_length * 1),
+ self.num_outputs)
+ # cycle_length = 1
+ cycle_length = 1
+ block_length = 3
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+ # block_length = 1
+ cycle_length = 2
+ block_length = 1
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+
+ def testSerializationWithSloppy(self):
+ break_points = self.gen_break_points(self.num_outputs, 10)
+ expected_outputs = np.repeat(
+ np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
+ self.num_repeats).tolist()
+
+ def run_test(cycle_length, block_length):
+ actual = self.gen_outputs(
+ lambda: self._build_ds(cycle_length, block_length, True),
+ break_points, self.num_outputs)
+ self.assertSequenceEqual(sorted(actual), expected_outputs)
+
+ # cycle_length > 1, block_length > 1
+ run_test(2, 3)
+ # cycle_length = 1
+ run_test(1, 3)
+ # block_length = 1
+ run_test(2, 1)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).apply(
+ interleave_ops.parallel_interleave(_interleave_fn, 1))
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
new file mode 100644
index 0000000000..3fb7605be1
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelMapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class ParallelMapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 1
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
+
+ def _build_ds_with_prefetch(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
+
+ def testSaveRestoreCore(self):
+ for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
+ self.run_core_tests(
+ ds_fn,
+ lambda: ds_fn(multiplier=15.0),
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(
+ _map_fn, num_parallel_calls=2).prefetch(2)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1),
+ num_parallel_calls=2).prefetch(2))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
index 3d120a3071..c802402461 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the PrefetchDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
new file mode 100644
index 0000000000..e4f5b6cf5d
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the RangeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RangeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _iterator_checkpoint_prefix_local(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix_local(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix_local()),
+ dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def _build_range_dataset(self, start, stop):
+ return dataset_ops.Dataset.range(start, stop)
+
+ def testRangeCore(self):
+ start = 2
+ stop = 10
+ stop_1 = 8
+ self.run_core_tests(lambda: self._build_range_dataset(start, stop),
+ lambda: self._build_range_dataset(start, stop_1),
+ stop - start)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
new file mode 100644
index 0000000000..fdb35ea624
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SampleFromDatasets serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class SampleFromDatasetsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, probs, num_samples):
+ dataset = interleave_ops.sample_from_datasets(
+ [
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(len(probs))
+ ],
+ probs,
+ seed=1813)
+ return dataset.take(num_samples)
+
+ def testSerializationCore(self):
+ self.run_core_tests(
+ lambda: self._build_dataset([0.5, 0.5], 100),
+ lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
new file mode 100644
index 0000000000..af9ef48c0f
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ScanDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ScanDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_elements):
+ return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
+
+ def testScanCore(self):
+ num_output = 5
+ self.run_core_tests(lambda: self._build_dataset(num_output),
+ lambda: self._build_dataset(2), num_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
index d0cb203a3a..2afebca0f5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the sequence datasets serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
-class SequenceDatasetSerializationTest(
+class SkipDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_skip_dataset(self, count):
@@ -52,6 +52,10 @@ class SequenceDatasetSerializationTest(
'Shape must be rank 0 but is rank 1'):
self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
+
+class TakeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
def _build_take_dataset(self, count):
components = (np.arange(10),)
return dataset_ops.Dataset.from_tensor_slices(components).take(count)
@@ -79,6 +83,10 @@ class SequenceDatasetSerializationTest(
'Shape must be rank 0 but is rank 1'):
self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
+
+class RepeatDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
def _build_repeat_dataset(self, count, take_count=3):
components = (np.arange(10),)
return dataset_ops.Dataset.from_tensor_slices(components).take(
@@ -117,5 +125,5 @@ class SequenceDatasetSerializationTest(
None, 0)
-if __name__ == "__main__":
+if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
index 0a6b74dc3e..992d996a48 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Integration test for input pipeline serialization."""
+"""Integration test for dataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
-class MultipleInputPipelinesTest(test.TestCase):
+class SerializationIntegrationTest(test.TestCase):
def _build_input_pipeline(self, name, num_outputs):
with ops.name_scope(name):
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
new file mode 100644
index 0000000000..f199ec835e
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleAndRepeatDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ShuffleAndRepeatSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, seed):
+ return dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
+
+ def testCore(self):
+ self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
+ 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
new file mode 100644
index 0000000000..d46c762aaa
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class ShuffleDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_shuffle_dataset(
+ self,
+ range_limit=10,
+ num_repeats=5,
+ buffer_size=5,
+ seed=None,
+ reshuffle_each_iteration=None,
+ ):
+ return dataset_ops.Dataset.range(range_limit).shuffle(
+ buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
+
+ def testShuffleCore(self):
+
+ seed = 55
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ # pylint: disable=cell-var-from-loop
+ # pylint: disable=g-long-lambda
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+ self.run_core_tests(
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=10,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ num_outputs)
+ # pylint: enable=cell-var-from-loop
+ # pylint: enable=g-long-lambda
+
+ def testNonDeterministicSeeding(self):
+
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ # We checkpoint the initial state of the Dataset so that we can restore
+ # the seeds in the next run. Since the seeding is non-deterministic
+ # the dataset gets initialized with different seeds each time.
+ expected = self.gen_outputs(
+ ds_fn,
+ break_points=[0],
+ num_outputs=num_outputs,
+ ckpt_saved=False,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points=self.gen_break_points(num_outputs),
+ num_outputs=num_outputs,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.match(expected, actual)
+
+ def testMultipleIterators(self):
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ with ops.Graph().as_default() as g:
+ ds = ds_fn()
+ iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
+ get_next_ops = [it.get_next() for it in iterators]
+ saveables = [
+ contrib_iterator_ops.make_saveable_from_iterator(it)
+ for it in iterators
+ ]
+ for saveable in saveables:
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ saver = saver_lib.Saver(allow_empty=True)
+ with self.test_session(graph=g) as sess:
+ self._save(sess, saver)
+ expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self._restore(saver, sess)
+ actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self.match(expected, actual)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
new file mode 100644
index 0000000000..93b26ed58a
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SqlDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetSerializationTest(
+ sql_dataset_op_test_base.SqlDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_repeats):
+ data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
+ "first_name DESC")
+ output_types = (dtypes.string, dtypes.string, dtypes.string)
+ return readers.SqlDataset(driver_name, data_source_name, query,
+ output_types).repeat(num_repeats)
+
+ def testSQLSaveable(self):
+ num_repeats = 4
+ num_outputs = num_repeats * 2
+ self.run_core_tests(lambda: self._build_dataset(num_repeats),
+ lambda: self._build_dataset(num_repeats // 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
new file mode 100644
index 0000000000..14cd3e9c4a
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the StatsDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
+# transformation `stats_ops.set_stats_aggregator`, since we don't support
+# serializing StatsAggregator yet.
+class StatsDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset_bytes_stats(self, num_elements):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
+ stats_ops.bytes_produced_stats("bytes_produced"))
+
+ def test_bytes_produced_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.bytes_produced_stats(["bytes_produced"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testBytesStatsDatasetSaveableCore(self):
+ num_outputs = 100
+ self.run_core_tests(
+ lambda: self._build_dataset_bytes_stats(num_outputs),
+ lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+
+ def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag))
+
+ def _build_dataset_multiple_tags(self,
+ num_elements,
+ tag1="record_latency",
+ tag2="record_latency_2"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
+
+ def test_latency_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats(["record_latency", "record_latency_2"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testLatencyStatsDatasetSaveableCore(self):
+ num_outputs = 100
+
+ self.run_core_tests(
+ lambda: self._build_dataset_latency_stats(num_outputs),
+ lambda: self._build_dataset_latency_stats(num_outputs // 10),
+ num_outputs)
+
+ self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
+ None, num_outputs)
+
+ tag1 = "record_latency"
+ tag2 = "record_latency"
+ self.run_core_tests(
+ lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
+ None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
new file mode 100644
index 0000000000..2483787f44
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TextLineDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TextLineDatasetSerializationTest(
+ reader_dataset_ops_test_base.TextLineDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, test_filenames, compression_type=None):
+ return core_readers.TextLineDataset(
+ test_filenames, compression_type=compression_type, buffer_size=10)
+
+ def testTextLineCore(self):
+ compression_types = [None, "GZIP", "ZLIB"]
+ num_files = 5
+ lines_per_file = 5
+ num_outputs = num_files * lines_per_file
+ for compression_type in compression_types:
+ test_filenames = self._createFiles(
+ num_files,
+ lines_per_file,
+ crlf=True,
+ compression_type=compression_type)
+ # pylint: disable=cell-var-from-loop
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(test_filenames, compression_type),
+ lambda: self._build_iterator_graph(test_filenames), num_outputs)
+ # pylint: enable=cell-var-from-loop
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..55a6257a27
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TFRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TFRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self,
+ num_epochs,
+ batch_size=1,
+ compression_type=None,
+ buffer_size=None):
+ filenames = self._createFiles()
+ if compression_type == "ZLIB":
+ zlib_files = []
+ for i, fn in enumerate(filenames):
+ with open(fn, "rb") as f:
+ cdata = zlib.compress(f.read())
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ zlib_files.append(zfn)
+ filenames = zlib_files
+
+ elif compression_type == "GZIP":
+ gzip_files = []
+ for i, fn in enumerate(self.test_filenames):
+ with open(fn, "rb") as f:
+ gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
+ with gzip.GzipFile(gzfn, "wb") as gzf:
+ gzf.write(f.read())
+ gzip_files.append(gzfn)
+ filenames = gzip_files
+
+ return core_readers.TFRecordDataset(
+ filenames, compression_type,
+ buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
+
+ def testTFRecordWithoutBufferCore(self):
+ num_epochs = 5
+ batch_size = num_epochs
+ num_outputs = num_epochs * self._num_files * self._num_records // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, batch_size,
+ buffer_size=0),
+ lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
+ num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
+ num_outputs * batch_size)
+ # pylint: enable=g-long-lambda
+
+ def testTFRecordWithBufferCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+ def testTFRecordWithCompressionCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
new file mode 100644
index 0000000000..b2a5a8a20d
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UnbatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UnbatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(
+ batch_size).apply(batching.unbatch())
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
new file mode 100644
index 0000000000..22f15b8846
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UniqueDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UniqueDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testUnique(self):
+
+ def build_dataset(num_elements, unique_elem_range):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: x % unique_elem_range).apply(unique.unique())
+
+ self.run_core_tests(lambda: build_dataset(200, 100),
+ lambda: build_dataset(40, 100), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
index e39fa957f0..340a6ff72e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the ZipDataset serialization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index bcc644c097..3c11d7a97f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
@@ -27,60 +26,25 @@ from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_shuffle_dataset(
- self,
- range_limit=10,
- num_repeats=5,
- buffer_size=5,
- seed=None,
- reshuffle_each_iteration=None,
- ):
- return dataset_ops.Dataset.range(range_limit).shuffle(
- buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
-
- def testShuffleCore(self):
-
- seed = 55
- range_limit = 10
- num_repeats = 5
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 8, 10, 25, 50]
- reshuffle_each_iteration = False
- # pylint: disable=cell-var-from-loop
- # pylint: disable=g-long-lambda
- for buffer_size in buffer_sizes:
- self.run_core_tests(
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration),
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=10,
- reshuffle_each_iteration=reshuffle_each_iteration),
- num_outputs)
- # pylint: enable=cell-var-from-loop
- # pylint: enable=g-long-lambda
-
-
-class ShuffleAndRepeatTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+class ShuffleAndRepeatTest(test.TestCase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
+ def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
+ get_next = ds_fn().make_one_shot_iterator().get_next()
+ outputs = []
+ with self.test_session() as sess:
+ for _ in range(num_outputs):
+ outputs.append(sess.run(get_next))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ return outputs
+
def testCorrectOutput(self):
- output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertSequenceEqual(
sorted(output), sorted(
np.array([range(20) for _ in range(5)]).flatten()))
@@ -89,53 +53,53 @@ class ShuffleAndRepeatTest(
def testReshuffling(self):
# Check that the output orders of different epochs are indeed different.
- output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
for i in range(4):
epoch1 = output[i * 20:(i + 1) * 20]
epoch2 = output[(i + 1) * 20:(i + 2) * 20]
self.assertNotEqual(epoch1, epoch2)
def testSameOrderForSameSeeds(self):
- output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
- output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
self.assertEqual(output1, output2)
def testDifferentOrderForDifferentSeeds(self):
- output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
- output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100)
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
def testCountNone(self):
- output1 = self.gen_outputs(
- lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False)
- output2 = self.gen_outputs(
- lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False)
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=None), 100, verify_exhausted=False)
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
def testCountMinusOne(self):
- output1 = self.gen_outputs(
- lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False)
- output2 = self.gen_outputs(
- lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False)
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False)
self.assertNotEqual(output1, output2)
self.assertEqual(sorted(output1), sorted(output2))
def testInfiniteOutputs(self):
# Asserting the iterator is exhausted after producing 100 items should fail.
with self.assertRaises(AssertionError):
- self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100)
+ self._gen_outputs(lambda: self._build_ds(10, count=None), 100)
with self.assertRaises(AssertionError):
- self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100)
+ self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
def testInfiniteEmpty(self):
with self.assertRaises(errors.OutOfRangeError):
- self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
- [], 100)
+ self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
+ 100)
with self.assertRaises(errors.OutOfRangeError):
- self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [],
- 100)
+ self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
+ 100)
def testLargeBufferSize(self):
with ops.Graph().as_default() as g:
@@ -146,17 +110,5 @@ class ShuffleAndRepeatTest(
sess.run(get_next_op)
-class ShuffleAndRepeatSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, seed):
- return dataset_ops.Dataset.range(20).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
-
- def testCore(self):
- self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
- 100)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 33c48e20be..5590a4bf78 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -58,6 +58,7 @@ class SlideDatasetTest(test.TestCase):
[t.shape.as_list() for t in get_next])
with self.test_session() as sess:
+ # stride < window_size.
# Slide over a finite input, where the window_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 7})
@@ -71,11 +72,9 @@ class SlideDatasetTest(test.TestCase):
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
-
# Slide over a finite input, where the window_size does not
# divide the total number of elements.
sess.run(init_op, feed_dict={count: 20, window_size: 17, stride: 9})
-
num_batches = (20 * 7 - 17) // 9 + 1
for i in range(num_batches):
result = sess.run(get_next)
@@ -86,6 +85,41 @@ class SlideDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ # stride == window_size.
+ sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 14})
+ num_batches = 20 * 7 // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # stride > window_size.
+ sess.run(init_op, feed_dict={count: 20, window_size: 10, stride: 14})
+ num_batches = 20 * 7 // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(10):
+ self.assertAllEqual(component[(i*14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ # Drop the last batch which is smaller than window_size.
+ sess.run(init_op, feed_dict={count: 20, window_size: 14, stride: 19})
+ num_batches = (20 * 7 - 7) // 19 # = 19 * 7 // 19
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i*19 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
# Slide over a finite input, which is less than window_size,
# should fail straight away.
sess.run(init_op, feed_dict={count: 1, window_size: 10, stride: 4})
@@ -108,10 +142,6 @@ class SlideDatasetTest(test.TestCase):
# Invalid stride should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 0})
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 3})
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, window_size: 3, stride: 5})
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 4148addf28..2c2cfbebff 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -18,83 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-
-import sqlite3
-
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
-
- def _createSqlDataset(self, output_types, num_repeats=1):
- dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
- self.query, output_types).repeat(num_repeats)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return init_op, get_next
-
- def setUp(self):
- self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- self.driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- self.query = array_ops.placeholder(dtypes.string, shape=[])
-
- conn = sqlite3.connect(self.data_source_name)
- c = conn.cursor()
- c.execute("DROP TABLE IF EXISTS students")
- c.execute("DROP TABLE IF EXISTS people")
- c.execute("DROP TABLE IF EXISTS townspeople")
- c.execute(
- "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
- "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
- "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
- "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
- "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
- "account_balance INTEGER, registration_complete INTEGER)")
- c.executemany(
- "INSERT INTO students (first_name, last_name, motto, school_id, "
- "favorite_nonsense_word, desk_number, income, favorite_number, "
- "favorite_big_number, favorite_negative_number, "
- "favorite_medium_sized_number, brownie_points, account_balance, "
- "registration_complete) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
- 9223372036854775807, -2, 32767, 0, 0, 1),
- ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
- -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
- c.execute(
- "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
- c.executemany(
- "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
- [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
- "California")])
- c.execute(
- "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
- "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
- "FLOAT, accolades FLOAT, triumphs FLOAT)")
- c.executemany(
- "INSERT INTO townspeople (first_name, last_name, victories, "
- "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
- [("George", "Washington", 20.00,
- 1331241.321342132321324589798264627463827647382647382643874,
- 9007199254740991.0),
- ("John", "Adams", -19.95,
- 1331241321342132321324589798264627463827647382647382643874.0,
- 9007199254740992.0)])
- conn.commit()
- conn.close()
-
-
-class SqlDatasetTest(SqlDatasetTestBase):
+class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
@@ -656,27 +586,5 @@ class SqlDatasetTest(SqlDatasetTestBase):
sess.run(get_next)
-class SqlDatasetSerializationTest(
- SqlDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_repeats):
- data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
- "first_name DESC")
- output_types = (dtypes.string, dtypes.string, dtypes.string)
- return readers.SqlDataset(driver_name, data_source_name, query,
- output_types).repeat(num_repeats)
-
- def testSQLSaveable(self):
- num_repeats = 4
- num_outputs = num_repeats * 2
- self.run_core_tests(lambda: self._build_dataset(num_repeats),
- lambda: self._build_dataset(num_repeats // 2),
- num_outputs)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
new file mode 100644
index 0000000000..1f5c725a92
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing SqlDataset."""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import sqlite3
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetTestBase(test.TestCase):
+ """Base class for setting up and testing SqlDataset."""
+
+ def _createSqlDataset(self, output_types, num_repeats=1):
+ dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
+ self.query, output_types).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return init_op, get_next
+
+ def setUp(self):
+ self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ self.driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ self.query = array_ops.placeholder(dtypes.string, shape=[])
+
+ conn = sqlite3.connect(self.data_source_name)
+ c = conn.cursor()
+ c.execute("DROP TABLE IF EXISTS students")
+ c.execute("DROP TABLE IF EXISTS people")
+ c.execute("DROP TABLE IF EXISTS townspeople")
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
+ "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
+ "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
+ "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
+ "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
+ "account_balance INTEGER, registration_complete INTEGER)")
+ c.executemany(
+ "INSERT INTO students (first_name, last_name, motto, school_id, "
+ "favorite_nonsense_word, desk_number, income, favorite_number, "
+ "favorite_big_number, favorite_negative_number, "
+ "favorite_medium_sized_number, brownie_points, account_balance, "
+ "registration_complete) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
+ 9223372036854775807, -2, 32767, 0, 0, 1),
+ ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
+ -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
+ c.executemany(
+ "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
+ [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
+ "California")])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
+ "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
+ "FLOAT, accolades FLOAT, triumphs FLOAT)")
+ c.executemany(
+ "INSERT INTO townspeople (first_name, last_name, victories, "
+ "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
+ [("George", "Washington", 20.00,
+ 1331241.321342132321324589798264627463827647382647382643874,
+ 9007199254740991.0),
+ ("John", "Adams", -19.95,
+ 1331241321342132321324589798264627463827647382647382643874.0,
+ 9007199254740992.0)])
+ conn.commit()
+ conn.close()
+
+
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 5c74ed6ae7..b4945685c1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import dataset_ops
@@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class StatsDatasetTest(test.TestCase):
+class StatsDatasetTestBase(test.TestCase):
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
@@ -49,6 +49,9 @@ class StatsDatasetTest(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+class StatsDatasetTest(StatsDatasetTestBase):
+
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
@@ -193,68 +196,44 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-class StatsDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset_bytes_stats(self, num_elements):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced"))
-
- def test_bytes_produced_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.bytes_produced_stats(["bytes_produced"])),
- None, 100)
-
- def testBytesStatsDatasetSaveableCore(self):
- num_outputs = 100
- self.run_core_tests(
- lambda: self._build_dataset_bytes_stats(num_outputs),
- lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+class FeatureStatsDatasetTest(
+ StatsDatasetTestBase,
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
- def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag))
-
- def _build_dataset_multiple_tags(self,
- num_elements,
- tag1="record_latency",
- tag2="record_latency_2"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
-
- def test_latency_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats(["record_latency", "record_latency_2"])),
- None, 100)
-
- def testLatencyStatsDatasetSaveableCore(self):
- num_outputs = 100
-
- self.run_core_tests(
- lambda: self._build_dataset_latency_stats(num_outputs),
- lambda: self._build_dataset_latency_stats(num_outputs // 10),
- num_outputs)
-
- self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
- None, num_outputs)
+ def testFeaturesStats(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ batch_size = 2
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5,
+ drop_final_batch=True).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
- tag1 = "record_latency"
- tag2 = "record_latency"
- self.run_core_tests(
- lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
- None, num_outputs)
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(total_records // batch_size):
+ sess.run(next_element)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats:features", total_records)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats:feature-values", total_records)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats:features", total_records * 3)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats:feature-values",
+ self._sum_keywords(1) * num_epochs + 2 * total_records)
-# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
-# transformation `stats_ops.set_stats_aggregator`, since we don't support
-# serializing StatsAggregator yet.
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 9167cb3379..0486e2bce2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import threading
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
@@ -30,9 +31,11 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase):
+class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- def testNumThreads(self):
+ @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
+ (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
@@ -42,35 +45,35 @@ class OverrideThreadpoolDatasetTest(test.TestCase):
# identifier that maps one-to-one with the underlying OS thread.
return np.array(threading.current_thread().ident).astype(np.int64)
- for num_threads in [1, 2, 4, 8, 16]:
+ dataset = (
+ dataset_ops.Dataset.range(1000).map(
+ lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
+ num_parallel_calls=32).apply(unique.unique()))
- dataset = (
- dataset_ops.Dataset.range(1000).map(
- lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
- num_parallel_calls=32).apply(unique.unique()))
+ dataset = threadpool.override_threadpool(
+ dataset,
+ threadpool.PrivateThreadPool(
+ num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name="private_thread_pool_%d" % num_threads))
- dataset = threadpool.override_threadpool(
- dataset,
- threadpool.PrivateThreadPool(
- num_threads, display_name="private_thread_pool_%d" % num_threads))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- thread_ids = []
- try:
- while True:
- thread_ids.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- self.assertEqual(len(thread_ids), len(set(thread_ids)))
- self.assertGreater(len(thread_ids), 0)
- # NOTE(mrry): We don't control the thread pool scheduling, and
- # so cannot guarantee that all of the threads in the pool will
- # perform work.
- self.assertLessEqual(len(thread_ids), num_threads)
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ thread_ids = []
+ try:
+ while True:
+ thread_ids.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ self.assertEqual(len(thread_ids), len(set(thread_ids)))
+ self.assertGreater(len(thread_ids), 0)
+ # NOTE(mrry): We don't control the thread pool scheduling, and
+ # so cannot guarantee that all of the threads in the pool will
+ # perform work.
+ self.assertLessEqual(len(thread_ids), num_threads)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index 3c436f7a0b..d79a842e7a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import unique
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
@@ -79,18 +78,5 @@ class UniqueDatasetTest(test.TestCase):
])
-class UniqueSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testUnique(self):
-
- def build_dataset(num_elements, unique_elem_range):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: x % unique_elem_range).apply(unique.unique())
-
- self.run_core_tests(lambda: build_dataset(200, 100),
- lambda: build_dataset(40, 100), 100)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
new file mode 100644
index 0000000000..33d95d6754
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -0,0 +1,523 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+
+ def _structuredDataset(self, structure, shape, dtype):
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self._structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def _structuredElement(self, structure, shape, dtype):
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self._structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
+
+ def _assertEqual(self, xs, ys):
+ self.assertEqual(type(xs), type(ys))
+ if isinstance(xs, tuple) and isinstance(ys, tuple):
+ self.assertEqual(len(xs), len(ys))
+ for x, y in zip(xs, ys):
+ self._assertEqual(x, y)
+ elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray):
+ self.assertAllEqual(xs, ys)
+ else:
+ self.assertEqual(xs, ys)
+
+ @parameterized.parameters(
+ (None, np.int32([]), dtypes.bool),
+ (None, np.int32([]), dtypes.int32),
+ (None, np.int32([]), dtypes.float32),
+ (None, np.int32([]), dtypes.string),
+ (None, np.int32([2]), dtypes.int32),
+ (None, np.int32([2, 2]), dtypes.int32),
+ ((None, None, None), np.int32([]), dtypes.int32),
+ ((None, (None, None)), np.int32([]), dtypes.int32),
+ )
+ def testWindowDatasetFlatMap(self, structure, shape, dtype):
+ """Tests windowing by chaining it with flat map.
+
+ Args:
+ structure: the input structure
+ shape: the input shape
+ dtype: the input data type
+ """
+
+ def fn(*args):
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ return args[0]
+ return dataset_ops.Dataset.zip(
+ tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
+
+ dataset = self._structuredDataset(structure, shape, dtype).apply(
+ grouping.window_dataset(5)).flat_map(fn)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ expected = sess.run(self._structuredElement(structure, shape, dtype))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (None, np.int32([]), dtypes.bool),
+ (None, np.int32([]), dtypes.int32),
+ (None, np.int32([]), dtypes.float32),
+ (None, np.int32([]), dtypes.string),
+ (None, np.int32([2]), dtypes.int32),
+ (None, np.int32([2, 2]), dtypes.int32),
+ ((None, None, None), np.int32([]), dtypes.int32),
+ ((None, (None, None)), np.int32([]), dtypes.int32),
+ )
+ def testWindowDatasetBatchDense(self, structure, shape, dtype):
+ """Tests batching of dense tensor windows.
+
+ Args:
+ structure: the input structure
+ shape: the input shape
+ dtype: the input data type
+ """
+
+ def fn(*args):
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ return batching.batch_window(args[0])
+
+ return tuple([
+ fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
+ for arg in args
+ ])
+
+ dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
+ grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ expected = sess.run(
+ self._structuredElement(structure, np.concatenate(
+ ([5], shape), axis=0), dtype))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int32([]),),
+ (np.int32([1]),),
+ (np.int32([1, 2, 3]),),
+ )
+ def testWindowDatasetBatchDenseDynamicShape(self, shape):
+ """Tests batching of dynamically shaped dense tensor windows.
+
+ Args:
+ shape: the input shape
+ """
+
+ shape_t = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape_t)).repeat(5).apply(
+ grouping.window_dataset(5)).apply(
+ grouping._map_x_dataset(batching.batch_window))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op, {shape_t: shape})
+ expected = sess.run(
+ self._structuredElement(None, np.concatenate(([5], shape), axis=0),
+ dtypes.int32))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ def _make_dense_to_sparse_fn(self, is_scalar):
+
+ def dense_to_sparse_scalar(tensor):
+ indices = [[]]
+ values = array_ops.expand_dims(tensor, 0)
+ shape = []
+ return sparse_tensor.SparseTensorValue(indices, values, shape)
+
+ def dense_to_sparse_non_scalar(tensor):
+ indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool))
+ values = array_ops.gather_nd(tensor, indices)
+ shape = array_ops.shape(tensor, out_type=dtypes.int64)
+ return sparse_tensor.SparseTensorValue(indices, values, shape)
+
+ if is_scalar:
+ return dense_to_sparse_scalar
+ return dense_to_sparse_non_scalar
+
+ def _structuredSparseDataset(self, structure, shape, dtype):
+ dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ dense_to_sparse(array_ops.zeros(shape, dtype=dtype)))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self._structuredSparseDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def _structuredSparseElement(self, structure, shape, dtype):
+ dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
+ if structure is None:
+ return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
+ else:
+ return tuple([
+ self._structuredSparseElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
+
+ @parameterized.parameters(
+ (None, np.int32([]), dtypes.bool),
+ (None, np.int32([]), dtypes.int32),
+ (None, np.int32([]), dtypes.float32),
+ (None, np.int32([]), dtypes.string),
+ (None, np.int32([2]), dtypes.int32),
+ (None, np.int32([2, 2]), dtypes.int32),
+ ((None, None, None), np.int32([]), dtypes.int32),
+ ((None, (None, None)), np.int32([]), dtypes.int32),
+ )
+ def testWindowDatasetBatchSparse(self, structure, shape, dtype):
+ """Tests batching of sparse tensor windows.
+
+ Args:
+ structure: the input structure
+ shape: the input shape
+ dtype: the input data type
+ """
+
+ def fn(*args):
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ return batching.batch_window(args[0])
+
+ return tuple([
+ fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
+ for arg in args
+ ])
+
+ dataset = self._structuredSparseDataset(
+ structure, shape, dtype).repeat(5).apply(
+ grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ expected = sess.run(
+ self._structuredSparseElement(structure,
+ np.concatenate(([5], shape), axis=0),
+ dtype))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int32([]),),
+ (np.int32([1]),),
+ (np.int32([1, 2, 3]),),
+ )
+ def testWindowDatasetBatchSparseDynamicShape(self, shape):
+ """Tests batching of dynamically shaped sparse tensor windows.
+
+ Args:
+ shape: the input shape
+ """
+
+ shape_t = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map(
+ self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test
+ grouping.window_dataset(5)).apply(
+ grouping._map_x_dataset(batching.batch_window))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op, {shape_t: shape})
+ expected = sess.run(
+ self._structuredSparseElement(None,
+ np.concatenate(([5], shape), axis=0),
+ dtypes.int32))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ def _structuredRaggedDataset(self, structure, shapes, dtype):
+
+ if structure is None:
+ return dataset_ops.Dataset.from_tensor_slices(shapes).map(
+ lambda shape: array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self._structuredRaggedDataset(substructure, shapes, dtype)
+ for substructure in structure
+ ]))
+
+ @parameterized.parameters(
+ (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ )
+ def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
+ padded_shape):
+ """Tests padded batching of dense tensor windows.
+
+ Args:
+ structure: the input structure
+ shapes: the input shapes
+ dtype: the input data type
+ padded_shape: the shape to pad the output to
+ """
+
+ def fn(*args):
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ return batching.padded_batch_window(args[0], padded_shape)
+
+ return tuple([
+ fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
+ arg, padded_shape) for arg in args
+ ])
+
+ dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply(
+ grouping.window_dataset(len(shapes))).apply(
+ grouping._map_x_dataset(fn))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
+ expected = sess.run(
+ self._structuredElement(
+ structure,
+ np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int32([[1], [2], [3]]), [-1]),
+ (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ )
+ def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
+ """Tests padded batching of dynamically shaped dense tensor windows.
+
+ Args:
+ shapes: the input shapes
+ padded_shape: the shape to pad the output to
+ """
+
+ shapes_t = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
+ lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
+ grouping.window_dataset(len(shapes))).apply(
+ grouping._map_x_dataset(
+ lambda x: batching.padded_batch_window(x, padded_shape)))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op, {shapes_t: shapes})
+ expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
+ expected = sess.run(
+ self._structuredElement(
+ None, np.concatenate((np.int32([len(shapes)]), expected_shape)),
+ dtypes.int32))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int32([[1]]), np.int32([0])),
+ (np.int32([[10], [20]]), np.int32([15])),
+ )
+ def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
+ """Tests invalid padded batching of dense tensor windows.
+
+ Args:
+ shapes: the input shapes
+ padded_shape: the shape to pad the output to
+ """
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
+ lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
+ grouping.window_dataset(len(shapes))).apply(
+ grouping._map_x_dataset(
+ lambda x: batching.padded_batch_window(x, padded_shape)))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def _structuredRaggedSparseDataset(self, structure, shapes, dtype):
+
+ def map_fn(shape):
+ dense_to_sparse = self._make_dense_to_sparse_fn(False)
+ return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
+
+ if structure is None:
+ return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn)
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self._structuredRaggedSparseDataset(substructure, shapes, dtype)
+ for substructure in structure
+ ]))
+
+ def _structuredRaggedSparseElement(self, structure, shapes, dtype,
+ padded_shape):
+ if structure is None:
+ dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
+ values = []
+ for shape in shapes:
+ dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
+ sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
+ padded_sparse = sparse_tensor.SparseTensor(sparse.indices,
+ sparse.values, dense_shape)
+ reshaped_sparse = sparse_ops.sparse_reshape(
+ padded_sparse,
+ array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0))
+ values.append(reshaped_sparse)
+ return sparse_ops.sparse_concat(0, values)
+ else:
+ return tuple([
+ self._structuredRaggedSparseElement(substructure, shapes, dtype,
+ padded_shape)
+ for substructure in structure
+ ])
+
+ @parameterized.parameters(
+ (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ )
+ def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
+ padded_shape):
+ """Tests padded batching of sparse tensor windows.
+
+ Args:
+ structure: the input structure
+ shapes: the input shapes
+ dtype: the input data type
+ padded_shape: the shape to pad the output to
+ """
+
+ def fn(*args):
+ if len(args) == 1 and not isinstance(args[0], tuple):
+ return batching.padded_batch_window(args[0], padded_shape)
+
+ return tuple([
+ fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
+ arg, padded_shape) for arg in args
+ ])
+
+ dataset = self._structuredRaggedSparseDataset(
+ structure, shapes, dtype).apply(grouping.window_dataset(
+ len(shapes))).apply(grouping._map_x_dataset(fn))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ expected = sess.run(
+ self._structuredRaggedSparseElement(structure, shapes, dtype,
+ padded_shape))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int64([[1], [2], [3]]), [-1]),
+ (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ )
+ def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
+ padded_shape):
+ """Tests padded batching of dynamically shaped sparse tensor windows.
+
+ Args:
+ shapes: the input shapes
+ padded_shape: the shape to pad the output to
+ """
+
+ shapes_t = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
+ lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
+ self._make_dense_to_sparse_fn(False)
+ ).apply(grouping.window_dataset(len(shapes))).apply(
+ grouping._map_x_dataset(
+ lambda x: batching.padded_batch_window(x, padded_shape)))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op, {shapes_t: shapes})
+ expected = sess.run(
+ self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
+ padded_shape))
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
+
+ @parameterized.parameters(
+ (np.int64([[1]]), [0]),
+ (np.int64([[10], [20]]), [15]),
+ )
+ def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
+ """Tests invalid padded batching of sparse tensor windows.
+
+ Args:
+ shapes: the input shapes
+ padded_shape: the shape to pad the output to
+ """
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
+ lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
+ self._make_dense_to_sparse_fn(False)
+ ).apply(grouping.window_dataset(len(shapes))).apply(
+ grouping._map_x_dataset(
+ lambda x: batching.padded_batch_window(x, padded_shape)))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index eceecfd174..160d7fe22a 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -49,26 +49,6 @@ py_library(
],
)
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
-
py_library(
name = "random_ops",
srcs = [
@@ -96,8 +76,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
+ ":gen_dataset_ops",
":interleave_ops",
":shuffle_ops",
+ ":stats_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
@@ -106,12 +88,12 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
@@ -133,6 +115,8 @@ py_library(
srcs = ["batching.py"],
srcs_version = "PY2AND3",
deps = [
+ ":get_single_element",
+ ":grouping",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
@@ -142,6 +126,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
@@ -209,6 +194,20 @@ py_library(
)
py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":contrib_op_loader",
+ ":gen_dataset_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
name = "resampling",
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
@@ -368,6 +367,7 @@ py_library(
":get_single_element",
":grouping",
":interleave_ops",
+ ":optimization",
":prefetching_ops",
":readers",
":resampling",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index b9393de4e9..a4914f4cde 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,18 +17,133 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import get_single_element
+from tensorflow.contrib.data.python.ops import grouping
from tensorflow.contrib.framework import with_shape
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util import deprecation
+
+
+def batch_window(dataset):
+ """Batches a window of tensors.
+
+ Args:
+ dataset: the input dataset.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+ """
+ if isinstance(dataset.output_classes, tuple):
+ raise TypeError("Input dataset expected to have a single component")
+ if dataset.output_classes is ops.Tensor:
+ return _batch_dense_window(dataset)
+ elif dataset.output_classes is sparse_tensor.SparseTensor:
+ return _batch_sparse_window(dataset)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _batch_dense_window(dataset):
+ """Batches a window of dense tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return array_ops.shape(first_element)
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, array_ops.shape(value))
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat([[0], shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _batch_sparse_window(dataset):
+ """Batches a window of sparse tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return first_element.dense_shape
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, value.dense_shape)
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64),
+ math_ops.cast(shape, dtypes.int64)], 0))
+
+ def batch_reduce_fn(state, value):
+ return sparse_ops.sparse_concat(0, [state, value])
+
+ def reshape_fn(value):
+ return sparse_ops.sparse_reshape(
+ value,
+ array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(reshape_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
def dense_to_sparse_batch(batch_size, row_shape):
@@ -75,17 +190,168 @@ def dense_to_sparse_batch(batch_size, row_shape):
"""
def _apply_fn(dataset):
- return DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
return _apply_fn
-class UnbatchDataset(dataset_ops.Dataset):
+def padded_batch_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of tensors with padding.
+
+ Args:
+ dataset: the input dataset.
+ padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the shape to which the input elements should be padded
+ prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
+ `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
+ maximum size of that dimension in each batch.
+ padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
+ padding value to use. Defaults are `0` for numeric types and the empty
+ string for string types. If `dataset` contains `tf.SparseTensor`, this
+ value is ignored.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
+ if not issubclass(dataset.output_classes,
+ (ops.Tensor, sparse_tensor.SparseTensor)):
+ raise TypeError("Input dataset expected to have a single tensor component")
+ if issubclass(dataset.output_classes, (ops.Tensor)):
+ return _padded_batch_dense_window(dataset, padded_shape, padding_value)
+ elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
+ if padding_value is not None:
+ raise ValueError("Padding value not allowed for sparse tensors")
+ return _padded_batch_sparse_window(dataset, padded_shape)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of dense tensors with padding."""
+
+ padded_shape = math_ops.cast(
+ convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return padded_shape
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(array_ops.shape(value), padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ",
+ array_ops.shape(value), padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, array_ops.shape(value))
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ if padding_value is None:
+ if dataset.output_types == dtypes.string:
+ padding_value = ""
+ elif dataset.output_types == dtypes.bool:
+ padding_value = False
+ elif dataset.output_types == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ padding_value = 0
+
+ def batch_init_fn(_):
+ return array_ops.fill(
+ array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
+ constant_op.constant(padding_value, dtype=dataset.output_types))
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ def pad_fn(value):
+ shape = array_ops.shape(value)
+ left = array_ops.zeros_like(shape)
+ right = padded_shape - shape
+ return array_ops.pad(
+ value, array_ops.stack([left, right], 1), constant_values=padding_value)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(pad_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _padded_batch_sparse_window(dataset, padded_shape):
+ """Batches a window of sparse tensors with padding."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return convert.partial_shape_to_tensor(padded_shape)
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(value.dense_shape, padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ", value.dense_shape,
+ padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, value.dense_shape)
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
+ 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64), padded_shape], 0))
+
+ def batch_reduce_fn(state, value):
+ padded_value = sparse_tensor.SparseTensor(
+ indices=value.indices, values=value.values, dense_shape=padded_shape)
+ reshaped_value = sparse_ops.sparse_reshape(
+ padded_value,
+ array_ops.concat(
+ [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
+ return sparse_ops.sparse_concat(0, [state, reshaped_value])
+
+ reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
+
+
+class _UnbatchDataset(dataset_ops.Dataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__()
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -101,10 +367,7 @@ class UnbatchDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return gen_dataset_ops.unbatch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
@@ -145,7 +408,7 @@ def unbatch():
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
if not sparse.any_sparse(dataset.output_classes):
- return UnbatchDataset(dataset)
+ return _UnbatchDataset(dataset)
# NOTE(mrry): We must ensure that any SparseTensors in `dataset`
# are normalized to the rank-1 dense representation, so that the
@@ -171,12 +434,12 @@ def unbatch():
dataset.output_shapes,
dataset.output_classes,
allow_unsafe_cast=True)
- return UnbatchDataset(restructured_dataset)
+ return _UnbatchDataset(restructured_dataset)
return _apply_fn
-def filter_irregular_batches(batch_size):
+def _filter_irregular_batches(batch_size):
"""Transformation that filters out batches that are not of size batch_size."""
def _apply_fn(dataset):
@@ -218,6 +481,8 @@ def filter_irregular_batches(batch_size):
return _apply_fn
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.")
def batch_and_drop_remainder(batch_size):
"""A batching transformation that omits the final small batch (if present).
@@ -250,12 +515,16 @@ def batch_and_drop_remainder(batch_size):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
+ # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time
+ # after 6/30/2018.
batched = dataset.batch(batch_size)
- return filter_irregular_batches(batch_size)(batched)
+ return _filter_irregular_batches(batch_size)(batched)
return _apply_fn
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.")
def padded_batch_and_drop_remainder(batch_size,
padded_shapes,
padding_values=None):
@@ -284,19 +553,21 @@ def padded_batch_and_drop_remainder(batch_size,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
+ # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)`
+ # any time after 6/30/2018.
batched = dataset.padded_batch(
batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
- return filter_irregular_batches(batch_size)(batched)
+ return _filter_irregular_batches(batch_size)(batched)
return _apply_fn
-class DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.Dataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__()
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@@ -309,11 +580,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset):
return gen_dataset_ops.dense_to_sparse_batch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._batch_size,
- row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
@@ -490,10 +758,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset):
batch_size=self._batch_size_t,
num_parallel_calls=self._num_parallel_calls_t,
drop_remainder=self._drop_remainder_t,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
# pylint: enable=protected-access
@property
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 6c21e489f7..d46d96c461 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -20,8 +20,6 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
def ignore_errors():
@@ -48,26 +46,23 @@ def ignore_errors():
"""
def _apply_fn(dataset):
- return IgnoreErrorsDataset(dataset)
+ return _IgnoreErrorsDataset(dataset)
return _apply_fn
-class IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.Dataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__()
self._input_dataset = input_dataset
def _as_variant_tensor(self):
return gen_dataset_ops.ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index 3a07df5727..0f4cd8e20c 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -64,10 +64,7 @@ def get_single_element(dataset):
nested_ret = nest.pack_sequence_as(
dataset.output_types, gen_dataset_ops.dataset_to_single_element(
dataset._as_variant_tensor(), # pylint: disable=protected-access
- output_types=nest.flatten(sparse.as_dense_types(
- dataset.output_types, dataset.output_classes)),
- output_shapes=nest.flatten(sparse.as_dense_shapes(
- dataset.output_shapes, dataset.output_classes))))
+ **dataset_ops.flat_structure(dataset)))
return sparse.deserialize_sparse_tensors(
nested_ret, dataset.output_types, dataset.output_shapes,
dataset.output_classes)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index ea229b5b27..bd8d398c58 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -21,12 +21,9 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -58,7 +55,7 @@ def group_by_reducer(key_func, reducer):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByReducerDataset(dataset, key_func, reducer)
+ return _GroupByReducerDataset(dataset, key_func, reducer)
return _apply_fn
@@ -116,8 +113,8 @@ def group_by_window(key_func,
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
return _apply_fn
@@ -152,9 +149,9 @@ def bucket_by_sequence_length(element_length_func,
@{tf.data.Dataset.padded_batch}. Defaults to padding with 0.
pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
size to maximum length in batch. If `True`, will pad dimensions with
- unknown size to bucket boundary, and caller must ensure that the source
- `Dataset` does not contain any elements with length longer than
- `max(bucket_boundaries)`.
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -206,7 +203,7 @@ def bucket_by_sequence_length(element_length_func,
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length <= max(bucket_boundaries).")
+ "length < max(bucket_boundaries).")
check = check_ops.assert_less(
bucket_id,
constant_op.constant(len(bucket_batch_sizes) - 1,
@@ -216,7 +213,7 @@ def bucket_by_sequence_length(element_length_func,
boundaries = constant_op.constant(bucket_boundaries,
dtype=dtypes.int64)
bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary
+ none_filler = bucket_boundary - 1
shapes = make_padded_shapes(
padded_shapes or grouped_dataset.output_shapes,
none_filler=none_filler)
@@ -230,39 +227,56 @@ def bucket_by_sequence_length(element_length_func,
return _apply_fn
-class _VariantDataset(dataset_ops.Dataset):
- """A Dataset wrapper for a tf.variant-typed function argument."""
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
- def __init__(self, dataset_variant, output_types, output_shapes,
- output_classes):
- super(_VariantDataset, self).__init__()
- self._dataset_variant = dataset_variant
- self._output_types = output_types
- self._output_shapes = output_shapes
- self._output_classes = output_classes
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
- def _as_variant_tensor(self):
- return self._dataset_variant
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
- @property
- def output_classes(self):
- return self._output_classes
+ Returns:
+ Dataset: A `Dataset`.
+ """
- @property
- def output_shapes(self):
- return self._output_shapes
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
- @property
- def output_types(self):
- return self._output_types
+ return _apply_fn
+
+
+def window_dataset(window_size):
+ """A transformation that creates window datasets from the input dataset.
+
+ The resulting datasets will contain `window_size` elements (or
+ `N % window_size` for the last dataset if `window_size` does not divide the
+ number of input elements `N` evenly).
+
+ Args:
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of the input dataset to combine into a window.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+ def _apply_fn(dataset):
+ return _WindowDataset(dataset, window_size)
-class GroupByReducerDataset(dataset_ops.Dataset):
+ return _apply_fn
+
+
+class _GroupByReducerDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__()
self._input_dataset = input_dataset
@@ -273,67 +287,27 @@ class GroupByReducerDataset(dataset_ops.Dataset):
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_key_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- # pylint: disable=protected-access
- if dataset_ops._should_unpack_args(nested_args):
- ret = key_func(*nested_args)
- # pylint: enable=protected-access
- else:
- ret = key_func(nested_args)
- ret = ops.convert_to_tensor(ret)
- if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
- return ret
-
- self._key_func = tf_key_func
- self._key_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
def _make_init_func(self, init_func):
"""Make wrapping Defun for init_func."""
-
- @function.Defun(dtypes.int64)
- def tf_init_func(key):
- """A wrapper for Defun that facilitates shape inference."""
- key.set_shape([])
- ret = init_func(key)
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._state_classes = sparse.get_classes(ret)
- self._state_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._state_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._init_func = tf_init_func
- self._init_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
@@ -343,83 +317,47 @@ class GroupByReducerDataset(dataset_ops.Dataset):
need_to_rerun = True
while need_to_rerun:
- # Create a list in which `tf_reduce_func` will store the new shapes.
- flat_new_state_shapes = []
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(
- self._state_types, self._state_classes)) + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
- def tf_reduce_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))
- + nest.flatten(
- sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes))):
- arg.set_shape(shape)
-
- pivot = len(nest.flatten(self._state_shapes))
- nested_state_args = nest.pack_sequence_as(self._state_types,
- args[:pivot])
- nested_state_args = sparse.deserialize_sparse_tensors(
- nested_state_args, self._state_types, self._state_shapes,
- self._state_classes)
- nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
- nested_input_args = sparse.deserialize_sparse_tensors(
- nested_input_args, input_dataset.output_types,
- input_dataset.output_shapes, input_dataset.output_classes)
-
- ret = reduce_func(nested_state_args, nested_input_args)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- # Extract shape information from the returned values.
- flat_new_state = nest.flatten(ret)
- flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])
-
- # Extract and validate type information from the returned values.
- for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
- if t.dtype != dtype:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types,
- nest.pack_sequence_as(self._state_types,
- [t.dtype for t in flat_new_state])))
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret,
- [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- # Use the private method that will execute `tf_reduce_func` but delay
- # adding it to the graph in case we need to rerun the function.
- tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access
-
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
weakened_state_shapes = [
- old.most_specific_compatible_shape(new)
- for old, new in zip(flat_state_shapes, flat_new_state_shapes)
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
]
need_to_rerun = False
- for old_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if old_shape.ndims is not None and (
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
weakened_shape.ndims is None or
- old_shape.as_list() != weakened_shape.as_list()):
+ original_shape.as_list() != weakened_shape.as_list()):
need_to_rerun = True
break
@@ -427,50 +365,19 @@ class GroupByReducerDataset(dataset_ops.Dataset):
self._state_shapes = nest.pack_sequence_as(self._state_shapes,
weakened_state_shapes)
- self._reduce_func = tf_reduce_func
+ self._reduce_func = wrapped_func.function
self._reduce_func.add_to_graph(ops.get_default_graph())
def _make_finalize_func(self, finalize_func):
"""Make wrapping Defun for finalize_func."""
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes))))
- def tf_finalize_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
-
- ret = finalize_func(nested_args)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._finalize_func = tf_finalize_func
- self._finalize_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func, "tf.contrib.data.group_by_reducer()",
+ input_classes=self._state_classes, input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
@property
def output_classes(self):
@@ -495,18 +402,15 @@ class GroupByReducerDataset(dataset_ops.Dataset):
init_func=self._init_func,
reduce_func=self._reduce_func,
finalize_func=self._finalize_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
-class GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__()
self._input_dataset = input_dataset
@@ -516,74 +420,48 @@ class GroupByWindowDataset(dataset_ops.Dataset):
def _make_window_size_func(self, window_size_func):
"""Make wrapping Defun for window_size_func."""
-
- @function.Defun(dtypes.int64)
- def tf_window_size_func(key):
- key.set_shape([])
- window_size = ops.convert_to_tensor(
- window_size_func(key), dtype=dtypes.int64)
- if window_size.dtype != dtypes.int64:
- raise ValueError(
- "`window_size_func` must return a single tf.int64 tensor.")
- return window_size
-
- self._window_size_func = tf_window_size_func
- self._window_size_func.add_to_graph(ops.get_default_graph())
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper, "tf.contrib.data.group_by_window()",
+ input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
def _make_key_func(self, key_func, input_dataset):
"""Make wrapping Defun for key_func."""
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_key_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- # pylint: disable=protected-access
- if dataset_ops._should_unpack_args(nested_args):
- ret = key_func(*nested_args)
- # pylint: enable=protected-access
- else:
- ret = key_func(nested_args)
- ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
- if ret.dtype != dtypes.int64:
- raise ValueError("`key_func` must return a single tf.int64 tensor.")
- return ret
-
- self._key_func = tf_key_func
- self._key_func.add_to_graph(ops.get_default_graph())
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping Defun for reduce_func."""
-
- @function.Defun(dtypes.int64, dtypes.variant)
- def tf_reduce_func(key, window_dataset_variant):
- """A wrapper for Defun that facilitates shape inference."""
- key.set_shape([])
- window_dataset = _VariantDataset(
- window_dataset_variant, input_dataset.output_types,
- input_dataset.output_shapes, input_dataset.output_classes)
- if not isinstance(window_dataset, dataset_ops.Dataset):
- raise TypeError("`window_dataset` must return a `Dataset` object.")
- output_dataset = reduce_func(key, window_dataset)
- if not isinstance(output_dataset, dataset_ops.Dataset):
- raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = output_dataset.output_classes
- self._output_types = output_dataset.output_types
- self._output_shapes = output_dataset.output_shapes
- return output_dataset._as_variant_tensor() # pylint: disable=protected-access
-
- self._reduce_func = tf_reduce_func
- self._reduce_func.add_to_graph(ops.get_default_graph())
+ nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func, "tf.contrib.data.reduce_by_window()",
+ input_classes=(ops.Tensor, nested_dataset),
+ input_shapes=(tensor_shape.scalar(), nested_dataset),
+ input_types=(dtypes.int64, nested_dataset),
+ experimental_nested_dataset_support=True)
+ if not isinstance(
+ wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ raise TypeError("`reduce_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._reduce_func = wrapped_func.function
@property
def output_classes(self):
@@ -606,10 +484,7 @@ class GroupByWindowDataset(dataset_ops.Dataset):
key_func=self._key_func,
reduce_func=self._reduce_func,
window_size_func=self._window_size_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
class Reducer(object):
@@ -637,3 +512,85 @@ class Reducer(object):
@property
def finalize_func(self):
return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.Dataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.contrib.data.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+class _WindowDataset(dataset_ops.Dataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, window_size):
+ """See `window_dataset()` for more details."""
+ super(_WindowDataset, self).__init__()
+ self._input_dataset = input_dataset
+ self._window_size = ops.convert_to_tensor(
+ window_size, dtype=dtypes.int64, name="window_size")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._window_size,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index be66fbac50..bcc959594a 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -24,7 +24,6 @@ from tensorflow.contrib.data.python.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -154,7 +153,7 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
return _apply_fn
-class DirectedInterleaveDataset(dataset_ops.Dataset):
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
"""A substitute for `Dataset.interleave()` on a fixed list of datasets."""
def __init__(self, selector_input, data_inputs):
@@ -171,10 +170,7 @@ class DirectedInterleaveDataset(dataset_ops.Dataset):
return gen_dataset_ops.directed_interleave_dataset(
self._selector_input._as_variant_tensor(),
[data_input._as_variant_tensor() for data_input in self._data_inputs],
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
# pylint: enable=protected-access
@property
@@ -240,7 +236,7 @@ def sample_from_datasets(datasets, weights=None, seed=None):
selector_input = dataset_ops.Dataset.zip(
(logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
- return DirectedInterleaveDataset(selector_input, datasets)
+ return _DirectedInterleaveDataset(selector_input, datasets)
def choose_from_datasets(datasets, choice_dataset):
@@ -284,4 +280,4 @@ def choose_from_datasets(datasets, choice_dataset):
and choice_dataset.output_classes == ops.Tensor):
raise TypeError("`choice_dataset` must be a dataset of scalar "
"`tf.int64` tensors.")
- return DirectedInterleaveDataset(choice_dataset, datasets)
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
new file mode 100644
index 0000000000..cf89657226
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+def optimize(optimizations=None):
+ """A transformation that applies optimizations.
+
+ Args:
+ optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
+ optimizations to use. If not specified, the default set of optimizations
+ is applied.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _OptimizeDataset(dataset, optimizations)
+
+ return _apply_fn
+
+
+class _OptimizeDataset(dataset_ops.Dataset):
+ """A `Dataset` that acts as an identity, and applies optimizations."""
+
+ def __init__(self, input_dataset, optimizations):
+ """See `optimize()` for details."""
+ super(_OptimizeDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if optimizations is None:
+ optimizations = []
+ self._optimizations = ops.convert_to_tensor(
+ optimizations, dtype=dtypes.string, name="optimizations")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.optimize_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._optimizations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index e4c9f8b58a..50212d3b52 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -26,21 +26,42 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
-# TODO(rohanj): Add a python class that constructs resource in the __init__
-# method and provides a get_next() that calls the prefetch op.
def function_buffering_resource(string_arg,
target_device,
f,
buffer_size,
+ output_types,
container="",
shared_name=None,
name=None):
+ """Creates a FunctionBufferingResource.
+
+ A FunctionBufferingResource fills up a buffer by calling a function `f` on
+ `target_device`. `f` should take in only a single string argument as input.
+
+ Args:
+ string_arg: The single string argument to the function.
+ target_device: The device to run `f` on.
+ f: The function to be executed.
+ buffer_size: Size of the buffer to be populated.
+ output_types: The output types generated by the function.
+ container: (Optional) string. Defaults to "".
+ shared_name: (Optional) string.
+ name: (Optional) string to name the op.
+
+ Returns:
+ Handle to a FunctionBufferingResource.
+ """
if shared_name is None:
shared_name = ""
return gen_dataset_ops.function_buffering_resource(
@@ -50,7 +71,8 @@ def function_buffering_resource(string_arg,
f=f,
buffer_size=buffer_size,
container=container,
- name=name)
+ name=name,
+ output_types=output_types)
def function_buffering_resource_get_next(function_buffer_resource,
@@ -123,7 +145,10 @@ class _PrefetchToDeviceIterator(object):
target_device=iterator_device,
string_arg=input_iterator_handle,
buffer_size=buffer_size,
- shared_name=shared_name)
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)))
if not self._one_shot:
reset_op = function_buffering_resource_reset(self._buffering_resource)
@@ -212,6 +237,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
with ops.device(device):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
+ output_types=self._flat_output_types,
target_device=gen_dataset_ops.iterator_get_device(self._resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
@@ -323,3 +349,172 @@ def prefetch_to_device(device, buffer_size=None):
return _PrefetchToDeviceDataset(dataset, device, buffer_size)
return _apply_fn
+
+
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.Dataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = core_gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return core_gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.contrib.data.copy_to_device()` on GPU. Please use "
+ "`Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return core_gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index 28ef5e50f3..e670c4c835 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -18,9 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -39,10 +37,7 @@ class RandomDataset(dataset_ops.Dataset):
return gen_dataset_ops.random_dataset(
seed=self._seed,
seed2=self._seed2,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f938153f5f..9373e37f5f 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import convert
@@ -539,11 +540,11 @@ class CsvDataset(dataset_ops.Dataset):
The expected output of its iterations is:
```python
- next = dataset.make_one_shot_iterator().get_next()
+ next_element = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
- print(sess.run(nxt))
+ print(sess.run(next_element))
except tf.errors.OutOfRangeError:
break
@@ -754,6 +755,8 @@ def make_batched_features_dataset(file_pattern,
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+ dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
+
if drop_final_batch:
dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size))
else:
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index bad6edd514..182a5c6ff3 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -291,4 +291,4 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
# TODO(joelshor): Simplify fraction, if possible.
a_i = (ratio_l - m) / (max_ratio - m)
- return a_i, m \ No newline at end of file
+ return a_i, m
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index e911ad0fa0..ea9dcfe68f 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -22,7 +22,6 @@ import collections
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
@@ -67,102 +66,45 @@ class _ScanDataset(dataset_ops.Dataset):
need_to_rerun = True
while need_to_rerun:
- # Create a list in which `tf_scan_func` will store the new shapes.
- flat_new_state_shapes = []
-
- @function.Defun(*(nest.flatten(
- sparse.as_dense_types(
- self._state_types, self._state_classes)) + nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes))))
- def tf_scan_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the state and input_dataset.
- for arg, shape in zip(
- args,
- nest.flatten(
- sparse.as_dense_shapes(self._state_shapes, self._state_classes))
- + nest.flatten(
- sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes))):
- arg.set_shape(shape)
-
- pivot = len(nest.flatten(self._state_shapes))
- print(self._state_classes)
- nested_state_args = nest.pack_sequence_as(self._state_types,
- args[:pivot])
- nested_state_args = sparse.deserialize_sparse_tensors(
- nested_state_args, self._state_types, self._state_shapes,
- self._state_classes)
- print(input_dataset.output_classes)
- nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
- args[pivot:])
- nested_input_args = sparse.deserialize_sparse_tensors(
- nested_input_args, input_dataset.output_types,
- input_dataset.output_shapes, input_dataset.output_classes)
-
- ret = scan_func(nested_state_args, nested_input_args)
- if not isinstance(ret, collections.Sequence) or len(ret) != 2:
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
- new_state, output_value = ret
-
- # Extract and validate class information from the returned values.
- for t, clazz in zip(
- nest.flatten(new_state), nest.flatten(self._state_classes)):
- if not isinstance(t, clazz):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes,
- nest.pack_sequence_as(
- self._state_types,
- [type(t) for t in nest.flatten(new_state)])))
- self._output_classes = sparse.get_classes(output_value)
-
- # Extract shape information from the returned values.
- flat_new_state_shapes.extend(
- [t.get_shape() for t in nest.flatten(new_state)])
- self._output_shapes = nest.pack_sequence_as(
- output_value, [t.get_shape() for t in nest.flatten(output_value)])
-
- # Extract and validate type information from the returned values.
- for t, dtype in zip(
- nest.flatten(new_state), nest.flatten(self._state_types)):
- if t.dtype != dtype:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types,
- nest.pack_sequence_as(
- self._state_types,
- [t.dtype for t in nest.flatten(new_state)])))
- self._output_types = nest.pack_sequence_as(
- output_value, [t.dtype for t in nest.flatten(output_value)])
-
- # Serialize any sparse tensors.
- new_state = nest.pack_sequence_as(new_state, [
- t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
- ])
- output_value = nest.pack_sequence_as(output_value, [
- t for t in nest.flatten(
- sparse.serialize_sparse_tensors(output_value))
- ])
- return nest.flatten(new_state) + nest.flatten(output_value)
-
- # Use the private method that will execute `tf_scan_func` but delay
- # adding it to the graph in case we need to rerun the function.
- tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func, "tf.contrib.data.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
weakened_state_shapes = [
original.most_specific_compatible_shape(new)
for original, new in zip(flat_state_shapes, flat_new_state_shapes)
@@ -178,12 +120,10 @@ class _ScanDataset(dataset_ops.Dataset):
break
if need_to_rerun:
- # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
- # `tf_scan_func`.
self._state_shapes = nest.pack_sequence_as(self._state_shapes,
weakened_state_shapes)
- self._scan_func = tf_scan_func
+ self._scan_func = wrapped_func.function
self._scan_func.add_to_graph(ops.get_default_graph())
def _as_variant_tensor(self):
@@ -193,10 +133,7 @@ class _ScanDataset(dataset_ops.Dataset):
nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
self._scan_func.captured_inputs,
f=self._scan_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index f35795abd3..d7f8a73fe3 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -18,9 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -56,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
count=self._count,
seed=self._seed,
seed2=self._seed2,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
# pylint: enable=protected-access
@property
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 19cc3cb89f..3f3c5ca17c 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -19,7 +19,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -43,10 +42,7 @@ class _SlideDataset(dataset_ops.Dataset):
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
window_size=self._window_size,
stride=self._stride,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
@@ -90,7 +86,7 @@ def sliding_window_batch(window_size, stride=1):
elements in the sliding window.
stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
steps moving the sliding window forward for one iteration. The default
- is `1`. It must be in `[1, window_size)`.
+ is `1`. It must be positive.
Returns:
A `Dataset` transformation function, which can be passed to
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3cbaab5aff..97931f75bd 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -18,13 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
@@ -97,10 +97,7 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_shapes(self):
@@ -115,7 +112,8 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def set_stats_aggregator(stats_aggregator):
"""Set the given stats_aggregator for aggregating the input dataset stats.
@@ -133,6 +131,8 @@ def set_stats_aggregator(stats_aggregator):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def bytes_produced_stats(tag):
"""Records the number of bytes produced by each element of the input dataset.
@@ -155,6 +155,8 @@ def bytes_produced_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
@@ -176,6 +178,29 @@ def latency_stats(tag):
return _apply_fn
+# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
+def feature_stats(tag):
+ """Records the features stats from `Example` records of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will be
+ associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag)
+
+ return _apply_fn
+
+
class _StatsDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
@@ -189,10 +214,7 @@ class _StatsDataset(dataset_ops.Dataset):
return self._op_function(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._tag,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_shapes(self):
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index 56f67e1766..9af1e784ff 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -22,8 +22,6 @@ import threading
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
from tensorflow.python.ops import resource_variable_ops
@@ -39,22 +37,28 @@ def _generate_shared_name(prefix):
return "{}{}".format(prefix, uid)
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
class PrivateThreadPool(object):
"""A stateful resource that represents a private thread pool."""
- def __init__(self, num_threads, display_name=None):
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
self._resource = gen_dataset_ops.thread_pool_handle(
num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
shared_name=shared_name)
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
self._resource = gen_dataset_ops.thread_pool_handle(
- num_threads=num_threads, display_name=display_name)
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
class _ThreadPoolDataset(dataset_ops.Dataset):
@@ -69,10 +73,7 @@ class _ThreadPoolDataset(dataset_ops.Dataset):
return gen_dataset_ops.thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_shapes(self):
@@ -87,6 +88,8 @@ class _ThreadPoolDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
+# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
+# or make private / remove.
def override_threadpool(dataset, thread_pool):
"""Returns a new dataset that uses the given thread pool for its operations.
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index 765ef3f9b6..e0ce0a4ef1 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -20,8 +20,6 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -44,17 +42,17 @@ def unique():
"""
def _apply_fn(dataset):
- return UniqueDataset(dataset)
+ return _UniqueDataset(dataset)
return _apply_fn
-class UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.Dataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__()
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
@@ -65,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return gen_dataset_ops.unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **dataset_ops.flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index 74b2cd90a1..1126f76f58 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -30,6 +30,7 @@ py_library(
"//tensorflow/contrib/distribute/python:monitor",
"//tensorflow/contrib/distribute/python:one_device_strategy",
"//tensorflow/contrib/distribute/python:step_fn",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
],
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 76711baf3a..2e2c3be853 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrat
from tensorflow.contrib.distribute.python.monitor import Monitor
from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceStrategy
from tensorflow.contrib.distribute.python.step_fn import *
+from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
from tensorflow.python.training.distribute import *
from tensorflow.python.util.all_util import remove_undocumented
@@ -41,6 +42,7 @@ _allowed_symbols = [
'StandardInputStep',
'StandardSingleLossStep',
'TowerContext',
+ 'TPUStrategy',
'get_cross_tower_context',
'get_distribution_strategy',
'get_loss_reduction',
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 3118deaa47..40dbfa3dd2 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -77,6 +77,7 @@ py_library(
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
@@ -148,6 +149,7 @@ py_library(
],
deps = [
":mirrored_strategy",
+ ":multi_worker_strategy",
":one_device_strategy",
":tpu_strategy",
"//tensorflow/contrib/optimizer_v2:training",
@@ -446,8 +448,10 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":values",
+ "//tensorflow/contrib/all_reduce:all_reduce_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
],
@@ -495,6 +499,7 @@ cuda_py_test(
additional_deps = [
":combinations",
":cross_tower_ops",
+ ":multi_worker_test_base",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
@@ -504,6 +509,7 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
+ shard_count = 15,
tags = [
"multi_and_single_gpu",
"no_pip",
@@ -581,6 +587,26 @@ cuda_py_test(
],
tags = [
"multi_and_single_gpu",
+ "no_windows_gpu",
"notsan",
],
)
+
+cuda_py_test(
+ name = "metrics_v1_test",
+ srcs = ["metrics_v1_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:test",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
+ ],
+)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index e400fa5be2..9a8ea4aa48 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -46,9 +46,10 @@ import unittest
from absl.testing import parameterized
import six
-from tensorflow.contrib.distribute.python import mirrored_strategy
-from tensorflow.contrib.distribute.python import one_device_strategy
-from tensorflow.contrib.distribute.python import tpu_strategy
+from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
+from tensorflow.contrib.distribute.python import multi_worker_strategy
+from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
+from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
@@ -289,9 +290,9 @@ class NamedObject(object):
class NamedDistribution(object):
"""Translates DistributionStrategy and its data into a good name."""
- def __init__(self, name, distribution, required_gpus=None,
+ def __init__(self, name, distribution_fn, required_gpus=None,
required_tpu=False):
- self._distribution = distribution
+ self._distribution_fn = distribution_fn
self._name = name
self._required_gpus = required_gpus
self._required_tpu = required_tpu
@@ -301,7 +302,7 @@ class NamedDistribution(object):
@property
def strategy(self):
- return self._distribution
+ return self._distribution_fn()
@property
def required_gpus(self):
@@ -312,32 +313,56 @@ class NamedDistribution(object):
return self._required_tpu
+# pylint: disable=g-long-lambda
default_strategy = NamedDistribution(
"Default",
- distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
+ lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access
required_gpus=None)
one_device_strategy = NamedDistribution(
- "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"),
+ "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
required_gpus=None)
-tpu_strategy_single_iteration = NamedDistribution(
- "TPUSingleIteration",
- tpu_strategy.TPUStrategy(iterations_per_step=1),
- required_tpu=True)
-tpu_strategy = NamedDistribution(
- "TPU", tpu_strategy.TPUStrategy(), required_tpu=True)
+tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
"MirroredCPUAndGPU",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/cpu:0"], prefetch_on_device=False),
required_gpus=1)
mirrored_strategy_with_two_gpus = NamedDistribution(
"Mirrored2GPUs",
- mirrored_strategy.MirroredStrategy(
+ lambda: mirrored_lib.MirroredStrategy(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
+multi_worker_strategy_with_cpu = NamedDistribution(
+ "MultiWorkerCPU",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=0), 0)
+multi_worker_strategy_with_one_gpu = NamedDistribution(
+ "MultiWorker1GPU",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=1), 1)
+multi_worker_strategy_with_two_gpus = NamedDistribution(
+ "MultiWorker2GPUs",
+ lambda: multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={
+ "worker": [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ },
+ num_gpus_per_worker=2), 2)
+
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index a411b880e8..b0baf0dad1 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import six
from tensorflow.contrib.distribute.python import cross_tower_utils
@@ -27,11 +28,12 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_util
-def _validate_destinations(destinations):
+def validate_destinations(destinations):
if not isinstance(destinations,
(value_lib.DistributedValues, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
@@ -54,7 +56,7 @@ def _validate_value_destination_pairs(value_destination_pairs):
# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps.
-def _get_devices_from(destinations):
+def get_devices_from(destinations):
if isinstance(destinations, value_lib.DistributedValues):
return list(destinations.devices)
elif isinstance(destinations, six.string_types):
@@ -64,7 +66,7 @@ def _get_devices_from(destinations):
def _devices_match(left, right):
- return set(_get_devices_from(left)) == set(_get_devices_from(right))
+ return set(get_devices_from(left)) == set(get_devices_from(right))
def _all_devices_match(value_destination_pairs):
@@ -79,7 +81,7 @@ def _all_devices_match(value_destination_pairs):
def _simple_broadcast(value, destinations):
index = {}
- devices = _get_devices_from(destinations)
+ devices = get_devices_from(destinations)
for d in devices:
index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device(
value, d)
@@ -87,7 +89,7 @@ def _simple_broadcast(value, destinations):
def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
- method_string):
+ aggregation):
# pylint: disable=g-missing-docstring
all_values = []
count = 0
@@ -111,11 +113,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn,
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices(
all_values, accumulation_fn)
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(
reduced, count)
- elif method_string != "sum":
- raise ValueError("`method_string` must be 'sum' or 'mean'")
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise ValueError("`aggregation` must be VariableAggregation.SUM "
+ "or VariableAggregation.MEAN.")
return reduced
@@ -125,14 +128,15 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, method_string, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations=None):
"""Reduce `per_device_value` to `destinations`.
- It runs the reduction operation defined by `method_string` and put the
+ It runs the reduction operation defined by `aggregation` and put the
result on `destinations`.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
per_device_value: a PerDevice object.
destinations: the reduction destinations.
@@ -145,17 +149,18 @@ class CrossTowerOps(object):
if not isinstance(per_device_value, value_lib.PerDevice):
raise ValueError("`per_device_value` must be a `PerDevice` object.")
if destinations is not None:
- _validate_destinations(destinations)
- return self._reduce(method_string, per_device_value, destinations)
+ validate_destinations(destinations)
+ return self._reduce(aggregation, per_device_value, destinations)
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Reduce PerDevice objects in a batch.
Reduce each first element in `value_destination_pairs` to each second
element which indicates the destinations.
Args:
- method_string: either 'sum' or 'mean' specifying the reduction method.
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
and destinations. If a destination is None, then the destinations
are set to match the devices of the input PerDevice object.
@@ -172,9 +177,9 @@ class CrossTowerOps(object):
"tuples of PerDevice objects and destinations")
for _, d in value_destination_pairs:
if d is not None:
- _validate_destinations(d)
+ validate_destinations(d)
- return self._batch_reduce(method_string, value_destination_pairs)
+ return self._batch_reduce(aggregation, value_destination_pairs)
def broadcast(self, tensor, destinations):
"""Broadcast the `tensor` to destinations.
@@ -186,14 +191,14 @@ class CrossTowerOps(object):
Returns:
a Mirrored object.
"""
- _validate_destinations(destinations)
+ validate_destinations(destinations)
return self._broadcast(tensor, destinations)
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
raise NotImplementedError(
"_reduce method must be implemented in descendants.")
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
raise NotImplementedError(
"_batch_reduce method must be implemented in descendants.")
@@ -219,22 +224,30 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps):
self.accumulation_fn = accumulation_fn
super(ReductionToOneDeviceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
- devices = _get_devices_from(destinations or per_device_value)
+ def _reduce(self, aggregation, per_device_value, destinations):
+ devices = get_devices_from(destinations or per_device_value)
reduce_to_device = self.reduce_to_device or devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- self.accumulation_fn, method_string)
+ self.accumulation_fn, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self._reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _group_value_by_device(per_device_values):
"""Group values into sublists by their devices.
- This grouping is needed to call the all-reduce library.
+ This grouping is needed to call the all-reduce library because it expects a
+ list of the following form:
+ [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...
+ (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...
+ (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...
+ ...
+ ]
Args:
per_device_values: a list of PerDevice obejcts.
@@ -253,18 +266,19 @@ def _group_value_by_device(per_device_values):
return grouped
-def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
+def _ungroup_and_make_mirrored(grouped_reduced, destinations, aggregation):
"""Ungroup results from all-reduce and make Mirrored objects.
Each all-reduce result will be divided by the number of destinations before
- Mirrored objects are created if method_string is "mean".
+ Mirrored objects are created if aggregation is "mean".
Args:
grouped_reduced: a list of lists, each sublist has components for each
device, paired with a None. It is the result from
cross_tower_utils.aggregate_gradients_using*.
destinations: a list of device strings for returned Mirrored objects.
- method_string: "mean" or "sum".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
Returns:
a list of Mirrored objects.
@@ -272,7 +286,7 @@ def _ungroup_and_make_mirrored(grouped_reduced, destinations, method_string):
index = [{} for _ in range(len(grouped_reduced[0]))]
for d, per_device_reduced in enumerate(grouped_reduced):
for i, (v, _) in enumerate(per_device_reduced):
- if method_string == "mean":
+ if aggregation == vs.VariableAggregation.MEAN:
index[i][destinations[d]] = v / len(destinations)
else:
index[i][destinations[d]] = v
@@ -322,7 +336,17 @@ class ConcatAndSplitPacker(object):
# TODO(zhengxq): it is also possible to optimize away all the concat
# as well.
num_splits = self.num_packs
- total_grad_size = array_ops.size(concat_grads)
+
+ # The array_ops.size function will sometimes remove static shapes. So if
+ # all gradient shapes are defined, we use another method to get the
+ # total size.
+ # TODO(yuefengz): move this logic to array_ops.size.
+ if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]):
+ total_grad_size = sum(
+ [g.shape.num_elements() for g, _ in tower_grads_and_vars])
+ else:
+ total_grad_size = array_ops.size(concat_grads)
+
split_size = total_grad_size // num_splits
split_size_last = total_grad_size - split_size * (num_splits - 1)
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
@@ -412,6 +436,31 @@ class AggregateSmallTensorPacker(object):
self.packing)
+def _pack_tensors(device_grads,
+ num_packs=0,
+ agg_small_grads_max_bytes=0,
+ agg_small_grads_max_group=0):
+ """Pack tensors if specified."""
+ if num_packs > 0:
+ tensor_packer = ConcatAndSplitPacker(num_packs)
+ device_grad_packs = tensor_packer.pack(device_grads)
+ elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0:
+ tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes,
+ agg_small_grads_max_group)
+ device_grad_packs = tensor_packer.pack(device_grads)
+ else:
+ tensor_packer = None
+ device_grad_packs = device_grads
+ return device_grad_packs, tensor_packer
+
+
+def _unpack_tensors(reduced, tensor_packer=None):
+ """Unpack tensors if they are packed before all-reduce."""
+ if tensor_packer:
+ return tensor_packer.unpack(reduced)
+ return reduced
+
+
class AllReduceCrossTowerOps(CrossTowerOps):
"""Reduction using all reduce."""
@@ -440,38 +489,38 @@ class AllReduceCrossTowerOps(CrossTowerOps):
agg_small_grads_max_group: see above.
tensors.
"""
- self.all_reduce_alg = all_reduce_alg
- self.num_packs = num_packs
- self.agg_small_grads_max_bytes = agg_small_grads_max_bytes
- self.agg_small_grads_max_group = agg_small_grads_max_group
+ self._all_reduce_alg = all_reduce_alg
+ self._num_packs = num_packs
+ self._agg_small_grads_max_bytes = agg_small_grads_max_bytes
+ self._agg_small_grads_max_group = agg_small_grads_max_group
super(AllReduceCrossTowerOps, self).__init__()
- def _reduce(self, method_string, per_device_value, destinations):
+ def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
if ((destinations is None or _devices_match(per_device_value, destinations))
and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string, [per_device_value])[0]
+ return self._batch_all_reduce(aggregation, [per_device_value])[0]
else:
if contains_indexed_slices:
logging.log_first_n(
logging.WARN,
"Efficient allreduce is not supported for IndexedSlices.", 10)
- devices = _get_devices_from(destinations or per_device_value)
+ devices = get_devices_from(destinations or per_device_value)
reduce_to_device = devices[0]
reduced = _simple_reduce(per_device_value, reduce_to_device,
- math_ops.add_n, method_string)
+ math_ops.add_n, aggregation)
return self.broadcast(reduced, devices)
- def _batch_reduce(self, method_string, value_destination_pairs):
+ def _batch_reduce(self, aggregation, value_destination_pairs):
all_devices_match = _all_devices_match(value_destination_pairs)
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
value_destination_pairs)
if (all_devices_match and not context.executing_eagerly()
and not contains_indexed_slices):
- return self._batch_all_reduce(method_string,
+ return self._batch_all_reduce(aggregation,
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
@@ -479,43 +528,30 @@ class AllReduceCrossTowerOps(CrossTowerOps):
"destinations are different.")
return [
- self._reduce(method_string, t, destinations=v)
+ self._reduce(aggregation, t, destinations=v)
for t, v in value_destination_pairs
]
- def _batch_all_reduce(self, method_string, per_device_values):
+ def _batch_all_reduce(self, aggregation, per_device_values):
"""All reduce algorithm in a batch."""
+ logging.info(
+ "batch_all_reduce invoked for batches size = %d with "
+ "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and "
+ "agg_small_grads_max_group = %d", len(per_device_values),
+ self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
destinations = per_device_values[0].devices
grouped = _group_value_by_device(per_device_values)
- if self.num_packs > 0:
- logging.info(
- "batch_all_reduce invoked for batches size = %d with "
- "algorithm = %s and num_packs = %d", len(per_device_values),
- self.all_reduce_alg, self.num_packs)
- tensor_packer = ConcatAndSplitPacker(self.num_packs)
- device_grad_packs = tensor_packer.pack(grouped)
- elif (self.agg_small_grads_max_bytes > 0 and
- self.agg_small_grads_max_group > 0):
- logging.info(
- "batch_all_reduce invoked for batches size = %d with "
- "algorithm = %s, agg_small_grads_max_bytes = %d and "
- "agg_small_grads_max_group = %d", len(per_device_values),
- self.all_reduce_alg, self.agg_small_grads_max_bytes,
- self.agg_small_grads_max_group)
- tensor_packer = AggregateSmallTensorPacker(
- self.agg_small_grads_max_bytes, self.agg_small_grads_max_group)
- device_grad_packs = tensor_packer.pack(grouped)
- else:
- logging.info(
- "batch_all_reduce invoked for batches size = %d with algorithm = %s",
- len(per_device_values), self.all_reduce_alg)
- tensor_packer = None
- device_grad_packs = grouped
+
+ device_grad_packs, tensor_packer = _pack_tensors(
+ grouped, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
# The actual aggregation of the repacked gradients. Note that they are
# sharded among different aggregation trees. So it is important to strike
# the balance on num_splits.
- if self.all_reduce_alg == "nccl":
+ if self._all_reduce_alg == "nccl":
+ # TODO(yuefengz): merge this into the all-reduce library.
reduced = cross_tower_utils.aggregate_gradients_using_nccl(
device_grad_packs)
else:
@@ -525,11 +561,135 @@ class AllReduceCrossTowerOps(CrossTowerOps):
cross_tower_utils.aggregate_gradients_using_hierarchical_copy(
destinations, device_grad_packs))
- if tensor_packer:
- reduced = tensor_packer.unpack(reduced)
-
+ reduced = _unpack_tensors(reduced, tensor_packer)
return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices,
- method_string)
+ aggregation)
+
+
+AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
+ "alg shards limit")
+
+
+class MultiWorkerAllReduce(AllReduceCrossTowerOps):
+ """All-reduce algorithms for distributed TensorFlow."""
+
+ def __init__(self,
+ worker_devices,
+ num_gpus_per_worker,
+ all_reduce_spec=("pscpu/pscpu", 2, -1),
+ num_packs=0,
+ agg_small_grads_max_bytes=0,
+ agg_small_grads_max_group=10):
+ """Initialize the all-reduce algorithm.
+
+ Args:
+ worker_devices: a list of device strings for workers participating in
+ all-reduce.
+ num_gpus_per_worker: number of GPU devices per worker.
+ all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
+ the all-reduce algorithm.
+ 1. The first element of a tuple is the name of the all-reduce algorithm.
+ Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
+ "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
+ a "/" are hierarchical, so two all-reduces are executed, the first one
+ aggregates tensors within a worker and the second aggregates across
+ workers.
+ 2. The second element of a tuple is the number of shards when doing
+ all-reduce. Let's say its values is M, each tensor after packing will be
+ split into M shards and then M parallel all-reduces would be performed
+ before finally they are concatenated backed into a complete tensor.
+ 3. The third element is the maximum size of tensors that will be
+ applicable for the algorithm specified by the first element. For
+ example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
+ tensors with size not larger than 1024 bytes will be applied a 2-shard
+ "nccl" all-reduce and other tensors will be applied a 2-shard
+ "pscpu/pscpu" algorithm. The third elements should be in increasing
+ order across tuples and end with -1 which indicates infinity.
+ num_packs: see AllReduceCrossTowerOps.
+ agg_small_grads_max_bytes: see AllReduceCrossTowerOps.
+ agg_small_grads_max_group: see AllReduceCrossTowerOps.
+ """
+ self._worker_devices = worker_devices
+ self._num_gpus_per_worker = num_gpus_per_worker
+ super(MultiWorkerAllReduce, self).__init__(
+ num_packs=num_packs,
+ agg_small_grads_max_bytes=agg_small_grads_max_bytes,
+ agg_small_grads_max_group=agg_small_grads_max_group)
+
+ def validate_and_complete_spec(spec):
+ """Validate and complete the all-reduce spec."""
+ # TODO(yuefengz): support namedtuple.
+ if not isinstance(spec, tuple):
+ raise ValueError(
+ "A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
+ if not spec or len(spec) > 3:
+ raise ValueError(
+ "Too many elements in the all-reduce spec tuple: %r" % spec)
+ if len(spec) == 1:
+ return AllReduceSpecTuple(spec[0], 1, -1)
+ elif len(spec) == 2:
+ return AllReduceSpecTuple(spec[0], spec[1], -1)
+ else:
+ return AllReduceSpecTuple(*spec)
+
+ self._all_reduce_spec = []
+ if isinstance(all_reduce_spec, six.string_types):
+ self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
+ elif isinstance(all_reduce_spec, tuple):
+ self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
+ elif isinstance(all_reduce_spec, list):
+ self._all_reduce_spec = [
+ validate_and_complete_spec(spec) for spec in all_reduce_spec
+ ]
+
+ def _batch_all_reduce(self, aggregation, per_device_values):
+ """All reduce algorithm in a batch."""
+ logging.info(
+ "distributed batch_all_reduce invoked for batches size = %d with "
+ "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d "
+ "and agg_small_grads_max_group = %d", len(per_device_values),
+ self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
+
+ destinations = sorted(per_device_values[0].devices)
+ device_grads = _group_value_by_device(per_device_values)
+
+ # The all reduce library requires fully defined shapes.
+ # TODO(yuefengz): when tensor sharding is not needed, static shapes are not
+ # required as well.
+ for device_grad in device_grads:
+ for grad, _ in device_grad:
+ if not grad.shape.is_fully_defined():
+ raise ValueError("Shape is unknown for node %r" % grad)
+
+ remaining_grads = device_grads
+ aggregated_grads = []
+ for spec_tuple in self._all_reduce_spec:
+ if spec_tuple.limit < 0:
+ this_grads = remaining_grads
+ remaining_grads = []
+ else:
+ (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size(
+ spec_tuple.limit, remaining_grads)
+ if this_grads:
+ device_grad_packs, tensor_packer = _pack_tensors(
+ this_grads, self._num_packs, self._agg_small_grads_max_bytes,
+ self._agg_small_grads_max_group)
+ range_agg_grads = cross_tower_utils.sum_gradients_all_reduce(
+ self._worker_devices, device_grad_packs, len(self._worker_devices),
+ spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
+ range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
+
+ if not aggregated_grads:
+ aggregated_grads = range_agg_grads
+ else:
+ assert len(aggregated_grads) == len(range_agg_grads)
+ for i in range(len(aggregated_grads)):
+ aggregated_grads[i] += range_agg_grads[i]
+ assert not remaining_grads
+
+ return _ungroup_and_make_mirrored(aggregated_grads, destinations,
+ aggregation)
_dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index b3bc0bac59..6a780ff60f 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -31,11 +32,12 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
def _make_per_device(values, devices):
- devices = cross_tower_ops_lib._get_devices_from(devices)
+ devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
index = {}
for d, v in zip(devices, values):
@@ -52,7 +54,7 @@ def _fake_mirrored(value, devices):
All components of the returned Mirrored have the same objects, which is not
true in reality.
"""
- devices = cross_tower_ops_lib._get_devices_from(devices)
+ devices = cross_tower_ops_lib.get_devices_from(devices)
return value_lib.Mirrored(
{d: v for d, v in zip(devices, [value] * len(devices))})
@@ -75,7 +77,7 @@ def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
_cpu_device = "/device:CPU:0"
-class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
+class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
def _assert_indexed_slices_equal(self, left, right):
self.assertIsInstance(left, ops.IndexedSlices)
@@ -92,7 +94,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self._assert_values_equal(l, r)
else:
self.assertEqual(type(left), type(right))
- self.assertEqual(left.devices, right.devices)
+ self.assertEqual(set(left.devices), set(right.devices))
if isinstance(list(left._index.values())[0], ops.IndexedSlices):
for (d, v) in left._index.items():
self._assert_indexed_slices_equal(v, right._index[d])
@@ -104,51 +106,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(
sess.run(list(left._index.values())), list(right._index.values()))
- # TODO(yuefengz): decouple the num_gpus check from distribution in
- # combinations module so that we can pass in devices instead of a distribution
- # strategy.
- reduction_to_one_combinations = combinations.combine(
- cross_tower_ops=[
- combinations.NamedObject(
- "DefaultReductionToOneDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
- combinations.NamedObject(
- "ReductionToCPUDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
- reduce_to_device=_cpu_device)),
- combinations.NamedObject(
- "AccumulateNCrossTowerOp",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
- accumulation_fn=math_ops.accumulate_n)),
- ],
- distribution=[
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus
- ],
- mode=["graph", "eager"])
- allreduce_combinations = combinations.combine(
- cross_tower_ops=[
- combinations.NamedObject(
- "AllReduce",
- cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
- combinations.NamedObject(
- "HierarchicalCopy",
- cross_tower_ops_lib.AllReduceCrossTowerOps(
- "hierarchical_copy", 8, 0, 0)),
- combinations.NamedObject(
- "AllReduceNoGradientRepacking",
- cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
- combinations.NamedObject(
- "HierarchicalCopyAggregateSmallTensors",
- cross_tower_ops_lib.AllReduceCrossTowerOps(
- "hierarchical_copy", 0, 100, 10))
- ],
- distribution=[combinations.mirrored_strategy_with_two_gpus],
- mode=["graph", "eager"])
-
- @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
- def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ def _testReductionAndBroadcast(self, cross_tower_ops, distribution):
devices = distribution.worker_devices
values = [constant_op.constant(float(d)) for d in range(len(devices))]
@@ -172,32 +130,45 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
# test reduce()
for destinations in all_destinations:
self._assert_values_equal(
- cross_tower_ops.reduce("mean", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.MEAN,
+ per_device,
+ destinations=destinations),
_fake_mirrored(mean, destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "mean", per_device_2, destinations=destinations),
+ vs.VariableAggregation.MEAN,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2, destinations or per_device))
self._assert_values_equal(
- cross_tower_ops.reduce("sum", per_device, destinations=destinations),
+ cross_tower_ops.reduce(
+ vs.VariableAggregation.SUM, per_device,
+ destinations=destinations),
_fake_mirrored(mean * len(devices), destinations or per_device))
self._assert_values_equal(
cross_tower_ops.reduce(
- "sum", per_device_2, destinations=destinations),
+ vs.VariableAggregation.SUM,
+ per_device_2,
+ destinations=destinations),
_fake_mirrored(mean_2 * len(devices), destinations or per_device))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "mean", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean, d1 or per_device),
+ _fake_mirrored(mean_2, d2 or per_device_2)
+ ])
self._assert_values_equal(
- cross_tower_ops.batch_reduce(
- "sum", [(per_device, d1), (per_device_2, d2)]),
- [_fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)])
+ cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
+ [(per_device, d1), (per_device_2, d2)]),
+ [
+ _fake_mirrored(mean * len(devices), d1 or per_device),
+ _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ ])
# test broadcast()
for destinations in all_destinations:
@@ -208,20 +179,70 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
_fake_mirrored(1., destinations))
+
+class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
+ # TODO(yuefengz): decouple the num_gpus check from distribution in
+ # combinations module so that we can pass in devices instead of a distribution
+ # strategy.
+ reduction_to_one_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "DefaultReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "ReductionToCPUDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ reduce_to_device=_cpu_device)),
+ combinations.NamedObject(
+ "AccumulateNCrossTowerOp",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
+ accumulation_fn=math_ops.accumulate_n)),
+ ],
+ distribution=[
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
+ ],
+ mode=["graph", "eager"])
+ allreduce_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "AllReduce",
+ cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)),
+ combinations.NamedObject(
+ "HierarchicalCopy",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 8, 0, 0)),
+ combinations.NamedObject(
+ "AllReduceNoGradientRepacking",
+ cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)),
+ combinations.NamedObject(
+ "HierarchicalCopyAggregateSmallTensors",
+ cross_tower_ops_lib.AllReduceCrossTowerOps(
+ "hierarchical_copy", 0, 100, 10))
+ ],
+ distribution=[combinations.mirrored_strategy_with_two_gpus],
+ mode=["graph", "eager"])
+
+ @combinations.generate(reduction_to_one_combinations + allreduce_combinations)
+ def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ with distribution.scope():
+ self._testReductionAndBroadcast(cross_tower_ops, distribution)
+
def testChooseAlgorithm(self):
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7],
[0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
- self.assertEqual(result.num_packs, 8)
+ self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
+ self.assertEqual(result._num_packs, 8)
# if there are only 4 devices
device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "nccl")
- self.assertEqual(result.num_packs, 1)
+ self.assertEqual(result._all_reduce_alg, "nccl")
+ self.assertEqual(result._num_packs, 1)
# if devices links contain each device itself
device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6],
@@ -229,16 +250,16 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
[2, 4, 5, 6, 7], [3, 4, 5, 6, 7]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "hierarchical_copy")
- self.assertEqual(result.num_packs, 8)
+ self.assertEqual(result._all_reduce_alg, "hierarchical_copy")
+ self.assertEqual(result._num_packs, 8)
# if not dgx1-like links
device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7],
[1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]]
result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links)
self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)
- self.assertEqual(result.all_reduce_alg, "nccl")
- self.assertEqual(result.num_packs, 1)
+ self.assertEqual(result._all_reduce_alg, "nccl")
+ self.assertEqual(result._num_packs, 1)
@combinations.generate(combinations.combine(
mode=["graph", "eager"],
@@ -248,8 +269,8 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0])
t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1])
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
- result = cross_tower_ops_lib._simple_reduce(per_device, devices[0],
- math_ops.add_n, "sum")
+ result = cross_tower_ops_lib._simple_reduce(
+ per_device, devices[0], math_ops.add_n, vs.VariableAggregation.SUM)
# Test that the result is semantically equal to both the concatenated
# IndexedSlices with and without duplicate indices.
@@ -260,21 +281,22 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self._assert_indexed_slices_equal(total_with_dups, result)
self._assert_indexed_slices_equal(total_without_dups, result)
- @combinations.generate(combinations.combine(
- cross_tower_ops_instance=[
- combinations.NamedObject(
- "ReductionToOneDeviceCrossTowerOps",
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
- combinations.NamedObject(
- "AllReduceCrossTowerOps",
- cross_tower_ops_lib.AllReduceCrossTowerOps())
- ],
- method_string=["sum", "mean"],
- batch_reduce=[True, False],
- mode=["graph", "eager"],
- required_gpus=1))
- def testIndexedSlicesAllReduce(self, cross_tower_ops_instance,
- method_string, batch_reduce):
+ @combinations.generate(
+ combinations.combine(
+ cross_tower_ops_instance=[
+ combinations.NamedObject(
+ "ReductionToOneDeviceCrossTowerOps",
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()),
+ combinations.NamedObject(
+ "AllReduceCrossTowerOps",
+ cross_tower_ops_lib.AllReduceCrossTowerOps())
+ ],
+ aggregation=[vs.VariableAggregation.SUM, vs.VariableAggregation.MEAN],
+ batch_reduce=[True, False],
+ mode=["graph", "eager"],
+ required_gpus=1))
+ def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, aggregation,
+ batch_reduce):
devices = ["/cpu:0", "/gpu:0"]
dense_shape = [5, 2]
t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0])
@@ -283,20 +305,19 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1})
if batch_reduce:
- result = cross_tower_ops_instance.batch_reduce(method_string,
+ result = cross_tower_ops_instance.batch_reduce(aggregation,
[(per_device, devices)])
else:
- result = cross_tower_ops_instance.reduce(method_string, per_device,
- devices)
+ result = cross_tower_ops_instance.reduce(aggregation, per_device, devices)
total_indices_with_dups = [1, 1, 3]
total_indices_without_dups = [1, 3]
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]]
total_values_without_dups = [[4., 6.], [5., 6.]]
else:
- assert method_string == "mean"
+ assert aggregation == vs.VariableAggregation.MEAN
total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]]
total_values_without_dups = [[2., 3.], [2.5, 3.]]
@@ -316,5 +337,44 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase):
self._assert_values_equal(total_mirrored_without_dups, result)
+class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
+ CrossTowerOpsTestBase):
+
+ worker_devices = [
+ "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
+ ]
+ multi_worker_allreduce_combinations = combinations.combine(
+ cross_tower_ops=[
+ combinations.NamedObject(
+ "MultiWorkerAllReduce",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)),
+ combinations.NamedObject(
+ "MultiWorkerAllReducePack",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)),
+ combinations.NamedObject(
+ "MultiWorkerAllReduceAggregation",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)),
+ combinations.NamedObject(
+ "MultiWorkerAllReduceMultipleSpecs",
+ cross_tower_ops_lib.MultiWorkerAllReduce(
+ worker_devices, 2, [("pscpu/pscpu", 2, 100),
+ ("xring", 2, -1)], 0, 0, 0)),
+ ],
+ distribution=[
+ combinations.multi_worker_strategy_with_cpu,
+ combinations.multi_worker_strategy_with_one_gpu,
+ combinations.multi_worker_strategy_with_two_gpus
+ ],
+ mode=["graph"])
+
+ @combinations.generate(multi_worker_allreduce_combinations)
+ def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ with distribution.scope():
+ self._testReductionAndBroadcast(cross_tower_ops, distribution)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 137fabf4c7..2bb088e704 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections as pycoll
from tensorflow.contrib import nccl
+from tensorflow.contrib.all_reduce.python import all_reduce
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -158,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
return (grad, v), None
+def group_device_names(devices, group_size):
+ """Group device names into groups of group_size.
+
+ Args:
+ devices: a list of canonical device strings.
+ group_size: integer which is equal to or greater than 1.
+
+ Returns:
+ list of lists of devices, where each inner list is group_size long,
+ and each device appears at least once in an inner list. If
+ len(devices) % group_size == 0 then each device will appear exactly once.
+
+ Raises:
+ ValueError: if group_size > len(devices)
+ """
+ num_devices = len(devices)
+ if group_size > num_devices:
+ raise ValueError(
+ 'only %d devices, but group_size=%d' % (num_devices, group_size))
+ num_groups = (
+ num_devices // group_size + (1 if (num_devices % group_size != 0) else 0))
+ groups = [[] for i in range(num_groups)]
+ for i in range(num_groups * group_size):
+ groups[i % num_groups].append(devices[i % num_devices])
+ return groups
+
+
+def split_grads_by_size(threshold_size, device_grads):
+ """Break gradients into two sets according to tensor size.
+
+ Args:
+ threshold_size: int size cutoff for small vs large tensor.
+ device_grads: List of lists of (gradient, variable) tuples. The outer
+ list is over devices. The inner list is over individual gradients.
+
+ Returns:
+ small_grads: Subset of device_grads where shape is <= threshold_size
+ elements.
+ large_grads: Subset of device_grads where shape is > threshold_size
+ elements.
+ """
+ small_grads = []
+ large_grads = []
+ for dl in device_grads:
+ small_dl = []
+ large_dl = []
+ for (g, v) in dl:
+ tensor_size = g.get_shape().num_elements()
+ if tensor_size <= threshold_size:
+ small_dl.append([g, v])
+ else:
+ large_dl.append([g, v])
+ if small_dl:
+ small_grads.append(small_dl)
+ if large_dl:
+ large_grads.append(large_dl)
+ return small_grads, large_grads
+
+
+def sum_grad_and_var_all_reduce(grad_and_vars,
+ num_workers,
+ alg,
+ gpu_indices,
+ aux_devices=None,
+ num_shards=1):
+ """Apply all-reduce algorithm over specified gradient tensors."""
+ with ops.name_scope('allreduce'):
+ # Note that each grad_and_vars looks like the following:
+ # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
+ scaled_grads = [g for g, _ in grad_and_vars]
+ if alg == 'nccl':
+ summed_grads = nccl.all_sum(scaled_grads)
+ elif alg == 'xring':
+ summed_grads = all_reduce.build_ring_all_reduce(
+ scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add)
+ elif alg == 'nccl/xring':
+ summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards,
+ math_ops.add)
+ elif alg == 'nccl/rechd':
+ summed_grads = all_reduce.build_nccl_then_recursive_hd(
+ scaled_grads, math_ops.add)
+ elif alg == 'nccl/pscpu':
+ summed_grads = all_reduce.build_nccl_then_shuffle(
+ scaled_grads, aux_devices, math_ops.add, math_ops.add_n)
+ elif alg == 'pscpu/pscpu':
+ second_gather_devices = aux_devices[:num_shards]
+ summed_grads = all_reduce.build_shuffle_then_shuffle(
+ scaled_grads, aux_devices, second_gather_devices, math_ops.add_n)
+ elif alg in ['pscpu', 'psgpu']:
+ summed_grads = all_reduce.build_shuffle_all_reduce(
+ scaled_grads, aux_devices, math_ops.add_n)
+ else:
+ raise ValueError('unsupported all_reduce alg: ', alg)
+
+ result = []
+ for (_, v), g in zip(grad_and_vars, summed_grads):
+ result.append([g, v])
+ return result
+
+
+def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg,
+ num_shards, gpu_indices):
+ """Apply all-reduce algorithm over specified gradient tensors.
+
+ Args:
+ dev_prefixes: list of prefix strings to use to generate PS device names.
+ tower_grads: the gradients to reduce.
+ num_workers: number of worker processes across entire job.
+ alg: the all-reduce algorithm to apply.
+ num_shards: alg-specific sharding factor.
+ gpu_indices: indices of local GPUs in order usable for ring-reduce.
+
+ Returns:
+ list of reduced tensors
+ """
+ alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']])
+ is_hierarchical = '/' in alg
+ if 'pscpu' in alg:
+ aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
+ elif 'psgpu' in alg:
+ aux_devices = [
+ prefix + '/gpu:%d' % i
+ for i in range(len(gpu_indices))
+ for prefix in dev_prefixes
+ ]
+ else:
+ aux_devices = ['/job:localhost/cpu:0']
+ # Auxiliary devices for hierarchical all-reduces.
+ aux_device_groups = group_device_names(
+ aux_devices, num_shards if alg_contains_shuffle else 1)
+ group_index = 0
+ reduced_gv_list = []
+ for grad_and_vars in zip(*tower_grads):
+ reduced_gv_list.append(
+ sum_grad_and_var_all_reduce(
+ grad_and_vars, num_workers, alg, gpu_indices, aux_devices
+ if is_hierarchical else aux_device_groups[group_index], num_shards))
+ group_index = (group_index + 1) % len(aux_device_groups)
+ new_tower_grads = [list(x) for x in zip(*reduced_gv_list)]
+ return new_tower_grads
+
+
def extract_ranges(index_list, range_size_limit=32):
"""Extract consecutive ranges and singles from index_list.
@@ -330,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing):
for dev_idx, gv_list in enumerate(tower_grads):
gv_list = list(gv_list)
new_gv_list = gv_list[num_packed:]
- for i in xrange(0, num_packed):
+ for i in range(num_packed):
k = '%d:%d' % (dev_idx, i)
gpt = packing[k]
gv = unpack_grad_tuple(gv_list[i], gpt)
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
index 4ef8db6815..d25964fa41 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py
@@ -38,7 +38,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
self.evaluate(ops.convert_to_tensor(left)),
self.evaluate(ops.convert_to_tensor(right)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAggregateTensors(self):
t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]])
@@ -46,7 +46,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1])
self._assert_values_equal(total, result)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAggregateIndexedSlices(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
@@ -57,7 +57,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(result, ops.IndexedSlices)
self._assert_values_equal(total, result)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDivideTensor(self):
t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]])
n = 2
@@ -65,7 +65,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n)
self._assert_values_equal(expected, result)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDivideIndexedSlices(self):
t = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
@@ -75,13 +75,13 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(result, ops.IndexedSlices)
self._assert_values_equal(expected, result)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testIsIndexedSlices(self):
t = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices(t))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_List(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
@@ -89,7 +89,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_Tuple(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
@@ -97,7 +97,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
constant_op.constant([[0., 0.], [5, 6], [7., 8.]]))
self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_PerDevice(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
@@ -106,7 +106,7 @@ class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1})
self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testContainsIndexedSlices_PerDeviceMapOutput(self):
t0 = math_ops._as_indexed_slices(
constant_op.constant([[1., 2.], [0, 0], [3., 4.]]))
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
new file mode 100644
index 0000000000..6c6bf14309
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -0,0 +1,438 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for V1 metrics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import variables
+
+
+def _labeled_dataset_fn():
+ # First four batches of x: labels, predictions -> (labels == predictions)
+ # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False
+ # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False
+ # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
+ # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
+ return dataset_ops.Dataset.range(1000).map(
+ lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
+
+
+def _boolean_dataset_fn():
+ # First four batches of labels, predictions: {TP, FP, TN, FN}
+ # with a threshold of 0.5:
+ # T, T -> TP; F, T -> FP; T, F -> FN
+ # F, F -> TN; T, T -> TP; F, T -> FP
+ # T, F -> FN; F, F -> TN; T, T -> TP
+ # F, T -> FP; T, F -> FN; F, F -> TN
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [True, False, True, False],
+ "predictions": [True, True, False, False]}).repeat().batch(3)
+
+
+def _threshold_dataset_fn():
+ # First four batches of labels, predictions: {TP, FP, TN, FN}
+ # with a threshold of 0.5:
+ # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN
+ # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP
+ # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP
+ # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [True, False, True, False],
+ "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
+
+
+def _regression_dataset_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [1., .5, 1., 0.],
+ "predictions": [1., .75, .25, 0.]}).repeat()
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus],
+ mode=["graph"])
+
+
+# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
+# metrics.precision_at_k
+class MetricsV1Test(test.TestCase, parameterized.TestCase):
+
+ def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
+ with ops.Graph().as_default(), distribution.scope():
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
+ value, update = distribution.call_for_each_tower(
+ metric_fn, iterator.get_next())
+ update = distribution.group(update)
+ self.evaluate(variables.local_variables_initializer())
+ # TODO(josh11b): Once we switch to using a global batch size for input,
+ # replace "distribution.num_towers" with "1".
+ batches_per_update = distribution.num_towers
+
+ # Update variables using the first `num_towers` batches.
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
+ 0.001, msg="After first update")
+
+ # Update variables using the second `num_towers` batches.
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(2 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After second update")
+
+ if batches_per_update == 1: # Consume 4 input batches
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(3 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After third update")
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(4 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After fourth update")
+
+ @combinations.generate(all_combinations())
+ def testMean(self, distribution):
+ def _dataset_fn():
+ return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
+
+ def _expected_fn(num_batches):
+ # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
+ return num_batches * 2 - 0.5
+
+ self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAccuracy(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.accuracy(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanPerClassAccuracy(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_per_class_accuracy(
+ labels, predictions, num_classes=5)
+
+ def _expected_fn(num_batches):
+ mean = lambda x: sum(x) / len(x)
+ return [mean([1., 1., 1., 0., 0.]),
+ mean([0.5, 0.5, 0.5, 0., 0.]),
+ mean([1./3, 1./3, 0.5, 0., 0.]),
+ mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanIOU(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_iou(
+ labels, predictions, num_classes=5)
+
+ def _expected_fn(num_batches):
+ mean = lambda x: sum(x) / len(x)
+ return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch
+ mean([1./4, 1./4, 1./3, 0., 0.]),
+ mean([1./6, 1./6, 1./5, 0., 0.]),
+ mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanTensor(self, distribution):
+ def _dataset_fn():
+ dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
+ # Want to produce a fixed, known shape, so drop remainder when batching.
+ dataset = dataset.apply(batching.batch_and_drop_remainder(4))
+ return dataset
+
+ def _expected_fn(num_batches):
+ # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
+ # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
+ # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
+ # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
+ first = 2. * num_batches - 2.
+ return [first, first + 1., first + 2., first + 3.]
+
+ self._test_metric(
+ distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAUCROC(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
+ summation_method="careful_interpolation")
+
+ def _expected_fn(num_batches):
+ return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAUCPR(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
+ summation_method="careful_interpolation")
+
+ def _expected_fn(num_batches):
+ return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalseNegatives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_negatives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 1., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalseNegativesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [1.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTrueNegatives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_negatives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 1., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTrueNegativesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[0.], [1.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalsePositives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_positives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 2., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalsePositivesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_positives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [2.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTruePositives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_positives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 2., 3., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTruePositivesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_positives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [2.], [3.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testPrecision(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.precision(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testPrecisionAtThreshold(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.precision_at_thresholds(labels, predictions, [0.5])
+
+ def _expected_fn(num_batches):
+ return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRecall(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.recall(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRecallAtThreshold(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.recall_at_thresholds(labels, predictions, [0.5])
+
+ def _expected_fn(num_batches):
+ return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanSquaredError(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_squared_error(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
+
+ self._test_metric(
+ distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRootMeanSquaredError(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.root_mean_squared_error(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
+
+ self._test_metric(
+ distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testSensitivityAtSpecificity(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
+
+ def _expected_fn(num_batches):
+ return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testSpecificityAtSensitivity(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
+
+ def _expected_fn(num_batches):
+ return [0., 1./3, 0.5, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 5c056a7c73..aeeb9553e6 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -56,6 +56,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
is_tpu=[True]))
def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss,
is_tpu):
+ # TODO(priyag): Remove this once the step TPU Strategy is stable.
+ if is_tpu:
+ self.skipTest("TPU tests are WIP.")
+
with distribution.scope():
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
@@ -84,8 +88,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- weights.append(self.evaluate(distribution.fetch(layer.kernel)))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ weights.append(self.evaluate(layer.kernel))
+ biases.append(self.evaluate(layer.bias))
if is_tpu:
with self.test_session() as sess:
@@ -111,6 +115,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
is_tpu=[True]))
def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu):
+ # TODO(priyag): Remove this once the step TPU Strategy is stable.
+ if is_tpu:
+ self.skipTest("TPU tests are WIP.")
+
created_variables = []
trainable_variables = []
@@ -186,7 +194,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
# towers will re-execute UPDATE_OPS of previous towers.
update_ops_in_cross_tower_mode=[True])) +
combinations.combine(
- distribution=[combinations.tpu_strategy_single_iteration],
+ distribution=[combinations.tpu_strategy],
optimizer_fn=[
combinations.gradient_descent_optimizer_v1_fn,
combinations.gradient_descent_optimizer_v2_fn
@@ -198,6 +206,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
renorm, is_tpu,
update_ops_in_cross_tower_mode):
"""Verifies that moving mean updates are reduced across towers."""
+ # TODO(priyag): Remove this once the step TPU Strategy is stable.
+ if is_tpu:
+ self.skipTest("TPU tests are WIP.")
+
with distribution.scope():
num_towers = len(distribution.worker_devices)
model_fn, dataset_fn, batchnorm = batchnorm_example(
@@ -242,7 +254,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- moving_means = self.evaluate(distribution.fetch(batchnorm.moving_mean))
+ moving_means = self.evaluate(batchnorm.moving_mean)
# We make sure that the moving_mean is updated as if the sample mean is
# calculated over all towers.
@@ -279,12 +291,16 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
mode=["graph"], use_callable_loss=[True, False]) +
combinations.combine(mode=["eager"], use_callable_loss=[True])) +
combinations.combine(
- distribution=[combinations.tpu_strategy_single_iteration],
+ distribution=[combinations.tpu_strategy],
is_tpu=[True],
mode=["graph"],
use_callable_loss=[True, False])))
def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction,
use_callable_loss, is_tpu):
+ # TODO(priyag): Remove this once the step TPU Strategy is stable.
+ if is_tpu:
+ self.skipTest("TPU tests are WIP.")
+
with distribution.scope():
all_vars = []
@@ -329,7 +345,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
v = all_vars[0]
self.assertTrue(all([v is vi for vi in all_vars[1:]]))
- weight = numpy.squeeze(self.evaluate(distribution.fetch(v)))
+ weight = numpy.squeeze(self.evaluate(v))
# Our model is:
# predict = x * w
# loss = (predict - y)^2
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 14dbbd6e27..dcbc6b0878 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -84,9 +84,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert len(set(devices)) == len(devices), (
"No duplicates allowed in `devices` argument.")
# TODO(josh11b): Require at least 2 devices?
- self._devices = devices
- self._canonical_device_set = set(
- [device_util.canonicalize(d) for d in devices])
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
dict((d, i) for i, d in enumerate(devices)))
self._cross_tower_ops = cross_tower_ops
@@ -105,9 +104,39 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
colocate_with = kwargs.pop("colocate_with", None)
devices = self._get_devices_from(colocate_with)
- tower_local = kwargs.pop("tower_local_reduce_method", None)
- if tower_local is not None:
+ # Get synchronization value
+ synchronization = kwargs.get(
+ "synchronization", variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are tower local.
+ is_tower_local = True
kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_tower_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in [
+ variable_scope.VariableAggregation.NONE,
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]:
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
@@ -119,7 +148,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if i > 0:
# Give replicas meaningful distinct names:
var0name = index[devices[0]].name.split(":")[0]
- kwargs["name"] = "%s/replica_%d" % (var0name, i)
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
# Initialize replicas with the same value:
if context.executing_eagerly():
kwargs["initial_value"] = array_ops.identity(
@@ -134,11 +166,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
assert not isinstance(v, values.DistributedVariable)
index[d] = v
- if tower_local is None:
- result = values.MirroredVariable(index, index[devices[0]])
+ if is_tower_local:
+ result = values.TowerLocalVariable(index, index[devices[0]],
+ aggregation)
else:
- result = values.TowerLocalVariable(
- index, index[devices[0]], tower_local)
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
if not context.executing_eagerly():
g = ops.get_default_graph()
@@ -259,8 +291,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
{t.device: t.merge_args for t in threads})
merge_kwargs = values.regroup(
{t.device: t.merge_kwargs for t in threads})
- merge_result = threads[0].merge_fn(
- self, *merge_args, **merge_kwargs)
+ # We capture the name_scope of the MTT when we call merge_fn
+ # to ensure that if we have opened a name scope in the MTT,
+ # it will be respected when executing the merge function. We only
+ # capture the name_scope from the first MTT and assume it is
+ # the same for all other MTTs.
+ mtt_captured_name_scope = threads[0].captured_name_scope
+ with ops.name_scope(mtt_captured_name_scope):
+ merge_result = threads[0].merge_fn(
+ self, *merge_args, **merge_kwargs)
for t in threads:
t.merge_result = values.select_device(t.device, merge_result)
finally:
@@ -273,8 +312,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
index = {}
- i = 0
- for m in map_over:
+ for i, m in enumerate(map_over):
d = self._devices[i % len(self._devices)]
with ops.device(d):
l = index.get(d, [])
@@ -297,27 +335,46 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
return self._cross_tower_ops
- def _reduce(self, method_string, value, destinations):
- if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
- value = values.PerDevice({self._devices[0]: value})
- assert isinstance(value, values.PerDevice)
+ def _reduce(self, aggregation, value, destinations):
+ assert not isinstance(value, values.Mirrored)
+ if not isinstance(value, values.PerDevice):
+ if value == 0:
+ return 0
+ if aggregation == variable_scope.VariableAggregation.MEAN:
+ return self._broadcast(value, destinations)
+
+ cross_tower_ops_lib.validate_destinations(destinations)
+ if len(self._devices) == 1:
+ if destinations:
+ # TODO(anjalisridhar): Moves these methods to a device utility file?
+ devices = cross_tower_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ with ops.device(devices[0]):
+ return array_ops.identity(value)
+ else:
+ value_updates = {}
+ for d in devices:
+ with ops.device(d):
+ value_updates[d] = array_ops.identity(value)
+ return values.Mirrored(value_updates)
+ raise ValueError("A non PerDevice value cannot be reduced with the given "
+ "aggregation.")
return self._get_cross_tower_ops().reduce(
- method_string, value, destinations=destinations)
+ aggregation, value, destinations=destinations)
- def _batch_reduce(self, method_string, value_destination_pairs):
- return self._get_cross_tower_ops().batch_reduce(method_string,
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
- # kwargs don't need to be mirrored.
- assert isinstance(var, values.MirroredVariable)
# TODO(josh11b): In eager mode, use one thread per device.
+ assert isinstance(var, values.DistributedVariable)
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
@@ -334,32 +391,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**values.select_device_mirrored(d, kwargs))
return values.regroup(updates, values.Mirrored)
- def _fetch(self, val, destination, fn):
- """Return a copy of `val` or `fn(val)` on `destination`."""
- if isinstance(val, values.TowerLocalVariable):
- val = self.reduce(val.reduce_method, val, destinations=destination)
- with ops.device(destination):
- return fn(self.unwrap(val)[0])
-
- assert isinstance(val, values.Mirrored), (
- "val = %s (type %s)" % (val, val.__class__.__name__))
- if val.on_device(destination):
- with ops.device(destination):
- # Use an identity here to make sure we are returning a tensor
- # instead of e.g. a variable object.
- return array_ops.identity(fn(val.get(destination)))
- device = None
- for d in self._devices:
- if val.on_device(d):
- device = d
- break
- assert device is not None, (
- "Could not find destination %s in list of devices %s." %
- (destination, val.devices))
- with ops.device(device):
- v = fn(val.get(device))
- with ops.device(destination):
- return array_ops.identity(v)
+ def read_var(self, tower_local_var):
+ """Read the aggregate value of a tower-local variable."""
+ if isinstance(tower_local_var, values.TowerLocalVariable):
+ return tower_local_var._get_cross_tower() # pylint: disable=protected-access
+ assert isinstance(tower_local_var, values.Mirrored)
+ return array_ops.identity(tower_local_var.get())
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
@@ -400,7 +437,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return list(colocate_with._index.keys())
elif isinstance(colocate_with, six.string_types):
- return [colocate_with]
+ return [device_util.resolve(colocate_with)]
+ elif isinstance(colocate_with, list):
+ return [device_util.resolve(d) for d in colocate_with]
else:
return colocate_with
@@ -427,6 +466,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self.merge_args = None
self.merge_kwargs = None
self.merge_result = None
+ self.captured_name_scope = None
# We use a thread.Event for the main thread to signal when this
# thread should start running (`should_run`), and another for
# this thread to transfer control back to the main thread
@@ -450,13 +490,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._variable_creator_stack = self.graph._variable_creator_stack[:]
self._captured_var_scope = variable_scope.get_variable_scope()
# Adding a "/" at end lets us re-enter this scope later.
- self._captured_name_scope = self.graph.get_name_scope()
- if self._captured_name_scope:
- self._captured_name_scope += "/"
+ self._name_scope = self.graph.get_name_scope()
+ if self._name_scope:
+ self._name_scope += "/"
if self.tower_id > 0:
- if not self._captured_name_scope:
- self._captured_name_scope = ""
- self._captured_name_scope += "tower_%d/" % self.tower_id
+ if not self._name_scope:
+ self._name_scope = ""
+ self._name_scope += "tower_%d/" % self.tower_id
def run(self):
# pylint: disable=protected-access
@@ -472,7 +512,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
_enter_graph(self.graph), \
MirroredTowerContext(self.distribution, self.tower_id), \
ops.device(self.device), \
- ops.name_scope(self._captured_name_scope), \
+ ops.name_scope(self._name_scope), \
variable_scope.variable_scope(
self._captured_var_scope, reuse=self.tower_id > 0), \
variable_scope.variable_creator_scope(self.variable_creator_fn):
@@ -498,6 +538,10 @@ class MirroredTowerContext(distribute_lib.TowerContext):
t.merge_fn = fn
t.merge_args = args
t.merge_kwargs = kwargs
+ t.captured_name_scope = t.graph.get_name_scope()
+ # Adding a "/" at end lets us re-enter this scope later.
+ if t.captured_name_scope:
+ t.captured_name_scope += "/"
t.has_paused.set()
t.should_run.wait()
t.should_run.clear()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 3f9a02b249..6a14b833d2 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -32,12 +32,14 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.layers import core
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
+
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -83,13 +85,13 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
self.skipTest("Not GPU test")
self.assertEqual(2, self._get_distribution_strategy().num_towers)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCallAndMergeExceptions(self):
if not GPU_TEST:
self.skipTest("Not GPU test")
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRunRegroupError(self):
def run_fn(device_id):
@@ -101,7 +103,7 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
with dist.scope(), self.assertRaises(AssertionError):
dist.call_for_each_tower(run_fn, dist.worker_device_index)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testReduceToCpu(self):
if not GPU_TEST:
self.skipTest("Not GPU test")
@@ -112,12 +114,35 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
dist = self._get_distribution_strategy()
with dist.scope():
result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
- reduced = dist.reduce("sum", result, destinations="/device:CPU:0")
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ result,
+ destinations="/device:CPU:0")
unwrapped = dist.unwrap(reduced)
self.assertEqual(1, len(unwrapped))
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes()
+ def testReduceToMultipleDestinations(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ devices = ["/device:GPU:0"]
+ if GPU_TEST:
+ self.assertGreater(context.num_gpus(), 0)
+ print(self.id().split(".")[-1], "devices:", ", ".join(devices))
+
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.SUM,
+ 1.0,
+ destinations=["/device:CPU:0", "/device:GPU:0"])
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(2, len(unwrapped))
+ self.assertEqual(1.0, self.evaluate(unwrapped[0]))
+
class MirroredStrategyVariableCreationTest(test.TestCase):
@@ -264,18 +289,68 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("common/dense" + suffix + "/bias:0", bias.name)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testWithVariableAndVariableScope(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def model_fn():
+ v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
+ with variable_scope.variable_scope("common"):
+ v1 = variable_scope.variable(1.0, name="var1")
+ # This will pause the current thread, and execute the other thread.
+ distribute_lib.get_tower_context().merge_call(lambda _: _)
+ v2 = variable_scope.variable(
+ 1.0,
+ name="var2",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.variable(
+ 1.0,
+ name="var3",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+
+ return v0, v1, v2, v3
+
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v = variable_scope.variable(1.0, name="var-main0")
+ self.assertEquals("var-main0:0", v.name)
+
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
+ self.assertIsInstance(v0, values.MirroredVariable)
+ self.assertEquals("var0:0", v0.name)
+ self.assertIsInstance(v1, values.MirroredVariable)
+ self.assertEquals("common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM, v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN, v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testWithGetVariableAndVariableScope(self):
self._skip_eager_if_gpus_less_than(1)
def model_fn():
- v0 = variable_scope.get_variable("var-thread0", [1])
+ v0 = variable_scope.get_variable("var0", [1])
with variable_scope.variable_scope("common"):
- v1 = variable_scope.get_variable("var-thread1", [1])
+ v1 = variable_scope.get_variable("var1", [1])
# This will pause the current thread, and execute the other thread.
distribute_lib.get_tower_context().merge_call(lambda _: _)
- v2 = variable_scope.get_variable("var-thread2", [1])
+ v2 = variable_scope.get_variable(
+ "var2", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v3 = variable_scope.get_variable(
+ "var3", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- return v0, v1, v2
+ return v0, v1, v2, v3
devices = ["/device:CPU:0", "/device:GPU:0"]
dist = mirrored_strategy.MirroredStrategy(devices)
@@ -285,14 +360,89 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("main/var-main0:0", v.name)
result = dist.call_for_each_tower(model_fn, run_concurrently=False)
- self.assertEquals(3, len(result))
- v0, v1, v2 = result
+ self.assertEquals(4, len(result))
+ v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
- self.assertEquals("main/var-thread0:0", v0.name)
+ self.assertEquals("main/var0:0", v0.name)
self.assertIsInstance(v1, values.MirroredVariable)
- self.assertEquals("main/common/var-thread1:0", v1.name)
- self.assertIsInstance(v2, values.MirroredVariable)
- self.assertEquals("main/common/var-thread2:0", v2.name)
+ self.assertEquals("main/common/var1:0", v1.name)
+ self.assertIsInstance(v2, values.TowerLocalVariable)
+ self.assertEquals("main/common/var2:0", v2.name)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ v2.aggregation)
+ self.assertIsInstance(v3, values.MirroredVariable)
+ self.assertEquals("main/common/var3:0", v3.name)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ v3.aggregation)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testNoneSynchronizationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testNoneSynchronizationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please change "
+ "the `synchronization` for variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.NONE)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidSynchronizationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable synchronization mode: Invalid for "
+ "variable: v"):
+ variable_scope.variable(1.0, name="v", synchronization="Invalid")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithGetVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.get_variable(
+ "v", [1],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testInvalidAggregationWithVariable(self):
+ self._skip_eager_if_gpus_less_than(1)
+ devices = ["/device:CPU:0", "/device:GPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ with self.assertRaisesRegexp(
+ ValueError, "Invalid variable aggregation mode: invalid for "
+ "variable: v"):
+ variable_scope.variable(
+ 1.0,
+ name="v",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation="invalid")
@test_util.run_in_graph_and_eager_modes(config=config)
def testThreeDevices(self):
@@ -337,34 +487,51 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
all_v_sum = {}
all_v_mean = {}
+ components_sum = {}
+ components_mean = {}
def model_fn(device_id):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope("sum"):
- v_sum = variable_scope.variable(1.0)
- with tower_context.tower_local_var_scope("mean"):
- v_mean = variable_scope.variable(4.0)
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ v_mean = variable_scope.variable(
+ 4.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.MEAN)
self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
updates = [v_sum.assign_add(2.0 + device_id),
v_mean.assign(6.0 * device_id)]
all_v_sum[device_id] = v_sum
all_v_mean[device_id] = v_mean
- return updates, v_sum, v_mean
+ c_sum = v_sum.get()
+ c_mean = v_mean.get()
+ components_sum[device_id] = c_sum
+ components_mean[device_id] = c_mean
+ self.assertIsNot(v_sum, c_sum)
+ self.assertIsNot(v_mean, c_mean)
+ return updates, v_sum, v_mean, c_sum, c_mean
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
with dist.scope():
# Create "sum" and "mean" versions of TowerLocalVariables.
- ret_ops, ret_v_sum, ret_v_mean = dist.call_for_each_tower(
- model_fn, dist.worker_device_index, run_concurrently=False)
+ ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
+ dist.call_for_each_tower(
+ model_fn, dist.worker_device_index, run_concurrently=False))
# Should see the same wrapping instance in all towers.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
- for i in range(1, dist.num_towers):
- self.assertIs(all_v_sum[0], all_v_sum[1])
- self.assertIs(all_v_mean[0], all_v_mean[1])
+ self.assertIs(all_v_sum[0], all_v_sum[1])
+ self.assertIs(all_v_mean[0], all_v_mean[1])
+
+ # Regroup should recover the same wrapper.
+ self.assertIs(ret_v_sum, regrouped_sum)
+ self.assertIs(ret_v_mean, regrouped_mean)
+ self.assertIsNot(components_sum[0], components_sum[1])
+ self.assertIsNot(components_mean[0], components_mean[1])
# Apply updates
self.evaluate(variables.global_variables_initializer())
@@ -385,14 +552,13 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
# Without get(device), should return the value you get by
# applying the reduction across all towers (whether you use
- # fetch(), get(), or nothing).
- self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
- self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+ # read_var(), get(), or nothing).
+ self.assertEqual(expected_sum, self.evaluate(dist.read_var(ret_v_sum)))
+ self.assertEqual(expected_mean, self.evaluate(dist.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
- if not context.executing_eagerly():
- self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
- self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
@@ -438,6 +604,74 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
self.assertEquals("foo/" + name + ":0", v0.name)
self.assertEquals("tower_1/foo/" + name + ":0", v1.name)
+ # variable_scope.variable() respects name scopes when creating
+ # variables. On the other hand variable_scope.get_variable() ignores name
+ # scopes when creating variables. We test both methods of creating variables
+ # to make sure that we have the same variable names in both cases.
+ def testNameScopeWithVariable(self):
+ def in_cross_tower(_):
+ c = variable_scope.variable(1.0, name="c")
+ return c
+
+ def model_fn():
+ b = variable_scope.variable(1.0, name="b")
+ with ops.name_scope("foo"):
+ c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ return b, c
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with context.graph_mode(), dist.scope():
+ with ops.name_scope("main"):
+ a = variable_scope.variable(1.0, name="a")
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ result_b = result[0]
+ result_c = result[1]
+ self.assertIsInstance(result_b, values.DistributedValues)
+ self.assertIsInstance(result_c, values.DistributedValues)
+ a0, a1 = dist.unwrap(a)
+ b0, b1 = dist.unwrap(result_b)
+ c0, c1 = dist.unwrap(result_c)
+ self.assertEquals("main/a:0", a0.name)
+ self.assertEquals("main/a/replica_1:0", a1.name)
+ self.assertEquals("main/b:0", b0.name)
+ self.assertEquals("main/b/replica_1:0", b1.name)
+ self.assertEquals("main/foo/c:0", c0.name)
+ self.assertEquals("main/foo/c/replica_1:0", c1.name)
+
+ def testNameScopeWithGetVariable(self):
+ def in_cross_tower(_):
+ c = variable_scope.get_variable("c", [1])
+ return c
+
+ def model_fn():
+ b = variable_scope.get_variable("b", [1])
+ with ops.name_scope("foo"):
+ c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
+ return b, c
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with context.graph_mode(), dist.scope():
+ with ops.name_scope("main"):
+ a = variable_scope.get_variable("a", [1])
+ result = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ result_b = result[0]
+ result_c = result[1]
+ self.assertIsInstance(result_b, values.DistributedValues)
+ self.assertIsInstance(result_c, values.DistributedValues)
+ a0, a1 = dist.unwrap(a)
+ b0, b1 = dist.unwrap(result_b)
+ c0, c1 = dist.unwrap(result_c)
+ self.assertEquals("a:0", a0.name)
+ self.assertEquals("a/replica_1:0", a1.name)
+ self.assertEquals("b:0", b0.name)
+ self.assertEquals("b/replica_1:0", b1.name)
+ self.assertEquals("c:0", c0.name)
+ self.assertEquals("c/replica_1:0", c1.name)
+
def testDynamicRnnVariables(self):
def model_fn():
inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]])
@@ -462,6 +696,276 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
_, v1 = dist.unwrap(v)
self.assertStartsWith(v1.name, "tower_1/")
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testTowerLocalVariableUpdate(self):
+ with context.graph_mode():
+
+ def model_fn():
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
+ return v_sum
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:GPU:1"])
+
+ def update(var, value):
+ return var.assign(value)
+
+ with dist.scope():
+ ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
+ update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+
+ # Initialize variables.
+ self.evaluate(variables.global_variables_initializer())
+ # Assert that the aggregated value of the tower local vars is the sum of
+ # the individual values before running the update ops.
+ self.assertEquals(1.0, self.evaluate(
+ ret_v_sum.get(dist._devices[0]).read_value()))
+ self.assertEquals(2.0, self.evaluate(ret_v_sum))
+
+ # Apply updates.
+ self.evaluate(update_ops)
+ # Assert that the aggregated value of the tower local vars is the sum of
+ # the individual values after running the update ops.
+ self.assertEquals(5.0, self.evaluate(
+ ret_v_sum.get(dist._devices[0]).read_value()))
+ self.assertEquals(10.0, self.evaluate(ret_v_sum))
+
+
+class MirroredVariableUpdateTest(test.TestCase):
+ # The following tests check assign, assign_add and assign_sub on Mirrored
+ # variables in tower and cross tower context.
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def _skip_eager_if_gpus_less_than(self, num_gpus):
+ if context.num_gpus() < num_gpus and context.executing_eagerly():
+ self.skipTest("Enough GPUs not available for this test in eager mode.")
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContextWithoutAggregationType(self):
+ # Test that we always have an aggregation type set on the mirrored variable
+ # if we assign to it in tower mode.
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+
+ def model_fn():
+ return mirrored_var.assign(5.0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "You must specify an aggregation method to update a "
+ "MirroredVariable in Tower Context."):
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContextWithSum(self):
+ # Test that we don't reduce a non-per-device value with the "sum"
+ # aggregation type.
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ v = variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.SUM)
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+
+ def model_fn():
+ return mirrored_var.assign(5.0)
+
+ with self.assertRaisesRegexp(
+ ValueError, "A non PerDevice value cannot be reduced with the given "
+ "aggregation."):
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign(6.0))
+ self.assertEquals(6.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(0.5, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignAddMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(1.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+ self.assertEquals(7.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignAddMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(1.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign_add(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(1.5, self.evaluate(mirrored_var))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignSubMirroredVarCrossTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(5.0, name="foo")
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+ mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
+ self.assertEquals(3.0, mirrored_var_result)
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testAssignSubMirroredVarTowerContext(self):
+ self._skip_eager_if_gpus_less_than(1)
+ def var_fn():
+ return variable_scope.variable(
+ 5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn, run_concurrently=False)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEquals(5.0, self.evaluate(mirrored_var))
+
+ def model_fn():
+ value = math_ops.cast(distribute_lib.get_tower_context().tower_id,
+ mirrored_var.dtype)
+ return mirrored_var.assign_sub(value)
+
+ self.evaluate(dist.unwrap(dist.call_for_each_tower(
+ model_fn, run_concurrently=False)))
+ self.assertEquals(4.5, self.evaluate(mirrored_var))
+
+
+class MirroredAndTowerLocalVariableInitializerTest(test.TestCase):
+ config = config_pb2.ConfigProto()
+ config.allow_soft_placement = True
+
+ def testAssignMirroredVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def var_fn():
+ v = variable_scope.variable(1.0, name="foo")
+ return v
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ mirrored_var = dist.call_for_each_tower(var_fn)
+ self.assertIsInstance(mirrored_var, values.MirroredVariable)
+ self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
+ self.evaluate(mirrored_var.initializer)
+ self.assertTrue(self.evaluate(mirrored_var.is_initialized()))
+
+ def testAssignTowerLocalVarInitializer(self):
+ # This test is not eager compatible since in eager variables are initialized
+ # upon construction instead of once the initialization op is run.
+ with context.graph_mode():
+ def model_fn():
+ v_sum = variable_scope.variable(
+ 1.0,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
+ return v_sum
+
+ dist = mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0", "/device:CPU:0"])
+
+ with dist.scope():
+ tower_local_var = dist.call_for_each_tower(model_fn)
+ self.assertTrue(isinstance(tower_local_var, values.TowerLocalVariable))
+ self.assertFalse(self.evaluate(tower_local_var.is_initialized()))
+ self.evaluate(tower_local_var.initializer)
+ self.assertTrue(self.evaluate(tower_local_var.is_initialized()))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 61cbe6df81..a066adf124 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -47,7 +47,7 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
def testTowerId(self):
self._test_tower_id(self._get_distribution_strategy())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCallAndMergeExceptions(self):
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 4fdb9bf69b..2892ce4394 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -52,11 +52,11 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
self.assertEqual(1, len(layer.trainable_variables))
mirrored_weight_variable = layer.trainable_variables[0]
- start_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ start_error = self.evaluate(mirrored_weight_variable)
start_error = abs(numpy.array(start_error) - 1)
monitor.run_steps(9)
- end_error = self.evaluate(distribution.fetch(mirrored_weight_variable))
+ end_error = self.evaluate(mirrored_weight_variable)
end_error = abs(numpy.array(end_error) - 1)
self.assertGreaterEqual(start_error, end_error)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
index a552b370eb..cbfe5df61d 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
@@ -46,7 +46,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy):
* **In-graph replication**: the `client` creates a single `tf.Graph` that
specifies tasks for devices on all workers. The `client` then creates a
client session which will talk to the `master` service of a `worker`. Then
- the `master` will parition the graph and distribute the work to all
+ the `master` will partition the graph and distribute the work to all
participating workers.
* **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
physical machine. We will have multiple `worker`s with different `task`
@@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy):
worker: [device_util.canonicalize(worker, '/device:CPU:0')]
for worker in self._workers
}
- self._devices = nest.flatten(self._worker_device_map.values())
+ self._devices = nest.flatten(self._worker_device_map)
super(MultiWorkerMirroredStrategy, self).__init__(
devices=self._devices, prefetch_on_device=prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 09b6d4a515..dbd3514aec 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
@@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._default_device = device
def _create_variable(self, next_creator, *args, **kwargs):
- # No need to distinguish tower-local variables when not mirroring,
- # we just enforce that they are not trainable.
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
-
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
with ops.device(self._device):
@@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
if not isinstance(value, values.MapOutput):
return value
l = value.get()
assert l
with ops.device(self._device):
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
return math_ops.add_n(l)
- elif method_string == "mean":
+ elif aggregation == vs.VariableAggregation.MEAN:
return math_ops.add_n(l) / len(l)
else:
assert False
@@ -102,12 +98,9 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
return fn(*args, **kwargs)
- def _fetch(self, val, destination, fn):
- """Return a copy of `val` or `fn(val)` on `destination`."""
- with ops.device(self._device):
- v = fn(val)
- with ops.device(destination):
- return array_ops.identity(v)
+ def read_var(self, tower_local_var):
+ """Read the aggregate value of a tower-local variable."""
+ return array_ops.identity(tower_local_var)
def _unwrap(self, value):
return [value]
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
index 7aad8a953c..4fdc0f72e6 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
@@ -44,7 +44,7 @@ class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase):
def testTowerId(self):
self._test_tower_id(self._get_distribution_strategy())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCallAndMergeExceptions(self):
self._test_call_and_merge_exceptions(self._get_distribution_strategy())
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index abd3a65ac4..a2d736e422 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -59,8 +59,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- weights.append(self.evaluate(distribution.fetch(layer.kernel)))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ weights.append(self.evaluate(layer.kernel))
+ biases.append(self.evaluate(layer.bias))
error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index 7b3670b45a..24cdc627a3 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -89,6 +89,9 @@ class _PrefetchToDeviceIterator(object):
with ops.device(device):
buffer_resource_handle = prefetching_ops.function_buffering_resource(
f=_prefetch_fn,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)),
target_device=target_device,
string_arg=input_iterator_handle,
buffer_size=buffer_size,
diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py
index a0b452fc2d..2a9ab51fcf 100644
--- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py
+++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py
@@ -46,7 +46,7 @@ class CanonicalizeVariableNameTest(test.TestCase):
class SharedVariableCreatorTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSharedVariable(self):
shared_variable_store = {}
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 75c5ec9659..2ee94d8f70 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,8 +50,8 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
run_step()
- weights.append(self.evaluate(distribution.fetch(layer.kernel)))
- biases.append(self.evaluate(distribution.fetch(layer.bias)))
+ weights.append(self.evaluate(layer.kernel))
+ biases.append(self.evaluate(layer.bias))
error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1)
is_not_increasing = all(y <= x for x, y in zip(error, error[1:]))
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 2b4ad9f146..baed0ebaae 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
@@ -106,13 +107,14 @@ class DistributionTestBase(test.TestCase):
before_list = []
after_list = []
for g, v in g_v:
- fetched = d.fetch(v)
+ fetched = d.read_var(v)
before_list.append(fetched)
# control_dependencies irrelevant but harmless in eager execution
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
- after_list.append(d.fetch(v))
+ after_list.append(d.read_var(v))
return before_list, after_list
for i in range(10):
@@ -159,12 +161,13 @@ class DistributionTestBase(test.TestCase):
before_list = []
after_list = []
for g, v in g_v:
- fetched = d.fetch(v)
+ fetched = d.read_var(v)
before_list.append(fetched)
with ops.control_dependencies([fetched]):
- g = d.reduce("sum", g, destinations=v)
+ g = d.reduce(
+ variable_scope.VariableAggregation.SUM, g, destinations=v)
with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
- after_list.append(d.fetch(v))
+ after_list.append(d.read_var(v))
return before_list, after_list
before_out, after_out = step()
@@ -184,7 +187,7 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.fetch(d.reduce("sum", map_out))
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 75441786a6..bc53898539 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -21,104 +21,126 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import itertools
-
from tensorflow.contrib import tpu
from tensorflow.contrib.distribute.python import one_device_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self,
- num_cores_per_host=2,
- iterations_per_step=2):
+ def __init__(self, num_cores_per_host=2):
# TODO(isaprykin): Generalize the defaults. They are currently tailored for
# the unit test.
super(TPUStrategy, self).__init__('/cpu:0')
# TODO(isaprykin): Auto-detect number of cores and hosts.
self._num_cores_per_host = num_cores_per_host
- # TODO(isaprykin): This might have to be per-call.
- self._iterations_per_step = iterations_per_step
+ # TODO(priyag): This should not be hardcoded here.
+ self._host = '/task:0/device:CPU:0'
def distribute_dataset(self, dataset_fn):
- return values.PerIterationDataset(
- self._call_dataset_fn(dataset_fn), self._iterations_per_step,
- self._num_cores_per_host)
-
- def _call_for_each_tower(self, fn, *args, **kwargs):
- kwargs.pop('run_concurrently', None)
-
- inputs = {'args': args, 'kwargs': kwargs}
- flat_inputs = nest.flatten(inputs)
-
- feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs]
-
- feeds = lambda: itertools.compress(flat_inputs, feed_mask)
- shapes = [f.get_shape() for f in feeds()]
+ # TODO(priyag): Perhaps distribute across cores here.
+ return self._call_dataset_fn(dataset_fn)
+
+ # TODO(priyag): Deal with OutOfRange errors.
+ # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
+ # a mechanism to infer the outputs of `fn`. Pending b/110550782.
+ def _run_steps_on_dataset(self, fn, iterator, iterations,
+ initial_loop_values=None):
+ # Enqueue ops
+ shapes = nest.flatten(iterator.output_shapes)
if any([not s.is_fully_defined() for s in shapes]):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
'dataset.apply(map_and_batch(..., drop_remainder=True)).')
- types = [f.get_dtype() for f in feeds()]
-
- def infeed_input(i):
- """Get input, split it and then enqueue."""
- iteration_inputs = [f.get(i) for f in feeds()]
- infeed_inputs = [[inputs_per_core[core_id]
- for inputs_per_core in iteration_inputs]
- for core_id in range(self._num_cores_per_host)]
-
- infeed_ops = []
- for core_id, infeed_input in enumerate(infeed_inputs):
- infeed_ops.append(
+ types = nest.flatten(iterator.output_types)
+
+ def enqueue_ops_fn():
+ """Enqueue ops for one iteration."""
+ control_deps = []
+ sharded_inputs = []
+ with ops.device(self._host):
+ for _ in range(self._num_cores_per_host):
+ # Use control dependencies to ensure a deterministic ordering.
+ with ops.control_dependencies(control_deps):
+ inputs = nest.flatten(iterator.get_next())
+ control_deps.extend(inputs)
+ sharded_inputs.append(inputs)
+
+ enqueue_ops = []
+ for core_id, shard_input in enumerate(sharded_inputs):
+ enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
- inputs=infeed_input, shapes=shapes, device_ordinal=core_id))
+ inputs=shard_input, shapes=shapes, device_ordinal=core_id))
+ return enqueue_ops
- with ops.control_dependencies(infeed_ops):
+ def enqueue_ops_loop_body(i):
+ with ops.control_dependencies(enqueue_ops_fn()):
return i + 1
- with ops.device('/task:0/device:CPU:0'):
+ with ops.device(self._host):
enqueue_ops = control_flow_ops.while_loop(
- lambda i: i < self._iterations_per_step,
- infeed_input, [constant_op.constant(0)],
+ lambda i: i < iterations,
+ enqueue_ops_loop_body,
+ [constant_op.constant(0)],
parallel_iterations=1)
- def dequeueing_fn(*args, **kwargs):
- """Dequeue input arguments and supply them to `fn`."""
- del args, kwargs
+ # Dequeue ops
+ def dequeue_fn():
dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
- dequeued = iter(dequeued)
+ return nest.pack_sequence_as(iterator.output_shapes, dequeued)
- fn_inputs = []
- for inp, is_feed in zip(flat_inputs, feed_mask):
- if is_feed:
- fn_inputs.append(next(dequeued))
- else:
- fn_inputs.append(inp)
+ # Wrap `fn` for repeat.
+ if initial_loop_values is None:
+ initial_loop_values = []
+ ctx = values.MultiStepContext(initial_loop_values)
+ def run_fn(*args, **kwargs):
+ del args, kwargs
+ fn_result = fn(ctx, dequeue_fn())
+ if ctx.last_step_outputs is None:
+ ctx.last_step_outputs = []
+ with ops.control_dependencies([fn_result]):
+ return array_ops.identity(ctx.last_step_outputs)
+
+ # Repeat
+ # TODO(sourabhbajaj): The input to while loop should be based on the output
+ # type of the step_fn
+ def iterate_on_tpu():
+ return tpu.repeat(iterations, run_fn, [initial_loop_values])
- fn_inputs = nest.pack_sequence_as(inputs, fn_inputs)
- return fn(*fn_inputs['args'], **fn_inputs['kwargs'])
+ # Re-write and distribute computation.
+ # TODO(sourabhbajaj): Convert the output to PerDevice variable and
+ # implement support for that in reduce.
+ last_step_tensor_outputs = tpu.batch_parallel(
+ iterate_on_tpu, [], num_shards=self._num_cores_per_host)
- def iterate_on_tpu():
- return tpu.repeat(self._iterations_per_step, dequeueing_fn, [])
+ # Take index [0] of last_step_tensor_outputs as we wrapped
+ # initial_loop_values in a list in the `repeat` call.
+ return (control_flow_ops.group(last_step_tensor_outputs, enqueue_ops),
+ last_step_tensor_outputs[0], ctx)
+ def _call_for_each_tower(self, fn, *args, **kwargs):
+ kwargs.pop('run_concurrently', None)
with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access
- tpu_result = tpu.batch_parallel(
- iterate_on_tpu, [], num_shards=self._num_cores_per_host)
+ return fn(*args, **kwargs)
+
+ def get_initialization_ops(self):
+ return [tpu.initialize_system()]
- return control_flow_ops.group(tpu_result, enqueue_ops)
+ def get_finalize_ops(self):
+ return [tpu.shutdown_system()]
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
- if method_string == 'mean':
+ if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_cores_per_host)
return tpu_ops.cross_replica_sum(value)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 49b4e24daa..1b5e00bc79 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -23,10 +23,8 @@ from __future__ import print_function
import collections
import weakref
-
import six
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
@@ -35,6 +33,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver
@@ -43,7 +43,7 @@ from tensorflow.python.util import nest
# pylint: disable=line-too-long
-# TODO(josh11b): Should device values be strings or DeviceSpec objects
+# TODO(josh11b): Should device values be strings or DeviceSpec objects?
# Not sure DeviceSpec objects are usable as a dict key.
class DistributedValues(object):
"""Holds a map from device to values. Either PerDevice or Mirrored."""
@@ -65,9 +65,10 @@ class DistributedValues(object):
device = device_util.canonicalize(device)
try:
return self._index[device]
- except KeyError:
- raise ValueError("Device %s not found in %s (current device %s)" %
- (device, self._index.keys(), device_util.current()))
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
def on_device(self, device):
device = device_util.canonicalize(device)
@@ -162,9 +163,16 @@ class PerDevice(DistributedValues):
pass
-class Mirrored(DistributedValues):
+# Note that unlike PerDevice, Mirrored values inherit from
+# DistributedDelegate and so can be used directly in cross-tower mode.
+class Mirrored(DistributedDelegate):
"""Holds a map from device to values which are kept in sync."""
- pass
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return list(self._index.values())[0]
def _assign_on_device(device, variable, tensor):
@@ -185,6 +193,10 @@ class DistributedVariable(DistributedDelegate):
# Child class must set self._primary_var before calling
# super(...).__init__(index).
self._common_name = self._primary_var.name.split(":")[0]
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
super(DistributedVariable, self).__init__(index)
@property
@@ -237,35 +249,9 @@ class DistributedVariable(DistributedDelegate):
pass
-# Register a conversion function which reads the value of the variable,
-# allowing instances of the class to be used as tensors.
-def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
- # Try to avoid assignments to and other mutations of MirroredVariable
- # state except through a DistributionStrategy.update() call.
- assert not as_ref
- return ops.internal_convert_to_tensor(
- var.get(), dtype=dtype, name=name, as_ref=as_ref)
-
-
-ops.register_tensor_conversion_function(DistributedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(DistributedVariable)
-class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
- """Class for defining how to restore a MirroredVariable."""
-
- def __init__(self, mirrored_variable, primary_variable, name):
- self._mirrored_variable = mirrored_variable
- super(_MirroredSaveable, self).__init__(primary_variable, "", name)
-
- def restore(self, restored_tensors, restored_shapes):
- """Restore the same value into all variables."""
- tensor, = restored_tensors
- return control_flow_ops.group([
- _assign_on_device(d, v, tensor)
- for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
-
-
def _get_update_device():
"""Validate we are in update/update_non_slot() and return current device.
@@ -286,34 +272,113 @@ def _get_update_device():
return device
+class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
+ """Class for defining how to restore a MirroredVariable."""
+
+ def __init__(self, mirrored_variable, primary_variable, name):
+ self._mirrored_variable = mirrored_variable
+ super(_MirroredSaveable, self).__init__(primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into all variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group([
+ _assign_on_device(d, v, tensor)
+ for d, v in six.iteritems(self._mirrored_variable._index)]) # pylint: disable=protected-access
+
+
class MirroredVariable(DistributedVariable, Mirrored,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are kept in sync."""
- def __init__(self, index, primary_var):
+ def __init__(self, index, primary_var, aggregation):
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in six.itervalues(index):
v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of MirroredVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
+ self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
- # We use _get_update_device() for the assign* methods to enforce
- # that we are in an update() function. The arguments to update() are
- # automatically unwrapped so the update() function would normally
- # see regular variables, not MirroredVariables. However, the update
- # function can still operate on wrapped MirroredVariables through
- # object members, captured arguments, etc. This is more likely in an
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
# update_non_slot() function (like OptimizerV2._finish), which can
# update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ f = kwargs.pop("f")
+ if distribute_lib.get_cross_tower_context():
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self.get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribute_lib.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "MirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self))
+
+ return distribute_lib.get_tower_context().merge_call(merge_fn, *args,
+ **kwargs)
+
def assign_sub(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign_sub(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign_sub, *args, **kwargs)
def assign_add(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign_add(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign_add, *args, **kwargs)
def assign(self, *args, **kwargs):
- return self.get(device=_get_update_device()).assign(*args, **kwargs)
+ return self._assign_func(f=state_ops.assign, *args, **kwargs)
+
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For mirrored variables, the `is_initialized`
+ # op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+ @property
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the mirrored variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
+ @property
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
device = device_util.canonicalize(device_util.current())
@@ -341,6 +406,20 @@ class MirroredVariable(DistributedVariable, Mirrored,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
+ # Try to avoid assignments to and other mutations of MirroredVariable
+ # state except through a DistributionStrategy.update() call.
+ assert not as_ref
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(MirroredVariable,
+ _tensor_conversion_mirrored)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
@@ -349,7 +428,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().fetch(
+ return distribute_lib.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,
@@ -364,7 +443,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# To preserve the sum across save and restore, we have to divide the
# total across all devices when restoring a variable that was summed
# when saving.
- if self._tower_local_variable.reduce_method == "sum":
+ if self._tower_local_variable.aggregation == vs.VariableAggregation.SUM:
tensor *= 1. / len(self._tower_local_variable.devices)
return control_flow_ops.group([
_assign_on_device(d, v, tensor)
@@ -381,9 +460,15 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
- def __init__(self, index, primary_var, reduce_method):
+ def __init__(self, index, primary_var, aggregation):
self._primary_var = primary_var
- self._reduce_method = reduce_method
+ self._aggregation = aggregation
+ # tf.keras keeps track of variables initialized using this attribute. When
+ # tf.keras gets the default session, it initializes all uninitialized vars.
+ # We need to make _keras_initialized a member of TowerLocalVariable because
+ # without this it will use `__getattr__` which will delegate to a component
+ # variable.
+ self._keras_initialized = False
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
@@ -398,15 +483,37 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
_assert_tower_context()
return self.get().assign(*args, **kwargs)
+ def is_initialized(self, name=None):
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = list(self._index.values())
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For tower local variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
@property
- def reduce_method(self):
- return self._reduce_method
+ def initializer(self):
+ # return grouped ops of all the var initializations of component values of
+ # the tower local variable
+ return control_flow_ops.group([v.initializer for v in self._index.values()])
+
+ @property
+ def aggregation(self):
+ return self._aggregation
def _get_cross_tower(self):
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
- if self._reduce_method == "mean":
+ if self._aggregation == vs.VariableAggregation.MEAN:
return total * (1./ len(all_components))
return total
@@ -430,6 +537,17 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+# Register a conversion function for TowerLocalVariable which allows as_ref to
+# be true.
+def _tensor_conversion_tower_local(var, dtype=None, name=None, as_ref=False):
+ return ops.internal_convert_to_tensor(
+ var.get(), dtype=dtype, name=name, as_ref=as_ref)
+
+
+ops.register_tensor_conversion_function(TowerLocalVariable,
+ _tensor_conversion_tower_local)
+
+
def _devices_match(d1, d2):
return device_util.canonicalize(d1) == device_util.canonicalize(d2)
@@ -477,40 +595,40 @@ def regroup(per_device, wrap_class=PerDevice):
same_id = False
break
# Consider three cases where same_id is true:
- # * If v0 is a MirroredVariable (and same_id means it is the same
- # across all devices), we want to return it. We check
- # MirroredVariable specifically since it can look like it
- # has a _mirrored_container member since its members do.
- # * If v0 is a member of a mirrored variable, in which case
- # hasattr(v0, "_mirrored_container") is true, we want to
- # return the MirroredVariable that contains it using the
- # _mirrored_container logic below. This case can trigger
+ # * If v0 is a DistributedVariable (a MirroredVariable or
+ # TowerLocalVariable, and same_id means it is the same across all
+ # devices), we want to return it. We check DistributedVariable
+ # specifically since it can look like it has a
+ # _distributed_container member since its members do.
+ # * If v0 is a member of a distributed variable, in which case
+ # hasattr(v0, "_distributed_container") is true, we want to
+ # return the DistributedVariable that contains it using the
+ # _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0.
- if same_id and (isinstance(v0, MirroredVariable) or
- not hasattr(v0, "_mirrored_container")):
+ if same_id and (isinstance(v0, DistributedVariable) or
+ not hasattr(v0, "_distributed_container")):
return v0
# Detect the case where each device has a parallel component of the
- # same MirroredVariable. In this case we want to return the
- # containing MirroredVariable, after a bunch of sanity checking.
- # In particular, each component should have the same container,
- # and the devices of the variables should match the keys of the
- # per-device dictionary.
- # TODO(josh11b): Do we need similar logic for TowerLocalVariables?
- if hasattr(v0, "_mirrored_container"):
+ # same MirroredVariable (or TowerLocalVariable). In this case we
+ # want to return the containing MirroredVariable, after a bunch of
+ # sanity checking. In particular, each component should have the
+ # same container, and the devices of the variables should match the
+ # keys of the per-device dictionary.
+ if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
"ids = %s, items = %s" % ([id(v[1]) for v in items], items))
assert _devices_match(v0.device, items[0][0]), (
"v0.device = %s, items = %s" % (v0.device, items))
- mirrored_container = v0._mirrored_container()
- assert mirrored_container is not None
+ distributed_container = v0._distributed_container()
+ assert distributed_container is not None
for d, v in items[1:]:
assert _devices_match(v.device, d), (
"v.device = %s, d = %s, items = %s" % (v.device, d, items))
- assert mirrored_container is v._mirrored_container()
- return mirrored_container
+ assert distributed_container is v._distributed_container()
+ return distributed_container
# pylint: enable=protected-access
return wrap_class(per_device)
@@ -592,8 +710,7 @@ class PerDeviceDataset(object):
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
- self._dataset = dataset.apply(
- batching.batch_and_drop_remainder(len(devices)))
+ self._dataset = dataset.batch(len(devices), drop_remainder=True)
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
@@ -804,3 +921,71 @@ class MapOutput(object):
def get(self):
return self._l
+
+
+class MultiStepContext(object):
+ """A context object that can be used to capture things when running steps.
+
+ This context object is useful when running multiple steps at a time using the
+ `run_steps_on_dataset` API. For e.g. it allows the user's step function to
+ specify which outputs to emit at what frequency. Currently it only supports
+ capturing output from the last step, but will soon be augmented to support
+ other use cases such as output each N steps.
+ """
+
+ def __init__(self, initial_loop_values=None):
+ """Initializes an output context.
+
+ Args:
+ initial_loop_values: Initial values passed to the run steps
+ while loop. The only purpose is to verify the shapes and types
+ when the actual output is set. This will be removed once we
+ automatically infer the output shapes and types (and do not need to
+ check for user error in specifying them manually).
+ Returns:
+ A context object.
+ """
+ self._last_step_outputs = None
+ self._non_tensor_outputs = None
+ self._initial_loop_values = initial_loop_values
+
+ @property
+ def last_step_outputs(self):
+ """Return the last step's outputs."""
+ return self._last_step_outputs
+
+ @last_step_outputs.setter
+ def last_step_outputs(self, outputs):
+ """Set the last step's outputs."""
+ self._verify_structure_shapes_types(outputs, self._initial_loop_values)
+ self._last_step_outputs = outputs
+
+ @property
+ def non_tensor_outputs(self):
+ """Return the non tensor outputs."""
+ return self._non_tensor_outputs
+
+ @non_tensor_outputs.setter
+ def non_tensor_outputs(self, outputs):
+ """Set any non tensor outputs."""
+ self._non_tensor_outputs = outputs
+
+ def _verify_structure_shapes_types(self, left, right):
+ """Verify that the structure, shapes and types of left are same as right."""
+ nest.assert_same_structure(left, right)
+ flat_left = nest.flatten(left)
+ flat_right = nest.flatten(right)
+ assert len(flat_left) == len(flat_right), (
+ "Length of left {} and right {} should be same.".
+ format(len(flat_left), len(flat_right)))
+
+ for o, i in zip(flat_left, flat_right):
+ # TODO(priyag): Add checks for other types like IndexedSlices.
+ if isinstance(o, ops.Tensor):
+ assert isinstance(i, ops.Tensor)
+ assert o.shape == i.shape, (
+ "Shape {} of left {} doesn't match shape {} of right {}.".
+ format(o.shape, o, i.shape, i))
+ assert o.dtype == i.dtype, (
+ "Dtype {} of left {} doesn't match dtype {} of right {}.".
+ format(o.dtype, o, i.dtype, i))
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 1c95758d96..8e44f2fea1 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -82,7 +82,7 @@ class DistributedValuesTest(test.TestCase):
class DistributedDelegateTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetAttr(self):
with ops.device("/device:CPU:0"):
@@ -97,7 +97,7 @@ class DistributedDelegateTest(test.TestCase):
with self.assertRaises(AttributeError):
_ = v.y
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOperatorOverride(self):
with ops.device("/device:CPU:0"):
v = values.DistributedDelegate({"/device:CPU:0": 7, "/device:GPU:0": 8})
@@ -158,7 +158,8 @@ def _make_mirrored():
v.append(variable_scope.get_variable(
name=n, initializer=init, use_resource=True))
index[d] = v[-1]
- mirrored = values.MirroredVariable(index, v[0])
+ mirrored = values.MirroredVariable(index, v[0],
+ variable_scope.VariableAggregation.SUM)
return v, devices, mirrored
@@ -277,7 +278,8 @@ class RegroupAndSelectDeviceTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
index = {d: v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.SUM)
result = values.regroup(index)
self.assertIs(mirrored, result)
@@ -363,7 +365,7 @@ class PerDeviceDatasetTest(test.TestCase):
self._test_iterator_no_prefetch(devices, dataset, expected_values)
self._test_iterator_with_prefetch(devices, dataset, expected_values)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOneDevice(self):
devices = ["/device:CPU:0"]
dataset = dataset_ops.Dataset.range(10)
@@ -581,7 +583,8 @@ class MirroredVariableTest(test.TestCase):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- mirrored = values.MirroredVariable(index, v)
+ mirrored = values.MirroredVariable(index, v,
+ variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, mirrored.name)
self.assertEquals(v.dtype, mirrored.dtype)
@@ -716,7 +719,9 @@ class MirroredVariableTest(test.TestCase):
with ops.device("/device:GPU:0"):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
- mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ mirrored = values.MirroredVariable({
+ "/device:GPU:0": v
+ }, v, variable_scope.VariableAggregation.MEAN)
sess.run(variables_lib.global_variables_initializer())
sess.run({"complicated": mirrored})
@@ -746,24 +751,27 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
self.assertEquals(v[0].name, tower_local.name)
self.assertEquals(v[0].dtype, tower_local.dtype)
self.assertEquals(v[0].shape, tower_local.shape)
- self.assertEquals("sum", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.SUM,
+ tower_local.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
def testVariableOnAnotherDevice(self):
v = variable_scope.get_variable(
name="v", initializer=[1.], use_resource=True)
index = {"/job:foo/device:CPU:0": v}
- tower_local = values.TowerLocalVariable(index, v, "mean")
+ tower_local = values.TowerLocalVariable(
+ index, v, variable_scope.VariableAggregation.MEAN)
self.assertEquals(v.name, tower_local.name)
self.assertEquals(v.dtype, tower_local.dtype)
self.assertEquals(v.shape, tower_local.shape)
- self.assertEquals("mean", tower_local.reduce_method)
+ self.assertEquals(variable_scope.VariableAggregation.MEAN,
+ tower_local.aggregation)
def _assign_tower_local(self, devices, v, new):
for d, var, n in zip(devices, v, new):
@@ -789,7 +797,7 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -812,7 +820,8 @@ class TowerLocalVariableTest(test.TestCase):
self.skipTest("A GPU is not available for this test in eager mode.")
with self.test_session() as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -831,7 +840,8 @@ class TowerLocalVariableTest(test.TestCase):
def _save_tower_local_mean(self):
"""Save variables with mirroring, returns save_path."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [3., 4.])
@@ -893,7 +903,8 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_mean(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("mean")
+ v, tower_local = _make_tower_local(
+ variable_scope.VariableAggregation.MEAN)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -907,7 +918,7 @@ class TowerLocalVariableTest(test.TestCase):
def _restore_tower_local_sum(self, save_path):
"""Restore to variables with mirroring in a fresh graph."""
with self.test_session(graph=ops.Graph()) as sess:
- v, tower_local = _make_tower_local("sum")
+ v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
self._assign_tower_local(_devices, v, [7., 8.])
@@ -966,6 +977,18 @@ class TowerLocalVariableTest(test.TestCase):
save_path = self._save_normal()
self._restore_tower_local_sum(save_path)
+ def testTensorConversion(self):
+ with context.graph_mode():
+ _, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
+ converted = ops.internal_convert_to_tensor(tower_local, as_ref=False)
+ self.assertIsInstance(converted, ops.Tensor)
+ self.assertEqual(converted.dtype, tower_local.dtype)
+
+ converted = ops.internal_convert_to_tensor(tower_local, as_ref=True)
+ # Resources variable are converted to tensors as well when as_ref is True.
+ self.assertIsInstance(converted, ops.Tensor)
+ self.assertEqual(converted.dtype, tower_local.dtype)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 23d9dbcd91..ad00d1734d 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
name = "bijectors_py",
srcs = glob(["python/ops/bijectors/*.py"]),
+ deprecation = ("TensorFlow Distributions has migrated to " +
+ "TensorFlow Probability " +
+ "(https://github.com/tensorflow/probability). " +
+ "Deprecated copies remaining in tf.contrib.distributions " +
+ "are unmaintained, unsupported, and will be removed by " +
+ "late 2018. You should update all usage of " +
+ "`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/linalg:linalg_py",
@@ -42,6 +49,13 @@ py_library(
py_library(
name = "distributions_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ deprecation = ("TensorFlow Distributions has migrated to " +
+ "TensorFlow Probability " +
+ "(https://github.com/tensorflow/probability). " +
+ "Deprecated copies remaining in tf.contrib.distributions " +
+ "are unmaintained, unsupported, and will be removed by " +
+ "late 2018. You should update all usage of " +
+ "`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
":bijectors_py",
@@ -941,6 +955,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "fill_triangular_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "gumbel_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/gumbel_test.py"],
@@ -1119,6 +1152,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "scale_tril_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "sigmoid_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/sigmoid_test.py"],
@@ -1236,6 +1288,25 @@ cuda_py_test(
)
cuda_py_test(
+ name = "transform_diagonal_test",
+ size = "small",
+ srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"],
+ additional_deps = [
+ ":bijectors_py",
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/contrib/linalg:linalg_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "weibull_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/weibull_test.py"],
diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py
index 802538ba97..5cec93c4df 100644
--- a/tensorflow/contrib/distributions/__init__.py
+++ b/tensorflow/contrib/distributions/__init__.py
@@ -13,8 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Classes representing statistical distributions and ops for working with them.
-
-See the @{$python/contrib.distributions} guide.
"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index e281e81bdf..d1ce273499 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -61,6 +61,28 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
atol=0.,
rtol=1e-7)
+ def testNoBatchStaticJacobian(self):
+ x = np.eye(2)
+ bijector = bijectors.CholeskyOuterProduct()
+
+ # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
+ self.assertAllClose(
+ np.log(4),
+ self.evaluate(bijector.forward_log_det_jacobian(x, event_ndims=2)))
+
+ def testNoBatchDynamicJacobian(self):
+ x = np.eye(2)
+ bijector = bijectors.CholeskyOuterProduct()
+ x_pl = array_ops.placeholder(dtypes.float32)
+
+ with self.test_session():
+ log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2)
+
+ # The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
+ self.assertAllClose(
+ np.log(4),
+ log_det_jacobian.eval({x_pl: x}))
+
def testNoBatchStatic(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
new file mode 100644
index 0000000000..3530e142e4
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class FillTriangularBijectorTest(test.TestCase):
+ """Tests the correctness of the FillTriangular bijector."""
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBijector(self):
+ x = np.float32(np.array([1., 2., 3.]))
+ y = np.float32(np.array([[3., 0.],
+ [2., 1.]]))
+
+ b = bijectors.FillTriangular()
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ self.assertAllClose(fldj, 0.)
+
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(ildj, 0.)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testShape(self):
+ x_shape = tensor_shape.TensorShape([5, 4, 6])
+ y_shape = tensor_shape.TensorShape([5, 4, 3, 3])
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x = array_ops.ones(shape=x_shape, dtype=dtypes.float32)
+ y_ = b.forward(x)
+ self.assertAllEqual(y_.shape.as_list(), y_shape.as_list())
+ x_ = b.inverse(y_)
+ self.assertAllEqual(x_.shape.as_list(), x_shape.as_list())
+
+ y_shape_ = b.forward_event_shape(x_shape)
+ self.assertAllEqual(y_shape_.as_list(), y_shape.as_list())
+ x_shape_ = b.inverse_event_shape(y_shape)
+ self.assertAllEqual(x_shape_.as_list(), x_shape.as_list())
+
+ y_shape_tensor = self.evaluate(
+ b.forward_event_shape_tensor(x_shape.as_list()))
+ self.assertAllEqual(y_shape_tensor, y_shape.as_list())
+ x_shape_tensor = self.evaluate(
+ b.inverse_event_shape_tensor(y_shape.as_list()))
+ self.assertAllEqual(x_shape_tensor, x_shape.as_list())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testShapeError(self):
+
+ b = bijectors.FillTriangular(validate_args=True)
+
+ x_shape_bad = tensor_shape.TensorShape([5, 4, 7])
+ with self.assertRaisesRegexp(ValueError, "is not a triangular number"):
+ b.forward_event_shape(x_shape_bad)
+ with self.assertRaisesOpError("is not a triangular number"):
+ self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list()))
+
+ y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2])
+ with self.assertRaisesRegexp(ValueError, "Matrix must be square"):
+ b.inverse_event_shape(y_shape_bad)
+ with self.assertRaisesOpError("Matrix must be square"):
+ self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list()))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
index 1839703557..85d604e34a 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class MatrixInverseTriLBijectorTest(test.TestCase):
"""Tests the correctness of the Y = inv(tril) transformation."""
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testComputesCorrectValues(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
self.assertEqual("matrix_inverse_tril", inv.name)
@@ -51,7 +51,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOneByOneMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[5.]], dtype=np.float32)
@@ -70,7 +70,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testZeroByZeroMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.eye(0, dtype=np.float32)
@@ -89,7 +89,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBatch(self):
# Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
# (2, 1).
@@ -114,7 +114,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testErrorOnInputRankTooLow(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([0.1], dtype=np.float32)
@@ -149,7 +149,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
## square_error_msg):
## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testErrorOnInputNotLowerTriangular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 2.],
@@ -169,7 +169,7 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
triangular_error_msg):
inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testErrorOnInputSingular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 0.],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
index a5f5219588..cb42331a21 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
@@ -36,7 +36,7 @@ class OrderedBijectorTest(test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorVector(self):
with self.test_session():
ordered = Ordered()
@@ -82,7 +82,7 @@ class OrderedBijectorTest(test.TestCase):
atol=0.,
rtol=1e-7)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testShapeGetters(self):
with self.test_session():
x = tensor_shape.TensorShape([4])
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
new file mode 100644
index 0000000000..d5b3367f9a
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py
@@ -0,0 +1,69 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class ScaleTriLBijectorTest(test.TestCase):
+ """Tests the correctness of the ScaleTriL bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ def testComputesCorrectValues(self):
+ shift = 1.61803398875
+ x = np.float32(np.array([-1, .5, 2]))
+ y = np.float32(np.array([[np.exp(2) + shift, 0.],
+ [.5, np.exp(-1) + shift]]))
+
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(),
+ diag_shift=shift)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testInvertible(self):
+
+ # Generate random inputs from an unconstrained space, with
+ # event size 6 to specify 3x3 triangular matrices.
+ batch_shape = [2, 1]
+ x = np.float32(np.random.randn(*(batch_shape + [6])))
+ b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(),
+ diag_shift=3.14159)
+ y = self.evaluate(b.forward(x))
+ self.assertAllEqual(y.shape, batch_shape + [3, 3])
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllClose(fldj, -ildj)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
index 45760a29ee..795f1993ba 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
@@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.)
self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.)
- # Do the numpy calculation in float128 to avoid inf/nan.
- y_float128 = np.float128(y)
- self.assertAllClose(
- np.log(np.cosh(
- np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt(
- y_float128**2 + 1)) -
- np.log(tailweight),
- bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
- rtol=1e-4,
- atol=0.)
+ # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision.
+ # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and
+ # below test fails due to overflow error giving inf. So this check avoids that error by skipping square
+ # calculation and corresponding assert.
+
+ if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \
+ np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)):
+
+ # Do the numpy calculation in float128 to avoid inf/nan.
+ y_float128 = np.float128(y)
+ self.assertAllClose(
+ np.log(np.cosh(
+ np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt(
+ y_float128**2 + 1)) -
+ np.log(tailweight),
+ bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
+ rtol=1e-4,
+ atol=0.)
self.assertAllClose(
-bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(),
bijector.forward_log_det_jacobian(x, event_ndims=0).eval(),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 2ac06fce55..d0098c3c10 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -40,7 +40,7 @@ class SoftsignBijectorTest(test.TestCase):
def setUp(self):
self._rng = np.random.RandomState(42)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorBounds(self):
bijector = Softsign(validate_args=True)
with self.test_session():
@@ -54,7 +54,7 @@ class SoftsignBijectorTest(test.TestCase):
with self.assertRaisesOpError("less than 1"):
bijector.inverse_log_det_jacobian(3., event_ndims=0).eval()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorForwardInverse(self):
bijector = Softsign(validate_args=True)
self.assertEqual("softsign", bijector.name)
@@ -64,7 +64,7 @@ class SoftsignBijectorTest(test.TestCase):
self.assertAllClose(y, self.evaluate(bijector.forward(x)))
self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorLogDetJacobianEventDimsZero(self):
bijector = Softsign(validate_args=True)
y = self._rng.rand(2, 10)
@@ -74,7 +74,7 @@ class SoftsignBijectorTest(test.TestCase):
self.assertAllClose(ildj, self.evaluate(
bijector.inverse_log_det_jacobian(y, event_ndims=0)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorForwardInverseEventDimsOne(self):
bijector = Softsign(validate_args=True)
self.assertEqual("softsign", bijector.name)
@@ -83,7 +83,7 @@ class SoftsignBijectorTest(test.TestCase):
self.assertAllClose(y, self.evaluate(bijector.forward(x)))
self.assertAllClose(x, self.evaluate(bijector.inverse(y)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBijectorLogDetJacobianEventDimsOne(self):
bijector = Softsign(validate_args=True)
y = self._rng.rand(2, 10)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
new file mode 100644
index 0000000000..efc9f266d1
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import bijectors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class TransformDiagonalBijectorTest(test.TestCase):
+ """Tests correctness of the TransformDiagonal bijector."""
+
+ def setUp(self):
+ self._rng = np.random.RandomState(42)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBijector(self):
+ x = np.float32(np.random.randn(3, 4, 4))
+
+ y = x.copy()
+ for i in range(x.shape[0]):
+ np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :])))
+
+ exp = bijectors.Exp()
+ b = bijectors.TransformDiagonal(diag_bijector=exp)
+
+ y_ = self.evaluate(b.forward(x))
+ self.assertAllClose(y, y_)
+
+ x_ = self.evaluate(b.inverse(y))
+ self.assertAllClose(x, x_)
+
+ fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2))
+ ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2))
+ self.assertAllEqual(
+ fldj,
+ self.evaluate(exp.forward_log_det_jacobian(
+ np.array([np.diag(x_mat) for x_mat in x]),
+ event_ndims=1)))
+ self.assertAllEqual(
+ ildj,
+ self.evaluate(exp.inverse_log_det_jacobian(
+ np.array([np.diag(y_mat) for y_mat in y]),
+ event_ndims=1)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index 31d24aa9ea..181c46d2e5 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.linalg import linear_operator_diag
@@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase):
return False
+class TestMoveDimension(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_move_dimension_static_shape(self):
+
+ x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 1, 1)
+ self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, 3)
+ self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, -2)
+ self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 4, 2)
+ self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_move_dimension_dynamic_shape(self):
+
+ x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
+ x = array_ops.placeholder_with_default(input=x_, shape=None)
+
+ x_perm = distribution_util.move_dimension(x, 1, 1)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 4, 1, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, 3)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 0, -2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [30, 4, 1, 200, 6])
+
+ x_perm = distribution_util.move_dimension(x, 4, 2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 6, 4, 1])
+
+ x_perm = distribution_util.move_dimension(x, -1, 2)
+ self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
+ [200, 30, 6, 4, 1])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
new file mode 100644
index 0000000000..42ecea034d
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
@@ -0,0 +1,51 @@
+# Description:
+# Internal testing utilities, e.g., computing the correct answer to
+# put in a unit test.
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "correlation_matrix_volumes_py",
+ srcs = [
+ "correlation_matrix_volumes_lib.py",
+ ],
+ deps = [
+ "//tensorflow/contrib/distributions:distributions_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_binary(
+ name = "correlation_matrix_volumes",
+ srcs = [
+ "correlation_matrix_volumes.py",
+ ],
+ deps = [
+ ":correlation_matrix_volumes_py",
+ ],
+)
+
+py_test(
+ name = "correlation_matrix_volumes_test",
+ size = "medium",
+ srcs = ["correlation_matrix_volumes_test.py"],
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+ deps = [
+ ":correlation_matrix_volumes_py",
+ # For statistical testing
+ "//tensorflow/contrib/distributions:distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ ],
+)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
new file mode 100644
index 0000000000..2eab51cd30
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
@@ -0,0 +1,98 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Executable to estimate the volume of various sets of correlation matrices.
+
+See correlation_matrix_volumes_lib.py for purpose and methodology.
+
+Invocation example:
+```
+python correlation_matrix_volumes.py --num_samples 1e7
+```
+
+This will compute 10,000,000-sample confidence intervals for the
+volumes of several sets of correlation matrices. Which sets, and the
+desired statistical significance, are hard-coded in this source file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import pprint
+
+from absl import app
+from absl import flags
+
+from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
+
+FLAGS = flags.FLAGS
+
+# Float to support giving the number of samples in scientific notation.
+# The production run used for the LKJ test used 1e7 samples.
+flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.')
+
+
+def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
+ # This wrapper undoes the batching in compute_true_volumes, because
+ # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM.
+ bounds = {}
+ for db in det_bounds:
+ bounds[db] = corr.compute_true_volumes(
+ [db], dim, num_samples, error_rate=error_rate, seed=seed)[db]
+ return bounds
+
+
+# The particular bounds in all three of these functions were chosen by
+# a somewhat arbitrary walk through an empirical tradeoff, for the
+# purpose of testing the LKJ distribution. Setting the determinant
+# bound lower
+# - Covers more of the testee's sample space, and
+# - Increases the probability that the rejection sampler will hit, thus
+# - Decreases the relative error (at a fixed sample count) in the
+# rejection-based volume estimate;
+# but also
+# - Increases the variance of the estimator used in the LKJ test.
+# This latter variance is also affected by the dimension and the
+# tested concentration parameter, and can be compensated for with more
+# compute (expensive) or a looser discrepancy limit (unsatisfying).
+# The values here are the projection of the points in that test design
+# space that ended up getting chosen.
+def compute_3x3_volumes(num_samples):
+ det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
+ return ctv_debatched(
+ det_bounds, 3, num_samples, error_rate=5e-7, seed=46)
+
+
+def compute_4x4_volumes(num_samples):
+ det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
+ return ctv_debatched(
+ det_bounds, 4, num_samples, error_rate=5e-7, seed=47)
+
+
+def compute_5x5_volumes(num_samples):
+ det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4]
+ return ctv_debatched(
+ det_bounds, 5, num_samples, error_rate=5e-7, seed=48)
+
+
+def main(_):
+ full_bounds = {}
+ full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples))
+ full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples))
+ full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples))
+ pprint.pprint(full_bounds)
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
new file mode 100644
index 0000000000..455e71f00c
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
@@ -0,0 +1,323 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Estimating the volume of the correlation matrices with bounded determinant.
+
+Why? Because lkj_test.py tests the sampler for the LKJ distribution
+by estimating the same volume another way.
+
+How? Rejection sampling. Or, more precisely, importance sampling,
+proposing from the uniform distribution on symmetric matrices with
+diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation
+matrix if and only if it is also positive semi-definite.
+
+The samples can then be converted into a confidence interval on the
+volume in question by the [Clopper-Pearson
+method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval),
+also implemented here.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import importlib
+import sys
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import uniform
+from tensorflow.python.ops.distributions import util
+from tensorflow.python.platform import tf_logging
+
+__all__ = [
+ "correlation_matrix_volume_rejection_samples",
+ "compute_true_volumes",
+]
+
+
+def try_import(name): # pylint: disable=invalid-name
+ module = None
+ try:
+ module = importlib.import_module(name)
+ except ImportError as e:
+ tf_logging.warning("Could not import %s: %s" % (name, str(e)))
+ return module
+
+optimize = try_import("scipy.optimize")
+stats = try_import("scipy.stats")
+
+
+def _psd_mask(x):
+ """Computes whether each square matrix in the input is positive semi-definite.
+
+ Args:
+ x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
+
+ Returns:
+ mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each
+ scalar is 1 if the corresponding matrix was PSD, otherwise 0.
+ """
+ # Allegedly
+ # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite
+ # it is more efficient to test for positive semi-definiteness by
+ # trying to compute the Cholesky decomposition -- the matrix is PSD
+ # if you succeed and not PSD if you fail. However, TensorFlow's
+ # Cholesky raises an exception if _any_ of the input matrices are
+ # not PSD, from which I don't know how to extract _which ones_, so I
+ # proceed by explicitly computing all the eigenvalues and checking
+ # whether they are all positive or not.
+ #
+ # Also, as was discussed in the answer, it is somewhat dangerous to
+ # treat SPD-ness as binary in floating-point arithmetic. Cholesky
+ # factorization can complete and 'look' like everything is fine
+ # (e.g., O(1) entries and a diagonal of all ones) but the matrix can
+ # have an exponential condition number.
+ eigenvalues, _ = linalg_ops.self_adjoint_eig(x)
+ return math_ops.cast(
+ math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype)
+
+
+def _det_large_enough_mask(x, det_bounds):
+ """Returns whether the input matches the given determinant limit.
+
+ Args:
+ x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
+ det_bounds: A floating-point `Tensor` that must broadcast to shape
+ `[B1, ..., Bn]`, giving the desired lower bound on the
+ determinants in `x`.
+
+ Returns:
+ mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each
+ scalar is 1 if the corresponding matrix had determinant above
+ the corresponding bound, otherwise 0.
+ """
+ # For the curious: I wonder whether it is possible and desirable to
+ # use a Cholesky decomposition-based algorithm for this, since the
+ # only matrices whose determinant this code cares about will be PSD.
+ # Didn't figure out how to code that in TensorFlow.
+ #
+ # Expert opinion is that it would be about twice as fast since
+ # Cholesky is roughly half the cost of Gaussian Elimination with
+ # Partial Pivoting. But this is less of an impact than the switch in
+ # _psd_mask.
+ return math_ops.cast(
+ linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype)
+
+
+def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
+ """Returns a uniformly random `Tensor` of "correlation-like" matrices.
+
+ A "correlation-like" matrix is a symmetric square matrix with all entries
+ between -1 and 1 (inclusive) and 1s on the main diagonal. Of these,
+ the ones that are positive semi-definite are exactly the correlation
+ matrices.
+
+ Args:
+ num_rows: Python `int` dimension of the correlation-like matrices.
+ batch_shape: `Tensor` or Python `tuple` of `int` shape of the
+ batch to return.
+ dtype: `dtype` of the `Tensor` to return.
+ seed: Random seed.
+
+ Returns:
+ matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
+ and dtype `dtype`. Each entry is in [-1, 1], and each matrix
+ along the bottom two dimensions is symmetric and has 1s on the
+ main diagonal.
+ """
+ num_entries = num_rows * (num_rows + 1) / 2
+ ones = array_ops.ones(shape=[num_entries], dtype=dtype)
+ # It seems wasteful to generate random values for the diagonal since
+ # I am going to throw them away, but `fill_triangular` fills the
+ # diagonal, so I probably need them.
+ # It's not impossible that it would be more efficient to just fill
+ # the whole matrix with random values instead of messing with
+ # `fill_triangular`. Then would need to filter almost half out with
+ # `matrix_band_part`.
+ unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
+ tril = util.fill_triangular(unifs)
+ symmetric = tril + array_ops.matrix_transpose(tril)
+ diagonal_ones = array_ops.ones(
+ shape=util.pad(batch_shape, axis=0, back=True, value=num_rows),
+ dtype=dtype)
+ return array_ops.matrix_set_diag(symmetric, diagonal_ones)
+
+
+def correlation_matrix_volume_rejection_samples(
+ det_bounds, dim, sample_shape, dtype, seed):
+ """Returns rejection samples from trying to get good correlation matrices.
+
+ The proposal being rejected from is the uniform distribution on
+ "correlation-like" matrices. We say a matrix is "correlation-like"
+ if it is a symmetric square matrix with all entries between -1 and 1
+ (inclusive) and 1s on the main diagonal. Of these, the ones that
+ are positive semi-definite are exactly the correlation matrices.
+
+ The rejection algorithm, then, is to sample a `Tensor` of
+ `sample_shape` correlation-like matrices of dimensions `dim` by
+ `dim`, and check each one for (i) being a correlation matrix (i.e.,
+ PSD), and (ii) having determinant at least the corresponding entry
+ of `det_bounds`.
+
+ Args:
+ det_bounds: A `Tensor` of lower bounds on the determinants of
+ acceptable matrices. The shape must broadcast with `sample_shape`.
+ dim: A Python `int` dimension of correlation matrices to sample.
+ sample_shape: Python `tuple` of `int` shape of the samples to
+ compute, excluding the two matrix dimensions.
+ dtype: The `dtype` in which to do the computation.
+ seed: Random seed.
+
+ Returns:
+ weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the
+ corresponding matrix was not a correlation matrix, or had too
+ small of a determinant. Otherwise, the entry is the
+ multiplicative inverse of the density of proposing that matrix
+ uniformly, i.e., the volume of the set of `dim` by `dim`
+ correlation-like matrices.
+ volume: The volume of the set of `dim` by `dim` correlation-like
+ matrices.
+ """
+ with ops.name_scope("rejection_sampler"):
+ rej_proposals = _uniform_correlation_like_matrix(
+ dim, sample_shape, dtype, seed=seed)
+ rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.)
+ # The density of proposing any given point is 1 / rej_proposal_volume;
+ # The weight of that point should be scaled by
+ # 1 / density = rej_proposal_volume.
+ rej_weights = rej_proposal_volume * _psd_mask(
+ rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds)
+ return rej_weights, rej_proposal_volume
+
+
+def _clopper_pearson_confidence_interval(samples, error_rate):
+ """Computes a confidence interval for the mean of the given 1-D distribution.
+
+ Assumes (and checks) that the given distribution is Bernoulli, i.e.,
+ takes only two values. This licenses using the CDF of the binomial
+ distribution for the confidence, which is tighter (for extreme
+ probabilities) than the DKWM inequality. The method is known as the
+ [Clopper-Pearson method]
+ (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
+
+ Assumes:
+
+ - The given samples were drawn iid from the distribution of interest.
+
+ - The given distribution is a Bernoulli, i.e., supported only on
+ low and high.
+
+ Guarantees:
+
+ - The probability (over the randomness of drawing the given sample)
+ that the true mean is outside the returned interval is no more
+ than the given error_rate.
+
+ Args:
+ samples: `np.ndarray` of samples drawn iid from the distribution
+ of interest.
+ error_rate: Python `float` admissible rate of mistakes.
+
+ Returns:
+ low: Lower bound of confidence interval.
+ high: Upper bound of confidence interval.
+
+ Raises:
+ ValueError: If `samples` has rank other than 1 (batch semantics
+ are not implemented), or if `samples` contains values other than
+ `low` or `high` (as that makes the distribution not Bernoulli).
+ """
+ # TODO(b/78025336) Migrate this confidence interval function
+ # to statistical_testing.py. In order to do that
+ # - Get the binomial CDF from the Binomial distribution
+ # - Implement scalar root finding in TF. Batch bisection search
+ # shouldn't be too hard, and is definitely good enough for this
+ # problem. Batching the Brent algorithm (from scipy) that is used
+ # here may be more involved, but may also not be necessary---it's
+ # only used here because scipy made it convenient. In particular,
+ # robustness is more important than speed here, which may make
+ # bisection search actively better.
+ # - The rest is just a matter of rewriting in the appropriate style.
+ if optimize is None or stats is None:
+ raise ValueError(
+ "Scipy is required for computing Clopper-Pearson confidence intervals")
+ if len(samples.shape) != 1:
+ raise ValueError("Batch semantics not implemented")
+ n = len(samples)
+ low = np.amin(samples)
+ high = np.amax(samples)
+ successes = np.count_nonzero(samples - low)
+ failures = np.count_nonzero(samples - high)
+ if successes + failures != n:
+ uniques = np.unique(samples)
+ msg = ("Purportedly Bernoulli distribution had distinct samples"
+ " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2]))
+ raise ValueError(msg)
+ def p_small_enough(p):
+ prob = stats.binom.logcdf(successes, n, p)
+ return prob - np.log(error_rate / 2.)
+ def p_big_enough(p):
+ prob = stats.binom.logsf(successes, n, p)
+ return prob - np.log(error_rate / 2.)
+ high_p = optimize.brentq(
+ p_small_enough, float(successes) / n, 1., rtol=1e-9)
+ low_p = optimize.brentq(
+ p_big_enough, 0., float(successes) / n, rtol=1e-9)
+ low_interval = low + (high - low) * low_p
+ high_interval = low + (high - low) * high_p
+ return (low_interval, high_interval)
+
+
+def compute_true_volumes(
+ det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
+ """Returns confidence intervals for the desired correlation matrix volumes.
+
+ The confidence intervals are computed by the [Clopper-Pearson method]
+ (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
+
+ Args:
+ det_bounds: A rank-1 numpy array of lower bounds on the
+ determinants of acceptable matrices. Entries must be unique.
+ dim: A Python `int` dimension of correlation matrices to sample.
+ num_samples: The number of samples to draw.
+ error_rate: The statistical significance of the returned
+ confidence intervals. The significance is broadcast: Each
+ returned interval separately may be incorrect with probability
+ (under the sample of correlation-like matrices drawn internally)
+ at most `error_rate`.
+ seed: Random seed.
+
+ Returns:
+ bounds: A Python `dict` mapping each determinant bound to the low, high
+ tuple giving the confidence interval.
+ """
+ bounds = {}
+ with session.Session() as sess:
+ rej_weights, _ = correlation_matrix_volume_rejection_samples(
+ det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed)
+ rej_weights = sess.run(rej_weights)
+ for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds):
+ template = ("Estimating volume of {}x{} correlation "
+ "matrices with determinant >= {}.")
+ print(template.format(dim, dim, det))
+ sys.stdout.flush()
+ bounds[det] = _clopper_pearson_confidence_interval(
+ rw, error_rate=error_rate)
+ return bounds
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
new file mode 100644
index 0000000000..8f99300e63
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
@@ -0,0 +1,150 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for correlation_matrix_volumes_lib.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
+from tensorflow.contrib.distributions.python.ops import statistical_testing as st
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.platform import test
+
+
+# NxN correlation matrices are determined by the N*(N-1)/2
+# lower-triangular entries. In addition to being between -1 and 1,
+# they must also obey the constraint that the determinant of the
+# resulting symmetric matrix is non-negative. In 2x2, we can even
+# analytically compute the volume when the determinant is bounded to >
+# epsilon, as that boils down to the one lower-triangular entry being
+# less than 1 - epsilon in absolute value.
+def two_by_two_volume(det_bound):
+ return 2 * np.sqrt(1.0 - det_bound)
+
+
+# The post
+# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/
+# derives (with elementary calculus) that the volume (with respect to
+# Lebesgue^3 measure) of the set of 3x3 correlation matrices is
+# pi^2/2. The same result is also obtained by [1].
+def three_by_three_volume():
+ return np.pi**2 / 2.
+
+
+# The volume of the unconstrained set of correlation matrices is also
+# the normalization constant of the LKJ distribution from [2]. As
+# part of defining the distribution, that reference a derives general
+# formula for this volume for all dimensions. A TensorFlow
+# computation thereof gave the below result for 4x4:
+def four_by_four_volume():
+ # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0]))
+ return 11.6973076
+
+# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of
+# correlation matrices." The American Statistician, 48(4), 276-279.
+
+# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating
+# random correlation matrices based on vines and extended onion
+# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001.
+
+
+class CorrelationMatrixVolumesTest(test.TestCase):
+
+ def testRejection2D(self):
+ num_samples = int(1e5) # Chosen for a small min detectable discrepancy
+ det_bounds = np.array(
+ [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
+ exact_volumes = two_by_two_volume(det_bounds)
+ (rej_weights,
+ rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
+ det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43)
+ # shape of rej_weights: [num_samples, 9, 2, 2]
+ chk1 = st.assert_true_mean_equal_by_dkwm(
+ rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
+ false_fail_rate=1e-6)
+ chk2 = check_ops.assert_less(
+ st.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=rej_proposal_volume,
+ # Correct the false fail rate due to different broadcasting
+ false_fail_rate=1.1e-7, false_pass_rate=1e-6),
+ 0.036)
+ with ops.control_dependencies([chk1, chk2]):
+ rej_weights = array_ops.identity(rej_weights)
+ self.evaluate(rej_weights)
+
+ def testRejection3D(self):
+ num_samples = int(1e5) # Chosen for a small min detectable discrepancy
+ det_bounds = np.array([0.0], dtype=np.float32)
+ exact_volumes = np.array([three_by_three_volume()], dtype=np.float32)
+ (rej_weights,
+ rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
+ det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44)
+ # shape of rej_weights: [num_samples, 1, 3, 3]
+ chk1 = st.assert_true_mean_equal_by_dkwm(
+ rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
+ false_fail_rate=1e-6)
+ chk2 = check_ops.assert_less(
+ st.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=rej_proposal_volume,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ # Going for about a 3% relative error
+ 0.15)
+ with ops.control_dependencies([chk1, chk2]):
+ rej_weights = array_ops.identity(rej_weights)
+ self.evaluate(rej_weights)
+
+ def testRejection4D(self):
+ num_samples = int(1e5) # Chosen for a small min detectable discrepancy
+ det_bounds = np.array([0.0], dtype=np.float32)
+ exact_volumes = [four_by_four_volume()]
+ (rej_weights,
+ rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
+ det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45)
+ # shape of rej_weights: [num_samples, 1, 4, 4]
+ chk1 = st.assert_true_mean_equal_by_dkwm(
+ rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
+ false_fail_rate=1e-6)
+ chk2 = check_ops.assert_less(
+ st.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=rej_proposal_volume,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ # Going for about a 10% relative error
+ 1.1)
+ with ops.control_dependencies([chk1, chk2]):
+ rej_weights = array_ops.identity(rej_weights)
+ self.evaluate(rej_weights)
+
+ def testVolumeEstimation2D(self):
+ # Test that the confidence intervals produced by
+ # corr.compte_true_volumes are sound, in the sense of containing
+ # the exact volume.
+ num_samples = int(1e5) # Chosen by symmetry with testRejection2D
+ det_bounds = np.array(
+ [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
+ volume_bounds = corr.compute_true_volumes(
+ det_bounds, 2, num_samples, error_rate=1e-6, seed=47)
+ exact_volumes = two_by_two_volume(det_bounds)
+ for det, volume in zip(det_bounds, exact_volumes):
+ computed_low, computed_high = volume_bounds[det]
+ self.assertLess(computed_low, volume)
+ self.assertGreater(computed_high, volume)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index 11ca90c483..bb9b8043b2 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class Autoregressive(distribution_lib.Distribution):
@@ -107,6 +108,14 @@ class Autoregressive(distribution_lib.Distribution):
https://arxiv.org/abs/1606.05328
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
distribution_fn,
sample0=None,
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 4714caad69..519077bc9a 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.util import deprecation
__all__ = [
@@ -71,6 +72,14 @@ class BatchReshape(distribution_lib.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
distribution,
batch_shape,
@@ -352,6 +361,14 @@ class BatchReshape(distribution_lib.Distribution):
return runtime_assertions
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
"""Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
@@ -384,6 +401,14 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None):
return expanded_new_shape, batch_shape_static, validations
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def validate_init_args_statically(distribution, batch_shape):
"""Helper to __init__ which makes or raises assertions."""
if batch_shape.shape.ndims is not None:
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
index 4965381ef3..e141f8b5c6 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py
@@ -24,6 +24,7 @@
@@CholeskyOuterProduct
@@ConditionalBijector
@@Exp
+@@FillTriangular
@@Gumbel
@@Identity
@@Inline
@@ -36,12 +37,14 @@
@@PowerTransform
@@RealNVP
@@Reshape
+@@ScaleTriL
@@Sigmoid
@@SinhArcsinh
@@SoftmaxCentered
@@Softplus
@@Softsign
@@Square
+@@TransformDiagonal
@@Weibull
@@masked_autoregressive_default_template
@@ -64,6 +67,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
+from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import *
from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
@@ -75,12 +79,14 @@ from tensorflow.contrib.distributions.python.ops.bijectors.permute import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import *
from tensorflow.contrib.distributions.python.ops.bijectors.reshape import *
+from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import *
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import *
from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import *
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.contrib.distributions.python.ops.bijectors.softsign import *
from tensorflow.contrib.distributions.python.ops.bijectors.square import *
+from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import *
from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py
index c9e31d7712..4d6a46e735 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py
@@ -23,6 +23,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
"AbsoluteValue",
@@ -70,6 +71,14 @@ class AbsoluteValue(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="absolute_value"):
"""Instantiates the `AbsoluteValue` bijector.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
index b4c2939eb9..25f29452c3 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -36,6 +37,14 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _as_tensor(x, name):
"""Convenience to convert to `Tensor` or leave as `None`."""
return None if x is None else ops.convert_to_tensor(x, name=name)
@@ -97,6 +106,14 @@ class Affine(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
shift=None,
scale_identity_multiplier=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py
index 59f9742d57..91301f15ad 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.util import deprecation
__all__ = [
@@ -88,6 +89,14 @@ class AffineLinearOperator(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
shift=None,
scale=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
index cd792e2c8c..460d906231 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -52,6 +53,14 @@ class AffineScalar(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
shift=None,
scale=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py
index 224cec8a63..f19f147dd6 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -34,6 +35,14 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _undo_batch_normalization(x,
mean,
variance,
@@ -128,6 +137,14 @@ class BatchNormalization(bijector.Bijector):
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
batchnorm_layer=None,
training=True,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
index 16f959560c..910774ea5b 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -31,10 +32,26 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _use_static_shape(input_tensor, ndims):
return input_tensor.shape.is_fully_defined() and isinstance(ndims, int)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _compute_min_event_ndims(bijector_list, compute_forward=True):
"""Computes the min_event_ndims associated with the give list of bijectors.
@@ -142,6 +159,14 @@ class Chain(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, bijectors=None, validate_args=False, name=None):
"""Instantiates `Chain` bijector.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
index 268c8d0342..3e1e4fc829 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
@@ -69,6 +70,14 @@ class CholeskyOuterProduct(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="cholesky_outer_product"):
"""Instantiates the `CholeskyOuterProduct` bijector.
@@ -173,7 +182,20 @@ class CholeskyOuterProduct(bijector.Bijector):
axis=-1)
fldj = p_float * np.log(2.) + sum_weighted_log_diag
- return fldj
+ # We finally need to undo adding an extra column in non-scalar cases
+ # where there is a single matrix as input.
+ if x.get_shape().ndims is not None:
+ if x.get_shape().ndims == 2:
+ fldj = array_ops.squeeze(fldj, axis=-1)
+ return fldj
+
+ shape = array_ops.shape(fldj)
+ maybe_squeeze_shape = array_ops.concat([
+ shape[:-1],
+ distribution_util.pick_vector(
+ math_ops.equal(array_ops.rank(x), 2),
+ np.array([], dtype=np.int32), shape[-1:])], 0)
+ return array_ops.reshape(fldj, maybe_squeeze_shape)
def _make_columnar(self, x):
"""Ensures non-scalar input has at least one column.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py
index 9fc1bbf052..07627e1e45 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.bijectors import power_transform
+from tensorflow.python.util import deprecation
__all__ = [
@@ -47,6 +48,14 @@ class Exp(power_transform.PowerTransform):
over the event space.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
validate_args=False,
name="exp"):
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
new file mode 100644
index 0000000000..31a9ca27e5
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py
@@ -0,0 +1,165 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""FillTriangular bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.distributions import util as dist_util
+from tensorflow.python.util import deprecation
+
+
+__all__ = [
+ "FillTriangular",
+]
+
+
+class FillTriangular(bijector.Bijector):
+ """Transforms vectors to triangular.
+
+ Triangular matrix elements are filled in a clockwise spiral.
+
+ Given input with shape `batch_shape + [d]`, produces output with
+ shape `batch_shape + [n, n]`, where
+ `n = (-1 + sqrt(1 + 8 * d))/2`.
+ This follows by solving the quadratic equation
+ `d = 1 + 2 + ... + n = n * (n + 1)/2`.
+
+ #### Example
+
+ ```python
+ b = tfb.FillTriangular(upper=False)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[4, 0, 0],
+ # [6, 5, 0],
+ # [3, 2, 1]]
+
+ b = tfb.FillTriangular(upper=True)
+ b.forward([1, 2, 3, 4, 5, 6])
+ # ==> [[1, 2, 3],
+ # [0, 5, 6],
+ # [0, 0, 4]]
+
+ ```
+ """
+
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
+ def __init__(self,
+ upper=False,
+ validate_args=False,
+ name="fill_triangular"):
+ """Instantiates the `FillTriangular` bijector.
+
+ Args:
+ upper: Python `bool` representing whether output matrix should be upper
+ triangular (`True`) or lower triangular (`False`, default).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._upper = upper
+ super(FillTriangular, self).__init__(
+ forward_min_event_ndims=1,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ return dist_util.fill_triangular(x, upper=self._upper)
+
+ def _inverse(self, y):
+ return dist_util.fill_triangular_inverse(y, upper=self._upper)
+
+ def _forward_log_det_jacobian(self, x):
+ return array_ops.zeros_like(x[..., 0])
+
+ def _inverse_log_det_jacobian(self, y):
+ return array_ops.zeros_like(y[..., 0, 0])
+
+ def _forward_event_shape(self, input_shape):
+ batch_shape, d = input_shape[:-1], input_shape[-1].value
+ if d is None:
+ n = None
+ else:
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return batch_shape.concatenate([n, n])
+
+ def _inverse_event_shape(self, output_shape):
+ batch_shape, n1, n2 = (output_shape[:-2],
+ output_shape[-2].value,
+ output_shape[-1].value)
+ if n1 is None or n2 is None:
+ m = None
+ elif n1 != n2:
+ raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2))
+ else:
+ m = n1 * (n1 + 1) / 2
+ return batch_shape.concatenate([m])
+
+ def _forward_event_shape_tensor(self, input_shape_tensor):
+ batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1]
+ n = vector_size_to_square_matrix_size(d, self.validate_args)
+ return array_ops.concat([batch_shape, [n, n]], axis=0)
+
+ def _inverse_event_shape_tensor(self, output_shape_tensor):
+ batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1]
+ if self.validate_args:
+ is_square_matrix = check_ops.assert_equal(
+ n, output_shape_tensor[-2], message="Matrix must be square.")
+ with ops.control_dependencies([is_square_matrix]):
+ n = array_ops.identity(n)
+ d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype)
+ return array_ops.concat([batch_shape, [d]], axis=0)
+
+
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
+def vector_size_to_square_matrix_size(d, validate_args, name=None):
+ """Convert a vector size to a matrix size."""
+ if isinstance(d, (float, int, np.generic, np.ndarray)):
+ n = (-1 + np.sqrt(1 + 8 * d)) / 2.
+ if float(int(n)) != n:
+ raise ValueError("Vector length is not a triangular number.")
+ return int(n)
+ else:
+ with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name:
+ n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2.
+ if validate_args:
+ with ops.control_dependencies([check_ops.assert_equal(
+ math_ops.to_float(math_ops.to_int32(n)), n,
+ message="Vector length is not a triangular number")]):
+ n = array_ops.identity(n)
+ return math_ops.cast(n, d.dtype)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
index e656a258e5..71e562a927 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
"Gumbel",
@@ -45,6 +46,14 @@ class Gumbel(bijector.Bijector):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=0.,
scale=1.,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py
index 2bde956d13..1504bd2720 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -43,6 +44,14 @@ class Inline(bijector.Bijector):
The above example is equivalent to the `Bijector` `Exp()`.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
forward_fn=None,
inverse_fn=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
index 84a3289ba2..a648676d4b 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
"Invert",
@@ -40,6 +41,14 @@ class Invert(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, bijector, validate_args=False, name=None):
"""Creates a `Bijector` which swaps the meaning of `inverse` and `forward`.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
index 97000c1726..33b75a04d3 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
"Kumaraswamy",
@@ -44,6 +45,14 @@ class Kumaraswamy(bijector.Bijector):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
concentration1=None,
concentration0=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 83667b0e80..b8f2a4b2c7 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import template as template_ops
from tensorflow.python.ops import variable_scope as variable_scope_lib
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -186,6 +187,14 @@ class MaskedAutoregressiveFlow(bijector.Bijector):
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
shift_and_log_scale_fn,
is_constant_jacobian=False,
@@ -296,6 +305,14 @@ MASK_INCLUSIVE = "inclusive"
MASK_EXCLUSIVE = "exclusive"
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE):
"""Generate the slices for building an autoregressive mask."""
# TODO(b/67594795): Better support of dynamic shape.
@@ -313,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE):
return slices
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _gen_mask(num_blocks,
n_in,
n_out,
@@ -327,6 +352,14 @@ def _gen_mask(num_blocks,
return mask
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def masked_dense(inputs,
units,
num_blocks=None,
@@ -399,6 +432,14 @@ def masked_dense(inputs,
return layer.apply(inputs)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def masked_autoregressive_default_template(
hidden_layers,
shift_only=False,
@@ -515,6 +556,14 @@ def masked_autoregressive_default_template(
"masked_autoregressive_default_template", _fn)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None):
"""Clips input while leaving gradient unaltered."""
with ops.name_scope(name, "clip_by_value_preserve_grad",
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py
index 71903f7052..49e6192f06 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -55,6 +56,14 @@ class MatrixInverseTriL(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="matrix_inverse_tril"):
"""Instantiates the `MatrixInverseTriL` bijector.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
index 3f03592f31..fb393218b6 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -57,6 +58,14 @@ class Ordered(bijector.Bijector):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="ordered"):
super(Ordered, self).__init__(
forward_min_event_ndims=1,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index 12a16a3f2b..f182a1adcb 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -74,6 +75,14 @@ class Permute(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, permutation, validate_args=False, name=None):
"""Creates the `Permute` bijector.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py
index 71f123f2a9..16264fe728 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -41,6 +42,14 @@ class PowerTransform(bijector.Bijector):
This bijector is equivalent to the `Exp` bijector when `c=0`.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
power=0.,
validate_args=False,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index 66e8a5b9b3..773ae24461 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import template as template_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -126,6 +127,14 @@ class RealNVP(bijector.Bijector):
Processing Systems_, 2017. https://arxiv.org/abs/1705.07057
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
num_masked,
shift_and_log_scale_fn,
@@ -228,6 +237,14 @@ class RealNVP(bijector.Bijector):
return math_ops.reduce_sum(log_scale, axis=-1)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def real_nvp_default_template(
hidden_layers,
shift_only=False,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index 5497c422e4..c8282229a3 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -36,10 +37,26 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _static_ndims_from_shape(shape):
return shape.shape.with_rank_at_least(1)[0].value
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _ndims_from_shape(shape):
return array_ops.shape(shape)[0]
@@ -86,6 +103,14 @@ class Reshape(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, event_shape_out, event_shape_in=(-1,),
validate_args=False, name=None):
"""Creates a `Reshape` bijector.
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
new file mode 100644
index 0000000000..6fbe866578
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ScaleTriL bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar
+from tensorflow.contrib.distributions.python.ops.bijectors import chain
+from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular
+from tensorflow.contrib.distributions.python.ops.bijectors import softplus
+from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal
+from tensorflow.python.util import deprecation
+
+__all__ = [
+ "ScaleTriL",
+]
+
+
+class ScaleTriL(chain.Chain):
+ """Transforms unconstrained vectors to TriL matrices with positive diagonal.
+
+ This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular`
+ followed by `tfb.TransformDiagonal`, and provided mostly as a
+ convenience. The default setup is somewhat opinionated, using a
+ Softplus transformation followed by a small shift (`1e-5`) which
+ attempts to avoid numerical issues from zeros on the diagonal.
+
+ #### Examples
+
+ ```python
+ tfb = tf.contrib.distributions.bijectors
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Exp(),
+ diag_shift=None)
+ b.forward(x=[0., 0., 0.])
+ # Result: [[1., 0.],
+ # [0., 1.]]
+ b.inverse(y=[[1., 0],
+ [.5, 2]])
+ # Result: [log(2), .5, log(1)]
+
+ # Define a distribution over PSD matrices of shape `[3, 3]`,
+ # with `1 + 2 + 3 = 6` degrees of freedom.
+ dist = tfd.TransformedDistribution(
+ tfd.Normal(tf.zeros(6), tf.ones(6)),
+ tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()]))
+
+ # Using an identity transformation, ScaleTriL is equivalent to
+ # tfb.FillTriangular.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Identity(),
+ diag_shift=None)
+
+ # For greater control over initialization, one can manually encode
+ # pre- and post- shifts inside of `diag_bijector`.
+ b = tfb.ScaleTriL(
+ diag_bijector=tfb.Chain([
+ tfb.AffineScalar(shift=1e-3),
+ tfb.Softplus(),
+ tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.)
+ # = log(expm1(1.)) = 0.5413
+ diag_shift=None)
+ ```
+ """
+
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
+ def __init__(self,
+ diag_bijector=None,
+ diag_shift=1e-5,
+ validate_args=False,
+ name="scale_tril"):
+ """Instantiates the `ScaleTriL` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance, used to transform the output diagonal
+ to be positive.
+ Default value: `None` (i.e., `tfb.Softplus()`).
+ diag_shift: Float value broadcastable and added to all diagonal entries
+ after applying the `diag_bijector`. Setting a positive
+ value forces the output diagonal entries to be positive, but
+ prevents inverting the transformation for matrices with
+ diagonal entries less than this value.
+ Default value: `1e-5` (i.e., no shift is applied).
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ Default value: `False` (i.e., arguments are not validated).
+ name: Python `str` name given to ops managed by this object.
+ Default value: `scale_tril`.
+ """
+
+ if diag_bijector is None:
+ diag_bijector = softplus.Softplus(validate_args=validate_args)
+
+ if diag_shift is not None:
+ diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift),
+ diag_bijector])
+
+ super(ScaleTriL, self).__init__(
+ [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
+ fill_triangular.FillTriangular()],
+ validate_args=validate_args,
+ name=name)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py
index 5df8c88631..194b318fce 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -31,6 +32,14 @@ __all__ = [
class Sigmoid(bijector.Bijector):
"""Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="sigmoid"):
super(Sigmoid, self).__init__(
forward_min_event_ndims=0,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
index 2a32e8abcd..241fba2cb7 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py
@@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
"SinhArcsinh",
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _sqrtx2p1(x):
"""Implementation of `sqrt(1 + x**2)` which is stable despite large `x`."""
return array_ops.where(
@@ -88,6 +97,14 @@ class SinhArcsinh(bijector.Bijector):
`Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
skewness=None,
tailweight=None,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
index f52b91550e..20ee0d3408 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -60,6 +61,14 @@ class SoftmaxCentered(bijector.Bijector):
makes the (forward) image non-open and the theorem does not directly apply.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
validate_args=False,
name="softmax_centered"):
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py
index 96a938c803..3df84ef8b0 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import bijector
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
@@ -80,6 +81,14 @@ class Softplus(bijector.Bijector):
"hinge_softness": (
"Nonzero floating point `Tensor`. Controls the softness of what "
"would otherwise be a kink at the origin. Default is 1.0")})
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
hinge_softness=None,
validate_args=False,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py
index b4a658c171..f96a4bb01d 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py
@@ -22,6 +22,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -51,6 +52,14 @@ class Softsign(bijector.Bijector):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="softsign"):
super(Softsign, self).__init__(
forward_min_event_ndims=0,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py
index 2ccfdc9597..294460a80f 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -49,6 +50,14 @@ class Square(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="square"):
"""Instantiates the `Square` bijector.
@@ -81,4 +90,3 @@ class Square(bijector.Bijector):
is_valid = check_ops.assert_non_negative(
t, message="All elements must be non-negative.")
return control_flow_ops.with_dependencies([is_valid], t)
-
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
new file mode 100644
index 0000000000..9b7a3b026b
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py
@@ -0,0 +1,111 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TransformDiagonal bijector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
+
+__all__ = [
+ "TransformDiagonal",
+]
+
+
+class TransformDiagonal(bijector.Bijector):
+ """Applies a Bijector to the diagonal of a matrix.
+
+ #### Example
+
+ ```python
+ b = tfb.TransformDiagonal(diag_bijector=tfb.Exp())
+
+ b.forward([[1., 0.],
+ [0., 1.]])
+ # ==> [[2.718, 0.],
+ [0., 2.718]]
+ ```
+
+ """
+
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
+ def __init__(self,
+ diag_bijector,
+ validate_args=False,
+ name="transform_diagonal"):
+ """Instantiates the `TransformDiagonal` bijector.
+
+ Args:
+ diag_bijector: `Bijector` instance used to transform the diagonal.
+ validate_args: Python `bool` indicating whether arguments should be
+ checked for correctness.
+ name: Python `str` name given to ops managed by this object.
+ """
+ self._diag_bijector = diag_bijector
+ super(TransformDiagonal, self).__init__(
+ forward_min_event_ndims=2,
+ inverse_min_event_ndims=2,
+ validate_args=validate_args,
+ name=name)
+
+ def _forward(self, x):
+ diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x))
+ return array_ops.matrix_set_diag(x, diag)
+
+ def _inverse(self, y):
+ diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y))
+ return array_ops.matrix_set_diag(y, diag)
+
+ def _forward_log_det_jacobian(self, x):
+ # We formulate the Jacobian with respect to the flattened matrices
+ # `vec(x)` and `vec(y)`. Suppose for notational convenience that
+ # the first `n` entries of `vec(x)` are the diagonal of `x`, and
+ # the remaining `n**2-n` entries are the off-diagonals in
+ # arbitrary order. Then the Jacobian is a block-diagonal matrix,
+ # with the Jacobian of the diagonal bijector in the first block,
+ # and the identity Jacobian for the remaining entries (since this
+ # bijector acts as the identity on non-diagonal entries):
+ #
+ # J_vec(x) (vec(y)) =
+ # -------------------------------
+ # | J_diag(x) (diag(y)) 0 | n entries
+ # | |
+ # | 0 I | n**2-n entries
+ # -------------------------------
+ # n n**2-n
+ #
+ # Since the log-det of the second (identity) block is zero, the
+ # overall log-det-jacobian is just the log-det of first block,
+ # from the diagonal bijector.
+ #
+ # Note that for elementwise operations (exp, softplus, etc) the
+ # first block of the Jacobian will itself be a diagonal matrix,
+ # but our implementation does not require this to be true.
+ return self._diag_bijector.forward_log_det_jacobian(
+ array_ops.matrix_diag_part(x), event_ndims=1)
+
+ def _inverse_log_det_jacobian(self, y):
+ return self._diag_bijector.inverse_log_det_jacobian(
+ array_ops.matrix_diag_part(y), event_ndims=1)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
index a22560fe80..8903a70d98 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -47,6 +48,14 @@ class Weibull(bijector.Bijector):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
scale=1.,
concentration=1.,
diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py
index e4944beedc..b349e5966d 100644
--- a/tensorflow/contrib/distributions/python/ops/binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/binomial.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
_binomial_sample_note = """
@@ -42,6 +43,14 @@ to integer values.
"""
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _bdtr(k, n, p):
"""The binomial cumulative distribution function.
@@ -130,6 +139,14 @@ class Binomial(distribution.Distribution):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
logits=None,
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index 23b6a83c17..cb5223b055 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
__all__ = [
"Cauchy",
@@ -92,6 +93,14 @@ class Cauchy(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py
index 686ae1ba74..e9a7b39070 100644
--- a/tensorflow/contrib/distributions/python/ops/chi2.py
+++ b/tensorflow/contrib/distributions/python/ops/chi2.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.util import deprecation
__all__ = [
@@ -63,6 +64,14 @@ class Chi2(gamma.Gamma):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
validate_args=False,
@@ -114,6 +123,14 @@ class Chi2(gamma.Gamma):
class Chi2WithAbsDf(Chi2):
"""Chi2 with parameter transform `df = floor(abs(df))`."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
validate_args=False,
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index c44c76a133..ad853ee293 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
__all__ = [
"Deterministic",
@@ -43,6 +44,14 @@ __all__ = [
class _BaseDeterministic(distribution.Distribution):
"""Base class for Deterministic distributions."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
atol=None,
@@ -203,6 +212,14 @@ class Deterministic(_BaseDeterministic):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
atol=None,
@@ -308,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
atol=None,
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 289e1d50e1..6959b3e877 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -21,12 +21,19 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
+
+# The following two lines are redundant, in a sense. The first enables
+# good coding practice *within* this file (`util.prefer_static_value`
+# rather than `prefer_static_value`). The second ensures that users
+# also get the core utils when they import this file.
+from tensorflow.python.ops.distributions import util
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
@@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
def static_value(x):
"""Returns the static value of a `Tensor` or `None`."""
return tensor_util.constant_value(ops.convert_to_tensor(x))
+
+
+def move_dimension(x, source_idx, dest_idx):
+ """Move a single tensor dimension within its shape.
+
+ This is a special case of `tf.transpose()`, which applies
+ arbitrary permutations to tensor dimensions.
+
+ Args:
+ x: Tensor of rank `ndims`.
+ source_idx: Integer index into `x.shape` (negative indexing is
+ supported).
+ dest_idx: Integer index into `x.shape` (negative indexing is
+ supported).
+
+ Returns:
+ x_perm: Tensor of rank `ndims`, in which the dimension at original
+ index `source_idx` has been moved to new index `dest_idx`, with
+ all other dimensions retained in their original order.
+
+ Example:
+
+ ```python
+ x = tf.placeholder(shape=[200, 30, 4, 1, 6])
+ x_perm = _move_dimension(x, 1, 1) # no-op
+ x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
+ x_perm = _move_dimension(x, 0, -2) # equivalent to previous
+ x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
+ ```
+ """
+ ndims = util.prefer_static_rank(x)
+ if isinstance(source_idx, int):
+ dtype = dtypes.int32
+ else:
+ dtype = dtypes.as_dtype(source_idx.dtype)
+
+ # Handle negative indexing. Since ndims might be dynamic, this makes
+ # source_idx and dest_idx also possibly dynamic.
+ if source_idx < 0:
+ source_idx = ndims + source_idx
+ if dest_idx < 0:
+ dest_idx = ndims + dest_idx
+
+ # Construct the appropriate permutation of dimensions, depending
+ # whether the source is before or after the destination.
+ def move_left_permutation():
+ return util.prefer_static_value(
+ array_ops.concat([
+ math_ops.range(0, dest_idx, dtype=dtype),
+ [source_idx],
+ math_ops.range(dest_idx, source_idx, dtype=dtype),
+ math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0))
+
+ def move_right_permutation():
+ return util.prefer_static_value(
+ array_ops.concat([
+ math_ops.range(0, source_idx, dtype=dtype),
+ math_ops.range(source_idx+1, dest_idx+1, dtype=dtype),
+ [source_idx],
+ math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0))
+
+ def x_permuted():
+ return array_ops.transpose(
+ x, perm=smart_cond.smart_cond(source_idx < dest_idx,
+ move_right_permutation,
+ move_left_permutation))
+
+ # One final conditional to handle the special case where source
+ # and destination indices are equal.
+ return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
+ lambda: x,
+ x_permuted)
diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py
index 98edd337fe..bdec6527d5 100644
--- a/tensorflow/contrib/distributions/python/ops/estimator.py
+++ b/tensorflow/contrib/distributions/python/ops/estimator.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.util import deprecation
__all__ = [
@@ -30,6 +31,14 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def estimator_head_distribution_regression(make_distribution_fn,
label_dimension=1,
logits_dimension=None,
@@ -77,6 +86,14 @@ def estimator_head_distribution_regression(make_distribution_fn,
class _DistributionRegressionHead(_RegressionHead):
"""Creates a _RegressionHead instance from an arbitrary `Distribution`."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
make_distribution_fn,
label_dimension,
diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py
index e1e42ee95d..d62f024aa2 100644
--- a/tensorflow/contrib/distributions/python/ops/geometric.py
+++ b/tensorflow/contrib/distributions/python/ops/geometric.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class Geometric(distribution.Distribution):
@@ -55,6 +56,14 @@ class Geometric(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
logits=None,
probs=None,
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index 9d94fd11c6..acdea4d61d 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
class _Gumbel(distribution.Distribution):
@@ -96,6 +97,14 @@ class _Gumbel(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index 9c96254d1c..b02c403106 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util import deprecation
__all__ = [
@@ -85,6 +86,14 @@ class HalfNormal(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
scale,
validate_args=False,
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index cd6eaa8407..0672702b96 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import kullback_leibler
+from tensorflow.python.util import deprecation
class Independent(distribution_lib.Distribution):
@@ -94,6 +95,14 @@ class Independent(distribution_lib.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(
self, distribution, reinterpreted_batch_ndims=None,
validate_args=False, name=None):
@@ -258,6 +267,14 @@ class Independent(distribution_lib.Distribution):
@kullback_leibler.RegisterKL(Independent, Independent)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _kl_independent(a, b, name="kl_independent"):
"""Batched KL divergence `KL(a || b)` for Independent distributions.
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 208057b34d..70d050d7a6 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
@@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
@@ -274,6 +283,14 @@ class InverseGamma(distribution.Distribution):
class InverseGammaWithSoftplusConcentrationRate(InverseGamma):
"""`InverseGamma` with softplus of `concentration` and `rate`."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
index 66682b2ff5..e3712dd84e 100644
--- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
+++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py
@@ -31,7 +31,7 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import uniform
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util import deprecation
__all__ = [
"Kumaraswamy",
@@ -41,6 +41,14 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in
`[0, 1].` It must have a shape compatible with `self.batch_shape()`."""
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _harmonic_number(x):
"""Compute the harmonic number from its analytic continuation.
@@ -59,7 +67,6 @@ def _harmonic_number(x):
return math_ops.digamma(x + one) - math_ops.digamma(one)
-@tf_export("distributions.Kumaraswamy")
class Kumaraswamy(transformed_distribution.TransformedDistribution):
"""Kumaraswamy distribution.
@@ -125,6 +132,14 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
concentration1=None,
concentration0=None,
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 27aa863440..02e3bad51e 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
class Logistic(distribution.Distribution):
@@ -91,6 +92,14 @@ class Logistic(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index bfb53a06c0..3b7114ef06 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class Mixture(distribution.Distribution):
@@ -66,6 +67,14 @@ class Mixture(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
cat,
components,
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 112eefd369..8ffee940d0 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class MixtureSameFamily(distribution.Distribution):
@@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
mixture_distribution,
components_distribution,
@@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution):
return x
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _outer_squared_difference(x, y):
"""Convenience function analogous to tf.squared_difference."""
z = x - y
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index d2beb2aff0..cd0c282ba6 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import deprecation
__all__ = [
@@ -134,6 +135,14 @@ class MultivariateNormalDiag(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_diag=None,
@@ -218,6 +227,14 @@ class MultivariateNormalDiag(
class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
"""MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`."""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale_diag,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 5117379b04..d8401801f2 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -22,6 +22,7 @@ from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
__all__ = [
@@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_diag=None,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 57f47db50c..dbc4c1b3dc 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -24,6 +24,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.util import deprecation
__all__ = [
@@ -112,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
covariance_matrix=None,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index 6a0383db02..efe5a6d0d9 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.util import deprecation
__all__ = [
@@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale=None,
@@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator(
@kullback_leibler.RegisterKL(MultivariateNormalLinearOperator,
MultivariateNormalLinearOperator)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _kl_brute_force(a, b, name=None):
"""Batched KL divergence `KL(a || b)` for multivariate Normals.
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index c809ef3c1c..d9110947ec 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -22,6 +22,7 @@ from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
@@ -134,6 +135,14 @@ class MultivariateNormalTriL(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_tril=None,
diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
index 2bd11e24b3..6acfc5746a 100644
--- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py
+++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class NegativeBinomial(distribution.Distribution):
@@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution):
* `n!` is the factorial of `n`.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
logits=None,
diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
index 3e44c10fab..214c6dca4a 100644
--- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class OneHotCategorical(distribution.Distribution):
@@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(
self,
logits=None,
@@ -226,13 +235,21 @@ class OneHotCategorical(distribution.Distribution):
return x
return control_flow_ops.with_dependencies([
check_ops.assert_non_positive(x),
- distribution_util.assert_close(
+ check_ops.assert_near(
array_ops.zeros([], dtype=self.dtype),
math_ops.reduce_logsumexp(x, axis=[-1])),
], x)
@kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _kl_categorical_categorical(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical.
diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py
index 04de8106ee..3d055085cc 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
"Poisson",
@@ -65,6 +66,14 @@ class Poisson(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
rate=None,
log_rate=None,
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 7b10ba998f..7a7ad1be35 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib
+from tensorflow.python.util import deprecation
__all__ = [
@@ -42,6 +43,14 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def quadrature_scheme_lognormal_gauss_hermite(
loc, scale, quadrature_size,
validate_args=False, name=None): # pylint: disable=unused-argument
@@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite(
return grid, probs
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def quadrature_scheme_lognormal_quantiles(
loc, scale, quadrature_size,
validate_args=False, name=None):
@@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
validate_args=True)
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
@@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
axis=[-2, -1])
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def concat_vectors(*args):
"""Concatenates input vectors, statically if possible."""
args_ = [distribution_util.static_value(x) for x in args]
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 5ac6c34b53..ef3bdfa75f 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distributions
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = ["QuantizedDistribution"]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def _logsum_expbig_minus_expsmall(big, small):
"""Stable evaluation of `Log[exp{big} - exp{small}]`.
@@ -228,6 +237,14 @@ class QuantizedDistribution(distributions.Distribution):
https://arxiv.org/abs/1711.10433
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
distribution,
low=None,
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
index 4182ca2b56..7e1f64dc42 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py
@@ -19,15 +19,16 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import logistic
+from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
# Bijectors must be directly imported because `remove_undocumented` prevents
# individual file imports.
-from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
@@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution):
Gumbel-Softmax. 2016.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
temperature,
logits=None,
diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
index 5414f347cd..25aaac379a 100644
--- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
+++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class ExpRelaxedOneHotCategorical(distribution.Distribution):
@@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
A Continuous Relaxation of Discrete Random Variables. 2016.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(
self,
temperature,
@@ -290,7 +299,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution):
return x
return control_flow_ops.with_dependencies([
check_ops.assert_non_positive(x),
- distribution_util.assert_close(
+ check_ops.assert_near(
array_ops.zeros([], dtype=self.dtype),
math_ops.reduce_logsumexp(x, axis=[-1])),
], x)
@@ -368,6 +377,14 @@ class RelaxedOneHotCategorical(
A Continuous Relaxation of Discrete Random Variables. 2016.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(
self,
temperature,
diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py
index 6a7f28713a..4f348be280 100644
--- a/tensorflow/contrib/distributions/python/ops/shape.py
+++ b/tensorflow/contrib/distributions/python/ops/shape.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
class _DistributionShape(object):
@@ -166,6 +167,14 @@ class _DistributionShape(object):
"free," i.e., during graph construction.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
batch_ndims=None,
event_ndims=None,
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index a764544932..a9d0fb4ccf 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import transformed_distribution
+from tensorflow.python.util import deprecation
__all__ = [
"SinhArcsinh",
@@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 8d4914e16c..ece03fe4aa 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib
+from tensorflow.python.util import deprecation
__all__ = [
@@ -49,6 +50,14 @@ __all__ = [
]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def quadrature_scheme_softmaxnormal_gauss_hermite(
normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
@@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite(
return grid, probs
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def quadrature_scheme_softmaxnormal_quantiles(
normal_loc, normal_scale, quadrature_size,
validate_args=False, name=None):
@@ -318,6 +335,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
https://arxiv.org/abs/1801.03080
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
mix_loc,
temperature,
@@ -779,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution):
return array_ops.reshape(p, shape=expand_shape)
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def maybe_check_quadrature_param(param, name, validate_args):
"""Helper which checks validity of `loc` and `scale` init args."""
with ops.name_scope(name="check_" + name, values=[param]):
@@ -812,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args):
return param
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def determine_batch_event_shapes(grid, endpoint_affine):
"""Helper to infer batch_shape and event_shape."""
with ops.name_scope(name="determine_batch_event_shapes"):
@@ -850,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine):
return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def interpolate_loc(grid, loc):
"""Helper which interpolates between two locs."""
if len(loc) != 2:
@@ -876,6 +925,14 @@ def interpolate_loc(grid, loc):
return [x[..., k] for k in range(deg)] # list(shape:[B, e])
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def interpolate_scale(grid, scale):
"""Helper which interpolates between two scales."""
if len(scale) != 2:
@@ -892,6 +949,14 @@ def interpolate_scale(grid, scale):
])[0] for q in range(deg)]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def linop_scale(w, op):
# We assume w > 0. (This assumption only relates to the is_* attributes.)
with ops.name_scope("linop_scale", values=[w]):
@@ -927,6 +992,14 @@ def linop_scale(w, op):
"Unsupported Linop type ({})".format(type(op).__name__))
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def concat_vectors(*args):
"""Concatenates input vectors, statically if possible."""
args_ = [distribution_util.static_value(x) for x in args]
@@ -935,6 +1008,14 @@ def concat_vectors(*args):
return [val for vec in args_ for val in vec]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def add(x, y):
"""Adds inputs; interprets `None` as zero."""
if x is None:
@@ -944,11 +1025,27 @@ def add(x, y):
return x + y
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def vec_osquare(x):
"""Computes the outer-product of a (batch of) vector, i.e., x.T x."""
return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :]
+@deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def softmax(x, axis, name=None):
"""Equivalent to tf.nn.softmax but works around b/70297725."""
with ops.name_scope(name, "softmax", [x, axis]):
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index a75b3f3df1..73356a3625 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop
from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
__all__ = [
@@ -116,6 +117,14 @@ class VectorExponentialDiag(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_diag=None,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index a7d4c55be9..9a47b48557 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import exponential
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.util import deprecation
__all__ = ["VectorExponentialLinearOperator"]
@@ -138,6 +139,14 @@ class VectorExponentialLinearOperator(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale=None,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 4a53e7a621..e68ddc569c 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop
from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
__all__ = [
@@ -151,6 +152,14 @@ class VectorLaplaceDiag(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_diag=None,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 0566e04fec..3923161a33 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import laplace
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.util import deprecation
__all__ = [
@@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator(
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale=None,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index bb33cd0762..49ffff24ca 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import transformed_distribution
+from tensorflow.python.util import deprecation
__all__ = [
"VectorSinhArcsinhDiag",
@@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
```
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
loc=None,
scale_diag=None,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 21f84dcbde..f289b39e51 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import student_t
from tensorflow.python.ops.distributions import transformed_distribution
+from tensorflow.python.util import deprecation
class _VectorStudentT(transformed_distribution.TransformedDistribution):
@@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
loc=None,
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index 88d4280759..f1accaaa4c 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
__all__ = [
"WishartCholesky",
@@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution):
this class.
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
scale_operator,
@@ -501,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
scale,
@@ -617,6 +634,14 @@ class WishartFull(_WishartLinearOperator):
"""
+ @deprecation.deprecated(
+ "2018-10-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.contrib.distributions`.",
+ warn_once=True)
def __init__(self,
df,
scale,
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index adf92c27ea..58c548d798 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -102,6 +102,7 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase):
with ops.device(self._device):
self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long
string_arg=iter_string_handle,
+ output_types=self._flat_output_types,
f=remote_fn,
target_device=target,
buffer_size=10,
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 1d9371c7ac..12155a459c 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -11,8 +11,12 @@ py_library(
"//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets",
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/resnet50",
+ "//tensorflow/contrib/eager/python/examples/revnet",
+ "//tensorflow/contrib/eager/python/examples/revnet:config",
"//tensorflow/contrib/eager/python/examples/rnn_colorbot",
"//tensorflow/contrib/eager/python/examples/rnn_ptb",
+ "//tensorflow/contrib/eager/python/examples/sagan",
+ "//tensorflow/contrib/eager/python/examples/sagan:config",
"//tensorflow/contrib/eager/python/examples/spinn:data",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/densenet/BUILD b/tensorflow/contrib/eager/python/examples/densenet/BUILD
new file mode 100644
index 0000000000..de2a817d17
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/BUILD
@@ -0,0 +1,29 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_binary(
+ name = "densenet",
+ srcs = ["densenet.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ ],
+)
+
+cuda_py_test(
+ name = "densenet_test",
+ srcs = ["densenet_test.py"],
+ additional_deps = [
+ ":densenet",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet.py b/tensorflow/contrib/eager/python/examples/densenet/densenet.py
new file mode 100644
index 0000000000..3a2b2de250
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet.py
@@ -0,0 +1,274 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Densely Connected Convolutional Networks.
+
+Reference [
+Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+l2 = tf.keras.regularizers.l2
+
+
+class ConvBlock(tf.keras.Model):
+ """Convolutional Block consisting of (batchnorm->relu->conv).
+
+ Arguments:
+ num_filters: number of filters passed to a convolutional layer.
+ bottleneck: if True, then a 1x1 Conv is performed followed by 3x3 Conv.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_filters, bottleneck, weight_decay=1e-4,
+ dropout_rate=0):
+ super(ConvBlock, self).__init__()
+ self.bottleneck = bottleneck
+ inter_filter = num_filters * 4
+ # don't forget to set use_bias=False when using batchnorm
+ self.conv2 = tf.keras.layers.Conv2D(num_filters,
+ (3, 3),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.batchnorm1 = tf.keras.layers.BatchNormalization()
+ self.dropout = tf.keras.layers.Dropout(dropout_rate)
+
+ if self.bottleneck:
+ self.conv1 = tf.keras.layers.Conv2D(inter_filter,
+ (1, 1),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.batchnorm2 = tf.keras.layers.BatchNormalization()
+
+ def call(self, x, training=True):
+ output = self.batchnorm1(x, training=training)
+
+ if self.bottleneck:
+ output = self.conv1(tf.nn.relu(output))
+ output = self.batchnorm2(output, training=training)
+
+ output = self.conv2(tf.nn.relu(output))
+ output = self.dropout(output, training=training)
+
+ return output
+
+
+class TransitionBlock(tf.keras.Model):
+ """Transition Block to reduce the number of features.
+
+ Arguments:
+ num_filters: number of filters passed to a convolutional layer.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_filters, weight_decay=1e-4, dropout_rate=0):
+ super(TransitionBlock, self).__init__()
+ self.batchnorm = tf.keras.layers.BatchNormalization()
+ self.conv = tf.keras.layers.Conv2D(num_filters,
+ (1, 1),
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(weight_decay))
+ self.avg_pool = tf.keras.layers.AveragePooling2D()
+
+ def call(self, x, training=True):
+ output = self.batchnorm(x, training=training)
+ output = self.conv(tf.nn.relu(output))
+ output = self.avg_pool(output)
+ return output
+
+
+class DenseBlock(tf.keras.Model):
+ """Dense Block consisting of ConvBlocks where each block's
+ output is concatenated with its input.
+
+ Arguments:
+ num_layers: Number of layers in each block.
+ growth_rate: number of filters to add per conv block.
+ bottleneck: boolean, that decides which part of ConvBlock to call.
+ weight_decay: weight decay
+ dropout_rate: dropout rate.
+ """
+
+ def __init__(self, num_layers, growth_rate, bottleneck,
+ weight_decay=1e-4, dropout_rate=0):
+ super(DenseBlock, self).__init__()
+ self.num_layers = num_layers
+
+ self.blocks = []
+ for _ in range(int(self.num_layers)):
+ self.blocks.append(ConvBlock(growth_rate,
+ bottleneck,
+ weight_decay,
+ dropout_rate))
+
+ def call(self, x, training=True):
+ for i in range(int(self.num_layers)):
+ output = self.blocks[i](x, training=training)
+ x = tf.concat([x, output], axis=-1)
+
+ return x
+
+
+class DenseNet(tf.keras.Model):
+ """Creating the Densenet Architecture.
+
+ Arguments:
+ depth_of_model: number of layers in the model.
+ growth_rate: number of filters to add per conv block.
+ num_of_blocks: number of dense blocks.
+ output_classes: number of output classes.
+ num_layers_in_each_block: number of layers in each block.
+ If -1, then we calculate this by (depth-3)/4.
+ If positive integer, then the it is used as the
+ number of layers per block.
+ If list or tuple, then this list is used directly.
+ bottleneck: boolean, to decide which part of conv block to call.
+ compression: reducing the number of inputs(filters) to the transition block.
+ weight_decay: weight decay
+ rate: dropout rate.
+ pool_initial: If True add a 7x7 conv with stride 2 followed by 3x3 maxpool
+ else, do a 3x3 conv with stride 1.
+ include_top: If true, GlobalAveragePooling Layer and Dense layer are
+ included.
+ """
+
+ def __init__(self, depth_of_model, growth_rate, num_of_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5, weight_decay=1e-4,
+ dropout_rate=0, pool_initial=False, include_top=True):
+ super(DenseNet, self).__init__()
+ self.depth_of_model = depth_of_model
+ self.growth_rate = growth_rate
+ self.num_of_blocks = num_of_blocks
+ self.output_classes = output_classes
+ self.num_layers_in_each_block = num_layers_in_each_block
+ self.bottleneck = bottleneck
+ self.compression = compression
+ self.weight_decay = weight_decay
+ self.dropout_rate = dropout_rate
+ self.pool_initial = pool_initial
+ self.include_top = include_top
+
+ # deciding on number of layers in each block
+ if isinstance(self.num_layers_in_each_block, list) or isinstance(
+ self.num_layers_in_each_block, tuple):
+ self.num_layers_in_each_block = list(self.num_layers_in_each_block)
+ else:
+ if self.num_layers_in_each_block == -1:
+ if self.num_of_blocks != 3:
+ raise ValueError(
+ "Number of blocks must be 3 if num_layers_in_each_block is -1")
+ if (self.depth_of_model - 4) % 3 == 0:
+ num_layers = (self.depth_of_model - 4) / 3
+ if self.bottleneck:
+ num_layers //= 2
+ self.num_layers_in_each_block = [num_layers] * self.num_of_blocks
+ else:
+ raise ValueError("Depth must be 3N+4 if num_layer_in_each_block=-1")
+ else:
+ self.num_layers_in_each_block = [
+ self.num_layers_in_each_block] * self.num_of_blocks
+
+ # setting the filters and stride of the initial covn layer.
+ if self.pool_initial:
+ init_filters = (7, 7)
+ stride = (2, 2)
+ else:
+ init_filters = (3, 3)
+ stride = (1, 1)
+
+ self.num_filters = 2 * self.growth_rate
+
+ # first conv and pool layer
+ self.conv1 = tf.keras.layers.Conv2D(self.num_filters,
+ init_filters,
+ strides=stride,
+ padding="same",
+ use_bias=False,
+ kernel_initializer="he_normal",
+ kernel_regularizer=l2(
+ self.weight_decay))
+ if self.pool_initial:
+ self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(3, 3),
+ strides=(2, 2),
+ padding="same")
+ self.batchnorm1 = tf.keras.layers.BatchNormalization()
+
+ self.batchnorm2 = tf.keras.layers.BatchNormalization()
+
+ # last pooling and fc layer
+ if self.include_top:
+ self.last_pool = tf.keras.layers.GlobalAveragePooling2D()
+ self.classifier = tf.keras.layers.Dense(self.output_classes)
+
+ # calculating the number of filters after each block
+ num_filters_after_each_block = [self.num_filters]
+ for i in range(1, self.num_of_blocks):
+ temp_num_filters = num_filters_after_each_block[i-1] + (
+ self.growth_rate * self.num_layers_in_each_block[i-1])
+ # using compression to reduce the number of inputs to the
+ # transition block
+ temp_num_filters = int(temp_num_filters * compression)
+ num_filters_after_each_block.append(temp_num_filters)
+
+ # dense block initialization
+ self.dense_blocks = []
+ self.transition_blocks = []
+ for i in range(self.num_of_blocks):
+ self.dense_blocks.append(DenseBlock(self.num_layers_in_each_block[i],
+ self.growth_rate,
+ self.bottleneck,
+ self.weight_decay,
+ self.dropout_rate))
+ if i+1 < self.num_of_blocks:
+ self.transition_blocks.append(
+ TransitionBlock(num_filters_after_each_block[i+1],
+ self.weight_decay,
+ self.dropout_rate))
+
+ def call(self, x, training=True):
+ output = self.conv1(x)
+
+ if self.pool_initial:
+ output = self.batchnorm1(output, training=training)
+ output = tf.nn.relu(output)
+ output = self.pool1(output)
+
+ for i in range(self.num_of_blocks - 1):
+ output = self.dense_blocks[i](output, training=training)
+ output = self.transition_blocks[i](output, training=training)
+
+ output = self.dense_blocks[
+ self.num_of_blocks - 1](output, training=training)
+ output = self.batchnorm2(output, training=training)
+ output = tf.nn.relu(output)
+
+ if self.include_top:
+ output = self.last_pool(output)
+ output = self.classifier(output)
+
+ return output
diff --git a/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
new file mode 100644
index 0000000000..56d3362f3b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/densenet/densenet_test.py
@@ -0,0 +1,83 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for various Densenet architectures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.densenet import densenet
+
+
+class DensenetTest(tf.test.TestCase):
+
+ def test_bottleneck_true(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 3
+ output_classes = 10
+ num_layers_in_each_block = -1
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=False, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+ def test_bottleneck_false(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 3
+ output_classes = 10
+ num_layers_in_each_block = -1
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=False, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=False, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+ def test_pool_initial_true(self):
+ depth = 7
+ growth_rate = 2
+ num_blocks = 4
+ output_classes = 10
+ num_layers_in_each_block = [1, 2, 2, 1]
+ batch_size = 1
+
+ model = densenet.DenseNet(depth, growth_rate, num_blocks,
+ output_classes, num_layers_in_each_block,
+ bottleneck=True, compression=0.5,
+ weight_decay=1e-4, dropout_rate=0,
+ pool_initial=True, include_top=True)
+
+ rand_input = tf.random_uniform((batch_size, 32, 32, 3))
+ output_shape = model(rand_input).shape
+ self.assertEqual(output_shape, (batch_size, output_classes))
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
new file mode 100644
index 0000000000..15e013f219
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -0,0 +1,1184 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "image_captioning_with_attention.ipynb",
+ "version": "0.3.2",
+ "views": {},
+ "default_view": {},
+ "provenance": [
+ {
+ "file_id": "1HI8OK2sMjcx9CTWVn0122QAHOuXaOaMg",
+ "timestamp": 1530222436922
+ }
+ ],
+ "private_outputs": true,
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "metadata": {
+ "id": "K2s1A9eLRPEj",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Cffg2i257iMS",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Image Captioning with Attention\n",
+ "\n",
+ "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
+ "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /><span>Run in Google Colab</span></a> \n",
+ "</td><td>\n",
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /><span>View source on GitHub</span></a></td></table>"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "QASbY_HGo4Lq",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Image captioning is the task of generating a caption for an image. Given an image like this:\n",
+ "\n",
+ "![Man Surfing](https://tensorflow.org/images/surf.jpg) \n",
+ "\n",
+ "[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
+ "\n",
+ "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
+ "\n",
+ "![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
+ "\n",
+ "This model architecture below is similar to [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044). \n",
+ "\n",
+ "The code uses [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager), which you can learn more about in the linked guides.\n",
+ "\n",
+ "This notebook is an end-to-end example. If you run it, it will download the [MS-COCO](http://cocodataset.org/#home) dataset, preprocess and cache a subset of the images using Inception V3, train an encoder-decoder model, and use it to generate captions on new images.\n",
+ "\n",
+ "The code requires TensorFlow version >=1.9. If you're running this in [Colab]()\n",
+ "\n",
+ "In this example, we're training on a relatively small amount of data as an example. On a single P100 GPU, this example will take about ~2 hours to train. We train on the first 30,000 captions (corresponding to about ~20,000 images depending on shuffling, as there are multiple captions per image in the dataset)\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "U8l4RJ0XRPEm",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Import TensorFlow and enable eager execution\n",
+ "# This code requires TensorFlow version >=1.9\n",
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "# We'll generate plots of attention in order to see which parts of an image\n",
+ "# our model focuses on during captioning\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "# Scikit-learn includes many helpful utilities\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.utils import shuffle\n",
+ "\n",
+ "import re\n",
+ "import numpy as np\n",
+ "import os\n",
+ "import time\n",
+ "import json\n",
+ "from glob import glob\n",
+ "from PIL import Image\n",
+ "import pickle"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "b6qbGw8MRPE5",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Download and prepare the MS-COCO dataset\n",
+ "\n",
+ "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n",
+ "\n",
+ "**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "krQuPYTtRPE7",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "annotation_zip = tf.keras.utils.get_file('captions.zip', \n",
+ " cache_subdir=os.path.abspath('.'),\n",
+ " origin = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip',\n",
+ " extract = True)\n",
+ "annotation_file = os.path.dirname(annotation_zip)+'/annotations/captions_train2014.json'\n",
+ "\n",
+ "name_of_zip = 'train2014.zip'\n",
+ "if not os.path.exists(os.path.abspath('.') + '/' + name_of_zip):\n",
+ " image_zip = tf.keras.utils.get_file(name_of_zip, \n",
+ " cache_subdir=os.path.abspath('.'),\n",
+ " origin = 'http://images.cocodataset.org/zips/train2014.zip',\n",
+ " extract = True)\n",
+ " PATH = os.path.dirname(image_zip)+'/train2014/'\n",
+ "else:\n",
+ " PATH = os.path.abspath('.')+'/train2014/'"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "aANEzb5WwSzg",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Optionally, limit the size of the training set for faster training\n",
+ "For this example, we'll select a subset of 30,000 captions and use these and the corresponding images to train our model. As always, captioning quality will improve if you choose to use more data."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4G3b8x8_RPFD",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# read the json file\n",
+ "with open(annotation_file, 'r') as f:\n",
+ " annotations = json.load(f)\n",
+ "\n",
+ "# storing the captions and the image name in vectors\n",
+ "all_captions = []\n",
+ "all_img_name_vector = []\n",
+ "\n",
+ "for annot in annotations['annotations']:\n",
+ " caption = '<start> ' + annot['caption'] + ' <end>'\n",
+ " image_id = annot['image_id']\n",
+ " full_coco_image_path = PATH + 'COCO_train2014_' + '%012d.jpg' % (image_id)\n",
+ " \n",
+ " all_img_name_vector.append(full_coco_image_path)\n",
+ " all_captions.append(caption)\n",
+ "\n",
+ "# shuffling the captions and image_names together\n",
+ "# setting a random state\n",
+ "train_captions, img_name_vector = shuffle(all_captions,\n",
+ " all_img_name_vector,\n",
+ " random_state=1)\n",
+ "\n",
+ "# selecting the first 30000 captions from the shuffled set\n",
+ "num_examples = 30000\n",
+ "train_captions = train_captions[:num_examples]\n",
+ "img_name_vector = img_name_vector[:num_examples]"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "mPBMgK34RPFL",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "len(train_captions), len(all_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "8cSW4u-ORPFQ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess the images using InceptionV3\n",
+ "Next, we will use InceptionV3 (pretrained on Imagenet) to classify each image. We will extract features from the last convolutional layer. \n",
+ "\n",
+ "First, we will need to convert the images into the format inceptionV3 expects by:\n",
+ "* Resizing the image to (299, 299)\n",
+ "* Using the [preprocess_input](https://www.tensorflow.org/api_docs/python/tf/keras/applications/inception_v3/preprocess_input) method to place the pixels in the range of -1 to 1 (to match the format of the images used to train InceptionV3)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "zXR0217aRPFR",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def load_image(image_path):\n",
+ " img = tf.read_file(image_path)\n",
+ " img = tf.image.decode_jpeg(img, channels=3)\n",
+ " img = tf.image.resize_images(img, (299, 299))\n",
+ " img = tf.keras.applications.inception_v3.preprocess_input(img)\n",
+ " return img, image_path"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "MDvIu4sXRPFV",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Initialize InceptionV3 and load the pretrained Imagenet weights\n",
+ "\n",
+ "To do so, we'll create a tf.keras model where the output layer is the last convolutional layer in the InceptionV3 architecture. \n",
+ "* Each image is forwarded through the network and the vector that we get at the end is stored in a dictionary (image_name --> feature_vector). \n",
+ "* We use the last convolutional layer because we are using attention in this example. The shape of the output of this layer is ```8x8x2048```. \n",
+ "* We avoid doing this during training so it does not become a bottleneck. \n",
+ "* After all the images are passed through the network, we pickle the dictionary and save it to disk."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "RD3vW4SsRPFW",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "image_model = tf.keras.applications.InceptionV3(include_top=False, \n",
+ " weights='imagenet')\n",
+ "new_input = image_model.input\n",
+ "hidden_layer = image_model.layers[-1].output\n",
+ "\n",
+ "image_features_extract_model = tf.keras.Model(new_input, hidden_layer)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "rERqlR3WRPGO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Caching the features extracted from InceptionV3\n",
+ "\n",
+ "We will pre-process each image with InceptionV3 and cache the output to disk. Caching the output in RAM would be faster but memory intensive, requiring 8 \\* 8 \\* 2048 floats per image. At the time of writing, this would exceed the memory limitations of Colab (although these may change, an instance appears to have about 12GB of memory currently). \n",
+ "\n",
+ "Performance could be improved with a more sophisticated caching strategy (e.g., by sharding the images to reduce random access disk I/O) at the cost of more code.\n",
+ "\n",
+ "This will take about 10 minutes to run in Colab with a GPU. If you'd like to see a progress bar, you could: install [tqdm](https://github.com/tqdm/tqdm) (```!pip install tqdm```), then change this line: \n",
+ "\n",
+ "```for img, path in image_dataset:``` \n",
+ "\n",
+ "to:\n",
+ "\n",
+ "```for img, path in tqdm(image_dataset):```."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Dx_fvbVgRPGQ",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# getting the unique images\n",
+ "encode_train = sorted(set(img_name_vector))\n",
+ "\n",
+ "# feel free to change the batch_size according to your system configuration\n",
+ "image_dataset = tf.data.Dataset.from_tensor_slices(\n",
+ " encode_train).map(load_image).batch(16)\n",
+ "\n",
+ "for img, path in image_dataset:\n",
+ " batch_features = image_features_extract_model(img)\n",
+ " batch_features = tf.reshape(batch_features, \n",
+ " (batch_features.shape[0], -1, batch_features.shape[3]))\n",
+ "\n",
+ " for bf, p in zip(batch_features, path):\n",
+ " path_of_feature = p.numpy().decode(\"utf-8\")\n",
+ " np.save(path_of_feature, bf.numpy())"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nyqH3zFwRPFi",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Preprocess and tokenize the captions\n",
+ "\n",
+ "* First, we'll tokenize the captions (e.g., by splitting on spaces). This will give us a vocabulary of all the unique words in the data (e.g., \"surfing\", \"football\", etc).\n",
+ "* Next, we'll limit the vocabulary size to the top 5,000 words to save memory. We'll replace all other words with the token \"UNK\" (for unknown).\n",
+ "* Finally, we create a word --> index mapping and vice-versa.\n",
+ "* We will then pad all sequences to the be same length as the longest one. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "HZfK8RhQRPFj",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# This will find the maximum length of any caption in our dataset\n",
+ "def calc_max_length(tensor):\n",
+ " return max(len(t) for t in tensor)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "oJGE34aiRPFo",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# The steps above is a general process of dealing with text processing\n",
+ "\n",
+ "# choosing the top 5000 words from the vocabulary\n",
+ "top_k = 5000\n",
+ "tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=top_k, \n",
+ " oov_token=\"<unk>\", \n",
+ " filters='!\"#$%&()*+.,-/:;=?@[\\]^_`{|}~ ')\n",
+ "tokenizer.fit_on_texts(train_captions)\n",
+ "train_seqs = tokenizer.texts_to_sequences(train_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "8Q44tNQVRPFt",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "tokenizer.word_index = {key:value for key, value in tokenizer.word_index.items() if value <= top_k}\n",
+ "# putting <unk> token in the word2idx dictionary\n",
+ "tokenizer.word_index[tokenizer.oov_token] = top_k + 1\n",
+ "tokenizer.word_index['<pad>'] = 0"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "0fpJb5ojRPFv",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# creating the tokenized vectors\n",
+ "train_seqs = tokenizer.texts_to_sequences(train_captions)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "olQArbgbRPF1",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# creating a reverse mapping (index -> word)\n",
+ "index_word = {value:key for key, value in tokenizer.word_index.items()}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AidglIZVRPF4",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# padding each vector to the max_length of the captions\n",
+ "# if the max_length parameter is not provided, pad_sequences calculates that automatically\n",
+ "cap_vector = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "gL0wkttkRPGA",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# calculating the max_length \n",
+ "# used to store the attention weights\n",
+ "max_length = calc_max_length(train_seqs)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "M3CD75nDpvTI",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Split the data into training and testing"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "iS7DDMszRPGF",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Create training and validation sets using 80-20 split\n",
+ "img_name_train, img_name_val, cap_train, cap_val = train_test_split(img_name_vector, \n",
+ " cap_vector, \n",
+ " test_size=0.2, \n",
+ " random_state=0)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "XmViPkRFRPGH",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "len(img_name_train), len(cap_train), len(img_name_val), len(cap_val)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "uEWM9xrYcg45",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Our images and captions are ready! Next, let's create a tf.data dataset to use for training our model.\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Q3TnZ1ToRPGV",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# feel free to change these parameters according to your system's configuration\n",
+ "\n",
+ "BATCH_SIZE = 64\n",
+ "BUFFER_SIZE = 1000\n",
+ "embedding_dim = 256\n",
+ "units = 512\n",
+ "vocab_size = len(tokenizer.word_index)\n",
+ "# shape of the vector extracted from InceptionV3 is (64, 2048)\n",
+ "# these two variables represent that\n",
+ "features_shape = 2048\n",
+ "attention_features_shape = 64"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "SmZS2N0bXG3T",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# loading the numpy files \n",
+ "def map_func(img_name, cap):\n",
+ " img_tensor = np.load(img_name.decode('utf-8')+'.npy')\n",
+ " return img_tensor, cap"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "FDF_Nm3tRPGZ",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))\n",
+ "\n",
+ "# using map to load the numpy files in parallel\n",
+ "# NOTE: Be sure to set num_parallel_calls to the number of CPU cores you have\n",
+ "# https://www.tensorflow.org/api_docs/python/tf/py_func\n",
+ "dataset = dataset.map(lambda item1, item2: tf.py_func(\n",
+ " map_func, [item1, item2], [tf.float32, tf.int32]), num_parallel_calls=8)\n",
+ "\n",
+ "# shuffling and batching\n",
+ "dataset = dataset.shuffle(BUFFER_SIZE)\n",
+ "# https://www.tensorflow.org/api_docs/python/tf/contrib/data/batch_and_drop_remainder\n",
+ "dataset = dataset.batch(BATCH_SIZE)\n",
+ "dataset = dataset.prefetch(1)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nrvoDphgRPGd",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Model\n",
+ "\n",
+ "Fun fact, the decoder below is identical to the one in the example for [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
+ "\n",
+ "The model architecture is inspired by the [Show, Attend and Tell](https://arxiv.org/pdf/1502.03044.pdf) paper.\n",
+ "\n",
+ "* In this example, we extract the features from the lower convolutional layer of InceptionV3 giving us a vector of shape (8, 8, 2048). \n",
+ "* We squash that to a shape of (64, 2048).\n",
+ "* This vector is then passed through the CNN Encoder(which consists of a single Fully connected layer).\n",
+ "* The RNN(here GRU) attends over the image to predict the next word."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "AAppCGLKRPGd",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def gru(units):\n",
+ " # If you have a GPU, we recommend using the CuDNNGRU layer (it provides a \n",
+ " # significant speedup).\n",
+ " if tf.test.is_gpu_available():\n",
+ " return tf.keras.layers.CuDNNGRU(units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ " else:\n",
+ " return tf.keras.layers.GRU(units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_activation='sigmoid', \n",
+ " recurrent_initializer='glorot_uniform')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "ja2LFTMSdeV3",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class BahdanauAttention(tf.keras.Model):\n",
+ " def __init__(self, units):\n",
+ " super(BahdanauAttention, self).__init__()\n",
+ " self.W1 = tf.keras.layers.Dense(units)\n",
+ " self.W2 = tf.keras.layers.Dense(units)\n",
+ " self.V = tf.keras.layers.Dense(1)\n",
+ " \n",
+ " def call(self, features, hidden):\n",
+ " # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)\n",
+ " \n",
+ " # hidden shape == (batch_size, hidden_size)\n",
+ " # hidden_with_time_axis shape == (batch_size, 1, hidden_size)\n",
+ " hidden_with_time_axis = tf.expand_dims(hidden, 1)\n",
+ " \n",
+ " # score shape == (batch_size, 64, hidden_size)\n",
+ " score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))\n",
+ " \n",
+ " # attention_weights shape == (batch_size, 64, 1)\n",
+ " # we get 1 at the last axis because we are applying score to self.V\n",
+ " attention_weights = tf.nn.softmax(self.V(score), axis=1)\n",
+ " \n",
+ " # context_vector shape after sum == (batch_size, hidden_size)\n",
+ " context_vector = attention_weights * features\n",
+ " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
+ " \n",
+ " return context_vector, attention_weights"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AZ7R1RxHRPGf",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class CNN_Encoder(tf.keras.Model):\n",
+ " # Since we have already extracted the features and dumped it using pickle\n",
+ " # This encoder passes those features through a Fully connected layer\n",
+ " def __init__(self, embedding_dim):\n",
+ " super(CNN_Encoder, self).__init__()\n",
+ " # shape after fc == (batch_size, 64, embedding_dim)\n",
+ " self.fc = tf.keras.layers.Dense(embedding_dim)\n",
+ " \n",
+ " def call(self, x):\n",
+ " x = self.fc(x)\n",
+ " x = tf.nn.relu(x)\n",
+ " return x"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "V9UbGQmERPGi",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class RNN_Decoder(tf.keras.Model):\n",
+ " def __init__(self, embedding_dim, units, vocab_size):\n",
+ " super(RNN_Decoder, self).__init__()\n",
+ " self.units = units\n",
+ "\n",
+ " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
+ " self.gru = gru(self.units)\n",
+ " self.fc1 = tf.keras.layers.Dense(self.units)\n",
+ " self.fc2 = tf.keras.layers.Dense(vocab_size)\n",
+ " \n",
+ " self.attention = BahdanauAttention(self.units)\n",
+ " \n",
+ " def call(self, x, features, hidden):\n",
+ " # defining attention as a separate model\n",
+ " context_vector, attention_weights = self.attention(features, hidden)\n",
+ " \n",
+ " # x shape after passing through embedding == (batch_size, 1, embedding_dim)\n",
+ " x = self.embedding(x)\n",
+ " \n",
+ " # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)\n",
+ " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
+ " \n",
+ " # passing the concatenated vector to the GRU\n",
+ " output, state = self.gru(x)\n",
+ " \n",
+ " # shape == (batch_size, max_length, hidden_size)\n",
+ " x = self.fc1(output)\n",
+ " \n",
+ " # x shape == (batch_size * max_length, hidden_size)\n",
+ " x = tf.reshape(x, (-1, x.shape[2]))\n",
+ " \n",
+ " # output shape == (batch_size * max_length, vocab)\n",
+ " x = self.fc2(x)\n",
+ "\n",
+ " return x, state, attention_weights\n",
+ "\n",
+ " def reset_state(self, batch_size):\n",
+ " return tf.zeros((batch_size, self.units))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Qs_Sr03wRPGk",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "encoder = CNN_Encoder(embedding_dim)\n",
+ "decoder = RNN_Decoder(embedding_dim, units, vocab_size)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "-bYN7xA0RPGl",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "optimizer = tf.train.AdamOptimizer()\n",
+ "\n",
+ "# We are masking the loss calculated for padding\n",
+ "def loss_function(real, pred):\n",
+ " mask = 1 - np.equal(real, 0)\n",
+ " loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask\n",
+ " return tf.reduce_mean(loss_)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "PHod7t72RPGn",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Training\n",
+ "\n",
+ "* We extract the features stored in the respective `.npy` files and then pass those features through the encoder.\n",
+ "* The encoder output, hidden state(initialized to 0) and the decoder input (which is the start token) is passed to the decoder.\n",
+ "* The decoder returns the predictions and the decoder hidden state.\n",
+ "* The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.\n",
+ "* Use teacher forcing to decide the next input to the decoder.\n",
+ "* Teacher forcing is the technique where the target word is passed as the next input to the decoder.\n",
+ "* The final step is to calculate the gradients and apply it to the optimizer and backpropagate.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Vt4WZ5mhJE-E",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# adding this in a separate cell because if you run the training cell \n",
+ "# many times, the loss_plot array will be reset\n",
+ "loss_plot = []"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "UlA4VIQpRPGo",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "EPOCHS = 20\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " start = time.time()\n",
+ " total_loss = 0\n",
+ " \n",
+ " for (batch, (img_tensor, target)) in enumerate(dataset):\n",
+ " loss = 0\n",
+ " \n",
+ " # initializing the hidden state for each batch\n",
+ " # because the captions are not related from image to image\n",
+ " hidden = decoder.reset_state(batch_size=target.shape[0])\n",
+ "\n",
+ " dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * BATCH_SIZE, 1)\n",
+ " \n",
+ " with tf.GradientTape() as tape:\n",
+ " features = encoder(img_tensor)\n",
+ " \n",
+ " for i in range(1, target.shape[1]):\n",
+ " # passing the features through the decoder\n",
+ " predictions, hidden, _ = decoder(dec_input, features, hidden)\n",
+ "\n",
+ " loss += loss_function(target[:, i], predictions)\n",
+ " \n",
+ " # using teacher forcing\n",
+ " dec_input = tf.expand_dims(target[:, i], 1)\n",
+ " \n",
+ " total_loss += (loss / int(target.shape[1]))\n",
+ " \n",
+ " variables = encoder.variables + decoder.variables\n",
+ " \n",
+ " gradients = tape.gradient(loss, variables) \n",
+ " \n",
+ " optimizer.apply_gradients(zip(gradients, variables), tf.train.get_or_create_global_step())\n",
+ " \n",
+ " if batch % 100 == 0:\n",
+ " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, \n",
+ " batch, \n",
+ " loss.numpy() / int(target.shape[1])))\n",
+ " # storing the epoch end loss value to plot later\n",
+ " loss_plot.append(total_loss / len(cap_vector))\n",
+ " \n",
+ " print ('Epoch {} Loss {:.6f}'.format(epoch + 1, \n",
+ " total_loss/len(cap_vector)))\n",
+ " print ('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1Wm83G-ZBPcC",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "plt.plot(loss_plot)\n",
+ "plt.xlabel('Epochs')\n",
+ "plt.ylabel('Loss')\n",
+ "plt.title('Loss Plot')\n",
+ "plt.show()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "xGvOcLQKghXN",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Caption!\n",
+ "\n",
+ "* The evaluate function is similar to the training loop, except we don't use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.\n",
+ "* Stop predicting when the model predicts the end token.\n",
+ "* And store the attention weights for every time step."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "RCWpDtyNRPGs",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def evaluate(image):\n",
+ " attention_plot = np.zeros((max_length, attention_features_shape))\n",
+ "\n",
+ " hidden = decoder.reset_state(batch_size=1)\n",
+ "\n",
+ " temp_input = tf.expand_dims(load_image(image)[0], 0)\n",
+ " img_tensor_val = image_features_extract_model(temp_input)\n",
+ " img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))\n",
+ "\n",
+ " features = encoder(img_tensor_val)\n",
+ "\n",
+ " dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)\n",
+ " result = []\n",
+ "\n",
+ " for i in range(max_length):\n",
+ " predictions, hidden, attention_weights = decoder(dec_input, features, hidden)\n",
+ "\n",
+ " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
+ "\n",
+ " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " result.append(index_word[predicted_id])\n",
+ "\n",
+ " if index_word[predicted_id] == '<end>':\n",
+ " return result, attention_plot\n",
+ "\n",
+ " dec_input = tf.expand_dims([predicted_id], 0)\n",
+ "\n",
+ " attention_plot = attention_plot[:len(result), :]\n",
+ " return result, attention_plot"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "fD_y7PD6RPGt",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "def plot_attention(image, result, attention_plot):\n",
+ " temp_image = np.array(Image.open(image))\n",
+ "\n",
+ " fig = plt.figure(figsize=(10, 10))\n",
+ " \n",
+ " len_result = len(result)\n",
+ " for l in range(len_result):\n",
+ " temp_att = np.resize(attention_plot[l], (8, 8))\n",
+ " ax = fig.add_subplot(len_result//2, len_result//2, l+1)\n",
+ " ax.set_title(result[l])\n",
+ " img = ax.imshow(temp_image)\n",
+ " ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())\n",
+ "\n",
+ " plt.tight_layout()\n",
+ " plt.show()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "io7ws3ReRPGv",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# captions on the validation set\n",
+ "rid = np.random.randint(0, len(img_name_val))\n",
+ "image = img_name_val[rid]\n",
+ "real_caption = ' '.join([index_word[i] for i in cap_val[rid] if i not in [0]])\n",
+ "result, attention_plot = evaluate(image)\n",
+ "\n",
+ "print ('Real Caption:', real_caption)\n",
+ "print ('Prediction Caption:', ' '.join(result))\n",
+ "plot_attention(image, result, attention_plot)\n",
+ "# opening the image\n",
+ "Image.open(img_name_val[rid])"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Rprk3HEvZuxb",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Try it on your own images\n",
+ "For fun, below we've provided a method you can use to caption your own images with the model we've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for weird results!)\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "9Psd1quzaAWg",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "image_url = 'https://tensorflow.org/images/imcap_prediction.png'\n",
+ "image_extension = image_url[-4:]\n",
+ "image_path = tf.keras.utils.get_file('image'+image_extension, \n",
+ " origin=image_url)\n",
+ "\n",
+ "result, attention_plot = evaluate(image_path)\n",
+ "print ('Prediction Caption:', ' '.join(result))\n",
+ "plot_attention(image_path, result, attention_plot)\n",
+ "# opening the image\n",
+ "Image.open(image_path)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "VJZXyJco6uLO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Next steps\n",
+ "\n",
+ "Congrats! You've just trained an image captioning model with attention. Next, we recommend taking a look at this example [Neural Machine Translation with Attention]( https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb). It uses a similar architecture to translate between Spanish and English sentences. You can also experiment with training the code in this notebook on a different dataset."
+ ]
+ }
+ ]
+}
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
new file mode 100644
index 0000000000..b0c8773993
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -0,0 +1,689 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "text_generation.ipynb",
+ "version": "0.3.2",
+ "views": {},
+ "default_view": {},
+ "provenance": [],
+ "private_outputs": true,
+ "collapsed_sections": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "metadata": {
+ "id": "hcD2nPQvPOFM",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "##### Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
+ "\n",
+ "# Text Generation using a RNN\n",
+ "\n",
+ "<table align=\"left\"><td>\n",
+ "<a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\">\n",
+ " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n",
+ "</td><td>\n",
+ "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on Github</a></td></table>"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BwpJ5IffzRG6",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This notebook demonstrates how to generate text using an RNN using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). If you like, you can write a similar [model](https://github.com/fchollet/deep-learning-with-python-notebooks/blob/master/8.1-text-generation-with-lstm.ipynb) using less code. Here, we show a lower-level impementation that's useful to understand as prework before diving in to deeper examples in a similar, like [Neural Machine Translation with Attention](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb).\n",
+ "\n",
+ "This notebook is an end-to-end example. When you run it, it will download a dataset of Shakespeare's writing. We'll use a collection of plays, borrowed from Andrej Karpathy's excellent [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/). The notebook will train a model, and use it to generate sample output.\n",
+ " \n",
+ "Here is the output(with start string='w') after training a single layer GRU for 30 epochs with the default settings below:\n",
+ "\n",
+ "```\n",
+ "were to the death of him\n",
+ "And nothing of the field in the view of hell,\n",
+ "When I said, banish him, I will not burn thee that would live.\n",
+ "\n",
+ "HENRY BOLINGBROKE:\n",
+ "My gracious uncle--\n",
+ "\n",
+ "DUKE OF YORK:\n",
+ "As much disgraced to the court, the gods them speak,\n",
+ "And now in peace himself excuse thee in the world.\n",
+ "\n",
+ "HORTENSIO:\n",
+ "Madam, 'tis not the cause of the counterfeit of the earth,\n",
+ "And leave me to the sun that set them on the earth\n",
+ "And leave the world and are revenged for thee.\n",
+ "\n",
+ "GLOUCESTER:\n",
+ "I would they were talking with the very name of means\n",
+ "To make a puppet of a guest, and therefore, good Grumio,\n",
+ "Nor arm'd to prison, o' the clouds, of the whole field,\n",
+ "With the admire\n",
+ "With the feeding of thy chair, and we have heard it so,\n",
+ "I thank you, sir, he is a visor friendship with your silly your bed.\n",
+ "\n",
+ "SAMPSON:\n",
+ "I do desire to live, I pray: some stand of the minds, make thee remedies\n",
+ "With the enemies of my soul.\n",
+ "\n",
+ "MENENIUS:\n",
+ "I'll keep the cause of my mistress.\n",
+ "\n",
+ "POLIXENES:\n",
+ "My brother Marcius!\n",
+ "\n",
+ "Second Servant:\n",
+ "Will't ple\n",
+ "```\n",
+ "\n",
+ "Of course, while some of the sentences are grammatical, most do not make sense. But, consider:\n",
+ "\n",
+ "* Our model is character based (when we began training, it did not yet know how to spell a valid English word, or that words were even a unit of text).\n",
+ "\n",
+ "* The structure of the output resembles a play (blocks begin with a speaker name, in all caps similar to the original text). Sentences generally end with a period. If you look at the text from a distance (or don't read the invididual words too closely, it appears as if it's an excerpt from a play).\n",
+ "\n",
+ "As a next step, you can experiment training the model on a different dataset - any large text file(ASCII) will do, and you can modify a single line of code below to make that change. Have fun!\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "R3p22DBDsaCA",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Install unidecode library\n",
+ "A helpful library to convert unicode to ASCII."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "wZ6LOM12wKGH",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "!pip install unidecode"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "WGyKZj3bzf9p",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Import tensorflow and enable eager execution."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "yG_n40gFzf9s",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Import TensorFlow >= 1.9 and enable eager execution\n",
+ "import tensorflow as tf\n",
+ "\n",
+ "# Note: Once you enable eager execution, it cannot be disabled. \n",
+ "tf.enable_eager_execution()\n",
+ "\n",
+ "import numpy as np\n",
+ "import re\n",
+ "import random\n",
+ "import unidecode\n",
+ "import time"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "EHDoRoc5PKWz",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Download the dataset\n",
+ "\n",
+ "In this example, we will use the [shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt). You can use any other dataset that you like.\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "pD_55cOxLkAb",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/yashkatariya/shakespeare.txt')"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "UHjdCjDuSvX_",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Read the dataset\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "-E5JvY3wzf94",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "text = unidecode.unidecode(open(path_to_file).read())\n",
+ "# length of text is the number of characters in it\n",
+ "print (len(text))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Il9ww98izf-D",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Creating dictionaries to map from characters to their indices and vice-versa, which will be used to vectorize the inputs"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "IalZLbvOzf-F",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# unique contains all the unique characters in the file\n",
+ "unique = sorted(set(text))\n",
+ "\n",
+ "# creating a mapping from unique characters to indices\n",
+ "char2idx = {u:i for i, u in enumerate(unique)}\n",
+ "idx2char = {i:u for i, u in enumerate(unique)}"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1v_qUYfAzf-I",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# setting the maximum length sentence we want for a single input in characters\n",
+ "max_length = 100\n",
+ "\n",
+ "# length of the vocabulary in chars\n",
+ "vocab_size = len(unique)\n",
+ "\n",
+ "# the embedding dimension \n",
+ "embedding_dim = 256\n",
+ "\n",
+ "# number of RNN (here GRU) units\n",
+ "units = 1024\n",
+ "\n",
+ "# batch size \n",
+ "BATCH_SIZE = 64\n",
+ "\n",
+ "# buffer size to shuffle our dataset\n",
+ "BUFFER_SIZE = 10000"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "LFjSVAlWzf-N",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating the input and output tensors\n",
+ "\n",
+ "Vectorizing the input and the target text because our model cannot understand strings only numbers.\n",
+ "\n",
+ "But first, we need to create the input and output vectors.\n",
+ "Remember the max_length we set above, we will use it here. We are creating **max_length** chunks of input, where each input vector is all the characters in that chunk except the last and the target vector is all the characters in that chunk except the first.\n",
+ "\n",
+ "For example, consider that the string = 'tensorflow' and the max_length is 9\n",
+ "\n",
+ "So, the `input = 'tensorflo'` and `output = 'ensorflow'`\n",
+ "\n",
+ "After creating the vectors, we convert each character into numbers using the **char2idx** dictionary we created above."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "0UHJDA39zf-O",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "input_text = []\n",
+ "target_text = []\n",
+ "\n",
+ "for f in range(0, len(text)-max_length, max_length):\n",
+ " inps = text[f:f+max_length]\n",
+ " targ = text[f+1:f+1+max_length]\n",
+ "\n",
+ " input_text.append([char2idx[i] for i in inps])\n",
+ " target_text.append([char2idx[t] for t in targ])\n",
+ " \n",
+ "print (np.array(input_text).shape)\n",
+ "print (np.array(target_text).shape)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "MJdfPmdqzf-R",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating batches and shuffling them using tf.data"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "p2pGotuNzf-S",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "dataset = tf.data.Dataset.from_tensor_slices((input_text, target_text)).shuffle(BUFFER_SIZE)\n",
+ "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(BATCH_SIZE))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "m8gPwEjRzf-Z",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Creating the model\n",
+ "\n",
+ "We use the Model Subclassing API which gives us full flexibility to create the model and change it however we like. We use 3 layers to define our model.\n",
+ "\n",
+ "* Embedding layer\n",
+ "* GRU layer (you can use an LSTM layer here)\n",
+ "* Fully connected layer"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "P3KTiiInzf-a",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "class Model(tf.keras.Model):\n",
+ " def __init__(self, vocab_size, embedding_dim, units, batch_size):\n",
+ " super(Model, self).__init__()\n",
+ " self.units = units\n",
+ " self.batch_sz = batch_size\n",
+ "\n",
+ " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
+ "\n",
+ " if tf.test.is_gpu_available():\n",
+ " self.gru = tf.keras.layers.CuDNNGRU(self.units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ " else:\n",
+ " self.gru = tf.keras.layers.GRU(self.units, \n",
+ " return_sequences=True, \n",
+ " return_state=True, \n",
+ " recurrent_activation='sigmoid', \n",
+ " recurrent_initializer='glorot_uniform')\n",
+ "\n",
+ " self.fc = tf.keras.layers.Dense(vocab_size)\n",
+ " \n",
+ " def call(self, x, hidden):\n",
+ " x = self.embedding(x)\n",
+ "\n",
+ " # output shape == (batch_size, max_length, hidden_size) \n",
+ " # states shape == (batch_size, hidden_size)\n",
+ "\n",
+ " # states variable to preserve the state of the model\n",
+ " # this will be used to pass at every step to the model while training\n",
+ " output, states = self.gru(x, initial_state=hidden)\n",
+ "\n",
+ "\n",
+ " # reshaping the output so that we can pass it to the Dense layer\n",
+ " # after reshaping the shape is (batch_size * max_length, hidden_size)\n",
+ " output = tf.reshape(output, (-1, output.shape[2]))\n",
+ "\n",
+ " # The dense layer will output predictions for every time_steps(max_length)\n",
+ " # output shape after the dense layer == (max_length * batch_size, vocab_size)\n",
+ " x = self.fc(output)\n",
+ "\n",
+ " return x, states"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "trpqTWyvk0nr",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Call the model and set the optimizer and the loss function"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "7t2XrzEOzf-e",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "model = Model(vocab_size, embedding_dim, units, BATCH_SIZE)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "dkjWIATszf-h",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "optimizer = tf.train.AdamOptimizer()\n",
+ "\n",
+ "# using sparse_softmax_cross_entropy so that we don't have to create one-hot vectors\n",
+ "def loss_function(real, preds):\n",
+ " return tf.losses.sparse_softmax_cross_entropy(labels=real, logits=preds)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "lPrP0XMUzf-p",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Train the model\n",
+ "\n",
+ "Here we will use a custom training loop with the help of GradientTape()\n",
+ "\n",
+ "* We initialize the hidden state of the model with zeros and shape == (batch_size, number of rnn units). We do this by calling the function defined while creating the model.\n",
+ "\n",
+ "* Next, we iterate over the dataset(batch by batch) and calculate the **predictions and the hidden states** associated with that input.\n",
+ "\n",
+ "* There are a lot of interesting things happening here.\n",
+ " * The model gets hidden state(initialized with 0), lets call that **H0** and the first batch of input, lets call that **I0**.\n",
+ " * The model then returns the predictions **P1** and **H1**.\n",
+ " * For the next batch of input, the model receives **I1** and **H1**.\n",
+ " * The interesting thing here is that we pass **H1** to the model with **I1** which is how the model learns. The context learned from batch to batch is contained in the **hidden state**.\n",
+ " * We continue doing this until the dataset is exhausted and then we start a new epoch and repeat this.\n",
+ "\n",
+ "* After calculating the predictions, we calculate the **loss** using the loss function defined above. Then we calculate the gradients of the loss with respect to the model variables(input)\n",
+ "\n",
+ "* Finally, we take a step in that direction with the help of the optimizer using the apply_gradients function.\n",
+ "\n",
+ "Note:- If you are running this notebook in Colab which has a **Tesla K80 GPU** it takes about 23 seconds per epoch.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "d4tSNwymzf-q",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Training step\n",
+ "\n",
+ "EPOCHS = 30\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " start = time.time()\n",
+ " \n",
+ " # initializing the hidden state at the start of every epoch\n",
+ " hidden = model.reset_states()\n",
+ " \n",
+ " for (batch, (inp, target)) in enumerate(dataset):\n",
+ " with tf.GradientTape() as tape:\n",
+ " # feeding the hidden state back into the model\n",
+ " # This is the interesting step\n",
+ " predictions, hidden = model(inp, hidden)\n",
+ " \n",
+ " # reshaping the target because that's how the \n",
+ " # loss function expects it\n",
+ " target = tf.reshape(target, (-1,))\n",
+ " loss = loss_function(target, predictions)\n",
+ " \n",
+ " grads = tape.gradient(loss, model.variables)\n",
+ " optimizer.apply_gradients(zip(grads, model.variables), global_step=tf.train.get_or_create_global_step())\n",
+ "\n",
+ " if batch % 100 == 0:\n",
+ " print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch+1,\n",
+ " batch,\n",
+ " loss))\n",
+ " \n",
+ " print ('Epoch {} Loss {:.4f}'.format(epoch+1, loss))\n",
+ " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "DjGz1tDkzf-u",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Predicting using our trained model\n",
+ "\n",
+ "The below code block is used to generated the text\n",
+ "\n",
+ "* We start by choosing a start string and initializing the hidden state and setting the number of characters we want to generate.\n",
+ "\n",
+ "* We get predictions using the start_string and the hidden state\n",
+ "\n",
+ "* Then we use a multinomial distribution to calculate the index of the predicted word. **We use this predicted word as our next input to the model**\n",
+ "\n",
+ "* **The hidden state returned by the model is fed back into the model so that it now has more context rather than just one word.** After we predict the next word, the modified hidden states are again fed back into the model, which is how it learns as it gets more context from the previously predicted words.\n",
+ "\n",
+ "* If you see the predictions, the model knows when to capitalize, make paragraphs and the text follows a shakespeare style of writing which is pretty awesome!"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WvuwZBX5Ogfd",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Evaluation step(generating text using the model learned)\n",
+ "\n",
+ "# number of characters to generate\n",
+ "num_generate = 1000\n",
+ "\n",
+ "# You can change the start string to experiment\n",
+ "start_string = 'Q'\n",
+ "# converting our start string to numbers(vectorizing!) \n",
+ "input_eval = [char2idx[s] for s in start_string]\n",
+ "input_eval = tf.expand_dims(input_eval, 0)\n",
+ "\n",
+ "# empty string to store our results\n",
+ "text_generated = ''\n",
+ "\n",
+ "# low temperatures results in more predictable text.\n",
+ "# higher temperatures results in more surprising text\n",
+ "# experiment to find the best setting\n",
+ "temperature = 1.0\n",
+ "\n",
+ "# hidden state shape == (batch_size, number of rnn units); here batch size == 1\n",
+ "hidden = [tf.zeros((1, units))]\n",
+ "for i in range(num_generate):\n",
+ " predictions, hidden = model(input_eval, hidden)\n",
+ "\n",
+ " # using a multinomial distribution to predict the word returned by the model\n",
+ " predictions = predictions / temperature\n",
+ " predicted_id = tf.multinomial(tf.exp(predictions), num_samples=1)[0][0].numpy()\n",
+ " \n",
+ " # We pass the predicted word as the next input to the model\n",
+ " # along with the previous hidden state\n",
+ " input_eval = tf.expand_dims([predicted_id], 0)\n",
+ " \n",
+ " text_generated += idx2char[predicted_id]\n",
+ "\n",
+ "print (start_string + text_generated)"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "AM2Uma_-yVIq",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Next steps\n",
+ "\n",
+ "* Change the start string to a different character, or the start of a sentence.\n",
+ "* Experiment with training on a different, or with different parameters. [Project Gutenberg](http://www.gutenberg.org/ebooks/100), for example, contains a large collection of books.\n",
+ "* Experiment with the temperature parameter.\n",
+ "* Add another RNN layer.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gtEd86sX5cB2",
+ "colab_type": "code",
+ "colab": {
+ "autoexec": {
+ "startup": false,
+ "wait_interval": 0
+ }
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+}
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
index 98b4ce1b26..729d8525fa 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -57,11 +57,6 @@ class Dynamics(tf.keras.Model):
self.eps = tfe.Variable(
initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
- # TODO(lxuechen): Remove this after model.add_weight is in place
- self.vars_not_in_layers = [self.eps]
- self.vars_not_in_layers += self.position_fn.vars_not_in_layers
- self.vars_not_in_layers += self.momentum_fn.vars_not_in_layers
-
def apply_transition(self, position):
"""Propose a new state and perform the accept or reject step."""
@@ -290,86 +285,35 @@ class Dynamics(tf.keras.Model):
return grad
-# Defining loss and grads for training
-def compute_loss(x, dynamics, scale=.1, eps=1e-4):
- """Compute loss defined in equation (8)."""
-
- z = tf.random_normal(tf.shape(x))
- x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
- z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
-
- # Add eps for numerical stability; following released impl
- x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
- z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
-
- loss = tf.reduce_mean(
- (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
-
- return loss, x_out
-
-
-def loss_and_grads(x, dynamics):
- """Obtain loss value and gradients."""
-
- with tf.GradientTape() as tape:
- loss_val, x_out = compute_loss(x, dynamics)
-
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- grads = tape.gradient(loss_val, vars_)
-
- return loss_val, grads, x_out
-
-
-def warmup(dynamics, optimizer, n_iters=1, n_samples=200):
- """Warmup optimization to reduce overhead."""
-
- samples = tf.random_normal(
- shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
-
- for _ in range(n_iters):
- _, grads, samples = loss_and_grads(samples, dynamics)
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- optimizer.apply_gradients(zip(grads, vars_))
-
-
-def fit(dynamics,
- optimizer,
- n_samples=200,
- n_iters=5000,
- verbose=True,
- logdir=None):
- """Fit L2HMC sampler with given log-likelihood function."""
-
- if logdir:
- summary_writer = tf.contrib.summary.create_file_writer(logdir)
+# Examples of unnormalized log density/probabilities
+def get_scg_energy_fn():
+ """Get energy function for 2d strongly correlated Gaussian."""
- samples = tf.random_normal(
- shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+ # Avoid recreating tf constants on each invocation of gradients
+ mu = tf.constant([0., 0.])
+ sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+ sigma_inv = tf.matrix_inverse(sigma)
- tf.train.get_or_create_global_step()
- for i in range(n_iters):
- loss, grads, samples = loss_and_grads(samples, dynamics)
- # TODO(lxuechen): Proper learning rate decay
- grads_ = [grad * .96**(i // 1000) for grad in grads]
- vars_ = dynamics.variables + dynamics.vars_not_in_layers
- optimizer.apply_gradients(
- zip(grads_, vars_), global_step=tf.train.get_global_step())
+ def energy(x):
+ """Unnormalized log density/energy of 2d strongly correlated Gaussian."""
- if verbose:
- print("Iteration %d: loss %.4f" % (i, loss))
+ xmmu = x - mu
+ return .5 * tf.diag_part(
+ tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
- if logdir:
- with summary_writer.as_default():
- with tf.contrib.summary.always_record_summaries():
- tf.contrib.summary.scalar("loss", loss)
+ return energy
-def get_scg_energy_fn():
+def get_multivariate_gaussian_energy_fn(x_dim=2):
"""Get energy function for 2d strongly correlated Gaussian."""
- # Avoid recreating tf constants on each invocation of gradients
- mu = tf.constant([0., 0.])
- sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
+ mu = tf.random_normal(shape=[x_dim])
+ # Lower triangularize and positive diagonal
+ l = tf.sigmoid(
+ tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0))
+ # Exploit Cholesky decomposition
+ sigma = tf.matmul(l, tf.transpose(l))
+ sigma *= 100. # Small covariance causes extreme numerical instability
sigma_inv = tf.matrix_inverse(sigma)
def energy(x):
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index 522a7c9380..e33b4cae4c 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -32,16 +32,83 @@ def get_default_hparams():
n_samples=200,
n_steps=10,
eps=.1,
- n_iters=5,
- learning_rate=.001,
- n_warmup_iters=1)
+ n_iters=10,
+ learning_rate=.0003,
+ n_warmup_iters=3)
+
+
+# Relevant functions for benchmarking
+def compute_loss(dynamics, x, scale=.1, eps=1e-4):
+ """Compute loss defined in equation (8)."""
+
+ z = tf.random_normal(tf.shape(x))
+ x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
+ z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
+
+ # Add eps for numerical stability; following released impl
+ x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
+ z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
+
+ loss = tf.reduce_mean(
+ (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
+
+ return loss, x_out
+
+
+def loss_and_grads(dynamics, x, loss_fn=compute_loss):
+ """Obtain loss value and gradients."""
+
+ with tf.GradientTape() as tape:
+ loss_val, x_out = loss_fn(dynamics, x)
+ grads = tape.gradient(loss_val, dynamics.variables)
+
+ return loss_val, grads, x_out
+
+
+def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss):
+ """Warmup optimization to reduce overhead."""
+
+ samples = tf.random_normal(
+ shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
+
+ for _ in range(n_iters):
+ _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+
+def fit(dynamics,
+ samples,
+ optimizer,
+ loss_fn=compute_loss,
+ n_iters=5000,
+ verbose=True,
+ logdir=None,
+ decay_lr=True):
+ """Fit L2HMC sampler with given log-likelihood function."""
+
+ if logdir:
+ summary_writer = tf.contrib.summary.create_file_writer(logdir)
+
+ for i in range(n_iters):
+ loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn)
+ # TODO(lxuechen): Proper learning rate decay
+ if decay_lr:
+ grads = [grad * .96**(i // 1000) for grad in grads]
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+ if verbose:
+ print("Iteration %d: loss %.4f" % (i, loss))
+
+ if logdir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("loss", loss)
class L2hmcTest(tf.test.TestCase):
"""Unit tests for l2hmc in both eager and graph mode."""
- def testComputeLoss(self):
- """Testing function l2hmc.compute_loss in both graph and eager mode."""
+ def test_apply_transition(self):
+ """Testing function `Dynamics.apply_transition` in graph and eager mode."""
# Eager mode testing
hparams = get_default_hparams()
@@ -51,12 +118,12 @@ class L2hmcTest(tf.test.TestCase):
n_steps=hparams.n_steps,
eps=hparams.eps)
samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(samples, dynamics)
+ x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples)
- # Check shape and numerical stability
+ self.assertEqual(x_.shape, v_.shape)
self.assertEqual(x_out.shape, samples.shape)
- self.assertEqual(loss.shape, [])
- self.assertAllClose(loss.numpy(), loss.numpy(), rtol=1e-5)
+ self.assertEqual(x_.shape, x_out.shape)
+ self.assertEqual(x_accept_prob.shape, (hparams.n_samples,))
# Graph mode testing
with tf.Graph().as_default():
@@ -66,65 +133,49 @@ class L2hmcTest(tf.test.TestCase):
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(x, dynamics)
+ x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x)
samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
- loss_np, x_out_np = sess.run([loss, x_out], feed_dict={x: samples})
+ np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run(
+ [x_, v_, x_accept_prob, x_out], feed_dict={x: samples})
- # Check shape and numerical stability
- self.assertEqual(x_out_np.shape, samples.shape)
- self.assertEqual(loss_np.shape, ())
- self.assertAllClose(loss_np, loss_np, rtol=1e-5)
+ self.assertEqual(np_x_.shape, np_v_.shape)
+ self.assertEqual(samples.shape, np_x_out.shape)
+ self.assertEqual(np_x_.shape, np_x_out.shape)
+ self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,))
class L2hmcBenchmark(tf.test.Benchmark):
"""Eager and graph benchmarks for l2hmc."""
- def benchmarkEagerL2hmc(self):
- """Benchmark Eager performance."""
-
- hparams = get_default_hparams()
- dynamics = l2hmc.Dynamics(
- x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
- n_steps=hparams.n_steps,
- eps=hparams.eps)
- # TODO(lxuechen): Add learning rate decay
- optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
-
- # Warmup to reduce initialization effect when timing
- l2hmc.warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters)
+ def _get_energy_fn(self):
+ """Get specific energy function according to FLAGS."""
- # Time
- start_time = time.time()
- l2hmc.fit(
- dynamics,
- optimizer,
- n_samples=hparams.n_samples,
- n_iters=hparams.n_iters)
- wall_time = time.time() - start_time
- examples_per_sec = hparams.n_samples / wall_time
+ if FLAGS.energy_fn == "scg":
+ energy_fn = l2hmc.get_scg_energy_fn()
+ elif FLAGS.energy_fn == "multivariate_gaussian":
+ energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim)
+ else:
+ raise ValueError("No such energy function %s" % FLAGS.energy_fn)
- self.report_benchmark(
- name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"),
- iters=hparams.n_iters,
- extras={"examples_per_sec": examples_per_sec},
- wall_time=wall_time)
+ return energy_fn
- def benchmarkGraphL2hmc(self):
+ def benchmark_graph(self):
"""Benchmark Graph performance."""
hparams = get_default_hparams()
+ tf.reset_default_graph()
with tf.Graph().as_default():
+ energy_fn = self._get_energy_fn()
dynamics = l2hmc.Dynamics(
x_dim=hparams.x_dim,
- loglikelihood_fn=l2hmc.get_scg_energy_fn(),
+ loglikelihood_fn=energy_fn,
n_steps=hparams.n_steps,
eps=hparams.eps)
x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim])
- loss, x_out = l2hmc.compute_loss(x, dynamics)
+ loss, x_out = compute_loss(dynamics, x)
global_step = tf.Variable(0., name="global_step", trainable=False)
learning_rate = tf.train.exponential_decay(
@@ -138,14 +189,15 @@ class L2hmcBenchmark(tf.test.Benchmark):
# Warmup to reduce initialization effect when timing
samples = npr.normal(size=[hparams.n_samples, hparams.x_dim])
for _ in range(hparams.n_warmup_iters):
- samples, _, _, _ = sess.run(
+ _, _, _, _ = sess.run(
[x_out, loss, train_op, learning_rate], feed_dict={x: samples})
- # Time
+ # Training
start_time = time.time()
- for _ in range(hparams.n_iters):
- samples, _, _, _ = sess.run(
+ for i in range(hparams.n_iters):
+ samples, loss_np, _, _ = sess.run(
[x_out, loss, train_op, learning_rate], feed_dict={x: samples})
+ print("Iteration %d: loss %.4f" % (i, loss_np))
wall_time = time.time() - start_time
examples_per_sec = hparams.n_samples / wall_time
@@ -156,7 +208,57 @@ class L2hmcBenchmark(tf.test.Benchmark):
extras={"examples_per_sec": examples_per_sec},
wall_time=wall_time)
+ def benchmark_eager(self):
+ self._benchmark_eager()
+
+ def benchmark_eager_defun(self):
+ self._benchmark_eager(defun=True)
+
+ def _benchmark_eager(self, defun=False):
+ """Benchmark Eager performance."""
+
+ hparams = get_default_hparams()
+ energy_fn = self._get_energy_fn()
+ dynamics = l2hmc.Dynamics(
+ x_dim=hparams.x_dim,
+ loglikelihood_fn=energy_fn,
+ n_steps=hparams.n_steps,
+ eps=hparams.eps)
+ optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
+ loss_fn = tfe.defun(compute_loss) if defun else compute_loss
+
+ # Warmup to reduce initialization effect when timing
+ warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
+
+ # Training
+ samples = tf.random_normal(
+ shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
+ start_time = time.time()
+ fit(dynamics,
+ samples,
+ optimizer,
+ loss_fn=loss_fn,
+ n_iters=hparams.n_iters,
+ decay_lr=True)
+ wall_time = time.time() - start_time
+ examples_per_sec = hparams.n_samples / wall_time
+
+ self.report_benchmark(
+ name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else
+ "cpu", "_defun" if defun else ""),
+ iters=hparams.n_iters,
+ extras={"examples_per_sec": examples_per_sec},
+ wall_time=wall_time)
+
+ del dynamics
+ del loss_fn
+
if __name__ == "__main__":
+ tf.flags.DEFINE_string("energy_fn", "scg",
+ ("The energy function/unnormalized log-probability. "
+ "Either be `scg` or `multivariate_gaussian`"))
+ tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.")
+ FLAGS = tf.flags.FLAGS
tf.enable_eager_execution()
tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
index c902e1f1f4..e230ad5e25 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
@@ -57,8 +57,6 @@ class GenericNet(tf.keras.Model):
initial_value=tf.zeros([1, x_dim]),
name='coeff_transformation',
trainable=True)
- # TODO(lxuechen): Remove this after model.add_weight is in place
- self.vars_not_in_layers = [self.coeff_scale, self.coeff_transformation]
def call(self, inputs):
v, x, t = inputs
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD
index 0c0e28dd95..68a84d5fbb 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD
+++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD
@@ -51,5 +51,6 @@ cuda_py_test(
"noasan",
"nomsan",
"notsan",
+ "optonly",
],
)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/BUILD b/tensorflow/contrib/eager/python/examples/revnet/BUILD
new file mode 100644
index 0000000000..0c0e4c0eb9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/BUILD
@@ -0,0 +1,115 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+# Model
+py_library(
+ name = "ops",
+ srcs = ["ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "config",
+ srcs = ["config.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "blocks",
+ srcs = ["blocks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "revnet",
+ srcs = ["revnet.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":blocks",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+# Tests
+cuda_py_test(
+ name = "ops_test",
+ size = "large",
+ srcs = ["ops_test.py"],
+ additional_deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "blocks_test",
+ size = "large",
+ srcs = ["blocks_test.py"],
+ additional_deps = [
+ ":blocks",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "optonly",
+ ],
+)
+
+cuda_py_test(
+ name = "revnet_test",
+ size = "large",
+ srcs = ["revnet_test.py"],
+ additional_deps = [
+ ":blocks_test",
+ ":config",
+ ":revnet",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "no_pip", # depends on blocks_test, which is not available in pip package
+ "optonly",
+ ],
+)
+
+# Training
+py_library(
+ name = "cifar_input",
+ srcs = ["cifar_input.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "cifar_tfrecords",
+ srcs = ["cifar_tfrecords.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
+ name = "main",
+ srcs = ["main.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cifar_input",
+ ":config",
+ ":revnet",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/README.md b/tensorflow/contrib/eager/python/examples/revnet/README.md
new file mode 100644
index 0000000000..21fc44febc
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/README.md
@@ -0,0 +1,45 @@
+# RevNet with TensorFlow eager execution
+
+This folder contains an TensorFlow eager implementation of the [Reversible Residual Network](https://arxiv.org/pdf/1707.04585.pdf) adapted from the released implementation by the authors. The presented implementation can be ran both in eager and graph mode. The code is considerably simplified with `tf.GradientTape`. Moreover, we reduce the step of reconstructing the outputs. This saves us from using `tf.stop_gradient` and makes the model run faster.
+
+## Content
+
+- `revnet.py`: The RevNet model.
+- `blocks.py`: The relevant reversible blocks.
+- `cifar_tfrecords.py`: Script to generate the TFRecords for both CIFAR-10 and CIFAR-100.
+- `cifar_input.py`: Script to read from TFRecords and generate dataset objects with the `tf.data` API.
+- `config.py`: Configuration file for network architectures and training hyperparameters.
+- `main.py`: Main training and evaluation script.
+- `ops.py`: Auxiliary downsampling operation.
+
+## To run
+- Make sure you have installed TensorFlow 1.9+ or the latest `tf-nightly`
+or `tf-nightly-gpu` pip package in order to access the eager execution feature.
+
+- First run
+
+```bash
+python cifar_tfrecords.py --data_dir ${PWD}/cifar
+```
+to download the cifar dataset and convert them
+to TFRecords. This produces TFRecord files for both CIFAR-10 and CIFAR-100.
+
+- To train a model run
+
+```bash
+python main.py --data_dir ${PWD}/cifar
+```
+
+- Optional arguments for `main.py` include
+ - `train_dir`: Directory to store eventfiles and checkpoints.
+ - `restore`: Restore the latest checkpoint.
+ - `validate`: Use validation set for training monitoring.
+ - `manual_grad`: Use the manually defined gradient map given by the authors.
+ - `dataset`: Use either `cifar-10` or `cifar-100`
+
+## Performance
+- With the current implementation, RevNet-38 achieves >92% on CIFAR-10 and >71% on CIFAR-100.
+
+## Reference
+The Reversible Residual Network: Backpropagation Without Storing Activations.
+Aidan N. Gomez, Mengye Ren, Raquel Urtasun, Roger B. Grosse. Neural Information Processing Systems (NIPS), 2017.
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
new file mode 100644
index 0000000000..306096e9f8
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py
@@ -0,0 +1,357 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Building blocks with manual backward gradient computation.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import ops
+
+
+class RevBlock(tf.keras.Model):
+ """Single reversible block containing several `_Residual` blocks.
+
+ Each `_Residual` block in turn contains two _ResidualInner blocks,
+ corresponding to the `F`/`G` functions in the paper.
+ """
+
+ def __init__(self,
+ n_res,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=False,
+ data_format="channels_first",
+ bottleneck=False,
+ fused=True,
+ dtype=tf.float32):
+ """Initialize RevBlock.
+
+ Args:
+ n_res: number of residual blocks
+ filters: list/tuple of integers for output filter sizes of each residual
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ bottleneck: use bottleneck residual if True
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+ """
+ super(RevBlock, self).__init__()
+ self.blocks = tf.contrib.checkpoint.List()
+ for i in range(n_res):
+ curr_batch_norm_first = batch_norm_first and i == 0
+ curr_strides = strides if i == 0 else (1, 1)
+ block = _Residual(
+ filters,
+ curr_strides,
+ input_shape,
+ batch_norm_first=curr_batch_norm_first,
+ data_format=data_format,
+ bottleneck=bottleneck,
+ fused=fused,
+ dtype=dtype)
+ self.blocks.append(block)
+
+ if data_format == "channels_first":
+ input_shape = (filters, input_shape[1] // curr_strides[0],
+ input_shape[2] // curr_strides[1])
+ else:
+ input_shape = (input_shape[0] // curr_strides[0],
+ input_shape[1] // curr_strides[1], filters)
+
+ def call(self, h, training=True):
+ """Apply reversible block to inputs."""
+
+ for block in self.blocks:
+ h = block(h, training=training)
+ return h
+
+ def backward_grads_and_vars(self, x, y, dy, training=True):
+ """Apply reversible block backward to outputs."""
+
+ grads_all = []
+ vars_all = []
+
+ for i in reversed(range(len(self.blocks))):
+ block = self.blocks[i]
+ if i == 0:
+ # First block usually contains downsampling that can't be reversed
+ with tf.GradientTape() as tape:
+ x = tf.identity(x)
+ tape.watch(x)
+ y = block(x, training=training)
+
+ grads_combined = tape.gradient(
+ y, [x] + block.trainable_variables, output_gradients=dy)
+ dy = grads_combined[0]
+ grads_all += grads_combined[1:]
+ vars_all += block.trainable_variables
+ else:
+ y, dy, grads, vars_ = block.backward_grads_and_vars(
+ y, dy, training=training)
+ grads_all += grads
+ vars_all += vars_
+
+ return dy, grads_all, vars_all
+
+
+class _Residual(tf.keras.Model):
+ """Single residual block contained in a _RevBlock. Each `_Residual` object has
+ two _ResidualInner objects, corresponding to the `F` and `G` functions in the
+ paper.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC",
+ bottleneck: use bottleneck residual if True
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+ """
+
+ def __init__(self,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ bottleneck=False,
+ fused=True,
+ dtype=tf.float32):
+ super(_Residual, self).__init__()
+
+ self.filters = filters
+ self.strides = strides
+ self.axis = 1 if data_format == "channels_first" else 3
+ if data_format == "channels_first":
+ f_input_shape = (input_shape[0] // 2,) + input_shape[1:]
+ g_input_shape = (filters // 2, input_shape[1] // strides[0],
+ input_shape[2] // strides[1])
+ else:
+ f_input_shape = input_shape[:2] + (input_shape[2] // 2,)
+ g_input_shape = (input_shape[0] // strides[0],
+ input_shape[1] // strides[1], filters // 2)
+
+ factory = _BottleneckResidualInner if bottleneck else _ResidualInner
+ self.f = factory(
+ filters=filters // 2,
+ strides=strides,
+ input_shape=f_input_shape,
+ batch_norm_first=batch_norm_first,
+ data_format=data_format,
+ fused=fused,
+ dtype=dtype)
+ self.g = factory(
+ filters=filters // 2,
+ strides=(1, 1),
+ input_shape=g_input_shape,
+ batch_norm_first=batch_norm_first,
+ data_format=data_format,
+ fused=fused,
+ dtype=dtype)
+
+ def call(self, x, training=True, concat=True):
+ """Apply residual block to inputs."""
+
+ x1, x2 = tf.split(x, num_or_size_splits=2, axis=self.axis)
+ f_x2 = self.f(x2, training=training)
+ x1_down = ops.downsample(
+ x1, self.filters // 2, self.strides, axis=self.axis)
+ x2_down = ops.downsample(
+ x2, self.filters // 2, self.strides, axis=self.axis)
+ y1 = f_x2 + x1_down
+ g_y1 = self.g(y1, training=training)
+ y2 = g_y1 + x2_down
+ if not concat: # For correct backward grads
+ return y1, y2
+
+ return tf.concat([y1, y2], axis=self.axis)
+
+ def backward_grads_and_vars(self, y, dy, training=True):
+ """Manually compute backward gradients given input and output grads."""
+ dy1, dy2 = tf.split(dy, num_or_size_splits=2, axis=self.axis)
+
+ with tf.GradientTape(persistent=True) as tape:
+ y = tf.identity(y)
+ tape.watch(y)
+ y1, y2 = tf.split(y, num_or_size_splits=2, axis=self.axis)
+ z1 = y1
+ gz1 = self.g(z1, training=training)
+ x2 = y2 - gz1
+ fx2 = self.f(x2, training=training)
+ x1 = z1 - fx2
+
+ grads_combined = tape.gradient(
+ gz1, [z1] + self.g.trainable_variables, output_gradients=dy2)
+ dz1 = dy1 + grads_combined[0]
+ dg = grads_combined[1:]
+ dx1 = dz1
+
+ grads_combined = tape.gradient(
+ fx2, [x2] + self.f.trainable_variables, output_gradients=dz1)
+ dx2 = dy2 + grads_combined[0]
+ df = grads_combined[1:]
+
+ del tape
+
+ grads = df + dg
+ vars_ = self.f.trainable_variables + self.g.trainable_variables
+
+ x = tf.concat([x1, x2], axis=self.axis)
+ dx = tf.concat([dx1, dx2], axis=self.axis)
+
+ return x, dx, grads, vars_
+
+
+def _BottleneckResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True,
+ dtype=tf.float32):
+ """Single bottleneck residual inner function contained in _Resdual.
+
+ Corresponds to the `F`/`G` functions in the paper.
+ Suitable for training on ImageNet dataset.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+
+ Returns:
+ A keras model
+ """
+
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=1,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters // 4,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ return model
+
+
+def _ResidualInner(filters,
+ strides,
+ input_shape,
+ batch_norm_first=True,
+ data_format="channels_first",
+ fused=True,
+ dtype=tf.float32):
+ """Single residual inner function contained in _ResdualBlock.
+
+ Corresponds to the `F`/`G` functions in the paper.
+
+ Args:
+ filters: output filter size
+ strides: length 2 list/tuple of integers for height and width strides
+ input_shape: length 3 list/tuple of integers
+ batch_norm_first: whether to apply activation and batch norm before conv
+ data_format: tensor data format, "NCHW"/"NHWC"
+ fused: use fused batch normalization if True
+ dtype: float16, float32, or float64
+
+ Returns:
+ A keras model
+ """
+
+ axis = 1 if data_format == "channels_first" else 3
+ model = tf.keras.Sequential()
+ if batch_norm_first:
+ model.add(
+ tf.keras.layers.BatchNormalization(
+ axis=axis, input_shape=input_shape, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=strides,
+ input_shape=input_shape,
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ model.add(
+ tf.keras.layers.BatchNormalization(axis=axis, fused=fused, dtype=dtype))
+ model.add(tf.keras.layers.Activation("relu"))
+ model.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=3,
+ strides=(1, 1),
+ data_format=data_format,
+ use_bias=False,
+ padding="SAME",
+ dtype=dtype))
+
+ return model
diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
new file mode 100644
index 0000000000..d74785c8fe
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/blocks_test.py
@@ -0,0 +1,304 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic building blocks used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks
+
+
+def compute_degree(g1, g2, eps=1e-7):
+ """Compute the degree between two vectors using their usual inner product."""
+
+ def _dot(u, v):
+ return tf.reduce_sum(u * v)
+
+ g1_norm = tf.sqrt(_dot(g1, g1))
+ g2_norm = tf.sqrt(_dot(g2, g2))
+ if g1_norm.numpy() == 0 and g2_norm.numpy() == 0:
+ cosine = 1. - eps
+ else:
+ g1_norm = 1. if g1_norm.numpy() == 0 else g1_norm
+ g2_norm = 1. if g2_norm.numpy() == 0 else g2_norm
+ cosine = _dot(g1, g2) / g1_norm / g2_norm
+ # Restrict to arccos range
+ cosine = tf.minimum(tf.maximum(cosine, eps - 1.), 1. - eps)
+ degree = tf.acos(cosine) * 180. / 3.141592653589793
+
+ return degree
+
+
+def _validate_block_call_channels_last(block_factory, test):
+ """Generic testing function for `channels_last` data format.
+
+ Completes a set of tests varying data format, stride, and batch normalization
+ configured train vs test time.
+ Args:
+ block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock,
+ blocks._ResidualInner
+ test: tf.test.TestCase object
+ """
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (8, 8, 128)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ block = block_factory(
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 8, 8, 128))
+ test.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = block_factory(
+ filters=128,
+ strides=(2, 2),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 4, 4, 128))
+ test.assertNotAllClose(y_tr, y_ev)
+
+
+def _validate_block_call_channels_first(block_factory, test):
+ """Generic testing function for `channels_first` data format.
+
+ Completes a set of tests varying data format, stride, and batch normalization
+ configured train vs test time.
+ Args:
+ block_factory: constructor of one of blocks.InitBlock, blocks.FinalBlock,
+ blocks._ResidualInner
+ test: tf.test.TestCase object
+ """
+ if not tf.test.is_gpu_available():
+ test.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (128, 8, 8)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride of 1
+ block = block_factory(filters=128, strides=(1, 1), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 128, 8, 8))
+ test.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = block_factory(filters=128, strides=(2, 2), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ test.assertEqual(y_tr.shape, y_ev.shape)
+ test.assertEqual(y_ev.shape, (16, 128, 4, 4))
+ test.assertNotAllClose(y_tr, y_ev)
+
+
+class RevBlockTest(tf.test.TestCase):
+
+ def test_call_channels_first(self):
+ """Test `call` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (128, 8, 8)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride of 1
+ block = blocks.RevBlock(
+ n_res=3, filters=128, strides=(1, 1), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 128, 8, 8))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = blocks.RevBlock(
+ n_res=3, filters=128, strides=(2, 2), input_shape=input_shape)
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, [16, 128, 4, 4])
+ self.assertNotAllClose(y_tr, y_ev)
+
+ def test_call_channels_last(self):
+ """Test `call` function with `channels_last` data format."""
+ with tf.device("/cpu:0"): # NHWC format
+ input_shape = (8, 8, 128)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape)
+
+ # Stride 1
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 8, 8, 128))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ # Stride of 2
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=128,
+ strides=(2, 2),
+ input_shape=input_shape,
+ data_format="channels_last")
+ y_tr, y_ev = block(x, training=True), block(x, training=False)
+ self.assertEqual(y_tr.shape, y_ev.shape)
+ self.assertEqual(y_ev.shape, (16, 4, 4, 128))
+ self.assertNotAllClose(y_tr, y_ev)
+
+ def _check_grad_angle(self, grads, grads_true, atol=1e0):
+ """Check the angle between two list of vectors are all close."""
+ for g1, g2 in zip(grads, grads_true):
+ degree = compute_degree(g1, g2)
+ self.assertLessEqual(degree, atol)
+
+ def test_backward_grads_and_vars_channels_first(self):
+ """Test `backward` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ # Stride 1
+ input_shape = (128, 8, 8)
+ data_shape = (16,) + input_shape
+ x = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ fused=False,
+ dtype=tf.float64)
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ y = block(x, training=True)
+ # Compute grads from reconstruction
+ dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ # Compute true grads
+ grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
+ self._check_grad_angle(dx_true, dx)
+ self._check_grad_angle(dw_true, dw)
+
+ # Stride 2
+ x = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=(16, 128, 4, 4), dtype=tf.float64)
+ block = blocks.RevBlock(
+ n_res=3,
+ filters=128,
+ strides=(2, 2),
+ input_shape=input_shape,
+ fused=False,
+ dtype=tf.float64)
+ with tf.GradientTape() as tape:
+ tape.watch(x)
+ y = block(x, training=True)
+ # Compute grads from reconstruction
+ dx, dw, vars_ = block.backward_grads_and_vars(x, y, dy, training=True)
+ # Compute true grads
+ grads = tape.gradient(y, [x] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
+ self._check_grad_angle(dx_true, dx)
+ self._check_grad_angle(dw_true, dw)
+
+
+class _ResidualTest(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function.
+
+ Varying downsampling and data format options.
+ """
+
+ _validate_block_call_channels_first(blocks._Residual, self)
+ _validate_block_call_channels_last(blocks._Residual, self)
+
+ def test_backward_grads_and_vars_channels_first(self):
+ """Test `backward_grads` function with `channels_first` data format."""
+ if not tf.test.is_gpu_available():
+ self.skipTest("GPU not available")
+
+ with tf.device("/gpu:0"): # Default NCHW format
+ input_shape = (128, 8, 8)
+ data_shape = (16,) + input_shape
+ # Use double precision for testing
+ x_true = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ dy = tf.random_normal(shape=data_shape, dtype=tf.float64)
+ residual = blocks._Residual(
+ filters=128,
+ strides=(1, 1),
+ input_shape=input_shape,
+ fused=False,
+ dtype=tf.float64)
+
+ with tf.GradientTape() as tape:
+ x_true = tf.identity(x_true)
+ tape.watch(x_true)
+ y = residual(x_true, training=True)
+
+ # Gradients computed due to reversibility
+ x, dx, dw, vars_ = residual.backward_grads_and_vars(
+ y, dy=dy, training=True)
+
+ # True gradients computed by the tape
+ grads = tape.gradient(y, [x_true] + vars_, output_gradients=dy)
+ dx_true, dw_true = grads[0], grads[1:]
+
+ self.assertAllClose(x_true, x)
+ self.assertAllClose(dx_true, dx)
+ self.assertAllClose(dw_true, dw)
+
+
+class _ResidualInnerTest(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function."""
+
+ _validate_block_call_channels_first(blocks._ResidualInner, self)
+ _validate_block_call_channels_last(blocks._ResidualInner, self)
+
+
+class _BottleneckResidualInner(tf.test.TestCase):
+
+ def test_call(self):
+ """Test `call` function."""
+
+ _validate_block_call_channels_first(blocks._BottleneckResidualInner, self)
+ _validate_block_call_channels_last(blocks._BottleneckResidualInner, self)
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
new file mode 100644
index 0000000000..b6d4c35bfd
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_input.py
@@ -0,0 +1,116 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script for reading and loading CIFAR-10."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+# Global constants describing the CIFAR data set.
+IMAGE_HEIGHT = 32
+IMAGE_WIDTH = 32
+NUM_CHANNEL = 3
+
+
+def get_ds_from_tfrecords(data_dir,
+ split,
+ data_aug=True,
+ batch_size=100,
+ epochs=None,
+ shuffle=True,
+ data_format="channels_first",
+ num_parallel_calls=12,
+ prefetch=0,
+ div255=True,
+ dtype=tf.float32):
+ """Returns a tf.train.Dataset object from reading tfrecords.
+
+ Args:
+ data_dir: Directory of tfrecords
+ split: "train", "validation", or "test"
+ data_aug: Apply data augmentation if True
+ batch_size: Batch size of dataset object
+ epochs: Number of epochs to repeat the dataset; default `None` means
+ repeating indefinitely
+ shuffle: Shuffle the dataset if True
+ data_format: `channels_first` or `channels_last`
+ num_parallel_calls: Number of threads for dataset preprocess
+ prefetch: Buffer size for prefetch
+ div255: Divide the images by 255 if True
+ dtype: Data type of images
+ Returns:
+ A tf.train.Dataset object
+
+ Raises:
+ ValueError: Unknown split
+ """
+
+ if split not in ["train", "validation", "test", "train_all"]:
+ raise ValueError("Unknown split {}".format(split))
+
+ def _parser(serialized_example):
+ """Parses a single tf.Example into image and label tensors."""
+ features = tf.parse_single_example(
+ serialized_example,
+ features={
+ "image": tf.FixedLenFeature([], tf.string),
+ "label": tf.FixedLenFeature([], tf.int64),
+ })
+ image = tf.decode_raw(features["image"], tf.uint8)
+ # Initially reshaping to [H, W, C] does not work
+ image = tf.reshape(image, [NUM_CHANNEL, IMAGE_HEIGHT, IMAGE_WIDTH])
+ # This is needed for `tf.image.resize_image_with_crop_or_pad`
+ image = tf.transpose(image, [1, 2, 0])
+
+ image = tf.cast(image, dtype)
+ label = tf.cast(features["label"], tf.int32)
+
+ if data_aug:
+ image = tf.image.resize_image_with_crop_or_pad(image, IMAGE_HEIGHT + 4,
+ IMAGE_WIDTH + 4)
+ image = tf.random_crop(image, [IMAGE_HEIGHT, IMAGE_WIDTH, NUM_CHANNEL])
+ image = tf.image.random_flip_left_right(image)
+
+ if data_format == "channels_first":
+ image = tf.transpose(image, [2, 0, 1])
+
+ if div255:
+ image /= 255.
+
+ return image, label
+
+ filename = os.path.join(data_dir, split + ".tfrecords")
+ dataset = tf.data.TFRecordDataset(filename)
+ dataset = dataset.repeat(epochs)
+ dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls)
+ dataset = dataset.prefetch(prefetch)
+
+ if shuffle:
+ # Find the right size according to the split
+ size = {
+ "train": 40000,
+ "validation": 10000,
+ "test": 10000,
+ "train_all": 50000
+ }[split]
+ dataset = dataset.shuffle(size)
+
+ dataset = dataset.batch(batch_size)
+
+ return dataset
diff --git a/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
new file mode 100644
index 0000000000..377844ad8f
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/cifar_tfrecords.py
@@ -0,0 +1,154 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Read CIFAR data from pickled numpy arrays and writes TFRecords.
+
+Generates tf.train.Example protos and writes them to TFRecord files from the
+python version of the CIFAR dataset downloaded from
+https://www.cs.toronto.edu/~kriz/cifar.html.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import tarfile
+
+from absl import flags
+from six.moves import cPickle as pickle
+from six.moves import urllib
+import tensorflow as tf
+
+BASE_URL = 'https://www.cs.toronto.edu/~kriz/'
+CIFAR_FILE_NAMES = ['cifar-10-python.tar.gz', 'cifar-100-python.tar.gz']
+CIFAR_DOWNLOAD_URLS = [BASE_URL + name for name in CIFAR_FILE_NAMES]
+CIFAR_LOCAL_FOLDERS = ['cifar-10', 'cifar-100']
+EXTRACT_FOLDERS = ['cifar-10-batches-py', 'cifar-100-python']
+
+
+def download_and_extract(data_dir, file_name, url):
+ """Download CIFAR if not already downloaded."""
+ filepath = os.path.join(data_dir, file_name)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(data_dir):
+ tf.gfile.MakeDirs(data_dir)
+
+ urllib.request.urlretrieve(url, filepath)
+ tarfile.open(os.path.join(filepath), 'r:gz').extractall(data_dir)
+ return filepath
+
+
+def _int64_feature(value):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def _get_file_names(folder):
+ """Returns the file names expected to exist in the input_dir."""
+ assert folder in ['cifar-10', 'cifar-100']
+
+ file_names = {}
+ if folder == 'cifar-10':
+ file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
+ file_names['validation'] = ['data_batch_5']
+ file_names['train_all'] = ['data_batch_%d' % i for i in range(1, 6)]
+ file_names['test'] = ['test_batch']
+ else:
+ file_names['train_all'] = ['train']
+ file_names['test'] = ['test']
+ # Split in `convert_to_tfrecord` function
+ file_names['train'] = ['train']
+ file_names['validation'] = ['train']
+ return file_names
+
+
+def read_pickle_from_file(filename):
+ with tf.gfile.Open(filename, 'rb') as f:
+ if sys.version_info >= (3, 0):
+ data_dict = pickle.load(f, encoding='bytes')
+ else:
+ data_dict = pickle.load(f)
+ return data_dict
+
+
+def convert_to_tfrecord(input_files, output_file, folder):
+ """Converts files with pickled data to TFRecords."""
+ assert folder in ['cifar-10', 'cifar-100']
+
+ print('Generating %s' % output_file)
+ with tf.python_io.TFRecordWriter(output_file) as record_writer:
+ for input_file in input_files:
+ data_dict = read_pickle_from_file(input_file)
+ data = data_dict[b'data']
+ try:
+ labels = data_dict[b'labels']
+ except KeyError:
+ labels = data_dict[b'fine_labels']
+
+ if folder == 'cifar-100' and input_file.endswith('train.tfrecords'):
+ data = data[:40000]
+ labels = labels[:40000]
+ elif folder == 'cifar-100' and input_file.endswith(
+ 'validation.tfrecords'):
+ data = data[40000:]
+ labels = labels[40000:]
+
+ num_entries_in_batch = len(labels)
+
+ for i in range(num_entries_in_batch):
+ example = tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'image': _bytes_feature(data[i].tobytes()),
+ 'label': _int64_feature(labels[i])
+ }))
+ record_writer.write(example.SerializeToString())
+
+
+def main(_):
+ for file_name, url, folder, extract_folder in zip(
+ CIFAR_FILE_NAMES, CIFAR_DOWNLOAD_URLS, CIFAR_LOCAL_FOLDERS,
+ EXTRACT_FOLDERS):
+ print('Download from {} and extract.'.format(url))
+ data_dir = os.path.join(FLAGS.data_dir, folder)
+ download_and_extract(data_dir, file_name, url)
+ file_names = _get_file_names(folder)
+ input_dir = os.path.join(data_dir, extract_folder)
+
+ for mode, files in file_names.items():
+ input_files = [os.path.join(input_dir, f) for f in files]
+ output_file = os.path.join(data_dir, mode + '.tfrecords')
+ try:
+ os.remove(output_file)
+ except OSError:
+ pass
+ convert_to_tfrecord(input_files, output_file, folder)
+
+ print('Done!')
+
+
+if __name__ == '__main__':
+ FLAGS = flags.FLAGS
+ flags.DEFINE_string(
+ 'data_dir',
+ default=None,
+ help='Directory to download, extract and store TFRecords.')
+
+ tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/config.py b/tensorflow/contrib/eager/python/examples/revnet/config.py
new file mode 100644
index 0000000000..3d93fa955a
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/config.py
@@ -0,0 +1,140 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Configuration in format of tf.contrib.training.HParams.
+Supports CIFAR-10, CIFAR-100, and ImageNet datasets.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+tfe = tf.contrib.eager
+
+
+def get_hparams_cifar_38():
+ """RevNet-38 configurations for CIFAR-10/CIFAR-100."""
+
+ config = tf.contrib.training.HParams()
+ config.add_hparam("init_filters", 32)
+ config.add_hparam("init_kernel", 3)
+ config.add_hparam("init_stride", 1)
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("n_rev_blocks", 3)
+ config.add_hparam("n_res", [3, 3, 3])
+ config.add_hparam("filters", [32, 64, 112])
+ config.add_hparam("strides", [1, 2, 2])
+ config.add_hparam("batch_size", 100)
+ config.add_hparam("bottleneck", False)
+ config.add_hparam("fused", True)
+ config.add_hparam("init_max_pool", False)
+ if tfe.num_gpus() > 0:
+ config.add_hparam("input_shape", (3, 32, 32))
+ config.add_hparam("data_format", "channels_first")
+ else:
+ config.add_hparam("input_shape", (32, 32, 3))
+ config.add_hparam("data_format", "channels_last")
+
+ # Training details
+ config.add_hparam("weight_decay", 2e-4)
+ config.add_hparam("momentum", .9)
+ config.add_hparam("lr_decay_steps", [40000, 60000])
+ config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3])
+ config.add_hparam("max_train_iter", 80000)
+ config.add_hparam("seed", 1234)
+ config.add_hparam("shuffle", True)
+ config.add_hparam("log_every", 500)
+ config.add_hparam("save_every", 500)
+ config.add_hparam("dtype", tf.float32)
+ config.add_hparam("eval_batch_size", 1000)
+ config.add_hparam("div255", True)
+ # This is imprecise, when training with validation set,
+ # we only have 40k images in training data
+ config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
+ config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
+
+ return config
+
+
+def get_hparams_cifar_110():
+ config = get_hparams_cifar_38()
+ config.filters = [32, 64, 128]
+ config.n_res = [9, 9, 9]
+
+ return config
+
+
+def get_hparams_cifar_164():
+ config = get_hparams_cifar_38()
+ config.filters = [32, 64, 128]
+ config.n_res = [9, 9, 9]
+ config.use_bottleneck = True
+ # Due to bottleneck residual blocks
+ filters = [f * 4 for f in config.filters]
+ config.filters = filters
+
+ return config
+
+
+def get_hparams_imagenet_56():
+ """RevNet-56 configurations for ImageNet."""
+
+ config = tf.contrib.training.HParams()
+ config.add_hparam("init_filters", 128)
+ config.add_hparam("init_kernel", 7)
+ config.add_hparam("init_stride", 2)
+ config.add_hparam("n_classes", 1000)
+ config.add_hparam("n_rev_blocks", 4)
+ config.add_hparam("n_res", [2, 2, 2, 2])
+ config.add_hparam("filters", [128, 256, 512, 832])
+ config.add_hparam("strides", [1, 2, 2, 2])
+ config.add_hparam("batch_size", 16)
+ config.add_hparam("bottleneck", True)
+ config.add_hparam("fused", True)
+ config.add_hparam("init_max_pool", True)
+ if tf.test.is_gpu_available():
+ config.add_hparam("input_shape", (3, 224, 224))
+ config.add_hparam("data_format", "channels_first")
+ else:
+ config.add_hparam("input_shape", (224, 224, 3))
+ config.add_hparam("data_format", "channels_last")
+
+ # Training details
+ config.add_hparam("weight_decay", 1e-4)
+ config.add_hparam("momentum", .9)
+ config.add_hparam("lr_decay_steps", [160000, 320000, 480000])
+ config.add_hparam("lr_list", [1e-1, 1e-2, 1e-3, 1e-4])
+ config.add_hparam("max_train_iter", 600000)
+ config.add_hparam("seed", 1234)
+ config.add_hparam("shuffle", True)
+ config.add_hparam("log_every", 50)
+ config.add_hparam("save_every", 50)
+ config.add_hparam("dtype", tf.float32)
+ config.add_hparam("eval_batch_size", 1000)
+ config.add_hparam("div255", True)
+ # TODO(lxuechen): Update this according to ImageNet data
+ config.add_hparam("iters_per_epoch", 50000 // config.batch_size)
+ config.add_hparam("epochs", config.max_train_iter // config.iters_per_epoch)
+ # Due to bottleneck residual blocks
+ filters = [f * 4 for f in config.filters]
+ config.filters = filters
+
+ return config
diff --git a/tensorflow/contrib/eager/python/examples/revnet/main.py b/tensorflow/contrib/eager/python/examples/revnet/main.py
new file mode 100644
index 0000000000..e2f43b03f9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/main.py
@@ -0,0 +1,256 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Eager execution workflow with RevNet train on CIFAR-10."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+from absl import flags
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import cifar_input
+from tensorflow.contrib.eager.python.examples.revnet import config as config_
+from tensorflow.contrib.eager.python.examples.revnet import revnet
+tfe = tf.contrib.eager
+
+
+def main(_):
+ """Eager execution workflow with RevNet trained on CIFAR-10."""
+ config = get_config()
+ ds_train, ds_train_one_shot, ds_validation, ds_test = get_datasets(config)
+ model = revnet.RevNet(config=config)
+ global_step = tf.train.get_or_create_global_step() # Ensure correct summary
+ global_step.assign(1)
+ learning_rate = tf.train.piecewise_constant(
+ global_step, config.lr_decay_steps, config.lr_list)
+ optimizer = tf.train.MomentumOptimizer(
+ learning_rate, momentum=config.momentum)
+ checkpointer = tf.train.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=global_step)
+
+ if FLAGS.train_dir:
+ summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
+ if FLAGS.restore:
+ latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
+ checkpointer.restore(latest_path)
+ print("Restored latest checkpoint at path:\"{}\" "
+ "with global_step: {}".format(latest_path, global_step.numpy()))
+ sys.stdout.flush()
+
+ if FLAGS.manual_grad:
+ print("Using manual gradients.")
+ else:
+ print("Not using manual gradients.")
+ sys.stdout.flush()
+
+ for x, y in ds_train:
+ train_one_iter(model, x, y, optimizer, global_step=global_step)
+
+ if global_step.numpy() % config.log_every == 0:
+ it_train = ds_train_one_shot.make_one_shot_iterator()
+ it_test = ds_test.make_one_shot_iterator()
+ acc_train, loss_train = evaluate(model, it_train)
+ acc_test, loss_test = evaluate(model, it_test)
+
+ if FLAGS.validate:
+ it_validation = ds_validation.make_one_shot_iterator()
+ acc_validation, loss_validation = evaluate(model, it_validation)
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "validation set accuracy {:.4f}, loss {:4.f}"
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_validation,
+ loss_validation, acc_test, loss_test))
+ else:
+ print("Iter {}, "
+ "training set accuracy {:.4f}, loss {:.4f}; "
+ "test accuracy {:.4f}, loss {:.4f}".format(
+ global_step.numpy(), acc_train, loss_train, acc_test,
+ loss_test))
+ sys.stdout.flush()
+
+ if FLAGS.train_dir:
+ with summary_writer.as_default():
+ with tf.contrib.summary.always_record_summaries():
+ tf.contrib.summary.scalar("Training accuracy", acc_train)
+ tf.contrib.summary.scalar("Test accuracy", acc_test)
+ tf.contrib.summary.scalar("Training loss", loss_train)
+ tf.contrib.summary.scalar("Test loss", loss_test)
+ if FLAGS.validate:
+ tf.contrib.summary.scalar("Validation accuracy", acc_validation)
+ tf.contrib.summary.scalar("Validation loss", loss_validation)
+
+ if global_step.numpy() % config.save_every == 0 and FLAGS.train_dir:
+ saved_path = checkpointer.save(
+ file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
+ print("Saved checkpoint at path: \"{}\" "
+ "with global_step: {}".format(saved_path, global_step.numpy()))
+ sys.stdout.flush()
+
+
+def get_config():
+ """Return configuration."""
+ print("Config: {}".format(FLAGS.config))
+ sys.stdout.flush()
+ config = {
+ "revnet-38": config_.get_hparams_cifar_38(),
+ "revnet-110": config_.get_hparams_cifar_110(),
+ "revnet-164": config_.get_hparams_cifar_164(),
+ }[FLAGS.config]
+
+ if FLAGS.dataset == "cifar-100":
+ config.n_classes = 100
+
+ return config
+
+
+def get_datasets(config):
+ """Return dataset."""
+ if FLAGS.data_dir is None:
+ raise ValueError("No supplied data directory")
+ if not os.path.exists(FLAGS.data_dir):
+ raise ValueError("Data directory {} does not exist".format(FLAGS.data_dir))
+ if FLAGS.dataset not in ["cifar-10", "cifar-100"]:
+ raise ValueError("Unknown dataset {}".format(FLAGS.dataset))
+
+ print("Training on {} dataset.".format(FLAGS.dataset))
+ sys.stdout.flush()
+ data_dir = os.path.join(FLAGS.data_dir, FLAGS.dataset)
+ if FLAGS.validate:
+ # 40k Training set
+ ds_train = cifar_input.get_ds_from_tfrecords(
+ data_dir=data_dir,
+ split="train",
+ data_aug=True,
+ batch_size=config.batch_size,
+ epochs=config.epochs,
+ shuffle=config.shuffle,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.batch_size)
+ # 10k Training set
+ ds_validation = cifar_input.get_ds_from_tfrecords(
+ data_dir=data_dir,
+ split="validation",
+ data_aug=False,
+ batch_size=config.eval_batch_size,
+ epochs=1,
+ shuffle=False,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.eval_batch_size)
+ else:
+ # 50k Training set
+ ds_train = cifar_input.get_ds_from_tfrecords(
+ data_dir=data_dir,
+ split="train_all",
+ data_aug=True,
+ batch_size=config.batch_size,
+ epochs=config.epochs,
+ shuffle=config.shuffle,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.batch_size)
+ ds_validation = None
+
+ # Always compute loss and accuracy on whole training and test set
+ ds_train_one_shot = cifar_input.get_ds_from_tfrecords(
+ data_dir=data_dir,
+ split="train_all",
+ data_aug=False,
+ batch_size=config.eval_batch_size,
+ epochs=1,
+ shuffle=False,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.eval_batch_size)
+
+ ds_test = cifar_input.get_ds_from_tfrecords(
+ data_dir=data_dir,
+ split="test",
+ data_aug=False,
+ batch_size=config.eval_batch_size,
+ epochs=1,
+ shuffle=False,
+ data_format=config.data_format,
+ dtype=config.dtype,
+ prefetch=config.eval_batch_size)
+
+ return ds_train, ds_train_one_shot, ds_validation, ds_test
+
+
+def train_one_iter(model, inputs, labels, optimizer, global_step=None):
+ """Train for one iteration."""
+ if FLAGS.manual_grad:
+ grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+ else: # For correctness validation
+ with tf.GradientTape() as tape:
+ logits, _ = model(inputs, training=True)
+ loss = model.compute_loss(logits=logits, labels=labels)
+ tf.logging.info("Logits are placed on device: {}".format(logits.device))
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(
+ zip(grads, model.trainable_variables), global_step=global_step)
+
+ return loss.numpy()
+
+
+def evaluate(model, iterator):
+ """Compute accuracy with the given dataset iterator."""
+ mean_loss = tfe.metrics.Mean()
+ accuracy = tfe.metrics.Accuracy()
+ for x, y in iterator:
+ logits, _ = model(x, training=False)
+ loss = model.compute_loss(logits=logits, labels=y)
+ accuracy(
+ labels=tf.cast(y, tf.int64),
+ predictions=tf.argmax(logits, axis=1, output_type=tf.int64))
+ mean_loss(loss)
+
+ return accuracy.result().numpy(), mean_loss.result().numpy()
+
+
+if __name__ == "__main__":
+ flags.DEFINE_string(
+ "data_dir", default=None, help="Directory to load tfrecords")
+ flags.DEFINE_string(
+ "train_dir",
+ default=None,
+ help="[Optional] Directory to store the training information")
+ flags.DEFINE_boolean(
+ "restore",
+ default=False,
+ help="[Optional] Restore the latest checkpoint from `train_dir` if True")
+ flags.DEFINE_boolean(
+ "validate",
+ default=False,
+ help="[Optional] Use the validation set or not for hyperparameter search")
+ flags.DEFINE_boolean(
+ "manual_grad",
+ default=False,
+ help="[Optional] Use manual gradient graph to save memory")
+ flags.DEFINE_string(
+ "dataset",
+ default="cifar-10",
+ help="[Optional] The dataset used; either `cifar-10` or `cifar-100`")
+ flags.DEFINE_string(
+ "config", default="revnet-38", help="[Optional] Architecture of network.")
+ FLAGS = flags.FLAGS
+ tf.enable_eager_execution()
+ tf.app.run(main)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops.py b/tensorflow/contrib/eager/python/examples/revnet/ops.py
new file mode 100644
index 0000000000..9ed5d363e6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/ops.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Customized basic operations.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def downsample(x, filters, strides, axis=1):
+ """Downsample feature map with avg pooling, if filter size doesn't match."""
+
+ def pad_strides(strides, axis=1):
+ """Convert length 2 to length 4 strides.
+
+ Needed since `tf.layers.Conv2D` uses length 2 strides, whereas operations
+ such as `tf.nn.avg_pool` use length 4 strides.
+
+ Args:
+ strides: length 2 list/tuple strides for height and width
+ axis: integer specifying feature dimension according to data format
+ Returns:
+ length 4 strides padded with 1 on batch and channel dimension
+ """
+
+ assert len(strides) == 2
+
+ if axis == 1:
+ return [1, 1, strides[0], strides[1]]
+ return [1, strides[0], strides[1], 1]
+
+ assert len(x.shape) == 4 and (axis == 1 or axis == 3)
+
+ data_format = "NCHW" if axis == 1 else "NHWC"
+ strides_ = pad_strides(strides, axis=axis)
+
+ if strides[0] > 1:
+ x = tf.nn.avg_pool(
+ x, strides_, strides_, padding="VALID", data_format=data_format)
+
+ in_filter = x.shape[axis]
+ out_filter = filters
+
+ if in_filter < out_filter:
+ pad_size = [(out_filter - in_filter) // 2, (out_filter - in_filter) // 2]
+ if axis == 1:
+ x = tf.pad(x, [[0, 0], pad_size, [0, 0], [0, 0]])
+ else:
+ x = tf.pad(x, [[0, 0], [0, 0], [0, 0], pad_size])
+ # In case `tape.gradient(x, [x])` produces a list of `None`
+ return x + 0.
diff --git a/tensorflow/contrib/eager/python/examples/revnet/ops_test.py b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py
new file mode 100644
index 0000000000..5bc2641faf
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/ops_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic ops used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import ops
+tfe = tf.contrib.eager
+
+
+class OpsTest(tf.test.TestCase):
+
+ def test_downsample(self):
+ """Test `possible_down_sample` function with mock object."""
+
+ batch_size = 100
+ # NHWC format
+ x = tf.random_normal(shape=[batch_size, 32, 32, 3])
+ # HW doesn't change but number of features increased
+ y = ops.downsample(x, filters=5, strides=(1, 1), axis=3)
+ self.assertEqual(y.shape, [batch_size, 32, 32, 5])
+ # Feature map doesn't change but HW reduced
+ y = ops.downsample(x, filters=3, strides=(2, 2), axis=3)
+ self.assertEqual(y.shape, [batch_size, 16, 16, 3])
+ # Number of feature increased and HW reduced
+ y = ops.downsample(x, filters=5, strides=(2, 2), axis=3)
+ self.assertEqual(y.shape, [batch_size, 16, 16, 5])
+
+ # Test gradient flow
+ x = tf.random_normal(shape=[batch_size, 32, 32, 3])
+ with tfe.GradientTape() as tape:
+ tape.watch(x)
+ y = ops.downsample(x, filters=3, strides=(1, 1))
+ self.assertEqual(y.shape, x.shape)
+ dy = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ grad, = tape.gradient(y, [x], output_gradients=[dy])
+ self.assertEqual(grad.shape, x.shape)
+
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ x = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ # HW doesn't change but feature map reduced
+ y = ops.downsample(x, filters=5, strides=(1, 1))
+ self.assertEqual(y.shape, [batch_size, 5, 32, 32])
+ # Feature map doesn't change but HW reduced
+ y = ops.downsample(x, filters=3, strides=(2, 2))
+ self.assertEqual(y.shape, [batch_size, 3, 16, 16])
+ # Both feature map and HW reduced
+ y = ops.downsample(x, filters=5, strides=(2, 2))
+ self.assertEqual(y.shape, [batch_size, 5, 16, 16])
+
+ # Test gradient flow
+ x = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ with tfe.GradientTape() as tape:
+ tape.watch(x)
+ y = ops.downsample(x, filters=3, strides=(1, 1))
+ self.assertEqual(y.shape, x.shape)
+ dy = tf.random_normal(shape=[batch_size, 3, 32, 32])
+ grad, = tape.gradient(y, [x], output_gradients=[dy])
+ self.assertEqual(grad.shape, x.shape)
+
+
+if __name__ == '__main__':
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
new file mode 100644
index 0000000000..af0d20fa72
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reversible residual network compatible with eager execution.
+
+Code for main model.
+
+Reference [The Reversible Residual Network: Backpropagation
+Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import operator
+
+import six
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks
+
+
+class RevNet(tf.keras.Model):
+ """RevNet that depends on all the blocks."""
+
+ def __init__(self, config):
+ """Initialize RevNet with building blocks.
+
+ Args:
+ config: tf.contrib.training.HParams object; specifies hyperparameters
+ """
+ super(RevNet, self).__init__()
+ self.axis = 1 if config.data_format == "channels_first" else 3
+ self.config = config
+
+ self._init_block = self._construct_init_block()
+ self._block_list = self._construct_intermediate_blocks()
+ self._final_block = self._construct_final_block()
+
+ def _construct_init_block(self):
+ init_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.Conv2D(
+ filters=self.config.init_filters,
+ kernel_size=self.config.init_kernel,
+ strides=(self.config.init_stride, self.config.init_stride),
+ data_format=self.config.data_format,
+ use_bias=False,
+ padding="SAME",
+ input_shape=self.config.input_shape,
+ dtype=self.config.dtype),
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis,
+ fused=self.config.fused,
+ dtype=self.config.dtype),
+ tf.keras.layers.Activation("relu"),
+ ],
+ name="init")
+ if self.config.init_max_pool:
+ init_block.add(
+ tf.keras.layers.MaxPooling2D(
+ pool_size=(3, 3),
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format,
+ dtype=self.config.dtype))
+ return init_block
+
+ def _construct_final_block(self):
+ f = self.config.filters[-1] # Number of filters
+ r = functools.reduce(operator.mul, self.config.strides, 1) # Reduce ratio
+ r *= self.config.init_stride
+ if self.config.init_max_pool:
+ r *= 2
+
+ if self.config.data_format == "channels_first":
+ w, h = self.config.input_shape[1], self.config.input_shape[2]
+ input_shape = (f, w // r, h // r)
+ elif self.config.data_format == "channels_last":
+ w, h = self.config.input_shape[0], self.config.input_shape[1]
+ input_shape = (w // r, h // r, f)
+ else:
+ raise ValueError("Data format should be either `channels_first`"
+ " or `channels_last`")
+
+ final_block = tf.keras.Sequential(
+ [
+ tf.keras.layers.BatchNormalization(
+ axis=self.axis,
+ input_shape=input_shape,
+ fused=self.config.fused,
+ dtype=self.config.dtype),
+ tf.keras.layers.Activation("relu"),
+ tf.keras.layers.GlobalAveragePooling2D(
+ data_format=self.config.data_format, dtype=self.config.dtype),
+ tf.keras.layers.Dense(
+ self.config.n_classes, dtype=self.config.dtype)
+ ],
+ name="final")
+ return final_block
+
+ def _construct_intermediate_blocks(self):
+ # Precompute input shape after initial block
+ stride = self.config.init_stride
+ if self.config.init_max_pool:
+ stride *= 2
+ if self.config.data_format == "channels_first":
+ w, h = self.config.input_shape[1], self.config.input_shape[2]
+ input_shape = (self.config.init_filters, w // stride, h // stride)
+ else:
+ w, h = self.config.input_shape[0], self.config.input_shape[1]
+ input_shape = (w // stride, h // stride, self.config.init_filters)
+
+ # Aggregate intermediate blocks
+ block_list = tf.contrib.checkpoint.List()
+ for i in range(self.config.n_rev_blocks):
+ # RevBlock configurations
+ n_res = self.config.n_res[i]
+ filters = self.config.filters[i]
+ if filters % 2 != 0:
+ raise ValueError("Number of output filters must be even to ensure"
+ "correct partitioning of channels")
+ stride = self.config.strides[i]
+ strides = (self.config.strides[i], self.config.strides[i])
+
+ # Add block
+ rev_block = blocks.RevBlock(
+ n_res,
+ filters,
+ strides,
+ input_shape,
+ batch_norm_first=(i != 0), # Only skip on first block
+ data_format=self.config.data_format,
+ bottleneck=self.config.bottleneck,
+ fused=self.config.fused,
+ dtype=self.config.dtype)
+ block_list.append(rev_block)
+
+ # Precompute input shape for the next block
+ if self.config.data_format == "channels_first":
+ w, h = input_shape[1], input_shape[2]
+ input_shape = (filters, w // stride, h // stride)
+ else:
+ w, h = input_shape[0], input_shape[1]
+ input_shape = (w // stride, h // stride, filters)
+
+ return block_list
+
+ def call(self, inputs, training=True):
+ """Forward pass."""
+
+ if training:
+ saved_hidden = [inputs]
+
+ h = self._init_block(inputs, training=training)
+ if training:
+ saved_hidden.append(h)
+
+ for block in self._block_list:
+ h = block(h, training=training)
+ if training:
+ saved_hidden.append(h)
+
+ logits = self._final_block(h, training=training)
+
+ return (logits, saved_hidden) if training else (logits, None)
+
+ def compute_loss(self, logits, labels):
+ """Compute cross entropy loss."""
+
+ if self.config.dtype == tf.float32 or self.config.dtype == tf.float16:
+ cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
+ else:
+ # `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel
+ # for float64, int32 pairs
+ labels = tf.one_hot(
+ labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype)
+ cross_ent = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=labels)
+
+ return tf.reduce_mean(cross_ent)
+
+ def compute_gradients(self, inputs, labels, training=True, l2_reg=True):
+ """Manually computes gradients.
+
+ When eager execution is enabled, this method also SILENTLY updates the
+ running averages of batch normalization when `training` is set to True.
+
+ Args:
+ inputs: Image tensor, either NHWC or NCHW, conforming to `data_format`
+ labels: One-hot labels for classification
+ training: Use the mini-batch stats in batch norm if set to True
+ l2_reg: Apply l2 regularization
+
+ Returns:
+ list of tuples each being (grad, var) for optimizer to use
+ """
+
+ # Run forward pass to record hidden states; avoid updating running averages
+ vars_and_vals = self.get_moving_stats()
+ _, saved_hidden = self.call(inputs, training=training)
+ self.restore_moving_stats(vars_and_vals)
+
+ grads_all = []
+ vars_all = []
+
+ # Manually backprop through last block
+ x = saved_hidden[-1]
+ with tf.GradientTape() as tape:
+ x = tf.identity(x)
+ tape.watch(x)
+ # Running stats updated below
+ logits = self._final_block(x, training=training)
+ loss = self.compute_loss(logits, labels)
+
+ grads_combined = tape.gradient(loss,
+ [x] + self._final_block.trainable_variables)
+ dy, grads_ = grads_combined[0], grads_combined[1:]
+ grads_all += grads_
+ vars_all += self._final_block.trainable_variables
+
+ # Manually backprop through intermediate blocks
+ for block in reversed(self._block_list):
+ y = saved_hidden.pop()
+ x = saved_hidden[-1]
+ dy, grads, vars_ = block.backward_grads_and_vars(
+ x, y, dy, training=training)
+ grads_all += grads
+ vars_all += vars_
+
+ # Manually backprop through first block
+ saved_hidden.pop()
+ x = saved_hidden.pop()
+ assert not saved_hidden # Cleared after backprop
+
+ with tf.GradientTape() as tape:
+ x = tf.identity(x)
+ # Running stats updated below
+ y = self._init_block(x, training=training)
+
+ grads_all += tape.gradient(
+ y, self._init_block.trainable_variables, output_gradients=dy)
+ vars_all += self._init_block.trainable_variables
+
+ # Apply weight decay
+ if l2_reg:
+ grads_all = self._apply_weight_decay(grads_all, vars_all)
+
+ return grads_all, vars_all, loss
+
+ def _apply_weight_decay(self, grads, vars_):
+ """Update gradients to reflect weight decay."""
+ # Don't decay bias
+ return [
+ g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g
+ for g, v in zip(grads, vars_)
+ ]
+
+ def get_moving_stats(self):
+ """Get moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Returns:
+ A dictionary mapping variables for batch normalization moving averages
+ to their current values.
+ """
+ vars_and_vals = {}
+
+ def _is_moving_var(v):
+ n = v.name
+ return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
+
+ for v in filter(_is_moving_var, self.variables):
+ vars_and_vals[v] = v.read_value()
+
+ return vars_and_vals
+
+ def restore_moving_stats(self, vars_and_vals):
+ """Restore moving averages of batch normalization.
+
+ This is needed to avoid updating the running average twice in one iteration.
+
+ Args:
+ vars_and_vals: The dictionary mapping variables to their previous values.
+ """
+ for var_, val in six.iteritems(vars_and_vals):
+ var_.assign(val)
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
new file mode 100644
index 0000000000..b2ac4b67c9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -0,0 +1,332 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for basic building blocks used in eager mode RevNet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import time
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.revnet import blocks_test
+from tensorflow.contrib.eager.python.examples.revnet import config as config_
+from tensorflow.contrib.eager.python.examples.revnet import revnet
+from tensorflow.python.client import device_lib
+tfe = tf.contrib.eager
+
+
+def train_one_iter(model, inputs, labels, optimizer, global_step=None):
+ """Train for one iteration."""
+ grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
+ optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
+
+ return loss
+
+
+class RevNetTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(RevNetTest, self).setUp()
+ config = config_.get_hparams_cifar_38()
+ # Reconstruction could cause numerical error, use double precision for tests
+ config.dtype = tf.float64
+ config.fused = False # Fused batch norm does not support tf.float64
+ shape = (config.batch_size,) + config.input_shape
+ self.model = revnet.RevNet(config=config)
+ self.x = tf.random_normal(shape=shape, dtype=tf.float64)
+ self.t = tf.random_uniform(
+ shape=[config.batch_size],
+ minval=0,
+ maxval=config.n_classes,
+ dtype=tf.int64)
+ self.config = config
+
+ def tearDown(self):
+ del self.model
+ del self.x
+ del self.t
+ del self.config
+ super(RevNetTest, self).tearDown()
+
+ def test_call(self):
+ """Test `call` function."""
+
+ y, _ = self.model(self.x, training=False)
+ self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
+
+ def _check_grad_angle_combined(self, grads, grads_true):
+ """Verify that the reconstructed gradients has correct direction.
+
+ Due to numerical imprecision, the magnitude may be slightly different.
+ Yet according to the paper, the angle should be roughly the same.
+
+ Args:
+ grads: list of gradients from reconstruction
+ grads_true: list of true gradients
+ """
+
+ def _combine(gs):
+ return [tf.reshape(g, [-1]) for g in gs]
+
+ g1_all = tf.concat(_combine(grads), axis=0)
+ g2_all = tf.concat(_combine(grads_true), axis=0)
+
+ self.assertEqual(len(g1_all.shape), 1)
+ self.assertEqual(len(g2_all.shape), 1)
+
+ degree = blocks_test.compute_degree(g1_all, g2_all)
+ self.assertLessEqual(degree, 1e0)
+
+ def test_compute_gradients(self):
+ """Test `compute_gradients` function."""
+ self.model(self.x, training=False) # Initialize model
+ grads, vars_, loss = self.model.compute_gradients(
+ inputs=self.x, labels=self.t, training=True, l2_reg=True)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+ self.assertEqual(len(grads), len(vars_))
+ for grad, var in zip(grads, vars_):
+ self.assertEqual(grad.shape, var.shape)
+
+ # Compare against the true gradient computed by the tape
+ with tf.GradientTape() as tape:
+ logits, _ = self.model(self.x, training=True)
+ loss_true = self.model.compute_loss(logits=logits, labels=self.t)
+ grads_true = tape.gradient(loss_true, vars_)
+ self.assertAllClose(loss, loss_true)
+ self.assertAllClose(grads, grads_true, rtol=1e-4, atol=1e-4)
+ self._check_grad_angle_combined(grads, grads_true)
+
+ def test_call_defun(self):
+ """Test `call` function with defun."""
+ y, _ = tfe.defun(self.model.call)(self.x, training=False)
+ self.assertEqual(y.shape, [self.config.batch_size, self.config.n_classes])
+
+ def test_compute_gradients_defun(self):
+ """Test `compute_gradients` function with defun."""
+ compute_gradients = tfe.defun(self.model.compute_gradients)
+ grads, vars_, _ = compute_gradients(self.x, self.t, training=True)
+ self.assertTrue(isinstance(grads, list))
+ self.assertTrue(isinstance(vars_, list))
+ self.assertEqual(len(grads), len(vars_))
+ for grad, var in zip(grads, vars_):
+ if grad is not None:
+ self.assertEqual(grad.shape, var.shape)
+
+ def test_training_graph(self):
+ """Test model training in graph mode."""
+ with tf.Graph().as_default():
+ config = config_.get_hparams_cifar_38()
+ x = tf.random_normal(
+ shape=(self.config.batch_size,) + self.config.input_shape)
+ t = tf.random_uniform(
+ shape=(self.config.batch_size,),
+ minval=0,
+ maxval=self.config.n_classes,
+ dtype=tf.int32)
+ global_step = tfe.Variable(0., trainable=False)
+ model = revnet.RevNet(config=config)
+ model(x)
+ updates = model.get_updates_for(x)
+
+ x_ = tf.identity(x)
+ grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True)
+ optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
+ with tf.control_dependencies(updates):
+ train_op = optimizer.apply_gradients(
+ zip(grads_all, vars_all), global_step=global_step)
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ for _ in range(1):
+ sess.run(train_op)
+
+
+# Benchmark related
+def device_and_data_format():
+ return ("/gpu:0",
+ "channels_first") if tf.test.is_gpu_available() else ("/cpu:0",
+ "channels_last")
+
+
+def random_batch(batch_size, config):
+ shape = (batch_size,) + config.input_shape
+ images = tf.random_uniform(shape)
+ labels = tf.random_uniform(
+ [batch_size], minval=0, maxval=config.n_classes, dtype=tf.int32)
+
+ return images, labels
+
+
+class MockIterator(object):
+
+ def __init__(self, tensors):
+ self._tensors = [tf.identity(x) for x in tensors]
+
+ def next(self):
+ return self._tensors
+
+
+class RevNetBenchmark(tf.test.Benchmark):
+ """Eager and graph benchmarks for RevNet."""
+
+ def _train_batch_sizes(self):
+ """Shamelessly copied from `resnet50_test.py`.
+
+ Note: This is targeted towards ImageNet. CIFAR-10 should allow more
+ aggressive batch sizes.
+
+ Returns:
+ A tuple of possible batch sizes
+ """
+ for device in device_lib.list_local_devices():
+ if tf.DeviceSpec.from_string(device.name).device_type == "GPU":
+ if "K20" in device.physical_device_desc:
+ return (16,)
+ if "P100" in device.physical_device_desc:
+ return (16, 32, 64)
+ if tf.DeviceSpec.from_string(device.name).device_type == "TPU":
+ return (32,)
+ return (16, 32)
+
+ def _force_device_sync(self):
+ """Shamelessly copied from `resnet50_test.py`."""
+ tf.constant(1.).cpu()
+
+ def _report(self, label, start, num_iters, device, batch_size, data_format):
+ avg_time = (time.time() - start) / num_iters
+ dev = tf.DeviceSpec.from_string(device).device_type.lower()
+ name = "%s_%s_batch_%d_%s" % (label, dev, batch_size, data_format)
+ extras = {"examples_per_sec": batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def _benchmark_eager_apply(self,
+ label,
+ device_and_format,
+ defun=False,
+ execution_mode=None,
+ compiled=False):
+ config = config_.get_hparams_imagenet_56()
+ with tfe.execution_mode(execution_mode):
+ device, data_format = device_and_format
+ model = revnet.RevNet(config=config)
+ if defun:
+ model.call = tfe.defun(model.call, compiled=compiled)
+ batch_size = 64
+ num_burn = 5
+ num_iters = 10
+ with tf.device(device):
+ images, _ = random_batch(batch_size, config)
+ for _ in range(num_burn):
+ model(images, training=False)
+ if execution_mode:
+ tfe.async_wait()
+ gc.collect()
+ start = time.time()
+ for _ in range(num_iters):
+ model(images, training=False)
+ if execution_mode:
+ tfe.async_wait()
+ self._report(label, start, num_iters, device, batch_size, data_format)
+
+ def benchmark_eager_apply_sync(self):
+ self._benchmark_eager_apply(
+ "eager_apply_sync", device_and_data_format(), defun=False)
+
+ def benchmark_eager_apply_async(self):
+ self._benchmark_eager_apply(
+ "eager_apply_async",
+ device_and_data_format(),
+ defun=False,
+ execution_mode=tfe.ASYNC)
+
+ def benchmark_eager_call_defun(self):
+ self._benchmark_eager_apply(
+ "eager_apply_with_defun", device_and_data_format(), defun=True)
+
+ def _benchmark_eager_train(self,
+ label,
+ make_iterator,
+ device_and_format,
+ defun=False,
+ execution_mode=None,
+ compiled=False):
+ config = config_.get_hparams_imagenet_56()
+ with tfe.execution_mode(execution_mode):
+ device, data_format = device_and_format
+ for batch_size in self._train_batch_sizes():
+ (images, labels) = random_batch(batch_size, config)
+ model = revnet.RevNet(config=config)
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ if defun:
+ model.call = tfe.defun(model.call)
+
+ num_burn = 3
+ num_iters = 10
+ with tf.device(device):
+ iterator = make_iterator((images, labels))
+ for _ in range(num_burn):
+ (images, labels) = iterator.next()
+ train_one_iter(model, images, labels, optimizer)
+ if execution_mode:
+ tfe.async_wait()
+ self._force_device_sync()
+ gc.collect()
+
+ start = time.time()
+ for _ in range(num_iters):
+ (images, labels) = iterator.next()
+ train_one_iter(model, images, labels, optimizer)
+ if execution_mode:
+ tfe.async_wait()
+ self._force_device_sync()
+ self._report(label, start, num_iters, device, batch_size, data_format)
+
+ def benchmark_eager_train_sync(self):
+ self._benchmark_eager_train(
+ "eager_train_sync", MockIterator, device_and_data_format(), defun=False)
+
+ def benchmark_eager_train_async(self):
+ self._benchmark_eager_train(
+ "eager_train_async",
+ MockIterator,
+ device_and_data_format(),
+ defun=False,
+ execution_mode=tfe.ASYNC)
+
+ def benchmark_eager_train_defun(self):
+ self._benchmark_eager_train(
+ "eager_train", MockIterator, device_and_data_format(), defun=False)
+
+ def benchmark_eager_train_datasets_with_defun(self):
+
+ def make_iterator(tensors):
+ with tf.device("/device:CPU:0"):
+ ds = tf.data.Dataset.from_tensors(tensors).repeat()
+ return tfe.Iterator(ds)
+
+ self._benchmark_eager_train(
+ "eager_train_dataset_with_defun",
+ make_iterator,
+ device_and_data_format(),
+ defun=True)
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/sagan/BUILD b/tensorflow/contrib/eager/python/examples/sagan/BUILD
new file mode 100644
index 0000000000..b470a41d81
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/BUILD
@@ -0,0 +1,59 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+# Model
+py_library(
+ name = "config",
+ srcs = ["config.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "ops",
+ srcs = ["ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "sagan",
+ srcs = ["sagan.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+# Tests
+cuda_py_test(
+ name = "ops_test",
+ size = "small",
+ srcs = ["ops_test.py"],
+ additional_deps = [
+ ":ops",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "sagan_test",
+ size = "large",
+ srcs = ["sagan_test.py"],
+ additional_deps = [
+ ":config",
+ ":sagan",
+ "//tensorflow:tensorflow_py",
+ ],
+ tags = [
+ "optonly",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/config.py b/tensorflow/contrib/eager/python/examples/sagan/config.py
new file mode 100644
index 0000000000..1967bbd867
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/config.py
@@ -0,0 +1,72 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Configuration in format of tf.contrib.training.HParams.
+Supports default 128x128 ImageNet.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+tfe = tf.contrib.eager
+
+
+def get_hparams_imagenet():
+ """Configurations to train SAGAN on 128x128 ImageNet dataset."""
+ config = tf.contrib.training.HParams()
+ if tf.test.is_gpu_available():
+ config.add_hparam("image_shape", (3, 128, 128))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (512, 4, 4))
+ else:
+ config.add_hparam("image_shape", (128, 128, 3))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (4, 4, 512))
+
+ config.add_hparam("latent_dim", 128)
+ config.add_hparam("update_g_once_every", 1)
+ config.add_hparam("batch_size", 64)
+ config.add_hparam("d_init_filters", 32)
+ config.add_hparam("num_upsamples", 5)
+ # (512, 4, 4) -> (3, 128, 128)
+ return config
+
+
+def get_hparams_mock():
+ """Configurations of smaller networks for testing."""
+ config = tf.contrib.training.HParams()
+ if tf.test.is_gpu_available():
+ config.add_hparam("image_shape", (3, 16, 16))
+ config.add_hparam("data_format", "channels_first")
+ config.add_hparam("g_init_shape", (32, 2, 2))
+ else:
+ config.add_hparam("image_shape", (16, 16, 3))
+ config.add_hparam("data_format", "channels_last")
+ config.add_hparam("g_init_shape", (2, 2, 32))
+
+ config.add_hparam("latent_dim", 16)
+ config.add_hparam("update_g_once_every", 1)
+ config.add_hparam("batch_size", 2)
+ config.add_hparam("d_init_filters", 4)
+ config.add_hparam("num_upsamples", 3)
+ # (32, 2, 2) -> (3, 16, 16)
+ return config
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops.py b/tensorflow/contrib/eager/python/examples/sagan/ops.py
new file mode 100644
index 0000000000..9a03cab1d1
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/ops.py
@@ -0,0 +1,71 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Auxiliary operations.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+def flatten_hw(x, data_format="channels_first"):
+ """Flatten the input tensor across height and width dimensions."""
+ if data_format == "channels_last":
+ x = tf.transpose(x, perm=[0, 3, 1, 2]) # Convert to `channels_first`
+
+ old_shape = tf.shape(x)
+ new_shape = [old_shape[0], old_shape[2] * old_shape[3], old_shape[1]]
+
+ return tf.reshape(x, new_shape)
+
+
+def broaden_hw(x, h, w, c, data_format="channels_first"):
+ """Broaden dimension so that output has height and width."""
+ if data_format == "channels_first":
+ shape = [-1, c, h, w]
+ else:
+ shape = [-1, h, w, c]
+
+ return tf.reshape(x, shape)
+
+
+class BroadenHW(tf.keras.layers.Layer):
+ """Wrapper class so that `broaden_hw` can be used in `tf.keras.Sequential`."""
+
+ def __init__(self, h, w, c, data_format="channels_first"):
+ super(BroadenHW, self).__init__()
+ self.h = h
+ self.w = w
+ self.c = c
+ self.data_format = data_format
+
+ def call(self, x):
+ return broaden_hw(
+ x, h=self.h, w=self.w, c=self.c, data_format=self.data_format)
+
+ def compute_output_shape(self, input_shape):
+ input_shape = tf.TensorShape(input_shape).as_list()
+ if self.data_format == "channels_first":
+ output_shape = (input_shape[0], self.c, self.h, self.w)
+ else:
+ output_shape = (input_shape[0], self.h, self.w, self.c)
+
+ return tf.TensorShape(output_shape)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/ops_test.py b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
new file mode 100644
index 0000000000..3454985904
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/ops_test.py
@@ -0,0 +1,59 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for auxiliary operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import ops
+
+
+class OpsTest(tf.test.TestCase):
+
+ def test_flatten_hw(self):
+ """Test `flatten_hw` function with mock object."""
+
+ batch_size = 1
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ x = tf.random_normal(shape=(batch_size, 3, 4, 4))
+ y = ops.flatten_hw(x, data_format="channels_first")
+ self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
+
+ # NHWC format
+ x = tf.random_normal(shape=(batch_size, 4, 4, 3))
+ y = ops.flatten_hw(x, data_format="channels_last")
+ self.assertEqual(y.shape, (batch_size, 4 * 4, 3))
+
+ def test_broaden_hw(self):
+ """Test `broaden_hw` function with mock object."""
+
+ batch_size = 1
+ # NHWC format
+ x = tf.random_normal(shape=[batch_size, 4 * 4 * 16])
+ y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_last")
+ self.assertEqual(y.shape, (batch_size, 4, 4, 16))
+
+ # Default NCHW format
+ if tf.test.is_gpu_available():
+ y = ops.broaden_hw(x, h=4, w=4, c=16, data_format="channels_first")
+ self.assertEqual(y.shape, (batch_size, 16, 4, 4))
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
new file mode 100644
index 0000000000..561be36c91
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
@@ -0,0 +1,232 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Self-attention generative adversarial with eager execution.
+
+Code for main model.
+
+Reference [Self-Attention Generative Adversarial
+Networks](https://arxiv.org/pdf/1805.08318.pdf)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import ops
+tfe = tf.contrib.eager
+
+
+class SelfAttentionModule(tf.keras.Model):
+ """Self-attention module composed of convolutional layers."""
+
+ def __init__(self,
+ attention_features,
+ original_features,
+ data_format="channels_first"):
+ """Initialize the module.
+
+ Args:
+ attention_features: Number of filters for the attention computation.
+ original_features: Number of filters of the original Tensor.
+ data_format: Either 'channels_first' or 'channels_last'
+ """
+ super(SelfAttentionModule, self).__init__()
+ self.data_format = data_format
+ # Matrix multiplication implemented as 2D Convolution
+ self.f = tf.keras.layers.Conv2D(
+ filters=attention_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.g = tf.keras.layers.Conv2D(
+ filters=attention_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.h = tf.keras.layers.Conv2D(
+ filters=original_features,
+ kernel_size=1,
+ strides=(1, 1),
+ data_format=data_format)
+ self.scale = tfe.Variable(0., trainable=True)
+
+ def call(self, x):
+ f = self.f(x)
+ g = self.g(x)
+ h = self.h(x)
+
+ f_flatten = ops.flatten_hw(f, data_format=self.data_format)
+ g_flatten = ops.flatten_hw(g, data_format=self.data_format)
+ h_flatten = ops.flatten_hw(h, data_format=self.data_format)
+
+ s = tf.matmul(g_flatten, f_flatten, transpose_b=True)
+ b = tf.nn.softmax(s, axis=-1)
+ o = tf.matmul(b, h_flatten)
+ y = self.scale * tf.reshape(o, tf.shape(x)) + x
+
+ return y
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+class SAGAN(tf.contrib.checkpoint.Checkpointable):
+ """Self-attention generative adversarial network."""
+
+ def __init__(self, config):
+ """Initialize the model.
+
+ Args:
+ config: tf.contrib.training.HParams object; specifies hyperparameters
+ """
+ super(SAGAN, self).__init__()
+ self.config = config
+ self.generator = self._construct_generator()
+ self.discriminator = self._construct_discriminator()
+
+ def _construct_generator(self):
+ """Construct generator."""
+ # TODO(lxuechen): Add spectral normalization for WGAN
+ axis = 1 if self.config.data_format == "channels_first" else 3
+
+ generator = tf.keras.Sequential()
+ generator.add(
+ tf.keras.layers.InputLayer(input_shape=(self.config.latent_dim,)))
+ generator.add(
+ tf.keras.layers.Dense(
+ units=np.prod(self.config.g_init_shape), activation=tf.nn.relu))
+
+ if self.config.data_format == "channels_first":
+ c, h, w = self.config.g_init_shape
+ else:
+ h, w, c = self.config.g_init_shape
+
+ # Reshape to NHWC/NCHW
+ generator.add(
+ ops.BroadenHW(h=h, w=w, c=c, data_format=self.config.data_format))
+
+ filters_list = [c // 2**p for p in range(1, self.config.num_upsamples + 1)]
+ filters_list[-1] = 3 # Standard RGB images
+
+ for filters in filters_list[:len(filters_list) // 2]:
+ generator.add(
+ tf.keras.layers.Conv2DTranspose(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ use_bias=False,
+ padding="SAME",
+ data_format=self.config.data_format))
+ generator.add(tf.keras.layers.BatchNormalization(axis=axis))
+ generator.add(tf.keras.layers.Activation("relu"))
+
+ # pylint: disable=undefined-loop-variable
+ generator.add(
+ SelfAttentionModule(
+ original_features=filters,
+ attention_features=filters // 8,
+ data_format=self.config.data_format))
+ # pylint: enable=undefined-loop-variable
+
+ for filters in filters_list[len(filters_list) // 2:]:
+ generator.add(
+ tf.keras.layers.Conv2DTranspose(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ use_bias=False,
+ padding="SAME",
+ data_format=self.config.data_format))
+ if filters == 3:
+ # Assume Image rescaled to [-1, 1]
+ generator.add(tf.keras.layers.Activation("tanh"))
+ else:
+ generator.add(tf.keras.layers.BatchNormalization(axis=axis))
+ generator.add(tf.keras.layers.Activation("relu"))
+
+ return generator
+
+ def _construct_discriminator(self):
+ """Construct discriminator."""
+ # TODO(lxuechen): Add spectral normalization for WGAN
+ discriminator = tf.keras.Sequential()
+ discriminator.add(
+ tf.keras.layers.InputLayer(input_shape=self.config.image_shape))
+
+ filters_list = [
+ self.config.d_init_filters * 2**p
+ for p in range(self.config.num_upsamples)
+ ]
+
+ for filters in filters_list[:(len(filters_list) + 1) // 2]:
+ discriminator.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format))
+ discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
+
+ # pylint: disable=undefined-loop-variable
+ discriminator.add(
+ SelfAttentionModule(
+ original_features=filters,
+ attention_features=filters // 8,
+ data_format=self.config.data_format))
+ # pylint: enable=undefined-loop-variable
+
+ for filters in filters_list[(len(filters_list) + 1) // 2:]:
+ discriminator.add(
+ tf.keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=4,
+ strides=(2, 2),
+ padding="SAME",
+ data_format=self.config.data_format))
+ discriminator.add(tf.keras.layers.LeakyReLU(alpha=.1))
+
+ discriminator.add(tf.keras.layers.Flatten())
+ discriminator.add(tf.keras.layers.Dense(units=1))
+
+ return discriminator
+
+ def compute_loss_and_grads(self, real_images, noise, training=True):
+ """Compute loss and gradients for both generator and discriminator."""
+ # TODO(lxuechen): Add gradient penalty for discriminator
+ with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
+ real_logits = self.discriminator(real_images, training=training)
+
+ fake_images = self.generator.call(noise, training=training)
+ fake_logits = self.discriminator.call(fake_images)
+
+ g_loss = self.compute_g_loss(fake_logits)
+ d_loss = self.compute_d_loss(fake_logits, real_logits)
+
+ g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
+ d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_variables)
+
+ return g_loss, d_loss, g_grads, d_grads
+
+ def compute_g_loss(self, fake_logits):
+ return -tf.reduce_mean(fake_logits) # Hinge loss
+
+ def compute_d_loss(self, fake_logits, real_logits):
+ # Hinge loss
+ real_loss = tf.reduce_mean(tf.nn.relu(1. - real_logits))
+ fake_loss = tf.reduce_mean(tf.nn.relu(1. + fake_logits))
+ return real_loss + fake_loss
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
new file mode 100644
index 0000000000..1834594510
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for self-attention generative adversarial network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.eager.python.examples.sagan import config as config_
+from tensorflow.contrib.eager.python.examples.sagan import sagan
+tfe = tf.contrib.eager
+
+
+class SAGANTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(SAGANTest, self).setUp()
+ config = config_.get_hparams_mock()
+ self.noise_shape = (config.batch_size, config.latent_dim)
+ self.logits_shape = (config.batch_size, 1)
+ self.images_shape = (config.batch_size,) + config.image_shape
+
+ self.model = sagan.SAGAN(config=config)
+ self.noise = tf.random_normal(shape=self.noise_shape)
+ self.real_images = tf.random_normal(shape=self.images_shape)
+ self.config = config
+
+ def tearDown(self):
+ del self.model
+ del self.noise
+ del self.real_images
+ super(SAGANTest, self).tearDown()
+
+ def test_generator_call(self):
+ """Test `generator.__call__` function."""
+ fake_images = self.model.generator(self.noise, training=False)
+ self.assertEqual(fake_images.shape, self.images_shape)
+
+ def test_generator_call_defun(self):
+ """Test `generator.__call__` function with defun."""
+ call_ = tfe.defun(self.model.generator.__call__)
+ fake_images = call_(self.noise, training=False)
+ self.assertEqual(fake_images.shape, self.images_shape)
+
+ def test_discriminator_call(self):
+ """Test `discriminator.__call__` function."""
+ real_logits = self.model.discriminator(self.real_images)
+ self.assertEqual(real_logits.shape, self.logits_shape)
+
+ def test_discriminator_call_defun(self):
+ """Test `discriminator.__call__` function with defun."""
+ call_ = tfe.defun(self.model.discriminator.__call__)
+ real_logits = call_(self.real_images)
+ self.assertEqual(real_logits.shape, self.logits_shape)
+
+ def test_compute_loss_and_grads(self):
+ """Test `compute_loss_and_grads` function."""
+ g_loss, d_loss, g_grads, d_grads = self.model.compute_loss_and_grads(
+ self.real_images, self.noise, training=False)
+ self.assertEqual(g_loss.shape, ())
+ self.assertEqual(d_loss.shape, ())
+ self.assertTrue(isinstance(g_grads, list))
+ self.assertTrue(isinstance(d_grads, list))
+ g_vars = self.model.generator.trainable_variables
+ d_vars = self.model.discriminator.trainable_variables
+
+ self.assertEqual(len(g_grads), len(g_vars))
+ self.assertEqual(len(d_grads), len(d_vars))
+
+ def test_compute_loss_and_grads_defun(self):
+ """Test `compute_loss_and_grads` function with defun."""
+ compute_loss_and_grads = tfe.defun(self.model.compute_loss_and_grads)
+ g_loss, d_loss, g_grads, d_grads = compute_loss_and_grads(
+ self.real_images, self.noise, training=False)
+ self.assertEqual(g_loss.shape, ())
+ self.assertEqual(d_loss.shape, ())
+ self.assertTrue(isinstance(g_grads, list))
+ self.assertTrue(isinstance(d_grads, list))
+ g_vars = self.model.generator.trainable_variables
+ d_vars = self.model.discriminator.trainable_variables
+
+ self.assertEqual(len(g_grads), len(g_vars))
+ self.assertEqual(len(d_grads), len(d_vars))
+
+
+if __name__ == "__main__":
+ tf.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
new file mode 100644
index 0000000000..75cb3f8227
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/1_basic.ipynb
@@ -0,0 +1,282 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "TFE Workshop: control flow",
+ "version": "0.3.2",
+ "provenance": [],
+ "include_colab_link": true
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/664b2f8700485ff6801f4d26293bd567/tfe-workshop-control-flow.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "9BpQzh9BvJlj",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "0b336886-8204-4815-89fa-5291a49d5784"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "import numpy as np\n",
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "0roIB19GvOjI",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Eager execution basics\n",
+ "\n",
+ "When eager execution is enabled TensorFlow immediately executes operations, and Tensors are always available. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "jeO8F-V-vN24",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "aeb3bdec-50b7-440d-93d8-5a171f091081"
+ },
+ "cell_type": "code",
+ "source": [
+ "t = tf.constant([[1, 2], [3, 4]])\n",
+ "t"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=0, shape=(2, 2), dtype=int32, numpy=\n",
+ "array([[1, 2],\n",
+ " [3, 4]], dtype=int32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 2
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Y17RwSFxvlDL",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "cfcc10c7-707b-4997-99b3-a5f382c5166b"
+ },
+ "cell_type": "code",
+ "source": [
+ "tf.matmul(t, t)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=2, shape=(2, 2), dtype=int32, numpy=\n",
+ "array([[ 7, 10],\n",
+ " [15, 22]], dtype=int32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Dab1bS3TvmRE",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "8a624f3d-a658-4359-c586-1c5f6bf4c8b7"
+ },
+ "cell_type": "code",
+ "source": [
+ "# It's also possible to have Python control flow which depends on the value of tensors.\n",
+ "if t[0, 0] > 0.5:\n",
+ " print(\"T is bigger\")\n",
+ "else:\n",
+ " print(\"T is smaller\")"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "T is bigger\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dPgptJcGwIon",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "c4f27f2b-0848-4475-dde5-2534dac65a5c"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Tensors are also usable as numpy arrays\n",
+ "np.prod(t)"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "24"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "p3DTfQXnwXzj",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Exercise\n",
+ "\n",
+ "The algorithm for bisecting line search is a pretty simple way to find a zero of a continuous scalar function in an interval [a,b] where f(a) and f(b) have different signs. Simply evaluate f((a+b)/2), and narrow the interval by replacing either a or b with (a+b)/2 such that the function when applied on the boundary of the interval still has different signs.\n",
+ "\n",
+ "Implement a python function `bisecting_line_search(f, a, b, epsilon)` which returns a value such that `tf.abs(f(value)) < epsilon`.\n",
+ "\n",
+ "One thing to keep in mind: python's `==` opertor is not overloaded on Tensors, so you need to use `tf.equal` to compare for equality."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "6eq0YuI6ykm5",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "# Example test harness to get you going\n",
+ "\n",
+ "def test_f(x):\n",
+ " return x - 0.1234\n",
+ "def bisecting_line_search(f, a, b, epsilon):\n",
+ " # Return x such that f(x) <= epsilon.\n",
+ " pass\n",
+ "a = tf.constant(0.0)\n",
+ "b = tf.constant(1.0)\n",
+ "epsilon = tf.constant(0.001)\n",
+ "x = bisecting_line_search(test_f, a, b, epsilon)\n"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "LcMmEfd_xvej",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 170
+ },
+ "outputId": "f402aa50-8ce3-4416-f755-8bbcd1af7809"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Double-click to see the solution\n",
+ "\n",
+ "def bisecting_line_search(f, a, b, epsilon):\n",
+ " f_a = f(a)\n",
+ " f_b = f(b)\n",
+ " probe = (a + b) / 2\n",
+ " f_probe = f(probe)\n",
+ " while tf.abs(f_probe) > epsilon:\n",
+ " if tf.equal(tf.sign(f_probe), tf.sign(f_a)):\n",
+ " a = probe\n",
+ " f_a = f_probe\n",
+ " else:\n",
+ " b = probe\n",
+ " f_b = f_probe\n",
+ " probe = (a + b) / 2\n",
+ " f_probe = f(probe)\n",
+ " print(\"new probe\", probe)\n",
+ " return probe\n",
+ "\n",
+ "bisecting_line_search(test_f, 0., 1., 0.001)"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "('new probe', 0.25)\n",
+ "('new probe', 0.125)\n",
+ "('new probe', 0.0625)\n",
+ "('new probe', 0.09375)\n",
+ "('new probe', 0.109375)\n",
+ "('new probe', 0.1171875)\n",
+ "('new probe', 0.12109375)\n",
+ "('new probe', 0.123046875)\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "0.123046875"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ }
+ ]
+}
diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
new file mode 100644
index 0000000000..4f1410e00b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
@@ -0,0 +1,1018 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "TFE Workshop: Models.ipynb",
+ "version": "0.3.2",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/5cfcffd408bd5103f5ae747bc97ab0b5/tfe-workshop-models.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BMxv1O6Q0SJL",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "8be9c556-ac7f-4142-e35e-19dc2b097121"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "tfe = tf.contrib.eager"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "lE1vJhxp0WR9",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Variables\n",
+ "\n",
+ "TensorFlow variables are useful to store the state in your program. They are integrated with other parts of the API (taking gradients, checkpointing, graph functions)."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "C4ztQNgc0VpW",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "8b63ae1f-2670-49c0-a31b-8cf7fc4194a1"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Creating variables\n",
+ "v = tfe.Variable(1.0)\n",
+ "v"
+ ],
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 2
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "H0daItGg1IAp",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "e47d5aab-16a1-4e29-c27d-7fbc0b94b5d3"
+ },
+ "cell_type": "code",
+ "source": [
+ "v.assign_add(1.0)\n",
+ "v"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 3
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "BJvBzcIG1hyK",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Layers: common sets of useful operations\n",
+ "\n",
+ "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n",
+ "\n",
+ "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n",
+ "\n",
+ "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "iSQTS3QW1YQQ",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "c5d8aa10-dcad-44f7-f0eb-0faf5249fd7e"
+ },
+ "cell_type": "code",
+ "source": [
+ "# In the tf.keras.layers package, layers are objects. To construct a layer,\n",
+ "# simply construct the object. Most layers take as a first argument the number\n",
+ "# of output dimensions / channels.\n",
+ "layer = tf.keras.layers.Dense(100)\n",
+ "\n",
+ "# The number of input dimensions is often unnecessary, as it can be inferred\n",
+ "# the first time the layer is used, but it can be provided if you want to \n",
+ "# specify it manually, which is useful in some complex models.\n",
+ "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))\n"
+ ],
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "nRuUogoS1liV",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 68
+ },
+ "outputId": "c352ce79-d519-45e4-a12e-1eaba76871a2"
+ },
+ "cell_type": "code",
+ "source": [
+ "layer(tf.zeros([2, 2]))"
+ ],
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=43, shape=(2, 10), dtype=float32, numpy=\n",
+ "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 5
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "JH4Kf4ka1mht",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 136
+ },
+ "outputId": "c34e2378-f83d-42c5-d30a-ebe55620368a"
+ },
+ "cell_type": "code",
+ "source": [
+ "layer.variables"
+ ],
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[<tf.Variable 'dense/kernel:0' shape=(2, 10) dtype=float32, numpy=\n",
+ " array([[-0.42494273, -0.2067694 , 0.4519381 , 0.6842533 , 0.04131705,\n",
+ " 0.70547956, 0.4021917 , -0.5939298 , -0.5671462 , 0.5586321 ],\n",
+ " [ 0.3709975 , -0.64126074, -0.5386696 , -0.42212513, 0.6550072 ,\n",
+ " 0.70081085, 0.08859557, -0.30801034, -0.31450653, 0.02522504]],\n",
+ " dtype=float32)>,\n",
+ " <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32, numpy=array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>]"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 6
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "DSI4NF0_1vn-",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n",
+ "Conv2D, LSTM, BatchNormalization, Dropout, and many others."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "hMgDBftJ12Bp",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Models: composing layers\n",
+ "\n",
+ "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n",
+ "\n",
+ "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model.\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "K3gVY6gj1nbe",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 190
+ },
+ "outputId": "6e9be0c4-960e-46c2-cdd9-7e94ad09d46b"
+ },
+ "cell_type": "code",
+ "source": [
+ "class ResnetIdentityBlock(tf.keras.Model):\n",
+ " def __init__(self, kernel_size, filters):\n",
+ " super(ResnetIdentityBlock, self).__init__(name='')\n",
+ " filters1, filters2, filters3 = filters\n",
+ "\n",
+ " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n",
+ " self.bn2a = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n",
+ " self.bn2b = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n",
+ " self.bn2c = tf.keras.layers.BatchNormalization()\n",
+ "\n",
+ " def call(self, input_tensor, training=False):\n",
+ " x = self.conv2a(input_tensor)\n",
+ " x = self.bn2a(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2b(x)\n",
+ " x = self.bn2b(x, training=training)\n",
+ " x = tf.nn.relu(x)\n",
+ "\n",
+ " x = self.conv2c(x)\n",
+ " x = self.bn2c(x, training=training)\n",
+ "\n",
+ " x += input_tensor\n",
+ " return tf.nn.relu(x)\n",
+ " \n",
+ "block = ResnetIdentityBlock(1, [1, 2, 3])\n",
+ "print(block(tf.zeros([1, 2, 3, 3])))\n",
+ "print([x.name for x in block.variables])"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "tf.Tensor(\n",
+ "[[[[0. 0. 0.]\n",
+ " [0. 0. 0.]\n",
+ " [0. 0. 0.]]\n",
+ "\n",
+ " [[0. 0. 0.]\n",
+ " [0. 0. 0.]\n",
+ " [0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n",
+ "['resnet_identity_block/conv2d/kernel:0', 'resnet_identity_block/conv2d/bias:0', 'resnet_identity_block/batch_normalization/gamma:0', 'resnet_identity_block/batch_normalization/beta:0', 'resnet_identity_block/conv2d_1/kernel:0', 'resnet_identity_block/conv2d_1/bias:0', 'resnet_identity_block/batch_normalization_1/gamma:0', 'resnet_identity_block/batch_normalization_1/beta:0', 'resnet_identity_block/conv2d_2/kernel:0', 'resnet_identity_block/conv2d_2/bias:0', 'resnet_identity_block/batch_normalization_2/gamma:0', 'resnet_identity_block/batch_normalization_2/beta:0', 'resnet_identity_block/batch_normalization/moving_mean:0', 'resnet_identity_block/batch_normalization/moving_variance:0', 'resnet_identity_block/batch_normalization_1/moving_mean:0', 'resnet_identity_block/batch_normalization_1/moving_variance:0', 'resnet_identity_block/batch_normalization_2/moving_mean:0', 'resnet_identity_block/batch_normalization_2/moving_variance:0']\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "LPXhHUIc1-sO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "5pXgzNAU17xk",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 173
+ },
+ "outputId": "03b7eaf8-9b35-482b-bcf0-a99af6c2c6a4"
+ },
+ "cell_type": "code",
+ "source": [
+ " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(2, 1, \n",
+ " padding='same'),\n",
+ " tf.keras.layers.BatchNormalization(),\n",
+ " tf.keras.layers.Conv2D(3, (1, 1)),\n",
+ " tf.keras.layers.BatchNormalization()])\n",
+ "my_seq(tf.zeros([1, 2, 3, 3]))\n"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=493, shape=(1, 2, 3, 3), dtype=float32, numpy=\n",
+ "array([[[[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]],\n",
+ "\n",
+ " [[0., 0., 0.],\n",
+ " [0., 0., 0.],\n",
+ " [0., 0., 0.]]]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "MZrns6p22GEQ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Exercise!\n",
+ "\n",
+ "Make a simple convolutional neural network model, useful for things such as MNIST which don't need too many parameters. A sequence of two or three convolutions with small output channels (say, 32 and 64) plus one or two fully connected layers is probably enough.\n",
+ "\n",
+ "The input shape should be [batch_size, 28, 28, 1]."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8CAUa3KNN916",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "97c0ff3c-c962-4c13-eee8-406101465761"
+ },
+ "cell_type": "code",
+ "source": [
+ "# TODO: Implement a convolutional model as described above, and assign it to\n",
+ "# model.\n",
+ "model = tf.keras.Sequential([\n",
+ " \n",
+ "])"
+ ],
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "vLDDduR32E82",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 34
+ },
+ "outputId": "09bb1d43-b4c6-44b5-916e-0d2903d10cf4"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Click to see the answer\n",
+ "\n",
+ "max_pool = tf.keras.layers.MaxPooling2D(\n",
+ " (2, 2), (2, 2), padding='same')\n",
+ " # The model consists of a sequential chain of layers, so tf.keras.Sequential\n",
+ " # (a subclass of tf.keras.Model) makes for a compact description.\n",
+ "model = tf.keras.Sequential(\n",
+ " [\n",
+ " tf.keras.layers.Conv2D(\n",
+ " 32,\n",
+ " 5,\n",
+ " padding='same',\n",
+ " activation=tf.nn.relu),\n",
+ " max_pool,\n",
+ " tf.keras.layers.Conv2D(\n",
+ " 64,\n",
+ " 5,\n",
+ " padding='same',\n",
+ " activation=tf.nn.relu),\n",
+ " max_pool,\n",
+ " tf.keras.layers.Flatten(),\n",
+ " tf.keras.layers.Dense(1024, activation=tf.nn.relu),\n",
+ " tf.keras.layers.Dropout(0.4),\n",
+ " tf.keras.layers.Dense(10)\n",
+ " ])\n",
+ "\n",
+ "model(tf.zeros([1, 28, 28, 1]))"
+ ],
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=625, shape=(1, 10), dtype=float32, numpy=array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "H_CKVBroik4M",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Stop here for now"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "_yRwuE6MMmzC",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Training\n",
+ "\n",
+ "When eager execution is enabled, you can write Pythonic training loops. Simply\n",
+ "\n",
+ "1. load your data into a `tf.data.Dataset`, which lets you construct functional pipelines for processing, shuffling, and batching your data,\n",
+ "2. iterate over the dataset using a Python `for` loop, and\n",
+ "3. perform an optimization step in the body of your `for` loop.\n",
+ "\n",
+ "This workflow is exemplified in the following exercise."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gj0-EkTc_Xt1",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "\n",
+ "## Exercise!\n",
+ "\n",
+ "In this exercise, you'll train the convolutional model you implemented for the previous exericse on the MNIST dataset. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "WOGm9HHn_byR",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17
+ },
+ "outputId": "bbccc7ad-33cd-446e-bcda-f358c7547e1b"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Utilities for downloading MNIST data (double-click to show code)\n",
+ "import gzip\n",
+ "import os\n",
+ "import tempfile\n",
+ "from six.moves import urllib\n",
+ "import shutil\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "def read32(bytestream):\n",
+ " \"\"\"Read 4 bytes from bytestream as an unsigned 32-bit integer.\"\"\"\n",
+ " dt = np.dtype(np.uint32).newbyteorder('>')\n",
+ " return np.frombuffer(bytestream.read(4), dtype=dt)[0]\n",
+ "\n",
+ "\n",
+ "def check_image_file_header(filename):\n",
+ " \"\"\"Validate that filename corresponds to images for the MNIST dataset.\"\"\"\n",
+ " with tf.gfile.Open(filename, 'rb') as f:\n",
+ " magic = read32(f)\n",
+ " read32(f) # num_images, unused\n",
+ " rows = read32(f)\n",
+ " cols = read32(f)\n",
+ " if magic != 2051:\n",
+ " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n",
+ " f.name))\n",
+ " if rows != 28 or cols != 28:\n",
+ " raise ValueError(\n",
+ " 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %\n",
+ " (f.name, rows, cols))\n",
+ "\n",
+ "\n",
+ "def check_labels_file_header(filename):\n",
+ " \"\"\"Validate that filename corresponds to labels for the MNIST dataset.\"\"\"\n",
+ " with tf.gfile.Open(filename, 'rb') as f:\n",
+ " magic = read32(f)\n",
+ " read32(f) # num_items, unused\n",
+ " if magic != 2049:\n",
+ " raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,\n",
+ " f.name))\n",
+ " \n",
+ "def download(directory, filename):\n",
+ " \"\"\"Download (and unzip) a file from the MNIST dataset if not already done.\"\"\"\n",
+ " filepath = os.path.join(directory, filename)\n",
+ " if tf.gfile.Exists(filepath):\n",
+ " return filepath\n",
+ " if not tf.gfile.Exists(directory):\n",
+ " tf.gfile.MakeDirs(directory)\n",
+ " # CVDF mirror of http://yann.lecun.com/exdb/mnist/\n",
+ " url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
+ " _, zipped_filepath = tempfile.mkstemp(suffix='.gz')\n",
+ " print('Downloading %s to %s' % (url, zipped_filepath))\n",
+ " urllib.request.urlretrieve(url, zipped_filepath)\n",
+ " with gzip.open(zipped_filepath, 'rb') as f_in, \\\n",
+ " tf.gfile.Open(filepath, 'wb') as f_out:\n",
+ " shutil.copyfileobj(f_in, f_out)\n",
+ " os.remove(zipped_filepath)\n",
+ " return filepath\n",
+ "\n",
+ "\n",
+ "def dataset(directory, images_file, labels_file):\n",
+ " \"\"\"Download and parse MNIST dataset.\"\"\"\n",
+ "\n",
+ " images_file = download(directory, images_file)\n",
+ " labels_file = download(directory, labels_file)\n",
+ "\n",
+ " check_image_file_header(images_file)\n",
+ " check_labels_file_header(labels_file)\n",
+ "\n",
+ " def decode_image(image):\n",
+ " # Normalize from [0, 255] to [0.0, 1.0]\n",
+ " image = tf.decode_raw(image, tf.uint8)\n",
+ " image = tf.cast(image, tf.float32)\n",
+ " image = tf.reshape(image, [28, 28, 1])\n",
+ " return image / 255.0\n",
+ "\n",
+ " def decode_label(label):\n",
+ " label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]\n",
+ " label = tf.reshape(label, []) # label is a scalar\n",
+ " return tf.to_int32(label)\n",
+ "\n",
+ " images = tf.data.FixedLengthRecordDataset(\n",
+ " images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
+ " labels = tf.data.FixedLengthRecordDataset(\n",
+ " labels_file, 1, header_bytes=8).map(decode_label)\n",
+ " return tf.data.Dataset.zip((images, labels))\n",
+ "\n",
+ "\n",
+ "def get_training_data(directory):\n",
+ " \"\"\"tf.data.Dataset object for MNIST training data.\"\"\"\n",
+ " return dataset(directory, 'train-images-idx3-ubyte',\n",
+ " 'train-labels-idx1-ubyte').take(1024)\n",
+ "\n",
+ "def get_test_data(directory):\n",
+ " \"\"\"tf.data.Dataset object for MNIST test data.\"\"\"\n",
+ " return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
+ ],
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "4ejmJ2dv_f0R",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 85
+ },
+ "outputId": "274c0381-e505-4e69-f910-3def6f8572a7"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Don't forget to run the cell above!\n",
+ "training_data = get_training_data(\"/tmp/mnist/train\")\n",
+ "test_data = get_test_data(\"/tmp/mnist/test\")"
+ ],
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/tmp4ull1xwa.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/tmp1eikhj1v.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/tmpcp8xah9c.gz\n",
+ "Downloading https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/tmpqww_1e74.gz\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "TANpFS6GKLMC",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Fill in the implementation of `train_one_epoch` below and run the cell to train your model. "
+ ]
+ },
+ {
+ "metadata": {
+ "id": "btKL0Ss9_rmC",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ },
+ "outputId": "56858516-86fc-424a-f00d-6f088f98bf9b"
+ },
+ "cell_type": "code",
+ "source": [
+ "EPOCHS = 5\n",
+ "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n",
+ "\n",
+ "def loss_fn(logits, labels):\n",
+ " return tf.reduce_mean(\n",
+ " tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits=tf.squeeze(logits), labels=labels))\n",
+ "\n",
+ "def train_one_epoch(model, training_data, optimizer):\n",
+ " # TODO: Implement an optimization step and return the average loss.\n",
+ " #\n",
+ " # Hint: Use `tf.GradientTape` to compute the gradient of the loss, and use\n",
+ " # `optimizer.apply_gradients` to update the model's variables, which are\n",
+ " # accessible as `model.variables`\n",
+ " average_loss = tfe.metrics.Mean('loss')\n",
+ " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n",
+ " pass\n",
+ " return average_loss.result()\n",
+ "\n",
+ "for epoch in range(EPOCHS):\n",
+ " loss = train_one_epoch(model, training_data, optimizer)\n",
+ " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))"
+ ],
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Average loss after epoch 0: 2.2847\n",
+ "Average loss after epoch 1: 2.2305\n",
+ "Average loss after epoch 2: 2.1334\n",
+ "Average loss after epoch 3: 1.9115\n",
+ "Average loss after epoch 4: 1.4285\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "yAOFupJN_htg",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 102
+ },
+ "outputId": "67e711e4-76c9-4e3f-bb49-a14955dba03a"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Double-click to see a solution.\n",
+ "EPOCHS = 5\n",
+ "optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.5)\n",
+ "\n",
+ "def _loss_fn(logits, labels):\n",
+ " return tf.reduce_mean(\n",
+ " tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
+ " logits=tf.squeeze(logits), labels=labels))\n",
+ "\n",
+ "def _train_one_epoch(model, training_data):\n",
+ " average_loss = tfe.metrics.Mean(\"loss\")\n",
+ " for images, labels in training_data.shuffle(buffer_size=10000).batch(64):\n",
+ " with tf.GradientTape() as tape:\n",
+ " logits = model(images, training=True)\n",
+ " loss = _loss_fn(logits, labels)\n",
+ " average_loss(loss)\n",
+ " gradients = tape.gradient(loss, model.variables)\n",
+ " optimizer.apply_gradients(zip(gradients, model.variables))\n",
+ " return average_loss.result()\n",
+ " \n",
+ "for epoch in range(EPOCHS):\n",
+ " loss = _train_one_epoch(model, training_data)\n",
+ " print(\"Average loss after epoch %d: %.4f\" % (epoch, loss))"
+ ],
+ "execution_count": 15,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Average loss after epoch 0: 1.0563\n",
+ "Average loss after epoch 1: 0.8013\n",
+ "Average loss after epoch 2: 0.6306\n",
+ "Average loss after epoch 3: 0.5543\n",
+ "Average loss after epoch 4: 0.5037\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "uDy1DrYA_2Jz",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Run the below cell to qualitatively evaluate your model. Note how eager execution interoperates seamlessly with `matplotlib`."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "vR7rMtpu_3nB",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1752
+ },
+ "outputId": "b212aefa-f4b3-425c-f34d-2491429fa521"
+ },
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "sampled_data = test_data.batch(1).shuffle(buffer_size=10000).take(5)\n",
+ "for image, label in sampled_data:\n",
+ " plt.figure()\n",
+ " plt.imshow(tf.reshape(image, (28, 28)))\n",
+ " plt.show()\n",
+ " logits = model(image, training=False)\n",
+ " prediction = tf.argmax(logits, axis=1, output_type=tf.int64)\n",
+ " print(\"Prediction: %d\" % prediction)"
+ ],
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEwpJREFUeJzt3X1Ilff/x/HXmScxV2GZOmLVohXK\nKmLQjbUsy+pbI7rbaEm1IFhRSU1aE+kO3LqxCGrBMlsNkq0zZIM2Cu1mUTg1itXQbVnBQqKZNtcN\n2d3J3x9ffpLrNN/ndM65jn6fj7/m5cfrvI9XPHedc7zOcTU3NzcLAPCvXnJ6AABoD4glABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAG7kB/cOPGjbpw4YJcLpdyc3M1ZMiQYM4FABEloFieOXNGV69e\nlcfj0ZUrV5SbmyuPxxPs2QAgYgT0MLy8vFwZGRmSpP79++vWrVu6e/duUAcDgEgSUCwbGhrUvXv3\nlq979Oih+vr6oA0FAJEmKC/w8F4cADq6gGKZmJiohoaGlq9v3LihhISEoA0FAJEmoFiOHj1aJSUl\nkqTq6molJiaqS5cuQR0MACJJQK+Gv/nmm3rjjTf03nvvyeVyaf369cGeCwAiios3/wWAtnEFDwAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMHA7\nPQAQiAcPHpjX3rlzx+f2nj17qqGhodW2kydPmvb566+/mm//xx9/NK+13r4kjRgx4pltFRUVGjly\nZKttP/30k3mfL73E+dPz8JsBAIOAziwrKyu1YsUKDRgwQJI0cOBArV27NqiDAUAkCfhh+PDhw7Vz\n585gzgIAEYuH4QBgEHAsL1++rCVLlmju3LkqKysL5kwAEHFczc3Nzf7+UF1dnc6dO6cpU6aotrZW\nCxYsUGlpqaKjo0MxIwA4LqDnLJOSkjR16lRJUp8+fdSzZ0/V1dWpd+/eQR0OeB7+dIg/HQq3gH4z\nhw4d0hdffCFJqq+v182bN5WUlBTUwQAgkgR0Zjl+/HitWrVKx48f16NHj7RhwwYeggPo0AKKZZcu\nXbR79+5gzwIAESugF3gAf1RVVZnXfvfdd6Z1hw8fNu/zzJkzPrd7vV5FRUWZ99Me+LpPDx8+NP98\nR/t9BBPP5gKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAM+3RGtPO/qV5fL\n1ep7BQUF5n1mZWWZ1z558sS8NhRcLpdpnT9vZebPJYT9+vUzry0pKfG5/Y8//mj1NW+7Fhz8FgHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgCt40MrBgwd9bp87d26r7y1btsy8z1de\necW89q233jKte//99837/Dfff/99q68TExNNP/fqq6+ab8Of+x8MvXv3Duvt/a/gzBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4mp/3CVXoMB49emRe+/rrr/vcfvXqVfXt\n27fl68zMTPM+P/74Y/PauLg481ognDizBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABny6YztVX19vXjthwgTz2oEDB5q+l5eXZ96n223/Z/b48WPTuuvXr5v3efz4cZ/bFy5c\nqC+//NK8n0CNHTvWvLZfv34hnAQvwnRmWVNTo4yMDBUVFUn67z/U+fPnKzMzUytWrNDDhw9DOiQA\nOK3NWN67d095eXlKTU1t2bZz505lZmbqq6++Ut++fVVcXBzSIQHAaW3GMjo6WoWFha0+fL6ysrLl\noV16errKy8tDNyEARIA2n0xyu93PPOfU1NSk6OhoSVJ8fLxfz58BQHv0wi/w8HaYzkhISDCv/eWX\nX4Jym0ePHg3Kfv6N9cWg3r17m/e5cOHCgL4HPC2gWMbGxur+/fuKiYlRXV1dq4foCI9QvRqelJTk\nc/vRo0c1ceLElq+PHDli3ievhvNqeEcQ0N9Zjho1SiUlJZKk0tJSjRkzJqhDAUCkafN/+VVVVdqy\nZYuuXbsmt9utkpISbdu2TTk5OfJ4POrVq5dmzJgRjlkBwDFtxnLQoEE6cODAM9v3798fkoEAIBLx\ngWXt1A8//GBeO3v2bPPa572Ik5aWplOnTrV8ff78efM+J02aZF5rnfX333837/N5vF6voqKiAvrZ\nd99917x20KBB5rWrVq0yr42JiTGvxYvj2nAAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGDA5Y7tlD+X23377bcvfHv/vDTQn7cS8+ft1NLS0kzr/Ln/o0aN8rk9OTn5mcsmO3Xq\nZNrn7du3zbc/YsQI89q9e/ea1y5YsMC8Fi+OM0sAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyI\nJQAYEEsAMCCWAGDQ5kfhIjItXrzYvHb06NHmtRcvXnzu9z744IOW//bnUruhQ4ea11ovN3S7g/NP\nNzk5OaCfe/qTLtvi9XrNa/351E4udwwvziwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIAreNqpjIyMkKz9N59//nlQ9tMRPHjwwOkREGacWQKAAbEEAANiCQAGxBIADIglABgQSwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADEyxrKmpUUZGhoqKiiRJOTk5mjZtmubPn6/58+fr5MmT\noZwRABzX5rsO3bt3T3l5eUpNTW21PTs7W+np6SEbDAAiSZtnltHR0SosLFRiYmI45gGAiNTmmaXb\n7Zbb/eyyoqIi7d+/X/Hx8Vq7dq169OgRkgGBSDRx4kTzWq/XG8JJEC4Bvfnv9OnTFRcXp5SUFO3Z\ns0e7du3SunXrgj0bELGOHj1qXvuf//zHvHb27Nnmtd988415LV5cQK+Gp6amKiUlRZI0fvx41dTU\nBHUoAIg0AcUyKytLtbW1kqTKykoNGDAgqEMBQKRp82F4VVWVtmzZomvXrsntdqukpETz5s3TypUr\n1blzZ8XGxmrTpk3hmBUAHNNmLAcNGqQDBw48s33y5MkhGQgAIhGf7ggEgAsx/vdwuSMAGBBLADAg\nlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgckcgAKdPnw7JfqdNmxaS/eLFcWYJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAZcwQM85dSpU6Z1P//8s3mfL7/8snntuHHj\nzGsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAAy53RIf3999/+9we\nFxf3zPcyMjJM+/R6vebbP3jwoHlt7969zWsRXpxZAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAAy53DIMnT56Y1+bm5prWbdiwwbzPmJgY89r24u7du+a1b7/9ts/tZWVlz3zP\nehnjO++8Y7792bNnm9cicplimZ+fr3Pnzunx48davHixBg8erNWrV8vr9SohIUFbt25VdHR0qGcF\nAMe0GcuKigpdunRJHo9HjY2NmjlzplJTU5WZmakpU6Zo+/btKi4uVmZmZjjmBQBHtPmc5bBhw7Rj\nxw5JUrdu3dTU1KTKykpNmDBBkpSenq7y8vLQTgkADmszllFRUYqNjZUkFRcXKy0tTU1NTS0Pu+Pj\n41VfXx/aKQHAYeYXeI4dO6bi4mLt27dPkyZNatne3NwcksE6kpdesv/RwebNm0M4ScfRpUsX89qy\nsrKAvgc8zRTL06dPa/fu3dq7d6+6du2q2NhY3b9/XzExMaqrq1NiYmKo52zXeDU8+Px5NXzy5Mk+\nt5eVlWn06NGttlVUVJj26c+r4V9//bV5rT//Y0V4tXlk7ty5o/z8fBUUFCguLk6SNGrUKJWUlEiS\nSktLNWbMmNBOCQAOa/PM8vDhw2psbNTKlStbtm3evFlr1qyRx+NRr169NGPGjJAOCQBOazOWc+bM\n0Zw5c57Zvn///pAMBACRyNXMKzQh58+HW1n/uP/TTz817zM7Ozvotx8qv/32m2nd0qVLzfs8deqU\nz+1er1dRUVHm/TyturravDY5OTmg20Bk4dlkADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgwOWOYeDP5Y4JCQmmdbdu3TLvc+LEiea148aN87k9Jycn4PfavH//vnntJ598Ylrn\nzz/bbt26+dze2Nio7t27t9p28eJF0z6tx0mSXC6XeS0iF2eWAGBALAHAgFgCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgMsdI0xxcbFp3bJly8z7bGhoCHScFi/ySYj++Oflh88zefJk8z4/\n+ugjn9uHDh2q8+fPP7MN8IUzSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAw4Aqe\ndqqmpsa8Njs727z2yJEjPre/yBU8q1evNq8dPHiwaV1mZmZAswCB4swSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYcLkjABi4LYvy8/N17tw5PX78WIsXL9aJEydUXV2tuLg4\nSdKiRYs0bty4UM4JAI5qM5YVFRW6dOmSPB6PGhsbNXPmTI0cOVLZ2dlKT08Px4wA4Lg2Yzls2DAN\nGTJEktStWzc1NTXJ6/WGfDAAiCR+PWfp8Xh09uxZRUVFqb6+Xo8ePVJ8fLzWrl2rHj16hHJOAHCU\nOZbHjh1TQUGB9u3bp6qqKsXFxSklJUV79uzRn3/+qXXr1oV6VgBwjOlPh06fPq3du3ersLBQXbt2\nVWpqqlJSUiRJ48eP9+uNaAGgPWozlnfu3FF+fr4KCgpaXv3OyspSbW2tJKmyslIDBgwI7ZQA4LA2\nX+A5fPiwGhsbtXLlypZts2bN0sqVK9W5c2fFxsZq06ZNIR0SAJzGH6UDgAGXOwKAAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4nbjRjRs3\n6sKFC3K5XMrNzdWQIUOcGCOoKisrtWLFCg0YMECSNHDgQK1du9bhqQJXU1OjpUuXauHChZo3b56u\nX7+u1atXy+v1KiEhQVu3blV0dLTTY/rln/cpJydH1dXViouLkyQtWrRI48aNc3ZIP+Xn5+vcuXN6\n/PixFi9erMGDB7f74yQ9e79OnDjh+LEKeyzPnDmjq1evyuPx6MqVK8rNzZXH4wn3GCExfPhw7dy5\n0+kxXti9e/eUl5en1NTUlm07d+5UZmampkyZou3bt6u4uFiZmZkOTukfX/dJkrKzs5Wenu7QVC+m\noqJCly5dksfjUWNjo2bOnKnU1NR2fZwk3/dr5MiRjh+rsD8MLy8vV0ZGhiSpf//+unXrlu7evRvu\nMfAvoqOjVVhYqMTExJZtlZWVmjBhgiQpPT1d5eXlTo0XEF/3qb0bNmyYduzYIUnq1q2bmpqa2v1x\nknzfL6/X6/BUDsSyoaFB3bt3b/m6R48eqq+vD/cYIXH58mUtWbJEc+fOVVlZmdPjBMztdismJqbV\ntqamppaHc/Hx8e3umPm6T5JUVFSkBQsW6MMPP9Rff/3lwGSBi4qKUmxsrCSpuLhYaWlp7f44Sb7v\nV1RUlOPHypHnLJ/W3Nzs9AhB8dprr2n58uWaMmWKamtrtWDBApWWlrbL54va0lGO2fTp0xUXF6eU\nlBTt2bNHu3bt0rp165wey2/Hjh1TcXGx9u3bp0mTJrVsb+/H6en7VVVV5fixCvuZZWJiohoaGlq+\nvnHjhhISEsI9RtAlJSVp6tSpcrlc6tOnj3r27Km6ujqnxwqa2NhY3b9/X5JUV1fXIR7OpqamKiUl\nRZI0fvx41dTUODyR/06fPq3du3ersLBQXbt27TDH6Z/3KxKOVdhjOXr0aJWUlEiSqqurlZiYqC5d\nuoR7jKA7dOiQvvjiC0lSfX29bt68qaSkJIenCp5Ro0a1HLfS0lKNGTPG4YleXFZWlmprayX99znZ\n//9Lhvbizp07ys/PV0FBQcurxB3hOPm6X5FwrFzNDpyrb9u2TWfPnpXL5dL69euVnJwc7hGC7u7d\nu1q1apVu376tR48eafny5Ro7dqzTYwWkqqpKW7Zs0bVr1+R2u5WUlKRt27YpJydHDx48UK9evbRp\n0yZ16tTJ6VHNfN2nefPmac+ePercubNiY2O1adMmxcfHOz2qmcfj0WeffaZ+/fq1bNu8ebPWrFnT\nbo+T5Pt+zZo1S0VFRY4eK0diCQDtDVfwAIABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwOD/\nAKCzFeFbFn4BAAAAAElFTkSuQmCC\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd61cfd1e80>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 5\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEQ1JREFUeJzt3W9Ilff/x/HXSSd2VmKaRwiqjTBy\nq9gfap2iliaFQfRvsCXW1rpRRJGTCJG0MSHLIpbF8M9qN3L7cjZvNQiOVAQt7LQcBLqB1Y0QaXYs\naUa2mZ3fjS9ff7Vcvj2ec65jez7ueZ1P57wPlzy7Li8vjysUCoUEAHihcU4PAABjAbEEAANiCQAG\nxBIADIglABgQSwAwIJYAYEAsAcAgMdx/uH//fl27dk0ul0ulpaWaO3duJOcCgLgSViyvXLmiW7du\nyefz6ebNmyotLZXP54v0bAAQN8I6DW9ublZeXp4kacaMGbp//74ePHgQ0cEAIJ6EFcvu7m5NmjRp\n8Ou0tDQFg8GIDQUA8SYiF3j4WxwAXnZhxdLj8ai7u3vw6zt37igjIyNiQwFAvAkrlosWLZLf75ck\ntbW1yePxaMKECREdDADiSVhXw9955x29+eab+uijj+RyubRv375IzwUAccXFH/8FgOFxBw8AGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMEp0eAIgnP/30k2nd+vXrzc+Zl5dnXvvtt9+a1yK2OLIEAANiCQAGxBIADIgl\nABgQSwAwIJYAYEAsAcCAWAKAAbEEAAPu4AGecuzYMdO6YDBofk6XyxXuOIgjHFkCgEFYR5aBQEC7\ndu1SVlaWJGnmzJkqKyuL6GAAEE/CPg2fP3++qqurIzkLAMQtTsMBwCDsWN64cUPbtm3Thg0bdOnS\npUjOBABxxxUKhUIj/UddXV1qaWlRfn6+Ojo6tGnTJjU1NSkpKSkaMwKA48L6mWVmZqZWrlwpSZo2\nbZomT56srq4uTZ06NaLDAbH24Ycfmtb98MMP5ucsKCgwr21oaDCvRWyFdRp++vRpnThxQtJ/f9/s\n7t27yszMjOhgABBPwjqyzM3N1e7du3Xu3Dn19/fr888/5xQcwEstrFhOmDBBNTU1kZ4FAOIWtzsC\nT7lw4ULEn3PVqlURf07EHr9nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDLjdES89v98/5PYVK1Y899hIPrXRqre3N+LPidjjyBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLiDB2NSKBQyr21oaBhy+4oVK/7xsUh6++23o/4aiD6OLAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgIErNJL7xoA40dnZaV47derUIbc/efJE48aFd7zw7rvv\nmtf+/PPPYb0G4gtHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBPd8SY\nVFlZ6ejrb9682dHXR+yZjizb29uVl5c3+LGht2/f1saNG1VQUKBdu3bpr7/+iuqQAOC0YWP58OFD\nVVRUyOv1Dm6rrq5WQUGBvvvuO02fPl2NjY1RHRIAnDZsLJOSklRfXy+PxzO4LRAIaNmyZZKknJwc\nNTc3R29CAIgDw/7MMjExUYmJzy7r6+tTUlKSJCk9PV3BYDA60wFAnBj1BR7+HCaccPz48YisffLk\nSSTGwb9AWLF0u9169OiRkpOT1dXV9cwpOhALO3bsMK/96quvhtw+mj/+O5JYb9++PazXQHwJ6ztl\n4cKF8vv9kqSmpiYtXrw4okMBQLwZ9siytbVVBw8eVGdnpxITE+X3+3X48GGVlJTI5/NpypQpWrNm\nTSxmBQDHDBvL2bNn69SpU89t/+abb6IyEADEI+7gQVyxXnCJ1oeAWX/+XlhYGJXXR/zi3nAAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDA7Y6IKxUVFaZ10brd8dVXXzWt6+3t\nNT9nSkpKuOMgjnBkCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADLjdEXHl\nyy+/dPT1BwYGTOv8fr/5OT/99NNwx0Ec4cgSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAy4gwdR99tvv5nXjuSDwKzcbrf5sV9++cX0nGlpaaOaCWMPR5YAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYEAsAcCA2x0RFusHe0kj+xCyJ0+ehDPOC507d878GLcx4p9wZAkA\nBqZYtre3Ky8vTw0NDZKkkpISrVq1Shs3btTGjRt14cKFaM4IAI4b9jT84cOHqqiokNfrfWZ7cXGx\ncnJyojYYAMSTYY8sk5KSVF9fL4/HE4t5ACAuuUKhUMiy8NixY5o0aZIKCwtVUlKiYDCo/v5+paen\nq6ysjB+MA3iphXU1fPXq1UpNTVV2drbq6up0/PhxlZeXR3o2xLGRXA3fvn27eW19fX0447xQc3Pz\nkNvfe+89BQKB57YBQwnrarjX61V2drYkKTc3V+3t7REdCgDiTVix3Llzpzo6OiRJgUBAWVlZER0K\nAOLNsKfhra2tOnjwoDo7O5WYmCi/36/CwkIVFRVp/PjxcrvdqqysjMWsAOCYYWM5e/ZsnTp16rnt\nK1asiMpAABCPzFfDgafdu3fPvHby5MkRf/0PPvjAvPY///nPkNsTEhKeu1CVkJAwqrnw8uJ2RwAw\nIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMCnO+IZ//TpiuPGjXvmsc2bN0fl\n9V0ul2ndF198YX7OF93CyO2NsOLIEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\nuIMHz/jf58H/3fTp05957Mcff4zK6xcWFprWzZo1KyqvD/wTjiwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABtzviGRcuXBhy+8cff/zMY6FQKCqvX15eHpXnBUaLI0sAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDgCkXrvjXEjV9//dW8ds6cOUNuHxgY\nUEJCwuDXI/m2Wb9+vXmtz+czrRs3jv/nEVume8OrqqrU0tKix48fa+vWrZozZ4727NmjgYEBZWRk\n6NChQ0pKSor2rADgmGFjefnyZV2/fl0+n089PT1au3atvF6vCgoKlJ+fryNHjqixsVEFBQWxmBcA\nHDHsucy8efN09OhRSVJKSor6+voUCAS0bNkySVJOTo6am5ujOyUAOGzYWCYkJMjtdkuSGhsbtWTJ\nEvX19Q2edqenpysYDEZ3SgBwmPnvWZ49e1aNjY06efKkli9fPrid60Px74033jCvHRgYCOsx4GVn\niuXFixdVU1Ojr7/+WhMnTpTb7dajR4+UnJysrq4ueTyeaM+JUeBqODB6w37H9fb2qqqqSrW1tUpN\nTZUkLVy4UH6/X5LU1NSkxYsXR3dKAHDYsEeWZ86cUU9Pj4qKiga3HThwQHv37pXP59OUKVO0Zs2a\nqA4JAE7jl9L/BTgNB0aPDyz7F7AGSHpxBJ9+LCUlxfycJ06cMK8lgohXfGcCgAGxBAADYgkABsQS\nAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADbnf8F7hx44Z5rfV2x+TkZPNzjuTWSCBecWQJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMuN3xX6C4uNi89vvvv//HxxIT///b\n5a233hrVTMBYw5ElABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4Qi/6hCoAgCSO\nLAHAhFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAAD06c7VlVVqaWlRY8f\nP9bWrVt1/vx5tbW1KTU1VZK0ZcsWLV26NJpzAoCjho3l5cuXdf36dfl8PvX09Gjt2rVasGCBiouL\nlZOTE4sZAcBxw8Zy3rx5mjt3riQpJSVFfX19GhgYiPpgABBPRvQn2nw+n65evaqEhAQFg0H19/cr\nPT1dZWVlSktLi+acAOAocyzPnj2r2tpanTx5Uq2trUpNTVV2drbq6ur0+++/q7y8PNqzAoBjTFfD\nL168qJqaGtXX12vixInyer3Kzs6WJOXm5qq9vT2qQwKA04aNZW9vr6qqqlRbWzt49Xvnzp3q6OiQ\nJAUCAWVlZUV3SgBw2LAXeM6cOaOenh4VFRUNblu3bp2Kioo0fvx4ud1uVVZWRnVIAHAan8EDAAbc\nwQMABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGCQ6MSL7t+/X9euXZPL5VJpaanmzp3rxBgRFQgEtGvXLmVlZUmSZs6cqbKyMoenCl97e7u2\nb9+uTz75RIWFhbp9+7b27NmjgYEBZWRk6NChQ0pKSnJ6zBH5+3sqKSlRW1ubUlNTJUlbtmzR0qVL\nnR1yhKqqqtTS0qLHjx9r69atmjNnzpjfT9Lz7+v8+fOO76uYx/LKlSu6deuWfD6fbt68qdLSUvl8\nvliPERXz589XdXW102OM2sOHD1VRUSGv1zu4rbq6WgUFBcrPz9eRI0fU2NiogoICB6ccmaHekyQV\nFxcrJyfHoalG5/Lly7p+/bp8Pp96enq0du1aeb3eMb2fpKHf14IFCxzfVzE/DW9ublZeXp4kacaM\nGbp//74ePHgQ6zHwAklJSaqvr5fH4xncFggEtGzZMklSTk6OmpubnRovLEO9p7Fu3rx5Onr0qCQp\nJSVFfX19Y34/SUO/r4GBAYenciCW3d3dmjRp0uDXaWlpCgaDsR4jKm7cuKFt27Zpw4YNunTpktPj\nhC0xMVHJycnPbOvr6xs8nUtPTx9z+2yo9yRJDQ0N2rRpkz777DPdu3fPgcnCl5CQILfbLUlqbGzU\nkiVLxvx+koZ+XwkJCY7vK0d+Zvm0UCjk9AgR8dprr2nHjh3Kz89XR0eHNm3apKampjH586LhvCz7\nbPXq1UpNTVV2drbq6up0/PhxlZeXOz3WiJ09e1aNjY06efKkli9fPrh9rO+np99Xa2ur4/sq5keW\nHo9H3d3dg1/fuXNHGRkZsR4j4jIzM7Vy5Uq5XC5NmzZNkydPVldXl9NjRYzb7dajR48kSV1dXS/F\n6azX61V2drYkKTc3V+3t7Q5PNHIXL15UTU2N6uvrNXHixJdmP/39fcXDvop5LBctWiS/3y9Jamtr\nk8fj0YQJE2I9RsSdPn1aJ06ckCQFg0HdvXtXmZmZDk8VOQsXLhzcb01NTVq8eLHDE43ezp071dHR\nIem/P5P9328yjBW9vb2qqqpSbW3t4FXil2E/DfW+4mFfuUIOHKsfPnxYV69elcvl0r59+zRr1qxY\njxBxDx480O7du/XHH3+ov79fO3bs0Pvvv+/0WGFpbW3VwYMH1dnZqcTERGVmZurw4cMqKSnRn3/+\nqSlTpqiyslKvvPKK06OaDfWeCgsLVVdXp/Hjx8vtdquyslLp6elOj2rm8/l07Ngxvf7664PbDhw4\noL17947Z/SQN/b7WrVunhoYGR/eVI7EEgLGGO3gAwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAg\nlgBg8H/nb4OLnfGqVAAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd61bade5c0>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 1\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE1ZJREFUeJzt3X1olfX/x/HXccc1DyrLuY1GaRGL\nRqZSaE7zZmqKgnhDsVwqkYGRE29QW8tp4M102solNJ03fzSqgyPoBmFDIlg1Jw0xNsrZDbKGranD\nG5x3x33/+NF+rp153js751znrOfjv13n43Xex4NPrrPL61yujo6ODgEA7muA0wMAQCwglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAMiCUAGLiD/YM7duzQ6dOn5XK5lJ+fr9GjR4dyLgCIKkHF8uTJkzp3\n7py8Xq9+++035efny+v1hno2AIgaQX0Mr6mp0cyZMyVJjz/+uC5fvqxr166FdDAAiCZBxfLChQt6\n8MEHO38eNmyYWltbQzYUAESbkJzg4bs4APR3QcUyJSVFFy5c6Pz577//VnJycsiGAoBoE1QsJ02a\npMrKSklSQ0ODUlJSNHjw4JAOBgDRJKiz4c8884yeeuopvfzyy3K5XNqyZUuo5wKAqOLiy38BIDCu\n4AEAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUA\nGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJ\nAAbEEgAMiCUAGLiD+UO1tbVavXq10tPTJUlPPPGECgoKQjoYAESToGIpSePHj1dJSUkoZwGAqMXH\ncAAwCDqWv/76q9544w0tXrxY33//fShnAoCo4+ro6Ojo7R9qaWlRXV2d5syZo6amJi1btkxVVVWK\nj48Px4wA4LigjixTU1M1d+5cuVwujRgxQsOHD1dLS0uoZwOAqBFULL/88ksdOnRIktTa2qqLFy8q\nNTU1pIMBQDQJ6mP4tWvXtH79el25ckW3b99Wbm6upk6dGo75ACAqBBVLAPivCfr/WQL90alTp0zr\nSktLzfssKysLdpz78nec09HRIZfL1WVbbm6ueZ+9+b/T/36e/o7/ZwkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAy4Nhz93tmzZ/1uT09P7/bY4sWLTfu0XhYZaT6fT3FxcUH/\n+Vu3bpnX9uV5YhFHlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgwA3LEHa9uUjs\nzJkzpnXz588377Opqcnv9uvXr2vMmDFdtt28edO8Xyu32/7PrKCgwLw2Pj7e7/bCwsIuPz/77LPm\nfQ4YwPFTT/ibAQADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtywDEG5ffu2\nee1bb71lXrt3795gxgmKv5t7PfTQQ6Y/u3r1avPzLF++3Lz2yJEj5rW5ubndtj3wwAPdLtl84IEH\nzPtEzziyBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzdEV3cvXvX7/YB\nAwZ0eSwvL8+8z0hewujPokWLzI999NFHpn16PB7z8y9evNi89uuvvzavbW5u7ratuLhYb7/9drdt\n6DvTkWVjY6Nmzpyp8vJySdL58+e1dOlS5eTkaPXq1bp161ZYhwQApwWM5fXr17V161ZlZmZ2bisp\nKVFOTo4++eQTjRw5UhUVFWEdEgCcFjCW8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpqa8E0IAFEg\n4O8s3W633O6uy9rb2xUfHy9JSkpKUmtra3imA4Ao0ecTPHwdZv8yYEDPHzbufey9994z77M3ayPt\n6NGjYX+OL774IuzPcS9O6IRHULH0eDy6ceOGEhIS1NLS0uUjOmKb9Wz4hg0bzPv84IMP+jxXX/R0\nNvzo0aN66aWXumyLpbPh/r6AuLi4WOvWreu2DX0X1P+znDhxoiorKyVJVVVVmjx5ckiHAoBoE/DI\nsr6+Xrt27VJzc7PcbrcqKyu1Z88e5eXlyev1Ki0tTQsWLIjErADgmICxHDVqlD7++ONu23tzrxAA\niHVcwfMf8Ndff5nXzpo1y+/2n376SWPHju38uaGhoc9z+TN06FDTutLSUvM+X3zxxR4f++yzz7r8\nfL8TXPf69NNPzc/fm99D9kZaWlqvtqNvuDYcAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkA\nBsQSAAyIJQAYuDr4QsqYdPXqVfPaUaNGmdf++eeffrf7fD7FxcWZ93Ovf75V3+LQoUOmdY888khQ\nswRivUVKdnZ2WJ7/ny/Vtjh16lS3bU8++aR++eWXbtvQdxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA+7uGKPKy8vNa3u6hLEvlixZYl67Z88e89rk5GTTupaWFvM+X3/9\ndb/bv/rqK82bN6/LtsrKSvN+w6E3d43s6TJGLm8MD44sAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCAG5ZFmbt375rWvfDCC+Z9fvvtt+a1Pd0wq729XYMGDer8ubGx0bzPtLQ089qf\nf/7ZtG7Dhg3mfVZVVfnd3pebsIXLjRs3zGsHDhwYxknwbxxZAoABsQQAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA25YFmWsV5/25hLG3vD5fKbHiouLzfv8448/zGu/+uor89pYsWDB\nAvPaaLv8Ev+PI0sAMDDFsrGxUTNnzuy8/WpeXp7mzZunpUuXaunSpWE7ygGAaBHwY/j169e1detW\nZWZmdtm+bt06ZWVlhW0wAIgmAY8s4+PjVVZWppSUlEjMAwBRKeCRpdvtltvdfVl5ebmOHDmipKQk\nFRQUaNiwYWEZ8L/G+gv++52ICZdbt25F/DnDzYm/R8SmoM6Gz58/X4mJicrIyNCBAwe0b98+bd68\nOdSz/SdZ//H29CW9fdVTrG/dutXlOVeuXGneZ7SeDY/Ul//25mz40aNHzWsHDOD8bCQF9bedmZmp\njIwMSdL06dN79a3ZABCLgorlqlWr1NTUJEmqra1Venp6SIcCgGgT8GN4fX29du3apebmZrndblVW\nVmrJkiVas2aNBg0aJI/Ho8LCwkjMCgCOCRjLUaNG6eOPP+62ffbs2WEZCACiEZc7ogvr5Y4lJSWR\nGKdf6M0JHk7aRC/eGQAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYMDljlHG\nernbsWPHzPvszeV24fiC3958MfT69etN6/Lz84MdJyS2bdtmXvvKK6+EcRJECkeWAGBALAHAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAFTxRxuVymdb15u6ap06dMq+9dOlSj49VV1eb93Ov\nsWPHmtfW1dUF9RyhMmbMGNO6lStXmvfJTcj6B95FADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg4Oro6Ohwegj0b21tbea1kyZNMq07c+ZMsON08vl8iouL67Lthx9+MP3Z5557\nrs/Pj9jCkSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDg7o4Iu5MnT5rX\nhuIyxn/Ly8szPzZ+/PiQPz/6B1Msi4qKVFdXpzt37mjFihV6+umntXHjRvl8PiUnJ2v37t2Kj48P\n96wA4JiAsTxx4oTOnj0rr9ertrY2LVy4UJmZmcrJydGcOXNUXFysiooK5eTkRGJeAHBEwN9Zjhs3\nTnv37pUkDR06VO3t7aqtrdWMGTMkSVlZWaqpqQnvlADgsICxjIuLk8fjkSRVVFRoypQpam9v7/zY\nnZSUpNbW1vBOCQAOM5/gOX78uCoqKnT48GHNmjWrcztfh4lAZs+ebV7r8/nCOEl327dvj+jzIXaZ\nYlldXa3S0lIdPHhQQ4YMkcfj0Y0bN5SQkKCWlhalpKSEe07EsMrKSvPauXPnhvz5ezobvn37dr3z\nzjtdtm3bts20T5fL1ee5EFsCfgy/evWqioqKtH//fiUmJkqSJk6c2PkPoKqqSpMnTw7vlADgsIBH\nlseOHVNbW5vWrFnTuW3nzp3atGmTvF6v0tLStGDBgrAOCQBOCxjL7OxsZWdnd9t+5MiRsAwEANGI\nG5YhKL25CVlGRoZ5bTj+Z8Xvv//ud/vIkSN17ty5btsAf7g2HAAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGHDDMgSlrKzMvDYclzDm5uaa16alpQX1GHAvjiwBwIBYAoABsQQA\nA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABlzuiizt37vjd7na7uzz2+eefh+X5V61aZVr3\n/vvvm/fpcrl6fGzgwIHm/eC/jSNLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADBw\ndXR0dDg9BKLHd99953f7888/3+WxqVOnmvf58MMPm9eeOXPGtC4hIcG8TyAUOLIEAANiCQAGxBIA\nDIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG3LAMXQwZMiSox+5ny5Yt5rVcxohoZYplUVGR\n6urqdOfOHa1YsULffPONGhoalJiYKElavny5pk2bFs45AcBRAWN54sQJnT17Vl6vV21tbVq4cKEm\nTJigdevWKSsrKxIzAoDjAsZy3LhxGj16tCRp6NCham9vl8/nC/tgABBNAp7giYuLk8fjkSRVVFRo\nypQpiouLU3l5uZYtW6a1a9fq0qVLYR8UAJxk/j7L48ePa//+/Tp8+LDq6+uVmJiojIwMHThwQH/9\n9Zc2b94c7lkBwDGmEzzV1dUqLS3VwYMHNWTIEGVmZnY+Nn36dL377rvhmg8Rdvr0ab/bx4wZ0+Wx\nZ555xrzPsrIy89rXXnvNvBaIpIAfw69evaqioiLt37+/8+z3qlWr1NTUJEmqra1Venp6eKcEAIcF\nPLI8duyY2tratGbNms5tixYt0po1azRo0CB5PB4VFhaGdUgAcFrAWGZnZys7O7vb9oULF4ZlIACI\nRlzuCAAG3N0RAAw4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQA\nA2IJAAbEEgAM3E486Y4dO3T69Gm5XC7l5+dr9OjRTowRUrW1tVq9erXS09MlSU888YQKCgocnip4\njY2NevPNN/Xqq69qyZIlOn/+vDZu3Cifz6fk5GTt3r1b8fHxTo/ZK/9+TXl5eWpoaFBiYqIkafny\n5Zo2bZqzQ/ZSUVGR6urqdOfOHa1YsUJPP/10zL9PUvfX9c033zj+XkU8lidPntS5c+fk9Xr122+/\nKT8/X16vN9JjhMX48eNVUlLi9Bh9dv36dW3dulWZmZmd20pKSpSTk6M5c+aouLhYFRUVysnJcXDK\n3vH3miRp3bp1ysrKcmiqvjlx4oTOnj0rr9ertrY2LVy4UJmZmTH9Pkn+X9eECRMcf68i/jG8pqZG\nM2fOlCQ9/vjjunz5sq5duxbpMXAf8fHxKisrU0pKSue22tpazZgxQ5KUlZWlmpoap8YLir/XFOvG\njRunvXv3SpKGDh2q9vb2mH+fJP+vy+fzOTyVA7G8cOGCHnzwwc6fhw0bptbW1kiPERa//vqr3njj\nDS1evFjff/+90+MEze12KyEhocu29vb2zo9zSUlJMfee+XtNklReXq5ly5Zp7dq1unTpkgOTBS8u\nLk4ej0eSVFFRoSlTpsT8+yT5f11xcXGOv1eO/M7yXh0dHU6PEBKPPvqocnNzNWfOHDU1NWnZsmWq\nqqqKyd8XBdJf3rP58+crMTFRGRkZOnDggPbt26fNmzc7PVavHT9+XBUVFTp8+LBmzZrVuT3W36d7\nX1d9fb3j71XEjyxTUlJ04cKFzp///vtvJScnR3qMkEtNTdXcuXPlcrk0YsQIDR8+XC0tLU6PFTIe\nj0c3btyQJLW0tPSLj7OZmZnKyMiQJE2fPl2NjY0OT9R71dXVKi0tVVlZmYYMGdJv3qd/v65oeK8i\nHstJkyapsrJSktTQ0KCUlBQNHjw40mOE3JdffqlDhw5JklpbW3Xx4kWlpqY6PFXoTJw4sfN9q6qq\n0uTJkx2eqO9WrVqlpqYmSf/3O9l//idDrLh69aqKioq0f//+zrPE/eF98ve6ouG9cnU4cKy+Z88e\n/fjjj3K5XNqyZYuefPLJSI8QcteuXdP69et15coV3b59W7m5uZo6darTYwWlvr5eu3btUnNzs9xu\nt1JTU7Vnzx7l5eXp5s2bSktLU2FhoQYOHOj0qGb+XtOSJUt04MABDRo0SB6PR4WFhUpKSnJ6VDOv\n16sPP/xQjz32WOe2nTt3atOmTTH7Pkn+X9eiRYtUXl7u6HvlSCwBINZwBQ8AGBBLADAglgBgQCwB\nwIBYAoABsQQAA2IJAAbEEgAM/gepgR0uaefKmwAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd6199ef278>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 4\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEelJREFUeJzt3W9MlfX/x/HXEWJyhg5BIG1ZfR0u\nKr3hhopOE2Q23FxiN0xCdNmGa5pG6hhTtNn8g85NtI0/aS1Z29moG96wILM2dYDKDRu0hrpyzCkC\nkUocDeH8brQfk8R4czyH64DPx624+Hid99nFnl2H61wHl8/n8wkA8J/GOD0AAIwExBIADIglABgQ\nSwAwIJYAYEAsAcCAWAKAAbEEAINwf//h7t27denSJblcLhUUFGjGjBmBnAsAQopfsTx//ryuXbsm\nj8ejq1evqqCgQB6PJ9CzAUDI8OtleE1NjdLT0yVJU6dO1e3bt9XZ2RnQwQAglPgVy7a2Nk2YMKHv\n65iYGLW2tgZsKAAINQG5wMNncQAY7fyKZXx8vNra2vq+vnXrluLi4gI2FACEGr9iOW/ePFVVVUmS\nGhsbFR8fr6ioqIAOBgChxK+r4TNnztSrr76qt99+Wy6XSzt27Aj0XAAQUlx8+C8ADI47eADAgFgC\ngAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAz8+lO4AJz3yy+/PLLtlVdeeWT777//bt7ne++9Z147f/58\n0zqPx2PeZyjjzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABi4fD6fz+kh\ngNHsr7/+Mq+tr683r33rrbce2dba2qq4uLh+29rb2837XL16tXntp59+alrndrvN+wxlnFkCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF/sAzww/37981rMzMzzWtPnTplXvu4O2O8\nXm+/rysrK837XLJkiXnt2LFjzWtHA84sAcDArzPLuro6bdy4UYmJiZKkadOmafv27QEdDABCid8v\nw2fNmqXi4uJAzgIAIYuX4QBg4Hcsr1y5onXr1mnlypU6d+5cIGcCgJDj1+dZtrS0qL6+XhkZGWpu\nblZOTo6qq6sVERERjBkBwHF+/c4yISGh7y0GU6ZM0cSJE9XS0qLnn38+oMMBoWoobx1aunSpee2T\nvnWos7NTUVFR/bZ9+eWX5n3y1qHH8+tl+IkTJ3T06FFJ/3wyc3t7uxISEgI6GACEEr/OLNPS0rR5\n82b98MMP6u7u1s6dO3kJDmBU8yuWUVFRKikpCfQsABCyuN0ReIj1vcNbtmwx77O7u9u8dii/9//x\nxx8H3P7zzz/3+/p///ufeZ94PN5nCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAG\nxBIADPz6PEvAaT09Pea1x48fH3D7mjVr9MUXX/Tblpuba9pnb2+v+fE/+eQT89qcnBzz2kmTJpnX\n4slxZgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABtzBgxHpcXflDGT16tUDbu/t\n7dWYMf6dL+zcudO8trCw0K/HQGjhzBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAANiCQAGxBIA\nDIglABhwuyNCSnFxsWndRx99ZN7n4/642UC3O77zzjumff77D539l7CwMPNahC7OLAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAG3OyLovF6vee2kSZNM6+7cuePvOH0Gut2x\npqbG9G9nz579xI+PkcV0ZtnU1KT09HRVVFRIkm7cuKFVq1YpKytLGzdu1N9//x3UIQHAaYPGsqur\nS7t27VJKSkrftuLiYmVlZemrr77SCy+8oMrKyqAOCQBOGzSWERERKi8vV3x8fN+2uro6LVq0SJKU\nmppqfukCACNV+KALwsMVHt5/mdfrVUREhCQpNjZWra2twZkOAELEoLEcDNeHMJjIyEjz2j///DOI\nkzyqt7d3WB8PI5dfsXS73bp3757Gjh2rlpaWfi/RgX/jajhGA7/eZzl37lxVVVVJkqqrqzV//vyA\nDgUAoWbQM8uGhgbt27dP169fV3h4uKqqqnTgwAHl5+fL4/Fo8uTJWrZs2XDMCgCO4U3pCDpehmM0\neOILPHg6ffvtt+a1hw4dMq8NRASfRElJiWkdsXz6cG84ABgQSwAwIJYAYEAsAcCAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAw4HZH+MV6W6D0zydTWU2ZMsW07v79++Z9trS0mNcCj8OZJQAYEEsAMCCW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMOB2R/Rz4cKFAbcnJyf3+15tbW1QHv/77783\nrRvKX4FMTk72dxygD2eWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGDAHTzoZ8GC\nBQNu93q9/b43lD8YNhTWP1jm9XqD8vjA43BmCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEE\nAANiCQAGxBIADLjd8Slw5MgR89r/uo3R31scZ8yYYV7rcrn8eoxAuXnzpmldV1eXeZ9ut9vfcRBC\nOLMEAANTLJuampSenq6KigpJUn5+vpYuXapVq1Zp1apV+umnn4I5IwA4btCX4V1dXdq1a5dSUlL6\nbc/Ly1NqamrQBgOAUDLomWVERITKy8sVHx8/HPMAQEhy+Xw+n2Xh4cOHNWHCBGVnZys/P1+tra3q\n7u5WbGystm/frpiYmGDPCgCO8etq+Jtvvqno6GglJSWprKxMR44cUWFhYaBnQ4AM5Wr4Bx98MOD2\n3t5ejRnj3/XAoVwNP3/+vGndUK5GP+5/5AM9pzfeeMO0z6+//tr8+FwNHx38+ulPSUlRUlKSJCkt\nLU1NTU0BHQoAQo1fsdywYYOam5slSXV1dUpMTAzoUAAQagZ9Gd7Q0KB9+/bp+vXrCg8PV1VVlbKz\ns7Vp0yZFRkbK7XZrz549wzErADhm0Fi+9tprOn78+CPbrb/bAYDRgNsdnwLt7e2OPv6WLVvMayMi\nIkzrhnKBZyiqqqpM63799VfzPmfOnOnvOAgh3O4IAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMuN0RfomNjTWvTU5ODvjjnz17NuD7lNT30YODee6554Ly+AhdnFkCgAGxBAAD\nYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAF38MAv48ePN6999tlnA/74FRUVAd+nJM2aNcu0\nLiEhISiPj9DFmSUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADDgdkf45bff\nfjOv/eabb8xrs7OzTet6e3vN+/T5fH59D3gYZ5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQ\nSwAwIJYAYEAsAcCA2x0RdO+++25Q1lq5XC6/vgc8zBTLoqIi1dfX68GDB8rNzdX06dO1detW9fT0\nKC4uTvv371dERESwZwUAxwway9raWl2+fFkej0cdHR3KzMxUSkqKsrKylJGRoYMHD6qyslJZWVnD\nMS8AOGLQ31kmJyfr0KFDkqTx48fL6/Wqrq5OixYtkiSlpqaqpqYmuFMCgMMGjWVYWJjcbrckqbKy\nUgsWLJDX6+172R0bG6vW1tbgTgkADjNf4Dl16pQqKyt17NgxLV68uG87nwcY+nbs2BGQtUP5DMmR\nYjQ+JwSHKZZnzpxRSUmJPvvsM40bN05ut1v37t3T2LFj1dLSovj4+GDPiSfw8ccfP/Ha3t5ejRkz\nut5pNtBzWr16tenffv7558EYCSFs0J/+u3fvqqioSKWlpYqOjpYkzZ07V1VVVZKk6upqzZ8/P7hT\nAoDDBj2zPHnypDo6OrRp06a+bXv37tW2bdvk8Xg0efJkLVu2LKhDAoDTBo3lihUrtGLFike28zIE\nwNOEO3ieAnl5eea1Fy5ceOz3lixZ0vffZ8+eNe/zzp075rVAqBpdv7EHgCAhlgBgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg4PLxgZTww3fffWde+/Btkk543I+4z+d75A+W1dbWmvY5\ne/bsJ54LIwtnlgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIDbHQHAgDNL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbE\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAg3DLoqKiItXX1+vBgwfKzc3V6dOn1djYqOjoaEnS\n2rVrtXDhwmDOCQCOGjSWtbW1unz5sjwejzo6OpSZmak5c+YoLy9PqampwzEjADhu0FgmJydrxowZ\nkqTx48fL6/Wqp6cn6IMBQChx+Xw+n3Wxx+PRxYsXFRYWptbWVnV3dys2Nlbbt29XTExMMOcEAEeZ\nY3nq1CmVlpbq2LFjamhoUHR0tJKSklRWVqabN2+qsLAw2LMCgGNMV8PPnDmjkpISlZeXa9y4cUpJ\nSVFSUpIkKS0tTU1NTUEdEgCcNmgs7969q6KiIpWWlvZd/d6wYYOam5slSXV1dUpMTAzulADgsEEv\n8Jw8eVIdHR3atGlT37bly5dr06ZNioyMlNvt1p49e4I6JAA4bUgXeADgacUdPABgQCwBwIBYAoAB\nsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBg\nQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbhTjzo7t27\ndenSJblcLhUUFGjGjBlOjBFQdXV12rhxoxITEyVJ06ZN0/bt2x2eyn9NTU16//33tWbNGmVnZ+vG\njRvaunWrenp6FBcXp/379ysiIsLpMYfk388pPz9fjY2Nio6OliStXbtWCxcudHbIISoqKlJ9fb0e\nPHig3NxcTZ8+fcQfJ+nR53X69GnHj9Wwx/L8+fO6du2aPB6Prl69qoKCAnk8nuEeIyhmzZql4uJi\np8d4Yl1dXdq1a5dSUlL6thUXFysrK0sZGRk6ePCgKisrlZWV5eCUQzPQc5KkvLw8paamOjTVk6mt\nrdXly5fl8XjU0dGhzMxMpaSkjOjjJA38vObMmeP4sRr2l+E1NTVKT0+XJE2dOlW3b99WZ2fncI+B\n/xAREaHy8nLFx8f3baurq9OiRYskSampqaqpqXFqPL8M9JxGuuTkZB06dEiSNH78eHm93hF/nKSB\nn1dPT4/DUzkQy7a2Nk2YMKHv65iYGLW2tg73GEFx5coVrVu3TitXrtS5c+ecHsdv4eHhGjt2bL9t\nXq+37+VcbGzsiDtmAz0nSaqoqFBOTo4+/PBD/fHHHw5M5r+wsDC53W5JUmVlpRYsWDDij5M08PMK\nCwtz/Fg58jvLh/l8PqdHCIgXX3xR69evV0ZGhpqbm5WTk6Pq6uoR+fuiwYyWY/bmm28qOjpaSUlJ\nKisr05EjR1RYWOj0WEN26tQpVVZW6tixY1q8eHHf9pF+nB5+Xg0NDY4fq2E/s4yPj1dbW1vf17du\n3VJcXNxwjxFwCQkJWrJkiVwul6ZMmaKJEyeqpaXF6bECxu126969e5KklpaWUfFyNiUlRUlJSZKk\ntLQ0NTU1OTzR0J05c0YlJSUqLy/XuHHjRs1x+vfzCoVjNeyxnDdvnqqqqiRJjY2Nio+PV1RU1HCP\nEXAnTpzQ0aNHJUmtra1qb29XQkKCw1MFzty5c/uOW3V1tebPn+/wRE9uw4YNam5ulvTP72T//50M\nI8Xdu3dVVFSk0tLSvqvEo+E4DfS8QuFYuXwOnKsfOHBAFy9elMvl0o4dO/Tyyy8P9wgB19nZqc2b\nN+vOnTvq7u7W+vXr9frrrzs9ll8aGhq0b98+Xb9+XeHh4UpISNCBAweUn5+v+/fva/LkydqzZ4+e\neeYZp0c1G+g5ZWdnq6ysTJGRkXK73dqzZ49iY2OdHtXM4/Ho8OHDeumll/q27d27V9u2bRuxx0ka\n+HktX75cFRUVjh4rR2IJACMNd/AAgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA4P8ALqDX\nN3rmU3AAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd62944c6d8>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 1\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEqVJREFUeJzt3W9Ilff/x/HX+eWkpMQ0dQRrZdgm\nq24Miiz6Y0nrFKPVjZqiMgiW/SMX0ZxlDYJMiyALZrnqRlKc4a1u5B9cjIWZUbDA7ljWQqJMm1iR\nbSbne2P8/H7NY77P8Ryvoz0f97y8us777BpPrnMuP+e4vF6vVwCAd/o/pwcAgNGAWAKAAbEEAANi\nCQAGxBIADIglABgQSwAwIJYAYBAR6D88dOiQbt++LZfLpYKCAs2dOzeYcwFAWAkoljdu3NDDhw/l\n8XjU0tKigoICeTyeYM8GAGEjoJfhDQ0NSk9PlyTNnDlTXV1devnyZVAHA4BwElAsOzo6NHny5L6f\nY2Nj1d7eHrShACDcBOUGD5/FAWCsCyiWCQkJ6ujo6Pv56dOnio+PD9pQABBuAorlokWLVFNTI0m6\nc+eOEhISNHHixKAOBgDhJKC74Z9//rk+++wzff3113K5XDpw4ECw5wKAsOLiw38BYGis4AEAA2IJ\nAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBY\nAoABsQQAA2IJAAbEEgAMiCUAGBBLADAI6KtwgVC5ePGiab+9e/eaj/ngwQOf271er1wul/k4gWpp\naTHvm5SUFMJJMBxcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKAAbEEAAOWOyIg\n9+/fD8lxMzMzTfutWrXKfMzBljv6MmPGjKAf88mTJ+Z9We4YvriyBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAgFgCgAGxBAADVvAgIOnp6eZ9/VntYrV06VLzvh6PZ9DfdXV19fs5OjradMwtW7aY\nH3/27NnmfRG+uLIEAIOAriwbGxu1c+dOJScnS5JmzZqlwsLCoA4GAOEk4Jfh8+fPV2lpaTBnAYCw\nxctwADAIOJb37t1Tbm6uMjIyVF9fH8yZACDsuLxer9fff9TW1qZbt27J7XartbVVOTk5qq2tVWRk\nZChmBADHBfSeZWJiolavXi1JmjZtmqZMmaK2tjZ99NFHQR0O4cufD6kNxZ8OFRUVmffdunWrz+3R\n0dF6/vz5gG0W/vzpUHFxsXlf6+Nj5AX0MvzSpUs6c+aMJKm9vV3Pnj1TYmJiUAcDgHAS0JXl8uXL\ntXv3bv3666/q6enRjz/+yEtwAGNaQLGcOHGiysrKgj0LAIStgG7wYHR5+325d9m4caPP7VVVVXK7\n3X0/V1dXD3suX6zvRebn54fk8YHB8HeWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCW\nAGBALAHAgOWO7wF/Pk5ssDX/Xq9XLpcroMf35+PUWMaIcMWVJQAYEEsAMCCWAGBALAHAgFgCgAGx\nBAADYgkABsQSAAyIJQAYsIJnlLp27Zp530WLFg378d5ewXPhwgXzv83IyBj24wNO48oSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYRDg9APp7/vy5ab9gLGH0JTc31/Q7ljDi\nfcOVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMODbHcOM2+027VddXW0+\n5qpVq8z7ejwen9ujo6P7LcWMjo42HxMYC0xXls3NzUpPT1dFRYUk6fHjx8rOzlZmZqZ27typf/75\nJ6RDAoDThozlq1evdPDgQaWmpvZtKy0tVWZmpi5cuKCPP/5YlZWVIR0SAJw2ZCwjIyNVXl6uhISE\nvm2NjY1asWKFJCktLU0NDQ2hmxAAwsCQH9EWERGhiIj+u3V3dysyMlKSFBcXp/b29tBMBwBhYtif\nZ8n9oeCqqqpyeoRBcVMH77OAYhkVFaXXr19r/Pjxamtr6/cSHcPD3XAgPAX0d5YLFy5UTU2NJKm2\ntlaLFy8O6lAAEG6GvLJsampScXGxHj16pIiICNXU1Ojo0aPKz8+Xx+PR1KlT9dVXX43ErADgmCFj\nOXv2bJ0/f37A9nPnzoVkIAAIR6zgGQH379837ztz5sygP35LS4t536SkpKA/PjAWsDYcAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYDPvzLDG0I0eOBP2Yubm55n1ZwggMH1eW\nAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgOWOI6Cmpibox8zOzg76Mceq\nwb5dMykpacDvrEtT//zzT/PjT58+3byvP/+vfPLJJwO2VVVVye1299uWk5NjPuaaNWvM+0ZHR5v3\nHQu4sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA5fX6/U6PcRY588Xhj148MC0\nX0tLS0ge32kXL1407bd3717zMQf7b+r1euVyuczHGQ2G+5xWrVpl3tfj8Zj2GysrfbiyBAADYgkA\nBsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABix3HAFbtmwx71tWVmbabzSdtlAs9wyG\n4SwN9GdZYHV1dUCPEYiRXMJpXXI7mpbbvgtXlgBgYIplc3Oz0tPTVVFRIUnKz8/Xl19+qezsbGVn\nZ+u3334L5YwA4Lghvzf81atXOnjwoFJTU/tt37Vrl9LS0kI2GACEkyGvLCMjI1VeXq6EhISRmAcA\nwpL5Bs+JEyc0efJkZWVlKT8/X+3t7erp6VFcXJwKCwsVGxsb6lkBwDFDvgz3Ze3atYqJiVFKSopO\nnz6tkydPav/+/cGebczgbjh3w0cKd8NDJ6C74ampqUpJSZEkLV++XM3NzUEdCgDCTUCx3LFjh1pb\nWyVJjY2NSk5ODupQABBuhnwZ3tTUpOLiYj169EgRERGqqalRVlaW8vLyNGHCBEVFRamoqGgkZgUA\nxwwZy9mzZ+v8+fMDtn/xxRchGQgAwlFAN3gAt9tt3jcUN238eTWzYcOGQX/39k2KKVOmBDzTYEL1\n7YbPnz/3ub2rq6vfz99//735mNYbjJK0bds2035VVVXmY4YzljsCgAGxBAADYgkABsQSAAyIJQAY\nEEsAMCCWAGBALAHAgFgCgAGxBAADljuOUteuXTPvu3DhwmEfd+HChf1+F6rPaKyvrzft589zepfR\n/FmLgy2jfHv7Tz/9ZD6mP8sd3zdcWQKAAbEEAANiCQAGxBIADIglABgQSwAwIJYAYEAsAcCAWAKA\nASt4RkBxcbF535qaGtN+WVlZ5mP+8ccf5n19fZOn9O+KmcF+NxR/vlwsWCtz8F/+rPbyR2FhYUiO\nG664sgQAA2IJAAbEEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAYur9frdXoI/Jd1adqi\nRYtCPEl/Xq9XLpcroH/b1dVl3newL+HCQBcvXhywLSMjY8D2zMxM8zEvXLhg3nfNmjWm/cbKOeXK\nEgAMiCUAGBBLADAglgBgQCwBwIBYAoABsQQAA2IJAAbEEgAMiCUAGLDccZTy5xv7grE0cjjLHf35\ndseHDx+a9svOzjYf88MPP/S5PSkpSffv3++37ZdffjEdc8mSJebH98fBgwfN+1ZXVw/YNpzzJEn1\n9fXmfd+3b+I0fRVuSUmJbt26pTdv3mjz5s2aM2eO9uzZo97eXsXHx+vIkSOKjIwM9awA4JghY3n9\n+nXdvXtXHo9HnZ2dWrdunVJTU5WZmSm3261jx46psrLSr8X6ADDaDPme5bx583T8+HFJ/356SHd3\ntxobG7VixQpJUlpamhoaGkI7JQA4bMhYjhs3TlFRUZKkyspKLVmyRN3d3X0vu+Pi4tTe3h7aKQHA\nYab3LCWprq5OlZWVOnv2rFauXNm3nftDzvDnzfVgnaOxeK6TkpL6/Zyfn+/QJP+qqqoa9jHG4nkK\nB6ZYXr16VWVlZfr55581adIkRUVF6fXr1xo/frza2tqUkJAQ6jnxFu6Gczecu+Eja8iX4S9evFBJ\nSYlOnTqlmJgYSf/+R6qpqZEk1dbWavHixaGdEgAcNuSV5eXLl9XZ2am8vLy+bYcPH9a+ffvk8Xg0\ndepUffXVVyEdEgCcNmQsN27cqI0bNw7Yfu7cuZAMBADhiBU874G335d7l23btvncXlVVJbfb3fez\nr/fLRpvhvr/ntBkzZgzYdv/+/QE3rerq6szHnDJlinnfsfJFZFasDQcAA2IJAAbEEgAMiCUAGBBL\nADAglgBgQCwBwIBYAoABsQQAA2IJAAYsd0RA/PmIuPPnz5v3tX702u+//24+5g8//OBzu6/ljr6W\nEPry7bffmh9/w4YN5n398fayRoQWV5YAYEAsAcCAWAKAAbEEAANiCQAGxBIADIglABgQSwAwIJYA\nYEAsAcCA5Y4AYMCVJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBA\nLAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgEGEZaeSkhLdunVLb9680ebN\nm3XlyhXduXNHMTExkqRNmzZp2bJloZwTABw1ZCyvX7+uu3fvyuPxqLOzU+vWrdOCBQu0a9cupaWl\njcSMAOC4IWM5b948zZ07V5IUHR2t7u5u9fb2hnwwAAgnLq/X67Xu7PF4dPPmTY0bN07t7e3q6elR\nXFycCgsLFRsbG8o5AcBR5ljW1dXp1KlTOnv2rJqamhQTE6OUlBSdPn1aT5480f79+0M9KwA4xnQ3\n/OrVqyorK1N5ebkmTZqk1NRUpaSkSJKWL1+u5ubmkA4JAE4bMpYvXrxQSUmJTp061Xf3e8eOHWpt\nbZUkNTY2Kjk5ObRTAoDDhrzBc/nyZXV2diovL69v2/r165WXl6cJEyYoKipKRUVFIR0SAJzm1w0e\nAHhfsYIHAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHA\ngFgCgAGxBAADYgkABsQSAAyIJQAYEEsAMCCWAGBALAHAgFgCgAGxBAADYgkABsQSAAyIJQAYEEsA\nMCCWAGBALAHAIMKJBz106JBu374tl8ulgoICzZ0714kxgqqxsVE7d+5UcnKyJGnWrFkqLCx0eKrA\nNTc3a+vWrfrmm2+UlZWlx48fa8+ePert7VV8fLyOHDmiyMhIp8f0y9vPKT8/X3fu3FFMTIwkadOm\nTVq2bJmzQ/qppKREt27d0ps3b7R582bNmTNn1J8naeDzunLliuPnasRjeePGDT18+FAej0ctLS0q\nKCiQx+MZ6TFCYv78+SotLXV6jGF79eqVDh48qNTU1L5tpaWlyszMlNvt1rFjx1RZWanMzEwHp/SP\nr+ckSbt27VJaWppDUw3P9evXdffuXXk8HnV2dmrdunVKTU0d1edJ8v28FixY4Pi5GvGX4Q0NDUpP\nT5ckzZw5U11dXXr58uVIj4F3iIyMVHl5uRISEvq2NTY2asWKFZKktLQ0NTQ0ODVeQHw9p9Fu3rx5\nOn78uCQpOjpa3d3do/48Sb6fV29vr8NTORDLjo4OTZ48ue/n2NhYtbe3j/QYIXHv3j3l5uYqIyND\n9fX1To8TsIiICI0fP77ftu7u7r6Xc3FxcaPunPl6TpJUUVGhnJwcfffdd/rrr78cmCxw48aNU1RU\nlCSpsrJSS5YsGfXnSfL9vMaNG+f4uXLkPcv/5fV6nR4hKKZPn67t27fL7XartbVVOTk5qq2tHZXv\nFw1lrJyztWvXKiYmRikpKTp9+rROnjyp/fv3Oz2W3+rq6lRZWamzZ89q5cqVfdtH+3n63+fV1NTk\n+Lka8SvLhIQEdXR09P389OlTxcfHj/QYQZeYmKjVq1fL5XJp2rRpmjJlitra2pweK2iioqL0+vVr\nSVJbW9uYeDmbmpqqlJQUSdLy5cvV3Nzs8ET+u3r1qsrKylReXq5JkyaNmfP09vMKh3M14rFctGiR\nampqJEl37txRQkKCJk6cONJjBN2lS5d05swZSVJ7e7uePXumxMREh6cKnoULF/adt9raWi1evNjh\niYZvx44dam1tlfTve7L//5cMo8WLFy9UUlKiU6dO9d0lHgvnydfzCodz5fI6cK1+9OhR3bx5Uy6X\nSwcOHNCnn3460iME3cuXL7V79249f/5cPT092r59u5YuXer0WAFpampScXGxHj16pIiICCUmJuro\n0aPKz8/X33//ralTp6qoqEgffPCB06Oa+XpOWVlZOn36tCZMmKCoqCgVFRUpLi7O6VHNPB6PTpw4\noRkzZvRtO3z4sPbt2zdqz5Pk+3mtX79eFRUVjp4rR2IJAKMNK3gAwIBYAoABsQQAA2IJAAbEEgAM\niCUAGBBLADAglgBg8B9OkjtgR8VvdgAAAABJRU5ErkJggg==\n",
+ "text/plain": [
+ "<matplotlib.figure.Figure at 0x7fd619a40b00>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ }
+ },
+ {
+ "output_type": "stream",
+ "text": [
+ "Prediction: 6\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4SJizeJtNaAs",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "# Profiling\n",
+ "\n",
+ "If you want to drill down into the performance characteristics of your code, you can use native Python profilers like [`cProfile`](https://docs.python.org/3/library/profile.html). In the next exercise, you'll do just that."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "_2v0QnG8__PJ",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "## Exercise!\n",
+ "\n",
+ "This exercise does not require coding. If you have not completed the training exercise, replace `train_one_epoch` below with `_train_one_epoch`.\n",
+ "\n",
+ "Run the below cell and inspect the printed profiles. What parts of the code appear to be hotspots or\n",
+ "bottlenecks? How does sorting the profile by total time compare to sorting it\n",
+ "by cumulative time?\n",
+ "\n"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "IFypaYbG_9fB",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 714
+ },
+ "outputId": "d9c3596b-a165-4edd-fc6b-53ccd0d01d19"
+ },
+ "cell_type": "code",
+ "source": [
+ "import cProfile\n",
+ "import pstats\n",
+ "\n",
+ "cProfile.run(\"train_one_epoch(model, training_data, optimizer)\", \"training_profile\")\n",
+ "\n",
+ "stats = pstats.Stats(\"training_profile\").strip_dirs().sort_stats(\"tottime\")\n",
+ "stats.print_stats(10)\n",
+ "\n",
+ "stats.sort_stats(\"cumtime\").print_stats(10)"
+ ],
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Thu Jun 7 12:25:04 2018 training_profile\n",
+ "\n",
+ " 92209 function calls (91817 primitive calls) in 3.446 seconds\n",
+ "\n",
+ " Ordered by: internal time\n",
+ " List reduced from 672 to 10 due to restriction <10>\n",
+ "\n",
+ " ncalls tottime percall cumtime percall filename:lineno(function)\n",
+ " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n",
+ " 83 0.753 0.009 0.753 0.009 {built-in method _pywrap_tensorflow_internal.TFE_Py_Execute}\n",
+ " 16 0.006 0.000 1.019 0.064 network.py:736(_run_internal_graph)\n",
+ " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n",
+ " 2321 0.004 0.000 0.007 0.000 abc.py:178(__instancecheck__)\n",
+ " 288 0.004 0.000 0.009 0.000 inspect.py:2092(_signature_from_function)\n",
+ " 878 0.004 0.000 0.005 0.000 ops.py:5936(__enter__)\n",
+ " 288 0.004 0.000 0.016 0.000 inspect.py:1079(getfullargspec)\n",
+ " 11006 0.003 0.000 0.005 0.000 {built-in method builtins.isinstance}\n",
+ " 768 0.003 0.000 0.008 0.000 {built-in method _pywrap_tensorflow_internal.Flatten}\n",
+ "\n",
+ "\n",
+ "Thu Jun 7 12:25:04 2018 training_profile\n",
+ "\n",
+ " 92209 function calls (91817 primitive calls) in 3.446 seconds\n",
+ "\n",
+ " Ordered by: cumulative time\n",
+ " List reduced from 672 to 10 due to restriction <10>\n",
+ "\n",
+ " ncalls tottime percall cumtime percall filename:lineno(function)\n",
+ " 1 0.000 0.000 3.446 3.446 {built-in method builtins.exec}\n",
+ " 1 0.000 0.000 3.446 3.446 <string>:1(<module>)\n",
+ " 1 0.001 0.001 3.446 3.446 <ipython-input-14-bcffed60b545>:9(train_one_epoch)\n",
+ " 1080 2.552 0.002 2.552 0.002 {built-in method _pywrap_tensorflow_internal.TFE_Py_FastPathExecute}\n",
+ " 16 0.000 0.000 2.255 0.141 backprop.py:739(gradient)\n",
+ " 16 0.000 0.000 2.253 0.141 imperative_grad.py:31(imperative_grad)\n",
+ " 16 0.005 0.000 2.253 0.141 {built-in method _pywrap_tensorflow_internal.TFE_Py_TapeGradient}\n",
+ " 400 0.002 0.000 2.246 0.006 backprop.py:145(grad_fn)\n",
+ " 400 0.002 0.000 2.239 0.006 backprop.py:95(_magic_gradient_function)\n",
+ " 32 0.001 0.000 1.601 0.050 nn_grad.py:497(_Conv2DGrad)\n",
+ "\n",
+ "\n"
+ ],
+ "name": "stdout"
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<pstats.Stats at 0x7fd61f841710>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 17
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "8ixpnyCNNTI4",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ ""
+ ],
+ "execution_count": 0,
+ "outputs": []
+ }
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb
new file mode 100644
index 0000000000..64d19ec5c9
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/workshop/3_inspecting.ipynb
@@ -0,0 +1,443 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "Debugging \"graph-first\" models with eager execution",
+ "version": "0.3.2",
+ "provenance": [],
+ "include_colab_link": true
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ "[View in Colaboratory](https://colab.research.google.com/gist/alextp/9568ab40f6ed6f9a3ba4736f6aef6127/debugging-graph-first-models-with-eager-execution.ipynb)"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "mm-t0GuIu1Dt",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This colab uses eager execution and the Python debugger to modify the execution of a translation model. This combination lets you quickly explore counterfactuals when researching and designing modifications to a model.\n",
+ "\n",
+ "The model, Transformer from [Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), was originally written with graph building in mind. Executing it eagerly can still be helpful!"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "gxb1DvIDg4sv",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title License (double click to show)\n",
+ "# Copyright 2018 The TensorFlow Authors.\n",
+ "\n",
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "Gx3HA9N1ui64",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "f6986f34-f3e1-44e1-c902-2eb33081acad"
+ },
+ "cell_type": "code",
+ "source": [
+ "import tensorflow as tf\n",
+ "import pdb\n",
+ "tfe = tf.contrib.eager\n",
+ "\n",
+ "tf.enable_eager_execution()"
+ ],
+ "execution_count": 1,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "3LkOm2ct-Lmc",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 37
+ },
+ "outputId": "2edc74d9-6bc0-4e78-ab4e-83bf96099ef4"
+ },
+ "cell_type": "code",
+ "source": [
+ "!pip install -q -U tensor2tensor\n",
+ "from tensor2tensor.models import transformer"
+ ],
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "1Z3oMsqV0zB6",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 170
+ },
+ "outputId": "0a8186ee-c688-457f-c9f6-9a6c1477a93b"
+ },
+ "cell_type": "code",
+ "source": [
+ "#@title Create a tensor2tensor translation model, fetch a checkpoint (double click to show)\n",
+ "from tensor2tensor import problems\n",
+ "from tensor2tensor.utils import trainer_lib\n",
+ "from tensor2tensor.utils import registry\n",
+ "\n",
+ "import numpy as np\n",
+ "import os\n",
+ "\n",
+ "# Setup some directories\n",
+ "data_dir = os.path.expanduser(\"~/t2t/data\")\n",
+ "tmp_dir = os.path.expanduser(\"~/t2t/tmp\")\n",
+ "train_dir = os.path.expanduser(\"~/t2t/train\")\n",
+ "checkpoint_dir = os.path.expanduser(\"~/t2t/checkpoints\")\n",
+ "tf.gfile.MakeDirs(data_dir)\n",
+ "tf.gfile.MakeDirs(tmp_dir)\n",
+ "tf.gfile.MakeDirs(train_dir)\n",
+ "tf.gfile.MakeDirs(checkpoint_dir)\n",
+ "gs_data_dir = \"gs://tensor2tensor-data\"\n",
+ "gs_ckpt_dir = \"gs://tensor2tensor-checkpoints/\"\n",
+ "\n",
+ "# Fetch the problem\n",
+ "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n",
+ "\n",
+ "# Copy the vocab file locally so we can encode inputs and decode model outputs\n",
+ "# All vocabs are stored on GCS\n",
+ "vocab_name = \"vocab.ende.32768\"\n",
+ "vocab_file = os.path.join(gs_data_dir, vocab_name)\n",
+ "!gsutil cp {vocab_file} {data_dir}\n",
+ "\n",
+ "# Get the encoders from the problem\n",
+ "encoders = ende_problem.feature_encoders(data_dir)\n",
+ "\n",
+ "# Setup helper functions for encoding and decoding\n",
+ "def encode(input_str, output_str=None):\n",
+ " \"\"\"Input str to features dict, ready for inference\"\"\"\n",
+ " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n",
+ " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n",
+ " return {\"inputs\": batch_inputs}\n",
+ "\n",
+ "def decode(integers):\n",
+ " \"\"\"List of ints to str\"\"\"\n",
+ " integers = list(np.squeeze(integers))\n",
+ " if 1 in integers:\n",
+ " integers = integers[:integers.index(1)]\n",
+ " return encoders[\"inputs\"].decode(np.squeeze(integers))\n",
+ "\n",
+ "# Copy the pretrained checkpoint locally\n",
+ "ckpt_name = \"transformer_ende_test\"\n",
+ "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n",
+ "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n",
+ "checkpoint_path = tf.train.latest_checkpoint(\n",
+ " os.path.join(checkpoint_dir, ckpt_name))\n",
+ "\n",
+ "# Create hparams and the model\n",
+ "model_name = \"transformer\"\n",
+ "hparams_set = \"transformer_base\"\n",
+ "\n",
+ "hparams = trainer_lib.create_hparams(hparams_set, data_dir=data_dir, problem_name=\"translate_ende_wmt32k\")\n",
+ "\n",
+ "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n",
+ "# Layer and so subsequent instantiations will have different variable scopes\n",
+ "# that will not match the checkpoint.\n",
+ "translate_model = registry.model(model_name)(hparams, tf.estimator.ModeKeys.EVAL)"
+ ],
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "Copying gs://tensor2tensor-data/vocab.ende.32768...\n",
+ "/ [1 files][316.4 KiB/316.4 KiB] \n",
+ "Operation completed over 1 objects/316.4 KiB. \n",
+ "INFO:tensorflow:Setting T2TModel mode to 'eval'\n",
+ "INFO:tensorflow:Setting hparams.layer_prepostprocess_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.symbol_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.attention_dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.dropout to 0.0\n",
+ "INFO:tensorflow:Setting hparams.relu_dropout to 0.0\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "4IblPXLGjuCl",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We've created a Transformer model and fetched an existing training checkpoint. It hasn't created variables yet, and we want to load them from the checkpoint before they're used (restore-on-create) so the first run of the model outputs the correct value. The `tfe.restore_variables_on_create` API looks up variables by name on creation and restores their values."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "o3MWxcAqJoqG",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 51
+ },
+ "outputId": "fbc1b1bf-ffbe-4621-b3cb-5eb855fec3a8"
+ },
+ "cell_type": "code",
+ "source": [
+ "with tfe.restore_variables_on_create(checkpoint_path):\n",
+ " model_output = translate_model.infer(encode(\"Eager execution\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "INFO:tensorflow:Greedy Decoding\n",
+ "Hinrichtung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "xk5HV9Hhu9zO",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Using global variable names can get somewhat fragile, so for new code we recommend the object-based `tf.keras.Model.save_weights` or `tf.train.Checkpoint`. However, these require some small code changes to work with existing graph building code.\n",
+ "\n",
+ "The Transformer model translates \"Eager execution\" in English to \"Hinrichtung\" in German, which refers to capital punishment rather than getting things done. Transformer first encodes the English, then decodes to German. We'll add a debugging hook at the start of the decode phase (once the encodings have been finalized) and see if we can correct the translation."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "GUGwbYvXZ9-7",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "cell_type": "code",
+ "source": [
+ "previous_fast_decode = transformer.fast_decode\n",
+ "def debug_fn(*args, **kwargs):\n",
+ " pdb.set_trace()\n",
+ " return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "transformer.fast_decode = debug_fn # Add our debugging hook to Transformer"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "metadata": {
+ "id": "f61HlvECxJn0",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Now that we've \"monkey patched\" the model, we'll drop into a debugger just before decoding starts. In most cases it'd be simpler to add the `pdb.set_trace()` call to the code directly, but in this case we're working with prepackaged library code.\n",
+ "\n",
+ "First, let's find an encoding which represents the correct sense of \"execution\". Then we'll patch part of that encoding into the encoding of \"Eager execution\" to fix the translation. Feel free to poke around with the debugger (e.g. print a Tensor's value), but your main task is to save the encodings by assigning them to an attribute of the function:\n",
+ "\n",
+ "```\n",
+ "(running the next cell drops you into a pdb shell)\n",
+ "step\n",
+ "fast_decode.previous_encoding = encoder_output\n",
+ "continue\n",
+ "\n",
+ "```\n",
+ "\n",
+ "You can type `next` (or `n`) a few times before `continue` to watch the decoding ops run."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "dX4CPOGSpZrb",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "6de38c31-836f-40ef-b701-e42908172619"
+ },
+ "cell_type": "code",
+ "source": [
+ "model_output = translate_model.infer(encode(\"Immediate running\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "> <ipython-input-6-ee9b4225ba2a>(4)debug_fn()\n",
+ "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "(Pdb) step\n",
+ "--Call--\n",
+ "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n",
+ "-> def fast_decode(encoder_output,\n",
+ "(Pdb) fast_decode.previous_encoding = encoder_output\n",
+ "(Pdb) continue\n",
+ "Sofortige Durchführung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "-ZEZciV4FpLo",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Now we have an encoding saved which gets the correct sense for \"execution\"."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "QeC_oDVqHD_v",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "253c9af1-003e-46bd-8bf5-db968cf6a8cf"
+ },
+ "cell_type": "code",
+ "source": [
+ "# Assumes you followed the pdb instructions above!\n",
+ "transformer.fast_decode.previous_encoding"
+ ],
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "<tf.Tensor: id=9528, shape=(1, 4, 512), dtype=float32, numpy=\n",
+ "array([[[-0.15239455, 0.12273102, -0.11209048, ..., -0.12478986,\n",
+ " 0.37216735, -0.40987235],\n",
+ " [-0.2686283 , 0.51448774, 0.03650613, ..., 0.08731575,\n",
+ " 0.51110077, -0.6646815 ],\n",
+ " [-0.24441548, 0.36622533, 0.11685672, ..., 0.21941349,\n",
+ " -0.03304008, -0.579611 ],\n",
+ " [-0.03339856, -0.01185844, 0.00579634, ..., 0.00294734,\n",
+ " 0.00136655, -0.01362935]]], dtype=float32)>"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "execution_count": 8
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "bC9JjeDcHEav",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "Let's replace part of the encoding for \"Eager execution\" with the encoding of \"Immediate running\".\n",
+ "\n",
+ "Again we'll drop into a pdb shell. This time we'll run some TensorFlow operations to patch the encodings while the model is running.\n",
+ "\n",
+ "```\n",
+ "(running the next cell again drops you into a pdb shell)\n",
+ "step\n",
+ "encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n",
+ "continue\n",
+ "```"
+ ]
+ },
+ {
+ "metadata": {
+ "id": "t2as_Kn1h65G",
+ "colab_type": "code",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 179
+ },
+ "outputId": "5b4e546e-3bb4-4761-c545-467b631e3ffe"
+ },
+ "cell_type": "code",
+ "source": [
+ "model_output = translate_model.infer(encode(\"Eager execution\"))\n",
+ "print(decode(model_output[\"outputs\"]))"
+ ],
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ "> <ipython-input-6-ee9b4225ba2a>(4)debug_fn()\n",
+ "-> return previous_fast_decode(*args, **kwargs) # \"step\" in pdb to step in\n",
+ "(Pdb) step\n",
+ "--Call--\n",
+ "> /usr/local/lib/python2.7/dist-packages/tensor2tensor/models/transformer.py(427)fast_decode()\n",
+ "-> def fast_decode(encoder_output,\n",
+ "(Pdb) encoder_output = tf.concat([fast_decode.previous_encoding[:, :3], encoder_output[:, 3:]], axis=1)\n",
+ "(Pdb) continue\n",
+ "sofortige Ausführung\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "metadata": {
+ "id": "rK6tYZ23I2cm",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "We get a different decoding, with the correct sense of \"execution\". Likely we're keeping just the encoding of \"tion\" from \"Eager execution\", so no great breakthrough in translation modeling.\n",
+ "\n",
+ "Similarly it's possible to modify attention vectors, or change words during decoding to help debug a beam search."
+ ]
+ },
+ {
+ "metadata": {
+ "id": "Nb-4ipYNRWxA",
+ "colab_type": "text"
+ },
+ "cell_type": "markdown",
+ "source": [
+ "This colab was adapted from the [Tensor2Tensor colab](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb). Credit to Ankur Taly for its concept."
+ ]
+ }
+ ]
+} \ No newline at end of file
diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py
index 3e31004273..04b7b1165e 100644
--- a/tensorflow/contrib/eager/python/metrics.py
+++ b/tensorflow/contrib/eager/python/metrics.py
@@ -22,5 +22,6 @@ from __future__ import print_function
from tensorflow.contrib.eager.python.metrics_impl import *
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['Accuracy', 'Mean', 'Metric']
+_allowed_symbols = ['Accuracy', 'Mean', 'Metric', 'CategoricalAccuracy',
+ 'BinaryAccuracy', 'SparseAccuracy']
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index c947ed9dcc..efa6ba0626 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -345,9 +345,14 @@ class Mean(Metric):
class Accuracy(Mean):
- """Calculates how often `predictions` matches `labels`."""
+ """Calculates how often `predictions` matches `labels`.
+ Attributes:
+ name: name of the accuracy object
+ dtype: data type of the tensor
+ """
def __init__(self, name=None, dtype=dtypes.float64):
+ """Inits Accuracy class with name and dtype."""
super(Accuracy, self).__init__(name=name, dtype=dtype)
def call(self, labels, predictions, weights=None):
@@ -377,3 +382,146 @@ class Accuracy(Mean):
if weights is None:
return labels, predictions
return labels, predictions, weights
+
+
+class CategoricalAccuracy(Mean):
+ """Calculates how often `predictions` matches `labels`.
+
+ This class is compatible with `tf.keras.losses.categorical_crossentropy`,
+ `tf.nn.softmax_cross_entropy_with_logits_v2`,
+ `tf.losses.softmax_cross_entropy`.
+
+ Attributes:
+ name: name of the accuracy object.
+ dtype: data type of tensor.
+ """
+
+ def __init__(self, name=None, dtype=dtypes.float64):
+ """Inits CategoricalAccuracy with name and dtype."""
+ super(CategoricalAccuracy, self).__init__(name=name, dtype=dtype)
+
+ def call(self, labels, predictions, weights=None):
+ """Accumulate accuracy statistics.
+
+ `labels` and `predictions` should have the same shape.
+ As argmax is being done here, labels and predictions type
+ can be different.
+
+ Args:
+ labels: One-hot Tensor.
+ predictions: Tensor with the logits or probabilities for each example.
+ weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
+ """
+ check_ops.assert_equal(
+ array_ops.shape(labels), array_ops.shape(predictions),
+ message="Shapes of labels and predictions are unequal")
+ labels = math_ops.argmax(labels, axis=-1)
+ predictions = math_ops.argmax(predictions, axis=-1)
+ matches = math_ops.equal(labels, predictions)
+ matches = math_ops.cast(matches, dtypes.float64)
+ super(CategoricalAccuracy, self).call(matches, weights=weights)
+ if weights is None:
+ return labels, predictions
+ return labels, predictions, weights
+
+
+class BinaryAccuracy(Mean):
+ """Calculates how often `predictions` matches `labels`.
+
+ This class is compatible with `tf.keras.losses.binary_crossentropy`,
+ `tf.losses.sigmoid_cross_entropy`,
+ `tf.nn.sigmoid_cross_entropy_with_logits`.
+ If there is more than one label, this will become multi-label classification.
+
+ Attributes:
+ name: name of the accuracy object.
+ threshold: Used for rounding off the predictions.
+ If the predictions are,
+ 1. probabilities then set the threshold to 0.5.
+ 2. logits then set the threshold to 0.
+ You can set the threshold appropriately,
+ to trade off with precision and recall.
+ dtype: data type of tensor.
+ """
+
+ def __init__(self, threshold, name=None, dtype=dtypes.float64):
+ """Inits BinaryAccuracy with name, threshold and dtype."""
+
+ super(BinaryAccuracy, self).__init__(name=name, dtype=dtype)
+ self.threshold = threshold
+
+ def call(self, labels, predictions, weights=None):
+ """Accumulate accuracy statistics.
+
+ `labels` and `predictions` should have the same shape and type.
+
+ Args:
+ labels: Binary Tensor(containing 0 or 1).
+ predictions: Tensor with probabilities or logits.
+ weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
+ """
+ check_ops.assert_equal(
+ array_ops.shape(labels), array_ops.shape(predictions),
+ message="Shapes of labels and predictions are unequal")
+ predictions = ops.convert_to_tensor(predictions)
+ predictions = predictions > self.threshold
+ matches = math_ops.equal(labels, predictions)
+ matches = math_ops.cast(matches, dtypes.float64)
+ super(BinaryAccuracy, self).call(matches, weights=weights)
+ if weights is None:
+ return labels, predictions
+ return labels, predictions, weights
+
+
+class SparseAccuracy(Mean):
+ """Calculates how often `predictions` matches `labels`.
+
+ This class is compatible with
+ `tf.keras.losses.sparse_categorical_crossentropy`,
+ `tf.nn.sparse_softmax_cross_entropy_with_logits`,
+ `tf.losses.sparse_softmax_cross_entropy`.
+
+ Attributes:
+ name: name of the accuracy object
+ dtype: data type of tensor.
+ """
+
+ def __init__(self, name=None, dtype=dtypes.float64):
+ """Inits SparseAccuracy with name and dtype."""
+
+ super(SparseAccuracy, self).__init__(name=name, dtype=dtype)
+
+ def call(self, labels, predictions, weights=None):
+ """Accumulate accuracy statistics.
+
+ `labels` and `predictions` should have the same shape except the
+ predictions must have one additional trailing dimension equal to the
+ number of classes(you want to predict).
+
+ Type of labels and predictions can be different.
+
+ Args:
+ labels: Tensor of shape (batch_size, ) containing integers
+ predictions: Tensor with the logits or probabilities for each example.
+ weights: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ The arguments, for easy chaining.
+ """
+ check_ops.assert_equal(
+ array_ops.shape(labels), array_ops.shape(predictions)[0],
+ message="First axis of labels and predictions is unequal")
+ predictions = math_ops.argmax(predictions, axis=-1)
+ labels = math_ops.cast(labels, dtypes.int64)
+ matches = math_ops.equal(labels, predictions)
+ matches = math_ops.cast(matches, dtypes.float64)
+ super(SparseAccuracy, self).call(matches, weights=weights)
+ if weights is None:
+ return labels, predictions
+ return labels, predictions, weights
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 02ee054875..20d938d492 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -118,6 +118,39 @@ class MetricsTest(test.TestCase):
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
+ def testCategoricalAccuracy(self):
+ m = metrics.CategoricalAccuracy()
+ m([[1, 0, 0, 0], [0, 1, 0, 0]],
+ [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 1/2 correct
+ m([[0, 0, 0, 1]], [[0.25, 0.95, 0.25, 0.0]]) # 0/1 correct
+ m([[1, 0, 0, 0], [0, 1, 0, 0]],
+ [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct
+ self.assertEqual(2.0/5, m.result().numpy())
+ self.assertEqual(dtypes.float64, m.dtype)
+ self.assertEqual(dtypes.float64, m.result().dtype)
+
+ def testBinaryAccuracy(self):
+ m = metrics.BinaryAccuracy(threshold=0)
+ # as threshold is 0 hence the predictions are logits
+ m([[0, 0, 0, 0]],
+ [[-4.2, 4.5, 1.2, -1.1]]) # 2/4 correct
+ m([[0, 1]], [[-5.3, 11.65]]) # 2/2 correct
+ m([[0, 1], [1, 1]],
+ [[-5.3, 11.65], [-10.32, 56.38]]) # 3/4 correct
+ self.assertEqual(7.0/10, m.result().numpy())
+ self.assertEqual(dtypes.float64, m.dtype)
+ self.assertEqual(dtypes.float64, m.result().dtype)
+
+ def testSparseAccuracy(self):
+ m = metrics.SparseAccuracy()
+ m([0, 2],
+ [[0.6, 0.1, 0.25, 0.05], [0.4, 0.05, 0.45, 0.0]]) # 2/2 correct
+ m([1], [[0.25, 0.95, 0.25, 0.0]]) # 1/1 correct
+ m([0, 3], [[0.99, 0.01, 0.0, 0.0], [0.35, 0.35, 0.3, 0.0]]) # 1/2 correct
+ self.assertEqual(4.0/5, m.result().numpy())
+ self.assertEqual(dtypes.float64, m.dtype)
+ self.assertEqual(dtypes.float64, m.result().dtype)
+
def testAccuracyDifferentShapes(self):
m = metrics.Accuracy()
with self.assertRaises(errors.InvalidArgumentError):
@@ -173,7 +206,7 @@ class MetricsTest(test.TestCase):
sess.run(accumulate, feed_dict={p: 7})
self.assertAllEqual(m.result().eval(), 7)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGraphAndEagerTensor(self):
m = metrics.Mean()
inputs = ops.convert_to_tensor([1.0, 2.0])
@@ -221,7 +254,7 @@ class MetricsTest(test.TestCase):
self.assertAllEqual(m2.result().eval(), 2.0)
self.assertAllEqual(m1.result().eval(), 1.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveRestore(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index c92bd15b25..240f213c60 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -126,7 +126,7 @@ class NetworkTest(test.TestCase):
self.assertAllEqual([[17.0], [34.0]], self.evaluate(result))
# TODO(allenl): This test creates garbage in some Python versions
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNetworkSaveRestoreAlreadyBuilt(self):
net = MyNetwork(name="abcd")
with self.assertRaisesRegexp(
@@ -138,7 +138,7 @@ class NetworkTest(test.TestCase):
self._save_modify_load_network_built(net, global_step=10)
# TODO(allenl): This test creates garbage in some Python versions
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveRestoreDefaultGlobalStep(self):
net = MyNetwork(name="abcd")
net(constant_op.constant([[2.0]]))
@@ -149,7 +149,7 @@ class NetworkTest(test.TestCase):
self.assertIn("abcd-4242", save_path)
# TODO(allenl): This test creates garbage in some Python versions
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNetworkSaveAndRestoreIntoUnbuilt(self):
save_dir = self.get_temp_dir()
net1 = MyNetwork()
@@ -166,7 +166,7 @@ class NetworkTest(test.TestCase):
self.assertAllEqual(self.evaluate(net1.variables[0]),
self.evaluate(net2.variables[0]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNetworkMatchesLayerVariableNames(self):
zero = constant_op.constant([[0.]])
layer_one = core.Dense(1, use_bias=False)
@@ -193,7 +193,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("two_layer_net/" + layer_two.variables[0].name,
net.second.variables[0].name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLoadIntoUnbuiltSharedLayer(self):
class Owner(network.Network):
@@ -272,7 +272,7 @@ class NetworkTest(test.TestCase):
network.restore_network_checkpoint(
load_into, save_path, map_func=_restore_map_func)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRestoreIntoSubNetwork(self):
class Parent(network.Network):
@@ -327,7 +327,7 @@ class NetworkTest(test.TestCase):
# The checkpoint is incompatible.
network.restore_network_checkpoint(save_into_parent, checkpoint)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCustomMapCollisionErrors(self):
class Parent(network.Network):
@@ -372,7 +372,7 @@ class NetworkTest(test.TestCase):
network.restore_network_checkpoint(
loader, checkpoint, map_func=lambda n: "foo")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDefaultMapCollisionErrors(self):
one = constant_op.constant([[1.]])
@@ -571,7 +571,7 @@ class NetworkTest(test.TestCase):
expected_start="my_network_1/dense/",
actual=outside_net_after.trainable_weights[0].name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVariableScopeStripping(self):
with variable_scope.variable_scope("scope1"):
with variable_scope.variable_scope("scope2"):
@@ -596,7 +596,7 @@ class NetworkTest(test.TestCase):
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayerNamesRespected(self):
class ParentNetwork(network.Network):
@@ -677,7 +677,7 @@ class NetworkTest(test.TestCase):
self.assertStartsWith(expected_start="my_network_1/dense/",
actual=net2.trainable_weights[0].name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestableAnonymous(self):
# The case where no explicit names are specified. We make up unique names,
@@ -721,7 +721,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("my_network", net2.first.name)
self.assertEqual("my_network_1", net2.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestableExplicit(self):
# We have explicit network names and everything is globally unique.
@@ -750,7 +750,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("first_unique_child_name", net.first.name)
self.assertEqual("second_unique_child_name", net.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayerNetworkNameInteractions(self):
# Same base name as core.Dense; Networks and non-Network Layers with the
@@ -801,7 +801,7 @@ class NetworkTest(test.TestCase):
actual=net.trainable_weights[4].name)
self.assertEqual("mixed_layer_network", net.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestableExplicitCollisions(self):
# We have explicit network names and they are unique within the layer
@@ -831,7 +831,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("nonunique_name", net.first.name)
self.assertEqual("second_unique_child_name", net.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestableExplicitWithAnonymousParent(self):
# A parent network is instantiated multiple times with explicitly named
@@ -873,7 +873,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("first_unique_child_name", net2.first.name)
self.assertEqual("second_unique_child_name", net2.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestableExplicitSameLayerCollisions(self):
# We have explicit network names and they are _not_ unique within the layer
@@ -891,7 +891,7 @@ class NetworkTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "nonunique_name"):
ParentNetwork()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAnonymousVariableSharing(self):
# Two "owned" Networks
@@ -989,7 +989,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("my_network", net4.first.name)
self.assertEqual("my_network", net4.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRecursiveLayerRenaming(self):
core.Dense(1) # Under default Layer naming, would change subsequent names.
@@ -1041,7 +1041,7 @@ class NetworkTest(test.TestCase):
self.assertEqual("dense", net.second.first.name)
self.assertEqual("dense_1", net.second.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCallInDifferentOrderThanConstruct(self):
shared_network = MyNetwork()
@@ -1091,7 +1091,7 @@ class NetworkTest(test.TestCase):
self.assertTrue(net2.first is net1.first)
self.assertEqual("my_network", net2.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayerCallInDifferentOrderThanConstruct(self):
# Same idea as testCallInDifferentOrderThanConstruct, but this time with a
# non-Network Layer shared between two Networks rather than a
@@ -1144,7 +1144,7 @@ class NetworkTest(test.TestCase):
self.assertTrue(net2.first is net1.first)
self.assertEqual("dense", net2.second.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayerAlreadyBuilt(self):
one = constant_op.constant([[1.]])
core.Dense(1, use_bias=False) # pre-built layers use global naming
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 5826700c73..ca6430253b 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -68,6 +68,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@async_clear_error
@@run_test_in_graph_and_eager_modes
+@@run_all_tests_in_graph_and_eager_modes
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@ -115,12 +116,13 @@ from tensorflow.python.eager.execution_callbacks import seterr
from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
+from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes
from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.ops.variable_scope import EagerVariableStore
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import template
-from tensorflow.python.training.checkpointable.base import Checkpointable
+from tensorflow.python.training.checkpointable.tracking import Checkpointable
from tensorflow.python.training.checkpointable.util import CheckpointableSaver
from tensorflow.python.training.checkpointable.util import Checkpoint
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 47c7b7fc19..11d40f5982 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":early_stopping",
":export",
":extenders",
":head",
@@ -117,7 +118,7 @@ py_library(
py_test(
name = "dnn_test",
- size = "small",
+ size = "medium",
srcs = ["python/estimator/dnn_test.py"],
srcs_version = "PY2AND3",
tags = [
@@ -312,6 +313,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:variables",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
@@ -589,3 +591,31 @@ py_test(
"@six_archive//:six",
],
)
+
+py_library(
+ name = "early_stopping",
+ srcs = ["python/estimator/early_stopping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python/estimator",
+ ],
+)
+
+py_test(
+ name = "early_stopping_test",
+ srcs = ["python/estimator/early_stopping_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":early_stopping",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/estimator",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 788ac5ca70..09fcfd66a1 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
+from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
@@ -63,6 +64,12 @@ _allowed_symbols = [
'RNNEstimator',
'export_saved_model_for_mode',
'export_all_saved_models',
+ 'make_early_stopping_hook',
+ 'read_eval_metrics',
+ 'stop_if_lower_hook',
+ 'stop_if_higher_hook',
+ 'stop_if_no_increase_hook',
+ 'stop_if_no_decrease_hook',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index d0e3e670f7..505c94e971 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -113,6 +113,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -141,6 +143,8 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -166,7 +170,9 @@ class BaselineEstimatorEvaluationTest(test.TestCase):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is bias which is [46, 58]
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index bd641014e9..43bfcffd79 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -49,7 +49,8 @@ class _BoostedTreesEstimator(estimator.Estimator):
l2_regularization=0.,
tree_complexity=0.,
min_node_weight=0.,
- config=None):
+ config=None,
+ center_bias=False):
"""Initializes a `BoostedTreesEstimator` instance.
Args:
@@ -82,17 +83,30 @@ class _BoostedTreesEstimator(estimator.Estimator):
considered. The value will be compared with sum(leaf_hessian)/
(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
- features, labels, mode, head, feature_columns, tree_hparams,
- n_batches_per_layer, config)
+ features,
+ labels,
+ mode,
+ head,
+ feature_columns,
+ tree_hparams,
+ n_batches_per_layer,
+ config=config)
super(_BoostedTreesEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -114,7 +128,8 @@ def boosted_trees_classifier_train_in_memory(
tree_complexity=0.,
min_node_weight=0.,
config=None,
- train_hooks=None):
+ train_hooks=None,
+ center_bias=False):
"""Trains a boosted tree classifier with in memory dataset.
Example:
@@ -186,7 +201,13 @@ def boosted_trees_classifier_train_in_memory(
considered. The value will be compared with sum(leaf_hessian)/
(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
- train_hooks: a list of Hook instances to be passed to estimator.train().
+ train_hooks: a list of Hook instances to be passed to estimator.train()
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -207,7 +228,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -247,7 +268,8 @@ def boosted_trees_regressor_train_in_memory(
tree_complexity=0.,
min_node_weight=0.,
config=None,
- train_hooks=None):
+ train_hooks=None,
+ center_bias=False):
"""Trains a boosted tree regressor with in memory dataset.
Example:
@@ -313,6 +335,12 @@ def boosted_trees_regressor_train_in_memory(
(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
train_hooks: a list of Hook instances to be passed to estimator.train().
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -332,7 +360,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index 76cbefe5e9..999c2aa5e2 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -115,6 +115,27 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['average_loss'], 1.008551)
+ def testTrainAndEvaluateEstimatorWithCenterBias(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ center_bias=True)
+
+ # It will stop after 11 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ # 10 steps for training and 2 step for bias centering.
+ self._assert_checkpoint(
+ est.model_dir, global_step=12, finalized_trees=2, attempted_layers=10)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 0.614642)
+
def testInferEstimator(self):
train_input_fn = _make_train_input_fn(is_classification=False)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -139,6 +160,33 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
[[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
[pred['predictions'] for pred in predictions])
+ def testInferEstimatorWithCenterBias(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True,
+ head=self._head)
+
+ # It will stop after 6 steps because of the max depth and num trees (5 for
+ # training and 2 for bias centering).
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(train_input_fn, steps=num_steps)
+ self._assert_checkpoint(
+ est.model_dir, global_step=7, finalized_trees=1, attempted_layers=5)
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+
+ self.assertAllClose(
+ [[1.634501], [1.325703], [1.187431], [2.019683], [2.832683]],
+ [pred['predictions'] for pred in predictions])
+
def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self):
train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -159,14 +207,40 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
+ def testBinaryClassifierTrainInMemoryAndEvalAndInferWithCenterBias(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+ # It will stop after 5 steps + 3 for bias, because of the max depth and num
+ # trees.
+ self._assert_checkpoint(
+ est.model_dir, global_step=8, finalized_trees=1, attempted_layers=5)
+
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testBinaryClassifierTrainInMemoryWithDataset(self):
train_input_fn = _make_train_input_fn_dataset(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
est = boosted_trees.boosted_trees_classifier_train_in_memory(
- train_input_fn=train_input_fn, feature_columns=self._feature_columns,
- n_trees=1, max_depth=5)
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5)
# It will stop after 5 steps because of the max depth and num trees.
self._assert_checkpoint(
est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py
index 7ff25b95c0..9efa8f474d 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn.py
@@ -53,6 +53,25 @@ class DNNEstimator(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
+ # Or estimator with warm-starting from a previous checkpoint.
+ estimator = DNNEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ warm_start_from="/path/to/checkpoint/dir")
+
# Input builders
def input_fn_train: # returns x, y
pass
@@ -92,7 +111,9 @@ class DNNEstimator(estimator.Estimator):
activation_fn=nn.relu,
dropout=None,
input_layer_partitioner=None,
- config=None):
+ config=None,
+ warm_start_from=None,
+ batch_norm=False):
"""Initializes a `DNNEstimator` instance.
Args:
@@ -107,8 +128,9 @@ class DNNEstimator(estimator.Estimator):
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
@@ -116,6 +138,12 @@ class DNNEstimator(estimator.Estimator):
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or
+ a `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ batch_norm: Whether to use batch normalization after each hidden layer.
"""
def _model_fn(features, labels, mode, config):
return dnn_lib._dnn_model_fn( # pylint: disable=protected-access
@@ -129,6 +157,8 @@ class DNNEstimator(estimator.Estimator):
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm)
super(DNNEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn, model_dir=model_dir, config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
index ccaf1128bf..2eef60c39f 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
@@ -53,12 +53,19 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
dnn_hidden_units=[1000, 500, 100],
dnn_optimizer=tf.train.ProximalAdagradOptimizer(...))
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -103,7 +110,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
dnn_activation_fn=nn.relu,
dnn_dropout=None,
input_layer_partitioner=None,
- config=None):
+ config=None,
+ linear_sparse_combiner='sum'):
"""Initializes a DNNLinearCombinedEstimator instance.
Args:
@@ -116,12 +124,16 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
@@ -131,6 +143,11 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
input_layer_partitioner: Partitioner for input layer. Defaults to
`min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: RunConfig object to configure the runtime settings.
+ linear_sparse_combiner: A string specifying how to reduce the linear model
+ if a categorical column is multivalent. One of "mean", "sqrtn", and
+ "sum" -- these are effectively different ways to do example-level
+ normalization, which can be useful for bag-of-words features. For more
+ details, see @{tf.feature_column.linear_model$linear_model}.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -158,7 +175,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ linear_sparse_combiner=linear_sparse_combiner)
super(DNNLinearCombinedEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
index dd009a6753..51b9ce7005 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py
@@ -100,7 +100,8 @@ def _linear_only_estimator_fn(
weight_column=None,
optimizer='Ftrl',
config=None,
- partitioner=None):
+ partitioner=None,
+ sparse_combiner='sum'):
return dnn_linear_combined.DNNLinearCombinedEstimator(
head=head_lib.regression_head(
weight_column=weight_column, label_dimension=label_dimension,
@@ -110,7 +111,8 @@ def _linear_only_estimator_fn(
linear_feature_columns=feature_columns,
linear_optimizer=optimizer,
input_layer_partitioner=partitioner,
- config=config)
+ config=config,
+ linear_sparse_combiner=sparse_combiner)
class LinearOnlyEstimatorEvaluateTest(
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
index 75e3107670..050b0428bf 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_test.py
@@ -38,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
-def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
+def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
"""Returns a DNNEstimator that uses regression_head."""
return dnn.DNNEstimator(
head=head_lib.regression_head(
@@ -48,6 +48,12 @@ def _dnn_estimator_fn(weight_column=None, label_dimension=1, *args, **kwargs):
*args, **kwargs)
+def _dnn_estimator_classifier_fn(n_classes=3, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
+ """Returns a DNNEstimator that uses multi_class_head."""
+ return dnn.DNNEstimator(head=head_lib.multi_class_head(n_classes=n_classes),
+ *args, **kwargs)
+
+
class DNNEstimatorEvaluateTest(
dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
@@ -75,6 +81,15 @@ class DNNEstimatorTrainTest(
self, _dnn_estimator_fn)
+class DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_estimator_classifier_fn, _dnn_estimator_fn)
+
+
class DNNEstimatorIntegrationTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
new file mode 100644
index 0000000000..af4855e91e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -0,0 +1,468 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for early stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import operator
+import os
+
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
+
+_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'
+
+
+def make_early_stopping_hook(estimator,
+ should_stop_fn,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates early-stopping hook.
+
+ Returns a `SessionRunHook` that stops training when `should_stop_fn` returns
+ `True`.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ hook = early_stopping.make_early_stopping_hook(
+ estimator, should_stop_fn=make_stop_fn(...))
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ should_stop_fn: `callable`, function that takes no arguments and returns a
+ `bool`. If the function returns `True`, stopping will be initiated by the
+ chief.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ A `SessionRunHook` that periodically executes `should_stop_fn` and initiates
+ early stopping if the function returns `True`.
+
+ Raises:
+ TypeError: If `estimator` is not of type `tf.estimator.Estimator`.
+ ValueError: If both `run_every_secs` and `run_every_steps` are set.
+ """
+ if not isinstance(estimator, estimator_lib.Estimator):
+ raise TypeError('`estimator` must have type `tf.estimator.Estimator`. '
+ 'Got: {}'.format(type(estimator)))
+
+ if run_every_secs is not None and run_every_steps is not None:
+ raise ValueError('Only one of `run_every_secs` and `run_every_steps` must '
+ 'be set.')
+
+ if estimator.config.is_chief:
+ return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps)
+ else:
+ return _CheckForStoppingHook()
+
+
+def stop_if_higher_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is higher than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy becomes higher than 0.9.
+ hook = early_stopping.stop_if_higher_hook(estimator, "accuracy", 0.9)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is higher than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_lower_hook(estimator,
+ metric_name,
+ threshold,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if the given metric is lower than the threshold.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss becomes lower than 100.
+ hook = early_stopping.stop_if_lower_hook(estimator, "loss", 100)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ threshold: Numeric threshold for the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric is lower than specified threshold and initiates
+ early stopping if true.
+ """
+ return _stop_if_threshold_crossed_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ threshold=threshold,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_increase_hook(estimator,
+ metric_name,
+ max_steps_without_increase,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not increase within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if accuracy does not increase in over 100000 steps.
+ hook = early_stopping.stop_if_no_increase_hook(estimator, "accuracy", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_increase: `int`, maximum number of training steps with no
+ increase in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no increase over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_increase,
+ higher_is_better=True,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def stop_if_no_decrease_hook(estimator,
+ metric_name,
+ max_steps_without_decrease,
+ eval_dir=None,
+ min_steps=0,
+ run_every_secs=60,
+ run_every_steps=None):
+ """Creates hook to stop if metric does not decrease within given max steps.
+
+ Usage example:
+
+ ```python
+ estimator = ...
+ # Hook to stop training if loss does not decrease in over 100000 steps.
+ hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
+ train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
+ tf.estimator.train_and_evaluate(estimator, train_spec, ...)
+ ```
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance.
+ metric_name: `str`, metric to track. "loss", "accuracy", etc.
+ max_steps_without_decrease: `int`, maximum number of training steps with no
+ decrease in the given metric.
+ eval_dir: If set, directory containing summary files with eval metrics. By
+ default, `estimator.eval_dir()` will be used.
+ min_steps: `int`, stop is never requested if global step is less than this
+ value. Defaults to 0.
+ run_every_secs: If specified, calls `should_stop_fn` at an interval of
+ `run_every_secs` seconds. Defaults to 60 seconds. Either this or
+ `run_every_steps` must be set.
+ run_every_steps: If specified, calls `should_stop_fn` every
+ `run_every_steps` steps. Either this or `run_every_secs` must be set.
+
+ Returns:
+ An early-stopping hook of type `SessionRunHook` that periodically checks
+ if the given metric shows no decrease over given maximum number of
+ training steps, and initiates early stopping if true.
+ """
+ return _stop_if_no_metric_improvement_hook(
+ estimator=estimator,
+ metric_name=metric_name,
+ max_steps_without_improvement=max_steps_without_decrease,
+ higher_is_better=False,
+ eval_dir=eval_dir,
+ min_steps=min_steps,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def read_eval_metrics(eval_dir):
+ """Helper to read eval metrics from eval summary files.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Returns:
+ A `dict` with global steps mapping to `dict` of metric names and values.
+ """
+ eval_metrics_dict = {}
+ for event in _summaries(eval_dir):
+ if not event.HasField('summary'):
+ continue
+ metrics = {}
+ for value in event.summary.value:
+ if value.HasField('simple_value'):
+ metrics[value.tag] = value.simple_value
+ if metrics:
+ eval_metrics_dict[event.step] = metrics
+ return eval_metrics_dict
+
+
+def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
+ higher_is_better, eval_dir, min_steps,
+ run_every_secs, run_every_steps):
+ """Creates early-stopping hook to stop training if threshold is crossed."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ greater_or_lesser = 'greater than' if higher_is_better else 'less than'
+
+ def stop_if_threshold_crossed_fn():
+ """Returns `True` if the given metric crosses specified threshold."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if is_lhs_better(val, threshold):
+ tf_logging.info(
+ 'At step %s, metric "%s" has value %s which is %s the configured '
+ 'threshold (%s) for early stopping.', step, metric_name, val,
+ greater_or_lesser, threshold)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_threshold_crossed_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _stop_if_no_metric_improvement_hook(
+ estimator, metric_name, max_steps_without_improvement, higher_is_better,
+ eval_dir, min_steps, run_every_secs, run_every_steps):
+ """Returns hook to stop training if given metric shows no improvement."""
+
+ if eval_dir is None:
+ eval_dir = estimator.eval_dir()
+
+ is_lhs_better = operator.gt if higher_is_better else operator.lt
+ increase_or_decrease = 'increase' if higher_is_better else 'decrease'
+
+ def stop_if_no_metric_improvement_fn():
+ """Returns `True` if metric does not improve within max steps."""
+
+ eval_results = read_eval_metrics(eval_dir)
+
+ best_val = None
+ best_val_step = None
+ for step, metrics in eval_results.items():
+ if step < min_steps:
+ continue
+ val = metrics[metric_name]
+ if best_val is None or is_lhs_better(val, best_val):
+ best_val = val
+ best_val_step = step
+ if step - best_val_step >= max_steps_without_improvement:
+ tf_logging.info(
+ 'No %s in metric "%s" for %s steps, which is greater than or equal '
+ 'to max steps (%s) configured for early stopping.',
+ increase_or_decrease, metric_name, step - best_val_step,
+ max_steps_without_improvement)
+ return True
+ return False
+
+ return make_early_stopping_hook(
+ estimator=estimator,
+ should_stop_fn=stop_if_no_metric_improvement_fn,
+ run_every_secs=run_every_secs,
+ run_every_steps=run_every_steps)
+
+
+def _summaries(eval_dir):
+ """Yields `tensorflow.Event` protos from event files in the eval dir.
+
+ Args:
+ eval_dir: Directory containing summary files with eval metrics.
+
+ Yields:
+ `tensorflow.Event` object read from the event files.
+ """
+ for event_file in gfile.Glob(
+ os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):
+ for event in summary_iterator.summary_iterator(event_file):
+ yield event
+
+
+def _get_or_create_stop_var():
+ with variable_scope.variable_scope(
+ name_or_scope='signal_early_stopping',
+ values=[],
+ reuse=variable_scope.AUTO_REUSE):
+ return variable_scope.get_variable(
+ name='STOP',
+ shape=[],
+ dtype=dtypes.bool,
+ initializer=init_ops.constant_initializer(False),
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False)
+
+
+class _StopOnPredicateHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop when `should_stop_fn` returns `True`."""
+
+ def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):
+ if not callable(should_stop_fn):
+ raise TypeError('`should_stop_fn` must be callable.')
+
+ self._should_stop_fn = should_stop_fn
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=run_every_secs, every_steps=run_every_steps)
+ self._global_step_tensor = None
+ self._stop_var = None
+ self._stop_op = None
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ self._stop_var = _get_or_create_stop_var()
+ self._stop_op = state_ops.assign(self._stop_var, True)
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results
+ if self._timer.should_trigger_for_step(global_step):
+ self._timer.update_last_triggered_step(global_step)
+ if self._should_stop_fn():
+ tf_logging.info('Requesting early stopping at global step %d',
+ global_step)
+ run_context.session.run(self._stop_op)
+ run_context.request_stop()
+
+
+class _CheckForStoppingHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop if stop is requested by `_StopOnPredicateHook`."""
+
+ def __init__(self):
+ self._stop_var = None
+
+ def begin(self):
+ self._stop_var = _get_or_create_stop_var()
+
+ def before_run(self, run_context):
+ del run_context
+ return session_run_hook.SessionRunArgs(self._stop_var)
+
+ def after_run(self, run_context, run_values):
+ should_early_stop = run_values.results
+ if should_early_stop:
+ tf_logging.info('Early stopping requested, suspending run.')
+ run_context.request_stop()
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
new file mode 100644
index 0000000000..b5eee818fa
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping_test.py
@@ -0,0 +1,233 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for early_stopping."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+from tensorflow.contrib.estimator.python.estimator import early_stopping
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import run_config
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import training_util
+
+
+class _FakeRunConfig(run_config.RunConfig):
+
+ def __init__(self, is_chief):
+ super(_FakeRunConfig, self).__init__()
+ self._is_chief = is_chief
+
+ @property
+ def is_chief(self):
+ return self._is_chief
+
+
+def _dummy_model_fn(features, labels, params):
+ _, _, _ = features, labels, params
+
+
+class _FakeEstimator(estimator.Estimator):
+ """Fake estimator for testing."""
+
+ def __init__(self, config):
+ super(_FakeEstimator, self).__init__(
+ model_fn=_dummy_model_fn, config=config)
+
+
+def _write_events(eval_dir, params):
+ """Test helper to write events to summary files."""
+ for steps, loss, accuracy in params:
+ estimator._write_dict_to_summary(eval_dir, {
+ 'loss': loss,
+ 'accuracy': accuracy,
+ }, steps)
+
+
+class ReadEvalMetricsTest(test.TestCase):
+
+ def test_read_eval_metrics(self):
+ eval_dir = tempfile.mkdtemp()
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 1, 2),
+ (2000, 3, 4),
+ (3000, 5, 6),
+ ])
+ self.assertEqual({
+ 1000: {
+ 'loss': 1,
+ 'accuracy': 2
+ },
+ 2000: {
+ 'loss': 3,
+ 'accuracy': 4
+ },
+ 3000: {
+ 'loss': 5,
+ 'accuracy': 6
+ },
+ }, early_stopping.read_eval_metrics(eval_dir))
+
+
+class EarlyStoppingHooksTest(test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ config = _FakeRunConfig(is_chief=True)
+ self._estimator = _FakeEstimator(config=config)
+ eval_dir = self._estimator.eval_dir()
+ os.makedirs(eval_dir)
+ _write_events(
+ eval_dir,
+ [
+ # steps, loss, accuracy
+ (1000, 0.8, 0.5),
+ (2000, 0.7, 0.6),
+ (3000, 0.4, 0.7),
+ (3500, 0.41, 0.68),
+ ])
+
+ def run_session(self, hooks, should_stop):
+ hooks = hooks if isinstance(hooks, list) else [hooks]
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertEqual(mon_sess.should_stop(), should_stop)
+
+ @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True))
+ def test_stop_if_higher_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_higher_hook(
+ self._estimator,
+ metric_name='accuracy',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True))
+ def test_stop_if_lower_hook(self, threshold, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_lower_hook(
+ self._estimator,
+ metric_name='loss',
+ threshold=threshold,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_increase_hook(
+ self._estimator,
+ metric_name='accuracy',
+ max_steps_without_increase=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0, False), (500, 4000, False),
+ (500, 0, True))
+ def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop):
+ self.run_session(
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps,
+ min_steps=min_steps), should_stop)
+
+ @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True),
+ (500, 0.3, True))
+ def test_multiple_hooks(self, max_steps, loss_threshold, should_stop):
+ self.run_session([
+ early_stopping.stop_if_no_decrease_hook(
+ self._estimator,
+ metric_name='loss',
+ max_steps_without_decrease=max_steps),
+ early_stopping.stop_if_lower_hook(
+ self._estimator, metric_name='loss', threshold=loss_threshold)
+ ], should_stop)
+
+ @parameterized.parameters(False, True)
+ def test_make_early_stopping_hook(self, should_stop):
+ self.run_session([
+ early_stopping.make_early_stopping_hook(
+ self._estimator, should_stop_fn=lambda: should_stop)
+ ], should_stop)
+
+ def test_make_early_stopping_hook_typeerror(self):
+ with self.assertRaises(TypeError):
+ early_stopping.make_early_stopping_hook(
+ estimator=object(), should_stop_fn=lambda: True)
+
+ def test_make_early_stopping_hook_valueerror(self):
+ with self.assertRaises(ValueError):
+ early_stopping.make_early_stopping_hook(
+ self._estimator,
+ should_stop_fn=lambda: True,
+ run_every_secs=60,
+ run_every_steps=100)
+
+
+class StopOnPredicateHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: False, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ self.assertFalse(mon_sess.raw_session().run(hook._stop_var))
+
+ hook = early_stopping._StopOnPredicateHook(
+ should_stop_fn=lambda: True, run_every_secs=0)
+ with ops.Graph().as_default():
+ training_util.create_global_step()
+ no_op = control_flow_ops.no_op()
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertTrue(mon_sess.should_stop())
+ self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
+
+
+class CheckForStoppingHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = early_stopping._CheckForStoppingHook()
+ with ops.Graph().as_default():
+ no_op = control_flow_ops.no_op()
+ assign_op = state_ops.assign(early_stopping._get_or_create_stop_var(),
+ True)
+ with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.run(no_op)
+ self.assertFalse(mon_sess.should_stop())
+ mon_sess.run(assign_op)
+ self.assertTrue(mon_sess.should_stop())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 8b97f86db1..c9d86ef4ab 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -529,11 +529,13 @@ def multi_label_head(n_classes,
applications, the shape is `[batch_size, n_classes]`.
Labels can be:
+
* A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
* An integer `SparseTensor` of class indices. The `dense_shape` must be
`[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
* If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
- must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
+ must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a
+ multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`.
If `weight_column` is specified, weights must be of shape
`[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
@@ -845,6 +847,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
train_op = train_op_fn(regularized_training_loss)
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
+ train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access
# Only summarize mean_loss for SUM reduction to preserve backwards
# compatibility. Otherwise skip it to avoid unnecessary computation.
if self._loss_reduction == losses.Reduction.SUM:
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index d6c158608b..7b884402d4 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -567,6 +568,33 @@ class MultiLabelHead(test.TestCase):
expected_loss=expected_loss,
expected_metrics=expected_metrics)
+ def test_eval_with_label_vocabulary_with_multi_hot_input(self):
+ n_classes = 2
+ head = head_lib.multi_label_head(
+ n_classes, label_vocabulary=['class0', 'class1'])
+ logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
+ labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
+ # loss = labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
+ keys = metric_keys.MetricKeys
+ expected_metrics = {
+ # Average loss over examples.
+ keys.LOSS_MEAN: expected_loss,
+ # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+ # this assert tests that the algorithm remains consistent.
+ keys.AUC: 0.3333,
+ keys.AUC_PR: 0.7639,
+ }
+ self._test_eval(
+ head=head,
+ logits=logits,
+ labels=labels_multi_hot,
+ expected_loss=expected_loss,
+ expected_metrics=expected_metrics)
+
def test_eval_with_thresholds(self):
n_classes = 2
thresholds = [0.25, 0.5, 0.75]
@@ -989,6 +1017,34 @@ class MultiLabelHead(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
+ def test_train_with_update_ops(self):
+ head = head_lib.multi_label_head(n_classes=2)
+
+ with ops.Graph().as_default():
+ w = variables.Variable(1)
+ update_op = w.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
+
+ t = variables.Variable('')
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return t.assign(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
+ labels=np.array([[1, 0], [1, 1]], dtype=np.int64),
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ sess.run(spec.train_op)
+ w_value, t_value = sess.run([w, t])
+ self.assertEqual(2, w_value)
+ self.assertEqual(expected_train_result, t_value)
+
def test_train_with_regularization_losses(self):
head = head_lib.multi_label_head(
n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py
index 3bf4abe83d..62a37abefb 100644
--- a/tensorflow/contrib/estimator/python/estimator/linear.py
+++ b/tensorflow/contrib/estimator/python/estimator/linear.py
@@ -39,6 +39,18 @@ class LinearEstimator(estimator.Estimator):
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearEstimator(
+ head=tf.contrib.estimator.multi_label_head(n_classes=3),
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator using the FTRL optimizer with regularization.
estimator = LinearEstimator(
head=tf.contrib.estimator.multi_label_head(n_classes=3),
@@ -87,7 +99,8 @@ class LinearEstimator(estimator.Estimator):
model_dir=None,
optimizer='Ftrl',
config=None,
- partitioner=None):
+ partitioner=None,
+ sparse_combiner='sum'):
"""Initializes a `LinearEstimator` instance.
Args:
@@ -99,10 +112,16 @@ class LinearEstimator(estimator.Estimator):
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. One of "mean", "sqrtn", and "sum" -- these are
+ effectively different ways to do example-level normalization, which can
+ be useful for bag-of-words features. for more details, see
+ @{tf.feature_column.linear_model$linear_model}.
"""
def _model_fn(features, labels, mode, config):
return linear_lib._linear_model_fn( # pylint: disable=protected-access
@@ -113,6 +132,7 @@ class LinearEstimator(estimator.Estimator):
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
- config=config)
+ config=config,
+ sparse_combiner=sparse_combiner)
super(LinearEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
index bb9b835889..7fcae5ad8e 100644
--- a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
+++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
@@ -62,10 +62,11 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
public:
explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context)
: OpKernel(context) {
- OP_REQUIRES_OK(context, context->MatchSignature(
- {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT,
- DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL},
- {DT_FLOAT, DT_FLOAT}));
+ OP_REQUIRES_OK(context,
+ context->MatchSignature(
+ {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64,
+ DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL},
+ {DT_FLOAT, DT_FLOAT}));
}
void Compute(OpKernelContext* context) override {
@@ -75,8 +76,9 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const Tensor& input_weights = context->input(3);
const Tensor& input_indices = context->input(4);
const Tensor& input_values = context->input(5);
- const Tensor& input_block_size = context->input(6);
- const Tensor& input_is_transpose = context->input(7);
+ const Tensor& entry_weights = context->input(6);
+ const Tensor& input_block_size = context->input(7);
+ const Tensor& input_is_transpose = context->input(8);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()),
InvalidArgument("Input factors should be a matrix."));
@@ -89,13 +91,33 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
InvalidArgument("Input input_weights should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
InvalidArgument("Input input_indices should be a matrix."));
+ OP_REQUIRES(
+ context, input_indices.dim_size(1) == 2,
+ InvalidArgument("Input input_indices should have shape (?, 2)."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
InvalidArgument("Input input_values should be a vector"));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(entry_weights.shape()),
+ InvalidArgument("Input entry_weights should be a vector"));
+ OP_REQUIRES(context, input_indices.dim_size(0) == input_values.dim_size(0),
+ InvalidArgument("Input input_values' length should match the "
+ "first dimension of Input input_indices "));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()),
InvalidArgument("Input input_block_size should be a scalar."));
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(input_is_transpose.shape()),
InvalidArgument("Input input_is_transpose should be a scalar."));
+ OP_REQUIRES(
+ context,
+ ((input_weights.dim_size(0) > 0 &&
+ factor_weights.dim_size(0) == factors.dim_size(0) &&
+ entry_weights.dim_size(0) == 0) ||
+ (input_weights.dim_size(0) == 0 && factor_weights.dim_size(0) == 0 &&
+ entry_weights.dim_size(0) == input_indices.dim_size(0))),
+ InvalidArgument("To specify the weights for observed entries, either "
+ "(1) entry_weights must be set or (2) input_weights "
+ "and factor_weights must be set, but not both."));
+ // TODO(yifanchen): Deprecate the support of input_weights and
+ // factor_weights.
const int64 factor_dim = factors.dim_size(1);
const int64 factors_size = factors.dim_size(0);
@@ -105,6 +127,7 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const auto& input_weights_vec = input_weights.vec<float>();
const float w_0 = unobserved_weights.scalar<float>()();
const auto& input_values_vec = input_values.vec<float>();
+ const auto& entry_weights_vec = entry_weights.vec<float>();
ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(),
factor_dim, factors_size);
@@ -134,6 +157,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
return is_transpose ? indices_mat(0, i) : indices_mat(1, i);
};
+ const bool use_entry_weights = entry_weights_vec.size() > 0;
+
// TODO(rmlarsen): In principle, we should be using the SparseTensor class
// and machinery for iterating over groups, but the fact that class
// SparseTensor makes a complete copy of the matrix makes me reluctant to
@@ -195,6 +220,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
// map using the hash of the thread id as the key.
//
// TODO(jpoulson): Switch to try_emplace once C++17 is supported
+ // TODO(b/72952120): Check whether the 3 lock-unlock pairs can be
+ // consolidated into just one.
map_mutex.lock();
const auto key_count = factor_batch_map.count(id_hash);
map_mutex.unlock();
@@ -213,6 +240,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
CHECK_LE(shard.second, perm.size());
CHECK_LE(shard.first, shard.second);
const int64 input_index = get_input_index(perm[shard.first]);
+ const float input_weight =
+ use_entry_weights ? 1.0 : input_weights_vec(input_index);
// Accumulate the rhs and lhs terms in the normal equations
// for the non-zero elements in the row or column of the sparse matrix
// corresponding to input_index.
@@ -228,7 +257,8 @@ class WALSComputePartialLhsAndRhsOp : public OpKernel {
const int64 factor_index = get_factor_index(i);
const float input_value = input_values_vec(i);
const float weight =
- input_weights_vec(input_index) * factor_weights_vec(factor_index);
+ use_entry_weights ? entry_weights_vec(i)
+ : input_weight * factor_weights_vec(factor_index);
CHECK_GE(weight, 0);
factor_batch.col(num_batched) =
factors_mat.col(factor_index) * std::sqrt(weight);
diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc
index 11ea36946e..1d31bd38c8 100644
--- a/tensorflow/contrib/factorization/ops/factorization_ops.cc
+++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc
@@ -25,20 +25,33 @@ REGISTER_OP("WALSComputePartialLhsAndRhs")
.Input("input_weights: float32")
.Input("input_indices: int64")
.Input("input_values: float32")
+ .Input("entry_weights: float32")
.Input("input_block_size: int64")
.Input("input_is_transpose: bool")
.Output("partial_lhs: float32")
.Output("partial_rhs: float32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"(
-Computes the partial left-hand side and right-hand side of WALS update.
+Computes the partial left-hand side and right-hand side of WALS update. For
+observed entry input_indices[i]=[m, n] with value input_values[i]=v, the weight
+should be specified either through (1) entry_weights[i] or (2) through
+input_weights[m] * factor_weights[n] (if input_is_transpose is false) or
+input_weights[n] * factor_weights[m] (if input_is_transpose is true). Note it is
+not allowed to have both (1) and (2) specified at the same time: when one
+approach is used, the input tensors related to the other approach must be kept
+completely empty.
factors: Matrix of size m * k.
-factor_weights: Vector of size m. Corresponds to column weights
+factor_weights: Vector of size m. Corresponds to column weights. Should be empty
+ if entry_weights is used.
unobserved_weights: Scalar. Weight for unobserved input entries.
-input_weights: Vector of size n. Corresponds to row weights.
+input_weights: Vector of size n. Corresponds to row weights. Should be empty if
+ entry_weights is used.
input_indices: Indices for the input SparseTensor.
input_values: Values for the input SparseTensor.
+entry_weights: If not empty, this must be same length as input_vaues and is used
+ as the per-entry non-zero weight. If this is used, input_weights and
+ factor_weights must be empty.
input_block_size: Scalar. Number of rows spanned by input.
input_is_transpose: If true, logically transposes the input for processing.
partial_lhs: 3-D tensor with size input_block_size x k x k.
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
index ba30fd9977..6c2f1d4608 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -55,7 +55,41 @@ class WalsSolverOpsTest(test.TestCase):
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, self._column_weights, self._unobserved_weights,
self._row_weights, sparse_block.indices, sparse_block.values,
- sparse_block.dense_shape[0], False)
+ [],
+ input_block_size=sparse_block.dense_shape[0],
+ input_is_transpose=False)
+ self.assertAllClose(lhs_tensor.eval(), [[
+ [0.014800, 0.017000, 0.019200],
+ [0.017000, 0.019600, 0.022200],
+ [0.019200, 0.022200, 0.025200],
+ ], [
+ [0.0064000, 0.0080000, 0.0096000],
+ [0.0080000, 0.0100000, 0.0120000],
+ [0.0096000, 0.0120000, 0.0144000],
+ ], [
+ [0.0099000, 0.0126000, 0.0153000],
+ [0.0126000, 0.0162000, 0.0198000],
+ [0.0153000, 0.0198000, 0.0243000],
+ ], [
+ [0.058800, 0.067200, 0.075600],
+ [0.067200, 0.076800, 0.086400],
+ [0.075600, 0.086400, 0.097200],
+ ]])
+ self.assertAllClose(rhs_matrix.eval(), [[0.019300, 0.023000, 0.026700],
+ [0.061600, 0.077000, 0.092400],
+ [0.160400, 0.220000, 0.279600],
+ [0.492800, 0.563200, 0.633600]])
+
+ def testWalsSolverLhsEntryWeights(self):
+ sparse_block = SparseBlock3x3()
+ with self.test_session():
+ [lhs_tensor,
+ rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
+ self._column_factors, [], self._unobserved_weights,
+ [], sparse_block.indices, sparse_block.values,
+ [0.01, 0.03, 0.04, 0.03, 0.06, 0.12],
+ input_block_size=sparse_block.dense_shape[0],
+ input_is_transpose=False)
self.assertAllClose(lhs_tensor.eval(), [[
[0.014800, 0.017000, 0.019200],
[0.017000, 0.019600, 0.022200],
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 8f73274c2a..7ab70fbcfd 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -943,6 +943,7 @@ class WALSModel(object):
row_weights_slice,
new_sp_input.indices,
new_sp_input.values,
+ [],
num_rows,
transpose_input,
name="wals_compute_partial_lhs_rhs"))
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index 555beddeaa..05bcdac2ca 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -95,7 +95,7 @@ def sequence_input_layer(
Raises:
ValueError: If any of the `feature_columns` is the wrong type.
"""
- feature_columns = fc._clean_feature_columns(feature_columns)
+ feature_columns = fc._normalize_feature_columns(feature_columns)
for c in feature_columns:
if not isinstance(c, fc._SequenceDenseColumn):
raise ValueError(
@@ -346,7 +346,8 @@ def sequence_numeric_column(
key,
shape=(1,),
default_value=0.,
- dtype=dtypes.float32):
+ dtype=dtypes.float32,
+ normalizer_fn=None):
"""Returns a feature column that represents sequences of numeric data.
Example:
@@ -370,6 +371,12 @@ def sequence_numeric_column(
default_value: A single value compatible with `dtype` that is used for
padding the sparse data into a dense `Tensor`.
dtype: The type of values.
+ normalizer_fn: If not `None`, a function that can be used to normalize the
+ value of the tensor after `default_value` is applied for parsing.
+ Normalizer function takes the input `Tensor` as its argument, and returns
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
+ even though the most common use case of this function is normalization, it
+ can be used for any kind of Tensorflow transformations.
Returns:
A `_SequenceNumericColumn`.
@@ -383,12 +390,16 @@ def sequence_numeric_column(
if not (dtype.is_integer or dtype.is_floating):
raise ValueError('dtype must be convertible to float. '
'dtype: {}, key: {}'.format(dtype, key))
+ if normalizer_fn is not None and not callable(normalizer_fn):
+ raise TypeError(
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
return _SequenceNumericColumn(
key,
shape=shape,
default_value=default_value,
- dtype=dtype)
+ dtype=dtype,
+ normalizer_fn=normalizer_fn)
def _assert_all_equal_and_return(tensors, name=None):
@@ -407,7 +418,7 @@ class _SequenceNumericColumn(
fc._SequenceDenseColumn,
collections.namedtuple(
'_SequenceNumericColumn',
- ['key', 'shape', 'default_value', 'dtype'])):
+ ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])):
"""Represents sequences of numeric data."""
@property
@@ -419,7 +430,10 @@ class _SequenceNumericColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- return inputs.get(self.key)
+ input_tensor = inputs.get(self.key)
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return input_tensor
@property
def _variable_shape(self):
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 88f5d53516..45d7b74046 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
@@ -109,7 +110,7 @@ class SequenceInputLayerTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
def test_embedding_column_with_non_sequence_categorical(self):
- """Tests that error is raised for non-sequence categorical column."""
+ """Tests that error is raised for non-sequence embedding column."""
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, ids [2]
@@ -131,6 +132,107 @@ class SequenceInputLayerTest(test.TestCase):
features={'aaa': sparse_input},
feature_columns=[embedding_column_a])
+ def test_shared_embedding_column(self):
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [2, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2))
+
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 4.), # id 1
+ (5., 6.) # id 2
+ )
+
+ def _get_initializer(embedding_dimension, embedding_values):
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ return _initializer
+
+ expected_input_layer = [
+ # example 0, ids_a [2], ids_b [1]
+ [[5., 6., 3., 4.], [0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [2, 0]
+ [[1., 2., 5., 6.], [3., 4., 1., 2.]],
+ ]
+ expected_sequence_length = [1, 2]
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ # Test that columns are reordered alphabetically.
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension,
+ initializer=_get_initializer(embedding_dimension, embedding_values))
+
+ input_layer, sequence_length = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ },
+ feature_columns=shared_embedding_columns)
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+ self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_shared_embedding_column_with_non_sequence_categorical(self):
+ """Tests that error is raised for non-sequence shared embedding column."""
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'In embedding_column: aaa_shared_embedding\. categorical_column must '
+ r'be of type _SequenceCategoricalColumn to use sequence_input_layer\.'):
+ _, _ = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b
+ },
+ feature_columns=shared_embedding_columns)
+
def test_indicator_column(self):
vocabulary_size_a = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -577,6 +679,182 @@ class SequenceEmbeddingColumnTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
+class SequenceSharedEmbeddingColumnTest(test.TestCase):
+
+ def test_get_sequence_dense_tensor(self):
+ vocabulary_size = 3
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [0, 2]
+ # example 2, ids [0]
+ # example 3, ids []
+ indices=((0, 0), (1, 0), (1, 1), (2, 0)),
+ values=(1, 0, 2, 0),
+ dense_shape=(4, 2))
+
+ expected_lookups_a = [
+ # example 0, ids [2]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 2.], [3., 5.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [1]
+ [[3., 5.], [0., 0.]],
+ ]
+
+ expected_lookups_b = [
+ # example 0, ids [1]
+ [[3., 5.], [0., 0.]],
+ # example 1, ids [0, 2]
+ [[1., 2.], [7., 11.]],
+ # example 2, ids [0]
+ [[1., 2.], [0., 0.]],
+ # example 3, ids []
+ [[0., 0.], [0., 0.]],
+ ]
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[0]
+ embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[0]
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+ self.assertAllEqual(
+ expected_lookups_a, embedding_lookup_a.eval(session=sess))
+ self.assertAllEqual(
+ expected_lookups_b, embedding_lookup_b.eval(session=sess))
+
+ def test_sequence_length(self):
+ vocabulary_size = 3
+
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ expected_sequence_length_a = [1, 2]
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [0, 2]
+ # example 1, ids [1]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ expected_sequence_length_b = [2, 1]
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[1]
+ sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[1]
+
+ with monitored_session.MonitoredSession() as sess:
+ sequence_length_a = sess.run(sequence_length_a)
+ self.assertAllEqual(expected_sequence_length_a, sequence_length_a)
+ self.assertEqual(np.int64, sequence_length_a.dtype)
+ sequence_length_b = sess.run(sequence_length_b)
+ self.assertAllEqual(expected_sequence_length_b, sequence_length_b)
+ self.assertEqual(np.int64, sequence_length_b.dtype)
+
+ def test_sequence_length_with_empty_rows(self):
+ """Tests _sequence_length when some examples do not have ids."""
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids []
+ # example 1, ids [2]
+ # example 2, ids [0, 1]
+ # example 3, ids []
+ # example 4, ids [1]
+ # example 5, ids []
+ indices=((1, 0), (2, 0), (2, 1), (4, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(6, 2))
+ expected_sequence_length_a = [0, 1, 2, 0, 1, 0]
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids []
+ # example 2, ids []
+ # example 3, ids []
+ # example 4, ids [1]
+ # example 5, ids [0, 1]
+ indices=((0, 0), (4, 0), (5, 0), (5, 1)),
+ values=(2, 1, 0, 1),
+ dense_shape=(6, 2))
+ expected_sequence_length_b = [1, 0, 0, 0, 1, 2]
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+
+ shared_embedding_columns = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+
+ sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'aaa': sparse_input_a
+ }))[1]
+ sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
+ _LazyBuilder({
+ 'bbb': sparse_input_b
+ }))[1]
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length_a, sequence_length_a.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length_b, sequence_length_b.eval(session=sess))
+
+
class SequenceIndicatorColumnTest(test.TestCase):
def test_get_sequence_dense_tensor(self):
@@ -670,6 +948,7 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertEqual((1,), a.shape)
self.assertEqual(0., a.default_value)
self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
def test_shape_saved_as_tuple(self):
a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
@@ -688,6 +967,10 @@ class SequenceNumericColumnTest(test.TestCase):
ValueError, 'dtype must be convertible to float'):
sfc.sequence_numeric_column('aaa', dtype=dtypes.string)
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')
+
def test_get_sequence_dense_tensor(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]
@@ -708,6 +991,41 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
+ def test_get_sequence_dense_tensor_with_normalizer_fn(self):
+
+ def _increment_two(input_sparse_tensor):
+ return sparse_ops.sparse_add(
+ input_sparse_tensor,
+ sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2))
+ )
+
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+
+ # Before _increment_two:
+ # [[0.], [1.]],
+ # [[10.], [0.]],
+ # After _increment_two:
+ # [[2.], [1.]],
+ # [[10.], [2.]],
+ expected_dense_tensor = [
+ [[2.], [1.]],
+ [[10.], [2.]],
+ ]
+ numeric_column = sfc.sequence_numeric_column(
+ 'aaa', normalizer_fn=_increment_two)
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
def test_get_sequence_dense_tensor_with_shape(self):
"""Tests get_sequence_dense_tensor with shape !=(1,)."""
sparse_input = sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py
index daba965a98..484ffee3e7 100644
--- a/tensorflow/contrib/ffmpeg/__init__.py
+++ b/tensorflow/contrib/ffmpeg/__init__.py
@@ -28,7 +28,6 @@ from __future__ import print_function
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video
from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio
-from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index 020b5c99c6..b1b5126d9e 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -21,7 +21,6 @@ from __future__ import print_function
from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py
from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py
from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
-from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 10d1ecc738..dc49383c5c 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_cond
from tensorflow.python.framework.smart_cond import smart_constant_value
from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
from tensorflow.python.framework.tensor_spec import TensorSpec
-from tensorflow.python.ops.array_ops import broadcast_to
from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal
from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d
from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d
from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['nest', 'broadcast_to']
+_allowed_symbols = ['nest']
_nest_allowed_symbols = [
'assert_same_structure',
'is_sequence',
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index df7d7e9dae..34fd5018af 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
class CriticalSectionTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCreateCriticalSection(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
@@ -53,7 +53,7 @@ class CriticalSectionTest(test.TestCase):
self.assertAllClose([2.0 * i for i in range(num_concurrent)],
sorted(r_value))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCriticalSectionWithControlFlow(self):
for outer_cond in [False, True]:
for inner_cond in [False, True]:
@@ -109,7 +109,7 @@ class CriticalSectionTest(test.TestCase):
with self.assertRaisesOpError("Error"):
self.evaluate(r)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCreateCriticalSectionFnReturnsOp(self):
cs = critical_section_ops.CriticalSection(shared_name="cs")
v = resource_variable_ops.ResourceVariable(0.0, name="v")
@@ -332,7 +332,7 @@ class CriticalSectionTest(test.TestCase):
self.evaluate(v.initializer)
self.assertEqual(10, self.evaluate(out))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInsideFunction(self):
cs = critical_section_ops.CriticalSection()
v = resource_variable_ops.ResourceVariable(1)
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 40ae01bfcc..e8e3180019 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -712,7 +712,8 @@ class VariableDeviceChooser(object):
num_tasks=0,
job_name='ps',
device_type='CPU',
- device_index=0):
+ device_index=0,
+ replica=None):
"""Initialize VariableDeviceChooser.
Usage:
@@ -733,12 +734,15 @@ class VariableDeviceChooser(object):
self._job_name = job_name
self._device_type = device_type
self._device_index = device_index
+ self._replica = replica
self._num_tasks = num_tasks
self._next_task_id = 0
def __call__(self, op):
- device_spec = tf_device.DeviceSpec(device_type=self._device_type,
- device_index=self._device_index)
+ device_spec = tf_device.DeviceSpec(
+ replica=self._replica,
+ device_type=self._device_type,
+ device_index=self._device_index)
if self._num_tasks > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 37ea6eb12a..7e0c7dbec1 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -506,6 +506,35 @@ class VariablesTest(test.TestCase):
self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+ def testVariableWithVariableDeviceChooserWithReplica(self):
+
+ with ops.Graph().as_default():
+ device_fn = variables_lib2.VariableDeviceChooser(replica=3, num_tasks=2)
+ with arg_scope([variables_lib2.variable], device=device_fn):
+ a = variables_lib2.variable('a', [])
+ b = variables_lib2.variable('b', [])
+ c = variables_lib2.variable('c', [], device='cpu:12')
+ d = variables_lib2.variable('d', [])
+ with ops.device('cpu:99'):
+ e_init = constant_op.constant(12)
+ e = variables_lib2.variable('e', initializer=e_init)
+ # The values below highlight how the VariableDeviceChooser puts initial
+ # values on the same device as the variable job.
+ self.assertDeviceEqual(a.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(a.initial_value.op.colocation_groups(),
+ a.op.colocation_groups())
+ self.assertDeviceEqual(b.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertEqual(b.initial_value.op.colocation_groups(),
+ b.op.colocation_groups())
+ self.assertDeviceEqual(c.device, '/cpu:12')
+ self.assertEqual(c.initial_value.op.colocation_groups(),
+ c.op.colocation_groups())
+ self.assertDeviceEqual(d.device, '/job:ps/replica:3/task:0/cpu:0')
+ self.assertEqual(d.initial_value.op.colocation_groups(),
+ d.op.colocation_groups())
+ self.assertDeviceEqual(e.device, '/job:ps/replica:3/task:1/cpu:0')
+ self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
+
def testVariableGPUPlacement(self):
with ops.Graph().as_default():
@@ -930,8 +959,8 @@ class AssignFromCheckpointTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
init_value0 = 10.0
init_value1 = 20.0
@@ -944,8 +973,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -960,8 +989,8 @@ class AssignFromCheckpointTest(test.TestCase):
# Tests restoring PartitionedVariables and tests using a dictionary
# of lists as the assign_from_checkpoint() var_list param.
def testLoadPartitionedVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_partitioned_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_partitioned_variables'))
init_value0 = np.array([[10.0, 11.0], [12.0, 13.0]])
init_value1 = np.array([20.0]) # Partitioned into 1 part, edge case.
@@ -974,15 +1003,14 @@ class AssignFromCheckpointTest(test.TestCase):
partitioner = partitioned_variables.variable_axis_size_partitioner(2)
var0 = variables_lib2.variable(
'var0', shape=init_value0.shape, partitioner=partitioner)
- var0full = variables_lib2.variable(
- 'var0full', shape=init_value0.shape)
+ var0full = variables_lib2.variable('var0full', shape=init_value0.shape)
var1 = variables_lib2.variable(
'var1', shape=init_value1.shape, partitioner=partitioner)
# Convert var0 and var1 into a list of underlying variables.
vars_to_restore = {'var0': list(var0) + [var0full], 'var1': list(var1)}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -992,16 +1020,18 @@ class AssignFromCheckpointTest(test.TestCase):
# Request and test the variable values. PartitionedVariables can't
# be evaled so we wrap them in an identity.
- self.assertTrue(np.array_equal(
- init_value0, array_ops.identity(var0).eval()))
- self.assertTrue(np.array_equal(
- init_value0, var0full.eval()))
- self.assertTrue(np.array_equal(
- init_value1, array_ops.identity(var1).eval()))
+ self.assertTrue(
+ np.array_equal(init_value0,
+ array_ops.identity(var0).eval()))
+ self.assertTrue(np.array_equal(init_value0, var0full.eval()))
+ self.assertTrue(
+ np.array_equal(init_value1,
+ array_ops.identity(var1).eval()))
def testRaisesValueErrorIfAVariableIsntFound(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'raises_value_error_if_var_isnt_found'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'raises_value_error_if_var_isnt_found'))
init_value0 = 10.0
init_value1 = 20.0
@@ -1019,8 +1049,9 @@ class AssignFromCheckpointTest(test.TestCase):
variables_lib2.assign_from_checkpoint(model_path, vars_to_restore)
def testInitFromCheckpointWithScopes(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'init_from_checkpoint_with_scopes'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'init_from_checkpoint_with_scopes'))
init_value0 = np.asarray(
[1.0, 3.0, 9.0], dtype=np.float32).reshape((1, 3, 1))
@@ -1038,8 +1069,8 @@ class AssignFromCheckpointTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=init_value1.shape)
vars_to_restore = {'layer0/v0': var0, 'layer1/v1': var1}
- op, feed_dict = variables_lib2.assign_from_checkpoint(model_path,
- vars_to_restore)
+ op, feed_dict = variables_lib2.assign_from_checkpoint(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1081,8 +1112,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
return saver.save(sess, checkpoint_dir, global_step=global_step)
def testLoadExistingVariables(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'load_existing_variables'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'load_existing_variables'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1097,8 +1128,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1111,8 +1142,9 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testLoadExistingVariablesDifferentShapeDefaultDoesNotAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(), 'load_existing_vars_no_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(),
+ 'load_existing_vars_no_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1127,8 +1159,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var1 = variables_lib2.variable('my_var1', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1138,9 +1170,10 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testLoadExistingVariablesDifferentShapeAllowReshape(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(
- self.get_temp_dir(),
- 'load_existing_variables_different_shape_allow_reshape'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(
+ self.get_temp_dir(),
+ 'load_existing_variables_different_shape_allow_reshape'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1169,8 +1202,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testNotFoundError(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'not_found_error'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'not_found_error'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1186,8 +1219,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
var2 = variables_lib2.variable('my_var2', shape=[])
vars_to_restore = {'v0': var0, 'v1': var1, 'v2': var2}
- init_fn = variables_lib2.assign_from_checkpoint_fn(model_path,
- vars_to_restore)
+ init_fn = variables_lib2.assign_from_checkpoint_fn(
+ model_path, vars_to_restore)
# Initialize the variables.
sess.run(variables_lib.global_variables_initializer())
@@ -1197,8 +1230,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
init_fn(sess)
def testMissingVariablesList(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_list'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_list'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1228,8 +1261,8 @@ class AssignFromCheckpointFnTest(test.TestCase):
self.assertEqual(init_value1, var1.eval())
def testMissingVariablesDict(self):
- model_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
- 'missing_variables_dict'))
+ model_dir = tempfile.mkdtemp(
+ prefix=os.path.join(self.get_temp_dir(), 'missing_variables_dict'))
if gfile.Exists(model_dir):
gfile.DeleteRecursively(model_dir)
@@ -1279,9 +1312,8 @@ class ZeroInitializerOpTest(test.TestCase):
def testZeroInitializer(self):
for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64):
for use_init in (False, True):
- self._testZeroInitializer(
- [10, 20], array_ops.ones(
- [10, 20], dtype=dtype), use_init)
+ self._testZeroInitializer([10, 20], array_ops.ones(
+ [10, 20], dtype=dtype), use_init)
class ZeroVarInitializerOpTest(test.TestCase):
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 0eb6889db1..0f0813c07f 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -75,6 +75,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:gpu_util_hdrs",
"//tensorflow/core/kernels:ops_util_hdrs",
"//third_party/eigen3",
+ "@local_config_cuda//cuda:cudnn_header",
],
alwayslink = 1,
)
@@ -94,6 +95,7 @@ tf_custom_op_library(
"//tensorflow/core/kernels:conv_ops_gpu_hdrs",
"//tensorflow/core/kernels:gpu_util_hdrs",
"//tensorflow/core/kernels:ops_util_hdrs",
+ "@local_config_cuda//cuda:cudnn_header",
],
)
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 3d0ed89932..4d62ac65ff 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -289,8 +289,8 @@ class FusedConv2DBiasActivationTest(test.TestCase):
conv = tensors[i]
value = values[i]
ref_value = ref_values[i]
- print("expected = ", ref_value)
- print("actual = ", value)
+ tf_logging.info("expected = ", ref_value)
+ tf_logging.info("actual = ", value)
tol = 1e-5
if value.dtype == np.float16:
tol = 1e-3
@@ -831,7 +831,8 @@ class FusedConvInt8Tests(test.TestCase):
vertical_stride, padding_type)
output_width = CalculateConvolvedOutputDim(input_width, filter_width,
horizontal_stride, padding_type)
- print("output_height=", output_height, ", output_width=", output_width)
+ tf_logging.info("output_height=", output_height, ", output_width=",
+ output_width)
side_input, _, _ = gen_array_ops.quantize_v2(
random_ops.random_uniform(
@@ -866,8 +867,8 @@ class FusedConvInt8Tests(test.TestCase):
with self.test_session(use_gpu=True) as sess:
actual_y, expected_y = sess.run([actual, expected])
- print("actual_y = ", actual_y)
- print("expected_y = ", expected_y)
+ tf_logging.info("actual_y = ", actual_y)
+ tf_logging.info("expected_y = ", expected_y)
self.assertTrue(np.array_equal(actual_y, expected_y))
def testFusedConvInt8(self):
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index b305f37791..10a8796bcb 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -45,6 +45,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
@@ -59,6 +60,7 @@ py_test(
deps = [
":features",
":namedtuples",
+ ":random_tensor_pool",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim:learning",
@@ -70,6 +72,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
@@ -188,6 +191,7 @@ py_test(
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":namedtuples",
":tuple_losses",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -344,9 +348,11 @@ py_library(
"//tensorflow/python:image_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "@six_archive//:six",
],
)
@@ -470,12 +476,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":head",
":namedtuples",
":summaries",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator",
@@ -498,16 +504,19 @@ py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
- "//tensorflow/python/estimator:head",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 4092b32004..8e4affb9b4 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -24,11 +24,11 @@ import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
-from tensorflow.contrib.gan.python.estimator.python import head as head_lib
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_inspect as inspect
@@ -154,94 +154,93 @@ class GANEstimator(estimator.Estimator):
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If loss functions aren't callable.
+ ValueError: If `use_loss_summaries` isn't boolean or `None`.
+ ValueError: If `get_hooks_fn` isn't callable or `None`.
"""
- # TODO(joelshor): Explicitly validate inputs.
+ if not callable(generator_loss_fn):
+ raise ValueError('generator_loss_fn must be callable.')
+ if not callable(discriminator_loss_fn):
+ raise ValueError('discriminator_loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
def _model_fn(features, labels, mode):
- gopt = (generator_optimizer() if callable(generator_optimizer) else
- generator_optimizer)
- dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
- else discriminator_optimizer)
- gan_head = head_lib.gan_head(
- generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn,
- get_eval_metric_ops_fn=get_eval_metric_ops_fn)
- return _gan_model_fn(
- features, labels, mode, generator_fn, discriminator_fn, gan_head,
+ """GANEstimator model function."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT]:
+ raise ValueError('Mode not recognized: %s' % mode)
+ real_data = labels # rename inputs for clarity
+ generator_inputs = features # rename inputs for clarity
+
+ # Make GANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries)
+ # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
+ # metrics, and optimizers (if required).
+ return _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn)
+
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
-def _gan_model_fn(
- features,
- labels,
- mode,
- generator_fn,
- discriminator_fn,
- head,
- add_summaries=None,
- generator_scope_name='Generator'):
- """The `model_fn` for the GAN estimator.
-
- We make the following convention:
- features -> TFGAN's `generator_inputs`
- labels -> TFGAN's `real_data`
-
- Args:
- features: A dictionary to feed to generator. In the unconditional case,
- this might be just `noise`. In the conditional GAN case, this
- might be the generator's conditioning. The `generator_fn` determines
- what the required keys are.
- labels: Real data. Can be any structure, as long as `discriminator_fn`
- can accept it for the first argument.
- mode: Defines whether this is training, evaluation or prediction.
- See `ModeKeys`.
- generator_fn: A python lambda that takes `generator_inputs` as inputs and
- returns the outputs of the GAN generator.
- discriminator_fn: A python lambda that takes `real_data`/`generated data`
- and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
- head: A `Head` instance suitable for GANs.
- add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
- generator_scope_name: The name of the generator scope. We need this to be
- the same for GANModels produced by TFGAN's `train.gan_model` and the
- manually constructed ones for predictions.
-
- Returns:
- `ModelFnOps`
-
- Raises:
- ValueError: If `labels` isn't `None` during prediction.
- """
- real_data = labels
- generator_inputs = features
-
- if mode == model_fn_lib.ModeKeys.TRAIN:
- gan_model = _make_train_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- gan_model = _make_eval_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- else:
+def _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries, generator_scope='Generator'):
+ """Makes the GANModel tuple, which encapsulates the GAN model architecture."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
if real_data is not None:
raise ValueError('`labels` must be `None` when mode is `predict`. '
'Instead, found %s' % real_data)
gan_model = _make_prediction_gan_model(
- generator_inputs, generator_fn, generator_scope_name)
+ generator_inputs, generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(
+ generator_fn, discriminator_fn, real_data, generator_inputs,
+ generator_scope, add_summaries, mode)
- return head.create_estimator_spec(
- features=None,
- mode=mode,
- logits=gan_model,
- labels=None)
+ return gan_model
+
+
+def _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = tfgan_tuples.GANLoss(
+ generator_loss=generator_loss_fn(gan_model),
+ discriminator_loss=discriminator_loss_fn(gan_model))
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(
+ gan_model, gan_loss, get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (generator_optimizer() if callable(generator_optimizer) else
+ generator_optimizer)
+ dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
+ else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(
+ gan_model, gan_loss, gopt, dopt, get_hooks_fn)
+
+ return estimator_spec
def _make_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries, mode):
- """Make a `GANModel`, and optionally pass in `mode`."""
+ """Construct a `GANModel`, and optionally pass in `mode`."""
# If network functions have an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn, mode=mode)
@@ -264,22 +263,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data,
return gan_model
-def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for training."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.TRAIN)
-
-
-def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for evaluation."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.EVAL)
-
-
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
@@ -303,3 +286,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss,
+ gan_loss.discriminator_loss]):
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(
+ gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 955482599b..9ac9c6ca9c 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -21,30 +21,30 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
from tensorflow.contrib import layers
-from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator
from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
-from tensorflow.python.training import monitored_session
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode):
return layers.fully_connected(data, 1)
-def mock_head(testcase, expected_generator_inputs, expected_real_data,
- generator_scope_name):
- """Returns a mock head that validates logits values and variable names."""
- discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults
- generator_var_names = set([
- '%s/fully_connected/weights:0' % generator_scope_name,
- '%s/fully_connected/biases:0' % generator_scope_name])
- discriminator_var_names = set([
- '%s/fully_connected/weights:0' % discriminator_scope_name,
- '%s/fully_connected/biases:0' % discriminator_scope_name])
-
- def _create_estimator_spec(features, mode, logits, labels):
- gan_model = logits # renaming for clarity
- is_predict = mode == model_fn_lib.ModeKeys.PREDICT
- testcase.assertIsNone(features)
- testcase.assertIsNone(labels)
- testcase.assertIsInstance(gan_model, namedtuples.GANModel)
-
- trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- expected_var_names = (generator_var_names if is_predict else
- generator_var_names | discriminator_var_names)
- testcase.assertItemsEqual(expected_var_names,
- [var.name for var in trainable_vars])
-
- assertions = []
- def _or_none(x):
- return None if is_predict else x
- testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs)
- # TODO(joelshor): Add check on `generated_data`.
- testcase.assertItemsEqual(
- generator_var_names,
- set([x.name for x in gan_model.generator_variables]))
- testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
- testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
- # TODO(joelshor): Add check on `discriminator_real_outputs`.
- # TODO(joelshor): Add check on `discriminator_gen_outputs`.
- if is_predict:
- testcase.assertIsNone(gan_model.discriminator_scope)
- else:
- testcase.assertEqual(discriminator_scope_name,
- gan_model.discriminator_scope.name)
-
- with ops.control_dependencies(assertions):
- if mode == model_fn_lib.ModeKeys.TRAIN:
- return model_fn_lib.EstimatorSpec(
- mode=mode, loss=array_ops.zeros([]),
- train_op=control_flow_ops.no_op(), training_hooks=[])
- elif mode == model_fn_lib.ModeKeys.EVAL:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data,
- loss=array_ops.zeros([]))
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data)
- else:
- testcase.fail('Invalid mode: {}'.format(mode))
-
- head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
- head.create_estimator_spec = test.mock.MagicMock(
- wraps=_create_estimator_spec)
-
- return head
-
-
-class GANModelFnTest(test.TestCase):
- """Tests that _gan_model_fn passes expected logits to mock head."""
-
- def setUp(self):
- self._model_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- if self._model_dir:
- writer_cache.FileWriterCache.clear()
- shutil.rmtree(self._model_dir)
+class GetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `GetGANModel` produces the correct model."""
- def _test_logits_helper(self, mode):
- """Tests that the expected logits are passed to mock head."""
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
with ops.Graph().as_default():
- training_util.get_or_create_global_step()
- generator_inputs = {'x': array_ops.zeros([5, 4])}
- real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
- array_ops.zeros([5, 4]))
- generator_scope_name = 'generator'
- head = mock_head(self,
- expected_generator_inputs=generator_inputs,
- expected_real_data=real_data,
- generator_scope_name=generator_scope_name)
- estimator_spec = estimator._gan_model_fn(
- features=generator_inputs,
- labels=real_data,
- mode=mode,
- generator_fn=generator_fn,
- discriminator_fn=discriminator_fn,
- generator_scope_name=generator_scope_name,
- head=head)
- with monitored_session.MonitoredTrainingSession(
- checkpoint_dir=self._model_dir) as sess:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- sess.run(estimator_spec.train_op)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- sess.run(estimator_spec.loss)
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- sess.run(estimator_spec.predictions)
- else:
- self.fail('Invalid mode: {}'.format(mode))
-
- def test_logits_predict(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT)
-
- def test_logits_eval(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.EVAL)
-
- def test_logits_train(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN)
+ generator_inputs = {'x': array_ops.ones([3, 4])}
+ real_data = (array_ops.zeros([3, 4]) if
+ mode != model_fn_lib.ModeKeys.PREDICT else None)
+ gan_model = estimator._get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries=False)
+
+ self.assertEqual(generator_inputs, gan_model.generator_inputs)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.real_data)
+ self.assertIsNone(gan_model.discriminator_real_outputs)
+ self.assertIsNone(gan_model.discriminator_gen_outputs)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertIsNotNone(gan_model.real_data)
+ self.assertIsNotNone(gan_model.discriminator_real_outputs)
+ self.assertIsNotNone(gan_model.discriminator_gen_outputs)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.GANModel(
+ generator_inputs=None,
+ generated_data=array_ops.ones([3, 4]),
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ real_data=array_ops.zeros([3, 4]),
+ discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
+ discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
+ gan_model.discriminator_gen_outputs)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ generator_loss_fn=dummy_loss_fn,
+ discriminator_loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
# TODO(joelshor): Add pandas test.
@@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase):
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
- def get_metrics(gan_model):
- return {
- 'mse_custom_metric': metrics_lib.mean_squared_error(
- gan_model.real_data, gan_model.generated_data)
- }
-
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index ff903a78cc..1a0ee6dfc4 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -24,18 +24,24 @@ from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
+from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.util import deprecation
__all__ = [
'GANHead',
'gan_head',
]
+
def _summary_key(head_name, val):
return '%s/%s' % (val, head_name) if head_name else val
+@deprecation.deprecated(
+ None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a '
+ 'GANHead.')
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
@@ -76,6 +82,9 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
class GANHead(head._Head): # pylint: disable=protected-access
"""`Head` for a GAN."""
+ @deprecation.deprecated(
+ None, 'Please use tf.contrib.gan.GANEstimator without explicitly making '
+ 'a GANHead.')
def __init__(self, generator_loss_fn, discriminator_loss_fn,
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
@@ -102,9 +111,20 @@ class GANHead(head._Head): # pylint: disable=protected-access
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
+
+ if not callable(generator_loss_fn):
+ raise TypeError('generator_loss_fn must be callable.')
+ if not callable(discriminator_loss_fn):
+ raise TypeError('discriminator_loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
+ if name is not None and not isinstance(name, str):
+ raise TypeError('name must be string.')
+
if get_hooks_fn is None:
get_hooks_fn = tfgan_train.get_sequential_train_hooks()
- # TODO(joelshor): Validate inputs.
if use_loss_summaries in [True, False]:
generator_loss_fn = functools.partial(
@@ -182,7 +202,10 @@ class GANHead(head._Head): # pylint: disable=protected-access
if mode == model_fn_lib.ModeKeys.PREDICT:
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.PREDICT,
- predictions=gan_model.generated_data)
+ predictions=gan_model.generated_data,
+ export_outputs={
+ 'predict': export_output.PredictOutput(gan_model.generated_data)
+ })
elif mode == model_fn_lib.ModeKeys.EVAL:
gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None)
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 6587f1fc60..8205bc889d 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -26,8 +26,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training
+_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
@@ -64,20 +67,22 @@ class GANHeadTest(test.TestCase):
generator_optimizer=training.GradientDescentOptimizer(1.0),
discriminator_optimizer=training.GradientDescentOptimizer(1.0),
get_eval_metric_ops_fn=self.get_metrics)
- self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ self.assertIsInstance(self.gan_head, head.GANHead)
def get_metrics(self, gan_model):
self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
return {}
def _test_modes_helper(self, mode):
- self.gan_head.create_estimator_spec(
+ return self.gan_head.create_estimator_spec(
features=None,
mode=mode,
logits=get_gan_model())
def test_modes_predict(self):
- self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
+ spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
+ self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'),
+ spec.export_outputs.keys())
def test_modes_eval(self):
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 81e70ae30a..1435e19109 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -34,8 +34,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/process_state.h"
#endif // GOOGLE_CUDA
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
@@ -274,7 +275,7 @@ Status GdrMemoryManager::Init() {
Allocator* allocators[] = {
#if GOOGLE_CUDA
- ProcessState::singleton()->GetCUDAHostAllocator(0),
+ GPUProcessState::singleton()->GetCUDAHostAllocator(0),
ProcessState::singleton()->GetCPUAllocator(0),
#endif // GOOGLE_CUDA
cpu_allocator(),
@@ -308,7 +309,8 @@ Status GdrMemoryManager::Init() {
if (IsGDRAvailable()) {
// Note we don't free allocated GPU memory so there is no free visitor
int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
- ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
}
#endif // GOOGLE_CUDA
@@ -430,7 +432,7 @@ void GdrMemoryManager::TransportOptionsFromTensor(
#if GOOGLE_CUDA
if (!on_host) {
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
GPUUtil::CopyGPUTensorToCPU(
device, device_context, &tensor, host_copy,
@@ -532,7 +534,7 @@ void GdrMemoryManager::TensorFromTransportOptions(
Tensor host_copy;
#if GOOGLE_CUDA
if (mr == nullptr && !on_host) {
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
buffer = DMAHelper::buffer(&host_copy);
addr = buffer->data();
diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc
index 1f9dd0decb..9025c992a4 100644
--- a/tensorflow/contrib/gdr/gdr_server_lib.cc
+++ b/tensorflow/contrib/gdr/gdr_server_lib.cc
@@ -57,7 +57,7 @@ Status GdrServer::Init() {
new GdrWorker(env, remote_memory_manager_.get()));
};
TF_RETURN_IF_ERROR(
- GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func));
+ GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func));
return remote_memory_manager_->Init();
}
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 592d37b432..026a3d1200 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None):
if op._original_op:
op_._original_op = op._original_op
- # Add op to the graph
- info.graph_._add_op(op_)
-
return op_, op_.outputs
@@ -492,7 +489,7 @@ class Transformer(object):
t_ = info.transformed_ts[t]
consumer_op_ = info.transformed_ops[consumer_op]
t_index_ = list(consumer_op_.inputs).index(tmp_t_)
- consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access
+ consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access
def _connect_control_inputs(self, info):
"""Connect the previously copied ops."""
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index c2e32da133..022e17d139 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -35,6 +35,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
template struct FillProjectiveTransform<CPUDevice, uint8>;
template struct FillProjectiveTransform<CPUDevice, int32>;
template struct FillProjectiveTransform<CPUDevice, int64>;
+template struct FillProjectiveTransform<CPUDevice, Eigen::half>;
template struct FillProjectiveTransform<CPUDevice, float>;
template struct FillProjectiveTransform<CPUDevice, double>;
@@ -99,6 +100,7 @@ class ImageProjectiveTransform : public OpKernel {
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index ad50133061..209aa24548 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/types.h"
@@ -58,6 +59,11 @@ class ProjectiveGenerator {
? transforms_.data()
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
+ if (projection == 0) {
+ // Return the fill value (0) for infinite coordinates,
+ // which are outside the input image
+ return T(0);
+ }
const float input_x =
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
projection;
@@ -105,21 +111,21 @@ class ProjectiveGenerator {
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
const float value_yfloor =
- (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
- DenseIndex(x_floor), channel,
- fill_value) +
- (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
- DenseIndex(x_ceil), channel,
- fill_value);
+ (x_ceil - x) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_floor), DenseIndex(x_floor),
+ channel, fill_value)) +
+ (x - x_floor) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_floor), DenseIndex(x_ceil),
+ channel, fill_value));
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
const float value_yceil =
- (x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
- DenseIndex(x_floor), channel,
- fill_value) +
- (x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
- DenseIndex(x_ceil), channel,
- fill_value);
+ (x_ceil - x) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_ceil), DenseIndex(x_floor),
+ channel, fill_value)) +
+ (x - x_floor) * static_cast<float>(read_with_fill_value(
+ batch, DenseIndex(y_ceil), DenseIndex(x_ceil),
+ channel, fill_value));
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index ebdcaea7ab..e59f1bf844 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -29,7 +29,7 @@ using shape_inference::ShapeHandle;
REGISTER_OP("ImageProjectiveTransform")
.Input("images: dtype")
.Input("transforms: float32")
- .Attr("dtype: {uint8, int32, int64, float32, float64}")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
.Attr("interpolation: string")
.Output("transformed_images: dtype")
.SetShapeFn([](InferenceContext* c) {
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index b50177ae56..62a22dcf34 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -30,7 +30,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
_DTYPES = set(
- [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
+ [dtypes.uint8, dtypes.int32, dtypes.int64,
+ dtypes.float16, dtypes.float32, dtypes.float64])
class ImageOpsTest(test_util.TensorFlowTestCase):
@@ -127,6 +128,23 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
[0, 1, 0, 1],
[0, 1, 1, 1]])
+ def test_extreme_projective_transform(self):
+ for dtype in _DTYPES:
+ with self.test_session():
+ image = constant_op.constant(
+ [[1, 0, 1, 0],
+ [0, 1, 0, 1],
+ [1, 0, 1, 0],
+ [0, 1, 0, 1]], dtype=dtype)
+ transformation = constant_op.constant([1, 0, 0, 0, 1, 0, -1, 0],
+ dtypes.float32)
+ image_transformed = image_ops.transform(image, transformation)
+ self.assertAllEqual(image_transformed.eval(),
+ [[1, 0, 0, 0],
+ [0, 0, 0, 0],
+ [1, 0, 0, 0],
+ [0, 0, 0, 0]])
+
def test_bilinear(self):
with self.test_session():
image = constant_op.constant(
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index cd984c8054..86b0ffe9a0 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -33,7 +33,8 @@ _image_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_image_ops.so"))
_IMAGE_DTYPES = set(
- [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64])
+ [dtypes.uint8, dtypes.int32, dtypes.int64,
+ dtypes.float16, dtypes.float32, dtypes.float64])
ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py
index b4a99867ed..61f78febfc 100644
--- a/tensorflow/contrib/integrate/python/ops/odes.py
+++ b/tensorflow/contrib/integrate/python/ops/odes.py
@@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
@@ -279,13 +278,27 @@ def _assert_increasing(t):
return ops.control_dependencies([assert_increasing])
-def _check_input_types(t, y0):
+def _check_input_types(y0, t, dt=None):
if not (y0.dtype.is_floating or y0.dtype.is_complex):
raise TypeError('`y0` must have a floating point or complex floating '
'point dtype')
if not t.dtype.is_floating:
raise TypeError('`t` must have a floating point dtype')
+ if dt is not None and not dt.dtype.is_floating:
+ raise TypeError('`dt` must have a floating point dtype')
+
+
+def _check_input_sizes(t, dt):
+ if len(t.get_shape().as_list()) > 1:
+ raise ValueError('t must be a 1D tensor')
+
+ if len(dt.get_shape().as_list()) > 1:
+ raise ValueError('t must be a 1D tensor')
+
+ if t.get_shape()[0] != dt.get_shape()[0] + 1:
+ raise ValueError('t and dt have incompatible lengths, must be N and N-1')
+
def _dopri5(func,
y0,
@@ -510,7 +523,7 @@ def odeint(func,
# avoiding the need to pack/unpack in user functions.
y0 = ops.convert_to_tensor(y0, name='y0')
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
- _check_input_types(t, y0)
+ _check_input_types(y0, t)
error_dtype = abs(y0).dtype
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
@@ -530,24 +543,74 @@ def odeint(func,
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
"""Base class for fixed-grid ODE integrators."""
- def integrate(self, evol_func, y0, time_grid):
- time_delta_grid = time_grid[1:] - time_grid[:-1]
-
- scan_func = self._make_scan_func(evol_func)
+ def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
+ """Returns integrated values of differential equation on the `time grid`.
+
+ Numerically integrates differential equation defined via time derivative
+ evaluator `evol_func` using fixed time steps specified in dt_grid.
+
+ Args:
+ evol_func: Callable, evaluates time derivative of y at a given time.
+ y0: N-D Tensor holds initial values of the solution.
+ time_grid: 1-D Tensor holding the time points at which the solution
+ will be recorded, must have a floating dtype.
+ dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
+ intervals. Must be a floating dtype and have one less element than that
+ of the time_grid.
+ steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
+ as dt_grid. Specifies number of steps needed for every interval. Assumes
+ steps_on_intervals * dt_grid == time intervals.
+
+ Returns:
+ (N+1)-D tensor, where the first dimension corresponds to different
+ time points. Contains the solved value of y for each desired time point in
+ `t`, with the initial value `y0` being the first element along the first
+ dimension.
+ """
- y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
- y0)
- return array_ops.concat([[y0], y_grid], axis=0)
+ iteration_func = self._make_iteration_func(evol_func, dt_grid)
+ integrate_interval = self._make_interval_integrator(iteration_func,
+ steps_on_intervals)
- def _make_scan_func(self, evol_func):
+ num_times = array_ops.size(time_grid)
+ current_time = time_grid[0]
+ solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
+ solution_array = solution_array.write(0, y0)
- def scan_func(y, t_and_dt):
- t, dt = t_and_dt
+ solution_array, _, _, _ = control_flow_ops.while_loop(
+ lambda _, __, ___, i: i < num_times,
+ integrate_interval,
+ (solution_array, y0, current_time, 1)
+ )
+ solution_array = solution_array.stack()
+ solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
+ return solution_array
+
+ def _make_iteration_func(self, evol_func, dt_grid):
+ """Returns a function that builds operations of a single time step."""
+
+ def iteration_func(y, t, dt_step, interval_step):
+ """Performs a single time step advance."""
+ dt = dt_grid[interval_step - 1]
dy = self._step_func(evol_func, t, dt, y)
dy = math_ops.cast(dy, dtype=y.dtype)
- return y + dy
+ return y + dy, t + dt, dt_step + 1, interval_step
+
+ return iteration_func
+
+ def _make_interval_integrator(self, iteration_func, interval_sizes):
+ """Returns a function that builds operations for interval integration."""
- return scan_func
+ def integrate_interval(solution_array, y, t, interval_num):
+ """Integrates y with fixed time step on interval `interval_num`."""
+ y, t, _, _ = control_flow_ops.while_loop(
+ lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
+ iteration_func,
+ (y, t, 0, interval_num)
+ )
+ return solution_array.write(interval_num, y), y, t, interval_num + 1
+
+ return integrate_interval
@abc.abstractmethod
def _step_func(self, evol_func, t, dt, y):
@@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
+ """Fixed grid integrator implementing midpoint scheme."""
def _step_func(self, evol_func, t, dt, y):
dt_cast = math_ops.cast(dt, y.dtype)
@@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
+ """Fixed grid integrator implementing RK4 scheme."""
def _step_func(self, evol_func, t, dt, y):
k1 = evol_func(y, t)
@@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator):
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
-def odeint_fixed(func, y0, t, method='rk4', name=None):
+def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
"""ODE integration on a fixed grid (with no step size control).
Useful in certain scenarios to avoid the overhead of adaptive step size
@@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype.
+ dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
+ integration intervals in `t`. 1-D Tensor should provide values
+ for all intervals, must have 1 less element than that of `t`.
+ If given a 0-D Tensor, the value is interpreted as time step suggestion
+ same for all intervals. If passed None, then time step is set to be the
+ t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
+ insuring an integer number of steps per interval, potentially reducing the
+ time step.
method: One of 'midpoint' or 'rk4'.
name: Optional name for the resulting operation.
@@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
Raises:
ValueError: Upon caller errors.
"""
- with ops.name_scope(name, 'odeint_fixed', [y0, t]):
+ with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
y0 = ops.convert_to_tensor(y0, name='y0')
- _check_input_types(t, y0)
+
+ intervals = t[1:] - t[:-1]
+ if dt is None:
+ dt = intervals
+ dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')
+
+ steps_on_intervals = math_ops.ceil(intervals / dt)
+ dt = intervals / steps_on_intervals
+ steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)
+
+ _check_input_types(y0, t, dt)
+ _check_input_sizes(t, dt)
with _assert_increasing(t):
with ops.name_scope(method):
if method == 'midpoint':
- return _MidpointFixedGridIntegrator().integrate(func, y0, t)
+ return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
+ steps_on_intervals)
elif method == 'rk4':
- return _RK4FixedGridIntegrator().integrate(func, y0, t)
+ return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
+ steps_on_intervals)
else:
raise ValueError('method not supported: {!s}'.format(method))
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index 3ec01212d2..c7b4e2faa8 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase):
class OdeIntFixedTest(test.TestCase):
- def _test_integrate_sine(self, method):
+ def _test_integrate_sine(self, method, t, dt=None):
def evol_func(y, t):
del t
return array_ops.stack([y[1], -y[0]])
y0 = [0., 1.]
- time_grid = np.linspace(0., 10., 200)
- y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
+ y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
with self.test_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
- y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
+ y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2)
- def _test_integrate_gaussian(self, method):
+ def _test_integrate_gaussian(self, method, t, dt=None):
def evol_func(y, t):
return -math_ops.cast(t, dtype=y.dtype) * y[0]
y0 = [1.]
- time_grid = np.linspace(0., 2., 100)
- y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
+ y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
with self.test_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
- y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
+ y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2)
+
+ def _test_integrate_sine_all(self, method):
+ uniform_time_grid = np.linspace(0., 10., 200)
+ non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0])
+ uniform_dt = 0.02
+ non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03])
+ self._test_integrate_sine(method, uniform_time_grid)
+ self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt)
+ self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt)
+
+ def _test_integrate_gaussian_all(self, method):
+ uniform_time_grid = np.linspace(0., 2., 100)
+ non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0])
+ uniform_dt = 0.01
+ non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03])
+ self._test_integrate_gaussian(method, uniform_time_grid)
+ self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt)
+ self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt)
def _test_everything(self, method):
- self._test_integrate_sine(method)
- self._test_integrate_gaussian(method)
+ self._test_integrate_sine_all(method)
+ self._test_integrate_gaussian_all(method)
def test_midpoint(self):
self._test_everything('midpoint')
@@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase):
def test_rk4(self):
self._test_everything('rk4')
+ def test_dt_size_exceptions(self):
+ times = np.linspace(0., 2., 100)
+ dt = np.ones(99) * 0.01
+ dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03])
+ dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0)
+ times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0)
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times, dt_wrong_length)
+
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times, dt_wrong_dim)
+
+ with self.assertRaises(ValueError):
+ self._test_integrate_gaussian('midpoint', times_wrong_dim, dt)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc
new file mode 100644
index 0000000000..8cdf16103b
--- /dev/null
+++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KafkaDataset")
+ .Input("topics: string")
+ .Input("servers: string")
+ .Input("group: string")
+ .Input("eof: bool")
+ .Input("timeout: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kafka topics.
+
+topics: A `tf.string` tensor containing one or more subscriptions,
+ in the format of [topic:partition:offset:length],
+ by default length is -1 for unlimited.
+servers: A list of bootstrap servers.
+group: The consumer group id.
+eof: If True, the kafka reader will stop on EOF.
+timeout: The timeout value for the Kafka Consumer to wait
+ (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py
index 938c881fcb..3327a9f9a6 100644
--- a/tensorflow/contrib/keras/api/keras/layers/__init__.py
+++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py
@@ -20,10 +20,10 @@ from __future__ import print_function
# Generic layers.
# pylint: disable=g-bad-import-order
-from tensorflow.python.keras.engine import Input
-from tensorflow.python.keras.engine import InputLayer
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_layer import Input
+from tensorflow.python.keras.engine.input_layer import InputLayer
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md
index 762a2f0b57..102626925d 100644
--- a/tensorflow/contrib/kfac/README.md
+++ b/tensorflow/contrib/kfac/README.md
@@ -1,5 +1,10 @@
# K-FAC: Kronecker-Factored Approximate Curvature
+# <font color="red", size=10><u>WARNING: </u></font>
+# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
+# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
+# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
+
**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
approximate second-order optimization method, in TensorFlow. When applied to
feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
index b7f63d8d94..03b9da7933 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ b/tensorflow/contrib/kfac/python/ops/optimizer.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import warnings
+
# pylint disable=long-line
from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
from tensorflow.contrib.kfac.python.ops import estimator as est
@@ -107,6 +109,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
ValueError: If momentum is non-zero and momentum_type is not 'regular'
or 'adam'.
"""
+ warnings.warn(
+ "third_party.tensorflow.contrib.kfac is deprecated."
+ "This will be removed on 15-07-2018. Check README for further details.",
+ DeprecationWarning)
# Parameters to be passed to the Fisher estimator:
self._variables = var_list or tf_variables.trainable_variables
self._cov_ema_decay = cov_ema_decay
diff --git a/tensorflow/contrib/kinesis/BUILD b/tensorflow/contrib/kinesis/BUILD
new file mode 100644
index 0000000000..25443d0ad4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/BUILD
@@ -0,0 +1,113 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "kinesis",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/kinesis_dataset_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core/platform/s3:aws_crypto",
+ "//third_party/eigen3",
+ "@aws",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/kinesis_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kinesis_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "kinesis_op_loader",
+ srcs = ["python/ops/kinesis_op_loader.py"],
+ dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+tf_py_test(
+ name = "kinesis_test",
+ srcs = ["python/kernel_tests/kinesis_test.py"],
+ additional_deps = [
+ ":kinesis",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/python/training/checkpointable/data_structures_base.py b/tensorflow/contrib/kinesis/__init__.py
index f1b2cf105b..3824b8ae75 100644
--- a/tensorflow/python/training/checkpointable/data_structures_base.py
+++ b/tensorflow/contrib/kinesis/__init__.py
@@ -1,4 +1,3 @@
-"""A trivial base class to avoid circular imports for isinstance checks."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,15 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Kinesis Dataset.
+
+@@KinesisDataset
+"""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
-
+from tensorflow.python.util.all_util import remove_undocumented
-class CheckpointableDataStructureBase(checkpointable_lib.CheckpointableBase):
- """Base class for data structures which contain checkpointable objects."""
+_allowed_symbols = [
+ "KinesisDataset",
+]
- pass
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
new file mode 100644
index 0000000000..3212279c4c
--- /dev/null
+++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc
@@ -0,0 +1,359 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <aws/core/Aws.h>
+#include <aws/core/config/AWSProfileConfigLoader.h>
+#include <aws/core/utils/Outcome.h>
+#include <aws/kinesis/KinesisClient.h>
+#include <aws/kinesis/model/DescribeStreamRequest.h>
+#include <aws/kinesis/model/GetRecordsRequest.h>
+#include <aws/kinesis/model/GetShardIteratorRequest.h>
+#include <aws/kinesis/model/PutRecordsRequest.h>
+#include <aws/kinesis/model/ShardIteratorType.h>
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
+
+namespace tensorflow {
+namespace {
+
+Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration config;
+ const char* endpoint = getenv("KINESIS_ENDPOINT");
+ if (endpoint) {
+ config.endpointOverride = Aws::String(endpoint);
+ }
+ const char* region = getenv("AWS_REGION");
+ if (region) {
+ config.region = Aws::String(region);
+ } else {
+ // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
+ // is set with a truthy value.
+ const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
+ string load_config =
+ load_config_env ? str_util::Lowercase(load_config_env) : "";
+ if (load_config == "true" || load_config == "1") {
+ Aws::String config_file;
+ // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
+ const char* config_file_env = getenv("AWS_CONFIG_FILE");
+ if (config_file_env) {
+ config_file = config_file_env;
+ } else {
+ const char* home_env = getenv("HOME");
+ if (home_env) {
+ config_file = home_env;
+ config_file += "/.aws/config";
+ }
+ }
+ Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
+ // Load the configuration. If successful, get the region.
+ // If the load is not successful, then generate a warning.
+ if (loader.Load()) {
+ auto profiles = loader.GetProfiles();
+ if (!profiles["default"].GetRegion().empty()) {
+ config.region = profiles["default"].GetRegion();
+ }
+ } else {
+ LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
+ }
+ }
+ }
+ const char* use_https = getenv("KINESIS_USE_HTTPS");
+ if (use_https) {
+ if (use_https[0] == '0') {
+ config.scheme = Aws::Http::Scheme::HTTP;
+ } else {
+ config.scheme = Aws::Http::Scheme::HTTPS;
+ }
+ }
+ const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
+ if (verify_ssl) {
+ if (verify_ssl[0] == '0') {
+ config.verifySSL = false;
+ } else {
+ config.verifySSL = true;
+ }
+ }
+ const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
+ if (connect_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(connect_timeout, &timeout)) {
+ config.connectTimeoutMs = timeout;
+ }
+ }
+ const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
+ if (request_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(request_timeout, &timeout)) {
+ config.requestTimeoutMs = timeout;
+ }
+ }
+
+ return &config;
+}
+
+Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
+ static Aws::Client::ClientConfiguration* config =
+ InitializeDefaultClientConfig();
+ return *config;
+}
+
+static mutex mu(LINKER_INITIALIZED);
+static unsigned count(0);
+void AwsInitAPI() {
+ mutex_lock lock(mu);
+ count++;
+ if (count == 1) {
+ Aws::SDKOptions options;
+ options.cryptoOptions.sha256Factory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
+ };
+ options.cryptoOptions.sha256HMACFactory_create_fn = []() {
+ return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
+ };
+ Aws::InitAPI(options);
+ }
+}
+void AwsShutdownAPI() {
+ mutex_lock lock(mu);
+ count--;
+ if (count == 0) {
+ Aws::SDKOptions options;
+ Aws::ShutdownAPI(options);
+ }
+}
+void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
+ if (client != nullptr) {
+ delete client;
+ AwsShutdownAPI();
+ }
+}
+}
+class KinesisDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ std::string stream = "";
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<std::string>(ctx, "stream", &stream));
+ std::string shard = "";
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<std::string>(ctx, "shard", &shard));
+ bool read_indefinitely = true;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "read_indefinitely",
+ &read_indefinitely));
+ int64 interval = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "interval", &interval));
+ OP_REQUIRES(ctx, (interval > 0),
+ errors::InvalidArgument(
+ "Interval value should be large than 0, got ", interval));
+ *output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
+ const bool read_indefinitely, const int64 interval)
+ : GraphDatasetBase(ctx),
+ stream_(stream),
+ shard_(shard),
+ read_indefinitely_(read_indefinitely),
+ interval_(interval) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* stream = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
+ Node* shard = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
+ Node* read_indefinitely = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
+ Node* interval = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {stream, shard, read_indefinitely, interval}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ client_(nullptr, ShutdownClient) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (iterator_ == "") {
+ TF_RETURN_IF_ERROR(SetupStreamsLocked());
+ }
+ do {
+ Aws::Kinesis::Model::GetRecordsRequest request;
+ auto outcome = client_->GetRecords(
+ request.WithShardIterator(iterator_).WithLimit(1));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ if (outcome.GetResult().GetRecords().size() == 0) {
+ // If no records were returned then nothing is available at the
+ // moment.
+ if (!dataset()->read_indefinitely_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ // Continue the loop after a period of time.
+ ctx->env()->SleepForMicroseconds(dataset()->interval_);
+ continue;
+ }
+ if (outcome.GetResult().GetRecords().size() != 1) {
+ return errors::Unknown("invalid number of records ",
+ outcome.GetResult().GetRecords().size(),
+ " returned");
+ }
+
+ iterator_ = outcome.GetResult().GetNextShardIterator();
+
+ const auto& data = outcome.GetResult().GetRecords()[0].GetData();
+ StringPiece value(
+ reinterpret_cast<const char*>(data.GetUnderlyingData()),
+ data.GetLength());
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<std::string>()() = std::string(value);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ *end_of_sequence = false;
+ return Status::OK();
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented("SaveInternal is currently not supported");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "RestoreInternal is currently not supported");
+ }
+
+ private:
+ // Sets up Kinesis streams to read from.
+ Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ AwsInitAPI();
+ client_.reset(
+ new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
+
+ Aws::Kinesis::Model::DescribeStreamRequest request;
+ auto outcome = client_->DescribeStream(
+ request.WithStreamName(dataset()->stream_.c_str()));
+ if (!outcome.IsSuccess()) {
+ return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
+ outcome.GetError().GetMessage());
+ }
+ Aws::String shard;
+ Aws::String sequence;
+ if (dataset()->shard_ == "") {
+ if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
+ 1) {
+ return errors::InvalidArgument(
+ "shard has to be provided unless the stream only have one "
+ "shard, there are ",
+ outcome.GetResult().GetStreamDescription().GetShards().size(),
+ " shards in stream ", dataset()->stream_);
+ }
+ shard = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetShardId();
+ sequence = outcome.GetResult()
+ .GetStreamDescription()
+ .GetShards()[0]
+ .GetSequenceNumberRange()
+ .GetStartingSequenceNumber();
+ } else {
+ for (const auto& entry :
+ outcome.GetResult().GetStreamDescription().GetShards()) {
+ if (entry.GetShardId() == dataset()->shard_.c_str()) {
+ shard = entry.GetShardId();
+ sequence =
+ entry.GetSequenceNumberRange().GetStartingSequenceNumber();
+ break;
+ }
+ }
+ if (shard == "") {
+ return errors::InvalidArgument("no shard ", dataset()->shard_,
+ " in stream ", dataset()->stream_);
+ }
+ }
+
+ Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
+ auto iterator_outcome = client_->GetShardIterator(
+ iterator_request.WithStreamName(dataset()->stream_.c_str())
+ .WithShardId(shard)
+ .WithShardIteratorType(
+ Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
+ .WithStartingSequenceNumber(sequence));
+ if (!iterator_outcome.IsSuccess()) {
+ return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
+ ": ",
+ iterator_outcome.GetError().GetMessage());
+ }
+ iterator_ = iterator_outcome.GetResult().GetShardIterator();
+ return Status::OK();
+ }
+
+ mutex mu_;
+ Aws::String iterator_ GUARDED_BY(mu_);
+ std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
+ client_ GUARDED_BY(mu_);
+ };
+
+ const std::string stream_;
+ const std::string shard_;
+ const bool read_indefinitely_;
+ const int64 interval_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
+ KinesisDatasetOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/ops/dataset_ops.cc b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
new file mode 100644
index 0000000000..54204513cf
--- /dev/null
+++ b/tensorflow/contrib/kinesis/ops/dataset_ops.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KinesisDataset")
+ .Input("stream: string")
+ .Input("shard: string")
+ .Input("read_indefinitely: bool")
+ .Input("interval: int64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that emits the messages of one or more Kinesis topics.
+
+stream: A `tf.string` tensor containing the name of the stream.
+shard: A `tf.string` tensor containing the id of the shard.
+read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
new file mode 100644
index 0000000000..7289b45c50
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -0,0 +1,139 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for KinesisDataset.
+NOTE: boto3 is needed and the test has to be invoked manually:
+```
+$ bazel test -s --verbose_failures --config=opt \
+ --action_env=AWS_ACCESS_KEY_ID=XXXXXX \
+ --action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \
+ //tensorflow/contrib/kinesis:kinesis_test
+```
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import boto3
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class KinesisDatasetTest(test.TestCase):
+
+ def testKinesisDatasetOneShard(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 1 shard.
+ stream_name = "tf_kinesis_test_1"
+ client.create_stream(StreamName=stream_name, ShardCount=1)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 1.
+ sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
+ for i in range(10):
+ self.assertEqual("D" + str(i), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+ def testKinesisDatasetTwoShards(self):
+ client = boto3.client('kinesis', region_name='us-east-1')
+
+ # Setup the Kinesis with 2 shards.
+ stream_name = "tf_kinesis_test_2"
+ client.create_stream(StreamName=stream_name, ShardCount=2)
+ # Wait until stream exists, default is 10 * 18 seconds.
+ client.get_waiter('stream_exists').wait(StreamName=stream_name)
+
+ for i in range(10):
+ data = "D" + str(i)
+ client.put_record(
+ StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
+ response = client.describe_stream(StreamName=stream_name)
+ shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
+ shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]
+
+ stream = array_ops.placeholder(dtypes.string, shape=[])
+ shard = array_ops.placeholder(dtypes.string, shape=[])
+ num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = kinesis_dataset_ops.KinesisDataset(
+ stream, shard, read_indefinitely=False).repeat(num_epochs)
+ batch_dataset = repeat_dataset.batch(batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ init_op = iterator.make_initializer(repeat_dataset)
+ init_batch_op = iterator.make_initializer(batch_dataset)
+ get_next = iterator.get_next()
+
+ data = list()
+ with self.test_session() as sess:
+ # Basic test: read from shard 0 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_0, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ # Basic test: read from shard 1 of stream 2.
+ sess.run(
+ init_op, feed_dict={
+ stream: stream_name, shard: shard_id_1, num_epochs: 1})
+ with self.assertRaises(errors.OutOfRangeError):
+ # Use range(11) to guarantee the OutOfRangeError.
+ for i in range(11):
+ data.append(sess.run(get_next))
+
+ data.sort()
+ self.assertEqual(data, ["D" + str(i) for i in range(10)])
+
+ client.delete_stream(StreamName=stream_name)
+ # Wait until stream deleted, default is 10 * 18 seconds.
+ client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
new file mode 100644
index 0000000000..ca2df95ba4
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Kinesis Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class KinesisDataset(Dataset):
+ """A Kinesis Dataset that consumes the message.
+
+ Kinesis is a managed service provided by AWS for data streaming.
+ This dataset reads messages from Kinesis with each message presented
+ as a `tf.string`.
+
+ For example, we can construct and use the KinesisDataset as follows:
+ ```python
+ dataset = tf.contrib.kinesis.KinesisDataset(
+ "kinesis_stream_name", read_indefinitely=False)
+ next = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(nxt))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Since Kinesis is a data streaming service, data may not be available
+ at the time it is being read. The argument `read_indefinitely` is
+ used to control the behavior in this situation. If `read_indefinitely`
+ is `True`, then `KinesisDataset` will keep retrying to retrieve data
+ from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError`
+ is returned immediately instead.
+ """
+
+ def __init__(self,
+ stream,
+ shard="",
+ read_indefinitely=True,
+ interval=100000):
+ """Create a KinesisDataset.
+
+ Args:
+ stream: A `tf.string` tensor containing the name of the stream.
+ shard: A `tf.string` tensor containing the id of the shard.
+ read_indefinitely: If `True`, the Kinesis dataset will keep retry
+ again on `EOF` after the `interval` period. If `False`, then
+ the dataset will stop on `EOF`. The default value is `True`.
+ interval: The interval for the Kinesis Client to wait before
+ it tries to get records again (in millisecond).
+ """
+ super(KinesisDataset, self).__init__()
+ self._stream = ops.convert_to_tensor(
+ stream, dtype=dtypes.string, name="stream")
+ self._shard = ops.convert_to_tensor(
+ shard, dtype=dtypes.string, name="shard")
+ self._read_indefinitely = ops.convert_to_tensor(
+ read_indefinitely, dtype=dtypes.bool, name="read_indefinitely")
+ self._interval = ops.convert_to_tensor(
+ interval, dtype=dtypes.int64, name="interval")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.kinesis_dataset(
+ self._stream, self._shard, self._read_indefinitely, self._interval)
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.string
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
new file mode 100644
index 0000000000..c9ce9f3646
--- /dev/null
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_op_loader.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python helper for loading kinesis ops and kernels."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
index 3ba1026383..2ede5daee7 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py
@@ -652,7 +652,8 @@ def map_fn(fn, labeled_tensor, name=None):
tensor_lt = core.LabeledTensor(tensor, original_axes)
return fn(tensor_lt).tensor
- map_op = functional_ops.map_fn(tf_fn, labeled_tensor.tensor)
+ map_op = functional_ops.map_fn(
+ tf_fn, labeled_tensor.tensor, dtype=first_map_lt.dtype)
map_lt = core.LabeledTensor(map_op, final_axes)
return core.identity(map_lt, name=scope)
diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py
index 00f03a111a..bc33596935 100644
--- a/tensorflow/contrib/layers/__init__.py
+++ b/tensorflow/contrib/layers/__init__.py
@@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide.
@@avg_pool2d
@@avg_pool3d
@@batch_norm
+@@convolution
+@@convolution1d
@@convolution2d
@@convolution3d
@@conv2d_in_plane
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index dd2395f8c9..7ede193029 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import itertools
import math
-import sys
import numpy as np
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
index 06060b99e7..a85cff4f70 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py
@@ -683,11 +683,12 @@ def parse_feature_columns_from_sequence_examples(
the serialized proto.
Returns:
- A tuple consisting of:
- context_features: a dict mapping `FeatureColumns` from
- `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
- sequence_features: a dict mapping `FeatureColumns` from
- `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
+ A tuple consisting of (context_features, sequence_features)
+
+ * context_features: a dict mapping `FeatureColumns` from
+ `context_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
+ * sequence_features: a dict mapping `FeatureColumns` from
+ `sequence_feature_columns` to their parsed `Tensors`/`SparseTensor`s.
"""
# Sequence example parsing requires a single (scalar) example.
try:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index a55d42c151..beeabd6b65 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages
__all__ = [
'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d',
'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution',
- 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose',
- 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse',
- 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn',
- 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d',
+ 'convolution1d', 'convolution2d', 'convolution2d_in_plane',
+ 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose',
+ 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN',
+ 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d',
'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat',
'scale_gradient', 'separable_conv2d', 'separable_convolution2d',
'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm',
@@ -2664,6 +2664,7 @@ def separable_convolution2d(
normalizer_fn=None,
normalizer_params=None,
weights_initializer=initializers.xavier_initializer(),
+ pointwise_initializer=None,
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
@@ -2705,7 +2706,9 @@ def separable_convolution2d(
`biases_regularizer` are ignored and `biases` are not created nor added.
default set to None for no normalizer function
normalizer_params: Normalization function parameters.
- weights_initializer: An initializer for the weights.
+ weights_initializer: An initializer for the depthwise weights.
+ pointwise_initializer: An initializer for the pointwise weights.
+ default set to None, means use weights_initializer.
weights_regularizer: Optional regularizer for the weights.
biases_initializer: An initializer for the biases. If None skip biases.
biases_regularizer: Optional regularizer for the biases.
@@ -2737,6 +2740,9 @@ def separable_convolution2d(
custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
+ if pointwise_initializer is None:
+ pointwise_initializer = weights_initializer
+
df = ('channels_first'
if data_format and data_format.startswith('NC') else 'channels_last')
if num_outputs is not None:
@@ -2752,7 +2758,7 @@ def separable_convolution2d(
depth_multiplier=depth_multiplier,
use_bias=not normalizer_fn and biases_initializer,
depthwise_initializer=weights_initializer,
- pointwise_initializer=weights_initializer,
+ pointwise_initializer=pointwise_initializer,
bias_initializer=biases_initializer,
depthwise_regularizer=weights_regularizer,
pointwise_regularizer=weights_regularizer,
@@ -3117,7 +3123,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'.format(
num_channels, num_units))
- shape[axis] = num_units
+ shape[axis] = -1
shape += [num_channels // num_units]
# Dealing with batches with arbitrary sizes
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 56e9194ceb..c5c7269b1f 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
+ def testConv1dShape(self):
+ width = 7
+ with self.test_session():
+ images = random_ops.random_uniform((5, width, 3), seed=1)
+ output = layers_lib.convolution1d(images, 32, 3)
+ self.assertEqual(output.op.name, 'Conv/Relu')
+ self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
+
+ def testConvInferSpatialDims(self):
+ depth, height, width = 7, 9, 11
+ with self.test_session():
+ images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3])
+ self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
+ images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3, 3])
+ self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+ images = np.random.uniform(size=(5, depth, height, width,
+ 4)).astype(np.float32)
+ output = layers_lib.convolution(images, 32, [3, 3, 3])
+ self.assertListEqual(output.get_shape().as_list(),
+ [5, depth, height, width, 32])
+
class DenseToSparseTest(test.TestCase):
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0e35b1aa8b..dad3da3748 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -514,15 +514,15 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
original_vars = set(tape.watched_variables())
# Backward pass
- def grad_fn(*output_grads, **kwargs):
+ def _grad_fn(output_grads, variables=None):
"""Recompute outputs for gradient computation."""
- variables = []
+ variables = variables or []
if original_vars:
- variables = kwargs["variables"]
- if set(variables) != original_vars:
- raise ValueError(_WRONG_VARS_ERR)
- del kwargs
- inputs = list(args)
+ assert variables, ("Fn created variables but the variables were not "
+ "passed to the gradient fn.")
+ if set(variables) != original_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+ inputs = [array_ops.identity(x) for x in list(args)]
# Recompute outputs
with framework_ops.control_dependencies(output_grads):
if use_data_dep_:
@@ -538,7 +538,7 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if original_vars != recompute_vars:
raise ValueError(_WRONG_VARS_ERR)
- if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
+ if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = gradients_impl.gradients(outputs, inputs + variables,
@@ -554,6 +554,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
+ # custom_gradient inspects the signature of the function to determine
+ # whether the user expects variables passed in the grad_fn. If the function
+ # created variables, the grad_fn should accept the "variables" kwarg.
+ if original_vars:
+ def grad_fn(*output_grads, **kwargs):
+ return _grad_fn(output_grads, kwargs["variables"])
+ else:
+ def grad_fn(*output_grads):
+ return _grad_fn(output_grads)
+
return outputs, grad_fn
return fn_with_recompute(*args)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index bc09ba8d43..d5971fb9d8 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -372,6 +372,26 @@ class RecomputeTest(test.TestCase):
self.assertEqual(2, len(update_ops))
self.assertEqual([False, True], kwarg_values)
+ def testWithoutVariables(self):
+
+ def concat_n(layer_list, num_inputs):
+ return math_ops.reduce_sum(
+ array_ops.concat([x for x in layer_list[-num_inputs:]], axis=-1),
+ axis=1, keepdims=True)
+
+ @rev_block_lib.recompute_grad
+ def concat_n_wrap(*args):
+ return concat_n(args, 3)
+
+ # DenseNet-style layers
+ layer_list = [random_ops.random_uniform((4, 8))]
+ for _ in range(5):
+ layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
+
+ grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
+ with self.test_session() as sess:
+ sess.run(grads)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 541da90617..f8a3709ee5 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -505,7 +505,7 @@ class Experiment(object):
eval_result = None
last_warning_time = 0
while (not predicate_fn or predicate_fn(
- eval_result, checkpoint_path=previous_path if eval_result else None)):
+ eval_result, checkpoint_path=previous_path)):
# Exit if we have already reached number of steps to train.
if self._has_training_stopped(eval_result):
logging.info("Exiting continuous eval, global_step=%s >= "
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index d10927a0cd..fb16c94c29 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase):
noop_hook = _NoopHook()
def _predicate_fn(eval_result, checkpoint_path):
- self.assertEqual(not eval_result,
+ self.assertEqual(eval_result is None,
checkpoint_path is None)
return est.eval_count < 3 # pylint: disable=cell-var-from-loop
diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD
index 5b89c6cef9..fe0ba19fcb 100644
--- a/tensorflow/contrib/linear_optimizer/BUILD
+++ b/tensorflow/contrib/linear_optimizer/BUILD
@@ -41,6 +41,7 @@ py_test(
size = "medium",
srcs = ["python/kernel_tests/sdca_ops_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows_gpu"],
deps = [
":sdca_ops_py",
":sparse_feature_column_py",
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index d0c32b43cc..ef0e08a777 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -377,7 +377,10 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
train_op.run()
def testDistributedSimple(self):
- # Setup test data
+ # Distributed SDCA may not converge if the workers update concurrently the
+ # same example. In this test the examples are partitioned across workers.
+ # The examples are the same for all workers, just the example_ids are
+ # different.
example_protos = [
make_example_proto({
'age': [0],
@@ -389,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
}, 1),
]
example_weights = [1.0, 1.0]
+ examples = make_example_dict(example_protos, example_weights)
+ example_ids = array_ops.placeholder(
+ dtypes.string, shape=(len(example_weights),))
+ examples['example_ids'] = example_ids
+ variables = make_variable_dict(1, 1)
for num_shards in _SHARD_NUMBERS:
for num_loss_partitions in _NUM_LOSS_PARTITIONS:
with self._single_threaded_test_session():
- examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
options = dict(
- symmetric_l2_regularization=1,
+ # Keep the same solution as for TestSimple: since the number of
+ # examples is multplied by num_loss_partitions, multiply also
+ # L2 by the same value.
+ symmetric_l2_regularization=num_loss_partitions,
symmetric_l1_regularization=0,
loss_type='logistic_loss',
num_table_shards=num_shards,
@@ -411,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
train_op = lr.minimize()
- def minimize():
+ def minimize(worker_id):
with self._single_threaded_test_session():
+ feed_dict = {example_ids: [
+ str(i + worker_id*len(example_weights)) for i in range(
+ len(example_weights))]}
for _ in range(_MAX_ITERATIONS):
- train_op.run() # pylint: disable=cell-var-from-loop
+ train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop
threads = []
- for _ in range(num_loss_partitions):
- threads.append(threading.Thread(target=minimize))
+ for worker_id in range(num_loss_partitions):
+ threads.append(threading.Thread(target=minimize, args=(worker_id,)))
threads[-1].start()
for t in threads:
t.join()
- lr.update_weights(train_op).run()
-
- # The high tolerance in unregularized_loss comparisons is due to the
- # fact that it's possible to trade off unregularized_loss vs.
- # regularization and still have a sum that is quite close to the
- # optimal regularized_loss value. SDCA's duality gap only ensures
- # that the regularized_loss is within 0.01 of optimal.
- # 0.525457 is the optimal regularized_loss.
- # 0.411608 is the unregularized_loss at that optimum.
- self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)
- self.assertAllClose(0.525457, loss.eval(), atol=0.01)
+ lr.update_weights(train_op).run(feed_dict={
+ example_ids: [str(i) for i in range(len(example_weights))]})
+
+ # Test only the unregularized loss because the optimal value of the
+ # regularized loss depends on num_loss_partitions.
+ self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 1], predicted_labels.eval())
- self.assertTrue(lr.approximate_duality_gap().eval() < 0.02)
+ self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02)
def testSimpleNoL2(self):
# Same as test above (so comments from above apply) but without an L2.
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 55b984f260..73f5c1448d 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -90,6 +90,16 @@ cc_library(
deps = [":context"],
)
+cc_library(
+ name = "kernel_api",
+ hdrs = [
+ "builtin_op_data.h",
+ "builtin_ops.h",
+ "context.h",
+ "context_util.h",
+ ],
+)
+
exports_files(["builtin_ops.h"])
cc_library(
@@ -118,6 +128,7 @@ cc_library(
hdrs = [
"allocation.h",
"context.h",
+ "context_util.h",
"error_reporter.h",
"graph_info.h",
"interpreter.h",
@@ -174,6 +185,7 @@ cc_test(
deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"//tensorflow/contrib/lite/schema:schema_fbs",
diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile
index cc8a8035d1..a616138d33 100644
--- a/tensorflow/contrib/lite/Makefile
+++ b/tensorflow/contrib/lite/Makefile
@@ -17,7 +17,29 @@ else
endif
endif
-ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+
+# Self-hosting
+TARGET_ARCH := ${HOST_ARCH}
+
+# Cross compiling
+ifeq ($(CROSS),rpi)
+ TARGET_ARCH := armv7l
+ TARGET_TOOLCHAIN_PREFIX := arm-linux-gnueabihf-
+endif
+
+ifeq ($(CROSS),riscv)
+ TARGET_ARCH := riscv
+ TARGET_TOOLCHAIN_PREFIX := riscv32-unknown-elf-
+endif
+ifeq ($(CROSS),stm32f7)
+ TARGET_ARCH := armf7
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+endif
+ifeq ($(CROSS),stm32f1)
+ TARGET_ARCH := armm1
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+endif
# Where compiled objects are stored.
OBJDIR := $(MAKEFILE_DIR)/gen/obj/
@@ -25,11 +47,46 @@ BINDIR := $(MAKEFILE_DIR)/gen/bin/
LIBDIR := $(MAKEFILE_DIR)/gen/lib/
GENDIR := $(MAKEFILE_DIR)/gen/obj/
+LIBS :=
+ifeq ($(TARGET_ARCH),x86_64)
+ CXXFLAGS += -fPIC -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -pthread # -msse4.2
+endif
+
+ifeq ($(TARGET_ARCH),armv7l)
+ CXXFLAGS += -mfpu=neon -pthread -fPIC
+ LIBS += -ldl
+endif
+
+ifeq ($(TARGET_ARCH),riscv)
+# CXXFLAGS += -march=gap8
+ CXXFLAGS += -DTFLITE_MCU
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
+
+ifeq ($(TARGET_ARCH),armf7)
+ CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -DTFLITE_MCU
+ CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
+ CXXFLAGS += -funsigned-char -MMD
+ CXXFLAGS += -mcpu=cortex-m7 -mthumb -mfpu=fpv5-sp-d16 -mfloat-abi=softfp
+ CXXFLAGS += '-std=gnu++11' '-fno-rtti' '-Wvla' '-c' '-Wall' '-Wextra' '-Wno-unused-parameter' '-Wno-missing-field-initializers' '-fmessage-length=0' '-fno-exceptions' '-fno-builtin' '-ffunction-sections' '-fdata-sections' '-funsigned-char' '-MMD' '-fno-delete-null-pointer-checks' '-fomit-frame-pointer' '-Os'
+ LIBS += -ldl
+ BUILD_TYPE := micro
+endif
+ifeq ($(TARGET_ARCH),armm1)
+ CXXFLAGS += -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK -mcpu=cortex-m1 -mthumb -DTFLITE_MCU
+ CXXFLAGS += -fno-rtti -fmessage-length=0 -fno-exceptions -fno-builtin -ffunction-sections -fdata-sections
+ CXXFLAGS += -funsigned-char -MMD
+ LIBS += -ldl
+endif
+
# Settings for the host compiler.
-CXX := $(CC_PREFIX)gcc
-CXXFLAGS := --std=c++11 -O3 -DNDEBUG
-CC := $(CC_PREFIX)gcc
-CCFLAGS := -O3 -DNDEBUG
+CXX := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}g++
+CXXFLAGS += --std=c++11 -O3 -DNDEBUG
+CCFLAGS := ${CXXFLAGS}
+CC := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX) ${TARGET_TOOLCHAIN_PREFIX}ar
+CFLAGS :=
LDOPTS :=
LDOPTS += -L/usr/local/lib
ARFLAGS := -r
@@ -48,7 +105,7 @@ INCLUDES := \
# override local versions in the source tree.
INCLUDES += -I/usr/local/include
-LIBS := \
+LIBS += \
-lstdc++ \
-lpthread \
-lm \
@@ -70,6 +127,12 @@ LIB_PATH := $(LIBDIR)$(LIB_NAME)
# A small example program that shows how to link against the library.
MINIMAL_PATH := $(BINDIR)minimal
+# Benchmark static library and binary
+BENCHMARK_LIB_NAME := benchmark-lib.a
+BENCHMARK_BINARY_NAME := benchmark_model
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
MINIMAL_OBJS := $(addprefix $(OBJDIR), \
@@ -78,19 +141,29 @@ $(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
+PROFILER_SRCS := \
+ tensorflow/contrib/lite/profiling/time.cc
+PROFILE_SUMMARIZER_SRCS := \
+ tensorflow/contrib/lite/profiling/profile_summarizer.cc \
+ tensorflow/core/util/stats_calculator.cc
+
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
+$(wildcard tensorflow/contrib/lite/*.c)
+ifneq ($(BUILD_TYPE),micro)
+CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c) \
+$(PROFILER_SRCS) \
$(wildcard tensorflow/contrib/lite/kernels/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \
$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \
$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c)
+endif
# Remove any duplicates.
CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS))
CORE_CC_EXCLUDE_SRCS := \
@@ -100,6 +173,11 @@ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \
$(MINIMAL_SRCS)
+ifeq ($(BUILD_TYPE),micro)
+CORE_CC_EXCLUDE_SRCS += \
+tensorflow/contrib/lite/model.cc \
+tensorflow/contrib/lite/nnapi_delegate.cc
+endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
# File names of the intermediate files target compilation generates.
@@ -107,18 +185,33 @@ TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
LIB_OBJS := $(TF_LITE_CC_OBJS)
+# Benchmark sources
+BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
+BENCHMARK_ALL_SRCS := $(TFLITE_CC_SRCS) \
+ $(wildcard $(BENCHMARK_SRCS_DIR)/*.cc) \
+ $(PROFILE_SUMMARIZER_SRCS)
+
+BENCHMARK_SRCS := $(filter-out \
+ $(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
+ $(BENCHMARK_ALL_SRCS))
+
+BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
+
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.cc
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
-
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.c
@mkdir -p $(dir $@)
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH)
+all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+
+# The target that's compiled for micro-controllers
+micro: $(LIB_PATH)
# Gathers together all the objects we've compiled into a single '.a' archive.
$(LIB_PATH): $(LIB_OBJS)
@@ -131,6 +224,21 @@ $(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
-o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
+
+$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
+ @mkdir -p $(dir $@)
+ $(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
+
+benchmark_lib: $(BENCHMARK_LIB)
+$(info $(BENCHMARK_BINARY))
+$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(BENCHMARK_BINARY) \
+ $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS)
+
+benchmark: $(BENCHMARK_BINARY)
+
# Gets rid of all generated files.
clean:
rm -rf $(MAKEFILE_DIR)/gen
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index a4772731ec..c42622ff02 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include <fcntl.h>
+#ifndef TFLITE_MCU
#include <sys/mman.h>
+#endif
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
@@ -27,10 +29,13 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
namespace tflite {
+#ifndef TFLITE_MCU
MMAPAllocation::MMAPAllocation(const char* filename,
ErrorReporter* error_reporter)
: Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) {
@@ -111,6 +116,7 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes,
buffer_ = ptr;
buffer_size_bytes_ = num_bytes;
}
+#endif
MemoryAllocation::~MemoryAllocation() {}
diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc
index 4f836d3677..4257e754ad 100644
--- a/tensorflow/contrib/lite/arena_planner.cc
+++ b/tensorflow/contrib/lite/arena_planner.cc
@@ -31,16 +31,17 @@ struct AllocationInfo {
// The tensor index to be allocated or deallocated.
int tensor;
// Whether to allocate or deallocate
- enum { ALLOC, DEALLOC } type;
+ enum Type { ALLOC, DEALLOC } type;
};
ArenaPlanner::ArenaPlanner(TfLiteContext* context,
- std::unique_ptr<GraphInfo> graph_info)
+ std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs)
: context_(context),
graph_info_(std::move(graph_info)),
arena_(kDefaultArenaAlignment),
- persistent_arena_(kDefaultArenaAlignment) {}
-
+ persistent_arena_(kDefaultArenaAlignment),
+ preserve_inputs_(preserve_inputs) {}
ArenaPlanner::~ArenaPlanner() {}
int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) {
@@ -67,6 +68,33 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Keeps track of references to each tensor.
std::vector<int> refcounts(graph_info_->num_tensors(), 0);
+ // `allocated` and `deallocated` are technically list of boolean values.
+ // We're saving the compiled binary size by using `vector<int>`.
+ std::vector<int> allocated(graph_info_->num_tensors(), false);
+ std::vector<int> deallocated(graph_info_->num_tensors(), false);
+
+ auto allocate = [this, &allocated, &deallocated](int node,
+ int tensor) -> TfLiteStatus {
+ if (allocated[tensor]) {
+ return kTfLiteOk;
+ }
+ TF_LITE_ENSURE(context_, !deallocated[tensor]);
+ alloc_queue_.push_back({node, tensor, AllocationInfo::ALLOC});
+ allocated[tensor] = true;
+ return kTfLiteOk;
+ };
+
+ auto deallocate = [this, &allocated, &deallocated](
+ int node, int tensor) -> TfLiteStatus {
+ if (!allocated[tensor]) {
+ // Do not enqueue a DEALLOC if the tensor is never allocated.
+ // This happened with the constant tensors.
+ return kTfLiteOk;
+ }
+ TF_LITE_ENSURE(context_, !deallocated[tensor]);
+ alloc_queue_.push_back({node, tensor, AllocationInfo::DEALLOC});
+ return kTfLiteOk;
+ };
// There will be an entry in alloc_queue_ for the allocation of each tensor
// and another for their deallocation.
@@ -79,6 +107,32 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
refcounts[tensor_index]++;
}
+ // Variable tensors should are also never overwritten and need to be alive all
+ // the time.
+ for (int tensor_index : graph_info_->variables()) {
+ refcounts[tensor_index]++;
+ }
+
+ // Queue all graph inputs for allocation. If preserve_inputs_ is true, make
+ // sure they never be overwritten.
+ for (int tensor_index : graph_info_->inputs()) {
+ if (tensor_index != kOptionalTensor) {
+ if (preserve_inputs_) {
+ refcounts[tensor_index]++;
+ }
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
+ }
+ }
+
+ // Queue all graph variable tensors for allocation.
+ for (int tensor_index : graph_info_->variables()) {
+ if (tensor_index != kOptionalTensor) {
+ // Increase the reference count for input tensors by one, so it will
+ // never be deallocated.
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
+ }
+ }
+
// Count references to node input tensors.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
@@ -94,10 +148,9 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
// Queue all graph inputs for allocation.
for (int tensor_index : graph_info_->inputs()) {
if (tensor_index != kOptionalTensor) {
- alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC});
+ TF_LITE_ENSURE_STATUS(allocate(0, tensor_index));
}
}
-
// Go through the graph in execution order.
for (int i = 0; i < graph_info_->num_nodes(); ++i) {
const TfLiteNode& node = graph_info_->node(i);
@@ -106,7 +159,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
TfLiteIntArray* node_outputs = node.outputs;
for (int j = 0; j < node_outputs->size; ++j) {
int tensor_index = node_outputs->data[j];
- alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC});
+ TF_LITE_ENSURE_STATUS(allocate(i, tensor_index));
}
// Then update the ref-counts of the node's inputs, and if necessary queue
@@ -117,7 +170,7 @@ TfLiteStatus ArenaPlanner::PlanAllocations() {
if (tensor_index != kOptionalTensor) {
refcounts[tensor_index]--;
if (refcounts[tensor_index] == 0) {
- alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC});
+ TF_LITE_ENSURE_STATUS(deallocate(i, tensor_index));
}
}
}
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index e9d0fbc5a9..1d84950e91 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -43,8 +43,11 @@ struct AllocationInfo;
class ArenaPlanner : public MemoryPlanner {
public:
// Ownership of 'context' is not taken and it must remain util the
- // ArenaPlanner is destroyed.
- ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info);
+ // ArenaPlanner is destroyed. If 'preserve_inputs' is true the inputs to the
+ // graph will not share memory with any other tensor, effectively preserving
+ // them until the end of inference.
+ ArenaPlanner(TfLiteContext* context, std::unique_ptr<GraphInfo> graph_info,
+ bool preserve_inputs);
~ArenaPlanner() override;
ArenaPlanner(const ArenaPlanner&) = delete;
ArenaPlanner& operator=(const ArenaPlanner&) = delete;
@@ -100,6 +103,8 @@ class ArenaPlanner : public MemoryPlanner {
// Raw memory buffer that is allocated for persistent tensors that are
// declared as kTfLiteArenaRwPersistent.
SimpleMemoryArena persistent_arena_;
+
+ bool preserve_inputs_;
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc
index a8a8755e2c..f5bd1932f9 100644
--- a/tensorflow/contrib/lite/arena_planner_test.cc
+++ b/tensorflow/contrib/lite/arena_planner_test.cc
@@ -100,12 +100,18 @@ class TestGraph {
std::vector<TfLiteTensor>* tensors() { return &tensors_; }
const std::vector<int>& inputs() { return inputs_; }
const std::vector<int>& outputs() { return outputs_; }
+ const std::vector<int>& variables() { return variables_; }
+
+ void SetVariables(const std::vector<int>& variables) {
+ variables_ = variables;
+ }
private:
std::vector<TfLiteNode> nodes_;
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
+ std::vector<int> variables_;
};
// The GraphInfo for a TestGraph.
@@ -123,6 +129,9 @@ class TestGraphInfo : public GraphInfo {
}
const std::vector<int>& inputs() const override { return graph_->inputs(); }
const std::vector<int>& outputs() const override { return graph_->outputs(); }
+ const std::vector<int>& variables() const override {
+ return graph_->variables();
+ }
private:
TestGraph* graph_;
@@ -142,11 +151,12 @@ void ReportError(TfLiteContext* context, const char* format, ...) {
class ArenaPlannerTest : public ::testing::Test {
protected:
- void SetGraph(TestGraph* graph) {
+ void SetGraph(TestGraph* graph, bool preserve_inputs = false) {
graph_ = graph;
context_.ReportError = ReportError;
planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph))));
+ &context_, std::unique_ptr<GraphInfo>(new TestGraphInfo(graph)),
+ preserve_inputs));
CHECK(planner_->ResetAllocations() == kTfLiteOk);
CHECK(planner_->PlanAllocations() == kTfLiteOk);
}
@@ -209,11 +219,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) {
TestGraph graph({1}, {{{1}, {2}, {}}}, {2});
(*graph.tensors())[1].bytes = 0;
SetGraph(&graph);
- // TODO(ahentz): this is currently broken because the arena finds two
- // allocations with the same offset and returns an error.
- ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk);
- // EXPECT_EQ(GetOffset(1), 0);
- // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1));
+ ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk);
+ EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr);
}
TEST_F(ArenaPlannerTest, SimpleGraph) {
@@ -237,6 +244,30 @@ TEST_F(ArenaPlannerTest, SimpleGraph) {
EXPECT_EQ(GetOffset(3), 0);
}
+TEST_F(ArenaPlannerTest, SimpleGraphInputsPreserved) {
+ TestGraph graph({0, 1},
+ {
+ /* in, out, tmp */
+ {{0, 1}, {2}, {}}, // First op
+ {{2, 0}, {4, 5}, {}}, // Second op
+ {{4, 5}, {3}, {}} // Third op
+ },
+ {3});
+ SetGraph(&graph, /*preserve_inputs=*/true);
+ Execute(0, 10);
+
+ // Alloc(+) and dealloc(-) order: +0 +1 +2 +4 +5 -2 +3 -4 -5
+ EXPECT_EQ(GetOffset(0), 0);
+ EXPECT_EQ(GetOffset(1), GetOffsetAfter(0));
+ EXPECT_EQ(GetOffset(2), GetOffsetAfter(1));
+ EXPECT_EQ(GetOffset(4), GetOffsetAfter(2));
+ EXPECT_EQ(GetOffset(5), GetOffsetAfter(4));
+ // Because we are keeping the inputs alive until the end (due to
+ // preserve_inputs=true), the output tensor will not be able to use that
+ // space. It will end up using the same are as tensor #2.
+ EXPECT_EQ(GetOffset(3), GetOffsetAfter(1));
+}
+
TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) {
TestGraph graph({0, 1},
{
@@ -309,13 +340,15 @@ TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) {
{
/* in, out, tmp */
{{0, 1}, {2}, {}}, // First op
- {{2, 0}, {4}, {5}}, // Second op, with temporary
+ {{2, 0}, {4}, {5}}, // Second op, with persistent
{{4, -1}, {3}, {}} // Third op, with optional
},
{3});
// Make #1 persistent so it goes into its own arena.
(*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent;
+ // The only use case for kTfLiteArenaRwPersistent is variable tensor now.
+ graph.SetVariables({1});
SetGraph(&graph);
Execute(0, 10);
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index b9e40cc50c..b735d08b4b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -195,7 +195,7 @@ def json_to_tflite(name, src, out):
def generated_test_models():
return [
"add",
- "arg_max",
+ "arg_min_max",
"avg_pool",
"batch_to_space_nd",
"concat",
@@ -204,7 +204,9 @@ def generated_test_models():
"conv",
"depthwiseconv",
"div",
+ "equal",
"exp",
+ "expand_dims",
"floor",
"fully_connected",
"fused_batch_norm",
@@ -212,12 +214,14 @@ def generated_test_models():
"global_batch_norm",
"greater",
"greater_equal",
+ "sum",
"l2norm",
"l2_pool",
"less",
"less_equal",
"local_response_norm",
"log_softmax",
+ "log",
"lstm",
"max_pool",
"maximum",
@@ -225,14 +229,18 @@ def generated_test_models():
"minimum",
"mul",
"neg",
+ "not_equal",
"pad",
"padv2",
- # "prelu",
+ "prelu",
+ "pow",
"relu",
"relu1",
"relu6",
"reshape",
"resize_bilinear",
+ "rsqrt",
+ "shape",
"sigmoid",
"sin",
"slice",
@@ -241,13 +249,15 @@ def generated_test_models():
"space_to_depth",
"sparse_to_dense",
"split",
+ "sqrt",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
"sub",
+ "tile",
"topk",
"transpose",
- "transpose_conv",
+ #"transpose_conv", # disabled due to b/111213074
"where",
]
diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh
index 9f398f4a9f..e9531aef19 100755
--- a/tensorflow/contrib/lite/build_ios_universal_lib.sh
+++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh
@@ -19,22 +19,23 @@ set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR/../../.."
-make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 \
-$SCRIPT_DIR/gen/lib/ios_x86_64/libtensorflow-lite.a
-make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 \
-$SCRIPT_DIR/gen/lib/ios_i386/libtensorflow-lite.a
-make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 \
-$SCRIPT_DIR/gen/lib/ios_armv7/libtensorflow-lite.a
-make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 \
-$SCRIPT_DIR/gen/lib/ios_armv7s/libtensorflow-lite.a
-make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 \
-$SCRIPT_DIR/gen/lib/ios_arm64/libtensorflow-lite.a
+# Build library for supported architectures and packs them in a fat binary.
+make_library() {
+ for arch in x86_64 i386 armv7 armv7s arm64
+ do
+ make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=${arch} \
+ -j 8 \
+ $SCRIPT_DIR/gen/lib/ios_${arch}/${1}
+ done
+ lipo \
+ tensorflow/contrib/lite/gen/lib/ios_x86_64/${1} \
+ tensorflow/contrib/lite/gen/lib/ios_i386/${1} \
+ tensorflow/contrib/lite/gen/lib/ios_armv7/${1} \
+ tensorflow/contrib/lite/gen/lib/ios_armv7s/${1} \
+ tensorflow/contrib/lite/gen/lib/ios_arm64/${1} \
+ -create \
+ -output tensorflow/contrib/lite/gen/lib/${1}
+}
-lipo \
-tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \
-tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \
-tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \
-tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \
-tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \
--create \
--output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a
+make_library libtensorflow-lite.a
+make_library benchmark-lib.a
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 52ab9ee640..a58dde9a7b 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -92,8 +92,17 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteSequenceRNNParams;
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
typedef struct {
+ // Parameters for FullyConnected version 1 or above.
TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
} TfLiteFullyConnectedParams;
typedef enum {
@@ -148,10 +157,20 @@ typedef struct {
float beta;
} TfLiteLocalResponseNormParams;
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
typedef struct {
+ // Parameters for LSTM version 1.
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
} TfLiteLSTMParams;
typedef struct {
@@ -205,7 +224,7 @@ typedef struct {
typedef struct {
bool keep_dims;
-} TfLiteMeanParams;
+} TfLiteReducerParams;
typedef struct {
int num_splits;
@@ -231,6 +250,10 @@ typedef struct {
} TfLiteArgMaxParams;
typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
TfLitePadding padding;
int stride_width;
int stride_height;
@@ -240,6 +263,16 @@ typedef struct {
bool validate_indices;
} TfLiteSparseToDenseParams;
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ float min;
+ float max;
+ int num_bits;
+} TfLiteFakeQuantParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index c797e3589a..6bde5d2e6d 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
// DO NOT EDIT MANUALLY: This file is automatically generated by
-// `schema_builtin_ops_header_generator.py`.
+// `schema/builtin_ops_header/generator.cc`.
#ifdef __cplusplus
extern "C" {
@@ -94,6 +94,18 @@ typedef enum {
kTfLiteBuiltinSin = 66,
kTfLiteBuiltinTransposeConv = 67,
kTfLiteBuiltinSparseToDense = 68,
+ kTfLiteBuiltinTile = 69,
+ kTfLiteBuiltinExpandDims = 70,
+ kTfLiteBuiltinEqual = 71,
+ kTfLiteBuiltinNotEqual = 72,
+ kTfLiteBuiltinLog = 73,
+ kTfLiteBuiltinSum = 74,
+ kTfLiteBuiltinSqrt = 75,
+ kTfLiteBuiltinRsqrt = 76,
+ kTfLiteBuiltinShape = 77,
+ kTfLiteBuiltinPow = 78,
+ kTfLiteBuiltinArgMin = 79,
+ kTfLiteBuiltinFakeQuant = 80,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c
index 5c6f5e72a4..7f2aa316f4 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/context.c
@@ -76,7 +76,7 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable, TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
@@ -86,6 +86,7 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
tensor->bytes = size;
tensor->allocation_type = allocation_type;
tensor->allocation = allocation;
+ tensor->is_variable = is_variable;
}
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 4eb66cc225..1ff8843fa7 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -39,6 +39,26 @@ extern "C" {
typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteMaxExternalContexts = 2
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
// Forward declare so GetNode can use this is in Context.
typedef struct _TfLiteRegistration TfLiteRegistration;
typedef struct _TfLiteDelegate TfLiteDelegate;
@@ -138,6 +158,8 @@ typedef enum {
kTfLiteInt64 = 4,
kTfLiteString = 5,
kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
} TfLiteType;
// Parameters for asymmetric quantization. Quantized values can be converted
@@ -148,7 +170,7 @@ typedef struct {
int32_t zero_point;
} TfLiteQuantizationParams;
-// A union of points that points to memory for a given tensor.
+// A union of pointers that points to memory for a given tensor.
typedef union {
int* i32;
int64_t* i64;
@@ -157,6 +179,8 @@ typedef union {
const char* raw_const;
uint8_t* uint8;
bool* b;
+ int16_t* i16;
+ _Complex float* c64;
} TfLitePtrUnion;
// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
@@ -223,6 +247,9 @@ typedef struct {
// delegate buffer.
// WARNING: This is an // experimental interface that is subject to change.
bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
} TfLiteTensor;
// Free data memory of tensor `t`;
@@ -235,9 +262,11 @@ void TfLiteTensorFree(TfLiteTensor* t);
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, TfLiteTensor* tensor);
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
-// Resize the allocated data of a (dynamic) tensor.
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
// A structure representing an instance of a node.
@@ -330,10 +359,15 @@ typedef struct TfLiteContext {
// eigen.
int recommended_num_threads;
- // TODO(ahentz): we should create a more general mechanism for this sort of
- // library-global objects.
- void* gemm_context;
- void* eigen_context;
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
} TfLiteContext;
typedef struct _TfLiteRegistration {
@@ -368,6 +402,14 @@ typedef struct _TfLiteRegistration {
// Returns kTfLiteOk on success.
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
// Builtin codes. If this kernel refers to a builtin this is the code
// of the builtin. This is so we can do marshaling to other frameworks like
// NN API.
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
new file mode 100644
index 0000000000..abe802e342
--- /dev/null
+++ b/tensorflow/contrib/lite/context_util.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This provides a few C++ helpers that are useful for manipulating C structures
+// in C++.
+#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
+#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite
+// C api uses. Can't use the google array_view, since we can't depend on even
+// absl for embedded device reasons.
+class TfLiteIntArrayView {
+ public:
+ // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null
+ // and this view does not take ownership of it.
+ explicit TfLiteIntArrayView(const TfLiteIntArray* int_array)
+ : int_array_(int_array) {}
+
+ TfLiteIntArrayView(const TfLiteIntArrayView&) = default;
+ TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default;
+
+ typedef const int* const_iterator;
+ const_iterator begin() const { return int_array_->data; }
+ const_iterator end() const { return &int_array_->data[int_array_->size]; }
+ size_t size() const { return end() - begin(); }
+
+ private:
+ const TfLiteIntArray* int_array_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
new file mode 100644
index 0000000000..35a8f6ca41
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "nnapi_delegate",
+ srcs = ["nnapi_delegate.cc"],
+ hdrs = ["nnapi_delegate.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "nnapi_delegate_test",
+ size = "small",
+ srcs = ["nnapi_delegate_test.cc"],
+ deps = [
+ ":nnapi_delegate",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
new file mode 100644
index 0000000000..f0d16575ec
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -0,0 +1,694 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/builtin_ops.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+#ifdef __ANDROID__
+#include <sys/system_properties.h>
+#endif
+
+namespace tflite {
+namespace {
+
+// TODO(b/80621585): Consider printing error string, but don't for now to
+// minimize binary size.
+#define CHECK_NN(context, code) \
+ if (code != ANEURALNETWORKS_NO_ERROR) { \
+ context->ReportError(context, "NN API returned error (%d).\n", code); \
+ return kTfLiteError; \
+ }
+
+namespace {
+int32_t GetAndroidSdkVersion() {
+#ifdef __ANDROID__
+ const char* sdkProp = "ro.build.version.sdk";
+ char sdkVersion[PROP_VALUE_MAX];
+ int length = __system_property_get(sdkProp, sdkVersion);
+ if (length != 0) {
+ for (int i = 0; i < length; ++i) {
+ int digit = sdkVersion[i] - '0';
+ if (digit < 0 || digit > 9) {
+ // Non-numeric SDK version, assume it's higher then expected;
+ return std::numeric_limits<int32_t>::max();
+ }
+ }
+ return atoi(sdkVersion);
+ }
+#endif // __ANDROID__
+ return 0;
+}
+
+constexpr int32_t kMinSdkVersionForNNAPI = 27;
+constexpr int32_t kMinSdkVersionForNNAPI11 = 28;
+static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+
+} // namespace
+
+// RAII NN API Model Destructor for use with std::unique_ptr
+struct NNFreeModel {
+ void operator()(ANeuralNetworksModel* model) {
+ ANeuralNetworksModel_free(model);
+ }
+};
+// RAII NN API Compilation Destructor for use with std::unique_ptr
+struct NNFreeCompilation {
+ void operator()(ANeuralNetworksCompilation* model) {
+ ANeuralNetworksCompilation_free(model);
+ }
+};
+
+// Track tensor indices to NN API tensor indices mapping.
+class OperandMapping {
+ public:
+ // Given a TFLite index return the ANN index. If it doesn't exist
+ // return -1.
+ int lite_index_to_ann(int index) const {
+ if (index < lite_tensor_to_ann_tensor_.size())
+ return lite_tensor_to_ann_tensor_[index];
+ else
+ return -1;
+ }
+
+ // NN API uses non tensor operands instead of structs. This creates one
+ // and returns the index. It uses a std::vector and resizes it as needed
+ // keeping -1 to unmapped values. Intermediate tensors likely will not
+ // be mapped.
+ int add_new_non_tensor_operand() { return next_ann_tensor_index_++; }
+
+ // Add a new mapping from `tflite_index` and return the NN API tensor index.
+ int add_new_ann_tensor_index(int tflite_index) {
+ if (tflite_index >= lite_tensor_to_ann_tensor_.size()) {
+ lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1);
+ }
+ int new_tensor_index = next_ann_tensor_index_++;
+ lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index;
+ return new_tensor_index;
+ }
+
+ private:
+ // Next index of ann tensor
+ int next_ann_tensor_index_ = 0;
+
+ // Mapping from lite index. Use a std::vector for speed and code size
+ // rather than a map.
+ std::vector<int> lite_tensor_to_ann_tensor_;
+};
+
+// Abstract builder for building an op in the NN API graph. This handles
+// the disparity between TFLite and NN API operand types. NN API has singular
+// operands for both tensors and parameters, and TFLite separates the two.
+class NNAPIOpBuilder {
+ public:
+ NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping,
+ ANeuralNetworksModel* nn_model)
+ : context_(context),
+ operand_mapping_(tensor_mapping),
+ nn_model_(nn_model) {}
+
+ TfLiteStatus AddScalarInt32Operand(int32_t value) {
+ return AddScalarOperand<int32_t>(value, ANEURALNETWORKS_INT32);
+ }
+
+ TfLiteStatus AddScalarFloat32Operand(float value) {
+ return AddScalarOperand<float>(value, ANEURALNETWORKS_FLOAT32);
+ }
+
+ TfLiteStatus AddVectorInt32Operand(const int32_t* values,
+ uint32_t num_values) {
+ return AddVectorOperand<int32_t>(values, num_values,
+ ANEURALNETWORKS_TENSOR_INT32);
+ }
+
+ TfLiteStatus AddPoolingParams(void* data) {
+ auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
+ AddScalarInt32Operand(builtin->padding);
+ AddScalarInt32Operand(builtin->stride_width);
+ AddScalarInt32Operand(builtin->stride_height);
+ AddScalarInt32Operand(builtin->filter_width);
+ AddScalarInt32Operand(builtin->filter_height);
+ AddScalarInt32Operand(builtin->activation);
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus AddTensorInput(int tensor_index) {
+ int ann_index;
+ TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
+ augmented_inputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus AddTensorOutput(int tensor_index) {
+ int ann_index;
+ TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
+ augmented_outputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
+
+ // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`.
+ // This returns the NN API tensor index corresponding to the created tensor.
+ // If another caller previously created a NN API tensor for `tensor_index`
+ // then the existing one is returned.
+ TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) {
+ int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index);
+ if (ann_tensor_index != -1) {
+ *ann_tensor_index_out = ann_tensor_index;
+ return kTfLiteOk;
+ }
+ // Allocate a new tensor index
+ ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index);
+
+ // Parameters needed for new type.
+ int32_t nn_type = 0;
+ float scale = 0.0f;
+ int32_t zeroPoint = 0;
+ TfLiteTensor* tensor = &context_->tensors[tensor_index];
+ switch (tensor->type) {
+ case kTfLiteNoType:
+ // Tensors added during initialization of Ops don't have a type yet and
+ // should not be registered with the NNAPI.
+ *ann_tensor_index_out = -1;
+ return kTfLiteOk;
+ case kTfLiteFloat32:
+ nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
+ break;
+ case kTfLiteUInt8:
+ nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ case kTfLiteInt32:
+ nn_type = ANEURALNETWORKS_TENSOR_INT32;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ default:
+ context_->ReportError(context_, "Logic error in NN API Delegate.\n");
+ return kTfLiteError;
+ }
+
+ ANeuralNetworksOperandType operand_type{
+ nn_type, static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
+ CHECK_NN(context_,
+ ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ // TODO(b/80630405): Use NNAPIAllocation.
+ CHECK_NN(context_, ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_tensor_index, tensor->data.raw,
+ tensor->bytes));
+ }
+
+ *ann_tensor_index_out = ann_tensor_index;
+ return kTfLiteOk;
+ }
+
+ // Finish emitting the op (of type `type`) into the NN API.
+ TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) {
+ // Actually add a NN API operation
+ CHECK_NN(context_, ANeuralNetworksModel_addOperation(
+ nn_model_, type,
+ static_cast<uint32_t>(augmented_inputs_.size()),
+ augmented_inputs_.data(),
+ static_cast<uint32_t>(augmented_outputs_.size()),
+ augmented_outputs_.data()));
+ augmented_inputs_.clear();
+ augmented_outputs_.clear();
+ return kTfLiteOk;
+ }
+
+ private:
+ template <typename T>
+ TfLiteStatus AddScalarOperand(T value, int32_t nn_type) {
+ ANeuralNetworksOperandType operand_type{.type = nn_type};
+ CHECK_NN(context_,
+ ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+ int ann_operand = operand_mapping_->add_new_non_tensor_operand();
+ CHECK_NN(context_, ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_operand, &value, sizeof(T)));
+ augmented_inputs_.push_back(ann_operand);
+ return kTfLiteOk;
+ }
+
+ template <typename T>
+ TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values,
+ int32_t nn_type) {
+ ANeuralNetworksOperandType operand_type{
+ .type = nn_type, .dimensionCount = 1, .dimensions = &num_values};
+ CHECK_NN(context_,
+ ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+ int ann_operand = operand_mapping_->add_new_non_tensor_operand();
+ CHECK_NN(context_,
+ ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_operand, values, sizeof(T) * num_values));
+ augmented_inputs_.push_back(ann_operand);
+ return kTfLiteOk;
+ }
+
+ // TfLiteContext for error handling. Must be named context for macros to
+ // work.
+ TfLiteContext* context_;
+
+ // Tracks relationship between indices
+ OperandMapping* operand_mapping_;
+
+ // The model
+ ANeuralNetworksModel* nn_model_;
+
+ // Inputs and outputs for the current op. These are augmented in the sense
+ // that NN API uses operands for all arguments, not just tensors, unlike
+ // TensorFlow lite.
+ std::vector<uint32_t> augmented_inputs_;
+ std::vector<uint32_t> augmented_outputs_;
+};
+
+// The kernel that represents the subgraph of TF Lite being run on NN API.
+class NNAPIDelegateKernel {
+ public:
+ NNAPIDelegateKernel() = default;
+
+ typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*,
+ NNAPIOpBuilder* builder,
+ TfLiteNode* node);
+
+ // Return a function that knows how to translate a node into its operands
+ // when called. You can use this function to see if a node is supported
+ // (i.e. that MappingFn is not nullptr).
+ MappingFn Map(TfLiteContext* context, int builtin_code, int version,
+ TfLiteNode* node) {
+ switch (builtin_code) {
+ case kTfLiteBuiltinAdd:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_ADD;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinMul:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_MUL;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinAveragePool2d:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ builder->AddPoolingParams(node->builtin_data);
+ return ANEURALNETWORKS_AVERAGE_POOL_2D;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinMaxPool2d:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ builder->AddPoolingParams(node->builtin_data);
+ return ANEURALNETWORKS_MAX_POOL_2D;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinL2Pool2d:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ builder->AddPoolingParams(node->builtin_data);
+ return ANEURALNETWORKS_L2_POOL_2D;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinConv2d:
+ if (version == 1) {
+ auto builtin =
+ reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ if (builtin->dilation_width_factor != 1 ||
+ builtin->dilation_height_factor != 1 || node->inputs->size != 3) {
+ // NNAPI does not support dilated Conv2D.
+ return nullptr;
+ }
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->padding);
+ builder->AddScalarInt32Operand(builtin->stride_width);
+ builder->AddScalarInt32Operand(builtin->stride_height);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_CONV_2D;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinDepthwiseConv2d:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(
+ node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->padding);
+ builder->AddScalarInt32Operand(builtin->stride_width);
+ builder->AddScalarInt32Operand(builtin->stride_height);
+ builder->AddScalarInt32Operand(builtin->depth_multiplier);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_DEPTHWISE_CONV_2D;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinFullyConnected:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(
+ node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_FULLY_CONNECTED;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinSoftmax:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+ builder->AddScalarFloat32Operand(builtin->beta);
+ return ANEURALNETWORKS_SOFTMAX;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinReshape:
+ if (version == 1) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_RESHAPE;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ case kTfLiteBuiltinSqueeze:
+ // Squeeze requires NNAPI1.1.
+ if (version == 1 && kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLiteSqueezeParams*>(node->builtin_data);
+ // Note that we add the squeeze dimensions even if the dimensions
+ // were unspecified (empty), as NNAPI requires the operand.
+ builder->AddVectorInt32Operand(
+ builtin->squeeze_dims,
+ static_cast<uint32_t>(builtin->num_squeeze_dims));
+ return ANEURALNETWORKS_SQUEEZE;
+ };
+ } else {
+ return nullptr;
+ }
+ case kTfLiteBuiltinTranspose:
+ // Transpose requires NNAPI1.1. Also note that the permutation input
+ // tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((version == 1) &&
+ (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) &&
+ (node->inputs->size > 1) &&
+ (context->tensors[node->inputs->data[1]].allocation_type ==
+ kTfLiteMmapRo)) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_TRANSPOSE;
+ };
+ } else {
+ return nullptr;
+ }
+ break;
+ default:
+ return nullptr;
+ }
+ }
+
+ // Initialize the kernel (a NN model).
+ TfLiteStatus Init(TfLiteContext* context,
+ const TfLiteDelegateParams* params) {
+ for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
+ nodes_.push_back(node_index);
+ }
+
+ if (!nn_model_) {
+ ANeuralNetworksModel* model;
+ CHECK_NN(context, ANeuralNetworksModel_create(&model));
+ nn_model_.reset(model);
+
+ TF_LITE_ENSURE_STATUS(
+ BuildGraph(context, params->input_tensors, params->output_tensors));
+ }
+
+ if (!nn_compilation_) {
+ ANeuralNetworksCompilation* compilation;
+ CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(),
+ &compilation));
+ CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation));
+ nn_compilation_.reset(compilation);
+ }
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
+ ANeuralNetworksExecution* execution = nullptr;
+ CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(),
+ &execution));
+
+ // Set the input tensor buffers. Note: we access tflite tensors using
+ // absolute indices but NN api indices inputs by relative indices.
+ int relative_input_index = 0;
+ for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
+ TfLiteTensor* tensor = &context->tensors[absolute_input_index];
+ // TODO(miaowang): make sure the delegation works with dequantized weights
+ // as intermediate tensors.
+ if (tensor->allocation_type != kTfLiteMmapRo) {
+ CHECK_NN(context, ANeuralNetworksExecution_setInput(
+ execution, relative_input_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_input_index++;
+ }
+ }
+
+ // Set the output tensor buffers.
+ int relative_output_index = 0;
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ TfLiteTensor* tensor = &context->tensors[output_index];
+ CHECK_NN(context, ANeuralNetworksExecution_setOutput(
+ execution, relative_output_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_output_index++;
+ }
+ // Invoke ANN in blocking fashion.
+ ANeuralNetworksEvent* event = nullptr;
+ CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event));
+ CHECK_NN(context, ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
+ ANeuralNetworksExecution_free(execution);
+
+ return kTfLiteOk;
+ }
+
+ private:
+ // ANN API state.
+ std::unique_ptr<ANeuralNetworksModel, NNFreeModel> nn_model_;
+ std::unique_ptr<ANeuralNetworksCompilation, NNFreeCompilation>
+ nn_compilation_;
+ // Node indices that this delegate is responsible for. Indices here
+ // indexes into the nodes array in the TfLiteContext.
+ std::vector<int> nodes_;
+ // Track indices we use
+ OperandMapping operand_mapping_;
+
+ TfLiteStatus AddOpsAndTensors(TfLiteContext* context) {
+ // The operand builder allows creating a single op. We create it at this
+ // reduced power position rather than in the for loop to avoid reallocating
+ // the vectors.
+ NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get());
+ // Add Tensors
+ // allocate outside to avoid realloc
+ for (auto node_index : nodes_) {
+ // Obtain the op and registration.
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // Map inputs to NN API tensor indices.
+ for (auto input_index : TfLiteIntArrayView(node->inputs)) {
+ TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index));
+ }
+ // Get op type and operands
+ int nn_op_type = Map(context, reg->builtin_code, reg->version, node)(
+ context, &builder, node);
+ // Map outputs to NN API tensor indices.
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
+ }
+
+ builder.FinalizeAddOperation(nn_op_type);
+ }
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus BuildGraph(TfLiteContext* context,
+ const TfLiteIntArray* input_tensors,
+ const TfLiteIntArray* output_tensors) {
+ // Build the ops and tensors.
+ TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context));
+ // Map input and output tensor indices to ANN
+ std::vector<uint32_t> inputs;
+ inputs.reserve(input_tensors->size);
+ std::vector<uint32_t> outputs;
+ outputs.reserve(output_tensors->size);
+ // Make the TensorFlow lite inputs and outputs to ann_indices.
+ for (int i : TfLiteIntArrayView(input_tensors)) {
+ // Constant tensors are not NNAPI inputs.
+ if (context->tensors[i].allocation_type != kTfLiteMmapRo) {
+ inputs.push_back(operand_mapping_.lite_index_to_ann(i));
+ }
+ }
+ for (int i : TfLiteIntArrayView(output_tensors))
+ outputs.push_back(operand_mapping_.lite_index_to_ann(i));
+ // Tell ANN to declare inputs/outputs
+ CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
+ nn_model_.get(), inputs.size(), inputs.data(),
+ outputs.size(), outputs.data()));
+ // Finalize the model
+ CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get()));
+
+ return kTfLiteOk;
+ }
+};
+
+} // namespace
+
+// Return a NN API Delegate struct that can check for support of ops.
+TfLiteDelegate* NnApiDelegate() {
+ static TfLiteDelegate delegate = {
+ .data_ = nullptr,
+ .Prepare = [](TfLiteContext* context,
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ // Do not check nodes_ if NN API is unavailable.
+ if (kAndroidSdkVersion < kMinSdkVersionForNNAPI || !NNAPIExists()) {
+ return kTfLiteOk;
+ }
+
+ std::vector<int> supported_nodes(1);
+ // We don't care about all nodes_, we only care about ones in the
+ // current plan.
+ TfLiteIntArray* plan;
+ TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
+ int total_supported_nodes = 0;
+
+ // Check for every node if it is supported
+ // TODO(b/80625235): Fix this to do more careful checking of versioning.
+ for (int node_index : TfLiteIntArrayView(plan)) {
+ TfLiteNode* node;
+ TfLiteRegistration* registration;
+ TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
+ context, node_index, &node, &registration));
+ NNAPIDelegateKernel dummy_kernel;
+ if (dummy_kernel.Map(context, registration->builtin_code,
+ registration->version, node)) {
+ supported_nodes.push_back(node_index);
+ }
+ total_supported_nodes += 1;
+ }
+ // Put the size at the beginning of the array.
+ supported_nodes[0] = supported_nodes.size() - 1;
+
+ // NN API Delegate Registration (the pseudo kernel that will invoke NN
+ // API subgraphs)
+ static const TfLiteRegistration nnapi_delegate_kernel = {
+ .init = [](TfLiteContext* context, const char* buffer,
+ size_t length) -> void* {
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+ NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel;
+ kernel_state->Init(context, params);
+ return kernel_state;
+ },
+
+ .free = [](TfLiteContext* context, void* buffer) -> void {
+ delete reinterpret_cast<NNAPIDelegateKernel*>(buffer);
+ },
+
+ .prepare = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ // Since the underlying resize happened ahead of delegation
+ // worked. This does nothing.
+ return kTfLiteOk;
+ },
+
+ .invoke = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ NNAPIDelegateKernel* state =
+ reinterpret_cast<NNAPIDelegateKernel*>(node->user_data);
+ return state->Invoke(context, node);
+ },
+
+ .builtin_code = kTfLiteBuiltinDelegate,
+ };
+
+ // Request TFLite to partition the graph and make kernels
+ // for each independent subgraph a new nnapi_delegate_kernel.
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, nnapi_delegate_kernel,
+ reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()),
+ delegate);
+ return kTfLiteOk;
+ }};
+
+ return &delegate;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 0e08a04370..44cca2fd28 100644
--- a/tensorflow/contrib/lite/java/src/main/native/duration_utils_jni.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -12,27 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
-#include <jni.h>
-#include <time.h>
+#include "tensorflow/contrib/lite/context.h"
namespace tflite {
-// Gets the elapsed wall-clock timespec.
-timespec getCurrentTime() {
- timespec time;
- clock_gettime(CLOCK_MONOTONIC, &time);
- return time;
-}
-
-// Computes the time diff from two timespecs. Returns '-1' if 'stop' is earlier
-// than 'start'.
-jlong timespec_diff_nanoseconds(struct timespec* start, struct timespec* stop) {
- jlong result = stop->tv_sec - start->tv_sec;
- if (result < 0) return -1;
- result = 1000000000 * result + (stop->tv_nsec - start->tv_nsec);
- if (result < 0) return -1;
- return result;
-}
-
+// Return a delegate that can be used to use the NN API.
+// e.g.
+// NnApiDelegate* delegate = NnApiDelegate();
+// interpreter->ModifyGraphWithDelegate(&delegate);
+// NnApiDelegate() returns a singleton, so you should not free this
+// pointer or worry about its lifetime.
+TfLiteDelegate* NnApiDelegate();
} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
new file mode 100644
index 0000000000..ab2181e8ff
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -0,0 +1,688 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+// TODO(b/110368244): figure out how to share the existing tests in kernels/ but
+// with the delegation on. Also, add more unit tests to improve code coverage.
+
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+ SingleOpModelWithNNAPI() {
+ this->SetApplyDelegate([](Interpreter* interpreter) {
+ interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false);
+ });
+ }
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
+ public:
+ FloatAddOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+ CreateAddOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+// Do a test with the NN API using no activation.
+TEST(NNAPIDelegate, AddWithNoActivation) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
+}
+
+// Do a test with the NN api with relu.
+TEST(NNAPIDelegate, AddWithRelu) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3}));
+}
+
+class FloatMulOpModel : public SingleOpModel {
+ public:
+ FloatMulOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output,
+ ActivationFunctionType activation_type) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions,
+ CreateMulOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, MulWithNoActivation) {
+ FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
+}
+
+class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
+ public:
+ FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
+ int filter_width, int filter_height,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ type, BuiltinOptions_Pool2DOptions,
+ CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width,
+ filter_height, ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, AveragePoolWithNoActivation) {
+ FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
+}
+
+TEST(NNAPIDelegate, MaxPoolWithNoActivation) {
+ FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
+}
+
+TEST(NNAPIDelegate, L2PoolWithNoActivation) {
+ FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
+ /*filter_width=*/2, /*filter_height=*/2,
+ /*output=*/{TensorType_FLOAT32, {}});
+ m.SetInput({
+ 0, 6, 2, 4, //
+ 3, 2, 10, 7, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
+}
+
+class BaseConvolutionOpModel : public SingleOpModel {
+ public:
+ BaseConvolutionOpModel(
+ const TensorData& input, const TensorData& filter,
+ const TensorData& output, int stride_width = 2, int stride_height = 2,
+ enum Padding padding = Padding_VALID,
+ enum ActivationFunctionType activation = ActivationFunctionType_NONE,
+ int dilation_width_factor = 1, int dilation_height_factor = 1) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[0];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+ if (input.type != TensorType_FLOAT32) {
+ // The following is required by quantized inference. It is the unittest's
+ // responsibility to make sure the output scale falls into the correct
+ // range.
+ CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_));
+ }
+
+ SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions,
+ CreateConv2DOptions(
+ builder_, padding, stride_width, stride_height, activation,
+ dilation_width_factor, dilation_height_factor)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+class ConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ QuantizeAndPopulate<uint8_t>(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int32_t>(bias_, data);
+ }
+
+ std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+};
+
+// In this tests we set the input and output scales so that the results
+// match exactly the 'non-quantized' version.
+TEST(NNAPIDelegate, SimpleTestQuantized) {
+ QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
+ {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 1e-5)));
+ // For good measure, let's also verify the quantized values:
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 145, 129, 132, //
+ 145, 129, 132, //
+ 144, 131, 130, //
+ 164, 131, 130, //
+ }));
+}
+
+TEST(NNAPIDelegate, Conv2DWithNoActivation) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {3, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
+ public:
+ DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
+ const TensorData& output) {
+ input_ = AddInput(input);
+ filter_ = AddInput(filter);
+
+ int bias_size = GetShape(filter_)[3];
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+
+ int input_depth = GetShape(input_)[3];
+ int output_depth = GetShape(filter_)[3];
+ int depth_mul = output_depth / input_depth;
+
+ SetBuiltinOp(
+ BuiltinOperator_DEPTHWISE_CONV_2D,
+ BuiltinOptions_DepthwiseConv2DOptions,
+ CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
+ ActivationFunctionType_NONE)
+ .Union());
+
+ BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
+ }
+
+ void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int filter_;
+ int bias_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
+ DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_FLOAT32, {1, 2, 2, 4}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ 1, 2, 7, 8, // column 1
+ 3, 4, 9, 10, // column 2
+ 5, 6, 11, 12, // column 3
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, //
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ });
+ m.SetBias({1, 2, 3, 4});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 71, -34, 99, -20, //
+ 91, -26, 127, -4, //
+ }));
+}
+
+class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI {
+ public:
+ FloatFullyConnectedOpModel(int units, int batches, const TensorData& input,
+ const TensorData& output = {TensorType_FLOAT32})
+ : batches_(batches), units_(units) {
+ int total_input_size = 1;
+ for (int i = 0; i < input.shape.size(); ++i) {
+ total_input_size *= input.shape[i];
+ }
+ input_size_ = total_input_size / batches_;
+
+ input_ = AddInput(input);
+ weights_ =
+ AddInput({input.type, {units_, input_size_}, input.min, input.max});
+
+ if (input.type == TensorType_FLOAT32) {
+ bias_ = AddInput({TensorType_FLOAT32, {units_}});
+ } else {
+ // This is a quantized version. The scale of 'bias' depends on the scales
+ // of input and filter. Supposedly this is correctly set during quantized
+ // training.
+ auto bias_scale = GetScale(input_) * GetScale(weights_);
+ TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
+
+ output_ = AddOutput(output);
+
+ SetBuiltinOp(
+ BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ .Union());
+ BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ }
+
+ int input_size() { return input_size_; }
+ int num_units() { return units_; }
+ int num_batches() { return batches_; }
+
+ void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+
+ void SetWeights(std::initializer_list<float> f) {
+ PopulateTensor(weights_, f);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int weights_;
+ int bias_;
+ int output_;
+
+ int batches_;
+ int units_;
+ int input_size_;
+};
+
+TEST(NNAPIDelegate, FullyConnectedSimpleTest) {
+ FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 10}});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
+}
+
+class SoftmaxOpModel : public SingleOpModelWithNNAPI {
+ public:
+ SoftmaxOpModel(int batches, int size, float beta)
+ : batches_(batches), input_size_(size), beta_(beta) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
+ CreateSoftmaxOptions(builder_, beta_).Union());
+ BuildInterpreter({{batches_, input_size_}});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetInput(int offset, float* begin, float* end) {
+ PopulateTensor(input_, offset, begin, end);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ private:
+ int input_;
+ int output_;
+
+ int batches_;
+ int input_size_;
+ float beta_;
+};
+
+TEST(NNAPIDelegate, SoftmaxSimpleTest) {
+ SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0);
+ m.SetInput({
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647,
+ 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231},
+ 1e-6)));
+}
+
+class ReshapeOpModel : public SingleOpModelWithNNAPI {
+ public:
+ ReshapeOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> new_shape) {
+ input_ = AddInput(TensorType_FLOAT32);
+ new_shape_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions,
+ CreateReshapeOptions(builder_, builder_.CreateVector<int>(new_shape))
+ .Union());
+ BuildInterpreter({input_shape, {static_cast<int>(new_shape.size())}});
+ PopulateTensor<int>(new_shape_, new_shape);
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int new_shape_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, ReshapeSimpleTest) {
+ ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
+}
+
+class SqueezeOpModel : public SingleOpModelWithNNAPI {
+ public:
+ SqueezeOpModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(
+ BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions,
+ CreateSqueezeOptions(builder_, builder_.CreateVector<int>(axis))
+ .Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int new_shape_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, SqueezeSimpleTest) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
+ {});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24}));
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
+}
+
+TEST(NNAPIDelegate, SqueezeWithAxisTest) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, {TensorType_FLOAT32, {24}},
+ {2});
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24}));
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
+ 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
+ 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
+}
+
+class TransposeSimpleModel : public SingleOpModelWithNNAPI {
+ public:
+ TransposeSimpleModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> perm_shape,
+ std::initializer_list<int> perm) {
+ input_ = AddInput(TensorType_FLOAT32);
+ perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+ CreateTransposeOptions(builder_).Union());
+ BuildInterpreter({input_shape, perm_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int perm_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, TransposeSimpleTest) {
+ TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
+ 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh
index 436c3e1d4c..840015a7fa 100755
--- a/tensorflow/contrib/lite/download_dependencies.sh
+++ b/tensorflow/contrib/lite/download_dependencies.sh
@@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then
fi
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"
-# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once
-# the archive has been propagated in mirror.bazel.build.
-GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
+GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 5700007256..4d2437e7d3 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -1,6 +1,8 @@
# Description:
# TensorFlow camera demo app for Android.
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
@@ -24,28 +26,29 @@ cc_library(
android_binary(
name = "tflite_demo",
srcs = glob([
- "src/**/*.java",
+ "app/src/main/java/**/*.java",
]),
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
assets = [
- "//tensorflow/contrib/lite/examples/android/assets:labels_mobilenet_quant_v1_224.txt",
+ "//tensorflow/contrib/lite/examples/android/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
"@tflite_conv_actions_frozen//:conv_actions_frozen.tflite",
- "//tensorflow/contrib/lite/examples/android/assets:conv_actions_labels.txt",
+ "//tensorflow/contrib/lite/examples/android/app/src/main/assets:conv_actions_labels.txt",
"@tflite_mobilenet_ssd//:mobilenet_ssd.tflite",
- "//tensorflow/contrib/lite/examples/android/assets:box_priors.txt",
- "//tensorflow/contrib/lite/examples/android/assets:coco_labels_list.txt",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
+ "//tensorflow/contrib/lite/examples/android/app/src/main/assets:box_priors.txt",
+ "//tensorflow/contrib/lite/examples/android/app/src/main/assets:coco_labels_list.txt",
],
assets_dir = "",
custom_package = "org.tensorflow.lite.demo",
inline_constants = 1,
- manifest = "AndroidManifest.xml",
+ manifest = "app/src/main/AndroidManifest.xml",
nocompress_extensions = [
".tflite",
],
- resource_files = glob(["res/**"]),
+ resource_files = glob(["app/src/main/res/**"]),
tags = [
"manual",
"notap",
@@ -55,31 +58,3 @@ android_binary(
"//tensorflow/contrib/lite/java:tensorflowlite",
],
)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- "bin/**",
- "gen/**",
- "gradleBuild/**",
- "libs/**",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
-filegroup(
- name = "java_files",
- srcs = glob(["src/**/*.java"]),
-)
-
-filegroup(
- name = "resource_files",
- srcs = glob(["res/**"]),
-)
-
-exports_files(["AndroidManifest.xml"])
diff --git a/tensorflow/contrib/lite/examples/android/android.iml b/tensorflow/contrib/lite/examples/android/android.iml
new file mode 100644
index 0000000000..f0a5ac2bf4
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/android.iml
@@ -0,0 +1,19 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module external.linked.project.id="android" external.linked.project.path="$MODULE_DIR$" external.root.project.path="$MODULE_DIR$" external.system.id="GRADLE" type="JAVA_MODULE" version="4">
+ <component name="FacetManager">
+ <facet type="java-gradle" name="Java-Gradle">
+ <configuration>
+ <option name="BUILD_FOLDER_PATH" value="$MODULE_DIR$/build" />
+ <option name="BUILDABLE" value="false" />
+ </configuration>
+ </facet>
+ </component>
+ <component name="NewModuleRootManager" LANGUAGE_LEVEL="JDK_1_8" inherit-compiler-output="true">
+ <exclude-output />
+ <content url="file://$MODULE_DIR$">
+ <excludeFolder url="file://$MODULE_DIR$/.gradle" />
+ </content>
+ <orderEntry type="inheritedJdk" />
+ <orderEntry type="sourceFolder" forTests="false" />
+ </component>
+</module> \ No newline at end of file
diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle
new file mode 100644
index 0000000000..1ffb9dd377
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/app/build.gradle
@@ -0,0 +1,60 @@
+apply plugin: 'com.android.application'
+
+android {
+ compileSdkVersion 26
+ buildToolsVersion '26.0.2'
+ defaultConfig {
+ applicationId "org.tensorflow.lite.demo"
+ minSdkVersion 15
+ targetSdkVersion 26
+ versionCode 1
+ versionName "1.0"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
+
+ // Remove this block.
+ jackOptions {
+ enabled true
+ }
+ }
+ lintOptions {
+ abortOnError false
+ }
+ buildTypes {
+ release {
+ minifyEnabled false
+ proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
+ }
+ }
+ aaptOptions {
+ noCompress "tflite"
+ }
+
+ compileOptions {
+ sourceCompatibility JavaVersion.VERSION_1_8
+ targetCompatibility JavaVersion.VERSION_1_8
+ }
+}
+
+repositories {
+ maven {
+ url 'https://google.bintray.com/tensorflow'
+ }
+}
+
+// import DownloadModels task
+project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
+project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
+
+// Download default models; if you wish to use your own models then
+// place them in the "assets" directory and comment out this line.
+apply from: "download-models.gradle"
+
+dependencies {
+ compile fileTree(dir: 'libs', include: ['*.jar'])
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
+ exclude group: 'com.android.support', module: 'support-annotations'
+ })
+ compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
+
+ testCompile 'junit:junit:4.12'
+}
diff --git a/tensorflow/contrib/lite/examples/android/app/download-models.gradle b/tensorflow/contrib/lite/examples/android/app/download-models.gradle
new file mode 100644
index 0000000000..c100e37c16
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/app/download-models.gradle
@@ -0,0 +1,74 @@
+/*
+ * download-models.gradle
+ * Downloads model files from ${MODEL_URL} into application's asset folder
+ * Input:
+ * project.ext.TMP_DIR: absolute path to hold downloaded zip files
+ * project.ext.ASSET_DIR: absolute path to save unzipped model files
+ * Output:
+ * 3 model files will be downloaded into given folder of ext.ASSET_DIR
+ */
+// hard coded model files
+// LINT.IfChange
+
+def models = ['conv_actions_tflite.zip',
+ 'mobilenet_ssd_tflite_v1.zip',
+ 'mobilenet_v1_224_android_quant_2017_11_08.zip',
+ 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip']
+// LINT.ThenChange(//tensorflow/contrib/lite/examples/android/BUILD)
+
+// Root URL for model archives
+def MODEL_URL = 'https://storage.googleapis.com/download.tensorflow.org/models/tflite'
+
+buildscript {
+ repositories {
+ jcenter()
+ }
+ dependencies {
+ classpath 'de.undercouch:gradle-download-task:3.2.0'
+ }
+}
+
+import de.undercouch.gradle.tasks.download.Download
+task downloadFile(type: Download){
+ for (f in models) {
+ def modelUrl = MODEL_URL + "/" + f
+ println "Downloading ${f} from ${modelUrl}"
+ src modelUrl
+ }
+
+ dest new File(project.ext.TMP_DIR)
+ overwrite true
+}
+
+task extractModels(type: Copy) {
+ for (f in models) {
+ def localFile = f.split("/")[-1]
+ from zipTree(project.ext.TMP_DIR + '/' + localFile)
+ }
+
+ into file(project.ext.ASSET_DIR)
+ fileMode 0644
+ exclude '**/LICENSE'
+
+ def needDownload = false
+ for (f in models) {
+ def localFile = f.split("/")[-1]
+ if (!(new File(project.ext.TMP_DIR + '/' + localFile)).exists()) {
+ needDownload = true
+ }
+ }
+
+ if (needDownload) {
+ dependsOn downloadFile
+ }
+}
+
+tasks.whenTaskAdded { task ->
+ if (task.name == 'assembleDebug') {
+ task.dependsOn 'extractModels'
+ }
+ if (task.name == 'assembleRelease') {
+ task.dependsOn 'extractModels'
+ }
+}
+
diff --git a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml
index bc9574d646..bc9574d646 100644
--- a/tensorflow/contrib/lite/examples/android/AndroidManifest.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/AndroidManifest.xml
diff --git a/tensorflow/contrib/lite/examples/android/assets/BUILD b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD
index dd0cd6c98f..dd0cd6c98f 100644
--- a/tensorflow/contrib/lite/examples/android/assets/BUILD
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/BUILD
diff --git a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt
index 7246b073fe..7246b073fe 100644
--- a/tensorflow/contrib/lite/examples/android/assets/box_priors.txt
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/box_priors.txt
diff --git a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt
index 5a70ff82aa..5a70ff82aa 100644
--- a/tensorflow/contrib/lite/examples/android/assets/coco_labels_list.txt
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/coco_labels_list.txt
diff --git a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt
index ba416458b0..ba416458b0 100644
--- a/tensorflow/contrib/lite/examples/android/assets/conv_actions_labels.txt
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/conv_actions_labels.txt
diff --git a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
index fe811239d8..fe811239d8 100644
--- a/tensorflow/contrib/lite/examples/android/assets/labels_mobilenet_quant_v1_224.txt
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/labels_mobilenet_quant_v1_224.txt
diff --git a/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt
new file mode 100644
index 0000000000..d581f733e4
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/assets/pets_labels_list.txt
@@ -0,0 +1,38 @@
+???
+Abyssinian
+american_bulldog
+american_pit_bull_terrier
+basset_hound
+beagle
+Bengal
+Birman
+Bombay
+boxer
+British_Shorthair
+chihuahua
+Egyptian_Mau
+english_cocker_spaniel
+english_setter
+german_shorthaired
+great_pyrenees
+havanese
+japanese_chin
+keeshond
+leonberger
+Maine_Coon
+miniature_pinscher
+newfoundland
+Persian
+pomeranian
+pug
+Ragdoll
+Russian_Blue
+saint_bernard
+samoyed
+scottish_terrier
+shiba_inu
+Siamese
+Sphynx
+staffordshire_bull_terrier
+wheaten_terrier
+yorkshire_terrier
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java
index eff24afdba..eff24afdba 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/AutoFitTextureView.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/AutoFitTextureView.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java
index 15d5456f02..15d5456f02 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraActivity.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java
index 51a1adb538..51a1adb538 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/CameraConnectionFragment.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/CameraConnectionFragment.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java
index 07995febaf..07995febaf 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/Classifier.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/Classifier.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java
index dcbbefbeab..dcbbefbeab 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ClassifierActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ClassifierActivity.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
index de997e454a..87160f6b3f 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/DetectorActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/DetectorActivity.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+ * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -50,9 +50,10 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
// Configuration values for the prepackaged SSD model.
private static final int TF_OD_API_INPUT_SIZE = 300;
- private static final String TF_OD_API_MODEL_FILE = "mobilenet_ssd.tflite";
+ private static final boolean TF_OD_API_IS_QUANTIZED = true;
+ private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco_labels_list.txt";
-
+
// Which detection model to use: by default uses Tensorflow Object Detection API frozen
// checkpoints.
private enum DetectorMode {
@@ -107,7 +108,11 @@ public class DetectorActivity extends CameraActivity implements OnImageAvailable
try {
detector =
TFLiteObjectDetectionAPIModel.create(
- getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE);
+ getAssets(),
+ TF_OD_API_MODEL_FILE,
+ TF_OD_API_LABELS_FILE,
+ TF_OD_API_INPUT_SIZE,
+ TF_OD_API_IS_QUANTIZED);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
LOGGER.e("Exception initializing classifier!", e);
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java
index fd83029753..fd83029753 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/LegacyCameraConnectionFragment.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/LegacyCameraConnectionFragment.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java
index 0f8d109fb4..0f8d109fb4 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/OverlayView.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/OverlayView.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java
index 31a4b07c83..31a4b07c83 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognitionScoreView.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognitionScoreView.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java
index 9e91aea7ef..9e91aea7ef 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/RecognizeCommands.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/RecognizeCommands.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java
index 211d7e66fb..211d7e66fb 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/ResultsView.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/ResultsView.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java
index 9c9c30bc09..9c9c30bc09 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/SpeechActivity.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/SpeechActivity.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java
index d75c3ceada..d75c3ceada 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteImageClassifier.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteImageClassifier.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
index bfb4a0a04b..9eb21de9d0 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/TFLiteObjectDetectionAPIModel.java
@@ -25,15 +25,14 @@ import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.PriorityQueue;
-import java.util.StringTokenizer;
import java.util.Vector;
import org.tensorflow.demo.env.Logger;
import org.tensorflow.lite.Interpreter;
@@ -46,32 +45,35 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
- private static final int NUM_RESULTS = 1917;
- private static final int NUM_CLASSES = 91;
-
- private static final float Y_SCALE = 10.0f;
- private static final float X_SCALE = 10.0f;
- private static final float H_SCALE = 5.0f;
- private static final float W_SCALE = 5.0f;
-
+ private static final int NUM_DETECTIONS = 10;
+ private boolean isModelQuantized;
+ // Float model
+ private static final float IMAGE_MEAN = 128.0f;
+ private static final float IMAGE_STD = 128.0f;
+ // Number of threads in the java app
+ private static final int NUM_THREADS = 4;
// Config values.
private int inputSize;
-
- private final float[][] boxPriors = new float[4][NUM_RESULTS];
-
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
+ // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
+ // contains the location of detected boxes
private float[][][] outputLocations;
- private float[][][] outputClasses;
-
- float[][][][] img;
+ // outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the classes of detected boxes
+ private float[][] outputClasses;
+ // outputScores: array of shape [Batchsize, NUM_DETECTIONS]
+ // contains the scores of detected boxes
+ private float[][] outputScores;
+ // numDetections: array of shape [Batchsize]
+ // contains the number of detected boxes
+ private float[] numDetections;
+
+ private ByteBuffer imgData;
private Interpreter tfLite;
- private float expit(final float x) {
- return (float) (1. / (1. + Math.exp(-x)));
- }
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
@@ -84,77 +86,24 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
- private void loadCoderOptions(
- final AssetManager assetManager, final String locationFilename, final float[][] boxPriors)
- throws IOException {
- // Try to be intelligent about opening from assets or sdcard depending on prefix.
- final String assetPrefix = "file:///android_asset/";
- InputStream is;
- if (locationFilename.startsWith(assetPrefix)) {
- is = assetManager.open(locationFilename.split(assetPrefix, -1)[1]);
- } else {
- is = new FileInputStream(locationFilename);
- }
-
- final BufferedReader reader = new BufferedReader(new InputStreamReader(is));
-
- for (int lineNum = 0; lineNum < 4; ++lineNum) {
- String line = reader.readLine();
- final StringTokenizer st = new StringTokenizer(line, ", ");
- int priorIndex = 0;
- while (st.hasMoreTokens()) {
- final String token = st.nextToken();
- try {
- final float number = Float.parseFloat(token);
- boxPriors[lineNum][priorIndex++] = number;
- } catch (final NumberFormatException e) {
- // Silently ignore.
- }
- }
- if (priorIndex != NUM_RESULTS) {
- throw new RuntimeException(
- "BoxPrior length mismatch: " + priorIndex + " vs " + NUM_RESULTS);
- }
- }
-
- LOGGER.i("Loaded box priors!");
- }
-
- void decodeCenterSizeBoxes(float[][][] predictions) {
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float ycenter = predictions[0][i][0] / Y_SCALE * boxPriors[2][i] + boxPriors[0][i];
- float xcenter = predictions[0][i][1] / X_SCALE * boxPriors[3][i] + boxPriors[1][i];
- float h = (float) Math.exp(predictions[0][i][2] / H_SCALE) * boxPriors[2][i];
- float w = (float) Math.exp(predictions[0][i][3] / W_SCALE) * boxPriors[3][i];
-
- float ymin = ycenter - h / 2.f;
- float xmin = xcenter - w / 2.f;
- float ymax = ycenter + h / 2.f;
- float xmax = xcenter + w / 2.f;
-
- predictions[0][i][0] = ymin;
- predictions[0][i][1] = xmin;
- predictions[0][i][2] = ymax;
- predictions[0][i][3] = xmax;
- }
- }
-
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
+ * @param inputSize The size of image input
+ * @param isQuantized Boolean representing model is quantized or not
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
- final int inputSize) throws IOException {
+ final int inputSize,
+ final boolean isQuantized)
+ throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
- d.loadCoderOptions(assetManager, "file:///android_asset/box_priors.txt", d.boxPriors);
-
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
@@ -175,12 +124,23 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
throw new RuntimeException(e);
}
+ d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
- d.img = new float[1][inputSize][inputSize][3];
-
+ int numBytesPerChannel;
+ if (isQuantized) {
+ numBytesPerChannel = 1; // Quantized
+ } else {
+ numBytesPerChannel = 4; // Floating point
+ }
+ d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
+ d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
- d.outputLocations = new float[1][NUM_RESULTS][4];
- d.outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+
+ d.tfLite.setNumThreads(NUM_THREADS);
+ d.outputLocations = new float[1][NUM_DETECTIONS][4];
+ d.outputClasses = new float[1][NUM_DETECTIONS];
+ d.outputScores = new float[1][NUM_DETECTIONS];
+ d.numDetections = new float[1];
return d;
}
@@ -196,25 +156,37 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
- int pixel = intValues[j * inputSize + i];
- img[0][j][i][2] = (float) (pixel & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][1] = (float) ((pixel >> 8) & 0xFF) / 128.0f - 1.0f;
- img[0][j][i][0] = (float) ((pixel >> 16) & 0xFF) / 128.0f - 1.0f;
+ int pixelValue = intValues[i * inputSize + j];
+ if (isModelQuantized) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else { // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
- outputLocations = new float[1][NUM_RESULTS][4];
- outputClasses = new float[1][NUM_RESULTS][NUM_CLASSES];
+ outputLocations = new float[1][NUM_DETECTIONS][4];
+ outputClasses = new float[1][NUM_DETECTIONS];
+ outputScores = new float[1][NUM_DETECTIONS];
+ numDetections = new float[1];
- Object[] inputArray = {img};
+ Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
+ outputMap.put(2, outputScores);
+ outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
@@ -222,56 +194,26 @@ public class TFLiteObjectDetectionAPIModel implements Classifier {
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
- decodeCenterSizeBoxes(outputLocations);
-
- // Find the best detections.
- final PriorityQueue<Recognition> pq =
- new PriorityQueue<Recognition>(
- 1,
- new Comparator<Recognition>() {
- @Override
- public int compare(final Recognition lhs, final Recognition rhs) {
- // Intentionally reversed to put high confidence at the head of the queue.
- return Float.compare(rhs.getConfidence(), lhs.getConfidence());
- }
- });
-
- // Scale them back to the input size.
- for (int i = 0; i < NUM_RESULTS; ++i) {
- float topClassScore = -1000f;
- int topClassScoreIndex = -1;
-
- // Skip the first catch-all class.
- for (int j = 1; j < NUM_CLASSES; ++j) {
- float score = expit(outputClasses[0][i][j]);
-
- if (score > topClassScore) {
- topClassScoreIndex = j;
- topClassScore = score;
- }
- }
-
- if (topClassScore > 0.001f) {
- final RectF detection =
- new RectF(
- outputLocations[0][i][1] * inputSize,
- outputLocations[0][i][0] * inputSize,
- outputLocations[0][i][3] * inputSize,
- outputLocations[0][i][2] * inputSize);
-
- pq.add(
- new Recognition(
- "" + i,
- labels.get(topClassScoreIndex),
- outputClasses[0][i][topClassScoreIndex],
- detection));
- }
- }
-
- final ArrayList<Recognition> recognitions = new ArrayList<Recognition>();
- for (int i = 0; i < Math.min(pq.size(), 10); ++i) {
- Recognition recog = pq.poll();
- recognitions.add(recog);
+ // Show the best detections.
+ // after scaling them back to the input size.
+ final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS);
+ for (int i = 0; i < NUM_DETECTIONS; ++i) {
+ final RectF detection =
+ new RectF(
+ outputLocations[0][i][1] * inputSize,
+ outputLocations[0][i][0] * inputSize,
+ outputLocations[0][i][3] * inputSize,
+ outputLocations[0][i][2] * inputSize);
+ // SSD Mobilenet V1 Model assumes class 0 is background class
+ // in label file and class labels start from 1 to number_of_classes+1,
+ // while outputClasses correspond to class index from 0 to number_of_classes
+ int labelOffset = 1;
+ recognitions.add(
+ new Recognition(
+ "" + i,
+ labels.get((int) outputClasses[0][i] + labelOffset),
+ outputScores[0][i],
+ detection));
}
Trace.endSection(); // "recognizeImage"
return recognitions;
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java
index c50efdf889..c50efdf889 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/AssetUtils.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/AssetUtils.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java
index decfc3d879..decfc3d879 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/BorderedText.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/BorderedText.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java
index e02c655917..e02c655917 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/ImageUtils.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/ImageUtils.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java
index 0d984096a0..0d984096a0 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Logger.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Logger.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java
index ef15d14daa..ef15d14daa 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/Size.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/Size.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java
index 459b0a0d4d..459b0a0d4d 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/env/SplitTimer.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/env/SplitTimer.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java
index af6af2bc8f..af6af2bc8f 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/MultiBoxTracker.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/MultiBoxTracker.java
diff --git a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java
index 8b4248d8fb..8b4248d8fb 100644
--- a/tensorflow/contrib/lite/examples/android/src/org/tensorflow/demo/tracking/ObjectTracker.java
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/java/org/tensorflow/demo/tracking/ObjectTracker.java
diff --git a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml
index 891d8cc1d4..891d8cc1d4 100644
--- a/tensorflow/contrib/lite/examples/android/res/animator/color_animation.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/animator/color_animation.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png
index 32bd1aabca..32bd1aabca 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_action_info.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png
index b3113cd15c..b3113cd15c 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png
index 135862883e..135862883e 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-hdpi/tile.9.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-hdpi/tile.9.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png
index 8efbbf8b3c..8efbbf8b3c 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_action_info.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png
index 51f87ee650..51f87ee650 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-mdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-mdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png
index ba143ea7a8..ba143ea7a8 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_action_info.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png
index 6361d792da..6361d792da 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-xhdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png
index 394eb7e534..394eb7e534 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_action_info.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_action_info.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png
index 2e27bec978..2e27bec978 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable-xxhdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable-xxhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml
index dd1d64d1d6..dd1d64d1d6 100644
--- a/tensorflow/contrib/lite/examples/android/res/drawable/border.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/drawable/border.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml
index 1a22d4b33e..1a22d4b33e 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/activity_camera.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_camera.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml
index 2fe1338da5..2fe1338da5 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/activity_speech.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/activity_speech.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml
index a1bbdf1702..a1bbdf1702 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml
index 1cdb24cab0..1cdb24cab0 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_stylize.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_stylize.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml
index ca18ea075d..ca18ea075d 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/camera_connection_fragment_tracking.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/camera_connection_fragment_tracking.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml
index 526017fbb2..526017fbb2 100644
--- a/tensorflow/contrib/lite/examples/android/res/layout/list_text_item.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/layout/list_text_item.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml
index 820eda0e55..820eda0e55 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-dimens.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-dimens.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml
index 09303314e9..09303314e9 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-sw600dp/template-styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-sw600dp/template-styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml
index c2d1babc12..c2d1babc12 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-v11/styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml
index 1ad048439c..1ad048439c 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-v11/template-styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v11/template-styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml
index cc370849c0..cc370849c0 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-v14/styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v14/styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml
index c16da7c51c..c16da7c51c 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-v21/base-colors.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-colors.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml
index 8890d2f4a5..8890d2f4a5 100644
--- a/tensorflow/contrib/lite/examples/android/res/values-v21/base-template-styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values-v21/base-template-styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml
index 56e5beae76..56e5beae76 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/attrs.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/attrs.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml
index ebc5dc8423..ebc5dc8423 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/base-strings.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/base-strings.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/colors.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml
index 584ed6052d..584ed6052d 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/colors.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/colors.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/strings.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml
index ea20ee78e0..ea20ee78e0 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/strings.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/strings.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml
index dd1d973e9b..dd1d973e9b 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml
index 069977b6a6..069977b6a6 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/template-dimens.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-dimens.xml
diff --git a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml
index 1b87714a49..1b87714a49 100644
--- a/tensorflow/contrib/lite/examples/android/res/values/template-styles.xml
+++ b/tensorflow/contrib/lite/examples/android/app/src/main/res/values/template-styles.xml
diff --git a/tensorflow/contrib/lite/examples/android/build.gradle b/tensorflow/contrib/lite/examples/android/build.gradle
index 0d4de35815..a47fa4bbf6 100644
--- a/tensorflow/contrib/lite/examples/android/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/build.gradle
@@ -1,52 +1,23 @@
-apply plugin: 'com.android.application'
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
-android {
- compileSdkVersion 26
- buildToolsVersion "26.0.1"
- defaultConfig {
- applicationId "org.tensorflow.lite.demo"
- minSdkVersion 15
- targetSdkVersion 26
- versionCode 1
- versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
-
- // Remove this block.
- jackOptions {
- enabled true
- }
- }
- lintOptions {
- abortOnError false
- }
- buildTypes {
- release {
- minifyEnabled false
- proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro'
- }
- }
- aaptOptions {
- noCompress "tflite"
+buildscript {
+ repositories {
+ jcenter()
}
+ dependencies {
+ classpath 'com.android.tools.build:gradle:3.0.1'
- compileOptions {
- sourceCompatibility JavaVersion.VERSION_1_8
- targetCompatibility JavaVersion.VERSION_1_8
+ // NOTE: Do not place your application dependencies here; they belong
+ // in the individual module build.gradle files
}
}
-repositories {
- maven {
- url 'https://google.bintray.com/tensorflow'
+allprojects {
+ repositories {
+ jcenter()
}
}
-dependencies {
- compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
- compile 'org.tensorflow:tensorflow-lite:+'
-
- testCompile 'junit:junit:4.12'
+task clean(type: Delete) {
+ delete rootProject.buildDir
}
diff --git a/tensorflow/contrib/lite/examples/android/settings.gradle b/tensorflow/contrib/lite/examples/android/settings.gradle
new file mode 100644
index 0000000000..e7b4def49c
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/android/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD
index 9322e186a2..c61445114e 100644
--- a/tensorflow/contrib/lite/examples/label_image/BUILD
+++ b/tensorflow/contrib/lite/examples/label_image/BUILD
@@ -53,19 +53,18 @@ cc_library(
],
)
-# TODO(ahentz): Test disabled as it has a memory leek from read_bmp
-# cc_test(
-# name = "label_image_test",
-# srcs = [
-# "get_top_n.h",
-# "get_top_n_impl.h",
-# "label_image_test.cc",
-# ],
-# data = [
-# "testdata/grace_hopper.bmp",
-# ],
-# deps = [
-# ":bitmap_helpers",
-# "//testing/base/public:gunit",
-# ],
-# )
+cc_test(
+ name = "label_image_test",
+ srcs = [
+ "get_top_n.h",
+ "get_top_n_impl.h",
+ "label_image_test.cc",
+ ],
+ data = [
+ "testdata/grace_hopper.bmp",
+ ],
+ deps = [
+ ":bitmap_helpers",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
index 0b38cd38c8..2735d1f5ea 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc
@@ -28,8 +28,9 @@ limitations under the License.
namespace tflite {
namespace label_image {
-uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output,
- int width, int height, int channels, bool top_down) {
+std::vector<uint8_t> decode_bmp(const uint8_t* input, int row_size, int width,
+ int height, int channels, bool top_down) {
+ std::vector<uint8_t> output(height * width * channels);
for (int i = 0; i < height; i++) {
int src_pos;
int dst_pos;
@@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output,
}
}
}
-
return output;
}
-uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
- int* channels, Settings* s) {
+std::vector<uint8_t> read_bmp(const std::string& input_bmp_name, int* width,
+ int* height, int* channels, Settings* s) {
int begin, end;
std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary);
@@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
if (s->verbose) LOG(INFO) << "len: " << len << "\n";
- const uint8_t* img_bytes = new uint8_t[len];
+ std::vector<uint8_t> img_bytes(len);
file.seekg(0, std::ios::beg);
- file.read((char*)img_bytes, len);
+ file.read(reinterpret_cast<char*>(img_bytes.data()), len);
const int32_t header_size =
- *(reinterpret_cast<const int32_t*>(img_bytes + 10));
- *width = *(reinterpret_cast<const int32_t*>(img_bytes + 18));
- *height = *(reinterpret_cast<const int32_t*>(img_bytes + 22));
- const int32_t bpp = *(reinterpret_cast<const int32_t*>(img_bytes + 28));
+ *(reinterpret_cast<const int32_t*>(img_bytes.data() + 10));
+ *width = *(reinterpret_cast<const int32_t*>(img_bytes.data() + 18));
+ *height = *(reinterpret_cast<const int32_t*>(img_bytes.data() + 22));
+ const int32_t bpp =
+ *(reinterpret_cast<const int32_t*>(img_bytes.data() + 28));
*channels = bpp / 8;
if (s->verbose)
@@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
bool top_down = (*height < 0);
// Decode image, allocating tensor once the image size is known
- uint8_t* output = new uint8_t[abs(*height) * *width * *channels];
const uint8_t* bmp_pixels = &img_bytes[header_size];
- return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height),
- *channels, top_down);
+ return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels,
+ top_down);
}
} // namespace label_image
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
index 97343dde6b..5fc75b1f72 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h
@@ -22,8 +22,8 @@ limitations under the License.
namespace tflite {
namespace label_image {
-uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height,
- int* channels, Settings* s);
+std::vector<uint8_t> read_bmp(const std::string& input_bmp_name, int* width,
+ int* height, int* channels, Settings* s);
template <class T>
void resize(T* out, uint8_t* in, int image_height, int image_width,
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc
index 966fcd2a31..86d7d1cc4a 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc
@@ -138,8 +138,8 @@ void RunInference(Settings* s) {
int image_width = 224;
int image_height = 224;
int image_channels = 3;
- uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height,
- &image_channels, s);
+ std::vector<uint8_t> in = read_bmp(s->input_bmp_name, &image_width,
+ &image_height, &image_channels, s);
int input = interpreter->inputs()[0];
if (s->verbose) LOG(INFO) << "input: " << input << "\n";
@@ -168,12 +168,12 @@ void RunInference(Settings* s) {
switch (interpreter->tensor(input)->type) {
case kTfLiteFloat32:
s->input_floating = true;
- resize<float>(interpreter->typed_tensor<float>(input), in, image_height,
- image_width, image_channels, wanted_height, wanted_width,
- wanted_channels, s);
+ resize<float>(interpreter->typed_tensor<float>(input), in.data(),
+ image_height, image_width, image_channels, wanted_height,
+ wanted_width, wanted_channels, s);
break;
case kTfLiteUInt8:
- resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in,
+ resize<uint8_t>(interpreter->typed_tensor<uint8_t>(input), in.data(),
image_height, image_width, image_channels, wanted_height,
wanted_width, wanted_channels, s);
break;
diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
index ce35483f76..de7de21f77 100644
--- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
+++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc
@@ -27,20 +27,20 @@ namespace label_image {
TEST(LabelImageTest, GraceHopper) {
std::string lena_file =
- "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp";
+ "tensorflow/contrib/lite/examples/label_image/testdata/"
+ "grace_hopper.bmp";
int height, width, channels;
Settings s;
- uint8_t *data;
-
- data = read_bmp(lena_file, &width, &height, &channels, &s);
+ std::vector<uint8_t> input =
+ read_bmp(lena_file, &width, &height, &channels, &s);
ASSERT_EQ(height, 606);
ASSERT_EQ(width, 517);
ASSERT_EQ(channels, 3);
- uint8_t *out = new uint8_t[606 * 517 * 3];
- downsize<uint8_t>(out, data, 606, 517, 3, 214, 214, 3, &s);
- ASSERT_EQ(out[0], 0x15);
- ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12);
+ std::vector<uint8_t> output(606 * 517 * 3);
+ resize<uint8_t>(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s);
+ ASSERT_EQ(output[0], 0x15);
+ ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11);
}
TEST(LabelImageTest, GetTopN) {
diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD
new file mode 100644
index 0000000000..b403628d6c
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/minimal/BUILD
@@ -0,0 +1,27 @@
+# Description:
+# TensorFlow Lite minimal example.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
+
+tf_cc_binary(
+ name = "minimal",
+ srcs = [
+ "minimal.cc",
+ ],
+ linkopts = tflite_linkopts() + select({
+ "//tensorflow:android": [
+ "-pie", # Android 5.0 and later supports only PIE
+ "-lm", # some builtin ops, e.g., tanh, need -lm
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc
index 106e3b0270..8b65cde7b7 100644
--- a/tensorflow/contrib/lite/examples/minimal/minimal.cc
+++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/model.h"
+#include <cstdio>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
-#include <cstdio>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/optional_debug_tools.h"
// This is an example that is minimal to read a model
// from disk and perform inference. There is no data being loaded
@@ -29,23 +30,22 @@ limitations under the License.
using namespace tflite;
-#define TFLITE_MINIMAL_CHECK(x) \
- if(!(x)) { \
- fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
- exit(1); \
+#define TFLITE_MINIMAL_CHECK(x) \
+ if (!(x)) { \
+ fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
+ exit(1); \
}
-
-int main(int argc, char *argv[]) {
+int main(int argc, char* argv[]) {
if(argc != 2) {
- fprintf(stderr, "Usage: %s <model>\n");
+ fprintf(stderr, "minimal <tflite model>\n");
return 1;
}
const char* filename = argv[1];
// Load model
- std::unique_ptr<tflite::FlatBufferModel> model
- = tflite::FlatBufferModel::BuildFromFile(filename);
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(filename);
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
@@ -57,12 +57,16 @@ int main(int argc, char *argv[]) {
// Allocate tensor buffers.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
+ printf("=== Pre-invoke Interpreter State ===\n");
+ tflite::PrintInterpreterState(interpreter.get());
// Fill input buffers
// TODO(user): Insert code to fill input tensors
// Run inference
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
+ printf("\n\n=== Post-invoke Interpreter State ===\n");
+ tflite::PrintInterpreterState(interpreter.get());
// Read output buffers
// TODO(user): Insert getting data out code.
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index 50cc146a87..a591a353dd 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -7,6 +7,9 @@ no surprise that the APIs try to avoid unnecessary copies at the expense of
convenience. Similarly, consistency with TensorFlow APIs was not an explicit
goal and some variance is to be expected.
+There is also a Python API for TensorFlow Lite described
+[here](../toco/g3doc/python_api.md#interpreter).
+
## C++
In order to run the inference model in TensorFlow Lite, one has to load the
diff --git a/tensorflow/contrib/lite/g3doc/benchmarks.md b/tensorflow/contrib/lite/g3doc/benchmarks.md
new file mode 100644
index 0000000000..96536cba27
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/benchmarks.md
@@ -0,0 +1,178 @@
+# Performance Benchmark numbers
+
+This document contains the performance benchmark numbers for running a few well
+known models on some Android and iOS devices.
+
+The benchmark numbers were generated by running the [TFLite benchmark
+binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
+on Android and running the [iOS benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+on iOS.
+
+# Android benchmarks
+
+When running Android benchmarks, the CPU affinity is set to use big cores on the
+device to reduce variance (see
+[details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+Models are assumed to have been downloaded from the link, unzipped and pushed to
+`/data/local/tmp/tflite_models` folder. The benchmark binary is built according
+to instructions listed
+[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
+and is assumed to have been pushed to `/data/local/tmp`.
+
+The following command was used to run the benchmark:
+
+```
+adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
+ --num_threads=1 \
+ --graph=/data/local/tmp/tflite_models/${GRAPH} \
+ --warmup_runs=1 \
+ --num_runs=50 \
+ --use_nnapi=false
+```
+
+where `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
+chosen according to the following table:
+
+Device | CPU_MASK |
+-------| ----------
+Pixel 2 | f0 |
+Pixel xl | 0c |
+
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>166.5 ms (2.6 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>122.9 ms (1.8 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>69.5 ms (0.9 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>78.9 ms (2.2 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>273.8 ms (3.5 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>210.8 ms (4.2 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>234.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>158.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>2846.0 ms (15.0 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>1973.0 ms (15.0 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>3180.0 ms (11.7 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>2262.0 ms (21.0 ms) </td>
+ </tr>
+
+ </table>
+
+# iOS benchmarks
+
+For running iOS benchmarks, the [benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+was modified to include the appropriate model and `benchmark_params.json` was
+modified to set `num_threads` to 1.
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>32.2 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>24.4 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>60.3 ms (0.6 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>44.3 (0.7 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>iPhone 8</td>
+ <td>562.4 ms (18.2 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>661.0 ms (29.2 ms)</td>
+ </tr>
+ </table>
diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md
new file mode 100644
index 0000000000..bd2f797e6c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md
@@ -0,0 +1,206 @@
+# TensorFlow Lite Ops Versioning
+
+This document describes TensorFlow Lite's op versioning schema. Op
+versioning enables developers to add new functionalities and parameters into
+existing ops. In addition, it guarantees the following:
+
+* Backward compatibility: New TensorFlow Lite implementation should
+ handle an old model file.
+* Forward compatibility: Old TensorFlow Lite implementation should
+ handle a new model file produced by new version of TOCO, as long as no new
+ features are used.
+* Forward in-compatibility detection: If an old TensorFlow Lite implementation
+ reads a new model that contains a new version of an op which isn't
+ supported, it should report the error.
+
+## Example: Adding Dilation into Convolution
+
+The remainder of this document explains op versioning in TFLite by showing how
+to add dilation parameters to the convolution operation.
+
+Knowledge of dilation is not required to understand this document. Note that:
+
+* 2 new integer parameters will be added: `dilation_width_factor` and
+ `dilation_height_factor`.
+* Old convolution kernels that don't support dilation are equivalent to
+ setting the dilation factors to 1.
+
+### Change FlatBuffer Schema
+
+To add new parameters into an op, change the options table in
+`lite/schema/schema.fbs`.
+
+For example, the options table of convolution looks like this:
+
+```
+table Conv2DOptions {
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+}
+```
+
+When adding new parameters:
+
+* Add comments indicating which parameters are supported by which version.
+* When the new implementation gets the default values for newly added
+ parameters, it should work exactly the same as the old implementation.
+
+The table will be like this after the new parameters are added:
+
+```
+table Conv2DOptions {
+ // Parameters supported by version 1:
+ padding:Padding;
+ stride_w:int;
+ stride_h:int;
+ fused_activation_function:ActivationFunctionType;
+
+ // Parameters supported by version 2:
+ dilation_width_factor:int = 1;
+ dilation_height_factor:int = 1;
+}
+```
+
+### Change C Structures and Kernel Implementation
+
+In TensorFlow Lite, the kernel implementation is decoupled from
+FlatBuffer definition. The kernels read the parameter from C structures defined
+in `lite/builtin_op_data.h`.
+
+The original convolution parameter is as follows:
+
+```
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+```
+
+As with the FlatBuffer schema, add comments indicating which parameters are
+supported starting from which version. The result is seen below:
+
+```
+typedef struct {
+ // Parameters supported by version 1: TfLitePadding padding; int
+ stride_width;
+ int stride_height;
+ TfLiteFusedActivation activation;
+
+ // Parameters supported by version 2:
+ int dilation_width_factor;
+ int dilation_height_factor;
+} TfLiteConvParams;
+```
+
+Please also change the kernel implementation to read the newly added parameters
+from the C structures. The details are omitted here.
+
+### Change the FlatBuffer Reading Code
+
+The logic to read FlatBuffer and produce C structure is in `lite/model.cc`.
+
+Update the file to handle the new parameters, as shown below:
+
+```
+case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ params->dilation_width_factor = conv_params->dilation_width_factor();
+ params->dilation_height_factor = conv_params->dilation_height_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+}
+```
+
+It's not required to check the op version here. When the new implementation
+reads an old model file where dilation factors are missing, it will use 1 as
+the default value, and the new kernel will work consistently with the old
+kernel.
+
+### Change Kernel Registration
+
+The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions
+to register op kernels. The minimum and maximum version are 1 by default:
+
+```
+void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+```
+
+The built-in ops are registered in `lite/kernels/register.cc`. In this example,
+we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we
+need to change this line:
+
+```
+AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
+```
+
+to:
+
+```
+AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2);
+```
+
+### Change TOCO TFLite exporter
+
+The last step is to make TOCO populate the minimum version that's required to
+execute the op. In this example, it means:
+
+* Populate version=1 when dilation factors are all 1.
+* Populate version=2 otherwise.
+
+To do this, you need to override `GetVersion` function for the operator class in
+`lite/toco/tflite/operator.cc`.
+
+For ops with only one version, the `GetVersion` function is defined as:
+
+```
+int GetVersion(const Operator& op) const override { return 1; }
+```
+
+When supporting multiple versions, check the parameters and determine the
+version for the op, as shown in the following example:
+
+```
+int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const ConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+}
+```
+
+### Delegation Implementation
+
+TensorFlow Lite provides a delegation API which enables delegating ops to
+hardware backends. In Delegate's `Prepare` function, check if the version
+is supported for every node in Delegation code.
+
+```
+const int kMinVersion = 1;
+TfLiteNode* node;
+TfLiteRegistration;
+context->GetNodeAndRegistration(context, node_index, &node, &registration);
+
+if (registration->version > kMinVersion) {
+ // Reject the node if the version isn't supported.
+}
+```
+
+This is required even if the delegation only supports version 1 ops, so the
+delegation can detect incompatibility when getting a higher version op.
+
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index b2f6444e9e..49d00a66ba 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -42,6 +42,7 @@ counterparts:
*as long as the input tensor is 4D (1 batch + 2 spatial + 1 other) and the
crops attribute is not used*
* [tf.exp](https://www.tensorflow.org/api_docs/python/tf/exp)
+* [tf.fake_quant*](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul) - *as long
as the second argument is constant and transposition is not used*
* [tf.nn.avg_pool](https://www.tensorflow.org/api_docs/python/tf/nn/avg_pool)
@@ -95,11 +96,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph:
* [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide)
* [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args)
* [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars)
-* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater)
-* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal)
* [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity)
-* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less)
-* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal)
* [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum)
* [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum)
* [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply)
@@ -257,6 +254,19 @@ Options {
}
```
+**EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is
+ equal to the corresponding element of the second tensor.
+}
+```
+
**EXP**
```
@@ -420,6 +430,17 @@ Outputs {
}
```
+**LOG**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to log(input)
+}
+```
+
**LOG_SOFTMAX**
```
@@ -503,6 +524,19 @@ Options {
}
```
+**NOT_EQUAL**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: a tensor of type bool, true whenever an element of the first tensor is not
+ equal to the corresponding element of the second tensor.
+}
+```
+
**RELU**
```
@@ -551,6 +585,31 @@ Options {
}
```
+**RSQRT**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: result of computing element-wise reciprocal square root of the input tensor
+}
+```
+
+**SHAPE**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a 1D tensor representing the shape of the input tensor
+}
+Options {
+ out_type: the output type of the op (int32 or int64). Defaults to int32.
+}
+```
+
**SLICE**
```
@@ -637,6 +696,17 @@ Options {
}
```
+**SQRT**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: result of computing element-wise square root of the input tensor
+}
+```
+
**SQUEEZE**
```
@@ -709,6 +779,42 @@ Outputs {
}
```
+**POW**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: elementwise pow of the input tensors
+}
+```
+
+**ARG_MAX**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of maximum values.
+}
+```
+
+**ARG_MIN**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of minium values.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 313af5fb75..77268d7aeb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -46,6 +46,9 @@ class GraphInfo {
// Returns the indices of the output tensors.
virtual const std::vector<int>& outputs() const = 0;
+
+ // Returns the indices of the variable tensors.
+ virtual const std::vector<int>& variables() const = 0;
};
// Represents a subgraph of a TensorFlow Lite graph.
diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc
index ea38b43993..89a8f36b41 100644
--- a/tensorflow/contrib/lite/graph_info_test.cc
+++ b/tensorflow/contrib/lite/graph_info_test.cc
@@ -45,6 +45,7 @@ class SimpleTestGraph : public GraphInfo {
TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; }
const std::vector<int>& inputs() const override { return inputs_; }
const std::vector<int>& outputs() const override { return outputs_; }
+ const std::vector<int>& variables() const override { return variables_; }
void AddNode(const std::vector<int>& inputs,
const std::vector<int>& outputs) {
@@ -67,6 +68,7 @@ class SimpleTestGraph : public GraphInfo {
std::vector<TfLiteTensor> tensors_;
std::vector<int> inputs_;
std::vector<int> outputs_;
+ std::vector<int> variables_;
};
// Partition a graph to generate a list of subgraphs. This wraps the API call
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index ebb0aedc20..0641a08636 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -22,17 +22,21 @@ limitations under the License.
#include "tensorflow/contrib/lite/arena_planner.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
-#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/memory_planner.h"
+#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
+#endif
#include "tensorflow/contrib/lite/profiling/profiler.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
+#ifdef TFLITE_MCU
+class NNAPIDelegate {};
+#endif
namespace {
@@ -53,6 +57,19 @@ void SetForbiddenContextFunction(FunctionType* func) {
*func = reinterpret_cast<FunctionType>(ForbiddenContextFunction);
}
+// Returns true if at least one tensor in the given list is kTfLiteDynamic.
+template <typename TensorIntArray>
+bool HasDynamicTensorImpl(const TfLiteContext& context,
+ const TensorIntArray& int_array) {
+ for (int i : int_array) {
+ const TfLiteTensor& tensor = context.tensors[i];
+ if (tensor.allocation_type == kTfLiteDynamic) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
// A trivial implementation of GraphInfo around the Interpreter.
@@ -82,6 +99,9 @@ class InterpreterInfo : public GraphInfo {
const std::vector<int>& outputs() const override {
return interpreter_->outputs();
}
+ const std::vector<int>& variables() const override {
+ return interpreter_->variables();
+ }
public:
Interpreter* interpreter_;
@@ -96,9 +116,9 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
- context_.eigen_context = nullptr;
- context_.gemm_context = nullptr;
context_.recommended_num_threads = -1;
+ context_.GetExternalContext = GetExternalContext;
+ context_.SetExternalContext = SetExternalContext;
// Invalid to call these these except from TfLiteDelegate
SetForbiddenContextFunction(&context_.GetNodeAndRegistration);
@@ -109,6 +129,11 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
tensors_.reserve(kTensorsReservedCapacity);
nodes_and_registration_.reserve(kTensorsReservedCapacity);
next_execution_plan_index_to_prepare_ = 0;
+
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ external_contexts_[i] = nullptr;
+ }
+
UseNNAPI(false);
}
@@ -266,6 +291,33 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
return kTfLiteOk;
}
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ TfLiteExternalContextType type) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ return external_contexts_[type];
+ }
+ return nullptr;
+}
+
+TfLiteExternalContext* Interpreter::GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type) {
+ return static_cast<Interpreter*>(context->impl_)->GetExternalContext(type);
+}
+
+void Interpreter::SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ if (type >= 0 && type < kTfLiteMaxExternalContexts) {
+ external_contexts_[type] = ctx;
+ }
+}
+
+void Interpreter::SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx) {
+ return static_cast<Interpreter*>(context->impl_)
+ ->SetExternalContext(type, ctx);
+}
+
// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
// this memory and it is only guaranteed to exist during the invocation of the
// delegate prepare.
@@ -302,6 +354,13 @@ TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
return kTfLiteOk;
}
+TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
+ TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("variables", variables.data(),
+ variables.size()));
+ variables_ = std::move(variables);
+ return kTfLiteOk;
+}
+
TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
const int* indices, int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
@@ -334,6 +393,9 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
case kTfLiteFloat32:
*bytes = sizeof(float) * count;
break;
+ case kTfLiteInt16:
+ *bytes = sizeof(int16_t) * count;
+ break;
case kTfLiteInt32:
*bytes = sizeof(int32_t) * count;
break;
@@ -346,32 +408,65 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
case kTfLiteBool:
*bytes = sizeof(bool) * count;
break;
+ case kTfLiteComplex64:
+ *bytes = sizeof(std::complex<float>) * count;
+ break;
default:
- ReportError(
- &context_,
- "Only float32, int32, int64, uint8, bool supported currently.");
+ ReportError(&context_,
+ "Only float32, int16, int32, int64, uint8, bool, complex64 "
+ "supported currently.");
return kTfLiteError;
}
return kTfLiteOk;
}
TfLiteStatus Interpreter::AllocateTensors() {
- next_execution_plan_index_to_prepare_ = 0;
- if (memory_planner_) {
- TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
- }
-
if (!consistent_) {
ReportError(&context_, "AllocateTensors() called on inconsistent model.");
return kTfLiteError;
}
+ // Explicit (re)allocation is necessary if nodes have been changed or tensors
+ // have been resized. For inputs marked as dynamic, we can't short-circuit the
+ // allocation as the client may have done the resize manually.
+ if (state_ != kStateUninvokable && !HasDynamicTensorImpl(context_, inputs_)) {
+ return kTfLiteOk;
+ }
+
+ next_execution_plan_index_to_prepare_ = 0;
+ if (memory_planner_) {
+ TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
+ }
+
TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
- if (state_ == kStateUninvokable) {
- state_ = kStateInvokable;
+
+ state_ = kStateInvokable;
+
+ // Reset the variable tensors to zero after (re)allocating the tensors.
+ // Developers shouldn't rely on the side effect of this function to reset
+ // variable tesnsors. They should call `ResetVariableTensorsToZero` directly
+ // instead.
+ ResetVariableTensorsToZero();
+
+ return kTfLiteOk;
+}
+
+// TODO(ycling): Consider to provide other functions to initialize variable
+// tensors to non-zero values.
+TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
+ for (auto& tensor : tensors_) {
+ if (!tensor.is_variable) {
+ continue;
+ }
+
+ // Variable tensors have to be `kTfLiteArenaRwPersistent`, and must be
+ // allocated after the initial `PrepareOpsAndTensors()` is called.
+ TF_LITE_ENSURE_EQ(&context_, tensor.allocation_type,
+ kTfLiteArenaRwPersistent);
+ TF_LITE_ENSURE(&context_, tensor.data.raw != nullptr);
+
+ memset(tensor.data.raw, 0, tensor.bytes);
}
- TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
- state_ == kStateInvokableAndImmutable);
return kTfLiteOk;
}
@@ -445,26 +540,26 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
"ResizeInputTensor is disallowed when graph is immutable.");
return kTfLiteError;
}
- state_ = kStateUninvokable;
// TODO(aselle): All bounds checks can be implemented as one-sided bounds
// checks by casting to unsigned for efficiency. Profile before doing this.
TF_LITE_ENSURE(&context_,
tensor_index < context_.tensors_size && tensor_index >= 0);
- TfLiteIntArray* dims_lite = ConvertVectorToTfLiteIntArray(dims);
- return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
+ TfLiteTensor* tensor = &context_.tensors[tensor_index];
+
+ // Short-circuit the state change if the dimensions don't change, avoiding
+ // unnecessary (re)allocations.
+ if (EqualArrayAndTfLiteIntArray(tensor->dims, dims.size(), dims.data())) {
+ return kTfLiteOk;
+ }
+
+ state_ = kStateUninvokable;
+ return ResizeTensorImpl(tensor, ConvertVectorToTfLiteIntArray(dims));
}
-// Returns true if at least one tensor in the given list is kTfLiteDynamic.
bool HasDynamicTensor(const TfLiteContext& context,
- const TfLiteIntArray* tensors) {
- for (int i = 0; i < tensors->size; ++i) {
- const TfLiteTensor& tensor = context.tensors[tensors->data[i]];
- if (tensor.allocation_type == kTfLiteDynamic) {
- return true;
- }
- }
- return false;
+ const TfLiteIntArray* int_array) {
+ return HasDynamicTensorImpl(context, TfLiteIntArrayView{int_array});
}
TfLiteStatus Interpreter::PrepareOpsStartingAt(
@@ -477,6 +572,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
nodes_and_registration_[node_index].second;
EnsureTensorsVectorCapacity();
if (OpPrepare(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to prepare.\n",
+ node_index);
return kTfLiteError;
}
@@ -495,7 +592,8 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt(
TfLiteStatus Interpreter::PrepareOpsAndTensors() {
if (!memory_planner_) {
memory_planner_.reset(new ArenaPlanner(
- &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
+ &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this)),
+ /*preserve_inputs=*/true));
memory_planner_->PlanAllocations();
}
@@ -521,6 +619,7 @@ TfLiteStatus Interpreter::Invoke() {
}
TfLiteStatus status = kTfLiteOk;
+#ifndef TFLITE_MCU
if (nnapi_delegate_) {
if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
@@ -534,6 +633,7 @@ TfLiteStatus Interpreter::Invoke() {
return kTfLiteError;
}
}
+#endif
// Invocations are always done in node order.
// Note that calling Invoke repeatedly will cause the original memory plan to
@@ -572,9 +672,19 @@ TfLiteStatus Interpreter::Invoke() {
}
EnsureTensorsVectorCapacity();
+ tensor_resized_since_op_invoke_ = false;
if (OpInvoke(registration, &node) == kTfLiteError) {
+ context_.ReportError(&context_, "Node %d failed to invoke.\n",
+ node_index);
status = kTfLiteError;
}
+
+ // Force execution prep for downstream ops if the latest op triggered the
+ // resize of a dynamic tensor.
+ if (tensor_resized_since_op_invoke_ &&
+ HasDynamicTensor(context_, node.outputs)) {
+ next_execution_plan_index_to_prepare_ = execution_plan_index + 1;
+ }
}
if (!allow_buffer_handle_output_) {
@@ -687,7 +797,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
state_ = kStateUninvokable;
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization, const_cast<char*>(buffer), bytes,
- kTfLiteMmapRo, allocation, &tensor);
+ kTfLiteMmapRo, allocation, false, &tensor);
}
return kTfLiteOk;
}
@@ -698,7 +808,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
- const int* dims, TfLiteQuantizationParams quantization) {
+ const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
&context_,
@@ -716,11 +826,23 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite(
TF_LITE_ENSURE_OK(&context_,
BytesRequired(type, dims, rank, &required_bytes));
}
+
+ TfLiteAllocationType allocation_type = kTfLiteArenaRw;
+ if (type == kTfLiteString) {
+ if (is_variable) {
+ // We don't have a real use case for string variable tensor.
+ ReportError(&context_, "String variable tensor isn't supported.");
+ return kTfLiteError;
+ }
+ allocation_type = kTfLiteDynamic;
+ } else if (is_variable) {
+ allocation_type = kTfLiteArenaRwPersistent;
+ }
+
TfLiteTensorReset(type, name, ConvertArrayToTfLiteIntArray(rank, dims),
quantization,
- /*buffer=*/nullptr, required_bytes,
- type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
- nullptr, &context_.tensors[tensor_index]);
+ /*buffer=*/nullptr, required_bytes, allocation_type,
+ nullptr, is_variable, &context_.tensors[tensor_index]);
return kTfLiteOk;
}
@@ -736,7 +858,10 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
TfLiteIntArray* new_size) {
// Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
if (tensor->allocation_type == kTfLiteArenaRw ||
- tensor->allocation_type == kTfLiteDynamic) {
+ tensor->allocation_type == kTfLiteDynamic ||
+ tensor->allocation_type == kTfLiteArenaRwPersistent) {
+ tensor_resized_since_op_invoke_ |=
+ TfLiteIntArrayEqual(tensor->dims, new_size) == 0;
if (tensor->type != kTfLiteString) {
size_t bytesRequired;
TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
@@ -767,6 +892,7 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
}
void Interpreter::UseNNAPI(bool enable) {
+#ifndef TFLITE_MCU
// TODO(aselle): This is a workaround for finding if NNAPI exists.
// We also need to make sure getLibraryHandle() is renamed to be NNAPI
// prefixed.
@@ -776,15 +902,18 @@ void Interpreter::UseNNAPI(bool enable) {
} else if (!nnapi_delegate_) {
nnapi_delegate_.reset(new NNAPIDelegate);
}
+#endif
}
void Interpreter::SetNumThreads(int num_threads) {
context_.recommended_num_threads = num_threads;
- // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
- // be required in order to compile the framework.
- gemm_support::SetNumThreads(&context_, num_threads);
- eigen_support::SetNumThreads(&context_, num_threads);
+ for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
+ auto* c = external_contexts_[i];
+ if (c && c->Refresh) {
+ c->Refresh(&context_);
+ }
+ }
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
@@ -828,9 +957,10 @@ TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
TF_LITE_ENSURE_OK(&context_, status);
if (!allow_dynamic_tensors) {
+ // Reset the state to force tensor/op reallocation.
+ state_ = kStateUninvokable;
TF_LITE_ENSURE_OK(&context_, AllocateTensors());
- TF_LITE_ENSURE(&context_, state_ == kStateInvokable ||
- state_ == kStateInvokableAndImmutable);
+ TF_LITE_ENSURE_EQ(&context_, state_, kStateInvokable);
// After using a delegate which doesn't support dynamic tensors, make the
// entire graph immutable.
state_ = kStateInvokableAndImmutable;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7315d83606..b69c50fbfc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -17,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_
+#include <complex>
#include <cstdio>
#include <cstdlib>
#include <vector>
@@ -39,6 +40,10 @@ constexpr TfLiteType typeToTfLiteType<int>() {
return kTfLiteInt32;
}
template <>
+constexpr TfLiteType typeToTfLiteType<int16_t>() {
+ return kTfLiteInt16;
+}
+template <>
constexpr TfLiteType typeToTfLiteType<int64_t>() {
return kTfLiteInt64;
}
@@ -54,6 +59,10 @@ template <>
constexpr TfLiteType typeToTfLiteType<bool>() {
return kTfLiteBool;
}
+template <>
+constexpr TfLiteType typeToTfLiteType<std::complex<float>>() {
+ return kTfLiteComplex64;
+}
// Forward declare since NNAPIDelegate uses Interpreter.
class NNAPIDelegate;
@@ -118,6 +127,11 @@ class Interpreter {
// interpreter.
TfLiteStatus SetOutputs(std::vector<int> outputs);
+ // Provide a list of tensor indexes that are variable tensors.
+ // Each index is bound check and this modifies the consistent_ flag of the
+ // interpreter.
+ TfLiteStatus SetVariables(std::vector<int> variables);
+
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
@@ -160,13 +174,15 @@ class Interpreter {
// to Interpreter.
inline TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name,
- const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
+ const std::vector<int>& dims, TfLiteQuantizationParams quantization,
+ bool is_variable = false) {
return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
- dims.data(), quantization);
+ dims.data(), quantization, is_variable);
}
TfLiteStatus SetTensorParametersReadWrite(
int tensor_index, TfLiteType type, const char* name, const size_t rank,
- const int* dims, TfLiteQuantizationParams quantization);
+ const int* dims, TfLiteQuantizationParams quantization,
+ bool is_variable = false);
// Functions to access tensor data
@@ -182,6 +198,9 @@ class Interpreter {
// Read only access to list of outputs.
const std::vector<int>& outputs() const { return outputs_; }
+ // Read only access to list of variable tensors.
+ const std::vector<int>& variables() const { return variables_; }
+
// Return the name of a given output. The given index must be between 0 and
// outputs().size().
const char* GetOutputName(int index) const {
@@ -379,7 +398,20 @@ class Interpreter {
allow_buffer_handle_output_ = allow_buffer_handle_output;
}
+ // Reset all variable tensors to zero.
+ // WARNING: This is an experimental API and subject to change.
+ TfLiteStatus ResetVariableTensorsToZero();
+
+ // Retrieve an operator's description of its work, for profiling purposes.
+ const char* OpProfilingString(const TfLiteRegistration& op_reg,
+ const TfLiteNode* node) const {
+ if (op_reg.profiling_string == nullptr) return nullptr;
+ return op_reg.profiling_string(&context_, node);
+ }
+
private:
+ friend class InterpreterTest;
+
// Give 'op_reg' a chance to initialize itself using the contents of
// 'buffer'.
void* OpInit(const TfLiteRegistration& op_reg, const char* buffer,
@@ -492,6 +524,18 @@ class Interpreter {
static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context,
TfLiteIntArray** execution_plan);
+ // Retrieve an existing external context by type.
+ TfLiteExternalContext* GetExternalContext(TfLiteExternalContextType type);
+ static TfLiteExternalContext* GetExternalContext(
+ struct TfLiteContext* context, TfLiteExternalContextType type);
+
+ // Set the value of an external context.
+ void SetExternalContext(TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+ static void SetExternalContext(struct TfLiteContext* context,
+ TfLiteExternalContextType type,
+ TfLiteExternalContext* ctx);
+
// Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra
// capacity. Calling this function may invalidate existing pointers to
// tensors. After calling this function, adding `kTensorsCapacityHeadroom`
@@ -541,6 +585,9 @@ class Interpreter {
// interpreter.
std::vector<int> outputs_;
+ // Array of indices representing the tensors that are variable tensors.
+ std::vector<int> variables_;
+
// The error reporter delegate that tflite will forward queries errors to.
ErrorReporter* error_reporter_;
@@ -572,8 +619,16 @@ class Interpreter {
bool allow_buffer_handle_output_ = false;
+ // Tracking bit for whether a tensor was resized in the course of an op
+ // invocation. This is a useful hint to ensure that dynamic tensor outputs
+ // trigger downstream reallocation after op invocation.
+ bool tensor_resized_since_op_invoke_ = false;
+
// Profiler for this interpreter instance.
profiling::Profiler* profiler_;
+
+ // List of active external contexts.
+ TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 453c1ada1c..10119903fe 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -23,6 +23,21 @@ limitations under the License.
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
+
+// InterpreterTest is a friend of Interpreter, so it can access context_.
+class InterpreterTest : public ::testing::Test {
+ protected:
+ TfLiteContext* GetInterpreterContext() { return &interpreter_.context_; }
+
+ Interpreter interpreter_;
+};
+
+namespace ops {
+namespace builtin {
+TfLiteRegistration* Register_PADV2();
+TfLiteRegistration* Register_NEG();
+} // namespace builtin
+} // namespace ops
namespace {
// Make an interpreter that has no tensors and no nodes
@@ -42,6 +57,22 @@ TEST(BasicInterpreter, InvokeInvalidModel) {
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
}
+TEST(BasicInterpreter, TestAllocateTensorsResetVariableTensors) {
+ Interpreter interpreter;
+ int tensor_index;
+ ASSERT_EQ(interpreter.AddTensors(1, &tensor_index), kTfLiteOk);
+ constexpr int kTensorSize = 16;
+ interpreter.SetTensorParametersReadWrite(tensor_index, kTfLiteFloat32, "",
+ {kTensorSize}, {}, true);
+ interpreter.SetVariables({tensor_index});
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ TfLiteTensor* tensor = interpreter.tensor(tensor_index);
+ // Ensure that variable tensors are reset to zero.
+ for (int i = 0; i < kTensorSize; ++i) {
+ ASSERT_EQ(tensor->data.f[i], 0.0f);
+ }
+}
+
// Test size accessor functions.
TEST(BasicInterpreter, TestSizeFunctions) {
Interpreter interpreter;
@@ -106,10 +137,9 @@ TEST(BasicInterpreter, CheckAllocate) {
TfLiteType type;
size_t size;
} cases[] = {
- {kTfLiteFloat32, sizeof(float)},
- {kTfLiteInt32, sizeof(int32_t)},
- {kTfLiteUInt8, sizeof(uint8_t)},
- {kTfLiteInt64, sizeof(int64_t)},
+ {kTfLiteFloat32, sizeof(float)}, {kTfLiteInt32, sizeof(int32_t)},
+ {kTfLiteUInt8, sizeof(uint8_t)}, {kTfLiteInt64, sizeof(int64_t)},
+ {kTfLiteInt16, sizeof(int16_t)},
};
for (auto test : cases) {
@@ -134,6 +164,7 @@ TEST(BasicInterpreter, CheckResize) {
const int32_t int32s[] = {-3, -4};
const uint8_t uint8s[] = {3, 4};
const int64_t int64s[] = {6, -7};
+ const int16_t int16s[] = {8, -9};
struct {
TfLiteType type;
@@ -144,6 +175,7 @@ TEST(BasicInterpreter, CheckResize) {
{kTfLiteInt32, sizeof(int32_t), reinterpret_cast<const char*>(int32s)},
{kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast<const char*>(uint8s)},
{kTfLiteInt64, sizeof(int64_t), reinterpret_cast<const char*>(int64s)},
+ {kTfLiteInt16, sizeof(int16_t), reinterpret_cast<const char*>(int16s)},
};
for (auto test : cases) {
@@ -179,10 +211,8 @@ TEST(BasicInterpreter, CheckAlignment) {
struct {
TfLiteType type;
} cases[] = {
- {kTfLiteFloat32},
- {kTfLiteInt32},
- {kTfLiteUInt8},
- {kTfLiteInt64},
+ {kTfLiteFloat32}, {kTfLiteInt32}, {kTfLiteUInt8},
+ {kTfLiteInt64}, {kTfLiteInt16},
};
for (auto test : cases) {
@@ -211,7 +241,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) {
TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
std::vector<int> sizes{2048, 4096, 1023, 2047, 1021,
- 2047, 1023, 2046, 1021, 2048};
+ 2047, 1023, 2046, 0, 2048};
for (int i = 0; i < sizes.size(); ++i) {
interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]},
quant);
@@ -226,31 +256,16 @@ TEST(BasicInterpreter, CheckArenaAllocation) {
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
- ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw);
- ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw);
-
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw);
ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw);
+ ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(2)->data.raw);
ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw);
-
- ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(4)->data.raw);
ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw);
- ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
+ ASSERT_LT(interpreter.tensor(5)->data.raw, interpreter.tensor(7)->data.raw);
+ ASSERT_EQ(interpreter.tensor(6)->data.raw, interpreter.tensor(2)->data.raw);
+ // #7 is the one with the largest pointer.
+ ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr);
+ ASSERT_EQ(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw);
}
TEST(BasicInterpreter, BufferAccess) {
@@ -286,6 +301,57 @@ TEST(BasicInterpreter, NoOpInterpreter) {
ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
}
+TEST(BasicInterpreter, RedundantAllocateTensors) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ const auto data_raw = interpreter.tensor(0)->data.raw;
+ ASSERT_NE(data_raw, nullptr);
+
+ // A redundant allocation request should have no impact.
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.tensor(0)->data.raw, data_raw);
+}
+
+TEST(BasicInterpreter, RedundantAllocateTensorsWithDynamicInputs) {
+ Interpreter interpreter;
+ TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
+ ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
+ interpreter.SetInputs({0});
+ interpreter.SetOutputs({1});
+ interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, &reg);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 1, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+
+ // Configure the input tensor as dynamic.
+ interpreter.tensor(0)->data.raw = nullptr;
+ interpreter.tensor(0)->allocation_type = kTfLiteDynamic;
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}),
+ kTfLiteOk);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+
+ // Reset the output tensor's buffer.
+ interpreter.tensor(1)->data.raw = nullptr;
+
+ // A redundant allocation request should be honored, as the input tensor
+ // was marked dynamic.
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr);
+}
+
TEST(BasicInterpreter, ResizingTensors) {
Interpreter interpreter;
ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
@@ -314,6 +380,18 @@ TEST(BasicInterpreter, ResizingTensors) {
EXPECT_EQ(tensor->bytes, 8 * sizeof(float));
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 1 * sizeof(float));
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 0);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 0);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
// TODO(ahentz): We shouldn't have to force reallocation, but
// ResizeInputTensor doesn't realloc dynamic tensors. Also note that
// TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op.
@@ -331,6 +409,37 @@ TEST(BasicInterpreter, ResizingTensors) {
tensor->data.f[15] = 0.123f;
}
+TEST(BasicInterpreter, NoopResizingTensors) {
+ Interpreter interpreter;
+ ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk);
+ ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk);
+
+ ASSERT_EQ(interpreter.SetTensorParametersReadWrite(
+ 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()),
+ kTfLiteOk);
+
+ int t = interpreter.inputs()[0];
+ TfLiteTensor* tensor = interpreter.tensor(t);
+
+ ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ tensor->data.f[5] = 0.123f;
+
+ // Resizing to the same size should not trigger re-allocation.
+ ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
+ ASSERT_NE(tensor->data.raw, nullptr);
+ ASSERT_EQ(tensor->data.f[5], 0.123f);
+
+ // Explicitly allocating should be a no-op, as no resize was performed.
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+ EXPECT_EQ(tensor->bytes, 6 * sizeof(float));
+ ASSERT_NE(tensor->data.raw, nullptr);
+ ASSERT_EQ(tensor->data.f[5], 0.123f);
+}
+
TEST(BasicInterpreter, OneOpInterpreter) {
Interpreter interpreter;
ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk);
@@ -603,6 +712,59 @@ TEST(BasicInterpreter, TestUnsupportedDelegateFunctions) {
EXPECT_EQ(interpreter.AllocateTensors(), kTfLiteError);
}
+TEST(BasicInterpreter, DynamicTensorsResizeDescendants) {
+ // Assemble a graph with a node that has dynamically sized output (via the
+ // pad op), followed by a node with a standard element-wise op (negate).
+ Interpreter interpreter;
+ interpreter.AddTensors(4);
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({3});
+ TfLiteQuantizationParams quant;
+ interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {2, 2, 1, 1},
+ quant);
+ interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "", {4, 2}, quant);
+ interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {}, quant);
+ interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {}, quant);
+
+ TfLiteRegistration* pad_op = tflite::ops::builtin::Register_PADV2();
+ TfLiteRegistration* neg_op = tflite::ops::builtin::Register_NEG();
+ interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, pad_op);
+ interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, neg_op);
+ ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
+
+ // Configure [[2,2],[4,4]] padding and execute the graph.
+ interpreter.typed_tensor<int>(1)[0] = 2;
+ interpreter.typed_tensor<int>(1)[1] = 2;
+ interpreter.typed_tensor<int>(1)[2] = 2;
+ interpreter.typed_tensor<int>(1)[3] = 2;
+ interpreter.typed_tensor<int>(1)[4] = 0;
+ interpreter.typed_tensor<int>(1)[5] = 0;
+ interpreter.typed_tensor<int>(1)[6] = 0;
+ interpreter.typed_tensor<int>(1)[7] = 0;
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+
+ // Both the output and intermediate tensor sizes should reflect the output
+ // from the dynamic pad operation.
+ ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 6 * 6);
+ ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 6 * 6);
+
+ // Now configure [[4,4],[6,6]] padding and execute the graph.
+ interpreter.typed_tensor<int>(1)[0] = 4;
+ interpreter.typed_tensor<int>(1)[1] = 4;
+ interpreter.typed_tensor<int>(1)[2] = 6;
+ interpreter.typed_tensor<int>(1)[3] = 6;
+ interpreter.typed_tensor<int>(1)[4] = 0;
+ interpreter.typed_tensor<int>(1)[5] = 0;
+ interpreter.typed_tensor<int>(1)[6] = 0;
+ interpreter.typed_tensor<int>(1)[7] = 0;
+ ASSERT_EQ(interpreter.Invoke(), kTfLiteOk);
+
+ // Again, the output and intermediate tensor sizes should reflect the *new*
+ // resize from the latest pad operation.
+ ASSERT_EQ(interpreter.tensor(2)->bytes, sizeof(float) * 10 * 14);
+ ASSERT_EQ(interpreter.tensor(3)->bytes, sizeof(float) * 10 * 14);
+}
+
TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) {
Interpreter interpreter;
ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity),
@@ -643,6 +805,47 @@ TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) {
ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk);
}
+struct TestExternalContext : public TfLiteExternalContext {
+ static const TfLiteExternalContextType kType = kTfLiteGemmLowpContext;
+
+ static TestExternalContext* Get(TfLiteContext* context) {
+ return reinterpret_cast<TestExternalContext*>(
+ context->GetExternalContext(context, kType));
+ }
+
+ static void Set(TfLiteContext* context, TestExternalContext* value) {
+ context->SetExternalContext(context, kType, value);
+ }
+
+ int num_refreshes = 0;
+};
+
+TEST_F(InterpreterTest, GetSetResetExternalContexts) {
+ auto* context = GetInterpreterContext();
+
+ TestExternalContext external_context;
+ external_context.Refresh = [](TfLiteContext* context) {
+ auto* ptr = TestExternalContext::Get(context);
+ if (ptr != nullptr) {
+ ++ptr->num_refreshes;
+ }
+ return kTfLiteOk;
+ };
+
+ EXPECT_EQ(TestExternalContext::Get(context), nullptr);
+ interpreter_.SetNumThreads(4);
+
+ TestExternalContext::Set(context, &external_context);
+ EXPECT_EQ(TestExternalContext::Get(context), &external_context);
+ interpreter_.SetNumThreads(4);
+ interpreter_.SetNumThreads(5);
+ EXPECT_EQ(external_context.num_refreshes, 2);
+
+ TestExternalContext::Set(context, nullptr);
+ EXPECT_EQ(TestExternalContext::Get(context), nullptr);
+ interpreter_.SetNumThreads(4);
+}
+
// Test fixture that allows playing with execution plans. It creates a two
// node graph that can be executed in either [0,1] order or [1,0] order.
// The CopyOp records when it is invoked in the class member run_order_
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index 4450bc9085..db837cf29e 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -1,5 +1,7 @@
"""Generate zipped aar file including different variants of .so in jni folder."""
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
def aar_with_jni(name, android_library):
# Generate dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app_for_so target below)
@@ -19,7 +21,7 @@ EOF
# Generate dummy apk including .so files and later we extract out
# .so files and throw away the apk.
- native.android_binary(
+ android_binary(
name = name + "_dummy_app_for_so",
manifest = name + "_generated_AndroidManifest.xml",
custom_package = "dummy.package.for.so",
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 2e818f728e..e3cea19e16 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -1,5 +1,14 @@
# TF Lite Android App
+## Building in Android Studio with TensorFlow Lite AAR from JCenter.
+The build.gradle is configured to use TensorFlow Lite's nightly build.
+
+If you see a build error related to compatibility with Tensorflow Lite's Java API (example: method X is
+undefined for type Interpreter), there has likely been a backwards compatible
+change to the API. You will need to pull new app code that's compatible with the
+nightly build and may need to first wait a few days for our external and internal
+code to merge.
+
## Building from Source with Bazel
1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index b76eaad8bb..49868c5a75 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -5,11 +5,12 @@ android {
buildToolsVersion "26.0.1"
defaultConfig {
applicationId "android.example.com.tflitecamerademo"
- minSdkVersion 15
+ // Required by Camera2 API.
+ minSdkVersion 21
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,7 +44,7 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
@@ -52,7 +53,43 @@ dependencies {
compile 'com.android.support:support-annotations:25.3.1'
compile 'com.android.support:support-v13:25.2.0'
- compile 'org.tensorflow:tensorflow-lite:+'
+ compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
testCompile 'junit:junit:4.12'
}
+
+def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
+def localCache = "build/intermediates/mobilenet_v1_224_android_quant_2017_11_08.zip"
+def targetFolder = "src/main/assets"
+
+task downloadModel(type: DownloadUrlTask) {
+ doFirst {
+ println "Downloading ${modelDownloadUrl}"
+ }
+ sourceUrl = "${modelDownloadUrl}"
+ target = file("${localCache}")
+}
+
+task unzipModel(type: Copy, dependsOn: 'downloadModel') {
+ doFirst {
+ println "Unzipping ${localCache}"
+ }
+ from zipTree("${localCache}")
+ into "${targetFolder}"
+}
+
+// Ensure the model file is downloaded and extracted before every build
+preBuild.dependsOn unzipModel
+
+class DownloadUrlTask extends DefaultTask {
+ @Input
+ String sourceUrl
+
+ @OutputFile
+ File target
+
+ @TaskAction
+ void download() {
+ ant.get(src: sourceUrl, dest: target)
+ }
+}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index d6fbef9cc9..220d6c2159 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -1,3 +1,5 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 362d93636f..f232b00045 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -1,6 +1,8 @@
# Description:
# OVIC Benchmarker Java API.
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index 83974f4b33..a8d751ade2 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -1,3 +1,5 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
# Sample app for OVIC benchmarking.
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
index c5d19bad89..3f32d62e5c 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -9,7 +9,7 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
+ testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,7 +43,7 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', {
+ androidTestCompile('com.androidx.test.espresso:espresso-core:2.2.2', {
exclude group: 'com.android.support', module: 'support-annotations'
})
compile 'com.android.support:appcompat-v7:25.2.0'
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 56f3e7604a..1587c3c56f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -127,12 +127,8 @@ public final class OvicClassifierTest {
try {
testResult = classifier.classifyByteBuffer(testImage);
fail();
- } catch (RuntimeException e) {
- assertThat(e)
- .hasMessageThat()
- .contains(
- "Failed to get input dimensions. 0-th input should have 49152 bytes, "
- + "but found 150528 bytes.");
+ } catch (IllegalArgumentException e) {
+ // Success.
}
}
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
index 75334cd96e..94a1ec65d6 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java
@@ -27,10 +27,7 @@ enum DataType {
UINT8(3),
/** 64-bit signed integer. */
- INT64(4),
-
- /** A {@link ByteBuffer}. */
- BYTEBUFFER(999);
+ INT64(4);
private final int value;
@@ -69,8 +66,6 @@ enum DataType {
return 1;
case INT64:
return 8;
- case BYTEBUFFER:
- return 1;
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
@@ -87,8 +82,6 @@ enum DataType {
return "byte";
case INT64:
return "long";
- case BYTEBUFFER:
- return "ByteBuffer";
}
throw new IllegalArgumentException(
"DataType error: DataType " + this + " is not supported yet");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index 644ce4cb3e..7002f82677 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -17,6 +17,7 @@ package org.tensorflow.lite;
import java.io.File;
import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
@@ -104,6 +105,27 @@ public final class Interpreter implements AutoCloseable {
}
/**
+ * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
+ *
+ * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+ * Interpreter}.
+ */
+ public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
+ wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ }
+
+ /**
+ * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
+ * specifies the number of threads used for inference.
+ *
+ * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
+ * Interpreter}.
+ */
+ public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
+ wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ }
+
+ /**
* Runs model inference if the model takes only one input, and provides only one output.
*
* <p>Warning: The API runs much faster if {@link ByteBuffer} is used as input data type. Please
@@ -113,7 +135,8 @@ public final class Interpreter implements AutoCloseable {
* including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large
* input data. When {@link ByteBuffer} is used, its content should remain unchanged until
* model inference is done.
- * @param output a multidimensional array of output data.
+ * @param output a multidimensional array of output data, or a {@link ByteBuffer} of primitive
+ * types including int, float, long, and byte.
*/
public void run(@NonNull Object input, @NonNull Object output) {
Object[] inputs = {input};
@@ -133,28 +156,16 @@ public final class Interpreter implements AutoCloseable {
* primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred
* way to pass large input data. When {@link ByteBuffer} is used, its content should remain
* unchanged until model inference is done.
- * @param outputs a map mapping output indices to multidimensional arrays of output data. It only
- * needs to keep entries for the outputs to be used.
+ * @param outputs a map mapping output indices to multidimensional arrays of output data or {@link
+ * ByteBuffer}s of primitive types including int, float, long, and byte. It only needs to keep
+ * entries for the outputs to be used.
*/
public void runForMultipleInputsOutputs(
@NonNull Object[] inputs, @NonNull Map<Integer, Object> outputs) {
if (wrapper == null) {
throw new IllegalStateException("Internal error: The Interpreter has already been closed.");
}
- Tensor[] tensors = wrapper.run(inputs);
- if (outputs == null || tensors == null || outputs.size() > tensors.length) {
- throw new IllegalArgumentException("Output error: Outputs do not match with model outputs.");
- }
- final int size = tensors.length;
- for (Integer idx : outputs.keySet()) {
- if (idx == null || idx < 0 || idx >= size) {
- throw new IllegalArgumentException(
- String.format(
- "Output error: Invalid index of output %d (should be in range [0, %d))",
- idx, size));
- }
- tensors[idx].copyTo(outputs.get(idx));
- }
+ wrapper.run(inputs, outputs);
}
/**
@@ -227,8 +238,19 @@ public final class Interpreter implements AutoCloseable {
/** Release resources associated with the {@code Interpreter}. */
@Override
public void close() {
- wrapper.close();
- wrapper = null;
+ if (wrapper != null) {
+ wrapper.close();
+ wrapper = null;
+ }
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ close();
+ } finally {
+ super.finalize();
+ }
}
NativeInterpreterWrapper wrapper;
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 2ae6c516b0..767a220f8c 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -15,10 +15,10 @@ limitations under the License.
package org.tensorflow.lite;
-import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -40,6 +40,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModel(modelPath, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/**
@@ -72,6 +74,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
isMemoryAllocated = true;
+ inputTensors = new Tensor[getInputCount(interpreterHandle)];
+ outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@@ -85,75 +89,63 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
+ Arrays.fill(inputTensors, null);
+ Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
- Tensor[] run(Object[] inputs) {
+ void run(Object[] inputs, Map<Integer, Object> outputs) {
+ inferenceDurationNanoseconds = -1;
if (inputs == null || inputs.length == 0) {
throw new IllegalArgumentException("Input error: Inputs should not be null or empty.");
}
- int[] dataTypes = new int[inputs.length];
- Object[] sizes = new Object[inputs.length];
- int[] numsOfBytes = new int[inputs.length];
+ if (outputs == null || outputs.isEmpty()) {
+ throw new IllegalArgumentException("Input error: Outputs should not be null or empty.");
+ }
+
+ // TODO(b/80431971): Remove implicit resize after deprecating multi-dimensional array inputs.
+ // Rather than forcing an immediate resize + allocation if an input's shape differs, we first
+ // flush all resizes, avoiding redundant allocations.
for (int i = 0; i < inputs.length; ++i) {
- DataType dataType = dataTypeOf(inputs[i]);
- dataTypes[i] = dataType.getNumber();
- if (dataType == DataType.BYTEBUFFER) {
- ByteBuffer buffer = (ByteBuffer) inputs[i];
- if (buffer == null || !buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()) {
- throw new IllegalArgumentException(
- "Input error: ByteBuffer should be a direct ByteBuffer that uses "
- + "ByteOrder.nativeOrder().");
- }
- numsOfBytes[i] = buffer.limit();
- sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]);
- } else if (isNonEmptyArray(inputs[i])) {
- int[] dims = shapeOf(inputs[i]);
- sizes[i] = dims;
- numsOfBytes[i] = dataType.elemByteSize() * numElements(dims);
- } else {
- throw new IllegalArgumentException(
- String.format(
- "Input error: %d-th element of the %d inputs is not an array or a ByteBuffer.",
- i, inputs.length));
+ Tensor tensor = getInputTensor(i);
+ int[] newShape = tensor.getInputShapeIfDifferent(inputs[i]);
+ if (newShape != null) {
+ resizeInput(i, newShape);
}
}
- inferenceDurationNanoseconds = -1;
- long[] outputsHandles =
- run(
- interpreterHandle,
- errorHandle,
- sizes,
- dataTypes,
- numsOfBytes,
- inputs,
- this,
- isMemoryAllocated);
- if (outputsHandles == null || outputsHandles.length == 0) {
- throw new IllegalStateException("Internal error: Interpreter has no outputs.");
+
+ if (!isMemoryAllocated) {
+ allocateTensors(interpreterHandle, errorHandle);
+ isMemoryAllocated = true;
+ // Allocation can trigger dynamic resizing of output tensors, so clear the
+ // output tensor cache.
+ Arrays.fill(outputTensors, null);
}
- isMemoryAllocated = true;
- Tensor[] outputs = new Tensor[outputsHandles.length];
- for (int i = 0; i < outputsHandles.length; ++i) {
- outputs[i] = Tensor.fromHandle(outputsHandles[i]);
+
+ for (int i = 0; i < inputs.length; ++i) {
+ getInputTensor(i).setTo(inputs[i]);
+ }
+
+ long inferenceStartNanos = System.nanoTime();
+ run(interpreterHandle, errorHandle);
+ long inferenceDurationNanoseconds = System.nanoTime() - inferenceStartNanos;
+
+ for (Map.Entry<Integer, Object> output : outputs.entrySet()) {
+ getOutputTensor(output.getKey()).copyTo(output.getValue());
}
- return outputs;
+
+ // Only set if the entire operation succeeds.
+ this.inferenceDurationNanoseconds = inferenceDurationNanoseconds;
}
- private static native long[] run(
- long interpreterHandle,
- long errorHandle,
- Object[] sizes,
- int[] dtypes,
- int[] numsOfBytes,
- Object[] values,
- NativeInterpreterWrapper wrapper,
- boolean memoryAllocated);
+ private static native boolean run(long interpreterHandle, long errorHandle);
/** Resizes dimensions of a specific input. */
void resizeInput(int idx, int[] dims) {
if (resizeInput(interpreterHandle, errorHandle, idx, dims)) {
isMemoryAllocated = false;
+ // Resizing will invalidate the Tensor's shape, so invalidate the Tensor handle.
+ inputTensors[idx] = null;
}
}
@@ -212,78 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- static int numElements(int[] shape) {
- if (shape == null) {
- return 0;
- }
- int n = 1;
- for (int i = 0; i < shape.length; i++) {
- n *= shape[i];
- }
- return n;
- }
-
- static boolean isNonEmptyArray(Object o) {
- return (o != null && o.getClass().isArray() && Array.getLength(o) != 0);
- }
-
- /** Returns the type of the data. */
- static DataType dataTypeOf(Object o) {
- if (o != null) {
- Class<?> c = o.getClass();
- while (c.isArray()) {
- c = c.getComponentType();
- }
- if (float.class.equals(c)) {
- return DataType.FLOAT32;
- } else if (int.class.equals(c)) {
- return DataType.INT32;
- } else if (byte.class.equals(c)) {
- return DataType.UINT8;
- } else if (long.class.equals(c)) {
- return DataType.INT64;
- } else if (ByteBuffer.class.isInstance(o)) {
- return DataType.BYTEBUFFER;
- }
- }
- throw new IllegalArgumentException(
- "DataType error: cannot resolve DataType of " + o.getClass().getName());
- }
-
- /** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
- int[] dimensions = new int[size];
- fillShape(o, 0, dimensions);
- return dimensions;
- }
-
- static int numDimensions(Object o) {
- if (o == null || !o.getClass().isArray()) {
- return 0;
- }
- if (Array.getLength(o) == 0) {
- throw new IllegalArgumentException("Array lengths cannot be 0.");
- }
- return 1 + numDimensions(Array.get(o, 0));
- }
-
- static void fillShape(Object o, int dim, int[] shape) {
- if (shape == null || dim == shape.length) {
- return;
- }
- final int len = Array.getLength(o);
- if (shape[dim] == 0) {
- shape[dim] = len;
- } else if (shape[dim] != len) {
- throw new IllegalArgumentException(
- String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
- }
- for (int i = 0; i < len; ++i) {
- fillShape(Array.get(o, i), dim + 1, shape);
- }
- }
-
/**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
@@ -293,26 +213,63 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
/**
- * Gets the dimensions of an input. It throws IllegalArgumentException if input index is invalid.
+ * Gets the quantization zero point of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- int[] getInputDims(int index) {
- return getInputDims(interpreterHandle, index, -1);
+ int getOutputQuantizationZeroPoint(int index) {
+ return getOutputQuantizationZeroPoint(interpreterHandle, index);
}
/**
- * Gets the dimensions of an input. If numBytes >= 0, it will check whether num of bytes match the
- * input.
+ * Gets the quantization scale of an output.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
*/
- private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes);
+ float getOutputQuantizationScale(int index) {
+ return getOutputQuantizationScale(interpreterHandle, index);
+ }
+
+ /**
+ * Gets the input {@link Tensor} for the provided input index.
+ *
+ * @throws IllegalArgumentException if the input index is invalid.
+ */
+ Tensor getInputTensor(int index) {
+ if (index < 0 || index >= inputTensors.length) {
+ throw new IllegalArgumentException("Invalid input Tensor index: " + index);
+ }
+ Tensor inputTensor = inputTensors[index];
+ if (inputTensor == null) {
+ inputTensor =
+ inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ }
+ return inputTensor;
+ }
- /** Gets the type of an output. It throws IllegalArgumentException if output index is invalid. */
- String getOutputDataType(int index) {
- int type = getOutputDataType(interpreterHandle, index);
- return DataType.fromNumber(type).toStringName();
+ /**
+ * Gets the output {@link Tensor} for the provided output index.
+ *
+ * @throws IllegalArgumentException if the output index is invalid.
+ */
+ Tensor getOutputTensor(int index) {
+ if (index < 0 || index >= outputTensors.length) {
+ throw new IllegalArgumentException("Invalid output Tensor index: " + index);
+ }
+ Tensor outputTensor = outputTensors[index];
+ if (outputTensor == null) {
+ outputTensor =
+ outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ }
+ return outputTensor;
}
private static native int getOutputDataType(long interpreterHandle, int outputIdx);
+ private static native int getOutputQuantizationZeroPoint(long interpreterHandle, int outputIdx);
+
+ private static native float getOutputQuantizationScale(long interpreterHandle, int outputIdx);
+
private static final int ERROR_BUFFER_SIZE = 512;
private long errorHandle;
@@ -321,18 +278,30 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private long modelHandle;
- private int inputSize;
-
private long inferenceDurationNanoseconds = -1;
private ByteBuffer modelByteBuffer;
+ // Lazily constructed maps of input and output names to input and output Tensor indexes.
private Map<String, Integer> inputsIndexes;
-
private Map<String, Integer> outputsIndexes;
+ // Lazily constructed and populated arrays of input and output Tensor wrappers.
+ private final Tensor[] inputTensors;
+ private final Tensor[] outputTensors;
+
private boolean isMemoryAllocated = false;
+ private static native long allocateTensors(long interpreterHandle, long errorHandle);
+
+ private static native long getInputTensor(long interpreterHandle, int inputIdx);
+
+ private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+
+ private static native int getInputCount(long interpreterHandle);
+
+ private static native int getOutputCount(long interpreterHandle);
+
private static native String[] getInputNames(long interpreterHandle);
private static native String[] getOutputNames(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 09e887aae3..2403570c52 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -15,6 +15,9 @@ limitations under the License.
package org.tensorflow.lite;
+import java.lang.reflect.Array;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.Arrays;
/**
@@ -29,30 +32,179 @@ final class Tensor {
return new Tensor(nativeHandle);
}
- /** Reads Tensor content into an array. */
- <T> T copyTo(T dst) {
- if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) {
+ /** Returns the {@link DataType} of elements stored in the Tensor. */
+ public DataType dataType() {
+ return dtype;
+ }
+
+ /** Returns the size, in bytes, of the tensor data. */
+ public int numBytes() {
+ return numBytes(nativeHandle);
+ }
+
+ /**
+ * Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
+ * the Tensor, i.e., the sizes of each dimension.
+ *
+ * @return an array where the i-th element is the size of the i-th dimension of the tensor.
+ */
+ public int[] shape() {
+ return shapeCopy;
+ }
+
+ /**
+ * Copies the contents of the provided {@code src} object to the Tensor.
+ *
+ * <p>The {@code src} should either be a (multi-dimensional) array with a shape matching that of
+ * this tensor, or a {@link ByteByffer} of compatible primitive type with a matching flat size.
+ *
+ * @throws IllegalArgumentException if the tensor is a scalar or if {@code src} is not compatible
+ * with the tensor (for example, mismatched data types or shapes).
+ */
+ void setTo(Object src) {
+ throwExceptionIfTypeIsIncompatible(src);
+ if (isByteBuffer(src)) {
+ ByteBuffer srcBuffer = (ByteBuffer) src;
+ // For direct ByteBuffer instances we support zero-copy. Note that this assumes the caller
+ // retains ownership of the source buffer until inference has completed.
+ if (srcBuffer.isDirect() && srcBuffer.order() == ByteOrder.nativeOrder()) {
+ writeDirectBuffer(nativeHandle, srcBuffer);
+ } else {
+ buffer().put(srcBuffer);
+ }
+ return;
+ }
+ writeMultiDimensionalArray(nativeHandle, src);
+ }
+
+ /**
+ * Copies the contents of the tensor to {@code dst} and returns {@code dst}.
+ *
+ * @param dst the destination buffer, either an explicitly-typed array or a {@link ByteBuffer}.
+ * @throws IllegalArgumentException if {@code dst} is not compatible with the tensor (for example,
+ * mismatched data types or shapes).
+ */
+ Object copyTo(Object dst) {
+ throwExceptionIfTypeIsIncompatible(dst);
+ if (dst instanceof ByteBuffer) {
+ ByteBuffer dstByteBuffer = (ByteBuffer) dst;
+ dstByteBuffer.put(buffer());
+ return dst;
+ }
+ readMultiDimensionalArray(nativeHandle, dst);
+ return dst;
+ }
+
+ /** Returns the provided buffer's shape if specified and different from this Tensor's shape. */
+ // TODO(b/80431971): Remove this method after deprecating multi-dimensional array inputs.
+ int[] getInputShapeIfDifferent(Object input) {
+ // Implicit resizes based on ByteBuffer capacity isn't supported, so short-circuit that path.
+ // The ByteBuffer's size will be validated against this Tensor's size in {@link #setTo(Object)}.
+ if (isByteBuffer(input)) {
+ return null;
+ }
+ int[] inputShape = shapeOf(input);
+ if (Arrays.equals(shapeCopy, inputShape)) {
+ return null;
+ }
+ return inputShape;
+ }
+
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType error: cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("Array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
+ private void throwExceptionIfTypeIsIncompatible(Object o) {
+ if (isByteBuffer(o)) {
+ ByteBuffer oBuffer = (ByteBuffer) o;
+ if (oBuffer.capacity() != numBytes()) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Cannot convert between a TensorFlowLite buffer with %d bytes and a "
+ + "ByteBuffer with %d bytes.",
+ numBytes(), oBuffer.capacity()));
+ }
+ return;
+ }
+ DataType oType = dataTypeOf(o);
+ if (oType != dtype) {
throw new IllegalArgumentException(
String.format(
- "Output error: Cannot convert an TensorFlowLite tensor with type %s to a Java "
- + "object of type %s (which is compatible with the TensorFlowLite type %s)",
- dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst)));
+ "Cannot convert between a TensorFlowLite tensor with type %s and a Java "
+ + "object of type %s (which is compatible with the TensorFlowLite type %s).",
+ dtype, o.getClass().getName(), oType));
}
- int[] dstShape = NativeInterpreterWrapper.shapeOf(dst);
- if (!Arrays.equals(dstShape, shapeCopy)) {
+
+ int[] oShape = shapeOf(o);
+ if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
- "Output error: Shape of output target %s does not match with the shape of the "
- + "Tensor %s.",
- Arrays.toString(dstShape), Arrays.toString(shapeCopy)));
+ "Cannot copy between a TensorFlowLite tensor with shape %s and a Java object "
+ + "with shape %s.",
+ Arrays.toString(shapeCopy), Arrays.toString(oShape)));
}
- readMultiDimensionalArray(nativeHandle, dst);
- return dst;
}
- final long nativeHandle;
- final DataType dtype;
- final int[] shapeCopy;
+ private static boolean isByteBuffer(Object o) {
+ return o instanceof ByteBuffer;
+ }
+
+ private final long nativeHandle;
+ private final DataType dtype;
+ private final int[] shapeCopy;
private Tensor(long nativeHandle) {
this.nativeHandle = nativeHandle;
@@ -60,11 +212,23 @@ final class Tensor {
this.shapeCopy = shape(nativeHandle);
}
+ private ByteBuffer buffer() {
+ return buffer(nativeHandle).order(ByteOrder.nativeOrder());
+ }
+
+ private static native ByteBuffer buffer(long handle);
+
+ private static native void writeDirectBuffer(long handle, ByteBuffer src);
+
private static native int dtype(long handle);
private static native int[] shape(long handle);
- private static native void readMultiDimensionalArray(long handle, Object value);
+ private static native int numBytes(long handle);
+
+ private static native void readMultiDimensionalArray(long handle, Object dst);
+
+ private static native void writeMultiDimensionalArray(long handle, Object src);
static {
TensorFlowLite.init();
diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD
index 4399ed2025..4b4e1c21d8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/BUILD
+++ b/tensorflow/contrib/lite/java/src/main/native/BUILD
@@ -11,7 +11,6 @@ licenses(["notice"]) # Apache 2.0
cc_library(
name = "native_framework_only",
srcs = [
- "duration_utils_jni.cc",
"exception_jni.cc",
"nativeinterpreterwrapper_jni.cc",
"tensor_jni.cc",
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index 1fb6997fb9..e2c1edd9af 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -16,9 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h"
namespace {
-const int kByteBufferValue = 999;
-const int kBufferSize = 256;
-
tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
@@ -62,22 +59,6 @@ std::vector<int> convertJIntArrayToVector(JNIEnv* env, jintArray inputs) {
return outputs;
}
-bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; }
-
-TfLiteType resolveDataType(jint data_type) {
- switch (data_type) {
- case 1:
- return kTfLiteFloat32;
- case 2:
- return kTfLiteInt32;
- case 3:
- return kTfLiteUInt8;
- case 4:
- return kTfLiteInt64;
- default:
- return kTfLiteNoType;
- }
-}
int getDataType(TfLiteType data_type) {
switch (data_type) {
@@ -108,64 +89,6 @@ void printDims(char* buffer, int max_size, int* dims, int num_dims) {
}
}
-TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- const int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values,
- jobjectArray sizes) {
- if (input_size != interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Expected num of inputs is %d but got %d",
- interpreter->inputs().size(), input_size);
- return kTfLiteError;
- }
- if (input_size != env->GetArrayLength(data_types) ||
- input_size != env->GetArrayLength(nums_of_bytes) ||
- input_size != env->GetArrayLength(values)) {
- throwException(env, kIllegalArgumentException,
- "Internal error: Arrays in arguments should be of the same "
- "length, but got %d sizes, %d data_types, %d nums_of_bytes, "
- "and %d values",
- input_size, env->GetArrayLength(data_types),
- env->GetArrayLength(nums_of_bytes),
- env->GetArrayLength(values));
- return kTfLiteError;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- int num_dims = static_cast<int>(env->GetArrayLength(dims));
- if (target->dims->size != num_dims) {
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input should have %d dimensions, but "
- "found %d dimensions",
- i, target->dims->size, num_dims);
- return kTfLiteError;
- }
- jint* ptr = env->GetIntArrayElements(dims, nullptr);
- for (int j = 1; j < num_dims; ++j) {
- if (target->dims->data[j] != ptr[j]) {
- std::unique_ptr<char[]> expected_dims(new char[kBufferSize]);
- std::unique_ptr<char[]> obtained_dims(new char[kBufferSize]);
- printDims(expected_dims.get(), kBufferSize, target->dims->data,
- num_dims);
- printDims(obtained_dims.get(), kBufferSize, ptr, num_dims);
- throwException(env, kIllegalArgumentException,
- "Input error: %d-th input dimension should be [%s], but "
- "found [%s]",
- i, expected_dims.get(), obtained_dims.get());
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- return kTfLiteError;
- }
- }
- env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT);
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
// Checks whether there is any difference between dimensions of a tensor and a
// given dimensions. Returns true if there is difference, else false.
bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
@@ -188,74 +111,6 @@ bool areDimsDifferent(JNIEnv* env, TfLiteTensor* tensor, jintArray dims) {
return false;
}
-bool areInputDimensionsTheSame(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- if (interpreter->inputs().size() != input_size) {
- return false;
- }
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteTensor* target = interpreter->tensor(input_idx);
- if (areDimsDifferent(env, target, dims)) return false;
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return false;
- }
- return true;
-}
-
-TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jobjectArray sizes) {
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- jintArray dims =
- static_cast<jintArray>(env->GetObjectArrayElement(sizes, i));
- TfLiteStatus status = interpreter->ResizeInputTensor(
- input_idx, convertJIntArrayToVector(env, dims));
- if (status != kTfLiteOk) {
- return status;
- }
- env->DeleteLocalRef(dims);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter,
- int input_size, jintArray data_types,
- jintArray nums_of_bytes, jobjectArray values) {
- jint* data_type = env->GetIntArrayElements(data_types, nullptr);
- jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr);
- for (int i = 0; i < input_size; ++i) {
- int input_idx = interpreter->inputs()[i];
- TfLiteTensor* target = interpreter->tensor(input_idx);
- jobject value = env->GetObjectArrayElement(values, i);
- bool is_byte_buffer = isByteBuffer(data_type[i]);
- if (is_byte_buffer) {
- writeByteBuffer(env, value, &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- } else {
- TfLiteType type = resolveDataType(data_type[i]);
- if (type != target->type) {
- throwException(env, kIllegalArgumentException,
- "Input error: DataType (%d) of input data does not "
- "match with the DataType (%d) of model inputs.",
- type, target->type);
- return kTfLiteError;
- }
- writeMultiDimensionalArray(env, value, target->type, target->dims->size,
- &(target->data.raw),
- static_cast<int>(num_bytes[i]));
- }
- env->DeleteLocalRef(value);
- if (env->ExceptionCheck()) return kTfLiteError;
- }
- env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT);
- env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT);
- return kTfLiteOk;
-}
-
// TODO(yichengfan): evaluate the benefit to use tflite verifier.
bool VerifyModel(const void* buf, size_t len) {
flatbuffers::Verifier verifier(static_cast<const uint8_t*>(buf), len);
@@ -287,6 +142,63 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env,
return names;
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ BufferErrorReporter* error_reporter =
+ convertLongToErrorReporter(env, error_handle);
+ if (error_reporter == nullptr) return;
+
+ if (interpreter->AllocateTensors() != kTfLiteOk) {
+ throwException(env, kNullPointerException,
+ "Internal error: Cannot allocate memory for the interpreter:"
+ " %s",
+ error_reporter->CachedErrorMessage());
+ }
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->inputs()[index]));
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return reinterpret_cast<jlong>(
+ interpreter->tensor(interpreter->outputs()[index]));
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->inputs().size());
+}
+
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ return static_cast<jint>(interpreter->outputs().size());
+}
+
JNIEXPORT jobjectArray JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
jclass clazz,
@@ -434,114 +346,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
}
// Sets inputs, runs inference, and returns outputs as long handles.
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated) {
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle) {
tflite::Interpreter* interpreter =
convertLongToInterpreter(env, interpreter_handle);
- if (interpreter == nullptr) return nullptr;
+ if (interpreter == nullptr) return;
BufferErrorReporter* error_reporter =
convertLongToErrorReporter(env, error_handle);
- if (error_reporter == nullptr) return nullptr;
- const int input_size = env->GetArrayLength(sizes);
- // validates inputs
- TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types,
- nums_of_bytes, values, sizes);
- if (status != kTfLiteOk) return nullptr;
- if (!memory_allocated ||
- !areInputDimensionsTheSame(env, interpreter, input_size, sizes)) {
- // resizes inputs
- status = resizeInputs(env, interpreter, input_size, sizes);
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not resize the input: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- // allocates memory
- status = interpreter->AllocateTensors();
- if (status != kTfLiteOk) {
- throwException(env, kNullPointerException,
- "Internal error: Can not allocate memory for the given "
- "inputs: %s",
- error_reporter->CachedErrorMessage());
- return nullptr;
- }
- }
- // sets inputs
- status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes,
- values);
- if (status != kTfLiteOk) return nullptr;
- timespec beforeInference = ::tflite::getCurrentTime();
- // runs inference
+ if (error_reporter == nullptr) return;
+
if (interpreter->Invoke() != kTfLiteOk) {
throwException(env, kIllegalArgumentException,
"Internal error: Failed to run on the given Interpreter: %s",
error_reporter->CachedErrorMessage());
- return nullptr;
- }
- timespec afterInference = ::tflite::getCurrentTime();
- jclass wrapper_clazz = env->GetObjectClass(wrapper);
- jfieldID fid =
- env->GetFieldID(wrapper_clazz, "inferenceDurationNanoseconds", "J");
- if (env->ExceptionCheck()) {
- env->ExceptionClear();
- } else if (fid != nullptr) {
- env->SetLongField(
- wrapper, fid,
- ::tflite::timespec_diff_nanoseconds(&beforeInference, &afterInference));
- }
- // returns outputs
- const std::vector<int>& results = interpreter->outputs();
- if (results.empty()) {
- throwException(
- env, kIllegalArgumentException,
- "Internal error: The Interpreter does not have any outputs.");
- return nullptr;
+ return;
}
- jlongArray outputs = env->NewLongArray(results.size());
- size_t size = results.size();
- for (int i = 0; i < size; ++i) {
- TfLiteTensor* source = interpreter->tensor(results[i]);
- jlong output = reinterpret_cast<jlong>(source);
- env->SetLongArrayRegion(outputs, i, 1, &output);
- }
- return outputs;
-}
-
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) {
- tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
- if (interpreter == nullptr) return nullptr;
- const int idx = static_cast<int>(input_idx);
- if (input_idx < 0 || input_idx >= interpreter->inputs().size()) {
- throwException(env, kIllegalArgumentException,
- "Input error: Out of range: Failed to get %d-th input out of"
- " %d inputs",
- input_idx, interpreter->inputs().size());
- return nullptr;
- }
- TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]);
- int size = target->dims->size;
- if (num_bytes >= 0) { // verifies num of bytes matches if num_bytes if valid.
- int expected_num_bytes = elementByteSize(target->type);
- for (int i = 0; i < size; ++i) {
- expected_num_bytes *= target->dims->data[i];
- }
- if (num_bytes != expected_num_bytes) {
- throwException(env, kIllegalArgumentException,
- "Input error: Failed to get input dimensions. %d-th input "
- "should have %d bytes, but found %d bytes.",
- idx, expected_num_bytes, num_bytes);
- return nullptr;
- }
- }
- jintArray outputs = env->NewIntArray(size);
- env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0]));
- return outputs;
}
JNIEXPORT jint JNICALL
@@ -561,6 +380,38 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
return static_cast<jint>(type);
}
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 0;
+ const int idx = static_cast<int>(output_idx);
+ if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to get %d-th output out of %d outputs", output_idx,
+ interpreter->outputs().size());
+ return 0;
+ }
+ TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
+ return static_cast<jint>(target->params.zero_point);
+}
+
+JNIEXPORT jfloat JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return 1.0f;
+ const int idx = static_cast<int>(output_idx);
+ if (output_idx < 0 || output_idx >= interpreter->outputs().size()) {
+ throwException(env, kIllegalArgumentException,
+ "Failed to get %d-th output out of %d outputs", output_idx,
+ interpreter->outputs().size());
+ return 1.0f;
+ }
+ TfLiteTensor* target = interpreter->tensor(interpreter->outputs()[idx]);
+ return static_cast<jfloat>(target->params.scale);
+}
+
JNIEXPORT jboolean JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput(
JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index eaa765cb34..618fba480e 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -29,9 +29,6 @@ limitations under the License.
namespace tflite {
// This is to be provided at link-time by a library.
extern std::unique_ptr<OpResolver> CreateOpResolver();
-extern timespec getCurrentTime();
-extern jlong timespec_diff_nanoseconds(struct timespec* start,
- struct timespec* stop);
} // namespace tflite
#ifdef __cplusplus
@@ -40,6 +37,57 @@ extern "C" {
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: allocateTensors
+ * Signature: (JJ)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
+ JNIEnv* env, jclass clazz, jlong handle, jlong error_handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputTensor
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint index);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getInputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method: getOutputCount
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputCount(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
* Signature: (J)[Ljava/lang/Object;
*/
@@ -118,38 +166,43 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method:
- * Signature:
- * (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;Ljava/lang/Object;Z)[J
+ * Method: run
+ * Signature: (JJ)V
*/
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
- JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle,
- jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes,
- jobjectArray values, jobject wrapper, jboolean memory_allocated);
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_run(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
- * Signature: (JII)[I
+ * Signature: (JI)I
*
- * Gets input dimensions. If num_bytes is non-negative, it will check whether
- * num_bytes matches num of bytes required by the input, and return null and
- * throw IllegalArgumentException if not.
+ * Gets output dimensions.
*/
-JNIEXPORT jintArray JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims(
- JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes);
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
* Signature: (JI)I
*
- * Gets output dimensions.
+ * Gets output quantization zero point.
*/
JNIEXPORT jint JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputDataType(
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationZeroPoint(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_idx);
+
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
+ * Signature: (JI)F
+ *
+ * Gets output quantization scale.
+ */
+JNIEXPORT jfloat JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputQuantizationScale(
JNIEnv* env, jclass clazz, jlong handle, jint output_idx);
/*
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 005dca0253..7ff96a3172 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -29,6 +29,35 @@ TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
return reinterpret_cast<TfLiteTensor*>(handle);
}
+size_t elementByteSize(TfLiteType data_type) {
+ // The code in this file makes the assumption that the
+ // TensorFlow TF_DataTypes and the Java primitive types
+ // have the same byte sizes. Validate that:
+ switch (data_type) {
+ case kTfLiteFloat32:
+ static_assert(sizeof(jfloat) == 4,
+ "Interal error: Java float not compatible with "
+ "kTfLiteFloat");
+ return 4;
+ case kTfLiteInt32:
+ static_assert(sizeof(jint) == 4,
+ "Interal error: Java int not compatible with kTfLiteInt");
+ return 4;
+ case kTfLiteUInt8:
+ static_assert(sizeof(jbyte) == 1,
+ "Interal error: Java byte not compatible with "
+ "kTfLiteUInt8");
+ return 1;
+ case kTfLiteInt64:
+ static_assert(sizeof(jlong) == 8,
+ "Interal error: Java long not compatible with "
+ "kTfLiteInt64");
+ return 8;
+ default:
+ return 0;
+ }
+}
+
size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
void* dst, size_t dst_size) {
jarray array = static_cast<jarray>(object);
@@ -43,31 +72,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type,
}
switch (type) {
case kTfLiteFloat32: {
- jfloatArray a = static_cast<jfloatArray>(array);
- jfloat* values = env->GetFloatArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseFloatArrayElements(a, values, JNI_ABORT);
+ jfloatArray float_array = static_cast<jfloatArray>(array);
+ jfloat* float_dst = static_cast<jfloat*>(dst);
+ env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst);
return to_copy;
}
case kTfLiteInt32: {
- jintArray a = static_cast<jintArray>(array);
- jint* values = env->GetIntArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseIntArrayElements(a, values, JNI_ABORT);
+ jintArray int_array = static_cast<jintArray>(array);
+ jint* int_dst = static_cast<jint*>(dst);
+ env->GetIntArrayRegion(int_array, 0, num_elements, int_dst);
return to_copy;
}
case kTfLiteInt64: {
- jlongArray a = static_cast<jlongArray>(array);
- jlong* values = env->GetLongArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseLongArrayElements(a, values, JNI_ABORT);
+ jlongArray long_array = static_cast<jlongArray>(array);
+ jlong* long_dst = static_cast<jlong*>(dst);
+ env->GetLongArrayRegion(long_array, 0, num_elements, long_dst);
return to_copy;
}
case kTfLiteUInt8: {
- jbyteArray a = static_cast<jbyteArray>(array);
- jbyte* values = env->GetByteArrayElements(a, nullptr);
- memcpy(dst, values, to_copy);
- env->ReleaseByteArrayElements(a, values, JNI_ABORT);
+ jbyteArray byte_array = static_cast<jbyteArray>(array);
+ jbyte* byte_dst = static_cast<jbyte*>(dst);
+ env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst);
return to_copy;
}
default: {
@@ -145,48 +170,6 @@ size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src,
}
}
-} // namespace
-
-size_t elementByteSize(TfLiteType data_type) {
- // The code in this file makes the assumption that the
- // TensorFlow TF_DataTypes and the Java primitive types
- // have the same byte sizes. Validate that:
- switch (data_type) {
- case kTfLiteFloat32:
- static_assert(sizeof(jfloat) == 4,
- "Interal error: Java float not compatible with "
- "kTfLiteFloat");
- return 4;
- case kTfLiteInt32:
- static_assert(sizeof(jint) == 4,
- "Interal error: Java int not compatible with kTfLiteInt");
- return 4;
- case kTfLiteUInt8:
- static_assert(sizeof(jbyte) == 1,
- "Interal error: Java byte not compatible with "
- "kTfLiteUInt8");
- return 1;
- case kTfLiteInt64:
- static_assert(sizeof(jlong) == 8,
- "Interal error: Java long not compatible with "
- "kTfLiteInt64");
- return 8;
- default:
- return 0;
- }
-}
-
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) {
- char* buf = static_cast<char*>(env->GetDirectBufferAddress(object));
- if (!buf) {
- throwException(env, kIllegalArgumentException,
- "Input ByteBuffer is not a direct buffer");
- return 0;
- }
- *dst = buf;
- return dst_size;
-}
-
size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
int dims_left, char** dst, int dst_size) {
if (dims_left <= 1) {
@@ -207,6 +190,37 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
}
}
+} // namespace
+
+JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return nullptr;
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Tensor hasn't been allocated.");
+ return nullptr;
+ }
+ return env->NewDirectByteBuffer(static_cast<void*>(tensor->data.raw),
+ static_cast<jlong>(tensor->bytes));
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+
+ char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
+ if (!src_data_raw) {
+ throwException(env, kIllegalArgumentException,
+ "Input ByteBuffer is not a direct buffer");
+ return;
+ }
+
+ tensor->data.raw = src_data_raw;
+}
+
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
@@ -224,6 +238,27 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
num_dims, static_cast<jarray>(value));
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src) {
+ TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return;
+ if (tensor->data.raw == nullptr) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Target Tensor hasn't been allocated.");
+ return;
+ }
+ if (tensor->dims->size == 0) {
+ throwException(env, kIllegalArgumentException,
+ "Internal error: Cannot copy empty/scalar Tensors.");
+ return;
+ }
+ writeMultiDimensionalArray(env, src, tensor->type, tensor->dims->size,
+ &tensor->data.raw, tensor->bytes);
+}
+
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
@@ -241,3 +276,11 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
env->SetIntArrayRegion(result, 0, num_dims, tensor->dims->data);
return result;
}
+
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ if (tensor == nullptr) return 0;
+ return static_cast<jint>(tensor->bytes);
+}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 3a4910dcc3..06e2546af8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -24,8 +24,25 @@ extern "C" {
#endif // __cplusplus
/*
- * Class: org_tensorflow_lite_TfLiteTensor
- * Method:
+ * Class: org_tensorflow_lite_Tensor
+ * Method: buffer
+ * Signature: (J)Ljava/nio/ByteBuffer;
+ */
+JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
+ * Method: writeDirectBuffer
+ * Signature: (JLjava/nio/ByteBuffer;)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
+ JNIEnv* env, jclass clazz, jlong handle, jobject src);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
+ * Method: dtype
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
@@ -33,8 +50,8 @@ JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jlong handle);
/*
- * Class: org_tensorflow_lite_TfLiteTensor
- * Method:
+ * Class: org_tensorflow_lite_Tensor
+ * Method: shape
* Signature: (J)[I
*/
JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
@@ -42,31 +59,35 @@ JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env,
jlong handle);
/*
- * Class: org_tensorflow_lite_TfLiteTensor
- * Method:
+ * Class: org_tensorflow_lite_Tensor
+ * Method: numBytes
+ * Signature: (J)I
+ */
+JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
+ * Method: readMultiDimensionalArray
* Signature: (JLjava/lang/Object;)
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
- jobject value);
+ jobject dst);
/*
- * Finds the size of each data type.
- */
-size_t elementByteSize(TfLiteType data_type);
-
-/*
- * Writes data of a ByteBuffer into dest.
- */
-size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size);
-
-/*
- * Writes a multi-dimensional array into dest.
+ * Class: org_tensorflow_lite_Tensor
+ * Method: writeMultidimensionalArray
+ * Signature: (JLjava/lang/Object;)
*/
-size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
- int dims_left, char** dst, int dst_size);
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jobject src);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 82007a6ab5..d66a73db94 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -165,6 +165,24 @@ public final class InterpreterTest {
}
@Test
+ public void testRunWithByteBufferOutput() {
+ float[] oneD = {1.23f, 6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ ByteBuffer parsedOutput =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+ try (Interpreter interpreter = new Interpreter(MODEL_FILE)) {
+ interpreter.run(fourD, parsedOutput);
+ }
+ float[] outputOneD = {
+ parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
+ };
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+
+ @Test
public void testMobilenetRun() {
// Create a gray image.
float[][][][] img = new float[1][224][224][3];
@@ -203,7 +221,9 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ + " TensorFlowLite type INT32)");
}
interpreter.close();
}
@@ -223,8 +243,8 @@ public final class InterpreterTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the"
+ "Cannot convert between a TensorFlowLite tensor with type "
+ + "FLOAT32 and a Java object of type [[[[I (which is compatible with the"
+ " TensorFlowLite type INT32)");
}
interpreter.close();
@@ -311,4 +331,11 @@ public final class InterpreterTest {
interpreter.close();
fileChannel.close();
}
+
+ @Test
+ public void testRedundantClose() throws Exception {
+ Interpreter interpreter = new Interpreter(MODEL_FILE);
+ interpreter.close();
+ interpreter.close();
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 7c00d3196f..9c4a5acd79 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -20,6 +20,8 @@ import static org.junit.Assert.fail;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -41,6 +43,9 @@ public final class NativeInterpreterWrapperTest {
private static final String BYTE_MODEL_PATH =
"tensorflow/contrib/lite/java/src/testdata/uint8.bin";
+ private static final String QUANTIZED_MODEL_PATH =
+ "tensorflow/contrib/lite/java/src/testdata/quantized.bin";
+
private static final String INVALID_MODEL_PATH =
"tensorflow/contrib/lite/java/src/testdata/invalid_model.bin";
@@ -98,10 +103,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -109,6 +114,27 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testRunWithBufferOutput() {
+ try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
+ float[] oneD = {1.23f, -6.54f, 7.81f};
+ float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
+ float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
+ float[][][][] fourD = {threeD, threeD};
+ Object[] inputs = {fourD};
+ ByteBuffer parsedOutput =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutput);
+ wrapper.run(inputs, outputs);
+ float[] outputOneD = {
+ parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
+ };
+ float[] expected = {3.69f, -19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+ }
+
+ @Test
public void testRunWithInputsOfSameDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, -6.54f, 7.81f};
@@ -116,17 +142,16 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
float[][][][] parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
parsedOutputs = new float[2][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
outputOneD = parsedOutputs[0][0][0];
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
wrapper.close();
@@ -140,10 +165,10 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
int[][][][] parsedOutputs = new int[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
int[] outputOneD = parsedOutputs[0][0][0];
int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4};
assertThat(outputOneD).isEqualTo(expected);
@@ -158,10 +183,10 @@ public final class NativeInterpreterWrapperTest {
long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
long[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
long[][][][] parsedOutputs = new long[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
long[] outputOneD = parsedOutputs[0][0][0];
long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L,
-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L};
@@ -179,10 +204,10 @@ public final class NativeInterpreterWrapperTest {
Object[] inputs = {fourD};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0};
@@ -205,13 +230,14 @@ public final class NativeInterpreterWrapperTest {
}
}
}
+ bbuf.rewind();
Object[] inputs = {bbuf};
int[] inputDims = {2, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
byte[][][][] parsedOutputs = new byte[2][4][4][12];
- outputs[0].copyTo(parsedOutputs);
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
byte[] outputOneD = parsedOutputs[0][0][0];
byte[] expected = {
(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0,
@@ -237,21 +263,22 @@ public final class NativeInterpreterWrapperTest {
}
}
Object[] inputs = {bbuf};
+ float[][][][] parsedOutputs = new float[4][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes");
+ "Cannot convert between a TensorFlowLite buffer with 768 bytes and a "
+ + "ByteBuffer with 3072 bytes.");
}
int[] inputDims = {4, 8, 8, 3};
wrapper.resizeInput(0, inputDims);
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
- float[][][][] parsedOutputs = new float[4][8][8][3];
- outputs[0].copyTo(parsedOutputs);
+ wrapper.run(inputs, outputs);
float[] outputOneD = parsedOutputs[0][0][0];
float[] expected = {3.69f, -19.62f, 23.43f};
assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
@@ -264,14 +291,18 @@ public final class NativeInterpreterWrapperTest {
ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
bbuf.order(ByteOrder.nativeOrder());
Object[] inputs = {bbuf};
+ Map<Integer, Object> outputs = new HashMap<>();
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(2 * 7 * 8 * 3);
+ outputs.put(0, parsedOutput);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes.");
+ "Cannot convert between a TensorFlowLite buffer with 192 bytes and a "
+ + "ByteBuffer with 336 bytes.");
}
wrapper.close();
}
@@ -284,14 +315,18 @@ public final class NativeInterpreterWrapperTest {
int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
int[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ int[][][][] parsedOutputs = new int[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
.contains(
- "DataType (2) of input data does not match with the DataType (1) of model inputs.");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
wrapper.close();
}
@@ -305,8 +340,11 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter.");
@@ -318,7 +356,7 @@ public final class NativeInterpreterWrapperTest {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
try {
Object[] inputs = {};
- wrapper.run(inputs);
+ wrapper.run(inputs, null);
fail();
} catch (IllegalArgumentException e) {
assertThat(e).hasMessageThat().contains("Inputs should not be null or empty.");
@@ -334,11 +372,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD, fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2");
+ assertThat(e).hasMessageThat().contains("Invalid input Tensor index: 1");
}
wrapper.close();
}
@@ -350,13 +391,18 @@ public final class NativeInterpreterWrapperTest {
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
Object[] inputs = {threeD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input should have 4 dimensions, but found 3 dimensions");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@@ -369,92 +415,23 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
assertThat(e)
.hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ .contains(
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 7, 3] and a "
+ + "Java object with shape [2, 8, 8, 3].");
}
wrapper.close();
}
@Test
- public void testNumElements() {
- int[] shape = {2, 3, 4};
- int num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(24);
- shape = null;
- num = NativeInterpreterWrapper.numElements(shape);
- assertThat(num).isEqualTo(0);
- }
-
- @Test
- public void testIsNonEmtpyArray() {
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse();
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse();
- int[] emptyArray = {};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse();
- int[] validArray = {9, 5, 2, 1};
- assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue();
- }
-
- @Test
- public void testDataTypeOf() {
- float[] testEmtpyArray = {};
- DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[] testFloatArray = {0.783f, 0.251f};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- try {
- double[] testDoubleArray = {0.783, 0.251};
- NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
- }
- try {
- Float[] testBoxedArray = {0.783f, 0.251f};
- NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
- }
- }
-
- @Test
- public void testNumDimensions() {
- int scalar = 1;
- assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
- int[][] array = {{2, 4}, {1, 9}};
- assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
- try {
- int[] emptyArray = {};
- NativeInterpreterWrapper.numDimensions(emptyArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
- }
- }
-
- @Test
- public void testFillShape() {
- int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = NativeInterpreterWrapper.numDimensions(array);
- int[] shape = new int[num];
- NativeInterpreterWrapper.fillShape(array, 0, shape);
- assertThat(num).isEqualTo(3);
- assertThat(shape[0]).isEqualTo(2);
- assertThat(shape[1]).isEqualTo(3);
- assertThat(shape[2]).isEqualTo(1);
- }
-
- @Test
public void testGetInferenceLatency() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, 6.54f, 7.81f};
@@ -462,8 +439,10 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- assertThat(outputs.length).isEqualTo(1);
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
+ wrapper.run(inputs, outputs);
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isGreaterThan(0L);
wrapper.close();
}
@@ -483,13 +462,14 @@ public final class NativeInterpreterWrapperTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
+ float[][][][] parsedOutputs = new float[2][8][8][3];
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, parsedOutputs);
try {
- wrapper.run(inputs);
+ wrapper.run(inputs, outputs);
fail();
} catch (IllegalArgumentException e) {
- assertThat(e)
- .hasMessageThat()
- .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]");
+ // Expected.
}
assertThat(wrapper.getLastNativeInferenceDurationNanoseconds()).isNull();
wrapper.close();
@@ -499,41 +479,19 @@ public final class NativeInterpreterWrapperTest {
public void testGetInputDims() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
int[] expectedDims = {1, 8, 8, 3};
- assertThat(wrapper.getInputDims(0)).isEqualTo(expectedDims);
+ assertThat(wrapper.getInputTensor(0).shape()).isEqualTo(expectedDims);
wrapper.close();
}
@Test
- public void testGetInputDimsOutOfRange() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- try {
- wrapper.getInputDims(-1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
+ public void testGetOutputQuantizationParams() {
+ try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH)) {
+ assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(0);
+ assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.0f);
}
- try {
- wrapper.getInputDims(1);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Out of range");
+ try (NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(QUANTIZED_MODEL_PATH)) {
+ assertThat(wrapper.getOutputQuantizationZeroPoint(0)).isEqualTo(127);
+ assertThat(wrapper.getOutputQuantizationScale(0)).isWithin(1e-6f).of(0.25f);
}
- wrapper.close();
- }
-
- @Test
- public void testGetOutputDataType() {
- NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("float");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("long");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("int");
- wrapper.close();
- wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH);
- assertThat(wrapper.getOutputDataType(0)).contains("byte");
- wrapper.close();
}
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 94b6632bb8..71ef044943 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -18,6 +18,10 @@ package org.tensorflow.lite;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.HashMap;
+import java.util.Map;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -32,7 +36,7 @@ public final class TensorTest {
"tensorflow/contrib/lite/java/src/testdata/add.bin";
private NativeInterpreterWrapper wrapper;
- private long nativeHandle;
+ private Tensor tensor;
@Before
public void setUp() {
@@ -42,8 +46,10 @@ public final class TensorTest {
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
float[][][][] fourD = {threeD, threeD};
Object[] inputs = {fourD};
- Tensor[] outputs = wrapper.run(inputs);
- nativeHandle = outputs[0].nativeHandle;
+ Map<Integer, Object> outputs = new HashMap<>();
+ outputs.put(0, new float[2][8][8][3]);
+ wrapper.run(inputs, outputs);
+ tensor = wrapper.getOutputTensor(0);
}
@After
@@ -52,17 +58,16 @@ public final class TensorTest {
}
@Test
- public void testFromHandle() throws Exception {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
+ public void testBasic() throws Exception {
assertThat(tensor).isNotNull();
int[] expectedShape = {2, 8, 8, 3};
- assertThat(tensor.shapeCopy).isEqualTo(expectedShape);
- assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.shape()).isEqualTo(expectedShape);
+ assertThat(tensor.dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(tensor.numBytes()).isEqualTo(2 * 8 * 8 * 3 * 4);
}
@Test
public void testCopyTo() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[2][8][8][3];
tensor.copyTo(parsedOutputs);
float[] outputOneD = parsedOutputs[0][0][0];
@@ -71,8 +76,31 @@ public final class TensorTest {
}
@Test
+ public void testCopyToByteBuffer() {
+ ByteBuffer parsedOutput =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+ tensor.copyTo(parsedOutput);
+ assertThat(parsedOutput.position()).isEqualTo(2 * 8 * 8 * 3 * 4);
+ float[] outputOneD = {
+ parsedOutput.getFloat(0), parsedOutput.getFloat(4), parsedOutput.getFloat(8)
+ };
+ float[] expected = {3.69f, 19.62f, 23.43f};
+ assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
+ }
+
+ @Test
+ public void testCopyToInvalidByteBuffer() {
+ ByteBuffer parsedOutput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ try {
+ tensor.copyTo(parsedOutput);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Expected.
+ }
+ }
+
+ @Test
public void testCopyToWrongType() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
int[][][][] parsedOutputs = new int[2][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -81,15 +109,13 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Cannot convert an TensorFlowLite tensor with type "
- + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite "
- + "type INT32)");
+ "Cannot convert between a TensorFlowLite tensor with type FLOAT32 and a Java object "
+ + "of type [[[[I (which is compatible with the TensorFlowLite type INT32)");
}
}
@Test
public void testCopyToWrongShape() {
- Tensor tensor = Tensor.fromHandle(nativeHandle);
float[][][][] parsedOutputs = new float[1][8][8][3];
try {
tensor.copyTo(parsedOutputs);
@@ -98,8 +124,104 @@ public final class TensorTest {
assertThat(e)
.hasMessageThat()
.contains(
- "Shape of output target [1, 8, 8, 3] does not match "
- + "with the shape of the Tensor [2, 8, 8, 3].");
+ "Cannot copy between a TensorFlowLite tensor with shape [2, 8, 8, 3] "
+ + "and a Java object with shape [1, 8, 8, 3].");
+ }
+ }
+
+ @Test
+ public void testSetTo() {
+ float[][][][] input = new float[2][8][8][3];
+ float[][][][] output = new float[2][8][8][3];
+ ByteBuffer inputByteBuffer =
+ ByteBuffer.allocateDirect(2 * 8 * 8 * 3 * 4).order(ByteOrder.nativeOrder());
+
+ input[0][0][0][0] = 2.0f;
+ tensor.setTo(input);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(2.0f);
+
+ inputByteBuffer.putFloat(0, 3.0f);
+ tensor.setTo(inputByteBuffer);
+ tensor.copyTo(output);
+ assertThat(output[0][0][0][0]).isEqualTo(3.0f);
+ }
+
+ @Test
+ public void testSetToInvalidByteBuffer() {
+ ByteBuffer input = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ try {
+ tensor.setTo(input);
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Success.
+ }
+ }
+
+ @Test
+ public void testGetInputShapeIfDifferent() {
+ ByteBuffer bytBufferInput = ByteBuffer.allocateDirect(3 * 4).order(ByteOrder.nativeOrder());
+ assertThat(tensor.getInputShapeIfDifferent(bytBufferInput)).isNull();
+
+ float[][][][] sameShapeInput = new float[2][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(sameShapeInput)).isNull();
+
+ float[][][][] differentShapeInput = new float[1][8][8][3];
+ assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
+ .isEqualTo(new int[] {1, 8, 8, 3});
+ }
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmptyArray = {};
+ DataType dataType = Tensor.dataTypeOf(testEmptyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ Tensor.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
}
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ Tensor.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
+ }
+ }
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ Tensor.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = Tensor.numDimensions(array);
+ int[] shape = new int[num];
+ Tensor.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
}
}
diff --git a/tensorflow/contrib/lite/java/src/testdata/quantized.bin b/tensorflow/contrib/lite/java/src/testdata/quantized.bin
new file mode 100644
index 0000000000..4062088cdf
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/testdata/quantized.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
index b524246d43..af1d99ef41 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD
@@ -1,6 +1,8 @@
# Description:
# Internal helper function to test TF Lite API.
+load("@build_bazel_rules_android//android:rules.bzl", "android_library")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 3aef0c3bb6..c23521c077 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -58,7 +58,7 @@ public class TestHelper {
*/
public static int[] getInputDims(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getInputDims(index);
+ return interpreter.wrapper.getInputTensor(index).shape();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get input dimensions.");
@@ -77,7 +77,7 @@ public class TestHelper {
*/
public static String getOutputDataType(Interpreter interpreter, int index) {
if (interpreter != null && interpreter.wrapper != null) {
- return interpreter.wrapper.getOutputDataType(index);
+ return interpreter.wrapper.getOutputTensor(index).dataType().toStringName();
} else {
throw new IllegalArgumentException(
"Interpreter has not initialized;" + " Failed to get output data type.");
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 0af659b5ca..edce73989c 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -46,11 +46,17 @@ cc_library(
hdrs = [
"eigen_support.h",
],
- copts = tflite_copts(),
+ copts = tflite_copts() + [
+ "-Wno-error=reorder",
+ ] + select({
+ "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"],
+ "//conditions:default": [
+ ],
+ }),
deps = [
":op_macros",
"//tensorflow/contrib/lite:context",
- "//third_party/eigen3",
+ "//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -130,7 +136,7 @@ cc_library(
srcs = [
"activations.cc",
"add.cc",
- "arg_max.cc",
+ "arg_min_max.cc",
"audio_spectrogram.cc",
"basic_rnn.cc",
"batch_to_space_nd.cc",
@@ -142,11 +148,14 @@ cc_library(
"conv.cc",
"depthwise_conv.cc",
"dequantize.cc",
+ "detection_postprocess.cc",
"div.cc",
"elementwise.cc",
"embedding_lookup.cc",
"embedding_lookup_sparse.cc",
"exp.cc",
+ "expand_dims.cc",
+ "fake_quant.cc",
"floor.cc",
"fully_connected.cc",
"gather.cc",
@@ -156,16 +165,18 @@ cc_library(
"lsh_projection.cc",
"lstm.cc",
"maximum_minimum.cc",
- "mean.cc",
"mfcc.cc",
"mul.cc",
"neg.cc",
"pad.cc",
"pooling.cc",
+ "pow.cc",
+ "reduce.cc",
"register.cc",
"reshape.cc",
"resize_bilinear.cc",
"select.cc",
+ "shape.cc",
"skip_gram.cc",
"slice.cc",
"space_to_batch_nd.cc",
@@ -176,6 +187,7 @@ cc_library(
"strided_slice.cc",
"sub.cc",
"svdf.cc",
+ "tile.cc",
"topk_v2.cc",
"transpose.cc",
"transpose_conv.cc",
@@ -245,6 +257,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "detection_postprocess_test",
+ size = "small",
+ srcs = ["detection_postprocess_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
@@ -271,9 +297,9 @@ tf_cc_test(
)
tf_cc_test(
- name = "arg_max_test",
+ name = "arg_min_max_test",
size = "small",
- srcs = ["arg_max_test.cc"],
+ srcs = ["arg_min_max_test.cc"],
tags = [
"tflite_not_portable_ios",
],
@@ -539,6 +565,19 @@ tf_cc_test(
)
tf_cc_test(
+ name = "fake_quant_test",
+ size = "small",
+ srcs = ["fake_quant_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "maximum_minimum_test",
size = "small",
srcs = ["maximum_minimum_test.cc"],
@@ -552,9 +591,9 @@ tf_cc_test(
)
tf_cc_test(
- name = "mean_test",
+ name = "reduce_test",
size = "small",
- srcs = ["mean_test.cc"],
+ srcs = ["reduce_test.cc"],
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
@@ -859,6 +898,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "tile_test",
+ size = "small",
+ srcs = ["tile_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "comparisons_test",
size = "small",
srcs = [
@@ -931,6 +984,21 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "expand_dims_test",
+ size = "small",
+ srcs = ["expand_dims_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
)
@@ -942,6 +1010,35 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "shape_test",
+ size = "small",
+ srcs = ["shape_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "pow_test",
+ size = "small",
+ srcs = ["pow_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index add36b46c0..99f81c4a8a 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -84,6 +84,38 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // These operators are implemented in fixed-point arithmetic,
+ // which intrinsically wants symmetric ranges (zero_point==0)
+ // and power-of-two scales (power-of-two is abbreviated below as POT).
+ // While more general support would be possible by means of rescaling,
+ // that would add some overhead and some loss of accuracy and wouldn't
+ // be used at the moment as current quantized LSTM applications are
+ // happy with symmetric, power-of-two-scales quantization. So we just
+ // implement that narrow case only for now.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // Support for shifts is limited until we have a parameterized version of
+ // SaturatingRoundingMultiplyByPOT().
+ TF_LITE_ENSURE(context, data->input_left_shift >= 0);
+ TF_LITE_ENSURE(context, data->input_left_shift <= 1);
}
return context->ResizeTensor(context, output,
@@ -114,6 +146,30 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {
&data->input_left_shift);
data->input_range_radius =
CalculateInputRadius(kInputIntegerBits, data->input_left_shift);
+ } else if (input->type == kTfLiteInt16) {
+ static constexpr int kInputIntegerBits = 3;
+ static constexpr int kOutputFractionalBits = 15;
+
+ // See comments in TanhPrepare about requiring zero_point==0
+ // and a power-of-two ("POT") scale.
+
+ TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input_scale_log2_rounded;
+ TF_LITE_ENSURE(context,
+ CheckedLog2(input->params.scale, &input_scale_log2_rounded));
+
+ int output_scale_log2_rounded;
+ TF_LITE_ENSURE(
+ context, CheckedLog2(output->params.scale, &output_scale_log2_rounded));
+ TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded,
+ -kOutputFractionalBits);
+
+ data->input_left_shift =
+ (15 - kInputIntegerBits) + input_scale_log2_rounded;
+ // The int16 logistic implementation does not support shifting of the input.
+ TF_LITE_ENSURE_EQ(context, data->input_left_shift, 0);
}
return context->ResizeTensor(context, output,
@@ -250,12 +306,19 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = std::tanh(*in);
return kTfLiteOk;
} break;
+ case kTfLiteInt16: {
+ optimized_ops::Tanh(GetTensorData<int16_t>(input), GetTensorShape(input),
+ data->input_left_shift,
+ GetTensorData<int16_t>(output),
+ GetTensorShape(output));
+ return kTfLiteOk;
+ } break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorShape(input),
input->params.zero_point, data->input_range_radius,
data->input_multiplier, data->input_left_shift,
GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
return kTfLiteOk;
} break;
default:
@@ -280,12 +343,18 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in));
break;
}
+ case kTfLiteInt16: {
+ optimized_ops::Logistic(
+ GetTensorData<int16>(input), GetTensorShape(input),
+ GetTensorData<int16_t>(output), GetTensorShape(output));
+ break;
+ }
case kTfLiteUInt8: {
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorDims(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(input),
input->params.zero_point, data->input_range_radius,
data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorDims(output));
+ GetTensorData<uint8_t>(output), GetTensorShape(output));
break;
}
default:
@@ -341,26 +410,26 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
optimized_ops::Softmax(GetTensorData<uint8_t>(input),
- GetTensorDims({batch_size, 1, 1, input_size}),
+ GetTensorShape({batch_size, 1, 1, input_size}),
data->input_multiplier, data->input_left_shift,
data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims({batch_size, 1, 1, input_size}));
+ GetTensorShape({batch_size, 1, 1, input_size}));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorDims(input),
+ optimized_ops::Softmax(GetTensorData<float>(input), GetTensorShape(input),
params->beta, GetTensorData<float>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorDims(input),
+ optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorShape(input),
data->input_multiplier, data->input_left_shift,
data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ GetTensorShape(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -415,8 +484,8 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
switch (input->type) {
case kTfLiteFloat32:
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorData<float>(input), GetTensorShape(input),
+ GetTensorData<float>(output), GetTensorShape(output));
return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index 50a84edd47..587e1303da 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -75,23 +75,42 @@ class FloatActivationsOpModel : public BaseActivationsOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-// TODO(ahentz): I don't quite understand the tradeoffs in the quantized
-// implementation of sigmoid and software, but a tolerance of twice the output
-// scale seems reasonable. We might want to change this if we have a better
-// theoretical bound.
+// Our fixed-point math function implementations have roughly 12 bits of
+// accuracy, when specialized to 16-bit fixed-point arithmetic.
+// That is purely an implementation compromise, it would have been possible
+// to get closer to 16 bits of accuracy but that would be more expensive,
+// and not needed for our purposes as ultimately the output is either
+// immediately down-quantized to 8 bits, or will typically be at the output
+// of the surrounding LSTM cell.
+// So we can require roughly 2^-12 accuracy when the output is 16-bit, and
+// we can more or less expect the full 2^-8 accuracy when the output is 8-bit.
+//
+// However, the representable output interval is often [-1, 1] (it has to be
+// for tanh, and even for logistic, when we implement it in fixed-point, we
+// typically have to do so on such a symmetric interval, e.g. ARM NEON only
+// has signed fixed-point arithmetic (SQRDMULH)). As the width of [-1, 1]
+// is 2, our representable values are often diluted by a factor of 2, whence
+// the factor of 2 below.
const float kQuantizedTolerance = 2 * (1. / 256);
+const float kQuantizedToleranceInt16 = 2 * (1. / 4096);
class QuantizedActivationsOpModel : public BaseActivationsOpModel {
public:
using BaseActivationsOpModel::BaseActivationsOpModel;
+ template <typename T>
void SetInput(std::initializer_list<float> data) {
- QuantizeAndPopulate<uint8_t>(input_, data);
+ QuantizeAndPopulate<T>(input_, data);
}
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ template <typename T>
+
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ template <typename T>
std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
}
};
@@ -152,24 +171,47 @@ TEST(FloatActivationsOpTest, Tanh) {
}
TEST(QuantizedActivationsOpTest, Tanh) {
+ const float kMin = -1;
+ const float kMax = 127.f / 128.f;
QuantizedActivationsOpModel m(
BuiltinOperator_TANH,
- /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -8, 8},
- /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, -1, 1});
- m.SetInput({
+ /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_UINT8, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
-4, -2, 8, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.0, -0.999987, 0.964027, 0.999329, //
- -0.996078, -0.96402, 0.99999, 0.76159, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
},
- 4 * (1. / 256))));
- EXPECT_THAT(m.GetOutput(),
- ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 226}));
+ kQuantizedTolerance)));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({128, 0, 251, 255, 0, 5, 255, 225}));
+}
+
+TEST(QuantizedActivationsOpTest, TanhInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_TANH,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ -4, -2, 8, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.0, -0.999987, 0.964027, 0.999329, //
+ -0.999329, -0.96402, 0.99999, 0.76159, //
+ },
+ kQuantizedToleranceInt16)));
}
TEST(FloatActivationsOpTest, Sigmoid) {
@@ -190,22 +232,43 @@ TEST(QuantizedActivationsOpTest, Sigmoid) {
QuantizedActivationsOpModel m(
BuiltinOperator_LOGISTIC,
/*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
0.5, 0.002473, 0.880797, 0.982014, //
0.952574, 0.119203, 0.999955, 0.731059, //
},
kQuantizedTolerance)));
- EXPECT_THAT(m.GetOutput(),
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188}));
}
+TEST(QuantizedActivationsOpTest, SigmoidInt16) {
+ const float kMin = -1;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedActivationsOpModel m(
+ BuiltinOperator_LOGISTIC,
+ /*input=*/{TensorType_INT16, {1, 2, 4, 1}, 8 * kMin, 8 * kMax},
+ /*output=*/{TensorType_INT16, {1, 2, 4, 1}, kMin, kMax});
+ m.SetInput<int16_t>({
+ 0, -6, 2, 4, //
+ 3, -2, 10, 1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.5, 0.002473, 0.880797, 0.982014, //
+ 0.952574, 0.119203, 0.999955, 0.731059, //
+ },
+ kQuantizedToleranceInt16)));
+}
+
TEST(FloatActivationsOpTest, Softmax4D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}});
@@ -241,12 +304,12 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m(
0.1,
/*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, // depth = 0
3, -2, 10, 1, // depth = 1
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -258,21 +321,22 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
QuantizedActivationsOpModel m2(
0.1,
/*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10});
- m2.SetInput({
+ m2.SetInput<uint8_t>({
0, -6, //
2, 4, //
3, -2, //
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
}
TEST(FloatActivationsOpTest, Softmax2D) {
@@ -309,12 +373,12 @@ TEST(FloatActivationsOpTest, Softmax2D) {
TEST(QuantizedActivationsOpTest, Softmax2D) {
QuantizedActivationsOpModel m(0.1,
/*input=*/{TensorType_UINT8, {2, 4}, -10, 10});
- m.SetInput({
+ m.SetInput<uint8_t>({
0, -6, 2, 4, //
3, -2, 10, 1, //
});
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
ElementsAreArray(ArrayFloatNear(
{
.23463, .12877, .28658, .35003, //
@@ -325,21 +389,22 @@ TEST(QuantizedActivationsOpTest, Softmax2D) {
// Same input, but a different shape.
QuantizedActivationsOpModel m2(0.1,
/*input=*/{TensorType_UINT8, {4, 2}, -10, 10});
- m2.SetInput({
+ m2.SetInput<uint8_t>({
0, -6, //
2, 4, //
3, -2, //
10, 1, //
});
m2.Invoke();
- EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
- {
- 0.645656, 0.354344, //
- 0.450166, 0.549834, //
- 0.622459, 0.377541, //
- 0.710949, 0.28905, //
- },
- kQuantizedTolerance)));
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
}
// This contains the same test values as the Softmax test, but reference answer
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index 7ca1e35489..f44d531cbf 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -39,6 +39,23 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // These fields are used in both the general 8-bit -> 8bit quantized path,
+ // and the special 16-bit -> 16bit quantized path
+ int input1_shift;
+ int input2_shift;
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // These fields are used only in the general 8-bit -> 8bit quantized path
+ int32 input1_multiplier;
+ int32 input2_multiplier;
+ int32 output_multiplier;
+ int output_shift;
+ int left_shift;
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +69,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -74,89 +92,169 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ // 8bit -> 8bit general quantized path, with general rescalings
+ data->input1_offset = -input1->params.zero_point;
+ data->input2_offset = -input2->params.zero_point;
+ data->output_offset = output->params.zero_point;
+ data->left_shift = 20;
+ const double twice_max_input_scale =
+ 2 * std::max(input1->params.scale, input2->params.scale);
+ const double real_input1_multiplier =
+ input1->params.scale / twice_max_input_scale;
+ const double real_input2_multiplier =
+ input2->params.scale / twice_max_input_scale;
+ const double real_output_multiplier =
+ twice_max_input_scale /
+ ((1 << data->left_shift) * output->params.scale);
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_input1_multiplier, &data->input1_multiplier, &data->input1_shift);
+ data->input1_shift *= -1;
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_input2_multiplier, &data->input2_multiplier, &data->input2_shift);
+ data->input2_shift *= -1;
+
+ QuantizeMultiplierSmallerThanOneExp(
+ real_output_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
+
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+
+ } else if (output->type == kTfLiteInt16) {
+ // 16bit -> 16bit special quantized path, supporting only a rather
+ // narrow case of quantization parameters: zero_points must all be 0
+ // ("symmetric quantization") and scales must be power-of-two (which
+ // we abbreviate as "POT" below). The intended use case for this path
+ // is in LSTM cells, where, due to the constraints of implementing
+ // some of the math in these LSTM cells in fixed-point arithmetic,
+ // we need to have such symmetric, power-of-two quantization
+ // (Fixed-point formats are inherently symmetric, power-of-two).
+ TF_LITE_ENSURE_EQ(context, input1->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, input2->params.zero_point, 0);
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+
+ int input1_scale_log2_rounded;
+ bool input1_scale_is_pot =
+ CheckedLog2(input1->params.scale, &input1_scale_log2_rounded);
+ TF_LITE_ENSURE(context, input1_scale_is_pot);
+
+ int input2_scale_log2_rounded;
+ bool input2_scale_is_pot =
+ CheckedLog2(input2->params.scale, &input2_scale_log2_rounded);
+ TF_LITE_ENSURE(context, input2_scale_is_pot);
+
+ int output_scale_log2_rounded;
+ bool output_scale_is_pot =
+ CheckedLog2(output->params.scale, &output_scale_log2_rounded);
+ TF_LITE_ENSURE(context, output_scale_is_pot);
+
+ data->input1_shift = output_scale_log2_rounded - input1_scale_log2_rounded;
+ data->input2_shift = output_scale_log2_rounded - input2_scale_log2_rounded;
+
+ // Shifting of one input is supported. The graph quantization should ensure
+ // that the other input matches the output.
+ TF_LITE_ENSURE(context, data->input1_shift == 0 || data->input2_shift == 0);
+ TF_LITE_ENSURE(context, data->input1_shift >= 0);
+ TF_LITE_ENSURE(context, data->input2_shift >= 0);
+
+ CalculateActivationRangeQuantized(context, params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
-void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_ADD(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
+void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_ADD(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, int32_t);
+ }
} else {
- TF_LITE_ADD(reference_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd, float);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, float);
+ }
} else {
- TF_LITE_ADD(optimized_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd, float);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, float);
+ }
}
}
#undef TF_LITE_ADD
}
template <KernelType kernel_type>
-void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
- const int left_shift = 20;
- const double twice_max_input_scale =
- 2 * std::max(input1->params.scale, input2->params.scale);
- const double real_input1_multiplier =
- input1->params.scale / twice_max_input_scale;
- const double real_input2_multiplier =
- input2->params.scale / twice_max_input_scale;
- const double real_output_multiplier =
- twice_max_input_scale / ((1 << left_shift) * output->params.scale);
-
- int32 input1_multiplier;
- int input1_shift;
- QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
- &input1_shift);
- int32 input2_multiplier;
- int input2_shift;
- QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
- &input2_shift);
- int32 output_multiplier;
- int output_shift;
- QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
- &output_shift);
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
-#define TF_LITE_ADD(type, opname) \
- type::opname(left_shift, GetTensorData<uint8_t>(input1), \
- GetTensorDims(input1), input1_offset, input1_multiplier, \
- input1_shift, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, input2_multiplier, \
- input2_shift, output_offset, output_multiplier, output_shift, \
- output_activation_min, output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output));
- // The quantized version of Add doesn't support activations, so we
- // always use BroadcastAdd.
- if (kernel_type == kReference) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
- } else {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
- }
+TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteAddParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ if (output->type == kTfLiteUInt8) {
+#define TF_LITE_ADD(type, opname) \
+ type::opname( \
+ data->left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ data->input1_offset, data->input1_multiplier, data->input1_shift, \
+ GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
+ data->input2_offset, data->input2_multiplier, data->input2_shift, \
+ data->output_offset, data->output_multiplier, data->output_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd);
+ } else {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ }
+#undef TF_LITE_ADD
+ } else if (output->type == kTfLiteInt16) {
+#define TF_LITE_ADD(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ data->input1_shift, GetTensorData<int16_t>(input2), \
+ GetTensorDims(input2), data->input2_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<int16_t>(output), GetTensorDims(output));
+ // The quantized version of Add doesn't support activations, so we
+ // always use BroadcastAdd.
+ if (kernel_type == kReference) {
+ TF_LITE_ADD(reference_ops, Add);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add);
+ }
#undef TF_LITE_ADD
+ }
+
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -168,15 +266,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalAddFloat<kernel_type>(context, node, params, data, input1, input2,
- output);
- } else if (output->type == kTfLiteUInt8) {
- EvalAddQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(context,
+ EvalAddQuantized<kernel_type>(context, node, params, data,
+ input1, input2, output));
} else {
context->ReportError(context,
- "Inputs and outputs not all float|uint8 types.");
+ "Inputs and outputs not all float|uint8|int16 types.");
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc
index 956d05bed5..0b58443211 100644
--- a/tensorflow/contrib/lite/kernels/add_test.cc
+++ b/tensorflow/contrib/lite/kernels/add_test.cc
@@ -52,6 +52,13 @@ class FloatAddOpModel : public BaseAddOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
+class IntegerAddOpModel : public BaseAddOpModel {
+ public:
+ using BaseAddOpModel::BaseAddOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
class QuantizedAddOpModel : public BaseAddOpModel {
public:
using BaseAddOpModel::BaseAddOpModel;
@@ -60,15 +67,26 @@ class QuantizedAddOpModel : public BaseAddOpModel {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
+
+ std::vector<float> GetDequantizedOutputInt16() {
+ return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
};
// for quantized Add, the error shouldn't exceed 2*step
-float GetTolerance(int min, int max) {
+float GetTolerance(float min, float max) {
float kQuantizedStep = (max - min) / 255.0;
float kQuantizedTolerance = 2.0 * kQuantizedStep;
return kQuantizedTolerance;
}
+float GetToleranceInt16(float min, float max) {
+ float kQuantizedStep = (max - min) / 32767.f;
+ float kQuantizedTolerance = 2.0 * kQuantizedStep;
+ return kQuantizedTolerance;
+}
+
TEST(FloatAddOpModel, NoActivation) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -122,6 +140,57 @@ TEST(FloatAddOpModel, WithBroadcast) {
}
}
+TEST(IntegerAddOpModel, NoActivation) {
+ IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 4, 10, 13}));
+}
+
+TEST(IntegerAddOpModel, ActivationRELU_N1_TO_1) {
+ IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
+ ActivationFunctionType_RELU_N1_TO_1);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 1, 1, 1}));
+}
+
+TEST(IntegerAddOpModel, VariousInputShapes) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1, 2, 3, 5, 11, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-19, 04, 10, 13, 22, 21}))
+ << "With shape number " << i;
+ }
+}
+
+TEST(IntegerAddOpModel, WithBroadcast) {
+ std::vector<std::initializer_list<int>> test_shapes = {
+ {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
+ for (int i = 0; i < test_shapes.size(); ++i) {
+ IntegerAddOpModel m({TensorType_INT32, test_shapes[i]},
+ {TensorType_INT32, {}}, // always a scalar
+ {TensorType_INT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<int32_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
+ m.PopulateTensor<int32_t>(m.input2(), {1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({-19, 3, 8, 9, 12, 21})))
+ << "With shape number " << i;
+ }
+}
+
TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {
@@ -144,6 +213,31 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) {
}
}
+TEST(QuantizedAddOpModel, QuantizedTestsNoActivationInt16) {
+ const float kMin = -1.f;
+ const float kMax = 32767.f / 32768.f;
+ float kQuantizedTolerance = GetToleranceInt16(kMin, kMax);
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}};
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}};
+ std::vector<std::initializer_list<float>> results = {
+ {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}};
+ for (int i = 0; i < inputs1.size(); ++i) {
+ QuantizedAddOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {}, kMin, kMax},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), inputs1[i]);
+ m.QuantizeAndPopulate<int16_t>(m.input2(), inputs2[i]);
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutputInt16(),
+ ElementsAreArray(ArrayFloatNear(results[i], kQuantizedTolerance)))
+ << "With test number " << i;
+ }
+}
+
TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<std::initializer_list<float>> inputs1 = {{-0.8, 0.2, 0.9, 0.7},
diff --git a/tensorflow/contrib/lite/kernels/arg_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 26f57e8896..4f30d09030 100644
--- a/tensorflow/contrib/lite/kernels/arg_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -23,7 +23,7 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace builtin {
-namespace arg_max {
+namespace arg_min_max {
constexpr int kInputTensor = 0;
constexpr int kAxis = 1;
@@ -80,30 +80,39 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, output_size);
}
+template <typename T>
+std::function<bool(T, T)> GetComparefunction(bool is_arg_max) {
+ if (is_arg_max) {
+ return std::greater<T>();
+ } else {
+ return std::less<T>();
+ }
+}
+
// The current impl actually ignores the axis argument.
// Only determine the index of the maximum value in the last dimension.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMax(GetTensorData<axis_type>(axis), \
- GetTensorData<data_type>(input), GetTensorDims(input), \
- GetTensorData<output_type>(output), \
- GetTensorDims(output))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
+ GetTensorDims(input), GetTensorData<output_type>(output), \
+ GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
default:
return kTfLiteError;
@@ -112,13 +121,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int32_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;
default:
return kTfLiteError;
@@ -132,13 +141,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt32: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int32_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
break;
default:
return kTfLiteError;
@@ -147,13 +156,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteInt64: {
switch (input->type) {
case kTfLiteFloat32:
- TF_LITE_ARG_MAX(float, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
break;
case kTfLiteUInt8:
- TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
break;
case kTfLiteInt32:
- TF_LITE_ARG_MAX(int32_t, int64_t, int64_t);
+ TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
break;
default:
return kTfLiteError;
@@ -163,16 +172,30 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteError;
}
}
-#undef TF_LITE_ARG_MAX
+#undef TF_LITE_ARG_MIN_MAX
return kTfLiteOk;
}
-} // namespace arg_max
+TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, false);
+}
+
+TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, true);
+}
+
+} // namespace arg_min_max
TfLiteRegistration* Register_ARG_MAX() {
- static TfLiteRegistration r = {nullptr, nullptr, arg_max::Prepare,
- arg_max::Eval};
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMaxEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_ARG_MIN() {
+ static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
+ arg_min_max::ArgMinEval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/arg_max_test.cc b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
index 31b15fe19a..90e5fdc532 100644
--- a/tensorflow/contrib/lite/kernels/arg_max_test.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max_test.cc
@@ -24,16 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
template <typename T>
-class ArgMaxOpModel : public SingleOpModel {
+class ArgBaseOpModel : public SingleOpModel {
public:
- ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
- TensorType output_type, TensorType index_output_type) {
+ ArgBaseOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type) {
input_ = AddInput(input_type);
axis_ = AddInput(TensorType_INT32);
output_ = AddOutput(output_type);
- SetBuiltinOp(BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
- CreateArgMaxOptions(builder_, index_output_type).Union());
- BuildInterpreter({input_shape, {1, 1, 1, 1}});
}
int input() { return input_; }
@@ -42,12 +39,42 @@ class ArgMaxOpModel : public SingleOpModel {
std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
- private:
+ protected:
int input_;
int axis_;
int output_;
};
+template <typename T>
+class ArgMaxOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMaxOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MAX, BuiltinOptions_ArgMaxOptions,
+ CreateArgMaxOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
+template <typename T>
+class ArgMinOpModel : public ArgBaseOpModel<T> {
+ public:
+ ArgMinOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type, TensorType index_output_type)
+ : ArgBaseOpModel<T>(input_shape, input_type, output_type,
+ index_output_type) {
+ ArgBaseOpModel<T>::SetBuiltinOp(
+ BuiltinOperator_ARG_MIN, BuiltinOptions_ArgMinOptions,
+ CreateArgMinOptions(ArgBaseOpModel<T>::builder_, index_output_type)
+ .Union());
+ ArgBaseOpModel<T>::BuildInterpreter({input_shape, {1, 1, 1, 1}});
+ }
+};
+
TEST(ArgMaxOpTest, GetMaxArgFloat) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
TensorType_INT32, TensorType_INT32);
@@ -96,6 +123,54 @@ TEST(ArgMaxOpTest, GetMaxArgOutput64) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
}
+TEST(ArgMinOpTest, GetMinArgFloat) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_FLOAT32,
+ TensorType_INT32, TensorType_INT32);
+ model.PopulateTensor<float>(model.input(), {0.1, 0.9, 0.7, 0.3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgInt) {
+ ArgMinOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgMulDimensions) {
+ ArgMinOpModel<int32_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int>(model.input(), {1, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({0, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
+TEST(ArgMinOpTest, GetMinArgOutput64) {
+ ArgMinOpModel<int64_t> model({1, 1, 2, 4}, TensorType_INT32, TensorType_INT64,
+ TensorType_INT64);
+ model.PopulateTensor<int>(model.input(), {10, 2, 7, 8, 1, 9, 7, 3});
+ // Currently only support the last dimension.
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 60770ca0aa..8dd48af57f 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <algorithm>
+#include <complex>
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
@@ -53,6 +54,20 @@ void copyCast(const FromT* in, ToT* out, int num_elements) {
[](FromT a) { return static_cast<ToT>(a); });
}
+template <typename ToT>
+void copyCast(const std::complex<float>* in, ToT* out, int num_elements) {
+ std::transform(in, in + num_elements, out, [](std::complex<float> a) {
+ return static_cast<ToT>(std::real(a));
+ });
+}
+
+template <>
+void copyCast(const std::complex<float>* in, std::complex<float>* out,
+ int num_elements) {
+ std::transform(in, in + num_elements, out,
+ [](std::complex<float> a) { return a; });
+}
+
template <typename FromT>
TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
int num_elements) {
@@ -72,6 +87,10 @@ TfLiteStatus copyToTensor(const FromT* in, TfLiteTensor* out,
case kTfLiteBool:
copyCast(in, out->data.b, num_elements);
break;
+ case kTfLiteComplex64:
+ copyCast(in, reinterpret_cast<std::complex<float>*>(out->data.c64),
+ num_elements);
+ break;
default:
// Unsupported type.
return kTfLiteError;
@@ -95,6 +114,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return copyToTensor(input->data.f, output, num_elements);
case kTfLiteBool:
return copyToTensor(input->data.b, output, num_elements);
+ case kTfLiteComplex64:
+ return copyToTensor(
+ reinterpret_cast<std::complex<float>*>(input->data.c64), output,
+ num_elements);
default:
// Unsupported type.
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/kernels/cast_test.cc b/tensorflow/contrib/lite/kernels/cast_test.cc
index 53e2000737..954f998206 100644
--- a/tensorflow/contrib/lite/kernels/cast_test.cc
+++ b/tensorflow/contrib/lite/kernels/cast_test.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <complex>
+
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -73,6 +75,71 @@ TEST(CastOpModel, CastBoolToFloat) {
ElementsAreArray({1.f, 1.0f, 0.f, 1.0f, 0.0f, 1.0f}));
}
+TEST(CastOpModel, CastComplex64ToFloat) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
+}
+
+TEST(CastOpModel, CastFloatToComplex64) {
+ CastOpModel m({TensorType_FLOAT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToInt) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_INT32, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int>(m.output()),
+ ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(CastOpModel, CastIntToComplex64) {
+ CastOpModel m({TensorType_INT32, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<int>(m.input(), {1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 0.0f), std::complex<float>(2.0f, 0.0f),
+ std::complex<float>(3.0f, 0.0f), std::complex<float>(4.0f, 0.0f),
+ std::complex<float>(5.0f, 0.0f), std::complex<float>(6.0f, 0.0f)}));
+}
+
+TEST(CastOpModel, CastComplex64ToComplex64) {
+ CastOpModel m({TensorType_COMPLEX64, {2, 3}}, {TensorType_COMPLEX64, {2, 3}});
+ m.PopulateTensor<std::complex<float>>(
+ m.input(),
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f), std::complex<float>(6.0f, 16.0f)});
+ m.Invoke();
+ EXPECT_THAT(
+ m.ExtractVector<std::complex<float>>(m.output()),
+ ElementsAreArray(
+ {std::complex<float>(1.0f, 11.0f), std::complex<float>(2.0f, 12.0f),
+ std::complex<float>(3.0f, 13.0f), std::complex<float>(4.0f, 14.0f),
+ std::complex<float>(5.0f, 15.0f),
+ std::complex<float>(6.0f, 16.0f)}));
+}
+
} // namespace
} // namespace tflite
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 3b81062cd4..f678f48fa5 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -23,6 +23,7 @@ namespace tflite {
namespace ops {
namespace builtin {
namespace comparisons {
+namespace {
constexpr int kInputTensor1 = 0;
constexpr int kInputTensor2 = 1;
@@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<type>(input2), GetTensorDims(input2), \
GetTensorData<bool>(output), GetTensorDims(output));
+TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// TODO(renjieliu): Refactor the logic to avoid duplications.
+TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ bool requires_broadcast = !HaveSameShapes(input1, input2);
+ // TODO(renjieliu): Support quantized data.
+ switch (input1->type) {
+ case kTfLiteFloat32:
+ TF_LITE_COMPARISON(float, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt32:
+ TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast);
+ break;
+ case kTfLiteInt64:
+ TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast);
+ break;
+ default:
+ context->ReportError(context,
+ "Does not support type %d, requires float|int",
+ input1->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
@@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+} // namespace
} // namespace comparisons
+TfLiteRegistration* Register_EQUAL() {
+ static TfLiteRegistration r = {
+ nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_NOT_EQUAL() {
+ static TfLiteRegistration r = {nullptr, nullptr,
+ comparisons::ComparisonPrepare,
+ comparisons::NotEqualEval};
+ return &r;
+}
+
TfLiteRegistration* Register_GREATER() {
static TfLiteRegistration r = {nullptr, nullptr,
comparisons::ComparisonPrepare,
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 835d238d36..bb02e1c812 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -21,18 +21,17 @@ limitations under the License.
namespace tflite {
namespace {
-using ::testing::ElementsAreArray;
+using ::testing::ElementsAre;
-class GreaterOpModel : public SingleOpModel {
+class ComparisonOpModel : public SingleOpModel {
public:
- GreaterOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
+ ComparisonOpModel(std::initializer_list<int> input1_shape,
+ std::initializer_list<int> input2_shape,
+ TensorType input_type, BuiltinOperator op) {
input1_ = AddInput(input_type);
input2_ = AddInput(input_type);
output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions,
- CreateGreaterOptions(builder_).Union());
+ ConfigureBuiltinOp(op);
BuildInterpreter({input1_shape, input2_shape});
}
@@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel {
int input1_;
int input2_;
int output_;
+
+ void ConfigureBuiltinOp(BuiltinOperator op) {
+ switch (op) {
+ case BuiltinOperator_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_EqualOptions,
+ CreateEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_NOT_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_NotEqualOptions,
+ CreateNotEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterOptions,
+ CreateGreaterOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_GREATER_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions,
+ CreateGreaterEqualOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS: {
+ SetBuiltinOp(op, BuiltinOptions_LessOptions,
+ CreateLessOptions(builder_).Union());
+ break;
+ }
+ case BuiltinOperator_LESS_EQUAL: {
+ SetBuiltinOp(op, BuiltinOptions_LessEqualOptions,
+ CreateLessEqualOptions(builder_).Union());
+ break;
+ }
+ default: { FAIL() << "We shouldn't get here."; }
+ }
+ }
};
-TEST(ComparisonsTest, GreaterFloat) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+TEST(ComparisonsTest, EqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterInt) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcast) {
- GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
-TEST(ComparisonsTest, GreaterBroadcastTwoD) {
- GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+TEST(ComparisonsTest, EqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false,
+ false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class GreaterEqualOpModel : public SingleOpModel {
- public:
- GreaterEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_GREATER_EQUAL,
- BuiltinOptions_GreaterEqualOptions,
- CreateGreaterEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
+TEST(ComparisonsTest, NotEqualFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
- int input1() { return input1_; }
- int input2() { return input2_; }
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+TEST(ComparisonsTest, NotEqualInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
- private:
- int input1_;
- int input2_;
- int output_;
-};
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, NotEqualBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_NOT_EQUAL);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, true, true, true, true, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
+
+TEST(ComparisonsTest, GreaterFloat) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+ model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterInt) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcast) {
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+ model.PopulateTensor<int>(model.input2(), {7});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
+}
+
+TEST(ComparisonsTest, GreaterBroadcastTwoD) {
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER);
+ model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
+ model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
+}
TEST(ComparisonsTest, GreaterEqualFloat) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualInt) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcast) {
- GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) {
- GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_GREATER_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false,
- false, true, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(false, true, true, false, false, true, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessOpModel : public SingleOpModel {
- public:
- LessOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape, TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
- CreateLessOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
TEST(ComparisonsTest, LessFloat) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessInt) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcast) {
- LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessBroadcastTwoD) {
- LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, false, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, false, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
-class LessEqualOpModel : public SingleOpModel {
- public:
- LessEqualOpModel(std::initializer_list<int> input1_shape,
- std::initializer_list<int> input2_shape,
- TensorType input_type) {
- input1_ = AddInput(input_type);
- input2_ = AddInput(input_type);
- output_ = AddOutput(TensorType_BOOL);
- SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions,
- CreateLessEqualOptions(builder_).Union());
- BuildInterpreter({input1_shape, input2_shape});
- }
-
- int input1() { return input1_; }
- int input2() { return input2_; }
-
- std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
- std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
-
- private:
- int input1_;
- int input2_;
- int output_;
-};
-
TEST(ComparisonsTest, LessEqualFloat) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualInt) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {1, 2, 7, 5});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcast) {
- LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
model.PopulateTensor<int>(model.input2(), {7});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4));
}
TEST(ComparisonsTest, LessEqualBroadcastTwoD) {
- LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+ ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32,
+ BuiltinOperator_LESS_EQUAL);
model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8});
model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
model.Invoke();
- EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
- true, false, true, false}));
- EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAre(true, false, false, true, true, false, true, false));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4));
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index ee42e5cdc8..a4fe9e5550 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// optimized_ops.h, in order to avoid a DCHECK(!im2col_data).
data->need_im2col =
(params->stride_width != 1 || params->stride_height != 1 ||
- filter_width != 1 || filter_height != 1);
+ params->dilation_width_factor != 1 ||
+ params->dilation_height_factor != 1 || filter_width != 1 ||
+ filter_height != 1);
// If we're using the optimized multithreaded EigenTensor implementation of
// convolution, it expects the filter weights to be transposed compared to
// the normal TF Lite buffer format. Typical TF Lite weights are
@@ -177,9 +179,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
- bool hasBias = node->inputs->size == 3;
+ bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
- TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2);
+ TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
@@ -202,9 +204,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO(ahentz): At this point the optimized versions require 'bias'. We can
// either change that or document that convolution requires it.
- TF_LITE_ENSURE(context, hasBias);
+ TF_LITE_ENSURE(context, has_bias);
- if (hasBias) {
+ if (has_bias) {
bias = &context->tensors[node->inputs->data[2]];
if (data_type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
@@ -224,29 +226,30 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto computeOutSize = [padding](int imageSize, int filterSize, int stride,
- int dilationRate) -> int {
- int effectiveFilterSize = (filterSize - 1) * dilationRate + 1;
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - effectiveFilterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int outWidth = computeOutSize(width, filter_width, params->stride_width,
- params->dilation_width_factor);
- int outHeight = computeOutSize(height, filter_height, params->stride_height,
- params->dilation_height_factor);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
+ int out_height =
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
data->padding.height =
ComputePadding(params->stride_height, params->dilation_height_factor,
- height, filter_height, outHeight);
+ height, filter_height, out_height);
data->padding.width =
ComputePadding(params->stride_width, params->dilation_width_factor, width,
- filter_width, outWidth);
+ filter_width, out_width);
- TF_LITE_ENSURE(context, hasBias);
+ TF_LITE_ENSURE(context, has_bias);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -255,8 +258,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
TF_LITE_ENSURE(context, real_multiplier < 1.0);
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
CalculateActivationRangeUint8(params->activation, output,
&data->output_activation_min,
&data->output_activation_max);
@@ -264,8 +268,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
output_size->data[0] = batches;
- output_size->data[1] = outHeight;
- output_size->data[2] = outWidth;
+ output_size->data[1] = out_height;
+ output_size->data[2] = out_width;
output_size->data[3] = channels_out;
auto output_status = context->ResizeTensor(context, output, output_size);
@@ -305,18 +309,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* hwcn_weights =
&context->tensors[node->temporaries->data[data->hwcn_weights_index]];
hwcn_weights->type = data_type;
- hwcn_weights->allocation_type = kTfLiteDynamic;
- // Make sure we release any previous allocations before we reallocate.
- // TODO(petewarden): Persistent arenas would be a better fit for this, but
- // they aren't fully implemented yet.
- if (hwcn_weights->data.raw) {
- free(hwcn_weights->data.raw);
- hwcn_weights->data.raw = nullptr;
- }
+ hwcn_weights->allocation_type = kTfLiteArenaRwPersistent;
- // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and
- // ResizeTensor will actually allocate space for it. The would be more
- // efficient if we placed hwcn_weights_status in the persistent arena.
auto hwcn_weights_status =
context->ResizeTensor(context, hwcn_weights, hwcn_weights_size);
if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status;
@@ -378,8 +372,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
KernelType effective_kernel_type;
if (((kernel_type == kMultithreadOptimized) ||
(kernel_type == kCblasOptimized)) &&
@@ -424,6 +418,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
filter_data = GetTensorData<float>(filter);
}
multithreaded_ops::Conv(
+ *eigen_support::GetThreadPoolDevice(context),
GetTensorData<float>(input), GetTensorDims(input), filter_data,
GetTensorDims(filter), GetTensorData<float>(bias),
GetTensorDims(bias), params->stride_width, params->stride_height,
@@ -455,9 +450,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
- bool hasBias = node->inputs->size == 3;
+ bool has_bias = node->inputs->size == 3;
TfLiteTensor* bias =
- hasBias ? &context->tensors[node->inputs->data[2]] : nullptr;
+ has_bias ? &context->tensors[node->inputs->data[2]] : nullptr;
TfLiteTensor* im2col =
data->need_im2col
? &context->tensors[node->temporaries->data[data->im2col_index]]
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index a308de055f..16e5f1d065 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -173,8 +173,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
const Dims<4>&, const float*, const Dims<4>&, int, int,
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
new file mode 100644
index 0000000000..0c532cac5a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -0,0 +1,591 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <numeric>
+#include <vector>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace detection_postprocess {
+
+// Input tensors
+constexpr int kInputTensorBoxEncodings = 0;
+constexpr int kInputTensorClassPredictions = 1;
+constexpr int kInputTensorAnchors = 2;
+
+// Output tensors
+constexpr int kOutputTensorDetectionBoxes = 0;
+constexpr int kOutputTensorDetectionClasses = 1;
+constexpr int kOutputTensorDetectionScores = 2;
+constexpr int kOutputTensorNumDetections = 3;
+
+constexpr size_t kNumCoordBox = 4;
+constexpr size_t kBatchSize = 1;
+
+// Object Detection model produces axis-aligned boxes in two formats:
+// BoxCorner represents the upper right (xmin, ymin) and
+// lower left corner (xmax, ymax).
+// CenterSize represents the center (xcenter, ycenter), height and width.
+// BoxCornerEncoding and CenterSizeEncoding are related as follows:
+// ycenter = y / y_scale * anchor.h + anchor.y;
+// xcenter = x / x_scale * anchor.w + anchor.x;
+// half_h = 0.5*exp(h/ h_scale)) * anchor.h;
+// half_w = 0.5*exp(w / w_scale)) * anchor.w;
+// ymin = ycenter - half_h
+// ymax = ycenter + half_h
+// xmin = xcenter - half_w
+// xmax = xcenter + half_w
+struct BoxCornerEncoding {
+ float ymin;
+ float xmin;
+ float ymax;
+ float xmax;
+};
+
+struct CenterSizeEncoding {
+ float y;
+ float x;
+ float h;
+ float w;
+};
+// We make sure that the memory allocations are contiguous with static assert.
+static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
+ "Size of BoxCornerEncoding is 4 float values");
+static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
+ "Size of CenterSizeEncoding is 4 float values");
+
+struct OpData {
+ int max_detections;
+ int max_classes_per_detection;
+ float non_max_suppression_score_threshold;
+ float intersection_over_union_threshold;
+ int num_classes;
+ CenterSizeEncoding scale_values;
+ // Indices of Temporary tensors
+ int decoded_boxes_index;
+ int scores_index;
+ int active_candidate_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+ op_data->max_detections = m["max_detections"].AsInt32();
+ op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
+ op_data->non_max_suppression_score_threshold =
+ m["nms_score_threshold"].AsFloat();
+ op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
+ op_data->num_classes = m["num_classes"].AsInt32();
+ op_data->scale_values.y = m["y_scale"].AsFloat();
+ op_data->scale_values.x = m["x_scale"].AsFloat();
+ op_data->scale_values.h = m["h_scale"].AsFloat();
+ op_data->scale_values.w = m["w_scale"].AsFloat();
+ context->AddTensors(context, 1, &op_data->decoded_boxes_index);
+ context->AddTensors(context, 1, &op_data->scores_index);
+ context->AddTensors(context, 1, &op_data->active_candidate_index);
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+// TODO(chowdhery): Add to kernel_util.h
+TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
+ std::initializer_list<int> values) {
+ TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
+ int index = 0;
+ for (int v : values) {
+ size->data[index] = v;
+ ++index;
+ }
+ return context->ResizeTensor(context, tensor, size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ // Inputs: box_encodings, scores, anchors
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* input_class_predictions =
+ GetInput(context, node, kInputTensorClassPredictions);
+ const TfLiteTensor* input_anchors =
+ GetInput(context, node, kInputTensorAnchors);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
+ // number of detected boxes
+ const int num_detected_boxes =
+ op_data->max_detections * op_data->max_classes_per_detection;
+
+ // Outputs: detection_boxes, detection_scores, detection_classes,
+ // num_detections
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
+ // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
+ TfLiteTensor* detection_boxes =
+ GetOutput(context, node, kOutputTensorDetectionBoxes);
+ detection_boxes->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_boxes,
+ {kBatchSize, num_detected_boxes, kNumCoordBox});
+
+ // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
+ TfLiteTensor* detection_classes =
+ GetOutput(context, node, kOutputTensorDetectionClasses);
+ detection_classes->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
+
+ // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
+ TfLiteTensor* detection_scores =
+ GetOutput(context, node, kOutputTensorDetectionScores);
+ detection_scores->type = kTfLiteFloat32;
+ SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
+
+ // Output Tensor num_detections: size is set to 1
+ TfLiteTensor* num_detections =
+ GetOutput(context, node, kOutputTensorNumDetections);
+ num_detections->type = kTfLiteFloat32;
+ // TODO (chowdhery): Make it a scalar when available
+ SetTensorSizes(context, num_detections, {1});
+
+ // Temporary tensors
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(3);
+ node->temporaries->data[0] = op_data->decoded_boxes_index;
+ node->temporaries->data[1] = op_data->scores_index;
+ node->temporaries->data[2] = op_data->active_candidate_index;
+
+ // decoded_boxes
+ TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
+ decoded_boxes->type = kTfLiteFloat32;
+ decoded_boxes->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, decoded_boxes,
+ {input_box_encodings->dims->data[1], kNumCoordBox});
+
+ // scores
+ TfLiteTensor* scores = &context->tensors[op_data->scores_index];
+ scores->type = kTfLiteFloat32;
+ scores->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, scores,
+ {input_class_predictions->dims->data[1],
+ input_class_predictions->dims->data[2]});
+
+ // active_candidate
+ TfLiteTensor* active_candidate =
+ &context->tensors[op_data->active_candidate_index];
+ active_candidate->type = kTfLiteUInt8;
+ active_candidate->allocation_type = kTfLiteArenaRw;
+ SetTensorSizes(context, active_candidate,
+ {input_box_encodings->dims->data[1]});
+
+ return kTfLiteOk;
+}
+
+class Dequantizer {
+ public:
+ Dequantizer(int zero_point, float scale)
+ : zero_point_(zero_point), scale_(scale) {}
+ float operator()(uint8 x) {
+ return (static_cast<float>(x) - zero_point_) * scale_;
+ }
+
+ private:
+ int zero_point_;
+ float scale_;
+};
+
+void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
+ float quant_zero_point, float quant_scale,
+ CenterSizeEncoding* box_centersize) {
+ const uint8* boxes =
+ GetTensorData<uint8>(input_box_encodings) + kNumCoordBox * idx;
+ Dequantizer dequantize(quant_zero_point, quant_scale);
+ box_centersize->y = dequantize(boxes[0]);
+ box_centersize->x = dequantize(boxes[1]);
+ box_centersize->h = dequantize(boxes[2]);
+ box_centersize->w = dequantize(boxes[3]);
+}
+
+template <class T>
+T ReInterpretTensor(const TfLiteTensor* tensor) {
+ // TODO (chowdhery): check float
+ const float* tensor_base = tensor->data.f;
+ return reinterpret_cast<T>(tensor_base);
+}
+
+template <class T>
+T ReInterpretTensor(TfLiteTensor* tensor) {
+ // TODO (chowdhery): check float
+ float* tensor_base = tensor->data.f;
+ return reinterpret_cast<T>(tensor_base);
+}
+
+TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
+ OpData* op_data) {
+ // Parse input tensor boxencodings
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
+ const int num_boxes = input_box_encodings->dims->data[1];
+ TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox);
+ const TfLiteTensor* input_anchors =
+ GetInput(context, node, kInputTensorAnchors);
+
+ // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
+ CenterSizeEncoding box_centersize;
+ CenterSizeEncoding scale_values = op_data->scale_values;
+ CenterSizeEncoding anchor;
+ for (int idx = 0; idx < num_boxes; ++idx) {
+ switch (input_box_encodings->type) {
+ // Quantized
+ case kTfLiteUInt8:
+ DequantizeBoxEncodings(
+ input_box_encodings, idx,
+ static_cast<float>(input_box_encodings->params.zero_point),
+ static_cast<float>(input_box_encodings->params.scale),
+ &box_centersize);
+ DequantizeBoxEncodings(
+ input_anchors, idx,
+ static_cast<float>(input_anchors->params.zero_point),
+ static_cast<float>(input_anchors->params.scale), &anchor);
+ break;
+ // Float
+ case kTfLiteFloat32:
+ box_centersize = ReInterpretTensor<const CenterSizeEncoding*>(
+ input_box_encodings)[idx];
+ anchor =
+ ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
+ break;
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+
+ float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y;
+ float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x;
+ float half_h =
+ 0.5f * static_cast<float>(std::exp(box_centersize.h / scale_values.h)) *
+ anchor.h;
+ float half_w =
+ 0.5f * static_cast<float>(std::exp(box_centersize.w / scale_values.w)) *
+ anchor.w;
+ TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+ auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
+ box.ymin = ycenter - half_h;
+ box.xmin = xcenter - half_w;
+ box.ymax = ycenter + half_h;
+ box.xmax = xcenter + half_w;
+ }
+ return kTfLiteOk;
+}
+
+void DecreasingPartialArgSort(const float* values, int num_values,
+ int num_to_sort, int* indices) {
+ std::iota(indices, indices + num_values, 0);
+ std::partial_sort(
+ indices, indices + num_to_sort, indices + num_values,
+ [&values](const int i, const int j) { return values[i] > values[j]; });
+}
+
+void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
+ const float threshold,
+ std::vector<float>* keep_values,
+ std::vector<int>* keep_indices) {
+ for (int i = 0; i < values.size(); i++) {
+ if (values[i] >= threshold) {
+ keep_values->emplace_back(values[i]);
+ keep_indices->emplace_back(i);
+ }
+ }
+}
+
+bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
+ for (int i = 0; i < num_boxes; ++i) {
+ // ymax>=ymin, xmax>=xmin
+ auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
+ if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
+ return false;
+ }
+ }
+ return true;
+}
+
+float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
+ const int i, const int j) {
+ auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
+ auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
+ const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
+ const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
+ if (area_i <= 0 || area_j <= 0) return 0.0;
+ const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
+ const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
+ const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
+ const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
+ const float intersection_area =
+ std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
+ std::max<float>(intersection_xmax - intersection_xmin, 0.0);
+ return intersection_area / (area_i + area_j - intersection_area);
+}
+
+// NonMaxSuppressionSingleClass() is O(n^2) pairwise comparison between boxes
+// It assumes all boxes are good in beginning and sorts based on the scores.
+// If lower-scoring box has too much overlap with a higher-scoring box,
+// we get rid of the lower-scoring box.
+TfLiteStatus NonMaxSuppressionSingleClassHelper(
+ TfLiteContext* context, TfLiteNode* node, OpData* op_data,
+ const std::vector<float>& scores, std::vector<int>* selected) {
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int max_detections = op_data->max_detections;
+ const float non_max_suppression_score_threshold =
+ op_data->non_max_suppression_score_threshold;
+ const float intersection_over_union_threshold =
+ op_data->intersection_over_union_threshold;
+ // Maximum detections should be positive.
+ TF_LITE_ENSURE(context, (max_detections >= 0));
+ // intersection_over_union_threshold should be positive
+ // and should be less than 1.
+ TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
+ (intersection_over_union_threshold <= 1.0f));
+ // Validate boxes
+ TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
+
+ // threshold scores
+ std::vector<int> keep_indices;
+ // TODO (chowdhery): Remove the dynamic allocation and replace it
+ // with temporaries, esp for std::vector<float>
+ std::vector<float> keep_scores;
+ SelectDetectionsAboveScoreThreshold(
+ scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
+
+ int num_scores_kept = keep_scores.size();
+ std::vector<int> sorted_indices;
+ sorted_indices.resize(num_scores_kept);
+ DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept,
+ sorted_indices.data());
+
+ const int num_boxes_kept = num_scores_kept;
+ const int output_size = std::min(num_boxes_kept, max_detections);
+ selected->clear();
+ TfLiteTensor* active_candidate =
+ &context->tensors[op_data->active_candidate_index];
+ TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes);
+ int num_active_candidate = num_boxes_kept;
+ uint8_t* active_box_candidate = (active_candidate->data.uint8);
+ for (int row = 0; row < num_boxes_kept; row++) {
+ active_box_candidate[row] = 1;
+ }
+
+ for (int i = 0; i < num_boxes_kept; ++i) {
+ if (num_active_candidate == 0 || selected->size() >= output_size) break;
+ if (active_box_candidate[i] == 1) {
+ selected->push_back(keep_indices[sorted_indices[i]]);
+ active_box_candidate[i] = 0;
+ num_active_candidate--;
+ } else {
+ continue;
+ }
+ for (int j = i + 1; j < num_boxes_kept; ++j) {
+ if (active_box_candidate[j] == 1) {
+ float intersection_over_union = ComputeIntersectionOverUnion(
+ decoded_boxes, keep_indices[sorted_indices[i]],
+ keep_indices[sorted_indices[j]]);
+
+ if (intersection_over_union > intersection_over_union_threshold) {
+ active_box_candidate[j] = 0;
+ num_active_candidate--;
+ }
+ }
+ }
+ }
+ return kTfLiteOk;
+}
+
+// This function implements a fast version of Non Maximal Suppression for
+// multiple classes where
+// 1) we keep the top-k scores for each anchor and
+// 2) during NMS, each anchor only uses the highest class score for sorting.
+// 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
+// instead of O(KN^2) where N is the number of anchors and K the number of
+// classes.
+TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
+ TfLiteNode* node,
+ OpData* op_data,
+ const float* scores) {
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* decoded_boxes =
+ &context->tensors[op_data->decoded_boxes_index];
+
+ TfLiteTensor* detection_boxes =
+ GetOutput(context, node, kOutputTensorDetectionBoxes);
+ TfLiteTensor* detection_classes =
+ GetOutput(context, node, kOutputTensorDetectionClasses);
+ TfLiteTensor* detection_scores =
+ GetOutput(context, node, kOutputTensorDetectionScores);
+ TfLiteTensor* num_detections =
+ GetOutput(context, node, kOutputTensorNumDetections);
+
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int num_classes = op_data->num_classes;
+ const int max_categories_per_anchor = op_data->max_classes_per_detection;
+ // The row index offset is 1 if background class is included and 0 otherwise.
+ const int label_offset = 1;
+ TF_LITE_ENSURE(context, (label_offset != -1));
+ TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
+ const int num_classes_with_background = num_classes + label_offset;
+ const int num_categories_per_anchor =
+ std::min(max_categories_per_anchor, num_classes);
+ std::vector<float> max_scores;
+ max_scores.resize(num_boxes);
+ std::vector<int> sorted_class_indices;
+ sorted_class_indices.resize(num_boxes * num_classes);
+ for (int row = 0; row < num_boxes; row++) {
+ const float* box_scores =
+ scores + row * num_classes_with_background + label_offset;
+ int* class_indices = sorted_class_indices.data() + row * num_classes;
+ DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
+ class_indices);
+ max_scores[row] = box_scores[class_indices[0]];
+ }
+ // Perform non-maximal suppression on max scores
+ std::vector<int> selected;
+ NonMaxSuppressionSingleClassHelper(context, node, op_data, max_scores,
+ &selected);
+ // Allocate output tensors
+ int output_box_index = 0;
+ for (const auto& selected_index : selected) {
+ const float* box_scores =
+ scores + selected_index * num_classes_with_background + label_offset;
+ const int* class_indices =
+ sorted_class_indices.data() + selected_index * num_classes;
+
+ for (int col = 0; col < num_categories_per_anchor; ++col) {
+ int box_offset = num_categories_per_anchor * output_box_index + col;
+ // detection_boxes
+ ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
+ ReInterpretTensor<const BoxCornerEncoding*>(
+ decoded_boxes)[selected_index];
+ // detection_classes
+ detection_classes->data.f[box_offset] = class_indices[col];
+ // detection_scores
+ detection_scores->data.f[box_offset] = box_scores[class_indices[col]];
+ output_box_index++;
+ }
+ }
+ num_detections->data.f[0] = output_box_index;
+ return kTfLiteOk;
+}
+
+void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
+ const int num_boxes,
+ const int num_classes_with_background,
+ const TfLiteTensor* scores) {
+ float quant_zero_point =
+ static_cast<float>(input_class_predictions->params.zero_point);
+ float quant_scale = static_cast<float>(input_class_predictions->params.scale);
+ Dequantizer dequantize(quant_zero_point, quant_scale);
+ const uint8* scores_quant = GetTensorData<uint8>(input_class_predictions);
+ for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
+ scores->data.f[idx] = dequantize(scores_quant[idx]);
+ }
+}
+
+TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
+ TfLiteNode* node, OpData* op_data) {
+ // Get the input tensors
+ const TfLiteTensor* input_box_encodings =
+ GetInput(context, node, kInputTensorBoxEncodings);
+ const TfLiteTensor* input_class_predictions =
+ GetInput(context, node, kInputTensorClassPredictions);
+ const int num_boxes = input_box_encodings->dims->data[1];
+ const int num_classes = op_data->num_classes;
+ TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
+ kBatchSize);
+ TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
+ const int num_classes_with_background =
+ input_class_predictions->dims->data[2];
+
+ TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1));
+
+ const TfLiteTensor* scores;
+ switch (input_class_predictions->type) {
+ case kTfLiteUInt8: {
+ TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
+ DequantizeClassPredictions(input_class_predictions, num_boxes,
+ num_classes_with_background, temporary_scores);
+ scores = temporary_scores;
+ } break;
+ case kTfLiteFloat32:
+ scores = input_class_predictions;
+ break;
+ default:
+ // Unsupported type.
+ return kTfLiteError;
+ }
+ NonMaxSuppressionMultiClassFastHelper(context, node, op_data,
+ GetTensorData<float>(scores));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(chowdhery): Generalize for any batch size
+ TF_LITE_ENSURE(context, (kBatchSize == 1));
+ auto* op_data = reinterpret_cast<OpData*>(node->user_data);
+ // These two functions correspond to two blocks in the Object Detection model.
+ // In future, we would like to break the custom op in two blocks, which is
+ // currently not feasible because we would like to input quantized inputs
+ // and do all calculations in float. Mixed quantized/float calculations are
+ // currently not supported in TFLite.
+
+ // This fills in temporary decoded_boxes
+ // by transforming input_box_encodings and input_anchors from
+ // CenterSizeEncodings to BoxCornerEncoding
+ DecodeCenterSizeBoxes(context, node, op_data);
+ // This fills in the output tensors
+ // by choosing effective set of decoded boxes
+ // based on Non Maximal Suppression, i.e. selecting
+ // highest scoring non-overlapping boxes.
+ NonMaxSuppressionMultiClass(context, node, op_data);
+
+ return kTfLiteOk;
+}
+} // namespace detection_postprocess
+
+TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
+ static TfLiteRegistration r = {detection_postprocess::Init,
+ detection_postprocess::Free,
+ detection_postprocess::Prepare,
+ detection_postprocess::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
new file mode 100644
index 0000000000..4e0f8484a3
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+class BaseDetectionPostprocessOpModel : public SingleOpModel {
+ public:
+ BaseDetectionPostprocessOpModel(const TensorData& input1,
+ const TensorData& input2,
+ const TensorData& input3,
+ const TensorData& output1,
+ const TensorData& output2,
+ const TensorData& output3,
+ const TensorData& output4) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ input3_ = AddInput(input3);
+ output1_ = AddOutput(output1);
+ output2_ = AddOutput(output2);
+ output3_ = AddOutput(output3);
+ output4_ = AddOutput(output4);
+
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("max_detections", 3);
+ fbb.Int("max_classes_per_detection", 1);
+ fbb.Float("nms_score_threshold", 0.0);
+ fbb.Float("nms_iou_threshold", 0.5);
+ fbb.Int("num_classes", 2);
+ fbb.Float("y_scale", 10.0);
+ fbb.Float("x_scale", 10.0);
+ fbb.Float("h_scale", 5.0);
+ fbb.Float("w_scale", 5.0);
+ });
+ fbb.Finish();
+ SetCustomOp("TFLite_Detection_PostProcess", fbb.GetBuffer(),
+ Register_DETECTION_POSTPROCESS);
+ BuildInterpreter({GetShape(input1_), GetShape(input2_), GetShape(input3_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+ int input3() { return input3_; }
+
+ template <class T>
+ void SetInput1(std::initializer_list<T> data) {
+ PopulateTensor<T>(input1_, data);
+ }
+
+ template <class T>
+ void SetInput2(std::initializer_list<T> data) {
+ PopulateTensor<T>(input2_, data);
+ }
+
+ template <class T>
+ void SetInput3(std::initializer_list<T> data) {
+ PopulateTensor<T>(input3_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput1() {
+ return ExtractVector<T>(output1_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput2() {
+ return ExtractVector<T>(output2_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput3() {
+ return ExtractVector<T>(output3_);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput4() {
+ return ExtractVector<T>(output4_);
+ }
+
+ std::vector<int> GetOutputShape1() { return GetTensorShape(output1_); }
+ std::vector<int> GetOutputShape2() { return GetTensorShape(output2_); }
+ std::vector<int> GetOutputShape3() { return GetTensorShape(output3_); }
+ std::vector<int> GetOutputShape4() { return GetTensorShape(output4_); }
+
+ protected:
+ int input1_;
+ int input2_;
+ int input3_;
+ int output1_;
+ int output2_;
+ int output3_;
+ int output4_;
+};
+
+TEST(DetectionPostprocessOpTest, FloatTest) {
+ BaseDetectionPostprocessOpModel m(
+ {TensorType_FLOAT32, {1, 6, 4}}, {TensorType_FLOAT32, {1, 6, 3}},
+ {TensorType_FLOAT32, {6, 4}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}});
+
+ // six boxes in center-size encoding
+ m.SetInput1<float>({0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,
+ 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
+ 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0});
+ // class scores - two classes with background
+ m.SetInput2<float>({0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0.,
+ .5, .4, 0., .3, .2});
+ // six anchors in center-size encoding
+ m.SetInput3<float>({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
+ 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0,
+ 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0});
+ // Same boxes in box-corner encoding:
+ // { 0.0, 0.0, 1.0, 1.0,
+ // 0.0, 0.1, 1.0, 1.1,
+ // 0.0, -0.1, 1.0, 0.9,
+ // 0.0, 10.0, 1.0, 11.0,
+ // 0.0, 10.1, 1.0, 11.1,
+ // 0.0, 100.0, 1.0, 101.0}
+ m.Invoke();
+ // detection_boxes
+ // in center-size
+ std::vector<int> output_shape1 = m.GetOutputShape1();
+ EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+ EXPECT_THAT(
+ m.GetOutput1<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0},
+ 1e-1)));
+ // detection_classes
+ std::vector<int> output_shape2 = m.GetOutputShape2();
+ EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput2<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
+ // detection_scores
+ std::vector<int> output_shape3 = m.GetOutputShape3();
+ EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput3<float>(),
+ ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
+ // num_detections
+ std::vector<int> output_shape4 = m.GetOutputShape4();
+ EXPECT_THAT(output_shape4, ElementsAre(1));
+ EXPECT_THAT(m.GetOutput4<float>(),
+ ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
+}
+
+TEST(DetectionPostprocessOpTest, QuantizedTest) {
+ BaseDetectionPostprocessOpModel m(
+ {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},
+ {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0},
+ {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+ {TensorType_FLOAT32, {}});
+ // six boxes in center-size encoding
+ std::vector<std::initializer_list<float>> inputs1 = {
+ {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0,
+ 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}};
+ m.QuantizeAndPopulate<uint8_t>(m.input1(), inputs1[0]);
+ // class scores - two classes with background
+ std::vector<std::initializer_list<float>> inputs2 = {
+ {0., .9, .8, 0., .75, .72, 0., .6, .5, 0., .93, .95, 0., .5, .4, 0., .3,
+ .2}};
+ m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[0]);
+ // six anchors in center-size encoding
+ std::vector<std::initializer_list<float>> inputs3 = {
+ {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
+ 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}};
+ m.QuantizeAndPopulate<uint8_t>(m.input3(), inputs3[0]);
+ m.Invoke();
+ // detection_boxes
+ // in center-size
+ std::vector<int> output_shape1 = m.GetOutputShape1();
+ EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+ EXPECT_THAT(
+ m.GetOutput1<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0},
+ 3e-1)));
+ // detection_classes
+ std::vector<int> output_shape2 = m.GetOutputShape2();
+ EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput2<float>(),
+ ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
+ // detection_scores
+ std::vector<int> output_shape3 = m.GetOutputShape3();
+ EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+ EXPECT_THAT(m.GetOutput3<float>(),
+ ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
+ // num_detections
+ std::vector<int> output_shape4 = m.GetOutputShape4();
+ EXPECT_THAT(output_shape4, ElementsAre(1));
+ EXPECT_THAT(m.GetOutput4<float>(),
+ ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
+}
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d264821e30..bc5c3783fd 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
#define TF_LITE_DIV(type, opname) \
type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
GetTensorData<float>(input2), GetTensorDims(input2), \
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index f1fdb42624..4f0d020793 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -14,31 +14,89 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
-#include "third_party/eigen3/Eigen/Core"
+#include <utility>
+
+#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace eigen_support {
+namespace {
+
+// We have a single global threadpool for all convolution operations. This means
+// that inferences started from different threads may block each other, but
+// since the underlying resource of CPU cores should be consumed by the
+// operations anyway, it shouldn't affect overall performance.
+class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
+ public:
+ // Takes ownership of 'pool'
+ explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
+ ~EigenThreadPoolWrapper() override {}
+
+ void Schedule(std::function<void()> fn) override {
+ pool_->Schedule(std::move(fn));
+ }
+ int NumThreads() const override { return pool_->NumThreads(); }
+ int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
+
+ private:
+ std::unique_ptr<Eigen::ThreadPool> pool_;
+};
-struct RefCountedEigenContext {
+struct RefCountedEigenContext : public TfLiteExternalContext {
+ std::unique_ptr<Eigen::ThreadPoolInterface> thread_pool_wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
int num_references = 0;
};
+RefCountedEigenContext* GetEigenContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedEigenContext*>(
+ context->GetExternalContext(context, kTfLiteEigenContext));
+}
+
+void InitDevice(TfLiteContext* context, RefCountedEigenContext* ptr) {
+ int num_threads = 4;
+ if (context->recommended_num_threads != -1) {
+ num_threads = context->recommended_num_threads;
+ }
+ ptr->device.reset(); // destroy before we invalidate the thread pool
+ ptr->thread_pool_wrapper.reset(
+ new EigenThreadPoolWrapper(new Eigen::ThreadPool(num_threads)));
+ ptr->device.reset(
+ new Eigen::ThreadPoolDevice(ptr->thread_pool_wrapper.get(), num_threads));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ Eigen::setNbThreads(context->recommended_num_threads);
+
+ auto* ptr = GetEigenContext(context);
+ if (ptr != nullptr) {
+ InitDevice(context, ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
if (context->recommended_num_threads != -1) {
Eigen::setNbThreads(context->recommended_num_threads);
}
ptr = new RefCountedEigenContext;
+ ptr->type = kTfLiteEigenContext;
+ ptr->Refresh = Refresh;
ptr->num_references = 0;
- context->eigen_context = ptr;
+ InitDevice(context, ptr);
+ context->SetExternalContext(context, kTfLiteEigenContext, ptr);
}
ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedEigenContext*>(context->eigen_context);
+ auto* ptr = GetEigenContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
@@ -46,14 +104,17 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
- context->eigen_context = nullptr;
+ context->SetExternalContext(context, kTfLiteEigenContext, nullptr);
}
}
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- Eigen::setNbThreads(num_threads);
- DecrementUsageCounter(context);
+const Eigen::ThreadPoolDevice* GetThreadPoolDevice(TfLiteContext* context) {
+ auto* ptr = GetEigenContext(context);
+ if (ptr == nullptr) {
+ TF_LITE_FATAL(
+ "Call to GetFromContext() not preceded by IncrementUsageCounter()");
+ }
+ return ptr->device.get();
}
} // namespace eigen_support
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index aa8c351fd8..ec77856b10 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -17,6 +17,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
+namespace EigenForTFLite {
+class ThreadPoolDevice;
+}
+
namespace tflite {
namespace eigen_support {
@@ -28,8 +32,8 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by Eigen.
-void SetNumThreads(TfLiteContext* context, int num_threads);
+const EigenForTFLite::ThreadPoolDevice* GetThreadPoolDevice(
+ TfLiteContext* context);
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 0bd5046950..59bab3c4ec 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -23,7 +23,7 @@ namespace ops {
namespace builtin {
namespace elementwise {
-TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
@@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArrayCopy(input->dims));
}
-TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node,
+ float float_func(float)) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
@@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
const float* in = GetTensorData<float>(input);
const float* in_end = in + elements;
float* out = output->data.f;
- for (; in < in_end; in++, out++) *out = std::sin(*in);
+ for (; in < in_end; in++, out++) *out = float_func(*in);
return kTfLiteOk;
}
default: {
@@ -55,14 +56,48 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::sin);
+}
+
+TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::log);
+}
+
+TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, std::sqrt);
+}
+
+TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
+ return Eval(context, node, [](float f) { return 1.f / std::sqrt(f); });
+}
+
} // namespace elementwise
TfLiteRegistration* Register_SIN() {
- static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare,
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
elementwise::SinEval};
return &r;
}
+TfLiteRegistration* Register_LOG() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
+ elementwise::LogEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_SQRT() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
+ elementwise::SqrtEval};
+ return &r;
+}
+
+TfLiteRegistration* Register_RSQRT() {
+ static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare,
+ elementwise::RsqrtEval};
+ return &r;
+}
+
} // namespace builtin
} // namespace ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index 412ffb04b9..ce4c602ee5 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -24,12 +24,13 @@ namespace {
using ::testing::ElementsAreArray;
-class SinOpModel : public SingleOpModel {
+class ElementWiseOpModel : public SingleOpModel {
public:
- SinOpModel(std::initializer_list<int> input_shape) {
+ ElementWiseOpModel(BuiltinOperator op,
+ std::initializer_list<int> input_shape) {
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0);
+ SetBuiltinOp(op, BuiltinOptions_NONE, 0);
BuildInterpreter({input_shape});
}
@@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel {
};
TEST(ElementWise, Sin) {
- SinOpModel m({1, 1, 4, 1});
+ ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1});
m.PopulateTensor<float>(m.input(), {0, 3.1415926, -3.1415926, 1});
m.Invoke();
EXPECT_THAT(m.ExtractVector<float>(m.output()),
@@ -50,6 +51,33 @@ TEST(ElementWise, Sin) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Log) {
+ ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 3.1415926, 1, 1});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, Sqrt) {
+ ElementWiseOpModel m(BuiltinOperator_SQRT, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {0, 1, 2, 4});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({0, 1, 1.41421, 2})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
+TEST(ElementWise, Rsqrt) {
+ ElementWiseOpModel m(BuiltinOperator_RSQRT, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 4, 9});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 0.7071, 0.5, 0.33333})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index 7539c0b30d..0ba170a4da 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -24,7 +24,8 @@ limitations under the License.
// Output:
// Output.dim[0] == Tensor[0].dim[0], num of lookups
// Output.dim[1] == Tensor[1].dim[1], num of items per row
-// Each item in output is a raw bytes copy of corresponding item in input.
+// Each item in output is a raw bytes copy of the corresponding item in input,
+// or a dequantized value in the case of a uint8 input.
// When indices are out of bound, the ops will not succeed.
//
@@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- TfLiteTensor* output = GetOutput(context, node, 0);
- const TfLiteTensor* lookup = GetInput(context, node, 0);
- const TfLiteTensor* value = GetInput(context, node, 1);
-
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* lookup, const TfLiteTensor* value,
+ TfLiteTensor* output) {
+ const int row_size = SizeOfDimension(value, 0);
+ const double scaling_factor = value->params.scale;
+
+ // col_size after we flatten tensor into 2D.
+ int col_size = 1;
+ for (int i = 1; i < NumDimensions(value); i++) {
+ col_size *= SizeOfDimension(value, i);
+ }
+
+ for (int i = 0; i < SizeOfDimension(lookup, 0); i++) {
+ int idx = lookup->data.i32[i];
+ if (idx >= row_size || idx < 0) {
+ context->ReportError(context, "Embedding Lookup: index out of bounds.");
+ return kTfLiteError;
+ } else {
+ // Dequantize embedding values.
+ // TODO(alanchiao): refactor scalar multiply into separate function
+ // for ease of adding a neon equivalent if ever necessary.
+ for (int j = 0; j < col_size; j++) {
+ output->data.f[j + i * col_size] =
+ value->data.uint8[j + idx * col_size] * scaling_factor;
+ }
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* lookup = GetInput(context, node, 0);
+ const TfLiteTensor* value = GetInput(context, node, 1);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ switch (value->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, lookup, value, output);
+ case kTfLiteUInt8:
+ return EvalHybrid(context, node, lookup, value, output);
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+}
+
} // namespace embedding_lookup
TfLiteRegistration* Register_EMBEDDING_LOOKUP() {
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
index 9b501878f1..04657fd863 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc
@@ -7,13 +7,14 @@ You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
+distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
+for the specific language governing permissions and limitations under the
+License.
==============================================================================*/
// Unit test for TFLite Lookup op.
+#include <initializer_list>
#include <iomanip>
#include <vector>
@@ -29,12 +30,13 @@ namespace {
using ::testing::ElementsAreArray;
-class EmbeddingLookupOpModel : public SingleOpModel {
+class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
- EmbeddingLookupOpModel(std::initializer_list<int> index_shape,
- std::initializer_list<int> weight_shape) {
+ BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape,
+ TensorType weight_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
- weight_ = AddInput(TensorType_FLOAT32);
+ weight_ = AddInput(weight_type);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
@@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int weight_;
+ int output_;
+};
+
+class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
+
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
TfLiteTensor* tensor = interpreter_->tensor(weight_);
int rows = tensor->dims->data[0];
@@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel {
}
}
}
+};
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
+ public:
+ HybridEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
+ std::initializer_list<int> weight_shape)
+ : BaseEmbeddingLookupOpModel(index_shape, weight_shape,
+ TensorType_UINT8) {}
- private:
- int input_;
- int weight_;
- int output_;
+ void SetWeight(std::initializer_list<float> data) {
+ SymmetricQuantizeAndPopulate(weight_, data);
+ }
};
// TODO(ahentz): write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
EmbeddingLookupOpModel m({3}, {3, 2, 4});
- m.PopulateTensor<int>(0, {1, 0, 2});
+ m.SetInput({1, 0, 2});
m.Set3DWeightMatrix(
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
@@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) {
})));
}
+TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 8});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 4});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
+TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) {
+ HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2});
+ m.SetInput({1, 0, 2});
+ m.SetWeight({
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
+ 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
+ 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
+ },
+ 7.41e-03)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
new file mode 100644
index 0000000000..ed33012864
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -0,0 +1,113 @@
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace expand_dims {
+constexpr int kInput = 0;
+constexpr int kAxis = 1;
+constexpr int kOutput = 0;
+
+namespace {
+TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input,
+ int axis, TfLiteTensor* output) {
+ const TfLiteIntArray& input_dims = *input.dims;
+ if (axis < 0) {
+ axis = input_dims.size + 1 + axis;
+ }
+ TF_LITE_ENSURE(context, axis <= input_dims.size);
+
+ TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1);
+ for (int i = 0; i < output_dims->size; ++i) {
+ if (i < axis) {
+ output_dims->data[i] = input_dims.data[i];
+ } else if (i == axis) {
+ output_dims->data[i] = 1;
+ } else {
+ output_dims->data[i] = input_dims.data[i - 1];
+ }
+ }
+
+ return context->ResizeTensor(context, output, output_dims);
+}
+
+TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
+ const TfLiteTensor& axis, int* axis_value) {
+ TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1);
+ switch (axis.type) {
+ case kTfLiteInt32:
+ *axis_value = *GetTensorData<int32_t>(&axis);
+ return kTfLiteOk;
+ case kTfLiteInt64:
+ *axis_value = *GetTensorData<int64_t>(&axis);
+ return kTfLiteOk;
+ default:
+ return kTfLiteError;
+ }
+}
+
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ if (IsConstantTensor(axis)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ return ExpandTensorDim(context, *input, axis_value, output);
+ }
+ SetTensorToDynamic(output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ // Just copy input to output.
+ const TfLiteTensor* input = GetInput(context, node, kInput);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const TfLiteTensor* axis = GetInput(context, node, kAxis);
+ if (IsDynamicTensor(output)) {
+ int axis_value;
+ TF_LITE_ENSURE_OK(context,
+ GetAxisValueFromTensor(context, *axis, &axis_value));
+ TF_LITE_ENSURE_OK(context,
+ ExpandTensorDim(context, *input, axis_value, output));
+ }
+ memcpy(output->data.raw, input->data.raw, input->bytes);
+ return kTfLiteOk;
+}
+
+} // namespace expand_dims
+TfLiteRegistration* Register_EXPAND_DIMS() {
+ static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare,
+ expand_dims::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
new file mode 100644
index 0000000000..50dc860e5a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -0,0 +1,83 @@
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ExpandDimsOpModel : public SingleOpModel {
+ public:
+ ExpandDimsOpModel(std::initializer_list<int> input_shape,
+ TensorType input_type) {
+ input_ = AddInput(input_type);
+ axis_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(input_type);
+ SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions,
+ 0);
+ BuildInterpreter({input_shape, {1}});
+ }
+ void SetInputFloat(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ void SetAxis(int axis) { PopulateTensor<int32_t>(axis_, {axis}); }
+ std::vector<float> GetValuesFloat() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int axis_;
+ int output_;
+};
+
+TEST(ExpandDimsOpTest, DifferentAxis) {
+ ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32);
+ std::initializer_list<float> values = {-1.f, 1.f, -2.f, 2.f};
+ m.SetInputFloat(values);
+ m.SetAxis(0);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
+
+ m.SetAxis(1);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2}));
+
+ m.SetAxis(2);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1}));
+
+ m.SetAxis(-1);
+ m.Invoke();
+ EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1}));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
new file mode 100644
index 0000000000..f8927a0799
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace fake_quant {
+
+// This file has reference implementation of FakeQuant.
+enum KernelType {
+ kReference,
+};
+
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ const TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpContext op_context(context, node);
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
+ op_context.output->type = op_context.input->type;
+ return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+
+ const auto* params =
+ reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
+
+ reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
+ GetTensorDims(op_context.input), params->min,
+ params->max, params->num_bits,
+ GetTensorData<float>(op_context.output),
+ GetTensorDims(op_context.output));
+
+ return kTfLiteOk;
+}
+
+} // namespace fake_quant
+
+TfLiteRegistration* Register_FAKE_QUANT_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare,
+ fake_quant::Eval<fake_quant::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/fake_quant_test.cc b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
new file mode 100644
index 0000000000..11a02f7ed7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/fake_quant_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FakeQuantOpModel : public SingleOpModel {
+ public:
+ FakeQuantOpModel(const TensorData& input, const TensorType& output, float min,
+ float max, int num_bits) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FAKE_QUANT, BuiltinOptions_FakeQuantOptions,
+ CreateFakeQuantOptions(builder_, min, max, num_bits).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ template <class T>
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor(input_, data);
+ }
+
+ template <class T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(FakeQuantOpTest, FloatPositiveRange8Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(
+ m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({0, 1, 0.25098, 0.498039, 0.443137, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange8Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 8);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.896471, 0.247059, 0.501176, 0.444706, 0})));
+}
+
+TEST(FakeQuantOpTest, FloatPositiveRange16Test) {
+ std::initializer_list<float> data = {0.0, 1.0, 0.25,
+ 0.50, 0.4444444, 0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, 0.0f,
+ 1.0f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, 1, 0.250004, 0.500008, 0.44445, 1.5259e-05})));
+}
+
+TEST(FakeQuantOpTest, FloatNegativeRange16Test) {
+ std::initializer_list<float> data = {0.0, -0.9, 0.25,
+ 0.50, 0.4444444, -0.00001};
+ FakeQuantOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32, -0.9f,
+ 0.9f, 16);
+ m.SetInput<float>(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear(
+ {0, -0.900014, 0.249998, 0.499995, 0.444431, 0})));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 989920622d..3b203dd480 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -63,6 +63,7 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
+constexpr int kShuffledInputWorkspaceTensor = 1;
constexpr int kScratchBufferTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -87,7 +88,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ // Shuffled formats need a workspace to store the shuffled input activations.
+ const int expected_outputs_count =
+ params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
+ : 2;
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
@@ -105,7 +110,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
const int batch_size = input_size / filter->dims->data[1];
const int num_units = filter->dims->data[0];
- TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]);
+ TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]);
if (bias) {
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
@@ -118,11 +123,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
TF_LITE_ENSURE(context, real_multiplier < 1.0);
- QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier,
- &data->output_shift);
- CalculateActivationRangeUint8(params->activation, output,
- &data->output_activation_min,
- &data->output_activation_max);
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
}
// If we have to perform on-the-fly quantization (with quantized weights and
@@ -277,30 +283,49 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type) \
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
type::FullyConnected( \
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
data->output_multiplier, data->output_shift, \
data->output_activation_min, data->output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context)
+ GetTensorData<output_data_type>(output), GetTensorDims(output), \
+ gemm_context)
if (kernel_type == kReference) {
- TF_LITE_FULLY_CONNECTED(reference_ops);
- } else if (kernel_type == kPie) {
- if (input->type == kTfLiteFloat32) {
- // Pie currently only supports quantized models and float inputs/outputs.
- TfLiteTensor* input_quantized =
- &context->tensors[node->temporaries->data[0]];
- return EvalPieQuantized(context, node, params, data, input, filter, bias,
- input_quantized, output);
- } else {
- // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
- // we just defer to the MINI ones.
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
}
+ } else if (kernel_type == kPie && input->type == kTfLiteFloat32) {
+ // Pie currently only supports quantized models and float inputs/outputs.
+ TfLiteTensor* input_quantized =
+ &context->tensors[node->temporaries->data[0]];
+ return EvalPieQuantized(context, node, params, data, input, filter, bias,
+ input_quantized, output);
} else {
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
+ }
}
#undef TF_LITE_FULLY_CONNECTED
@@ -308,13 +333,51 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params,
+ OpData* data, const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* output,
+ TfLiteTensor* shuffled_input_workspace) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ // TODO(b/110697972) decide more consistently if / how / where we want
+ // to perform this kind of runtime data type checks.
+ if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 ||
+ bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 ||
+ shuffled_input_workspace->type != kTfLiteUInt8) {
+ context->ReportError(context, "Unexpected data type");
+ return kTfLiteError;
+ }
+
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ type::ShuffledFullyConnected( \
+ GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), \
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), \
+ data->output_multiplier, data->output_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<int16_t>(output), GetTensorDims(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context)
+ if (kernel_type == kReference) {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
+ } else {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
const TfLiteTensor* bias, TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
#define TF_LITE_FULLY_CONNECTED(type) \
type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
GetTensorData<float>(filter), GetTensorDims(filter), \
@@ -351,8 +414,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalFloat<kernel_type>(context, node, params, data, input, filter,
bias, output);
case kTfLiteUInt8:
- return EvalQuantized<kernel_type>(context, node, params, data, input,
- filter, bias, output);
+ if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
+ TfLiteTensor* shuffled_input_workspace =
+ GetOutput(context, node, kShuffledInputWorkspaceTensor);
+ return EvalShuffledQuantized<kernel_type>(context, node, params, data,
+ input, filter, bias, output,
+ shuffled_input_workspace);
+ } else if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatDefault) {
+ return EvalQuantized<kernel_type>(context, node, params, data, input,
+ filter, bias, output);
+ } else {
+ context->ReportError(context,
+ "Unhandled fully-connected weights format");
+ return kTfLiteError;
+ }
default:
context->ReportError(context, "Type %d not currently supported.",
filter->type);
diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
index 05dd028b48..ec94905697 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
// Unit test for TFLite FULLY_CONNECTED op.
#include <iomanip>
+#include <random>
#include <vector>
#include <gmock/gmock.h>
@@ -133,9 +134,12 @@ static float fully_connected_golden_output[] = {
class BaseFullyConnectedOpModel : public SingleOpModel {
public:
// TODO(ahentz): test different activation types too.
- BaseFullyConnectedOpModel(TfLiteRegistration* registration, int units,
- int batches, const TensorData& input,
- const TensorData& output = {TensorType_FLOAT32})
+ BaseFullyConnectedOpModel(
+ TfLiteRegistration* registration, int units, int batches,
+ const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
+ ActivationFunctionType activation_func = ActivationFunctionType_RELU,
+ FullyConnectedOptionsWeightsFormat weights_format =
+ FullyConnectedOptionsWeightsFormat_DEFAULT)
: batches_(batches), units_(units) {
int total_input_size = 1;
for (int i = 0; i < input.shape.size(); ++i) {
@@ -159,10 +163,13 @@ class BaseFullyConnectedOpModel : public SingleOpModel {
}
output_ = AddOutput(output);
+ if (weights_format != FullyConnectedOptionsWeightsFormat_DEFAULT) {
+ AddOutput({TensorType_UINT8, input.shape});
+ }
SetBuiltinOp(
BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
+ CreateFullyConnectedOptions(builder_, activation_func, weights_format)
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, registration);
@@ -188,13 +195,11 @@ class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel {
public:
using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
- void SetWeights(std::initializer_list<float> f) {
- PopulateTensor(weights_, f);
- }
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
- void SetInput(std::initializer_list<float> data) {
+ void SetInput(const std::vector<float>& data) {
PopulateTensor(input_, data);
}
void SetInput(int offset, float* begin, float* end) {
@@ -208,20 +213,50 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel {
public:
using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel;
- void SetBias(std::initializer_list<float> data) {
+ void SetBias(const std::vector<float>& data) {
QuantizeAndPopulate<int32_t>(bias_, data);
}
- void SetWeights(std::initializer_list<float> data) {
+ void SetWeights(const std::vector<float>& data) {
QuantizeAndPopulate<uint8_t>(weights_, data);
}
- void SetInput(std::initializer_list<float> data) {
+ void ShuffleAndSetWeights(const std::vector<float>& data, int input_depth,
+ int output_depth) {
+ std::vector<float> shuffled_data(data.size());
+ CHECK_EQ(input_depth % 16, 0);
+ CHECK_EQ(output_depth % 4, 0);
+ float* shuffled_data_ptr = shuffled_data.data();
+ for (int block_o = 0; block_o < output_depth; block_o += 4) {
+ for (int block_i = 0; block_i < input_depth; block_i += 16) {
+ for (int o = 0; o < 4; o++) {
+ for (int i = 0; i < 16; i++) {
+ *shuffled_data_ptr++ =
+ data[(block_o + o) * input_depth + block_i + i];
+ }
+ }
+ }
+ }
+ TfLiteTensor* t = interpreter_->tensor(weights_);
+ auto quantized_data =
+ Quantize<uint8_t>(shuffled_data, t->params.scale, t->params.zero_point);
+ for (uint8_t& q : quantized_data) {
+ q ^= 0x80;
+ }
+ PopulateTensor(weights_, 0, quantized_data.data(),
+ quantized_data.data() + quantized_data.size());
+ }
+ void SetInput(const std::vector<float>& data) {
QuantizeAndPopulate<uint8_t>(input_, data);
}
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ template <typename T>
std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
+ return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
+ GetZeroPoint(output_));
}
};
@@ -256,12 +291,12 @@ class HybridFullyConnectedOpModel : public SingleOpModel {
ops::builtin::Register_FULLY_CONNECTED_PIE());
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
- void SetWeights(std::initializer_list<float> data) {
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+ void SetWeights(const std::vector<float>& data) {
SymmetricQuantizeAndPopulate(weights_, data);
}
- void SetInput(std::initializer_list<float> f) { PopulateTensor(input_, f); }
+ void SetInput(const std::vector<float>& f) { PopulateTensor(input_, f); }
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -340,6 +375,24 @@ TEST_P(FloatFullyConnectedOpTest, SimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
+TEST_P(FloatFullyConnectedOpTest, SimpleTest2) {
+ FloatFullyConnectedOpModel m(GetRegistration(), /*units=*/1, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 2}});
+ m.SetWeights({
+ 2, 4, // u = 0
+ });
+ m.SetBias({1});
+
+ m.SetInput({
+ 1, 2, // b = 0
+ 2, 1, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
+}
+
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
@@ -350,7 +403,7 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
- 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
});
m.SetBias({1, 2, 3});
@@ -361,11 +414,136 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
- 24, 25, 26, //
- 58, 59, 60, //
- })));
- EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(151, 152, 153, 185, 186, 187));
+}
+
+void SimpleTestQuantizedInt16OutputCase(
+ TfLiteRegistration* registration, int input_depth, int output_depth,
+ int batches, FullyConnectedOptionsWeightsFormat weights_format) {
+ const uint8_t kWeightsZeroPoint = 128;
+ const float kWeightsScale = 1.f / 128.f;
+ const uint8_t kInputZeroPoint = 128;
+ const float kInputScale = 1.f / 128.f;
+ const float kInputMin = (0 - kInputZeroPoint) * kInputScale;
+ const float kInputMax = (255 - kInputZeroPoint) * kInputScale;
+ // Output ranges in [-8..8] encoded as int16
+ const float kOutputScale = 8.f / 32768.f;
+ const float kOutputMin = -32768 * kOutputScale;
+ const float kOutputMax = 32767 * kOutputScale;
+
+ QuantizedFullyConnectedOpModel m(
+ registration, output_depth, batches,
+ /*input=*/
+ {TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
+ /*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
+ /*activation_func=*/ActivationFunctionType_NONE, weights_format);
+
+ std::mt19937 random_engine;
+ std::uniform_int_distribution<uint8_t> weights_dist;
+
+ std::vector<float> weights_data(input_depth * output_depth);
+ for (auto& w : weights_data) {
+ uint8_t q = weights_dist(random_engine);
+ w = (q - kWeightsZeroPoint) * kWeightsScale;
+ }
+
+ // Based on weights_format, enforce any shape requirement for that format/path
+ // and set the (possibly shuffled) weights.
+ switch (weights_format) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ m.SetWeights(weights_data);
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ // The shuffled path currently supports only a restrictive subset of
+ // shapes, described by the following assertions:
+ CHECK_EQ(input_depth % 16, 0);
+ CHECK_EQ(output_depth % 4, 0);
+ CHECK(batches == 1 || batches == 4);
+ m.ShuffleAndSetWeights(weights_data, input_depth, output_depth);
+ break;
+ default:
+ LOG(FATAL) << "Unhandled weights format";
+ }
+
+ std::uniform_int_distribution<uint8_t> input_dist;
+ std::vector<float> input_data(input_depth * batches);
+ for (auto& i : input_data) {
+ uint8_t q = input_dist(random_engine);
+ i = (q - kInputZeroPoint) * kInputScale;
+ }
+
+ std::vector<float> bias_data(output_depth);
+ // As the output ranges in [-8, 8], it's reasonable to have bias values
+ // in [-1, 1], this won't result in too much saturation.
+ std::uniform_real_distribution<float> bias_dist(-1.f, 1.f);
+ for (auto& b : bias_data) {
+ b = bias_dist(random_engine);
+ }
+
+ m.SetBias(bias_data);
+ m.SetInput(input_data);
+
+ m.Invoke();
+
+ std::vector<float> expected_output_data(output_depth * batches);
+ for (int b = 0; b < batches; b++) {
+ for (int o = 0; o < output_depth; o++) {
+ float accum = bias_data[o];
+ for (int i = 0; i < input_depth; i++) {
+ accum +=
+ input_data[b * input_depth + i] * weights_data[o * input_depth + i];
+ }
+ accum = std::min(accum, kOutputMax);
+ accum = std::max(accum, kOutputMin);
+ expected_output_data[b * output_depth + o] = accum;
+ }
+ }
+
+ EXPECT_THAT(m.GetDequantizedOutput<int16_t>(),
+ ElementsAreArray(ArrayFloatNear(expected_output_data, 3e-4f)));
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTestQuantizedInt16OutputDefaultWeights) {
+ for (int input_depth : {1, 3, 10, 100}) {
+ for (int output_depth : {1, 3, 10, 100}) {
+ for (int batch : {1, 3, 10, 100}) {
+ SimpleTestQuantizedInt16OutputCase(
+ GetRegistration(), input_depth, output_depth, batch,
+ FullyConnectedOptionsWeightsFormat_DEFAULT);
+ }
+ }
+ }
+}
+
+TEST_P(QuantizedFullyConnectedOpTest,
+ SimpleTestQuantizedInt16OutputShuffled4x16Int8Weights) {
+ // The shuffled weights block shape is 4x16. The shape of the weights matrix
+ // is: rows = output_depth, cols = input_depth. It must be a multiple of 4x16.
+ // This means that output_depth must be a multiple of 4, and input_deth must
+ // be a multiple of 16.
+ for (int input_depth_numblocks : {1, 3}) {
+ for (int output_depth_numblocks : {1, 3}) {
+ int input_depth = 16 * input_depth_numblocks;
+ int output_depth = 4 * output_depth_numblocks;
+ // The fast shuffled path is currently supporting only batch sizes of 1
+ // and 4. The idea is that the whole point of that path is to go as fast
+ // as possible for small batch size, which requires fully specializing
+ // it for each batch size, and for larger batch sizes the generic
+ // gemmlowp-based implementation is fast enough.
+ for (int batch : {1, 4}) {
+ SimpleTestQuantizedInt16OutputCase(
+ GetRegistration(), input_depth, output_depth, batch,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8);
+ }
+ }
+ }
}
TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
@@ -396,11 +574,11 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantized) {
/*max_abs_error=*/1.3f)));
}
-TEST(FloatFullyConnectedOpTest, SimpleTest4DInput) {
+TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) {
// Note that it is not required that the first dimension be the number of
// batches. All we care is that the input can be evenly distributed in
// batches. In this case, we need the input to have multiples of '2'.
- FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
+ FloatFullyConnectedOpModel m(GetRegistration(),
/*units=*/3, /*batches=*/2,
/*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}});
m.SetWeights({
@@ -444,11 +622,13 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantized) {
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({
- 24, 25, 26, //
- 58, 59, 60, //
- })));
- EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187));
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear({
+ 24, 25, 26, //
+ 58, 59, 60, //
+ })));
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAre(151, 152, 153, 185, 186, 187));
}
INSTANTIATE_TEST_CASE_P(
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 6a2341461f..2b2a9e6620 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -40,10 +40,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// Only INT32 positions are supported.
TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32);
- // Check that input and output types match.
- TF_LITE_ENSURE_EQ(context, input->type, output->type);
- // TODO(mgubin): only 0D or 1D positions are currently supported.
- TF_LITE_ENSURE(context, NumDimensions(positions) <= 1);
+ // Assign to output the input type.
+ output->type = input->type;
// TODO(mgubin): Only default axis == 0 is supported.
TF_LITE_ENSURE_EQ(context, params->axis, 0);
// Check conditions for different types.
@@ -102,6 +100,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_GATHER(int32_t, int32_t);
break;
case kTfLiteString: {
+ // TODO(mgubin): Currently support only for 1D output tensors.
DynamicBuffer buffer;
const int32* indexes = positions->data.i32;
const int num_strings = GetStringCount(input);
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index cdadbeda18..1d4292955c 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -96,6 +96,15 @@ TEST(GatherOpTest, Test0DIndexWith0DResult) {
EXPECT_TRUE(m.GetOutputShape().empty());
}
+TEST(GatherOpTest, Test2DIndexWith2DResult) {
+ GatherOpModel m({3}, TensorType_FLOAT32, {1, 2});
+ m.SetInputFloat({1.0, 2.0, 3.0});
+ m.SetPositions({1, 0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0, 1.0})));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+}
+
TEST(FloatGatherOpTest, Duplicate) {
GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2});
m.SetInputFloat({-2.0, 0.2, 0.7, 0.8});
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index 95f45ea768..ed334af2da 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -14,57 +14,70 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
namespace gemm_support {
+namespace {
-struct RefCountedGemmContext {
- gemmlowp::GemmContext* gemm_context_ = nullptr;
- int num_references_ = 0;
+struct RefCountedGemmContext : public TfLiteExternalContext {
+ std::unique_ptr<gemmlowp::GemmContext> gemm_context;
+ int num_references = 0;
};
+RefCountedGemmContext* GetGemmLowpContext(TfLiteContext* context) {
+ return reinterpret_cast<RefCountedGemmContext*>(
+ context->GetExternalContext(context, kTfLiteGemmLowpContext));
+}
+
+TfLiteStatus Refresh(TfLiteContext* context) {
+ auto* ptr = GetGemmLowpContext(context);
+ if (ptr != nullptr) {
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
void IncrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
ptr = new RefCountedGemmContext;
- ptr->gemm_context_ = new gemmlowp::GemmContext();
+ ptr->type = kTfLiteGemmLowpContext;
+ ptr->Refresh = Refresh;
+ ptr->gemm_context.reset(new gemmlowp::GemmContext());
if (context->recommended_num_threads != -1) {
- ptr->gemm_context_->set_max_num_threads(context->recommended_num_threads);
+ ptr->gemm_context->set_max_num_threads(context->recommended_num_threads);
}
- ptr->num_references_ = 0;
- context->gemm_context = ptr;
+ ptr->num_references = 0;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, ptr);
}
- ptr->num_references_++;
+ ptr->num_references++;
}
void DecrementUsageCounter(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to DecrementUsageCounter() not preceded by "
"IncrementUsageCounter()");
}
- if (--ptr->num_references_ == 0) {
- delete ptr->gemm_context_;
+ if (--ptr->num_references == 0) {
delete ptr;
- context->gemm_context = nullptr;
+ context->SetExternalContext(context, kTfLiteGemmLowpContext, nullptr);
}
}
gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
- auto* ptr = reinterpret_cast<RefCountedGemmContext*>(context->gemm_context);
+ auto* ptr = GetGemmLowpContext(context);
if (ptr == nullptr) {
TF_LITE_FATAL(
"Call to GetFromContext() not preceded by IncrementUsageCounter()");
}
- return ptr->gemm_context_;
-}
-
-void SetNumThreads(TfLiteContext* context, int num_threads) {
- IncrementUsageCounter(context);
- GetFromContext(context)->set_max_num_threads(num_threads);
- DecrementUsageCounter(context);
+ return ptr->gemm_context.get();
}
} // namespace gemm_support
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index f033501cb6..37af772c68 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -45,9 +45,6 @@ void IncrementUsageCounter(TfLiteContext* context);
// 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
-// Set the number of threads that can be used by gemmlowp.
-void SetNumThreads(TfLiteContext* context, int num_threads);
-
} // namespace gemm_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 0a5223b235..7962fcbc9d 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -177,6 +177,40 @@ cc_library(
)
cc_library(
+ name = "legacy_optimized_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "optimized/depthwiseconv_float.h",
+ "optimized/depthwiseconv_uint8.h",
+ "optimized/depthwiseconv_uint8_3x3_filter.h",
+ "optimized/legacy_optimized_ops.h",
+ "optimized/optimized_ops.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":quantization_util",
+ ":strided_slice_logic",
+ ":types",
+ ":legacy_reference_base",
+ ":round",
+ "//third_party/eigen3",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
name = "optimized",
hdrs = [
"optimized/cblas_conv.h",
@@ -274,6 +308,37 @@ cc_library(
)
cc_library(
+ name = "legacy_reference_base",
+ srcs = [],
+ hdrs = [
+ "common.h",
+ "reference/depthwiseconv_float.h",
+ "reference/depthwiseconv_uint8.h",
+ "reference/legacy_reference_ops.h",
+ "reference/reference_ops.h",
+ ],
+ deps = [
+ ":quantization_util",
+ ":round",
+ ":strided_slice_logic",
+ ":types",
+ "//third_party/eigen3",
+ "@gemmlowp",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ ] + select({
+ ":haswell": tflite_deps_intel,
+ ":ios_x86_64": tflite_deps_intel,
+ ":k8": tflite_deps_intel,
+ ":x86": tflite_deps_intel,
+ ":x86_64": tflite_deps_intel,
+ ":darwin": tflite_deps_intel,
+ ":darwin_x86_64": tflite_deps_intel,
+ ":freebsd": tflite_deps_intel,
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
name = "reference",
hdrs = ["tensor.h"],
deps = [
@@ -474,8 +539,9 @@ cc_test(
)
cc_test(
- name = "resize_bilinear_float_test",
- srcs = ["resize_bilinear_float_test.cc"],
+ name = "resize_bilinear_test",
+ srcs = ["resize_bilinear_test.cc"],
+ tags = ["tflite_not_portable"],
deps = [
":optimized_base",
":reference_base",
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 67e3810479..a0e382edb6 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -63,6 +63,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
// Quantize input from float to uint8 + quantization params (scaling
// factor).
float unused_min, unused_max;
+ // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function)
+ // whichever is faster.
for (int b = 0; b < batch_size; ++b) {
const int offset = b * input_size;
tensor_utils::SymmetricQuantizeFloats(
@@ -147,6 +149,7 @@ void LstmStep(
input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
input_gate_scratch, /*result_stride=*/1);
}
+
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
forget_gate_scratch, /*result_stride=*/1);
@@ -161,8 +164,7 @@ void LstmStep(
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch,
- /*result_stride=*/1);
+ n_batch, input_gate_scratch, /*result_stride=*/1);
}
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
@@ -253,5 +255,263 @@ void LstmStep(
output_state_ptr);
}
+// TODO(alanchiao): move this to tensor_utils.
+void VectorMultiply(const int8_t* vector, const int v_size, const float scale,
+ float* result) {
+ for (int i = 0; i < v_size; ++i) {
+ *result++ = scale * *vector++;
+ }
+}
+
+void LstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_output, float* input_gate_scratch, float* forget_gate_scratch,
+ float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_cell_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ VectorMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ VectorMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ VectorMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale, recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index f3f42f0840..2a11b37a60 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -92,6 +92,89 @@ void LstmStep(
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch);
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+void LstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_output, float* input_gate_scratch, float* forget_gate_scratch,
+ float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_cell_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch);
+
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index b7531ea2e2..d2f1103e14 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -32,19 +32,21 @@ namespace tflite {
namespace {
void RunLogSoftmaxFloatReference(const uint8* input_data,
- const Dims<4>& dims_common, int32 input_offset,
- const double input_scale, int stride,
- float beta, uint8* reference_output_data) {
- const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
- reference_dequant_data.data(), dims_common);
- optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common,
- reference_output_float_data.data(), dims_common);
+ reference_ops::Dequantize(
+ input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
+ reference_dequant_data.data(), ToRuntimeDims(shape_common));
+ optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data(), shape_common);
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
- const Dims<4>& dims_common, const string& check_label,
- bool be_exacting) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
// While calculating some metrics in floating point, we work with quantized
// scaling.
std::vector<int> diff(buffer_size);
@@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Runs the LogSoftmax and compares against the float reference implementation
// and the quantized reference implementation.
-void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
- int32 input_offset, const double input_scale,
- int stride, float beta) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+void RunOneLogSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> optimized_logsoftmax_output(buffer_size);
std::vector<uint8> reference_float_logsoftmax_output(buffer_size);
std::vector<uint8> reference_quant_logsoftmax_output(buffer_size);
- RunLogSoftmaxFloatReference(input_data, dims_common, input_offset,
+ RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
input_scale, stride, beta,
reference_float_logsoftmax_output.data());
@@ -116,32 +118,33 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
int32 reverse_scaling_divisor;
int reverse_scaling_right_shift;
static const int kScaledDiffIntegerBits = 5;
- tflite::PreprocessLogSoftmaxScaling(
+ tflite::PreprocessLogSoftmaxScalingExp(
beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
&input_beta_left_shift, &reverse_scaling_divisor,
&reverse_scaling_right_shift);
+ reverse_scaling_right_shift *= -1;
// diff_min has a negative value, and is used to limit the maximum magnitude
// of the diffs, which are <= 0.
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier,
+ optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, reverse_scaling_divisor,
reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), dims_common);
+ optimized_logsoftmax_output.data(), shape_common);
reference_ops::LogSoftmax(
- input_data, dims_common, input_beta_multiplier, input_beta_left_shift,
+ input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), dims_common);
+ reference_quant_logsoftmax_output.data(), shape_common);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_quant_logsoftmax_output.data(), dims_common,
+ reference_quant_logsoftmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Quant reference vs float reference", false);
}
@@ -164,13 +167,13 @@ bool TryOneUniformLogSoftmax() {
const int32 input_offset = UniformRandomInt(-256, 0);
static constexpr float beta = 1.0f;
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandom(&input_data);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}
@@ -202,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) {
const int middle_min = UniformRandomInt(0, 255);
const int sides_max = UniformRandomInt(0, middle_min);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
sides_max);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 8cd72239e9..0ce64f8c70 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -26,7 +26,7 @@ namespace optimized_ops {
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
-
+#include <stddef.h>
// clang-format gets confused with this file and ends up formatting lines to
// be larger than 80 characters. Turn off here and back on at the end of the
// file.
@@ -42,6 +42,7 @@ struct DepthwiseConvParams {
int64_t input_row_size;
int64_t output_depth;
int64_t output_row_size;
+ int64_t filter_row_size;
int32 input_offset;
int32 output_offset;
int32 filter_offset;
@@ -51,6 +52,8 @@ struct DepthwiseConvParams {
int32 output_shift;
int32 input_width;
int32 input_height;
+ int32 stride_width;
+ int32 stride_height;
int32 output_width;
int32 output_height;
};
@@ -65,17 +68,20 @@ struct DepthwiseConvParams {
#define OFFSET_INPUT_ROW_SIZE 8
#define OFFSET_OUTPUT_DEPTH 16
#define OFFSET_OUTPUT_ROW_SIZE 24
-#define OFFSET_INPUT_OFFSET 32
-#define OFFSET_OUTPUT_OFFSET 36
-#define OFFSET_FILTER_OFFSET 40
-#define OFFSET_OUTPUT_MULTIPLIER 44
-#define OFFSET_OUTPUT_ACTIVATION_MIN 48
-#define OFFSET_OUTPUT_ACTIVATION_MAX 52
-#define OFFSET_OUTPUT_SHIFT 56
-#define OFFSET_INPUT_WIDTH 60
-#define OFFSET_INPUT_HEIGHT 64
-#define OFFSET_OUTPUT_WIDTH 68
-#define OFFSET_OUTPUT_HEIGHT 72
+#define OFFSET_FILTER_ROW_SIZE 32
+#define OFFSET_INPUT_OFFSET 40
+#define OFFSET_OUTPUT_OFFSET 44
+#define OFFSET_FILTER_OFFSET 48
+#define OFFSET_OUTPUT_MULTIPLIER 52
+#define OFFSET_OUTPUT_ACTIVATION_MIN 56
+#define OFFSET_OUTPUT_ACTIVATION_MAX 60
+#define OFFSET_OUTPUT_SHIFT 64
+#define OFFSET_INPUT_WIDTH 68
+#define OFFSET_INPUT_HEIGHT 72
+#define OFFSET_STRIDE_WIDTH 76
+#define OFFSET_STRIDE_HEIGHT 80
+#define OFFSET_OUTPUT_WIDTH 84
+#define OFFSET_OUTPUT_HEIGHT 88
static_assert(offsetof(DepthwiseConvParams, input_depth) ==
OFFSET_INPUT_DEPTH, "");
@@ -85,6 +91,8 @@ static_assert(offsetof(DepthwiseConvParams, output_depth) ==
OFFSET_OUTPUT_DEPTH, "");
static_assert(offsetof(DepthwiseConvParams, output_row_size) ==
OFFSET_OUTPUT_ROW_SIZE, "");
+static_assert(offsetof(DepthwiseConvParams, filter_row_size) ==
+ OFFSET_FILTER_ROW_SIZE, "");
static_assert(offsetof(DepthwiseConvParams, input_offset) ==
OFFSET_INPUT_OFFSET, "");
static_assert(offsetof(DepthwiseConvParams, output_offset) ==
@@ -103,6 +111,10 @@ static_assert(offsetof(DepthwiseConvParams, input_width) ==
OFFSET_INPUT_WIDTH, "");
static_assert(offsetof(DepthwiseConvParams, input_height) ==
OFFSET_INPUT_HEIGHT, "");
+static_assert(offsetof(DepthwiseConvParams, stride_width) ==
+ OFFSET_STRIDE_WIDTH, "");
+static_assert(offsetof(DepthwiseConvParams, stride_height) ==
+ OFFSET_STRIDE_HEIGHT, "");
static_assert(offsetof(DepthwiseConvParams, output_width) ==
OFFSET_OUTPUT_WIDTH, "");
static_assert(offsetof(DepthwiseConvParams, output_height) ==
@@ -114,7 +126,7 @@ struct DepthwiseConvWindow {};
template <>
struct DepthwiseConvWindow<8, 1, 1> {
public:
- static void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
int64_t input_row_size, int32 output_window_height,
int32 output_window_width,
@@ -1097,7 +1109,7 @@ struct DepthwiseConvWindow<8, 1, 1> {
template <>
struct DepthwiseConvWindow<8, 2, 2> {
- static void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
const int32* bias_ptr, uint8* output_ptr, int64_t input_depth,
int64_t input_row_size, int32 output_window_height,
int32 output_window_width,
@@ -2179,6 +2191,715 @@ struct DepthwiseConvWindow<8, 2, 2> {
}
};
+enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter };
+
+template <EdgeType kEdgeType, int kPadWidth, int kPadHeight>
+struct DepthwiseConvPartial {};
+
+template <>
+struct DepthwiseConvPartial<EdgeType::kCenter, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 1x1 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the 1x1 input and filter values.
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w10\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "cmp x11, #16\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "dup v28.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w10, w10\n"
+ "dup v29.4s, w10\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w9\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w10\n"
+ "dup v25.8h, w9\n"
+
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "subs x11, x11, #8\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "cmp x11, #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28",
+ "v29", "v30", "v31",
+ // We use these general-purpose registers.
+ "x9", "x10", "x11");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
+ }
+};
+
+template <>
+struct DepthwiseConvPartial<EdgeType::kCorner, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 2x2 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 2x2 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "cmp x15, #16\n"
+ "add x12, %[input_ptr], x15\n"
+ "add x13, %[input_ptr], x9\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "add x14, x13, x15\n"
+ "ld1 {v9.8b}, [x12], #8\n"
+ "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+
+ "add x9, %[filter_ptr], x15\n"
+ "ld1 {v10.8b}, [x13], #8\n"
+ "add x10, %[filter_ptr], x6\n"
+ "ld1 {v11.8b}, [x14], #8\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "add x11, x10, x15\n"
+ "ld1 {v1.8b}, [x9], #8\n"
+ "ld1 {v2.8b}, [x10], #8\n"
+ "ld1 {v3.8b}, [x11], #8\n"
+
+ // Load constants.
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w7\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "dup v28.4s, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w7, w7\n"
+ "dup v29.4s, w7\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w6\n"
+ "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w7\n"
+ "dup v25.8h, w6\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "subs x15, x15, #8\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [%[input_ptr]], #8\n"
+ "cmp x15, #16\n"
+ "ld1 {v0.8b}, [%[filter_ptr]], #8\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], #8\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "ld1 {v1.8b}, [x9], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], #8\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v2.8b}, [x10], #8\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x14], #8\n"
+ "ld1 {v3.8b}, [x11], #8\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18",
+ "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31",
+ // We use these general-purpose registers.
+ "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
+ }
+};
+
+template <>
+struct DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 2x3 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 2x3 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n"
+ "mov x12, %[input_ptr]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "mov x9, %[filter_ptr]\n"
+ "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+ "add x13, x12, x11\n"
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+
+ "ld1 {v8.8b}, [x12], x7\n"
+ "add x10, x9, x14\n"
+ "ld1 {v9.8b}, [x12], x7\n"
+ "cmp x15, #16\n"
+ "ld1 {v10.8b}, [x12]\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+ "ld1 {v11.8b}, [x13], x7\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "ld1 {v12.8b}, [x13], x7\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "ld1 {v0.8b}, [x9], x7\n"
+ "ld1 {v1.8b}, [x9], x7\n"
+ "ld1 {v2.8b}, [x9]\n"
+ "ld1 {v3.8b}, [x10], x7\n"
+ "ld1 {v4.8b}, [x10], x7\n"
+ "ld1 {v5.8b}, [x10]\n"
+
+ // Load constants.
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "dup v28.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w13, w13\n"
+ "dup v29.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w13\n"
+ "dup v25.8h, w12\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "mov x12, %[input_ptr]\n"
+ "subs x15, x15, #8\n"
+ "add x13, x12, x11\n"
+ "cmp x15, #16\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "mov x9, %[filter_ptr]\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [x12], x7\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "add x10, x9, x14\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "ld1 {v9.8b}, [x12], x7\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x12]\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v0.8b}, [x9], x7\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13], x7\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "ld1 {v1.8b}, [x9], x7\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x13], x7\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "ld1 {v2.8b}, [x9]\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x13]\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "ld1 {v3.8b}, [x10], x7\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "ld1 {v4.8b}, [x10], x7\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "ld1 {v5.8b}, [x10]\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
+ }
+};
+
+template <>
+struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
+ static inline void Run(const uint8* input_ptr, const uint8* filter_ptr,
+ const int32* bias_ptr, uint8* output_ptr,
+ const DepthwiseConvParams* params_ptr) {
+#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1"
+#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2"
+ asm volatile(
+ // Performs depthwise convolutions for an input window of size 3x2 and
+ // padding of 1 across the full depth. Expects |input_ptr| and
+ // |filter_ptr| to be pointing to the beginning of the 3x2 input and
+ // filter values.
+
+ // Load input and filter values.
+ "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n"
+ "mov x12, %[input_ptr]\n"
+ "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n"
+ "mov x7, %[filter_ptr]\n"
+ "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n"
+ "add x13, x12, x11\n"
+ "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n"
+ "add x14, x13, x11\n"
+
+ "ld1 {v8.8b}, [x12], x6\n"
+ "add x9, x7, x5\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "cmp x15, #16\n"
+ "add x10, x9, x5\n"
+ "ld1 {v10.8b}, [x13], x6\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+ "ld1 {v11.8b}, [x13]\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "ld1 {v12.8b}, [x14], x6\n"
+ "ld1 {v13.8b}, [x14]\n"
+
+ "ld1 {v0.8b}, [x7], x6\n"
+ "ld1 {v1.8b}, [x7]\n"
+ "ld1 {v2.8b}, [x9], x6\n"
+ "ld1 {v3.8b}, [x9]\n"
+ "ld1 {v4.8b}, [x10], x6\n"
+ "ld1 {v5.8b}, [x10]\n"
+
+ // Load constants.
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
+ "dup v26.8h, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
+ "dup v27.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "dup v28.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
+ "neg w13, w13\n"
+ "dup v29.4s, w13\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n"
+ "dup v30.4s, w12\n"
+ "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n"
+ "dup v31.4s, w13\n"
+ "dup v25.8h, w12\n"
+
+ // Add input and filter offsets.
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n"
+
+ //"loop_%=:\n"
+ DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n"
+ "mov x12, %[input_ptr]\n"
+ "subs x15, x15, #8\n"
+ "add x13, x12, x11\n"
+ "cmp x15, #16\n"
+ "add x14, x13, x11\n"
+ "add %[input_ptr], %[input_ptr], #8\n"
+
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "mov x7, %[filter_ptr]\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "ld1 {v8.8b}, [x12], x6\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "add x9, x7, x5\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "add x10, x9, x5\n"
+ "ld1 {v9.8b}, [x12]\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "add %[filter_ptr], %[filter_ptr], #8\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "ld1 {v10.8b}, [x13], x6\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "ld1 {v0.8b}, [x7], x6\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "ld1 {v11.8b}, [x13]\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "ld1 {v1.8b}, [x7]\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "ld1 {v12.8b}, [x14], x6\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "ld1 {v2.8b}, [x9], x6\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+ "ld1 {v13.8b}, [x14]\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "ld1 {v3.8b}, [x9]\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "ld1 {v4.8b}, [x10], x6\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "ld1 {v5.8b}, [x10]\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "uaddw v8.8h, v26.8h, v8.8b\n"
+ "st1 {v16.8b}, [%[output_ptr]], #8\n"
+ "uaddw v9.8h, v26.8h, v9.8b\n"
+ "uaddw v10.8h, v26.8h, v10.8b\n"
+ "uaddw v11.8h, v26.8h, v11.8b\n"
+ "uaddw v12.8h, v26.8h, v12.8b\n"
+ "uaddw v13.8h, v26.8h, v13.8b\n"
+
+ "uaddw v0.8h, v25.8h, v0.8b\n"
+ "uaddw v1.8h, v25.8h, v1.8b\n"
+ "uaddw v2.8h, v25.8h, v2.8b\n"
+ "ld1 {v16.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v3.8h, v25.8h, v3.8b\n"
+ "ld1 {v17.4s}, [%[bias_ptr]], #16\n"
+ "uaddw v4.8h, v25.8h, v4.8b\n"
+ "uaddw v5.8h, v25.8h, v5.8b\n"
+
+ "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n"
+
+ DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n"
+ "smlal v16.4s, v0.4h, v8.4h\n"
+ "smlal2 v17.4s, v0.8h, v8.8h\n"
+ "smlal v16.4s, v1.4h, v9.4h\n"
+ "smlal2 v17.4s, v1.8h, v9.8h\n"
+ "smlal v16.4s, v2.4h, v10.4h\n"
+ "smlal2 v17.4s, v2.8h, v10.8h\n"
+ "smlal v16.4s, v3.4h, v11.4h\n"
+ "smlal2 v17.4s, v3.8h, v11.8h\n"
+ "smlal v16.4s, v4.4h, v12.4h\n"
+ "smlal2 v17.4s, v4.8h, v12.8h\n"
+ "smlal v16.4s, v5.4h, v13.4h\n"
+ "smlal2 v17.4s, v5.8h, v13.8h\n"
+
+ "sqrdmulh v16.4s, v16.4s, v27.4s\n"
+ "sqrdmulh v17.4s, v17.4s, v27.4s\n"
+ "and v18.16b, v16.16b, v29.16b\n"
+ "and v19.16b, v17.16b, v29.16b\n"
+ "sshr v18.4s, v18.4s, #31\n"
+ "sshr v19.4s, v19.4s, #31\n"
+ "sqadd v16.4s, v16.4s, v18.4s\n"
+ "sqadd v17.4s, v17.4s, v19.4s\n"
+ "srshl v16.4s, v16.4s, v29.4s\n"
+ "srshl v17.4s, v17.4s, v29.4s\n"
+ "add v16.4s, v16.4s, v28.4s\n"
+ "add v17.4s, v17.4s, v28.4s\n"
+ "smax v16.4s, v16.4s, v30.4s\n"
+ "smax v17.4s, v17.4s, v30.4s\n"
+ "smin v16.4s, v16.4s, v31.4s\n"
+ "smin v17.4s, v17.4s, v31.4s\n"
+ "sqxtn v16.4h, v16.4s\n"
+ "sqxtn2 v16.8h, v17.4s\n"
+ "sqxtun v16.8b, v16.8h\n"
+ "st1 {v16.8b}, [%[output_ptr]]\n"
+ :
+ // Outputs.
+ [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr),
+ [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr)
+ :
+ // Inputs.
+ [params_ptr] "r"(params_ptr)
+ :
+ // Clobbers.
+ "cc", "memory",
+ // We use these NEON registers.
+ "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12",
+ "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29",
+ "v30", "v31",
+ // We use these general-purpose registers.
+ "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15");
+#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP
+#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP
+ }
+};
+
#undef OFFSET_INPUT_DEPTH
#undef OFFSET_INPUT_ROW_SIZE
#undef OFFSET_OUTPUT_DEPTH
@@ -2266,7 +2987,7 @@ template <int32 kStrideWidth, int32 kStrideHeight>
struct DepthwiseConvMultiRow {
using ConvKernel = DepthwiseConvThroughDepth<kStrideWidth, kStrideHeight>;
- static inline void Run(const uint8* input_data, int32 start_x, int32 start_y,
+ static inline void Run(const uint8* input_data, int32 start_x, int32 end_x,
const uint8* filter_data, const int32* bias_data,
uint8* output_data, const DepthwiseConvParams& params,
const ShuffleParams& shuffle_params,
@@ -2286,7 +3007,7 @@ struct DepthwiseConvMultiRow {
// preshuffle the input data to maximize locality.
if (params.output_depth > 64 ||
(params.output_depth <= 64 && params.input_width > 150)) {
- for (; out_x <= (params.output_width - shuffle_params.output_width);
+ for (; out_x <= (end_x - shuffle_params.output_width);
out_x += shuffle_params.output_width) {
const uint8* input_ptr = input_data;
const int32* bias_ptr = bias_data;
@@ -2344,7 +3065,7 @@ struct DepthwiseConvMultiRow {
}
}
- const int32 output_leftover_width = params.output_width - out_x;
+ const int32 output_leftover_width = end_x - out_x;
if (output_leftover_width > 0) {
ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0,
params.output_depth, params.input_depth,
@@ -2354,6 +3075,105 @@ struct DepthwiseConvMultiRow {
}
};
+// Processes the borders of the input for pad_width and pad_height = 1.
+// Calls 4 asm kernels:
+// * 1x1 input shape.
+// * Corner edges.
+// * Horizontal edges.
+// * Vertical edges.
+inline void DepthwiseConvHandlePadding(const uint8* input_data,
+ const uint8* filter_data, const int32* bias_data, uint8* output_data,
+ const DepthwiseConvParams& params) {
+ if (params.input_width == 1 && params.input_height == 1) {
+ const uint8* filter_ptr = filter_data + params.filter_row_size
+ + params.output_depth;
+ DepthwiseConvPartial<EdgeType::kCenter, 1, 1>::Run(input_data, filter_ptr,
+ bias_data, output_data, &params);
+ return;
+ }
+
+ const int32 out_x_start_corner = 0;
+ const int32 out_x_end_corner = params.output_width - 1;
+ const int32 out_y_start_corner = 0;
+ const int32 out_y_end_corner = params.output_height - 1;
+
+ // Handle top row.
+ const uint8* input_ptr = input_data;
+ const uint8* filter_ptr = filter_data + params.filter_row_size
+ + params.output_depth;
+ uint8* output_ptr = output_data;
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+
+ input_ptr += (params.stride_width - 1) * params.input_depth;
+ filter_ptr = filter_data + params.filter_row_size;
+ output_ptr += params.output_depth;
+
+ for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner;
+ out_x++) {
+ DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_depth;
+ output_ptr += params.output_depth;
+ }
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+
+ // Handle left side.
+ input_ptr = input_data + (params.stride_width - 1) * params.input_row_size;
+ filter_ptr = filter_data + params.input_depth;
+ output_ptr = output_data + params.output_row_size;
+
+ for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner;
+ out_y++) {
+ DepthwiseConvPartial<EdgeType::kVertical, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_row_size;
+ output_ptr += params.output_row_size;
+ }
+
+ // Handle right side.
+ input_ptr = input_data + (params.input_width - 2) * params.input_depth
+ + (params.stride_width - 1) * params.input_row_size;
+ filter_ptr = filter_data;
+ output_ptr = output_data + params.output_row_size +
+ (params.output_width - 1) * params.output_depth;
+
+ for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner;
+ out_y++) {
+ DepthwiseConvPartial<EdgeType::kVertical, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_row_size;
+ output_ptr += params.output_row_size;
+ }
+
+ // Handle bottom row.
+ input_ptr = input_data + (params.input_height - 2) * params.input_row_size;
+ filter_ptr = filter_data + params.output_depth;
+ output_ptr = output_data +
+ (params.output_height - 1) * params.output_row_size;
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+
+ input_ptr += (params.stride_width == 1) ? 0 : params.input_depth;
+ filter_ptr = filter_data;
+ output_ptr += params.output_depth;
+
+ for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner;
+ out_x++) {
+ DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1>::Run(
+ input_ptr, filter_ptr, bias_data, output_ptr, &params);
+ input_ptr += params.stride_width * params.input_depth;
+ output_ptr += params.output_depth;
+ }
+
+ DepthwiseConvPartial<EdgeType::kCorner, 1, 1>::Run(input_ptr, filter_ptr,
+ bias_data, output_ptr, &params);
+}
+
inline bool Fast3x3FilterKernelSupported(
const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width,
int32 stride_height, int32 pad_width, int32 pad_height,
@@ -2370,7 +3190,8 @@ inline bool Fast3x3FilterKernelSupported(
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
(stride_width == 1 || stride_width == 2) &&
(stride_height == 1 || stride_height == 2) &&
- (stride_width == stride_height) && pad_width == 0 && pad_height == 0 &&
+ (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
+ (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
(input_depth % 8) == 0 && (output_shift > 0);
if (!supported) {
@@ -2390,8 +3211,26 @@ inline bool Fast3x3FilterKernelSupported(
const int32 in_y_end = in_y_origin + filter_height;
// Supported only if filter on the right and bottom boundary lies completely
- // within the input.
- return in_x_end <= input_width && in_y_end <= input_height;
+ // within the input if padding is zero.
+ if (pad_width == 0 && pad_height == 0) {
+ return in_x_end <= input_width && in_y_end <= input_height;
+ }
+
+ // Else if padding is 1, supported if bottom right filter lies +1 past input
+ // width and height.
+ supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1);
+
+ if (!supported) {
+ return false;
+ }
+
+ // Shapes with width 1 and height > 1, and vice versa are not supported yet.
+ if (input_width == 1) {
+ supported = (input_width == input_height);
+ } else if (input_height == 1) {
+ supported = (input_width == input_height);
+ }
+ return supported;
}
inline void DepthwiseConv3x3Filter(
@@ -2403,12 +3242,15 @@ inline void DepthwiseConv3x3Filter(
int32 output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
DepthwiseConvParams params;
params.input_depth = ArraySize(input_dims, 0);
params.input_width = ArraySize(input_dims, 1);
params.input_height = ArraySize(input_dims, 2);
params.input_row_size = params.input_depth * params.input_width;
params.input_offset = input_offset;
+ params.stride_width = stride_width;
+ params.stride_height = stride_height;
params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
params.output_width = ArraySize(output_dims, 1);
params.output_height = ArraySize(output_dims, 2);
@@ -2422,6 +3264,7 @@ inline void DepthwiseConv3x3Filter(
const int32 filter_height = ArraySize(filter_dims, 2);
const int32 filter_width = ArraySize(filter_dims, 1);
+ params.filter_row_size = params.output_depth * filter_width;
// Algorithm assumes below constraints. It is optimized for depth
// multiplier of 1, 3x3 filter, no padding and strides 1 and 2.
@@ -2432,8 +3275,9 @@ inline void DepthwiseConv3x3Filter(
TFLITE_DCHECK(stride_height == 1 || stride_height == 2);
TFLITE_DCHECK(stride_width == 1 || stride_width == 2);
TFLITE_DCHECK(stride_width == stride_height);
- TFLITE_DCHECK(pad_height == 0);
- TFLITE_DCHECK(pad_width == 0);
+ TFLITE_DCHECK(pad_height == 0 || pad_height == 1);
+ TFLITE_DCHECK(pad_width == 0 || pad_width == 1);
+ TFLITE_DCHECK(pad_width == pad_height);
const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
const int64_t input_batch_size = params.input_row_size * params.input_height;
@@ -2471,7 +3315,26 @@ inline void DepthwiseConv3x3Filter(
const uint8* input_ptr = input_data + b * input_batch_size;
uint8* output_ptr = output_data + b * output_batch_size;
+ int32 out_x = 0;
int32 out_y = 0;
+ int32 end_x = params.output_width;
+ int32 end_y = params.output_height;
+
+ if (pad_width == 1 && pad_height == 1) {
+ DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr,
+ params);
+
+ // Update extents now that the edges have been handled.
+ out_x = 1;
+ end_x = params.output_width - 1;
+ out_y = 1;
+ end_y = params.output_height - 1;
+ const int in_x = (out_x * stride_width) - pad_width;
+ const int in_y = (out_y * stride_height) - pad_height;
+ input_ptr += in_y * params.input_row_size + in_x * params.input_depth;
+ output_ptr += out_y * params.output_row_size
+ + out_x * params.output_depth;
+ }
// Shuffling shapes that maximize width over the shuffle workspace size
// perform better since the inputs are closer together, minimizing
@@ -2486,8 +3349,8 @@ inline void DepthwiseConv3x3Filter(
// Handle 8 rows at a time.
if (params.input_width < four_row_shuffle_params.input_width) {
- for (; out_y <= params.output_height - 8; out_y += 8) {
- conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ for (; out_y <= end_y - 8; out_y += 8) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
output_ptr, params, eight_row_shuffle_params,
shuffle_workspace);
input_ptr += 8 * stride_height * params.input_row_size;
@@ -2497,8 +3360,8 @@ inline void DepthwiseConv3x3Filter(
// Handle 4 rows at a time.
if (params.input_width < two_row_shuffle_params.input_width) {
- for (; out_y <= params.output_height - 4; out_y += 4) {
- conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ for (; out_y <= end_y - 4; out_y += 4) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
output_ptr, params, four_row_shuffle_params,
shuffle_workspace);
input_ptr += 4 * stride_height * params.input_row_size;
@@ -2507,8 +3370,8 @@ inline void DepthwiseConv3x3Filter(
}
// Handle 2 rows at a time.
- for (; out_y <= params.output_height - 2; out_y += 2) {
- conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ for (; out_y <= end_y - 2; out_y += 2) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
output_ptr, params, two_row_shuffle_params,
shuffle_workspace);
input_ptr += 2 * stride_height * params.input_row_size;
@@ -2516,8 +3379,8 @@ inline void DepthwiseConv3x3Filter(
}
// Handle one row at a time.
- for (; out_y < params.output_height; out_y++) {
- conv_multirow_func(input_ptr, 0, out_y, filter_data, bias_data,
+ for (; out_y < end_y; out_y++) {
+ conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data,
output_ptr, params, one_row_shuffle_params,
shuffle_workspace);
input_ptr += stride_height * params.input_row_size;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
new file mode 100644
index 0000000000..6db41d7961
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -0,0 +1,361 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace optimized_ops {
+
+// Unoptimized reference ops:
+using reference_ops::Relu1;
+using reference_ops::Relu6;
+
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+} // namespace optimized_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 27d9224512..4a3545d47a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -35,35 +35,6 @@ limitations under the License.
namespace tflite {
namespace multithreaded_ops {
-class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
- public:
- explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {}
- ~EigenThreadPoolWrapper() override {}
-
- void Schedule(std::function<void()> fn) override {
- pool_->Schedule(std::move(fn));
- }
- int NumThreads() const override { return pool_->NumThreads(); }
- int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
-
- private:
- Eigen::ThreadPool* pool_ = nullptr;
-};
-
-// We have a single global threadpool for all convolution operations. This means
-// that inferences started from different threads may block each other, but
-// since the underlying resource of CPU cores should be consumed by the
-// operations anyway, it shouldn't affect overall performance.
-const Eigen::ThreadPoolDevice& GetThreadPoolDevice() {
- const int thread_count = 4;
- static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count);
- static EigenThreadPoolWrapper* thread_pool_wrapper =
- new EigenThreadPoolWrapper(tp);
- static Eigen::ThreadPoolDevice* device =
- new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count);
- return *device;
-}
-
// Shorthands for the types we need when interfacing with the EigenTensor
// library.
typedef Eigen::TensorMap<
@@ -113,14 +84,13 @@ class EigenTensorConvFunctor {
}
public:
- void operator()(const T* input_data, T* im2col_buffer, int input_batches,
- int input_height, int input_width, int input_depth,
- const T* filter_data, int filter_height, int filter_width,
- int filter_count, int stride_rows, int stride_cols,
- int pad_width, int pad_height, TfLitePadding padding,
- T* output_data, int output_height, int output_width) {
- const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice();
-
+ void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data,
+ T* im2col_buffer, int input_batches, int input_height,
+ int input_width, int input_depth, const T* filter_data,
+ int filter_height, int filter_width, int filter_count,
+ int stride_rows, int stride_cols, int pad_width,
+ int pad_height, TfLitePadding padding, T* output_data,
+ int output_height, int output_width) {
const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 &&
stride_rows == 1 && stride_cols == 1);
if (is_1x1_kernel) {
@@ -162,11 +132,11 @@ class EigenTensorConvFunctor {
}
};
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, TfLitePadding padding,
+inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
+ const Dims<4>& input_dims, const float* filter_data,
+ const Dims<4>& filter_dims, const float* bias_data,
+ const Dims<4>& bias_dims, int stride_width, int stride_height,
+ int pad_width, int pad_height, TfLitePadding padding,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims,
float* im2col_data, const Dims<4>& im2col_dims) {
@@ -180,10 +150,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
EigenTensorConvFunctor<float> conv_functor;
- conv_functor(input_data, im2col_data, batches, input_height, input_width,
- input_depth, filter_data, filter_height, filter_width,
- output_depth, stride_height, stride_width, pad_height, pad_width,
- padding, output_data, output_height, output_width);
+ conv_functor(device, input_data, im2col_data, batches, input_height,
+ input_width, input_depth, filter_data, filter_height,
+ filter_width, output_depth, stride_height, stride_width,
+ pad_height, pad_width, padding, output_data, output_height,
+ output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
bias_data, bias_dims, output_data, output_dims, output_activation_min,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 38ad32c734..c19f8e8a81 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -62,72 +62,35 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows,
sizeof(float32x4_t), (postamble_start >> 2) * sizeof(float32x4_t),
&aligned_vector_cache_free));
- const int kUnrollSize = 2;
for (int b = 0; b < n_batch; b++) {
float* result_in_batch = result + b * m_rows * result_stride;
const float* vector_in_batch = vector + b * m_cols;
-
- const float* matrix_ptr0 = matrix;
- // If there is only 1 row, we don't want to assign an illegal pointer.
- const float* matrix_ptr1 = nullptr;
- if (m_rows > 1) {
- matrix_ptr1 = matrix + m_cols;
- }
+ const float* matrix_row = matrix;
// Cache the vector.
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c);
}
- // Main matrix by vector multiplication loop, which handles two rows of
- // matrix by vector multiplication.
- for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
- float32x4_t acc1_32x4 = vmovq_n_f32(0.0);
- for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
- float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
- float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c);
- // Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
- acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp);
- }
- // Add the 4 intermediate sum values to get the final dot-prod value for
- // this column.
- *result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
- *(result_in_batch + result_stride) +=
- (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) +
- vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3));
- for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
- *(result_in_batch + result_stride) +=
- matrix_ptr1[c] * vector_in_batch[c];
- }
- matrix_ptr0 += kUnrollSize * m_cols;
- matrix_ptr1 += kUnrollSize * m_cols;
- result_in_batch += kUnrollSize * result_stride;
- }
- for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) {
- float32x4_t acc0_32x4 = vmovq_n_f32(0.0);
+ // Main matrix by vector multiplication loop
+ for (int r = 0; r < m_rows; r++) {
+ float32x4_t acc_32x4 = vmovq_n_f32(0.0);
for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) {
float32x4_t temp = vector_cache_float32x4[c >> 2];
- // Load 4 float values from vector1 and vector2 and accumulator.
- float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c);
+ // Load 4 float values from vector and accumulator.
+ float32x4_t v_f32x4 = vld1q_f32(matrix_row + c);
// Vector multiply-accumulate 4 float
- acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp);
+ acc_32x4 = vmlaq_f32(acc_32x4, v_f32x4, temp);
}
// Add the 4 intermediate sum values to get the final dot-prod value for
// this column.
*result_in_batch +=
- (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) +
- vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3));
+ (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) +
+ vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3));
for (int c = postamble_start; c < m_cols; c++) {
- *result_in_batch += matrix_ptr0[c] * vector_in_batch[c];
+ *result_in_batch += matrix_row[c] * vector_in_batch[c];
}
- matrix_ptr0 += m_cols;
+ matrix_row += m_cols;
result_in_batch += result_stride;
}
}
@@ -162,7 +125,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Copy the vector data to an aligned vector.
memcpy(aligned_vec, vectors + batch * m_cols, sizeof(int8) * m_cols);
// Compute dot-product for every column.
@@ -232,7 +195,7 @@ void NeonMatrixBatchVectorMultiplyAccumulate(
int32 neon_sum =
vgetq_lane_s64(pairwiseAdded, 0) + vgetq_lane_s64(pairwiseAdded, 1);
- *result += ((neon_sum + postable_sum) * batch_scaling_factor_inv);
+ *result += ((neon_sum + postable_sum) * batch_scaling_factor);
} // for row
} // for batch
@@ -418,13 +381,14 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
// Vectorized constants.
- const float32x4_t q_factor_f32x4 = vmovq_n_f32(*scaling_factor);
+ const float32x4_t q_factor_f32x4 = vmovq_n_f32(scaling_factor_inv);
const float32x4_t point5_f32x4 = vmovq_n_f32(0.5);
const float32x4_t zero_f32x4 = vmovq_n_f32(0.0);
const int32x4_t scale_i32x4 = vmovq_n_s32(kScale);
@@ -476,7 +440,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
for (int i = postamble_start; i < size; ++i) {
const int32 quantized_value =
- static_cast<int32>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32>(TfLiteRound(scaling_factor_inv * values[i]));
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index f7011b28fd..ebd3b116e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -40,16 +40,31 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
+using reference_ops::ArgMax;
+using reference_ops::ArgMinMax;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
using reference_ops::BroadcastLess;
using reference_ops::BroadcastLessEqual;
+using reference_ops::Concatenation;
+using reference_ops::DepthConcatenation;
+using reference_ops::Dequantize;
+using reference_ops::Div;
+using reference_ops::FakeQuant;
+using reference_ops::Gather;
using reference_ops::Greater;
using reference_ops::GreaterEqual;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::Mean;
using reference_ops::RankOneSelect;
+using reference_ops::Relu1;
+using reference_ops::Relu6;
+using reference_ops::ReluX;
using reference_ops::Select;
+using reference_ops::SpaceToBatchND;
+using reference_ops::StridedSlice;
+using reference_ops::Transpose;
// TODO(b/80247582) Remove this constant.
// This will be phased out as the shifts are revised with more thought. Use of a
@@ -72,6 +87,12 @@ using VectorMap = typename std::conditional<
Eigen::Dynamic, 1>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, 1>>>::type;
+template <typename Scalar>
+VectorMap<Scalar> MapAsVector(Scalar* data, const RuntimeShape& shape) {
+ const int size = shape.FlatSize();
+ return VectorMap<Scalar>(data, size, 1);
+}
+
template <typename Scalar, int N>
VectorMap<Scalar> MapAsVector(Scalar* data, const Dims<N>& dims) {
const int size = FlatSize(dims);
@@ -88,6 +109,23 @@ using MatrixMap = typename std::conditional<
Eigen::Dynamic, Eigen::Dynamic>>,
Eigen::Map<Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic>>>::type;
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithLastDimAsRows(Scalar* data,
+ const RuntimeShape& shape) {
+ const int dims_count = shape.DimensionsCount();
+ const int rows = shape.Dims(dims_count - 1);
+ const int cols = FlatSizeSkipDim(shape, dims_count - 1);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithFirstDimAsCols(Scalar* data,
+ const RuntimeShape& shape) {
+ const int cols = shape.Dims(0);
+ const int rows = FlatSizeSkipDim(shape, 0);
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
template <typename Scalar, int N>
MatrixMap<Scalar> MapAsMatrixWithFirstDimAsRows(Scalar* data,
const Dims<N>& dims) {
@@ -134,16 +172,9 @@ template <typename Scalar, int N>
MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
const Dims<N>& dims,
int rows) {
- int cols = 1;
- bool matched_rows = false;
- for (int d = 0; d < N; d++) {
- cols *= dims.sizes[d];
- if (cols == rows) {
- matched_rows = true;
- cols = 1;
- }
- }
- TFLITE_DCHECK(matched_rows);
+ const int flatsize = FlatSize(dims);
+ TFLITE_DCHECK((flatsize % rows) == 0);
+ const int cols = flatsize / rows;
return MatrixMap<Scalar>(data, rows, cols);
}
@@ -1082,10 +1113,10 @@ struct GemmlowpOutputPipeline {
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8>
Pipeline;
- static Pipeline Make(const int32* bias_data, int output_rows,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max) {
+ static Pipeline MakeExp(const int32* bias_data, int output_rows,
+ int32 output_offset, int32 output_multiplier,
+ int output_left_shift, int32 output_activation_min,
+ int32 output_activation_max) {
ColVectorMap bias_vector(bias_data, output_rows);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
@@ -1093,7 +1124,7 @@ struct GemmlowpOutputPipeline {
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_offset;
quantize_down_stage.result_fixedpoint_multiplier = output_multiplier;
- quantize_down_stage.result_shift = output_shift;
+ quantize_down_stage.result_shift = -output_left_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1146,8 +1177,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
input_data, filter_cols, batches, filter_cols);
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
- bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -1256,11 +1287,11 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
}
// Internal function doing the actual arithmetic work for
-// ExperimentalShuffledFullyConnected.
+// ShuffledFullyConnected.
// May be called either directly by it (single-threaded case) or may be used
// as the 'task' for worker threads to run (multi-threaded case, see
-// ExperimentalShuffledFullyConnectedWorkerTask below).
-inline void ExperimentalShuffledFullyConnectedWorkerImpl(
+// ShuffledFullyConnectedWorkerTask below).
+inline void ShuffledFullyConnectedWorkerImpl(
const uint8* shuffled_input_workspace_data,
const int8* shuffled_weights_data, int batches, int output_depth,
int output_stride, int accum_depth, const int32* bias_data,
@@ -1534,14 +1565,16 @@ inline void ExperimentalShuffledFullyConnectedWorkerImpl(
#endif
}
-// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class
+// Wraps ShuffledFullyConnectedWorkerImpl into a Task class
// to allow using gemmlowp's threadpool.
-struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
- ExperimentalShuffledFullyConnectedWorkerTask(
- const uint8* input_data, const int8* shuffled_weights_data, int batches,
- int output_depth, int output_stride, int accum_depth,
- const int32* bias_data, int32 output_multiplier, int output_shift,
- int16* output_data)
+struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
+ ShuffledFullyConnectedWorkerTask(const uint8* input_data,
+ const int8* shuffled_weights_data,
+ int batches, int output_depth,
+ int output_stride, int accum_depth,
+ const int32* bias_data,
+ int32 output_multiplier, int output_shift,
+ int16* output_data)
: input_data_(input_data),
shuffled_weights_data_(shuffled_weights_data),
batches_(batches),
@@ -1554,7 +1587,7 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
output_data_(output_data) {}
void Run() override {
- ExperimentalShuffledFullyConnectedWorkerImpl(
+ ShuffledFullyConnectedWorkerImpl(
input_data_, shuffled_weights_data_, batches_, output_depth_,
output_stride_, accum_depth_, bias_data_, output_multiplier_,
output_shift_, output_data_);
@@ -1572,15 +1605,14 @@ struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
int16* output_data_;
};
-inline void ExperimentalShuffledFullyConnected(
+inline void ShuffledFullyConnected(
const uint8* input_data, const Dims<4>& input_dims,
const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
int output_shift, int32 output_activation_min, int32 output_activation_max,
int16* output_data, const Dims<4>& output_dims,
uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- gemmlowp::ScopedProfilingLabel label(
- "ExperimentalShuffledFullyConnected/8bit");
+ gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
(void)gemm_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767);
@@ -1664,7 +1696,7 @@ inline void ExperimentalShuffledFullyConnected(
if (thread_count == 1) {
// Single-thread case: do the computation on the current thread, don't
// use a threadpool
- ExperimentalShuffledFullyConnectedWorkerImpl(
+ ShuffledFullyConnectedWorkerImpl(
shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
output_depth, output_depth, accum_depth, bias_data, output_multiplier,
output_shift, output_data);
@@ -1679,7 +1711,7 @@ inline void ExperimentalShuffledFullyConnected(
int row_start = 0;
for (int i = 0; i < thread_count; i++) {
int row_end = std::min(output_depth, row_start + kRowsPerWorker);
- tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask(
+ tasks[i] = new ShuffledFullyConnectedWorkerTask(
shuffled_input_workspace_data,
int8_shuffled_weights_data + row_start * accum_depth, batches,
row_end - row_start, output_depth, accum_depth, bias_data + row_start,
@@ -1777,6 +1809,100 @@ inline void ExtractPatchIntoBufferColumn(
}
template <typename T>
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 byte_zero,
+ T* im2col_data) {
+ // For dilated convolution, the input pixels are not contiguous therefore we
+ // can't use the same opitimizations as Im2Col(). Though note this code would
+ // work fine for the non-dilated case too (though likely a bit slower).
+ gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
+ TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK(im2col_data);
+ const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int input_height = ArraySize(input_dims, 2);
+ const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int filter_height = ArraySize(filter_dims, 2);
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int output_height = ArraySize(output_dims, 2);
+ const int output_width = ArraySize(output_dims, 1);
+ MatchingArraySize(output_dims, 0, filter_dims, 3);
+
+ // Construct the MxN sized im2col matrix.
+ // The rows M, are sub-ordered B x H x W
+ Dims<4> row_dims;
+ row_dims.sizes[0] = output_width;
+ row_dims.sizes[1] = output_height;
+ row_dims.sizes[2] = batches;
+ row_dims.sizes[3] = 1;
+ ComputeStrides(&row_dims);
+
+ // The columns, N, are sub-ordered Kh x Kw x Din
+ Dims<4> col_dims;
+ col_dims.sizes[0] = input_depth;
+ col_dims.sizes[1] = filter_width;
+ col_dims.sizes[2] = filter_height;
+ col_dims.sizes[3] = 1;
+ ComputeStrides(&col_dims);
+
+ // Use dimensions M and N to construct dims for indexing directly into im2col
+ Dims<4> im2col_dims;
+ im2col_dims.sizes[0] = FlatSize(col_dims);
+ im2col_dims.sizes[1] = FlatSize(row_dims);
+ im2col_dims.sizes[2] = 1;
+ im2col_dims.sizes[3] = 1;
+ ComputeStrides(&im2col_dims);
+
+ // Loop through the output rows (B x H x W)
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ // Each im2col row is an output pixel. Arrange the input data in this
+ // row in an order we can conveniently multiply with the filter data.
+ int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ // Loop through all the pixels of the filter (Kh x Kw)
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ if ((in_y >= 0) && (in_y < input_height)) {
+ // Filter row is within the input data.
+ // Loop through all the filter pixels in this row.
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0);
+ T* dst = im2col_data +
+ Offset(im2col_dims, col_offset, row_offset, 0, 0);
+ if ((in_x >= 0) && (in_x < input_width)) {
+ // Filter pixel is within the input, copy the input data.
+ T const* src =
+ input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ memcpy(dst, src, input_depth * sizeof(T));
+ } else {
+ // Filter pixel is outside the input, zero it out.
+ memset(dst, byte_zero, input_depth * sizeof(T));
+ }
+ }
+ } else {
+ // Filter row is outside the input, zero out the entire filter row.
+ int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
+ T* dst =
+ im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
+ memset(dst, byte_zero, filter_width * input_depth * sizeof(T));
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
int stride_height, int pad_width, int pad_height, int kheight,
int kwidth, uint8 byte_zero, T* output_data,
@@ -1816,74 +1942,6 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
kwidth, byte_zero, output_data, output_dims);
}
-inline void DilatedConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height,
- int dilation_width_factor, int dilation_height_factor,
- int pad_width, int pad_height,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
- gemmlowp::ScopedProfilingLabel label("DilatedConv");
- // This is a copy of the reference Conv implementation. We do not currently
- // have an optimized path for dilation.
- (void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- if (bias_data) {
- TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
- }
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- for (int batch = 0; batch < batches; ++batch) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
- float total = 0.f;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
- for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- const int in_x = in_x_origin + dilation_width_factor * filter_x;
- const int in_y =
- in_y_origin + dilation_height_factor * filter_y;
- // If the location is outside the bounds of the input image,
- // use zero as a default value.
- if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
- (in_y < input_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
- float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
- total += (input_value * filter_value);
- }
- }
- }
- }
- float bias_value = 0.0f;
- if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
- }
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(total + bias_value,
- output_activation_min,
- output_activation_max);
- }
- }
- }
- }
-}
-
inline void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
@@ -1892,29 +1950,32 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims,
float* im2col_data, const Dims<4>& im2col_dims) {
- if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) {
- return DilatedConv(input_data, input_dims, filter_data, filter_dims,
- bias_data, bias_dims, stride_width, stride_height,
- dilation_width_factor, dilation_height_factor, pad_width,
- pad_height, output_activation_min, output_activation_max,
- output_data, output_dims, im2col_data, im2col_dims);
- }
-
(void)im2col_data;
(void)im2col_dims;
gemmlowp::ScopedProfilingLabel label("Conv");
+ // NB: static_cast<float>(0x00000000h) == 0.0f
+ const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
const Dims<4>* gemm_input_dims = nullptr;
const int filter_width = ArraySize(filter_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
+ const bool need_dilated_im2col =
+ dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
- if (need_im2col) {
+ if (need_dilated_im2col) {
+ DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
+ stride_height, dilation_width_factor, dilation_height_factor,
+ pad_width, pad_height, output_dims, float_zero_byte,
+ im2col_data);
+ gemm_input_data = im2col_data;
+ gemm_input_dims = &im2col_dims;
+ } else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, 0, im2col_data,
- im2col_dims);
+ pad_height, filter_height, filter_width, float_zero_byte,
+ im2col_data, im2col_dims);
gemm_input_data = im2col_data;
gemm_input_dims = &im2col_dims;
} else {
@@ -2055,8 +2116,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
gemm_input_data, gemm_input_rows, gemm_input_cols);
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
- bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2213,8 +2274,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
input_data, filter_cols, output_cols, filter_cols);
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols, output_rows);
- const auto& output_pipeline = GemmlowpOutputPipeline::Make(
- bias_data, output_rows, output_offset, output_multiplier, output_shift,
+ const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
+ bias_data, output_rows, output_offset, output_multiplier, -output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2301,48 +2362,25 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
- const auto input = MapAsVector(input_data, input_dims);
- auto output = MapAsVector(output_data, output_dims);
+ const auto input = MapAsVector(input_data, input_shape);
+ auto output = MapAsVector(output_data, output_shape);
output = input.cwiseMax(0.0f);
}
-inline void Relu1(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float upper = 1;
- const float lower = -1;
- const float clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
-inline void Relu6(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float val = input_data[i];
- const float upper = 6;
- const float lower = 0;
- const float clamped = val > upper ? upper : val < lower ? lower : val;
- output_data[i] = clamped;
- }
-}
-
template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("L2Normalization");
static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -2358,8 +2396,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
- int* output_shift) {
+inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
+ int32* output_inv_sqrt,
+ int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
@@ -2401,31 +2440,35 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ *output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
+ // Note that input_data advances by depth in the second pass below.
int32 diff = input_data[c] - input_zero_point;
square_l2_norm += diff * diff;
}
int32 inv_l2norm_multiplier;
int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
+ GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
for (int c = 0; c < depth; c++) {
int32 diff = *input_data - input_zero_point;
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
- 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift);
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
int32 unclamped_output_val = 128 + rescaled_diff;
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
*output_data = static_cast<uint8>(output_val);
@@ -2634,25 +2677,13 @@ inline void Add(int left_shift, const uint8* input1_data,
output_activation_max, output_data);
}
-template <FusedActivationFunctionType Ac>
inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
int input1_shift, const int16* input2_data,
const Dims<4>& input2_dims, int input2_shift,
int16 output_activation_min, int16 output_activation_max,
int16* output_data, const Dims<4>& output_dims) {
gemmlowp::ScopedProfilingLabel label("Add/Int16");
- // This is a copy of the reference implementation. We do not currently have a
- // properly optimized version.
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
@@ -2678,6 +2709,42 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
+inline void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(
+ input1_data[i] + input2_data[i], output_activation_min,
+ output_activation_max);
+ }
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
+ input2_shift, output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
template <FusedActivationFunctionType Ac>
void Add(const int32* input1_data, const Dims<4>& input1_dims,
const int32* input2_data, const Dims<4>& input2_dims,
@@ -3178,19 +3245,6 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
output_data, output_dims);
}
-// TODO(aselle): This is not actually optimized yet.
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
- for (int i = 0; i < flat_size; i++) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] / input2_data[i], output_activation_min,
- output_activation_max);
- }
-}
-
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -3356,105 +3410,6 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Concatenation");
- int concat_size = 0;
- for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
- }
- }
- concat_size += ArraySize(*input_dims[i], concat_dim);
- }
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- // for now we dont have a model with a Concatenation
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
- Scalar* output_ptr = output_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
- }
- }
-}
-
-// TODO(prabhumk): This is the same as the reference implementation.
-// TODO(prabhumk): The quantized implementation of concatentation isn't fully
-// quantized as it takes scale as a floating point value. This should be fixed
-// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
- // The arguments input_zeropoint and input_scale are expected to be an array
- // that have the quantization parameters for all the inputs to the concat
- // operator.
- gemmlowp::ScopedProfilingLabel label("Concatenation");
- TFLITE_DCHECK_GT(inputs_count, 1);
- int concat_size = 0;
- for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
- }
- }
- concat_size += ArraySize(*input_dims[i], concat_dim);
- }
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
- const float inverse_output_scale = 1.f / output_scale;
- uint8* output_ptr = output_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
- const uint8* input_ptr = input_data[i] + k * copy_size;
- if (input_zeropoint[i] == output_zeropoint &&
- input_scale[i] == output_scale) {
- memcpy(output_ptr, input_ptr, copy_size);
- } else {
- const float scale = input_scale[i] * inverse_output_scale;
- const float bias = -input_zeropoint[i] * scale;
- for (int j = 0; j < copy_size; ++j) {
- const int32_t value =
- static_cast<int32_t>(round(input_ptr[j] * scale + bias)) +
- output_zeropoint;
- output_ptr[j] =
- static_cast<uint8_t>(std::max(std::min(255, value), 0));
- }
- }
- output_ptr += copy_size;
- }
- }
-}
-
-template <FusedActivationFunctionType Ac, typename Scalar>
-void DepthConcatenation(const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- Concatenation<Ac, Scalar>(0, input_data, input_dims, inputs_count,
- output_data, output_dims);
-}
-
inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
const float* prev_activ_data,
const Dims<4>& prev_activ_dims, const float* weights_data,
@@ -3817,23 +3772,24 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// TODO(benoitjacob) make this a proper reference impl without Eigen!
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// TODO(benoitjacob) get rid of the dynamic memory allocation here!
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -3844,12 +3800,15 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -3867,69 +3826,44 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK_GT(out_count.minCoeff(), 0);
out_mat.array().rowwise() /= out_count.transpose().array();
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("AveragePool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
const int filter_count =
(filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
// 1280 required by Inception v3
@@ -3938,11 +3872,12 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
uint16 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -3973,21 +3908,21 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
-#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
- if (filter_count == FILTER_COUNT) { \
- for (; channel <= depth - 8; channel += 8) { \
- uint16 buf[8]; \
- for (int i = 0; i < 8; i++) { \
- buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
- } \
- uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
- vst1_u8(output_ptr + channel, buf8); \
- } \
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16 buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
}
AVGPOOL_DIVIDING_BY(9)
AVGPOOL_DIVIDING_BY(15)
@@ -3998,15 +3933,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
}
uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
- buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
- buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ buf8 = vmin_u8(buf8, vdup_n_u8(params.quantized_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, buf8);
}
#endif
for (; channel < depth; ++channel) {
uint16 a = (acc[channel] + filter_count / 2) / filter_count;
- a = std::max<uint16>(a, output_activation_min);
- a = std::min<uint16>(a, output_activation_max);
+ a = std::max<uint16>(a, params.quantized_activation_min);
+ a = std::min<uint16>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -4014,54 +3949,22 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int kwidth, int kheight,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
+
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Prefill the output to minimum representable float value
out_mat.setConstant(std::numeric_limits<float>::lowest());
for (int b = 0; b < batches; ++b) {
@@ -4069,12 +3972,15 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- int hpad = h + pad_height;
- int wpad = w + pad_width;
- int h_start =
- (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1;
+ int hpad = h + params.padding_values.height;
+ int wpad = w + params.padding_values.width;
+ int h_start = (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
int h_end = std::min(hpad / stride_height + 1, output_height);
- int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1;
+ int w_start = (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
int w_end = std::min(wpad / stride_width + 1, output_width);
// compute elementwise sum
for (int ph = h_start; ph < h_end; ++ph) {
@@ -4089,78 +3995,55 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
}
}
}
-
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < output_height; ++y) {
- for (int x = 0; x < output_width; ++x) {
- for (int c = 0; c < depth; ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- output_data[Offset(output_dims, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int kwidth, int kheight, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("MaxPool/8bit");
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
// 2048 required by Inception v3
static constexpr int kAccBufferMaxSize = 2048;
TFLITE_DCHECK_LE(depth, kAccBufferMaxSize);
uint8 acc[kAccBufferMaxSize];
memset(acc, 0, depth * sizeof(acc[0]));
const uint8* input_ptr =
- input_data + input_dims.strides[1] * in_x_origin +
- input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch;
+ input_data +
+ depth * (in_x_origin +
+ input_width * (in_y_origin + input_height * batch));
for (int fy = filter_y_start; fy < filter_y_end; fy++) {
- const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] +
- filter_x_start * input_dims.strides[1];
+ const uint8* input_row_ptr =
+ input_ptr + depth * (fy * input_width + filter_x_start);
for (int fx = filter_x_start; fx < filter_x_end; fx++) {
int channel = 0;
#ifdef USE_NEON
@@ -4186,26 +4069,26 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
uint8* output_ptr =
- output_data + Offset(output_dims, 0, out_x, out_y, batch);
+ output_data + Offset(output_shape, batch, out_y, out_x, 0);
int channel = 0;
#ifdef USE_NEON
for (; channel <= depth - 16; channel += 16) {
uint8x16_t a = vld1q_u8(acc + channel);
- a = vminq_u8(a, vdupq_n_u8(output_activation_max));
- a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ a = vminq_u8(a, vdupq_n_u8(params.quantized_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(params.quantized_activation_min));
vst1q_u8(output_ptr + channel, a);
}
for (; channel <= depth - 8; channel += 8) {
uint8x8_t a = vld1_u8(acc + channel);
- a = vmin_u8(a, vdup_n_u8(output_activation_max));
- a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ a = vmin_u8(a, vdup_n_u8(params.quantized_activation_max));
+ a = vmax_u8(a, vdup_n_u8(params.quantized_activation_min));
vst1_u8(output_ptr + channel, a);
}
#endif
for (; channel < depth; ++channel) {
uint8 a = acc[channel];
- a = std::max<uint8>(a, output_activation_min);
- a = std::min<uint8>(a, output_activation_max);
+ a = std::max<uint8>(a, params.quantized_activation_min);
+ a = std::min<uint8>(a, params.quantized_activation_max);
output_ptr[channel] = static_cast<uint8>(a);
}
}
@@ -4213,53 +4096,23 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Pool");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
// Actually carry out L2 Pool. Code is written in forward mode: we go through
// the input values once, and write to all the pooled regions that it maps to.
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Eigen::VectorXf in_square(in_mat.rows());
Eigen::VectorXf out_count(out_mat.cols());
out_count.setZero();
@@ -4270,15 +4123,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
for (int w = 0; w < input_width; ++w) {
// (h_start, h_end) * (w_start, w_end) is the range that the input
// vector projects to.
- const int hpad = h + pad_height;
- const int wpad = w + pad_width;
- const int h_start = (hpad < filter_height)
- ? 0
- : (hpad - filter_height) / stride_height + 1;
+ const int hpad = h + params.padding_values.height;
+ const int wpad = w + params.padding_values.width;
+ const int h_start =
+ (hpad < params.filter_height)
+ ? 0
+ : (hpad - params.filter_height) / stride_height + 1;
const int h_end = std::min(hpad / stride_height + 1, output_height);
- const int w_start = (wpad < filter_width)
- ? 0
- : (wpad - filter_width) / stride_width + 1;
+ const int w_start =
+ (wpad < params.filter_width)
+ ? 0
+ : (wpad - params.filter_width) / stride_width + 1;
const int w_end = std::min(wpad / stride_width + 1, output_width);
// pre-compute square
const int in_offset = w + input_width * (h + input_height * b);
@@ -4299,28 +4154,13 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
out_count = out_count.array().inverse();
out_mat =
(out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt();
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
+ const int flat_size = output_shape.FlatSize();
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = ActivationFunctionWithMinMax(output_data[i],
+ params.float_activation_min,
+ params.float_activation_max);
+ }
}
inline void LocalResponseNormalization(const float* input_data,
@@ -4368,14 +4208,14 @@ inline void LocalResponseNormalization(const float* input_data,
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Softmax");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Compute the exponential first, removing the max coefficient for numerical
// stability.
out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
@@ -4387,10 +4227,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
out_mat.array().rowwise() *= scale;
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4404,8 +4244,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int b = 0; b < outer_size; ++b) {
const uint8* input_data_ptr = input_data + b * depth;
@@ -4595,11 +4438,14 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const float* block_input_data = input_data + i * depth;
@@ -4740,11 +4586,11 @@ log_x_for_x_greater_than_or_equal_to_1(
}
// Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
@@ -4759,8 +4605,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
const uint8* block_input_data = input_data + i * depth;
@@ -4824,21 +4673,21 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() =
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
#ifdef USE_NEON
@@ -4970,10 +4819,10 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
}
@@ -5030,21 +4879,21 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Tanh");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
// Note that this is almost the exact same code as in Logistic().
gemmlowp::ScopedProfilingLabel label("Tanh");
- const int size = MatchingFlatSize(input_dims, output_dims);
+ const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
int32_t output_zero_point = 128;
@@ -5185,16 +5034,16 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
const int16* input_data_ptr = input_data;
@@ -5285,49 +5134,6 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Dequantize");
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
- for (int i = 0; i < flat_size; ++i) {
- int32 val = input_data[i];
- float result = static_cast<float>(scale * (val - zero_point));
- output_data[i] = result;
- }
-}
-
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("FakeQuant");
-
- // 0 should always be a representable value. Let's assume that the initial
- // min,max range contains 0.
- TFLITE_DCHECK_LE(rmin, 0.0f);
- TFLITE_DCHECK_GE(rmax, 0.0f);
- TFLITE_DCHECK_LT(rmin, rmax);
-
- // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
- int quant_min = 0;
- int quant_max = (1 << num_bits) - 1;
- float nudged_min, nudged_max, nudged_scale;
- NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
- &nudged_max, &nudged_scale);
- const float inv_nudged_scale = 1.0f / nudged_scale;
-
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
- for (int i = 0; i < flat_size; ++i) {
- const float src_val = input_data[i];
- const float clamped = std::min(nudged_max, std::max(nudged_min, src_val));
- const float clamped_shifted = clamped - nudged_min;
- const float dst_val =
- TfLiteRound(clamped_shifted * inv_nudged_scale) * nudged_scale +
- nudged_min;
- output_data[i] = dst_val;
- }
-}
-
template <typename SrcT, typename DstT>
inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
DstT* output_data, const Dims<4>& output_dims) {
@@ -5345,26 +5151,6 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims,
output_map.array() = Eigen::floor(input_map.array());
}
-template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Gather");
-
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
- T* out = output_data;
-
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
- TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
- const T* in = input_data + coords_data[i] * stride;
- memcpy(out, in, sizeof(T) * stride);
- out += stride;
- }
-}
-
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) {
@@ -5693,6 +5479,46 @@ inline void ResizeBilinearGeneric(const float* input_data,
}
}
+template <typename T>
+inline void ResizeBilinearGenericSmallChannel(
+ const T* input_data, const Dims<4>& input_dims, T* output_data,
+ const Dims<4>& output_dims, int32 batches, int32 input_height,
+ int32 input_width, int32 depth, int32 output_height, int32 output_width,
+ float height_scale, float width_scale) {
+ memset(output_data, 0,
+ batches * output_height * output_width * depth * sizeof(T));
+
+ T* output_ptr = &output_data[0];
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < output_height; ++y) {
+ float input_y = y * height_scale;
+ int32 y0 = static_cast<int32>(std::floor(input_y));
+ int32 y1 = std::min(y0 + 1, input_height - 1);
+ for (int x = 0; x < output_width; ++x) {
+ float input_x = x * width_scale;
+ int32 x0 = static_cast<int32>(input_x);
+ int32 x1 = std::min(x0 + 1, input_width - 1);
+
+ int32 input_offset[4] = {
+ Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b),
+ Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)};
+ float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
+ (1 - (input_y - y0)) * (input_x - x0),
+ (input_y - y0) * (1 - (input_x - x0)),
+ (input_y - y0) * (input_x - x0)};
+
+ for (int d = 0; d < depth; d++) {
+ const T* input_ptr = &input_data[d];
+ *output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
+ input_ptr[input_offset[1]] * scale[1] +
+ input_ptr[input_offset[2]] * scale[2] +
+ input_ptr[input_offset[3]] * scale[3]);
+ }
+ }
+ }
+ }
+}
+
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, float* output_data,
@@ -5733,6 +5559,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
+// or int16 arithmetic.
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
+ int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ int32 input_height = ArraySize(input_dims, 2);
+ int32 input_width = ArraySize(input_dims, 1);
+ int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
+ TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
+ int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+
+ float height_scale =
+ (align_corners && output_height > 1)
+ ? (static_cast<float>(input_height - 1) / (output_height - 1))
+ : (static_cast<float>(input_height) / output_height);
+
+ float width_scale =
+ (align_corners && output_width > 1)
+ ? (static_cast<float>(input_width - 1) / (output_width - 1))
+ : (static_cast<float>(input_width) / output_width);
+
+ ResizeBilinearGenericSmallChannel<uint8>(
+ input_data, input_dims, output_data, output_dims, batches, input_height,
+ input_width, depth, output_height, output_width, height_scale,
+ width_scale);
+}
+
// legacy, for compatibility with old checked-in code
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
@@ -5742,53 +5603,13 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims, /*align_corners=*/false);
}
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
const Dims<4>& output_dims) {
- // Unoptimized - Straight copy from reference ops.
- gemmlowp::ScopedProfilingLabel label("SpaceToBatchND");
-
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
- const int block_shape_height = block_shape_data[0];
- const int block_shape_width = block_shape_data[1];
- const int padding_top = paddings_data[0];
- const int padding_left = paddings_data[2];
-
- for (int out_b = 0; out_b < output_batch_size; ++out_b) {
- int input_batch = out_b % input_batch_size;
- int shift_w = (out_b / input_batch_size) % block_shape_width;
- int shift_h = (out_b / input_batch_size) / block_shape_width;
- for (int out_h = 0; out_h < output_height; ++out_h) {
- for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
- if (out_h * block_shape_height + shift_h < padding_top ||
- out_h * block_shape_height + shift_h >=
- padding_top + input_height ||
- out_w * block_shape_width + shift_w < padding_left ||
- out_w * block_shape_width + shift_w >= padding_left + input_width) {
- memset(out, 0, depth * sizeof(T));
- } else {
- const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
- (out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
- memcpy(out, in, depth * sizeof(T));
- }
- }
- }
- }
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
}
// Helper methods for BatchToSpaceND.
@@ -5993,54 +5814,6 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
output_dims, 0);
}
-// UNOPTIMIZED COPY of StridedSlice from reference_ops.h.
-template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
- const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 3);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
- const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 2);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
- const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 1);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
- const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 0);
-
- T* out_ptr = output_data;
- for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
- for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
- for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
- }
- }
- }
- }
-}
-
template <typename T>
inline void Slice(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& begin, const std::vector<int>& size,
@@ -6076,41 +5849,6 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mean");
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
-
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
-
- // The current implementation only supports simultaneous reduction over
- // width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
- TFLITE_DCHECK_EQ(output_height, 1);
- TFLITE_DCHECK_EQ(output_width, 1);
-
- for (int out_b = 0; out_b < output_batch; ++out_b) {
- for (int out_d = 0; out_d < output_depth; ++out_d) {
- float value = 0;
- for (int in_h = 0; in_h < input_height; ++in_h) {
- for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
- }
- }
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
- value / (input_width * input_height);
- }
- }
-}
-
-template <typename T>
void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
T* output_data, const Dims<4>& output_dims) {
@@ -6189,130 +5927,84 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
output_map.array() = input1_map.array().max(max_value);
}
-template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("ArgMax");
-
- // The current ArgMax implemention can only determine the index of the maximum
- // value in the last dimension. So the axis argument is ignored.
-
- // For ArgMax, the number of output dimensions = (number of input dimensions -
- // 1). For the sake of simplicity, the output dimensions are equal to the
- // input dimensions here. We enforce the constraint that the last dimension
- // must always be 1.
- TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = ArraySize(input_dims, 0);
- for (int i = 0; i < outer_size; ++i) {
- auto max_value = *input_data;
- ++input_data;
- int max_index = 0;
- for (int d = 1; d < depth; ++d) {
- const auto& curr_value = *input_data;
- if (curr_value > max_value) {
- max_value = curr_value;
- max_index = d;
- }
- ++input_data;
- }
- *output_data = max_index;
- ++output_data;
- }
-}
-
template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, const int* permuted_axes) {
- int out_sizes[4];
- // Compute the inverse permutation array so we can do an output centered
- // transpose. Also, check to make sure output_dims is matching input_dims.
- for (int k = 0; k < 4; k++) {
- out_sizes[k] =
- MatchingArraySize(input_dims, permuted_axes[k], output_dims, k);
- }
-
- // Naive transpose loop (iterate on output index and compute input index).
- int o[4]; // loop index (on output).
- int i[4];
- for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
- i[permuted_axes[3]] = o[3];
- for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
- i[permuted_axes[2]] = o[2];
- for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
- i[permuted_axes[1]] = o[1];
- for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
- i[permuted_axes[0]] = o[0];
- output[Offset(output_dims, o)] = input[Offset(input_dims, i)];
- }
- }
- }
- }
-}
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
+ TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ TFLITE_DCHECK(im2col_data);
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("TransposeConv");
- // THIS FUNCTION IS A COPY FROM reference_ops.h.
- // To optimize, start by using the conv code with transposed weights for the
- // case of stride_height = stride_width = 1.
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
const int input_height = ArraySize(input_dims, 2);
const int input_width = ArraySize(input_dims, 1);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
const int filter_height = ArraySize(filter_dims, 2);
const int filter_width = ArraySize(filter_dims, 1);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
-
- // Although transpose convolution simplifies to convolution with transposed
- // weights for strides of 1, non-unitary striding complicates matters. To
- // keep this reference implementation as clear as possible, we use a "scatter"
- // access pattern, where we loop through all the input elements, computing
- // their influence on the output, rather than looping through the output
- // elements in the typical "gather" access pattern of a conv. We therefore
- // must initialize the output array to zero.
- for (int batch = 0; batch < batches; ++batch) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
- 0.0f;
- }
- }
- }
- }
-
- // Loop through input elements one at a time.
+ MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth
+
+ // Construct the MxN sized im2col matrix.
+ // The rows M, are sub-ordered B x H x W
+ Dims<4> row_dims;
+ row_dims.sizes[0] = output_width;
+ row_dims.sizes[1] = output_height;
+ row_dims.sizes[2] = batches;
+ row_dims.sizes[3] = 1;
+ ComputeStrides(&row_dims);
+
+ // The columns, N, are sub-ordered Kh x Kw x Din
+ Dims<4> col_dims;
+ col_dims.sizes[0] = input_depth;
+ col_dims.sizes[1] = filter_width;
+ col_dims.sizes[2] = filter_height;
+ col_dims.sizes[3] = 1;
+ ComputeStrides(&col_dims);
+
+ // Use dimensions M and N to construct dims for indexing directly into im2col
+ Dims<4> im2col_dims;
+ im2col_dims.sizes[0] = FlatSize(col_dims);
+ im2col_dims.sizes[1] = FlatSize(row_dims);
+ im2col_dims.sizes[2] = 1;
+ im2col_dims.sizes[3] = 1;
+ ComputeStrides(&im2col_dims);
+
+ // Build the im2col matrix by looping through all the input pixels,
+ // computing their influence on the output, rather than looping through all
+ // the output pixels. We therefore must initialize the im2col array to zero.
+ // This is potentially inefficient because we subsequently overwrite bytes
+ // set here. However, in practice memset is very fast and costs negligible.
+ memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+
+ // Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
+ // Loop through input pixels one at a time.
for (int in_y = 0; in_y < input_height; ++in_y) {
for (int in_x = 0; in_x < input_width; ++in_x) {
- for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
- // Loop through the output elements it will influence
- const int out_x_origin = (in_x * stride_width) - pad_width;
- const int out_y_origin = (in_y * stride_height) - pad_height;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ // Loop through the output pixels it will influence
+ const int out_x_origin = (in_x * stride_width) - pad_width;
+ const int out_y_origin = (in_y * stride_height) - pad_height;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int out_y = out_y_origin + filter_y;
+ // Is output pixel within height bounds?
+ if ((out_y >= 0) && (out_y < output_height)) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- for (int out_channel = 0; out_channel < output_depth;
- ++out_channel) {
- // Compute output element location
- const int out_x = out_x_origin + filter_x;
- const int out_y = out_y_origin + filter_y;
- // We cannot accumulate out of bounds
- if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
- (out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
- float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
- }
+ const int out_x = out_x_origin + filter_x;
+ // Is output pixel within width bounds?
+ if ((out_x >= 0) && (out_x < output_width)) {
+ // Copy the input elements of this pixel
+ T const* src =
+ input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ T* dst = im2col_data +
+ Offset(im2col_dims,
+ Offset(col_dims, 0, filter_x, filter_y, 0),
+ Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+ memcpy(dst, src, input_depth * sizeof(T));
}
}
}
@@ -6322,6 +6014,31 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ gemmlowp::ScopedProfilingLabel label("TransposeConv");
+
+ // Note we could use transposed weights with forward conv for unstrided
+ // cases. But we are already getting good performance with this code as-is.
+ TFLITE_DCHECK(im2col_data);
+ TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
+ stride_height, pad_width, pad_height, output_dims, 0,
+ im2col_data);
+
+ const auto im2col_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
+ const auto filter_matrix_map =
+ MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ auto output_matrix_map =
+ MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+
+ Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index b0951aac8c..e224980493 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+
#include <algorithm>
#include <cmath>
#include <limits>
@@ -48,15 +49,15 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
TFLITE_CHECK_GE(*left_shift, 0);
}
-void QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int* right_shift) {
+void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
TFLITE_CHECK_LT(double_multiplier, 1.);
TFLITE_CHECK_GT(double_multiplier, 0.);
int shift;
QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift);
TFLITE_CHECK_LE(shift, 0);
- *right_shift = -shift;
+ *left_shift = shift;
}
void PreprocessSoftmaxScaling(double beta, double input_scale,
@@ -78,20 +79,21 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
quantized_multiplier, left_shift);
}
-void PreprocessLogSoftmaxScaling(double beta, double input_scale,
- int input_integer_bits,
- int32_t* quantized_multiplier, int* left_shift,
- int32_t* reverse_scaling_divisor,
- int* reverse_scaling_right_shift) {
+void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier,
+ int* left_shift,
+ int32_t* reverse_scaling_divisor,
+ int* reverse_scaling_left_shift) {
PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits,
quantized_multiplier, left_shift);
// Also calculate what amounts to the inverse scaling factor for the input.
const double real_reverse_scaling_divisor =
(1 << (31 - *left_shift)) / static_cast<double>(*quantized_multiplier);
- tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor,
- reverse_scaling_divisor,
- reverse_scaling_right_shift);
+ tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor,
+ reverse_scaling_divisor,
+ reverse_scaling_left_shift);
}
int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
@@ -125,4 +127,16 @@ void NudgeQuantizationRange(const float min, const float max,
*nudged_max = (quant_max_float - nudged_zero_point) * (*scale);
}
+bool CheckedLog2(const float x, int* log2_result) {
+ // Using TfLiteRound instead of std::round and std::log instead of
+ // std::log2 to work around these fuctions being missing in a toolchain
+ // used in some TensorFlow tests as of May 2018.
+ const float x_log2 = std::log(x) * (1.0f / std::log(2.0f));
+ const float x_log2_rounded = TfLiteRound(x_log2);
+ const float x_log2_fracpart = x_log2 - x_log2_rounded;
+
+ *log2_result = static_cast<int>(x_log2_rounded);
+ return std::abs(x_log2_fracpart) < 1e-3;
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 4a217515f1..525857a2e6 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -167,9 +167,9 @@ IntOut SafeCast(FloatIn x) {
// this is intended as a RIGHT-shift.
//
// Restricted to the case where the multiplier < 1 (and non-negative).
-void QuantizeMultiplierSmallerThanOne(double double_multiplier,
- int32_t* quantized_multiplier,
- int* right_shift);
+void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift);
// Decompose a double multiplier into a Q0.31 int32 representation of its
// significand, and shift representation of its exponent.
@@ -197,11 +197,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
int input_integer_bits,
int32_t* quantized_multiplier, int* left_shift);
// Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated.
-void PreprocessLogSoftmaxScaling(double beta, double input_scale,
- int input_integer_bits,
- int32_t* quantized_multiplier, int* left_shift,
- int32_t* reverse_scaling_divisor,
- int* reverse_scaling_right_shift);
+void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier,
+ int* left_shift,
+ int32_t* reverse_scaling_divisor,
+ int* reverse_scaling_left_shift);
// Calculate the largest input that will result in a within-bounds intermediate
// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words,
// it must not overflow before we reduce the value by multiplication by the
@@ -217,6 +218,11 @@ void NudgeQuantizationRange(const float min, const float max,
const int quant_min, const int quant_max,
float* nudged_min, float* nudged_max, float* scale);
+// If x is approximately a power of two (with any positive or negative
+// exponent), stores that exponent (i.e. log2(x)) in *log2_result, otherwise
+// returns false.
+bool CheckedLog2(const float x, int* log2_result);
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 2d74b3d384..94773b47d3 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -196,21 +196,21 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
}
-TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) {
+TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) {
auto quantize = [](double d) {
int32_t q;
int s;
- QuantizeMultiplierSmallerThanOne(d, &q, &s);
+ QuantizeMultiplierSmallerThanOneExp(d, &q, &s);
return std::pair<int32_t, int>{q, s};
};
EXPECT_DEATH(quantize(-0.1), "");
EXPECT_DEATH(quantize(0.0), "");
- EXPECT_THAT(quantize(0.25), Pair(1073741824, 1));
+ EXPECT_THAT(quantize(0.25), Pair(1073741824, -1));
// Around 0.5 we can see the change in exponent and how we try hard to
// void hitting max int32.
- EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1));
+ EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1));
EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0));
EXPECT_THAT(quantize(0.50), Pair(1073741824, 0));
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
new file mode 100644
index 0000000000..f715d34bc1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -0,0 +1,369 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
+
+#include <stdint.h>
+#include <sys/types.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+namespace reference_ops {
+
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, uint8* output_data,
+ const Dims<4>& output_dims) {
+ L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
+ output_data, DimsToShape(output_dims));
+}
+
+inline void Relu(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Relu1(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu1(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Relu6(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Relu6(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width,
+ int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int kwidth, int kheight,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int kwidth, int kheight, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, kwidth, kheight, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int filter_width, int filter_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int filter_width, int filter_height, float* output_data,
+ const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_width, filter_height, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
+ int pad_width, int pad_height, int filter_width, int filter_height,
+ float* output_data, const Dims<4>& output_dims) {
+ L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
+ filter_width, filter_height, output_data, output_dims);
+}
+
+inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+ float beta, float* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), beta, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const Dims<4>& output_dims) {
+ Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
+ input_beta_left_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const Dims<4>& output_dims) {
+ LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
+ input_left_shift, reverse_scaling_divisor,
+ reverse_scaling_right_shift, diff_min, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ Logistic(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_zero_point,
+ input_range_radius, input_multiplier, input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+ int input_left_shift, int16* output_data,
+ const Dims<4>& output_dims) {
+ Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
+ DimsToShape(output_dims));
+}
+
+} // namespace reference_ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index f8c6f341f7..ccf112c990 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -51,10 +51,11 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
*scaling_factor = 1;
return;
}
- *scaling_factor = kScale / range;
+ *scaling_factor = range / kScale;
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
- static_cast<int32_t>(TfLiteRound(*scaling_factor * values[i]));
+ static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
// Clamp: just in case some odd numeric offset.
quantized_values[i] = std::min(kScale, std::max(-kScale, quantized_value));
}
@@ -85,7 +86,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
float* __restrict__ result, int result_stride) {
int batch, row, col;
for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) {
- const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch];
+ const float batch_scaling_factor = scaling_factors[batch];
// Get the address of the first row.
const int8_t* row_ptr = matrix;
for (row = 0; row < m_rows; ++row, result += result_stride) {
@@ -98,7 +99,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate(
for (col = 0; col < m_cols; ++col, ++row_ptr) {
dotprod += (*row_ptr) * (vectors[col]);
} // for col
- *result += (dotprod * batch_scaling_factor_inv);
+ *result += (dotprod * batch_scaling_factor);
} // for row
} // for batch
}
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index bebc97309e..912e455a2e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -697,7 +697,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void ExperimentalShuffledFullyConnected(
+inline void ShuffledFullyConnected(
const uint8* input_data, const Dims<4>& input_dims,
const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
@@ -914,9 +914,9 @@ void GlobalBatchNormalization(const float* input_data,
}
}
-inline void Relu(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float lower = 0;
@@ -925,9 +925,10 @@ inline void Relu(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu1(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu1(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 1;
@@ -937,9 +938,10 @@ inline void Relu1(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Relu6(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input_dims, output_dims);
+inline void Relu6(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
const float val = input_data[i];
const float upper = 6;
@@ -949,12 +951,28 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
}
}
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ const uint8 val = input_data[i];
+ const uint8 clamped =
+ val > max_value ? max_value : val < min_value ? min_value : val;
+ output_data[i] = clamped;
+ }
+}
+
template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
static_assert(Ac == FusedActivationFunctionType::kNone, "");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
float squared_l2_norm = 0;
for (int c = 0; c < depth; ++c) {
@@ -968,8 +986,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
- int* output_shift) {
+inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
+ int32* output_inv_sqrt,
+ int* output_shift) {
*output_shift = 11;
while (input >= (1 << 29)) {
input /= 4;
@@ -1011,42 +1030,45 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ *output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
int32 input_zero_point, uint8* output_data,
- const Dims<4>& output_dims) {
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
+ const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
- int32 diff =
- input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ int32 diff = input_data[depth * i + c] - input_zero_point;
square_l2_norm += diff * diff;
}
int32 inv_l2norm_multiplier;
int inv_l2norm_shift;
- GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier,
- &inv_l2norm_shift);
+ GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier,
+ &inv_l2norm_shift);
for (int c = 0; c < depth; c++) {
- int32 diff =
- input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point;
+ int32 diff = input_data[depth * i + c] - input_zero_point;
int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
- 128 * diff, inv_l2norm_multiplier, kReverseShift * inv_l2norm_shift);
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
int32 unclamped_output_val = 128 + rescaled_diff;
int32 output_val = std::min(255, std::max(0, unclamped_output_val));
- output_data[Offset(output_dims, c, i, 0, 0)] =
- static_cast<uint8>(output_val);
+ output_data[depth * i + c] = static_cast<uint8>(output_val);
}
}
}
-inline void Add(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+template <typename T>
+inline void Add(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
@@ -1128,22 +1150,12 @@ inline void Add(int left_shift, const uint8* input1_data,
}
}
-template <FusedActivationFunctionType Ac>
inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
int input1_shift, const int16* input2_data,
const Dims<4>& input2_dims, int input2_shift,
int16 output_activation_min, int16 output_activation_max,
int16* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, -32768);
- TFLITE_DCHECK_EQ(output_activation_max, 32767);
- }
const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
@@ -1169,6 +1181,28 @@ inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
}
}
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ Add(input1_data, input1_dims, input1_shift, input2_data, input2_dims,
+ input2_shift, output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -1749,7 +1783,6 @@ template <FusedActivationFunctionType Ac, typename Scalar>
void Concatenation(int concat_dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, int inputs_count,
Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_GT(inputs_count, 1);
int concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
for (int j = 0; j < 4; j++) {
@@ -1760,7 +1793,9 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
concat_size += ArraySize(*input_dims[i], concat_dim);
}
TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
int outer_size = 1;
for (int i = concat_dim + 1; i < 4; i++) {
outer_size *= output_dims.sizes[i];
@@ -2238,32 +2273,36 @@ inline int NodeOffset(int b, int h, int w, int height, int width) {
return (b * height + h) * width + w;
}
-inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float total = 0.f;
float filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2273,70 +2312,52 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
total +=
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
const float average = total / filter_count;
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(average, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(average, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void AveragePool(const PoolParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
int32 acc = 0;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2345,14 +2366,15 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
++filter_x) {
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
- acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ acc +=
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
filter_count++;
}
}
acc = (acc + filter_count / 2) / filter_count;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ acc = std::max(acc, params.quantized_activation_min);
+ acc = std::min(acc, params.quantized_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(acc);
}
}
@@ -2360,64 +2382,35 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- AveragePool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width,
- int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float sum_squares = 0.f;
int filter_count = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
@@ -2427,69 +2420,51 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
const int in_x = in_x_origin + filter_x;
const int in_y = in_y_origin + filter_y;
const float val =
- input_data[Offset(input_dims, channel, in_x, in_y, batch)];
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)];
sum_squares += val * val;
filter_count++;
}
}
const float l2pool_result = std::sqrt(sum_squares / filter_count);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(l2pool_result, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(l2pool_result,
+ params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
float max = std::numeric_limits<float>::lowest();
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2499,68 +2474,51 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
- ActivationFunctionWithMinMax(max, output_activation_min,
- output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
+ ActivationFunctionWithMinMax(max, params.float_activation_min,
+ params.float_activation_max);
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, float* output_data,
- const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- float* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_data, output_dims);
-}
-
-inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_GE(output_activation_min, 0);
- TFLITE_DCHECK_LE(output_activation_max, 255);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ TFLITE_DCHECK_LE(params.quantized_activation_min,
+ params.quantized_activation_max);
+ TFLITE_DCHECK_GE(params.quantized_activation_min, 0);
+ TFLITE_DCHECK_LE(params.quantized_activation_max, 255);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int depth = MatchingDim(input_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int stride_height = params.stride_height;
+ const int stride_width = params.stride_width;
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
for (int channel = 0; channel < depth; ++channel) {
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int in_x_origin =
+ (out_x * stride_width) - params.padding_values.width;
+ const int in_y_origin =
+ (out_y * stride_height) - params.padding_values.height;
// Compute the boundaries of the filter region clamped so as to
// ensure that the filter window fits in the input array.
const int filter_x_start = std::max(0, -in_x_origin);
const int filter_x_end =
- std::min(filter_width, input_width - in_x_origin);
+ std::min(params.filter_width, input_width - in_x_origin);
const int filter_y_start = std::max(0, -in_y_origin);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(params.filter_height, input_height - in_y_origin);
uint8 max = 0;
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
@@ -2570,12 +2528,12 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
const int in_y = in_y_origin + filter_y;
max = std::max(
max,
- input_data[Offset(input_dims, channel, in_x, in_y, batch)]);
+ input_data[Offset(input_shape, batch, in_y, in_x, channel)]);
}
}
- max = std::max<uint8>(max, output_activation_min);
- max = std::min<uint8>(max, output_activation_max);
- output_data[Offset(output_dims, channel, out_x, out_y, batch)] =
+ max = std::max<uint8>(max, params.quantized_activation_min);
+ max = std::min<uint8>(max, params.quantized_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x, channel)] =
static_cast<uint8>(max);
}
}
@@ -2583,38 +2541,6 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int filter_width, int filter_height, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
- int pad_width, int pad_height, int filter_width, int filter_height,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims) {
- MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
- filter_width, filter_height, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
inline void LocalResponseNormalization(const float* input_data,
const Dims<4>& input_dims, int range,
float bias, float alpha, float beta,
@@ -2638,11 +2564,14 @@ inline void LocalResponseNormalization(const float* input_data,
}
}
-inline void Softmax(const float* input_data, const Dims<4>& input_dims,
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
float beta, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
@@ -2667,10 +2596,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_beta_multiplier, int32 input_beta_left_shift,
int diff_min, uint8* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -2683,8 +2612,11 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8 max_in_row = 0;
@@ -2745,10 +2677,13 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
// Find max element value which we'll use to ensure numerical stability
@@ -2888,11 +2823,11 @@ log_x_for_x_greater_than_or_equal_to_1(
input_val);
}
-inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_multiplier, int32 input_left_shift,
int32 reverse_scaling_divisor,
int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -2906,8 +2841,11 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
uint8 max_in_row = 0;
@@ -2971,9 +2909,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -2982,11 +2920,11 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ uint8* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -3020,9 +2958,9 @@ inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
- int16* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
+ int16* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -3038,9 +2976,9 @@ inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Tanh(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
float val = input_data[i];
@@ -3049,12 +2987,12 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
int32 input_zero_point, int32 input_range_radius,
int32 input_multiplier, int input_left_shift,
- uint8* output_data, const Dims<4>& output_dims) {
+ uint8* output_data, const RuntimeShape& output_shape) {
const int32 output_zero_point = 128;
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
const uint8 input_val_u8 = input_data[i];
@@ -3089,15 +3027,15 @@ inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
int input_left_shift, int16* output_data,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
TFLITE_DCHECK_LE(input_left_shift, 1);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
// F0 uses 0 integer bits, range [-1, 1].
// This is the return type of math functions such as tanh, logistic,
@@ -3202,9 +3140,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims,
}
}
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+template <typename T>
+inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_size_dims, T* output_data,
const Dims<4>& output_dims, bool align_corners) {
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
int32 input_height = ArraySize(input_dims, 2);
@@ -3236,15 +3175,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
int32 x0 = static_cast<int32>(std::floor(input_x));
int32 x1 = std::min(x0 + 1, input_width - 1);
for (int c = 0; c < depth; ++c) {
- float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
- (1 - (input_y - y0)) *
- (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x0, y1, b)] *
- (input_y - y0) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x1, y0, b)] *
- (1 - (input_y - y0)) * (input_x - x0) +
- input_data[Offset(input_dims, c, x1, y1, b)] *
- (input_y - y0) * (input_x - x0);
+ T interpolation =
+ static_cast<T>(input_data[Offset(input_dims, c, x0, y0, b)] *
+ (1 - (input_y - y0)) * (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x0, y1, b)] *
+ (input_y - y0) * (1 - (input_x - x0)) +
+ input_data[Offset(input_dims, c, x1, y0, b)] *
+ (1 - (input_y - y0)) * (input_x - x0) +
+ input_data[Offset(input_dims, c, x1, y1, b)] *
+ (input_y - y0) * (input_x - x0));
output_data[Offset(output_dims, c, x, y, b)] = interpolation;
}
}
@@ -3257,8 +3196,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
const int32* output_size_data,
const Dims<4>& output_size_dims, float* output_data,
const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
+ ResizeBilinear<float>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
}
template <typename T>
@@ -3418,7 +3367,7 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
template <typename T>
inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask,
+ int begin_mask, int end_mask, int shrink_axis_mask,
const std::vector<int>& start_indices,
const std::vector<int>& stop_indices,
const std::vector<int>& strides, T* output_data,
@@ -3430,20 +3379,24 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK_EQ(strides.size(), 4);
const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
strides, input_dims.sizes, 3);
- const int stop_b = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 3);
+ const int stop_b =
+ strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
+ strides, input_dims.sizes, 3, start_b);
const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
strides, input_dims.sizes, 2);
- const int stop_h = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 2);
+ const int stop_h =
+ strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
+ strides, input_dims.sizes, 2, start_h);
const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
strides, input_dims.sizes, 1);
- const int stop_w = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 1);
+ const int stop_w =
+ strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
+ strides, input_dims.sizes, 1, start_w);
const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
strides, input_dims.sizes, 0);
- const int stop_d = strided_slice::StopForAxis(end_mask, stop_indices, strides,
- input_dims.sizes, 0);
+ const int stop_d =
+ strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
+ strides, input_dims.sizes, 0, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
@@ -3505,63 +3458,152 @@ inline void Exp(const T* input_data, const size_t num_elements,
}
}
+// A generic reduce method that can be used for reduce_sum, reduce_mean, etc.
+// This method iterates through input data and reduce elements along the
+// dimensions given in axis.
+template <typename In, typename Out>
+inline bool Reduce(const In* input_data, const int* input_dims,
+ const int* output_dims, const int input_num_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis, int* input_iter,
+ Out reducer(Out current, const In in), Out* output_data) {
+ // Reset input iterator.
+ TFLITE_DCHECK(input_num_dims > 0);
+ for (int idx = 0; idx < input_num_dims; ++idx) {
+ input_iter[idx] = 0;
+ }
+ // Iterate through input_data.
+ do {
+ size_t input_offset =
+ ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr);
+ size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims,
+ input_iter, num_axis, axis);
+ output_data[output_offset] =
+ reducer(output_data[output_offset], input_data[input_offset]);
+ } while (NextIndex(input_num_dims, input_dims, input_iter));
+ return true;
+}
+
+inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis,
+ int* out_axis, int* out_num_axis) {
+ *out_num_axis = 0; // Just in case.
+ // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
+ for (int idx = 0; idx < num_axis; ++idx) {
+ // Handle negative index.
+ int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
+ TFLITE_DCHECK(current >= 0 && current < num_dims);
+ bool is_dup = false;
+ for (int j = 0; j < *out_num_axis; ++j) {
+ if (out_axis[j] == current) {
+ is_dup = true;
+ break;
+ }
+ }
+ if (!is_dup) {
+ out_axis[*out_num_axis] = current;
+ *out_num_axis += 1;
+ }
+ }
+ return true;
+}
+
+// This method expects that output_data has been initialized.
+template <typename In, typename Out>
+inline bool ReduceSumImpl(const In* input_data, const int* input_dims,
+ const int* output_dims, const int input_num_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis, int* input_iter,
+ Out* output_data) {
+ auto reducer = [](Out current, const In in) -> Out {
+ const Out actual_in = static_cast<Out>(in);
+ return current + actual_in;
+ };
+ return Reduce<In, Out>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, axis, num_axis, input_iter, reducer,
+ output_data);
+}
+
+// Computes the sum of elements across dimensions given in axis.
+template <typename T>
+inline bool Sum(const T* input_data, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis) {
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
+ }
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] = T();
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ return ReduceSumImpl<T, T>(input_data, input_dims, output_dims,
+ input_num_dims, output_num_dims, resolved_axis,
+ num_resolved_axis, temp_index, output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis.
template <typename T, typename U>
inline bool Mean(const T* input_data, const int* input_dims,
const int input_num_dims, T* output_data,
const int* output_dims, const int output_num_dims,
const int* axis, const int num_axis_dimensions, bool keep_dims,
int* temp_index, int* resolved_axis, U* temp_sum) {
- // resets output data.
+ // Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
- num_outputs *= static_cast<size_t>(output_dims[idx]);
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
}
for (size_t idx = 0; idx < num_outputs; ++idx) {
output_data[idx] = T();
temp_sum[idx] = U();
}
- // resets temp index.
- for (int idx = 0; idx < input_num_dims; ++idx) {
- temp_index[idx] = 0;
- }
- // resolves axis.
+
+ // Resolve axis.
int num_resolved_axis = 0;
- for (int idx = 0; idx < num_axis_dimensions; ++idx) {
- int current = axis[idx];
- TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0);
- if (current < 0) {
- current += input_num_dims;
- }
- bool is_dup = false;
- for (int j = 0; j < num_resolved_axis; ++j) {
- if (resolved_axis[j] == current) {
- is_dup = true;
- break;
- }
- }
- if (!is_dup) {
- resolved_axis[num_resolved_axis++] = current;
- }
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
}
- // iterates through input_data.
- for (bool has_next = true; has_next;
- has_next = NextIndex(input_num_dims, input_dims, temp_index)) {
- size_t input_offset =
- ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
- size_t output_offset =
- ReducedOutputOffset(input_num_dims, input_dims, temp_index,
- num_resolved_axis, resolved_axis);
- temp_sum[output_offset] += static_cast<U>(input_data[input_offset]);
- }
- // takes average by num of elements added to get mean.
- size_t num_elements_in_axis = 1;
+
+ if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, temp_sum)) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
for (int idx = 0; idx < num_resolved_axis; ++idx) {
size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ // Overflow prevention.
if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
return false;
}
num_elements_in_axis *= current;
}
+
if (num_elements_in_axis > 0) {
for (size_t idx = 0; idx < num_outputs; ++idx) {
output_data[idx] =
@@ -3686,9 +3728,9 @@ void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
}
}
-template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims) {
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -3701,22 +3743,31 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
const int depth = ArraySize(input_dims, 0);
for (int i = 0; i < outer_size; ++i) {
- auto max_value = input_data[i * depth];
- int max_index = 0;
+ auto min_max_value = input_data[i * depth];
+ int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
const auto& curr_value = input_data[i * depth + d];
- if (curr_value > max_value) {
- max_value = curr_value;
- max_index = d;
+ if (cmp(curr_value, min_max_value)) {
+ min_max_value = curr_value;
+ min_max_index = d;
}
}
- output_data[i] = max_index;
+ output_data[i] = min_max_index;
}
}
+// TODO(renjieliu): Remove this one.
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(axis, input_data, input_dims, output_data, output_dims,
+ std::greater<T1>());
+}
+
template <typename T>
void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, int* permuted_axes) {
+ const Dims<4>& output_dims, const int* permuted_axes) {
int out_sizes[4];
// Compute the inverse permutation array so we can do an output centered
// transpose. Also, check to make sure output_dims is matching input_dims.
@@ -3747,10 +3798,11 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, float* output_data,
- const Dims<4>& output_dims) {
+ const Dims<4>& output_dims, float* /*im2col_data*/,
+ const Dims<4>& /*im2col_dims*/) {
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
+ const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
+ const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
const int input_height = ArraySize(input_dims, 2);
const int input_width = ArraySize(input_dims, 1);
const int filter_height = ArraySize(filter_dims, 2);
@@ -3765,7 +3817,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// computing their influence on the output, rather than looping through the
// output elements in the typical "gather" access pattern of a conv. We
// therefore must initialize the output array to zero.
- for (int i = 0; i < FlatSize(output_dims); i++) {
+ const int num_elements = FlatSize(output_dims);
+ for (int i = 0; i < num_elements; i++) {
output_data[i] = 0.0f;
}
@@ -3790,8 +3843,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
float input_value = input_data[Offset(input_dims, in_channel,
in_x, in_y, batch)];
float filter_value =
- filter_data[Offset(filter_dims, out_channel, filter_x,
- filter_y, in_channel)];
+ filter_data[Offset(filter_dims, in_channel, filter_x,
+ filter_y, out_channel)];
output_data[Offset(output_dims, out_channel, out_x, out_y,
batch)] += input_value * filter_value;
}
@@ -3805,6 +3858,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
+inline bool EqualFn(T lhs, T rhs) {
+ return lhs == rhs;
+}
+
+template <typename T>
+inline bool NotEqualFn(T lhs, T rhs) {
+ return lhs != rhs;
+}
+
+template <typename T>
inline bool GreaterFn(T lhs, T rhs) {
return lhs > rhs;
}
@@ -3967,6 +4030,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
input2_offset, input2_multiplier, \
input2_shift, output_data, output_dims); \
}
+TFLITE_COMPARISON_OP(Equal);
+TFLITE_COMPARISON_OP(NotEqual);
TFLITE_COMPARISON_OP(Greater);
TFLITE_COMPARISON_OP(GreaterEqual);
TFLITE_COMPARISON_OP(Less);
@@ -4042,6 +4107,36 @@ inline void SparseToDense(const std::vector<std::vector<I>>& indices,
}
}
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = std::pow(input1_data[i], input2_data[i]);
+ }
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
index c1c50dff4d..3d8765f11b 100644
--- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_float_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -24,9 +24,10 @@ limitations under the License.
namespace tflite {
namespace {
+template <typename T>
void TestOneResizeBilinear(int batch, int depth, int input_width,
int input_height, int output_width,
- int output_height) {
+ int output_height, float error_threshold) {
Dims<4> input_dims_inference =
MakeDimsForInference(depth, input_width, input_height, batch);
Dims<4> output_dims_inference =
@@ -36,14 +37,15 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
const int output_buffer_size =
RequiredBufferSizeForDims(output_dims_inference);
- std::vector<float> input_data(input_buffer_size, 0);
- std::vector<float> reference_output_data(output_buffer_size, 0);
+ std::vector<T> input_data(input_buffer_size, 0);
+ std::vector<T> reference_output_data(output_buffer_size, 0);
// Initialize the output data with something other than zero, so we can catch
// issue with kernels failing to initialize the output.
- std::vector<float> output_data(output_buffer_size, 3.1415);
+ std::vector<T> output_data(output_buffer_size, 3);
- const float input_amplitude = 1.f;
- FillRandom(&input_data, -input_amplitude, input_amplitude);
+ const T min_amplitude = static_cast<T>(0);
+ const T max_amplitude = static_cast<T>(255);
+ FillRandom(&input_data, min_amplitude, max_amplitude);
Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
std::vector<int32> output_size_data = {output_height, output_width};
@@ -58,14 +60,46 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
double sum_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
- sum_diff += std::abs(output_data[i] - reference_output_data[i]);
- max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
+ sum_diff += std::abs(static_cast<float>(output_data[i]) -
+ static_cast<float>(reference_output_data[i]));
+ max_abs_val = std::max(
+ max_abs_val, std::abs(static_cast<float>(reference_output_data[i])));
}
if (sum_diff != 0.f) {
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
const float relative_error = std::abs(mean_diff) / max_abs_val;
- ASSERT_LT(relative_error, 1e-5f);
+ ASSERT_LT(relative_error, error_threshold);
+ }
+}
+
+TEST(ResizeBilinear, TestResizeBilinear8Bit) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+
+ TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
+ output_width, output_height, 0.025);
+ }
+}
+
+TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) {
+ const int kTestsToRun = 100 * 1000;
+ for (int i = 0; i < kTestsToRun; i++) {
+ const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
+ const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
+ const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ const int output_width = input_width * 2;
+ const int output_height = input_height * 2;
+
+ TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
}
}
@@ -79,8 +113,8 @@ TEST(ResizeBilinear, TestResizeBilinear) {
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
- TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
- output_height);
+ TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
}
}
@@ -94,8 +128,8 @@ TEST(ResizeBilinear2x2, TestResizeBilinear) {
const int output_width = input_width * 2;
const int output_height = input_height * 2;
- TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
- output_height);
+ TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
+ output_width, output_height, 1e-5);
}
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
index d781a7b642..a7dad3c14e 100644
--- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -32,19 +32,21 @@ namespace tflite {
namespace {
void RunSoftmaxFloatReference(const uint8* input_data,
- const Dims<4>& dims_common, int32 input_offset,
- const double input_scale, int stride, float beta,
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
uint8* reference_output_data) {
- const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float Softmax.
- reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
- reference_dequant_data.data(), dims_common);
- optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta,
- reference_output_float_data.data(), dims_common);
+ reference_ops::Dequantize(
+ input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
+ reference_dequant_data.data(), ToRuntimeDims(shape_common));
+ optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta,
+ reference_output_float_data.data(), shape_common);
// Work with quantized scaling for Softmax, under which 256 represents 1, but
// we limit this to 255.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -55,9 +57,9 @@ void RunSoftmaxFloatReference(const uint8* input_data,
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
- const Dims<4>& dims_common, const string& check_label,
- bool be_exacting) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
// While calculating some metrics in floating point, we work with quantized
// scaling.
std::vector<int> diff(buffer_size);
@@ -91,15 +93,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Runs the Softmax and compares against the float reference implementation and
// the quantized reference implementation.
-void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
- int32 input_offset, const double input_scale, int stride,
- float beta) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+void RunOneSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> optimized_softmax_output(buffer_size);
std::vector<uint8> reference_float_softmax_output(buffer_size);
std::vector<uint8> reference_quant_softmax_output(buffer_size);
- RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale,
+ RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale,
stride, beta, reference_float_softmax_output.data());
int32 input_beta_multiplier;
@@ -113,21 +115,21 @@ void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, diff_min,
- optimized_softmax_output.data(), dims_common);
- reference_ops::Softmax(input_data, dims_common, input_beta_multiplier,
+ optimized_softmax_output.data(), shape_common);
+ reference_ops::Softmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, diff_min,
- reference_quant_softmax_output.data(), dims_common);
+ reference_quant_softmax_output.data(), shape_common);
CheckOutputData(optimized_softmax_output.data(),
- reference_float_softmax_output.data(), dims_common,
+ reference_float_softmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_softmax_output.data(),
- reference_quant_softmax_output.data(), dims_common,
+ reference_quant_softmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_softmax_output.data(),
- reference_float_softmax_output.data(), dims_common,
+ reference_float_softmax_output.data(), shape_common,
"Quant reference vs float reference", false);
}
@@ -150,13 +152,13 @@ bool TryOneUniformSoftmax() {
const int32 input_offset = UniformRandomInt(-256, 0);
const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandom(&input_data);
- RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
stride, beta);
return true;
}
@@ -188,14 +190,14 @@ bool TryOneSkyscraperSoftmax(bool small_depth) {
const int middle_min = UniformRandomInt(0, 255);
const int sides_max = UniformRandomInt(0, middle_min);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
sides_max);
- RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale,
+ RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale,
stride, beta);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index ef77371bf6..5994fad5c7 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -74,12 +74,22 @@ inline int StartForAxis(int begin_mask,
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
template <typename IntType>
-inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
+inline int StopForAxis(int end_mask, int shrink_axis_mask,
+ std::vector<IntType> const& stop_indices,
std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
+ int const* input_shape, int axis, int start_for_axis) {
// Begin with the specified index
+ const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
+ // When shrinking an axis, the end position does not matter (and can be
+ // incorrect when negative indexing is used, see Issue #19260). Always use
+ // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
+ // already been adjusted for negative indices.
+ if (shrink_axis) {
+ stop = start_for_axis + 1;
+ }
+
// end_mask override
if (end_mask & (1 << axis)) {
if (strides[axis] > 0) {
@@ -93,7 +103,7 @@ inline int StopForAxis(int end_mask, std::vector<IntType> const& stop_indices,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ const int axis_size = input_shape[axis];
if (stop < 0) {
stop += axis_size;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ce887cea8b..ee2af5b460 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_
+#include <complex>
#include <vector>
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -35,6 +36,11 @@ inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
}
template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
inline int32_t* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.i32 : nullptr;
}
@@ -49,6 +55,13 @@ inline bool* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
+template <>
+inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<std::complex<float>*>(tensor->data.c64)
+ : nullptr;
+}
+
template <typename T>
inline const T* GetTensorData(const TfLiteTensor* tensor);
@@ -63,6 +76,11 @@ inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
}
template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.i32 : nullptr;
}
@@ -77,6 +95,13 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
+template <>
+inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr
+ ? reinterpret_cast<const std::complex<float>*>(tensor->data.c64)
+ : nullptr;
+}
+
inline int RemapDim(int max_dimensions, int d) {
return max_dimensions - d - 1;
}
@@ -114,6 +139,19 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
return GetTensorDims(dims->data, dims->size);
}
+inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
+ return RuntimeShape(data.size(), data.data());
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ auto* dims = tensor->dims;
+ return RuntimeShape(dims->size, dims->data);
+}
+
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index 14ee528394..aa0d49ae4d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -63,7 +63,8 @@ TEST(uKernels, SymmetricQuantizeFloatsTest) {
EXPECT_EQ(min, -640);
EXPECT_EQ(max, 1000);
- EXPECT_NEAR(scaling_factor, 0.127, 1e-6); // EQ won't work due to fpoint.
+ // EQ won't work due to fpoint.
+ EXPECT_NEAR(scaling_factor, 1000 / 127.0, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-81, -81, -80, 1, 0, -1, -1, 0, 127}));
}
@@ -95,7 +96,7 @@ TEST(uKernels, SymmetricQuantizeFloatsAllAlmostZeroTest) {
EXPECT_NEAR(min, -9e-05, 1e-6);
EXPECT_NEAR(max, 0.0002, 1e-6);
- EXPECT_EQ(scaling_factor, 635000);
+ EXPECT_NEAR(scaling_factor, 1.57e-6, 1e-6);
EXPECT_THAT(output,
testing::ElementsAreArray({-6, 19, -4, -57, 1, 25, 6, 127, 0}));
}
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index fc8ed753c5..737cfb69c9 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -23,7 +23,73 @@ limitations under the License.
namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
-enum class PaddingType { kNone, kSame, kValid };
+enum class PaddingType : uint8 { kNone, kSame, kValid };
+
+struct PaddingValues {
+ int8 width;
+ int8 height;
+};
+
+// This enumeration allows for non-default formats for the weights array
+// of a fully-connected operator, allowing the use of special optimized
+// runtime paths.
+enum class FullyConnectedWeightsFormat : uint8 {
+ // Default format (flat 2D layout, the inner contiguous dimension
+ // is input_depth, the outer non-contiguous dimension is output_depth)
+ kDefault,
+ // Summary: optimized layout for fast CPU runtime implementation,
+ // aimed specifically at ARM CPUs at the moment, and specialized for
+ // 8-bit quantized layers.
+ //
+ // The use case we're concerned with here is: 8-bit quantization,
+ // large weights matrix that doesn't fit in cache (e.g. 4096x2048 in
+ // a key application that drove this), very small batch size (e.g. 1 -- 4).
+ //
+ // Even with 8-bit quantization of weights, the performance of memory
+ // accesses to the weights can become the dominant issue when
+ // the batch size is small, so each weight value is used in only a few
+ // arithmetic ops, i.e. the fully-connected node has a low arithmetic
+ // intensity. The specific issues that arise are of three kinds:
+ // (1) One may, ideally, max out DRAM bandwidth, i.e. be truly memory
+ // bound. That's the "good" issue to run into.
+ // (2) One may run into sub-optimal pre-fetching: the data hasn't been
+ // prefetched into the cache by the time we need it.
+ // (3) One may run into cache aliasing: multiple values that are
+ // pre-fetched, alias each other in the L1 cache (which typically
+ // has only 4-way set associativity in ARM CPUs) and thus evict
+ // each other before we get to using them.
+ //
+ // The point of this shuffling is to avoid issues (2) and (3) so that
+ // we get as fast as possible given only the hard constraint (1).
+ // This is achieved by turning the difficulty into a solution: the
+ // difficulty, that each value loaded from memory is used only in
+ // one kernel iteration, making this operation memory-intensive, hints at
+ // the solution, of shuffling the weights so that they are stored in the
+ // exact order as the kernel needs to load them, so that the memory
+ // accesses made by the kernel are trivial. This solves (2) because the
+ // trivial memory access pattern allows the CPU's automatic prefetching
+ // to perform very well (no need even for preload instructions), and this
+ // solves (3) because the values being loaded concurrently are now
+ // contiguous in the address space, thus don't alias each other in the cache.
+ //
+ // On ARM, we typically want our kernel to process a 4x16 block of weights
+ // at a time, because:
+ // - 16 is the number of bytes in a NEON register.
+ // - 4 is how many rows we need to handle concurrently in the kernel in
+ // order to have sufficient mutual independence of instructions to
+ // maximize arithmetic throughput.
+ //
+ // Finally, the 'Int8' part in the name refers to the fact that this
+ // weights format has each weights value encoded as a signed int8 value,
+ // even if the data type of the weights buffer is uint8. This is intended
+ // to save runtime kernels the effort to have to XOR the top bit of these
+ // bytes before using them in signed arithmetic, see this file for more
+ // explanations on the 'signed int8 trick' in matrix multiplication kernels:
+ //
+ // tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+ //
+ kShuffled4x16Int8,
+};
// Quantization parameters, determining the mapping of quantized values
// to real values (i.e. determining how quantized values are mathematically
@@ -65,6 +131,10 @@ class RuntimeShape {
ReplaceWith(dimensions_count, dims_data);
}
+ RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
+ BuildFrom(init_list);
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -121,6 +191,10 @@ class RuntimeShape {
}
}
+ inline void BuildFrom(const std::initializer_list<int> init_list) {
+ BuildFrom<const std::initializer_list<int>>(init_list);
+ }
+
// Returns the total count of elements, that is the size when flattened into a
// vector.
inline int FlatSize() const {
@@ -142,6 +216,22 @@ class RuntimeShape {
};
};
+// Converts inference-style shape to legacy tflite::Dims<4>.
+inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
+ tflite::Dims<4> result;
+ const int dimensions_count = array_shape.DimensionsCount();
+ TFLITE_CHECK_LE(dimensions_count, 4);
+ int cum_prod = 1;
+ for (int i = 0; i < 4; i++) {
+ const int new_dim =
+ (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
+ result.sizes[i] = new_dim;
+ result.strides[i] = cum_prod;
+ cum_prod *= new_dim;
+ }
+ return result;
+}
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
TFLITE_DCHECK_GT(num_dims, 0);
@@ -194,6 +284,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
return offset;
}
+inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
+ TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
+ TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
+ TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
+ const int* dims_data = shape.DimsData();
+ return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
+}
+
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@@ -208,6 +307,9 @@ inline int Offset(const Dims<4>& dims, int* index) {
}
// Get array size, DCHECKing that the dim index is in range.
+//
+// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
+// already performs this check.
template <int N>
int ArraySize(const Dims<N>& array, int index) {
TFLITE_DCHECK(index >= 0 && index < N);
@@ -229,6 +331,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
return MatchingArraySize(array1, index1, args...);
}
+// Get common shape dim, DCHECKing that they all agree.
+inline int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return shape1.Dims(index1);
+}
+
+template <typename... Args>
+int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return MatchingDim(shape1, index1, args...);
+}
+
+// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
template <int N>
inline int FlatSize(const Dims<N>& dims) {
int flat_size = 1;
@@ -245,6 +362,50 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
// Flat size calculation, checking that dimensions match with one or more other
// arrays.
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return shape.FlatSize();
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSize(const RuntimeShape& shape,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ return MatchingFlatSize(shape, check_shape_1, check_shape_2, check_shape_3);
+}
+
+// Flat size calculation, checking that dimensions match with one or more other
+// arrays.
template <int N>
inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0) {
for (int i = 0; i < N; ++i) {
@@ -269,7 +430,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2);
}
template <int N>
@@ -280,7 +441,7 @@ inline int MatchingFlatSize(const Dims<N>& dims, const Dims<N>& check_dims_0,
for (int i = 0; i < N; ++i) {
TFLITE_DCHECK_EQ(ArraySize(dims, i), ArraySize(check_dims_0, i));
}
- return FlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
+ return MatchingFlatSize(dims, check_dims_1, check_dims_2, check_dims_3);
}
// Data is required to be contiguous, and so many operators can use either the
@@ -348,6 +509,72 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
check_dims_3);
}
+// Data is required to be contiguous, and so many operators can use either the
+// full array flat size or the flat size with one dimension skipped (commonly
+// the depth).
+inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
+ const int dims_count = shape.DimensionsCount();
+ TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
+ const auto* dims_data = shape.DimsData();
+ int flat_size = 1;
+ for (int i = 0; i < dims_count; ++i) {
+ flat_size *= (i == skip_dim) ? 1 : dims_data[i];
+ }
+ return flat_size;
+}
+
+// A combination of MatchingFlatSize() and FlatSizeSkipDim().
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return FlatSizeSkipDim(shape, skip_dim);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
+ check_shape_3);
+}
+
template <int N>
bool IsPackedWithoutStrides(const Dims<N>& dims) {
int expected_stride = 1;
@@ -358,6 +585,30 @@ bool IsPackedWithoutStrides(const Dims<N>& dims) {
return true;
}
+template <int N>
+void ComputeStrides(Dims<N>* dims) {
+ dims->strides[0] = 1;
+ for (int d = 1; d < N; d++) {
+ dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1];
+ }
+}
+
+struct PoolParams {
+ FusedActivationFunctionType activation;
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int stride_height;
+ int stride_width;
+ int filter_height;
+ int filter_width;
+ // uint8, etc, inference params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float inference params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 184028427f..08f942c933 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -43,12 +43,11 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
return kTfLiteOk;
}
-void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
- TfLiteTensor* output, int32_t* act_min,
- int32_t* act_max) {
- const int32_t qmin = std::numeric_limits<uint8_t>::min();
- const int32_t qmax = std::numeric_limits<uint8_t>::max();
-
+namespace {
+void CalculateActivationRangeQuantizedImpl(TfLiteFusedActivation activation,
+ int32_t qmin, int32_t qmax,
+ TfLiteTensor* output,
+ int32_t* act_min, int32_t* act_max) {
const auto scale = output->params.scale;
const auto zero_point = output->params.zero_point;
@@ -70,23 +69,38 @@ void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
*act_max = qmax;
}
}
-
-void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
- float* activation_min,
- float* activation_max) {
- if (activation == kTfLiteActRelu) {
- *activation_min = 0.f;
- *activation_max = std::numeric_limits<float>::max();
- } else if (activation == kTfLiteActRelu6) {
- *activation_min = 0.f;
- *activation_max = 6.f;
- } else if (activation == kTfLiteActRelu1) {
- *activation_min = -1.f;
- *activation_max = 1.f;
+} // namespace
+
+TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
+ TfLiteFusedActivation activation,
+ TfLiteTensor* output,
+ int32_t* act_min,
+ int32_t* act_max) {
+ int32_t qmin = 0;
+ int32_t qmax = 0;
+ if (output->type == kTfLiteUInt8) {
+ qmin = std::numeric_limits<uint8_t>::min();
+ qmax = std::numeric_limits<uint8_t>::max();
+ } else if (output->type == kTfLiteInt16) {
+ qmin = std::numeric_limits<int16_t>::min();
+ qmax = std::numeric_limits<int16_t>::max();
} else {
- *activation_min = std::numeric_limits<float>::lowest();
- *activation_max = std::numeric_limits<float>::max();
+ TF_LITE_ENSURE(context, false);
}
+
+ CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min,
+ act_max);
+ return kTfLiteOk;
+}
+
+void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
+ TfLiteTensor* output, int32_t* act_min,
+ int32_t* act_max) {
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ CalculateActivationRangeQuantizedImpl(activation, qmin, qmax, output, act_min,
+ act_max);
}
bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index 82cded36f2..c8ce3c917d 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -15,6 +15,8 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
+#include <algorithm>
+
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
@@ -86,14 +88,35 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context,
TfLiteTensor* output,
double* multiplier);
-// Calculates the useful range of an activation layer given its activation
-// tensor.
+// Calculates the useful quantized range of an activation layer given its
+// activation tensor.
+TfLiteStatus CalculateActivationRangeQuantized(TfLiteContext* context,
+ TfLiteFusedActivation activation,
+ TfLiteTensor* output,
+ int32_t* act_min,
+ int32_t* act_max);
void CalculateActivationRangeUint8(TfLiteFusedActivation activation,
TfLiteTensor* output, int32_t* act_min,
int32_t* act_max);
-void CalculateActivationRangeFloat(TfLiteFusedActivation activation,
- float* activation_min,
- float* activation_max);
+// Calculates the useful range of an activation layer given its activation
+// tensor.a
+template <typename T>
+void CalculateActivationRange(TfLiteFusedActivation activation,
+ T* activation_min, T* activation_max) {
+ if (activation == kTfLiteActRelu) {
+ *activation_min = 0;
+ *activation_max = std::numeric_limits<T>::max();
+ } else if (activation == kTfLiteActRelu6) {
+ *activation_min = 0;
+ *activation_max = 6;
+ } else if (activation == kTfLiteActRelu1) {
+ *activation_min = -1;
+ *activation_max = 1;
+ } else {
+ *activation_min = std::numeric_limits<T>::lowest();
+ *activation_max = std::numeric_limits<T>::max();
+ }
+}
// Return true if the given tensors have the same shape.
bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2);
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index 3205c1cc52..a7b54c6b84 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -70,8 +70,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
#define TF_LITE_L2NORM(type) \
type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(output), GetTensorDims(output))
+ GetTensorData<float>(input), GetTensorShape(input), \
+ GetTensorData<float>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +81,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorDims(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorDims(output))
+#define TF_LITE_L2NORM(type) \
+ type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
+ input->params.zero_point, \
+ GetTensorData<uint8>(output), GetTensorShape(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 62820a2f51..9a8d35e82c 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -90,10 +90,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::LogSoftmax(input_buffer, input_dims,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::LogSoftmax(input_buffer, input_shape,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 990b3da055..3577ae6caa 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -24,7 +24,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
+#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -34,6 +37,20 @@ namespace ops {
namespace builtin {
namespace lstm {
+struct OpData {
+ // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // (5 inputs).
+ TfLiteLSTMKernelType kernel_type;
+
+ // These fields are only used by full kernel.
+ int activation_state_tensor_index;
+ int cell_state_tensor_index;
+ int scratch_tensor_index;
+};
+
+// For full inputs kernel (18 or 20 inputs).
+namespace full {
+
// Input Tensors of size {n_batch, n_input}
constexpr int kInputTensor = 0;
@@ -65,26 +82,33 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// If the node has 20 inputs, the following 2 tensors are used as state tensors.
+// These are defined as variable tensors, and will be modified by this op.
+constexpr int kInputActivationStateTensor = 18;
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
+// * If the node has 18 inputs, these 2 tensors are used as state tensors.
+// * If the node has 20 inputs, these 2 tensors are ignored.
+// TODO(ycling): Make the 2 output state tensors optional, and propagate the
+// state to output tensors when the 2 tensors present.
constexpr int kOutputStateTensor = 0;
constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMFullKernel;
+ context->AddTensors(context, /*tensors_to_add=*/7,
+ &op_data->scratch_tensor_index);
+ return op_data;
}
// Check that input tensor dimensions matches with each other.
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -94,7 +118,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- if (input_to_input_weights) {
+ if (input_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
@@ -114,7 +138,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* recurrent_to_input_weights =
GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- if (recurrent_to_input_weights) {
+ if (recurrent_to_input_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
n_cell);
@@ -204,7 +228,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_weights =
GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- if (projection_weights) {
+ if (projection_weights != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
@@ -212,7 +236,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
const TfLiteTensor* projection_bias =
GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- if (projection_bias) {
+ if (projection_bias != nullptr) {
TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
}
@@ -233,15 +257,37 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// Allocate a temporary scratch tensor. Also check that the sizes of the input
// tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- // Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ // True if the node is using input variable state tensors. It means:
+ // * The state tensors are defined as inputs. In this case it would be the
+ // 19th and 20th input tensors.
+ // * Otherwise, the output tensors are used to store states.
+ bool use_input_variable_states;
+ if (node->inputs->size == 20) {
+ use_input_variable_states = true;
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index =
+ node->inputs->data[kInputCellStateTensor];
+ } else if (node->inputs->size == 18) {
+ use_input_variable_states = false;
+ op_data->activation_state_tensor_index =
+ node->outputs->data[kOutputStateTensor];
+ op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
+ } else {
+ context->ReportError(
+ context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
+ node->inputs->size);
+ return kTfLiteError;
+ }
+
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
@@ -262,110 +308,185 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check that input tensor dimensions matches with each other.
CheckInputTensorDimensions(context, node, n_input, n_output, n_cell);
- // Get the pointer to output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
+ if (use_input_variable_states) {
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
+ n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ } else {
+ // If the state tensors are outputs, this function takes the
+ // responsibility to resize the state tensors.
+ TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
+ activation_state_size->data[0] = n_batch;
+ activation_state_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
+ activation_state_size));
+
+ TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
+ cell_size->data[0] = n_batch;
+ cell_size->data[1] = n_cell;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state, cell_size));
+ // Mark state tensors as persistent tensors.
+ activation_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+ }
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
output_size->data[0] = n_batch;
output_size->data[1] = n_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
// Create a scratch buffer tensor.
- TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
// The LSTM Op engine.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
-
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
const int n_batch = input->dims->data[0];
const int n_input = input->dims->data[1];
// n_cell and n_output will be the same size when there is no projection.
@@ -377,9 +498,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
-
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -428,7 +546,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
float* output_ptr_batch = output->data.f;
@@ -441,12 +559,493 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, output_ptr_batch);
+ activation_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
+ TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ kernel_utils::LstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[op_data->activation_state_tensor_index];
+ TfLiteTensor* cell_state =
+ &context->tensors[op_data->cell_state_tensor_index];
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ // TODO(mirkov): add a check that weights are all uint8s or all floats.
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params,
+ scratch_buffer, activation_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params, scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace full
+
+// For basic kernel (5-inputs).
+namespace basic {
+
+enum InputTensor {
+ kInputData = 0,
+ kInputPrevActivation = 1,
+ kInputWeights = 2,
+ kInputBiases = 3,
+ kInputPrevState = 4,
+ kInputNum = 5,
+};
+
+enum OutputTensor {
+ kOutputActivation = 0,
+ kOutputState = 1,
+ kOutputConcatTemp = 2,
+ kOutputActivationTemp = 3,
+ kOutputNum = 4,
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ op_data->kernel_type = kTfLiteLSTMBasicKernel;
+ // `scratch_tensor_index` is unused in this kernel.
+ op_data->scratch_tensor_index = -1;
+ return op_data;
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
+ TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
+ const int num_batches = input->dims->data[0];
+ const int input_depth = input->dims->data[1];
+
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
+ const int activation_depth = prev_activation->dims->data[1];
+ const int total_depth = input_depth + activation_depth;
+
+ TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth);
+ TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth);
+
+ TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth);
+
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
+ TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(
+ context, activation_out,
+ TfLiteIntArrayCopy(prev_activation->dims)));
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, state_out,
+ TfLiteIntArrayCopy(prev_state->dims)));
+ TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
+ concat_temp_size->data[0] = num_batches;
+ concat_temp_size->data[1] = total_depth;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, concat_temp, concat_temp_size));
+ TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
+ activation_temp_size->data[0] = num_batches;
+ activation_temp_size->data[1] = 4 * activation_depth;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
+ activation_temp_size));
+
+ // Set the state tensors as persistent.
+ for (auto index : {kInputPrevActivation, kInputPrevState}) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ tensor->allocation_type = kTfLiteArenaRwPersistent;
+ }
return kTfLiteOk;
}
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputData);
+ const TfLiteTensor* prev_activation =
+ GetInput(context, node, kInputPrevActivation);
+ const TfLiteTensor* weights = GetInput(context, node, kInputWeights);
+ const TfLiteTensor* bias = GetInput(context, node, kInputBiases);
+ const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState);
+
+ TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation);
+ TfLiteTensor* state_out = GetOutput(context, node, kOutputState);
+ TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp);
+ TfLiteTensor* activation_temp =
+ GetOutput(context, node, kOutputActivationTemp);
+
+ if (input->type == kTfLiteFloat32 &&
+ prev_activation->type == kTfLiteFloat32 &&
+ weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 &&
+ prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 &&
+ activation_out->type == kTfLiteFloat32 &&
+ concat_temp->type == kTfLiteFloat32 &&
+ activation_temp->type == kTfLiteFloat32) {
+ optimized_ops::LstmCell(
+ // Inputs.
+ GetTensorData<float>(input), GetTensorDims(input),
+ GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
+ GetTensorData<float>(weights), GetTensorDims(weights),
+ GetTensorData<float>(bias), GetTensorDims(bias),
+ GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ // Outputs.
+ GetTensorData<float>(state_out), GetTensorDims(state_out),
+ GetTensorData<float>(activation_out), GetTensorDims(activation_out),
+ GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
+ GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+ } else if (input->type == kTfLiteUInt8 &&
+ prev_activation->type == kTfLiteUInt8 &&
+ weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
+ prev_state->type == kTfLiteInt16 &&
+ state_out->type == kTfLiteInt16 &&
+ activation_out->type == kTfLiteUInt8 &&
+ concat_temp->type == kTfLiteUInt8 &&
+ activation_temp->type == kTfLiteInt16) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+ int state_scale_log2_rounded;
+ if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
+ context->ReportError(
+ context,
+ "The internal state of a LSTM cell must have a power-of-two scale.");
+ return kTfLiteError;
+ }
+ const int state_integer_bits = 15 + state_scale_log2_rounded;
+ if (state_integer_bits != 4) {
+ context->ReportError(context,
+ "The only case of quantized LstmCell currently "
+ "supported is with StateIntegerBits==4");
+ return kTfLiteError;
+ }
+
+ double real_accum_multiplier = 4096 * bias->params.scale;
+ int32 accum_multiplier;
+ int accum_shift;
+ tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
+ &accum_shift);
+ optimized_ops::LstmCell<4>(
+ // Inputs.
+ GetTensorData<uint8_t>(input), GetTensorDims(input),
+ GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation),
+ GetTensorData<uint8_t>(weights), GetTensorDims(weights),
+ GetTensorData<int32_t>(bias), GetTensorDims(bias),
+ GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state),
+ // Outputs.
+ GetTensorData<int16_t>(state_out), GetTensorDims(state_out),
+ GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out),
+ GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp),
+ GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp),
+ weights->params.zero_point, accum_multiplier, accum_shift,
+ gemm_context);
+ } else {
+ context->ReportError(context,
+ "Unsupported combination of data types for LstmCell");
+ return kTfLiteError;
+ }
+
+ // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs
+ // LSTM kernel.
+ memcpy(prev_activation->data.raw, activation_out->data.raw,
+ activation_out->bytes);
+ memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
+
+ return kTfLiteOk;
+}
+
+} // namespace basic
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ gemm_support::IncrementUsageCounter(context);
+
+ const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
+ switch (params->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Init(context, buffer, length);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Init(context, buffer, length);
+ }
+}
+void Free(TfLiteContext* context, void* buffer) {
+ gemm_support::DecrementUsageCounter(context);
+
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Prepare(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Prepare(context, node);
+ }
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* op_data = reinterpret_cast<const OpData*>(node->user_data);
+ switch (op_data->kernel_type) {
+ case kTfLiteLSTMFullKernel:
+ return full::Eval(context, node);
+ case kTfLiteLSTMBasicKernel:
+ return basic::Eval(context, node);
+ }
+}
+
} // namespace lstm
TfLiteRegistration* Register_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index d81220d8d3..0b7c56133e 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite LSTM op.
-#include <iomanip>
#include <memory>
#include <vector>
@@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel {
LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg,
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel {
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
- input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_input_weights_ = AddInput(weight_type);
}
- input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
- recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_input_weights_ = AddInput(weight_type);
}
- recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
- cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_input_weights_ = AddInput(weight_type);
}
- cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
@@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel {
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
- projection_weights_ = AddInput(TensorType_FLOAT32);
+ projection_weights_ = AddInput(weight_type);
if (use_projection_bias) {
projection_bias_ = AddInput(TensorType_FLOAT32);
} else {
@@ -97,6 +97,12 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -192,8 +198,9 @@ class LSTMOpModel : public SingleOpModel {
zero_buffer.get() + zero_buffer_size);
}
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -203,7 +210,7 @@ class LSTMOpModel : public SingleOpModel {
int num_cells() { return n_cell_; }
int num_batches() { return n_batch_; }
- private:
+ protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
@@ -226,6 +233,8 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
int output_state_;
@@ -237,7 +246,174 @@ class LSTMOpModel : public SingleOpModel {
int n_output_;
};
-TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+class HybridLSTMOpModel : public LSTMOpModel {
+ public:
+ HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole,
+ use_projection_weights, use_projection_bias, cell_clip,
+ proj_clip, input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> lstm_input_;
+ // LSTM output is stored as num_batch x num_outputs vector.
+ std::vector<std::vector<float>> lstm_golden_output_;
+
+ // Compares output up to tolerance to the result of the lstm given the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LSTMOpModel* lstm, float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end);
+ }
+
+ lstm->Invoke();
+
+ const int num_outputs = lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524};
+ input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113, -0.29909778};
+ input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212};
+ input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077,
+ -0.1556896, 0.19487578};
+ input_gate_bias_ = {0., 0., 0., 0.};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_input_weights_ = {
+ -0.0063535, -0.2042388, 0.31454784, -0.35746509,
+ 0.28902304, 0.08183324, -0.16555229, 0.02286911,
+ -0.13566875, 0.03034258, 0.48091322, -0.12528998,
+ 0.24077177, -0.51332325, -0.33502164, 0.10629296};
+
+ recurrent_to_cell_weights_ = {
+ -0.3407414, 0.24443203, -0.2078532, 0.26320225,
+ 0.05695659, -0.00123841, -0.4744786, -0.35869038,
+ -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064};
+
+ recurrent_to_forget_weights_ = {
+ -0.48684245, -0.06655136, 0.42224967, 0.2112639,
+ 0.27654213, 0.20864892, -0.07646349, 0.45877004,
+ 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004};
+
+ recurrent_to_output_weights_ = {
+ 0.43385774, -0.17194885, 0.2718237, 0.09215671,
+ 0.24107647, -0.39835793, 0.18212086, 0.01301402,
+ 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
+ }
+};
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -257,10 +433,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
+ {n_cell, n_output}, // recurrent_to_input_weight_tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight_tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight_tensor
+ {n_cell, n_output}, // recurrent_to_output_weight_tensor
{0}, // cell_to_input_weight tensor
{0}, // cell_to_forget_weight tensor
@@ -275,79 +451,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
-
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
-
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
-
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
-
- lstm.SetInputGateBias({0., 0., 0., 0.});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetCellBias({0., 0., 0., 0.});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetOutputGateBias({0., 0., 0., 0.});
-
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
-
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
+ /*tolerance=*/0.0157651);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
+ 0.05100781, 0.04717243, 0.48944736,
+ -0.38535351, -0.17212132};
- lstm.Invoke();
+ input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698,
+ 0.24407166, 0.33826375};
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_cell_weights_ = {
+ 0.54066205, -0.32668582, -0.43562764, -0.56094903,
+ 0.42957711, 0.01841056, -0.32764608, -0.33027974,
+ -0.10826075, 0.20675004, 0.19069612, -0.03026325,
+ -0.54532051, 0.33003211, 0.44901288, 0.21193194};
+
+ recurrent_to_forget_weights_ = {
+ -0.13832897, -0.0515101, -0.2359007, -0.16661474,
+ -0.14340827, 0.36986142, 0.23414481, 0.55899,
+ 0.10798943, -0.41174671, 0.17751795, -0.34484994,
+ -0.35874045, -0.11352962, 0.27268326, 0.54058349};
+
+ recurrent_to_output_weights_ = {
+ 0.41613156, 0.42610586, -0.16495961, -0.5663873,
+ 0.30579174, -0.05115908, -0.33941799, 0.23364776,
+ 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987};
+
+ cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
+ 0.31544167};
+ cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
+ -0.77109635};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302}};
}
-}
+};
-TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -385,74 +619,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
-
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
-
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
-
- lstm.SetCellBias({0., 0., 0., 0.});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- const int input_sequence_size =
- sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- lstm.Invoke();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
+}
- float* golden_start = lstm_golden_output + i * lstm.num_outputs();
- float* golden_end = golden_start + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {
+ 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
+
+ input_to_forget_weights_ = {
+ -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
+ 0.028273236, -0.016726194, -0.05249759, -0.10204261,
+ 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567,
+ -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465,
+ 0.045630626, 0.06843887, -0.13492945, -0.012480007,
+ -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997,
+ 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294,
+ -0.06277783, -0.08833636, -0.040120605, -0.011405586,
+ -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839,
+ 0.13396506, -0.08402166, -0.01901462, -0.044678304,
+ -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491,
+ 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179,
+ -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
+ -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
+
+ input_to_cell_weights_ = {
+ -0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
+
+ input_to_output_weights_ = {
+ -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
+
+ input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
+ 0.053110216, -0.06928846, -0.13942584, -0.11816189,
+ 0.19483899, 0.03652339, -0.10250295, 0.036714908,
+ -0.18426876, 0.036065217, 0.21810818, 0.02383196,
+ -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
+
+ forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739};
+
+ cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027};
+
+ output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
+ 0.027195795, 0.35373217, -0.018957434, 0.008907322,
+ -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
+ 0.040952638, 0.3147856, 0.08225149, -0.057416286,
+ -0.14995944, -0.008040261, 0.13208859, 0.029760877};
+
+ recurrent_to_input_weights_ = {
+ -0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447};
+
+ recurrent_to_cell_weights_ = {
+ -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
+
+ recurrent_to_forget_weights_ = {
+ -0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027};
+
+ recurrent_to_output_weights_ = {
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812,
+ -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
+ -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
+ -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
+ -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307,
+ -0.08829125, -0.005139627, -0.08989442, -0.0555066,
+ 0.13596267, -0.025062224, -0.048351806, -0.03850004,
+ 0.07266485, -0.022414139, 0.05940088, 0.075114764,
+ 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916,
+ 0.014416728, 0.043229222, 0.034178585, -0.07530371,
+ 0.035837382, -0.085607, -0.007721233, -0.03287832,
+ -0.043848954, -0.06404588, -0.06632928, -0.073643476,
+ 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091,
+ -0.030063879, 0.008801774, -0.023021035, -0.019558564,
+ 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
+ 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
+ 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699,
+ -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
+ 0.10522024, -0.032441203, 0.008176899, -0.04454919,
+ 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
+ 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724,
+ 0.036879618, 0.043357447, 0.028362012, -0.05908629,
+ 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
+ -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
+ -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454,
+ -0.009660886, 0.019076364, 0.018299393, -0.046004917,
+ 0.08891175, 0.0431396, -0.026327137, -0.051502608,
+ 0.08979574, -0.051670972, 0.04940282, -0.07491107,
+ -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988,
+ -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
+ -0.021745546, -0.043124277, -0.06471268, -0.07053354,
+ -0.029321948, -0.05330136, 0.016933719, -0.053782392,
+ 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434,
+ -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
+ -0.037325785, -0.07251102, -0.033633437, -0.08677009,
+ 0.091591336, -0.14165086, 0.021752775, 0.019683983,
+ 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
+ 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
+ -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
+ 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
+ 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798,
+ -0.010036754, 0.02576849, -0.08307328, 0.010112348,
+ 0.042521734, -0.05869831, -0.071689695, 0.03876447,
+ -0.13275425, -0.0352966, -0.023077697, 0.10285965,
+ 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767,
+ -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
+ 0.042672627, 0.013456989, -0.0437609, -0.022309763,
+ 0.11576483, 0.04108048, 0.061026827, -0.0190714,
+ -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893,
+ -0.023771819, -0.01965048, 0.007955471, -0.043740474,
+ 0.03346837, -0.10549954, 0.090567775, 0.042013682,
+ -0.03176985, 0.12569028, -0.02421228, -0.029526481,
+ 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367,
+ -0.06861939, -0.021256343, -0.041093912, -0.06669611,
+ 0.035498552, 0.021757556, -0.09302526, -0.015403468,
+ -0.06614931, -0.051798206, -0.013874718, 0.03630673,
+ 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424,
+ -0.020674974, -0.03944324, -0.008110165, -0.11113267,
+ 0.08484226, 0.043586485, 0.040582247, 0.0968012,
+ -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
+ 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844,
+ -0.11768019, 0.085926116, -0.08251791, -0.045081906,
+ 0.0948852, 0.068401024, 0.024856757, 0.06978981,
+ -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
+ -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ };
+
+ cell_to_input_weights_ = {
+ 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
+
+ cell_to_forget_weights_ = {
+ -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
+
+ cell_to_output_weights_ = {
+ 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
+
+ projection_weights_ = {
+ -0.009802181, 0.09401916, 0.0717386, -0.13895074,
+ 0.09641832, 0.060420845, 0.08539281, 0.054285463,
+ 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018,
+ -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731,
+ 0.08633481, -0.28869787, 0.08682067, 0.17240396,
+ 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434,
+ 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664,
+ 0.13221376, 0.009441198, -0.16715929, 0.15859416,
+ -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
+ 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
+ -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461,
+ -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946,
+ -0.12035478, -0.08361797, -0.050666027, -0.1248618,
+ -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444,
+ 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581,
+ 0.13708383, 0.09134148, -0.12617786, -0.06428341,
+ 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063,
+ 0.022913318, -0.042050496, 0.16842307, -0.060597885,
+ 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757,
+ 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501,
+ -0.02328883, 0.009533515, -0.03606021, -0.07421458,
+ -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874,
+ 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737,
+ 0.1776665, -0.07409566, -0.0477268, 0.29323658,
+ 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
+ 0.10034707, 0.045594677, 0.0635285, -0.0715442,
+ -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807,
+ 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
+ 0.09663576, -0.057112187, -0.10100678, 0.0628376,
+ 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151,
+ -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982,
+ 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
+ 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895,
+ -0.07468996, -0.0855457, 0.099339016, -0.07580735,
+ -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874,
+ 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532,
+ 0.09793732, 0.07512415, -0.11319543, -0.032502122,
+ 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725,
+ -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008,
+ 0.11156489, -0.07483202, 0.06693364, -0.26151478,
+ 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075,
+ 0.030635102, 0.010969227, 0.11109743, 0.010919218,
+ 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
+ -0.07000971, 0.026893595, -0.038844477, 0.14543656};
+
+ lstm_input_ = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
+ 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
+ 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
+ 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
+ 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
+ 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
+ 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
+ };
+
+ lstm_golden_output_ = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
}
-}
+};
-TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -489,588 +1338,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- const int input_sequence_size =
- sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs());
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end);
+ HybridLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
- lstm.Invoke();
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
index 0752aa1804..fd4d5367c5 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum_test.cc
@@ -126,10 +126,10 @@ TEST(MaximumOpTest, FloatWithBroadcastTest) {
TEST(MaximumOpTest, Int32WithBroadcastTest) {
std::initializer_list<int32_t> data1 = {1, 0, -1, -2, 3, 11};
std::initializer_list<int32_t> data2 = {2};
- TestModel<int32>(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}},
+ TestModel<int32_t>(BuiltinOperator_MAXIMUM, {TensorType_INT32, {3, 1, 2}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}},
data1, data2, {2, 2, 2, 2, 3, 11});
- TestModel<int32>(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}},
+ TestModel<int32_t>(BuiltinOperator_MINIMUM, {TensorType_INT32, {3, 1, 2}},
{TensorType_INT32, {1}}, {TensorType_INT32, {3, 1, 2}},
data1, data2, {1, 0, -1, -2, 2, 2});
}
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 62f4e94a38..1f72f3a3c7 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // Parameters used in the quantized paths where the output is 8bit
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // Parameters used in all quantized paths
+ int32_t output_multiplier;
+ int output_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
- output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
@@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
+ }
+
return context->ResizeTensor(context, output, output_size);
}
@@ -83,8 +105,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
#define TF_LITE_MUL(type, opname) \
type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
GetTensorData<float>(input2), GetTensorDims(input2), \
@@ -107,41 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
-
- int32_t output_multiplier;
- int output_shift;
-
- double real_multiplier =
- input1->params.scale * input2->params.scale / output->params.scale;
- QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier,
- &output_shift);
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+ if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), -input2->params.zero_point, \
+ output->params.zero_point, data->output_multiplier, \
+ data->output_shift, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
GetTensorDims(output));
- // The quantized version of Mul doesn't support activations, so we
- // always use BroadcastMul.
- if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteInt16) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ GetTensorData<int16_t>(output), GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ output->params.zero_point, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
} else {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ context->ReportError(
+ context, "Unsupported combination of input and output types in Mul.");
+ return kTfLiteError;
}
-#undef TF_LITE_MUL
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -155,12 +196,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
- } else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(
+ context, EvalQuantized<kernel_type>(context, node, params, data, input1,
+ input2, output));
} else {
context->ReportError(
- context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.",
+ context,
+ "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc
index f1a30f8263..43d56e50d2 100644
--- a/tensorflow/contrib/lite/kernels/mul_test.cc
+++ b/tensorflow/contrib/lite/kernels/mul_test.cc
@@ -58,6 +58,9 @@ class FloatMulOpModel : public BaseMulOpModel {
const float kQuantizedStep = 2.0 / 255.0;
const float kQuantizedTolerance =
2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep;
+const float kQuantizedStepInt16 = 2.0 / 32767.0;
+const float kQuantizedToleranceInt16 =
+ 2.0 * kQuantizedStepInt16 + kQuantizedStepInt16 * kQuantizedStepInt16;
class QuantizedMulOpModel : public BaseMulOpModel {
public:
@@ -67,6 +70,11 @@ class QuantizedMulOpModel : public BaseMulOpModel {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
GetScale(output_), GetZeroPoint(output_));
}
+
+ std::vector<float> GetDequantizedOutputInt16() {
+ return Dequantize<int16_t>(ExtractVector<int16_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
};
TEST(FloatMulOpTest, NoActivation) {
@@ -138,6 +146,38 @@ TEST(QuantizedMulOpTest, NoActivation) {
kQuantizedTolerance)));
}
+TEST(QuantizedMulOpTest, NoActivationInt16) {
+ const float kMin = -1.f;
+ const float kMax = 32767.f / 32768.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {1, 2, 2, 1}, kMin, kMax},
+ {TensorType_INT16, {}, kMin, kMax},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutputInt16(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedToleranceInt16)));
+}
+
+TEST(QuantizedMulOpTest, NoActivationInt16WithUint8Output) {
+ const float kMinInt16 = -1.f;
+ const float kMaxInt16 = 32767.f / 32768.f;
+ const float kMinUint8 = -1.f;
+ const float kMaxUint8 = 127.f / 128.f;
+ QuantizedMulOpModel m({TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_INT16, {1, 2, 2, 1}, kMinInt16, kMaxInt16},
+ {TensorType_UINT8, {}, kMinUint8, kMaxUint8},
+ ActivationFunctionType_NONE);
+ m.QuantizeAndPopulate<int16_t>(m.input1(), {-0.8, 0.2, 0.9, 0.7});
+ m.QuantizeAndPopulate<int16_t>(m.input2(), {0.6, 0.4, 0.9, 0.8});
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56},
+ kQuantizedTolerance)));
+}
+
// for quantized Mul, the error shouldn't exceed 2*step
float GetTolerance(int min, int max) {
float kQuantizedStep = (max - min) / 255.0;
diff --git a/tensorflow/contrib/lite/kernels/neg_test.cc b/tensorflow/contrib/lite/kernels/neg_test.cc
index 3c95ac8cc2..3d3594c60b 100644
--- a/tensorflow/contrib/lite/kernels/neg_test.cc
+++ b/tensorflow/contrib/lite/kernels/neg_test.cc
@@ -58,9 +58,9 @@ TEST(NegOpModel, NegFloat) {
TEST(NegOpModel, NegInt32) {
NegOpModel m({TensorType_INT32, {2, 3}}, {TensorType_INT32, {2, 3}});
- m.SetInput<int32>({-2, -1, 0, 1, 2, 3});
+ m.SetInput<int32_t>({-2, -1, 0, 1, 2, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput<int32>(), ElementsAreArray({2, 1, 0, -1, -2, -3}));
+ EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({2, 1, 0, -1, -2, -3}));
}
TEST(NegOpModel, NegInt64) {
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index bcad58406a..1c728a4733 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -95,6 +95,12 @@ class LSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
output_state_ = AddOutput(TensorType_FLOAT32);
cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -228,6 +234,8 @@ class LSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
int output_;
int output_state_;
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 83668cb4ca..4be8c243c1 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -128,7 +128,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(nupurgarg): Change kernel implementation to use padding arrays in
// forward order (depth, width, height, batch).
// Build paddings in order of int[] = {batch, height, width, depth} to match
- // kernel implementation of Pad in referenced_ops.h and optimized_ops.h.
+ // kernel implementation of Pad in reference_ops.h and optimized_ops.h.
for (int idx = op_context.dims - 1; idx >= 0; --idx) {
before_padding.push_back(paddings_data[idx * 2]);
after_padding.push_back(paddings_data[idx * 2 + 1]);
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 311e9b8399..9b0487ae16 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -80,24 +80,24 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto computeOutSize = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size,
+ int stride) -> int {
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - filter_size + stride) / stride
: 0;
};
- int outWidth =
- computeOutSize(width, params->filter_width, params->stride_width);
- int outHeight =
- computeOutSize(height, params->filter_height, params->stride_height);
+ int out_width =
+ compute_out_size(width, params->filter_width, params->stride_width);
+ int out_height =
+ compute_out_size(height, params->filter_height, params->stride_height);
data->padding.height = ComputePadding(params->stride_height, 1, height,
- params->filter_height, outHeight);
+ params->filter_height, out_height);
data->padding.width = ComputePadding(params->stride_width, 1, width,
- params->filter_width, outWidth);
+ params->filter_width, out_width);
if (input->type == kTfLiteUInt8) {
if (pool_type == kAverage || pool_type == kMax) {
@@ -111,12 +111,12 @@ TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
}
}
- TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
- outputSize->data[0] = batches;
- outputSize->data[1] = outHeight;
- outputSize->data[2] = outWidth;
- outputSize->data[3] = channels_out;
- return context->ResizeTensor(context, output, outputSize);
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+ output_size->data[0] = batches;
+ output_size->data[1] = out_height;
+ output_size->data[2] = out_width;
+ output_size->data[3] = channels_out;
+ return context->ResizeTensor(context, output, output_size);
}
template <KernelType kernel_type>
@@ -124,14 +124,21 @@ void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -148,13 +155,19 @@ void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_AVERAGE_POOL(type) \
- type::AveragePool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, \
- activation_min, activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output))
+#define TF_LITE_AVERAGE_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::AveragePool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_AVERAGE_POOL(reference_ops);
} else {
@@ -168,14 +181,20 @@ void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
#define TF_LITE_MAX_POOL(type) \
- type::MaxPool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -192,13 +211,19 @@ void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t activation_max;
CalculateActivationRangeUint8(params->activation, output, &activation_min,
&activation_max);
-#define TF_LITE_MAX_POOL(type) \
- type::MaxPool(GetTensorData<uint8_t>(input), GetTensorDims(input), \
- params->stride_width, params->stride_height, \
- data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output))
+#define TF_LITE_MAX_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.quantized_activation_min = activation_min; \
+ op_params.quantized_activation_max = activation_max; \
+ type::MaxPool(op_params, GetTensorShape(input), \
+ GetTensorData<uint8_t>(input), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MAX_POOL(reference_ops);
} else {
@@ -212,14 +237,20 @@ void L2EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLitePoolParams* params, OpData* data,
const TfLiteTensor* input, TfLiteTensor* output) {
float activation_min, activation_max;
- CalculateActivationRangeFloat(params->activation, &activation_min,
- &activation_max);
-#define TF_LITE_L2_POOL(type) \
- type::L2Pool( \
- GetTensorData<float>(input), GetTensorDims(input), params->stride_width, \
- params->stride_height, data->padding.width, data->padding.height, \
- params->filter_width, params->filter_height, activation_min, \
- activation_max, GetTensorData<float>(output), GetTensorDims(output))
+ CalculateActivationRange(params->activation, &activation_min,
+ &activation_max);
+#define TF_LITE_L2_POOL(type) \
+ tflite::PoolParams op_params; \
+ op_params.stride_height = params->stride_height; \
+ op_params.stride_width = params->stride_width; \
+ op_params.filter_height = params->filter_height; \
+ op_params.filter_width = params->filter_width; \
+ op_params.padding_values.height = data->padding.height; \
+ op_params.padding_values.width = data->padding.width; \
+ op_params.float_activation_min = activation_min; \
+ op_params.float_activation_max = activation_max; \
+ type::L2Pool(op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2_POOL(reference_ops);
} else {
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
new file mode 100644
index 0000000000..4a539c47a8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -0,0 +1,143 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pow {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for pow op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
+ context->ReportError(context, "Unsupported data type %d.", type);
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output, bool requires_broadcast) {
+ if (requires_broadcast) {
+ reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
+ GetTensorData<T>(input2), GetTensorDims(input2),
+ GetTensorData<T>(output),
+ GetTensorDims(output));
+ } else {
+ reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
+ GetTensorData<T>(input2), GetTensorDims(input2),
+ GetTensorData<T>(output), GetTensorDims(output));
+ }
+}
+
+TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
+ const int64_t num_elements = NumElements(input);
+ const int32_t* data = GetTensorData<int32_t>(input);
+ for (int i = 0; i < num_elements; ++i) {
+ if (data[i] < 0) {
+ context->ReportError(context,
+ "POW does not support negative value for int32.");
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (output->type) {
+ case kTfLiteInt32: {
+ // TensorFlow does not support negative for int32.
+ TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
+ PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ case kTfLiteFloat32: {
+ PowImpl<float>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ default: {
+ context->ReportError(context, "Unsupported data type: %d", output->type);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace pow
+
+TfLiteRegistration* Register_POW() {
+ static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc
new file mode 100644
index 0000000000..474d323bc3
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class PowOpModel : public SingleOpModel {
+ public:
+ PowOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions,
+ CreatePowOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8));
+}
+
+TEST(PowOpModel, NegativeAndZeroValue) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1));
+}
+
+TEST(PowOpModel, Float) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, 2.7, 3.1, 3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3)));
+}
+
+TEST(PowOpModel, NegativeFloatTest) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, -2.7, 3.1, -3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3)));
+}
+
+TEST(PowOpModel, BroadcastTest) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32>(model.input2(), {4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 03e5db24de..31c331a8c6 100644
--- a/tensorflow/contrib/lite/kernels/mean.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -25,21 +25,21 @@ limitations under the License.
namespace tflite {
namespace ops {
namespace builtin {
-namespace mean {
+namespace reduce {
-// This file has reference implementation of Mean.
+// This file has reference implementation of reduce_* operators.
enum KernelType {
kReference,
};
-struct MeanContext {
- MeanContext(TfLiteContext* context, TfLiteNode* node) {
- params = reinterpret_cast<TfLiteMeanParams*>(node->builtin_data);
+struct OpContext {
+ OpContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
input = GetInput(context, node, 0);
axis = GetInput(context, node, 1);
output = GetOutput(context, node, 0);
}
- TfLiteMeanParams* params;
+ TfLiteReducerParams* params;
const TfLiteTensor* input;
const TfLiteTensor* axis;
TfLiteTensor* output;
@@ -58,7 +58,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
// Resizes the temp tensor that stores resolved axis.
-TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
+TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context,
TfLiteTensor* resolved_axis) {
TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
@@ -66,7 +66,7 @@ TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context,
}
// Resizes the temp tensor that stores temp sum of reduced elements.
-TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context,
+TfLiteStatus ResizeTempSum(TfLiteContext* context, OpContext* op_context,
TfLiteTensor* temp_sum) {
TfLiteIntArray* size = TfLiteIntArrayCreate(1);
size->data[0] = static_cast<int>(NumElements(op_context->output));
@@ -74,8 +74,7 @@ TfLiteStatus ResizeTempSum(TfLiteContext* context, MeanContext* op_context,
}
// Resizes output array based on the input size and resolved axis.
-TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
- MeanContext* op_context) {
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
size_t num_axis = NumElements(op_context->axis);
const TfLiteIntArray* input_dims = op_context->input->dims;
int input_num_dims = NumDimensions(op_context->input);
@@ -140,7 +139,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
// Initializes temp tensors to store index and resolved axis.
TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
- MeanContext* op_context) {
+ OpContext* op_context) {
// Creates a temp index to iterate through input data.
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
TfLiteIntArrayFree(node->temporaries);
@@ -180,33 +179,44 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
return kTfLiteOk;
}
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- MeanContext op_context(context, node);
+ OpContext op_context(context, node);
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
// Leaves work to Eval if axis is not constant; else resizes output.
if (!IsConstantTensor(op_context.axis)) {
SetTensorToDynamic(op_context.output);
SetTensorToDynamic(resolved_axis);
- SetTensorToDynamic(temp_sum);
return kTfLiteOk;
}
resolved_axis->allocation_type = kTfLiteArenaRw;
TF_LITE_ENSURE_OK(context,
ResizeTempAxis(context, &op_context, resolved_axis));
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ return kTfLiteOk;
+}
+
+TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
+
+ // reduce_mean requires a buffer to store intermediate sum result.
+ OpContext op_context(context, node);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ if (!IsConstantTensor(op_context.axis)) {
+ SetTensorToDynamic(temp_sum);
+ return kTfLiteOk;
+ }
temp_sum->allocation_type = kTfLiteArenaRw;
return ResizeTempSum(context, &op_context, temp_sum);
}
template <KernelType kernel_type>
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- MeanContext op_context(context, node);
+TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
int num_axis = static_cast<int>(NumElements(op_context.axis));
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
@@ -255,16 +265,75 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
#undef TF_LITE_MEAN
return kTfLiteOk;
}
-} // namespace mean
+
+template <KernelType kernel_type>
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ }
+
+#define TF_LITE_SUM(kernel_type, data_type) \
+ kernel_type::Sum<>( \
+ GetTensorData<data_type>(op_context.input), \
+ op_context.input->dims->data, op_context.input->dims->size, \
+ GetTensorData<data_type>(op_context.output), \
+ op_context.output->dims->data, op_context.output->dims->size, \
+ GetTensorData<int>(op_context.axis), num_axis, \
+ op_context.params->keep_dims, GetTensorData<int>(temp_index), \
+ GetTensorData<int>(resolved_axis))
+
+ if (kernel_type == kReference) {
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float));
+ break;
+ case kTfLiteInt32:
+ TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int));
+ break;
+ case kTfLiteInt64:
+ TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t));
+ break;
+ case kTfLiteUInt8:
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
+ op_context.output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
+ op_context.output->params.zero_point);
+ TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t));
+ break;
+ default:
+ return kTfLiteError;
+ }
+ }
+#undef TF_LITE_SUM
+ return kTfLiteOk;
+}
+
+} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
- static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare,
- mean::Eval<mean::kReference>};
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMean,
+ reduce::EvalMean<reduce::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_SUM_REF() {
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareSimple,
+ reduce::EvalSum<reduce::kReference>};
return &r;
}
// TODO(kanlig): add optimized implementation of Mean.
TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); }
+TfLiteRegistration* Register_SUM() { return Register_SUM_REF(); }
} // namespace builtin
} // namespace ops
diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 79c9957f76..9e946822c6 100644
--- a/tensorflow/contrib/lite/kernels/mean_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -23,7 +23,7 @@ namespace {
using ::testing::ElementsAreArray;
-class BaseMeanOpModel : public SingleOpModel {
+class BaseOpModel : public SingleOpModel {
public:
void SetAxis(std::initializer_list<int> data) { PopulateTensor(axis_, data); }
@@ -53,7 +53,7 @@ class BaseMeanOpModel : public SingleOpModel {
};
// Model for the tests case where axis is a const tensor.
-class MeanOpConstModel : public BaseMeanOpModel {
+class MeanOpConstModel : public BaseOpModel {
public:
MeanOpConstModel(const TensorData& input, const TensorData& output,
std::initializer_list<int> axis_shape,
@@ -61,26 +61,59 @@ class MeanOpConstModel : public BaseMeanOpModel {
input_ = AddInput(input);
axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
output_ = AddOutput(output);
- SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
- CreateMeanOptions(builder_, keep_dims).Union());
+ SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
BuildInterpreter({GetShape(input_)});
}
};
// Model for the tests case where axis is a dynamic tensor.
-class MeanOpDynamicModel : public BaseMeanOpModel {
+class MeanOpDynamicModel : public BaseOpModel {
public:
MeanOpDynamicModel(const TensorData& input, const TensorData& output,
const TensorData& axis, bool keep_dims) {
input_ = AddInput(input);
axis_ = AddInput(axis);
output_ = AddOutput(output);
- SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions,
- CreateMeanOptions(builder_, keep_dims).Union());
+ SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
BuildInterpreter({GetShape(input_)});
}
};
+// Model for the tests case where axis is a const tensor.
+class SumOpConstModel : public BaseOpModel {
+ public:
+ SumOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class SumOpDynamicModel : public BaseOpModel {
+ public:
+ SumOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_SUM, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// for quantized Add, the error shouldn't exceed step
+float GetTolerance(int min, int max) { return (max - min) / 255.0; }
+
+// Tests for reduce_mean
TEST(ConstFloatMeanOpTest, NotKeepDims) {
std::initializer_list<float> data = {
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
@@ -149,8 +182,6 @@ TEST(DynamicFloatMeanOpTest, Scale) {
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
}
-// for quantized Add, the error shouldn't exceed step
-float GetTolerance(int min, int max) { return (max - min) / 255.0; }
TEST(ConstUint8MeanOpTest, NotKeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
@@ -209,6 +240,135 @@ TEST(DynamicUint8MeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
}
+// Tests for reduce_sum
+
+TEST(ConstFloatSumOpTest, NotKeepDims) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}},
+ {4}, {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({144, 156})));
+}
+
+TEST(ConstFloatSumOpTest, KeepDims) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}},
+ {2}, {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({84, 100, 116})));
+}
+
+TEST(DynamicFloatSumOpTest, NotKeepDims) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}},
+ false);
+ std::initializer_list<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({144, 156})));
+}
+
+TEST(DynamicFloatSumOpTest, KeepDims) {
+ std::initializer_list<float> data = {
+ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
+ 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0};
+ SumOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}},
+ {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, true);
+ std::initializer_list<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({84, 100, 116})));
+}
+
+TEST(DynamicFloatSumOpTest, Scale) {
+ std::initializer_list<float> data = {9.527};
+ SumOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::initializer_list<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({9.527})));
+}
+
+TEST(ConstUint8SumOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {2}, -1.0, 1.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8SumOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
+ std::initializer_list<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {3, 2}, -1.0, 1.0},
+ {TensorType_UINT8, {3}, -1.0, 1.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({-0.407843, -0.313726, 0.0941177},
+ kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8SumOpTest, NotKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 2.0);
+ std::initializer_list<float> data = {1.3, -4.8, -3.6, 0.24};
+ SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -5.0, 2.0},
+ {TensorType_UINT8, {2}, -5.0, 2.0},
+ {TensorType_INT32, {1}}, false);
+ std::initializer_list<int> axis = {1};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({1.48235, 1.64706}, kQuantizedTolerance)));
+}
+
+TEST(DynamicUint8SumOpTest, KeepDims) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::initializer_list<float> data = {11.14, -0.14, 7.423, 0.879};
+ SumOpDynamicModel m({TensorType_UINT8, {2, 2}, -10.0, 12.0},
+ {TensorType_UINT8, {2}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::initializer_list<int> axis = {0};
+ m.SetAxis(axis);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({6.47059, 10.698}, kQuantizedTolerance)));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 4eea9921b2..1994e85ce3 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -22,6 +22,7 @@ namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_MFCC();
+TfLiteRegistration* Register_DETECTION_POSTPROCESS();
} // namespace custom
@@ -73,6 +74,7 @@ TfLiteRegistration* Register_SQUEEZE();
TfLiteRegistration* Register_STRIDED_SLICE();
TfLiteRegistration* Register_EXP();
TfLiteRegistration* Register_TOPK_V2();
+TfLiteRegistration* Register_LOG();
TfLiteRegistration* Register_LOG_SOFTMAX();
TfLiteRegistration* Register_CAST();
TfLiteRegistration* Register_DEQUANTIZE();
@@ -80,17 +82,28 @@ TfLiteRegistration* Register_PRELU();
TfLiteRegistration* Register_MAXIMUM();
TfLiteRegistration* Register_MINIMUM();
TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_ARG_MIN();
TfLiteRegistration* Register_GREATER();
TfLiteRegistration* Register_GREATER_EQUAL();
TfLiteRegistration* Register_LESS();
TfLiteRegistration* Register_LESS_EQUAL();
TfLiteRegistration* Register_FLOOR();
+TfLiteRegistration* Register_TILE();
TfLiteRegistration* Register_NEG();
+TfLiteRegistration* Register_SUM();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
TfLiteRegistration* Register_TRANSPOSE_CONV();
+TfLiteRegistration* Register_EXPAND_DIMS();
TfLiteRegistration* Register_SPARSE_TO_DENSE();
+TfLiteRegistration* Register_EQUAL();
+TfLiteRegistration* Register_NOT_EQUAL();
+TfLiteRegistration* Register_SQRT();
+TfLiteRegistration* Register_RSQRT();
+TfLiteRegistration* Register_SHAPE();
+TfLiteRegistration* Register_POW();
+TfLiteRegistration* Register_FAKE_QUANT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -112,7 +125,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
Register_EMBEDDING_LOOKUP_SPARSE());
- AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
@@ -124,7 +139,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION());
AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
Register_LOCAL_RESPONSE_NORMALIZATION());
- AddBuiltin(BuiltinOperator_LSTM, Register_LSTM());
+ AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
Register_BIDIRECTIONAL_SEQUENCE_LSTM());
AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -145,6 +161,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
AddBuiltin(BuiltinOperator_EXP, Register_EXP());
AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
+ AddBuiltin(BuiltinOperator_LOG, Register_LOG());
AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX());
AddBuiltin(BuiltinOperator_CAST, Register_CAST());
AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE());
@@ -152,6 +169,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
@@ -162,13 +180,25 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SLICE, Register_SLICE());
AddBuiltin(BuiltinOperator_SIN, Register_SIN());
AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV());
+ AddBuiltin(BuiltinOperator_TILE, Register_TILE());
+ AddBuiltin(BuiltinOperator_SUM, Register_SUM());
+ AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
+ AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
+ AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
+ AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
+ AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
+ AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
+ AddBuiltin(BuiltinOperator_POW, Register_POW());
+ AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+ AddCustom("TFLite_Detection_PostProcess",
+ tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
} // namespace builtin
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index b928f1b302..940718d67e 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -32,4 +32,4 @@ class BuiltinOpResolver : public MutableOpResolver {
} // namespace ops
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index f2092eaa36..86c4cd3ee8 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
- // TODO(ahentz): Our current implementations only support float32.
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
// ResizeBilinear creates a float tensor even when the input is made of
// integers.
- output->type = kTfLiteFloat32;
+ output->type = input->type;
if (!IsConstantTensor(size)) {
SetTensorToDynamic(output);
@@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type) \
- type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<float>(output), GetTensorDims(output), \
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
+ GetTensorData<int32>(size), GetTensorDims(size), \
+ GetTensorData<datatype>(output), GetTensorDims(output), \
params->align_corners)
if (kernel_type == kReference) {
- TF_LITE_RESIZE_BILINEAR(reference_ops);
+ TF_LITE_RESIZE_BILINEAR(reference_ops, float);
}
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
- TF_LITE_RESIZE_BILINEAR(optimized_ops);
+ TF_LITE_RESIZE_BILINEAR(optimized_ops, float);
+ }
+ } else if (output->type == kTfLiteUInt8) {
+ if (kernel_type == kReference) {
+ TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t);
+ }
+ if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
+ TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t);
}
#undef TF_LITE_RESIZE_BILINEAR
} else {
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
index 4e03f3820a..10caffea03 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc
@@ -22,6 +22,7 @@ namespace tflite {
namespace {
using ::testing::ElementsAreArray;
+using uint8 = std::uint8_t;
class ResizeBilinearOpModel : public SingleOpModel {
public:
@@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
} else {
size_ = AddInput({TensorType_INT32, {2}});
}
- output_ = AddOutput(TensorType_FLOAT32); // Always float.
+ output_ = AddOutput(input.type);
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
BuiltinOptions_ResizeBilinearOptions,
CreateResizeBilinearOptions(builder_).Union());
@@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel {
}
}
- void SetInput(std::initializer_list<float> data) {
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
PopulateTensor(input_, data);
}
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
private:
int input_;
@@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel {
TEST(ResizeBilinearOpTest, HorizontalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
- m.SetInput({3, 6});
+ m.SetInput<float>({3, 6});
m.SetSize({1, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
- const_m.SetInput({3, 6});
+ const_m.SetInput<float>({3, 6});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+}
+
+TEST(ResizeBilinearOpTest, HorizontalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}});
+ m.SetInput<uint8>({3, 6});
+ m.SetSize({1, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
+ const_m.SetInput<uint8>({3, 6});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+ EXPECT_THAT(const_m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 5, 6})));
}
TEST(ResizeBilinearOpTest, VerticalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
- m.SetInput({3, 9});
+ m.SetInput<float>({3, 9});
m.SetSize({3, 1});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+ EXPECT_THAT(m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
- const_m.SetInput({3, 9});
+ const_m.SetInput<float>({3, 9});
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<float>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+}
+
+TEST(ResizeBilinearOpTest, VerticalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}});
+ m.SetInput<uint8>({3, 9});
+ m.SetSize({3, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
+ const_m.SetInput<uint8>({3, 9});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+ EXPECT_THAT(const_m.GetOutput<uint8>(),
+ ElementsAreArray(ArrayFloatNear({3, 7, 9})));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
- m.SetInput({
+ m.SetInput<float>({
3, 6, //
9, 12 //
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
3, 6, //
9, 12 //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- })));
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}});
+ m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12 //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
+ const_m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ })));
}
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
- m.SetInput({
+ m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
@@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- 4, 8, 10, //
- 8, 12, 14, //
- 10, 14, 16, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
3, 6, //
9, 12, //
4, 10, //
10, 16 //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 5, 6, //
- 7, 9, 10, //
- 9, 11, 12, //
- 4, 8, 10, //
- 8, 12, 14, //
- 10, 14, 16, //
- })));
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 14, 16, //
+ })));
}
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
- m.SetInput({
+ m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
m.SetSize({3, 3});
m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 14, 12, 16, //
- })));
+ EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
- const_m.SetInput({
+ const_m.SetInput<float>({
3, 4, 6, 10, //
9, 10, 12, 16, //
});
const_m.Invoke();
- EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
- 3, 4, 5, 8, 6, 10, //
- 7, 8, 9, 12, 10, 14, //
- 9, 10, 11, 14, 12, 16, //
- })));
+ EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 14, 12, 16, //
+ })));
+}
+
+TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}});
+ m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 10, 16 //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 13, 16, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
+ const_m.SetInput<uint8>({
+ 3, 6, //
+ 9, 12, //
+ 4, 10, //
+ 10, 16 //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 5, 6, //
+ 7, 9, 10, //
+ 9, 11, 12, //
+ 4, 8, 10, //
+ 8, 12, 14, //
+ 10, 13, 16, //
+ })));
}
+TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) {
+ ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}});
+ m.SetInput<uint8>({
+ 3, 4, 6, 10, //
+ 9, 10, 12, 16, //
+ });
+ m.SetSize({3, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 13, 12, 16, //
+ })));
+
+ ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
+ const_m.SetInput<uint8>({
+ 3, 4, 6, 10, //
+ 9, 10, 12, 16, //
+ });
+ const_m.Invoke();
+ EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
+ 3, 4, 5, 8, 6, 10, //
+ 7, 8, 9, 12, 10, 14, //
+ 9, 10, 11, 13, 12, 16, //
+ })));
+}
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 9b6cee3cb5..3cdb5db209 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -89,6 +89,9 @@ TfLiteStatus SelectEval(TfLiteContext* context, TfLiteNode* node) {
case kTfLiteUInt8: \
TF_LITE_SELECT(uint8_t, op); \
break; \
+ case kTfLiteInt16: \
+ TF_LITE_SELECT(int16_t, op); \
+ break; \
case kTfLiteInt32: \
TF_LITE_SELECT(int32_t, op); \
break; \
diff --git a/tensorflow/contrib/lite/kernels/select_test.cc b/tensorflow/contrib/lite/kernels/select_test.cc
index cfe24a5fc9..5b2e61cd29 100644
--- a/tensorflow/contrib/lite/kernels/select_test.cc
+++ b/tensorflow/contrib/lite/kernels/select_test.cc
@@ -88,11 +88,24 @@ TEST(SelectOpTest, SelectUInt8) {
TensorType_UINT8);
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
- model.PopulateTensor<uint8>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<uint8>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<uint8_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<uint8_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<uint8>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutput<uint8_t>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(SelectOpTest, SelectInt16) {
+ SelectOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, {1, 1, 1, 4},
+ TensorType_INT16);
+
+ model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
+ model.PopulateTensor<int16_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int16_t>(model.input3(), {5, 6, 7, 8});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput<int16_t>(), ElementsAreArray({5, 2, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
@@ -101,11 +114,11 @@ TEST(SelectOpTest, SelectInt32) {
TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false, true, false, false});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 2, 7, 8}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 2, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
}
@@ -113,11 +126,11 @@ TEST(SelectOpTest, RankOneSelectInt32) {
SelectOpModel model({2}, {2, 1, 2, 1}, {2, 1, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false, true});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 3, 4}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 3, 4}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 2, 1}));
}
@@ -125,11 +138,11 @@ TEST(SelectOpTest, RankZeroSelectInt32) {
SelectOpModel model({1}, {1, 2, 2, 1}, {1, 2, 2, 1}, TensorType_INT32);
model.PopulateTensor<bool>(model.input1(), {false});
- model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 4});
- model.PopulateTensor<int32>(model.input3(), {5, 6, 7, 8});
+ model.PopulateTensor<int32_t>(model.input2(), {1, 2, 3, 4});
+ model.PopulateTensor<int32_t>(model.input3(), {5, 6, 7, 8});
model.Invoke();
- EXPECT_THAT(model.GetOutput<int32>(), ElementsAreArray({5, 6, 7, 8}));
+ EXPECT_THAT(model.GetOutput<int32_t>(), ElementsAreArray({5, 6, 7, 8}));
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 2, 1}));
}
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
new file mode 100644
index 0000000000..dbcd2ef004
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace shape {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+template <typename OutType>
+void ExtractShape(const TfLiteTensor* input, OutType* output_data) {
+ for (int i = 0; i < NumDimensions(input); ++i) {
+ output_data[i] = SizeOfDimension(input, i);
+ }
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ auto* params = reinterpret_cast<TfLiteShapeParams*>(node->builtin_data);
+ switch (params->out_type) {
+ case kTfLiteInt32:
+ output->type = kTfLiteInt32;
+ break;
+ case kTfLiteInt64:
+ output->type = kTfLiteInt64;
+ break;
+ default:
+ context->ReportError(context, "Unknown shape output data type: %d",
+ params->out_type);
+ return kTfLiteError;
+ }
+
+ // Shape always produces a 1-dimensional output tensor, where each output
+ // element is the length of the corresponding input tensor's dimension.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(1);
+ output_size->data[0] = NumDimensions(input);
+ return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TFLITE_DCHECK_EQ(NumDimensions(output), 1);
+ TFLITE_DCHECK_EQ(SizeOfDimension(output, 0), NumDimensions(input));
+
+ switch (output->type) {
+ case kTfLiteInt32:
+ ExtractShape(input, GetTensorData<int32_t>(output));
+ break;
+ case kTfLiteInt64:
+ ExtractShape(input, GetTensorData<int64_t>(output));
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace shape
+
+TfLiteRegistration* Register_SHAPE() {
+ static TfLiteRegistration r = {nullptr, nullptr, shape::Prepare, shape::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/shape_test.cc b/tensorflow/contrib/lite/kernels/shape_test.cc
new file mode 100644
index 0000000000..27b48f4e99
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/shape_test.cc
@@ -0,0 +1,95 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class ShapeOpModel : public SingleOpModel {
+ public:
+ ShapeOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType output_type) {
+ input_ = AddInput(input_type);
+ output_ = AddOutput(output_type);
+ SetBuiltinOp(BuiltinOperator_SHAPE, BuiltinOptions_ShapeOptions,
+ CreateShapeOptions(builder_, output_type).Union());
+ BuildInterpreter({input_shape});
+ }
+
+ TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
+
+ int input() { return input_; }
+
+ int32_t GetOutputSize() { return GetTensorSize(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(ShapeOpTest, OutTypeInt) {
+ ShapeOpModel<int32_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
+ TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(ShapeOpTest, OutTypeInt64) {
+ ShapeOpModel<int64_t> model({1, 3, 1, 3, 5}, TensorType_FLOAT32,
+ TensorType_INT64);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 3, 1, 3, 5}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({5}));
+}
+
+TEST(ShapeOpTest, ScalarTensor) {
+ ShapeOpModel<int32_t> model({}, TensorType_FLOAT32, TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_EQ(model.GetOutputSize(), 0);
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({0}));
+}
+
+TEST(ShapeOpTest, EmptyTensor) {
+ ShapeOpModel<int32_t> model({1, 0}, TensorType_FLOAT32, TensorType_INT32);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 6c5338ff0f..727822f6be 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -92,10 +92,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,10 +119,9 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
m.Invoke();
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
- static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size},
- {1, 0, 0, input_size}};
- tflite::reference_ops::Softmax(input_buffer, input_dims, beta,
- output_buffer.get(), input_dims);
+ auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
+ tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
+ output_buffer.get(), input_shape);
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index 43387df9ce..b144486041 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -76,8 +76,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
auto input_type = op_context.input->type;
- TF_LITE_ENSURE(context,
- input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
+ input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt16);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
}
@@ -137,9 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_SPLIT(uint8_t);
break;
}
+ case kTfLiteInt16: {
+ TF_LITE_SPLIT(int16_t);
+ break;
+ }
default:
context->ReportError(
- context, "Only float32 and uint8 are currently supported, got %d.",
+ context,
+ "Only float32, uint8 and int16 are currently supported, got %d.",
op_context.input->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 725dd8105a..bed2117f9a 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -121,10 +121,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
int32_t begin = GetBeginValueAtIndex(op_context, idx);
int32_t end = GetEndValueAtIndex(op_context, idx);
+ // When shrinking an axis, the end position does not matter (and can be
+ // incorrect when negative indexing is used, see Issue #19260). Always use
+ // begin + 1 to generate a length 1 slice, since begin has
+ // already been adjusted for negative indices by GetBeginValueAtIndex.
+ const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
+ if (shrink_axis) {
+ end = begin + 1;
+ }
+
// This is valid for both positive and negative strides
int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
- if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
+ if (!shrink_axis) {
output_shape_vector.push_back(dim_shape);
}
}
@@ -204,13 +213,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice(GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, \
- end_mask, starts, stops, strides, \
- GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ int shrink_axis_mask =
+ ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
+ starts, stops, strides, GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
index cc39179bc7..c5d4f9affb 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice_test.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc
@@ -21,7 +21,6 @@ limitations under the License.
namespace tflite {
namespace {
-using ::int32;
using ::testing::ElementsAreArray;
template <typename input_type = float,
@@ -50,14 +49,14 @@ class StridedSliceOpModel : public SingleOpModel {
void SetInput(std::initializer_list<input_type> data) {
PopulateTensor<input_type>(input_, data);
}
- void SetBegin(std::initializer_list<int32> data) {
- PopulateTensor<int32>(begin_, data);
+ void SetBegin(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(begin_, data);
}
- void SetEnd(std::initializer_list<int32> data) {
- PopulateTensor<int32>(end_, data);
+ void SetEnd(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(end_, data);
}
- void SetStrides(std::initializer_list<int32> data) {
- PopulateTensor<int32>(strides_, data);
+ void SetStrides(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(strides_, data);
}
std::vector<input_type> GetOutput() {
@@ -384,6 +383,45 @@ TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
}
+TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) {
+ // This is equivalent to tf.range(4)[-1].
+ StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({-1});
+ m.SetEnd({0});
+ m.SetStrides({1});
+
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) {
+ // This is equivalent to tf.range(4)[:, tf.newaxis][-2, -1].
+ StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 0, 0, 0, 0, 3);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({-2, -1});
+ m.SetEnd({-1, 0});
+ m.SetStrides({1, 1});
+
+ m.Invoke();
+ EXPECT_TRUE(m.GetOutputShape().empty());
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({2}));
+}
+
+TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) {
+ // This is equivalent to tf.range(4)[:, tf.newaxis][:, -1].
+ StridedSliceOpModel<> m({4, 1}, {2}, {2}, {2}, 1, 1, 0, 0, 2);
+ m.SetInput({0, 1, 2, 3});
+ m.SetBegin({0, -1});
+ m.SetEnd({0, 0});
+ m.SetStrides({1, 1});
+
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3}));
+}
+
TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4});
@@ -395,17 +433,6 @@ TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({1}));
}
-TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) {
- StridedSliceOpModel<> m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1);
- m.SetInput({1, 2, 3, 4});
- m.SetBegin({-2});
- m.SetEnd({-3});
- m.SetStrides({-1});
- m.Invoke();
- EXPECT_TRUE(m.GetOutputShape().empty());
- EXPECT_THAT(m.GetOutput(), ElementsAreArray({3}));
-}
-
TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) {
StridedSliceOpModel<> m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6});
@@ -538,7 +565,7 @@ TEST(StridedSliceOpTest, RunTwice) {
}
TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) {
- StridedSliceOpModel<uint8, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
+ StridedSliceOpModel<uint8_t, TensorType_UINT8> m({2, 3, 2}, {3}, {3}, {3}, 0, 0,
0, 0, 1);
m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
m.SetBegin({0, 0, 0});
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index d788159a8d..1247525d41 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -83,8 +83,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output) {
float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
#define TF_LITE_SUB(type, opname) \
type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
GetTensorData<float>(input2), GetTensorDims(input2), \
@@ -126,16 +126,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32 input1_multiplier;
int input1_shift;
- QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
- &input1_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier,
+ &input1_multiplier, &input1_shift);
+ input1_shift *= -1;
int32 input2_multiplier;
int input2_shift;
- QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
- &input2_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier,
+ &input2_multiplier, &input2_shift);
+ input2_shift *= -1;
int32 output_multiplier;
int output_shift;
- QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
- &output_shift);
+ QuantizeMultiplierSmallerThanOneExp(real_output_multiplier,
+ &output_multiplier, &output_shift);
+ output_shift *= -1;
int32 output_activation_min, output_activation_max;
CalculateActivationRangeUint8(params->activation, output,
@@ -175,7 +178,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output);
} else {
context->ReportError(
- context, "output type %d is not support, requires float|uint8 types.",
+ context, "output type %d is not supported, requires float|uint8 types.",
output->type);
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 308860c299..22eebdd4ce 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+
+// SVDF op that compresses a fully connected op via low-rank matrix
+// factorization. See https://research.google.com/pubs/archive/43813.pdf for
+// details.
#include <unistd.h>
#include <cassert>
#include <cmath>
@@ -32,6 +36,67 @@ namespace ops {
namespace builtin {
namespace svdf {
+namespace {
+
+struct OpData {
+ int scratch_tensor_index;
+ bool float_weights_time_initialized;
+};
+
+static inline void ApplyTimeWeightsBiasAndActivation(
+ int batch_size, int memory_size, int num_filters, int num_units, int rank,
+ const TfLiteTensor* weights_time, const TfLiteTensor* bias,
+ TfLiteFusedActivation activation, TfLiteTensor* state,
+ TfLiteTensor* scratch, TfLiteTensor* output) {
+ // Compute matmul(state, weights_time).
+ // The right most column is used to save temporary output (with the size of
+ // num_filters). This is achieved by starting at state->data.f and having the
+ // stride equal to memory_size.
+ for (int b = 0; b < batch_size; ++b) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::BatchVectorBatchVectorDotProduct(
+ weights_time->data.f, state_ptr_batch, memory_size, num_filters,
+ scratch_ptr_batch, /*result_stride=*/1);
+ }
+
+ // Initialize output with bias if provided.
+ if (bias) {
+ tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
+ output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
+ }
+
+ // Reduction sum.
+ for (int b = 0; b < batch_size; ++b) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ float* scratch_ptr_batch = scratch->data.f + b * num_filters;
+ tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
+ num_units, rank);
+ }
+
+ // Apply activation.
+ for (int b = 0; b < batch_size; ++b) {
+ float* output_ptr_batch = output->data.f + b * num_units;
+ tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
+ activation, output_ptr_batch);
+ }
+
+ // Left shift the state to make room for next cycle's activation.
+ // TODO(alanchiao): explore collapsing this into a single loop.
+ for (int b = 0; b < batch_size; ++b) {
+ float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
+ for (int f = 0; f < num_filters; ++f) {
+ tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
+ /*shift_value=*/0.0);
+ state_ptr_batch += memory_size;
+ }
+ }
+}
+
+} // namespace
+
constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
@@ -40,29 +105,34 @@ constexpr int kStateTensor = 0;
constexpr int kOutputTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
- return scratch_tensor_index;
+ auto* op_data = new OpData;
+ op_data->float_weights_time_initialized = false;
+ context->AddTensors(context, /*tensors_to_add=*/4,
+ &op_data->scratch_tensor_index);
+ return op_data;
}
void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<int*>(buffer);
+ delete reinterpret_cast<OpData*>(buffer);
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
- int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
GetInput(context, node, kWeightsFeatureTensor);
const TfLiteTensor* weights_time =
GetInput(context, node, kWeightsTimeTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+
// Check all the parameters of tensor match within themselves and match the
// input configuration.
const int rank = params->rank;
@@ -103,10 +173,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size_array));
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op =
+ (input->type == kTfLiteFloat32 && weights_feature->type == kTfLiteUInt8);
+
// Resize scratch.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
- node->temporaries->data[0] = *scratch_tensor_index;
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(4);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = scratch_tensor_index;
TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
scratch_size_array->data[0] = batch_size;
@@ -118,24 +196,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
scratch_size_array));
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const TfLiteTensor* weights_feature =
- GetInput(context, node, kWeightsFeatureTensor);
- const TfLiteTensor* weights_time =
- GetInput(context, node, kWeightsTimeTensor);
+ if (is_hybrid_op) {
+ // Tell interpreter to allocate temporary tensors to store quantized values
+ // of input tensors.
+ node->temporaries->data[1] = scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
- TfLiteTensor* state = GetOutput(context, node, kStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[2] = scratch_tensor_index + 2;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
- const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ // Used to store dequantized weights_time matrix for hybrid computation
+ // of matmul(state, weights_time), which occurs in floating point.
+ node->temporaries->data[3] = scratch_tensor_index + 3;
+ TfLiteTensor* float_weights_time = GetTemporary(context, node, /*index=*/3);
+ float_weights_time->type = kTfLiteFloat32;
+ // Persistent so that we can compute the dequantized weights only once.
+ float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
+ if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) {
+ TfLiteIntArray* float_weights_time_size =
+ TfLiteIntArrayCopy(weights_time->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, float_weights_time,
+ float_weights_time_size));
+ }
+ }
+ return kTfLiteOk;
+}
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input,
+ const TfLiteTensor* weights_feature,
+ const TfLiteTensor* weights_time,
+ const TfLiteTensor* bias, const TfLiteSVDFParams* params,
+ TfLiteTensor* scratch, TfLiteTensor* state,
+ TfLiteTensor* output) {
const int rank = params->rank;
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
@@ -146,67 +256,151 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Clear the activation (state left most column).
// TODO(ghodrat): Add a test which initialize state with invalid values in
// left most column and make sure it passes.
- for (int b = 0; b < batch_size; b++) {
+ for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- for (int c = 0; c < num_filters; c++) {
+ for (int c = 0; c < num_filters; ++c) {
float* state_ptr = state_ptr_batch + c * memory_size;
state_ptr[memory_size - 1] = 0.0;
}
}
// Compute conv1d(inputs, weights_feature).
- // The state left most column is used to save current cycle activation. This
+ // The state right most column is used to save current cycle activation. This
// is achieved by starting at state->data.f[memory_size - 1] and having the
// stride equal to memory_size.
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
weights_feature->data.f, num_filters, input_size, input->data.f,
batch_size, &state->data.f[memory_size - 1], memory_size);
- // Compute matmul(state, weights_time).
- // The right most column is used to save temporary output (with the size of
- // num_filters). This is achieved by starting at state->data.f and having the
- // stride equal to memory_size.
- for (int b = 0; b < batch_size; b++) {
+ ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+ num_units, rank, weights_time, bias,
+ params->activation, state, scratch, output);
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ TfLiteContext* context, TfLiteNode* node, const TfLiteTensor* input,
+ const TfLiteTensor* weights_feature, const TfLiteTensor* weights_time,
+ const TfLiteTensor* bias, const TfLiteSVDFParams* params,
+ TfLiteTensor* scratch, TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* state, TfLiteTensor* output) {
+ const int rank = params->rank;
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int num_filters = weights_feature->dims->data[0];
+ const int num_units = num_filters / rank;
+ const int memory_size = weights_time->dims->data[1];
+
+ // Initialize the pointer to input.
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Other initializations.
+ const int8_t* weights_feature_ptr =
+ reinterpret_cast<int8_t*>(weights_feature->data.uint8);
+ const float weights_feature_scale = weights_feature->params.scale;
+
+ // Clear the activation (state left most column).
+ // TODO(ghodrat): Add a test which initialize state with invalid values in
+ // left most column and make sure it passes.
+ for (int b = 0; b < batch_size; ++b) {
float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- float* scratch_ptr_batch = scratch->data.f + b * num_filters;
- tensor_utils::BatchVectorBatchVectorDotProduct(
- weights_time->data.f, state_ptr_batch, memory_size, num_filters,
- scratch_ptr_batch, /*result_stride=*/1);
+ for (int c = 0; c < num_filters; ++c) {
+ float* state_ptr = state_ptr_batch + c * memory_size;
+ state_ptr[memory_size - 1] = 0.0;
+ }
}
- // Initialize output with bias if provided.
- if (bias) {
- tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size,
- output->data.f);
- } else {
- tensor_utils::ZeroVector(output->data.f, batch_size * num_units);
- }
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, input_size,
+ quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights_feature_scale;
+ }
- // Reduction sum
- for (int b = 0; b < batch_size; b++) {
- float* output_ptr_batch = output->data.f + b * num_units;
- float* scratch_ptr_batch = scratch->data.f + b * num_filters;
- tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch,
- num_units, rank);
+ // Compute conv1d(inputs, weights_feature).
+ // The state right most column is used to save current cycle activation.
+ // This is achieved by starting at state->data.f[memory_size - 1] and having
+ // the stride equal to memory_size.
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_feature_ptr, num_filters, input_size, quantized_input_ptr_batch,
+ scaling_factors_ptr, batch_size, &state->data.f[memory_size - 1],
+ memory_size);
}
- // Apply activation.
- for (int b = 0; b < batch_size; b++) {
- float* output_ptr_batch = output->data.f + b * num_units;
- tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units,
- params->activation, output_ptr_batch);
- }
+ // TODO(alanchiao): can optimize hybrid case ~5% by unrolling loop in applying
+ // time weights so that the inner loop multiplies eight elements at a time.
+ ApplyTimeWeightsBiasAndActivation(batch_size, memory_size, num_filters,
+ num_units, rank, weights_time, bias,
+ params->activation, state, scratch, output);
+ return kTfLiteOk;
+}
- // Right shift the state.
- for (int b = 0; b < batch_size; b++) {
- float* state_ptr_batch = state->data.f + b * memory_size * num_filters;
- for (int f = 0; f < num_filters; f++) {
- tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size,
- /*shift_value=*/0.0);
- state_ptr_batch += memory_size;
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* weights_feature =
+ GetInput(context, node, kWeightsFeatureTensor);
+ const TfLiteTensor* weights_time =
+ GetInput(context, node, kWeightsTimeTensor);
+ const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+
+ TfLiteTensor* scratch = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* state = GetOutput(context, node, kStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights_feature->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(context, node, input, weights_feature, weights_time,
+ bias, params, scratch, state, output);
+ break;
}
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* float_weights_time =
+ GetTemporary(context, node, /*index=*/3);
+
+ // Dequantize weights time.
+ // TODO(alanchiao): this dequantization initialization only needs to
+ // happen once per model and should theoretically be placed in either Init
+ // or Prepare. However, TFLite doesn't allocate float_weights_time until
+ // the Eval function.
+ // TODO(alanchiao): refactor logic out into dequantize function.
+ if (!op_data->float_weights_time_initialized) {
+ const float dequantization_scale = weights_time->params.scale;
+ const int8_t* weights_time_ptr =
+ reinterpret_cast<int8_t*>(weights_time->data.uint8);
+ for (int i = 0; i < NumElements(float_weights_time); ++i) {
+ float_weights_time->data.f[i] =
+ weights_time_ptr[i] * dequantization_scale;
+ }
+ op_data->float_weights_time_initialized = true;
+ }
+ return EvalHybrid(context, node, input, weights_feature,
+ float_weights_time, bias, params, scratch,
+ scaling_factors, input_quantized, state, output);
+ break;
+ }
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ weights_feature->type);
+ return kTfLiteError;
}
- return kTfLiteOk;
}
} // namespace svdf
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index 0f166dc69b..5af3ff8500 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -126,17 +126,20 @@ static float svdf_golden_output_rank_2[] = {
};
// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
-class SVDFOpModel : public SingleOpModel {
+class BaseSVDFOpModel : public SingleOpModel {
public:
- SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank)
+ BaseSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank,
+ TensorType weights_feature_type = TensorType_FLOAT32,
+ TensorType weights_time_type = TensorType_FLOAT32)
: batches_(batches),
units_(units),
input_size_(input_size),
memory_size_(memory_size),
rank_(rank) {
input_ = AddInput(TensorType_FLOAT32);
- weights_feature_ = AddInput(TensorType_FLOAT32);
- weights_time_ = AddInput(TensorType_FLOAT32);
+ weights_feature_ = AddInput(weights_feature_type);
+ weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -182,7 +185,7 @@ class SVDFOpModel : public SingleOpModel {
int num_units() { return units_; }
int num_batches() { return batches_; }
- private:
+ protected:
int input_;
int weights_feature_;
int weights_time_;
@@ -197,7 +200,61 @@ class SVDFOpModel : public SingleOpModel {
int rank_;
};
-TEST(SVDFOpTest, BlackBoxTestRank1) {
+class SVDFOpModel : public BaseSVDFOpModel {
+ public:
+ using BaseSVDFOpModel::BaseSVDFOpModel;
+};
+
+class HybridSVDFOpModel : public BaseSVDFOpModel {
+ public:
+ HybridSVDFOpModel(int batches, int units, int input_size, int memory_size,
+ int rank)
+ : BaseSVDFOpModel(batches, units, input_size, memory_size, rank,
+ TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeightsFeature(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_feature_, f);
+ }
+
+ void SetWeightsTime(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_time_, f);
+ }
+};
+
+class SVDFOpTest : public ::testing::Test {
+ protected:
+ void VerifyGoldens(float golden_input[], float golden_output[],
+ int golden_size, BaseSVDFOpModel* svdf,
+ float tolerance = 1e-5) {
+ const int svdf_num_batches = svdf->num_batches();
+ const int svdf_input_size = svdf->input_size();
+ const int svdf_num_units = svdf->num_units();
+ const int input_sequence_size =
+ golden_size / sizeof(float) / (svdf_input_size * svdf_num_batches);
+ // Going over each input batch, setting the input tensor, invoking the SVDF
+ // op and checking the output with the expected golden values.
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start =
+ golden_input + i * svdf_input_size * svdf_num_batches;
+ float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
+ svdf->SetInput(0, batch_start, batch_end);
+
+ svdf->Invoke();
+
+ const float* golden_start =
+ golden_output + i * svdf_num_units * svdf_num_batches;
+ const float* golden_end =
+ golden_start + svdf_num_units * svdf_num_batches;
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(svdf->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+TEST_F(SVDFOpTest, BlackBoxTestRank1) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/1);
svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
@@ -218,31 +275,11 @@ TEST(SVDFOpTest, BlackBoxTestRank1) {
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf);
}
-TEST(SVDFOpTest, BlackBoxTestRank2) {
+TEST_F(SVDFOpTest, BlackBoxTestRank2) {
SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
/*memory_size=*/10, /*rank=*/2);
svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
@@ -278,28 +315,75 @@ TEST(SVDFOpTest, BlackBoxTestRank2) {
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
svdf.ResetState();
- const int svdf_num_batches = svdf.num_batches();
- const int svdf_input_size = svdf.input_size();
- const int svdf_num_units = svdf.num_units();
- const int input_sequence_size =
- sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches);
- // Going over each input batch, setting the input tensor, invoking the SVDF op
- // and checking the output with the expected golden values.
- for (int i = 0; i < input_sequence_size; i++) {
- float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches;
- float* batch_end = batch_start + svdf_input_size * svdf_num_batches;
- svdf.SetInput(0, batch_start, batch_end);
-
- svdf.Invoke();
-
- float* golden_start =
- svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches;
- float* golden_end = golden_start + svdf_num_units * svdf_num_batches;
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
-
- EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank1) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/1);
+ svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347,
+ 0.22197971, 0.12416199, 0.27901134, 0.27557442,
+ 0.3905206, -0.36137494, -0.06634006, -0.10640851});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
+
+ svdf.ResetState();
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.002945);
+}
+
+TEST_F(SVDFOpTest, BlackBoxTestHybridRank2) {
+ HybridSVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3,
+ /*memory_size=*/10, /*rank=*/2);
+ svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347,
+ 0.12416199, 0.15785322, 0.27901134, 0.3905206,
+ 0.21931258, -0.36137494, -0.10640851, 0.31053296,
+ -0.36118156, -0.0976817, -0.36916667, 0.22197971,
+ 0.15294972, 0.38031587, 0.27557442, 0.39635518,
+ -0.21580373, -0.06634006, -0.02702999, 0.27072677});
+
+ svdf.SetWeightsTime(
+ {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156,
+ 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199,
+
+ 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518,
+ -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296,
+
+ -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236,
+ 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846,
+
+ -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
+ -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657,
+
+ -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486,
+ 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187,
+
+ -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589,
+ 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836,
+
+ -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277,
+ -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214,
+
+ 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
+ 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
+
+ svdf.ResetState();
+ VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input),
+ &svdf,
+ /*tolerance=*/0.00625109);
}
} // namespace
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 1a01ee0936..9156917140 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -32,8 +32,8 @@ std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
return matchers;
}
-int SingleOpModel::AddInput(const TensorData& t) {
- int id = AddTensor<float>(t, {});
+int SingleOpModel::AddInput(const TensorData& t, bool is_variable) {
+ int id = AddTensor<float>(t, {}, is_variable);
inputs_.push_back(id);
return id;
}
@@ -112,8 +112,15 @@ void SingleOpModel::BuildInterpreter(
if (shape.empty()) continue;
CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
}
+
+ // Modify delegate with function.
+ if (apply_delegate_fn_) {
+ apply_delegate_fn_(interpreter_.get());
+ }
+
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
+ interpreter_->ResetVariableTensorsToZero();
}
void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); }
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 55edc97d19..bedbe93ae6 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -114,13 +114,22 @@ class SingleOpModel {
SingleOpModel() {}
~SingleOpModel() {}
+ // Set a function callback that is run right after graph is prepared
+ // that allows applying external delegates. This is useful for testing
+ // other runtimes like NN API or GPU.
+ void SetApplyDelegate(std::function<void(Interpreter*)> apply_delegate_fn) {
+ apply_delegate_fn_ = apply_delegate_fn;
+ }
+
// Copying or assignment is disallowed to simplify ownership semantics.
SingleOpModel(const SingleOpModel&) = delete;
SingleOpModel& operator=(const SingleOpModel&) = delete;
// Add a TensorType input tensor and return its index.
- int AddInput(TensorType type) { return AddInput(TensorData{type}); }
- int AddInput(const TensorData& t);
+ int AddInput(TensorType type, bool is_variable = false) {
+ return AddInput(TensorData{type}, is_variable);
+ }
+ int AddInput(const TensorData& t, bool is_variable = false);
// Templated version of AddConstInput().
template <typename T>
@@ -139,20 +148,18 @@ class SingleOpModel {
int AddOutput(const TensorData& t);
template <typename T>
- void QuantizeAndPopulate(int index, std::initializer_list<float> data) {
+ void QuantizeAndPopulate(int index, const std::vector<float>& data) {
TfLiteTensor* t = interpreter_->tensor(index);
auto q = Quantize<T>(data, t->params.scale, t->params.zero_point);
PopulateTensor(index, 0, q.data(), q.data() + q.size());
}
- void SymmetricQuantizeAndPopulate(int index,
- std::initializer_list<float> data) {
+ void SymmetricQuantizeAndPopulate(int index, const std::vector<float>& data) {
TfLiteTensor* t = interpreter_->tensor(index);
- std::vector<float> values(data);
- const int length = values.size();
+ const int length = data.size();
std::vector<int8_t> q(length);
float min, max, scaling_factor;
- tensor_utils::SymmetricQuantizeFloats(values.data(), length, q.data(), &min,
+ tensor_utils::SymmetricQuantizeFloats(data.data(), length, q.data(), &min,
&max, &scaling_factor);
// Update quantization params.
t->params.scale = scaling_factor;
@@ -189,8 +196,22 @@ class SingleOpModel {
}
// Populate the tensor given its index.
+ // TODO(b/110696148) clean up and merge with vector-taking variant below.
+ template <typename T>
+ void PopulateTensor(int index, const std::initializer_list<T>& data) {
+ T* v = interpreter_->typed_tensor<T>(index);
+ CHECK(v) << "No tensor with index '" << index << "'.";
+ for (T f : data) {
+ *v = f;
+ ++v;
+ }
+ }
+
+ // Populate the tensor given its index.
+ // TODO(b/110696148) clean up and merge with initializer_list-taking variant
+ // above.
template <typename T>
- void PopulateTensor(int index, std::initializer_list<T> data) {
+ void PopulateTensor(int index, const std::vector<T>& data) {
T* v = interpreter_->typed_tensor<T>(index);
CHECK(v) << "No tensor with index '" << index << "'.";
for (T f : data) {
@@ -253,7 +274,8 @@ class SingleOpModel {
}
template <typename T>
- int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int AddTensor(TensorData t, std::initializer_list<T> data,
+ bool is_variable = false) {
int id = tensors_.size();
// This is slightly different depending on whether we are adding a
@@ -270,6 +292,9 @@ class SingleOpModel {
} else if (t.type == TensorType_INT32) {
std::tie(t.scale, t.zero_point) =
QuantizationParams<int32_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT16) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int16_t>(t.min, t.max);
} else {
LOG(FATAL) << "No support for the requested quantized type";
}
@@ -302,7 +327,7 @@ class SingleOpModel {
tensors_.push_back(CreateTensor(builder_,
builder_.CreateVector<int>(t.shape), t.type,
/*buffer=*/buffer_id,
- /*name=*/0, q_params));
+ /*name=*/0, q_params, is_variable));
tensor_data_[id] = t;
@@ -317,6 +342,9 @@ class SingleOpModel {
std::vector<flatbuffers::Offset<Operator>> operators_;
std::vector<flatbuffers::Offset<Buffer>> buffers_;
std::map<string, std::function<TfLiteRegistration*()>> custom_registrations_;
+ // A function pointer that gets called after the interpreter is created but
+ // before evaluation happens. This is useful for applying a delegate.
+ std::function<void(Interpreter*)> apply_delegate_fn_;
};
// Base class for single op unit tests.
diff --git a/tensorflow/contrib/lite/kernels/test_util_test.cc b/tensorflow/contrib/lite/kernels/test_util_test.cc
index 1e10e89061..2365803472 100644
--- a/tensorflow/contrib/lite/kernels/test_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/test_util_test.cc
@@ -22,22 +22,22 @@ using ::testing::ElementsAreArray;
TEST(TestUtilTest, QuantizeVector) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/1.0, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 1, 1, 255};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/1.0, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 1, 1, 255};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
TEST(TestUtilTest, QuantizeVectorScalingDown) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/10.0, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 0, 0, 100};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/10.0, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 0, 0, 100};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
TEST(TestUtilTest, QuantizeVectorScalingUp) {
std::vector<float> data = {-1.0, -0.5, 0.0, 0.5, 1.0, 1000.0};
- auto q_data = Quantize<uint8>(data, /*scale=*/0.1, /*zero_point=*/0);
- std::vector<uint8> expected = {0, 0, 0, 5, 10, 255};
+ auto q_data = Quantize<uint8_t>(data, /*scale=*/0.1, /*zero_point=*/0);
+ std::vector<uint8_t> expected = {0, 0, 0, 5, 10, 255};
EXPECT_THAT(q_data, ElementsAreArray(expected));
}
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
new file mode 100644
index 0000000000..af77f07474
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -0,0 +1,194 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace tile {
+
+constexpr int kInputTensor = 0;
+constexpr int kInputMultipliers = 1;
+constexpr int kOutputTensor = 0;
+
+namespace {
+template <typename T>
+TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape,
+ const TfLiteTensor* multipliers,
+ int num_dimensions) {
+ const T* multipliers_v = GetTensorData<T>(multipliers);
+
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
+ for (int i = 0; i < num_dimensions; ++i) {
+ output_shape->data[i] = shape.data[i] * multipliers_v[i];
+ }
+ return output_shape;
+}
+
+TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+
+ const int num_dimensions = NumDimensions(input);
+ const int num_multipliers = NumElements(multipliers);
+ TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers);
+ switch (multipliers->type) {
+ case kTfLiteInt32:
+ return context->ResizeTensor(
+ context, output,
+ MultiplyShapeDims<int32_t>(*input->dims, multipliers,
+ num_dimensions));
+ case kTfLiteInt64:
+ return context->ResizeTensor(
+ context, output,
+ MultiplyShapeDims<int64_t>(*input->dims, multipliers,
+ num_dimensions));
+ default:
+ context->ReportError(context, "Tile not supported multiply tensor type.");
+ return kTfLiteError;
+ }
+}
+
+template <typename T>
+void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier,
+ T* out_data) {
+ for (int i = 0; i < multiplier; ++i) {
+ const T* in_end = in_data + in_size;
+ T* new_out_data = std::copy(in_data, in_end, out_data);
+ in_data = out_data;
+ out_data = new_out_data;
+ }
+}
+
+template <typename T, typename M>
+std::pair<int, int> TileOneDimension(const TfLiteIntArray& in_dimensions,
+ const T* in_data, const M* multipliers,
+ T* out_data, int dimension) {
+ const int dimension_size = in_dimensions.data[dimension];
+ if (dimension == in_dimensions.size - 1) {
+ CopyMultipleTimes(in_data, dimension_size, multipliers[dimension],
+ out_data);
+ return std::make_pair(dimension_size,
+ dimension_size * multipliers[dimension]);
+ }
+ int total_stride_size = 0, total_tiled_stride_size = 0;
+ const T* copy_from_data = in_data;
+ T* copy_to_data = out_data;
+ for (int i = 0; i < dimension_size; ++i) {
+ int stride_size = 0, tiled_stride_size = 0;
+ std::tie(stride_size, tiled_stride_size) =
+ TileOneDimension(in_dimensions, copy_from_data, multipliers,
+ copy_to_data, dimension + 1);
+ copy_from_data += stride_size;
+ copy_to_data += tiled_stride_size;
+ total_stride_size += stride_size;
+ total_tiled_stride_size += tiled_stride_size;
+ }
+ CopyMultipleTimes(out_data, total_tiled_stride_size,
+ multipliers[dimension] - 1,
+ out_data + total_tiled_stride_size);
+ return std::make_pair(total_stride_size,
+ total_tiled_stride_size * multipliers[dimension]);
+}
+
+template <typename T>
+void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data,
+ const TfLiteTensor* multipliers, TfLiteTensor* out_data) {
+ // Doing recursively tiling from top to down dimension.
+ switch (multipliers->type) {
+ case kTfLiteInt32:
+ TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
+ GetTensorData<int32_t>(multipliers),
+ GetTensorData<T>(out_data), 0);
+ break;
+ case kTfLiteInt64:
+ TileOneDimension(in_dimensions, GetTensorData<T>(in_data),
+ GetTensorData<int64_t>(multipliers),
+ GetTensorData<T>(out_data), 0);
+ break;
+ default:
+ break;
+ }
+}
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+ // Only int32 and int64 multipliers type is supported.
+ TF_LITE_ENSURE_MSG(context,
+ (multipliers->type == kTfLiteInt32) ||
+ (multipliers->type == kTfLiteInt64),
+ "Tile only supports int32 and int64 mutlipliers.");
+
+ if (IsConstantTensor(multipliers)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
+ } else {
+ SetTensorToDynamic(output);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers);
+
+ if (IsDynamicTensor(output)) {
+ TF_LITE_ENSURE_OK(context, ResizeOutput(context, node));
+ }
+
+ switch (output->type) {
+ case kTfLiteFloat32:
+ Tile<float>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteUInt8:
+ Tile<uint8_t>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteInt32:
+ Tile<int32_t>(*(input->dims), input, multipliers, output);
+ break;
+ case kTfLiteInt64:
+ Tile<int64_t>(*(input->dims), input, multipliers, output);
+ break;
+ default:
+ context->ReportError(context, "Type is currently not supported by Tile.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace tile
+TfLiteRegistration* Register_TILE() {
+ static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval};
+ return &r;
+}
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
new file mode 100644
index 0000000000..4f78c224e5
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -0,0 +1,256 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+class TileOpModel : public SingleOpModel {
+ public:
+ TileOpModel(std::initializer_list<int> input_shape, TensorType input_type,
+ TensorType multiply_type) {
+ input_ = AddInput(input_type);
+ multipliers_ = AddInput(TensorType_INT32);
+ output_ = AddOutput(input_type);
+ SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0);
+ BuildInterpreter({input_shape, {static_cast<int>(input_shape.size())}});
+ }
+
+ void SetInputFloat(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ void SetInputUInt8(std::initializer_list<uint8_t> data) {
+ PopulateTensor<uint8_t>(input_, data);
+ }
+
+ void SetInputInt32(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(input_, data);
+ }
+
+ void SetInputInt64(std::initializer_list<int64_t> data) {
+ PopulateTensor<int64_t>(input_, data);
+ }
+
+ void SetMultipliers(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(multipliers_, data);
+ }
+
+ std::vector<float> GetOutputFloat() { return ExtractVector<float>(output_); }
+
+ std::vector<uint8_t> GetOutputUInt8() { return ExtractVector<uint8_t>(output_); }
+
+ std::vector<int32_t> GetOutputInt32() { return ExtractVector<int32_t>(output_); }
+
+ std::vector<int64_t> GetOutputInt64() {
+ return ExtractVector<int64_t>(output_);
+ }
+
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+ int input_;
+ int multipliers_;
+ int output_;
+};
+
+TEST(TileTest, Float32Vector) {
+ TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({1.f, 2.f, 3.f});
+ m.SetMultipliers({2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(),
+ ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f}));
+}
+
+TEST(TileTest, Float32Matrix) {
+ TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Float32HighDimension) {
+ TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32);
+ m.SetInputFloat({
+ 11.f,
+ 12.f,
+ 13.f,
+ 21.f,
+ 22.f,
+ 23.f,
+ });
+ m.SetMultipliers({2, 3, 1});
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetOutputFloat(),
+ ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
+ 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f,
+ 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f,
+ 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3}));
+}
+
+TEST(TileTest, Uint8Matrix) {
+ TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32);
+ m.SetInputUInt8({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int32Matrix) {
+ TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32);
+ m.SetInputInt32({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int64Matrix) {
+ TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32);
+ m.SetInputInt64({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+
+TEST(TileTest, Int64Matrix64Multipliers) {
+ TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64);
+ m.SetInputInt64({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ });
+ m.SetMultipliers({2, 1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ 11,
+ 12,
+ 13,
+ 21,
+ 22,
+ 23,
+ }));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3}));
+}
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index fb0e49c90c..2dd760bbfe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -56,11 +56,13 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) {
output_values_shape->data[num_dimensions - 1] = k;
TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes);
TfLiteTensor* output_values = GetOutput(context, node, kOutputValues);
+ // Force output types.
+ output_indexes->type = kTfLiteInt32;
+ output_values->type = input->type;
auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size,
TfLiteIntArray* delete_on_error) {
TfLiteStatus status = context->ResizeTensor(context, tensor, new_size);
if (status != kTfLiteOk) {
- TfLiteIntArrayFree(new_size);
if (delete_on_error != nullptr) {
TfLiteIntArrayFree(delete_on_error);
}
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 212f8acc76..2abb89b617 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -42,32 +42,32 @@ class TopKV2OpModel : public SingleOpModel {
PopulateTensor<float>(input_, data);
}
- void SetInputUInt8(std::initializer_list<uint8> data) {
- PopulateTensor<uint8>(input_, data);
+ void SetInputUInt8(std::initializer_list<uint8_t> data) {
+ PopulateTensor<uint8_t>(input_, data);
}
- void SetInputInt32(std::initializer_list<int32> data) {
- PopulateTensor<int32>(input_, data);
+ void SetInputInt32(std::initializer_list<int32_t> data) {
+ PopulateTensor<int32_t>(input_, data);
}
void SetInputInt64(std::initializer_list<int64_t> data) {
PopulateTensor<int64_t>(input_, data);
}
- std::vector<int32> GetIndexes() {
- return ExtractVector<int32>(output_indexes_);
+ std::vector<int32_t> GetIndexes() {
+ return ExtractVector<int32_t>(output_indexes_);
}
std::vector<float> GetValuesFloat() {
return ExtractVector<float>(output_values_);
}
- std::vector<uint8> GetValuesUInt8() {
- return ExtractVector<uint8>(output_values_);
+ std::vector<uint8_t> GetValuesUInt8() {
+ return ExtractVector<uint8_t>(output_values_);
}
- std::vector<int32> GetValuesInt32() {
- return ExtractVector<int32>(output_values_);
+ std::vector<int32_t> GetValuesInt32() {
+ return ExtractVector<int32_t>(output_values_);
}
std::vector<int64_t> GetValuesInt64() {
@@ -119,7 +119,7 @@ TEST(TopKV2OpTest, VectorFloat) {
EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2})));
}
-// Check that uint8 works.
+// Check that uint8_t works.
TEST(TopKV2OpTest, TypeUint8) {
TopKV2OpModel m({2, 3}, TensorType_UINT8, 2);
m.SetInputUInt8({1, 2, 3, 251, 250, 249});
@@ -128,7 +128,7 @@ TEST(TopKV2OpTest, TypeUint8) {
EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250}));
}
-// Check that int32 works.
+// Check that int32_t works.
TEST(TopKV2OpTest, TypeInt32) {
TopKV2OpModel m({2, 3}, TensorType_INT32, 2);
m.SetInputInt32({1, 2, 3, 10251, 10250, 10249});
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 3c99661029..7182374a6f 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -38,9 +39,35 @@ constexpr int kWeightsTensor = 1;
constexpr int kDataInputTensor = 2;
constexpr int kOutputTensor = 0;
-TfLiteStatus ResizeOutputShape(TfLiteContext* context,
- const TfLiteTensor* output_shape,
- TfLiteTensor* output) {
+const int kTensorNotAllocated = -1;
+
+struct OpData {
+ // IDs are the arbitrary identifiers used by TF Lite to identify and access
+ // memory buffers.
+ int im2col_id = kTensorNotAllocated;
+
+ // im2col is the only temporary currently tracked, therefore always index 0.
+ // If more temporaries are added, they should be properly tracked.
+ int32_t im2col_index = 0;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ // This is a builtin op, so we don't use the contents in 'buffer', if any.
+ // Instead, we allocate a new object to use as scratch space for im2col, and
+ // to carry information from Prepare() to Eval().
+ auto* data = new OpData;
+ eigen_support::IncrementUsageCounter(context);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ eigen_support::DecrementUsageCounter(context);
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ TfLiteTensor* output) {
// Currently only support int32 for output shape.
if (output_shape->type != kTfLiteInt32) {
context->ReportError(context, "Output shape is %d, not int32.",
@@ -56,15 +83,60 @@ TfLiteStatus ResizeOutputShape(TfLiteContext* context,
return context->ResizeTensor(context, output, output_shape_array);
}
+// Allocate temporary im2col tensor.
+static TfLiteStatus AllocateIm2colTensor(TfLiteContext* context,
+ TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+ if (data->im2col_id == kTensorNotAllocated) {
+ context->AddTensors(context, 1, &data->im2col_id);
+ }
+
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(1);
+ node->temporaries->data[data->im2col_index] = data->im2col_id;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus ResizeIm2ColTensor(TfLiteContext* context,
+ const TfLiteTensor* output_shape,
+ const TfLiteTensor* weights,
+ const TfLiteTensor* input,
+ TfLiteTensor* im2col) {
+ if (output_shape->type != kTfLiteInt32) {
+ context->ReportError(context, "im2col shape is %d, not int32.",
+ output_shape->type);
+ return kTfLiteError;
+ }
+ TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 4);
+ TfLiteIntArray* im2col_shape_array = TfLiteIntArrayCreate(4);
+ im2col_shape_array->data[0] = output_shape->data.i32[0];
+ im2col_shape_array->data[1] = output_shape->data.i32[1];
+ im2col_shape_array->data[2] = output_shape->data.i32[2];
+ const int input_depth = SizeOfDimension(input, 3);
+ const int filter_width = SizeOfDimension(weights, 1);
+ const int filter_height = SizeOfDimension(weights, 2);
+ im2col_shape_array->data[3] = input_depth * filter_height * filter_width;
+
+ im2col->type = input->type;
+ im2col->allocation_type = kTfLiteArenaRw;
+ return context->ResizeTensor(context, im2col, im2col_shape_array);
+}
+
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ TF_LITE_ENSURE_STATUS(AllocateIm2colTensor(context, node));
+
const TfLiteTensor* output_shape =
GetInput(context, node, kOutputShapeTensor);
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[user_data->im2col_index]];
TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1);
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
@@ -79,13 +151,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Ensure that weights and inputs have the same channel dimension.
// Note: TOCO will reorder weights in the following format: OHWI.
TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3),
- SizeOfDimension(weights, 0));
-
- if (!IsConstantTensor(output_shape)) {
+ SizeOfDimension(weights, 3));
+
+ if (IsConstantTensor(output_shape)) {
+ TF_LITE_ENSURE_STATUS(ResizeOutputTensor(context, output_shape, output));
+ TF_LITE_ENSURE_STATUS(
+ ResizeIm2ColTensor(context, output_shape, weights, input, im2col));
+ } else {
+ // Defer resizing until Eval().
SetTensorToDynamic(output);
- return kTfLiteOk;
}
- return ResizeOutputShape(context, output_shape, output);
+ return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -94,13 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* input = GetInput(context, node, kDataInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+ OpData* user_data = reinterpret_cast<OpData*>(node->user_data);
+ TfLiteTensor* im2col =
+ &context->tensors[node->temporaries->data[user_data->im2col_index]];
const auto* params =
reinterpret_cast<TfLiteTransposeConvParams*>(node->builtin_data);
if (IsDynamicTensor(output)) {
TF_LITE_ENSURE_OK(context,
- ResizeOutputShape(context, output_shape, output));
+ ResizeOutputTensor(context, output_shape, output));
+ }
+ if (IsDynamicTensor(im2col)) {
+ TF_LITE_ENSURE_OK(context, ResizeIm2ColTensor(context, output_shape,
+ weights, input, im2col));
}
// Get height and width of the output image.
@@ -123,7 +205,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<float>(input), GetTensorDims(input),
GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
stride_height, padding_size.width, padding_size.height,
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorData<float>(output), GetTensorDims(output),
+ GetTensorData<float>(im2col), GetTensorDims(im2col));
break;
default:
context->ReportError(context, "Type %d, not currently supported.",
@@ -136,8 +219,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
} // namespace transpose_conv
TfLiteRegistration* Register_TRANSPOSE_CONV() {
- static TfLiteRegistration r = {nullptr, nullptr, transpose_conv::Prepare,
- transpose_conv::Eval};
+ static TfLiteRegistration r = {transpose_conv::Init, transpose_conv::Free,
+ transpose_conv::Prepare, transpose_conv::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
index 52be089349..c741df19de 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -24,9 +25,49 @@ namespace {
using ::testing::ElementsAreArray;
+class ConstTransposeConvOpModel : public SingleOpModel {
+ // Just to be extra confusing, transpose_conv has an _input_ named
+ // "output_shape". This input sets the shape of the output tensor of the op.
+ // In this version of the test class, "output_shape" is a constant that must
+ // be specified in the constructor.
+ public:
+ ConstTransposeConvOpModel(TfLiteRegistration* registration,
+ std::initializer_list<int> input_shape,
+ std::initializer_list<int> filter_shape,
+ std::initializer_list<int> output_shape_data,
+ Padding padding, int stride_w, int stride_h) {
+ output_shape_ = AddConstInput(TensorType_INT32, output_shape_data,
+ {static_cast<int>(output_shape_data.size())});
+ filter_ = AddInput(TensorType_FLOAT32);
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
+ CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
+ .Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_TRANSPOSE_CONV, registration);
+ BuildInterpreter({{4}, filter_shape, input_shape});
+ }
+
+ int output_shape() { return output_shape_; }
+ int filter() { return filter_; }
+ int input() { return input_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int output_shape_;
+ int filter_;
+ int input_;
+ int output_;
+};
+
class TransposeConvOpModel : public SingleOpModel {
public:
- TransposeConvOpModel(std::initializer_list<int> input_shape,
+ TransposeConvOpModel(TfLiteRegistration* registration,
+ std::initializer_list<int> input_shape,
std::initializer_list<int> filter_shape, Padding padding,
int stride_w, int stride_h) {
output_shape_ = AddInput(TensorType_INT32);
@@ -37,6 +78,8 @@ class TransposeConvOpModel : public SingleOpModel {
BuiltinOperator_TRANSPOSE_CONV, BuiltinOptions_TransposeConvOptions,
CreateTransposeConvOptions(builder_, padding, stride_w, stride_h)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_TRANSPOSE_CONV, registration);
BuildInterpreter({{4}, filter_shape, input_shape});
}
@@ -54,6 +97,15 @@ class TransposeConvOpModel : public SingleOpModel {
int output_;
};
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({});
+
+class TransposeConvOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// Test case:
// output = tf.nn.conv2d_backprop_input(
// tf.constant([ 1, 4, 4, 1 ]),
@@ -61,8 +113,9 @@ class TransposeConvOpModel : public SingleOpModel {
// tf.constant(np.arange(1, 17), shape=[ 1, 4, 4, 1 ], dtype=tf.float32),
// [1, 1, 1, 1 ],
// "SAME")
-TEST(TransposeConvOpModelTest, SimpleTest) {
- TransposeConvOpModel m({1, 4, 4, 1}, {1, 3, 3, 1}, Padding_SAME, 1, 1);
+TEST_P(TransposeConvOpTest, SimpleTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 3, 3, 1},
+ Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(
@@ -75,6 +128,21 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+// Test case: Same as above, but with a const "output_shape"
+TEST_P(TransposeConvOpTest, ConstSimpleTest) {
+ ConstTransposeConvOpModel m(GetRegistration(), {1, 4, 4, 1}, {1, 4, 4, 1},
+ {1, 3, 3, 1}, Padding_SAME, 1, 1);
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ m.PopulateTensor<float>(
+ m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({29, 62, 83, 75, 99, 192, 237, 198, 207, 372,
+ 417, 330, 263, 446, 485, 365}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
+}
+
// Test case:
// filter = tf.constant(np.arange(1, 19),
// shape=[ 3, 3, 1, 2 ],
@@ -87,11 +155,12 @@ TEST(TransposeConvOpModelTest, SimpleTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1])
-TEST(TransposeConvOpModelTest, TwoFiltersTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1);
+TEST_P(TransposeConvOpTest, TwoFiltersTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
+ Padding_SAME, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 4, 4, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -116,11 +185,12 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) {
// "VALID")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18])
-TEST(TransposeConvOpModelTest, PaddingValidTest) {
- TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1);
+TEST_P(TransposeConvOpTest, PaddingValidTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 4, 4, 2}, {1, 3, 3, 2},
+ Padding_VALID, 1, 1);
m.PopulateTensor<int>(m.output_shape(), {1, 6, 6, 1});
- m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
- 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18});
m.PopulateTensor<float>(
m.input(),
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
@@ -146,8 +216,9 @@ TEST(TransposeConvOpModelTest, PaddingValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST(TransposeConvOpModelTest, StrideValidTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 1}, Padding_VALID, 2, 2);
+TEST_P(TransposeConvOpTest, StrideValidTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {1, 3, 3, 1},
+ Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 1});
m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
@@ -170,11 +241,30 @@ TEST(TransposeConvOpModelTest, StrideValidTest) {
// tf.constant(np.arange(1, 5), shape=[ 1, 2, 2, 1 ], dtype=tf.float32),
// [1, 2, 2, 1 ],
// "VALID")
-TEST(TransposeConvOpModelTest, MultiChannelTest) {
- TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2);
+TEST_P(TransposeConvOpTest, MultiChannelTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
+ Padding_VALID, 2, 2);
m.PopulateTensor<int>(m.output_shape(), {1, 5, 5, 2});
- m.PopulateTensor<float>(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
- 13, 14, 15, 16, 17, 18});
+ m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
+ 8, 10, 12, 14, 16, 18});
+ m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
+ m.Invoke();
+
+ EXPECT_THAT(
+ m.GetOutput(),
+ ElementsAreArray({1, 2, 3, 4, 7, 10, 6, 8, 10, 12, 7, 8, 9,
+ 10, 25, 28, 18, 20, 22, 24, 16, 20, 24, 28, 62, 72,
+ 42, 48, 54, 60, 21, 24, 27, 30, 61, 68, 36, 40, 44,
+ 48, 39, 42, 45, 48, 103, 110, 60, 64, 68, 72}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 5, 5, 2}));
+}
+
+// Test case: Same as above, but with a const "output_shape"
+TEST_P(TransposeConvOpTest, ConstMultiChannelTest) {
+ ConstTransposeConvOpModel m(GetRegistration(), {1, 2, 2, 1}, {2, 3, 3, 1},
+ {1, 5, 5, 2}, Padding_VALID, 2, 2);
+ m.PopulateTensor<float>(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6,
+ 8, 10, 12, 14, 16, 18});
m.PopulateTensor<float>(m.input(), {1, 2, 3, 4});
m.Invoke();
@@ -199,8 +289,9 @@ TEST(TransposeConvOpModelTest, MultiChannelTest) {
// "SAME")
// And filter value is derived by:
// filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[-1])
-TEST(TransposeConvOpModelTest, AccuracyTest) {
- TransposeConvOpModel m({1, 1, 2, 1}, {1, 3, 3, 1}, Padding_SAME, 3, 3);
+TEST_P(TransposeConvOpTest, AccuracyTest) {
+ TransposeConvOpModel m(GetRegistration(), {1, 1, 2, 1}, {1, 3, 3, 1},
+ Padding_SAME, 3, 3);
m.PopulateTensor<int>(m.output_shape(), {1, 3, 4, 1});
m.PopulateTensor<float>(m.filter(), {9, 5, 6, 9, 8, 5, 3, 1, 4});
m.PopulateTensor<float>(m.input(), {323, 521});
@@ -212,6 +303,10 @@ TEST(TransposeConvOpModelTest, AccuracyTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 4, 1}));
}
+INSTANTIATE_TEST_CASE_P(
+ TransposeConvOpTest, TransposeConvOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 1c28123a24..32daf2bb02 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -70,9 +70,21 @@ constexpr int kOutputStateTensor = 0;
constexpr int kCellStateTensor = 1;
constexpr int kOutputTensor = 2;
+// Temporary tensors
+enum TemporaryTensor {
+ kScratchBuffer = 0,
+ kInputQuantized = 1,
+ kOutputStateQuantized = 2,
+ kCellStateQuantized = 3,
+ kScalingFactors = 4,
+ kProductScalingFactors = 5,
+ kRecoveredCellWeights = 6,
+ kNumTemporaryTensors = 7
+};
+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 1, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -84,7 +96,7 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteNode* node, int n_input,
int n_output, int n_cell) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -242,6 +254,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context, input->dims->size > 1);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -288,86 +301,156 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, cell_state, cell_size));
- // Create a scratch buffer tensor.
+ // Mark state tensors as persistent tensors.
+ output_state->allocation_type = kTfLiteArenaRwPersistent;
+ cell_state->allocation_type = kTfLiteArenaRwPersistent;
+
+ // The weights are of consistent type, so it suffices to check one.
+ // TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(1);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, kScratchBuffer);
scratch_buffer->type = input->type;
scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* input_to_input_weights =
GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
if (use_cifg) {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 3;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
} else {
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
// Reserving space for Input, Cell, Forget, Output gates
scratch_buffer_size->data[1] = n_cell * 4;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // output_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[kOutputStateQuantized] =
+ *scratch_tensor_index + kOutputStateQuantized;
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, kOutputStateQuantized);
+ output_state_quantized->type = kTfLiteUInt8;
+ output_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(output_state_quantized->dims,
+ output_state->dims)) {
+ TfLiteIntArray* output_state_quantized_size =
+ TfLiteIntArrayCopy(output_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output_state_quantized,
+ output_state_quantized_size));
+ }
+ node->temporaries->data[kCellStateQuantized] =
+ *scratch_tensor_index + kCellStateQuantized;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, kCellStateQuantized);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
}
return kTfLiteOk;
}
// The LSTM Op engine.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
-
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -380,8 +463,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const bool use_cifg = (input_to_input_weights == nullptr);
const bool use_peephole = (cell_to_output_weights != nullptr);
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
@@ -432,6 +513,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
float* output_state_ptr = output_state->data.f;
float* cell_state_ptr = cell_state->data.f;
+ // Feed the sequence into the LSTM step-by-step.
for (int t = 0; t < max_time; t++) {
const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
float* output_ptr_batch = output->data.f + t * n_batch * n_output;
@@ -452,6 +534,262 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
+ TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
+ TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Feed the sequence into the LSTM step-by-step.
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
+ float* output_ptr_batch = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStep(
+ input_ptr_batch, input_to_input_weights_ptr,
+ input_to_input_weights_scale, input_to_forget_weights_ptr,
+ input_to_forget_weights_scale, input_to_cell_weights_ptr,
+ input_to_cell_weights_scale, input_to_output_weights_ptr,
+ input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors_ptr,
+ prod_scaling_factors_ptr, recovered_cell_weights_ptr,
+ quantized_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr_batch);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
+ TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params,
+ scratch_buffer, output_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* output_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, params, scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, output_state_quantized, cell_state_quantized,
+ output_state, cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
} // namespace unidirectional_sequence_lstm
TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index 5881ced7c7..de38bdef6f 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite Sequential LSTM op.
-#include <iomanip>
#include <memory>
#include <vector>
@@ -37,7 +36,8 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip,
float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weights_type = TensorType_FLOAT32)
: n_batch_(n_batch),
n_input_(n_input),
n_cell_(n_cell),
@@ -48,31 +48,31 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
if (use_cifg) {
input_to_input_weights_ = AddNullInput();
} else {
- input_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_input_weights_ = AddInput(weights_type);
}
- input_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- input_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- input_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ input_to_forget_weights_ = AddInput(weights_type);
+ input_to_cell_weights_ = AddInput(weights_type);
+ input_to_output_weights_ = AddInput(weights_type);
if (use_cifg) {
recurrent_to_input_weights_ = AddNullInput();
} else {
- recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_input_weights_ = AddInput(weights_type);
}
- recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32);
- recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ recurrent_to_forget_weights_ = AddInput(weights_type);
+ recurrent_to_cell_weights_ = AddInput(weights_type);
+ recurrent_to_output_weights_ = AddInput(weights_type);
if (use_peephole) {
if (use_cifg) {
cell_to_input_weights_ = AddNullInput();
} else {
- cell_to_input_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_input_weights_ = AddInput(weights_type);
}
- cell_to_forget_weights_ = AddInput(TensorType_FLOAT32);
- cell_to_output_weights_ = AddInput(TensorType_FLOAT32);
+ cell_to_forget_weights_ = AddInput(weights_type);
+ cell_to_output_weights_ = AddInput(weights_type);
} else {
cell_to_input_weights_ = AddNullInput();
cell_to_forget_weights_ = AddNullInput();
@@ -89,7 +89,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_gate_bias_ = AddInput(TensorType_FLOAT32);
if (use_projection_weights) {
- projection_weights_ = AddInput(TensorType_FLOAT32);
+ projection_weights_ = AddInput(weights_type);
if (use_projection_bias) {
projection_bias_ = AddInput(TensorType_FLOAT32);
} else {
@@ -196,8 +196,9 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
zero_buffer.get() + zero_buffer_size);
}
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -208,7 +209,7 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int num_batches() { return n_batch_; }
int sequence_length() { return sequence_length_; }
- private:
+ protected:
int input_;
int input_to_input_weights_;
int input_to_forget_weights_;
@@ -243,7 +244,183 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int sequence_length_;
};
-TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
+// The hybrid model has quantized weights.
+class HybridUnidirectionalLSTMOpModel : public UnidirectionalLSTMOpModel {
+ public:
+ HybridUnidirectionalLSTMOpModel(
+ int n_batch, int n_input, int n_cell, int n_output, int sequence_length,
+ bool use_cifg, bool use_peephole, bool use_projection_weights,
+ bool use_projection_bias, float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : UnidirectionalLSTMOpModel(
+ n_batch, n_input, n_cell, n_output, sequence_length, use_cifg,
+ use_peephole, use_projection_weights, use_projection_bias,
+ cell_clip, proj_clip, input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> lstm_input_;
+ // LSTM output is stored as num_batch x num_outputs vector.
+ std::vector<std::vector<float>> lstm_golden_output_;
+
+ // Compares output up to tolerance to the result of the lstm given the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ UnidirectionalLSTMOpModel* lstm, float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ // Feed the whole sequence as input.
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ lstm->SetInput(((i * num_batches) + b) * lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+ }
+
+ lstm->Invoke();
+
+ const int num_outputs = lstm->num_outputs();
+ EXPECT_GT(num_outputs, 0);
+ std::vector<float> expected;
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ }
+
+ EXPECT_THAT(lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+};
+
+class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524};
+ input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113, -0.29909778};
+ input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212};
+ input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077,
+ -0.1556896, 0.19487578};
+ input_gate_bias_ = {0., 0., 0., 0.};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_input_weights_ = {
+ -0.0063535, -0.2042388, 0.31454784, -0.35746509,
+ 0.28902304, 0.08183324, -0.16555229, 0.02286911,
+ -0.13566875, 0.03034258, 0.48091322, -0.12528998,
+ 0.24077177, -0.51332325, -0.33502164, 0.10629296};
+
+ recurrent_to_cell_weights_ = {
+ -0.3407414, 0.24443203, -0.2078532, 0.26320225,
+ 0.05695659, -0.00123841, -0.4744786, -0.35869038,
+ -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064};
+
+ recurrent_to_forget_weights_ = {
+ -0.48684245, -0.06655136, 0.42224967, 0.2112639,
+ 0.27654213, 0.20864892, -0.07646349, 0.45877004,
+ 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004};
+
+ recurrent_to_output_weights_ = {
+ 0.43385774, -0.17194885, 0.2718237, 0.09215671,
+ 0.24107647, -0.39835793, 0.18212086, 0.01301402,
+ 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792}};
+ }
+};
+
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -252,9 +429,11 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -281,77 +460,138 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
- -0.34550029, 0.04266912, -0.15680569,
- -0.34856534, 0.43890524});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
- -0.20583314, 0.44344562, 0.22077113,
- -0.29909778});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
- -0.31343272, -0.40032279, 0.44781327,
- 0.01387155, -0.35593212});
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
- 0.40525138, 0.44272184, 0.03897077, -0.1556896,
- 0.19487578});
+TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
- lstm.SetInputGateBias({0., 0., 0., 0.});
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/false,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.SetCellBias({0., 0., 0., 0.});
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
- lstm.SetRecurrentToInputWeights(
- {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
- -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
- -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
- lstm.SetRecurrentToCellWeights(
- {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
- -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
- -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
- lstm.SetRecurrentToForgetWeights(
- {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
- -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
- 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToOutputWeights(
- {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
- 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
- -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- // Input should have n_input * sequence_length many values.
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126,
- -0.15358765, -0.03716109, 0.12507336,
- 0.41193449, -0.20860538, -0.15053082,
- 0.09120187, 0.24278517, -0.12222792};
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
+ /*tolerance=*/0.0157651);
+}
- lstm.SetInput(0, batch0_start, batch0_end);
+class CifgPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726,
+ 0.05100781, 0.04717243, 0.48944736,
+ -0.38535351, -0.17212132};
- lstm.Invoke();
+ input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698,
+ 0.24407166, 0.33826375};
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
-}
+ input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634};
+ cell_gate_bias_ = {0., 0., 0., 0.};
+ forget_gate_bias_ = {1., 1., 1., 1.};
+ output_gate_bias_ = {0., 0., 0., 0.};
+
+ recurrent_to_cell_weights_ = {
+ 0.54066205, -0.32668582, -0.43562764, -0.56094903,
+ 0.42957711, 0.01841056, -0.32764608, -0.33027974,
+ -0.10826075, 0.20675004, 0.19069612, -0.03026325,
+ -0.54532051, 0.33003211, 0.44901288, 0.21193194};
+
+ recurrent_to_forget_weights_ = {
+ -0.13832897, -0.0515101, -0.2359007, -0.16661474,
+ -0.14340827, 0.36986142, 0.23414481, 0.55899,
+ 0.10798943, -0.41174671, 0.17751795, -0.34484994,
+ -0.35874045, -0.11352962, 0.27268326, 0.54058349};
+
+ recurrent_to_output_weights_ = {
+ 0.41613156, 0.42610586, -0.16495961, -0.5663873,
+ 0.30579174, -0.05115908, -0.33941799, 0.23364776,
+ 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987};
+
+ cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408,
+ 0.31544167};
+ cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703,
+ -0.77109635};
+
+ lstm_input_ = {{2., 3., 3., 4., 1., 1.}};
+ lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302}};
+ }
+};
-TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 1;
const int n_input = 2;
// n_cell and n_output have the same size when there is no projection.
@@ -360,9 +600,11 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
const int sequence_length = 3;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
- /*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -389,71 +631,690 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
- 0.04717243, 0.48944736, -0.38535351,
- -0.17212132});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
- -0.3633365, -0.22755712, 0.28253698, 0.24407166,
- 0.33826375});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
- -0.09426838, -0.44257352, 0.54939759,
- 0.01533556, 0.42751634});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
+
+TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/true, /*use_peephole=*/true,
+ /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- lstm.SetCellBias({0., 0., 0., 0.});
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
- lstm.SetForgetGateBias({1., 1., 1., 1.});
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
- lstm.SetOutputGateBias({0., 0., 0., 0.});
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
- lstm.SetRecurrentToCellWeights(
- {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
- 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
- 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
- 0.21193194});
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
- lstm.SetRecurrentToForgetWeights(
- {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
- 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
- -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
- lstm.SetRecurrentToOutputWeights(
- {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
- -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
- 0.50248802, 0.26114327, -0.43736315, 0.33149987});
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
- lstm.SetCellToForgetWeights(
- {0.47485286, -0.51955009, -0.24458408, 0.31544167});
- lstm.SetCellToOutputWeights(
- {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
- static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585,
- -0.05163646, -0.42312205, -0.01218222,
- 0.24201041, -0.08124574, -0.358325,
- -0.04621704, 0.21641694, -0.06471302};
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- float* batch0_start = lstm_input;
- float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
-
- lstm.SetInput(0, batch0_start, batch0_end);
-
- lstm.Invoke();
-
- float* golden_start = lstm_golden_output;
- float* golden_end =
- golden_start + lstm.num_outputs() * lstm.sequence_length();
- std::vector<float> expected;
- expected.insert(expected.end(), golden_start, golden_end);
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
+class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {
+ 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
+ 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
+ -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
+ -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
+ -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
+ -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
+ -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
+ 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
+ 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
+ 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
+ -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
+ 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
+ -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
+ -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
+ -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
+ 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
+ -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
+ -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
+ -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
+ -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677};
+
+ input_to_forget_weights_ = {
+ -0.0018401089, -0.004852237, 0.03698424, 0.014181704,
+ 0.028273236, -0.016726194, -0.05249759, -0.10204261,
+ 0.00861066, -0.040979505, -0.009899187, 0.01923892,
+ -0.028177269, -0.08535103, -0.14585495, 0.10662567,
+ -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
+ 0.0030784295, 0.076784775, 0.07463696, 0.094531395,
+ 0.0814421, -0.12257899, -0.033945758, -0.031303465,
+ 0.045630626, 0.06843887, -0.13492945, -0.012480007,
+ -0.0811829, -0.07224499, -0.09628791, 0.045100946,
+ 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
+ 0.06958324, 0.034257296, 0.0482646, 0.06267997,
+ 0.052625068, 0.12784666, 0.07077897, 0.025725935,
+ 0.04165009, 0.07241905, 0.018668644, -0.037377294,
+ -0.06277783, -0.08833636, -0.040120605, -0.011405586,
+ -0.007808335, -0.010301386, -0.005102167, 0.027717464,
+ 0.05483423, 0.11449111, 0.11289652, 0.10939839,
+ 0.13396506, -0.08402166, -0.01901462, -0.044678304,
+ -0.07720565, 0.014350063, -0.11757958, -0.0652038,
+ -0.08185733, -0.076754324, -0.092614375, 0.10405491,
+ 0.052960336, 0.035755895, 0.035839386, -0.012540553,
+ 0.036881298, 0.02913376, 0.03420159, 0.05448447,
+ -0.054523353, 0.02582715, 0.02327355, -0.011857179,
+ -0.0011980024, -0.034641717, -0.026125094, -0.17582615,
+ -0.15923657, -0.27486774, -0.0006143371, 0.0001771948,
+ -8.470171e-05, 0.02651807, 0.045790765, 0.06956496};
+
+ input_to_cell_weights_ = {
+ -0.04580283, -0.09549462, -0.032418985, -0.06454633,
+ -0.043528453, 0.043018587, -0.049152344, -0.12418144,
+ -0.078985475, -0.07596889, 0.019484362, -0.11434962,
+ -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
+ -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
+ 0.10665918, -0.032036792, -0.08505916, -0.10843358,
+ -0.13002433, -0.036816437, -0.02130134, -0.016518239,
+ 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
+ -0.10652836, -0.1037554, -0.13056071, -0.03266643,
+ -0.033702414, -0.006473424, -0.04611692, 0.014419339,
+ -0.025174323, 0.0396852, 0.081777506, 0.06157468,
+ 0.10210095, -0.009658194, 0.046511717, 0.03603906,
+ 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
+ 0.053568836, 0.06408714, 0.12835667, -0.008714329,
+ -0.20211966, -0.12093674, 0.029450472, 0.2849013,
+ -0.029227901, 0.1164364, -0.08560263, 0.09941786,
+ -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
+ -0.09720865, -0.11193351, -0.029155117, -0.017936034,
+ -0.009768936, -0.04223324, -0.036159635, 0.06505112,
+ -0.021742892, -0.023377212, -0.07221364, -0.06430552,
+ 0.05453865, 0.091149814, 0.06387331, 0.007518393,
+ 0.055960953, 0.069779344, 0.046411168, 0.10509911,
+ 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
+ 0.056955688, 0.06555285, 0.050801456, -0.009862683,
+ 0.00826772, -0.026555609, -0.0073611983, -0.0014897042};
+
+ input_to_output_weights_ = {
+ -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
+ -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
+ 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
+ -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
+ -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
+ 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
+ -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
+ -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
+ -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
+ -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
+ 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
+ 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
+ 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
+ -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
+ 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
+ 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
+ -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
+ 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
+ -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
+ -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956};
+
+ input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666,
+ 0.053110216, -0.06928846, -0.13942584, -0.11816189,
+ 0.19483899, 0.03652339, -0.10250295, 0.036714908,
+ -0.18426876, 0.036065217, 0.21810818, 0.02383196,
+ -0.043370757, 0.08690144, -0.04444982, 0.00030581196};
+
+ forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696,
+ 0.11098921, 0.15378423, 0.09263801, 0.09790885,
+ 0.09508917, 0.061199076, 0.07665568, -0.015443159,
+ -0.03499149, 0.046190713, 0.08895977, 0.10899629,
+ 0.40694186, 0.06030037, 0.012413437, -0.06108739};
+
+ cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
+ -0.1483596, -0.10639995, -0.091433935, 0.058573797,
+ -0.06809782, -0.07889636, -0.043246906, -0.09829136,
+ -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
+ 0.016178843, 0.1749513, 0.13975595, 0.92058027};
+
+ output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113,
+ 0.027195795, 0.35373217, -0.018957434, 0.008907322,
+ -0.0762701, 0.12018895, 0.04216877, 0.0022856654,
+ 0.040952638, 0.3147856, 0.08225149, -0.057416286,
+ -0.14995944, -0.008040261, 0.13208859, 0.029760877};
+
+ recurrent_to_input_weights_ = {
+ -0.001374326, -0.078856036, 0.10672688, 0.029162422,
+ -0.11585556, 0.02557986, -0.13446963, -0.035785314,
+ -0.01244275, 0.025961924, -0.02337298, -0.044228926,
+ -0.055839065, -0.046598054, -0.010546039, -0.06900766,
+ 0.027239809, 0.022582639, -0.013296484, -0.05459212,
+ 0.08981, -0.045407712, 0.08682226, -0.06867011,
+ -0.14390695, -0.02916037, 0.000996957, 0.091420636,
+ 0.14283475, -0.07390571, -0.06402044, 0.062524505,
+ -0.093129106, 0.04860203, -0.08364217, -0.08119002,
+ 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
+ -0.13732095, 0.012405723, -0.07551853, 0.06343048,
+ 0.12162708, -0.031923793, -0.014335606, 0.01790974,
+ -0.10650317, -0.0724401, 0.08554849, -0.05727212,
+ 0.06556731, -0.042729504, -0.043227166, 0.011683251,
+ -0.013082158, -0.029302018, -0.010899579, -0.062036745,
+ -0.022509435, -0.00964907, -0.01567329, 0.04260106,
+ -0.07787477, -0.11576462, 0.017356863, 0.048673786,
+ -0.017577527, -0.05527947, -0.082487635, -0.040137455,
+ -0.10820036, -0.04666372, 0.022746278, -0.07851417,
+ 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
+ 0.08944216, -0.0685835, 0.010513544, 0.07228705,
+ 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
+ 0.040414046, -0.1380399, 0.094208956, -0.05722982,
+ 0.012092817, -0.04989123, -0.086576, -0.003399834,
+ -0.04696032, -0.045747425, 0.10091314, 0.048676282,
+ -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
+ 0.09504992, 0.041799378, -0.049185462, -0.031518843,
+ -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
+ -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
+ -0.10167381, 0.042500053, -0.01447153, 0.06464186,
+ -0.017142897, 0.03312627, 0.009205989, 0.024138335,
+ -0.011337001, 0.035530265, -0.010912711, 0.0706555,
+ -0.005894094, 0.051841937, -0.1401738, -0.02351249,
+ 0.0365468, 0.07590991, 0.08838724, 0.021681072,
+ -0.10086113, 0.019608743, -0.06195883, 0.077335775,
+ 0.023646897, -0.095322326, 0.02233014, 0.09756986,
+ -0.048691444, -0.009579111, 0.07595467, 0.11480546,
+ -0.09801813, 0.019894179, 0.08502348, 0.004032281,
+ 0.037211012, 0.068537936, -0.048005626, -0.091520436,
+ -0.028379958, -0.01556313, 0.06554592, -0.045599163,
+ -0.01672207, -0.020169014, -0.011877351, -0.20212261,
+ 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
+ -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
+ 0.015963363, 0.00871737, 0.060130805, 0.028611384,
+ 0.10109069, -0.015060172, -0.07894427, 0.06401885,
+ 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
+ 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
+ 0.019899689, 0.006106124, -0.027092824, 0.0786356,
+ 0.05052217, -0.058925, -0.011402121, -0.024987547,
+ -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
+ -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
+ -0.033664223, -0.07978348, -0.025200296, -0.017207067,
+ -0.058403496, -0.055697463, 0.005798788, 0.12965427,
+ -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
+ 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
+ 0.013806405, -0.017858358, -0.01008298, -0.07700066,
+ -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
+ 0.062634714, -0.02338735, -0.039547626, -0.02050681,
+ 0.03385117, -0.083611414, 0.002862572, -0.09421313,
+ 0.058618143, -0.08598433, 0.00972939, 0.023867095,
+ -0.053934585, -0.023203006, 0.07452513, -0.048767887,
+ -0.07314807, -0.056307215, -0.10433547, -0.06440842,
+ 0.04328182, 0.04389765, -0.020006588, -0.09076438,
+ -0.11652589, -0.021705797, 0.03345259, -0.010329105,
+ -0.025767034, 0.013057034, -0.07316461, -0.10145612,
+ 0.06358255, 0.18531723, 0.07759293, 0.12006465,
+ 0.1305557, 0.058638252, -0.03393652, 0.09622831,
+ -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
+ -0.005644518, 0.06857898, -0.12598175, -0.035084512,
+ 0.03156317, -0.12794146, -0.031963028, 0.04692781,
+ 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
+ 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
+ 0.08184801, -0.019164372, 0.06791302, 0.034257166,
+ -0.10307039, 0.021943003, 0.046745934, 0.0790918,
+ -0.0265588, -0.007824208, 0.042546265, -0.00977924,
+ -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
+ -0.014512694, -0.08251313, 0.08861942, 0.13589665,
+ 0.026351685, 0.012641483, 0.07466548, 0.044301085,
+ -0.045414884, -0.051112458, 0.03444247, -0.08502782,
+ -0.04106223, -0.028126027, 0.028473156, 0.10467447};
+
+ recurrent_to_cell_weights_ = {
+ -0.037322544, 0.018592842, 0.0056175636, -0.06253426,
+ 0.055647098, -0.05713207, -0.05626563, 0.005559383,
+ 0.03375411, -0.025757805, -0.088049285, 0.06017052,
+ -0.06570978, 0.007384076, 0.035123326, -0.07920549,
+ 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
+ 0.08089997, 0.05143358, 0.038261272, 0.03339287,
+ -0.027673481, 0.044746667, 0.028349208, 0.020090483,
+ -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
+ -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
+ -0.10893326, 0.076739706, -0.08509834, -0.027997585,
+ 0.037871376, 0.01449768, -0.09002357, -0.06111149,
+ -0.046195522, 0.0422062, -0.005683705, -0.1253618,
+ -0.012925729, -0.04890792, 0.06985068, 0.037654128,
+ 0.03398274, -0.004781977, 0.007032333, -0.031787455,
+ 0.010868644, -0.031489216, 0.09525667, 0.013939797,
+ 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
+ -0.048885044, -0.12722108, 0.035304096, 0.06554885,
+ 0.00972396, -0.039238118, -0.05159735, -0.11329045,
+ 0.1613692, -0.03750952, 0.06529313, -0.071974665,
+ -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
+ 0.02786344, -0.014179351, 0.005264273, 0.14376344,
+ 0.015983658, 0.03406988, -0.06939408, 0.040699873,
+ 0.02111075, 0.09669095, 0.041345075, -0.08316494,
+ -0.07684199, -0.045768797, 0.032298047, -0.041805092,
+ 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
+ -0.024950314, 0.11574242, 0.04508852, -0.04335324,
+ 0.06760663, -0.027437469, 0.07216407, 0.06977076,
+ -0.05438599, 0.034033038, -0.028602652, 0.05346137,
+ 0.043184172, -0.037189785, 0.10420091, 0.00882477,
+ -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
+ 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
+ 0.04361412, -0.007001822, 0.09631092, -0.06702025,
+ -0.042049985, -0.035070654, -0.04103342, -0.10273396,
+ 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
+ -0.008264958, 0.042035464, 0.05891794, 0.029673764,
+ 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
+ -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
+ -0.04043371, -0.017094059, 0.07229206, -0.023670016,
+ -0.052195564, -0.025616996, -0.01520939, 0.045104615,
+ -0.007376126, 0.003533447, 0.006570588, 0.056037236,
+ 0.12436656, 0.051817212, 0.028532185, -0.08686856,
+ 0.11868599, 0.07663395, -0.07323171, 0.03463402,
+ -0.050708205, -0.04458982, -0.11590894, 0.021273347,
+ 0.1251325, -0.15313013, -0.12224372, 0.17228661,
+ 0.023029093, 0.086124025, 0.006445803, -0.03496501,
+ 0.028332196, 0.04449512, -0.042436164, -0.026587414,
+ -0.006041347, -0.09292539, -0.05678812, 0.03897832,
+ 0.09465633, 0.008115513, -0.02171956, 0.08304309,
+ 0.071401566, 0.019622514, 0.032163795, -0.004167056,
+ 0.02295182, 0.030739572, 0.056506045, 0.004612461,
+ 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
+ -0.1335546, -0.030136576, 0.11584653, -0.014678886,
+ 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
+ -0.0329582, 0.07922767, 0.029322514, 0.026405897,
+ 0.04207835, -0.07073373, 0.063781224, 0.0859677,
+ -0.10925287, -0.07011058, 0.048005477, 0.03438226,
+ -0.09606514, -0.006669445, -0.043381985, 0.04240257,
+ -0.06955775, -0.06769346, 0.043903265, -0.026784198,
+ -0.017840602, 0.024307009, -0.040079936, -0.019946516,
+ 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
+ 0.15978073, 0.10185836, 0.10298046, -0.015476589,
+ -0.039390966, -0.072174534, 0.0739445, -0.1211869,
+ -0.0347889, -0.07943156, 0.014809798, -0.12412325,
+ -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
+ -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
+ -0.01514876, -0.056505352, -0.012800942, -0.06994386,
+ 0.012962922, -0.031234352, 0.07029052, 0.016418684,
+ 0.03618972, 0.055686004, -0.08663945, -0.017404709,
+ -0.054761406, 0.029065743, 0.052404847, 0.020238016,
+ 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
+ 0.06262858, 0.009184685, 0.020785125, -0.043904778,
+ -0.0270329, -0.03299152, -0.060088247, -0.015162964,
+ -0.001828936, 0.12642565, -0.056757294, 0.013586685,
+ 0.09232601, -0.035886683, 0.06000002, 0.05229691,
+ -0.052580316, -0.082029596, -0.010794592, 0.012947712,
+ -0.036429964, -0.085508935, -0.13127148, -0.017744139,
+ 0.031502828, 0.036232427, -0.031581745, 0.023051167,
+ -0.05325106, -0.03421577, 0.028793324, -0.034633752,
+ -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
+ -0.008799762, 0.056595087, 0.0022273948, 0.055752404};
+
+ recurrent_to_forget_weights_ = {
+ -0.057784554, -0.026057621, -0.068447545, -0.022581743,
+ 0.14811787, 0.10826372, 0.09471067, 0.03987225,
+ -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
+ 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
+ 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
+ -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
+ -0.06193199, 0.055729095, 0.03736828, 0.020123724,
+ 0.061878487, -0.04729229, 0.034919553, -0.07585433,
+ -0.04421272, -0.044019096, 0.085488975, 0.04058006,
+ -0.06890133, -0.030951202, -0.024628663, -0.07672815,
+ 0.034293607, 0.08556707, -0.05293577, -0.033561368,
+ -0.04899627, 0.0241671, 0.015736353, -0.095442444,
+ -0.029564252, 0.016493602, -0.035026584, 0.022337519,
+ -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
+ 0.016435321, -0.03263031, -0.09543275, -0.047392778,
+ 0.013454138, 0.028934088, 0.01685226, -0.086110644,
+ -0.046250615, -0.01847454, 0.047608484, 0.07339695,
+ 0.034546845, -0.04881143, 0.009128804, -0.08802852,
+ 0.03761666, 0.008096139, -0.014454086, 0.014361001,
+ -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
+ -0.06509276, -0.006021153, -0.08570962, -0.1451793,
+ 0.060212336, 0.055259194, 0.06974018, 0.049454916,
+ -0.027794661, -0.08077226, -0.016179763, 0.1169753,
+ 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
+ -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
+ 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
+ -0.05695512, 0.047233116, 0.038937137, -0.06542224,
+ 0.014429736, -0.09719407, 0.13908425, -0.05379757,
+ 0.012321099, 0.082840554, -0.029899208, 0.044217527,
+ 0.059855383, 0.07711018, -0.045319796, 0.0948846,
+ -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
+ -0.13873616, 0.040668588, 0.034832682, -0.015319203,
+ -0.018715994, 0.046002675, 0.0599172, -0.043107376,
+ 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
+ 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
+ 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
+ 0.052958444, 0.07558703, 0.04817258, 0.044462286,
+ -0.015213451, -0.08783778, -0.0561384, -0.003008196,
+ 0.047060397, -0.002058388, 0.03429439, -0.018839769,
+ 0.024734668, 0.024614193, -0.042046934, 0.09597743,
+ -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
+ -0.02558259, -0.022822596, -0.023273505, -0.02464396,
+ -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
+ 0.04383914, -0.046476185, 0.028658995, 0.060410924,
+ 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
+ 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
+ 0.015898481, 0.021362653, -0.030262267, 0.016587038,
+ -0.011442813, 0.041154444, -0.007631438, -0.03423484,
+ -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
+ 0.02318443, -0.041350313, 0.021485701, -0.10906167,
+ -0.028218046, -0.00954771, 0.020531068, -0.11995105,
+ -0.03672871, 0.024019798, 0.014255957, -0.05221243,
+ -0.00661567, -0.04630967, 0.033188973, 0.10107534,
+ -0.014027541, 0.030796422, -0.10270911, -0.035999842,
+ 0.15443139, 0.07684145, 0.036571592, -0.035900835,
+ -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
+ -0.03858649, 0.01849943, 0.13872518, 0.01503974,
+ 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
+ -0.047401894, 0.03100163, -0.041533746, -0.10430945,
+ 0.044574402, -0.01425562, -0.024290353, 0.034563623,
+ 0.05866852, 0.023947537, -0.09445152, 0.035450947,
+ 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
+ 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
+ 0.03532124, -0.016341697, 0.09685456, -0.016764693,
+ 0.051808182, 0.05875331, -0.04536488, 0.001626336,
+ -0.028892258, -0.01048663, -0.009793449, -0.017093895,
+ 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
+ -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
+ -0.01769146, 0.040995963, 0.02235177, -0.060430344,
+ 0.11475477, -0.023854522, 0.10071741, 0.0686208,
+ -0.014250481, 0.034261297, 0.047418304, 0.08562733,
+ -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
+ 0.04096551, 0.032249358, -0.08355519, -0.026823482,
+ 0.056386515, -0.010401743, -0.028396193, 0.08507674,
+ 0.014410365, 0.020995233, 0.17040324, 0.11511526,
+ 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
+ -0.081302024, 0.017264642, -0.009585969, 0.09491168,
+ -0.051313367, 0.054532815, -0.014298593, 0.10657464,
+ 0.007076659, 0.10964551, 0.0409152, 0.008275321,
+ -0.07283536, 0.07937492, 0.04192024, -0.1075027};
+
+ recurrent_to_output_weights_ = {
+ 0.025825322, -0.05813119, 0.09495884, -0.045984812,
+ -0.01255415, -0.0026479573, -0.08196161, -0.054914974,
+ -0.0046604523, -0.029587349, -0.044576716, -0.07480124,
+ -0.082868785, 0.023254942, 0.027502948, -0.0039728214,
+ -0.08683098, -0.08116779, -0.014675607, -0.037924774,
+ -0.023314456, -0.007401714, -0.09255757, 0.029460307,
+ -0.08829125, -0.005139627, -0.08989442, -0.0555066,
+ 0.13596267, -0.025062224, -0.048351806, -0.03850004,
+ 0.07266485, -0.022414139, 0.05940088, 0.075114764,
+ 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
+ -0.025980417, 0.072999895, 0.11091378, -0.081685916,
+ 0.014416728, 0.043229222, 0.034178585, -0.07530371,
+ 0.035837382, -0.085607, -0.007721233, -0.03287832,
+ -0.043848954, -0.06404588, -0.06632928, -0.073643476,
+ 0.008214239, -0.045984086, 0.039764922, 0.03474462,
+ 0.060612556, -0.080590084, 0.049127717, 0.04151091,
+ -0.030063879, 0.008801774, -0.023021035, -0.019558564,
+ 0.05158114, -0.010947698, -0.011825728, 0.0075720972,
+ 0.0699727, -0.0039981045, 0.069350146, 0.08799282,
+ 0.016156472, 0.035502106, 0.11695009, 0.006217345,
+ 0.13392477, -0.037875112, 0.025745004, 0.08940699,
+ -0.00924166, 0.0046702605, -0.036598757, -0.08811812,
+ 0.10522024, -0.032441203, 0.008176899, -0.04454919,
+ 0.07058152, 0.0067963637, 0.039206743, 0.03259838,
+ 0.03725492, -0.09515802, 0.013326398, -0.052055415,
+ -0.025676316, 0.03198509, -0.015951829, -0.058556724,
+ 0.036879618, 0.043357447, 0.028362012, -0.05908629,
+ 0.0059240665, -0.04995891, -0.019187413, 0.0276265,
+ -0.01628143, 0.0025863599, 0.08800015, 0.035250366,
+ -0.022165963, -0.07328642, -0.009415526, -0.07455109,
+ 0.11690406, 0.0363299, 0.07411125, 0.042103454,
+ -0.009660886, 0.019076364, 0.018299393, -0.046004917,
+ 0.08891175, 0.0431396, -0.026327137, -0.051502608,
+ 0.08979574, -0.051670972, 0.04940282, -0.07491107,
+ -0.021240504, 0.022596184, -0.034280192, 0.060163025,
+ -0.058211457, -0.051837247, -0.01349775, -0.04639988,
+ -0.035936575, -0.011681591, 0.064818054, 0.0073146066,
+ -0.021745546, -0.043124277, -0.06471268, -0.07053354,
+ -0.029321948, -0.05330136, 0.016933719, -0.053782392,
+ 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
+ 0.05693899, -0.053219706, 0.063698, 0.07977434,
+ -0.07924483, 0.06936997, 0.0034815092, -0.007305279,
+ -0.037325785, -0.07251102, -0.033633437, -0.08677009,
+ 0.091591336, -0.14165086, 0.021752775, 0.019683983,
+ 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
+ -0.0024567875, -0.14345716, 0.010955264, -0.10234828,
+ 0.1183656, -0.0010731248, -0.023590032, -0.072285876,
+ -0.0724771, -0.026382286, -0.0014920527, 0.042667855,
+ 0.0018776858, 0.02986552, 0.009814309, 0.0733756,
+ 0.12289186, 0.018043943, -0.0458958, 0.049412545,
+ 0.033632483, 0.05495232, 0.036686596, -0.013781798,
+ -0.010036754, 0.02576849, -0.08307328, 0.010112348,
+ 0.042521734, -0.05869831, -0.071689695, 0.03876447,
+ -0.13275425, -0.0352966, -0.023077697, 0.10285965,
+ 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
+ -0.10292561, -0.032401145, 0.10053256, -0.026142767,
+ -0.08271222, -0.0030240538, -0.016368777, 0.1070414,
+ 0.042672627, 0.013456989, -0.0437609, -0.022309763,
+ 0.11576483, 0.04108048, 0.061026827, -0.0190714,
+ -0.0869359, 0.037901703, 0.0610107, 0.07202949,
+ 0.01675338, 0.086139716, -0.08795751, -0.014898893,
+ -0.023771819, -0.01965048, 0.007955471, -0.043740474,
+ 0.03346837, -0.10549954, 0.090567775, 0.042013682,
+ -0.03176985, 0.12569028, -0.02421228, -0.029526481,
+ 0.023851605, 0.031539805, 0.05292009, -0.02344001,
+ -0.07811758, -0.08834428, 0.10094801, 0.16594367,
+ -0.06861939, -0.021256343, -0.041093912, -0.06669611,
+ 0.035498552, 0.021757556, -0.09302526, -0.015403468,
+ -0.06614931, -0.051798206, -0.013874718, 0.03630673,
+ 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
+ 0.03541868, -0.094149634, -0.034814864, 0.003128424,
+ -0.020674974, -0.03944324, -0.008110165, -0.11113267,
+ 0.08484226, 0.043586485, 0.040582247, 0.0968012,
+ -0.065249965, -0.028036479, 0.0050708856, 0.0017462453,
+ 0.0326779, 0.041296225, 0.09164146, -0.047743853,
+ -0.015952192, -0.034451712, 0.084197424, -0.05347844,
+ -0.11768019, 0.085926116, -0.08251791, -0.045081906,
+ 0.0948852, 0.068401024, 0.024856757, 0.06978981,
+ -0.057309967, -0.012775832, -0.0032452994, 0.01977615,
+ -0.041040014, -0.024264973, 0.063464895, 0.05431621,
+ };
+
+ cell_to_input_weights_ = {
+ 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
+ -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
+ -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
+ 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175};
+
+ cell_to_forget_weights_ = {
+ -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
+ -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
+ -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
+ 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355};
+
+ cell_to_output_weights_ = {
+ 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
+ -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
+ -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
+ 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733};
+
+ projection_weights_ = {
+ -0.009802181, 0.09401916, 0.0717386, -0.13895074,
+ 0.09641832, 0.060420845, 0.08539281, 0.054285463,
+ 0.061395317, 0.034448683, -0.042991187, 0.019801661,
+ -0.16840284, -0.015726732, -0.23041931, -0.024478018,
+ -0.10959692, -0.013875541, 0.18600968, -0.061274476,
+ 0.0138165, -0.08160894, -0.07661644, 0.032372914,
+ 0.16169067, 0.22465782, -0.03993472, -0.004017731,
+ 0.08633481, -0.28869787, 0.08682067, 0.17240396,
+ 0.014975425, 0.056431185, 0.031037588, 0.16702051,
+ 0.0077946745, 0.15140012, 0.29405436, 0.120285,
+ -0.188994, -0.027265169, 0.043389652, -0.022061434,
+ 0.014777949, -0.20203483, 0.094781205, 0.19100232,
+ 0.13987629, -0.036132768, -0.06426278, -0.05108664,
+ 0.13221376, 0.009441198, -0.16715929, 0.15859416,
+ -0.040437475, 0.050779544, -0.022187516, 0.012166504,
+ 0.027685808, -0.07675938, -0.0055694645, -0.09444123,
+ 0.0046453946, 0.050794356, 0.10770313, -0.20790008,
+ -0.07149004, -0.11425117, 0.008225835, -0.035802525,
+ 0.14374903, 0.15262283, 0.048710253, 0.1847461,
+ -0.007487823, 0.11000021, -0.09542012, 0.22619456,
+ -0.029149994, 0.08527916, 0.009043713, 0.0042746216,
+ 0.016261552, 0.022461696, 0.12689082, -0.043589946,
+ -0.12035478, -0.08361797, -0.050666027, -0.1248618,
+ -0.1275799, -0.071875185, 0.07377272, 0.09944291,
+ -0.18897448, -0.1593054, -0.06526116, -0.040107165,
+ -0.004618631, -0.067624845, -0.007576253, 0.10727444,
+ 0.041546922, -0.20424393, 0.06907816, 0.050412357,
+ 0.00724631, 0.039827548, 0.12449835, 0.10747581,
+ 0.13708383, 0.09134148, -0.12617786, -0.06428341,
+ 0.09956831, 0.1208086, -0.14676677, -0.0727722,
+ 0.1126304, 0.010139365, 0.015571211, -0.038128063,
+ 0.022913318, -0.042050496, 0.16842307, -0.060597885,
+ 0.10531834, -0.06411776, -0.07451711, -0.03410368,
+ -0.13393489, 0.06534304, 0.003620307, 0.04490757,
+ 0.05970546, 0.05197996, 0.02839995, 0.10434969,
+ -0.013699693, -0.028353551, -0.07260381, 0.047201227,
+ -0.024575593, -0.036445823, 0.07155557, 0.009672501,
+ -0.02328883, 0.009533515, -0.03606021, -0.07421458,
+ -0.028082801, -0.2678904, -0.13221288, 0.18419984,
+ -0.13012612, -0.014588381, -0.035059117, -0.04824723,
+ 0.07830115, -0.056184657, 0.03277091, 0.025466874,
+ 0.14494097, -0.12522776, -0.098633975, -0.10766018,
+ -0.08317623, 0.08594209, 0.07749552, 0.039474737,
+ 0.1776665, -0.07409566, -0.0477268, 0.29323658,
+ 0.10801441, 0.1154011, 0.013952499, 0.10739139,
+ 0.10708251, -0.051456142, 0.0074137426, -0.10430189,
+ 0.10034707, 0.045594677, 0.0635285, -0.0715442,
+ -0.089667566, -0.10811871, 0.00026344223, 0.08298446,
+ -0.009525053, 0.006585689, -0.24567553, -0.09450807,
+ 0.09648481, 0.026996298, -0.06419476, -0.04752702,
+ -0.11063944, -0.23441927, -0.17608605, -0.052156363,
+ 0.067035615, 0.19271925, -0.0032889997, -0.043264326,
+ 0.09663576, -0.057112187, -0.10100678, 0.0628376,
+ 0.04447668, 0.017961001, -0.10094388, -0.10190601,
+ 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
+ 0.10539724, -0.04383912, -0.042349473, 0.08438151,
+ -0.1947263, 0.02251204, 0.11216432, -0.10307853,
+ 0.17351969, -0.039091777, 0.08066188, -0.00561982,
+ 0.12633002, 0.11335965, -0.0088127935, -0.019777594,
+ 0.06864014, -0.059751723, 0.016233567, -0.06894641,
+ -0.28651384, -0.004228674, 0.019708522, -0.16305895,
+ -0.07468996, -0.0855457, 0.099339016, -0.07580735,
+ -0.13775392, 0.08434318, 0.08330512, -0.12131499,
+ 0.031935584, 0.09180414, -0.08876437, -0.08049874,
+ 0.008753825, 0.03498998, 0.030215185, 0.03907079,
+ 0.089751154, 0.029194152, -0.03337423, -0.019092513,
+ 0.04331237, 0.04299654, -0.036394123, -0.12915532,
+ 0.09793732, 0.07512415, -0.11319543, -0.032502122,
+ 0.15661901, 0.07671967, -0.005491124, -0.19379048,
+ -0.218606, 0.21448623, 0.017840758, 0.1416943,
+ -0.07051762, 0.19488361, 0.02664691, -0.18104725,
+ -0.09334311, 0.15026465, -0.15493552, -0.057762887,
+ -0.11604192, -0.262013, -0.01391798, 0.012185008,
+ 0.11156489, -0.07483202, 0.06693364, -0.26151478,
+ 0.046425626, 0.036540434, -0.16435726, 0.17338543,
+ -0.21401681, -0.11385144, -0.08283257, -0.069031075,
+ 0.030635102, 0.010969227, 0.11109743, 0.010919218,
+ 0.027526086, 0.13519906, 0.01891392, -0.046839405,
+ -0.040167913, 0.017953383, -0.09700955, 0.0061885654,
+ -0.07000971, 0.026893595, -0.038844477, 0.14543656};
+
+ lstm_input_ = {
+ {// Batch0: 4 (input_sequence_size) * 5 (n_input)
+ 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0
+ 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1
+ 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2
+ 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3
+
+ {// Batch1: 4 (input_sequence_size) * 5 (n_input)
+ 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0
+ 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1
+ 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2
+ 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3
+ };
+
+ lstm_golden_output_ = {
+ {// Batch0: 4 (input_sequence_size) * 16 (n_output)
+ -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
+ -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
+ -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
+ 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
+ -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
+ -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
+ 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
+ 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
+ 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
+ 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
+ -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
+ -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
+ 0.0286833, 0.00824207, 0.0264887, 0.0305169},
+ {// Batch1: 4 (input_sequence_size) * 16 (n_output)
+ -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
+ -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
+ 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
+ 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
+ -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
+ -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
+ 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
+ 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
+ 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
+ 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
+ -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
+ -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
+ 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -461,8 +1322,9 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
const int sequence_length = 4;
UnidirectionalLSTMOpModel lstm(
- n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
- /*use_peephole=*/true, /*use_projection_weights=*/true,
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
@@ -491,588 +1353,99 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{0}, // projection_bias tensor
});
- lstm.SetInputToInputWeights(
- {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
- 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048,
- -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385,
- -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282,
- -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627,
- -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226,
- -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059,
- 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698,
- 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206,
- 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585,
- -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063,
- 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603,
- -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682,
- -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988,
- -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764,
- 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476,
- -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012,
- -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604,
- -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654,
- -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677});
-
- lstm.SetInputToForgetWeights(
- {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236,
- -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505,
- -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495,
- 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323,
- 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421,
- -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887,
- -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791,
- 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059,
- 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068,
- 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905,
- 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605,
- -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464,
- 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506,
- -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063,
- -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375,
- 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553,
- 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353,
- 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717,
- -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371,
- 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496});
-
- lstm.SetInputToCellWeights(
- {-0.04580283, -0.09549462, -0.032418985, -0.06454633,
- -0.043528453, 0.043018587, -0.049152344, -0.12418144,
- -0.078985475, -0.07596889, 0.019484362, -0.11434962,
- -0.0074034138, -0.06314844, -0.092981495, 0.0062155537,
- -0.025034338, -0.0028890965, 0.048929527, 0.06235075,
- 0.10665918, -0.032036792, -0.08505916, -0.10843358,
- -0.13002433, -0.036816437, -0.02130134, -0.016518239,
- 0.0047691227, -0.0025825808, 0.066017866, 0.029991534,
- -0.10652836, -0.1037554, -0.13056071, -0.03266643,
- -0.033702414, -0.006473424, -0.04611692, 0.014419339,
- -0.025174323, 0.0396852, 0.081777506, 0.06157468,
- 0.10210095, -0.009658194, 0.046511717, 0.03603906,
- 0.0069369148, 0.015960095, -0.06507666, 0.09551598,
- 0.053568836, 0.06408714, 0.12835667, -0.008714329,
- -0.20211966, -0.12093674, 0.029450472, 0.2849013,
- -0.029227901, 0.1164364, -0.08560263, 0.09941786,
- -0.036999565, -0.028842626, -0.0033637602, -0.017012902,
- -0.09720865, -0.11193351, -0.029155117, -0.017936034,
- -0.009768936, -0.04223324, -0.036159635, 0.06505112,
- -0.021742892, -0.023377212, -0.07221364, -0.06430552,
- 0.05453865, 0.091149814, 0.06387331, 0.007518393,
- 0.055960953, 0.069779344, 0.046411168, 0.10509911,
- 0.07463894, 0.0075130584, 0.012850982, 0.04555431,
- 0.056955688, 0.06555285, 0.050801456, -0.009862683,
- 0.00826772, -0.026555609, -0.0073611983, -0.0014897042});
-
- lstm.SetInputToOutputWeights(
- {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918,
- -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534,
- 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722,
- -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761,
- -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394,
- 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154,
- -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135,
- -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564,
- -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047,
- -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304,
- 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946,
- 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646,
- 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813,
- -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403,
- 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415,
- 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495,
- -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158,
- 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295,
- -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739,
- -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956});
-
- lstm.SetInputGateBias(
- {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216,
- -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339,
- -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818,
- 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196});
-
- lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696,
- 0.11098921, 0.15378423, 0.09263801, 0.09790885,
- 0.09508917, 0.061199076, 0.07665568, -0.015443159,
- -0.03499149, 0.046190713, 0.08895977, 0.10899629,
- 0.40694186, 0.06030037, 0.012413437, -0.06108739});
-
- lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873,
- -0.1483596, -0.10639995, -0.091433935, 0.058573797,
- -0.06809782, -0.07889636, -0.043246906, -0.09829136,
- -0.4279842, 0.034901652, 0.18797937, 0.0075234566,
- 0.016178843, 0.1749513, 0.13975595, 0.92058027});
-
- lstm.SetOutputGateBias(
- {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795,
- 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895,
- 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149,
- -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877});
-
- lstm.SetRecurrentToInputWeights(
- {-0.001374326, -0.078856036, 0.10672688, 0.029162422,
- -0.11585556, 0.02557986, -0.13446963, -0.035785314,
- -0.01244275, 0.025961924, -0.02337298, -0.044228926,
- -0.055839065, -0.046598054, -0.010546039, -0.06900766,
- 0.027239809, 0.022582639, -0.013296484, -0.05459212,
- 0.08981, -0.045407712, 0.08682226, -0.06867011,
- -0.14390695, -0.02916037, 0.000996957, 0.091420636,
- 0.14283475, -0.07390571, -0.06402044, 0.062524505,
- -0.093129106, 0.04860203, -0.08364217, -0.08119002,
- 0.009352075, 0.22920375, 0.0016303885, 0.11583097,
- -0.13732095, 0.012405723, -0.07551853, 0.06343048,
- 0.12162708, -0.031923793, -0.014335606, 0.01790974,
- -0.10650317, -0.0724401, 0.08554849, -0.05727212,
- 0.06556731, -0.042729504, -0.043227166, 0.011683251,
- -0.013082158, -0.029302018, -0.010899579, -0.062036745,
- -0.022509435, -0.00964907, -0.01567329, 0.04260106,
- -0.07787477, -0.11576462, 0.017356863, 0.048673786,
- -0.017577527, -0.05527947, -0.082487635, -0.040137455,
- -0.10820036, -0.04666372, 0.022746278, -0.07851417,
- 0.01068115, 0.032956902, 0.022433773, 0.0026891115,
- 0.08944216, -0.0685835, 0.010513544, 0.07228705,
- 0.02032331, -0.059686817, -0.0005566496, -0.086984694,
- 0.040414046, -0.1380399, 0.094208956, -0.05722982,
- 0.012092817, -0.04989123, -0.086576, -0.003399834,
- -0.04696032, -0.045747425, 0.10091314, 0.048676282,
- -0.029037097, 0.031399418, -0.0040285117, 0.047237843,
- 0.09504992, 0.041799378, -0.049185462, -0.031518843,
- -0.10516937, 0.026374253, 0.10058866, -0.0033195973,
- -0.041975245, 0.0073591834, 0.0033782164, -0.004325073,
- -0.10167381, 0.042500053, -0.01447153, 0.06464186,
- -0.017142897, 0.03312627, 0.009205989, 0.024138335,
- -0.011337001, 0.035530265, -0.010912711, 0.0706555,
- -0.005894094, 0.051841937, -0.1401738, -0.02351249,
- 0.0365468, 0.07590991, 0.08838724, 0.021681072,
- -0.10086113, 0.019608743, -0.06195883, 0.077335775,
- 0.023646897, -0.095322326, 0.02233014, 0.09756986,
- -0.048691444, -0.009579111, 0.07595467, 0.11480546,
- -0.09801813, 0.019894179, 0.08502348, 0.004032281,
- 0.037211012, 0.068537936, -0.048005626, -0.091520436,
- -0.028379958, -0.01556313, 0.06554592, -0.045599163,
- -0.01672207, -0.020169014, -0.011877351, -0.20212261,
- 0.010889619, 0.0047078193, 0.038385306, 0.08540671,
- -0.017140968, -0.0035865551, 0.016678626, 0.005633034,
- 0.015963363, 0.00871737, 0.060130805, 0.028611384,
- 0.10109069, -0.015060172, -0.07894427, 0.06401885,
- 0.011584063, -0.024466386, 0.0047652307, -0.09041358,
- 0.030737216, -0.0046374933, 0.14215417, -0.11823516,
- 0.019899689, 0.006106124, -0.027092824, 0.0786356,
- 0.05052217, -0.058925, -0.011402121, -0.024987547,
- -0.0013661642, -0.06832946, -0.015667673, -0.1083353,
- -0.00096863037, -0.06988685, -0.053350925, -0.027275559,
- -0.033664223, -0.07978348, -0.025200296, -0.017207067,
- -0.058403496, -0.055697463, 0.005798788, 0.12965427,
- -0.062582195, 0.0013350133, -0.10482091, 0.0379771,
- 0.072521195, -0.0029455067, -0.13797039, -0.03628521,
- 0.013806405, -0.017858358, -0.01008298, -0.07700066,
- -0.017081132, 0.019358726, 0.0027079724, 0.004635139,
- 0.062634714, -0.02338735, -0.039547626, -0.02050681,
- 0.03385117, -0.083611414, 0.002862572, -0.09421313,
- 0.058618143, -0.08598433, 0.00972939, 0.023867095,
- -0.053934585, -0.023203006, 0.07452513, -0.048767887,
- -0.07314807, -0.056307215, -0.10433547, -0.06440842,
- 0.04328182, 0.04389765, -0.020006588, -0.09076438,
- -0.11652589, -0.021705797, 0.03345259, -0.010329105,
- -0.025767034, 0.013057034, -0.07316461, -0.10145612,
- 0.06358255, 0.18531723, 0.07759293, 0.12006465,
- 0.1305557, 0.058638252, -0.03393652, 0.09622831,
- -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845,
- -0.005644518, 0.06857898, -0.12598175, -0.035084512,
- 0.03156317, -0.12794146, -0.031963028, 0.04692781,
- 0.030070418, 0.0071660685, -0.095516115, -0.004643372,
- 0.040170413, -0.062104587, -0.0037324072, 0.0554317,
- 0.08184801, -0.019164372, 0.06791302, 0.034257166,
- -0.10307039, 0.021943003, 0.046745934, 0.0790918,
- -0.0265588, -0.007824208, 0.042546265, -0.00977924,
- -0.0002440307, -0.017384544, -0.017990116, 0.12252321,
- -0.014512694, -0.08251313, 0.08861942, 0.13589665,
- 0.026351685, 0.012641483, 0.07466548, 0.044301085,
- -0.045414884, -0.051112458, 0.03444247, -0.08502782,
- -0.04106223, -0.028126027, 0.028473156, 0.10467447});
-
- lstm.SetRecurrentToForgetWeights(
- {-0.057784554, -0.026057621, -0.068447545, -0.022581743,
- 0.14811787, 0.10826372, 0.09471067, 0.03987225,
- -0.0039523416, 0.00030638507, 0.053185795, 0.10572994,
- 0.08414449, -0.022036452, -0.00066928595, -0.09203576,
- 0.032950465, -0.10985798, -0.023809856, 0.0021431844,
- -0.02196096, -0.00326074, 0.00058621005, -0.074678116,
- -0.06193199, 0.055729095, 0.03736828, 0.020123724,
- 0.061878487, -0.04729229, 0.034919553, -0.07585433,
- -0.04421272, -0.044019096, 0.085488975, 0.04058006,
- -0.06890133, -0.030951202, -0.024628663, -0.07672815,
- 0.034293607, 0.08556707, -0.05293577, -0.033561368,
- -0.04899627, 0.0241671, 0.015736353, -0.095442444,
- -0.029564252, 0.016493602, -0.035026584, 0.022337519,
- -0.026871363, 0.004780428, 0.0077918363, -0.03601621,
- 0.016435321, -0.03263031, -0.09543275, -0.047392778,
- 0.013454138, 0.028934088, 0.01685226, -0.086110644,
- -0.046250615, -0.01847454, 0.047608484, 0.07339695,
- 0.034546845, -0.04881143, 0.009128804, -0.08802852,
- 0.03761666, 0.008096139, -0.014454086, 0.014361001,
- -0.023502491, -0.0011840804, -0.07607001, 0.001856849,
- -0.06509276, -0.006021153, -0.08570962, -0.1451793,
- 0.060212336, 0.055259194, 0.06974018, 0.049454916,
- -0.027794661, -0.08077226, -0.016179763, 0.1169753,
- 0.17213494, -0.0056326236, -0.053934924, -0.0124349,
- -0.11520337, 0.05409887, 0.088759385, 0.0019655675,
- 0.0042065294, 0.03881498, 0.019844765, 0.041858196,
- -0.05695512, 0.047233116, 0.038937137, -0.06542224,
- 0.014429736, -0.09719407, 0.13908425, -0.05379757,
- 0.012321099, 0.082840554, -0.029899208, 0.044217527,
- 0.059855383, 0.07711018, -0.045319796, 0.0948846,
- -0.011724666, -0.0033288454, -0.033542685, -0.04764985,
- -0.13873616, 0.040668588, 0.034832682, -0.015319203,
- -0.018715994, 0.046002675, 0.0599172, -0.043107376,
- 0.0294216, -0.002314414, -0.022424703, 0.0030315618,
- 0.0014641669, 0.0029166266, -0.11878115, 0.013738511,
- 0.12375372, -0.0006038222, 0.029104086, 0.087442465,
- 0.052958444, 0.07558703, 0.04817258, 0.044462286,
- -0.015213451, -0.08783778, -0.0561384, -0.003008196,
- 0.047060397, -0.002058388, 0.03429439, -0.018839769,
- 0.024734668, 0.024614193, -0.042046934, 0.09597743,
- -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786,
- -0.02558259, -0.022822596, -0.023273505, -0.02464396,
- -0.10991725, -0.006240552, 0.0074488563, 0.024044557,
- 0.04383914, -0.046476185, 0.028658995, 0.060410924,
- 0.050786525, 0.009452605, -0.0073054377, -0.024810238,
- 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517,
- 0.015898481, 0.021362653, -0.030262267, 0.016587038,
- -0.011442813, 0.041154444, -0.007631438, -0.03423484,
- -0.010977775, 0.036152758, 0.0066366293, 0.11915515,
- 0.02318443, -0.041350313, 0.021485701, -0.10906167,
- -0.028218046, -0.00954771, 0.020531068, -0.11995105,
- -0.03672871, 0.024019798, 0.014255957, -0.05221243,
- -0.00661567, -0.04630967, 0.033188973, 0.10107534,
- -0.014027541, 0.030796422, -0.10270911, -0.035999842,
- 0.15443139, 0.07684145, 0.036571592, -0.035900835,
- -0.0034699554, 0.06209149, 0.015920248, -0.031122351,
- -0.03858649, 0.01849943, 0.13872518, 0.01503974,
- 0.069941424, -0.06948533, -0.0088794185, 0.061282158,
- -0.047401894, 0.03100163, -0.041533746, -0.10430945,
- 0.044574402, -0.01425562, -0.024290353, 0.034563623,
- 0.05866852, 0.023947537, -0.09445152, 0.035450947,
- 0.02247216, -0.0042998926, 0.061146557, -0.10250651,
- 0.020881841, -0.06747029, 0.10062043, -0.0023941975,
- 0.03532124, -0.016341697, 0.09685456, -0.016764693,
- 0.051808182, 0.05875331, -0.04536488, 0.001626336,
- -0.028892258, -0.01048663, -0.009793449, -0.017093895,
- 0.010987891, 0.02357273, -0.00010856845, 0.0099760275,
- -0.001845119, -0.03551521, 0.0018358806, 0.05763657,
- -0.01769146, 0.040995963, 0.02235177, -0.060430344,
- 0.11475477, -0.023854522, 0.10071741, 0.0686208,
- -0.014250481, 0.034261297, 0.047418304, 0.08562733,
- -0.030519066, 0.0060542435, 0.014653856, -0.038836084,
- 0.04096551, 0.032249358, -0.08355519, -0.026823482,
- 0.056386515, -0.010401743, -0.028396193, 0.08507674,
- 0.014410365, 0.020995233, 0.17040324, 0.11511526,
- 0.02459721, 0.0066619175, 0.025853224, -0.023133837,
- -0.081302024, 0.017264642, -0.009585969, 0.09491168,
- -0.051313367, 0.054532815, -0.014298593, 0.10657464,
- 0.007076659, 0.10964551, 0.0409152, 0.008275321,
- -0.07283536, 0.07937492, 0.04192024, -0.1075027});
-
- lstm.SetRecurrentToCellWeights(
- {-0.037322544, 0.018592842, 0.0056175636, -0.06253426,
- 0.055647098, -0.05713207, -0.05626563, 0.005559383,
- 0.03375411, -0.025757805, -0.088049285, 0.06017052,
- -0.06570978, 0.007384076, 0.035123326, -0.07920549,
- 0.053676967, 0.044480428, -0.07663568, 0.0071805613,
- 0.08089997, 0.05143358, 0.038261272, 0.03339287,
- -0.027673481, 0.044746667, 0.028349208, 0.020090483,
- -0.019443132, -0.030755889, -0.0040000007, 0.04465846,
- -0.021585021, 0.0031670958, 0.0053199246, -0.056117613,
- -0.10893326, 0.076739706, -0.08509834, -0.027997585,
- 0.037871376, 0.01449768, -0.09002357, -0.06111149,
- -0.046195522, 0.0422062, -0.005683705, -0.1253618,
- -0.012925729, -0.04890792, 0.06985068, 0.037654128,
- 0.03398274, -0.004781977, 0.007032333, -0.031787455,
- 0.010868644, -0.031489216, 0.09525667, 0.013939797,
- 0.0058680447, 0.0167067, 0.02668468, -0.04797466,
- -0.048885044, -0.12722108, 0.035304096, 0.06554885,
- 0.00972396, -0.039238118, -0.05159735, -0.11329045,
- 0.1613692, -0.03750952, 0.06529313, -0.071974665,
- -0.11769596, 0.015524369, -0.0013754242, -0.12446318,
- 0.02786344, -0.014179351, 0.005264273, 0.14376344,
- 0.015983658, 0.03406988, -0.06939408, 0.040699873,
- 0.02111075, 0.09669095, 0.041345075, -0.08316494,
- -0.07684199, -0.045768797, 0.032298047, -0.041805092,
- 0.0119405, 0.0061010392, 0.12652606, 0.0064572375,
- -0.024950314, 0.11574242, 0.04508852, -0.04335324,
- 0.06760663, -0.027437469, 0.07216407, 0.06977076,
- -0.05438599, 0.034033038, -0.028602652, 0.05346137,
- 0.043184172, -0.037189785, 0.10420091, 0.00882477,
- -0.054019816, -0.074273005, -0.030617684, -0.0028467078,
- 0.024302477, -0.0038869337, 0.005332455, 0.0013399826,
- 0.04361412, -0.007001822, 0.09631092, -0.06702025,
- -0.042049985, -0.035070654, -0.04103342, -0.10273396,
- 0.0544271, 0.037184782, -0.13150354, -0.0058036847,
- -0.008264958, 0.042035464, 0.05891794, 0.029673764,
- 0.0063542654, 0.044788733, 0.054816857, 0.062257513,
- -0.00093483756, 0.048938446, -0.004952862, -0.007730018,
- -0.04043371, -0.017094059, 0.07229206, -0.023670016,
- -0.052195564, -0.025616996, -0.01520939, 0.045104615,
- -0.007376126, 0.003533447, 0.006570588, 0.056037236,
- 0.12436656, 0.051817212, 0.028532185, -0.08686856,
- 0.11868599, 0.07663395, -0.07323171, 0.03463402,
- -0.050708205, -0.04458982, -0.11590894, 0.021273347,
- 0.1251325, -0.15313013, -0.12224372, 0.17228661,
- 0.023029093, 0.086124025, 0.006445803, -0.03496501,
- 0.028332196, 0.04449512, -0.042436164, -0.026587414,
- -0.006041347, -0.09292539, -0.05678812, 0.03897832,
- 0.09465633, 0.008115513, -0.02171956, 0.08304309,
- 0.071401566, 0.019622514, 0.032163795, -0.004167056,
- 0.02295182, 0.030739572, 0.056506045, 0.004612461,
- 0.06524936, 0.059999723, 0.046395954, -0.0045512207,
- -0.1335546, -0.030136576, 0.11584653, -0.014678886,
- 0.0020118146, -0.09688814, -0.0790206, 0.039770417,
- -0.0329582, 0.07922767, 0.029322514, 0.026405897,
- 0.04207835, -0.07073373, 0.063781224, 0.0859677,
- -0.10925287, -0.07011058, 0.048005477, 0.03438226,
- -0.09606514, -0.006669445, -0.043381985, 0.04240257,
- -0.06955775, -0.06769346, 0.043903265, -0.026784198,
- -0.017840602, 0.024307009, -0.040079936, -0.019946516,
- 0.045318738, -0.12233574, 0.026170589, 0.0074471775,
- 0.15978073, 0.10185836, 0.10298046, -0.015476589,
- -0.039390966, -0.072174534, 0.0739445, -0.1211869,
- -0.0347889, -0.07943156, 0.014809798, -0.12412325,
- -0.0030663363, 0.039695457, 0.0647603, -0.08291318,
- -0.018529687, -0.004423833, 0.0037507233, 0.084633216,
- -0.01514876, -0.056505352, -0.012800942, -0.06994386,
- 0.012962922, -0.031234352, 0.07029052, 0.016418684,
- 0.03618972, 0.055686004, -0.08663945, -0.017404709,
- -0.054761406, 0.029065743, 0.052404847, 0.020238016,
- 0.0048197987, -0.0214882, 0.07078733, 0.013016777,
- 0.06262858, 0.009184685, 0.020785125, -0.043904778,
- -0.0270329, -0.03299152, -0.060088247, -0.015162964,
- -0.001828936, 0.12642565, -0.056757294, 0.013586685,
- 0.09232601, -0.035886683, 0.06000002, 0.05229691,
- -0.052580316, -0.082029596, -0.010794592, 0.012947712,
- -0.036429964, -0.085508935, -0.13127148, -0.017744139,
- 0.031502828, 0.036232427, -0.031581745, 0.023051167,
- -0.05325106, -0.03421577, 0.028793324, -0.034633752,
- -0.009881397, -0.043551125, -0.018609839, 0.0019097115,
- -0.008799762, 0.056595087, 0.0022273948, 0.055752404});
-
- lstm.SetRecurrentToOutputWeights({
- 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415,
- -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349,
- -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948,
- -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774,
- -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125,
- -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224,
- -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088,
- 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867,
- -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728,
- 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607,
- -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928,
- -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462,
- 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879,
- 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698,
- -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146,
- 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345,
- 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166,
- 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203,
- 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743,
- 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415,
- -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618,
- 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891,
- -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015,
- 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109,
- 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886,
- 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396,
- -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282,
- -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025,
- -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575,
- -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277,
- -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719,
- -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215,
- 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483,
- 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102,
- -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775,
- 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841,
- -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656,
- -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286,
- -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309,
- 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545,
- 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754,
- 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831,
- -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697,
- 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453,
- -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222,
- -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989,
- -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827,
- -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949,
- 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819,
- -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954,
- 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228,
- -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001,
- -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939,
- -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556,
- -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718,
- 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893,
- 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974,
- -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485,
- 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856,
- 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853,
- -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019,
- 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024,
- 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994,
- 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621,
- });
-
- lstm.SetCellToInputWeights(
- {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458,
- -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174,
- -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047,
- 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175});
-
- lstm.SetCellToForgetWeights(
- {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276,
- -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766,
- -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774,
- 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355});
-
- lstm.SetCellToOutputWeights(
- {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764,
- -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544,
- -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817,
- 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733});
-
- lstm.SetProjectionWeights(
- {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832,
- 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683,
- -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931,
- -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476,
- 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067,
- 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787,
- 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588,
- 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285,
- -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949,
- -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768,
- -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929,
- 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504,
- 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946,
- 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117,
- 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253,
- 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456,
- -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552,
- 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797,
- -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272,
- 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165,
- -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922,
- -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548,
- 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786,
- -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722,
- 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318,
- -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776,
- -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307,
- 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969,
- -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593,
- -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515,
- -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288,
- 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723,
- 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097,
- -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209,
- 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268,
- 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139,
- 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707,
- 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871,
- 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553,
- -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702,
- -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615,
- 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187,
- -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388,
- -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709,
- 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263,
- 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777,
- 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935,
- -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641,
- -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996,
- -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318,
- 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437,
- -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079,
- 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237,
- 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415,
- -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124,
- -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943,
- -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311,
- 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013,
- -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364,
- -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543,
- -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102,
- 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906,
- 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955,
- 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656});
-
- static float lstm_input[][20] = {
- {// Batch0: 4 (input_sequence_size) * 5 (n_input)
- 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386,
- 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199,
- 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339},
-
- {// Batch1: 4 (input_sequence_size) * 5 (n_input)
- 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260,
- 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485,
- 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}};
-
- static float lstm_golden_output[][64] = {
- {// Batch0: 4 (input_sequence_size) * 16 (n_output)
- -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576,
- -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004,
- -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147,
- 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363,
- -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322,
- -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308,
- 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794,
- 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474,
- 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827,
- 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512,
- -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407,
- -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193,
- 0.0286833, 0.00824207, 0.0264887, 0.0305169},
- {// Batch1: 4 (input_sequence_size) * 16 (n_output)
- -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926,
- -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232,
- 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954,
- 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507,
- -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039,
- -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233,
- 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378,
- 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034,
- 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789,
- 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855,
- -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679,
- -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181,
- 0.0412031, 0.0118723, 0.0239643, 0.0394009}};
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
- float* batch0_end = batch0_start + lstm.num_inputs();
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
+}
- lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end);
+TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 20;
+ const int n_output = 16;
+ const int sequence_length = 4;
- float* batch1_start = lstm_input[1] + i * lstm.num_inputs();
- float* batch1_end = batch1_start + lstm.num_inputs();
- lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end);
- }
+ HybridUnidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false,
+ /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
- lstm.Invoke();
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
- std::vector<float> expected;
- for (int i = 0; i < lstm.sequence_length(); i++) {
- float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs();
- float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs();
- float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs();
- float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs();
- expected.insert(expected.end(), golden_start_batch0, golden_end_batch0);
- expected.insert(expected.end(), golden_start_batch1, golden_end_batch1);
- }
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ lstm.SetInputToInputWeights(input_to_input_weights_);
+ lstm.SetInputToCellWeights(input_to_cell_weights_);
+ lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ lstm.SetInputGateBias(input_gate_bias_);
+ lstm.SetCellBias(cell_gate_bias_);
+ lstm.SetForgetGateBias(forget_gate_bias_);
+ lstm.SetOutputGateBias(output_gate_bias_);
+
+ lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ lstm.SetCellToInputWeights(cell_to_input_weights_);
+ lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ lstm.SetProjectionWeights(projection_weights_);
+
+ // Resetting cell_state and output_state
+ lstm.ResetCellState();
+ lstm.ResetOutputState();
+
+ VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
} // namespace
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 6ac41a94bd..93b3df98f3 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -45,6 +45,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_FLOAT32:
*type = kTfLiteFloat32;
break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
case TensorType_INT32:
*type = kTfLiteInt32;
break;
@@ -60,6 +63,9 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
case TensorType_BOOL:
*type = kTfLiteBool;
break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
default:
error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
EnumNameTensorType(tensor_type), tensor_type);
@@ -180,6 +186,8 @@ InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model,
op_resolver_(op_resolver),
error_reporter_(ValidateErrorReporter(error_reporter)) {}
+InterpreterBuilder::~InterpreterBuilder() {}
+
TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
TfLiteStatus status = kTfLiteOk;
auto opcodes = model_->operator_codes();
@@ -198,8 +206,9 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
} else if (builtin_code != BuiltinOperator_CUSTOM) {
registration = op_resolver_.FindOp(builtin_code, version);
if (registration == nullptr) {
- error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
- EnumNameBuiltinOperator(builtin_code));
+ error_reporter_->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
status = kTfLiteError;
}
} else if (!opcode->custom_code()) {
@@ -322,12 +331,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = nullptr;
switch (op_type) {
- case BuiltinOperator_CALL:
- // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
- // ok for now, since there is no call implementation either.
- break;
- case BuiltinOperator_CUSTOM:
- break;
case BuiltinOperator_CONV_2D: {
TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
@@ -343,21 +346,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_TANH:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU_N1_TO_1:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_CONCAT_EMBEDDINGS:
- case BuiltinOperator_EXP:
- case BuiltinOperator_TOPK_V2:
- case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_DEQUANTIZE:
- case BuiltinOperator_PRELU:
- case BuiltinOperator_FLOOR:
- case BuiltinOperator_NEG:
- case BuiltinOperator_SIN:
- break;
case BuiltinOperator_CAST: {
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
@@ -445,9 +433,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_EMBEDDING_LOOKUP:
- // no-op.
- break;
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
TfLiteEmbeddingLookupSparseParams* params =
MallocPOD<TfLiteEmbeddingLookupSparseParams>();
@@ -465,6 +450,18 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
op->builtin_options_as_FullyConnectedOptions()) {
params->activation = parse_activation(
fully_connected_params->fused_activation_function());
+ switch (fully_connected_params->weights_format()) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ params->weights_format =
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+ break;
+ default:
+ error_reporter->Report("Unhandled fully-connected weights format.");
+ return kTfLiteError;
+ }
}
*builtin_data = reinterpret_cast<void*>(params);
break;
@@ -558,6 +555,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
parse_activation(lstm_params->fused_activation_function());
params->cell_clip = lstm_params->cell_clip();
params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
}
*builtin_data = reinterpret_cast<void*>(params);
break;
@@ -571,12 +576,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_PAD: {
- break;
- }
- case BuiltinOperator_PADV2: {
- break;
- }
case BuiltinOperator_RESHAPE: {
auto* params = MallocPOD<TfLiteReshapeParams>();
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
@@ -616,18 +615,10 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_SPACE_TO_BATCH_ND: {
- break;
- }
- case BuiltinOperator_BATCH_TO_SPACE_ND: {
- break;
- }
- case BuiltinOperator_TRANSPOSE: {
- break;
- }
- case BuiltinOperator_MEAN: {
- auto* params = MallocPOD<TfLiteMeanParams>();
- if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_SUM: {
+ auto* params = MallocPOD<TfLiteReducerParams>();
+ if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
params->keep_dims = schema_params->keep_dims();
}
*builtin_data = reinterpret_cast<void*>(params);
@@ -664,10 +655,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_MINIMUM: {
- break;
- }
case BuiltinOperator_ARG_MAX: {
auto* params = MallocPOD<TfLiteArgMaxParams>();
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
@@ -677,14 +664,13 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_GREATER:
- case BuiltinOperator_GREATER_EQUAL:
- case BuiltinOperator_LESS:
- case BuiltinOperator_LESS_EQUAL:
- case BuiltinOperator_SELECT: {
- break;
- }
- case BuiltinOperator_SLICE: {
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_TRANSPOSE_CONV: {
@@ -709,11 +695,73 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_SHAPE: {
+ auto* params = MallocPOD<TfLiteShapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+ ConvertTensorType(schema_params->out_type(), &params->out_type,
+ error_reporter);
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_DELEGATE: {
// TODO(ycling): Revisit when supporting saving delegated models.
error_reporter->Report("DELEGATE op shouldn't exist in model.");
return kTfLiteError;
}
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+
+ // Below are the ops with no builtin_data strcture.
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ case BuiltinOperator_CALL:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_EXP:
+ case BuiltinOperator_EXPAND_DIMS:
+ case BuiltinOperator_FLOOR:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_LOG:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_NEG:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_PRELU:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RSQRT:
+ case BuiltinOperator_SELECT:
+ case BuiltinOperator_SIN:
+ case BuiltinOperator_SLICE:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SQRT:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_TILE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
+ break;
}
return kTfLiteOk;
}
@@ -735,7 +783,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
}
const TfLiteRegistration* registration =
- flatbuffer_op_index_to_registration_[op->opcode_index()];
+ flatbuffer_op_index_to_registration_[index];
if (registration == nullptr) {
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
status = kTfLiteError;
@@ -854,7 +902,16 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
const char* buffer_ptr;
TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
+ bool is_variable = tensor->is_variable();
if (buffer_ptr) {
+ if (is_variable) {
+ error_reporter_->Report(
+ "Tensor %d is a variable tensor with buffer. "
+ "It's not supported now.\n",
+ i);
+ status = kTfLiteError;
+ }
+
if (interpreter->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
@@ -863,8 +920,9 @@ TfLiteStatus InterpreterBuilder::ParseTensors(
status = kTfLiteError;
}
} else {
- if (interpreter->SetTensorParametersReadWrite(
- i, type, get_name(tensor), dims, quantization) != kTfLiteOk) {
+ if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
+ dims, quantization,
+ is_variable) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
@@ -948,6 +1006,15 @@ TfLiteStatus InterpreterBuilder::operator()(
if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
+ std::vector<int> variables;
+ for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
+ auto* tensor = (*interpreter)->tensor(i);
+ if (tensor->is_variable) {
+ variables.push_back(i);
+ }
+ }
+ (**interpreter).SetVariables(std::move(variables));
+
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 3946b49041..8bc9ecd7ce 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -156,6 +156,7 @@ class InterpreterBuilder {
InterpreterBuilder(const ::tflite::Model* model,
const OpResolver& op_resolver,
ErrorReporter* error_reporter = DefaultErrorReporter());
+ ~InterpreterBuilder();
InterpreterBuilder(const InterpreterBuilder&) = delete;
InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
index f8767b443a..f18a2ca07a 100644
--- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -1,3 +1,5 @@
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
index e6c8d966f1..c7e08814fd 100644
--- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
+++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc
@@ -35,8 +35,8 @@ const char kModelName[] = "smartreply_ondevice_model.bin";
const char kSamples[] = "smartreply_samples.tsv";
string TestDataPath() {
- return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
- "contrib/lite/models/testdata/"));
+ return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/",
+ "contrib/lite/models/testdata/"));
}
MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") {
@@ -55,7 +55,7 @@ class PredictorTest : public ::testing::Test {
protected:
PredictorTest() {
model_ = tflite::FlatBufferModel::BuildFromFile(
- StrCat(TestDataPath(), "/", kModelName).c_str());
+ absl::StrCat(TestDataPath(), "/", kModelName).c_str());
CHECK(model_);
}
~PredictorTest() override {}
@@ -121,7 +121,7 @@ TEST_F(PredictorTest, BatchTest) {
int total_triggers = 0;
string line;
- std::ifstream fin(StrCat(TestDataPath(), "/", kSamples));
+ std::ifstream fin(absl::StrCat(TestDataPath(), "/", kSamples));
while (std::getline(fin, line)) {
const std::vector<string> fields = absl::StrSplit(line, '\t');
if (fields.empty()) {
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index fad08bbfe6..cc668485a4 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -29,27 +29,46 @@ limitations under the License.
namespace tflite {
-// TODO(aselle): FATAL leaves resources hanging.
-void FATAL(const char* format, ...) {
+void logError(const char* format, ...) {
+ // TODO(mikie): use android logging, stderr is not captured for Java
+ // applications
va_list args;
va_start(args, format);
vfprintf(stderr, format, args);
va_end(args);
+ fprintf(stderr, "\n");
fflush(stderr);
- exit(1);
}
+#define FATAL(...) \
+ logError(__VA_ARGS__); \
+ exit(1);
+
// TODO(aselle): Change the error model to use status codes.
-#define CHECK_TFLITE_SUCCESS(x) \
- if (x != kTfLiteOk) { \
- FATAL("Aborting since tflite returned failure."); \
+#define CHECK_TFLITE_SUCCESS(x) \
+ if (x != kTfLiteOk) { \
+ FATAL("Aborting since tflite returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ }
+
+#define CHECK_NN(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ FATAL("Aborting since NNAPI returned failure nnapi_delegate.cc:%d", \
+ __LINE__); \
}
-#define CHECK_NN(x) \
- if (x != ANEURALNETWORKS_NO_ERROR) { \
- FATAL("Aborting since tflite returned failure."); \
+#define RETURN_ERROR_IF_NN_FAILED(x) \
+ if (x != ANEURALNETWORKS_NO_ERROR) { \
+ logError( \
+ "Returning error since NNAPI returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ return kTfLiteError; \
}
+// Tracking of NNAPI operand ids
+static const int64_t kOperandIdNotSet = -1;
+static const int64_t kOperandNotNeeded = -2;
+
namespace {
int32_t GetAndroidSdkVersion() {
@@ -104,21 +123,16 @@ NNAPIDelegate::~NNAPIDelegate() {
}
// Adds the tensors of the interpreter to the NN API model.
-// Returns the number of operands added.
-uint32_t addTensorOperands(tflite::Interpreter* interpreter,
- ANeuralNetworksModel* nn_model,
- const std::vector<uint32_t>& skip_list) {
+TfLiteStatus addTensorOperands(tflite::Interpreter* interpreter,
+ ANeuralNetworksModel* nn_model,
+ uint32_t* no_of_operands_added,
+ std::vector<int64_t>* nnapi_ids) {
uint32_t next_id = 0;
for (size_t i = 0; i < interpreter->tensors_size(); i++) {
- // skip temporaries tensors.
- bool shouldSkip = false;
- for (auto skip_idx : skip_list) {
- if (i == skip_idx) {
- shouldSkip = true;
- break;
- }
- }
- if (shouldSkip) continue;
+ // Skip temporaries and RNN back-edges.
+ if ((*nnapi_ids)[i] == kOperandNotNeeded) continue;
+
+ (*nnapi_ids)[i] = int64_t(next_id);
int32_t nn_type = 0;
// NNAPI requires 32-bit float scale to be zero, tflite doesn't care
@@ -144,7 +158,18 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
zeroPoint = tensor->params.zero_point;
break;
default:
- FATAL("Unsupported type.");
+ logError("Unsupported tensor type %d", tensor->type);
+ return kTfLiteError;
+ }
+ if (tensor->dims->size == 0) {
+ logError("NNAPI doesn't support tensors with rank 0 (index %d name %s)",
+ i, tensor->name);
+ return kTfLiteError;
+ }
+ if (tensor->dims->size > 4) {
+ logError("NNAPI doesn't support tensors with rank > 4 (index %d name %s)",
+ i, tensor->name);
+ return kTfLiteError;
}
// TODO(aselle): Note, many of these are intermediate results. Do I need
// to ever specify these sizes. I am currently below doing setValue
@@ -154,36 +179,53 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter,
ANeuralNetworksOperandType operand_type{
nn_type, static_cast<uint32_t>(tensor->dims->size),
reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
- CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type));
+ RETURN_ERROR_IF_NN_FAILED(
+ ANeuralNetworksModel_addOperand(nn_model, &operand_type));
// TODO(aselle): Based on Michael's suggestion, limiting this to read
// only memory
if (tensor->allocation_type == kTfLiteMmapRo) {
if (const NNAPIAllocation* alloc = dynamic_cast<const NNAPIAllocation*>(
static_cast<const Allocation*>(tensor->allocation))) {
- CHECK_NN(ANeuralNetworksModel_setOperandValueFromMemory(
- nn_model, next_id, alloc->memory(), alloc->offset(tensor->data.raw),
- tensor->bytes));
+ RETURN_ERROR_IF_NN_FAILED(
+ ANeuralNetworksModel_setOperandValueFromMemory(
+ nn_model, next_id, alloc->memory(),
+ alloc->offset(tensor->data.raw), tensor->bytes));
} else {
- CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_setOperandValue(
nn_model, next_id, tensor->data.raw, tensor->bytes));
}
} else if (tensor->bytes == 0) {
// These size 0 tensors are optional tensors reserved.
- CHECK_NN(
+ RETURN_ERROR_IF_NN_FAILED(
ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0));
}
++next_id;
}
- return next_id;
+ *no_of_operands_added = next_id;
+ return kTfLiteOk;
+}
+
+void MapAndAddTensorIds(const int* from_ids_buf, size_t from_ids_count,
+ std::vector<uint32_t>* into,
+ const std::vector<int64_t>& map) {
+ for (size_t i = 0; i < from_ids_count; i++) {
+ int from_id = from_ids_buf[i];
+ if (from_id == kOptionalTensor) {
+ into->push_back(from_id);
+ } else {
+ into->push_back(map[from_id]);
+ }
+ }
}
// Adds the operations and their parameters to the NN API model.
// 'next-id' is the operand ID of the next operand of the model.
-void AddOpsAndParams(tflite::Interpreter* interpreter,
- ANeuralNetworksModel* nn_model, uint32_t next_id,
- std::vector<int>* model_state_inputs,
- std::vector<int>* model_state_outputs) {
+TfLiteStatus AddOpsAndParams(
+ tflite::Interpreter* interpreter, ANeuralNetworksModel* nn_model,
+ uint32_t next_id, std::vector<int>* model_state_inputs,
+ std::vector<int>* model_state_outputs,
+ const std::vector<int64_t>& tensor_id_to_nnapi_id) {
for (size_t i = 0; i < interpreter->nodes_size(); i++) {
const auto* node_and_registration = interpreter->node_and_registration(i);
const TfLiteNode& node = node_and_registration->first;
@@ -192,10 +234,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
static_cast<tflite::BuiltinOperator>(registration.builtin_code);
// Add the parameters.
- std::vector<uint32_t> augmented_inputs(
- node.inputs->data, node.inputs->data + node.inputs->size);
- std::vector<uint32_t> augmented_outputs(
- node.outputs->data, node.outputs->data + node.outputs->size);
+ std::vector<uint32_t> augmented_inputs, augmented_outputs;
+ MapAndAddTensorIds(node.inputs->data, node.inputs->size, &augmented_inputs,
+ tensor_id_to_nnapi_id);
+ MapAndAddTensorIds(node.outputs->data, node.outputs->size,
+ &augmented_outputs, tensor_id_to_nnapi_id);
auto add_scalar_int32 = [&nn_model, &augmented_inputs,
&next_id](int value) {
@@ -215,6 +258,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
augmented_inputs.push_back(next_id++);
};
+ auto add_vector_int32 = [&](const int* values, uint32_t num_values) {
+ ANeuralNetworksOperandType operand_type{
+ .type = ANEURALNETWORKS_TENSOR_INT32,
+ .dimensionCount = 1,
+ .dimensions = &num_values};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, next_id, values, sizeof(int32_t) * num_values));
+ augmented_inputs.push_back(next_id++);
+ };
+
// Handle state tensors of RNN, LSTM, SVDF.
// For each state_out tensor, a corresponding state_in operand needs to be
// created for NNAPI.
@@ -233,39 +287,54 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
model_state_outputs->push_back(tensor_id);
next_id++;
};
+ auto check_and_add_activation = [&add_scalar_int32](int activation) {
+ if (activation > kTfLiteActRelu6) {
+ FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ }
+ add_scalar_int32(activation);
+ };
- auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
+ auto add_add_params = [&add_scalar_int32](void* data) {
+ auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
+ if (builtin->activation > kTfLiteActRelu6) {
+ FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ }
+ add_scalar_int32(builtin->activation);
+ };
- auto add_pooling_params = [&add_scalar_int32](void* data) {
+ auto add_pooling_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->filter_width);
add_scalar_int32(builtin->filter_height);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_convolution_params = [&add_scalar_int32](void* data) {
+ auto add_convolution_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteConvParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_depthwise_conv_params = [&add_scalar_int32](void* data) {
+ auto add_depthwise_conv_params = [&add_scalar_int32,
+ &check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteDepthwiseConvParams*>(data);
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->depth_multiplier);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
- auto add_fully_connected_params = [&add_scalar_int32](void* data) {
+ auto add_fully_connected_params = [&check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
- add_scalar_int32(builtin->activation);
+ check_and_add_activation(builtin->activation);
};
auto add_concatenation_params = [&add_scalar_int32](void* data) {
@@ -297,6 +366,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
// LSTM in NNAPI requires scratch tensor as an output operand.
auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model,
&next_id, &augmented_outputs]() {
+ if (node.temporaries->size == 0) return;
int scratch_buffer_index = node.temporaries->data[0];
const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index);
ANeuralNetworksOperandType operand_type{
@@ -309,7 +379,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
};
auto add_mean_params = [&add_scalar_int32](void* data) {
- auto builtin = reinterpret_cast<TfLiteMeanParams*>(data);
+ auto builtin = reinterpret_cast<TfLiteReducerParams*>(data);
add_scalar_int32(builtin->keep_dims);
};
@@ -324,6 +394,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_scalar_int32(builtin->activation);
};
+ auto add_squeeze_params = [&](void* data) {
+ const auto* builtin = reinterpret_cast<TfLiteSqueezeParams*>(data);
+ // Note that we add the squeeze dimensions even if the dimensions were
+ // unspecified (empty), as NNAPI requires the operand.
+ add_vector_int32(builtin->squeeze_dims,
+ static_cast<uint32_t>(builtin->num_squeeze_dims));
+ };
+
// Handle optional input tensors.
auto add_optional_tensors = [&nn_model, &augmented_inputs,
&next_id](int nn_type) {
@@ -345,11 +423,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
switch (builtin) {
case tflite::BuiltinOperator_ADD:
nn_op_type = ANEURALNETWORKS_ADD;
- add_add_params();
+ add_add_params(node.builtin_data);
break;
case tflite::BuiltinOperator_MUL:
nn_op_type = ANEURALNETWORKS_MUL;
- add_add_params();
+ add_add_params(node.builtin_data);
break;
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
add_pooling_params(node.builtin_data);
@@ -363,7 +441,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_pooling_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
break;
- case tflite::BuiltinOperator_CONV_2D:
+ case tflite::BuiltinOperator_CONV_2D: {
+ auto builtin = reinterpret_cast<TfLiteConvParams*>(node.builtin_data);
+ if (builtin->dilation_width_factor != 1 ||
+ builtin->dilation_height_factor != 1 || node.inputs->size != 3) {
+ logError("NNAPI does not support dilated Conv2D.");
+ return kTfLiteError;
+ }
+ }
add_convolution_params(node.builtin_data);
nn_op_type = ANEURALNETWORKS_CONV_2D;
break;
@@ -407,6 +492,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
nn_op_type = ANEURALNETWORKS_SPACE_TO_DEPTH;
break;
case tflite::BuiltinOperator_LSTM: {
+ if (node.inputs->size + /* no of params */ 3 != 21) {
+ logError("NNAPI only supports 21-input LSTMs");
+ return kTfLiteError;
+ }
duplicate_state_tensor_float32(
node.outputs->data[/*kOutputStateTensor*/ 0]);
duplicate_state_tensor_float32(
@@ -445,10 +534,31 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_DIV:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_DIV;
+ check_and_add_activation(
+ reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation);
break;
case tflite::BuiltinOperator_SUB:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_SUB;
+ check_and_add_activation(
+ reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation);
+ break;
+ case tflite::BuiltinOperator_SQUEEZE:
+ nnapi_version = 11; // requires NNAPI 1.1
+ add_squeeze_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_SQUEEZE;
+ break;
+ case tflite::BuiltinOperator_TRANSPOSE:
+ // The permutation input tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((node.inputs->size > 1) &&
+ (interpreter->tensor(node.inputs->data[1])->allocation_type !=
+ kTfLiteMmapRo)) {
+ logError("NNAPI does not yet support dynamic tensors.");
+ return kTfLiteError;
+ }
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_TRANSPOSE;
break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
@@ -469,9 +579,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
case tflite::BuiltinOperator_TOPK_V2:
- case tflite::BuiltinOperator_TRANSPOSE:
case tflite::BuiltinOperator_SPLIT:
- case tflite::BuiltinOperator_SQUEEZE:
case tflite::BuiltinOperator_STRIDED_SLICE:
case tflite::BuiltinOperator_EXP:
case tflite::BuiltinOperator_LOG_SOFTMAX:
@@ -482,6 +590,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_MAXIMUM:
case tflite::BuiltinOperator_MINIMUM:
case tflite::BuiltinOperator_ARG_MAX:
+ case tflite::BuiltinOperator_ARG_MIN:
case tflite::BuiltinOperator_GREATER:
case tflite::BuiltinOperator_GREATER_EQUAL:
case tflite::BuiltinOperator_LESS:
@@ -490,14 +599,25 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SELECT:
case tflite::BuiltinOperator_SLICE:
case tflite::BuiltinOperator_SIN:
+ case tflite::BuiltinOperator_LOG:
case tflite::BuiltinOperator_TRANSPOSE_CONV:
+ case tflite::BuiltinOperator_TILE:
+ case tflite::BuiltinOperator_EXPAND_DIMS:
case tflite::BuiltinOperator_SPARSE_TO_DENSE:
- FATAL("Op code %d is currently not delegated to NNAPI", builtin);
- nn_op_type = -1; // set to invalid
+ case tflite::BuiltinOperator_EQUAL:
+ case tflite::BuiltinOperator_NOT_EQUAL:
+ case tflite::BuiltinOperator_SUM:
+ case tflite::BuiltinOperator_SQRT:
+ case tflite::BuiltinOperator_RSQRT:
+ case tflite::BuiltinOperator_SHAPE:
+ case tflite::BuiltinOperator_POW:
+ case tflite::BuiltinOperator_FAKE_QUANT:
+ logError("Op code %d is currently not delegated to NNAPI", builtin);
+ return kTfLiteError;
break;
case tflite::BuiltinOperator_CUSTOM:
- FATAL("Custom operations are not supported when using NNAPI.");
- nn_op_type = -1; // set to invalid
+ logError("Custom operations are not supported when using NNAPI.");
+ return kTfLiteError;
break;
}
@@ -506,47 +626,70 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
}
// Add the operation.
- CHECK_NN(ANeuralNetworksModel_addOperation(
+ RETURN_ERROR_IF_NN_FAILED(ANeuralNetworksModel_addOperation(
nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),
augmented_inputs.data(),
static_cast<uint32_t>(augmented_outputs.size()),
reinterpret_cast<uint32_t*>(augmented_outputs.data())));
}
+ return kTfLiteOk;
}
TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
- // TODO(aselle): This is not correct. need to handle resize invalidation.
- if (nn_model_ && nn_compiled_model_) return kTfLiteOk;
+ if (nn_model_ && nn_compiled_model_) return model_status_;
+ // TODO(aselle): This is not correct. need to handle resize invalidation.
if (!nn_model_) {
CHECK_NN(ANeuralNetworksModel_create(&nn_model_));
- // Find all the temporary tensors and put them in a skip_list.
- std::vector<uint32_t> skip_list;
+ // Find which tensors should be added to NNAPI. TFLite has temporaries
+ // and RNN back-edges which are are not valid for NNAPI. We look through all
+ // inputs and outputs and mark the mapping in tensor_id_to_nnapi_id with
+ // kOperandIdNotSet. addTensorOperands will replace those with the
+ // corresponding NNAPI operand ids and skip kOperandNotNeeded entries.
+ std::vector<int64_t> tensor_id_to_nnapi_id(interpreter->tensors_size(),
+ kOperandNotNeeded);
+ auto set_ids_to_not_set = [&tensor_id_to_nnapi_id](const int* buf,
+ size_t count) {
+ for (int j = 0; j < count; j++) {
+ auto tensor_id = buf[j];
+ if (tensor_id != kOptionalTensor) {
+ tensor_id_to_nnapi_id[tensor_id] = kOperandIdNotSet;
+ }
+ }
+ };
for (size_t i = 0; i < interpreter->nodes_size(); i++) {
const auto* node_and_registration = interpreter->node_and_registration(i);
const TfLiteNode& node = node_and_registration->first;
- if (node.temporaries != nullptr) {
- for (int j = 0; j < node.temporaries->size; j++) {
- skip_list.push_back(static_cast<uint32_t>(node.temporaries->data[j]));
- }
- }
+ set_ids_to_not_set(node.inputs->data, node.inputs->size);
+ set_ids_to_not_set(node.outputs->data, node.outputs->size);
}
-
- uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list);
- AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
- &model_states_outputs_);
-
- std::vector<int> augmented_inputs = interpreter->inputs();
- std::vector<int> augmented_outputs = interpreter->outputs();
-
- // All state tensors input/output need to be treated as model input/output.
+ set_ids_to_not_set(interpreter->inputs().data(),
+ interpreter->inputs().size());
+ set_ids_to_not_set(interpreter->outputs().data(),
+ interpreter->outputs().size());
+
+ uint32_t next_id = 0;
+ RETURN_ERROR_IF_NN_FAILED(addTensorOperands(
+ interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id));
+ RETURN_ERROR_IF_NN_FAILED(
+ AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
+ &model_states_outputs_, tensor_id_to_nnapi_id));
+
+ std::vector<uint32_t> augmented_inputs;
+ MapAndAddTensorIds(interpreter->inputs().data(),
+ interpreter->inputs().size(), &augmented_inputs,
+ tensor_id_to_nnapi_id);
augmented_inputs.insert(augmented_inputs.end(),
model_states_inputs_.begin(),
model_states_inputs_.end());
- augmented_outputs.insert(augmented_outputs.end(),
- model_states_outputs_.begin(),
- model_states_outputs_.end());
+ std::vector<uint32_t> augmented_outputs;
+ MapAndAddTensorIds(interpreter->outputs().data(),
+ interpreter->outputs().size(), &augmented_outputs,
+ tensor_id_to_nnapi_id);
+ MapAndAddTensorIds(model_states_outputs_.data(),
+ model_states_outputs_.size(), &augmented_outputs,
+ tensor_id_to_nnapi_id);
CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_, static_cast<uint32_t>(augmented_inputs.size()),
@@ -564,7 +707,13 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) {
if (!nn_model_) {
- TF_LITE_ENSURE_STATUS(BuildGraph(interpreter));
+ model_status_ = BuildGraph(interpreter);
+ if (model_status_ != kTfLiteOk) {
+ logError("Failed to build graph for NNAPI");
+ }
+ }
+ if (model_status_ != kTfLiteOk) {
+ return model_status_;
}
ANeuralNetworksExecution* execution = nullptr;
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 94dea4f9b2..8dc7d38a30 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -59,14 +59,16 @@ class NNAPIDelegate {
ANeuralNetworksModel* nn_model_ = nullptr;
// The NN API compilation handle
ANeuralNetworksCompilation* nn_compiled_model_ = nullptr;
+ // Model status
+ TfLiteStatus model_status_ = kTfLiteOk;
// List of state tensors for LSTM, RNN, SVDF.
// NN API does not allow ops to maintain states across multiple
// invocations. We need to manually create state input tensors from
// corresponding state output tensors of TFLite operations, and map them
// correctly.
- std::vector<int> model_states_inputs_;
- std::vector<int> model_states_outputs_;
+ std::vector<int> model_states_inputs_; // holds NNAPI operand ids
+ std::vector<int> model_states_outputs_; // holds TFLite tensor ids
};
} // namespace tflite
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index dfdd80ea8a..f1f025f777 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -50,6 +50,10 @@ const char* TensorTypeName(TfLiteType type) {
return "kTfLiteString";
case kTfLiteBool:
return "kTfLiteBool";
+ case kTfLiteInt16:
+ return "kTfLiteInt16";
+ case kTfLiteComplex64:
+ return "kTfLiteComplex64";
}
return "(invalid)";
}
@@ -82,13 +86,13 @@ void PrintInterpreterState(Interpreter* interpreter) {
for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
tensor_index++) {
TfLiteTensor* tensor = interpreter->tensor(tensor_index);
- printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
- TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type),
- tensor->bytes, float(tensor->bytes) / float(1 << 20));
+ printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
+ tensor->name, TensorTypeName(tensor->type),
+ AllocTypeName(tensor->allocation_type), tensor->bytes,
+ (static_cast<float>(tensor->bytes) / (1 << 20)));
PrintTfLiteIntVector(tensor->dims);
- printf("\n");
}
-
+ printf("\n");
for (int node_index = 0; node_index < interpreter->nodes_size();
node_index++) {
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
@@ -104,7 +108,4 @@ void PrintInterpreterState(Interpreter* interpreter) {
}
}
-// Prints a dump of what tensors and what nodes are in the interpreter.
-TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
-
} // namespace tflite
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
index 1b6998cda3..7fb4b8d8b7 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.h
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -24,9 +24,6 @@ namespace tflite {
// Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(Interpreter* interpreter);
-// Prints a dump of what tensors and what nodes are in the interpreter.
-TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
-
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_
diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD
index c31189f2b1..a162b87b8f 100644
--- a/tensorflow/contrib/lite/profiling/BUILD
+++ b/tensorflow/contrib/lite/profiling/BUILD
@@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
common_copts = [
"-Wall",
-]
+] + tflite_copts()
cc_library(
name = "profiler",
@@ -36,12 +38,14 @@ cc_library(
name = "time",
srcs = ["time.cc"],
hdrs = ["time.h"],
+ copts = common_copts,
)
cc_library(
name = "profile_summarizer",
srcs = ["profile_summarizer.cc"],
hdrs = ["profile_summarizer.h"],
+ copts = common_copts,
deps = [
":profiler",
"//tensorflow/contrib/lite:framework",
@@ -53,6 +57,7 @@ cc_library(
cc_test(
name = "profile_summarizer_test",
srcs = ["profile_summarizer_test.cc"],
+ copts = common_copts,
deps = [
":profile_summarizer",
"//tensorflow/contrib/lite:framework",
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
index 788f6922d2..c37a096588 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc
@@ -26,21 +26,22 @@ namespace {
using Detail = tensorflow::StatsCalculator::Detail;
struct OperatorDetails {
- string name;
- std::vector<string> inputs;
- std::vector<string> outputs;
+ std::string name;
+ std::vector<std::string> inputs;
+ std::vector<std::string> outputs;
};
-string GetTensorName(const tflite::Interpreter& interpreter, int tensor_index) {
+std::string GetTensorName(const tflite::Interpreter& interpreter,
+ int tensor_index) {
const auto tensor = interpreter.tensor(tensor_index);
if (tensor == nullptr || tensor->name == nullptr) {
return "Unknown";
}
return tensor->name;
}
-std::vector<string> GetTensorNames(const tflite::Interpreter& interpreter,
- const TfLiteIntArray* tensor_indices) {
- std::vector<string> tensors;
+std::vector<std::string> GetTensorNames(const tflite::Interpreter& interpreter,
+ const TfLiteIntArray* tensor_indices) {
+ std::vector<std::string> tensors;
tensors.reserve(tensor_indices->size);
for (int i = 0; i < tensor_indices->size; i++) {
tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i]));
@@ -48,7 +49,7 @@ std::vector<string> GetTensorNames(const tflite::Interpreter& interpreter,
return tensors;
}
-string ToString(const std::vector<string>& str_vector) {
+std::string ToString(const std::vector<std::string>& str_vector) {
std::stringstream stream;
stream << "[";
bool first = true;
@@ -77,18 +78,30 @@ OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter,
} else {
op_name = tflite::EnumNamesBuiltinOperator()[code];
}
+ const char* profiling_string =
+ interpreter.OpProfilingString(node_reg->second, &node_reg->first);
OperatorDetails details;
details.name = op_name;
+ if (profiling_string) {
+ details.name += ":" + string(profiling_string);
+ }
details.inputs = GetTensorNames(interpreter, inputs);
details.outputs = GetTensorNames(interpreter, outputs);
return details;
}
+tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() {
+ auto options = tensorflow::StatSummarizerOptions();
+ options.show_summary = true;
+ options.show_memory = false;
+ return options;
+}
+
} // namespace
ProfileSummarizer::ProfileSummarizer()
- : stats_calculator_(new ::tensorflow::StatsCalculator(
- tensorflow::StatSummarizerOptions())) {}
+ : stats_calculator_(
+ new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {}
void ProfileSummarizer::ProcessProfiles(
const std::vector<const ProfileEvent*>& profile_stats,
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h
index 6fe6ca04f5..a529ff8742 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer.h
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h
@@ -45,9 +45,6 @@ class ProfileSummarizer {
return stats_calculator_->GetShortSummary();
}
- // Prints the string returned by GetOutputString().
- void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
-
private:
std::unique_ptr<tensorflow::StatsCalculator> stats_calculator_;
};
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
index 35cf780713..67a5eecfa0 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
@@ -31,6 +31,7 @@ namespace profiling {
namespace {
+#ifdef TFLITE_PROFILING_ENABLED
TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1);
@@ -42,20 +43,35 @@ TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+const char* SimpleOpProfilingString(const TfLiteContext* context,
+ const TfLiteNode* node) {
+ return "Profile";
+}
+
TfLiteRegistration* RegisterSimpleOp() {
+ static TfLiteRegistration registration = {
+ nullptr, nullptr, nullptr,
+ SimpleOpEval, nullptr, tflite::BuiltinOperator_CUSTOM,
+ "SimpleOpEval", 1};
+ return &registration;
+}
+
+TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
static TfLiteRegistration registration = {nullptr,
nullptr,
nullptr,
SimpleOpEval,
+ SimpleOpProfilingString,
tflite::BuiltinOperator_CUSTOM,
"SimpleOpEval",
1};
return &registration;
}
+#endif
class SimpleOpModel : public SingleOpModel {
public:
- void Init();
+ void Init(const std::function<TfLiteRegistration*()>& registration);
tflite::Interpreter* GetInterpreter() { return interpreter_.get(); }
void SetInputs(int32_t x, int32_t y) {
PopulateTensor(inputs_[0], {x});
@@ -68,11 +84,12 @@ class SimpleOpModel : public SingleOpModel {
int output_;
};
-void SimpleOpModel::Init() {
+void SimpleOpModel::Init(
+ const std::function<TfLiteRegistration*()>& registration) {
inputs_[0] = AddInput({TensorType_INT32, {1}});
inputs_[1] = AddInput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT32, {}});
- SetCustomOp("SimpleAdd", {}, RegisterSimpleOp);
+ SetCustomOp("SimpleAdd", {}, registration);
BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
}
@@ -86,7 +103,28 @@ TEST(ProfileSummarizerTest, Empty) {
TEST(ProfileSummarizerTest, Interpreter) {
Profiler profiler;
SimpleOpModel m;
- m.Init();
+ m.Init(RegisterSimpleOp);
+ auto interpreter = m.GetInterpreter();
+ interpreter->SetProfiler(&profiler);
+ profiler.StartProfiling();
+ m.SetInputs(1, 2);
+ m.Invoke();
+ // 3 = 1 + 2
+ EXPECT_EQ(m.GetOutput(), 3);
+ profiler.StopProfiling();
+ ProfileSummarizer summarizer;
+ auto events = profiler.GetProfileEvents();
+ EXPECT_EQ(1, events.size());
+ summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
+ auto output = summarizer.GetOutputString();
+ // TODO(shashishekhar): Add a better test here.
+ ASSERT_TRUE(output.find("SimpleOpEval") != std::string::npos) << output;
+}
+
+TEST(ProfileSummarizerTest, InterpreterPlusProfilingDetails) {
+ Profiler profiler;
+ SimpleOpModel m;
+ m.Init(RegisterSimpleOpWithProfilingDetails);
auto interpreter = m.GetInterpreter();
interpreter->SetProfiler(&profiler);
profiler.StartProfiling();
@@ -101,8 +139,10 @@ TEST(ProfileSummarizerTest, Interpreter) {
summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter);
auto output = summarizer.GetOutputString();
// TODO(shashishekhar): Add a better test here.
- ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output;
+ ASSERT_TRUE(output.find("SimpleOpEval:Profile") != std::string::npos)
+ << output;
}
+
#endif
} // namespace
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 7e6ff6c0a8..8c9608db04 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -19,6 +19,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
+ "//tensorflow/python:util",
],
)
@@ -30,9 +31,10 @@ py_test(
tags = ["no_oss"],
deps = [
":interpreter",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//third_party/py/numpy",
],
)
@@ -57,8 +59,9 @@ py_library(
":interpreter",
":lite_constants",
":op_hint",
- "//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model:loader",
"//tensorflow/python/tools:freeze_graph_lib",
],
)
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 0819475240..0ea2630f71 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -25,7 +25,6 @@ import tempfile as _tempfile
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
-from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -111,42 +110,42 @@ def tensor_name(x):
return x.name.split(":")[0]
-def toco_convert(input_data,
- input_tensors,
- output_tensors,
- inference_type=lite_constants.FLOAT,
- inference_input_type=None,
- input_format=lite_constants.TENSORFLOW_GRAPHDEF,
- output_format=lite_constants.TFLITE,
- quantized_input_stats=None,
- default_ranges_stats=None,
- drop_control_dependency=True,
- reorder_across_fake_quant=False,
- allow_custom_ops=False,
- change_concat_input_ranges=False):
- """Convert a model using TOCO from `input_format` to `output_format`.
+def build_toco_convert_protos(input_tensors,
+ output_tensors,
+ inference_type=lite_constants.FLOAT,
+ inference_input_type=None,
+ input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ output_format=lite_constants.TFLITE,
+ quantized_input_stats=None,
+ default_ranges_stats=None,
+ drop_control_dependency=True,
+ reorder_across_fake_quant=False,
+ allow_custom_ops=False,
+ change_concat_input_ranges=False,
+ quantize_weights=False,
+ dump_graphviz_dir=None,
+ dump_graphviz_video=False):
+ """Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
case the default `input_format` and `output_format` are sufficient.
Args:
- input_data: Input data (i.e. often `sess.graph_def`).
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- inference_type: Target data type of arrays in the output file. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
- inference_input_type: Target data type of input arrays. Allows for a
- different type for input arrays in the case of quantization. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ inference_type: Target data type of real-number arrays in the output file.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of real-number input arrays. Allows
+ for a different type for input arrays in the case of quantization.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
- of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default None)
+ quantized_input_stats: List of tuples of integers representing the mean and
+ standard deviation. Each tuple maps to the corresponding input tensor.
+ Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -158,18 +157,28 @@ def toco_convert(input_data,
nodes is preventing graph transformations necessary to convert the graph.
Results in a graph that differs from the quantized training graph,
potentially causing differing arithmetic behavior. (default False)
- change_concat_input_ranges: Boolean to change behavior of min/max ranges for
- inputs and outputs of the concat operator for quantized models. Changes
- the ranges of concat operator overlap when true. (default False)
allow_custom_ops: Boolean indicating whether to allow custom operations.
When false any unknown operation is an error. When true, custom ops are
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
+ change_concat_input_ranges: Boolean to change behavior of min/max ranges for
+ inputs and outputs of the concat operator for quantized models. Changes
+ the ranges of concat operator overlap when true. (default False)
+ quantize_weights: Boolean indicating whether to store weights as quantized
+ weights followed by dequantize operations. Computation is still done in
+ float, but reduces model size (at the cost of accuracy and latency).
+ (default False)
+ dump_graphviz_dir: Full filepath of folder to dump the graphs at various
+ stages of processing GraphViz .dot files. Preferred over
+ --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
+ output file. (default None)
+ dump_graphviz_video: Boolean indicating whether to dump the graph after
+ every graph transformation. (default False)
Returns:
- The converted data. For example if TFLite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
+ model_flags, toco_flags: two protocol buffers describing the conversion
+ process.
Raises:
ValueError: If the input tensor type is unknown
@@ -185,42 +194,54 @@ def toco_convert(input_data,
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
+ toco.quantize_weights = quantize_weights
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
+ if dump_graphviz_dir:
+ toco.dump_graphviz_dir = dump_graphviz_dir
+ toco.dump_graphviz_include_video = dump_graphviz_video
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
- if input_tensor.dtype == _dtypes.float32:
- tflite_input_type = lite_constants.FLOAT
- elif input_tensor.dtype == _dtypes.int32:
- tflite_input_type = lite_constants.INT32
- elif input_tensor.dtype == _dtypes.int64:
- tflite_input_type = lite_constants.INT64
- elif input_tensor.dtype == _dtypes.uint8:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
- # TODO(aselle): Insert strings when they are available
- else:
- raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
- input_tensor.dtype))
-
input_array = model.input_arrays.add()
-
if inference_type == lite_constants.QUANTIZED_UINT8:
- if tflite_input_type == lite_constants.FLOAT:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
-
input_array.name = tensor_name(input_tensor)
input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
for output_tensor in output_tensors:
model.output_arrays.append(tensor_name(output_tensor))
+ return model, toco
+
- # TODO(aselle): Consider handling the case of allowing quantized
- # inputs to be converted to float (via the toco.inference_input_type field).
- data = toco_convert_protos(model.SerializeToString(),
- toco.SerializeToString(),
+def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ Typically this function is used to convert from TensorFlow GraphDef to TFLite.
+ Conversion can be customized by providing arguments that are forwarded to
+ `build_toco_convert_protos` (see documentation for details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ model_flags, toco_flags = build_toco_convert_protos(input_tensors,
+ output_tensors,
+ *args, **kwargs)
+ data = toco_convert_protos(model_flags.SerializeToString(),
+ toco_flags.SerializeToString(),
input_data.SerializeToString())
return data
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 5dad49f1ed..1553464b9f 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -19,13 +19,12 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.lite.python.convert import tensor_name
-from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
@@ -58,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
Raises:
ValueError: No valid MetaGraphDef for given tag_set.
"""
- saved_model = reader.read_saved_model(saved_model_dir)
- tag_sets = []
- result_meta_graph_def = None
- for meta_graph_def in saved_model.meta_graphs:
- meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags)
- tag_sets.append(meta_graph_tag_set)
- if meta_graph_tag_set == tag_set:
- result_meta_graph_def = meta_graph_def
- logging.info("The given saved_model contains the following tags: %s",
- tag_sets)
- if result_meta_graph_def is not None:
- return result_meta_graph_def
- else:
- raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible "
- "values are '{}'. ".format(tag_set, tag_sets))
+ with session.Session(graph=ops.Graph()) as sess:
+ return loader.load(sess, tag_set, saved_model_dir)
def _get_signature_def(meta_graph, signature_key):
@@ -97,9 +83,7 @@ def _get_signature_def(meta_graph, signature_key):
raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
"values are '{}'.".format(signature_key,
",".join(signature_def_keys)))
- signature_def = signature_def_utils.get_signature_def_by_key(
- meta_graph, signature_key)
- return signature_def
+ return signature_def_map[signature_key]
def _get_inputs_outputs(signature_def):
@@ -247,6 +231,7 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
+ assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
@@ -255,9 +240,13 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
+ # Check SavedModel for assets directory.
+ collection_def = meta_graph.collection_def
+ if constants.ASSETS_KEY in collection_def:
+ raise ValueError("SavedModels with assets/ directory are not supported.")
+
graph = ops.Graph()
with session.Session(graph=graph) as sess:
- # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index 1e570d2c89..92c4ebb246 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -78,6 +78,7 @@ class TensorFunctionsTest(test_util.TensorFlowTestCase):
def testSetTensorShapeNoneValid(self):
tensor = array_ops.placeholder(dtype=dtypes.float32)
+ self.assertEqual(None, tensor.shape)
convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
self.assertEqual([1, 3, 5], tensor.shape.as_list())
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 779bda4c9d..e1981ceae2 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
from tensorflow.python.util.lazy_loader import LazyLoader
# Lazy load since some of the performance benchmark skylark rules
@@ -55,17 +56,42 @@ class Interpreter(object):
self._interpreter = (
_interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer(
model_content))
- if not self._interpreter:
- raise ValueError(
- 'Failed to create model from {} bytes'.format(len(model_content)))
elif not model_path and not model_path:
raise ValueError('`model_path` or `model_content` must be specified.')
else:
raise ValueError('Can\'t both provide `model_path` and `model_content`')
def allocate_tensors(self):
- if not self._interpreter.AllocateTensors():
- raise ValueError('Failed to allocate tensors')
+ self._ensure_safe()
+ return self._interpreter.AllocateTensors()
+
+ def _safe_to_run(self):
+ """Returns true if there exist no numpy array buffers.
+
+ This means it is safe to run tflite calls that may destroy internally
+ allocated memory. This works, because in the wrapper.cc we have made
+ the numpy base be the self._interpreter.
+ """
+ # NOTE, our tensor() call in cpp will use _interpreter as a base pointer.
+ # If this environment is the only _interpreter, then the ref count should be
+ # 2 (1 in self and 1 in temporary of sys.getrefcount).
+ return sys.getrefcount(self._interpreter) == 2
+
+ def _ensure_safe(self):
+ """Makes sure no numpy arrays pointing to internal buffers are active.
+
+ This should be called from any function that will call a function on
+ _interpreter that may reallocate memory e.g. invoke(), ...
+
+ Raises:
+ RuntimeError: If there exist numpy objects pointing to internal memory
+ then we throw.
+ """
+ if not self._safe_to_run():
+ raise RuntimeError("""There is at least 1 reference to internal data
+ in the interpreter in the form of a numpy array or slice. Be sure to
+ only hold the function returned from tensor() if you are using raw
+ data access.""")
def _get_tensor_details(self, tensor_index):
"""Gets tensor details.
@@ -109,7 +135,10 @@ class Interpreter(object):
]
def set_tensor(self, tensor_index, value):
- """Sets the value of the input.
+ """Sets the value of the input tensor. Note this copies data in `value`.
+
+ If you want to avoid copying, you can use the `tensor()` function to get a
+ numpy buffer pointing to the input buffer in the tflite interpreter.
Args:
tensor_index: Tensor index of tensor to set. This value can be gotten from
@@ -119,8 +148,7 @@ class Interpreter(object):
Raises:
ValueError: If the interpreter could not set the tensor.
"""
- if not self._interpreter.SetTensor(tensor_index, value):
- raise ValueError('Failed to set tensor')
+ self._interpreter.SetTensor(tensor_index, value)
def resize_tensor_input(self, input_index, tensor_size):
"""Resizes an input tensor.
@@ -133,8 +161,8 @@ class Interpreter(object):
Raises:
ValueError: If the interpreter could not resize the input tensor.
"""
- if not self._interpreter.ResizeInputTensor(input_index, tensor_size):
- raise ValueError('Failed to resize input')
+ self._ensure_safe()
+ self._interpreter.ResizeInputTensor(input_index, tensor_size)
def get_output_details(self):
"""Gets model output details.
@@ -147,7 +175,9 @@ class Interpreter(object):
]
def get_tensor(self, tensor_index):
- """Sets the value of the input.
+ """Gets the value of the input tensor (get a copy).
+
+ If you wish to avoid the copy, use `tensor()`.
Args:
tensor_index: Tensor index of tensor to get. This value can be gotten from
@@ -158,6 +188,62 @@ class Interpreter(object):
"""
return self._interpreter.GetTensor(tensor_index)
+ def tensor(self, tensor_index):
+ """Returns function that gives a numpy view of the current tensor buffer.
+
+ This allows reading and writing to this tensors w/o copies. This more
+ closely mirrors the C++ Interpreter class interface's tensor() member, hence
+ the name. Be careful to not hold these output references through calls
+ to `allocate_tensors()` and `invoke()`.
+
+ Usage:
+
+ interpreter.allocate_tensors()
+ input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
+ output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
+ for i in range(10):
+ input().fill(3.)
+ interpreter.invoke()
+ print("inference %s" % output)
+
+ Notice how this function avoids making a numpy array directly. This is
+ because it is important to not hold actual numpy views to the data longer
+ than necessary. If you do, then the interpreter can no longer be invoked,
+ because it is possible the interpreter would resize and invalidate the
+ referenced tensors. The NumPy API doesn't allow any mutability of the
+ the underlying buffers.
+
+ WRONG:
+
+ input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
+ output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
+ interpreter.allocate_tensors() # This will throw RuntimeError
+ for i in range(10):
+ input.fill(3.)
+ interpreter.invoke() # this will throw RuntimeError since input,output
+
+ Args:
+ tensor_index: Tensor index of tensor to get. This value can be gotten from
+ the 'index' field in get_output_details.
+
+ Returns:
+ A function that can return a new numpy array pointing to the internal
+ TFLite tensor state at any point. It is safe to hold the function forever,
+ but it is not safe to hold the numpy array forever.
+ """
+ return lambda: self._interpreter.tensor(self._interpreter, tensor_index)
+
def invoke(self):
- if not self._interpreter.Invoke():
- raise ValueError('Failed to invoke TFLite model')
+ """Invoke the interpreter.
+
+ Be sure to set the input sizes, allocate tensors and fill values before
+ calling this.
+
+ Raises:
+ ValueError: When the underlying interpreter fails raise ValueError.
+ """
+ self._ensure_safe()
+ self._interpreter.Invoke()
+
+ def reset_all_variables_to_zero(self):
+ return self._interpreter.ResetVariableTensorsToZero()
diff --git a/tensorflow/contrib/lite/python/interpreter_test.py b/tensorflow/contrib/lite/python/interpreter_test.py
index f802edf020..95fa4b8584 100644
--- a/tensorflow/contrib/lite/python/interpreter_test.py
+++ b/tensorflow/contrib/lite/python/interpreter_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import io
import numpy as np
+import six
from tensorflow.contrib.lite.python import interpreter as interpreter_wrapper
from tensorflow.python.framework import test_util
@@ -91,5 +92,83 @@ class InterpreterTest(test_util.TensorFlowTestCase):
self.assertTrue((expected_output == output_data).all())
+class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
+
+ def testInvalidModelContent(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Model provided has model identifier \''):
+ interpreter_wrapper.Interpreter(model_content=six.b('garbage'))
+
+ def testInvalidModelFile(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Could not open \'totally_invalid_file_name\''):
+ interpreter_wrapper.Interpreter(
+ model_path='totally_invalid_file_name')
+
+ def testInvokeBeforeReady(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
+ with self.assertRaisesRegexp(RuntimeError,
+ 'Invoke called on model that is not ready'):
+ interpreter.invoke()
+
+
+class InterpreterTensorAccessorTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.interpreter = interpreter_wrapper.Interpreter(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/permute_float.tflite'))
+ self.interpreter.allocate_tensors()
+ self.input0 = self.interpreter.get_input_details()[0]['index']
+ self.initial_data = np.array([[-1., -2., -3., -4.]], np.float32)
+
+ def testTensorAccessor(self):
+ """Check that tensor returns a reference."""
+ array_ref = self.interpreter.tensor(self.input0)
+ np.copyto(array_ref(), self.initial_data)
+ self.assertAllEqual(array_ref(), self.initial_data)
+ self.assertAllEqual(
+ self.interpreter.get_tensor(self.input0), self.initial_data)
+
+ def testGetTensorAccessor(self):
+ """Check that get_tensor returns a copy."""
+ self.interpreter.set_tensor(self.input0, self.initial_data)
+ array_initial_copy = self.interpreter.get_tensor(self.input0)
+ new_value = np.add(1., array_initial_copy)
+ self.interpreter.set_tensor(self.input0, new_value)
+ self.assertAllEqual(array_initial_copy, self.initial_data)
+ self.assertAllEqual(self.interpreter.get_tensor(self.input0), new_value)
+
+ def testBase(self):
+ self.assertTrue(self.interpreter._safe_to_run())
+ _ = self.interpreter.tensor(self.input0)
+ self.assertTrue(self.interpreter._safe_to_run())
+ in0 = self.interpreter.tensor(self.input0)()
+ self.assertFalse(self.interpreter._safe_to_run())
+ in0b = self.interpreter.tensor(self.input0)()
+ self.assertFalse(self.interpreter._safe_to_run())
+ # Now get rid of the buffers so that we can evaluate.
+ del in0
+ del in0b
+ self.assertTrue(self.interpreter._safe_to_run())
+
+ def testBaseProtectsFunctions(self):
+ in0 = self.interpreter.tensor(self.input0)()
+ # Make sure we get an exception if we try to run an unsafe operation
+ with self.assertRaisesRegexp(
+ RuntimeError, 'There is at least 1 reference'):
+ _ = self.interpreter.allocate_tensors()
+ # Make sure we get an exception if we try to run an unsafe operation
+ with self.assertRaisesRegexp(
+ RuntimeError, 'There is at least 1 reference'):
+ _ = self.interpreter.invoke()
+ # Now test that we can run
+ del in0 # this is our only buffer reference, so now it is safe to change
+ in0safe = self.interpreter.tensor(self.input0)
+ _ = self.interpreter.allocate_tensors()
+ del in0safe # make sure in0Safe is held but lint doesn't complain
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
index 12ab38847d..69ee95c320 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD
@@ -13,8 +13,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/core:lib",
- "//tensorflow/python:numpy_lib",
+ "//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/memory",
],
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 5f304ad45d..c38b692dcd 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -14,14 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+#include <sstream>
#include <string>
#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/python/lib/core/numpy.h"
+
+// Disallow Numpy 1.7 deprecated symbols.
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+#include <Python.h>
+
+#include "numpy/arrayobject.h"
+#include "numpy/ufuncobject.h"
#if PY_MAJOR_VERSION >= 3
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
@@ -31,10 +38,66 @@ limitations under the License.
#define CPP_TO_PYSTRING PyString_FromStringAndSize
#endif
+#define TFLITE_PY_CHECK(x) \
+ if ((x) != kTfLiteOk) { \
+ return error_reporter_->exception(); \
+ }
+
+#define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
+ if (i >= interpreter_->tensors_size() || i < 0) { \
+ PyErr_Format(PyExc_ValueError, \
+ "Invalid tensor index %d exceeds max tensor index %lu", i, \
+ interpreter_->tensors_size()); \
+ return nullptr; \
+ }
+
+#define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
+ if (!interpreter_) { \
+ PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
+ return nullptr; \
+ }
+
namespace tflite {
namespace interpreter_wrapper {
+class PythonErrorReporter : public tflite::ErrorReporter {
+ public:
+ PythonErrorReporter() {}
+
+ // Report an error message
+ int Report(const char* format, va_list args) override {
+ char buf[1024];
+ int formatted = vsnprintf(buf, sizeof(buf), format, args);
+ buffer_ << buf;
+ return formatted;
+ }
+
+ // Set's a Python runtime exception with the last error.
+ PyObject* exception() {
+ std::string last_message = message();
+ PyErr_SetString(PyExc_RuntimeError, last_message.c_str());
+ return nullptr;
+ }
+
+ // Gets the last error message and clears the buffer.
+ std::string message() {
+ std::string value = buffer_.str();
+ buffer_.clear();
+ return value;
+ }
+
+ private:
+ std::stringstream buffer_;
+};
+
namespace {
+
+// Calls PyArray's initialization to initialize all the API pointers. Note that
+// this usage implies only this translation unit can use the pointers. See
+// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend
+// this further.
+void ImportNumpy() { import_array1(); }
+
std::unique_ptr<tflite::Interpreter> CreateInterpreter(
const tflite::FlatBufferModel* model,
const tflite::ops::builtin::BuiltinOpResolver& resolver) {
@@ -42,23 +105,10 @@ std::unique_ptr<tflite::Interpreter> CreateInterpreter(
return nullptr;
}
- tensorflow::ImportNumpy();
+ ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
- if (interpreter) {
- for (const int input_index : interpreter->inputs()) {
- const TfLiteTensor* tensor = interpreter->tensor(input_index);
- CHECK(tensor);
- const TfLiteIntArray* dims = tensor->dims;
- if (!dims) {
- continue;
- }
-
- std::vector<int> input_dims(dims->data, dims->data + dims->size);
- interpreter->ResizeInputTensor(input_index, input_dims);
- }
- }
return interpreter;
}
@@ -68,6 +118,8 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_FLOAT32;
case kTfLiteInt32:
return NPY_INT32;
+ case kTfLiteInt16:
+ return NPY_INT16;
case kTfLiteUInt8:
return NPY_UINT8;
case kTfLiteInt64:
@@ -76,11 +128,13 @@ int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
return NPY_OBJECT;
case kTfLiteBool:
return NPY_BOOL;
+ case kTfLiteComplex64:
+ return NPY_COMPLEX64;
case kTfLiteNoType:
- return -1;
+ return NPY_NOTYPE;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown TfLiteType " << tf_lite_type;
- return -1;
+ return NPY_NOTYPE;
}
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
@@ -90,6 +144,8 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
return kTfLiteFloat32;
case NPY_INT32:
return kTfLiteInt32;
+ case NPY_INT16:
+ return kTfLiteInt16;
case NPY_UINT8:
return kTfLiteUInt8;
case NPY_INT64:
@@ -100,8 +156,10 @@ TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
case NPY_STRING:
case NPY_UNICODE:
return kTfLiteString;
+ case NPY_COMPLEX64:
+ return kTfLiteComplex64;
+ // Avoid default so compiler errors created when new types are made.
}
- LOG(ERROR) << "Unknown PyArray dtype " << pyarray_type;
return kTfLiteNoType;
}
@@ -125,32 +183,29 @@ PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
} // namespace
InterpreterWrapper::InterpreterWrapper(
- std::unique_ptr<tflite::FlatBufferModel> model)
+ std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter)
: model_(std::move(model)),
+ error_reporter_(std::move(error_reporter)),
resolver_(absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()),
interpreter_(CreateInterpreter(model_.get(), *resolver_)) {}
InterpreterWrapper::~InterpreterWrapper() {}
-bool InterpreterWrapper::AllocateTensors() {
- if (!interpreter_) {
- LOG(ERROR) << "Cannot allocate tensors: invalid interpreter.";
- return false;
- }
-
- if (interpreter_->AllocateTensors() != kTfLiteOk) {
- LOG(ERROR) << "Unable to allocate tensors.";
- return false;
- }
-
- return true;
+PyObject* InterpreterWrapper::AllocateTensors() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->AllocateTensors());
+ Py_RETURN_NONE;
}
-bool InterpreterWrapper::Invoke() {
- return interpreter_ ? (interpreter_->Invoke() == kTfLiteOk) : false;
+PyObject* InterpreterWrapper::Invoke() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->Invoke());
+ Py_RETURN_NONE;
}
PyObject* InterpreterWrapper::InputIndices() const {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
interpreter_->inputs().size());
@@ -164,35 +219,36 @@ PyObject* InterpreterWrapper::OutputIndices() const {
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
-bool InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
+PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert numpy value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
if (PyArray_NDIM(array) != 1) {
- LOG(ERROR) << "Expected 1-D defining input shape.";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
+ PyArray_NDIM(array));
+ return nullptr;
}
if (PyArray_TYPE(array) != NPY_INT32) {
- LOG(ERROR) << "Shape must be an int32 array";
- return false;
+ PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
+ PyArray_TYPE(array));
+ return nullptr;
}
std::vector<int> dims(PyArray_SHAPE(array)[0]);
memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
- return (interpreter_->ResizeInputTensor(i, dims) == kTfLiteOk);
+ TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
+ Py_RETURN_NONE;
}
std::string InterpreterWrapper::TensorName(int i) const {
@@ -205,21 +261,21 @@ std::string InterpreterWrapper::TensorName(int i) const {
}
PyObject* InterpreterWrapper::TensorType(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- return nullptr;
- }
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
- int typenum = TfLiteTypeToPyArrayType(tensor->type);
- return PyArray_TypeObjectFromType(typenum);
+ int code = TfLiteTypeToPyArrayType(tensor->type);
+ if (code == -1) {
+ PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
+ return nullptr;
+ }
+ return PyArray_TypeObjectFromType(code);
}
PyObject* InterpreterWrapper::TensorSize(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
@@ -228,120 +284,167 @@ PyObject* InterpreterWrapper::TensorSize(int i) const {
}
PyObject* InterpreterWrapper::TensorQuantization(int i) const {
- if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
- Py_INCREF(Py_None);
- return Py_None;
- }
-
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
return PyTupleFromQuantizationParam(tensor->params);
}
-bool InterpreterWrapper::SetTensor(int i, PyObject* value) {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- return false;
- }
-
- if (i >= interpreter_->tensors_size()) {
- LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
- << interpreter_->tensors_size();
- return false;
- }
+PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
if (!array_safe) {
- LOG(ERROR) << "Failed to convert value into readable tensor.";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Failed to convert value into readable tensor.");
+ return nullptr;
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
const TfLiteTensor* tensor = interpreter_->tensor(i);
if (TfLiteTypeFromPyArray(array) != tensor->type) {
- LOG(ERROR) << "Cannot set tensor:"
- << " Got tensor of type " << TfLiteTypeFromPyArray(array)
- << " but expected type " << tensor->type << " for input " << i;
- return false;
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor:"
+ " Got tensor of type %d"
+ " but expected type %d for input %d ",
+ TfLiteTypeFromPyArray(array), tensor->type, i);
+ return nullptr;
}
if (PyArray_NDIM(array) != tensor->dims->size) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
- LOG(ERROR) << "Cannot set tensor: Dimension mismatch";
- return false;
+ PyErr_SetString(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch");
+ return nullptr;
}
}
size_t size = PyArray_NBYTES(array);
- DCHECK_EQ(size, tensor->bytes);
+ if (size != tensor->bytes) {
+ PyErr_Format(PyExc_ValueError,
+ "numpy array had %zu bytes but expected %zu bytes.", size,
+ tensor->bytes);
+ return nullptr;
+ }
memcpy(tensor->data.raw, PyArray_DATA(array), size);
- return true;
+ Py_RETURN_NONE;
}
-PyObject* InterpreterWrapper::GetTensor(int i) const {
- if (!interpreter_) {
- LOG(ERROR) << "Invalid interpreter.";
- Py_INCREF(Py_None);
- return Py_None;
- }
+namespace {
+
+PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
+ TfLiteTensor** tensor, int* type_num) {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
- if (i >= interpreter_->tensors_size()) {
- LOG(ERROR) << "Invalid tensor index: " << i << " exceeds max tensor index "
- << interpreter_->inputs().size();
- Py_INCREF(Py_None);
- return Py_None;
+ *tensor = interpreter_->tensor(tensor_index);
+ if ((*tensor)->bytes == 0) {
+ PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
+ return nullptr;
}
- const TfLiteTensor* output_tensor = interpreter_->tensor(i);
- const int tensor_size = output_tensor->bytes;
- if (tensor_size <= 0) {
- LOG(ERROR) << "Invalid tensor size";
- Py_INCREF(Py_None);
- return Py_None;
+ *type_num = TfLiteTypeToPyArrayType((*tensor)->type);
+ if (*type_num == -1) {
+ PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
+ return nullptr;
}
- int type_num = TfLiteTypeToPyArrayType(output_tensor->type);
- if (type_num == -1) {
- LOG(ERROR) << "Unknown tensor type " << output_tensor->type;
- Py_INCREF(Py_None);
- return Py_None;
+ if (!(*tensor)->data.raw) {
+ PyErr_SetString(PyExc_ValueError, "Tensor data is null.");
+ return nullptr;
}
- void* data = malloc(tensor_size);
- memcpy(data, output_tensor->data.raw, tensor_size);
+ return nullptr;
+}
+
+} // namespace
- const TfLiteIntArray* output_dims = output_tensor->dims;
- std::vector<npy_intp> dims(output_dims->data,
- output_dims->data + output_dims->size);
+PyObject* InterpreterWrapper::GetTensor(int i) const {
+ // Sanity check accessor
+ TfLiteTensor* tensor = nullptr;
+ int type_num = 0;
+ if (PyObject* pynone_or_nullptr =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
+ return pynone_or_nullptr;
+ }
+ std::vector<npy_intp> dims(tensor->dims->data,
+ tensor->dims->data + tensor->dims->size);
+ // Make a buffer copy but we must tell Numpy It owns that data or else
+ // it will leak.
+ void* data = malloc(tensor->bytes);
+ if (!data) {
+ PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
+ return nullptr;
+ }
+ memcpy(data, tensor->data.raw, tensor->bytes);
PyObject* np_array =
PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
-
+ PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
+ NPY_ARRAY_OWNDATA);
return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
+PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
+ // Sanity check accessor
+ TfLiteTensor* tensor = nullptr;
+ int type_num = 0;
+ if (PyObject* pynone_or_nullptr =
+ CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num)) {
+ return pynone_or_nullptr;
+ }
+
+ std::vector<npy_intp> dims(tensor->dims->data,
+ tensor->dims->data + tensor->dims->size);
+ PyArrayObject* np_array =
+ reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
+ dims.size(), dims.data(), type_num, tensor->data.raw));
+ Py_INCREF(base_object); // SetBaseObject steals, so we need to add.
+ PyArray_SetBaseObject(np_array, base_object);
+ return PyArray_Return(np_array);
+}
+
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
- const char* model_path) {
+ const char* model_path, std::string* error_msg) {
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromFile(model_path, error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
}
InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
- PyObject* data) {
+ PyObject* data, std::string* error_msg) {
char * buf = nullptr;
Py_ssize_t length;
+ std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
if (PY_TO_CPPSTRING(data, &buf, &length) == -1) {
return nullptr;
}
std::unique_ptr<tflite::FlatBufferModel> model =
- tflite::FlatBufferModel::BuildFromBuffer(buf, length);
- return model ? new InterpreterWrapper(std::move(model)) : nullptr;
+ tflite::FlatBufferModel::BuildFromBuffer(buf, length,
+ error_reporter.get());
+ if (!model) {
+ *error_msg = error_reporter->message();
+ return nullptr;
+ }
+ return new InterpreterWrapper(std::move(model), std::move(error_reporter));
+}
+
+PyObject* InterpreterWrapper::ResetVariableTensorsToZero() {
+ TFLITE_PY_ENSURE_VALID_INTERPRETER();
+ TFLITE_PY_CHECK(interpreter_->ResetVariableTensorsToZero());
+ Py_RETURN_NONE;
}
} // namespace interpreter_wrapper
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index 01320af7a9..febfd2dc56 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -19,7 +19,9 @@ limitations under the License.
#include <string>
#include <vector>
+// Place `<locale>` before <Python.h> to avoid build failures in macOS.
#include <Python.h>
+#include <locale>
// We forward declare TFLite classes here to avoid exposing them to SWIG.
namespace tflite {
@@ -34,31 +36,41 @@ class Interpreter;
namespace interpreter_wrapper {
+class PythonErrorReporter;
+
class InterpreterWrapper {
public:
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path);
+ static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path,
+ std::string* error_msg);
// SWIG caller takes ownership of pointer.
- static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data);
+ static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data,
+ std::string* error_msg);
~InterpreterWrapper();
- bool AllocateTensors();
- bool Invoke();
+ PyObject* AllocateTensors();
+ PyObject* Invoke();
PyObject* InputIndices() const;
PyObject* OutputIndices() const;
- bool ResizeInputTensor(int i, PyObject* value);
+ PyObject* ResizeInputTensor(int i, PyObject* value);
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
PyObject* TensorQuantization(int i) const;
- bool SetTensor(int i, PyObject* value);
+ PyObject* SetTensor(int i, PyObject* value);
PyObject* GetTensor(int i) const;
+ PyObject* ResetVariableTensorsToZero();
+
+ // Returns a reference to tensor index i as a numpy array. The base_object
+ // should be the interpreter object providing the memory.
+ PyObject* tensor(PyObject* base_object, int i);
private:
- InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model);
+ InterpreterWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
+ std::unique_ptr<PythonErrorReporter> error_reporter);
// InterpreterWrapper is not copyable or assignable. We avoid the use of
// InterpreterWrapper() = delete here for SWIG compatibility.
@@ -66,6 +78,7 @@ class InterpreterWrapper {
InterpreterWrapper(const InterpreterWrapper& rhs);
const std::unique_ptr<tflite::FlatBufferModel> model_;
+ const std::unique_ptr<PythonErrorReporter> error_reporter_;
const std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver_;
const std::unique_ptr<tflite::Interpreter> interpreter_;
};
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
index 7f51f9f00d..afb2092eac 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.i
@@ -18,8 +18,51 @@ limitations under the License.
%{
#define SWIG_FILE_WITH_INIT
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
%}
%include "tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h"
+
+namespace tflite {
+namespace interpreter_wrapper {
+%extend InterpreterWrapper {
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromFile(const char* model_path) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromFile(
+ model_path, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+
+ // Version of the constructor that handles producing Python exceptions
+ // that propagate strings.
+ static PyObject* CreateWrapperCPPFromBuffer(
+ PyObject* data) {
+ std::string error;
+ if(tflite::interpreter_wrapper::InterpreterWrapper* ptr =
+ tflite::interpreter_wrapper::InterpreterWrapper
+ ::CreateWrapperCPPFromBuffer(
+ data, &error)) {
+ return SWIG_NewPointerObj(
+ ptr, SWIGTYPE_p_tflite__interpreter_wrapper__InterpreterWrapper, 1);
+ } else {
+ PyErr_SetString(PyExc_ValueError, error.c_str());
+ return nullptr;
+ }
+ }
+}
+
+} // namespace interpreter_wrapper
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index d595415b63..29a1487c1f 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -22,6 +22,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@Interpreter
@@OpHint
@@convert_op_hints_to_stubs
+@@build_toco_convert_protos
@@FLOAT
@@QUANTIZED_UINT8
@@ -38,6 +39,7 @@ from six import PY3
from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
+from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
@@ -48,12 +50,14 @@ from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: di
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework.importer import import_graph_def
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
+# from tensorflow.python.util.all_util import remove_undocumented
class TocoConverter(object):
@@ -64,11 +68,11 @@ class TocoConverter(object):
Attributes:
- inference_type: Target data type of arrays in the output file. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
- inference_input_type: Target data type of input arrays. Allows for a
- different type for input arrays in the case of quantization. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ inference_type: Target data type of real-number arrays in the output file.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of real-number input arrays. Allows
+ for a different type for input arrays in the case of quantization.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
@@ -94,6 +98,16 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
+ quantize_weights: Boolean indicating whether to store weights as quantized
+ weights followed by dequantize operations. Computation is still done in
+ float, but reduces model size (at the cost of accuracy and latency).
+ (default False)
+ dump_graphviz_dir: Full filepath of folder to dump the graphs at various
+ stages of processing GraphViz .dot files. Preferred over
+ --output_format=GRAPHVIZ_DOT in order to keep the requirements of the
+ output file. (default None)
+ dump_graphviz_video: Boolean indicating whether to dump the graph after
+ every graph transformation. (default False)
Example usage:
@@ -118,7 +132,7 @@ class TocoConverter(object):
Args:
- graph_def: TensorFlow GraphDef.
+ graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
@@ -135,6 +149,9 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
+ self.quantize_weights = False
+ self.dump_graphviz_dir = None
+ self.dump_graphviz_video = False
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
@@ -161,7 +178,7 @@ class TocoConverter(object):
"""Creates a TocoConverter class from a file containing a frozen GraphDef.
Args:
- graph_def_file: Full filepath of file containing TensorFlow GraphDef.
+ graph_def_file: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
input_shapes: Dict of strings representing input tensor names to list of
@@ -210,7 +227,7 @@ class TocoConverter(object):
# Check if graph is frozen.
if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py")
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
# Create TocoConverter class.
return cls(sess.graph_def, input_tensors, output_tensors)
@@ -253,6 +270,48 @@ class TocoConverter(object):
return cls(
graph_def=result[0], input_tensors=result[1], output_tensors=result[2])
+ @classmethod
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file.
+
+ Args:
+ model_file: Full filepath of HDF5 file containing the tf.keras model.
+ input_arrays: List of input tensors to freeze graph with. Uses input
+ arrays from SignatureDef when none are provided. (default None)
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" :
+ None}). (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+
+ Returns:
+ TocoConverter class.
+ """
+ _keras.backend.clear_session()
+ _keras.backend.set_learning_phase(False)
+ keras_model = _keras.models.load_model(model_file)
+ sess = _keras.backend.get_session()
+
+ # Get input and output tensors.
+ if input_arrays:
+ input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays)
+ else:
+ input_tensors = keras_model.inputs
+
+ if output_arrays:
+ output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays)
+ else:
+ output_tensors = keras_model.outputs
+ set_tensor_shapes(input_tensors, input_shapes)
+
+ graph_def = _freeze_graph(sess, output_tensors)
+ return cls(graph_def, input_tensors, output_tensors)
+
def convert(self):
"""Converts a TensorFlow GraphDef based on instance variables.
@@ -262,15 +321,19 @@ class TocoConverter(object):
Raises:
ValueError:
+ Input shape is not specified.
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
for tensor in self._input_tensors:
+ if not tensor.get_shape():
+ raise ValueError("Provide an input shape for input array '{0}'.".format(
+ tensor_name(tensor)))
shape = tensor.get_shape().as_list()
if None in shape[1:]:
raise ValueError(
"None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(tensor.name, shape))
+ "invalid shape '{1}'.".format(tensor_name(tensor), shape))
elif shape[0] is None:
self._set_batch_size(batch_size=1)
@@ -306,9 +369,20 @@ class TocoConverter(object):
drop_control_dependency=self.drop_control_dependency,
reorder_across_fake_quant=self.reorder_across_fake_quant,
change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops)
+ allow_custom_ops=self.allow_custom_ops,
+ quantize_weights=self.quantize_weights,
+ dump_graphviz_dir=self.dump_graphviz_dir,
+ dump_graphviz_video=self.dump_graphviz_video)
return result
+ def get_input_arrays(self):
+ """Returns a list of the names of the input tensors.
+
+ Returns:
+ List of strings.
+ """
+ return [tensor_name(tensor) for tensor in self._input_tensors]
+
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -335,7 +409,7 @@ def _is_frozen_graph(sess):
Bool.
"""
for op in sess.graph.get_operations():
- if op.type.startswith("Variable"):
+ if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
return False
return True
@@ -360,3 +434,5 @@ def _freeze_graph(sess, output_tensors):
output_arrays)
else:
return sess.graph_def
+
+# remove_undocumented(__name__)
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 53d1878293..ca2af5aaed 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -19,15 +19,19 @@ from __future__ import division
from __future__ import print_function
import os
+import tempfile
import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python import keras
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -131,21 +135,31 @@ class FromSessionTest(test_util.TensorFlowTestCase):
'Quantization input stats are not available for input tensors '
'\'inputB\'.', str(error.exception))
- def testBatchSizeInvalid(self):
- in_tensor = array_ops.placeholder(
- shape=[None, 16, 16, 3], dtype=dtypes.float32)
+ def testSizeNoneInvalid(self):
+ in_tensor = array_ops.placeholder(dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Test invalid shape. None after 1st dimension.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ with self.assertRaises(ValueError) as error:
+ converter.convert()
+ self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
+ str(error.exception))
+
+ def testBatchSizeInvalid(self):
in_tensor = array_ops.placeholder(
shape=[1, None, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Test invalid shape. None after 1st dimension.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
'None is only supported in the 1st dimension. Tensor '
- '\'Placeholder_1:0\' has invalid shape \'[1, None, 16, 3]\'.',
+ '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
str(error.exception))
def testBatchSizeValid(self):
@@ -208,6 +222,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
def testGraphviz(self):
in_tensor = array_ops.placeholder(
shape=[1, 16, 16, 3], dtype=dtypes.float32)
@@ -220,8 +235,42 @@ class FromSessionTest(test_util.TensorFlowTestCase):
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
+ # TODO(nupurgarg): Verify value of contents in GraphViz.
+ def testDumpGraphviz(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure interpreter is able to allocate and check graphviz data.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ num_items_graphviz = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz)
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ graphviz_dir = self.get_temp_dir()
+ converter.dump_graphviz_dir = graphviz_dir
+ converter.dump_graphviz_video = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure graphviz folder has more data after using video flag.
+ num_items_graphviz_video = len(os.listdir(graphviz_dir))
+ self.assertTrue(num_items_graphviz_video > num_items_graphviz)
+
def testInferenceInputType(self):
- in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8)
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
@@ -240,14 +289,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertEqual('Placeholder', input_details[0]['name'])
self.assertEqual(np.uint8, input_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
- self.assertEqual((0., 0.), input_details[0]['quantization'])
+ self.assertEqual((1., 0.), input_details[0]['quantization'])
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('add', output_details[0]['name'])
- self.assertEqual(np.uint8, output_details[0]['dtype'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
- self.assertEqual((0., 0.), input_details[0]['quantization'])
def testDefaultRangesStats(self):
in_tensor = array_ops.placeholder(
@@ -281,8 +329,38 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
+ def testQuantizeWeights(self):
+ np.random.seed(0)
+ # We need the tensor to have more than 1024 elements for quantize_weights
+ # to kick in. Thus, the [33, 33] shape.
+ in_tensor_1 = array_ops.placeholder(
+ shape=[33, 33], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = constant_op.constant(
+ np.random.uniform(low=-10., high=10., size=(33, 33)),
+ shape=[33, 33],
+ dtype=dtypes.float32,
+ name='inputB')
+ out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+ sess = session.Session()
+
+ # Convert float model.
+ float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
+ float_tflite = float_converter.convert()
+ self.assertTrue(float_tflite)
+
+ # Convert quantized weights model.
+ quantized_weights_converter = lite.TocoConverter.from_session(
+ sess, [in_tensor_1], [out_tensor])
+ quantized_weights_converter.quantize_weights = True
+ quantized_weights_tflite = quantized_weights_converter.convert()
+ self.assertTrue(quantized_weights_tflite)
+
+ # Ensure that the quantized weights tflite model is smaller.
+ self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
-class FromFlatbufferFile(test_util.TensorFlowTestCase):
+
+class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def testFloat(self):
in_tensor = array_ops.placeholder(
@@ -359,7 +437,7 @@ class FromFlatbufferFile(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
- self.assertEqual('Please freeze the graph using freeze_graph.py',
+ self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
def testPbtxt(self):
@@ -542,5 +620,279 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
+class FromKerasFile(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ keras.backend.clear_session()
+
+ def _getSequentialModel(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testSequentialModel(self):
+ """Test a Sequential tf.keras model with default inputs."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testSequentialModelInputArray(self):
+ """Test a Sequential tf.keras model testing input arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid input array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['invalid-input'])
+ self.assertEqual("Invalid tensors 'invalid-input' were found.",
+ str(error.exception))
+
+ # Valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_arrays=['dense_input'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testSequentialModelInputShape(self):
+ """Test a Sequential tf.keras model testing input shapes argument."""
+ keras_file = self._getSequentialModel()
+
+ # Passing in shape of invalid input array has no impact as long as all input
+ # arrays have a shape.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'invalid-input': [2, 3]})
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Passing in shape of valid input array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, input_shapes={'dense_input': [2, 3]})
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ # Check input shape from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertTrue(([2, 3] == input_details[0]['shape']).all())
+
+ def testSequentialModelOutputArray(self):
+ """Test a Sequential tf.keras model testing output arrays argument."""
+ keras_file = self._getSequentialModel()
+
+ # Invalid output array raises error.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['invalid-output'])
+ self.assertEqual("Invalid tensors 'invalid-output' were found.",
+ str(error.exception))
+
+ # Valid output array.
+ converter = lite.TocoConverter.from_keras_model_file(
+ keras_file, output_arrays=['time_distributed/Reshape_1'])
+ tflite_model = converter.convert()
+ os.remove(keras_file)
+ self.assertTrue(tflite_model)
+
+ def testFunctionalModel(self):
+ """Test a Functional tf.keras model with default inputs."""
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ def testFunctionalModelMultipleInputs(self):
+ """Test a Functional tf.keras model with multiple inputs and outputs."""
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('input_a', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ self.assertEqual('input_b', input_details[1]['name'])
+ self.assertEqual(np.float32, input_details[1]['dtype'])
+ self.assertTrue(([1, 3] == input_details[1]['shape']).all())
+ self.assertEqual((0., 0.), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(2, len(output_details))
+ self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('dropout/Identity', output_details[1]['name'])
+ self.assertEqual(np.float32, output_details[1]['dtype'])
+ self.assertTrue(([1, 4] == output_details[1]['shape']).all())
+ self.assertEqual((0., 0.), output_details[1]['quantization'])
+
+ def testFunctionalSequentialModel(self):
+ """Test a Functional tf.keras model containing a Sequential model."""
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+
+ # Convert to TFLite model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ os.close(fd)
+ os.remove(keras_file)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('dense_input', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 337f05785e..9bd1f4f76e 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -23,19 +23,15 @@ import os
import sys
from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.platform import app
-def _parse_array(values):
+def _parse_array(values, type_fn=str):
if values:
- return values.split(",")
-
-
-def _parse_int_array(values):
- if values:
- return [int(val) for val in values.split(",")]
+ return [type_fn(val) for val in values.split(",") if val]
def _parse_set(values):
@@ -57,7 +53,8 @@ def _get_toco_converter(flags):
input_shapes = None
if flags.input_shapes:
input_shapes_list = [
- _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+ _parse_array(shape, type_fn=int)
+ for shape in flags.input_shapes.split(":")
]
input_shapes = dict(zip(input_arrays, input_shapes_list))
output_arrays = _parse_array(flags.output_arrays)
@@ -77,6 +74,9 @@ def _get_toco_converter(flags):
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
+ elif flags.keras_model_file:
+ converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_kwargs["model_file"] = flags.keras_model_file
return converter_fn(**converter_kwargs)
@@ -86,6 +86,9 @@ def _convert_model(flags):
Args:
flags: argparse.Namespace object.
+
+ Raises:
+ ValueError: Invalid flags.
"""
# Create converter.
converter = _get_toco_converter(flags)
@@ -99,12 +102,22 @@ def _convert_model(flags):
flags.output_format)
if flags.mean_values and flags.std_dev_values:
- input_arrays = _parse_array(flags.input_arrays)
- std_dev_values = _parse_int_array(flags.std_dev_values)
- mean_values = _parse_int_array(flags.mean_values)
- quant_stats = zip(mean_values, std_dev_values)
+ input_arrays = converter.get_input_arrays()
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
+ quant_stats = list(zip(mean_values, std_dev_values))
+ if ((not flags.input_arrays and len(input_arrays) > 1) or
+ (len(input_arrays) != len(quant_stats))):
+ raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
+ "--mean_values. The flags must have the same number of "
+ "items. The current input arrays are '{0}'. "
+ "--input_arrays must be present when specifying "
+ "--std_dev_values and --mean_values with multiple input "
+ "tensors in order to map between names and "
+ "values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
- if flags.default_ranges_min and flags.default_ranges_max:
+ if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
+ not None):
converter.default_ranges_stats = (flags.default_ranges_min,
flags.default_ranges_max)
@@ -116,6 +129,15 @@ def _convert_model(flags):
converter.change_concat_input_ranges = flags.change_concat_input_ranges
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
+ if flags.quantize_weights:
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ raise ValueError("--quantized_weights is not supported with "
+ "--inference_type=QUANTIZED_UINT8")
+ converter.quantize_weights = flags.quantize_weights
+ if flags.dump_graphviz_dir:
+ converter.dump_graphviz_dir = flags.dump_graphviz_dir
+ if flags.dump_graphviz_video:
+ converter.dump_graphviz_vode = flags.dump_graphviz_video
# Convert model.
output_data = converter.convert()
@@ -147,9 +169,14 @@ def _check_flags(flags, unparsed):
output = ""
for flag in unparsed:
output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
+ output += _get_message_unparsed(flag, "--savedmodel_directory",
+ "--saved_model_dir")
output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
- raise ValueError(output)
+ output += _get_message_unparsed(flag, "--dump_graphviz",
+ "--dump_graphviz_dir")
+ if output:
+ raise ValueError(output)
# Check that flags are valid.
if flags.graph_def_file and (not flags.input_arrays or
@@ -168,18 +195,17 @@ def _check_flags(flags, unparsed):
if bool(flags.std_dev_values) != bool(flags.mean_values):
raise ValueError("--std_dev_values and --mean_values must be used "
"together")
- if not flags.input_arrays:
- raise ValueError("--std_dev_values and --mean_values must be used with "
- "--input_arrays")
- if (flags.std_dev_values.count(",") != flags.mean_values.count(",") or
- flags.std_dev_values.count(",") != flags.input_arrays.count(",")):
- raise ValueError("--std_dev_values, --mean_values, and --input_arrays "
- "must have the same number of items")
-
- if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
+ if flags.std_dev_values.count(",") != flags.mean_values.count(","):
+ raise ValueError("--std_dev_values, --mean_values must have the same "
+ "number of items")
+
+ if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
raise ValueError("--default_ranges_min and --default_ranges_max must be "
"used together")
+ if flags.dump_graphviz_video and not flags.dump_graphviz:
+ raise ValueError("--dump_graphviz_video must be used with --dump_graphviz")
+
def run_main(_):
"""Main in toco_convert.py."""
@@ -199,29 +225,33 @@ def run_main(_):
input_file_group.add_argument(
"--graph_def_file",
type=str,
- help="Full filepath of file containing TensorFlow GraphDef.")
+ help="Full filepath of file containing frozen TensorFlow GraphDef.")
input_file_group.add_argument(
"--saved_model_dir",
type=str,
help="Full filepath of directory containing the SavedModel.")
+ input_file_group.add_argument(
+ "--keras_model_file",
+ type=str,
+ help="Full filepath of HDF5 file containing tf.Keras model.")
# Model format flags.
parser.add_argument(
"--output_format",
- type=str,
+ type=str.upper,
choices=["TFLITE", "GRAPHVIZ_DOT"],
help="Output file format.")
parser.add_argument(
"--inference_type",
- type=str,
+ type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
- help="Target data type of arrays in the output file.")
+ help="Target data type of real-number arrays in the output file.")
parser.add_argument(
"--inference_input_type",
- type=str,
+ type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
- help=("Target data type of input arrays. Allows for a different type for "
- "input arrays in the case of quantization."))
+ help=("Target data type of real-number input arrays. Allows for a "
+ "different type for input arrays in the case of quantization."))
# Input and output arrays flags.
parser.add_argument(
@@ -255,12 +285,12 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated. Used for quantization. (default None)"))
+ "comma-separated integers. Used for quantization. (default None)"))
parser.add_argument(
"--mean_values",
type=str,
- help=("Mean of training data for each input tensor, comma-separated. "
- "Used for quantization. (default None)"))
+ help=("Mean of training data for each input tensor, comma-separated "
+ "integers. Used for quantization. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,
@@ -273,17 +303,23 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ parser.add_argument(
+ "--quantize_weights",
+ type=bool,
+ help=("Store float weights as quantized weights followed by dequantize "
+ "operations. Inference is still done in FLOAT, but reduces model "
+ "size (at the cost of accuracy and latency)."))
# Graph manipulation flags.
parser.add_argument(
"--drop_control_dependency",
- type=bool,
+ action="store_true",
help=("Boolean indicating whether to drop control dependencies silently. "
"This is due to TensorFlow not supporting control dependencies. "
"(default True)"))
parser.add_argument(
"--reorder_across_fake_quant",
- type=bool,
+ action="store_true",
help=("Boolean indicating whether to reorder FakeQuant nodes in "
"unexpected locations. Used when the location of the FakeQuant "
"nodes is preventing graph transformations necessary to convert "
@@ -292,19 +328,33 @@ def run_main(_):
"behavior. (default False)"))
parser.add_argument(
"--change_concat_input_ranges",
- type=bool,
+ action="store_true",
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
parser.add_argument(
"--allow_custom_ops",
- type=bool,
+ action="store_true",
help=("Boolean indicating whether to allow custom operations. When false "
"any unknown operation is an error. When true, custom ops are "
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
+ # Logging flags.
+ parser.add_argument(
+ "--dump_graphviz_dir",
+ type=str,
+ help=("Full filepath of folder to dump the graphs at various stages of "
+ "processing GraphViz .dot files. Preferred over --output_format="
+ "GRAPHVIZ_DOT in order to keep the requirements of the output "
+ "file."))
+ parser.add_argument(
+ "--dump_graphviz_video",
+ action="store_true",
+ help=("Boolean indicating whether to dump the graph after every graph "
+ "transformation"))
+
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
try:
_check_flags(tflite_flags, unparsed)
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 9717a4a1a4..f095151cae 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -65,6 +65,7 @@ cc_test(
],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
"//tensorflow/core:lib_platform",
diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
index 64ab0a9fe2..9dc8daa227 100644
--- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
+++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc
@@ -39,7 +39,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
// DO NOT EDIT MANUALLY: This file is automatically generated by
-// `schema_builtin_ops_header_generator.py`.
+// `schema/builtin_ops_header/generator.cc`.
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 522eac25b3..17ea26052d 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -34,6 +34,8 @@ enum TensorType : byte {
INT64 = 4,
STRING = 5,
BOOL = 6,
+ INT16 = 7,
+ COMPLEX64 = 8,
}
// Parameters for converting a quantized tensor back to float. Given a
@@ -42,7 +44,7 @@ enum TensorType : byte {
table QuantizationParameters {
min:[float]; // For importing back into tensorflow.
max:[float]; // For importing back into tensorflow.
- scale:[float];
+ scale:[float]; // For dequantizing the tensor's values.
zero_point:[long];
}
@@ -63,6 +65,8 @@ table Tensor {
buffer:uint;
name:string; // For debugging and importing back into tensorflow.
quantization:QuantizationParameters; // Optional.
+
+ is_variable:bool = false;
}
// A list of builtin operators. Builtin operators are slightly faster than custom
@@ -146,6 +150,18 @@ enum BuiltinOperator : byte {
SIN = 66,
TRANSPOSE_CONV = 67,
SPARSE_TO_DENSE = 68,
+ TILE = 69,
+ EXPAND_DIMS = 70,
+ EQUAL = 71,
+ NOT_EQUAL = 72,
+ LOG = 73,
+ SUM=74,
+ SQRT = 75,
+ RSQRT = 76,
+ SHAPE = 77,
+ POW = 78,
+ ARG_MIN = 79,
+ FAKE_QUANT = 80,
}
// Options for the builtin operators.
@@ -176,7 +192,7 @@ union BuiltinOptions {
BatchToSpaceNDOptions,
SpaceToBatchNDOptions,
TransposeOptions,
- MeanOptions,
+ ReducerOptions,
SubOptions,
DivOptions,
SqueezeOptions,
@@ -200,6 +216,14 @@ union BuiltinOptions {
SliceOptions,
TransposeConvOptions,
SparseToDenseOptions,
+ TileOptions,
+ ExpandDimsOptions,
+ EqualOptions,
+ NotEqualOptions,
+ ShapeOptions,
+ PowOptions,
+ ArgMinOptions,
+ FakeQuantOptions,
}
enum Padding : byte { SAME, VALID }
@@ -277,9 +301,18 @@ table BidirectionalSequenceRNNOptions {
fused_activation_function:ActivationFunctionType;
}
+enum FullyConnectedOptionsWeightsFormat: byte {
+ DEFAULT = 0,
+ SHUFFLED4x16INT8 = 1,
+}
+
// An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
table FullyConnectedOptions {
+ // Parameters for FullyConnected version 1 or above.
fused_activation_function:ActivationFunctionType;
+
+ // Parameters for FullyConnected version 2 or above.
+ weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
}
table SoftmaxOptions {
@@ -311,11 +344,23 @@ table LocalResponseNormalizationOptions {
beta:float;
}
+enum LSTMKernelType : byte {
+ // Full LSTM kernel which supports peephole and projection.
+ FULL = 0,
+ // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
+ BASIC = 1,
+}
+
// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
table LSTMOptions {
+ // Parameters for LSTM version 1 or above.
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
proj_clip: float; // Optional, 0.0 means no clipping
+
+ // Parameters for LSTM version 2 or above.
+ // Basic kernel is only supported in version 2 or above.
+ kernel_type: LSTMKernelType = FULL;
}
table ResizeBilinearOptions {
@@ -387,7 +432,7 @@ table TransposeOptions {
table ExpOptions {
}
-table MeanOptions {
+table ReducerOptions {
keep_dims: bool;
}
@@ -421,10 +466,17 @@ table DequantizeOptions {
table MaximumMinimumOptions {
}
+table TileOptions {
+}
+
table ArgMaxOptions {
output_type : TensorType;
}
+table ArgMinOptions {
+ output_type : TensorType;
+}
+
table GreaterOptions {
}
@@ -452,10 +504,33 @@ table TransposeConvOptions {
stride_h:int;
}
+table ExpandDimsOptions {
+}
+
table SparseToDenseOptions {
validate_indices:bool;
}
+table EqualOptions {
+}
+
+table NotEqualOptions {
+}
+
+table ShapeOptions {
+ // Optional output type of the operation (int32 or int64). Defaults to int32.
+ out_type : TensorType;
+}
+
+table PowOptions {
+}
+
+table FakeQuantOptions {
+ min:float;
+ max:float;
+ num_bits:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
@@ -487,6 +562,16 @@ table Operator {
builtin_options:BuiltinOptions;
custom_options:[ubyte];
custom_options_format:CustomOptionsFormat;
+
+ // A list of booleans indicating the input tensors which are being mutated by
+ // this operator.(e.g. used by RNN and LSTM).
+ // For example, if the "inputs" array refers to 5 tensors and the second and
+ // fifth are mutable variables, then this list will contain
+ // [false, true, false, false, true].
+ //
+ // If the list is empty, no variable is mutated in this operator.
+ // The list either has the same length as `inputs`, or is empty.
+ mutating_variable_inputs:[bool];
}
// The root type, defining a subgraph, which typically represents an entire
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 746dd26796..37489ebc68 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -127,8 +127,8 @@ struct TransposeOptionsT;
struct ExpOptions;
struct ExpOptionsT;
-struct MeanOptions;
-struct MeanOptionsT;
+struct ReducerOptions;
+struct ReducerOptionsT;
struct SqueezeOptions;
struct SqueezeOptionsT;
@@ -151,9 +151,15 @@ struct DequantizeOptionsT;
struct MaximumMinimumOptions;
struct MaximumMinimumOptionsT;
+struct TileOptions;
+struct TileOptionsT;
+
struct ArgMaxOptions;
struct ArgMaxOptionsT;
+struct ArgMinOptions;
+struct ArgMinOptionsT;
+
struct GreaterOptions;
struct GreaterOptionsT;
@@ -178,9 +184,27 @@ struct SliceOptionsT;
struct TransposeConvOptions;
struct TransposeConvOptionsT;
+struct ExpandDimsOptions;
+struct ExpandDimsOptionsT;
+
struct SparseToDenseOptions;
struct SparseToDenseOptionsT;
+struct EqualOptions;
+struct EqualOptionsT;
+
+struct NotEqualOptions;
+struct NotEqualOptionsT;
+
+struct ShapeOptions;
+struct ShapeOptionsT;
+
+struct PowOptions;
+struct PowOptionsT;
+
+struct FakeQuantOptions;
+struct FakeQuantOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -204,11 +228,13 @@ enum TensorType {
TensorType_INT64 = 4,
TensorType_STRING = 5,
TensorType_BOOL = 6,
+ TensorType_INT16 = 7,
+ TensorType_COMPLEX64 = 8,
TensorType_MIN = TensorType_FLOAT32,
- TensorType_MAX = TensorType_BOOL
+ TensorType_MAX = TensorType_COMPLEX64
};
-inline TensorType (&EnumValuesTensorType())[7] {
+inline TensorType (&EnumValuesTensorType())[9] {
static TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
@@ -216,7 +242,9 @@ inline TensorType (&EnumValuesTensorType())[7] {
TensorType_UINT8,
TensorType_INT64,
TensorType_STRING,
- TensorType_BOOL
+ TensorType_BOOL,
+ TensorType_INT16,
+ TensorType_COMPLEX64
};
return values;
}
@@ -230,6 +258,8 @@ inline const char **EnumNamesTensorType() {
"INT64",
"STRING",
"BOOL",
+ "INT16",
+ "COMPLEX64",
nullptr
};
return names;
@@ -309,11 +339,23 @@ enum BuiltinOperator {
BuiltinOperator_SIN = 66,
BuiltinOperator_TRANSPOSE_CONV = 67,
BuiltinOperator_SPARSE_TO_DENSE = 68,
+ BuiltinOperator_TILE = 69,
+ BuiltinOperator_EXPAND_DIMS = 70,
+ BuiltinOperator_EQUAL = 71,
+ BuiltinOperator_NOT_EQUAL = 72,
+ BuiltinOperator_LOG = 73,
+ BuiltinOperator_SUM = 74,
+ BuiltinOperator_SQRT = 75,
+ BuiltinOperator_RSQRT = 76,
+ BuiltinOperator_SHAPE = 77,
+ BuiltinOperator_POW = 78,
+ BuiltinOperator_ARG_MIN = 79,
+ BuiltinOperator_FAKE_QUANT = 80,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SPARSE_TO_DENSE
+ BuiltinOperator_MAX = BuiltinOperator_FAKE_QUANT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[80] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -382,7 +424,19 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] {
BuiltinOperator_SLICE,
BuiltinOperator_SIN,
BuiltinOperator_TRANSPOSE_CONV,
- BuiltinOperator_SPARSE_TO_DENSE
+ BuiltinOperator_SPARSE_TO_DENSE,
+ BuiltinOperator_TILE,
+ BuiltinOperator_EXPAND_DIMS,
+ BuiltinOperator_EQUAL,
+ BuiltinOperator_NOT_EQUAL,
+ BuiltinOperator_LOG,
+ BuiltinOperator_SUM,
+ BuiltinOperator_SQRT,
+ BuiltinOperator_RSQRT,
+ BuiltinOperator_SHAPE,
+ BuiltinOperator_POW,
+ BuiltinOperator_ARG_MIN,
+ BuiltinOperator_FAKE_QUANT
};
return values;
}
@@ -458,6 +512,18 @@ inline const char **EnumNamesBuiltinOperator() {
"SIN",
"TRANSPOSE_CONV",
"SPARSE_TO_DENSE",
+ "TILE",
+ "EXPAND_DIMS",
+ "EQUAL",
+ "NOT_EQUAL",
+ "LOG",
+ "SUM",
+ "SQRT",
+ "RSQRT",
+ "SHAPE",
+ "POW",
+ "ARG_MIN",
+ "FAKE_QUANT",
nullptr
};
return names;
@@ -496,7 +562,7 @@ enum BuiltinOptions {
BuiltinOptions_BatchToSpaceNDOptions = 24,
BuiltinOptions_SpaceToBatchNDOptions = 25,
BuiltinOptions_TransposeOptions = 26,
- BuiltinOptions_MeanOptions = 27,
+ BuiltinOptions_ReducerOptions = 27,
BuiltinOptions_SubOptions = 28,
BuiltinOptions_DivOptions = 29,
BuiltinOptions_SqueezeOptions = 30,
@@ -520,11 +586,19 @@ enum BuiltinOptions {
BuiltinOptions_SliceOptions = 48,
BuiltinOptions_TransposeConvOptions = 49,
BuiltinOptions_SparseToDenseOptions = 50,
+ BuiltinOptions_TileOptions = 51,
+ BuiltinOptions_ExpandDimsOptions = 52,
+ BuiltinOptions_EqualOptions = 53,
+ BuiltinOptions_NotEqualOptions = 54,
+ BuiltinOptions_ShapeOptions = 55,
+ BuiltinOptions_PowOptions = 56,
+ BuiltinOptions_ArgMinOptions = 57,
+ BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SparseToDenseOptions
+ BuiltinOptions_MAX = BuiltinOptions_FakeQuantOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[59] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -553,7 +627,7 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] {
BuiltinOptions_BatchToSpaceNDOptions,
BuiltinOptions_SpaceToBatchNDOptions,
BuiltinOptions_TransposeOptions,
- BuiltinOptions_MeanOptions,
+ BuiltinOptions_ReducerOptions,
BuiltinOptions_SubOptions,
BuiltinOptions_DivOptions,
BuiltinOptions_SqueezeOptions,
@@ -576,7 +650,15 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] {
BuiltinOptions_SelectOptions,
BuiltinOptions_SliceOptions,
BuiltinOptions_TransposeConvOptions,
- BuiltinOptions_SparseToDenseOptions
+ BuiltinOptions_SparseToDenseOptions,
+ BuiltinOptions_TileOptions,
+ BuiltinOptions_ExpandDimsOptions,
+ BuiltinOptions_EqualOptions,
+ BuiltinOptions_NotEqualOptions,
+ BuiltinOptions_ShapeOptions,
+ BuiltinOptions_PowOptions,
+ BuiltinOptions_ArgMinOptions,
+ BuiltinOptions_FakeQuantOptions
};
return values;
}
@@ -610,7 +692,7 @@ inline const char **EnumNamesBuiltinOptions() {
"BatchToSpaceNDOptions",
"SpaceToBatchNDOptions",
"TransposeOptions",
- "MeanOptions",
+ "ReducerOptions",
"SubOptions",
"DivOptions",
"SqueezeOptions",
@@ -634,6 +716,14 @@ inline const char **EnumNamesBuiltinOptions() {
"SliceOptions",
"TransposeConvOptions",
"SparseToDenseOptions",
+ "TileOptions",
+ "ExpandDimsOptions",
+ "EqualOptions",
+ "NotEqualOptions",
+ "ShapeOptions",
+ "PowOptions",
+ "ArgMinOptions",
+ "FakeQuantOptions",
nullptr
};
return names;
@@ -752,8 +842,8 @@ template<> struct BuiltinOptionsTraits<TransposeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions;
};
-template<> struct BuiltinOptionsTraits<MeanOptions> {
- static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions;
+template<> struct BuiltinOptionsTraits<ReducerOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ReducerOptions;
};
template<> struct BuiltinOptionsTraits<SubOptions> {
@@ -848,6 +938,38 @@ template<> struct BuiltinOptionsTraits<SparseToDenseOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions;
};
+template<> struct BuiltinOptionsTraits<TileOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_TileOptions;
+};
+
+template<> struct BuiltinOptionsTraits<ExpandDimsOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions;
+};
+
+template<> struct BuiltinOptionsTraits<EqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<NotEqualOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions;
+};
+
+template<> struct BuiltinOptionsTraits<ShapeOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions;
+};
+
+template<> struct BuiltinOptionsTraits<PowOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
+};
+
+template<> struct BuiltinOptionsTraits<ArgMinOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ArgMinOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FakeQuantOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FakeQuantOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1087,13 +1209,13 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_TransposeOptions ?
reinterpret_cast<const TransposeOptionsT *>(value) : nullptr;
}
- MeanOptionsT *AsMeanOptions() {
- return type == BuiltinOptions_MeanOptions ?
- reinterpret_cast<MeanOptionsT *>(value) : nullptr;
+ ReducerOptionsT *AsReducerOptions() {
+ return type == BuiltinOptions_ReducerOptions ?
+ reinterpret_cast<ReducerOptionsT *>(value) : nullptr;
}
- const MeanOptionsT *AsMeanOptions() const {
- return type == BuiltinOptions_MeanOptions ?
- reinterpret_cast<const MeanOptionsT *>(value) : nullptr;
+ const ReducerOptionsT *AsReducerOptions() const {
+ return type == BuiltinOptions_ReducerOptions ?
+ reinterpret_cast<const ReducerOptionsT *>(value) : nullptr;
}
SubOptionsT *AsSubOptions() {
return type == BuiltinOptions_SubOptions ?
@@ -1279,6 +1401,70 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_SparseToDenseOptions ?
reinterpret_cast<const SparseToDenseOptionsT *>(value) : nullptr;
}
+ TileOptionsT *AsTileOptions() {
+ return type == BuiltinOptions_TileOptions ?
+ reinterpret_cast<TileOptionsT *>(value) : nullptr;
+ }
+ const TileOptionsT *AsTileOptions() const {
+ return type == BuiltinOptions_TileOptions ?
+ reinterpret_cast<const TileOptionsT *>(value) : nullptr;
+ }
+ ExpandDimsOptionsT *AsExpandDimsOptions() {
+ return type == BuiltinOptions_ExpandDimsOptions ?
+ reinterpret_cast<ExpandDimsOptionsT *>(value) : nullptr;
+ }
+ const ExpandDimsOptionsT *AsExpandDimsOptions() const {
+ return type == BuiltinOptions_ExpandDimsOptions ?
+ reinterpret_cast<const ExpandDimsOptionsT *>(value) : nullptr;
+ }
+ EqualOptionsT *AsEqualOptions() {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<EqualOptionsT *>(value) : nullptr;
+ }
+ const EqualOptionsT *AsEqualOptions() const {
+ return type == BuiltinOptions_EqualOptions ?
+ reinterpret_cast<const EqualOptionsT *>(value) : nullptr;
+ }
+ NotEqualOptionsT *AsNotEqualOptions() {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<NotEqualOptionsT *>(value) : nullptr;
+ }
+ const NotEqualOptionsT *AsNotEqualOptions() const {
+ return type == BuiltinOptions_NotEqualOptions ?
+ reinterpret_cast<const NotEqualOptionsT *>(value) : nullptr;
+ }
+ ShapeOptionsT *AsShapeOptions() {
+ return type == BuiltinOptions_ShapeOptions ?
+ reinterpret_cast<ShapeOptionsT *>(value) : nullptr;
+ }
+ const ShapeOptionsT *AsShapeOptions() const {
+ return type == BuiltinOptions_ShapeOptions ?
+ reinterpret_cast<const ShapeOptionsT *>(value) : nullptr;
+ }
+ PowOptionsT *AsPowOptions() {
+ return type == BuiltinOptions_PowOptions ?
+ reinterpret_cast<PowOptionsT *>(value) : nullptr;
+ }
+ const PowOptionsT *AsPowOptions() const {
+ return type == BuiltinOptions_PowOptions ?
+ reinterpret_cast<const PowOptionsT *>(value) : nullptr;
+ }
+ ArgMinOptionsT *AsArgMinOptions() {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<ArgMinOptionsT *>(value) : nullptr;
+ }
+ const ArgMinOptionsT *AsArgMinOptions() const {
+ return type == BuiltinOptions_ArgMinOptions ?
+ reinterpret_cast<const ArgMinOptionsT *>(value) : nullptr;
+ }
+ FakeQuantOptionsT *AsFakeQuantOptions() {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<FakeQuantOptionsT *>(value) : nullptr;
+ }
+ const FakeQuantOptionsT *AsFakeQuantOptions() const {
+ return type == BuiltinOptions_FakeQuantOptions ?
+ reinterpret_cast<const FakeQuantOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -1386,6 +1572,64 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) {
return EnumNamesLSHProjectionType()[index];
}
+enum FullyConnectedOptionsWeightsFormat {
+ FullyConnectedOptionsWeightsFormat_DEFAULT = 0,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8 = 1,
+ FullyConnectedOptionsWeightsFormat_MIN = FullyConnectedOptionsWeightsFormat_DEFAULT,
+ FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
+};
+
+inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
+ static FullyConnectedOptionsWeightsFormat values[] = {
+ FullyConnectedOptionsWeightsFormat_DEFAULT,
+ FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
+ };
+ return values;
+}
+
+inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() {
+ static const char *names[] = {
+ "DEFAULT",
+ "SHUFFLED4x16INT8",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameFullyConnectedOptionsWeightsFormat(FullyConnectedOptionsWeightsFormat e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesFullyConnectedOptionsWeightsFormat()[index];
+}
+
+enum LSTMKernelType {
+ LSTMKernelType_FULL = 0,
+ LSTMKernelType_BASIC = 1,
+ LSTMKernelType_MIN = LSTMKernelType_FULL,
+ LSTMKernelType_MAX = LSTMKernelType_BASIC
+};
+
+inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static LSTMKernelType values[] = {
+ LSTMKernelType_FULL,
+ LSTMKernelType_BASIC
+ };
+ return values;
+}
+
+inline const char **EnumNamesLSTMKernelType() {
+ static const char *names[] = {
+ "FULL",
+ "BASIC",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameLSTMKernelType(LSTMKernelType e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesLSTMKernelType()[index];
+}
+
enum CombinerType {
CombinerType_SUM = 0,
CombinerType_MEAN = 1,
@@ -1555,9 +1799,11 @@ struct TensorT : public flatbuffers::NativeTable {
uint32_t buffer;
std::string name;
std::unique_ptr<QuantizationParametersT> quantization;
+ bool is_variable;
TensorT()
: type(TensorType_FLOAT32),
- buffer(0) {
+ buffer(0),
+ is_variable(false) {
}
};
@@ -1568,7 +1814,8 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_TYPE = 6,
VT_BUFFER = 8,
VT_NAME = 10,
- VT_QUANTIZATION = 12
+ VT_QUANTIZATION = 12,
+ VT_IS_VARIABLE = 14
};
const flatbuffers::Vector<int32_t> *shape() const {
return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_SHAPE);
@@ -1585,6 +1832,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const QuantizationParameters *quantization() const {
return GetPointer<const QuantizationParameters *>(VT_QUANTIZATION);
}
+ bool is_variable() const {
+ return GetField<uint8_t>(VT_IS_VARIABLE, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
@@ -1595,6 +1845,7 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.Verify(name()) &&
VerifyOffset(verifier, VT_QUANTIZATION) &&
verifier.VerifyTable(quantization()) &&
+ VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
verifier.EndTable();
}
TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1620,6 +1871,9 @@ struct TensorBuilder {
void add_quantization(flatbuffers::Offset<QuantizationParameters> quantization) {
fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization);
}
+ void add_is_variable(bool is_variable) {
+ fbb_.AddElement<uint8_t>(Tensor::VT_IS_VARIABLE, static_cast<uint8_t>(is_variable), 0);
+ }
explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -1638,12 +1892,14 @@ inline flatbuffers::Offset<Tensor> CreateTensor(
TensorType type = TensorType_FLOAT32,
uint32_t buffer = 0,
flatbuffers::Offset<flatbuffers::String> name = 0,
- flatbuffers::Offset<QuantizationParameters> quantization = 0) {
+ flatbuffers::Offset<QuantizationParameters> quantization = 0,
+ bool is_variable = false) {
TensorBuilder builder_(_fbb);
builder_.add_quantization(quantization);
builder_.add_name(name);
builder_.add_buffer(buffer);
builder_.add_shape(shape);
+ builder_.add_is_variable(is_variable);
builder_.add_type(type);
return builder_.Finish();
}
@@ -1654,14 +1910,16 @@ inline flatbuffers::Offset<Tensor> CreateTensorDirect(
TensorType type = TensorType_FLOAT32,
uint32_t buffer = 0,
const char *name = nullptr,
- flatbuffers::Offset<QuantizationParameters> quantization = 0) {
+ flatbuffers::Offset<QuantizationParameters> quantization = 0,
+ bool is_variable = false) {
return tflite::CreateTensor(
_fbb,
shape ? _fbb.CreateVector<int32_t>(*shape) : 0,
type,
buffer,
name ? _fbb.CreateString(name) : 0,
- quantization);
+ quantization,
+ is_variable);
}
flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -2395,22 +2653,29 @@ flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequence
struct FullyConnectedOptionsT : public flatbuffers::NativeTable {
typedef FullyConnectedOptions TableType;
ActivationFunctionType fused_activation_function;
+ FullyConnectedOptionsWeightsFormat weights_format;
FullyConnectedOptionsT()
- : fused_activation_function(ActivationFunctionType_NONE) {
+ : fused_activation_function(ActivationFunctionType_NONE),
+ weights_format(FullyConnectedOptionsWeightsFormat_DEFAULT) {
}
};
struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef FullyConnectedOptionsT NativeTableType;
enum {
- VT_FUSED_ACTIVATION_FUNCTION = 4
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_WEIGHTS_FORMAT = 6
};
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ FullyConnectedOptionsWeightsFormat weights_format() const {
+ return static_cast<FullyConnectedOptionsWeightsFormat>(GetField<int8_t>(VT_WEIGHTS_FORMAT, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<int8_t>(verifier, VT_WEIGHTS_FORMAT) &&
verifier.EndTable();
}
FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2424,6 +2689,9 @@ struct FullyConnectedOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_weights_format(FullyConnectedOptionsWeightsFormat weights_format) {
+ fbb_.AddElement<int8_t>(FullyConnectedOptions::VT_WEIGHTS_FORMAT, static_cast<int8_t>(weights_format), 0);
+ }
explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2438,8 +2706,10 @@ struct FullyConnectedOptionsBuilder {
inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(
flatbuffers::FlatBufferBuilder &_fbb,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT) {
FullyConnectedOptionsBuilder builder_(_fbb);
+ builder_.add_weights_format(weights_format);
builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish();
}
@@ -2823,10 +3093,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable {
ActivationFunctionType fused_activation_function;
float cell_clip;
float proj_clip;
+ LSTMKernelType kernel_type;
LSTMOptionsT()
: fused_activation_function(ActivationFunctionType_NONE),
cell_clip(0.0f),
- proj_clip(0.0f) {
+ proj_clip(0.0f),
+ kernel_type(LSTMKernelType_FULL) {
}
};
@@ -2835,7 +3107,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
enum {
VT_FUSED_ACTIVATION_FUNCTION = 4,
VT_CELL_CLIP = 6,
- VT_PROJ_CLIP = 8
+ VT_PROJ_CLIP = 8,
+ VT_KERNEL_TYPE = 10
};
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
@@ -2846,11 +3119,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
float proj_clip() const {
return GetField<float>(VT_PROJ_CLIP, 0.0f);
}
+ LSTMKernelType kernel_type() const {
+ return static_cast<LSTMKernelType>(GetField<int8_t>(VT_KERNEL_TYPE, 0));
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
VerifyField<float>(verifier, VT_CELL_CLIP) &&
VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ VerifyField<int8_t>(verifier, VT_KERNEL_TYPE) &&
verifier.EndTable();
}
LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2870,6 +3147,9 @@ struct LSTMOptionsBuilder {
void add_proj_clip(float proj_clip) {
fbb_.AddElement<float>(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
}
+ void add_kernel_type(LSTMKernelType kernel_type) {
+ fbb_.AddElement<int8_t>(LSTMOptions::VT_KERNEL_TYPE, static_cast<int8_t>(kernel_type), 0);
+ }
explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2886,10 +3166,12 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::FlatBufferBuilder &_fbb,
ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
float cell_clip = 0.0f,
- float proj_clip = 0.0f) {
+ float proj_clip = 0.0f,
+ LSTMKernelType kernel_type = LSTMKernelType_FULL) {
LSTMOptionsBuilder builder_(_fbb);
builder_.add_proj_clip(proj_clip);
builder_.add_cell_clip(cell_clip);
+ builder_.add_kernel_type(kernel_type);
builder_.add_fused_activation_function(fused_activation_function);
return builder_.Finish();
}
@@ -3694,16 +3976,16 @@ inline flatbuffers::Offset<ExpOptions> CreateExpOptions(
flatbuffers::Offset<ExpOptions> CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
-struct MeanOptionsT : public flatbuffers::NativeTable {
- typedef MeanOptions TableType;
+struct ReducerOptionsT : public flatbuffers::NativeTable {
+ typedef ReducerOptions TableType;
bool keep_dims;
- MeanOptionsT()
+ ReducerOptionsT()
: keep_dims(false) {
}
};
-struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
- typedef MeanOptionsT NativeTableType;
+struct ReducerOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ReducerOptionsT NativeTableType;
enum {
VT_KEEP_DIMS = 4
};
@@ -3715,38 +3997,38 @@ struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyField<uint8_t>(verifier, VT_KEEP_DIMS) &&
verifier.EndTable();
}
- MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<MeanOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+ ReducerOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ReducerOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
-struct MeanOptionsBuilder {
+struct ReducerOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
void add_keep_dims(bool keep_dims) {
- fbb_.AddElement<uint8_t>(MeanOptions::VT_KEEP_DIMS, static_cast<uint8_t>(keep_dims), 0);
+ fbb_.AddElement<uint8_t>(ReducerOptions::VT_KEEP_DIMS, static_cast<uint8_t>(keep_dims), 0);
}
- explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ explicit ReducerOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
}
- MeanOptionsBuilder &operator=(const MeanOptionsBuilder &);
- flatbuffers::Offset<MeanOptions> Finish() {
+ ReducerOptionsBuilder &operator=(const ReducerOptionsBuilder &);
+ flatbuffers::Offset<ReducerOptions> Finish() {
const auto end = fbb_.EndTable(start_);
- auto o = flatbuffers::Offset<MeanOptions>(end);
+ auto o = flatbuffers::Offset<ReducerOptions>(end);
return o;
}
};
-inline flatbuffers::Offset<MeanOptions> CreateMeanOptions(
+inline flatbuffers::Offset<ReducerOptions> CreateReducerOptions(
flatbuffers::FlatBufferBuilder &_fbb,
bool keep_dims = false) {
- MeanOptionsBuilder builder_(_fbb);
+ ReducerOptionsBuilder builder_(_fbb);
builder_.add_keep_dims(keep_dims);
return builder_.Finish();
}
-flatbuffers::Offset<MeanOptions> CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<ReducerOptions> CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct SqueezeOptionsT : public flatbuffers::NativeTable {
typedef SqueezeOptions TableType;
@@ -4152,6 +4434,46 @@ inline flatbuffers::Offset<MaximumMinimumOptions> CreateMaximumMinimumOptions(
flatbuffers::Offset<MaximumMinimumOptions> CreateMaximumMinimumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct TileOptionsT : public flatbuffers::NativeTable {
+ typedef TileOptions TableType;
+ TileOptionsT() {
+ }
+};
+
+struct TileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TileOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ TileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<TileOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TileOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit TileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TileOptionsBuilder &operator=(const TileOptionsBuilder &);
+ flatbuffers::Offset<TileOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TileOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TileOptions> CreateTileOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ TileOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<TileOptions> CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ArgMaxOptionsT : public flatbuffers::NativeTable {
typedef ArgMaxOptions TableType;
TensorType output_type;
@@ -4206,6 +4528,60 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ArgMinOptionsT : public flatbuffers::NativeTable {
+ typedef ArgMinOptions TableType;
+ TensorType output_type;
+ ArgMinOptionsT()
+ : output_type(TensorType_FLOAT32) {
+ }
+};
+
+struct ArgMinOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ArgMinOptionsT NativeTableType;
+ enum {
+ VT_OUTPUT_TYPE = 4
+ };
+ TensorType output_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUTPUT_TYPE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_OUTPUT_TYPE) &&
+ verifier.EndTable();
+ }
+ ArgMinOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ArgMinOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ArgMinOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_output_type(TensorType output_type) {
+ fbb_.AddElement<int8_t>(ArgMinOptions::VT_OUTPUT_TYPE, static_cast<int8_t>(output_type), 0);
+ }
+ explicit ArgMinOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ArgMinOptionsBuilder &operator=(const ArgMinOptionsBuilder &);
+ flatbuffers::Offset<ArgMinOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ArgMinOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType output_type = TensorType_FLOAT32) {
+ ArgMinOptionsBuilder builder_(_fbb);
+ builder_.add_output_type(output_type);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct GreaterOptionsT : public flatbuffers::NativeTable {
typedef GreaterOptions TableType;
GreaterOptionsT() {
@@ -4564,6 +4940,46 @@ inline flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(
flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct ExpandDimsOptionsT : public flatbuffers::NativeTable {
+ typedef ExpandDimsOptions TableType;
+ ExpandDimsOptionsT() {
+ }
+};
+
+struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ExpandDimsOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ ExpandDimsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ExpandDimsOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ExpandDimsOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit ExpandDimsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ExpandDimsOptionsBuilder &operator=(const ExpandDimsOptionsBuilder &);
+ flatbuffers::Offset<ExpandDimsOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ExpandDimsOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ExpandDimsOptions> CreateExpandDimsOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ ExpandDimsOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ExpandDimsOptions> CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct SparseToDenseOptionsT : public flatbuffers::NativeTable {
typedef SparseToDenseOptions TableType;
bool validate_indices;
@@ -4618,6 +5034,258 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(
flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct EqualOptionsT : public flatbuffers::NativeTable {
+ typedef EqualOptions TableType;
+ EqualOptionsT() {
+ }
+};
+
+struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef EqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<EqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct EqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ EqualOptionsBuilder &operator=(const EqualOptionsBuilder &);
+ flatbuffers::Offset<EqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<EqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ EqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct NotEqualOptionsT : public flatbuffers::NativeTable {
+ typedef NotEqualOptions TableType;
+ NotEqualOptionsT() {
+ }
+};
+
+struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef NotEqualOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<NotEqualOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct NotEqualOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &);
+ flatbuffers::Offset<NotEqualOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<NotEqualOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ NotEqualOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ShapeOptionsT : public flatbuffers::NativeTable {
+ typedef ShapeOptions TableType;
+ TensorType out_type;
+ ShapeOptionsT()
+ : out_type(TensorType_FLOAT32) {
+ }
+};
+
+struct ShapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ShapeOptionsT NativeTableType;
+ enum {
+ VT_OUT_TYPE = 4
+ };
+ TensorType out_type() const {
+ return static_cast<TensorType>(GetField<int8_t>(VT_OUT_TYPE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_OUT_TYPE) &&
+ verifier.EndTable();
+ }
+ ShapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ShapeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ShapeOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_out_type(TensorType out_type) {
+ fbb_.AddElement<int8_t>(ShapeOptions::VT_OUT_TYPE, static_cast<int8_t>(out_type), 0);
+ }
+ explicit ShapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ShapeOptionsBuilder &operator=(const ShapeOptionsBuilder &);
+ flatbuffers::Offset<ShapeOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ShapeOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ TensorType out_type = TensorType_FLOAT32) {
+ ShapeOptionsBuilder builder_(_fbb);
+ builder_.add_out_type(out_type);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct PowOptionsT : public flatbuffers::NativeTable {
+ typedef PowOptions TableType;
+ PowOptionsT() {
+ }
+};
+
+struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PowOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PowOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PowOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PowOptionsBuilder &operator=(const PowOptionsBuilder &);
+ flatbuffers::Offset<PowOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PowOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PowOptions> CreatePowOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PowOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FakeQuantOptionsT : public flatbuffers::NativeTable {
+ typedef FakeQuantOptions TableType;
+ float min;
+ float max;
+ int32_t num_bits;
+ FakeQuantOptionsT()
+ : min(0.0f),
+ max(0.0f),
+ num_bits(0) {
+ }
+};
+
+struct FakeQuantOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FakeQuantOptionsT NativeTableType;
+ enum {
+ VT_MIN = 4,
+ VT_MAX = 6,
+ VT_NUM_BITS = 8
+ };
+ float min() const {
+ return GetField<float>(VT_MIN, 0.0f);
+ }
+ float max() const {
+ return GetField<float>(VT_MAX, 0.0f);
+ }
+ int32_t num_bits() const {
+ return GetField<int32_t>(VT_NUM_BITS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<float>(verifier, VT_MIN) &&
+ VerifyField<float>(verifier, VT_MAX) &&
+ VerifyField<int32_t>(verifier, VT_NUM_BITS) &&
+ verifier.EndTable();
+ }
+ FakeQuantOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FakeQuantOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FakeQuantOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_min(float min) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MIN, min, 0.0f);
+ }
+ void add_max(float max) {
+ fbb_.AddElement<float>(FakeQuantOptions::VT_MAX, max, 0.0f);
+ }
+ void add_num_bits(int32_t num_bits) {
+ fbb_.AddElement<int32_t>(FakeQuantOptions::VT_NUM_BITS, num_bits, 0);
+ }
+ explicit FakeQuantOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FakeQuantOptionsBuilder &operator=(const FakeQuantOptionsBuilder &);
+ flatbuffers::Offset<FakeQuantOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FakeQuantOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ float min = 0.0f,
+ float max = 0.0f,
+ int32_t num_bits = 0) {
+ FakeQuantOptionsBuilder builder_(_fbb);
+ builder_.add_num_bits(num_bits);
+ builder_.add_max(max);
+ builder_.add_min(min);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -4716,6 +5384,7 @@ struct OperatorT : public flatbuffers::NativeTable {
BuiltinOptionsUnion builtin_options;
std::vector<uint8_t> custom_options;
CustomOptionsFormat custom_options_format;
+ std::vector<bool> mutating_variable_inputs;
OperatorT()
: opcode_index(0),
custom_options_format(CustomOptionsFormat_FLEXBUFFERS) {
@@ -4731,7 +5400,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_BUILTIN_OPTIONS_TYPE = 10,
VT_BUILTIN_OPTIONS = 12,
VT_CUSTOM_OPTIONS = 14,
- VT_CUSTOM_OPTIONS_FORMAT = 16
+ VT_CUSTOM_OPTIONS_FORMAT = 16,
+ VT_MUTATING_VARIABLE_INPUTS = 18
};
uint32_t opcode_index() const {
return GetField<uint32_t>(VT_OPCODE_INDEX, 0);
@@ -4827,8 +5497,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const TransposeOptions *builtin_options_as_TransposeOptions() const {
return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast<const TransposeOptions *>(builtin_options()) : nullptr;
}
- const MeanOptions *builtin_options_as_MeanOptions() const {
- return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast<const MeanOptions *>(builtin_options()) : nullptr;
+ const ReducerOptions *builtin_options_as_ReducerOptions() const {
+ return builtin_options_type() == BuiltinOptions_ReducerOptions ? static_cast<const ReducerOptions *>(builtin_options()) : nullptr;
}
const SubOptions *builtin_options_as_SubOptions() const {
return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast<const SubOptions *>(builtin_options()) : nullptr;
@@ -4899,12 +5569,39 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const {
return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast<const SparseToDenseOptions *>(builtin_options()) : nullptr;
}
+ const TileOptions *builtin_options_as_TileOptions() const {
+ return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast<const TileOptions *>(builtin_options()) : nullptr;
+ }
+ const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const {
+ return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast<const ExpandDimsOptions *>(builtin_options()) : nullptr;
+ }
+ const EqualOptions *builtin_options_as_EqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast<const EqualOptions *>(builtin_options()) : nullptr;
+ }
+ const NotEqualOptions *builtin_options_as_NotEqualOptions() const {
+ return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast<const NotEqualOptions *>(builtin_options()) : nullptr;
+ }
+ const ShapeOptions *builtin_options_as_ShapeOptions() const {
+ return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast<const ShapeOptions *>(builtin_options()) : nullptr;
+ }
+ const PowOptions *builtin_options_as_PowOptions() const {
+ return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr;
+ }
+ const ArgMinOptions *builtin_options_as_ArgMinOptions() const {
+ return builtin_options_type() == BuiltinOptions_ArgMinOptions ? static_cast<const ArgMinOptions *>(builtin_options()) : nullptr;
+ }
+ const FakeQuantOptions *builtin_options_as_FakeQuantOptions() const {
+ return builtin_options_type() == BuiltinOptions_FakeQuantOptions ? static_cast<const FakeQuantOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
CustomOptionsFormat custom_options_format() const {
return static_cast<CustomOptionsFormat>(GetField<int8_t>(VT_CUSTOM_OPTIONS_FORMAT, 0));
}
+ const flatbuffers::Vector<uint8_t> *mutating_variable_inputs() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_MUTATING_VARIABLE_INPUTS);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX) &&
@@ -4918,6 +5615,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VerifyOffset(verifier, VT_CUSTOM_OPTIONS) &&
verifier.Verify(custom_options()) &&
VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT) &&
+ VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
+ verifier.Verify(mutating_variable_inputs()) &&
verifier.EndTable();
}
OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -5029,8 +5728,8 @@ template<> inline const TransposeOptions *Operator::builtin_options_as<Transpose
return builtin_options_as_TransposeOptions();
}
-template<> inline const MeanOptions *Operator::builtin_options_as<MeanOptions>() const {
- return builtin_options_as_MeanOptions();
+template<> inline const ReducerOptions *Operator::builtin_options_as<ReducerOptions>() const {
+ return builtin_options_as_ReducerOptions();
}
template<> inline const SubOptions *Operator::builtin_options_as<SubOptions>() const {
@@ -5125,6 +5824,38 @@ template<> inline const SparseToDenseOptions *Operator::builtin_options_as<Spars
return builtin_options_as_SparseToDenseOptions();
}
+template<> inline const TileOptions *Operator::builtin_options_as<TileOptions>() const {
+ return builtin_options_as_TileOptions();
+}
+
+template<> inline const ExpandDimsOptions *Operator::builtin_options_as<ExpandDimsOptions>() const {
+ return builtin_options_as_ExpandDimsOptions();
+}
+
+template<> inline const EqualOptions *Operator::builtin_options_as<EqualOptions>() const {
+ return builtin_options_as_EqualOptions();
+}
+
+template<> inline const NotEqualOptions *Operator::builtin_options_as<NotEqualOptions>() const {
+ return builtin_options_as_NotEqualOptions();
+}
+
+template<> inline const ShapeOptions *Operator::builtin_options_as<ShapeOptions>() const {
+ return builtin_options_as_ShapeOptions();
+}
+
+template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() const {
+ return builtin_options_as_PowOptions();
+}
+
+template<> inline const ArgMinOptions *Operator::builtin_options_as<ArgMinOptions>() const {
+ return builtin_options_as_ArgMinOptions();
+}
+
+template<> inline const FakeQuantOptions *Operator::builtin_options_as<FakeQuantOptions>() const {
+ return builtin_options_as_FakeQuantOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -5149,6 +5880,9 @@ struct OperatorBuilder {
void add_custom_options_format(CustomOptionsFormat custom_options_format) {
fbb_.AddElement<int8_t>(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast<int8_t>(custom_options_format), 0);
}
+ void add_mutating_variable_inputs(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> mutating_variable_inputs) {
+ fbb_.AddOffset(Operator::VT_MUTATING_VARIABLE_INPUTS, mutating_variable_inputs);
+ }
explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -5169,8 +5903,10 @@ inline flatbuffers::Offset<Operator> CreateOperator(
BuiltinOptions builtin_options_type = BuiltinOptions_NONE,
flatbuffers::Offset<void> builtin_options = 0,
flatbuffers::Offset<flatbuffers::Vector<uint8_t>> custom_options = 0,
- CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) {
+ CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> mutating_variable_inputs = 0) {
OperatorBuilder builder_(_fbb);
+ builder_.add_mutating_variable_inputs(mutating_variable_inputs);
builder_.add_custom_options(custom_options);
builder_.add_builtin_options(builtin_options);
builder_.add_outputs(outputs);
@@ -5189,7 +5925,8 @@ inline flatbuffers::Offset<Operator> CreateOperatorDirect(
BuiltinOptions builtin_options_type = BuiltinOptions_NONE,
flatbuffers::Offset<void> builtin_options = 0,
const std::vector<uint8_t> *custom_options = nullptr,
- CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) {
+ CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS,
+ const std::vector<uint8_t> *mutating_variable_inputs = nullptr) {
return tflite::CreateOperator(
_fbb,
opcode_index,
@@ -5198,7 +5935,8 @@ inline flatbuffers::Offset<Operator> CreateOperatorDirect(
builtin_options_type,
builtin_options,
custom_options ? _fbb.CreateVector<uint8_t>(*custom_options) : 0,
- custom_options_format);
+ custom_options_format,
+ mutating_variable_inputs ? _fbb.CreateVector<uint8_t>(*mutating_variable_inputs) : 0);
}
flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
@@ -5569,6 +6307,7 @@ inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t
{ auto _e = buffer(); _o->buffer = _e; };
{ auto _e = name(); if (_e) _o->name = _e->str(); };
{ auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr<QuantizationParametersT>(_e->UnPack(_resolver)); };
+ { auto _e = is_variable(); _o->is_variable = _e; };
}
inline flatbuffers::Offset<Tensor> Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -5584,13 +6323,15 @@ inline flatbuffers::Offset<Tensor> CreateTensor(flatbuffers::FlatBufferBuilder &
auto _buffer = _o->buffer;
auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name);
auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0;
+ auto _is_variable = _o->is_variable;
return tflite::CreateTensor(
_fbb,
_shape,
_type,
_buffer,
_name,
- _quantization);
+ _quantization,
+ _is_variable);
}
inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -5894,6 +6635,7 @@ inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const fl
(void)_o;
(void)_resolver;
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = weights_format(); _o->weights_format = _e; };
}
inline flatbuffers::Offset<FullyConnectedOptions> FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -5905,9 +6647,11 @@ inline flatbuffers::Offset<FullyConnectedOptions> CreateFullyConnectedOptions(fl
(void)_o;
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _weights_format = _o->weights_format;
return tflite::CreateFullyConnectedOptions(
_fbb,
- _fused_activation_function);
+ _fused_activation_function,
+ _weights_format);
}
inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -6090,6 +6834,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
{ auto _e = cell_clip(); _o->cell_clip = _e; };
{ auto _e = proj_clip(); _o->proj_clip = _e; };
+ { auto _e = kernel_type(); _o->kernel_type = _e; };
}
inline flatbuffers::Offset<LSTMOptions> LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -6103,11 +6848,13 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
auto _fused_activation_function = _o->fused_activation_function;
auto _cell_clip = _o->cell_clip;
auto _proj_clip = _o->proj_clip;
+ auto _kernel_type = _o->kernel_type;
return tflite::CreateLSTMOptions(
_fbb,
_fused_activation_function,
_cell_clip,
- _proj_clip);
+ _proj_clip,
+ _kernel_type);
}
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -6511,28 +7258,28 @@ inline flatbuffers::Offset<ExpOptions> CreateExpOptions(flatbuffers::FlatBufferB
_fbb);
}
-inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
- auto _o = new MeanOptionsT();
+inline ReducerOptionsT *ReducerOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ReducerOptionsT();
UnPackTo(_o, _resolver);
return _o;
}
-inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+inline void ReducerOptions::UnPackTo(ReducerOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
(void)_o;
(void)_resolver;
{ auto _e = keep_dims(); _o->keep_dims = _e; };
}
-inline flatbuffers::Offset<MeanOptions> MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
- return CreateMeanOptions(_fbb, _o, _rehasher);
+inline flatbuffers::Offset<ReducerOptions> ReducerOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateReducerOptions(_fbb, _o, _rehasher);
}
-inline flatbuffers::Offset<MeanOptions> CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+inline flatbuffers::Offset<ReducerOptions> CreateReducerOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReducerOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
(void)_rehasher;
(void)_o;
- struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReducerOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _keep_dims = _o->keep_dims;
- return tflite::CreateMeanOptions(
+ return tflite::CreateReducerOptions(
_fbb,
_keep_dims);
}
@@ -6725,6 +7472,29 @@ inline flatbuffers::Offset<MaximumMinimumOptions> CreateMaximumMinimumOptions(fl
_fbb);
}
+inline TileOptionsT *TileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new TileOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void TileOptions::UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<TileOptions> TileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateTileOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TileOptions> CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateTileOptions(
+ _fbb);
+}
+
inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ArgMaxOptionsT();
UnPackTo(_o, _resolver);
@@ -6751,6 +7521,32 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
_output_type);
}
+inline ArgMinOptionsT *ArgMinOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ArgMinOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ArgMinOptions::UnPackTo(ArgMinOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = output_type(); _o->output_type = _e; };
+}
+
+inline flatbuffers::Offset<ArgMinOptions> ArgMinOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateArgMinOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ArgMinOptions> CreateArgMinOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMinOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ArgMinOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _output_type = _o->output_type;
+ return tflite::CreateArgMinOptions(
+ _fbb,
+ _output_type);
+}
+
inline GreaterOptionsT *GreaterOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new GreaterOptionsT();
UnPackTo(_o, _resolver);
@@ -6944,6 +7740,29 @@ inline flatbuffers::Offset<TransposeConvOptions> CreateTransposeConvOptions(flat
_stride_h);
}
+inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ExpandDimsOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<ExpandDimsOptions> ExpandDimsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateExpandDimsOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ExpandDimsOptions> CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateExpandDimsOptions(
+ _fbb);
+}
+
inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new SparseToDenseOptionsT();
UnPackTo(_o, _resolver);
@@ -6970,6 +7789,133 @@ inline flatbuffers::Offset<SparseToDenseOptions> CreateSparseToDenseOptions(flat
_validate_indices);
}
+inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new EqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<EqualOptions> EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<EqualOptions> CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateEqualOptions(
+ _fbb);
+}
+
+inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new NotEqualOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<NotEqualOptions> NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateNotEqualOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<NotEqualOptions> CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateNotEqualOptions(
+ _fbb);
+}
+
+inline ShapeOptionsT *ShapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ShapeOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ShapeOptions::UnPackTo(ShapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = out_type(); _o->out_type = _e; };
+}
+
+inline flatbuffers::Offset<ShapeOptions> ShapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateShapeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ShapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _out_type = _o->out_type;
+ return tflite::CreateShapeOptions(
+ _fbb,
+ _out_type);
+}
+
+inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new PowOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PowOptions> PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePowOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreatePowOptions(
+ _fbb);
+}
+
+inline FakeQuantOptionsT *FakeQuantOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FakeQuantOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FakeQuantOptions::UnPackTo(FakeQuantOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = min(); _o->min = _e; };
+ { auto _e = max(); _o->max = _e; };
+ { auto _e = num_bits(); _o->num_bits = _e; };
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> FakeQuantOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFakeQuantOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FakeQuantOptions> CreateFakeQuantOptions(flatbuffers::FlatBufferBuilder &_fbb, const FakeQuantOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FakeQuantOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _min = _o->min;
+ auto _max = _o->max;
+ auto _num_bits = _o->num_bits;
+ return tflite::CreateFakeQuantOptions(
+ _fbb,
+ _min,
+ _max,
+ _num_bits);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -7018,6 +7964,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_functi
{ auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); };
{ auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } };
{ auto _e = custom_options_format(); _o->custom_options_format = _e; };
+ { auto _e = mutating_variable_inputs(); if (_e) { _o->mutating_variable_inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->mutating_variable_inputs[_i] = _e->Get(_i) != 0; } } };
}
inline flatbuffers::Offset<Operator> Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7035,6 +7982,7 @@ inline flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuild
auto _builtin_options = _o->builtin_options.Pack(_fbb);
auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0;
auto _custom_options_format = _o->custom_options_format;
+ auto _mutating_variable_inputs = _o->mutating_variable_inputs.size() ? _fbb.CreateVector(_o->mutating_variable_inputs) : 0;
return tflite::CreateOperator(
_fbb,
_opcode_index,
@@ -7043,7 +7991,8 @@ inline flatbuffers::Offset<Operator> CreateOperator(flatbuffers::FlatBufferBuild
_builtin_options_type,
_builtin_options,
_custom_options,
- _custom_options_format);
+ _custom_options_format,
+ _mutating_variable_inputs);
}
inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -7260,8 +8209,8 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
- case BuiltinOptions_MeanOptions: {
- auto ptr = reinterpret_cast<const MeanOptions *>(obj);
+ case BuiltinOptions_ReducerOptions: {
+ auto ptr = reinterpret_cast<const ReducerOptions *>(obj);
return verifier.VerifyTable(ptr);
}
case BuiltinOptions_SubOptions: {
@@ -7356,6 +8305,38 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const SparseToDenseOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_TileOptions: {
+ auto ptr = reinterpret_cast<const TileOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ExpandDimsOptions: {
+ auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ShapeOptions: {
+ auto ptr = reinterpret_cast<const ShapeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -7478,8 +8459,8 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
return ptr->UnPack(resolver);
}
- case BuiltinOptions_MeanOptions: {
- auto ptr = reinterpret_cast<const MeanOptions *>(obj);
+ case BuiltinOptions_ReducerOptions: {
+ auto ptr = reinterpret_cast<const ReducerOptions *>(obj);
return ptr->UnPack(resolver);
}
case BuiltinOptions_SubOptions: {
@@ -7574,6 +8555,38 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const SparseToDenseOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_TileOptions: {
+ auto ptr = reinterpret_cast<const TileOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_ExpandDimsOptions: {
+ auto ptr = reinterpret_cast<const ExpandDimsOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_ShapeOptions: {
+ auto ptr = reinterpret_cast<const ShapeOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -7684,9 +8697,9 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const TransposeOptionsT *>(value);
return CreateTransposeOptions(_fbb, ptr, _rehasher).Union();
}
- case BuiltinOptions_MeanOptions: {
- auto ptr = reinterpret_cast<const MeanOptionsT *>(value);
- return CreateMeanOptions(_fbb, ptr, _rehasher).Union();
+ case BuiltinOptions_ReducerOptions: {
+ auto ptr = reinterpret_cast<const ReducerOptionsT *>(value);
+ return CreateReducerOptions(_fbb, ptr, _rehasher).Union();
}
case BuiltinOptions_SubOptions: {
auto ptr = reinterpret_cast<const SubOptionsT *>(value);
@@ -7780,6 +8793,38 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const SparseToDenseOptionsT *>(value);
return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_TileOptions: {
+ auto ptr = reinterpret_cast<const TileOptionsT *>(value);
+ return CreateTileOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_ExpandDimsOptions: {
+ auto ptr = reinterpret_cast<const ExpandDimsOptionsT *>(value);
+ return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<const EqualOptionsT *>(value);
+ return CreateEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<const NotEqualOptionsT *>(value);
+ return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_ShapeOptions: {
+ auto ptr = reinterpret_cast<const ShapeOptionsT *>(value);
+ return CreateShapeOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptionsT *>(value);
+ return CreatePowOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<const ArgMinOptionsT *>(value);
+ return CreateArgMinOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<const FakeQuantOptionsT *>(value);
+ return CreateFakeQuantOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -7890,8 +8935,8 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new TransposeOptionsT(*reinterpret_cast<TransposeOptionsT *>(u.value));
break;
}
- case BuiltinOptions_MeanOptions: {
- value = new MeanOptionsT(*reinterpret_cast<MeanOptionsT *>(u.value));
+ case BuiltinOptions_ReducerOptions: {
+ value = new ReducerOptionsT(*reinterpret_cast<ReducerOptionsT *>(u.value));
break;
}
case BuiltinOptions_SubOptions: {
@@ -7986,6 +9031,38 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new SparseToDenseOptionsT(*reinterpret_cast<SparseToDenseOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_TileOptions: {
+ value = new TileOptionsT(*reinterpret_cast<TileOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_ExpandDimsOptions: {
+ value = new ExpandDimsOptionsT(*reinterpret_cast<ExpandDimsOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_EqualOptions: {
+ value = new EqualOptionsT(*reinterpret_cast<EqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ value = new NotEqualOptionsT(*reinterpret_cast<NotEqualOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_ShapeOptions: {
+ value = new ShapeOptionsT(*reinterpret_cast<ShapeOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_PowOptions: {
+ value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_ArgMinOptions: {
+ value = new ArgMinOptionsT(*reinterpret_cast<ArgMinOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ value = new FakeQuantOptionsT(*reinterpret_cast<FakeQuantOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -8123,8 +9200,8 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
- case BuiltinOptions_MeanOptions: {
- auto ptr = reinterpret_cast<MeanOptionsT *>(value);
+ case BuiltinOptions_ReducerOptions: {
+ auto ptr = reinterpret_cast<ReducerOptionsT *>(value);
delete ptr;
break;
}
@@ -8243,6 +9320,46 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_TileOptions: {
+ auto ptr = reinterpret_cast<TileOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_ExpandDimsOptions: {
+ auto ptr = reinterpret_cast<ExpandDimsOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_EqualOptions: {
+ auto ptr = reinterpret_cast<EqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_NotEqualOptions: {
+ auto ptr = reinterpret_cast<NotEqualOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_ShapeOptions: {
+ auto ptr = reinterpret_cast<ShapeOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<PowOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_ArgMinOptions: {
+ auto ptr = reinterpret_cast<ArgMinOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FakeQuantOptions: {
+ auto ptr = reinterpret_cast<FakeQuantOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc
index 2f2004f56b..4eaf6f1bfe 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.cc
+++ b/tensorflow/contrib/lite/simple_memory_arena.cc
@@ -36,6 +36,12 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context,
ArenaAlloc* new_alloc) {
TF_LITE_ENSURE(context, alignment < arena_alignment_);
+ if (size == 0) {
+ new_alloc->offset = 0;
+ new_alloc->size = 0;
+ return kTfLiteOk;
+ }
+
size_t current_top = 0;
if (!allocs_.empty()) {
@@ -75,6 +81,10 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context,
TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context,
const ArenaAlloc& alloc) {
+ if (alloc.size == 0) {
+ return kTfLiteOk;
+ }
+
int erased_allocs_count = 0;
auto it = allocs_.begin();
while (it != allocs_.end()) {
@@ -122,7 +132,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context,
char** output_ptr) {
TF_LITE_ENSURE(context, committed_);
TF_LITE_ENSURE(context, output_ptr != nullptr);
- *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
+ if (alloc.size == 0) {
+ *output_ptr = nullptr;
+ } else {
+ *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset;
+ }
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index 5faf78b59e..f738315cf2 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -39,7 +39,8 @@ struct ArenaAlloc {
// This small class is responsible for allocating, deallocating and reusing
// dynamic memory from a common underlying buffer. The arena can be used in
// scenarios when the pattern of memory allocations and deallocations is
-// repetitive, e.g. running NN inference in multiple iterations.
+// repetitive, e.g. running NN inference in multiple iterations. Note that
+// zero-sized allocations are explicitly allowed, and will resolve to null.
class SimpleMemoryArena {
public:
explicit SimpleMemoryArena(size_t arena_alignment)
diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc
index 4444f642eb..60d4d5e768 100644
--- a/tensorflow/contrib/lite/simple_memory_arena_test.cc
+++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc
@@ -43,6 +43,47 @@ TEST(SimpleMemoryArenaTest, BasicArenaOperations) {
EXPECT_EQ(allocs[5].offset, 1024);
}
+TEST(SimpleMemoryArenaTest, BasicZeroAlloc) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc alloc;
+
+ // Zero-sized allocs should have a 0 offset and size.
+ ASSERT_EQ(arena.Allocate(&context, 32, 0, &alloc), kTfLiteOk);
+ EXPECT_EQ(alloc.offset, 0);
+ EXPECT_EQ(alloc.size, 0);
+
+ // Deallocation of zero-sized allocs should always succeed (even redundantly).
+ ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk);
+ ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk);
+
+ // The zero-sized alloc should resolve to null.
+ char* resolved_ptr = nullptr;
+ ASSERT_EQ(arena.Commit(&context), kTfLiteOk);
+ ASSERT_EQ(arena.ResolveAlloc(&context, alloc, &resolved_ptr), kTfLiteOk);
+ EXPECT_EQ(resolved_ptr, nullptr);
+}
+
+TEST(SimpleMemoryArenaTest, InterleavedZeroAlloc) {
+ TfLiteContext context;
+ SimpleMemoryArena arena(64);
+ ArenaAlloc allocs[4];
+
+ // Interleave some zero and non-zero-sized allocations and deallocations.
+ ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[0]), kTfLiteOk);
+ ASSERT_EQ(arena.Allocate(&context, 32, 0, &allocs[1]), kTfLiteOk);
+ ASSERT_EQ(arena.Allocate(&context, 32, 1023, &allocs[2]), kTfLiteOk);
+ ASSERT_EQ(arena.Deallocate(&context, allocs[1]), kTfLiteOk);
+ ASSERT_EQ(arena.Deallocate(&context, allocs[2]), kTfLiteOk);
+ ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[3]), kTfLiteOk);
+
+ // Deallocation of a zero-sized alloc should not impact the allocator offsets.
+ EXPECT_EQ(allocs[0].offset, 0);
+ EXPECT_EQ(allocs[1].offset, 0);
+ EXPECT_EQ(allocs[2].offset, 2048);
+ EXPECT_EQ(allocs[3].offset, 2048);
+}
+
TEST(SimpleMemoryArenaTest, TestAfterClear) {
TfLiteContext context;
SimpleMemoryArena arena(64);
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a89776b29f..a316a40b62 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -105,7 +105,7 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
dims->data[0] = offset_.size() - 1; // Store number of strings.
TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
- tensor);
+ tensor->is_variable, tensor);
}
int GetStringCount(const char* raw_buffer) {
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 74fc32a12b..789bc695f8 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -20,11 +20,15 @@ load(
size = "large",
srcs = ["generated_examples_zip_test.cc"],
args = [
- "--zip_file_path=$(location :zip_%s)" % test_name,
- # TODO(angerson) We may be able to add an external unzip binary instead
- # of relying on an existing one for OSS builds.
- "--unzip_binary_path=/usr/bin/unzip",
- ],
+ ] + select({
+ "//tensorflow:android": [],
+ "//conditions:default": [
+ "--zip_file_path=$(location :zip_%s)" % test_name,
+ # TODO(angerson) We may be able to add an external unzip binary instead
+ # of relying on an existing one for OSS builds.
+ "--unzip_binary_path=/usr/bin/unzip",
+ ],
+ }),
data = [
":zip_%s" % test_name,
],
@@ -155,6 +159,7 @@ cc_library(
deps = [
":split",
":test_runner",
+ "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
@@ -167,6 +172,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
":tflite_driver",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index ae66bd858b..1093bd2cbe 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -58,10 +58,11 @@ from tensorflow.python.ops import rnn
parser = argparse.ArgumentParser(description="Script to generate TFLite tests.")
parser.add_argument("output_path",
help="Directory where the outputs will be go.")
-parser.add_argument("--zip_to_output",
- type=str,
- help="Particular zip to output.",
- required=False)
+parser.add_argument(
+ "--zip_to_output",
+ type=str,
+ help="Particular zip to output.",
+ required=True)
parser.add_argument("--toco",
type=str,
help="Path to toco tool.",
@@ -93,12 +94,10 @@ KNOWN_BUGS = {
r"sigmoid.*input_shape=\[\]": "67645668",
# Concat doesn't work with a single input tensor
r"concat.*num_tensors=1": "67378344",
- # Transposition in MatMul is not supported.
- r"fully_connected.*transpose_.=True": "67586970",
+ # Transposition in MatMul is not fully supported.
+ "fully_connected.*transpose_a=True": "67586970",
# Softmax graphs are too complex.
r"softmax.*dim=0": "67749831",
- # SpaceToDepth only supports float32.
- r"space_to_depth.*(float16|int32|uint8|int64)": "68018134",
# BatchToSpaceND only supports 4D tensors.
r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733",
# Div will use floordiv.
@@ -118,6 +117,8 @@ class ExtraTocoOptions(object):
self.allow_custom_ops = False
# Rnn states that are used to support rnn / lstm cells.
self.rnn_states = None
+ # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite.
+ self.split_tflite_lstm_inputs = None
def toco_options(data_types,
@@ -136,7 +137,7 @@ def toco_options(data_types,
Returns:
the options in a string.
"""
- shape_str = ":".join([",".join(str(y) for y in x) for x in shapes])
+ shape_str = ":".join([",".join(str(y) for y in x) for x in shapes if x])
inference_type = "FLOAT"
# TODO(ahentz): if we get multi-input quantization to work we need this
# to change
@@ -155,6 +156,11 @@ def toco_options(data_types,
s += " --allow_custom_ops"
if extra_toco_options.rnn_states:
s += (" --rnn_states='" + extra_toco_options.rnn_states + "'")
+ if extra_toco_options.split_tflite_lstm_inputs is not None:
+ if extra_toco_options.split_tflite_lstm_inputs:
+ s += " --split_tflite_lstm_inputs=true"
+ else:
+ s += " --split_tflite_lstm_inputs=false"
return s
@@ -461,6 +467,11 @@ def make_zip_of_tests(zip_path,
sess,
tf.global_variables() + inputs +
outputs) if use_frozen_graph else sess.graph_def
+
+ if "split_tflite_lstm_inputs" in param_dict_real:
+ extra_toco_options.split_tflite_lstm_inputs = param_dict_real[
+ "split_tflite_lstm_inputs"]
+
tflite_model_binary, toco_log = toco_convert(
graph_def.SerializeToString(), input_tensors, output_tensors,
extra_toco_options)
@@ -667,6 +678,55 @@ def make_relu6_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_prelu_tests(zip_path):
+ """Make a set of tests to do PReLU."""
+
+ test_parameters = [{
+ # The canonical case for image processing is having a 4D `input` (NHWC)
+ # and `shared_axes`=[1, 2], so the alpha parameter is per channel.
+ "input_shape": [[1, 10, 10, 3], [3, 3, 3, 3]],
+ "shared_axes": [[1, 2], [1]],
+ }]
+
+ def build_graph(parameters):
+ """Build the graph for the test case."""
+
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name="input", shape=parameters["input_shape"])
+ prelu = tf.keras.layers.PReLU(shared_axes=parameters["shared_axes"])
+ out = prelu(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Build the inputs for the test case."""
+
+ input_shape = parameters["input_shape"]
+ input_values = create_tensor_data(
+ np.float32, input_shape, min_value=-10, max_value=10)
+ shared_axes = parameters["shared_axes"]
+
+ alpha_shape = []
+ for dim in range(1, len(input_shape)):
+ alpha_shape.append(1 if dim in shared_axes else input_shape[dim])
+
+ alpha_values = create_tensor_data(np.float32, alpha_shape)
+
+ # There should be only 1 trainable variable tensor.
+ variables = tf.all_variables()
+ assert len(variables) == 1
+ sess.run(variables[0].assign(alpha_values))
+
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(
+ zip_path,
+ test_parameters,
+ build_graph,
+ build_inputs,
+ use_frozen_graph=True)
+
+
# This function tests various TensorFLow functions that generates Const op,
# including `tf.ones`, `tf.zeros` and random functions.
def make_constant_tests(zip_path):
@@ -694,7 +754,7 @@ def make_constant_tests(zip_path):
def make_binary_op_tests(zip_path, binary_operator):
- """Make a set of tests to do add with and without broadcast."""
+ """Make a set of tests to do binary ops with and without broadcast."""
# These parameters are split because we don't support broadcasting.
test_parameters = [{
@@ -744,65 +804,89 @@ def make_binary_op_tests(zip_path, binary_operator):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_mean_tests(zip_path):
- """Make a set of tests to do mean."""
+def make_reduce_tests(reduce_op):
+ """Make a set of tests to do reduce operation.
- test_parameters = [{
- "input_dtype": [tf.float32, tf.int32, tf.int64],
- "input_shape": [[3, 2, 4]],
- "axis": [
- None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
- [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
- [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
- ],
- "const_axis": [True, False],
- "keepdims": [True, False],
- }, {
- "input_dtype": [tf.float32],
- "input_shape": [[1, 8, 8, 3]],
- "axis": [
- None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
- [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2,
- -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
- [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
- ],
- "const_axis": [True, False],
- "keepdims": [True, False],
- }]
+ Args:
+ reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`.
- def build_graph(parameters):
- """Build the mean op testing graph."""
- input_tensor = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input",
- shape=parameters["input_shape"])
+ Returns:
+ a function representing the true generator with `reduce_op_in` curried.
+ """
- # Get axis as either a placeholder or constants.
- if parameters["const_axis"]:
- axis = parameters["axis"]
- input_tensors = [input_tensor]
- else:
- if isinstance(parameters["axis"], list):
- shape = [len(parameters["axis"])]
+ def f(zip_path):
+ """Actual function that generates examples."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape": [[3, 2, 4]],
+ "axis": [
+ None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
+ [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
+ [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
+ ],
+ "const_axis": [True, False],
+ "keepdims": [True, False],
+ }, {
+ "input_dtype": [tf.float32],
+ "input_shape": [[1, 8, 8, 3]],
+ "axis": [
+ None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
+ [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2,
+ -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
+ [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
+ ],
+ "const_axis": [True, False],
+ "keepdims": [True, False],
+ }]
+
+ def build_graph(parameters):
+ """Build the mean op testing graph."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+
+ # Get axis as either a placeholder or constants.
+ if parameters["const_axis"]:
+ axis = parameters["axis"]
+ input_tensors = [input_tensor]
else:
- shape = [0] # shape for None or integers.
- axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape)
- input_tensors = [input_tensor, axis]
+ if isinstance(parameters["axis"], list):
+ shape = [len(parameters["axis"])]
+ else:
+ shape = [0] # shape for None or integers.
+ axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape)
+ input_tensors = [input_tensor, axis]
- out = tf.reduce_mean(
- input_tensor, axis=axis, keepdims=parameters["keepdims"])
- return input_tensors, [out]
+ out = reduce_op(
+ input_tensor, axis=axis, keepdims=parameters["keepdims"])
+ return input_tensors, [out]
- def build_inputs(parameters, sess, inputs, outputs):
- values = [
- create_tensor_data(parameters["input_dtype"], parameters["input_shape"])
- ]
- if not parameters["const_axis"]:
- if parameters["axis"]:
- values.append(np.array(parameters["axis"]))
- return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+ def build_inputs(parameters, sess, inputs, outputs):
+ values = [
+ create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])]
+ if not parameters["const_axis"]:
+ if parameters["axis"]:
+ values.append(np.array(parameters["axis"]))
+ return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
- make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+ return f
+
+
+def make_mean_tests(zip_path):
+ """Make a set of tests to do mean."""
+
+ return make_reduce_tests(tf.reduce_mean)(zip_path)
+
+
+def make_sum_tests(zip_path):
+ """Make a set of tests to do sum."""
+
+ return make_reduce_tests(tf.reduce_sum)(zip_path)
def make_exp_tests(zip_path):
@@ -955,6 +1039,10 @@ def make_mul_tests(zip_path):
make_binary_op_tests(zip_path, tf.multiply)
+def make_pow_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.pow)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
@@ -1286,6 +1374,12 @@ def make_fully_connected_tests(zip_path):
"transpose_a": [False],
"transpose_b": [False],
"constant_filter": [True, False],
+ }, {
+ "shape1": [[40, 37]],
+ "shape2": [[40, 37]],
+ "transpose_a": [False],
+ "transpose_b": [True],
+ "constant_filter": [True, False],
}]
def build_graph(parameters):
@@ -1510,6 +1604,32 @@ def make_reshape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_shape_tests(zip_path):
+ """Make a set of tests to do shape."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32],
+ "input_shape": [[], [0], [1, 1, 1, 3], [2, 3, 4, 5], [5, 5], [10]],
+ "out_type": [tf.int32, tf.int64],
+ }]
+
+ def build_graph(parameters):
+ """Build the topk op testing graph."""
+ # Note that we intentionally leave out the shape from the input placeholder
+ # to prevent the Shape operation from being optimized out during conversion.
+ input_value = tf.placeholder(dtype=parameters["input_dtype"], name="input")
+ out = tf.shape(input_value, out_type=parameters["out_type"])
+ return [input_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_resize_bilinear_tests(zip_path):
"""Make a set of tests to do resize_bilinear."""
@@ -1591,7 +1711,7 @@ def make_space_to_depth_tests(zip_path):
"""Make a set of tests to do space_to_depth."""
test_parameters = [{
- "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64],
+ "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64],
"input_shape": [[2, 12, 24, 1]],
"block_size": [2, 3, 4],
}]
@@ -2001,6 +2121,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
+ "split_tflite_lstm_inputs": [True, False],
},
]
@@ -2103,7 +2224,7 @@ def make_topk_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_arg_max_tests(zip_path):
+def make_arg_min_max_tests(zip_path):
"""Make a set of tests to do arg_max."""
test_parameters = [{
@@ -2111,6 +2232,7 @@ def make_arg_max_tests(zip_path):
"input_shape": [[1, 1, 1, 3], [2, 3, 4, 5], [2, 3, 3], [5, 5], [10]],
"output_type": [tf.int32, tf.int64],
"axis_is_last_dim": [True, False],
+ "is_arg_max": [True],
}]
def build_graph(parameters):
@@ -2123,7 +2245,10 @@ def make_arg_max_tests(zip_path):
axis = len(parameters["input_shape"]) - 1
else:
axis = random.randint(0, max(len(parameters["input_shape"]) - 2, 0))
- out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ if parameters["is_arg_max"]:
+ out = tf.arg_max(input_value, axis, output_type=parameters["output_type"])
+ else:
+ out = tf.arg_min(input_value, axis, output_type=parameters["output_type"])
return [input_value], [out]
def build_inputs(parameters, sess, inputs, outputs):
@@ -2135,6 +2260,74 @@ def make_arg_max_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_equal_tests(zip_path):
+ """Make a set of tests to do equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the equal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_not_equal_tests(zip_path):
+ """Make a set of tests to do not equal."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+ ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+ ([5, 5], [1]), ([10], [2, 4, 10])],
+ }]
+
+ def build_graph(parameters):
+ """Build the not euqal op testing graph."""
+ input_value1 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape_pair"][0])
+ input_value2 = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input2",
+ shape=parameters["input_shape_pair"][1])
+ out = tf.not_equal(input_value1, input_value2)
+ return [input_value1, input_value2], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value1 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][0])
+ input_value2 = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape_pair"][1])
+ return [input_value1, input_value2], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_greater_tests(zip_path):
"""Make a set of tests to do greater."""
@@ -2322,30 +2515,54 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def _make_elementwise_tests(op):
+ """Make a set of tests to do element-wise operations."""
+
+ def f(zip_path):
+ """Actual function that generates examples."""
+ test_parameters = [{
+ "input_dtype": [tf.float32],
+ "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ """Build the unary op testing graph."""
+ input_value = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input1",
+ shape=parameters["input_shape"])
+ out = op(input_value)
+ return [input_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict={inputs[0]: input_value})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+ return f
+
+
def make_sin_tests(zip_path):
"""Make a set of tests to do sin."""
+ return _make_elementwise_tests(tf.sin)(zip_path)
- test_parameters = [{
- "input_dtype": [tf.float32],
- "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
- }]
- def build_graph(parameters):
- """Build the sin op testing graph."""
- input_value = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input1",
- shape=parameters["input_shape"])
- out = tf.sin(input_value)
- return [input_value], [out]
+def make_log_tests(zip_path):
+ """Make a set of tests to do log."""
+ return _make_elementwise_tests(tf.log)(zip_path)
- def build_inputs(parameters, sess, inputs, outputs):
- input_value = create_tensor_data(parameters["input_dtype"],
- parameters["input_shape"])
- return [input_value], sess.run(
- outputs, feed_dict={inputs[0]: input_value})
- make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_sqrt_tests(zip_path):
+ """Make a set of tests to do sqrt."""
+ return _make_elementwise_tests(tf.sqrt)(zip_path)
+
+
+def make_rsqrt_tests(zip_path):
+ """Make a set of tests to do 1/sqrt."""
+ return _make_elementwise_tests(tf.rsqrt)(zip_path)
def make_where_tests(zip_path):
@@ -2499,6 +2716,72 @@ def make_transpose_conv_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_tile_tests(zip_path):
+ """Make a set of tests to do tile."""
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32],
+ "input_shape": [[3, 2, 1], [2, 2, 2]],
+ "multiplier_dtype": [tf.int32, tf.int64],
+ "multiplier_shape": [[3]]
+ }]
+
+ def build_graph(parameters):
+ """Build the tile op testing graph."""
+ input_value = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ shape=parameters["input_shape"],
+ name="input")
+ multiplier_value = tf.placeholder(
+ dtype=parameters["multiplier_dtype"],
+ shape=parameters["multiplier_shape"],
+ name="multiplier")
+ out = tf.tile(input_value, multiplier_value)
+ return [input_value, multiplier_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ multipliers_value = create_tensor_data(parameters["multiplier_dtype"],
+ parameters["multiplier_shape"])
+ return [input_value, multipliers_value], sess.run(
+ outputs,
+ feed_dict={
+ inputs[0]: input_value,
+ inputs[1]: multipliers_value
+ })
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
+def make_expand_dims_tests(zip_path):
+ """Make a set of tests to do expand_dims."""
+
+ test_parameters = [{
+ "input_type": [tf.float32, tf.int32],
+ "input_shape": [[3, 4], [10, 10, 3]],
+ "axis_value": [0, 1, 2, -1, -2],
+ }]
+
+ def build_graph(parameters):
+ """Build the where op testing graph."""
+ input_value = tf.placeholder(
+ dtype=parameters["input_type"],
+ name="input",
+ shape=parameters["input_shape"])
+ axis_value = tf.placeholder(dtype=tf.int32, name="axis", shape=[1])
+ out = tf.expand_dims(input_value, axis=axis_value)
+ return [input_value, axis_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(parameters["input_type"],
+ parameters["input_shape"])
+ axis_value = np.array([parameters["axis_value"]], dtype=np.int32)
+ return [input_value, axis_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value, axis_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_sparse_to_dense_tests(zip_path):
"""Make a set of tests to do sparse to dense."""
@@ -2560,6 +2843,7 @@ def make_sparse_to_dense_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc
index c0c861ff6d..c1092e4d25 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.cc
+++ b/tensorflow/contrib/lite/testing/generate_testspec.cc
@@ -25,7 +25,7 @@ namespace testing {
template <typename T>
void GenerateCsv(const std::vector<int>& shape, float min, float max,
string* out) {
- auto random_float = [](int min, int max) {
+ auto random_float = [](float min, float max) {
static unsigned int seed;
return min + (max - min) * static_cast<float>(rand_r(&seed)) / RAND_MAX;
};
@@ -37,16 +37,10 @@ void GenerateCsv(const std::vector<int>& shape, float min, float max,
*out = Join(data.data(), data.size(), ",");
}
-bool GenerateTestSpecFromTensorflowModel(
- std::iostream& stream, const string& tensorflow_model_path,
- const string& tflite_model_path, const std::vector<string>& input_layer,
+std::vector<string> GenerateInputValues(
+ const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
- const std::vector<string>& input_layer_shape,
- const std::vector<string>& output_layer) {
- CHECK_EQ(input_layer.size(), input_layer_type.size());
- CHECK_EQ(input_layer.size(), input_layer_shape.size());
-
- // Generate inputs.
+ const std::vector<string>& input_layer_shape) {
std::vector<string> input_values;
input_values.resize(input_layer.size());
for (int i = 0; i < input_layer.size(); i++) {
@@ -73,9 +67,22 @@ bool GenerateTestSpecFromTensorflowModel(
default:
fprintf(stderr, "Unsupported type %d (%s) when generating testspec.\n",
type, input_layer_type[i].c_str());
- return false;
+ input_values.clear();
+ return input_values;
}
}
+ return input_values;
+}
+
+bool GenerateTestSpecFromTensorflowModel(
+ std::iostream& stream, const string& tensorflow_model_path,
+ const string& tflite_model_path, int num_invocations,
+ const std::vector<string>& input_layer,
+ const std::vector<string>& input_layer_type,
+ const std::vector<string>& input_layer_shape,
+ const std::vector<string>& output_layer) {
+ CHECK_EQ(input_layer.size(), input_layer_type.size());
+ CHECK_EQ(input_layer.size(), input_layer_shape.size());
// Invoke tensorflow model.
TfDriver runner(input_layer, input_layer_type, input_layer_shape,
@@ -91,39 +98,51 @@ bool GenerateTestSpecFromTensorflowModel(
return false;
}
- for (int i = 0; i < input_values.size(); i++) {
- runner.SetInput(i, input_values[i]);
- if (!runner.IsValid()) {
- cerr << runner.GetErrorMessage() << endl;
- return false;
- }
- }
-
- runner.Invoke();
- if (!runner.IsValid()) {
- cerr << runner.GetErrorMessage() << endl;
- return false;
- }
-
- // Write test spec.
+ // Write first part of test spec, defining model and input shapes.
stream << "load_model: " << tflite_model_path << "\n";
stream << "reshape {\n";
for (const auto& shape : input_layer_shape) {
stream << " input: \"" << shape << "\"\n";
}
stream << "}\n";
- stream << "invoke {\n";
- for (const auto& value : input_values) {
- stream << " input: \"" << value << "\"\n";
- }
- for (int i = 0; i < output_layer.size(); i++) {
- stream << " output: \"" << runner.ReadOutput(i) << "\"\n";
+
+ // Generate inputs.
+ for (int i = 0; i < num_invocations; ++i) {
+ // Note that the input values are random, so each invocation will have a
+ // different set.
+ std::vector<string> input_values =
+ GenerateInputValues(input_layer, input_layer_type, input_layer_shape);
+ if (input_values.empty()) return false;
+
+ // Run TensorFlow.
+ for (int j = 0; j < input_values.size(); j++) {
+ runner.SetInput(j, input_values[j]);
+ if (!runner.IsValid()) {
+ cerr << runner.GetErrorMessage() << endl;
+ return false;
+ }
+ }
+
+ runner.Invoke();
if (!runner.IsValid()) {
cerr << runner.GetErrorMessage() << endl;
return false;
}
+
+ // Write second part of test spec, with inputs and outputs.
+ stream << "invoke {\n";
+ for (const auto& value : input_values) {
+ stream << " input: \"" << value << "\"\n";
+ }
+ for (int j = 0; j < output_layer.size(); j++) {
+ stream << " output: \"" << runner.ReadOutput(j) << "\"\n";
+ if (!runner.IsValid()) {
+ cerr << runner.GetErrorMessage() << endl;
+ return false;
+ }
+ }
+ stream << "}\n";
}
- stream << "}\n";
return true;
}
diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h
index 6e31a853c3..bfaf5e7ec8 100644
--- a/tensorflow/contrib/lite/testing/generate_testspec.h
+++ b/tensorflow/contrib/lite/testing/generate_testspec.h
@@ -30,13 +30,15 @@ namespace testing {
// stream: mutable iostream that contains the contents of test spec.
// tensorflow_model_path: path to TensorFlow model.
// tflite_model_path: path to tflite_model_path that the test spec runs
+// num_invocations: how many pairs of inputs and outputs will be generated.
// against. input_layer: names of input tensors. Example: input1
// input_layer_type: datatypes of input tensors. Example: float
// input_layer_shape: shapes of input tensors, separated by comma. example:
// 1,3,4 output_layer: names of output tensors. Example: output
bool GenerateTestSpecFromTensorflowModel(
std::iostream& stream, const string& tensorflow_model_path,
- const string& tflite_model_path, const std::vector<string>& input_layer,
+ const string& tflite_model_path, int num_invocations,
+ const std::vector<string>& input_layer,
const std::vector<string>& input_layer_type,
const std::vector<string>& input_layer_shape,
const std::vector<string>& output_layer);
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 2f069ff8e7..5bc6b53416 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -36,7 +36,13 @@ bool FLAGS_ignore_known_bugs = true;
// TODO(b/71769302) zip_files_dir should have a more accurate default, if
// possible
string* FLAGS_zip_file_path = new string("./");
+#ifndef __ANDROID__
string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip");
+#else
+string* FLAGS_unzip_binary_path = new string("/system/bin/unzip");
+#endif
+bool FLAGS_use_nnapi = false;
+bool FLAGS_ignore_unsupported_nnapi = false;
} // namespace
// TensorFlow system environment for file system called.
@@ -47,9 +53,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- // Add only supports float32. (and "constant" tests use Add)
- {R"(^\/adda.*int32)", "68808744"},
- {R"(^\/constant.*int32)", "68808744"},
{R"(^\/mul.*int32)", "68808744"},
{R"(^\/div.*int32)", "68808744"},
{R"(^\/sub.*int32)", "68808744"},
@@ -61,25 +64,25 @@ std::map<string, string> kBrokenTests = {
"70527055"},
// L2Norm only supports tensors with 4D or fewer.
- {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
+ {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
// SpaceToBatchND only supports 4D tensors.
{R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"},
// L2Norm only works for dim=-1.
- {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
- {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
- {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
- {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
- {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])",
+ {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"},
+ {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"},
+ {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"},
+ {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])",
"67963812"},
- {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
+ {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"},
// ResizeBilinear looks completely incompatible with Tensorflow
{R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
@@ -94,11 +97,12 @@ std::map<string, string> kBrokenTests = {
{R"(^\/gather.*axis=1)", "76910444"},
// No support for arbitrary dimensions in ArgMax.
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.,.\])",
+ "77546240"},
+ {R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
- {R"(^\/arg_max.*axis_is_last_dim=False.*input_shape=\[.,.\])", "77546240"},
};
// Allows test data to be unzipped into a temporary directory and makes
@@ -212,7 +216,7 @@ TEST_P(OpsTest, RunZipTests) {
std::ifstream tflite_stream(tflite_test_case);
ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case;
- tflite::testing::TfLiteDriver test_driver(/*use_nnapi=*/true);
+ tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi);
test_driver.SetModelBaseDir(tflite_dir);
string bug_number;
@@ -223,16 +227,21 @@ TEST_P(OpsTest, RunZipTests) {
}
bool result = tflite::testing::ParseAndRunTests(&tflite_stream, &test_driver);
+ string message = test_driver.GetErrorMessage();
if (bug_number.empty()) {
- EXPECT_TRUE(result) << test_driver.GetErrorMessage();
+ if (FLAGS_use_nnapi && FLAGS_ignore_unsupported_nnapi && !result) {
+ EXPECT_EQ(message, string("Failed to invoke interpreter")) << message;
+ } else {
+ EXPECT_TRUE(result) << message;
+ }
} else {
if (FLAGS_ignore_known_bugs) {
EXPECT_FALSE(result) << "Test was expected to fail but is now passing; "
"you can mark http://b/"
<< bug_number << " as fixed! Yay!";
} else {
- EXPECT_TRUE(result) << test_driver.GetErrorMessage()
- << ": Possibly due to http://b/" << bug_number;
+ EXPECT_TRUE(result) << message << ": Possibly due to http://b/"
+ << bug_number;
}
}
}
@@ -273,7 +282,13 @@ int main(int argc, char** argv) {
"Required: Location of the test zip file."),
tensorflow::Flag("unzip_binary_path",
tflite::testing::FLAGS_unzip_binary_path,
- "Required: Location of a suitable unzip binary.")};
+ "Required: Location of a suitable unzip binary."),
+ tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi,
+ "Whether to enable the NNAPI delegate"),
+ tensorflow::Flag("ignore_unsupported_nnapi",
+ &tflite::testing::FLAGS_ignore_unsupported_nnapi,
+ "Don't fail tests just because delegation to NNAPI "
+ "is not possible")};
bool success = tensorflow::Flags::Parse(&argc, argv, flags);
if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) {
fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str());
@@ -281,6 +296,8 @@ int main(int argc, char** argv) {
}
::tflite::LogToStderr();
+ // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags
+ // parser removes them?
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
index 5afa0f800c..f2c49fe389 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc
@@ -20,12 +20,29 @@ int main(int argc, char** argv) {
::tflite::testing::DiffOptions options =
::tflite::testing::ParseTfliteDiffFlags(&argc, argv);
if (options.tensorflow_model.empty()) return 1;
+
int failure_count = 0;
- for (int i = 0; i < 100; i++) {
- if (!tflite::testing::RunDiffTest(options)) {
+ for (int i = 0; i < options.num_runs_per_pass; i++) {
+ if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/1)) {
++failure_count;
}
}
- fprintf(stderr, "Num errors: %d\n", failure_count);
+ int failures_in_first_pass = failure_count;
+
+ if (failure_count == 0) {
+ // Let's try again with num_invocations > 1 to make sure we can do multiple
+ // invocations without resetting the interpreter.
+ for (int i = 0; i < options.num_runs_per_pass; i++) {
+ if (!tflite::testing::RunDiffTest(options, /*num_invocations=*/2)) {
+ ++failure_count;
+ }
+ }
+ }
+
+ fprintf(stderr, "Num errors in single-inference pass: %d\n",
+ failures_in_first_pass);
+ fprintf(stderr, "Num errors in multi-inference pass : %d\n",
+ failure_count - failures_in_first_pass);
+
return failure_count != 0 ? 1 : 0;
}
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
index 706108ed73..7a57e8d3fb 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_flags.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h
@@ -30,6 +30,7 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
string input_layer_type;
string input_layer_shape;
string output_layer;
+ int32_t num_runs_per_pass = 100;
} values;
std::vector<tensorflow::Flag> flags = {
@@ -49,6 +50,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
tensorflow::Flag("output_layer", &values.output_layer,
"Names of output tensors, separated by comma. Example "
"output_1,output_2"),
+ tensorflow::Flag("num_runs_per_pass", &values.num_runs_per_pass,
+ "Number of full runs in each pass."),
};
bool no_inputs = *argc == 1;
@@ -63,7 +66,8 @@ DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) {
Split<string>(values.input_layer, ","),
Split<string>(values.input_layer_type, ","),
Split<string>(values.input_layer_shape, ":"),
- Split<string>(values.output_layer, ",")};
+ Split<string>(values.output_layer, ","),
+ values.num_runs_per_pass};
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
index f601d3752d..19f34c0a51 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.cc
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc
@@ -25,13 +25,14 @@ limitations under the License.
namespace tflite {
namespace testing {
-bool RunDiffTest(const DiffOptions& options) {
+bool RunDiffTest(const DiffOptions& options, int num_invocations) {
std::stringstream tflite_stream;
if (!GenerateTestSpecFromTensorflowModel(
tflite_stream, options.tensorflow_model, options.tflite_model,
- options.input_layer, options.input_layer_type,
- options.input_layer_shape, options.output_layer))
+ num_invocations, options.input_layer, options.input_layer_type,
+ options.input_layer_shape, options.output_layer)) {
return false;
+ }
TfLiteDriver tflite_driver(/*use_nnapi=*/true);
tflite_driver.LoadModel(options.tflite_model);
return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver);
diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h
index 326fa6c3e2..4ab2f230fd 100644
--- a/tensorflow/contrib/lite/testing/tflite_diff_util.h
+++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h
@@ -40,10 +40,14 @@ struct DiffOptions {
// Names of output tensors.
// Example output_1,output_2
std::vector<string> output_layer;
+ // Number of full runs (from building interpreter to checking outputs) in
+ // each of the passes. The first pass has a single inference, while the
+ // second pass does multiple inferences back to back.
+ int num_runs_per_pass;
};
// Run a single TensorFLow Lite diff test with a given options.
-bool RunDiffTest(const DiffOptions& options);
+bool RunDiffTest(const DiffOptions& options, int num_invocations);
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 8cab6cd8cd..4d08fb5458 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <iostream>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/testing/split.h"
namespace tflite {
@@ -162,6 +163,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) {
Invalidate("Failed build interpreter");
return;
}
+ interpreter_->UseNNAPI(use_nnapi_);
must_allocate_tensors_ = true;
}
@@ -283,19 +285,26 @@ bool TfLiteDriver::CheckResults() {
}
void TfLiteDriver::ResetLSTMStateTensors() {
- // This is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Refactoring and find a better way to initialize state
- // tensors. Maybe write the reset instructions into the test data.
+ interpreter_->ResetVariableTensorsToZero();
+
+ // Below is a workaround for initializing state tensors for LSTM.
+ // TODO(ycling): Remove the code below after nobody is using the 18-inputs
+ // definition.
for (auto node_index : interpreter_->execution_plan()) {
const auto& node_and_reg = interpreter_->node_and_registration(node_index);
const auto& node = node_and_reg->first;
const auto& registration = node_and_reg->second;
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
- node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
+
+ if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
+ const auto* params =
+ reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
+ if (params->kernel_type == kTfLiteLSTMFullKernel &&
+ node.inputs->size == 18 && node.outputs->size >= 2) {
+ // The first 2 outputs of LSTM are state tensors.
+ for (int i = 0; i < 2; ++i) {
+ int node_index = node.outputs->data[i];
+ ResetTensor(node_index);
+ }
}
}
}
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index b8acc9a8e0..209dce56cb 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -143,7 +143,6 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
- "//tensorflow/cc/saved_model:tag_constants",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
@@ -170,41 +169,6 @@ cc_library(
)
cc_library(
- name = "toco_saved_model",
- srcs = [
- "toco_saved_model.cc",
- ],
- hdrs = [
- "toco_saved_model.h",
- ],
- visibility = ["//visibility:public"],
- deps = [
- ":model_cmdline_flags",
- ":model_flags_proto_cc",
- ":toco_flags_proto_cc",
- ":types_proto_cc",
- "//tensorflow/cc/tools:freeze_saved_model",
- "//tensorflow/core:protos_all_cc",
- "@com_google_absl//absl/strings",
- ],
-)
-
-tf_cc_test(
- name = "toco_saved_model_test",
- srcs = ["toco_saved_model_test.cc"],
- deps = [
- ":model_cmdline_flags",
- ":toco_cmdline_flags",
- ":toco_saved_model",
- "//tensorflow/cc:cc_ops",
- "//tensorflow/cc:scope",
- "//tensorflow/core:test",
- "@com_google_absl//absl/strings",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-cc_library(
name = "graph_transformations",
srcs = [
"graph_transformations/convert_expanddims_to_reshape.cc",
@@ -213,6 +177,7 @@ cc_library(
"graph_transformations/convert_squeeze_to_reshape.cc",
"graph_transformations/convert_trivial_addn_to_add.cc",
"graph_transformations/convert_trivial_stack_to_reshape.cc",
+ "graph_transformations/convert_trivial_tile_to_concat.cc",
"graph_transformations/convert_trivial_transpose_to_reshape.cc",
"graph_transformations/create_im2col_arrays.cc",
"graph_transformations/dequantize.cc",
@@ -220,10 +185,10 @@ cc_library(
"graph_transformations/drop_im2col_arrays.cc",
"graph_transformations/ensure_bias_vectors.cc",
"graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc",
- "graph_transformations/experimental_shuffle_fc_weights.cc",
"graph_transformations/fuse_activation_functions.cc",
"graph_transformations/fuse_binary_into_following_affine.cc",
"graph_transformations/fuse_binary_into_preceding_affine.cc",
+ "graph_transformations/fuse_broadcast_into_following_binary.cc",
"graph_transformations/graph_transformations.cc",
"graph_transformations/hardcode_min_max.cc",
"graph_transformations/identify_dilated_conv.cc",
@@ -237,6 +202,7 @@ cc_library(
"graph_transformations/lstm_utils.cc",
"graph_transformations/make_initial_dequantize_operator.cc",
"graph_transformations/merge_reshape_into_preceding_transpose.cc",
+ "graph_transformations/move_binary_operator_before_reshape.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
"graph_transformations/propagate_default_min_max.cc",
@@ -245,6 +211,7 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
+ "graph_transformations/quantize_weights.cc",
"graph_transformations/read_fake_quant_min_max.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
@@ -292,8 +259,8 @@ cc_library(
"graph_transformations/resolve_tensorflow_matmul.cc",
"graph_transformations/resolve_tensorflow_merge.cc",
"graph_transformations/resolve_tensorflow_switch.cc",
- "graph_transformations/resolve_tensorflow_tile.cc",
"graph_transformations/resolve_transpose_attributes.cc",
+ "graph_transformations/shuffle_fc_weights.cc",
"graph_transformations/unfuse_activation_functions.cc",
"graph_transformations/unpartition_embedding_lookup.cc",
"graph_transformations/unroll_batch_matmul.cc",
@@ -373,6 +340,7 @@ tf_cc_test(
":toco_tooling",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -410,6 +378,7 @@ tf_cc_test(
deps = [
":model",
":tooling_util",
+ "//tensorflow/core:lib",
"@com_google_googletest//:gtest_main",
],
)
@@ -427,7 +396,6 @@ tf_cc_binary(
":toco_cmdline_flags",
":toco_flags_proto_cc",
":toco_port",
- ":toco_saved_model",
":toco_tooling",
":types_proto_cc",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/toco/README.md b/tensorflow/contrib/lite/toco/README.md
index ee83c7a6e3..2db6a627ab 100644
--- a/tensorflow/contrib/lite/toco/README.md
+++ b/tensorflow/contrib/lite/toco/README.md
@@ -17,11 +17,12 @@ Usage information is given in these documents:
Once an application developer has a trained TensorFlow model, TOCO will accept
that model and generate a TensorFlow Lite
[FlatBuffer](https://google.github.io/flatbuffers/) file. TOCO currently supports
-[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
-and frozen graphs (models generated via
-[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)).
-The TensorFlow Lite FlatBuffer file can be shipped to client devices, generally
-mobile devices, where the TensorFlow Lite interpreter handles them on-device.
-This flow is represented in the diagram below.
+[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators),
+frozen graphs (models generated via
+[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)),
+and `tf.Keras` model files. The TensorFlow Lite FlatBuffer file can be shipped
+to client devices, generally mobile devices, where the TensorFlow Lite
+interpreter handles them on-device. This flow is represented in the diagram
+below.
![drawing](g3doc/toco_landscape.svg)
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 6c0311af0a..aef35ad490 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -21,13 +21,13 @@ limitations under the License.
#include <functional>
#include <unordered_map>
#include <vector>
+#include "tensorflow/contrib/lite/toco/toco_port.h"
#if defined(PLATFORM_GOOGLE)
#include "strings/split.h"
+#include "strings/strip.h"
#endif
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
-#include "tensorflow/cc/saved_model/tag_constants.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
namespace toco {
@@ -145,8 +145,10 @@ class Arg<toco::StringMapList> final {
}
string outer_member_copy = outer_member;
absl::StripAsciiWhitespace(&outer_member);
- if (!TryStripPrefixString(outer_member, "{", &outer_member)) return false;
- if (!TryStripSuffixString(outer_member, "}", &outer_member)) return false;
+ if (!strings::TryStripPrefixString(outer_member, "{", &outer_member))
+ return false;
+ if (!strings::TryStripSuffixString(outer_member, "}", &outer_member))
+ return false;
const std::vector<string> inner_fields_vector =
absl::StrSplit(outer_member, ',');
@@ -223,7 +225,7 @@ struct ParsedTocoFlags {
Arg<string> output_file;
Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF");
Arg<string> output_format = Arg<string>("TFLITE");
- Arg<string> savedmodel_tagset = Arg<string>(tensorflow::kSavedModelTagServe);
+ Arg<string> savedmodel_tagset;
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
@@ -234,6 +236,7 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
+ Arg<bool> quantize_weights = Arg<bool>(false);
// Deprecated flags
Arg<string> input_type;
Arg<string> input_types;
@@ -242,6 +245,7 @@ struct ParsedTocoFlags {
Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
+ Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index 3aeebb14f1..6877fb237c 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -132,6 +132,12 @@ void AppendArrayVal(string* string, Array const& array, int index) {
return;
}
AppendF(string, "%d", data[index]);
+ } else if (array.buffer->type == ArrayDataType::kBool) {
+ const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
+ if (index >= data.size()) {
+ return;
+ }
+ AppendF(string, "%d", data[index]);
}
}
@@ -140,6 +146,7 @@ NodeProperties GetPropertiesForArray(const Model& model,
NodeProperties node_properties;
node_properties.color = GetColorForArray(model, array_name);
node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}});
+ node_properties.log2_buffer_size = 0.0f;
// Append array shape to the label.
auto& array = model.GetArray(array_name);
@@ -159,9 +166,12 @@ NodeProperties GetPropertiesForArray(const Model& model,
}
node_properties.label += "]";
- int buffer_size = RequiredBufferSizeForShape(array.shape());
- node_properties.log2_buffer_size =
- std::log2(static_cast<float>(buffer_size));
+ int buffer_size = 0;
+ if (IsValid(array.shape())) {
+ buffer_size = RequiredBufferSizeForShape(array.shape());
+ node_properties.log2_buffer_size =
+ std::log2(static_cast<float>(buffer_size));
+ }
if (array.buffer) {
const auto& array = model.GetArray(array_name);
@@ -194,8 +204,6 @@ NodeProperties GetPropertiesForArray(const Model& model,
AppendF(&node_properties.label, "}");
}
}
- } else {
- node_properties.log2_buffer_size = 0.0f;
}
if (array.minmax) {
@@ -219,7 +227,7 @@ NodeProperties GetPropertiesForArray(const Model& model,
NodeProperties GetPropertiesForOperator(const Operator& op) {
NodeProperties node_properties;
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
node_properties.label =
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op;
} else {
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 99f0c81a1b..a08cdbfba6 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -145,7 +145,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -162,7 +162,7 @@ void ConvertFloatTensorConst(const string& name, const Shape& input_shape,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -178,7 +178,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -199,7 +199,7 @@ void ConvertFloatTensorConst(const Model& model, const string& name,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -222,7 +222,7 @@ void ConvertIntTensorConst(const Model& model, const string& name,
}
CHECK(model.HasArray(name));
const auto& array = model.GetArray(name);
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -245,7 +245,7 @@ void CreateIntTensorConst(const string& name, const std::vector<int32>& data,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -268,7 +268,7 @@ void CreateMatrixShapeTensorConst(const string& name, int rows, int cols,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -286,7 +286,7 @@ void CreateDummyConcatDimTensorConst(const string& name, int dim,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -301,7 +301,7 @@ void CreateReshapeShapeTensorConst(const string& name,
if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
return;
}
- auto* const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
const_op->set_op("Const");
const_op->set_name(name);
(*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -341,7 +341,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
conv_output += "/conv";
}
- auto* conv2d_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
conv2d_op->set_op("Conv2D");
conv2d_op->set_name(conv_output);
*conv2d_op->add_input() = src_op.inputs[0];
@@ -377,7 +377,7 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
(*conv2d_op->mutable_attr())["padding"].set_s(padding);
if (has_bias) {
- auto* biasadd_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
biasadd_op->set_op("BiasAdd");
biasadd_op->set_name(src_op.outputs[0]);
biasadd_op->add_input(conv_output);
@@ -409,7 +409,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
conv_output += "/conv";
}
- auto* dc2d_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
dc2d_op->set_op("DepthwiseConv2dNative");
dc2d_op->set_name(conv_output);
*dc2d_op->add_input() = src_op.inputs[0];
@@ -457,7 +457,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
(*dc2d_op->mutable_attr())["padding"].set_s(padding);
if (has_bias) {
- auto* biasadd_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
biasadd_op->set_op("BiasAdd");
biasadd_op->set_name(src_op.outputs[0]);
biasadd_op->add_input(conv_output);
@@ -482,7 +482,7 @@ void ConvertDepthwiseConvOperator(const Model& model,
void ConvertTransposeConvOperator(const Model& model,
const TransposeConvOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* conv2d_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
conv2d_op->set_op("Conv2DBackpropInput");
conv2d_op->set_name(src_op.outputs[0]);
*conv2d_op->add_input() = src_op.inputs[0];
@@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model,
const auto& weights_array = model.GetArray(weights_array_name);
CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
- AxesOrder::kHWIO, tensorflow_graph);
+ AxesOrder::kHWOI, tensorflow_graph);
auto& strides = (*conv2d_op->mutable_attr())["strides"];
strides.mutable_list()->add_i(1);
strides.mutable_list()->add_i(src_op.stride_height);
@@ -514,7 +514,7 @@ void ConvertTransposeConvOperator(const Model& model,
void ConvertDepthToSpaceOperator(const Model& model,
const DepthToSpaceOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* op = tensorflow_graph->add_node();
op->set_op("DepthToSpace");
op->set_name(src_op.outputs[0]);
*op->add_input() = src_op.inputs[0];
@@ -525,7 +525,7 @@ void ConvertDepthToSpaceOperator(const Model& model,
void ConvertSpaceToDepthOperator(const Model& model,
const SpaceToDepthOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* op = tensorflow_graph->add_node();
op->set_op("SpaceToDepth");
op->set_name(src_op.outputs[0]);
*op->add_input() = src_op.inputs[0];
@@ -546,7 +546,7 @@ void ConvertFullyConnectedOperator(const Model& model,
CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
tensorflow_graph);
- auto* reshape_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
reshape_op->set_op("Reshape");
reshape_op->set_name(reshape_output);
reshape_op->add_input(src_op.inputs[0]);
@@ -568,7 +568,7 @@ void ConvertFullyConnectedOperator(const Model& model,
const string transpose_perm =
AvailableArrayName(model, transpose_output + "/perm");
CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
- auto transpose_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
transpose_op->set_op("Transpose");
transpose_op->set_name(transpose_output);
*transpose_op->add_input() = src_op.inputs[1];
@@ -577,7 +577,7 @@ void ConvertFullyConnectedOperator(const Model& model,
GetTensorFlowDataType(model, src_op.inputs[1]));
(*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
- auto* matmul_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
matmul_op->set_op("MatMul");
matmul_op->set_name(matmul_output);
*matmul_op->add_input() = reshape_output;
@@ -590,7 +590,7 @@ void ConvertFullyConnectedOperator(const Model& model,
// Add the bias, if it exists.
if (has_bias) {
- auto* biasadd_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
biasadd_op->set_op("BiasAdd");
biasadd_op->set_name(src_op.outputs[0]);
biasadd_op->add_input(matmul_output);
@@ -615,7 +615,7 @@ void ConvertFullyConnectedOperator(const Model& model,
void ConvertAddOperator(const Model& model, const AddOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* add_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
add_op->set_op("Add");
add_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -626,7 +626,7 @@ void ConvertAddOperator(const Model& model, const AddOperator& src_op,
void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* add_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
add_op->set_op("AddN");
add_op->set_name(src_op.outputs[0]);
for (const auto& input : src_op.inputs) {
@@ -638,7 +638,7 @@ void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
void ConvertMulOperator(const Model& model, const MulOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* add_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
add_op->set_op("Mul");
add_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -649,7 +649,7 @@ void ConvertMulOperator(const Model& model, const MulOperator& src_op,
void ConvertReluOperator(const ReluOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* relu_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
relu_op->set_op("Relu");
relu_op->set_name(src_op.outputs[0]);
*relu_op->add_input() = src_op.inputs[0];
@@ -662,7 +662,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op,
const string min_bounds = src_op.outputs[0] + "/min_bounds";
const string max_output = src_op.outputs[0] + "/max_output";
- auto* max_bounds_const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
max_bounds_const_op->set_op("Const");
max_bounds_const_op->set_name(max_bounds);
(*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -671,7 +671,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op,
max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
max_bounds_const_op_tensor->add_float_val(-1.0f);
- auto* min_bounds_const_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
min_bounds_const_op->set_op("Const");
min_bounds_const_op->set_name(min_bounds);
(*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
@@ -680,14 +680,14 @@ void ConvertRelu1Operator(const Relu1Operator& src_op,
min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
min_bounds_const_op_tensor->add_float_val(1.0f);
- auto* max_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
max_op->set_op("Maximum");
max_op->set_name(max_output);
*max_op->add_input() = src_op.inputs[0];
*max_op->add_input() = max_bounds;
(*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
- auto* min_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
min_op->set_op("Minimum");
min_op->set_name(src_op.outputs[0]);
*min_op->add_input() = max_output;
@@ -697,7 +697,7 @@ void ConvertRelu1Operator(const Relu1Operator& src_op,
void ConvertRelu6Operator(const Relu6Operator& src_op,
GraphDef* tensorflow_graph) {
- auto* relu_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
relu_op->set_op("Relu6");
relu_op->set_name(src_op.outputs[0]);
*relu_op->add_input() = src_op.inputs[0];
@@ -705,7 +705,7 @@ void ConvertRelu6Operator(const Relu6Operator& src_op,
}
void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
- auto* op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* op = tensorflow_graph->add_node();
op->set_op("Log");
op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -715,7 +715,7 @@ void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
void ConvertLogisticOperator(const LogisticOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* relu_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
relu_op->set_op("Sigmoid");
relu_op->set_name(src_op.outputs[0]);
*relu_op->add_input() = src_op.inputs[0];
@@ -724,7 +724,7 @@ void ConvertLogisticOperator(const LogisticOperator& src_op,
void ConvertTanhOperator(const TanhOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* tanh_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
tanh_op->set_op("Tanh");
tanh_op->set_name(src_op.outputs[0]);
*tanh_op->add_input() = src_op.inputs[0];
@@ -735,8 +735,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
GraphDef* tensorflow_graph) {
string softmax_input;
Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
- if (providing_op != nullptr &&
- providing_op->type == OperatorType::kTensorFlowReshape) {
+ if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
softmax_input = src_op.inputs[0];
} else {
// Insert a reshape operator that reduces the dimensions down to the 2 that
@@ -745,7 +744,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
const string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
softmax_input = reshape_output;
- auto* reshape_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
reshape_op->set_op("Reshape");
reshape_op->set_name(reshape_output);
*reshape_op->add_input() = src_op.inputs[0];
@@ -762,7 +761,7 @@ void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
}
- auto* softmax_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
softmax_op->set_op("Softmax");
softmax_op->set_name(src_op.outputs[0]);
*softmax_op->add_input() = softmax_input;
@@ -776,8 +775,7 @@ void ConvertLogSoftmaxOperator(const Model& model,
GraphDef* tensorflow_graph) {
string softmax_input;
Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
- if (providing_op != nullptr &&
- providing_op->type == OperatorType::kTensorFlowReshape) {
+ if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
softmax_input = src_op.inputs[0];
} else {
// Insert a reshape operator that reduces the dimensions down to the 2 that
@@ -787,7 +785,7 @@ void ConvertLogSoftmaxOperator(const Model& model,
const string softmax_size = src_op.outputs[0] + "/log_softmax_insert_size";
softmax_input = reshape_output;
- auto* reshape_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
reshape_op->set_op("Reshape");
reshape_op->set_name(reshape_output);
*reshape_op->add_input() = src_op.inputs[0];
@@ -804,7 +802,7 @@ void ConvertLogSoftmaxOperator(const Model& model,
CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
}
- auto* log_softmax_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
log_softmax_op->set_op("LogSoftmax");
log_softmax_op->set_name(src_op.outputs[0]);
*log_softmax_op->add_input() = softmax_input;
@@ -819,7 +817,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
const string rsqrt_output = src_op.outputs[0] + "/rsqrt";
const string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
- auto* sum_reduction_indices_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
sum_reduction_indices_op->set_op("Const");
sum_reduction_indices_op->set_name(sum_reduction_indices);
(*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -833,26 +831,26 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
sum_reduction_indices_tensor->add_int_val(0);
sum_reduction_indices_tensor->add_int_val(1);
- auto* square_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
square_op->set_op("Square");
square_op->set_name(square_output);
*square_op->add_input() = src_op.inputs[0];
(*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
- auto* sum_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
sum_op->set_op("Sum");
sum_op->set_name(sum_output);
*sum_op->add_input() = square_output;
*sum_op->add_input() = sum_reduction_indices;
(*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
- auto* rsqrt_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
rsqrt_op->set_op("Rsqrt");
rsqrt_op->set_name(rsqrt_output);
*rsqrt_op->add_input() = sum_output;
(*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
- auto* mul_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
mul_op->set_op("Mul");
mul_op->set_name(src_op.outputs[0]);
*mul_op->add_input() = src_op.inputs[0];
@@ -863,7 +861,7 @@ void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
void ConvertLocalResponseNormalizationOperator(
const LocalResponseNormalizationOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* lrn_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
lrn_op->set_op("LRN");
lrn_op->set_name(src_op.outputs[0]);
*lrn_op->add_input() = src_op.inputs[0];
@@ -875,7 +873,7 @@ void ConvertLocalResponseNormalizationOperator(
void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* fakequant_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
fakequant_op->set_op("FakeQuantWithMinMaxArgs");
fakequant_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -890,7 +888,7 @@ void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* maxpool_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
maxpool_op->set_op("MaxPool");
maxpool_op->set_name(src_op.outputs[0]);
*maxpool_op->add_input() = src_op.inputs[0];
@@ -918,7 +916,7 @@ void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* avgpool_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
avgpool_op->set_op("AvgPool");
avgpool_op->set_name(src_op.outputs[0]);
*avgpool_op->add_input() = src_op.inputs[0];
@@ -947,7 +945,7 @@ void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
void ConvertConcatenationOperator(const Model& model,
const ConcatenationOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* dc_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
dc_op->set_op("ConcatV2");
dc_op->set_name(src_op.outputs[0]);
const string dummy_axis = src_op.outputs[0] + "/axis";
@@ -965,7 +963,7 @@ void ConvertConcatenationOperator(const Model& model,
void ConvertTensorFlowReshapeOperator(const Model& model,
const TensorFlowReshapeOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* reshape_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
reshape_op->set_op("Reshape");
reshape_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -987,7 +985,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op,
const string square_output = src_op.outputs[0] + "/square";
const string avgpool_output = src_op.outputs[0] + "/avgpool";
- auto* square_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
square_op->set_op("Square");
square_op->set_name(square_output);
*square_op->add_input() = src_op.inputs[0];
@@ -1002,7 +1000,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
- auto* avgpool_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
avgpool_op->set_op("AvgPool");
avgpool_op->set_name(avgpool_output);
*avgpool_op->add_input() = square_output;
@@ -1020,7 +1018,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op,
ksize.mutable_list()->add_i(src_op.kwidth);
ksize.mutable_list()->add_i(1);
- auto* sqrt_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
sqrt_op->set_op("Sqrt");
sqrt_op->set_name(src_op.outputs[0]);
*sqrt_op->add_input() = avgpool_output;
@@ -1029,7 +1027,7 @@ void ConvertL2PoolOperator(const L2PoolOperator& src_op,
void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* square_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
square_op->set_op("Square");
square_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1039,7 +1037,7 @@ void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* sqrt_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
sqrt_op->set_op("Sqrt");
sqrt_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1047,10 +1045,23 @@ void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
(*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
}
+void ConvertRsqrtOperator(const Model& model,
+ const TensorFlowRsqrtOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
+ rsqrt_op->set_op("Rsqrt");
+ rsqrt_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 1);
+ *rsqrt_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertSplitOperator(const Model& model,
const TensorFlowSplitOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* split_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
split_op->set_op("Split");
split_op->set_name(src_op.outputs[0]);
for (const auto& input : src_op.inputs) {
@@ -1071,7 +1082,7 @@ void ConvertSplitOperator(const Model& model,
void ConvertCastOperator(const Model& model, const CastOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* cast_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
cast_op->set_op("Cast");
cast_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1085,7 +1096,7 @@ void ConvertCastOperator(const Model& model, const CastOperator& src_op,
void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* floor_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
floor_op->set_op("Floor");
floor_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1095,7 +1106,7 @@ void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* gather_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
gather_op->set_op("Gather");
gather_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1103,13 +1114,14 @@ void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
*gather_op->add_input() = src_op.inputs[1];
(*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*gather_op->mutable_attr())["Tparams"].set_type(params_type);
}
void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* argmax_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
argmax_op->set_op("ArgMax");
argmax_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1123,10 +1135,26 @@ void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
GetTensorFlowDataType(model, src_op.outputs[0]));
}
+void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
+ argmin_op->set_op("ArgMin");
+ argmin_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *argmin_op->add_input() = src_op.inputs[0];
+ *argmin_op->add_input() = src_op.inputs[1];
+ (*argmin_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[0]));
+ (*argmin_op->mutable_attr())["Tidx"].set_type(
+ GetTensorFlowDataType(model, src_op.inputs[1]));
+ (*argmin_op->mutable_attr())["output_type"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+}
+
void ConvertTransposeOperator(const Model& model,
const TransposeOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* transpose_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
transpose_op->set_op("Transpose");
transpose_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1141,7 +1169,7 @@ void ConvertTransposeOperator(const Model& model,
void ConvertTensorFlowShapeOperator(const Model& model,
const TensorFlowShapeOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* shape_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
shape_op->set_op("Shape");
shape_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1154,7 +1182,7 @@ void ConvertTensorFlowShapeOperator(const Model& model,
void ConvertRankOperator(const Model& model, const RankOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* rank_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
rank_op->set_op("Rank");
rank_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
@@ -1165,7 +1193,7 @@ void ConvertRankOperator(const Model& model, const RankOperator& src_op,
void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* range_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
range_op->set_op("Range");
range_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
@@ -1178,7 +1206,7 @@ void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
void ConvertStackOperator(const Model& model, const StackOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* stack_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* stack_op = tensorflow_graph->add_node();
stack_op->set_op("Stack");
stack_op->set_name(src_op.outputs[0]);
for (const auto& input : src_op.inputs) {
@@ -1191,7 +1219,7 @@ void ConvertStackOperator(const Model& model, const StackOperator& src_op,
void ConvertFillOperator(const Model& model, const FillOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* fill_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
fill_op->set_op("Fill");
fill_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1205,7 +1233,7 @@ void ConvertFillOperator(const Model& model, const FillOperator& src_op,
void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* floor_div_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
floor_div_op->set_op("FloorDiv");
floor_div_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1218,7 +1246,7 @@ void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
void ConvertExpandDimsOperator(const Model& model,
const ExpandDimsOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* expand_dims_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
expand_dims_op->set_op("ExpandDims");
expand_dims_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1233,7 +1261,7 @@ void ConvertExpandDimsOperator(const Model& model,
void ConvertResizeBilinearOperator(const Model& model,
const ResizeBilinearOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* resize_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
resize_op->set_op("ResizeBilinear");
resize_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1283,7 +1311,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// works the same since the tensor has the same underlying data layout.
const string axis_output = concat_output + "/axis";
CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
- auto* concat_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
concat_op->set_op("ConcatV2");
concat_op->set_name(concat_output);
*concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
@@ -1311,7 +1339,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Fully connected matrix multiply
const string matmul_output = base + "MatMul";
- auto* matmul_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
matmul_op->set_op("MatMul");
matmul_op->set_name(matmul_output);
*matmul_op->add_input() = concat_output;
@@ -1340,7 +1368,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Add biases
string biasadd_output = base + "BiasAdd";
- auto* biasadd_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
biasadd_op->set_op("BiasAdd");
biasadd_op->set_name(biasadd_output);
biasadd_op->add_input(matmul_output);
@@ -1353,7 +1381,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// The dimension is the same as the concatenation dimension
CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
string split_output = base + "split";
- auto* split_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
split_op->set_op("Split");
split_op->set_name(split_output);
*split_op->add_input() = split_dim_output;
@@ -1363,21 +1391,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
// Activation functions and memory computations
const string tanh_0_output = base + "Tanh";
- auto* tanh_0_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
tanh_0_op->set_op("Tanh");
tanh_0_op->set_name(tanh_0_output);
*tanh_0_op->add_input() = split_output + ":1";
(*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string sigmoid_1_output = base + "Sigmoid_1";
- auto* logistic_1_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
logistic_1_op->set_op("Sigmoid");
logistic_1_op->set_name(sigmoid_1_output);
*logistic_1_op->add_input() = split_output;
(*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string mul_1_output = base + "mul_1";
- auto* mul_1_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
mul_1_op->set_op("Mul");
mul_1_op->set_name(mul_1_output);
*mul_1_op->add_input() = sigmoid_1_output;
@@ -1385,21 +1413,21 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
(*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string sigmoid_0_output = base + "Sigmoid";
- auto* logistic_2_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
logistic_2_op->set_op("Sigmoid");
logistic_2_op->set_name(sigmoid_0_output);
*logistic_2_op->add_input() = split_output + ":2";
(*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string sigmoid_2_output = base + "Sigmoid_2";
- auto* logistic_3_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
logistic_3_op->set_op("Sigmoid");
logistic_3_op->set_name(sigmoid_2_output);
*logistic_3_op->add_input() = split_output + ":3";
(*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string mul_0_output = base + "mul";
- auto* mul_0_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
mul_0_op->set_op("Mul");
mul_0_op->set_name(mul_0_output);
*mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
@@ -1407,7 +1435,7 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
(*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string add_1_output = src_op.outputs[LstmCellOperator::STATE_OUTPUT];
- auto* add_1_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
add_1_op->set_op("Add");
add_1_op->set_name(add_1_output);
*add_1_op->add_input() = mul_0_output;
@@ -1415,14 +1443,14 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
(*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string tanh_1_output = base + "Tanh_1";
- auto* tanh_1_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
tanh_1_op->set_op("Tanh");
tanh_1_op->set_name(tanh_1_output);
*tanh_1_op->add_input() = add_1_output;
(*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
const string mul_2_output = src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
- auto* mul_2_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
mul_2_op->set_op("Mul");
mul_2_op->set_name(mul_2_output);
*mul_2_op->add_input() = tanh_1_output;
@@ -1433,14 +1461,15 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
void ConvertSpaceToBatchNDOperator(const Model& model,
const SpaceToBatchNDOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("SpaceToBatchND");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
*new_op->add_input() = src_op.inputs[2];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
(*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
(*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
@@ -1449,14 +1478,15 @@ void ConvertSpaceToBatchNDOperator(const Model& model,
void ConvertBatchToSpaceNDOperator(const Model& model,
const BatchToSpaceNDOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("BatchToSpaceND");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
*new_op->add_input() = src_op.inputs[2];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
(*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
(*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
@@ -1464,18 +1494,19 @@ void ConvertBatchToSpaceNDOperator(const Model& model,
void ConvertPadOperator(const Model& model, const PadOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("Pad");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
// Create the params tensor.
- auto* params_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
params_op->set_op("Const");
params_op->set_name(src_op.inputs[1]);
(*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -1494,7 +1525,7 @@ void ConvertPadOperator(const Model& model, const PadOperator& src_op,
void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("PadV2");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1502,11 +1533,12 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
*new_op->add_input() = src_op.inputs[1];
*new_op->add_input() = src_op.inputs[2];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
// Create the params tensor.
- auto* params_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
params_op->set_op("Const");
params_op->set_name(src_op.inputs[1]);
(*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -1525,7 +1557,7 @@ void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
void CreateSliceInput(const string& input_name, const std::vector<int>& values,
GraphDef* tensorflow_graph) {
- auto* params_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
params_op->set_op("Const");
params_op->set_name(input_name);
(*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -1542,7 +1574,7 @@ void CreateSliceInput(const string& input_name, const std::vector<int>& values,
void ConvertStridedSliceOperator(const Model& model,
const StridedSliceOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("StridedSlice");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 4);
@@ -1551,7 +1583,8 @@ void ConvertStridedSliceOperator(const Model& model,
*new_op->add_input() = src_op.inputs[2];
*new_op->add_input() = src_op.inputs[3];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
(*new_op->mutable_attr())["Index"].set_type(DT_INT32);
@@ -1569,7 +1602,7 @@ void ConvertStridedSliceOperator(const Model& model,
void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("Slice");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
@@ -1577,7 +1610,8 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
*new_op->add_input() = src_op.inputs[1];
*new_op->add_input() = src_op.inputs[2];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
(*new_op->mutable_attr())["Index"].set_type(DT_INT32);
@@ -1588,14 +1622,15 @@ void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("Mean");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
if (src_op.keep_dims) {
@@ -1603,7 +1638,7 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
}
// Create the params tensor.
- auto* params_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
params_op->set_op("Const");
params_op->set_name(src_op.inputs[1]);
(*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
@@ -1619,13 +1654,14 @@ void ConvertMeanOperator(const Model& model, const MeanOperator& src_op,
void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("Squeeze");
new_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 1);
*new_op->add_input() = src_op.inputs[0];
- const auto params_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(params_type);
if (!src_op.squeeze_dims.empty()) {
@@ -1638,58 +1674,79 @@ void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
void ConvertSubOperator(const Model& model, const SubOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* sub_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
sub_op->set_op("Sub");
sub_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*sub_op->add_input() = src_op.inputs[0];
*sub_op->add_input() = src_op.inputs[1];
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTensorFlowMinimumOperator(const Model& model,
const TensorFlowMinimumOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* sub_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
sub_op->set_op("Minimum");
sub_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*sub_op->add_input() = src_op.inputs[0];
*sub_op->add_input() = src_op.inputs[1];
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertTensorFlowMaximumOperator(const Model& model,
const TensorFlowMaximumOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* sub_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
sub_op->set_op("Maximum");
sub_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*sub_op->add_input() = src_op.inputs[0];
*sub_op->add_input() = src_op.inputs[1];
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
GraphDef* tensorflow_graph) {
- auto* sub_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
sub_op->set_op("Select");
sub_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 3);
*sub_op->add_input() = src_op.inputs[0];
*sub_op->add_input() = src_op.inputs[1];
*sub_op->add_input() = src_op.inputs[2];
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[1]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
(*sub_op->mutable_attr())["T"].set_type(data_type);
}
+void ConvertTileOperator(const Model& model,
+ const TensorFlowTileOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
+ tile_op->set_op("Tile");
+ tile_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *tile_op->add_input() = src_op.inputs[0];
+ *tile_op->add_input() = src_op.inputs[1];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*tile_op->mutable_attr())["T"].set_type(data_type);
+ const tensorflow::DataType multiples_data_type =
+ GetTensorFlowDataType(model, src_op.inputs[1]);
+ (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
+}
+
void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
GraphDef* tensorflow_graph) {
- auto* topk_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
topk_op->set_op("TOPKV2");
topk_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
@@ -1702,12 +1759,13 @@ void ConvertRandomUniformOperator(const Model& model,
const RandomUniformOperator& src_op,
GraphDef* tensorflow_graph) {
CHECK(tensorflow_graph != nullptr);
- auto* new_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
new_op->set_op("RandomUniform");
CHECK_EQ(src_op.inputs.size(), 1);
new_op->set_name(src_op.outputs[0]);
*new_op->add_input() = src_op.inputs[0];
- const auto shape_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType shape_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(shape_type);
(*new_op->mutable_attr())["dtype"].set_type(
GetTensorFlowDataType(src_op.dtype));
@@ -1718,13 +1776,14 @@ void ConvertRandomUniformOperator(const Model& model,
void ConvertComparisonOperator(const Model& model, const Operator& src_op,
const char* op_name,
GraphDef* tensorflow_graph) {
- auto* comparison_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
comparison_op->set_op(op_name);
comparison_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 2);
*comparison_op->add_input() = src_op.inputs[0];
*comparison_op->add_input() = src_op.inputs[1];
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*comparison_op->mutable_attr())["T"].set_type(data_type);
}
@@ -1732,21 +1791,37 @@ void ConvertSparseToDenseOperator(const Model& model,
const SparseToDenseOperator& src_op,
const char* op_name,
GraphDef* tensorflow_graph) {
- auto* sparse_to_dense_op = tensorflow_graph->add_node();
+ tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
sparse_to_dense_op->set_op(op_name);
sparse_to_dense_op->set_name(src_op.outputs[0]);
CHECK_EQ(src_op.inputs.size(), 4);
for (int i = 0; i < 4; ++i) {
*sparse_to_dense_op->add_input() = src_op.inputs[i];
}
- const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]);
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[3]);
(*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
- const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]);
+ const tensorflow::DataType index_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
(*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
(*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
src_op.validate_indices);
}
+void ConvertPowOperator(const Model& model, const PowOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
+ pow_op->set_op(op_name);
+ pow_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *pow_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*pow_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1827,20 +1902,24 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertConcatenationOperator(
model, static_cast<const ConcatenationOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowReshape) {
+ } else if (src_op.type == OperatorType::kReshape) {
ConvertTensorFlowReshapeOperator(
model, static_cast<const TensorFlowReshapeOperator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kL2Pool) {
ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSquare) {
+ } else if (src_op.type == OperatorType::kSquare) {
ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSqrt) {
+ } else if (src_op.type == OperatorType::kSqrt) {
ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowSplit) {
+ } else if (src_op.type == OperatorType::kRsqrt) {
+ ConvertRsqrtOperator(model,
+ static_cast<const TensorFlowRsqrtOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kSplit) {
ConvertSplitOperator(model,
static_cast<const TensorFlowSplitOperator&>(src_op),
tensorflow_graph);
@@ -1884,11 +1963,11 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kSub) {
ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowMinimum) {
+ } else if (src_op.type == OperatorType::kMinimum) {
ConvertTensorFlowMinimumOperator(
model, static_cast<const TensorFlowMinimumOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowMaximum) {
+ } else if (src_op.type == OperatorType::kMaximum) {
ConvertTensorFlowMaximumOperator(
model, static_cast<const TensorFlowMaximumOperator&>(src_op),
tensorflow_graph);
@@ -1901,13 +1980,16 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kArgMax) {
ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kArgMin) {
+ ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
+ tensorflow_graph);
} else if (src_op.type == OperatorType::kTopK_V2) {
ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
tensorflow_graph);
} else if (src_op.type == OperatorType::kTranspose) {
ConvertTransposeOperator(
model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowShape) {
+ } else if (src_op.type == OperatorType::kShape) {
ConvertTensorFlowShapeOperator(
model, static_cast<const TensorFlowShapeOperator&>(src_op),
tensorflow_graph);
@@ -1938,17 +2020,28 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertRandomUniformOperator(
model, static_cast<const RandomUniformOperator&>(src_op),
tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowGreater) {
+ } else if (src_op.type == OperatorType::kEqual) {
+ ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kNotEqual) {
+ ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kGreater) {
ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) {
+ } else if (src_op.type == OperatorType::kGreaterEqual) {
ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowLess) {
+ } else if (src_op.type == OperatorType::kLess) {
ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
- } else if (src_op.type == OperatorType::kTensorFlowLessEqual) {
+ } else if (src_op.type == OperatorType::kLessEqual) {
ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
} else if (src_op.type == OperatorType::kSelect) {
ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kTile) {
+ ConvertTileOperator(model,
+ static_cast<const TensorFlowTileOperator&>(src_op),
+ tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPow) {
+ ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
@@ -1956,7 +2049,7 @@ void ConvertOperator(const Model& model, const Operator& src_op,
void AddPlaceholder(const string& name, ArrayDataType type,
GraphDef* tensorflow_graph) {
- auto* placeholder = tensorflow_graph->add_node();
+ tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
placeholder->set_op("Placeholder");
switch (type) {
case ArrayDataType::kBool:
@@ -1985,7 +2078,7 @@ void AddPlaceholder(const string& name, ArrayDataType type,
void AddPlaceholderForRNNState(const Model& model, const string& name, int size,
GraphDef* tensorflow_graph) {
- auto* placeholder = tensorflow_graph->add_node();
+ tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
placeholder->set_op("Placeholder");
placeholder->set_name(name);
(*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 7680cdd344..18b7848db8 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -9,59 +9,56 @@ complemented by the following documents:
Table of contents:
-* [Convert a TensorFlow SavedModel to TensorFlow Lite](#savedmodel)
-* [Convert a TensorFlow GraphDef to TensorFlow Lite for float
- inference](#graphdef-float)
+* [Command-line tools](#tools)
+ * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9)
+* [Basic examples](#basic)
+ * [Convert a TensorFlow GraphDef](#graphdef)
+ * [Convert a TensorFlow SavedModel](#savedmodel)
+ * [Convert a tf.keras model](#keras)
* [Quantization](#quantization)
- * [Convert a TensorFlow GraphDef to TensorFlow Lite for quantized
- inference](#graphdef-quant)
+ * [Convert a TensorFlow GraphDef for quantized inference](#graphdef-quant)
* [Use "dummy-quantization" to try out quantized inference on a float
graph](#dummy-quant)
* [Specifying input and output arrays](#specifying-input-and-output-arrays)
- * [Multiple output arrays](#multiple-output-arrays)
* [Multiple input arrays](#multiple-input-arrays)
+ * [Multiple output arrays](#multiple-output-arrays)
* [Specifying subgraphs](#specifying-subgraphs)
-* [Other conversions supported by TOCO](#other-conversions)
- * [Optimize a TensorFlow GraphDef](#optimize-graphdef)
- * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef
- format](#to-graphdef)
-* [Logging](#logging)
- * [Standard logging](#standard-logging)
- * [Verbose logging](#verbose-logging)
- * [Graph "video" logging](#graph-video-logging)
* [Graph visualizations](#graph-visualizations)
* [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot)
* [Using --dump_graphviz](#using-dump-graphviz)
+ * [Graph "video" logging](#graph-video-logging)
* [Legend for the graph visualizations](#graphviz-legend)
-## Convert a TensorFlow SavedModel to TensorFlow Lite <a name="savedmodel"></a>
+## Command-line tools <a name="tools"></a>
-The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite
-FlatBuffer to perform floating-point inference.
+There are two approaches to running TOCO via command line.
-```
-bazel run --config=opt \
- third_party/tensorflow/contrib/lite/toco:toco -- \
- --savedmodel_directory=/tmp/saved_model \
- --output_file=/tmp/foo.tflite
-```
+* `tflite_convert`: Starting from TensorFlow 1.9, the command-line tool
+ `tflite_convert` will be installed as part of the Python package. All of the
+ examples below use `tflite_convert` for simplicity.
+ * Example: `tflite --output_file=...`
+* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow
+ repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository)
+ and use `bazel`. This is the recommended approach for converting models that
+ utilize new features that were not supported by TOCO in TensorFlow 1.9.
+ * Example: `bazel run
+ //tensorflow/contrib/lite/python:tflite_convert --
+ --output_file=...`
-[SavedModel](https://www.tensorflow.org/programmers_guide/saved_model#using_savedmodel_with_estimators)
-has fewer required flags than frozen graphs (described [below](#graphdef-float))
-due to access to additional data contained within the SavedModel. The values for
-`--input_arrays` and `--output_arrays` are an aggregated, alphabetized list of
-the inputs and outputs in the
-[SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within the
-[MetaGraphDef](https://www.tensorflow.org/programmers_guide/saved_model#apis_to_build_and_load_a_savedmodel)
-specified by `--savedmodel_tagset`. The value for `input_shapes` is
-automatically determined from the MetaGraphDef whenever possible. The default
-value for `--inference_type` for SavedModels is `FLOAT`.
+### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
-There is currently no support for MetaGraphDefs without a SignatureDef or for
-MetaGraphDefs that use the [`assets/`
-directory](https://www.tensorflow.org/programmers_guide/saved_model#structure_of_a_savedmodel_directory).
+The recommended approach for using TOCO prior to TensorFlow 1.9 is the [Python
+API](python_api.md#pre-tensorflow-1.9). If a command line tool is desired, the
+`toco` command line tool was available in TensorFlow 1.7. Enter `toco --help` in
+Terminal for additional details on the command-line flags available. There were
+no command line tools in TensorFlow 1.8.
+
+## Basic examples <a name="basic"></a>
-## Convert a TensorFlow GraphDef to TensorFlow Lite for float inference <a name="graphdef-float"></a>
+The following section shows examples of how to convert a basic float-point model
+from each of the supported data formats into a TensorFlow Lite FlatBuffers.
+
+### Convert a TensorFlow GraphDef <a name="graphdef"></a>
The follow example converts a basic TensorFlow GraphDef (frozen by
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py))
@@ -71,19 +68,54 @@ graphs contain the variables stored in Checkpoint files as Const ops.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+tflite_convert \
--output_file=/tmp/foo.tflite \
- --inference_type=FLOAT \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1
+ --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+ --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1
+```
+
+The value for `input_shapes` is automatically determined whenever possible.
+
+### Convert a TensorFlow SavedModel <a name="savedmodel"></a>
+
+The follow example converts a basic TensorFlow SavedModel into a Tensorflow Lite
+FlatBuffer to perform floating-point inference.
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --saved_model_dir=/tmp/saved_model
+```
+
+[SavedModel](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators)
+has fewer required flags than frozen graphs due to access to additional data
+contained within the SavedModel. The values for `--input_arrays` and
+`--output_arrays` are an aggregated, alphabetized list of the inputs and outputs
+in the [SignatureDefs](https://www.tensorflow.org/serving/signature_defs) within
+the
+[MetaGraphDef](https://www.tensorflow.org/guide/saved_model#apis_to_build_and_load_a_savedmodel)
+specified by `--saved_model_tag_set`. As with the GraphDef, the value for
+`input_shapes` is automatically determined whenever possible.
+
+There is currently no support for MetaGraphDefs without a SignatureDef or for
+MetaGraphDefs that use the [`assets/`
+directory](https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory).
+
+### Convert a tf.Keras model <a name="keras"></a>
+
+The following example converts a `tf.keras` model into a TensorFlow Lite
+Flatbuffer. The `tf.keras` file must contain both the model and the weights.
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --keras_model_file=/tmp/keras_model.h5
```
## Quantization
-### Convert a TensorFlow GraphDef to TensorFlow Lite for quantized inference <a name="graphdef-quant"></a>
+### Convert a TensorFlow GraphDef for quantized inference <a name="graphdef-quant"></a>
TOCO is compatible with fixed point quantization models described
[here](https://www.tensorflow.org/performance/quantization). These are float
@@ -97,18 +129,14 @@ The following command generates a quantized TensorFlow Lite FlatBuffer from a
"quantized" TensorFlow GraphDef.
```
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/some_quantized_graph.pb \
+tflite_convert \
--output_file=/tmp/foo.tflite \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
+ --graph_def_file=/tmp/some_quantized_graph.pb \
--inference_type=QUANTIZED_UINT8 \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1 \
- --mean_value=128 \
- --std_value=127
+ --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1 \
+ --mean_values=128 \
+ --std_dev_values=127
```
### Use \"dummy-quantization\" to try out quantized inference on a float graph <a name="dummy-quant"></a>
@@ -126,45 +154,20 @@ a reasonable guess is that most activation ranges should be contained in [0, 6].
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+tflite_convert \
--output_file=/tmp/foo.cc \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
+ --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
--inference_type=QUANTIZED_UINT8 \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1 \
+ --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1 \
--default_ranges_min=0 \
--default_ranges_max=6 \
- --mean_value=127.5 \
- --std_value=127.5
+ --mean_values=128 \
+ --std_dev_values=127
```
## Specifying input and output arrays
-### Multiple output arrays
-
-The flag `output_arrays` takes in a comma-separated list of output arrays as
-seen in the example below. This is useful for models or subgraphs with multiple
-outputs.
-
-```
-curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
- | tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
- --output_file=/tmp/foo.tflite \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --inference_type=FLOAT \
- --input_shape=1,224,224,3 \
- --input_array=input \
- --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu
-```
-
### Multiple input arrays
The flag `input_arrays` takes in a comma-separated list of input arrays as seen
@@ -174,21 +177,33 @@ inputs.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+tflite_convert \
+ --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
--output_file=/tmp/foo.tflite \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --inference_type=FLOAT \
--input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
--input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
- --output_array=InceptionV1/Logits/Predictions/Reshape_1
+ --output_arrays=InceptionV1/Logits/Predictions/Reshape_1
```
Note that `input_shapes` is provided as a colon-separated list. Each input shape
corresponds to the input array at the same position in the respective list.
+### Multiple output arrays
+
+The flag `output_arrays` takes in a comma-separated list of output arrays as
+seen in the example below. This is useful for models or subgraphs with multiple
+outputs.
+
+```
+curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
+ | tar xzv -C /tmp
+tflite_convert \
+ --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+ --output_file=/tmp/foo.tflite \
+ --input_arrays=input \
+ --output_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu
+```
+
### Specifying subgraphs
Any array in the input file can be specified as an input or output array in
@@ -203,158 +218,57 @@ GraphDef.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/inception_v1_2016_08_28_frozen.pb \
+tflite_convert \
+ --graph_def_file=/tmp/inception_v1_2016_08_28_frozen.pb \
--output_file=/tmp/foo.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TENSORFLOW_GRAPHDEF \
--input_shapes=1,28,28,96:1,28,28,16:1,28,28,192:1,28,28,64 \
--input_arrays=InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_2/Conv2d_0a_1x1/Relu,InceptionV1/InceptionV1/Mixed_3b/Branch_3/MaxPool_0a_3x3/MaxPool,InceptionV1/InceptionV1/Mixed_3b/Branch_0/Conv2d_0a_1x1/Relu \
- --output_array=InceptionV1/InceptionV1/Mixed_3b/concat_v2
+ --output_arrays=InceptionV1/InceptionV1/Mixed_3b/concat_v2
```
-Note that the final representation of an on-device inference workload (say, in
-TensorFlow Lite FlatBuffers format) tends to have coarser granularity than the
-very fine granularity of the TensorFlow GraphDef representation. For example,
-while a fully-connected layer is typically represented as at least four separate
-ops in TensorFlow GraphDef (Reshape, MatMul, BiasAdd, Relu...), it is typically
-represented as a single "fused" op (FullyConnected) in the converter's optimized
-representation and in the final on-device representation (e.g. in TensorFlow
-Lite FlatBuffer format). As the level of granularity gets coarser, some
+Note that the final representation in TensorFlow Lite FlatBuffers tends to have
+coarser granularity than the very fine granularity of the TensorFlow GraphDef
+representation. For example, while a fully-connected layer is typically
+represented as at least four separate ops in TensorFlow GraphDef (Reshape,
+MatMul, BiasAdd, Relu...), it is typically represented as a single "fused" op
+(FullyConnected) in the converter's optimized representation and in the final
+on-device representation. As the level of granularity gets coarser, some
intermediate arrays (say, the array between the MatMul and the BiasAdd in the
-TensorFlow GraphDef) are dropped. When specifying intermediate arrays as
-`--input_arrays` / `--output_arrays`, it is desirable (and often required) to
-specify arrays that are meant to survive in the final form of the graph, after
-fusing. These are typically the outputs of activation functions (since
-everything in each layer until the activation function tends to get fused).
-
-## Other conversions supported by TOCO <a name="other-conversions"></a>
+TensorFlow GraphDef) are dropped.
-The converter accepts both TENSORFLOW_GRAPHDEF and TFLITE file formats as both
-`--input_format` and `--output_format`. This means that conversion to and from
-any supported format is possible.
-
-### Optimize a TensorFlow GraphDef <a name="optimize-graphdef"></a>
-
-Same-format "conversions" can be used to optimize and simplify a graph or be
-used to [get a subgraph](#specifying-subgraphs) of a graph. The flag
-`--inference_type` is not required because TensorFlow graphs, including those
-containing the
-[`FakeQuant*`](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization)
-ops are always float graphs.
-
-```
-curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
- | tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
- --output_file=/tmp/foo.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TENSORFLOW_GRAPHDEF \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1
-```
-
-### Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format <a name="to-graphdef"></a>
-
-The converter supports file format conversions from TensorFlow Lite, back into
-TensorFlow GraphDef format.
-
-```
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/foo.tflite \
- --output_file=/tmp/foo.pb \
- --input_format=TFLITE \
- --output_format=TENSORFLOW_GRAPHDEF \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1
-```
+When specifying intermediate arrays as `--input_arrays` and `--output_arrays`,
+it is desirable (and often required) to specify arrays that are meant to survive
+in the final form of the graph, after fusing. These are typically the outputs of
+activation functions (since everything in each layer until the activation
+function tends to get fused).
## Logging
-### Standard logging
-
-The converter generates some informative log messages during processing. The
-easiest way to view them is to add `--logtostderr` to command lines as seen in
-the following example.
-
-```
-curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
- | tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
- --output_file=/tmp/foo.tflite \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --inference_type=FLOAT \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1 \
- --logtostderr
-```
-
-After some initialization messages, we get the following informative messages:
-
-```
-I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized)
-I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized)
-I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized)
-I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes.
-I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops).
-```
-
-### Verbose logging
-
-For debugging purposes, the converter supports two levels of verbose logging,
-which can be set by passing a `--v=` flag:
-
-* For `--v=1`, the converter generates text dumps of the graph at various
- points during processing as well as log messages about every graph
- transformation that took place.
-* For `--v=2`, the converter additionally generates log messages about graph
- transformations that were considered but not performed.
-
-### Graph "video" logging
-
-When `--dump_graphviz=` is used (see the section on [graph
-visualizations](#graph-visualizations)), one may additionally pass
-`--dump_graphviz_video`, which causes a graph visualization to be dumped after
-each individual graph transformation. This results in thousands of files.
-Typically, one would then bisect into these files to understand when a given
-change was introduced in the graph.
## Graph visualizations
TOCO can export a graph to the GraphViz Dot format for easy visualization via
-either the `--output_format` flag or the `--dump_graphviz` flag. The subsections
-below outline the use cases for each.
+either the `--output_format` flag or the `--dump_graphviz_dir` flag. The
+subsections below outline the use cases for each.
### Using `--output_format=GRAPHVIZ_DOT`
The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into
`--output_format`. This results in a plausible visualization of the graph. This
-reduces the requirements that normally exist during conversion between other
-input and output formats. For example, this may be useful if conversion from
-TENSORFLOW_GRAPHDEF to TFLITE is failing.
+reduces the requirements that exist during conversion between other input and
+output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to
+TFLITE is failing.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+tflite_convert \
+ --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
--output_file=/tmp/foo.dot \
- --input_format=TENSORFLOW_GRAPHDEF \
--output_format=GRAPHVIZ_DOT \
--input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1
+ --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1
```
The resulting `.dot` file can be rendered into a PDF as follows:
@@ -375,49 +289,35 @@ Example PDF files are viewable online in the next section.
### Using `--dump_graphviz`
-The second way to get a graphviz rendering is to pass the `--dump_graphviz=`
+The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir`
flag, specifying a destination directory to dump GraphViz rendering to. Unlike
-the previous approach, this one allows you to keep your real command-line (with
-your real `--output_format` and other flags) unchanged, just appending a
-`--dump_graphviz=` flag to it. This provides a visualization of the actual graph
-during a specific conversion process.
+the previous approach, this one retains the original output format. This
+provides a visualization of the actual graph resulting from a specific
+conversion process.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
| tar xzv -C /tmp
-bazel run --config=opt \
- //tensorflow/contrib/lite/toco:toco -- \
- --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
+tflite_convert \
+ --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
--output_file=/tmp/foo.tflite \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --inference_type=FLOAT \
- --input_shape=1,128,128,3 \
- --input_array=input \
- --output_array=MobilenetV1/Predictions/Reshape_1 \
- --dump_graphviz=/tmp
+ --input_arrays=input \
+ --output_arrays=MobilenetV1/Predictions/Reshape_1 \
+ --dump_graphviz_dir=/tmp
```
-This generates a few files in the destination directory, here `/tmp`. The two
-most important files are:
-
-```
-/tmp/toco_AT_IMPORT.dot
-/tmp/toco_AFTER_TRANSFORMATIONS.dot
-```
-
-`toco_AT_IMPORT.dot` represents the graph as it was imported from
-`--input_file`, before any transformation was applied to it (besides some
-transformations that are applied immediately while importing). This tends to be
-a complex visualization with limited information, but is useful especially in
-situations where a conversion command fails (this file is generated even if the
-conversion subsequently fails).
+This generates a few files in the destination directory. The two most important
+files are `toco_AT_IMPORT.dot` and `/tmp/toco_AFTER_TRANSFORMATIONS.dot`.
+`toco_AT_IMPORT.dot` represents the original graph containing only the
+transformations done at import time. This tends to be a complex visualization
+with limited information about each node. It is useful in situations where a
+conversion command fails.
`toco_AFTER_TRANSFORMATIONS.dot` represents the graph after all transformations
-were applied to it, just before it was exported to the `--output_file`.
-Typically, this is a much smaller graph with more information about each node.
+were applied to it, just before it is exported. Typically, this is a much
+smaller graph with more information about each node.
-Again, these can be rendered to PDFs:
+As before, these can be rendered to PDFs:
```
dot -Tpdf -O /tmp/toco_*.dot
@@ -428,6 +328,14 @@ Sample output files can be seen here:
* [toco_AT_IMPORT.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AT_IMPORT.dot.pdf)
* [toco_AFTER_TRANSFORMATIONS.dot.pdf](https://storage.googleapis.com/download.tensorflow.org/example_images/toco_AFTER_TRANSFORMATIONS.dot.pdf).
+### Graph "video" logging
+
+When `--dump_graphviz_dir` is used, one may additionally pass
+`--dump_graphviz_video`. This causes a graph visualization to be dumped after
+each individual graph transformation, resulting in thousands of files.
+Typically, one would then bisect into these files to understand when a given
+change was introduced in the graph.
+
### Legend for the graph visualizations <a name="graphviz-legend"></a>
* Operators are red square boxes with the following hues of red:
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index 9e99287f82..decc8a45a4 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -1,7 +1,8 @@
# TensorFlow Lite Optimizing Converter command-line glossary
-This page is complete reference of command-line flags. It is complemented by the
-following other documents:
+This page is complete reference of command-line flags used by TOCO's command
+line starting from TensorFlow 1.9 up until the most recent build of TensorFlow.
+It is complemented by the following other documents:
* [README](../README.md)
* [Command-line examples](cmdline_examples.md)
@@ -16,116 +17,81 @@ Table of contents:
## High-level flags
-The following high level flags specify the location of the input and output
+The following high level flags specify the details of the input and output
files. The flag `--output_file` is always required. Additionally, either
-`--input_file` or `--savedmodel_directory` is required.
-
-* `--savedmodel_directory`. Type: string. Specifies the full path to the
- directory containing the SavedModel.
-* `--savedmodel_tagset`. Type: string. Default:
+`--graph_def_file`, `--saved_model_dir` or `--keras_model_file` is required.
+
+* `--output_file`. Type: string. Specifies the full path of the output file.
+* `--graph_def_file`. Type: string. Specifies the full path of the input
+ GraphDef file frozen using
+ [freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py).
+* `--saved_model_dir`. Type: string. Specifies the full path to the directory
+ containing the SavedModel.
+* `--keras_model_file`. Type: string. Specifies the full path of the HDF5 file
+ containing the tf.keras model.
+* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of
+ the output file. Allowed values:
+ * `TFLITE`: TensorFlow Lite FlatBuffer format.
+ * `GRAPHVIZ_DOT`: GraphViz `.dot` format containg a visualization of the
+ graph after graph transformations.
+ * Note that passing `GRAPHVIZ_DOT` to `--output_format` leads to loss
+ of TFLite specific transformations. Therefore, the resulting
+ visualization may not reflect the final set of graph
+ transformations. To get a final visualization with all graph
+ transformations use `--dump_graphviz` instead.
+
+The following flags specify optional parameters when using SavedModels.
+
+* `--saved_model_tag_set`. Type: string. Default:
[kSavedModelTagServe](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/tag_constants.h).
Specifies a comma-separated set of tags identifying the MetaGraphDef within
the SavedModel to analyze. All tags in the tag set must be specified.
-* `--input_file`. Type: string. Specifies the path of the input file. This may
- be either an absolute or a relative path.
-* `--output_file`. Type: string. Specifies the path of the output file.
-
-The following high level flags specify the types of the input and output files:
-
-* `--input_format`. Type: string. Default: `TENSORFLOW_GRAPHDEF`. Specifies
- the format of the input file. Allowed values:
- * `TENSORFLOW_GRAPHDEF` &mdash; The TensorFlow GraphDef format. Both
- binary and text proto formats are allowed.
- * `TFLITE` &mdash; The TensorFlow Lite FlatBuffers format.
-* `--output_format`. Type: string. Default: `TFLITE`. Specifies the format of
- the output file. Allowed values:
- * `TENSORFLOW_GRAPHDEF` &mdash; The TensorFlow GraphDef format. Always
- produces a file in binary (not text) proto format.
- * `TFLITE` &mdash; The TensorFlow Lite FlatBuffers format.
- * Whether a float or quantized TensorFlow Lite file will be produced
- depends on the `--inference_type` flag.
- * `GRAPHVIZ_DOT` &mdash; The GraphViz `.dot` format. This asks the
- converter to generate a reasonable graphical representation of the graph
- after simplification by a generic set of transformation.
- * A typical `dot` command line to view the resulting graph might look
- like: `dot -Tpdf -O file.dot`.
- * Note that since passing this `--output_format` means losing the
- information of which output format you actually care about, and
- since the converter's transformations depend on the specific output
- format, the resulting visualization may not fully reflect what you
- would get on the actual output format that you are using. To avoid
- that concern, and generally to get a visualization of exactly what
- you get in your actual output format as opposed to just a merely
- plausible visualization of a model, consider using `--dump_graphviz`
- instead and keeping your true `--output_format`.
+* `--saved_model_signature_key`. Type: string. Default:
+ [DEFAULT_SERVING_SIGNATURE_DEF_KEY](https://www.tensorflow.org/api_docs/python/tf/saved_model/signature_constants).
+ Specifies the key identifying the SignatureDef containing inputs and
+ outputs.
## Model flags
*Model flags* provide additional information about the model stored in the input
file.
-* `--output_array`. Type: string. Specifies a single array as the output
- activations. Incompatible with `--output_arrays`.
-* `--output_arrays`. Type: comma-separated list of strings. Specifies a list
- of arrays as the output activations, for models with multiple outputs.
- Incompatible with `--output_array`.
-* `--input_array`. Type: string. Specifies a single array as the input
- activations. Incompatible with `--input_arrays`.
-* `--input_arrays`. Type: comma-separated list of strings. Specifies a list of
- arrays as the input activations, for models with multiple inputs.
- Incompatible with `--input_array`.
-* `--batch_size`. Type: integer. Default: 1. Specifies the batch size for the
- model. Replaces the first dimension of an input size array if undefined. Use
- only with SavedModels when neither `--input_shape` nor `input_shapes` flags
- are specified. Incompatible with GraphDefs.
-
-When `--input_array` is used, the following flags are available to provide
-additional information about the single input array:
-
-* `--input_shape`. Type: comma-separated list of integers. Specifies the shape
- of the input array, in TensorFlow convention: starting with the outer-most
- dimension (the dimension corresponding to the largest offset stride in the
- array layout), ending with the inner-most dimension (the dimension along
- which array entries are typically laid out contiguously in memory).
- * For example, a typical vision model might pass
- `--input_shape=1,60,80,3`, meaning a batch size of 1 (no batching), an
- input image height of 60, an input image width of 80, and an input image
- depth of 3, for the typical case where the input image is a RGB bitmap
- (3 channels, depth=3) stored by horizontal scanlines (so 'width' is the
- next innermost dimension after 'depth').
-* `--mean_value` and `--std_value`. Type: floating-point. The decimal point
- character is always the dot (`.`) regardless of the locale. These specify
- the (de-)quantization parameters of the input array, when it is quantized.
- * The meaning of mean_value and std_value is as follows: each quantized
- value in the quantized input array will be interpreted as a mathematical
- real number (i.e. as an input activation value) according to the
- following formula:
+* `--input_arrays`. Type: comma-separated list of strings. Specifies the list
+ of names of input activation tensors.
+* `--output_arrays`. Type: comma-separated list of strings. Specifies the list
+ of names of output activation tensors.
+
+The following flags define properties of the input tensors. Each item in the
+`--input_arrays` flag should correspond to each item in the following flags
+based on index.
+
+* `--input_shapes`. Type: colon-separated list of comma-separated lists of
+ integers. Each comma-separated list of integers gives the shape of one of
+ the input arrays specified in [TensorFlow
+ convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
+ * Example: `--input_shapes=1,60,80,3` for a typical vision model means a
+ batch size of 1, an input image height of 60, an input image width of
+ 80, and an input image depth of 3 (representing RGB channels).
+ * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo"
+ has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
+* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers.
+ These specify the (de-)quantization parameters of the input array, when it
+ is quantized.
+ * The meaning of `mean_values` and `std_dev_values` is as follows: each
+ quantized value in the quantized input array will be interpreted as a
+ mathematical real number (i.e. as an input activation value) according
+ to the following formula:
* `real_value = (quantized_input_value - mean_value) / std_value`.
* When performing float inference (`--inference_type=FLOAT`) on a
quantized input, the quantized input would be immediately dequantized by
the inference code according to the above formula, before proceeding
with float inference.
* When performing quantized inference
- (`--inference_type=QUANTIZED_UINT8`), no dequantization is ever to be
- performed by the inference code; however, the quantization parameters of
- all arrays, including those of the input arrays as specified by
- mean_value and std_value, all participate in the determination of the
- fixed-point multipliers used in the quantized inference code.
-
-When `--input_arrays` is used, the following flags are available to provide
-additional information about the multiple input arrays:
-
-* `--input_shapes`. Type: colon-separated list of comma-separated lists of
- integers. Each comma-separated list of integer gives the shape of one of the
- input arrays specified in `--input_arrays`, in the same order. See
- `--input_shape` for details.
- * Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means that
- there are two input arrays. The first one, "foo", has shape [2,3]. The
- second one, "bar", has shape [4,5,6].
-* `--mean_values`, `--std_values`. Type: comma-separated lists of
- floating-point numbers. Each number gives the corresponding value for one of
- the input arrays specified in `--input_arrays`, in the same order. See
- `--mean_value`, `--std_value` for details.
+ (`--inference_type=QUANTIZED_UINT8`), no dequantization is performed by
+ the inference code. However, the quantization parameters of all arrays,
+ including those of the input arrays as specified by `mean_value` and
+ `std_dev_value`, determine the fixed-point multipliers used in the
+ quantized inference code.
## Transformation flags
@@ -133,21 +99,13 @@ additional information about the multiple input arrays:
the graph, i.e. they specify requested properties that the output file should
have.
-* `--inference_type`. Type: string. Sets the type of real-number arrays in the
- output file, that is, controls the representation (quantization) of real
- numbers in the output file, except for input arrays, which are controlled by
- `--inference_input_type`.
+* `--inference_type`. Type: string. Default: `FLOAT`. Data type of all
+ real-number arrays in the output file except for input arrays (defined by
+ `--inference_input_type`). Must be `{FLOAT, QUANTIZED_UINT8}`.
- This flag only impacts real-number arrays. By "real-number" we mean float
- arrays, and quantized arrays. This excludes plain integer arrays, strings
- arrays, and every other data type.
-
- For real-number arrays, the impact of this flag is to allow the output file
- to choose a different real-numbers representation (quantization) from what
- the input file used. For any other types of arrays, changing the data type
- would not make sense.
-
- Specifically:
+ This flag only impacts real-number arrays including float and quantized
+ arrays. This excludes all other data types including plain integer arrays
+ and string arrays. Specifically:
* If `FLOAT`, then real-numbers arrays will be of type float in the output
file. If they were quantized in the input file, then they get
@@ -155,72 +113,54 @@ have.
* If `QUANTIZED_UINT8`, then real-numbers arrays will be quantized as
uint8 in the output file. If they were float in the input file, then
they get quantized.
- * If not set, then all real-numbers arrays retain the same type in the
- output file as they have in the input file.
-
-* `--inference_input_type`. Type: string. Similar to inference_type, but
- allows to control specifically the quantization of input arrays, separately
- from other arrays.
-
- If not set, then the value of `--inference_type` is implicitly used, i.e. by
- default input arrays are quantized like other arrays.
-
- Like `--inference_type`, this only affects real-number arrays. By
- "real-number" we mean float arrays, and quantized arrays. This excludes
- plain integer arrays, strings arrays, and every other data type.
-
- The typical use for this flag is for vision models taking a bitmap as input,
- typically with uint8 channels, yet still requiring floating-point inference.
- For such image models, the uint8 input is quantized, i.e. the uint8 values
- are interpreted as real numbers, and the quantization parameters used for
- such input arrays are their `mean_value`, `std_value` parameters.
-
-* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point. The
- decimal point character is always the dot (`.`) regardless of the locale.
- These flags enable what is called "dummy quantization". If defined, their
- effect is to define fallback (min, max) range values for all arrays that do
- not have a properly specified (min, max) range in the input file, thus
- allowing to proceed with quantization of non-quantized or
- incorrectly-quantized input files. This enables easy performance prototyping
- ("how fast would my model run if I quantized it?") but should never be used
- in production as the resulting quantized arithmetic is inaccurate.
-
-* `--drop_fake_quant`. Type: boolean. Default: false. Causes fake-quantization
- nodes to be dropped from the graph. This may be used to recover a plain
- float graph from a fake-quantized graph.
-
-* `--reorder_across_fake_quant`. Type: boolean. Default: false. Normally,
- fake-quantization nodes must be strict boundaries for graph transformations,
- in order to ensure that quantized inference has the exact same arithmetic
- behavior as quantized training --- which is the whole point of quantized
- training and of FakeQuant nodes in the first place. However, that entails
- subtle requirements on where exactly FakeQuant nodes must be placed in the
- graph. Some quantized graphs have FakeQuant nodes at unexpected locations,
- that prevent graph transformations that are necessary in order to generate a
- well-formed quantized representation of these graphs. Such graphs should be
- fixed, but as a temporary work-around, setting this
- reorder_across_fake_quant flag allows the converter to perform necessary
- graph transformations on them, at the cost of no longer faithfully matching
- inference and training arithmetic.
+
+* `--inference_input_type`. Type: string. Data type of a real-number input
+ array in the output file. By default the `--inference_type` is used as type
+ of all of the input arrays. Flag is primarily intended for generating a
+ float-point graph with a quantized input array. A Dequantized operator is
+ added immediately after the input array. Must be `{FLOAT, QUANTIZED_UINT8}`.
+
+ The flag is typically used for vision models taking a bitmap as input but
+ requiring floating-point inference. For such image models, the uint8 input
+ is quantized and the quantization parameters used for such input arrays are
+ their `mean_value` and `std_dev_value` parameters.
+
+* `--default_ranges_min`, `--default_ranges_max`. Type: floating-point.
+ Default value for the (min, max) range values used for all arrays without a
+ specified range. Allows user to proceed with quantization of non-quantized
+ or incorrectly-quantized input files. These flags produce models with low
+ accuracy. They are intended for easy experimentation with quantization via
+ "dummy quantization".
+
+* `--drop_control_dependency`. Type: boolean. Default: True. Indicates whether
+ to drop control dependencies silently. This is due to TensorFlow Lite not
+ supporting control dependencies.
+
+* `--reorder_across_fake_quant`. Type: boolean. Default: False. Indicates
+ whether to reorder FakeQuant nodes in unexpected locations. Used when the
+ location of the FakeQuant nodes is preventing graph transformations
+ necessary to convert the graph. Results in a graph that differs from the
+ quantized training graph, potentially causing differing arithmetic behavior.
+
+* `--allow_custom_ops`. Type: string. Default: False. Indicates whether to
+ allow custom operations. When false, any unknown operation is an error. When
+ true, custom ops are created for any op that is unknown. The developer will
+ need to provide these to the TensorFlow Lite runtime with a custom resolver.
+
+* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to
+ store weights as quantized weights followed by dequantize operations.
+ Computation is still done in float, but reduces model size (at the cost of
+ accuracy and latency).
## Logging flags
-The following are standard Google logging flags:
-
-* `--logtostderr` redirects Google logging to standard error, typically making
- it visible in a terminal.
-* `--v` sets verbose logging levels (for debugging purposes). Defined levels:
- * `--v=1`: log all graph transformations that did make a change on the
- graph.
- * `--v=2`: log all graph transformations that did *not* make a change on
- the graph.
-
-The following flags allow to generate graph visualizations of the actual graph
-at various points during transformations:
-
-* `--dump_graphviz=/path` enables dumping of the graphs at various stages of
- processing as GraphViz `.dot` files. Generally preferred over
- `--output_format=GRAPHVIZ_DOT` as this allows you to keep your actually
- relevant `--output_format`.
-* `--dump_graphviz_video` enables dumping of the graph after every single
- graph transformation (for debugging purposes).
+The following flags generate graph visualizations of the graph as
+[GraphViz](https://www.graphviz.org/) `.dot` files at various points during
+graph transformations:
+
+* `--dump_graphviz_dir`. Type: string. Specifies the full path of the
+ directory to output GraphViz `.dot` files. Outputs the graph immediately
+ after reading in the graph and after all of the transformations have been
+ completed.
+* `--dump_graphviz_video`. Type: boolean. Outputs GraphViz after every graph
+ transformation. Requires `--dump_graphviz_dir` to be specified.
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 5071361bfd..3799eac0a1 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -15,11 +15,15 @@ Table of contents:
* [Exporting a GraphDef from tf.Session](#basic-graphdef-sess)
* [Exporting a GraphDef from file](#basic-graphdef-file)
* [Exporting a SavedModel](#basic-savedmodel)
+ * [Exporting a tf.keras File](#basic-keras-file)
* [Complex examples](#complex)
* [Exporting a quantized GraphDef](#complex-quant)
* [TensorFlow Lite Python interpreter](#interpreter)
* [Using the interpreter from a model file](#interpreter-file)
* [Using the interpreter from model data](#interpreter-data)
+* [Additional instructions](#additional-instructions)
+ * [Build from source code](#latest-package)
+ * [Converting models prior to TensorFlow 1.9.](#pre-tensorflow-1.9)
## High-level overview
@@ -31,15 +35,17 @@ designing a model that can be targeted to devices with mobile.
## API
-The API for converting TensorFlow models to TensorFlow Lite is
-`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
+The API for converting TensorFlow models to TensorFlow Lite as of TensorFlow 1.9
+is `tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is
`tf.contrib.lite.Interpreter`.
`TocoConverter` provides class methods based on the original format of the
model. `TocoConverter.from_session()` is available for GraphDefs.
-`TocoConverter.from_saved_model()` is available for SavedModels. Example usages
-for simple float-point models are shown in [Basic Examples](#basic). Examples
-usages for more complex models is shown in [Complex Examples](#complex).
+`TocoConverter.from_saved_model()` is available for SavedModels.
+`TocoConverter.from_keras_model_file()` is available for `tf.Keras` files.
+Example usages for simple float-point models are shown in [Basic
+Examples](#basic). Examples usages for more complex models is shown in [Complex
+Examples](#complex).
**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python
interpreter when the conversion fails. This will be remedied as soon as
@@ -111,6 +117,51 @@ For more complex SavedModels, the optional parameters that can be passed into
`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are
available by running `help(tf.contrib.lite.TocoConverter)`.
+### Exporting a tf.keras File <a name="basic-keras-file"></a>
+
+The following example shows how to convert a `tf.keras` model into a TensorFlow
+Lite FlatBuffer.
+
+```python
+import tensorflow as tf
+
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file("keras_model.h5")
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
+The `tf.keras` file must contain both the model and the weights. A comprehensive
+example including model construction can be seen below.
+
+```python
+import numpy as np
+import tensorflow as tf
+
+# Generate tf.keras model.
+model = tf.keras.models.Sequential()
+model.add(tf.keras.layers.Dense(2, input_shape=(3,)))
+model.add(tf.keras.layers.RepeatVector(3))
+model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(3)))
+model.compile(loss=tf.keras.losses.MSE,
+ optimizer=tf.keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[tf.keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+
+x = np.random.random((1, 3))
+y = np.random.random((1, 3, 3))
+model.train_on_batch(x, y)
+model.predict(x)
+
+# Save tf.keras model in HDF5 format.
+keras_file = "keras_model.h5"
+tf.keras.models.save_model(model, keras_file)
+
+# Convert to TensorFlow Lite model.
+converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_file)
+tflite_model = converter.convert()
+open("converted_model.tflite", "wb").write(tflite_model)
+```
+
## Complex examples <a name="complex"></a>
For models where the default value of the attributes is not sufficient, the
@@ -138,7 +189,8 @@ out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output")
with tf.Session() as sess:
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8
- converter.quantized_input_stats = {"img" : (0., 1.)} # mean, std_dev
+ input_arrays = converter.get_input_arrays()
+ converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
```
@@ -199,3 +251,18 @@ with tf.Session() as sess:
interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
```
+
+## Additional instructions
+
+### Build from source code <a name="latest-package"></a>
+
+In order to run the latest version of the TOCO Python API, clone the TensorFlow
+repository, configure the installation, and build and install the pip package.
+Detailed instructions are available
+[here](https://www.tensorflow.org/install/install_sources).
+
+### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
+
+To use TOCO in TensorFlow 1.7 and TensorFlow 1.8, use the `toco_convert`
+function. Run `help(tf.contrib.lite.toco_convert)` to get details about accepted
+parameters.
diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
index a47c088991..262e13a591 100644
--- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
+++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
@@ -1 +1 @@
-<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m154.36745 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m154.36745 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m184.89111 339.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m241.86351 334.89435l42.267715 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m241.86351 334.89435l38.840652 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.70413 334.89435l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m78.872284 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m78.872284 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m93.328064 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.63894 87.62236q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m127.74803 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m127.74803 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m147.45874 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m233.1085 268.03217l-66.74016 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m233.10852 268.03217l-63.313095 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m169.79543 268.03217l1.124588 -1.1246033l-3.0897675 1.1246033l3.0897675 1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 19.652092l46.992126 0l0 133.54475" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 19.652084l46.992126 0l0 130.11768" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m122.614174 249.1182l-1.124588 -1.124588l1.124588 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m171.49606 99.34974l0 19.650558l-48.88189 0l0 133.5463" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m171.49606 99.34974l0 19.650558l-48.88189 0l0 130.1192" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m122.614174 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m122.620316 283.52823l0 14.9730835l75.49606 0l0 20.90091" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m122.620316 283.52823l0 14.9730835l75.49608 0l0 17.473846" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m198.1164 315.97516l-1.124588 -1.1246033l1.124588 3.0897827l1.1245728 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 14.9730835l-78.74016 0l0 20.90091" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 14.9730835l-78.74014 0l0 17.473846" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m198.1164 315.97516l-1.124588 -1.1246033l1.124588 3.0897827l1.1245728 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
index 0fffab574d..1ea83abf8e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -38,6 +38,16 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
// Depthwise conv does not support dilation
return false;
}
+ auto& input_array = model->GetArray(conv_op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Shapes not propagated yet
+ return false;
+ }
+ if (input_array.shape().dims(3) != 1) {
+ // Not a pure convolution: Conv does accumulation across the depth
+ // dimension.
+ return false;
+ }
auto& weights_array = model->GetArray(conv_op->inputs[1]);
if (!weights_array.buffer) {
// Yield until the weights are resolved as a constant array.
@@ -46,11 +56,6 @@ bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
if (weights_array.data_type != ArrayDataType::kFloat) {
return false;
}
- if (weights_array.shape().dims(3) != 1) {
- // Not a pure convolution: Conv does accumulation across the depth
- // dimension.
- return false;
- }
// At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
AddMessageF(
"%s is purely convolutional (input/weights depth is 1), replacing it by "
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
new file mode 100644
index 0000000000..b689be0792
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertTrivialTileToConcat::Run(Model* model, std::size_t op_index) {
+ auto tile_it = model->operators.begin() + op_index;
+ if (tile_it->get()->type != OperatorType::kTile) {
+ return false;
+ }
+ auto* tile_op = static_cast<TransposeOperator*>(tile_it->get());
+
+ const auto& input_array = model->GetArray(tile_op->inputs[0]);
+ const auto& multiples_array = model->GetArray(tile_op->inputs[1]);
+ const auto& output_array = model->GetArray(tile_op->outputs[0]);
+ if (!input_array.has_shape() || !multiples_array.has_shape() ||
+ !output_array.has_shape()) {
+ // Yield until PropagateFixedSizes has been run on this op.
+ return false;
+ }
+ // Note: We can assume we have error checked inputs in PropagateFixedSizes.
+
+ if (!multiples_array.buffer) {
+ // Yield until the multiples is constant.
+ return false;
+ }
+ std::vector<int32> const& multiples =
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // We can simplify the tile if only a single dimension is being multiplied.
+ // It then just becomes a concat along that dimension.
+ int non_one_dims = 0;
+ int concat_axis = 0;
+ for (int i = 0; i < multiples.size(); ++i) {
+ if (multiples[i] != 1) {
+ ++non_one_dims;
+ concat_axis = i;
+ }
+ }
+ if (non_one_dims != 1) {
+ // The tile is non-trivial. Good luck.
+ AddMessageF("Tile %s is non-trivial (has more than one multiply dimension)",
+ LogName(*tile_op));
+ return false;
+ }
+
+ // The tile is like a concat.
+ AddMessageF("Simplifying %s to a Concat along a single axis %d",
+ LogName(*tile_op), concat_axis);
+
+ auto* concat_op = new ConcatenationOperator;
+
+ // Copy input and output.
+ // Note that we multiply out the input by the number of times requested.
+ for (int i = 0; i < multiples[concat_axis]; ++i) {
+ concat_op->inputs.push_back(tile_op->inputs[0]);
+ }
+ concat_op->axis = concat_axis;
+ concat_op->outputs = tile_op->outputs;
+
+ // Delete multiples array if unused.
+ if (IsDiscardableArray(*model, tile_op->inputs[1]) &&
+ CountOpsWithInput(*model, tile_op->inputs[1]) == 1) {
+ model->EraseArray(tile_op->inputs[1]);
+ }
+
+ // Replace the operator in the graph.
+ const auto concat_it = model->operators.emplace(tile_it, concat_op);
+ tile_it = concat_it + 1;
+ CHECK_EQ(tile_it->get(), tile_op);
+ model->operators.erase(tile_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
index 076415ece8..1e68cd678b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -25,17 +25,12 @@ limitations under the License.
namespace toco {
-bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
- auto conv_it = model->operators.begin() + op_index;
- if (conv_it->get()->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
- if (conv_op->outputs.size() == 2) {
+bool ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (op->outputs.size() == 2) {
// We already have an im2col array
return false;
}
- const auto& weights_array = model->GetArray(conv_op->inputs[1]);
+ const auto& weights_array = model->GetArray(op->inputs[1]);
if (!weights_array.has_shape()) {
// We need to yield until weights dims have been resolved, because
// from the weights dims we determine whether an im2col array is
@@ -45,25 +40,52 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
const auto& weights_shape = weights_array.shape();
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
- if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
- conv_op->stride_height == 1) {
- // 1x1 unstrided conv does not need an im2col array.
+ if (kwidth == 1 && kheight == 1 && op->stride_width == 1 &&
+ op->stride_height == 1 && op->dilation_width_factor == 1 &&
+ op->dilation_height_factor == 1) {
+ // 1x1 unstrided undilated conv does not need an im2col array.
return false;
}
// Create the im2col array.
- CHECK_EQ(conv_op->outputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
const string& im2col_array_name =
- AvailableArrayName(*model, conv_op->inputs[0] + "_im2col");
+ AvailableArrayName(*model, op->inputs[0] + "_im2col");
model->GetOrCreateArray(im2col_array_name);
- conv_op->outputs.push_back(im2col_array_name);
- AddMessageF(
- "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, "
- "stride_height=%d",
- LogName(*conv_op), kwidth, kheight, conv_op->stride_width,
- conv_op->stride_height);
+ op->outputs.push_back(im2col_array_name);
return true;
}
+bool ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
+ if (op->outputs.size() == 2) {
+ // We already have an im2col array
+ return false;
+ }
+
+ // Always create an im2col array for transpose_conv.
+ CHECK_EQ(op->outputs.size(), 1);
+ const string& im2col_array_name = AvailableArrayName(
+ *model, op->inputs[TransposeConvOperator::DATA_INPUT] + "_im2col");
+ model->GetOrCreateArray(im2col_array_name);
+ op->outputs.push_back(im2col_array_name);
+
+ return true;
+}
+
+bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ switch (op->type) {
+ case OperatorType::kConv:
+ return ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ case OperatorType::kTransposeConv:
+ return ProcessTransposeConvOperator(
+ model, static_cast<TransposeConvOperator*>(op));
+ default:
+ return false;
+ }
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
index 498c864bde..2c7ffe4884 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -111,7 +111,7 @@ bool DequantizeArray(const string& array_name,
auto* op_outputting_array = GetOpWithOutput(*model, array_name);
if (op_outputting_array) {
- if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ if (op_outputting_array->type == OperatorType::kReshape) {
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
index 708ecf6e0a..e80ed036b3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -26,17 +26,38 @@ namespace toco {
namespace {
+int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
+ const string& weights_name = op.inputs[1];
+ const auto& weights_shape = model.GetArray(weights_name).shape();
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kFullyConnected) {
+ return weights_shape.dims(0);
+ }
+ if (op.type == OperatorType::kDepthwiseConv) {
+ return weights_shape.dims(3);
+ }
+ LOG(FATAL) << "Unhandled operator type";
+ return 0;
+}
+
bool ProcessLinearOperator(Model* model, Operator* op) {
if (op->inputs.size() >= 3) {
return false;
}
const string& output_name = op->outputs[0];
+ const string& weights_name = op->inputs[1];
+ if (!model->GetArray(weights_name).has_shape()) {
+ return false;
+ }
+ const int depth = GetOutputDepthFromWeights(*model, *op);
const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
op->inputs.push_back(bias_name);
DCHECK_EQ(op->inputs.size(), 3);
auto& bias_array = model->GetOrCreateArray(bias_name);
bias_array.data_type = ArrayDataType::kFloat;
-
+ bias_array.mutable_shape()->mutable_dims()->push_back(depth);
+ auto& bias_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ bias_buffer.data.resize(depth, 0.f);
return true;
}
} // namespace
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
index 394fa349e2..75642bbc37 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc
@@ -122,7 +122,7 @@ bool EnsureUint8WeightsSafeForFastInt8Kernels::Run(Model* model,
case OperatorType::kFullyConnected: {
weights_index = 1;
const auto& fc_op = static_cast<const toco::FullyConnectedOperator&>(op);
- CHECK(!fc_op.experimental_shuffled_weights)
+ CHECK(fc_op.weights_format == FullyConnectedWeightsFormat::kDefault)
<< "This graph transformation expects to run before FC weights get "
"shuffled.";
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
new file mode 100644
index 0000000000..874d8def57
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// Returns true if the given op is strictly a broadcasting operation.
+// This is commonly seen as a Concat of the same input multiple times, and is
+// often generated from Tile ops that were converted via the
+// convert_trivial_tile_to_concat transformation.
+bool IsBroadcastingOp(const Model& model, Operator* op) {
+ // Concatenation of identical inputs is usually a broadcast.
+ if (op->type == OperatorType::kConcatenation) {
+ // Verify that all inputs are the same.
+ for (int i = 1; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] != op->inputs[0]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // There are other things we could look for (Stack/etc) when needed.
+ return false;
+}
+
+} // namespace
+
+// Finds an operation that looks like a broadcast (concat of the same sources
+// along the last dimension) and drops it by relying on the ability of certain
+// binary ops to perform an implicit broadcast.
+bool FuseBroadcastIntoFollowingBinary::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->inputs.size() != 2) {
+ return false;
+ }
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ // NOTE: either of these ops may be nullptr if the input array is constant.
+ Operator* const op[2] = {
+ GetOpWithOutput(*model, binary_op->inputs[0]),
+ GetOpWithOutput(*model, binary_op->inputs[1]),
+ };
+
+ // Check whether either input is a broadcast-like concat.
+ bool is_op_0_broadcast = op[0] && IsBroadcastingOp(*model, op[0]);
+ bool is_op_1_broadcast = op[1] && IsBroadcastingOp(*model, op[1]);
+ if (!is_op_0_broadcast && !is_op_1_broadcast) {
+ // Neither input is a broadcast-looking thing.
+ AddMessageF("Neither input looks broadcasty");
+ return false;
+ } else if (is_op_0_broadcast && is_op_1_broadcast) {
+ AddMessageF(
+ "Unable to fuse broadcast into %s as both inputs (%s, %s) are "
+ "broadcasts",
+ LogName(*binary_op), op[0] ? LogName(*op[0]) : "(?)",
+ op[1] ? LogName(*op[1]) : "(?)");
+ return false;
+ }
+ int broadcast_index = is_op_0_broadcast ? 0 : 1;
+
+ // Just pull out the input of the broadcast op and pass it directly to the
+ // binary op.
+ AddMessageF("Fusing broadcast op %s into the following binary %s",
+ LogName(*op[broadcast_index]), LogName(*binary_op));
+ binary_op->inputs[broadcast_index] = op[broadcast_index]->inputs[0];
+
+ // We leave the broadcast op in; it'll get cleaned up if it's not used later.
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 8da242aa9c..8cd1298bca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -117,12 +117,14 @@ DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialAddNToAdd)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialStackToReshape)
+DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTileToConcat)
DECLARE_GRAPH_TRANSFORMATION(ConvertTrivialTransposeToReshape)
DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
+DECLARE_GRAPH_TRANSFORMATION(FuseBroadcastIntoFollowingBinary)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
@@ -133,12 +135,14 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
+DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
+DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
@@ -164,7 +168,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
-DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
@@ -190,7 +193,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantGather)
DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
-DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights)
+DECLARE_GRAPH_TRANSFORMATION(ShuffleFCWeights)
class PropagateDefaultMinMax : public GraphTransformation {
public:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index d63ee7c951..2f1bb8f0ad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -133,24 +133,20 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
}
bool HardcodeMinMaxForSplit(Model* model, Operator* op) {
- for (const auto& output : op->outputs) {
- if (model->GetArray(output).minmax) {
- LOG(WARNING) << "Skipping min-max setting for " << LogName(*op)
- << " because output " << output << " already has min-max.";
- return false;
- }
- }
// Data is in second input.
auto& input_array = model->GetArray(op->inputs[1]);
if (!input_array.minmax) {
return false;
- } else {
- for (const auto& output : op->outputs) {
- auto& array = model->GetArray(output);
+ }
+ bool changed = false;
+ for (const auto& output : op->outputs) {
+ auto& array = model->GetArray(output);
+ if (!array.minmax || !(array.GetMinMax() == input_array.GetMinMax())) {
+ changed = true;
array.GetOrCreateMinMax() = *input_array.minmax;
}
- return true;
}
+ return changed;
}
// The output of average or max pooling is within the same range as its input.
@@ -232,6 +228,14 @@ bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
return true;
}
+bool MinMaxApproximatelyEqual(const MinMax& minmax1, const MinMax& minmax2) {
+ const double magnitude =
+ std::min(minmax1.max - minmax1.min, minmax2.max - minmax2.min);
+ const double tolerated = 1e-6 * magnitude;
+ return std::abs(minmax1.min - minmax2.min) < tolerated &&
+ std::abs(minmax1.max - minmax2.max) < tolerated;
+}
+
// Propagates MinMax from any of the listed arrays, to all others.
// If multiple of these arrays have MinMax, then these are required
// to agree with each other.
@@ -254,7 +258,7 @@ bool PropagateMinMaxAmongArrays(Model* model,
for (const string& array_name : array_names) {
auto& array = model->GetArray(array_name);
if (array.minmax) {
- CHECK(*array.minmax == *reference_minmax)
+ CHECK(MinMaxApproximatelyEqual(*array.minmax, *reference_minmax))
<< "Both the following arrays have minmax, and they disagree: "
<< reference_array_name << " (" << reference_minmax->min << ","
<< reference_minmax->max << ") and " << array_name << " ("
@@ -353,7 +357,7 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForConcatenation(model, op);
break;
- case OperatorType::kTensorFlowSplit:
+ case OperatorType::kSplit:
changed = HardcodeMinMaxForSplit(model, op);
break;
@@ -362,9 +366,11 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
break;
+ case OperatorType::kResizeBilinear:
+ case OperatorType::kSlice:
case OperatorType::kStridedSlice:
case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kPad:
case OperatorType::kGather:
case OperatorType::kTranspose:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index ae3301f467..d49857cfc2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -90,12 +90,13 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
}
// Conv Op
- ConvOperator* conv_op = dynamic_cast<ConvOperator*>(
- has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0])
- : GetOpWithInput(*model, stb_op->outputs[0]));
- if (!conv_op || conv_op->type != OperatorType::kConv) {
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ if (conv_base_op->type != OperatorType::kConv) {
return false;
}
+ auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index 419a0776a6..b78efd7fc3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
if (div_or_mul_op->type == OperatorType::kDiv) {
- expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt;
} else if (div_or_mul_op->type == OperatorType::kMul) {
- expected_op_type_producing_div_or_mul_input =
- OperatorType::kTensorFlowRsqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
return false;
}
@@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* add_op = nullptr;
Operator* op_producing_add_input = nullptr;
if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
- op_producing_sqrt_or_rsqrt_input->type ==
- OperatorType::kTensorFlowMaximum) {
+ op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) {
add_op = op_producing_sqrt_or_rsqrt_input;
bool add_can_be_removed = false;
CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
@@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* sum_op =
add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
- if (sum_op->type != OperatorType::kTensorFlowSum) {
+ if (sum_op->type != OperatorType::kSum) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
@@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
- if (square_op->type != OperatorType::kTensorFlowSquare) {
+ if (square_op->type != OperatorType::kSquare) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
index e4d52476c6..705e73779b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -41,7 +41,7 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
const auto sqrt_it = model->operators.begin() + op_index;
const auto* sqrt_op = sqrt_it->get();
- if (sqrt_op->type != OperatorType::kTensorFlowSqrt) {
+ if (sqrt_op->type != OperatorType::kSqrt) {
return false;
}
@@ -52,6 +52,13 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
const Operator* square_op;
Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]);
+ if (prev_to_sqrt_op == nullptr) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected AveragePool op, but Sqrt op has no preceding op");
+ return false;
+ }
+
if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
AddMessageF(
"Giving up trying to identify L2Pool subgraph: "
@@ -65,7 +72,7 @@ bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
square_op = GetOpWithOutput(*model, avpool_op->inputs[0]);
CHECK_EQ(square_op->inputs.size(), 1);
- if (square_op->type != OperatorType::kTensorFlowSquare) {
+ if (square_op->type != OperatorType::kSquare) {
AddMessageF(
"Giving up trying to identify L2Pool subgraph: "
"expected Square op, got %s",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
index e9842524c8..c0b014b45e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -35,19 +35,24 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
return it;
}
-bool GetStateArrayForBackEdge(const Model& model,
- const string& back_edge_source_array,
- string* state_array = nullptr) {
- for (const auto& rnn_state : model.flags.rnn_states()) {
- if (back_edge_source_array == rnn_state.back_edge_source_array()) {
- // Found LSTM cell output
- if (state_array) {
- *state_array = rnn_state.state_array();
- }
- return true;
+bool ValidateSourceOp(const Model& model, const string& array_name,
+ OperatorType op_type, Operator** source_op) {
+ if (op_type == OperatorType::kNone) {
+ CHECK(!source_op);
+ } else {
+ CHECK(source_op);
+ *source_op = GetOpWithOutput(model, array_name);
+ if (*source_op == nullptr) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((*source_op)->type != op_type) {
+ return false;
}
}
- return false;
+
+ return true;
}
// Returns true if the given operator has exactly 1 input, and is connected to
@@ -62,24 +67,10 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[0], op_type, connected_op)) {
return false;
}
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (connected_op) {
- *connected_op = x;
- }
-
return true;
}
@@ -96,40 +87,15 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
- return false;
- }
-
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
return true;
}
@@ -147,57 +113,20 @@ bool MatchOperatorInputs(const Operator& op, const Model& model,
}
// Check if first input is disconnected/connected to an operator
- Operator* x = GetOpWithOutput(model, op.inputs[0]);
- if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
- return false;
- }
- if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
- return false;
- }
-
- // Check that first operator, if connected, is of correct type
- if ((x != nullptr) && (x->type != a_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[0], a_op_type, a_op)) {
return false;
}
// Check if second input is disconnected/connected to an operator
- Operator* y = GetOpWithOutput(model, op.inputs[1]);
- if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
- return false;
- }
- if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
- return false;
- }
-
- // Check that second operator, if connected, is of correct type
- if ((y != nullptr) && (y->type != b_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[1], b_op_type, b_op)) {
return false;
}
// Check if third input is disconnected/connected to an operator
- Operator* z = GetOpWithOutput(model, op.inputs[2]);
- if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
- return false;
- }
- if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
- return false;
- }
-
- // Check that third operator, if connected, is of correct type
- if ((z != nullptr) && (z->type != c_op_type)) {
+ if (!ValidateSourceOp(model, op.inputs[2], c_op_type, c_op)) {
return false;
}
- // Successfully matched. Optionally return matching input operators.
- if (a_op != nullptr) {
- *a_op = x;
- }
- if (b_op != nullptr) {
- *b_op = y;
- }
- if (c_op != nullptr) {
- *c_op = z;
- }
return true;
}
@@ -231,11 +160,6 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
&state_combine_add)) {
return false;
}
- string prev_state;
- if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0],
- &prev_state)) {
- return false;
- }
// State forget & remember addition
Operator *state_forget_mul, *state_remember_mul;
@@ -244,9 +168,7 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
&state_remember_mul)) {
return false;
}
- if (state_forget_mul->inputs[0] != prev_state) {
- return false;
- }
+ const string prev_state = state_forget_mul->inputs[0];
// State forget gate
Operator* state_forget_sig;
@@ -266,26 +188,26 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
// State remember "information" activation function
Operator* fc_output_split;
- if (!MatchOperatorInputs(*state_info_tanh, *model,
- OperatorType::kTensorFlowSplit, &fc_output_split)) {
+ if (!MatchOperatorInputs(*state_info_tanh, *model, OperatorType::kSplit,
+ &fc_output_split)) {
return false;
}
// State remember gate activation function
Operator* tmp;
- if (!MatchOperatorInputs(*state_remember_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*state_remember_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
// State forget gate activation function
- if (!MatchOperatorInputs(*state_forget_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*state_forget_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
// Fully connected output activation function
- if (!MatchOperatorInputs(*fc_output_sig, *model,
- OperatorType::kTensorFlowSplit, &tmp) ||
+ if (!MatchOperatorInputs(*fc_output_sig, *model, OperatorType::kSplit,
+ &tmp) ||
(tmp != fc_output_split)) {
return false;
}
@@ -306,8 +228,8 @@ bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
return false;
}
- if (static_cast<FullyConnectedOperator*>(fully_connected)
- ->experimental_shuffled_weights) {
+ if (static_cast<FullyConnectedOperator*>(fully_connected)->weights_format !=
+ FullyConnectedWeightsFormat::kDefault) {
// Not yet implemented: experimental shuffled weights in fused LSTM cell.
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
index 3f768bfee1..5b6a984ee1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc
@@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs,
- // do not need to merge cell inputs.
- if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) {
+ // Already a compact LstmCell. Do not need to merge cell inputs.
+ const auto* src_lstm_op = static_cast<LstmCellOperator*>(src_op);
+ if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL ||
+ src_lstm_op->inputs.size() != kExtendedLstmInputCount) {
return false;
}
@@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LSTM cell operator (use basic 5 inputs kernel).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC;
// Compact LstmCell's 5 inputs.
lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
index 8e66323bd7..46d1fce50e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc
@@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
return false;
}
- // Already an extended LstmCell with kExtendedLstmInputCount of inputs,
- // do not need to split cell inputs.
- if (curr_op->inputs.size() == kExtendedLstmInputCount) {
+ const auto* curr_lstm_op = static_cast<LstmCellOperator*>(curr_op);
+ // Already an extended LstmCell. Do not need to split cell inputs.
+ if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC ||
+ curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) {
return false;
}
@@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc).
auto lstm_cell_op = absl::make_unique<LstmCellOperator>();
+ lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL;
lstm_cell_op->inputs.resize(kExtendedLstmInputCount);
int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT])
.shape()
@@ -72,6 +74,12 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
lstm_cell_op->inputs[kInputTensor] =
curr_op->inputs[LstmCellOperator::ACTIV_OUTPUT];
+ // Previous states.
+ lstm_cell_op->inputs[kInputActivationStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT];
+ lstm_cell_op->inputs[kInputCellStateTensor] =
+ curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT];
+
// Get original weight tensor and decompose 1 tensor to 8 sub tensors.
Array& kernel =
model->GetArray(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
@@ -158,10 +166,6 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) {
// Erase curr lstm op being replaced.
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::WEIGHTS_INPUT], model);
DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::BIASES_INPUT], model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT],
- model);
- DeleteArrayIfUnused(curr_op->inputs[LstmCellOperator::PREV_STATE_INPUT],
- model);
model->operators.erase(FindOp(*model, curr_op));
return true;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
index 30be4ac0aa..b90a156a0d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -74,14 +74,30 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
const auto* relu_neg_input_op = GetOpWithOutput(*model, mul_op->inputs[1]);
if (relu_neg_input_op == nullptr ||
- relu_neg_input_op->type != OperatorType::kNeg ||
- relu_neg_input_op->fused_activation_function !=
- FusedActivationFunctionType::kRelu ||
relu_neg_input_op->inputs.size() != 1) {
return false;
}
- if (relu_input_op->inputs[0] != relu_neg_input_op->inputs[0]) {
+ const Operator* final_input_op;
+ if (relu_neg_input_op->type == OperatorType::kNeg &&
+ relu_neg_input_op->fused_activation_function ==
+ FusedActivationFunctionType::kRelu) {
+ // This detects a Neg op with fused Relu activation function.
+ final_input_op = relu_neg_input_op;
+ } else {
+ // This detects a Neg op followed by a separated Relu op.
+ const auto* neg_input_op =
+ GetOpWithOutput(*model, relu_neg_input_op->inputs[0]);
+ if (neg_input_op == nullptr || neg_input_op->inputs.size() != 1 ||
+ relu_neg_input_op->type != OperatorType::kRelu ||
+ relu_neg_input_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ return false;
+ }
+ final_input_op = neg_input_op;
+ }
+
+ if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
return false;
}
@@ -112,7 +128,6 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
-
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index bddb563206..94820a0166 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -60,24 +60,22 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
// Follow sequences of min+max and max+min. First get the leading op.
const auto op_it = model->operators.begin() + op_index;
const auto* op_0 = op_it->get();
- if (op_0->type != OperatorType::kTensorFlowMinimum &&
- op_0->type != OperatorType::kTensorFlowMaximum) {
+ if (op_0->type != OperatorType::kMinimum &&
+ op_0->type != OperatorType::kMaximum) {
return false;
}
// Get the paired op and ensure it's the counter to the first.
const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]);
if (!op_1 ||
- (op_1->type != OperatorType::kTensorFlowMinimum &&
- op_1->type != OperatorType::kTensorFlowMaximum) ||
+ (op_1->type != OperatorType::kMinimum &&
+ op_1->type != OperatorType::kMaximum) ||
op_0->type == op_1->type) {
return false;
}
- const auto* min_op =
- op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1;
- const auto* max_op =
- op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
+ const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
+ const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
index 1c32a78169..6d8603a113 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h
@@ -47,10 +47,14 @@ enum ExtendedLstmCellInputs {
kOutputGateBiasTensor = 15,
kProjectionWeightsTensor = 16, // Optional
kProjectionBiasTensor = 17, // Optional
- kExtendedLstmInputCount = 18
+ kInputActivationStateTensor = 18,
+ // The op can handle 18 inputs or 20 inputs.
+ kInputCellStateTensor = 19,
+ kExtendedLstmInputCount = 20,
};
enum ExtendedLstmCellOutputs {
+ // TODO(ycling): Make the 2 output state tensors optional.
kOutputStateTensor = 0,
kCellStateTensor = 1,
kOutputTensor = 2,
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
index 5065004093..95bc7f7d4b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -106,7 +106,7 @@ bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
std::size_t op_index) {
auto it = model->operators.begin() + op_index;
auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
- it->get(), OperatorType::kTensorFlowReshape);
+ it->get(), OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
new file mode 100644
index 0000000000..7f44c65285
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -0,0 +1,178 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+#include <algorithm>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+bool IsTailOfShape(const Shape& tail, const Shape& shape) {
+ // Return true if 'tail' dimensions are the same as the ending dimensions of
+ // 'shape'.
+
+ int shape_end = shape.dimensions_count() - 1;
+ int tail_end = tail.dimensions_count() - 1;
+
+ if (tail_end > shape_end) {
+ // tail cannot be longer than shape.
+ return false;
+ }
+
+ // Walk dimensions back to front and compare
+ for (int i = 0; i <= tail_end; i++) {
+ if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+// If a binary operator is doing a broadcast operation from a constant array,
+// and the constant array shape is the tail of both the other input shape, and a
+// subsequent reshape op's output shape, we can swap their order. Since we
+// prefer to have reshape ops after mathematic ops, this can allow for the
+// collapsing of some reshapes. The WaveNet model in particular benefits from
+// this transformation.
+//
+// Note we are testing for one particular case of a broader set of possible
+// binary-reshape op transformations. This transformation could be generalized.
+bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ Operator* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kFloorDiv &&
+ binary_op->type != OperatorType::kFloorMod &&
+ binary_op->type != OperatorType::kMinimum &&
+ binary_op->type != OperatorType::kMaximum &&
+ binary_op->type != OperatorType::kLess &&
+ binary_op->type != OperatorType::kLessEqual &&
+ binary_op->type != OperatorType::kGreater &&
+ binary_op->type != OperatorType::kGreaterEqual) {
+ return false;
+ }
+
+ // BINARY OP INPUT CHECKS
+ CHECK_EQ(binary_op->inputs.size(), 2);
+ const bool input_is_const[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!input_is_const[0] && !input_is_const[1]) {
+ // To limit our scope, we require one constant input. Though there's no
+ // reason this transformation wouldn't work with all variable inputs.
+ return false;
+ }
+ if (input_is_const[0] && input_is_const[1]) {
+ // Both inputs are constants. Leave this for constants propagation.
+ return false;
+ }
+ const int constant_input_idx = input_is_const[0] ? 0 : 1;
+ const int variable_input_idx = input_is_const[0] ? 1 : 0;
+ CHECK(input_is_const[constant_input_idx]);
+ CHECK(!input_is_const[variable_input_idx]);
+
+ const auto& variable_input_array =
+ model->GetArray(binary_op->inputs[variable_input_idx]);
+ if (!variable_input_array.has_shape()) {
+ AddMessageF(
+ "Not moving %s because it's non-constant input shape is not resolved.",
+ LogName(*binary_op));
+ return false;
+ }
+ if (!IsTailOfShape(
+ model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
+ model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
+ // Constant array shape must be the latter part of the variable shape.
+ return false;
+ }
+
+ // RESHAPE OP CHECKS
+ auto reshape_it =
+ FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]);
+ if (reshape_it == model->operators.end()) {
+ AddMessageF("Not moving %s because it's variable input is not connected.",
+ LogName(*binary_op));
+ return false;
+ }
+ Operator* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kReshape) {
+ AddMessageF("Not moving %s because the preceding %s is not a reshape op",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+ const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
+ if (!reshape_input_array.has_shape()) {
+ AddMessageF(
+ "Not moving %s because it's non-constant input shape is not resolved "
+ "yet",
+ LogName(*binary_op));
+ return false;
+ }
+ if (!IsTailOfShape(
+ model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
+ model->GetArray(reshape_op->outputs[0]).shape())) {
+ // Constant array shape must be the latter part of the binary op output
+ // shape.
+ return false;
+ }
+
+ // EXTRA CHECKS ON CONNECTING ARRAY
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (binary_op->inputs[variable_input_idx] == output_array) {
+ AddMessageF(
+ "Not moving %s because the output of reshape op %s is an output op.",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+ }
+ int count_ops_consuming_output =
+ CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not moving %s because the output of reshape op %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+
+ // SWAP ORDER OF BINARY AND RESHAPE OPS
+ AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op),
+ LogName(*reshape_op));
+
+ // Swap op input and outputs
+ std::iter_swap(reshape_op->inputs.begin(),
+ binary_op->inputs.begin() + variable_input_idx);
+ std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin());
+
+ // Swap operator ordering
+ std::iter_swap(binary_it, reshape_it);
+
+ // Clear binary output shape so it will be re-propagated
+ model->GetArray(binary_op->outputs[0]).clear_shape();
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 64096fb069..670bcf64e7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -56,20 +56,22 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
// These operators unconditionally produce float outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
break;
- case OperatorType::kTensorFlowLess:
- case OperatorType::kTensorFlowLessEqual:
- case OperatorType::kTensorFlowGreater:
- case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kLess:
+ case OperatorType::kLessEqual:
+ case OperatorType::kGreater:
+ case OperatorType::kGreaterEqual:
+ case OperatorType::kEqual:
+ case OperatorType::kNotEqual:
// These operators unconditionally produce bool outputs
SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
break;
case OperatorType::kRank:
- case OperatorType::kTensorFlowShape:
+ case OperatorType::kShape:
// These operators only produce int32 outputs.
SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
break;
- case OperatorType::kTensorFlowSplit:
- case OperatorType::kTensorFlowConcat:
+ case OperatorType::kSplit:
+ case OperatorType::kConcat:
case OperatorType::kFill: {
// These operators produce an output with the same type as their 2nd input
CHECK_GE(op->inputs.size(), 2);
@@ -98,6 +100,13 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[0]).data_type = argmax_op->output_data_type;
break;
}
+ case OperatorType::kArgMin: {
+ // Data type of the ArgMin op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* argmin_op = static_cast<ArgMinOperator*>(op);
+ model->GetArray(op->outputs[0]).data_type = argmin_op->output_data_type;
+ break;
+ }
case OperatorType::kRange: {
auto* range_op = static_cast<RangeOperator*>(op);
// Output type of the Range op can be set via an attribute
@@ -133,7 +142,7 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
model->GetArray(op->outputs[1]).data_type = ArrayDataType ::kInt32;
break;
}
- case OperatorType::kTensorFlowUnsupported: {
+ case OperatorType::kUnsupported: {
auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
// Some output tensors from the op could be eliminated by optimization.
// This can make unsupported_op->output_data_types have more elements than
@@ -173,6 +182,14 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
+ case OperatorType::kPow: {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK(model->GetArray(op->inputs[0]).data_type ==
+ model->GetArray(op->inputs[1]).data_type);
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
index 50b90e7c2b..cd078ef189 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
@@ -25,6 +25,14 @@ limitations under the License.
namespace toco {
+namespace {
+
+bool SupportsMinMax(const Array& array) {
+ return array.data_type == ArrayDataType::kFloat;
+}
+
+} // namespace
+
// Propagates default min/max values to any operator input/output array that
// is missing them.
//
@@ -39,14 +47,16 @@ bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
for (const auto& input : op->inputs) {
auto& input_array = model->GetArray(input);
- if (!input_array.minmax && !input_array.buffer) {
+ if (!input_array.minmax && !input_array.buffer &&
+ SupportsMinMax(input_array)) {
did_change |= SetArrayMinMax(input, &input_array);
}
}
for (const auto& output : op->outputs) {
auto& output_array = model->GetArray(output);
- if (!output_array.minmax && !output_array.buffer) {
+ if (!output_array.minmax && !output_array.buffer &&
+ SupportsMinMax(output_array)) {
did_change |= SetArrayMinMax(output, &output_array);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 6d51fc8c31..53fc87da7b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -27,11 +27,15 @@ namespace toco {
namespace {
-void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
+bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
ArrayDataType new_data_type,
const MinMax* new_minmax) {
// Ensure the array ends up in the new type (if it hasn't yet been quantized).
- array->final_data_type = new_data_type;
+ bool changed = false;
+ if (array->final_data_type != new_data_type) {
+ array->final_data_type = new_data_type;
+ changed = true;
+ }
if (array->minmax && array->quantization_params) {
// The array is already quantized and has min/max info.
@@ -65,15 +69,27 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
array_minmax.min = min;
array_minmax.max = max;
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- array_minmax, array->quantization_params.get());
+ switch (new_data_type) {
+ case ArrayDataType::kUint8:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ array_minmax, array->quantization_params.get());
+ break;
+ case ArrayDataType::kInt16:
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ array_minmax, array->quantization_params.get());
+ break;
+ default:
+ CHECK(false) << "Unsupported quantized data type: "
+ << ArrayDataTypeName(new_data_type);
+ return false;
+ }
// Directly change the type as the array was already quantized.
array->data_type = new_data_type;
- } else {
+ changed = true;
+ } else if (!array->quantization_params) {
// Array has not yet been quantized so we can just set the final data type
// and assign the new min/max value (if provided).
- CHECK(!array->quantization_params);
if (!array->minmax && new_minmax) {
transformation->AddMessageF("Forcing new minmax to %g,%g (%s)",
@@ -82,16 +98,19 @@ void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
auto& array_minmax = array->GetOrCreateMinMax();
array_minmax.min = new_minmax->min;
array_minmax.max = new_minmax->max;
+ changed = true;
}
}
+
+ return changed;
}
// Returns true if the op blocks our backward recursive data type propagation.
bool DoesOpBlockBackwardPropagation(const Operator& op) {
switch (op.type) {
case OperatorType::kConcatenation:
- case OperatorType::kTensorFlowConcat:
- case OperatorType::kTensorFlowConcatV2:
+ case OperatorType::kConcat:
+ case OperatorType::kConcatV2:
// Concat shouldn't block propagation, but we do expect that all inputs
// have the same range.
return false;
@@ -100,9 +119,10 @@ bool DoesOpBlockBackwardPropagation(const Operator& op) {
// FakeQuant so make sure we move across them.
case OperatorType::kGather:
// Gathers need their parameters changed to the appropriate data type.
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
case OperatorType::kSelect:
+ case OperatorType::kTile:
// Reshapes and transposes don't change values.
return false;
default:
@@ -120,10 +140,13 @@ bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
// Ignore gather indices.
return input_index != 0;
break;
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
// Ignore reshape/transpose shapes/dimensions.
return input_index != 0;
+ case OperatorType::kTile:
+ // Ignore tile multiples.
+ return input_index != 0;
default:
return false;
}
@@ -155,9 +178,8 @@ bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation,
"Adjusting input final data type of array %s from %s to %s", input,
ArrayDataTypeName(input_array.final_data_type),
ArrayDataTypeName(new_data_type));
- did_change = true;
- ChangeArrayDataType(transformation, &input_array, new_data_type,
- &new_minmax);
+ did_change |= ChangeArrayDataType(transformation, &input_array,
+ new_data_type, &new_minmax);
// Walk up into all ops producing the inputs to this op.
for (auto& producing_op : model->operators) {
@@ -208,9 +230,8 @@ bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
"Adjusting output final data type of array %s from %s to %s", output,
ArrayDataTypeName(output_array.final_data_type),
ArrayDataTypeName(new_data_type));
- did_change = true;
- ChangeArrayDataType(transformation, &output_array, new_data_type,
- nullptr);
+ did_change |= ChangeArrayDataType(transformation, &output_array,
+ new_data_type, nullptr);
// Walk down into all ops consuming the output of this op.
for (auto& consuming_op : model->operators) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index adb241da32..4f95c57451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -120,49 +120,7 @@ void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
CHECK(output_array->has_shape());
}
-int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
- const string& weights_name = op.inputs[1];
- const auto& weights_shape = model.GetArray(weights_name).shape();
- if (op.type == OperatorType::kConv ||
- op.type == OperatorType::kFullyConnected) {
- return weights_shape.dims(0);
- } else if (op.type == OperatorType::kDepthwiseConv) {
- return weights_shape.dims(3);
- } else {
- LOG(FATAL) << "Unhandled operator type";
- }
-}
-
-bool EnsureBiasVectorShape(Model* model, Operator* op) {
- const string& weights_name = op->inputs[1];
- const auto& weights_array = model->GetArray(weights_name);
- // Yield until weights shape has been resolved.
- if (!weights_array.has_shape()) {
- return false;
- }
-
- if (op->inputs.size() < 3) {
- return false;
- }
- auto& bias_array = model->GetArray(op->inputs[2]);
- if (bias_array.has_shape()) {
- return true;
- }
-
- const int output_depth = GetOutputDepthFromWeights(*model, *op);
- bias_array.copy_shape(Shape({output_depth}));
-
- auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
- float_buffer.data.resize(output_depth, 0);
-
- return true;
-}
-
void ProcessConvOperator(Model* model, ConvOperator* op) {
- if (!EnsureBiasVectorShape(model, op)) {
- return;
- }
-
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
@@ -211,12 +169,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
// might as well calculate the output shape and ensure it matches the
// specified one
- // Check if we have already run.
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- return;
- }
-
// SPECIFIED OUTPUT SHAPE
// The below is the specified, or prescribed output shape, _given_ to the
// operator as an input.
@@ -278,20 +230,26 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< "TransposeConv input shape must have 4 dimensions. Input \""
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
- CHECK_EQ(input_shape.dims(3), weights_shape.dims(0))
+ CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
<< "Input shape depth and weight depth do not agree";
// Set the output shape according to the specified output shape.
std::vector<int32> const& specified_output_shape =
specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_array = model->GetArray(op->outputs[0]);
*(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
-}
-void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
- if (!EnsureBiasVectorShape(model, op)) {
- return;
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ im2col_array.copy_shape(
+ Shape{specified_output_shape[0], specified_output_shape[1],
+ specified_output_shape[2], input_depth * kheight * kwidth});
}
+}
+void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
@@ -321,7 +279,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
if (!op->depth_multiplier) {
op->depth_multiplier = output_depth / input_depth;
}
- QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
+ CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
<< "input/output depths and depth_multiplier don't match";
const int kheight = weights_shape.dims(1);
@@ -406,10 +364,6 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) {
}
void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
- if (!EnsureBiasVectorShape(model, op)) {
- return;
- }
-
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
@@ -568,11 +522,11 @@ void ProcessAddNOperator(Model* model, Operator* op) {
bool KeepDims(const Operator& op) {
switch (op.type) {
- case OperatorType::kTensorFlowMin:
+ case OperatorType::kMin: // Reduction Min
return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
- case OperatorType::kTensorFlowMax:
+ case OperatorType::kMax: // Reduction Max
return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
- case OperatorType::kTensorFlowSum:
+ case OperatorType::kSum:
return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
@@ -1085,9 +1039,6 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) {
QCHECK_GE(input_shape.dimensions_count(), 1);
op->input_rank = input_shape.dimensions_count();
- // We only support 1-D indices.
- QCHECK_EQ(indices_shape.dimensions_count(), 1);
-
// Copy the input dimensions to the output except for dimension 0,
// where the dimension of indices_shape is used.
// TODO(mgubin): if axis != 0 this is not true, change when it's supported.
@@ -1337,8 +1288,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
op->begin_mask, op->start_indices, op->strides,
input_array.shape().dims().data(), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
+ input_array.shape().dims().data(), axis, start_index);
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
@@ -1453,7 +1404,8 @@ void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
}
}
-void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
+template <typename Op>
+void ProcessArgMinMaxOperator(Model* model, Op* op) {
CHECK_EQ(op->inputs.size(), 2);
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1505,6 +1457,48 @@ void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
}
}
+void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // We have already run.
+ return;
+ }
+
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.has_shape()) {
+ // Yield until input dims have been resolved.
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ auto& multiples_array = model->GetArray(op->inputs[1]);
+ if (!multiples_array.has_shape()) {
+ // Yield until multiples shape been resolved.
+ return;
+ }
+ if (!multiples_array.buffer) {
+ // Yield until the multiples is constant.
+ return;
+ }
+ CHECK(multiples_array.data_type == ArrayDataType::kInt32)
+ << "Tile multiples input must be int32";
+
+ std::vector<int32> const& multiples =
+ multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(multiples.size(), input_shape.dimensions_count())
+ << "Tile multiples input " << op->inputs[1]
+ << " must be same length as input dimensions";
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->resize(multiples.size());
+ for (int i = 0; i < mutable_dims->size(); ++i) {
+ (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1531,14 +1525,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kLogistic:
case OperatorType::kTanh:
case OperatorType::kLocalResponseNormalization:
- case OperatorType::kTensorFlowIdentity:
+ case OperatorType::kIdentity:
case OperatorType::kFakeQuant:
case OperatorType::kNeg:
- case OperatorType::kTensorFlowRsqrt:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
- case OperatorType::kTensorFlowAll:
- case OperatorType::kTensorFlowAssert:
+ case OperatorType::kRsqrt:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
+ case OperatorType::kAll:
+ case OperatorType::kAssert:
case OperatorType::kCast:
case OperatorType::kFloor:
case OperatorType::kExp:
@@ -1557,12 +1551,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kDiv:
case OperatorType::kFloorDiv:
case OperatorType::kFloorMod:
- case OperatorType::kTensorFlowLess:
- case OperatorType::kTensorFlowLessEqual:
- case OperatorType::kTensorFlowGreater:
- case OperatorType::kTensorFlowMaximum:
- case OperatorType::kTensorFlowMinimum:
- case OperatorType::kTensorFlowGreaterEqual:
+ case OperatorType::kLess:
+ case OperatorType::kLessEqual:
+ case OperatorType::kGreater:
+ case OperatorType::kMaximum: // Element-wise Maximum
+ case OperatorType::kMinimum: // Element-wise Minimum
+ case OperatorType::kGreaterEqual:
+ case OperatorType::kEqual:
+ case OperatorType::kNotEqual:
+ case OperatorType::kPow:
ProcessSimpleBinaryOperator(model, op);
break;
case OperatorType::kAddN:
@@ -1595,7 +1592,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessFullyConnectedOperator(model,
static_cast<FullyConnectedOperator*>(op));
break;
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
ProcessTensorFlowReshapeOperator(
model, static_cast<TensorFlowReshapeOperator*>(op));
break;
@@ -1608,9 +1605,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kL2Pool:
ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
break;
- case OperatorType::kTensorFlowMin:
- case OperatorType::kTensorFlowMax:
- case OperatorType::kTensorFlowSum:
+ case OperatorType::kMin: // Reduction Min
+ case OperatorType::kMax: // Reduction Max
+ case OperatorType::kSum:
case OperatorType::kMean:
ProcessTensorFlowReductionOperator(model, op);
break;
@@ -1621,34 +1618,26 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
break;
- case OperatorType::kTensorFlowTile:
- // We don't currently implement the propagation of fixed sizes through
- // a TensorFlow Tile.
- //
- // Fortunately, we don't need to: so far, we have only dealt with Tile
- // or Slice ops in subgraphs that are identified as L2Normalization.
- // See IdentifyL2Normalization.
- break;
- case OperatorType::kTensorFlowSwitch:
+ case OperatorType::kSwitch:
// We can't know the sizes of the outputs until we have resolved the
// predicate, and once we have resolved the predicate, the whole
// Switch node will get resolved away.
// See ResolveTensorFlowSwitch.
break;
- case OperatorType::kTensorFlowMerge:
+ case OperatorType::kMerge:
// No need to bother resolving TensorFlow Merge ops: other graph
// transformations will remove them anyway.
// See ResolveTensorFlowMerge.
break;
- case OperatorType::kTensorFlowSplit:
+ case OperatorType::kSplit:
ProcessTensorFlowSplitOperator(model,
static_cast<TensorFlowSplitOperator*>(op));
break;
case OperatorType::kSqueeze:
ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
break;
- case OperatorType::kTensorFlowConcat:
- case OperatorType::kTensorFlowConcatV2:
+ case OperatorType::kConcat:
+ case OperatorType::kConcatV2:
// Unimplemented, hopefully another graph transformation will
// drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
// will resolve this node to a DepthConcatenation, or else we have
@@ -1664,7 +1653,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kRank:
ProcessRankOperator(model, static_cast<RankOperator*>(op));
break;
- case OperatorType::kTensorFlowShape:
+ case OperatorType::kShape:
ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
break;
case OperatorType::kStack:
@@ -1685,7 +1674,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;
case OperatorType::kBatchMatMul:
- case OperatorType::kTensorFlowMatMul:
+ case OperatorType::kMatMul:
// MatMul operators are converted to FullyConnected, after which their
// shapes are propagated.
break;
@@ -1708,9 +1697,14 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
static_cast<StridedSliceOperator*>(op));
break;
case OperatorType::kArgMax:
- ProcessArgMaxOperator(model, static_cast<ArgMaxOperator*>(op));
+ ProcessArgMinMaxOperator<ArgMaxOperator>(
+ model, static_cast<ArgMaxOperator*>(op));
break;
- case OperatorType::kTensorFlowUnsupported:
+ case OperatorType::kArgMin:
+ ProcessArgMinMaxOperator<ArgMinOperator>(
+ model, static_cast<ArgMinOperator*>(op));
+ break;
+ case OperatorType::kUnsupported:
break;
case OperatorType::kSvdf:
ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
@@ -1732,6 +1726,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
ProcessSparseToDenseOperator(model,
static_cast<SparseToDenseOperator*>(op));
break;
+ case OperatorType::kTile:
+ ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 142841fcc4..58885b4950 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -33,7 +33,7 @@ namespace {
bool SupportsQuantization(const Operator& op) {
auto type = op.type;
- if (type == OperatorType::kTensorFlowUnsupported) {
+ if (type == OperatorType::kUnsupported) {
auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
return unsupported->quantized;
}
@@ -42,25 +42,25 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kConcatenation ||
type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
- type == OperatorType::kTensorFlowMinimum ||
- type == OperatorType::kTensorFlowMaximum ||
+ type == OperatorType::kMinimum || type == OperatorType::kMaximum ||
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
- type == OperatorType::kLogSoftmax ||
- type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
+ type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
+ type == OperatorType::kResizeBilinear ||
+ type == OperatorType::kSplit || type == OperatorType::kSub ||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
- type == OperatorType::kPadV2 ||
- type == OperatorType::kTensorFlowReshape ||
+ type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
+ type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
type == OperatorType::kDepthToSpace ||
type == OperatorType::kLstmCell || type == OperatorType::kGather ||
type == OperatorType::kTranspose || type == OperatorType::kMean ||
- type == OperatorType::kTensorFlowGreater ||
- type == OperatorType::kTensorFlowGreaterEqual ||
- type == OperatorType::kTensorFlowLess ||
- type == OperatorType::kTensorFlowLessEqual ||
- type == OperatorType::kSelect;
+ type == OperatorType::kGreater ||
+ type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
+ type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
+ type == OperatorType::kArgMax || type == OperatorType::kRelu ||
+ type == OperatorType::kRelu1 || type == OperatorType::kRelu6;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -326,14 +326,15 @@ bool ChooseQuantizationForOperatorOutput(
output, OperatorTypeName(op.type));
return true;
}
- if ((op.type == OperatorType::kDepthToSpace) ||
- (op.type == OperatorType::kSpaceToDepth) ||
- (op.type == OperatorType::kTensorFlowReshape) ||
- (op.type == OperatorType::kTensorFlowSplit) ||
- (op.type == OperatorType::kConcatenation &&
- model->flags.change_concat_input_ranges())) {
+ if ((op.type == OperatorType::kConcatenation &&
+ model->flags.change_concat_input_ranges()) ||
+ op.type == OperatorType::kDepthToSpace ||
+ op.type == OperatorType::kSpaceToDepth ||
+ op.type == OperatorType::kReshape || op.type == OperatorType::kSplit ||
+ op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 ||
+ op.type == OperatorType::kRelu6) {
int data_input_index = 0;
- if (op.type == OperatorType::kTensorFlowSplit) {
+ if (op.type == OperatorType::kSplit) {
data_input_index = 1;
}
// Copying and rearrangement ops should preserve the quantization parameters
@@ -506,36 +507,47 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
// Check if the output of that Dequantize op was not used by any
// other operator. We will then erase that Dequantize op.
if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
- // If any of the model's output_arrays was pointing to the
- // Dequantize op's output, let it point to the Dequantize op's
- // input instead.
- for (int i = 0; i < model->flags.output_arrays_size(); i++) {
- if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
- // TODO(b/78013785): never rename output arrays.
- if (IsInputArray(*model, dequantize_op->inputs[0])) {
- // The op input is an input array and the output is an output
- // array and we can't have an array be both. Insert a copy
- // op to ensure the two arrays stay separate.
- AddMessageF(
- "Tried to rename output array %d while removing dequant "
- "op %s but array is also an input; inserting copy %s "
- "-> %s",
- i, LogName(*dequantize_op), model->flags.output_arrays(i),
- dequantize_op->inputs[0]);
- InsertCopyOperator(model, dequantize_op->inputs[0],
- dequantize_op->outputs[0]);
- } else {
- // Op output is strictly used as an output array, so we can
- // just rename the array and directly bypass the op.
- AddMessageF(
- "Renaming output array %d after removing dequant op %s: "
- "%s -> %s",
- i, LogName(*dequantize_op), model->flags.output_arrays(i),
- dequantize_op->inputs[0]);
- model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
- model->EraseArray(dequantize_op->outputs[0]);
+ if (IsDiscardableArray(*model, dequantize_op->outputs[0])) {
+ // Usual case: we can just discard the dequantize output.
+ model->EraseArray(dequantize_op->outputs[0]);
+ } else {
+ // The dequantize output is not discardable. Special care needed.
+ // If any of the model's output_arrays was pointing to the
+ // Dequantize op's output, let it point to the Dequantize op's
+ // input instead.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) ==
+ dequantize_op->outputs[0]) {
+ // TODO(b/78013785): never rename output arrays.
+ if (IsInputArray(*model, dequantize_op->inputs[0])) {
+ // The op input is an input array and the output is an
+ // output array and we can't have an array be both. Insert a
+ // copy op to ensure the two arrays stay separate.
+ AddMessageF(
+ "Tried to rename output array %d while removing "
+ "dequant "
+ "op %s but array is also an input; inserting copy %s "
+ "-> %s",
+ i, LogName(*dequantize_op),
+ model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ InsertCopyOperator(model, dequantize_op->inputs[0],
+ dequantize_op->outputs[0]);
+ } else {
+ // Op output is strictly used as an output array, so we can
+ // just rename the array and directly bypass the op.
+ AddMessageF(
+ "Renaming output array %d after removing dequant op "
+ "%s: "
+ "%s -> %s",
+ i, LogName(*dequantize_op),
+ model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
+ }
+ break;
}
- break;
}
}
model->operators.erase(dequantize_it);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
new file mode 100644
index 0000000000..88ea0945e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+// The minimum number of elements a weights array must have to be quantized
+// by this transformation.
+// TODO(suharshs): Make this minimum size configurable.
+const int kWeightsMinSize = 1024;
+
+// Gets the quantization params from the float array.
+void GetQuantizationParamsFromArray(const Array& array,
+ QuantizationParams* params) {
+ const std::vector<float>& float_vals =
+ array.GetBuffer<ArrayDataType::kFloat>().data;
+ auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
+ MinMax toco_minmax;
+ toco_minmax.min = *minmax.first;
+ toco_minmax.max = *minmax.second;
+ GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params);
+}
+
+} // namespace
+
+bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ Operator* op = op_it->get();
+
+ // Get the weights tensor, if the current operator has one.
+ int weights_index;
+ if (op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected) {
+ weights_index = 1;
+ } else if (op->type == OperatorType::kLstmCell) {
+ weights_index = LstmCellOperator::WEIGHTS_INPUT;
+ } else {
+ return false;
+ }
+
+ // Return early if the array isn't a constant param, this can happen in early
+ // transformation passes until transpose operations following the weight array
+ // are resolved.
+ const string weights = op->inputs[weights_index];
+ if (!IsConstantParameterArray(*model, weights)) {
+ return false;
+ }
+
+ // Return early if the weight tensor is not type float.
+ Array& weights_array = model->GetArray(weights);
+ if (weights_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+
+ // Return early if the tensor is too small. Small tensors don't take up too
+ // much space and can result in bad quantization results.
+ if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
+ kWeightsMinSize) {
+ return false;
+ }
+
+ // Quantize the weight tensor to type kUint8.
+ QuantizationParams params;
+ GetQuantizationParamsFromArray(weights_array, &params);
+ QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
+
+ // Insert a Dequantize operation after the quantized weights tensor.
+ auto* dequantize_op = new DequantizeOperator;
+ model->operators.emplace(op_it, dequantize_op);
+
+ // Create a new intermediate tensor to connect the Dequantize op to the
+ // original op.
+ const string dequantized_output =
+ AvailableArrayName(*model, weights + "_dequantized");
+ Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+
+ // Connect up the new Dequantize op with the weights and original op.
+ op->inputs[weights_index] = dequantized_output;
+ dequantize_op->inputs = {weights};
+ dequantize_op->outputs = {dequantized_output};
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
index 35a0c46532..73ad326299 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -26,7 +26,7 @@ namespace toco {
bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
const auto assert_it = model->operators.begin() + op_index;
const auto* assert_op = assert_it->get();
- if (assert_op->type != OperatorType::kTensorFlowAssert) {
+ if (assert_op->type != OperatorType::kAssert) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
index 404269bbfd..7ec7752f25 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -28,7 +28,7 @@ namespace toco {
bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
const auto passthru_it = model->operators.begin() + op_index;
const auto* passthru_op = passthru_it->get();
- if (passthru_op->type != OperatorType::kTensorFlowIdentity) {
+ if (passthru_op->type != OperatorType::kIdentity) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index a950fe6442..9f5d8b9450 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -97,7 +97,7 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
"Cannot remove %s, neither its main input nor its output may be "
"discarded",
LogName(*passthru_op));
- if (passthru_op->type != OperatorType::kTensorFlowReshape &&
+ if (passthru_op->type != OperatorType::kReshape &&
model->GetArray(main_input_name).has_shape()) {
// We can't remove either array but we can remove the op. Converting it to
// a reshape gives us some hope of later on fixing that (either in the
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
index eaee1c662b..142c876b15 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_min_max.cc
@@ -47,11 +47,11 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
double clamp_min;
double clamp_max;
switch (op_type) {
- case OperatorType::kTensorFlowMinimum:
+ case OperatorType::kMinimum: // Element-wise Minimum
clamp_min = -std::numeric_limits<double>::infinity();
clamp_max = clamp_value;
break;
- case OperatorType::kTensorFlowMaximum:
+ case OperatorType::kMaximum: // Element-wise Maximum
clamp_min = clamp_value;
clamp_max = std::numeric_limits<double>::infinity();
break;
@@ -72,8 +72,8 @@ bool IsTrivialMinMax(GraphTransformation* transformation, const Model& model,
bool RemoveTrivialQuantizedMinMax::Run(Model* model, std::size_t op_index) {
const auto it = model->operators.begin() + op_index;
auto* op = it->get();
- if ((op->type != OperatorType::kTensorFlowMinimum &&
- op->type != OperatorType::kTensorFlowMaximum) ||
+ if ((op->type != OperatorType::kMinimum &&
+ op->type != OperatorType::kMaximum) ||
op->inputs.size() != 2) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
index e28d8cf01e..404f27e067 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -30,7 +30,7 @@ namespace {
bool IsReshapeTrivial(const Model& model, const Operator& op,
RemoveTrivialReshape* transformation) {
- CHECK(op.type == OperatorType::kTensorFlowReshape);
+ CHECK(op.type == OperatorType::kReshape);
// One way in which a reshape can be trivial is if its
// output shape is == its input shape
@@ -58,7 +58,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
// is only consumed by another reshape.
if (CountOpsWithInput(model, op.outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(model, op.outputs[0]);
- if (next_op->type == OperatorType::kTensorFlowReshape) {
+ if (next_op->type == OperatorType::kReshape) {
transformation->AddMessageF(
"%s is trivial because its output is only consumed by another "
"Reshape op %s",
@@ -75,7 +75,7 @@ bool IsReshapeTrivial(const Model& model, const Operator& op,
bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
- if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ if (reshape_op->type != OperatorType::kReshape) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
index 1956ab2d20..dde91234a8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -48,7 +48,7 @@ bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
for (const auto& rnn_state : model->flags.rnn_states()) {
if (output == rnn_state.state_array()) {
CHECK(op->type == OperatorType::kFill ||
- op->type == OperatorType::kTensorFlowIdentity);
+ op->type == OperatorType::kIdentity);
found_output_as_rnn_state_array = true;
break;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
index 9f5b7920cb..550de83018 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_elementwise_unary.cc
@@ -37,8 +37,8 @@ bool IsElementwiseOperator(OperatorType optype) {
case OperatorType::kRelu1:
case OperatorType::kRelu6:
case OperatorType::kTanh:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
return true;
default:
return false;
@@ -51,7 +51,7 @@ bool IsMoveOperator(OperatorType optype) {
case OperatorType::kExpandDims:
case OperatorType::kSpaceToDepth:
case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kReshape:
case OperatorType::kTranspose:
return true;
default:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
index 9e7fe1b1cc..c907a597cb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_reshape_transpose.cc
@@ -123,8 +123,8 @@ bool ReorderReshapeTranspose::Run(Model* model, std::size_t op_index) {
}
TensorFlowReshapeOperator* reshape_op =
- ConvertOperator<TensorFlowReshapeOperator*>(
- reshape_it->get(), OperatorType::kTensorFlowReshape);
+ ConvertOperator<TensorFlowReshapeOperator*>(reshape_it->get(),
+ OperatorType::kReshape);
if (reshape_op == nullptr) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
index a06919e228..b8b35161d7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_to_space_nd_attributes.cc
@@ -50,7 +50,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
// will delete this op.
return false;
}
- std::vector<int> crops_buffer =
+ const std::vector<int>& crops_buffer =
crops_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < crops_dims[0]; ++i) {
op->before_crops.push_back(crops_buffer[i * 2]);
@@ -62,7 +62,7 @@ bool ResolveBatchToSpaceNDAttributes::Run(Model* model, std::size_t op_index) {
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
- std::vector<int> block_shape_buffer =
+ const std::vector<int>& block_shape_buffer =
block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < block_shape_dims[0]; ++i) {
op->block_shape.push_back(block_shape_buffer[i]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index 6e78653fad..f7e5aa6609 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -145,17 +145,17 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
outval = floor(val0 / val1);
} else if (binary_op->type == OperatorType::kFloorMod) {
outval = val0 - (floor(val0 / val1) * val1);
- } else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
+ } else if (binary_op->type == OperatorType::kMinimum) {
outval = std::min(val0, val1);
- } else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
+ } else if (binary_op->type == OperatorType::kMaximum) {
outval = std::max(val0, val1);
- } else if (binary_op->type == OperatorType::kTensorFlowLess) {
+ } else if (binary_op->type == OperatorType::kLess) {
outval = val0 < val1;
- } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) {
+ } else if (binary_op->type == OperatorType::kLessEqual) {
outval = val0 <= val1;
- } else if (binary_op->type == OperatorType::kTensorFlowGreater) {
+ } else if (binary_op->type == OperatorType::kGreater) {
outval = val0 > val1;
- } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) {
+ } else if (binary_op->type == OperatorType::kGreaterEqual) {
outval = val0 >= val1;
} else {
LOG(FATAL) << "should not get here";
@@ -198,12 +198,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kDiv &&
binary_op->type != OperatorType::kFloorDiv &&
binary_op->type != OperatorType::kFloorMod &&
- binary_op->type != OperatorType::kTensorFlowMinimum &&
- binary_op->type != OperatorType::kTensorFlowMaximum &&
- binary_op->type != OperatorType::kTensorFlowLess &&
- binary_op->type != OperatorType::kTensorFlowLessEqual &&
- binary_op->type != OperatorType::kTensorFlowGreater &&
- binary_op->type != OperatorType::kTensorFlowGreaterEqual) {
+ binary_op->type != OperatorType::kMinimum &&
+ binary_op->type != OperatorType::kMaximum &&
+ binary_op->type != OperatorType::kLess &&
+ binary_op->type != OperatorType::kLessEqual &&
+ binary_op->type != OperatorType::kGreater &&
+ binary_op->type != OperatorType::kGreaterEqual) {
return false;
}
CHECK_EQ(binary_op->inputs.size(), 2);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index 7e7ad383e7..41562ab393 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -25,7 +25,7 @@ namespace toco {
bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
- if (base_op->type != OperatorType::kTensorFlowReshape) {
+ if (base_op->type != OperatorType::kReshape) {
return false;
}
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 9ea01acd05..8a0e3e8995 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -22,8 +22,7 @@ namespace toco {
bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
const auto it = model->operators.begin() + op_index;
const auto* op = it->get();
- if (!(op->type == OperatorType::kTensorFlowShape ||
- op->type == OperatorType::kRank)) {
+ if (!(op->type == OperatorType::kShape || op->type == OperatorType::kRank)) {
return false;
}
@@ -48,7 +47,7 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
// Compute the output
CHECK(!output_array.buffer);
auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
- if (op->type == OperatorType::kTensorFlowShape) {
+ if (op->type == OperatorType::kShape) {
// Copy the input shape into the output buffer.
output_buffer.data = input_array.shape().dims();
} else if (op->type == OperatorType::kRank) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
index 69db1942cd..a4d5f1923a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_stack.cc
@@ -41,7 +41,7 @@ void Stack(Model* model, StackOperator const& op) {
const auto& input_array = model->GetArray(op.inputs[i]);
int input_size = RequiredBufferSizeForShape(input_array.shape());
memcpy(&output_data[dst_offset], &input_array.GetBuffer<Type>().data[0],
- input_size * sizeof(Type));
+ input_size * ElementSize(Type));
dst_offset += input_size;
}
CHECK_EQ(dst_offset, output_data.size());
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 1dd52e9069..9d8bd4fc39 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -38,6 +38,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
CHECK_EQ(op.new_axis_mask, 0);
int num_input_axes = op.start_indices.size();
+ CHECK_EQ(num_input_axes, op.start_indices.size());
CHECK_EQ(num_input_axes, op.stop_indices.size());
CHECK_EQ(num_input_axes, op.strides.size());
@@ -49,11 +50,16 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
// Initialize source coordinate
Shape const& input_shape = input_array.shape();
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
- std::vector<int> src_coord(op.start_indices.size());
+ std::vector<int> src_coord(num_input_axes);
+ std::vector<int> stop_for_axis(num_input_axes);
for (int axis = 0; axis < num_input_axes; axis++) {
- src_coord[axis] = tflite::strided_slice::StartForAxis(
+ int start = tflite::strided_slice::StartForAxis(
op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
axis);
+ src_coord[axis] = start;
+ stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
+ op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
+ input_shape.dims().data(), axis, start);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -76,9 +82,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
}
// Check if we've overflowed.
- int stop = tflite::strided_slice::StopForAxis(
- op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(),
- axis);
+ int stop = stop_for_axis[axis];
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
@@ -155,14 +159,7 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
break;
}
- // Erase input array if no longer used
- if (IsDiscardableArray(*model, op->inputs[0]) &&
- CountOpsWithInput(*model, op->inputs[0]) == 1) {
- model->EraseArray(op->inputs[0]);
- }
-
- // Erase the operator
- model->operators.erase(it);
+ DeleteOpAndArraysIfUnused(model, it->get());
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index f6c8f79d8d..f89ef85fdb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -53,13 +53,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kCast:
case OperatorType::kLog:
case OperatorType::kNeg:
- case OperatorType::kTensorFlowRsqrt:
- case OperatorType::kTensorFlowSqrt:
- case OperatorType::kTensorFlowSquare:
- case OperatorType::kTensorFlowSum:
- case OperatorType::kTensorFlowMin:
- case OperatorType::kTensorFlowMax:
- case OperatorType::kTensorFlowReshape:
+ case OperatorType::kRsqrt:
+ case OperatorType::kSqrt:
+ case OperatorType::kSquare:
+ case OperatorType::kSum:
+ case OperatorType::kMin: // Reduction Min
+ case OperatorType::kMax: // Reduction Max
+ case OperatorType::kReshape:
case OperatorType::kRelu6:
case OperatorType::kRelu1:
case OperatorType::kRelu:
@@ -103,7 +103,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// The min-max is only copied for ops that copy data without arithmetic.
// In future trivial transpose, etc, can be handled here.
- if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ if (unary_op->type == OperatorType::kReshape) {
CopyMinMaxFromFirstInput(*unary_op, model);
}
@@ -164,10 +164,10 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = outval;
}
- } else if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ } else if (unary_op->type == OperatorType::kReshape) {
CHECK(input_buffer_size == output_buffer_size);
output_float_data = *input_float_data;
- } else if (unary_op->type == OperatorType::kTensorFlowSum) {
+ } else if (unary_op->type == OperatorType::kSum) {
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
AddMessageF("Axis input is non-constant");
@@ -196,7 +196,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
output_float_data[i] = sum;
}
- } else if (unary_op->type == OperatorType::kTensorFlowMin) {
+ } else if (unary_op->type == OperatorType::kMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
for (int i = 0; i < output_dims_count; i++) {
@@ -207,7 +207,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
min = std::min(min, (*input_float_data)[i]);
}
output_float_data[0] = min;
- } else if (unary_op->type == OperatorType::kTensorFlowMax) {
+ } else if (unary_op->type == OperatorType::kMax) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
for (int i = 0; i < output_dims_count; i++) {
@@ -220,9 +220,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
output_float_data[0] = max;
} else if (unary_op->type == OperatorType::kNeg ||
unary_op->type == OperatorType::kLog ||
- unary_op->type == OperatorType::kTensorFlowRsqrt ||
- unary_op->type == OperatorType::kTensorFlowSqrt ||
- unary_op->type == OperatorType::kTensorFlowSquare) {
+ unary_op->type == OperatorType::kRsqrt ||
+ unary_op->type == OperatorType::kSqrt ||
+ unary_op->type == OperatorType::kSquare) {
// Element-wise ops. Should have perfectly matching sizes here.
for (int i = 0; i < output_dims_count; i++) {
CHECK_EQ(output_shape.dims(i), input_shape.dims(i));
@@ -235,11 +235,11 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
outval = -val;
} else if (unary_op->type == OperatorType::kLog) {
outval = std::log(val);
- } else if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
+ } else if (unary_op->type == OperatorType::kRsqrt) {
outval = 1.0f / std::sqrt(val);
- } else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
+ } else if (unary_op->type == OperatorType::kSqrt) {
outval = std::sqrt(val);
- } else if (unary_op->type == OperatorType::kTensorFlowSquare) {
+ } else if (unary_op->type == OperatorType::kSquare) {
outval = val * val;
} else {
LOG(FATAL) << "should not get here.";
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index bc70db0bd8..8266e2c205 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -51,11 +51,12 @@ void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
- auto reorder_it = model->operators.begin() + op_index;
- auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
- if (reorder_op->type != OperatorType::kReorderAxes) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->type != OperatorType::kReorderAxes) {
return false;
}
+ auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
const auto& input_array_name = reorder_op->inputs[0];
const auto& output_array_name = reorder_op->outputs[0];
auto& input_array = model->GetArray(input_array_name);
@@ -95,7 +96,7 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
// Remove the op and output array.
model->EraseArray(output_array_name);
- model->operators.erase(reorder_it);
+ model->operators.erase(it);
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
index 2e063e3554..b615c9a545 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -28,7 +28,7 @@ namespace toco {
bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
const auto reshape_it = model->operators.begin() + op_index;
auto* reshape_op = reshape_it->get();
- if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ if (reshape_op->type != OperatorType::kReshape) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
index dad6aceccf..fab50bec1f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_space_to_batch_nd_attributes.cc
@@ -53,7 +53,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
// will delete this op.
return false;
}
- std::vector<int> paddings_buffer =
+ const std::vector<int>& paddings_buffer =
paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < paddings_dims[0]; ++i) {
op->before_paddings.push_back(paddings_buffer[i * 2]);
@@ -66,7 +66,7 @@ bool ResolveSpaceToBatchNDAttributes::Run(Model* model, std::size_t op_index) {
if (!block_shape_array.has_shape()) return false;
const std::vector<int>& block_shape_dims = block_shape_array.shape().dims();
CHECK_EQ(block_shape_dims.size(), 1);
- std::vector<int> block_shape_buffer =
+ const std::vector<int>& block_shape_buffer =
block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
for (int i = 0; i < block_shape_dims[0]; ++i) {
op->block_shape.push_back(block_shape_buffer[i]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
index dd3e73635a..e8bb85704e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_squeeze_attributes.cc
@@ -36,7 +36,7 @@ bool ResolveSqueezeAttributes::Run(Model* model, std::size_t op_index) {
// If the output is consumed by a reshape op, it's a trivial squeeze.
if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]);
- if (next_op->type == OperatorType::kTensorFlowReshape) {
+ if (next_op->type == OperatorType::kReshape) {
AddMessageF(
"%s is trivial because its output is only consumed by a "
"Reshape op",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
index 5c0c1e3478..fa5ee89933 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -28,8 +28,8 @@ namespace toco {
bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
auto concat_it = model->operators.begin() + op_index;
const auto* tf_concat_op = concat_it->get();
- if (tf_concat_op->type != OperatorType::kTensorFlowConcat &&
- tf_concat_op->type != OperatorType::kTensorFlowConcatV2) {
+ if (tf_concat_op->type != OperatorType::kConcat &&
+ tf_concat_op->type != OperatorType::kConcatV2) {
return false;
}
@@ -38,7 +38,7 @@ bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
// of inputs: in Concat,the axis is the first input, while in
// ConcatV2, it is the last input.
std::size_t axis_pos = 0;
- if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) {
+ if (tf_concat_op->type == OperatorType::kConcatV2) {
axis_pos = tf_concat_op->inputs.size() - 1;
}
const string axis_name = tf_concat_op->inputs[axis_pos];
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index 2a236d3f98..fcf30bd347 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -26,27 +26,40 @@ namespace toco {
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
- if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) {
+ if (matmul_it->get()->type != OperatorType::kMatMul) {
return false;
}
const auto* matmul_op =
static_cast<const TensorFlowMatMulOperator*>(matmul_it->get());
+ // Handling transposition of the first input here isn't very simple because
+ // we need to know the actual shape in order to produce a proper
+ // TransposeOperator. However, the second input is supposed to be 2D, so we
+ // can actually handle transposition of that matrix, which happens to be more
+ // common anyway.
+ CHECK(!matmul_op->transpose_a);
+
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
// operator. We'll transpose the second input to be in column-major order now
// and let constant propagation optimize things (if possible).
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(
- model,
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
+ string input_lhs = matmul_op->inputs[0];
+ string input_rhs = matmul_op->inputs[1];
+ if (!matmul_op->transpose_b) {
+ auto* transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ matmul_op->inputs[1],
+ CreateInt32Array(model,
+ AvailableArrayName(
+ *model, matmul_op->inputs[1] + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+
+ input_rhs = transpose_op->outputs[0];
+ }
// Refresh iterator.
matmul_it = model->operators.begin();
@@ -57,9 +70,6 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
}
DCHECK_EQ(matmul_it->get(), matmul_op);
- string input_lhs = matmul_op->inputs[0];
- string input_rhs = transpose_op->outputs[0];
-
// Construct the new FullyConnectedOperator.
auto* fc_op = new FullyConnectedOperator;
fc_op->outputs = matmul_op->outputs;
@@ -97,7 +107,7 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// MatMul op as a FullyConnected. However, TensorFlow skips the Reshape ops if
// the input doesn't need reshaping, so we can't just match (Reshape, MatMul)
// pairs.
- if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) {
+ if (previous_op && previous_op->type == OperatorType::kReshape) {
AddMessageF("Combining %s and %s into %s", LogName(*previous_op),
LogName(*matmul_op), LogName(*fc_op));
const auto& previous_op_output = previous_op->outputs[0];
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
index 38e0005890..4edffe3d48 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -27,7 +27,7 @@ namespace toco {
bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
const auto merge_it = model->operators.begin() + op_index;
const auto* merge_op = merge_it->get();
- if (merge_op->type != OperatorType::kTensorFlowMerge) {
+ if (merge_op->type != OperatorType::kMerge) {
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
index a418073441..da8e7a2d1c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -27,7 +27,7 @@ namespace toco {
bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
const auto switch_it = model->operators.begin() + op_index;
const auto* switch_op = switch_it->get();
- if (switch_op->type != OperatorType::kTensorFlowSwitch) {
+ if (switch_op->type != OperatorType::kSwitch) {
return false;
}
@@ -92,7 +92,7 @@ bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
if (*input_it == switch_op->outputs[nonselected_output_index]) {
// Let us guard our assumption that only Merge nodes consume the outputs
// of Switch nodes:
- CHECK(other_op->type == OperatorType::kTensorFlowMerge);
+ CHECK(other_op->type == OperatorType::kMerge);
input_it = other_op->inputs.erase(input_it);
} else {
++input_it;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
deleted file mode 100644
index 1ddf54c778..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-
-namespace {
-
-void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
- int operand_index) {
- CHECK(tile_op->type == OperatorType::kTensorFlowTile);
- CHECK_EQ(binary_op->inputs.size(), 2);
- CHECK_EQ(tile_op->inputs.size(), 2);
- const string tile_multiplier_array = tile_op->inputs[1];
- const string tile_output_array = tile_op->outputs[0];
- binary_op->inputs[operand_index] = tile_op->inputs[0];
- auto tile_it = model->operators.begin();
- for (; tile_it != model->operators.end(); ++tile_it) {
- if (tile_it->get() == tile_op) {
- break;
- }
- }
- CHECK(tile_it != model->operators.end());
- CHECK(tile_it->get() == tile_op);
- model->operators.erase(tile_it);
- if (!CountOpsWithInput(*model, tile_multiplier_array) &&
- !GetOpWithOutput(*model, tile_multiplier_array)) {
- model->EraseArray(tile_multiplier_array);
- }
- if (!CountOpsWithInput(*model, tile_output_array)) {
- model->EraseArray(tile_output_array);
- }
-}
-} // namespace
-
-bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) {
- const auto binary_it = model->operators.begin() + op_index;
- auto* binary_op = binary_it->get();
- // Test for binary ops of types that we know how to resolve
- if (binary_op->inputs.size() != 2) {
- return false;
- }
- if (binary_op->type != OperatorType::kAdd &&
- binary_op->type != OperatorType::kMul &&
- binary_op->type != OperatorType::kSub &&
- binary_op->type != OperatorType::kDiv) {
- return false;
- }
-
- Operator* const op[2] = {
- GetOpWithOutput(*model, binary_op->inputs[0]),
- GetOpWithOutput(*model, binary_op->inputs[1]),
- };
-
- // In the unlikely case where both operands are Tile, we can't infer the
- // output
- // size without the Tile nodes, so we have to bail out.
- if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] &&
- op[1]->type == OperatorType::kTensorFlowTile) {
- return false;
- }
-
- for (int i = 0; i < 2; i++) {
- if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) {
- // We can only remove a Tile operator is no other op than the present
- // binary op was consuming its tiled output.
- if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) {
- AddMessageF("Removing %s", LogName(*op[i]));
- RemoveTileOperator(model, op[i], binary_op, i);
- return true;
- }
- }
- }
- return false;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
index c00cdcb944..22c258cec5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/shuffle_fc_weights.cc
@@ -24,14 +24,14 @@ limitations under the License.
namespace toco {
-bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) {
+bool ShuffleFCWeights::Run(Model* model, std::size_t op_index) {
Operator* op = model->operators[op_index].get();
if (op->type != OperatorType::kFullyConnected) {
return false;
}
FullyConnectedOperator* fc_op = static_cast<FullyConnectedOperator*>(op);
// Exit if this FC op already has shuffled weights
- if (fc_op->experimental_shuffled_weights) {
+ if (fc_op->weights_format != FullyConnectedWeightsFormat::kDefault) {
return false;
}
const Array& input_array = model->GetArray(fc_op->inputs[0]);
@@ -135,7 +135,7 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) {
CHECK_EQ(shuffled_data_ptr, shuffled_data.data() + rows * cols);
// Switch this FC op to using the shuffled weights.
weights_data = std::move(shuffled_data);
- fc_op->experimental_shuffled_weights = true;
+ fc_op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
AddMessageF("Applied experimental shuffling to the weights of %s",
LogName(*op));
// Add a second output array to this FC op, serving as a workspace to perform
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index 8dcd4adc90..95e8433be2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -8,8 +8,8 @@ load(
)
tf_cc_test(
- name = "resolve_constant_concatenation_test",
- srcs = ["resolve_constant_concatenation_test.cc"],
+ name = "lstm_utils_test",
+ srcs = ["lstm_utils_test.cc"],
deps = [
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
@@ -19,8 +19,20 @@ tf_cc_test(
)
tf_cc_test(
- name = "lstm_utils_test",
- srcs = ["lstm_utils_test.cc"],
+ name = "quantize_weights_test",
+ srcs = ["quantize_weights_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "resolve_constant_concatenation_test",
+ srcs = ["resolve_constant_concatenation_test.cc"],
deps = [
"//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
new file mode 100644
index 0000000000..c05eb0929f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <math.h>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ QuantizeWeightsTest() {}
+
+ // The name of the weights input array.
+ const string kWeightsName = "weights";
+ // The zero_point of the values in the input array.
+ const int kZeroPoint = 128;
+
+ // Prepare a hypothetical TOCO model of a quantizable fully connected float
+ // layer.
+ void PrepareModel(Model* model, int elements_per_dim) {
+ std::vector<string> fc_input_names = {"inputs", kWeightsName};
+
+ const int kDim = 4;
+ const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
+ auto in_buf = absl::make_unique<float[]>(buf_size);
+ // Initialize the array with values from -128.0 to 127.0, since these values
+ // should be exactly representable by quantization.
+ for (int i = 0; i < buf_size; i++) {
+ in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
+ }
+
+ for (const string& fc_input_name : fc_input_names) {
+ Array& in_array = model->GetOrCreateArray(fc_input_name);
+ in_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* in_array_shape = in_array.mutable_shape();
+ std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
+ in_array_shape_dim->resize(kDim, elements_per_dim);
+ auto& in_array_buffer =
+ in_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ in_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
+ }
+
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->inputs = fc_input_names;
+ fc_op->outputs = {"fc_op_outputs"};
+ Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
+ out_array.data_type = ArrayDataType::kFloat;
+ Shape* out_array_shape = out_array.mutable_shape();
+ std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
+ out_array_shape_dim->resize(kDim, elements_per_dim);
+ model->operators.push_back(std::unique_ptr<Operator>(fc_op));
+ }
+};
+
+TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
+ // Test that weight arrays that are large enough are quantized.
+ Model model;
+ // 6 elements per dim gives us 1296 elements, which is sufficient to be
+ // quantized.
+ PrepareModel(&model, 6);
+
+ // Check the state of the graph before the transformation.
+ const auto& float_array_map = model.GetArrayMap();
+ EXPECT_EQ(float_array_map.size(), 3);
+ // Before the transformation, all arrays should be type float.
+ for (const auto& element : float_array_map) {
+ EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
+ }
+ const std::vector<float> float_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+
+ // Invoke the transformation.
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::QuantizeWeights);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+
+ // Check the state of the graph after the transformation.
+ const auto& quantized_array_map = model.GetArrayMap();
+ EXPECT_EQ(quantized_array_map.size(), 4);
+ // After the transformation, three arrays should be type float and one array
+ // should be uint8.
+ int num_float = 0;
+ int num_uint8 = 0;
+ for (const auto& element : quantized_array_map) {
+ if (element.second->data_type == ArrayDataType::kFloat) {
+ num_float++;
+ } else if (element.second->data_type == ArrayDataType::kUint8) {
+ num_uint8++;
+ } else {
+ FAIL() << "Unexpected array type.";
+ }
+ }
+ EXPECT_EQ(num_float, 3);
+ EXPECT_EQ(num_uint8, 1);
+ // Ensure that the values were quantized correctly.
+ const std::vector<uint8>& quantized_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
+ for (int i = 0; i < quantized_weight_vals.size(); i++) {
+ EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
+ }
+
+ // Ensure that a Dequantize operator has been inserted before the
+ // FullyConnectedLayer.
+ EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
+}
+
+TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
+ // Test that weight arrays that are too small are left untouched.
+ Model model;
+ // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
+ // quantized.
+ PrepareModel(&model, 5);
+
+ // Check the state of the graph before the transformation.
+ const auto& float_array_map = model.GetArrayMap();
+ EXPECT_EQ(float_array_map.size(), 3);
+ // Before the transformation, all arrays should be type float.
+ for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
+ EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
+ }
+ std::vector<float> float_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+
+ // Invoke the transformation.
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::QuantizeWeights);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+
+ // Check the state of the graph after the transformation.
+ const auto& post_array_map = model.GetArrayMap();
+ EXPECT_EQ(post_array_map.size(), 3);
+ for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
+ EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
+ }
+ // Ensure that the values remain unchanged.
+ std::vector<float> const& quantized_weight_vals =
+ model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
+ for (int i = 0; i < quantized_weight_vals.size(); i++) {
+ EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
index 3a1d175b98..66cfed4ac2 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <memory>
#include <string>
-#include <unordered_map>
#include <vector>
#include <gmock/gmock.h>
@@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test {
Array& in_array = model->GetOrCreateArray(concat_input_name);
in_array.data_type = ArrayDataType::kFloat;
- // Initialize shape for the input array.
+ // Initialize shape for the input array.
Shape* in_array_shape = in_array.mutable_shape();
std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
for (int i = 0; i < kDim; i++) {
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 83dce66df1..bc439a2feb 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
#include "tensorflow/contrib/lite/toco/tensorflow_util.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -44,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
@@ -63,8 +63,6 @@ using tensorflow::TensorShapeProto;
namespace toco {
-using port::Status;
-
namespace {
bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
@@ -130,6 +128,42 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node,
return attr.list();
}
+tensorflow::Status CheckOptionalAttr(const NodeDef& node,
+ const string& attr_name,
+ const string& expected_value) {
+ if (HasAttr(node, attr_name)) {
+ const string& value = GetStringAttr(node, attr_name);
+ if (value != expected_value) {
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ expected_value + "'");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status CheckOptionalAttr(
+ const NodeDef& node, const string& attr_name,
+ const tensorflow::DataType& expected_value) {
+ if (HasAttr(node, attr_name)) {
+ const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
+ if (value != expected_value) {
+ return tensorflow::errors::InvalidArgument(
+ "Unexpected value for attribute '" + attr_name + "'. Expected '" +
+ tensorflow::DataType_Name(expected_value) + "'");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+template <typename T1, typename T2>
+tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
+ const string& description) {
+ if (v1 == v2) return tensorflow::Status::OK();
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Unexpected ", description, ": got ", v1, ", expected ", v2));
+}
+
ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
if (dtype == DT_UINT8)
return ArrayDataType::kUint8;
@@ -148,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
return ArrayDataType::kNone;
}
-Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
- tensorflow::TensorShapeProto_Dim>& input_dims,
- int* input_flat_size, Shape* shape) {
+tensorflow::Status ImportShape(
+ const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
+ input_dims,
+ int* input_flat_size, Shape* shape) {
std::vector<int> input_dims_only_sizes;
for (auto& d : input_dims) {
if (d.size() == 0) {
@@ -160,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField<
// For now, tweaking this to record a 0-D shape instead.
shape->mutable_dims()->clear();
if (input_flat_size != nullptr) *input_flat_size = 0;
- return Status::OK();
+ return tensorflow::Status::OK();
}
// TensorFlow's shapes use int64s, while TOCO uses ints.
if (d.size() > std::numeric_limits<int>::max()) {
- return Status(false, "Shape element overflows");
+ return tensorflow::errors::InvalidArgument("Shape element overflows");
}
input_dims_only_sizes.push_back(d.size());
}
*shape->mutable_dims() = input_dims_only_sizes;
- if (input_flat_size == nullptr) return Status::OK();
+ if (input_flat_size == nullptr) return tensorflow::Status::OK();
return NumElements(input_dims_only_sizes, input_flat_size);
}
-Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -203,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_float_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(float),
") nor float_val (", input_tensor.float_val_size(),
") have the right dimensions (", input_flat_size,
") for this float tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -227,7 +263,11 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int_val_size()) {
+ if (input_tensor.int_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int_val(0);
+ }
+ } else if (input_tensor.int_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
}
@@ -236,18 +276,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(uint8_t),
") nor int_val (", input_tensor.int_val_size(),
") have the right dimensions (", input_flat_size,
") for this uint8 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT32);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -260,7 +300,11 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int_val_size()) {
+ if (input_tensor.int_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int_val(0);
+ }
+ } else if (input_tensor.int_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.int_val_size(); i++) {
output_int_data[i] = input_tensor.int_val(i);
}
@@ -269,18 +313,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
- absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size() / sizeof(int32),
- ") nor int_val (", input_tensor.int_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this int32 tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (",
+ input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (",
+ input_tensor.int_val_size(), ") have the right dimensions (",
+ input_flat_size, ") for this int32 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_INT64);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -293,8 +336,12 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
CHECK_GE(output_int_data.size(), input_flat_size);
- if (input_tensor.int64_val_size()) {
- for (int i = 0; i < input_tensor.int64_val_size(); i++) {
+ if (input_tensor.int64_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_int_data[i] = input_tensor.int64_val(0);
+ }
+ } else if (input_tensor.int64_val_size() == input_flat_size) {
+ for (int i = 0; i < input_tensor.float_val_size(); i++) {
output_int_data[i] = input_tensor.int64_val(i);
}
} else if (input_tensor.tensor_content().size() ==
@@ -302,18 +349,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) {
toco::port::CopyToBuffer(input_tensor.tensor_content(),
reinterpret_cast<char*>(output_int_data.data()));
} else {
- return Status(
- false,
+ return tensorflow::errors::InvalidArgument(
absl::StrCat("Neither input_content (",
input_tensor.tensor_content().size() / sizeof(int64),
") nor int64_val (", input_tensor.int64_val_size(),
") have the right dimensions (", input_flat_size,
") for this int64 tensor"));
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_BOOL);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -327,7 +374,11 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
false);
CHECK_GE(output_bool_data.size(), input_flat_size);
- if (input_tensor.bool_val_size()) {
+ if (input_tensor.bool_val_size() == 1) {
+ for (int i = 0; i < input_flat_size; i++) {
+ output_bool_data[i] = input_tensor.bool_val(0);
+ }
+ } else if (input_tensor.bool_val_size() == input_flat_size) {
for (int i = 0; i < input_tensor.bool_val_size(); i++) {
output_bool_data[i] = input_tensor.bool_val(i);
}
@@ -343,19 +394,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) {
// So far only encountered that in an array with 1 entry, let's
// require that until we encounter a graph where that's not the case.
if (output_bool_data.size() != 1) {
- return Status(
- false, absl::StrCat("Neither input_content (",
- input_tensor.tensor_content().size(),
- ") nor bool_val (", input_tensor.bool_val_size(),
- ") have the right dimensions (", input_flat_size,
- ") for this bool tensor"));
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Neither input_content (", input_tensor.tensor_content().size(),
+ ") nor bool_val (", input_tensor.bool_val_size(),
+ ") have the right dimensions (", input_flat_size,
+ ") for this bool tensor"));
}
output_bool_data[0] = false;
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
-Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
+tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
+ Array* output_array) {
CHECK_EQ(input_tensor.dtype(), DT_STRING);
const auto& input_shape = input_tensor.tensor_shape();
CHECK_LE(input_shape.dim_size(), 4);
@@ -365,9 +416,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
if (!status.ok()) return status;
if (input_flat_size != input_tensor.string_val_size()) {
- return Status(false,
- "Input_content string_val doesn't have the right dimensions "
- "for this string tensor");
+ return tensorflow::errors::InvalidArgument(
+ "Input_content string_val doesn't have the right dimensions "
+ "for this string tensor");
}
auto& output_string_data =
@@ -377,7 +428,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) {
for (int i = 0; i < input_flat_size; ++i) {
output_string_data[i] = input_tensor.string_val(i);
}
- return Status::OK();
+ return tensorflow::Status::OK();
}
// Count the number of inputs of a given node. If
@@ -391,18 +442,19 @@ int GetInputsCount(const NodeDef& node,
return i;
}
}
- return node.input_size();
- } else {
- return node.input_size();
}
+ return node.input_size();
}
-void CheckInputsCount(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- int expected_input_count) {
- QCHECK_EQ(GetInputsCount(node, tf_import_flags), expected_input_count)
- << node.op() << " node expects " << expected_input_count
- << " input(s) other than control dependencies: " << node.DebugString();
+tensorflow::Status CheckInputsCount(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ int expected_input_count) {
+ if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
+ return tensorflow::errors::FailedPrecondition(
+ node.op(), " node expects ", expected_input_count,
+ " input(s) other than control dependencies: ", node.DebugString());
+ }
+ return tensorflow::Status::OK();
}
template <ArrayDataType T>
@@ -417,14 +469,14 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
-Status ConvertConstOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConstOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Const");
const auto& tensor = GetTensorAttr(node, "value");
const auto dtype = GetDataTypeAttr(node, "dtype");
- Status status = Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
auto& array = model->GetOrCreateArray(node.name());
switch (dtype) {
@@ -460,24 +512,21 @@ Status ConvertConstOperator(const NodeDef& node,
array.GetMutableBuffer<ArrayDataType::kNone>();
break;
}
- if (!status.ok()) {
- status.AppendMessage(" (while processing node '" + node.name() + "')");
- }
- return status;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ status, " (while processing node '" + node.name() + "')");
+ return tensorflow::Status::OK();
}
-void ConvertConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2D");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
- if (HasAttr(node, "data_format")) {
- CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
- }
- CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
+ TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
const auto& input_name = node.input(0);
const auto& weights_name = node.input(1);
@@ -502,27 +551,26 @@ void ConvertConvOperator(const NodeDef& node,
auto* conv = new ConvOperator;
conv->inputs = {input_name, reordered_weights_name};
conv->outputs = {node.name()};
+ if (!HasAttr(node, "strides")) {
+ return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
+ }
const auto& strides = GetListAttr(node, "strides");
- CHECK_EQ(strides.i_size(), 4);
- CHECK_EQ(strides.i(0), 1);
- CHECK_EQ(strides.i(3), 1);
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
+ TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
- CHECK_EQ(dilations.i_size(), 4);
- CHECK_EQ(dilations.i(0), 1)
- << "Can only import Conv ops with dilation along the height (1st) or "
- "width (2nd) axis. TensorFlow op \""
- << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
- << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
- << "].";
- CHECK_EQ(dilations.i(3), 1)
- << "Can only import Conv ops with dilation along the height (1st) or "
- "width (2nd) axis. TensorFlow op \""
- << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
- << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
- << "].";
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
} else {
@@ -535,16 +583,19 @@ void ConvertConvOperator(const NodeDef& node,
} else if (padding == "VALID") {
conv->padding.type = PaddingType::kValid;
} else {
- LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
+ return tensorflow::errors::InvalidArgument(
+ "Bad padding (only SAME and VALID are supported)");
}
model->operators.emplace_back(conv);
+
+ return tensorflow::Status::OK();
}
-void ConvertDepthwiseConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDepthwiseConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "DepthwiseConv2dNative");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
@@ -591,13 +642,14 @@ void ConvertDepthwiseConvOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(conv);
+ return tensorflow::Status::OK();
}
-void ConvertDepthToSpaceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDepthToSpaceOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "DepthToSpace");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
auto* op = new DepthToSpaceOperator;
@@ -606,28 +658,37 @@ void ConvertDepthToSpaceOperator(const NodeDef& node,
op->block_size = GetIntAttr(node, "block_size");
QCHECK_GE(op->block_size, 2);
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSpaceToDepthOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSpaceToDepthOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SpaceToDepth");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
- CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
+ tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
+ if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
+ dtype != DT_INT64) {
+ const auto* enum_descriptor = tensorflow::DataType_descriptor();
+ LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
+ << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
+ << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}.";
+ }
auto* op = new SpaceToDepthOperator;
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
op->block_size = GetIntAttr(node, "block_size");
QCHECK_GE(op->block_size, 2);
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBiasAddOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertBiasAddOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "BiasAdd");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto& input_name = node.input(0);
const auto& bias_name = node.input(1);
@@ -637,13 +698,14 @@ void ConvertBiasAddOperator(const NodeDef& node,
biasadd->inputs.push_back(bias_name);
biasadd->outputs.push_back(node.name());
model->operators.emplace_back(biasadd);
+ return tensorflow::Status::OK();
}
-void ConvertRandomUniform(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertRandomUniform(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "RandomUniform");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
auto op = absl::make_unique<RandomUniformOperator>();
@@ -654,86 +716,12 @@ void ConvertRandomUniform(const NodeDef& node,
op->seed2 = GetIntAttr(node, "seed2");
CHECK(model != nullptr);
model->operators.emplace_back(std::move(op));
+ return tensorflow::Status::OK();
}
-void ConvertReluOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Relu");
- CheckInputsCount(node, tf_import_flags, 1);
- const auto& input_name = node.input(0);
- auto* relu = new ReluOperator;
- relu->inputs.push_back(input_name);
- relu->outputs.push_back(node.name());
- model->operators.emplace_back(relu);
-}
-
-void ConvertRelu6Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Relu6");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new Relu6Operator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertLogOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Log");
- CheckInputsCount(node, tf_import_flags, 1);
-
- auto op = absl::make_unique<LogOperator>();
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(std::move(op));
-}
-
-void ConvertLogisticOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sigmoid");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new LogisticOperator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertTanhOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Tanh");
- CheckInputsCount(node, tf_import_flags, 1);
-
- const auto& input_name = node.input(0);
- auto* op = new TanhOperator;
- op->inputs.push_back(input_name);
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertDivOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK(node.op() == "Div" || node.op() == "RealDiv");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new DivOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertIdentityOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertIdentityOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient");
auto* op = new TensorFlowIdentityOperator;
@@ -750,13 +738,14 @@ void ConvertIdentityOperator(const NodeDef& node,
op->inputs.push_back(input_name);
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFakeQuantWithMinMaxArgs(
+tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new FakeQuantOperator;
op->inputs.push_back(node.input(0));
op->minmax.reset(new MinMax);
@@ -767,9 +756,10 @@ void ConvertFakeQuantWithMinMaxArgs(
// tf.fake_quant_with_min_max_args num_bits defaults to 8.
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFakeQuantWithMinMaxVars(
+tensorflow::Status ConvertFakeQuantWithMinMaxVars(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
@@ -785,46 +775,14 @@ void ConvertFakeQuantWithMinMaxVars(
op->outputs.push_back(node.name());
op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertNegOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Neg");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new NegOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertRsqrtOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Rsqrt");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowRsqrtOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSqrtOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sqrt");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowSqrtOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSqueezeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSqueezeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Squeeze");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
auto* op = new SqueezeOperator;
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
@@ -838,73 +796,14 @@ void ConvertSqueezeOperator(const NodeDef& node,
}
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSquareOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Square");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowSquareOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAddOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Add");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new AddOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAddNOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "AddN");
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- auto* op = new AddNOperator;
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Mul");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new MulOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSubOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sub");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new SubOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSumOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Sum");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSumOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -913,74 +812,14 @@ void ConvertSumOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertTileOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Tile");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowTileOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSliceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Slice");
- CheckInputsCount(node, tf_import_flags, 3);
- auto* op = new SliceOperator;
- for (int i = 0; i < 3; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertPadOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Pad");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new PadOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertPadV2Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "PadV2");
- CheckInputsCount(node, tf_import_flags, 3);
- auto* op = new PadV2Operator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->inputs.push_back(node.input(2));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertShapeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Shape");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new TensorFlowShapeOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSplitOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSplitOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Split");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSplitOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -991,25 +830,14 @@ void ConvertSplitOperator(const NodeDef& node,
}
op->num_split = num_split;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertMergeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Merge");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMergeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSwitchOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSwitchOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Switch");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowSwitchOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1017,13 +845,14 @@ void ConvertSwitchOperator(const NodeDef& node,
// Switch operators have two outputs: "name" and "name:1".
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSoftmaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSoftmaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Softmax");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
auto* softmax = new SoftmaxOperator;
softmax->inputs.push_back(input_name);
@@ -1032,25 +861,14 @@ void ConvertSoftmaxOperator(const NodeDef& node,
CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
softmax->beta = 1.f;
model->operators.emplace_back(softmax);
+ return tensorflow::Status::OK();
}
-void ConvertLogSoftmaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "LogSoftmax");
- CheckInputsCount(node, tf_import_flags, 1);
- const auto& input_name = node.input(0);
- auto* log_softmax = new LogSoftmaxOperator;
- log_softmax->inputs.push_back(input_name);
- log_softmax->outputs.push_back(node.name());
- model->operators.emplace_back(log_softmax);
-}
-
-void ConvertLRNOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertLRNOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "LRN");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
auto* lrn = new LocalResponseNormalizationOperator;
lrn->inputs.push_back(input_name);
@@ -1060,13 +878,14 @@ void ConvertLRNOperator(const NodeDef& node,
lrn->alpha = GetFloatAttr(node, "alpha");
lrn->beta = GetFloatAttr(node, "beta");
model->operators.emplace_back(lrn);
+ return tensorflow::Status::OK();
}
-void ConvertMaxPoolOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMaxPoolOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "MaxPool");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
@@ -1102,13 +921,14 @@ void ConvertMaxPoolOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(maxpool);
+ return tensorflow::Status::OK();
}
-void ConvertAvgPoolOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertAvgPoolOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "AvgPool");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto& input_name = node.input(0);
// We only support NHWC, which is the default data_format.
// So if data_format is not defined, we're all good.
@@ -1140,24 +960,13 @@ void ConvertAvgPoolOperator(const NodeDef& node,
LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
}
model->operators.emplace_back(avgpool);
+ return tensorflow::Status::OK();
}
-void ConvertReshapeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Reshape");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowReshapeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertBatchMatMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CheckInputsCount(node, tf_import_flags, 2);
+tensorflow::Status ConvertBatchMatMulOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
// https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
CHECK(!HasAttr(node, "adj_a") || (GetBoolAttr(node, "adj_a") == false));
@@ -1167,33 +976,36 @@ void ConvertBatchMatMulOperator(const NodeDef& node,
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
model->operators.emplace_back(batch_matmul);
+ return tensorflow::Status::OK();
}
-void ConvertMatMulOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CheckInputsCount(node, tf_import_flags, 2);
+tensorflow::Status ConvertMatMulOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- // Transpose flags should be easy to support, but we don't have a
- // GraphDef with them to test on at the moment.
- CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"),
- false);
- CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"),
- false);
CHECK(!HasAttr(node, "adjoint_a") ||
(GetBoolAttr(node, "adjoint_a") == false));
CHECK(!HasAttr(node, "adjoint_b") ||
(GetBoolAttr(node, "adjoint_b") == false));
auto* matmul = new TensorFlowMatMulOperator;
+ if (HasAttr(node, "transpose_a")) {
+ matmul->transpose_a = GetBoolAttr(node, "transpose_a");
+ }
+ if (HasAttr(node, "transpose_b")) {
+ matmul->transpose_b = GetBoolAttr(node, "transpose_b");
+ }
+
matmul->inputs = {node.input(0), node.input(1)};
matmul->outputs = {node.name()};
model->operators.emplace_back(matmul);
+ return tensorflow::Status::OK();
}
-void ConvertConcatOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertConcatOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
Operator* op = nullptr;
if (node.op() == "Concat") {
op = new TensorFlowConcatOperator;
@@ -1213,104 +1025,38 @@ void ConvertConcatOperator(const NodeDef& node,
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertAllOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "All");
- auto* op = new TensorFlowAllOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertAssertOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Assert");
- auto* op = new TensorFlowAssertOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertLessOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Less");
- auto* op = new TensorFlowLessOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertLessEqualOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "LessEqual");
- auto* op = new TensorFlowLessEqualOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertSinOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Sin");
- auto* op = new SinOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertGreaterOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Greater");
- auto* op = new TensorFlowGreaterOperator;
+// This method supports simple operators without additional attributes.
+template <typename Op>
+tensorflow::Status ConvertSimpleOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ auto* op = new Op;
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertGreaterEqualOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "GreaterEqual");
- auto* op = new TensorFlowGreaterEqualOperator;
- const int num_inputs = GetInputsCount(node, tf_import_flags);
- for (int i = 0; i < num_inputs; ++i) {
- op->inputs.push_back(node.input(i));
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
+// This method supports simple operators without additional attributes.
+template <typename Op, unsigned int NumInputs>
+tensorflow::Status ConvertSimpleOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
+ return ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
-void ConvertMaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Max");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowMaxOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1319,13 +1065,14 @@ void ConvertMaxOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertMinOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMinOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Min");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new TensorFlowMinOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1334,35 +1081,12 @@ void ConvertMinOperator(const NodeDef& node,
if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertMaximumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Maximum");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMaximumOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertMinimumOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Minimum");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TensorFlowMinimumOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertUnsupportedOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertUnsupportedOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
LOG(INFO) << "Converting unsupported operation: " << node.op();
auto* op = new TensorFlowUnsupportedOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1385,28 +1109,16 @@ void ConvertUnsupportedOperator(const NodeDef& node,
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
}
+ return tensorflow::Status::OK();
}
-void ConvertSelectOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CheckInputsCount(node, tf_import_flags, 3);
-
- auto* op = new SelectOperator;
- for (const auto& input : node.input()) {
- op->inputs.push_back(input);
- }
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertStridedSliceOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertStridedSliceOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "StridedSlice");
// TODO(soroosh): The 4th input (strides) should be e optional, to be
// consistent with TF.
- CheckInputsCount(node, tf_import_flags, 4);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
auto* op = new StridedSliceOperator;
for (const auto& input : node.input()) {
@@ -1426,14 +1138,15 @@ void ConvertStridedSliceOperator(const NodeDef& node,
: 0;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertPlaceholderOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertPlaceholderOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
if (node.op() == "Placeholder") {
- CheckInputsCount(node, tf_import_flags, 0);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
}
auto& array = model->GetOrCreateArray(node.name());
if (node.attr().count("dtype")) {
@@ -1458,17 +1171,20 @@ void ConvertPlaceholderOperator(const NodeDef& node,
}
}
}
+ return tensorflow::Status::OK();
}
-void ConvertNoOpOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {}
+tensorflow::Status ConvertNoOpOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ return tensorflow::Status::OK();
+}
-void ConvertCastOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertCastOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Cast");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
auto* op = new CastOperator;
@@ -1477,27 +1193,31 @@ void ConvertCastOperator(const NodeDef& node,
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFloorOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertFloorOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Floor");
- CheckInputsCount(node, tf_import_flags, 1);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
const auto data_type = GetDataTypeAttr(node, "T");
CHECK(data_type == DT_FLOAT);
auto* op = new FloorOperator;
op->inputs.push_back(node.input(0));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertGatherOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertGatherOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK(node.op() == "Gather" || node.op() == "GatherV2");
- if (node.op() == "Gather") CheckInputsCount(node, tf_import_flags, 2);
- if (node.op() == "GatherV2") CheckInputsCount(node, tf_import_flags, 3);
+ if (node.op() == "Gather")
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+ if (node.op() == "GatherV2")
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
auto* op = new GatherOperator;
@@ -1507,13 +1227,15 @@ void ConvertGatherOperator(const NodeDef& node,
// should read it an pass it on to the TF Lite Interpreter.
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertArgMaxOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "ArgMax");
- CheckInputsCount(node, tf_import_flags, 2);
+template <typename Op, const char* op_name>
+tensorflow::Status ConvertArgMinMaxOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), op_name);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
const auto axis_data_type =
HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
const auto output_type = HasAttr(node, "output_type")
@@ -1521,19 +1243,20 @@ void ConvertArgMaxOperator(const NodeDef& node,
: DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
- auto* op = new ArgMaxOperator;
+ auto* op = new Op;
op->output_data_type = ConvertDataType(output_type);
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertResizeBilinearOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertResizeBilinearOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "ResizeBilinear");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new ResizeBilinearOperator;
op->align_corners = false;
@@ -1545,13 +1268,14 @@ void ConvertResizeBilinearOperator(const NodeDef& node,
op->inputs.push_back(node.input(1));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBatchNormWithGlobalNormalizationOperator(
+tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
- CheckInputsCount(node, tf_import_flags, 5);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
// TODO(ahentz): to really match tensorflow we need to add variance_epsilon
// to the input, before feeding it into TensorFlowRsqrtOperator.
@@ -1594,13 +1318,14 @@ void ConvertBatchNormWithGlobalNormalizationOperator(
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertFusedBatchNormOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertFusedBatchNormOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "FusedBatchNorm");
- CheckInputsCount(node, tf_import_flags, 5);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
// Declare shortcuts for the inputs.
const string& gamma_input = node.input(1);
@@ -1646,13 +1371,14 @@ void ConvertFusedBatchNormOperator(const NodeDef& node,
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertSpaceToBatchNDOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSpaceToBatchNDOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SpaceToBatchND");
- CheckInputsCount(node, tf_import_flags, 3);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
auto* op = new SpaceToBatchNDOperator;
@@ -1661,13 +1387,14 @@ void ConvertSpaceToBatchNDOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertBatchToSpaceNDOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertBatchToSpaceNDOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "BatchToSpaceND");
- CheckInputsCount(node, tf_import_flags, 3);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
auto* op = new BatchToSpaceNDOperator;
@@ -1676,24 +1403,14 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertExpOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Exp");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new ExpOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertMeanOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertMeanOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Mean");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
auto* op = new MeanOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1704,11 +1421,12 @@ void ConvertMeanOperator(const NodeDef& node,
} else if (HasAttr(node, "keep_dims")) {
op->keep_dims = GetBoolAttr(node, "keep_dims");
}
+ return tensorflow::Status::OK();
}
-void ConvertSvdfOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSvdfOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Svdf");
const int input_size = GetInputsCount(node, tf_import_flags);
QCHECK(input_size == 3 || input_size == 4)
@@ -1731,14 +1449,15 @@ void ConvertSvdfOperator(const NodeDef& node,
}
op->rank = node.attr().at("Rank").i();
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
// This is just bare bones support to get the shapes to propagate.
-void ConvertTransposeConvOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertTransposeConvOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Conv2DBackpropInput");
- CheckInputsCount(node, tf_import_flags, 3);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new TransposeConvOperator;
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
@@ -1779,11 +1498,13 @@ void ConvertTransposeConvOperator(const NodeDef& node,
if (existing_transpose) {
CHECK(existing_transpose->type == OperatorType::kTranspose);
} else {
- // Transpose weights from HWIO order to OHWI order, which is more efficient
- // for computation
+ // Transpose weights from HWOI order to OHWI order, which is more efficient
+ // for computation. (Note that TensorFlow considers the order as HWIO
+ // because they consider this a backward conv, inverting the sense of
+ // input/output.)
TransposeOperator* transpose = new TransposeOperator;
string perm_array = CreateConstArray<ArrayDataType::kInt32>(
- model, node.name() + "_transpose_perm", {3, 0, 1, 2});
+ model, node.name() + "_transpose_perm", {2, 0, 1, 3});
transpose->inputs = {weights_name, perm_array};
transpose->outputs = {transposed_weights_name};
model->operators.emplace_back(transpose);
@@ -1800,61 +1521,14 @@ void ConvertTransposeConvOperator(const NodeDef& node,
"Conv2DBackpropInput nodes.";
}
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertExpandDimsOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "ExpandDims");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new ExpandDimsOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFillOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Fill");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FillOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFloorDivOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "FloorDiv");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FloorDivOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertFloorModOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "FloorMod");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new FloorModOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertRangeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertRangeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "Range");
- CheckInputsCount(node, tf_import_flags, 3);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
auto* op = new RangeOperator;
if (HasAttr(node, "Tidx")) {
const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
@@ -1867,22 +1541,12 @@ void ConvertRangeOperator(const NodeDef& node,
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
-void ConvertRankOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Rank");
- CheckInputsCount(node, tf_import_flags, 1);
- auto* op = new RankOperator;
- op->inputs.push_back(node.input(0));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
-}
-
-void ConvertStackOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertStackOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK((node.op() == "Stack") || (node.op() == "Pack"));
auto* op = new StackOperator;
const int num_inputs = GetInputsCount(node, tf_import_flags);
@@ -1898,18 +1562,7 @@ void ConvertStackOperator(const NodeDef& node,
op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
-}
-
-void ConvertTransposeOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Transpose");
- CheckInputsCount(node, tf_import_flags, 2);
- auto* op = new TransposeOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
// Some TensorFlow ops only occur in graph cycles, representing
@@ -1922,7 +1575,7 @@ void ConvertTransposeOperator(const NodeDef& node,
// such ops as RNN back-edges, which is technically incorrect (does not
// allow representing the op's semantics) but good enough to get a
// graph visualization.
-void ConvertOperatorSpecialCasedAsRNNBackEdge(
+tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
// At the moment, the only type of operator special-cased in this way is
@@ -1935,6 +1588,23 @@ void ConvertOperatorSpecialCasedAsRNNBackEdge(
rnn_state->set_discardable(true);
rnn_state->set_state_array(node.name());
rnn_state->set_back_edge_source_array(node.input(0));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ConvertShapeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Shape");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
+ const auto out_type =
+ HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
+ CHECK(out_type == DT_INT64 || out_type == DT_INT32);
+ auto op = absl::make_unique<TensorFlowShapeOperator>();
+ op->output_data_type = ConvertDataType(out_type);
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.push_back(std::move(op));
+ return tensorflow::Status::OK();
}
void StripCaretFromArrayNames(Model* model) {
@@ -2077,9 +1747,9 @@ bool InlineAllFunctions(GraphDef* graphdef) {
return graph_modified;
}
-void ConvertTopKV2Operator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertTopKV2Operator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
auto op = absl::make_unique<TopKV2Operator>();
op->inputs.push_back(node.input(0));
@@ -2089,22 +1759,23 @@ void ConvertTopKV2Operator(const NodeDef& node,
model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
op->inputs.push_back(k_array);
} else {
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
op->inputs.push_back(node.input(1));
}
// The op has two outputs.
op->outputs.push_back(node.name());
op->outputs.push_back(node.name() + ":1");
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertDynamicPartitionOperator(
+tensorflow::Status ConvertDynamicPartitionOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
auto op = absl::make_unique<DynamicPartitionOperator>();
CHECK(HasAttr(node, "num_partitions"));
op->num_partitions = GetIntAttr(node, "num_partitions");
- CheckInputsCount(node, tf_import_flags, 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
op->inputs.push_back(node.input(0));
op->inputs.push_back(node.input(1));
CHECK_GT(op->num_partitions, 1);
@@ -2113,11 +1784,12 @@ void ConvertDynamicPartitionOperator(
op->outputs.push_back(node.name() + ":" + std::to_string(i));
}
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertDynamicStitchOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertDynamicStitchOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
// The parallel and non-parallel variants are the same besides whether they
// have a parallel loop; there are no behavioral differences.
CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
@@ -2125,19 +1797,20 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
CHECK(HasAttr(node, "N"));
op->num_partitions = GetIntAttr(node, "N");
// Expect all ID partitions + all value partitions.
- CheckInputsCount(node, tf_import_flags, op->num_partitions * 2);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
for (int i = 0; i < op->num_partitions * 2; ++i) {
op->inputs.push_back(node.input(i));
}
op->outputs.push_back(node.name());
model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
}
-void ConvertSparseToDenseOperator(const NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
+tensorflow::Status ConvertSparseToDenseOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
CHECK_EQ(node.op(), "SparseToDense");
- CheckInputsCount(node, tf_import_flags, 4);
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
auto* op = new SparseToDenseOperator;
for (const string& input : node.input()) {
@@ -2149,195 +1822,137 @@ void ConvertSparseToDenseOperator(const NodeDef& node,
? GetBoolAttr(node, "validate_indices")
: true;
model->operators.emplace_back(op);
+ return tensorflow::Status::OK();
}
} // namespace
namespace internal {
-Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
- const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- // TODO(ahentz): Historically these functions all CHECK-fail on error. We've
- // been slowly converting them to return Status.
- if (node.op() == "Const") {
- return ConvertConstOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2D") {
- ConvertConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "Conv2DBackpropInput") {
- ConvertTransposeConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthwiseConv2dNative") {
- ConvertDepthwiseConvOperator(node, tf_import_flags, model);
- } else if (node.op() == "DepthToSpace") {
- ConvertDepthToSpaceOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToDepth") {
- ConvertSpaceToDepthOperator(node, tf_import_flags, model);
- } else if (node.op() == "BiasAdd") {
- ConvertBiasAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu") {
- ConvertReluOperator(node, tf_import_flags, model);
- } else if (node.op() == "Relu6") {
- ConvertRelu6Operator(node, tf_import_flags, model);
- } else if (node.op() == "Sigmoid") {
- ConvertLogisticOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tanh") {
- ConvertTanhOperator(node, tf_import_flags, model);
- } else if (node.op() == "MaxPool") {
- ConvertMaxPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "AvgPool") {
- ConvertAvgPoolOperator(node, tf_import_flags, model);
- } else if (node.op() == "Reshape") {
- ConvertReshapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchMatMul") {
- ConvertBatchMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "MatMul") {
- ConvertMatMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Div" || node.op() == "RealDiv") {
- ConvertDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "Identity" || node.op() == "CheckNumerics" ||
- node.op() == "StopGradient") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxVars") {
- ConvertFakeQuantWithMinMaxVars(node, tf_import_flags, model);
- } else if (node.op() == "FakeQuantWithMinMaxArgs") {
- ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model);
- } else if (node.op() == "Neg") {
- ConvertNegOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rsqrt") {
- ConvertRsqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Squeeze") {
- ConvertSqueezeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sqrt") {
- ConvertSqrtOperator(node, tf_import_flags, model);
- } else if (node.op() == "Square") {
- ConvertSquareOperator(node, tf_import_flags, model);
- } else if (node.op() == "Add") {
- ConvertAddOperator(node, tf_import_flags, model);
- } else if (node.op() == "AddN") {
- ConvertAddNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mul") {
- ConvertMulOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sub") {
- ConvertSubOperator(node, tf_import_flags, model);
- } else if (node.op() == "Sum") {
- ConvertSumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Tile") {
- ConvertTileOperator(node, tf_import_flags, model);
- } else if (node.op() == "Concat" || node.op() == "ConcatV2") {
- ConvertConcatOperator(node, tf_import_flags, model);
- } else if (node.op() == "LRN") {
- ConvertLRNOperator(node, tf_import_flags, model);
- } else if (node.op() == "Softmax") {
- ConvertSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Log") {
- ConvertLogOperator(node, tf_import_flags, model);
- } else if (node.op() == "LogSoftmax") {
- ConvertLogSoftmaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "All") {
- ConvertAllOperator(node, tf_import_flags, model);
- } else if (node.op() == "Assert") {
- ConvertAssertOperator(node, tf_import_flags, model);
- } else if (node.op() == "Less") {
- ConvertLessOperator(node, tf_import_flags, model);
- } else if (node.op() == "LessEqual") {
- ConvertLessEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Greater") {
- ConvertGreaterOperator(node, tf_import_flags, model);
- } else if (node.op() == "GreaterEqual") {
- ConvertGreaterEqualOperator(node, tf_import_flags, model);
- } else if (node.op() == "Max") {
- ConvertMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Min") {
- ConvertMinOperator(node, tf_import_flags, model);
- } else if (node.op() == "Maximum") {
- ConvertMaximumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Minimum") {
- ConvertMinimumOperator(node, tf_import_flags, model);
- } else if (node.op() == "Merge") {
- ConvertMergeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Pad") {
- ConvertPadOperator(node, tf_import_flags, model);
- } else if (node.op() == "PadV2") {
- ConvertPadV2Operator(node, tf_import_flags, model);
- } else if (node.op() == "StridedSlice") {
- ConvertStridedSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Shape") {
- ConvertShapeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Slice") {
- ConvertSliceOperator(node, tf_import_flags, model);
- } else if (node.op() == "Split") {
- ConvertSplitOperator(node, tf_import_flags, model);
- } else if (node.op() == "Switch") {
- ConvertSwitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "Placeholder") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "PlaceholderWithDefault") {
- ConvertIdentityOperator(node, tf_import_flags, model);
- } else if (node.op() == "LegacyFedInput") {
- ConvertPlaceholderOperator(node, tf_import_flags, model);
- } else if (node.op() == "NoOp") {
- ConvertNoOpOperator(node, tf_import_flags, model);
- } else if (node.op() == "Cast") {
- ConvertCastOperator(node, tf_import_flags, model);
- } else if (node.op() == "Floor") {
- ConvertFloorOperator(node, tf_import_flags, model);
- } else if (node.op() == "Gather" || node.op() == "GatherV2") {
- ConvertGatherOperator(node, tf_import_flags, model);
- } else if (node.op() == "ResizeBilinear") {
- ConvertResizeBilinearOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchNormWithGlobalNormalization") {
- ConvertBatchNormWithGlobalNormalizationOperator(node, tf_import_flags,
- model);
- } else if (node.op() == "FusedBatchNorm") {
- ConvertFusedBatchNormOperator(node, tf_import_flags, model);
- } else if (node.op() == "SpaceToBatchND") {
- ConvertSpaceToBatchNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "BatchToSpaceND") {
- ConvertBatchToSpaceNDOperator(node, tf_import_flags, model);
- } else if (node.op() == "Mean") {
- ConvertMeanOperator(node, tf_import_flags, model);
- } else if (node.op() == "Svdf") {
- ConvertSvdfOperator(node, tf_import_flags, model);
- } else if (node.op() == "NextIteration") {
- ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model);
- } else if (node.op() == "ExpandDims") {
- ConvertExpandDimsOperator(node, tf_import_flags, model);
- } else if (node.op() == "Fill") {
- ConvertFillOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorDiv") {
- ConvertFloorDivOperator(node, tf_import_flags, model);
- } else if (node.op() == "FloorMod") {
- ConvertFloorModOperator(node, tf_import_flags, model);
- } else if (node.op() == "Range") {
- ConvertRangeOperator(node, tf_import_flags, model);
- } else if (node.op() == "Rank") {
- ConvertRankOperator(node, tf_import_flags, model);
- } else if (node.op() == "Stack" || node.op() == "Pack") {
- ConvertStackOperator(node, tf_import_flags, model);
- } else if (node.op() == "Transpose") {
- ConvertTransposeOperator(node, tf_import_flags, model);
- } else if (node.op() == "ArgMax") {
- ConvertArgMaxOperator(node, tf_import_flags, model);
- } else if (node.op() == "Exp") {
- ConvertExpOperator(node, tf_import_flags, model);
- } else if (node.op() == "TopK" || node.op() == "TopKV2") {
- ConvertTopKV2Operator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicPartition") {
- ConvertDynamicPartitionOperator(node, tf_import_flags, model);
- } else if (node.op() == "DynamicStitch" ||
- node.op() == "ParallelDynamicStitch") {
- ConvertDynamicStitchOperator(node, tf_import_flags, model);
- } else if (node.op() == "RandomUniform") {
- ConvertRandomUniform(node, tf_import_flags, model);
- } else if (node.op() == "Sin") {
- ConvertSinOperator(node, tf_import_flags, model);
- } else if (node.op() == "Select") {
- ConvertSelectOperator(node, tf_import_flags, model);
- } else if (node.op() == "SparseToDense") {
- ConvertSparseToDenseOperator(node, tf_import_flags, model);
+
+using ConverterType = tensorflow::Status (*)(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model);
+using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+
+constexpr char kArgMax[] = "ArgMax";
+constexpr char kArgMin[] = "ArgMin";
+
+ConverterMapType GetTensorFlowNodeConverterMap() {
+ return std::unordered_map<std::string, ConverterType>({
+ {"Add", ConvertSimpleOperator<AddOperator, 2>},
+ {"AddN", ConvertSimpleOperator<AddNOperator>},
+ {"All", ConvertSimpleOperator<TensorFlowAllOperator>},
+ {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
+ {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
+ {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
+ {"AvgPool", ConvertAvgPoolOperator},
+ {"BatchMatMul", ConvertBatchMatMulOperator},
+ {"BatchNormWithGlobalNormalization",
+ ConvertBatchNormWithGlobalNormalizationOperator},
+ {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
+ {"BiasAdd", ConvertBiasAddOperator},
+ {"Cast", ConvertCastOperator},
+ {"CheckNumerics", ConvertIdentityOperator},
+ {"Concat", ConvertConcatOperator},
+ {"ConcatV2", ConvertConcatOperator},
+ {"Const", ConvertConstOperator},
+ {"Conv2D", ConvertConvOperator},
+ {"Conv2DBackpropInput", ConvertTransposeConvOperator},
+ {"DepthToSpace", ConvertDepthToSpaceOperator},
+ {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
+ {"Div", ConvertSimpleOperator<DivOperator, 2>},
+ {"DynamicPartition", ConvertDynamicPartitionOperator},
+ {"DynamicStitch", ConvertDynamicStitchOperator},
+ {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2>},
+ {"Exp", ConvertSimpleOperator<ExpOperator, 1>},
+ {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2>},
+ {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
+ {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
+ {"Fill", ConvertSimpleOperator<FillOperator, 2>},
+ {"Floor", ConvertFloorOperator},
+ {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2>},
+ {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2>},
+ {"FusedBatchNorm", ConvertFusedBatchNormOperator},
+ {"Gather", ConvertGatherOperator},
+ {"GatherV2", ConvertGatherOperator},
+ {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2>},
+ {"GreaterEqual",
+ ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2>},
+ {"Identity", ConvertIdentityOperator},
+ {"LRN", ConvertLRNOperator},
+ {"LegacyFedInput", ConvertPlaceholderOperator},
+ {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2>},
+ {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
+ {"Log", ConvertSimpleOperator<LogOperator, 1>},
+ {"Log", ConvertSimpleOperator<LogOperator, 1>},
+ {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
+ {"MatMul", ConvertMatMulOperator},
+ {"Max", ConvertMaxOperator},
+ {"MaxPool", ConvertMaxPoolOperator},
+ {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
+ {"Mean", ConvertMeanOperator},
+ {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
+ {"Min", ConvertMinOperator},
+ {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
+ {"Mul", ConvertSimpleOperator<MulOperator, 2>},
+ {"Neg", ConvertSimpleOperator<NegOperator, 1>},
+ {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
+ {"NoOp", ConvertNoOpOperator},
+ {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
+ {"Pack", ConvertStackOperator},
+ {"Pad", ConvertSimpleOperator<PadOperator, 2>},
+ {"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
+ {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
+ {"Placeholder", ConvertPlaceholderOperator},
+ {"PlaceholderWithDefault", ConvertIdentityOperator},
+ {"Pow", ConvertSimpleOperator<PowOperator, 2>},
+ {"RandomUniform", ConvertRandomUniform},
+ {"Range", ConvertRangeOperator},
+ {"Rank", ConvertSimpleOperator<RankOperator, 1>},
+ {"RealDiv", ConvertSimpleOperator<DivOperator, 2>},
+ {"Relu", ConvertSimpleOperator<ReluOperator, 1>},
+ {"Relu6", ConvertSimpleOperator<Relu6Operator, 1>},
+ {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2>},
+ {"ResizeBilinear", ConvertResizeBilinearOperator},
+ {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>},
+ {"Select", ConvertSimpleOperator<SelectOperator, 3>},
+ {"Shape", ConvertShapeOperator},
+ {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>},
+ {"Sin", ConvertSimpleOperator<SinOperator, 1>},
+ {"Slice", ConvertSimpleOperator<SliceOperator, 3>},
+ {"Softmax", ConvertSoftmaxOperator},
+ {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
+ {"SpaceToDepth", ConvertSpaceToDepthOperator},
+ {"SparseToDense", ConvertSparseToDenseOperator},
+ {"Split", ConvertSplitOperator},
+ {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1>},
+ {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1>},
+ {"Squeeze", ConvertSqueezeOperator},
+ {"Stack", ConvertStackOperator},
+ {"StopGradient", ConvertIdentityOperator},
+ {"StridedSlice", ConvertStridedSliceOperator},
+ {"Sub", ConvertSimpleOperator<SubOperator, 2>},
+ {"Sum", ConvertSumOperator},
+ {"Svdf", ConvertSvdfOperator},
+ {"Switch", ConvertSwitchOperator},
+ {"Tanh", ConvertSimpleOperator<TanhOperator, 1>},
+ {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2>},
+ {"TopK", ConvertTopKV2Operator},
+ {"TopKV2", ConvertTopKV2Operator},
+ {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
+ });
+}
+
+tensorflow::Status ImportTensorFlowNode(
+ const tensorflow::NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags, Model* model,
+ const ConverterMapType& converter_map) {
+ auto converter = converter_map.find(node.op());
+ if (converter == converter_map.end()) {
+ return ConvertUnsupportedOperator(node, tf_import_flags, model);
} else {
- ConvertUnsupportedOperator(node, tf_import_flags, model);
+ return converter->second(node, tf_import_flags, model);
}
- return Status::OK();
}
} // namespace internal
@@ -2363,10 +1978,13 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
+ const internal::ConverterMapType& converter_map =
+ internal::GetTensorFlowNodeConverterMap();
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
- auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model);
+ auto status = internal::ImportTensorFlowNode(node, tf_import_flags, model,
+ converter_map);
CHECK(status.ok()) << status.error_message();
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 835676662b..90e6f698ef 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -21,10 +21,10 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status.h"
namespace toco {
-using port::Status;
using tensorflow::AttrValue;
using tensorflow::DT_BOOL;
using tensorflow::DT_FLOAT;
@@ -33,10 +33,17 @@ using tensorflow::DT_INT64;
using tensorflow::DT_QUINT8;
using tensorflow::DT_STRING;
using tensorflow::NodeDef;
+using tensorflow::Status;
namespace internal {
+using ConverterType = tensorflow::Status (*)(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model);
+using ConverterMapType = std::unordered_map<std::string, ConverterType>;
+
+ConverterMapType GetTensorFlowNodeConverterMap();
Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
- Model*);
+ Model*, const ConverterMapType&);
} // namespace internal
namespace {
@@ -104,8 +111,9 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
Status ImportNode(const NodeDef& node) {
Model model;
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(),
- &model);
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
+ converter);
}
};
@@ -117,9 +125,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) {
NodeDef node;
BuildConstNode({1, -2, 10}, GetParam(), 0, &node);
auto status = ImportNode(node);
- EXPECT_EQ(status.error_message(),
- "Tensor shape should not include negative values (while processing "
- "node 'Node1')");
+ EXPECT_EQ(
+ status.error_message(),
+ "Tensor shape should not include negative values\n\t (while processing "
+ "node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -129,7 +138,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) {
BuildConstNode({3000000000}, GetParam(), 0, &node);
auto status = ImportNode(node);
EXPECT_EQ(status.error_message(),
- "Shape element overflows (while processing node 'Node1')");
+ "Shape element overflows\n\t (while processing node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -139,7 +148,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) {
BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node);
auto status = ImportNode(node);
EXPECT_EQ(status.error_message(),
- "Tensor shape is too large (while processing node 'Node1')");
+ "Tensor shape is too large\n\t (while processing node 'Node1')");
}
INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
@@ -148,11 +157,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
NodeDef node;
BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node);
auto status = ImportNode(node);
- EXPECT_THAT(
- status.error_message(),
- ::testing::MatchesRegex(
- "Neither input_content .0. nor .*_val .0. have the right "
- "dimensions .8. for this .* tensor .while processing node 'Node1'."));
+ EXPECT_THAT(status.error_message(),
+ ::testing::MatchesRegex(
+ "Neither input_content .0. nor .*_val .0. have the right "
+ "dimensions .8. for this .* tensor\n\t .while processing "
+ "node 'Node1'."));
}
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 9062c03c73..8660464fdb 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
+#include <complex>
#include <functional>
#include <initializer_list>
#include <memory>
@@ -32,7 +33,7 @@ namespace toco {
using tflite::QuantizationParams;
-enum class OperatorType {
+enum class OperatorType : uint8 {
kNone,
// General-purpose neural network operators.
kAdd,
@@ -96,38 +97,38 @@ enum class OperatorType {
// Special operators used for importing TensorFlow nodes.
// The general intent is to have some graph transformation either
// drop them or rewrite them as general-purpose operators.
- kTensorFlowAll,
- kTensorFlowAssert,
- kTensorFlowConcat,
- kTensorFlowConcatV2,
- kTensorFlowGreater,
- kTensorFlowGreaterEqual,
- kTensorFlowIdentity,
- kTensorFlowLess,
- kTensorFlowLessEqual,
- kTensorFlowMax,
- kTensorFlowMaximum,
- kTensorFlowMin,
- kTensorFlowMinimum,
- kTensorFlowMatMul,
- kTensorFlowMerge,
+ kAll,
+ kAssert,
+ kConcat,
+ kConcatV2,
+ kGreater,
+ kGreaterEqual,
+ kIdentity,
+ kLess,
+ kLessEqual,
+ kMax, // Reduction Max
+ kMaximum, // Element-wise Maximum
+ kMin, // Reduction Min
+ kMinimum, // Element-wise Minimum
+ kMatMul,
+ kMerge,
kNeg,
- kTensorFlowReshape,
- kTensorFlowRsqrt,
- kTensorFlowShape,
- kTensorFlowSplit,
- kTensorFlowSqrt,
- kTensorFlowSquare,
- kTensorFlowSum,
- kTensorFlowSwitch,
- kTensorFlowTile,
+ kReshape,
+ kRsqrt,
+ kShape,
+ kSplit,
+ kSqrt,
+ kSquare,
+ kSum,
+ kSwitch,
+ kTile,
kTranspose,
kTopK_V2,
kDynamicPartition,
kDynamicStitch,
// An unsupported TF operation. It's only needed to be able to represent TF
// graph internally and is expected to be dropped by graph transformations.
- kTensorFlowUnsupported,
+ kUnsupported,
// Finally, TensorFlow uses different conventions for axes ordering,
// see AxesOrder, and this cannot always be resolved at the time of importing
// nodes, as TensorFlow parameters may be constant-expression subgraphs
@@ -136,6 +137,10 @@ enum class OperatorType {
kReorderAxes,
kSelect,
kSparseToDense,
+ kEqual,
+ kNotEqual,
+ kPow,
+ kArgMin,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -153,25 +158,27 @@ enum class AxesOrder {
k1HWO, // Our standard for DepthwiseConv weights
kHWIM, // TensorFlow DepthwiseConv weights
kNHWC, // TensorFlow activations
+ kHWOI, // TensorFlow back-prop conv weights
};
// The type of the scalars in an array.
// Note that the type does not by itself tell whether the values in the array
-// are real (are literally interpreted as real numbers) or quantized (only
-// acquire a meaning as real numbers in conjunction with QuantizationParams).
+// are non-quantized (can be accessed directly) or quantized (must be
+// interpreted in conjunction with QuantizationParams).
//
// In practice though:
-// float values are always real
+// float values are never quantized
// uint8 values are always quantized
-// int32 values are either real or quantized (depending on whether
+// int32 values are sometimes quantized (depending on whether
// QuantizationParams are present).
-// other types are unused at the moment.
+// complex values are never quantized
+// other types are never quantized at the moment.
//
// kNone means that we don't know the data type yet, or that we don't care
// because we'll be dropping the array anyway (e.g. some exotic array types
// may be involved only in debug-only subgraphs that we may not be interested
// in actually supporting).
-enum class ArrayDataType {
+enum class ArrayDataType : uint8 {
kNone, // 0
kBool,
kFloat,
@@ -183,7 +190,8 @@ enum class ArrayDataType {
kUint32,
kInt64,
kUint64, // 10
- kString
+ kString,
+ kComplex64,
};
// Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
@@ -237,6 +245,10 @@ template <>
struct DataTypeImpl<ArrayDataType::kString> {
typedef string Type;
};
+template <>
+struct DataTypeImpl<ArrayDataType::kComplex64> {
+ typedef std::complex<float> Type;
+};
template <ArrayDataType A>
using DataType = typename DataTypeImpl<A>::Type;
@@ -430,7 +442,8 @@ struct SpaceToDepthOperator : Operator {
// input activations as a matrix, followed by a MatMul node.
struct FullyConnectedOperator : Operator {
FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
- bool experimental_shuffled_weights = false;
+ FullyConnectedWeightsFormat weights_format =
+ FullyConnectedWeightsFormat::kDefault;
};
// Dequantization operator, converting a quantized array of integers with
@@ -527,7 +540,15 @@ struct LstmCellOperator : Operator {
ACTIV_TEMP = 3,
NUM_OUTPUTS = 4
};
- LstmCellOperator() : Operator(OperatorType::kLstmCell) {}
+ enum KernelType {
+ KERNEL_BASIC = 0,
+ KERNEL_FULL = 1,
+ };
+
+ LstmCellOperator()
+ : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {}
+
+ KernelType kernel_type;
};
// Element-wise multiplication operator.
@@ -790,7 +811,7 @@ struct DivOperator : Operator {
//
// TensorFlow equivalent: Identity
struct TensorFlowIdentityOperator : Operator {
- TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {}
+ TensorFlowIdentityOperator() : Operator(OperatorType::kIdentity) {}
};
// Batch matrix multiplication operator. This comes from the (deprecated)
@@ -816,7 +837,9 @@ struct BatchMatMulOperator : Operator {
//
// TensorFlow equivalent: MatMul
struct TensorFlowMatMulOperator : Operator {
- TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {}
+ TensorFlowMatMulOperator() : Operator(OperatorType::kMatMul) {}
+ bool transpose_a = false;
+ bool transpose_b = false;
};
// Padding operator. Pads a tensor with zeros.
@@ -950,7 +973,7 @@ struct StridedSliceOperator : Operator {
// TensorFlow equivalent: Reshape --- except that we only support a special case
// here, where the output shape is a matrix (2D) shape.
struct TensorFlowReshapeOperator : Operator {
- TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {}
+ TensorFlowReshapeOperator() : Operator(OperatorType::kReshape) {}
std::vector<int> shape;
};
@@ -1120,7 +1143,7 @@ struct SelectOperator : Operator {
//
// TensorFlow equivalent: Rsqrt
struct TensorFlowRsqrtOperator : Operator {
- TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {}
+ TensorFlowRsqrtOperator() : Operator(OperatorType::kRsqrt) {}
};
// Stacks a list of rank-R tensors into one rank-(R+1) tensor.
@@ -1146,10 +1169,10 @@ struct StackOperator : Operator {
// This operation outputs a 1-D integer tensor representing the shape of
// the input.
//
-// TensorFlow equivalent: Shape. We currently assume that the output is int32
-// and not int64. The output type could be stored herein.
+// TensorFlow equivalent: Shape.
struct TensorFlowShapeOperator : Operator {
- TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {}
+ TensorFlowShapeOperator() : Operator(OperatorType::kShape) {}
+ ArrayDataType output_data_type = ArrayDataType::kInt32;
};
// Element-wise square-root (x^0.5) operator.
@@ -1159,7 +1182,7 @@ struct TensorFlowShapeOperator : Operator {
//
// TensorFlow equivalent: Sqrt
struct TensorFlowSqrtOperator : Operator {
- TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {}
+ TensorFlowSqrtOperator() : Operator(OperatorType::kSqrt) {}
};
// Element-wise square (x*x) operator.
@@ -1169,7 +1192,7 @@ struct TensorFlowSqrtOperator : Operator {
//
// TensorFlow equivalent: Square
struct TensorFlowSquareOperator : Operator {
- TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {}
+ TensorFlowSquareOperator() : Operator(OperatorType::kSquare) {}
};
// Transposes a tensor.
@@ -1197,24 +1220,24 @@ struct SubOperator : Operator {
SubOperator() : Operator(OperatorType::kSub) {}
};
-// Global sum reduction: computes the sum of all of entries in the input array.
-// Thus the output is "0-dimensional": it consists of a single scalar value.
+// Sum reduction: computes the sum of all of entries across the axes.
//
// Inputs:
// inputs[0]: required: the input array
//
-// TensorFlow equivalent: Sum --- except that we only support the special case
-// of global reduction across all dimensions.
+// TensorFlow equivalent: Sum
struct TensorFlowSumOperator : Operator {
- TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {}
+ TensorFlowSumOperator() : Operator(OperatorType::kSum) {}
bool keep_dims = false;
};
// TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
-// Not fully supported, just a placeholder to handle TensorFlow graphs and
-// support graph transformations to other operator types by matching sub-graphs.
+//
+// Inputs:
+// inputs[0]: required: the input array
+// inputs[1]: required: int array with length of rank(input[0])
struct TensorFlowTileOperator : Operator {
- TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
+ TensorFlowTileOperator() : Operator(OperatorType::kTile) {}
};
// TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
@@ -1229,7 +1252,7 @@ struct SliceOperator : Operator {
// Not fully supported, just a placeholder to handle TensorFlow graphs and
// support graph transformations to other operator types by matching sub-graphs.
struct TensorFlowSplitOperator : Operator {
- TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {}
+ TensorFlowSplitOperator() : Operator(OperatorType::kSplit) {}
int num_split = 0;
};
@@ -1240,7 +1263,7 @@ struct TensorFlowSplitOperator : Operator {
// dimension then we can change this op into a DepthConcatenation op.
// Otherwise, we hope for some other graph transformation to drop this node.
struct TensorFlowConcatOperator : Operator {
- TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {}
+ TensorFlowConcatOperator() : Operator(OperatorType::kConcat) {}
};
// TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
@@ -1251,7 +1274,7 @@ struct TensorFlowConcatOperator : Operator {
// dimension then we can change this op into a DepthConcatenation op.
// Otherwise, we hope for some other graph transformation to drop this node.
struct TensorFlowConcatV2Operator : Operator {
- TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {}
+ TensorFlowConcatV2Operator() : Operator(OperatorType::kConcatV2) {}
};
// TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
@@ -1267,7 +1290,7 @@ struct TensorFlowConcatV2Operator : Operator {
// control flow that can be resolved at tooling time (independently of input
// activations).
struct TensorFlowMergeOperator : Operator {
- TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {}
+ TensorFlowMergeOperator() : Operator(OperatorType::kMerge) {}
};
// TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
@@ -1290,7 +1313,7 @@ struct TensorFlowMergeOperator : Operator {
// control flow that can be resolved at tooling time (independently of input
// activations).
struct TensorFlowSwitchOperator : Operator {
- TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {}
+ TensorFlowSwitchOperator() : Operator(OperatorType::kSwitch) {}
};
// TensorFlow All equivalent. Refer to TensorFlow documentation for details.
@@ -1299,7 +1322,7 @@ struct TensorFlowSwitchOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowAllOperator : Operator {
- TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {}
+ TensorFlowAllOperator() : Operator(OperatorType::kAll) {}
};
// TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
@@ -1307,7 +1330,7 @@ struct TensorFlowAllOperator : Operator {
// support graph transformations to other operator types by matching sub-graphs.
// Typically, we just drop Assert nodes.
struct TensorFlowAssertOperator : Operator {
- TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {}
+ TensorFlowAssertOperator() : Operator(OperatorType::kAssert) {}
};
// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
@@ -1316,7 +1339,7 @@ struct TensorFlowAssertOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowLessOperator : Operator {
- TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {}
+ TensorFlowLessOperator() : Operator(OperatorType::kLess) {}
};
// TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
@@ -1326,8 +1349,7 @@ struct TensorFlowLessOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowLessEqualOperator : Operator {
- TensorFlowLessEqualOperator()
- : Operator(OperatorType::kTensorFlowLessEqual) {}
+ TensorFlowLessEqualOperator() : Operator(OperatorType::kLessEqual) {}
};
// TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
@@ -1336,7 +1358,7 @@ struct TensorFlowLessEqualOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowGreaterOperator : Operator {
- TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {}
+ TensorFlowGreaterOperator() : Operator(OperatorType::kGreater) {}
};
// TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
@@ -1346,8 +1368,23 @@ struct TensorFlowGreaterOperator : Operator {
// Typically, this is only used as an input to an Assert node, so can be
// removed as an unused node as we drop Assert nodes.
struct TensorFlowGreaterEqualOperator : Operator {
- TensorFlowGreaterEqualOperator()
- : Operator(OperatorType::kTensorFlowGreaterEqual) {}
+ TensorFlowGreaterEqualOperator() : Operator(OperatorType::kGreaterEqual) {}
+};
+
+// TensorFlow Equal equivalent. Refer to TensorFlow documentation for
+// details.
+// Not fully supported, just a placeholder to handle TensorFlow graphs and
+// support graph transformations to other operator types by matching sub-graphs.
+// Typically, this is only used as an input to an Assert node, so can be
+// removed as an unused node as we drop Assert nodes.
+struct TensorFlowEqualOperator : Operator {
+ TensorFlowEqualOperator() : Operator(OperatorType::kEqual) {}
+};
+
+// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for
+// details.
+struct TensorFlowNotEqualOperator : Operator {
+ TensorFlowNotEqualOperator() : Operator(OperatorType::kNotEqual) {}
};
// Global max reduction: computes the max of all of entries in the input array.
@@ -1359,7 +1396,7 @@ struct TensorFlowGreaterEqualOperator : Operator {
// TensorFlow equivalent: Max --- except that we only support the special case
// of global reduction across all dimensions.
struct TensorFlowMaxOperator : Operator {
- TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {}
+ TensorFlowMaxOperator() : Operator(OperatorType::kMax) {}
bool keep_dims = false;
};
@@ -1372,7 +1409,7 @@ struct TensorFlowMaxOperator : Operator {
// TensorFlow equivalent: Min --- except that we only support the special case
// of global reduction across all dimensions.
struct TensorFlowMinOperator : Operator {
- TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {}
+ TensorFlowMinOperator() : Operator(OperatorType::kMin) {}
bool keep_dims = false;
};
@@ -1385,7 +1422,7 @@ struct TensorFlowMinOperator : Operator {
//
// TensorFlow equivalent: Maximum
struct TensorFlowMaximumOperator : Operator {
- TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {}
+ TensorFlowMaximumOperator() : Operator(OperatorType::kMaximum) {}
};
// Element-wise minimum operator. Currently it only supports scalar as
@@ -1397,14 +1434,13 @@ struct TensorFlowMaximumOperator : Operator {
//
// TensorFlow equivalent: Minimum
struct TensorFlowMinimumOperator : Operator {
- TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {}
+ TensorFlowMinimumOperator() : Operator(OperatorType::kMinimum) {}
};
// General TF operation, unsupported by tf.mini. Expected to be dropped by
// graph transformations.
struct TensorFlowUnsupportedOperator : Operator {
- TensorFlowUnsupportedOperator()
- : Operator(OperatorType::kTensorFlowUnsupported) {}
+ TensorFlowUnsupportedOperator() : Operator(OperatorType::kUnsupported) {}
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
@@ -1493,6 +1529,17 @@ struct ArgMaxOperator : Operator {
ArrayDataType output_data_type = ArrayDataType::kInt64;
};
+// ArgMin operator. It returns the index of the minimum value along axis.
+//
+// Inputs:
+// inputs[0]: required: the input tensor
+//
+// TensorFlow equivalent: ArgMin
+struct ArgMinOperator : Operator {
+ ArgMinOperator() : Operator(OperatorType::kArgMin) {}
+ ArrayDataType output_data_type = ArrayDataType::kInt64;
+};
+
// ResizeBilinear operator. It resizes input images with bilinear interpolation.
// It does not support align_corners at the moment.
//
@@ -1612,13 +1659,24 @@ struct SparseToDenseOperator : Operator {
bool validate_indices;
};
+// Pow operator:
+//
+// Inputs:
+// Inputs[0]: required: A tensor.
+// Inputs[1]: required: A tensor.
+//
+// TensorFlow equivalent: Pow.
+struct PowOperator : Operator {
+ PowOperator() : Operator(OperatorType::kPow) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
// offsets from the start of the workspace buffer, expressed in bytes.
struct Alloc {
- int start = 0;
- int end = 0;
+ int64 start = 0;
+ int64 end = 0;
};
inline bool operator<(const Alloc& a, const Alloc& b) {
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 0f104d5e2d..06072d1fcb 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags(
"that information from the input file."),
Flag("input_arrays", parsed_flags.input_arrays.bind(),
parsed_flags.input_arrays.default_value(),
- "Names of the output arrays, comma-separated. If not specified, "
+ "Names of the input arrays, comma-separated. If not specified, "
"will try to read that information from the input file."),
Flag("output_array", parsed_flags.output_array.bind(),
parsed_flags.output_array.default_value(),
@@ -74,10 +74,10 @@ bool ParseModelFlagsFromCommandLineFlags(
"height, input array width, input array depth."),
Flag("batch_size", parsed_flags.batch_size.bind(),
parsed_flags.batch_size.default_value(),
- "Batch size for the model. Replaces the first dimension of an "
- "input size array if undefined. Use only with SavedModels when "
- "--input_shapes flag is not specified. Always use --input_shapes "
- "flag with frozen graphs."),
+ "Deprecated. Batch size for the model. Replaces the first dimension "
+ "of an input size array if undefined. Use only with SavedModels "
+ "when --input_shapes flag is not specified. Always use "
+ "--input_shapes flag with frozen graphs."),
Flag("input_data_type", parsed_flags.input_data_type.bind(),
parsed_flags.input_data_type.default_value(),
"Deprecated: use --input_data_types instead. Input array type, if "
diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD
index a954f1d6ba..93fe756a55 100644
--- a/tensorflow/contrib/lite/toco/python/BUILD
+++ b/tensorflow/contrib/lite/toco/python/BUILD
@@ -12,6 +12,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite/toco:model_flags_proto_cc",
"//tensorflow/contrib/lite/toco:toco_flags_proto_cc",
+ "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options",
"//tensorflow/contrib/lite/toco:toco_port",
"//tensorflow/contrib/lite/toco:toco_tooling",
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
index 5b1db852b4..d93e104038 100644
--- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc
+++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/python/toco_python_api.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_tooling.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
@@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error);
if (error) return nullptr;
- // Use toco to produce new outputs
+ // Use TOCO to produce new outputs.
toco::ModelFlags model_flags;
if (!model_flags.ParseFromString(model_flags_proto_txt)) {
LOG(FATAL) << "Model proto failed to parse." << std::endl;
@@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw,
if (!toco_flags.ParseFromString(toco_flags_proto_txt)) {
LOG(FATAL) << "Toco proto failed to parse." << std::endl;
}
+
+ auto& dump_options = *GraphVizDumpOptions::singleton();
+ if (toco_flags.has_dump_graphviz_dir()) {
+ dump_options.dump_graphviz = toco_flags.dump_graphviz_dir();
+ }
+ if (toco_flags.has_dump_graphviz_include_video()) {
+ dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video();
+ }
+
+ // Convert model.
std::unique_ptr<toco::Model> model =
toco::Import(toco_flags, model_flags, input_contents_txt);
toco::Transform(toco_flags, model.get());
diff --git a/tensorflow/contrib/lite/toco/runtime/types.h b/tensorflow/contrib/lite/toco/runtime/types.h
index f5de5a5781..207f2c1706 100644
--- a/tensorflow/contrib/lite/toco/runtime/types.h
+++ b/tensorflow/contrib/lite/toco/runtime/types.h
@@ -24,6 +24,7 @@ namespace toco {
// TODO(ahentz): These are just stopgaps for now, untils we move all
// the code over to tflite.
using tflite::Dims;
+using tflite::FullyConnectedWeightsFormat;
using tflite::FusedActivationFunctionType;
using tflite::RequiredBufferSizeForDims;
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index e1025c6664..a02f90988b 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -24,6 +24,7 @@ cc_library(
deps = [
":types",
"//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5daa703c80..5ad307af14 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -49,7 +49,7 @@ details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
string custom_code;
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
custom_code = unsupported_op.tensorflow_op;
@@ -99,7 +99,8 @@ void LoadOperatorsMap(
Offset<Vector<Offset<Tensor>>> ExportTensors(
const Model& model, const details::TensorsMap& tensors_map,
- FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write) {
+ FlatBufferBuilder* builder, std::vector<const Array*>* buffers_to_write,
+ const std::set<int32_t>& variable_tensor_indices) {
// In the end we will need to produce a vector sorted by the indices of the
// tensors in the tensors_map.
std::map<int, Offset<Tensor>> ordered_tensors;
@@ -139,9 +140,11 @@ Offset<Vector<Offset<Tensor>>> ExportTensors(
scale, zero_point);
int index = tensors_map.at(tensor_name);
+ bool is_variable =
+ variable_tensor_indices.find(index) != variable_tensor_indices.end();
ordered_tensors[index] =
CreateTensor(*builder, builder->CreateVector(shape), type, buffer_index,
- builder->CreateString(tensor_name), q_param);
+ builder->CreateString(tensor_name), q_param, is_variable);
}
std::vector<Offset<Tensor>> tensor_vector;
@@ -208,7 +211,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
ordered_opcodes[op_index] =
CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
} else {
- // This could be a kTensorFlowUnsupported, in which case we should be
+ // This could be a kUnsupported, in which case we should be
// able to retrieve the original Tensorflow name from the OperatorKey, or
// this could be a proper TOCO operator that is completely unknown to TF
// Lite.
@@ -239,7 +242,10 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
- const details::TensorsMap& tensors_map, FlatBufferBuilder* builder) {
+ const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
+ std::set<int32_t>* variable_tensor_indices) {
+ variable_tensor_indices->clear();
+
// The operators are in execution order, so we just follow tf.mini order.
std::vector<Offset<Operator>> op_vector;
for (const auto& op : model.operators) {
@@ -256,18 +262,36 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
- // This is a custom op unless we can find it in ops_by_type, and even then
- // it could be a custom op (such as kTensorFlowUnsupported).
+ auto tflite_op_it = ops_by_type.find(op->type);
+ BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
+ ? nullptr
+ : tflite_op_it->second.get();
+ // This is a custom op unless we can find it in ops_by_type, and even then
+ // it could be a custom op (such as kUnsupported).
auto options = Options::Custom(0);
- if (ops_by_type.count(op->type) != 0) {
- options = ops_by_type.at(op->type)->Serialize(*op, builder);
+
+ std::vector<bool> mutating_input_variables;
+ if (tflite_op) {
+ options = tflite_op->Serialize(*op, builder);
+ mutating_input_variables = tflite_op->GetMutatingInputVariables(*op);
+
+ if (!mutating_input_variables.empty()) {
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (!mutating_input_variables[i]) {
+ continue;
+ }
+ int32_t variable_tensor_index = tensors_map.at(op->inputs[i]);
+ variable_tensor_indices->insert(variable_tensor_index);
+ }
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
*builder, op_index, builder->CreateVector(inputs),
builder->CreateVector(outputs), options.type, options.builtin,
- options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS));
+ options.custom, ::tflite::CustomOptionsFormat_FLEXBUFFERS,
+ builder->CreateVector(mutating_input_variables)));
}
return builder->CreateVector(op_vector);
@@ -308,25 +332,34 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write);
- auto inputs = ExportInputTensors(model, tensors_map, &builder);
- auto outputs = ExportOutputTensors(model, tensors_map, &builder);
-
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
&builder, &error_summary);
- const string fake_quant_operation_name = "FAKE_QUANT";
- if (error_summary.count(fake_quant_operation_name) != 0) {
- LOG(ERROR)
- << fake_quant_operation_name
- << " operation was not converted. If running quantized make sure you "
- "are passing --inference_type=QUANTIZED_UINT8 and values for "
- "--std_values and --mean_values.";
- // Remove the fake quant operation from the errors, since it shouldn't
- // be provided a custom implementation.
- error_summary.erase(fake_quant_operation_name);
+
+ for (const auto& op : model.operators) {
+ if (op->type == OperatorType::kFakeQuant) {
+ LOG(WARNING) << "FAKE_QUANT operation " << LogName(*op)
+ << " was not converted. If running quantized make sure you "
+ "are passing --inference_type=QUANTIZED_UINT8 and values "
+ "for --std_values and --mean_values.";
+ }
}
if (!allow_custom_ops && !error_summary.empty()) {
+ // Remove ExpandDims and ReorderAxes from unimplemented list unless they
+ // compose the list. Both ops are removed during graph transformations.
+ // However, if an op is unimplemented earlier in the model, the graph
+ // transformation is unable to run because the output shape is not defined.
+ // This causes unnecessary confusion during model conversion time.
+ std::set<string> error_summary_final;
+ for (const auto& op_type : error_summary) {
+ if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
+ error_summary_final.insert(op_type);
+ }
+ }
+ if (error_summary_final.empty()) {
+ error_summary_final = error_summary;
+ }
+
LOG(QFATAL)
<< "Some of the operators in the model are not supported by "
"the standard TensorFlow Lite runtime. If you have a custom "
@@ -334,14 +367,21 @@ void Export(
"--allow_custom_ops, or by setting allow_custom_ops=True "
"when calling tf.contrib.lite.toco_convert(). Here is a list "
"of operators for which you will need custom implementations: "
- << absl::StrJoin(error_summary, ", ") << ".";
+ << absl::StrJoin(error_summary_final, ", ") << ".";
}
- auto ops =
- ExportOperators(model, ops_by_type, operators_map, tensors_map, &builder);
+ std::set<int32_t> variable_tensor_indices;
+ auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
+ &builder, &variable_tensor_indices);
+
+ auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
+ variable_tensor_indices);
+ auto inputs = ExportInputTensors(model, tensors_map, &builder);
+ auto outputs = ExportOutputTensors(model, tensors_map, &builder);
// TODO(aselle): add support to toco for multiple subgraphs.
- auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops);
+ auto subgraph = CreateSubGraph(builder, tensors, inputs, outputs, ops,
+ /* name */ 0);
std::vector<flatbuffers::Offset<SubGraph>> subgraphs = {subgraph};
auto buffers = ExportBuffers(model, buffers_to_write, &builder);
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 098d2163e6..58ea5c725c 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -45,7 +45,7 @@ namespace details {
using TensorsMap = std::unordered_map<string, int>;
// A key to identify an operator.
-// Only when `type` is `kTensorFlowUnsupported`, `custom_code` is filled to
+// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
OperatorKey(OperatorType type, const std::string& custom_code, int version)
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 409e7d72a5..d1fdbcb8e9 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -73,8 +73,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
- EXPECT_EQ(3, operators[details::OperatorKey(
- OperatorType::kTensorFlowUnsupported, "MyCrazyOp", 1)]);
+ EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported,
+ "MyCrazyOp", 1)]);
}
TEST_F(ExportTest, Export) {
diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc
index c0e7ab2ef5..1dd4915b31 100644
--- a/tensorflow/contrib/lite/toco/tflite/import.cc
+++ b/tensorflow/contrib/lite/toco/tflite/import.cc
@@ -113,15 +113,35 @@ void ImportOperators(
<< operators_table.size();
}
string opname = operators_table.at(index);
+
+ // Find and use the appropriate operator deserialization factory.
+ std::unique_ptr<Operator> new_op = nullptr;
if (ops_by_name.count(opname) == 0) {
- LOG(FATAL) << "Op '" << opname << "' not supported";
+ string effective_opname = "TENSORFLOW_UNSUPPORTED";
+ if (ops_by_name.count(effective_opname) == 0) {
+ LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found.";
+ }
+ new_op = ops_by_name.at(effective_opname)
+ ->Deserialize(input_op->builtin_options(),
+ input_op->custom_options());
+ if (new_op->type == OperatorType::kUnsupported) {
+ auto* unsupported_op =
+ static_cast<TensorFlowUnsupportedOperator*>(new_op.get());
+ unsupported_op->tensorflow_op = opname;
+ // TODO(b/109932940): Remove this when quantized is removed.
+ // For now, we assume all ops are quantized.
+ unsupported_op->quantized = true;
+ } else {
+ LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator";
+ }
+ } else {
+ new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(),
+ input_op->custom_options());
}
-
- auto new_op = ops_by_name.at(opname)->Deserialize(
- input_op->builtin_options(), input_op->custom_options());
model->operators.emplace_back(new_op.release());
auto* op = model->operators.back().get();
+ // Make sure all the inputs and outputs are hooked up.
auto inputs = input_op->inputs();
for (int i = 0; i < inputs->Length(); i++) {
auto input_index = inputs->Get(i);
@@ -201,6 +221,8 @@ std::unique_ptr<Model> Import(const ModelFlags& model_flags,
model.get());
ImportIOTensors(*input_model, tensors_table, model.get());
+ UndoWeightsShuffling(model.get());
+
return model;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 8f0f2e24db..8377ba6a03 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+// TODO(ycling): Consider refactoring to extract the LSTM definition out of
+// graph_transformation module.
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
@@ -279,22 +282,24 @@ class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
int GetVersion(const Operator& op) const override { return 1; }
};
-class FakeQuant : public CustomOperator<FakeQuantOperator> {
+class FakeQuant
+ : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
+ ::tflite::BuiltinOptions_FakeQuantOptions> {
public:
- using CustomOperator::CustomOperator;
- void WriteOptions(const TocoOperator& op,
- flexbuffers::Builder* fbb) const override {
- fbb->Float("min", op.minmax->min);
- fbb->Float("max", op.minmax->max);
- fbb->Int("num_bits", op.num_bits);
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateFakeQuantOptions(*builder, op.minmax->min,
+ op.minmax->max, op.num_bits);
}
- void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
auto* minmax = new MinMax;
- minmax->min = m["min"].AsFloat();
- minmax->max = m["max"].AsFloat();
+ minmax->min = options.min();
+ minmax->max = options.max();
op->minmax.reset(minmax);
- const auto& num_bits = m["num_bits"];
- op->num_bits = num_bits.IsInt() ? num_bits.AsInt32() : 8;
+ op->num_bits = options.num_bits();
}
int GetVersion(const Operator& op) const override { return 1; }
@@ -311,16 +316,47 @@ class FullyConnected
flatbuffers::FlatBufferBuilder* builder) const override {
auto activation_function =
ActivationFunction::Serialize(op.fused_activation_function);
- return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
+ ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
+ switch (op.weights_format) {
+ case FullyConnectedWeightsFormat::kDefault:
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ break;
+ case FullyConnectedWeightsFormat::kShuffled4x16Int8:
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ break;
+ default:
+ LOG(ERROR) << "Unhandled FC weights format";
+ tflite_weights_format =
+ ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ }
+ return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
+ tflite_weights_format);
}
void ReadOptions(const TfLiteOptions& options,
TocoOperator* op) const override {
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ switch (options.weights_format()) {
+ case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
+ op->weights_format = FullyConnectedWeightsFormat::kDefault;
+ break;
+ case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
+ break;
+ default:
+ LOG(ERROR) << "Unhandled FC weights format";
+ op->weights_format = FullyConnectedWeightsFormat::kDefault;
+ }
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& fc_op = static_cast<const FullyConnectedOperator&>(op);
+ return fc_op.weights_format == FullyConnectedWeightsFormat::kDefault ? 1
+ : 2;
+ }
};
class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
@@ -507,6 +543,22 @@ class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class Tile
+ : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
+ ::tflite::BuiltinOptions_TileOptions> {
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateTileOptions(*builder);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {}
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
::tflite::BuiltinOptions_PadV2Options> {
public:
@@ -610,11 +662,21 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
+ ::tflite::LSTMKernelType kernel_type;
+ switch (op.kernel_type) {
+ case LstmCellOperator::KERNEL_BASIC:
+ kernel_type = ::tflite::LSTMKernelType_BASIC;
+ break;
+ case LstmCellOperator::KERNEL_FULL:
+ kernel_type = ::tflite::LSTMKernelType_FULL;
+ break;
+ }
+
// Current toco converter only supports tanh, no clip.
return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
::tflite::ActivationFunctionType_TANH,
/*cell_clip=*/0.0,
- /*proj_clip=*/0.0);
+ /*proj_clip=*/0.0, kernel_type);
}
void ReadOptions(const TfLiteOptions& options,
@@ -622,19 +684,75 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
// Only support tanh activation, so check that tflite type is tanh.
CHECK(options.fused_activation_function() ==
::tflite::ActivationFunctionType_TANH);
+
+ switch (options.kernel_type()) {
+ case ::tflite::LSTMKernelType_BASIC:
+ op->kernel_type = LstmCellOperator::KERNEL_BASIC;
+ break;
+ case ::tflite::LSTMKernelType_FULL:
+ op->kernel_type = LstmCellOperator::KERNEL_FULL;
+ break;
+ }
+ }
+
+ int GetVersion(const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL:
+ return 1;
+ case LstmCellOperator::KERNEL_BASIC:
+ return 2;
+ }
+ }
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL: {
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ break;
+ }
+ case LstmCellOperator::KERNEL_BASIC: {
+ mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
+ mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
+ break;
+ }
+ }
+ return mutating_input_variables;
+ }
+};
+
+class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
}
int GetVersion(const Operator& op) const override { return 1; }
};
-class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
- ::tflite::BuiltinOptions_MeanOptions> {
+class Sum
+ : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
flatbuffers::Offset<TfLiteOptions> WriteOptions(
const TocoOperator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
- return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
}
void ReadOptions(const TfLiteOptions& options,
@@ -769,6 +887,25 @@ class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
+ ::tflite::BuiltinOptions_ArgMinOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateArgMinOptions(
+ *builder, DataType::Serialize(op.output_data_type));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->output_data_type = DataType::Deserialize(options.output_type());
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TransposeConv
: public BuiltinOperator<TransposeConvOperator,
::tflite::TransposeConvOptions,
@@ -815,6 +952,44 @@ class SparseToDense
int GetVersion(const Operator& op) const override { return 1; }
};
+class ExpandDims
+ : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
+ ::tflite::BuiltinOptions_ExpandDimsOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateExpandDimsOptions(*builder);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {}
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class Shape
+ : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
+ ::tflite::BuiltinOptions_ShapeOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateShapeOptions(
+ *builder, DataType::Serialize(op.output_data_type));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->output_data_type = DataType::Deserialize(options.out_type());
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -875,6 +1050,20 @@ class TensorFlowUnsupported : public BaseOperator {
fbb->Bool(key, attr.b());
has_valid_attr = true;
break;
+ case tensorflow::AttrValue::kList:
+ if (attr.list().i_size() > 0) {
+ auto start = fbb->StartVector(key);
+ for (const int64_t v : attr.list().i()) {
+ fbb->Add(v);
+ }
+ fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
+ has_valid_attr = true;
+ } else {
+ LOG(WARNING)
+ << "Ignoring unsupported type in list attribute with key '"
+ << key << "'";
+ }
+ break;
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
@@ -911,6 +1100,14 @@ class TensorFlowUnsupported : public BaseOperator {
case flexbuffers::TYPE_BOOL:
(*attr)[key].set_b(value.AsBool());
break;
+ case flexbuffers::TYPE_VECTOR_INT: {
+ auto* list = (*attr)[key].mutable_list();
+ const auto& vector = value.AsTypedVector();
+ for (size_t i = 0; i < vector.size(); i++) {
+ list->add_i(vector[i].AsInt64());
+ }
+ break;
+ }
default:
LOG(WARNING) << "Ignoring unsupported attribute type with key '"
<< key << "'";
@@ -969,8 +1166,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
ops.emplace_back(
new PadV2(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
- ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
- OperatorType::kTensorFlowReshape));
+ ops.emplace_back(
+ new Reshape(::tflite::BuiltinOperator_RESHAPE, OperatorType::kReshape));
ops.emplace_back(
new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
@@ -981,12 +1178,13 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kTranspose));
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
+ ops.emplace_back(new Sum(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
- ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
- OperatorType::kTensorFlowSplit));
+ ops.emplace_back(
+ new Split(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
OperatorType::kStridedSlice));
ops.emplace_back(
@@ -997,24 +1195,31 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
ops.emplace_back(
new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax));
+ ops.emplace_back(
+ new ArgMin(::tflite::BuiltinOperator_ARG_MIN, OperatorType::kArgMin));
+ ops.emplace_back(
+ new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
+ ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS,
+ OperatorType::kExpandDims));
ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV,
OperatorType::kTransposeConv));
ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE,
OperatorType::kSparseToDense));
+ ops.emplace_back(
+ new Shape(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
+ ops.emplace_back(new FakeQuant(::tflite::BuiltinOperator_FAKE_QUANT,
+ OperatorType::kFakeQuant));
// Custom Operators.
ops.emplace_back(
new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
- ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
- ops.emplace_back(new TensorFlowUnsupported(
- "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
+ ops.emplace_back(new TensorFlowUnsupported("TENSORFLOW_UNSUPPORTED",
+ OperatorType::kUnsupported));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
ops.emplace_back(
new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
- ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
- "RSQRT", OperatorType::kTensorFlowRsqrt));
// Simple Operators.
ops.emplace_back(new SimpleOperator<DequantizeOperator>(
"DEQUANTIZE", OperatorType::kDequantize));
@@ -1036,23 +1241,34 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
ops.emplace_back(new SimpleOperator<LogSoftmaxOperator>(
"LOG_SOFTMAX", OperatorType::kLogSoftmax));
ops.emplace_back(new SimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum));
+ "MAXIMUM", OperatorType::kMaximum)); // Element-wise Maximum
ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
- "MINIMUM", OperatorType::kTensorFlowMinimum));
+ "MINIMUM", OperatorType::kMinimum)); // Element-wise Minimum
ops.emplace_back(new SimpleOperator<TensorFlowGreaterOperator>(
- "GREATER", OperatorType::kTensorFlowGreater));
+ "GREATER", OperatorType::kGreater));
ops.emplace_back(new SimpleOperator<TensorFlowGreaterEqualOperator>(
- "GREATER_EQUAL", OperatorType::kTensorFlowGreaterEqual));
- ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
- "LESS", OperatorType::kTensorFlowLess));
+ "GREATER_EQUAL", OperatorType::kGreaterEqual));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess));
ops.emplace_back(new SimpleOperator<TensorFlowLessEqualOperator>(
- "LESS_EQUAL", OperatorType::kTensorFlowLessEqual));
+ "LESS_EQUAL", OperatorType::kLessEqual));
+ ops.emplace_back(new SimpleOperator<TensorFlowEqualOperator>(
+ "EQUAL", OperatorType::kEqual));
+ ops.emplace_back(new SimpleOperator<TensorFlowNotEqualOperator>(
+ "NOT_EQUAL", OperatorType::kNotEqual));
ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
ops.emplace_back(
new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
+ ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow));
+ // Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
+ ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
+ ops.emplace_back(
+ new SimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt));
+ ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
+ "RSQRT", OperatorType::kRsqrt));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 5e9c20e40d..d9ea23edf2 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -87,6 +87,17 @@ class BaseOperator {
// overridden. (See example in `operator_test.cc`)
virtual int GetVersion(const Operator& op) const = 0;
+ // Given a Toco `Operator`, return a list of booleans indicating the op
+ // mutates which input variables.
+ // * If the op mutates any input variables, it should return a list of bool
+ // with the same length as inputs.
+ // * Otherwise, it will return an empty list.
+ virtual std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const {
+ // Most ops don't have variable tensors. This function can be overridden.
+ return std::vector<bool>();
+ }
+
private:
string name_;
OperatorType type_;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index d63c99a5f9..ff2d35b1f5 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -74,8 +74,10 @@ class OperatorTest : public ::testing::Test {
auto new_toco_op = op.Deserialize(output_options->builtin_options(),
output_options->custom_options());
- CHECK(dynamic_cast<T*>(new_toco_op.get()))
- << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to "
+ CHECK(new_toco_op->type == toco_op.type)
+ << "The type of the serialized and deserialized"
+ << HelpfulOperatorTypeName(*new_toco_op)
+ << " does not match the type of the original "
<< HelpfulOperatorTypeName(toco_op);
return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
@@ -110,15 +112,21 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
OperatorType::kLogSoftmax);
CheckSimpleOperator<TensorFlowMaximumOperator>(
- "MAXIMUM", OperatorType::kTensorFlowMaximum);
+ "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum
CheckSimpleOperator<TensorFlowMinimumOperator>(
- "MINIMUM", OperatorType::kTensorFlowMinimum);
- CheckSimpleOperator<TensorFlowLessOperator>("LESS",
- OperatorType::kTensorFlowLess);
+ "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum
+ CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
+ CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
+ CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
+ OperatorType::kNotEqual);
+ CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
+ CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
+ CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
+ CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
}
TEST_F(OperatorTest, BuiltinAdd) {
@@ -247,7 +255,7 @@ TEST_F(OperatorTest, BuiltinReshape) {
TensorFlowReshapeOperator op;
op.shape = {1, 2, 4, 5, 8};
auto output_toco_op = SerializeAndDeserialize(
- GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
+ GetOperator("RESHAPE", OperatorType::kReshape), op);
EXPECT_EQ(op.shape, output_toco_op->shape);
}
@@ -270,8 +278,8 @@ TEST_F(OperatorTest, BuiltinSpaceToDepth) {
TEST_F(OperatorTest, CustomSplit) {
TensorFlowSplitOperator op;
op.num_split = 123;
- auto output_toco_op = SerializeAndDeserialize(
- GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
EXPECT_EQ(op.num_split, output_toco_op->num_split);
}
@@ -408,6 +416,13 @@ TEST_F(OperatorTest, BuiltinArgMax) {
EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
}
+TEST_F(OperatorTest, BuiltinArgMin) {
+ ArgMinOperator op;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ARG_MIN", OperatorType::kArgMin), op);
+ EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
+}
+
TEST_F(OperatorTest, BuiltinTransposeConv) {
TransposeConvOperator op;
op.stride_width = 123;
@@ -420,6 +435,14 @@ TEST_F(OperatorTest, BuiltinTransposeConv) {
EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
}
+TEST_F(OperatorTest, BuiltinShape) {
+ TensorFlowShapeOperator op;
+ op.output_data_type = ArrayDataType::kInt64;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
+ EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
+}
+
TEST_F(OperatorTest, BuiltinSparseToDense) {
SparseToDenseOperator op;
op.validate_indices = false;
@@ -439,12 +462,17 @@ TEST_F(OperatorTest, TensorFlowUnsupported) {
(*attr)["str_attr"].set_s("Hello World");
(*attr)["int_attr"].set_i(17);
(*attr)["bool_attr"].set_b(true);
+ {
+ auto* list = (*attr)["list_int_attr"].mutable_list();
+ list->add_i(1);
+ list->add_i(20);
+ list->add_i(1LL << 40);
+ list->add_i(-(1LL << 40));
+ }
node_def.SerializeToString(&op.tensorflow_node_def);
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
- OperatorType::kTensorFlowUnsupported),
- op);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
::tensorflow::NodeDef output_node_def;
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
@@ -453,15 +481,22 @@ TEST_F(OperatorTest, TensorFlowUnsupported) {
EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
EXPECT_EQ(17, output_attr.at("int_attr").i());
EXPECT_EQ(true, output_attr.at("bool_attr").b());
+
+ {
+ const auto& list = output_attr.at("list_int_attr").list();
+ ASSERT_EQ(4, list.i_size());
+ EXPECT_EQ(1, list.i(0));
+ EXPECT_EQ(20, list.i(1));
+ EXPECT_EQ(1LL << 40, list.i(2));
+ EXPECT_EQ(-(1LL << 40), list.i(3));
+ }
}
TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
- OperatorType::kTensorFlowUnsupported),
- op);
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
::tensorflow::NodeDef output_node_def;
output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
index 4867c3a62e..754f0b4b8c 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -88,6 +88,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
switch (array_data_type) {
case ArrayDataType::kFloat:
return ::tflite::TensorType_FLOAT32;
+ case ArrayDataType::kInt16:
+ return ::tflite::TensorType_INT16;
case ArrayDataType::kInt32:
return ::tflite::TensorType_INT32;
case ArrayDataType::kInt64:
@@ -98,6 +100,8 @@ void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
return ::tflite::TensorType_STRING;
case ArrayDataType::kBool:
return ::tflite::TensorType_BOOL;
+ case ArrayDataType::kComplex64:
+ return ::tflite::TensorType_COMPLEX64;
default:
// FLOAT32 is filled for unknown data types.
// TODO(ycling): Implement type inference in TF Lite interpreter.
@@ -109,6 +113,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
switch (::tflite::TensorType(tensor_type)) {
case ::tflite::TensorType_FLOAT32:
return ArrayDataType::kFloat;
+ case ::tflite::TensorType_INT16:
+ return ArrayDataType::kInt16;
case ::tflite::TensorType_INT32:
return ArrayDataType::kInt32;
case ::tflite::TensorType_INT64:
@@ -119,6 +125,8 @@ ArrayDataType DataType::Deserialize(int tensor_type) {
return ArrayDataType::kUint8;
case ::tflite::TensorType_BOOL:
return ArrayDataType::kBool;
+ case ::tflite::TensorType_COMPLEX64:
+ return ArrayDataType::kComplex64;
default:
LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
}
@@ -131,6 +139,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
switch (array.data_type) {
case ArrayDataType::kFloat:
return CopyBuffer<ArrayDataType::kFloat>(array, builder);
+ case ArrayDataType::kInt16:
+ return CopyBuffer<ArrayDataType::kInt16>(array, builder);
case ArrayDataType::kInt32:
return CopyBuffer<ArrayDataType::kInt32>(array, builder);
case ArrayDataType::kInt64:
@@ -141,6 +151,8 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
case ArrayDataType::kBool:
return CopyBoolToBuffer(array, builder);
+ case ArrayDataType::kComplex64:
+ return CopyBuffer<ArrayDataType::kComplex64>(array, builder);
default:
LOG(FATAL) << "Unhandled array data type.";
}
@@ -154,6 +166,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
switch (tensor.type()) {
case ::tflite::TensorType_FLOAT32:
return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
+ case ::tflite::TensorType_INT16:
+ return CopyBuffer<ArrayDataType::kInt16>(buffer, array);
case ::tflite::TensorType_INT32:
return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
case ::tflite::TensorType_INT64:
@@ -164,6 +178,8 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
case ::tflite::TensorType_BOOL:
return CopyBuffer<ArrayDataType::kBool>(buffer, array);
+ case ::tflite::TensorType_COMPLEX64:
+ return CopyBuffer<ArrayDataType::kComplex64>(buffer, array);
default:
LOG(FATAL) << "Unhandled tensor type.";
}
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
index 564f303b9b..8e9f30ba3a 100644
--- a/tensorflow/contrib/lite/toco/tflite/types_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include <complex>
+
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -71,7 +73,8 @@ TEST(DataType, SupportedTypes) {
{ArrayDataType::kInt32, ::tflite::TensorType_INT32},
{ArrayDataType::kInt64, ::tflite::TensorType_INT64},
{ArrayDataType::kFloat, ::tflite::TensorType_FLOAT32},
- {ArrayDataType::kBool, ::tflite::TensorType_BOOL}};
+ {ArrayDataType::kBool, ::tflite::TensorType_BOOL},
+ {ArrayDataType::kComplex64, ::tflite::TensorType_COMPLEX64}};
for (auto x : testdata) {
EXPECT_EQ(x.second, DataType::Serialize(x.first));
EXPECT_EQ(x.first, DataType::Deserialize(x.second));
@@ -151,6 +154,12 @@ TEST(DataBuffer, Int32) {
::testing::ElementsAre(1, 1 << 30));
}
+TEST(DataBuffer, Int16) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kInt16>({1, 1 << 14});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kInt16>().data,
+ ::testing::ElementsAre(1, 1 << 14));
+}
+
TEST(DataBuffer, String) {
Array recovered = ToFlatBufferAndBack<ArrayDataType::kString>(
{"AA", "BBB", "Best. String. Ever."});
@@ -165,6 +174,14 @@ TEST(DataBuffer, Bool) {
::testing::ElementsAre(true, false, true));
}
+TEST(DataBuffer, Complex64) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kComplex64>(
+ {std::complex<float>(1.0f, 2.0f), std::complex<float>(3.0f, 4.0f)});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kComplex64>().data,
+ ::testing::ElementsAre(std::complex<float>(1.0f, 2.0f),
+ std::complex<float>(3.0f, 4.0f)));
+}
+
TEST(Padding, All) {
EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
diff --git a/tensorflow/contrib/lite/toco/toco.cc b/tensorflow/contrib/lite/toco/toco.cc
index 8041aa9e7f..0b460bd178 100644
--- a/tensorflow/contrib/lite/toco/toco.cc
+++ b/tensorflow/contrib/lite/toco/toco.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_port.h"
-#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
#include "tensorflow/contrib/lite/toco/toco_tooling.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
#include "tensorflow/core/platform/logging.h"
@@ -49,17 +48,6 @@ void CheckFrozenModelPermissions(const Arg<string>& input_file) {
<< input_file.value() << ".\n";
}
-// Checks the permissions of the SavedModel directory.
-void CheckSavedModelPermissions(const Arg<string>& savedmodel_directory) {
- QCHECK(savedmodel_directory.specified())
- << "Missing required flag --savedmodel_directory.\n";
- QCHECK(
- port::file::Exists(savedmodel_directory.value(), port::file::Defaults())
- .ok())
- << "Specified savedmodel_directory does not exist: "
- << savedmodel_directory.value() << ".\n";
-}
-
// Reads the contents of the GraphDef from either the frozen graph file or the
// SavedModel directory. If it reads the SavedModel directory, it updates the
// ModelFlags and TocoFlags accordingly.
@@ -69,24 +57,16 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
string* graph_def_contents) {
port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
- bool has_input_file = parsed_toco_flags.input_file.specified();
- bool has_savedmodel_dir = parsed_toco_flags.savedmodel_directory.specified();
-
- // Ensure either input_file or savedmodel_directory flag has been set.
- QCHECK_NE(has_input_file, has_savedmodel_dir)
- << "Specify either input_file or savedmodel_directory flag.\n";
+ // Ensure savedmodel_directory is not set.
+ QCHECK(!parsed_toco_flags.savedmodel_directory.specified())
+ << "Use `tensorflow/contrib/lite/python/tflite_convert` script with "
+ << "SavedModel directories.\n";
// Checks the input file permissions and reads the contents.
- if (has_input_file) {
- CheckFrozenModelPermissions(parsed_toco_flags.input_file);
- CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
- graph_def_contents, port::file::Defaults())
- .ok());
- } else {
- CheckSavedModelPermissions(parsed_toco_flags.savedmodel_directory);
- GetSavedModelContents(parsed_toco_flags, parsed_model_flags, toco_flags,
- model_flags, graph_def_contents);
- }
+ CheckFrozenModelPermissions(parsed_toco_flags.input_file);
+ CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ graph_def_contents, port::file::Defaults())
+ .ok());
}
void ToolMain(const ParsedTocoFlags& parsed_toco_flags,
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index 7786a4ada3..c6d0a03452 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -41,7 +41,7 @@ bool ParseTocoFlagsFromCommandLineFlags(
"extension."),
Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
parsed_flags.savedmodel_directory.default_value(),
- "Full path to the directory containing the SavedModel."),
+ "Deprecated. Full path to the directory containing the SavedModel."),
Flag("output_file", parsed_flags.output_file.bind(),
parsed_flags.output_file.default_value(),
"Output file. "
@@ -55,9 +55,9 @@ bool ParseTocoFlagsFromCommandLineFlags(
"One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
parsed_flags.savedmodel_tagset.default_value(),
- "Comma-separated set of tags identifying the MetaGraphDef within "
- "the SavedModel to analyze. All tags in the tag set must be "
- "specified."),
+ "Deprecated. Comma-separated set of tags identifying the "
+ "MetaGraphDef within the SavedModel to analyze. All tags in the tag "
+ "set must be specified."),
Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
parsed_flags.default_ranges_min.default_value(),
"If defined, will be used as the default value for the min bound "
@@ -153,6 +153,16 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.dedupe_array_min_size_bytes.default_value(),
"Minimum size of constant arrays to deduplicate; arrays smaller "
"will not be deduplicated."),
+ Flag("split_tflite_lstm_inputs",
+ parsed_flags.split_tflite_lstm_inputs.bind(),
+ parsed_flags.split_tflite_lstm_inputs.default_value(),
+ "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
+ "Ignored if the output format is not TFLite."),
+ Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
+ parsed_flags.quantize_weights.default_value(),
+ "Store weights as quantized weights followed by dequantize "
+ "operations. Computation is still done in float, but reduces model "
+ "size (at the cost of accuracy and latency)."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -245,6 +255,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
FlagRequirement::kNone);
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
+ READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
+ READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -278,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
toco_flags->set_inference_input_type(input_type);
}
+ if (parsed_toco_flags.quantize_weights.value()) {
+ QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
+ << "quantize_weights is not supported with inference_type "
+ "QUANTIZED_UINT8.";
+ }
#undef READ_TOCO_FLAG
#undef PARSE_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 8589ca361d..b4a9870d58 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 19.
+// Next ID to use: 26.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -165,4 +165,22 @@ message TocoFlags {
// Minimum size of constant arrays to deduplicate; arrays smaller will not be
// deduplicated.
optional int64 dedupe_array_min_size_bytes = 18 [default = 64];
+
+ // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite.
+ // Ignored if the output format is not TFLite.
+ optional bool split_tflite_lstm_inputs = 19 [default = true];
+
+ // Store weights as quantized weights followed by dequantize operations.
+ // Computation is still done in float, but reduces model size (at the cost of
+ // accuracy and latency).
+ optional bool quantize_weights = 20 [default = false];
+
+ // Full filepath of folder to dump the graphs at various stages of processing
+ // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order
+ // to keep the requirements of the output file.
+ optional string dump_graphviz_dir = 24;
+
+ // Boolean indicating whether to dump the graph after every graph
+ // transformation.
+ optional bool dump_graphviz_include_video = 25;
}
diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc
index a1c8696cd0..de76fd4032 100644
--- a/tensorflow/contrib/lite/toco/toco_port.cc
+++ b/tensorflow/contrib/lite/toco/toco_port.cc
@@ -16,8 +16,16 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
+#if defined(__ANDROID__) && defined(__ARM_ARCH_7A__)
+namespace std {
+double round(double x) { return ::round(x); }
+} // namespace std
+#endif
+
namespace toco {
namespace port {
void CopyToBuffer(const string& src, char* dest) {
@@ -55,8 +63,12 @@ void CheckInitGoogleIsDone(const char* message) {
namespace file {
// Conversion to our wrapper Status.
-Status ToStatus(const ::util::Status& uts) {
- return Status(uts.ok(), uts.error_message());
+tensorflow::Status ToStatus(const ::util::Status& uts) {
+ if (!uts.ok()) {
+ return tensorflow::Status(tensorflow::errors::Code(uts.error_code()),
+ uts.error_message());
+ }
+ return tensorflow::Status::OK();
}
// Conversion to our wrapper Options.
@@ -65,7 +77,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) {
return Options();
}
-Status Writable(const string& filename) {
+tensorflow::Status Writable(const string& filename) {
File* f = nullptr;
const auto status = ::file::Open(filename, "w", &f, ::file::Defaults());
if (f) {
@@ -74,22 +86,24 @@ Status Writable(const string& filename) {
return ToStatus(status);
}
-Status Readable(const string& filename, const file::Options& options) {
+tensorflow::Status Readable(const string& filename,
+ const file::Options& options) {
return ToStatus(::file::Readable(filename, ::file::Defaults()));
}
-Status Exists(const string& filename, const file::Options& options) {
+tensorflow::Status Exists(const string& filename,
+ const file::Options& options) {
auto status = ::file::Exists(filename, ::file::Defaults());
return ToStatus(status);
}
-Status GetContents(const string& filename, string* contents,
- const file::Options& options) {
+tensorflow::Status GetContents(const string& filename, string* contents,
+ const file::Options& options) {
return ToStatus(::file::GetContents(filename, contents, ::file::Defaults()));
}
-Status SetContents(const string& filename, const string& contents,
- const file::Options& options) {
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
return ToStatus(::file::SetContents(filename, contents, ::file::Defaults()));
}
@@ -133,37 +147,42 @@ void CheckInitGoogleIsDone(const char* message) {
namespace file {
-Status Writable(const string& filename) {
+tensorflow::Status Writable(const string& filename) {
FILE* f = fopen(filename.c_str(), "w");
if (f) {
fclose(f);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
- return Status(false, "not writable");
+ return tensorflow::errors::NotFound("not writable");
}
-Status Readable(const string& filename, const file::Options& options) {
+tensorflow::Status Readable(const string& filename,
+ const file::Options& options) {
FILE* f = fopen(filename.c_str(), "r");
if (f) {
fclose(f);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
- return Status(false, "not readable");
+ return tensorflow::errors::NotFound("not readable");
}
-Status Exists(const string& filename, const file::Options& options) {
+tensorflow::Status Exists(const string& filename,
+ const file::Options& options) {
struct stat statbuf;
int ret = stat(filename.c_str(), &statbuf);
- return Status(ret != -1, "");
+ if (ret == -1) {
+ return tensorflow::errors::NotFound("file doesn't exist");
+ }
+ return tensorflow::Status::OK();
}
-Status GetContents(const string& path, string* output,
- const file::Options& options) {
+tensorflow::Status GetContents(const string& path, string* output,
+ const file::Options& options) {
output->clear();
int fd = open(path.c_str(), O_RDONLY);
if (fd == -1) {
- return Status(false, "can't open() for read");
+ return tensorflow::errors::NotFound("can't open() for read");
}
// Direct read, for speed.
@@ -174,25 +193,25 @@ Status GetContents(const string& path, string* output,
if (size == 0) {
// Done.
close(fd);
- return Status(true, "");
+ return tensorflow::Status::OK();
} else if (size == -1) {
// Error.
close(fd);
- return Status(false, "error during read()");
+ return tensorflow::errors::Internal("error during read()");
} else {
output->append(buffer, size);
}
}
CHECK(0);
- return Status(false, "internal error");
+ return tensorflow::errors::Internal("internal error");
}
-Status SetContents(const string& filename, const string& contents,
- const file::Options& options) {
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const file::Options& options) {
int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664);
if (fd == -1) {
- return Status(false, "can't open() for write");
+ return tensorflow::errors::Internal("can't open() for write");
}
size_t i = 0;
@@ -201,13 +220,13 @@ Status SetContents(const string& filename, const string& contents,
ssize_t written = write(fd, &contents[i], to_write);
if (written == -1) {
close(fd);
- return Status(false, "write() error");
+ return tensorflow::errors::Internal("write() error");
}
i += written;
}
close(fd);
- return Status(true, "");
+ return tensorflow::Status::OK();
}
string JoinPath(const string& base, const string& filename) {
diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h
index 906792ef56..17f82b9dd7 100644
--- a/tensorflow/contrib/lite/toco/toco_port.h
+++ b/tensorflow/contrib/lite/toco/toco_port.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "google/protobuf/text_format.h"
#include "tensorflow/contrib/lite/toco/format_port.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
@@ -33,28 +34,26 @@ limitations under the License.
#define TFLITE_PROTO_NS google::protobuf
#endif
-namespace toco {
-namespace port {
-
-class Status {
- public:
- static Status OK() { return Status(true, ""); }
-
- // Create a failed status with no message.
- Status() {}
-
- Status(bool ok, const string& message) : ok_(ok), message_(message) {}
-
- void AppendMessage(const string& message) { message_ += message; }
+#ifdef __ANDROID__
+#include <sstream>
+namespace std {
- bool ok() const { return ok_; }
+template <typename T>
+std::string to_string(T value)
+{
+ std::ostringstream os ;
+ os << value ;
+ return os.str() ;
+}
- const string error_message() const { return message_; }
+#ifdef __ARM_ARCH_7A__
+double round(double x);
+#endif
+}
+#endif
- private:
- bool ok_ = false;
- string message_;
-};
+namespace toco {
+namespace port {
void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags);
void CheckInitGoogleIsDone(const char* message);
@@ -65,14 +64,14 @@ inline Options Defaults() {
Options o;
return o;
}
-Status GetContents(const string& filename, string* contents,
- const Options& options);
-Status SetContents(const string& filename, const string& contents,
- const Options& options);
+tensorflow::Status GetContents(const string& filename, string* contents,
+ const Options& options);
+tensorflow::Status SetContents(const string& filename, const string& contents,
+ const Options& options);
string JoinPath(const string& base, const string& filename);
-Status Writable(const string& filename);
-Status Readable(const string& filename, const Options& options);
-Status Exists(const string& filename, const Options& options);
+tensorflow::Status Writable(const string& filename);
+tensorflow::Status Readable(const string& filename, const Options& options);
+tensorflow::Status Exists(const string& filename, const Options& options);
} // namespace file
// Copy `src` string to `dest`. User must ensure `dest` has enough space.
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.cc b/tensorflow/contrib/lite/toco/toco_saved_model.cc
deleted file mode 100644
index 26f55a66c7..0000000000
--- a/tensorflow/contrib/lite/toco/toco_saved_model.cc
+++ /dev/null
@@ -1,189 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <string>
-#include <vector>
-
-#include "absl/strings/numbers.h"
-#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
-#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
-#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/framework/tensor_shape.pb.h"
-
-namespace toco {
-namespace {
-
-// Loads a SavedModel from the directory specified in parsed_toco_flags.
-// Returns a SavedModelBundle with the requested MetaGraphDef.
-const tensorflow::SavedModelBundle* LoadSavedModel(
- const ParsedTocoFlags& parsed_toco_flags) {
- const string model_path = parsed_toco_flags.savedmodel_directory.value();
- QCHECK(tensorflow::MaybeSavedModelDirectory(model_path))
- << "Model is not saved in the supported SavedModel format.\n";
-
- // Gets the tags identifying the MetaGraphDef from the command line arguments.
- string tags_str;
- if (parsed_toco_flags.savedmodel_tagset.specified()) {
- tags_str = parsed_toco_flags.savedmodel_tagset.value();
- } else {
- tags_str = parsed_toco_flags.savedmodel_tagset.default_value();
- }
- auto tags = absl::StrSplit(tags_str, ',');
-
- // Loads MetaGraphDef.
- auto* bundle = new tensorflow::SavedModelBundle;
- TF_CHECK_OK(tensorflow::LoadSavedModel(tensorflow::SessionOptions(),
- tensorflow::RunOptions(), model_path,
- tags, bundle))
- << "Failed to load exported model from " << model_path
- << ". Ensure the model contains the required tags '" << tags_str
- << "'.\n";
- return bundle;
-}
-
-// Returns the array name without the postfix.
-//
-// e.g. reduces "input:0" to "input".
-string GetArrayName(const string& name) {
- const std::vector<string>& names = absl::StrSplit(name, ':');
- return names[0];
-}
-
-// Returns the list of array names without the postfix sorted alphabetically.
-std::set<string> GetSortedNames(const std::unordered_set<string>& names) {
- std::vector<string> final_names;
- final_names.reserve(names.size());
- for (const auto& name : names) {
- final_names.push_back(GetArrayName(name));
- }
- return std::set<string>(final_names.begin(), final_names.end());
-}
-
-// Gets the final shape after replacing the first dimension with batch size, if
-// it is undefined (containing the value -1). Returns whether the shape is
-// valid.
-bool ReplaceShapeBatchSize(const tensorflow::TensorShapeProto& shape,
- int batch_size,
- tensorflow::TensorShapeProto* final_shape) {
- for (int idx = 0; idx < shape.dim().size(); ++idx) {
- int64 final_dim = shape.dim()[idx].size();
- if (final_dim == -1) {
- if (idx > 0) return false;
- final_dim = batch_size;
- }
- final_shape->add_dim()->set_size(final_dim);
- }
- return true;
-}
-
-// Updates the input arrays in ModelFlags to contain the shape of the array.
-void ProcessInputShapes(const tensorflow::GraphDef& graph_def, int batch_size,
- ModelFlags* model_flags) {
- // Build map of input array names to input arrays.
- std::unordered_map<string, InputArray*> input_data_map;
- for (auto& input : *model_flags->mutable_input_arrays()) {
- input_data_map[input.name()] = &input;
- }
-
- // Adds shapes to the input arrays if the shape is valid.
- for (const tensorflow::NodeDef& node_def : graph_def.node()) {
- if (input_data_map.find(node_def.name()) != input_data_map.end()) {
- const auto shape_it = node_def.attr().find("shape");
- if (shape_it != node_def.attr().end()) {
- tensorflow::TensorShapeProto final_shape;
- bool is_valid = ReplaceShapeBatchSize(shape_it->second.shape(),
- batch_size, &final_shape);
-
- if (is_valid) {
- auto* shape = input_data_map.at(node_def.name())->mutable_shape();
- QCHECK_EQ(shape->dims_size(), 0)
- << "The shape for the input '" << node_def.name()
- << "' was previously defined. For clarity please define inputs "
- << "via --input_arrays and input_shapes flags.\n";
- for (const auto& dim : final_shape.dim()) {
- shape->add_dims(dim.size());
- }
- }
- }
- }
- }
-
- // Checks all input arrays have a shape.
- for (auto const& input : model_flags->input_arrays()) {
- QCHECK(input.shape().dims_size() > 0)
- << "A valid input shape was not found for input '" << input.name()
- << "'. Please define via --input_arrays and --input_shapes flags.\n";
- }
-}
-
-} // namespace
-
-void ParseMetaData(const tensorflow::GraphDef& graph_def,
- const std::unordered_set<string>& inputs,
- const std::unordered_set<string>& outputs,
- const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- TocoFlags* toco_flags, ModelFlags* model_flags) {
- if (!parsed_model_flags.input_arrays.specified()) {
- const std::set<string> sorted_inputs = GetSortedNames(inputs);
- for (const auto& input_name : sorted_inputs) {
- model_flags->add_input_arrays()->set_name(input_name);
- }
- }
-
- if (!parsed_model_flags.output_arrays.specified()) {
- const std::set<string> sorted_outputs = GetSortedNames(outputs);
- for (const auto& output_name : sorted_outputs) {
- model_flags->add_output_arrays(GetArrayName(output_name));
- }
- }
-
- if (!parsed_model_flags.input_shapes.specified()) {
- int batch_size = parsed_model_flags.batch_size.value();
- ProcessInputShapes(graph_def, batch_size, model_flags);
- }
-
- if (!parsed_toco_flags.inference_type.specified()) {
- toco_flags->set_inference_type(IODataType::FLOAT);
- }
-}
-
-// TODO(nupurgarg): Add top level tests.
-void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- TocoFlags* toco_flags, ModelFlags* model_flags,
- string* graph_def_contents) {
- // Loads the MetaGraphDef within a SavedModelBundle.
- auto bundle = LoadSavedModel(parsed_toco_flags);
-
- // Converts the MetaGraphDef to frozen GraphDef.
- tensorflow::GraphDef frozen_graph_def;
- std::unordered_set<string> inputs;
- std::unordered_set<string> outputs;
- TF_CHECK_OK(tensorflow::FreezeSavedModel(*bundle, &frozen_graph_def, &inputs,
- &outputs));
-
- // Reads the frozen GraphDef into a string.
- QCHECK(frozen_graph_def.SerializeToString(graph_def_contents))
- << "Unable to generate serialized GraphDef.\n";
-
- // Process inputs and outputs and metadata within GraphDef.
- const tensorflow::GraphDef graph_def = bundle->meta_graph_def.graph_def();
- ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags,
- parsed_model_flags, toco_flags, model_flags);
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model.h b/tensorflow/contrib/lite/toco/toco_saved_model.h
deleted file mode 100644
index 7a0fabd82d..0000000000
--- a/tensorflow/contrib/lite/toco/toco_saved_model.h
+++ /dev/null
@@ -1,53 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
-#define TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
-
-#include <string>
-#include <vector>
-
-#include "tensorflow/cc/tools/freeze_saved_model.h"
-#include "tensorflow/contrib/lite/toco/args.h"
-#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
-#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
-#include "tensorflow/contrib/lite/toco/types.pb.h"
-
-namespace toco {
-
-// Parses metadata into `toco_flags` and `model_flags`.
-//
-// Stores `inputs` as input_arrays and `outputs` as output_arrays in
-// `model_flags`. Infers input_shapes from the GraphDef and stores it in
-// `model_flags` as part of the input_arrays. Assumes inference_type is FLOAT
-// and stores it in `toco_flags`.
-void ParseMetaData(const tensorflow::GraphDef& graph_def,
- const std::unordered_set<string>& inputs,
- const std::unordered_set<string>& outputs,
- const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- TocoFlags* toco_flags, ModelFlags* model_flags);
-
-// Generates a frozen graph from the SavedModel in the directory specified in
-// `toco_flags`. Reads frozen graph contents into `graph_def_contents`. Parses
-// metadata relating to the GraphDef into `toco_flags` and `model_flags`.
-void GetSavedModelContents(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- TocoFlags* toco_flags, ModelFlags* model_flags,
- string* graph_def_contents);
-
-} // namespace toco
-
-#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOCO_SAVED_MODEL_H_
diff --git a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc b/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
deleted file mode 100644
index 5e122afe65..0000000000
--- a/tensorflow/contrib/lite/toco/toco_saved_model_test.cc
+++ /dev/null
@@ -1,274 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/lite/toco/toco_saved_model.h"
-#include "absl/strings/str_join.h"
-#include "tensorflow/cc/framework/scope.h"
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
-#include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-
-namespace toco {
-namespace {
-
-using tensorflow::ops::Add;
-using tensorflow::ops::Const;
-using tensorflow::ops::FakeQuantWithMinMaxArgs;
-using tensorflow::ops::Placeholder;
-
-class TocoSavedModelTest : public ::testing::Test {
- protected:
- // Calls functions to process cmdline arguments and calls ParseMetaData.
- // ParseMetaData parses input_arrays, output_arrays, and gets metadata from
- // SavedModel it is not defined in the cmdline arguments.
- void ProcessGraphDefMetadata(const std::unordered_set<string>& inputs,
- const std::unordered_set<string>& outputs,
- const tensorflow::GraphDef& graph_def) {
- ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags_, &toco_flags_);
- ReadModelFlagsFromCommandLineFlags(parsed_model_flags_, &model_flags_);
- ParseMetaData(graph_def, inputs, outputs, parsed_toco_flags_,
- parsed_model_flags_, &toco_flags_, &model_flags_);
- }
-
- // Gets the GraphDef from the SavedModelBundle and processes metadata.
- void ProcessSavedModelMetadata(const std::unordered_set<string>& inputs,
- const std::unordered_set<string>& outputs) {
- const tensorflow::GraphDef graph_def = bundle_.meta_graph_def.graph_def();
- ProcessGraphDefMetadata(inputs, outputs, graph_def);
- }
-
- // Returns a GraphDef representing a simple float model with a single input.
- tensorflow::GraphDef GetFloatGraphDef(const std::vector<int64>& shape) {
- tensorflow::GraphDef graph_def;
- tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
- tensorflow::Output input =
- Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
- Placeholder::Shape(tensorflow::PartialTensorShape(shape)));
- tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
- tensorflow::Output add = Add(scope.WithOpName("add"), input, zero);
-
- TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
- return graph_def;
- }
-
- // Returns a GraphDef representing a simple float model with two inputs.
- tensorflow::GraphDef GetComplexFloatGraphDef() {
- tensorflow::GraphDef graph_def;
- tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
- tensorflow::Output inputA =
- Placeholder(scope.WithOpName("inputA"), tensorflow::DT_FLOAT,
- Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
- tensorflow::Output inputB =
- Placeholder(scope.WithOpName("inputB"), tensorflow::DT_FLOAT,
- Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
- tensorflow::Output add = Add(scope.WithOpName("add"), inputB, inputA);
-
- TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
- return graph_def;
- }
-
- // Returns a GraphDef representing a simple quantized model.
- tensorflow::GraphDef GetQuantizedGraphDef() {
- tensorflow::GraphDef graph_def;
- tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
-
- tensorflow::Output input =
- Placeholder(scope.WithOpName("input"), tensorflow::DT_FLOAT,
- Placeholder::Shape(tensorflow::TensorShape({1, 3, 3, 1})));
- tensorflow::Output zero = Const(scope.WithOpName("zero"), 0.0f, {});
- tensorflow::Output fake_quant =
- FakeQuantWithMinMaxArgs(scope.WithOpName("quant"), zero);
- tensorflow::Output add = Add(scope.WithOpName("add"), input, fake_quant);
-
- TF_EXPECT_OK(scope.ToGraphDef(&graph_def));
- return graph_def;
- }
-
- // Gets the values in the input_arrays flag.
- std::vector<string> GetInputArrays() {
- std::vector<string> actual;
- for (const auto& input : model_flags_.input_arrays()) {
- actual.push_back(input.name());
- }
- return actual;
- }
-
- // Gets the values in the output_arrays flag.
- std::vector<string> GetOutputArrays() {
- std::vector<string> actual(model_flags_.output_arrays().begin(),
- model_flags_.output_arrays().end());
- return actual;
- }
-
- // Gets the shape of the given input array.
- string GetInputShape(const string& input_array) {
- for (const auto& input : model_flags_.input_arrays()) {
- if (input.name() == input_array) {
- std::vector<string> dims;
- for (int idx = 0; idx < input.shape().dims_size(); ++idx) {
- dims.push_back(std::to_string(input.shape().dims(idx)));
- }
- return absl::StrJoin(dims, ",");
- }
- }
- return "";
- }
-
- tensorflow::SavedModelBundle bundle_;
- ParsedTocoFlags parsed_toco_flags_;
- ParsedModelFlags parsed_model_flags_;
- TocoFlags toco_flags_;
- ModelFlags model_flags_;
-};
-
-// Tests if input_arrays, output_arrays, inference_type, and output_arrays are
-// added to ModelFlags if they are not specified in cmdline arguments.
-// Tests if the default batch size replaces a -1 in the first dimension.
-TEST_F(TocoSavedModelTest, NoCmdLine) {
- tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
-
- ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
- EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Tests if the order of input_arrays and output_arrays is deterministic when
-// they are taken from the SavedModel.
-TEST_F(TocoSavedModelTest, NoCmdLineMultipleArrays) {
- tensorflow::GraphDef graph_def = GetComplexFloatGraphDef();
-
- // Note: The model does not have two outputs. However, the function does not
- // need an accurate output_array list. This is only meant to test order.
- ProcessGraphDefMetadata({"inputB", "inputA"}, {"add", "invalid"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add", "invalid"}));
- EXPECT_EQ(GetInputShape("inputA"), "1,3,3,1");
- EXPECT_EQ(GetInputShape("inputB"), "1,3,3,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Tests if input_shapes is inferred when input_arrays is passed in via cmdline
-// arguments.
-TEST_F(TocoSavedModelTest, InputNameWithoutInputShape) {
- parsed_model_flags_.input_arrays.bind()("input");
- tensorflow::GraphDef graph_def = GetFloatGraphDef({2, 3, 3, 1});
-
- ProcessGraphDefMetadata({"not_used_input"}, {"add"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
- EXPECT_EQ(GetInputShape("input"), "2,3,3,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Ensures a failure occurs when input_shapes is defined without input_arrays.
-TEST_F(TocoSavedModelTest, InputShapeWithoutInputName) {
- parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
- tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
-
- EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
- "failed: input_shapes.size\\(\\) == "
- "model_flags->input_arrays_size\\(\\)");
-}
-
-// Tests if the cmdline values of input_arrays, input_shapes are used when
-// specified with an empty GraphDef.
-TEST_F(TocoSavedModelTest, InputArraysCmdLine) {
- parsed_model_flags_.input_arrays.bind()("inputA,inputB");
- parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
-
- ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"output0", "output1"}));
- EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
- EXPECT_EQ(GetInputShape("inputB"), "9,12");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Tests if the cmdline values of input_arrays, input_shapes are used when
-// specified even if values exist within the GraphDef.
-TEST_F(TocoSavedModelTest, InputArraysCmdLineWithGraphDef) {
- parsed_model_flags_.input_arrays.bind()("inputA");
- parsed_model_flags_.input_shapes.bind()("1,224,224,1");
- tensorflow::GraphDef graph_def = GetFloatGraphDef({1, 3, 3, 1});
-
- ProcessGraphDefMetadata({"inputA"}, {"add"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
- EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Tests if the cmdline values of input_arrays, input_shapes, inference_type,
-// and output_arrays are used when specified with an empty GraphDef.
-TEST_F(TocoSavedModelTest, AllParamsCmdLine) {
- parsed_model_flags_.input_arrays.bind()("inputA,inputB");
- parsed_model_flags_.output_arrays.bind()("outputA,outputB");
- parsed_model_flags_.input_shapes.bind()("1,224,224,1:9,12");
- parsed_toco_flags_.inference_type.bind()("FLOAT");
-
- ProcessSavedModelMetadata({"input0", "input1"}, {"output0", "output1"});
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"inputA", "inputB"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"outputA", "outputB"}));
- EXPECT_EQ(GetInputShape("inputA"), "1,224,224,1");
- EXPECT_EQ(GetInputShape("inputB"), "9,12");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Tests if a quantized graph gives the correct values assuming type is passed
-// in via command line.
-TEST_F(TocoSavedModelTest, QuantizedNoCmdLine) {
- parsed_toco_flags_.inference_type.bind()("QUANTIZED_UINT8");
- tensorflow::GraphDef graph_def = GetQuantizedGraphDef();
-
- ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
- EXPECT_EQ(GetInputShape("input"), "1,3,3,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::QUANTIZED_UINT8);
-}
-
-// Tests if the provided batch size replaces a -1 in the first dimension of
-// input shape.
-TEST_F(TocoSavedModelTest, MissingShapeParameterValid) {
- parsed_model_flags_.batch_size.bind()(3);
- tensorflow::GraphDef graph_def = GetFloatGraphDef({-1, 3, 3, 1});
-
- ProcessGraphDefMetadata({"input"}, {"add"}, graph_def);
- EXPECT_EQ(GetInputArrays(), std::vector<string>({"input"}));
- EXPECT_EQ(GetOutputArrays(), std::vector<string>({"add"}));
- EXPECT_EQ(GetInputShape("input"), "3,3,3,1");
- EXPECT_EQ(toco_flags_.inference_type(), IODataType::FLOAT);
-}
-
-// Ensures a failure occurs if there is a -1 in a dimension aside from the first
-// position of input shape.
-TEST_F(TocoSavedModelTest, MissingShapeParameterInvalid) {
- parsed_model_flags_.batch_size.bind()(3);
- tensorflow::GraphDef graph_def = GetFloatGraphDef({1, -1, 3, 1});
-
- EXPECT_DEATH(ProcessGraphDefMetadata({"input"}, {"add"}, graph_def),
- "A valid input shape was not found for input 'input'.");
-}
-
-} // namespace
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index b5531ca2f4..3ca36338eb 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -34,11 +34,11 @@ limitations under the License.
namespace toco {
namespace {
-// CHECK-fails if the model contains a kTensorFlowUnsupported operation.
+// CHECK-fails if the model contains a kUnsupported operation.
void CheckUnsupportedOperations(const Model& model) {
std::set<string> unsupported_ops;
for (auto& op : model.operators) {
- if (op->type == OperatorType::kTensorFlowUnsupported) {
+ if (op->type == OperatorType::kUnsupported) {
unsupported_ops.insert(
static_cast<const TensorFlowUnsupportedOperator*>(op.get())
->tensorflow_op);
@@ -56,6 +56,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ConvertSqueezeToReshape);
transformations->Add(new ConvertTrivialAddNToAdd);
transformations->Add(new ConvertTrivialStackToReshape);
+ transformations->Add(new ConvertTrivialTileToConcat);
transformations->Add(new ConvertTrivialTransposeToReshape);
transformations->Add(new ConvertReorderAxes);
transformations->Add(new ResolveReshapeAttributes);
@@ -76,7 +77,9 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowMatMul);
transformations->Add(new FuseBinaryIntoPrecedingAffine);
transformations->Add(new FuseBinaryIntoFollowingAffine);
+ transformations->Add(new FuseBroadcastIntoFollowingBinary);
transformations->Add(new MergeReshapeIntoPrecedingTranspose);
+ transformations->Add(new MoveBinaryOperatorBeforeReshape);
transformations->Add(new ReorderElementwiseUnary);
transformations->Add(new ReorderReshapeTranspose);
transformations->Add(new ResolveBatchNormalization);
@@ -94,7 +97,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowMerge);
transformations->Add(new ResolveSqueezeAttributes);
transformations->Add(new ResolveTensorFlowSwitch);
- transformations->Add(new ResolveTensorFlowTile);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
transformations->Add(new IdentifyDilatedConv);
@@ -133,6 +135,8 @@ bool SupportsPreallocatedWorkspace(FileFormat format) {
return (format == TFLITE);
}
+bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
+
bool IsRealValued(toco::ArrayDataType type) {
// TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
// for quantized real-number values, and no other integer type is ever used
@@ -263,7 +267,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
transformations.Add(new IdentifyLstmCell);
}
- if (output_format == TFLITE) {
+ if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
transformations.Add(new toco::SplitLstmCellInputs);
} else {
transformations.Add(new toco::MergeLstmCellInputs);
@@ -273,6 +277,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
+ if (toco_flags.quantize_weights()) {
+ // Run the quantize weights transformation after batchnorms have been
+ // folded into the weights.
+ RunGraphTransformations(model, "quantize weights transformation",
+ {new QuantizeWeights});
+ }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -331,6 +341,10 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
new RemoveFinalDequantizeOp,
ensure_safe_for_int8_kernels,
});
+ if (SupportsShuffledFCWeights(output_format)) {
+ RunGraphTransformations(model, "shuffling of FC weights",
+ {new ShuffleFCWeights});
+ }
} else {
GraphTransformationsSet dequantization_transformations{new Dequantize};
// Dequantize creates FakeQuant nodes. We may want to discard
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index fe7bed885d..4ec74e351f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace toco {
@@ -338,23 +338,23 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Div)
HANDLE_OPERATORTYPENAME_CASE(Tanh)
HANDLE_OPERATORTYPENAME_CASE(Sin)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowAll)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowAssert)
+ HANDLE_OPERATORTYPENAME_CASE(All)
+ HANDLE_OPERATORTYPENAME_CASE(Assert)
HANDLE_OPERATORTYPENAME_CASE(ExpandDims)
HANDLE_OPERATORTYPENAME_CASE(Fill)
HANDLE_OPERATORTYPENAME_CASE(FloorMod)
HANDLE_OPERATORTYPENAME_CASE(FloorDiv)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreater)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowGreaterEqual)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowIdentity)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowLess)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowLessEqual)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMatMul)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMax)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMaximum)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMerge)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMin)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowMinimum)
+ HANDLE_OPERATORTYPENAME_CASE(Greater)
+ HANDLE_OPERATORTYPENAME_CASE(GreaterEqual)
+ HANDLE_OPERATORTYPENAME_CASE(Identity)
+ HANDLE_OPERATORTYPENAME_CASE(Less)
+ HANDLE_OPERATORTYPENAME_CASE(LessEqual)
+ HANDLE_OPERATORTYPENAME_CASE(MatMul)
+ HANDLE_OPERATORTYPENAME_CASE(Max) // Reduction Max
+ HANDLE_OPERATORTYPENAME_CASE(Maximum) // Element-wise Maximum
+ HANDLE_OPERATORTYPENAME_CASE(Merge)
+ HANDLE_OPERATORTYPENAME_CASE(Min) // Reduction Min
+ HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)
@@ -362,22 +362,22 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Stack)
HANDLE_OPERATORTYPENAME_CASE(Range)
HANDLE_OPERATORTYPENAME_CASE(Rank)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowReshape)
+ HANDLE_OPERATORTYPENAME_CASE(Reshape)
HANDLE_OPERATORTYPENAME_CASE(Squeeze)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowRsqrt)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowShape)
+ HANDLE_OPERATORTYPENAME_CASE(Rsqrt)
+ HANDLE_OPERATORTYPENAME_CASE(Shape)
HANDLE_OPERATORTYPENAME_CASE(Slice)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSplit)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSqrt)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSquare)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSwitch)
+ HANDLE_OPERATORTYPENAME_CASE(Split)
+ HANDLE_OPERATORTYPENAME_CASE(Sqrt)
+ HANDLE_OPERATORTYPENAME_CASE(Square)
+ HANDLE_OPERATORTYPENAME_CASE(Switch)
HANDLE_OPERATORTYPENAME_CASE(Sub)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowSum)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowTile)
+ HANDLE_OPERATORTYPENAME_CASE(Sum)
+ HANDLE_OPERATORTYPENAME_CASE(Tile)
HANDLE_OPERATORTYPENAME_CASE(Transpose)
HANDLE_OPERATORTYPENAME_CASE(TransposeConv)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcat)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowConcatV2)
+ HANDLE_OPERATORTYPENAME_CASE(Concat)
+ HANDLE_OPERATORTYPENAME_CASE(ConcatV2)
HANDLE_OPERATORTYPENAME_CASE(Cast)
HANDLE_OPERATORTYPENAME_CASE(Floor)
HANDLE_OPERATORTYPENAME_CASE(Gather)
@@ -387,13 +387,17 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(Mean)
HANDLE_OPERATORTYPENAME_CASE(Svdf)
HANDLE_OPERATORTYPENAME_CASE(ArgMax)
+ HANDLE_OPERATORTYPENAME_CASE(ArgMin)
HANDLE_OPERATORTYPENAME_CASE(TopK_V2)
- HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
+ HANDLE_OPERATORTYPENAME_CASE(Unsupported)
HANDLE_OPERATORTYPENAME_CASE(Exp)
HANDLE_OPERATORTYPENAME_CASE(DynamicPartition)
HANDLE_OPERATORTYPENAME_CASE(DynamicStitch)
HANDLE_OPERATORTYPENAME_CASE(Select)
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
+ HANDLE_OPERATORTYPENAME_CASE(Equal)
+ HANDLE_OPERATORTYPENAME_CASE(NotEqual)
+ HANDLE_OPERATORTYPENAME_CASE(Pow)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -401,7 +405,7 @@ const char* OperatorTypeName(OperatorType type) {
}
string HelpfulOperatorTypeName(const Operator& op) {
- if (op.type == OperatorType::kTensorFlowUnsupported) {
+ if (op.type == OperatorType::kUnsupported) {
return toco::port::StringF(
"(Unsupported TensorFlow op: %s)",
static_cast<const TensorFlowUnsupportedOperator&>(op).tensorflow_op);
@@ -411,16 +415,20 @@ string HelpfulOperatorTypeName(const Operator& op) {
bool OperatorSupportsFusedActivation(OperatorType type) {
switch (type) {
- case OperatorType::kConcatenation:
- case OperatorType::kFakeQuant:
- case OperatorType::kGather:
- case OperatorType::kSlice:
- case OperatorType::kSqueeze:
- case OperatorType::kTensorFlowReshape:
- case OperatorType::kTensorFlowSplit:
- return false;
- default:
+ case OperatorType::kAdd:
+ case OperatorType::kAveragePool:
+ case OperatorType::kBatchNormalization:
+ case OperatorType::kConv:
+ case OperatorType::kDepthwiseConv:
+ case OperatorType::kDiv:
+ case OperatorType::kFullyConnected:
+ case OperatorType::kL2Pool:
+ case OperatorType::kMaxPool:
+ case OperatorType::kMul:
+ case OperatorType::kSub:
return true;
+ default:
+ return false;
}
}
@@ -440,8 +448,12 @@ void LogSummary(int log_level, const Model& model) {
}
void LogArray(int log_level, const Model& model, const string& name) {
- const auto& array = model.GetArray(name);
VLOG(log_level) << "Array: " << name;
+ if (!model.HasArray(name)) {
+ VLOG(log_level) << " DOES NOT EXIST";
+ return;
+ }
+ const auto& array = model.GetArray(name);
VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type);
VLOG(log_level) << " Final type: "
<< ArrayDataTypeName(array.final_data_type);
@@ -583,6 +595,13 @@ void UnextendShape(Shape* shape, int new_shape_size) {
shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction);
}
+bool IsValid(const Shape& shape) {
+ for (int i = 0; i < shape.dimensions_count(); ++i) {
+ if (shape.dims(i) < 1) return false;
+ }
+ return true;
+}
+
void CheckShapeDimensions(const Shape& shape) {
for (int i = 0; i < shape.dimensions_count(); ++i) {
CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i
@@ -1247,8 +1266,13 @@ void InsertCopyOperator(Model* model, const string& source_array_name,
auto* copy_op = new TensorFlowReshapeOperator;
copy_op->inputs = {
source_array_name,
- CreateInt32Array(model, target_array_name + "_copy_shape", shape)};
+ CreateInt32Array(
+ model, AvailableArrayName(*model, target_array_name + "_copy_shape"),
+ shape)};
copy_op->outputs = {target_array_name};
+ if (target_array.has_shape()) {
+ copy_op->shape = target_array.shape().dims();
+ }
model->operators.emplace_back(copy_op);
}
@@ -1863,18 +1887,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order,
output_axes_order == AxesOrder::kHWIO) {
// 3210 <- 3210
// HWIO <- OHWI
- (*shuffle)[0] = 1;
- (*shuffle)[1] = 2;
- (*shuffle)[2] = 3;
- (*shuffle)[3] = 0;
+ *shuffle = {1, 2, 3, 0};
} else if (input_axes_order == AxesOrder::kHWIO &&
output_axes_order == AxesOrder::kOHWI) {
// 3210 <- 3210
// OHWI <- HWIO
- (*shuffle)[0] = 3;
- (*shuffle)[1] = 0;
- (*shuffle)[2] = 1;
- (*shuffle)[3] = 2;
+ *shuffle = {3, 0, 1, 2};
+ } else if (input_axes_order == AxesOrder::kOHWI &&
+ output_axes_order == AxesOrder::kHWOI) {
+ *shuffle = {1, 2, 0, 3};
} else {
LOG(FATAL) << "Bad shuffle";
}
@@ -2020,6 +2041,8 @@ int AxesCount(AxesOrder axes_order) {
return 4;
case AxesOrder::kNHWC:
return 4;
+ case AxesOrder::kHWOI:
+ return 4;
default:
LOG(FATAL) << "Bad AxesOrder";
return 0;
@@ -2188,4 +2211,51 @@ void UseArraysExtraInfo(Model* model, bool quantize_output) {
}
}
+void UndoWeightsShuffling(Model* model) {
+ for (const auto& op : model->operators) {
+ if (op->type != toco::OperatorType::kFullyConnected) {
+ continue;
+ }
+ const auto& fc_op = static_cast<toco::FullyConnectedOperator&>(*op);
+ if (fc_op.weights_format == FullyConnectedWeightsFormat::kDefault) {
+ continue;
+ }
+ const string& weights_name = fc_op.inputs[1];
+ QCHECK_EQ(CountOpsWithInput(*model, weights_name), 1);
+ auto& weights_array = model->GetArray(weights_name);
+ QCHECK(weights_array.data_type == ArrayDataType::kUint8);
+ auto& weights_data =
+ weights_array.GetMutableBuffer<toco::ArrayDataType::kUint8>().data;
+ const auto& weights_shape = weights_array.shape();
+ QCHECK_EQ(weights_shape.dimensions_count(), 2);
+ const int rows = weights_shape.dims(0);
+ const int cols = weights_shape.dims(1);
+ QCHECK_EQ(rows % 4, 0);
+ QCHECK_EQ(cols % 16, 0);
+ CHECK_EQ(rows * cols, weights_data.size());
+ // Compute the de-shuffled weights
+ std::vector<uint8> deshuffled_data(weights_data.size());
+ uint8* shuffled_data_ptr = weights_data.data();
+ for (int r = 0; r < rows; r += 4) {
+ for (int c = 0; c < cols; c += 16) {
+ for (int i = 0; i < 4; i++) {
+ uint8* deshuffled_data_ptr =
+ deshuffled_data.data() + (r + i) * cols + c;
+ for (int j = 0; j < 16; j++) {
+ uint8 shuffled_val = *shuffled_data_ptr++;
+ // Deshuffling isn't only about deshuffling the storage layout,
+ // it's also about undoing the flipping of the sign bit, which is
+ // performed on the shuffled weights.
+ uint8 deshuffled_val = shuffled_val ^ 0x80;
+ *deshuffled_data_ptr++ = deshuffled_val;
+ }
+ }
+ }
+ }
+ CHECK_EQ(shuffled_data_ptr, weights_data.data() + rows * cols);
+ // Switch this FC op to using the deshuffled weights.
+ weights_data = std::move(deshuffled_data);
+ }
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 1f596ca8e5..5dbfa54fa0 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -26,14 +26,15 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#if TOCO_SUPPORT_PORTABLE_PROTOS
-#include "third_party/protobuf/src/google/protobuf/text_format.h"
+#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
#include "tensorflow/contrib/lite/toco/toco_flags.pb.h"
-#include "tensorflow/contrib/lite/toco/toco_port.h"
#include "tensorflow/contrib/lite/toco/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
// TODO(aselle): Replace with using a container specific hash override instead.
namespace std {
@@ -100,6 +101,8 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOp(Model& model,
const char* OperatorTypeName(OperatorType type);
string HelpfulOperatorTypeName(const Operator& op);
+// Whether the operator can be fused with an activation function. Note that this
+// will return false by default for new operators; fusing support is opt-in.
bool OperatorSupportsFusedActivation(OperatorType type);
void DumpGraphvizVideoFrame(const Model& model);
@@ -112,7 +115,9 @@ void ExtendShape(Shape* shape, int new_shape_size);
// TODO(b/36075966): Clean up when dims superseded by array shape.
void UnextendShape(Shape* shape, int new_shape_size);
-// Checks (using CHECK) that all dimensions of 'shape' are at least 1.
+// Checks that all dimensions of 'shape' are at least 1.
+bool IsValid(const Shape& shape);
+// Same as above, but reports error using CHECK.
void CheckShapeDimensions(const Shape& shape);
// Given two shapes with potentially different dimensionality and dimension
@@ -315,7 +320,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output);
// doesn't have enough range to represent the sum of elements, an error is
// returned.
template <typename T, typename U>
-port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
+tensorflow::Status NumElements(const std::vector<T>& shape, U* num_elements) {
static_assert(
std::numeric_limits<T>::max() <= std::numeric_limits<uint64_t>::max(),
"vector type exceed capabilities of NumElements");
@@ -326,19 +331,24 @@ port::Status NumElements(const std::vector<T>& shape, U* num_elements) {
// TensorFlow's shapes sometimes include -1 to represent an "unknown"
// size but TOCO isn't able to create arrays of unknown sizes and will
// crash in RequiredBufferSizeForShape().
- return port::Status(false,
- "Tensor shape should not include negative values");
+ return tensorflow::errors::InvalidArgument(
+ "Tensor shape should not include negative values");
}
if (static_cast<uint64_t>(dim) >
std::numeric_limits<U>::max() / *num_elements) {
*num_elements = 0;
- return port::Status(false, "Tensor shape is too large");
+ return tensorflow::errors::InvalidArgument("Tensor shape is too large");
}
*num_elements *= dim;
}
- return port::Status::OK();
+ return tensorflow::Status::OK();
}
+// A model file may have shuffled FC weights.
+// When that happens, we want to de-shuffle them immediately on import,
+// so that the rest of toco doesn't need to know about shuffled weights.
+void UndoWeightsShuffling(Model* model);
+
} // namespace toco
#endif // TENSORFLOW_CONTRIB_LITE_TOCO_TOOLING_UTIL_H_
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index 87fd30db2c..8609e5bedd 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/lib/core/status.h"
namespace toco {
@@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large";
TEST(NumElementsTest, Int) {
int count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int>{1024, 1024, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) {
TEST(NumElementsTest, Int32) {
int32_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int32_t>{1024, 1024, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) {
TEST(NumElementsTest, Int64) {
int64_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<int64_t>{16777216, 16777216, 32767}, &count);
EXPECT_TRUE(status.ok());
@@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) {
TEST(NumElementsTest, UnsignedInt32) {
uint32_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status = NumElements(std::vector<uint32_t>{1024, 2048, 2047}, &count);
EXPECT_TRUE(status.ok());
@@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) {
TEST(NumElementsTest, UnsignedInt64) {
uint64_t count;
- port::Status status = port::Status::OK();
+ tensorflow::Status status = tensorflow::Status::OK();
status =
NumElements(std::vector<uint64_t>{16777216, 16777216, 65535}, &count);
@@ -174,4 +175,10 @@ TEST(NumElementsTest, UnsignedInt64) {
EXPECT_EQ(status.error_message(), kLargeTensorMessage);
}
+TEST(FusedActivationTest, DefaultsToUnfused) {
+ EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd));
+ EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone));
+ EXPECT_FALSE(OperatorSupportsFusedActivation(static_cast<OperatorType>(255)));
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD
index 7fb7517600..d070018e83 100644
--- a/tensorflow/contrib/lite/tools/BUILD
+++ b/tensorflow/contrib/lite/tools/BUILD
@@ -14,6 +14,7 @@ py_binary(
srcs = ["visualize.py"],
data = [
"//tensorflow/contrib/lite/schema:schema.fbs",
+ "//tensorflow/python:platform",
"@flatbuffers//:flatc",
],
srcs_version = "PY2AND3",
@@ -30,87 +31,6 @@ tf_cc_binary(
],
)
-tf_cc_binary(
- name = "benchmark_model",
- srcs = [
- "benchmark_main.cc",
- "logging.h",
- ],
- copts = common_copts,
- linkopts = select({
- "//tensorflow:android": [
- "-pie",
- "-landroid",
- "-lm",
- "-z defs",
- "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
- ],
- "//conditions:default": [],
- }),
- deps = [
- ":benchmark_tflite_model_lib",
- "//tensorflow/core:stats_calculator_portable",
- ],
-)
-
-cc_library(
- name = "command_line_flags",
- srcs = ["command_line_flags.cc"],
- hdrs = ["command_line_flags.h"],
- copts = common_copts,
- visibility = ["//visibility:private"],
-)
-
-cc_test(
- name = "command_line_flags_test",
- srcs = ["command_line_flags_test.cc"],
- copts = common_copts,
- visibility = ["//visibility:private"],
- deps = [
- ":command_line_flags",
- "//tensorflow/contrib/lite/testing:util",
- "@com_google_googletest//:gtest",
- ],
-)
-
-cc_library(
- name = "benchmark_tflite_model_lib",
- srcs = [
- "benchmark_tflite_model.cc",
- "logging.h",
- ],
- hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts,
- deps = [
- ":benchmark_model_lib",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/profiling:profile_summarizer",
- "//tensorflow/contrib/lite/profiling:profiler",
- ],
-)
-
-cc_library(
- name = "benchmark_model_lib",
- srcs = [
- "benchmark_model.cc",
- "logging.h",
- ],
- hdrs = ["benchmark_model.h"],
- copts = common_copts,
- deps = [
- ":command_line_flags",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/profiling:profile_summarizer",
- "//tensorflow/contrib/lite/profiling:profiler",
- "//tensorflow/contrib/lite/profiling:time",
- "//tensorflow/core:stats_calculator_portable",
- ],
-)
-
cc_library(
name = "gen_op_registration",
srcs = ["gen_op_registration.cc"],
@@ -134,6 +54,7 @@ cc_test(
],
tags = [
"tflite_not_portable_android",
+ "tflite_not_portable_ios",
],
deps = [
":gen_op_registration",
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
new file mode 100644
index 0000000000..183a545295
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -0,0 +1,100 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+common_copts = ["-Wall"] + tflite_copts()
+
+cc_binary(
+ name = "benchmark_model",
+ srcs = [
+ "benchmark_main.cc",
+ "logging.h",
+ ],
+ copts = common_copts,
+ linkopts = tflite_linkopts() + select({
+ "//tensorflow:android": [
+ "-pie", # Android 5.0 and later supports only PIE
+ "-lm", # some builtin ops, e.g., tanh, need -lm
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ ":benchmark_tflite_model_lib",
+ ],
+)
+
+cc_library(
+ name = "command_line_flags",
+ srcs = ["command_line_flags.cc"],
+ hdrs = ["command_line_flags.h"],
+ copts = common_copts,
+)
+
+cc_test(
+ name = "command_line_flags_test",
+ srcs = ["command_line_flags_test.cc"],
+ copts = common_copts,
+ visibility = ["//visibility:private"],
+ deps = [
+ ":command_line_flags",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
+ name = "benchmark_tflite_model_lib",
+ srcs = [
+ "benchmark_tflite_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_tflite_model.h"],
+ copts = common_copts,
+ deps = [
+ ":benchmark_model_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
+ "//tensorflow/contrib/lite/profiling:profiler",
+ ],
+)
+
+cc_library(
+ name = "benchmark_params",
+ srcs = [
+ "benchmark_params.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_params.h"],
+ copts = common_copts,
+)
+
+cc_library(
+ name = "benchmark_model_lib",
+ srcs = [
+ "benchmark_model.cc",
+ "logging.h",
+ ],
+ hdrs = ["benchmark_model.h"],
+ copts = common_copts,
+ deps = [
+ ":benchmark_params",
+ ":command_line_flags",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/profiling:profile_summarizer",
+ "//tensorflow/contrib/lite/profiling:profiler",
+ "//tensorflow/contrib/lite/profiling:time",
+ "//tensorflow/core:stats_calculator_portable",
+ ],
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
new file mode 100644
index 0000000000..93769305bd
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -0,0 +1,209 @@
+# TFLite Model Benchmark Tool
+
+## Description
+
+A simple C++ binary to benchmark a TFLite model and its individual operators,
+both on desktop machines and on Android. The binary takes a TFLite model,
+generates random inputs and then repeatedly runs the model for specified number
+of runs. Aggregrate latency statistics are reported after running the benchmark.
+
+The instructions below are for running the binary on Desktop and Android,
+for iOS please use the
+[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+
+## Parameters
+
+The binary takes the following required parameters:
+
+* `graph`: `string` \
+ The path to the TFLite model file.
+* `input_layer`: `string` \
+ The name of the input layer, this is typically the first layer of the model.
+* `input_layer_shape`: `string` \
+ The shape of the input layer. This is a comma separated string of the shape
+ of tensor of input layer.
+
+and the following optional parameters:
+
+* `num_threads`: `int` (default=1) \
+ The number of threads to use for running TFLite interpreter.
+* `warmup_runs`: `int` (default=1) \
+ The number of warmup runs to do before starting the benchmark.
+* `run_delay`: `float` (default=-1.0) \
+ The delay in seconds between subsequent benchmark runs. Non-positive values
+ mean use no delay.
+* `use_nnapi`: `bool` (default=false) \
+ Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+ This API is available on recent Android devices.
+
+## To build/install/run
+
+### On Android:
+
+(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK.
+
+(1) Build for your specific platform, e.g.:
+
+```
+bazel build -c opt \
+ --config=android_arm \
+ --cxxopt='--std=c++11' \
+ tensorflow/contrib/lite/tools/benchmark:benchmark_model
+```
+
+(2) Connect your phone. Push the binary to your phone with adb push
+ (make the directory if required):
+
+```
+adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp
+```
+
+(3) Make the binary executable.
+
+```
+adb shell chmod +x /data/local/tmp/benchmark_model
+```
+
+(4) Push the compute graph that you need to test. For example:
+
+```
+adb push mobilenet_quant_v1_224.tflite /data/local/tmp
+```
+
+(5) Run the benchmark. For example:
+
+```
+adb shell /data/local/tmp/benchmark_model \
+ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
+ --input_layer="input" \
+ --input_layer_shape="1,224,224,3" \
+ --num_threads=4
+```
+
+### On desktop:
+(1) build the binary
+
+```
+bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model
+```
+
+(2) Run on your compute graph, similar to the Android case but without the need of adb shell.
+For example:
+
+```
+bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
+ --graph=mobilenet_quant_v1_224.tflite \
+ --input_layer="Placeholder" \
+ --input_layer_shape="1,224,224,3" \
+ --num_threads=4
+```
+
+The MobileNet graph used as an example here may be downloaded from
+https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+
+
+## Reducing variance between runs on Android.
+
+Most modern Android phones use [ARM big.LITTLE](https://en.wikipedia.org/wiki/ARM_big.LITTLE)
+architecture where some cores are more power hungry but faster than other cores.
+When running benchmarks on these phones there can be significant variance
+between different runs of the benchmark. One way to reduce variance between runs
+is to set the [CPU affinity](https://en.wikipedia.org/wiki/Processor_affinity)
+before running the benchmark. On Android this can be done using the `taskset`
+command.
+E.g. for running the benchmark on big cores on Pixel 2 with a single thread one
+can use the following command:
+
+```
+adb shell tasket f0 /data/local/tmp/benchmark_model \
+ --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
+ --input_layer="input" \
+ --input_layer_shape="1,224,224,3" \
+ --num_threads=1
+```
+
+where `f0` is the affinity mask for big cores on Pixel 2.
+Note: The affinity mask varies with the device.
+
+## Profiling model operators
+The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this,
+compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED**
+to compile benchmark with profiling support.
+For example, to compile with profiling support on Android, add this flag to the previous command:
+
+```
+bazel build -c opt \
+ --config=android_arm \
+ --cxxopt='--std=c++11' \
+ --copt=-DTFLITE_PROFILING_ENABLED \
+ tensorflow/contrib/lite/tools/benchmark:benchmark_model
+```
+This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below:
+
+```
+
+============================== Run Order ==============================
+ [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
+ CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
+ DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6]
+ CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6]
+ CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6]
+ CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6]
+ CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6]
+ CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6]
+ CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6]
+ CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6]
+ CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6]
+ CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6]
+ CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6]
+ CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6]
+ CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6]
+ DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6]
+ CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
+ AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool]
+ CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd]
+ RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape]
+ SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax]
+
+============================== Top by Computation Time ==============================
+ [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name]
+ CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6]
+ CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6]
+ CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6]
+ CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6]
+ CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6]
+ CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6]
+ CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6]
+ CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6]
+ CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6]
+ CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6]
+
+Number of nodes executed: 31
+============================== Summary by node type ==============================
+ [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called]
+ CONV_2D 15 1.406 89.270% 89.270% 0.000 0
+ DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0
+ SOFTMAX 1 0.000 0.000% 100.000% 0.000 0
+ RESHAPE 1 0.000 0.000% 100.000% 0.000 0
+ AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0
+
+Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929
+Memory (bytes): count=0
+31 nodes observed
+
+
+Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
+```
+
+
diff --git a/tensorflow/contrib/lite/tools/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc
index 1325385e32..372d31e838 100644
--- a/tensorflow/contrib/lite/tools/benchmark_main.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
-#include "tensorflow/contrib/lite/tools/logging.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
namespace tflite {
namespace benchmark {
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
index 550994c662..19b9a9c7ba 100644
--- a/tensorflow/contrib/lite/tools/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/tools/benchmark_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
#include <time.h>
@@ -21,7 +21,7 @@ limitations under the License.
#include <sstream>
#include "tensorflow/contrib/lite/profiling/time.h"
-#include "tensorflow/contrib/lite/tools/logging.h"
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
namespace {
void SleepForSeconds(double sleep_seconds) {
@@ -48,6 +48,19 @@ namespace tflite {
namespace benchmark {
using tensorflow::Stat;
+BenchmarkParams BenchmarkModel::DefaultParams() {
+ BenchmarkParams params;
+ params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
+ params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
+ params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
+ params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
+ params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
+ params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
+ return params;
+}
+
+BenchmarkModel::BenchmarkModel() : params_(DefaultParams()) {}
+
void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
auto inference_us = results.inference_time_us();
auto init_us = results.startup_latency_us();
@@ -60,30 +73,38 @@ void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) {
std::vector<Flag> BenchmarkModel::GetFlags() {
return {
- Flag("num_runs", &params_.num_runs, "number of runs"),
- Flag("run_delay", &params_.run_delay, "delay between runs in seconds"),
- Flag("num_threads", &params_.num_threads, "number of threads"),
- Flag("benchmark_name", &params_.benchmark_name, "benchmark name"),
- Flag("output_prefix", &params_.output_prefix, "benchmark output prefix"),
- Flag("warmup_runs", &params_.warmup_runs,
- "how many runs to initialize model"),
+ CreateFlag<int32_t>("num_runs", &params_, "number of runs"),
+ CreateFlag<float>("run_delay", &params_, "delay between runs in seconds"),
+ CreateFlag<int32_t>("num_threads", &params_, "number of threads"),
+ CreateFlag<std::string>("benchmark_name", &params_, "benchmark name"),
+ CreateFlag<std::string>("output_prefix", &params_,
+ "benchmark output prefix"),
+ CreateFlag<int32_t>("warmup_runs", &params_,
+ "how many runs to initialize model"),
};
}
void BenchmarkModel::LogFlags() {
- TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]";
- TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay
+ TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]";
+ TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
+ << params_.Get<float>("run_delay") << "]";
+ TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
+ << "]";
+ TFLITE_LOG(INFO) << "Benchmark name: ["
+ << params_.Get<std::string>("benchmark_name") << "]";
+ TFLITE_LOG(INFO) << "Output prefix: ["
+ << params_.Get<std::string>("output_prefix") << "]";
+ TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs")
<< "]";
- TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]";
- TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]";
- TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]";
- TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]";
}
+void BenchmarkModel::PrepareInputsAndOutputs() {}
+
Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
Stat<int64_t> run_stats;
TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
for (int run = 0; run < num_times; run++) {
+ PrepareInputsAndOutputs();
listeners_.OnSingleRunStart(run_type);
int64_t start_us = profiling::time::NowMicros();
RunImpl();
@@ -91,7 +112,7 @@ Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
listeners_.OnSingleRunEnd();
run_stats.UpdateStat(end_us - start_us);
- SleepForSeconds(params_.run_delay);
+ SleepForSeconds(params_.Get<float>("run_delay"));
}
std::stringstream stream;
@@ -117,8 +138,10 @@ void BenchmarkModel::Run(int argc, char **argv) {
<< "ms";
uint64_t input_bytes = ComputeInputBytes();
- Stat<int64_t> warmup_time_us = Run(params_.warmup_runs, WARMUP);
- Stat<int64_t> inference_time_us = Run(params_.num_runs, REGULAR);
+ Stat<int64_t> warmup_time_us =
+ Run(params_.Get<int32_t>("warmup_runs"), WARMUP);
+ Stat<int64_t> inference_time_us =
+ Run(params_.Get<int32_t>("num_runs"), REGULAR);
listeners_.OnBenchmarkEnd(
{startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
}
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
index ef8d6a7d1e..3c7063b2d4 100644
--- a/tensorflow/contrib/lite/tools/benchmark_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h
@@ -23,7 +23,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/contrib/lite/tools//command_line_flags.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h"
+#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
#include "tensorflow/core/util/stats_calculator.h"
namespace tflite {
@@ -63,17 +64,6 @@ class BenchmarkResults {
tensorflow::Stat<int64_t> inference_time_us_;
};
-struct BenchmarkParams {
- BenchmarkParams()
- : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {}
- int num_runs;
- int warmup_runs;
- float run_delay;
- int num_threads;
- std::string benchmark_name;
- std::string output_prefix;
-};
-
class BenchmarkListener {
public:
virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
@@ -130,12 +120,22 @@ class BenchmarkLoggingListener : public BenchmarkListener {
void OnBenchmarkEnd(const BenchmarkResults& results) override;
};
+template <typename T>
+Flag CreateFlag(const char* name, BenchmarkParams* params,
+ const std::string& usage) {
+ return Flag(name, [params, name](const T& val) { params->Set<T>(name, val); },
+ params->Get<T>(name), usage);
+}
+
// Benchmarks a model.
//
// Subclasses need to implement initialization and running of the model.
// The results can be collected by adding BenchmarkListener(s).
class BenchmarkModel {
public:
+ static BenchmarkParams DefaultParams();
+ BenchmarkModel();
+ BenchmarkModel(BenchmarkParams params) : params_(std::move(params)) {}
virtual ~BenchmarkModel() {}
bool ParseFlags(int argc, char** argv);
virtual void Init() = 0;
@@ -150,6 +150,7 @@ class BenchmarkModel {
virtual std::vector<Flag> GetFlags();
virtual uint64_t ComputeInputBytes() = 0;
virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
+ virtual void PrepareInputsAndOutputs();
virtual void RunImpl() = 0;
BenchmarkParams params_;
BenchmarkListeners listeners_;
@@ -158,4 +159,4 @@ class BenchmarkModel {
} // namespace benchmark
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc
new file mode 100644
index 0000000000..1dcf580a9d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_params.h"
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+void BenchmarkParam::AssertHasSameType(BenchmarkParam::ParamType a,
+ BenchmarkParam::ParamType b) {
+ TFLITE_BENCHMARK_CHECK(a == b) << "Type mismatch while accessing parameter.";
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<int32_t>() {
+ return BenchmarkParam::ParamType::TYPE_INT32;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<bool>() {
+ return BenchmarkParam::ParamType::TYPE_BOOL;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<float>() {
+ return BenchmarkParam::ParamType::TYPE_FLOAT;
+}
+
+template <>
+BenchmarkParam::ParamType BenchmarkParam::GetValueType<std::string>() {
+ return BenchmarkParam::ParamType::TYPE_STRING;
+}
+
+void BenchmarkParams::AssertParamExists(const std::string& name) const {
+ TFLITE_BENCHMARK_CHECK(HasParam(name)) << name << " was not found.";
+}
+
+} // namespace benchmark
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h
new file mode 100644
index 0000000000..33448dd162
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_params.h
@@ -0,0 +1,101 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+namespace tflite {
+namespace benchmark {
+
+template <typename T>
+class TypedBenchmarkParam;
+
+class BenchmarkParam {
+ protected:
+ enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
+
+ public:
+ template <typename T>
+ static std::unique_ptr<BenchmarkParam> Create(const T& default_value) {
+ return std::unique_ptr<BenchmarkParam>(
+ new TypedBenchmarkParam<T>(default_value));
+ }
+
+ template <typename T>
+ TypedBenchmarkParam<T>* AsTyped() {
+ AssertHasSameType(GetValueType<T>(), type_);
+ return static_cast<TypedBenchmarkParam<T>*>(this);
+ }
+ virtual ~BenchmarkParam() {}
+ BenchmarkParam(ParamType type) : type_(type) {}
+
+ private:
+ static void AssertHasSameType(ParamType a, ParamType b);
+ template <typename T>
+ static ParamType GetValueType();
+
+ const ParamType type_;
+};
+
+template <typename T>
+class TypedBenchmarkParam : public BenchmarkParam {
+ public:
+ TypedBenchmarkParam(const T& value)
+ : BenchmarkParam(GetValueType<T>()), value_(value) {}
+ void Set(const T& value) { value_ = value; }
+
+ T Get() { return value_; }
+
+ private:
+ T value_;
+};
+
+class BenchmarkParams {
+ public:
+ void AddParam(const std::string& name,
+ std::unique_ptr<BenchmarkParam> value) {
+ params_[name] = std::move(value);
+ }
+
+ bool HasParam(const std::string& name) const {
+ return params_.find(name) != params_.end();
+ }
+
+ template <typename T>
+ void Set(const std::string& name, const T& value) {
+ AssertParamExists(name);
+ params_.at(name)->AsTyped<T>()->Set(value);
+ }
+
+ template <typename T>
+ T Get(const std::string& name) const {
+ AssertParamExists(name);
+ return params_.at(name)->AsTyped<T>()->Get();
+ }
+
+ private:
+ void AssertParamExists(const std::string& name) const;
+ std::unordered_map<std::string, std::unique_ptr<BenchmarkParam>> params_;
+};
+
+} // namespace benchmark
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_PARAMS_H_
diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index be8f46f599..73affc26b0 100644
--- a/tensorflow/contrib/lite/tools/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/tools/benchmark_tflite_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
#include <cstdarg>
#include <cstdlib>
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
#include "tensorflow/contrib/lite/string_util.h"
-#include "tensorflow/contrib/lite/tools/logging.h"
+#include "tensorflow/contrib/lite/tools/benchmark/logging.h"
#ifdef TFLITE_CUSTOM_OPS_HEADER
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
@@ -123,29 +123,11 @@ void FillRandomString(tflite::DynamicBuffer* buffer,
}
}
-TfLiteType TfLiteTypeFromString(const string& input_layer_type) {
- if (input_layer_type == "string")
- return kTfLiteString;
- else if (input_layer_type == "float")
- return kTfLiteFloat32;
- else if (input_layer_type == "uint8")
- return kTfLiteUInt8;
- else if (input_layer_type == "int32")
- return kTfLiteInt32;
- else if (input_layer_type == "int64")
- return kTfLiteInt64;
- else
- return kTfLiteNoType;
-}
-
bool PopulateInputLayerInfo(
const string& names_string, const string& shapes_string,
- const string& types_string, const string& values_string,
std::vector<BenchmarkTfLiteModel::InputLayerInfo>* info) {
std::vector<std::string> names = Split(names_string, ',');
std::vector<std::string> shapes = Split(shapes_string, ':');
- std::vector<std::string> types = Split(types_string, ',');
- std::vector<std::string> values = Split(values_string, ':');
if (names.size() != shapes.size()) {
TFLITE_LOG(ERROR) << "The number of items in"
@@ -158,17 +140,6 @@ bool PopulateInputLayerInfo(
<< " --input_layer_shape=1,224,224,4:1,20";
return false;
}
- if (names.size() != types.size()) {
- TFLITE_LOG(ERROR) << "The number of items in"
- << " --input_layer_type (" << types_string << ", with "
- << types.size() << " items)"
- << " must match the number of items in"
- << " --input_layer (" << names_string << ", with "
- << names.size() << " items)."
- << " For example --input_layer=input1,input2"
- << " --input_layer_type=float,int";
- return false;
- }
for (int i = 0; i < names.size(); ++i) {
info->push_back(BenchmarkTfLiteModel::InputLayerInfo());
@@ -176,10 +147,6 @@ bool PopulateInputLayerInfo(
input.name = names[i];
- input.data_type = TfLiteTypeFromString(types[i]);
- TFLITE_BENCHMARK_CHECK(input.data_type != kTfLiteNoType)
- << types[i] << " was an invalid type";
-
TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape))
<< "Incorrect size string specified: " << shapes[i];
for (int dim : input.shape) {
@@ -190,30 +157,42 @@ bool PopulateInputLayerInfo(
return false;
}
}
-
- if (i < values.size()) {
- TFLITE_BENCHMARK_CHECK(
- SplitAndParse(values[i], ',', &input.initialization_values))
- << "Incorrect initialization values string specified: " << values[i];
- }
}
return true;
}
+BenchmarkParams GetDefaultParams() {
+ BenchmarkParams default_params = BenchmarkModel::DefaultParams();
+ default_params.AddParam("graph", BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("input_layer",
+ BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("input_layer_shape",
+ BenchmarkParam::Create<std::string>(""));
+ default_params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
+ return default_params;
+}
+
} // namespace
+BenchmarkTfLiteModel::BenchmarkTfLiteModel()
+ : BenchmarkModel(GetDefaultParams()) {
+ AddListener(&profiling_listener_);
+}
+
+BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params)
+ : BenchmarkModel(std::move(params)) {
+ AddListener(&profiling_listener_);
+}
+
std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
std::vector<Flag> flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags();
std::vector<Flag> specific_flags = {
- Flag("graph", &graph, "graph file name"),
- Flag("input_layer", &input_layer_string, "input layer names"),
- Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"),
- Flag("input_layer_type", &input_layer_type_string, "input layer type"),
- Flag("input_layer_values", &input_layer_values_string,
- "values to initialize the inputs with"),
- Flag("output_layer", &output_layer_string, "output layer name"),
- Flag("use_nnapi", &use_nnapi, "use nnapi api")};
+ CreateFlag<std::string>("graph", &params_, "graph file name"),
+ CreateFlag<std::string>("input_layer", &params_, "input layer names"),
+ CreateFlag<std::string>("input_layer_shape", &params_,
+ "input layer shape"),
+ CreateFlag<bool>("use_nnapi", &params_, "use nnapi api")};
flags.insert(flags.end(), specific_flags.begin(), specific_flags.end());
return flags;
@@ -221,23 +200,23 @@ std::vector<Flag> BenchmarkTfLiteModel::GetFlags() {
void BenchmarkTfLiteModel::LogFlags() {
BenchmarkModel::LogFlags();
- TFLITE_LOG(INFO) << "Graph: [" << graph << "]";
- TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]";
- TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]";
- TFLITE_LOG(INFO) << "Input types: [" << input_layer_type_string << "]";
- TFLITE_LOG(INFO) << "Output layers: [" << output_layer_string << "]";
- TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]";
+ TFLITE_LOG(INFO) << "Graph: [" << params_.Get<std::string>("graph") << "]";
+ TFLITE_LOG(INFO) << "Input layers: ["
+ << params_.Get<std::string>("input_layer") << "]";
+ TFLITE_LOG(INFO) << "Input shapes: ["
+ << params_.Get<std::string>("input_layer_shape") << "]";
+ TFLITE_LOG(INFO) << "Use nnapi : [" << params_.Get<bool>("use_nnapi") << "]";
}
bool BenchmarkTfLiteModel::ValidateFlags() {
- if (graph.empty()) {
+ if (params_.Get<std::string>("graph").empty()) {
TFLITE_LOG(ERROR)
<< "Please specify the name of your TF Lite input file with --graph";
return false;
}
- return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string,
- input_layer_type_string,
- input_layer_values_string, &inputs);
+ return PopulateInputLayerInfo(params_.Get<std::string>("input_layer"),
+ params_.Get<std::string>("input_layer_shape"),
+ &inputs);
}
uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
@@ -251,6 +230,7 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
}
void BenchmarkTfLiteModel::Init() {
+ std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
if (!model) {
TFLITE_LOG(FATAL) << "Failed to mmap model " << graph;
@@ -272,10 +252,14 @@ void BenchmarkTfLiteModel::Init() {
}
profiling_listener_.SetInterpreter(interpreter.get());
- if (params_.num_threads != -1) {
- interpreter->SetNumThreads(params_.num_threads);
+ const int32_t num_threads = params_.Get<int32_t>("num_threads");
+
+ if (num_threads != -1) {
+ interpreter->SetNumThreads(num_threads);
}
+ bool use_nnapi = params_.Get<bool>("use_nnapi");
+
interpreter->UseNNAPI(use_nnapi);
auto interpreter_inputs = interpreter->inputs();
@@ -293,8 +277,6 @@ void BenchmarkTfLiteModel::Init() {
TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name)
<< "Tensor # " << i << " is named " << t->name << " but flags call it "
<< input.name;
- TFLITE_BENCHMARK_CHECK_EQ(t->type, input.data_type)
- << "Could not match the type of input tensor " << t->name;
}
// Resize all non-string tensors.
diff --git a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index e6d03d5211..50cc3f24b3 100644
--- a/tensorflow/contrib/lite/tools/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
-#include "tensorflow/contrib/lite/tools/benchmark_model.h"
+#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
namespace tflite {
namespace benchmark {
@@ -50,9 +50,8 @@ class ProfilingListener : public BenchmarkListener {
// Benchmarks a TFLite model by running tflite interpreter.
class BenchmarkTfLiteModel : public BenchmarkModel {
public:
- BenchmarkTfLiteModel() : use_nnapi(false) {
- AddListener(&profiling_listener_);
- }
+ BenchmarkTfLiteModel();
+ BenchmarkTfLiteModel(BenchmarkParams params);
std::vector<Flag> GetFlags() override;
void LogFlags() override;
@@ -64,27 +63,17 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
struct InputLayerInfo {
std::string name;
- TfLiteType data_type;
std::vector<int> shape;
- // Note that initialization_values is currently unused.
- std::vector<float> initialization_values;
};
private:
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
- std::string graph;
- std::string input_layer_string;
- std::string input_layer_type_string;
- std::string input_layer_shape_string;
- std::string input_layer_values_string;
- std::string output_layer_string;
std::vector<InputLayerInfo> inputs;
- bool use_nnapi;
ProfilingListener profiling_listener_;
};
} // namespace benchmark
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_
diff --git a/tensorflow/contrib/lite/tools/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc
index ba72f40689..ff818b9dcb 100644
--- a/tensorflow/contrib/lite/tools/command_line_flags.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc
@@ -10,15 +10,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/tools/command_line_flags.h"
+#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
+#include <cstring>
#include <sstream>
#include <string>
+#include <utility>
#include <vector>
namespace tflite {
namespace {
+template <typename T>
+std::string ToString(T val) {
+ std::ostringstream stream;
+ stream << val;
+ return stream.str();
+}
+
bool ParseFlag(const std::string& arg, const std::string& flag,
const std::function<bool(const std::string&)>& parse_func,
bool* value_parsing_ok) {
@@ -27,7 +36,7 @@ bool ParseFlag(const std::string& arg, const std::string& flag,
if (arg.find(flag_prefix) != 0) {
return false;
}
- bool has_value = (arg.size() >= flag_prefix.size() + 1);
+ bool has_value = arg.size() >= flag_prefix.size();
*value_parsing_ok = has_value;
if (has_value) {
*value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
@@ -35,80 +44,80 @@ bool ParseFlag(const std::string& arg, const std::string& flag,
return true;
}
-bool ParseInt32Flag(const std::string& flag_value, int32_t* value) {
- char extra;
- return sscanf(flag_value.data(), "%d%c", value, &extra) == 1;
-}
-
-bool ParseInt64Flag(const std::string& flag_value, int64_t* value) {
- char extra;
- return sscanf(flag_value.data(), "%ld%c", value, &extra) == 1;
-}
-
-bool ParseBoolFlag(const std::string& flag_value, bool* value) {
- if (flag_value != "true" && flag_value != "false") {
+template <typename T>
+bool ParseFlag(const std::string& flag_value,
+ const std::function<void(const T&)>& hook) {
+ std::istringstream stream(flag_value);
+ T read_value;
+ stream >> read_value;
+ if (!stream.eof() && !stream.good()) {
return false;
}
-
- *value = (flag_value == "true");
+ hook(read_value);
return true;
}
-bool ParseFloatFlag(const std::string& flag_value, float* value) {
- char extra;
- return sscanf(flag_value.data(), "%f%c", value, &extra) == 1;
-}
+bool ParseBoolFlag(const std::string& flag_value,
+ const std::function<void(const bool&)>& hook) {
+ if (flag_value != "true" && flag_value != "false") {
+ return false;
+ }
-bool ParseStringFlag(const std::string& flag_value, std::string* value) {
- *value = flag_value;
+ hook(flag_value == "true");
return true;
}
-
} // namespace
-Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
+ int32_t default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_INT32),
- value_hook_([dst](const std::string& flag_value) {
- return ParseInt32Flag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<int32_t>(flag_value, hook);
}),
- default_for_display_(std::to_string(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
+ int64_t default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_INT64),
- value_hook_([dst](const std::string& flag_value) {
- return ParseInt64Flag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<int64_t>(flag_value, hook);
}),
- default_for_display_(std::to_string(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, float* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
+ float default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_FLOAT),
- value_hook_([dst](const std::string& flag_value) {
- return ParseFloatFlag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseFlag<float>(flag_value, hook);
}),
- default_for_display_(std::to_string(*dst)),
+ default_for_display_(ToString(default_value)),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, bool* dst, const std::string& usage_text)
+Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
+ bool default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_BOOL),
- value_hook_([dst](const std::string& flag_value) {
- return ParseBoolFlag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ return ParseBoolFlag(flag_value, hook);
}),
- default_for_display_((*dst) ? "true" : "false"),
+ default_for_display_(default_value ? "true" : "false"),
usage_text_(usage_text) {}
-Flag::Flag(const char* name, std::string* dst, const std::string& usage_text)
+Flag::Flag(const char* name,
+ const std::function<void(const std::string&)>& hook,
+ const std::string& default_value, const std::string& usage_text)
: name_(name),
type_(TYPE_STRING),
- value_hook_([dst](const std::string& flag_value) {
- return ParseStringFlag(flag_value, dst);
+ value_hook_([hook](const std::string& flag_value) {
+ hook(flag_value);
+ return true;
}),
- default_for_display_(*dst),
+ default_for_display_(default_value),
usage_text_(usage_text) {}
bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
@@ -166,7 +175,7 @@ std::string Flag::GetTypeName() const {
}
argv[dst++] = nullptr;
*argc = unknown_flags.size() + 1;
- return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
+ return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0);
}
/*static*/ std::string Flags::Usage(const std::string& cmdline,
diff --git a/tensorflow/contrib/lite/tools/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
index 0605d3c9d4..2e514ae3ea 100644
--- a/tensorflow/contrib/lite/tools/command_line_flags.h
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h
@@ -33,10 +33,11 @@ namespace tflite {
// int some_int = 10;
// bool some_switch = false;
// std::string some_name = "something";
+//
// std::vector<tensorFlow::Flag> flag_list = {
-// Flag("some_int", &some_int, "an integer that affects X"),
-// Flag("some_switch", &some_switch, "a bool that affects Y"),
-// Flag("some_name", &some_name, "a std::string that affects Z")
+// Flag::CreateFlag("some_int", &some_int, "an integer that affects X"),
+// Flag::CreateFlag("some_switch", &some_switch, "a bool that affects Y"),
+// Flag::CreateFlag("some_name", &some_name, "a string that affects Z")
// };
// // Get usage message before ParseFlags() to capture default values.
// std::string usage = Flag::Usage(argv[0], flag_list);
@@ -63,11 +64,21 @@ namespace tflite {
// text, and a pointer to the corresponding variable.
class Flag {
public:
- Flag(const char* name, int32_t* dst, const std::string& usage_text);
- Flag(const char* name, int64_t* dst, const std::string& usage_text);
- Flag(const char* name, bool* dst, const std::string& usage_text);
- Flag(const char* name, std::string* dst, const std::string& usage_text);
- Flag(const char* name, float* dst, const std::string& usage_text);
+ template <typename T>
+ static Flag CreateFlag(const char* name, T* val, const char* usage) {
+ return Flag(name, [val](const T& v) { *val = v; }, *val, usage);
+ }
+
+ Flag(const char* name, const std::function<void(const int32_t&)>& hook,
+ int32_t default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const int64_t&)>& hook,
+ int64_t default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const float&)>& hook,
+ float default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const bool&)>& hook,
+ bool default_value, const std::string& usage_text);
+ Flag(const char* name, const std::function<void(const std::string&)>& hook,
+ const std::string& default_value, const std::string& usage_text);
private:
friend class Flags;
@@ -109,4 +120,4 @@ class Flags {
} // namespace tflite
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_
diff --git a/tensorflow/contrib/lite/tools/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
index 463647bec9..03da805109 100644
--- a/tensorflow/contrib/lite/tools/command_line_flags_test.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/tools/command_line_flags.h"
+#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/testing/util.h"
@@ -34,15 +34,15 @@ TEST(CommandLineFlagsTest, BasicUsage) {
"--some_name=somethingelse",
"--some_float=42.0"};
int argc = 6;
- bool parsed_ok =
- Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {
- Flag("some_int32", &some_int32, "some int32"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name"),
- Flag("some_float", &some_float, "some float"),
- });
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {
+ Flag::CreateFlag("some_int32", &some_int32, "some int32"),
+ Flag::CreateFlag("some_int64", &some_int64, "some int64"),
+ Flag::CreateFlag("some_switch", &some_switch, "some switch"),
+ Flag::CreateFlag("some_name", &some_name, "some name"),
+ Flag::CreateFlag("some_float", &some_float, "some float"),
+ });
EXPECT_EQ(true, parsed_ok);
EXPECT_EQ(20, some_int32);
@@ -53,13 +53,26 @@ TEST(CommandLineFlagsTest, BasicUsage) {
EXPECT_EQ(argc, 1);
}
+TEST(CommandLineFlagsTest, EmptyStringFlag) {
+ int argc = 2;
+ std::string some_string = "invalid";
+ const char* argv_strings[] = {"program_name", "--some_string="};
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag::CreateFlag("some_string", &some_string, "some string")});
+
+ EXPECT_EQ(true, parsed_ok);
+ EXPECT_EQ(some_string, "");
+ EXPECT_EQ(argc, 1);
+}
+
TEST(CommandLineFlagsTest, BadIntValue) {
int some_int = 10;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_int=notanumber"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_int", &some_int, "some int")});
+ {Flag::CreateFlag("some_int", &some_int, "some int")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(10, some_int);
@@ -70,9 +83,9 @@ TEST(CommandLineFlagsTest, BadBoolValue) {
bool some_switch = false;
int argc = 2;
const char* argv_strings[] = {"program_name", "--some_switch=notabool"};
- bool parsed_ok =
- Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_switch", &some_switch, "some switch")});
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {Flag::CreateFlag("some_switch", &some_switch, "some switch")});
EXPECT_EQ(false, parsed_ok);
EXPECT_EQ(false, some_switch);
@@ -85,7 +98,7 @@ TEST(CommandLineFlagsTest, BadFloatValue) {
const char* argv_strings[] = {"program_name", "--some_float=notanumber"};
bool parsed_ok =
Flags::Parse(&argc, reinterpret_cast<const char**>(argv_strings),
- {Flag("some_float", &some_float, "some float")});
+ {Flag::CreateFlag("some_float", &some_float, "some float")});
EXPECT_EQ(false, parsed_ok);
EXPECT_NEAR(-23.23f, some_float, 1e-5f);
@@ -121,12 +134,13 @@ TEST(CommandLineFlagsTest, UsageString) {
std::string some_name = "something";
// Don't test float in this case, because precision is hard to predict and
// match against, and we don't want a flakey test.
- const string tool_name = "some_tool_name";
- string usage = Flags::Usage(tool_name + " <flags>",
- {Flag("some_int", &some_int, "some int"),
- Flag("some_int64", &some_int64, "some int64"),
- Flag("some_switch", &some_switch, "some switch"),
- Flag("some_name", &some_name, "some name")});
+ const std::string tool_name = "some_tool_name";
+ std::string usage = Flags::Usage(
+ tool_name + " <flags>",
+ {Flag::CreateFlag("some_int", &some_int, "some int"),
+ Flag::CreateFlag("some_int64", &some_int64, "some int64"),
+ Flag::CreateFlag("some_switch", &some_switch, "some switch"),
+ Flag::CreateFlag("some_name", &some_name, "some name")});
// Match the usage message, being sloppy about whitespace.
const char* expected_usage =
" usage: some_tool_name <flags>\n"
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
new file mode 100644
index 0000000000..c8d3307e29
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
@@ -0,0 +1,43 @@
+# TFLite iOS benchmark app.
+
+## Description
+
+An iOS app to benchmark TFLite models.
+
+The app reads benchmark parameters from a JSON file named `benchmark_params.json`
+in its `benchmark_data` directory. Any downloaded models for benchmarking should
+also be placed in `benchmark_data` directory.
+
+The JSON file specifies the name of the model file and other benchmarking
+parameters like inputs to the model, type of inputs, number of iterations,
+number of threads. The default values in the JSON file are for the
+Mobilenet_1.0_224 model
+([paper](https://arxiv.org/pdf/1704.04861.pdf),
+[tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz))
+
+## To build/install/run
+
+- Follow instructions at [iOS build for TFLite]
+(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
+to build TFLite.
+
+Running
+
+```bash
+tensorflow/contrib/lite/build_ios_universal_lib.sh
+```
+will also build `tensorflow/contrib/lite/gen/lib/benchmark-lib.a` .
+
+- Now copy the downloaded model file to `benchmark_data` directory.
+
+- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape`
+and other benchmark parameters.
+
+- Change `Build Phases -> Copy Bundle Resources` and add the model file to the
+resources that need to be copied.
+
+- Ensure that `Build Phases -> Link Binary With Library` contains the
+`Accelerate framework` and `tensorflow/contrib/lite/gen/lib/benchmark-lib.a`.
+
+- Now try running the app. The app has a single button that runs the benchmark
+ on the model and displays results in a text view below.
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj
new file mode 100644
index 0000000000..b908f733d4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark.xcodeproj/project.pbxproj
@@ -0,0 +1,381 @@
+// !$*UTF8*$!
+{
+ archiveVersion = 1;
+ classes = {
+ };
+ objectVersion = 50;
+ objects = {
+
+/* Begin PBXBuildFile section */
+ 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */ = {isa = PBXBuildFile; fileRef = 6FE7579920D59CE500F01636 /* benchmark_params.json */; };
+ 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */; };
+ 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 6FE7579E20D5A6A700F01636 /* Accelerate.framework */; };
+ 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */; };
+ 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */; };
+ 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */; };
+ 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400120D592D8008C9FE4 /* Main.storyboard */; };
+ 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 6FE9400420D592DA008C9FE4 /* Assets.xcassets */; };
+ 6FE9400B20D592DA008C9FE4 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 6FE9400A20D592DA008C9FE4 /* main.m */; };
+/* End PBXBuildFile section */
+
+/* Begin PBXFileReference section */
+ 6FE7579920D59CE500F01636 /* benchmark_params.json */ = {isa = PBXFileReference; lastKnownFileType = text.json; path = benchmark_params.json; sourceTree = "<group>"; };
+ 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "benchmark-lib.a"; path = "$SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib/benchmark-lib.a"; sourceTree = "<group>"; };
+ 6FE7579E20D5A6A700F01636 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
+ 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = "<group>"; };
+ 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TFLiteBenchmark.app; sourceTree = BUILT_PRODUCTS_DIR; };
+ 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
+ 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = "<group>"; };
+ 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; path = BenchmarkViewController.h; sourceTree = "<group>"; };
+ 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = BenchmarkViewController.mm; sourceTree = "<group>"; };
+ 6FE9400220D592D8008C9FE4 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = "<group>"; };
+ 6FE9400420D592DA008C9FE4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
+ 6FE9400920D592DA008C9FE4 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
+ 6FE9400A20D592DA008C9FE4 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = "<group>"; };
+/* End PBXFileReference section */
+
+/* Begin PBXFrameworksBuildPhase section */
+ 6FE93FF520D592D8008C9FE4 /* Frameworks */ = {
+ isa = PBXFrameworksBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 6FE7579F20D5A6A700F01636 /* Accelerate.framework in Frameworks */,
+ 6FE7579D20D5A5E000F01636 /* benchmark-lib.a in Frameworks */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXFrameworksBuildPhase section */
+
+/* Begin PBXGroup section */
+ 6FE7579820D59C8B00F01636 /* benchmark_data */ = {
+ isa = PBXGroup;
+ children = (
+ 6FE757A020D5AB8000F01636 /* mobilenet_v1_1.0_224.tflite */,
+ 6FE7579920D59CE500F01636 /* benchmark_params.json */,
+ );
+ path = benchmark_data;
+ sourceTree = "<group>";
+ };
+ 6FE7579B20D5A5E000F01636 /* Frameworks */ = {
+ isa = PBXGroup;
+ children = (
+ 6FE7579E20D5A6A700F01636 /* Accelerate.framework */,
+ 6FE7579C20D5A5E000F01636 /* benchmark-lib.a */,
+ );
+ name = Frameworks;
+ sourceTree = "<group>";
+ };
+ 6FE93FEF20D592D8008C9FE4 = {
+ isa = PBXGroup;
+ children = (
+ 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */,
+ 6FE93FF920D592D8008C9FE4 /* Products */,
+ 6FE7579B20D5A5E000F01636 /* Frameworks */,
+ );
+ sourceTree = "<group>";
+ };
+ 6FE93FF920D592D8008C9FE4 /* Products */ = {
+ isa = PBXGroup;
+ children = (
+ 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */,
+ );
+ name = Products;
+ sourceTree = "<group>";
+ };
+ 6FE93FFA20D592D8008C9FE4 /* TFLiteBenchmark */ = {
+ isa = PBXGroup;
+ children = (
+ 6FE7579820D59C8B00F01636 /* benchmark_data */,
+ 6FE93FFB20D592D8008C9FE4 /* AppDelegate.h */,
+ 6FE93FFC20D592D8008C9FE4 /* AppDelegate.m */,
+ 6FE93FFE20D592D8008C9FE4 /* BenchmarkViewController.h */,
+ 6FE93FFF20D592D8008C9FE4 /* BenchmarkViewController.mm */,
+ 6FE9400120D592D8008C9FE4 /* Main.storyboard */,
+ 6FE9400420D592DA008C9FE4 /* Assets.xcassets */,
+ 6FE9400920D592DA008C9FE4 /* Info.plist */,
+ 6FE9400A20D592DA008C9FE4 /* main.m */,
+ );
+ path = TFLiteBenchmark;
+ sourceTree = "<group>";
+ };
+/* End PBXGroup section */
+
+/* Begin PBXNativeTarget section */
+ 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */ = {
+ isa = PBXNativeTarget;
+ buildConfigurationList = 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */;
+ buildPhases = (
+ 6FE93FF420D592D8008C9FE4 /* Sources */,
+ 6FE93FF520D592D8008C9FE4 /* Frameworks */,
+ 6FE93FF620D592D8008C9FE4 /* Resources */,
+ );
+ buildRules = (
+ );
+ dependencies = (
+ );
+ name = TFLiteBenchmark;
+ productName = TFLiteBenchmark;
+ productReference = 6FE93FF820D592D8008C9FE4 /* TFLiteBenchmark.app */;
+ productType = "com.apple.product-type.application";
+ };
+/* End PBXNativeTarget section */
+
+/* Begin PBXProject section */
+ 6FE93FF020D592D8008C9FE4 /* Project object */ = {
+ isa = PBXProject;
+ attributes = {
+ LastUpgradeCheck = 1000;
+ ORGANIZATIONNAME = Example;
+ TargetAttributes = {
+ 6FE93FF720D592D8008C9FE4 = {
+ CreatedOnToolsVersion = 10.0;
+ };
+ };
+ };
+ buildConfigurationList = 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */;
+ compatibilityVersion = "Xcode 9.3";
+ developmentRegion = en;
+ hasScannedForEncodings = 0;
+ knownRegions = (
+ en,
+ Base,
+ );
+ mainGroup = 6FE93FEF20D592D8008C9FE4;
+ productRefGroup = 6FE93FF920D592D8008C9FE4 /* Products */;
+ projectDirPath = "";
+ projectRoot = "";
+ targets = (
+ 6FE93FF720D592D8008C9FE4 /* TFLiteBenchmark */,
+ );
+ };
+/* End PBXProject section */
+
+/* Begin PBXResourcesBuildPhase section */
+ 6FE93FF620D592D8008C9FE4 /* Resources */ = {
+ isa = PBXResourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 6FE757A120D5AB8100F01636 /* mobilenet_v1_1.0_224.tflite in Resources */,
+ 6FE9400520D592DA008C9FE4 /* Assets.xcassets in Resources */,
+ 6FE9400320D592D8008C9FE4 /* Main.storyboard in Resources */,
+ 6FE7579A20D59CE500F01636 /* benchmark_params.json in Resources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXResourcesBuildPhase section */
+
+/* Begin PBXSourcesBuildPhase section */
+ 6FE93FF420D592D8008C9FE4 /* Sources */ = {
+ isa = PBXSourcesBuildPhase;
+ buildActionMask = 2147483647;
+ files = (
+ 6FE9400020D592D8008C9FE4 /* BenchmarkViewController.mm in Sources */,
+ 6FE9400B20D592DA008C9FE4 /* main.m in Sources */,
+ 6FE93FFD20D592D8008C9FE4 /* AppDelegate.m in Sources */,
+ );
+ runOnlyForDeploymentPostprocessing = 0;
+ };
+/* End PBXSourcesBuildPhase section */
+
+/* Begin PBXVariantGroup section */
+ 6FE9400120D592D8008C9FE4 /* Main.storyboard */ = {
+ isa = PBXVariantGroup;
+ children = (
+ 6FE9400220D592D8008C9FE4 /* Base */,
+ );
+ name = Main.storyboard;
+ sourceTree = "<group>";
+ };
+/* End PBXVariantGroup section */
+
+/* Begin XCBuildConfiguration section */
+ 6FE9400C20D592DA008C9FE4 /* Debug */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_ENABLE_OBJC_WEAK = YES;
+ CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_COMMA = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
+ CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
+ CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
+ CLANG_WARN_STRICT_PROTOTYPES = YES;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = dwarf;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ ENABLE_TESTABILITY = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu11;
+ GCC_DYNAMIC_NO_PIC = NO;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_OPTIMIZATION_LEVEL = 0;
+ GCC_PREPROCESSOR_DEFINITIONS = (
+ "DEBUG=1",
+ "$(inherited)",
+ );
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 11.0;
+ MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
+ ONLY_ACTIVE_ARCH = YES;
+ OTHER_CFLAGS = "";
+ OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
+ SDKROOT = iphoneos;
+ };
+ name = Debug;
+ };
+ 6FE9400D20D592DA008C9FE4 /* Release */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ALWAYS_SEARCH_USER_PATHS = NO;
+ CLANG_ANALYZER_NONNULL = YES;
+ CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
+ CLANG_CXX_LANGUAGE_STANDARD = "gnu++14";
+ CLANG_CXX_LIBRARY = "libc++";
+ CLANG_ENABLE_MODULES = YES;
+ CLANG_ENABLE_OBJC_ARC = YES;
+ CLANG_ENABLE_OBJC_WEAK = YES;
+ CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
+ CLANG_WARN_BOOL_CONVERSION = YES;
+ CLANG_WARN_COMMA = YES;
+ CLANG_WARN_CONSTANT_CONVERSION = YES;
+ CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
+ CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
+ CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
+ CLANG_WARN_EMPTY_BODY = YES;
+ CLANG_WARN_ENUM_CONVERSION = YES;
+ CLANG_WARN_INFINITE_RECURSION = YES;
+ CLANG_WARN_INT_CONVERSION = YES;
+ CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
+ CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
+ CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
+ CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
+ CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
+ CLANG_WARN_STRICT_PROTOTYPES = YES;
+ CLANG_WARN_SUSPICIOUS_MOVE = YES;
+ CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
+ CLANG_WARN_UNREACHABLE_CODE = YES;
+ CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
+ CODE_SIGN_IDENTITY = "iPhone Developer";
+ COPY_PHASE_STRIP = NO;
+ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
+ ENABLE_NS_ASSERTIONS = NO;
+ ENABLE_STRICT_OBJC_MSGSEND = YES;
+ GCC_C_LANGUAGE_STANDARD = gnu11;
+ GCC_NO_COMMON_BLOCKS = YES;
+ GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
+ GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
+ GCC_WARN_UNDECLARED_SELECTOR = YES;
+ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
+ GCC_WARN_UNUSED_FUNCTION = YES;
+ GCC_WARN_UNUSED_VARIABLE = YES;
+ IPHONEOS_DEPLOYMENT_TARGET = 11.0;
+ MTL_ENABLE_DEBUG_INFO = NO;
+ OTHER_CFLAGS = "";
+ OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
+ SDKROOT = iphoneos;
+ VALIDATE_PRODUCT = YES;
+ };
+ name = Release;
+ };
+ 6FE9400F20D592DA008C9FE4 /* Debug */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_STYLE = Automatic;
+ "HEADER_SEARCH_PATHS[arch=*]" = (
+ $SRCROOT/../../../../../../../,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include,
+ );
+ INFOPLIST_FILE = TFLiteBenchmark/Info.plist;
+ LD_RUNPATH_SEARCH_PATHS = (
+ "$(inherited)",
+ "@executable_path/Frameworks",
+ );
+ "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib;
+ PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark;
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ TARGETED_DEVICE_FAMILY = "1,2";
+ "USER_HEADER_SEARCH_PATHS[arch=*]" = "";
+ };
+ name = Debug;
+ };
+ 6FE9401020D592DA008C9FE4 /* Release */ = {
+ isa = XCBuildConfiguration;
+ buildSettings = {
+ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
+ CODE_SIGN_STYLE = Automatic;
+ "HEADER_SEARCH_PATHS[arch=*]" = (
+ $SRCROOT/../../../../../../../,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/eigen,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/gemmlowp,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/neon_2_sse,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/farmhash/src,
+ $SRCROOT/../../../../../../../tensorflow/contrib/lite/downloads/flatbuffers/include,
+ );
+ INFOPLIST_FILE = TFLiteBenchmark/Info.plist;
+ LD_RUNPATH_SEARCH_PATHS = (
+ "$(inherited)",
+ "@executable_path/Frameworks",
+ );
+ "LIBRARY_SEARCH_PATHS[arch=*]" = $SRCROOT/../../../../../../../tensorflow/contrib/lite/gen/lib;
+ PRODUCT_BUNDLE_IDENTIFIER = example.TFLiteBenchmark;
+ PRODUCT_NAME = "$(TARGET_NAME)";
+ TARGETED_DEVICE_FAMILY = "1,2";
+ };
+ name = Release;
+ };
+/* End XCBuildConfiguration section */
+
+/* Begin XCConfigurationList section */
+ 6FE93FF320D592D8008C9FE4 /* Build configuration list for PBXProject "TFLiteBenchmark" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 6FE9400C20D592DA008C9FE4 /* Debug */,
+ 6FE9400D20D592DA008C9FE4 /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+ 6FE9400E20D592DA008C9FE4 /* Build configuration list for PBXNativeTarget "TFLiteBenchmark" */ = {
+ isa = XCConfigurationList;
+ buildConfigurations = (
+ 6FE9400F20D592DA008C9FE4 /* Debug */,
+ 6FE9401020D592DA008C9FE4 /* Release */,
+ );
+ defaultConfigurationIsVisible = 0;
+ defaultConfigurationName = Release;
+ };
+/* End XCConfigurationList section */
+ };
+ rootObject = 6FE93FF020D592D8008C9FE4 /* Project object */;
+}
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h
new file mode 100644
index 0000000000..a55c03e00b
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.h
@@ -0,0 +1,22 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import <UIKit/UIKit.h>
+
+@interface AppDelegate : UIResponder <UIApplicationDelegate>
+
+@property(strong, nonatomic) UIWindow *window;
+
+@end
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m
new file mode 100644
index 0000000000..b1165940e9
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/AppDelegate.m
@@ -0,0 +1,27 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import "AppDelegate.h"
+
+@interface AppDelegate ()
+
+@end
+
+@implementation AppDelegate
+- (BOOL)application:(UIApplication *)application
+ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
+ return YES;
+}
+@end
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json
new file mode 100644
index 0000000000..d8db8d65fd
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/AppIcon.appiconset/Contents.json
@@ -0,0 +1,98 @@
+{
+ "images" : [
+ {
+ "idiom" : "iphone",
+ "size" : "20x20",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "20x20",
+ "scale" : "3x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "29x29",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "29x29",
+ "scale" : "3x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "40x40",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "40x40",
+ "scale" : "3x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "60x60",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "iphone",
+ "size" : "60x60",
+ "scale" : "3x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "20x20",
+ "scale" : "1x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "20x20",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "29x29",
+ "scale" : "1x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "29x29",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "40x40",
+ "scale" : "1x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "40x40",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "76x76",
+ "scale" : "1x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "76x76",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "ipad",
+ "size" : "83.5x83.5",
+ "scale" : "2x"
+ },
+ {
+ "idiom" : "ios-marketing",
+ "size" : "1024x1024",
+ "scale" : "1x"
+ }
+ ],
+ "info" : {
+ "version" : 1,
+ "author" : "xcode"
+ }
+} \ No newline at end of file
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json
new file mode 100644
index 0000000000..da4a164c91
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Assets.xcassets/Contents.json
@@ -0,0 +1,6 @@
+{
+ "info" : {
+ "version" : 1,
+ "author" : "xcode"
+ }
+} \ No newline at end of file
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard
new file mode 100644
index 0000000000..bfa3612941
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/LaunchScreen.storyboard
@@ -0,0 +1,25 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="13122.16" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" launchScreen="YES" useTraitCollections="YES" useSafeAreas="YES" colorMatched="YES" initialViewController="01J-lp-oVM">
+ <dependencies>
+ <plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="13104.12"/>
+ <capability name="Safe area layout guides" minToolsVersion="9.0"/>
+ <capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
+ </dependencies>
+ <scenes>
+ <!--View Controller-->
+ <scene sceneID="EHf-IW-A2E">
+ <objects>
+ <viewController id="01J-lp-oVM" sceneMemberID="viewController">
+ <view key="view" contentMode="scaleToFill" id="Ze5-6b-2t3">
+ <rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
+ <autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
+ <color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
+ <viewLayoutGuide key="safeArea" id="6Tk-OE-BBY"/>
+ </view>
+ </viewController>
+ <placeholder placeholderIdentifier="IBFirstResponder" id="iYj-Kq-Ea1" userLabel="First Responder" sceneMemberID="firstResponder"/>
+ </objects>
+ <point key="canvasLocation" x="53" y="375"/>
+ </scene>
+ </scenes>
+</document>
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard
new file mode 100644
index 0000000000..adcfe1ef4e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Base.lproj/Main.storyboard
@@ -0,0 +1,60 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="14269.12" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES" useTraitCollections="YES" useSafeAreas="YES" colorMatched="YES" initialViewController="BYZ-38-t0r">
+ <device id="retina4_7" orientation="portrait">
+ <adaptation id="fullscreen"/>
+ </device>
+ <dependencies>
+ <deployment identifier="iOS"/>
+ <plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="14252.5"/>
+ <capability name="Safe area layout guides" minToolsVersion="9.0"/>
+ <capability name="documents saved in the Xcode 8 format" minToolsVersion="8.0"/>
+ </dependencies>
+ <scenes>
+ <!--Benchmark View Controller-->
+ <scene sceneID="tne-QT-ifu">
+ <objects>
+ <viewController id="BYZ-38-t0r" customClass="BenchmarkViewController" sceneMemberID="viewController">
+ <view key="view" contentMode="scaleToFill" id="8bC-Xf-vdC">
+ <rect key="frame" x="0.0" y="0.0" width="375" height="667"/>
+ <autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
+ <subviews>
+ <button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="j0O-Lq-1tJ">
+ <rect key="frame" x="64" y="20" width="247" height="63"/>
+ <constraints>
+ <constraint firstAttribute="height" constant="63" id="8VO-Ln-L2h"/>
+ </constraints>
+ <fontDescription key="fontDescription" type="system" pointSize="24"/>
+ <state key="normal" title="Benchmark model"/>
+ <connections>
+ <action selector="onBenchmarkModel:" destination="BYZ-38-t0r" eventType="touchUpInside" id="Rb1-hs-Mub"/>
+ </connections>
+ </button>
+ <textView clipsSubviews="YES" multipleTouchEnabled="YES" contentMode="scaleToFill" textAlignment="natural" translatesAutoresizingMaskIntoConstraints="NO" id="Vd4-Gf-qKO">
+ <rect key="frame" x="26" y="101" width="333" height="556"/>
+ <color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="genericGamma22GrayColorSpace"/>
+ <fontDescription key="fontDescription" type="system" pointSize="14"/>
+ <textInputTraits key="textInputTraits" autocapitalizationType="sentences"/>
+ </textView>
+ </subviews>
+ <color key="backgroundColor" red="1" green="1" blue="1" alpha="1" colorSpace="custom" customColorSpace="sRGB"/>
+ <constraints>
+ <constraint firstItem="Vd4-Gf-qKO" firstAttribute="top" secondItem="j0O-Lq-1tJ" secondAttribute="bottom" constant="18" id="Kd3-pP-C1k"/>
+ <constraint firstItem="j0O-Lq-1tJ" firstAttribute="centerX" secondItem="8bC-Xf-vdC" secondAttribute="centerX" id="QJU-cq-L87"/>
+ <constraint firstItem="Vd4-Gf-qKO" firstAttribute="trailing" secondItem="8bC-Xf-vdC" secondAttribute="trailingMargin" id="Tew-W4-Vq5"/>
+ <constraint firstItem="j0O-Lq-1tJ" firstAttribute="top" secondItem="6Tk-OE-BBY" secondAttribute="top" id="Uce-n7-kZI"/>
+ <constraint firstItem="j0O-Lq-1tJ" firstAttribute="leading" secondItem="6Tk-OE-BBY" secondAttribute="leading" constant="64" id="Uhq-Rw-NKT"/>
+ <constraint firstItem="Vd4-Gf-qKO" firstAttribute="leading" secondItem="6Tk-OE-BBY" secondAttribute="leading" constant="26" id="aXc-6M-kyL"/>
+ <constraint firstItem="6Tk-OE-BBY" firstAttribute="bottom" secondItem="Vd4-Gf-qKO" secondAttribute="bottom" constant="10" id="tz5-wP-LZs"/>
+ </constraints>
+ <viewLayoutGuide key="safeArea" id="6Tk-OE-BBY"/>
+ </view>
+ <connections>
+ <outlet property="resultsView" destination="Vd4-Gf-qKO" id="dBT-f6-SYw"/>
+ </connections>
+ </viewController>
+ <placeholder placeholderIdentifier="IBFirstResponder" id="dkx-z0-nzr" sceneMemberID="firstResponder"/>
+ </objects>
+ <point key="canvasLocation" x="140" y="122.78860569715144"/>
+ </scene>
+ </scenes>
+</document>
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h
new file mode 100644
index 0000000000..ec6dea0546
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.h
@@ -0,0 +1,21 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import <UIKit/UIKit.h>
+
+@interface BenchmarkViewController : UIViewController
+@property(weak, nonatomic) IBOutlet UITextView *resultsView;
+
+@end
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm
new file mode 100644
index 0000000000..356d5b0e17
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/BenchmarkViewController.mm
@@ -0,0 +1,125 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import "BenchmarkViewController.h"
+#import <algorithm>
+#import <sstream>
+#import <string>
+#import <vector>
+#import "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h"
+#import "tensorflow/contrib/lite/tools/benchmark/logging.h"
+
+namespace {
+NSString* FilePathForResourceName(NSString* filename) {
+ NSString* name = [filename stringByDeletingPathExtension];
+ NSString* extension = [filename pathExtension];
+ NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
+ if (file_path == NULL) {
+ TFLITE_LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String]
+ << "' in bundle.";
+ }
+ return file_path;
+}
+
+NSDictionary* ParseJson() {
+ NSString* params_json_path = FilePathForResourceName(@"benchmark_params.json");
+ NSData* data = [NSData dataWithContentsOfFile:params_json_path];
+ return [NSJSONSerialization JSONObjectWithData:data options:kNilOptions error:nil];
+}
+
+std::string FormatCommandLineParam(NSString* key, NSString* value) {
+ std::ostringstream stream;
+ stream << "--" << [key UTF8String] << "=" << [value UTF8String];
+ return stream.str();
+}
+
+// Reads the |benchmark_params.json| to read command line parameters and returns them as a vector of
+// strings.
+void ReadCommandLineParameters(std::vector<std::string>* params) {
+ NSDictionary* param_dict = ParseJson();
+ for (NSString* key in param_dict) {
+ NSString* value = param_dict[key];
+ if ([key isEqualToString:@"graph"]) {
+ value = FilePathForResourceName(value);
+ }
+ params->push_back(FormatCommandLineParam(key, value));
+ }
+}
+std::vector<char*> StringVecToCharPtrVec(const std::vector<std::string>& str_vec) {
+ std::vector<char*> charptr_vec;
+ std::transform(str_vec.begin(), str_vec.end(), std::back_inserter(charptr_vec),
+ [](const std::string& s) -> char* { return const_cast<char*>(s.c_str()); });
+ return charptr_vec;
+}
+
+class ResultsListener : public tflite::benchmark::BenchmarkListener {
+ public:
+ void OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) override;
+ std::string Results() { return results_; }
+
+ private:
+ std::string results_;
+};
+
+void OutputMicrosecondsStatToStream(const tensorflow::Stat<int64_t>& time_us,
+ const std::string& prefix, std::ostringstream* stream) {
+ *stream << prefix << "Num runs: " << time_us.count() << "\n";
+
+ *stream << prefix << "Average: " << time_us.avg() / 1e3 << " ms\n";
+ *stream << prefix << "Min: " << time_us.min() / 1e3 << " ms \n";
+ *stream << prefix << "Max: " << time_us.max() / 1e3 << " ms \n";
+ *stream << prefix << "Std deviation: " << time_us.std_deviation() / 1e3 << " ms\n";
+}
+
+void ResultsListener::OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults& results) {
+ std::ostringstream stream;
+ const std::string prefix = " - ";
+ stream << "Startup latency: ";
+ stream << results.startup_latency_us() / 1e3 << " ms\n";
+ stream << "\nInference:\n";
+ OutputMicrosecondsStatToStream(results.inference_time_us(), prefix, &stream);
+ stream << "\nWarmup:\n";
+ OutputMicrosecondsStatToStream(results.warmup_time_us(), prefix, &stream);
+
+ results_ = stream.str();
+}
+
+std::string RunBenchmark() {
+ ResultsListener listener;
+ tflite::benchmark::BenchmarkTfLiteModel benchmark;
+ benchmark.AddListener(&listener);
+ // TODO(shashishekhar): Passing arguments like this is brittle, refactor the BenchmarkParams
+ // so that it contains arguments for BenchmarkTfLiteModel and set parameters using BenchmarkParams
+ std::vector<std::string> command_line_params;
+ // Benchmark model expects first arg to be program name.
+ // push a string for name of program.
+ command_line_params.push_back("benchmark_tflite_model");
+ ReadCommandLineParameters(&command_line_params);
+ std::vector<char*> argv = StringVecToCharPtrVec(command_line_params);
+ int argc = static_cast<int>(argv.size());
+ benchmark.Run(argc, argv.data());
+ return listener.Results();
+}
+} // namespace
+
+@interface BenchmarkViewController ()
+@end
+
+@implementation BenchmarkViewController
+- (IBAction)onBenchmarkModel:(UIButton*)sender {
+ std::string results = RunBenchmark();
+ [_resultsView setText:[NSString stringWithUTF8String:results.c_str()]];
+}
+@end
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist
new file mode 100644
index 0000000000..96051cf08f
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/Info.plist
@@ -0,0 +1,43 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
+<plist version="1.0">
+<dict>
+ <key>UILaunchStoryboardName</key>
+ <string>Main</string>
+ <key>CFBundleDevelopmentRegion</key>
+ <string>$(DEVELOPMENT_LANGUAGE)</string>
+ <key>CFBundleExecutable</key>
+ <string>$(EXECUTABLE_NAME)</string>
+ <key>CFBundleIdentifier</key>
+ <string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
+ <key>CFBundleInfoDictionaryVersion</key>
+ <string>6.0</string>
+ <key>CFBundleName</key>
+ <string>$(PRODUCT_NAME)</string>
+ <key>CFBundlePackageType</key>
+ <string>APPL</string>
+ <key>CFBundleShortVersionString</key>
+ <string>1.0</string>
+ <key>CFBundleVersion</key>
+ <string>1</string>
+ <key>LSRequiresIPhoneOS</key>
+ <true/>
+ <key>UIMainStoryboardFile</key>
+ <string>Main</string>
+ <key>UIRequiredDeviceCapabilities</key>
+ <array>
+ <string>armv7</string>
+ </array>
+ <key>UISupportedInterfaceOrientations</key>
+ <array>
+ <string>UIInterfaceOrientationPortrait</string>
+ </array>
+ <key>UISupportedInterfaceOrientations~ipad</key>
+ <array>
+ <string>UIInterfaceOrientationPortrait</string>
+ <string>UIInterfaceOrientationPortraitUpsideDown</string>
+ <string>UIInterfaceOrientationLandscapeLeft</string>
+ <string>UIInterfaceOrientationLandscapeRight</string>
+ </array>
+</dict>
+</plist>
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json
new file mode 100644
index 0000000000..d344a7a5ef
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/benchmark_data/benchmark_params.json
@@ -0,0 +1,10 @@
+{
+ "benchmark_name" : "mobile_net_benchmark",
+ "num_threads" : "4",
+ "num_runs" : "20",
+ "warmup_runs" : "1",
+ "graph" : "mobilenet_v1_1.0_224.tflite",
+ "input_layer" : "input",
+ "input_layer_shape" : "1,224,224,3",
+ "run_delay" : "-1"
+}
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m
new file mode 100644
index 0000000000..1e70b9cd1d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/TFLiteBenchmark/TFLiteBenchmark/main.m
@@ -0,0 +1,23 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#import <UIKit/UIKit.h>
+#import "AppDelegate.h"
+
+int main(int argc, char* argv[]) {
+ @autoreleasepool {
+ return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class]));
+ }
+}
diff --git a/tensorflow/contrib/lite/tools/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h
index aa1fa5b827..9e9292e2fe 100644
--- a/tensorflow/contrib/lite/tools/logging.h
+++ b/tensorflow/contrib/lite/tools/benchmark/logging.h
@@ -18,6 +18,7 @@ limitations under the License.
// LOG and CHECK macros for benchmarks.
+#include <cstdlib>
#include <iostream>
#include <sstream>
@@ -72,4 +73,4 @@ class LoggingWrapper {
#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b)
-#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_
diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc
index ce8a7857d2..ad7d59ecb4 100644
--- a/tensorflow/contrib/lite/tools/verifier_test.cc
+++ b/tensorflow/contrib/lite/tools/verifier_test.cc
@@ -41,7 +41,7 @@ class TfLiteFlatbufferModelBuilder {
}
TfLiteFlatbufferModelBuilder(const std::vector<BuiltinOperator>& builtin_ops,
- const std::vector<string>& custom_ops) {
+ const std::vector<std::string>& custom_ops) {
buffers_.push_back(
CreateBuffer(builder_, builder_.CreateVector(std::vector<uint8_t>{})));
@@ -194,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) {
/*operators=*/0, builder.CreateString("Main"))});
auto buffers = builder.CreateVector(std::vector<Offset<Buffer>>{
- CreateBuffer(builder,
- builder.CreateVector(std::vector<uint8>{1, 2, 3, 4, 5, 6})),
+ CreateBuffer(builder, builder.CreateVector(
+ std::vector<uint8_t>{1, 2, 3, 4, 5, 6})),
});
auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0,
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index f571dd59da..e07f899e4d 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -28,11 +28,24 @@ import json
import os
import sys
+from tensorflow.python.platform import resource_loader
+
# Schema to use for flatbuffers
_SCHEMA = "third_party/tensorflow/contrib/lite/schema/schema.fbs"
-# Where the binary will be once built in for the flatc converter
-_BINARY = "third_party/flatbuffers/flatc"
+# TODO(angerson): fix later when rules are simplified..
+_SCHEMA = resource_loader.get_path_to_datafile("../schema/schema.fbs")
+_BINARY = resource_loader.get_path_to_datafile("../../../../flatbuffers/flatc")
+# Account for different package positioning internal vs. external.
+if not os.path.exists(_BINARY):
+ _BINARY = resource_loader.get_path_to_datafile(
+ "../../../../../flatbuffers/flatc")
+
+if not os.path.exists(_SCHEMA):
+ raise RuntimeError("Sorry, schema file cannot be found at %r" % _SCHEMA)
+if not os.path.exists(_BINARY):
+ raise RuntimeError("Sorry, flatc is not available at %r" % _BINARY)
+
# A CSS description for making the visualizer
_CSS = """
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 5d4682ec9f..889accdd5a 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib import lookup
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes
def test_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
+
+ if not context.executing_eagerly():
+ with self.assertRaises(errors_impl.OpError):
+ self.evaluate(table.lookup(
+ constant_op.constant(("salad", "surgery", "tarkus"))))
+ else:
+ # Reinitializing a table in eager should work.
table = lookup.index_table_from_tensor(
mapping=("brain", "salad", "surgery"), num_oov_buckets=1)
- ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
-
- self.assertRaises(errors_impl.OpError, ids.eval)
- lookup_ops.tables_initializer().run()
- self.assertAllEqual((1, 2, 3), ids.eval())
+ self.evaluate(lookup_ops.tables_initializer())
+ ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
+ self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
@@ -1662,7 +1670,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
f.write("\n".join(values) + "\n")
return vocabulary_file
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
default_value = -1
diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh
index fc88f59e09..fb9e77ae1b 100755
--- a/tensorflow/contrib/makefile/build_all_android.sh
+++ b/tensorflow/contrib/makefile/build_all_android.sh
@@ -30,6 +30,14 @@ arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64 tegra)"
exit 1
}
+echo "********************************************************************"
+echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference."
+echo "You are currently using an older version. Please switch over to TensorFlow Lite."
+echo ""
+echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite"
+echo "********************************************************************"
+echo ""
+
if [[ -z "${NDK_ROOT}" ]]; then
echo "NDK_ROOT should be set as an environment variable" 1>&2
exit 1
diff --git a/tensorflow/contrib/makefile/build_all_ios.sh b/tensorflow/contrib/makefile/build_all_ios.sh
index 0a458a27b3..1d4677ef4b 100755
--- a/tensorflow/contrib/makefile/build_all_ios.sh
+++ b/tensorflow/contrib/makefile/build_all_ios.sh
@@ -31,6 +31,14 @@ usage() {
exit 1
}
+echo "********************************************************************"
+echo "TensorFlow Lite is the recommended library for mobile and embedded machine learning inference."
+echo "You are currently using an older version. Please switch over to TensorFlow Lite."
+echo ""
+echo "Link to the code: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite"
+echo "********************************************************************"
+echo ""
+
DEFAULT_ARCH="i386 x86_64 armv7 armv7s arm64"
while getopts "a:g:T" opt_name; do
case "$opt_name" in
diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh
index eff9081e35..48953e2e38 100755
--- a/tensorflow/contrib/makefile/download_dependencies.sh
+++ b/tensorflow/contrib/makefile/download_dependencies.sh
@@ -27,9 +27,7 @@ if [ ! -f $BZL_FILE_PATH ]; then
fi
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)"
-# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once
-# the archive has been propagated in mirror.bazel.build.
-GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
+GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 89db9ee279..6e7423f85e 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/queue_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 4f2c82ca23..21cd34f73f 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -31,6 +31,7 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:confusion_matrix",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:histogram_ops",
"//tensorflow/python:init_ops",
@@ -77,7 +78,31 @@ py_test(
py_test(
name = "metric_ops_test",
srcs = ["python/ops/metric_ops_test.py"],
- shard_count = 16,
+ shard_count = 30,
+ srcs_version = "PY2AND3",
+ tags = ["noasan"], # times out b/63678675
+ deps = [
+ ":metrics_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "metric_ops_large_test",
+ size = "large",
+ srcs = ["python/ops/metric_ops_large_test.py"],
srcs_version = "PY2AND3",
tags = ["noasan"], # times out b/63678675
deps = [
diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 5effea3596..88798d61b7 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -63,6 +63,7 @@ See the @{$python/contrib.metrics} guide.
@@aggregate_metrics
@@aggregate_metric_map
@@confusion_matrix
+@@f1_score
@@set_difference
@@set_intersection
@@set_size
diff --git a/tensorflow/contrib/metrics/python/metrics/classification.py b/tensorflow/contrib/metrics/python/metrics/classification.py
index 26aba1cc51..e553612269 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification.py
@@ -22,6 +22,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics_impl
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import distribute as distribute_lib
# TODO(nsilberman): move into metrics/python/ops/
@@ -62,3 +65,121 @@ def accuracy(predictions, labels, weights=None, name=None):
return math_ops.div(math_ops.reduce_sum(is_correct),
math_ops.reduce_sum(num_values))
return math_ops.reduce_mean(is_correct)
+
+
+def f1_score(labels, predictions, weights=None, num_thresholds=200,
+ metrics_collections=None, updates_collections=None, name=None):
+ """Computes the approximately best F1-score across different thresholds.
+
+ The f1_score function applies a range of thresholds to the predictions to
+ convert them from [0, 1] to bool. Precision and recall are computed by
+ comparing them to the labels. The F1-Score is then defined as
+ 2 * precision * recall / (precision + recall). The best one across the
+ thresholds is returned.
+
+ Disclaimer: In practice it may be desirable to choose the best threshold on
+ the validation set and evaluate the F1 score with this threshold on a
+ separate test set. Or it may be desirable to use a fixed threshold (e.g. 0.5).
+
+ This function internally creates four local variables, `true_positives`,
+ `true_negatives`, `false_positives` and `false_negatives` that are used to
+ compute the pairs of recall and precision values for a linearly spaced set of
+ thresholds from which the best f1-score is derived.
+
+ This value is ultimately returned as `f1-score`, an idempotent operation that
+ computes the F1-score (computed using the aforementioned variables). The
+ `num_thresholds` variable controls the degree of discretization with larger
+ numbers of thresholds more closely approximating the true best F1-score.
+
+ For estimation of the metric over a stream of data, the function creates an
+ `update_op` operation that updates these variables and returns the F1-score.
+
+ Example usage with a custom estimator:
+ def model_fn(features, labels, mode):
+ predictions = make_predictions(features)
+ loss = make_loss(predictions, labels)
+ train_op = tf.contrib.training.create_train_op(
+ total_loss=loss,
+ optimizer='Adam')
+ eval_metric_ops = {'f1': f1_score(labels, predictions)}
+ return tf.estimator.EstimatorSpec(
+ mode=mode,
+ predictions=predictions,
+ loss=loss,
+ train_op=train_op,
+ eval_metric_ops=eval_metric_ops,
+ export_outputs=export_outputs)
+ estimator = tf.estimator.Estimator(model_fn=model_fn)
+
+ If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
+
+ Args:
+ labels: A `Tensor` whose shape matches `predictions`. Will be cast to
+ `bool`.
+ predictions: A floating point `Tensor` of arbitrary shape and whose values
+ are in the range `[0, 1]`.
+ weights: Optional `Tensor` whose rank is either 0, or the same rank as
+ `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
+ be either `1`, or the same as the corresponding `labels` dimension).
+ num_thresholds: The number of thresholds to use when discretizing the roc
+ curve.
+ metrics_collections: An optional list of collections that `f1_score` should
+ be added to.
+ updates_collections: An optional list of collections that `update_op` should
+ be added to.
+ name: An optional variable_scope name.
+
+ Returns:
+ f1_score: A scalar `Tensor` representing the current best f1-score across
+ different thresholds.
+ update_op: An operation that increments the `true_positives`,
+ `true_negatives`, `false_positives` and `false_negatives` variables
+ appropriately and whose value matches the `f1_score`.
+
+ Raises:
+ ValueError: If `predictions` and `labels` have mismatched shapes, or if
+ `weights` is not `None` and its shape doesn't match `predictions`, or if
+ either `metrics_collections` or `updates_collections` are not a list or
+ tuple.
+ """
+ with variable_scope.variable_scope(
+ name, 'f1', (labels, predictions, weights)):
+ predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access
+ predictions=predictions, labels=labels, weights=weights)
+ # To account for floating point imprecisions / avoid division by zero.
+ epsilon = 1e-7
+ thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
+ for i in range(num_thresholds - 2)]
+ thresholds = [0.0 - epsilon] + thresholds + [1.0 + epsilon]
+
+ # Confusion matrix.
+ values, update_ops = metrics_impl._confusion_matrix_at_thresholds( # pylint: disable=protected-access
+ labels, predictions, thresholds, weights, includes=('tp', 'fp', 'fn'))
+
+ # Compute precision and recall at various thresholds.
+ def compute_best_f1_score(tp, fp, fn, name):
+ precision_at_t = math_ops.div(tp, epsilon + tp + fp,
+ name='precision_' + name)
+ recall_at_t = math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
+ # Compute F1 score.
+ f1_at_thresholds = (
+ 2.0 * precision_at_t * recall_at_t /
+ (precision_at_t + recall_at_t + epsilon))
+ return math_ops.reduce_max(f1_at_thresholds)
+
+ def f1_across_towers(_, values):
+ best_f1 = compute_best_f1_score(tp=values['tp'], fp=values['fp'],
+ fn=values['fn'], name='value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, best_f1)
+ return best_f1
+
+ best_f1 = distribute_lib.get_tower_context().merge_call(
+ f1_across_towers, values)
+
+ update_op = compute_best_f1_score(tp=update_ops['tp'], fp=update_ops['fp'],
+ fn=update_ops['fn'], name='update')
+ if updates_collections:
+ ops.add_to_collections(updates_collections, update_op)
+
+ return best_f1, update_op
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py
index fa0f12d029..3d0b81c1be 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification_test.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py
@@ -18,9 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.metrics.python.metrics import classification
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -108,5 +115,200 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
+class F1ScoreTest(test.TestCase):
+
+ def setUp(self):
+ super(F1ScoreTest, self).setUp()
+ np.random.seed(1)
+
+ def testVars(self):
+ classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3)
+ expected = {'f1/true_positives:0', 'f1/false_positives:0',
+ 'f1/false_negatives:0'}
+ self.assertEquals(
+ expected, set(v.name for v in variables.local_variables()))
+ self.assertEquals(
+ set(expected), set(v.name for v in variables.local_variables()))
+ self.assertEquals(
+ set(expected),
+ set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
+
+ def testMetricsCollection(self):
+ my_collection_name = '__metrics__'
+ f1, _ = classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3,
+ metrics_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [f1])
+
+ def testUpdatesCollection(self):
+ my_collection_name = '__updates__'
+ _, f1_op = classification.f1_score(
+ predictions=array_ops.ones((10, 1)),
+ labels=array_ops.ones((10, 1)),
+ num_thresholds=3,
+ updates_collections=[my_collection_name])
+ self.assertListEqual(ops.get_collection(my_collection_name), [f1_op])
+
+ def testValueTensorIsIdempotent(self):
+ predictions = random_ops.random_uniform(
+ (10, 3), maxval=1, dtype=dtypes.float32, seed=1)
+ labels = random_ops.random_uniform(
+ (10, 3), maxval=2, dtype=dtypes.int64, seed=2)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+
+ # Run several updates.
+ for _ in range(10):
+ sess.run([f1_op])
+
+ # Then verify idempotency.
+ initial_f1 = f1.eval()
+ for _ in range(10):
+ self.assertAllClose(initial_f1, f1.eval())
+
+ def testAllCorrect(self):
+ inputs = np.random.randint(0, 2, size=(100, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes.float32)
+ labels = constant_op.constant(inputs)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertEqual(1, f1.eval())
+
+ def testSomeCorrect(self):
+ predictions = constant_op.constant(
+ [1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32)
+ labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+ # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
+ # score of 2 * 0.5 * 1 / (1 + 0.5).
+ self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval())
+
+ def testAllIncorrect(self):
+ inputs = np.random.randint(0, 2, size=(10000, 1))
+
+ with self.test_session() as sess:
+ predictions = constant_op.constant(inputs, dtype=dtypes.float32)
+ labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ # Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
+ # score of 2 * 0.5 * 1 / (1 + 0.5).
+ self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2)
+
+ def testWeights1d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0], [1]], shape=(2, 1), dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, weights,
+ num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(1.0, f1.eval(), places=5)
+
+ def testWeights2d(self):
+ with self.test_session() as sess:
+ predictions = constant_op.constant(
+ [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
+ labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ weights = constant_op.constant(
+ [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes.float32)
+ f1, f1_op = classification.f1_score(predictions, labels, weights,
+ num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(1.0, f1.eval(), places=5)
+
+ def testZeroLabelsPredictions(self):
+ with self.test_session() as sess:
+ predictions = array_ops.zeros([4], dtype=dtypes.float32)
+ labels = array_ops.zeros([4])
+ f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
+ sess.run(variables.local_variables_initializer())
+ sess.run([f1_op])
+
+ self.assertAlmostEqual(0.0, f1.eval(), places=5)
+
+ def testWithMultipleUpdates(self):
+ num_samples = 1000
+ batch_size = 10
+ num_batches = int(num_samples / batch_size)
+
+ # Create the labels and data.
+ labels = np.random.randint(0, 2, size=(num_samples, 1))
+ noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
+ predictions = 0.4 + 0.2 * labels + noise
+ predictions[predictions > 1] = 1
+ predictions[predictions < 0] = 0
+ thresholds = [-0.01, 0.5, 1.01]
+
+ expected_max_f1 = -1.0
+ for threshold in thresholds:
+ tp = 0
+ fp = 0
+ fn = 0
+ tn = 0
+ for i in range(num_samples):
+ if predictions[i] >= threshold:
+ if labels[i] == 1:
+ tp += 1
+ else:
+ fp += 1
+ else:
+ if labels[i] == 1:
+ fn += 1
+ else:
+ tn += 1
+ epsilon = 1e-7
+ expected_prec = tp / (epsilon + tp + fp)
+ expected_rec = tp / (epsilon + tp + fn)
+ expected_f1 = (2 * expected_prec * expected_rec /
+ (epsilon + expected_prec + expected_rec))
+ if expected_f1 > expected_max_f1:
+ expected_max_f1 = expected_f1
+
+ labels = labels.astype(np.float32)
+ predictions = predictions.astype(np.float32)
+ tf_predictions, tf_labels = (dataset_ops.Dataset
+ .from_tensor_slices((predictions, labels))
+ .repeat()
+ .batch(batch_size)
+ .make_one_shot_iterator()
+ .get_next())
+ f1, f1_op = classification.f1_score(tf_labels, tf_predictions,
+ num_thresholds=3)
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in range(num_batches):
+ sess.run([f1_op])
+ # Since this is only approximate, we can't expect a 6 digits match.
+ # Although with higher number of samples/thresholds we should see the
+ # accuracy improving
+ self.assertAlmostEqual(expected_max_f1, f1.eval(), 2)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 00a933e5e0..b14202ff9e 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1064,7 +1064,7 @@ def streaming_auc(predictions,
name=name)
-def _compute_dynamic_auc(labels, predictions, curve='ROC'):
+def _compute_dynamic_auc(labels, predictions, curve='ROC', weights=None):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
Computes the area under the ROC or PR curve using each prediction as a
@@ -1077,13 +1077,22 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
predictions: A 1-D `Tensor` of predictions whose values are `float64`.
curve: The name of the curve to be computed, 'ROC' for the Receiving
Operating Characteristic or 'PR' for the Precision-Recall curve.
+ weights: A 1-D `Tensor` of weights whose values are `float64`.
Returns:
A scalar `Tensor` containing the area-under-curve value for the input.
"""
- # Count the total number of positive and negative labels in the input.
+ # Compute the total weight and the total positive weight.
size = array_ops.size(predictions)
- total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
+ if weights is None:
+ weights = array_ops.ones_like(labels, dtype=dtypes.float64)
+ labels, predictions, weights = metrics_impl._remove_squeezable_dimensions(
+ labels, predictions, weights)
+ total_weight = math_ops.reduce_sum(weights)
+ total_positive = math_ops.reduce_sum(
+ array_ops.where(
+ math_ops.greater(labels, 0), weights,
+ array_ops.zeros_like(labels, dtype=dtypes.float64)))
def continue_computing_dynamic_auc():
"""Continues dynamic auc computation, entered if labels are not all equal.
@@ -1091,9 +1100,11 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
Returns:
A scalar `Tensor` containing the area-under-curve value.
"""
- # Sort the predictions descending, and the corresponding labels as well.
+ # Sort the predictions descending, keeping the same order for the
+ # corresponding labels and weights.
ordered_predictions, indices = nn.top_k(predictions, k=size)
ordered_labels = array_ops.gather(labels, indices)
+ ordered_weights = array_ops.gather(weights, indices)
# Get the counts of the unique ordered predictions.
_, _, counts = array_ops.unique_with_counts(ordered_predictions)
@@ -1103,23 +1114,39 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
# Count the positives to the left of the split indices.
- positives = math_ops.cast(
- array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
- dtypes.int32)
- true_positives = array_ops.gather(positives, splits)
+ true_positives = array_ops.gather(
+ array_ops.pad(
+ math_ops.cumsum(
+ array_ops.where(
+ math_ops.greater(ordered_labels, 0), ordered_weights,
+ array_ops.zeros_like(ordered_labels,
+ dtype=dtypes.float64))),
+ paddings=[[1, 0]]), splits)
if curve == 'ROC':
- # Count the negatives to the left of every split point and the total
- # number of negatives for computing the FPR.
- false_positives = math_ops.subtract(splits, true_positives)
- total_negative = size - total_positive
+ # Compute the weight of the negatives to the left of every split point and
+ # the total weight of the negatives number of negatives for computing the
+ # FPR.
+ false_positives = array_ops.gather(
+ array_ops.pad(
+ math_ops.cumsum(
+ array_ops.where(
+ math_ops.less(ordered_labels, 1), ordered_weights,
+ array_ops.zeros_like(
+ ordered_labels, dtype=dtypes.float64))),
+ paddings=[[1, 0]]), splits)
+ total_negative = total_weight - total_positive
x_axis_values = math_ops.truediv(false_positives, total_negative)
y_axis_values = math_ops.truediv(true_positives, total_positive)
elif curve == 'PR':
x_axis_values = math_ops.truediv(true_positives, total_positive)
# For conformance, set precision to 1 when the number of positive
# classifications is 0.
+ positives = array_ops.gather(
+ array_ops.pad(math_ops.cumsum(ordered_weights), paddings=[[1, 0]]),
+ splits)
y_axis_values = array_ops.where(
- math_ops.greater(splits, 0), math_ops.truediv(true_positives, splits),
+ math_ops.greater(splits, 0),
+ math_ops.truediv(true_positives, positives),
array_ops.ones_like(true_positives, dtype=dtypes.float64))
# Calculate trapezoid areas.
@@ -1133,7 +1160,7 @@ def _compute_dynamic_auc(labels, predictions, curve='ROC'):
return control_flow_ops.cond(
math_ops.logical_or(
math_ops.equal(total_positive, 0), math_ops.equal(
- total_positive, size)),
+ total_positive, total_weight)),
true_fn=lambda: array_ops.constant(0, dtypes.float64),
false_fn=continue_computing_dynamic_auc)
@@ -1143,7 +1170,8 @@ def streaming_dynamic_auc(labels,
curve='ROC',
metrics_collections=(),
updates_collections=(),
- name=None):
+ name=None,
+ weights=None):
"""Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
USAGE NOTE: this approach requires storing all of the predictions and labels
@@ -1168,6 +1196,8 @@ def streaming_dynamic_auc(labels,
should be added to.
name: An optional name for the variable_scope that contains the metric
variables.
+ weights: A 'Tensor' of non-negative weights whose values are castable to
+ `float64`. Will be flattened into a 1-D `Tensor`.
Returns:
auc: A scalar `Tensor` containing the current area-under-curve value.
@@ -1195,14 +1225,24 @@ def streaming_dynamic_auc(labels,
check_ops.assert_less_equal(
labels,
array_ops.ones_like(labels, dtypes.int64),
- message='labels must be 0 or 1, at least one is >1')
+ message='labels must be 0 or 1, at least one is >1'),
]):
preds_accum, update_preds = streaming_concat(
predictions, name='concat_preds')
labels_accum, update_labels = streaming_concat(
labels, name='concat_labels')
- update_op = control_flow_ops.group(update_labels, update_preds)
- auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
+ if weights is not None:
+ weights = array_ops.reshape(
+ math_ops.cast(weights, dtypes.float64), [-1])
+ weights_accum, update_weights = streaming_concat(
+ weights, name='concat_weights')
+ update_op = control_flow_ops.group(update_labels, update_preds,
+ update_weights)
+ else:
+ weights_accum = None
+ update_op = control_flow_ops.group(update_labels, update_preds)
+ auc = _compute_dynamic_auc(
+ labels_accum, preds_accum, curve=curve, weights=weights_accum)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
if metrics_collections:
@@ -1544,7 +1584,7 @@ def precision_recall_at_equal_thresholds(labels,
result: A named tuple (See PrecisionRecallData within the implementation of
this function) with properties that are variables of shape
`[num_thresholds]`. The names of the properties are tp, fp, tn, fn,
- precision, recall, thresholds.
+ precision, recall, thresholds. Types are same as that of predictions.
update_op: An op that accumulates values.
Raises:
@@ -1570,7 +1610,6 @@ def precision_recall_at_equal_thresholds(labels,
check_ops.assert_type(labels, dtypes.bool)
- dtype = predictions.dtype
with variable_scope.variable_scope(name,
'precision_recall_at_equal_thresholds',
(labels, predictions, weights)):
@@ -1592,11 +1631,16 @@ def precision_recall_at_equal_thresholds(labels,
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
- # We cast to float to ensure we have 0.0 or 1.0.
- f_labels = math_ops.cast(labels, dtype)
+ # It's important we aggregate using float64 since we're accumulating a lot
+ # of 1.0's for the true/false labels, and accumulating to float32 will
+ # be quite inaccurate even with just a modest amount of values (~20M).
+ # We use float64 instead of integer primarily since GPU scatter kernel
+ # only support floats.
+ agg_dtype = dtypes.float64
- # Get weighted true/false labels.
- true_labels = f_labels * weights
+ f_labels = math_ops.cast(labels, agg_dtype)
+ weights = math_ops.cast(weights, agg_dtype)
+ true_labels = f_labels * weights
false_labels = (1.0 - f_labels) * weights
# Flatten predictions and labels.
@@ -1638,9 +1682,9 @@ def precision_recall_at_equal_thresholds(labels,
with ops.name_scope('variables'):
tp_buckets_v = metrics_impl.metric_variable(
- [num_thresholds], dtype, name='tp_buckets')
+ [num_thresholds], agg_dtype, name='tp_buckets')
fp_buckets_v = metrics_impl.metric_variable(
- [num_thresholds], dtype, name='fp_buckets')
+ [num_thresholds], agg_dtype, name='fp_buckets')
with ops.name_scope('update_op'):
update_tp = state_ops.scatter_add(
@@ -1660,18 +1704,21 @@ def precision_recall_at_equal_thresholds(labels,
fn = tp[0] - tp
# We use a minimum to prevent division by 0.
- epsilon = 1e-7
+ epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype)
precision = tp / math_ops.maximum(epsilon, tp + fp)
recall = tp / math_ops.maximum(epsilon, tp + fn)
+ # Convert all tensors back to predictions' dtype (as per function contract).
+ out_dtype = predictions.dtype
+ _convert = lambda tensor: math_ops.cast(tensor, out_dtype)
result = PrecisionRecallData(
- tp=tp,
- fp=fp,
- tn=tn,
- fn=fn,
- precision=precision,
- recall=recall,
- thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds))
+ tp=_convert(tp),
+ fp=_convert(fp),
+ tn=_convert(tn),
+ fn=_convert(fn),
+ precision=_convert(precision),
+ recall=_convert(recall),
+ thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds)))
update_op = control_flow_ops.group(update_tp, update_fp)
return result, update_op
@@ -2496,7 +2543,7 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name):
name: An optional variable_scope name.
Returns:
- The recall at a the given `precision`.
+ The recall at a given `precision`.
"""
precisions = math_ops.div(tp, tp + fp + _EPSILON)
tf_index = math_ops.argmin(
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
new file mode 100644
index 0000000000..7acfc383eb
--- /dev/null
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Large tests for metric_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.contrib.metrics.python.ops import metric_ops
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
+
+ def setUp(self):
+ np.random.seed(1)
+ ops.reset_default_graph()
+
+ def testLargeCase(self):
+ shape = [32, 512, 256, 1]
+ predictions = random_ops.random_uniform(
+ shape, 0.0, 1.0, dtype=dtypes_lib.float32)
+ labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5)
+
+ result, update_op = metric_ops.precision_recall_at_equal_thresholds(
+ labels=labels, predictions=predictions, num_thresholds=201)
+ # Run many updates, enough to cause highly inaccurate values if the
+ # code used float32 for accumulation.
+ num_updates = 71
+
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_updates):
+ sess.run(update_op)
+
+ prdata = sess.run(result)
+
+ # Since we use random values, we won't know the tp/fp/tn/fn values, but
+ # tp and fp at threshold 0 should be the total number of positive and
+ # negative labels, hence their sum should be total number of pixels.
+ expected_value = 1.0 * np.product(shape) * num_updates
+ got_value = prdata.tp[0] + prdata.fp[0]
+ # They should be at least within 1.
+ self.assertNear(got_value, expected_value, 1.0)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index e6f75fcbd7..a09fc4abd4 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -2127,6 +2127,44 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
+ def testWithWeights(self):
+ batch_size = 10
+ num_batches = 100
+ labels = np.array([])
+ predictions = np.array([])
+ weights = np.array([])
+ tf_labels = variables.Variable(
+ array_ops.ones(batch_size, dtypes_lib.int32),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.int32)
+ tf_predictions = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ tf_weights = variables.Variable(
+ array_ops.ones(batch_size),
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ dtype=dtypes_lib.float32)
+ auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
+ tf_predictions,
+ weights=tf_weights)
+ with self.test_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ for _ in xrange(num_batches):
+ new_labels = np.random.randint(0, 2, size=batch_size)
+ noise = np.random.uniform(-0.2, 0.2, size=batch_size)
+ new_predictions = 0.4 + 0.2 * new_labels + noise
+ new_weights = np.random.uniform(0.0, 3.0, size=batch_size)
+ labels = np.concatenate([labels, new_labels])
+ predictions = np.concatenate([predictions, new_predictions])
+ weights = np.concatenate([weights, new_weights])
+ sess.run([tf_labels.assign(new_labels),
+ tf_predictions.assign(new_predictions),
+ tf_weights.assign(new_weights)])
+ sess.run(update_op)
+ expected_auc = _np_auc(predictions, labels, weights)
+ self.assertAlmostEqual(expected_auc, auc.eval())
+
class AucWithConfidenceIntervalsTest(test.TestCase):
@@ -2333,47 +2371,24 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
np.random.seed(1)
ops.reset_default_graph()
- def _testResultsEqual(self, expected_dict, gotten_result):
+ def _testResultsEqual(self, expected_dict, gotten_result, eps=None):
"""Tests that 2 results (dicts) represent the same data.
Args:
expected_dict: A dictionary with keys that are the names of properties
of PrecisionRecallData and whose values are lists of floats.
gotten_result: A PrecisionRecallData object.
+ eps: Epsilon value to use for testing output values. If unspecified, use
+ default from assertAllClose.
"""
gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()}
self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys()))
for key, expected_values in expected_dict.items():
- self.assertAllClose(expected_values, gotten_dict[key])
-
- def _testCase(self, predictions, labels, expected_result, weights=None):
- """Performs a test given a certain scenario of labels, predictions, weights.
-
- Args:
- predictions: The predictions tensor. Of type float32.
- labels: The labels tensor. Of type bool.
- expected_result: The expected result (dict) that maps to tensors.
- weights: Optional weights tensor.
- """
- with self.test_session() as sess:
- predictions_tensor = constant_op.constant(
- predictions, dtype=dtypes_lib.float32)
- labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
- weights_tensor = None
- if weights:
- weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32)
- gotten_result, update_op = (
- metric_ops.precision_recall_at_equal_thresholds(
- labels=labels_tensor,
- predictions=predictions_tensor,
- weights=weights_tensor,
- num_thresholds=3))
-
- sess.run(variables.local_variables_initializer())
- sess.run(update_op)
-
- self._testResultsEqual(expected_result, gotten_result)
+ if eps is not None:
+ self.assertAllClose(expected_values, gotten_dict[key], atol=eps)
+ else:
+ self.assertAllClose(expected_values, gotten_dict[key])
def testVars(self):
metric_ops.precision_recall_at_equal_thresholds(
@@ -2414,6 +2429,50 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
for _ in range(3):
self._testResultsEqual(initial_result, result)
+ def _testCase(self,
+ predictions,
+ labels,
+ expected_result,
+ dtype=dtypes_lib.float32,
+ eps=None,
+ weights=None):
+ """Performs a test given a certain scenario of labels, predictions, weights.
+
+ Args:
+ predictions: The predictions tensor. Of type dtype.
+ labels: The labels tensor. Of type bool.
+ expected_result: The expected result (dict) that maps to tensors.
+ dtype: Data type to use for predictions and weights tensor. Default
+ is float32.
+ eps: Epsilon value to use for testing output values. If unspecified, use
+ default from assertAllClose.
+ weights: Optional weights tensor.
+ """
+ with self.test_session() as sess:
+ predictions_tensor = constant_op.constant(predictions, dtype=dtype)
+ labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
+ weights_tensor = None
+ if weights:
+ weights_tensor = constant_op.constant(weights, dtype=dtype)
+ gotten_result, update_op = (
+ metric_ops.precision_recall_at_equal_thresholds(
+ labels=labels_tensor,
+ predictions=predictions_tensor,
+ weights=weights_tensor,
+ num_thresholds=3))
+ self.assertEqual(gotten_result.tp.dtype, dtype)
+ self.assertEqual(gotten_result.fp.dtype, dtype)
+ self.assertEqual(gotten_result.tn.dtype, dtype)
+ self.assertEqual(gotten_result.fn.dtype, dtype)
+ self.assertEqual(gotten_result.precision.dtype, dtype)
+ self.assertEqual(gotten_result.recall.dtype, dtype)
+ self.assertEqual(gotten_result.thresholds.dtype, dtype)
+
+ sess.run(variables.local_variables_initializer())
+ sess.run(update_op)
+
+ self._testResultsEqual(expected_result, gotten_result, eps=eps)
+
def testAllTruePositives(self):
self._testCase(
[[1]], [[True]], {
@@ -2489,6 +2548,35 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
},
weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]])
+ def testFloat64(self):
+ self._testCase(
+ [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
+ [[True, False, False, True, True, True]], {
+ 'tp': [4, 3, 0],
+ 'fp': [2, 0, 0],
+ 'tn': [0, 2, 2],
+ 'fn': [0, 1, 4],
+ 'precision': [2.0 / 3.0, 1.0, 0.0],
+ 'recall': [1.0, 0.75, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ },
+ dtype=dtypes_lib.float64)
+
+ def testFloat16(self):
+ self._testCase(
+ [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]],
+ [[True, False, False, True, True, True]], {
+ 'tp': [4, 3, 0],
+ 'fp': [2, 0, 0],
+ 'tn': [0, 2, 2],
+ 'fn': [0, 1, 4],
+ 'precision': [2.0 / 3.0, 1.0, 0.0],
+ 'recall': [1.0, 0.75, 0.0],
+ 'thresholds': [0.0, 0.5, 1.0],
+ },
+ dtype=dtypes_lib.float16,
+ eps=1e-3)
+
class StreamingSpecificityAtSensitivityTest(test.TestCase):
@@ -4649,199 +4737,204 @@ class StreamingSparseRecallTest(test.TestCase):
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2)
- def test_one_label_at_k1_weighted(self):
+ def _test_one_label_at_k1_weighted(self, labels):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
- sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
- [0, 0, 1, 0]])
- dense_labels = np.array([[3], [2]], dtype=np.int64)
- for labels in (sparse_labels, dense_labels):
- # Class 3: 1 label, 2 predictions, 1 correct.
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0,))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(2.0,))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(2.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=NAN,
- class_id=3,
- weights=(0.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=NAN,
- class_id=3,
- weights=(0.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=NAN,
- class_id=3,
- weights=(0.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=NAN,
- class_id=3,
- weights=(0.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=1.0 / 1,
- class_id=3,
- weights=(1.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=2.0 / 2,
- class_id=3,
- weights=(2.0, 3.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=2.0 / 2,
- class_id=3,
- weights=(2.0, 3.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=3.0 / 3,
- class_id=3,
- weights=(3.0, 2.0))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=3.0 / 3,
- class_id=3,
- weights=(3.0, 2.0))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=0.3 / 0.3,
- class_id=3,
- weights=(0.3, 0.6))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=0.3 / 0.3,
- class_id=3,
- weights=(0.3, 0.6))
- self._test_streaming_sparse_recall_at_k(
- predictions,
- labels,
- k=1,
- expected=0.6 / 0.6,
- class_id=3,
- weights=(0.6, 0.3))
- self._test_sparse_recall_at_top_k(
- labels,
- top_k_predictions,
- expected=0.6 / 0.6,
- class_id=3,
- weights=(0.6, 0.3))
+ # Class 3: 1 label, 2 predictions, 1 correct.
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(2.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=NAN,
+ class_id=3,
+ weights=(0.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=1.0 / 1,
+ class_id=3,
+ weights=(1.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=2.0 / 2,
+ class_id=3,
+ weights=(2.0, 3.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=3.0 / 3,
+ class_id=3,
+ weights=(3.0, 2.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.3 / 0.3,
+ class_id=3,
+ weights=(0.3, 0.6))
+ self._test_streaming_sparse_recall_at_k(
+ predictions,
+ labels,
+ k=1,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels,
+ top_k_predictions,
+ expected=0.6 / 0.6,
+ class_id=3,
+ weights=(0.6, 0.3))
- # All classes: 2 labels, 2 predictions, 1 correct.
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=NAN, weights=(0.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=NAN, weights=(0.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
+ # All classes: 2 labels, 2 predictions, 1 correct.
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=NAN, weights=(0.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=NAN, weights=(0.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
- self._test_streaming_sparse_recall_at_k(
- predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
- self._test_sparse_recall_at_top_k(
- labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_streaming_sparse_recall_at_k(
+ predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
+ self._test_sparse_recall_at_top_k(
+ labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
+
+ def test_one_label_at_k1_weighted_sparse_labels(self):
+ sparse_labels = _binary_2d_label_to_sparse_value([[0, 0, 0, 1],
+ [0, 0, 1, 0]])
+ self._test_one_label_at_k1_weighted(sparse_labels)
+
+ def test_one_label_at_k1_weighted_dense_labels(self):
+ dense_labels = np.array([[3], [2]], dtype=np.int64)
+ self._test_one_label_at_k1_weighted(dense_labels)
def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
index 480f5f6eaf..1b0383d24c 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_manager_test.py
@@ -34,7 +34,7 @@ def _GetExampleIter(inputs):
class FixedLossScaleManagerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_basic(self):
itr = _GetExampleIter([True] * 10 + [False] * 10)
@@ -84,13 +84,13 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
actual_outputs.append(self.evaluate(lsm.get_loss_scale()))
self.assertEqual(actual_outputs, expected_outputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_increase_every_n_steps(self):
inputs = [True] * 6
expected_outputs = [1, 2, 2, 4, 4, 8]
self._test_helper(inputs, expected_outputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_keep_increasing_until_capped(self):
init_loss_scale = np.finfo(np.float32).max / 4 + 10
max_float = np.finfo(np.float32).max
@@ -104,7 +104,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_decrease_every_n_steps(self):
inputs = [False] * 6
init_loss_scale = 1024
@@ -112,7 +112,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_keep_decreasing_until_one(self):
inputs = [False] * 10
init_loss_scale = 16
@@ -120,19 +120,19 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_incr_bad_step_clear_good_step(self):
inputs = [True, True, True, False, True]
expected_outputs = [1, 2, 2, 2, 2]
self._test_helper(inputs, expected_outputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_incr_good_step_does_not_clear_bad_step(self):
inputs = [True, True, True, False, True, False]
expected_outputs = [1, 2, 2, 2, 2, 1]
self._test_helper(inputs, expected_outputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_trigger_loss_scale_update_each_step(self):
"""Test when incr_every_n_step and decr_every_n_nan_or_inf is 1."""
init_loss_scale = 1
@@ -145,7 +145,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale,
incr_every_n_step, decr_every_n_nan_or_inf)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_alternating_good_and_bad_gradients_trigger_each_step(self):
init_loss_scale = 1
incr_every_n_step = 1
@@ -156,7 +156,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale,
incr_every_n_step, decr_every_n_nan_or_inf)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_alternating_good_and_bad_gradients_trigger_incr_every_2steps(self):
init_loss_scale = 32
incr_every_n_step = 2
@@ -167,7 +167,7 @@ class ExponentialUpdateLossScaleManagerTest(test.TestCase):
self._test_helper(inputs, expected_outputs, init_loss_scale,
incr_every_n_step, decr_every_n_nan_or_inf)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_random_mix_good_and_bad_gradients(self):
init_loss_scale = 4
inputs = [
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index e4e5ccc334..93050a3ae3 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer
class LossScaleOptimizer(optimizer.Optimizer):
+ # TODO(jamesqin): move mixed precision training explanation to __init__
+ # docstring.
"""An optimizer that applies loss scaling in backprop.
- This class is useful for mixed precision training on GPUs (or other potential
- accelerators), which is an approach to improve compute throughput without loss
- of model quality.
-
- The commmon configuration of mixed precision models is the following:
- * variables are kept in high precision (e.g. float32).
- * computations are done in lower precision (e.g. float16). variables are
- casted to lower precision before they're used.
- * (in training), final gradients are casted back to variable precision and get
- applied.
-
- Because computations happen in lower precision, gradients in the backprop pass
- might underflow in the smaller dynamic range, causing a model to converge at a
- suboptimal level. This optimizer multiplies the loss by a factor before
- backprop starts to prevent underflow. Before gradients are applied, they are
- casted to higher precision and down-scaled by the same factor, so
- mathematically the variable updates are no different from regular
- same-precision training.
+ This class is useful for "mixed precision training" on GPUs (or other
+ potential accelerators), an approach to improve compute throughput without
+ compromising model quality.
+
+ The canonical way to perform mixed precision training is the following:
+ * Model variables are kept in high precision (e.g. float32).
+ * Computations are done in lower precision (e.g. float16), which enjoys
+ performance speedup by virtue of hardware support. Variables are casted to
+ lower precision before they're used.
+ * Final gradients are casted back to high precision dtype, then used to update
+ variables.
+
+ The side-effect of performing computation in lower precision, is that it comes
+ with smaller numerical range. During backproping, small gradients might
+ underflow in the reduced numerical range, causing a model to converge at
+ suboptimal level.
+
+ To prevent underflow, this optimizer multiplies the loss by a factor before
+ backprop starts. Consequently, the gradients are linearly scaled up by the
+ same factor, thus not falling into the underflow zone. After that, to perserve
+ the correctness of backprop, the gradients are down-scaled by the same factor,
+ casted to the (higher) variable precision, then applied on the variables.
See [Nvidia's manual on mixed precision training](
https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
@@ -71,7 +77,7 @@ class LossScaleOptimizer(optimizer.Optimizer):
If gradients clipping is applied, one can call
`optimizer.compute_gradients()` and `optimizer.apply_gradients()`
- seperately.
+ separately.
Notice the following way of using LossScaleOptimizer is not intended. Always
use `loss_scale_optimizer.compute_gradients()` to compute gradients instead of
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
index dded61ccd5..9009df0eef 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer_test.py
@@ -54,7 +54,7 @@ class LossScaleOptimizerTest(test.TestCase):
opt = loss_scale_opt_fn(opt)
return x, loss, opt
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_float16_underflow_without_loss_scale(self):
lr = 1
init_val = 1.
@@ -73,7 +73,7 @@ class LossScaleOptimizerTest(test.TestCase):
rtol=0,
atol=min(symbolic_update, 1e-6))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_float16_with_loss_scale(self):
lr = 1.
init_val = 1.
@@ -95,7 +95,7 @@ class LossScaleOptimizerTest(test.TestCase):
rtol=0,
atol=min(expected_update, 1e-6))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_compute_gradients_with_loss_scale(self):
lr = 1
init_val = 1.
@@ -115,7 +115,7 @@ class LossScaleOptimizerTest(test.TestCase):
# Gradients aren't applied.
self.assertAllClose(init_val, self.evaluate(x), rtol=0, atol=1e-6)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_compute_gradients_without_loss_scale(self):
lr = 1
init_val = 1.
@@ -127,7 +127,7 @@ class LossScaleOptimizerTest(test.TestCase):
g_v = self.evaluate(grads_and_vars[0][0])
self.assertAllClose(g_v, 0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_apply_gradients(self):
x = variable_scope.get_variable("x", initializer=1., dtype=dtypes.float32)
@@ -155,7 +155,7 @@ class LossScaleOptimizerTest(test.TestCase):
actual_output.append(self.evaluate(x))
self.assertAllClose(expected_output, actual_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_apply_gradients_loss_scale_is_updated(self):
class SimpleLossScaleManager(lsm_lib.LossScaleManager):
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
index a7be92a35e..ecac06354d 100644
--- a/tensorflow/contrib/mpi_collectives/BUILD
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -52,6 +52,7 @@ tf_custom_op_library(
deps = [
":mpi_defines",
":mpi_message_proto_cc",
+ "//tensorflow/stream_executor:stream_executor_headers_lib",
"//third_party/mpi",
],
)
diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
index ed22ee667f..e4b0c2c654 100644
--- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
+++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
@@ -73,7 +73,7 @@ limitations under the License.
*/
template <class T>
-using StatusOr = se::port::StatusOr<T>;
+using StatusOr = stream_executor::port::StatusOr<T>;
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py
new file mode 100644
index 0000000000..bd7096d9ce
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.py
@@ -0,0 +1,163 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Inter-process communication using MPI."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.platform import tf_logging as logging
+
+
+def _load_library(name, op_list=None):
+ """Loads a .so file containing the specified operators.
+
+ Args:
+ name: The name of the .so file to load.
+ op_list: A list of names of operators that the library should have. If None
+ then the .so file's contents will not be verified.
+
+ Raises:
+ NameError if one of the required ops is missing.
+ """
+ try:
+ filename = resource_loader.get_path_to_datafile(name)
+ library = load_library.load_op_library(filename)
+ for expected_op in (op_list or []):
+ for lib_op in library.OP_LIST.op:
+ if lib_op.name == expected_op:
+ break
+ else:
+ raise NameError('Could not find operator %s in dynamic library %s' %
+ (expected_op, name))
+ return library
+ except errors.NotFoundError:
+ logging.warning('%s file could not be loaded.', name)
+
+
+MPI_LIB = _load_library(
+ 'mpi_collectives.so',
+ ['MPISize', 'MPIRank', 'MPILocalRank', 'MPIAllgather', 'MPIAllreduce'])
+
+
+def size(name=None):
+ """An op which returns the number of MPI processes.
+
+ This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the
+ size of the global communicator.
+
+ Returns:
+ An integer scalar containing the number of MPI processes.
+ """
+ return MPI_LIB.mpi_size(name=name)
+
+
+ops.NotDifferentiable('MPISize')
+
+
+def rank(name=None):
+ """An op which returns the MPI rank of the calling process.
+
+ This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the
+ rank of the current process in the global communicator.
+
+ Returns:
+ An integer scalar with the MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_rank(name=name)
+
+
+ops.NotDifferentiable('MPIRank')
+
+
+def init(name=None):
+ """An op which initializes MPI on the device on which it is run.
+
+ All future MPI ops must be run on the same device that the `init` op was run
+ on.
+ """
+ return MPI_LIB.mpi_init(name=name)
+
+
+ops.NotDifferentiable('MPIInit')
+
+
+def local_rank(name=None):
+ """An op which returns the local MPI rank of the calling process, within the
+ node that it is running on. For example, if there are seven processes running
+ on a node, their local ranks will be zero through six, inclusive.
+
+ This is equivalent to running `MPI_Comm_rank(...)` on a new communicator
+ which only includes processes on the same node.
+
+ Returns:
+ An integer scalar with the local MPI rank of the calling process.
+ """
+ return MPI_LIB.mpi_local_rank(name=name)
+
+
+ops.NotDifferentiable('MPILocalRank')
+
+
+def _allreduce(tensor, name=None):
+ """An op which sums an input tensor over all the MPI processes.
+
+ The reduction operation is keyed by the name of the op. The tensor type and
+ shape must be the same on all MPI processes for a given name. The reduction
+ will not start until all processes are ready to send and receive the tensor.
+
+ Returns:
+ A tensor of the same shape and type as `tensor`, summed across all
+ processes.
+ """
+ return MPI_LIB.mpi_allreduce(tensor, name=name)
+
+
+ops.NotDifferentiable('MPIAllreduce')
+
+
+def allgather(tensor, name=None):
+ """An op which concatenates the input tensor with the same input tensor on
+ all other MPI processes.
+
+ The concatenation is done on the first dimension, so the input tensors on the
+ different processes must have the same rank and shape, except for the first
+ dimension, which is allowed to be different.
+
+ Returns:
+ A tensor of the same type as `tensor`, concatenated on dimension zero
+ across all processes. The shape is identical to the input shape, except for
+ the first dimension, which may be greater and is the sum of all first
+ dimensions of the tensors in different MPI processes.
+ """
+ # Specify that first allgather is to collect the tensor gather sizes,
+ # indicated by passing in a scalar (0-D tensor) of value 0
+ sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const')
+ my_size = tf.slice(
+ tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice')
+ if name is None:
+ name = 'allgather'
+ sizing_name = '{}_sizing'.format(name)
+ sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name)
+ return MPI_LIB.mpi_allgather(tensor, sizes, name=name)
+
+
+ops.NotDifferentiable('MPIAllgather')
diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc
new file mode 100644
index 0000000000..d93233eb21
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+extern template MPI_Datatype MPIType<float>();
+extern template MPI_Datatype MPIType<int>();
+extern template MPI_Datatype MPIType<long long>();
+extern template DataType TensorFlowDataType<float>();
+extern template DataType TensorFlowDataType<int>();
+extern template DataType TensorFlowDataType<long long>();
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<CPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<CPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Copy data on a CPU using a straight-forward memcpy.
+template <>
+void CopyTensorData<CPUDevice>(void* dst, void* src, size_t size) {
+ std::memcpy(dst, src, size);
+};
+
+// Accumulate values on a CPU.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<CPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ for (unsigned int i = 0; i < size; i++) { \
+ dst[i] += src[i]; \
+ } \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc
new file mode 100644
index 0000000000..2f3eef366a
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.cu.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+template <>
+MPI_Datatype MPIType<float>() {
+ return MPI_FLOAT;
+};
+template <>
+MPI_Datatype MPIType<int>() {
+ return MPI_INT;
+};
+template <>
+MPI_Datatype MPIType<long long>() {
+ return MPI_LONG_LONG;
+};
+
+template <>
+DataType TensorFlowDataType<float>() {
+ return DT_FLOAT;
+};
+template <>
+DataType TensorFlowDataType<int>() {
+ return DT_INT32;
+};
+template <>
+DataType TensorFlowDataType<long long>() {
+ return DT_INT64;
+};
+
+// Generate all necessary specializations for RingAllreduce.
+template Status RingAllreduce<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+template Status RingAllreduce<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ Tensor*, Tensor*);
+
+// Generate all necessary specializations for RingAllgather.
+template Status RingAllgather<GPUDevice, int>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+template Status RingAllgather<GPUDevice, float>(OpKernelContext*, const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+
+// Synchronously copy data on the GPU, using a different stream than the default
+// and than TensorFlow to avoid synchronizing on operations unrelated to the
+// allreduce.
+template <>
+void CopyTensorData<GPUDevice>(void* dst, void* src, size_t size) {
+ auto stream = CudaStreamForMPI();
+ cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
+ cudaStreamSynchronize(stream);
+};
+
+// Elementwise accumulation kernel for GPU.
+template <typename T>
+__global__ void elemwise_accum(T* out, const T* in, const size_t N) {
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+ i += blockDim.x * gridDim.x) {
+ out[i] += in[i];
+ }
+}
+
+// Synchronously accumulate tensors on the GPU, using a different stream than
+// the default and than TensorFlow to avoid synchronizing on operations
+// unrelated to the allreduce.
+#define GENERATE_ACCUMULATE(type) \
+ template <> \
+ void AccumulateTensorData<GPUDevice, type>(type * dst, type * src, \
+ size_t size) { \
+ auto stream = CudaStreamForMPI(); \
+ elemwise_accum<type><<<32, 256, 0, stream>>>(dst, src, size); \
+ cudaStreamSynchronize(stream); \
+ };
+GENERATE_ACCUMULATE(int);
+GENERATE_ACCUMULATE(long long);
+GENERATE_ACCUMULATE(float);
+#undef GENERATE_ACCUMULATE
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h
new file mode 100644
index 0000000000..cae57ce60e
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ring.h
@@ -0,0 +1,327 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_MPI_H_
+#define TENSORFLOW_CONTRIB_MPI_H_
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#if GOOGLE_CUDA
+#include "cuda_runtime.h"
+#endif
+
+// Needed to avoid header issues with C++-supporting MPI implementations
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+
+#define TAG_TENSOR 12
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+// Convert from templated types to values we can pass to MPI.
+template <typename T>
+MPI_Datatype MPIType();
+
+// Convert from templated types to TensorFlow data types.
+template <typename T>
+DataType TensorFlowDataType();
+
+#define MPI_REQUIRES_OK(MPI_STATUS) \
+ if ((MPI_STATUS) != MPI_SUCCESS) { \
+ return errors::Unknown("MPI operation failed unexpectedly."); \
+ }
+
+// Copy data from one tensor to another tensor.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device>
+void CopyTensorData(void* destination, void* source, size_t size);
+
+// Add a tensor into another tensor, accumulating in place.
+// This uses a custom CUDA stream on GPU, which is necessary to overlay the
+// backpropagation computations with the allreduce.
+template <typename Device, typename T>
+void AccumulateTensorData(T* destination, T* source, size_t size);
+
+// We need to get the right stream for doing CUDA memory transfers and
+// operations, which is possibly different from the standard TensorFlow stream.
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI();
+#endif
+
+/* Perform a ring allreduce on the data. Allocate the necessary output tensor
+ * and store it in the output parameter.
+ *
+ * Assumes that all MPI processes are doing an allreduce of the same tensor,
+ * with the same dimensions.
+ *
+ * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the
+ * allreduce, the nodes involved are arranged in a ring:
+ *
+ * .--0--.
+ * / \
+ * 3 1
+ * \ /
+ * *--2--*
+ *
+ * Each node always sends to the next clockwise node in the ring, and receives
+ * from the previous one.
+ *
+ * The allreduce is done in two parts: a scatter-reduce and an allgather. In
+ * the scatter reduce, a reduction is done, so that each node ends up with a
+ * chunk of the final output tensor which has contributions from all other
+ * nodes. In the allgather, those chunks are distributed among all the nodes,
+ * so that all nodes have the entire output tensor.
+ *
+ * Both of these operations are done by dividing the input tensor into N
+ * evenly sized chunks (where N is the number of nodes in the ring).
+ *
+ * The scatter-reduce is done in N-1 steps. In the ith step, node j will send
+ * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to
+ * its existing data for that chunk. For example, in the first iteration with
+ * the ring depicted above, you will have the following transfers:
+ *
+ * Segment 0: Node 0 --> Node 1
+ * Segment 1: Node 1 --> Node 2
+ * Segment 2: Node 2 --> Node 3
+ * Segment 3: Node 3 --> Node 0
+ *
+ * In the second iteration, you'll have the following transfers:
+ *
+ * Segment 0: Node 1 --> Node 2
+ * Segment 1: Node 2 --> Node 3
+ * Segment 2: Node 3 --> Node 0
+ * Segment 3: Node 0 --> Node 1
+ *
+ * After this iteration, Node 2 has 3 of the four contributions to Segment 0.
+ * The last iteration has the following transfers:
+ *
+ * Segment 0: Node 2 --> Node 3
+ * Segment 1: Node 3 --> Node 0
+ * Segment 2: Node 0 --> Node 1
+ * Segment 3: Node 1 --> Node 2
+ *
+ * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0
+ * has the fully accumulated Segment 1; and so on. The scatter-reduce is
+ * complete.
+ *
+ * Next, the allgather distributes these fully accumululated chunks across all
+ * nodes. Communication proceeds in the same ring, once again in N-1 steps. At
+ * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i).
+ * For example, at the first iteration, the following transfers will occur:
+ *
+ * Segment 0: Node 3 --> Node 0
+ * Segment 1: Node 0 --> Node 1
+ * Segment 2: Node 1 --> Node 2
+ * Segment 3: Node 2 --> Node 3
+ *
+ * After the first iteration, Node 0 will have a fully accumulated Segment 0
+ * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its
+ * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3.
+ * After this has continued for N - 1 iterations, all nodes will have a the
+ * fully accumulated tensor.
+ *
+ * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the
+ * allgather. Each send will contain K / N bytes, if there are K bytes in the
+ * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N
+ * bytes of data, and the performance of the allreduce (assuming no latency in
+ * connections) is constrained by the slowest interconnect between the nodes.
+ *
+ */
+template <typename Device, typename T>
+Status RingAllreduce(OpKernelContext* context, const Tensor* input,
+ Tensor* temp, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ T* buffer = (T*)output->tensor_data().data();
+
+ CopyTensorData<Device>((void*)buffer, (void*)input->tensor_data().data(),
+ output->tensor_data().size());
+
+ // Calculate segment sizes and segment ends
+ const size_t elements_to_reduce = input->NumElements();
+ const size_t segment_size = elements_to_reduce / n;
+ std::vector<size_t> segment_sizes(n, segment_size);
+
+ const size_t residual = elements_to_reduce % n;
+ for (size_t i = 0; i < residual; ++i) {
+ segment_sizes[i]++;
+ }
+
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (size_t i = 1; i < segment_starts.size(); ++i) {
+ segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1];
+ }
+
+ assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce);
+
+ T* segment_recv = (T*)temp->tensor_data().data();
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ MPI_Status recv_status;
+ MPI_Request recv_req;
+
+ // Now start ring. At every step, for every rank, we iterate through
+ // segments with wraparound and send and recv from our neighbors and reduce
+ // locally. At the i'th iteration, rank r, sends segment (r-i) and receives
+ // segment (r-i-1).
+ for (int i = 0; i < n - 1; i++) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id],
+ MPIType<T>(), recv_from, TAG_TENSOR,
+ MPI_COMM_WORLD, &recv_req));
+
+ MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id],
+ MPIType<T>(), send_to, TAG_TENSOR,
+ MPI_COMM_WORLD));
+
+ T* segment_update = &(buffer[segment_starts[recv_seg_id]]);
+
+ // Wait for recv to complete before reduction
+ MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status));
+
+ const size_t recv_seg_size = segment_sizes[recv_seg_id];
+ AccumulateTensorData<Device, T>(segment_update, segment_recv,
+ recv_seg_size);
+ }
+
+ // Now start pipelined ring allgather. At every step, for every rank, we
+ // iterate through segments with wraparound and send and recv from our
+ // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and
+ // receives segment (r-i).
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i + 1) + n) % n;
+ const size_t recv_seg_id = ((r - i) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i+1)
+ T* segment_send = &(buffer[segment_starts[send_seg_id]]);
+
+ // Segment to recv - at every iteration we receive segment (r-i)
+ T* segment_recv = &(buffer[segment_starts[recv_seg_id]]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, segment_sizes[send_seg_id], MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+// Perform a ring allgather on a Tensor. Other ranks may allgather with a
+// tensor which differs in the first dimension only; all other dimensions must
+// be the same.
+//
+// For more information on the ring allgather, read the documentation for the
+// ring allreduce, which includes a ring allgather.
+template <typename Device, typename T>
+Status RingAllgather(OpKernelContext* context, const Tensor* input,
+ const std::vector<size_t>& sizes, Tensor* output) {
+ // Acquire MPI size and rank
+ int n, r;
+ MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n));
+ MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r));
+
+ assert(sizes.size() == n);
+ assert(input->dim_size(0) == sizes[r]);
+
+ // Compute number of elements in every "row". We can't compute number of
+ // elements in every chunks, because those chunks are variable length.
+ size_t elements_per_row = 1;
+ for (int i = 1; i < input->shape().dims(); i++) {
+ elements_per_row *= input->dim_size(i);
+ }
+
+ // Copy data from input tensor to correct place in output tensor.
+ std::vector<size_t> segment_starts(n);
+ segment_starts[0] = 0;
+ for (int i = 1; i < n; i++) {
+ segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1];
+ }
+ size_t offset = segment_starts[r];
+
+ // Copy data to the right offset for this rank.
+ T* buffer = (T*)output->tensor_data().data();
+ CopyTensorData<Device>((void*)(buffer + offset),
+ (void*)input->tensor_data().data(),
+ elements_per_row * sizes[r] * sizeof(T));
+
+ // Receive from your left neighbor with wrap-around
+ const size_t recv_from = ((r - 1) + n) % n;
+
+ // Send to your right neighbor with wrap-around
+ const size_t send_to = (r + 1) % n;
+
+ // Perform a ring allgather. At every step, for every rank, we iterate
+ // through segments with wraparound and send and recv from our neighbors.
+ // At the i'th iteration, rank r, sends segment (r-i) and receives segment
+ // (r-1-i).
+ MPI_Status recv_status;
+ for (size_t i = 0; i < n - 1; ++i) {
+ const size_t send_seg_id = ((r - i) + n) % n;
+ const size_t recv_seg_id = ((r - i - 1) + n) % n;
+
+ // Segment to send - at every iteration we send segment (r-i)
+ size_t offset_send = segment_starts[send_seg_id];
+ size_t rows_send = sizes[send_seg_id];
+ T* segment_send = &(buffer[offset_send]);
+
+ // Segment to recv - at every iteration we receive segment (r-1-i)
+ size_t offset_recv = segment_starts[recv_seg_id];
+ size_t rows_recv = sizes[recv_seg_id];
+ T* segment_recv = &(buffer[offset_recv]);
+
+ MPI_REQUIRES_OK(MPI_Sendrecv(
+ segment_send, elements_per_row * rows_send, MPIType<T>(), send_to,
+ TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType<T>(),
+ recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status));
+ }
+
+ return Status::OK();
+}
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
+
+#undef TENSORFLOW_CONTRIB_MPI_H_
+#endif // TENSORFLOW_CONTRIB_MPI_H_
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 334e70318d..62996d1fd8 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -19,17 +19,18 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "if_not_windows_cuda")
tf_custom_op_library(
name = "python/ops/_nccl_ops.so",
srcs = [
"ops/nccl_ops.cc",
],
- gpu_srcs = [
+ gpu_srcs = if_not_windows_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
- ],
+ ]),
deps = if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
@@ -97,18 +98,19 @@ tf_gen_op_wrapper_py(
deps = [":nccl_ops_op_lib"],
)
+# Test only nccl ops lib without dso to test behavior when NCCL lib is not
+# installed. See nccl_dependency_test for more details.
+#
+# Users should use the public nccl_py lib that also adds the dso.
tf_custom_op_py_library(
- name = "nccl_py",
+ name = "nccl_ops_lib_without_dso",
srcs = [
"__init__.py",
"python/ops/nccl_ops.py",
],
- dso = [":python/ops/_nccl_ops.so"],
kernels = if_cuda([":nccl_kernels"]) + [
":nccl_ops_op_lib",
],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
deps = [
":nccl_ops",
"//tensorflow/contrib/util:util_py",
@@ -120,6 +122,15 @@ tf_custom_op_py_library(
],
)
+tf_custom_op_py_library(
+ name = "nccl_py",
+ dso = [":python/ops/_nccl_ops.so"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":nccl_ops_lib_without_dso",
+ ],
+)
+
cuda_py_test(
name = "nccl_ops_test",
size = "small",
@@ -141,3 +152,25 @@ cuda_py_test(
"notap",
],
)
+
+cuda_py_test(
+ name = "nccl_dependency_test",
+ size = "small",
+ srcs = ["python/ops/nccl_dependency_test.py"],
+ additional_deps = [
+ ":nccl_ops_lib_without_dso",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform_test",
+ ],
+ # Disable this test internally as static linking is used internally and only
+ # run for OSS to verify that NCCL is an optional dynamic dependency.
+ tags = [
+ "manual",
+ "noguitar",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py
new file mode 100644
index 0000000000..c766080dbe
--- /dev/null
+++ b/tensorflow/contrib/nccl/python/ops/nccl_dependency_test.py
@@ -0,0 +1,59 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Dependency test for nccl to test behavior when NCCL is not installed."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib import nccl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+class NcclDependencyTest(test.TestCase):
+ """Verifies that importing nccl ops lib does not fail even if NCCL is not
+ installed but nccl ops throws an exception on use if NCCL is not installed.
+ """
+
+ def test_nccl_ops(self):
+ """Tests behavior of nccl ops when NCCL is not installed."""
+
+ public_methods = [
+ m[0]
+ for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction)
+ if not m[0].startswith('_')
+ ]
+ for method_name in public_methods:
+ with ops.device('/device:CPU:0'):
+ tensor = constant_op.constant(1)
+
+ if method_name == 'broadcast':
+ arg = tensor
+ else:
+ arg = [tensor]
+
+ nccl_op = getattr(nccl, method_name)
+ with ops.device('/device:CPU:0'):
+ with self.assertRaisesRegexp(errors_impl.NotFoundError,
+ r'cannot open shared object file'):
+ nccl_op(arg)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
index 794372a1f4..fa597cf3ef 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py
@@ -26,8 +26,10 @@ from tensorflow.python.framework import device
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
-_nccl_ops_so = loader.load_op_library(
- resource_loader.get_path_to_datafile('_nccl_ops.so'))
+
+_nccl_ops_so = None
+_module_lock = threading.Lock()
+_shared_name_counter = 0
def all_sum(tensors):
@@ -61,12 +63,12 @@ def _all_sum_grad(op, grad):
Raises:
LookupError: If `reduction` is not `sum`.
"""
- if op.get_attr('reduction') != 'sum':
+ if op.get_attr('reduction') != b'sum':
raise LookupError('No gradient defined for NcclAllReduce except sum.')
_check_device(grad, expected=op.device)
num_devices = op.get_attr('num_devices')
- shared_name = op.get_attr('shared_name') + '_grad'
+ shared_name = op.get_attr('shared_name') + b'_grad'
with ops.device(op.device):
return gen_nccl_ops.nccl_all_reduce(
@@ -160,7 +162,7 @@ def _reduce_sum_grad(op, grad):
Raises:
LookupError: If the reduction attribute of op is not `sum`.
"""
- if op.get_attr('reduction') != 'sum':
+ if op.get_attr('reduction') != b'sum':
raise LookupError('No gradient defined for NcclReduce except sum.')
_check_device(grad, expected=op.device)
@@ -180,7 +182,7 @@ def broadcast(tensor):
A tensor with the value of `src_tensor`, which can be used as input to
ops on other GPU devices.
"""
- _check_graph_mode()
+ _validate_and_load_nccl_so()
_check_device(tensor)
with ops.device(tensor.device):
@@ -212,7 +214,7 @@ def _apply_all_reduce(reduction, tensors):
"""Helper function for all_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to all reduce operations')
- _check_graph_mode()
+ _validate_and_load_nccl_so()
shared_name = _get_shared_name()
res = []
@@ -234,7 +236,7 @@ def _apply_reduce(reduction, tensors):
"""Helper function for reduce_* functions."""
if not tensors:
raise ValueError('Must pass >0 tensors to reduce operations')
- _check_graph_mode()
+ _validate_and_load_nccl_so()
for t in tensors:
_check_device(t)
@@ -246,14 +248,10 @@ def _apply_reduce(reduction, tensors):
return result
-_lock = threading.Lock()
-_shared_name_counter = 0
-
-
def _get_shared_name():
global _shared_name_counter
- with _lock:
+ with _module_lock:
val = _shared_name_counter
_shared_name_counter += 1
return 'c%s' % val
@@ -266,6 +264,25 @@ def _check_device(tensor, expected=None):
raise ValueError('Expected device %s, got %s' % (expected, tensor.device))
-def _check_graph_mode():
+def _maybe_load_nccl_ops_so():
+ """Loads nccl ops so if it hasn't been loaded already."""
+
+ with _module_lock:
+ global _nccl_ops_so
+ if not _nccl_ops_so:
+ _nccl_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile('_nccl_ops.so'))
+
+
+def _validate_and_load_nccl_so():
+ """Validates calling context and loads nccl ops so file.
+
+ Raises:
+ ValueError: Ops are not supported.
+ errors_impl.NotFoundError: nccl library is not installed.
+ """
+
if context.executing_eagerly():
raise ValueError('Nccl ops are not supported in eager mode')
+
+ _maybe_load_nccl_ops_so()
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 13aa1d7e7a..bbdf962d04 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -19,6 +19,7 @@ py_library(
"python/training/drop_stale_gradient_optimizer.py",
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
+ "python/training/ggt.py",
"python/training/lazy_adam_optimizer.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
@@ -28,15 +29,19 @@ py_library(
"python/training/reg_adagrad_optimizer.py",
"python/training/sign_decay.py",
"python/training/variable_clipping_optimizer.py",
+ "python/training/weight_decay_optimizers.py",
],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/optimizer_v2:optimizer_v2_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
@@ -194,6 +199,25 @@ py_test(
],
)
+py_test(
+ name = "weight_decay_optimizers_test",
+ srcs = ["python/training/weight_decay_optimizers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
tf_py_test(
name = "drop_stale_gradient_optimizer_test",
srcs = ["python/training/drop_stale_gradient_optimizer_test.py"],
@@ -302,3 +326,21 @@ py_test(
"//third_party/py/numpy",
],
)
+
+py_test(
+ name = "ggt_test",
+ srcs = ["python/training/ggt_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 4c13c8e247..3e63e99030 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -22,15 +22,18 @@ from __future__ import print_function
from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
+from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
+from tensorflow.contrib.opt.python.training.ggt import *
from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import *
+from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
+from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import *
-from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
-from tensorflow.contrib.opt.python.training.model_average_optimizer import *
+from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -46,6 +49,10 @@ _allowed_symbols = [
'LazyAdamOptimizer',
'NadamOptimizer',
'MovingAverageOptimizer',
+ 'MomentumWOptimizer',
+ 'AdamWOptimizer',
+ 'DecoupledWeightDecayExtension',
+ 'extend_with_decoupled_weight_decay',
'ScipyOptimizerInterface',
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
@@ -53,7 +60,8 @@ _allowed_symbols = [
'ElasticAverageOptimizer',
'ElasticAverageCustomGetter',
'ModelAverageOptimizer',
- 'ModelAverageCustomGetter'
+ 'ModelAverageCustomGetter',
+ 'GGTOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/addsign_test.py b/tensorflow/contrib/opt/python/training/addsign_test.py
index 08d45ed73f..628a735e72 100644
--- a/tensorflow/contrib/opt/python/training/addsign_test.py
+++ b/tensorflow/contrib/opt/python/training/addsign_test.py
@@ -214,7 +214,7 @@ class AddSignTest(test.TestCase):
# Run 7 steps of AddSign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
- for t in range(1, 4):
+ for t in range(1, 8):
if t < 5:
update.run()
else:
@@ -222,7 +222,7 @@ class AddSignTest(test.TestCase):
var0_np, m0 = addsign_update_numpy(
var0_np,
- grads0_np,
+ grads0_np if t < 5 else -grads0_np,
m0,
learning_rate,
alpha=alpha,
@@ -232,7 +232,7 @@ class AddSignTest(test.TestCase):
)
var1_np, m1 = addsign_update_numpy(
var1_np,
- grads1_np,
+ grads1_np if t < 5 else -grads1_np,
m1,
learning_rate,
alpha=alpha,
diff --git a/tensorflow/contrib/opt/python/training/ggt.py b/tensorflow/contrib/opt/python/training/ggt.py
new file mode 100644
index 0000000000..cae952d8f5
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt.py
@@ -0,0 +1,312 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GGT for Tensorflow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import numpy as np
+from tensorflow.contrib.optimizer_v2 import optimizer_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+
+
+class GGTOptimizer(optimizer_v2.OptimizerV2):
+ """Optimizer that implements the GGT algorithm.
+
+ GGT has an advantage over sgd and adam on large models with poor conditioning,
+ for example language models and CNNs,
+ see [[ABCHSZZ 2018]](https://arxiv.org/pdf/1806.02958.pdf).
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta1=0.9,
+ use_locking=False,
+ name="GGT",
+ window=10,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Construct a new GGT optimizer.
+
+ Initialization:
+
+ ```
+ t <- 0 (Initialize timestep)
+ grad_buffer <- 0 (Initialize buffer for keeping past gradients)
+ flat_grad <- 0 (Initialize flattened gradient that contains gradients of all
+ variables)
+ m_0 <- 0 (Initialize 1st moment vector)
+ ```
+
+ Suppose all variables and their gradients are concatenated into vectors
+ `flat_vars` and `flat_grad`. The update rule for `flat_vars`
+ uses an optimization described at the beginning of section 2 of the paper:
+
+ ```
+ t <- t + 1
+
+ m_t <- beta1 * m_{t-1} + (1 - beta1) * flat_grad
+ grad_buffer[(t-1) % window, :] <- m_t
+
+ M <- grad_buffer^T / sqrt(min(t, window))
+ U, sigma, _ <- SVD(M^TM + I * svd_eps)
+
+ sigma_sqrt_inv <- (sqrt(sigma) + sigma_eps)^(-3)
+ sigma_sqrt_min <- min(sqrt(sigma))
+
+ if sigma_sqrt_min > eps:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t +
+ (m_t - M U diag(1/sigma) U^T M^T m_t) / sigma_sqrt_min
+ else:
+ new_step <- M U diag(sigma_sqrt_inv) U^T M^T m_t
+
+ flat_vars <- flat_vars - learning_rate * new_step
+ ```
+
+ GGT provides the power of full-matrix adaptive regularization at a cost not
+ much larger than SGD. As a result it is suited for large models where the
+ gradient covariance matrix has a poor condition number that slows down first
+ order methods.
+ GGT uses the preconditioner from full-matrix AdaGrad, with gradient history
+ attenuated exponentially as in Adam, and truncated to a window parameter.
+ It has provable guarantees even for non-convex optimization that is never
+ significantly worse than SGD and in some cases better.
+
+ Args:
+ learning_rate: A float hyperparameter. The learning rate.
+ beta1: A float hyperparameter. The exponential decay rate for the 1st
+ moment estimates.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "GGT".
+ window: An integer hyperparameter. The number of first moments to keep in
+ computing the adaptive preconditioner.
+ eps: A float hyperparameter. Used to truncate small eigenvalues of the
+ gradient covariance matrix.
+ svd_eps: A float hyperparameter. Used to stabilize SVD.
+ sigma_eps: A float hyperparameter. Used to regularize matrix inversion.
+ """
+ super(GGTOptimizer, self).__init__(use_locking, name)
+ self._set_hyper("lr", learning_rate)
+ self._set_hyper("beta1", beta1)
+ self._set_hyper("window", window)
+ self._set_hyper("eps", eps)
+ self._set_hyper("svd_eps", svd_eps)
+ self._set_hyper("sigma_eps", sigma_eps)
+
+ self.index_dict = {}
+ self.shape_dict = {}
+
+ def _create_vars(self, var_list, state):
+ # Construct ordered dictionary for variable dimensions, sorted by name.
+ shape_dict = {}
+ for v in var_list:
+ shape_dict[v.name] = np.prod(v.get_shape()).value
+ self.shape_dict = collections.OrderedDict(
+ sorted(shape_dict.items(), key=lambda t: t[0]))
+
+ # Assign each variable its location in flat_grad. The locations are based on
+ # the order of sorted names.
+ idx = 0
+ for v_name, v_dim in self.shape_dict.items():
+ self.index_dict[v_name] = idx
+ idx += v_dim
+
+ state.create_non_slot(
+ initial_value=math_ops.cast(0., dtype=var_list[0].dtype.base_dtype),
+ name="global_step")
+
+ # Buffer for keeping past gradients.
+ window = state.get_hyper("window")
+ grad_buffer_init = array_ops.zeros(
+ [window, idx], dtype=var_list[0].dtype.base_dtype)
+ state.create_non_slot(initial_value=grad_buffer_init, name="grad_buffer")
+
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="moment1")
+
+ # Flattened gradient that contains gradients for all variables in the model.
+ state.create_non_slot(
+ initial_value=array_ops.zeros(
+ (idx,), dtype=var_list[0].dtype.base_dtype),
+ name="flat_grad")
+
+ def _get_global_step(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("global_step")
+
+ def _get_moment1(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("moment1")
+
+ def _get_grad_buffer(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("grad_buffer")
+
+ def _get_flat_grad(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return state.get_non_slot("flat_grad")
+
+ def _apply_sparse(self, grad, var):
+ raise NotImplementedError("Sparse gradient updates are not supported.")
+
+ def _prepare(self, state):
+ self._variables = []
+
+ def _apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _resource_apply_dense(self, grad, var, state):
+ self._variables.append(var)
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+
+ # Update flat_gradient at the index associated with the variable.
+ flat_grad = self._get_flat_grad(state)
+ new_flat_grad = array_ops.reshape(grad, [-1])
+ flat_grad_updated = state_ops.scatter_update(
+ flat_grad, math_ops.range(start_index, end_index), new_flat_grad)
+
+ return flat_grad_updated
+
+ def _finish(self, state):
+ var_dtype = self._variables[0].dtype.base_dtype
+ # Update global step.
+ global_step = self._get_global_step(state)
+ update_global_step = state_ops.assign_add(global_step, 1.)
+
+ # Update the first moment estimate.
+ beta1 = state.get_hyper("beta1", dtype=var_dtype)
+ moment1 = self._get_moment1(state)
+ flat_grad = self._get_flat_grad(state)
+ # moment1_t := beta1 * moment1_{t-1} + (1 - beta1) * flat_grad_t
+ update_moment1 = moment1.assign(beta1 * moment1 + (1. - beta1) * flat_grad)
+
+ # Update the gradient buffer.
+ window = state.get_hyper("window")
+ grad_buffer = self._get_grad_buffer(state)
+ next_grad_index = math_ops.floormod(
+ math_ops.to_int32(update_global_step - 1.), window)
+ # grad_buffer[(t-1) % window] := moment1_t
+ update_grad_buffer = state_ops.scatter_update(grad_buffer, next_grad_index,
+ update_moment1)
+
+ # Compute the update step.
+ eps = state.get_hyper("eps", dtype=var_dtype)
+ svd_eps = state.get_hyper("svd_eps", dtype=var_dtype)
+ sigma_eps = state.get_hyper("sigma_eps", dtype=var_dtype)
+ lr = state.get_hyper("lr", dtype=var_dtype)
+ denom = math_ops.sqrt(
+ math_ops.minimum(
+ ops.convert_to_tensor(update_global_step),
+ ops.convert_to_tensor(math_ops.cast(window, dtype=var_dtype))))
+ moment1_2d = array_ops.expand_dims(update_moment1, -1)
+
+ # m = grad_buffer^T / sqrt(min(t, window))
+ # m has shape [model dimension, window], where model dimension is the sum
+ # of the dimensions of the flattened variables.
+ m = array_ops.transpose(math_ops.divide(update_grad_buffer, denom))
+
+ # sigma, u, _ = SVD(m^Tm + I * svd_eps)
+ mm = math_ops.matmul(m, m, transpose_a=True)
+ damping = math_ops.cast(linalg_ops.eye(window), dtype=var_dtype) * svd_eps
+ sigma, u, _ = linalg_ops.svd(mm + damping)
+ sigma_sqrt = math_ops.sqrt(sigma)
+ sigma_sqrt_min = math_ops.reduce_min(sigma_sqrt)
+
+ # sigma_sqrt_inv = 1 / (\sqrt{sigma} + sigma_eps) ^ 3
+ # We add sigma_eps to alleviate numerical instability.
+ # Note that (m^Tm)^(-3/2) = u diag(sigma_sqrt_inv) u^T.
+ sigma_sqrt_inv = math_ops.divide(
+ math_ops.cast(1.0, dtype=var_dtype),
+ math_ops.pow(sigma_sqrt + sigma_eps, 3))
+
+ # In full matrix AdaGrad, the update step computes (mm^T)^(-1/2)g, where the
+ # inversion of a model dimension by model dimension matrix is needed. To
+ # speed up this computation we calculate the following instead:
+ # m(m^Tm)^(-3/2)m^T moment1 = m u diag(sigma_sqrt_inv) u^T m^T moment1.
+ new_step = array_ops.expand_dims(
+ array_ops.zeros(flat_grad.get_shape(), dtype=var_dtype), -1)
+ head = math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(sigma_sqrt_inv),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+
+ # When inverting (mm^t)^(1/2), we also add epsilon * I regularization for
+ # degenerate cases. We expand ((mm^t)^(1/2) + epsilon * I)^(-1) using
+ # Woodbury's identity.
+ # For full derivation please see paper at
+ # https://arxiv.org/pdf/1806.02958.pdf
+ tail = moment1_2d - math_ops.matmul(
+ m,
+ math_ops.matmul(
+ u,
+ math_ops.matmul(
+ array_ops.diag(
+ math_ops.divide(math_ops.cast(1.0, dtype=var_dtype),
+ sigma)),
+ math_ops.matmul(
+ u,
+ math_ops.matmul(m, moment1_2d, transpose_a=True),
+ transpose_a=True))))
+ scaled_tail = math_ops.divide(tail, sigma_sqrt_min)
+
+ update_new_step = control_flow_ops.cond(
+ sigma_sqrt_min > eps, lambda: math_ops.add(head, scaled_tail),
+ lambda: math_ops.add(new_step, head))
+
+ # Update each variable.
+ update_step = []
+ for var in self._variables:
+ dim = self.shape_dict[var.name]
+ start_index = self.index_dict[var.name]
+ end_index = start_index + dim
+ var_update_correct_shape = array_ops.reshape(
+ update_new_step[start_index:end_index], var.get_shape())
+ var_updated = state_ops.assign_sub(var, lr * var_update_correct_shape)
+ update_step.append(var_updated)
+
+ return control_flow_ops.group(update_step)
diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py
new file mode 100644
index 0000000000..42162960b0
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/ggt_test.py
@@ -0,0 +1,183 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for GGTOptimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from tensorflow.contrib.opt.python.training.ggt import GGTOptimizer
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def ggt_update_numpy(param,
+ g_t,
+ lr,
+ grad_buffer,
+ m,
+ window,
+ t,
+ beta1=0.9,
+ eps=1e-4,
+ svd_eps=1e-6,
+ sigma_eps=1e-2):
+ """Tests the correctness of one step of GGT."""
+ m_t = m * beta1 + (1 - beta1) * g_t
+ grad_buffer[((t - 1) % window), :] = m_t
+ m_matrix = np.transpose(grad_buffer / np.sqrt(np.minimum(t, window)))
+ mm = np.dot(np.transpose(m_matrix), m_matrix)
+ damping = np.eye(window) * svd_eps
+ u, sigma, _ = np.linalg.svd(mm + damping)
+
+ sigma_sqrt_inv = np.power(np.sqrt(sigma) + sigma_eps, -3)
+ new_step = np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(sigma_sqrt_inv),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])
+
+ sigma_sqrt_min = np.sqrt(sigma).min()
+
+ if sigma_sqrt_min > eps:
+ new_step += (m_t - np.linalg.multi_dot([
+ m_matrix, u,
+ np.diag(1.0 / sigma),
+ np.transpose(u),
+ np.transpose(m_matrix), m_t
+ ])) * (1.0 / sigma_sqrt_min)
+
+ param_t = param - lr * new_step
+ return param_t, m_t, grad_buffer
+
+
+class GGTOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ # SVD does not support float16
+ for i, dtype in enumerate([dtypes.float32, dtypes.float64]):
+ with self.test_session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0 = 0.0
+ window = 3
+ grad_buffer = np.zeros((window, 4), dtype=dtype.as_numpy_dtype)
+ lr = 0.001
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np, name="var0")
+ var1 = variables.Variable(var1_np, name="var1")
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = GGTOptimizer(learning_rate=lr, window=window)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+ self.assertTrue(m_t is not None)
+ self.assertTrue(grad_buffer_t is not None)
+ self.assertTrue(g_t is not None)
+ self.assertIn(m_t, opt_variables)
+ self.assertIn(grad_buffer_t, opt_variables)
+ self.assertIn(g_t, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ m_t = opt._get_moment1()
+ grad_buffer_t = opt._get_grad_buffer()
+ g_t = opt._get_flat_grad()
+
+ # Run 3 steps of GGT
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if t == 1:
+ self.assertAllCloseAccordingToType(
+ np.array([0.01, 0.01, 0.001, 0.001]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001], [0., 0., 0., 0.],
+ [0., 0., 0., 0.]]), self.evaluate(grad_buffer_t))
+ elif t == 2:
+ self.assertAllCloseAccordingToType(
+ np.array([0.019, 0.019, 0.0019, 0.0019]), self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001, 0.001],
+ [0.019, 0.019, 0.0019, 0.0019], [0., 0., 0., 0.]]),
+ self.evaluate(grad_buffer_t))
+ else:
+ self.assertAllCloseAccordingToType(
+ np.array([0.0271, 0.0271, 0.00271, 0.00271]),
+ self.evaluate(m_t))
+ self.assertAllCloseAccordingToType(
+ np.array([[0.01, 0.01, 0.001,
+ 0.001], [0.019, 0.019, 0.0019, 0.0019],
+ [0.0271, 0.0271, 0.00271, 0.00271]]),
+ self.evaluate(grad_buffer_t))
+
+ self.assertAllCloseAccordingToType([0.1, 0.1, 0.01, 0.01],
+ self.evaluate(g_t))
+
+ var_np = np.append(var0_np, var1_np)
+ grads_np = np.append(grads0_np, grads1_np)
+ var_np, m0, grad_buffer = ggt_update_numpy(var_np, grads_np, lr,
+ grad_buffer, m0, window, t)
+
+ var0_np = var_np[:2]
+ var1_np = var_np[2:]
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/powersign_test.py b/tensorflow/contrib/opt/python/training/powersign_test.py
index 5214082dd6..0bcf5d230a 100644
--- a/tensorflow/contrib/opt/python/training/powersign_test.py
+++ b/tensorflow/contrib/opt/python/training/powersign_test.py
@@ -216,7 +216,7 @@ class PowerSignTest(test.TestCase):
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of powersign
+ # Run 7 steps of powersign
# first 4 steps with positive gradient
# last 3 steps with negative gradient (sign(gm) should be -1)
for t in range(1, 8):
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
new file mode 100644
index 0000000000..b9cf40eb7b
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -0,0 +1,362 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Base class to make optimizers weight decay ready."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import adam
+from tensorflow.python.training import momentum as momentum_opt
+from tensorflow.python.training import optimizer
+from tensorflow.python.util.tf_export import tf_export
+
+
+class DecoupledWeightDecayExtension(object):
+ """This class allows to extend optimizers with decoupled weight decay.
+
+ It implements the decoupled weight decay described by Loshchilov & Hutter
+ (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
+ decoupled from the optimization steps w.r.t. to the loss function.
+ For SGD variants, this simplifies hyperparameter search since it decouples
+ the settings of weight decay and learning rate.
+ For adaptive gradient algorithms, it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield better
+ training loss and generalization error in the paper above.
+
+ This class alone is not an optimizer but rather extends existing
+ optimizers with decoupled weight decay. We explicitly define the two examples
+ used in the above paper (SGDW and AdamW), but in general this can extend
+ any OptimizerX by using
+ `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`.
+ In order for it to work, it must be the first class the Optimizer with
+ weight decay inherits from, e.g.
+
+ ```python
+ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
+ def __init__(self, weight_decay, *args, **kwargs):
+ super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs).
+ ```
+
+ Note that this extension decays weights BEFORE applying the update based
+ on the gradient, i.e. this extension only has the desired behaviour for
+ optimizers which do not depend on the value of'var' in the update step!
+ """
+
+ def __init__(self, weight_decay, **kwargs):
+ """Construct the extension class that adds weight decay to an optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value, the factor by which
+ a variable is decayed in the update step.
+ **kwargs: Optional list or tuple or set of `Variable` objects to
+ decay.
+ """
+ self._decay_var_list = None # is set in minimize or apply_gradients
+ self._weight_decay = weight_decay
+ # The tensors are initialized in call to _prepare
+ self._weight_decay_tensor = None
+ super(DecoupledWeightDecayExtension, self).__init__(**kwargs)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=optimizer.Optimizer.GATE_OP,
+ aggregation_method=None, colocate_gradients_with_ops=False,
+ name=None, grad_loss=None, decay_var_list=None):
+ """Add operations to minimize `loss` by updating `var_list` with decay.
+
+ This function is the same as Optimizer.minimize except that it allows to
+ specify the variables that should be decayed using decay_var_list.
+ If decay_var_list is None, all variables in var_list are decayed.
+
+ For more information see the documentation of Optimizer.minimize.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ """
+ self._decay_var_list = set(decay_var_list) if decay_var_list else False
+ return super(DecoupledWeightDecayExtension, self).minimize(
+ loss, global_step=global_step, var_list=var_list,
+ gate_gradients=gate_gradients, aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops, name=name,
+ grad_loss=grad_loss)
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None,
+ decay_var_list=None):
+ """Apply gradients to variables and decay the variables.
+
+ This function is the same as Optimizer.apply_gradients except that it
+ allows to specify the variables that should be decayed using
+ decay_var_list. If decay_var_list is None, all variables in var_list
+ are decayed.
+
+ For more information see the documentation of Optimizer.apply_gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+ decay_var_list: Optional list of decay variables.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+ """
+ self._decay_var_list = set(decay_var_list) if decay_var_list else False
+ return super(DecoupledWeightDecayExtension, self).apply_gradients(
+ grads_and_vars, global_step=global_step, name=name)
+
+ def _prepare(self):
+ weight_decay = self._weight_decay
+ if callable(weight_decay):
+ weight_decay = weight_decay()
+ self._weight_decay_tensor = ops.convert_to_tensor(
+ weight_decay, name="weight_decay")
+ # Call the optimizers _prepare function.
+ super(DecoupledWeightDecayExtension, self)._prepare()
+
+ def _decay_weights_op(self, var):
+ if not self._decay_var_list or var in self._decay_var_list:
+ return var.assign_sub(self._weight_decay * var, self._use_locking)
+ return control_flow_ops.no_op()
+
+ def _decay_weights_sparse_op(self, var, indices, scatter_add):
+ if not self._decay_var_list or var in self._decay_var_list:
+ return scatter_add(var, indices, -self._weight_decay * var,
+ self._use_locking)
+ return control_flow_ops.no_op()
+
+ # Here, we overwrite the apply functions that the base optimizer calls.
+ # super().apply_x resolves to the apply_x function of the BaseOptimizer.
+ def _apply_dense(self, grad, var):
+ with ops.control_dependencies([self._decay_weights_op(var)]):
+ return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var)
+
+ def _resource_apply_dense(self, grad, var):
+ with ops.control_dependencies([self._decay_weights_op(var)]):
+ return super(DecoupledWeightDecayExtension, self)._resource_apply_dense(
+ grad, var)
+
+ def _apply_sparse(self, grad, var):
+ scatter_add = state_ops.scatter_add
+ decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add)
+ with ops.control_dependencies([decay_op]):
+ return super(DecoupledWeightDecayExtension, self)._apply_sparse(
+ grad, var)
+
+ def _resource_scatter_add(self, x, i, v, _=None):
+ # last argument allows for one overflow argument, to have the same function
+ # signature as state_ops.scatter_add
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ scatter_add = self._resource_scatter_add
+ decay_op = self._decay_weights_sparse_op(var, indices, scatter_add)
+ with ops.control_dependencies([decay_op]):
+ return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse(
+ grad, var, indices)
+
+
+def extend_with_decoupled_weight_decay(base_optimizer):
+ """Factory function returning an optimizer class with decoupled weight decay.
+
+ Returns an optimizer class. An instance of the returned class computes the
+ update step of `base_optimizer` and additionally decays the weights.
+ E.g., the class returned by
+ `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to
+ `tf.contrib.opt.AdamWOptimizer`.
+
+ The API of the new optimizer class slightly differs from the API of the
+ base optimizer:
+ - The first argument to the constructor is the weight decay rate.
+ - `minimize` and `apply_gradients` accept the optional keyword argument
+ `decay_var_list`, which specifies the variables that should be decayed.
+ If `None`, all variables that are optimized are decayed.
+
+ Usage example:
+ ```python
+ # MyAdamW is a new class
+ MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)
+ # Create a MyAdamW object
+ optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001)
+ sess.run(optimizer.minimize(loss, decay_variables=[var1, var2]))
+
+ Note that this extension decays weights BEFORE applying the update based
+ on the gradient, i.e. this extension only has the desired behaviour for
+ optimizers which do not depend on the value of'var' in the update step!
+ ```
+
+ Args:
+ base_optimizer: An optimizer class that inherits from tf.train.Optimizer.
+
+ Returns:
+ A new optimizer class that inherits from DecoupledWeightDecayExtension
+ and base_optimizer.
+ """
+
+ class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension,
+ base_optimizer):
+ """Base_optimizer with decoupled weight decay.
+
+ This class computes the update step of `base_optimizer` and
+ additionally decays the variable with the weight decay being decoupled from
+ the optimization steps w.r.t. to the loss function, as described by
+ Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf).
+ For SGD variants, this simplifies hyperparameter search since
+ it decouples the settings of weight decay and learning rate.
+ For adaptive gradient algorithms, it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield
+ better training loss and generalization error in the paper above.
+ """
+
+ def __init__(self, weight_decay, *args, **kwargs):
+ # super delegation is necessary here
+ # pylint: disable=useless-super-delegation
+ super(OptimizerWithDecoupledWeightDecay, self).__init__(
+ weight_decay, *args, **kwargs)
+ # pylint: enable=useless-super-delegation
+
+ return OptimizerWithDecoupledWeightDecay
+
+
+@tf_export("contrib.opt.MomentumWOptimizer")
+class MomentumWOptimizer(DecoupledWeightDecayExtension,
+ momentum_opt.MomentumOptimizer):
+ """Optimizer that implements the Momentum algorithm with weight_decay.
+
+ This is an implementation of the SGDW optimizer described in "Fixing
+ Weight Decay Regularization in Adam" by Loshchilov & Hutter
+ (https://arxiv.org/abs/1711.05101)
+ ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
+ It computes the update step of `train.MomentumOptimizer` and additionally
+ decays the variable. Note that this is different from adding
+ L2 regularization on the variables to the loss. Decoupling the weight decay
+ from other hyperparameters (in particular the learning rate) simplifies
+ hyperparameter search.
+
+ For further information see the documentation of the Momentum Optimizer.
+
+ Note that this optimizer can also be instantiated as
+ ```python
+ extend_with_weight_decay(tf.train.MomentumOptimizer,
+ weight_decay=weight_decay)
+ ```
+ """
+
+ def __init__(self, weight_decay, learning_rate, momentum,
+ use_locking=False, name="MomentumW", use_nesterov=False):
+ """Construct a new MomentumW optimizer.
+
+ For further information see the documentation of the Momentum Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ momentum: A `Tensor` or a floating point value. The momentum.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Momentum".
+ use_nesterov: If `True` use Nesterov Momentum.
+ See [Sutskever et al., 2013](
+ http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
+ This implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate, weight_decay and momentum
+ can each be a callable that takes no arguments and returns the actual value
+ to use. This can be useful for changing these values across different
+ invocations of optimizer functions.
+ @end_compatibility
+ """
+ super(MomentumWOptimizer, self).__init__(
+ weight_decay, learning_rate=learning_rate, momentum=momentum,
+ use_locking=use_locking, name=name, use_nesterov=use_nesterov)
+
+
+@tf_export("contrib.opt.AdamWOptimizer")
+class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
+ """Optimizer that implements the Adam algorithm with weight decay.
+
+ This is an implementation of the AdamW optimizer described in "Fixing
+ Weight Decay Regularization in Adam" by Loshchilov & Hutter
+ (https://arxiv.org/abs/1711.05101)
+ ([pdf])(https://arxiv.org/pdf/1711.05101.pdf).
+
+ It computes the update step of `train.AdamOptimizer` and additionally decays
+ the variable. Note that this is different from adding L2 regularization on
+ the variables to the loss: it regularizes variables with large
+ gradients more than L2 regularization would, which was shown to yield better
+ training loss and generalization error in the paper above.
+
+ For further information see the documentation of the Adam Optimizer.
+
+ Note that this optimizer can also be instantiated as
+ ```python
+ extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay)
+ ```
+ """
+
+ def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999,
+ epsilon=1e-8, use_locking=False, name="AdamW"):
+ """Construct a new AdamW optimizer.
+
+ For further information see the documentation of the Adam Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor.
+ The exponential decay rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor.
+ The exponential decay rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability. This epsilon is
+ "epsilon hat" in the Kingma and Ba paper (in the formula just before
+ Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ use_locking: If True use locks for update operations.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+ super(AdamWOptimizer, self).__init__(
+ weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
+ epsilon=epsilon, use_locking=use_locking, name=name)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
new file mode 100644
index 0000000000..76d8a5697a
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py
@@ -0,0 +1,188 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for optimizers with weight decay."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import weight_decay_optimizers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+
+WEIGHT_DECAY = 0.01
+
+
+def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9,
+ beta2=0.999, epsilon=1e-8):
+ lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) -
+ (param * WEIGHT_DECAY))
+ return param_t, m_t, v_t
+
+
+def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_):
+ # v, t are not needed for momentum optimizer
+ m = momentum * m + g_t
+ param_t = param - lr * m - param * WEIGHT_DECAY
+ return param_t, m, None
+
+
+class WeightDecayOptimizerTest(test.TestCase):
+
+ def doTest(self, optimizer, update_fn, optimizer_name, slot_name,
+ use_resource=False, do_sparse=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.test_session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+
+ if do_sparse:
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices),
+ constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices),
+ constant_op.constant([2]))
+ else:
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = optimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 3 steps of the optimizer
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0)
+ var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/%s:0" % (i, optimizer_name),
+ opt.get_slot(var=var0, name=slot_name).name)
+
+
+class AdamWOptimizerTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY)
+
+ def testSparse(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=False, do_sparse=True)
+
+ def testResourceSparse(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=True, do_sparse=True)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m",
+ use_resource=True)
+
+
+class MomentumWOptimizerTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9)
+
+ def testSparse(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=False, do_sparse=True)
+
+ def testResourceSparse(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=True, do_sparse=True)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW",
+ "momentum", use_resource=True)
+
+
+class ExtendWithWeightDecayTest(WeightDecayOptimizerTest):
+
+ @staticmethod
+ def get_optimizer():
+ adamw = weight_decay_optimizers.extend_with_decoupled_weight_decay(
+ adam.AdamOptimizer)
+ return adamw(WEIGHT_DECAY)
+
+ def testBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
+ use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m",
+ use_resource=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index d538ad0fb0..631d4f44df 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
# Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=state.get_hyper("beta1"),
+ state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
name="beta1_power")
- state.create_non_slot(initial_value=state.get_hyper("beta2"),
+ state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
name="beta2_power")
# Create slots for the first and second moments.
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 64b95786b5..06ab58188a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -43,15 +43,15 @@ from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
-class NonLayerCheckpointable(checkpointable.Checkpointable):
+class NonLayerCheckpointable(tracking.Checkpointable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
- self.a_variable = checkpointable_utils.add_variable(
+ self.a_variable = util.add_variable(
self, name="a_variable", shape=[])
@@ -88,29 +88,6 @@ class _MirroringSaveable(
self._mirrored_variable.assign(tensor))
-class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
- """A Checkpointable object which returns a more complex SaveableObject."""
-
- def __init__(self):
- self.non_dep_variable = variable_scope.get_variable(
- name="non_dep_variable", initializer=6., use_resource=True)
- self.mirrored = variable_scope.get_variable(
- name="mirrored", initializer=15., use_resource=True)
-
- def _gather_saveables_for_checkpoint(self):
- def _saveable_factory(name=self.non_dep_variable.name):
- return _MirroringSaveable(
- primary_variable=self.non_dep_variable,
- mirrored_variable=self.mirrored,
- name=name)
- return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
-
- # The Saver sorts by name before parsing, so we need a name property.
- @property
- def name(self):
- return self.non_dep_variable.name
-
-
class CheckpointingTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@@ -122,7 +99,7 @@ class CheckpointingTests(test.TestCase):
other_model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
if context.executing_eagerly():
optimizer.minimize(
@@ -137,11 +114,11 @@ class CheckpointingTests(test.TestCase):
optimizer.minimize(
other_model(input_value),
global_step=optimizer_step)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
named_variables, serialized_graph, _ = (
- checkpointable_utils._serialize_object_graph(
+ util._serialize_object_graph(
root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
@@ -226,11 +203,11 @@ class CheckpointingTests(test.TestCase):
optimizer_node.slot_variables[0]
.slot_variable_node_id].attributes[0].checkpoint_key)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveRestore(self):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model)
input_value = constant_op.constant([[3.]])
if context.executing_eagerly():
@@ -240,7 +217,7 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(model(input_value))
# TODO(allenl): Make initialization more pleasant when graph building.
root_checkpointable.save_counter # pylint: disable=pointless-statement
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
@@ -266,7 +243,7 @@ class CheckpointingTests(test.TestCase):
# Preserve beta1_power and beta2_power when appying gradients so we can
# test that they've been restored correctly.
beta1=1.0, beta2=1.0)
- on_create_root = checkpointable_utils.Checkpoint(
+ on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
@@ -298,7 +275,7 @@ class CheckpointingTests(test.TestCase):
for training_continuation in range(3):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(core_saver.latest_checkpoint(checkpoint_directory))
@@ -322,7 +299,7 @@ class CheckpointingTests(test.TestCase):
with ops.Graph().as_default():
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
input_value = constant_op.constant([[3.]])
@@ -347,7 +324,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAgnosticUsage(self):
"""Graph/eager agnostic usage."""
# Does create garbage when executing eagerly due to ops.Graph() creation.
@@ -359,7 +336,7 @@ class CheckpointingTests(test.TestCase):
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
@@ -381,7 +358,7 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
# pylint: disable=cell-var-from-loop
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWithDefun(self):
num_training_steps = 2
checkpoint_directory = self.get_temp_dir()
@@ -392,7 +369,7 @@ class CheckpointingTests(test.TestCase):
model = MyModel()
# Don't actually train so we can test variable values
optimizer = adam.AdamOptimizer(0.)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
@@ -442,7 +419,7 @@ class CheckpointingTests(test.TestCase):
optimizer = adam.AdamOptimizer(learning_rate=0.05)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- checkpoint = checkpointable_utils.Checkpoint(
+ checkpoint = util.Checkpoint(
model=model, optimizer=optimizer)
for _ in range(2):
checkpoint.save(checkpoint_prefix)
@@ -453,12 +430,12 @@ class CheckpointingTests(test.TestCase):
optimizer.apply_gradients(
[(g, v) for g, v in zip(grad, model.vars)])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = checkpointable.Checkpointable()
- root.var = checkpointable_utils.add_variable(
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
if context.executing_eagerly():
@@ -468,28 +445,28 @@ class CheckpointingTests(test.TestCase):
# Note that `optimizer` has not been added as a dependency of
# `root`. Create a one-off grouping so that slot variables for `root.var`
# get initialized too.
- self.evaluate(checkpointable_utils.gather_initializers(
- checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
self.evaluate(state_ops.assign(root.var, 12.))
- no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
+ no_slots_path = util.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
- slots_path = checkpointable_utils.CheckpointableSaver(root).save(
+ slots_path = util.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "with_slots"))
- new_root = checkpointable.Checkpointable()
+ new_root = tracking.Checkpointable()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
- slot_status = checkpointable_utils.CheckpointableSaver(
+ slot_status = util.CheckpointableSaver(
new_root).restore(slots_path)
- no_slot_status = checkpointable_utils.CheckpointableSaver(
+ no_slot_status = util.CheckpointableSaver(
new_root).restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
- new_root.var = checkpointable_utils.add_variable(
+ new_root.var = util.add_variable(
new_root, name="var", shape=[])
no_slot_status.assert_consumed()
no_slot_status.run_restore_ops()
@@ -525,12 +502,12 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
saver.save(checkpoint_prefix)
before_ops = graph.get_operations()
saver.save(checkpoint_prefix)
@@ -543,12 +520,12 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
save_path = saver.save(checkpoint_prefix)
saver.restore(save_path)
before_ops = graph.get_operations()
@@ -565,10 +542,10 @@ class CheckpointingTests(test.TestCase):
first_session = session_lib.Session(graph=first_graph)
with first_graph.as_default(), first_session.as_default():
first_variable = resource_variable_ops.ResourceVariable([1.])
- first_root_checkpointable = checkpointable_utils.Checkpoint(
+ first_root_checkpointable = util.Checkpoint(
optimizer=optimizer, variable=first_variable)
train_op = optimizer.minimize(first_variable.read_value)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
first_root_checkpointable))
self.evaluate(train_op)
self.evaluate(first_variable.assign([1.]))
@@ -581,7 +558,7 @@ class CheckpointingTests(test.TestCase):
second_graph = ops.Graph()
with second_graph.as_default(), session_lib.Session(graph=second_graph):
second_variable = resource_variable_ops.ResourceVariable([1.])
- second_root_checkpointable = checkpointable_utils.Checkpoint(
+ second_root_checkpointable = util.Checkpoint(
optimizer=optimizer, variable=second_variable)
train_op = optimizer.minimize(second_variable.read_value)
second_root_checkpointable.restore(None).initialize_or_restore()
@@ -616,7 +593,7 @@ class CheckpointingTests(test.TestCase):
class TemplateTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_checkpointable_save_restore(self):
def _templated():
@@ -631,7 +608,7 @@ class TemplateTests(test.TestCase):
save_template = template.make_template("s1", _templated)
v1_save, _, v2_save = save_template()
optimizer = adam.AdamOptimizer(0.0)
- save_root = checkpointable_utils.Checkpoint(
+ save_root = util.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
self.evaluate([v.initializer for v in optimizer.variables()])
@@ -643,7 +620,7 @@ class TemplateTests(test.TestCase):
load_template = template.make_template("s2", _templated)
load_optimizer = adam.AdamOptimizer(0.0)
- load_root = checkpointable_utils.Checkpoint(
+ load_root = util.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2 = load_template()
@@ -664,12 +641,12 @@ class CheckpointCompatibilityTests(test.TestCase):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
train_op = optimizer.minimize(
functools.partial(model, input_value),
global_step=optimizer_step)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
# A regular variable, a slot variable, and a non-slot Optimizer variable
@@ -712,7 +689,7 @@ class CheckpointCompatibilityTests(test.TestCase):
sess=session, save_path=checkpoint_prefix,
global_step=root.optimizer_step)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLoadFromNameBasedSaver(self):
"""Save a name-based checkpoint, load it using the object-based API."""
with test_util.device(use_gpu=True):
@@ -721,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase):
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = util.CheckpointableSaver(root)
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f537318b32..8c11d8bcfd 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -162,12 +162,12 @@ def _get_processor(v):
def _var_key_v2(var):
"""Key for representing a primary variable, for looking up slots."""
# pylint: disable=protected-access
- if hasattr(var, "_mirrored_container"):
- mirrored_container = var._mirrored_container()
- assert mirrored_container is not None
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
if context.executing_eagerly():
- return mirrored_container._unique_id
- return mirrored_container._shared_name
+ return distributed_container._unique_id
+ return distributed_container._shared_name
if context.executing_eagerly():
return var._unique_id
return var.op.name
@@ -211,8 +211,9 @@ class _OptimizerV2State(object):
# This dict starts with a single item with key "None" with the hyper
# parameter value converted to a Tensor. Other items have dtype keys
# with that Tensor cast to that dtype.
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in hyper.items() if not dynamic}
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in hyper.items() if not dynamic}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information
@@ -765,7 +766,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# *after* loss() is evaluated, so we know what loss reduction it uses.
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -783,7 +785,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Scale loss for number of towers (non-callable-loss case).
if scale_loss_by_num_towers is None:
scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() == "mean")
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
if scale_loss_by_num_towers:
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
@@ -895,7 +898,8 @@ class OptimizerV2(optimizer_v1.Optimizer):
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
index 8599af32f6..ec033c4a01 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
@@ -35,7 +35,7 @@ from tensorflow.python.platform import test
class OptimizerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBasic(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -113,7 +113,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
var1.eval())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoVariables(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
# pylint: disable=cell-var-from-loop
@@ -128,7 +128,7 @@ class OptimizerTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'No.*variables'):
sgd_op.minimize(loss)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -146,7 +146,7 @@ class OptimizerTest(test.TestCase):
# var1 has no gradient
sgd_op.minimize(loss, var_list=[var1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_Minimize(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -162,7 +162,7 @@ class OptimizerTest(test.TestCase):
'No gradients provided for any variable'):
sgd_op.minimize(loss, var_list=[var0, var1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_ApplyGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -176,7 +176,7 @@ class OptimizerTest(test.TestCase):
'No gradients provided for any variable'):
sgd_op.apply_gradients([(None, var0), (None, var1)])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientsAsVariables(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -216,7 +216,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([-14., -13.], self.evaluate(var0))
self.assertAllClose([-6., -5.], self.evaluate(var1))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testComputeGradientsWithTensors(self):
x = ops.convert_to_tensor(1.0)
def f():
diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD
index 6ca7fe8b6e..f2171efc95 100644
--- a/tensorflow/contrib/periodic_resample/BUILD
+++ b/tensorflow/contrib/periodic_resample/BUILD
@@ -6,12 +6,13 @@ exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
- "py_test",
+ "tf_cc_test",
"tf_gen_op_libs",
"tf_custom_op_library",
"tf_custom_op_py_library",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "py_test")
cc_library(
name = "all_ops",
@@ -84,6 +85,22 @@ py_test(
":init_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradient_checker",
+ ],
+)
+
+tf_cc_test(
+ name = "periodic_resample_op_cc_test",
+ size = "small",
+ srcs = [
+ "ops/array_ops_test.cc",
+ ],
+ deps = [
+ ":all_ops",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
],
)
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
index e18923c8aa..514689cf45 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc
@@ -22,4 +22,9 @@ namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU),
PeriodicResampleOp);
+
+REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad")
+ .Device(DEVICE_CPU),
+ PeriodicResampleOpGrad);
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
index 3ab588c458..42fba81a5c 100644
--- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
+++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h
@@ -25,92 +25,202 @@
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace {
-template <class IndexVecT, class IndexT>
-IndexT compute_input_index(
- IndexVecT* target_dimensions, const IndexT& output_index,
- const IndexVecT& original_dimensions, const int& adjustable_dimension,
- const std::vector<tensorflow::int64>& dimension_ceiling,
- const std::vector<tensorflow::int64>& cumulative_dimensions, IndexT* result,
- std::vector<IndexT>* output_indices, const int& rank) {
- *result = 0;
- output_indices->clear();
+// Computes input tensor index for given output index during forward
+// propagation through periodic_resample operation.
+class InputIndexer {
+ public:
+ InputIndexer(const std::vector<tensorflow::int64>& output_dimensions,
+ const tensorflow::TensorShape& input_shape,
+ int adjustable_dimension)
+ : output_dimensions_(output_dimensions),
+ adjustable_dimension_(adjustable_dimension),
+ rank_(input_shape.dims()),
+ linear_output_index_(0),
+ linear_input_index_(0),
+ adjustable_dimension_carriage_sum_(0) {
+ auto input_dimensions = TensorShapeToVector(input_shape);
+ // factors by which input_dimensions increases/decreases w.r.t.
+ // output_dimensions
+ dimension_ceiling_ =
+ ComputeDimensionCeiling(output_dimensions, input_dimensions);
+ cumulative_dimensions_ = ComputeCumulativeDimensions();
+
+ output_indices_.resize(output_dimensions_.size());
+ input_indices_.resize(output_dimensions_.size());
+
+ // Compute index_factors
+ index_factors_.resize(rank_);
+ tensorflow::int64 last_index_factor = 1;
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ index_factors_[r] = last_index_factor;
+ last_index_factor *= input_dimensions[r];
+ }
+ }
+
+ tensorflow::int64 linear_input_index() const { return linear_input_index_; }
+
+ void MoveToOutputIndex(tensorflow::int64 output_index);
+ void IncrementOutputIndex();
+
+ private:
+ void RecomputeInputAdjustableDimensionIndex() {
+ tensorflow::int64 index = adjustable_dimension_carriage_sum_;
+ index *= output_dimensions_[adjustable_dimension_];
+ index += output_indices_[adjustable_dimension_];
+ input_indices_[adjustable_dimension_] = index;
+ }
+
+ std::vector<tensorflow::int64> TensorShapeToVector(
+ const tensorflow::TensorShape& tensor_shape);
+
+ std::vector<tensorflow::int64> ComputeDimensionCeiling(
+ const std::vector<tensorflow::int64>& output_dimensions,
+ const std::vector<tensorflow::int64>& input_dimensions);
+
+ std::vector<tensorflow::int64> ComputeCumulativeDimensions();
+
+ const std::vector<tensorflow::int64> output_dimensions_;
+ std::vector<tensorflow::int64> dimension_ceiling_;
+ std::vector<tensorflow::int64> index_factors_;
+ std::vector<tensorflow::int64> cumulative_dimensions_;
+ std::vector<tensorflow::int64> output_indices_;
+ std::vector<tensorflow::int64> input_indices_;
+
+ const int adjustable_dimension_;
+ const int rank_;
+ tensorflow::int64 linear_output_index_;
+ tensorflow::int64 linear_input_index_;
+ tensorflow::int64 adjustable_dimension_carriage_sum_;
+};
+
+void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) {
+ linear_output_index_ = output_index;
+ linear_input_index_ = 0;
// un-rasterize the output index
auto last_reduced_i = output_index;
- for (auto r = rank - 1; r >= 0; --r) {
- (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r];
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ output_indices_[r] = last_reduced_i % output_dimensions_[r];
last_reduced_i =
- (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r];
+ (last_reduced_i - output_indices_[r]) / output_dimensions_[r];
}
+ tensorflow::int64 carriage_sum = 0;
+ for (int qi = 0; qi < rank_; ++qi) {
+ if (qi == adjustable_dimension_) continue;
+ carriage_sum += cumulative_dimensions_[qi] *
+ (output_indices_[qi] % dimension_ceiling_[qi]);
+ }
+ adjustable_dimension_carriage_sum_ = carriage_sum;
+
// rasterize the input index
- IndexT last_index_factor = 1;
- for (auto r = rank - 1; r >= 0; --r) {
- IndexT index = 0;
- if (r != adjustable_dimension)
- index = (*output_indices)[r] / dimension_ceiling[r];
- else {
- for (int qi = 0; qi < rank; ++qi) {
- if (qi == adjustable_dimension) continue;
- index += cumulative_dimensions[qi] *
- ((*output_indices)[qi] % dimension_ceiling[qi]);
- }
- index *= (*target_dimensions)[adjustable_dimension];
- index += (*output_indices)[r];
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ if (r != adjustable_dimension_) {
+ input_indices_[r] = output_indices_[r] / dimension_ceiling_[r];
+ } else {
+ RecomputeInputAdjustableDimensionIndex();
}
- *result += last_index_factor * index;
- last_index_factor *= original_dimensions[r];
}
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ linear_input_index_ += index_factors_[r] * input_indices_[r];
+ }
+}
+
+void InputIndexer::IncrementOutputIndex() {
+ linear_output_index_++;
+ for (auto r = rank_ - 1; r >= 0; --r) {
+ auto old_carriage_sum_increment =
+ cumulative_dimensions_[r] *
+ (output_indices_[r] % dimension_ceiling_[r]);
+ output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r];
+ if (r != adjustable_dimension_) {
+ auto new_input_index = output_indices_[r] / dimension_ceiling_[r];
+ linear_input_index_ +=
+ (new_input_index - input_indices_[r]) * index_factors_[r];
+
+ input_indices_[r] = new_input_index;
+
+ auto new_carriage_sum_increment =
+ cumulative_dimensions_[r] *
+ (output_indices_[r] % dimension_ceiling_[r]);
- return *result;
+ adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ -
+ old_carriage_sum_increment +
+ new_carriage_sum_increment;
+ }
+
+ if (output_indices_[r] != 0) {
+ // No more carries to higher indices.
+ break;
+ }
+ }
+ auto old_adjustable_dimension_input_index =
+ input_indices_[adjustable_dimension_];
+ RecomputeInputAdjustableDimensionIndex();
+ linear_input_index_ += (input_indices_[adjustable_dimension_] -
+ old_adjustable_dimension_input_index) *
+ index_factors_[adjustable_dimension_];
}
-template <class InputDataT,
- class IndexVecT> // both types are needed here b/c IndexVecT and
- // InputDataT are not related
- void
- fill_periodic_tensor(
- tensorflow::OpKernelContext* context,
- const IndexVecT& desired_shape,
- const tensorflow::Tensor& input_tensor) {
- // input is a strided array (last index is fastest, C-ordered)
- auto input = input_tensor.flat<InputDataT>();
- const int rank = input_tensor.dims();
- // original and target dimensions
- std::vector<tensorflow::int64> original_dimensions(rank),
- target_dimensions(rank);
- tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1);
- // factors by which original_dimensions increases/decreases w.r.t.
- // target_dimensions
- std::vector<tensorflow::int64> dimension_ceiling(rank),
- cumulative_dimensions(rank);
- // index of adjustable dimension
- int adjustable_dimension;
- tensorflow::TensorShape output_shape;
+std::vector<tensorflow::int64> InputIndexer::TensorShapeToVector(
+ const tensorflow::TensorShape& tensor_shape) {
+ std::vector<tensorflow::int64> result(tensor_shape.dims());
+ int count = 0;
+ for (const auto dim_info : tensor_shape) {
+ result[count] = dim_info.size;
+ ++count;
+ }
+ return result;
+}
- // requires that the rank of the input tensor and length of the desired shape
- // are equal
- OP_REQUIRES(context, rank == desired_shape.size(),
- tensorflow::errors::InvalidArgument(
- "periodic_resample expects the rank of the input tensor, ",
- rank, ", to be the same as the length of the desired shape, ",
- desired_shape.size(), "."));
+std::vector<tensorflow::int64> InputIndexer::ComputeDimensionCeiling(
+ const std::vector<tensorflow::int64>& output_dimensions,
+ const std::vector<tensorflow::int64>& input_dimensions) {
+ std::vector<tensorflow::int64> dimension_ceiling(input_dimensions.size());
+ for (size_t i = 0; i < input_dimensions.size(); ++i) {
+ dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) /
+ input_dimensions[i];
+ }
+ return dimension_ceiling;
+}
- bool found = false;
- const auto& input_tensor_shape = input_tensor.shape();
+std::vector<tensorflow::int64> InputIndexer::ComputeCumulativeDimensions() {
+ std::vector<tensorflow::int64> cumulative_dimensions(rank_);
+ int count = 0;
+ for (int i = 0; i < rank_; ++i) {
+ if (count == 0) {
+ cumulative_dimensions[count] = 1;
+ } else {
+ cumulative_dimensions[count] =
+ cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1];
+ }
+ ++count;
+ }
+ return cumulative_dimensions;
+}
+template <typename IndexVecT>
+void process_desired_shape(tensorflow::OpKernelContext* context,
+ const tensorflow::TensorShape& input_tensor_shape,
+ const IndexVecT& desired_shape,
+ int* adjustable_dimension,
+ std::vector<tensorflow::int64>* target_dimensions,
+ tensorflow::int64* output_size) {
+ tensorflow::int64 new_sliced_size = 1;
+ bool found = false;
+ const int rank = input_tensor_shape.dims();
for (int i = 0; i < rank; ++i) {
- // if (desired_shape(i) < 1) {
if (desired_shape[i] < 1) {
// only one index can be adjustable
OP_REQUIRES(context, !found,
tensorflow::errors::InvalidArgument(
"periodic_resample expects only "
"one index to be marked as adjustable."));
- adjustable_dimension = i;
+ *adjustable_dimension = i;
found = true;
} else {
OP_REQUIRES(
@@ -122,9 +232,8 @@ template <class InputDataT,
i, " input tensor has size ", input_tensor_shape.dim_size(i),
", desired shape has size ", desired_shape[i], "."));
- // target_dimensions[i] = desired_shape(i);
- target_dimensions[i] = desired_shape[i];
- new_sliced_size *= target_dimensions[i];
+ (*target_dimensions)[i] = desired_shape[i];
+ new_sliced_size *= (*target_dimensions)[i];
}
}
// at least one index needs to be adjustable
@@ -132,26 +241,50 @@ template <class InputDataT,
tensorflow::errors::InvalidArgument(
"periodic_resample expects at least "
"one index to be marked as adjustable."));
+ (*target_dimensions)[*adjustable_dimension] =
+ input_tensor_shape.num_elements() / new_sliced_size;
- int count = 0;
- for (const auto dim_info : input_tensor.shape()) {
- original_dimensions[count] = dim_info.size;
- ++count;
- }
+ *output_size = new_sliced_size * (*target_dimensions)[*adjustable_dimension];
+}
- target_dimensions[adjustable_dimension] = total_size / new_sliced_size;
+// Heuristic number based on measurements on
+// Intel(R) Core(TM) i7-4930K CPU @ 3.40GHz
+const tensorflow::int64 costPerFillIndex = 35;
- count = 0;
- for (int i = 0; i < input_tensor.shape().dims(); ++i) {
- dimension_ceiling[count] = tensorflow::int64(std::ceil(
- float(target_dimensions[count]) / float(original_dimensions[count])));
- if (count == 0)
- cumulative_dimensions[count] = 1;
- else
- cumulative_dimensions[count] =
- cumulative_dimensions[count - 1] * dimension_ceiling[count - 1];
- ++count;
- }
+enum class Mode {
+ kForward,
+ kGradient
+};
+
+// Computes either periodic_resample operation output or gradients for it,
+// depending on |mode|.
+// |original_shape| is always shape of input to periodic_resample operation.
+// |source_tensor| is either source for periodic_resample (for forward mode)
+// or gradients tensor.
+// |desired_shape| is always shape, provided by user, to which forward
+// propagation attempts resample input tensor.
+template <class InputDataT, Mode mode>
+void
+do_periodic_resample_op(tensorflow::OpKernelContext* context,
+ const tensorflow::TensorShape& original_shape,
+ const tensorflow::PartialTensorShape& desired_shape,
+ const tensorflow::Tensor& source_tensor) {
+ const int rank = source_tensor.dims();
+
+ // requires that the rank of the input tensor and length of the desired shape
+ // are equal
+ OP_REQUIRES(context, rank == desired_shape.dims(),
+ tensorflow::errors::InvalidArgument(
+ "periodic_resample expects the rank of the input tensor, ",
+ rank, ", to be the same as the length of the desired shape, ",
+ desired_shape.dims(), "."));
+
+ std::vector<tensorflow::int64> target_dimensions(rank);
+ tensorflow::int64 new_size = 0;
+ // index of adjustable dimension
+ int adjustable_dimension = 0;
+ process_desired_shape(context, original_shape, desired_shape.dim_sizes(),
+ &adjustable_dimension, &target_dimensions, &new_size);
// ensure that the new dimension is greater than zero
OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0,
@@ -160,11 +293,14 @@ template <class InputDataT,
"adjustable dimension, ",
adjustable_dimension, ", isn't greater than zero, ",
target_dimensions[adjustable_dimension], "."));
- for (int i = 0; i < rank; ++i) {
- output_shape.AddDim(target_dimensions[i]);
+ tensorflow::TensorShape output_shape;
+ if (mode == Mode::kForward) {
+ for (int i = 0; i < rank; ++i) {
+ output_shape.AddDim(target_dimensions[i]);
+ }
+ } else {
+ output_shape = original_shape;
}
- const auto new_size =
- new_sliced_size * target_dimensions[adjustable_dimension];
// Create an output tensor and attach it to the current context
tensorflow::Tensor* output_tensor = nullptr;
@@ -172,47 +308,73 @@ template <class InputDataT,
context->allocate_output(0, output_shape, &output_tensor));
auto output = output_tensor->flat<InputDataT>();
- // memory is allocated for these variables outside the inner loop for
- // efficiency (although, I could create a separate class scope for
- // this purpose instead)
- tensorflow::int64 result = 0;
- std::vector<tensorflow::int64> output_indices(target_dimensions.size());
+ // input is a strided array (last index is fastest, C-ordered)
+ auto input = source_tensor.flat<InputDataT>();
// Fill output tensor with periodically resampled input tensor values
- for (tensorflow::int64 output_index = 0; output_index < new_size;
- ++output_index) {
- output(output_index) = input(compute_input_index(
- &target_dimensions, output_index, original_dimensions,
- adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result,
- &output_indices, rank));
- }
+ InputIndexer input_indexer(target_dimensions, original_shape,
+ adjustable_dimension);
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ auto fill_output_tensor = [&input_indexer, &output, &input](
+ tensorflow::int64 start, tensorflow::int64 limit) {
+ InputIndexer local_indexer(input_indexer);
+ local_indexer.MoveToOutputIndex(start);
+ for (tensorflow::int64 output_index = start; output_index < limit;
+ ++output_index) {
+ if (mode == Mode::kForward) {
+ output(output_index) = input(local_indexer.linear_input_index());
+ } else {
+ output(local_indexer.linear_input_index()) = input(output_index);
+ }
+ local_indexer.IncrementOutputIndex();
+ }
+ };
+ ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers,
+ new_size, costPerFillIndex, fill_output_tensor);
}
+#define DATA_TYPE_SWITCH(data_type, context, CASE) \
+ switch (data_type) { \
+ CASE(float) \
+ CASE(double) \
+ CASE(tensorflow::int32) \
+ CASE(tensorflow::int64) \
+ default: \
+ context->CtxFailure(__FILE__, __LINE__, \
+ tensorflow::errors::InvalidArgument( \
+ "Unsuppored tensor elements type")); \
+ break; \
+ }
+
void create_output_tensor(
tensorflow::OpKernelContext* context,
const tensorflow::Tensor& input_tensor,
const tensorflow::DataType& input_tensor_type,
- const tensorflow::PartialTensorShape& desired_shape_tensor) {
- auto desired_shape = desired_shape_tensor.dim_sizes();
-
- // obligatory type switch
- switch (input_tensor_type) {
- case tensorflow::DataTypeToEnum<float>::value:
- fill_periodic_tensor<float>(context, desired_shape, input_tensor);
+ const tensorflow::PartialTensorShape& desired_shape) {
+#define CASE(type) \
+ case tensorflow::DataTypeToEnum<type>::value: \
+ do_periodic_resample_op<type, Mode::kForward>( \
+ context, input_tensor.shape(), desired_shape, input_tensor); \
break;
- case tensorflow::DataTypeToEnum<double>::value:
- fill_periodic_tensor<double>(context, desired_shape, input_tensor);
- break;
- case tensorflow::DataTypeToEnum<tensorflow::int32>::value:
- fill_periodic_tensor<tensorflow::int32>(context, desired_shape,
- input_tensor);
- break;
- case tensorflow::DataTypeToEnum<tensorflow::int64>::value:
- fill_periodic_tensor<tensorflow::int64>(context, desired_shape,
- input_tensor);
+
+ DATA_TYPE_SWITCH(input_tensor_type, context, CASE);
+#undef CASE
+}
+
+void create_grad_tensor(tensorflow::OpKernelContext* context,
+ const tensorflow::Tensor& grad_tensor,
+ const tensorflow::DataType& grad_tensor_type,
+ const tensorflow::TensorShape& original_shape,
+ const tensorflow::PartialTensorShape& desired_shape) {
+#define CASE(type) \
+ case tensorflow::DataTypeToEnum<type>::value: \
+ do_periodic_resample_op<type, Mode::kGradient>( \
+ context, original_shape, desired_shape, grad_tensor); \
break;
- default:;
- }
+
+ DATA_TYPE_SWITCH(grad_tensor_type, context, CASE);
+#undef CASE
}
} // namespace
@@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel {
tensorflow::PartialTensorShape desired_shape;
};
+class PeriodicResampleOpGrad : public tensorflow::OpKernel {
+ public:
+ explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context)
+ : tensorflow::OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("original_shape", &original_shape));
+ OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape));
+ }
+
+ void Compute(tensorflow::OpKernelContext* context) override {
+ const tensorflow::Tensor& grad_tensor = context->input(0);
+ const tensorflow::DataType grad_tensor_type = context->input_dtype(0);
+ create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape,
+ desired_shape);
+ }
+
+ private:
+ tensorflow::TensorShape original_shape;
+ tensorflow::PartialTensorShape desired_shape;
+};
+
#endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
index 82bd796956..fd38cd09b4 100644
--- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc
@@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample")
.Input("values: T")
.Attr("shape: shape")
.Output("output: T")
- .SetShapeFn(shape_inference::ExplicitShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+ shape_inference::ShapeHandle input_tensor_shape = c->input(0);
+ shape_inference::DimensionHandle num_input_elements =
+ c->NumElements(input_tensor_shape);
+ shape_inference::ShapeHandle result_shape_handle;
+ if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) {
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &result_shape_handle));
+ } else {
+ const int rank = c->Rank(input_tensor_shape);
+ std::vector<tensorflow::int64> target_dimensions(rank);
+ tensorflow::int64 new_sliced_size = 1;
+ int adjustable_dimension = 0;
+ for (int i = 0; i < rank; ++i) {
+ if (desired_shape.dim_size(i) < 1) {
+ adjustable_dimension = i;
+ } else {
+ target_dimensions[i] = desired_shape.dim_size(i);
+ new_sliced_size *= target_dimensions[i];
+ }
+ }
+ target_dimensions[adjustable_dimension] =
+ shape_inference::InferenceContext::Value(
+ num_input_elements) / new_sliced_size;
+ tensorflow::TensorShape result_shape;
+ for (int i = 0; i < rank; ++i) {
+ result_shape.AddDim(target_dimensions[i]);
+ }
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(
+ result_shape, &result_shape_handle));
+ }
+ c->set_output(0, result_shape_handle);
+ return Status::OK();
+ })
.Doc(R"doc(
Periodically resample elements of a tensor to conform to `shape`.
@@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in
)doc");
+
+REGISTER_OP("PeriodicResampleOpGrad")
+ .Attr("T: numbertype")
+ .Input("grad: T")
+ .Attr("original_shape: shape")
+ .Attr("desired_shape: shape")
+ .Output("grad_values: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ tensorflow::TensorShape original_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s));
+ c->set_output(0, s);
+ return Status::OK();
+});
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
new file mode 100644
index 0000000000..43b7c1799f
--- /dev/null
+++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ArrayOpsTest, PeriodicResample_ShapeFn) {
+ ShapeInferenceTestOp op("PeriodicResample");
+ // Case 1: output shape can be fully inferreed.
+ PartialTensorShape shape({4, 4, -1});
+ TensorShapeProto shape_proto;
+ shape.AsProto(&shape_proto);
+
+ TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample")
+ .Input({"values", 0, DT_INT32})
+ .Attr("shape", shape_proto)
+ .Finalize(&op.node_def));
+ INFER_OK(op, "[2,2,4]", "[4,4,1]");
+ // Case 2: output shape can not be inferred - report desired shape.
+ INFER_OK(op, "[2,2,?]", "[4,4,?]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index a25de55e18..31a6fe1d94 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import numpy
from tensorflow.contrib.periodic_resample import periodic_resample
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
with self.test_session():
- variables.global_variables_initializer().run()
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
'4, to be the same as the length of the desired shape, 3'):
periodic_resample(input_tensor, [None, 4, 4]).eval()
+ def testPeriodicResampleGradient(self):
+ desired_shape = numpy.array([4, 4, None])
+ result_shape = (4, 4, 1)
+ input_shape = (2, 2, 4)
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=input_shape)
+ output = periodic_resample(x, desired_shape)
+ error = gradient_checker.compute_gradient_error(
+ x, input_shape, output, result_shape)
+ self.assertLess(error, 1e-4)
+
+ def testPeriodicResampleShapeInference(self):
+ with self.test_session() as sess:
+ # Case 1: output shape can be fully inferreed.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertEqual(output.shape, [4, 4, 1])
+ # Case 2: output shape can not be inferred - report desired shape.
+ x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None))
+ output = periodic_resample(x, [4, 4, None])
+ self.assertTrue(output.shape.is_compatible_with([4, 4, None]))
+ self.assertEqual(output.shape[2].value, None)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
index 348623d8f8..470e300ccb 100644
--- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
+++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py
@@ -21,11 +21,17 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op
-from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample
+from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad
from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
# pylint: enable=unused-import
_periodic_resample_op = loader.load_op_library(
resource_loader.get_path_to_datafile('_periodic_resample_op.so'))
+
+@ops.RegisterGradient("PeriodicResample")
+def _periodic_resample_grad_cc(op, grad):
+ return periodic_resample_op_grad(
+ grad, op.inputs[0].shape, op.get_attr('shape'))
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
index b7a98c68e2..af3b2ad1b5 100644
--- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `ContribEstimatorPredictor`.
Args:
@@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_filename_with_path=checkpoint_path))
input_alternative_key = (
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py
index d78d94c269..a725072e72 100644
--- a/tensorflow/contrib/predictor/core_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/core_estimator_predictor.py
@@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor):
estimator,
serving_input_receiver_fn,
output_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
@@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
checkpoint_dir = estimator.model_dir
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_dir=checkpoint_dir))
feed_tensor_info = signature_def.inputs
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
index 6e77e934fe..f275bc15ad 100644
--- a/tensorflow/contrib/predictor/predictor_factories.py
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -30,7 +30,8 @@ def from_contrib_estimator(estimator,
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `tf.contrib.learn.Estimator`.
Args:
@@ -44,6 +45,7 @@ def from_contrib_estimator(estimator,
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -62,13 +64,15 @@ def from_contrib_estimator(estimator,
prediction_input_fn,
input_alternative_key=input_alternative_key,
output_alternative_key=output_alternative_key,
- graph=graph)
+ graph=graph,
+ config=config)
def from_estimator(estimator,
serving_input_receiver_fn,
output_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `tf.python.estimator.Estimator`.
Args:
@@ -79,6 +83,7 @@ def from_estimator(estimator,
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -93,14 +98,19 @@ def from_estimator(estimator,
'tf.contrib.learn.Estimator. You likely want to call '
'from_contrib_estimator.')
return core_estimator_predictor.CoreEstimatorPredictor(
- estimator, serving_input_receiver_fn, output_key=output_key, graph=graph)
+ estimator,
+ serving_input_receiver_fn,
+ output_key=output_key,
+ graph=graph,
+ config=config)
def from_saved_model(export_dir,
signature_def_key=None,
signature_def=None,
tags=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `SavedModel` on disk.
Args:
@@ -115,6 +125,7 @@ def from_saved_model(export_dir,
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -128,4 +139,5 @@ def from_saved_model(export_dir,
signature_def_key=signature_def_key,
signature_def=signature_def,
tags=tags,
- graph=graph)
+ graph=graph,
+ config=config)
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
index 578d9424b2..a2ef1dc3af 100644
--- a/tensorflow/contrib/predictor/predictor_factories_test.py
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.predictor import predictor_factories
from tensorflow.contrib.predictor import testing_common
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import test
MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
@@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase):
"""Test loading from_saved_model with tags."""
predictor_factories.from_saved_model(self._export_dir, tags='serve')
+ def testFromSavedModelWithSessionConfig(self):
+ """Test loading from_saved_model with session config."""
+ predictor_factories.from_saved_model(
+ self._export_dir, config=config_pb2.ConfigProto())
+
def testFromSavedModelWithBadTags(self):
"""Test that loading fails for bad tags."""
bad_tags_regex = ('.*? could not be found in SavedModel')
@@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase):
predictor_factories.from_contrib_estimator(
estimator, input_fn, output_alternative_key='sum')
+ def testFromContribEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=False)
+ input_fn = testing_common.get_arithmetic_input_fn(core=False)
+ predictor_factories.from_contrib_estimator(
+ estimator, input_fn, output_alternative_key='sum',
+ config=config_pb2.ConfigProto())
+
def testFromContribEstimatorWithCoreEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=True)
input_fn = testing_common.get_arithmetic_input_fn(core=True)
@@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase):
input_fn = testing_common.get_arithmetic_input_fn(core=True)
predictor_factories.from_estimator(estimator, input_fn)
+ def testFromCoreEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=True)
+ input_fn = testing_common.get_arithmetic_input_fn(core=True)
+ predictor_factories.from_estimator(
+ estimator, input_fn, config=config_pb2.ConfigProto())
+
def testFromCoreEstimatorWithContribEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=False)
input_fn = testing_common.get_arithmetic_input_fn(core=False)
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index 0dbca0f813..95da6d04ed 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -121,7 +121,8 @@ class SavedModelPredictor(predictor.Predictor):
input_names=None,
output_names=None,
tags=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
@@ -142,6 +143,7 @@ class SavedModelPredictor(predictor.Predictor):
the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Raises:
ValueError: If more than one of signature_def_key OR signature_def OR
(input_names AND output_names) is specified.
@@ -152,7 +154,7 @@ class SavedModelPredictor(predictor.Predictor):
self._graph = graph or ops.Graph()
with self._graph.as_default():
- self._session = session.Session()
+ self._session = session.Session(config=config)
loader.load(self._session, tags.split(','), export_dir)
if input_names is None:
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
index 3e9b1a0b8d..d45622174f 100644
--- a/tensorflow/contrib/proto/BUILD
+++ b/tensorflow/contrib/proto/BUILD
@@ -19,9 +19,7 @@ py_library(
py_library(
name = "proto_pip",
- data = [
- "//tensorflow/contrib/proto/python/kernel_tests:test_messages",
- ] + if_static(
+ data = if_static(
[],
otherwise = ["//tensorflow/contrib/proto/python/kernel_tests:libtestexample.so"],
),
diff --git a/tensorflow/contrib/proto/python/kernel_tests/BUILD b/tensorflow/contrib/proto/python/kernel_tests/BUILD
index a380a131f8..3f53ef1707 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/proto/python/kernel_tests/BUILD
@@ -4,33 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-# Much of the work in this BUILD file actually happens in the corresponding
-# build_defs.bzl, which creates an individual testcase for each example .pbtxt
-# file in this directory.
-#
-load(":build_defs.bzl", "decode_proto_test_suite")
-load(":build_defs.bzl", "encode_proto_test_suite")
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :decode_proto_op_tests.
-decode_proto_test_suite(
- name = "decode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# This expands to a tf_py_test for each test file.
-# It defines the test_suite :encode_proto_op_tests.
-encode_proto_test_suite(
- name = "encode_proto_tests",
- examples = glob(["*.pbtxt"]),
-)
-
-# Below here are tests that are not tied to an example text proto.
-filegroup(
- name = "test_messages",
- srcs = glob(["*.pbtxt"]),
-)
-
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
@@ -56,16 +29,62 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "decode_proto_op_test",
+ size = "small",
+ srcs = ["decode_proto_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
+tf_py_test(
+ name = "encode_proto_op_test",
+ size = "small",
+ srcs = ["encode_proto_op_test.py"],
+ additional_deps = [
+ ":py_test_deps",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/proto:proto",
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+ data = if_static(
+ [],
+ otherwise = [":libtestexample.so"],
+ ),
+ tags = [
+ "no_pip", # TODO(b/78026780)
+ "no_windows", # TODO(b/78028010)
+ ],
+)
+
py_library(
- name = "test_case",
- srcs = ["test_case.py"],
- deps = ["//tensorflow/python:client_testlib"],
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
+ ":test_example_proto_py",
+ "//tensorflow/python:client_testlib",
+ ],
)
py_library(
name = "py_test_deps",
deps = [
- ":test_case",
+ ":test_base",
":test_example_proto_py",
],
)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl b/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
deleted file mode 100644
index f425601691..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/build_defs.bzl
+++ /dev/null
@@ -1,89 +0,0 @@
-"""BUILD rules for generating file-driven proto test cases.
-
-The decode_proto_test_suite() and encode_proto_test_suite() rules take a list
-of text protos and generates a tf_py_test() for each one.
-"""
-
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
-load("//tensorflow:tensorflow.bzl", "register_extension_info")
-load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
-
-def _test_name(test, path):
- return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
-
-def decode_proto_test_suite(name, examples):
- """Build the decode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("decode_proto", test_filename),
- srcs = ["decode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "decode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("decode_proto", test_filename)
- for test_filename in examples],
- )
-
-def encode_proto_test_suite(name, examples):
- """Build the encode_proto py_test for each test filename."""
- for test_filename in examples:
- tf_py_test(
- name = _test_name("encode_proto", test_filename),
- srcs = ["encode_proto_op_test.py"],
- size = "small",
- data = [test_filename] + if_static(
- [],
- otherwise = [":libtestexample.so"],
- ),
- main = "encode_proto_op_test.py",
- args = [
- "--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
- ],
- additional_deps = [
- ":py_test_deps",
- "//third_party/py/numpy",
- "//tensorflow/contrib/proto:proto",
- "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
- "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
- ],
- tags = [
- "no_pip", # TODO(b/78026780)
- "no_windows", # TODO(b/78028010)
- ],
- )
- native.test_suite(
- name = name,
- tests = [":" + _test_name("encode_proto", test_filename)
- for test_filename in examples],
- )
-
-register_extension_info(
- extension_name = "decode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:decode_example_.*",
- })
-
-register_extension_info(
- extension_name = "encode_proto_test_suite",
- label_regex_map = {
- "deps": "deps:encode_example_.*",
- })
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
index 5298342ee7..3b982864bc 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_fail_test.py
@@ -21,14 +21,14 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class DecodeProtoFailTest(test_case.ProtoOpTestCase):
+class DecodeProtoFailTest(test_base.ProtoOpTestBase):
"""Test failure cases for DecodeToProto."""
def _TestCorruptProtobuf(self, sanitize):
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
index d1c13c82bc..2a07794499 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test.py
@@ -23,24 +23,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
+
from google.protobuf import text_format
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import flags
from tensorflow.python.platform import test
-FLAGS = flags.FLAGS
-
-flags.DEFINE_string('message_text_file', None,
- 'A file containing a text serialized TestCase protobuf.')
-
-class DecodeProtoOpTest(test_case.ProtoOpTestCase):
+class DecodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase):
def _compareValues(self, fd, vs, evs):
"""Compare lists/arrays of field values."""
@@ -203,10 +199,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
self._compareRepeatedPrimitiveValue(batch_shape, sizes, fields,
field_dict)
- def testBinary(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinary(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
@@ -217,10 +211,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'binary',
sanitize=False)
- def testBinaryDisordered(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testBinaryDisordered(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
@@ -232,10 +224,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
sanitize=False,
force_disordered=True)
- def testPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testPacked(self, case):
# Now try with the packed serialization.
# We test the packed representations by loading the same test cases
# using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
@@ -261,10 +251,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'binary',
sanitize=False)
- def testText(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testText(self, case):
# Note: float_format='.17g' is necessary to ensure preservation of
# doubles and floats in text format.
text_batch = [
@@ -281,10 +269,8 @@ class DecodeProtoOpTest(test_case.ProtoOpTestCase):
'text',
sanitize=False)
- def testSanitizerGood(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testSanitizerGood(self, case):
batch = [primitive.SerializeToString() for primitive in case.primitive]
self._runDecodeProtoTests(
case.field,
diff --git a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
deleted file mode 100644
index 4e31681907..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/defaut_values.TestCase.pbtxt
+++ /dev/null
@@ -1,94 +0,0 @@
-primitive {
- # No fields specified, so we get all defaults
-}
-shape: 1
-sizes: 0
-field {
- name: "double_default"
- dtype: DT_DOUBLE
- expected { double_value: 1.0 }
-}
-sizes: 0
-field {
- name: "float_default"
- dtype: DT_DOUBLE # Try casting the float field to double.
- expected { double_value: 2.0 }
-}
-sizes: 0
-field {
- name: "int64_default"
- dtype: DT_INT64
- expected { int64_value: 3 }
-}
-sizes: 0
-field {
- name: "uint64_default"
- dtype: DT_INT64
- expected { int64_value: 4 }
-}
-sizes: 0
-field {
- name: "int32_default"
- dtype: DT_INT32
- expected { int32_value: 5 }
-}
-sizes: 0
-field {
- name: "fixed64_default"
- dtype: DT_INT64
- expected { int64_value: 6 }
-}
-sizes: 0
-field {
- name: "fixed32_default"
- dtype: DT_INT32
- expected { int32_value: 7 }
-}
-sizes: 0
-field {
- name: "bool_default"
- dtype: DT_BOOL
- expected { bool_value: true }
-}
-sizes: 0
-field {
- name: "string_default"
- dtype: DT_STRING
- expected { string_value: "a" }
-}
-sizes: 0
-field {
- name: "bytes_default"
- dtype: DT_STRING
- expected { string_value: "a longer default string" }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT32
- expected { int32_value: -1 }
-}
-sizes: 0
-field {
- name: "sfixed32_default"
- dtype: DT_INT32
- expected { int32_value: 10 }
-}
-sizes: 0
-field {
- name: "sfixed64_default"
- dtype: DT_INT64
- expected { int64_value: 11 }
-}
-sizes: 0
-field {
- name: "sint32_default"
- dtype: DT_INT32
- expected { int32_value: 12 }
-}
-sizes: 0
-field {
- name: "sint64_default"
- dtype: DT_INT64
- expected { int64_value: 13 }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
index 30e58e6336..fb33660554 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test.py
@@ -26,11 +26,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
-from tensorflow.contrib.proto.python.kernel_tests import test_case
+from tensorflow.contrib.proto.python.kernel_tests import test_base
from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
from tensorflow.contrib.proto.python.ops import decode_proto_op
from tensorflow.contrib.proto.python.ops import encode_proto_op
@@ -45,7 +46,7 @@ flags.DEFINE_string('message_text_file', None,
'A file containing a text serialized TestCase protobuf.')
-class EncodeProtoOpTest(test_case.ProtoOpTestCase):
+class EncodeProtoOpTest(test_base.ProtoOpTestBase, parameterized.TestCase):
def testBadInputs(self):
# Invalid field name
@@ -139,10 +140,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
# loss of packing in the encoding).
self.assertEqual(in_buf, out_buf)
- def testRoundtrip(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtrip(self, case):
in_bufs = [primitive.SerializeToString() for primitive in case.primitive]
# np.array silently truncates strings if you don't specify dtype=object.
@@ -150,10 +149,8 @@ class EncodeProtoOpTest(test_case.ProtoOpTestCase):
return self._testRoundtrip(
in_bufs, 'tensorflow.contrib.proto.RepeatedPrimitiveValue', case.field)
- def testRoundtripPacked(self):
- with open(FLAGS.message_text_file, 'r') as fp:
- case = text_format.Parse(fp.read(), test_example_pb2.TestCase())
-
+ @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ def testRoundtripPacked(self, case):
# Now try with the packed serialization.
# We test the packed representations by loading the same test cases
# using PackedPrimitiveValue instead of RepeatedPrimitiveValue.
diff --git a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
deleted file mode 100644
index b170f89c0f..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/minmax.TestCase.pbtxt
+++ /dev/null
@@ -1,161 +0,0 @@
-primitive {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- uint64_value: 0
- uint64_value: 18446744073709551615
- int32_value: -2147483648
- int32_value: 2147483647
- fixed64_value: 0
- fixed64_value: 18446744073709551615
- fixed32_value: 0
- fixed32_value: 4294967295
- bool_value: false
- bool_value: true
- string_value: ""
- string_value: "I refer to the infinite."
- uint32_value: 0
- uint32_value: 4294967295
- sfixed32_value: -2147483648
- sfixed32_value: 2147483647
- sfixed64_value: -9223372036854775808
- sfixed64_value: 9223372036854775807
- sint32_value: -2147483648
- sint32_value: 2147483647
- sint64_value: -9223372036854775808
- sint64_value: 9223372036854775807
-}
-shape: 1
-sizes: 3
-sizes: 3
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-sizes: 2
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: -1.7976931348623158e+308
- double_value: 2.2250738585072014e-308
- double_value: 1.7976931348623158e+308
- }
-}
-field {
- name: "float_value"
- dtype: DT_FLOAT
- expected {
- float_value: -3.402823466e+38
- float_value: 1.175494351e-38
- float_value: 3.402823466e+38
- }
-}
-field {
- name: "int64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "uint64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1
- }
-}
-field {
- name: "int32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "fixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: 0
- int64_value: -1 # unsigned is 18446744073709551615
- }
-}
-field {
- name: "fixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: false
- bool_value: true
- }
-}
-field {
- name: "string_value"
- dtype: DT_STRING
- expected {
- string_value: ""
- string_value: "I refer to the infinite."
- }
-}
-field {
- name: "uint32_value"
- dtype: DT_INT32
- expected {
- int32_value: 0
- int32_value: -1 # unsigned is 4294967295
- }
-}
-field {
- name: "sfixed32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sfixed64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
-field {
- name: "sint32_value"
- dtype: DT_INT32
- expected {
- int32_value: -2147483648
- int32_value: 2147483647
- }
-}
-field {
- name: "sint64_value"
- dtype: DT_INT64
- expected {
- int64_value: -9223372036854775808
- int64_value: 9223372036854775807
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
deleted file mode 100644
index c664e52851..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/nested.TestCase.pbtxt
+++ /dev/null
@@ -1,16 +0,0 @@
-primitive {
- message_value {
- double_value: 23.5
- }
-}
-shape: 1
-sizes: 1
-field {
- name: "message_value"
- dtype: DT_STRING
- expected {
- message_value {
- double_value: 23.5
- }
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
deleted file mode 100644
index 125651d7ea..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/optional.TestCase.pbtxt
+++ /dev/null
@@ -1,20 +0,0 @@
-primitive {
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 0
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 0.0
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
deleted file mode 100644
index bc07efc8f3..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/promote_unsigned.TestCase.pbtxt
+++ /dev/null
@@ -1,29 +0,0 @@
-primitive {
- fixed32_value: 4294967295
- uint32_value: 4294967295
-}
-shape: 1
-sizes: 1
-field {
- name: "fixed32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 1
-field {
- name: "uint32_value"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295
- }
-}
-sizes: 0
-field {
- name: "uint32_default"
- dtype: DT_INT64
- expected {
- int64_value: 4294967295 # Comes from an explicitly-specified default
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
deleted file mode 100644
index 61c7ac53f7..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/ragged.TestCase.pbtxt
+++ /dev/null
@@ -1,32 +0,0 @@
-primitive {
- double_value: 23.5
- double_value: 123.0
- bool_value: true
-}
-primitive {
- double_value: 3.1
- bool_value: false
-}
-shape: 2
-sizes: 2
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 123.0
- double_value: 3.1
- double_value: 0.0
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
deleted file mode 100644
index f4828076d5..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/shaped_batch.TestCase.pbtxt
+++ /dev/null
@@ -1,62 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-primitive {
- double_value: 44.0
- bool_value: false
-}
-primitive {
- double_value: 3.14159
- bool_value: true
-}
-primitive {
- double_value: 1.414
- bool_value: true
-}
-primitive {
- double_value: -32.2
- bool_value: false
-}
-primitive {
- double_value: 0.0001
- bool_value: true
-}
-shape: 3
-shape: 2
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- double_value: 44.0
- double_value: 3.14159
- double_value: 1.414
- double_value: -32.2
- double_value: 0.0001
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- bool_value: false
- bool_value: true
- bool_value: true
- bool_value: false
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt b/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
deleted file mode 100644
index dc20ac147b..0000000000
--- a/tensorflow/contrib/proto/python/kernel_tests/simple.TestCase.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-primitive {
- double_value: 23.5
- bool_value: true
-}
-shape: 1
-sizes: 1
-sizes: 1
-field {
- name: "double_value"
- dtype: DT_DOUBLE
- expected {
- double_value: 23.5
- }
-}
-field {
- name: "bool_value"
- dtype: DT_BOOL
- expected {
- bool_value: true
- }
-}
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_base.py b/tensorflow/contrib/proto/python/kernel_tests/test_base.py
new file mode 100644
index 0000000000..1fc8c16786
--- /dev/null
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_base.py
@@ -0,0 +1,407 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Test case base for testing proto operations."""
+
+# Python3 preparedness imports.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes as ct
+import os
+
+from tensorflow.contrib.proto.python.kernel_tests import test_example_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.platform import test
+
+
+class ProtoOpTestBase(test.TestCase):
+ """Base class for testing proto decoding and encoding ops."""
+
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(ProtoOpTestBase, self).__init__(methodName)
+ lib = os.path.join(os.path.dirname(__file__), "libtestexample.so")
+ if os.path.isfile(lib):
+ ct.cdll.LoadLibrary(lib)
+
+ @staticmethod
+ def named_parameters():
+ return (
+ ("defaults", ProtoOpTestBase.defaults_test_case()),
+ ("minmax", ProtoOpTestBase.minmax_test_case()),
+ ("nested", ProtoOpTestBase.nested_test_case()),
+ ("optional", ProtoOpTestBase.optional_test_case()),
+ ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
+ ("ragged", ProtoOpTestBase.ragged_test_case()),
+ ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
+ ("simple", ProtoOpTestBase.simple_test_case()),
+ )
+
+ @staticmethod
+ def defaults_test_case():
+ test_case = test_example_pb2.TestCase()
+ test_case.primitive.add() # No fields specified, so we get all defaults.
+ test_case.shape.append(1)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "double_default"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(1.0)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "float_default"
+ field.dtype = types_pb2.DT_FLOAT
+ field.expected.float_value.append(2.0)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "int64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(3)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sfixed64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(11)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sint64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(13)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "fixed64_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(6)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "int32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(5)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sfixed32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(10)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "sint32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(12)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "fixed32_default"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(7)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "bool_default"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "string_default"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("a")
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "bytes_default"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("a longer default string")
+ return test_case
+
+ @staticmethod
+ def minmax_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(-1.7976931348623158e+308)
+ primitive.double_value.append(2.2250738585072014e-308)
+ primitive.double_value.append(1.7976931348623158e+308)
+ primitive.float_value.append(-3.402823466e+38)
+ primitive.float_value.append(1.175494351e-38)
+ primitive.float_value.append(3.402823466e+38)
+ primitive.int64_value.append(-9223372036854775808)
+ primitive.int64_value.append(9223372036854775807)
+ primitive.sfixed64_value.append(-9223372036854775808)
+ primitive.sfixed64_value.append(9223372036854775807)
+ primitive.sint64_value.append(-9223372036854775808)
+ primitive.sint64_value.append(9223372036854775807)
+ primitive.uint64_value.append(0)
+ primitive.uint64_value.append(18446744073709551615)
+ primitive.fixed64_value.append(0)
+ primitive.fixed64_value.append(18446744073709551615)
+ primitive.int32_value.append(-2147483648)
+ primitive.int32_value.append(2147483647)
+ primitive.sfixed32_value.append(-2147483648)
+ primitive.sfixed32_value.append(2147483647)
+ primitive.sint32_value.append(-2147483648)
+ primitive.sint32_value.append(2147483647)
+ primitive.uint32_value.append(0)
+ primitive.uint32_value.append(4294967295)
+ primitive.fixed32_value.append(0)
+ primitive.fixed32_value.append(4294967295)
+ primitive.bool_value.append(False)
+ primitive.bool_value.append(True)
+ primitive.string_value.append("")
+ primitive.string_value.append("I refer to the infinite.")
+ test_case.shape.append(1)
+ test_case.sizes.append(3)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(-1.7976931348623158e+308)
+ field.expected.double_value.append(2.2250738585072014e-308)
+ field.expected.double_value.append(1.7976931348623158e+308)
+ test_case.sizes.append(3)
+ field = test_case.field.add()
+ field.name = "float_value"
+ field.dtype = types_pb2.DT_FLOAT
+ field.expected.float_value.append(-3.402823466e+38)
+ field.expected.float_value.append(1.175494351e-38)
+ field.expected.float_value.append(3.402823466e+38)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "int64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sfixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(-9223372036854775808)
+ field.expected.int64_value.append(9223372036854775807)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "uint64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(0)
+ field.expected.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "fixed64_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(0)
+ field.expected.int64_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "int32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sfixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "sint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(-2147483648)
+ field.expected.int32_value.append(2147483647)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(0)
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT32
+ field.expected.int32_value.append(0)
+ field.expected.int32_value.append(-1)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(2)
+ field = test_case.field.add()
+ field.name = "string_value"
+ field.dtype = types_pb2.DT_STRING
+ field.expected.string_value.append("")
+ field.expected.string_value.append("I refer to the infinite.")
+ return test_case
+
+ @staticmethod
+ def nested_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ message_value = primitive.message_value.add()
+ message_value.double_value = 23.5
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "message_value"
+ field.dtype = types_pb2.DT_STRING
+ message_value = field.expected.message_value.add()
+ message_value.double_value = 23.5
+ return test_case
+
+ @staticmethod
+ def optional_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.bool_value.append(True)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(0.0)
+ return test_case
+
+ @staticmethod
+ def promote_unsigned_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.fixed32_value.append(4294967295)
+ primitive.uint32_value.append(4294967295)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ # Comes from an explicitly-specified default
+ test_case.sizes.append(0)
+ field = test_case.field.add()
+ field.name = "uint32_default"
+ field.dtype = types_pb2.DT_INT64
+ field.expected.int64_value.append(4294967295)
+ return test_case
+
+ @staticmethod
+ def ragged_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.double_value.append(123.0)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(3.1)
+ primitive.bool_value.append(False)
+ test_case.shape.append(2)
+ test_case.sizes.append(2)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ field.expected.double_value.append(123.0)
+ field.expected.double_value.append(3.1)
+ field.expected.double_value.append(0.0)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ return test_case
+
+ @staticmethod
+ def shaped_batch_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(44.0)
+ primitive.bool_value.append(False)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(3.14159)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(1.414)
+ primitive.bool_value.append(True)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(-32.2)
+ primitive.bool_value.append(False)
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(0.0001)
+ primitive.bool_value.append(True)
+ test_case.shape.append(3)
+ test_case.shape.append(2)
+ for _ in range(12):
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ field.expected.double_value.append(44.0)
+ field.expected.double_value.append(3.14159)
+ field.expected.double_value.append(1.414)
+ field.expected.double_value.append(-32.2)
+ field.expected.double_value.append(0.0001)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(True)
+ field.expected.bool_value.append(False)
+ field.expected.bool_value.append(True)
+ return test_case
+
+ @staticmethod
+ def simple_test_case():
+ test_case = test_example_pb2.TestCase()
+ primitive = test_case.primitive.add()
+ primitive.double_value.append(23.5)
+ primitive.bool_value.append(True)
+ test_case.shape.append(1)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "double_value"
+ field.dtype = types_pb2.DT_DOUBLE
+ field.expected.double_value.append(23.5)
+ test_case.sizes.append(1)
+ field = test_case.field.add()
+ field.name = "bool_value"
+ field.dtype = types_pb2.DT_BOOL
+ field.expected.bool_value.append(True)
+ return test_case
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index c83623ec94..27a933c0f9 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -6,7 +6,7 @@ inference. The details of the transformation implemented in this package is
described here [1].
This is done using the
-[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization).
+[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization).
Literature has shown that fixed point networks provide comparable performance to
floating point networks [2]. This is achieved by modeling the quantization
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 55479bf5f7..e3c4899830 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -121,7 +121,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
scaled_weight_tensor = math_ops.multiply(
weights, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
- match.layer_op, match.input_tensor, scaled_weight_tensor)
+ match.layer_op, match.input_tensor, scaled_weight_tensor,
+ match.batch_to_space_op)
if correction_recip is not None:
new_layer_tensor = math_ops.multiply(
@@ -149,6 +150,8 @@ def _FindFusedBatchNorms(graph):
_FusedBatchNormMatches.
"""
input_pattern = graph_matcher.OpTypePattern('*')
+ # In practice, the weight pattern can match a Variable or a SpaceToBatchND
+ # operation that follows a variable for atrous convolutions.
weight_pattern = graph_matcher.OpTypePattern('*')
gamma_pattern = graph_matcher.OpTypePattern('*')
beta_pattern = graph_matcher.OpTypePattern('*')
@@ -160,16 +163,27 @@ def _FindFusedBatchNorms(graph):
layer_pattern = graph_matcher.OpTypePattern(
'Conv2D|DepthwiseConv2dNative|MatMul',
inputs=[input_pattern, weight_pattern])
+ batch_to_space_pattern = graph_matcher.OpTypePattern(
+ 'BatchToSpaceND',
+ inputs=[
+ layer_pattern,
+ graph_matcher.OpTypePattern('*'),
+ graph_matcher.OpTypePattern('*')
+ ])
+ layer_output_pattern = graph_matcher.OneofPattern(
+ [layer_pattern, batch_to_space_pattern])
# MatMul has a Reshape between it and FusedBatchNorm.
matmul_reshape_pattern = graph_matcher.OpTypePattern(
- 'Reshape', inputs=[layer_pattern,
- graph_matcher.OpTypePattern('*')])
+ 'Reshape',
+ inputs=[layer_output_pattern,
+ graph_matcher.OpTypePattern('*')])
batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
- graph_matcher.OneofPattern([matmul_reshape_pattern, layer_pattern]),
- gamma_pattern, beta_pattern, mean_pattern, variance_pattern
+ graph_matcher.OneofPattern(
+ [matmul_reshape_pattern, layer_output_pattern]), gamma_pattern,
+ beta_pattern, mean_pattern, variance_pattern
])
matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[batch_norm_pattern,
@@ -192,6 +206,7 @@ def _FindFusedBatchNorms(graph):
moving_variance_tensor = None
bn_decay_mean_tensor = None
bn_decay_var_tensor = None
+ batch_to_space_op = None
layer_op = match_result.get_op(layer_pattern)
layer_tensor = match_result.get_tensor(layer_pattern)
bn_op = match_result.get_op(batch_norm_pattern)
@@ -213,6 +228,7 @@ def _FindFusedBatchNorms(graph):
if not output_tensor.consumers():
continue
+ batch_to_space_op = match_result.get_op(batch_to_space_pattern)
input_tensor = match_result.get_tensor(input_pattern)
weight_tensor = match_result.get_tensor(weight_pattern)
gamma_tensor = match_result.get_tensor(gamma_pattern)
@@ -276,7 +292,8 @@ def _FindFusedBatchNorms(graph):
moving_variance_tensor=moving_variance_tensor,
bn_decay_mean_tensor=bn_decay_mean_tensor,
bn_decay_var_tensor=bn_decay_var_tensor,
- batch_epsilon=batch_epsilon)
+ batch_epsilon=batch_epsilon,
+ batch_to_space_op=batch_to_space_op)
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
@@ -380,7 +397,8 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
return correction_scale, correction_recip, correction_offset
-def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
+def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor,
+ batch_to_space_op):
"""Clones layer_op with input_tensor and weight_tensor as new inputs."""
new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
if layer_op.type == 'Conv2D':
@@ -400,12 +418,25 @@ def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
transpose_b=layer_op.get_attr('transpose_b'),
name=new_layer_name)
elif layer_op.type == 'DepthwiseConv2dNative':
- return nn.depthwise_conv2d(
+ conv = nn.depthwise_conv2d(
input_tensor,
weight_tensor,
+ rate=layer_op.get_attr('dilations'),
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
name=new_layer_name)
+ # Copy the batch to space operation if we have a atrous convolution.
+ if batch_to_space_op:
+ batch_to_space_op = layer_op.outputs[0].consumers()[0]
+ # TODO(suharshs): It's hard to make this name match with the unfused name.
+ # Restructure this code to not rely on scope at all.
+ new_batch_to_space_name = batch_to_space_op.name.split('/')[-1] + '_Fold'
+ conv = array_ops.batch_to_space_nd(
+ conv,
+ batch_to_space_op.inputs[1],
+ batch_to_space_op.inputs[2],
+ name=new_batch_to_space_name)
+ return conv
else:
raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
@@ -617,7 +648,8 @@ def _GetBatchNormParams(graph, context, has_scaling):
moving_variance_tensor=moving_variance_tensor,
bn_decay_mean_tensor=bn_decay_mean_tensor,
bn_decay_var_tensor=bn_decay_var_tensor,
- batch_epsilon=batch_epsilon)
+ batch_epsilon=batch_epsilon,
+ batch_to_space_op=None)
def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
@@ -651,6 +683,11 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
'/BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
+ # Skip over the BatchToSpace operation in the case of atrous convolutions.
+ batch_to_space_op = None
+ if op_below.type == 'BatchToSpaceND':
+ batch_to_space_op = op_below
+ op_below = op_below.inputs[0].op
weights = op_below.inputs[1]
match = _GetBatchNormParams(
graph=graph, context=context, has_scaling=has_scaling)
@@ -691,7 +728,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
context + '/correction_mult')
mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
else:
- raise ValueError('Cannot handle operation of type: %s' % op_below.op)
+ raise ValueError('Cannot handle operation of type: %s' % op_below.type)
_AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
@@ -701,6 +738,13 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
context + '/BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
+ # Copy the batch to space operation if we have a atrous convolution.
+ if batch_to_space_op:
+ corrected_output = array_ops.batch_to_space_nd(
+ corrected_output,
+ batch_to_space_op.inputs[1],
+ batch_to_space_op.inputs[2],
+ name=batch_to_space_op.name + '_Fold')
if correction_offset is not None:
with ops.device(conv_or_fc_folded.device):
corrected_output = math_ops.multiply(correction_recip, corrected_output,
@@ -898,7 +942,8 @@ class _BatchNormMatch(object):
def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor, moving_mean_tensor, moving_variance_tensor,
- bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon):
+ bn_decay_mean_tensor, bn_decay_var_tensor, batch_epsilon,
+ batch_to_space_op):
self._layer_op = layer_op
self._bn_op = bn_op
self._output_tensor = output_tensor
@@ -913,6 +958,7 @@ class _BatchNormMatch(object):
self._bn_decay_mean_tensor = bn_decay_mean_tensor
self._bn_decay_var_tensor = bn_decay_var_tensor
self._batch_epsilon = batch_epsilon
+ self._batch_to_space_op = batch_to_space_op
@property
def layer_op(self):
@@ -969,3 +1015,7 @@ class _BatchNormMatch(object):
@property
def bn_decay_var_tensor(self):
return self._bn_decay_var_tensor
+
+ @property
+ def batch_to_space_op(self):
+ return self._batch_to_space_op
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index bfa9d3bf70..7c907ffd92 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -438,6 +438,90 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def testFoldDepthwiseConv2d(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
+ def _TestFoldAtrousConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
+ fused_batch_norm, freeze_batch_norm_delay):
+ """Tests folding: inputs -> AtrousConv2d with batch norm -> Relu*.
+
+ Args:
+ relu: Callable that returns an Operation, a factory method for the Relu*.
+ relu_op_name: String, name of the Relu* operation.
+ with_bypass: Bool, when true there is an extra connection added from
+ inputs to just before Relu*.
+ has_scaling: Bool, when true the batch norm has scaling.
+ fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
+ """
+ g = ops.Graph()
+ with g.as_default():
+ batch_size, height, width = 5, 128, 128
+ inputs = array_ops.zeros((batch_size, height, width, 3))
+ dilation_rate = 2
+ activation_fn = None if with_bypass else relu
+ scope = 'test/test2' if with_bypass else 'test'
+ node = separable_conv2d(
+ inputs,
+ None, [3, 3],
+ rate=dilation_rate,
+ depth_multiplier=1.0,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope=scope)
+ if with_bypass:
+ node = math_ops.add(inputs, node, name='test/Add')
+ relu(node, name='test/' + relu_op_name)
+
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
+
+ folded_mul = g.get_operation_by_name(scope + '/mul_fold')
+ self.assertEqual(folded_mul.type, 'Mul')
+ if fused_batch_norm:
+ scale_reshape_op_name = scope + '/BatchNorm_Fold/scale_reshape'
+ else:
+ scale_reshape_op_name = scope + '/scale_reshape'
+ self._AssertInputOpsAre(folded_mul,
+ [scope + '/correction_mult', scale_reshape_op_name])
+ self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold'])
+
+ scale_reshape = g.get_operation_by_name(scale_reshape_op_name)
+ self.assertEqual(scale_reshape.type, 'Reshape')
+ self._AssertInputOpsAre(scale_reshape, [
+ self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm),
+ scale_reshape_op_name + '/shape'
+ ])
+ self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
+
+ folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
+ self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
+ self._AssertInputOpsAre(
+ folded_conv, [scope + '/mul_fold', scope + '/depthwise/SpaceToBatchND'])
+ if fused_batch_norm:
+ self._AssertOutputGoesToOps(folded_conv, g,
+ [scope + '/BatchToSpaceND_Fold'])
+ else:
+ self._AssertOutputGoesToOps(folded_conv, g,
+ [scope + '/depthwise/BatchToSpaceND_Fold'])
+
+ folded_add = g.get_operation_by_name(scope + '/add_fold')
+ self.assertEqual(folded_add.type, 'Add')
+ self._AssertInputOpsAre(folded_add, [
+ scope + '/correction_add',
+ self._BathNormBiasName(scope, fused_batch_norm)
+ ])
+ output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
+ self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
+ def testFoldAtrousConv2d(self):
+ self._RunTestOverParameters(self._TestFoldAtrousConv2d)
+
def _TestCompareFoldAndUnfolded(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm,
freeze_batch_norm_delay):
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index cbba72643f..4fc315d901 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -194,6 +194,8 @@ def _FindLayersToQuantize(graph):
/
conv|fc
|
+ [batch_to_space_nd]
+ |
[post_conv_correction]
|
biasadd|folded_bias
@@ -247,9 +249,21 @@ def _FindLayersToQuantize(graph):
],
ordered_inputs=False)
+ # For atrous convolutions a BatchToSpaceND will occur after the depthwise
+ # convolution.
+ batch_to_space_pattern = graph_matcher.OpTypePattern(
+ 'BatchToSpaceND',
+ inputs=[
+ layer_pattern,
+ graph_matcher.OpTypePattern('*'),
+ graph_matcher.OpTypePattern('*')
+ ])
+
+ layer_output_pattern = graph_matcher.OneofPattern(
+ [batch_to_space_pattern, layer_pattern])
folded_bias_mul_pattern = graph_matcher.OpTypePattern(
'Mul',
- inputs=[graph_matcher.OpTypePattern('*'), layer_pattern],
+ inputs=[graph_matcher.OpTypePattern('*'), layer_output_pattern],
ordered_inputs=False)
post_layer_op_correction_pattern = graph_matcher.OpTypePattern(
'Add',
@@ -264,28 +278,37 @@ def _FindLayersToQuantize(graph):
],
ordered_inputs=False)
+ # batch_norms with forced updates have an Identity operation at the end.
+ # TODO(suharshs): Find a way to easily skip extra Identity operations. The
+ # current issue is that doing so can often match patterns across many layers
+ # incorrectly.
+ batch_norm_identity = graph_matcher.OpTypePattern(
+ 'Identity', inputs=[folded_bias_add_pattern])
+
bias_add_pattern = graph_matcher.OpTypePattern(
- 'Add|BiasAdd', inputs=[layer_pattern, '*'], ordered_inputs=False)
+ 'Add|BiasAdd', inputs=[layer_output_pattern, '*'], ordered_inputs=False)
# The bias can come from the bias add or the folded bias add.
bypass_pattern = graph_matcher.OpTypePattern(
'Add',
inputs=[
graph_matcher.OneofPattern(
- [bias_add_pattern, folded_bias_add_pattern]), '*'
+ [bias_add_pattern, folded_bias_add_pattern, batch_norm_identity]),
+ '*'
],
ordered_inputs=False)
# The input to the activation can come from bias add, fold bias add, the
# bypasses.
# TODO(suharshs): We should ideally skip Identity operations instead of
- # treating them as an activation.
+ # treating them as activations.
activation_pattern = graph_matcher.OpTypePattern(
'|'.join(_ACTIVATION_TYPES) + '|Identity',
inputs=[
graph_matcher.OneofPattern([
bias_add_pattern,
folded_bias_add_pattern,
+ batch_norm_identity,
bypass_pattern,
])
])
@@ -373,14 +396,6 @@ def _FindLayersToQuantize(graph):
return layer_matches
-def _HasPostActivationBypass(activation_op):
- for activation_tensor in activation_op.outputs:
- for output_op in activation_tensor.consumers():
- if output_op.type == 'Add':
- return True
- return False
-
-
class _LayerMatch(object):
"""Contains all information related to a matched Layer."""
diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py
index 11d052d7f4..2944f964c7 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph.py
@@ -191,6 +191,7 @@ def experimental_create_training_graph(input_graph=None,
def experimental_create_eval_graph(input_graph=None,
weight_bits=8,
activation_bits=8,
+ quant_delay=None,
scope=None):
"""Rewrites an eval input_graph in place for simulated quantization.
@@ -209,6 +210,8 @@ def experimental_create_eval_graph(input_graph=None,
default graph.
weight_bits: Number of bits to use for quantizing weights.
activation_bits: Number of bits to use for quantizing activations.
+ quant_delay: Number of steps after which weights and activations are
+ quantized during eval.
scope: The scope to be transformed. If it's not None, only the ops which
are in this scope will be transformed.
@@ -221,4 +224,5 @@ def experimental_create_eval_graph(input_graph=None,
is_training=False,
weight_bits=weight_bits,
activation_bits=activation_bits,
+ quant_delay=quant_delay,
scope=scope)
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index db745aa562..31a2955ddb 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -276,6 +276,52 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass,
delay, use_resource)
+ def testQuantize_AtrousConvWithoutBatchNorm(self):
+ self._RunWithoutBatchNormTestOverParameters(
+ self._TestQuantize_AtrousConvWithoutBatchNorm)
+
+ def _TestQuantize_AtrousConvWithoutBatchNorm(
+ self, activation, activation_op_name, with_bypass, delay, use_resource):
+ """Tests quantization: inputs -> atrous conv no batch norm -> Activation.
+
+ Args:
+ activation: Callable that returns an Operation, a factory method for the
+ Activation.
+ activation_op_name: String, name of the Activation operation.
+ with_bypass: Bool, when true there is an extra connection added from
+ inputs to just before Activation.
+ delay: Int (optional), delay in number of steps until quantization starts.
+ use_resource: Bool, when true uses resource variables.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ variable_scope.get_variable_scope().set_use_resource(use_resource)
+ batch_size, height, width, depth = 5, 128, 128, 3
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ dilation_rate = 2
+ activation_fn = None if with_bypass else activation
+ scope = 'test/test2' if with_bypass else 'test'
+ node = separable_conv2d(
+ inputs,
+ None, [3, 3],
+ rate=dilation_rate,
+ depth_multiplier=1.0,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ scope=scope)
+ if with_bypass:
+ node = math_ops.add(inputs, node, name='test/Add')
+ node = activation(node, name='test/' + activation_op_name)
+ update_barrier = control_flow_ops.no_op(name='update_barrier')
+ with ops.control_dependencies([update_barrier]):
+ array_ops.identity(node, name='control_dependency')
+ quantize.Quantize(graph, True, quant_delay=delay)
+
+ self._AssertCorrectQuantizedGraphWithoutBatchNorm(
+ graph, scope, 'DepthwiseConv2dNative', activation_op_name, with_bypass,
+ delay, use_resource)
+
def _RunBatchNormTestOverParameters(self, test_fn):
# TODO(suharshs): Use parameterized test once OSS TF supports it.
parameters_list = [
@@ -543,6 +589,61 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph, scope, 'DepthwiseConv2dNative', activation_op_name,
with_bypass, delay, use_resource)
+ def testQuantize_AtrousConvWithBatchNorm(self):
+ self._RunBatchNormTestOverParameters(
+ self._TestQuantize_AtrousConvWithBatchNorm)
+
+ def _TestQuantize_AtrousConvWithBatchNorm(
+ self, activation, activation_op_name, with_bypass, delay,
+ fused_batch_norm, use_resource):
+ """Tests quantization: inputs -> atrous conv with batch norm -> Activation.
+
+ Args:
+ activation: Callable that returns an Operation, a factory method for the
+ Activation.
+ activation_op_name: String, name of the Activation operation.
+ with_bypass: Bool, when true there is an extra connection added from
+ inputs to just before Activation.
+ delay: Int (optional), delay in number of steps until quantization starts.
+ fused_batch_norm: Bool, when true use FusedBatchNorm.
+ use_resource: Bool, when true uses resource variables.
+ """
+ graph = ops.Graph()
+ with graph.as_default():
+ variable_scope.get_variable_scope().set_use_resource(use_resource)
+ batch_size, height, width, depth = 5, 128, 128, 3
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ dilation_rate = 2
+ scope = 'test/test2' if with_bypass else 'test'
+ node = separable_conv2d(
+ inputs,
+ None, [3, 3],
+ rate=dilation_rate,
+ depth_multiplier=1.0,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(fused_batch_norm),
+ scope=scope)
+
+ # Manually add a bypass (optional) and an activation.
+ if with_bypass:
+ node = math_ops.add(inputs, node, name='test/Add')
+
+ node = activation(node, name='test/' + activation_op_name)
+
+ update_barrier = control_flow_ops.no_op(name='update_barrier')
+ with ops.control_dependencies([update_barrier]):
+ array_ops.identity(node, name='control_dependency')
+
+ fold_batch_norms.FoldBatchNorms(graph, is_training=True)
+ quantize.Quantize(graph, True, quant_delay=delay)
+
+ self._AssertCorrectQuantizedGraphWithBatchNorm(
+ graph, scope, 'DepthwiseConv2dNative', activation_op_name,
+ with_bypass, delay, use_resource)
+
def _AssertIdempotent(self, graph):
# Ensure that calling the rewrite again doesn't change the graph.
graph_def_before = str(graph.as_graph_def())
@@ -553,8 +654,80 @@ class QuantizeTest(test_util.TensorFlowTestCase):
graph_def_after = str(graph.as_graph_def())
self.assertEqual(graph_def_before, graph_def_after)
- def _BatchNormParams(self, fused=False):
- return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
+ def testBatchNormForcedUpdates(self):
+ parameter_list = [
+ # (activation, activation_op_name, fused_batch_norm)
+ (nn_ops.relu6, 'Relu6', False),
+ (nn_ops.relu, 'Relu', False),
+ (array_ops.identity, 'Identity', False),
+ (nn_ops.relu6, 'Relu6', True),
+ (nn_ops.relu, 'Relu', True),
+ (array_ops.identity, 'Identity', True),
+ ]
+ for params in parameter_list:
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], False)
+ self._TestBatchNormForcedUpdates(params[0], params[1], params[2], True)
+
+ def _TestBatchNormForcedUpdates(self, activation, activation_op_name,
+ fused_batch_norm, use_resource):
+ """post_activation bypass quantization should happen with forced updates."""
+ graph = ops.Graph()
+ with graph.as_default():
+ variable_scope.get_variable_scope().set_use_resource(use_resource)
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros((batch_size, height, width, depth))
+ input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32))
+ # Setting updates_collections to None forces updates adding an extra
+ # identity operation following batch norms.
+ bn_params = self._BatchNormParams(
+ fused=fused_batch_norm, force_updates=True)
+ conv = conv2d(
+ input1,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation,
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ scope='test/test')
+ bypass_tensor = math_ops.add(conv, input2, name='test/add')
+ # The output of the post_activation bypass will be another layer.
+ _ = conv2d(
+ bypass_tensor,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ normalizer_fn=batch_norm,
+ normalizer_params=bn_params,
+ activation_fn=activation,
+ scope='test/unused')
+
+ fold_batch_norms.FoldBatchNorms(graph, is_training=True)
+ quantize.Quantize(graph, is_training=True)
+
+ # Ensure that the bypass node is preceded by and followed by a
+ # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
+ # activation.
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [c.type for c in bypass_tensor.consumers()])
+ self.assertTrue('FakeQuantWithMinMaxVars' in
+ [i.op.type for i in bypass_tensor.op.inputs])
+
+ with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
+ f.write(str(graph.as_graph_def()))
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ if force_updates:
+ params['updates_collections'] = None
+ return params
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md
index 3ff85faf61..79b015a916 100644
--- a/tensorflow/contrib/receptive_field/README.md
+++ b/tensorflow/contrib/receptive_field/README.md
@@ -6,6 +6,32 @@ region your output features depend on. Better yet, using the parameters computed
by the library, you can easily find the exact image region which is used to
compute each convnet feature.
+This library can be used to compute receptive field parameters of popular
+convnets:
+
+<center>
+
+convnet model | receptive field | effective stride | effective padding
+:-----------------: | :-------------: | :--------------: | :---------------:
+alexnet_v2 | 195 | 32 | 64
+vgg_16 | 212 | 32 | 90
+inception_v2 | 699 | 32 | 318
+inception_v3 | 1311 | 32 | 618
+inception_v4 | 2071 | 32 | 998
+inception_resnet_v2 | 3039 | 32 | 1482
+mobilenet_v1 | 315 | 32 | 126
+mobilenet_v1_075 | 315 | 32 | 126
+resnet_v1_50 | 483 | 32 | 241
+resnet_v1_101 | 1027 | 32 | 513
+resnet_v1_152 | 1507 | 32 | 753
+resnet_v1_200 | 1763 | 32 | 881
+
+</center>
+
+A comprehensive table with pre-computed receptive field parameters for different
+end-points, input resolutions, and other variants of these networks can be found
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md).
+
## Basic usage
The main function to be called is `compute_receptive_field_from_graph_def`,
@@ -96,9 +122,9 @@ The script will write to stdout the receptive field parameters for many variants
of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They
are also written to the file `/tmp/rf_benchmark_results.csv`.
-TODO: include here a plot for receptive field sizes of different convnets.
-
-TODO: include table/link to pre-computed RF parameters.
+A comprehensive table with pre-computed receptive field parameters for different
+networks can be found
+[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md).
## Compute RF parameters from a graph pbtxt
diff --git a/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md
new file mode 100644
index 0000000000..736fbef6e7
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md
@@ -0,0 +1,629 @@
+# Pre-computed receptive field parameters
+
+## Table with results
+
+The table below presents the receptive field parameters for several popular
+convolutional neural networks. These are computed using the models from the
+[TF-Slim
+repository](https://github.com/tensorflow/models/tree/master/research/slim),
+by using the [rf_benchmark
+script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py).
+
+Questions? See the [FAQ](#faq).
+
+CNN | resolution | end-point | RF | effective stride | effective padding
+:----------------------------: | :--------: | :------------------: | :--: | :--------------: | :---------------:
+alexnet_v2 | None | alexnet_v2/conv1 | 11 | 4 | 0
+alexnet_v2 | None | alexnet_v2/pool1 | 19 | 8 | 0
+alexnet_v2 | None | alexnet_v2/conv2 | 51 | 8 | 16
+alexnet_v2 | None | alexnet_v2/conv3 | 99 | 16 | 32
+alexnet_v2 | None | alexnet_v2/conv4 | 131 | 16 | 48
+alexnet_v2 | None | alexnet_v2/conv5 | 163 | 16 | 64
+alexnet_v2 | None | alexnet_v2/pool5 | 195 | 32 | 64
+alexnet_v2 | 224 | alexnet_v2/conv1 | 11 | 4 | 0
+alexnet_v2 | 224 | alexnet_v2/pool1 | 19 | 8 | 0
+alexnet_v2 | 224 | alexnet_v2/conv2 | 51 | 8 | 16
+alexnet_v2 | 224 | alexnet_v2/conv3 | 99 | 16 | 32
+alexnet_v2 | 224 | alexnet_v2/conv4 | 131 | 16 | 48
+alexnet_v2 | 224 | alexnet_v2/conv5 | 163 | 16 | 64
+alexnet_v2 | 224 | alexnet_v2/pool5 | 195 | 32 | 64
+alexnet_v2 | 321 | alexnet_v2/conv1 | 11 | 4 | 0
+alexnet_v2 | 321 | alexnet_v2/pool1 | 19 | 8 | 0
+alexnet_v2 | 321 | alexnet_v2/conv2 | 51 | 8 | 16
+alexnet_v2 | 321 | alexnet_v2/conv3 | 99 | 16 | 32
+alexnet_v2 | 321 | alexnet_v2/conv4 | 131 | 16 | 48
+alexnet_v2 | 321 | alexnet_v2/conv5 | 163 | 16 | 64
+alexnet_v2 | 321 | alexnet_v2/pool5 | 195 | 32 | 64
+vgg_a | None | vgg_a/conv1/conv1_1 | 3 | 1 | 1
+vgg_a | None | vgg_a/pool1 | 4 | 2 | 1
+vgg_a | None | vgg_a/conv2/conv2_1 | 8 | 2 | 3
+vgg_a | None | vgg_a/pool2 | 10 | 4 | 3
+vgg_a | None | vgg_a/conv3/conv3_1 | 18 | 4 | 7
+vgg_a | None | vgg_a/conv3/conv3_2 | 26 | 4 | 11
+vgg_a | None | vgg_a/pool3 | 30 | 8 | 11
+vgg_a | None | vgg_a/conv4/conv4_1 | 46 | 8 | 19
+vgg_a | None | vgg_a/conv4/conv4_2 | 62 | 8 | 27
+vgg_a | None | vgg_a/pool4 | 70 | 16 | 27
+vgg_a | None | vgg_a/conv5/conv5_1 | 102 | 16 | 43
+vgg_a | None | vgg_a/conv5/conv5_2 | 134 | 16 | 59
+vgg_a | None | vgg_a/pool5 | 150 | 32 | 59
+vgg_a | 224 | vgg_a/conv1/conv1_1 | 3 | 1 | 1
+vgg_a | 224 | vgg_a/pool1 | 4 | 2 | 1
+vgg_a | 224 | vgg_a/conv2/conv2_1 | 8 | 2 | 3
+vgg_a | 224 | vgg_a/pool2 | 10 | 4 | 3
+vgg_a | 224 | vgg_a/conv3/conv3_1 | 18 | 4 | 7
+vgg_a | 224 | vgg_a/conv3/conv3_2 | 26 | 4 | 11
+vgg_a | 224 | vgg_a/pool3 | 30 | 8 | 11
+vgg_a | 224 | vgg_a/conv4/conv4_1 | 46 | 8 | 19
+vgg_a | 224 | vgg_a/conv4/conv4_2 | 62 | 8 | 27
+vgg_a | 224 | vgg_a/pool4 | 70 | 16 | 27
+vgg_a | 224 | vgg_a/conv5/conv5_1 | 102 | 16 | 43
+vgg_a | 224 | vgg_a/conv5/conv5_2 | 134 | 16 | 59
+vgg_a | 224 | vgg_a/pool5 | 150 | 32 | 59
+vgg_a | 321 | vgg_a/conv1/conv1_1 | 3 | 1 | 1
+vgg_a | 321 | vgg_a/pool1 | 4 | 2 | 1
+vgg_a | 321 | vgg_a/conv2/conv2_1 | 8 | 2 | 3
+vgg_a | 321 | vgg_a/pool2 | 10 | 4 | 3
+vgg_a | 321 | vgg_a/conv3/conv3_1 | 18 | 4 | 7
+vgg_a | 321 | vgg_a/conv3/conv3_2 | 26 | 4 | 11
+vgg_a | 321 | vgg_a/pool3 | 30 | 8 | 11
+vgg_a | 321 | vgg_a/conv4/conv4_1 | 46 | 8 | 19
+vgg_a | 321 | vgg_a/conv4/conv4_2 | 62 | 8 | 27
+vgg_a | 321 | vgg_a/pool4 | 70 | 16 | 27
+vgg_a | 321 | vgg_a/conv5/conv5_1 | 102 | 16 | 43
+vgg_a | 321 | vgg_a/conv5/conv5_2 | 134 | 16 | 59
+vgg_a | 321 | vgg_a/pool5 | 150 | 32 | 59
+vgg_16 | None | vgg_16/conv1/conv1_1 | 3 | 1 | 1
+vgg_16 | None | vgg_16/pool1 | 6 | 2 | 2
+vgg_16 | None | vgg_16/conv2/conv2_1 | 10 | 2 | 4
+vgg_16 | None | vgg_16/pool2 | 16 | 4 | 6
+vgg_16 | None | vgg_16/conv3/conv3_1 | 24 | 4 | 10
+vgg_16 | None | vgg_16/conv3/conv3_2 | 32 | 4 | 14
+vgg_16 | None | vgg_16/pool3 | 44 | 8 | 18
+vgg_16 | None | vgg_16/conv4/conv4_1 | 60 | 8 | 26
+vgg_16 | None | vgg_16/conv4/conv4_2 | 76 | 8 | 34
+vgg_16 | None | vgg_16/pool4 | 100 | 16 | 42
+vgg_16 | None | vgg_16/conv5/conv5_1 | 132 | 16 | 58
+vgg_16 | None | vgg_16/conv5/conv5_2 | 164 | 16 | 74
+vgg_16 | None | vgg_16/pool5 | 212 | 32 | 90
+vgg_16 | 224 | vgg_16/conv1/conv1_1 | 3 | 1 | 1
+vgg_16 | 224 | vgg_16/pool1 | 6 | 2 | 2
+vgg_16 | 224 | vgg_16/conv2/conv2_1 | 10 | 2 | 4
+vgg_16 | 224 | vgg_16/pool2 | 16 | 4 | 6
+vgg_16 | 224 | vgg_16/conv3/conv3_1 | 24 | 4 | 10
+vgg_16 | 224 | vgg_16/conv3/conv3_2 | 32 | 4 | 14
+vgg_16 | 224 | vgg_16/pool3 | 44 | 8 | 18
+vgg_16 | 224 | vgg_16/conv4/conv4_1 | 60 | 8 | 26
+vgg_16 | 224 | vgg_16/conv4/conv4_2 | 76 | 8 | 34
+vgg_16 | 224 | vgg_16/pool4 | 100 | 16 | 42
+vgg_16 | 224 | vgg_16/conv5/conv5_1 | 132 | 16 | 58
+vgg_16 | 224 | vgg_16/conv5/conv5_2 | 164 | 16 | 74
+vgg_16 | 224 | vgg_16/pool5 | 212 | 32 | 90
+vgg_16 | 321 | vgg_16/conv1/conv1_1 | 3 | 1 | 1
+vgg_16 | 321 | vgg_16/pool1 | 6 | 2 | 2
+vgg_16 | 321 | vgg_16/conv2/conv2_1 | 10 | 2 | 4
+vgg_16 | 321 | vgg_16/pool2 | 16 | 4 | 6
+vgg_16 | 321 | vgg_16/conv3/conv3_1 | 24 | 4 | 10
+vgg_16 | 321 | vgg_16/conv3/conv3_2 | 32 | 4 | 14
+vgg_16 | 321 | vgg_16/pool3 | 44 | 8 | 18
+vgg_16 | 321 | vgg_16/conv4/conv4_1 | 60 | 8 | 26
+vgg_16 | 321 | vgg_16/conv4/conv4_2 | 76 | 8 | 34
+vgg_16 | 321 | vgg_16/pool4 | 100 | 16 | 42
+vgg_16 | 321 | vgg_16/conv5/conv5_1 | 132 | 16 | 58
+vgg_16 | 321 | vgg_16/conv5/conv5_2 | 164 | 16 | 74
+vgg_16 | 321 | vgg_16/pool5 | 212 | 32 | 90
+inception_v2 | None | Conv2d_1a_7x7 | 7 | 2 | None
+inception_v2 | None | MaxPool_2a_3x3 | 11 | 4 | None
+inception_v2 | None | Conv2d_2b_1x1 | 11 | 4 | None
+inception_v2 | None | Conv2d_2c_3x3 | 19 | 4 | None
+inception_v2 | None | MaxPool_3a_3x3 | 27 | 8 | None
+inception_v2 | None | Mixed_3b | 59 | 8 | None
+inception_v2 | None | Mixed_3c | 91 | 8 | None
+inception_v2 | None | Mixed_4a | 123 | 16 | None
+inception_v2 | None | Mixed_4b | 187 | 16 | None
+inception_v2 | None | Mixed_4c | 251 | 16 | None
+inception_v2 | None | Mixed_4d | 315 | 16 | None
+inception_v2 | None | Mixed_4e | 379 | 16 | None
+inception_v2 | None | Mixed_5a | 443 | 32 | None
+inception_v2 | None | Mixed_5b | 571 | 32 | None
+inception_v2 | None | Mixed_5c | 699 | 32 | None
+inception_v2 | 224 | Conv2d_1a_7x7 | 7 | 2 | 2
+inception_v2 | 224 | MaxPool_2a_3x3 | 11 | 4 | 2
+inception_v2 | 224 | Conv2d_2b_1x1 | 11 | 4 | 2
+inception_v2 | 224 | Conv2d_2c_3x3 | 19 | 4 | 6
+inception_v2 | 224 | MaxPool_3a_3x3 | 27 | 8 | 6
+inception_v2 | 224 | Mixed_3b | 59 | 8 | 22
+inception_v2 | 224 | Mixed_3c | 91 | 8 | 38
+inception_v2 | 224 | Mixed_4a | 123 | 16 | 46
+inception_v2 | 224 | Mixed_4b | 187 | 16 | 78
+inception_v2 | 224 | Mixed_4c | 251 | 16 | 110
+inception_v2 | 224 | Mixed_4d | 315 | 16 | 142
+inception_v2 | 224 | Mixed_4e | 379 | 16 | 174
+inception_v2 | 224 | Mixed_5a | 443 | 32 | 190
+inception_v2 | 224 | Mixed_5b | 571 | 32 | 254
+inception_v2 | 224 | Mixed_5c | 699 | 32 | 318
+inception_v2 | 321 | Conv2d_1a_7x7 | 7 | 2 | 3
+inception_v2 | 321 | MaxPool_2a_3x3 | 11 | 4 | 5
+inception_v2 | 321 | Conv2d_2b_1x1 | 11 | 4 | 5
+inception_v2 | 321 | Conv2d_2c_3x3 | 19 | 4 | 9
+inception_v2 | 321 | MaxPool_3a_3x3 | 27 | 8 | 13
+inception_v2 | 321 | Mixed_3b | 59 | 8 | 29
+inception_v2 | 321 | Mixed_3c | 91 | 8 | 45
+inception_v2 | 321 | Mixed_4a | 123 | 16 | 61
+inception_v2 | 321 | Mixed_4b | 187 | 16 | 93
+inception_v2 | 321 | Mixed_4c | 251 | 16 | 125
+inception_v2 | 321 | Mixed_4d | 315 | 16 | 157
+inception_v2 | 321 | Mixed_4e | 379 | 16 | 189
+inception_v2 | 321 | Mixed_5a | 443 | 32 | 221
+inception_v2 | 321 | Mixed_5b | 571 | 32 | 285
+inception_v2 | 321 | Mixed_5c | 699 | 32 | 349
+inception_v2-no-separable-conv | None | Conv2d_1a_7x7 | 7 | 2 | None
+inception_v2-no-separable-conv | None | MaxPool_2a_3x3 | 11 | 4 | None
+inception_v2-no-separable-conv | None | Conv2d_2b_1x1 | 11 | 4 | None
+inception_v2-no-separable-conv | None | Conv2d_2c_3x3 | 19 | 4 | None
+inception_v2-no-separable-conv | None | MaxPool_3a_3x3 | 27 | 8 | None
+inception_v2-no-separable-conv | None | Mixed_3b | 59 | 8 | None
+inception_v2-no-separable-conv | None | Mixed_3c | 91 | 8 | None
+inception_v2-no-separable-conv | None | Mixed_4a | 123 | 16 | None
+inception_v2-no-separable-conv | None | Mixed_4b | 187 | 16 | None
+inception_v2-no-separable-conv | None | Mixed_4c | 251 | 16 | None
+inception_v2-no-separable-conv | None | Mixed_4d | 315 | 16 | None
+inception_v2-no-separable-conv | None | Mixed_4e | 379 | 16 | None
+inception_v2-no-separable-conv | None | Mixed_5a | 443 | 32 | None
+inception_v2-no-separable-conv | None | Mixed_5b | 571 | 32 | None
+inception_v2-no-separable-conv | None | Mixed_5c | 699 | 32 | None
+inception_v2-no-separable-conv | 224 | Conv2d_1a_7x7 | 7 | 2 | 2
+inception_v2-no-separable-conv | 224 | MaxPool_2a_3x3 | 11 | 4 | 2
+inception_v2-no-separable-conv | 224 | Conv2d_2b_1x1 | 11 | 4 | 2
+inception_v2-no-separable-conv | 224 | Conv2d_2c_3x3 | 19 | 4 | 6
+inception_v2-no-separable-conv | 224 | MaxPool_3a_3x3 | 27 | 8 | 6
+inception_v2-no-separable-conv | 224 | Mixed_3b | 59 | 8 | 22
+inception_v2-no-separable-conv | 224 | Mixed_3c | 91 | 8 | 38
+inception_v2-no-separable-conv | 224 | Mixed_4a | 123 | 16 | 46
+inception_v2-no-separable-conv | 224 | Mixed_4b | 187 | 16 | 78
+inception_v2-no-separable-conv | 224 | Mixed_4c | 251 | 16 | 110
+inception_v2-no-separable-conv | 224 | Mixed_4d | 315 | 16 | 142
+inception_v2-no-separable-conv | 224 | Mixed_4e | 379 | 16 | 174
+inception_v2-no-separable-conv | 224 | Mixed_5a | 443 | 32 | 190
+inception_v2-no-separable-conv | 224 | Mixed_5b | 571 | 32 | 254
+inception_v2-no-separable-conv | 224 | Mixed_5c | 699 | 32 | 318
+inception_v2-no-separable-conv | 321 | Conv2d_1a_7x7 | 7 | 2 | 3
+inception_v2-no-separable-conv | 321 | MaxPool_2a_3x3 | 11 | 4 | 5
+inception_v2-no-separable-conv | 321 | Conv2d_2b_1x1 | 11 | 4 | 5
+inception_v2-no-separable-conv | 321 | Conv2d_2c_3x3 | 19 | 4 | 9
+inception_v2-no-separable-conv | 321 | MaxPool_3a_3x3 | 27 | 8 | 13
+inception_v2-no-separable-conv | 321 | Mixed_3b | 59 | 8 | 29
+inception_v2-no-separable-conv | 321 | Mixed_3c | 91 | 8 | 45
+inception_v2-no-separable-conv | 321 | Mixed_4a | 123 | 16 | 61
+inception_v2-no-separable-conv | 321 | Mixed_4b | 187 | 16 | 93
+inception_v2-no-separable-conv | 321 | Mixed_4c | 251 | 16 | 125
+inception_v2-no-separable-conv | 321 | Mixed_4d | 315 | 16 | 157
+inception_v2-no-separable-conv | 321 | Mixed_4e | 379 | 16 | 189
+inception_v2-no-separable-conv | 321 | Mixed_5a | 443 | 32 | 221
+inception_v2-no-separable-conv | 321 | Mixed_5b | 571 | 32 | 285
+inception_v2-no-separable-conv | 321 | Mixed_5c | 699 | 32 | 349
+inception_v3 | None | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v3 | None | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v3 | None | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v3 | None | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_v3 | None | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_v3 | None | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_v3 | None | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_v3 | None | Mixed_5b | 63 | 8 | 18
+inception_v3 | None | Mixed_5c | 95 | 8 | 34
+inception_v3 | None | Mixed_5d | 127 | 8 | 50
+inception_v3 | None | Mixed_6a | 159 | 16 | 58
+inception_v3 | None | Mixed_6b | 351 | 16 | 154
+inception_v3 | None | Mixed_6c | 543 | 16 | 250
+inception_v3 | None | Mixed_6d | 735 | 16 | 346
+inception_v3 | None | Mixed_6e | 927 | 16 | 442
+inception_v3 | None | Mixed_7a | 1055 | 32 | 490
+inception_v3 | None | Mixed_7b | 1183 | 32 | 554
+inception_v3 | None | Mixed_7c | 1311 | 32 | 618
+inception_v3 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v3 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v3 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v3 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_v3 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_v3 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_v3 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_v3 | 224 | Mixed_5b | 63 | 8 | 18
+inception_v3 | 224 | Mixed_5c | 95 | 8 | 34
+inception_v3 | 224 | Mixed_5d | 127 | 8 | 50
+inception_v3 | 224 | Mixed_6a | 159 | 16 | 58
+inception_v3 | 224 | Mixed_6b | 351 | 16 | 154
+inception_v3 | 224 | Mixed_6c | 543 | 16 | 250
+inception_v3 | 224 | Mixed_6d | 735 | 16 | 346
+inception_v3 | 224 | Mixed_6e | 927 | 16 | 442
+inception_v3 | 224 | Mixed_7a | 1055 | 32 | 490
+inception_v3 | 224 | Mixed_7b | 1183 | 32 | 554
+inception_v3 | 224 | Mixed_7c | 1311 | 32 | 618
+inception_v3 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v3 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v3 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v3 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_v3 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_v3 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_v3 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_v3 | 321 | Mixed_5b | 63 | 8 | 18
+inception_v3 | 321 | Mixed_5c | 95 | 8 | 34
+inception_v3 | 321 | Mixed_5d | 127 | 8 | 50
+inception_v3 | 321 | Mixed_6a | 159 | 16 | 58
+inception_v3 | 321 | Mixed_6b | 351 | 16 | 154
+inception_v3 | 321 | Mixed_6c | 543 | 16 | 250
+inception_v3 | 321 | Mixed_6d | 735 | 16 | 346
+inception_v3 | 321 | Mixed_6e | 927 | 16 | 442
+inception_v3 | 321 | Mixed_7a | 1055 | 32 | 490
+inception_v3 | 321 | Mixed_7b | 1183 | 32 | 554
+inception_v3 | 321 | Mixed_7c | 1311 | 32 | 618
+inception_v4 | None | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v4 | None | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v4 | None | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v4 | None | Mixed_3a | 15 | 4 | 2
+inception_v4 | None | Mixed_4a | 47 | 4 | 14
+inception_v4 | None | Mixed_5a | 55 | 8 | 14
+inception_v4 | None | Mixed_5b | 87 | 8 | 30
+inception_v4 | None | Mixed_5c | 119 | 8 | 46
+inception_v4 | None | Mixed_5d | 151 | 8 | 62
+inception_v4 | None | Mixed_5e | 183 | 8 | 78
+inception_v4 | None | Mixed_6a | 215 | 16 | 86
+inception_v4 | None | Mixed_6b | 407 | 16 | 182
+inception_v4 | None | Mixed_6c | 599 | 16 | 278
+inception_v4 | None | Mixed_6d | 791 | 16 | 374
+inception_v4 | None | Mixed_6e | 983 | 16 | 470
+inception_v4 | None | Mixed_6f | 1175 | 16 | 566
+inception_v4 | None | Mixed_6g | 1367 | 16 | 662
+inception_v4 | None | Mixed_6h | 1559 | 16 | 758
+inception_v4 | None | Mixed_7a | 1687 | 32 | 806
+inception_v4 | None | Mixed_7b | 1815 | 32 | 870
+inception_v4 | None | Mixed_7c | 1943 | 32 | 934
+inception_v4 | None | Mixed_7d | 2071 | 32 | 998
+inception_v4 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v4 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v4 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v4 | 224 | Mixed_3a | 15 | 4 | 2
+inception_v4 | 224 | Mixed_4a | 47 | 4 | 14
+inception_v4 | 224 | Mixed_5a | 55 | 8 | 14
+inception_v4 | 224 | Mixed_5b | 87 | 8 | 30
+inception_v4 | 224 | Mixed_5c | 119 | 8 | 46
+inception_v4 | 224 | Mixed_5d | 151 | 8 | 62
+inception_v4 | 224 | Mixed_5e | 183 | 8 | 78
+inception_v4 | 224 | Mixed_6a | 215 | 16 | 86
+inception_v4 | 224 | Mixed_6b | 407 | 16 | 182
+inception_v4 | 224 | Mixed_6c | 599 | 16 | 278
+inception_v4 | 224 | Mixed_6d | 791 | 16 | 374
+inception_v4 | 224 | Mixed_6e | 983 | 16 | 470
+inception_v4 | 224 | Mixed_6f | 1175 | 16 | 566
+inception_v4 | 224 | Mixed_6g | 1367 | 16 | 662
+inception_v4 | 224 | Mixed_6h | 1559 | 16 | 758
+inception_v4 | 224 | Mixed_7a | 1687 | 32 | 806
+inception_v4 | 224 | Mixed_7b | 1815 | 32 | 870
+inception_v4 | 224 | Mixed_7c | 1943 | 32 | 934
+inception_v4 | 224 | Mixed_7d | 2071 | 32 | 998
+inception_v4 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_v4 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_v4 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_v4 | 321 | Mixed_3a | 15 | 4 | 2
+inception_v4 | 321 | Mixed_4a | 47 | 4 | 14
+inception_v4 | 321 | Mixed_5a | 55 | 8 | 14
+inception_v4 | 321 | Mixed_5b | 87 | 8 | 30
+inception_v4 | 321 | Mixed_5c | 119 | 8 | 46
+inception_v4 | 321 | Mixed_5d | 151 | 8 | 62
+inception_v4 | 321 | Mixed_5e | 183 | 8 | 78
+inception_v4 | 321 | Mixed_6a | 215 | 16 | 86
+inception_v4 | 321 | Mixed_6b | 407 | 16 | 182
+inception_v4 | 321 | Mixed_6c | 599 | 16 | 278
+inception_v4 | 321 | Mixed_6d | 791 | 16 | 374
+inception_v4 | 321 | Mixed_6e | 983 | 16 | 470
+inception_v4 | 321 | Mixed_6f | 1175 | 16 | 566
+inception_v4 | 321 | Mixed_6g | 1367 | 16 | 662
+inception_v4 | 321 | Mixed_6h | 1559 | 16 | 758
+inception_v4 | 321 | Mixed_7a | 1687 | 32 | 806
+inception_v4 | 321 | Mixed_7b | 1815 | 32 | 870
+inception_v4 | 321 | Mixed_7c | 1943 | 32 | 934
+inception_v4 | 321 | Mixed_7d | 2071 | 32 | 998
+inception_resnet_v2 | None | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_resnet_v2 | None | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_resnet_v2 | None | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_resnet_v2 | None | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_resnet_v2 | None | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_resnet_v2 | None | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_resnet_v2 | None | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_resnet_v2 | None | Mixed_5b | 63 | 8 | 18
+inception_resnet_v2 | None | Mixed_6a | 415 | 16 | 186
+inception_resnet_v2 | None | PreAuxLogits | 2335 | 16 | 1146
+inception_resnet_v2 | None | Mixed_7a | 2399 | 32 | 1162
+inception_resnet_v2 | None | Conv2d_7b_1x1 | 3039 | 32 | 1482
+inception_resnet_v2 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_resnet_v2 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_resnet_v2 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_resnet_v2 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_resnet_v2 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_resnet_v2 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_resnet_v2 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_resnet_v2 | 224 | Mixed_5b | 63 | 8 | 18
+inception_resnet_v2 | 224 | Mixed_6a | 415 | 16 | 186
+inception_resnet_v2 | 224 | PreAuxLogits | 2335 | 16 | 1146
+inception_resnet_v2 | 224 | Mixed_7a | 2399 | 32 | 1162
+inception_resnet_v2 | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1482
+inception_resnet_v2 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_resnet_v2 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0
+inception_resnet_v2 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2
+inception_resnet_v2 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2
+inception_resnet_v2 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2
+inception_resnet_v2 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2
+inception_resnet_v2 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2
+inception_resnet_v2 | 321 | Mixed_5b | 63 | 8 | 18
+inception_resnet_v2 | 321 | Mixed_6a | 415 | 16 | 186
+inception_resnet_v2 | 321 | PreAuxLogits | 2335 | 16 | 1146
+inception_resnet_v2 | 321 | Mixed_7a | 2399 | 32 | 1162
+inception_resnet_v2 | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1482
+inception_resnet_v2-same | None | Conv2d_1a_3x3 | 3 | 2 | None
+inception_resnet_v2-same | None | Conv2d_2a_3x3 | 7 | 2 | None
+inception_resnet_v2-same | None | Conv2d_2b_3x3 | 11 | 2 | None
+inception_resnet_v2-same | None | MaxPool_3a_3x3 | 15 | 4 | None
+inception_resnet_v2-same | None | Conv2d_3b_1x1 | 15 | 4 | None
+inception_resnet_v2-same | None | Conv2d_4a_3x3 | 23 | 4 | None
+inception_resnet_v2-same | None | MaxPool_5a_3x3 | 31 | 8 | None
+inception_resnet_v2-same | None | Mixed_5b | 63 | 8 | None
+inception_resnet_v2-same | None | Mixed_6a | 415 | 16 | None
+inception_resnet_v2-same | None | PreAuxLogits | 2335 | 16 | None
+inception_resnet_v2-same | None | Mixed_7a | 2399 | 32 | None
+inception_resnet_v2-same | None | Conv2d_7b_1x1 | 3039 | 32 | None
+inception_resnet_v2-same | 224 | Conv2d_1a_3x3 | 3 | 2 | 0
+inception_resnet_v2-same | 224 | Conv2d_2a_3x3 | 7 | 2 | 2
+inception_resnet_v2-same | 224 | Conv2d_2b_3x3 | 11 | 2 | 4
+inception_resnet_v2-same | 224 | MaxPool_3a_3x3 | 15 | 4 | 4
+inception_resnet_v2-same | 224 | Conv2d_3b_1x1 | 15 | 4 | 4
+inception_resnet_v2-same | 224 | Conv2d_4a_3x3 | 23 | 4 | 8
+inception_resnet_v2-same | 224 | MaxPool_5a_3x3 | 31 | 8 | 8
+inception_resnet_v2-same | 224 | Mixed_5b | 63 | 8 | 24
+inception_resnet_v2-same | 224 | Mixed_6a | 415 | 16 | 192
+inception_resnet_v2-same | 224 | PreAuxLogits | 2335 | 16 | 1152
+inception_resnet_v2-same | 224 | Mixed_7a | 2399 | 32 | 1168
+inception_resnet_v2-same | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1488
+inception_resnet_v2-same | 321 | Conv2d_1a_3x3 | 3 | 2 | 1
+inception_resnet_v2-same | 321 | Conv2d_2a_3x3 | 7 | 2 | 3
+inception_resnet_v2-same | 321 | Conv2d_2b_3x3 | 11 | 2 | 5
+inception_resnet_v2-same | 321 | MaxPool_3a_3x3 | 15 | 4 | 7
+inception_resnet_v2-same | 321 | Conv2d_3b_1x1 | 15 | 4 | 7
+inception_resnet_v2-same | 321 | Conv2d_4a_3x3 | 23 | 4 | 11
+inception_resnet_v2-same | 321 | MaxPool_5a_3x3 | 31 | 8 | 15
+inception_resnet_v2-same | 321 | Mixed_5b | 63 | 8 | 31
+inception_resnet_v2-same | 321 | Mixed_6a | 415 | 16 | 207
+inception_resnet_v2-same | 321 | PreAuxLogits | 2335 | 16 | 1167
+inception_resnet_v2-same | 321 | Mixed_7a | 2399 | 32 | 1199
+inception_resnet_v2-same | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1519
+mobilenet_v1 | None | Conv2d_0 | 3 | 2 | None
+mobilenet_v1 | None | Conv2d_1_pointwise | 7 | 2 | None
+mobilenet_v1 | None | Conv2d_2_pointwise | 11 | 4 | None
+mobilenet_v1 | None | Conv2d_3_pointwise | 19 | 4 | None
+mobilenet_v1 | None | Conv2d_4_pointwise | 27 | 8 | None
+mobilenet_v1 | None | Conv2d_5_pointwise | 43 | 8 | None
+mobilenet_v1 | None | Conv2d_6_pointwise | 59 | 16 | None
+mobilenet_v1 | None | Conv2d_7_pointwise | 91 | 16 | None
+mobilenet_v1 | None | Conv2d_8_pointwise | 123 | 16 | None
+mobilenet_v1 | None | Conv2d_9_pointwise | 155 | 16 | None
+mobilenet_v1 | None | Conv2d_10_pointwise | 187 | 16 | None
+mobilenet_v1 | None | Conv2d_11_pointwise | 219 | 16 | None
+mobilenet_v1 | None | Conv2d_12_pointwise | 251 | 32 | None
+mobilenet_v1 | None | Conv2d_13_pointwise | 315 | 32 | None
+mobilenet_v1 | 224 | Conv2d_0 | 3 | 2 | 0
+mobilenet_v1 | 224 | Conv2d_1_pointwise | 7 | 2 | 2
+mobilenet_v1 | 224 | Conv2d_2_pointwise | 11 | 4 | 2
+mobilenet_v1 | 224 | Conv2d_3_pointwise | 19 | 4 | 6
+mobilenet_v1 | 224 | Conv2d_4_pointwise | 27 | 8 | 6
+mobilenet_v1 | 224 | Conv2d_5_pointwise | 43 | 8 | 14
+mobilenet_v1 | 224 | Conv2d_6_pointwise | 59 | 16 | 14
+mobilenet_v1 | 224 | Conv2d_7_pointwise | 91 | 16 | 30
+mobilenet_v1 | 224 | Conv2d_8_pointwise | 123 | 16 | 46
+mobilenet_v1 | 224 | Conv2d_9_pointwise | 155 | 16 | 62
+mobilenet_v1 | 224 | Conv2d_10_pointwise | 187 | 16 | 78
+mobilenet_v1 | 224 | Conv2d_11_pointwise | 219 | 16 | 94
+mobilenet_v1 | 224 | Conv2d_12_pointwise | 251 | 32 | 94
+mobilenet_v1 | 224 | Conv2d_13_pointwise | 315 | 32 | 126
+mobilenet_v1 | 321 | Conv2d_0 | 3 | 2 | 1
+mobilenet_v1 | 321 | Conv2d_1_pointwise | 7 | 2 | 3
+mobilenet_v1 | 321 | Conv2d_2_pointwise | 11 | 4 | 5
+mobilenet_v1 | 321 | Conv2d_3_pointwise | 19 | 4 | 9
+mobilenet_v1 | 321 | Conv2d_4_pointwise | 27 | 8 | 13
+mobilenet_v1 | 321 | Conv2d_5_pointwise | 43 | 8 | 21
+mobilenet_v1 | 321 | Conv2d_6_pointwise | 59 | 16 | 29
+mobilenet_v1 | 321 | Conv2d_7_pointwise | 91 | 16 | 45
+mobilenet_v1 | 321 | Conv2d_8_pointwise | 123 | 16 | 61
+mobilenet_v1 | 321 | Conv2d_9_pointwise | 155 | 16 | 77
+mobilenet_v1 | 321 | Conv2d_10_pointwise | 187 | 16 | 93
+mobilenet_v1 | 321 | Conv2d_11_pointwise | 219 | 16 | 109
+mobilenet_v1 | 321 | Conv2d_12_pointwise | 251 | 32 | 125
+mobilenet_v1 | 321 | Conv2d_13_pointwise | 315 | 32 | 157
+mobilenet_v1_075 | None | Conv2d_0 | 3 | 2 | None
+mobilenet_v1_075 | None | Conv2d_1_pointwise | 7 | 2 | None
+mobilenet_v1_075 | None | Conv2d_2_pointwise | 11 | 4 | None
+mobilenet_v1_075 | None | Conv2d_3_pointwise | 19 | 4 | None
+mobilenet_v1_075 | None | Conv2d_4_pointwise | 27 | 8 | None
+mobilenet_v1_075 | None | Conv2d_5_pointwise | 43 | 8 | None
+mobilenet_v1_075 | None | Conv2d_6_pointwise | 59 | 16 | None
+mobilenet_v1_075 | None | Conv2d_7_pointwise | 91 | 16 | None
+mobilenet_v1_075 | None | Conv2d_8_pointwise | 123 | 16 | None
+mobilenet_v1_075 | None | Conv2d_9_pointwise | 155 | 16 | None
+mobilenet_v1_075 | None | Conv2d_10_pointwise | 187 | 16 | None
+mobilenet_v1_075 | None | Conv2d_11_pointwise | 219 | 16 | None
+mobilenet_v1_075 | None | Conv2d_12_pointwise | 251 | 32 | None
+mobilenet_v1_075 | None | Conv2d_13_pointwise | 315 | 32 | None
+mobilenet_v1_075 | 224 | Conv2d_0 | 3 | 2 | 0
+mobilenet_v1_075 | 224 | Conv2d_1_pointwise | 7 | 2 | 2
+mobilenet_v1_075 | 224 | Conv2d_2_pointwise | 11 | 4 | 2
+mobilenet_v1_075 | 224 | Conv2d_3_pointwise | 19 | 4 | 6
+mobilenet_v1_075 | 224 | Conv2d_4_pointwise | 27 | 8 | 6
+mobilenet_v1_075 | 224 | Conv2d_5_pointwise | 43 | 8 | 14
+mobilenet_v1_075 | 224 | Conv2d_6_pointwise | 59 | 16 | 14
+mobilenet_v1_075 | 224 | Conv2d_7_pointwise | 91 | 16 | 30
+mobilenet_v1_075 | 224 | Conv2d_8_pointwise | 123 | 16 | 46
+mobilenet_v1_075 | 224 | Conv2d_9_pointwise | 155 | 16 | 62
+mobilenet_v1_075 | 224 | Conv2d_10_pointwise | 187 | 16 | 78
+mobilenet_v1_075 | 224 | Conv2d_11_pointwise | 219 | 16 | 94
+mobilenet_v1_075 | 224 | Conv2d_12_pointwise | 251 | 32 | 94
+mobilenet_v1_075 | 224 | Conv2d_13_pointwise | 315 | 32 | 126
+mobilenet_v1_075 | 321 | Conv2d_0 | 3 | 2 | 1
+mobilenet_v1_075 | 321 | Conv2d_1_pointwise | 7 | 2 | 3
+mobilenet_v1_075 | 321 | Conv2d_2_pointwise | 11 | 4 | 5
+mobilenet_v1_075 | 321 | Conv2d_3_pointwise | 19 | 4 | 9
+mobilenet_v1_075 | 321 | Conv2d_4_pointwise | 27 | 8 | 13
+mobilenet_v1_075 | 321 | Conv2d_5_pointwise | 43 | 8 | 21
+mobilenet_v1_075 | 321 | Conv2d_6_pointwise | 59 | 16 | 29
+mobilenet_v1_075 | 321 | Conv2d_7_pointwise | 91 | 16 | 45
+mobilenet_v1_075 | 321 | Conv2d_8_pointwise | 123 | 16 | 61
+mobilenet_v1_075 | 321 | Conv2d_9_pointwise | 155 | 16 | 77
+mobilenet_v1_075 | 321 | Conv2d_10_pointwise | 187 | 16 | 93
+mobilenet_v1_075 | 321 | Conv2d_11_pointwise | 219 | 16 | 109
+mobilenet_v1_075 | 321 | Conv2d_12_pointwise | 251 | 32 | 125
+mobilenet_v1_075 | 321 | Conv2d_13_pointwise | 315 | 32 | 157
+resnet_v1_50 | None | resnet_v1_50/block1 | 35 | 8 | None
+resnet_v1_50 | None | resnet_v1_50/block2 | 99 | 16 | None
+resnet_v1_50 | None | resnet_v1_50/block3 | 291 | 32 | None
+resnet_v1_50 | None | resnet_v1_50/block4 | 483 | 32 | None
+resnet_v1_50 | 224 | resnet_v1_50/block1 | 35 | 8 | 15
+resnet_v1_50 | 224 | resnet_v1_50/block2 | 99 | 16 | 47
+resnet_v1_50 | 224 | resnet_v1_50/block3 | 291 | 32 | 143
+resnet_v1_50 | 224 | resnet_v1_50/block4 | 483 | 32 | 239
+resnet_v1_50 | 321 | resnet_v1_50/block1 | 35 | 8 | 17
+resnet_v1_50 | 321 | resnet_v1_50/block2 | 99 | 16 | 49
+resnet_v1_50 | 321 | resnet_v1_50/block3 | 291 | 32 | 145
+resnet_v1_50 | 321 | resnet_v1_50/block4 | 483 | 32 | 241
+resnet_v1_101 | None | resnet_v1_101/block1 | 35 | 8 | None
+resnet_v1_101 | None | resnet_v1_101/block2 | 99 | 16 | None
+resnet_v1_101 | None | resnet_v1_101/block3 | 835 | 32 | None
+resnet_v1_101 | None | resnet_v1_101/block4 | 1027 | 32 | None
+resnet_v1_101 | 224 | resnet_v1_101/block1 | 35 | 8 | 15
+resnet_v1_101 | 224 | resnet_v1_101/block2 | 99 | 16 | 47
+resnet_v1_101 | 224 | resnet_v1_101/block3 | 835 | 32 | 415
+resnet_v1_101 | 224 | resnet_v1_101/block4 | 1027 | 32 | 511
+resnet_v1_101 | 321 | resnet_v1_101/block1 | 35 | 8 | 17
+resnet_v1_101 | 321 | resnet_v1_101/block2 | 99 | 16 | 49
+resnet_v1_101 | 321 | resnet_v1_101/block3 | 835 | 32 | 417
+resnet_v1_101 | 321 | resnet_v1_101/block4 | 1027 | 32 | 513
+resnet_v1_152 | None | resnet_v1_152/block1 | 35 | 8 | None
+resnet_v1_152 | None | resnet_v1_152/block2 | 163 | 16 | None
+resnet_v1_152 | None | resnet_v1_152/block3 | 1315 | 32 | None
+resnet_v1_152 | None | resnet_v1_152/block4 | 1507 | 32 | None
+resnet_v1_152 | 224 | resnet_v1_152/block1 | 35 | 8 | 15
+resnet_v1_152 | 224 | resnet_v1_152/block2 | 163 | 16 | 79
+resnet_v1_152 | 224 | resnet_v1_152/block3 | 1315 | 32 | 655
+resnet_v1_152 | 224 | resnet_v1_152/block4 | 1507 | 32 | 751
+resnet_v1_152 | 321 | resnet_v1_152/block1 | 35 | 8 | 17
+resnet_v1_152 | 321 | resnet_v1_152/block2 | 163 | 16 | 81
+resnet_v1_152 | 321 | resnet_v1_152/block3 | 1315 | 32 | 657
+resnet_v1_152 | 321 | resnet_v1_152/block4 | 1507 | 32 | 753
+resnet_v1_200 | None | resnet_v1_200/block1 | 35 | 8 | None
+resnet_v1_200 | None | resnet_v1_200/block2 | 419 | 16 | None
+resnet_v1_200 | None | resnet_v1_200/block3 | 1571 | 32 | None
+resnet_v1_200 | None | resnet_v1_200/block4 | 1763 | 32 | None
+resnet_v1_200 | 224 | resnet_v1_200/block1 | 35 | 8 | 15
+resnet_v1_200 | 224 | resnet_v1_200/block2 | 419 | 16 | 207
+resnet_v1_200 | 224 | resnet_v1_200/block3 | 1571 | 32 | 783
+resnet_v1_200 | 224 | resnet_v1_200/block4 | 1763 | 32 | 879
+resnet_v1_200 | 321 | resnet_v1_200/block1 | 35 | 8 | 17
+resnet_v1_200 | 321 | resnet_v1_200/block2 | 419 | 16 | 209
+resnet_v1_200 | 321 | resnet_v1_200/block3 | 1571 | 32 | 785
+resnet_v1_200 | 321 | resnet_v1_200/block4 | 1763 | 32 | 881
+resnet_v2_50 | None | resnet_v2_50/block1 | 35 | 8 | None
+resnet_v2_50 | None | resnet_v2_50/block2 | 99 | 16 | None
+resnet_v2_50 | None | resnet_v2_50/block3 | 291 | 32 | None
+resnet_v2_50 | None | resnet_v2_50/block4 | 483 | 32 | None
+resnet_v2_50 | 224 | resnet_v2_50/block1 | 35 | 8 | 15
+resnet_v2_50 | 224 | resnet_v2_50/block2 | 99 | 16 | 47
+resnet_v2_50 | 224 | resnet_v2_50/block3 | 291 | 32 | 143
+resnet_v2_50 | 224 | resnet_v2_50/block4 | 483 | 32 | 239
+resnet_v2_50 | 321 | resnet_v2_50/block1 | 35 | 8 | 17
+resnet_v2_50 | 321 | resnet_v2_50/block2 | 99 | 16 | 49
+resnet_v2_50 | 321 | resnet_v2_50/block3 | 291 | 32 | 145
+resnet_v2_50 | 321 | resnet_v2_50/block4 | 483 | 32 | 241
+resnet_v2_101 | None | resnet_v2_101/block1 | 35 | 8 | None
+resnet_v2_101 | None | resnet_v2_101/block2 | 99 | 16 | None
+resnet_v2_101 | None | resnet_v2_101/block3 | 835 | 32 | None
+resnet_v2_101 | None | resnet_v2_101/block4 | 1027 | 32 | None
+resnet_v2_101 | 224 | resnet_v2_101/block1 | 35 | 8 | 15
+resnet_v2_101 | 224 | resnet_v2_101/block2 | 99 | 16 | 47
+resnet_v2_101 | 224 | resnet_v2_101/block3 | 835 | 32 | 415
+resnet_v2_101 | 224 | resnet_v2_101/block4 | 1027 | 32 | 511
+resnet_v2_101 | 321 | resnet_v2_101/block1 | 35 | 8 | 17
+resnet_v2_101 | 321 | resnet_v2_101/block2 | 99 | 16 | 49
+resnet_v2_101 | 321 | resnet_v2_101/block3 | 835 | 32 | 417
+resnet_v2_101 | 321 | resnet_v2_101/block4 | 1027 | 32 | 513
+resnet_v2_152 | None | resnet_v2_152/block1 | 35 | 8 | None
+resnet_v2_152 | None | resnet_v2_152/block2 | 163 | 16 | None
+resnet_v2_152 | None | resnet_v2_152/block3 | 1315 | 32 | None
+resnet_v2_152 | None | resnet_v2_152/block4 | 1507 | 32 | None
+resnet_v2_152 | 224 | resnet_v2_152/block1 | 35 | 8 | 15
+resnet_v2_152 | 224 | resnet_v2_152/block2 | 163 | 16 | 79
+resnet_v2_152 | 224 | resnet_v2_152/block3 | 1315 | 32 | 655
+resnet_v2_152 | 224 | resnet_v2_152/block4 | 1507 | 32 | 751
+resnet_v2_152 | 321 | resnet_v2_152/block1 | 35 | 8 | 17
+resnet_v2_152 | 321 | resnet_v2_152/block2 | 163 | 16 | 81
+resnet_v2_152 | 321 | resnet_v2_152/block3 | 1315 | 32 | 657
+resnet_v2_152 | 321 | resnet_v2_152/block4 | 1507 | 32 | 753
+resnet_v2_200 | None | resnet_v2_200/block1 | 35 | 8 | None
+resnet_v2_200 | None | resnet_v2_200/block2 | 419 | 16 | None
+resnet_v2_200 | None | resnet_v2_200/block3 | 1571 | 32 | None
+resnet_v2_200 | None | resnet_v2_200/block4 | 1763 | 32 | None
+resnet_v2_200 | 224 | resnet_v2_200/block1 | 35 | 8 | 15
+resnet_v2_200 | 224 | resnet_v2_200/block2 | 419 | 16 | 207
+resnet_v2_200 | 224 | resnet_v2_200/block3 | 1571 | 32 | 783
+resnet_v2_200 | 224 | resnet_v2_200/block4 | 1763 | 32 | 879
+resnet_v2_200 | 321 | resnet_v2_200/block1 | 35 | 8 | 17
+resnet_v2_200 | 321 | resnet_v2_200/block2 | 419 | 16 | 209
+resnet_v2_200 | 321 | resnet_v2_200/block3 | 1571 | 32 | 785
+resnet_v2_200 | 321 | resnet_v2_200/block4 | 1763 | 32 | 881
+
+## FAQ
+
+### What does a resolution of 'None' mean?
+
+In this case, the input resolution is undefined. For most models, the receptive
+field parameters can be computed even without knowing the input resolution.
+
+### For some networks, effective_padding shows as 'None' (eg, for Inception_v2 or Mobilenet_v1 when input size is not specified). Why is that?
+
+This means that the padding for these networks depends on the input size. So,
+unless we know exactly the input image dimensionality to be used, it is not
+possible to determine the padding applied at the different layers. Look at the
+other entries where the input size is fixed; for those cases, effective_padding
+is not None.
+
+This happens due to Tensorflow's implementation of the 'SAME' padding mode,
+which may depend on the input feature map size to a given layer. For background
+on this, see [these notes from the TF
+documentation](https://www.tensorflow.org/versions/master/api_guides/python/nn#Notes_on_SAME_Convolution_Padding).
+
+Also, note that in this case the program is not able to check if the network is
+aligned (ie, it could be that the different paths from input to output have
+receptive fields which are not consistently centered at the same position in the
+input image).
+
+So you should be aware that such networks might not be aligned -- the program
+has no way of checking it when the padding cannot be determined.
+
+### The receptive field parameters for network X seem different from what I expected... maybe your calculation is incorrect?
+
+First, note that the results presented here are based on the tensorflow
+implementations from the [TF-Slim model
+library](https://github.com/tensorflow/models/tree/master/research/slim).
+
+So, it is possible that due to some implementation details the RF parameters are
+different.
+
+One common case of confusion is the TF-Slim Resnet implementation, which applies
+stride in the last residual unit of each block, instead of at the input
+activations in the first residual unit of each block (which is what is described
+in the Resnet paper) -- see [this
+comment](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py#L30).
+This makes the stride with respect to each convolution block potentially
+different. In this case, though, note that a
+[flag](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py#L150)
+may be used to recover the original striding convention.
+
+Second, it could be that we have a bug somewhere. While we include [many
+tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py)
+in our library, it is always possible that we missed something. If you suspect
+this is happening, please file a GitHub issue
+[here](https://github.com/tensorflow/tensorflow/issues).
diff --git a/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py
new file mode 100644
index 0000000000..4495d74bbf
--- /dev/null
+++ b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Simple script to convert CSV output from rf_benchmark to Markdown format.
+
+The input CSV should have the following fields:
+- CNN
+- input resolution
+- end_point
+- RF size hor
+- RF size ver
+- effective stride hor
+- effective stride ver
+- effective padding hor
+- effective padding ver
+
+Since usually in all cases the parameters in the horizontal and vertical
+directions are the same, this is assumed by this script, which only prints one
+of them to the Markdown file.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import csv
+import sys
+
+from tensorflow.python.platform import app
+
+cmd_args = None
+
+
+def main(unused_argv):
+ with open(cmd_args.markdown_path, 'w') as f:
+ # Write table header and field size.
+ f.write('CNN | resolution | end-point | RF | effective stride | '
+ 'effective padding|\n')
+ f.write(
+ ':--------------------: | :----------: | :---------------: | :-----: |'
+ ' :----: | :----:|\n')
+ with open(cmd_args.csv_path) as csvfile:
+ reader = csv.DictReader(csvfile)
+ for row in reader:
+ # Make sure horizontal and parameters are the same.
+ assert row['RF size hor'] == row['RF size ver']
+ assert row['effective stride hor'] == row['effective stride ver']
+ assert row['effective padding hor'] == row['effective padding ver']
+
+ f.write('%s|%s|%s|%s|%s|%s\n' %
+ (row['CNN'], row['input resolution'], row['end_point'],
+ row['RF size hor'], row['effective stride hor'],
+ row['effective padding hor']))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--csv_path',
+ type=str,
+ default='/tmp/rf.csv',
+ help='Path where CSV output of rf_benchmark was saved.')
+ parser.add_argument(
+ '--markdown_path',
+ type=str,
+ default='/tmp/rf.md',
+ help='Path where Markdown output will be saved.')
+ cmd_args, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py
index bc383a8034..0e3c46f17d 100644
--- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py
+++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import tf_logging as logging
_UNCHANGED_RF_LAYER_OPS = [
"Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor",
"FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu",
- "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2"
+ "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN"
]
# Different ways in which padding modes may be spelled.
diff --git a/tensorflow/contrib/recurrent/BUILD b/tensorflow/contrib/recurrent/BUILD
index b3cb04ce26..f9827f766d 100644
--- a/tensorflow/contrib/recurrent/BUILD
+++ b/tensorflow/contrib/recurrent/BUILD
@@ -102,5 +102,8 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = ["nopip"],
+ tags = [
+ "nopip",
+ "optonly",
+ ],
)
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 4eb5c920b3..2a84629080 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -118,7 +118,6 @@ cuda_py_tests(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
"//tensorflow/python:rnn",
"//tensorflow/python:rnn_cell",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 67f31785b5..cb437f2a2f 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -58,6 +58,10 @@ See @{$python/contrib.rnn} guide.
@@Conv3DLSTMCell
@@HighwayWrapper
@@GLSTMCell
+@@SRUCell
+@@IndRNNCell
+@@IndyGRUCell
+@@IndyLSTMCell
<!--RNNCell wrappers-->
@@AttentionCellWrapper
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index b8840a8f24..85f0f8ced9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import numpy as np
@@ -35,7 +34,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
@@ -117,6 +115,27 @@ class RNNCellTest(test.TestCase):
})
self.assertEqual(res[0].shape, (1, 2))
+ def testIndRNNCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.IndRNNCell(2)
+ g, _ = cell(x, m)
+ self.assertEqual([
+ "root/ind_rnn_cell/%s_w:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s_u:0" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/ind_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
+ ], [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ self.assertEqual(res[0].shape, (1, 2))
+
def testGRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -145,6 +164,34 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testIndyGRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.185265, 0.17704]])
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyGRUCell with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.IndyGRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.155127, 0.157328]])
+
def testSRUCell(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -345,6 +392,72 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_mem0)
self.assertAllClose(res[2], expected_mem1)
+ def testIndyLSTMCell(self):
+ for dtype in [dtypes.float16, dtypes.float32]:
+ np_dtype = dtype.as_numpy_dtype
+ with self.test_session(graph=ops.Graph()) as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2], dtype=dtype)
+ state_0 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ state_1 = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ cell = rnn_cell_impl.MultiRNNCell(
+ [contrib_rnn_cell.IndyLSTMCell(2) for _ in range(2)])
+ self.assertEqual(cell.dtype, None)
+ self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name)
+ self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name)
+ cell.get_config() # Should not throw an error
+ g, (out_state_0, out_state_1) = cell(x, (state_0, state_1))
+ # Layer infers the input type.
+ self.assertEqual(cell.dtype, dtype.name)
+ expected_variable_names = [
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_0/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_w:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s_u:0" %
+ rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ "root/multi_rnn_cell/cell_1/indy_lstm_cell/%s:0" %
+ rnn_cell_impl._BIAS_VARIABLE_NAME
+ ]
+ self.assertEqual(expected_variable_names,
+ [v.name for v in cell.trainable_variables])
+ self.assertFalse(cell.non_trainable_variables)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state_0, out_state_1], {
+ x.name: np.array([[1., 1.]]),
+ state_0[0].name: 0.1 * np.ones([1, 2]),
+ state_0[1].name: 0.1 * np.ones([1, 2]),
+ state_1[0].name: 0.1 * np.ones([1, 2]),
+ state_1[1].name: 0.1 * np.ones([1, 2]),
+ })
+ self.assertEqual(len(res), 3)
+ variables = variables_lib.global_variables()
+ self.assertEqual(expected_variable_names, [v.name for v in variables])
+ # Only check the range of outputs as this is just a smoke test.
+ self.assertAllInRange(res[0], -1.0, 1.0)
+ self.assertAllInRange(res[1], -1.0, 1.0)
+ self.assertAllInRange(res[2], -1.0, 1.0)
+ with variable_scope.variable_scope(
+ "other", initializer=init_ops.constant_initializer(0.5)):
+ # Test IndyLSTMCell with input_size != num_units.
+ x = array_ops.zeros([1, 3], dtype=dtype)
+ state = (array_ops.zeros([1, 2], dtype=dtype),) * 2
+ g, out_state = contrib_rnn_cell.IndyLSTMCell(2)(x, state)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g, out_state], {
+ x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
+ state[0].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ state[1].name: 0.1 * np.ones([1, 2], dtype=np_dtype),
+ })
+ self.assertEqual(len(res), 2)
+
def testLSTMCell(self):
with self.test_session() as sess:
num_units = 8
@@ -443,7 +556,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWrapperCheckpointing(self):
for wrapper_type in [
rnn_cell_impl.DropoutWrapper,
@@ -935,50 +1048,6 @@ class DropoutWrapperTest(test.TestCase):
self.assertAllClose(res0[1].h, res1[1].h)
-class SlimRNNCellTest(test.TestCase):
-
- def testBasicRNNCell(self):
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 2])
- m = array_ops.zeros([1, 2])
- my_cell = functools.partial(basic_rnn_cell, num_units=2)
- # pylint: disable=protected-access
- g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
- # pylint: enable=protected-access
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g], {
- x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])
- })
- self.assertEqual(res[0].shape, (1, 2))
-
- def testBasicRNNCellMatch(self):
- batch_size = 32
- input_size = 100
- num_units = 10
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- inputs = random_ops.random_uniform((batch_size, input_size))
- _, initial_state = basic_rnn_cell(inputs, None, num_units)
- rnn_cell = rnn_cell_impl.BasicRNNCell(num_units)
- outputs, state = rnn_cell(inputs, initial_state)
- variable_scope.get_variable_scope().reuse_variables()
- my_cell = functools.partial(basic_rnn_cell, num_units=num_units)
- # pylint: disable=protected-access
- slim_cell = rnn_cell_impl._SlimRNNCell(my_cell)
- # pylint: enable=protected-access
- slim_outputs, slim_state = slim_cell(inputs, initial_state)
- self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
- self.assertEqual(slim_state.get_shape(), state.get_shape())
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([slim_outputs, slim_state, outputs, state])
- self.assertAllClose(res[0], res[2])
- self.assertAllClose(res[1], res[3])
-
-
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
if inputs is not None:
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index be99a5d67a..1c20d88fe4 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -921,7 +921,7 @@ class LSTMTest(test.TestCase):
# Smoke test, this should not raise an error
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDynamicRNNWithTupleStates(self):
num_units = 3
input_size = 5
@@ -997,7 +997,7 @@ class LSTMTest(test.TestCase):
self.assertAllEqual(array_ops.stack(outputs_static), outputs_dynamic)
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDynamicRNNWithNestedTupleStates(self):
num_units = 3
input_size = 5
@@ -1285,7 +1285,7 @@ class LSTMTest(test.TestCase):
"Comparing individual variable gradients iteration %d" % i)
self.assertAllEqual(a, b)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDynamicEquivalentToStaticRNN(self):
self._testDynamicEquivalentToStaticRNN(use_sequence_length=False)
self._testDynamicEquivalentToStaticRNN(use_sequence_length=False)
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index b12e2cd5ed..1816b469ee 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -23,6 +23,7 @@ import math
from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
@@ -30,6 +31,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl # pylint: disable=unused-import
@@ -3050,3 +3052,343 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
+
+
+class IndRNNCell(rnn_cell_impl.LayerRNNCell):
+ """Independently Recurrent Neural Network (IndRNN) cell
+ (cf. https://arxiv.org/abs/1803.04831).
+
+ Args:
+ num_units: int, The number of units in the RNN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units])
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=init_ops.zeros_initializer(dtype=self.dtype))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """IndRNN: output = new_state = act(W * input + u * state + B)."""
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w) + (
+ state * self._kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+ output = self._activation(gate_inputs)
+ return output, output
+
+
+class IndyGRUCell(rnn_cell_impl.LayerRNNCell):
+ r"""Independently Gated Recurrent Unit cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to GRUCell,
+ yet with the \(U_r\), \(U_z\), and \(U\) matrices in equations 5, 6, and
+ 8 of http://arxiv.org/abs/1406.1078 respectively replaced by diagonal
+ matrices, i.e. a Hadamard product with a single vector:
+
+ $$r_j = \sigma\left([\mathbf W_r\mathbf x]_j +
+ [\mathbf u_r\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$z_j = \sigma\left([\mathbf W_z\mathbf x]_j +
+ [\mathbf u_z\circ \mathbf h_{(t-1)}]_j\right)$$
+ $$\tilde{h}^{(t)}_j = \phi\left([\mathbf W \mathbf x]_j +
+ [\mathbf u \circ \mathbf r \circ \mathbf h_{(t-1)}]_j\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyGRU
+ node sees only its own state, as opposed to seeing all states in the same
+ layer.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+
+ Args:
+ num_units: int, The number of units in the GRU cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ kernel_initializer: (optional) The initializer to use for the weight
+ matrices applied to the input.
+ bias_initializer: (optional) The initializer to use for the bias.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+
+ def __init__(self,
+ num_units,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None,
+ dtype=None):
+ super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._gate_kernel_w = self.add_variable(
+ "gates/%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 2 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._gate_kernel_u = self.add_variable(
+ "gates/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 2 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._gate_bias = self.add_variable(
+ "gates/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[2 * self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.constant_initializer(1.0, dtype=self.dtype)))
+ self._candidate_kernel_w = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units],
+ initializer=self._kernel_initializer)
+ self._candidate_kernel_u = self.add_variable(
+ "candidate/%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._candidate_bias = self.add_variable(
+ "candidate/%s" % rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.zeros_initializer(dtype=self.dtype)))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Gated recurrent unit (GRU) with nunits cells."""
+
+ gate_inputs = math_ops.matmul(inputs, self._gate_kernel_w) + (
+ gen_array_ops.tile(state, [1, 2]) * self._gate_kernel_u)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
+
+ value = math_ops.sigmoid(gate_inputs)
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+
+ r_state = r * state
+
+ candidate = math_ops.matmul(inputs, self._candidate_kernel_w) + (
+ r_state * self._candidate_kernel_u)
+ candidate = nn_ops.bias_add(candidate, self._candidate_bias)
+
+ c = self._activation(candidate)
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
+ r"""Basic IndyLSTM recurrent network cell.
+
+ Based on IndRNNs (https://arxiv.org/abs/1803.04831) and similar to
+ BasicLSTMCell, yet with the \(U_f\), \(U_i\), \(U_o\) and \(U_c\)
+ matrices in
+ https://en.wikipedia.org/wiki/Long_short-term_memory#LSTM_with_a_forget_gate
+ replaced by diagonal matrices, i.e. a Hadamard product with a single vector:
+
+ $$f_t = \sigma_g\left(W_f x_t + u_f \circ h_{t-1} + b_f\right)$$
+ $$i_t = \sigma_g\left(W_i x_t + u_i \circ h_{t-1} + b_i\right)$$
+ $$o_t = \sigma_g\left(W_o x_t + u_o \circ h_{t-1} + b_o\right)$$
+ $$c_t = f_t \circ c_{t-1} +
+ i_t \circ \sigma_c\left(W_c x_t + u_c \circ h_{t-1} + b_c\right)$$
+
+ where \(\circ\) denotes the Hadamard operator. This means that each IndyLSTM
+ node sees only its own state \(h\) and \(c\), as opposed to seeing all
+ states in the same layer.
+
+ We add forget_bias (default: 1) to the biases of the forget gate in order to
+ reduce the scale of forgetting in the beginning of the training.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ that follows.
+
+ TODO(gonnet): Write a paper describing this and add a reference here.
+ """
+
+ def __init__(self,
+ num_units,
+ forget_bias=1.0,
+ activation=None,
+ reuse=None,
+ kernel_initializer=None,
+ bias_initializer=None,
+ name=None,
+ dtype=None):
+ """Initialize the IndyLSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ forget_bias: float, The bias added to forget gates (see above).
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
+ checkpoints.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ kernel_initializer: (optional) The initializer to use for the weight
+ matrix applied to the inputs.
+ bias_initializer: (optional) The initializer to use for the bias.
+ name: String, the name of the layer. Layers with the same name will
+ share weights, but to avoid mistakes we require reuse=True in such
+ cases.
+ dtype: Default dtype of the layer (default of `None` means use the type
+ of the first input). Required when `build` is called before `call`.
+ """
+ super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+ self._activation = activation or math_ops.tanh
+ self._kernel_initializer = kernel_initializer
+ self._bias_initializer = bias_initializer
+
+ @property
+ def state_size(self):
+ return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError(
+ "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ # pylint: disable=protected-access
+ self._kernel_w = self.add_variable(
+ "%s_w" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, 4 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._kernel_u = self.add_variable(
+ "%s_u" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[1, 4 * self._num_units],
+ initializer=init_ops.random_uniform_initializer(
+ minval=-1, maxval=1, dtype=self.dtype))
+ self._bias = self.add_variable(
+ rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[4 * self._num_units],
+ initializer=(self._bias_initializer
+ if self._bias_initializer is not None else
+ init_ops.zeros_initializer(dtype=self.dtype)))
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Independent Long short-term memory cell (IndyLSTM).
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size, input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size, num_units]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (a
+ `LSTMStateTuple`).
+ """
+ sigmoid = math_ops.sigmoid
+ one = constant_op.constant(1, dtype=dtypes.int32)
+ c, h = state
+
+ gate_inputs = math_ops.matmul(inputs, self._kernel_w)
+ gate_inputs += gen_array_ops.tile(h, [1, 4]) * self._kernel_u
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = array_ops.split(
+ value=gate_inputs, num_or_size_splits=4, axis=one)
+
+ forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)
+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
+ # performance improvement. So using those at the cost of readability.
+ add = math_ops.add
+ multiply = math_ops.multiply
+ new_c = add(
+ multiply(c, sigmoid(add(f, forget_bias_tensor))),
+ multiply(sigmoid(i), self._activation(j)))
+ new_h = multiply(self._activation(new_c), sigmoid(o))
+
+ new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
+ return new_h, new_state
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 184144f64a..c7fbeea310 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -250,7 +250,7 @@ class BeamSearchDecoder(decoder.Decoder):
```
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
- tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
+ tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index e69725ff8a..f58268eff5 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import abc
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -182,19 +183,20 @@ def dynamic_decode(decoder,
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
- def _is_xla_tensor(tensor):
- try:
- op = tensor.op
- except AttributeError:
- return False
- if control_flow_util.IsInXLAContext(op):
- return True
- return False
-
with variable_scope.variable_scope(scope, "decoder") as varscope:
- # Properly cache variable values inside the while_loop
- if varscope.caching_device is None:
- varscope.set_caching_device(lambda op: op.device)
+ # Determine context types.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
+ in_while_loop = (
+ control_flow_util.GetContainingWhileContext(ctxt) is not None)
+ # Properly cache variable values inside the while_loop.
+ # Don't set a caching device when running in a loop, since it is possible
+ # that train steps could be wrapped in a tf.while_loop. In that scenario
+ # caching prevents forward computations in loop iterations from re-reading
+ # the updated weights.
+ if not context.executing_eagerly() and not in_while_loop:
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
@@ -208,9 +210,6 @@ def dynamic_decode(decoder,
decoder.output_dtype,
decoder.batch_size)
- is_xla = False
- if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
- is_xla = True
if is_xla and maximum_iterations is None:
raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
index 03d6da7765..f10d78259a 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py
@@ -147,7 +147,7 @@ class SpectralOpsTest(test.TestCase):
inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8,
fft_length=16, frame_step=8)
expected_length = (stft.shape[0] - 1) * 8 + 8
- self.assertAllEqual([None], inverse_stft.shape.as_list())
+ self.assertAllEqual([256], inverse_stft.shape.as_list())
self.assertAllEqual([expected_length], inverse_stft.eval().shape)
def test_stft_and_inverse_stft(self):
diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py
index 9a3603b6a9..7d6289532a 100644
--- a/tensorflow/contrib/signal/python/kernel_tests/test_util.py
+++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py
@@ -39,6 +39,7 @@ def grappler_optimize(graph, fetches=None, rewriter_config=None):
"""
if rewriter_config is None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.min_graph_nodes = -1
if fetches is not None:
for fetch in fetches:
graph.add_to_collection('train_op', fetch)
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index 94fc12ca81..2c97834523 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -26,7 +26,6 @@ import time
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as variables_lib
-from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.slim.python.slim import evaluation
from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
from tensorflow.core.protobuf import saver_pb2
@@ -34,9 +33,9 @@ from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.wrappers import hooks
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
@@ -89,8 +88,8 @@ class EvaluationTest(test.TestCase):
self._predictions, self._scale = TestModel(self._inputs)
def testFinalOpsOnEvaluationLoop(self):
- value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
- self._labels)
+ value_op, update_op = metrics.accuracy(
+ labels=self._labels, predictions=self._predictions)
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
# Create checkpoint and log directories:
@@ -136,9 +135,10 @@ class EvaluationTest(test.TestCase):
self.assertTrue(obj.hook_was_run)
def _create_names_to_metrics(self, predictions, labels):
- accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels)
- accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1,
- labels)
+ accuracy0, update_op0 = metrics.accuracy(
+ labels=labels, predictions=predictions)
+ accuracy1, update_op1 = metrics.accuracy(
+ labels=labels, predictions=predictions + 1)
names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1}
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
@@ -198,8 +198,8 @@ class EvaluationTest(test.TestCase):
predictions_limited = input.limit_epochs(self._predictions, num_epochs=1)
labels_limited = input.limit_epochs(self._labels, num_epochs=1)
- value_op, update_op = metric_ops.streaming_accuracy(
- predictions_limited, labels_limited)
+ value_op, update_op = metrics.accuracy(
+ labels=labels_limited, predictions=predictions_limited)
init_op = control_flow_ops.group(variables.global_variables_initializer(),
variables.local_variables_initializer())
@@ -241,7 +241,7 @@ class SingleEvaluationTest(test.TestCase):
checkpoint_path = os.path.join(self.get_temp_dir(),
'this_file_doesnt_exist')
log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
- with self.assertRaises(errors.NotFoundError):
+ with self.assertRaises(ValueError):
evaluation.evaluate_once('', checkpoint_path, log_dir)
def _prepareCheckpoint(self, checkpoint_path):
@@ -260,8 +260,8 @@ class SingleEvaluationTest(test.TestCase):
self._prepareCheckpoint(checkpoint_path)
# Next, determine the metric to evaluate:
- value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
- self._labels)
+ value_op, update_op = metrics.accuracy(
+ labels=self._labels, predictions=self._predictions)
# Run the evaluation and verify the results:
accuracy_value = evaluation.evaluate_once(
@@ -276,8 +276,8 @@ class SingleEvaluationTest(test.TestCase):
self._prepareCheckpoint(checkpoint_path)
# Next, determine the metric to evaluate:
- value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
- self._labels)
+ value_op, update_op = metrics.accuracy(
+ labels=self._labels, predictions=self._predictions)
dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py
index 9305c6a11c..85918bf850 100644
--- a/tensorflow/contrib/solvers/python/ops/linear_equations.py
+++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py
@@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import linalg_ops
def conjugate_gradient(operator,
diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD
index 30be14c10c..0b8fc0cdc6 100644
--- a/tensorflow/contrib/stat_summarizer/BUILD
+++ b/tensorflow/contrib/stat_summarizer/BUILD
@@ -31,5 +31,8 @@ tf_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
],
- tags = ["no_windows"],
+ tags = [
+ "no_windows",
+ "notap", # TODO(b/80546574): test is flaky
+ ],
)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index f1ef218e74..3e41e3d0b4 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -81,6 +81,19 @@ class EagerFileTest(test_util.TensorFlowTestCase):
# test here that we're calling them correctly.
self.assertTrue(gfile.Exists(logdir))
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ training_util.get_or_create_global_step()
+ logdir = self.get_temp_dir()
+ with summary_ops.create_file_writer(
+ logdir, max_queue=0,
+ name='t0').as_default(), summary_ops.always_record_summaries():
+ summary_ops.generic('tensor', 1, '')
+ summary_ops.scalar('scalar', 2.0)
+ summary_ops.histogram('histogram', [1.0])
+ summary_ops.image('image', [[[[1.0]]]])
+ summary_ops.audio('audio', [[1.0]], 1.0, 1)
+
def testDefunSummarys(self):
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
index e893e1d1c8..d8236a0a6f 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py
@@ -21,10 +21,10 @@ import numpy as np
from tensorflow.contrib import losses
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
-from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES
@@ -38,12 +38,13 @@ def _top_k_generator(k):
targets = math_ops.to_int32(targets)
if targets.get_shape().ndims > 1:
targets = array_ops.squeeze(targets, axis=[1])
- return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k))
+ return metrics.mean(nn.in_top_k(probabilities, targets, k))
return _top_k
def _accuracy(predictions, targets, weights=None):
- return metric_ops.streaming_accuracy(predictions, targets, weights=weights)
+ return metrics.accuracy(
+ labels=targets, predictions=predictions, weights=weights)
def _r2(probabilities, targets, weights=None):
@@ -53,7 +54,7 @@ def _r2(probabilities, targets, weights=None):
squares_residuals = math_ops.reduce_sum(
math_ops.square(targets - probabilities), 0)
score = 1 - math_ops.reduce_sum(squares_residuals / squares_total)
- return metric_ops.streaming_mean(score, weights=weights)
+ return metrics.mean(score, weights=weights)
def _squeeze_and_onehot(targets, depth):
@@ -62,7 +63,7 @@ def _squeeze_and_onehot(targets, depth):
def _sigmoid_entropy(probabilities, targets, weights=None):
- return metric_ops.streaming_mean(
+ return metrics.mean(
losses.sigmoid_cross_entropy(probabilities,
_squeeze_and_onehot(
targets,
@@ -71,7 +72,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None):
def _softmax_entropy(probabilities, targets, weights=None):
- return metric_ops.streaming_mean(
+ return metrics.mean(
losses.sparse_softmax_cross_entropy(probabilities,
math_ops.to_int32(targets)),
weights=weights)
@@ -82,7 +83,7 @@ def _predictions(predictions, unused_targets, **unused_kwargs):
def _class_log_loss(probabilities, targets, weights=None):
- return metric_ops.streaming_mean(
+ return metrics.mean(
losses.log_loss(probabilities,
_squeeze_and_onehot(targets,
array_ops.shape(probabilities)[1])),
@@ -90,34 +91,36 @@ def _class_log_loss(probabilities, targets, weights=None):
def _precision(predictions, targets, weights=None):
- return metric_ops.streaming_precision(predictions, targets, weights=weights)
+ return metrics.precision(
+ labels=targets, predictions=predictions, weights=weights)
def _precision_at_thresholds(predictions, targets, weights=None):
- return metric_ops.streaming_precision_at_thresholds(
- array_ops.slice(predictions, [0, 1], [-1, 1]),
- targets,
- np.arange(
- 0, 1, 0.01, dtype=np.float32),
+ return metrics.precision_at_thresholds(
+ labels=targets,
+ predictions=array_ops.slice(predictions, [0, 1], [-1, 1]),
+ thresholds=np.arange(0, 1, 0.01, dtype=np.float32),
weights=weights)
def _recall(predictions, targets, weights=None):
- return metric_ops.streaming_recall(predictions, targets, weights=weights)
+ return metrics.recall(
+ labels=targets, predictions=predictions, weights=weights)
def _recall_at_thresholds(predictions, targets, weights=None):
- return metric_ops.streaming_recall_at_thresholds(
- array_ops.slice(predictions, [0, 1], [-1, 1]),
- targets,
- np.arange(
- 0, 1, 0.01, dtype=np.float32),
+ return metrics.recall_at_thresholds(
+ labels=targets,
+ predictions=array_ops.slice(predictions, [0, 1], [-1, 1]),
+ thresholds=np.arange(0, 1, 0.01, dtype=np.float32),
weights=weights)
def _auc(probs, targets, weights=None):
- return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]),
- targets, weights=weights)
+ return metrics.auc(
+ labels=targets,
+ predictions=array_ops.slice(probs, [0, 1], [-1, 1]),
+ weights=weights)
_EVAL_METRICS = {
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index 3f6b4cdc9a..6507546ee9 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -106,6 +106,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:png_internal",
"//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index a5d8b061b6..adda0b758b 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -49,7 +49,6 @@ tf_cuda_cc_test(
tf_custom_op_library(
name = "python/ops/_trt_engine_op.so",
srcs = [
- "ops/trt_calib_op.cc",
"ops/trt_engine_op.cc",
],
deps = [
@@ -76,11 +75,9 @@ tf_cuda_library(
cc_library(
name = "trt_engine_op_kernel",
srcs = [
- "kernels/trt_calib_op.cc",
"kernels/trt_engine_op.cc",
],
hdrs = [
- "kernels/trt_calib_op.h",
"kernels/trt_engine_op.h",
],
copts = tf_copts(),
@@ -89,20 +86,22 @@ cc_library(
":trt_logging",
":trt_plugins",
":trt_resources",
+ ":trt_conversion",
+ ":utils",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:stream_executor_headers_lib",
+ "//tensorflow/core/grappler/costs:graph_properties",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
- # TODO(laigd)
+ # TODO(laigd): fix this by merging header file in cc file.
alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
)
tf_gen_op_libs(
op_lib_names = [
"trt_engine_op",
- "trt_calib_op",
],
)
@@ -122,7 +121,6 @@ tf_gen_op_wrapper_py(
name = "trt_engine_op",
gen_locally = True,
deps = [
- ":trt_calib_op_op_lib",
":trt_engine_op_op_lib",
":trt_logging",
":trt_shape_function",
@@ -140,7 +138,6 @@ tf_custom_op_py_library(
kernels = [
":trt_engine_op_kernel",
":trt_engine_op_op_lib",
- ":trt_calib_op_op_lib",
":trt_shape_function",
],
srcs_version = "PY2AND3",
@@ -191,7 +188,6 @@ tf_py_wrap_cc(
deps = [
":trt_conversion",
":trt_engine_op_kernel",
- "//tensorflow/core:framework_lite",
"//third_party/python_runtime:headers",
],
)
@@ -211,6 +207,7 @@ tf_cuda_library(
],
deps = [
":trt_logging",
+ ":utils",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_proto_parsing",
@@ -237,12 +234,12 @@ tf_cuda_library(
":trt_plugins",
":trt_logging",
":trt_resources",
+ ":utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:framework_lite",
"//tensorflow/core:graph",
@@ -343,3 +340,8 @@ py_test(
"//tensorflow/python:framework_test_lib",
],
)
+
+cc_library(
+ name = "utils",
+ hdrs = ["convert/utils.h"],
+)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index da4dd5a14c..189944f29b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
-#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+#include <fstream>
#include <list>
#include <map>
#include <set>
@@ -24,10 +24,17 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -39,17 +46,39 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
+#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
+#include "tensorflow/core/util/device_name_utils.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorrt/include/NvInfer.h"
-
namespace tensorflow {
namespace tensorrt {
namespace convert {
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+// Returns compiled TRT version information {Maj, Min, Patch}
+std::vector<int> GetLinkedTensorRTVersion() {
+ return {NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH};
+}
+
+// Returns loaded TRT library version {Maj, Min, Patch}
+std::vector<int> GetLoadedTensorRTVersion() {
+ int ver = getInferLibVersion();
+ int ver_major = ver / 1000;
+ ver = ver - ver_major * 1000;
+ int ver_minor = ver / 100;
+ int ver_patch = ver - ver_minor * 100;
+ return {ver_major, ver_minor, ver_patch};
+}
+
namespace {
bool IsTensorRTCandidate(const tensorflow::Node* node) {
@@ -82,229 +111,6 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
}
-void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
- const std::set<int>& subgraph_node_ids,
- tensorflow::EdgeSet* incoming_edges) {
- for (int node_id : subgraph_node_ids) {
- const tensorflow::Node* node = graph.FindNodeId(node_id);
- for (const tensorflow::Edge* edge : node->in_edges()) {
- if (!subgraph_node_ids.count(edge->src()->id()) &&
- !edge->src()->IsSource() && !edge->IsControlEdge()) {
- incoming_edges->insert(edge);
- VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
- << " Y, ";
- } else {
- VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name()
- << " N, ";
- }
- }
- }
-}
-
-void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
- const std::set<int>& subgraph_node_ids,
- tensorflow::EdgeSet* outgoing_edges) {
- for (int node_id : subgraph_node_ids) {
- const tensorflow::Node* node = graph.FindNodeId(node_id);
- for (const tensorflow::Edge* edge : node->out_edges()) {
- if (!subgraph_node_ids.count(edge->dst()->id()) &&
- !edge->dst()->IsSink() && !edge->IsControlEdge()) {
- VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
- << " Y, ";
- outgoing_edges->insert(edge);
- } else {
- VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name()
- << " N, ";
- }
- }
- }
-}
-
-std::pair<string, int> ParseTensorName(const string& name,
- int default_idx = 0) {
- string name_no_idx = name;
- int idx = default_idx;
- const size_t sep = name_no_idx.find_last_of(':');
- if (sep != string::npos) {
- name_no_idx = name_no_idx.substr(0, sep);
- idx = std::stoi(name.substr(sep + 1));
- }
- return std::make_pair(name_no_idx, idx);
-}
-
-std::unordered_map<string, std::vector<int>> BuildTensorNameMap(
- const std::vector<string>& tensor_names) {
- std::unordered_map<string, std::vector<int>> result;
- for (const string& tensor_name : tensor_names) {
- string node_name;
- int index;
- std::tie(node_name, index) = ParseTensorName(tensor_name);
- result[node_name].push_back(index);
- }
- return result;
-}
-
-// TODO(sami): convert references to pointers
-struct ConvertGraphParams {
- ConvertGraphParams(
- tensorflow::Graph& inp_graph,
- const std::vector<string>& output_node_names,
- const std::set<int>& subgraph_node_id_numbers,
- size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& current_graph_properties,
- std::unordered_map<string, std::pair<int, string>>* output_edges,
- int engine_precision_mode, const string& device_name,
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator, int cuda_gpu_id)
- : graph(inp_graph),
- output_names(output_node_names),
- subgraph_node_ids(subgraph_node_id_numbers),
- max_batch_size(max_supported_batch_size),
- max_workspace_size_bytes(max_consumed_workspace_size_bytes),
- graph_properties(current_graph_properties),
- output_edge_map(output_edges),
- precision_mode(engine_precision_mode),
- device_name_(device_name),
- allocator_(allocator),
- cuda_gpu_id_(cuda_gpu_id) {}
- tensorflow::Graph& graph;
- const std::vector<string>& output_names;
- const std::set<int>& subgraph_node_ids;
- size_t max_batch_size;
- size_t max_workspace_size_bytes;
- const tensorflow::grappler::GraphProperties& graph_properties;
- std::unordered_map<string, std::pair<int, string>>* output_edge_map;
- int precision_mode;
- string device_name_;
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
- int cuda_gpu_id_;
- std::vector<std::pair<int, int>> subgraph_inputs;
- std::vector<std::pair<int, int>> subgraph_outputs;
- tensorflow::EdgeSet subgraph_incoming_edges;
- tensorflow::EdgeSet subgraph_outgoing_edges;
-};
-
-static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) {
- GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids,
- &p->subgraph_incoming_edges);
-
- std::set<std::pair<int, int>> unique_tensors;
- // Add only unique input source nodes. If output of an outside node is shared
- // between multiple nodes inside the engine, only one edge should be created
- for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) {
- unique_tensors.insert({edge->src()->id(), edge->src_output()});
- }
- p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(),
- unique_tensors.end());
- GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids,
- &p->subgraph_outgoing_edges);
- unique_tensors.clear();
- // Similar to above, if multiple ouside nodes are sharing the output of an
- // internal node only one output port should be created and shared between
- // outputs
- for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) {
- unique_tensors.insert({edge->src()->id(), edge->src_output()});
- }
- p->subgraph_outputs.reserve(unique_tensors.size());
- p->subgraph_outputs.insert(p->subgraph_outputs.begin(),
- unique_tensors.begin(), unique_tensors.end());
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status GetCalibNode(ConvertGraphParams* params) {
- TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
- tensorflow::NodeDef trt_node_def;
- SubGraphParams s(params->graph, params->subgraph_node_ids,
- params->subgraph_inputs, params->subgraph_outputs,
- params->max_batch_size, params->max_workspace_size_bytes,
- params->graph_properties, params->output_edge_map,
- &trt_node_def, params->precision_mode, params->device_name_,
- params->allocator_, params->cuda_gpu_id_);
- TF_RETURN_IF_ERROR(InjectCalibrationNode(s));
- tensorflow::Status status;
- tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
-
- TF_RETURN_IF_ERROR(status);
-
- for (auto in_edge :
- params->subgraph_incoming_edges) { // loop over incoming edges and
- // attach them to calib node
- auto src_output = in_edge->src_output();
- auto dst_node = in_edge->dst();
- auto dst_input = in_edge->dst_input();
- VLOG(1) << " update edge " << trt_node->name() << ":" << src_output
- << " -> " << dst_node->name() << ":" << dst_input;
- TF_RETURN_IF_ERROR(
- params->graph.UpdateEdge(trt_node, src_output, dst_node, dst_input));
- }
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) {
- TF_RETURN_IF_ERROR(FillSubGraphEdgeSets(params));
- tensorflow::NodeDef trt_node_def;
-
- SubGraphParams s(params->graph, params->subgraph_node_ids,
- params->subgraph_inputs, params->subgraph_outputs,
- params->max_batch_size, params->max_workspace_size_bytes,
- params->graph_properties, params->output_edge_map,
- &trt_node_def, params->precision_mode, params->device_name_,
- params->allocator_, params->cuda_gpu_id_);
- TF_RETURN_IF_ERROR(ConvertSubGraphToTensorRTNodeDef(s));
- tensorflow::Status status;
- tensorflow::Node* trt_node = params->graph.AddNode(trt_node_def, &status);
-
- // AddNode does not wire edges.
- // Re-map incoming edges to use the new TRT node instead of the orig subgraph
- std::map<std::pair<int, int>, int> subgraph_edge_to_input_map;
- for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) {
- subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i});
- }
- std::set<std::pair<int, int>> unique_tensors;
- for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) {
- std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
- if (unique_tensors.count(old_src)) continue;
- unique_tensors.insert(old_src);
- int new_src_output = subgraph_edge_to_input_map.at(old_src);
- params->graph.AddEdge(edge->src(), edge->src_output(), trt_node,
- new_src_output);
- VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output()
- << " -> " << trt_node->name() << ":" << new_src_output;
- params->graph.RemoveEdge(edge);
- }
- if (VLOG_IS_ON(2)) {
- VLOG(2) << "new edge count: " << trt_node->in_edges().size();
- for (const tensorflow::Edge* edge : trt_node->in_edges()) {
- VLOG(2) << edge->src()->name() << " port: " << edge->src_output();
- }
- }
- TF_RETURN_IF_ERROR(status);
-
- // Re-map outgoing edges to use the new TRT node instead of the orig subgraph
- std::map<std::pair<int, int>, int> subgraph_edge_to_output_map;
- for (size_t i = 0; i < params->subgraph_outputs.size(); ++i) {
- subgraph_edge_to_output_map.insert({params->subgraph_outputs.at(i), i});
- }
- TF_RETURN_IF_ERROR(status);
- for (const tensorflow::Edge* edge : params->subgraph_outgoing_edges) {
- std::pair<int, int> old_src = {edge->src()->id(), edge->src_output()};
- int new_src_output = subgraph_edge_to_output_map.at(old_src);
- TF_RETURN_IF_ERROR(params->graph.UpdateEdge(
- trt_node, new_src_output, edge->dst(), edge->dst_input()));
- VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> "
- << edge->dst()->name() << ":" << edge->dst_input();
- }
- // Remove the original subgraph
- for (int node_id : params->subgraph_node_ids) {
- tensorflow::Node* node = params->graph.FindNodeId(node_id);
- // Don't remove the input placeholders
- if (node->type_string() == "Placeholder") {
- continue;
- }
- params->graph.RemoveNode(node);
- }
- return tensorflow::Status::OK();
-}
-
tensorflow::Status BuildNodeMap(
const tensorflow::Graph& graph,
std::unordered_map<string, tensorflow::Node*>* node_map) {
@@ -318,51 +124,78 @@ tensorflow::Status BuildNodeMap(
}
} // namespace
+
+// Function to get calibration from ResourceMgr and put them into nodedef.
tensorflow::Status ConvertCalibGraphToInferGraph(
- const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph) {
+ const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph,
+ bool is_dyn_op) {
VLOG(0) << "Starting Calib Conversion";
- tensorflow::Graph graph(tensorflow::OpRegistry::Global());
- TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
- tensorflow::GraphConstructorOptions(), graph_def, &graph));
- // get calib nodes
- std::vector<tensorflow::Node*> calib_nodes;
- std::vector<tensorflow::Node*> topo_order;
- tensorflow::GetPostOrder(graph, &topo_order);
- for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
- auto node = *rit;
- if (node->type_string() == "TRTCalibOp") {
- VLOG(1) << "Found Calib Node " << node->name();
- calib_nodes.push_back(node);
- }
+ infer_graph->CopyFrom(graph_def);
+ auto trt_rm = TRTResourceManager::instance();
+ auto calib_rm = trt_rm->getManager("TRTCalibration");
+ int num_nodes = infer_graph->node_size();
+ if (!is_dyn_op) {
+ LOG(WARNING) << "Construction of static int8 engine is not implemented "
+ "yet!. Dynamic engine will be constructed";
}
- VLOG(0) << "Num Calib nodes in graph= " << calib_nodes.size();
- if (calib_nodes.size() == 0)
- return tensorflow::errors::FailedPrecondition(
- "Graph doesn't contain any calibration nodes!."
- " Please generate calibration graph and run calibration first");
- for (auto n : calib_nodes) {
- TF_RETURN_IF_ERROR(
- tensorrt::convert::ConvertCalibrationNodeToEngineNode(graph, n));
+ for (int i = 0; i < num_nodes; ++i) {
+ auto n = infer_graph->mutable_node(i);
+ if (n->op() == "TRTEngineOp") {
+ VLOG(1) << "Processing " << n->name();
+ const string& container_name = n->attr().at("segment_funcdef_name").s();
+ TRTCalibrationResource* cres = nullptr;
+ auto status = calib_rm->Lookup(container_name, "Calibrator", &cres);
+ if (!status.ok()) {
+ LOG(ERROR) << "Could not get Calibration information. Did you run with "
+ "calibration data?";
+ return tensorflow::errors::FailedPrecondition(
+ "Need to run graph with calibration data first!");
+ }
+ if (cres->calibrator_) {
+ cres->calibrator_->waitAndSetDone();
+ cres->thr_->join();
+ const auto& calibration_table =
+ cres->calibrator_->getCalibrationTableAsString();
+ if (!calibration_table.size()) {
+ LOG(ERROR) << "Calibration table is empty";
+ return tensorflow::errors::Unknown(
+ "Calibration table is missing. This shouldn't have happened!");
+ }
+ n->mutable_attr()->at("calibration_data").set_s(calibration_table);
+ } else {
+ LOG(ERROR) << "Can't get TRTCalibrator from resource manager!";
+ return tensorflow::errors::Unknown(
+ "Can't get TRTCalibrator from resource manager!");
+ }
+ cres->Unref();
+ calib_rm->Cleanup(container_name);
+ }
}
- graph.ToGraphDef(infer_graph);
return tensorflow::Status::OK();
}
+// Entry function from Python.
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
- int precision_mode = FP32MODE, int minimum_segment_size = 3) {
+ int precision_mode, int minimum_segment_size, bool is_dyn_op,
+ int max_cached_engines, std::vector<int> cached_engine_batches) {
// optimization pass
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
item.graph = graph_def;
-
+ // grappler requires a virtual cluster with a proper GPU device
+ // in order to calculate flops>0 or fails with FATAL
+ // We add numbers from a Pascal card here to have flops>0
tensorflow::DeviceProperties device_properties;
device_properties.set_type("GPU");
device_properties.mutable_environment()->insert({"architecture", "6"});
- tensorflow::grappler::Cluster* cluster =
- new tensorflow::grappler::VirtualCluster({{"/GPU:0", device_properties}});
+ device_properties.set_num_cores(3584);
+ device_properties.set_frequency(1531);
+ std::unique_ptr<tensorflow::grappler::Cluster> cluster(
+ new tensorflow::grappler::VirtualCluster(
+ {{"/GPU:0", device_properties}}));
// single machine
int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores();
@@ -370,134 +203,633 @@ tensorflow::Status ConvertGraphDefToTensorRT(
VLOG(2) << "cpu_cores: " << num_cpu_cores;
VLOG(2) << "gpus: " << num_gpus;
tensorflow::RewriterConfig rw_cfg;
+ // use only const folding and layout for the time being since new optimizers
+ // break the graph for us
+ rw_cfg.add_optimizers("constfold");
+ rw_cfg.add_optimizers("layout");
+ rw_cfg.set_meta_optimizer_iterations(tensorflow::RewriterConfig::ONE);
tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg);
tensorflow::GraphDef gdef;
- TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster, item, &gdef));
+ TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef));
item.graph = gdef;
// AJ refactoring shape inference through grappler/GraphProperties.
tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
// Build full graph
-
- return ConvertAfterShapes(gdef, output_names, max_batch_size,
- max_workspace_size_bytes, new_graph_def,
- precision_mode, minimum_segment_size,
- static_graph_properties, nullptr);
+ ConversionParams cp;
+ cp.input_graph_def = &gdef;
+ cp.output_names = &output_names;
+ cp.max_batch_size = max_batch_size;
+ cp.output_graph_def = new_graph_def;
+ cp.precision_mode = precision_mode;
+ cp.is_dyn_op = is_dyn_op;
+ cp.max_cached_engines = max_cached_engines;
+ cp.cached_engine_batches = cached_engine_batches;
+ cp.minimum_segment_size = minimum_segment_size;
+ cp.graph_properties = &static_graph_properties;
+ cp.max_workspace_size_bytes = max_workspace_size_bytes;
+ if (VLOG_IS_ON(5)) {
+ std::fstream f;
+ f.open("TRTConversionInput.pb",
+ std::fstream::out | std::fstream::binary | std::fstream::trunc);
+ f << gdef.SerializeAsString();
+ f.close();
+ }
+ return ConvertAfterShapes(cp);
}
-tensorflow::Status ConvertAfterShapes(
- const tensorflow::GraphDef& gdef, const std::vector<string>& output_names,
- size_t max_batch_size, size_t max_workspace_size_bytes,
- tensorflow::GraphDef* new_graph_def, int precision_mode,
- int minimum_segment_size,
+// Function to get subsegment information structure.
+tensorflow::Status GetEngineInfo(
+ const tensorflow::Graph* g,
const tensorflow::grappler::GraphProperties& graph_properties,
- const tensorflow::grappler::Cluster* cluster) {
- // Segment the graph into subgraphs that can be converted to TensorRT
- tensorflow::tensorrt::segment::SegmentOptions segment_options;
+ const std::set<string>& segment_nodes,
+ const std::unordered_map<string, tensorflow::Node*>& node_map,
+ const std::vector<tensorflow::Node*>& reverse_topo_order,
+ EngineInfo* info) {
+ std::vector<int> subgraph_node_ids;
+ std::set<string> segment_devices;
+ int input_port = 0;
+ int output_port = 0;
+
+ // Map from src_node_name+port to the unique port numbers of the TRT op, where
+ // the src_node_name is the name of the source node of the input/output
+ // edge, thus there must not be any duplicates since source nodes of
+ // input/output edges must be in different split of the graph.
+ // TODO(aaroey): consider using node id and port instead.
+ std::unordered_map<string, int> created_edges;
+ for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
+ ++it) {
+ const auto& node_name = (*it)->name();
+
+ if (segment_nodes.count(node_name) == 0) continue;
+ auto node = node_map.at(node_name);
+ auto node_device = node->requested_device();
+ if (!node_device.empty()) {
+ segment_devices.insert(node_device);
+ } else {
+ if (node->has_assigned_device_name()) {
+ segment_devices.insert(node->assigned_device_name());
+ } else {
+ VLOG(2) << "Node " << node->name()
+ << " neither have requested device nor assigned device";
+ }
+ }
+ int node_id = node->id();
+ subgraph_node_ids.push_back(node_id);
+ for (const auto edge : node->in_edges()) {
+ auto input_node = edge->src();
+ if (segment_nodes.count(input_node->name()) == 0) {
+ // Add constant input node into the segment. We don't care if it has
+ // other output edges going into other engines or TF nodes. Since we add
+ // it only to the subsegment node list, not the subsegment itself, it
+ // won't be removed from the graph. If it doesn't have any edges, TF
+ // will prune it out.
+ if (input_node->type_string() == "Const") {
+ subgraph_node_ids.push_back(input_node->id());
+ } else if (!edge->IsControlEdge() && !input_node->IsSource()) {
+ string s(input_node->name());
+ StrAppend(&s, ":", edge->src_output());
+ VLOG(1) << "Input edge = " << s;
+ int port = input_port;
+ if (created_edges.count(s)) {
+ port = created_edges.at(s);
+ } else {
+ created_edges.insert({s, port});
+ input_port++;
+ }
+ info->connections.emplace_back(input_node->name(), input_node->id(),
+ edge->src_output(), node_name, node_id,
+ edge->dst_input(), true, port);
+ }
+ }
+ }
+ for (const auto edge : node->out_edges()) {
+ auto output_node = edge->dst();
+ if (segment_nodes.count(output_node->name()) == 0 &&
+ !edge->IsControlEdge() && !output_node->IsSink()) {
+ string s(node_name);
+ StrAppend(&s, ":", edge->src_output());
+ VLOG(1) << "Output edge = " << s;
+ int port = output_port;
+ if (created_edges.count(s)) {
+ port = created_edges.at(s);
+ } else {
+ created_edges.insert({s, port});
+ output_port++;
+ }
+ info->connections.emplace_back(output_node->name(), output_node->id(),
+ edge->dst_input(), node_name, node_id,
+ edge->src_output(), false, port);
+ }
+ }
+ }
+
+ TF_RETURN_IF_ERROR(ConvertSegmentToGraphDef(
+ g, graph_properties, subgraph_node_ids, &info->connections,
+ &info->segment_graph_def, &info->engine_name));
+ // TODO(sami): This should not happen once segmenter is updated.
+ if (segment_devices.size() == 1) {
+ info->device = *segment_devices.begin();
+ } else if (segment_devices.size() > 1) {
+ LOG(WARNING) << "Detected multiple(" << segment_devices.size()
+ << ") devices for the segment. Picking first one to continue "
+ << "but this shouldn't have happened";
+ info->device = *segment_devices.begin();
+ } else {
+ VLOG(1) << "Segment devices size is 0";
+ }
+ return Status::OK();
+}
+
+// Function to insert a TRT node into the graph. The graph is not modified if
+// the returned status is not ok.
+// 'alloc' is only used for creating static engine.
+tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
+ const std::vector<EngineInfo>& infos, int pos,
+ nvinfer1::IGpuAllocator* alloc,
+ int max_batch_size) {
+ const auto& info = infos.at(pos);
+ std::vector<tensorflow::TensorShapeProto> out_shapes;
+ std::vector<tensorflow::TensorShapeProto> input_shapes;
+ std::vector<tensorflow::PartialTensorShape> shapes;
+ std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
+ std::vector<tensorflow::DataType> out_types;
+ VLOG(1) << "Processing " << info.engine_name;
+
+ // Update the shape and data types of input/output nodes, and find all unique
+ // inputs.
+ for (const auto& conn : info.connections) {
+ if (!conn.is_input_edge) {
+ // Set the shapes and data types of output edge.
+ tensorflow::TensorShapeProto out_shape;
+ // shape of the output node inside segment
+ conn.inside_shape.AsProto(&out_shape);
+ if (out_shapes.size() <= conn.port_number) {
+ out_shapes.resize(conn.port_number + 1);
+ out_types.resize(conn.port_number + 1);
+ }
+ out_shapes.at(conn.port_number) = out_shape;
+ out_types.at(conn.port_number) = conn.connection_type;
+ continue;
+ }
+
+ // Set the shapes and data types of input edge.
+ tensorflow::TensorShapeProto in_shape;
+ conn.outside_shape.AsProto(&in_shape);
+ if (input_shapes.size() <= conn.port_number) {
+ input_shapes.resize(conn.port_number + 1);
+ shapes.resize(conn.port_number + 1);
+ }
+ input_shapes.at(conn.port_number) = in_shape;
+ shapes.at(conn.port_number) = conn.outside_shape;
+
+ string input_node = conn.outside_node_name;
+ int input_port = conn.outside_port;
+ bool found_engine = false;
+ // Rewire the inputs to other engines if they contain original input node.
+ // Note that we use the information of the engine here, not the information
+ // of the created TRT nodes, so we're able to find all the connections to
+ // any other engines beforehand.
+ for (size_t t = 0; t < infos.size(); ++t) {
+ if (t == pos) continue;
+ auto& engine_info = infos.at(t);
+ for (const auto& eng_conn : engine_info.connections) {
+ if (eng_conn.is_input_edge) continue;
+ if (eng_conn.inside_node_name == input_node) {
+ input_node = engine_info.engine_name;
+ if (eng_conn.inside_port == input_port) {
+ input_port = eng_conn.port_number;
+ found_engine = true;
+ break;
+ }
+ }
+ }
+ if (found_engine) break;
+ }
+ VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
+ << info.engine_name << ":" << inputs.size();
+ // Skip duplicate inputs.
+ bool new_input = true;
+ for (const auto& inp : inputs) {
+ if (inp.node == input_node && inp.index == input_port) {
+ new_input = false;
+ break;
+ }
+ }
+ if (new_input) {
+ inputs.emplace_back(input_node, input_port, conn.connection_type);
+ }
+ }
+
+ // Build the engine and get its serialized representation.
+ string segment_string;
+ if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
+ info.precision_mode == INT8MODE) {
+ // Create static engine for fp32/fp16 mode, and test validity of the engine
+ // for int8 mode. We don't want engine to fail at the calibration time.
+ // So we are constructing a FP32 engine here to check its validity, and if
+ // it is a valid engine then we put the serialized graphdef to the op.
+ // Otherwise we skip node creation for this engine.
+ Logger trt_logger;
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
+ // TODO(sami): What happens if 1st dim is not batch?
+ TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
+ info.segment_graph_def,
+ info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
+ max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
+ alloc, /*calibrator=*/nullptr, &engine,
+ /*convert_successfully=*/nullptr));
+ TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
+ segment_string =
+ string((const char*)engine_data->data(), engine_data->size());
+ if (info.precision_mode == INT8MODE) {
+ // See above comment about why not putting this inside the 'else' branch.
+ segment_string = info.segment_graph_def.SerializeAsString();
+ }
+ } else {
+ segment_string = info.segment_graph_def.SerializeAsString();
+ }
+
+ // TODO(aaroey): use enum instead, and add a helper method to do the
+ // conversion.
+ string prec_string;
+ switch (info.precision_mode) {
+ case FP32MODE:
+ prec_string = "FP32";
+ break;
+ case FP16MODE:
+ prec_string = "FP16";
+ break;
+ case INT8MODE:
+ prec_string = "INT8";
+ if (!TRTResourceManager::instance()->getManager("TRTCalibration")) {
+ LOG(ERROR) << "Failed to construct calibration storage";
+ }
+ break;
+ default:
+ return tensorflow::errors::OutOfRange("Unknown precision mode");
+ }
+ tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
+ if (!info.device.empty()) node_builder.Device(info.device);
+ if (VLOG_IS_ON(1)) {
+ string ins = StrCat(info.engine_name, " inputs= ");
+ for (const auto& ii : inputs) {
+ StrAppend(&ins, ii.node, ":", ii.index, " ");
+ }
+ VLOG(1) << ins;
+ }
+ node_builder.Input(inputs);
+ if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
+ info.cached_engine_batches.size()) {
+ LOG(WARNING) << "Cached engine batches are ignored for static engines";
+ }
+ tensorflow::NodeDef trt_node;
+ tensorflow::Status status =
+ node_builder.Attr("input_shapes", input_shapes)
+ .Attr("output_shapes", out_shapes)
+ .Attr("static_engine",
+ info.engine_type == EngineInfo::EngineType::TRTStatic)
+ .Attr("segment_funcdef_name",
+ StrCat(info.engine_name, "_native_segment"))
+ .Attr("serialized_segment", segment_string)
+ .Attr("calibration_data", "")
+ .Attr("max_cached_engines_count", info.maximum_cached_engines)
+ .Attr("cached_engine_batches", {max_batch_size})
+ .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
+ .Attr("precision_mode", prec_string)
+ .Attr("OutT", out_types)
+ .Finalize(&trt_node);
+ if (!status.ok()) {
+ LOG(ERROR) << "Node construction failed with" << status;
+ return status;
+ }
+ VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
+
+ // Up until this point, graph is not modified. If we return !status.ok() from
+ // here, this segment will be skipped
+ tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
+ if (!status.ok()) {
+ LOG(ERROR) << "Adding node failed " << status;
+ return status;
+ }
+ // Updates the inputs of output edges destination nodes, and point them to the
+ // engine node.
+ for (auto& conn : info.connections) {
+ if (conn.is_input_edge) continue;
+ VLOG(1) << " Updating DBG " << engine_node->name() << " out_port "
+ << conn.port_number << " out_id " << conn.outside_id
+ << " name=" << conn.outside_node_name;
+ auto dst_node = graph->FindNodeId(conn.outside_id);
+ // dst_node can only be removed if it is an input node of another engine.
+ // In this case, other engines input edge is updated in nodedef to point to
+ // this engine. Even though edge doesn't exists in the graph, when it is
+ // deserialized again, correct edges will be constructed. This is a problem
+ // of graph->AddNode().
+ if (!dst_node) continue;
+ VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
+ << " to " << dst_node->name() << ":" << conn.outside_port;
+ auto new_edge = graph->AddEdge(engine_node, conn.port_number, dst_node,
+ conn.outside_port);
+ CHECK(new_edge) << "Adding a new edge failed " << engine_node->name() << ":"
+ << conn.port_number << " -> " << dst_node->name() << ":"
+ << conn.outside_port;
+ }
+ return status;
+}
+
+// Function to construct a funcdef from the segment and add it to the graph.
+tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
+ tensorflow::Graph* graph, const tensorflow::GraphDef& segment,
+ const string& name) {
+ tensorflow::Graph sgraph(graph->flib_def());
+ tensorflow::GraphConstructorOptions gcopts;
+ TF_RETURN_IF_ERROR(
+ tensorflow::ConvertGraphDefToGraph(gcopts, segment, &sgraph));
+ std::map<string, tensorflow::Node*> io_nodes;
+ int num_inputs = 0;
+ for (auto n : sgraph.op_nodes()) {
+ if (tensorflow::str_util::StartsWith(n->name(), kInputPHName)) {
+ num_inputs++;
+ io_nodes.insert({n->name(), n});
+ } else if (tensorflow::str_util::StartsWith(n->name(), kOutputPHName)) {
+ io_nodes.insert({n->name(), n});
+ }
+ }
+
+ for (int i = 0; i < num_inputs; ++i) {
+ auto name = StrCat(kInputPHName, i);
+ auto node = io_nodes[name];
+ tensorflow::NodeDef nd;
+ tensorflow::NodeDefBuilder node_builder(
+ StrCat(name, "_Arg"), tensorflow::FunctionLibraryDefinition::kArgOp);
+ VLOG(1) << "Adding " << StrCat(name, "_Arg");
+ TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
+ .Attr("index", i)
+ .Finalize(&nd));
+ tensorflow::Status s;
+ auto node_arg = sgraph.AddNode(nd, &s);
+ if (!s.ok()) {
+ LOG(ERROR) << "Couldn't add _Arg node for " << name;
+ }
+ for (auto edge : node->out_edges()) {
+ sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input());
+ VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0
+ << " - > " << edge->dst()->name() << ":" << edge->dst_input();
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to update edge from " << node_arg->name()
+ << " to " << edge->dst()->name() << ":" << edge->dst_input();
+ }
+ }
+ sgraph.RemoveNode(node);
+ }
+
+ for (int i = 0; i < io_nodes.size() - num_inputs; ++i) {
+ auto name = StrCat(kOutputPHName, i);
+ auto node = io_nodes[name];
+ tensorflow::NodeDef nd;
+ tensorflow::NodeDefBuilder node_builder(
+ StrCat(name, "_Ret"), tensorflow::FunctionLibraryDefinition::kRetOp);
+ auto edge = *(node->in_edges().begin());
+ tensorflow::NodeDefBuilder::NodeOut nout(
+ edge->src()->name(), edge->src_output(),
+ edge->src()->output_type(edge->src_output()));
+ VLOG(1) << " input " << nout.node << ":" << nout.index
+ << " dtype=" << tensorflow::DataTypeString(nout.data_type);
+ node_builder.Input({nout});
+ TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0))
+ .Attr("index", i)
+ .Finalize(&nd));
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << nd.DebugString();
+ }
+ tensorflow::Status s;
+ auto node_ret = sgraph.AddNode(nd, &s);
+ if (!s.ok()) {
+ LOG(ERROR) << "Couldn't add _Ret node for " << name;
+ }
+ VLOG(1) << "Update edge from " << edge->src()->name() << ":"
+ << edge->src_output() << " - > " << node_ret->name() << ":" << 0;
+ sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0);
+ s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0);
+ if (!s.ok()) {
+ LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":"
+ << edge->src_output() << " - > " << node_ret->name() << ":"
+ << 0;
+ }
+ sgraph.RemoveNode(node);
+ }
+ tensorflow::FunctionDefLibrary fdeflib;
+ auto native_segment = fdeflib.add_function();
+ TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
+ sgraph, StrCat(name, "_native_segment"), native_segment));
+ if (VLOG_IS_ON(7)) {
+ VLOG(7) << name << " Function_Def ";
+ VLOG(7) << native_segment->DebugString();
+ }
+ VLOG(1) << "Adding funcdef to graphlib";
+ TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib));
+ return tensorflow::Status::OK();
+}
+
+std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
+ ConversionParams& params, EngineInfo& engine) {
+ int cuda_device_id = -1;
+ auto check_device_id = [](int tfid) -> int {
+ tensorflow::TfGpuId tf_gpu_id(tfid);
+ CudaGpuId cuda_gpu_id;
+ Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ if (s.ok()) {
+ VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
+ << cuda_gpu_id.value();
+ return cuda_gpu_id.value();
+ }
+ VLOG(2) << "TF GPU with id " << tfid << " do not exist " << s;
+ return -1;
+ };
+ tensorflow::Allocator* dev_allocator = nullptr;
+ // we need to us PM here since in python path there is no way to get
+ // to allocators.
+ // TODO(sami): when grappler devices become available else path will not be
+ // necessary
+ auto pm = tensorflow::GPUProcessState::singleton();
+ if (params.cluster) { // get allocator
+ tensorflow::Device* device = nullptr;
+ if (params.cluster->GetDeviceSet()) {
+ device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device);
+ }
+ if (device) {
+ tensorflow::AllocatorAttributes alloc_attr;
+ dev_allocator = device->GetAllocator(alloc_attr);
+ VLOG(1) << "Using allocator " << dev_allocator->Name();
+ } else {
+ LOG(WARNING) << "Cluster is set but device '" << engine.device
+ << "' is not found in the cluster";
+ }
+ } else { // cluster not found, possibly a python call
+ VLOG(1) << "Cluster is not set, probably called from python";
+ int found_device = 0;
+ bool try_gpu_ids = true;
+ // if device is set, try to find the device. Might be a problem for multi
+ // host case but TensorRT do not support multi host setups yet.
+ if (!engine.device.empty()) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) {
+ cuda_device_id = parsed_name.has_id ? parsed_name.id : -1;
+ }
+ try_gpu_ids = !parsed_name.has_id;
+ }
+ if (try_gpu_ids) {
+ while (found_device < 100) {
+ cuda_device_id = check_device_id(found_device);
+ if (cuda_device_id >= 0) break;
+ found_device++;
+ }
+ }
+ if (found_device == 100) {
+ LOG(ERROR) << " Can't find a GPU device to work with. Please "
+ "instantiate a session to initialize devices";
+ return std::make_pair(cuda_device_id, dev_allocator);
+ }
+ LOG(WARNING)
+ << "Can't determine the device, constructing an allocator at device "
+ << found_device;
+ tensorflow::GPUOptions gpuoptions;
+ // this will be a noop if device is already initialized
+ gpuoptions.set_allow_growth(true);
+ tensorflow::TfGpuId tf_gpu_id(found_device);
+ dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
+ }
+ return std::make_pair(cuda_device_id, dev_allocator);
+}
+
+// Entry function from optimization pass.
+tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
+ // Convert graphdef to graph.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
- gdef.library());
+ params.input_graph_def->library());
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
- tensorflow::GraphConstructorOptions(), gdef, &graph));
+ tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph));
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorflow::tensorrt::segment::SegmentOptions segment_options;
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
- for (auto node : output_names) {
+ for (auto node : *(params.output_names)) {
segment_options.exclude_node_list.insert(node);
}
-
- // TODO(sami): this should be passed as a knob!!!!
- segment_options.minimum_segment_size = minimum_segment_size;
- tensorflow::tensorrt::segment::SegmentNodesVector segments;
+ segment_options.minimum_segment_size = params.minimum_segment_size;
+ tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
- &graph, IsTensorRTCandidate, segment_options, &segments));
- if (segments.size() > 1) {
- VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
+ &graph, IsTensorRTCandidate, segment_options, &initial_segments));
+ if (initial_segments.size() > 1) {
+ VLOG(0) << "MULTIPLE tensorrt candidate conversion: "
+ << initial_segments.size();
}
+
+ // Get the EngineInfo for each segment.
std::unordered_map<string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
- std::unordered_map<string, std::pair<int, string>> output_edge_map;
- int count = 0;
float total_num_nodes_in_segments = 0.;
- for (auto s : segments) {
- total_num_nodes_in_segments += s.first.size();
- }
- // We create the map here since cluster may not be available in all cases.
- std::map<string, tensorflow::Device*> name_to_device_map;
- if (cluster) {
- // TODO(aaroey): consider using DeviceSet::FindDeviceByName(), as in a
- // distributed environment, devices from different workers can have same
- // short name.
- for (const auto dm : cluster->GetDeviceSet()->devices()) {
- name_to_device_map[dm->name()] = dm;
+ std::vector<EngineInfo> engine_segments;
+ engine_segments.reserve(initial_segments.size());
+ std::vector<tensorflow::Node*> reverse_topo_order;
+ tensorflow::GetPostOrder(graph, &reverse_topo_order);
+ size_t total_engine_bytes_size = 0;
+ std::vector<size_t> engine_bytes_size;
+ tensorflow::tensorrt::segment::SegmentNodesVector converted_segments;
+ converted_segments.reserve(initial_segments.size());
+ for (size_t t = 0; t < initial_segments.size(); t++) {
+ auto& curr_segment = initial_segments.at(t);
+ EngineInfo curr_engine;
+ Status status =
+ GetEngineInfo(&graph, *params.graph_properties, curr_segment.first,
+ node_map, reverse_topo_order, &curr_engine);
+ if (!status.ok()) {
+ LOG(WARNING) << "Failed to get engine info for segment " << t << ": "
+ << status;
+ continue;
}
- }
- for (const auto& segment_nodes_and_device : segments) {
- const std::set<string>& subgraph_node_names =
- segment_nodes_and_device.first;
- std::set<int> subgraph_node_ids;
- size_t max_mem_per_engine =
- max_workspace_size_bytes *
- ((float)subgraph_node_names.size() / total_num_nodes_in_segments);
- std::stringstream oss;
- for (const string& node_name : subgraph_node_names) {
- oss << " " << node_name;
- subgraph_node_ids.insert(node_map.at(node_name)->id());
+ curr_engine.precision_mode = params.precision_mode;
+ curr_engine.engine_type =
+ (params.is_dyn_op || params.precision_mode == INT8MODE
+ ? EngineInfo::EngineType::TRTDynamic
+ : EngineInfo::EngineType::TRTStatic);
+ curr_engine.cached_engine_batches = params.cached_engine_batches;
+ curr_engine.maximum_cached_engines = params.max_cached_engines;
+ StrAppend(&curr_engine.engine_name, "my_trt_op_", t);
+ status = RegisterSegmentFunctionToFunctionLibrary(
+ &graph, curr_engine.segment_graph_def, curr_engine.engine_name);
+ if (!status.ok()) {
+ LOG(WARNING) << "Failed to register segment graphdef as a function " << t
+ << ": " << status;
+ continue;
}
- VLOG(1) << "Subgraph nodes at device " << segment_nodes_and_device.second
- << " : " << oss.str();
- auto target_device =
- name_to_device_map.find(segment_nodes_and_device.second);
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator(0);
+ engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
+ total_engine_bytes_size += engine_bytes_size.back();
+ total_num_nodes_in_segments += curr_segment.first.size();
+ engine_segments.push_back(std::move(curr_engine));
+ converted_segments.push_back(std::move(curr_segment));
+
+ if (VLOG_IS_ON(8)) {
+ string fname = curr_engine.engine_name;
+ StrAppend(&fname, ".pb");
+ std::fstream f;
+ f.open(fname.c_str(), std::fstream::out | std::fstream::binary);
+ f << engine_segments.at(t).segment_graph_def.SerializeAsString();
+ f.close();
+ }
+ }
+
+ // Create a TRT node for each segment using its EngineInfo.
+ int old_cuda_device = 0;
+ auto err = cudaGetDevice(&old_cuda_device);
+ if (err != cudaSuccess) {
+ LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
+ }
+ VLOG(1) << "Current cuda device is " << old_cuda_device;
+ for (int i = 0; i < engine_segments.size(); ++i) {
+ auto& engine = engine_segments.at(i);
+ // Partition the workspace size by the average of node ratio and segment
+ // graphdef size
+ engine.max_workspace_size_bytes =
+ params.max_workspace_size_bytes *
+ (engine_bytes_size.at(i) / total_engine_bytes_size +
+ converted_segments.at(i).first.size() / total_num_nodes_in_segments) /
+ 2.0;
+ // The allocator is used to build the engine. The build and the built engine
+ // will be destroyed after we get the serialized engine string, so it's fine
+ // to use unique_ptr here.
+ std::unique_ptr<nvinfer1::IGpuAllocator> alloc;
+ auto device_alloc = GetDeviceAndAllocator(params, engine);
int cuda_device_id = 0;
- if (target_device != name_to_device_map.end()) {
- tensorflow::TfGpuId tf_gpu_id(target_device->second->parsed_name().id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
- if (!s.ok()) {
- LOG(ERROR)
- << "Cuda device identification failed, using device 0. Error= "
- << s;
- } else {
- cuda_device_id = cuda_gpu_id.value();
- }
- tensorflow::GPUOptions gpuoptions;
- // we need to us PM here since in python path there is no way to get to
- // allocators
- auto pm = tensorflow::ProcessState::singleton();
- // this should be instantiated by now
- auto dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
- VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value()
- << " cuda device= " << cuda_device_id << " at " << dev_allocator;
- allocator = std::make_shared<TRTDeviceAllocator>(dev_allocator);
- } else { // device unknown or not available
- allocator = std::make_shared<TRTCudaAllocator>();
+ if (device_alloc.first >= 0) {
+ cuda_device_id = device_alloc.first;
+ alloc.reset(new TRTDeviceAllocator(device_alloc.second));
+ } else {
+ // Setting allocator as nullptr should get revert to the cudamalloc
+ LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
}
- ConvertGraphParams p(graph, output_names, subgraph_node_ids, max_batch_size,
- max_mem_per_engine, graph_properties, &output_edge_map,
- precision_mode, segment_nodes_and_device.second,
- allocator, cuda_device_id);
- if (precision_mode == INT8MODE) {
- tensorflow::Status status = GetCalibNode(&p);
- if (status != tensorflow::Status::OK()) {
- LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
- << " due to: \"" << status.ToString()
- << "\" SKIPPING......( " << subgraph_node_names.size()
- << " nodes)";
+ cudaSetDevice(cuda_device_id);
+ auto status = CreateTRTNode(&graph, engine_segments, i, alloc.get(),
+ params.max_batch_size);
+ // If status is ok, we successfully added the node to the graph and can
+ // remove segment ops. Otherwise graph is not modified.
+ if (status.ok()) {
+ for (auto node_name : converted_segments.at(i).first) {
+ graph.RemoveNode(node_map.at(node_name));
}
} else {
- tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
- if (status != tensorflow::Status::OK()) {
- LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
- << " due to: \"" << status.ToString()
- << "\" SKIPPING......( " << subgraph_node_names.size()
- << " nodes)";
- }
+ // Graph is not modified.
+ LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
+ << converted_segments.at(i).first.size()
+ << " nodes failed: " << status << ". Skipping...";
}
- count++;
}
- graph.ToGraphDef(new_graph_def);
+ cudaSetDevice(old_cuda_device);
+ graph.ToGraphDef(params.output_graph_def);
+ VLOG(1) << "Returning from conversion";
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 65a67d7e73..9d986e4890 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -30,29 +30,60 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
-// This method converts an already generated calibration graph which was used in
-// calibration runs to an inference graph
+struct ConversionParams {
+ ConversionParams()
+ : input_graph_def(nullptr),
+ max_batch_size(1),
+ max_workspace_size_bytes(1 << 30),
+ output_graph_def(nullptr),
+ precision_mode(1),
+ minimum_segment_size(3),
+ graph_properties(nullptr),
+ cluster(nullptr),
+ is_dyn_op(false),
+ fixed_input_size(true),
+ max_cached_engines(1) {}
+ const tensorflow::GraphDef* input_graph_def;
+ const std::vector<string>* output_names;
+ size_t max_batch_size;
+ size_t max_workspace_size_bytes;
+ tensorflow::GraphDef* output_graph_def;
+ int precision_mode;
+ int minimum_segment_size;
+ const tensorflow::grappler::GraphProperties* graph_properties;
+ const tensorflow::grappler::Cluster* cluster;
+ bool is_dyn_op; // Whether to create engine on conversion or execution time
+ bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed
+ int max_cached_engines; // maximum number of cached engines
+ std::vector<int> cached_engine_batches; // list of cached engines
+};
+
+// This method extracts calibration information from the resource managers
+// and puts them in to engine nodedefs.
tensorflow::Status ConvertCalibGraphToInferGraph(
- const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def);
+ const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
+ bool is_dyn_op);
-// max_batch_size: maximum batch size which can be used for inference for
-// optimization targets inference run with max batch size.
-// max_workspace_size_bytes: The upper bound of memory allowance for
-// engine building.
+// - max_batch_size: maximum batch size which can be used for inference for
+// optimization targets inference run with max batch size.
+// - max_workspace_size_bytes: The upper bound of memory allowance for engine
+// building.
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
- int precision_mode, int minimum_segment_size);
+ int precision_mode = 1, int minimum_segment_size = 3,
+ bool is_dyn_op = false, int max_cached_engines = 1,
+ std::vector<int> cached_engine_batches = {});
// Method to call from optimization pass
-tensorflow::Status ConvertAfterShapes(
- const tensorflow::GraphDef& graph, const std::vector<string>& output_names,
- size_t max_batch_size, size_t max_workspace_size_bytes,
- tensorflow::GraphDef* new_graph_def, int precision_mode,
- int minimum_segment_size,
- const tensorflow::grappler::GraphProperties& graph_properties,
- const tensorflow::grappler::Cluster* cluster);
+tensorflow::Status ConvertAfterShapes(ConversionParams& params);
+
+// Return compile time TensorRT library version information.
+std::vector<int> GetLinkedTensorRTVersion();
+
+// Return runtime time TensorRT library version information.
+std::vector<int> GetLoadedTensorRTVersion();
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 4e4d295538..146b9c7344 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
-#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <algorithm>
#include <list>
@@ -25,7 +24,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
@@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -54,8 +56,11 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
namespace convert {
+using ::tensorflow::str_util::Split;
+
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
+
namespace {
inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
@@ -121,12 +126,10 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
size_t last_scope_separator = 0;
- for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) {
- if (op_name_a[i] != op_name_b[i]) {
- break;
- } else if (op_name_a[i] == '/') {
- last_scope_separator = i + 1;
- }
+ const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
+ for (size_t i = 0; i < min_size; ++i) {
+ if (op_name_a[i] != op_name_b[i]) break;
+ if (op_name_a[i] == '/') last_scope_separator = i + 1;
}
return op_name_a.substr(0, last_scope_separator);
}
@@ -417,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
}
}
-struct InferDeleter {
- template <typename T>
- void operator()(T* obj) const {
- if (obj) {
- obj->destroy();
- }
- }
-};
-
-template <typename T>
-inline std::shared_ptr<T> infer_object(T* obj) {
- return std::shared_ptr<T>(obj, InferDeleter());
-}
-
class Converter;
using OpConverter =
@@ -444,7 +433,7 @@ class Converter {
OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
- tensorflow::tensorrt::TRTWeightStore* weight_store_;
+ TRTWeightStore* weight_store_;
bool fp16_;
void register_op_converters();
tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
@@ -486,11 +475,11 @@ class Converter {
public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
- tensorflow::tensorrt::TRTWeightStore* ws, bool fp16)
+ TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters();
}
- tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; }
+ TRTWeightStore* weight_store() { return weight_store_; }
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape);
@@ -2140,559 +2129,265 @@ void Converter::register_op_converters() {
} // namespace
-tensorflow::Status ConvertCalibrationNodeToEngineNode(
- tensorflow::Graph& graph, tensorflow::Node* c_node) {
- const auto ndef = c_node->def();
-
- TFAttrs attrs(ndef);
- std::vector<string> segment_nodes(
- attrs.get<std::vector<string>>("segment_nodes"));
- std::vector<string> output_nodes(
- attrs.get<std::vector<string>>("segment_output_names"));
- std::vector<string> input_names(
- attrs.get<std::vector<string>>("input_names"));
- string res_name = attrs.get<string>("resource_name");
- VLOG(1) << "Node name " << c_node->name() << " res_name " << res_name;
- string engine_name = "my_trt_op";
- {
- const auto node_id = tensorflow::str_util::Split(res_name, "_");
- engine_name += node_id.back();
- }
- std::map<string, tensorflow::Node*> node_maps;
-
- for (auto n : graph.op_nodes()) {
- node_maps.insert({n->name(), n});
- }
- std::set<int> subgraph_ids;
- for (const auto internal_node : segment_nodes) {
- subgraph_ids.insert(node_maps.at(internal_node)->id());
- }
- if (VLOG_IS_ON(2)) {
- string node_names = StrCat(c_node->name(), " segment nodes= ");
-
- for (const auto& node_name : segment_nodes) {
- StrAppend(&node_names, node_name, ", ");
- }
- VLOG(2) << node_names;
+tensorflow::Status ConvertGraphDefToEngine(
+ const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
+ size_t max_workspace_size_bytes,
+ const std::vector<tensorflow::PartialTensorShape>& input_shapes,
+ Logger* logger, nvinfer1::IGpuAllocator* allocator,
+ TRTInt8Calibrator* calibrator,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ bool* convert_successfully) {
+ engine->reset();
+ if (convert_successfully) *convert_successfully = false;
+
+ // Create the builder.
+ TrtUniquePtrType<nvinfer1::IBuilder> builder(
+ nvinfer1::createInferBuilder(*logger));
+ builder->setMaxBatchSize(max_batch_size);
+ // TODO(aaroey): use the allocator to allocate the TRT workspace.
+ builder->setMaxWorkspaceSize(max_workspace_size_bytes);
+#if NV_TENSORRT_MAJOR > 3
+ builder->setGpuAllocator(allocator);
+#endif
+ if (precision_mode == FP16MODE) {
+ builder->setHalf2Mode(true);
+ } else if (precision_mode == INT8MODE) {
+ builder->setInt8Mode(true);
+ builder->setInt8Calibrator(calibrator);
}
- VLOG(1) << "Output Nodes:";
- std::vector<tensorflow::DataType> out_types;
- std::vector<const tensorflow::Edge*> out_edges;
+ // Create the network.
+ auto trt_network =
+ TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
+ if (!trt_network) {
+ return tensorflow::errors::Internal(
+ "Failed to create TensorRT network object");
+ }
+ auto ws = std::unique_ptr<TRTWeightStore>(new TRTWeightStore());
- for (auto& i : output_nodes) {
- auto node_port = tensorflow::str_util::Split(i, ":");
- VLOG(1) << " " << i << " in graph " << node_maps.count(i);
- auto out_node_name = node_port.at(0);
- if (node_port.size() > 1) {
- VLOG(1) << "Multi port output" << node_port.at(0) << " "
- << node_port.at(1) << " size=" << node_port.size();
- }
- auto node_it = node_maps.find(out_node_name);
- if (node_it != node_maps.end()) {
- tensorflow::Node* out_node = node_it->second;
- int port = 0;
- if (node_port.size() == 2) {
- port = std::strtoul(node_port.at(1).c_str(), nullptr, 10);
- out_types.push_back(out_node->output_type(port));
- } else {
- out_types.push_back(out_node->output_type(0));
+ // Build the network
+ VLOG(1) << "Starting engine conversion ";
+ Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE);
+ std::vector<std::pair<string, string>> output_tensors;
+ // Graph nodes are already topologically sorted during construction
+ for (const auto& node_def : gdef.node()) {
+ string node_name = node_def.name();
+ VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
+ if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
+ (node_def.op() == "Placeholder")) {
+ nvinfer1::DimsCHW input_dim_pseudo_chw;
+ for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
+ nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
+ auto type_status =
+ ConvertDType(node_def.attr().at("dtype").type(), &dtype);
+ if (type_status != tensorflow::Status::OK()) {
+ LOG(WARNING) << "Type conversion failed for " << node_name;
+ return type_status;
}
- for (auto out_edge : out_node->out_edges()) {
- if (subgraph_ids.count(out_edge->dst()->id()))
- continue; // skip internal edges;
- if (out_edge->src_output() == port) {
- out_edges.push_back(out_edge);
- VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":"
- << out_edge->src_output() << " -> " << out_edge->dst()->name()
- << ":" << out_edge->dst_input();
+ int32 slot_number = -1;
+ if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
+ &slot_number)) {
+ LOG(ERROR) << "Failed to parse slot number from " << node_name
+ << " +8= " << node_name.c_str() + 8;
+ }
+ auto shape = input_shapes.at(slot_number);
+ if (shape.dims() > 8) {
+ LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
+ << " at input slot " << slot_number;
+ return tensorflow::errors::OutOfRange(
+ "Input tensor rank is greater than 8");
+ }
+ if (VLOG_IS_ON(1)) {
+ string dim_str("dims=");
+ StrAppend(&dim_str, "[ ", shape.dim_size(0));
+ for (int i = 1; i < shape.dims(); i++) {
+ StrAppend(&dim_str, ", ", shape.dim_size(i));
}
+ StrAppend(&dim_str, " ]");
+ VLOG(1) << dim_str;
+ }
+ for (int i = 1; i < shape.dims(); i++) {
+ input_dim_pseudo_chw.d[i - 1] = shape.dim_size(i);
}
- } else {
- LOG(WARNING) << " couldn't find output node " << out_node_name;
- }
- }
- if (VLOG_IS_ON(1)) {
- VLOG(1) << c_node->name() << " Input Nodes:";
- for (auto& i : input_names) {
- VLOG(1) << " Input " << i << " in graph " << node_maps.count(i);
- }
- }
- auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
- auto resmgr = trt_rm->getManager("TRTCalibOps");
- tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
- auto status = resmgr->Lookup(res_name, res_name, &calib_res);
- if (!status.ok() || !calib_res->calibrator_) {
- return tensorflow::errors::FailedPrecondition(
- "You must run calibration"
- " and inference conversion in the same process");
- }
-
- calib_res->calibrator_->setDone();
- calib_res->thr_->join();
- delete calib_res->thr_;
- if (!calib_res->engine_) {
- LOG(ERROR) << "Calibration failed!, engine does not exist. Did you run "
- "calibration graph?";
- return tensorflow::errors::FailedPrecondition(
- "Calibration graph needs to be executed on"
- " calibration data before convertsion to inference graph");
- }
- auto weight_rmgr = trt_rm->getManager("WeightStore");
- TF_CHECK_OK(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
- res_name, res_name));
- auto engine_plan = calib_res->engine_->serialize();
- calib_res->engine_->destroy();
- calib_res->network_->destroy();
- calib_res->builder_->destroy();
- calib_res->thr_ = nullptr;
- calib_res->engine_ = nullptr;
- calib_res->builder_ = nullptr;
- tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
- std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
- income_edges.resize(c_node->num_inputs());
- for (const auto in_edge : c_node->in_edges()) {
- auto src = in_edge->src();
- int dest_port = in_edge->dst_input();
- VLOG(1) << "Incoming connection " << src->name() << ":"
- << in_edge->src_output() << " -> " << c_node->name() << ":"
- << dest_port;
- income_edges.at(dest_port) = {src->name(), in_edge->src_output(),
- c_node->input_type(dest_port)};
- }
- tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
- income_edges);
- if (VLOG_IS_ON(2)) {
- for (const auto& inp : input_list) {
- VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " "
- << tensorflow::DataTypeString(inp.data_type);
- }
- }
- op_builder.Input(input_list);
- tensorflow::NodeDef engine_node;
- const char* engine_plan_data = static_cast<const char*>(engine_plan->data());
- string engine_plan_string(engine_plan_data,
- engine_plan_data + engine_plan->size());
- status = op_builder.Attr("serialized_engine", engine_plan_string)
- .Attr("input_nodes", input_names)
- .Attr("output_nodes", output_nodes)
- .Attr("OutT", out_types)
- .Finalize(&engine_node);
- if (!status.ok()) {
- LOG(ERROR) << "Engine Node creation failed";
- return status;
- }
- auto trt_engine_node = graph.AddNode(engine_node, &status);
- TF_RETURN_IF_ERROR(status);
- std::map<string, int> port_map;
- for (size_t t = 0; t < output_nodes.size(); t++) {
- port_map.insert({output_nodes.at(t), t});
- }
- for (auto& i : out_edges) {
- string s(i->src()->name());
- if (i->src_output()) StrAppend(&s, ":", i->src_output());
- int out_port = port_map.at(s);
- VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port
- << " -> " << i->dst()->name() << ":" << i->dst_input();
- TF_RETURN_IF_ERROR(
- graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input()));
- }
- for (const auto ed : trt_engine_node->in_edges()) {
- VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output()
- << " -> " << ed->dst()->name() << ":" << ed->dst_input();
- }
- for (const auto ed : trt_engine_node->out_edges()) {
- VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output()
- << " -> " << ed->dst()->name() << ":" << ed->dst_input();
- }
- VLOG(1) << "Segment nodes:";
- for (auto& i : segment_nodes) {
- VLOG(1) << " " << i << " in graph " << node_maps.count(i);
- auto it = node_maps.find(i);
- if (it != node_maps.end()) {
- graph.RemoveNode(it->second);
- }
- }
- graph.RemoveNode(c_node);
- return tensorflow::Status::OK();
-}
-tensorflow::Status ReverseTopologicalSort(
- const tensorrt::convert::SubGraphParams& s,
- std::list<tensorflow::Node*>* order) {
- std::vector<tensorflow::Node*> order_vec;
- tensorflow::GetPostOrder(s.graph, &order_vec);
- // Select just the subgraph
- for (tensorflow::Node* node : order_vec) {
- if (s.subgraph_node_ids.count(node->id())) {
- // We want topological order to contstruct the
- // network layer by layer
- order->push_front(node);
+ input_dim_pseudo_chw.nbDims = shape.dims() - 1;
+ nvinfer1::ITensor* input_tensor = converter.network()->addInput(
+ node_name.c_str(), dtype, input_dim_pseudo_chw);
+ if (!input_tensor) {
+ return tensorflow::errors::InvalidArgument(
+ "Failed to create Input layer tensor ", node_name,
+ " rank=", shape.dims() - 1);
+ }
+ VLOG(1) << "Input tensor name :" << node_name;
+ if (!converter.insert_input_tensor(node_name, input_tensor)) {
+ return tensorflow::errors::AlreadyExists(
+ "Output tensor already exists for op: " + node_name);
+ }
+ } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
+ (node_def.op() == "Identity")) {
+ int32 slot_number = -1;
+ if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
+ &slot_number)) {
+ LOG(ERROR) << "Failed to parse slot number from " << node_name
+ << " +9=" << node_name.c_str() + 9;
+ }
+ if (output_tensors.size() <= slot_number) {
+ output_tensors.resize(slot_number + 1);
+ }
+ output_tensors.at(slot_number) = {node_def.input(0), node_name};
+ } else {
+ VLOG(2) << "Converting node: " << node_def.name() << " , "
+ << node_def.op();
+ TF_RETURN_IF_ERROR(converter.convert_node(node_def));
}
}
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status SetInputList(
- const tensorrt::convert::SubGraphParams& s,
- tensorflow::NodeDefBuilder* op_builder,
- const std::vector<string>* input_names,
- std::vector<tensorflow::DataType>* input_dtypes) {
- std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
- VLOG(2) << "input edge size: " << input_names->size();
- for (size_t i = 0; i < input_names->size(); ++i) {
- VLOG(2) << "input edges: " << i << " " << input_names->at(i);
- int output_idx = s.input_inds.at(i).second;
- // we wired up the input here already, it is redundant to do it again in
- // ConvertSubGraphToTensorRT(convert_graph.cc)
- auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
- input_names->at(i), output_idx, input_dtypes->at(i));
- income_edges.push_back(incoming_edge);
- }
- tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
- income_edges);
- op_builder->Input(input_list);
- return tensorflow::Status::OK();
-}
-
-string SubgraphNameScopeGenerator(const std::list<tensorflow::Node*>* order) {
- string subgraph_name_scope;
- if (!order->empty()) {
- subgraph_name_scope = order->front()->name();
- }
- for (const tensorflow::Node* node : *order) {
- subgraph_name_scope = GetCommonNameScope(subgraph_name_scope, node->name());
- }
- // TODO(sami,ben,jie): proper naming!
- return subgraph_name_scope;
-}
-
-tensorflow::Status ConvertSubgraph(
- Converter& converter, tensorrt::convert::SubGraphParams& s,
- std::list<tensorflow::Node*>* order, std::vector<string>* input_names,
- std::vector<tensorflow::DataType>* input_dtypes,
- std::vector<string>* output_names,
- std::vector<tensorflow::DataType>* output_dtypes,
- const string& engine_name) {
- std::set<string> added_tensors;
- for (const std::pair<int, int>& input : s.input_inds) {
- VLOG(2) << "parsing input. Node id= " << input.first;
- int node_id = input.first;
- int output_idx = input.second;
- tensorflow::Node* node = s.graph.FindNodeId(node_id);
- auto node_name = node->name();
- // input_names should use the node name in the graph
- // here it should be the input tensor name -> matching the binding
- // insert original node name without port
- auto tensor_name = node_name;
- if (output_idx != 0) {
- tensor_name = StrCat(tensor_name, ":", output_idx);
- }
-
- VLOG(2) << "input name: " << node_name << " tensor_name: " << tensor_name
- << " idx: " << output_idx;
-
- auto shape_inference_node_name = node_name;
- auto shape_inference_output_idx = output_idx;
- // rewire the shape inference to original node in the graph
- if (s.output_edge_map->count(tensor_name)) {
- shape_inference_node_name = s.output_edge_map->at(tensor_name).second;
- shape_inference_output_idx = s.output_edge_map->at(tensor_name).first;
- }
- if (shape_inference_output_idx < 0) continue;
- VLOG(2) << "shapeinference name: " << shape_inference_node_name
- << " idx: " << shape_inference_output_idx;
-
- if (!s.graph_properties.HasOutputProperties(shape_inference_node_name))
- return tensorflow::errors::Internal("failed to find input node: " +
- shape_inference_node_name);
-
- auto op_info_vec =
- s.graph_properties.GetOutputProperties(shape_inference_node_name);
- if (static_cast<int>(op_info_vec.size()) <= shape_inference_output_idx)
- return tensorflow::errors::Internal(
- "accessing output index of: ", shape_inference_output_idx,
- ", at node: ", shape_inference_node_name,
- " with output entry from shape_map: ", op_info_vec.size());
-
- auto op_info = op_info_vec.at(shape_inference_output_idx);
- tensorflow::DataType tf_dtype = op_info.dtype();
-
- nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
- auto type_status = ConvertDType(tf_dtype, &dtype);
- if (type_status != tensorflow::Status::OK()) {
- LOG(WARNING) << "Type conversion failed for " << node_name;
- return type_status;
- }
-
- VLOG(2) << "Accessing output index of: " << output_idx
- << ", at node: " << node_name
- << " with output entry from shape_map: " << op_info_vec.size();
- // TODO(ben,jie): update TRT input format/dimension
- nvinfer1::DimsCHW input_dim_pseudo_chw;
- for (int i = 0; i < 3; i++) input_dim_pseudo_chw.d[i] = 1;
-
- // TODO(jie): TRT 3.x only support 4 dimensional input tensor.
- // update the code once TRT 4.0 comes out.
- if (op_info.shape().dim_size() != 4) {
- string err_str = "Require 4 dimensional input.";
- StrAppend(&err_str, " Got ", op_info.shape().dim_size(), " ",
- shape_inference_node_name);
- return tensorflow::errors::Unimplemented(err_str);
- }
-
- for (int i = 1; i < op_info.shape().dim_size(); i++) {
- VLOG(2) << "dimension: " << i
- << " , size: " << op_info.shape().dim(i).size();
- input_dim_pseudo_chw.d[i - 1] = op_info.shape().dim(i).size();
- }
-
- // TODO(ben,jie): proper way to restore input tensor name?
- auto input_tensor_name = node_name;
- if (output_idx != 0) {
- input_tensor_name = StrCat(node_name, ":", output_idx);
- }
- if (added_tensors.count(input_tensor_name)) continue;
- added_tensors.insert(input_tensor_name);
- input_names->push_back(input_tensor_name);
- input_dtypes->push_back(tf_dtype);
- nvinfer1::ITensor* input_tensor = converter.network()->addInput(
- input_tensor_name.c_str(), dtype, input_dim_pseudo_chw);
-
- if (!input_tensor)
- return tensorflow::errors::InvalidArgument(
- "Failed to create Input layer");
- VLOG(2) << "Input tensor name :" << input_tensor_name;
-
- if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
- return tensorflow::errors::AlreadyExists(
- "Output tensor already exists for op: " + input_tensor_name);
- }
-
- for (const tensorflow::Node* node : *order) {
- const tensorflow::NodeDef& node_def = node->def();
- VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
- TF_RETURN_IF_ERROR(converter.convert_node(node_def));
- }
-
- VLOG(2) << "Finished conversion";
-
- // Gather output metadata
- int trt_engine_op_output_idx = 0;
- added_tensors.clear();
- for (const std::pair<int, int>& output : s.output_inds) {
- int node_id = output.first;
- int output_idx = output.second;
- tensorflow::Node* node = s.graph.FindNodeId(node_id);
- string op_name = node->name();
- string tensor_name = op_name;
-
- s.output_edge_map->insert(
- {trt_engine_op_output_idx == 0
- ? engine_name
- : StrCat(engine_name, ":", trt_engine_op_output_idx),
- {output_idx, tensor_name}});
- trt_engine_op_output_idx++;
- if (output_idx != 0)
- tensorflow::strings::StrAppend(&tensor_name, ":", output_idx);
- VLOG(2) << "Output tensor name: " << tensor_name;
- if (added_tensors.count(tensor_name)) continue;
- added_tensors.insert(tensor_name);
- output_names->push_back(tensor_name);
- auto tensor_or_weights = converter.get_tensor(tensor_name);
+ for (const auto& output : output_tensors) {
+ auto tensor_or_weights = converter.get_tensor(output.first);
if (!tensor_or_weights.is_tensor()) {
- return tensorflow::errors::InvalidArgument("Output node '" + tensor_name +
- "' is weights not tensor");
+ return tensorflow::errors::InvalidArgument(
+ "Output node '" + output.first + "' is weights not tensor");
}
nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
+ tensor->setName(output.second.c_str());
if (!tensor) {
return tensorflow::errors::NotFound("Output tensor not found: " +
- tensor_name);
+ output.first);
}
+ VLOG(1) << "Marking output tensor " << output.first << ", as output tensor "
+ << output.second;
+
converter.network()->markOutput(*tensor);
- tensorflow::DataType tf_dtype = node->output_type(output_idx);
- output_dtypes->push_back(tf_dtype);
- nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
- TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
- tensor->setType(trt_dtype);
}
+ if (convert_successfully) *convert_successfully = true;
- return tensorflow::Status::OK();
-}
-
-tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) {
- // Visit nodes in reverse topological order and construct the TRT network.
- // Toposort
- std::list<tensorflow::Node*> order;
- TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
-
- static int static_id = 0;
- string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
- // TODO(sami,ben,jie): proper naming!
- string calib_op_name =
- StrCat(subgraph_name_scope, "my_trt_calib_op_", static_id);
- string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id);
- static_id++;
-
- auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
- auto op_rmgr = trt_rmgr->getManager("TRTCalibOps");
- auto op_res = new tensorflow::tensorrt::TRTCalibrationResource();
- TF_CHECK_OK(op_rmgr->Create(calib_op_name, calib_op_name, op_res));
- op_res->logger_ = new tensorflow::tensorrt::Logger();
- cudaSetDevice(s.cuda_gpu_id_);
- op_res->builder_ = nvinfer1::createInferBuilder(*(op_res->logger_));
- op_res->allocator_ = s.allocator_;
-#if NV_TENSORRT_MAJOR > 3
- op_res->builder_->setGpuAllocator(s.allocator_.get());
-#endif
- if (!op_res->builder_) {
- return tensorflow::errors::Internal(
- "failed to create TensorRT builder object");
+ // Build the engine.
+ VLOG(1) << "Starting engine creation";
+ engine->reset(builder->buildCudaEngine(*converter.network()));
+ if (engine->get() == nullptr) {
+ return tensorflow::errors::Internal("Failed to build TensorRT engine");
}
-
- op_res->network_ = op_res->builder_->createNetwork();
- if (!op_res->network_) {
- return tensorflow::errors::Internal(
- "failed to create TensorRT network object");
- }
-
- // Build the network
- auto weight_rmgr = trt_rmgr->getManager("WeightStore");
- auto ws = new tensorflow::tensorrt::TRTWeightStore();
- TF_CHECK_OK(weight_rmgr->Create(calib_op_name, calib_op_name, ws));
- Converter converter(op_res->network_, ws, s.precision_mode == FP16MODE);
-
- std::vector<string> input_names;
- std::vector<tensorflow::DataType> input_dtypes;
- std::vector<string> output_names;
- std::vector<tensorflow::DataType> output_dtypes;
- TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names,
- &input_dtypes, &output_names,
- &output_dtypes, engine_name));
-
- VLOG(2) << "Finished processing outputs";
-
- // Build the engine
- op_res->builder_->setMaxBatchSize(s.max_batch_size);
- op_res->builder_->setMaxWorkspaceSize(s.max_workspace_size_bytes);
- VLOG(0) << "Max batch size= " << s.max_batch_size
- << " max workspace size= " << s.max_workspace_size_bytes;
-
- // Build the TRT op
- // TODO(sami,ben,jie): proper naming!
- tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp");
- TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
-
- std::vector<string> segment_names;
- segment_names.reserve(s.subgraph_node_ids.size());
- for (int i : s.subgraph_node_ids) {
- auto node = s.graph.FindNodeId(i);
- segment_names.push_back(node->name());
- }
- LOG(INFO) << "finished op preparation";
-
- auto status = op_builder.Attr("segment_nodes", segment_names)
- .Attr("input_names", input_names)
- .Attr("segment_output_names", output_names)
- .Attr("resource_name", calib_op_name)
- .Finalize(s.trt_node);
-
- LOG(INFO) << status.ToString();
- LOG(INFO) << "finished op building";
-
+ VLOG(1) << "Finished conversion";
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
- tensorrt::convert::SubGraphParams& s) {
- // Visit nodes in reverse topological order and construct the TRT network.
- std::list<tensorflow::Node*> order;
- TF_RETURN_IF_ERROR(ReverseTopologicalSort(s, &order));
-
- static int static_id = 0;
- string subgraph_name_scope = SubgraphNameScopeGenerator(&order);
- string engine_name = StrCat(subgraph_name_scope, "my_trt_op", static_id++);
-
- tensorflow::tensorrt::Logger trt_logger;
- cudaSetDevice(s.cuda_gpu_id_);
- auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
- if (!trt_builder) {
- return tensorflow::errors::Internal(
- "Failed to create TensorRT builder object");
- }
-#if NV_TENSORRT_MAJOR > 3
- trt_builder->setGpuAllocator(s.allocator_.get());
-#endif
- auto trt_network = infer_object(trt_builder->createNetwork());
- if (!trt_network) {
- return tensorflow::errors::Internal(
- "Failed to create TensorRT network object");
- }
-
- auto trt_rmgr = tensorflow::tensorrt::TRTResourceManager::instance();
- auto weight_rmgr = trt_rmgr->getManager("WeightStore");
- auto ws = new tensorflow::tensorrt::TRTWeightStore();
- TF_CHECK_OK(weight_rmgr->Create(engine_name, engine_name, ws));
-
- // Build the network
- Converter converter(trt_network.get(), ws, s.precision_mode == FP16MODE);
-
- std::vector<string> input_names;
- std::vector<tensorflow::DataType> input_dtypes;
- std::vector<string> output_names;
- std::vector<tensorflow::DataType> output_dtypes;
- TF_RETURN_IF_ERROR(ConvertSubgraph(converter, s, &order, &input_names,
- &input_dtypes, &output_names,
- &output_dtypes, engine_name));
-
- VLOG(2) << "Finished output";
-
- // Build the engine
- trt_builder->setMaxBatchSize(s.max_batch_size);
- trt_builder->setMaxWorkspaceSize(s.max_workspace_size_bytes);
- VLOG(0) << "Max batch size= " << s.max_batch_size
- << " max workspace size= " << s.max_workspace_size_bytes;
- if (s.precision_mode == FP16MODE) {
- trt_builder->setHalf2Mode(true);
- VLOG(0) << "Using FP16 precision mode";
- }
- LOG(INFO) << "starting build engine";
- string engine_plan_string;
- {
- auto trt_engine =
- infer_object(trt_builder->buildCudaEngine(*converter.network()));
- VLOG(0) << "Built network";
- if (trt_engine.get() == nullptr) {
- return tensorflow::errors::Internal("Engine building failure");
+tensorflow::Status ConvertSegmentToGraphDef(
+ const tensorflow::Graph* graph,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::vector<int>& subgraph_node_ids, // In topological order
+ std::vector<EngineConnection>* connections,
+ tensorflow::GraphDef* segment_def, string* common_scope) {
+ std::set<string> marker_nodes;
+ // Update connection shapes/data types and add corresponding input/output
+ // nodes in the segment graphdef.
+ for (size_t i = 0; i < connections->size(); ++i) {
+ auto& connection = connections->at(i);
+ auto outside_node = graph->FindNodeId(connection.outside_id);
+ if (!outside_node) {
+ // This should never happen, unless the original graph is problematic.
+ return tensorflow::errors::NotFound(
+ "Cannot find node with id ", connection.outside_id, " in the graph.");
+ }
+ // Updates the shape and data types of input/output connections.
+ tensorflow::DataType input_type = tensorflow::DT_FLOAT;
+ tensorflow::PartialTensorShape partial_shape;
+ if (connection.is_input_edge) {
+ if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
+ auto output_params =
+ graph_properties.GetOutputProperties(connection.outside_node_name);
+ auto out_shape = output_params.at(connection.outside_port);
+ input_type = out_shape.dtype();
+ std::vector<tensorflow::int64> dims;
+ partial_shape = out_shape.shape();
+ connection.outside_shape = partial_shape;
+ } else {
+ VLOG(0) << "Unknown output shape" << outside_node->name();
+ input_type = graph->FindNodeId(connection.outside_id)
+ ->output_type(connection.outside_port);
+ }
+ connection.connection_type = input_type;
+
+ } else { // output edge
+ if (graph_properties.HasInputProperties(connection.outside_node_name)) {
+ auto input_params =
+ graph_properties.GetInputProperties(connection.outside_node_name);
+ auto in_shape = input_params.at(connection.outside_port);
+ input_type = in_shape.dtype();
+ partial_shape = in_shape.shape();
+ connection.inside_shape = partial_shape;
+ } else {
+ input_type = graph->FindNodeId(connection.inside_id)
+ ->output_type(connection.outside_port);
+ }
+ connection.connection_type = input_type;
}
- auto engine_plan = infer_object(trt_engine->serialize());
- VLOG(0) << "Serialized engine";
- const char* engine_plan_data =
- static_cast<const char*>(engine_plan->data());
- engine_plan_string =
- string(engine_plan_data, engine_plan_data + engine_plan->size());
- }
- TF_RETURN_IF_ERROR(weight_rmgr->Delete<tensorflow::tensorrt::TRTWeightStore>(
- engine_name, engine_name));
- LOG(INFO) << "finished engine " << engine_name << " containing "
- << s.subgraph_node_ids.size() << " nodes";
-
- // Build the TRT op
- tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp");
- TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes));
-
- VLOG(0) << "Finished op preparation";
-
- auto status = op_builder.Attr("serialized_engine", engine_plan_string)
- .Attr("input_nodes", input_names)
- .Attr("output_nodes", output_names)
- .Attr("OutT", output_dtypes)
- .Device(s.device_name_)
- .Finalize(s.trt_node);
-
- VLOG(0) << status.ToString() << " finished op building for " << engine_name
- << " on device " << s.device_name_;
+ // Add dummy input/output nodes to the segment graphdef.
+ if (connection.is_input_edge) {
+ const string node_name = StrCat(kInputPHName, connection.port_number);
+ if (marker_nodes.count(node_name)) {
+ VLOG(1) << "Reusing input " << node_name << " for the edge "
+ << connection.outside_node_name << ":"
+ << connection.outside_port << " -> "
+ << connection.inside_node_name << ":" << connection.inside_port;
+ continue;
+ }
+ marker_nodes.insert(node_name);
+ auto seg_node = segment_def->add_node();
+ tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
+ auto status = builder.Attr("shape", partial_shape)
+ .Attr("dtype", input_type)
+ .Finalize(seg_node);
+ VLOG(1) << "Constructing input " << node_name << " for the edge "
+ << connection.outside_node_name << ":" << connection.outside_port
+ << " -> " << connection.inside_node_name << ":"
+ << connection.inside_port;
+ } else {
+ const string node_name = StrCat(kOutputPHName, connection.port_number);
+ if (marker_nodes.count(node_name)) {
+ VLOG(1) << "Reusing output " << node_name << " for the edge "
+ << connection.inside_node_name << ":" << connection.inside_port
+ << " -> " << connection.outside_node_name << ":"
+ << connection.outside_port;
+ continue;
+ }
+ marker_nodes.insert(node_name);
+ auto seg_node = segment_def->add_node();
+ tensorflow::NodeDefBuilder builder(node_name, "Identity");
+ auto status = builder.Input(connection.inside_node_name, 0, input_type)
+ .Finalize(seg_node);
+ VLOG(1) << "Constructing output " << node_name << " for the edge "
+ << connection.inside_node_name << ":" << connection.inside_port
+ << " -> " << connection.outside_node_name << ":"
+ << connection.outside_port;
+ }
+ } // for each connection.
+
+ std::unordered_map<int, int> old_to_new_id_map;
+ // Copy internal nodes to new graphdef
+ string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name();
+ for (const auto node_id : subgraph_node_ids) {
+ const auto node = graph->FindNodeId(node_id);
+ local_scope = GetCommonNameScope(local_scope, node->name());
+ old_to_new_id_map[node_id] = segment_def->node_size();
+ auto snode = segment_def->add_node();
+ snode->CopyFrom(node->def());
+ VLOG(1) << "Copying " << snode->name() << " to subgraph";
+ }
+ // Update the inputs of the new input nodes to point to placeholder nodes.
+ for (int i = 0; i < connections->size(); ++i) {
+ auto& connection = connections->at(i);
+ if (!connection.is_input_edge) continue;
+ auto snode =
+ segment_def->mutable_node(old_to_new_id_map[connection.inside_id]);
+ const string placeholder_name =
+ StrCat(kInputPHName, connection.port_number);
+ VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
+ << " from " << snode->input(connection.inside_port) << " to "
+ << placeholder_name;
+ snode->set_input(connection.inside_port, placeholder_name);
+ }
+ *common_scope = local_scope;
+ VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 3f6592cd25..1a4c0e755d 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -22,69 +22,112 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
+
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
namespace tensorflow {
namespace tensorrt {
+static const char* kInputPHName = "InputPH_";
+static const char* kOutputPHName = "OutputPH_";
namespace convert {
+// TODO(aaroey): use an enum instead.
const int FP32MODE = 0;
const int FP16MODE = 1;
const int INT8MODE = 2;
-struct SubGraphParams {
- SubGraphParams(
- tensorflow::Graph& inp_graph,
- const std::set<int>& subgraph_node_id_numbers,
- const std::vector<std::pair<int, int>>& input_indices,
- const std::vector<std::pair<int, int>>& output_indices,
- size_t max_supported_batch_size, size_t max_consumed_workspace_size_bytes,
- const tensorflow::grappler::GraphProperties& current_graph_properties,
- std::unordered_map<string, std::pair<int, string>>* output_edges,
- tensorflow::NodeDef* constructed_trt_node,
- int engine_precision_mode = FP32MODE, const string& device_name = "",
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator = nullptr,
- int cuda_gpu_id = 0)
- : graph(inp_graph),
- subgraph_node_ids(subgraph_node_id_numbers),
- input_inds(input_indices),
- output_inds(output_indices),
- max_batch_size(max_supported_batch_size),
- max_workspace_size_bytes(max_consumed_workspace_size_bytes),
- graph_properties(current_graph_properties),
- output_edge_map(output_edges),
- trt_node(constructed_trt_node),
- precision_mode(engine_precision_mode),
- device_name_(device_name),
- allocator_(allocator),
- cuda_gpu_id_(cuda_gpu_id) {}
-
- tensorflow::Graph& graph;
- const std::set<int>& subgraph_node_ids;
- const std::vector<std::pair<int, int>>& input_inds; // {node_id, output_idx}
- const std::vector<std::pair<int, int>>& output_inds; // {node_id, output_idx}
- size_t max_batch_size;
- size_t max_workspace_size_bytes;
- const tensorflow::grappler::GraphProperties& graph_properties;
- std::unordered_map<string, std::pair<int, string>>* output_edge_map;
- tensorflow::NodeDef* trt_node;
- const int precision_mode;
- const string device_name_;
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
- const int cuda_gpu_id_;
+struct EngineConnection {
+ EngineConnection(const string& outside, int out_id, int out_port,
+ const string& inside, int in_id, int in_port,
+ bool input_edge, int port)
+ : outside_node_name(outside),
+ outside_id(out_id),
+ outside_port(out_port),
+ inside_node_name(inside),
+ inside_id(in_id),
+ inside_port(in_port),
+ is_input_edge(input_edge),
+ port_number(port) {}
+
+ const string outside_node_name;
+ const int outside_id;
+ const int outside_port;
+ tensorflow::PartialTensorShape outside_shape;
+
+ const string inside_node_name;
+ const int inside_id;
+ const int inside_port;
+ tensorflow::PartialTensorShape inside_shape;
+
+ tensorflow::DataType connection_type;
+ bool is_input_edge;
+
+ // The port number of the TRT node connecting to this edge.
+ int port_number;
+};
+
+struct EngineInfo {
+ EngineInfo()
+ : engine_type(EngineType::TRTStatic),
+ max_workspace_size_bytes(0),
+ precision_mode(FP32MODE) {}
+
+ string engine_name;
+ string device;
+ tensorflow::GraphDef segment_graph_def;
+
+ // The segment nodes that are on one side of the edges are topological sorted.
+ std::vector<EngineConnection> connections;
+
+ enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
+ EngineType engine_type;
+ int64 max_workspace_size_bytes;
+ int maximum_cached_engines;
+ std::vector<int> cached_engine_batches;
+ int precision_mode;
};
-// TODO(sami): Replace references with const reference or pointers
-tensorflow::Status ConvertSubGraphToTensorRTNodeDef(SubGraphParams& params);
-tensorflow::Status InjectCalibrationNode(SubGraphParams& params);
-tensorflow::Status ConvertCalibrationNodeToEngineNode(tensorflow::Graph& graph,
- tensorflow::Node* c_node);
+// Constructs a graphdef from the segment in the given graph. Adds placeholder
+// nodes for input edges (InputPH_*) and identity nodes for output edges
+// (OutputPH_*). This function needs to be called before TensorRT nodes
+// inserted in order to correctly get sizes from the original graph.
+//
+// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
+// topological order.
+// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
+// sorted in topological order.
+tensorflow::Status ConvertSegmentToGraphDef(
+ const tensorflow::Graph* graph,
+ const tensorflow::grappler::GraphProperties& graph_properties,
+ const std::vector<int>& subgraph_node_ids,
+ std::vector<EngineConnection>* connections,
+ tensorflow::GraphDef* segment_def, string* common_scope);
+
+// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
+// 'builder' successfully build the engine. If the result is not ok, 'engine'
+// will be set to nullptr
+// Once returned, 'builder' is not needed any more and can be safely detroyed.
+//
+// - convert_successfully: indicates whether the converson to TensorRT network
+// is successful. This is different than successfully building the engine:
+// building can still fail afterwards.
+tensorflow::Status ConvertGraphDefToEngine(
+ const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
+ size_t max_workspace_size_bytes,
+ const std::vector<tensorflow::PartialTensorShape>& input_shapes,
+ Logger* logger, nvinfer1::IGpuAllocator* allocator,
+ TRTInt8Calibrator* calibrator,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ bool* convert_successfully);
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 8f634b1f74..ec9dbfa13b 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -45,8 +45,24 @@ tensorflow::Status TRTOptimizationPass::Init(
if (params.count("max_batch_size")) {
maximum_batch_size_ = params.at("max_batch_size").i();
}
- if (params.count("max_workspace_size_bytes"))
+ is_dynamic_op_ = false;
+ if (params.count("is_dynamic_op")) {
+ is_dynamic_op_ = params.at("is_dynamic_op").b();
+ }
+ if (params.count("cached_engine_batches")) {
+ auto batch_vec = params.at("cached_engine_batches").list();
+ batches_.reserve(batch_vec.i_size());
+ for (const auto i : batch_vec.i()) {
+ batches_.push_back(i);
+ }
+ }
+ max_cached_batches_ = 1;
+ if (params.count("maximum_cached_engines")) {
+ max_cached_batches_ = params.at("maximum_cached_engines").i();
+ }
+ if (params.count("max_workspace_size_bytes")) {
maximum_workspace_size_ = params.at("max_workspace_size_bytes").i();
+ }
if (params.count("precision_mode")) {
string pm = Uppercase(params.at("precision_mode").s());
if (pm == "FP32") {
@@ -175,6 +191,17 @@ tensorflow::Status TRTOptimizationPass::Optimize(
if (VLOG_IS_ON(1)) {
PrintDebugInfo(cluster, item);
}
+ // This is a hack to workaround optimizer issue. MetaOptimizer calls
+ // optimization passes on function objects as well, we should not modify
+ // generated funcdefs! This is fragile but we don't have any other option
+ // until framework fixes it.
+ if (item.id != "tf_graph") {
+ LOG(WARNING) << name_
+ << " is probably called on funcdef! This optimizer must *NOT* "
+ "be called on function objects.";
+ *optimized_graph = item.graph;
+ return tensorflow::Status::OK();
+ }
int max_dim = -1;
if (item.feed.size()) {
for (const auto& f : item.feed) {
@@ -204,11 +231,22 @@ tensorflow::Status TRTOptimizationPass::Optimize(
}
tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
- auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(
- item.graph, item.fetch, maximum_batch_size_, maximum_workspace_size_,
- optimized_graph, precision_mode_, minimum_segment_size_,
- static_graph_properties, cluster);
+ tensorflow::tensorrt::convert::ConversionParams cp;
+ cp.input_graph_def = &item.graph;
+ cp.output_names = &item.fetch;
+ cp.max_batch_size = maximum_batch_size_;
+ cp.max_workspace_size_bytes = maximum_workspace_size_;
+ cp.output_graph_def = optimized_graph;
+ cp.precision_mode = precision_mode_;
+ cp.minimum_segment_size = minimum_segment_size_;
+ cp.graph_properties = &static_graph_properties;
+ cp.cluster = cluster;
+ cp.is_dyn_op = is_dynamic_op_;
+ cp.cached_engine_batches = batches_;
+ cp.max_cached_engines = max_cached_batches_;
+ auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
VLOG(2) << optimized_graph->DebugString();
+ VLOG(1) << "Returning from " << name_;
return status;
}
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
index d8ecead23e..463ed3883e 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
@@ -61,6 +61,9 @@ class TRTOptimizationPass : public tensorflow::grappler::CustomGraphOptimizer {
int minimum_segment_size_;
int precision_mode_;
int maximum_batch_size_;
+ bool is_dynamic_op_;
+ std::vector<int> batches_;
+ int max_cached_batches_;
int64_t maximum_workspace_size_;
};
diff --git a/tensorflow/contrib/tensorrt/convert/utils.h b/tensorflow/contrib/tensorrt/convert/utils.h
new file mode 100644
index 0000000000..f601c06701
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/utils.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
+#define TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
+
+#include <memory>
+
+namespace tensorflow {
+namespace tensorrt {
+
+template <typename T>
+struct TrtDestroyer {
+ void operator()(T* t) {
+ if (t) t->destroy();
+ }
+};
+
+template <typename T>
+using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
+
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_TENSORRT_CONVERT_UTILS_H_
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
deleted file mode 100644
index aea44fd8a2..0000000000
--- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc
+++ /dev/null
@@ -1,136 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h"
-#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
-#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
-#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/tensor_types.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/platform/stream_executor.h"
-
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
-#include "cuda/include/cuda_runtime_api.h"
-#include "tensorrt/include/NvInfer.h"
-
-namespace tensorflow {
-namespace tensorrt {
-
-TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_));
- OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_));
- OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_));
-};
-
-#define TYPECASE(dt, X, Y) \
- case dt: { \
- return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
- }
-
-void* GetTensorAddress(const Tensor* tensor_ptr) {
- auto tensor_type = tensor_ptr->dtype();
- switch (tensor_type) {
- TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
- TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
- TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
- default: {
- LOG(FATAL) << "Unsupported Data type "
- << tensorflow::DataTypeString(tensor_type);
- return nullptr;
- }
- }
-}
-
-void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) {
- // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR.
- auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
- auto res_mgr = trt_rm->getManager("TRTCalibOps");
- tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
- auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res);
-
- if (!status.ok()) {
- ctx->SetStatus(status);
- return;
- }
- int num_inputs = ctx->num_inputs();
- // first run instantiate calibrator
- if (calib_res->calibrator_ == nullptr) {
- dev_tensors_.resize(num_inputs);
- int batch_size = ctx->input(0).dim_size(0);
- VLOG(1) << " Constructing calibrator";
- for (int i = 0; i < num_inputs; i++) {
- // allocate workspace on device for inputs
- const tensorflow::Tensor& t = ctx->input(i);
- OP_REQUIRES_OK(ctx,
- ctx->allocate_persistent(t.dtype(), t.shape(),
- &dev_tensors_.at(i), nullptr));
- const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
- CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
- void* device_address = GetTensorAddress(device_tensor);
- device_buffers_.emplace(input_names_.at(i),
- std::pair<void*, size_t>(
- device_address, device_tensor->TotalBytes()));
- }
-
- calib_res->calibrator_ =
- new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_);
- string label(resource_name_);
- calib_res->thr_ = new std::thread([calib_res, label]() {
- VLOG(1) << "Starting calibration thread, Calibration Resource @ "
- << calib_res;
- calib_res->builder_->setInt8Calibrator(calib_res->calibrator_);
- calib_res->builder_->setInt8Mode(true);
- calib_res->engine_ = calib_res->builder_->buildCudaEngine(
- *calib_res->network_); // will loop until we terminate calibrator
- VLOG(1) << "Calibration loop terminated " << label;
- });
- VLOG(1) << "initialized calibrator resource";
- } // calibrator initialized
-
- // Pass input data to calibrator
- std::unordered_map<string, void*> input_data;
- for (int i = 0; i < num_inputs; i++) {
- const Tensor& t = ctx->input(i);
- void* data_address = GetTensorAddress(&t);
- const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
- CHECK_EQ(t.TotalBytes(),
- device_tensor->TotalBytes()); // use the tensor so FW keeps it
- input_data.emplace(input_names_.at(i), data_address);
- ctx->set_output(i, t);
- }
- VLOG(2) << "Filled map for sending";
- // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
- const cudaStream_t* stream = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
- ->stream()
- ->implementation()
- ->CudaStreamMemberHack()));
- calib_res->calibrator_->setBatch(input_data, *stream);
- VLOG(2) << "Passed calibration data";
- // TODO(aaroey): make sure we wait for the completion of calibration on the
- // last batch in future PR.
-};
-
-#undef TYPECASE
-
-REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp);
-
-} // namespace tensorrt
-} // namespace tensorflow
-#endif
-#endif
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h
deleted file mode 100644
index 23df9db32f..0000000000
--- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H
-#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H
-
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/platform/types.h"
-
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
-namespace tensorflow {
-namespace tensorrt {
-// TODO(sami): Convert this to async kernel!
-class TRTCalibOp : public OpKernel {
- public:
- explicit TRTCalibOp(OpKernelConstruction* context);
-
- void Compute(OpKernelContext* context) override;
-
- private:
- string resource_name_;
- std::vector<string> segment_nodes_;
- std::vector<string> input_names_;
- std::vector<tensorflow::TensorShape> shapes_;
- std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
- std::vector<tensorflow::PersistentTensor> dev_tensors_;
-};
-} // namespace tensorrt
-} // namespace tensorflow
-#endif
-#endif
-#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 9ac8047944..8a17eb02f1 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -14,8 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
+#include <algorithm>
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
-#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -25,144 +33,556 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
-static ::tensorflow::tensorrt::Logger logger;
-using IRuntime = nvinfer1::IRuntime;
-using Dims = nvinfer1::Dims;
-
namespace tensorrt {
+static Logger logger;
+using ::nvinfer1::IRuntime;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+// A helper class to call done() when destructed for asynchronous execution.
+// Helps simultaneous execution of native and TRT engines.
+class AsyncHelper : public tensorflow::core::RefCounted {
+ public:
+ AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
+ ~AsyncHelper() override { done_(); }
-TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) : OpKernel(context) {
+ private:
+ tensorflow::AsyncOpKernel::DoneCallback done_;
+};
+
+#define TYPECASE(dt, X, Y) \
+ case dt: { \
+ return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
+ }
+
+void* GetTensorAddress(const Tensor* tensor_ptr) {
+ auto tensor_type = tensor_ptr->dtype();
+ switch (tensor_type) {
+ TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
+ TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
+ TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
+ default: {
+ LOG(ERROR) << "Unsupported Data type "
+ << tensorflow::DataTypeString(tensor_type);
+ return nullptr;
+ }
+ }
+}
+
+tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
+ VLOG(1) << "Constructing function handle";
+ auto lib = ctx->function_library();
+ if (lib == nullptr) {
+ return tensorflow::errors::Internal("Context function library is null");
+ }
+ auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
+ if (fdef == nullptr) {
+ return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_,
+ " can't be found in function library");
+ }
+ tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
+ inst_ops.overlay_lib = nullptr;
+ inst_ops.state_handle = "";
+ inst_ops.target = ctx->device()->name();
+ native_func_ = 0;
+ auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
+ inst_ops, &native_func_);
+ if (!status.ok()) {
+ LOG(ERROR) << " Instantiating native function " << funcdef_name_
+ << " failed!";
+ }
+ return status;
+}
+
+TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {
// read serialized_engine
OP_REQUIRES_OK(context,
- context->GetAttr("serialized_engine", &serialized_engine_));
+ context->GetAttr("serialized_segment", &serialized_segment_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("workspace_size_bytes", &workspace_size_));
+ OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
+ if (!static_engine_) {
+ if (!segment_graph_.ParseFromString(serialized_segment_)) {
+ LOG(ERROR) << "Parsing segment graph failed!";
+ context->SetStatus(tensorflow::errors::InvalidArgument(
+ "Failed to parse segment graphdef!"));
+ return;
+ }
+ serialized_segment_.resize(0);
+ }
+ VLOG(1) << "Constructing " << name();
+ string precision_string;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("precision_mode", &precision_string));
+ string calibration_data;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("calibration_data", &calibration_data));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("segment_funcdef_name", &funcdef_name_));
+ if (precision_string == "FP32") {
+ precision_mode_ = convert::FP32MODE;
+ } else if (precision_string == "FP16") {
+ precision_mode_ = convert::FP16MODE;
+ } else if (precision_string == "INT8") {
+ precision_mode_ = convert::INT8MODE;
+ }
+ calibration_mode_ =
+ (precision_mode_ == convert::INT8MODE && calibration_data.size() == 0);
+ if (calibration_data.size()) {
+ calibrator_.reset(new TRTInt8Calibrator(calibration_data));
+ calibration_data.resize(0);
+ }
+ native_func_ = tensorflow::kInvalidHandle;
+ OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
+ &max_cached_engines_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("fixed_input_size", &fixed_input_size_));
+ OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
+ &cached_engine_batches_));
+ std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
+ if (VLOG_IS_ON(1)) {
+ string s("Engine Batches= ");
+ for (auto i : cached_engine_batches_) {
+ StrAppend(&s, i, " ");
+ }
+ VLOG(1) << s;
+ }
+}
- // register input output node name in trt_sub_graph
- OP_REQUIRES_OK(context, context->GetAttr("input_nodes", &input_nodes_));
- OP_REQUIRES_OK(context, context->GetAttr("output_nodes", &output_nodes_));
+void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
+ AsyncHelper* helper) {
+ if (!calibration_mode_) {
+ VLOG(1) << "Executing native engine";
+ }
+ std::vector<Tensor> inputs;
+ std::vector<Tensor>* outputs = new std::vector<Tensor>();
+ if (native_func_ == tensorflow::kInvalidHandle) {
+ auto status = ConstructFunctionHandle(ctx);
+ if (!status.ok()) {
+ LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
+ ctx->SetStatus(status);
+ return;
+ }
+ }
+ auto lib = ctx->function_library();
+ tensorflow::FunctionLibraryRuntime::Options opts;
+ opts.step_id = ctx->step_id();
+ opts.rendezvous = ctx->rendezvous();
+ opts.cancellation_manager = ctx->cancellation_manager();
+ opts.runner = ctx->runner();
+ for (int i = 0; i < ctx->num_inputs(); i++) {
+ inputs.push_back(ctx->input(i));
+ }
+ helper->Ref(); // Increment count for calculating native graph
+ VLOG(1) << "Executing native segment " << name();
+ lib->Run(opts, native_func_, inputs, outputs,
+ [ctx, outputs, helper](const tensorflow::Status& s) {
+ tensorflow::core::ScopedUnref sc(helper);
+ VLOG(1) << "Native Segment completed";
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ return;
+ }
+ for (size_t t = 0; t < outputs->size(); ++t) {
+ ctx->set_output(t, outputs->at(t));
+ }
+ delete outputs;
+ });
}
-void TRTEngineOp::Compute(OpKernelContext* context) {
- // TODO(samikama) runtime should be taken from a resourcemanager as well.
- // Only engine should be in the op and context and runtime should be taken
- // from resourcemanager
+void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+ AsyncHelper* helper) {
+ helper->Ref();
+ tensorflow::core::ScopedUnref sc(helper);
+ // TODO(aaroey): remove the ResourceMgr singleton.
+ auto trt_rm = TRTResourceManager::instance();
+ auto res_mgr = trt_rm->getManager("TRTCalibration");
+ TRTCalibrationResource* calib_res = nullptr;
+ auto status = res_mgr->LookupOrCreate(
+ funcdef_name_, "Calibrator", &calib_res,
+ {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status {
+ return this->AllocateCalibrationResources(ctx, cr);
+ }});
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ int num_inputs = ctx->num_inputs();
+ // Pass input data to calibrator
+ std::unordered_map<string, void*> input_data;
+ for (int i = 0; i < num_inputs; i++) {
+ const Tensor& t = ctx->input(i);
+ void* data_address = GetTensorAddress(&t);
+ if (data_address == nullptr) {
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "Unsupported data type encountered in input ", i));
+ return;
+ }
+ // Check the allocated buffer is sufficient for input
+ const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
+ CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
+ input_data.emplace(StrCat(kInputPHName, i), data_address);
+ }
+ VLOG(2) << "Filled map for sending";
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ calib_res->calibrator_->setBatch(input_data, *stream);
+ VLOG(2) << "Passed calibration data";
+ ExecuteNativeSegment(ctx, helper);
+}
- if (!trt_execution_context_ptr_) {
- IRuntime* infer = nvinfer1::createInferRuntime(logger);
-#if NV_TENSORRT_MAJOR > 3
- auto device = context->device();
- auto dev_allocator =
- device->GetAllocator(tensorflow::AllocatorAttributes());
- if (!dev_allocator) {
- LOG(FATAL) << "Can't find device allocator for gpu device "
- << device->name();
- }
- allocator_ = std::make_shared<TRTDeviceAllocator>(dev_allocator);
- infer->setGpuAllocator(allocator_.get());
-#endif
- trt_engine_ptr_.reset(infer->deserializeCudaEngine(
- serialized_engine_.c_str(), serialized_engine_.size(),
- PluginFactoryTensorRT::GetInstance()));
- trt_execution_context_ptr_.reset(trt_engine_ptr_->createExecutionContext());
- // Runtime is safe to delete after engine creation
- infer->destroy();
- serialized_engine_.clear();
+int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
+ int num_batch = ctx->input(0).shape().dim_size(0);
+ int smallest_engine = 0;
+ for (const auto i : cached_engine_batches_) {
+ if (i >= num_batch) {
+ smallest_engine = i;
+ break;
+ }
}
- int num_binding = context->num_inputs() + context->num_outputs();
- std::vector<void*> buffers(num_binding);
+ // TODO(sami): Need an LRU here
+ if (smallest_engine == 0) {
+ if (max_cached_engines_ > cached_engine_batches_.size()) {
+ smallest_engine = num_batch;
+ cached_engine_batches_.push_back(num_batch);
+ VLOG(1) << "Running with batch size " << num_batch;
+ } else {
+ string s("Engine buffer is full. buffer limit= ");
+ StrAppend(&s, max_cached_engines_, ", current entries= ");
+ for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
+ StrAppend(&s, "Requested batch= ", num_batch);
+ LOG(ERROR) << s;
+ ctx->SetStatus(tensorflow::errors::ResourceExhausted(
+ "Requested batch size is not available and engine cache is full"));
+ return -1;
+ }
+ }
+ return smallest_engine;
+}
- size_t binding_index;
- int num_batch = 0;
- for (int i = 0; i < context->num_inputs(); i++) {
- // Grab the input tensor
- binding_index = trt_engine_ptr_->getBindingIndex(input_nodes_[i].c_str());
+void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
+ tensorflow::AsyncOpKernel::DoneCallback done) {
+ auto helper = new AsyncHelper(done);
+ tensorflow::core::ScopedUnref sc(helper);
+ if (calibration_mode_) {
+ ExecuteCalibration(ctx, helper);
+ return;
+ }
+ const int smallest_engine = GetEngineBatch(ctx);
+ if (smallest_engine < 0) return; // GetEngineBatch already set the status.
+
+ const int num_batch = ctx->input(0).shape().dim_size(0);
+ auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
+ auto& trt_engine_ptr = engine_ctx_pair.first;
+ if (!trt_engine_ptr) {
+ LOG(WARNING) << "Engine retrieval for batch size " << num_batch
+ << " failed Running native segment";
+ ExecuteNativeSegment(ctx, helper);
+ return;
+ }
- const Tensor& input_tensor = context->input(i);
+ const int num_binding = ctx->num_inputs() + ctx->num_outputs();
+ std::vector<void*> buffers(num_binding);
+ for (int i = 0; i < ctx->num_inputs(); i++) {
+ const string inp_name = StrCat(kInputPHName, i);
+ const size_t binding_index =
+ trt_engine_ptr->getBindingIndex(inp_name.c_str());
+
+ const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
- if (i == 0) {
- num_batch = input_shape.dim_size(0);
- if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
- LOG(FATAL) << "input tensor batch larger than max_batch_size: "
- << trt_engine_ptr_->getMaxBatchSize();
- }
- } else if (num_batch != input_shape.dim_size(0)) {
- LOG(FATAL) << "input data inconsistent batch size";
- break;
+ if (num_batch != input_shape.dim_size(0)) {
+ LOG(ERROR) << "input data inconsistent batch size";
+ ctx->SetStatus(tensorflow::errors::FailedPrecondition(
+ "Different batch sizes between input tensors"));
+ return;
}
- auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+ auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
- LOG(FATAL) << "half size is not supported yet!";
- break;
+ LOG(ERROR) << "FP16 inputs are not supported yet!";
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "FP16 inputs are not supported!"));
+ return;
case nvinfer1::DataType::kINT8:
- LOG(FATAL) << "int8 is not supported yet!";
- break;
+ LOG(ERROR) << "INT8 inputs are not supported yet!";
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "INT8 inputs are not supported!"));
+ return;
default:
- LOG(FATAL) << "Unknown data type: " << int(dtype);
- break;
+ LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "Unknown output TRT data type! ", static_cast<int>(dtype)));
+ return;
}
}
- for (int i = 0; i < static_cast<int>(output_nodes_.size()); i++) {
- // This is bad that we have to reallocate output buffer every run.
+ for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
- binding_index = trt_engine_ptr_->getBindingIndex(output_nodes_[i].c_str());
+ const string output_name = StrCat(kOutputPHName, i);
+ const size_t binding_index =
+ trt_engine_ptr->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
TensorShape output_shape;
if (binding_index != -1) {
- auto dims = trt_engine_ptr_->getBindingDimensions(binding_index);
+ auto dims = trt_engine_ptr->getBindingDimensions(binding_index);
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
- OP_REQUIRES_OK(context,
- TensorShapeUtils::MakeShape(
- trt_shape.data(), trt_shape.size(), &output_shape));
+ OP_REQUIRES_OK(
+ ctx, TensorShapeUtils::MakeShape(trt_shape.data(), trt_shape.size(),
+ &output_shape));
} else {
- LOG(FATAL) << "output node not found, at " << output_nodes_[i];
- break;
+ LOG(ERROR) << "output node not found, at " << output_name;
+ ctx->SetStatus(tensorflow::errors::Internal("output ", output_name,
+ " couldn't be found!"));
+ return;
}
-
- OP_REQUIRES_OK(context,
- context->allocate_output(i, output_shape, &output_tensor));
- auto dtype = trt_engine_ptr_->getBindingDataType(binding_index);
+ auto status = ctx->allocate_output(i, output_shape, &output_tensor);
+ if (!status.ok()) {
+ LOG(ERROR) << "Allocating output failed with " << status;
+ ctx->SetStatus(status);
+ return;
+ }
+ auto dtype = trt_engine_ptr->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
- LOG(FATAL) << "half size is not supported yet!";
- break;
+ LOG(ERROR) << "half size is not supported yet!";
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "Half outputs are not supported!"));
+ return;
case nvinfer1::DataType::kINT8:
- LOG(FATAL) << "int8 is not supported yet!";
- break;
+ LOG(ERROR) << "int8 is not supported yet!";
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "INT8 outputs are not supported!"));
+ return;
default:
- LOG(FATAL) << "Unknown data type: " << int(dtype);
- break;
+ LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype);
+ ctx->SetStatus(tensorflow::errors::InvalidArgument(
+ "Unsupported output data type! ", static_cast<int>(dtype)));
+ return;
}
}
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(context->op_device_context()
+ reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->CudaStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
- auto ret = trt_execution_context_ptr_->enqueue(num_batch, &buffers[0],
- *stream, nullptr);
- VLOG(2) << "enqueue returns: " << ret;
+ auto& trt_execution_context_ptr = engine_ctx_pair.second;
+ auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
+ nullptr);
+ if (!ret) {
+ LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
+ ctx->SetStatus(tensorflow::errors::Internal(
+ "Failed to enqueue batch for TRT engine: ", name()));
+ }
// sync should be done by TF.
}
+
TRTEngineOp::~TRTEngineOp() {
- // Order matters!
- trt_execution_context_ptr_.reset();
- trt_engine_ptr_.reset();
+ // We need to manually destroy the engine and execution context before
+ // the allocator is destructed.
+ for (auto& eng : engine_map_) {
+ eng.second.first.reset();
+ eng.second.second.reset();
+ }
allocator_.reset();
}
+
+nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
+ if (allocator_) return allocator_.get();
+ auto device = ctx->device();
+ auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes());
+ if (!alloc) {
+ LOG(ERROR) << "Can't find device allocator for gpu device "
+ << device->name();
+ ctx->SetStatus(tensorflow::errors::Internal(
+ "Can't get device allocator for device ", device->name()));
+ return nullptr;
+ }
+ allocator_.reset(new TRTDeviceAllocator(alloc));
+ return allocator_.get();
+}
+
+TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
+ OpKernelContext* ctx) {
+ static EngineCtxPair null_pair = {
+ TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
+ // TODO(sami): This method needs to be re-written to use resource manager and
+ // with LRU mechanism option.
+ tensorflow::mutex_lock lock(engine_mutex_);
+
+ if (static_engine_) {
+ if (engine_map_.size()) {
+ if (engine_map_.begin()->first >= batch_size) {
+ return engine_map_.begin()->second;
+ }
+ return null_pair;
+ }
+ TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
+#if NV_TENSORRT_MAJOR > 3
+ auto allocator = GetAllocator(ctx);
+ if (allocator == nullptr) {
+ // GetAllocator already set the Status.
+ return null_pair;
+ }
+ infer->setGpuAllocator(allocator);
+#endif
+ TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
+ infer->deserializeCudaEngine(serialized_segment_.c_str(),
+ serialized_segment_.size(), nullptr));
+ auto raw_static_engine = static_engine.get();
+ const auto max_batch_size = raw_static_engine->getMaxBatchSize();
+ engine_map_[max_batch_size] = {
+ std::move(static_engine),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(
+ raw_static_engine->createExecutionContext())};
+ // Runtime is safe to delete after engine creation
+ serialized_segment_.clear();
+ if (max_batch_size < batch_size) return null_pair;
+ return engine_map_.at(max_batch_size);
+ } // static_engine_
+
+ // Handle the dynamic engine case.
+ auto engine_it = engine_map_.find(batch_size);
+ if (engine_it == engine_map_.end() &&
+ engine_map_.size() < (size_t)max_cached_engines_) {
+ nvinfer1::IGpuAllocator* allocator = nullptr;
+#if NV_TENSORRT_MAJOR > 3
+ allocator = GetAllocator(ctx);
+ if (allocator == nullptr) {
+ // GetAllocator already set the Status.
+ return null_pair;
+ }
+#endif
+ std::vector<tensorflow::PartialTensorShape> shapes;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ shapes.emplace_back(ctx->input(i).shape());
+ }
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
+ bool convert_successfully = false;
+ VLOG(0) << name() << " Constructing a new engine with batch size "
+ << batch_size;
+ // Up to this point, calibrator_ can never be empty, since otherwise it
+ // means calibration_mode_ is true and this path won't get executed.
+ auto status = convert::ConvertGraphDefToEngine(
+ segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
+ &logger, allocator, calibrator_.get(), &engine, &convert_successfully);
+ if (!status.ok()) {
+ if (convert_successfully) {
+ // This means it fail to build the engine even when the network is built
+ // successfully, probably due to internal issues. In this case we don't
+ // retry in the future.
+ engine_map_[batch_size] = {nullptr, nullptr};
+ }
+ LOG(ERROR) << "Engine creation for batch size " << batch_size
+ << " failed " << status;
+ ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
+ return null_pair;
+ }
+ VLOG(1) << "Conversion is done";
+ TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
+ engine->createExecutionContext());
+ engine_map_[batch_size] = {std::move(engine), std::move(exec_context)};
+ }
+ return engine_map_.at(batch_size);
+}
+
+tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
+ tensorflow::OpKernelContext* ctx, TRTCalibrationResource** cr) {
+ auto cres = new TRTCalibrationResource();
+ *cr = cres;
+ // Get the allocator.
+ auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
+ if (!alloc) {
+ LOG(WARNING) << "Can't get device allocator will not be able to "
+ "allocate memory from TensorFlow memory pool";
+ cres->allocator_.reset(new TRTCudaAllocator);
+ } else {
+ cres->allocator_.reset(new TRTDeviceAllocator(alloc));
+ }
+ // Get the input shapes.
+ const int batch_size = ctx->input(0).dim_size(0);
+ const int num_inputs = ctx->num_inputs();
+ std::vector<tensorflow::PartialTensorShape> shapes;
+ dev_tensors_.resize(num_inputs);
+ VLOG(1) << " Constructing calibrator";
+ for (int i = 0; i < num_inputs; i++) {
+ // allocate workspace on device for inputs
+ const tensorflow::Tensor& t = ctx->input(i);
+ shapes.emplace_back(t.shape());
+ Tensor* device_tensor;
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ t.dtype(), t.shape(), &dev_tensors_.at(i), &device_tensor));
+ CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
+ void* device_address = GetTensorAddress(device_tensor);
+ if (device_address == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "Unsupported data type encountered in input ", i);
+ }
+ device_buffers_.emplace(
+ StrCat(kInputPHName, i),
+ std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
+ }
+ cres->calibrator_.reset(
+ new TRTInt8Calibrator(device_buffers_, batch_size, name()));
+ const string label(name());
+ auto segment_graph = &segment_graph_;
+ const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+ if (cuda_gpu_id < 0) {
+ LOG(ERROR) << "Can't get gpu_device_info from context->device()";
+ return tensorflow::errors::InvalidArgument(
+ "Context->device doesn't contain device info!");
+ }
+ const int64 workspace_size_bytes = workspace_size_;
+ cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
+ cuda_gpu_id, workspace_size_bytes]() {
+ VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
+ << ", Calibration Resource @ " << cres;
+ auto err = cudaSetDevice(cuda_gpu_id);
+ if (err != cudaSuccess) {
+ // TODO(aaroey): should return error here.
+ LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+ << " in calibration thread";
+ }
+ // ConvertGraphDefToEngine() will try to build the engine. This thread
+ // will loop inside buildCudaEngine() consuming the calibration data
+ // that is set by the TF op, and drive the builder until calibrator returns
+ // false. Engine is discarded after calibration table is generated
+ //
+ // TODO(aaroey): maybe setting the max batch size using the python
+ // calibration wrapper class.
+ auto s = convert::ConvertGraphDefToEngine(
+ *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
+ workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
+ cres->calibrator_.get(), &cres->engine_,
+ /*convert_successfully=*/nullptr);
+ if (!s.ok()) {
+ LOG(ERROR) << "Calibration failed: " << s;
+ cres->calibrator_->setDone(); // Ignore further pushes
+ }
+ VLOG(1) << "Calibration loop terminated " << label;
+ }));
+ VLOG(1) << "initialized calibrator resource";
+ return tensorflow::Status::OK();
+}
+
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index e613a71422..6fe318be6a 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -19,9 +19,14 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/mutex.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -30,32 +35,95 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-class Logger;
-
+class TRTInt8Calibrator;
+class TRTCalibrationResource;
+class AsyncHelper;
// TODO(Sami): Remove this file?
-class TRTEngineOp : public OpKernel {
+
+// This OP can construct TRTEngine on the fly and if construction of engine
+// fails, executes equivalent subgraph as a TensorFlow function.
+class TRTEngineOp : public AsyncOpKernel {
public:
explicit TRTEngineOp(OpKernelConstruction* context);
- void Compute(OpKernelContext* context) override;
+ void ComputeAsync(OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override;
~TRTEngineOp();
private:
- template <typename T>
- struct Destroyer {
- void operator()(T* d) { d->destroy(); }
- };
-
- template <typename T>
- using destroyed_ptr = std::unique_ptr<T, Destroyer<T>>;
- destroyed_ptr<nvinfer1::ICudaEngine> trt_engine_ptr_;
+ // Execute calibration
+ void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper);
+
+ // Construct a function handle for executing native funcdef graph
+ Status ConstructFunctionHandle(OpKernelContext* ctx);
+
+ // Execute replaced native segment as function Op.
+ void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper);
+
+ // Allocate necessary resources for calibration
+ Status AllocateCalibrationResources(OpKernelContext* ctx,
+ TRTCalibrationResource** cr);
+
// TODO(samikama): context should go to a resource manager!
- destroyed_ptr<nvinfer1::IExecutionContext> trt_execution_context_ptr_;
+ typedef std::pair<TrtUniquePtrType<nvinfer1::ICudaEngine>,
+ TrtUniquePtrType<nvinfer1::IExecutionContext>>
+ EngineCtxPair;
+ EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx);
+ // Return engine batch closest to input batch.
+ int GetEngineBatch(OpKernelContext* ctx);
+
+ nvinfer1::IGpuAllocator* GetAllocator(OpKernelContext* ctx);
+
+ // map to keep engines and their execution context for given batch size.
+ std::unordered_map<int, EngineCtxPair> engine_map_;
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
- string serialized_engine_;
+
+ // keep device allocator for TRT.
+ std::unique_ptr<TRTDeviceAllocator> allocator_;
+
+ // serialized protobuf segment or trt engine depending on static_engine_ flag.
+ string serialized_segment_;
+
+ // Name of the function for TF native execution of the segment.
+ string funcdef_name_;
+
+ // GraphDef representation of the segment.
+ GraphDef segment_graph_;
+
+ // Lookup table for temporary staging areas of input tensors for calibration.
+ std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
+
+ // Temporary staging areas for calibration inputs.
+ std::vector<PersistentTensor> dev_tensors_;
+
+ // Engine Precision mode.
+ int precision_mode_;
+
+ // Whether engine is constructed during the conversion or needs to be
+ // constructed from protobuf segment.
+ bool static_engine_;
+
+ // Whether to calibrate INT8 engine.
+ bool calibration_mode_;
+
+ // Whether non-batch ranks of the inputs are assumed to be fixed or not for
+ // engine construction.
+ bool fixed_input_size_;
+
+ // Batches of the cached engines
+ std::vector<int> cached_engine_batches_;
+
+ // Maximum number of cached engines
+ int max_cached_engines_;
+
+ int64 workspace_size_;
+ mutex engine_mutex_;
+ FunctionLibraryRuntime::Handle native_func_;
+
+ // The finalized calibrator for inference.
+ std::unique_ptr<TRTInt8Calibrator> calibrator_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc
deleted file mode 100644
index 4835e50650..0000000000
--- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc
+++ /dev/null
@@ -1,37 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-namespace tensorflow {
-
-REGISTER_OP("TRTCalibOp")
- .Attr("segment_nodes: list(string)") // names of the ops in segment
- .Attr("segment_output_names: list(string)") // names of the output ops in
- // segment
- .Attr("input_names: list(string)") // names of the inputs for
- // passing into tensorrt
- .Attr("resource_name: string")
- .Attr("InT: list({int8, float16, float32})")
- .Input("in_tensor: InT")
- .Output("out_tensor: InT")
- .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_inputs(); i++) {
- c->set_output(i, c->input(i));
- }
- return Status::OK();
- });
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
index 079d73f7be..383635f428 100644
--- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -28,11 +28,19 @@ extern Status TRTEngineOpShapeInference(InferenceContext* c);
}
REGISTER_OP("TRTEngineOp")
- .Attr("serialized_engine: string")
- .Attr("input_nodes: list(string)")
- .Attr("output_nodes: list(string)")
- .Attr("InT: list({float32})")
- .Attr("OutT: list({float32})")
+ .Attr("serialized_segment: string")
+ .Attr("input_shapes: list(shape)")
+ .Attr("output_shapes: list(shape)")
+ .Attr("segment_funcdef_name: string")
+ .Attr("InT: list({int8,float16,float32})")
+ .Attr("OutT: list({int8,float16,float32})")
+ .Attr("static_engine: bool = true")
+ .Attr("fixed_input_size: bool = true")
+ .Attr("cached_engine_batches: list(int) = []")
+ .Attr("max_cached_engines_count: int = 1")
+ .Attr("workspace_size_bytes: int")
+ .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
+ .Attr("calibration_data: string = ''")
.Input("in_tensor: InT")
.Output("out_tensor: OutT")
.SetShapeFn(shape_inference::TRTEngineOpShapeInference);
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 338475d90e..79f512dbcf 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -21,6 +21,8 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long
import six as _six
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
+from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
+from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import trt_convert
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
@@ -29,7 +31,9 @@ from tensorflow.python.framework import errors_impl as _impl
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.platform import tf_logging
from tensorflow.python.util import compat
+
# pylint: enable=unused-import,line-too-long
@@ -40,7 +44,10 @@ def create_inference_graph(input_graph_def,
max_batch_size=1,
max_workspace_size_bytes=2 << 20,
precision_mode="FP32",
- minimum_segment_size=3):
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=[]):
"""Python wrapper for the TRT transformation.
Args:
@@ -51,6 +58,10 @@ def create_inference_graph(input_graph_def,
precision_mode: one of 'FP32', 'FP16' and 'INT8'
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ cached_engine_batches: batch sizes used to pre-create cached engines.
Returns:
New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
@@ -65,6 +76,30 @@ def create_inference_graph(input_graph_def,
"It should be one of {}").format(
precision_mode, "{'FP32', 'FP16', 'INT8'}"))
mode = supported_precision_modes[precision_mode.upper()]
+ compiled_version = get_linked_tensorrt_version()
+ loaded_version = get_loaded_tensorrt_version()
+ version_mismatch = False
+ if loaded_version[0] < compiled_version[0]:
+ tf_logging.error(
+ "TensorRT version mismatch. Tensorflow was compiled against " +
+ "TensorRT %s but library loaded from environment is TensorRT %s" %
+ (".".join([str(x) for x in compiled_version]),
+ ".".join([str(x) for x in loaded_version])) +
+ ". Please make sure that correct version of TensorRT " +
+ "is available in the system and added to ldconfig or LD_LIBRARY_PATH"
+ )
+ raise RuntimeError("Incompatible TensorRT library version")
+ for i in zip(loaded_version, compiled_version):
+ if i[0] != i[1]:
+ tf_logging.warn("TensorRT mismatch. Compiled against version " +
+ "%s, but loaded %s. Things may not work" %
+ (".".join([str(x) for x in compiled_version]),
+ ".".join([str(x) for x in loaded_version])))
+ version_mismatch = True
+ break
+ if not version_mismatch:
+ tf_logging.info("Running against TensorRT version %s" % ".".join(
+ [str(x) for x in loaded_version]))
def py2bytes(inp):
return inp
@@ -100,7 +135,9 @@ def create_inference_graph(input_graph_def,
# pair or strings where first one is encoded status and the second
# one is the transformed graphs protobuf string.
out = trt_convert(input_graph_def_str, out_names, max_batch_size,
- max_workspace_size_bytes, mode, minimum_segment_size)
+ max_workspace_size_bytes, mode, minimum_segment_size,
+ is_dynamic_op, maximum_cached_engines,
+ cached_engine_batches)
status = to_string(out[0])
output_graph_def_string = out[1]
del input_graph_def_str # Save some memory
@@ -120,11 +157,12 @@ def create_inference_graph(input_graph_def,
return output_graph_def
-def calib_graph_to_infer_graph(calibration_graph_def):
+def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
"""Convert an existing calibration graph to inference graph.
Args:
calibration_graph_def: the calibration GraphDef object with calibration data
+ is_dynamic_op: whether to create dynamic static engines from calibration
Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises:
@@ -141,9 +179,16 @@ def calib_graph_to_infer_graph(calibration_graph_def):
to_string = py2string
else:
to_string = py3string
-
+ is_calib_graph = False
+ for n in calibration_graph_def.node:
+ if n.op == "TRTEngineOp":
+ is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
+ if not is_calib_graph:
+ tf_logging.error(
+ "Not a calib graph. Doesn't seem to contain any calibration nodes.")
+ return None
graph_str = calibration_graph_def.SerializeToString()
- out = calib_convert(graph_str)
+ out = calib_convert(graph_str, is_dynamic_op)
status = to_string(out[0])
output_graph_def_string = out[1]
del graph_str # Save some memory
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
index 0f0508331c..9f115990c3 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -50,7 +50,7 @@ TRTDeviceAllocator::TRTDeviceAllocator(tensorflow::Allocator* allocator)
}
void TRTDeviceAllocator::free(void* memory) {
- VLOG(2) << "Deallocating " << memory;
+ VLOG(2) << "Deallocating @ " << memory;
allocator_->DeallocateRaw(memory);
}
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
index a0c2540a76..c5d2cec730 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -16,7 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_ALLOCATOR_H_
-
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/core/framework/allocator.h"
@@ -52,7 +51,9 @@ class TRTDeviceAllocator : public nvinfer1::IGpuAllocator {
// Allocator implementation wrapping TF device allocators.
public:
TRTDeviceAllocator(tensorflow::Allocator* allocator);
- virtual ~TRTDeviceAllocator() {}
+ virtual ~TRTDeviceAllocator() {
+ VLOG(1) << "Destroying allocator attached to " << allocator_->Name();
+ }
void* allocate(uint64_t size, uint64_t alignment, uint32_t flags) override;
void free(void* memory) override;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index dc7c93f869..dab1dd9343 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include <atomic>
-#include <chrono>
#include <unordered_map>
#include "tensorflow/core/platform/logging.h"
@@ -37,20 +36,29 @@ TRTInt8Calibrator::TRTInt8Calibrator(
: batch_size_(batch_size),
done_(false),
dev_buffers_(dev_buffers),
- calib_running_(false),
+ // Make sure setBatch() waits until getBatch() is called (the first time).
+ calib_running_(true),
batch_is_set_(false),
engine_name_(engine_name) {}
+TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
+ : batch_size_(0),
+ done_(true),
+ calib_running_(false),
+ batch_is_set_(false),
+ calibration_table_(calib_data) {}
+
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- while ((calib_running_ || batch_is_set_) &&
- !done_) { // wait while calibration is running
- cond_.wait(lock);
- }
+
+ // Wait while the queue is full or calibration is running.
+ while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
if (done_) return false;
CHECK(!calib_running_ && !batch_is_set_);
VLOG(1) << "Set Batch Waiting finished";
+
+ // Sets the batch.
for (const auto it : data) {
auto devptr = dev_buffers_.find(it.first);
if (devptr == dev_buffers_.end()) {
@@ -59,8 +67,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}
const auto& d = devptr->second;
- // TODO(aaroey): we should not use sync copy on default stream. Make sure
- // stream->ThenMemcpy() is used in future PRs.
// TODO(sami,aaroey): Need to figure out a way to ensure synchronization
// between stream, perhaps using a tensor?
auto status = cudaMemcpyAsync(d.first, it.second, d.second,
@@ -72,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}
// TODO(Sami, aaorey): Find an alternative way!
- cudaStreamSynchronize(
- stream); // we have to wait for the stream before returning!
+ // we have to wait for the stream before returning!
+ cudaStreamSynchronize(stream);
batch_is_set_ = true;
cond_.notify_all();
return true;
@@ -82,23 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
tensorflow::mutex_lock lock(cond_mtx_);
+ // Notify finish of last round of calibration.
calib_running_ = false;
cond_.notify_all();
- while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
- cond_.wait(lock);
- }
- if (done_) {
- return false;
- }
+ // Wait until new batch arrives
+ while ((!batch_is_set_ && !done_)) cond_.wait(lock);
+ if (done_) return false;
+ // Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
if (it == dev_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}
-
bindings[i] = it->second.first;
}
batch_is_set_ = false;
@@ -106,8 +110,21 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
return true;
}
+void TRTInt8Calibrator::waitAndSetDone() {
+ tensorflow::mutex_lock lock(cond_mtx_);
+ // Wait while the queue is full or calibration is running, so we don't miss
+ // the last batch.
+ while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
+ if (!done_) {
+ done_ = true;
+ cond_.notify_all();
+ }
+}
+
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
- return nullptr;
+ if (calibration_table_.empty()) return nullptr;
+ length = calibration_table_.size();
+ return calibration_table_.data();
}
void TRTInt8Calibrator::setDone() {
@@ -117,7 +134,11 @@ void TRTInt8Calibrator::setDone() {
}
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
- std::size_t length) {}
+ std::size_t length) {
+ calibration_table_ = string((const char*)ptr, length);
+ VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
+ << " length=" << length;
+}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(1) << "Destroying calibrator for " << engine_name_;
}
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index d77aa2c5ab..65466c9741 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -36,32 +36,59 @@ namespace tensorrt {
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
public:
+ // Construct a calibrator for future calibration.
TRTInt8Calibrator(
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
int batch_size, string engine_name);
+
+ // Construct a finalized calibrator where we don't need to run calibration any
+ // more, as the calibration data is provided.
+ TRTInt8Calibrator(const string& calibration_data);
+
+ ~TRTInt8Calibrator();
+
int getBatchSize() const override;
+
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
+
bool setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream);
+
+ // Wait until the last batch is consumed by the calibrator and set done.
+ void waitAndSetDone();
+
+ // Notify that calibration is done and future batches provided by setBatch()
+ // will be ignored.
void setDone();
+
+ // If not null, calibration is skipped.
const void* readCalibrationCache(std::size_t& length) override;
+
void writeCalibrationCache(const void* ptr, std::size_t length) override;
- ~TRTInt8Calibrator();
+
+ const string& getCalibrationTableAsString() { return calibration_table_; }
private:
const int batch_size_;
- tensorflow::mutex cond_mtx_; // mutex for condition_variable
- tensorflow::condition_variable cond_; // condition variable to implement
- // producer-consumer queue for
- // calibration
+
+ // mutex for condition_variable
+ tensorflow::mutex cond_mtx_;
+
+ // condition variable to implement producer-consumer queue for calibration
+ tensorflow::condition_variable cond_;
+
+ // Is calibration finished?
bool done_;
- const std::unordered_map<string, std::pair<void*, size_t>>
- dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with
- // buffer names
+
+ // Map to keep tensorrt input buffers and sizes keyed with buffer names
+ const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
+
bool calib_running_;
bool batch_is_set_;
+
string engine_name_;
+ string calibration_table_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index e3469124ac..b7d5ffd674 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <thread>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
@@ -34,50 +35,48 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
+
class TRTCalibrationResource : public tensorflow::ResourceBase {
public:
- TRTCalibrationResource()
- : calibrator_(nullptr),
- builder_(nullptr),
- network_(nullptr),
- engine_(nullptr),
- logger_(nullptr),
- thr_(nullptr) {}
-
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
+ builder_.reset();
+ engine_.reset();
+ // We need to manually destroy the builder and engine before the allocator
+ // is destroyed.
+ allocator_.reset();
}
string DebugString() override {
std::stringstream oss;
- oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl
- << " Builder = " << std::hex << builder_ << std::dec << std::endl
- << " Network = " << std::hex << network_ << std::dec << std::endl
- << " Engine = " << std::hex << engine_ << std::dec << std::endl
- << " Logger = " << std::hex << logger_ << std::dec << std::endl
- << " Allocator = " << std::hex << allocator_.get() << std::dec
- << std::endl
- << " Thread = " << std::hex << thr_ << std::dec << std::endl;
+ using std::dec;
+ using std::endl;
+ using std::hex;
+ oss << " Calibrator = " << hex << calibrator_.get() << dec << endl
+ << " Builder = " << hex << builder_.get() << dec << endl
+ << " Engine = " << hex << engine_.get() << dec << endl
+ << " Logger = " << hex << &logger_ << dec << endl
+ << " Allocator = " << hex << allocator_.get() << dec << endl
+ << " Thread = " << hex << thr_.get() << dec << endl;
return oss.str();
}
- TRTInt8Calibrator* calibrator_;
- nvinfer1::IBuilder* builder_;
- nvinfer1::INetworkDefinition* network_;
- nvinfer1::ICudaEngine* engine_;
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
- tensorflow::tensorrt::Logger* logger_;
+ std::unique_ptr<TRTInt8Calibrator> calibrator_;
+ TrtUniquePtrType<nvinfer1::IBuilder> builder_;
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
+ std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
+ tensorflow::tensorrt::Logger logger_;
// TODO(sami): Use threadpool threads!
- std::thread* thr_;
+ std::unique_ptr<std::thread> thr_;
};
-class TRTWeightStore : public tensorflow::ResourceBase {
+class TRTWeightStore {
public:
TRTWeightStore() {}
virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); }
- string DebugString() override {
+ string DebugString() {
std::stringstream oss;
size_t len_bytes = 0;
for (const auto& v : store_) {
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 1568dd9153..81b4bfe49f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,8 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-// vector of segments, each entry contains a device name and a set of nodes in
-// segment
+// Vector of segments, each entry contains a set of node names and a device name
+// in the segment.
+// TODO(aaroey): use node pointer instead of node name.
using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
@@ -48,6 +49,8 @@ struct SegmentOptions {
// in the vector describes a subgraph by giving a set of the names of
// all the NodeDefs in that subgraph.
// @return the status.
+//
+// TODO(aaroey): remove this method.
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index f36495f6b6..227ac120dd 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -29,61 +29,35 @@ namespace tensorflow {
namespace shape_inference {
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
- tensorflow::tensorrt::Logger logger;
- string serialized_engine;
- TF_RETURN_IF_ERROR(context->GetAttr("serialized_engine", &serialized_engine));
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
- nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(),
- tensorrt::PluginFactoryTensorRT::GetInstance());
-
- int num_batch = -1;
- std::vector<::tensorflow::DataType> input_type;
- TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
- for (size_t i = 0; i < context->num_inputs(); i++) {
- // Check if input shape is legit
- auto input_shape = context->input(i);
- for (int j = 0; j < context->Rank(input_shape); j++) {
- auto dim_handler = context->Dim(input_shape, j);
- if (j == 0) {
- if (i == 0) {
- num_batch = context->Value(dim_handler);
- } else if (num_batch != context->Value(dim_handler)) {
- // TODO(jie): TensorRT engine requires consistent batch between inputs
- // tensors. Segmenter should be aware of this.
- LOG(FATAL) << "TensorRT engine requires consistent batch size";
- }
- }
- }
+ std::vector<tensorflow::TensorShape> shapes;
+ for (int i = 0; i < context->num_outputs(); ++i) {
+ context->set_output(i, context->UnknownShape());
}
-
- // Arrange input here
- std::vector<string> input_nodes;
- TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes));
-
- // Arrange output here
- std::vector<string> output_nodes;
- TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes));
- for (size_t i = 0; i < output_nodes.size(); i++) {
- int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
- ShapeHandle output_shape;
- std::vector<DimensionHandle> dim_vec;
- dim_vec.emplace_back(context->MakeDim(num_batch));
- if (binding_index != -1) {
- auto dims = trt_engine->getBindingDimensions(binding_index);
- for (int j = 0; j < dims.nbDims; j++) {
- dim_vec.emplace_back(context->MakeDim(dims.d[j]));
- }
- } else {
- LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
- }
- output_shape = context->MakeShape(dim_vec);
- context->set_output(i, output_shape);
+ auto status = context->GetAttr("input_shapes", &shapes);
+ // it is ok to not to have shapes
+ if (!status.ok()) return Status::OK();
+ if ((int)shapes.size() != context->num_inputs()) return Status::OK();
+ bool different_input = false;
+ for (int i = 0; i < context->num_inputs(); ++i) {
+ if (shapes.at(i) != context->input_tensor(i)->shape())
+ different_input = true;
+ }
+ if (different_input) return Status::OK();
+ shapes.resize(0);
+ status = context->GetAttr("output_shapes", &shapes);
+ if (!status.ok()) return Status::OK();
+ if ((int)shapes.size() != context->num_outputs()) return Status::OK();
+ std::vector<ShapeHandle> shape_handles(shapes.size());
+ for (size_t i = 0; i < shapes.size(); ++i) {
+ status =
+ context->MakeShapeFromTensorShape(shapes.at(i), &shape_handles.at(i));
+ if (!status.ok()) return Status::OK();
+ }
+ for (int i = 0; i < context->num_outputs(); ++i) {
+ context->set_output(i, shape_handles.at(i));
}
-
return Status::OK();
}
-
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 175ccd8006..090aa8bdb0 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import argparse
import numpy as np
+import six as _six
# normally we should do import tensorflow as tf and then
# tf.placeholder, tf.constant, tf.nn.conv2d etc but
@@ -35,10 +36,75 @@ from tensorflow.python.framework import dtypes as dtypes
from tensorflow.python.framework import importer as importer
from tensorflow.python.framework import ops as ops
from tensorflow.python.ops import array_ops as aops
+from tensorflow.python.ops import math_ops as mops
from tensorflow.python.ops import nn as nn
from tensorflow.python.ops import nn_ops as nn_ops
+def py2bytes(inp):
+ return inp
+
+
+def py3bytes(inp):
+ return inp.encode("utf-8", errors="surrogateescape")
+
+
+def py2string(inp):
+ return inp
+
+
+def py3string(inp):
+ return inp.decode("utf-8")
+
+
+if _six.PY2:
+ to_bytes = py2bytes
+ to_string = py2string
+else:
+ to_bytes = py3bytes
+ to_string = py3string
+
+
+def get_multi_engine_graph_def(mode="FP32"):
+ """Create a simple graph and return its graph_def."""
+ dtype = dtypes.float32
+ if mode.upper() == "FP16":
+ dtype = dtypes.float16
+ else:
+ pass
+
+ g = ops.Graph()
+ with g.as_default():
+ x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype)
+ with g.name_scope("Global_scope"):
+ with g.name_scope("first_scope"):
+ e = cop.constant(
+ np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype)
+ conv = nn.conv2d(
+ input=x,
+ filter=e,
+ data_format="NCHW",
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv")
+ b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype)
+ t = conv * b
+
+ b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype)
+ q = conv / b
+ edge = mops.sin(q)
+ edge1 = mops.cos(conv)
+ with g.name_scope("test_scope"):
+ de = edge + edge1
+ t -= edge1
+ q *= edge
+ t += q
+ t -= de
+ k = aops.squeeze(t, name="output")
+ print(k.dtype)
+ return g.as_graph_def()
+
+
def get_simple_graph_def():
"""Create a simple graph and return its graph_def."""
g = ops.Graph()
@@ -65,7 +131,9 @@ def get_simple_graph_def():
def execute_graph(gdef, dumm_inp):
"""Run given graphdef once."""
print("executing")
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ gpu_options = None
+ if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
ops.reset_default_graph()
g = ops.Graph()
@@ -83,7 +151,9 @@ def execute_graph(gdef, dumm_inp):
# for calibration. For this test script it is random data.
def execute_calibration(gdef, dumm_inp):
"""Run given calibration graph multiple times."""
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ gpu_options = None
+ if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph()
g = ops.Graph()
with g.as_default():
@@ -100,12 +170,17 @@ def execute_calibration(gdef, dumm_inp):
return val
-def user(run_graph=execute_graph, run_calibration=execute_calibration):
+def user(multi_engine,
+ run_graph=execute_graph,
+ run_calibration=execute_calibration):
"""Example function that converts a graph to TFTRT graph."""
-
- inp_dims = (100, 24, 24, 2)
+ if multi_engine:
+ inp_dims = (2, 3, 7, 5)
+ orig_graph = get_multi_engine_graph_def()
+ else:
+ inp_dims = (100, 24, 24, 2)
+ orig_graph = get_simple_graph_def() # use a frozen graph for inference
dummy_input = np.random.random_sample(inp_dims)
- orig_graph = get_simple_graph_def() # use a frozen graph for inference
# Get optimized graph
trt_graph = trt.create_inference_graph(
input_graph_def=orig_graph,
@@ -113,8 +188,10 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25,
precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
+ minimum_segment_size=2, # minimum number of nodes in an engine
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=[])
o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
@@ -126,40 +203,51 @@ def user(run_graph=execute_graph, run_calibration=execute_calibration):
max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25,
precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
+ minimum_segment_size=2, # minimum number of nodes in an engine
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=[])
int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph,
outputs=["output"],
max_batch_size=inp_dims[0],
max_workspace_size_bytes=1 << 25,
precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
+ minimum_segment_size=2, # minimum number of nodes in an engine
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=[])
o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
o5 = run_graph(int8_graph, dummy_input)
- assert np.allclose(o1, o4)
- assert np.allclose(o1, o5)
+ print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4))
+ print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5))
print("Pass")
-def auto():
+def auto(multi_engine):
"""Run the conversion as an optimization pass."""
- inp_dims = (100, 24, 24, 2)
+ if multi_engine:
+ inp_dims = (2, 3, 7, 5)
+ orig_graph = get_multi_engine_graph_def()
+ else:
+ inp_dims = (100, 24, 24, 2)
+ orig_graph = get_simple_graph_def() # use a frozen graph for inference
dummy_input = np.random.random_sample(inp_dims)
- orig_graph = get_simple_graph_def()
opt_config = rwpb2.RewriterConfig()
+ opt_config.meta_optimizer_iterations = opt_config.ONE
opt_config.optimizers.extend(["constfold", "layout"])
custom_op = opt_config.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
custom_op.parameter_map["minimum_segment_size"].i = 3
- custom_op.parameter_map["precision_mode"].s = "FP32"
+ custom_op.parameter_map["precision_mode"].s = to_bytes("FP32")
custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
print(custom_op)
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ gpu_options = None
+ if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
+ gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
graph_options = cpb2.GraphOptions(rewrite_options=opt_config)
sessconfig = cpb2.ConfigProto(
gpu_options=gpu_options, graph_options=graph_options)
@@ -168,7 +256,7 @@ def auto():
ops.reset_default_graph()
with g.as_default():
inp, out = importer.import_graph_def(
- graph_def=orig_graph, return_elements=["input", "output"])
+ graph_def=orig_graph, return_elements=["input", "output"], name="")
inp = inp.outputs[0]
out = out.outputs[0]
with csess.Session(config=sessconfig, graph=g) as sess:
@@ -186,8 +274,14 @@ if "__main__" in __name__:
action="store_true",
help="Do TRT conversion automatically",
default=False)
+ P.add_argument(
+ "--multi-engine",
+ "-m",
+ action="store_true",
+ help="Use a graph that will result in 2 engines",
+ default=False)
flags, unparsed = P.parse_known_args()
if flags.automatic:
- auto()
+ auto(flags.multi_engine)
else:
- user()
+ user(flags.multi_engine)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
index 0403b652d7..d9c41f90d0 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test.py
@@ -18,131 +18,330 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from collections import namedtuple
+import itertools
import warnings
import numpy as np
+import six
from tensorflow.contrib import tensorrt as trt
-from tensorflow.core.protobuf import config_pb2 as cpb2
-from tensorflow.python.framework import constant_op as cop
-from tensorflow.python.framework import dtypes as dtypes
-from tensorflow.python.framework import importer as importer
-from tensorflow.python.framework import ops as ops
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops as aops
-from tensorflow.python.ops import nn as nn
-from tensorflow.python.ops import nn_ops as nn_ops
-from tensorflow.python.platform import googletest
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+INPUT_NAME = "input"
+OUTPUT_NAME = "output"
+INPUT_DIMS = [100, 24, 24, 2]
+MODE_FP32 = "FP32"
+MODE_FP16 = "FP16"
+MODE_INT8 = "INT8"
-class IntegrationTest(test_util.TensorFlowTestCase):
+if six.PY2:
+ to_bytes = lambda s: s
+ to_string = lambda s: s
+else:
+ to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+ to_string = lambda s: s.decode("utf-8")
+
+
+# TODO(aaroey): test graph with different dtypes.
+def GetSingleEngineGraphDef(dtype=dtypes.float32):
+ """Create a graph containing single segment."""
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
+ with g.device("/GPU:0"):
+ conv_filter = constant_op.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=conv_filter,
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ name="conv")
+ bias = constant_op.constant(
+ [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
+ added = nn.bias_add(conv, bias, name="bias_add")
+ relu = nn.relu(added, "relu")
+ identity = array_ops.identity(relu, "identity")
+ pool = nn_ops.max_pool(
+ identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+ array_ops.squeeze(pool, name=OUTPUT_NAME)
+ return g.as_graph_def()
+
+
+# TODO(aaroey): test graph with different dtypes.
+def GetMultiEngineGraphDef(dtype=dtypes.float32):
+ """Create a graph containing multiple segment."""
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
+ with g.device("/GPU:0"):
+ conv_filter = constant_op.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=dtype)
+ conv = nn.conv2d(
+ input=inp,
+ filter=conv_filter,
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ name="conv")
+ c1 = constant_op.constant(
+ np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
+ p = conv * c1
+ c2 = constant_op.constant(
+ np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
+ q = conv / c2
+
+ edge = math_ops.sin(q)
+ edge /= edge
+ r = edge + edge
+
+ p -= edge
+ q *= edge
+ s = p + q
+ s -= r
+ array_ops.squeeze(s, name=OUTPUT_NAME)
+ return g.as_graph_def()
+
+
+TestGraph = namedtuple("TestGraph",
+ ["gdef", "num_expected_engines", "expected_output_dims"])
+
+TEST_GRAPHS = {
+ "SingleEngineGraph":
+ TestGraph(
+ gdef=GetSingleEngineGraphDef(),
+ num_expected_engines=1,
+ expected_output_dims=(100, 6, 6, 6)),
+ "MultiEngineGraph":
+ TestGraph(
+ gdef=GetMultiEngineGraphDef(),
+ num_expected_engines=2,
+ expected_output_dims=(100, 12, 12, 6)),
+ # TODO(aaroey): add a large complex graph to test.
+}
+
+
+class TfTrtIntegrationTest(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""
def setUp(self):
"""Setup method."""
- super(IntegrationTest, self).setUp()
+ super(TfTrtIntegrationTest, self).setUp()
warnings.simplefilter("always")
- inp_dims = (100, 24, 24, 2)
- self._input = np.random.random_sample(inp_dims)
- self._original_graph = self.get_simple_graph_def()
- self._gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- self._config = cpb2.ConfigProto(gpu_options=self._gpu_options)
- self._reference = self.run_graph(self._original_graph, self._input)
-
- def get_simple_graph_def(self):
- """Create a simple graph and return its graph_def."""
- g = ops.Graph()
- with g.as_default():
- a = aops.placeholder(
- dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
- e = cop.constant(
- [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
- name="weights",
- dtype=dtypes.float32)
- conv = nn.conv2d(
- input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
- b = cop.constant(
- [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
- t = nn.bias_add(conv, b, name="biasAdd")
- relu = nn.relu(t, "relu")
- idty = aops.identity(relu, "ID")
- v = nn_ops.max_pool(
- idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- aops.squeeze(v, name="output")
- return g.as_graph_def()
-
- def run_graph(self, gdef, dumm_inp):
- """Run given graphdef once."""
- ops.reset_default_graph()
+ self._input = np.random.random_sample(INPUT_DIMS)
+
+ def _GetConfigProto(self,
+ use_optimizer,
+ precision_mode=None,
+ is_dynamic_op=None):
+ if use_optimizer:
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ custom_op = rewriter_cfg.custom_optimizers.add()
+ custom_op.name = "TensorRTOptimizer"
+ custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["max_batch_size"].i = self._input.shape[0]
+ custom_op.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
+ custom_op.parameter_map["precision_mode"].s = to_bytes(precision_mode)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
+ else:
+ graph_options = config_pb2.GraphOptions()
+
+ gpu_options = config_pb2.GPUOptions()
+ if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
+ gpu_options.per_process_gpu_memory_fraction = 0.50
+
+ config = config_pb2.ConfigProto(
+ gpu_options=gpu_options, graph_options=graph_options)
+ return config
+
+ def _RunGraph(self, graph_key, gdef, input_data, config, num_runs=2):
+ """Run given graphdef multiple times."""
g = ops.Graph()
with g.as_default():
inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=["input", "output"])
+ graph_def=gdef, return_elements=[INPUT_NAME, OUTPUT_NAME], name="")
inp = inp.outputs[0]
out = out.outputs[0]
with self.test_session(
- graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess:
- val = sess.run(out, {inp: dumm_inp})
+ graph=g, config=config, use_gpu=True, force_gpu=True) as sess:
+ val = None
+ # Defaults to 2 runs to verify result across multiple runs is same.
+ for _ in range(num_runs):
+ new_val = sess.run(out, {inp: input_data})
+ self.assertEquals(TEST_GRAPHS[graph_key].expected_output_dims,
+ new_val.shape)
+ if val is not None:
+ self.assertAllEqual(new_val, val)
+ val = new_val
return val
# Use real data that is representative of the inference dataset
# for calibration. For this test script it is random data.
- def run_calibration(self, gdef, dumm_inp):
- """Run given calibration graph multiple times."""
- ops.reset_default_graph()
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=["input", "output"])
- inp = inp.outputs[0]
- out = out.outputs[0]
- # run over real calibration data here, we are mimicking a calibration
- # set of 30 different batches. Use as much calibration data as you want
- with self.test_session(
- graph=g, config=self._config, use_gpu=True, force_gpu=True) as sess:
- for _ in range(30):
- val = sess.run(out, {inp: dumm_inp})
- return val
+ def _RunCalibration(self, graph_key, gdef, input_data, config):
+ """Run calibration on given graph."""
+ return self._RunGraph(graph_key, gdef, input_data, config, 30)
- def get_trt_graph(self, mode):
+ def _GetTrtGraph(self, gdef, precision_mode, is_dynamic_op):
"""Return trt converted graph."""
- if mode in ["FP32", "FP16", "INT8"]:
- return trt.create_inference_graph(
- input_graph_def=self._original_graph,
- outputs=["output"],
- max_batch_size=self._input.shape[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode=mode, # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
- return None
-
- def testFP32(self):
- """Test FP32 conversion. Results should be identical to native case."""
- trt_graph = self.get_trt_graph("FP32")
- result = self.run_graph(trt_graph, self._input)
- self.assertAllEqual(self._reference, result)
- result1 = self.run_graph(trt_graph, self._input)
- self.assertAllEqual(result1, result)
-
- def testFP16(self):
- """Test FP16 conversion. Results may be different from native case."""
- trt_graph = self.get_trt_graph("FP16")
- result = self.run_graph(trt_graph, self._input)
- self.assertAllClose(self._reference, result, rtol=1.e-03)
- result1 = self.run_graph(trt_graph, self._input)
- self.assertAllEqual(result1, result)
-
- def testINT8(self):
- """Test INT8 conversion. Results may be different from native case."""
- calib_graph = self.get_trt_graph("INT8")
- result = self.run_calibration(calib_graph, self._input)
- self.assertAllEqual(self._reference, result)
- int8_graph = trt.calib_graph_to_infer_graph(calib_graph)
- result = self.run_graph(int8_graph, self._input)
- self.assertAllClose(self._reference, result, rtol=1.e-03)
- result1 = self.run_graph(int8_graph, self._input)
- self.assertAllEqual(result1, result)
+ return trt.create_inference_graph(
+ input_graph_def=gdef,
+ outputs=[OUTPUT_NAME],
+ max_batch_size=self._input.shape[0],
+ max_workspace_size_bytes=1 << 25,
+ precision_mode=precision_mode,
+ minimum_segment_size=2,
+ is_dynamic_op=is_dynamic_op)
+
+ def _VerifyGraphDef(self,
+ graph_key,
+ gdef,
+ precision_mode=None,
+ is_calibrated=None,
+ dynamic_engine=None):
+ num_engines = 0
+ for n in gdef.node:
+ if n.op == "TRTEngineOp":
+ num_engines += 1
+ self.assertNotEqual("", n.attr["serialized_segment"].s)
+ self.assertNotEqual("", n.attr["segment_funcdef_name"].s)
+ self.assertEquals(n.attr["precision_mode"].s, precision_mode)
+ self.assertEquals(n.attr["static_engine"].b, not dynamic_engine)
+ if precision_mode == MODE_INT8 and is_calibrated:
+ self.assertNotEqual("", n.attr["calibration_data"].s)
+ else:
+ self.assertEquals("", n.attr["calibration_data"].s)
+ if precision_mode is None:
+ self.assertEquals(num_engines, 0)
+ else:
+ self.assertEquals(num_engines,
+ TEST_GRAPHS[graph_key].num_expected_engines)
+
+ def _RunTest(self, graph_key, use_optimizer, precision_mode,
+ dynamic_infer_engine, dynamic_calib_engine):
+ assert precision_mode in [MODE_FP32, MODE_FP16, MODE_INT8]
+ input_gdef = TEST_GRAPHS[graph_key].gdef
+ self._VerifyGraphDef(graph_key, input_gdef)
+
+ # Get reference result without running trt.
+ config_no_trt = self._GetConfigProto(False)
+ print("Running original graph w/o trt, config:\n%s" % str(config_no_trt))
+ ref_result = self._RunGraph(graph_key, input_gdef, self._input,
+ config_no_trt)
+
+ # Run calibration if necessary.
+ if precision_mode == MODE_INT8:
+
+ calib_config = self._GetConfigProto(use_optimizer, precision_mode,
+ dynamic_calib_engine)
+ print("Running calibration graph, config:\n%s" % str(calib_config))
+ if use_optimizer:
+ self.assertTrue(False)
+ # TODO(aaroey): uncomment this and get infer_gdef when this mode is
+ # supported.
+ # result = self._RunCalibration(graph_key, input_gdef, self._input,
+ # calib_config)
+ else:
+ calib_gdef = self._GetTrtGraph(input_gdef, precision_mode,
+ dynamic_calib_engine)
+ self._VerifyGraphDef(graph_key, calib_gdef, precision_mode, False,
+ dynamic_calib_engine)
+ result = self._RunCalibration(graph_key, calib_gdef, self._input,
+ calib_config)
+ infer_gdef = trt.calib_graph_to_infer_graph(calib_gdef)
+ self._VerifyGraphDef(graph_key, infer_gdef, precision_mode, True,
+ dynamic_calib_engine)
+ self.assertAllClose(ref_result, result, rtol=1.e-03)
+ else:
+ infer_gdef = input_gdef
+
+ # Run inference.
+ infer_config = self._GetConfigProto(use_optimizer, precision_mode,
+ dynamic_infer_engine)
+ print("Running final inference graph, config:\n%s" % str(infer_config))
+ if use_optimizer:
+ result = self._RunGraph(graph_key, infer_gdef, self._input, infer_config)
+ else:
+ trt_infer_gdef = self._GetTrtGraph(infer_gdef, precision_mode,
+ dynamic_infer_engine)
+ self._VerifyGraphDef(graph_key, trt_infer_gdef, precision_mode, True,
+ dynamic_infer_engine)
+ result = self._RunGraph(graph_key, trt_infer_gdef, self._input,
+ infer_config)
+ self.assertAllClose(ref_result, result, rtol=1.e-03)
+
+ def testIdempotence(self):
+ # Test that applying tensorrt optimizer or offline conversion tools multiple
+ # times to the same graph will result in same graph.
+ # TODO(aaroey): implement this.
+ pass
+
+
+def GetTests():
+
+ def _GetTest(g, u, p, i, c):
+
+ def _Test(self):
+ print("Running test with parameters: graph_key=%s, use_optimizer=%s, "
+ "precision_mode=%s, dynamic_infer_engine=%s, "
+ "dynamic_calib_engine=%s" % (g, u, p, i, c))
+ self._RunTest(g, u, p, i, c)
+
+ return _Test
+
+ use_optimizer_options = [False, True]
+ precision_mode_options = [MODE_FP32, MODE_FP16, MODE_INT8]
+ dynamic_infer_engine_options = [False, True]
+ dynamic_calib_engine_options = [False, True]
+ for (graph_key, use_optimizer, precision_mode,
+ dynamic_infer_engine, dynamic_calib_engine) in itertools.product(
+ TEST_GRAPHS, use_optimizer_options, precision_mode_options,
+ dynamic_infer_engine_options, dynamic_calib_engine_options):
+ if precision_mode == MODE_INT8:
+ if not dynamic_calib_engine and dynamic_infer_engine:
+ # TODO(aaroey): test this case, the conversion from static calibration
+ # engine to dynamic inference engine should be a noop.
+ continue
+ if use_optimizer:
+ # TODO(aaroey): if use_optimizer is True we need to get the inference
+ # graphdef using custom python wrapper class, which is not currently
+ # supported yet.
+ continue
+ if not dynamic_calib_engine:
+ # TODO(aaroey): construction of static calibration engine is not
+ # supported yet.
+ continue
+ if dynamic_calib_engine and not dynamic_infer_engine:
+ # TODO(aaroey): construction of static inference engine using dynamic
+ # calibration engine is not supported yet.
+ continue
+ else: # In non int8 mode.
+ if dynamic_calib_engine:
+ # dynamic_calib_engine doesn't affect non-int8 modes, so just let
+ # related tests run once on dynamic_calib_engine=False.
+ continue
+ yield _GetTest(graph_key, use_optimizer, precision_mode,
+ dynamic_infer_engine, dynamic_calib_engine)
if __name__ == "__main__":
- googletest.main()
+ for index, t in enumerate(GetTests()):
+ setattr(TfTrtIntegrationTest, "testTfTRT_" + str(index), t)
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 46480e99a1..d6628cd1eb 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -48,12 +48,53 @@ PyObject* pair_helper(std::pair<string, string>* in) {
}
return tuple;
}
+
+struct version_struct{
+ int vmajor;
+ int vminor;
+ int vpatch;
+};
+
+PyObject* version_helper(version_struct* in) {
+ PyObject *tuple(nullptr);
+ tuple = Py_BuildValue("(iii)", in->vmajor, in->vminor, in->vpatch);
+ if (!tuple) {
+ if (!PyErr_Occurred()) {
+ PyErr_SetString(PyExc_TypeError,
+ "Tuple creation from version structure failed!");
+ }
+ return NULL;
+ }
+ return tuple;
+}
+/* Define converters for vector<int> */
+template<>
+bool _PyObjAs(PyObject *pyobj, int* dest) {
+ *dest = PyLong_AsLong(pyobj);
+ return true;
+}
+
+template<>
+PyObject *_PyObjFrom(const int& src) {
+ return PyLong_FromLong(src);
+}
+
%}
+
+_LIST_OUTPUT_TYPEMAP(int, PyLong_FromLong);
+
%typemap(out) std::pair<string, string> {
PyObject *tuple = pair_helper(&$1);
if (!tuple) SWIG_fail;
$result = tuple;
}
+
+%typemap(out) version_struct {
+ PyObject *tuple = version_helper(&$1);
+ if (!tuple) SWIG_fail;
+ $result = tuple;
+}
+
%{
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -65,6 +106,8 @@ PyObject* pair_helper(std::pair<string, string>* in) {
%unignore tensorflow;
%unignore trt_convert;
%unignore calib_convert;
+%unignore get_linked_tensorrt_version;
+%unignore get_loaded_tensorrt_version;
%{
@@ -74,7 +117,10 @@ std::pair<string, string> trt_convert(
size_t max_batch_size,
size_t max_workspace_size_bytes,
int precision_mode,
- int minimum_segment_size
+ int minimum_segment_size,
+ bool is_dyn_op,
+ int max_cached_engines,
+ std::vector<int> cached_engine_batches
// Unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
@@ -102,11 +148,12 @@ std::pair<string, string> trt_convert(
out_status = "InvalidArgument;Size of the output_names vector is 0";
return std::pair<string, string>{out_status, ""};
}
- tensorflow::GraphDef outGraph;
+ tensorflow::GraphDef out_graph;
tensorflow::Status conversion_status =
tensorflow::tensorrt::convert::ConvertGraphDefToTensorRT(
graph_def, output_names, max_batch_size, max_workspace_size_bytes,
- &outGraph, precision_mode, minimum_segment_size);
+ &out_graph, precision_mode, minimum_segment_size,
+ is_dyn_op, max_cached_engines, cached_engine_batches);
if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code();
char buff[2000];
@@ -116,7 +163,7 @@ std::pair<string, string> trt_convert(
return std::pair<string, string>{out_status, ""};
}
string result;
- if (!outGraph.SerializeToString(&result)) {
+ if (!out_graph.SerializeToString(&result)) {
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
return std::pair<string, string>{out_status, ""};
}
@@ -128,7 +175,8 @@ std::pair<string, string> trt_convert(
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
}
-std::pair<string, string> calib_convert(string graph_def_string // const tensorflow::GraphDef&
+std::pair<string, string> calib_convert(
+ string graph_def_string, bool is_dyn_op
// unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
@@ -147,11 +195,11 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
out_status = "InvalidArgument;Couldn't interpret input as a GraphDef";
return std::pair<string, string>{out_status, ""};
}
-
- tensorflow::GraphDef outGraph;
+ graph_def_string.resize(0);
+ tensorflow::GraphDef out_graph;
tensorflow::Status conversion_status =
- tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(graph_def,
- &outGraph);
+ tensorflow::tensorrt::convert::ConvertCalibGraphToInferGraph(
+ graph_def, &out_graph, is_dyn_op);
if (!conversion_status.ok()) {
auto retCode = (int)conversion_status.code();
char buff[2000];
@@ -161,7 +209,7 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
return std::pair<string, string>{out_status, ""};
}
string result;
- if (!outGraph.SerializeToString(&result)) {
+ if (!out_graph.SerializeToString(&result)) {
out_status = "InvalidArgument;Couldn't serialize output as a GraphDef";
return std::pair<string, string>{out_status, ""};
}
@@ -172,15 +220,43 @@ std::pair<string, string> calib_convert(string graph_def_string // const tenso
return std::pair<string, string>{"9;TensorRT is not enabled!", ""};
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
}
+
+version_struct get_linked_tensorrt_version() {
+ // Return the version at the link time.
+ version_struct s;
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ const auto &lv = tensorflow::tensorrt::convert::GetLinkedTensorRTVersion();
+ s.vmajor = lv[0];
+ s.vminor = lv[1];
+ s.vpatch = lv[2];
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+ return s;
+}
+version_struct get_loaded_tensorrt_version(){
+ // Return the version from the loaded library.
+ version_struct s;
+#if GOOGLE_CUDA && GOOGLE_TENSORRT
+ const auto &lv = tensorflow::tensorrt::convert::GetLoadedTensorRTVersion();
+ s.vmajor = lv[0];
+ s.vminor = lv[1];
+ s.vpatch = lv[2];
+#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
+ return s;
+}
+
%}
-std::pair<string, string> calib_convert(string graph_def_string);
+std::pair<string, string> calib_convert(string graph_def_string, bool is_dyn_op);
std::pair<string, string> trt_convert(string graph_def_string,
std::vector<string> output_names,
size_t max_batch_size,
size_t max_workspace_size_bytes,
- int precision_mode, int minimum_segment_size);
-
+ int precision_mode, int minimum_segment_size,
+ bool is_dyn_op,
+ int max_cached_engines,
+ std::vector<int> cached_engine_batches);
+version_struct get_linked_tensorrt_version();
+version_struct get_loaded_tensorrt_version();
%unignoreall
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index e4963596d3..7020989d68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -157,6 +157,7 @@ py_library(
py_test(
name = "head_test",
+ size = "large",
srcs = [
"head_test.py",
],
@@ -184,6 +185,7 @@ py_test(
"//tensorflow/python/saved_model:loader",
"//tensorflow/python/saved_model:tag_constants",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 4ec8d26116..769183f40a 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -288,7 +288,7 @@ class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
def __init__(self, model, state_manager=None, optimizer=None, model_dir=None,
- config=None):
+ config=None, head_type=ts_head_lib.TimeSeriesRegressionHead):
"""See TimeSeriesRegressor. Uses the ChainingStateManager by default."""
if not isinstance(model, state_space_model.StateSpaceModel):
raise ValueError(
@@ -301,7 +301,8 @@ class StateSpaceRegressor(TimeSeriesRegressor):
state_manager=state_manager,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
class StructuralEnsembleRegressor(StateSpaceRegressor):
@@ -344,7 +345,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
anomaly_prior_probability=None,
optimizer=None,
model_dir=None,
- config=None):
+ config=None,
+ head_type=ts_head_lib.TimeSeriesRegressionHead):
"""Initialize the Estimator.
Args:
@@ -401,6 +403,8 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
from tf.train.Optimizer. Defaults to Adam with step size 0.02.
model_dir: See `Estimator`.
config: See `Estimator`.
+ head_type: The kind of head to use for the model (inheriting from
+ `TimeSeriesRegressionHead`).
"""
if anomaly_prior_probability is not None:
filtering_postprocessor = StateInterpolatingAnomalyDetector(
@@ -424,4 +428,5 @@ class StructuralEnsembleRegressor(StateSpaceRegressor):
model=model,
optimizer=optimizer,
model_dir=model_dir,
- config=config)
+ config=config,
+ head_type=head_type)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py
index a28a5872b8..8686a803e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head.py
@@ -19,11 +19,7 @@ from __future__ import print_function
import re
-from tensorflow.python.training import training_util
-from tensorflow.contrib.layers.python.layers import optimizers
-
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
-
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import metric_keys
@@ -35,8 +31,9 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
from tensorflow.python.summary import summary
+from tensorflow.python.training import training_util
+from tensorflow.python.util import nest
class _NoStatePredictOutput(export_lib.PredictOutput):
@@ -102,12 +99,9 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
use_resource=True):
model_outputs = self.create_loss(features, mode)
- train_op = optimizers.optimize_loss(
+ train_op = self.optimizer.minimize(
model_outputs.loss,
- global_step=training_util.get_global_step(),
- optimizer=self.optimizer,
- # Learning rate is set in the Optimizer object
- learning_rate=None)
+ global_step=training_util.get_global_step())
return estimator_lib.EstimatorSpec(
loss=model_outputs.loss,
mode=mode,
@@ -132,7 +126,8 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
loss=model_outputs.loss,
mode=mode,
eval_metric_ops=metrics,
- predictions={})
+ # needed for custom metrics.
+ predictions=model_outputs.predictions)
def _predict_ops(self, features):
"""Add ops for prediction to the graph."""
@@ -210,12 +205,12 @@ class TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acce
def create_estimator_spec(self, features, mode, labels=None):
"""Performs basic error checking and returns an EstimatorSpec."""
with ops.name_scope(self._name, "head"):
- if labels:
+ if labels is not None and labels != {}: # for better error messages.
raise ValueError(
- "The model received a `labels` dictionary, which is "
- "not supported. Pass '{}' and '{}' as "
- "features.".format(feature_keys.TrainEvalFeatures.TIMES,
- feature_keys.TrainEvalFeatures.VALUES))
+ "The model received a `labels`, which is not supported. "
+ "Pass '{}' and '{}' as features.".format(
+ feature_keys.TrainEvalFeatures.TIMES,
+ feature_keys.TrainEvalFeatures.VALUES))
del labels
features = {
name: self._convert_feature_to_tensor(name=name, value=value)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index c606db76a6..78c2cec21c 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -18,9 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from absl.testing import parameterized
import numpy
import six
+from tensorflow.contrib.estimator.python.estimator import extenders
from tensorflow.contrib.timeseries.examples import lstm as lstm_example
from tensorflow.contrib.timeseries.python.timeseries import estimators as ts_estimators
from tensorflow.contrib.timeseries.python.timeseries import feature_keys
@@ -35,6 +39,7 @@ from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import variables
@@ -53,9 +58,12 @@ class HeadTest(test.TestCase):
model_fn = _stub_model_fn()
for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,
estimator_lib.ModeKeys.PREDICT]:
- with self.assertRaisesRegexp(ValueError, "labels"):
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
model_fn(features={}, labels={"a": "b"}, mode=mode)
+ with self.assertRaisesRegexp(ValueError, "received a `labels`"):
+ model_fn(features={}, labels=array_ops.zeros([]), mode=mode)
+
def test_unknown_mode(self):
model_fn = _stub_model_fn()
with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"):
@@ -128,6 +136,44 @@ class EvaluationMetricsTests(test.TestCase):
coordinator.request_stop()
coordinator.join()
+ def test_custom_metrics(self):
+ """Tests that the custom metrics can be applied to the estimator."""
+ model_dir = self.get_temp_dir()
+ estimator = ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(num_features=1, num_units=4),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ model_dir=model_dir)
+
+ def input_fn():
+ return {
+ feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]],
+ feature_keys.TrainEvalFeatures.VALUES:
+ numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]])
+ }
+
+ def metrics_fn(predictions, features):
+ # checking that the inputs are properly passed.
+ predict = predictions["mean"]
+ target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0]
+ return {
+ "plain_boring_metric386":
+ (math_ops.reduce_mean(math_ops.abs(predict - target)),
+ control_flow_ops.no_op()),
+ "fun_metric101": (math_ops.reduce_sum(predict + target),
+ control_flow_ops.no_op()),
+ }
+
+ # Evaluation without training is enough for testing custom metrics.
+ estimator = extenders.add_metrics(estimator, metrics_fn)
+ evaluation = estimator.evaluate(input_fn, steps=1)
+ self.assertIn("plain_boring_metric386", evaluation)
+ self.assertIn("fun_metric101", evaluation)
+ # The values are deterministic because of fixed tf_random_seed.
+ # However if they become flaky, remove such exacts comparisons.
+ self.assertAllClose(evaluation["plain_boring_metric386"], 1.130380)
+ self.assertAllClose(evaluation["fun_metric101"], 10.435442)
+
class _StubModel(object):
num_features = 3
@@ -274,10 +320,38 @@ class PredictFeatureCheckingTests(test.TestCase):
mode=estimator_lib.ModeKeys.PREDICT)
-class OneShotTests(test.TestCase):
-
- def test_one_shot_prediction_head_export(self):
- model_dir = self.get_temp_dir()
+def _custom_time_series_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.TimeSeriesRegressor(
+ model=lstm_example._LSTMModel(
+ num_features=5, num_units=128,
+ exogenous_feature_columns=exogenous_feature_columns),
+ optimizer=adam.AdamOptimizer(0.001),
+ config=estimator_lib.RunConfig(tf_random_seed=4),
+ state_manager=state_management.ChainingStateManager(),
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+def _structural_ensemble_regressor(
+ model_dir, head_type, exogenous_feature_columns):
+ return ts_estimators.StructuralEnsembleRegressor(
+ periodicities=None,
+ num_features=5,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=head_type,
+ model_dir=model_dir)
+
+
+class OneShotTests(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {"testcase_name": "custom_time_series_regressor",
+ "estimator_factory": _custom_time_series_regressor},
+ {"testcase_name": "structural_ensemble_regressor",
+ "estimator_factory": _structural_ensemble_regressor})
+ def test_one_shot_prediction_head_export(self, estimator_factory):
+ model_dir = os.path.join(test.get_temp_dir(), str(ops.uid()))
categorical_column = feature_column.categorical_column_with_hash_bucket(
key="categorical_exogenous_feature", hash_bucket_size=16)
exogenous_feature_columns = [
@@ -285,15 +359,10 @@ class OneShotTests(test.TestCase):
"2d_exogenous_feature", shape=(2,)),
feature_column.embedding_column(
categorical_column=categorical_column, dimension=10)]
- estimator = ts_estimators.TimeSeriesRegressor(
- model=lstm_example._LSTMModel(
- num_features=5, num_units=128,
- exogenous_feature_columns=exogenous_feature_columns),
- optimizer=adam.AdamOptimizer(0.001),
- config=estimator_lib.RunConfig(tf_random_seed=4),
- state_manager=state_management.ChainingStateManager(),
- head_type=ts_head_lib.OneShotPredictionHead,
- model_dir=model_dir)
+ estimator = estimator_factory(
+ model_dir=model_dir,
+ exogenous_feature_columns=exogenous_feature_columns,
+ head_type=ts_head_lib.OneShotPredictionHead)
train_features = {
feature_keys.TrainEvalFeatures.TIMES: numpy.arange(
20, dtype=numpy.int64),
@@ -308,7 +377,7 @@ class OneShotTests(test.TestCase):
num_threads=1, batch_size=16, window_size=16)
estimator.train(input_fn=train_input_fn, steps=5)
input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()
- export_location = estimator.export_savedmodel(self.get_temp_dir(),
+ export_location = estimator.export_savedmodel(test.get_temp_dir(),
input_receiver_fn)
graph = ops.Graph()
with graph.as_default():
@@ -342,7 +411,7 @@ class OneShotTests(test.TestCase):
for output_key, output_value
in predict_signature.outputs.items()}
output = session.run(fetches, feed_dict=feeds)
- self.assertAllEqual((2, 15, 5), output["mean"].shape)
+ self.assertEqual((2, 15, 5), output["mean"].shape)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index f84ff1bfe9..0d1c7fc75a 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -161,12 +161,40 @@ py_library(
)
py_library(
+ name = "keras_support",
+ srcs = [
+ "python/tpu/keras_support.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tpu_lib",
+ ":tpu_py",
+ "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tensor_spec",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/keras:backend",
+ "//tensorflow/python/keras:engine",
+ "//tensorflow/python/keras:layers",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "tpu_lib",
srcs = [
"python/tpu/__init__.py",
"python/tpu/bfloat16.py",
"python/tpu/device_assignment.py",
- "python/tpu/keras_support.py",
"python/tpu/session_support.py",
"python/tpu/topology.py",
"python/tpu/tpu.py",
@@ -181,6 +209,7 @@ py_library(
":datasets",
":profiler",
":tpu_py",
+ "//tensorflow/contrib/cluster_resolver:tpu_cluster_resolver_py",
"//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/contrib/tpu/proto:topology_proto_py",
"//tensorflow/core:protos_all_py",
@@ -306,3 +335,13 @@ tf_py_test(
"//tensorflow/python:framework_test_lib",
],
)
+
+tf_py_test(
+ name = "topology_test",
+ size = "small",
+ srcs = ["python/tpu/topology_test.py"],
+ additional_deps = [
+ ":tpu",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index d389050e67..06553929dc 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -23,15 +23,23 @@ REGISTER_OP("CrossReplicaSum")
.Input("input: T")
.Output("output: T")
.Attr("T: {bfloat16, float}")
+ .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input, and the output of each is the sum of
-all the inputs.
+instance supplies its own input. If group_assignment is empty, the output of
+each is the sum of all the inputs, otherwise the output of each is the sum of
+the inputs belonging to the same group.
+
+For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
+Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
input: The local input to the sum.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
+group_assignment: The list of group ids. `group_assignment[i]` represents the
+ group id of replica i.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index ab2a7a0d4b..15a2bb17a9 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -44,6 +44,27 @@ REGISTER_OP("TPUReplicatedInput")
" with other shapes.");
}
c->set_output(0, cur);
+
+ // If this is a resource, unify the resource shapes.
+ DataType dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
+ if (dtype == DT_RESOURCE) {
+ const std::vector<shape_inference::ShapeAndType>* shapes_and_types =
+ nullptr;
+ for (int i = c->num_inputs() - 1; i >= 0; --i) {
+ if (shapes_and_types) {
+ // The return value of MergeInputHandleShapesAndTypes indicates
+ // the shape was refined, not that there was an error.
+ // TODO(phawkins): there seems to be no way to discover errors.
+ (void)c->MergeInputHandleShapesAndTypes(i, *shapes_and_types);
+ } else {
+ shapes_and_types = c->input_handle_shapes_and_types(i);
+ }
+ }
+ if (shapes_and_types) {
+ c->set_output_handle_shapes_and_types(0, *shapes_and_types);
+ }
+ }
return Status::OK();
})
.Doc(
diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD
index dbf1ab6bbf..38d1c3049e 100644
--- a/tensorflow/contrib/tpu/profiler/BUILD
+++ b/tensorflow/contrib/tpu/profiler/BUILD
@@ -49,11 +49,11 @@ tf_cc_binary(
":tpu_profiler_analysis_proto_cc",
":tpu_profiler_proto_cc",
":version",
+ "//tensorflow:grpc++",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/platform/cloud:gcs_file_system",
- "@grpc//:grpc++_unsecure",
],
)
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 99485322c6..f80f5652af 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -18,7 +18,7 @@ limitations under the License.
// Initiates a TPU profiling on the TPUProfiler service at service_addr,
// receives and dumps the profile data to a tensorboard log directory.
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include <cstdio>
#include <ctime>
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 508c7a842f..7a5d01cca4 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -17,12 +17,11 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from absl import flags
-
import os
import subprocess
import sys
-
+from absl import flags
+from distutils.version import LooseVersion
import tensorflow as tf
# Cloud TPU Cluster Resolvers
@@ -35,26 +34,26 @@ flags.DEFINE_string(
None,
help='GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.')
-flags.DEFINE_string('tpu_name', None,
- 'Name of the Cloud TPU for Cluster Resolvers. You must '
- 'specify either this flag or --service_addr.')
+flags.DEFINE_string(
+ 'tpu', None, 'Name of the Cloud TPU for Cluster Resolvers. You must '
+ 'specify either this flag or --service_addr.')
# Tool specific parameters
flags.DEFINE_string(
'service_addr', None, 'Address of TPU profiler service e.g. '
- 'localhost:8466, you must specify either this flag or --tpu_name.')
+ 'localhost:8466, you must specify either this flag or --tpu.')
flags.DEFINE_string(
'workers_list', None, 'The list of worker TPUs that we are about to profile'
- ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu_name or '
+ ' e.g. 10.0.1.2, 10.0.1.3. You can specify this flag with --tpu or '
'--service_addr to profile a subset of tpu nodes. You can also use only'
- '--tpu_name and leave this flag unspecified to profile all the tpus.')
-flags.DEFINE_string('logdir', None,
- 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
- 'gs://tb_bucket')
+ '--tpu and leave this flag unspecified to profile all the tpus.')
+flags.DEFINE_string(
+ 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, '
+ 'gs://tb_bucket')
flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.')
-flags.DEFINE_integer('num_tracing_attempts', 3,
- 'Automatically retry N times when no trace '
- 'event is collected.')
+flags.DEFINE_integer(
+ 'num_tracing_attempts', 3, 'Automatically retry N times when no trace '
+ 'event is collected.')
flags.DEFINE_boolean('include_dataset_ops', True,
'Set to false to profile longer TPU '
'device traces.')
@@ -63,42 +62,50 @@ FLAGS = flags.FLAGS
EXECUTABLE = 'data/capture_tpu_profile'
JOB_NAME = 'worker'
+
def get_workers_list(cluster_resolver):
cluster_spec = cluster_resolver.cluster_spec()
task_indices = cluster_spec.task_indices(JOB_NAME)
- workers_list = [cluster_spec.task_address(JOB_NAME, i).split(':')[0]
- for i in task_indices]
+ workers_list = [
+ cluster_spec.task_address(JOB_NAME, i).split(':')[0] for i in task_indices
+ ]
return ','.join(workers_list)
+
def run_main():
tf.app.run(main)
+
def main(unused_argv=None):
tf.logging.set_verbosity(tf.logging.INFO)
+ tf_version = tf.__version__
+ print('TensorFlow version %s detected' % tf_version)
- if FLAGS.service_addr is None and FLAGS.tpu_name is None:
- sys.exit('You must specify either --service_addr or --tpu_name.')
+ if FLAGS.service_addr is None and FLAGS.tpu is None:
+ sys.exit('You must specify either --service_addr or --tpu.')
tpu_cluster_resolver = None
if FLAGS.service_addr is not None:
- if FLAGS.tpu_name is not None:
- tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring '
- '--tpu_name and using --service_addr.')
+ if FLAGS.tpu is not None:
+ tf.logging.warn('Both --service_addr and --tpu are set. Ignoring '
+ '--tpu and using --service_addr.')
service_addr = FLAGS.service_addr
else:
tpu_cluster_resolver = (
tf.contrib.cluster_resolver.TPUClusterResolver(
- [FLAGS.tpu_name],
- zone=FLAGS.tpu_zone,
- project=FLAGS.gcp_project))
+ [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project))
service_addr = tpu_cluster_resolver.get_master()
service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466')
- workers_list = ""
- if FLAGS.workers_list is not None:
- workers_list = FLAGS.workers_list
- elif tpu_cluster_resolver is not None:
- workers_list = get_workers_list(tpu_cluster_resolver)
+ workers_list = ''
+ if LooseVersion(tf_version) < LooseVersion('1.9'):
+ tf.logging.warn('Attempt to profile with legacy support under TensorFlow '
+ 'version %s' % tf_version)
+ else:
+ if FLAGS.workers_list is not None:
+ workers_list = FLAGS.workers_list
+ elif tpu_cluster_resolver is not None:
+ workers_list = get_workers_list(tpu_cluster_resolver)
if not FLAGS.logdir:
sys.exit('logdir must be provided.')
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index ebd478fd02..19f088f8b8 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.6.0'
+_VERSION = '1.9.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
@@ -46,7 +46,7 @@ setup(
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
- 'Development Status :: 4 - Beta',
+ 'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index 618479e1a6..1bf49966d1 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.6.0"
+#define TPU_PROFILER_VERSION "1.9.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 7ecb36852c..26016f47df 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -2,7 +2,12 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_additional_all_protos",
+ "tf_proto_library",
+ "tf_proto_library_py",
+)
tf_proto_library(
name = "tpu_embedding_config_proto",
@@ -22,12 +27,14 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
-tf_proto_library(
+tf_proto_library_py(
name = "compilation_result_proto",
srcs = [
"compilation_result.proto",
],
- cc_api_version = 2,
- protodeps = ["//tensorflow/core:protos_all"],
+ protodeps = tf_additional_all_protos() + [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_proto",
+ ],
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/contrib/tpu/proto/compilation_result.proto
index cf52897de3..88585a5bd1 100644
--- a/tensorflow/contrib/tpu/proto/compilation_result.proto
+++ b/tensorflow/contrib/tpu/proto/compilation_result.proto
@@ -3,6 +3,7 @@ syntax = "proto3";
option cc_enable_arenas = true;
package tensorflow.tpu;
+import "tensorflow/compiler/xla/service/hlo.proto";
import "tensorflow/core/lib/core/error_codes.proto";
// Describes the result of a TPU compilation.
@@ -10,4 +11,7 @@ message CompilationResultProto {
// The error message, if any, returned during compilation.
error.Code status_code = 1;
string status_error_message = 2;
+
+ // HLO proto.
+ repeated xla.HloProto hlo_protos = 3;
}
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 14c63a7976..bf442d9116 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -38,9 +38,8 @@ if platform.system() != "Windows":
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
- del op # Unused
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad)
+ return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index f1a11fa654..6e9c607f2e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -19,15 +19,16 @@ To use, wrap your model with the `keras_support.tpu_model` function.
Example usage:
```
-# Must activate before building TPU models
-keras_support.setup_tpu_session(master_address)
-
image = tf.keras.layers.Input(shape=(28, 28, 3), name='image')
c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image)
flattened = tf.keras.layers.Flatten()(c1)
logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
model = tf.keras.Model(inputs=[image], outputs=[logits])
-model = keras_support.tpu_model(model)
+
+strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
+model = keras_support.tpu_model(model,
+ strategy=strategy,
+ tpu_name_or_address=tpu_name)
# Only TF optimizers are currently supported.
model.compile(optimizer=tf.train.AdamOptimizer(), ...)
@@ -35,9 +36,6 @@ model.compile(optimizer=tf.train.AdamOptimizer(), ...)
# `images` and `labels` should be Numpy arrays. Support for tensor input
# (e.g. datasets) is planned.
model.fit(images, labels)
-
-# Invoke before shutting down
-keras_support.shutdown_tpu_session()
```
"""
@@ -48,13 +46,20 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import re
+import sys
import time
+import numpy as np
+
+from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
@@ -62,14 +67,17 @@ from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
+from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+TPUDistributionStrategy = tpu_strategy.TPUStrategy # pylint: disable=invalid-name
+
class TPUEmbedding(embeddings.Embedding):
"""TPU compatible embedding layer.
@@ -92,11 +100,49 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
+ """An optimizer that averages gradients across TPU shards."""
+
+ def __init__(self, opt, name='KerasCrossShardOptimizer'):
+ """Construct a new cross-shard optimizer.
+
+ Args:
+ opt: An existing `Optimizer` to encapsulate.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "KerasCrossShardOptimizer".
+
+ Raises:
+ ValueError: If reduction is not a valid cross-shard reduction.
+ """
+ super(KerasCrossShardOptimizer, self).__init__()
+ self._name = name
+ self._opt = opt
+
+ def get_updates(self, loss, params):
+ logging.info('Get updates: %s', loss)
+ self._opt.get_gradients = self.get_gradients
+ return self._opt.get_updates(loss, params)
+
+ def get_gradients(self, loss, params):
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
+ return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
+
+ def set_weights(self, weights):
+ self._opt.set_weights()
+
+ def get_weights(self):
+ return self._opt.get_weights()
+
+ @property
+ def lr(self):
+ return self._opt.lr
+
+
class TPUModelOp(
- collections.namedtuple(
- 'TPUModelOp',
- ['compile_op', 'execute_op', 'infeed_tensors', 'infeed_op',
- 'outfeed_op'])):
+ collections.namedtuple('TPUModelOp', [
+ 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
+ ])):
pass
@@ -105,13 +151,74 @@ def _valid_name(tensor_name):
return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
-def _replicated_optimizer(opt, num_replicas):
+def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- if num_replicas == 1:
+ if tpu_function.get_tpu_context().number_of_shards == 1:
return opt
- return keras_optimizers.TFOptimizer(
- optimizer=tpu_optimizer.CrossShardOptimizer(opt.optimizer)
- )
+
+ if isinstance(opt, keras_optimizers.TFOptimizer):
+ return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
+ else:
+ return KerasCrossShardOptimizer(opt)
+
+
+class TPURewriteContext(object):
+ """Prepare the environment for a Keras model during `tpu.rewrite`.
+
+ This overrides the default placeholder behaviour to instead refer to a preset
+ input mapping. Placeholders are unsupported in TPU compiled code, and must
+ be replaced with explicit inputs or values from the infeed queue.
+
+ Instead of explicitly threading inputs all the way through the Keras codebase,
+ we override the behavior of the placeholder while compiling and inject the
+ Tensors from the infeed in place of the placeholder.
+
+ Similarly, as we compile a new sub-graph for each unique shape and execution
+ mode, we need to override the behavior of an embedded `name_scope` call in
+ the base Keras layer code. This allows us to re-use the same weights across
+ many compiles and share a single session/graph.
+ """
+
+ def __init__(self, input_map):
+ self._input_map = input_map
+ self._default_placeholder = None
+ self._default_name_scope = None
+
+ def __enter__(self):
+
+ def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument
+ logging.info('Remapping placeholder for %s', name)
+ if name in self._input_map:
+ return self._input_map[name]
+ else:
+ logging.info('Default: %s', name)
+ return self._default_placeholder(dtype, shape, name)
+
+ def _name_scope(name, default_name=None, values=None):
+ caller_frame = sys._getframe().f_back
+ caller_obj = caller_frame.f_locals.get('self')
+ if (caller_obj is not None and
+ isinstance(caller_obj, base_layer.Layer) and name is not None):
+ logging.info('Intercepted name_scope: %s', caller_obj)
+ return variable_scope.variable_scope(
+ name, default_name, values, reuse=variable_scope.AUTO_REUSE)
+
+ return self._default_name_scope(name, default_name, values)
+
+ self._default_placeholder = array_ops.placeholder
+ self._default_name_scope = ops.name_scope
+ self._default_make_variable = base_layer.make_variable
+
+ array_ops.placeholder = _placeholder
+ ops.name_scope = _name_scope
+ base_layer.make_variable = variable_scope.get_variable
+ logging.info('Overriding default placeholder.')
+ return
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ array_ops.placeholder = self._default_placeholder
+ ops.name_scope = self._default_name_scope
+ base_layer.make_variable = self._default_make_variable
class TPUFunction(object):
@@ -126,19 +233,24 @@ class TPUFunction(object):
instead of being injected as `feed_dict` items or fetches.
"""
- def __init__(self, model, execution_mode, num_replicas=1):
+ def __init__(self, model, execution_mode, strategy):
self.model = model
self.execution_mode = execution_mode
+ self._strategy = strategy
self._compilation_cache = {}
- self.num_replicas = num_replicas
+ self._cloned_model = None
+
+ # Copy optimizer configuration. This is done prior to `_specialize_model`
+ # as the configuration may require evaluating variables in the CPU session.
+ self._optimizer_config = None
+ if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ self._optimizer_config = self.model.optimizer.get_config()
def _specialize_model(self, input_specs):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
# Re-create our input and output layers inside our subgraph. They will be
# attached to the true computation when we clone our model in `tpu_fn`.
- K.set_learning_phase(
- self.execution_mode == model_fn_lib.ModeKeys.TRAIN
- )
+ K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
# functools.partial and callable objects are not supported by tpu.rewrite
def _model_fn():
@@ -164,23 +276,32 @@ class TPUFunction(object):
infeed_tensors))
tpu_targets = []
- tpu_inputs = []
+ tpu_input_map = {}
# Sort infeed outputs into inputs and labels for calling our Keras model.
for tensor, layer in zip(infeed_tensors, infeed_layers):
if layer in self.model._input_layers:
- tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor))
+ tpu_input_map[layer.name] = tensor
if layer in self.model._output_layers:
tpu_targets.append(tensor)
- # Call our model with our infeed inputs (re-using the weights).
- model_outputs = self.model(tpu_inputs)
- child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
+ # Clone our CPU model, running within the TPU device context.
+ with TPURewriteContext(tpu_input_map):
+ self._cloned_model = models.clone_model(self.model)
+
+ # Create a copy of the optimizer for this graph.
+ if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
+ cloned_optimizer = keras_optimizers.TFOptimizer(
+ self.model.optimizer.optimizer)
+ else:
+ logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
+ self._optimizer_config)
+ cloned_optimizer = self.model.optimizer.__class__.from_config(
+ self._optimizer_config)
if is_training or is_test:
- child_model.compile(
- optimizer=_replicated_optimizer(self.model.optimizer,
- self.num_replicas),
+ self._cloned_model.compile(
+ optimizer=_replicated_optimizer(cloned_optimizer),
loss=self.model.loss,
loss_weights=self.model.loss_weights,
metrics=self.model.metrics,
@@ -190,37 +311,37 @@ class TPUFunction(object):
# Compute our outfeed depending on the execution mode
if is_training:
- child_model._make_train_function()
+ self._cloned_model._make_train_function()
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
- for tensor in child_model.train_function.outputs
+ for tensor in self._cloned_model.train_function.outputs
]
return [
- child_model.train_function.updates_op,
+ self._cloned_model.train_function.updates_op,
tpu_ops.outfeed_enqueue_tuple(
- child_model.train_function.outputs,
+ self._cloned_model.train_function.outputs,
name='outfeed-enqueue-train')
]
elif is_test:
- child_model._make_test_function()
+ self._cloned_model._make_test_function()
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
- for tensor in child_model.test_function.outputs
+ for tensor in self._cloned_model.test_function.outputs
]
return [
tpu_ops.outfeed_enqueue_tuple(
- child_model.test_function.outputs,
+ self._cloned_model.test_function.outputs,
name='outfeed-enqueue-test')
]
elif is_predict:
- child_model._make_predict_function()
+ self._cloned_model._make_predict_function()
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
- for tensor in child_model.predict_function.outputs
+ for tensor in self._cloned_model.predict_function.outputs
]
return [
tpu_ops.outfeed_enqueue_tuple(
- child_model.predict_function.outputs,
+ self._cloned_model.predict_function.outputs,
name='outfeed-enqueue-predict',
)
]
@@ -235,7 +356,7 @@ class TPUFunction(object):
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
# running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
- _model_fn, inputs=[[]] * self.num_replicas)
+ _model_fn, inputs=[[]] * self._strategy.num_towers)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
@@ -243,7 +364,7 @@ class TPUFunction(object):
outfeed_op = []
shard_infeed_tensors = []
- for shard_id in range(self.num_replicas):
+ for shard_id in range(self._strategy.num_towers):
with ops.device('/device:TPU:%d' % shard_id):
infeed_tensors = []
for spec in input_specs:
@@ -254,32 +375,35 @@ class TPUFunction(object):
name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
shard_infeed_tensors.append(infeed_tensors)
- infeed_op.append(tpu_ops.infeed_enqueue_tuple(
- infeed_tensors, [spec.shape for spec in input_specs],
- name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))
+ infeed_op.append(
+ tpu_ops.infeed_enqueue_tuple(
+ infeed_tensors, [spec.shape for spec in input_specs],
+ name='infeed-enqueue-%s-%d' % (self.execution_mode, shard_id)))
- outfeed_op.extend(tpu_ops.outfeed_dequeue_tuple(
- dtypes=[spec.dtype for spec in self._outfeed_spec],
- shapes=[spec.shape for spec in self._outfeed_spec],
- name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))
+ outfeed_op.extend(
+ tpu_ops.outfeed_dequeue_tuple(
+ dtypes=[spec.dtype for spec in self._outfeed_spec],
+ shapes=[spec.shape for spec in self._outfeed_spec],
+ name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id)))
return TPUModelOp(
- compile_op, execute_op, infeed_tensors=shard_infeed_tensors,
- infeed_op=infeed_op, outfeed_op=outfeed_op)
+ compile_op,
+ execute_op,
+ infeed_tensors=shard_infeed_tensors,
+ infeed_op=infeed_op,
+ outfeed_op=outfeed_op)
def _test_model_compiles(self, tpu_model_ops):
"""Verifies that the given TPUModelOp can be compiled via XLA."""
- session = K.get_session()
-
logging.info('Started compiling')
start_time = time.clock()
- result = session.run(tpu_model_ops.compile_op)
+ result = K.get_session().run(tpu_model_ops.compile_op)
proto = tpu_compilation_result.CompilationResultProto()
proto.ParseFromString(result)
if proto.status_error_message:
- raise RuntimeError(
- 'Compilation failed: {}'.format(proto.status_error_message))
+ raise RuntimeError('Compilation failed: {}'.format(
+ proto.status_error_message))
end_time = time.clock()
logging.info('Finished compiling. Time elapsed: %s secs',
@@ -296,17 +420,19 @@ class TPUFunction(object):
Returns:
List of lists containing the input to feed to each TPU shard.
"""
- if self.num_replicas == 1:
+ if self._strategy.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
- assert batch_size % self.num_replicas == 0, (
- 'batch_size must be divisible by num_replicas')
- shard_size = batch_size // self.num_replicas
+ assert batch_size % self._strategy.num_towers == 0, (
+ 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
+ (batch_size, self._strategy.num_towers))
+ shard_size = batch_size // self._strategy.num_towers
input_list = []
- for index in range(self.num_replicas):
- shard_inputs = [x[index * shard_size:(index + 1) * shard_size]
- for x in inputs]
+ for index in range(self._strategy.num_towers):
+ shard_inputs = [
+ x[index * shard_size:(index + 1) * shard_size] for x in inputs
+ ]
input_list.append(shard_inputs)
return input_list
@@ -343,12 +469,15 @@ class TPUFunction(object):
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
- logging.info('New input shapes; (re-)compiling: mode=%s, %s',
- self.execution_mode, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs)
- self._compilation_cache[shape_key] = new_tpu_model_ops
- self._test_model_compiles(new_tpu_model_ops)
-
+ with self.model.tpu_session():
+ logging.info('New input shapes; (re-)compiling: mode=%s, %s',
+ self.execution_mode, input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs)
+ self._compilation_cache[shape_key] = new_tpu_model_ops
+ self._test_model_compiles(new_tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
tpu_model_ops = self._compilation_cache[shape_key]
infeed_dict = {}
@@ -357,58 +486,82 @@ class TPUFunction(object):
for tensor, value in zip(infeed_tensors, inputs):
infeed_dict[tensor] = value
- session = K.get_session()
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
# TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
- return outfeed_outputs[:len(outfeed_outputs) // self.num_replicas]
-
-
-@experimental
-def setup_tpu_session(master):
- """Initializes and returns a Keras/TF session connected the TPU `master`."""
- session = tf_session.Session(
- target=master, config=config_pb2.ConfigProto(isolate_session_state=True))
- K.set_session(session)
- K.get_session().run(tpu.initialize_system())
- return session
+ if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
+ outputs = [[]] * len(self._outfeed_spec)
+ outputs_per_replica = len(self._outfeed_spec)
+ for i in range(self._strategy.num_towers):
+ output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
+ outputs_per_replica]
+ for j in range(outputs_per_replica):
+ outputs[j].append(output_group[j])
-@experimental
-def shutdown_tpu_session(session=None):
- """Shutdown the TPU attached to session.
-
- This should be called to cleanly shut down the TPU system before the client
- exits.
-
- Args:
- session: Session to shutdown, or None to use the default session.
-
- Returns:
-
- """
- if session is None:
- session = K.get_session()
-
- session.run(tpu.shutdown_system())
+ return [np.concatenate(group) for group in outputs]
+ else:
+ return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers]
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
- def __init__(self, inputs, outputs, name, replicas=1):
+ def __init__(self, cpu_model, tpu_name_or_address, strategy):
super(models.Model, self).__init__( # pylint: disable=bad-super-call
- inputs=inputs,
- outputs=outputs,
- name=name,
+ inputs=cpu_model.inputs,
+ outputs=cpu_model.outputs,
+ name=cpu_model.name,
)
+
self.predict_function = None
self.test_function = None
self.train_function = None
- self.replicas = replicas
+ self._strategy = strategy
+
+ self._tpu_name_or_address = tpu_name_or_address
+ self._cpu_model = cpu_model
+ self._tpu_model = None
+ self._tpu_weights_initialized = False
+ self._graph = ops.Graph()
+
+ cluster_resolver = tpu_cluster_resolver.TPUClusterResolver(
+ tpu_name_or_address)
+ cluster_spec = cluster_resolver.cluster_spec()
+ self._session = tf_session.Session(
+ graph=self._graph,
+ target=cluster_resolver.master(),
+ config=config_pb2.ConfigProto(isolate_session_state=True))
+
+ if cluster_spec:
+ self._session.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
+ with self._graph.as_default():
+ self._session.run(tpu.initialize_system())
+
+ # If the input CPU model has already been compiled, compile our TPU model
+ # immediately.
+ if self._cpu_model.optimizer:
+ self.compile(
+ self._cpu_model.optimizer,
+ self._cpu_model.loss,
+ self._cpu_model.metrics,
+ self._cpu_model.loss_weights,
+ self._cpu_model.sample_weight_mode,
+ self._cpu_model.weighted_metrics,
+ self._cpu_model.target_tensors,
+ )
+
+ def get_config(self):
+ return {
+ 'cpu_model': self._cpu_model,
+ 'tpu_name_or_address': self._tpu_name_or_address,
+ 'strategy': self._strategy,
+ }
def compile(self,
optimizer,
@@ -430,44 +583,97 @@ class KerasTPUModel(models.Model):
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- # Keras optimizers are not compatible with TPU rewrite
- if not isinstance(self.optimizer, keras_optimizers.TFOptimizer):
- raise ValueError(
- 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer)
+ if not self._cpu_model.optimizer:
+ self._cpu_model.compile(optimizer, loss, metrics, loss_weights,
+ sample_weight_mode, weighted_metrics,
+ target_tensors, **kwargs)
def _make_train_function(self):
if not self.train_function:
- self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN,
- num_replicas=self.replicas)
+ self.train_function = TPUFunction(
+ self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
return self.train_function
def _make_test_function(self):
if not self.test_function:
- self.test_function = TPUFunction(self, model_fn_lib.ModeKeys.EVAL)
+ self.test_function = TPUFunction(
+ self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
return self.test_function
def _make_predict_function(self):
if not self.predict_function:
- self.predict_function = TPUFunction(self, model_fn_lib.ModeKeys.PREDICT)
+ self.predict_function = TPUFunction(
+ self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
return self.predict_function
- def cpu_model(self):
- cpu_model = models.Model(
- inputs=self.inputs,
- outputs=self.outputs,
- name=self.name,
- )
+ def _initialize_weights(self, cloned_model):
+ """Initialize TPU weights.
- if self.optimizer:
- cpu_model.compile(
- optimizer=self.optimizer,
- loss=self.loss,
- metrics=self.metrics,
- loss_weights=self.loss_weights,
- )
+ This is called on the first compile of the TPU model (first call to
+ fit/predict/evaluate).
- return cpu_model
+ Args:
+ cloned_model: `keras.Model`, TPU model to initialize.
+ """
+ if self._tpu_weights_initialized:
+ return
+
+ self._tpu_model = cloned_model
+ self._tpu_weights_initialized = True
+
+ weights = self._cpu_model.get_weights()
+ with self.tpu_session():
+ logging.info('Setting weights on TPU model.')
+ cloned_model.set_weights(weights)
+
+ def sync_to_cpu(self):
+ """Copy weights from the CPU, returning a synchronized CPU model."""
+ if self._tpu_weights_initialized:
+ with self.tpu_session():
+ logging.info('Copying TPU weights to the CPU')
+ tpu_weights = self._tpu_model.get_weights()
+
+ self._cpu_model.set_weights(tpu_weights)
+
+ return self._cpu_model
+
+ def get_weights(self):
+ return self.sync_to_cpu().get_weights()
+
+ def save_weights(self, *args, **kw):
+ return self.sync_to_cpu().save_weights(*args, **kw)
+
+ def save(self, *args, **kw):
+ return self.sync_to_cpu().save(*args, **kw)
+
+ def set_weights(self, weights):
+ # We may not have a TPU model available if we haven't run fit/predict, so
+ # we can't directly set the TPU weights here.
+ # Instead, reset CPU model weights and force TPU re-initialization at the
+ # next call.
+ self._cpu_model.set_weights(weights)
+ self._tpu_weights_initialized = False
+
+ @contextlib.contextmanager
+ def tpu_session(self):
+ """Yields a TPU session and sets it as the default Keras session."""
+ with self._graph.as_default():
+ default_session = K.get_session()
+ # N.B. We have to call `K.set_session()` AND set our session as the
+ # TF default. `K.get_session()` surprisingly does not return the value
+ # supplied by K.set_session otherwise.
+ K.set_session(self._session)
+ with self._session.as_default():
+ yield self._session
+ K.set_session(default_session)
+
+ def shutdown(self):
+ logging.info('Shutting down TPU session.')
+ with self.tpu_session() as session:
+ session.run(tpu.shutdown_system())
+
+ self._session.close()
def _validate_shapes(model):
@@ -504,26 +710,8 @@ Output shape: %(output_shape)s
@experimental
-def tpu_model(model, replicas=None):
- """Runs a model on TPU(s).
-
- Usage:
- ```
- a = Input(shape=(32,))
- b = Dense(32)(a)
- model = Model(inputs=a, outputs=b)
-
- model = keras_support.tpu_model(model)
- model.compile(
- optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
- ...)
- ```
-
- If `replicas` is set, replicates the model computation on all TPU cores. The
- model computation is replicated `num_replicas` times; each shard will run on a
- different TPU core.
-
- Limitation: Currently, replication is only supported for training.
+def tpu_model(model, tpu_name_or_address=None, strategy=None):
+ """Copy `model` along with weights to the TPU. Returns a TPU model.
Usage:
```
@@ -531,26 +719,39 @@ def tpu_model(model, replicas=None):
b = Dense(32)(a)
model = Model(inputs=a, outputs=b)
- model = keras_support.tpu_model(model, replicas=2)
+ # If `num_cores_per_host` is greater than one, batch parallelism will be used
+ # to run on multiple TPU cores.
+ strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
+ model = keras_support.tpu_model(model, strategy)
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
...)
+ model.shutdown()
```
Args:
model: A `KerasTPUModel`.
- replicas: (Optional) Int, number of TPU cores which to create model
- replicas. If `None`, the model runs on single core only, i.e., no
- replication.
+ tpu_name_or_address: A string that is either the name of the Cloud TPU,
+ the grpc address of the Cloud TPU, or (Googlers only) the BNS name of the
+ Cloud TPU. If tpu_name_or_address is None, the TPUClusterResolver will
+ examine the environment to determine a potential Cloud TPU to use.
+ strategy: `TPUDistributionStrategy`. The strategy to use for replicating
+ model across multiple TPU cores.
Returns:
A new `KerasTPUModel` instance.
"""
+ # Force initialization of the CPU model.
+ model.get_weights()
+ model.reset_states()
+
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
# TODO(xiejw): Adds reduction option.
- replicas = 1 if replicas is None else replicas
+ if strategy is None:
+ strategy = TPUDistributionStrategy(num_cores_per_host=1)
return KerasTPUModel(
- inputs=model.inputs, outputs=model.outputs, name=model.name,
- replicas=replicas)
+ cpu_model=model,
+ tpu_name_or_address=tpu_name_or_address,
+ strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index cda9a63f20..1fb26e701a 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -55,8 +55,9 @@ class Topology(object):
rank 3 numpy int32 array that describes a valid coordinate mapping.
"""
+ self._serialized = serialized
+
if serialized:
- self._serialized = serialized
self._parse_topology(serialized)
else:
self._mesh_shape = np.asarray(mesh_shape, dtype=np.int32)
@@ -131,7 +132,7 @@ class Topology(object):
proto.mesh_shape[:] = list(self._mesh_shape)
proto.num_tasks = self._device_coordinates.shape[0]
proto.num_tpu_devices_per_task = self._device_coordinates.shape[1]
- proto.device_coordinates = list(self._device_coordinates.flatten())
+ proto.device_coordinates.extend(list(self._device_coordinates.flatten()))
self._serialized = proto.SerializeToString()
return self._serialized
diff --git a/tensorflow/contrib/tpu/python/tpu/topology_test.py b/tensorflow/contrib/tpu/python/tpu/topology_test.py
new file mode 100644
index 0000000000..e67fdb263a
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/topology_test.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Tests for topology.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tpu.python.tpu import topology
+
+from tensorflow.python.platform import test
+
+
+class TopologyTest(test.TestCase):
+
+ def testSerialization(self):
+ """Test if the class is able to generate serialzied string."""
+ original_topology = topology.Topology(
+ mesh_shape=[1, 1, 2],
+ device_coordinates=[[[0, 0, 0], [0, 0, 1]]],
+ )
+ serialized_str = original_topology.serialized()
+ new_topology = topology.Topology(serialized=serialized_str)
+
+ # Make sure the topology recovered from serialized str is same as the
+ # original topology.
+ self.assertAllEqual(
+ original_topology.mesh_shape, new_topology.mesh_shape)
+ self.assertAllEqual(
+ original_topology.device_coordinates, new_topology.device_coordinates)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 612cd0114b..6a64893d9a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -126,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
outside the replicated computation.
"""
- def __init__(self, name, num_replicas):
+ def __init__(self, name, num_replicas, pivot):
+ """Builds a new TPUReplicateContext.
+
+ Args:
+ name: a unique name for the context, used to populate the `_tpu_replicate`
+ attribute.
+ num_replicas: an integer that gives the number of replicas for the
+ computation.
+ pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any
+ inputs will have a control dependency on the pivot node. This ensures
+ that nodes are correctly included in any enclosing control flow
+ contexts.
+ """
super(TPUReplicateContext, self).__init__()
self._num_replicas = num_replicas
self._outer_device_function_stack = None
@@ -138,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._host_compute_core = []
self._name = name
self._unsupported_ops = []
+ self._pivot = pivot
def report_unsupported_operations(self):
if self._unsupported_ops:
@@ -214,7 +227,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
class FakeOp(object):
"""A helper class to determine the current device.
- Supports only the device set/get methods needed to run the
+ Supports only the type and device set/get methods needed to run the
graph's _apply_device_function method.
"""
@@ -222,11 +235,18 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._device = ""
@property
+ def type(self):
+ return "FakeOp"
+
+ @property
def device(self):
return self._device
def _set_device(self, device):
- self._device = device.to_string()
+ if isinstance(device, pydev.DeviceSpec):
+ self._device = device.to_string()
+ else:
+ self._device = device
if self._outside_compilation_cluster:
raise NotImplementedError("Cannot nest outside_compilation clusters")
@@ -262,9 +282,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
super(TPUReplicateContext, self).Enter()
- def Exit(self):
- super(TPUReplicateContext, self).Exit()
-
def HostComputeCore(self):
return self._host_compute_core
@@ -300,10 +317,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
op.graph.prevent_feeding(op)
op.graph.prevent_fetching(op)
+ # Remove any control edges from outer control flow contexts. These may cause
+ # mismatched frame errors.
+ control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+
+ if not op.inputs:
+ # Add a control edge from the control pivot to this op.
+ if not control_inputs:
+ # pylint: disable=protected-access
+ op._add_control_input(self.GetControlPivot())
+ # pylint: enable=protected-access
+ else:
+ for index in xrange(len(op.inputs)):
+ x = op.inputs[index]
+ real_x = self.AddValue(x)
+ if real_x != x:
+ op._update_input(index, real_x) # pylint: disable=protected-access
+
+ if external_inputs:
+ # Use an identity to pull control inputs as data inputs. Note that we
+ # ignore ops which don't have outputs. TODO(phawkins): fix that.
+ with ops.control_dependencies(None):
+ self.Enter()
+ external_inputs = [
+ array_ops.identity(x.outputs[0]).op
+ for x in external_inputs
+ if x.outputs
+ ]
+ self.Exit()
+ # pylint: disable=protected-access
+ op._add_control_inputs(external_inputs)
+ # pylint: enable=protected-access
+
+ # Mark op's outputs as seen by this context and any outer contexts.
+ output_names = [x.name for x in op.outputs]
+ context = self
+ while context is not None:
+ # pylint: disable=protected-access
+ context._values.update(output_names)
+ context = context._outer_context
+ # pylint: enable=protected-access
+
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
def AddValue(self, val):
+ if val.name in self._values:
+ # Use the real value if it comes from outer context.
+ result = self._external_values.get(val.name)
+ return val if result is None else result
+
result = val
+ self._values.add(val.name)
if self._outer_context:
result = self._outer_context.AddValue(val)
+ self._values.add(result.name)
+
+ self._external_values[val.name] = result
+
return result
def AddInnerOp(self, op):
@@ -319,6 +390,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# grad_state should be as if this is the top-level gradient state.
return None
+ @property
+ def back_prop(self):
+ """Forwards to the enclosing while context, if any."""
+ if self.GetWhileContext():
+ return self.GetWhileContext().back_prop
+ return False
+
+ def GetControlPivot(self):
+ return self._pivot
+
def outside_compilation(computation, *args, **kwargs):
"""Builds part of a computation outside any current TPU replicate scope.
@@ -505,7 +586,9 @@ def split_compile_and_replicate(computation,
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
cluster_name = graph.unique_name("cluster")
- context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas)
+ pivot = control_flow_ops.no_op(name=cluster_name + "/pivot")
+ context = TPUReplicateContext(
+ name=cluster_name, num_replicas=num_replicas, pivot=pivot)
try:
context.Enter()
@@ -515,16 +598,22 @@ def split_compile_and_replicate(computation,
with tpu_function.tpu_shard_context(
num_replicas), ops.control_dependencies([metadata]):
- # The EncapsulateTPUComputations rewrite needs to identify the
- # replicated arguments inside each computation. Adds identity operators
- # tagged with an attribute _tpu_replicated_input to identify the
- # replicated inputs.
+ # For backward compatibility reasons, we tag replicated inputs with the
+ # _tpu_replicated_input attribute. This does nothing and exists only for
+ # backward compatibility.
+ # TODO(phawkins): delete the attr_scope after 6/28/2018.
# pylint: disable=protected-access
- with graph._attr_scope({"_tpu_replicated_input":
- attr_value_pb2.AttrValue(b=True)}):
+ with graph._attr_scope({
+ "_tpu_replicated_input": attr_value_pb2.AttrValue(b=True)
+ }):
+ # Add identity ops so even unused inputs are "consumed" by the
+ # computation. This is to avoid orphaned TPUReplicatedInput nodes.
+ # TODO(phawkins): consider instead pruning unused TPUReplicatedInput
+ # and eliding trivial TPUReplicatedInput/TPUReplicatedOutput pairs.
computation_inputs = [
array_ops.identity(x, name="replicated_input_{}".format(i))
- for i, x in enumerate(computation_inputs)]
+ for i, x in enumerate(computation_inputs)
+ ]
# pylint: enable=protected-access
# If there is an infeed queue, adds the dequeued values to the
@@ -547,10 +636,16 @@ def split_compile_and_replicate(computation,
vscope.set_use_resource(saved_use_resource)
+ # If the computation returns `None`, make it an empty tuple.
+ if outputs is None:
+ outputs = tuple()
# If the computation only returned one value, makes it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
+ # Append `no_op` here so that fetching any return value of this function
+ # will trigger TPUExecute node.
+ outputs += (control_flow_ops.no_op(),)
try:
with ops.device(core(0)):
outputs = [
@@ -582,6 +677,7 @@ def split_compile_and_replicate(computation,
with ops.device(t.device if t.device else core(0)):
new_output_tensors.append(array_ops.identity(t))
output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
finally:
context.report_unsupported_operations()
context.Exit()
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 6d7331e3c7..2f2e97b3cd 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -23,8 +23,6 @@ import collections
import json
import os
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import util as util_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
@@ -50,7 +48,7 @@ class TPUConfig(
collections.namedtuple('TPUConfig', [
'iterations_per_loop',
'num_shards',
- 'computation_shape',
+ 'num_cores_per_replica',
'per_host_input_for_training',
'tpu_job_name',
'initial_infeed_sleep_secs',
@@ -67,13 +65,11 @@ class TPUConfig(
case, this number equals the total number of TPU cores. For
model-parallelism, the total number of TPU cores equals
product(computation_shape) * num_shards.
- computation_shape: Defaults to `None`, which disables model parallelism. A
- list of size 3 which describes the shape of a model replica's block of
- cores. This is required by model-parallelism which enables partitioning
- the model to multiple cores. For example, [2, 2, 1] means the model is
- partitioned across 4 cores which span two cores in both x and y
- coordinates. Please refer to @{tf.contrib.tpu.Topology} for the
- geometry of a TPU mesh.
+ num_cores_per_replica: Defaults to `None`, which disables model parallelism.
+ An integer which describes the number of TPU cores per model replica. This
+ is required by model-parallelism which enables partitioning
+ the model to multiple cores. Currently num_cores_per_replica must be
+ 1, 2, 4, or 8.
per_host_input_for_training: If `True`, `PER_HOST_V1`, or `PER_HOST_V2`,
`input_fn` is invoked per-host rather than per-core. With per-host input
pipeline configuration, `input_fn` is invoked once on each host. With the
@@ -99,7 +95,7 @@ class TPUConfig(
def __new__(cls,
iterations_per_loop=2,
num_shards=None,
- computation_shape=None,
+ num_cores_per_replica=None,
per_host_input_for_training=True,
tpu_job_name=None,
initial_infeed_sleep_secs=None):
@@ -112,19 +108,12 @@ class TPUConfig(
if num_shards is not None:
util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')
- # Check computation_shape
- if computation_shape is not None and len(computation_shape) != 3:
- raise ValueError(
- 'computation_shape must be a list with length 3 or None; got {}'.
- format(str(computation_shape)))
-
- if computation_shape is not None:
- computation_shape_array = np.asarray(computation_shape, dtype=np.int32)
- # This prevents any computation being replicated across multiple hosts, so
- # that each host feeds the same number of computations.
- if any(computation_shape_array < 1) or any(computation_shape_array > 2):
- raise ValueError('computation_shape elements can only be 1 or 2; got '
- 'computation_shape={}'.format(computation_shape))
+ # Parse computation_shape
+ if num_cores_per_replica is not None:
+ if num_cores_per_replica not in [1, 2, 4, 8]:
+ raise ValueError(
+ 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
# Map legacy values (True, False) to numeric values.
@@ -144,7 +133,7 @@ class TPUConfig(
cls,
iterations_per_loop=iterations_per_loop,
num_shards=num_shards,
- computation_shape=computation_shape,
+ num_cores_per_replica=num_cores_per_replica,
per_host_input_for_training=per_host_input_for_training,
tpu_job_name=tpu_job_name,
initial_infeed_sleep_secs=initial_infeed_sleep_secs)
@@ -214,6 +203,12 @@ class RunConfig(run_config_lib.RunConfig):
self._session_config.cluster_def.CopyFrom(
self._cluster_spec.as_cluster_def())
+ def _maybe_overwrite_session_config_for_distributed_training(self):
+ # Overrides the parent class session_config overwrite for between-graph. TPU
+ # runs with in-graph, which should not have device filter. Doing nothing
+ # ("pass") basically disables it.
+ pass
+
@property
def evaluation_master(self):
return self._evaluation_master
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 37ef3dbe1e..2326fe97a8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import json
from tensorflow.contrib.tpu.python.tpu import tpu_config as tpu_config_lib
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
@@ -33,6 +34,46 @@ def _set_tf_config_env_variable(tf_config):
class TPURunConfigTest(test.TestCase):
+ def test_no_session_config_set_in_local_case(self):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertIsNone(run_config.session_config)
+
+ def test_no_session_config_overwrite_in_local_case(self):
+ session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ run_config = tpu_config_lib.RunConfig(session_config=session_config)
+ self.assertEqual(session_config, run_config.session_config)
+
+ def test_no_session_config_set_with_cluster_spec(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host3:3'],
+ run_config_lib.TaskType.WORKER: ['host3:4']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ with _set_tf_config_env_variable(tf_config):
+ run_config = tpu_config_lib.RunConfig()
+ self.assertIsNone(run_config.session_config)
+
+ def test_no_session_config_overwrite_with_cluster_spec(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host3:3'],
+ run_config_lib.TaskType.WORKER: ['host3:4']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ with _set_tf_config_env_variable(tf_config):
+ session_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ run_config = tpu_config_lib.RunConfig(session_config=session_config)
+ self.assertEqual(session_config, run_config.session_config)
+
def test_fail_with_invalid_num_shards(self):
with self.assertRaisesRegexp(ValueError, 'must be positive'):
tpu_config_lib.RunConfig(
@@ -43,15 +84,11 @@ class TPURunConfigTest(test.TestCase):
tpu_config_lib.RunConfig(
tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=0))
- def test_fail_with_invalid_computation_shape(self):
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape must be a list with length'
- ' 3 or None'):
- tpu_config_lib.TPUConfig(computation_shape=[2, 1])
-
- with self.assertRaisesRegexp(ValueError,
- 'computation_shape elements can only be'):
- tpu_config_lib.TPUConfig(computation_shape=[1, 3, 1])
+ def test_fail_with_invalid_num_cores_per_replica(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ' got 7'):
+ tpu_config_lib.TPUConfig(num_cores_per_replica=7)
class TPURunConfigMasterTest(test.TestCase):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 5b9aeaa879..0efbe45dbf 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -21,8 +21,6 @@ from __future__ import print_function
from contextlib import contextmanager
import copy
-import numpy as np
-
from tensorflow.contrib.tpu.python.tpu import device_assignment as tpu_device_assignment
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
@@ -33,6 +31,12 @@ from tensorflow.python.platform import tf_logging as logging
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
+_NUM_CORES_TO_COMPUTATION_SHAPE = {
+ 1: [1, 1, 1],
+ 2: [1, 1, 2],
+ 4: [1, 2, 2],
+ 8: [2, 2, 2]
+}
class TPUContext(object):
@@ -92,6 +96,19 @@ class TPUContext(object):
"""
return self._internal_ctx.num_replicas
+ @property
+ def num_hosts(self):
+ """The number of hosts for the TPU system."""
+ return self._internal_ctx.num_hosts
+
+ @property
+ def num_of_replicas_per_host(self):
+ """The number of replicas for each host."""
+ if self._internal_ctx.model_parallelism_enabled:
+ raise ValueError(
+ 'num_of_replicas_per_host is not supported for model_parallelism')
+ return self._internal_ctx.num_of_replicas_per_host
+
def device_for_replica(self, replica_id):
"""Returns the tuple of (CPU device and device ordinal) for replica.
@@ -108,8 +125,8 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
# If the precise replica_id to device mapping is required, please
- # set the computation_shape as [1,1,1] in TPUConfig to enable
- # the model parallelism.
+ # set the num_cores_per_replica to 1 in TPUConfig to enable the
+ # model parallelism.
if self._internal_ctx.model_parallelism_enabled:
return RuntimeError(
'device_for_replica is not yet implemented for model parallelism. '
@@ -162,9 +179,14 @@ class _InternalTPUContext(object):
self._eval_on_tpu = eval_on_tpu
self._model_parallelism_enabled = (
- use_tpu and config.tpu_config.computation_shape)
+ use_tpu and config.tpu_config.num_cores_per_replica)
self._mode = None
-
+ num_cores_per_replica = config.tpu_config.num_cores_per_replica
+ if num_cores_per_replica:
+ self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[
+ num_cores_per_replica]
+ else:
+ self._computation_shape = None
self._lazy_tpu_system_metadata_dict = {} # key by master address
self._lazy_device_assignment_dict = {} # key by master address
self._lazy_validation_dict = {} # key by ModeKeys
@@ -225,11 +247,12 @@ class _InternalTPUContext(object):
device_assignment = tpu_device_assignment.device_assignment(
tpu_system_metadata.topology,
- computation_shape=self._config.tpu_config.computation_shape,
+ computation_shape=self._computation_shape,
num_replicas=self.num_replicas)
- logging.info('computation_shape: %s',
- str(self._config.tpu_config.computation_shape))
+ logging.info('num_cores_per_replica: %s',
+ str(self._config.tpu_config.num_cores_per_replica))
+ logging.info('computation_shape: %s', str(self._computation_shape))
logging.info('num_replicas: %d', self.num_replicas)
logging.info('device_assignment.topology.device_coordinates: %s',
str(device_assignment.topology.device_coordinates))
@@ -270,23 +293,20 @@ class _InternalTPUContext(object):
num_cores_in_system = self.num_cores
if self.model_parallelism_enabled:
- computation_shape_array = np.asarray(
- self._config.tpu_config.computation_shape, dtype=np.int32)
- num_cores_per_replica = np.prod(computation_shape_array)
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
if num_cores_per_replica > num_cores_in_system:
raise ValueError(
'The num of cores required by the model parallelism, specified by '
- 'TPUConfig.computation_shape, is larger than the total num of '
- 'TPU cores in the system. computation_shape: {}, num cores '
- 'in the system: {}'.format(
- self._config.tpu_config.computation_shape,
- num_cores_in_system))
+ 'TPUConfig.num_cores_per_replica, is larger than the total num of '
+ 'TPU cores in the system. num_cores_per_replica: {}, num cores '
+ 'in the system: {}'.format(num_cores_per_replica,
+ num_cores_in_system))
if num_cores_in_system % num_cores_per_replica != 0:
raise RuntimeError(
'The num of cores in the system ({}) is not divisible by the num '
'of cores ({}) required by the model parallelism, specified by '
- 'TPUConfig.computation_shape. This should never happen!'.format(
+ 'TPUConfig.num_cores_per_replica. This should never happen!'.format(
num_cores_in_system, num_cores_per_replica))
return num_cores_in_system // num_cores_per_replica
@@ -384,9 +404,7 @@ class _InternalTPUContext(object):
# On TPU
if self.is_input_sharded_per_core() or (
self.is_input_per_host_with_iterators()):
- # We prohibit per core input sharding for the model parallelism case,
- # therefore it is safe to use num_cores here.
- return global_batch_size // self.num_cores
+ return global_batch_size // self.num_replicas
else:
return global_batch_size // self.num_hosts
@@ -484,25 +502,27 @@ class _InternalTPUContext(object):
return _placement_function
- @property
- def tpu_ordinal_function(self):
+ def tpu_ordinal_function(self, host_id):
"""Returns the TPU ordinal fn."""
- def _tpu_ordinal_function(index):
+ def _tpu_ordinal_function(shard_index_in_host):
"""Return the TPU ordinal associated with a shard.
Required because the enqueue ops are placed on CPU.
Args:
- index: the shard index
+ shard_index_in_host: the shard index
Returns:
The ordinal of the TPU device the shard's infeed should be placed on.
"""
if self.model_parallelism_enabled:
- return self.device_assignment.tpu_ordinal(replica=index)
+ # We put both enqueue/dequeue ops at tpu.core(0) in each replica.
+ replica = self.device_assignment.lookup_replicas(
+ host_id, (0, 0, 0))[shard_index_in_host]
+ return self.device_assignment.tpu_ordinal(replica=replica)
else:
- return index % self.num_of_cores_per_host
+ return shard_index_in_host % self.num_of_cores_per_host
return _tpu_ordinal_function
@@ -533,7 +553,7 @@ class _InternalTPUContext(object):
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
- 'product(computation_shape) * num_replicas. Please set it '
+ 'num_cores_per_replica * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
@@ -612,7 +632,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
"""Returns an instance of `_InternalTPUContext`."""
if (config.tpu_config.num_shards == 1 and
- config.tpu_config.computation_shape is None):
+ config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
'Please fix as soon as possible (leaving num_shards as None.')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index c57acd0a2d..8a137005b6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -81,12 +81,17 @@ _TPU_ESTIMATOR = 'tpu_estimator'
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CTX_KEY = 'context'
+_USE_TPU_KEY = 'use_tpu'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
+# Ideally _USE_TPU_KEY should be reserved as well. However there are already
+# models that make use of this key, thus it can not be reserved now to prevent
+# breakage. In the long run, we would like to mitigate this by migrating models
+# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
@@ -122,6 +127,33 @@ def _create_global_step(graph):
def _create_or_get_iterations_per_loop():
+ """Creates or gets the iterations_per_loop variable.
+
+ In TPUEstimator, the user provided computation, the model_fn, is wrapped
+ inside a tf.while_loop for peak performance. The iterations of the loop are
+ specified by this variable, which adjusts its value on the CPU after each TPU
+ program execution and before the next TPU execution.
+
+ The purpose of using a variable, rather then a constant, is to allow
+ TPUEstimator adapt the TPU training iterations according to the final steps
+ specified by users. For example, if the user sets the iterations_per_loop as 4
+ in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
+ variable will have the following value before each TPU training.
+
+ - 1-th TPU execution: iterations_per_loop = 4
+ - 2-th TPU execution: iterations_per_loop = 4
+ - 3-th TPU execution: iterations_per_loop = 2
+
+ As model_fn increases the global step once per train_op invocation, the global
+ step is 10 after all TPU executions, matching the steps=10 inputs passed in by
+ users.
+
+ Returns:
+ A TF non-trainable resource variable.
+
+ Raises:
+ RuntimeError: If multi iterations_per_loop variables were found.
+ """
graph = ops.get_default_graph()
collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
iter_vars = graph.get_collection(collection_name)
@@ -184,8 +216,8 @@ class _SIGNAL(object):
class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
- See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
- 'export_outputs`.
+ See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
+ `export_outputs`.
For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
@@ -199,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
size is the first dimension. Once all tensors are available at CPU host from
all shards, they are concatenated (on CPU) and passed as positional arguments
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
- dict. `metric_fn` takes the `tensors` and returns a dict from metric string
+ a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
`eval_metrics`.
@@ -388,20 +420,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
return
def _cancel_session():
- # Close the session to avoid the main thread from hanging. If input
- # pipeline triggers any error, the infeed thread dies but the main thread
- # for TPU computation waits for the infeed enqueue forever. Close the
- # Session to cancel the main thread Session.run execution.
- #
- # We sleep for a few seconds before closing to give some time
- # for the TPU compilation error, if any, propagating, from TPU to CPU
- # host. Compilation errors should be reported by the main thread so that
- # the program can be interrupted and users can take action. Due to a race
- # condition, the infeed thread might see an error first. Closing the
- # session here immediately would result in a session cancellation
- # exception in the main thread, instead of the expected compile error.
- # User code that depends on having the proper exception type will
- # therefore be confused.
+ """Close the session to avoid the main thread from hanging.
+
+ If input pipeline triggers any error, the infeed thread dies but the main
+ thread for TPU computation waits for the infeed enqueue forever. Close the
+ Session to cancel the main thread Session.run execution.
+
+ We sleep for a few seconds before closing to give some time for the TPU
+ compilation error, if any, propagating, from TPU to CPU host. Compilation
+ errors should be reported by the main thread so that the program can be
+ interrupted and users can take action. Due to a race condition, the
+ infeed thread might see an error first. Closing the session here
+ immediately would result in a session cancellation exception in the main
+ thread, instead of the expected compile error. User code that depends on
+ having the proper exception type will therefore be confused.
+ """
time.sleep(5)
# If the main session is still running, the infeed/outfeed errors are
@@ -636,6 +669,7 @@ def generate_per_core_enqueue_ops_fn_for_host(
ctx, input_fn, inputs_structure_recorder, host_device, host_id):
"""Generates infeed enqueue ops for per-core input_fn on a single host."""
captured_infeed_queue = _CapturedObject()
+ tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
"""A fn returns enqueue_ops."""
@@ -671,7 +705,7 @@ def generate_per_core_enqueue_ops_fn_for_host(
per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
- per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
+ per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
return per_host_enqueue_ops
return enqueue_ops_fn, captured_infeed_queue
@@ -706,21 +740,18 @@ def generate_per_host_enqueue_ops_fn_for_host(
if is_dataset:
hooks.append(inputs.dataset_initializer_hook())
- # TODO(ylc): Refactoring the code to merge the tpu ordinal logic here and the
- # _InternalTPUContext.tpu_ordinal_function. We should either introduce another
- # abstraction or a different helper method.
- def _tpu_ordinal_function_impl(shard_index_in_host):
- # We put both enqueue/dequeue op at tpu.core(0) in each replica.
- replica = ctx.device_assignment.lookup_replicas(
- host_id, (0, 0, 0))[shard_index_in_host]
- return ctx.device_assignment.tpu_ordinal(replica=replica)
-
- if ctx.model_parallelism_enabled:
- tpu_ordinal_function = _tpu_ordinal_function_impl
- else:
- tpu_ordinal_function = None
+ tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
+ """A Fn returning the TPU infeed enqueue ops.
+
+ By providing as a Fn, it can be invoked inside the tf.while_loop such that
+ the input pipeline for multiple iterations can be executed by one
+ Session.run call.
+
+ Returns:
+ list of dict of ops.
+ """
with ops.device(device):
num_of_replicas_per_host = ctx.num_of_replicas_per_host
# Convert user input to features and labels. If the user returns a
@@ -745,7 +776,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
infeed_queue.split_inputs_and_generate_enqueue_ops(
unsharded_tensor_list,
placement_function=lambda x: device,
- tpu_ordinal_function=tpu_ordinal_function))
+ tpu_ordinal_function=tpu_ordinal_function_impl))
if signals is None:
return per_host_enqueue_ops
else:
@@ -779,6 +810,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
raise TypeError('Most PREDICT not yet supported in PER_HOST_V2 mode.')
hooks.append(inputs.dataset_initializer_hook())
+ tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
def enqueue_ops_fn():
"""Generates the per_host enqueue ops."""
@@ -809,7 +841,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
- per_host_sharded_inputs, tpu_ordinal_function=ctx.tpu_ordinal_function)
+ per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
return per_host_enqueue_ops
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
@@ -1095,10 +1127,16 @@ class _InputPipeline(object):
return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator
def _validate_input_pipeline(self):
- # Perform some sanity checks to log user friendly information. We should
- # error out to give users better error message. But, if
- # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
- # user code, so, log a warning.
+ """Validates the input pipeline.
+
+ Perform some sanity checks to log user friendly information. We should
+ error out to give users better error message. But, if
+ _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
+ user code, so, log a warning.
+
+ Raises:
+ RuntimeError: If the validation failed.
+ """
if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
err_msg = ('Input pipeline contains one or more QueueRunners. '
'It could be slow and not scalable. Please consider '
@@ -1300,8 +1338,55 @@ class _ModelFnWrapper(object):
key, tensor))
return predictions
+ def _validate_model_features_and_labels(self,
+ features,
+ labels,
+ is_export_mode):
+ """Validates that the features and labels for the model function are valid.
+
+ A valid features/labels object is the one with:
+ - Type: Tensor or a dictionary of Tensors
+ - Static shape if is_export_mode is False.
+
+ Args:
+ features: the features that would be input to the model function.
+ labels: the labels that would be input to the model function.
+ is_export_mode: boolean value specifying if in export mode.
+
+ Raises:
+ TypeError: If features/labels are not of the correct type.
+ ValueError: If features/labels have dynamic shape.
+ """
+
+ def validate(obj, obj_name):
+ """Helper validate function."""
+ if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict):
+ raise TypeError(
+ 'The {} to the model returned by input_fn must be either a Tensor '
+ 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name,
+ obj))
+ if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
+ return
+ if isinstance(obj, ops.Tensor):
+ if not obj.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static shape.'
+ ' Tensor: {}'.format(obj_name, obj))
+ else:
+ for (key, tensor) in obj.items():
+ if not tensor.get_shape().is_fully_defined():
+ raise ValueError(
+ 'The {} to the model returned by input_fn must have static '
+ 'shape. Key: \'{}\', Tensor: {}'.format(
+ obj_name, key, tensor))
+
+ validate(features, 'features')
+ if labels is not None:
+ validate(labels, 'labels')
+
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
+ self._validate_model_features_and_labels(features, labels, is_export_mode)
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
@@ -1334,8 +1419,11 @@ class _ModelFnWrapper(object):
if batch_size_for_model_fn is not None:
_add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
+ running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
+ _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
+
estimator_spec = self._model_fn(features=features, **kwargs)
- if (self._ctx.is_running_on_cpu(is_export_mode) and
+ if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
@@ -1807,16 +1895,11 @@ class TPUEstimator(estimator_lib.Estimator):
export_outputs['classes'] =
export_output_lib.ClassificationOutput(classes=classes)
- tpu.outside_compilation(host_call, [logits])
+ tpu.outside_compilation(host_call, logits)
...
```
- Current limitations:
- --------------------
-
- 1. Outside compilation does not work yet (b/79991729).
-
"""
def __init__(self,
@@ -1837,7 +1920,8 @@ class TPUEstimator(estimator_lib.Estimator):
Args:
model_fn: Model function as required by `Estimator`. For training, the
returned `EstimatorSpec` cannot have hooks as it is not supported in
- `TPUEstimator`.
+ `TPUEstimator`. Instead, the user can pass the training hooks as
+ an argument to `TPUEstimator.train()`.
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
continue training a previously saved model. If `None`, the model_dir in
@@ -1902,7 +1986,7 @@ class TPUEstimator(estimator_lib.Estimator):
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
- config.tpu_config.computation_shape):
+ config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')
@@ -1957,24 +2041,29 @@ class TPUEstimator(estimator_lib.Estimator):
strip_default_attrs,
save_variables=True,
mode=model_fn_lib.ModeKeys.PREDICT,
- export_tags=None):
+ export_tags=None,
+ check_variables=True):
if mode != model_fn_lib.ModeKeys.PREDICT:
raise NotImplementedError(
'TPUEstimator only handles mode PREDICT for export_savedmodel(); '
'got {}.'.format(mode))
- super(TPUEstimator, self)._add_meta_graph_for_mode(builder,
- input_receiver_fn_map,
- checkpoint_path,
- strip_default_attrs,
- save_variables,
- mode=mode)
+ (super(TPUEstimator, self).
+ _add_meta_graph_for_mode(builder,
+ input_receiver_fn_map,
+ checkpoint_path,
+ strip_default_attrs,
+ save_variables,
+ mode=mode,
+ export_tags=export_tags,
+ check_variables=check_variables))
if self._export_to_tpu:
input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
input_receiver_fn_map[mode]}
export_tags = [tag_constants.SERVING, tag_constants.TPU]
mode = _REWRITE_FOR_INFERENCE_MODE
+ # See b/110052256 for why `check_variables` is `False`.
(super(TPUEstimator, self).
_add_meta_graph_for_mode(builder,
input_receiver_fn_map,
@@ -1982,7 +2071,8 @@ class TPUEstimator(estimator_lib.Estimator):
strip_default_attrs,
save_variables=False,
mode=mode,
- export_tags=export_tags))
+ export_tags=export_tags,
+ check_variables=False))
def _call_model_fn(self, features, labels, mode, config):
if mode == _REWRITE_FOR_INFERENCE_MODE:
@@ -2034,10 +2124,21 @@ class TPUEstimator(estimator_lib.Estimator):
# Reconstruct `tensors`, but with `tpu_tensors` replaced with
# `tpu_tensors_on_cpu`.
- new_tensors = [
- tpu_tensors_on_cpu.pop(0) if _is_tpu_tensor(t) else t
- for t in tensors
- ]
+ new_tensors = []
+ for t in tensors:
+ if _is_tpu_tensor(t):
+ new_tensors.append(tpu_tensors_on_cpu.pop(0))
+ elif t is None:
+ new_tensors.append(None)
+ else:
+ # Only fetching `tpu_tensors_on_cpu` does not trigger
+ # TPU computation and blocks, so we add the control dependency here.
+ control_inputs = (tpu_tensors_on_cpu
+ if isinstance(tpu_tensors_on_cpu, (list, tuple))
+ else (tpu_tensors_on_cpu,))
+ with ops.control_dependencies(control_inputs):
+ new_tensors.append(array_ops.identity(t))
+
# Reconstruct `tensors_dict`.
new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
# Reconstruct `export_outputs`.
@@ -2197,10 +2298,20 @@ class TPUEstimator(estimator_lib.Estimator):
# Clear the bit.
self._is_input_fn_invoked = None
+ # examples_hook is added to training_hooks for both CPU and TPU
+ # execution.
+ examples_hook = ExamplesPerSecondHook(
+ ctx.global_batch_size,
+ output_dir=self.model_dir,
+ every_n_steps=self._log_every_n_steps)
+
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
- return model_fn_wrapper.call_without_tpu(
+ estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks + (examples_hook,))
+ return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
@@ -2268,10 +2379,6 @@ class TPUEstimator(estimator_lib.Estimator):
},
every_n_iter=logging_hook_frequency)
])
- examples_hook = ExamplesPerSecondHook(
- ctx.global_batch_size,
- output_dir=self.model_dir,
- every_n_steps=self._log_every_n_steps)
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
@@ -2641,7 +2748,7 @@ class _CapturedObject(object):
def capture(self, o):
if self._captured:
raise RuntimeError(
- 'InternalError: Object can be captured only. Please file bug .')
+ 'InternalError: Object can capture only once. Please file bug.')
self._captured = True
self._object = o
@@ -2650,7 +2757,7 @@ class _CapturedObject(object):
if not self._captured:
raise RuntimeError(
'InternalError: Object is not captured properly before `get`. '
- 'Please file bug .')
+ 'Please file bug.')
return self._object
@@ -2898,6 +3005,7 @@ class _StopSignals(object):
@staticmethod
def should_stop(scalar_stopping_signal):
+ """Detects whether scalar_stopping_signal indicates stopping."""
if isinstance(scalar_stopping_signal, ops.Tensor):
# STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
# way to express the bool check whether scalar_stopping_signal is True.
@@ -3017,7 +3125,7 @@ class _SignalsHelper(object):
def __init__(self, signals):
self._signal_keys = []
- for key in sorted(signals.iterkeys()):
+ for key in sorted(iter(signals.keys())):
self._signal_keys.append(key)
@property
@@ -3029,7 +3137,7 @@ class _SignalsHelper(object):
@staticmethod
def as_tensor_list(signals):
- return [signals[key] for key in sorted(signals.iterkeys())]
+ return [signals[key] for key in sorted(iter(signals.keys()))]
def _verify_cross_hosts_transfer_size(tensor_dict, message):
@@ -3055,7 +3163,7 @@ def _add_item_to_params(params, key, value):
if isinstance(params, hparam.HParams):
# For HParams, we need to use special API.
if key in params:
- params.key = value
+ params.set_hparam(key, value)
else:
params.add_hparam(key, value)
else:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index e76cf83e4d..53d33f4077 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,8 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer
@@ -32,7 +35,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
def __init__(self,
opt,
reduction=losses.Reduction.MEAN,
- name="CrossShardOptimizer"):
+ name="CrossShardOptimizer",
+ group_assignment=None):
"""Construct a new cross-shard optimizer.
Args:
@@ -40,6 +44,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
+ group_assignment: Optional list of group ids for applying the optimizer
+ to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -50,6 +56,35 @@ class CrossShardOptimizer(optimizer.Optimizer):
super(CrossShardOptimizer, self).__init__(False, name)
self._opt = opt
self._reduction = reduction
+ self._group_assignment = group_assignment
+
+ def _verify_and_get_subgroup_size(self, group_assignment, num_shards):
+ """Verify group_assignment and get the subgroup size".
+
+ Args:
+ group_assignment: list of group ids for applying the optimizer
+ to subgroups.
+ num_shards: The number of TPU shards.
+
+ Returns:
+ The size of one subgroup in group_assignment.
+
+ Raises:
+ ValueError: If group_assignment is invalid.
+ """
+ if not group_assignment:
+ return None
+ if len(group_assignment) != num_shards:
+ raise ValueError("The size of group_assignment does not equal to "
+ "num_shard({0}). Got group_assignment={1}".format(
+ num_shards, self._group_assignment))
+ subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if all(subgroup_size_list[0] == size for size in subgroup_size_list):
+ return subgroup_size_list[0]
+ else:
+ raise ValueError("The size of each subgroup in group_assignment must "
+ "be equal. Got group_assignment={}".format(
+ self._group_assignment))
def compute_gradients(self, loss, var_list=None, **kwargs):
"""Compute gradients of "loss" for the variables in "var_list".
@@ -71,7 +106,8 @@ class CrossShardOptimizer(optimizer.Optimizer):
A list of (gradient, variable) pairs.
Raises:
- ValueError: If not within a tpu_shard_context.
+ ValueError: If not within a tpu_shard_context or group_assignment is
+ invalid.
"""
num_shards = tpu_function.get_tpu_context().number_of_shards
if num_shards is None:
@@ -79,9 +115,17 @@ class CrossShardOptimizer(optimizer.Optimizer):
"CrossShardOptimizer should be used within a tpu_shard_context, but "
"got unset number_of_shards. Assuming 1.")
num_shards = 1
+
+ subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
+ num_shards)
+
if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
- scale = 1.0 / num_shards
+ if self._group_assignment:
+ scale = 1.0 / subgroup_size
+ else:
+ scale = 1.0 / num_shards
loss *= scale
+
return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
@@ -110,7 +154,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
if grad is None:
summed_grads_and_vars.append((grad, var))
else:
- summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var))
+ with ops.colocate_with(grad):
+ summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
+ grad, self._group_assignment), var))
return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
def get_slot(self, *args, **kwargs):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
index c3882b8a27..6bdaa528f9 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
@@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase):
def testIsInContext(self):
"""Test that control_flow_util can check that we're in a TPU context."""
z1 = array_ops.identity(1)
- context = tpu.TPUReplicateContext(b"context", 1)
+ pivot = control_flow_ops.no_op()
+ context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot)
context.Enter()
z2 = array_ops.identity(1)
context.Exit()
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index 5de55b5f7f..76927e62e8 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -295,7 +295,7 @@ py_test(
tags = ["notsan"],
deps = [
":training_py",
- "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
+ "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py
new file mode 100644
index 0000000000..ed0f398e30
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay.py
@@ -0,0 +1,187 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""SGDR learning rate decay function."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops, control_flow_ops
+
+
+def sgdr_decay(learning_rate, global_step, initial_period_steps,
+ t_mul=2.0, m_mul=1.0, name=None):
+ """Implements Stochastic Gradient Descent with Warm Restarts (SGDR).
+
+ As described in "SGDR: Stochastic Gradient Descent
+ with Warm Restarts" by Ilya Loshchilov & Frank Hutter, Proceedings of
+ ICLR'2017, available at https://arxiv.org/pdf/1608.03983.pdf
+
+ The learning rate decreases according to cosine annealing:
+
+ ```python
+ learning_rate * 0.5 * (1 + cos(x_val * pi)) # for x_val defined in [0, 1]
+ ```
+
+ Thus, at the beginning (when the restart index i = 0),
+ the learning rate decreases for `initial_period_steps` steps from the initial
+ learning rate `learning_rate` (when `x_val=0`, we get `cos(0)=1`) to
+ 0 (when `x_val=1`, we get `cos(pi)=-1`).
+
+ The decrease within the i-th period takes `t_i` steps,
+ where `t_0` = `initial_period_steps` is the user-defined number of batch
+ iterations (not epochs as in the paper) to be performed before the first
+ restart is launched.
+
+ Then, we perform the first restart (i=1) by setting the learning rate to
+ `learning_rate*(m_mul^i)`, where `m_mul in [0,1]` (set to 1 by default).
+ The i-th restart runs for `t_i=t_0*(t_mul^i)` steps, i.e., every new
+ restart runs `t_mul` times longer than the previous one.
+
+ Importantly, when one has no access to a validation set, SGDR suggests
+ to report the best expected / recommended solution in the following way:
+ When we are within our initial run (i=0), every new solution represents
+ SGDR's recommended solution. Instead, when i>0, the recommended solution is
+ the one obtained at the end of each restart.
+
+ Note that the minimum learning rate is set to 0 for simplicity,
+ you can adjust the code to deal with any positive minimum learning rate
+ as defined in the paper.
+
+ `initial_period_steps` is the duration of the first period measured in terms
+ of number of minibatch updates. If one wants to use epochs, one should compute
+ the number of updates required for an epoch.
+
+ For example, assume the following parameters and intention:
+ Minibatch size: 100
+ Training dataset size: 10000
+ If the user wants the first decay period to span across 5 epochs, then
+ `initial_period_steps` = 5 * 10000/100 = 500
+
+ Train for 10000 batch iterations with the initial learning rate set to
+ 0.1, then restart to run 2 times longer, i.e, for 20000 batch iterations
+ and with the initial learning rate 0.05, then restart again and again,
+ doubling the runtime of each new period and with two times smaller
+ initial learning rate.
+
+ To accomplish the above, one would write:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate = sgdr_decay(starter_learning_rate, global_step,
+ initial_period_steps=10000, t_mul=2, m_mul=0.5)
+ # Passing global_step to minimize() will increment it at each step.
+ learning_step = (
+ tf.train.GradientDescentOptimizer(learning_rate)
+ .minimize(...my loss..., global_step=global_step)
+ )
+
+ # Step | 0 | 1000 | 5000 | 9000 | 9999 | 10000 | 11000 |
+ # LR | 0.1 | 0.097 | 0.05 | 0.002 | 0.00 | 0.05 | 0.0496 |
+
+ # Step | 20000 | 29000 | 29999 | 30000 |
+ # LR | 0.025 | 0.0003 | 0.00 | 0.025 |
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ initial_period_steps: Duration of the first period measured as the number
+ of minibatch updates, if one wants to use epochs, one should compute
+ the number of updates required for an epoch.
+ t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Must be positive.
+ Used to derive the number of iterations in the i-th period:
+ `initial_period_steps * (t_mul^i)`. Defaults to 2.0.
+ m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
+ Must be positive.
+ Used to derive the initial learning rate of the i-th period:
+ `learning_rate * (m_mul^i)`. Defaults to 1.0
+
+ Returns:
+ A scalar `Tensor` of the same type as `learning_rate`.
+ The learning rate for a provided global_step.
+ Raises:
+ ValueError: if `global_step` is not supplied.
+ """
+
+ if global_step is None:
+ raise ValueError("global_step is required for sgdr_decay.")
+ with ops.name_scope(name, "SGDRDecay",
+ [learning_rate, global_step,
+ initial_period_steps, t_mul, m_mul]) as name:
+ learning_rate = ops.convert_to_tensor(learning_rate,
+ name="initial_learning_rate")
+ dtype = learning_rate.dtype
+ global_step = math_ops.cast(global_step, dtype)
+ t_0 = math_ops.cast(initial_period_steps, dtype)
+ t_mul = math_ops.cast(t_mul, dtype)
+ m_mul = math_ops.cast(m_mul, dtype)
+
+ c_one = math_ops.cast(constant_op.constant(1.0), dtype)
+ c_half = math_ops.cast(constant_op.constant(0.5), dtype)
+ c_pi = math_ops.cast(constant_op.constant(math.pi), dtype)
+
+ # Find normalized value of the current step
+ x_val = math_ops.div(global_step, t_0)
+
+ def compute_step(x_val, geometric=False):
+ if geometric:
+ # Consider geometric series where t_mul != 1
+ # 1 + t_mul + t_mul^2 ... = (1 - t_mul^i_restart) / (1 - t_mul)
+
+ # First find how many restarts were performed for a given x_val
+ # Find maximal integer i_restart value for which this equation holds
+ # x_val >= (1 - t_mul^i_restart) / (1 - t_mul)
+ # x_val * (1 - t_mul) <= (1 - t_mul^i_restart)
+ # t_mul^i_restart <= (1 - x_val * (1 - t_mul))
+
+ # tensorflow allows only log with base e
+ # i_restart <= log(1 - x_val * (1 - t_mul) / log(t_mul)
+ # Find how many restarts were performed
+
+ i_restart = math_ops.floor(
+ math_ops.log(c_one - x_val * (c_one - t_mul)) / math_ops.log(t_mul))
+ # Compute the sum of all restarts before the current one
+ sum_r = (c_one - t_mul ** i_restart) / (c_one - t_mul)
+ # Compute our position within the current restart
+ x_val = (x_val - sum_r) / t_mul ** i_restart
+
+ else:
+ # Find how many restarts were performed
+ i_restart = math_ops.floor(x_val)
+ # Compute our position within the current restart
+ x_val = x_val - i_restart
+ return i_restart, x_val
+
+ i_restart, x_val = control_flow_ops.cond(
+ math_ops.equal(t_mul, c_one),
+ lambda: compute_step(x_val, geometric=False),
+ lambda: compute_step(x_val, geometric=True))
+
+ # If m_mul < 1, then the initial learning rate of every new restart will be
+ # smaller, i.e., by a factor of m_mul ** i_restart at i_restart-th restart
+ m_fac = learning_rate * (m_mul ** i_restart)
+
+ return math_ops.multiply(c_half * m_fac,
+ (math_ops.cos(x_val * c_pi) + c_one), name=name)
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
new file mode 100644
index 0000000000..4a46e9a49e
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -0,0 +1,145 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functional test for sgdr learning rate decay."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from sgdr_learning_rate_decay import sgdr_decay
+from tensorflow.python.platform import googletest
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import dtypes
+from tensorflow import placeholder
+
+
+class SGDRDecayTest(test_util.TensorFlowTestCase):
+ """Unit tests for SGDR learning rate decay."""
+
+ def get_original_values(self, lr, t_e, mult_factor, iter_per_epoch, epochs):
+ """Get an array with learning rate values from the consecutive steps using
+ the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ t0 = math.pi / 2.0
+ tt = 0
+ te_next = t_e
+
+ lr_values = []
+ sh_lr = lr
+ for epoch in range(epochs):
+ for _ in range(iter_per_epoch):
+ # In the original approach training function is executed here
+ lr_values.append(sh_lr)
+ dt = 2.0 * math.pi / float(2.0 * t_e)
+ tt = tt + float(dt) / iter_per_epoch
+ if tt >= math.pi:
+ tt = tt - math.pi
+ cur_t = t0 + tt
+ new_lr = lr * (1.0 + math.sin(cur_t)) / 2.0 # lr_min = 0, lr_max = lr
+ sh_lr = new_lr
+ if (epoch + 1) == te_next: # time to restart
+ sh_lr = lr
+ tt = 0 # by setting to 0 we set lr to lr_max, see above
+ t_e = t_e * mult_factor # change the period of restarts
+ te_next = te_next + t_e # note the next restart's epoch
+
+ return lr_values
+
+ def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
+ """Get an array with learning rate values from the consecutive steps
+ using current tensorflow implementation."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
+ lr_values = []
+ for i in range(iters):
+ lr_values.append(decay.eval(feed_dict={step: i}))
+
+ return lr_values
+
+ def testCompareToOriginal(self):
+ """Compare values generated by tensorflow implementation to the values
+ generated by the original implementation
+ (https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
+ with self.test_session():
+ lr = 10.0
+ init_steps = 2
+ t_mul = 3
+ iters = 10
+ epochs = 50
+
+ org_lr = self.get_original_values(lr, init_steps, t_mul, iters, epochs)
+ sgdr_lr = self.get_sgdr_values(lr, init_steps*iters, t_mul, iters*epochs)
+
+ for org, sgdr in zip(org_lr, sgdr_lr):
+ self.assertAllClose(org, sgdr)
+
+ def testMDecay(self):
+ """Test m_mul argument. Check values for learning rate at the beginning
+ of the first, second, third and fourth period. """
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+
+ lr = 0.1
+ t_e = 10
+ t_mul = 3
+ m_mul = 0.9
+
+ decay = sgdr_decay(lr, step, t_e, t_mul, m_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul)
+
+ test_step = t_e + t_e*t_mul
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * m_mul**2)
+
+ test_step = t_e + t_e*t_mul + t_e * (t_mul**2)
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}),
+ lr * (m_mul**3))
+
+ def testCos(self):
+ """Check learning rate values at the beginning, in the middle
+ and at the end of the period."""
+ with self.test_session():
+ step = placeholder(dtypes.int32)
+ lr = 0.2
+ t_e = 1000
+ t_mul = 1
+
+ decay = sgdr_decay(lr, step, t_e, t_mul)
+
+ test_step = 0
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+ test_step = t_e
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr)
+
+ test_step = t_e*3//2
+ self.assertAllClose(decay.eval(feed_dict={step: test_step}), lr/2)
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index 409aba817c..a2444934bc 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
- # pylint: disable=protected-access
if padded_shapes is None:
self._padded_shapes = nest.map_structure(
- dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ convert.partial_shape_to_tensor, input_dataset.output_shapes)
else:
self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ input_dataset.output_shapes, convert.partial_shape_to_tensor,
padded_shapes)
+ # pylint: disable=protected-access
padding_values = (
padding_values if padding_values is not None else
dataset_ops._default_padding(input_dataset))
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index 0338f409a2..df0a186f4f 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
index 9720fd6e86..19cb8983b6 100644
--- a/tensorflow/contrib/verbs/BUILD
+++ b/tensorflow/contrib/verbs/BUILD
@@ -53,12 +53,12 @@ cc_library(
":grpc_verbs_service_impl",
":rdma_mgr",
":verbs_service_proto_cc",
+ "//tensorflow:grpc++",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
- "@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
@@ -69,7 +69,7 @@ cc_library(
hdrs = ["grpc_verbs_service_impl.h"],
deps = [
":verbs_service_proto_cc",
- "@grpc//:grpc++_unsecure",
+ "//tensorflow:grpc++",
],
)
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
index 742f946c95..af29abd91f 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service.cc
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -15,9 +15,9 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
-#include "grpc++/alarm.h"
-#include "grpc++/grpc++.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
index 991f9a9d8b..4da7b59c69 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
index 1f0f10517e..abe5e08b07 100644
--- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
#define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
index 86350a08e5..f7c979e863 100644
--- a/tensorflow/contrib/verbs/rdma.cc
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -24,8 +24,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
#if GOOGLE_CUDA
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
#endif
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
@@ -1084,7 +1084,7 @@ void RdmaTensorResponse::RecvHandler(Rendezvous::ParsedKey parsed,
// The tensor must be copied from GPU to CPU, because either:
// 1. The tensor is located on a non GDR compatible GPU.
// 2. The tensor's meta-data has changed.
- Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
copy = Tensor(alloc, in.dtype(), in.shape());
CountCopies(rm_.name_, (void*)DMAHelper::base(&in),
(void*)DMAHelper::base(&copy), in.TotalBytes(), true);
@@ -1541,7 +1541,7 @@ bool RdmaTensorRequest::AllocateTensors() {
if (mr_ == nullptr) {
// Can't RDMA directly to result. Use a proxy.
proxy_tensor_ =
- new Tensor(ProcessState::singleton()->GetCUDAHostAllocator(0),
+ new Tensor(GPUProcessState::singleton()->GetCUDAHostAllocator(0),
result_tensor_->dtype(), result_tensor_->shape());
rdma_addr_ = DMAHelper::base(proxy_tensor_);
mr_ =
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 369bd986df..9cb3d1fbbf 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include "tensorflow/core/common_runtime/bfc_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/process_state.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/allocator_registry.h"
@@ -282,7 +283,7 @@ void RdmaMgr::InitAllocators() {
Allocator* allocators[] = {
#if GOOGLE_CUDA
- ProcessState::singleton()->GetCUDAHostAllocator(0),
+ GPUProcessState::singleton()->GetCUDAHostAllocator(0),
ProcessState::singleton()->GetCPUAllocator(0),
#endif // GOOGLE_CUDA
cpu_allocator(),
@@ -323,7 +324,8 @@ void RdmaMgr::InitAllocators() {
std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
&RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
- ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
}
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index c976079350..dbe87a6dbb 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -72,24 +72,24 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
+ "cc_header_only_library",
"full_path",
"if_android",
- "if_not_android_mips_and_mips64",
"if_ios",
"if_linux_x86_64",
"if_mobile",
"if_not_mobile",
- "if_windows",
"if_not_windows",
- "tf_copts",
+ "if_windows",
"tf_cc_test",
"tf_cc_tests",
+ "tf_copts",
"tf_cuda_library",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
- "cc_header_only_library",
+ "tf_features_nomodules_if_android",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -113,11 +113,11 @@ load(
"tf_additional_human_readable_json_deps",
"tf_additional_lib_defines",
"tf_additional_lib_deps",
+ "tf_additional_lib_hdrs",
+ "tf_additional_lib_srcs",
"tf_additional_libdevice_data",
"tf_additional_libdevice_deps",
"tf_additional_libdevice_srcs",
- "tf_additional_lib_hdrs",
- "tf_additional_lib_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
"tf_additional_proto_hdrs",
@@ -141,8 +141,8 @@ load(
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
- "tf_cuda_tests_tags",
"if_static",
+ "tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
@@ -150,7 +150,6 @@ load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
)
-load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
exports_files(["ops/ops.pbtxt"])
@@ -234,7 +233,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
protodeps = [
@@ -335,6 +333,7 @@ filegroup(
"platform/init_main.h",
"platform/mem.h",
"platform/mutex.h",
+ "platform/numa.h",
"platform/thread_annotations.h",
],
visibility = ["//visibility:private"],
@@ -793,6 +792,7 @@ tf_cuda_library(
"framework/graph_def_util.h",
"framework/graph_to_functiondef.h",
"framework/kernel_def_builder.h",
+ "framework/kernel_def_util.h",
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
@@ -879,7 +879,7 @@ cc_library(
hdrs = [
"util/stats_calculator.h",
],
- deps = [":platform_base"],
+ copts = tf_copts(),
)
cc_library(
@@ -892,11 +892,26 @@ cc_library(
)
cc_library(
+ name = "exec_on_stall",
+ hdrs = ["util/exec_on_stall.h"],
+ deps = [":framework_lite"],
+)
+
+cc_library(
name = "ptr_util",
hdrs = ["util/ptr_util.h"],
)
cc_library(
+ name = "status_util",
+ hdrs = ["util/status_util.h"],
+ deps = [
+ ":graph",
+ ":lib",
+ ],
+)
+
+cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],
hdrs = ["framework/reader_base.h"],
@@ -993,6 +1008,7 @@ tf_gen_op_libs(
"nn_ops",
"no_op",
"parsing_ops",
+ "random_grad",
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
@@ -1191,6 +1207,7 @@ tf_cuda_library(
hdrs = [
"common_runtime/device.h",
"common_runtime/device_factory.h",
+ "common_runtime/function.h",
"common_runtime/optimization_registry.h",
"common_runtime/shape_refiner.h",
"graph/algorithm.h",
@@ -1245,6 +1262,7 @@ cc_library(
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:grappler",
"//tensorflow/core/kernels:histogram_op",
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
@@ -1443,6 +1461,7 @@ filegroup(
"lib/png/**/*",
"lib/gif/**/*",
"util/events_writer.*",
+ "util/stats_calculator.*",
"util/reporter.*",
"platform/**/cuda_libdevice_path.*",
"platform/default/test_benchmark.*",
@@ -1526,6 +1545,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":protos_all_cc_impl",
+ ":stats_calculator_portable",
"//third_party/eigen3",
"@double_conversion//:double-conversion",
"@nsync//:nsync_cpp",
@@ -1566,6 +1586,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":protos_all_cc_impl",
+ ":stats_calculator_portable",
"//third_party/eigen3",
"@double_conversion//:double-conversion",
"@nsync//:nsync_cpp",
@@ -1931,8 +1952,10 @@ LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob(
"**/*test*",
"lib/gif/**/*",
"lib/jpeg/**/*",
+ "lib/png/**/*",
"platform/gif.h",
"platform/jpeg.h",
+ "platform/png.h",
"platform/**/cuda.h",
"platform/**/stream_executor.h",
],
@@ -2027,6 +2050,7 @@ cc_library(
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
+ "lib/png/**/*",
"platform/**/env_time.cc",
"platform/**/cuda_libdevice_path.cc",
"platform/**/device_tracer.cc",
@@ -2123,6 +2147,39 @@ cc_library(
)
cc_library(
+ name = "png_internal",
+ srcs = ["lib/png/png_io.cc"],
+ hdrs = [
+ "lib/bfloat16/bfloat16.h",
+ "lib/core/casts.h",
+ "lib/core/stringpiece.h",
+ "lib/png/png_io.h",
+ "platform/byte_order.h",
+ "platform/cpu_info.h",
+ "platform/default/integral_types.h",
+ "platform/default/logging.h",
+ "platform/logging.h",
+ "platform/macros.h",
+ "platform/platform.h",
+ "platform/png.h",
+ "platform/types.h",
+ ],
+ copts = tf_copts(),
+ linkopts = select({
+ "//tensorflow:freebsd": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:windows_msvc": [],
+ "//conditions:default": ["-ldl"],
+ }),
+ deps = [
+ ":lib",
+ ":lib_internal",
+ "//tensorflow/core/platform/default/build_config:png",
+ "@zlib_archive//:zlib",
+ ],
+)
+
+cc_library(
name = "tflite_portable_logging",
srcs = [],
hdrs = [
@@ -2230,7 +2287,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
@@ -2252,7 +2308,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
protodeps = [
@@ -2331,6 +2386,7 @@ FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [
FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
+ "framework/resource_var.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -2626,6 +2682,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/dma_helper.h",
"common_runtime/eigen_thread_pool.h",
"common_runtime/executor.h",
+ "common_runtime/executor_factory.h",
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
@@ -2648,6 +2705,8 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
"common_runtime/visitable_allocator.h",
+ "common_runtime/process_state.h",
+ "common_runtime/pool_allocator.h",
"graph/gradients.h",
"graph/quantize_training.h",
] + if_mkl(["graph/mkl_graph_util.h"])
@@ -2675,6 +2734,7 @@ tf_cuda_library(
"common_runtime/device_resolver_local.cc",
"common_runtime/device_set.cc",
"common_runtime/executor.cc",
+ "common_runtime/executor_factory.cc",
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
@@ -2685,7 +2745,9 @@ tf_cuda_library(
"common_runtime/optimization_registry.cc",
"common_runtime/parallel_concat_optimizer.cc",
"common_runtime/placer.cc",
+ "common_runtime/pool_allocator.cc",
"common_runtime/process_function_library_runtime.cc",
+ "common_runtime/process_state.cc",
"common_runtime/process_util.cc",
"common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc",
@@ -2872,6 +2934,7 @@ cc_library(
)
GPU_RUNTIME_HEADERS = [
+ "common_runtime/gpu/cuda_host_allocator.h",
"common_runtime/gpu/gpu_bfc_allocator.h",
"common_runtime/gpu/gpu_cudamalloc_allocator.h",
"common_runtime/gpu/gpu_debug_allocator.h",
@@ -2881,10 +2944,9 @@ GPU_RUNTIME_HEADERS = [
"common_runtime/gpu/gpu_id_utils.h",
"common_runtime/gpu/gpu_init.h",
"common_runtime/gpu/gpu_managed_allocator.h",
+ "common_runtime/gpu/gpu_process_state.h",
"common_runtime/gpu/gpu_stream_util.h",
"common_runtime/gpu/gpu_util.h",
- "common_runtime/gpu/pool_allocator.h",
- "common_runtime/gpu/process_state.h",
"common_runtime/gpu_device_context.h",
]
@@ -2897,11 +2959,10 @@ tf_cuda_library(
"common_runtime/gpu/gpu_device.cc",
"common_runtime/gpu/gpu_device_factory.cc",
"common_runtime/gpu/gpu_managed_allocator.cc",
+ "common_runtime/gpu/gpu_process_state.cc",
"common_runtime/gpu/gpu_stream_util.cc",
"common_runtime/gpu/gpu_util.cc",
"common_runtime/gpu/gpu_util_platform_specific.cc",
- "common_runtime/gpu/pool_allocator.cc",
- "common_runtime/gpu/process_state.cc",
],
hdrs = GPU_RUNTIME_HEADERS,
copts = tf_copts(),
@@ -3212,6 +3273,28 @@ tf_cc_test(
)
tf_cc_test(
+ name = "platform_numa_test",
+ size = "small",
+ srcs = ["platform/numa_test.cc"],
+ tags = [
+ # This test will not pass unless it has access to all NUMA nodes
+ # on the executing machine.
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":framework",
+ ":lib",
+ ":lib_internal",
+ ":lib_test_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
name = "platform_setround_test",
size = "small",
srcs = ["platform/setround_test.cc"],
@@ -3257,6 +3340,18 @@ tf_cc_test(
)
tf_cc_test(
+ name = "exec_on_stall_test",
+ size = "small",
+ srcs = ["util/exec_on_stall_test.cc"],
+ deps = [
+ ":exec_on_stall",
+ ":framework_lite",
+ ":test",
+ ":test_main",
+ ],
+)
+
+tf_cc_test(
name = "lib_jpeg_jpeg_mem_unittest",
srcs = ["lib/jpeg/jpeg_mem_unittest.cc"],
data = glob(["lib/jpeg/testdata/*.jpg"]),
@@ -3347,10 +3442,12 @@ tf_cc_tests(
"framework/bfloat16_test.cc",
"framework/cancellation_test.cc",
"framework/common_shape_fns_test.cc",
+ "framework/device_base_test.cc",
"framework/function_test.cc",
"framework/graph_def_util_test.cc",
"framework/graph_to_functiondef_test.cc",
"framework/kernel_def_builder_test.cc",
+ "framework/kernel_def_util_test.cc",
"framework/memory_types_test.cc",
"framework/node_def_builder_test.cc",
"framework/node_def_util_test.cc",
@@ -3375,6 +3472,7 @@ tf_cc_tests(
"framework/variant_op_registry_test.cc",
"framework/variant_test.cc",
"graph/algorithm_test.cc",
+ "graph/control_flow_test.cc",
"graph/edgeset_test.cc",
"graph/graph_def_builder_test.cc",
"graph/graph_partition_test.cc",
@@ -3399,6 +3497,7 @@ tf_cc_tests(
"util/semver_test.cc",
"util/sparse/sparse_tensor_test.cc",
"util/stat_summarizer_test.cc",
+ "util/status_util_test.cc",
"util/tensor_format_test.cc",
"util/tensor_slice_reader_test.cc",
"util/tensor_slice_set_test.cc",
@@ -3423,6 +3522,7 @@ tf_cc_tests(
":ops",
":protos_all_cc",
":protos_test_cc",
+ ":status_util",
":test",
":test_main",
":testlib",
@@ -3558,6 +3658,7 @@ tf_cc_test_mkl(
deps = [
":core",
":core_cpu",
+ ":core_cpu_internal",
":framework",
":framework_internal",
":test",
@@ -3881,13 +3982,13 @@ tf_cc_test(
],
)
-tf_cc_test(
+tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
srcs = ["common_runtime/direct_session_test.cc"],
+ args = [] + if_cuda(["--heap_check=local"]), # The GPU tracer leaks memory
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
- ":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
@@ -3900,6 +4001,7 @@ tf_cc_test(
":test",
":test_main",
":testlib",
+ "//third_party/eigen3",
"//tensorflow/cc:cc_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
@@ -3913,8 +4015,7 @@ tf_cc_test(
"//tensorflow/core/kernels:queue_ops",
"//tensorflow/core/kernels:session_ops",
"//tensorflow/core/kernels:variable_ops",
- "//third_party/eigen3",
- ],
+ ] + if_cuda([":cuda"]),
)
# This is identical to :common_runtime_direct_session_test with the addition of
diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD
index 19d6438809..06b797e32e 100644
--- a/tensorflow/core/api_def/BUILD
+++ b/tensorflow/core/api_def/BUILD
@@ -4,6 +4,7 @@
# The following targets can be used to access ApiDefs:
# :base_api_def
# :python_api_def
+# :java_api_def
package(
default_visibility = ["//visibility:private"],
@@ -29,6 +30,12 @@ filegroup(
visibility = ["//tensorflow:internal"],
)
+filegroup(
+ name = "java_api_def",
+ srcs = glob(["java_api/*"]),
+ visibility = ["//tensorflow:internal"],
+)
+
cc_library(
name = "excluded_ops_lib",
srcs = ["excluded_ops.cc"],
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc
index 477a0b670e..ae03a61ae6 100644
--- a/tensorflow/core/api_def/api_test.cc
+++ b/tensorflow/core/api_def/api_test.cc
@@ -149,6 +149,33 @@ void TestAllApiDefAttributeNamesAreValid(
}
}
}
+
+void TestDeprecatedAttributesSetCorrectly(
+ const std::unordered_map<string, ApiDef>& api_defs_map) {
+ for (const auto& name_and_api_def : api_defs_map) {
+ int num_deprecated_endpoints = 0;
+ const auto& api_def = name_and_api_def.second;
+ for (const auto& endpoint : api_def.endpoint()) {
+ if (endpoint.deprecated()) {
+ ++num_deprecated_endpoints;
+ }
+ }
+
+ const auto& name = name_and_api_def.first;
+ ASSERT_TRUE(api_def.deprecation_message().empty() ||
+ num_deprecated_endpoints == 0)
+ << "Endpoints are set to 'deprecated' for deprecated op " << name
+ << ". If an op is deprecated (i.e. deprecation_message is set), "
+ << "all the endpoints are deprecated implicitly and 'deprecated' "
+ << "field should not be set.";
+ if (num_deprecated_endpoints > 0) {
+ ASSERT_NE(num_deprecated_endpoints, api_def.endpoint_size())
+ << "All " << name << " endpoints are deprecated. Please, set "
+ << "deprecation_message in api_def_" << name << ".pbtxt instead. "
+ << "to indicate that the op is deprecated.";
+ }
+ }
+}
} // namespace
class BaseApiTest : public ::testing::Test {
@@ -171,7 +198,7 @@ TEST_F(BaseApiTest, AllOpsAreInApiDef) {
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
continue;
}
- ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end())
+ EXPECT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end())
<< op.name() << " op does not have api_def_*.pbtxt file. "
<< "Please add api_def_" << op.name() << ".pbtxt file "
<< "under tensorflow/core/api_def/base_api/ directory.";
@@ -236,6 +263,11 @@ TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) {
TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_);
}
+// Checks that deprecation is set correctly.
+TEST_F(BaseApiTest, DeprecationSetCorrectly) {
+ TestDeprecatedAttributesSetCorrectly(api_defs_map_);
+}
+
class PythonApiTest : public ::testing::Test {
protected:
PythonApiTest() {
@@ -272,4 +304,9 @@ TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) {
TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_);
}
+// Checks that deprecation is set correctly.
+TEST_F(PythonApiTest, DeprecationSetCorrectly) {
+ TestDeprecatedAttributesSetCorrectly(api_defs_map_);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt
new file mode 100644
index 0000000000..0c5b1eb45a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BatchDatasetV2.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "BatchDatasetV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a batch.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ summary: "Creates a dataset that batches `batch_size` elements from `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt
new file mode 100644
index 0000000000..09eff6177b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BatchFunction.pbtxt
@@ -0,0 +1,128 @@
+op {
+ graph_op_name: "BatchFunction"
+ in_arg {
+ name: "in_tensors"
+ description: <<END
+The tensors to be batched.
+END
+ }
+ in_arg {
+ name: "captured_tensors"
+ description: <<END
+The tensors which are captured in the function, and don't need
+to be batched.
+END
+ }
+ out_arg {
+ name: "out_tensors"
+ description: <<END
+The output tensors.
+END
+ }
+ attr {
+ name: "num_batch_threads"
+ description: <<END
+Number of scheduling threads for processing batches of work.
+Determines the number of batches processed in parallel.
+END
+ }
+ attr {
+ name: "max_batch_size"
+ description: <<END
+Batch sizes will never be bigger than this.
+END
+ }
+ attr {
+ name: "batch_timeout_micros"
+ description: <<END
+Maximum number of microseconds to wait before outputting
+an incomplete batch.
+END
+ }
+ attr {
+ name: "max_enqueued_batches"
+ description: <<END
+Maximum number of batches enqueued. Default: 10.
+END
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ description: <<END
+Optional list of allowed batch sizes. If left empty, does
+nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
+batches up to one of those sizes. The entries must increase monotonically, and
+the final entry must equal max_batch_size.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+Controls the scope of sharing of this batch.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+Concurrently running instances of batch in the same device with the
+same container and shared_name will batch their elements together. If left
+empty, the op name will be used as the shared name.
+END
+ }
+ attr {
+ name: "Tin"
+ description: <<END
+the types of tensors to be batched.
+END
+ }
+ attr {
+ name: "Tcaptured"
+ description: <<END
+the types of the captured tensors.
+END
+ }
+ attr {
+ name: "Tout"
+ description: <<END
+the types of the output tensors.
+END
+ }
+ summary: "Batches all the inputs tensors to the computation done by the function."
+ description: <<END
+So, for example, in the following code
+
+ ```python
+
+ # This input will be captured.
+ y = tf.placeholder_with_default(1.0, shape=[])
+
+ @tf.Defun(tf.float32)
+ def computation(a):
+ return tf.matmul(a, a) + y
+
+ b = gen_batch_ops.batch_function(
+ f=computation
+ in_tensors=[a],
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg],
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ allowed_batch_sizes=[3, 10],
+ batching_queue="")
+
+If more than one session.run call is simultaneously trying to compute `b`
+the values of `a` will be gathered, non-deterministically concatenated
+along the first axis, and only one thread will run the computation.
+
+Assumes that all arguments of the function are Tensors which will be batched
+along their first dimension.
+
+Arguments that are captured, are not batched. The session.run call which does
+the concatenation, will use the values of the captured tensors available to it.
+Therefore, typical uses of captured tensors should involve values which remain
+unchanged across session.run calls. Inference is a good example of this.
+
+SparseTensor is not supported. The return value of the decorated function
+must be a Tensor or a list/tuple of Tensors.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt b/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
new file mode 100644
index 0000000000..08313cebb9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "BesselI0e"
+ summary: "Computes the Bessel i0e function of `x` element-wise."
+ description: <<END
+Exponentially scaled modified Bessel function of order 0 defined as
+`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
+
+This function is faster and numerically stabler than `bessel_i0(x)`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt b/tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
new file mode 100644
index 0000000000..3e46a9506f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "BesselI1e"
+ summary: "Computes the Bessel i1e function of `x` element-wise."
+ description: <<END
+Exponentially scaled modified Bessel function of order 0 defined as
+`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
+
+This function is faster and numerically stabler than `bessel_i1(x)`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt
new file mode 100644
index 0000000000..b58b974eb4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCenterBias.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "BoostedTreesCenterBias"
+ visibility: HIDDEN
+ in_arg {
+ name: "tree_ensemble_handle"
+ description: <<END
+Handle to the tree ensemble.
+END
+ }
+ in_arg {
+ name: "mean_gradients"
+ description: <<END
+A tensor with shape=[logits_dimension] with mean of gradients for a first node.
+END
+ }
+ in_arg {
+ name: "mean_hessians"
+ description: <<END
+A tensor with shape=[logits_dimension] mean of hessians for a first node.
+END
+ }
+in_arg {
+ name: "l1"
+ description: <<END
+l1 regularization factor on leaf weights, per instance based.
+END
+ }
+ in_arg {
+ name: "l2"
+ description: <<END
+l2 regularization factor on leaf weights, per instance based.
+END
+ }
+ out_arg {
+ name: "continue_centering"
+ description: <<END
+Bool, whether to continue bias centering.
+END
+ }
+ summary: "Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering."
+} \ No newline at end of file
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt
new file mode 100644
index 0000000000..206fa3cc98
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesExampleDebugOutputs.pbtxt
@@ -0,0 +1,36 @@
+op {
+ graph_op_name: "BoostedTreesExampleDebugOutputs"
+ visibility: HIDDEN
+ in_arg {
+ name: "bucketized_features"
+ description: <<END
+A list of rank 1 Tensors containing bucket id for each
+feature.
+END
+ }
+ out_arg {
+ name: "examples_debug_outputs_serialized"
+ description: <<END
+Output rank 1 Tensor containing a proto serialized as a string for each example.
+END
+ }
+ attr {
+ name: "num_bucketized_features"
+ description: <<END
+Inferred.
+END
+ }
+ attr {
+ name: "logits_dimension"
+ description: <<END
+scalar, dimension of the logits, to be used for constructing the protos in
+examples_debug_outputs_serialized.
+END
+ }
+ summary: "Debugging/model interpretability outputs for each example."
+ description: <<END
+It traverses all the trees and computes debug metrics for individual examples,
+such as getting split feature ids and logits after each split along the decision
+path used to compute directional feature contributions.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt
new file mode 100644
index 0000000000..55dd6179dd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DatasetToGraph.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "DatasetToGraph"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the dataset to return the graph representation for.
+END
+ }
+ out_arg {
+ name: "graph"
+ description: <<END
+The graph representation of the dataset (as serialized GraphDef).
+END
+ }
+ summary: "Returns a serialized GraphDef representing `input_dataset`."
+ description: <<END
+Returns a graph representation for `input_dataset`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt b/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt
new file mode 100644
index 0000000000..d110aba42b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_FakeParam.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "FakeParam"
+ visibility: SKIP
+ out_arg {
+ name: "output"
+ description: <<END
+ \"Fake\" output value. This should not be consumed by another op.
+END
+ }
+ attr { name: "dtype" description: "The type of the output." }
+ attr {
+ name: "shape"
+ description: <<END
+ The purported shape of the output. This is only used for shape inference;
+ the output will not necessarily have this shape. Can be a partial shape.
+END
+ }
+ summary: <<END
+ This op is used as a placeholder in If branch functions. It doesn't provide a
+ valid output when run, so must either be removed (e.g. replaced with a
+ function input) or guaranteed not to be used (e.g. if mirroring an
+ intermediate output needed for the gradient computation of the other branch).
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
new file mode 100644
index 0000000000..ffd01ba5cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
@@ -0,0 +1,3 @@
+op {
+ graph_op_name: "FeatureStatsDataset"
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
index 6cd76ff340..342a1f6b05 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
@@ -25,7 +25,7 @@ END
(K-1)-dimensional tensor of indices into `params`, where each element defines a
slice of `params`:
- output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]]
+ output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
Whereas in @{tf.gather} `indices` defines slices into the first
dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
diff --git a/tensorflow/core/api_def/base_api/api_def_GcsConfigureBlockCache.pbtxt b/tensorflow/core/api_def/base_api/api_def_GcsConfigureBlockCache.pbtxt
deleted file mode 100644
index 9d32940c64..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_GcsConfigureBlockCache.pbtxt
+++ /dev/null
@@ -1,9 +0,0 @@
-op {
- graph_op_name: "GcsConfigureBlockCache"
- summary: "Re-configures the GCS block cache with the new configuration values."
- description: <<END
-If the values are the same as already configured values, this op is a no-op. If
-they are different, the current contents of the block cache is dropped, and a
-new block cache is created fresh.
-END
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_GcsConfigureCredentials.pbtxt b/tensorflow/core/api_def/base_api/api_def_GcsConfigureCredentials.pbtxt
deleted file mode 100644
index 786022ae64..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_GcsConfigureCredentials.pbtxt
+++ /dev/null
@@ -1,33 +0,0 @@
-op {
- graph_op_name: "GcsConfigureCredentials"
- summary: "Configures the credentials used by the GCS client of the local TF runtime."
- description: <<END0
-The json input can be of the format:
-
-1. Refresh Token:
-{
- "client_id": "<redacted>",
- "client_secret": "<redacted>",
- "refresh_token: "<redacted>",
- "type": "authorized_user",
-}
-
-2. Service Account:
-{
- "type": "service_account",
- "project_id": "<redacted>",
- "private_key_id": "<redacted>",
- "private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
- "client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
- "client_id": "<REDACTED>",
- # Some additional fields elided
-}
-
-Note the credentials established through this method are shared across all
-sessions run on this runtime.
-
-Note be sure to feed the inputs to this op to ensure the credentials are not
-stored in a constant op within the graph that might accidentally be checkpointed
-or in other ways be persisted or exfiltrated.
-END0
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt
new file mode 100644
index 0000000000..747a8badfd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IgammaGradA.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "IgammaGradA"
+ visibility: HIDDEN
+ summary: "Computes the gradient of `igamma(a, x)` wrt `a`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt
new file mode 100644
index 0000000000..9d464b2aea
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorFromStringHandleV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorFromStringHandleV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt
new file mode 100644
index 0000000000..becc729016
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IteratorV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "IteratorV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt b/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt
index 94a4ef574d..f706810662 100644
--- a/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_LinSpace.pbtxt
@@ -3,19 +3,19 @@ op {
in_arg {
name: "start"
description: <<END
-First entry in the range.
+0-D tensor. First entry in the range.
END
}
in_arg {
name: "stop"
description: <<END
-Last entry in the range.
+0-D tensor. Last entry in the range.
END
}
in_arg {
name: "num"
description: <<END
-Number of values to generate.
+0-D tensor. Number of values to generate.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
index 0d680f6531..d7b56aec87 100644
--- a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
@@ -18,7 +18,7 @@ END
}
summary: "Computes the matrix exponential of one or more square matrices:"
description: <<END
-exp(A) = \sum_{n=0}^\infty A^n/n!
+\\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
The exponential is computed using a combination of the scaling and squaring
method and the Pade approximation. Details can be founds in:
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixLogarithm.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixLogarithm.pbtxt
index a6c4d0d400..9e80064d15 100644
--- a/tensorflow/core/api_def/base_api/api_def_MatrixLogarithm.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixLogarithm.pbtxt
@@ -20,7 +20,7 @@ END
summary: "Computes the matrix logarithm of one or more square matrices:"
description: <<END
-log(exp(A)) = A
+\\(log(exp(A)) = A\\)
This op is only defined for complex matrices. If A is positive-definite and
real, then casting to a complex matrix, taking the logarithm and casting back
diff --git a/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
new file mode 100644
index 0000000000..180edb15a4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
@@ -0,0 +1,62 @@
+op {
+ graph_op_name: "NonMaxSuppressionWithOverlaps"
+ in_arg {
+ name: "overlaps"
+ description: <<END
+A 2-D float tensor of shape `[num_boxes, num_boxes]` representing
+the n-by-n box overlap values.
+END
+ }
+ in_arg {
+ name: "scores"
+ description: <<END
+A 1-D float tensor of shape `[num_boxes]` representing a single
+score corresponding to each box (each row of boxes).
+END
+ }
+ in_arg {
+ name: "max_output_size"
+ description: <<END
+A scalar integer tensor representing the maximum number of
+boxes to be selected by non max suppression.
+END
+ }
+ in_arg {
+ name: "overlap_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding whether
+boxes overlap too.
+END
+ }
+ in_arg {
+ name: "score_threshold"
+ description: <<END
+A 0-D float tensor representing the threshold for deciding when to remove
+boxes based on score.
+END
+ }
+ out_arg {
+ name: "selected_indices"
+ description: <<END
+A 1-D integer tensor of shape `[M]` representing the selected
+indices from the boxes tensor, where `M <= max_output_size`.
+END
+ }
+ summary: "Greedily selects a subset of bounding boxes in descending order of score,"
+ description: <<END
+pruning away boxes that have high overlaps
+with previously selected boxes. Bounding boxes with score less than
+`score_threshold` are removed. N-by-n overlap values are supplied as square matrix,
+which allows for defining a custom overlap criterium (eg. intersection over union,
+intersection over area, etc.).
+
+The output of this operation is a set of integers indexing into the input
+collection of bounding boxes representing the selected boxes. The bounding
+box coordinates corresponding to the selected indices can then be obtained
+using the `tf.gather operation`. For example:
+
+ selected_indices = tf.image.non_max_suppression_with_overlaps(
+ overlaps, scores, max_output_size, overlap_threshold, score_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt
new file mode 100644
index 0000000000..f26eb6e3c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_OptimizeDataset.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "OptimizeDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "optimizations"
+ description: <<END
+A `tf.string` vector `tf.Tensor` identifying optimizations to use.
+END
+ }
+ summary: "Creates a dataset by applying optimizations to `input_dataset`."
+ description: <<END
+Creates a dataset by applying optimizations to `input_dataset`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt
new file mode 100644
index 0000000000..9fefc0c418
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PaddedBatchDatasetV2.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "PaddedBatchDatasetV2"
+ visibility: HIDDEN
+ in_arg {
+ name: "batch_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a
+batch.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether the last batch should be dropped in case its size
+is smaller than desired.
+END
+ }
+ in_arg {
+ name: "padded_shapes"
+ description: <<END
+A list of int64 tensors representing the desired padded shapes
+of the corresponding output components. These shapes may be partially
+specified, using `-1` to indicate that a particular dimension should be
+padded to the maximum size of all batch elements.
+END
+ }
+ in_arg {
+ name: "padding_values"
+ description: <<END
+A list of scalars containing the padding value to use for
+each of the outputs.
+END
+ }
+ summary: "Creates a dataset that batches and pads `batch_size` elements from the input."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
index 41a9cfaa27..9b500d0b58 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
@@ -44,6 +44,7 @@ END
summary: "Quantizes then dequantizes a tensor."
description: <<END
This op simulates the precision loss from the quantized forward pass by:
+
1. Quantizing the tensor to fixed point numbers, which should match the target
quantization method when it is used in inference.
2. Dequantizing it back to floating point numbers for the following ops, most
@@ -85,9 +86,9 @@ e.g.
10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it
would update input_min to be 128.0 / 12.7 = -10.07874
* if the output is unsigned, input_min is forced to be 0, and only the
- specifide input_max is used.
+ specified input_max is used.
-After determining the scale_factor and updating the input tange, it applies the
+After determining the scale_factor and updating the input range, it applies the
following to each value in the 'input' tensor.
output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor.
diff --git a/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt
new file mode 100644
index 0000000000..d2bd76f8b9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RandomGammaGrad.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "RandomGammaGrad"
+ visibility: HIDDEN
+ summary: "Computes the derivative of a Gamma random sample w.r.t. `alpha`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
index d13866ddaa..b447d09377 100644
--- a/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceJoin.pbtxt
@@ -36,7 +36,7 @@ END
summary: "Joins a string Tensor across the given dimensions."
description: <<END
Computes the string join across dimensions in the given string Tensor of shape
-`[d_0, d_1, ..., d_n-1]`. Returns a new Tensor created by joining the input
+`[\\(d_0, d_1, ..., d_{n-1}\\)]`. Returns a new Tensor created by joining the input
strings with the given separator (default: empty string). Negative indices are
counted backwards from the end, with `-1` being equivalent to `n - 1`. If
indices are not specified, joins across all dimensions beginning from `n - 1`
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..3b3a274df5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdAdd.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ResourceScatterNdAdd"
+ in_arg {
+ name: "ref"
+ description: <<END
+A resource handle. Must be from a VarHandleOp.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A Tensor. Must be one of the following types: int32, int64.
+A tensor of indices into ref.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A Tensor. Must have the same type as ref. A tensor of
+values to add to ref.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+An optional bool. Defaults to True. If True, the assignment will
+be protected by a lock; otherwise the behavior is undefined,
+but may exhibit less contention.
+END
+ }
+ summary: "Adds sparse `updates` to individual values or slices within a given"
+ description: <<END
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_add(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 12, 3, 14, 14, 6, 7, 20]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
index 6f1121dd37..5ab5917bd3 100644
--- a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBox.pbtxt
@@ -68,7 +68,7 @@ END
name: "area_range"
description: <<END
The cropped area of the image must contain a fraction of the
-supplied image within in this range.
+supplied image within this range.
END
}
attr {
diff --git a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
index 473aec50aa..663fc582d4 100644
--- a/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SampleDistortedBoundingBoxV2.pbtxt
@@ -68,7 +68,7 @@ END
name: "area_range"
description: <<END
The cropped area of the image must contain a fraction of the
-supplied image within in this range.
+supplied image within this range.
END
}
attr {
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdAdd.pbtxt
index b0665ebf0e..a9a7646314 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdAdd.pbtxt
@@ -42,7 +42,7 @@ within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
-It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+It must be shape `\\([d_0, ..., d_{Q-2}, K]\\)` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
@@ -50,9 +50,7 @@ dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
-```
-[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-```
+$$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
elements. In Python, that addition would look like this:
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdNonAliasingAdd.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdNonAliasingAdd.pbtxt
index e5c64c2b90..35116e5f6a 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdNonAliasingAdd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdNonAliasingAdd.pbtxt
@@ -37,7 +37,7 @@ respect to both `input` and `updates`.
`input` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `input`.
-It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or `(P-K)`-dimensional slices
@@ -45,9 +45,7 @@ indices into elements (if `K = P`) or `(P-K)`-dimensional slices
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
-```
-[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].
-```
+$$[d_0, ..., d_{Q-2}, input.shape[K], ..., input.shape[P-1]].$$
For example, say we want to add 4 scattered elements to a rank-1 tensor to 8
elements. In Python, that addition would look like this:
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdSub.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdSub.pbtxt
index 333db017f5..99e5c4908b 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdSub.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdSub.pbtxt
@@ -42,7 +42,7 @@ within a given variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
-It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
@@ -50,9 +50,7 @@ dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
-```
-[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-```
+$$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$
For example, say we want to subtract 4 scattered elements from a rank-1 tensor
with 8 elements. In Python, that subtraction would look like this:
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
index 33d98262d5..cb57c171b9 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
@@ -42,7 +42,7 @@ variable according to `indices`.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
-It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+It must be shape \\([d_0, ..., d_{Q-2}, K]\\) where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
@@ -50,9 +50,7 @@ dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
-```
-[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
-```
+$$[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].$$
For example, say we want to update 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
diff --git a/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt b/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt
index cbe76de415..985f09312f 100644
--- a/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Selu.pbtxt
@@ -4,6 +4,10 @@ op {
description: <<END
if < 0, `scale * features` otherwise.
+To be used together with
+`initializer = tf.variance_scaling_initializer(factor=1.0, mode='FAN_IN')`.
+For correct dropout, use `tf.contrib.nn.alpha_dropout`.
+
See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt
new file mode 100644
index 0000000000..b5758ddbfb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SinkDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "SinkDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "A placeholder for input pipeline graph optimizations."
+ description: <<END
+A placeholder for input pipeline graph optimizations.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SlideDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_SlideDataset.pbtxt
index 9fabe7863e..c80ee77f73 100644
--- a/tensorflow/core/api_def/base_api/api_def_SlideDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SlideDataset.pbtxt
@@ -11,7 +11,7 @@ END
name: "stride"
description: <<END
A scalar representing the steps moving the sliding window
-forward in one iteration. It must be in `[1, window_size)`.
+forward in one iteration. It must be positive.
END
}
summary: "Creates a dataset that passes a sliding window over `input_dataset`."
diff --git a/tensorflow/core/api_def/base_api/api_def_Softmax.pbtxt b/tensorflow/core/api_def/base_api/api_def_Softmax.pbtxt
index 43884824c9..b51b468c3d 100644
--- a/tensorflow/core/api_def/base_api/api_def_Softmax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Softmax.pbtxt
@@ -16,6 +16,6 @@ END
description: <<END
For each batch `i` and class `j` we have
- softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
+ $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagrad.pbtxt
index 1698e2def0..06409d8db2 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyAdagrad.pbtxt
@@ -47,7 +47,7 @@ END
summary: "Update relevant entries in \'*var\' and \'*accum\' according to the adagrad scheme."
description: <<END
That is for rows we have grad for, we update var and accum as follows:
-accum += grad * grad
-var -= lr * grad * (1 / sqrt(accum))
+$$accum += grad * grad$$
+$$var -= lr * grad * (1 / sqrt(accum))$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyCenteredRMSProp.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyCenteredRMSProp.pbtxt
index 2c6a36bf45..b3f2d3ea62 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyCenteredRMSProp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyCenteredRMSProp.pbtxt
@@ -83,8 +83,8 @@ mean_square = decay * mean_square + (1-decay) * gradient ** 2
mean_grad = decay * mean_grad + (1-decay) * gradient
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
-ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-var <- var - mom
+$$ms <- rho * ms_{t-1} + (1-rho) * grad * grad$$
+$$mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)$$
+$$var <- var - mom$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyFtrl.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyFtrl.pbtxt
index 524b5c5a47..9a6b6bca5f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyFtrl.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyFtrl.pbtxt
@@ -71,10 +71,10 @@ END
summary: "Update relevant entries in \'*var\' according to the Ftrl-proximal scheme."
description: <<END
That is for rows we have grad for, we update var, accum and linear as follows:
-accum_new = accum + grad * grad
-linear += grad + (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-accum = accum_new
+$$accum_new = accum + grad * grad$$
+$$linear += grad + (accum_{new}^{-lr_{power}} - accum^{-lr_{power}} / lr * var$$
+$$quadratic = 1.0 / (accum_{new}^{lr_{power}} * lr) + 2 * l2$$
+$$var = (sign(linear) * l1 - linear) / quadratic\ if\ |linear| > l1\ else\ 0.0$$
+$$accum = accum_{new}$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyMomentum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyMomentum.pbtxt
index 8d9ac9ea3f..17dbb488de 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyMomentum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyMomentum.pbtxt
@@ -64,7 +64,7 @@ Set use_nesterov = True if you want to use Nesterov momentum.
That is for rows we have grad for, we update var and accum as follows:
-accum = accum * momentum + grad
-var -= lr * accum
+$$accum = accum * momentum + grad$$
+$$var -= lr * accum$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalAdagrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalAdagrad.pbtxt
index 80541b91c7..0b24f2ddd1 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalAdagrad.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalAdagrad.pbtxt
@@ -58,9 +58,9 @@ END
summary: "Sparse update entries in \'*var\' and \'*accum\' according to FOBOS algorithm."
description: <<END
That is for rows we have grad for, we update var and accum as follows:
-accum += grad * grad
-prox_v = var
-prox_v -= lr * grad * (1 / sqrt(accum))
-var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
+$$accum += grad * grad$$
+$$prox_v = var$$
+$$prox_v -= lr * grad * (1 / sqrt(accum))$$
+$$var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalGradientDescent.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalGradientDescent.pbtxt
index 5200e5516d..9dc53860e5 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalGradientDescent.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyProximalGradientDescent.pbtxt
@@ -52,7 +52,7 @@ END
summary: "Sparse update \'*var\' as FOBOS algorithm with fixed learning rate."
description: <<END
That is for rows we have grad for, we update var as follows:
-prox_v = var - alpha * grad
-var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}
+$$prox_v = var - alpha * grad$$
+$$var = sign(prox_v)/(1+alpha*l2) * max{|prox_v|-alpha*l1,0}$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseApplyRMSProp.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseApplyRMSProp.pbtxt
index a4dbd608b8..ee9f57fa9d 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseApplyRMSProp.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseApplyRMSProp.pbtxt
@@ -71,8 +71,8 @@ and mom will not update in iterations during which the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-var <- var - mom
+$$ms <- rho * ms_{t-1} + (1-rho) * grad * grad$$
+$$mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)$$
+$$var <- var - mom$$
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseMatMul.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseMatMul.pbtxt
index 58f2ede629..fe568df388 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseMatMul.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseMatMul.pbtxt
@@ -3,9 +3,11 @@ op {
summary: "Multiply matrix \"a\" by matrix \"b\"."
description: <<END
The inputs must be two-dimensional matrices and the inner dimension of "a" must
-match the outer dimension of "b". This op is optimized for the case where at
-least one of "a" or "b" is sparse. The breakeven for using this versus a dense
-matrix multiply on one platform was 30% zero values in the sparse matrix.
+match the outer dimension of "b". Both "a" and "b" must be `Tensor`s not
+`SparseTensor`s. This op is optimized for the case where at least one of "a" or
+"b" is sparse, in the sense that they have a large proportion of zero values.
+The breakeven for using this versus a dense matrix multiply on one platform was
+30% zero values in the sparse matrix.
The gradient computation of this operation will only take advantage of sparsity
in the input gradient when that gradient comes from a Relu.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSliceGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSliceGrad.pbtxt
new file mode 100644
index 0000000000..51af6adcf1
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSliceGrad.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "SparseSliceGrad"
+ in_arg {
+ name: "backprop_val_grad"
+ description: <<END
+1-D. The gradient with respect to
+the non-empty values of the sliced `SparseTensor`.
+END
+ }
+ in_arg {
+ name: "input_indices"
+ description: <<END
+2-D. The `indices` of the input `SparseTensor`.
+END
+ }
+ in_arg {
+ name: "input_start"
+ description: <<END
+1-D. tensor represents the start of the slice.
+END
+ }
+ in_arg {
+ name: "output_indices"
+ description: <<END
+2-D. The `indices` of the sliced `SparseTensor`.
+END
+ }
+ out_arg {
+ name: "val_grad"
+ description: <<END
+1-D. The gradient with respect to the non-empty values of input `SparseTensor`.
+END
+ }
+ summary: "The gradient operator for the SparseSlice op."
+ description: <<END
+This op takes in the upstream gradient w.r.t. non-empty values of
+the sliced `SparseTensor`, and outputs the gradients w.r.t.
+the non-empty values of input `SparseTensor`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StatefulPartitionedCall.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatefulPartitionedCall.pbtxt
new file mode 100644
index 0000000000..c4cb4e362a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatefulPartitionedCall.pbtxt
@@ -0,0 +1,25 @@
+
+op {
+ graph_op_name: "StatefulPartitionedCall"
+ in_arg {
+ name: "args"
+ description: "A list of input tensors."
+ }
+ out_arg {
+ name: "output"
+ description: "A list of return values."
+ }
+ attr { name: "Tin" description: "A list of input types." }
+ attr { name: "Tout" description: "A list of output types." }
+ attr {
+ name: "f"
+ description: <<END
+ A function that takes 'args', a list of tensors, and returns 'output',
+ another list of tensors. Input and output types are specified by 'Tin'
+ and 'Tout'. The function body of f will be placed and partitioned across
+ devices, setting this op apart from the regular Call op. This op is
+ stateful.
+END
+ }
+ summary: "returns `f(inputs)`, where `f`'s body is placed and partitioned."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt
new file mode 100644
index 0000000000..6e13d0d049
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringSplitV2.pbtxt
@@ -0,0 +1,48 @@
+op {
+ graph_op_name: "StringSplitV2"
+ in_arg {
+ name: "input"
+ description: <<END
+`1-D` string `Tensor`, the strings to split.
+END
+ }
+ in_arg {
+ name: "sep"
+ description: <<END
+`0-D` string `Tensor`, the delimiter character.
+END
+ }
+ attr {
+ name: "maxsplit"
+ description: <<END
+An `int`. If `maxsplit > 0`, limit of the split of the result.
+END
+ }
+ summary: "Split elements of `source` based on `sep` into a `SparseTensor`."
+ description: <<END
+Let N be the size of source (typically N will be the batch size). Split each
+element of `source` based on `sep` and return a `SparseTensor`
+containing the split tokens. Empty tokens are ignored.
+
+For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+then the output will be
+```
+st.indices = [0, 0;
+ 0, 1;
+ 1, 0;
+ 1, 1;
+ 1, 2]
+st.shape = [2, 3]
+st.values = ['hello', 'world', 'a', 'b', 'c']
+```
+
+If `sep` is given, consecutive delimiters are not grouped together and are
+deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+string, consecutive whitespace are regarded as a single separator, and the
+result will contain no empty strings at the startor end if the string has
+leading or trailing whitespace.
+
+Note that the above mentioned behavior matches python's str.split.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt
new file mode 100644
index 0000000000..dd37b94ffa
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorArrayGradWithShape.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "TensorArrayGradWithShape"
+ endpoint {
+ name: "TensorArrayGradWithShape"
+ }
+ in_arg {
+ name: "handle"
+ description: <<END
+The handle to the forward TensorArray.
+END
+ }
+ in_arg {
+ name: "flow_in"
+ description: <<END
+A float scalar that enforces proper chaining of operations.
+END
+ }
+ in_arg {
+ name: "shape_to_prepend"
+ description: <<END
+An int32 vector representing a shape. Elements in the gradient accumulator will
+have shape which is this shape_to_prepend value concatenated with shape of the
+elements in the TensorArray corresponding to the input handle.
+END
+ }
+ attr {
+ name: "source"
+ description: <<END
+The gradient source string, used to decide which gradient TensorArray
+to return.
+END
+ }
+ summary: "Creates a TensorArray for storing multiple gradients of values in the given handle."
+ description: <<END
+Similar to TensorArrayGradV3. However it creates an accumulator with an
+expanded shape compared to the input TensorArray whose gradient is being
+computed. This enables multiple gradients for the same TensorArray to be
+calculated using the same accumulator.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index eb5d0d1247..9aeabd030d 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -20,7 +20,7 @@ Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
segments.
Computes a tensor such that
-`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such
+\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
need not be sorted and need not cover all values in the full
range of valid values.
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
new file mode 100644
index 0000000000..1bc3660479
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -0,0 +1,11 @@
+op {
+ visibility: HIDDEN
+ graph_op_name: "WindowDataset"
+ in_arg {
+ name: "window_size"
+ description: <<END
+A scalar representing the number of elements to accumulate in a window.
+END
+ }
+ summary: "A dataset that creates window datasets from the input dataset."
+}
diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc
index 07ac974ff9..931c943dbc 100644
--- a/tensorflow/core/api_def/excluded_ops.cc
+++ b/tensorflow/core/api_def/excluded_ops.cc
@@ -20,7 +20,8 @@ namespace tensorflow {
const std::unordered_set<std::string>* GetExcludedOps() {
static std::unordered_set<std::string>* excluded_ops =
new std::unordered_set<std::string>(
- {"BigQueryReader", "GenerateBigQueryReaderPartitions"});
+ {"BigQueryReader", "GenerateBigQueryReaderPartitions",
+ "GcsConfigureBlockCache", "GcsConfigureCredentials"});
return excluded_ops;
}
} // namespace tensorflow
diff --git a/tensorflow/core/api_def/java_api/api_def_Assert.pbtxt b/tensorflow/core/api_def/java_api/api_def_Assert.pbtxt
new file mode 100644
index 0000000000..b1f868897d
--- /dev/null
+++ b/tensorflow/core/api_def/java_api/api_def_Assert.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Assert" #TODO(karllessard) escape that reserved name
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/java_api/api_def_Const.pbtxt b/tensorflow/core/api_def/java_api/api_def_Const.pbtxt
new file mode 100644
index 0000000000..2dbdca34e0
--- /dev/null
+++ b/tensorflow/core/api_def/java_api/api_def_Const.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Const" #TODO(karllessard) escape that reserved name
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/java_api/api_def_Switch.pbtxt b/tensorflow/core/api_def/java_api/api_def_Switch.pbtxt
new file mode 100644
index 0000000000..0d3362a91e
--- /dev/null
+++ b/tensorflow/core/api_def/java_api/api_def_Switch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Switch" #TODO(karllessard) escape that reserved name
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
new file mode 100644
index 0000000000..1fd8baf05f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Acos.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Acos"
+ endpoint {
+ name: "math.acos"
+ }
+ endpoint {
+ name: "acos"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
new file mode 100644
index 0000000000..f7946652ef
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Acosh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Acosh"
+ endpoint {
+ name: "math.acosh"
+ }
+ endpoint {
+ name: "acosh"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Add.pbtxt b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
new file mode 100644
index 0000000000..fb505a91ac
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Add.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Add"
+ endpoint {
+ name: "math.add"
+ }
+ endpoint {
+ name: "add"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
new file mode 100644
index 0000000000..ea65543a76
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_AsString.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "AsString"
+ endpoint {
+ name: "dtypes.as_string"
+ }
+ endpoint {
+ name: "as_string"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
new file mode 100644
index 0000000000..eedf4553c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Asin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Asin"
+ endpoint {
+ name: "math.asin"
+ }
+ endpoint {
+ name: "asin"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
new file mode 100644
index 0000000000..10c2fb356e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Asinh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Asinh"
+ endpoint {
+ name: "math.asinh"
+ }
+ endpoint {
+ name: "asinh"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
new file mode 100644
index 0000000000..03dd5dc848
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atan"
+ endpoint {
+ name: "math.atan"
+ }
+ endpoint {
+ name: "atan"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
new file mode 100644
index 0000000000..85b27bd881
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atan2.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atan2"
+ endpoint {
+ name: "math.atan2"
+ }
+ endpoint {
+ name: "atan2"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
new file mode 100644
index 0000000000..ee7c0600d6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Atanh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Atanh"
+ endpoint {
+ name: "math.atanh"
+ }
+ endpoint {
+ name: "atanh"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
new file mode 100644
index 0000000000..9552fc92e3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "BatchToSpaceND"
+ endpoint {
+ name: "manip.batch_to_space_nd"
+ }
+ endpoint {
+ name: "batch_to_space_nd"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BesselI0e.pbtxt b/tensorflow/core/api_def/python_api/api_def_BesselI0e.pbtxt
new file mode 100644
index 0000000000..7965af4916
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_BesselI0e.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "BesselI0e"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BesselI1e.pbtxt b/tensorflow/core/api_def/python_api/api_def_BesselI1e.pbtxt
new file mode 100644
index 0000000000..dffd296f6d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_BesselI1e.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "BesselI1e"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
new file mode 100644
index 0000000000..7ad7cbcba9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Betainc.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Betainc"
+ endpoint {
+ name: "math.betainc"
+ }
+ endpoint {
+ name: "betainc"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt b/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt
deleted file mode 100644
index 083eeced81..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_BroadcastTo.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "BroadcastTo"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
new file mode 100644
index 0000000000..f2265bad56
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Ceil.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Ceil"
+ endpoint {
+ name: "math.ceil"
+ }
+ endpoint {
+ name: "ceil"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
new file mode 100644
index 0000000000..541b09a591
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CheckNumerics.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "CheckNumerics"
+ endpoint {
+ name: "debugging.check_numerics"
+ }
+ endpoint {
+ name: "check_numerics"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
index 2676c92bfb..942f4e6ed8 100644
--- a/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Cholesky.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "Cholesky"
endpoint {
- name: "cholesky"
+ name: "linalg.cholesky"
}
endpoint {
- name: "linalg.cholesky"
+ name: "cholesky"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
new file mode 100644
index 0000000000..1af8c0c2c9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cos.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cos"
+ endpoint {
+ name: "math.cos"
+ }
+ endpoint {
+ name: "cos"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
new file mode 100644
index 0000000000..2de87df40d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cosh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cosh"
+ endpoint {
+ name: "math.cosh"
+ }
+ endpoint {
+ name: "cosh"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
new file mode 100644
index 0000000000..e8a871cae6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Cross.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Cross"
+ endpoint {
+ name: "linalg.cross"
+ }
+ endpoint {
+ name: "cross"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
new file mode 100644
index 0000000000..8b96eee631
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeBase64.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeBase64"
+ endpoint {
+ name: "io.decode_base64"
+ }
+ endpoint {
+ name: "decode_base64"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
new file mode 100644
index 0000000000..829608fc8f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeCompressed.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeCompressed"
+ endpoint {
+ name: "io.decode_compressed"
+ }
+ endpoint {
+ name: "decode_compressed"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
new file mode 100644
index 0000000000..9f28bc5f59
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeJSONExample.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeJSONExample"
+ endpoint {
+ name: "io.decode_json_example"
+ }
+ endpoint {
+ name: "decode_json_example"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
new file mode 100644
index 0000000000..0010a59ca4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeRaw.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DecodeRaw"
+ endpoint {
+ name: "io.decode_raw"
+ }
+ endpoint {
+ name: "decode_raw"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
new file mode 100644
index 0000000000..5edd0c216b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Dequantize.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Dequantize"
+ endpoint {
+ name: "quantization.dequantize"
+ }
+ endpoint {
+ name: "dequantize"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
new file mode 100644
index 0000000000..cba30e63e8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Diag.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Diag"
+ endpoint {
+ name: "linalg.tensor_diag"
+ }
+ endpoint {
+ name: "diag"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
new file mode 100644
index 0000000000..54e1f34e82
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DiagPart.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "DiagPart"
+ endpoint {
+ name: "linalg.tensor_diag_part"
+ }
+ endpoint {
+ name: "diag_part"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
new file mode 100644
index 0000000000..91b4dfead7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Digamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Digamma"
+ endpoint {
+ name: "math.digamma"
+ }
+ endpoint {
+ name: "digamma"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
new file mode 100644
index 0000000000..71bb73cfb2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EncodeBase64.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "EncodeBase64"
+ endpoint {
+ name: "io.encode_base64"
+ }
+ endpoint {
+ name: "encode_base64"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
new file mode 100644
index 0000000000..78aa1b3bc5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Equal.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Equal"
+ endpoint {
+ name: "math.equal"
+ }
+ endpoint {
+ name: "equal"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
new file mode 100644
index 0000000000..e96df0c596
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Erfc.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Erfc"
+ endpoint {
+ name: "math.erfc"
+ }
+ endpoint {
+ name: "erfc"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
new file mode 100644
index 0000000000..70323fe5b4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Exp.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Exp"
+ endpoint {
+ name: "math.exp"
+ }
+ endpoint {
+ name: "exp"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
new file mode 100644
index 0000000000..8ddf9d4d70
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Expm1.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Expm1"
+ endpoint {
+ name: "math.expm1"
+ }
+ endpoint {
+ name: "expm1"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
new file mode 100644
index 0000000000..f008b1222d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ExtractImagePatches"
+ endpoint {
+ name: "image.extract_image_patches"
+ }
+ endpoint {
+ name: "extract_image_patches"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
index 3bcab99415..d79e936b71 100644
--- a/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_FFT.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "FFT"
endpoint {
- name: "fft"
+ name: "spectral.fft"
}
endpoint {
- name: "spectral.fft"
+ name: "fft"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeParam.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeParam.pbtxt
new file mode 100644
index 0000000000..57fa8ff5b9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeParam.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "FakeParam"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
new file mode 100644
index 0000000000..d8db83331f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgs.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxArgs"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_args"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_args"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
new file mode 100644
index 0000000000..74f01d1a0c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxArgsGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxArgsGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_args_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_args_gradient"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
new file mode 100644
index 0000000000..e14fb6d118
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVars.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVars"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
new file mode 100644
index 0000000000..4611ebdfb8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_gradient"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
new file mode 100644
index 0000000000..0936e513c3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannel.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsPerChannel"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_per_channel"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_per_channel"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
new file mode 100644
index 0000000000..0d9968248c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FakeQuantWithMinMaxVarsPerChannelGradient.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "FakeQuantWithMinMaxVarsPerChannelGradient"
+ endpoint {
+ name: "quantization.fake_quant_with_min_max_vars_per_channel_gradient"
+ }
+ endpoint {
+ name: "fake_quant_with_min_max_vars_per_channel_gradient"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
new file mode 100644
index 0000000000..7f721f4fb7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "FeatureStatsDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
new file mode 100644
index 0000000000..9b93caa0b1
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Floor.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Floor"
+ endpoint {
+ name: "math.floor"
+ }
+ endpoint {
+ name: "floor"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
new file mode 100644
index 0000000000..71257c8855
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "GatherNd"
+ endpoint {
+ name: "manip.gather_nd"
+ }
+ endpoint {
+ name: "gather_nd"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
new file mode 100644
index 0000000000..7de60d44c4
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Greater.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Greater"
+ endpoint {
+ name: "math.greater"
+ }
+ endpoint {
+ name: "greater"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
new file mode 100644
index 0000000000..9c8975c2a9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_GreaterEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "GreaterEqual"
+ endpoint {
+ name: "math.greater_equal"
+ }
+ endpoint {
+ name: "greater_equal"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
index 6bbc4ed720..17fbd8ace4 100644
--- a/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_IFFT.pbtxt
@@ -1,9 +1,10 @@
op {
graph_op_name: "IFFT"
endpoint {
- name: "ifft"
+ name: "spectral.ifft"
}
endpoint {
- name: "spectral.ifft"
+ name: "ifft"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
new file mode 100644
index 0000000000..8c4815c26e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Igamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Igamma"
+ endpoint {
+ name: "math.igamma"
+ }
+ endpoint {
+ name: "igamma"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
new file mode 100644
index 0000000000..b43b54391b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Igammac.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Igammac"
+ endpoint {
+ name: "math.igammac"
+ }
+ endpoint {
+ name: "igammac"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
new file mode 100644
index 0000000000..d75fcd63e3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_InvertPermutation.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "InvertPermutation"
+ endpoint {
+ name: "math.invert_permutation"
+ }
+ endpoint {
+ name: "invert_permutation"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
new file mode 100644
index 0000000000..27142644bf
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsFinite.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsFinite"
+ endpoint {
+ name: "debugging.is_finite"
+ }
+ endpoint {
+ name: "is_finite"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
new file mode 100644
index 0000000000..4cd92f1cb7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsInf.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsInf"
+ endpoint {
+ name: "debugging.is_inf"
+ }
+ endpoint {
+ name: "is_inf"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
new file mode 100644
index 0000000000..07d49f9436
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_IsNan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "IsNan"
+ endpoint {
+ name: "debugging.is_nan"
+ }
+ endpoint {
+ name: "is_nan"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Less.pbtxt b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
new file mode 100644
index 0000000000..055df2922a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Less.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Less"
+ endpoint {
+ name: "math.less"
+ }
+ endpoint {
+ name: "less"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
new file mode 100644
index 0000000000..d2803ddb69
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LessEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LessEqual"
+ endpoint {
+ name: "math.less_equal"
+ }
+ endpoint {
+ name: "less_equal"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
new file mode 100644
index 0000000000..0262b838ca
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Lgamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Lgamma"
+ endpoint {
+ name: "math.lgamma"
+ }
+ endpoint {
+ name: "lgamma"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
new file mode 100644
index 0000000000..26d2473b9c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Log.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Log"
+ endpoint {
+ name: "math.log"
+ }
+ endpoint {
+ name: "log"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
new file mode 100644
index 0000000000..d85b6dccec
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Log1p.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Log1p"
+ endpoint {
+ name: "math.log1p"
+ }
+ endpoint {
+ name: "log1p"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
new file mode 100644
index 0000000000..80bd98b740
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalAnd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalAnd"
+ endpoint {
+ name: "math.logical_and"
+ }
+ endpoint {
+ name: "logical_and"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
new file mode 100644
index 0000000000..b2244c44b1
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalNot.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalNot"
+ endpoint {
+ name: "math.logical_not"
+ }
+ endpoint {
+ name: "logical_not"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
new file mode 100644
index 0000000000..cf78b52e07
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_LogicalOr.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "LogicalOr"
+ endpoint {
+ name: "math.logical_or"
+ }
+ endpoint {
+ name: "logical_or"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
new file mode 100644
index 0000000000..74145670a8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_MatchingFiles.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "MatchingFiles"
+ endpoint {
+ name: "io.matching_files"
+ }
+ endpoint {
+ name: "matching_files"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
index 89b1c1f5a9..1122c52ab4 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixBandPart.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_band_part"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
index 4d289f542f..9563bf0354 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDeterminant.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_determinant"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
index fd9d34635e..8ab0bf75eb 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiag.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_diag"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
index fa5d1f10af..82ce67853c 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPart.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_diag_part"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
index c0ddd73704..85862f6eb5 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixInverse.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_inverse"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
index 01f4f0e89d..6325e4f0e6 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiag.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_set_diag"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
index cef763e4e9..6325dff407 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSolve.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_solve"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
index a0d576aa31..7f865e23b2 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixTriangularSolve.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "matrix_triangular_solve"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
new file mode 100644
index 0000000000..bcff379b71
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Maximum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Maximum"
+ endpoint {
+ name: "math.maximum"
+ }
+ endpoint {
+ name: "maximum"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
new file mode 100644
index 0000000000..9aae74226a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Minimum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Minimum"
+ endpoint {
+ name: "math.minimum"
+ }
+ endpoint {
+ name: "minimum"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
new file mode 100644
index 0000000000..0d358dff98
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NonMaxSuppressionWithOverlaps.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "NonMaxSuppressionWithOverlaps"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
new file mode 100644
index 0000000000..f37317854f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_NotEqual.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "NotEqual"
+ endpoint {
+ name: "math.not_equal"
+ }
+ endpoint {
+ name: "not_equal"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
new file mode 100644
index 0000000000..10b3aab0c7
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseTensor.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ParseTensor"
+ endpoint {
+ name: "io.parse_tensor"
+ }
+ endpoint {
+ name: "parse_tensor"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
new file mode 100644
index 0000000000..9df81402d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Polygamma.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Polygamma"
+ endpoint {
+ name: "math.polygamma"
+ }
+ endpoint {
+ name: "polygamma"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
index b19da0d817..0260eecc91 100644
--- a/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Qr.pbtxt
@@ -5,5 +5,6 @@ op {
}
endpoint {
name: "qr"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
new file mode 100644
index 0000000000..69404b9472
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_QuantizedConcat.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "QuantizedConcat"
+ endpoint {
+ name: "quantization.quantized_concat"
+ }
+ endpoint {
+ name: "quantized_concat"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
new file mode 100644
index 0000000000..9d479be45f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ReadFile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ReadFile"
+ endpoint {
+ name: "io.read_file"
+ }
+ endpoint {
+ name: "read_file"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
new file mode 100644
index 0000000000..c4d4c27722
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Reciprocal.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Reciprocal"
+ endpoint {
+ name: "math.reciprocal"
+ }
+ endpoint {
+ name: "reciprocal"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
new file mode 100644
index 0000000000..b17806b338
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "RegexReplace"
+ endpoint {
+ name: "strings.regex_replace"
+ }
+ endpoint {
+ name: "regex_replace"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
new file mode 100644
index 0000000000..c469665b66
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Reshape"
+ endpoint {
+ name: "manip.reshape"
+ }
+ endpoint {
+ name: "reshape"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..ffef3ab522
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ResourceScatterNdAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ResourceScatterNdAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 8307a3c2dd..77f595927b 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,6 +1,14 @@
op {
graph_op_name: "ReverseV2"
endpoint {
+ name: "manip.reverse"
+ }
+ endpoint {
+ name: "reverse"
+ deprecated: true
+ }
+ endpoint {
name: "reverse_v2"
+ deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
new file mode 100644
index 0000000000..ec37a23127
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Rint.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Rint"
+ endpoint {
+ name: "math.rint"
+ }
+ endpoint {
+ name: "rint"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
new file mode 100644
index 0000000000..4fc2b81421
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Rsqrt.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Rsqrt"
+ endpoint {
+ name: "math.rsqrt"
+ }
+ endpoint {
+ name: "rsqrt"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
new file mode 100644
index 0000000000..a65a19b542
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "ScatterNd"
+ endpoint {
+ name: "manip.scatter_nd"
+ }
+ endpoint {
+ name: "scatter_nd"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt
new file mode 100644
index 0000000000..f6c8af5c33
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNdAdd.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterNdAdd"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
new file mode 100644
index 0000000000..2e22c375c0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMax.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMax"
+ endpoint {
+ name: "math.segment_max"
+ }
+ endpoint {
+ name: "segment_max"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
new file mode 100644
index 0000000000..646348072f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMean.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMean"
+ endpoint {
+ name: "math.segment_mean"
+ }
+ endpoint {
+ name: "segment_mean"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
new file mode 100644
index 0000000000..1a77019a2d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentMin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentMin"
+ endpoint {
+ name: "math.segment_min"
+ }
+ endpoint {
+ name: "segment_min"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
new file mode 100644
index 0000000000..cf4d6f0237
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentProd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentProd"
+ endpoint {
+ name: "math.segment_prod"
+ }
+ endpoint {
+ name: "segment_prod"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
new file mode 100644
index 0000000000..c6d7999455
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SegmentSum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SegmentSum"
+ endpoint {
+ name: "math.segment_sum"
+ }
+ endpoint {
+ name: "segment_sum"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
new file mode 100644
index 0000000000..9c19a1a177
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Sin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Sin"
+ endpoint {
+ name: "math.sin"
+ }
+ endpoint {
+ name: "sin"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
new file mode 100644
index 0000000000..155e58e6d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Sinh.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Sinh"
+ endpoint {
+ name: "math.sinh"
+ }
+ endpoint {
+ name: "sinh"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt b/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
index 2de56c27be..c4da47241b 100644
--- a/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Softplus.pbtxt
@@ -1,6 +1,9 @@
op {
graph_op_name: "Softplus"
endpoint {
+ name: "math.softplus"
+ }
+ endpoint {
name: "nn.softplus"
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt b/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
index b47412d135..852d205024 100644
--- a/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Softsign.pbtxt
@@ -3,4 +3,7 @@ op {
endpoint {
name: "nn.softsign"
}
+ endpoint {
+ name: "math.softsign"
+ }
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
new file mode 100644
index 0000000000..af323a6cf3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SpaceToBatchND"
+ endpoint {
+ name: "manip.space_to_batch_nd"
+ }
+ endpoint {
+ name: "space_to_batch_nd"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SparseSliceGrad.pbtxt b/tensorflow/core/api_def/python_api/api_def_SparseSliceGrad.pbtxt
new file mode 100644
index 0000000000..6ea8df46ec
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SparseSliceGrad.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "SparseSliceGrad"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
new file mode 100644
index 0000000000..4bab8cf00c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_SquaredDifference.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "SquaredDifference"
+ endpoint {
+ name: "math.squared_difference"
+ }
+ endpoint {
+ name: "squared_difference"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatefulPartitionedCall.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatefulPartitionedCall.pbtxt
new file mode 100644
index 0000000000..eb8e3ae902
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatefulPartitionedCall.pbtxt
@@ -0,0 +1 @@
+op { graph_op_name: "StatefulPartitionedCall" visibility: HIDDEN }
diff --git a/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
new file mode 100644
index 0000000000..46a7c0361e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringJoin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringJoin"
+ endpoint {
+ name: "strings.join"
+ }
+ endpoint {
+ name: "string_join"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt
new file mode 100644
index 0000000000..0e8576fb01
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringSplitV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringSplitV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
new file mode 100644
index 0000000000..fbcdeaad6d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringStrip.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringStrip"
+ endpoint {
+ name: "strings.strip"
+ }
+ endpoint {
+ name: "string_strip"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
new file mode 100644
index 0000000000..d122e79b39
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucket.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucket"
+ endpoint {
+ name: "strings.to_hash_bucket"
+ }
+ endpoint {
+ name: "string_to_hash_bucket"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
new file mode 100644
index 0000000000..aef9dffefe
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketFast.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucketFast"
+ endpoint {
+ name: "strings.to_hash_bucket_fast"
+ }
+ endpoint {
+ name: "string_to_hash_bucket_fast"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
new file mode 100644
index 0000000000..385b9fd02a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToHashBucketStrong.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToHashBucketStrong"
+ endpoint {
+ name: "strings.to_hash_bucket_strong"
+ }
+ endpoint {
+ name: "string_to_hash_bucket_strong"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
new file mode 100644
index 0000000000..f740b9849d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringToNumber.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "StringToNumber"
+ endpoint {
+ name: "strings.to_number"
+ }
+ endpoint {
+ name: "string_to_number"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
new file mode 100644
index 0000000000..4778d7927c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Substr"
+ endpoint {
+ name: "strings.substr"
+ }
+ endpoint {
+ name: "substr"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
new file mode 100644
index 0000000000..ffa92f5580
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Tan.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Tan"
+ endpoint {
+ name: "math.tan"
+ }
+ endpoint {
+ name: "tan"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt
new file mode 100644
index 0000000000..5d76c112a0
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorArrayGradWithShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorArrayGradWithShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
new file mode 100644
index 0000000000..c34061c941
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Tile"
+ endpoint {
+ name: "manip.tile"
+ }
+ endpoint {
+ name: "tile"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
new file mode 100644
index 0000000000..cf81843241
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMax.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentMax"
+ endpoint {
+ name: "math.unsorted_segment_max"
+ }
+ endpoint {
+ name: "unsorted_segment_max"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
new file mode 100644
index 0000000000..475361c85a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentMin.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentMin"
+ endpoint {
+ name: "math.unsorted_segment_min"
+ }
+ endpoint {
+ name: "unsorted_segment_min"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
new file mode 100644
index 0000000000..a9d741bbc3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentProd.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentProd"
+ endpoint {
+ name: "math.unsorted_segment_prod"
+ }
+ endpoint {
+ name: "unsorted_segment_prod"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
new file mode 100644
index 0000000000..337678dcff
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnsortedSegmentSum.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "UnsortedSegmentSum"
+ endpoint {
+ name: "math.unsorted_segment_sum"
+ }
+ endpoint {
+ name: "unsorted_segment_sum"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
new file mode 100644
index 0000000000..1a58ae19e5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_WriteFile.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "WriteFile"
+ endpoint {
+ name: "io.write_file"
+ }
+ endpoint {
+ name: "write_file"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
new file mode 100644
index 0000000000..4684a9d624
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Zeta.pbtxt
@@ -0,0 +1,10 @@
+op {
+ graph_op_name: "Zeta"
+ endpoint {
+ name: "math.zeta"
+ }
+ endpoint {
+ name: "zeta"
+ deprecated: true
+ }
+}
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 462d6b7533..3af9286264 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -108,11 +108,11 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
bool peer_is_local, const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality,
+ const DeviceLocality& client_locality, int stream_index,
const StatusCallback& done) override {
- remote_access_->RecvFromPeer(peer_device, peer_task, peer_is_local, key,
- to_device, to_device_ctx, to_alloc_attr,
- to_tensor, client_locality, done);
+ remote_access_->RecvFromPeer(
+ peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
+ to_alloc_attr, to_tensor, client_locality, stream_index, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
index 9646a0856e..46142d5923 100644
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ b/tensorflow/core/common_runtime/broadcaster.cc
@@ -187,7 +187,7 @@ void Broadcaster::RunTree() {
DeviceContext* op_dev_ctx = ctx_->op_device_context();
CollectiveRemoteAccessLocal::MemCpyAsync(
op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
- ctx_->output_alloc_attr(0), input, output_,
+ ctx_->output_alloc_attr(0), input, output_, 0 /*steam_index*/,
[this, &mu, &pending_count, &all_done](const Status& s) {
mutex_lock l(mu);
status_.Update(s);
@@ -239,7 +239,7 @@ void Broadcaster::DispatchRecv(int src_rank, Tensor* dst_tensor,
col_params_.task.is_local[src_idx], recv_buf_key,
device_, ctx_->op_device_context(),
ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, done);
+ device_locality_, 0 /*stream_index*/, done);
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/broadcaster_test.cc
index 959b93d56e..6a163a0db0 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/broadcaster_test.cc
@@ -161,12 +161,12 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
bool peer_is_local, const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality,
+ const DeviceLocality& client_locality, int stream_index,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
- to_alloc_attr, to_tensor, client_locality, done);
+ to_alloc_attr, to_tensor, client_locality, stream_index, done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
diff --git a/tensorflow/core/common_runtime/build_graph_options.cc b/tensorflow/core/common_runtime/build_graph_options.cc
index a9dc6ca6cd..00f7a8e645 100644
--- a/tensorflow/core/common_runtime/build_graph_options.cc
+++ b/tensorflow/core/common_runtime/build_graph_options.cc
@@ -32,6 +32,9 @@ string BuildGraphOptions::DebugString() const {
for (auto& s : callable_options.target()) {
strings::StrAppend(&rv, s, ", ");
}
+ if (collective_graph_key != kNoCollectiveGraphKey) {
+ strings::StrAppend(&rv, "\ncollective_graph_key: ", collective_graph_key);
+ }
return rv;
}
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 5ca170e922..3d0f242ea5 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -31,6 +31,9 @@ struct BuildGraphOptions {
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
+ static const int64 kNoCollectiveGraphKey = 0;
+ int64 collective_graph_key = kNoCollectiveGraphKey;
+
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.cc b/tensorflow/core/common_runtime/collective_executor_mgr.cc
index e07829b286..4f03a5e13a 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.cc
@@ -25,11 +25,11 @@ namespace tensorflow {
CollectiveExecutorMgr::CollectiveExecutorMgr(
const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver)
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver)
: dev_mgr_(dev_mgr),
- dev_resolver_(dev_resolver),
- param_resolver_(param_resolver) {}
+ dev_resolver_(std::move(dev_resolver)),
+ param_resolver_(std::move(param_resolver)) {}
CollectiveExecutorMgr::~CollectiveExecutorMgr() {
for (auto iter : executor_table_) {
@@ -45,9 +45,7 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
if (it != executor_table_.end()) {
ce = it->second;
} else {
- CollectiveRemoteAccessLocal* rma = new CollectiveRemoteAccessLocal(
- dev_mgr_, dev_resolver_.get(), step_id);
- ce = new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+ ce = Create(step_id);
executor_table_[step_id] = ce;
}
ce->Ref();
@@ -55,6 +53,12 @@ CollectiveExecutor* CollectiveExecutorMgr::FindOrCreate(int64 step_id) {
return ce;
}
+CollectiveExecutor* CollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessLocal* rma =
+ new CollectiveRemoteAccessLocal(dev_mgr_, dev_resolver_.get(), step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
void CollectiveExecutorMgr::Cleanup(int64 step_id) {
CollectiveExecutor* ce = nullptr;
{
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 4b42e2b4d1..9de6ab8968 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -25,8 +25,8 @@ class DeviceMgr;
class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
public:
CollectiveExecutorMgr(const ConfigProto& config, const DeviceMgr* dev_mgr,
- DeviceResolverInterface* dev_resolver,
- ParamResolverInterface* param_resolver);
+ std::unique_ptr<DeviceResolverInterface> dev_resolver,
+ std::unique_ptr<ParamResolverInterface> param_resolver);
virtual ~CollectiveExecutorMgr();
@@ -56,11 +56,16 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
void RetireStepId(int64 graph_key, int64 step_id) override {}
protected:
+ // Called by FindOrCreate when table entry does not yet exist.
+ virtual CollectiveExecutor* Create(int64 step_id);
+
const DeviceMgr* dev_mgr_;
std::unique_ptr<DeviceResolverInterface> dev_resolver_;
std::unique_ptr<ParamResolverInterface> param_resolver_;
CollectiveRemoteAccess* remote_access_;
string task_name_;
+
+ private:
mutex exec_mu_;
// Map from step_id to CollectiveExecutor
gtl::FlatMap<int64, CollectiveExecutor*> executor_table_ GUARDED_BY(exec_mu_);
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index 34c9163d6a..91994c5731 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -40,10 +40,13 @@ class CollectiveExecutorMgrTest : public ::testing::Test {
device_count->insert({"CPU", NUM_DEVS});
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
device_mgr_.reset(new DeviceMgr(devices_));
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
- cme_.reset(new CollectiveExecutorMgr(
- cp, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl, task_name)));
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> prl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ task_name));
+ cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
+ std::move(prl)));
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 8b2e0d1e0a..236f999228 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -18,6 +18,10 @@ limitations under the License.
namespace tensorflow {
+void CollectiveParamResolverLocal::InstanceRec::WaitForOutMu(mutex_lock& lock) {
+ while (!out_mu_available) out_cv.wait(lock);
+}
+
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
const string& task_name)
@@ -313,11 +317,14 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
VLOG(1) << "Modified device_names on " << cp;
SetDevPerTask(cp);
}
+} // namespace
// Establish the requested number of subdivision permutations based on the
// ring order implicit in the device order.
-void GenerateSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp) {
+/*static*/
+void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
+ int source_rank,
+ CollectiveParams* cp) {
// Each subdiv permutation is a ring formed by rotating each
// single-task subsequence of devices by an offset. This makes most
// sense when each task has the same number of devices but we can't
@@ -356,15 +363,27 @@ void GenerateSubdivPerms(const string& device, int source_rank,
std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
CHECK_EQ(perm.size(), 0);
int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- int prior_dev_count = 0;
+ // A negative subdivision offset is interpreted as follows:
+ // 1. Reverse the local device ordering.
+ // 2. Begin the subdivision at abs(offset) in the reversed ordering.
+ bool reverse = false;
+ if (offset < 0) {
+ offset = abs(offset);
+ reverse = true;
+ }
+ int prior_dev_count = 0; // sum over prior worker device counts
for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
for (int di = 0; di < dev_per_task[ti]; ++di) {
- int offset_di = (di + offset) % dev_per_task[ti];
+ int di_offset = (di + offset) % dev_per_task[ti];
+ int offset_di =
+ reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
+ // Device index in global subdivision permutation.
int permuted_di = prior_dev_count + offset_di;
+ int rank = static_cast<int>(perm.size());
perm.push_back(permuted_di);
- if (cp->instance.device_names[prior_dev_count + di] == device) {
- CHECK_EQ(prior_dev_count + di, cp->default_rank);
- cp->subdiv_rank[sdi] = permuted_di;
+ if (cp->instance.device_names[permuted_di] == device) {
+ CHECK_EQ(permuted_di, cp->default_rank);
+ cp->subdiv_rank[sdi] = rank;
}
}
prior_dev_count += dev_per_task[ti];
@@ -411,8 +430,6 @@ void GenerateSubdivPerms(const string& device, int source_rank,
}
}
-} // namespace
-
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
@@ -460,11 +477,24 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
// called by a derived class, some of the devices may be non-local and
// GetDeviceLocalitiesAsync will use those fields to launch RPCs.
CompleteTaskIsLocal(task_name_, &ir->shared);
+
+ // Because the callback may execute in a different thread, we release
+ // ir->out_mu here. Before releasing, we mark it as unavailable for other
+ // threads.
+ ir->out_mu_available = false;
+ ir->out_mu.unlock();
std::vector<DeviceLocality>* localities = new std::vector<DeviceLocality>;
dev_resolver_->GetDeviceLocalitiesAsync(
ir->shared.instance, localities,
[this, gr, cp, ir, localities, done](const Status& s)
- EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) {
+ EXCLUSIVE_LOCK_FUNCTION(ir->out_mu) {
+ // Then we recover the lock in the callback thread that will hold it
+ // through the rest of the call chain. Signal the cv now, any
+ // waiting threads will wake only when out_mu is released later.
+ ir->out_mu.lock();
+ DCHECK(!ir->out_mu_available);
+ ir->out_mu_available = true;
+ ir->out_cv.notify_all();
if (s.ok()) {
CompleteDefaultRanking(gr, cp, ir, *localities);
done(Status::OK());
@@ -512,6 +542,7 @@ void CollectiveParamResolverLocal::CallbackWithStatus(
Status s;
{
mutex_lock l(irec->out_mu);
+ irec->WaitForOutMu(l);
s = irec->status;
}
done(s, irec);
@@ -559,21 +590,29 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
// static analysis, so we turn off analysis only within this
// function body.
//
- // A lock on ir->out_mu must be held throughout the _bodies_ of the
+ // A lock on ir->out_mu must be held* throughout the _bodies_ of the
// chain of function calls initiated here, each of which calls
// another as its last action, but it will be dropped within the
// callback defined below, which means that the lock can be dropped
// before all the function stack frames pop. The static analysis will
// not allow that.
+ //
+ // *the lock is dropped just before calling GetDeviceLocalitiesAsync, because
+ // there is no guarantee that the thread that executes the callback is the
+ // same as the one that locked ir->out_mu. To prevent other threads from
+ // grabbing ir->out_mu, we mark ir->out_mu_available as false. Hence, in
+ // principle, the lock is held throughout.
ir->out_mu.lock();
+ DCHECK(ir->out_mu_available);
ir->known.resize(cp->group.group_size, false);
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
DCHECK(!ir->out_mu.try_lock());
+ DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
- // Prepare to invoke any waiters that accumlated during
+ // Prepare to invoke any waiters that accumulated during
// initialization.
std::vector<IRConsumer> init_waiters;
{
@@ -650,6 +689,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// Populate the fields common across instance.
{
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
// custom operator= does a deep copy.
cp->instance = ir->shared.instance;
}
@@ -665,8 +705,9 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
int source_rank;
{
mutex_lock l(irec->out_mu);
+ irec->WaitForOutMu(l);
s = irec->status;
- source_rank = ir->source_rank;
+ source_rank = irec->source_rank;
}
if (s.ok()) {
GenerateSubdivPerms(device, source_rank, cp);
@@ -687,6 +728,7 @@ void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
std::vector<IRConsumer> ready_waiters;
{
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
CHECK_EQ(cp->group.group_size, ir->known.size());
CHECK_GE(cp->default_rank, 0);
if (!ir->known[cp->default_rank]) {
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 3a871f962d..01bdeca7d1 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -88,7 +88,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// permit mutex locks to be taken in more than one order.
//
// out_mu guards access to most of the fields.
- // in_mu guards access to a queue of comsumer callbacks wanting to
+ // in_mu guards access to a queue of consumer callbacks wanting to
// read the fields guarded by out_mu.
//
// The in_mu should be locked only while holding instance_mu_; the
@@ -109,8 +109,12 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
bool is_init GUARDED_BY(in_mu);
std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
- // Values to be shared by all instances, constant after initialization.
+ // A thread that wishes to acquire out_mu must ensure that it is available
+ // by invoking WaitForOutMu().
mutex out_mu;
+ condition_variable out_cv;
+ bool out_mu_available GUARDED_BY(out_mu);
+ // Values to be shared by all instances, constant after initialization.
CollectiveParams shared GUARDED_BY(out_mu);
// If an error occurs during initialization this structure stays in
// the table with a non-OK status. Purging the table and restarting
@@ -124,7 +128,15 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
std::vector<bool> known GUARDED_BY(out_mu);
std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
- InstanceRec() : is_init(false), source_rank(-1), known_count(0) {}
+ InstanceRec()
+ : is_init(false),
+ out_mu_available(true),
+ source_rank(-1),
+ known_count(0) {}
+
+ // If out_mu is unavailable during distributed device locality
+ // initialization, wait on out_cv until it is available again.
+ void WaitForOutMu(mutex_lock& lock) EXCLUSIVE_LOCKS_REQUIRED(out_mu);
};
// Find the InstanceRec with the same instance_key as cp. If it doesn't
@@ -147,7 +159,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// cp is populated with all DeviceLocalities
void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
InstanceRec* ir, const StatusCallback& done)
- EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
+ UNLOCK_FUNCTION(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
void CallInitInstanceSharedParams(const GroupRec* gr,
const CollectiveParams* cp, InstanceRec* ir,
@@ -200,8 +212,12 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);
+ friend class CollectiveParamResolverLocalTest;
+ static void GenerateSubdivPerms(const string& device, int source_rank,
+ CollectiveParams* cp);
+
const DeviceMgr* dev_mgr_;
- DeviceResolverInterface* dev_resolver_;
+ DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
mutex group_mu_;
gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 4e33c4779a..d5be8f927e 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
-namespace {
#define NUM_DEVS 3
@@ -45,6 +44,11 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
task_name));
}
+ void GenSubdivPerms(const string& device, int source_rank,
+ CollectiveParams* cp) {
+ CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp);
+ }
+
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@@ -147,7 +151,69 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
}
}
-// TEST_F(CollectiveParamResolverLocalTest,
+TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) {
+ static const int kNumDevsPerTask = 8;
+ static const int kNumTasks = 3;
+ static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp;
+ std::vector<string> device_names;
+ std::vector<string> task_names;
+ cp.group.group_key = 1;
+ cp.group.group_size = kNumDevs;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = kNumTasks;
+ cp.instance.instance_key = 3;
+ cp.instance.type = REDUCTION_COLLECTIVE;
+ cp.instance.data_type = DataType(DT_FLOAT);
+ cp.instance.shape = TensorShape({5});
+ cp.instance.impl_details.subdiv_offsets.push_back(0);
+ cp.is_source = false;
+ for (int i = 0; i < kNumDevs; ++i) {
+ int task_id = i / kNumDevsPerTask;
+ int dev_id = i % kNumDevsPerTask;
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
+ task_names.push_back(task_name);
+ string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
+ device_names.push_back(device_name);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(device_name);
+ }
+
+ int test_rank = 0;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, 4};
+ GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
+ std::vector<int> expected_0 = {0, 1, 2, 3, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23};
+ std::vector<int> expected_1 = {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
+ 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19};
+ for (int i = 0; i < kNumDevs; ++i) {
+ EXPECT_EQ(expected_0[i],
+ cp.instance.impl_details.subdiv_permutations[0][i]);
+ EXPECT_EQ(expected_1[i],
+ cp.instance.impl_details.subdiv_permutations[1][i]);
+ }
+ EXPECT_EQ(0, cp.subdiv_rank[0]);
+ EXPECT_EQ(4, cp.subdiv_rank[1]);
+
+ test_rank = 3;
+ cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {3, -3};
+ cp.instance.impl_details.subdiv_permutations.clear();
+ GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
+ expected_0 = {3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
+ 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18};
+ expected_1 = {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
+ 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21};
+ for (int i = 0; i < kNumDevs; ++i) {
+ EXPECT_EQ(expected_0[i],
+ cp.instance.impl_details.subdiv_permutations[0][i]);
+ EXPECT_EQ(expected_1[i],
+ cp.instance.impl_details.subdiv_permutations[1][i]);
+ }
+ EXPECT_EQ(0, cp.subdiv_rank[0]);
+ EXPECT_EQ(1, cp.subdiv_rank[1]);
+}
-} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc
index 69f1a9f24c..288ae9d794 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local.cc
@@ -27,7 +27,8 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
const string& peer_device, const string& peer_task, bool peer_is_local,
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality, const StatusCallback& done) {
+ const DeviceLocality& client_locality, int dev_to_dev_stream_index,
+ const StatusCallback& done) {
VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
<< key;
if (!peer_is_local) {
@@ -37,8 +38,9 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
return;
}
buf_rendezvous_.ConsumeBuf(
- key, [this, to_tensor, to_device_ctx, to_device, to_alloc_attr, done](
- const Status& s, BufRendezvous::Hook* hook) {
+ key, [this, to_tensor, to_device_ctx, to_device, to_alloc_attr,
+ dev_to_dev_stream_index,
+ done](const Status& s, BufRendezvous::Hook* hook) {
if (!s.ok()) {
done(s);
delete hook;
@@ -53,7 +55,7 @@ void CollectiveRemoteAccessLocal::RecvFromPeer(
to_alloc_attr, // dst AllocatorAttributes
hook->prod_value, // src Tensor*
to_tensor, // dst Tensor*
- [hook, done](const Status& s) {
+ dev_to_dev_stream_index, [hook, done](const Status& s) {
// This callback may be executing in the GPUEventMgr
// pool in which case it must be very short duration
// and non-blocking (except e.g. for queue insertion).
@@ -82,7 +84,7 @@ void CollectiveRemoteAccessLocal::MemCpyAsync(
DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
Device* dst_dev, const AllocatorAttributes& src_attr,
const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
- const StatusCallback& done) {
+ int dev_to_dev_stream_index, const StatusCallback& done) {
// We want a real copy to happen, i.e. the bytes inside of src should be
// transferred to the buffer backing dst. If src and dst are on different
// devices then CopyTensor::ViaDMA will do just that. But if they're both
@@ -115,7 +117,7 @@ void CollectiveRemoteAccessLocal::MemCpyAsync(
if (non_cpu_src || non_cpu_dst) {
CopyTensor::ViaDMA("", // edge name (non-existent)
src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
- dst_attr, src, dst, done);
+ dst_attr, src, dst, dev_to_dev_stream_index, done);
} else {
int64 bytes = src->TotalBytes();
DCHECK_EQ(dst->TotalBytes(), bytes);
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index 716e23bfa1..dbb2e67c7d 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -41,6 +41,7 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) override;
void PostToPeer(const string& peer_device, const string& peer_task,
@@ -77,6 +78,7 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
Device* dst_dev, const AllocatorAttributes& src_attr,
const AllocatorAttributes& dst_attr,
const Tensor* src, Tensor* dst,
+ int dev_to_dev_stream_index,
const StatusCallback& done);
protected:
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index dcd4272d96..a931fe64bd 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -69,6 +69,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) {
rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/,
"key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ 0 /*stream_index*/,
[this, &recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
@@ -111,6 +112,7 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) {
rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/,
"key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
+ 0 /*stream_index*/,
[this, &recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index 08d120c7a5..630b3702c8 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -170,7 +170,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
Device* dst, const AllocatorAttributes src_alloc_attr,
const AllocatorAttributes dst_alloc_attr,
const Tensor* input, Tensor* output,
- StatusCallback done) {
+ int dev_to_dev_stream_index, StatusCallback done) {
if (input->dtype() == DT_VARIANT) {
Tensor copy(cpu_allocator, DT_VARIANT, input->shape());
auto* status_cb = new ReffedStatusCallback(std::move(done));
@@ -182,10 +182,10 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
};
auto copier = std::bind(
[copy_function, src, dst, src_alloc_attr, dst_alloc_attr,
- recv_dev_context, send_dev_context, out_allocator,
- status_cb](StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
+ recv_dev_context, send_dev_context, out_allocator, status_cb,
+ dev_to_dev_stream_index](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
if (!DMAHelper::CanUseDMA(&from)) {
Status err = errors::InvalidArgument(
"During Variant Device->Device Copy: "
@@ -199,7 +199,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
*to = Tensor(out_allocator, from.dtype(), from.shape());
copy_function(send_dev_context, recv_dev_context, src, dst,
src_alloc_attr, dst_alloc_attr, &from, to,
- std::move(wrapped_done_));
+ dev_to_dev_stream_index, std::move(wrapped_done_));
return Status::OK();
} else {
return status_cb->status();
@@ -224,7 +224,8 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function,
}
} else {
copy_function(send_dev_context, recv_dev_context, src, dst, src_alloc_attr,
- dst_alloc_attr, input, output, std::move(done));
+ dst_alloc_attr, input, output, dev_to_dev_stream_index,
+ std::move(done));
}
}
@@ -236,7 +237,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
Device* dst, const AllocatorAttributes src_alloc_attr,
const AllocatorAttributes dst_alloc_attr,
const Tensor* input, Tensor* output,
- StatusCallback done) {
+ int dev_to_dev_stream_index, StatusCallback done) {
tracing::ScopedAnnotation annotation(edge_name);
VLOG(1) << "Copy " << edge_name;
@@ -266,7 +267,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
CopyDeviceToDevice(ri.copy_function, cpu_allocator, out_allocator,
send_dev_context, recv_dev_context, src, dst,
src_alloc_attr, dst_alloc_attr, input, output,
- std::move(done));
+ dev_to_dev_stream_index, std::move(done));
return;
}
}
diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h
index a9d684bf11..9cd5ac2a37 100644
--- a/tensorflow/core/common_runtime/copy_tensor.h
+++ b/tensorflow/core/common_runtime/copy_tensor.h
@@ -28,13 +28,11 @@ namespace tensorflow {
class CopyTensor {
public:
- typedef void (*CopyFunction)(DeviceContext* send_dev_context,
- DeviceContext* recv_dev_context, Device* src,
- Device* dst,
- const AllocatorAttributes src_alloc_attr,
- const AllocatorAttributes dst_alloc_attr,
- const Tensor* input, Tensor* output,
- StatusCallback done);
+ typedef void (*CopyFunction)(
+ DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
+ Device* src, Device* dst, const AllocatorAttributes src_alloc_attr,
+ const AllocatorAttributes dst_alloc_attr, const Tensor* input,
+ Tensor* output, int dev_to_dev_stream_index, StatusCallback done);
// Copies "input" to "output" between devices accessible to the
// local process via some DMA-like method. "edge_name" is the name
@@ -46,7 +44,8 @@ class CopyTensor {
DeviceContext* recv_dev_context, Device* src, Device* dst,
const AllocatorAttributes src_alloc_attr,
const AllocatorAttributes dst_alloc_attr,
- const Tensor* input, Tensor* output, StatusCallback done);
+ const Tensor* input, Tensor* output,
+ int dev_to_dev_stream_index, StatusCallback done);
// Object used to call Register() at static-initialization time.
// Note: This should only ever be used as a global-static object; no stack
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 07c1eafedc..f903faf1bd 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -447,18 +447,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
// Create a run state and start execution.
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
+#ifndef __ANDROID__
// Set up for collectives if the RunOption declares a key.
if (run_options.experimental().collective_graph_key() > 0) {
if (!collective_executor_mgr_) {
- DeviceResolverLocal* drl = new DeviceResolverLocal(device_mgr_.get());
+ std::unique_ptr<DeviceResolverInterface> drl(
+ new DeviceResolverLocal(device_mgr_.get()));
+ std::unique_ptr<ParamResolverInterface> cprl(
+ new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
- options_.config, device_mgr_.get(), drl,
- new CollectiveParamResolverLocal(device_mgr_.get(), drl,
- "/job:localhost/replica:0/task:0")));
+ options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
}
run_state.collective_executor.reset(new CollectiveExecutor::Handle(
collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
}
+#endif
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
@@ -1622,15 +1626,6 @@ Status DirectSession::MakeCallable(const CallableOptions& callable_options,
TF_RETURN_IF_ERROR(CheckNotClosed());
TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
- if (!callable_options.run_options()
- .debug_options()
- .debug_tensor_watch_opts()
- .empty()) {
- return errors::Unimplemented(
- "Debug options are not currently supported via the C++ MakeCallable "
- "interface.");
- }
-
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
RunStateArgs run_state_args(callable_options.run_options().debug_options());
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 8ddc9958b2..142d613129 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@@ -47,6 +48,11 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
+#ifdef GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#include "cuda/include/cuda_runtime_api.h"
+#endif // GOOGLE_CUDA
+
namespace tensorflow {
namespace {
@@ -1233,36 +1239,23 @@ TEST(DirectSessionTest, TimeoutSession) {
device: '/device:CPU:0'
attr {
key: 'capacity'
- value {
- i: 10
- }
+ value { i: 10 }
}
attr {
key: 'component_types'
- value {
- list {
- type: DT_FLOAT
- }
- }
+ value { list { type: DT_FLOAT } }
}
attr {
key: 'container'
- value {
- s: ''
- }
+ value { s: '' }
}
attr {
key: 'shapes'
- value {
- list {
- }
- }
+ value { list {} }
}
attr {
key: 'shared_name'
- value {
- s: ''
- }
+ value { s: '' }
}
}
node {
@@ -1272,24 +1265,15 @@ TEST(DirectSessionTest, TimeoutSession) {
device: '/device:CPU:0'
attr {
key: 'component_types'
- value {
- list {
- type: DT_FLOAT
- }
- }
+ value { list { type: DT_FLOAT } }
}
attr {
key: 'timeout_ms'
- value {
- i: -1
- }
+ value { i: -1 }
}
}
- versions {
- producer: 9
- }
- )proto",
- &graph);
+ versions { producer: 9 }
+ )proto", &graph);
{
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
@@ -1352,11 +1336,8 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) {
op: 'CancellationMgrPollingOp'
device: '/device:CPU:0'
}
- versions {
- producer: 9
- }
- )proto",
- &graph);
+ versions { producer: 9 }
+ )proto", &graph);
// Creates a session with operation_timeout_in_ms set to 100 milliseconds.
SessionOptions options;
@@ -1730,6 +1711,292 @@ TEST(DirectSessionTest, LocalDeviceManager) {
EXPECT_GT(mgr->ListDevices().size(), 0);
}
+// y = tf.square(x)
+GraphDef CreateGraphForYEqualsXSquared() {
+ GraphDef graph_def;
+ const char* text_proto = R"EOF(
+node {
+ name: "x"
+ op: "Placeholder"
+ attr { key: "dtype" value { type: DT_FLOAT } }
+ attr { key: "shape" value { shape { unknown_rank: true } } }
+}
+node {
+ name: "y"
+ op: "Square"
+ input: "x"
+ attr { key: "T" value { type: DT_FLOAT } }
+}
+versions {
+ producer: 26
+}
+ )EOF";
+
+ QCHECK(protobuf::TextFormat::ParseFromString(text_proto, &graph_def));
+ return graph_def;
+}
+
+// A graph that consumes and produces string tensors
+// (which are not GPU-compatible, i.e., there are no
+// GPU kernels for these operations).
+bool IsCUDATensor(const Tensor& t) {
+#ifdef GOOGLE_CUDA
+ cudaPointerAttributes attributes;
+ cudaError_t err =
+ cudaPointerGetAttributes(&attributes, t.tensor_data().data());
+ if (err == cudaErrorInvalidValue) return false;
+ CHECK_EQ(cudaSuccess, err) << cudaGetErrorString(err);
+ return (attributes.memoryType == cudaMemoryTypeDevice);
+#else
+ return false;
+#endif
+}
+
+string GPUDeviceName(Session* session) {
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ for (const DeviceAttributes& d : devices) {
+ if (d.device_type() == "GPU" || d.device_type() == "gpu") {
+ return d.name();
+ }
+ }
+ return "";
+}
+
+TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory) {
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ const string gpu_device_name = GPUDeviceName(session.get());
+ if (gpu_device_name.empty()) {
+ LOG(INFO) << "Skipping test since no GPU is available";
+ return;
+ }
+
+ TF_ASSERT_OK(session->Create(CreateGraphForYEqualsXSquared()));
+
+ CallableOptions opts;
+ opts.add_feed("x:0");
+ opts.add_fetch("y:0");
+
+ Tensor gpu_tensor;
+
+ {
+ Session::CallableHandle feed_cpu_fetch_gpu;
+ opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
+ opts.set_fetch_skip_sync(true);
+ TF_ASSERT_OK(session->MakeCallable(opts, &feed_cpu_fetch_gpu));
+ Tensor input(DT_FLOAT, {});
+ input.scalar<float>()() = 2.0f;
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(
+ session->RunCallable(feed_cpu_fetch_gpu, {input}, &outputs, nullptr));
+ TF_ASSERT_OK(session->ReleaseCallable(feed_cpu_fetch_gpu));
+ ASSERT_EQ(1, outputs.size());
+ gpu_tensor = outputs[0];
+ ASSERT_TRUE(IsCUDATensor(gpu_tensor));
+ }
+
+ {
+ Session::CallableHandle feed_gpu_fetch_cpu;
+ opts.clear_fetch_devices();
+ opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
+ TF_ASSERT_OK(session->MakeCallable(opts, &feed_gpu_fetch_cpu));
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(feed_gpu_fetch_cpu, {gpu_tensor},
+ &outputs, nullptr));
+ TF_ASSERT_OK(session->ReleaseCallable(feed_gpu_fetch_cpu));
+ ASSERT_EQ(1, outputs.size());
+ // The output is in CPU/host memory, so it can be dereferenced.
+ ASSERT_EQ(16.0, outputs[0].scalar<float>()());
+ }
+}
+
+GraphDef CreateIdentityGraphDef(DataType dtype) {
+ GraphDef def;
+
+ AttrValue dtype_attr;
+ dtype_attr.set_type(dtype);
+
+ AttrValue shape_attr;
+ shape_attr.mutable_shape()->set_unknown_rank(true);
+
+ auto* placeholder = def.add_node();
+ placeholder->set_name("x");
+ placeholder->set_op("Placeholder");
+ placeholder->mutable_attr()->insert({"dtype", dtype_attr});
+ placeholder->mutable_attr()->insert({"shape", shape_attr});
+
+ auto* identity = def.add_node();
+ identity->set_name("y");
+ identity->set_op("Identity");
+ identity->add_input("x");
+ identity->mutable_attr()->insert({"T", dtype_attr});
+
+ return def;
+}
+
+void TestFeedAndFetchTensorsInDeviceMemory(
+ const SessionOptions& session_options, DataType dtype) {
+ std::unique_ptr<Session> session(NewSession(session_options));
+ const string gpu_device_name = GPUDeviceName(session.get());
+ if (gpu_device_name.empty()) {
+ LOG(INFO) << "Skipping test since no GPU is available";
+ return;
+ }
+
+ TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
+ << DataType_Name(dtype);
+
+ CallableOptions opts;
+ opts.add_feed("x:0");
+ opts.add_fetch("y:0");
+
+ Tensor gpu_tensor;
+ Tensor host_tensor(dtype, {3});
+ {
+ // Ask for the fetched tensor to be backed by device memory.
+ // Even though the kernel that created the tensor produced it in host
+ // memory.
+ opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
+ opts.set_fetch_skip_sync(true);
+ Session::CallableHandle handle;
+ TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(handle, {host_tensor}, &outputs, nullptr))
+ << DataType_Name(dtype);
+ TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
+ ASSERT_EQ(1, outputs.size()) << DataType_Name(dtype);
+ gpu_tensor = outputs[0];
+ ASSERT_TRUE(IsCUDATensor(gpu_tensor)) << DataType_Name(dtype);
+ }
+
+ {
+ // Feed a tensor backed by device memory, even though the operations in the
+ // graph expect it in host memory.
+ opts.clear_fetch_devices();
+ opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
+ Session::CallableHandle handle;
+ TF_ASSERT_OK(session->MakeCallable(opts, &handle)) << DataType_Name(dtype);
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(session->RunCallable(handle, {gpu_tensor}, &outputs, nullptr))
+ << DataType_Name(dtype);
+ TF_ASSERT_OK(session->ReleaseCallable(handle)) << DataType_Name(dtype);
+ ASSERT_EQ(1, outputs.size());
+ const StringPiece actual_data = outputs[0].tensor_data();
+ const StringPiece expected_data = host_tensor.tensor_data();
+ EXPECT_EQ(expected_data.size(), actual_data.size()) << DataType_Name(dtype);
+ EXPECT_EQ(0, memcmp(expected_data.data(), actual_data.data(),
+ std::min(expected_data.size(), actual_data.size())))
+ << DataType_Name(dtype);
+ }
+}
+
+void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(
+ const SessionOptions& session_options, DataType dtype) {
+ std::unique_ptr<Session> session(NewSession(session_options));
+ const string gpu_device_name = GPUDeviceName(session.get());
+ if (gpu_device_name.empty()) {
+ LOG(INFO) << "Skipping test since no GPU is available";
+ return;
+ }
+
+ TF_ASSERT_OK(session->Create(CreateIdentityGraphDef(dtype)))
+ << DataType_Name(dtype);
+
+ CallableOptions opts;
+ opts.add_feed("x:0");
+ opts.add_fetch("y:0");
+
+ // Fail when asking to fetch into GPU memory.
+ {
+ opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name});
+ opts.set_fetch_skip_sync(true);
+ Session::CallableHandle handle;
+ Status status = session->MakeCallable(opts, &handle);
+ EXPECT_FALSE(status.ok()) << DataType_Name(dtype);
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(),
+ strings::StrCat(
+ "Cannot feed or fetch tensor 'y:0' from device ", gpu_device_name,
+ " as feeding/fetching from GPU devices is not yet supported for ",
+ DataTypeString(dtype), " tensors")))
+ << DataType_Name(dtype) << ", Status: " << status;
+ }
+
+ // Fail when feeding from GPU memory.
+ {
+ opts.clear_feed_devices();
+ opts.mutable_feed_devices()->insert({"x:0", gpu_device_name});
+ Session::CallableHandle handle;
+ Status status = session->MakeCallable(opts, &handle);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(
+ status.error_message(),
+ strings::StrCat(
+ "Cannot feed or fetch tensor 'x:0' from device ", gpu_device_name,
+ " as feeding/fetching from GPU devices is not yet supported for ",
+ DataTypeString(dtype), " tensors")))
+ << DataType_Name(dtype) << ", Status: " << status;
+ }
+}
+
+void TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(
+ const SessionOptions& opts) {
+ // Feeding/fetching on device does not work for all DataTypes as it
+ // relies on the implementation of the _Arg and _Retval kernels which
+ // are not registered for some types or consume/produce inputs/outputs
+ // in host memory for some types.
+ //
+ // Run through all datatypes to validate that either:
+ // (a) MakeCallable fails (because the given type cannot be fed/fetched
+ // in device memory),
+ // OR
+ // (b) Succeeds: RunCallable should gladly accept inputs in device memory
+ // and produce output tensors in device memory.
+ for (int i = DataType_MIN; i <= DataType_MAX; ++i) {
+ if (!DataType_IsValid(i)) continue;
+ const DataType dtype = static_cast<DataType>(i);
+ switch (dtype) {
+ case DT_INVALID:
+ break;
+ case DT_BFLOAT16:
+ case DT_BOOL:
+ case DT_COMPLEX128:
+ case DT_COMPLEX64:
+ case DT_DOUBLE:
+ case DT_FLOAT:
+ case DT_HALF:
+ case DT_INT16:
+ case DT_INT64:
+ case DT_INT8:
+ case DT_UINT16:
+ case DT_UINT8:
+ TestFeedAndFetchTensorsInDeviceMemory(opts, dtype);
+ break;
+ default:
+ // Ignore all REF types since Tensors of this type aren't intended to
+ // be fed (and attempting to create one via the Tensor constructor
+ // will result in a LOG(FATAL)).
+ if (!IsRefType(dtype)) {
+ TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable(opts, dtype);
+ }
+ break;
+ }
+ }
+}
+
+TEST(DirectSessionTest, FeedAndFetchTensorsInDeviceMemory_AllDataTypes) {
+ SessionOptions opts;
+ opts.config.set_allow_soft_placement(false);
+ TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
+}
+
+TEST(DirectSessionTest,
+ FeedAndFetchTensorsInDeviceMemory_AllDataTypes_SoftPlacement) {
+ SessionOptions opts;
+ opts.config.set_allow_soft_placement(true);
+ TestFeedAndFetchTensorsInDeviceMemoryForAllDataTypes(opts);
+}
+
// A simple benchmark for the overhead of `DirectSession::Run()` calls
// with varying numbers of feeds/fetches.
void FeedFetchBenchmarkHelper(int iters, int num_feeds,
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 9028e6298c..0b096a14a3 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -74,6 +74,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_constant_folding(RewriterConfig::OFF);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_min_graph_nodes(-1);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
@@ -103,24 +106,24 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
#ifdef INTEL_MKL
- // if MKL is used, it goes through various additional
- // graph rewrite pass. In TF, everytime a graph pass
+ // if MKL is used, it goes through various additional
+ // graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
// and deallocated. Each allocation calls the
// (FindChunkPtr of BFCAllocator),
- // which increments the value of AllocationId.
- // Thus AllocationId becomes more than 3 and 4 if
- // MKL is used. Now they are 9 and 10 for MKL.
- EXPECT_EQ(19, cm->AllocationId(node, 0));
+ // which increments the value of AllocationId.
+ // Thus AllocationId becomes more than TF if MKL
+ // is used. Now IDs for MKL are 8 more than TF.
+ EXPECT_EQ(29, cm->AllocationId(node, 0));
#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
+#endif
} else {
#ifdef INTEL_MKL
- EXPECT_EQ(20, cm->AllocationId(node, 0));
+ EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#endif
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index b5120f2872..7f28f3b793 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -22,14 +22,19 @@ tf_cuda_library(
"eager_executor.h",
],
visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ deps = select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cuda_library(
@@ -44,17 +49,23 @@ tf_cuda_library(
deps = [
":eager_executor",
":kernel_and_device",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- "//tensorflow/core/distributed_runtime:worker_session",
- "//tensorflow/core/distributed_runtime/eager:eager_client",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:worker_session",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ ],
+ }),
)
tf_cuda_library(
@@ -86,14 +97,20 @@ tf_cuda_library(
":context",
":eager_executor",
":kernel_and_device",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+ }),
)
tf_cuda_library(
@@ -106,14 +123,19 @@ tf_cuda_library(
":context",
":eager_executor",
":tensor_handle",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_options",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ ],
+ }),
)
tf_cuda_library(
@@ -125,14 +147,20 @@ tf_cuda_library(
"kernel_and_device.h",
],
visibility = ["//tensorflow:internal"],
- deps = [
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- ],
+ deps = select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ "//util/hash:farmhash_fingerprint",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+ }),
)
tf_cc_test(
@@ -168,14 +196,20 @@ cc_library(
":eager_operation",
":kernel_and_device",
":tensor_handle",
- "//tensorflow/core:core_cpu_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/distributed_runtime/eager:eager_client",
- "//tensorflow/core/distributed_runtime/eager:remote_execute_node",
- ],
+ ] + select({
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib_lite",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/distributed_runtime/eager:eager_client",
+ "//tensorflow/core/distributed_runtime/eager:remote_execute_node",
+ ],
+ }),
)
tf_cuda_library(
@@ -183,13 +217,15 @@ tf_cuda_library(
srcs = ["attr_builder.cc"],
hdrs = ["attr_builder.h"],
visibility = ["//tensorflow:internal"],
- deps = select({
+ deps = [
+ ":kernel_and_device",
+ "//tensorflow/c:c_api",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
+ "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
- ":kernel_and_device",
- "//tensorflow/c:c_api",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 8381cb58d2..70208fb6d1 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -38,10 +38,11 @@ EagerContext::EagerContext(const SessionOptions& opts,
InitDeviceMapAndAsync();
}
+#ifndef __ANDROID__
EagerContext::EagerContext(
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<GrpcServer> server,
+ std::unique_ptr<ServerInterface> server,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts)
@@ -55,12 +56,13 @@ EagerContext::EagerContext(
&func_lib_def_, {}, thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
async_default_(async),
+ remote_device_manager_(std::move(remote_device_manager)),
server_(std::move(server)),
remote_eager_workers_(std::move(remote_eager_workers)),
- remote_device_manager_(std::move(remote_device_manager)),
remote_contexts_(remote_contexts) {
InitDeviceMapAndAsync();
}
+#endif
void EagerContext::InitDeviceMapAndAsync() {
if (async_default_) {
@@ -125,10 +127,11 @@ ContextDevicePlacementPolicy EagerContext::GetDevicePlacementPolicy() {
}
EagerContext::~EagerContext() {
+#ifndef __ANDROID__
if (server_) {
// TODO(nareshmodi): Fix this.
LOG(WARNING) << "Unable to destroy server_ object, so releasing instead. "
- "GrpcServer doesn't support clean shutdown.";
+ "Servers don't support clean shutdown.";
server_.release();
}
@@ -158,6 +161,7 @@ EagerContext::~EagerContext() {
}
counter.Wait();
+#endif
executor_.WaitForAllPendingNodes().IgnoreError();
ClearCaches();
@@ -189,9 +193,46 @@ Status EagerContext::FindDeviceByName(const string& name, Device** result) {
return Status::OK();
}
+Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
+ if (remote_device_manager_ == nullptr) return Status::OK();
+
+ BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
+
+ std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size());
+ std::vector<eager::RegisterFunctionResponse> responses(
+ remote_contexts_.size());
+ std::vector<Status> statuses(remote_contexts_.size());
+
+ int i = 0;
+ for (const auto& target_and_context_id : remote_contexts_) {
+ requests[i].set_context_id(target_and_context_id.second);
+ *requests[i].mutable_function_def() = fdef;
+
+ auto* eager_client =
+ remote_eager_workers_->GetClient(target_and_context_id.first);
+
+ eager_client->RegisterFunctionAsync(
+ &requests[i], &responses[i],
+ [i, &statuses, &blocking_counter](const Status& status) {
+ statuses[i] = status;
+ blocking_counter.DecrementCount();
+ });
+
+ i++;
+ }
+ blocking_counter.Wait();
+
+ for (int i = 0; i < remote_contexts_.size(); i++) {
+ TF_RETURN_IF_ERROR(statuses[i]);
+ }
+ return Status::OK();
+}
+
Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
mutex_lock l(functions_mu_);
- return func_lib_def_.AddFunctionDef(fdef);
+ TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
+
+ return MaybeRegisterFunctionRemotely(fdef);
}
KernelAndDevice* EagerContext::GetCachedKernel(Fprint128 cache_key) {
@@ -224,6 +265,7 @@ Status GetTaskName(Device* d, string* task_name) {
}
} // namespace
+#ifndef __ANDROID__
Status EagerContext::GetClientAndContextID(Device* device,
eager::EagerClient** client,
uint64* context_id) {
@@ -253,5 +295,6 @@ Status EagerContext::GetClientAndContextID(Device* device,
return Status::OK();
}
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 096ed3112e..864f514a19 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -29,8 +29,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#endif
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -75,21 +77,22 @@ class EagerContext {
// workers.
//
// Additional remote-specific args are:
- // - server: A GrpcServer that exports the tensorflow.WorkerService. Note
- // that this class expects the server to already have been started.
+ // - server: A ServerInterface that exports the tensorflow.WorkerService.
+ // Note that this class expects the server to already have been started.
// - remote_eager_workers: A cache from which we can get "EagerClient"s to
// communicate with remote eager services.
// - remote_device_mgr: A DeviceMgr* which contains all remote devices
// (should contain no local devices).
// - remote_contexts: A map containing task name to remote context ID.
+#ifndef __ANDROID__
explicit EagerContext(
const SessionOptions& opts, ContextDevicePlacementPolicy default_policy,
bool async, DeviceMgr* local_device_mgr, Rendezvous* rendezvous,
- std::unique_ptr<GrpcServer> server,
+ std::unique_ptr<ServerInterface> server,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts);
-
+#endif
~EagerContext();
// Returns the function library runtime for the given device.
@@ -174,11 +177,13 @@ class EagerContext {
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
+#ifndef __ANDROID__
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
uint64* context_id);
-
+#endif
private:
void InitDeviceMapAndAsync();
+ Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
const ContextDevicePlacementPolicy policy_;
@@ -228,16 +233,19 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const std::unique_ptr<DeviceMgr> remote_device_manager_;
+
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
// be).
- std::unique_ptr<GrpcServer> server_;
+#ifndef __ANDROID__
+ std::unique_ptr<ServerInterface> server_;
const std::unique_ptr<eager::EagerClientCache> remote_eager_workers_;
- const std::unique_ptr<DeviceMgr> remote_device_manager_;
const gtl::FlatMap<string, uint64> remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
device_to_client_cache_;
+#endif
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index ce989f4b4e..7a2b477845 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute_node.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
+#endif
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -34,11 +36,17 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
+// Copy of the definition in third_party/tensorflow/compiler/jit/defs.h
+// Copied here because we don't currently compile XLA on windows. So, can't
+// depend on it directly.
+const char* const kXlaCompileAttr = "_XlaCompile";
+
// Initializes the step stats if needed.
void MaybeInitializeStepStats(StepStats* step_stats, EagerContext* ctx) {
// Lazily initialize the RunMetadata with information about all devices if
@@ -66,6 +74,88 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
return 0;
}
+// This function expects *handle to point to an existing tensor handle. The
+// function will (maybe) update the *handle to be pointed to the newly copied
+// tensor handle.
+//
+// The passed in *handle will be Unreffed if it is replaced.
+Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
+ const Device* expected_device,
+ RunMetadata* run_metadata,
+ TensorHandle** handle) {
+ EagerContext* ctx = op->EagerContext();
+ Device* handle_device = nullptr;
+ TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
+ const Device* actual_device =
+ handle_device == nullptr ? ctx->HostCPU() : handle_device;
+
+ if (expected_device != actual_device) {
+ switch (ctx->GetDevicePlacementPolicy()) {
+ case DEVICE_PLACEMENT_SILENT_FOR_INT32:
+ // TODO(xpan): See if we could bubble python related error up
+ // to python level.
+ if ((*handle)->dtype == DT_INT32) {
+ // Note: enabling silent copies of int32 tensors to match behavior
+ // of graph mode.
+ break;
+ }
+ TF_FALLTHROUGH_INTENDED;
+ case DEVICE_PLACEMENT_EXPLICIT:
+ return errors::InvalidArgument(
+ "Tensors on conflicting devices:"
+ " cannot compute ",
+ op->Name(), " as input #", i, " was expected to be on ",
+ expected_device->name(), " but is actually on ",
+ actual_device->name(), " (operation running on ",
+ op->Device()->name(), ")",
+ " Tensors can be copied explicitly using .gpu() or .cpu() "
+ "methods,"
+ " or transparently copied by using tf.enable_eager_execution("
+ "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
+ "between devices"
+ " may slow down your model");
+ case DEVICE_PLACEMENT_WARN:
+ LOG(WARNING) << "before computing " << op->Name() << " input #" << i
+ << " was expected to be on " << expected_device->name()
+ << " but is actually on " << actual_device->name()
+ << " (operation running on " << op->Device()->name()
+ << "). This triggers a copy which can be a performance "
+ "bottleneck.";
+ break;
+ case DEVICE_PLACEMENT_SILENT: // Do nothing.
+ break;
+ }
+ // We are only here if the policy is warn or silent copies, so we should
+ // trigger a copy.
+ auto pre_time = Env::Default()->NowMicros();
+ TensorHandle* result_handle = nullptr;
+ Status status = EagerCopyToDevice(
+ *handle, ctx, expected_device->name().c_str(), &result_handle);
+ if (run_metadata != nullptr) {
+ auto* step_stats = run_metadata->mutable_step_stats();
+ MaybeInitializeStepStats(step_stats, ctx);
+ // Record the sending on the source device for now.
+ int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
+ auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
+ auto* node_stats = dev_stats->add_node_stats();
+ node_stats->set_node_name("_Send");
+ node_stats->set_all_start_micros(pre_time);
+ node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() - pre_time);
+ }
+ if (!status.ok()) {
+ if (result_handle != nullptr) result_handle->Unref();
+ return errors::Internal("Failed copying input tensor from ",
+ actual_device->name(), " to ",
+ expected_device->name(), " in order to run ",
+ op->Name(), ": ", status.error_message());
+ }
+
+ (*handle)->Unref();
+ *handle = result_handle;
+ }
+ return Status::OK();
+}
+
Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
EagerOperation* op, const OpKernel* kernel,
RunMetadata* run_metadata) {
@@ -78,76 +168,9 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
for (int i = 0; i < op->Inputs().size(); ++i) {
const Device* expected_device =
memtypes[i] == HOST_MEMORY ? host_device : op_device;
- TensorHandle* handle = op->Inputs()[i];
- Device* handle_device = nullptr;
- TF_RETURN_IF_ERROR(handle->Device(&handle_device));
- const Device* actual_device =
- handle_device == nullptr ? host_device : handle_device;
- if (expected_device != actual_device) {
- switch (ctx->GetDevicePlacementPolicy()) {
- case DEVICE_PLACEMENT_SILENT_FOR_INT32:
- // TODO(xpan): See if we could bubble python related error up
- // to python level.
- if (handle->dtype == DT_INT32) {
- // Note: enabling silent copies of int32 tensors to match behavior
- // of graph mode.
- break;
- }
- TF_FALLTHROUGH_INTENDED;
- case DEVICE_PLACEMENT_EXPLICIT:
- return errors::InvalidArgument(
- "Tensors on conflicting devices:"
- " cannot compute ",
- op->Name(), " as input #", i, " was expected to be on ",
- expected_device->name(), " but is actually on ",
- actual_device->name(), " (operation running on ",
- op_device->name(), ")",
- " Tensors can be copied explicitly using .gpu() or .cpu() "
- "methods,"
- " or transparently copied by using tf.enable_eager_execution("
- "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
- "between devices"
- " may slow down your model");
- case DEVICE_PLACEMENT_WARN:
- LOG(WARNING) << "before computing " << op->Name() << " input #" << i
- << " was expected to be on " << expected_device->name()
- << " but is actually on " << actual_device->name()
- << " (operation running on " << op_device->name()
- << "). This triggers a copy which can be a performance "
- "bottleneck.";
- break;
- case DEVICE_PLACEMENT_SILENT: // Do nothing.
- break;
- }
- // We are only here if the policy is warn or silent copies, so we should
- // trigger a copy.
- auto pre_time = Env::Default()->NowMicros();
- TensorHandle* copied_tensor = nullptr;
- Status status = EagerCopyToDevice(
- handle, ctx, expected_device->name().c_str(), &copied_tensor);
- if (run_metadata != nullptr) {
- auto* step_stats = run_metadata->mutable_step_stats();
- MaybeInitializeStepStats(step_stats, ctx);
- // Record the sending on the source device for now.
- int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
- auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
- auto* node_stats = dev_stats->add_node_stats();
- node_stats->set_node_name("_Send");
- node_stats->set_all_start_micros(pre_time);
- node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() -
- pre_time);
- }
- if (!status.ok()) {
- if (copied_tensor != nullptr) copied_tensor->Unref();
- return errors::Internal("Failed copying input tensor from ",
- actual_device->name(), " to ",
- expected_device->name(), " in order to run ",
- op->Name(), ": ", status.error_message());
- }
- handle->Unref();
- handle = copied_tensor;
- (*op->MutableInputs())[i] = copied_tensor;
- }
+ TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
+ op, i, expected_device, run_metadata, &((*op->MutableInputs())[i])));
+ tensorflow::TensorHandle* handle = op->Inputs()[i];
if (handle->dtype != kernel->input_type(i)) {
return errors::InvalidArgument(
"cannot compute ", op->Name(), " as input #", i,
@@ -192,8 +215,8 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
// Resource4> as the input params to the synthesized function.
//
// It populates `const_input_types`, `arg_input_types` and
-// `op_input_to_func_input` based on the reordering results, that the caller can
-// use them to build an XlaLaunch. On error, it returns NULL, and sets
+// `op_input_to_func_input` based on the reordering results, that the caller
+// can use them to build an XlaLaunch. On error, it returns NULL, and sets
// `status` accordingly.
const FunctionDef* OpToFunction(TFE_Op* op,
std::vector<TF_DataType>* const_input_types,
@@ -221,8 +244,8 @@ const FunctionDef* OpToFunction(TFE_Op* op,
const std::unordered_set<string> const_inputs(
*XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));
- // First add place holders for the input args, so that we can refer to them by
- // position in the next loop. Also tally up the resource inputs.
+ // First add place holders for the input args, so that we can refer to them
+ // by position in the next loop. Also tally up the resource inputs.
int num_resource_inputs = 0;
for (int i = 0; i < op_def.input_arg_size(); ++i) {
if (op_def.input_arg(i).type() == DT_RESOURCE) {
@@ -336,8 +359,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
&op_input_to_func_input, status);
if (!status.ok()) return nullptr;
} else {
- // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
- // functions, so we need to find another way to handle constant inputs.
+ // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work
+ // for functions, so we need to find another way to handle constant
+ // inputs.
for (int i = const_input_types.size();
i < fdef->signature().input_arg_size(); ++i) {
VLOG(1) << "Adding Targs from input arg " << i;
@@ -348,8 +372,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
DCHECK(fdef != nullptr);
// Copy inputs and their devices.
- // Since input param reordering may have occurred between `op` and `launch_op`
- // via `op_input_to_func_input`, adjust the actual inputs accordingly.
+ // Since input param reordering may have occurred between `op` and
+ // `launch_op` via `op_input_to_func_input`, adjust the actual inputs
+ // accordingly.
*launch_op->operation.MutableInputs() = op->operation.Inputs();
for (TensorHandle* h : launch_op->operation.Inputs()) {
h->Ref();
@@ -399,7 +424,13 @@ Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
+ const FunctionDef* function_def =
+ op->EagerContext()->FuncLibDef()->Find(op->Name());
+ if (function_def != nullptr) {
+ op_def = &(function_def->signature());
+ } else {
+ TF_RETURN_IF_ERROR(OpDefForOp(op->Name().c_str(), &op_def));
+ }
TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, *op_def, output_dtypes));
@@ -455,6 +486,15 @@ Status EagerLocalExecute(EagerOperation* op,
device == nullptr ? "unspecified" : device->name());
KernelAndDevice* kernel = ctx->GetCachedKernel(cache_key);
if (kernel == nullptr) {
+ // If we are running a function on explicitly requested TPU,
+ // compile it with XLA.
+ // Note that it is not ideal, but currently ok, to set this
+ // attribute after computing the kernel cache key above.
+ if (op->is_function() && device != nullptr &&
+ device->device_type() == "TPU") {
+ op->MutableAttrs()->Set(kXlaCompileAttr, true);
+ }
+
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (device == nullptr) {
status = SelectDevice(ndef, ctx, &device);
@@ -542,28 +582,38 @@ Status EagerLocalExecute(EagerOperation* op,
return status;
}
-Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
- uint64 context_id, TensorHandle** retvals,
+Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) {
- // All tensors must be on the same device.
- // TODO(nareshmodi): handle silent copies
+#ifdef __ANDROID__
+ return errors::Unimplemented(
+ "Eager's remote execution is not available on Android devices.");
+#else
+ EagerContext* ctx = op->EagerContext();
+
+ eager::EagerClient* eager_client;
+ uint64 context_id;
+ TF_RETURN_IF_ERROR(
+ ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
+
eager::EnqueueRequest request;
eager::EnqueueResponse response;
auto* remote_op = request.add_queue()->mutable_operation();
- for (auto* input : op->Inputs()) {
+ for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::Device* input_device;
- TF_RETURN_IF_ERROR(input->Device(&input_device));
+ TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
if (op->Device() != input_device) {
- return tensorflow::errors::InvalidArgument(
- "Ops and inputs are not on the same device. Use "
- "TFE_TensorHandleCopyToDevice to get ops on the same "
- "device. Expected device: ",
- op->Device()->name(), ", Actual device: ", input_device->name());
+ // TODO(b/110044833): It's possible the same tensor gets copied to the
+ // remote device repeatedly.
+ TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
+ op, i, op->Device(), /* run_metadata= */ nullptr,
+ &(*op->MutableInputs())[i]));
}
- tensorflow::uint64 op_id;
+ tensorflow::TensorHandle* input = op->Inputs()[i];
+
+ tensorflow::int64 op_id;
int32 output_num;
TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
@@ -580,22 +630,6 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
request.set_context_id(context_id);
- if (op->EagerContext()->Async()) {
- tensorflow::uint64 id = op->EagerContext()->NextId();
- auto* node = new eager::RemoteExecuteNode(id, request, eager_client);
- op->EagerContext()->ExecutorAdd(node);
- } else {
- Notification n;
- Status status;
- eager_client->EnqueueAsync(&request, &response,
- [&n, &status](const Status& s) {
- status = s;
- n.Notify();
- });
- n.WaitForNotification();
- if (!status.ok()) return status;
- }
-
DataTypeVector output_dtypes;
TF_RETURN_IF_ERROR(GetOutputDTypes(op, &output_dtypes));
@@ -605,7 +639,13 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
}
tensorflow::Device* op_device = op->Device();
- EagerContext* ctx = op->EagerContext();
+
+ bool is_async = op->EagerContext()->Async();
+ uint64 remote_node_id = 0;
+
+ if (is_async) {
+ remote_node_id = op->EagerContext()->NextId();
+ }
const tensorflow::uint64 id = remote_op->id();
for (int i = 0; i < *num_retvals; i++) {
@@ -634,12 +674,56 @@ Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
return tensorflow::Status::OK();
};
- retvals[i] = new TensorHandle(remote_op->id(), i, output_dtypes[i],
- std::move(callback), op_device, op_device,
- op->EagerContext());
+
+ retvals[i] = new TensorHandle(remote_op->id(), i, remote_node_id,
+ output_dtypes[i], std::move(callback),
+ op_device, op_device, op->EagerContext());
+ }
+
+ if (is_async) {
+ // Copy the output handles, since the container for them might get
+ // destroyed.
+ gtl::InlinedVector<TensorHandle*, 2> retvals_copy;
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals_copy.push_back(retvals[i]);
+ retvals_copy[i]->Ref();
+ }
+ // Unable to capture via std::move, so bind instead.
+ auto* node = new eager::RemoteExecuteNode(
+ remote_node_id, request, eager_client, op->Inputs(),
+ std::bind(
+ [](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
+ const Status& status, const eager::EnqueueResponse& response) {
+ if (!status.ok()) return;
+ for (int i = 0; i < retvals.size(); i++) {
+ retvals[i]->SetRemoteShape(MakeUnique<TensorShape>(
+ response.queue_response(0).shape(i)));
+ retvals[i]->Unref();
+ }
+ },
+ std::move(retvals_copy), std::placeholders::_1,
+ std::placeholders::_2));
+ op->EagerContext()->ExecutorAdd(node);
+ } else {
+ Notification n;
+ Status status;
+ eager_client->EnqueueAsync(&request, &response,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+
+ if (!status.ok()) return status;
+
+ for (int i = 0; i < *num_retvals; i++) {
+ retvals[i]->SetRemoteShape(
+ MakeUnique<TensorShape>(response.queue_response(0).shape(i)));
+ }
}
return Status::OK();
+#endif
}
} // namespace
@@ -652,15 +736,7 @@ Status EagerExecute(EagerOperation* op,
return EagerLocalExecute(op, retvals, num_retvals);
}
- auto* ctx = op->EagerContext();
-
- tensorflow::eager::EagerClient* eager_client;
- tensorflow::uint64 context_id;
- TF_RETURN_IF_ERROR(
- ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
-
- return EagerRemoteExecute(op, eager_client, context_id, retvals->data(),
- num_retvals);
+ return EagerRemoteExecute(op, retvals->data(), num_retvals);
}
Status EagerExecute(EagerContext* ctx, Device* device,
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 2a43a31c02..b410ea175b 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -79,6 +79,7 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
params.function_library = flib_;
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
+ params.cancellation_manager = &cm_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index f78d197fd5..c41a0972b1 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
@@ -76,6 +77,11 @@ class KernelAndDevice {
const DataTypeVector& output_dtypes() { return output_dtypes_; }
private:
+ // TODO(apassos) Consider a shared cancellation manager. Note that this
+ // cancellation manager is not useful to actually cancel anything, and is
+ // provided here only for the few kernels which can't handle one being
+ // missing.
+ CancellationManager cm_;
std::unique_ptr<OpKernel> kernel_;
Device* device_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 1a811aa8df..f9b9abcc99 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -45,7 +45,7 @@ limitations under the License.
namespace tensorflow {
bool TensorHandle::IsReady() {
- if (node_id == 0) return true;
+ if (node_id_ == 0) return true;
mutex_lock l(ctx_mutex_);
return is_ready_;
}
@@ -54,17 +54,19 @@ bool TensorHandle::IsRemote() {
return remote_op_id_ >= 0 && remote_output_num_ >= 0;
}
-Status TensorHandle::WaitReady() {
+Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) {
if (node_id == 0) return Status::OK();
EagerExecutor* executor = nullptr;
{
mutex_lock l(ctx_mutex_);
- if (is_ready_) return Status::OK();
+ if (return_if_is_ready && is_ready_) return Status::OK();
executor = ctx_->Executor();
}
return executor->WaitFor(node_id);
}
+Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); }
+
Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
if (IsRemote()) {
return errors::Unavailable(
@@ -107,7 +109,38 @@ Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
return Status::OK();
}
-Status TensorHandle::RemoteAddress(uint64* op_id, int32* output_num) {
+Status TensorHandle::NumDims(int* num_dims) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ CHECK(remote_shape_ != nullptr);
+ *num_dims = remote_shape_->dims();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_dims != nullptr);
+
+ *num_dims = tensor_.dims();
+ }
+
+ return Status::OK();
+}
+
+Status TensorHandle::Dim(int dim_index, int64* dim) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *dim = remote_shape_->dim_size(dim_index);
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(dim != nullptr);
+
+ *dim = tensor_.dim_size(dim_index);
+ }
+
+ return Status::OK();
+}
+
+Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
"This TensorHandle refers to a local tensor handle");
@@ -122,7 +155,7 @@ void TensorHandle::SetTensorAndDevice(const tensorflow::Tensor& tensor,
tensorflow::Device* device,
tensorflow::Device* op_device) {
mutex_lock l(ctx_mutex_);
- DCHECK(node_id > 0 && !is_ready_)
+ DCHECK(node_id_ > 0 && !is_ready_)
<< "SetTensorAndDevice should be only called "
<< "on non-ready handles.";
is_ready_ = true;
@@ -189,6 +222,7 @@ Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
srcd, dstd, tensorflow::AllocatorAttributes(),
tensorflow::AllocatorAttributes(), src, &dst,
+ 0 /*dev_to_dev_stream_index*/,
[&status, &n](const tensorflow::Status& s) {
status = s;
n.Notify();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index a3b7dd862e..46bc94f875 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -51,38 +51,41 @@ class TensorHandle : public core::RefCounted {
public:
TensorHandle(const Tensor& t, Device* d, Device* op_device, EagerContext* ctx)
: dtype(t.dtype()),
- node_id(0),
+ node_id_(0),
tensor_(t),
device_(d),
op_device_(op_device),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(true) {}
TensorHandle(uint64 node_id, DataType dtype, EagerContext* ctx)
: dtype(dtype),
- node_id(node_id),
+ node_id_(node_id),
tensor_(dtype),
device_(nullptr),
op_device_(nullptr),
remote_op_id_(-1),
remote_output_num_(-1),
+ remote_shape_node_id_(-1),
ctx_(ctx),
is_ready_(ctx == nullptr) {
- DCHECK_GT(node_id, 0);
+ DCHECK_GT(node_id_, 0);
}
// Remote tensor handle constructor.
- TensorHandle(uint64 op_id, int32 output_num, DataType dtype,
- std::function<void()> call_on_destroy, Device* d,
+ TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id,
+ DataType dtype, std::function<void()> call_on_destroy, Device* d,
Device* op_device, EagerContext* ctx)
: dtype(dtype),
- node_id(0),
+ node_id_(0),
device_(d),
op_device_(op_device),
remote_op_id_(op_id),
remote_output_num_(output_num),
+ remote_shape_node_id_(remote_shape_node_id),
call_on_destroy_(std::move(call_on_destroy)),
ctx_(ctx),
is_ready_(true) {
@@ -106,8 +109,11 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device** device,
tensorflow::Device** op_device);
+ Status NumDims(int* num_dims);
+ Status Dim(int dim_index, int64* dim);
+
// Return the op_id and output num if the handle refers to a remote tensor.
- Status RemoteAddress(uint64* op_id, int32* output_num);
+ Status RemoteAddress(int64* op_id, int32* output_num);
// Note that this can be called at most once, and only on non-ready handles,
// and makes them ready.
@@ -128,11 +134,16 @@ class TensorHandle : public core::RefCounted {
// ready.
const DataType dtype;
+ void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) {
+ remote_shape_ = std::move(remote_shape);
+ }
+
private:
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that compuatation is
// done and the handle is "ready".
Status WaitReady();
+ Status WaitForNode(uint64 node_id, bool return_if_is_ready);
bool IsReady();
@@ -140,7 +151,7 @@ class TensorHandle : public core::RefCounted {
// Id for the EagerNode that will compute the value pointed to by this handle.
// If the value is 0, the handle is already ready, but not vice-versa.
- const uint64 node_id;
+ const uint64 node_id_;
tensorflow::Tensor tensor_;
@@ -159,8 +170,10 @@ class TensorHandle : public core::RefCounted {
tensorflow::Device* op_device_;
// IDs required when this class is representing a remote tensor handle.
- const uint64 remote_op_id_;
+ const int64 remote_op_id_;
const int32 remote_output_num_;
+ std::unique_ptr<TensorShape> remote_shape_;
+ const uint64 remote_shape_node_id_;
// A callback that is executed when the class is destroyed.
//
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 585d777e81..f7f2cdc14f 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
@@ -2764,4 +2765,30 @@ Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
+namespace {
+
+class DefaultExecutorRegistrar {
+ public:
+ DefaultExecutorRegistrar() {
+ Factory* factory = new Factory;
+ ExecutorFactory::Register("", factory);
+ ExecutorFactory::Register("DEFAULT", factory);
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret = nullptr;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static DefaultExecutorRegistrar registrar;
+
+} // namespace
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.cc b/tensorflow/core/common_runtime/executor_factory.cc
new file mode 100644
index 0000000000..ee7c7c3a73
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.cc
@@ -0,0 +1,85 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/executor_factory.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+static mutex executor_factory_lock(LINKER_INITIALIZED);
+
+typedef std::unordered_map<string, ExecutorFactory*> ExecutorFactories;
+ExecutorFactories* executor_factories() {
+ static ExecutorFactories* factories = new ExecutorFactories;
+ return factories;
+}
+
+} // namespace
+
+void ExecutorFactory::Register(const string& executor_type,
+ ExecutorFactory* factory) {
+ mutex_lock l(executor_factory_lock);
+ if (!executor_factories()->insert({executor_type, factory}).second) {
+ LOG(FATAL) << "Two executor factories are being registered "
+ << "under" << executor_type;
+ }
+}
+
+namespace {
+const string RegisteredFactoriesErrorMessageLocked()
+ SHARED_LOCKS_REQUIRED(executor_factory_lock) {
+ std::vector<string> factory_types;
+ for (const auto& executor_factory : *executor_factories()) {
+ factory_types.push_back(executor_factory.first);
+ }
+ return strings::StrCat("Registered factories are {",
+ str_util::Join(factory_types, ", "), "}.");
+}
+} // namespace
+
+Status ExecutorFactory::GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory) {
+ tf_shared_lock l(executor_factory_lock);
+
+ auto iter = executor_factories()->find(executor_type);
+ if (iter == executor_factories()->end()) {
+ return errors::NotFound(
+ "No executor factory registered for the given executor type: ",
+ executor_type, " ", RegisteredFactoriesErrorMessageLocked());
+ }
+
+ *out_factory = iter->second;
+ return Status::OK();
+}
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) {
+ ExecutorFactory* factory = nullptr;
+ TF_RETURN_IF_ERROR(ExecutorFactory::GetFactory(executor_type, &factory));
+ return factory->NewExecutor(params, std::move(graph), out_executor);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/executor_factory.h b/tensorflow/core/common_runtime/executor_factory.h
new file mode 100644
index 0000000000..f81bb080eb
--- /dev/null
+++ b/tensorflow/core/common_runtime/executor_factory.h
@@ -0,0 +1,51 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
+
+#include <string>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class Executor;
+class Graph;
+struct LocalExecutorParams;
+
+class ExecutorFactory {
+ public:
+ virtual Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) = 0;
+ virtual ~ExecutorFactory() {}
+
+ static void Register(const string& executor_type, ExecutorFactory* factory);
+ static Status GetFactory(const string& executor_type,
+ ExecutorFactory** out_factory);
+};
+
+Status NewExecutor(const string& executor_type,
+ const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc
index 8cb1567852..7697103faf 100644
--- a/tensorflow/core/common_runtime/executor_test.cc
+++ b/tensorflow/core/common_runtime/executor_test.cc
@@ -464,12 +464,12 @@ BENCHMARK(BM_executor)->ArgPair(1024, 1024);
static void BM_FeedInputFetchOutput(int iters) {
Graph* g = new Graph(OpRegistry::Global());
// z = x + y: x and y are provided as benchmark inputs. z is the
- // output of the benchmark. Conceptually, the caller is "a", the
- // benchmark is "b".
- Node* x = test::graph::Recv(g, "x", "float", "a", 1, "b");
- Node* y = test::graph::Recv(g, "y", "float", "a", 1, "b");
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
Node* sum = test::graph::Add(g, x, y);
- Node* z = test::graph::Send(g, sum, "z", "b", 1, "a");
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
Tensor val(DT_FLOAT, TensorShape({}));
val.scalar<float>()() = 3.14;
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 5d9be70522..a93cfa2ec5 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
@@ -215,6 +216,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
Executor* exec = nullptr;
+ string executor_type;
~Item() {
delete this->func_graph;
@@ -397,12 +399,11 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(
// types.
MemoryTypeVector input_memory_types;
for (const auto& t : fbody->arg_types) {
- input_memory_types.push_back(
- (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY);
+ input_memory_types.push_back(MTypeFromDType(t));
}
MemoryTypeVector output_memory_types;
for (const auto& t : fbody->ret_types) {
- output_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY);
+ output_memory_types.push_back(MTypeFromDType(t));
}
// Constructs a CallOp kernel for running the instantiated function.
@@ -549,6 +550,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
item->func_graph = fbody;
item->overlay_lib = options.overlay_lib;
item->instantiation_counter = 1;
+ item->executor_type = options.executor_type;
items_.emplace(next_handle_, std::unique_ptr<Item>(item));
next_handle_++;
}
@@ -623,10 +625,12 @@ void PruneFunctionBody(Graph* g) {
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionBody* fbody;
const FunctionLibraryDefinition* lib_def;
+ string executor_type;
{
mutex_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
+ executor_type = (*item)->executor_type;
}
if (!lib_def) {
lib_def = base_lib_def_;
@@ -656,17 +660,14 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
DeleteNonCachedKernel(kernel);
};
Graph* graph = g.get();
- Executor* exec;
- TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec));
-
+ std::unique_ptr<Executor> exec;
+ TF_RETURN_IF_ERROR(NewExecutor(executor_type, params, std::move(g), &exec));
{
// Guard item since it is already inserted in items_.
mutex_lock l(mu_);
- if ((*item)->exec) {
- delete exec;
- } else {
+ if ((*item)->exec == nullptr) {
(*item)->graph = graph;
- (*item)->exec = exec;
+ (*item)->exec = exec.release();
}
}
return Status::OK();
@@ -726,6 +727,25 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
+ std::vector<AllocatorAttributes> args_alloc_attrs, rets_alloc_attrs;
+ args_alloc_attrs.reserve(fbody->arg_types.size());
+ rets_alloc_attrs.reserve(fbody->ret_types.size());
+ // Note: Functions assume that int32's are always on host memory.
+ for (const auto& arg_type : fbody->arg_types) {
+ AllocatorAttributes arg_alloc_attrs;
+ if (MTypeFromDType(arg_type) == HOST_MEMORY) {
+ arg_alloc_attrs.set_on_host(true);
+ }
+ args_alloc_attrs.push_back(arg_alloc_attrs);
+ }
+ for (const auto& ret_type : fbody->ret_types) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (MTypeFromDType(ret_type) == HOST_MEMORY) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
+
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
@@ -733,10 +753,10 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", src_incarnation, args.size(),
- device_context, {}, rendezvous, remote_args,
+ device_context, args_alloc_attrs, rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device,
- target_incarnation, rendezvous, device_context, rets, done,
- exec_args](const Status& status) {
+ target_incarnation, rendezvous, device_context, rets, done, exec_args,
+ rets_alloc_attrs](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->SetArgs(*remote_args);
@@ -749,9 +769,10 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
return;
}
item->exec->RunAsync(
- *exec_args, [frame, rets, done, source_device, target_device,
- target_incarnation, rendezvous, device_context,
- remote_args, exec_args](const Status& status) {
+ *exec_args,
+ [frame, rets, done, source_device, target_device,
+ target_incarnation, rendezvous, device_context, remote_args,
+ exec_args, rets_alloc_attrs](const Status& status) {
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -765,7 +786,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
}
s = ProcessFunctionLibraryRuntime::SendTensors(
target_device, source_device, "ret_", target_incarnation,
- *rets, device_context, {}, rendezvous);
+ *rets, device_context, rets_alloc_attrs, rendezvous);
delete remote_args;
delete exec_args;
done(s);
@@ -1186,11 +1207,13 @@ static bool ValidateInlining(const Node* node, const FunctionBody* fbody) {
return true;
}
-// Given a "caller" in "graph", which is a function call of a function
+// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
-// edges properly.
+// edges properly. "override_device" specifies whether inlining should replace
+// explicitly specified devices inside fbody with the callee's device.
void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
- Node* caller, const FunctionBody* fbody) {
+ Node* caller, const FunctionBody* fbody,
+ bool override_device) {
if (!ValidateInlining(caller, fbody)) {
LOG(WARNING) << "Inlining mismatch: " << caller->DebugString() << " vs. "
<< DebugString(fbody->graph);
@@ -1225,7 +1248,9 @@ void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
for (Node* n : fbody->graph->op_nodes()) {
NodeDef ndef = n->def();
ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
- ndef.set_device(caller->def().device());
+ if (override_device || ndef.device().empty()) {
+ ndef.set_device(caller->def().device());
+ }
Node* clone = g->AddNode(ndef, &s);
TF_CHECK_OK(s);
node_map[n->id()] = clone;
@@ -1579,6 +1604,12 @@ FunctionBody* SymbolicGradientHelper::Compute() {
g->RemoveNode(n);
}
gbody_->ret_types = fbody_->arg_types;
+ // TODO(apassos): use the right dtype for gradients of resource variables
+ for (int i = 0; i < gbody_->ret_types.size(); ++i) {
+ if (gbody_->ret_types[i] == DT_RESOURCE) {
+ gbody_->ret_types[i] = DT_FLOAT;
+ }
+ }
gbody_->ret_nodes.clear();
// Add new return nodes to the function gradient body for each node
// in 'x_grad_nodes'.
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index a0f9fcae0a..a274f1ef51 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -155,9 +155,11 @@ FunctionBody* SymbolicGradient(const FunctionBody& f);
// Given a "caller" in graph "g", which is a function call of a function
// to "fbody". Replaces the "caller" with fbody->graph and connects
-// edges properly.
+// edges properly. "override_device" specifies whether inlining should replace
+// explicitly specified devices inside fbody with the callee's device.
void InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
- Node* caller, const FunctionBody* fbody);
+ Node* caller, const FunctionBody* fbody,
+ bool override_device = true);
// Instantiates FunctionDef into a graph. Set *fbody to point to the
// FunctionBody that holds the instantiated FunctionDef.
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index f4f5198396..1e837e9a7e 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
@@ -531,6 +532,69 @@ TEST_F(FunctionLibraryRuntimeTest, StateHandle) {
}
}
+namespace {
+class DummyExecutorRegistrar {
+ public:
+ DummyExecutorRegistrar() {
+ ExecutorFactory::Register("DUMMY", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ return errors::Internal("This is a dummy.");
+ }
+ };
+};
+static DummyExecutorRegistrar registrar;
+} // namespace
+
+TEST_F(FunctionLibraryRuntimeTest, ExecutorFactory) {
+ Init({test::function::XTimesTwo()});
+
+ auto x = test::AsTensor<float>({1, 2, 3, 4});
+ Tensor y;
+
+ // Test that the default executor works.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test the explicit registration for the default executor.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DEFAULT";
+ TF_CHECK_OK(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}},
+ options, {x}, {&y}));
+ test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
+ }
+
+ // Test that a non-default executor factory can be invoked.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "DUMMY";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Internal: This is a dummy.");
+ }
+
+ // Test that non-existent exector types trigger an error.
+ {
+ FunctionLibraryRuntime::InstantiateOptions options;
+ options.executor_type = "UNKNOWN_EXECUTOR";
+ HasError(InstantiateAndRun(flr0_, "XTimesTwo", {{"T", DT_FLOAT}}, options,
+ {x}, {&y}),
+ "Not found: No executor factory registered for the given executor "
+ "type: UNKNOWN_EXECUTOR");
+ }
+}
+
TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctions) {
Init({test::function::XTimesTwo(), test::function::XTimesFour(),
test::function::XTimes16()});
@@ -803,7 +867,7 @@ TEST_F(FunctionLibraryRuntimeTest, OptimizeGraph) {
Scope s = Scope::NewRootScope();
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto x4_x2_scale = ops::Const<float>(
- s.WithOpName("x4/x2/scale/_12__cf__6")
+ s.WithOpName("x4/x2/scale/_12__cf__10")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto x4_x2_y = ops::Mul(s.WithOpName("x4/x2/y"), x, x4_x2_scale);
@@ -913,7 +977,7 @@ TEST_F(FunctionLibraryRuntimeTest, Error_NotFound) {
"Not found: Function Foo is not defined.");
}
-TEST_F(FunctionLibraryRuntimeTest, Error_InstantiaionError) {
+TEST_F(FunctionLibraryRuntimeTest, Error_InstantiationError) {
auto bad_x_times_two = FDH::Define(
// Name
"XTimesTwo",
@@ -1009,13 +1073,13 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
auto x = ops::_Arg(s.WithOpName("x"), DT_FLOAT, 0);
auto func0 = ops::_Arg(s.WithOpName("Func/_0"), DT_FLOAT, 1);
auto scale = ops::Const(
- s.WithOpName("scale/_6__cf__11")
+ s.WithOpName("scale/_6__cf__15")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
2.0f);
auto func1_gx = ops::Mul(s.WithOpName("Func/_1/gx"), func0, scale);
auto func1_sx = ops::Shape(s.WithOpName("Func/_1/sx"), x);
auto const0 = ops::Const(
- s.WithOpName("Func/_1/sy/_5__cf__10")
+ s.WithOpName("Func/_1/sy/_5__cf__14")
.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"),
0, {0});
auto func1_rx = ops::internal::BroadcastGradientArgs(
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
new file mode 100644
index 0000000000..636cd43575
--- /dev/null
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+// Allocator for pinned CPU RAM that is made known to CUDA for the
+// purpose of efficient DMA with a GPU.
+class CUDAHostAllocator : public SubAllocator {
+ public:
+ // Note: stream_exec cannot be null.
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
+ : stream_exec_(stream_exec) {
+ CHECK(stream_exec_ != nullptr);
+ }
+ ~CUDAHostAllocator() override {}
+
+ void* Alloc(size_t alignment, size_t num_bytes) override {
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = stream_exec_->HostMemoryAllocate(num_bytes);
+ if (ptr == nullptr) {
+ LOG(WARNING) << "could not allocate pinned host memory of size: "
+ << num_bytes;
+ }
+ }
+ return ptr;
+ }
+
+ void Free(void* ptr, size_t num_bytes) override {
+ if (ptr != nullptr) {
+ stream_exec_->HostMemoryDeallocate(ptr);
+ }
+ }
+
+ private:
+ se::StreamExecutor* stream_exec_; // not owned, non-null
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_CUDA_HOST_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index bee5627636..3cb51b0dbc 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -36,9 +36,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_stream_util.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -201,7 +201,8 @@ class BaseGPUDevice::StreamGroupFactory {
// This function is thread safe.
BaseGPUDevice::StreamGroup* GetOrCreate(TfGpuId tf_gpu_id,
int stream_group_within_gpu,
- se::StreamExecutor* executor) {
+ se::StreamExecutor* executor,
+ const GPUOptions& options) {
mutex_lock guard(lock_);
StreamGroup* group =
&streams_[key_type(tf_gpu_id.value(), stream_group_within_gpu)];
@@ -221,10 +222,21 @@ class BaseGPUDevice::StreamGroupFactory {
VLOG(2) << "Created device_to_host_stream[" << stream_group_within_gpu
<< "] = " << group->device_to_host;
- group->device_to_device = new se::Stream(executor);
- group->device_to_device->Init();
- VLOG(2) << "Created device_to_device_stream[" << stream_group_within_gpu
- << "] = " << group->device_to_host;
+ int num_d2d_streams =
+ options.experimental().num_dev_to_dev_copy_streams();
+ if (num_d2d_streams < 1 || num_d2d_streams > 4) {
+ LOG(ERROR)
+ << "Illegal GPUOptions.experimental.num_dev_to_dev_copy_streams="
+ << num_d2d_streams << " set to 1 instead.";
+ num_d2d_streams = 1;
+ }
+ for (int i = 0; i < num_d2d_streams; ++i) {
+ se::Stream* stream = new se::Stream(executor);
+ stream->Init();
+ group->device_to_device.push_back(stream);
+ VLOG(2) << "Created device_to_device_stream[" << stream_group_within_gpu
+ << "] = " << group->device_to_device.back();
+ }
}
return group;
}
@@ -262,7 +274,7 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
tf_gpu_id_(tf_gpu_id),
sync_every_op_(sync_every_op),
max_streams_(max_streams) {
- ProcessState::singleton()->EnableGPUDevice();
+ GPUProcessState::singleton()->EnableGPUDevice();
}
BaseGPUDevice::~BaseGPUDevice() {
@@ -287,8 +299,8 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
// Create the specified number of GPU streams
for (int i = 0; i < max_streams_; i++) {
- streams_.push_back(
- StreamGroupFactory::Global().GetOrCreate(tf_gpu_id_, i, executor_));
+ streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
+ tf_gpu_id_, i, executor_, options.config.gpu_options()));
size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
void* scratch_buffer = gpu_allocator_->AllocateRaw(
@@ -1060,7 +1072,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
se::StreamExecutor* se =
GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
- ProcessState* process_state = ProcessState::singleton();
+ GPUProcessState* process_state = GPUProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
options.config.gpu_options(), tf_gpu_id, memory_limit);
if (gpu_allocator == nullptr) {
@@ -1080,7 +1092,7 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
BaseGPUDevice* gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
- process_state->GetCPUAllocator(numa_node));
+ ProcessState::singleton()->GetCPUAllocator(numa_node));
LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
<< GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 737a3515b6..56d03d7a8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -119,7 +120,7 @@ class BaseGPUDevice : public LocalDevice {
se::Stream* compute = nullptr;
se::Stream* host_to_device = nullptr;
se::Stream* device_to_host = nullptr;
- se::Stream* device_to_device = nullptr;
+ gtl::InlinedVector<se::Stream*, 4> device_to_device;
};
class StreamGroupFactory;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
index 9a000749c6..e1aaf95df6 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_device.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
namespace tensorflow {
@@ -40,9 +40,10 @@ class GPUDevice : public BaseGPUDevice {
}
Allocator* GetAllocator(AllocatorAttributes attr) override {
+ CHECK(cpu_allocator_) << "bad place 1";
if (attr.on_host()) {
if (attr.gpu_compatible() || force_gpu_compatible_) {
- ProcessState* ps = ProcessState::singleton();
+ GPUProcessState* ps = GPUProcessState::singleton();
return ps->GetCUDAHostAllocator(0);
} else {
return cpu_allocator_;
@@ -90,7 +91,7 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice {
~GPUCompatibleCPUDevice() override {}
Allocator* GetAllocator(AllocatorAttributes attr) override {
- ProcessState* ps = ProcessState::singleton();
+ GPUProcessState* ps = GPUProcessState::singleton();
if (attr.gpu_compatible() || force_gpu_compatible_) {
return ps->GetCUDAHostAllocator(0);
} else {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 5c6cb43eff..daf59f0560 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -58,7 +58,7 @@ void ExpectErrorMessageSubstr(const Status& s, StringPiece substr) {
class GPUDeviceTest : public ::testing::Test {
public:
- void TearDown() override { ProcessState::singleton()->TestOnlyReset(); }
+ void TearDown() override { GPUProcessState::singleton()->TestOnlyReset(); }
protected:
static SessionOptions MakeSessionOptions(
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index 2b442071e2..b18688174d 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include <cstring>
#include <vector>
+#include "tensorflow/core/common_runtime/gpu/cuda_host_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h"
@@ -25,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
-#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
@@ -37,19 +38,6 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
-// If these flags need to be runtime configurable, consider adding
-// options to ConfigProto.
-
-// If true, register CPU RAM used to copy to/from GPU RAM with the
-// CUDA driver.
-const bool FLAGS_brain_mem_reg_cuda_dma = true;
-
-// If true, record attributes of memory allocations and
-// dynamically check for appropriate use of registered memory.
-// Should only be true for debugging or diagnosis of
-// performance issues.
-const bool FLAGS_brain_gpu_record_mem_types = false;
-
namespace tensorflow {
namespace {
@@ -67,46 +55,37 @@ bool useCudaMemoryGuardAllocator() {
} // namespace
-ProcessState* ProcessState::instance_ = nullptr;
+GPUProcessState* GPUProcessState::instance_ = nullptr;
-/*static*/ ProcessState* ProcessState::singleton() {
+/*static*/ GPUProcessState* GPUProcessState::singleton() {
if (instance_ == nullptr) {
- instance_ = new ProcessState;
+ instance_ = new GPUProcessState;
}
+ CHECK(instance_->process_state_);
return instance_;
}
-ProcessState::ProcessState() : gpu_device_enabled_(false) {
+GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
CHECK(instance_ == nullptr);
instance_ = this;
+ process_state_ = ProcessState::singleton();
}
-ProcessState::~ProcessState() {
+// Normally the GPUProcessState singleton is never explicitly deleted.
+// This function is defined for debugging problems with the allocators.
+GPUProcessState::~GPUProcessState() {
+ CHECK_EQ(this, instance_);
for (auto p : gpu_allocators_) {
delete p;
}
instance_ = nullptr;
}
-string ProcessState::MemDesc::DebugString() {
- return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index,
- ", dma: ", gpu_registered, ", nic: ", nic_registered);
-}
-
-ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
- if (FLAGS_brain_gpu_record_mem_types) {
- auto iter = mem_desc_map_.find(ptr);
- if (iter != mem_desc_map_.end()) {
- return iter->second;
- }
- }
- return MemDesc();
-}
-
-Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
- TfGpuId tf_gpu_id,
- size_t total_bytes) {
+Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
+ TfGpuId tf_gpu_id,
+ size_t total_bytes) {
+ CHECK(process_state_);
#if GOOGLE_CUDA
const string& allocator_type = options.allocator_type();
mutex_lock lock(mu_);
@@ -114,7 +93,8 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
gpu_allocators_.resize(tf_gpu_id.value() + 1);
- if (FLAGS_brain_gpu_record_mem_types) gpu_al_.resize(tf_gpu_id.value() + 1);
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
+ gpu_al_.resize(tf_gpu_id.value() + 1);
}
if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
@@ -155,9 +135,9 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
gpu_allocator->AddAllocVisitor(v);
}
}
- if (FLAGS_brain_gpu_record_mem_types) {
- MemDesc md;
- md.loc = MemDesc::GPU;
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ ProcessState::MemDesc md;
+ md.loc = ProcessState::MemDesc::GPU;
md.dev_index = cuda_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
@@ -165,10 +145,11 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
gpu_al_.resize(tf_gpu_id.value() + 1);
}
gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
- &mem_desc_map_, gpu_allocator, md, &mu_);
+ &process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
}
- if (FLAGS_brain_gpu_record_mem_types) return gpu_al_[tf_gpu_id.value()];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
+ return gpu_al_[tf_gpu_id.value()];
return gpu_allocators_[tf_gpu_id.value()];
#else
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
@@ -176,64 +157,13 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
#endif // GOOGLE_CUDA
}
-Allocator* ProcessState::GetCPUAllocator(int numa_node) {
- // Although we're temporarily ignoring numa_node, check for legality.
- CHECK_GE(numa_node, 0);
- // TODO(tucker): actually maintain separate CPUAllocators for
- // different numa_nodes. For now, just one.
- numa_node = 0;
- mutex_lock lock(mu_);
- while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
- bool use_bfc_allocator = false;
- // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
- // efficient.
- Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
- &use_bfc_allocator);
- if (!status.ok()) {
- LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
- }
- VisitableAllocator* allocator;
- if (use_bfc_allocator) {
- // TODO(reedwm): evaluate whether 64GB by default is the best choice.
- int64 cpu_mem_limit_in_mb = -1;
- Status status = ReadInt64FromEnvVar("TF_CPU_BFC_MEM_LIMIT_IN_MB",
- 1LL << 16 /*64GB max by default*/,
- &cpu_mem_limit_in_mb);
- if (!status.ok()) {
- LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
- }
- int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
- allocator = new BFCAllocator(new BasicCPUAllocator(), cpu_mem_limit,
- true /*allow_growth*/,
- "bfc_cpu_allocator_for_gpu" /*name*/);
- VLOG(2) << "Using BFCAllocator with memory limit of "
- << cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
- } else {
- allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(), new NoopRounder, "cpu_pool");
- VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator";
- }
- if (LogMemory::IsEnabled()) {
- // Wrap the allocator to track allocation ids for better logging
- // at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
- }
- cpu_allocators_.push_back(allocator);
+Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
+ CHECK(process_state_);
+ if (!HasGPUDevice() ||
+ !process_state_->ProcessState::FLAGS_brain_mem_reg_cuda_dma) {
+ return process_state_->GetCPUAllocator(numa_node);
}
- return cpu_allocators_[0];
-}
-
-Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
- if (!HasGPUDevice() || !FLAGS_brain_mem_reg_cuda_dma) {
- return cpu_allocator();
- }
- // Although we're temporarily ignoring numa_node, check for legality.
CHECK_GE(numa_node, 0);
- // TODO(tucker): actually maintain separate CPUAllocators for
- // different numa_nodes. For now, just one.
- numa_node = 0;
-
{
// Here we optimize the most common use case where cuda_host_allocators_
// and cuda_al_ have already been populated and since we're only reading
@@ -241,7 +171,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
// we take a unique lock and populate these vectors.
tf_shared_lock lock(mu_);
- if (FLAGS_brain_gpu_record_mem_types &&
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
static_cast<int>(cuda_al_.size()) > 0) {
return cuda_al_[0];
}
@@ -288,21 +218,25 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
allocator = new TrackingVisitableAllocator(allocator, true);
}
cuda_host_allocators_.push_back(allocator);
- if (FLAGS_brain_gpu_record_mem_types) {
- MemDesc md;
- md.loc = MemDesc::CPU;
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ ProcessState::MemDesc md;
+ md.loc = ProcessState::MemDesc::CPU;
md.dev_index = 0;
md.gpu_registered = true;
md.nic_registered = false;
cuda_al_.push_back(new internal::RecordingAllocator(
- &mem_desc_map_, cuda_host_allocators_.back(), md, &mu_));
+ &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
+ &mu_));
}
}
- if (FLAGS_brain_gpu_record_mem_types) return cuda_al_[0];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
+ return cuda_al_[0];
return cuda_host_allocators_[0];
}
-void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) {
+void GPUProcessState::AddGPUAllocVisitor(int bus_id,
+ const AllocVisitor& visitor) {
+ CHECK(process_state_);
#if GOOGLE_CUDA
mutex_lock lock(mu_);
for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
@@ -320,17 +254,17 @@ void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) {
#endif // GOOGLE_CUDA
}
-void ProcessState::TestOnlyReset() {
- mutex_lock lock(mu_);
- gpu_device_enabled_ = false;
- gpu_visitors_.clear();
- mem_desc_map_.clear();
- gtl::STLDeleteElements(&cpu_allocators_);
- gtl::STLDeleteElements(&gpu_allocators_);
- gtl::STLDeleteElements(&cuda_host_allocators_);
- gtl::STLDeleteElements(&cpu_al_);
- gtl::STLDeleteElements(&gpu_al_);
- gtl::STLDeleteElements(&cuda_al_);
+void GPUProcessState::TestOnlyReset() {
+ process_state_->ProcessState::TestOnlyReset();
+ {
+ mutex_lock lock(mu_);
+ gpu_device_enabled_ = false;
+ gpu_visitors_.clear();
+ gtl::STLDeleteElements(&gpu_allocators_);
+ gtl::STLDeleteElements(&cuda_host_allocators_);
+ gtl::STLDeleteElements(&gpu_al_);
+ gtl::STLDeleteElements(&cuda_al_);
+ }
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index bc2c4182d7..cb41c3c6bd 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_
#include <functional>
#include <map>
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/process_state.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -34,27 +35,10 @@ class Allocator;
class VisitableAllocator;
class PoolAllocator;
-// Singleton that manages per-process state, e.g. allocation
-// of shared resources.
-class ProcessState {
+// Singleton that manages per-process state when GPUs are present.
+class GPUProcessState {
public:
- static ProcessState* singleton();
-
- // Descriptor for memory allocation attributes, used by optional
- // runtime correctness analysis logic.
- struct MemDesc {
- enum MemLoc { CPU, GPU };
- MemLoc loc;
- int dev_index;
- bool gpu_registered;
- bool nic_registered;
- MemDesc()
- : loc(CPU),
- dev_index(0),
- gpu_registered(false),
- nic_registered(false) {}
- string DebugString();
- };
+ static GPUProcessState* singleton();
// Query whether any GPU device has been created so far.
// Disable thread safety analysis since a race is benign here.
@@ -68,14 +52,6 @@ class ProcessState {
gpu_device_enabled_ = true;
}
- // Returns what we know about the memory at ptr.
- // If we know nothing, it's called CPU 0 with no other attributes.
- MemDesc PtrType(const void* ptr);
-
- // Returns the one CPUAllocator used for the given numa_node.
- // TEMPORARY: ignores numa_node.
- Allocator* GetCPUAllocator(int numa_node);
-
// Returns the one GPU allocator used for the indexed GPU.
// Note that this is a system GPU index, not (necessarily) a brain
// device index.
@@ -107,69 +83,39 @@ class ProcessState {
// the index of one of the PCIe buses. If the bus_id is invalid,
// results are undefined.
typedef std::function<void(void*, size_t)> AllocVisitor;
- virtual void AddGPUAllocVisitor(int bus_id, AllocVisitor visitor);
-
- typedef std::unordered_map<const void*, MemDesc> MDMap;
+ virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
protected:
- ProcessState();
+ GPUProcessState();
// Helper method for unit tests to reset the ProcessState singleton by
// cleaning up everything. Never use in production.
virtual void TestOnlyReset();
- static ProcessState* instance_;
+ ProcessState::MDMap* mem_desc_map() {
+ if (process_state_) return &process_state_->mem_desc_map_;
+ return nullptr;
+ }
+
+ static GPUProcessState* instance_;
+ ProcessState* process_state_; // Not owned.
bool gpu_device_enabled_;
mutex mu_;
- std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
- virtual ~ProcessState();
+ virtual ~GPUProcessState();
// Optional RecordingAllocators that wrap the corresponding
// Allocators for runtime attribute use analysis.
- MDMap mem_desc_map_;
- std::vector<Allocator*> cpu_al_ GUARDED_BY(mu_);
std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
friend class GPUDeviceTest;
};
-namespace internal {
-class RecordingAllocator : public Allocator {
- public:
- RecordingAllocator(ProcessState::MDMap* mm, Allocator* a,
- ProcessState::MemDesc md, mutex* mu)
- : mm_(mm), a_(a), md_(md), mu_(mu) {}
-
- string Name() override { return a_->Name(); }
- void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- void* p = a_->AllocateRaw(alignment, num_bytes);
- mutex_lock l(*mu_);
- (*mm_)[p] = md_;
- return p;
- }
- void DeallocateRaw(void* p) override {
- mutex_lock l(*mu_);
- auto iter = mm_->find(p);
- mm_->erase(iter);
- a_->DeallocateRaw(p);
- }
- bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); }
- size_t RequestedSize(const void* p) override { return a_->RequestedSize(p); }
- size_t AllocatedSize(const void* p) override { return a_->AllocatedSize(p); }
- void GetStats(AllocatorStats* stats) override { a_->GetStats(stats); }
- void ClearStats() override { a_->ClearStats(); }
- ProcessState::MDMap* mm_; // not owned
- Allocator* a_; // not owned
- ProcessState::MemDesc md_;
- mutex* mu_;
-};
-} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_PROCESS_STATE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_PROCESS_STATE_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
index d38413d79c..5851360cab 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
-#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -150,7 +150,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
const int64 total_bytes = is_dead ? 0 : tensor.TotalBytes();
if (total_bytes > 0) {
tracing::ScopedAnnotation annotation("SetProtoFromGPU");
- alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
+ alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
buf = alloc->Allocate<char>(total_bytes);
if (LogMemory::IsEnabled()) {
LogMemory::RecordRawAllocation("SetProtoFromGPU",
@@ -185,13 +185,11 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
}
// static
-void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context,
- DeviceContext* recv_dev_context, Device* src,
- Device* dst,
- AllocatorAttributes src_alloc_attr,
- AllocatorAttributes dst_alloc_attr,
- const Tensor* input, Tensor* output,
- StatusCallback done) {
+void GPUUtil::DeviceToDeviceCopy(
+ DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
+ Device* src, Device* dst, AllocatorAttributes src_alloc_attr,
+ AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output,
+ int dev_to_dev_stream_index, StatusCallback done) {
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
se::Stream* send_stream = nullptr;
Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info,
@@ -202,7 +200,7 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context,
}
auto send_device_to_device_stream =
static_cast<const GPUDeviceContext*>(send_dev_context)
- ->device_to_device_stream();
+ ->device_to_device_stream(dev_to_dev_stream_index);
if (send_device_to_device_stream == nullptr) {
done(errors::Internal("No send gpu copy-out-stream is available."));
return;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
index 237b0044da..57687a8364 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -90,13 +90,11 @@ class GPUUtil {
Device* gpu_device, Tensor* gpu_tensor,
StatusCallback done);
- static void DeviceToDeviceCopy(DeviceContext* send_dev_context,
- DeviceContext* recv_dev_context, Device* src,
- Device* dst,
- AllocatorAttributes src_alloc_attr,
- AllocatorAttributes dst_alloc_attr,
- const Tensor* input, Tensor* output,
- StatusCallback done);
+ static void DeviceToDeviceCopy(
+ DeviceContext* send_dev_context, DeviceContext* recv_dev_context,
+ Device* src, Device* dst, AllocatorAttributes src_alloc_attr,
+ AllocatorAttributes dst_alloc_attr, const Tensor* input, Tensor* output,
+ int dev_to_dev_stream_index, StatusCallback done);
// Deep-copying of GPU tensor on the same device.
// 'src_gpu_tensor''s and 'dst_gpu_tensor''s backing memory must be on
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index a4c8d5fe86..583bff2c07 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -15,8 +15,9 @@ limitations under the License.
#if GOOGLE_CUDA
-#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
+#include "tensorflow/core/common_runtime/gpu/cuda_host_allocator.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
@@ -96,7 +97,8 @@ TEST(PoolAllocatorTest, Alignment) {
TEST(PoolAllocatorTest, AutoResize) {
PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator, new NoopRounder, "pool");
+ new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
+ "pool");
// Alloc/dealloc 10 sizes just a few times, confirming pool size
// stays at 2.
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
index c92c5d1af3..d697d878dc 100644
--- a/tensorflow/core/common_runtime/gpu_device_context.h
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace stream_executor {
class Stream;
@@ -31,7 +32,7 @@ class GPUDeviceContext : public DeviceContext {
GPUDeviceContext(int stream_id, se::Stream* stream,
se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream,
- se::Stream* device_to_device_stream)
+ gtl::InlinedVector<se::Stream*, 4> device_to_device_stream)
: stream_id_(stream_id),
stream_(stream),
host_to_device_stream_(host_to_device_stream),
@@ -43,8 +44,8 @@ class GPUDeviceContext : public DeviceContext {
se::Stream* stream() const override { return stream_; }
se::Stream* host_to_device_stream() const { return host_to_device_stream_; }
se::Stream* device_to_host_stream() const { return device_to_host_stream_; }
- se::Stream* device_to_device_stream() const {
- return device_to_device_stream_;
+ se::Stream* device_to_device_stream(int index) const {
+ return device_to_device_stream_[index % device_to_device_stream_.size()];
}
int stream_id() const { return stream_id_; }
@@ -64,12 +65,12 @@ class GPUDeviceContext : public DeviceContext {
// The default primary stream to use for this context.
// All the memory belongs to this stream.
se::Stream* stream_;
- // The stream to use for copy data from host into GPU.
+ // The stream to use for copying data from host into GPU.
se::Stream* host_to_device_stream_;
- // The stream to use for copy data from GPU to host.
+ // The stream to use for copying data from GPU to host.
se::Stream* device_to_host_stream_;
- // The stream to use for copy data between GPU.
- se::Stream* device_to_device_stream_;
+ // Streams to use for copying data between GPUs.
+ gtl::InlinedVector<se::Stream*, 4> device_to_device_stream_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index eb710bdbc5..9c9eacb5b5 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -43,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/util/util.h"
#ifndef IS_MOBILE_PLATFORM
-#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
@@ -281,6 +280,118 @@ class TensorConnectionPruneRewrite : public subgraph::PruneRewrite {
NodeBuilder::NodeOut from_tensor_;
};
+template <class Map>
+Status LookupDevice(const DeviceSet& device_set, const string& tensor_name,
+ const Map& tensor2device,
+ const tensorflow::DeviceAttributes** out_device_attrs) {
+ *out_device_attrs = nullptr;
+ if (tensor2device.empty()) {
+ *out_device_attrs = &device_set.client_device()->attributes();
+ return Status::OK();
+ }
+ const auto it = tensor2device.find(tensor_name);
+ if (it == tensor2device.end()) {
+ *out_device_attrs = &device_set.client_device()->attributes();
+ return Status::OK();
+ }
+ DeviceNameUtils::ParsedName parsed_name;
+ if (!DeviceNameUtils::ParseFullName(it->second, &parsed_name)) {
+ return errors::InvalidArgument("Invalid device name ('", it->second,
+ "') provided for the tensor '", tensor_name,
+ "' in CallableOptions");
+ }
+ Device* device = device_set.FindDeviceByName(
+ DeviceNameUtils::ParsedNameToString(parsed_name));
+ if (device == nullptr) {
+ return errors::InvalidArgument("Device '", it->second,
+ "' specified for tensor '", tensor_name,
+ "' in CallableOptions does not exist");
+ }
+ *out_device_attrs = &device->attributes();
+ return Status::OK();
+}
+
+struct TensorAndDevice {
+ // WARNING: backing memory for the 'tensor' field is NOT owend.
+ const TensorId tensor;
+ // WARNING: device pointer is not owned, so must outlive TensorAndDevice.
+ const DeviceAttributes* device;
+};
+
+// Tensors of some DataTypes cannot placed in device memory as feeds or
+// fetches. Validate against a whitelist of those known to work.
+bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) {
+ // The mechanism for supporting feeds of device-backed Tensors requires
+ // the _Arg kernel to be registered for the corresponding type (and that
+ // the input to the kernel be in device and not host memory).
+ //
+ // The mechanism for supporting fetches of device-backed Tensors requires
+ // the _Retval kernel to be registered for the corresponding type (and
+ // that the output is produced in device and not host memory).
+ //
+ // For now, we return true iff there are _Arg AND _Retval kernels for dtype on
+ // the device. False negatives are okay, false positives would be bad.
+ //
+ // TODO(ashankar): Instead of a whitelist here, perhaps we could query
+ // the kernel registry for _Arg and _Retval kernels instead.
+ if (device_type == DEVICE_CPU) return true;
+ if (device_type != DEVICE_GPU) return false;
+ switch (dtype) {
+ case DT_BFLOAT16:
+ case DT_BOOL:
+ case DT_COMPLEX128:
+ case DT_COMPLEX64:
+ case DT_DOUBLE:
+ case DT_FLOAT:
+ case DT_HALF:
+ case DT_INT16:
+ case DT_INT64:
+ case DT_INT8:
+ case DT_UINT16:
+ case DT_UINT8:
+ return true;
+ default:
+ return false;
+ }
+}
+
+Status ValidateFeedAndFetchDevices(
+ const Graph& graph,
+ const std::vector<TensorAndDevice>& tensors_and_devices) {
+ if (tensors_and_devices.empty()) return Status::OK();
+ std::vector<bool> found(tensors_and_devices.size(), false);
+ for (const Node* node : graph.nodes()) {
+ // Linearly looping through all nodes and then all feed+fetch tensors isn't
+ // quite efficient. At the time of this writing, the expectation was that
+ // tensors_and_devices.size() is really small in practice, so this won't be
+ // problematic.
+ // Revist and make a more efficient lookup possible if needed (e.g., perhaps
+ // Graph can maintain a map from node name to Node*).
+ for (int i = 0; i < tensors_and_devices.size(); ++i) {
+ const TensorAndDevice& td = tensors_and_devices[i];
+ if (td.tensor.first != node->name()) continue;
+ found[i] = true;
+ TF_RETURN_IF_ERROR(graph.IsValidOutputTensor(node, td.tensor.second));
+ const DataType dtype = node->output_type(td.tensor.second);
+ if (!IsFeedAndFetchSupported(dtype, td.device->device_type())) {
+ return errors::Unimplemented(
+ "Cannot feed or fetch tensor '", td.tensor.ToString(),
+ "' from device ", td.device->name(), " as feeding/fetching from ",
+ td.device->device_type(), " devices is not yet supported for ",
+ DataTypeString(dtype), " tensors");
+ }
+ }
+ }
+ for (int i = 0; i < found.size(); ++i) {
+ if (!found[i]) {
+ return errors::InvalidArgument(
+ "Tensor ", tensors_and_devices[i].tensor.ToString(),
+ ", specified in either feed_devices or fetch_devices was not found "
+ "in the Graph");
+ }
+ }
+ return Status::OK();
+}
} // namespace
Status GraphExecutionState::PruneGraph(
@@ -290,18 +401,52 @@ Status GraphExecutionState::PruneGraph(
feed_rewrites.reserve(options.callable_options.feed_size());
std::vector<std::unique_ptr<subgraph::PruneRewrite>> fetch_rewrites;
fetch_rewrites.reserve(options.callable_options.fetch_size());
- const DeviceAttributes* device_info =
- &device_set_->client_device()->attributes();
if (options.use_function_convention) {
+ std::vector<TensorAndDevice> tensors_and_devices;
for (int i = 0; i < options.callable_options.feed_size(); ++i) {
- feed_rewrites.emplace_back(new subgraph::ArgFeedRewrite(
- &options.callable_options.feed(i), device_info, i));
+ // WARNING: feed MUST be a reference, since ArgFeedRewrite and
+ // tensors_and_devices holds on to its address.
+ const string& feed = options.callable_options.feed(i);
+ const DeviceAttributes* device_info;
+ TF_RETURN_IF_ERROR(LookupDevice(*device_set_, feed,
+ options.callable_options.feed_devices(),
+ &device_info));
+ feed_rewrites.emplace_back(
+ new subgraph::ArgFeedRewrite(&feed, device_info, i));
+ tensors_and_devices.push_back({ParseTensorName(feed), device_info});
+ }
+ if (!options.callable_options.fetch_devices().empty() &&
+ !options.callable_options.fetch_skip_sync()) {
+ return errors::Unimplemented(
+ "CallableOptions.fetch_skip_sync = false is not yet implemented. You "
+ "can set it to true instead, but MUST ensure that Device::Sync() is "
+ "invoked on the Device corresponding to the fetched tensor before "
+ "dereferencing the Tensor's memory.");
}
for (int i = 0; i < options.callable_options.fetch_size(); ++i) {
- fetch_rewrites.emplace_back(new subgraph::RetvalFetchRewrite(
- &options.callable_options.fetch(i), device_info, i));
+ // WARNING: fetch MUST be a reference, since RetvalFetchRewrite and
+ // tensors_and_devices holds on to its address.
+ const string& fetch = options.callable_options.fetch(i);
+ const DeviceAttributes* device_info;
+ TF_RETURN_IF_ERROR(LookupDevice(*device_set_, fetch,
+ options.callable_options.fetch_devices(),
+ &device_info));
+ fetch_rewrites.emplace_back(
+ new subgraph::RetvalFetchRewrite(&fetch, device_info, i));
+ tensors_and_devices.push_back({ParseTensorName(fetch), device_info});
}
+ TF_RETURN_IF_ERROR(
+ ValidateFeedAndFetchDevices(*graph, tensors_and_devices));
} else {
+ if (!options.callable_options.feed_devices().empty() ||
+ !options.callable_options.fetch_devices().empty()) {
+ return errors::Unimplemented(
+ "CallableOptions::feed_devices and CallableOptions::fetch_devices "
+ "to configure feeding/fetching tensors to/from device memory is not "
+ "yet supported when using a remote session.");
+ }
+ const DeviceAttributes* device_info =
+ &device_set_->client_device()->attributes();
for (const string& feed : options.callable_options.feed()) {
feed_rewrites.emplace_back(
new subgraph::RecvFeedRewrite(&feed, device_info));
@@ -456,11 +601,11 @@ Status GraphExecutionState::OptimizeGraph(
return errors::InvalidArgument("Missing node shape or type");
}
TensorShapeProto shape_proto(node.attr().at("shape").shape());
- // If the shape of the placeholder value is only partially known, we're
- // free to use any dimension we want to feed the placeholder. We choose
- // 1 to minimize the memory impact. Note that this only matters if an
- // optimizer choose to run the graph to build its cost model, which
- // doesn't happen (yet)
+ // If the shape of the placeholder value is only partially known,
+ // we're free to use any dimension we want to feed the placeholder. We
+ // choose 1 to minimize the memory impact. Note that this only matters
+ // if an optimizer choose to run the graph to build its cost model,
+ // which doesn't happen (yet)
if (shape_proto.unknown_rank()) {
shape_proto.set_unknown_rank(false);
}
@@ -476,21 +621,15 @@ Status GraphExecutionState::OptimizeGraph(
}
}
- std::unordered_map<string, DeviceProperties> device_map;
Device* cpu_device = nullptr;
for (const auto& device : device_set_->devices()) {
- DeviceProperties props = grappler::GetDeviceInfo(device->parsed_name());
- if (props.type() == "UNKNOWN") {
- continue;
- }
- device_map[device->name()] = props;
if (device->parsed_name().id == 0 &&
StringPiece(device->parsed_name().type) == "CPU" &&
device->GetAllocator(AllocatorAttributes()) != nullptr) {
cpu_device = device;
}
}
- grappler::VirtualCluster cluster(device_map, device_set_);
+ grappler::VirtualCluster cluster(device_set_);
GraphDef new_graph;
TF_RETURN_IF_ERROR(grappler::RunMetaOptimizer(
item, rewrite_options, cpu_device, &cluster, &new_graph));
@@ -520,10 +659,10 @@ Status GraphExecutionState::OptimizeGraph(
opts.allow_internal_ops = true;
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, new_graph, optimized_graph->get()));
- // The graph conversion sets the requested device names but not the assigned
- // device names. However, since at this point the graph is placed TF expects
- // an assigned device name for every node. Therefore we copy the requested
- // device into the assigned device field.
+ // The graph conversion sets the requested device names but not the
+ // assigned device names. However, since at this point the graph is placed
+ // TF expects an assigned device name for every node. Therefore we copy
+ // the requested device into the assigned device field.
for (Node* node : optimized_graph->get()->nodes()) {
node->set_assigned_device_name(node->requested_device());
}
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 7de1b80e2d..1f585a8c24 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -43,7 +44,7 @@ namespace test {
// TODO(hongm): Convert `g` and `init` to using std::unique_ptr.
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init,
- Rendezvous* rendez) {
+ Rendezvous* rendez, const char* executor_type) {
SessionOptions default_options;
if (!options) {
options = &default_options;
@@ -86,23 +87,26 @@ Benchmark::Benchmark(const string& device, Graph* g,
};
if (init) {
- Executor* init_exec;
- TF_CHECK_OK(
- NewLocalExecutor(params, std::unique_ptr<Graph>(init), &init_exec));
+ std::unique_ptr<Executor> init_exec;
+ TF_CHECK_OK(NewExecutor(executor_type, params, std::unique_ptr<Graph>(init),
+ &init_exec));
Executor::Args args;
args.rendezvous = rendez_;
args.runner = runner;
TF_CHECK_OK(init_exec->Run(args));
- delete init_exec;
}
- TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr<Graph>(g), &exec_));
+ TF_CHECK_OK(
+ NewExecutor(executor_type, params, std::unique_ptr<Graph>(g), &exec_));
}
Benchmark::~Benchmark() {
if (device_) {
rendez_->Unref();
- delete exec_;
+ // We delete `exec_` before `device_` because the `exec_` destructor may
+ // run kernel destructors that may attempt to access state borrowed from
+ // `device_`, such as the resource manager.
+ exec_.reset();
delete device_;
delete pool_;
}
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 3a7b3a5ace..995a15a299 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -39,7 +39,7 @@ class Benchmark {
// "init", and one reference on "rendez" (if not null).
Benchmark(const string& device, Graph* g,
const SessionOptions* options = nullptr, Graph* init = nullptr,
- Rendezvous* rendez = nullptr);
+ Rendezvous* rendez = nullptr, const char* executor_type = "");
~Benchmark();
// Executes the graph for "iters" times.
@@ -57,7 +57,7 @@ class Benchmark {
thread::ThreadPool* pool_ = nullptr;
Device* device_ = nullptr;
Rendezvous* rendez_ = nullptr;
- Executor* exec_ = nullptr;
+ std::unique_ptr<Executor> exec_;
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
};
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index b5fee36ff4..dfce7c23e7 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -187,8 +187,7 @@ Status CondBuilder::AddOutputs() {
} else {
// Feed the outputs directly from the merge nodes so that downstream ops
// can start before all the outputs have been computed.
- graph_->AddEdge(merges[e->src_output()], e->src_output(), e->dst(),
- e->dst_input());
+ graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
}
}
return Status::OK();
@@ -207,7 +206,7 @@ Status InlineCallInGraph(Node* n, Graph* g) {
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
// explicit.
- InlineFunctionBody(g->flib_def(), g, n, fbody);
+ InlineFunctionBody(g->flib_def(), g, n, fbody, false);
delete fbody;
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator.cc
index 43a909466e..4ec85457ad 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.cc
@@ -17,6 +17,13 @@ limitations under the License.
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
+#ifdef _WIN32
+// Declare function to avoid unresolved symbol in VS
+i_malloc_t i_malloc;
+i_calloc_t i_calloc;
+i_realloc_t i_realloc;
+i_free_t i_free;
+#endif
namespace tensorflow {
constexpr const char* MklCPUAllocator::kMaxLimitStr;
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 245320c896..29f702699f 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -29,7 +29,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#ifndef DO_NOT_USE_ML
#include "i_malloc.h"
+#endif
#ifdef _WIN32
typedef unsigned int uint;
@@ -97,14 +99,14 @@ class MklCPUAllocator : public VisitableAllocator {
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
kAllowGrowth, kName);
-
+#ifndef DO_NOT_USE_ML
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
i_malloc = MallocHook;
i_calloc = CallocHook;
i_realloc = ReallocHook;
i_free = FreeHook;
-
+#endif
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 86851c2c07..1f0773d387 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -628,6 +628,40 @@ class ColocationGraph {
return parent;
}
+ // Ensures that the devices of 'dst's resource and reference match the device
+ // specified for 'src', which is an input of 'dst' with a partially or fully
+ // specified device.
+ Status VerifyResourceAndRefInputsCanBeColocated(
+ const Node* dst, const Node* src,
+ const DeviceNameUtils::ParsedName& src_parsed_name) {
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(dst->input_edges(&edges));
+ for (const Edge* edge : edges) {
+ DataType input_type = dst->input_type(edge->dst_input());
+ if (input_type == DT_RESOURCE || IsRefType(input_type)) {
+ const Node* input_node = edge->src();
+ if (input_node == src) {
+ continue;
+ }
+ const auto& input_root = members_[FindRoot(input_node->id())];
+ const auto& input_parsed_name = input_root.device_name;
+ if (DeviceNameUtils::HasSomeDetails(input_parsed_name) &&
+ !DeviceNameUtils::AreCompatibleDevNames(input_parsed_name,
+ src_parsed_name)) {
+ return AttachDef(
+ errors::InvalidArgument(
+ "Could not colocate node with its "
+ "resource and reference inputs; devices ",
+ DeviceNameUtils::ParsedNameToString(input_parsed_name),
+ " and ", DeviceNameUtils::ParsedNameToString(src_parsed_name),
+ " are not compatible."),
+ *dst);
+ }
+ }
+ }
+ return Status::OK();
+ }
+
Graph* const graph_; // Not owned.
std::vector<Member> members_;
const DeviceSet* device_set_; // Not owned.
@@ -646,6 +680,15 @@ bool IsGeneratorNode(const Node* node) {
!IsRefType(node->output_type(0));
}
+bool IsExemptFromResourceInputColocation(const Node* node) {
+ // Note: Partitioned function calls, which place and partition their
+ // function bodies, are exempt from this check: they forward resource and
+ // ref inputs to operations that are appropriately placed, instead of
+ // dereferencing them.
+ const string& op_type = node->op_def().name();
+ return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
+}
+
} // namespace
Placer::Placer(Graph* graph, const DeviceSet* devices,
@@ -680,8 +723,8 @@ Status Placer::Run() {
// 2. Enumerate the constraint edges, and use them to update the disjoint
// node set.
- // If `node` has an input edge with reference type, add an
- // edge from the source of that edge to `node`.
+ // If `node` has an input edge with reference type, add an edge from the
+ // source of that edge to `node`.
for (const Edge* edge : graph_->edges()) {
if (edge->IsControlEdge()) {
continue;
@@ -689,7 +732,10 @@ Status Placer::Run() {
Node* src = edge->src();
Node* dst = edge->dst();
DataType input_type = dst->input_type(edge->dst_input());
- if (input_type == DT_RESOURCE || IsRefType(input_type)) {
+ if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
+ !IsExemptFromResourceInputColocation(dst)) {
+ // Colocate `src` and `dst` to maintain the invariant that nodes connected
+ // by reference edges are colocated.
int src_root_id = colocation_graph.FindRoot(src->id());
int dst_root_id = colocation_graph.FindRoot(dst->id());
auto& src_root = colocation_graph.members_[src_root_id];
@@ -706,6 +752,9 @@ Status Placer::Run() {
// incompatible.
if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
dest_parsed_name)) {
+ TF_RETURN_IF_ERROR(
+ colocation_graph.VerifyResourceAndRefInputsCanBeColocated(
+ dst, src, source_parsed_name));
if (log_device_placement_) {
LOG(INFO) << "Ignoring device specification "
<< DeviceNameUtils::ParsedNameToString(dest_parsed_name)
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 5ad251c892..07a7724f16 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -575,6 +575,10 @@ REGISTER_KERNEL_BUILDER(Name("HandleAssignCPU").Device("FakeCPU"), DummyOp);
REGISTER_OP("HandleAssignGPU").Input("i: resource").Input("v: float");
REGISTER_KERNEL_BUILDER(Name("HandleAssignGPU").Device("FakeGPU"), DummyOp);
+REGISTER_OP("TestTwoHandlesIn").Input("i: resource").Input("j: resource");
+REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("TestTwoHandlesIn").Device("FakeGPU"), DummyOp);
+
// Tests all combinations of resource handles and ops using them.
TEST_F(PlacerTest, TestResourceHandle) {
auto handle_test = [this](const string& var_op_name,
@@ -609,6 +613,42 @@ TEST_F(PlacerTest, TestResourceHandle) {
handle_test("HandleVariableCPU", "HandleAssignGPU", "FakeCPU").ok());
}
+TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
+ auto handle_test = [this](bool allow_soft_placement) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* var_cpu =
+ ops::SourceOp("TestHandleVariable", b.opts().WithName("var_cpu"));
+ Node* var_gpu =
+ ops::SourceOp("TestHandleVariable", b.opts().WithName("var_gpu"));
+ ops::BinaryOp("TestTwoHandlesIn", var_cpu, var_gpu,
+ b.opts().WithName("two_handles_in"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+
+ GetNodeByName(g, "var_cpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakecpu:0");
+ GetNodeByName(g, "var_gpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakegpu:0");
+ }
+
+ SessionOptions options;
+ options.config.set_allow_soft_placement(allow_soft_placement);
+ options.config.set_log_device_placement(true);
+ Status s = Place(&g, &options);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Could not colocate node with its resource and reference inputs"));
+ return Status::OK();
+ };
+
+ TF_EXPECT_OK(handle_test(false));
+ TF_EXPECT_OK(handle_test(true));
+}
+
// Test that an assignment of an operator to the wrong device
// is ignored when it could never be satisfied (due to reference
// edges, for example).
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 66fff16e8f..10a24ed14c 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/gpu/pool_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include <errno.h>
#ifndef _MSC_VER
@@ -284,4 +284,12 @@ void PoolAllocator::AddFreeVisitor(Visitor visitor) {
free_visitors_.push_back(visitor);
}
+void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
+ return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+}
+
+void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
+ port::AlignedFree(ptr);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 310158aba1..607734445b 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface. GPU memory is managed
-// by GPURegionAllocator.
+// implement the VisitableAllocator interface.
#include <atomic>
#include <map>
@@ -28,9 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -168,48 +165,18 @@ class Pow2Rounder : public RoundUpInterface {
class BasicCPUAllocator : public SubAllocator {
public:
+ // Argument numa_node is currently ignored.
+ explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+
~BasicCPUAllocator() override {}
- void* Alloc(size_t alignment, size_t num_bytes) override {
- return port::AlignedMalloc(num_bytes, alignment);
- }
- void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
-};
+ void* Alloc(size_t alignment, size_t num_bytes) override;
-// Allocator for pinned CPU RAM that is made known to CUDA for the
-// purpose of efficient DMA with a GPU.
-class CUDAHostAllocator : public SubAllocator {
- public:
- // Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
- CHECK(stream_exec_ != nullptr);
- }
- ~CUDAHostAllocator() override {}
-
- void* Alloc(size_t alignment, size_t num_bytes) override {
- void* ptr = nullptr;
- if (num_bytes > 0) {
- ptr = stream_exec_->HostMemoryAllocate(num_bytes);
- if (ptr == nullptr) {
- LOG(WARNING) << "could not allocate pinned host memory of size: "
- << num_bytes;
- }
- }
- return ptr;
- }
-
- void Free(void* ptr, size_t num_bytes) override {
- if (ptr != nullptr) {
- stream_exec_->HostMemoryDeallocate(ptr);
- }
- }
+ void Free(void* ptr, size_t num_bytes) override;
private:
- se::StreamExecutor* stream_exec_; // not owned, non-null
-
- TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
+ int numa_node_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_POOL_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
new file mode 100644
index 0000000000..4d83b25ce6
--- /dev/null
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -0,0 +1,129 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/process_state.h"
+
+#include <cstring>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/bfc_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/log_memory.h"
+#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/env_var.h"
+
+namespace tensorflow {
+
+ProcessState* ProcessState::instance_ = nullptr;
+
+/*static*/ ProcessState* ProcessState::singleton() {
+ if (instance_ == nullptr) {
+ instance_ = new ProcessState;
+ }
+
+ return instance_;
+}
+
+ProcessState::ProcessState() : numa_enabled_(false) {
+ CHECK(instance_ == nullptr);
+}
+
+// Normally the ProcessState singleton is never explicitly deleted.
+// This function is defined for debugging problems with the allocators.
+ProcessState::~ProcessState() {
+ CHECK_EQ(this, instance_);
+ instance_ = nullptr;
+ for (Allocator* a : cpu_allocators_) {
+ delete a;
+ }
+}
+
+string ProcessState::MemDesc::DebugString() {
+ return strings::StrCat((loc == CPU ? "CPU " : "GPU "), dev_index,
+ ", dma: ", gpu_registered, ", nic: ", nic_registered);
+}
+
+ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
+ if (FLAGS_brain_gpu_record_mem_types) {
+ auto iter = mem_desc_map_.find(ptr);
+ if (iter != mem_desc_map_.end()) {
+ return iter->second;
+ }
+ }
+ return MemDesc();
+}
+
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
+ CHECK_GE(numa_node, 0);
+ if (!numa_enabled_) numa_node = 0;
+ mutex_lock lock(mu_);
+ while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ bool use_bfc_allocator = false;
+ // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
+ // efficient.
+ Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
+ &use_bfc_allocator);
+ if (!status.ok()) {
+ LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
+ }
+ VisitableAllocator* allocator;
+ if (use_bfc_allocator) {
+ // TODO(reedwm): evaluate whether 64GB by default is the best choice.
+ int64 cpu_mem_limit_in_mb = -1;
+ Status status = ReadInt64FromEnvVar("TF_CPU_BFC_MEM_LIMIT_IN_MB",
+ 1LL << 16 /*64GB max by default*/,
+ &cpu_mem_limit_in_mb);
+ if (!status.ok()) {
+ LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
+ }
+ int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
+ allocator = new BFCAllocator(
+ new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
+ true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+ VLOG(2) << "Using BFCAllocator with memory limit of "
+ << cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
+ } else {
+ allocator = new PoolAllocator(
+ 100 /*pool_size_limit*/, true /*auto_resize*/,
+ new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
+ new NoopRounder, "cpu_pool");
+ VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
+ << "numa_enabled_=" << numa_enabled_
+ << " numa_node=" << numa_node;
+ }
+ if (LogMemory::IsEnabled()) {
+ // Wrap the allocator to track allocation ids for better logging
+ // at the cost of performance.
+ allocator = new TrackingVisitableAllocator(allocator, true);
+ }
+ cpu_allocators_.push_back(allocator);
+ }
+ return cpu_allocators_[numa_node];
+}
+
+void ProcessState::TestOnlyReset() {
+ mutex_lock lock(mu_);
+ mem_desc_map_.clear();
+ gtl::STLDeleteElements(&cpu_allocators_);
+ gtl::STLDeleteElements(&cpu_al_);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
new file mode 100644
index 0000000000..0f4ae230bb
--- /dev/null
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -0,0 +1,132 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_
+
+#include <functional>
+#include <map>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class Allocator;
+class VisitableAllocator;
+class PoolAllocator;
+
+// Singleton that manages per-process state, e.g. allocation of
+// shared resources.
+class ProcessState {
+ public:
+ static ProcessState* singleton();
+
+ // Descriptor for memory allocation attributes, used by optional
+ // runtime correctness analysis logic.
+ struct MemDesc {
+ enum MemLoc { CPU, GPU };
+ MemLoc loc;
+ int dev_index;
+ bool gpu_registered;
+ bool nic_registered;
+ MemDesc()
+ : loc(CPU),
+ dev_index(0),
+ gpu_registered(false),
+ nic_registered(false) {}
+ string DebugString();
+ };
+
+ // If NUMA Allocators are desired, call this before calling any
+ // Allocator accessor.
+ void EnableNUMA() { numa_enabled_ = true; }
+
+ // Returns what we know about the memory at ptr.
+ // If we know nothing, it's called CPU 0 with no other attributes.
+ MemDesc PtrType(const void* ptr);
+
+ // Returns the one CPUAllocator used for the given numa_node.
+ // TEMPORARY: ignores numa_node.
+ Allocator* GetCPUAllocator(int numa_node);
+
+ typedef std::unordered_map<const void*, MemDesc> MDMap;
+
+ protected:
+ ProcessState();
+ friend class GPUProcessState;
+
+ // If these flags need to be runtime configurable consider adding
+ // them to ConfigProto.
+ static const bool FLAGS_brain_mem_reg_cuda_dma = true;
+ static const bool FLAGS_brain_gpu_record_mem_types = false;
+
+ // Helper method for unit tests to reset the ProcessState singleton by
+ // cleaning up everything. Never use in production.
+ virtual void TestOnlyReset();
+
+ static ProcessState* instance_;
+ bool numa_enabled_;
+
+ mutex mu_;
+
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+
+ virtual ~ProcessState();
+
+ // Optional RecordingAllocators that wrap the corresponding
+ // Allocators for runtime attribute use analysis.
+ MDMap mem_desc_map_;
+ std::vector<Allocator*> cpu_al_ GUARDED_BY(mu_);
+};
+
+namespace internal {
+class RecordingAllocator : public Allocator {
+ public:
+ RecordingAllocator(ProcessState::MDMap* mm, Allocator* a,
+ ProcessState::MemDesc md, mutex* mu)
+ : mm_(mm), a_(a), md_(md), mu_(mu) {}
+
+ string Name() override { return a_->Name(); }
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* p = a_->AllocateRaw(alignment, num_bytes);
+ mutex_lock l(*mu_);
+ (*mm_)[p] = md_;
+ return p;
+ }
+ void DeallocateRaw(void* p) override {
+ mutex_lock l(*mu_);
+ auto iter = mm_->find(p);
+ mm_->erase(iter);
+ a_->DeallocateRaw(p);
+ }
+ bool TracksAllocationSizes() override { return a_->TracksAllocationSizes(); }
+ size_t RequestedSize(const void* p) override { return a_->RequestedSize(p); }
+ size_t AllocatedSize(const void* p) override { return a_->AllocatedSize(p); }
+ void GetStats(AllocatorStats* stats) override { a_->GetStats(stats); }
+ void ClearStats() override { a_->ClearStats(); }
+ ProcessState::MDMap* mm_; // not owned
+ Allocator* a_; // not owned
+ ProcessState::MemDesc md_;
+ mutex* mu_;
+};
+} // namespace internal
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROCESS_STATE_H_
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 93f24a3217..6d247975ed 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -110,7 +110,7 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
recv_args.device_context, src_device, dst_device,
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
- std::move(done));
+ 0 /*dev_to_dev_stream_index*/, std::move(done));
}
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index f8428f2fde..c1e514d5ad 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -163,7 +163,8 @@ void RingReducer::Run(StatusCallback done) {
CollectiveRemoteAccessLocal::MemCpyAsync(
ctx_->input_device_context(0), ctx_->op_device_context(), device_,
device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_,
- output_, [this, &note, &status](const Status& s) {
+ output_, 0 /*dev_to_dev_stream_index*/,
+ [this, &note, &status](const Status& s) {
status.Update(s);
note.Notify();
});
@@ -387,7 +388,7 @@ void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) {
col_params_.task.is_local[rf->recv_dev_idx],
recv_buf_key, device_, ctx_->op_device_context(),
ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, done);
+ device_locality_, rf->subdiv_idx, done);
}
string RingReducer::FieldState() {
@@ -446,10 +447,11 @@ bool RingReducer::RunAsyncParts() {
if (rf->do_recv) {
rf->action = RF_RECV;
auto requeue = [this, rf, &ready_queue, &aborted](Status s) {
- const bool bad_status = !s.ok();
- if (bad_status) aborted = true;
+ if (!s.ok()) {
+ aborted = true;
+ StartAbort(s);
+ }
ready_queue.Enqueue(rf);
- if (bad_status) StartAbort(s);
};
DispatchRecv(rf, requeue);
dispatched = true;
@@ -494,10 +496,11 @@ bool RingReducer::RunAsyncParts() {
if (rf->do_send) {
rf->action = RF_SEND;
auto send_complete = [this, rf, &ready_queue, &aborted](Status s) {
- const bool bad_status = !s.ok();
- if (bad_status) aborted = true;
+ if (!s.ok()) {
+ aborted = true;
+ StartAbort(s);
+ }
ready_queue.Enqueue(rf);
- if (bad_status) StartAbort(s);
};
DispatchSend(rf, send_complete);
dispatched = true;
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index e4387a074a..fcdf9deff8 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -68,11 +68,13 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) override {
if (MaybeFail(done)) return;
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
- to_alloc_attr, to_tensor, client_locality, done);
+ to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
+ done);
}
void PostToPeer(const string& peer_device, const string& peer_task,
diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
index d0d4f24b11..80205830a2 100644
--- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h
@@ -32,7 +32,8 @@ class TestCollectiveExecutor : public CollectiveExecutor {
bool peer_is_local, const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality, //???
+ const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) override {
done(errors::Internal("Unimplemented"));
}
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 1528c7f130..36e9b3455a 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -42,7 +42,7 @@ load(
# Check that tensorflow/core:tensorflow does not depend on grpc.
check_deps(
name = "core_tensorflow_check_deps",
- disallowed_deps = ["@grpc//:grpc++_unsecure"],
+ disallowed_deps = ["@grpc//:grpc++"],
deps = ["//tensorflow/core:tensorflow"],
)
@@ -143,6 +143,7 @@ tf_cuda_library(
":debug_node_key",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
@@ -150,7 +151,6 @@ tf_cuda_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
- "@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
@@ -166,11 +166,11 @@ tf_cuda_library(
":debug_io_utils",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
+ "//tensorflow:grpc++",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 58361bf78f..8d3c9ff575 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <atomic>
#include <unordered_set>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 03a011f79e..9e8002d490 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#ifndef PLATFORM_WINDOWS
-#include "grpc++/create_channel.h"
+#include "grpcpp/create_channel.h"
#else
// winsock2.h is used in grpc, so Ws2_32.lib is needed
#pragma comment(lib, "Ws2_32.lib")
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index ead698d787..2059b1ce0d 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -145,9 +145,11 @@ tf_cc_test(
deps = [
":session_mgr",
":worker_env",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@@ -226,6 +228,17 @@ tf_cc_test(
],
)
+cc_library(
+ name = "cancellable_call",
+ hdrs = ["cancellable_call.h"],
+ deps = [
+ ":call_options",
+ ":worker_cache",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
tf_cc_test(
name = "tensor_coding_test",
size = "small",
@@ -392,6 +405,7 @@ cc_library(
hdrs = ["master_env.h"],
deps = [
":worker_cache",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
],
@@ -453,10 +467,47 @@ cc_library(
)
cc_library(
+ name = "rpc_collective_executor_mgr",
+ srcs = ["rpc_collective_executor_mgr.cc"],
+ hdrs = ["rpc_collective_executor_mgr.h"],
+ deps = [
+ ":base_rendezvous_mgr",
+ ":collective_param_resolver_distributed",
+ ":collective_rma_distributed",
+ ":device_resolver_distributed",
+ ":worker_cache",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_collective_executor_mgr_test",
+ srcs = ["rpc_collective_executor_mgr_test.cc"],
+ deps = [
+ ":collective_param_resolver_distributed",
+ ":device_resolver_distributed",
+ ":rpc_collective_executor_mgr",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:session_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
name = "collective_rma_distributed",
srcs = ["collective_rma_distributed.cc"],
hdrs = ["collective_rma_distributed.h"],
deps = [
+ ":cancellable_call",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -492,6 +543,7 @@ cc_library(
hdrs = ["collective_param_resolver_distributed.h"],
deps = [
":call_options",
+ ":cancellable_call",
":device_resolver_distributed",
":worker_cache",
"//tensorflow/core:core_cpu_internal",
@@ -578,6 +630,7 @@ tf_cuda_cc_test(
":master",
":remote_device",
":worker_interface",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -599,7 +652,6 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:dense_update_ops",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:variable_ops",
- "@grpc//:grpc++_unsecure",
],
)
@@ -617,6 +669,7 @@ tf_cuda_cc_test(
":master",
":remote_device",
":worker_interface",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -632,7 +685,6 @@ tf_cuda_cc_test(
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
- "@grpc//:grpc++_unsecure",
],
)
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
index 5f6931e008..de6e4b4a7c 100644
--- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -281,7 +281,7 @@ void BaseRemoteRendezvous::SameWorkerRecvDone(
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
recv_args.device_context, src_device, dst_device,
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
- std::move(done));
+ 0 /*dev_to_dev_stream_index*/, std::move(done));
}
bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
diff --git a/tensorflow/core/distributed_runtime/cancellable_call.h b/tensorflow/core/distributed_runtime/cancellable_call.h
new file mode 100644
index 0000000000..05089c7d15
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/cancellable_call.h
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
+
+#include <string>
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Supports client side cancellation of WorkerInterface calls via
+// registration with a CancellationManager.
+class CancellableCall {
+ public:
+ CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
+ WorkerCacheInterface* wc)
+ : cancel_mgr_(cancel_mgr),
+ remote_worker_(remote_worker),
+ wc_(wc),
+ wi_(wc_->CreateWorker(remote_worker_)) {}
+
+ virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
+
+ virtual void IssueCall(const StatusCallback& done) = 0;
+
+ void Start(const StatusCallback& done) {
+ CancellationToken token = cancel_mgr_->get_cancellation_token();
+ const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
+ token, [this, token]() { opts_.StartCancel(); });
+ if (not_yet_cancelled) {
+ IssueCall([this, token, done](const Status& s) {
+ cancel_mgr_->DeregisterCallback(token);
+ done(s);
+ });
+ } else {
+ done(errors::Cancelled("RPC Request was cancelled"));
+ }
+ }
+
+ protected:
+ mutable mutex mu_;
+ CancellationManager* const cancel_mgr_; // Not owned
+ const string remote_worker_;
+ WorkerCacheInterface* const wc_; // Not owned
+ WorkerInterface* const wi_; // Owned by wc_, must be released.
+ CallOptions opts_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CANCELLABLE_CALL_H_
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 7a93b54eae..1dd10d309b 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -14,55 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
-#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/protobuf/config.pb.h"
-// TODO(tucker): When we're ready to enable collectives this const will
-// transition to a settable config member.
-static const char FLAGS_collective_group_leader[] =
- "/job:worker/replica:0/task:0";
-
namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager. Note that ParamResolverInterface
-// calls are done on behalf of an Op execution which needs to abort if the
-// step in which it executes is cancelled.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
class CompleteGroupCall : public CancellableCall {
public:
@@ -126,9 +84,9 @@ CollectiveParamResolverDistributed::CollectiveParamResolverDistributed(
const string& task_name)
: CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name),
worker_cache_(worker_cache),
- group_leader_(task_name == FLAGS_collective_group_leader
+ group_leader_(task_name == config.experimental().collective_group_leader()
? ""
- : FLAGS_collective_group_leader) {}
+ : config.experimental().collective_group_leader()) {}
void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
@@ -192,21 +150,23 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
for (int32 offset : request->subdiv_offset()) {
cp->instance.impl_details.subdiv_offsets.push_back(offset);
}
- VLOG(1) << "New cp " << cp << " for device " << request->device() << " : "
+ string* device = new string(request->device());
+ VLOG(1) << "New cp " << cp << " for device " << *device << " : "
<< cp->ToString();
- StatusCallback done_and_cleanup = [this, cp, done](const Status& s) {
+ StatusCallback done_and_cleanup = [this, cp, device, done](const Status& s) {
done(s);
delete cp;
+ delete device;
};
// Start by completing the group.
CompleteGroupDistributed(
- request->device(), cp, cancel_mgr,
- [this, cp, request, response, cancel_mgr, done_and_cleanup](
+ *device, cp, cancel_mgr,
+ [this, cp, device, response, cancel_mgr, done_and_cleanup](
const Status& cg_status, const GroupRec* gr) {
if (cg_status.ok()) {
// Then complete the instance.
CompleteInstanceDistributed(
- request->device(), gr, cp, cancel_mgr,
+ *device, gr, cp, cancel_mgr,
[this, gr, cp, response,
done_and_cleanup](const Status& ci_status) {
if (ci_status.ok()) {
@@ -218,6 +178,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync(
const Status& fi_status, InstanceRec* ir) {
if (fi_status.ok()) {
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
response->set_instance_key(cp->instance.instance_key);
response->set_source_rank(ir->source_rank);
done_and_cleanup(fi_status);
@@ -319,18 +280,21 @@ bool CollectiveParamResolverDistributed::InstanceIsCached(int32 instance_key) {
void CollectiveParamResolverDistributed::UpdateInstanceCache(
const GroupRec* gr, CollectiveParams* cp,
const CompleteInstanceResponse& resp, const StatusCallback& done) {
- Notification note;
- InstanceRec* ir = nullptr;
+ using InstanceRecPointer = InstanceRec*;
+ InstanceRecPointer* irp = new InstanceRecPointer(nullptr);
int32 source_rank = resp.source_rank();
- auto continue_with_ir = [this, cp, &ir, source_rank, done](const Status& s) {
+ auto continue_with_ir = [this, cp, irp, source_rank, done](const Status& s) {
if (!s.ok()) {
done(s);
+ delete irp;
return;
}
Status status;
+ InstanceRec* ir = *irp;
do {
mutex_lock l(ir->out_mu);
+ ir->WaitForOutMu(l);
if (ir->source_rank != source_rank) {
if (ir->source_rank >= 0) {
ir->status = errors::Internal(
@@ -360,11 +324,12 @@ void CollectiveParamResolverDistributed::UpdateInstanceCache(
} while (false);
// Callback outside of lock.
done(status);
+ delete irp;
};
FindInstanceRec(
- gr, cp, [this, &ir, continue_with_ir](const Status s, InstanceRec* irec) {
- ir = irec;
+ gr, cp, [this, irp, continue_with_ir](const Status s, InstanceRec* irec) {
+ *irp = irec;
continue_with_ir(s);
});
}
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 95a010286d..4eed856759 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -147,10 +147,9 @@ class DeviceResDistTest : public ::testing::Test {
ConfigProto config;
for (int w = 0; w < num_workers; ++w) {
string name = strings::StrCat("/job:worker/replica:0/task:", w);
- // TODO(tucker): When config option becomes available, set here.
- // if (w == 0) {
- // config.set_collective_group_leader(name);
- // }
+ if (w == 0) {
+ config.mutable_experimental()->set_collective_group_leader(name);
+ }
DefineWorker(config, name, device_type, num_devices);
}
}
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index c15878bfd3..b9a3502131 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
@@ -28,45 +29,6 @@ namespace tensorflow {
namespace {
-// Supports client side cancellation of WorkerInterface calls via
-// registration with a CancellationManager.
-//
-// TODO(tucker): Maybe unify this with CancellableCall in
-// collective_param_resolver_distributed.cc.
-class CancellableCall {
- public:
- CancellableCall(CancellationManager* cancel_mgr, const string& remote_worker,
- WorkerCacheInterface* wc)
- : cancel_mgr_(cancel_mgr), remote_worker_(remote_worker), wc_(wc) {
- wi_ = wc_->CreateWorker(remote_worker_);
- }
- virtual ~CancellableCall() { wc_->ReleaseWorker(remote_worker_, wi_); }
-
- virtual void IssueCall(const StatusCallback& done) = 0;
-
- void Start(const StatusCallback& done) {
- CancellationToken token = cancel_mgr_->get_cancellation_token();
- const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(
- token, [this, token]() { opts_.StartCancel(); });
- if (not_yet_cancelled) {
- IssueCall([this, token, done](const Status& s) {
- cancel_mgr_->DeregisterCallback(token);
- done(s);
- });
- } else {
- done(errors::Cancelled("RPC Request was cancelled"));
- }
- }
-
- protected:
- mutable mutex mu_;
- CancellationManager* cancel_mgr_; // Not owned
- const string remote_worker_;
- WorkerCacheInterface* wc_; // Not owned
- WorkerInterface* wi_; // Owned by wc_, must be released.
- CallOptions opts_;
-};
-
class RecvBufCall : public CancellableCall {
public:
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
@@ -103,11 +65,13 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
const string& peer_device, const string& peer_task, bool peer_is_local,
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
- const DeviceLocality& client_locality, const StatusCallback& done) {
+ const DeviceLocality& client_locality, int dev_to_dev_stream_index,
+ const StatusCallback& done) {
if (peer_is_local) {
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
- to_alloc_attr, to_tensor, client_locality, done);
+ to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
+ done);
return;
}
@@ -119,9 +83,10 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
};
State* state = new State;
- // Logic to be executed on the RecvBufferAsync callback.
+ // Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
- to_device_ctx, to_tensor, done](const Status& s) {
+ to_device_ctx, to_tensor, dev_to_dev_stream_index,
+ done](const Status& s) {
if (s.ok()) {
// In this generic implementation the bytes come back in the
// RPC response protobuf rather than via RDMA so we need to copy
@@ -157,7 +122,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer(
CopyTensor::ViaDMA("", // edge name (non-existent)
nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
to_device, cpu_attr, to_alloc_attr, cpu_tensor,
- to_tensor,
+ to_tensor, dev_to_dev_stream_index,
[this, cpu_tensor, done](const Status& s) {
delete cpu_tensor;
// This callback must not block, so execute
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
index cfa9110f47..9434cacbca 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h
@@ -37,6 +37,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal {
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) override;
void StartAbort(const Status& s) override;
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index a552f81f58..bfd312410c 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -280,7 +280,7 @@ TEST_F(CollRMADistTest, ProdFirstOK) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@@ -309,7 +309,7 @@ TEST_F(CollRMADistTest, ConsFirstOK) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
@@ -342,7 +342,7 @@ TEST_F(CollRMADistTest, ConsFirstAbort) {
"/job:worker/replica:0/task:1", // peer_task
false, // peer_is_local
kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
- device_locality_,
+ device_locality_, 0 /*dev_to_dev_stream_index*/,
[this, &consumer_status, &consumer_note](const Status& s) {
consumer_status = s;
consumer_note.Notify();
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index f3922dde74..055e5dfced 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:eager_executor",
+ "//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
@@ -47,6 +48,8 @@ cc_library(
"eager_service_impl.h",
],
deps = [
+ "//tensorflow:grpc",
+ "//tensorflow:grpc++",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:core_cpu_internal",
@@ -60,13 +63,12 @@ cc_library(
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
- "@grpc//:grpc++_unsecure",
- "@grpc//:grpc_unsecure",
],
)
@@ -79,10 +81,12 @@ tf_cc_test(
"//tensorflow/c:c_api_internal",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:tensor_handle",
+ "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 4bd74b81a7..466e779fab 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
@@ -62,10 +63,10 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
}
*num_retvals += iter->second.i();
} else if (!output_arg.type_list_attr().empty()) {
- auto iter = attrs.find(output_arg.number_attr());
+ auto iter = attrs.find(output_arg.type_list_attr());
if (iter == attrs.end()) {
- return errors::InvalidArgument("Unable to find number_attr ",
- output_arg.number_attr(),
+ return errors::InvalidArgument("Unable to find type_list_attr ",
+ output_arg.type_list_attr(),
" for Op: ", op_name);
}
*num_retvals += iter->second.list().type_size();
@@ -80,8 +81,12 @@ Status GetNumRetvals(tensorflow::EagerContext* context, const string& op_name,
Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
CreateContextResponse* response) {
- tensorflow::RemoteRendezvous* r = env_->rendezvous_mgr->Find(0);
+ //make sure env_ , env_->rendezvous_mgr available
+ if (env_ == nullptr || env_->rendezvous_mgr == nullptr) {
+ return tensorflow::errors::Internal("invalid eager env_ or env_->rendezvous_mgr.");
+ }
std::vector<tensorflow::Device*> devices;
+
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions.
SessionOptions(),
@@ -89,7 +94,6 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
request->server_def().job_name().data(),
request->server_def().task_index()),
&devices));
-
response->mutable_device_attributes()->Reserve(devices.size());
for (auto& d : devices) {
*response->add_device_attributes() = d->attributes();
@@ -97,6 +101,19 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
new tensorflow::DeviceMgr(devices));
+
+ auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
+ auto session_name = strings::StrCat("eager_", request->rendezvous_id());
+ TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
+ session_name, request->server_def(), true));
+
+ std::shared_ptr<WorkerSession> worker_session;
+ TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
+ session_name, &worker_session));
+
+ // Initialize remote tensor communication based on worker session.
+ TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
+
std::unique_ptr<tensorflow::EagerContext> ctx(new tensorflow::EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
@@ -115,8 +132,20 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
return Status::OK();
}
+Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
+ const tensorflow::Tensor* t = nullptr;
+
+ // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
+ TF_RETURN_IF_ERROR(handle->Tensor(&t));
+
+ t->shape().AsProto(proto);
+
+ return Status::OK();
+}
+
Status EagerServiceImpl::ExecuteOp(const Operation& operation,
- ServerContext* server_context) {
+ ServerContext* server_context,
+ QueueResponse* queue_response) {
std::unique_ptr<tensorflow::EagerOperation> op;
const char* name = operation.name().c_str(); // Shorthand
const tensorflow::AttrTypeMap* types;
@@ -159,6 +188,10 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
server_context->AddOperationOutputs(retvals, operation.id());
+ for (auto* handle : retvals) {
+ TF_RETURN_IF_ERROR(TensorHandleShape(handle, queue_response->add_shape()));
+ }
+
return Status::OK();
}
@@ -169,8 +202,9 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
core::ScopedUnref context_unref(context);
for (const auto& item : request->queue()) {
+ auto* queue_response = response->add_queue_response();
if (item.has_operation()) {
- TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context));
+ TF_RETURN_IF_ERROR(ExecuteOp(item.operation(), context, queue_response));
} else {
TF_RETURN_IF_ERROR(context->DeleteTensorHandle(
RemoteTensorHandleInternal(item.handle_to_decref())));
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
index ebd5269a57..b0e4aa84b9 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h
@@ -135,7 +135,8 @@ class EagerServiceImpl {
tensorflow::Status GetServerContext(uint64, ServerContext**);
private:
- Status ExecuteOp(const Operation& operation, ServerContext* server_context);
+ Status ExecuteOp(const Operation& operation, ServerContext* server_context,
+ QueueResponse* queue_response);
const WorkerEnv* const env_; // Not owned.
mutex contexts_mu_;
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index f865ebe1be..b98386ba86 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -20,15 +20,16 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@@ -50,6 +51,39 @@ class TestEagerServiceImpl : public EagerServiceImpl {
}
};
+class EagerServiceImplTest : public ::testing::Test {
+ public:
+ EagerServiceImplTest()
+ : rendezvous_mgr_(&worker_env_),
+ session_mgr_(new SessionMgr(
+ &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
+ std::unique_ptr<WorkerCacheInterface>(),
+ [](const ServerDef& server_def,
+ WorkerCacheInterface** worker_cache) {
+ *worker_cache = nullptr;
+ return Status::OK();
+ })) {
+ worker_env_.env = Env::Default();
+
+ worker_env_.rendezvous_mgr = &rendezvous_mgr_;
+ worker_env_.session_mgr = session_mgr_.get();
+
+ Device* device = DeviceFactory::NewDevice(
+ "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0");
+
+ worker_env_.local_devices = {device};
+
+ device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
+ worker_env_.device_mgr = device_mgr_.get();
+ }
+
+ protected:
+ WorkerEnv worker_env_;
+ tensorflow::RpcRendezvousMgr rendezvous_mgr_;
+ std::unique_ptr<SessionMgr> session_mgr_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
void SetTensorProto(AttrValue* val) {
int64_t dims[] = {2, 2};
float data[] = {1.0f, 2.0f, 3.0f, 4.0f};
@@ -119,17 +153,13 @@ tensorflow::FunctionDef MatMulFunction() {
}
// Test creates a context and attempts to execute some ops.
-TEST(EagerServiceImplTest, BasicTest) {
- WorkerEnv worker_env;
- worker_env.env = Env::Default();
- tensorflow::RpcRendezvousMgr rm(&worker_env);
- worker_env.rendezvous_mgr = &rm;
-
- TestEagerServiceImpl eager_service_impl(&worker_env);
+TEST_F(EagerServiceImplTest, BasicTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
@@ -168,6 +198,11 @@ TEST(EagerServiceImplTest, BasicTest) {
TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
&remote_enqueue_response));
+ auto& matmul_result_shape =
+ remote_enqueue_response.queue_response(1).shape(0);
+ EXPECT_EQ(matmul_result_shape.dim(0).size(), 2);
+ EXPECT_EQ(matmul_result_shape.dim(1).size(), 2);
+
tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
@@ -194,17 +229,13 @@ TEST(EagerServiceImplTest, BasicTest) {
}
// Test creates a context and attempts to execute a function.
-TEST(EagerServiceImplTest, BasicFunctionTest) {
- WorkerEnv worker_env;
- worker_env.env = Env::Default();
- tensorflow::RpcRendezvousMgr rm(&worker_env);
- worker_env.rendezvous_mgr = &rm;
-
- TestEagerServiceImpl eager_service_impl(&worker_env);
+TEST_F(EagerServiceImplTest, BasicFunctionTest) {
+ TestEagerServiceImpl eager_service_impl(&worker_env_);
CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0);
+ request.set_rendezvous_id(random::New64());
CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
index c4bd67aaed..28b68c3b88 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
+++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_EXECUTE_NODE_H_
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
@@ -27,6 +28,22 @@ namespace eager {
// via RPC in a remote EagerService.
class RemoteExecuteNode : public tensorflow::EagerNode {
public:
+ RemoteExecuteNode(
+ tensorflow::uint64 id, const tensorflow::eager::EnqueueRequest& request,
+ tensorflow::eager::EagerClient* eager_client,
+ const gtl::InlinedVector<TensorHandle*, 4>& inputs,
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback)
+ : tensorflow::EagerNode(id),
+ request_(std::move(request)),
+ eager_client_(eager_client),
+ inputs_(inputs),
+ done_callback_(std::move(done_callback)) {
+ for (auto* handle : inputs_) {
+ handle->Ref();
+ }
+ }
+
RemoteExecuteNode(tensorflow::uint64 id,
const tensorflow::eager::EnqueueRequest& request,
tensorflow::eager::EagerClient* eager_client)
@@ -34,6 +51,12 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
request_(std::move(request)),
eager_client_(eager_client) {}
+ ~RemoteExecuteNode() {
+ for (auto* handle : inputs_) {
+ handle->Unref();
+ }
+ }
+
tensorflow::Status Run() override {
tensorflow::eager::EnqueueResponse response;
tensorflow::Status status;
@@ -45,6 +68,10 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
});
n.WaitForNotification();
+ if (done_callback_) {
+ done_callback_(status, response);
+ }
+
return status;
}
@@ -52,6 +79,13 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
EnqueueRequest request_;
tensorflow::eager::EagerClient*
eager_client_; // Not owned, and must outlive the RemoteExecuteNode.
+
+ // This is required to ensure that the tensor handles stay alive across the
+ // execution.
+ gtl::InlinedVector<TensorHandle*, 4> inputs_;
+
+ std::function<void(const Status& status, const EnqueueResponse& response)>
+ done_callback_;
};
} // namespace eager
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 8447c55bf4..e2f13df19f 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -118,9 +120,11 @@ Status GraphMgr::DecorateAndPublishGraphForDebug(
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
Item* item) {
item->session = session;
+ item->collective_graph_key = collective_graph_key;
item->lib_def.reset(
new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
@@ -280,11 +284,12 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options,
+ int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle) {
Item* item = new Item;
- Status s =
- InitItem(session, gdef, graph_options, debug_options, cluster_flr, item);
+ Status s = InitItem(session, gdef, graph_options, debug_options,
+ collective_graph_key, cluster_flr, item);
if (!s.ok()) {
item->Unref();
return s;
@@ -415,7 +420,12 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
-
+ CollectiveExecutor::Handle* ce_handle =
+ item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
+ ? new CollectiveExecutor::Handle(
+ worker_env_->collective_executor_mgr->FindOrCreate(step_id),
+ true)
+ : nullptr;
// Sends values specified by the caller.
if (s.ok()) {
std::vector<string> keys;
@@ -431,22 +441,25 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
if (!s.ok()) {
done(s);
+ delete ce_handle;
item->Unref();
rendezvous->Unref();
return;
}
- StartParallelExecutors(handle, step_id, item, rendezvous, collector,
- cost_graph, cancellation_manager,
- [item, rendezvous, done](const Status& s) {
+ StartParallelExecutors(handle, step_id, item, rendezvous, ce_handle,
+ collector, cost_graph, cancellation_manager,
+ [item, rendezvous, ce_handle, done](const Status& s) {
done(s);
rendezvous->Unref();
item->Unref();
+ delete ce_handle;
});
}
void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
Item* item, Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -471,6 +484,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
args.step_id = ++next_id_;
}
args.rendezvous = rendezvous;
+ args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
args.stats_collector = collector;
args.step_container = step_container;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index cc35264b8f..5196046c19 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -75,7 +76,7 @@ class GraphMgr {
// reference to cluster_flr to do cross process function calls.
Status Register(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr,
string* handle);
@@ -138,6 +139,8 @@ class GraphMgr {
// Used to deregister a cost model when cost model is required in graph
// manager.
GraphMgr* graph_mgr;
+
+ int64 collective_graph_key;
};
const WorkerEnv* worker_env_; // Not owned.
@@ -161,6 +164,7 @@ class GraphMgr {
void StartParallelExecutors(const string& handle, int64 step_id, Item* item,
Rendezvous* rendezvous,
+ CollectiveExecutor::Handle* ce_handle,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@@ -175,7 +179,7 @@ class GraphMgr {
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
- const DebugOptions& debug_options,
+ const DebugOptions& debug_options, int64 collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, Item* item);
Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h
index cad6babad8..b9c76d0f1d 100644
--- a/tensorflow/core/distributed_runtime/local_master.h
+++ b/tensorflow/core/distributed_runtime/local_master.h
@@ -79,7 +79,7 @@ class LocalMaster : public MasterInterface {
RunCallableResponse* response) override;
Status ReleaseCallable(CallOptions* call_options,
const ReleaseCallableRequest* request,
- ReleaseCallableResponse* response);
+ ReleaseCallableResponse* response) override;
// Registers the mapping from the given `target` to the given `master`.
//
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index 4f9d84d158..a48f734d3e 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -473,7 +473,7 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
return;
}
- SchedClosure([this, session, req, resp, done]() {
+ SchedClosure([session, req, resp, done]() {
Status s = session->PartialRunSetup(req, resp);
session->Unref();
done(s);
@@ -628,7 +628,7 @@ void Master::MakeCallable(const MakeCallableRequest* req,
}
SchedClosure(std::bind(
- [this, session, req, resp](MyClosure done) {
+ [session, req, resp](MyClosure done) {
Status s = session->MakeCallable(*req, resp);
session->Unref();
done(s);
@@ -645,7 +645,7 @@ void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
}
SchedClosure(std::bind(
- [this, session, opts, req, resp](MyClosure done) {
+ [session, opts, req, resp](MyClosure done) {
Status s = session->RunCallable(opts, *req, resp);
session->Unref();
done(s);
@@ -662,7 +662,7 @@ void Master::ReleaseCallable(const ReleaseCallableRequest* req,
}
SchedClosure(std::bind(
- [this, session, req, resp](MyClosure done) {
+ [session, req, resp](MyClosure done) {
Status s = session->ReleaseCallable(*req, resp);
session->Unref();
done(s);
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index 16f4d93c8b..da26c42aca 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -26,6 +26,7 @@ limitations under the License.
namespace tensorflow {
+class CollectiveExecutorMgrInterface;
class Device;
class DeviceSet;
class Env;
@@ -90,6 +91,10 @@ struct MasterEnv {
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
+
+ // Generates per-step CollectiveExecutors and has access to utilities
+ // supporting collective operations.
+ CollectiveExecutorMgrInterface* collective_executor_mgr = nullptr;
};
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index bd70eca3f6..d34ca53f73 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -69,6 +70,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
bool is_partial, WorkerCacheInterface* worker_cache,
bool should_deregister)
: session_handle_(handle),
+ bg_opts_(bopts),
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
@@ -100,6 +102,8 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const CallableOptions& callable_options() { return callable_opts_; }
+ const BuildGraphOptions& build_graph_options() { return bg_opts_; }
+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
@@ -156,8 +160,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
LoggingResponse* resp = new LoggingResponse;
p.worker->LoggingAsync(
&req, resp,
- [step_id, ss, resp, &scoped_mu, &waiting_for,
- &all_done](const Status& s) {
+ [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) {
{
mutex_lock l(scoped_mu);
if (s.ok()) {
@@ -226,6 +229,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
private:
const string session_handle_;
+ const BuildGraphOptions bg_opts_;
const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
@@ -445,6 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
+ c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -1066,6 +1071,9 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
*callable_opts->mutable_run_options()->mutable_debug_options() =
req.options().debug_options();
}
+
+ opts->collective_graph_key =
+ req.options().experimental().collective_graph_key();
}
void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
@@ -1103,6 +1111,10 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ h = Hash64Combine(opts.collective_graph_key, h);
+ }
+
return h;
}
@@ -1119,6 +1131,9 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
for (const string& name : opts.callable_options.fetch()) {
strings::StrAppend(&buf, " FeE: ", name);
}
+ if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
+ strings::StrAppend(&buf, "\nGK: ", opts.collective_graph_key);
+ }
strings::StrAppend(&buf, "\n");
return buf;
}
@@ -1207,7 +1222,7 @@ Status MasterSession::CreateWorkerSessions(
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
- auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
+ auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
@@ -1289,7 +1304,7 @@ Status MasterSession::DeleteWorkerSessions() {
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
- auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
+ auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
@@ -1431,11 +1446,35 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
-namespace {
-uint64 MakeStepId() {
- return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+uint64 MasterSession::NewStepId(int64 graph_key) {
+ if (graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // StepId must leave the most-significant 7 bits empty for future use.
+ return random::New64() & (((1uLL << 56) - 1) | (1uLL << 56));
+ } else {
+ uint64 step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ int32 retry_count = 0;
+ while (step_id == CollectiveExecutor::kInvalidId) {
+ Notification note;
+ Status status;
+ env_->collective_executor_mgr->RefreshStepIdSequenceAsync(
+ graph_key, [&status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ if (!status.ok()) {
+ LOG(ERROR) << "Bad status from "
+ "collective_executor_mgr->RefreshStepIdSequence: "
+ << status << ". Retrying.";
+ int64 delay_micros = std::min(60000000LL, 1000000LL * ++retry_count);
+ Env::Default()->SleepForMicroseconds(delay_micros);
+ } else {
+ step_id = env_->collective_executor_mgr->NextStepId(graph_key);
+ }
+ }
+ return step_id;
+ }
}
-} // namespace
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
@@ -1457,15 +1496,13 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
- int64 count;
+ int64 count = 0;
TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
- TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
- RunState* run_state = new RunState(inputs, outputs, rcg, step_id, count);
+ RunState* run_state =
+ new RunState(inputs, outputs, rcg,
+ NewStepId(BuildGraphOptions::kNoCollectiveGraphKey), count);
{
mutex_lock l(mu_);
partial_runs_.emplace(
@@ -1567,6 +1604,13 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
}
run_state = it->second.get();
}
+ // CollectiveOps are not supported in partial runs.
+ if (req.options().experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ return errors::InvalidArgument(
+ "PartialRun does not support Collective ops. collective_graph_key "
+ "must be kNoCollectiveGraphKey.");
+ }
// If this is the first partial run, initialize the PerStepState.
if (!run_state->step_started) {
@@ -1744,7 +1788,11 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
-
+ if (rcg->build_graph_options().collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ env_->collective_executor_mgr->RetireStepId(
+ rcg->build_graph_options().collective_graph_key, step_id);
+ }
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
} else if (errors::IsCancelled(s)) {
@@ -1802,7 +1850,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- const uint64 step_id = MakeStepId();
+ uint64 step_id = NewStepId(bgopts.collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1866,9 +1914,8 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- // Keeps the highest 8 bits 0x01: we reserve some bits of the
- // step_id for future use.
- const uint64 step_id = MakeStepId();
+ const uint64 step_id =
+ NewStepId(rcg->build_graph_options().collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index ec34e20b79..449a6d3e3c 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -141,6 +141,8 @@ class MasterSession : public core::RefCounted {
std::atomic<int64> partial_run_handle_counter_ = {0};
+ uint64 NewStepId(int64 graph_key);
+
mutex mu_;
std::unique_ptr<GraphExecutionState> execution_state_ GUARDED_BY(mu_);
int64 graph_version_;
@@ -175,6 +177,7 @@ class MasterSession : public core::RefCounted {
std::unordered_map<string, bool> pending_outputs; // true if fetched
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
+ int64 collective_graph_key;
int64 count = 0;
PerStepState pss;
std::unique_ptr<ProfileHandler> ph;
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
index 0826a90860..62b18a45b1 100644
--- a/tensorflow/core/distributed_runtime/master_test.cc
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc
index 778060daaf..a04e79328b 100644
--- a/tensorflow/core/distributed_runtime/remote_device_test.cc
+++ b/tensorflow/core/distributed_runtime/remote_device_test.cc
@@ -49,8 +49,9 @@ class RemoteDeviceTest : public ::testing::Test {
TF_CHECK_OK(spec.AddHostPortsJob("localhost", {hostport}));
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
- worker_cache_.reset(
- NewGrpcWorkerCache(NewGrpcChannelCache(spec, channel_func)));
+ std::shared_ptr<GrpcChannelCache> channel_cache(
+ NewGrpcChannelCache(spec, channel_func));
+ worker_cache_.reset(NewGrpcWorkerCache(channel_cache));
remote_name_ = "/job:localhost/replica:0/task:0";
wi_ = worker_cache_->CreateWorker(remote_name_);
}
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 4b2747f26d..4a10d99a60 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -41,8 +41,8 @@ cc_library(
srcs = ["grpc_util.cc"],
hdrs = ["grpc_util.h"],
deps = [
- "@grpc//:grpc_unsecure",
- "@grpc//:grpc++_unsecure",
+ "//tensorflow:grpc",
+ "//tensorflow:grpc++",
"//tensorflow/core:lib",
# Required to be able to overload TensorResponse parsing.
"//tensorflow/core/distributed_runtime:tensor_coding",
@@ -55,8 +55,8 @@ cc_library(
hdrs = ["grpc_client_cq_tag.h"],
deps = [
":grpc_util",
+ "//tensorflow:grpc++",
"//tensorflow/core:lib",
- "@grpc//:grpc++_unsecure",
],
)
@@ -67,10 +67,10 @@ cc_library(
deps = [
":grpc_client_cq_tag",
":grpc_util",
+ "//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:tensor_coding",
- "@grpc//:grpc++_unsecure",
],
)
@@ -83,6 +83,7 @@ cc_library(
":grpc_state",
":grpc_util",
":grpc_worker_service_impl",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -90,7 +91,6 @@ cc_library(
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_cache_logger",
"//tensorflow/core/distributed_runtime:worker_interface",
- "@grpc//:grpc++_unsecure",
],
)
@@ -100,10 +100,10 @@ cc_library(
hdrs = ["grpc_channel.h"],
deps = [
":grpc_util",
+ "//tensorflow:grpc++",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@grpc//:grpc++_unsecure",
],
)
@@ -112,13 +112,13 @@ cc_library(
srcs = ["grpc_tensor_coding.cc"],
hdrs = ["grpc_tensor_coding.h"],
deps = [
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
- "@grpc//:grpc++_unsecure",
],
)
@@ -127,9 +127,9 @@ cc_library(
srcs = [],
hdrs = ["grpc_call.h"],
deps = [
+ "//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
- "@grpc//:grpc++_unsecure",
],
)
@@ -167,6 +167,7 @@ tf_cuda_library(
":grpc_tensor_coding",
":grpc_util",
":grpc_worker_service_impl",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -180,7 +181,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
- "@grpc//:grpc++_unsecure",
],
)
@@ -190,9 +190,9 @@ cc_library(
hdrs = ["grpc_worker_service_impl.h"],
deps = [
":grpc_util",
+ "//tensorflow:grpc++",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:tensor_coding",
- "@grpc//:grpc++_unsecure",
],
)
@@ -221,11 +221,11 @@ cc_library(
":grpc_call",
":grpc_master_service_impl",
":grpc_util",
+ "//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core/distributed_runtime:master",
- "@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
@@ -235,8 +235,8 @@ cc_library(
srcs = ["grpc_master_service_impl.cc"],
hdrs = ["grpc_master_service_impl.h"],
deps = [
+ "//tensorflow:grpc++",
"//tensorflow/core:master_proto_cc",
- "@grpc//:grpc++_unsecure",
],
)
@@ -269,21 +269,26 @@ cc_library(
":grpc_worker_cache",
":grpc_worker_service",
":rpc_rendezvous_mgr",
+ "//tensorflow:grpc",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
+ "//tensorflow/core/distributed_runtime:device_resolver_distributed",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:local_master",
"//tensorflow/core/distributed_runtime:master",
"//tensorflow/core/distributed_runtime:master_env",
"//tensorflow/core/distributed_runtime:master_session",
+ "//tensorflow/core/distributed_runtime:rpc_collective_executor_mgr",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:session_mgr",
+ "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env",
- "@grpc//:grpc++_unsecure",
- "@grpc//:grpc_unsecure",
+ "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
],
alwayslink = 1,
)
@@ -304,13 +309,13 @@ tf_cc_binary(
],
deps = [
":grpc_server_lib",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:data_flow",
- "@grpc//:grpc++_unsecure",
],
)
@@ -322,6 +327,7 @@ tf_cc_binary(
],
deps = [
":grpc_server_lib",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
@@ -335,7 +341,6 @@ tf_cc_binary(
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:reduction_ops",
"//tensorflow/core/kernels:variable_ops",
- "@grpc//:grpc++_unsecure",
],
)
@@ -420,6 +425,7 @@ tf_cc_test(
deps = [
":grpc_tensor_coding",
":grpc_testlib",
+ "//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -429,7 +435,6 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core:worker_proto_cc",
- "@grpc//:grpc++_unsecure",
],
)
@@ -439,11 +444,11 @@ tf_cc_test(
srcs = ["grpc_util_test.cc"],
deps = [
":grpc_util",
+ "//tensorflow:grpc",
+ "//tensorflow:grpc++",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:worker_proto_cc",
- "@grpc//:grpc++_unsecure",
- "@grpc//:grpc_unsecure",
],
)
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
index 1a3bd9d6bf..d09a85c6a5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD
@@ -11,8 +11,8 @@ cc_library(
srcs = ["grpc_eager_service.cc"],
hdrs = ["grpc_eager_service.h"],
deps = [
+ "//tensorflow:grpc++",
"//tensorflow/core:eager_service_proto_cc",
- "@grpc//:grpc++_unsecure",
],
)
@@ -21,6 +21,7 @@ cc_library(
srcs = ["grpc_eager_client.cc"],
hdrs = ["grpc_eager_client.h"],
deps = [
+ "//tensorflow:grpc++",
"//tensorflow/core:eager_service_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/distributed_runtime/eager:eager_client",
@@ -29,7 +30,6 @@ cc_library(
"//tensorflow/core/distributed_runtime/rpc:grpc_state",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
- "@grpc//:grpc++_unsecure",
],
)
@@ -39,29 +39,15 @@ cc_library(
hdrs = ["grpc_eager_service_impl.h"],
deps = [
":grpc_eager_service",
+ "//tensorflow:grpc++",
"//tensorflow/core:framework",
"//tensorflow/core:ptr_util",
"//tensorflow/core/distributed_runtime/eager:eager_service_impl",
+ "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
- "@grpc//:grpc++_unsecure",
- ],
-)
-
-cc_library(
- name = "eager_grpc_server_lib",
- hdrs = ["eager_grpc_server_lib.h"],
- deps = [
- ":grpc_eager_service_impl",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
- "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
- "//tensorflow/core/distributed_runtime/eager:eager_service_impl",
- "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
- "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
],
)
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
deleted file mode 100644
index f5dc4c831d..0000000000
--- a/tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h
+++ /dev/null
@@ -1,97 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
-
-#include "tensorflow/core/common_runtime/device_factory.h"
-#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
-#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
-#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
-
-namespace tensorflow {
-namespace eager {
-
-class EagerGrpcServer : public GrpcServer {
- public:
- static Status Create(const ServerDef& server_def,
- std::unique_ptr<EagerGrpcServer>* server) {
- std::unique_ptr<EagerGrpcServer> ret(new EagerGrpcServer(server_def));
-
- TF_RETURN_IF_ERROR(ret->InitEager());
-
- *server = std::move(ret);
-
- return Status::OK();
- }
-
- Status Start() override {
- TF_RETURN_IF_ERROR(GrpcServer::Start());
-
- eager_service_->Start();
-
- return Status::OK();
- }
-
- Status Stop() override {
- TF_RETURN_IF_ERROR(GrpcServer::Stop());
-
- eager_service_->Stop();
-
- return Status::OK();
- }
-
- using GrpcServer::channel_cache;
- using GrpcServer::master_env;
- using GrpcServer::worker_env;
-
- private:
- EagerGrpcServer(const ServerDef& server_def)
- : GrpcServer(server_def, Env::Default()),
- worker_name_(
- strings::StrCat("/job:", server_def.job_name(),
- "/replica:0/task:", server_def.task_index())) {}
-
- Status InitEager() {
- TF_RETURN_IF_ERROR(this->Init(
- [this](const WorkerEnv* worker_env,
- ::grpc::ServerBuilder* server_builder) {
- this->eager_service_.reset(
- new eager::GrpcEagerServiceImpl(worker_env, server_builder));
- },
- nullptr));
-
- worker_session_ = WorkerSession::CreateWithBorrowedDeviceMgr(
- "", worker_name_,
- std::unique_ptr<WorkerCacheInterface>(
- new WorkerCacheWrapper(master_env()->worker_cache)),
- worker_env()->device_mgr, {});
-
- auto* r = worker_env()->rendezvous_mgr->Find(0);
- return r->Initialize(worker_session_.get());
- }
-
- std::unique_ptr<GrpcEagerServiceImpl> eager_service_;
- std::shared_ptr<WorkerSession> worker_session_;
- const string worker_name_;
-}; // namespace eager
-
-} // namespace eager
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_EAGER_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
index 4786c43ee2..b23466037f 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
-#include "grpc++/generic/generic_stub.h"
+#include "grpcpp/generic/generic_stub.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
index 3fd7deaa86..39ab6856c5 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
namespace eager {
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
index d7b192ac85..66458186ad 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/core/protobuf/eager_service.pb.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
index b36c6dce86..52e06c263d 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc
@@ -18,10 +18,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -36,7 +34,7 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl(
cq_ = server_builder->AddCompletionQueue();
}
-void GrpcEagerServiceImpl::DriveCQ() {
+void GrpcEagerServiceImpl::HandleRPCsLoop() {
#define ENQUEUE_REQUEST(method) \
do { \
Call<GrpcEagerServiceImpl, \
@@ -74,12 +72,7 @@ void GrpcEagerServiceImpl::DriveCQ() {
}
}
-void GrpcEagerServiceImpl::Start() {
- // TODO(nareshmodi) separate thread for driving CQ
- request_handler_threadpool_->Schedule([this]() { DriveCQ(); });
-}
-
-void GrpcEagerServiceImpl::Stop() {
+void GrpcEagerServiceImpl::Shutdown() {
// This enqueues a special event (with a null tag)
// that causes the completion queue to be shut down on the
// polling thread.
diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
index 65550caf64..9a94026342 100644
--- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h
@@ -16,20 +16,20 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_IMPL_H_
-#include "grpc++/alarm.h"
-#include "grpc++/completion_queue.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/completion_queue.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
-#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
namespace tensorflow {
namespace eager {
// This class is a wrapper that handles communication for gRPC.
-class GrpcEagerServiceImpl {
+class GrpcEagerServiceImpl : public AsyncServiceInterface {
public:
template <class RequestMessage, class ResponseMessage>
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
@@ -39,8 +39,8 @@ class GrpcEagerServiceImpl {
::grpc::ServerBuilder* server_builder);
virtual ~GrpcEagerServiceImpl() {}
- void Start();
- void Stop();
+ void HandleRPCsLoop() override;
+ void Shutdown() override;
private:
#define HANDLER(method) \
@@ -66,8 +66,6 @@ class GrpcEagerServiceImpl {
EagerServiceImpl local_impl_;
- void DriveCQ();
-
std::unique_ptr<::grpc::Alarm> shutdown_alarm_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
index ecad1274cc..90666def60 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_call.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
@@ -20,9 +20,9 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
-#include "grpc++/grpc++.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/server_builder.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 613188244f..b7eb3c9015 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <map>
#include <unordered_map>
-#include "grpc++/create_channel.h"
+#include "grpcpp/create_channel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
@@ -42,12 +42,12 @@ string MakeAddress(const string& job, int task) {
return strings::StrCat("/job:", job, "/replica:0/task:", task);
}
+// Allows the host to be a raw IP (either v4 or v6).
Status ValidateHostPortPair(const string& host_port) {
uint32 port;
- std::vector<string> parts = str_util::Split(host_port, ':');
- // Must be host:port, port must be a number, host must not contain a '/'.
- if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
- parts[0].find("/") != string::npos) {
+ auto colon_index = host_port.find_last_of(':');
+ if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
+ host_port.substr(0, colon_index).find("/") != string::npos) {
return errors::InvalidArgument("Could not interpret \"", host_port,
"\" as a host-port pair.");
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 48b9d958aa..4861cdb691 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -22,7 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index a17acc85b3..f07a5a0974 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -150,10 +150,15 @@ TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
EXPECT_TRUE(NewHostPortGrpcChannel("127.0.0.1:2222", &mock_ptr).ok());
EXPECT_TRUE(NewHostPortGrpcChannel("example.com:2222", &mock_ptr).ok());
EXPECT_TRUE(NewHostPortGrpcChannel("fqdn.example.com.:2222", &mock_ptr).ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("[2002:a9c:258e::]:2222", &mock_ptr).ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("[::]:2222", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:2222", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("127.0.0.1:2222/", &mock_ptr).ok());
EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]/:2222", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]:2222/", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("[::]:", &mock_ptr).ok());
}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
index d367b83ee7..6e7f5dbd13 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
-#include "grpc++/grpc++.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index e025e555dd..127dea2882 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -30,8 +30,8 @@ limitations under the License.
// RunGraph on workers.
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
-#include "grpc++/alarm.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
index 85adfd2c76..770a0fcf14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index 8f1b589698..751f2633e7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -16,14 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
#include "tensorflow/core/protobuf/master.pb.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 1acf1fb4fc..6008462d04 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <utility>
-#include "grpc++/generic/generic_stub.h"
-#include "grpc++/grpc++.h"
+#include "grpcpp/generic/generic_stub.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index e5ffb4ed2f..db14f6473e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -19,26 +19,31 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
#include "grpc/support/alloc.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/local_master.h"
#include "tensorflow/core/distributed_runtime/master.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -78,6 +83,7 @@ GrpcServer::~GrpcServer() {
delete master_service_;
delete worker_service_;
+ delete eager_service_;
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
@@ -106,6 +112,7 @@ GrpcServer::~GrpcServer() {
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory) {
mutex_lock l(mu_);
@@ -145,16 +152,14 @@ Status GrpcServer::Init(
" was not defined in job \"",
server_def_.job_name(), "\"");
}
- const std::vector<string> hostname_port =
- str_util::Split(iter->second, ':');
- if (hostname_port.size() != 2 ||
- !strings::safe_strto32(hostname_port[1], &requested_port)) {
+ auto colon_index = iter->second.find_last_of(':');
+ if (!strings::safe_strto32(iter->second.substr(colon_index + 1),
+ &requested_port)) {
return errors::InvalidArgument(
"Could not parse port for local server from \"", iter->second,
- "\"");
- } else {
- break;
+ "\".");
}
+ break;
}
}
if (requested_port == -1) {
@@ -188,6 +193,8 @@ Status GrpcServer::Init(
worker_func ? worker_func(&worker_env_) : NewGrpcWorker(&worker_env_);
worker_service_ =
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+ eager_service_ = new eager::GrpcEagerServiceImpl(&worker_env_, &builder);
+
// extra service:
if (service_func != nullptr) {
service_func(&worker_env_, &builder);
@@ -204,6 +211,26 @@ Status GrpcServer::Init(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
+ if (collective_mgr_func) {
+ worker_env_.collective_executor_mgr =
+ collective_mgr_func(config, &worker_env_, worker_cache);
+ if (!worker_env_.collective_executor_mgr) {
+ return errors::Internal(
+ "collective_mgr_func did not return CollectiveExecutorMgr");
+ }
+ } else {
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver(
+ new DeviceResolverDistributed(worker_env_.device_mgr, worker_cache,
+ default_worker_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver(
+ new CollectiveParamResolverDistributed(config, worker_env_.device_mgr,
+ dev_resolver.get(), worker_cache,
+ default_worker_name));
+ worker_env_.collective_executor_mgr = new RpcCollectiveExecutorMgr(
+ config, worker_env_.device_mgr, std::move(dev_resolver),
+ std::move(param_resolver), worker_cache, default_worker_name);
+ }
+
// Set up worker environment.
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
@@ -246,18 +273,27 @@ Status GrpcServer::Init(
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func) {
- return Init(std::move(service_func), rendezvous_mgr_func, worker_func,
- CreateNoOpStatsPublisher);
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ worker_func, CreateNoOpStatsPublisher);
+}
+
+Status GrpcServer::Init(
+ ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func) {
+ return Init(std::move(service_func), rendezvous_mgr_func, collective_mgr_func,
+ nullptr);
}
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
- return Init(service_func, rendezvous_mgr_func, nullptr);
+ return Init(std::move(service_func), rendezvous_mgr_func, nullptr, nullptr);
}
-Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr); }
+Status GrpcServer::Init() { return Init(nullptr, nullptr, nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
@@ -305,11 +341,13 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
const string host_port = channel_cache_->TranslateTask(name_prefix);
int requested_port;
- if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
+ auto colon_index = host_port.find_last_of(':');
+ if (!strings::safe_strto32(host_port.substr(colon_index + 1),
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
- channel_cache_->TranslateTask(name_prefix), "\".");
+ host_port, "\".");
}
+
if (requested_port != bound_port_) {
return errors::InvalidArgument("Requested port ", requested_port,
" differs from expected port ", bound_port_);
@@ -330,6 +368,9 @@ Status GrpcServer::Start() {
worker_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_worker_service",
[this] { worker_service_->HandleRPCsLoop(); }));
+ eager_thread_.reset(
+ env_->StartThread(ThreadOptions(), "TF_eager_service",
+ [this] { eager_service_->HandleRPCsLoop(); }));
state_ = STARTED;
LOG(INFO) << "Started server with target: " << target();
return Status::OK();
@@ -372,6 +413,7 @@ Status GrpcServer::Join() {
case STOPPED:
master_thread_.reset();
worker_thread_.reset();
+ eager_thread_.reset();
return Status::OK();
default:
LOG(FATAL);
@@ -403,7 +445,18 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<GrpcServer> ret(
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
ServiceInitFunction service_func = nullptr;
- TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
+ *out_server = std::move(ret);
+ return Status::OK();
+}
+
+/* static */
+Status GrpcServer::Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server) {
+ std::unique_ptr<GrpcServer> ret(
+ new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
+ ServiceInitFunction service_func = nullptr;
+ TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr, nullptr));
*out_server = std::move(ret);
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 0122df178a..3366246afb 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/server_lib.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/platform/env.h"
@@ -41,6 +42,11 @@ class Master;
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
+// function that creates a CollectiveExecutorMgr.
+typedef std::function<CollectiveExecutorMgrInterface*(
+ const ConfigProto&, const WorkerEnv*, WorkerCacheInterface*)>
+ CollectiveMgrCreationFunction;
+
// function that registers a service to the server. The service needs to
// be registered before builder.BuildAndStart().
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
@@ -57,6 +63,8 @@ class GrpcServer : public ServerInterface {
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
+ static Status Create(const ServerDef& server_def, Env* env,
+ std::unique_ptr<GrpcServer>* out_server);
// Destruction is only supported in the factory method. Clean
// shutdown is not currently implemented for this server type.
@@ -68,17 +76,28 @@ class GrpcServer : public ServerInterface {
Status Join() override;
const string target() const override;
+ WorkerEnv* worker_env() { return &worker_env_; }
+ MasterEnv* master_env() { return &master_env_; }
+
+ std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
+
protected:
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func,
const StatsPublisherFactory& stats_factory);
Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func,
const WorkerCreationFunction& worker_func);
Status Init(ServiceInitFunction service_func,
+ const RendezvousMgrCreationFunction& rendezvous_mgr_func,
+ const CollectiveMgrCreationFunction& collective_mgr_func);
+
+ Status Init(ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init();
@@ -103,11 +122,6 @@ class GrpcServer : public ServerInterface {
// This method may only be called after `this->Init()` returns successfully.
int bound_port() const { return bound_port_; }
- WorkerEnv* worker_env() { return &worker_env_; }
- MasterEnv* master_env() { return &master_env_; }
-
- std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; }
-
const ServerDef& server_def() const { return server_def_; }
private:
@@ -146,6 +160,11 @@ class GrpcServer : public ServerInterface {
AsyncServiceInterface* worker_service_ = nullptr;
std::unique_ptr<Thread> worker_thread_ GUARDED_BY(mu_);
+ // TensorFlow Eager implementation, and RPC polling thread.
+ AsyncServiceInterface* eager_service_ = nullptr;
+ std::unique_ptr<Thread> eager_thread_ GUARDED_BY(mu_);
+ std::shared_ptr<WorkerSession> worker_session_;
+
std::unique_ptr<::grpc::Server> server_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 45b15a54a2..fc601991a2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -163,6 +163,39 @@ TEST(GrpcSessionTest, BasicCallable) {
}
}
+TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) {
+ // Specifying feeds/fetch devices for remote sessions is not yet defined.
+ // Ensure that the error is graceful.
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(graph));
+
+ std::vector<DeviceAttributes> devices;
+ TF_CHECK_OK(session->ListDevices(&devices));
+ ASSERT_GT(devices.size(), 0);
+ const string device_name = devices.back().name();
+
+ CallableOptions opts;
+ const string fetch = node_names[2] + ":0";
+ opts.add_fetch(fetch);
+ opts.mutable_fetch_devices()->insert({fetch, device_name});
+
+ Session::CallableHandle handle;
+ Status status = session->MakeCallable(opts, &handle);
+ EXPECT_EQ(error::UNIMPLEMENTED, status.code());
+ TF_CHECK_OK(session->Close());
+}
+
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index 59dbb7ae04..61c5bc285f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <utility>
-#include "grpc++/generic/generic_stub.h"
-#include "grpc++/grpc++.h"
+#include "grpcpp/generic/generic_stub.h"
+#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
index e51894b4c7..159435fd7d 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
-#include "grpc++/support/byte_buffer.h"
-#include "grpc++/support/slice.h"
+#include "grpcpp/support/byte_buffer.h"
+#include "grpcpp/support/slice.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/worker.pb.h"
+// (Omitted internal-only flag)
+
namespace tensorflow {
namespace grpc {
@@ -168,15 +170,20 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
(header.size() +
VarLengthEncodingSize(RecvTensorResponse::kTensorFieldNumber,
overall_tensor_proto_bytesize));
- // If "tensor_data_is_large == false", we copy the tensor data to the
- // end of the buffer we are preparing that holds the rest of the
+ // If "share_tensor_slice_memory == false", we copy the tensor data to
+ // the end of the buffer we are preparing that holds the rest of the
// RecvTensorResponse protocol buffer.
//
- // If "tensor_data_is_large == true", we arrange to share the backing
- // store of the data by creating a slice that also points to the
+ // If "share_tensor_slice_memory == true", we arrange to share the
+ // backing store of the data by creating a slice that also points to the
// backing store, with appropriate reference counts to keep the
// backing store alive as needed.
- bool tensor_data_is_large = (tdata.size() > kLargeTensorBytes);
+ //
+ // We enable this behavior if the tensor is large.
+ bool share_tensor_slice_memory = (tdata.size() > kLargeTensorBytes);
+
+ // (Omitted internal-only conditional)
+
size_t encoder_size = expected_size - tdata.size();
// Encode all but the actual "tdata", but including the tag and
@@ -201,10 +208,11 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
::grpc::Slice slices[2];
int num_slices = 0;
{
- size_t slice_len = e.size() + (tensor_data_is_large ? 0 : tdata.size());
+ size_t slice_len =
+ e.size() + (share_tensor_slice_memory ? 0 : tdata.size());
slices[0] = ::grpc::Slice(slice_len);
memcpy(const_cast<uint8_t*>(slices[0].begin()), e.data(), e.size());
- if (!tensor_data_is_large) {
+ if (!share_tensor_slice_memory) {
// (E)
memcpy(const_cast<uint8_t*>(slices[0].begin()) + e.size(), tdata.data(),
tdata.size());
@@ -212,7 +220,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
num_slices += 1;
}
- if (tensor_data_is_large) {
+ if (share_tensor_slice_memory) {
// (E) Encode tensor data, but by sharing backing store
const TensorBuffer* buf = DMAHelper::buffer(&val);
buf->Ref();
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
index 71f69e9024..7cace573e8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
-#include "grpc++/support/byte_buffer.h"
-#include "grpc++/support/slice.h"
+#include "grpcpp/support/byte_buffer.h"
+#include "grpcpp/support/slice.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
index f247322bc4..e52b257411 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include <iostream>
#include <vector>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
index 89f83f9f24..a8508d2d4f 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -50,9 +51,14 @@ Status TestCluster::MakeTestCluster(const SessionOptions& options, int n,
}
for (int i = 0; i < n; ++i) {
+ string server_file =
+ strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/distributed_runtime/rpc/grpc_testlib_server");
+ if (!options.env->FileExists(server_file).ok()) {
+ return errors::Internal("Could not find grpc_testlib_server");
+ }
const std::vector<string> argv(
- {strings::StrCat(testing::TensorFlowSrcRoot(),
- "/core/distributed_runtime/rpc/grpc_testlib_server"),
+ {server_file,
/* see grpc_testlib_server.cc for flags */
tf_jobs, "--tf_job=localhost", strings::StrCat("--tf_task=", i),
strings::StrCat("--num_cpus=", num_cpus),
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
index e718db251c..33cbadda0a 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <vector>
-#include "grpc++/grpc++.h"
-#include "grpc++/security/credentials.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/security/credentials.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
index 4b58781b54..45259aa2ec 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_util.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
-#include "grpc++/grpc++.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/support/byte_buffer.h"
+#include "grpcpp/grpcpp.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/support/byte_buffer.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 2e7b111963..61f5369617 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include <deque>
-#include "grpc++/alarm.h"
-#include "grpc++/server_builder.h"
+#include "grpcpp/alarm.h"
+#include "grpcpp/server_builder.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -513,8 +513,8 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
rma->buf_rendezvous()->ConsumeBuf(
request->buf_rendezvous_key(),
- [this, opts, request, response, done](const Status& status,
- BufRendezvous::Hook* hook) {
+ [this, request, response, done](const Status& status,
+ BufRendezvous::Hook* hook) {
Status s = status;
if (s.ok()) {
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
index 38cc2b81d3..72b5e77f1c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/channel_interface.h"
-#include "grpc++/impl/codegen/client_unary_call.h"
-#include "grpc++/impl/codegen/method_handler_impl.h"
-#include "grpc++/impl/codegen/rpc_service_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/sync_stream.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/channel_interface.h"
+#include "grpcpp/impl/codegen/client_unary_call.h"
+#include "grpcpp/impl/codegen/method_handler_impl.h"
+#include "grpcpp/impl/codegen/rpc_service_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
namespace tensorflow {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
index da270835bd..7915c3aafd 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
@@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_IMPL_H_
-#include "grpc++/impl/codegen/async_stream.h"
-#include "grpc++/impl/codegen/async_unary_call.h"
-#include "grpc++/impl/codegen/proto_utils.h"
-#include "grpc++/impl/codegen/rpc_method.h"
-#include "grpc++/impl/codegen/service_type.h"
-#include "grpc++/impl/codegen/status.h"
-#include "grpc++/impl/codegen/stub_options.h"
-#include "grpc++/impl/codegen/sync_stream.h"
-#include "grpc++/support/byte_buffer.h"
+#include "grpcpp/impl/codegen/async_stream.h"
+#include "grpcpp/impl/codegen/async_unary_call.h"
+#include "grpcpp/impl/codegen/proto_utils.h"
+#include "grpcpp/impl/codegen/rpc_method.h"
+#include "grpcpp/impl/codegen/service_type.h"
+#include "grpcpp/impl/codegen/status.h"
+#include "grpcpp/impl/codegen/stub_options.h"
+#include "grpcpp/impl/codegen/sync_stream.h"
+#include "grpcpp/support/byte_buffer.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
new file mode 100644
index 0000000000..45b989f6e2
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -0,0 +1,168 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+
+#include "tensorflow/core/common_runtime/base_collective_executor.h"
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name)
+ : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
+ std::move(param_resolver)),
+ worker_cache_(worker_cache),
+ task_name_(task_name) {
+ group_leader_ = (task_name == config.experimental().collective_group_leader())
+ ? ""
+ : config.experimental().collective_group_leader();
+}
+
+RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
+ for (auto it : sequence_table_) {
+ delete it.second;
+ }
+}
+
+CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
+ CollectiveRemoteAccessDistributed* rma =
+ new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
+ worker_cache_, step_id);
+ return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_);
+}
+
+namespace {
+// StepId must leave the most-significant 7 bits empty for future use.
+static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
+
+int64 NewRandomStepId() {
+ int64 step_id = random::New64();
+ // Leave MS 8 bits clear for future use.
+ step_id &= kStepIdMask;
+ return step_id;
+}
+} // namespace
+
+void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
+ int64 graph_key, const StatusCallback& done) {
+ if (group_leader_.empty()) {
+ mutex_lock l(sequence_mu_);
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(graph_key);
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = NewRandomStepId();
+ done(Status::OK());
+ } else {
+ WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
+ GetStepSequenceRequest* req = new GetStepSequenceRequest;
+ GetStepSequenceResponse* resp = new GetStepSequenceResponse;
+ req->add_graph_key(graph_key);
+ wi->GetStepSequenceAsync(
+ req, resp, [this, req, resp, done](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Bad response [" << s
+ << "] from GetStepSequenceAsync call to "
+ << group_leader_;
+ done(s);
+ } else {
+ done(UpdateStepSequences(*resp));
+ }
+ delete req;
+ delete resp;
+ });
+ }
+}
+
+void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
+ const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+ const StatusCallback& done) {
+ if (!group_leader_.empty()) {
+ LOG(ERROR) << "GetStepSequence called at non-group-leader";
+ done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
+ } else {
+ mutex_lock l(sequence_mu_);
+ for (int64 graph_key : request->graph_key()) {
+ auto it = sequence_table_.find(graph_key);
+ GraphKeySequence* gks = nullptr;
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ gks->next_step_id_ = NewRandomStepId();
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ StepSequence* ss = response->add_step_sequence();
+ ss->set_graph_key(graph_key);
+ ss->set_next_step_id(gks->next_step_id_);
+ }
+ done(Status::OK());
+ }
+}
+
+Status RpcCollectiveExecutorMgr::UpdateStepSequences(
+ const GetStepSequenceResponse& resp) {
+ mutex_lock l(sequence_mu_);
+ for (const StepSequence& ss : resp.step_sequence()) {
+ GraphKeySequence* gks = nullptr;
+ auto it = sequence_table_.find(ss.graph_key());
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(ss.graph_key());
+ sequence_table_[ss.graph_key()] = gks;
+ } else {
+ gks = it->second;
+ }
+ gks->next_step_id_ = ss.next_step_id();
+ }
+ return Status::OK();
+}
+
+int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ return it->second->next_step_id_;
+ }
+ return CollectiveExecutor::kInvalidId;
+}
+
+void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
+ mutex_lock l(sequence_mu_);
+ auto it = sequence_table_.find(graph_key);
+ if (it != sequence_table_.end()) {
+ if (step_id == it->second->next_step_id_) {
+ it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
+ } else {
+ it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
+ }
+ } else {
+ LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
new file mode 100644
index 0000000000..c9581fa00f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
+
+#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+
+namespace tensorflow {
+class CollectiveParamResolverDistributed;
+class ConfigProto;
+class DeviceMgr;
+class DeviceResolverDistributed;
+class WorkerCacheInterface;
+class StepSequenceRequest;
+class StepSequenceResponse;
+
+// An implementation of CollectiveExecutorMgr for a distributed environment
+// that uses WorkerInterface::RecvBufAsync to route data transfers over RPCs.
+//
+// In some execution environments it may be possible to implement a
+// higher-performance solution and use it in place of this class.
+class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
+ public:
+ RpcCollectiveExecutorMgr(
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ std::unique_ptr<DeviceResolverDistributed> dev_resolver,
+ std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
+ WorkerCacheInterface* worker_cache, const string& task_name);
+
+ virtual ~RpcCollectiveExecutorMgr();
+
+ // This function should only be called at the group_leader, by an RPC.
+ // Other needs for StepIds should be satisfied by NextStepId.
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) override;
+
+ void RefreshStepIdSequenceAsync(int64 graph_key,
+ const StatusCallback& done) override;
+
+ int64 NextStepId(int64 graph_key) override;
+
+ void RetireStepId(int64 graph_key, int64 step_id) override;
+
+ protected:
+ CollectiveExecutor* Create(int64 step_id) override;
+
+ WorkerCacheInterface* const worker_cache_; // Not owned.
+ const string task_name_;
+ string group_leader_;
+ friend class RpcCollectiveExecutorMgrTest;
+
+ private:
+ Status UpdateStepSequences(const GetStepSequenceResponse& resp);
+
+ // This class maintains the step_id sequencing for a single
+ // collective_graph_key.
+ struct GraphKeySequence {
+ explicit GraphKeySequence(int64 k)
+ : graph_key_(k), next_step_id_(CollectiveExecutor::kInvalidId) {}
+
+ const int64 graph_key_;
+ int64 next_step_id_;
+ };
+
+ mutex sequence_mu_;
+ gtl::FlatMap<int64, GraphKeySequence*> sequence_table_
+ GUARDED_BY(sequence_mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
new file mode 100644
index 0000000000..0323300fdd
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -0,0 +1,171 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+#define NUM_DEVS 3
+
+class RpcCollectiveExecutorMgrTest : public ::testing::Test {
+ protected:
+ RpcCollectiveExecutorMgrTest() {
+ string task_name = "/job:localhost/replica:0/task:0";
+ SessionOptions options;
+ options.config.mutable_experimental()->set_collective_group_leader(
+ task_name);
+ WorkerCacheInterface* worker_cache = nullptr;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", NUM_DEVS});
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
+ device_mgr_.reset(new DeviceMgr(devices_));
+ std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
+ device_mgr_.get(), worker_cache, task_name));
+ std::unique_ptr<CollectiveParamResolverDistributed> cpr(
+ new CollectiveParamResolverDistributed(options.config,
+ device_mgr_.get(), dr.get(),
+ worker_cache, task_name));
+ // This CME is the group leader.
+ cme_.reset(new RpcCollectiveExecutorMgr(options.config, device_mgr_.get(),
+ std::move(dr), std::move(cpr),
+ worker_cache, task_name));
+ }
+
+ std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+};
+
+TEST_F(RpcCollectiveExecutorMgrTest, FindOrCreate) {
+ CollectiveExecutor::Handle* h =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_TRUE(h->get());
+ CollectiveExecutor::Handle* h2 =
+ new CollectiveExecutor::Handle(cme_->FindOrCreate(1), true);
+ EXPECT_EQ(h->get(), h2->get());
+ CollectiveExecutor* ce = h->get();
+ delete h;
+ delete h2;
+ CollectiveExecutor* ce2 = cme_->FindOrCreate(1);
+ EXPECT_EQ(ce, ce2);
+ ce2->Unref();
+ cme_->Cleanup(1);
+}
+
+TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) {
+ int64 x = cme_->NextStepId(7);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ // Calling Refresh should generate a valid id.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ EXPECT_TRUE(status.ok());
+ }
+ x = cme_->NextStepId(7);
+ EXPECT_NE(x, CollectiveExecutor::kInvalidId);
+ // Should keep returning same number.
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on a different graph_key should have no effect.
+ cme_->RetireStepId(6, x);
+ EXPECT_EQ(x, cme_->NextStepId(7));
+ // Retire on same graph_key should advance.
+ cme_->RetireStepId(7, x);
+ int64 y = cme_->NextStepId(7);
+ EXPECT_EQ((x + 1) & (((1uLL << 56) - 1) | (1uLL << 56)), y);
+ // Calling refresh should jump to a different point in the random space.
+ {
+ Notification note;
+ Status status;
+ cme_->RefreshStepIdSequenceAsync(7,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ int64 z = cme_->NextStepId(7);
+ // z should not be equal to or a successor of y.
+ EXPECT_NE(y, z);
+ EXPECT_GT(llabs(y - z), 3);
+}
+
+TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) {
+ int64 x = cme_->NextStepId(3);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ int64 y = cme_->NextStepId(4);
+ EXPECT_EQ(y, CollectiveExecutor::kInvalidId);
+ GetStepSequenceRequest request;
+ GetStepSequenceResponse response;
+ request.add_graph_key(3);
+ request.add_graph_key(4);
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ std::unordered_map<int64, int64> values;
+ for (const auto& ss : response.step_sequence()) {
+ values[ss.graph_key()] = ss.next_step_id();
+ }
+ EXPECT_NE(values[3], CollectiveExecutor::kInvalidId);
+ EXPECT_NE(values[4], CollectiveExecutor::kInvalidId);
+ // Re-get, should be same values.
+ response.Clear();
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ for (const auto& ss : response.step_sequence()) {
+ EXPECT_EQ(values[ss.graph_key()], ss.next_step_id());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 4e6500fbc6..1ea19c48f0 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
@@ -72,7 +73,8 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(),
request->graph_options(), request->debug_options(),
- session->cluster_flr.get(), response->mutable_graph_handle());
+ request->collective_graph_key(), session->cluster_flr.get(),
+ response->mutable_graph_handle());
}
done(s);
}
@@ -315,6 +317,12 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
if (env_->collective_executor_mgr) {
env_->collective_executor_mgr->Cleanup(step_id);
}
+ for (Device* d : env_->local_devices) {
+ ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
+ if (sam) {
+ sam->Cleanup(step_id);
+ }
+ }
done(Status::OK());
}
diff --git a/tensorflow/core/framework/api_def.proto b/tensorflow/core/framework/api_def.proto
index 3f8dd272e7..f8553cf5bb 100644
--- a/tensorflow/core/framework/api_def.proto
+++ b/tensorflow/core/framework/api_def.proto
@@ -30,6 +30,10 @@ import "tensorflow/core/framework/attr_value.proto";
message ApiDef {
// Name of the op (in the OpDef) to specify the API for.
string graph_op_name = 1;
+ // If this op is deprecated, set deprecation message to the message
+ // that should be logged when this op is used.
+ // The message should indicate alternative op to use, if any.
+ string deprecation_message = 12;
enum Visibility {
// Normally this is "VISIBLE" unless you are inheriting a
@@ -56,10 +60,10 @@ message ApiDef {
// use a snake_case convention instead of CamelCase.
string name = 1;
- // If this endpoint is deprecated, set deprecation_message to a
- // message that should be logged when the endpoint is used.
- // The message should indicate alternative endpoint to use, if any.
- string deprecation_message = 2;
+ // Set if this endpoint is deprecated. If set to true, a message suggesting
+ // to use a non-deprecated endpoint instead will be printed. If all
+ // endpoints are deprecated, set deprecation_message in ApiDef instead.
+ bool deprecated = 3;
}
repeated Endpoint endpoint = 3;
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index f8d27d3868..c3e6388e28 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -225,6 +225,7 @@ class PeerAccessInterface {
const AllocatorAttributes& to_alloc_attr,
Tensor* to_tensor,
const DeviceLocality& client_locality,
+ int dev_to_dev_stream_index,
const StatusCallback& done) = 0;
virtual void PostToPeer(const string& peer_device, const string& peer_task,
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 6da0da14f0..21c6940b62 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -721,10 +721,15 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
bool is_training;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
int number_inputs = (is_training) ? 3 : 5;
- string data_format;
- TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
- DimensionHandle channel_dim =
- (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
+ string data_format_str;
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
+ TensorFormat data_format;
+ if (!FormatFromString(data_format_str, &data_format)) {
+ return errors::InvalidArgument("Invalid data format string: ",
+ data_format_str);
+ }
+ int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
+ DimensionHandle channel_dim = c->Dim(x, channel_dim_index);
// covers scale, offset, and if is_training is false, mean, variance
for (int i = 1; i < number_inputs; ++i) {
@@ -734,11 +739,7 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) {
}
ShapeHandle y;
- if (data_format == "NHWC") {
- TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
- } else {
- TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
- }
+ TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y));
c->set_output(0, y);
ShapeHandle vector_shape = c->Vector(channel_dim);
c->set_output(1, vector_shape);
@@ -755,16 +756,18 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
bool is_training;
- string data_format;
TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
- TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
- DimensionHandle channel_dim =
- (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1);
- if (data_format == "NHWC") {
- TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
- } else {
- TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
+ string data_format_str;
+ TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str));
+ TensorFormat data_format;
+ if (!FormatFromString(data_format_str, &data_format)) {
+ return errors::InvalidArgument("Invalid data format string: ",
+ data_format_str);
}
+ int channel_dim_index = GetTensorFeatureDimIndex(4, data_format);
+ DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index);
+ TF_RETURN_IF_ERROR(
+ c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim));
// covers scale, mean (reserve_space_1), variance (reserve_space_2)
for (int i = 2; i < 5; ++i) {
@@ -774,11 +777,8 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
}
ShapeHandle x_backprop;
- if (data_format == "NHWC") {
- TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
- } else {
- TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
- }
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop));
c->set_output(0, x_backprop);
c->set_output(1, c->Vector(channel_dim));
c->set_output(2, c->Vector(channel_dim));
@@ -1231,11 +1231,13 @@ Status ConcatV2Shape(InferenceContext* c) {
c->num_inputs() - 1 /* dim_index */);
}
-Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) {
- ShapeHandle shape_x = c->input(0);
- ShapeHandle shape_y = c->input(1);
+Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
+ ShapeHandle shape_x,
+ ShapeHandle shape_y,
+ ShapeHandle* out) {
+ CHECK_NOTNULL(out);
if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) {
- c->set_output(0, c->UnknownShape());
+ *out = c->UnknownShape();
return Status::OK();
}
const int32 rank_x = c->Rank(shape_x);
@@ -1293,7 +1295,7 @@ Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index) {
}
}
- c->set_output(output_index, c->MakeShape(dims));
+ *out = c->MakeShape(dims);
return Status::OK();
}
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 87bb133d92..2bedce1d6a 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -267,7 +267,22 @@ Status ConcatV2Shape(shape_inference::InferenceContext* c);
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
-Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, int output_index);
+// Note: out cannot be NULL.
+Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
+ ShapeHandle shape_x,
+ ShapeHandle shape_y,
+ ShapeHandle* out);
+
+// Shape function for binary operators that broadcast their inputs
+// and with output to output_index.
+inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c,
+ int output_index) {
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(
+ BroadcastBinaryOpOutputShapeFnHelper(c, c->input(0), c->input(1), &out));
+ c->set_output(output_index, out);
+ return Status::OK();
+}
// Shape function for binary operators that broadcast their inputs.
// Tested by ops/math_ops_test.cc.
diff --git a/tensorflow/core/framework/cost_graph.proto b/tensorflow/core/framework/cost_graph.proto
index 19d765cd32..cc6bc84d69 100644
--- a/tensorflow/core/framework/cost_graph.proto
+++ b/tensorflow/core/framework/cost_graph.proto
@@ -69,6 +69,9 @@ message CostGraphDef {
// Ids of the control inputs for this node.
repeated int32 control_input = 8;
+
+ // Are the costs inaccurate?
+ bool inaccurate = 17;
}
repeated Node node = 1;
}
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 23dc903caf..d8618f391e 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -459,6 +459,8 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
+
+ friend class DatasetToGraphOp; // For access to graph related members.
};
// Base-class for datasets that are built by ops.
@@ -584,6 +586,23 @@ class DatasetOpKernel : public OpKernel {
*output = argument_t->scalar<T>()();
return Status::OK();
}
+
+ template <typename T>
+ Status ParseVectorArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name,
+ std::vector<T>* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsVector(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a vector");
+ }
+ int size = argument_t->vec<T>().size();
+ output->reserve(size);
+ for (int i = 0; i < size; ++i) {
+ output->push_back(argument_t->vec<T>()(i));
+ }
+ return Status::OK();
+ }
};
// Encapsulates the work required to plug unary Datasets into the core
diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc
index e30ee84cc3..9108c32942 100644
--- a/tensorflow/core/framework/device_base.cc
+++ b/tensorflow/core/framework/device_base.cc
@@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#define EIGEN_USE_THREADS
+
#include "tensorflow/core/framework/device_base.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/util/work_sharder.h"
+
namespace tensorflow {
-DeviceBase::~DeviceBase() {}
+DeviceBase::~DeviceBase() { gtl::STLDeleteElements(&eigen_cpu_devices_); }
const DeviceAttributes& DeviceBase::attributes() const {
LOG(FATAL) << "Device does not implement attributes()";
@@ -27,4 +33,29 @@ const string& DeviceBase::name() const {
LOG(FATAL) << "Device does not implement name()";
}
+void DeviceBase::set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
+ // Eigen::ThreadPoolDevice is a very cheap struct (one pointer and
+ // an int). Therefore, we can afford a pre-allocated array of
+ // Eigen::ThreadPoolDevice. Here, we ensure that
+ // Eigen::ThreadPoolDevices in eigen_cpu_devices_ has increasingly
+ // larger numThreads.
+ for (int i = 1; i <= d->numThreads(); ++i) {
+ eigen_cpu_devices_.push_back(
+ new Eigen::ThreadPoolDevice(d->getPool(), i /* numThreads() */));
+ }
+}
+
+const Eigen::ThreadPoolDevice* DeviceBase::eigen_cpu_device() {
+ // Based on GetPerThreadMaxParallelism(), we return a different
+ // pre-allocated Eigen::ThreadPoolDevice. All these ThreadPoolDevice
+ // use the same underlying threadpool. But they use different
+ // nominal numThreads() hoping that the user of the returned
+ // Eigen::ThreadPoolDevice may not aggressively occupy all the
+ // threads in the underlying threadpool.
+ const int parallelism = std::max<int>(
+ 1,
+ std::min<int>(GetPerThreadMaxParallelism(), eigen_cpu_devices_.size()));
+ return eigen_cpu_devices_[parallelism - 1];
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index ec26d92a61..922d34fac9 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
-#include <unordered_map>
+#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -154,9 +154,7 @@ class DeviceBase {
}
// Does not take ownership.
- void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d) {
- eigen_cpu_device_ = d;
- }
+ void set_eigen_cpu_device(Eigen::ThreadPoolDevice* d);
#ifdef TENSORFLOW_USE_SYCL
void set_eigen_sycl_device(Eigen::SyclDevice* d) { eigen_sycl_device_ = d; }
@@ -186,11 +184,12 @@ class DeviceBase {
virtual ScopedAllocatorMgr* GetScopedAllocatorMgr() const { return nullptr; }
- virtual const Eigen::ThreadPoolDevice* eigen_cpu_device() {
- CHECK(eigen_cpu_device_ != nullptr);
- return eigen_cpu_device_;
+ const bool has_eigen_cpu_device() const {
+ return !eigen_cpu_devices_.empty();
}
+ virtual const Eigen::ThreadPoolDevice* eigen_cpu_device();
+
#ifdef TENSORFLOW_USE_SYCL
virtual const Eigen::SyclDevice* eigen_sycl_device() const {
CHECK(eigen_sycl_device_ != nullptr);
@@ -242,7 +241,7 @@ class DeviceBase {
// Set by GPUs as well as by TPU devices.
GpuDeviceInfo* gpu_device_info_ = nullptr;
thread::ThreadPool* device_thread_pool_ = nullptr;
- Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr;
+ std::vector<Eigen::ThreadPoolDevice*> eigen_cpu_devices_;
#ifdef TENSORFLOW_USE_SYCL
Eigen::SyclDevice* eigen_sycl_device_ = nullptr;
#endif
diff --git a/tensorflow/core/framework/device_base_test.cc b/tensorflow/core/framework/device_base_test.cc
new file mode 100644
index 0000000000..6909559ea2
--- /dev/null
+++ b/tensorflow/core/framework/device_base_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/device_base.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+TEST(DeviceBaseTest, CpuDevice) {
+ DeviceBase dbase(Env::Default());
+ thread::ThreadPool pool(Env::Default(), "test", 16);
+ EigenThreadPoolWrapper wrapper(&pool);
+ Eigen::ThreadPoolDevice eigen_device(&wrapper, pool.NumThreads());
+ ASSERT_FALSE(dbase.has_eigen_cpu_device());
+ dbase.set_eigen_cpu_device(&eigen_device);
+ ASSERT_TRUE(dbase.has_eigen_cpu_device());
+
+ {
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 16);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(4);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 4);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(1);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 1);
+ }
+
+ {
+ ScopedPerThreadMaxParallelism maxp(1000);
+ auto d = dbase.eigen_cpu_device();
+ EXPECT_EQ(d->numThreads(), 16);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 647c66099c..88d9d65f5a 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -815,6 +815,10 @@ string Canonicalize(const string& funcname, AttrSlice attrs,
entries.push_back(
strings::StrCat("_state_handle", "=", options.state_handle));
}
+ if (!options.executor_type.empty()) {
+ entries.push_back(
+ strings::StrCat("_executor_type", "=", options.executor_type));
+ }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 872906756a..8e607b927c 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -450,6 +450,12 @@ class FunctionLibraryRuntime {
// state (in stateful kernels); and two functions with different
// values for `state_handle` will have independent state.
string state_handle;
+
+ // This interface is EXPERIMENTAL and subject to change.
+ //
+ // Instatiates the function using an executor of the given type. If empty,
+ // the default TensorFlow executor will be used.
+ string executor_type;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc
index 4ffa503379..b2bc414c49 100644
--- a/tensorflow/core/framework/graph_to_functiondef.cc
+++ b/tensorflow/core/framework/graph_to_functiondef.cc
@@ -153,7 +153,7 @@ Status GraphToFunctionDef(const Graph& graph, const string& name,
const string normalized = node_names.Normalize(node->name());
argdef->set_name(normalized);
Edge const* edge;
- TF_CHECK_OK(node->input_edge(0, &edge));
+ TF_RETURN_IF_ERROR(node->input_edge(0, &edge));
return_values[normalized] =
strings::StrCat(edge->src()->name(), ":", edge->src_output());
continue;
diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto
index a17b9c8492..e16c2ae73b 100644
--- a/tensorflow/core/framework/kernel_def.proto
+++ b/tensorflow/core/framework/kernel_def.proto
@@ -34,3 +34,8 @@ message KernelDef {
// value matching this.
string label = 5;
}
+
+// A collection of KernelDefs
+message KernelList {
+ repeated KernelDef kernel = 1;
+};
diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc
new file mode 100644
index 0000000000..bbd3dd3e57
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/kernel_def_util.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/kernel_def.pb_text.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+
+namespace {
+// Helper for KernelAttrsMatch().
+bool InTypeList(DataType dt, const AttrValue& type_list) {
+ for (int in_list : type_list.list().type()) {
+ if (dt == in_list) return true;
+ }
+ return false;
+}
+} // namespace
+
+Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
+ bool* match) {
+ *match = false;
+ for (const auto& constraint : kernel_def.constraint()) {
+ if (constraint.allowed_values().list().type_size() == 0) {
+ return errors::Unimplemented(
+ "KernelDef '", ProtoShortDebugString(kernel_def),
+ " has constraint on attr '", constraint.name(),
+ "' with unsupported type: ",
+ SummarizeAttrValue(constraint.allowed_values()));
+ }
+
+ const AttrValue* found = attrs.Find(constraint.name());
+ if (found) {
+ if (found->type() != DT_INVALID) {
+ if (!InTypeList(found->type(), constraint.allowed_values())) {
+ return Status::OK();
+ }
+ } else {
+ if (!AttrValueHasType(*found, "list(type)").ok()) {
+ return errors::InvalidArgument(
+ "KernelDef '", ProtoShortDebugString(kernel_def),
+ "' has constraint on attr '", constraint.name(),
+ "' that has value '", SummarizeAttrValue(*found),
+ "' that does not have type 'type' or 'list(type)' in NodeDef "
+ "'",
+ attrs.SummarizeNode(), "'");
+ }
+
+ for (int t : found->list().type()) {
+ if (!InTypeList(static_cast<DataType>(t),
+ constraint.allowed_values())) {
+ return Status::OK();
+ }
+ }
+ }
+ } else {
+ return errors::InvalidArgument(
+ "OpKernel '", kernel_def.op(), "' has constraint on attr '",
+ constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
+ "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
+ }
+ }
+ *match = true;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/kernel_def_util.h b/tensorflow/core/framework/kernel_def_util.h
new file mode 100644
index 0000000000..b973cefc4f
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+
+namespace tensorflow {
+
+// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
+// an error if attrs in kernel_def are not found, or have a mismatching type.
+Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
+ bool* match);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def_util_test.cc b/tensorflow/core/framework/kernel_def_util_test.cc
new file mode 100644
index 0000000000..a2e4aa82fa
--- /dev/null
+++ b/tensorflow/core/framework/kernel_def_util_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/kernel_def_util.h"
+
+#include "tensorflow/core/framework/kernel_def.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+NodeDef NodeDefFromText(const string& text) {
+ NodeDef node_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def));
+ return node_def;
+}
+
+KernelDef KernelDefFromText(const string& text) {
+ KernelDef kernel_def;
+ EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &kernel_def));
+ return kernel_def;
+}
+
+class AttrsMatchTest : public ::testing::Test {
+ protected:
+ void ExpectStatus(const string& node_def_str, const string& kernel_def_str,
+ error::Code code) {
+ bool match;
+ auto status = KernelAttrsMatch(KernelDefFromText(kernel_def_str),
+ NodeDefFromText(node_def_str), &match);
+ LOG(INFO) << "status: " << status;
+ EXPECT_EQ(code, status.code());
+ if (!status.ok()) {
+ EXPECT_FALSE(match)
+ << "Expect no match between the given NodeDef and KernelDef";
+ }
+ }
+};
+
+TEST_F(AttrsMatchTest, ValidConstraint) {
+ string node_def_str = R"(
+ name: "ValidConstraint-op"
+ op: "ValidConstraint"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "ValidConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::OK);
+}
+
+TEST_F(AttrsMatchTest, BadConstraint) {
+ string node_def_str = R"(
+ name: "BadConstraint-op"
+ op: "BadConstraint"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "BadConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::INVALID_ARGUMENT);
+}
+
+TEST_F(AttrsMatchTest, Unimplemented) {
+ string node_def_str = R"(
+ name: "BadConstraint-op"
+ op: "BadConstraint"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ )";
+ string kernel_def_str = R"(
+ op: "BadConstraint"
+ device_type: "CPU"
+ constraint {
+ name: "T"
+ allowed_values {
+ list {
+ }
+ }
+ }
+ )";
+ ExpectStatus(node_def_str, kernel_def_str, error::UNIMPLEMENTED);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc
index 270118bb67..6dff6fe654 100644
--- a/tensorflow/core/framework/memory_types.cc
+++ b/tensorflow/core/framework/memory_types.cc
@@ -60,13 +60,18 @@ void MemoryTypesHelper(const NameRangeMap& name_map,
host_memory_args->resize(keep);
}
+bool IsFunctionCallOp(const string& op_type) {
+ return op_type == "SymbolicGradient" || op_type == "PartitionedCall" ||
+ op_type == "StatefulPartitionedCall";
+}
+
+} // namespace
+
MemoryType MTypeFromDType(const DataType dtype) {
return (dtype == DT_INT32 || DataTypeAlwaysOnHost(dtype)) ? HOST_MEMORY
: DEVICE_MEMORY;
}
-} // namespace
-
Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
const DeviceType& device_type, const NodeDef& ndef,
MemoryTypeVector* inp_mtypes,
@@ -94,7 +99,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// TODO(zhifengc,phawkins): We should do type inference over function bodies
// to derive the correct input/output memory types. We should also split
// host-memory and non host-memory arguments into separate type lists.
- if (!status.ok() || ndef.op() == "SymbolicGradient") {
+ if (!status.ok() || IsFunctionCallOp(ndef.op())) {
for (const auto& t : inp_dtypes) inp_mtypes->push_back(MTypeFromDType(t));
for (const auto& t : out_dtypes) out_mtypes->push_back(MTypeFromDType(t));
return Status::OK();
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index a816c15140..e8ea904ebd 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -694,4 +694,17 @@ void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
ADD_ATTR(bool)
#undef ADD_ATTR
+Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
+ NodeDef* node_def) {
+ node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
+ if (node_def->op() == "Enter" || node_def->op() == "RefEnter") {
+ string frame_name;
+ TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
+ AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
+ frame_name = strings::StrCat(prefix, frame_name, suffix);
+ attr.set_s(frame_name);
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index ce7818a31c..64c8b386e8 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -299,6 +299,11 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
Status AttachDef(const Status& status, const NodeDef& node_def);
Status AttachDef(const Status& status, const Node& node);
+// Appends the given prefix and suffix to the original node name in order to
+// make the name unique. If it's an "Enter" node, use the same way to reset
+// attribute "frame_name".
+Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
+ NodeDef* node_def);
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 2a49425dba..35b7b2272b 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -495,5 +495,19 @@ TEST(NameRangesForNodeTest, TypeList) {
EXPECT_FALSE(NameRangesForNode(bad_node_def, op_def, &inputs, &outputs).ok());
}
+TEST(AddPrefixAndSuffixToNode, Enter) {
+ NodeDef node_def;
+ node_def.set_name("enter");
+ node_def.set_op("Enter");
+ AddNodeAttr("frame_name", "test_frame", &node_def);
+ const string prefix = "prefix/";
+ const string suffix = "/suffix";
+ TF_ASSERT_OK(AddPrefixAndSuffixToNode(prefix, suffix, &node_def));
+ EXPECT_EQ("prefix/enter/suffix", node_def.name());
+ string frame_name;
+ TF_ASSERT_OK(GetNodeAttr(node_def, "frame_name", &frame_name));
+ EXPECT_EQ("prefix/test_frame/suffix", frame_name);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b05a9df7c1..58feec90f0 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/graph.pb_text.h"
#include "tensorflow/core/framework/kernel_def.pb_text.h"
+#include "tensorflow/core/framework/kernel_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -262,11 +263,13 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
outputs_(num_outputs),
temp_memory_allocated_(0),
persistent_memory_allocated_(0) {
- Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
params_->ensure_eigen_gpu_device();
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ if (params_->eigen_gpu_device != nullptr) {
+ Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
+ params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
+ params_->op_device_context,
+ eigen_gpu_allocator);
+ }
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
}
@@ -969,62 +972,6 @@ void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
namespace {
-// Helper for AttrsMatch().
-bool InTypeList(DataType dt, const AttrValue& type_list) {
- for (int in_list : type_list.list().type()) {
- if (dt == in_list) return true;
- }
- return false;
-}
-
-// Returns whether the attrs satisfy the constraints in the kernel_def. Returns
-// an error if attrs in kernel_def are not found, or have a mismatching type.
-Status AttrsMatch(AttrSlice attrs, const KernelDef& kernel_def, bool* match) {
- *match = false;
- for (const auto& constraint : kernel_def.constraint()) {
- if (constraint.allowed_values().list().type_size() == 0) {
- return errors::Unimplemented(
- "KernelDef '", ProtoShortDebugString(kernel_def),
- " has constraint on attr '", constraint.name(),
- "' with unsupported type: ",
- SummarizeAttrValue(constraint.allowed_values()));
- }
-
- const AttrValue* found = attrs.Find(constraint.name());
- if (found) {
- if (found->type() != DT_INVALID) {
- if (!InTypeList(found->type(), constraint.allowed_values())) {
- return Status::OK();
- }
- } else {
- if (!AttrValueHasType(*found, "list(type)").ok()) {
- return errors::InvalidArgument(
- "KernelDef '", ProtoShortDebugString(kernel_def),
- "' has constraint on attr '", constraint.name(),
- "' that has value '", SummarizeAttrValue(*found),
- "' that does not have type 'type' or 'list(type)' in NodeDef "
- "'",
- attrs.SummarizeNode(), "'");
- }
-
- for (int t : found->list().type()) {
- if (!InTypeList(static_cast<DataType>(t),
- constraint.allowed_values())) {
- return Status::OK();
- }
- }
- }
- } else {
- return errors::InvalidArgument(
- "OpKernel '", kernel_def.op(), "' has constraint on attr '",
- constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
- "', KernelDef: '", ProtoShortDebugString(kernel_def), "'");
- }
- }
- *match = true;
- return Status::OK();
-}
-
static const StringPiece kKernelAttr("_kernel");
// TODO(irving): Replace with const Node& version below.
@@ -1043,7 +990,7 @@ Status FindKernelRegistration(const DeviceType& device_type,
// If there is a kernel registered for the op and device_type,
// check that the attrs match.
bool match;
- TF_RETURN_IF_ERROR(AttrsMatch(node_def, iter->second.def, &match));
+ TF_RETURN_IF_ERROR(KernelAttrsMatch(iter->second.def, node_def, &match));
if (match) {
if (*reg != nullptr) {
return errors::InvalidArgument(
@@ -1120,6 +1067,16 @@ void LogAllRegisteredKernels() {
}
}
+KernelList GetAllRegisteredKernels() {
+ const KernelRegistry* const typed_registry = GlobalKernelRegistryTyped();
+ KernelList kernel_list;
+ kernel_list.mutable_kernel()->Reserve(typed_registry->size());
+ for (const auto& p : *typed_registry) {
+ *kernel_list.add_kernel() = p.second.def;
+ }
+ return kernel_list;
+}
+
string KernelsRegisteredForOp(StringPiece op_name) {
string ret;
for (const auto& key_registration : *GlobalKernelRegistryTyped()) {
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index f577664709..6c4c3a2ac1 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
#include <functional>
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/control_flow.h"
#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h" // TODO(b/62899350): Remove
@@ -1303,6 +1304,9 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
// missing kernel errors.
void LogAllRegisteredKernels();
+// Gets a list of all registered kernels.
+KernelList GetAllRegisteredKernels();
+
namespace kernel_factory {
class OpKernelRegistrar {
@@ -1572,4 +1576,4 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc
index bcd409e5c5..b76a3400a8 100644
--- a/tensorflow/core/framework/op_kernel_test.cc
+++ b/tensorflow/core/framework/op_kernel_test.cc
@@ -964,5 +964,27 @@ void BM_SelectInputRange(int iters) {
BENCHMARK(BM_ConcatInputRange);
BENCHMARK(BM_SelectInputRange);
+TEST(RegisteredKernels, CanCallGetAllRegisteredKernels) {
+ auto all_registered_kernels = GetAllRegisteredKernels().kernel();
+ auto has_name_test1 = [](const KernelDef& k) { return k.op() == "Test1"; };
+
+ // Verify we can find the "Test1" op registered above
+ auto test1_it = std::find_if(all_registered_kernels.begin(),
+ all_registered_kernels.end(), has_name_test1);
+ ASSERT_NE(test1_it, all_registered_kernels.end());
+ EXPECT_EQ(test1_it->device_type(), "CPU");
+
+ // Verify there was just one kernel
+ ++test1_it;
+ EXPECT_EQ(
+ std::find_if(test1_it, all_registered_kernels.end(), has_name_test1),
+ all_registered_kernels.end());
+}
+
+// Simple test just to check we can call LogAllRegisteredKernels
+TEST(RegisteredKernels, CanLogAllRegisteredKernels) {
+ tensorflow::LogAllRegisteredKernels();
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 21fc6c1bd5..0a19861efd 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -60,8 +60,8 @@ namespace internal {
Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) {
if (ctx->device()->attributes().name() != p.device()) {
return errors::InvalidArgument(
- "Trying to access resource located in device ", p.device(),
- " from device ", ctx->device()->attributes().name());
+ "Trying to access resource ", p.name(), " located in device ",
+ p.device(), " from device ", ctx->device()->attributes().name());
}
return Status::OK();
}
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 813ec6eed5..0a8da8b3bf 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->allocate_persistent(DT_STRING, TensorShape({2}),
- &handle_, nullptr));
+ has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
+ if (!has_resource_type_) {
+ // The resource variant of the op may be placed on non-CPU devices, but
+ // this allocation is always on the host. Fortunately we don't need it in
+ // the resource case.
+ OP_REQUIRES_OK(context,
+ context->allocate_persistent(DT_STRING, TensorShape({2}),
+ &handle_, nullptr));
+ }
}
// The resource is deleted from the resource manager only when it is private
@@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}
- auto h = handle_.AccessTensor(context)->template flat<string>();
- h(0) = cinfo_.container();
- h(1) = cinfo_.name();
+ if (!has_resource_type_) {
+ auto h = handle_.AccessTensor(context)->template flat<string>();
+ h(0) = cinfo_.container();
+ h(1) = cinfo_.name();
+ }
resource_ = resource;
}
- if (context->expected_output_dtype(0) == DT_RESOURCE) {
+ if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
@@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }
PersistentTensor handle_ GUARDED_BY(mu_);
+
+ // Is the output of the operator of type DT_RESOURCE?
+ bool has_resource_type_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h
index 872b8f8b30..ff7b3e78a7 100644
--- a/tensorflow/core/framework/resource_var.h
+++ b/tensorflow/core/framework/resource_var.h
@@ -29,6 +29,8 @@ class Var : public ResourceBase {
Var(const Var&) = delete;
Var& operator=(const Var&) = delete;
+ // When locking multiple variables, the locks must be acquired in order of
+ // increasing mu() address.
// TODO(ebrevdo): Use LockSet instead of exposing mu.
mutex* mu() { return &mu_; }
Tensor* tensor() { return &tensor_; }
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index b02bc3adbe..8d597e198d 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -340,6 +340,20 @@ string InferenceContext::DebugString() const {
ProtoDebugString(*node_def_));
}
+string InferenceContext::DebugString(const ShapeAndType& shape_and_type) {
+ return strings::StrCat(DebugString(shape_and_type.shape), ":",
+ DataTypeString(shape_and_type.dtype));
+}
+
+string InferenceContext::DebugString(
+ gtl::ArraySlice<ShapeAndType> shape_and_types) {
+ std::vector<string> pieces;
+ for (const ShapeAndType& s : shape_and_types) {
+ pieces.push_back(DebugString(s));
+ }
+ return strings::StrCat("[", str_util::Join(pieces, ","), "]");
+}
+
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
ShapeHandle* out) {
if (rank > kint32max) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 3f3729dcf9..81258b55b3 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -381,6 +381,8 @@ class InferenceContext {
string DebugString(ShapeHandle s);
string DebugString(DimensionHandle d);
+ string DebugString(const ShapeAndType& shape_and_type);
+ string DebugString(gtl::ArraySlice<ShapeAndType> shape_and_types);
// Describes the whole context, for debugging purposes.
string DebugString() const;
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 8002d9291c..4a18efc940 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -57,6 +57,10 @@ class StatsAggregator {
// interface. It is possible that not all implementations will support
// encoding their state as a protocol buffer.
virtual void EncodeToProto(Summary* out_summary) = 0;
+
+ // Increment the `label` cell of metrics mapped with `name` by given `value`.
+ virtual void IncrementCounter(const string& name, const string& label,
+ int64 val) = 0;
};
// A `StatsAggregatorResource` wraps a shareable `StatsAggregator` as a resource
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index ded6aa0991..ff7c9855d6 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -470,6 +470,10 @@ inline bool DataTypeIsUnsigned(DataType dt) {
// Returns a 0 on failure
int DataTypeSize(DataType dt);
+// Returns HOST_MEMORY if `dtype` is always on host or is a DT_INT32,
+// DEVICE_MEMORY otherwise.
+MemoryType MTypeFromDType(const DataType dtype);
+
// Types that always sit on host: DT_STRING, DT_STRING_REF, DT_RESOURCE.
// For DT_RESOURCE, the handle always sits on host (even if the underlying
// object has device-allocated resources).
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index 30ff19cd7e..1778e48ef6 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -23,9 +23,66 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace {
+// Information about a loop frame structure.
+struct Frame {
+ string name;
+
+ // Pointer to the parent frame. The root frame has a pointer to itself.
+ Frame* parent = nullptr;
+
+ // The loop condition of the loop. There should be exactly one loop condition
+ // in every loop.
+ const Node* loop_cond = nullptr;
+};
+
+// Verify that the ControlFlowInfo of the graph has valid loop structure.
+Status ValidateControlFlowInfo(const Graph* graph,
+ const std::vector<ControlFlowInfo>& cf_info) {
+ std::unordered_map<string, Frame> frames;
+ for (const Node* node : graph->op_nodes()) {
+ const ControlFlowInfo& cf = cf_info[node->id()];
+ if (!cf.frame || !cf.parent_frame) {
+ // Skip nodes unreachable from the source node. They might be pruned
+ // later.
+ continue;
+ }
-Status BuildControlFlowInfo(const Graph* g,
- std::vector<ControlFlowInfo>* info) {
+ Frame& frame = frames[cf.frame_name];
+ Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name];
+ if (frame.parent == nullptr) {
+ frame.parent = parent;
+ frame.name = cf.frame_name;
+ } else if (frame.parent != parent) {
+ return errors::InvalidArgument(
+ "Invalid loop structure: Mismatched parent frames for \"",
+ cf.frame_name, "\": \"", parent->name, "\" vs \"", frame.parent->name,
+ "\". This is an internal bug, please file a bug report with "
+ "instructions on how to reproduce the error.");
+ }
+ if (IsLoopCond(node)) {
+ // ForwardLoopCounter runs in the same frame as the forward loop and
+ // BackPropLoopCounter runs in the same frame as the backprop loop. They
+ // are the only cases that multiple loops share the same frame.
+ if (frame.loop_cond &&
+ !str_util::StrContains(frame.loop_cond->name(), "LoopCounter") &&
+ !str_util::StrContains(node->name(), "LoopCounter")) {
+ return errors::InvalidArgument(
+ "Invalid loop structure: Loop \"", cf.frame_name,
+ "\" has more than one LoopCond node: \"", node->name(), "\" and \"",
+ frame.loop_cond->name(),
+ "\". This is an internal bug, please file a bug report with "
+ "instructions on how to reproduce the error.");
+ }
+ frame.loop_cond = node;
+ }
+ }
+ return Status::OK();
+}
+} // namespace
+
+Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes) {
info->clear();
info->resize(g->num_node_ids());
@@ -114,6 +171,14 @@ Status BuildControlFlowInfo(const Graph* g,
}
}
}
+ if (unreachable_nodes) {
+ for (const Node* node : g->op_nodes()) {
+ if (!parent_nodes[node->id()]) {
+ unreachable_nodes->push_back(node->name());
+ }
+ }
+ }
+ TF_RETURN_IF_ERROR(ValidateControlFlowInfo(g, *info));
return Status::OK();
}
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 79e2be0d4b..548820720b 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -31,13 +31,20 @@ struct ControlFlowInfo {
};
// Clear and populate `info` with each node's frame and the level it belongs to.
-// We check the well-formedness of the graph: All inputs to a node must come
-// from the same frame and have the same "static" iteration level.
+// We check the well-formedness of the graph:
+// 1) All inputs to a node must come from the same frame and have the same
+// "static" iteration level.
+// 2) Each frame has at most one LoopCond node.
+// 3) Each frame has a single parent frame.
+// If `unreachable_nodes` is set, return names of nodes unreachable from the
+// source node. We cannot build ControlFlowInfo for such nodes. They might be
+// pruned later.
//
// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0.
// This essentially means there can't be multiple serial Nexts in an iteration,
// which all sane front-ends should satisfy.
-Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info);
+Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
+ std::vector<string>* unreachable_nodes = nullptr);
} // namespace tensorflow
diff --git a/tensorflow/core/graph/control_flow_test.cc b/tensorflow/core/graph/control_flow_test.cc
new file mode 100644
index 0000000000..eb7937400f
--- /dev/null
+++ b/tensorflow/core/graph/control_flow_test.cc
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/graph/control_flow.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+Status LessThanTenCond(const Scope& scope, const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ops::Less(scope, inputs[0], 10);
+ return scope.status();
+}
+
+Status AddOneBody(const Scope& scope, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::AddN(scope, {inputs[0], 1}));
+ return scope.status();
+}
+
+Status NestedLoopBody(const Scope& scope, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs,
+ LessThanTenCond, AddOneBody, "inner_loop",
+ outputs);
+}
+
+TEST(ValidateControlFlowTest, InputsFromDifferentFrames) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("outer"), inputs,
+ LessThanTenCond, NestedLoopBody,
+ "outer_loop", &outputs));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame
+ // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'.
+ std::vector<ControlFlowInfo> info;
+ Status status = BuildControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "has inputs from different frames"))
+ << status.error_message();
+}
+
+TEST(ValidateControlFlowTest, MismatchedParentFrames) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
+ "test_loop", &outputs));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ Node* enter_1 = nullptr;
+ for (Node* node : graph->op_nodes()) {
+ if (IsEnter(node)) {
+ enter_1 = node;
+ }
+ }
+ ASSERT_TRUE(enter_1 != nullptr);
+
+ NodeDef enter;
+ enter.set_name("Enter2");
+ enter.set_op("Enter");
+ (*enter.mutable_attr())["T"].set_type(DT_INT32);
+ (*enter.mutable_attr())["frame_name"].set_s("test_loop");
+ *enter.add_input() = "Enter";
+ Status status;
+ Node* enter_2 = graph->AddNode(enter, &status);
+ TF_ASSERT_OK(status);
+ graph->AddControlEdge(enter_1, enter_2);
+
+ // SOURCE("") --> Enter("test_loop") --> Enter2("test_loop")
+ // For node 'Enter', the parent frame of "test_loop" is empty.
+ // For node 'Enter2', the parent frame of "test_loop" is "test_loop".
+ std::vector<ControlFlowInfo> info;
+ status = BuildControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(status.error_message(), "Mismatched parent frames"))
+ << status.error_message();
+}
+
+TEST(ValidateControlFlowTest, TwoLoopCond) {
+ // Test that one frame has at most one LoopCond node. This is necessary for
+ // functionalize control flow.
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ std::vector<Output> inputs;
+ inputs.push_back(ops::Placeholder(scope, DT_INT32));
+ std::vector<Output> outputs;
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope, inputs, LessThanTenCond, AddOneBody,
+ "test_loop", &outputs));
+ outputs.clear();
+ TF_ASSERT_OK(ops::BuildWhileLoop(scope.NewSubScope("sub"), inputs,
+ LessThanTenCond, AddOneBody, "test_loop",
+ &outputs, false));
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+ std::vector<ControlFlowInfo> info;
+ Status status = BuildControlFlowInfo(graph.get(), &info);
+ EXPECT_FALSE(status.ok());
+ EXPECT_TRUE(str_util::StrContains(status.error_message(),
+ "more than one LoopCond node"))
+ << status.error_message();
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index 6b56613470..c1a8a63784 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -106,8 +106,15 @@ static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
AddNodeAttr("Tin", in_types, &ndef);
// The gradient node's outputs have the same types as the node 'n's
- // inputs.
- AddNodeAttr("Tout", n->input_types(), &ndef);
+ // inputs, except for resources.
+ DataTypeVector out_types = n->input_types();
+ for (int i = 0; i < out_types.size(); ++i) {
+ if (out_types[i] == DT_RESOURCE) {
+ // TODO(apassos): figure out how to get the right dtype
+ out_types[i] = DT_FLOAT;
+ }
+ }
+ AddNodeAttr("Tout", out_types, &ndef);
NameAttrList func;
func.set_name(n->type_string());
for (const auto& attr : n->attrs()) {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 0f748515ef..568f0870c0 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -265,6 +266,28 @@ Status Node::input_node(int idx, const Node** const_n) const {
return Status::OK();
}
+// InputTensor
+
+bool InputTensor::operator==(const InputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
+// OutputTensor
+
+bool OutputTensor::operator==(const OutputTensor& other) const {
+ return node == other.node && index == other.index;
+}
+
+uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.index));
+}
+
// Graph
Graph::Graph(const OpRegistryInterface* ops)
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 33fb7cb57a..a147c94689 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -284,6 +284,16 @@ struct InputTensor {
InputTensor(const Node* n, int i) : node(n), index(i) {}
InputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this InputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const InputTensor& other) const;
+
+ // A hash function for InputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(InputTensor const& s) const;
+ };
};
// Represents an output of a node, i.e., the `index`-th output of `node`. Note
@@ -295,6 +305,16 @@ struct OutputTensor {
OutputTensor(const Node* n, int i) : node(n), index(i) {}
OutputTensor() : node(nullptr), index(0) {}
+
+ // Returns true if this OutputTensor is identical to 'other'. Nodes are
+ // compared using pointer equality.
+ bool operator==(const OutputTensor& other) const;
+
+ // A hash function for OutputTensors. Nodes are hashed based on their pointer
+ // value.
+ struct Hash {
+ uint64 operator()(OutputTensor const& s) const;
+ };
};
class Edge {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 0967492d92..add26f3b71 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -79,10 +79,10 @@ class GraphConstructor {
: in.prefix + "/"),
uniquify_names(in.uniquify_names),
uniquify_prefix(in.uniquify_prefix),
- input_map(in.input_map),
+ input_map(in.input_map.begin(), in.input_map.end()),
skip_mapped_nodes(in.skip_mapped_nodes),
control_dependencies(in.control_dependencies),
- return_tensors(in.return_tensors),
+ return_tensors(in.return_tensors.begin(), in.return_tensors.end()),
return_nodes(in.return_nodes),
importing(true),
validate_colocation_constraints(in.validate_colocation_constraints),
@@ -121,7 +121,7 @@ class GraphConstructor {
const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* missing_unused_input_map_keys) {
+ std::vector<SafeTensorId>* missing_unused_input_map_keys) {
if (versions) {
TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
@@ -142,7 +142,7 @@ class GraphConstructor {
ShapeRefiner* refiner,
std::vector<std::pair<Node*, int>>* return_tensors,
std::vector<Node*>* return_nodes,
- std::vector<TensorId>* missing_unused_input_map_keys)
+ std::vector<SafeTensorId>* missing_unused_input_map_keys)
: opts_(opts),
node_defs_(node_defs),
versions_(versions),
@@ -227,6 +227,10 @@ class GraphConstructor {
// already unique in the graph.
string FindUniqueName(StringPiece original_name);
+ // Decrement pending count for users of `processed` and add the ones that now
+ // have all of their pending inputs satisfied to `ready_`.
+ void UpdatePendingCountAndReady(int processed);
+
// From constructor
const Options opts_;
const NodeDefSlice node_defs_;
@@ -247,7 +251,7 @@ class GraphConstructor {
std::vector<Node*>* return_nodes_;
// May be null. Not owned.
- std::vector<TensorId>* missing_unused_input_map_keys_;
+ std::vector<SafeTensorId>* missing_unused_input_map_keys_;
// Intermediate datastructure used to populate
// `missing_unused_input_map_keys_`.
@@ -315,6 +319,25 @@ class GraphConstructor {
std::vector<EdgeInfo> back_edges_;
};
+void GraphConstructor::UpdatePendingCountAndReady(int processed) {
+ // We didn't consider NextIteration->Merge edges when computing
+ // pending_counts_ so we should not have to consider it here either.
+ bool is_next_iteration = IsNextIteration(*node_defs_[processed]);
+ for (size_t i = 0; i < outputs_[processed].size(); ++i) {
+ const int output = outputs_[processed][i];
+ bool is_next_iteration_to_merge_edge =
+ is_next_iteration && IsMerge(*node_defs_[output]);
+ if (!is_next_iteration_to_merge_edge) {
+ int* current_pending_count = &pending_count_[output];
+ CHECK_GT(*current_pending_count, 0);
+ (*current_pending_count)--;
+ if (*current_pending_count == 0) {
+ ready_.insert(output);
+ }
+ }
+ }
+}
+
// This could be expensive but we don't expect to call it often, if at all (only
// if there are multiple nodes in g_ with the same name)
bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
@@ -881,22 +904,6 @@ Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
return Status::OK();
}
-namespace {
-
-void UpdatePendingCountAndReady(
- const std::vector<gtl::InlinedVector<int, 4>>& outputs, int o,
- std::vector<int>* pending_count, std::set<int>* ready) {
- for (size_t i = 0; i < outputs[o].size(); ++i) {
- const int output = outputs[o][i];
- (*pending_count)[output]--;
- if ((*pending_count)[output] == 0) {
- ready->insert(output);
- }
- }
-}
-
-} // anonymous namespace
-
Status GraphConstructor::Convert() {
// Import functions before adding nodes, since imported nodes may refer to
// functions
@@ -938,7 +945,7 @@ Status GraphConstructor::Convert() {
IsNodeFullyMapped(original_node_def, &is_node_mapped));
if (is_node_mapped) {
// Skip this node after updating pending_count_ for outputs
- UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_);
+ UpdatePendingCountAndReady(o);
continue;
}
}
@@ -1031,7 +1038,7 @@ Status GraphConstructor::Convert() {
TF_RETURN_IF_ERROR(ValidateShape(node));
// Update pending_count_ for outputs.
- UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_);
+ UpdatePendingCountAndReady(o);
}
if (processed < node_defs_.size()) {
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index b03d655fe6..889359a68a 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -81,14 +81,14 @@ struct ImportGraphDefOptions {
// corresponding to `input_map` keys will be remapped to the nodes in `g`
// corresponding to the values.
//
- // Keys should not include `prefix`, i.e., a key TensorId's name should be the
- // name as it originally appears in `gdef`.
+ // Keys should not include `prefix`, i.e., a key ID's name should be the name
+ // as it originally appears in `gdef`.
//
// If this is non-empty, ImportGraphDef must be called with the shape refiner
// used to create the existing nodes referenced in `input_map`.
// TODO(skyewm): can we remove this requirement? How do we access the original
// shape refiner?
- std::map<TensorId, TensorId> input_map;
+ std::map<SafeTensorId, SafeTensorId> input_map;
// If true, nodes that will have all output edges removed because of
// overrides in `input_map` will not be imported.
@@ -107,12 +107,12 @@ struct ImportGraphDefOptions {
// caller must pass a results object to `ImportGraphDef()`. The
// `return_tensors` field will be populated with the imported nodes in `g`.
//
- // Entries should not include `prefix`, i.e., each TensorId's name should be
- // the name as it originally appears in `gdef`.
+ // Entries should not include `prefix`, i.e., each ID's name should be the
+ // name as it originally appears in `gdef`.
//
// If this contains a tensor that's also being remapped via `input_map`, the
// corresponding existing tensor in `g` will be returned.
- std::vector<TensorId> return_tensors;
+ std::vector<SafeTensorId> return_tensors;
// The names of nodes in `gdef` that will be returned via the
// ImportGraphDefResults output parameter of `ImportGraphDef()`. If this list
@@ -155,7 +155,7 @@ struct ImportGraphDefResults {
// Keys in ImportGraphDefOptions::input_map that don't appear in `gdef` and
// weren't used as an input to any node in `gdef`. These keys are likely due
// to typos, and callers may wish to treat their existence as an error.
- std::vector<TensorId> missing_unused_input_map_keys;
+ std::vector<SafeTensorId> missing_unused_input_map_keys;
};
// Adds the graph in GraphDef `gdef` into an existing Graph `*g`.
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 6309870190..e338840eeb 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1502,7 +1502,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts, &refiner, &results);
ASSERT_EQ(results.missing_unused_input_map_keys.size(), 1);
- EXPECT_EQ(results.missing_unused_input_map_keys[0], TensorId("new_input", 2));
+ EXPECT_EQ(results.missing_unused_input_map_keys[0],
+ SafeTensorId("new_input", 2));
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithUnboundInput) {
@@ -2748,6 +2749,51 @@ TEST_F(GraphConstructorTest, ImportGraphDef_NestedFunctionDefs) {
EXPECT_EQ(outputs[0].scalar<float>()(), 3.0);
}
+// NOTE(skyewm): the C API depends on this behavior, but it's easier to write
+// the test here.
+TEST_F(GraphConstructorTest, ImportGraphDef_OptionsMemMgmt) {
+ ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, graph_.op_registry());
+
+ // Populate graph with node we'll use in input map
+ ExpectOK("node { name: 'input' op: 'TestInput' }", ImportGraphDefOptions(),
+ &refiner);
+
+ // Add some strings to ImportGraphDefOptions and then rewrite the buffers.
+ char buf1[100];
+ char buf2[100];
+ char buf3[100];
+ snprintf(buf1, sizeof(buf1), "input");
+ snprintf(buf2, sizeof(buf2), "new_input");
+ snprintf(buf3, sizeof(buf3), "t1");
+
+ ImportGraphDefOptions opts;
+ opts.input_map[TensorId(buf2, 0)] = TensorId(buf1, 0);
+ opts.return_tensors.push_back(TensorId(buf3, 0));
+
+ snprintf(buf1, sizeof(buf1), "xxxxxxxxxxxxxxxxxxxx");
+ snprintf(buf2, sizeof(buf2), "xxxxxxxxxxxxxxxxxxxx");
+ snprintf(buf3, sizeof(buf3), "xxxxxxxxxxxxxxxxxxxx");
+
+ // Import some new nodes using opts.
+ ImportGraphDefResults results;
+ ExpectOK(
+ R"EOF(
+ node { name: 'new_input' op: 'TestInput' }
+ node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
+ )EOF",
+ opts, &refiner, &results);
+
+ EXPECT_TRUE(HasNode("input"));
+ EXPECT_TRUE(HasNode("new_input"));
+ EXPECT_TRUE(HasNode("t1"));
+
+ EXPECT_TRUE(HasEdge("input", 0, "t1", 0));
+ EXPECT_TRUE(HasEdge("new_input", 1, "t1", 1));
+
+ ASSERT_EQ(results.return_tensors.size(), 1);
+ EXPECT_EQ(results.return_tensors[0].first->name(), "t1");
+}
+
TEST_F(GraphConstructorTest, CopyGraph) {
const int v = TF_GRAPH_DEF_VERSION;
const int bad = v + 17;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 7645b4a7f0..fc474c0dc8 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -1901,6 +1901,11 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
#else // INTEL_MKL_ML
+// NOTE: Unit tests in this file rely on a topological sorted graph for
+// printing. But since sibling nodes of a node in the topologically sorted graph
+// can be printed in different orders, tests may fail if the order in which
+// sibling nodes are visited is changed.
+
namespace {
const char kCPUDevice[] = "/job:a/replica:0/task:0/device:CPU:0";
@@ -2572,9 +2577,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Zeta)|A->E;A->I;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;"
- "B->E:1;C->F;C:control->DMT/_0:control;C:control->DMT/_1:control;"
- "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+ "A:control->DMT/_0:control;A:control->DMT/_1:control;"
+ "B->E:1;C->F;C:control->DMT/_2:control;C:control->DMT/_3:control;"
+ "D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:3;E->H:1;E:2->H:4;F->H:2;F:2->H:5;G->H;"
"G:control->DMT/_4:control;H->I:1");
}
@@ -2681,9 +2686,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Zeta)|A->E;A->I;"
- "A:control->DMT/_2:control;A:control->DMT/_3:control;B->E:1;C->F;"
- "C:control->DMT/_0:control;C:control->DMT/_1:control;"
- "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+ "A:control->DMT/_0:control;A:control->DMT/_1:control;B->E:1;C->F;"
+ "C:control->DMT/_2:control;C:control->DMT/_3:control;"
+ "D->F:1;DMT/_0->E:2;DMT/_1->E:3;DMT/_2->F:2;DMT/_3->F:3;"
"DMT/_4->H:5;E->H;E:2->H:3;E:control->DMT/_4:control;F->H:1;"
"F:2->H:4;G->H:2;H->I:1");
}
@@ -3060,8 +3065,8 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
- "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
- "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
+ "D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;DMT/_3->F:3;"
+ "DMT/_4->F:7;DMT/_5->F:4;DMT/_6->F:6;E->G;F->G:1");
}
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 193cf88aed..60337e30aa 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -81,7 +81,9 @@ Status FeedInputs(
// Update name_index
(*name_index)[feed_node->name()] = feed_node;
- g->AddControlEdge(g->source_node(), feed_node);
+ // Duplicate control edges aren't allowed, but feed_node was *just* created
+ // so there's no need to check for a duplicate.
+ g->AddControlEdge(g->source_node(), feed_node, true);
// Look through edges coming out of "n" for edges whose src_output() index
// matches "output_index". If found, replace the edges with a connection
@@ -107,7 +109,9 @@ Status FeedInputs(
g->AddEdge(feed_node, 0, e->dst(), e->dst_input());
} else {
CHECK_EQ(Graph::kControlSlot, e->src_output());
- g->AddControlEdge(feed_node, e->dst());
+ // Duplicate control edges aren't allowed, but feed_node was *just*
+ // created so there's no need to check for a duplicate.
+ g->AddControlEdge(feed_node, e->dst(), true);
}
g->RemoveEdge(e);
}
@@ -160,7 +164,9 @@ Status FetchOutputs(
// Update the index.
(*name_index)[fetch_node->name()] = fetch_node;
- g->AddControlEdge(fetch_node, g->sink_node());
+ // Duplicate control edges aren't allowed, but fetch_node was *just* created
+ // so there's no need to check for a duplicate.
+ g->AddControlEdge(fetch_node, g->sink_node(), true);
out_fetch_nodes->push_back(fetch_node);
out_fetch_types->push_back(BaseType(n->output_type(id.second)));
}
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 8af1936d64..80c76df255 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -22,6 +22,11 @@ limitations under the License.
namespace tensorflow {
+TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
+
+SafeTensorId::SafeTensorId(const TensorId& id)
+ : SafeTensorId(id.first.ToString(), id.second) {}
+
TensorId ParseTensorName(const string& name) {
return ParseTensorName(StringPiece(name.data(), name.size()));
}
diff --git a/tensorflow/core/graph/tensor_id.h b/tensorflow/core/graph/tensor_id.h
index c27120f7e6..0ba3942618 100644
--- a/tensorflow/core/graph/tensor_id.h
+++ b/tensorflow/core/graph/tensor_id.h
@@ -25,6 +25,8 @@ limitations under the License.
namespace tensorflow {
+struct SafeTensorId;
+
// Identifier for a tensor within a step.
// first == operation_name, second == output_index
// Note: does not own backing storage for name.
@@ -34,6 +36,11 @@ struct TensorId : public std::pair<StringPiece, int> {
// Inherit the set of constructors.
using Base::pair;
+ // NOTE(skyewm): this is required on some platforms. I'm not sure why the
+ // using statement above isn't always sufficient.
+ TensorId() : Base() {}
+ TensorId(const SafeTensorId& id);
+
string ToString() const {
if (second == Graph::kControlSlot) return strings::StrCat("^", first);
return strings::StrCat(first, ":", second);
@@ -50,6 +57,30 @@ struct TensorId : public std::pair<StringPiece, int> {
TensorId ParseTensorName(const string& name);
TensorId ParseTensorName(StringPiece name);
+// Same as TensorId, except owns the backing storage for the op name. This makes
+// the memory management simpler at the expense of a copy.
+struct SafeTensorId : public std::pair<string, int> {
+ typedef std::pair<string, int> Base;
+
+ // NOTE(skyewm): this is required on some platforms. I'm not sure why the
+ // using "using Base::pair;" isn't always sufficient.
+ SafeTensorId() : Base() {}
+ SafeTensorId(const string& str, int idx) : Base(str, idx) {}
+ SafeTensorId(const TensorId& id);
+
+ string ToString() const {
+ if (second == Graph::kControlSlot) return strings::StrCat("^", first);
+ return strings::StrCat(first, ":", second);
+ }
+
+ struct Hasher {
+ public:
+ std::size_t operator()(const TensorId& x) const {
+ return Hash32(x.first.data(), x.first.size(), x.second);
+ }
+ };
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_GRAPH_TENSOR_ID_H_
diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc
index bd905651d2..e44eb91d48 100644
--- a/tensorflow/core/graph/validate.cc
+++ b/tensorflow/core/graph/validate.cc
@@ -59,5 +59,59 @@ void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) {
RemoveDescriptionsFromOpList(op_list);
}
+Status ValidateGraphHasNoCycle(const Graph& graph) {
+ // A node is ready when all of its inputs have been visited.
+ std::vector<const Node*> ready;
+ std::vector<int> pending_count(graph.num_node_ids(), 0);
+
+ for (int i = 0; i < graph.num_node_ids(); ++i) {
+ const Node* n = graph.FindNodeId(i);
+ if (n == nullptr) continue;
+ pending_count[i] = n->in_edges().size();
+ if (n->IsMerge()) {
+ // While-loop cycles are legal cycles so we manually adjust the
+ // pending_count to make sure that the loop is visited.
+ for (const Edge* e : n->in_edges()) {
+ if (!e->IsControlEdge() && e->src()->IsNextIteration()) {
+ pending_count[i]--;
+ }
+ }
+ }
+ if (pending_count[i] == 0) {
+ ready.push_back(n);
+ }
+ }
+
+ int processed = 0;
+ while (!ready.empty()) {
+ const Node* node = ready.back();
+ ready.pop_back();
+ ++processed;
+
+ for (const Edge* out : node->out_edges()) {
+ const int output_id = out->dst()->id();
+ pending_count[output_id]--;
+ if (pending_count[output_id] == 0) {
+ ready.push_back(out->dst());
+ }
+ }
+ }
+
+ if (processed < graph.num_nodes()) {
+ std::vector<string> nodes_in_cycle;
+ for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
+ ++i) {
+ if (pending_count[i] != 0) {
+ nodes_in_cycle.push_back(graph.FindNodeId(i)->name());
+ }
+ }
+ return errors::InvalidArgument(
+ "Graph is invalid, contains a cycle with ",
+ graph.num_nodes() - processed,
+ " nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
+ }
+ return Status::OK();
+}
+
} // namespace graph
} // namespace tensorflow
diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h
index cda93fe1de..08879dca60 100644
--- a/tensorflow/core/graph/validate.h
+++ b/tensorflow/core/graph/validate.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -50,6 +51,14 @@ Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def,
void GetOpListForValidation(
OpList* op_list, const OpRegistry& op_registry = *OpRegistry::Global());
+// Validate that the graph has no cycle except for legal while loop cycles.
+// This traverses the specified nodes in topological order to verify there are
+// no cycles. Starting with inputless nodes, it visits nodes whose inputs have
+// all been visited, and counts the total number of visited nodes. If there is a
+// cycle, nodes in the cycle will never be visited, and the visited count will
+// be less than the total node count.
+Status ValidateGraphHasNoCycle(const Graph& graph);
+
} // namespace graph
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index 30c6126fbb..ab8f4bebb3 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -20,6 +20,9 @@ tf_cuda_library(
name = "utils",
srcs = ["utils.cc"],
hdrs = ["utils.h"],
+ cuda_deps = [
+ "@local_config_cuda//cuda:cudnn_header",
+ ],
visibility = ["//visibility:public"],
deps = [
"//third_party/eigen3",
@@ -74,6 +77,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cluster",
+ ":utils",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index d33aaa7e4c..06db36b3aa 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -95,7 +95,7 @@ class Cluster {
// The DeviceSet is not always available, but when it is it contains a
// superset of the devices listed in GetDevices/GetDeviceNames().
- const DeviceSet* GetDeviceSet() const { return device_set_; }
+ virtual const DeviceSet* GetDeviceSet() const { return nullptr; }
// Enables collecting the allocator stats. Call with enable=true must be made
// before Provision().
@@ -124,7 +124,6 @@ class Cluster {
protected:
std::unordered_map<string, DeviceProperties> devices_;
- const DeviceSet* device_set_ = nullptr; // Not owned
const int timeout_s_;
SessionOptions options_;
RunOptions run_options_;
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 313ef90d81..b97603c890 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -368,6 +368,15 @@ Status SingleMachine::ResetSession() {
}
coordinator_.reset(new Coordinator());
+ // Build the DeviceSet.
+ device_set_.reset(new DeviceSet);
+ const DeviceMgr* device_mgr;
+ TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
+ for (auto d : device_mgr->ListDevices()) {
+ device_set_->AddDevice(d);
+ // We currently don't care about the client device.
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
index 0ae188e0d6..c0421dd4de 100644
--- a/tensorflow/core/grappler/clusters/single_machine.h
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -43,6 +43,8 @@ class SingleMachine : public Cluster {
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_.get(); }
+
Status EnablePeakMemoryStats(bool enable) override;
// It requires EnableAllocatorStats(true) be called before Provision().
@@ -73,6 +75,7 @@ class SingleMachine : public Cluster {
int64 expected_init_time_s_;
std::unique_ptr<Coordinator> coordinator_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<DeviceSet> device_set_;
RunMetadata init_metadata_;
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index 352f08fede..31b19cfcfd 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -546,7 +546,7 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before));
EXPECT_EQ(device_peak_memory_before.size(), 1);
// There might be a bit memory used before session's running anything.
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
RunMetadata metadata;
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
@@ -567,8 +567,8 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
// Check memory used by resources are released after cluster destruction.
EXPECT_EQ(device_peak_memory_before.size(), 1);
EXPECT_EQ(device_peak_memory_after.size(), 1);
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
- EXPECT_LT(device_peak_memory_after.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
+ EXPECT_LT(device_peak_memory_after.begin()->second, 400);
}
TEST_F(SingleMachineTest, PeakMemory) {
@@ -597,7 +597,7 @@ TEST_F(SingleMachineTest, PeakMemory) {
device_peak_memory.end());
cpu_memory =
device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
- EXPECT_LT(cpu_memory, 100);
+ EXPECT_LT(cpu_memory, 200);
}
TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) {
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 5c9b2320b5..12e3e46f65 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
@@ -38,11 +39,14 @@ VirtualCluster::VirtualCluster(
devices_ = devices;
}
-VirtualCluster::VirtualCluster(
- const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set)
- : VirtualCluster(devices) {
+VirtualCluster::VirtualCluster(const DeviceSet* device_set)
+ : VirtualCluster(std::unordered_map<string, DeviceProperties>()) {
device_set_ = device_set;
+ for (const auto& device : device_set_->devices()) {
+ DeviceProperties props = GetDeviceInfo(device->parsed_name());
+ if (props.type() == "UNKNOWN") continue;
+ devices_[device->name()] = props;
+ }
}
VirtualCluster::~VirtualCluster() {}
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
index eebac68e1b..6adb0b99bc 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.h
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -36,8 +36,7 @@ class VirtualCluster : public Cluster {
VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
OpLevelCostEstimator* node_estimator,
ReadyNodeManager* node_manager);
- VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set);
+ VirtualCluster(const DeviceSet* device_set);
~VirtualCluster() override;
@@ -48,10 +47,12 @@ class VirtualCluster : public Cluster {
Status Run(const GraphDef& item,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_; }
private:
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
+ const DeviceSet* device_set_ = nullptr; // Not owned
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index 35f11eac29..f3dc2c2091 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -41,6 +41,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":utils",
+ "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:op_types",
@@ -129,6 +130,9 @@ tf_cuda_library(
name = "utils",
srcs = ["utils.cc"],
hdrs = ["utils.h"],
+ cuda_deps = [
+ "@local_config_cuda//cuda:cudnn_header",
+ ],
visibility = ["//visibility:public"],
deps = [
"//third_party/eigen3",
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index c8ba4dfbda..a60e3c7a9f 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -98,6 +98,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
node_costs.compute_time.asMicroSeconds().count());
cost_node->set_memory_time(
node_costs.memory_time.asMicroSeconds().count());
+ cost_node->set_inaccurate(node_costs.inaccurate);
for (const auto& output : op_context.op_info.outputs()) {
auto output_info = cost_node->add_output_info();
output_info->set_dtype(output.dtype());
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d9a08d42db..83a8326e79 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -353,12 +354,12 @@ void VerboseLogUnknownDimensionSources(
class TopoQueue {
public:
explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
- : queue_(CompareNodes(topo_order)) {}
- void push(const NodeDef* n) { queue_.insert(n); }
+ : topo_order_(topo_order) {}
+ void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
- const NodeDef* n = *it;
+ const NodeDef* n = it->first;
queue_.erase(it);
return n;
}
@@ -367,20 +368,16 @@ class TopoQueue {
std::size_t size() const { return queue_.size(); }
private:
+ using NodeAndId = std::pair<const NodeDef*, int>;
// Graph nodes are created in (roughly) topological order. Therefore we can
// use their id to ensure they're sorted topologically.
- struct CompareNodes {
- explicit CompareNodes(
- const std::unordered_map<const NodeDef*, int>& topo_ordering)
- : topo_order(topo_ordering) {}
- bool operator()(const NodeDef* lhs, const NodeDef* rhs) const {
- return topo_order.at(lhs) < topo_order.at(rhs);
+ struct OrderByIdAscending {
+ bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
+ return lhs.second < rhs.second;
}
-
- private:
- const std::unordered_map<const NodeDef*, int>& topo_order;
};
- std::set<const NodeDef*, CompareNodes> queue_;
+ const std::unordered_map<const NodeDef*, int>& topo_order_;
+ std::set<NodeAndId, OrderByIdAscending> queue_;
};
// Processes symbolic shapes.
@@ -426,11 +423,108 @@ class SymbolicShapeRefiner {
return it->second.inference_context.get();
}
- // Forward the shapes from the function's fanin to the function body,
- // then call PropagateShapes.
- // Returns an error if 'node' is not a function node.
- Status UpdateFunction(const NodeDef* node, bool* refined) {
- return UpdateNode(node, refined);
+ // Forward the shapes from the function input nodes to
+ // the argument nodes (which are Placeholder nodes), then
+ // perform shape inference on the function body.
+ //
+ // Propagate shape information of final function body node
+ // to function node `node`.
+ //
+ // In the event of an error, UpdateNode will simply set `node`'s
+ // output shape to be Unknown.
+ Status UpdateFunction(const NodeDef* node) {
+ auto it = fun_to_grappler_function_item_.find(node->op());
+ if (it == fun_to_grappler_function_item_.end()) {
+ return errors::InvalidArgument(
+ node->op(), " was not previously added to SymbolicShapeRefiner.");
+ }
+
+ GrapplerFunctionItem& grappler_function_item = it->second;
+ GraphView gv(&grappler_function_item.graph);
+
+ // Forward shapes from function input nodes to argument nodes.
+ for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
+ auto& fun_input = grappler_function_item.input(i);
+ if (fun_input.placeholders.size() > 1) {
+ // TODO(jmdecker): Handle case with multiple input placeholders
+ return errors::Unimplemented(
+ "Input arguments with multiple placeholders are not yet "
+ "supported.");
+ }
+ NodeDef* fun_node = gv.GetNode(fun_input.input_name);
+ const string& input = node->input(i);
+ const string& node_name = NodeName(input);
+
+ if (IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Function inputs should not contain control nodes.");
+ }
+
+ NodeDef* input_node = graph_.GetNode(node_name);
+ if (input_node == nullptr) {
+ return errors::FailedPrecondition(node_name,
+ " was not found in the graph.");
+ }
+
+ InferenceContext* input_inference_context = GetContext(input_node);
+ if (input_inference_context == nullptr) {
+ return errors::FailedPrecondition(
+ "Inference context has not been created for ", node_name);
+ }
+
+ int output_port_num = NodePosition(input);
+ AttrValue attr_output_shape;
+ TensorShapeProto proto;
+ const auto& handle = input_inference_context->output(output_port_num);
+ input_inference_context->ShapeHandleToProto(handle, &proto);
+ *attr_output_shape.mutable_shape() = proto;
+ (*fun_node->mutable_attr())["shape"] = attr_output_shape;
+ }
+
+ // Perform inference on function body.
+ GraphProperties gp(grappler_function_item);
+ TF_RETURN_IF_ERROR(gp.InferStatically(true));
+
+ // Add return nodes for output shapes.
+ auto ic = GetContext(node);
+ int output = 0;
+ for (auto const& out_arg : grappler_function_item.outputs()) {
+ if (out_arg.output_tensors.size() > 1) {
+ // TODO(jmdecker): Handle case of multiple output tensors
+ return errors::Unimplemented(
+ "Output arguments with multiple output tensors are not yet "
+ "supported.");
+ }
+
+ string out_tensor = out_arg.output_tensors[0];
+ auto out_tensor_pieces = str_util::Split(out_tensor, ",");
+ string node_name = out_tensor_pieces[0];
+ int port_id;
+
+ // Check if port_id was included in out_tensor
+ if (out_tensor_pieces.size() <= 1) {
+ port_id = 0;
+ } else if (!strings::safe_strto32(out_tensor_pieces[1], &port_id)) {
+ return errors::FailedPrecondition(
+ "Failed string to integer conversion for ", out_tensor_pieces[1]);
+ }
+
+ const NodeDef* retnode = gv.GetNode(node_name);
+ if (retnode == nullptr) {
+ return errors::FailedPrecondition("Unable to find return node ",
+ node_name, " for ", node->name());
+ }
+
+ auto output_properties = gp.GetOutputProperties(retnode->name());
+ auto const& outprop = output_properties[port_id];
+ const TensorShapeProto& shape = outprop.shape();
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
+ ic->set_output(output, out);
+ output++;
+ }
+
+ return Status::OK();
}
Status UpdateNode(const NodeDef* node, bool* refined) {
@@ -440,6 +534,7 @@ class SymbolicShapeRefiner {
node_context = CHECK_NOTNULL(GetNodeContext(node));
*refined = true;
}
+
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have, update the node input shapes.
InferenceContext* inference_context = node_context->inference_context.get();
@@ -459,7 +554,8 @@ class SymbolicShapeRefiner {
if (c == nullptr) {
return errors::FailedPrecondition(
"Input ", dst_input, " ('", input->name(), "') for '",
- node->name(), "' was not previously added to ShapeRefiner.");
+ node->name(),
+ "' was not previously added to SymbolicShapeRefiner.");
}
if (IsConstant(*input)) {
@@ -569,6 +665,21 @@ class SymbolicShapeRefiner {
node_context->inference_context->set_input_tensors_as_shapes(
input_tensors_as_shapes);
+ // Properly handle function nodes.
+ if (node_context->op_data && node_context->op_data->is_function_op) {
+ // TODO(jmdecker): Detect if the input shapes have changed for this
+ // function. Note that when we hit a function call node, refined will be
+ // true, as the updates to the call node will have changed, even if it's
+ // the same function being called twice with the same input shapes.
+ // Example: simple_function.pbtxt
+ if (UpdateFunction(node).ok()) {
+ return Status::OK();
+ } else {
+ VLOG(1) << "UpdateFunction failed for " << node->op()
+ << ". Defaulting to ShapeUnknown.";
+ }
+ }
+
// Update the shapes of the outputs.
return InferShapes(*node, node_context);
}
@@ -685,7 +796,39 @@ class SymbolicShapeRefiner {
return true;
}
- Status AddFunction(const NodeDef* node) { return Status::OK(); }
+ Status AddFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
+ if (it != fun_to_grappler_function_item_.end()) {
+ return Status::OK();
+ }
+
+ const FunctionDef* function_def =
+ CHECK_NOTNULL(function_library_.Find(function_node->op()));
+
+ GrapplerFunctionItem grappler_function_item;
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ *function_def, function_library_, &grappler_function_item));
+
+ if (grappler_function_item.inputs().size() > function_node->input_size()) {
+ return errors::FailedPrecondition(
+ "Function input size should be smaller than node input size.");
+ }
+
+ for (int i = grappler_function_item.inputs().size();
+ i < function_node->input_size(); ++i) {
+ const string& input = function_node->input(i);
+ if (!IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Found regular input (", input,
+ ") instead of control nodes for node ", function_node->name());
+ }
+ }
+
+ fun_to_grappler_function_item_[function_def->signature().name()] =
+ grappler_function_item;
+
+ return Status::OK();
+ }
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
@@ -915,6 +1058,8 @@ class SymbolicShapeRefiner {
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
+ std::unordered_map<string, GrapplerFunctionItem>
+ fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
};
@@ -1082,14 +1227,13 @@ Status GraphProperties::UpdateShapes(
// itself.
TF_RETURN_IF_ERROR(
UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
+ } else if (IsQueue(*n)) {
+ // Set shapes and types of Queue ops, if needed.
+ TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
- auto c = shape_refiner->GetNodeContext(n);
- if (c && c->op_data && c->op_data->is_function_op) {
- TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes));
- } else {
- // Rely on regular TF shape refinement for all the other nodes.
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
- }
+ // Rely on regular TF shape refinement for all the other nodes.
+ // UpdateNode calls UpdateFunction if a function node is detected.
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}
@@ -1147,6 +1291,53 @@ Status GraphProperties::PropagateShapes(
return Status::OK();
}
+Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes) {
+ auto ctx = shape_refiner->GetNodeContext(queue_node);
+ if (!ctx) {
+ TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
+ ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
+ }
+ auto* ic = ctx->inference_context.get();
+
+ auto* outputs = ic->output_handle_shapes_and_types(0);
+ if (outputs) {
+ // Shapes and types are already set, presumably by Enqueue ops.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ if (queue_node->attr().count("shapes") <= 0 ||
+ queue_node->attr().count("component_types") <= 0 ||
+ queue_node->attr().at("shapes").list().shape_size() !=
+ queue_node->attr().at("component_types").list().type_size()) {
+ // Errors in shapes and component_types attr.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ // Extract types and shapes from Queue attr.
+ const auto& shapes = queue_node->attr().at("shapes").list().shape();
+ const auto& types = queue_node->attr().at("component_types").list().type();
+ std::vector<ShapeAndType> shapes_and_types;
+ for (int i = 0; i < types.size(); i++) {
+ const auto& shape = shapes[i];
+ ShapeHandle shape_handle;
+ TF_RETURN_IF_ERROR(
+ ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
+ DataType data_type =
+ queue_node->attr().at("component_types").list().type(i);
+ ShapeAndType shape_and_type(shape_handle, data_type);
+ shapes_and_types.push_back(shape_and_type);
+ }
+ ic->set_output_handle_shapes_and_types(0, shapes_and_types);
+
+ // Queue node is updated with output_handle_shapes_and_types, so set
+ // new_shapes and ignore it from UpdateNoe().
+ *new_shapes = true;
+ bool dummy_new_shapes = false;
+ return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
+}
+
Status GraphProperties::UpdateEnqueue(
const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 8703613a12..f716cd72c9 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -91,6 +91,11 @@ class GraphProperties {
resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
+ // Update the shapes and types of the Queue node, if not set by Enqueue node.
+ static Status UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes);
+
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3e44b222fd..1be19d291a 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -262,6 +262,59 @@ TEST_F(GraphPropertiesTest, VarHandles) {
EXPECT_EQ(7, prop.shape().dim(1).size());
}
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: ?", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
+}
+
TEST_F(GraphPropertiesTest, Queues) {
// Create a graph with known input shapes, and propagate the shapes through a
// couple of queues.
@@ -730,7 +783,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
-TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
+TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
// Test graph produced in python using:
/*
@function.Defun(*[tf.float32] * 2, noinline=True)
@@ -743,7 +796,6 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
z = MyAdd(x, y)
z = MyAdd(x, z)
*/
- // Check that the shape inference code infers what it can.
GrapplerItem item;
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
"simple_function.pbtxt");
@@ -753,15 +805,258 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_55e046a8");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_TRUE(out_prop.shape().unknown_rank());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
const auto in_props = properties.GetInputProperties("MyAdd_55e046a8");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "large_function_graph.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto out_props = properties.GetOutputProperties("y0");
+ EXPECT_EQ(2, out_props.size());
+
+ const OpInfo::TensorProperties& out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_EQ(4, out_prop0.shape().dim_size());
+ EXPECT_EQ(128, out_prop0.shape().dim(0).size());
+ EXPECT_EQ(112, out_prop0.shape().dim(1).size());
+ EXPECT_EQ(112, out_prop0.shape().dim(2).size());
+ EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& out_prop1 = out_props[1];
+ EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
+ EXPECT_EQ(128, out_prop1.shape().dim(0).size());
+ EXPECT_EQ(112, out_prop1.shape().dim(1).size());
+ EXPECT_EQ(112, out_prop1.shape().dim(2).size());
+ EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+
+ const auto in_props = properties.GetInputProperties("y0");
+ EXPECT_EQ(4, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop0 = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
+ EXPECT_EQ(1, in_prop0.shape().dim_size());
+ EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_EQ(4, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(1, in_prop1.shape().dim(1).size());
+ EXPECT_EQ(24, in_prop1.shape().dim(2).size());
+ EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& in_prop2 = in_props[2];
+ EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
+ EXPECT_EQ(4, in_prop2.shape().dim_size());
+ EXPECT_EQ(128, in_prop2.shape().dim(0).size());
+ EXPECT_EQ(224, in_prop2.shape().dim(1).size());
+ EXPECT_EQ(224, in_prop2.shape().dim(2).size());
+ EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+
+ const OpInfo::TensorProperties& in_prop3 = in_props[3];
+ EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
+ EXPECT_EQ(4, in_prop3.shape().dim_size());
+ EXPECT_EQ(7, in_prop3.shape().dim(0).size());
+ EXPECT_EQ(7, in_prop3.shape().dim(1).size());
+ EXPECT_EQ(3, in_prop3.shape().dim(2).size());
+ EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_error.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto out_props = properties.GetOutputProperties("MyAdd_yabA4wXEdM4");
+ EXPECT_EQ(1, out_props.size());
+
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_yabA4wXEdM4");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ return tf.add(x, y)
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(0, 1), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
+ EXPECT_EQ(2, in_props.size());
+
const OpInfo::TensorProperties& in_prop = in_props[0];
EXPECT_EQ(DT_FLOAT, in_prop.dtype());
EXPECT_FALSE(in_prop.shape().unknown_rank());
EXPECT_EQ(2, in_prop.shape().dim_size());
EXPECT_EQ(1, in_prop.shape().dim(0).size());
EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ return tf.add(x, y)
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch_2.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+}
+
+TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
+ // Test graph produced in python using:
+ /*
+ @function.Defun(*[tf.float32] * 2, noinline=True)
+ def MyAdd(x, y):
+ a = tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ b = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
+ c = tf.add(x, a)
+ d = tf.add(y, b)
+ return c
+
+ with tf.Graph().as_default():
+ x = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ y = lambda: tf.constant(2.0, shape=[1, 2], dtype=tf.float32)
+ z = tf.constant(2.0, shape=[1, 3], dtype=tf.float32)
+ z2 = MyAdd(tf.case([(tf.less(1, 0), x)], default=y), z)
+ */
+ GrapplerItem item;
+ string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
+ "function_switch_shapes.pbtxt");
+ TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop.dtype());
+ EXPECT_FALSE(out_prop.shape().unknown_rank());
+ EXPECT_EQ(2, out_prop.shape().dim_size());
+ EXPECT_EQ(1, out_prop.shape().dim(0).size());
+ EXPECT_EQ(2, out_prop.shape().dim(1).size());
+
+ const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
+ EXPECT_EQ(2, in_props.size());
+
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(1, in_prop.shape().dim(0).size());
+ EXPECT_EQ(2, in_prop.shape().dim(1).size());
+
+ const OpInfo::TensorProperties& in_prop1 = in_props[1];
+ EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
+ EXPECT_FALSE(in_prop1.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop1.shape().dim_size());
+ EXPECT_EQ(1, in_prop1.shape().dim(0).size());
+ EXPECT_EQ(3, in_prop1.shape().dim(1).size());
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt
new file mode 100644
index 0000000000..c3f0a6c95d
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_error.pbtxt
@@ -0,0 +1,117 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "MyAdd_yabA4wXEdM4"
+ op: "MyAdd_yabA4wXEdM4"
+ input: "Const"
+ input: "Const_1"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_yabA4wXEdM4"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "add_1"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "Add:z:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "Add_1"
+ op: "Add"
+ input: "Add:z:0"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "add_1"
+ value: "Add_1:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt
new file mode 100644
index 0000000000..d6d856ce41
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch.pbtxt
@@ -0,0 +1,251 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_MPaeanipb7o"
+ op: "MyAdd_MPaeanipb7o"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_MPaeanipb7o"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt
new file mode 100644
index 0000000000..e57d9d7076
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_2.pbtxt
@@ -0,0 +1,251 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_MPaeanipb7o"
+ op: "MyAdd_MPaeanipb7o"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_MPaeanipb7o"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt
new file mode 100644
index 0000000000..e9afa91886
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_switch_shapes.pbtxt
@@ -0,0 +1,317 @@
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "Less/x"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "Less/y"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "Less"
+ op: "Less"
+ input: "Less/x"
+ input: "Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "case/cond/Switch"
+ op: "Switch"
+ input: "Less"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_t"
+ op: "Identity"
+ input: "case/cond/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/switch_f"
+ op: "Identity"
+ input: "case/cond/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/pred_id"
+ op: "Identity"
+ input: "Less"
+ attr {
+ key: "T"
+ value {
+ type: DT_BOOL
+ }
+ }
+}
+node {
+ name: "case/cond/Const"
+ op: "Const"
+ input: "^case/cond/switch_t"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Const_1"
+ op: "Const"
+ input: "^case/cond/switch_f"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+}
+node {
+ name: "case/cond/Merge"
+ op: "Merge"
+ input: "case/cond/Const_1"
+ input: "case/cond/Const"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "MyAdd_lEKAAnIwI5I"
+ op: "MyAdd_lEKAAnIwI5I"
+ input: "case/cond/Merge"
+ input: "Const"
+}
+library {
+ function {
+ signature {
+ name: "MyAdd_lEKAAnIwI5I"
+ input_arg {
+ name: "x"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "y"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "Add"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Const_1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 2.0
+ }
+ }
+ }
+ }
+ node_def {
+ name: "Add"
+ op: "Add"
+ input: "x"
+ input: "Const:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "Add_1"
+ op: "Add"
+ input: "y"
+ input: "Const_1:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ ret {
+ key: "Add"
+ value: "Add:z:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt
new file mode 100644
index 0000000000..415c347a1d
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/large_function_graph.pbtxt
@@ -0,0 +1,597 @@
+node {
+ name: "Const/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 64
+ }
+ }
+ }
+}
+node {
+ name: "input_0_0"
+ op: "RandomUniform"
+ input: "Const/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_1/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\001\000\000\000\001\000\000\000\030\000\000\000@\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_1_0"
+ op: "RandomUniform"
+ input: "Const_1/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_2/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\200\000\000\000\340\000\000\000\340\000\000\000\003\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_2_0"
+ op: "RandomUniform"
+ input: "Const_2/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "Const_3/Const"
+ op: "Const"
+ device: "/cpu:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 4
+ }
+ }
+ tensor_content: "\007\000\000\000\007\000\000\000\003\000\000\000\010\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "input_3_0"
+ op: "RandomUniform"
+ input: "Const_3/Const"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "y0"
+ op: "BiasAddx1_Conv2Dx1_DepthwiseConv2dNativex1_Relux1_95"
+ input: "input_0_0"
+ input: "input_1_0"
+ input: "input_2_0"
+ input: "input_3_0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+}
+node {
+ name: "shape"
+ op: "Shape"
+ input: "y0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "zeros"
+ op: "ZerosLike"
+ input: "shape"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "ones"
+ op: "OnesLike"
+ input: "shape"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "slice_0"
+ op: "Slice"
+ input: "y0"
+ input: "zeros"
+ input: "ones"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "identity_0"
+ op: "Identity"
+ input: "slice_0"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "shape_1"
+ op: "Shape"
+ input: "y0:1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "out_type"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "zeros_1"
+ op: "ZerosLike"
+ input: "shape_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "ones_1"
+ op: "OnesLike"
+ input: "shape_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "slice_1"
+ op: "Slice"
+ input: "y0:1"
+ input: "zeros_1"
+ input: "ones_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "Index"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "identity_1"
+ op: "Identity"
+ input: "slice_1"
+ input: "^input_0_0"
+ input: "^input_1_0"
+ input: "^input_2_0"
+ input: "^input_3_0"
+ device: "/cpu:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+library {
+ function {
+ signature {
+ name: "BiasAddx1_Conv2Dx1_DepthwiseConv2dNativex1_Relux1_95"
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/biases/read"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/pointwise_weights/read"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "random_uniform"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "InceptionV2/Conv2d_1a_7x7/depthwise_weights/read"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ type: DT_FLOAT
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/BiasAdd"
+ op: "BiasAdd"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d:output:0"
+ input: "InceptionV2/Conv2d_1a_7x7/biases/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ op: "Relu"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/BiasAdd:output:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d"
+ op: "Conv2D"
+ input: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise:output:0"
+ input: "InceptionV2/Conv2d_1a_7x7/pointwise_weights/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "VALID"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "use_cudnn_on_gpu"
+ value {
+ b: true
+ }
+ }
+ }
+ node_def {
+ name: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ op: "DepthwiseConv2dNative"
+ input: "random_uniform"
+ input: "InceptionV2/Conv2d_1a_7x7/depthwise_weights/read"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ key: "dilations"
+ value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+ attr {
+ key: "padding"
+ value {
+ s: "SAME"
+ }
+ }
+ attr {
+ key: "strides"
+ value {
+ list {
+ i: 1
+ i: 2
+ i: 2
+ i: 1
+ }
+ }
+ }
+ }
+ ret {
+ key: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu"
+ value: "InceptionV2/InceptionV2/Conv2d_1a_7x7/Relu:activations:0"
+ }
+ ret {
+ key: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise"
+ value: "InceptionV2/InceptionV2/Conv2d_1a_7x7/separable_conv2d/depthwise:output:0"
+ }
+ attr {
+ key: "_noinline"
+ value {
+ b: true
+ }
+ }
+ }
+}
+versions {
+ producer: 26
+ min_consumer: 12
+}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b8e337582c..d34eecd009 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -45,6 +45,7 @@ constexpr char kIdentityN[] = "IdentityN";
constexpr char kRefIdentity[] = "RefIdentity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
+constexpr char kSqueeze[] = "Squeeze";
constexpr char kRecv[] = "_Recv";
constexpr char kSend[] = "_Send";
constexpr char kBatchMatMul[] = "BatchMatMul";
@@ -77,6 +78,14 @@ string GetDataFormat(const OpInfo& op_features) {
return data_format;
}
+string GetFilterFormat(const OpInfo& op_features) {
+ string filter_format = "HWIO"; // Default format.
+ if (op_features.attr().find("filter_format") != op_features.attr().end()) {
+ filter_format = op_features.attr().at("filter_format").s();
+ }
+ return filter_format;
+}
+
Padding GetPadding(const OpInfo& op_features) {
if (op_features.attr().find("padding") != op_features.attr().end() &&
op_features.attr().at("padding").s() == "VALID") {
@@ -232,6 +241,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kSqueeze, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -511,29 +521,44 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
y_index = 3;
channel_index = 1;
} else {
+ // Use NHWC.
x_index = 1;
y_index = 2;
channel_index = 3;
}
+ const string& filter_format = GetFilterFormat(op_features);
+ int filter_x_index, filter_y_index, in_channel_index, out_channel_index;
+ if (filter_format == "HWIO") {
+ filter_x_index = 0;
+ filter_y_index = 1;
+ in_channel_index = 2;
+ out_channel_index = 3;
+ } else {
+ // Use OIHW
+ filter_x_index = 2;
+ filter_y_index = 3;
+ in_channel_index = 1;
+ out_channel_index = 0;
+ }
int64 batch = image_shape.dim(0).size();
int64 ix = image_shape.dim(x_index).size();
int64 iy = image_shape.dim(y_index).size();
int64 iz = image_shape.dim(channel_index).size();
- int64 kx = filter_shape.dim(0).size();
- int64 ky = filter_shape.dim(1).size();
+ int64 kx = filter_shape.dim(filter_x_index).size();
+ int64 ky = filter_shape.dim(filter_y_index).size();
std::vector<int64> strides = GetStrides(op_features);
const auto padding = GetPadding(op_features);
int64 sx = strides[x_index];
int64 sy = strides[y_index];
int64 ox = GetOutputSize(ix, kx, sx, padding);
int64 oy = GetOutputSize(iy, ky, sy, padding);
- int64 oz = filter_shape.dim(3).size();
+ int64 oz = filter_shape.dim(out_channel_index).size();
// Only check equality when both sizes are known (in other words, when
// neither is set to a minimum dimension size of 1).
- if (iz != 1 && filter_shape.dim(2).size() != 1) {
- CHECK_EQ(iz, filter_shape.dim(2).size());
+ if (iz != 1 && filter_shape.dim(in_channel_index).size() != 1) {
+ CHECK_EQ(iz, filter_shape.dim(in_channel_index).size());
} else {
- iz = std::max<int64>(iz, filter_shape.dim(2).size());
+ iz = std::max<int64>(iz, filter_shape.dim(in_channel_index).size());
}
OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding};
@@ -1052,6 +1077,24 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
//
// For more information, see
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+
+ // TODO(yaozhang): Support other data formats (NCHW_VECT_C, NHWC_VECT_W) and
+ // filter formats (OIHW_VECT_I).
+ string data_format = GetDataFormat(op_context.op_info);
+ if (data_format != "NCHW" && data_format != "NHWC") {
+ LOG(WARNING) << "unsupported data format: " << data_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+ string filter_format = GetFilterFormat(op_context.op_info);
+ if (filter_format != "HWIO" && filter_format != "OIHW") {
+ LOG(WARNING) << "unsupported filter format: " << filter_format;
+ Costs cost = Costs::ZeroCosts();
+ cost.inaccurate = true;
+ return cost;
+ }
+
auto& conv_input = op_context.op_info.inputs(0);
auto& filter = op_context.op_info.inputs(1);
auto& bias = op_context.op_info.inputs(2);
@@ -1067,28 +1110,12 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct the shape of our output tensor from our convolution dimensions
// and format, as it may not be available yet.
- //
// TODO(varomodt): should we centralize the Conv2D input/output shapes?
- bool unknown_conv_format = false;
OpInfo::TensorProperties output;
- switch (GetConvolutionFormat(op_context)) {
- case NCHW:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
- break;
- case NHWC:
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- break;
- default:
- // TODO(b/77722245): support cost estimation for NCHW_VECT_C.
- LOG(WARNING) << "unsupported data format: "
- << GetDataFormat(op_context.op_info)
- << " Defaulting to NHWC.";
- output =
- DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
- unknown_conv_format = true;
- break;
+ if (data_format == "NCHW") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.ox, dims.oy});
+ } else if (data_format == "NHWC") {
+ output = DescribeTensor(DT_FLOAT, {dims.batch, dims.ox, dims.oy, dims.oz});
}
// Add the operations the fused op always computes.
@@ -1113,7 +1140,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
// Construct component operations and run the cost computation.
auto costs = PredictFusedOp(op_context_with_output, component_ops);
- costs.inaccurate |= found_unknown_shapes || unknown_conv_format;
+ costs.inaccurate |= found_unknown_shapes;
return costs;
}
@@ -1566,20 +1593,6 @@ Costs OpLevelCostEstimator::PredictFusedBatchNormGrad(
}
/* static */
-OpLevelCostEstimator::ConvolutionFormat
-OpLevelCostEstimator::GetConvolutionFormat(const OpContext& op_context) {
- auto data_format = GetDataFormat(op_context.op_info);
- if (data_format == "NCHW") {
- return NCHW;
- } else if (data_format == "NHWC") {
- return NHWC;
- } else if (data_format == "NCHW_VECT_C") {
- return NCHW_VECT_C;
- }
-
- return UNKNOWN_CONVOLUTION_FORMAT;
-}
-
void OpLevelCostEstimator::CombineCostsAndUpdateExecutionTime(
Costs* costs) const {
if (compute_memory_overlap_) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index d384f57279..a277dfdf65 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -84,13 +84,6 @@ class OpLevelCostEstimator {
int64 sy; // Stride y.
Padding padding; // SAME or VALID.
};
- enum ConvolutionFormat {
- UNKNOWN_CONVOLUTION_FORMAT,
- NHWC,
- NCHW,
- NCHW_VECT_C,
- NCHW_VECT_W,
- };
int64 CountConv2DOperations(const OpInfo& op_features,
bool* found_unknown_shapes) const;
int64 CountConv2DOperations(const OpInfo& op_features,
@@ -198,9 +191,6 @@ class OpLevelCostEstimator {
static OpInfo::TensorProperties DescribeTensor(
DataType type, const std::vector<int64>& dims);
- // Returns the Conv2D format for this operation.
- static ConvolutionFormat GetConvolutionFormat(const OpContext& op_context);
-
// This method calculates the execution time depending on whether IO can
// overlap with computation. It assumes the memory and the compute times have
// already been calculated.
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index b2c021b73a..77352f6652 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -155,19 +155,38 @@ OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
// Note that this assumes the NHWC data format.
OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
int iz2, int kx, int ky, int ox,
- int oy, int oz,
- bool has_side_input) {
+ int oy, int oz, bool has_side_input,
+ const string& data_format,
+ const string& filter_format) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
op_context.op_info.set_op("FusedConv2DBiasActivation");
- DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
- DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ auto* attr_data_format = op_context.op_info.mutable_attr();
+ SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
+ auto* attr_filter_format = op_context.op_info.mutable_attr();
+ SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
+ } else {
+ // Use the NCHW format.
+ DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
+ }
+ if (filter_format == "HWIO") {
+ DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
+ } else {
+ // Use the OIHW format.
+ DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
+ }
DescribeTensor1D(oz, op_context.op_info.add_inputs());
// Add the side_input, if any.
auto side_input = op_context.op_info.add_inputs();
if (has_side_input) {
- DescribeTensor4D(batch, ox, oy, oz, side_input);
+ if (data_format == "NHWC") {
+ DescribeTensor4D(batch, ox, oy, oz, side_input);
+ } else {
+ DescribeTensor4D(batch, oz, ox, oy, side_input);
+ }
}
// Add the scaling tensors.
@@ -549,25 +568,79 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
SetComputeMemoryOverlap(false); // Set it back to default.
}
-TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest,
+ FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true));
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
+ "NCHW", "HWIO"));
+ EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "HWIO"));
EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
-TEST_F(OpLevelCostEstimatorTest,
- FusedConv2DBiasActivationNoSideInputExecutionTime) {
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
- 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false));
- EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
- EXPECT_EQ(Costs::Duration(355321038), cost.compute_time);
- EXPECT_EQ(Costs::Duration(356146383), cost.execution_time);
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
EXPECT_FALSE(cost.inaccurate);
}
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "HWIO"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NHWC", "OIHW"));
+ EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(355616770), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(357033578), cost.execution_time);
+ EXPECT_FALSE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once NCHW_VECT_C is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW_VECT_C", "OIHW"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
+// TODO(yaozhang): Update once OIHW_VECT_I is supported.
+TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
+ auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
+ 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
+ "NCHW", "OIHW_VECT_I"));
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(0), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(0), cost.execution_time);
+ EXPECT_TRUE(cost.inaccurate);
+}
+
TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
@@ -655,8 +728,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
TensorProto tensor_proto;
TensorShapeProto tensor_shape_proto;
- // Dimension larger than max value; should fail while converting to Tensor
- // class.
+ // Dimension larger than max value; should fail while converting to
+ // Tensor class.
tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
EXPECT_FALSE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
@@ -676,8 +749,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
// Check GetTensorShapeProtoFromTensorProto() resturns correct values.
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -685,8 +758,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/false, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -694,8 +767,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
- GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT32, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
@@ -703,8 +776,8 @@ TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
{
std::vector<int64> shape_expected = {40, 20, 90, 40};
- GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true,
- &tensor_proto);
+ GetTensorProto(DT_INT64, {4}, shape_expected,
+ /*tensor_content=*/true, &tensor_proto);
EXPECT_TRUE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
ExpectTensorShape(shape_expected, tensor_shape_proto);
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2a47a4c495..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
+bool IsElementWiseMonotonic(const NodeDef& node) {
+ static const std::unordered_set<string>* element_wise_monotonic_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Relu",
+ "Relu6",
+ "Sigmoid",
+ "Sqrt",
+ "Tanh",
+ }));
+ return element_wise_monotonic_ops->count(node.op()) > 0;
+}
+
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsEnter(const NodeDef& node) {
@@ -193,6 +205,8 @@ bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
+bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
+
bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
@@ -615,7 +629,8 @@ bool HasOpDef(const NodeDef& node) {
}
bool IsIdempotent(const NodeDef& node) {
- return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node);
+ return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
+ !ModifiesFrameInfo(node);
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index e7f39981c0..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
@@ -74,6 +75,7 @@ bool IsImag(const NodeDef& node);
bool IsInvGrad(const NodeDef& node);
bool IsLess(const NodeDef& node);
bool IsLessEqual(const NodeDef& node);
+bool IsLog(const NodeDef& node);
bool IsLogicalAnd(const NodeDef& node);
bool IsLogicalNot(const NodeDef& node);
bool IsLogicalOr(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index c90667abad..b1d6d48e31 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -171,6 +171,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:grappler_test",
],
@@ -210,8 +211,7 @@ cc_library(
hdrs = ["graph_optimizer_stage.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
@@ -225,6 +225,7 @@ tf_cuda_cc_test(
deps = [
":graph_optimizer_stage",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
@@ -328,11 +329,13 @@ tf_cuda_cc_test(
":model_pruner",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -677,6 +680,7 @@ cc_library(
deps = [
":constant_folding",
":graph_optimizer",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
@@ -778,7 +782,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:scoped_allocator_ops_op_lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index ca3f84a81d..97862d1ed0 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -101,38 +101,6 @@ bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
return false;
}
-template <typename T>
-bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
- const T n = perm.size();
- if (n < 2) {
- return false;
- }
- for (T i = 0; i < n - 2; ++i) {
- if (perm[i] != i) {
- return false;
- }
- }
- return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
-}
-
-bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
- const NodeMap* node_map) {
- if (transpose_node.op() != "Transpose" &&
- transpose_node.op() != "ConjugateTranspose") {
- return false;
- }
- const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
- std::vector<int> perm32;
- if (ValuesFromConstNode(*perm_node, &perm32)) {
- return IsInnerMatrixTranspose(perm32);
- }
- std::vector<int64> perm64;
- if (ValuesFromConstNode(*perm_node, &perm64)) {
- return IsInnerMatrixTranspose(perm64);
- }
- return false;
-}
-
bool MaybeAddControlInput(const string& new_input, NodeDef* node,
GraphDef* graph, NodeMap* node_map) {
bool already_exists = false;
@@ -155,12 +123,6 @@ void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
(*node->mutable_attr())[attr_name].set_type(dtype);
}
-void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
- const bool old_value =
- !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
- (*node->mutable_attr())[attr_name].set_b(!old_value);
-}
-
string SourceDataTypeAttrName(const NodeDef& node) {
if (node.op() == "Bitcast") {
return "T";
@@ -265,6 +227,27 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
ctx().nodes_to_preserve->end();
}
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): move to GraphOptimizerStage?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@@ -395,27 +378,6 @@ class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
is_broadcastable);
}
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool IsDrivenByControlDependency(const NodeDef& node) const {
- return std::any_of(node.input().begin(), node.input().end(),
- IsControlInput);
- }
-
- // TODO(ezhulenev): move to GraphOptimizerStage?
- bool DrivesControlDependency(const NodeDef& node) const {
- int position;
- for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
- for (int i = 0; i < output->input_size(); ++i) {
- auto input = output->input(i);
- string name = ParseNodeName(input, &position);
- if (name == node.name() && /*control input*/ position < 0) {
- return true;
- }
- }
- }
- return false;
- }
-
string ShapeSignature(const TensorShapeProto& shape) const {
string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
for (int i = 0; i < shape.dim_size(); ++i)
@@ -1122,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
+ // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
+ if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -1757,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
@@ -1958,6 +1919,901 @@ class ReorderCastAndTranspose : public ArithmeticOptimizerStage {
bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
};
+// Fold a multiply of a scalar into the following convolution. This folding
+// can jump across nodes that merely reorders data (such as reshape and
+// transpose). For example, we can optimize
+//
+//
+// Conv2D Conv2D
+// / \ / \
+// Transpose weights* -> Transpose Mul
+// | | / \
+// Mul | weights scale
+// / \ |
+// input scale** input
+//
+// *) weights must be a const
+// **) scale must be a const scalar
+//
+// When `weights` and `scale` are constant, `Mul` in the optimized graph can be
+// constant-folded, also weights tend to be smaller than the activations.
+//
+// TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
+// Conv?DBackpropInput.
+class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
+ ~FoldMultiplyIntoConv() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsConv2D(*node) || IsConv3D(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+#define TF_RETURN_IF_TRUE(...) \
+ if ((__VA_ARGS__)) return Status::OK()
+
+ NodeDef* conv = node;
+
+ NodeDef* weights;
+ TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
+
+ // Fold the multiply to conv only when the weights are constant, so the
+ // multiply can be constant-folded.
+ //
+ // TODO(jingyue): When the weights aren't constant, this should also help
+ // performance a bit and memory usage a lot, since the weights tend to be
+ // smaller than the activations.
+ TF_RETURN_IF_TRUE(!IsConstant(*weights));
+
+ // Verify that this node was not already optimized.
+ const string scaled_weights_node_name =
+ OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
+ strings::StrCat("scaled", "_", conv->name()));
+
+ TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
+
+ // Find the tail of value preserving chain entering the Conv node.
+ NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+
+ NodeDef* source;
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
+
+ // Check that value preserving chain is the only consumer of the Mul output.
+ TF_RETURN_IF_TRUE(!IsMul(*source));
+ TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
+
+ const NodeDef* mul = source;
+
+ // TODO(jingyue): handle the case where `scale` is 0-th operand.
+ NodeDef* scale; // scalar multiplier fot the input tensor
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale));
+ TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input));
+
+ // Check that 'scale * weight' can be const folded.
+ TF_RETURN_IF_TRUE(!IsConstant(*scale));
+ TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
+ weights->attr().at("dtype").type());
+
+ // Check that `scale` is a scalar.
+ const TensorProto& scale_tensor = scale->attr().at("value").tensor();
+ bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
+ scale_tensor.tensor_shape().dim_size() == 0;
+ TF_RETURN_IF_TRUE(!scale_is_a_scalar);
+
+ // At this point all preconditions are met, and we safely do the rewrite.
+ VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
+ << " mul=" << mul->name() << " weights=" << weights->name();
+
+ // Create new node `scaled_weights`.
+ NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
+ scaled_weights->set_op("Mul");
+ scaled_weights->set_device(weights->device());
+ (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
+ AddToOptimizationQueue(scaled_weights);
+
+ // Link in its inputs.
+ scaled_weights->add_input(conv->input(1));
+ ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
+ scaled_weights->add_input(mul->input(1));
+ ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
+ ForwardControlDependencies(scaled_weights, {source});
+
+ // Update `conv`'s weights to `scaled_weights`.
+ conv->set_input(1, scaled_weights->name());
+ ctx().node_map->UpdateInput(conv->name(), weights->name(),
+ scaled_weights->name());
+ AddToOptimizationQueue(conv);
+
+ // Update `tail` node to bypass `mul` because it's folded to the weights.
+ tail->set_input(0, mul->input(0));
+ ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
+ AddToOptimizationQueue(tail);
+ *simplified_node_name = conv->name();
+
+ return Status::OK();
+#undef TF_RETURN_IF_TRUE
+ }
+};
+
+// Fold Transpose into matrix multiplication.
+class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
+ ~FoldTransposeIntoMatMul() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMatMul(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(matmul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ NodeDef* a;
+ NodeDef* b;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
+
+ bool is_complex = false;
+ if (node->op() != "SparseMatMul") {
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+ }
+
+ const std::set<string> foldable_transpose_ops =
+ !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
+ : (node->op() == "BatchMatMul"
+ ? std::set<string>{"ConjugateTranspose"}
+ : std::set<string>{"Transpose"});
+
+ const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
+ IsInnerMatrixTransposeNode(*a, ctx().node_map);
+ const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
+ IsInnerMatrixTransposeNode(*b, ctx().node_map);
+ if (!a_is_foldable && !b_is_foldable) return Status::OK();
+
+ NodeDef* new_op = AddCopyNode(optimized_node_name, node);
+
+ if (a_is_foldable) {
+ const string attr_a =
+ node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
+ FlipBooleanAttr(attr_a, new_op);
+ new_op->set_input(0, a->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
+ }
+
+ if (b_is_foldable) {
+ const string attr_b =
+ node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
+ FlipBooleanAttr(attr_b, new_op);
+ new_op->set_input(1, b->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
+ }
+
+ std::vector<const NodeDef*> deps_to_forward = {node};
+ if (a_is_foldable) deps_to_forward.push_back(a);
+ if (b_is_foldable) deps_to_forward.push_back(b);
+ ForwardControlDependencies(new_op, deps_to_forward);
+
+ return Status::OK();
+ }
+
+ private:
+ void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
+ const bool old_value =
+ !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
+ (*node->mutable_attr())[attr_name].set_b(!old_value);
+ }
+
+ template <typename T>
+ bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
+ const T n = perm.size();
+ if (n < 2) {
+ return false;
+ }
+ for (T i = 0; i < n - 2; ++i) {
+ if (perm[i] != i) {
+ return false;
+ }
+ }
+ return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
+ }
+
+ bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
+ const NodeMap* node_map) {
+ if (transpose_node.op() != "Transpose" &&
+ transpose_node.op() != "ConjugateTranspose") {
+ return false;
+ }
+ const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
+ std::vector<int> perm32;
+ if (ValuesFromConstNode(*perm_node, &perm32)) {
+ return IsInnerMatrixTranspose(perm32);
+ }
+ std::vector<int64> perm64;
+ if (ValuesFromConstNode(*perm_node, &perm64)) {
+ return IsInnerMatrixTranspose(perm64);
+ }
+ return false;
+ }
+};
+
+// Fold Transpose into matrix multiplication.
+class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
+ public:
+ explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
+ ~FoldConjugateIntoTranspose() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(matmul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+
+ const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
+ const NodeDef* conj_op = node->op() == "Conj" ? node : input;
+
+ if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
+ IsConj(*conj_op)) {
+ NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
+
+ // Flip the type of transpose op to absorb the conjugation.
+ new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
+ : "Transpose");
+ new_op->set_input(0, input->input(0));
+ ctx().node_map->UpdateInput(new_op->name(), node->name(),
+ input->input(0));
+ ForwardControlDependencies(new_op, {node, input});
+ *simplified_node_name = new_op->name();
+ }
+
+ return Status::OK();
+ }
+};
+
+// Replace Mul node with identical inputs with a Square.
+class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
+ public:
+ explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
+ ~ReplaceMulWithSquare() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMul(*node) && node->input(0) == node->input(1);
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
+ const string optimized_node_name = OptimizedNodeName(mul);
+ if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
+
+ const DataType type = GetDataTypeFromAttr(*node, "T");
+ bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
+
+ string task;
+ string device;
+ bool is_on_cpu =
+ DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
+ str_util::StrContains(device, DEVICE_CPU);
+
+ if (!is_complex || is_on_cpu) {
+ NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
+ new_square_node->set_op("Square");
+ for (int i = 1; i < new_square_node->input_size(); ++i) {
+ new_square_node->set_input(i - 1, new_square_node->input(i));
+ }
+ new_square_node->mutable_input()->RemoveLast();
+ for (const string& input : new_square_node->input()) {
+ ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
+ }
+ *simplified_node_name = new_square_node->name();
+ }
+
+ return Status::OK();
+ }
+};
+
+// Simplify aggregation (e.g. AddN) nodes:
+//
+// 1. Discard aggregate nodes with a single input and no control dependencies.
+//
+// 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
+// deduping or other rewrites) so we can get rid of the sum entirely.
+//
+// The expression (using AddN as an example of an aggregate op):
+// AddN(x, x, x, ... ,x)
+// <-- N terms -->
+// can be rewritten to:
+// Mul(Const(N), x))
+//
+class SimplifyAggregation : public ArithmeticOptimizerStage {
+ public:
+ explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
+ ~SimplifyAggregation() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsAggregate(*node) && NumNonControlInputs(*node) > 0;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ // 1. Discard aggregate nodes with a single input and no control deps.
+ if (node->input_size() == 1) {
+ *simplified_node_name = node->input(0);
+ return Status::OK();
+ }
+
+ // 2. Rewrite aggregations of N >= 2 identical terms.
+
+ // All non-control inputs must be identical.
+ bool all_equal = true;
+ int num_inputs = 1;
+ for (int i = 1; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) break;
+ ++num_inputs;
+ if (node->input(i) != node->input(0)) {
+ all_equal = false;
+ break;
+ }
+ }
+ if (!all_equal) return Status::OK();
+
+ // And node should not be optimized earlier.
+ const NodeScopeAndName node_scope_and_name =
+ ParseNodeScopeAndName(node->name());
+ const string optimized_const_name =
+ OptimizedNodeName(node_scope_and_name, "Const");
+ const string optimized_mul_name =
+ OptimizedNodeName(node_scope_and_name, "Mul");
+
+ bool is_already_optimized =
+ ctx().node_map->NodeExists(optimized_const_name) ||
+ ctx().node_map->NodeExists(optimized_mul_name);
+
+ if (is_already_optimized) return Status::OK();
+
+ // At this point all preconditions are met, and we safely do the rewrite.
+ VLOG(3) << "Simplify aggregation with identical inputs: node="
+ << node->name() << " num_inputs=" << num_inputs;
+
+ // 1. Create constant node with value N.
+ const auto type = GetDataTypeFromAttr(*node, "T");
+ Tensor t(type, TensorShape({}));
+ Status status = SetTensorValue(type, num_inputs, &t);
+ if (!status.ok()) {
+ return errors::Internal("Failed to create const node: ",
+ status.error_message());
+ }
+
+ TensorValue value(&t);
+ NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
+ status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
+ new_const_node);
+ if (!status.ok()) {
+ return errors::Internal("Failed to create const node: ",
+ status.error_message());
+ }
+ new_const_node->set_device(node->device());
+ MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+ ctx().optimized_graph, ctx().node_map);
+ AddToOptimizationQueue(new_const_node);
+
+ // 2. Replace the aggregate node with Mul(Const(N), x).
+ NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
+ new_mul_node->set_op("Mul");
+ new_mul_node->set_device(node->device());
+ SetDataTypeToAttr(type, "T", new_mul_node);
+ new_mul_node->add_input(new_const_node->name());
+ ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
+ new_mul_node->add_input(node->input(0));
+ ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
+
+ ForwardControlDependencies(new_mul_node, {node});
+ *simplified_node_name = new_mul_node->name();
+
+ return Status::OK();
+ }
+};
+
+class ConvertPowStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertPowStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsPow(*node) &&
+ ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
+ }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < p.shape().dim_size(); ++i) {
+ if (p.shape().dim(i).size() < 0) {
+ // skip if p is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(p.shape()) && p.has_value()) {
+ Tensor pow(p.dtype(), p.shape());
+ if (!pow.FromProto(p.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ p.value().DebugString());
+ }
+
+ complex128 prev, curr;
+ for (int i = 0; i < pow.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(GetElement(pow, i, &curr));
+ if (i != 0 && curr != prev) {
+ // pow has different values on different elements. Skip.
+ return Status::OK();
+ }
+ prev = curr;
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ if (curr == complex128(2, 0)) {
+ node->set_op("Square");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(1, 0)) {
+ node->set_op("Identity");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0.5, 0)) {
+ node->set_op("Sqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(0, 0)) {
+ const auto& b =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ for (int i = 0; i < b.shape().dim_size(); ++i) {
+ if (b.shape().dim(i).size() < 0) {
+ // skip if b is is not fully defined.
+ return Status::OK();
+ }
+ }
+ if (TensorShape::IsValid(b.shape()) && b.has_value()) {
+ Tensor base(b.dtype(), b.shape());
+ if (!base.FromProto(b.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ b.value().DebugString());
+ }
+ node->set_op("Const");
+ Tensor c(base.dtype(), base.shape());
+ for (int i = 0; i < c.NumElements(); ++i) {
+ TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
+ }
+ (*node->mutable_attr())["dtype"].set_type(base.dtype());
+ c.AsProtoTensorContent(
+ (*node->mutable_attr())["value"].mutable_tensor());
+ node->mutable_attr()->erase("T");
+ node->set_input(0, AsControlDependency(x->name()));
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ }
+ } else if (curr == complex128(-0.5, 0)) {
+ node->set_op("Rsqrt");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ } else if (curr == complex128(-1, 0)) {
+ node->set_op("Reciprocal");
+ node->set_input(1, AsControlDependency(y->name()));
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(y);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_INT32:
+ *element = complex128(t.flat<int32>()(i));
+ return Status::OK();
+ case DT_INT64:
+ *element = complex128(t.flat<int64>()(i));
+ return Status::OK();
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return Status::OK();
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return Status::OK();
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return Status::OK();
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t.dtype());
+ }
+ }
+
+ Status SetElementToOne(int i, Tensor* t) {
+ switch (t->dtype()) {
+ case DT_INT32:
+ t->flat<int32>()(i) = 1;
+ return Status::OK();
+ case DT_INT64:
+ t->flat<int64>()(i) = 1L;
+ return Status::OK();
+ case DT_FLOAT:
+ t->flat<float>()(i) = 1.0f;
+ return Status::OK();
+ case DT_DOUBLE:
+ t->flat<double>()(i) = 1.0;
+ return Status::OK();
+ case DT_COMPLEX64:
+ t->flat<complex64>()(i) = complex64(1);
+ return Status::OK();
+ case DT_COMPLEX128:
+ t->flat<complex128>()(i) = complex128(1);
+ return Status::OK();
+ default:
+ return errors::InvalidArgument("Invalid data type: ", t->dtype());
+ }
+ }
+};
+
+class ConvertLog1pStage : public ArithmeticOptimizerStage {
+ public:
+ explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
+ ~ConvertLog1pStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
+
+ Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
+ NodeDef* input;
+ TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
+ if (!IsAdd(*input)) {
+ return Status::OK();
+ }
+
+ if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
+ return Status::OK();
+ }
+
+ bool modified = false;
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
+ if (!modified) {
+ TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
+ }
+ if (modified) {
+ *simplified_node_name = node->name();
+ }
+ return Status::OK();
+ }
+
+ private:
+ Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
+ bool* modified) {
+ const auto& t =
+ ctx().graph_properties->GetInputProperties(input->name())[i];
+ const auto& c =
+ ctx().graph_properties->GetInputProperties(input->name())[j];
+ for (int k = 0; k < c.shape().dim_size(); ++k) {
+ // Skip if c shape is not fully determined.
+ if (c.shape().dim(k).size() < 0) {
+ return Status::OK();
+ }
+ }
+ TensorShapeProto broadcast_shape;
+ if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
+ return Status::OK();
+ }
+ if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
+ // skip if the non-constant tensor doesn't have the same shape after
+ // broadcast.
+ return Status::OK();
+ }
+ if (TensorShape::IsValid(c.shape()) && c.has_value()) {
+ Tensor constant(c.dtype(), c.shape());
+ if (!constant.FromProto(c.value())) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ c.value().DebugString());
+ }
+ complex128 element;
+ for (int k = 0; k < constant.NumElements(); ++k) {
+ if (!GetElement(constant, k, &element)) {
+ // input data type is not supported by log1p. Skip.
+ return Status::OK();
+ }
+ if (element != complex128(1)) {
+ // current element is not 1. Skip.
+ return Status::OK();
+ }
+ }
+ NodeDef *x, *y;
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
+ TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
+ node->set_op("Log1p");
+ node->set_input(0, input->input(i));
+ node->add_input(AsControlDependency(y->name()));
+ ForwardControlDependencies(node, {input});
+
+ AddToOptimizationQueue(node);
+ AddToOptimizationQueue(input);
+ AddToOptimizationQueue(x);
+ AddToOptimizationQueue(y);
+ *modified = true;
+ }
+ return Status::OK();
+ }
+
+ bool GetElement(const Tensor& t, int i, complex128* element) {
+ switch (t.dtype()) {
+ case DT_BFLOAT16:
+ *element = complex128(t.flat<bfloat16>()(i));
+ return true;
+ case DT_HALF:
+ *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
+ return true;
+ case DT_FLOAT:
+ *element = complex128(t.flat<float>()(i));
+ return true;
+ case DT_DOUBLE:
+ *element = complex128(t.flat<double>()(i));
+ return true;
+ case DT_COMPLEX64:
+ *element = complex128(t.flat<complex64>()(i));
+ return true;
+ case DT_COMPLEX128:
+ *element = t.flat<complex128>()(i);
+ return true;
+ default:
+ return false;
+ }
+ }
+};
+
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
+// Replace a chain of type&shape preserving unary ops with a
+// '_UnaryOpsComposition' node.
+// TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
+// have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
+class UnaryOpsComposition : public ArithmeticOptimizerStage {
+ public:
+ explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
+ // WARN: This should be consistent with unary_ops_composition.cc.
+ // clang-format off
+ supported_ops_ = {// Ops defined via Eigen scalar ops.
+ {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Acos", {DT_FLOAT, DT_DOUBLE}},
+ {"Acosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Asin", {DT_FLOAT, DT_DOUBLE}},
+ {"Asinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Atan", {DT_FLOAT, DT_DOUBLE}},
+ {"Atanh", {DT_FLOAT, DT_DOUBLE}},
+ {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Cosh", {DT_FLOAT, DT_DOUBLE}},
+ {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rint", {DT_FLOAT, DT_DOUBLE}},
+ {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Sinh", {DT_FLOAT, DT_DOUBLE}},
+ {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Tan", {DT_FLOAT, DT_DOUBLE}},
+ {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ // Additional ops that are not part of the Eigen.
+ {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
+ {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
+ // clang-format on
+ }
+ ~UnaryOpsComposition() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return CanOptimize(*node) &&
+ // Check that this node was not already a root of a fused chain. If
+ // graph optimization runs twice without pruning in between,
+ // fused_nodes_ will not have this information.
+ !ctx().node_map->NodeExists(OptimizedNodeName(*node));
+ }
+
+ Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
+ DataType dtype = root->attr().at("T").type();
+
+ // Keep a trace of all supported input nodes that can be fused together.
+ std::vector<string> op_nodes = {root->name()};
+ std::vector<string> op_names = {root->op()};
+
+ // Check if we should follow input(0) while building an op composition.
+ const auto predicate_fn = [&](const NodeDef& input) {
+ if (input.name() == root->name()) return true;
+
+ bool follow_input_node =
+ dtype == GetDataTypeFromAttr(input, "T") &&
+ NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
+ CanOptimize(input);
+
+ if (follow_input_node) {
+ op_nodes.push_back(input.name());
+ op_names.push_back(input.op());
+ }
+
+ return follow_input_node;
+ };
+
+ NodeDef* last_op = GetTailOfChain(
+ *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
+
+ // We were not able to find a chain that can be replaced.
+ if (op_names.size() == 1) return Status::OK();
+
+ // Do not add fused nodes to any other chain.
+ std::for_each(op_nodes.begin(), op_nodes.end(),
+ [this](const string& name) { AddToFusedNodes(name); });
+
+ // Reverse the trace to get correct composition computation order.
+ std::reverse(op_names.begin(), op_names.end());
+
+ VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
+ << str_util::Join(op_names, ", ") << "]";
+
+ NodeDef* composition_node = ctx().optimized_graph->add_node();
+ composition_node->set_name(OptimizedNodeName(*root));
+ composition_node->set_op("_UnaryOpsComposition");
+ composition_node->add_input(last_op->input(0));
+ composition_node->set_device(root->device());
+
+ auto attr = composition_node->mutable_attr();
+ SetAttrValue(dtype, &(*attr)["T"]);
+ SetAttrValue(op_names, &(*attr)["op_names"]);
+
+ ctx().node_map->AddNode(composition_node->name(), composition_node);
+ ctx().node_map->AddOutput(NodeName(last_op->input(0)),
+ composition_node->name());
+
+ *simplified_node_name = composition_node->name();
+
+ return Status::OK();
+ }
+
+ private:
+ bool CanOptimize(const NodeDef& node) const {
+ DataType dtype = GetDataTypeFromAttr(node, "T");
+ if (!IsSupported(node.op(), dtype)) {
+ return false;
+ }
+ if (IsInPreserveSet(node)) {
+ return false;
+ }
+ if (!NodeIsOnCpu(node)) {
+ return false;
+ }
+ if (NodeIsAlreadyFused(node)) {
+ return false;
+ }
+ return !(IsDrivenByControlDependency(node) ||
+ DrivesControlDependency(node));
+ }
+
+ // UnaryOpsComposition is defined only for CPU.
+ bool NodeIsOnCpu(const NodeDef& node) const {
+ using str_util::StartsWith;
+
+ string task;
+ string device;
+
+ return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
+ StartsWith(device, DEVICE_CPU);
+ }
+
+ bool NodeIsAlreadyFused(const NodeDef& node) const {
+ return fused_nodes_.count(node.name()) > 0;
+ }
+
+ string OptimizedNodeName(const NodeDef& node) const {
+ return strings::StrCat(node.name(), "/unary_ops_composition");
+ }
+
+ void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
+
+ // Check if an op is supported by the _UnaryOpsComposition for the given type.
+ bool IsSupported(const string& op_name, DataType dtype) const {
+ const auto it = supported_ops_.find(op_name);
+ return it != supported_ops_.end() && it->second.count(dtype) > 0;
+ }
+
+ std::unordered_map<string, std::set<DataType>> supported_ops_;
+ std::unordered_set<string> fused_nodes_;
+};
+
} // namespace
class UniqueNodes {
@@ -2056,33 +2912,6 @@ bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
return true;
}
-NodeDef* ArithmeticOptimizer::AddNode(const NodeDef& node, StringPiece suffix,
- bool copy_node) {
- return AddNode(OptimizedNodeName(node, suffix), copy_node ? &node : nullptr);
-}
-
-NodeDef* ArithmeticOptimizer::AddNode(const string& name,
- const NodeDef* node_to_copy) {
- NodeDef* new_node = optimized_graph_->add_node();
- node_map_->AddNode(NodeName(name), new_node);
- if (node_to_copy != nullptr) {
- *new_node = *node_to_copy;
- }
- new_node->set_name(name);
- return new_node;
-}
-
-string ArithmeticOptimizer::OptimizedNodeName(const NodeDef& node,
- StringPiece suffix) const {
- return AddPrefixToNodeName(strings::StrCat(node.name(), "_", suffix),
- kArithmeticOptimizer);
-}
-
-bool ArithmeticOptimizer::OptimizedNodeExists(const NodeDef& node,
- StringPiece suffix) const {
- return node_map_->NodeExists(OptimizedNodeName(node, suffix));
-}
-
namespace {
bool FeedsInPlaceOp(const SimpleGraphView& graph_view, const NodeDef& node) {
@@ -2206,263 +3035,6 @@ void ArithmeticOptimizer::ForwardControlDependencies(
DedupControlInputs(target_node);
}
-// TODO(ezhulenev): extract each individual simplify rewrite into separate
-// ArithmeticOptimizerStage
-string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
- const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
- // Fold a multiply of a scalar into the following convolution. This folding
- // can jump across nodes that merely reorders data (such as reshape and
- // transpose). For example, we can optimize
- //
- //
- // Conv2D
- // / \
- // Transpose weights
- // |
- // Mul
- // / \
- // inputs 255.0
- //
- // to
- //
- // Conv2D
- // / \
- // Transpose Mul
- // | / \
- // | weights 255.0
- // |
- // inputs
- //
- // when `weights` are constant. `Mul` in the optimized graph can be
- // constant-folded.
- //
- // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
- // Conv?DBackpropInput.
- if (node->op() == "Conv2D" || node->op() == "Conv3D") {
- NodeDef* conv = const_cast<NodeDef*>(node);
- const NodeDef* weights = node_map_->GetNode(NodeName(conv->input(1)));
- // Fold the multiply to conv only when the weights are constant, so the
- // multiply can be constant-folded. TODO(jingyue): When the weights aren't
- // constant, this should also help performance a bit and memory usage a lot,
- // since the weights tend to be smaller than the activations.
- if (weights->op() == "Const" &&
- !OptimizedNodeExists(*weights, StrCat("scaled_", conv->name()))) {
- const NodeDef* source = node_map_->GetNode(
- GetTailOfValuePreservingChain(*node, *node_map_, nodes_to_preserve_)
- ->input(0));
- if (source->op() == "Mul" &&
- node_map_->GetOutputs(source->name()).size() == 1) {
- const NodeDef* mul = source;
- // `scale` is the scalar multiplier, and `other` is the other operand.
- // TODO(jingyue): handle the case where `scale` is 0-th operand.
- const NodeDef* scale = node_map_->GetNode(mul->input(1));
- const NodeDef* other = node_map_->GetNode(mul->input(0));
- if (scale->op() == "Const" && scale->attr().at("dtype").type() ==
- weights->attr().at("dtype").type()) {
- const TensorProto& scale_tensor = scale->attr().at("value").tensor();
- // Test whether `scale` is a scalar.
- if (scale_tensor.has_tensor_shape() &&
- scale_tensor.tensor_shape().dim_size() == 0) {
- // Create new node `scaled_weights`.
- NodeDef* scaled_weights = AddNode(
- *weights, StrCat("scaled_", conv->name()), /*copy_node=*/false);
- scaled_weights->set_op("Mul");
- scaled_weights->set_device(weights->device());
- (*scaled_weights->mutable_attr())["T"] =
- weights->attr().at("dtype");
- nodes_to_simplify->PushBack(scaled_weights);
-
- // Link in its inputs.
- scaled_weights->add_input(conv->input(1));
- node_map_->AddOutput(weights->name(), scaled_weights->name());
- scaled_weights->add_input(mul->input(1));
- node_map_->AddOutput(scale->name(), scaled_weights->name());
- ForwardControlDependencies(scaled_weights, {source});
-
- // Update `conv`'s weights to `scaled_weights`.
- conv->set_input(1, scaled_weights->name());
- node_map_->UpdateInput(conv->name(), weights->name(),
- scaled_weights->name());
- nodes_to_simplify->PushBack(conv);
-
- // Update `mul`'s consumer to bypass `mul` because it's folded to
- // the weights.
- CHECK_EQ(node_map_->GetOutputs(mul->name()).size(), 1);
- NodeDef* consumer_of_mul =
- *node_map_->GetOutputs(mul->name()).begin();
- consumer_of_mul->set_input(0, mul->input(0));
- node_map_->UpdateInput(consumer_of_mul->name(), mul->name(),
- other->name());
- nodes_to_simplify->PushBack(consumer_of_mul);
- return conv->name();
- }
- }
- }
- }
- }
-
- if (node->op() == "Mul" && node->input(0) == node->input(1) &&
- !OptimizedNodeExists(*node, "square")) {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- string dontcare;
- string device;
- bool is_on_cpu =
- DeviceNameUtils::SplitDeviceName(node->device(), &dontcare, &device) &&
- str_util::StrContains(device, DEVICE_CPU);
- if (!is_complex || is_on_cpu) {
- NodeDef* new_square_node = AddNode(*node, "square", /*copy_node=*/true);
- new_square_node->set_op("Square");
- for (int i = 1; i < new_square_node->input_size(); ++i) {
- new_square_node->set_input(i - 1, new_square_node->input(i));
- }
- new_square_node->mutable_input()->RemoveLast();
- for (const string& input : new_square_node->input()) {
- node_map_->AddOutput(NodeName(input), new_square_node->name());
- }
- return new_square_node->name();
- }
- }
-
- if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
- // Discard aggregate nodes with a single input and no control dependencies.
- if (node->input_size() == 1) {
- return node->input(0);
- }
-
- // Try to rewrite aggregations of N >= 2 identical terms (possibly due
- // to deduping or other rewrites) so we can get rid of the sum entirely.
- // The expression (using AddN as an example of an aggregate op):
- // AddN(x, x, x, ... ,x)
- // <-- N terms -->
- // can be rewritten to
- // Mul(Const(N), x))
- //
- bool all_equal = true;
- int num_inputs = 1;
- for (int i = 1; i < node->input_size(); ++i) {
- if (IsControlInput(node->input(i))) {
- break;
- }
- ++num_inputs;
- if (node->input(i) != node->input(0)) {
- all_equal = false;
- break;
- }
- }
- if (all_equal && !OptimizedNodeExists(*node, "const") &&
- !OptimizedNodeExists(*node, "mul")) {
- // 1. Create constant node with value N.
- const auto type = GetDataTypeFromAttr(*node, "T");
- Tensor t(type, TensorShape({}));
- Status status = SetTensorValue(type, num_inputs, &t);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create const node: "
- << status.error_message();
- return "";
- }
- TensorValue value(&t);
- NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false);
- status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
- new_const_node);
- if (!status.ok()) {
- LOG(WARNING) << "Failed to create const node: "
- << status.error_message();
- return "";
- }
- new_const_node->set_device(node->device());
- MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
- optimized_graph_, node_map_.get());
- nodes_to_simplify->PushBack(new_const_node);
-
- // 2. Replace the aggregate node with Mul(Const(N), x).
- NodeDef* new_mul_node = AddNode(*node, "mul", /*copy_node=*/false);
- new_mul_node->set_op("Mul");
- new_mul_node->set_device(node->device());
- SetDataTypeToAttr(type, "T", new_mul_node);
- new_mul_node->add_input(new_const_node->name());
- node_map_->AddOutput(new_const_node->name(), new_mul_node->name());
- new_mul_node->add_input(node->input(0));
- node_map_->AddOutput(node->input(0), new_mul_node->name());
-
- ForwardControlDependencies(new_mul_node, {node});
- return new_mul_node->name();
- }
- }
-
- // Fold Transpose into matrix multiplication.
- if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
- node->op() == "BatchMatMul") &&
- !OptimizedNodeExists(*node, "fused")) {
- const NodeDef* a = node_map_->GetNode(node->input(0));
- const NodeDef* b = node_map_->GetNode(node->input(1));
- bool is_complex = false;
- if (node->op() != "SparseMatMul") {
- const DataType type = GetDataTypeFromAttr(*node, "T");
- is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
- }
- const std::set<string> foldable_transpose_ops =
- !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
- : (node->op() == "BatchMatMul"
- ? std::set<string>{"ConjugateTranspose"}
- : std::set<string>{"Transpose"});
- const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
- IsInnerMatrixTransposeNode(*a, node_map_.get());
- const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
- IsInnerMatrixTransposeNode(*b, node_map_.get());
- if (a_is_foldable || b_is_foldable) {
- NodeDef* new_op = AddNode(*node, "fused", /*copy_node=*/true);
- if (a_is_foldable) {
- const string attr_a =
- node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
- FlipBooleanAttr(attr_a, new_op);
- new_op->set_input(0, a->input(0));
- node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
- }
- if (b_is_foldable) {
- const string attr_b =
- node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
- FlipBooleanAttr(attr_b, new_op);
- new_op->set_input(1, b->input(0));
- node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
- }
- std::vector<const NodeDef*> deps_to_forward({node});
- if (a_is_foldable) {
- deps_to_forward.push_back(a);
- }
- if (b_is_foldable) {
- deps_to_forward.push_back(b);
- }
- ForwardControlDependencies(new_op, deps_to_forward);
- }
- }
-
- // Fold Conj into Transpose or ConjugateTranspose.
- if ((node->op() == "Conj" || node->op() == "Transpose" ||
- node->op() == "ConjugateTranspose") &&
- !OptimizedNodeExists(*node, "fused")) {
- const NodeDef* input = node_map_->GetNode(node->input(0));
- const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
- const NodeDef* conj_op = node->op() == "Conj" ? node : input;
-
- if ((transpose_op->op() == "Transpose" ||
- transpose_op->op() == "ConjugateTranspose") &&
- conj_op->op() == "Conj") {
- NodeDef* new_op =
- AddNode(OptimizedNodeName(*node, "fused"), transpose_op);
- // Flip the type of transpose op to absorb the conjugation.
- new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
- : "Transpose");
- new_op->set_input(0, input->input(0));
- node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
- ForwardControlDependencies(new_op, {node, input});
- return new_op->name();
- }
- }
-
- return "";
-}
-
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
SetVector<NodeDef*> nodes_to_simplify;
nodes_to_simplify.Reserve(optimized_graph_->node_size());
@@ -2471,7 +3043,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get());
+ graph_properties_.get(), node_map_.get(),
+ opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -2480,6 +3053,12 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.combine_add_to_addn && can_use_shapes)
pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
+ if (options_.fold_conjugate_into_transpose)
+ pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
+ if (options_.fold_multiply_into_conv)
+ pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
+ if (options_.fold_transpose_into_matmul)
+ pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
if (options_.minimize_broadcasts && can_use_shapes)
@@ -2496,16 +3075,27 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+ if (options_.replace_mul_with_square)
+ pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
if (options_.remove_logical_not)
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_and_transpose)
pipeline.AddStage<ReorderCastAndTranspose>(ctx, ctx_ext);
+ if (options_.simplify_aggregation)
+ pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
if (options_.hoist_cwise_unary_chains)
pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
if (options_.convert_sqrt_div_to_rsqrt_mul)
pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
if (options_.remove_idempotent)
pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
+ if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
+ if (options_.convert_log1p)
+ pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
+ if (options_.unary_ops_composition)
+ pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
@@ -2513,19 +3103,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
while (!nodes_to_simplify.Empty()) {
NodeDef* node = nodes_to_simplify.PopBack();
- // TODO(ezhulenev): move all rewrites into separate stages
string simplified_tensor = "";
- if (options_.enable_try_simplify_and_replace) {
- simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
- }
+ bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
- // if it was not simplified try to run it through all configured stages
- if (!stop(simplified_tensor)) {
- bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
- if (!optimized) {
- continue;
- }
- }
+ // If the node was not optimized by any of the stages, go to the next one.
+ if (!optimized) continue;
// re-wire consumers of an old node to the new one
if (NodeName(simplified_tensor) != node->name()) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 0fce23a40a..00c02d19bd 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -54,16 +54,16 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
- // TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
- // Remove when all optimizers will be migrated to separate stages.
- bool enable_try_simplify_and_replace = true;
-
bool combine_add_to_addn = true;
- bool convert_sqrt_div_to_rsqrt_mul = false;
+ bool convert_sqrt_div_to_rsqrt_mul = true;
bool dedup_computations = true;
+ bool fold_conjugate_into_transpose = true;
+ bool fold_multiply_into_conv = true;
+ bool fold_transpose_into_matmul = true;
bool hoist_common_factor_out_of_aggregation = true;
bool hoist_cwise_unary_chains = false;
bool minimize_broadcasts = true;
+ bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
bool remove_identity_transpose = true;
bool remove_involution = true;
@@ -73,6 +73,11 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_cast = true;
bool remove_redundant_reshape = true;
bool reorder_cast_and_transpose = true;
+ bool replace_mul_with_square = true;
+ bool simplify_aggregation = true;
+ bool convert_pow = true;
+ bool convert_log1p = true;
+ bool unary_ops_composition = true;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
@@ -83,21 +88,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
}
};
- // Returns true is a node with given name and the optimizer prefix already
- // exists.
- string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
- bool OptimizedNodeExists(const NodeDef& node, StringPiece suffix) const;
-
- // Creates a new node in the graph, with name equal to that of node, prefixed
- // with "ArithmeticOptimizer/" and the given suffix. Also updates node_map_,
- // and optionally copies node into the new node if copy_node is true.
- NodeDef* AddNode(const NodeDef& node, StringPiece suffix, bool copy_node);
-
- // Creates a new node in the graph, prefixed with "ArithmeticOptimizer/",
- // updates node_map_, and optionally copies *node_to_copy into the new
- // node, if node_to_copy is not nullptr.
- NodeDef* AddNode(const string& name, const NodeDef* node_to_copy);
-
// Returns true if it is safe to dedup node from the graph.
bool CanDedup(const NodeDef& node) const;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 02f76df025..c387b00303 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -40,21 +40,37 @@ constexpr char kHoistFactorOptimizerMul[] =
constexpr char kHoistFactorOptimizerAdd[] =
"ArithmeticOptimizer/HoistCommonFactor_Add_";
-// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation
+constexpr char kSimplifyAggregationConst[] =
+ "ArithmeticOptimizer/SimplifyAggregation_Const_";
+
+constexpr char kSimplifyAggregationMul[] =
+ "ArithmeticOptimizer/SimplifyAggregation_Mul_";
+
+// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation.
string HoistMulName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
}
-// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation
+// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation.
string HoistDivName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
}
-// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
+// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation.
string HoistAddName(const string& name) {
return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
}
+// Optimized name of Const node by SimplifyAggregation.
+string AggregationConstName(const string& name) {
+ return AddPrefixToNodeName(name, kSimplifyAggregationConst, "");
+}
+
+// Optimized name of Mul node by SimplifyAggregation.
+string AggregationMulName(const string& name) {
+ return AddPrefixToNodeName(name, kSimplifyAggregationMul, "");
+}
+
string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
@@ -123,9 +139,14 @@ class ArithmeticOptimizerTest : public GrapplerTest {
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
options.dedup_computations = false;
- options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.convert_sqrt_div_to_rsqrt_mul = false;
+ options.convert_pow = false;
+ options.convert_log1p = false;
+ options.optimize_max_or_min_of_monotonic = false;
+ options.fold_conjugate_into_transpose = false;
+ options.fold_multiply_into_conv = false;
+ options.fold_transpose_into_matmul = false;
options.hoist_common_factor_out_of_aggregation = false;
options.hoist_cwise_unary_chains = false;
options.minimize_broadcasts = false;
@@ -138,6 +159,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
options.remove_negation = false;
options.remove_logical_not = false;
options.reorder_cast_and_transpose = false;
+ options.replace_mul_with_square = false;
+ options.simplify_aggregation = false;
+ options.unary_ops_composition = false;
optimizer->options_ = options;
}
@@ -150,6 +174,21 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.combine_add_to_addn = true;
}
+ void EnableOnlyFoldConjugateIntoTranspose(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.fold_conjugate_into_transpose = true;
+ }
+
+ void EnableOnlyFoldMultipleIntoConv(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.fold_multiply_into_conv = true;
+ }
+
+ void EnableOnlyFoldTransposeIntoMatMul(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.fold_transpose_into_matmul = true;
+ }
+
void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_common_factor_out_of_aggregation = true;
@@ -195,6 +234,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.reorder_cast_and_transpose = true;
}
+ void EnableOnlyReplaceMulWithSquare(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.replace_mul_with_square = true;
+ }
+
void EnableOnlyHoistCWiseUnaryChains(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.hoist_cwise_unary_chains = true;
@@ -205,6 +249,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.convert_sqrt_div_to_rsqrt_mul = true;
}
+ void EnableOnlyConvertPow(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_pow = true;
+ }
+
void EnableOnlyRemoveIdempotent(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_idempotent = true;
@@ -214,6 +263,26 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.remove_logical_not = true;
}
+
+ void EnableOnlySimplifyAggregation(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.simplify_aggregation = true;
+ }
+
+ void EnableOnlyLog1p(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.convert_log1p = true;
+ }
+
+ void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.optimize_max_or_min_of_monotonic = true;
+ }
+
+ void EnableOnlyUnaryOpsComposition(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.unary_ops_composition = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -339,33 +408,36 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
-TEST_F(ArithmeticOptimizerTest, MulToSquare) {
+TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
Output id = ops::Identity(s.WithOpName("id"), mul);
+
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
- ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ ArithmeticOptimizer optimizer;
+ EnableOnlyReplaceMulWithSquare(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("id", output.node(3).name());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0));
- EXPECT_EQ("Square", output.node(4).op());
- EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name());
- EXPECT_EQ(2, output.node(4).input_size());
- EXPECT_EQ("c", output.node(4).input(0));
- EXPECT_EQ("^d", output.node(4).input(1));
+ EXPECT_EQ(4, output.node_size());
- auto tensors = EvaluateNodes(output, fetch);
+ NodeMap node_map(&output);
+ const string p = "ArithmeticOptimizer/ReplaceMulWithSquare";
+ const NodeDef* square_node = node_map.GetNode(strings::StrCat(p, "_", "mul"));
+
+ ASSERT_NE(square_node, nullptr);
+ EXPECT_EQ("Square", square_node->op());
+ EXPECT_EQ("c", square_node->input(0));
+ EXPECT_EQ("^d", square_node->input(1));
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -380,12 +452,10 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
auto recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
auto id = ops::Identity(s.WithOpName("id"), recip2);
- std::vector<string> fetch = {"id"};
-
GrapplerItem item;
- item.fetch = fetch;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
GraphDef output;
@@ -398,7 +468,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveInvolution_AdjacentNodes) {
EXPECT_EQ("id", output.node(1).name());
EXPECT_EQ("c", output.node(1).input(0));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -487,10 +557,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
Output id = ops::Identity(s.WithOpName("id"), add);
GrapplerItem item;
+ item.fetch = {"id"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"id"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
@@ -500,22 +570,25 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
EXPECT_EQ(5, output.node_size());
- const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ const string optimized_const_name = AggregationConstName("add");
+ const string optimized_mul_name = AggregationMulName("add");
+
+ const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
- const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
ASSERT_NE(new_mul, nullptr);
- EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ(optimized_const_name, new_mul->input(0));
EXPECT_EQ("x", new_mul->input(1));
const NodeDef* new_id = node_map.GetNode("id");
ASSERT_NE(new_id, nullptr);
- EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+ EXPECT_EQ(optimized_mul_name, new_id->input(0));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -541,21 +614,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
EXPECT_EQ(6, output.node_size());
- const NodeDef* new_const = node_map.GetNode(OptimizedName("add_const"));
+ const string optimized_const_name = AggregationConstName("add");
+ const string optimized_mul_name = AggregationMulName("add");
+
+ const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
EXPECT_EQ(std::string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
- const NodeDef* new_mul = node_map.GetNode(OptimizedName("add_mul"));
+ const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
ASSERT_NE(new_mul, nullptr);
- EXPECT_EQ(OptimizedName("add_const"), new_mul->input(0));
+ EXPECT_EQ(optimized_const_name, new_mul->input(0));
EXPECT_EQ("x", new_mul->input(1));
EXPECT_EQ("^y", new_mul->input(2));
const NodeDef* new_id = node_map.GetNode("id");
ASSERT_NE(new_id, nullptr);
- EXPECT_EQ(OptimizedName("add_mul"), new_id->input(0));
+ EXPECT_EQ(optimized_mul_name, new_id->input(0));
auto tensors = EvaluateNodes(output, fetch);
EXPECT_EQ(1, tensors.size());
@@ -620,24 +696,24 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
ASSERT_NE(add_4_node, nullptr);
EXPECT_EQ("Add", add_4_node->op());
EXPECT_EQ(2, add_4_node->input_size());
- EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
+ EXPECT_EQ(AggregationConstName("Add"), add_4_node->input(0));
+ EXPECT_EQ(AggregationConstName("Add_1"), add_4_node->input(1));
const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
ASSERT_NE(add_5_node, nullptr);
EXPECT_EQ("Add", add_5_node->op());
EXPECT_EQ(2, add_5_node->input_size());
- EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
- EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
+ EXPECT_EQ(AggregationConstName("Add"), add_5_node->input(0));
+ EXPECT_EQ(AggregationConstName("Add_1"), add_5_node->input(1));
- const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
+ const NodeDef* add_const_node = node_map.GetNode(AggregationConstName("Add"));
ASSERT_NE(add_const_node, nullptr);
EXPECT_EQ("Const", add_const_node->op());
EXPECT_EQ(1, add_const_node->input_size());
EXPECT_EQ("^Placeholder", add_const_node->input(0));
const NodeDef* add_1_const_node =
- node_map.GetNode(OptimizedName("Add_1_const"));
+ node_map.GetNode(AggregationConstName("Add_1"));
ASSERT_NE(add_1_const_node, nullptr);
EXPECT_EQ("Const", add_1_const_node->op());
EXPECT_EQ(1, add_1_const_node->input_size());
@@ -804,11 +880,14 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output conj = ops::Conj(s.WithOpName("conj"), z);
Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
+
GrapplerItem item;
+ item.fetch = {"trans"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"trans"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
+
ArithmeticOptimizer optimizer;
GraphDef output;
OptimizeTwice(&optimizer, &item, &output);
@@ -816,20 +895,23 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
EXPECT_EQ(7, output.node_size());
- const NodeDef* trans_fused_node =
- node_map.GetNode(OptimizedName("trans_fused"));
+ const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
+ const string optimized_name = strings::StrCat(p, "_", "trans");
+
+ const NodeDef* trans_fused_node = node_map.GetNode(optimized_name);
ASSERT_NE(trans_fused_node, nullptr);
EXPECT_EQ("ConjugateTranspose", trans_fused_node->op());
EXPECT_EQ("z", trans_fused_node->input(0));
EXPECT_EQ("perm", trans_fused_node->input(1));
- auto tensors = EvaluateNodes(output, fetch);
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
Output im = ops::Const(s.WithOpName("im"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
Output z = ops::Complex(s.WithOpName("z"), re, im);
@@ -837,10 +919,12 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
Output conj = ops::Conj(s.WithOpName("conj"), z);
Output transp =
ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
+
GrapplerItem item;
+ item.fetch = {"conjugate_trans"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"conjugate_trans"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
@@ -850,12 +934,16 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
EXPECT_EQ(7, output.node_size());
- const NodeDef* conjugate_trans_fused_node =
- node_map.GetNode(OptimizedName("conjugate_trans_fused"));
+ const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
+ const string optimized_name = strings::StrCat(p, "_", "conjugate_trans");
+
+ const NodeDef* conjugate_trans_fused_node = node_map.GetNode(optimized_name);
+ ASSERT_NE(conjugate_trans_fused_node, nullptr);
EXPECT_EQ("Transpose", conjugate_trans_fused_node->op());
EXPECT_EQ("z", conjugate_trans_fused_node->input(0));
EXPECT_EQ("perm", conjugate_trans_fused_node->input(1));
- auto tensors = EvaluateNodes(output, fetch);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
@@ -868,10 +956,12 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
Output conj = ops::Conj(s.WithOpName("conj"), trans);
+
GrapplerItem item;
+ item.fetch = {"conj"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"conj"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
@@ -881,12 +971,16 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
EXPECT_EQ(7, output.node_size());
- const NodeDef* conj_fused_node =
- node_map.GetNode(OptimizedName("conj_fused"));
+ const string p = "ArithmeticOptimizer/FoldConjugateIntoTranspose";
+ const string optimized_name = strings::StrCat(p, "_", "conj");
+
+ const NodeDef* conj_fused_node = node_map.GetNode(optimized_name);
+ ASSERT_NE(conj_fused_node, nullptr);
EXPECT_EQ("ConjugateTranspose", conj_fused_node->op());
EXPECT_EQ("z", conj_fused_node->input(0));
EXPECT_EQ("perm", conj_fused_node->input(1));
- auto tensors = EvaluateNodes(output, fetch);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorEqual<complex64>(tensors_expected[0], tensors[0]);
}
@@ -894,38 +988,45 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
for (const string matmul_type : {"MatMul", "SparseMatMul", "BatchMatMul"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
+
+ auto matmul_op = s.WithOpName("matmul");
if (matmul_type == "MatMul") {
- Output matmul = ops::MatMul(s.WithOpName("matmul"), trans_a, trans_b);
+ Output matmul = ops::MatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "SparseMatMul") {
- Output matmul =
- ops::SparseMatMul(s.WithOpName("matmul"), trans_a, trans_b);
+ Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "BatchMatMul") {
- Output matmul =
- ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
+ Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
}
+
GrapplerItem item;
+ item.fetch = {"matmul"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"matmul"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
+ EnableOnlyFoldTransposeIntoMatMul(&optimizer);
GraphDef output;
OptimizeTwice(&optimizer, &item, &output);
NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
- const NodeDef* matmul_fused_node =
- node_map.GetNode(OptimizedName("matmul_fused"));
+ const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
+ const string optimized_name = strings::StrCat(p, "_", "matmul");
+
+ const NodeDef* matmul_fused_node = node_map.GetNode(optimized_name);
ASSERT_NE(matmul_fused_node, nullptr);
EXPECT_EQ("a", matmul_fused_node->input(0));
EXPECT_EQ("b", matmul_fused_node->input(1));
+
if (matmul_type == "BatchMatMul") {
EXPECT_TRUE(matmul_fused_node->attr().at("adj_x").b());
EXPECT_TRUE(matmul_fused_node->attr().at("adj_y").b());
@@ -933,7 +1034,8 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_a").b());
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
}
- auto tensors = EvaluateNodes(output, fetch);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -941,6 +1043,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
Output re_a =
ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
Output im_a =
@@ -955,24 +1058,32 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
+
GrapplerItem item;
+ item.fetch = {"matmul"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
- std::vector<string> fetch = {"matmul"};
- auto tensors_expected = EvaluateNodes(item.graph, fetch);
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
EXPECT_EQ(1, tensors_expected.size());
ArithmeticOptimizer optimizer;
GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(11, output.node_size());
- EXPECT_EQ(OptimizedName("matmul_fused"), output.node(10).name());
- EXPECT_EQ("a", output.node(10).input(0));
- EXPECT_EQ("b", output.node(10).input(1));
- EXPECT_TRUE(output.node(10).attr().at("adj_x").b());
- EXPECT_TRUE(output.node(10).attr().at("adj_y").b());
- auto tensors = EvaluateNodes(output, fetch);
+ NodeMap node_map(&output);
+ ASSERT_EQ(11, output.node_size());
+
+ const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
+ const string optimized_name = strings::StrCat(p, "_", "matmul");
+
+ const NodeDef* optimized_matmul = node_map.GetNode(optimized_name);
+ ASSERT_NE(optimized_matmul, nullptr);
+ EXPECT_EQ("a", optimized_matmul->input(0));
+ EXPECT_EQ("b", optimized_matmul->input(1));
+ EXPECT_TRUE(optimized_matmul->attr().at("adj_x").b());
+ EXPECT_TRUE(optimized_matmul->attr().at("adj_y").b());
+
+ auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<complex64>(tensors_expected[0], tensors[0], 1e-6);
}
@@ -1418,7 +1529,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- ArithmeticOptimizer optimizer;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
EnableOnlyRemoveIdentityTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
@@ -1462,18 +1573,24 @@ TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
-
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyFoldMultipleIntoConv(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output);
NodeMap node_map(&output);
+
// `conv` is now a folded convolution with scaled weights.
const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
- CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul");
+ ASSERT_NE(folded_conv, nullptr);
+
+ const NodeDef* folded_conv_weights = node_map.GetNode(folded_conv->input(1));
+ ASSERT_NE(folded_conv_weights, nullptr);
+ EXPECT_EQ("Mul", folded_conv_weights->op());
+
// Its input should be a transpose of `inputs`.
const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0)));
- CHECK_EQ(NodeName(transpose->input(0)), inputs.node()->name());
+ ASSERT_NE(transpose, nullptr);
+ EXPECT_EQ("inputs", transpose->input(0));
}
TEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
@@ -1574,28 +1691,32 @@ TEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- ArithmeticOptimizer optimizer;
+ ArithmeticOptimizer optimizer; // all optimization stages are on
OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
NodeMap node_map(&output);
- // Expected names for the optimized nodes.
+ // Expected names for reordered cast and transpose.
const string p = "ArithmeticOptimizer/ReorderCastAndTranspose_";
const string optimized_cast_name = strings::StrCat(p, "float_Cast");
const string optimized_transpose_name = strings::StrCat(p, "uint8_Transpose");
+ // Expected names for folded multiply and conv.
+ const string optimized_weights =
+ "ArithmeticOptimizer/FoldMultiplyIntoConv_scaled_Conv2D_weights";
+
const NodeDef* inputs_node = node_map.GetNode("Placeholder");
const NodeDef* transpose_node = node_map.GetNode(optimized_transpose_name);
const NodeDef* cast_node = node_map.GetNode(optimized_cast_name);
- const NodeDef* weights_node =
- node_map.GetNode(OptimizedName("weights_scaled_Conv2D"));
+
+ const NodeDef* weights_node = node_map.GetNode(optimized_weights);
const NodeDef* conv_node = node_map.GetNode("Conv2D");
- ASSERT_TRUE(inputs_node != nullptr);
- ASSERT_TRUE(transpose_node != nullptr);
- ASSERT_TRUE(cast_node != nullptr);
- ASSERT_TRUE(weights_node != nullptr);
- ASSERT_TRUE(conv_node != nullptr);
+ ASSERT_NE(inputs_node, nullptr);
+ ASSERT_NE(transpose_node, nullptr);
+ ASSERT_NE(cast_node, nullptr);
+ ASSERT_NE(weights_node, nullptr);
+ ASSERT_NE(conv_node, nullptr);
EXPECT_EQ(output.node_size(), 7);
EXPECT_EQ(transpose_node->input(0), inputs_node->name());
@@ -1627,23 +1748,27 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
+ ArithmeticOptimizer optimizer;
+ EnableOnlyFoldMultipleIntoConv(&optimizer);
+ OptimizeTwiceAndPrune(&optimizer, &item, &output, /*const_folding=*/true);
- item.graph.Swap(&output);
- TF_EXPECT_OK(
- ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
+ NodeMap node_map(&output);
- item.graph.Swap(&output);
- TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
+ using strings::StrCat;
+ const string p = "ArithmeticOptimizer/FoldMultiplyIntoConv_";
+ const string optimized_weights = StrCat(p, "scaled_Conv2D_weights");
+ const string optimized_weights_1 = StrCat(p, "scaled_Conv2D_1_weights_1");
- NodeMap node_map(&output);
- const NodeDef* weights_node =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D")));
- const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
+ const NodeDef* weights_node = node_map.GetNode(optimized_weights);
+ const NodeDef* weights_node_1 = node_map.GetNode(optimized_weights_1);
+ const NodeDef* conv_node = node_map.GetNode("Conv2D");
+ const NodeDef* conv_node_1 = node_map.GetNode("Conv2D_1");
+
+ ASSERT_NE(weights_node, nullptr);
+ ASSERT_NE(weights_node_1, nullptr);
+ ASSERT_NE(conv_node, nullptr);
+ ASSERT_NE(conv_node_1, nullptr);
- const NodeDef* weights_node_1 =
- CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D_1")));
- const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1"));
EXPECT_EQ(conv_node->input(1), weights_node->name());
EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
}
@@ -2328,6 +2453,95 @@ TEST_F(ArithmeticOptimizerTest, ConvertSqrtDivToRsqrtMul) {
}
}
+TEST_F(ArithmeticOptimizerTest, ConvertPow) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ auto y2 = ops::Const(s.WithOpName("y2"), {2.0f, 2.0f}, {1, 2});
+ auto y1 = ops::Const(s.WithOpName("y1"), {1.0f, 1.0f}, {1, 2});
+ auto yPoint5 = ops::Const(s.WithOpName("y.5"), {0.5f, 0.5f}, {1, 2});
+ auto y0 = ops::Const(s.WithOpName("y0"), {0.0f, 0.0f}, {1, 2});
+ auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
+ auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
+ auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
+ Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
+ Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
+ Output out0 = ops::Pow(s.WithOpName("out0"), x, y0);
+ Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
+ Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
+ Output out = ops::Pow(s.WithOpName("out"), x, y);
+
+ GrapplerItem item;
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(7, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyConvertPow(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(7, tensors.size());
+
+ GraphDef want;
+ AddNode("x", "Const", {}, {}, &want);
+ AddNode("y2", "Const", {}, {}, &want);
+ AddNode("y1", "Const", {}, {}, &want);
+ AddNode("y.5", "Const", {}, {}, &want);
+ AddNode("y0", "Const", {}, {}, &want);
+ AddNode("y_.5", "Const", {}, {}, &want);
+ AddNode("y_1", "Const", {}, {}, &want);
+ AddNode("y", "Const", {}, {}, &want);
+ AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
+ AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
+ AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
+ AddNode("out0", "Const",
+ {AsControlDependency("x"), AsControlDependency("y0")}, {}, &want);
+ AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
+ AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
+ AddNode("out", "Pow", {"x", "y"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ArithmeticOptimizerTest, Log1p) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x1 = ops::Const(s.WithOpName("x1"), {1.0f, 1.0f}, {1, 2});
+ auto x2 = ops::Const(s.WithOpName("x2"), {2.0f, 2.0f}, {1, 2});
+ auto x3 = ops::Const(s.WithOpName("x3"), {3.0f, 3.0f}, {1, 2});
+ auto a12 = ops::Add(s.WithOpName("a12").WithControlDependencies(x3), x1, x2);
+ auto a23 = ops::Add(s.WithOpName("a23"), x2, x3);
+ Output out1 = ops::Log(s.WithOpName("out1"), a12);
+ Output out2 = ops::Log(s.WithOpName("out2"), a23);
+
+ GrapplerItem item;
+ item.fetch = {"out1", "out2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef got;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyLog1p(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &got);
+ auto tensors = EvaluateNodes(got, item.fetch);
+ EXPECT_EQ(2, tensors.size());
+
+ GraphDef want;
+ AddNode("x1", "Const", {}, {}, &want);
+ AddNode("x2", "Const", {}, {}, &want);
+ AddNode("x3", "Const", {}, {}, &want);
+ AddNode("a23", "Add", {"x2", "x3"}, {}, &want);
+ AddNode("out1", "Log1p",
+ {"x2", AsControlDependency("x1"), AsControlDependency("x3")}, {},
+ &want);
+ AddNode("out2", "Log", {"a23"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_SimpleSwap) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -2771,12 +2985,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
- Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
- Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
- Output sn1 =
- ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a);
- Output sn2 =
- ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1);
+ Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
+ Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
Output id1 = ops::Identity(s.WithOpName("id1"), a);
Output id2 = ops::Identity(s.WithOpName("id2"), id1);
@@ -2792,32 +3002,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
EnableOnlyRemoveIdempotent(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(11, output.node_size());
+ EXPECT_EQ(7, output.node_size());
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0));
- found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") {
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("a", node.input(0));
- EXPECT_EQ("^ctrl1", node.input(1));
- EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("sn1", node.input(0));
found++;
} else if (node.name() == "out2") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0));
+ EXPECT_EQ("id1", node.input(0));
found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") {
- EXPECT_EQ("Identity", node.op());
+ } else if (node.name() == "sn1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("a", node.input(0));
found++;
}
}
- EXPECT_EQ(4, found);
+ EXPECT_EQ(3, found);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
@@ -2925,5 +3127,103 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
}
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "sqrt") {
+ EXPECT_EQ("Sqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Max", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
+TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output log = ops::Log(s.WithOpName("log"), sqrt);
+ Output relu = ops::Relu(s.WithOpName("relu"), log);
+ Output final_out = ops::Identity(s.WithOpName("final_out"), relu);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Place all nodes on CPU.
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ item.graph.mutable_node(i)->set_device("/device:CPU:0");
+ }
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyUnaryOpsComposition(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ EXPECT_EQ(3, output.node_size());
+
+ // Check that Sqrt/Log/Relu were replaced with a single op.
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "final_out") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("relu/unary_ops_composition", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "relu/unary_ops_composition") {
+ EXPECT_EQ("_UnaryOpsComposition", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+
+ auto op_names = node.attr().at("op_names").list().s();
+ EXPECT_EQ(3, op_names.size());
+ EXPECT_EQ("Sqrt", op_names[0]);
+ EXPECT_EQ("Log", op_names[1]);
+ EXPECT_EQ("Relu", op_names[2]);
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7f0c2a2116..76c928f995 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -354,12 +354,14 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
if (op == "TensorArraySizeV3") {
- const NodeDef* array = node_map_->GetNode(node->input(0));
- if (array->attr().count("dynamic_size") != 0 &&
- array->attr().at("dynamic_size").b()) {
+ const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
+ if (array->input_size() == 0 ||
+ (array->attr().count("dynamic_size") != 0 &&
+ array->attr().at("dynamic_size").b())) {
continue;
}
- const NodeDef* array_size = node_map_->GetNode(array->input(0));
+ const NodeDef* array_size =
+ CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
if (IsReallyConstant(*array_size)) {
// Don't materialize 0 sizes to avoid triggering incorrect static
// checks. A 0 sized array that can't grow isn't useful anyway.
@@ -374,6 +376,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (value.flat<int32>()(0) == 0) {
continue;
}
+
node->set_op("Const");
*node->mutable_attr() = array_size->attr();
node->set_input(0, AsControlDependency(NodeName(node->input(0))));
@@ -2185,8 +2188,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
node->add_input(axis_node->name());
if (node->input_size() > 2) {
node->mutable_input()->SwapElements(1, node->input_size() - 1);
- return true;
}
+ return true;
}
return false;
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 9f051ca248..b9765b9292 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3000,6 +3000,10 @@ TEST_F(ConstantFoldingTest, Enter) {
TEST_F(ConstantFoldingTest, TensorArraySize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+ Output placeholder =
+ ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
+ ops::Placeholder::Shape(TensorShape({2})));
+ Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
auto dynamic_array =
ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
ops::TensorArray::DynamicSize(true));
@@ -3010,6 +3014,8 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
static_array.handle, static_array.flow);
+ auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
+ placeholder, foo);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
@@ -3026,11 +3032,13 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("dynamic_sz", output.node(3).name());
- EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
- EXPECT_EQ("static_sz", output.node(4).name());
- EXPECT_EQ("Const", output.node(4).op());
+ EXPECT_EQ(8, output.node_size());
+ EXPECT_EQ("dynamic_sz", output.node(5).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
+ EXPECT_EQ("static_sz", output.node(6).name());
+ EXPECT_EQ("Const", output.node(6).op());
+ EXPECT_EQ("placeholder_sz", output.node(7).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
EXPECT_EQ(2, tensors_expected.size());
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
index 3148a5f809..0b8e0b692a 100644
--- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
@@ -50,7 +50,7 @@ class CustomGraphOptimizerRegistrar {
#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \
namespace { \
- static CustomGraphOptimizerRegistrar \
+ static ::tensorflow::grappler::CustomGraphOptimizerRegistrar \
MyCustomGraphOptimizerClass##_registrar( \
[]() { return new MyCustomGraphOptimizerClass; }, (name)); \
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 121de1e089..3cb9d4d61c 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -4,6 +4,39 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
+ name = "function_rename",
+ srcs = ["function_rename.cc"],
+ hdrs = [
+ "function_rename.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_rename_test",
+ srcs = ["function_rename_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_rename",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -68,10 +101,81 @@ tf_cc_test(
)
cc_library(
+ name = "noop_elimination",
+ srcs = ["noop_elimination.cc"],
+ hdrs = [
+ "noop_elimination.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "noop_elimination_test",
+ srcs = ["noop_elimination_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":noop_elimination",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
+ name = "shuffle_and_repeat_fusion",
+ srcs = ["shuffle_and_repeat_fusion.cc"],
+ hdrs = [
+ "shuffle_and_repeat_fusion.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "shuffle_and_repeat_fusion_test",
+ srcs = ["shuffle_and_repeat_fusion_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":shuffle_and_repeat_fusion",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
+ ":function_rename",
":map_and_batch_fusion",
+ ":noop_elimination",
+ ":shuffle_and_repeat_fusion",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc
new file mode 100644
index 0000000000..8cf044d1bd
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename.cc
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
+
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ int n = output->mutable_library()->function_size();
+ for (int i = 0; i < n; ++i) {
+ FunctionDef* fn = output->mutable_library()->mutable_function(i);
+ fn->mutable_signature()->set_name(fn->signature().name() + "world");
+ }
+
+ return Status::OK();
+}
+
+void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/function_rename.h
new file mode 100644
index 0000000000..23ad9470ff
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class FunctionRename : public CustomGraphOptimizer {
+ public:
+ FunctionRename() = default;
+ ~FunctionRename() override = default;
+
+ string name() const override { return "_test_only_function_rename"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
new file mode 100644
index 0000000000..56b8a960a7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(FunctionRenameTest, RenameFunction) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ FunctionDef *fn = graph->mutable_library()->add_function();
+ fn->mutable_signature()->set_name("hello");
+
+ FunctionRename optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_EQ(output.library().function(0).signature().name(), "helloworld");
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index df12de37da..b5b46ccafe 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -28,6 +28,8 @@ namespace grappler {
namespace graph_utils {
namespace {
+constexpr char kConstOpName[] = "Const";
+
int FindNodeWithPredicate(const std::function<bool(const NodeDef&)>& predicate,
const GraphDef& graph) {
for (int i = 0; i < graph.node_size(); ++i) {
@@ -68,9 +70,8 @@ Status AddScalarConstNodeHelper(
DataType dtype, const std::function<void(TensorProto*)>& add_value,
GraphDef* graph, NodeDef** result) {
NodeDef* node = graph->add_node();
- const string& name = strings::StrCat("Const/_", graph->node_size());
- node->set_name(name);
- node->set_op("Const");
+ node->set_op(kConstOpName);
+ SetUniqueName(kConstOpName, graph, node);
(*node->mutable_attr())["dtype"].set_type(dtype);
std::unique_ptr<tensorflow::TensorProto> tensor =
tensorflow::MakeUnique<tensorflow::TensorProto>();
@@ -94,7 +95,7 @@ Status AddNode(const string& name, const string& op,
if (!name.empty()) {
node->set_name(name);
} else {
- node->set_name(strings::StrCat(op, "/_", graph->node_size()));
+ SetUniqueName(op, graph, node);
}
node->set_op(op);
for (const string& input : inputs) {
@@ -212,6 +213,22 @@ int FindNodeWithOp(const string& op, const GraphDef& graph) {
[op](const NodeDef& node) { return node.op() == op; }, graph);
}
+void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node) {
+ int id = graph->node_size();
+ while (ContainsNodeWithName(strings::StrCat(op, "/_", id), *graph)) {
+ ++id;
+ }
+ node->set_name(strings::StrCat(op, "/_", id));
+}
+
+void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
+ GraphView* graph) {
+ GraphView::OutputPort output_port = graph->GetOutputPort(old_input.name(), 0);
+ auto fanout = graph->GetFanout(output_port);
+ for (auto& input_port : fanout)
+ input_port.node->set_input(0, new_input.name());
+}
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b40ca44d78..1cb0f0c81d 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -74,6 +75,14 @@ int FindNodeWithName(const string& name, const GraphDef& graph);
// exists.
int FindNodeWithOp(const string& op, const GraphDef& graph);
+// Sets the node name using the op name as a prefix while guaranteeing the name
+// is unique across the graph.
+void SetUniqueName(const string& op, GraphDef* graph, NodeDef* node);
+
+// Replaces the input for the output nodes of 'old_input' with 'new_input'.
+void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
+ GraphView* graph);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index b34726044e..d723d73b7a 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -23,9 +23,7 @@ namespace grappler {
namespace graph_utils {
namespace {
-class GraphUtilsTest : public ::testing::Test {};
-
-TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
+TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph;
NodeDef* bool_node;
TF_EXPECT_OK(AddScalarConstNode<bool>(true, &graph, &bool_node));
@@ -33,7 +31,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeBool) {
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
+TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
GraphDef graph;
NodeDef* double_node;
TF_EXPECT_OK(AddScalarConstNode<double>(3.14, &graph, &double_node));
@@ -41,7 +39,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeDouble) {
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
+TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
GraphDef graph;
NodeDef* float_node;
TF_EXPECT_OK(AddScalarConstNode<float>(3.14, &graph, &float_node));
@@ -49,7 +47,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeFloat) {
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
+TEST(GraphUtilsTest, AddScalarConstNodeInt) {
GraphDef graph;
NodeDef* int_node;
TF_EXPECT_OK(AddScalarConstNode<int>(42, &graph, &int_node));
@@ -57,7 +55,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt) {
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
+TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
GraphDef graph;
NodeDef* int64_node;
TF_EXPECT_OK(AddScalarConstNode<int64>(42, &graph, &int64_node));
@@ -65,7 +63,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeInt64) {
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
}
-TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
+TEST(GraphUtilsTest, AddScalarConstNodeString) {
GraphDef graph;
NodeDef* string_node;
TF_EXPECT_OK(AddScalarConstNode<StringPiece>("hello", &graph, &string_node));
@@ -73,7 +71,7 @@ TEST_F(GraphUtilsTest, AddScalarConstNodeString) {
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
}
-TEST_F(GraphUtilsTest, Compare) {
+TEST(GraphUtilsTest, Compare) {
GraphDef graphA;
GraphDef graphB;
EXPECT_TRUE(Compare(graphA, graphB));
@@ -88,7 +86,7 @@ TEST_F(GraphUtilsTest, Compare) {
EXPECT_TRUE(Compare(graphA, graphB));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithName) {
+TEST(GraphUtilsTest, ContainsNodeWithName) {
GraphDef graph;
EXPECT_TRUE(!ContainsNodeWithName("A", graph));
@@ -100,7 +98,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithName) {
EXPECT_TRUE(!ContainsNodeWithName("A", graph));
}
-TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
+TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph;
EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
@@ -112,7 +110,7 @@ TEST_F(GraphUtilsTest, ContainsNodeWithOp) {
EXPECT_TRUE(!ContainsNodeWithOp("OpA", graph));
}
-TEST_F(GraphUtilsTest, FindNodeWithName) {
+TEST(GraphUtilsTest, FindNodeWithName) {
GraphDef graph;
EXPECT_EQ(FindNodeWithName("A", graph), -1);
@@ -124,7 +122,7 @@ TEST_F(GraphUtilsTest, FindNodeWithName) {
EXPECT_EQ(FindNodeWithName("A", graph), -1);
}
-TEST_F(GraphUtilsTest, FindNodeWithOp) {
+TEST(GraphUtilsTest, FindNodeWithOp) {
GraphDef graph;
EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
@@ -136,6 +134,41 @@ TEST_F(GraphUtilsTest, FindNodeWithOp) {
EXPECT_EQ(FindNodeWithOp("OpA", graph), -1);
}
+TEST(GraphUtilsTest, SetUniqueName) {
+ GraphDef graph;
+
+ NodeDef* node1;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node1));
+ NodeDef* node2;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node2));
+ EXPECT_NE(node1->name(), node2->name());
+
+ TF_EXPECT_OK(DeleteNodes({node1->name()}, &graph));
+ NodeDef* node3;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node3));
+ EXPECT_NE(node2->name(), node3->name());
+}
+
+TEST(GraphUtilsTest, ReplaceInput) {
+ GraphDef graph;
+
+ NodeDef* node1;
+ TF_EXPECT_OK(AddNode("", "A", {}, {}, &graph, &node1));
+
+ NodeDef* node2;
+ TF_EXPECT_OK(AddNode("", "A", {node1->name()}, {}, &graph, &node2));
+
+ NodeDef* node3;
+ TF_EXPECT_OK(AddNode("", "A", {node2->name()}, {}, &graph, &node3));
+
+ EXPECT_EQ(node3->input(0), node2->name());
+
+ GraphView view(&graph);
+ ReplaceInput(*node2, *node1, &view);
+
+ EXPECT_EQ(node3->input(0), node1->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 290326ab75..eac665bd92 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -28,6 +28,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+namespace {
+
+constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
+
+} // namespace
Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
@@ -35,25 +40,24 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphView graph(output);
std::set<string> nodes_to_delete;
for (const NodeDef& node : item.graph.node()) {
- if (node.op() != "BatchDataset") {
+ if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
continue;
}
- // Use a more descriptive variable name now that we now the node type.
- NodeDef batch_node(node);
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef batch_node(node);
GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
NodeDef* node2 = graph.GetRegularFanin(input_port).node;
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
- // Use a more descriptive variable name now that we now the node type.
- NodeDef* map_node = node2;
- NodeDef* new_node = output->mutable_node()->Add();
- new_node->set_op("MapAndBatchDatasetV2");
- new_node->set_name(
- strings::StrCat("MapAndBatchDatasetV2/_", output->node_size()));
+ NodeDef* new_node = output->add_node();
+ new_node->set_op(kFusedOpName);
+ graph_utils::SetUniqueName(kFusedOpName, output, new_node);
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* map_node = node2;
// Set the `input` input argument.
new_node->add_input(map_node->input(0));
@@ -89,7 +93,9 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Set the `drop_remainder` input argument.
- {
+ if (batch_node.op() == "BatchDatasetV2") {
+ new_node->add_input(batch_node.input(2));
+ } else {
NodeDef* tmp;
TF_RETURN_IF_ERROR(
graph_utils::AddScalarConstNode<bool>(false, output, &tmp));
@@ -109,15 +115,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
nodes_to_delete.insert(map_node->name());
nodes_to_delete.insert(batch_node.name());
- // Update the input of the outputs of the `Batch` node to use
- // `MapAndBatch`.
- GraphView::OutputPort output_port =
- graph.GetOutputPort(batch_node.name(), 0);
- auto fanout = graph.GetFanout(output_port);
- for (auto it = fanout.begin(); it != fanout.end(); ++it) {
- NodeDef* node = it->node;
- node->set_input(0, new_node->name());
- }
+ graph_utils::ReplaceInput(batch_node, *new_node, &graph);
}
TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
index a5a4d91df6..2c64831105 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.h
@@ -23,13 +23,13 @@ namespace grappler {
class MapAndBatchFusion : public CustomGraphOptimizer {
public:
- MapAndBatchFusion() {}
- ~MapAndBatchFusion() override {}
+ MapAndBatchFusion() = default;
+ ~MapAndBatchFusion() override = default;
string name() const override { return "map_and_batch_fusion"; };
- Status Init(const tensorflow::RewriterConfig_CustomGraphOptimizer* config =
- nullptr) override {
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index 8c7498dc5d..3c1d8d5359 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -112,6 +112,95 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
batch_node->attr().at("output_types")));
}
+TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ std::vector<std::pair<string, AttrValue>> range_attrs;
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ range_attrs, graph, &range_node));
+ NodeDef *captured_input_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<StringPiece>(
+ "hello", graph, &captured_input_node));
+
+ NodeDef *map_node;
+ {
+ std::vector<string> map_inputs(2);
+ map_inputs[0] = range_node->name();
+ map_inputs[1] = captured_input_node->name();
+ std::vector<std::pair<string, AttrValue>> map_attrs(2);
+ AttrValue f_attr;
+ SetAttrValue("f", &f_attr);
+ map_attrs[0] = std::make_pair("f", f_attr);
+ AttrValue args_attr;
+ SetAttrValue("Targuments", &args_attr);
+ map_attrs[1] = std::make_pair("Targuments", args_attr);
+ TF_ASSERT_OK(graph_utils::AddNode("", "MapDataset", map_inputs, map_attrs,
+ graph, &map_node));
+ }
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ NodeDef *drop_remainder_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<bool>(true, graph, &drop_remainder_node));
+ NodeDef *batch_node;
+ {
+ std::vector<string> batch_inputs(3);
+ batch_inputs[0] = map_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ batch_inputs[2] = drop_remainder_node->name();
+ std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ batch_attrs[1] = std::make_pair("output_types", types_attr);
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDatasetV2", batch_inputs,
+ batch_attrs, graph, &batch_node));
+ }
+
+ MapAndBatchFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(map_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(batch_node->name(), output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node =
+ output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ EXPECT_EQ(map_and_batch_node.input_size(), 5);
+ EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
+ EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
+ EXPECT_EQ(map_and_batch_node.input(2), batch_node->input(1));
+ NodeDef num_parallel_calls_node = output.node(
+ graph_utils::FindNodeWithName(map_and_batch_node.input(3), output));
+ EXPECT_EQ(num_parallel_calls_node.attr().at("value").tensor().int64_val(0),
+ 1);
+ EXPECT_EQ(map_and_batch_node.input(4), batch_node->input(2));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("f"),
+ map_node->attr().at("f")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("Targuments"),
+ map_node->attr().at("Targuments")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_shapes"),
+ batch_node->attr().at("output_shapes")));
+ EXPECT_TRUE(AreAttrValuesEqual(map_and_batch_node.attr().at("output_types"),
+ batch_node->attr().at("output_types")));
+}
+
TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
GrapplerItem item;
GraphDef *graph = &item.graph;
@@ -204,10 +293,9 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
}
TEST(MapAndBatchFusionTest, NoChange) {
- std::vector<std::pair<string, AttrValue>> empty_attributes;
-
GrapplerItem item;
GraphDef *graph = &item.graph;
+
NodeDef *start_node;
TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
NodeDef *stop_node;
@@ -219,9 +307,27 @@ TEST(MapAndBatchFusionTest, NoChange) {
range_inputs[0] = start_node->name();
range_inputs[1] = stop_node->name();
range_inputs[2] = step_node->name();
+ std::vector<std::pair<string, AttrValue>> range_attrs;
NodeDef *range_node;
TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
- empty_attributes, graph, &range_node));
+ range_attrs, graph, &range_node));
+
+ NodeDef *batch_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(5, graph, &batch_size_node));
+ std::vector<string> batch_inputs(2);
+ batch_inputs[0] = range_node->name();
+ batch_inputs[1] = batch_size_node->name();
+ std::vector<std::pair<string, AttrValue>> batch_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ batch_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ batch_attrs[1] = std::make_pair("output_types", types_attr);
+ NodeDef *batch_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "BatchDataset", batch_inputs,
+ batch_attrs, graph, &batch_node));
MapAndBatchFusion optimizer;
GraphDef output;
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
new file mode 100644
index 0000000000..5670966367
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/noop_elimination.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
+ if (take_node.op() != "TakeDataset") return false;
+
+ const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ // We are looking only for 'take' with negative count.
+ return count_node.attr().at("value").tensor().int64_val(0) < 0;
+}
+
+bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
+ if (skip_node.op() != "SkipDataset") return false;
+
+ const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
+ // We are looking only for skip(0) nodes.
+ return count_node.attr().at("value").tensor().int64_val(0) == 0;
+}
+
+bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
+ if (repeat_node.op() != "RepeatDataset") return false;
+
+ const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
+ // We are looking only for repeat(1) nodes.
+ return count_node.attr().at("value").tensor().int64_val(0) == 1;
+}
+
+bool IsNoOp(const NodeDef& node, const GraphView& graph) {
+ return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
+ IsRepeatOne(node, graph);
+}
+
+} // namespace
+
+Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ std::set<string> nodes_to_delete;
+ for (const NodeDef& node : item.graph.node()) {
+ if (!IsNoOp(node, graph)) continue;
+
+ GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
+ NodeDef* const parent = graph.GetRegularFanin(input_port).node;
+ graph_utils::ReplaceInput(node, *parent, &graph);
+
+ nodes_to_delete.insert(node.name());
+ }
+ TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+ return Status::OK();
+}
+
+void NoOpElimination::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(NoOpElimination, "noop_elimination");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.h b/tensorflow/core/grappler/optimizers/data/noop_elimination.h
new file mode 100644
index 0000000000..c67cea49d5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This class eliminates tf.data transformations such as `take(n)` (for n < 0),
+// `skip(0)`, or `repeat(1)`
+class NoOpElimination : public CustomGraphOptimizer {
+ public:
+ NoOpElimination() = default;
+ ~NoOpElimination() override = default;
+
+ string name() const override { return "noop_elimination"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_NOOP_ELIMINATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
new file mode 100644
index 0000000000..8628b16ea5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/noop_elimination.h"
+#include <tuple>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+std::vector<std::pair<string, AttrValue>> GetCommonAttributes() {
+ AttrValue shapes_attr, types_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ SetAttrValue("output_types", &types_attr);
+ std::vector<std::pair<string, AttrValue>> commonAttributes = {
+ {"output_shapes", shapes_attr}, {"output_types", types_attr}};
+
+ return commonAttributes;
+}
+
+void MakeUnaryNode(GraphDef *graph, const std::string &node_type, int count,
+ string input_node, NodeDef **return_node) {
+ NodeDef *node_count;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(count, graph, &node_count));
+ TF_ASSERT_OK(graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph, return_node));
+}
+
+void MakeCacheNode(GraphDef *graph, string input_node, NodeDef **return_node) {
+ NodeDef *node_filename;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<StringPiece>("", graph, &node_filename));
+ TF_ASSERT_OK(graph_utils::AddNode(
+ "", "CacheDataset", {std::move(input_node), node_filename->name()},
+ GetCommonAttributes(), graph, return_node));
+}
+
+void MakeRangeNode(GraphDef *graph, NodeDef **range_node) {
+ NodeDef *start_node, *stop_node, *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs = {start_node->name(), stop_node->name(),
+ step_node->name()};
+
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ GetCommonAttributes(), graph, range_node));
+}
+
+struct NoOpLastEliminationTest
+ : ::testing::TestWithParam<std::tuple<std::string, int, bool>> {};
+
+// This test checks whether the no-op elimination correctly handles
+// transformations at the end of the pipeline.
+TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ const std::string &node_type = std::get<0>(GetParam());
+ const int node_count = std::get<1>(GetParam());
+ const bool should_keep_node = std::get<2>(GetParam());
+
+ NodeDef *range_node;
+ MakeRangeNode(graph, &range_node);
+
+ NodeDef *node;
+ MakeUnaryNode(graph, node_type, node_count, range_node->name(), &node);
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::ContainsNodeWithName(node->name(), output),
+ should_keep_node);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ BasicRemovalTest, NoOpLastEliminationTest,
+ ::testing::Values(std::make_tuple("TakeDataset", -3, false),
+ std::make_tuple("TakeDataset", -1, false),
+ std::make_tuple("TakeDataset", 0, true),
+ std::make_tuple("TakeDataset", 3, true),
+ std::make_tuple("SkipDataset", -1, true),
+ std::make_tuple("SkipDataset", 0, false),
+ std::make_tuple("SkipDataset", 3, true),
+ std::make_tuple("RepeatDataset", 1, false),
+ std::make_tuple("RepeatDataset", 2, true)));
+
+struct NoOpMiddleEliminationTest
+ : ::testing::TestWithParam<std::tuple<std::string, int, bool>> {};
+
+// This test checks whether the no-op elimination correctly handles
+// transformations int the middle of the pipeline.
+TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ const std::string &node_type = std::get<0>(GetParam());
+ const int node_count = std::get<1>(GetParam());
+ const bool should_keep_node = std::get<2>(GetParam());
+
+ NodeDef *range_node;
+ MakeRangeNode(graph, &range_node);
+
+ NodeDef *node;
+ MakeUnaryNode(graph, node_type, node_count, range_node->name(), &node);
+
+ NodeDef *cache_node;
+ MakeCacheNode(graph, node->name(), &cache_node);
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::ContainsNodeWithName(node->name(), output),
+ should_keep_node);
+ EXPECT_TRUE(graph_utils::ContainsNodeWithName(cache_node->name(), output));
+
+ NodeDef cache_node_out =
+ output.node(graph_utils::FindNodeWithName(cache_node->name(), output));
+
+ EXPECT_EQ(cache_node_out.input_size(), 2);
+ auto last_node_input = (should_keep_node ? node : range_node)->name();
+ EXPECT_EQ(cache_node_out.input(0), last_node_input);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ BasicRemovalTest, NoOpMiddleEliminationTest,
+ ::testing::Values(std::make_tuple("TakeDataset", -1, false),
+ std::make_tuple("TakeDataset", -3, false),
+ std::make_tuple("TakeDataset", 0, true),
+ std::make_tuple("TakeDataset", 3, true),
+ std::make_tuple("SkipDataset", -1, true),
+ std::make_tuple("SkipDataset", 0, false),
+ std::make_tuple("SkipDataset", 3, true),
+ std::make_tuple("RepeatDataset", 1, false),
+ std::make_tuple("RepeatDataset", 2, true)));
+
+using NodesTypes = std::tuple<std::pair<string, int>, std::pair<string, int>>;
+struct NoOpMultipleEliminationTest : ::testing::TestWithParam<NodesTypes> {};
+
+// This test checks whether the no-op elimination correctly removes
+// multiple noop nodes.
+TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<std::pair<string, int>> noop_nodes = {
+ std::get<0>(GetParam()), std::get<1>(GetParam())};
+
+ NodeDef *range_node;
+ MakeRangeNode(graph, &range_node);
+
+ NodeDef *previous = range_node;
+ std::vector<string> nodes_to_remove;
+ nodes_to_remove.reserve(noop_nodes.size());
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node;
+ MakeUnaryNode(graph, noop_node.first, noop_node.second, previous->name(),
+ &node);
+ nodes_to_remove.push_back(node->name());
+ previous = node;
+ }
+
+ NodeDef *cache_node;
+ MakeCacheNode(graph, previous->name(), &cache_node);
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const auto &noop_node_name : nodes_to_remove)
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(noop_node_name, output));
+
+ EXPECT_TRUE(graph_utils::ContainsNodeWithName(cache_node->name(), output));
+
+ NodeDef cache_node_out =
+ output.node(graph_utils::FindNodeWithName(cache_node->name(), output));
+
+ EXPECT_EQ(cache_node_out.input_size(), 2);
+ EXPECT_EQ(cache_node_out.input(0), range_node->name());
+}
+
+const auto *const kTakeNode = new std::pair<string, int>{"TakeDataset", -1};
+const auto *const kSkipNode = new std::pair<string, int>{"SkipDataset", 0};
+const auto *const kRepeatNode = new std::pair<string, int>{"RepeatDataset", 1};
+
+INSTANTIATE_TEST_CASE_P(
+ BasicRemovalTest, NoOpMultipleEliminationTest,
+ ::testing::Combine(::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode),
+ ::testing::Values(*kTakeNode, *kSkipNode,
+ *kRepeatNode)));
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
new file mode 100644
index 0000000000..8332fb0b1e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char kFusedOpName[] = "ShuffleAndRepeatDataset";
+
+} // namespace
+
+Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ GraphView graph(output);
+ std::set<string> nodes_to_delete;
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() != "RepeatDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we know the node type.
+ const NodeDef repeat_node(node);
+ GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
+ NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ if (node2->op() != "ShuffleDataset") {
+ continue;
+ }
+
+ NodeDef* new_node = output->add_node();
+ new_node->set_op(kFusedOpName);
+ graph_utils::SetUniqueName(kFusedOpName, output, new_node);
+
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* shuffle_node = node2;
+
+ // Set the `input` input argument.
+ new_node->add_input(shuffle_node->input(0));
+
+ // Set the `buffer_size` input argument.
+ new_node->add_input(shuffle_node->input(1));
+
+ // Set the `seed` input argument.
+ new_node->add_input(shuffle_node->input(2));
+
+ // Set the `seed2` input argument.
+ new_node->add_input(shuffle_node->input(3));
+
+ // Set the `count` input argument.
+ new_node->add_input(repeat_node.input(1));
+
+ // Set `output_types` and `output_shapes` attributes.
+ for (auto key : {"output_shapes", "output_types"}) {
+ (*new_node->mutable_attr())[key] = repeat_node.attr().at(key);
+ }
+
+ // Mark the `Shuffle` and `Repeat` nodes for removal.
+ nodes_to_delete.insert(shuffle_node->name());
+ nodes_to_delete.insert(repeat_node.name());
+
+ graph_utils::ReplaceInput(repeat_node, *new_node, &graph);
+ }
+ TF_RETURN_IF_ERROR(graph_utils::DeleteNodes(nodes_to_delete, output));
+ return Status::OK();
+}
+
+void ShuffleAndRepeatFusion::Feedback(Cluster* cluster,
+ const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(ShuffleAndRepeatFusion,
+ "shuffle_and_repeat_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h
new file mode 100644
index 0000000000..c8fa53edce
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class ShuffleAndRepeatFusion : public CustomGraphOptimizer {
+ public:
+ ShuffleAndRepeatFusion() = default;
+ ~ShuffleAndRepeatFusion() override = default;
+
+ string name() const override { return "shuffle_and_repeat_fusion"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_SHUFFLE_AND_REPEAT_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
new file mode 100644
index 0000000000..e89675efb7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ std::vector<std::pair<string, AttrValue>> common_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ common_attrs[1] = std::make_pair("output_types", types_attr);
+
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ common_attrs, graph, &range_node));
+
+ NodeDef *buffer_size_node;
+ TF_ASSERT_OK(
+ graph_utils::AddScalarConstNode<int64>(128, graph, &buffer_size_node));
+ NodeDef *seed_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &seed_node));
+ NodeDef *seed2_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &seed2_node));
+ std::vector<string> shuffle_inputs(4);
+ shuffle_inputs[0] = range_node->name();
+ shuffle_inputs[1] = buffer_size_node->name();
+ shuffle_inputs[2] = seed_node->name();
+ shuffle_inputs[3] = seed2_node->name();
+ NodeDef *shuffle_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "ShuffleDataset", shuffle_inputs,
+ common_attrs, graph, &shuffle_node));
+
+ NodeDef *count_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &count_node));
+ std::vector<string> repeat_inputs(2);
+ repeat_inputs[0] = shuffle_node->name();
+ repeat_inputs[1] = count_node->name();
+ NodeDef *repeat_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RepeatDataset", repeat_inputs,
+ common_attrs, graph, &repeat_node));
+
+ ShuffleAndRepeatFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(shuffle_node->name(), output));
+ EXPECT_FALSE(graph_utils::ContainsNodeWithName(repeat_node->name(), output));
+ EXPECT_TRUE(
+ graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
+ NodeDef shuffle_and_repeat_node = output.node(
+ graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
+ EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
+ EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
+ EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
+ EXPECT_EQ(shuffle_and_repeat_node.input(2), shuffle_node->input(2));
+ EXPECT_EQ(shuffle_and_repeat_node.input(3), shuffle_node->input(3));
+ EXPECT_EQ(shuffle_and_repeat_node.input(4), repeat_node->input(1));
+ EXPECT_TRUE(
+ AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_shapes"),
+ repeat_node->attr().at("output_shapes")));
+ EXPECT_TRUE(
+ AreAttrValuesEqual(shuffle_and_repeat_node.attr().at("output_types"),
+ repeat_node->attr().at("output_types")));
+}
+
+TEST(ShuffleAndRepeatFusionTest, NoChange) {
+ GrapplerItem item;
+ GraphDef *graph = &item.graph;
+
+ std::vector<std::pair<string, AttrValue>> common_attrs(2);
+ AttrValue shapes_attr;
+ SetAttrValue("output_shapes", &shapes_attr);
+ common_attrs[0] = std::make_pair("output_shapes", shapes_attr);
+ AttrValue types_attr;
+ SetAttrValue("output_types", &types_attr);
+ common_attrs[1] = std::make_pair("output_types", types_attr);
+
+ NodeDef *start_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(0, graph, &start_node));
+ NodeDef *stop_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(10, graph, &stop_node));
+ NodeDef *step_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(1, graph, &step_node));
+
+ std::vector<string> range_inputs(3);
+ range_inputs[0] = start_node->name();
+ range_inputs[1] = stop_node->name();
+ range_inputs[2] = step_node->name();
+ NodeDef *range_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RangeDataset", range_inputs,
+ common_attrs, graph, &range_node));
+
+ NodeDef *count_node;
+ TF_ASSERT_OK(graph_utils::AddScalarConstNode<int64>(-1, graph, &count_node));
+ std::vector<string> repeat_inputs(2);
+ repeat_inputs[0] = range_node->name();
+ repeat_inputs[1] = count_node->name();
+ NodeDef *repeat_node;
+ TF_ASSERT_OK(graph_utils::AddNode("", "RepeatDataset", repeat_inputs,
+ common_attrs, graph, &repeat_node));
+
+ ShuffleAndRepeatFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_TRUE(graph_utils::Compare(*graph, output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 200454b522..fdd82b9603 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -65,7 +65,7 @@ void DeleteNodes(const std::set<int>& nodes_to_delete, GraphDef* graph) {
} // namespace
-bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) {
+bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
if (!IsIdentity(node)) {
return true;
}
@@ -108,7 +108,7 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) {
return true;
}
-bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
+bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
if (!fetch_nodes_known_ ||
nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
return false;
@@ -142,6 +142,61 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
return true;
}
+bool DependencyOptimizer::BypassingNodeIsBeneficial(
+ const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
+ const std::vector<NodeDef*>& output_nodes) const {
+ const bool is_identity = IsIdentity(node);
+ const int num_outputs = output_nodes.size();
+ const int num_inputs = node.input_size();
+
+ // Don't increase the number of edges in the graph.
+ if (num_inputs * num_outputs > num_inputs + num_outputs) {
+ return false;
+ }
+
+ // Make sure that we don't increase the number of edges that cross
+ // device boundaries.
+ if ((num_inputs == 1 && num_outputs > 1 &&
+ input_nodes[0]->device() != node.device()) ||
+ (num_inputs > 1 && num_outputs == 1 &&
+ output_nodes[0]->device() != node.device())) {
+ return false;
+ }
+
+ // TODO(rmlarsen): Not all device crossings are equally expensive.
+ // Assign a cost to each based on device affinity and compute a
+ // cost before and after.
+ const string& node_dev = node.device();
+ int num_cross_in = 0;
+ for (NodeDef* input_node : input_nodes) {
+ num_cross_in += static_cast<int>(input_node->device() != node_dev);
+ }
+ int num_cross_out = 0;
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_out += static_cast<int>(output_node->device() != node_dev);
+ }
+ if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
+ // This identity node follows a device crossing, so it might be
+ // following a _Recv node after partioning. Do not remove such nodes,
+ // unless they only have consumers on the same device as themselves.
+ return false;
+ }
+
+ // Make sure we do not increase the number of device crossings.
+ const int num_cross_before = num_cross_in + num_cross_out;
+ int num_cross_after = 0;
+ for (NodeDef* input_node : input_nodes) {
+ for (NodeDef* output_node : output_nodes) {
+ num_cross_after +=
+ static_cast<int>(input_node->device() != output_node->device());
+ }
+ }
+ if (num_cross_after > num_cross_before) {
+ return false;
+ }
+ return true;
+}
+
void DependencyOptimizer::OptimizeNode(int node_idx,
SetVector<int>* nodes_to_simplify,
std::set<int>* nodes_to_delete) {
@@ -205,14 +260,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
}
continue;
}
+ // Replace a normal input with a control input.
const string ctrl_input = ConstantFolding::AddControlDependency(
old_input, optimized_graph_, node_map_.get());
- if (ctrl_inputs.insert(ctrl_input).second) {
- node->set_input(pos, ctrl_input);
- node_map_->UpdateInput(node_name, old_input, ctrl_input);
- const NodeDef* old_input_node = node_map_->GetNode(old_input);
- nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
- }
+ ctrl_inputs.insert(ctrl_input);
+ node->set_input(pos, ctrl_input);
+ node_map_->UpdateInput(node_name, old_input, ctrl_input);
+ const NodeDef* old_input_node = node_map_->GetNode(old_input);
+ nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
++pos;
}
node->set_op("NoOp");
@@ -269,21 +324,11 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
// y --^> | | --^> b /\ +---+
// +----------+ y --^> b
- if (is_noop || is_identity) {
- if (is_identity && !SafeToRemoveIdentity(*node)) {
- return;
- }
-
+ if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) {
const auto& output_node_set = node_map_->GetOutputs(node_name);
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
output_node_set.end());
- const int num_outputs = output_nodes.size();
const int num_inputs = node->input_size();
-
- // Don't increase the number of edges in the graph.
- if (num_inputs * num_outputs > num_inputs + num_outputs) {
- return;
- }
std::vector<NodeDef*> input_nodes;
for (int i = 0; i < num_inputs; ++i) {
NodeDef* input_node = node_map_->GetNode(node->input(i));
@@ -294,44 +339,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
input_nodes.push_back(input_node);
}
- // Make sure that we don't increase the number of edges that cross
- // device boundaries.
- if ((num_inputs == 1 && num_outputs > 1 &&
- input_nodes[0]->device() != node->device()) ||
- (num_inputs > 1 && num_outputs == 1 &&
- output_nodes[0]->device() != node->device())) {
- return;
- }
-
- // TODO(rmlarsen): Not all device crossings are equally expensive.
- // Assign a cost to each based on device affinity and compute a
- // cost before and after.
- const string& node_dev = node->device();
- int num_cross_in = 0;
- for (NodeDef* input_node : input_nodes) {
- num_cross_in += static_cast<int>(input_node->device() != node_dev);
- }
- int num_cross_out = 0;
- for (NodeDef* output_node : output_nodes) {
- num_cross_out += static_cast<int>(output_node->device() != node_dev);
- }
- if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
- // This identity node follows a device crossing, so it might be
- // following a _Recv node after partioning. Do not remove such nodes,
- // unless they only have consumers on the same device as themselves.
- return;
- }
-
- // Make sure we do not increase the number of device crossings.
- const int num_cross_before = num_cross_in + num_cross_out;
- int num_cross_after = 0;
- for (NodeDef* input_node : input_nodes) {
- for (NodeDef* output_node : output_nodes) {
- num_cross_after +=
- static_cast<int>(input_node->device() != output_node->device());
- }
- }
- if (num_cross_after > num_cross_before) {
+ if (!BypassingNodeIsBeneficial(*node, input_nodes, output_nodes)) {
return;
}
@@ -557,6 +565,92 @@ void DependencyOptimizer::BuildNodeToIdx() {
}
}
+// Suppose there are cross-device control inputs to node C from multiple nodes
+// that are located on another device, e.g., we have control edges:
+// A->C, B->C
+// where A and B are on device X and C is on device Y.
+// We can reduce cross-device communication by introducing an intermediate
+// NoOp node C' on device X and rewriting the control edges to:
+// A->C', B->C', C' -> C
+void DependencyOptimizer::GroupCrossDeviceControlEdges() {
+ const int num_nodes = optimized_graph_->node_size();
+ for (int i = 0; i < num_nodes; ++i) {
+ NodeDef* node = optimized_graph_->mutable_node(i);
+ if (node->device().empty()) continue;
+
+ // Creates new noop nodes for devices on which multiple control inputs are
+ // located.
+
+ // Map keyed by device name to the newly introduced Noop node for that
+ // device. A nullptr value means that we have only seen a single node on
+ // that device.
+ std::map<string, NodeDef*> noops;
+ int num_noops = 0;
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (IsControlInput(node->input(j))) {
+ const NodeDef* input = node_map_->GetNode(node->input(j));
+ if (input != nullptr && !input->device().empty() &&
+ input->device() != node->device()) {
+ auto emplace_result = noops.emplace(input->device(), nullptr);
+ if (!emplace_result.second &&
+ emplace_result.first->second == nullptr) {
+ // This is the second cross-device control input from the same
+ // device. Creates an intermediate noop node on that device.
+ string group_name;
+ NodeDef* noop;
+ // Creates a fresh node name; there may be conflicting names from
+ // a previous iteration of the optimizer.
+ do {
+ group_name = AddPrefixToNodeName(
+ node->name(),
+ strings::StrCat("GroupCrossDeviceControlEdges_", num_noops));
+ noop = node_map_->GetNode(group_name);
+ ++num_noops;
+ } while (noop != nullptr);
+ noop = optimized_graph_->add_node();
+ noop->set_name(group_name);
+ noop->set_device(input->device());
+ noop->set_op("NoOp");
+ node_map_->AddNode(noop->name(), noop);
+ emplace_result.first->second = noop;
+ }
+ }
+ }
+ }
+
+ // Reroute existing control edges to go via the newly introduced NoOp nodes.
+ int pos = 0;
+ while (pos < node->input_size()) {
+ const string& input_name = node->input(pos);
+ if (IsControlInput(input_name)) {
+ NodeDef* input = node_map_->GetNode(input_name);
+ if (input == nullptr) {
+ ++pos;
+ } else {
+ auto it = noops.find(input->device());
+ if (it == noops.end() || it->second == nullptr) {
+ ++pos;
+ } else {
+ node->mutable_input()->SwapElements(pos, node->input_size() - 1);
+ node->mutable_input()->RemoveLast();
+ it->second->add_input(AsControlDependency(*input));
+ node_map_->UpdateOutput(input_name, node->name(),
+ it->second->name());
+ }
+ }
+ } else {
+ ++pos;
+ }
+ }
+ for (const auto& entry : noops) {
+ if (entry.second) {
+ node->add_input(AsControlDependency(*entry.second));
+ node_map_->AddOutput(entry.second->name(), node->name());
+ }
+ }
+ }
+}
+
Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
optimized_graph_ = optimized_graph;
@@ -588,6 +682,8 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Dedup control inputs.
CleanControlInputs();
+
+ GroupCrossDeviceControlEdges();
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
index b4db98125a..48cfa236af 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h
@@ -30,7 +30,8 @@ namespace grappler {
class DependencyOptimizer : public GraphOptimizer {
public:
DependencyOptimizer() {}
- explicit DependencyOptimizer(RewriterConfig::Toggle opt_level) {}
+ explicit DependencyOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
~DependencyOptimizer() override {}
string name() const override { return "dependency_optimizer"; };
@@ -42,11 +43,17 @@ class DependencyOptimizer : public GraphOptimizer {
const GraphDef& optimized_graph, double result) override;
private:
+ // Returns true if bypassing node does not increase the number of edges or
+ // number of edges crossing a device boundary.
+ bool BypassingNodeIsBeneficial(
+ const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
+ const std::vector<NodeDef*>& output_nodes) const;
+
// Returns true if node is not an Identity node or if it is an Identity
// that is safe to remove.
- bool SafeToRemoveIdentity(const NodeDef& node);
+ bool SafeToRemoveIdentity(const NodeDef& node) const;
// Returns true if it is safe to convert node to NoOp.
- bool SafeToConvertToNoOp(const NodeDef& node);
+ bool SafeToConvertToNoOp(const NodeDef& node) const;
// Removes all duplicate control dependencies.
void CleanControlInputs();
// Builds a map from the &optimized_graph_->node(i) to i.
@@ -61,7 +68,11 @@ class DependencyOptimizer : public GraphOptimizer {
Status TransitiveReduction();
// Main driver of dependency optimizations.
Status OptimizeDependencies();
+ // Replaces multiple cross-device control edges from the same device with a
+ // single control edge.
+ void GroupCrossDeviceControlEdges();
+ RewriterConfig::Toggle opt_level_;
bool fetch_nodes_known_;
std::unordered_set<string> nodes_to_preserve_;
std::unique_ptr<NodeMap> node_map_;
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 6a297da52d..c0f07562af 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -29,7 +31,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-class DependencyOptimizerTest : public ::testing::Test {};
+class DependencyOptimizerTest : public GrapplerTest {};
void VerifyGraphsEqual(const GraphDef& original_graph,
const GraphDef& optimized_graph, const string& func) {
@@ -122,25 +124,62 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
for (int i = 0; i < item.graph.node_size(); ++i) {
const NodeDef& node = item.graph.node(i);
- if (node.name() == "add") {
- EXPECT_EQ("NoOp", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^y", node.input(1));
- } else if (node.name() == "id1") {
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
} else if (node.name() == "id2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
+ }
+ }
+ EXPECT_EQ(2, found);
+}
+
+TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
+ Output add = ops::Add(s.WithOpName("add"), x, x);
+ Output id1 =
+ ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"id1"};
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ const NodeDef& node = item.graph.node(i);
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(1, found);
}
TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
@@ -398,6 +437,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id_a", node.name());
EXPECT_NE("id_b", node.name());
@@ -405,30 +445,36 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
if (node.name() == "a_a" || node.name() == "a_b") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("x", node.input(0));
+ ++found;
}
if (node.name() == "a_c" || node.name() == "a_d") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
}
if (node.name() == "b_a") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
EXPECT_EQ("^z", node.input(2));
+ ++found;
}
if (node.name() == "c_a") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
}
if (node.name() == "c_b") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 7);
}
TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
@@ -458,17 +504,20 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id0", node.name());
if (node.name() == "or0") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("switch:1", node.input(1));
+ ++found;
}
if (node.name() == "or1") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("y", node.input(1));
+ ++found;
}
if (node.name() == "or2") {
// or1 should be unchanged.
@@ -476,8 +525,10 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("y", node.input(1));
EXPECT_EQ("^id1", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
}
TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
@@ -533,6 +584,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
+ bool found = false;
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
// "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
@@ -543,8 +595,10 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
+ found = true;
}
}
+ EXPECT_TRUE(found);
}
TEST_F(DependencyOptimizerTest, IdentityInputs) {
@@ -722,6 +776,68 @@ TEST_F(DependencyOptimizerTest, RemoveGreaterEqualWithNoOp) {
EXPECT_EQ(3, count);
}
+TEST_F(DependencyOptimizerTest, GroupCrossDeviceControlDeps) {
+ GrapplerItem item;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
+ {1, 2}, DT_FLOAT);
+ Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
+ {1, 2}, DT_FLOAT);
+ Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
+ {1, 2}, DT_FLOAT);
+ Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
+ {1, 2}, DT_FLOAT);
+ Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
+ {1, 2}, DT_FLOAT);
+ // Node with cross-device dependencies.
+ auto fetch = ops::Identity(
+ s.WithOpName("f")
+ .WithControlDependencies({a.op(), b.op(), c.op(), d.op()})
+ .WithDevice("/GPU:0"),
+ {e});
+
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("f");
+ }
+
+ GraphDef expected;
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::RandomUniform(s.WithOpName("a").WithDevice("/CPU:1"),
+ {1, 2}, DT_FLOAT);
+ Output b = ops::RandomUniform(s.WithOpName("b").WithDevice("/CPU:2"),
+ {1, 2}, DT_FLOAT);
+ Output c = ops::RandomUniform(s.WithOpName("c").WithDevice("/CPU:1"),
+ {1, 2}, DT_FLOAT);
+ Output d = ops::RandomUniform(s.WithOpName("d").WithDevice("/CPU:3"),
+ {1, 2}, DT_FLOAT);
+ Output e = ops::RandomUniform(s.WithOpName("e").WithDevice("/CPU:0"),
+ {1, 2}, DT_FLOAT);
+ auto noop = ops::NoOp(s.WithOpName("GroupCrossDeviceControlEdges_0/f")
+ .WithDevice("/CPU:1")
+ .WithControlDependencies({a.op(), c.op()}));
+ auto fetch =
+ ops::Identity(s.WithOpName("f")
+ .WithControlDependencies({b.op(), d.op(), noop})
+ .WithDevice("/GPU:0"),
+ {e});
+
+ TF_CHECK_OK(s.ToGraphDef(&expected));
+ }
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ CompareGraphs(expected, output);
+
+ // Run the optimizer again to verify idempotence.
+ item.graph.Swap(&output);
+ output.Clear();
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ CompareGraphs(expected, output);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index b0d689c2dd..645e4c2087 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -629,9 +629,12 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
}
}
- // Add the node name as a prefix to avoid collisions after inlining.
- func_body_node.set_name(
- strings::StrCat(func_node.name(), "/", func_body_node.name()));
+ // Add the function node name as a prefix 1) to node name to avoid
+ // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
+ // frame after inlining.
+ const string prefix = strings::StrCat(func_node.name(), "/");
+ TF_RETURN_IF_ERROR(
+ AddPrefixAndSuffixToNode(prefix, "" /* suffix */, &func_body_node));
// Make sure the node is placed.
func_body_node.set_device(func_node.device());
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index d043f6129d..fab3f994c1 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -207,6 +208,12 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) {
// Nodes
{
{{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_FLOAT}}},
+ // "enter" node is used to verify that InlineFunction would update the
+ // frame name accordingly.
+ {{"enter"},
+ "Enter",
+ {"x"},
+ {{"T", DT_FLOAT}, {"frame_name", "frame"}}},
{{"y"}, "Mul", {"x", "two"}, {{"T", DT_FLOAT}}},
});
@@ -263,9 +270,14 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) {
EXPECT_EQ(kDevice, node.device());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("y", node.input(0));
+ } else if (node.name() == "y/enter") {
+ count++;
+ EXPECT_TRUE(IsEnter(node));
+ const string frame_name = node.attr().at("frame_name").s();
+ EXPECT_EQ("y/frame", frame_name);
}
}
- EXPECT_EQ(6, count);
+ EXPECT_EQ(7, count);
Tensor pi = test::AsScalar<float>(3.14f);
item.fetch = {"z"};
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 2fbdd76a77..2afb5df431 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -44,16 +45,19 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
- GraphProperties* graph_properties, NodeMap* node_map)
+ GraphProperties* graph_properties, NodeMap* node_map,
+ RewriterConfig::Toggle opt_level)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
- node_map(node_map) {}
+ node_map(node_map),
+ opt_level(opt_level) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
+ RewriterConfig::Toggle opt_level;
};
Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 3f5ab87a5a..34f28c7c27 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -59,7 +60,8 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ nullptr,
/*graph_properties*/ nullptr,
- /*node_name*/ nullptr);
+ /*node_name*/ nullptr,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,7 +96,8 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -133,7 +136,8 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index e08ab1eb67..3251e7cb10 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -499,6 +499,7 @@ class NodeProcessor : public GraphProcessor {
UpdateAttrDataFormat();
UpdateAttrKSize();
UpdateAttrStrides();
+ UpdateAttrDilations();
UpdateAttrShape();
TF_RETURN_IF_ERROR(AddLayoutTransposeToInputs());
TF_RETURN_IF_ERROR(AddLayoutTransposeToOutputs());
@@ -742,6 +743,13 @@ class NodeProcessor : public GraphProcessor {
}
}
+ void UpdateAttrDilations() {
+ if (node_->attr().find("dilations") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("dilations").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
void UpdateAttrDataFormat() {
if (node_->attr().find("data_format") != node_->attr().end()) {
if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index dad49cd74f..20e47c1b26 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -87,12 +87,13 @@ class LayoutOptimizerTest : public GrapplerTest {
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
int filter_size, const string& padding) {
- return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true);
+ return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true,
+ true);
}
Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
int filter_size, const string& padding,
- bool const_input_size) {
+ bool const_input_size, bool dilated) {
int batch_size = 128;
int input_height = input_size;
int input_width = input_size;
@@ -123,14 +124,18 @@ class LayoutOptimizerTest : public GrapplerTest {
Output conv_backprop_input;
Output input_sizes_i =
ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
+ ops::Conv2DBackpropInput::Attrs attrs;
+ if (dilated) {
+ attrs = attrs.Dilations({1, 2, 2, 1});
+ }
if (const_input_size) {
conv_backprop_input = ops::Conv2DBackpropInput(
s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
- {1, stride, stride, 1}, padding);
+ {1, stride, stride, 1}, padding, attrs);
} else {
conv_backprop_input = ops::Conv2DBackpropInput(
s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
- {1, stride, stride, 1}, padding);
+ {1, stride, stride, 1}, padding, attrs);
}
return conv_backprop_input;
}
@@ -216,7 +221,7 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false);
+ auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false, false);
Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e6622486eb..c55f479451 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -42,6 +42,7 @@ namespace grappler {
namespace {
constexpr int kDefaultNumberOfIterations = 2;
+constexpr int kDefaultMinGraphNodes = 4;
int64 NumEdges(const GraphDef& graph) {
int64 num_edges = 0;
@@ -90,7 +91,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
MK_OPT("debug_stripper", new DebugStripper());
MK_OPT("scoped_allocator",
- new ScopedAllocatorOptimizer(cfg_.scoped_allocator_opts()));
+ new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
+ cfg_.scoped_allocator_opts()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -149,8 +151,8 @@ Status MetaOptimizer::InitializeOptimizers(
new AutoParallel(cfg_.auto_parallel().num_replicas()));
}
if (cfg_.scoped_allocator_optimization()) {
- optimizers->emplace_back(
- new ScopedAllocatorOptimizer(cfg_.scoped_allocator_opts()));
+ optimizers->emplace_back(new ScopedAllocatorOptimizer(
+ cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
return Status::OK();
}
@@ -194,6 +196,15 @@ Status MetaOptimizer::InitializeOptimizersByName(
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
+ : cfg_.min_graph_nodes();
+ if (item.graph.node_size() < min_graph_nodes) {
+ VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
+ << " nodes.";
+ *optimized_graph = item.graph;
+ return Status::OK();
+ }
+
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
@@ -202,10 +213,11 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
- << " num_optimizers=" << optimizers.size();
+ << " num_optimizers=" << optimizers.size()
+ << ", num nodes = " << item.graph.node_size();
if (optimizers.empty()) {
- VLOG(3) << "Skip graph optimization, no optimizers registered";
+ VLOG(3) << "Skipping graph optimization, no optimizers registered";
*optimized_graph = item.graph;
return Status::OK();
}
@@ -217,61 +229,56 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
bool is_optimized = false;
GraphOptimizationResult optimization_result(item.id);
+ GraphOptimizer* fusion_optimizer = nullptr;
+ GraphOptimizer* sa_optimizer = nullptr;
- // ScopedAllocatorOptimizer must run last, so move it to the
- // end of optimizers and run only on the last iteration.
- {
- int sa_index = 0;
- for (; sa_index < optimizers.size(); ++sa_index) {
- if (optimizers[sa_index]->name() == "scoped_allocator_optimizer") {
- break;
- }
- }
- const int last_index = optimizers.size() - 1;
- if (sa_index < last_index) {
- optimizers[last_index].swap(optimizers[sa_index]);
- }
- }
-
- const int last_iteration = NumIterations(cfg_) - 1;
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
- VLOG(4) << "Starting optimization iteration " << iteration + 1;
+ // Don't bother optimizing further if the graph is already tiny.
+ if (optimized_graph->node_size() < min_graph_nodes) {
+ VLOG(3) << "Stopping after iteration " << iteration
+ << ", graph is tiny (#nodes = " << optimized_graph->node_size()
+ << " < " << min_graph_nodes << ")";
+ break;
+ }
+ VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
- if (optimizer->name() == "scoped_allocator_optimizer" &&
- iteration != last_iteration)
+ if (optimizer->name() == "scoped_allocator_optimizer") {
+ if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
continue;
-
- uint64 start_us = Env::Default()->NowMicros();
- // This swaps the current optimized_graph into optimized item and
- // resets optimized_graph to an empty graph.
- optimized_graph->Swap(&optimized_item.graph);
- *optimized_graph = GraphDef();
- Status status =
- optimizer->Optimize(cluster, optimized_item, optimized_graph);
- uint64 end_us = Env::Default()->NowMicros();
-
- string result;
- if (!status.ok()) {
- optimized_graph->Swap(&optimized_item.graph);
- result = status.ToString();
- } else {
- is_optimized = true;
- float duration_ms = (end_us - start_us) / 1000.0f;
- result = strings::StrCat(
- PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
- ", time = ", duration_ms, "ms.");
}
- VLOG(4) << optimizer->name() << ": " << result;
-
- OptimizerResult optimizer_result{optimizer->name(), result};
- optimization_result.results.push_back(optimizer_result);
+ if (optimizer->name() == "xla-fusion") {
+ if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
+ continue;
+ }
+ Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
}
}
+ // Run fusion optimizer if requested after all other optimizers since: 1) it
+ // doesn't need to be called more than once. 2) we don't want subsequent
+ // optimization passes to break the fusion clusters. We could potentially
+ // encapsulate the fusion clusters right away, but that will prevent a lot of
+ // optimizations from taking place since we don't have shape inference for
+ // functions, and we can't optimize across function boundaries.
+ if (fusion_optimizer != nullptr) {
+ Status status = RunOptimizer(fusion_optimizer, cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
+ }
+
+ // ScopedAllocatorOptimizer must run last.
+ if (sa_optimizer != nullptr) {
+ Status status = RunOptimizer(sa_optimizer, cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
+ }
+
// Record graph optimization result.
optimization_results_.push_back(optimization_result);
@@ -286,6 +293,35 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
+Status MetaOptimizer::RunOptimizer(
+ GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
+ GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
+ uint64 start_us = Env::Default()->NowMicros();
+ // This swaps the current optimized_graph into optimized item and
+ // resets optimized_graph to an empty graph.
+ optimized_graph->Swap(&optimized_item->graph);
+ *optimized_graph = GraphDef();
+ Status status =
+ optimizer->Optimize(cluster, *optimized_item, optimized_graph);
+ uint64 end_us = Env::Default()->NowMicros();
+
+ string result;
+ if (!status.ok()) {
+ optimized_graph->Swap(&optimized_item->graph);
+ result = status.ToString();
+ } else {
+ float duration_ms = (end_us - start_us) / 1000.0f;
+ result = strings::StrCat(
+ PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
+ ", time = ", duration_ms, "ms.");
+ }
+ VLOG(1) << optimizer->name() << ": " << result;
+
+ OptimizerResult optimizer_result{optimizer->name(), result};
+ optimization_result->results.push_back(optimizer_result);
+ return status;
+}
+
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
optimization_results_.clear();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index e736dd174e..151a54cbdf 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -72,6 +72,10 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<OptimizerResult> results;
};
+ Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster,
+ GrapplerItem* optimized_item, GraphDef* optimized_graph,
+ GraphOptimizationResult* optimization_result);
+
std::vector<GraphOptimizationResult> optimization_results_;
};
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 8247cce339..9a03c7dfef 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -74,6 +74,7 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
TestOptimizer::SetOptimized(false);
RewriterConfig rewriter_config;
rewriter_config.add_optimizers("TestOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -89,6 +90,7 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
RewriterConfig rewriter_config;
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -104,6 +106,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
rewriter_config.set_function_optimization(RewriterConfig::ON);
rewriter_config.add_optimizers("function");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index efd870b118..03e36a7b9c 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace grappler {
@@ -200,8 +201,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
}
}
if (optimizable) {
- std::cout << "Optimizing fused batch norm node " << node.DebugString()
- << std::endl;
+ VLOG(1) << "Optimizing fused batch norm node " << node.DebugString();
AddBatchNormNodes(optimized_graph, node);
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
index cceef4098d..275568e464 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -650,7 +650,8 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
};
ScopedAllocatorOptimizer::ScopedAllocatorOptimizer(
- const ScopedAllocatorOptions& opts) {
+ RewriterConfig::Toggle opt_level, const ScopedAllocatorOptions& opts)
+ : opt_level_(opt_level) {
VLOG(1) << "ScopedAllocatorOptimizer::ScopedAllocatorOptimizer";
Rewriter* r = new UnaryElementwiseRewriter();
to_delete_.push_back(r);
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
index ab4d444595..13589f536c 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h
@@ -32,7 +32,8 @@ class ScopedAllocatorOptimizer;
// movement and consolidate some kinds of Ops.
class ScopedAllocatorOptimizer : public GraphOptimizer {
public:
- explicit ScopedAllocatorOptimizer(const ScopedAllocatorOptions& opts);
+ ScopedAllocatorOptimizer(RewriterConfig::Toggle opt_level,
+ const ScopedAllocatorOptions& opts);
~ScopedAllocatorOptimizer() override;
string name() const override { return "scoped_allocator_optimizer"; }
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
index 3a2859dc5f..89847f83d4 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -115,7 +115,7 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
ScopedAllocatorOptions opts;
opts.add_enable_op("Abs");
- ScopedAllocatorOptimizer sao(opts);
+ ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
ScopedAllocatorOptimizer::OpNameSet ons;
ons.insert("Abs");
@@ -199,7 +199,7 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryExecute) {
// b + c == -4, -4, 3, 2
for (int oi = 0; oi < outputs.size(); ++oi) {
for (int i = 0; i < outputs[oi].NumElements(); ++i) {
- VLOG(0) << "output vec " << oi << " index " << i << " = "
+ VLOG(1) << "output vec " << oi << " index " << i << " = "
<< outputs[oi].flat<float>()(i);
}
if (oi == 0) {
diff --git a/tensorflow/core/grappler/utils/scc.cc b/tensorflow/core/grappler/utils/scc.cc
index f2a6507d94..d033e9c522 100644
--- a/tensorflow/core/grappler/utils/scc.cc
+++ b/tensorflow/core/grappler/utils/scc.cc
@@ -142,9 +142,13 @@ void StronglyConnectedComponents(
// Create a list of top-level parents (add them to object queue)
// Also create a mapping from nodes to their children.
+ // Inputs might not be present if called on a subgraph.
for (const NodeDef& node : graph.node()) {
for (const string& input : node.input()) {
- name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]);
+ auto it = name_to_data.find(NodeName(input));
+ if (it != name_to_data.end()) {
+ it->second->children.push_back(node_to_data[&node]);
+ }
}
}
@@ -202,10 +206,12 @@ int IdentifyLoops(const GraphDef& graph,
const std::vector<const NodeDef*>& component_nodes = component.second;
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
GraphDef subgraph;
+ std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
for (const auto& component_node : component_nodes) {
NodeDef* node = subgraph.add_node();
*node = *component_node;
+ subgraph_mapping[node] = component_node;
if (IsNextIteration(*node)) {
CHECK_EQ(1, node->input_size());
next_iter_nodes.emplace_back(node, node->input(0));
@@ -227,13 +233,13 @@ int IdentifyLoops(const GraphDef& graph,
int num_components = 0;
std::unordered_map<const NodeDef*, int> components;
StronglyConnectedComponents(subgraph, &components, &num_components);
- CHECK_EQ(1, num_components);
+ CHECK_GE(num_components, 1);
for (const auto it : components) {
int id = it.second;
if (id < 0) {
continue;
}
- (*loops)[it.first].push_back(loop_id);
+ (*loops)[subgraph_mapping[it.first]].push_back(loop_id);
}
++loop_id;
}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 5948f8d39f..3426ea8aa2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -368,6 +368,7 @@ cc_library(
cc_library(
name = "queue_op",
+ srcs = ["queue_op.cc"],
hdrs = ["queue_op.h"],
deps = [
":queue_base",
@@ -881,7 +882,6 @@ tf_kernel_library(
"tile_functor_gpu.cu.cc",
],
prefix = "tile_ops",
- textual_hdrs = ["tile_ops_gpu_impl.h"],
deps = ARRAY_DEPS,
)
@@ -1885,9 +1885,10 @@ cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
hdrs = ["fifo_queue.h"],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":queue_base",
+ ":queue_op",
":typed_queue",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -2085,6 +2086,7 @@ IMAGE_DEPS = [
"//tensorflow/core:jpeg_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:png_internal",
"//tensorflow/core:protos_all_cc",
]
@@ -2659,7 +2661,7 @@ tf_kernel_library(
tf_kernel_library(
name = "summary_image_op",
prefix = "summary_image_op",
- deps = LOGGING_DEPS,
+ deps = LOGGING_DEPS + ["//tensorflow/core:png_internal"],
)
tf_kernel_library(
@@ -2812,6 +2814,9 @@ tf_kernel_library(
srcs = [] + if_mkl([
"mkl_batch_matmul_op.cc",
]),
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "batch_matmul_op",
deps = MATH_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
@@ -2879,6 +2884,9 @@ tf_kernel_library(
"mkl_matmul_op.cc",
]),
hdrs = ["matmul_op.h"],
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
defines = select({
":xsmm": [
"TENSORFLOW_USE_LIBXSMM",
@@ -2928,6 +2936,15 @@ tf_kernel_library(
deps = MATH_DEPS,
)
+tf_kernel_library(
+ name = "unary_ops_composition",
+ prefix = "unary_ops_composition",
+ deps = MATH_DEPS + [
+ ":cwise_op",
+ ":relu_op",
+ ],
+)
+
tf_cc_test(
name = "sequence_ops_test",
size = "small",
@@ -3027,6 +3044,28 @@ tf_cuda_cc_test(
)
tf_cuda_cc_test(
+ name = "unary_ops_composition_test",
+ size = "small",
+ srcs = ["unary_ops_composition_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ ":unary_ops_composition",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cuda_cc_test(
name = "matmul_op_test",
size = "small",
srcs = ["matmul_op_test.cc"],
@@ -3248,8 +3287,7 @@ tf_kernel_library(
"//conditions:default": [],
}),
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
- # So that it doesn't take 20 minutes to compile conv_grad_ops_3d.cc and conv_ops_3d.cc
- # on Windows. See https://github.com/tensorflow/tensorflow/issues/10521
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
defines = select({
":xsmm_convolutions": [
@@ -3300,7 +3338,7 @@ tf_kernel_library(
"//tensorflow/core:nn_ops_op_lib",
] + if_cuda([
"@cub_archive//:cub",
- "@local_config_cuda//cuda:cudnn",
+ "@local_config_cuda//cuda:cudnn_header",
]),
)
@@ -3319,7 +3357,7 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
] + if_cuda([
- "@local_config_cuda//cuda:cudnn",
+ "@local_config_cuda//cuda:cudnn_header",
]),
)
@@ -3347,6 +3385,14 @@ cc_library(
],
)
+# Kernels for the nodes intented to be added to the graph by the Grappler optimizers.
+cc_library(
+ name = "grappler",
+ deps = [
+ ":unary_ops_composition",
+ ],
+)
+
NN_DEPS = [
":bounds_check",
":conv_2d",
@@ -3376,7 +3422,10 @@ tf_kernel_library(
tf_kernel_library(
name = "bias_op",
prefix = "bias_op",
- deps = NN_DEPS,
+ deps = NN_DEPS + if_cuda([
+ ":reduction_ops",
+ "@cub_archive//:cub",
+ ]),
)
tf_kernel_library(
@@ -3395,6 +3444,9 @@ tf_kernel_library(
tf_kernel_library(
name = "lrn_op",
+ # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
+ # to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
+ copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "lrn_op",
deps = NN_DEPS,
)
@@ -3877,6 +3929,8 @@ tf_cc_test(
cc_library(
name = "sparse",
deps = [
+ ":deserialize_sparse_string_op",
+ ":deserialize_sparse_variant_op",
":serialize_sparse_op",
":sparse_add_grad_op",
":sparse_add_op",
@@ -3887,6 +3941,7 @@ cc_library(
":sparse_reduce_op",
":sparse_reorder_op",
":sparse_reshape_op",
+ ":sparse_slice_grad_op",
":sparse_slice_op",
":sparse_softmax",
":sparse_sparse_binary_op_shared",
@@ -3973,6 +4028,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "sparse_slice_grad_op",
+ prefix = "sparse_slice_grad_op",
+ deps = SPARSE_DEPS,
+)
+
+tf_kernel_library(
name = "sparse_slice_op",
prefix = "sparse_slice_op",
deps = SPARSE_DEPS,
@@ -4024,6 +4085,23 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "deserialize_sparse_string_op",
+ prefix = "deserialize_sparse_string_op",
+ deps = SPARSE_DEPS + [
+ ":reshape_util",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "deserialize_sparse_variant_op",
+ prefix = "deserialize_sparse_variant_op",
+ deps = SPARSE_DEPS + [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "sparse_tensors_map_ops",
prefix = "sparse_tensors_map_ops",
deps = SPARSE_DEPS,
@@ -5034,6 +5112,7 @@ filegroup(
"padding_fifo_queue.cc",
"padding_fifo_queue_op.cc",
"queue_base.cc",
+ "queue_op.cc",
"queue_ops.cc",
"random_op.cc",
"reduction_ops_all.cc",
@@ -6170,7 +6249,7 @@ cc_library(
tf_kernel_library(
name = "dataset_ops",
deps = [
- "//tensorflow/core/kernels/data:dataset_ops",
+ "//tensorflow/core/kernels/data",
],
)
diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc
index 66c4aff3e3..a7757d1361 100644
--- a/tensorflow/core/kernels/as_string_op.cc
+++ b/tensorflow/core/kernels/as_string_op.cc
@@ -73,6 +73,7 @@ class AsStringOp : public OpKernel {
}
switch (dtype) {
case DT_INT8:
+ case DT_INT16:
case DT_INT32:
strings::Appendf(&format_, "d");
break;
@@ -129,6 +130,7 @@ class AsStringOp : public OpKernel {
ENCODE_TYPE(DT_FLOAT, float, format_);
ENCODE_TYPE(DT_DOUBLE, double, format_);
ENCODE_TYPE(DT_INT8, int8, format_);
+ ENCODE_TYPE(DT_INT16, int16, format_);
case (DT_BOOL): {
const auto& input_flat = input_tensor->flat<bool>();
for (int i = 0; i < input_flat.size(); ++i) {
diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc
index 8c99ded0a8..35ddda0ec0 100644
--- a/tensorflow/core/kernels/batch_kernels.cc
+++ b/tensorflow/core/kernels/batch_kernels.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/split_lib.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,7 +43,7 @@ typedef Eigen::SyclDevice SYCLDevice;
// ensure proper device placement.
template <typename T>
Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
- int output_index) {
+ Tensor* output) {
const int input_dims = inputs[0].dims();
const TensorShape& input_shape = inputs[0].shape();
@@ -76,9 +78,8 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
TensorShape output_shape(input_shape);
output_shape.set_dim(0, output_dim0);
- Tensor* output = nullptr;
TF_RETURN_IF_ERROR(
- context->allocate_output(output_index, output_shape, &output));
+ context->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
if (output->NumElements() > 0) {
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
#if GOOGLE_CUDA
@@ -209,6 +210,7 @@ class BatchResource : public ResourceBase {
static Status Create(int32 num_batch_threads, int32 max_batch_size,
int32 batch_timeout_micros, int32 max_enqueued_batches,
const std::vector<int32>& allowed_batch_sizes,
+ FunctionLibraryRuntime::Handle fhandle,
std::unique_ptr<BatchResource>* resource) {
std::unique_ptr<BatchResource> new_resource(new BatchResource);
@@ -225,6 +227,8 @@ class BatchResource : public ResourceBase {
new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
+ new_resource->fhandle_ = fhandle;
+
*resource = std::move(new_resource);
return Status::OK();
}
@@ -254,6 +258,14 @@ class BatchResource : public ResourceBase {
}
batch_components->inputs.push_back(tensor);
}
+ OpInputList captured_tensors;
+ const auto captured_status =
+ context->input_list("captured_tensors", &captured_tensors);
+ if (captured_status.ok()) {
+ for (const Tensor& captured_tensor : captured_tensors) {
+ batch_components->captured_inputs.push_back(captured_tensor);
+ }
+ }
batch_components->context = context;
batch_components->done_callback = std::move(done_callback);
@@ -272,6 +284,7 @@ class BatchResource : public ResourceBase {
int64 guid;
std::vector<Tensor> inputs;
+ std::vector<Tensor> captured_inputs;
OpKernelContext* context;
AsyncOpKernel::DoneCallback done_callback;
@@ -314,50 +327,32 @@ class BatchResource : public ResourceBase {
return batch_size;
}
- // Processes a batch of one or more BatchTask entries.
- void ProcessBatch(std::unique_ptr<Batch> batch) const {
- if (batch->empty()) {
- return;
+ Status ConcatInputTensors(const Batch& batch, OpKernelContext* context,
+ std::vector<Tensor>* concatenated_tensors) const {
+ if (batch.num_tasks() == 0) {
+ return errors::InvalidArgument("Empty batch.");
}
- const int padded_batch_size = RoundToLowestAllowedBatchSize(batch->size());
- const int padding_amount = padded_batch_size - batch->size();
- OpKernelContext* last_task_context =
- batch->task(batch->num_tasks() - 1).context;
- AsyncOpKernel::DoneCallback last_task_callback =
- batch->task(batch->num_tasks() - 1).done_callback;
-
- OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
- last_task_callback);
+ const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
+ const int padding_amount = padded_batch_size - batch.size();
// All tasks should have the same number of input edges.
- const int num_input_edges = batch->task(0).inputs.size();
-
- // Process each input edge one at a time (the typical case has just one).
- for (int i = 0; i < num_input_edges; ++i) {
- // Emit batch->num_tasks() - 1 empty output tensors.
- for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
- const BatchTask& task = batch->task(task_idx);
- TensorShape output_shape(task.inputs.at(i).shape());
- output_shape.set_dim(0, 0);
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- task.context,
- task.context->allocate_output(i, output_shape, &output),
- task.done_callback);
- }
+ const int num_inputs = batch.task(0).inputs.size();
+ concatenated_tensors->reserve(num_inputs);
+ // Process each input one at a time (the typical case has just one).
+ for (int i = 0; i < num_inputs; ++i) {
// Concatenate the tasks ith input tensors into a big output tensor.
std::vector<Tensor> to_concatenate;
- to_concatenate.reserve(batch->num_tasks());
- for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
- to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
+ to_concatenate.reserve(batch.num_tasks());
+ for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
+ to_concatenate.push_back(batch.task(task_idx).inputs.at(i));
}
// Add padding as needed. Use the first row of the first task's tensor as
// the data for padding.
if (padding_amount > 0) {
- const Tensor& padding_source = batch->task(0).inputs.at(i);
+ const Tensor& padding_source = batch.task(0).inputs.at(i);
Tensor padding;
if (padding_source.shape().dim_size(0) == 1) {
padding = padding_source;
@@ -367,10 +362,10 @@ class BatchResource : public ResourceBase {
Status slice_status;
std::vector<Tensor> slices;
switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- slice_status = SplitCPU<type>(last_task_context, padding_source, \
- slice_sizes, &slices); \
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ slice_status = \
+ SplitCPU<type>(context, padding_source, slice_sizes, &slices); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
@@ -379,8 +374,7 @@ class BatchResource : public ResourceBase {
errors::InvalidArgument("Unsupported data type: ", type);
break;
}
- OP_REQUIRES_OK_ASYNC(last_task_context, slice_status,
- last_task_callback);
+ TF_RETURN_IF_ERROR(slice_status);
padding = slices.at(0);
}
for (int i = 0; i < padding_amount; ++i) {
@@ -390,10 +384,12 @@ class BatchResource : public ResourceBase {
const DataType type = to_concatenate[0].dtype();
Status concat_status;
+ Tensor concatenated_tensor;
switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- concat_status = Concat<type>(last_task_context, to_concatenate, i); \
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ concat_status = \
+ Concat<type>(context, to_concatenate, &concatenated_tensor); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
@@ -402,10 +398,197 @@ class BatchResource : public ResourceBase {
errors::InvalidArgument("Unsupported data type: ", type);
break;
}
- OP_REQUIRES_OK_ASYNC(last_task_context, concat_status,
- last_task_callback);
+ TF_RETURN_IF_ERROR(concat_status);
+ concatenated_tensors->push_back(concatenated_tensor);
+ }
+ return Status::OK();
+ }
+
+ Status SplitOutputTensors(const std::vector<Tensor>& combined_outputs,
+ Batch* batch) const {
+ DCHECK_GE(batch->num_tasks(), 1);
+ if (batch->num_tasks() < 1) {
+ return errors::Internal("Batch size expected to be positive; was ",
+ batch->num_tasks());
+ }
+
+ std::vector<int64> task_sizes_plus_optional_padding;
+ task_sizes_plus_optional_padding.reserve(batch->num_tasks());
+ for (int i = 0; i < batch->num_tasks(); ++i) {
+ task_sizes_plus_optional_padding.push_back(batch->task(i).size());
+ }
+ const int padding_size =
+ RoundToLowestAllowedBatchSize(batch->size()) - batch->size();
+ if (padding_size > 0) {
+ task_sizes_plus_optional_padding.push_back(padding_size);
+ }
+
+ // For each output tensor name, a divided-up tensor with one entry per task.
+ std::map<string, std::vector<Tensor>> split_tensors;
+
+ DCHECK_EQ(batch->task(0).context->num_outputs(), combined_outputs.size());
+ if (combined_outputs.size() != batch->task(0).context->num_outputs()) {
+ return errors::Internal("Wrong number of batched output tensors");
+ }
+
+ // Generate 'split_tensors' and populate the context outputs.
+ for (int i = 0; i < combined_outputs.size(); ++i) {
+ const Tensor& output_tensor = combined_outputs[i];
+ if (output_tensor.shape().dims() == 0) {
+ return errors::FailedPrecondition(
+ "Batched output tensor has 0 dimensions");
+ }
+ if (output_tensor.shape().dim_size(0) != batch->size() + padding_size) {
+ return errors::FailedPrecondition(
+ "Batched output tensor's 0th dimension does not equal the sum of "
+ "the 0th dimension sizes of the input tensors");
+ }
+
+ std::vector<Tensor> split_tensor;
+ const Status split_status = tensor::Split(
+ output_tensor, task_sizes_plus_optional_padding, &split_tensor);
+ DCHECK(split_status.ok()) << split_status.ToString();
+ if (!split_status.ok()) {
+ return errors::Internal("Tensor split operation failed: ",
+ split_status.ToString());
+ }
+ DCHECK_EQ(split_tensor.size(), task_sizes_plus_optional_padding.size());
+ if (split_tensor.size() != task_sizes_plus_optional_padding.size()) {
+ return errors::Internal(
+ "Tensor split operation did not work as expected; got ",
+ split_tensor.size(), " splits; expected ",
+ task_sizes_plus_optional_padding.size());
+ }
+
+ for (int j = 0; j < batch->num_tasks(); ++j) {
+ BatchTask& task = *(batch->mutable_task(j));
+ task.context->set_output(i, split_tensor.at(j));
+ } // (Ignore a possible final split_tensors entry containing the
+ // padding.)
+ }
+
+ return Status::OK();
+ }
+
+ void ProcessFuncBatch(std::unique_ptr<Batch> batch) const {
+ if (batch->empty()) {
+ return;
+ }
+
+ OpKernelContext* last_task_context =
+ batch->task(batch->num_tasks() - 1).context;
+
+ // Regardless of the outcome, we need to propagate the status to the
+ // individual tasks and signal that they are done. We use MakeCleanup() to
+ // ensure that this happens no matter how we exit the method below.
+ Status status;
+ bool cleanup_done = false;
+ auto cleanup_fn = [&cleanup_done, &batch](const Status& status) {
+ if (cleanup_done) {
+ return;
+ }
+ for (int i = 0; i < batch->num_tasks(); ++i) {
+ batch->mutable_task(i)->context->SetStatus(status);
+ batch->mutable_task(i)->done_callback();
+ }
+ cleanup_done = true;
+ };
+ auto finally =
+ gtl::MakeCleanup([&cleanup_fn, &status] { cleanup_fn(status); });
+
+ status = ValidateBatch(*batch);
+ if (!status.ok()) {
+ return;
}
+ std::vector<Tensor> concatenated_tensors;
+ status =
+ ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
+ if (!status.ok()) {
+ return;
+ }
+ FunctionLibraryRuntime::Options opts;
+ opts.step_id = last_task_context->step_id();
+ opts.step_container = last_task_context->step_container();
+ opts.cancellation_manager = last_task_context->cancellation_manager();
+ opts.stats_collector = last_task_context->stats_collector();
+ opts.rendezvous = last_task_context->rendezvous();
+ opts.runner = last_task_context->runner();
+
+ auto* flib = last_task_context->function_library();
+ std::vector<Tensor> combined_outputs;
+ Notification done;
+ std::vector<Tensor> args(concatenated_tensors.begin(),
+ concatenated_tensors.end());
+ const auto& captured_inputs =
+ batch->task(batch->num_tasks() - 1).captured_inputs;
+ args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
+
+ // Releases the cleanup method here, because the callback of the function
+ // library runtime will handle it now.
+ finally.release();
+ flib->Run(
+ opts, fhandle_, args, &combined_outputs, [&](const Status& run_status) {
+ Status final_status;
+ auto run_finally = gtl::MakeCleanup([&]() {
+ // We do the cleanup here as an optimization, so that it runs in
+ // the underlying TF inter-op threadpool. Running it in the
+ // threadpool, let's the ensuing ops be scheduled faster,
+ // because the executor will add them to the front of the
+ // threadpool's task queue rather than the end.
+ cleanup_fn(final_status);
+ done.Notify();
+ });
+ final_status = run_status;
+ if (!final_status.ok()) {
+ return;
+ }
+ final_status = SplitOutputTensors(combined_outputs, batch.get());
+ });
+ // By waiting for the notification we are ensuring that this thread isn't
+ // used for processing other batches, which gives the batches time to
+ // coalesce upstream. So overall the number of batches going through the
+ // devices goes down, improving latency and throughput in most cases.
+ done.WaitForNotification();
+ }
+
+ // Processes a batch of one or more BatchTask entries.
+ void ProcessBatch(std::unique_ptr<Batch> batch) const {
+ if (batch->empty()) {
+ return;
+ }
+
+ OpKernelContext* last_task_context =
+ batch->task(batch->num_tasks() - 1).context;
+ AsyncOpKernel::DoneCallback last_task_callback =
+ batch->task(batch->num_tasks() - 1).done_callback;
+
+ OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
+ last_task_callback);
+
+ // All tasks should have the same number of input edges.
+ const int num_input_edges = batch->task(0).inputs.size();
+ std::vector<Tensor> concatenated_tensors;
+ const Status concat_status =
+ ConcatInputTensors(*batch, last_task_context, &concatenated_tensors);
+ OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback);
+
+ // Process each input edge one at a time (the typical case has just one).
+ for (int i = 0; i < num_input_edges; ++i) {
+ last_task_context->set_output(i, concatenated_tensors.at(i));
+
+ // Emit batch->num_tasks() - 1 empty output tensors.
+ for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
+ const BatchTask& task = batch->task(task_idx);
+ TensorShape output_shape(task.inputs.at(i).shape());
+ output_shape.set_dim(0, 0);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ task.context,
+ task.context->allocate_output(i, output_shape, &output),
+ task.done_callback);
+ }
+ }
// Emit batch->num_tasks() - 1 empty index tensors.
for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
const BatchTask& task = batch->task(task_idx);
@@ -463,7 +646,7 @@ class BatchResource : public ResourceBase {
return Status::OK();
}
- // Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
+ // Looks up the batcher queue for 'queue_name'. If it did't previously exist,
// creates it.
Status LookupOrCreateBatcherQueue(const string& queue_name,
BatcherQueue** queue) {
@@ -477,7 +660,11 @@ class BatchResource : public ResourceBase {
std::unique_ptr<BatcherQueue> new_queue;
auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
- ProcessBatch(std::move(batch));
+ if (fhandle_ == kInvalidHandle) {
+ ProcessBatch(std::move(batch));
+ } else {
+ ProcessFuncBatch(std::move(batch));
+ }
};
TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
process_batch_callback, &new_queue));
@@ -498,8 +685,99 @@ class BatchResource : public ResourceBase {
GUARDED_BY(batcher_queues_mu_);
std::vector<int32> allowed_batch_sizes_;
+ FunctionLibraryRuntime::Handle fhandle_;
};
+class BatchFunctionKernel : public AsyncOpKernel {
+ public:
+ explicit BatchFunctionKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
+ OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
+ // If shared_name is not supplied, use name instead (prevent collisions by
+ // default).
+ if (shared_name_.empty()) {
+ shared_name_ = name();
+ }
+ OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
+ OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
+ OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
+ OP_REQUIRES_OK(c,
+ c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
+ OP_REQUIRES_OK(c,
+ c->GetAttr("max_enqueued_batches", &max_enqueued_batches_));
+ OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
+ OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
+
+ auto lib = c->function_library();
+ OP_REQUIRES(c, lib != nullptr, errors::Internal("No function library"));
+ NameAttrList func;
+ OP_REQUIRES_OK(c, c->GetAttr("f", &func));
+ OP_REQUIRES_OK(
+ c, lib->Instantiate(func.name(), AttrSlice(&func.attr()), &fhandle_));
+ }
+
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
+ BatchResource* br;
+ std::function<Status(BatchResource * *r)> creator = [this,
+ c](BatchResource** r) {
+ std::unique_ptr<BatchResource> new_resource;
+ TF_RETURN_IF_ERROR(
+ BatchResource::Create(num_batch_threads_, max_batch_size_,
+ batch_timeout_micros_, max_enqueued_batches_,
+ allowed_batch_sizes_, fhandle_, &new_resource));
+ *r = new_resource.release();
+ return Status::OK();
+ };
+ OP_REQUIRES_OK_ASYNC(c,
+ c->resource_manager()->LookupOrCreate(
+ container_, shared_name_, &br, creator),
+ done);
+ const Status status =
+ br->RegisterInput(random::New64(), c, batcher_queue_, done);
+ br->Unref();
+ OP_REQUIRES_OK_ASYNC(c, status, done);
+ // Assume br calls done, so nothing to do here.
+ }
+
+ // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
+ // and the last one must equal 'max_batch_size_'.
+ Status ValidateAllowedBatchSizes() const {
+ if (allowed_batch_sizes_.empty()) {
+ return Status::OK();
+ }
+ int32 last_size = 0;
+ for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
+ const int32 size = allowed_batch_sizes_.at(i);
+ if (i > 0 && size <= last_size) {
+ return errors::InvalidArgument(
+ "allowed_batch_sizes entries must be monotonically increasing");
+ }
+ if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
+ return errors::InvalidArgument(
+ "final entry in allowed_batch_sizes must equal max_batch_size");
+ }
+ last_size = size;
+ }
+ return Status::OK();
+ }
+
+ private:
+ string container_;
+ string shared_name_;
+ string batcher_queue_;
+ int32 num_batch_threads_;
+ int32 max_batch_size_;
+ int32 batch_timeout_micros_;
+ int32 max_enqueued_batches_;
+ std::vector<int32> allowed_batch_sizes_;
+ FunctionLibraryRuntime::Handle fhandle_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BatchFunction").Device(DEVICE_CPU),
+ BatchFunctionKernel);
+
class BatchKernel : public AsyncOpKernel {
public:
explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
@@ -528,7 +806,8 @@ class BatchKernel : public AsyncOpKernel {
std::unique_ptr<BatchResource> new_resource;
TF_RETURN_IF_ERROR(BatchResource::Create(
num_batch_threads_, max_batch_size_, batch_timeout_micros_,
- max_enqueued_batches_, allowed_batch_sizes_, &new_resource));
+ max_enqueued_batches_, allowed_batch_sizes_, kInvalidHandle,
+ &new_resource));
*r = new_resource.release();
return Status::OK();
};
@@ -539,9 +818,7 @@ class BatchKernel : public AsyncOpKernel {
const Status status =
br->RegisterInput(random::New64(), c, batcher_queue_, done);
br->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
+ OP_REQUIRES_OK_ASYNC(c, status, done);
// Assume br calls done, so nothing to do here.
}
@@ -800,9 +1077,7 @@ class UnbatchKernel : public AsyncOpKernel {
done);
auto status = ubr->Compute(c, done);
ubr->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
+ OP_REQUIRES_OK_ASYNC(c, status, done);
// Assume ubr calls done, so nothing to do here.
}
@@ -840,10 +1115,12 @@ class UnbatchGradResource : public ResourceBase {
}
const DataType type = tensors[0].dtype();
+ Tensor concatenated_tensor;
switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- TF_RETURN_IF_ERROR(Concat<type>(context, tensors, 0)); \
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ TF_RETURN_IF_ERROR(Concat<type>(context, tensors, &concatenated_tensor)); \
+ context->set_output(0, concatenated_tensor); \
break;
TF_CALL_ALL_TYPES(CASE);
#undef CASE
@@ -986,9 +1263,7 @@ class UnbatchGradKernel : public AsyncOpKernel {
done);
Status status = ubr->Compute(c, done);
ubr->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
+ OP_REQUIRES_OK_ASYNC(c, status, done);
// Assume ubr calls done, so nothing to do here.
}
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 96216764fd..b77c80c01f 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,7 +17,7 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL)
+#if !defined(INTEL_MKL) || defined(DO_NOT_USE_ML)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
#endif
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 87a0795f2f..fe259c1634 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL)
+#if !defined(INTEL_MKL) || defined(DO_NOT_USE_ML)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
#endif
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index e292ff200a..792eb74e31 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -138,6 +138,9 @@ cc_library(
tf_cc_test(
name = "serial_device_batch_scheduler_test",
srcs = ["serial_device_batch_scheduler_test.cc"],
+ tags = [
+ "notap", # b/110374108
+ ],
deps = [
":fake_clock_env",
":serial_device_batch_scheduler",
diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
index a2f8f9a03e..a91356c095 100644
--- a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
+++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc
@@ -145,8 +145,6 @@ TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) {
std::shared_ptr<SerialDeviceBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
SerialDeviceBatchScheduler<FakeTask>::Create(options, &scheduler));
- // Make sure batch processing thread has gone to sleep.
- Env::Default()->SleepForMicroseconds(1000);
int processed_batches = 0;
Notification start_processing;
auto queue_callback = [&mu, &processed_batches, &start_processing, &pending,
@@ -163,26 +161,18 @@ TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) {
start_processing.WaitForNotification();
{
mutex_lock l(mu);
- pending = 2;
- }
- break;
- case 2:
- // No batches initially --> low traffic --> no adjustment.
- CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
- {
- mutex_lock l(mu);
pending = 3;
}
break;
- case 3:
- // Pending at target --> no adjustment.
+ case 2:
+ // Either low traffic or pending at target --> no adjustment.
CHECK_EQ(scheduler->in_flight_batches_limit(), 1);
{
mutex_lock l(mu);
pending = 1;
}
break;
- case 4:
+ case 3:
// Small pending --> 2 additional threads added.
CHECK_EQ(scheduler->in_flight_batches_limit(), 3);
{
@@ -196,8 +186,8 @@ TEST(SerialDeviceBatchSchedulerTest, PendingOnSerialDevice) {
};
std::unique_ptr<BatchScheduler<FakeTask>> queue;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
- // Create 4 batches.
- for (int i = 0; i < 4; i++) {
+ // Create 3 batches.
+ for (int i = 0; i < 3; i++) {
TF_ASSERT_OK(ScheduleTask(800, queue.get()));
}
start_processing.Notify();
@@ -295,12 +285,22 @@ TEST(SerialDeviceBatchSchedulerTest, DeleteQueue) {
TF_ASSERT_OK(ScheduleTask(800, queue.get()));
}
std::unique_ptr<Thread> queue_deleter(Env::Default()->StartThread(
- {}, "QueueDeleterThread", [&queue, &mu, &processed_batches] {
+ {}, "QueueDeleterThread",
+ [&queue, &mu, &processed_batches, scheduler]() mutable {
// Delete queue, should be kept alive until empty.
queue.reset();
+ {
+ mutex_lock l(mu);
+ // queue may be destroyed before 2nd batch finishes processing.
+ EXPECT_GT(processed_batches, 0);
+ }
+ // Delete scheduler, should be kept alive until all batches processed.
+ scheduler.reset();
mutex_lock l(mu);
EXPECT_EQ(processed_batches, 2);
}));
+ // Release reference to scheduler, queue and callback above should keep alive.
+ scheduler.reset();
// Give queue_deleter thread time to delete queue.
Env::Default()->SleepForMicroseconds(1000);
finish_processing.Notify();
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 9fda7169a8..7b28c8e91f 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -29,6 +29,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/bias_op_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
@@ -363,6 +364,93 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
+struct BiasGradAutotuneGroup {
+ static string name() { return "BiasGrad"; }
+};
+
+class BiasAddGradGPUConfig {
+ public:
+ BiasAddGradGPUConfig() : mode_(BiasAddGradGPUMode::kReduction) {}
+ string ToString() const {
+ if (mode_ == BiasAddGradGPUMode::kNative) {
+ return "native CUDA kernel.";
+ }
+ if (mode_ == BiasAddGradGPUMode::kReduction) {
+ return "cub reduction kernel.";
+ }
+ return "unknown kernel.";
+ }
+ BiasAddGradGPUMode get_mode() const { return mode_; }
+ void set_mode(BiasAddGradGPUMode val) { mode_ = val; }
+
+ bool operator==(const BiasAddGradGPUConfig& other) const {
+ return this->mode_ == other.get_mode();
+ }
+
+ bool operator!=(const BiasAddGradGPUConfig& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ BiasAddGradGPUMode mode_;
+};
+
+// Encapsulate all the shape information that is used in bias add grad
+// operations.
+class BiasAddParams {
+ public:
+ // We use a list to maintain both the shape value and the order (data format).
+ using SpatialArray = gtl::InlinedVector<int64, 4>;
+ BiasAddParams(const SpatialArray& in_shape, TensorFormat data_format,
+ DataType dtype, int device_id)
+ : in_shape_(in_shape),
+ data_format_(data_format),
+ dtype_(dtype),
+ device_id_(device_id) {
+ for (int64 val : in_shape_) {
+ hash_code_ = Hash64Combine(hash_code_, val);
+ }
+ hash_code_ = Hash64Combine(hash_code_, data_format);
+ hash_code_ = Hash64Combine(hash_code_, dtype);
+ hash_code_ = Hash64Combine(hash_code_, device_id);
+ }
+ bool operator==(const BiasAddParams& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const BiasAddParams& other) const {
+ return !(*this == other);
+ }
+ uint64 hash() const { return hash_code_; }
+
+ string ToString() const {
+ // clang-format off
+ return strings::StrCat(
+ "(", str_util::Join(in_shape_, ", "), "), ",
+ data_format_, ", ", dtype_, ", ", device_id_);
+ // clang-format on
+ }
+
+ protected:
+ using ParamsDataType = std::tuple<SpatialArray, TensorFormat, DataType, int>;
+
+ ParamsDataType get_data_as_tuple() const {
+ return std::make_tuple(in_shape_, data_format_, dtype_, device_id_);
+ }
+
+ uint64 hash_code_ = 0;
+
+ private:
+ SpatialArray in_shape_;
+ TensorFormat data_format_;
+ DataType dtype_;
+ int device_id_;
+};
+
+typedef AutoTuneSingleton<BiasGradAutotuneGroup, BiasAddParams,
+ BiasAddGradGPUConfig>
+ AutotuneBiasGrad;
+
template <typename T>
class BiasGradOp<GPUDevice, T> : public OpKernel {
public:
@@ -377,6 +465,49 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
}
}
+ void ComputeWithCustomKernel(OpKernelContext* context,
+ const Tensor& output_backprop, int32 batch,
+ int32 width, int32 height, int32 channel,
+ Tensor* output) {
+ BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
+ output_backprop.template flat<T>().data(),
+ output->flat<T>().data(), batch, width, height,
+ channel, data_format_);
+ }
+
+ void ComputeWithReduceSum(OpKernelContext* context,
+ const Tensor& output_backprop, int32 batch,
+ int32 width, int32 height, int32 channel,
+ Tensor* output) {
+ if (data_format_ == FORMAT_NCHW) {
+ int32 row_count = batch * channel;
+ int32 col_count = height * width;
+ Tensor temp_grad_outputs;
+ // For 'NCHW' format, we perform reduction twice: first HW, then N.
+ TensorShape temp_grad_output_shape{row_count, col_count};
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value,
+ temp_grad_output_shape,
+ &temp_grad_outputs));
+ BiasGradGPU<T>::DoRowReduction(
+ context, temp_grad_outputs.flat<T>().data(),
+ output_backprop.template flat<T>().data(), row_count, col_count);
+
+ row_count = batch;
+ col_count = channel;
+ BiasGradGPU<T>::DoColReduction(context, output->flat<T>().data(),
+ temp_grad_outputs.flat<T>().data(),
+ row_count, col_count);
+ } else {
+ // For 'NHWC', we simply apply reduction once on NHW.
+ int32 row_count = batch * height * width;
+ int32 col_count = channel;
+ BiasGradGPU<T>::DoColReduction(
+ context, const_cast<T*>(output->flat<T>().data()),
+ reinterpret_cast<const T*>(output_backprop.template flat<T>().data()),
+ row_count, col_count);
+ }
+ }
+
void Compute(OpKernelContext* context) override {
const Tensor& output_backprop = context->input(0);
@@ -396,11 +527,65 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
output->NumElements() * sizeof(T));
stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
- if (output_backprop.NumElements() > 0) {
- BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
- output_backprop.template flat<T>().data(),
- output->flat<T>().data(), batch, width, height,
- channel, data_format_);
+ if (output_backprop.NumElements() <= 0) return;
+
+ int device_id = stream->parent()->device_ordinal();
+ DataType dtype = output_backprop.dtype();
+ BiasAddParams bias_parameters = {
+ {batch, height * width, channel},
+ data_format_,
+ dtype,
+ device_id,
+ };
+
+ // Autotune two algorithm: customized
+ BiasAddGradGPUConfig algo_config;
+ if (!AutotuneBiasGrad::GetInstance()->Find(bias_parameters, &algo_config)) {
+ BiasGradGPUProfileResult best_result;
+ // Initialize the timer.
+ perftools::gputools::Timer timer(stream->parent());
+ stream->InitTimer(&timer);
+ stream->ThenStartTimer(&timer);
+ ComputeWithCustomKernel(context, output_backprop, batch, width, height,
+ channel, output);
+ stream->ThenStopTimer(&timer);
+ uint64 elapsed_microseconds = timer.Microseconds();
+ VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
+ << " Native algo latency: " << elapsed_microseconds;
+ if (elapsed_microseconds < best_result.elapsed_time()) {
+ best_result.set_algorithm(BiasAddGradGPUMode::kNative);
+ best_result.set_elapsed_time(elapsed_microseconds);
+ }
+
+ // Try reduction and profile.
+ stream->ThenStartTimer(&timer);
+ ComputeWithReduceSum(context, output_backprop, batch, width, height,
+ channel, output);
+ stream->ThenStopTimer(&timer);
+
+ elapsed_microseconds = timer.Microseconds();
+ VLOG(1) << "BiasAddGrad " << bias_parameters.ToString()
+ << " Reduction algo latency: " << elapsed_microseconds;
+ if (elapsed_microseconds < best_result.elapsed_time()) {
+ best_result.set_algorithm(BiasAddGradGPUMode::kReduction);
+ best_result.set_elapsed_time(elapsed_microseconds);
+ }
+
+ algo_config.set_mode(best_result.algorithm());
+ AutotuneBiasGrad::GetInstance()->Insert(bias_parameters, algo_config);
+
+ // Results are already available during autotune, so no need to continue.
+ return;
+ }
+
+ // Choose the best algorithm based on autotune results.
+ if (algo_config.get_mode() == BiasAddGradGPUMode::kReduction) {
+ ComputeWithReduceSum(context, output_backprop, batch, width, height,
+ channel, output);
+ } else {
+ // Default to the customized kernel.
+ ComputeWithCustomKernel(context, output_backprop, batch, width, height,
+ channel, output);
}
}
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 754b93b073..1a7211a7cb 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -24,6 +24,14 @@ limitations under the License.
#include "tensorflow/core/kernels/bias_op_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@@ -239,6 +247,26 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
}
}
+template <typename T>
+void BiasGradGPU<T>::DoRowReduction(OpKernelContext* context, T* output,
+ const T* input, int rows, int cols) {
+ typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
+ Constants<GPUDevice> constants;
+ cub::Sum op;
+ functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
+ context, output, input, 2, rows, cols, 1, 1, constants.kOne, op);
+}
+
+template <typename T>
+void BiasGradGPU<T>::DoColReduction(OpKernelContext* context, T* output,
+ const T* input, int rows, int cols) {
+ typedef const Eigen::array<TTypes<float>::Tensor::Index, 1>& ReductionAxes;
+ Constants<GPUDevice> constants;
+ cub::Sum op;
+ functor::ReduceImpl<T, cub::Sum, T*, const T*, ReductionAxes>(
+ context, output, input, 2, rows, cols, 1, 1, constants.kZero, op);
+}
+
#define DEFINE_GPU_SPECS(T) \
template struct BiasGPU<T>; \
template struct BiasGradGPU<T>;
diff --git a/tensorflow/core/kernels/bias_op_gpu.h b/tensorflow/core/kernels/bias_op_gpu.h
index 9f14cc296f..c1051f43c9 100644
--- a/tensorflow/core/kernels/bias_op_gpu.h
+++ b/tensorflow/core/kernels/bias_op_gpu.h
@@ -19,7 +19,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -38,6 +40,39 @@ struct BiasGradGPU {
static void compute(const GPUDevice& device, const T* output_backprop,
T* bias_backprop, int32 batch, int32 height, int32 width,
int32 channel, TensorFormat data_format);
+
+ static void DoRowReduction(OpKernelContext* context, T* output,
+ const T* input, int rows, int cols);
+
+ static void DoColReduction(OpKernelContext* context, T* output,
+ const T* input, int rows, int cols);
+};
+
+enum class BiasAddGradGPUMode {
+ kInvalid = 0,
+ kNative = 1,
+ kReduction = 2,
+};
+
+// Describe the BiasGradGPU result from a perf experiment.
+//
+// Arguments:
+// algorithm: returns the method to use for bias add grad.
+// elapsed_time; returns the measured elapsed time in microseconds.
+class BiasGradGPUProfileResult {
+ public:
+ bool is_valid() const {
+ return (algorithm_ != BiasAddGradGPUMode::kInvalid &&
+ elapsed_time_ != std::numeric_limits<float>::max());
+ }
+ BiasAddGradGPUMode algorithm() const { return algorithm_; }
+ void set_algorithm(BiasAddGradGPUMode val) { algorithm_ = val; }
+ uint64 elapsed_time() const { return elapsed_time_; }
+ void set_elapsed_time(uint64 val) { elapsed_time_ = val; }
+
+ private:
+ BiasAddGradGPUMode algorithm_ = BiasAddGradGPUMode::kInvalid;
+ uint64 elapsed_time_ = std::numeric_limits<uint64>::max();
};
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 62327dfe1d..4910021c63 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -30,6 +30,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_cc",
],
)
@@ -44,6 +45,11 @@ cc_library(
],
)
+cc_library(
+ name = "tree_helper",
+ hdrs = ["tree_helper.h"],
+)
+
tf_kernel_library(
name = "resource_ops",
srcs = ["resource_ops.cc"],
@@ -60,6 +66,7 @@ tf_kernel_library(
name = "stats_ops",
srcs = ["stats_ops.cc"],
deps = [
+ ":tree_helper",
"//tensorflow/core:boosted_trees_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -71,6 +78,7 @@ tf_kernel_library(
srcs = ["training_ops.cc"],
deps = [
":resources",
+ ":tree_helper",
"//tensorflow/core:boosted_trees_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index 55599de731..c9664f0c1c 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -115,3 +115,20 @@ message TreeEnsemble {
// Metadata that is used during the training.
GrowingMetadata growing_metadata = 4;
}
+
+// DebugOutput contains outputs useful for debugging/model interpretation, at
+// the individual example-level. Debug outputs that are available to the user
+// are: 1) Directional feature contributions (DFCs) 2) Node IDs for ensemble
+// prediction path 3) Leaf node IDs.
+message DebugOutput {
+ // Return the logits and associated feature splits across prediction paths for
+ // each tree, for every example, at predict time. We will use these values to
+ // compute DFCs in Python, by subtracting each child prediction from its
+ // parent prediction and associating this change with its respective feature
+ // id.
+ repeated int32 feature_ids = 1;
+ repeated float logits_path = 2;
+
+ // TODO(crawles): return 2) Node IDs for ensemble prediction path 3) Leaf node
+ // IDs.
+}
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index 20359f28d3..b2efa06941 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -103,8 +104,8 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
const int32 latest_tree = resource->num_trees() - 1;
if (latest_tree < 0) {
- // Ensemble was empty. Nothing changes.
- output_node_ids = cached_node_ids;
+ // Ensemble was empty. Output the very first node.
+ output_node_ids.setZero();
output_tree_ids = cached_tree_ids;
// All the predictions are zeros.
output_partial_logits.setZero();
@@ -119,16 +120,20 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
int32 node_id = cached_node_ids(i);
float partial_tree_logit = 0.0;
- // If the tree was pruned, returns the node id into which the
- // current_node_id was pruned, as well the correction of the cached
- // logit prediction.
- resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
- &partial_tree_logit);
-
- // Logic in the loop adds the cached node value again if it is a leaf.
- // If it is not a leaf anymore we need to subtract the old node's
- // value. The following logic handles both of these cases.
- partial_tree_logit -= resource->node_value(tree_id, node_id);
+ if (node_id >= 0) {
+ // If the tree was pruned, returns the node id into which the
+ // current_node_id was pruned, as well the correction of the cached
+ // logit prediction.
+ resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
+ &partial_tree_logit);
+ // Logic in the loop adds the cached node value again if it is a
+ // leaf. If it is not a leaf anymore we need to subtract the old
+ // node's value. The following logic handles both of these cases.
+ partial_tree_logit -= resource->node_value(tree_id, node_id);
+ } else {
+ // No cache exists, start from the very first node.
+ node_id = 0;
+ }
float partial_all_logit = 0.0;
while (true) {
if (resource->is_leaf(tree_id, node_id)) {
@@ -219,10 +224,10 @@ class BoostedTreesPredictOp : public OpKernel {
return;
}
- const int32 latest_tree = resource->num_trees() - 1;
+ const int32 last_tree = resource->num_trees() - 1;
auto do_work = [&resource, &batch_bucketized_features, &output_logits,
- batch_size, latest_tree](int32 start, int32 end) {
+ batch_size, last_tree](int32 start, int32 end) {
for (int32 i = start; i < end; ++i) {
float tree_logit = 0.0;
int32 tree_id = 0;
@@ -232,8 +237,8 @@ class BoostedTreesPredictOp : public OpKernel {
tree_logit += resource->GetTreeWeight(tree_id) *
resource->node_value(tree_id, node_id);
- // Stop if it was the latest tree.
- if (tree_id == latest_tree) {
+ // Stop if it was the last tree.
+ if (tree_id == last_tree) {
break;
}
// Move onto other trees.
@@ -250,7 +255,7 @@ class BoostedTreesPredictOp : public OpKernel {
// 10 is the magic number. The actual number might depend on (the number of
// layers in the trees) and (cpu cycles spent on each layer), but this
// value would work for many cases. May be tuned later.
- const int64 cost = (latest_tree + 1) * 10;
+ const int64 cost = (last_tree + 1) * 10;
thread::ThreadPool* const worker_threads =
context->device()->tensorflow_cpu_worker_threads()->workers;
Shard(worker_threads->NumThreads(), worker_threads, batch_size,
@@ -266,4 +271,118 @@ class BoostedTreesPredictOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BoostedTreesPredict").Device(DEVICE_CPU),
BoostedTreesPredictOp);
+// The Op that returns debugging/model interpretability outputs for each
+// example. Currently it outputs the split feature ids and logits after each
+// split along the decision path for each example. This will be used to compute
+// directional feature contributions at predict time for an arbitrary activation
+// function.
+// TODO(crawles): return in proto 1) Node IDs for ensemble prediction path
+// 2) Leaf node IDs.
+class BoostedTreesExampleDebugOutputsOp : public OpKernel {
+ public:
+ explicit BoostedTreesExampleDebugOutputsOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_bucketized_features",
+ &num_bucketized_features_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("logits_dimension", &logits_dimension_));
+ OP_REQUIRES(context, logits_dimension_ == 1,
+ errors::InvalidArgument(
+ "Currently only one dimensional outputs are supported."));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ BoostedTreesEnsembleResource* resource;
+ // Get the resource.
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &resource));
+ // Release the reference to the resource once we're done using it.
+ core::ScopedUnref unref_me(resource);
+
+ // Get the inputs.
+ OpInputList bucketized_features_list;
+ OP_REQUIRES_OK(context, context->input_list("bucketized_features",
+ &bucketized_features_list));
+ std::vector<tensorflow::TTypes<int32>::ConstVec> batch_bucketized_features;
+ batch_bucketized_features.reserve(bucketized_features_list.size());
+ for (const Tensor& tensor : bucketized_features_list) {
+ batch_bucketized_features.emplace_back(tensor.vec<int32>());
+ }
+ const int batch_size = batch_bucketized_features[0].size();
+
+ // We need to get the feature ids used for splitting and the logits after
+ // each split. We will use these to calulate the changes in the prediction
+ // (contributions) for an arbitrary activation function (done in Python) and
+ // attribute them to the associated feature ids. We will store these in
+ // a proto below.
+ Tensor* output_debug_info_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("examples_debug_outputs_serialized",
+ {batch_size}, &output_debug_info_t));
+ // Will contain serialized protos, per example.
+ auto output_debug_info = output_debug_info_t->flat<string>();
+ const int32 last_tree = resource->num_trees() - 1;
+
+ // For each given example, traverse through all trees keeping track of the
+ // features used to split and the associated logits at each point along the
+ // path. Note: feature_ids has one less value than logits_path because the
+ // first value of each logit path will be the bias.
+ auto do_work = [&resource, &batch_bucketized_features, &output_debug_info,
+ batch_size, last_tree](int32 start, int32 end) {
+ for (int32 i = start; i < end; ++i) {
+ // Proto to store debug outputs, per example.
+ boosted_trees::DebugOutput example_debug_info;
+ // Initial bias prediction. E.g., prediction based off training mean.
+ example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
+ resource->node_value(0, 0));
+ int32 node_id = 0;
+ int32 tree_id = 0;
+ int32 feature_id;
+ float tree_logit;
+ float past_trees_logit = 0; // Sum of leaf logits from prior trees.
+ // Populate proto.
+ while (tree_id <= last_tree) {
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
+ if (resource->is_leaf(tree_id, node_id)) {
+ // Move onto other trees.
+ past_trees_logit += tree_logit;
+ ++tree_id;
+ node_id = 0;
+ }
+ }
+ // Set output as serialized proto containing debug info.
+ string serialized = example_debug_info.SerializeAsString();
+ output_debug_info(i) = serialized;
+ }
+ };
+
+ // 10 is the magic number. The actual number might depend on (the number of
+ // layers in the trees) and (cpu cycles spent on each layer), but this
+ // value would work for many cases. May be tuned later.
+ const int64 cost = (last_tree + 1) * 10;
+ thread::ThreadPool* const worker_threads =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ Shard(worker_threads->NumThreads(), worker_threads, batch_size,
+ /*cost_per_unit=*/cost, do_work);
+ }
+
+ private:
+ int32 logits_dimension_; // Indicates dimension of logits in the tree nodes.
+ int32 num_bucketized_features_; // Indicates the number of features.
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesExampleDebugOutputs").Device(DEVICE_CPU),
+ BoostedTreesExampleDebugOutputsOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index c410748c27..cc90bb2f45 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -21,6 +21,10 @@ limitations under the License.
namespace tensorflow {
+namespace {
+constexpr float kLayerByLayerTreeWeight = 1.0;
+} // namespace
+
// Constructor.
BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
: tree_ensemble_(
@@ -78,6 +82,16 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
}
}
+void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
+ const int32 node_id,
+ const float logits) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
+ DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
+ node->mutable_leaf()->set_scalar(logits);
+}
+
int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
const int32 tree_id) const {
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
@@ -204,9 +218,14 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
// Add a tree to the ensemble and returns a new tree_id.
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
+ return AddNewTreeWithLogits(weight, 0.0);
+}
+
+int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
+ const float logits) {
const int32 new_tree_id = tree_ensemble_->trees_size();
auto* node = tree_ensemble_->add_trees()->add_nodes();
- node->mutable_leaf()->set_scalar(0.0);
+ node->mutable_leaf()->set_scalar(logits);
tree_ensemble_->add_tree_weights(weight);
tree_ensemble_->add_tree_metadata();
@@ -225,7 +244,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
*right_node_id = *left_node_id + 1;
auto* left_node = tree->add_nodes();
auto* right_node = tree->add_nodes();
- if (node_id != 0) {
+ if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) {
// Save previous leaf value if it is not the first leaf in the tree.
node->mutable_metadata()->mutable_original_leaf()->Swap(
node->mutable_leaf());
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
index df78d3f275..f961ed3814 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.h
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -70,6 +70,9 @@ class BoostedTreesEnsembleResource : public StampedResource {
float node_value(const int32 tree_id, const int32 node_id) const;
+ void set_node_value(const int32 tree_id, const int32 node_id,
+ const float logits);
+
int32 GetNumLayersGrown(const int32 tree_id) const;
void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
@@ -99,6 +102,9 @@ class BoostedTreesEnsembleResource : public StampedResource {
// Add a tree to the ensemble and returns a new tree_id.
int32 AddNewTree(const float weight);
+ // Adds new tree with one node to the ensemble and sets node's value to logits
+ int32 AddNewTreeWithLogits(const float weight, const float logits);
+
// Grows the tree by adding a split and leaves.
void AddBucketizedSplitNode(const int32 tree_id, const int32 node_id,
const int32 feature_id, const int32 threshold,
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 53bdd482cb..64ec1caa9c 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -17,13 +17,10 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
namespace tensorflow {
-namespace {
-const float kEps = 1e-15;
-} // namespace
-
class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
public:
explicit BoostedTreesCalculateBestGainsPerFeatureOp(
@@ -139,7 +136,7 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
total_hess - cum_hess_bucket, l1, l2,
&contrib_for_right, &gain_for_right);
- if (gain_for_left + gain_for_right > best_gain) {
+ if (GainIsLarger(gain_for_left + gain_for_right, best_gain)) {
best_gain = gain_for_left + gain_for_right;
best_bucket = bucket;
best_contrib_for_left = contrib_for_left;
@@ -200,40 +197,6 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
}
private:
- void CalculateWeightsAndGains(const float g, const float h, const float l1,
- const float l2, float* weight, float* gain) {
- //
- // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
- // (g+l1*sgn(w))^2/(h+l2).
- // This is because for each leaf we optimize
- // 1/2(h+l2)*w^2+g*w+l1*abs(w)
- float g_with_l1 = g;
- // Apply L1 regularization.
- // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
- // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
- // For g from (-l1, l1), thus there is no solution => set to 0.
- if (l1 > 0) {
- if (g > l1) {
- g_with_l1 -= l1;
- } else if (g < -l1) {
- g_with_l1 += l1;
- } else {
- *weight = 0.0;
- *gain = 0.0;
- return;
- }
- }
- // Apply L2 regularization.
- if (h + l2 <= kEps) {
- // Avoid division by 0 or infinitesimal.
- *weight = 0;
- *gain = 0;
- } else {
- *weight = -g_with_l1 / (h + l2);
- *gain = -g_with_l1 * (*weight);
- }
- }
-
int max_splits_;
int num_features_;
};
@@ -255,7 +218,7 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
// node_ids
const Tensor* node_ids_t;
OP_REQUIRES_OK(context, context->input("node_ids", &node_ids_t));
- const auto node_ids = node_ids_t->flat<int32>();
+ const auto node_ids = node_ids_t->vec<int32>();
// gradients
const Tensor* gradients_t;
OP_REQUIRES_OK(context, context->input("gradients", &gradients_t));
@@ -270,46 +233,34 @@ class BoostedTreesMakeStatsSummaryOp : public OpKernel {
&bucketized_features_list));
// Infer batch size.
const int64 batch_size = node_ids_t->dim_size(0);
- // Allocate output stats tensor (Rank 4).
- Tensor* output_stats_summary_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "stats_summary",
- {num_features_, max_splits_, num_buckets_, 2},
- &output_stats_summary_t));
- auto output_stats_summary = output_stats_summary_t->flat<float>();
- EIGEN_STATIC_ASSERT(
- (static_cast<int>(decltype(output_stats_summary)::Layout) ==
- static_cast<int>(Eigen::RowMajor)),
- THIS_METHOD_IS_ONLY_FOR_ROW_MAJOR_MATRICES);
- const int shift_per_node = num_buckets_ * 2;
- const int shift_per_feature = shift_per_node * max_splits_;
- const int32 max_index = num_features_ * shift_per_feature;
- // We use double to sum the gradients and hessians, due to possible
- // precision loss when summing small float values.
- std::vector<double> res(max_index, 0);
+ // Allocate temporary stats tensor (Rank 4).
+ Tensor temp_stats_double_t;
+ OP_REQUIRES_OK(context, context->allocate_temp(
+ DT_DOUBLE,
+ {num_features_, max_splits_, num_buckets_, 2},
+ &temp_stats_double_t));
+ auto temp_stats_double = temp_stats_double_t.tensor<double, 4>();
+ temp_stats_double.setZero();
// Partition by node, and then bucketize.
- int feature_idx = 0;
- int feature_shift = 0;
- for (const Tensor& tensor : bucketized_features_list) {
- const auto& features = tensor.flat<int32>();
+ for (int feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
+ const auto& features = bucketized_features_list[feature_idx].vec<int32>();
for (int i = 0; i < batch_size; ++i) {
const int32 node = node_ids(i);
const int32 bucket = features(i);
- // Calculate the index in the flattened vector for
- // [feature_idx][node][bucket][0].
- const int index = feature_shift + node * shift_per_node + bucket * 2;
- res[index] += gradients(i, 0);
- res[index + 1] += hessians(i, 0);
+ temp_stats_double(feature_idx, node, bucket, 0) += gradients(i, 0);
+ temp_stats_double(feature_idx, node, bucket, 1) += hessians(i, 0);
}
- ++feature_idx;
- feature_shift += shift_per_feature;
- }
- // Copy over the results.
- for (int i = 0; i < max_index; ++i) {
- output_stats_summary(i) = res[i];
}
+
+ // Copy temp tensor over to output tensor.
+ Tensor* output_stats_summary_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "stats_summary", temp_stats_double_t.shape(),
+ &output_stats_summary_t));
+ output_stats_summary_t->tensor<float, 4>() =
+ temp_stats_double.template cast<float>();
}
private:
diff --git a/tensorflow/core/kernels/boosted_trees/training_ops.cc b/tensorflow/core/kernels/boosted_trees/training_ops.cc
index a14fd4a133..973cdec13a 100644
--- a/tensorflow/core/kernels/boosted_trees/training_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/training_ops.cc
@@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/boosted_trees/resources.h"
+#include "tensorflow/core/kernels/boosted_trees/tree_helper.h"
namespace tensorflow {
namespace {
constexpr float kLayerByLayerTreeWeight = 1.0;
+constexpr float kMinDeltaForCenterBias = 0.01;
// TODO(nponomareva, youngheek): consider using vector.
struct SplitCandidate {
@@ -89,7 +91,8 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
// Find best splits for each active node.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerNode(context, node_ids_list, gains_list, &best_splits);
+ FindBestSplitsPerNode(context, node_ids_list, gains_list, feature_ids,
+ &best_splits);
int32 current_tree =
UpdateGlobalAttemptsAndRetrieveGrowableTree(ensemble_resource);
@@ -193,6 +196,7 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
void FindBestSplitsPerNode(
OpKernelContext* const context, const OpInputList& node_ids_list,
const OpInputList& gains_list,
+ const TTypes<const int32>::Vec& feature_ids,
std::map<int32, SplitCandidate>* best_split_per_node) {
// Find best split per node going through every feature candidate.
for (int64 feature_idx = 0; feature_idx < num_features_; ++feature_idx) {
@@ -211,8 +215,18 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
candidate.candidate_idx = candidate_idx;
candidate.gain = gain;
- if (best_split_it == best_split_per_node->end() ||
- gain > best_split_it->second.gain) {
+ if (TF_PREDICT_FALSE(best_split_it != best_split_per_node->end() &&
+ GainsAreEqual(gain, best_split_it->second.gain))) {
+ const auto best_candidate = (*best_split_per_node)[node_id];
+ const int32 best_feature_id = feature_ids(best_candidate.feature_idx);
+ const int32 feature_id = feature_ids(candidate.feature_idx);
+ VLOG(2) << "Breaking ties on feature ids and buckets";
+ // Breaking ties deterministically.
+ if (feature_id < best_feature_id) {
+ (*best_split_per_node)[node_id] = candidate;
+ }
+ } else if (best_split_it == best_split_per_node->end() ||
+ GainIsLarger(gain, best_split_it->second.gain)) {
(*best_split_per_node)[node_id] = candidate;
}
}
@@ -227,4 +241,69 @@ class BoostedTreesUpdateEnsembleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BoostedTreesUpdateEnsemble").Device(DEVICE_CPU),
BoostedTreesUpdateEnsembleOp);
+class BoostedTreesCenterBiasOp : public OpKernel {
+ public:
+ explicit BoostedTreesCenterBiasOp(OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* const context) override {
+ // Get decision tree ensemble.
+ BoostedTreesEnsembleResource* ensemble_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &ensemble_resource));
+ core::ScopedUnref unref_me(ensemble_resource);
+ mutex_lock l(*ensemble_resource->get_mutex());
+ // Increase the ensemble stamp.
+ ensemble_resource->set_stamp(ensemble_resource->stamp() + 1);
+
+ // Read means of hessians and gradients
+ const Tensor* mean_gradients_t;
+ OP_REQUIRES_OK(context,
+ context->input("mean_gradients", &mean_gradients_t));
+
+ const Tensor* mean_hessians_t;
+ OP_REQUIRES_OK(context, context->input("mean_hessians", &mean_hessians_t));
+
+ // Get the regularization options.
+ const Tensor* l1_t;
+ OP_REQUIRES_OK(context, context->input("l1", &l1_t));
+ const auto l1 = l1_t->scalar<float>()();
+ const Tensor* l2_t;
+ OP_REQUIRES_OK(context, context->input("l2", &l2_t));
+ const auto l2 = l2_t->scalar<float>()();
+
+ // For now, assume 1-dimensional weight on leaves.
+ float logits;
+ float unused_gain;
+
+ // TODO(nponomareva): change this when supporting multiclass.
+ const float gradients_mean = mean_gradients_t->flat<float>()(0);
+ const float hessians_mean = mean_hessians_t->flat<float>()(0);
+ CalculateWeightsAndGains(gradients_mean, hessians_mean, l1, l2, &logits,
+ &unused_gain);
+
+ float current_bias = 0.0;
+ bool continue_centering = true;
+ if (ensemble_resource->num_trees() == 0) {
+ ensemble_resource->AddNewTreeWithLogits(kLayerByLayerTreeWeight, logits);
+ current_bias = logits;
+ } else {
+ current_bias = ensemble_resource->node_value(0, 0);
+ continue_centering =
+ std::abs(logits / current_bias) > kMinDeltaForCenterBias;
+ current_bias += logits;
+ ensemble_resource->set_node_value(0, 0, current_bias);
+ }
+
+ Tensor* continue_centering_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output("continue_centering", TensorShape({}),
+ &continue_centering_t));
+ // Check if we need to continue centering bias.
+ continue_centering_t->scalar<bool>()() = continue_centering;
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesCenterBias").Device(DEVICE_CPU),
+ BoostedTreesCenterBiasOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/tree_helper.h b/tensorflow/core/kernels/boosted_trees/tree_helper.h
new file mode 100644
index 0000000000..8b18d9e5f8
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/tree_helper.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_
+#include <cmath>
+
+namespace tensorflow {
+
+static bool GainsAreEqual(const float g1, const float g2) {
+ const float kTolerance = 1e-15;
+ return std::abs(g1 - g2) < kTolerance;
+}
+
+static bool GainIsLarger(const float g1, const float g2) {
+ const float kTolerance = 1e-15;
+ return g1 - g2 >= kTolerance;
+}
+
+static void CalculateWeightsAndGains(const float g, const float h,
+ const float l1, const float l2,
+ float* weight, float* gain) {
+ const float kEps = 1e-15;
+ // The formula for weight is -(g+l1*sgn(w))/(H+l2), for gain it is
+ // (g+l1*sgn(w))^2/(h+l2).
+ // This is because for each leaf we optimize
+ // 1/2(h+l2)*w^2+g*w+l1*abs(w)
+ float g_with_l1 = g;
+ // Apply L1 regularization.
+ // 1) Assume w>0 => w=-(g+l1)/(h+l2)=> g+l1 < 0 => g < -l1
+ // 2) Assume w<0 => w=-(g-l1)/(h+l2)=> g-l1 > 0 => g > l1
+ // For g from (-l1, l1), thus there is no solution => set to 0.
+ if (l1 > 0) {
+ if (g > l1) {
+ g_with_l1 -= l1;
+ } else if (g < -l1) {
+ g_with_l1 += l1;
+ } else {
+ *weight = 0.0;
+ *gain = 0.0;
+ return;
+ }
+ }
+ // Apply L2 regularization.
+ if (h + l2 <= kEps) {
+ // Avoid division by 0 or infinitesimal.
+ *weight = 0;
+ *gain = 0;
+ } else {
+ *weight = -g_with_l1 / (h + l2);
+ *gain = -g_with_l1 * (*weight);
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_TREE_HELPER_H_
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index a87b63f913..902327aaea 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -113,7 +113,7 @@ class ConcatBaseOp : public OpKernel {
int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
- const auto in = values[i];
+ const auto& in = values[i];
const bool in_is_scalar = IsLegacyScalar(in.shape());
OP_REQUIRES(
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index fe1a1ba5a3..a888422d49 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -297,7 +297,8 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
+ Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
+ TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 7d5d54e5be..fd3a0ad422 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -108,6 +108,7 @@ REGISTER_GPU_HOST_KERNEL(bool);
REGISTER_GPU_HOST_REF_KERNEL(bool);
REGISTER_GPU_HOST_KERNEL(string);
REGISTER_GPU_HOST_REF_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
#undef REGISTER_GPU_HOST_KERNEL
#undef REGISTER_GPU_HOST_REF_KERNEL
@@ -587,24 +588,14 @@ REGISTER_SYCL_HOST_KERNEL(string);
#undef REGISTER_SYCL_HOST_KERNEL
#endif // TENSORFLOW_USE_SYCL
-// A LoopCond op has one input and one output. The input is a boolean
-// scalar representing the taken branches of the "pivot" Switch that
-// determines loop termination. As a contract, any high-level front-end
-// should always use port '0' of the "pivot" switches for loop exit.
-class LoopCondOp : public OpKernel {
- public:
- explicit LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- context->set_output(0, context->input(0));
- }
-
- bool IsExpensive() override { return false; }
+LoopCondOp::LoopCondOp(OpKernelConstruction* context) : OpKernel(context) {}
+LoopCondOp::~LoopCondOp() = default;
- ~LoopCondOp() override {}
+void LoopCondOp::Compute(OpKernelContext* context) {
+ context->set_output(0, context->input(0));
+}
- TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
-};
+bool LoopCondOp::IsExpensive() { return false; }
REGISTER_KERNEL_BUILDER(Name("LoopCond").Device(DEVICE_CPU), LoopCondOp);
REGISTER_KERNEL_BUILDER(Name("LoopCond")
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
index 4838f2e2bf..8edbcc9077 100644
--- a/tensorflow/core/kernels/control_flow_ops.h
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -97,6 +97,22 @@ class NextIterationOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(NextIterationOp);
};
+// A LoopCond op has one input and one output. The input is a boolean
+// scalar representing the taken branches of the "pivot" Switch that
+// determines loop termination. As a contract, any high-level front-end
+// should always use port '0' of the "pivot" switches for loop exit.
+class LoopCondOp : public OpKernel {
+ public:
+ explicit LoopCondOp(OpKernelConstruction* context);
+ ~LoopCondOp() override;
+
+ void Compute(OpKernelContext* context) override;
+
+ bool IsExpensive() override;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LoopCondOp);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 6949e5b5fd..6b7544fd4c 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -159,7 +159,7 @@ struct TransformFilter {
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS; ++i) { // spatial dimensions
+ for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
expanded_dims[i + 2] = in.dimension(i);
}
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index bdd08222d4..aca75176a5 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -404,9 +404,10 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
// image ('work_unit_size').
// TODO(andydavis)
+ // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
// *) Consider reducing 'target_working_set_size' if L3 is shared by
// other concurrently running tensorflow ops.
- const size_t target_working_set_size = Eigen::l3CacheSize() / sizeof(T);
+ const size_t target_working_set_size = (30LL << 20) / sizeof(T);
const size_t size_A = output_image_size * filter_total_size;
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 95301b170f..63a775afa8 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -420,8 +420,9 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
const int output_image_size =
dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
- const size_t l2_cache_size = Eigen::l2CacheSize();
- const size_t l3_cache_size = Eigen::l3CacheSize();
+ // TODO(andydavis) Get L2/L3 cache sizes from device.
+ const size_t l2_cache_size = 256LL << 10;
+ const size_t l3_cache_size = 30LL << 20;
// Use L3 cache size as target working set size.
const size_t target_working_set_size = l3_cache_size / sizeof(T);
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc
index 1b40ad81f4..972100ba77 100644
--- a/tensorflow/core/kernels/conv_ops_fused.cc
+++ b/tensorflow/core/kernels/conv_ops_fused.cc
@@ -195,7 +195,7 @@ EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
const int64 bottom_y_index =
std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
// Lerp is used for bilinear filtering when that's needed.
- result.y_lerp = in_y - top_y_index;
+ result.y_lerp = static_cast<T1>(in_y - top_y_index);
// Which rows of the original input image to pull the values from.
result.input_top_row_start =
input_batch_start + (top_y_index * input_width * input_depth);
@@ -245,7 +245,7 @@ CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
result.right_x_index =
std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
// This x_lerp is used to blend pixels in bilinear filtering.
- result.x_lerp = in_x - result.left_x_index;
+ result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
return result;
}
@@ -465,8 +465,8 @@ class FusedResizeAndPadConvFunctor {
// for that operation are always present.
// Work out the parameters that remain constant across the
// row we're calculating.
- PerCacheLineParameters<float> line_params(
- CalculatePerCacheLineParameters<float>(
+ PerCacheLineParameters<T1> line_params(
+ CalculatePerCacheLineParameters<T1>(
task_params.cache_height, cache_y,
task_params.resize_cache,
task_params.cache_line_width, task_params.input_width,
@@ -881,7 +881,9 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
BILINEAR>, \
true>);
+TF_CALL_half(REGISTER_FUSED);
TF_CALL_float(REGISTER_FUSED);
+TF_CALL_double(REGISTER_FUSED);
#define REGISTER_PAD_ONLY_FUSED(T) \
REGISTER_KERNEL_BUILDER( \
@@ -892,6 +894,8 @@ TF_CALL_float(REGISTER_FUSED);
NEAREST>, \
false>);
+TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
+TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a2e7342b04..a5fa48f85e 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -247,7 +247,13 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(
constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
constexpr int WriteRowPerPass = NumThreads / TileSizeI;
// One extra line in the inner dimension to avoid share memory bank conflict.
- __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
+ // This is to mimic the following, but no constructor of T can be invoked.
+ // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
+ __shared__ __align__(
+ alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
+ typedef T(*SharedMemoryTile)[TileSizeJ + 1];
+ SharedMemoryTile shared_memory_tile =
+ reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
int x = threadIdx.x;
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index 8afe6a2cbd..4f9a96ce17 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -88,14 +88,15 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) {
class FusedResizePadConvOpTest : public OpsTestBase {
protected:
- void HandwrittenConv() {
+ template <typename T>
+ void HandwrittenConv(DataType dtype) {
const int stride = 1;
TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
- .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(dtype))
.Input(FakeInput(DT_INT32))
.Input(FakeInput(DT_INT32))
- .Input(FakeInput(DT_FLOAT))
- .Attr("T", DT_FLOAT)
+ .Input(FakeInput(dtype))
+ .Attr("T", dtype)
.Attr("resize_align_corners", false)
.Attr("mode", "REFLECT")
.Attr("strides", {1, stride, stride, 1})
@@ -110,9 +111,8 @@ class FusedResizePadConvOpTest : public OpsTestBase {
// | 1 | 2 | 3 | 4 |
// | 5 | 6 | 7 | 8 |
// | 9 | 10 | 11 | 12 |
- Tensor image(DT_FLOAT,
- {image_batch_count, image_height, image_width, depth});
- test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
+ test::FillValues<T>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
// The filter matrix is:
// | 1 | 4 | 7 |
@@ -120,8 +120,8 @@ class FusedResizePadConvOpTest : public OpsTestBase {
// | 3 | 6 | 9 |
const int filter_size = 3;
const int filter_count = 1;
- Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
- test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
+ Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
+ test::FillValues<T>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
const int resized_width = image_width;
const int resized_height = image_height;
@@ -131,12 +131,12 @@ class FusedResizePadConvOpTest : public OpsTestBase {
const int left_padding = 0;
const int right_padding = 0;
- AddInputFromArray<float>(image.shape(), image.flat<float>());
+ AddInputFromArray<T>(image.shape(), image.flat<T>());
AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
AddInputFromArray<int32>(
TensorShape({4, 2}),
{0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
- AddInputFromArray<float>(filter.shape(), filter.flat<float>());
+ AddInputFromArray<T>(filter.shape(), filter.flat<T>());
TF_ASSERT_OK(RunOpKernel());
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
@@ -160,21 +160,22 @@ class FusedResizePadConvOpTest : public OpsTestBase {
// | 187 | 234 | 261 | 121 |
const int expected_width = image_width;
const int expected_height = image_height * filter_count;
- Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
- expected_width, filter_count}));
- test::FillValues<float>(
+ Tensor expected(dtype, TensorShape({image_batch_count, expected_height,
+ expected_width, filter_count}));
+ test::FillValues<T>(
&expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
const Tensor& output = *GetOutput(0);
- test::ExpectTensorNear<float>(expected, output, 1e-5);
+ test::ExpectTensorNear<T>(expected, output, 1e-5);
}
+ template <typename T>
void CompareFusedAndSeparate(int input_width, int input_height,
int input_depth, int resize_width,
int resize_height, int y_padding, int x_padding,
int filter_size, int filter_count,
bool resize_align_corners,
const string& pad_mode, int stride,
- const string& padding) {
+ const string& padding, DataType dtype) {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -183,29 +184,34 @@ class FusedResizePadConvOpTest : public OpsTestBase {
test::FillIota<float>(&input_data, 1.0f);
Output input =
Const(root.WithOpName("input"), Input::Initializer(input_data));
+ Output casted_input = Cast(root.WithOpName("casted_input"), input, dtype);
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
input_depth, filter_count}));
test::FillIota<float>(&filter_data, 1.0f);
Output filter =
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
+ Output casted_filter =
+ Cast(root.WithOpName("casted_filter"), filter, dtype);
Output resize_size =
Const(root.WithOpName("resize_size"), {resize_height, resize_width});
Output resize =
ResizeBilinear(root.WithOpName("resize"), input, resize_size,
ResizeBilinear::AlignCorners(resize_align_corners));
+ // Bilinear resize only output float, cast it to dtype to match the input.
+ Output casted_resize = Cast(root.WithOpName("cast"), resize, dtype);
Output paddings =
Const(root.WithOpName("paddings"),
{{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
- Output mirror_pad =
- MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
- Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
+ Output mirror_pad = MirrorPad(root.WithOpName("mirror_pad"), casted_resize,
+ paddings, pad_mode);
+ Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, casted_filter,
{1, stride, stride, 1}, padding);
Output fused_conv = FusedResizeAndPadConv2D(
- root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
- pad_mode, {1, stride, stride, 1}, padding,
+ root.WithOpName("fused_conv"), casted_input, resize_size, paddings,
+ casted_filter, pad_mode, {1, stride, stride, 1}, padding,
FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
tensorflow::GraphDef graph;
@@ -221,14 +227,16 @@ class FusedResizePadConvOpTest : public OpsTestBase {
std::vector<Tensor> fused_tensors;
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
- test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5);
}
+ template <typename T>
void CompareFusedPadOnlyAndSeparate(int input_width, int input_height,
int input_depth, int y_padding,
int x_padding, int filter_size,
int filter_count, const string& pad_mode,
- int stride, const string& padding) {
+ int stride, const string& padding,
+ DataType dtype) {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -237,24 +245,27 @@ class FusedResizePadConvOpTest : public OpsTestBase {
test::FillIota<float>(&input_data, 1.0f);
Output input =
Const(root.WithOpName("input"), Input::Initializer(input_data));
+ Output casted_input = Cast(root.WithOpName("casted_input"), input, dtype);
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
input_depth, filter_count}));
test::FillIota<float>(&filter_data, 1.0f);
Output filter =
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
+ Output casted_filter =
+ Cast(root.WithOpName("casted_filter"), filter, dtype);
Output paddings =
Const(root.WithOpName("paddings"),
{{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
- Output mirror_pad =
- MirrorPad(root.WithOpName("mirror_pad"), input, paddings, pad_mode);
- Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
+ Output mirror_pad = MirrorPad(root.WithOpName("mirror_pad"), casted_input,
+ paddings, pad_mode);
+ Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, casted_filter,
{1, stride, stride, 1}, padding);
- Output fused_conv =
- FusedPadConv2D(root.WithOpName("fused_conv"), input, paddings, filter,
- pad_mode, {1, stride, stride, 1}, padding);
+ Output fused_conv = FusedPadConv2D(
+ root.WithOpName("fused_conv"), casted_input, paddings, casted_filter,
+ pad_mode, {1, stride, stride, 1}, padding);
tensorflow::GraphDef graph;
TF_ASSERT_OK(root.ToGraphDef(&graph));
@@ -269,95 +280,130 @@ class FusedResizePadConvOpTest : public OpsTestBase {
std::vector<Tensor> fused_tensors;
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
- test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5);
}
};
-TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
+TEST_F(FusedResizePadConvOpTest, HandwrittenConvHalf) {
+ HandwrittenConv<Eigen::half>(DT_HALF);
+}
-TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
- CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
- "SAME");
+TEST_F(FusedResizePadConvOpTest, HandwrittenConvFloat) {
+ HandwrittenConv<float>(DT_FLOAT);
+}
+
+TEST_F(FusedResizePadConvOpTest, HandwrittenConvDouble) {
+ HandwrittenConv<double>(DT_DOUBLE);
+}
+
+TEST_F(FusedResizePadConvOpTest, IdentityComparativeHalf) {
+ CompareFusedAndSeparate<Eigen::half>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
+ "REFLECT", 1, "SAME", DT_HALF);
+}
+
+TEST_F(FusedResizePadConvOpTest, IdentityComparativeFloat) {
+ CompareFusedAndSeparate<float>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
+ "REFLECT", 1, "SAME", DT_FLOAT);
+}
+
+TEST_F(FusedResizePadConvOpTest, IdentityComparativeDouble) {
+ CompareFusedAndSeparate<double>(10, 10, 1, 10, 10, 0, 0, 1, 1, false,
+ "REFLECT", 1, "SAME", DT_DOUBLE);
}
TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
- CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(10, 10, 3, 10, 10, 0, 0, 4, 4, false,
+ "REFLECT", 1, "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
- CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(10, 10, 1, 20, 20, 0, 0, 1, 1, false,
+ "REFLECT", 1, "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
- CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
- CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
- CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
- "SAME");
+ CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
- CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
- "VALID");
+ CompareFusedAndSeparate<float>(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+ "VALID", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
- CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
- CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
- CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
- CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC",
+ 1, "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
- CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
- "SAME");
+ CompareFusedAndSeparate<float>(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC",
+ 1, "SAME", DT_FLOAT);
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparativeLarge) {
+ CompareFusedAndSeparate<float>(1000, 1000, 3, 1006, 1006, 2, 2, 1, 1, false,
+ "SYMMETRIC", 1, "SAME", DT_FLOAT);
}
-TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparative) {
- CompareFusedPadOnlyAndSeparate(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1, "SAME");
+TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeHalf) {
+ CompareFusedPadOnlyAndSeparate<Eigen::half>(10, 10, 1, 0, 0, 1, 1, "REFLECT",
+ 1, "SAME", DT_HALF);
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeFloat) {
+ CompareFusedPadOnlyAndSeparate<float>(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1,
+ "SAME", DT_FLOAT);
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparativeDouble) {
+ CompareFusedPadOnlyAndSeparate<double>(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1,
+ "SAME", DT_DOUBLE);
}
TEST_F(FusedResizePadConvOpTest, NoResizeConvOnlyComparative) {
- CompareFusedPadOnlyAndSeparate(10, 10, 3, 0, 0, 4, 4, "REFLECT", 1, "SAME");
+ CompareFusedPadOnlyAndSeparate<float>(10, 10, 3, 0, 0, 4, 4, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyComparative) {
- CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "REFLECT", 1, "SAME");
+ CompareFusedPadOnlyAndSeparate<float>(4, 4, 1, 2, 2, 1, 1, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyWithChannelsComparative) {
- CompareFusedPadOnlyAndSeparate(4, 4, 3, 2, 2, 1, 1, "REFLECT", 1, "SAME");
+ CompareFusedPadOnlyAndSeparate<float>(4, 4, 3, 2, 2, 1, 1, "REFLECT", 1,
+ "SAME", DT_FLOAT);
}
TEST_F(FusedResizePadConvOpTest, NoResizePadOnlySymmetricComparative) {
- CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "SYMMETRIC", 1, "SAME");
-}
-
-TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparativeLarge) {
- CompareFusedAndSeparate(1000, 1000, 3, 1006, 1006, 2, 2, 1, 1, false,
- "SYMMETRIC", 1, "SAME");
+ CompareFusedPadOnlyAndSeparate<float>(4, 4, 1, 2, 2, 1, 1, "SYMMETRIC", 1,
+ "SAME", DT_FLOAT);
}
class ConvOpTest : public OpsTestBase {
diff --git a/tensorflow/core/kernels/cwise_op_bessel.cc b/tensorflow/core/kernels/cwise_op_bessel.cc
new file mode 100644
index 0000000000..4372f56408
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_bessel.cc
@@ -0,0 +1,29 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
+ double);
+REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
+ double);
+#if GOOGLE_CUDA
+REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
+ double);
+REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
+ double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_bessel.cu.cc b/tensorflow/core/kernels/cwise_op_bessel.cu.cc
new file mode 100644
index 0000000000..30de8b1fdc
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_bessel.cu.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_UNARY3(bessel_i0e, Eigen::half, float, double);
+DEFINE_UNARY3(bessel_i1e, Eigen::half, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
index ea10ebe9a0..931f59014b 100644
--- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
- uint8, int8, int16);
+REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
+ uint8, int8, int16, bfloat16);
REGISTER_KERNEL_BUILDER(
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
ApproximateEqualOp<CPUDevice, float>);
diff --git a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
index 5a529bd8ca..508a47deda 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_igammas.cu.cc
@@ -16,10 +16,12 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+#include "tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h"
namespace tensorflow {
namespace functor {
DEFINE_BINARY2(igamma, float, double);
+DEFINE_BINARY2(igamma_grad_a, float, double);
DEFINE_BINARY2(igammac, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc
new file mode 100644
index 0000000000..fd0a95ecc5
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_random_grad.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY2(random_gamma_grad, float, double);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc
index a4ea408836..b385e9e545 100644
--- a/tensorflow/core/kernels/cwise_op_greater.cc
+++ b/tensorflow/core/kernels/cwise_op_greater.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
- double, int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "Greater", functor::greater, float, Eigen::half,
+ double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Greater", functor::greater, float, Eigen::half,
double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc
index 3f34d6269e..8bfc018052 100644
--- a/tensorflow/core/kernels/cwise_op_greater_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER8(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
- Eigen::half, double, int32, int64, uint8, int8, int16);
+REGISTER9(BinaryOp, CPU, "GreaterEqual", functor::greater_equal, float,
+ Eigen::half, double, int32, int64, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "GreaterEqual", functor::greater_equal, float,
Eigen::half, double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_igammas.cc b/tensorflow/core/kernels/cwise_op_igammas.cc
index 4b5f888bc1..cadda3b723 100644
--- a/tensorflow/core/kernels/cwise_op_igammas.cc
+++ b/tensorflow/core/kernels/cwise_op_igammas.cc
@@ -14,12 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/cwise_ops_gradients.h"
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Igamma", functor::igamma, float, double);
+REGISTER2(BinaryOp, CPU, "IgammaGradA", functor::igamma_grad_a, float, double);
REGISTER2(BinaryOp, CPU, "Igammac", functor::igammac, float, double);
#if GOOGLE_CUDA
REGISTER2(BinaryOp, GPU, "Igamma", functor::igamma, float, double);
+REGISTER2(BinaryOp, GPU, "IgammaGradA", functor::igamma_grad_a, float, double);
REGISTER2(BinaryOp, GPU, "Igammac", functor::igammac, float, double);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc
index 575968126f..e369fdcf8a 100644
--- a/tensorflow/core/kernels/cwise_op_less.cc
+++ b/tensorflow/core/kernels/cwise_op_less.cc
@@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
- bfloat16, int32, int64, uint8, int8, int16);
+REGISTER5(BinaryOp, CPU, "Less", functor::less, float, Eigen::half, double,
+ bfloat16, int32);
+REGISTER5(BinaryOp, CPU, "Less", functor::less, int64, uint8, int8, int16,
+ bfloat16);
+
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "Less", functor::less, float, Eigen::half, double,
int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc
index 499200d054..3353e117cd 100644
--- a/tensorflow/core/kernels/cwise_op_less_equal.cc
+++ b/tensorflow/core/kernels/cwise_op_less_equal.cc
@@ -16,8 +16,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER9(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
- bfloat16, double, int32, int64, uint8, int8, int16);
+REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, float, Eigen::half,
+ bfloat16, double, int32);
+REGISTER5(BinaryOp, CPU, "LessEqual", functor::less_equal, int64, uint8, int8,
+ int16, bfloat16);
+
#if GOOGLE_CUDA
REGISTER7(BinaryOp, GPU, "LessEqual", functor::less_equal, float, Eigen::half,
double, int64, uint8, int8, int16);
diff --git a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
index 935619711c..9f1e575805 100644
--- a/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+++ b/tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER6(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
- double, uint8, int8, int16);
+REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
+ double, uint8, int8, int16, bfloat16);
#if GOOGLE_CUDA
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
double, uint8);
diff --git a/tensorflow/core/kernels/cwise_op_random_grad.cc b/tensorflow/core/kernels/cwise_op_random_grad.cc
new file mode 100644
index 0000000000..8e388ead9e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_random_grad.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER2(BinaryOp, CPU, "RandomGammaGrad", functor::random_gamma_grad, float,
+ double);
+#if GOOGLE_CUDA
+REGISTER2(BinaryOp, GPU, "RandomGammaGrad", functor::random_gamma_grad, float,
+ double);
+#endif
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index a80905d145..1b1a704d42 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -616,6 +616,12 @@ struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
template <typename T>
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
+template <typename T>
+struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {};
+
+template <typename T>
+struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {};
+
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
};
@@ -765,6 +771,10 @@ template <typename T>
struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
template <typename T>
+struct random_gamma_grad
+ : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
+
+template <typename T>
struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 82cdae9a34..7a6f14babc 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -202,6 +202,9 @@ struct sqrt_grad : base<T, Eigen::internal::scalar_sqrt_gradient_op<T>> {};
template <typename T>
struct rsqrt_grad : base<T, Eigen::internal::scalar_rsqrt_gradient_op<T>> {};
+template <typename T>
+struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
+
} // end namespace functor
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index d35aad980d..e04fa20414 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -85,6 +85,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "window_dataset_op",
+ srcs = ["window_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ ":window_dataset",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "slide_dataset_op",
srcs = ["slide_dataset_op.cc"],
deps = [
@@ -358,6 +371,7 @@ tf_kernel_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
],
)
@@ -549,11 +563,47 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "optimize_dataset_op",
+ srcs = ["optimize_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:grappler_item_builder",
+ "//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/optimizers:meta_optimizer",
+ "//tensorflow/core/grappler/optimizers/data",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
+ srcs = ["dataset_ops.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "data",
deps = [
":batch_dataset_op",
":cache_dataset_ops",
":concatenate_dataset_op",
+ ":dataset",
+ ":dataset_ops",
":dense_to_sparse_batch_dataset_op",
":filter_dataset_op",
":flat_map_dataset_op",
@@ -564,6 +614,7 @@ tf_kernel_library(
":iterator_ops",
":map_and_batch_dataset_op",
":map_dataset_op",
+ ":optimize_dataset_op",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
@@ -586,6 +637,7 @@ tf_kernel_library(
":tensor_queue_dataset_op",
":tensor_slice_dataset_op",
":unbatch_dataset_op",
+ ":window_dataset_op",
":writer_ops",
":zip_dataset_op",
],
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 9a83c16f33..58b86f2a08 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -27,7 +27,8 @@ namespace {
class BatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit BatchDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
+ : UnaryDatasetOpKernel(ctx),
+ op_version_(ctx->def().op() == "BatchDataset" ? 1 : 2) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
@@ -38,14 +39,24 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
- *output = new Dataset(ctx, batch_size, input);
+ bool drop_remainder = false;
+ if (op_version_ > 1) {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "drop_remainder",
+ &drop_remainder));
+ }
+
+ *output = new Dataset(ctx, batch_size, drop_remainder, input);
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 batch_size, const DatasetBase* input)
- : GraphDatasetBase(ctx), batch_size_(batch_size), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
+ const DatasetBase* input)
+ : GraphDatasetBase(ctx),
+ batch_size_(batch_size),
+ drop_remainder_(drop_remainder),
+ input_(input) {
input_->Ref();
// NOTE(mrry): Currently we implement "batch up to" semantics. If
@@ -54,8 +65,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
for (const auto& input_shape : input_shapes) {
- output_shapes_.emplace_back(
- PartialTensorShape({-1}).Concatenate(input_shape));
+ if (drop_remainder_) {
+ output_shapes_.emplace_back(
+ PartialTensorShape({batch_size_}).Concatenate(input_shape));
+ } else {
+ output_shapes_.emplace_back(
+ PartialTensorShape({-1}).Concatenate(input_shape));
+ }
}
}
@@ -86,8 +102,10 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, batch_size}, output));
+ Node* drop_remainder = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, batch_size, drop_remainder}, output));
return Status::OK();
}
@@ -133,6 +151,12 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ if (dataset()->drop_remainder_ &&
+ batch_elements.size() < dataset()->batch_size_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
// Copy the retrieved batch elements into one output tensor
// per tuple component.
// NOTE(mrry): If the input or output sizes are statically
@@ -201,14 +225,20 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
+ const bool drop_remainder_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
+
+ const int op_version_;
};
REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
BatchDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
+ BatchDatasetOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 3673df6fa3..ed4932bf32 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -41,15 +41,17 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
if (filename.empty()) {
*output = new MemoryDataset(input);
} else {
- *output = new FileDataset(input, filename, ctx->env());
+ *output = new FileDataset(ctx, input, filename, ctx->env());
}
}
private:
- class FileDataset : public DatasetBase {
+ class FileDataset : public GraphDatasetBase {
public:
- explicit FileDataset(const DatasetBase* input, string filename, Env* env)
- : input_(input),
+ explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
+ string filename, Env* env)
+ : GraphDatasetBase(ctx),
+ input_(input),
filename_(std::move(filename)),
env_(env),
num_tensors_(input->output_dtypes().size()),
@@ -66,13 +68,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) {
- return std::unique_ptr<IteratorBase>(new FileReaderIterator(
- {this, strings::StrCat(prefix, "::FileReader")}));
- } else {
- return std::unique_ptr<IteratorBase>(new FileWriterIterator(
- {this, strings::StrCat(prefix, "::FileWriter")}));
- }
+ return std::unique_ptr<IteratorBase>(new FileCacheIterator(
+ {this, strings::StrCat(prefix, "::FileCacheIterator")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -87,6 +84,17 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return "CacheDatasetOp::FileDataset";
}
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph));
+ Node* filename = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph, filename}, output));
+ return Status::OK();
+ }
+
private:
static size_t StringPaddingSize(size_t num_tensors) {
return strings::Printf("%zu", num_tensors - 1).size();
@@ -97,163 +105,428 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
tensor_index);
}
- // FileWriterIterator passes through and caches items from the input
- // FileDataset.
- //
- // This iterator is used when the cache directory is not found on disk. It
- // creates the cache directory, and passes on the underlying iterator's
- // elements.
- class FileWriterIterator : public DatasetIterator<FileDataset> {
+ class FileCacheIterator : public DatasetIterator<FileDataset> {
public:
- explicit FileWriterIterator(const Params& params)
- : DatasetIterator<FileDataset>(params),
- cur_index_(0),
- writer_(params.dataset->env_, params.dataset->filename_),
- lockfile_(strings::StrCat(params.dataset->filename_, ".lockfile")),
- lockfile_created_(false),
- iteration_completed_(false) {}
+ explicit FileCacheIterator(const Params& params)
+ : DatasetIterator<FileDataset>(params) {
+ if (params.dataset->env_
+ ->FileExists(MetaFilename(params.dataset->filename_))
+ .ok()) {
+ mode_ = Mode::read;
+ } else {
+ mode_ = Mode::write;
+ }
+ InitializeIterator();
+ }
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ mutex_lock l(mu_);
+ return iterator_->Initialize(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(EnsureLockFileExists());
- TF_RETURN_IF_ERROR(writer_.status());
- if (cur_index_ >= kMaxItems) {
- // As a courtesy, close the [truncated] cache file.
- Status s = Finish();
- if (!s.ok()) {
- LOG(ERROR) << s;
- }
- return errors::InvalidArgument(
- "Upstream iterator is producing more than ", kMaxItems,
- " items, which is more than the cache limit.");
- }
+ return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (*end_of_sequence && out_tensors->empty()) {
- TF_RETURN_IF_ERROR(Finish());
- cur_index_++;
- return Status::OK();
- }
- if (out_tensors->size() != dataset()->num_tensors_) {
- return errors::Internal(
- "Upstream iterator returned invalid number of tensors. Expected ",
- dataset()->num_tensors_, " got: ", out_tensors->size());
- }
- size_t tensor_index = 0;
- for (const Tensor& t : *out_tensors) {
- DCHECK_LT(tensor_index, dataset()->num_tensors_);
- string key = dataset()->FormatName(cur_index_, tensor_index++);
- TF_RETURN_IF_ERROR(writer_.Add(key, t));
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_));
+ return SaveParent(writer, iterator_);
+ }
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp));
+ mode_ = static_cast<Mode>(temp);
}
- if (*end_of_sequence) {
- TF_RETURN_IF_ERROR(Finish());
+ if (mode_ == Mode::write &&
+ dataset()
+ ->env_->FileExists(MetaFilename(dataset()->filename_))
+ .ok()) {
+ // This could happen if the cache was completely written after the
+ // checkpoint was saved.
+ LOG(WARNING)
+ << "It looks like the cache was already completely written("
+ << MetaFilename(dataset()->filename_)
+ << ") after the last checkpoint was saved. "
+ << "Attempting to read the cache instead of continuing to "
+ << "write. If this is a mistake, please remove the above file "
+ << "and try running again.";
+ mode_ = Mode::read;
}
- cur_index_++;
- return Status::OK();
+ InitializeIterator();
+ TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
+ return RestoreParent(ctx, reader, iterator_);
}
private:
- Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (iteration_completed_)
- return errors::OutOfRange(
- "Attempting to call get_next after iteration should have "
- "finished.");
- if (lockfile_created_ && !iteration_completed_) return Status::OK();
- // Perform rudimentary locking to help catch concurrent writes to the
- // same cache files.
- if (dataset()->env_->FileExists(lockfile_).ok()) {
- // Attempt to read the contents of the lockfile.
- char contents_scratch[151] = {0}; // Initialize all to 0.
- StringPiece contents;
- std::unique_ptr<RandomAccessFile> file;
- if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
- file->Read(0, 150, &contents, contents_scratch).IgnoreError();
+ // FileWriterIterator passes through and caches items from the input
+ // FileDataset.
+ //
+ // This iterator is used when the cache directory is not found on disk. It
+ // creates the cache directory, and passes on the underlying iterator's
+ // elements.
+ //
+ // Caching is performed by writing the input tensors to disk using the
+ // `BundleWriter`. Note that the cache gets fully flushed to disk only
+ // after the input iterator has been fully exhausted. If the program
+ // exits, before completion of an epoch, the cached state would be lost.
+ // To ensure that the partial cache persists across sessions, one should
+ // checkpoint the input pipeline. On each call to `SaveInternal` the
+ // partial cache gets flushed to disk in files with prefix
+ // <filename>_<shard_id> where shard_id is unique for each checkpoint.
+ // When all elements have been produced, these shards get coalesced.
+ class FileWriterIterator : public DatasetIterator<FileDataset> {
+ public:
+ explicit FileWriterIterator(const Params& params)
+ : DatasetIterator<FileDataset>(params),
+ cur_index_(0),
+ shard_id_(0),
+ filename_(
+ strings::StrCat(params.dataset->filename_, "_", shard_id_)),
+ lockfile_(strings::StrCat(filename_, ".lockfile")),
+ lockfile_created_(false),
+ iteration_completed_(false) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureLockFileExists());
+ TF_RETURN_IF_ERROR(writer_->status());
+ if (cur_index_ >= kMaxItems) {
+ // As a courtesy, close the [truncated] cache file.
+ Status s = Finish();
+ if (!s.ok()) {
+ LOG(ERROR) << s;
+ }
+ return errors::InvalidArgument(
+ "Upstream iterator is producing more than ", kMaxItems,
+ " items, which is more than the cache limit.");
}
- return errors::AlreadyExists(
- "There appears to be a concurrent caching iterator running - "
- "cache lockfile already exists ('",
- lockfile_,
- "'). If you are sure no other running TF computations are using "
- "this cache prefix, delete the lockfile and re-initialize the "
- "iterator. Lockfile contents: ",
- contents);
- } else {
- // Create the file, and write some basic contents.
- std::unique_ptr<WritableFile> lockfile;
+
TF_RETURN_IF_ERROR(
- dataset()->env_->NewWritableFile(lockfile_, &lockfile));
- TF_RETURN_IF_ERROR(lockfile->Append(
- strings::StrCat("Created at: ", dataset()->env_->NowSeconds())));
- lockfile_created_ = true;
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence && out_tensors->empty()) {
+ TF_RETURN_IF_ERROR(Finish());
+ cur_index_++;
+ return Status::OK();
+ }
+ if (out_tensors->size() != dataset()->num_tensors_) {
+ return errors::Internal(
+ "Upstream iterator returned invalid number of tensors. "
+ "Expected ",
+ dataset()->num_tensors_, " got: ", out_tensors->size());
+ }
+ size_t tensor_index = 0;
+ for (const Tensor& t : *out_tensors) {
+ DCHECK_LT(tensor_index, dataset()->num_tensors_);
+ string key = dataset()->FormatName(cur_index_, tensor_index++);
+ TF_RETURN_IF_ERROR(writer_->Add(key, t));
+ }
+ if (*end_of_sequence) {
+ TF_RETURN_IF_ERROR(Finish());
+ }
+ cur_index_++;
return Status::OK();
}
- }
- Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- iteration_completed_ = true;
- TF_RETURN_IF_ERROR(writer_.Finish());
- TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(lockfile_));
- return Status::OK();
- }
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (iteration_completed_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("iteration_completed"), ""));
+ return Status::OK();
+ }
- mutex mu_;
- size_t cur_index_ GUARDED_BY(mu_);
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- BundleWriter writer_ GUARDED_BY(mu_);
- const string lockfile_;
- bool lockfile_created_ GUARDED_BY(mu_);
- bool iteration_completed_ GUARDED_BY(mu_);
- }; // FileWriterIterator
+ // lockfile is created on the first call to GetNextInternal. The
+ // absence of a lockfile means that GetNextInternal was not called
+ // and hence nothing was written to cache. So we don't need to worry
+ // about flushing the current shard. This ensures that we never write
+ // empty shards.
+ if (lockfile_created_) {
+ // Flush the current bundle.
+ TF_RETURN_IF_ERROR(writer_->Finish());
+
+ // Note: We do not delete the lockfile here. We keep lockfiles of
+ // all shards around until the entire cache has been written to
+ // prevent concurrent iterators from corrupting any of the shards.
+
+ // Start caching to a new shard.
+ shard_id_++;
+ filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
+ lockfile_ = strings::StrCat(filename_, ".lockfile");
+ lockfile_created_ = false;
+ }
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cur_index"), cur_index_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("shard_id"), shard_id_));
+ return Status::OK();
+ }
- class FileReaderIterator : public DatasetIterator<FileDataset> {
- public:
- explicit FileReaderIterator(const Params& params)
- : DatasetIterator<FileDataset>(params),
- cur_index_(0),
- reader_(dataset()->env_, dataset()->filename_) {}
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("iteration_completed"))) {
+ iteration_completed_ = true;
+ return Status::OK();
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- *end_of_sequence = false;
- TF_RETURN_IF_ERROR(reader_.status());
- if (!reader_.Valid()) {
- return errors::Internal(
- "Cache iterator is in an invalid state. (Perhaps GetNext called "
- "after end_of_sequence?)");
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ int64 temp;
+ // TODO(b/78048575): Update this when saving size_t tensors directly
+ // is supported.
+ {
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cur_index"), &temp));
+ cur_index_ = static_cast<size_t>(temp);
+ if (cur_index_ != temp) {
+ return errors::Internal("Invalid value for cur_index ", temp);
+ }
+ }
+ // TODO(b/78048575): Update this when saving size_t tensors directly
+ // is supported.
+ {
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("shard_id"), &temp));
+ shard_id_ = static_cast<size_t>(temp);
+ if (shard_id_ != temp) {
+ return errors::Internal("Invalid value for shard_id ", temp);
+ }
+ }
+ filename_ = strings::StrCat(dataset()->filename_, "_", shard_id_);
+ lockfile_ = strings::StrCat(filename_, ".lockfile");
+ writer_.reset(new BundleWriter(dataset()->env_, filename_));
+ return Status::OK();
}
- out_tensors->clear();
- out_tensors->resize(dataset()->num_tensors_);
- for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
- reader_.Next(); // The first entry in the table is a header entry.
- if (!reader_.Valid()) {
- out_tensors->clear();
- *end_of_sequence = true;
+ private:
+ Status EnsureLockFileExists() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (iteration_completed_)
+ return errors::OutOfRange(
+ "Attempting to call get_next after iteration should have "
+ "finished.");
+ if (lockfile_created_ && !iteration_completed_) return Status::OK();
+
+ // Perform rudimentary locking to help catch concurrent writes to the
+ // same cache files.
+
+ // 1. Check that a checkpoint for the shard has not already been
+ // written.
+ if (dataset()->env_->FileExists(MetaFilename(filename_)).ok()) {
+ return errors::AlreadyExists("Existing cache files found: \n",
+ MetaFilename(filename_), "\n",
+ DataFilename(filename_, 0, 1), "\n",
+ "To continue delete the above files.");
+ }
+
+ // 2. Check that there isn't a concurrent iterator that is writing
+ // to cache.
+ if (dataset()->env_->FileExists(lockfile_).ok()) {
+ // Attempt to read the contents of the lockfile.
+ char contents_scratch[151] = {0}; // Initialize all to 0.
+ StringPiece contents;
+ std::unique_ptr<RandomAccessFile> file;
+ if (dataset()->env_->NewRandomAccessFile(lockfile_, &file).ok()) {
+ file->Read(0, 150, &contents, contents_scratch).IgnoreError();
+ }
+ return errors::AlreadyExists(
+ "There appears to be a concurrent caching iterator running - "
+ "cache lockfile already exists ('",
+ lockfile_,
+ "'). If you are sure no other running TF computations are "
+ "using "
+ "this cache prefix, delete the lockfile and re-initialize the "
+ "iterator. Lockfile contents: ",
+ contents);
+ } else {
+ // Create the file, and write some basic contents.
+ std::unique_ptr<WritableFile> lockfile;
+ TF_RETURN_IF_ERROR(
+ dataset()->env_->NewWritableFile(lockfile_, &lockfile));
+ TF_RETURN_IF_ERROR(lockfile->Append(strings::StrCat(
+ "Created at: ", dataset()->env_->NowSeconds())));
+
+ // At this point we know that
+ // 1. There is no conflicting checkpoint with prefix `filename_`.
+ // 2. There is no concurrent session that is trying to write a ckpt
+ // to filename.
+ // So it is safe to create a BundleWriter here. Note that it is
+ // unsafe to initialize the BundleWriter anywhere the above
+ // conditions are not met since BundleWriter's constructor creates
+ // new temp files which can delete the temp files created by a
+ // BundleWriter in another Session.
+ writer_.reset(new BundleWriter(dataset()->env_, filename_));
+ lockfile_created_ = true;
return Status::OK();
}
- StringPiece key = reader_.key();
- DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
- TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
+ }
+
+ Status Finish() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ iteration_completed_ = true;
+ // Flush the current bundle.
+ TF_RETURN_IF_ERROR(writer_->Finish());
+ // Merge all the bundles.
+ // Currently there are `shard_id_ + 1` bundles, one for each
+ // checkpoint. Each bundle has prefix <filename>_<id> where `id` is an
+ // integer starting at 0 an incremented by 1 for each new checkpoint.
+ // We merge all these bundles into a bundle with prefix <filename> so
+ // that the next call to `MakeIterator` can build a
+ // `FileReaderIterator`.
+ {
+ std::vector<string> prefixes;
+ prefixes.reserve(shard_id_ + 1);
+ for (size_t i = 0; i <= shard_id_; ++i) {
+ prefixes.emplace_back(
+ strings::StrCat(dataset()->filename_, "_", i));
+ }
+ TF_RETURN_IF_ERROR(
+ MergeBundles(dataset()->env_, prefixes, dataset()->filename_));
+ }
+ // Delete all lockfiles.
+ for (size_t i = 0; i <= shard_id_; ++i) {
+ TF_RETURN_IF_ERROR(dataset()->env_->DeleteFile(
+ strings::StrCat(dataset()->filename_, "_", i, ".lockfile")));
+ }
+ return Status::OK();
+ }
+
+ mutex mu_;
+ size_t cur_index_ GUARDED_BY(mu_);
+ // Index of the current shard. This gets incremented whenever a new
+ // cache shard is saved.
+ size_t shard_id_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ // The current prefix for the cache file. This is equal to
+ // `StrCat(dataset()->filename_, "_", shard_id_)`.
+ string filename_;
+ std::unique_ptr<BundleWriter> writer_ GUARDED_BY(mu_);
+ string lockfile_ GUARDED_BY(mu_);
+ bool lockfile_created_ GUARDED_BY(mu_);
+ bool iteration_completed_ GUARDED_BY(mu_);
+ }; // FileWriterIterator
+
+ class FileReaderIterator : public DatasetIterator<FileDataset> {
+ public:
+ explicit FileReaderIterator(const Params& params)
+ : DatasetIterator<FileDataset>(params),
+ cur_index_(0),
+ reader_(dataset()->env_, dataset()->filename_),
+ iterator_restored_(false) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ *end_of_sequence = false;
TF_RETURN_IF_ERROR(reader_.status());
+ if (!reader_.Valid()) {
+ return errors::Internal(
+ "Cache iterator is in an invalid state. (Perhaps GetNext "
+ "called "
+ "after end_of_sequence?)");
+ }
+ out_tensors->clear();
+ out_tensors->resize(dataset()->num_tensors_);
+
+ for (size_t i = 0; i < dataset()->num_tensors_; ++i) {
+ // When the iterator is restored from the checkpoint, `reader_` is
+ // already pointing at `key` so we do not need to skip the header
+ // entry.
+ if (!iterator_restored_) {
+ reader_
+ .Next(); // The first entry in the table is a header entry.
+ } else {
+ iterator_restored_ = false;
+ }
+ if (!reader_.Valid()) {
+ out_tensors->clear();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ StringPiece key = reader_.key();
+ DCHECK_EQ(key, dataset()->FormatName(cur_index_, i));
+ TF_RETURN_IF_ERROR(reader_.ReadCurrent(&(*out_tensors)[i]));
+ TF_RETURN_IF_ERROR(reader_.status());
+ }
+ cur_index_++;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cur_index"), cur_index_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(
+ IteratorContext* ctx,
+ IteratorStateReader* iterator_state_reader) override {
+ mutex_lock l(mu_);
+ {
+ // TODO(b/78048575): Update this when saving size_t tensors directly
+ // is supported.
+ int64 temp;
+ TF_RETURN_IF_ERROR(iterator_state_reader->ReadScalar(
+ full_name("cur_index"), &temp));
+ cur_index_ = static_cast<size_t>(temp);
+ if (cur_index_ != temp) {
+ return errors::Internal("Invalid value for cur_index ", temp);
+ }
+ }
+ if (!reader_.Valid()) {
+ return errors::Internal("Error initializing BundleReader.");
+ }
+ reader_.Seek(dataset()->FormatName(cur_index_, 0));
+ iterator_restored_ = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ size_t cur_index_ GUARDED_BY(mu_);
+ BundleReader reader_ GUARDED_BY(mu_);
+ bool iterator_restored_ GUARDED_BY(mu_);
+ }; // FileReaderIterator
+
+ void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // We intentionally use the same prefix for both `FileReaderIterator`
+ // and `FileWriterIterator`. Since at any time there will be at most
+ // one of them alive, there should be no conflicts. This allows both
+ // iterators to use a common key for `cur_index`. We leverage this
+ // in the corner case when this iterator is restored from an old
+ // checkpoint in `write` mode and the cache has been completely
+ // flushed to disk since then. In that case we simply build a
+ // `FileReaderIterator` and seek to the `cur_index`.
+ switch (mode_) {
+ case Mode::read:
+ iterator_.reset(new FileReaderIterator({dataset(), prefix()}));
+ break;
+ case Mode::write:
+ iterator_.reset(new FileWriterIterator({dataset(), prefix()}));
}
- cur_index_++;
- return Status::OK();
}
- private:
mutex mu_;
- size_t cur_index_ GUARDED_BY(mu_);
- BundleReader reader_ GUARDED_BY(mu_);
- }; // FileReaderIterator
+ enum Mode { read, write };
+ Mode mode_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
+ }; // FileCacheIterator
const DatasetBase* const input_;
const string filename_;
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ee58341cfd..82da385405 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -214,6 +214,9 @@ Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -248,6 +251,9 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
});
f_opts.step_container = &step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -304,6 +310,9 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
});
f_opts.step_container = &step_container;
f_opts.runner = runner;
+ if (lib->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
@@ -351,6 +360,9 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
});
f_opts.step_container = step_container;
f_opts.runner = ctx->runner();
+ if (ctx->lib()->device()->device_type() != DEVICE_CPU) {
+ f_opts.create_rendezvous = true;
+ }
// TODO(mrry): Add cancellation manager support to IteratorContext
// so that we can cancel running map functions. The local
// cancellation manager here is created so that we can run kernels
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
new file mode 100644
index 0000000000..01989a3bd9
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class DatasetToGraphOp : public OpKernel {
+ public:
+ explicit DatasetToGraphOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ GraphDefBuilder b;
+ DatasetBase::DatasetGraphDefBuilder db(&b);
+ Node* input_node = nullptr;
+ OP_REQUIRES_OK(ctx, db.AddParentDataset(ctx, dataset, &input_node));
+ GraphDef graph_def;
+ OP_REQUIRES_OK(ctx, b.ToGraphDef(&graph_def));
+ Tensor* result;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
+ result->scalar<string>()() = graph_def.SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
+ DatasetToGraphOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 91b9279427..da4b14c8b9 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -101,8 +101,8 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
}
const DataTypeVector& output_dtypes() const override {
- static DataTypeVector* output_dtypes_ = new DataTypeVector({DT_VARIANT});
- return *output_dtypes_;
+ static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT});
+ return *output_dtypes;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index aae62ad2fe..0981e42ba1 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -197,6 +197,9 @@ class GeneratorDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
+ GeneratorDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index 03abae79d2..7206be8c0d 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -254,6 +254,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(
dataset()->captured_finalize_func_->RunWithBorrowedArgs(
ctx, states_[keys_[keys_index_++]], out_tensors));
+ *end_of_sequence = false;
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 9d9e74adba..2a94a54f3d 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -207,12 +207,6 @@ class IteratorResource : public ResourceBase {
return Status::OK();
}
-
- std::shared_ptr<StatsAggregator> stats_aggregator() {
- tf_shared_lock l(mu_);
- return stats_aggregator_;
- }
-
string DebugString() override { return "Iterator resource"; }
const DataTypeVector& output_dtypes() const { return output_dtypes_; }
@@ -231,7 +225,6 @@ class IteratorResource : public ResourceBase {
FunctionLibraryRuntime* lib_ = nullptr; // not owned.
std::shared_ptr<IteratorBase> iterator_;
mutex mu_;
- std::shared_ptr<StatsAggregator> stats_aggregator_ GUARDED_BY(mu_);
std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
const DataTypeVector output_dtypes_;
const std::vector<PartialTensorShape> output_shapes_;
@@ -693,30 +686,45 @@ class ToSingleElementOp : public AsyncOpKernel {
ctx,
dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([ctx, raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
- bool end_of_sequence;
-
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
- OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
- errors::InvalidArgument("Dataset was empty."), done);
+ bool end_of_sequence = false;
+ Status s =
+ raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ return;
+ }
+ if (end_of_sequence) {
+ ctx->SetStatus(errors::InvalidArgument("Dataset was empty."));
+ return;
+ }
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
components.clear();
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
- OP_REQUIRES_ASYNC(
- ctx, end_of_sequence,
- errors::InvalidArgument("Dataset had more than one element."), done);
-
- done();
+ Status s2 =
+ raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ if (!s2.ok()) {
+ ctx->SetStatus(s2);
+ return;
+ }
+ if (!end_of_sequence) {
+ ctx->SetStatus(
+ errors::InvalidArgument("Dataset had more than one element."));
+ return;
+ }
});
}
@@ -782,11 +790,11 @@ class OneShotIteratorOp : public AsyncOpKernel {
return;
}
}
- ProduceOutput(ctx, std::move(done));
+ ProduceOutput(ctx, done);
}
private:
- void Init(OpKernelContext* ctx, DoneCallback done) {
+ void Init(OpKernelContext* ctx, const DoneCallback& done) {
IteratorResource* iterator = nullptr;
ContainerInfo cinfo;
Status s = TryInit(ctx, &iterator, &cinfo);
@@ -803,9 +811,9 @@ class OneShotIteratorOp : public AsyncOpKernel {
}
for (auto&& ctx_done : callbacks_to_run) {
- ProduceOutput(ctx_done.first, std::move(ctx_done.second));
+ ProduceOutput(ctx_done.first, ctx_done.second);
}
- ProduceOutput(ctx, std::move(done));
+ ProduceOutput(ctx, done);
}
Status TryInit(OpKernelContext* ctx, IteratorResource** iterator,
@@ -944,9 +952,6 @@ class IteratorGetNextOp : public AsyncOpKernel {
IteratorContext::Params params;
params.env = ctx->env();
- params.stats_aggregator_getter = [iterator]() {
- return iterator->stats_aggregator();
- };
params.runner = *(ctx->runner());
params.function_library = iterator->function_library();
DeviceBase* device = ctx->function_library()->device();
@@ -995,9 +1000,6 @@ class IteratorGetNextSyncOp : public OpKernel {
IteratorContext::Params params;
params.env = ctx->env();
- params.stats_aggregator_getter = [iterator]() {
- return iterator->stats_aggregator();
- };
params.runner = *(ctx->runner());
params.function_library = iterator->function_library();
DeviceBase* device = ctx->function_library()->device();
@@ -1148,22 +1150,45 @@ class DeserializeIteratorOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_CPU),
+ IteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE_GPU),
+ IteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
MakeIteratorOp);
+REGISTER_KERNEL_BUILDER(
+ Name("MakeIterator").Device(DEVICE_GPU).HostMemory("dataset"),
+ MakeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_CPU),
AnonymousIteratorHandleOp);
+REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
+ AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
IteratorGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_GPU),
+ IteratorGetNextOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
IteratorGetNextSyncOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
+ IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2").Device(DEVICE_CPU),
+ IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")
+ .Device(DEVICE_GPU)
+ .HostMemory("string_handle"),
+ IteratorFromStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
SerializeIteratorOp);
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 703ef194a1..004f153af6 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -189,14 +189,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- batch_results_((params.dataset->num_parallel_calls_ +
- params.dataset->batch_size_ - 1) /
- params.dataset->batch_size_) {
- for (int i = 0; i < batch_results_.size(); ++i) {
- batch_results_[i].Initialize(params.dataset->batch_size_);
- }
- }
+ : DatasetIterator<Dataset>(params) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -216,17 +209,23 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock external_l(external_mu_);
- mutex_lock l(mu_);
- EnsureRunnerThreadStarted(ctx);
- BatchResult* result = &batch_results_[ComputeIndex(input_batch_)];
- WaitForBatch(result, &l);
- return ProcessBatch(ctx, result, out_tensors, end_of_sequence);
+ std::shared_ptr<BatchResult> result;
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (batch_results_.empty() ||
+ batch_results_.front()->num_calls > 0) {
+ cond_var_.wait(l);
+ }
+ std::swap(result, batch_results_.front());
+ batch_results_.pop_front();
+ }
+ cond_var_.notify_all();
+ return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
@@ -236,10 +235,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("call_counter"), call_counter_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_batch"), input_batch_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("output_batch"), output_batch_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_results_size"),
batch_results_.size()));
for (size_t i = 0; i < batch_results_.size(); ++i) {
@@ -250,19 +245,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock external_l(external_mu_);
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("input_batch"), &input_batch_));
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("output_batch"), &output_batch_));
int64 batch_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_results_size"),
&batch_results_size));
- CHECK_EQ(batch_results_.size(), batch_results_size);
for (int i = 0; i < batch_results_size; ++i) {
TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
}
@@ -271,21 +260,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
struct BatchResult {
- mutex mu;
- bool end_of_input GUARDED_BY(mu);
- int64 num_elements GUARDED_BY(mu);
- std::vector<Tensor> output;
- bool output_allocated GUARDED_BY(mu);
- Status status GUARDED_BY(mu);
- // Used for coordination between the main thread and the callback
- // threads. In particular, the main thread will wait for the value
- // of `num_calls` to reach zero before processing the batch result.
- condition_variable cond_var; // access guarded by owner's mutex
- // Counts the number of outstanding calls for this batch.
- int64 num_calls; // access guarded by owner's mutex
-
- void Initialize(int64 batch_size) {
- mutex_lock l(mu);
+ explicit BatchResult(int64 batch_size) {
end_of_input = false;
num_calls = batch_size;
num_elements = 0;
@@ -297,12 +272,21 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu);
status.Update(s);
}
+
+ mutex mu;
+ bool end_of_input GUARDED_BY(mu);
+ int64 num_elements GUARDED_BY(mu);
+ std::vector<Tensor> output;
+ bool output_allocated GUARDED_BY(mu);
+ Status status GUARDED_BY(mu);
+ // Counts the number of outstanding calls for this batch.
+ int64 num_calls; // access guarded by owner's mutex
};
void Callback(const std::shared_ptr<IteratorContext>& ctx,
- BatchResult* result, std::vector<Tensor>* return_values,
- int64 offset, const Status& status) {
- std::unique_ptr<std::vector<Tensor>> cleanup_retvals(return_values);
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values,
+ int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -329,40 +313,42 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
break;
}
}
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
}
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
+ CallCompleted(result);
+ }
+
+ void CallCompleted(const std::shared_ptr<BatchResult>& result)
+ LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- CallCompleted(result);
+ num_calls_--;
+ result->num_calls--;
}
- }
-
- void CallCompleted(BatchResult* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- num_calls_--;
cond_var_.notify_all();
- result->num_calls--;
- result->cond_var.notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
- BatchResult* result, int64 offset) {
+ const std::shared_ptr<BatchResult>& result,
+ int64 offset) LOCKS_EXCLUDED(mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
Status status =
input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
+ bool return_early;
{
- mutex_lock l(mu_);
- mutex_lock l2(result->mu);
+ mutex_lock l(result->mu);
result->end_of_input = result->end_of_input || end_of_input;
result->status.Update(status);
- if (result->end_of_input || !result->status.ok()) {
- CallCompleted(result);
- return;
- }
+ return_early = result->end_of_input || !result->status.ok();
+ }
+ if (return_early) {
+ CallCompleted(result);
+ return;
}
// Call `captured_func_(input_element)`, using `Callback` to store the
@@ -370,9 +356,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
(*ctx->runner())(std::bind(
[this, result, offset](std::shared_ptr<IteratorContext> ctx,
std::vector<Tensor> input_element) {
- std::vector<Tensor>* return_values = new std::vector<Tensor>();
+ std::shared_ptr<std::vector<Tensor>> return_values(
+ new std::vector<Tensor>());
dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values,
+ ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
});
@@ -380,14 +367,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx, std::move(input_element)));
}
- int64 ComputeIndex(int64 n) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return n % batch_results_.size();
- }
-
Status CopyPartialBatch(Tensor* output, const Tensor& value,
int64 num_elements) {
switch (value.dtype()) {
-#define CASE(type) \
+#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
auto output_t = output->flat_outer_dims<type>(); \
auto value_t = value.flat_outer_dims<type>(); \
@@ -396,10 +379,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
} \
return Status::OK(); \
}
- TF_CALL_NUMBER_TYPES(CASE);
- TF_CALL_string(CASE);
- TF_CALL_variant(CASE);
-#undef CASE
+ TF_CALL_DATASET_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
default:
return errors::InvalidArgument("Unsupported data type: ",
value.dtype());
@@ -417,9 +398,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
}
- void EnsureOutputAllocated(const std::shared_ptr<IteratorContext>& ctx,
- BatchResult* result,
- const std::vector<Tensor>* return_values) {
+ void EnsureOutputAllocated(
+ const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<BatchResult>& result,
+ const std::shared_ptr<std::vector<Tensor>>& return_values) {
mutex_lock l(result->mu);
if (result->output_allocated) {
return;
@@ -437,93 +419,100 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- Status ProcessBatch(IteratorContext* ctx, BatchResult* result,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- auto cleanup =
- gtl::MakeCleanup([this, result]() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- result->Initialize(dataset()->batch_size_);
- input_batch_++;
- cond_var_.notify_all();
- });
+ int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ }
+
+ Status ProcessResult(IteratorContext* ctx,
+ const std::shared_ptr<BatchResult>& result,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
mutex_lock l(result->mu);
if (result->num_elements == 0) {
*end_of_sequence = true;
return Status::OK();
}
-
- if (!result->status.ok()) {
+ // `f` may deliberately raise `errors::OutOfRange` to indicate that we
+ // should terminate the iteration early.
+ if (!result->status.ok() && !errors::IsOutOfRange(result->status)) {
// Deallocate tensors allocated for the output.
result->output.clear();
- } else {
- if (result->num_elements < dataset()->batch_size_) {
- if (dataset()->drop_remainder_) {
- // Deallocate tensors allocated for the output.
- result->output.clear();
- *end_of_sequence = true;
- return Status::OK();
- }
- const std::vector<Tensor>& output = result->output;
- for (size_t i = 0; i < output.size(); ++i) {
- TensorShape component_shape(result->output[i].shape());
- component_shape.set_dim(0, result->num_elements);
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- Tensor component(ctx->allocator(attr), output[i].dtype(),
- component_shape);
- TF_RETURN_IF_ERROR(CopyPartialBatch(&component, output[i],
- result->num_elements));
- out_tensors->emplace_back(std::move(component));
- }
+ *end_of_sequence = false;
+ return result->status;
+ }
+ if (result->num_elements < dataset()->batch_size_) {
+ if (dataset()->drop_remainder_) {
// Deallocate tensors allocated for the output.
result->output.clear();
- } else {
- *out_tensors = std::move(result->output);
+ *end_of_sequence = true;
+ return Status::OK();
}
- *end_of_sequence = false;
+ const std::vector<Tensor>& output = result->output;
+ for (size_t i = 0; i < output.size(); ++i) {
+ TensorShape component_shape(result->output[i].shape());
+ component_shape.set_dim(0, result->num_elements);
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ Tensor component(ctx->allocator(attr), output[i].dtype(),
+ component_shape);
+ TF_RETURN_IF_ERROR(
+ CopyPartialBatch(&component, output[i], result->num_elements));
+ out_tensors->emplace_back(std::move(component));
+ }
+ // Deallocate tensors allocated for the output.
+ result->output.clear();
+ } else {
+ *out_tensors = std::move(result->output);
}
- return result->status;
+ *end_of_sequence = result->num_elements == 0;
+ return Status::OK();
}
- void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- mutex_lock l(mu_);
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
+ LOCKS_EXCLUDED(mu_) {
+ std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
+ new_calls.reserve(dataset()->num_parallel_calls_);
while (true) {
- while (!cancelled_ &&
- (num_calls_ == dataset()->num_parallel_calls_ ||
- (output_batch_ - input_batch_ == batch_results_.size()))) {
- cond_var_.wait(l);
- }
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ (num_calls_ >= dataset()->num_parallel_calls_ ||
+ batch_results_.size() > MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ == 0))) {
+ cond_var_.wait(l);
+ }
- if (cancelled_) {
- return;
- }
+ if (cancelled_) {
+ return;
+ }
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (output_batch_ - input_batch_ < batch_results_.size())) {
- BatchResult* result = &batch_results_[ComputeIndex(output_batch_)];
- int64 offset = call_counter_++ % dataset()->batch_size_;
- num_calls_++;
- mu_.unlock();
- CallFunction(ctx, result, offset);
- mu_.lock();
- if (offset + 1 == dataset()->batch_size_) {
- // Done scheduling calls for the current batch.
- output_batch_++;
+ while (num_calls_ < dataset()->num_parallel_calls_ &&
+ (batch_results_.size() < MaxBatchResults() ||
+ (batch_results_.size() == MaxBatchResults() &&
+ call_counter_ % dataset()->batch_size_ != 0))) {
+ if (call_counter_ % dataset()->batch_size_ == 0) {
+ batch_results_.emplace_back(
+ new BatchResult(dataset()->batch_size_));
+ }
+ int64 offset = call_counter_++ % dataset()->batch_size_;
+ new_calls.emplace_back(batch_results_.back(), offset);
+ num_calls_++;
}
}
- }
- }
- void WaitForBatch(BatchResult* result, mutex_lock* l)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- while (result->num_calls > 0) {
- result->cond_var.wait(*l);
+ for (const auto& call : new_calls) {
+ CallFunction(ctx, call.first, call.second);
+ }
+ new_calls.clear();
}
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- BatchResult* result = &batch_results_[index];
+ batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
result->end_of_input = reader->Contains(
@@ -585,7 +574,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- BatchResult* result = &batch_results_[index];
+ std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
if (result->end_of_input) {
@@ -646,21 +635,13 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
- // Used for serializing external parallelism.
- mutex external_mu_ ACQUIRED_BEFORE(mu_);
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
int64 call_counter_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
- // Identifies the next batch to be read by the caller.
- int64 input_batch_ GUARDED_BY(mu_) = 0;
- // Identifies the next batch to create.
- int64 output_batch_ GUARDED_BY(mu_) = 0;
- // Circular buffer for storing the (intermediate) batch results. When
- // using `input_batch_` and `output_batch_` to index into the buffer,
- // their value should be interpreted modulo the size of the buffer.
- std::vector<BatchResult> batch_results_ GUARDED_BY(mu_);
+ // Buffer for storing the (intermediate) batch results.
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
};
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
new file mode 100644
index 0000000000..81be69105e
--- /dev/null
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -0,0 +1,232 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class OptimizeDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> optimizations;
+ OP_REQUIRES_OK(
+ ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
+ Dataset* dataset =
+ new Dataset(ctx, optimizations, output_types_, output_shapes_);
+ OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, input));
+ *output = dataset;
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& optimizations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ optimizations_(optimizations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {}
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
+ }
+
+ Status Optimize(OpKernelContext* ctx, const DatasetBase* input) {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* input_node = nullptr;
+ TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input, &input_node));
+ string output_node = input_node->name();
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
+ flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
+ graph_def.library()));
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+ std::vector<Tensor> outputs;
+ GraphRunner graph_runner(ctx->function_library()->device());
+ TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
+ {output_node}, &outputs));
+ TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &input_));
+ input_->Ref();
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* optimizations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, optimizations_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = ctx->lib();
+ params.function_library = dataset()->flib_def_;
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext iter_ctx(params);
+ return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
+ string* output_node) {
+ // Add a fake sink node to allow rewriting the actual sink node.
+ NodeDef* node = graph_def->mutable_node()->Add();
+ node->set_name("FakeSink");
+ node->set_op("SinkDataset");
+ node->add_input(*output_node);
+
+ // Create metagraph.
+ MetaGraphDef meta_graph_def;
+ (*meta_graph_def.mutable_graph_def()) = *graph_def;
+
+ // Grappler determines fetch ops from collection 'train_op'.
+ CollectionDef collection_def;
+ auto node_list = collection_def.mutable_node_list();
+ node_list->add_value("FakeSink");
+ (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
+
+ // Create Grappler item.
+ tensorflow::RewriterConfig rewriter_config;
+ for (const string& optimization : optimizations_) {
+ rewriter_config.add_optimizers(optimization);
+ }
+ // If no optimizations were specified, supply a non-existent
+ // optimization to prevent Grappler from applying the default set of
+ // optimizations as some of them do not work out of the box at the
+ // moment (e.g. because we have no cost model for dataset ops).
+ if (optimizations_.empty()) {
+ rewriter_config.add_optimizers("non-existent");
+ }
+ tensorflow::grappler::ItemConfig item_config;
+ item_config.apply_optimizations = true;
+ std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
+ tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+ "graph", meta_graph_def, item_config);
+ std::unordered_map<string, tensorflow::DeviceProperties> device_map;
+ tensorflow::grappler::VirtualCluster cluster(device_map);
+
+ // Run optimizer.
+ TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
+ *grappler_item, rewriter_config, ctx->device(), &cluster, graph_def));
+
+ // Set `output_node` to the input of the fake sink node.
+ {
+ grappler::GraphView graph(graph_def);
+ grappler::GraphView::InputPort input_port =
+ graph.GetInputPort("FakeSink", 0);
+ *output_node = graph.GetRegularFanin(input_port).node->name();
+ }
+
+ return Status::OK();
+ }
+
+ DatasetBase* input_;
+ std::shared_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::vector<string> optimizations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
+ OptimizeDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index d9e43ace39..59cbdb655d 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -28,7 +28,8 @@ namespace {
class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit PaddedBatchDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
+ : UnaryDatasetOpKernel(ctx),
+ op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
@@ -39,6 +40,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
ctx, batch_size > 0,
errors::InvalidArgument("Batch size must be greater than zero."));
+ bool drop_remainder = false;
+ if (op_version_ > 1) {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "drop_remainder",
+ &drop_remainder));
+ }
+
OpInputList padded_shape_tensors;
OP_REQUIRES_OK(ctx,
ctx->input_list("padded_shapes", &padded_shape_tensors));
@@ -85,18 +92,20 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.push_back(tensor::DeepCopy(padding_value_t));
}
- *output = new Dataset(ctx, batch_size, std::move(padded_shapes),
- std::move(padding_values), input);
+ *output =
+ new Dataset(ctx, batch_size, drop_remainder, std::move(padded_shapes),
+ std::move(padding_values), input);
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 batch_size,
+ Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input)
: GraphDatasetBase(ctx),
batch_size_(batch_size),
+ drop_remainder_(drop_remainder),
padded_shapes_(std::move(padded_shapes)),
padding_values_(std::move(padding_values)),
input_(input) {
@@ -112,8 +121,13 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
const auto& input_shapes = input_->output_shapes();
output_shapes_.reserve(input_shapes.size());
for (size_t i = 0; i < input_shapes.size(); ++i) {
- output_shapes_.push_back(
- PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
+ if (drop_remainder_) {
+ output_shapes_.push_back(
+ PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
+ } else {
+ output_shapes_.push_back(
+ PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
+ }
}
}
@@ -166,16 +180,19 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padding_values.emplace_back(node);
}
+ Node* drop_remainder = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
+
AttrValue output_types;
b->BuildAttrValue(output_dtypes(), &output_types);
AttrValue N;
b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {{0, input_graph_node}, {1, batch_size}},
- {{2, padded_shapes}, {3, padding_values}},
- {{"Toutput_types", output_types}, {"N", N}}, output));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
+ {{2, padded_shapes}, {3, padding_values}},
+ {{"Toutput_types", output_types}, {"N", N}}, output));
return Status::OK();
}
@@ -226,6 +243,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ if (dataset()->drop_remainder_ &&
+ batch_elements.size() < dataset()->batch_size_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
// Copy the retrieved batch elements into one output tensor
// per tuple component.
// NOTE(mrry): If the input or output sizes are statically
@@ -341,16 +364,22 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
};
const int64 batch_size_;
+ const bool drop_remainder_;
const std::vector<PartialTensorShape> padded_shapes_;
const std::vector<Tensor> padding_values_;
const DatasetBase* const input_;
std::vector<PartialTensorShape> output_shapes_;
};
+
+ const int op_version_;
};
REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
PaddedBatchDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
+ PaddedBatchDatasetOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 3fa6b0d3a9..15f3dc3b1d 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -151,8 +151,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- invocation_results_(params.dataset->num_parallel_calls_) {}
+ : DatasetIterator<Dataset>(params) {}
~Iterator() override {
// TODO(mrry): Replace this cancellation logic with a
@@ -160,13 +159,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
// but it would be possible to thread a cancellation manager
// through the IteratorContext to upstream,
// potentially-blocking iterators, when we add these.
- {
- mutex_lock l(mu_);
- for (size_t i = 0; i < dataset()->num_parallel_calls_; ++i) {
- if (invocation_results_[i].notification) {
- invocation_results_[i].notification->WaitForNotification();
- }
- }
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
}
@@ -177,173 +176,191 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock l(mu_);
-
- // Ensure that there are `dataset()->num_parallel_calls_`
- // invocations of `func_` outstanding at once.
- while (input_impl_ && (num_inputs_consumed_ - num_outputs_consumed_ <
- dataset()->num_parallel_calls_)) {
- InvokeFunctionLocked(ctx);
- }
-
- if (!input_impl_ && num_inputs_consumed_ == num_outputs_consumed_) {
- *end_of_sequence = true;
- return Status::OK();
- }
-
- // Read the next result out of `invocation_results_`, which
- // acts as a circular buffer.
- const size_t result_index =
- num_outputs_consumed_ % dataset()->num_parallel_calls_;
- InvocationResult* result = &invocation_results_[result_index];
- *end_of_sequence = false;
- if (result->notification) {
- result->notification->WaitForNotification();
- if (result->status.ok()) {
- std::swap(*out_tensors, result->return_values);
+ std::shared_ptr<InvocationResult> result;
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty()) {
+ cond_var_.wait(l);
}
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
}
- ++num_outputs_consumed_;
- if (errors::IsOutOfRange(result->status)) {
- // `f` may deliberately raise `errors::OutOfRange` to indicate
- // that we should terminate the iteration early.
- *end_of_sequence = true;
- return Status::OK();
- } else {
- return result->status;
- }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ return ProcessResult(result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (input_impl_) {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- } else {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("end_of_input"), ""));
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
}
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_inputs_consumed"),
- num_inputs_consumed_));
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name("num_outputs_consumed"), num_outputs_consumed_));
-
- for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) {
- if (invocation_results_[i].notification) {
- invocation_results_[i].notification->WaitForNotification();
- TF_RETURN_IF_ERROR(
- WriteStatusLocked(writer, i, invocation_results_[i].status));
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i, "].size")),
- invocation_results_[i].return_values.size()));
- for (size_t j = 0; j < invocation_results_[i].return_values.size();
- j++) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(
- strings::StrCat("invocation_results[", i, "][", j, "]")),
- invocation_results_[i].return_values[j]));
- }
- } else {
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i, "]_empty")),
+ full_name(strings::StrCat("invocation_results[", i,
+ "].end_of_input")),
""));
}
}
-
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
- if (reader->Contains(full_name("end_of_input"))) {
- input_impl_.reset();
- } else {
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- }
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_inputs_consumed"),
- &num_inputs_consumed_));
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_outputs_consumed"),
- &num_outputs_consumed_));
- for (size_t i = 0; i < dataset()->num_parallel_calls_; i++) {
- InvocationResult* result = &invocation_results_[i];
- *result = InvocationResult();
- if (!reader->Contains(full_name(
- strings::StrCat("invocation_results[", i, "]_empty")))) {
- result->notification.reset(new Notification);
- result->notification->Notify();
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
- size_t num_return_values;
- {
- int64 size;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name(strings::StrCat(
- "invocation_results[", i, "].size")),
- &size));
- num_return_values = static_cast<size_t>(size);
- if (num_return_values != size) {
- return errors::InvalidArgument(strings::StrCat(
- full_name(
- strings::StrCat("invocation_results[", i, "].size")),
- ": ", size, " is not a valid value of type size_t."));
- }
- }
- result->return_values.reserve(num_return_values);
- for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(reader->ReadTensor(
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
full_name(
- strings::StrCat("invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
}
}
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->end_of_input = reader->Contains(full_name(
+ strings::StrCat("invocation_results[", i, "].end_of_input")));
+ result->notification.Notify();
}
return Status::OK();
}
private:
struct InvocationResult {
+ Notification notification;
Status status;
- std::unique_ptr<Notification> notification;
std::vector<Tensor> return_values;
+ bool end_of_input;
};
- void InvokeFunctionLocked(IteratorContext* ctx)
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- DCHECK(input_impl_);
- DCHECK(num_inputs_consumed_ - num_outputs_consumed_ <
- dataset()->num_parallel_calls_);
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&Iterator::RunnerThread, this, ctx_copy)));
+ }
+ }
- // The result of invoking the function will be written into the next
- // slot in `invocation_results_`, which acts as a circular buffer.
- const size_t result_index =
- num_inputs_consumed_ % dataset()->num_parallel_calls_;
- InvocationResult* result = &invocation_results_[result_index];
- *result = InvocationResult();
+ void CallCompleted(const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ num_calls_--;
+ }
+ result->notification.Notify();
+ cond_var_.notify_all();
+ }
+ void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
- bool end_of_input = false;
- result->status =
- input_impl_->GetNext(ctx, &input_element, &end_of_input);
- if (end_of_input) {
- input_impl_.reset();
- result->status = errors::OutOfRange("");
- } else {
- ++num_inputs_consumed_;
+ result->status = input_impl_->GetNext(ctx.get(), &input_element,
+ &result->end_of_input);
+ if (result->end_of_input || !result->status.ok()) {
+ CallCompleted(result);
+ return;
}
- if (result->status.ok()) {
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification`
- // to unblock a consumer.
- result->notification.reset(new Notification);
- dataset()->captured_func_->RunAsync(
- ctx, std::move(input_element), &result->return_values,
- [result, result_index](Status ret_status) {
- result->status.Update(ret_status);
- result->notification->Notify();
- });
+ // Call `func_(input_element)`, store the result in
+ // `result->return_values`, and notify `result->notification` to unblock
+ // a consumer.
+ auto done = [this, result](Status status) {
+ result->status.Update(status);
+ CallCompleted(result);
+ };
+ dataset()->captured_func_->RunAsync(ctx.get(), std::move(input_element),
+ &result->return_values, done);
+ }
+
+ int64 MaxInvocationResults() { return dataset()->num_parallel_calls_; }
+
+ Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (!result->end_of_input && result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (errors::IsOutOfRange(result->status)) {
+ // `f` may deliberately raise `errors::OutOfRange` to indicate that we
+ // should terminate the iteration early.
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ *end_of_sequence = result->end_of_input;
+ return result->status;
+ }
+
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ std::vector<std::shared_ptr<InvocationResult>> new_calls;
+ new_calls.reserve(dataset()->num_parallel_calls_);
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ (num_calls_ >= dataset()->num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+ if (cancelled_) {
+ return;
+ }
+ while (num_calls_ < dataset()->num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ new_calls.push_back(invocation_results_.back());
+ num_calls_++;
+ }
+ }
+ cond_var_.notify_all();
+ for (const auto& call : new_calls) {
+ CallFunction(ctx, call);
+ }
+ new_calls.clear();
}
}
@@ -386,11 +403,22 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
strings::StrCat("invocation_results[", index, "].error_message"));
}
+ // Used for coordination between the main thread and the runner thread.
mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::vector<InvocationResult> invocation_results_ GUARDED_BY(mu_);
- int64 num_inputs_consumed_ GUARDED_BY(mu_) = 0;
- int64 num_outputs_consumed_ GUARDED_BY(mu_) = 0;
+ // Used for coordination between the main thread and the runner thread. In
+ // particular, the runner thread should only schedule new calls when the
+ // number of in-flight calls is less than the user specified level of
+ // parallelism and there are slots available in the `invocation_results_`
+ // buffer.
+ condition_variable cond_var_;
+ // Counts the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<IteratorBase> input_impl_;
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index e2b6aa590e..cc16108dce 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -39,8 +39,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx,
- buffer_size > 0 || buffer_size == PrefetchAutotuner::kAutoTune,
- errors::InvalidArgument("buffer_size must be > 0"));
+ buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune,
+ errors::InvalidArgument("buffer_size must be >= 0"));
*output = new Dataset(ctx, input, buffer_size);
}
@@ -112,13 +112,13 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
-
- while (true) {
+ {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
- while (!cancelled_ && !prefetch_thread_finished_ && buffer_.empty()) {
+ while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
+ auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
cond_var_.wait(l);
}
@@ -129,29 +129,20 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
}
if (!buffer_.empty()) {
- // A new element is available. Forward the status from
- // computing it, and (if we successfully got an element)
- // the output values.
- Status s = buffer_.front().status;
- if (s.ok()) {
- *out_tensors = std::move(buffer_.front().value);
- }
- auto_tuner_.RecordConsumption(buffer_.size());
- buffer_.pop_front();
- *end_of_sequence = false;
-
- // Wake the prefetch thread, in case it has been waiting
- // for space in the buffer.
- // Also wake up threads from other calls to GetNext.
- // TODO(mrry): Consider using different condition variables
- // for GetNext and Prefetch.
- cond_var_.notify_all();
- return s;
- } else if (prefetch_thread_finished_) {
+ return Consume(out_tensors, end_of_sequence);
+ }
+
+ if (prefetch_thread_finished_) {
*end_of_sequence = true;
return Status::OK();
}
+
+ DCHECK_EQ(auto_tuner_.buffer_limit(), 0);
}
+
+ mutex_lock parent_l(parent_mu_);
+ mutex_lock l(mu_);
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
protected:
@@ -227,6 +218,26 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> value;
};
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // A new element is available. Forward the status from computing it, and
+ // (if we successfully got an element) the output values.
+ Status s = buffer_.front().status;
+ if (s.ok()) {
+ *out_tensors = std::move(buffer_.front().value);
+ }
+ buffer_.pop_front();
+ *end_of_sequence = false;
+
+ // Wake the prefetch thread, in case it has been waiting for space
+ // in the buffer. Also wake up threads from other calls to GetNext.
+ //
+ // TODO(mrry): Consider using different condition variables for
+ // GetNext and Prefetch.
+ cond_var_.notify_all();
+ return s;
+ }
+
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
@@ -251,7 +262,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
{
mutex_lock l(mu_);
while (!cancelled_ &&
- buffer_.size() == auto_tuner_.buffer_limit()) {
+ buffer_.size() >= auto_tuner_.buffer_limit()) {
cond_var_.wait(l);
}
@@ -346,7 +357,12 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
-
+REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
+ .Device(DEVICE_GPU)
+ .HostMemory("buffer_size")
+ .HostMemory("input_dataset")
+ .HostMemory("handle"),
+ PrefetchDatasetOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 3438199ebd..b859295fa4 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -61,10 +61,12 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
}
protected:
- class Iterator : public DatasetIterator<ShuffleDatasetBase> {
+ template <class T>
+ class Iterator : public DatasetIterator<T> {
public:
- explicit Iterator(const Params& params, int64 seed, int64 seed2)
- : DatasetIterator<ShuffleDatasetBase>(params),
+ explicit Iterator(const typename DatasetIterator<T>::Params& params,
+ int64 seed, int64 seed2)
+ : DatasetIterator<T>(params),
input_impl_(nullptr),
seed_(seed),
seed2_(seed2),
@@ -85,26 +87,28 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
bool first_call = false;
if (!input_impl_ && epoch_ == 0) {
first_call = true;
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
}
- while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
+ while (input_impl_ && num_elements_ < this->dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
num_log_entries++;
LOG(INFO) << "Filling up shuffle buffer (this may take a while): "
- << num_elements_ << " of " << dataset()->buffer_size_;
+ << num_elements_ << " of "
+ << this->dataset()->buffer_size_;
}
std::vector<Tensor> input_element;
bool end_of_input_sequence = false;
- while (dataset()->count_ == -1 || epoch_ < dataset()->count_) {
+ while (this->dataset()->count_ == -1 ||
+ epoch_ < this->dataset()->count_) {
TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
&end_of_input_sequence));
if (!end_of_input_sequence) {
first_call = false;
break;
}
- if (first_call && dataset()->count_ == -1) {
+ if (first_call && this->dataset()->count_ == -1) {
// If the first call to GetNext() fails because the end
// of sequence has been reached, we terminate the
// iteration immediately. (Otherwise, this iterator
@@ -115,11 +119,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
epoch_++;
int64 n = slices_.back()->end;
slices_.emplace_back(new Slice{n, n});
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
}
if (!end_of_input_sequence) {
- buffer_[slices_.back()->end % dataset()->buffer_size_] =
+ buffer_[slices_.back()->end % this->dataset()->buffer_size_] =
std::move(input_element);
num_elements_++;
slices_.back()->end++;
@@ -144,10 +148,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
int64 offset =
Random() % (slices_.front()->end - slices_.front()->start);
int64 index =
- (slices_.front()->start + offset) % dataset()->buffer_size_;
+ (slices_.front()->start + offset) % this->dataset()->buffer_size_;
*out_tensors = std::move(buffer_[index]);
- std::swap(buffer_[index],
- buffer_[slices_.front()->start % dataset()->buffer_size_]);
+ std::swap(
+ buffer_[index],
+ buffer_[slices_.front()->start % this->dataset()->buffer_size_]);
slices_.front()->start++;
num_elements_--;
} else {
@@ -160,40 +165,44 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
-
// Save state needed to restore the random number generators.
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
- num_random_samples_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ this->full_name("num_random_samples"), num_random_samples_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("seed"), seed_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(this->full_name("seed2"), seed2_));
// Save input iterator if it hasn't been exhausted else write
// "end_of_input_sequence".
if (!input_impl_) {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("end_of_input_sequence"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ this->full_name("end_of_input_sequence"), ""));
} else {
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ TF_RETURN_IF_ERROR(this->SaveParent(writer, input_impl_));
}
// Save the epoch counter, buffer, and buffer slices.
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("epoch"), epoch_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("num_elements"), num_elements_));
TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("slices_size"), slices_.size()));
+ writer->WriteScalar(this->full_name("epoch"), epoch_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("num_elements"),
+ num_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name("slices_size"),
+ slices_.size()));
for (size_t i = 0; i < slices_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("slices_start_", i)),
+ this->full_name(strings::StrCat("slices_start_", i)),
slices_[i]->start));
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("slices_end_", i)), slices_[i]->end));
+ this->full_name(strings::StrCat("slices_end_", i)),
+ slices_[i]->end));
for (size_t j = slices_[i]->start; j < slices_[i]->end; ++j) {
- size_t index = j % dataset()->buffer_size_;
+ size_t index = j % this->dataset()->buffer_size_;
TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("buffer_", index, "_size")),
+ this->full_name(strings::StrCat("buffer_", index, "_size")),
buffer_[index].size()));
for (size_t k = 0; k < buffer_[index].size(); ++k) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat("buffer_", index, "_", k)),
+ this->full_name(strings::StrCat("buffer_", index, "_", k)),
buffer_[index][k]));
}
}
@@ -205,51 +214,54 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
-
// Restore the random number generators.
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
- &num_random_samples_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ this->full_name("num_random_samples"), &num_random_samples_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("seed"), &seed_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(this->full_name("seed2"), &seed2_));
ResetRngs();
// Restore the input iterator if it wasn't already exhausted.
- if (!reader->Contains(full_name("end_of_input_sequence"))) {
- TF_RETURN_IF_ERROR(
- dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ if (!reader->Contains(this->full_name("end_of_input_sequence"))) {
+ TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
+ ctx, this->prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(this->RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
}
// Restore the epoch counter, buffer, and buffer slices.
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("epoch"), &epoch_));
TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("num_elements"), &num_elements_));
+ reader->ReadScalar(this->full_name("epoch"), &epoch_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(this->full_name("num_elements"),
+ &num_elements_));
size_t slices_size;
{
int64 temp;
TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("slices_size"), &temp));
+ reader->ReadScalar(this->full_name("slices_size"), &temp));
slices_size = static_cast<size_t>(temp);
}
- buffer_.reset(new std::vector<Tensor>[dataset()->buffer_size_]);
+ buffer_.reset(new std::vector<Tensor>[this->dataset()->buffer_size_]);
for (size_t i = 0; i < slices_size; ++i) {
int64 start;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("slices_start_", i)), &start));
+ this->full_name(strings::StrCat("slices_start_", i)), &start));
int64 end;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("slices_end_", i)), &end));
+ this->full_name(strings::StrCat("slices_end_", i)), &end));
slices_.emplace_back(new Slice{start, end});
for (size_t j = start; j < end; ++j) {
- size_t index = j % dataset()->buffer_size_;
+ size_t index = j % this->dataset()->buffer_size_;
int64 list_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("buffer_", index, "_size")),
+ this->full_name(strings::StrCat("buffer_", index, "_size")),
&list_size));
buffer_[index] = std::vector<Tensor>(list_size);
for (int k = 0; k < list_size; ++k) {
TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat("buffer_", index, "_", k)),
+ this->full_name(strings::StrCat("buffer_", index, "_", k)),
&buffer_[index][k]));
}
}
@@ -289,8 +301,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
mutex mu_;
std::unique_ptr<std::vector<Tensor>[]> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- const int64 seed_ GUARDED_BY(mu_);
- const int64 seed2_ GUARDED_BY(mu_);
+ int64 seed_ GUARDED_BY(mu_);
+ int64 seed2_ GUARDED_BY(mu_);
int64 epoch_ GUARDED_BY(mu_);
int64 num_elements_ GUARDED_BY(mu_);
std::deque<std::unique_ptr<Slice>> slices_ GUARDED_BY(mu_);
@@ -360,6 +372,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
generator_(&parent_generator_) {}
string DebugString() const override {
+ mutex_lock l(mu_);
return strings::StrCat("ShuffleDatasetOp(", buffer_size_, ", ", seed_,
", ", seed2_, ")::ReshufflingDataset");
}
@@ -370,38 +383,96 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
int64 iterator_seed2;
{
mutex_lock l(mu_);
- iterator_seed = generator_();
- iterator_seed2 = generator_();
+ iterator_seed = Random();
+ iterator_seed2 = Random();
}
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::Shuffle")}, iterator_seed,
- iterator_seed2));
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Shuffle")},
+ iterator_seed, iterator_seed2));
}
protected:
+ class Iterator : public ShuffleDatasetBase::Iterator<ReshufflingDataset> {
+ public:
+ explicit Iterator(const Params& params, int64 seed, int64 seed2)
+ : ShuffleDatasetBase::Iterator<ReshufflingDataset>(params, seed,
+ seed2) {}
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(dataset()->mu_);
+
+ // Save RNG state of Dataset.
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("ds_num_random_samples"),
+ dataset()->num_random_samples_));
+
+ // Save the Iterator.
+ return ShuffleDatasetBase::Iterator<ReshufflingDataset>::SaveInternal(
+ writer);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(dataset()->mu_);
+
+ // Restore RNG state of Dataset.
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("ds_num_random_samples"),
+ &dataset()->num_random_samples_));
+ dataset()->ResetRngs();
+
+ // Restore the Iterator.
+ return ShuffleDatasetBase::Iterator<
+ ReshufflingDataset>::RestoreInternal(ctx, reader);
+ }
+ };
+
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented(
- "Checkpointing ShufflingDataset with reshuffle_each_iteration=true "
- "is not supported.\n"
- "If you have a ds.shuffle(buffer_size).repeat(count) in your input "
- "pipeline, replace it with "
- "ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count)).\n"
- "If you iterate over your dataset once, change shuffle(buffer_size) "
- "to shuffle(buffer_size, reshuffle_each_iteration=False).\n"
- "If you are using Dataset.list_files(pattern), change it to "
- "Dataset.list_files(pattern, shuffle=False) and manually shuffle "
- "the list of files using shuffle_and_repeat as above or using "
- "ds.shuffle with reshuffle_each_iteration=False.");
+ mutex_lock l(mu_);
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* buffer_size = nullptr;
+ Node* seed = nullptr;
+ Node* seed2 = nullptr;
+ AttrValue reshuffle_each_iteration;
+
+ TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
+ b->BuildAttrValue(true, &reshuffle_each_iteration);
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, buffer_size, seed, seed2}, // Inputs
+ {std::make_pair("reshuffle_each_iteration",
+ reshuffle_each_iteration)}, // Attrs
+ output));
+ return Status::OK();
}
private:
- const int64 seed_;
- const int64 seed2_;
+ random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random() const
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ num_random_samples_++;
+ auto out = generator_();
+ return out;
+ }
+
+ void ResetRngs() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // Reset the generators based on the current seeds.
+ parent_generator_ = random::PhiloxRandom(seed_, seed2_);
+ generator_ =
+ random::SingleSampleAdapter<random::PhiloxRandom>(&parent_generator_);
+ generator_.Skip(num_random_samples_);
+ }
+
+ mutable int64 seed_ GUARDED_BY(mu_);
+ mutable int64 seed2_ GUARDED_BY(mu_);
mutable mutex mu_;
mutable random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
mutable random::SingleSampleAdapter<random::PhiloxRandom> generator_
GUARDED_BY(mu_);
+ mutable int64 num_random_samples_ GUARDED_BY(mu_) = 0;
};
// A dataset that uses the same fixed seed for all iterators created from it.
@@ -421,8 +492,9 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
+ return std::unique_ptr<IteratorBase>(
+ new ShuffleDatasetBase::Iterator<ShuffleDatasetBase>(
+ {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
}
protected:
@@ -504,9 +576,10 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
- {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
- seed2_));
+ return std::unique_ptr<IteratorBase>(
+ new ShuffleDatasetBase::Iterator<ShuffleDatasetBase>(
+ {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
+ seed2_));
}
protected:
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index 48776cbf61..07cc91f9d5 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
@@ -32,16 +33,24 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- int64 stride = 1;
+ int64 stride = 0;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- OP_REQUIRES(
- ctx, stride > 0 && stride < window_size,
- errors::InvalidArgument("Stride must be in [1, window_size)."));
+ OP_REQUIRES(ctx, stride > 0,
+ errors::InvalidArgument("Stride must be greater than zero."));
+ if (stride == window_size) {
+ LOG(WARNING) << "stride: " << stride
+ << " is equal to window_size: " << window_size
+ << ", to use `batch` instead.";
+ } else if (stride > window_size) {
+ LOG(WARNING) << "stride: " << stride
+ << " is greater than window_size: " << window_size
+ << ", you will lose some data.";
+ }
*output = new Dataset(ctx, window_size, stride, input);
}
@@ -124,12 +133,15 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
batch_elements.reserve(window_size);
- const bool first_call = cache_.empty();
- if (first_call) {
- cache_.reserve(window_size);
- } else {
- // Reuse cache in the previous iteration.
- cache_.swap(batch_elements);
+ // Use cache if stride < window_size.
+ if (stride < window_size) {
+ const bool first_call = cache_.empty();
+ if (first_call) {
+ cache_.reserve(window_size);
+ } else {
+ // Reuse cache in the previous iteration.
+ cache_.swap(batch_elements);
+ }
}
// Fill up with new elements.
*end_of_sequence = false;
@@ -149,9 +161,22 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
DCHECK(*end_of_sequence);
return Status::OK();
}
- // Cache the data used for the next iteration.
- for (size_t i = stride; i < window_size; ++i) {
- cache_.emplace_back(batch_elements[i]);
+
+ if (stride < window_size) {
+ // Cache the data used for the next iteration.
+ for (size_t i = stride; i < window_size; ++i) {
+ cache_.emplace_back(batch_elements[i]);
+ }
+ } else if (stride > window_size) {
+ // Drop the data before the next iteration.
+ std::vector<Tensor> batch_element_tuple;
+ for (size_t i = window_size; i < stride && !*end_of_sequence; ++i) {
+ TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
+ end_of_sequence));
+ if (*end_of_sequence) {
+ input_impl_.reset();
+ }
+ }
}
}
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index 33a56b2eb5..b133cfab54 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -20,11 +20,25 @@ limitations under the License.
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/lib/monitoring/counter.h"
+#include "tensorflow/core/lib/monitoring/gauge.h"
+#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace {
+static mutex* get_counters_map_lock() {
+ static mutex counters_map_lock(LINKER_INITIALIZED);
+ return &counters_map_lock;
+}
+
+static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
+ static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
+ new std::unordered_map<string, monitoring::Counter<1>*>;
+ return counters_map;
+}
+
class StatsAggregatorImpl : public StatsAggregator {
public:
StatsAggregatorImpl() {}
@@ -61,6 +75,21 @@ class StatsAggregatorImpl : public StatsAggregator {
}
}
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ mutex_lock l(*get_counters_map_lock());
+ auto counters_map = get_counters_map();
+ if (counters_map->find(name) == counters_map->end()) {
+ counters_map->emplace(
+ name, monitoring::Counter<1>::New(
+ /*streamz name*/ "/tensorflow/" + name,
+ /*streamz description*/
+ name + " generated or consumed by the component.",
+ /*streamz label name*/ "component_descriptor"));
+ }
+ counters_map->at(name)->GetCell(label)->IncrementBy(val);
+ }
+
private:
mutex mu_;
std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 7370a24b38..a537e7e68f 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
@@ -234,6 +236,200 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
+class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit FeatureStatsDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ OP_REQUIRES(ctx, input->output_dtypes()[0] == DT_STRING,
+ errors::InvalidArgument("FeatureStatsDataset only supports "
+ "input with a single `tf.string` "
+ "component."));
+ *output = new Dataset(ctx, input, std::move(tag));
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
+ : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::FeatureStatsDataset")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "FeatureStatsDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* tag_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ tf_shared_lock l(mu_);
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ auto stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator && s.ok() && !*end_of_sequence) {
+ for (const Tensor& t : *out_tensors) {
+ auto record_t = t.flat<string>();
+ Example example;
+ // TODO(shivaniagrawal): redundant parsing here, potential solutions
+ // to improve performance is to a) have a potential
+ // ParseExampleDataset and collect stats from there and b) make
+ // changes to parse_example() where it returns stats as well.
+ for (int i = 0; i < record_t.size(); ++i) {
+ if (example.ParseFromString(record_t(i))) {
+ stats_aggregator->IncrementCounter("examples_count", "trainer",
+ 1);
+ AddStatsFeatures(example, stats_aggregator);
+ } else {
+ SequenceExample sequence_example;
+ if (sequence_example.ParseFromString(record_t(i))) {
+ stats_aggregator->IncrementCounter("sequence_examples_count",
+ "trainer", 1);
+ AddStatsFeatures(sequence_example, stats_aggregator);
+ }
+ }
+ }
+ }
+ }
+ return s;
+ }
+
+ // TODO(shivaniagrawal): Add features/feature-values to streamz metrics.
+ int AddStatsFeatureValues(const Feature& feature) {
+ int feature_values_list_size = 0;
+ switch (feature.kind_case()) {
+ case Feature::kBytesList: {
+ feature_values_list_size = feature.bytes_list().value().size();
+ break;
+ }
+ case Feature::kFloatList: {
+ feature_values_list_size = feature.float_list().value().size();
+ break;
+ }
+ case Feature::kInt64List: {
+ feature_values_list_size = feature.int64_list().value().size();
+ break;
+ }
+ case Feature::KIND_NOT_SET:
+ break;
+ }
+ return feature_values_list_size;
+ }
+
+ void AddStatsFeatures(
+ const Example& example,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":features"),
+ {static_cast<double>(example.features().feature().size())});
+
+ int feature_values_list_size_sum = 0;
+ for (const auto& feature : example.features().feature()) {
+ stats_aggregator->IncrementCounter("features_count", "trainer", 1);
+ feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
+ }
+ stats_aggregator->IncrementCounter("feature_values_count", "trainer",
+ feature_values_list_size_sum);
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":feature-values"),
+ {static_cast<double>(feature_values_list_size_sum)});
+ }
+
+ void AddStatsFeatures(
+ const SequenceExample& example,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":features"),
+ {static_cast<double>(
+ example.context().feature().size() +
+ example.feature_lists().feature_list().size())});
+
+ int feature_values_list_size_sum = 0;
+ for (const auto& feature : example.context().feature()) {
+ stats_aggregator->IncrementCounter("features_count", "trainer", 1);
+ feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
+ }
+
+ for (const auto& feature_list :
+ example.feature_lists().feature_list()) {
+ stats_aggregator->IncrementCounter("feature_lists_count", "reainer",
+ 1);
+ for (const auto& feature : feature_list.second.feature()) {
+ feature_values_list_size_sum += AddStatsFeatureValues(feature);
+ }
+ }
+ stats_aggregator->IncrementCounter("feature_values_count", "trainer",
+ feature_values_list_size_sum);
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(dataset()->tag_, ":feature-values"),
+ {static_cast<double>(feature_values_list_size_sum)});
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ const string tag_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("FeatureStatsDataset").Device(DEVICE_CPU),
+ FeatureStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 668b461374..17551bccd9 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -17,6 +17,7 @@ limitations under the License.
namespace tensorflow {
namespace {
+// TODO(b/110981596): Support checkpointing.
class WindowDataset : public DatasetBase {
public:
WindowDataset(std::vector<std::vector<Tensor>> elements,
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 97c31668ac..7bd31a0bc7 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -31,7 +31,7 @@ namespace tensorflow {
//
// This dataset is constructed internally for use in datasets that
// build nested dataset expressions (e.g. the reducer function for
-// GroupByBatchDataset). It efficiently supports multiple iterators on
+// GroupByWindowDataset). It efficiently supports multiple iterators on
// the same window without recomputation.
//
// REQUIRES: `output_types` must match the types of the respective
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
new file mode 100644
index 0000000000..0283e5697b
--- /dev/null
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/window_dataset.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class WindowDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit WindowDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 window_size = 0;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES(
+ ctx, window_size > 0,
+ errors::InvalidArgument("Window size must be greater than zero."));
+
+ *output = new Dataset(ctx, window_size, input);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ : GraphDatasetBase(ctx), window_size_(window_size), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ Iterator::Params{this, strings::StrCat(prefix, "::Window")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* output_dtypes = new DataTypeVector({DT_VARIANT});
+ return *output_dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* output_shapes =
+ new std::vector<PartialTensorShape>({TensorShape({})});
+ return *output_shapes;
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* window_size = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, window_size}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ // Each row of `window_elements` is a tuple of tensors from the
+ // input iterator.
+ std::vector<std::vector<Tensor>> window_elements;
+ {
+ mutex_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ window_elements.reserve(dataset()->window_size_);
+ *end_of_sequence = false;
+ for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
+ ++i) {
+ std::vector<Tensor> window_element_tuple;
+ TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
+ end_of_sequence));
+ if (!*end_of_sequence) {
+ window_elements.emplace_back(std::move(window_element_tuple));
+ } else {
+ input_impl_.reset();
+ }
+ }
+ }
+
+ if (window_elements.empty()) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ const size_t num_tuple_components = window_elements[0].size();
+ const int64 num_window_elements = window_elements.size();
+ for (size_t idx = 0; idx < num_tuple_components; ++idx) {
+ DatasetBase* window_dataset;
+ std::vector<std::vector<Tensor>> window_component_elements;
+ window_component_elements.reserve(num_window_elements);
+ // Build the output tuple component by copying one slice
+ // from each input element in the window.
+ for (size_t i = 0; i < num_window_elements; ++i) {
+ std::vector<Tensor> component_element;
+ component_element.push_back(std::move(window_elements[i][idx]));
+ window_component_elements.push_back(component_element);
+ }
+ DataTypeVector output_types(
+ {dataset()->input_->output_dtypes()[idx]});
+ std::vector<PartialTensorShape> output_shapes(
+ {dataset()->input_->output_shapes()[idx]});
+ TF_RETURN_IF_ERROR(NewWindowDataset(window_component_elements,
+ output_types, output_shapes,
+ &window_dataset));
+ out_tensors->emplace_back(DT_VARIANT, TensorShape({}));
+ TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
+ &out_tensors->back()));
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (!input_impl_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ } else {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const int64 window_size_;
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
+ WindowDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc
index 85a9702ae7..1aa8c72d66 100644
--- a/tensorflow/core/kernels/deep_conv2d.cc
+++ b/tensorflow/core/kernels/deep_conv2d.cc
@@ -393,8 +393,9 @@ struct TransformFilters {
// Calculate filter transform batch based on cache/filter sizes.
- // Cache budget (based on L2 cache size).
- const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
+ // Cache budget (based on L2 cache size = 256KB).
+ // TODO(andydavis) Read cache size from system.
+ const int64 cache_size = (256LL << 10) / sizeof(T);
// Fixed cost.
const int64 filter_transform_matrix_size =
@@ -1017,8 +1018,9 @@ struct DeepConv2D<CPUDevice, T> {
const int64 filter_shard_size = filter_shards_row * filter_shards_col;
const int64 out_tile_spatial_size = out_tile_rows * out_tile_cols;
- // Cache budget (based on L2 cache size).
- const int64 cache_size = Eigen::l2CacheSize() / sizeof(T);
+ // Cache budget (based on L2 cache size = 256KB).
+ // TODO(andydavis) Read cache size from the system.
+ const int64 cache_size = (256LL << 10) / sizeof(T);
// Fixed costs.
const int64 tile_transform_matrix_size =
diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc
index 0de97de205..f942b1a8a9 100644
--- a/tensorflow/core/kernels/dense_update_ops.cc
+++ b/tensorflow/core/kernels/dense_update_ops.cc
@@ -98,6 +98,8 @@ typedef Eigen::SyclDevice SYCLDevice;
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS);
+// quint16 not included in QUANTZIED_TYPES
+TF_CALL_quint16(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
new file mode 100644
index 0000000000..6fb07c11e9
--- /dev/null
+++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc
@@ -0,0 +1,293 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <numeric>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/kernels/reshape_util.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+namespace {
+
+using sparse::SparseTensor;
+
+class DeserializeSparseOp : public OpKernel {
+ public:
+ explicit DeserializeSparseOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& serialized_sparse = context->input(0);
+ const int ndims = serialized_sparse.shape().dims();
+
+ OP_REQUIRES(
+ context, ndims > 0,
+ errors::InvalidArgument("Serialized sparse should have non-zero rank ",
+ serialized_sparse.shape().DebugString()));
+
+ OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
+ errors::InvalidArgument(
+ "Serialized sparse should have 3 as the last dimension ",
+ serialized_sparse.shape().DebugString()));
+
+ int num_sparse_tensors = 1;
+ for (int i = 0; i < ndims - 1; ++i) {
+ num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
+ }
+
+ OP_REQUIRES(
+ context, num_sparse_tensors > 0,
+ errors::InvalidArgument(
+ "Serialized sparse should have at least 1 serialized tensor, "
+ "but has a zero dimension ",
+ serialized_sparse.shape().DebugString()));
+
+ if (num_sparse_tensors == 1 && ndims == 1) {
+ // Special case with a single sparse tensor. We can avoid data
+ // motion in the Concat and Reshape.
+ const auto& serialized_sparse_t = serialized_sparse.vec<string>();
+
+ Tensor output_indices;
+ Tensor output_values;
+ Tensor output_shape;
+ OP_REQUIRES_OK(context,
+ this->GetAndValidateSparseTensor(
+ serialized_sparse_t(0), serialized_sparse_t(1),
+ serialized_sparse_t(2), dtype_, 0 /* index */,
+ &output_indices, &output_values, &output_shape));
+ context->set_output(0, output_indices);
+ context->set_output(1, output_values);
+ context->set_output(2, output_shape);
+ return;
+ }
+
+ std::vector<Tensor> indices;
+ std::vector<Tensor> values;
+ TensorShape shape;
+ indices.reserve(num_sparse_tensors);
+ values.reserve(num_sparse_tensors);
+
+ const auto& serialized_sparse_t =
+ serialized_sparse.flat_inner_dims<string, 2>();
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ Tensor output_indices;
+ Tensor output_values;
+ Tensor output_shape;
+ OP_REQUIRES_OK(context,
+ this->GetAndValidateSparseTensor(
+ serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
+ serialized_sparse_t(i, 2), dtype_, i, &output_indices,
+ &output_values, &output_shape));
+ int64 num_entries = output_indices.dim_size(0);
+ int rank = output_indices.dim_size(1);
+
+ // Now we expand each SparseTensors' indices and shape by
+ // prefixing a dimension
+ Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
+ const auto& output_indices_t = output_indices.matrix<int64>();
+ auto expanded_indices_t = expanded_indices.matrix<int64>();
+ expanded_indices_t.chip<1>(0).setZero();
+ if (rank > 0) {
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
+ Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
+ expanded_indices_t.slice(indices_start, indices_sizes) =
+ output_indices_t;
+ }
+ Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
+ const auto& output_shape_t = output_shape.vec<int64>();
+ auto expanded_shape_t = expanded_shape.vec<int64>();
+ expanded_shape_t(0) = 1;
+ std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
+
+ TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
+
+ indices.push_back(expanded_indices);
+ values.push_back(output_values);
+ if (i == 0) {
+ shape = expanded_tensor_shape;
+ } else {
+ OP_REQUIRES(
+ context, shape.dims() == expanded_tensor_shape.dims(),
+ errors::InvalidArgument(
+ "Inconsistent shape across SparseTensors: rank prior to "
+ "SparseTensor[",
+ i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
+ "] is: ", expanded_tensor_shape.dims() - 1));
+ for (int j = 1; j < shape.dims(); ++j) {
+ // NOTE(mrry): For compatibility with the implementations of
+ // DeserializeManySparse, and many ops that generate
+ // SparseTensors to batch that do not have a fixed
+ // dense_shape (e.g. `tf.parse_single_example()`), we
+ // compute the maximum in each dimension to find the
+ // smallest dense_shape that bounds all of the input
+ // SparseTensors.
+ shape.set_dim(j, std::max(shape.dim_size(j),
+ expanded_tensor_shape.dim_size(j)));
+ }
+ }
+ }
+
+ // Dimension 0 is the primary dimension.
+ int rank = shape.dims();
+ gtl::InlinedVector<int64, 8> std_order(rank);
+ std::iota(std_order.begin(), std_order.end(), 0);
+
+ std::vector<SparseTensor> tensors;
+ tensors.reserve(num_sparse_tensors);
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ tensors.emplace_back(indices[i], values[i], shape, std_order);
+ }
+
+ gtl::optional<SparseTensor> maybe_output;
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ maybe_output = SparseTensor::Concat<T>(tensors); \
+ break; \
+ }
+
+ switch (dtype_) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+ TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ OP_REQUIRES(context, false,
+ errors::Unimplemented(
+ "DeserializeSparse Unhandled data type: ", dtype_));
+ }
+ DCHECK(maybe_output);
+ SparseTensor& output = maybe_output.value();
+
+ // Compute the input shape for the reshape operation.
+ Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
+ std::copy_n(output.shape().data(), output.dims(),
+ input_shape.vec<int64>().data());
+
+ // Compute the target shape for the reshape operation.
+ Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
+ for (int i = 0; i < ndims - 1; ++i) {
+ target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
+ }
+ for (int i = 0; i < output.dims() - 1; ++i) {
+ target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
+ }
+
+ Tensor output_indices;
+ Tensor output_shape;
+ Reshape(context, output.indices(), input_shape, target_shape,
+ 0 /* output indices index */, 2 /* output shape index */);
+ context->set_output(1, output.values());
+ }
+
+ private:
+ Status Deserialize(const string& serialized, Tensor* result) {
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, serialized)) {
+ return errors::InvalidArgument("Could not parse serialized proto");
+ }
+ Tensor tensor;
+ if (!tensor.FromProto(proto)) {
+ return errors::InvalidArgument("Could not construct tensor from proto");
+ }
+ *result = tensor;
+ return Status::OK();
+ }
+
+ Status GetAndValidateSparseTensor(
+ const string& serialized_indices, const string& serialized_values,
+ const string& serialized_shape, DataType values_dtype, int index,
+ Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
+ // Deserialize and validate the indices.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
+ if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to represent an index matrix but received shape ",
+ output_indices->shape().DebugString());
+ }
+ int64 num_entries = output_indices->dim_size(0);
+ int rank = output_indices->dim_size(1);
+
+ // Deserialize and validate the values.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
+ if (!TensorShapeUtils::IsVector(output_values->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to represent a values vector but received shape ",
+ output_values->shape().DebugString());
+ }
+ if (values_dtype != output_values->dtype()) {
+ return errors::InvalidArgument(
+ "Requested SparseTensor of type ", DataTypeString(values_dtype),
+ " but SparseTensor[", index,
+ "].values.dtype() == ", DataTypeString(output_values->dtype()));
+ }
+ if (num_entries != output_values->dim_size(0)) {
+ return errors::InvalidArgument(
+ "Expected row counts of SparseTensor[", index,
+ "].indices and SparseTensor[", index,
+ "].values to match but they do not: ", num_entries, " vs. ",
+ output_values->dim_size(0));
+ }
+
+ // Deserialize and validate the shape.
+ TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
+ if (!TensorShapeUtils::IsVector(output_shape->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to be a shape vector but its shape is ",
+ output_shape->shape().DebugString());
+ }
+ if (rank != output_shape->dim_size(0)) {
+ return errors::InvalidArgument("Expected column counts of SparseTensor[",
+ index,
+ "].indices to match size of SparseTensor[",
+ index, "].shape but they do not: ", rank,
+ " vs. ", output_shape->dim_size(0));
+ }
+ return Status::OK();
+ }
+
+ DataType dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<string>("Tserialized"),
+ DeserializeSparseOp)
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
+ DeserializeSparseOp)
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc
new file mode 100644
index 0000000000..fce3029e4e
--- /dev/null
+++ b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc
@@ -0,0 +1,372 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+namespace {
+
+class DeserializeSparseOp : public OpKernel {
+ public:
+ explicit DeserializeSparseOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ OP_REQUIRES(
+ context, input.dims() > 0,
+ errors::InvalidArgument("Serialized sparse should have non-zero rank ",
+ input.shape().DebugString()));
+ OP_REQUIRES(context, input.shape().dim_size(input.dims() - 1) == 3,
+ errors::InvalidArgument(
+ "Serialized sparse should have 3 as the last dimension ",
+ input.shape().DebugString()));
+
+ // `input_dims_to_stack` is the number of dimensions that will be added to
+ // each of the elements before they are concatenated into the output.
+ const int64 input_dims_to_stack = input.dims() - 1;
+ int num_sparse_tensors = 1;
+ for (int i = 0; i < input_dims_to_stack; ++i) {
+ num_sparse_tensors *= input.shape().dim_size(i);
+ }
+
+ if (num_sparse_tensors == 1 && input_dims_to_stack == 0) {
+ // Special case with a single sparse tensor, and no dimensions to add
+ // to the output indices. We can return the boxed tensors directly (after
+ // validating them).
+ const Tensor* output_indices;
+ const Tensor* output_values;
+ const Tensor* output_shape;
+ const auto& input_as_vec = input.vec<Variant>();
+ int64 total_non_zeros;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_vec(1), input_as_vec(2), 0,
+ &output_shape, &total_non_zeros));
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorIndicesAndValues(
+ input_as_vec(0), input_as_vec(1), 0,
+ output_shape->NumElements(), &output_indices,
+ &output_values));
+ context->set_output(0, *output_indices);
+ context->set_output(1, *output_values);
+ context->set_output(2, *output_shape);
+ return;
+ }
+
+ OP_REQUIRES(
+ context, num_sparse_tensors > 0,
+ errors::InvalidArgument(
+ "Serialized sparse should have at least 1 serialized tensor, "
+ "but has a zero dimension ",
+ input.shape().DebugString()));
+
+ const auto& input_as_matrix = input.flat_inner_dims<Variant, 2>();
+
+ // Compute the output "dense shape" of and number of non-zero elements in
+ // the stacked sparse tensors. Given an input of shape (S_0, ...,
+ // S_{input_dims_to_stack-1}, 3), and an element of dense shape (E_0, ...
+ // E_n), the output dense shape will be (S_0, ...,
+ // S_{input_dims_to_stack-1}, E_0, ..., E_n).
+ Tensor* output_shape;
+ int64 total_non_zeros = 0;
+
+ // Allocate and build the initial output shape based on the element shape of
+ // the 0th sparse tensor in the input.
+ //
+ // NOTE(mrry): We define `element_shape` as a `const Tensor*` rather than a
+ // `Tensor` to avoid the overhead of allocating and deallocating a `Tensor`
+ // on the stack. While the per-`Tensor` cost is small, this op can unbox a
+ // large number of tensors (3 per batch element) and these fixed overheads
+ // dominate when the number of non-zeros per element is small.
+ const Tensor* element_shape;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_matrix(0, 1), input_as_matrix(0, 2), 0,
+ &element_shape, &total_non_zeros));
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 2, {input_dims_to_stack + element_shape->NumElements()},
+ &output_shape));
+ const auto element_shape_vec = element_shape->vec<int64>();
+ auto output_shape_vec = output_shape->vec<int64>();
+ output_shape_vec(0) = num_sparse_tensors;
+ for (int64 j = 0; j < input_dims_to_stack; ++j) {
+ output_shape_vec(j) = input.dim_size(j);
+ }
+ for (int64 j = 0; j < element_shape->NumElements(); ++j) {
+ output_shape_vec(j + input_dims_to_stack) = element_shape_vec(j);
+ }
+
+ // Accumulate the number of non-zero elements from the remaining sparse
+ // tensors, and validate that they have compatible dense shapes.
+ //
+ // NOTE(mrry): For compatibility with the implementations of
+ // DeserializeManySparse, and many ops that generate SparseTensors to batch
+ // that do not have a fixed dense_shape (e.g. `tf.parse_single_example()`),
+ // we compute the maximum in each dimension to find the smallest dense_shape
+ // that bounds all of the input SparseTensors.
+ for (int i = 1; i < num_sparse_tensors; ++i) {
+ int64 num_non_zeros;
+ OP_REQUIRES_OK(context, GetAndValidateSparseTensorShape(
+ input_as_matrix(i, 1), input_as_matrix(i, 2),
+ i, &element_shape, &num_non_zeros));
+ total_non_zeros += num_non_zeros;
+ OP_REQUIRES(
+ context,
+ output_shape->NumElements() - input_dims_to_stack ==
+ element_shape->NumElements(),
+ errors::InvalidArgument(
+ "Inconsistent shape across SparseTensors: rank prior to "
+ "SparseTensor[",
+ i, "] was: ", output_shape->NumElements() - input_dims_to_stack,
+ " but rank of SparseTensor[", i,
+ "] is: ", element_shape->NumElements()));
+ const auto element_shape_vec = element_shape->vec<int64>();
+ for (int j = 0; j < element_shape->NumElements(); ++j) {
+ output_shape_vec(j + input_dims_to_stack) = std::max(
+ output_shape_vec(j + input_dims_to_stack), element_shape_vec(j));
+ }
+ }
+
+ // Compute the output "indices" matrix and "values" vector.
+ Tensor* output_indices;
+ Tensor* output_values;
+
+ const int output_rank = output_shape->NumElements();
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, {static_cast<int64>(total_non_zeros), output_rank},
+ &output_indices));
+ OP_REQUIRES_OK(
+ context, context->allocate_output(
+ 1, {static_cast<int64>(total_non_zeros)}, &output_values));
+
+ // The bulk of the work in this method involves building the output indices
+ // in a tight loop. For cache friendliness, we generate the indices in the
+ // order that they will be laid out in memory. We use raw pointers instead
+ // of Eigen element/slice indexing methods, to access the underlying index
+ // buffer to minimize the amount of work in that tight loop.
+ int64* output_indices_data = output_indices->matrix<int64>().data();
+ size_t current_row = 0;
+
+ for (int i = 0; i < num_sparse_tensors; ++i) {
+ const Tensor* element_indices;
+ const Tensor* element_values;
+ OP_REQUIRES_OK(context, this->GetAndValidateSparseTensorIndicesAndValues(
+ input_as_matrix(i, 0), input_as_matrix(i, 1),
+ i, output_rank - input_dims_to_stack,
+ &element_indices, &element_values));
+
+ const size_t num_index_rows = element_values->NumElements();
+
+ // An empty sparse tensor in the input will generate no data
+ // in the output. We short-circuit the rest of the iteration to avoid
+ // triggering assertions in the Eigen when manipulating empty tensors (or
+ // slices of tensors).
+ if (num_index_rows == 0) continue;
+
+ const size_t start_row = current_row;
+ const size_t next_start_row = current_row + num_index_rows;
+
+ // NOTE(mrry): If the element is a scalar SparseTensor,
+ // `element_indices` will be an empty tensor, and this pointer will not
+ // be valid. However, we will not dereference the pointer in that case,
+ // because `input_dims_to_stack == output_rank`.
+ const int64* element_indices_data =
+ element_indices->matrix<int64>().data();
+
+ // Build the submatrix of `output_indices` for the i^th sparse tensor
+ // in the input.
+ //
+ // Each row of `output_indices` comprises `input_dims_to_stack` indices
+ // based on the position of the i^th sparse tensor in the input tensor,
+ // followed by the indices from the corresponding row in
+ // `element_indices`.
+ if (input_dims_to_stack == 1 && output_rank == 2) {
+ // We specialize this case because the compiler can generate
+ // more efficient code when the number of indices for each element is
+ // known statically. Since the most common use of this op is to
+ // serialize batches of SparseTensors, and the most common source of
+ // SparseTensors is the `tf.parse_single_example()` op, which generates
+ // 1-D SparseTensors, we statically unroll the loop for the rank 2
+ // output case.
+ for (; current_row < next_start_row; ++current_row) {
+ *output_indices_data++ = i;
+ *output_indices_data++ = *element_indices_data++;
+ }
+ } else {
+ // `sparse_tensor_index` is the tuple of indices that correspond to
+ // mapping the flat element index (`i`) back onto the stacked
+ // coordinates implied by the position of the i^th sparse tensor in the
+ // input tensor.
+ //
+ // We build `sparse_tensor_index` in reverse (innermost/minor dimension
+ // to outermost/major dimension). The `cumulative_product` represents
+ // the size of the inner subtensor for which `sparse_tensor_index` has
+ // already been built.
+ gtl::InlinedVector<int64, 4> sparse_tensor_index(input_dims_to_stack);
+ int cumulative_product = 1;
+ for (size_t j = 0; j < sparse_tensor_index.size(); ++j) {
+ size_t reverse_index = sparse_tensor_index.size() - j - 1;
+ sparse_tensor_index[reverse_index] =
+ (i / cumulative_product) % input.dim_size(reverse_index);
+ cumulative_product *= input.dim_size(reverse_index);
+ }
+ for (; current_row < next_start_row; ++current_row) {
+ for (int64 sparse_tensor_index_component : sparse_tensor_index) {
+ *output_indices_data++ = sparse_tensor_index_component;
+ }
+ for (size_t k = input_dims_to_stack; k < output_rank; ++k) {
+ *output_indices_data++ = *element_indices_data++;
+ }
+ }
+ }
+
+ // Build the subvector of `output_values` for the i^th sparse tensor
+ // in the input.
+ //
+ // NOTE(mrry): There is a potential optimization here where we use a T*
+ // to represent the current position in `output_values`, but it would
+ // require some rejigging of the template parameters.
+ // NOTE(mrry): Another potential optimization: if we know that this
+ // operation consumes its input, we could std::move non-primitive elements
+ // into the output and avoid a copy.
+ Eigen::DSizes<Eigen::DenseIndex, 1> values_start(start_row);
+ Eigen::DSizes<Eigen::DenseIndex, 1> values_sizes(num_index_rows);
+
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ output_values->vec<T>().slice(values_start, values_sizes) = \
+ element_values->vec<T>(); \
+ break; \
+ }
+ switch (dtype_) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+ TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ OP_REQUIRES_OK(
+ context, errors::Unimplemented(
+ "DeserializeSparse Unhandled data type: ", dtype_));
+ }
+ }
+ }
+
+ private:
+ Status GetAndValidateSparseTensorShape(const Variant& serialized_values,
+ const Variant& serialized_shape,
+ int index, const Tensor** output_shape,
+ int64* output_num_non_zeros) {
+ // Deserialize and validate the shape.
+ *output_shape = serialized_shape.get<Tensor>();
+ if (*output_shape == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 2]");
+ }
+ if ((*output_shape)->dtype() != DT_INT64) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 2] to be a vector of DT_INT64 but received dtype ",
+ DataTypeString((*output_shape)->dtype()));
+ }
+ if (!TensorShapeUtils::IsVector((*output_shape)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 2] to be a shape vector but its shape is ",
+ (*output_shape)->shape().DebugString());
+ }
+ *output_num_non_zeros = serialized_values.get<Tensor>()->NumElements();
+ return Status::OK();
+ }
+
+ Status GetAndValidateSparseTensorIndicesAndValues(
+ const Variant& serialized_indices, const Variant& serialized_values,
+ int index, int expected_rank, const Tensor** output_indices,
+ const Tensor** output_values) {
+ // Deserialize and validate the indices.
+ *output_indices = serialized_indices.get<Tensor>();
+ if (*output_indices == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 0]");
+ }
+ if ((*output_indices)->dtype() != DT_INT64) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to be a matrix of DT_INT64 but received dtype ",
+ DataTypeString((*output_indices)->dtype()));
+ }
+ if (!TensorShapeUtils::IsMatrix((*output_indices)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 0] to represent an index matrix but received shape ",
+ (*output_indices)->shape().DebugString());
+ }
+ int64 num_entries = (*output_indices)->dim_size(0);
+ int rank = (*output_indices)->dim_size(1);
+ if (rank != expected_rank) {
+ return errors::InvalidArgument(
+ "Expected column counts of SparseTensor[", index,
+ "].indices to match size of SparseTensor[", index,
+ "].shape but they do not: ", rank, " vs. ", expected_rank);
+ }
+
+ // Deserialize and validate the values.
+ *output_values = serialized_values.get<Tensor>();
+ if (*output_values == nullptr) {
+ return errors::InvalidArgument(
+ "Could not get a tensor from serialized_sparse[", index, ", 1]");
+ }
+ if (!TensorShapeUtils::IsVector((*output_values)->shape())) {
+ return errors::InvalidArgument(
+ "Expected serialized_sparse[", index,
+ ", 1] to represent a values vector but received shape ",
+ (*output_values)->shape().DebugString());
+ }
+ if (dtype_ != (*output_values)->dtype()) {
+ return errors::InvalidArgument(
+ "Requested SparseTensor of type ", DataTypeString(dtype_),
+ " but SparseTensor[", index,
+ "].values.dtype() == ", DataTypeString((*output_values)->dtype()));
+ }
+ if (num_entries != (*output_values)->dim_size(0)) {
+ return errors::InvalidArgument(
+ "Expected row counts of SparseTensor[", index,
+ "].indices and SparseTensor[", index,
+ "].values to match but they do not: ", num_entries, " vs. ",
+ (*output_values)->dim_size(0));
+ }
+
+ return Status::OK();
+ }
+
+ DataType dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<Variant>("Tserialized"),
+ DeserializeSparseOp)
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h
index 2f83780525..56de6b1d43 100644
--- a/tensorflow/core/kernels/eigen_pooling.h
+++ b/tensorflow/core/kernels/eigen_pooling.h
@@ -372,16 +372,23 @@ struct reducer_traits<AvgPoolMeanReducer<float>, Device> {
Cost = 1,
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
// We only support packet access for floats.
- PacketAccess = true
+ PacketAccess = true,
#else
- PacketAccess = false
+ PacketAccess = false,
#endif
+ IsStateful = true,
+ IsExactlyAssociative = false
};
};
template <>
struct reducer_traits<AvgPoolMeanReducer<float>, GpuDevice> {
- enum { Cost = 1, PacketAccess = false };
+ enum {
+ Cost = 1,
+ PacketAccess = false,
+ IsStateful = true,
+ IsExactlyAssociative = false
+ };
};
} // namespace internal
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index a23478af5b..d6e859f1aa 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
+ : TypedQueueOp(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
+}
+
+Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
+ FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
+ component_shapes_, cinfo_.name());
+ return CreateTypedQueue(queue, ret);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index f01d70924d..697ee81c39 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/typed_queue.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -69,6 +70,22 @@ class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueue);
};
+// Defines a FIFOQueueOp, which produces a Queue (specifically, one
+// backed by FIFOQueue) that persists across different graph
+// executions, and sessions. Running this op produces a single-element
+// tensor of handles to Queues in the corresponding device.
+class FIFOQueueOp : public TypedQueueOp {
+ public:
+ explicit FIFOQueueOp(OpKernelConstruction* context);
+
+ private:
+ Status CreateResource(QueueInterface** ret) override
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ std::vector<TensorShape> component_shapes_;
+ TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FIFO_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/fifo_queue_op.cc b/tensorflow/core/kernels/fifo_queue_op.cc
index b35bdbb2f0..80869768f1 100644
--- a/tensorflow/core/kernels/fifo_queue_op.cc
+++ b/tensorflow/core/kernels/fifo_queue_op.cc
@@ -13,50 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
-#include <deque>
-#include <vector>
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fifo_queue.h"
-#include "tensorflow/core/kernels/queue_base.h"
-#include "tensorflow/core/kernels/queue_op.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-// Defines a FIFOQueueOp, which produces a Queue (specifically, one
-// backed by FIFOQueue) that persists across different graph
-// executions, and sessions. Running this op produces a single-element
-// tensor of handles to Queues in the corresponding device.
-class FIFOQueueOp : public TypedQueueOp {
- public:
- explicit FIFOQueueOp(OpKernelConstruction* context) : TypedQueueOp(context) {
- OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
- }
-
- private:
- Status CreateResource(QueueInterface** ret) override
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
- component_shapes_, cinfo_.name());
- return CreateTypedQueue(queue, ret);
- }
-
- std::vector<TensorShape> component_shapes_;
- TF_DISALLOW_COPY_AND_ASSIGN(FIFOQueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("FIFOQueue").Device(DEVICE_CPU), FIFOQueueOp);
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index f2724735bf..d5c33c0188 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -135,6 +135,12 @@ REGISTER_KERNEL_BUILDER(Name(kArgOp)
.TypeConstraint<ResourceHandle>("T"),
ArgOp);
+REGISTER_KERNEL_BUILDER(Name(kArgOp)
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<string>("T"),
+ ArgOp);
+
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
@@ -149,6 +155,12 @@ REGISTER_KERNEL_BUILDER(Name(kRetOp)
.TypeConstraint<ResourceHandle>("T")
.HostMemory("input"),
RetvalOp);
+
+REGISTER_KERNEL_BUILDER(Name(kRetOp)
+ .Device(DEVICE_GPU)
+ .TypeConstraint<string>("T")
+ .HostMemory("input"),
+ RetvalOp);
#undef REGISTER
class PassOn : public OpKernel {
@@ -297,20 +309,28 @@ class RemoteCallOp : public AsyncOpKernel {
explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx,
ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
}
~RemoteCallOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- const Tensor* target;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- const string& target_device =
- DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
-
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
+
+ const string& source_device = lib->device()->name();
+ const Tensor* target;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+ string target_device;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
+ source_device, &target_device),
+ done);
+
AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
instantiate_opts.target = target_device;
@@ -345,7 +365,7 @@ class RemoteCallOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.runner = ctx->runner();
- opts.source_device = lib->device()->name();
+ opts.source_device = source_device;
if (opts.source_device != target_device) {
opts.remote_execution = true;
}
@@ -355,6 +375,20 @@ class RemoteCallOp : public AsyncOpKernel {
for (const Tensor& argument : arguments) {
args.push_back(argument);
}
+ for (const auto& dtype : input_dtypes_) {
+ AllocatorAttributes arg_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ arg_alloc_attrs.set_on_host(true);
+ }
+ opts.args_alloc_attrs.push_back(arg_alloc_attrs);
+ }
+ for (const auto& dtype : output_dtypes_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
auto* rets = new std::vector<Tensor>;
auto* activity = new tracing::ScopedActivity(strings::StrCat(
"RemoteCall: Run: ", func_.name(), " on ", target_device));
@@ -377,6 +411,8 @@ class RemoteCallOp : public AsyncOpKernel {
private:
NameAttrList func_;
+ DataTypeVector input_dtypes_;
+ DataTypeVector output_dtypes_;
mutex mu_;
typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget;
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 9ae04a1062..519c475332 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -152,7 +152,7 @@ class IfOp : public AsyncOpKernel {
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
- done_(done),
+ done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
for (int i = 1; i < ctx_->num_inputs(); ++i) {
@@ -174,9 +174,9 @@ class IfOp : public AsyncOpKernel {
s = SetOutputs(kernel_, ctx_, rets_);
}
ctx_->SetStatus(s);
- auto done = done_;
+ DoneCallback captured_done(std::move(done_));
delete this;
- done();
+ captured_done();
});
}
@@ -184,7 +184,7 @@ class IfOp : public AsyncOpKernel {
IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
- const DoneCallback done_;
+ DoneCallback done_;
FunctionLibraryRuntime* const lib_;
FunctionLibraryRuntime::Options opts_;
TensorVec args_;
@@ -257,7 +257,7 @@ class WhileOp : public AsyncOpKernel {
ctx_(ctx),
cond_handle_(cond_handle),
body_handle_(body_handle),
- done_(done),
+ done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, false /* always_collect_stats */);
for (int i = 0; i < ctx_->num_inputs(); ++i) {
@@ -518,5 +518,24 @@ REGISTER_KERNEL_BUILDER(Name("For")
.HostMemory("delta"),
ForOp);
+class FakeParamOp : public OpKernel {
+ public:
+ explicit FakeParamOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // We must produce something (only Switch and Recvs are allowed to output
+ // dead tensors). This output is not expected to be consumed by anything.
+ Tensor output_tensor(dtype_, TensorShape({}));
+ context->set_output(0, output_tensor);
+ }
+
+ private:
+ DataType dtype_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("FakeParam").Device(DEVICE_CPU), FakeParamOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index 990cbceac2..b4f81d9a70 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -51,7 +51,7 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
- Status ExportValues(OpKernelContext* context) {
+ Status ExportValues(OpKernelContext* context) override {
return errors::Unimplemented(
"ExportValues not supported by InitializableLookupTable "
"implementations");
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index 8f51cc3819..8ddf3c38e8 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -50,7 +50,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
#define CASE(type) \
case DataTypeToEnum<type>::value: \
return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
- TF_CALL_NUMBER_TYPES(CASE);
+ TF_CALL_POD_TYPES(CASE);
TF_CALL_string(CASE);
TF_CALL_variant(CASE);
#undef CASE
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index f9c15ce6d7..b596dbc782 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -488,8 +488,31 @@ class MatMulOp : public OpKernel {
return;
}
- LaunchMatMul<Device, T, USE_CUBLAS>::launch(
- ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
+ if (std::is_same<T, bfloat16>::value) {
+ bool is_cpu = std::is_same<Device, CPUDevice>::value;
+ OP_REQUIRES(ctx, is_cpu,
+ errors::Internal("bfloat16 matmul is not supported by GPU"));
+ Tensor a_float, b_float, out_float;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, a.shape(), &a_float));
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, b.shape(), &b_float));
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DT_FLOAT, out->shape(), &out_float));
+
+ // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
+ BFloat16ToFloat(a.flat<bfloat16>().data(), a_float.flat<float>().data(),
+ a.NumElements());
+ BFloat16ToFloat(b.flat<bfloat16>().data(), b_float.flat<float>().data(),
+ b.NumElements());
+
+ LaunchMatMul<Device, float, USE_CUBLAS>::launch(
+ ctx, a_float, b_float, dim_pair, &algorithms_, use_autotune_,
+ &out_float);
+ FloatToBFloat16(out_float.flat<float>().data(),
+ out->flat<bfloat16>().data(), out->NumElements());
+ } else {
+ LaunchMatMul<Device, T, USE_CUBLAS>::launch(
+ ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
+ }
}
private:
@@ -551,13 +574,15 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
+
// MKL does not support half and int32 types for matrix-multiplication, so
// register the kernel to use default Eigen based implementations for these
// types. Registration for NO-LABEL version is in mkl_matmul_op.cc
TF_CALL_float(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
@@ -566,6 +591,7 @@ TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index b539b00009..3d04aeeb3e 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -24,15 +24,16 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::sum;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -333,7 +334,7 @@ class MklAddNOp : public OpKernel {
if (!input1_in_mkl_format && src1_dims_size == 0) {
Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
+ MklDnnShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
@@ -347,7 +348,7 @@ class MklAddNOp : public OpKernel {
if (!input1_in_mkl_format && !input2_in_mkl_format) {
if (src1_tensor.shape().num_elements() == 0) {
Tensor* dst_tensor = nullptr;
- MklShape mkl_shape_dst;
+ MklDnnShape mkl_shape_dst;
mkl_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(ctx, output_idx, &dst_tensor,
src1_tensor.shape(), mkl_shape_dst);
@@ -444,11 +445,10 @@ class MklAddNOp : public OpKernel {
// atleast one input is in MKL format, we choose output descriptor for
// reorder.
std::vector<primitive::at> inputs;
- std::vector<primitive> net;
// Check if actual input format of the tensor is different than common_pd
// we told MKLDNN. In that case, we will need reorder.
- src1.CheckReorderToOpMem(srcs_pd[0], &net);
- src2.CheckReorderToOpMem(srcs_pd[1], &net);
+ src1.CheckReorderToOpMem(srcs_pd[0]);
+ src2.CheckReorderToOpMem(srcs_pd[1]);
inputs.push_back(src1.GetOpMem());
inputs.push_back(src2.GetOpMem());
@@ -481,6 +481,7 @@ class MklAddNOp : public OpKernel {
dst.SetUsrMemDataHandle(dst_tensor);
// Create Sum op, and submit net for execution.
+ std::vector<primitive> net;
net.push_back(sum(sum_pd, inputs, dst.GetOpMem()));
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 723b445a75..45328b03d6 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include <vector>
#include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index a9b952095d..6f490cdc23 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -27,16 +27,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::concat;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -703,14 +704,14 @@ class MklConcatOp : public OpKernel {
if (input_tensors[k].NumElements() == 0)
continue;
- auto src_dims = TFShapeToMklDnnDims(
- mkl_input_shapes[k].GetTfShape());
auto src_md = mkl_input_shapes[k].GetMklLayout();
srcs[k].SetUsrMem(src_md, &input_tensors[k]);
- if (src_md.data.format != mkl_common_format)
+ if (src_md.data.format != mkl_common_format) {
+ memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]);
src_md = memory::desc(src_dims, MklDnnType<T>(),
mkl_common_format);
+ }
srcs_pd.push_back(memory::primitive_desc(src_md, cpu_engine));
}
@@ -755,11 +756,10 @@ class MklConcatOp : public OpKernel {
}
std::vector<primitive::at> inputs;
- std::vector<primitive> net;
if (isMklReorderNeeded) {
for (int k = 0; k < input_tensors.size(); k++) {
if (input_tensors[k].NumElements() > 0) {
- srcs[k].CheckReorderToOpMem(srcs_pd[k], &net);
+ srcs[k].CheckReorderToOpMem(srcs_pd[k]);
}
}
}
@@ -805,6 +805,7 @@ class MklConcatOp : public OpKernel {
dst.SetUsrMem(dst_md, dst_tensor);
auto concat_op = concat(concat_pd, inputs, dst.GetOpMem());
+ std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
} catch (mkldnn::error& e) {
diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
index a6698a1a07..f857be6c32 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
@@ -39,8 +39,10 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index e0706568b1..4e80f5acce 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -38,9 +38,6 @@ limitations under the License.
#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
@@ -49,12 +46,319 @@ using mkldnn::convolution_backward_weights;
using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
-namespace tensorflow {
+#include "tensorflow/core/util/mkl_util.h"
+namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_ML
+
+struct MklConvBwdFilterParams {
+ memory::dims src_dims;
+ memory::dims diff_filter_dims;
+ memory::dims diff_bias_dims;
+ memory::dims diff_dst_dims;
+ memory::dims strides;
+ memory::dims dilations;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ padding_kind padding;
+
+ MklConvBwdFilterParams(memory::dims src_dims,
+ memory::dims diff_filter_dims, memory::dims diff_bias_dims,
+ memory::dims diff_dst_dims, memory::dims strides,
+ memory::dims dilations, memory::dims padding_left,
+ memory::dims padding_right, padding_kind padding) :
+ src_dims(src_dims), diff_filter_dims(diff_filter_dims),
+ diff_bias_dims(diff_bias_dims), diff_dst_dims(diff_dst_dims),
+ strides(strides), dilations(dilations),
+ padding_left(padding_left), padding_right(padding_right),
+ padding(padding) {
+ }
+};
+
+template <typename T>
+class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+ public:
+ explicit MklConv2DBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims) :
+ cpu_engine_(engine::cpu, 0) {
+ context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
+ // create conv primitive
+ if (context_.conv_bwd_filter == nullptr) {
+ Setup(convBwdFilterDims);
+ }
+ }
+
+ ~MklConv2DBwdFilterPrimitive() {}
+
+ // Convolution backward weights with bias
+ // src_data: input data buffer of src
+ // diff_filter_data: output data buffer of diff_filter
+ // diff_bias_data: output data buffer of diff_bias
+ // diff_dst_data: input data buffer of diff_dst
+ void Execute(const T* src_data, const T* diff_filter_data,
+ const T* diff_bias_data, const T* diff_dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_filter_data)));
+ context_.diff_bias_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_bias_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_filter_mem->set_data_handle(DummyData);
+ context_.diff_bias_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ // Convolution backward weights without bias
+ // src_data: input data buffer of src
+ // diff_filter_data: output data buffer of diff_filter
+ // diff_dst_data: input data buffer of diff_dst
+ void Execute(const T* src_data,
+ const T* diff_filter_data, const T* diff_dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_filter_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_filter_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ memory::format GetSrcMemoryFormat() const {
+ return context_.src_fmt;
+ }
+
+ memory::format GetDiffDstMemoryFormat() const {
+ return context_.diff_dst_fmt;
+ }
+
+ memory::format GetDiffFilterMemoryFormat() const {
+ return context_.diff_filter_fmt;
+ }
+
+ // convolution primitive
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+ GetPrimitiveDesc() const {
+ return context_.bwd_filter_pd;
+ }
+
+ private:
+ // Primitive reuse context for Conv2D bwd filter op
+ struct ConvBwdFilterContext {
+ // expected memory format for this primitive instance
+ memory::format src_fmt;
+ memory::format diff_dst_fmt;
+ memory::format diff_filter_fmt;
+
+ // convolution bwd input primitive
+ std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
+ bwd_filter_pd;
+ std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> diff_filter_mem;
+ std::shared_ptr<mkldnn::memory> diff_bias_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc;
+ std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
+
+ // memory desc: forward & backward can share same memory desc
+ std::shared_ptr<mkldnn::memory::desc> src_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_filter_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_bias_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
+
+ // MKL pipeline
+ std::shared_ptr<mkldnn::stream> bwd_filter_stream;
+ std::vector<mkldnn::primitive> bwd_filter_primitives;
+
+ ConvBwdFilterContext() :
+ src_fmt(memory::format::any),
+ diff_dst_fmt(memory::format::any),
+ diff_filter_fmt(memory::format::any),
+ src_mem(nullptr), diff_filter_mem(nullptr),
+ diff_bias_mem(nullptr), diff_dst_mem(nullptr),
+ bwd_filter_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr),
+ src_md(nullptr), diff_filter_md(nullptr),
+ diff_bias_md(nullptr), diff_dst_md(nullptr),
+ bwd_filter_stream(nullptr) {
+ }
+ };
+
+ // Setup Conv2d backward filter (weights) primitives.
+ void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
+ // create memory descriptors for convolution data w/ no specified format
+ context_.src_md.reset(new memory::desc({convBwdFilterDims.src_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ context_.diff_dst_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_dst_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ context_.diff_filter_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_filter_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ if (!convBwdFilterDims.diff_bias_dims.empty())
+ context_.diff_bias_md.reset(new memory::desc(
+ {convBwdFilterDims.diff_bias_dims},
+ MklDnnType<T>(), memory::format::x));
+
+ // create a convolution
+ if (!convBwdFilterDims.diff_bias_dims.empty()) {
+ context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
+ convolution_direct, *context_.src_md, *context_.diff_filter_md,
+ *context_.diff_bias_md, *context_.diff_dst_md,
+ convBwdFilterDims.strides, convBwdFilterDims.dilations,
+ convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
+ convBwdFilterDims.padding));
+ } else {
+ context_.bwd_filter_desc.reset(
+ new convolution_backward_weights::desc(
+ convolution_direct, *context_.src_md, *context_.diff_filter_md,
+ *context_.diff_dst_md, convBwdFilterDims.strides,
+ convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
+ convBwdFilterDims.padding_right, convBwdFilterDims.padding));
+ }
+
+ // create fwd primitive_desc
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct,
+ *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md,
+ convBwdFilterDims.strides,
+ convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
+ convBwdFilterDims.padding_right, convBwdFilterDims.padding));
+ context_.fwd_pd.reset(new convolution_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create backward conv primitive_desc
+ context_.bwd_filter_pd.reset(
+ new convolution_backward_weights::primitive_desc(
+ *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
+
+ // store the expected memory format
+ auto bwd_filter_pd = context_.bwd_filter_pd.get();
+ context_.src_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->src_primitive_desc().desc().data.format);
+ context_.diff_filter_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->diff_weights_primitive_desc().desc().data.format);
+ context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
+ bwd_filter_pd->diff_dst_primitive_desc().desc().data.format);
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(
+ bwd_filter_pd->src_primitive_desc(), DummyData));
+ context_.diff_filter_mem.reset(new memory(
+ bwd_filter_pd->diff_weights_primitive_desc(), DummyData));
+ context_.diff_dst_mem.reset(new memory(
+ bwd_filter_pd->diff_dst_primitive_desc(), DummyData));
+
+ // create convolution primitive and add it to net
+ if (!convBwdFilterDims.diff_bias_dims.empty()) {
+ context_.diff_bias_mem.reset(new memory(
+ {{{convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
+ memory::format::x}, cpu_engine_}, DummyData));
+ context_.conv_bwd_filter.reset(new convolution_backward_weights(
+ *context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
+ *context_.diff_filter_mem, *context_.diff_bias_mem));
+ } else {
+ context_.conv_bwd_filter.reset(new convolution_backward_weights(
+ *context_.bwd_filter_pd, *context_.src_mem,
+ *context_.diff_dst_mem, *context_.diff_filter_mem));
+ }
+
+ context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
+ }
+
+ struct ConvBwdFilterContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklConv2DBwdFilterPrimitive<T>* Get(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
+
+ // look into the pool for reusable primitive
+ conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
+ MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
+ convBwdFilterDims));
+
+ if (conv2d_bwd_filter == nullptr) {
+ conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
+ convBwdFilterDims);
+ MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
+ convBwdFilterDims, conv2d_bwd_filter);
+ }
+ return conv2d_bwd_filter;
+ }
+
+
+ private:
+ MklConv2DBwdFilterPrimitiveFactory() {}
+ ~MklConv2DBwdFilterPrimitiveFactory() {}
+
+ static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConv2DBwdFilterPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ static std::string CreateKey(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ std::string prefix = "conv2d_bwd_filter";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(convBwdFilterDims.src_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_filter_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_bias_dims);
+ key_creator.AddAsKey(convBwdFilterDims.diff_dst_dims);
+ key_creator.AddAsKey(convBwdFilterDims.strides);
+ key_creator.AddAsKey(convBwdFilterDims.dilations);
+ key_creator.AddAsKey(convBwdFilterDims.padding_left);
+ key_creator.AddAsKey(convBwdFilterDims.padding_right);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetConv2dBwdFilter(
+ const MklConvBwdFilterParams& convBwdFilterDims) {
+ std::string key = CreateKey(convBwdFilterDims);
+ return this->GetOp(key);
+ }
+
+ void SetConv2dBwdFilter(
+ const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ std::string key = CreateKey(convBwdFilterDims);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
#ifdef INTEL_MKL_ML
template <typename Device, class T>
@@ -440,11 +744,213 @@ class MklConv2DCustomBackpropFilterOp
: public MklConv2DBackpropCommonOp<Device, T> {
public:
explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {}
+ : MklConv2DBackpropCommonOp<Device, T>(context) {
+ }
+
~MklConv2DCustomBackpropFilterOp() {}
+ void Compute(OpKernelContext* context) {
+ try {
+ MklDnnData<T> src(&cpu_engine_);
+ MklDnnData<T> diff_dst(&cpu_engine_);
+ MklDnnData<T> diff_filter(&cpu_engine_); // output
+
+ // Input tensors
+ const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
+ const Tensor& src_tensor = MklGetInput(context, kInputIdx);
+ const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
+ const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
+
+ MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
+ GetMklShape(context, kInputIdx, &src_mkl_shape);
+ GetMklShape(context, kFilterIdx, &filter_mkl_shape);
+ GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+ // Allow operator-specific sanity checking of shapes.
+ ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
+
+ // Allow operator-specific generation of shapes.
+ // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // tensor containing shape of filter. So filter.shape() is not
+ // a correct way to get filter shape. These operator-specific calls
+ // allow this class to handle this case.
+ TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
+ TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
+ TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+
+ // Corner cases: output with 0 elements and 0 batch size.
+ Tensor* diff_filter_tensor = nullptr;
+ if (src_tf_shape.num_elements() == 0 ||
+ filter_tf_shape.num_elements() == 0 ||
+ diff_dst_tf_shape.num_elements() == 0) {
+ MklDnnShape diff_filter_mkl_shape;
+ diff_filter_mkl_shape.SetMklTensor(false);
+ TensorShape diff_filter_tf_shape = GetOutputTfShape(
+ src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
+ const int kOutputIdx = 0;
+ AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ CHECK_NOTNULL(diff_filter_tensor);
+
+ // if output tensor has more than 0 elements, we need to 0 them out.
+ auto diff_filter_data = diff_filter_tensor->flat<T>().data();
+ for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
+ diff_filter_data[i] = 0;
+ }
+ return;
+ }
+
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with prefix tf_order.
+ memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
+ memory::dims padding_left, padding_right, dilations,
+ strides, fwd_dst_dims;
+ memory::dims fwd_dst_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
+ this->data_format_, this->dilations_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
+ &strides, &dilations, &fwd_dst_dims_tf_order,
+ &fwd_dst_dims, &padding_left, &padding_right);
+ if (!context->status().ok()) return;
+
+ auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto fwd_src_md =
+ src_mkl_shape.IsMklTensor()
+ ? src_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt);
+
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ if (!context->status().ok()) return;
+
+ auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor()
+ ? diff_dst_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims,
+ MklDnnType<T>(), tf_fmt);
+
+ memory::dims diff_bias_dims = {};
+ int64 depth = 0;
+ if (biasEnabled) {
+ TensorShape obp_tf_shape = GetTfShape(context, 2);
+ depth = (this->data_format_ == FORMAT_NCHW)
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(3);
+ diff_bias_dims = {static_cast<int>(depth)};
+ }
+
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+
+ MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
+ diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
+ padding_right, TFPaddingToMklDnnPadding(this->padding_));
+ conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims);
+ auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+
+ // allocate output tensors: diff_fitler and diff_bias (w bias)
+ auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
+
+ // diff_filter
+ MklDnnShape diff_filter_mkl_shape;
+ diff_filter_mkl_shape.SetMklTensor(false);
+ // output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ Tensor* diff_bias_tensor = nullptr;
+ if (biasEnabled) {
+ TensorShape diff_bias_shape({depth});
+ AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
+ }
+
+ // check if src and diff_dst need reorder
+ std::vector<primitive> net;
+ T *src_data = nullptr;
+ if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ src.SetUsrMem(fwd_src_md, &src_tensor);
+ src.CheckReorderToOpMem(
+ bwd_filter_pd->src_primitive_desc(), &net);
+ src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
+ } else {
+ src_data = static_cast<T*>(const_cast<T*>(
+ src_tensor.flat<T>().data()));
+ }
+
+ T *diff_dst_data = nullptr;
+ if (diff_dst_md.data.format !=
+ conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ bwd_filter_pd->diff_dst_primitive_desc(), &net);
+ diff_dst_data = static_cast<T*>(
+ diff_dst.GetOpMem().get_data_handle());
+ } else {
+ diff_dst_data = static_cast<T*>(const_cast<T*>(
+ diff_dst_tensor.flat<T>().data()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+
+ // For backward filter, convert diff_filter back to Tensorflow layout
+ // Here we prepare to reorder op memory back to user memory
+ bool diff_filter_reorder_required = false;
+ T *diff_filter_data = nullptr;
+ if (GetOutputFormat(tf_fmt) !=
+ conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ // Allocate diff filter tensor as Tensorflow layout
+ diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
+ diff_filter_tensor);
+ diff_filter_reorder_required = true;
+ diff_filter.PrepareReorderToUserMemIfReq(
+ bwd_filter_pd->diff_weights_primitive_desc());
+ diff_filter_data = static_cast<T*>(
+ diff_filter.GetOpMem().get_data_handle());
+ } else {
+ diff_filter_data = static_cast<T*>(const_cast<T*>(
+ diff_filter_tensor->flat<T>().data()));
+ }
+
+ // Execute convolution filter bwd
+ if (biasEnabled) {
+ T* diff_bias_data = static_cast<T*>(const_cast<T*>(
+ diff_bias_tensor->flat<T>().data()));
+ conv2d_bwd_filter->Execute(src_data, diff_filter_data,
+ diff_bias_data, diff_dst_data);
+ } else {
+ conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ }
+
+ // Reorder diff_filter back to Tensorflow layout if necessary
+ if (diff_filter_reorder_required) {
+ std::vector<primitive> net;
+ diff_filter.InsertReorderToUserMem(&net);
+ stream(stream::kind::eager).submit(net).wait();
+ }
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
private:
+ const int kInputIndex_Filter = 1;
+ const int kInputIndex_InputSizes = 0;
const int kDilationH = 0, kDilationW = 1;
+ engine cpu_engine_ = engine(engine::cpu, 0);
+
+ // Validate input shapes.
+ // Function asserts that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -452,141 +958,44 @@ class MklConv2DCustomBackpropFilterOp
<< "Conv2DBackpropFilter: filter should not be in MKL Layout";
}
- size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ }
-
+ // Get TensorFlow shape of input tensor.
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
size_t input_idx = 0;
return GetTfShape(context, input_idx);
}
+ // Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
TensorShape filter_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(),
- &filter_tf_shape)
- .ok(),
- true);
+ &filter_tf_shape).ok(), true);
return filter_tf_shape;
}
+ // Get Tensorflow shape of output tensor (diff_filter),
+ // which is same as shape of filter.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
- // Shape of output of Conv2DBackpropFilter is same as shape of filter.
return filter_shape;
}
+ // Get the shape of output (diff_filter) in MKL-DNN order.
+ // Computes shape of output from input shape (fwd_input_dims)
+ // and filter shape (fwd_filter_dims).
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
- // Shape of output of Conv2DBackpropFilter is same as shape of filter.
return fwd_filter_dims;
}
+ // Output layout is Tensorflow's filter layout (HWIO).
memory::format GetOutputFormat(const memory::format data_format) {
- // Output layout is Tensorflow's filter layout (HWIO).
return memory::format::hwio;
}
- void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter,
- MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor,
- const memory::dims& strides,
- const memory::dims& dilations,
- const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) {
- CHECK_NOTNULL(context);
- CHECK_NOTNULL(input);
- CHECK_NOTNULL(filter);
- CHECK_NOTNULL(outbackprop);
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(output_tensor);
-
- MklDnnData<T>* bias_grad = nullptr;
- int depth = 0;
- if (biasEnabled) {
- // Data structure for bias_grad
- bias_grad = new MklDnnData<T>(&cpu_engine);
- TensorShape obp_tf_shape = GetTfShape(context, 2);
- depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat() ==
- FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
- memory::dims bias_grad_dims = {depth};
- bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x);
- }
-
- if (biasEnabled && (bias_grad != nullptr)) {
- // Create convolution backward weights with bias primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- bias_grad->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding) :
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- bias_grad->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- conv_fwd_pd);
-
- // Allocate output tensor.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
- bwd_output_format, output_tensor);
-
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
-
- // Allocate bias_grad tensor
- TensorShape bias_grad_shape({depth});
- Tensor* bias_grad_tensor = nullptr;
- AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor);
- memory::dims bias_grad_dims = {depth};
- // Since Bias is 1D, we use format::x from MKLDNN to represent it.
- auto bias_grad_md =
- memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x);
- bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor);
- bias_grad->SetUsrMemDataHandle(bias_grad_tensor);
-
- PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output,
- bias_grad);
- } else {
- // Create convolution backward weights primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding) :
- convolution_backward_weights::desc(convolution_direct,
- input->GetOpMemDesc(), output->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- conv_fwd_pd);
-
- // Allocate output tensor.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
- bwd_output_format, output_tensor);
-
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
- PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output);
- }
- }
-
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
@@ -621,40 +1030,8 @@ class MklConv2DCustomBackpropFilterOp
MklDnnShape bias_grad_mkl_shape;
bias_grad_mkl_shape.SetMklTensor(false);
- AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
- bias_grad_mkl_shape);
- }
-
- // Prepare and execute net - checks for input and output reorders.
- void PrepareAndExecutePrimitive(
- const convolution_backward_weights::primitive_desc& conv_pd,
- MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output,
- MklDnnData<T>* bias_grad = nullptr) {
- // Create reorders between user layout and MKL layout if it is needed and
- // add it to the net before convolution.
- std::vector<primitive> net;
- input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
- obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
-
- // For BackpropFilter, we convert the output tensor back in Tensorflow
- // layout.
- bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
- conv_pd.diff_weights_primitive_desc());
-
- if (biasEnabled && (bias_grad != nullptr)) {
- net.push_back(convolution_backward_weights(
- conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(),
- bias_grad->GetOpMem()));
- } else {
- net.push_back(convolution_backward_weights(
- conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
- }
-
- if (output_reorder_required) {
- output->InsertReorderToUserMem(&net);
- }
-
- stream(stream::kind::eager).submit(net).wait();
+ AllocateOutputSetMklShape(context, 1, bias_grad_tensor,
+ bias_grad_shape, bias_grad_mkl_shape);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index d203c04934..0af4568b47 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -23,8 +23,10 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -53,9 +55,246 @@ using mkldnn::stream;
#endif
namespace tensorflow {
-
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifndef INTEL_MKL_ML
+
+/// utility classes enabling primitive reuse for backward conv2d ops.
+struct MklConvBwdInputParams {
+ memory::dims diff_src_dims;
+ memory::dims filter_dims;
+ memory::dims diff_dst_dims;
+ memory::dims strides;
+ memory::dims dilations;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ padding_kind padding;
+
+ MklConvBwdInputParams(memory::dims diff_src_dims,
+ memory::dims filter_dims, memory::dims diff_dst_dims,
+ memory::dims strides, memory::dims dilations,
+ memory::dims padding_left, memory::dims padding_right,
+ padding_kind padding) :
+ diff_src_dims(diff_src_dims), filter_dims(filter_dims),
+ diff_dst_dims(diff_dst_dims), strides(strides),
+ dilations(dilations), padding_left(padding_left),
+ padding_right(padding_right), padding(padding) {
+ }
+};
+
+template <typename T>
+class MklConv2DBwdInputPrimitive : public MklPrimitive {
+ public:
+ explicit MklConv2DBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims) :
+ cpu_engine_(engine::cpu, 0) {
+ context_.bwd_input_stream.reset(new stream(stream::kind::eager));
+
+ // create conv primitive
+ if (context_.conv_bwd_input == nullptr) {
+ Setup(convBwdInputDims);
+ }
+ }
+ ~MklConv2DBwdInputPrimitive() {}
+
+ // Convolution backward filter (weights)
+ // diff_src_data: output data buffer of diff_src
+ // filter_data: input data buffer of filter (weights)
+ // diff_dst_data: input data buffer of dst
+ // Bias does not matter here
+ void Execute(const T* diff_src_data,
+ const T* filter_data, const T* diff_dst_data) {
+ context_.diff_src_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(diff_src_data)));
+ context_.filter_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(filter_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<T*>(const_cast<T*>(diff_dst_data)));
+
+ context_.bwd_input_stream->submit(context_.bwd_input_primitives);
+
+ // set back data handle
+ context_.diff_src_mem->set_data_handle(DummyData);
+ context_.filter_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ return;
+ }
+
+ memory::format GetFilterMemoryFormat() const {
+ return context_.filter_fmt;
+ }
+
+ memory::format GetDiffDstMemoryFormat() const {
+ return context_.diff_dst_fmt;
+ }
+
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+ GetPrimitiveDesc() const {
+ return context_.bwd_input_pd;
+ }
+
+ private:
+ // Primitive reuse context for Conv2D Bwd Input op
+ struct ConvBwdInputContext {
+ // expected memory format for this primitive instance
+ memory::format filter_fmt;
+ memory::format diff_dst_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+ std::shared_ptr<mkldnn::memory> filter_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // convolution primitive
+ std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
+ bwd_input_pd;
+ std::shared_ptr<mkldnn::primitive> conv_bwd_input;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::convolution_backward_data::desc> bwd_input_desc;
+ std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
+
+ // memory desc: forward & backward can share same memory::desc
+ std::shared_ptr<memory::desc> diff_src_md;
+ std::shared_ptr<memory::desc> filter_md;
+ std::shared_ptr<memory::desc> diff_dst_md;
+
+ // MKL pipeline
+ std::shared_ptr<mkldnn::stream> bwd_input_stream;
+ std::vector<mkldnn::primitive> bwd_input_primitives;
+
+ ConvBwdInputContext() :
+ filter_fmt(memory::format::any), diff_dst_fmt(memory::format::any),
+ diff_src_mem(nullptr), filter_mem(nullptr), diff_dst_mem(nullptr),
+ bwd_input_pd(nullptr), conv_bwd_input(nullptr),
+ bwd_input_desc(nullptr), fwd_desc(nullptr), fwd_pd(nullptr),
+ diff_src_md(nullptr), filter_md(nullptr), diff_dst_md(nullptr),
+ bwd_input_stream(nullptr) {
+ }
+ };
+
+
+ void Setup(const MklConvBwdInputParams& convBwdInputDims) {
+ // create memory descriptors for convolution data w/ no specified format
+ context_.diff_src_md.reset(new memory::desc(
+ {convBwdInputDims.diff_src_dims},
+ MklDnnType<T>(), memory::format::any));
+ context_.filter_md.reset(new memory::desc(
+ {convBwdInputDims.filter_dims},
+ MklDnnType<T>(), memory::format::any));
+ context_.diff_dst_md.reset(new memory::desc(
+ {convBwdInputDims.diff_dst_dims},
+ MklDnnType<T>(), memory::format::any));
+
+ // create convolution primitives
+ context_.bwd_input_desc.reset(new convolution_backward_data::desc(
+ convolution_direct, *context_.diff_src_md, *context_.filter_md,
+ *context_.diff_dst_md, convBwdInputDims.strides,
+ convBwdInputDims.dilations, convBwdInputDims.padding_left,
+ convBwdInputDims.padding_right, convBwdInputDims.padding));
+
+ context_.fwd_desc.reset(new convolution_forward::desc(prop_kind::forward,
+ convolution_direct, *context_.diff_src_md, *context_.filter_md,
+ *context_.diff_dst_md, convBwdInputDims.strides,
+ convBwdInputDims.dilations, convBwdInputDims.padding_left,
+ convBwdInputDims.padding_right, convBwdInputDims.padding));
+
+ context_.fwd_pd.reset(new convolution_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create backward conv prim desc
+ context_.bwd_input_pd.reset(
+ new convolution_backward_data::primitive_desc(
+ *context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd));
+
+ // create memory primitive based on dummy data
+ context_.diff_src_mem.reset(new memory(
+ context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData));
+ context_.filter_mem.reset(new memory(
+ context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData));
+ context_.diff_dst_mem.reset(new memory(
+ context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData));
+
+ // store the expected memory format
+ context_.filter_fmt = static_cast<memory::format>(
+ context_.bwd_input_pd.get()->weights_primitive_desc().desc().data.format);
+ context_.diff_dst_fmt = static_cast<memory::format>(
+ context_.bwd_input_pd.get()->diff_dst_primitive_desc().desc().data.format);
+
+ // create convolution primitive and add it to net
+ context_.conv_bwd_input.reset(new convolution_backward_data(
+ *context_.bwd_input_pd, *context_.diff_dst_mem,
+ *context_.filter_mem, *context_.diff_src_mem));
+
+ context_.bwd_input_primitives.push_back(*context_.conv_bwd_input);
+ }
+
+ struct ConvBwdInputContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+ private:
+ MklConv2DBwdInputPrimitiveFactory() {}
+ ~MklConv2DBwdInputPrimitiveFactory() {}
+
+ public:
+ static MklConv2DBwdInputPrimitive<T>* Get(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
+
+ // look into the pool for reusable primitive
+ conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
+ MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
+ convBwdInputDims));
+
+ if (conv2d_bwd_input == nullptr) {
+ conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
+ convBwdInputDims);
+ MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
+ convBwdInputDims, conv2d_bwd_input);
+ }
+ return conv2d_bwd_input;
+ }
+
+ private:
+ static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
+ static MklConv2DBwdInputPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ static std::string CreateKey(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ std::string prefix = "conv2d_bwd_input";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
+ key_creator.AddAsKey(convBwdInputDims.filter_dims);
+ key_creator.AddAsKey(convBwdInputDims.diff_dst_dims);
+ key_creator.AddAsKey(convBwdInputDims.strides);
+ key_creator.AddAsKey(convBwdInputDims.dilations);
+ key_creator.AddAsKey(convBwdInputDims.padding_left);
+ key_creator.AddAsKey(convBwdInputDims.padding_right);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetConv2dBwdInput(
+ const MklConvBwdInputParams& convBwdInputDims) {
+ std::string key = CreateKey(convBwdInputDims);
+ return this->GetOp(key);
+ }
+
+ void SetConv2dBwdInput(
+ const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ std::string key = CreateKey(convBwdInputDims);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
#ifdef INTEL_MKL_ML
template <typename Device, class T>
@@ -363,13 +602,173 @@ class MklConv2DCustomBackpropInputOp
: public MklConv2DBackpropCommonOp<Device, T> {
public:
explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {}
+ : MklConv2DBackpropCommonOp<Device, T>(context) {
+ }
+
~MklConv2DCustomBackpropInputOp() {}
+ void Compute(OpKernelContext* context) {
+ try {
+ MklDnnData<T> filter(&cpu_engine);
+ MklDnnData<T> diff_dst(&cpu_engine);
+
+ // Input tensors
+ const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
+ const Tensor& src_tensor = MklGetInput(context, kInputIdx);
+ const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
+ const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
+
+ MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
+ GetMklShape(context, kInputIdx, &src_mkl_shape);
+ GetMklShape(context, kFilterIdx, &filter_mkl_shape);
+ GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape);
+ // Allow operator-specific sanity checking of shapes.
+ ValidateMklShapes(src_mkl_shape, filter_mkl_shape,
+ diff_dst_mkl_shape);
+
+ // Allow operator-specific generation of shapes.
+ // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // tensor containing shape of filter. So filter.shape() is not
+ // a correct way to get filter shape. These operator-specific calls
+ // allow this class to handle this case.
+ TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
+ TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
+ TensorShape diff_dst_tf_shape = GetTfShape(context, kOutbpropIdx);
+
+ // Corner cases: output with 0 elements and 0 batch size.
+ Tensor* diff_src_tensor = nullptr;
+ if (src_tf_shape.num_elements() == 0 ||
+ filter_tf_shape.num_elements() == 0 ||
+ diff_dst_tf_shape.num_elements() == 0) {
+ MklDnnShape diff_src_mkl_shape;
+ diff_src_mkl_shape.SetMklTensor(false);
+ TensorShape diff_src_tf_shape = GetOutputTfShape(
+ src_tf_shape, filter_tf_shape, diff_dst_tf_shape);
+ const int kOutputIdx = 0;
+ AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
+ diff_src_tf_shape, diff_src_mkl_shape);
+ CHECK_NOTNULL(diff_src_tensor);
+
+ // if output tensor has more than 0 elements, we need to 0 them out.
+ auto diff_src_data = diff_src_tensor->flat<T>().data();
+ for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) {
+ diff_src_data[i] = 0;
+ }
+ return;
+ }
+ // By default, all dims are in MKL order. Only dims in TF order
+ // are those with postfix tf_order.
+ memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
+ memory::dims padding_left, padding_right, dilations, strides;
+ memory::dims fwd_output_dims, fwd_output_dims_tf_order;
+
+ // Get forward convolution parameters.
+ MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
+ this->data_format_, this->dilations_);
+ conv_utl.GetConvFwdSizesInMklOrder(
+ src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
+ &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
+ &padding_left, &padding_right);
+ if (!context->status().ok()) return;
+
+ // Create Convolution forward descriptor since Convolution backward
+ // API needs it. For that, we first need to create input, filter
+ // and output memory descriptors.
+ auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+
+ // If filter is in MKL layout, then simply grab filter layout;
+ // otherwise, construct filter in TF layout.
+ // For TF layout, filter is in HWIO format.
+ auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ memory::format::hwio);
+
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ if (!context->status().ok()) return;
+ auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor()
+ ? diff_dst_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims,
+ MklDnnType<T>(), tf_fmt);
+
+ dilations[kDilationH] -= 1;
+ dilations[kDilationW] -= 1;
+
+ MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
+ conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
+ diff_dst_dims, strides, dilations, padding_left, padding_right,
+ TFPaddingToMklDnnPadding(this->padding_));
+ conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
+ convBwdInputDims);
+ auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+
+ // allocate output tensor
+ auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
+ auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
+ auto bwd_diff_src_format = GetOutputFormat(tf_fmt);
+ MklDnnShape diff_src_mkl_shape;
+ diff_src_mkl_shape.SetMklTensor(true);
+ diff_src_mkl_shape.SetMklLayout(&diff_src_pd);
+ diff_src_mkl_shape.SetElemType(MklDnnType<T>());
+ diff_src_mkl_shape.SetTfLayout(bwd_diff_src_dims.size(),
+ bwd_diff_src_dims, bwd_diff_src_format);
+ TensorShape diff_src_tf_shape;
+ diff_src_tf_shape.AddDim(diff_src_pd.get_size() / sizeof(T));
+ AllocateOutputSetMklShape(context, 0, &diff_src_tensor,
+ diff_src_tf_shape, diff_src_mkl_shape);
+
+ T *diff_src_data = static_cast<T*>(const_cast<T*>(
+ diff_src_tensor->flat<T>().data()));
+
+ // check if filter and diff_dst need reorder
+ std::vector<primitive> net;
+ T* filter_data = nullptr;
+ if (fwd_filter_md.data.format !=
+ conv2d_bwd_input->GetFilterMemoryFormat()) {
+ filter.SetUsrMem(fwd_filter_md, &filter_tensor);
+ filter.CheckReorderToOpMem(
+ bwd_input_pd->weights_primitive_desc(),
+ &net);
+ filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
+ } else {
+ filter_data = static_cast<T*>(const_cast<T*>(
+ filter_tensor.flat<T>().data()));
+ }
+
+ T* diff_dst_data = nullptr;
+ if (diff_dst_md.data.format !=
+ conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ bwd_input_pd->diff_dst_primitive_desc(), &net);
+ diff_dst_data = static_cast<T*>(
+ diff_dst.GetOpMem().get_data_handle());
+ } else {
+ diff_dst_data = static_cast<T*>(const_cast<T*>(
+ diff_dst_tensor.flat<T>().data()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+
+ // execute convolution input bwd
+ conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) + ", in file " +
+ string(__FILE__) + ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
private:
- const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0,
- kInputIndex_OutBackProp = 2;
+ const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0;
const int kDilationH = 0, kDilationW = 1;
+ engine cpu_engine = engine(engine::cpu, 0);
+
+ // Validate input shapes.
+ // Function asserts that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@@ -380,8 +779,7 @@ class MklConv2DCustomBackpropInputOp
<< "Conv2DBackpropInput: input should not be in MKL Layout";
}
- size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; }
-
+ // Get TensorFlow shape of input tensor.
TensorShape MakeInputTfShape(OpKernelContext* context,
const Tensor& input_tensor) {
TensorShape input_tf_shape;
@@ -393,72 +791,32 @@ class MklConv2DCustomBackpropInputOp
return input_tf_shape;
}
+ // Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
return GetTfShape(context, kInputIndex_Filter);
}
+ // Get the Tensorflow shape of Output (diff_src),
+ // which is same as shape of Conv2D 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
- // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
return input_shape;
}
+ // Get the Tensorflow shape of Output (diff_src),
+ // which is same as shape of Conv2D 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
- // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
return fwd_input_dims;
}
+ // Output layout is Tensorflow's layout in data format order.
memory::format GetOutputFormat(const memory::format data_format) {
- // Output layout is Tensorflow's layout in data format order.
return data_format;
}
- void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter,
- MklDnnData<T>* outbackprop, MklDnnData<T>* output,
- Tensor** output_tensor,
- const memory::dims& strides,
- const memory::dims& dilations,
- const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) {
- CHECK_NOTNULL(context);
- CHECK_NOTNULL(input);
- CHECK_NOTNULL(filter);
- CHECK_NOTNULL(outbackprop);
- CHECK_NOTNULL(output);
- CHECK_NOTNULL(output_tensor);
-
- // Create convolution backward data primitive.
- // Use dilated convolution in case dilate rates are greater than zero.
- auto bwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0) ?
- convolution_backward_data::desc(convolution_direct,
- output->GetOpMemDesc(), filter->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(), strides,
- dilations, padding_l, padding_r, padding):
- convolution_backward_data::desc(convolution_direct,
- output->GetOpMemDesc(), filter->GetOpMemDesc(),
- outbackprop->GetOpMemDesc(),
- strides, padding_l, padding_r, padding);
-
- auto bwd_pd = convolution_backward_data::primitive_desc(
- bwd_desc, cpu_engine, conv_fwd_pd);
-
- // Allocate output tensor in TensorFlow and MKL layout.
- AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
- output_tensor);
- CHECK_NOTNULL(*output_tensor);
- // Set buffer handle using allocated output tensor.
- output->SetUsrMemDataHandle(*output_tensor);
-
- PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output);
- }
-
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
@@ -485,22 +843,6 @@ class MklConv2DCustomBackpropInputOp
AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
output_mkl_shape);
}
-
- // Prepare and execute net - checks for input and output reorders.
- void PrepareAndExecutePrimitive(
- const convolution_backward_data::primitive_desc& conv_pd,
- MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
- // Create reorders between user layout and MKL layout if it is needed and
- // add it to the net before convolution.
- std::vector<primitive> net;
- filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
- obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
-
- net.push_back(convolution_backward_data(
- conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
-
- stream(stream::kind::eager).submit(net).wait();
- }
};
#endif // INTEL_MKL_ML
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index f2b14f1278..b568973220 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -59,7 +59,8 @@ namespace tensorflow {
#ifndef INTEL_MKL_ML
-struct ConvFwdDimensions {
+// This structure aggregates multiple inputs to Conv2DFwd* methods.
+struct MklConvFwdParams {
memory::dims src_dims;
memory::dims filter_dims;
memory::dims bias_dims;
@@ -69,48 +70,56 @@ struct ConvFwdDimensions {
memory::dims padding_left;
memory::dims padding_right;
- ConvFwdDimensions(memory::dims src_dims,
- memory::dims filter_dims, memory::dims bias_dims,
- memory::dims dst_dims, memory::dims strides,
- memory::dims dilations, memory::dims padding_left,
- memory::dims padding_right) :
- src_dims(src_dims), filter_dims(filter_dims),
- bias_dims(bias_dims), dst_dims(dst_dims),
- strides(strides), dilations(dilations),
- padding_left(padding_left), padding_right(padding_right) {
- }
+ MklConvFwdParams(memory::dims src_dims, memory::dims filter_dims,
+ memory::dims bias_dims, memory::dims dst_dims,
+ memory::dims strides, memory::dims dilations,
+ memory::dims padding_left, memory::dims padding_right)
+ : src_dims(src_dims),
+ filter_dims(filter_dims),
+ bias_dims(bias_dims),
+ dst_dims(dst_dims),
+ strides(strides),
+ dilations(dilations),
+ padding_left(padding_left),
+ padding_right(padding_right) {}
};
template <typename T>
-class Conv2DFwd : public DnnOp {
+class MklConv2DFwdPrimitive : public MklPrimitive {
public:
- explicit Conv2DFwd(const ConvFwdDimensions& convFwdDims) {
- fwd_stream_.reset(new stream(stream::kind::eager));
+ explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
- if (conv_fwd_ == nullptr) {
+ if (context_.conv_fwd == nullptr) {
Setup(convFwdDims);
}
}
- ~Conv2DFwd() {}
+ ~MklConv2DFwdPrimitive() {}
// Convolution forward execute with bias
// src_data: input data buffer of src
// filter_data: input data buffer of filter (weights)
// bias_data: input data buffer of bias
// dst_data: output data buffer of dst
- void Execute(T* src_data, T* filter_data, T* bias_data, T* dst_data) {
- src_mem_->set_data_handle(static_cast<void*>(src_data));
- filter_mem_->set_data_handle(static_cast<void*>(filter_data));
- bias_mem_->set_data_handle(static_cast<void*>(bias_data));
- dst_mem_->set_data_handle(static_cast<void*>(dst_data));
- fwd_stream_->submit(fwd_primitives_);
+ void Execute(const T* src_data, const T* filter_data, const T* bias_data,
+ const T* dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(filter_data)));
+ context_.bias_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(bias_data)));
+ context_.dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(dst_data)));
+ context_.fwd_stream->submit(context_.fwd_primitives);
// after exec, set data handle back
- src_mem_->set_data_handle(DummyData);
- filter_mem_->set_data_handle(DummyData);
- bias_mem_->set_data_handle(DummyData);
- dst_mem_->set_data_handle(DummyData);
+ context_.src_mem->set_data_handle(DummyData);
+ context_.filter_mem->set_data_handle(DummyData);
+ context_.bias_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
return;
}
@@ -119,139 +128,177 @@ class Conv2DFwd : public DnnOp {
// src_data: input data buffer of src
// filter_data: input data buffer of filter (weights)
// dst_data: output data buffer of dst
- void Execute(T* src_data, T* filter_data, T* dst_data) {
- src_mem_->set_data_handle(static_cast<void*>(src_data));
- filter_mem_->set_data_handle(static_cast<void*>(filter_data));
- dst_mem_->set_data_handle(static_cast<void*>(dst_data));
- fwd_stream_->submit(fwd_primitives_);
-
- // after exec, set data handle back
- src_mem_->set_data_handle(DummyData);
- filter_mem_->set_data_handle(DummyData);
- dst_mem_->set_data_handle(DummyData);
-
- return;
+ void Execute(const T* src_data, const T* filter_data, const T* dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.filter_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(filter_data)));
+ context_.dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(dst_data)));
+ context_.fwd_stream->submit(context_.fwd_primitives);
+
+ // after execution, set data handle back
+ context_.src_mem->set_data_handle(DummyData);
+ context_.filter_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
}
- // expected memory format for this primitive instance
- memory::format src_fmt_;
- memory::format filter_fmt_;
+ memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
- // convolution primitive
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd_;
- std::shared_ptr<mkldnn::primitive> conv_fwd_;
+ memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
+
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
+ GetPrimitiveDesc() const {
+ return context_.fwd_pd;
+ }
private:
- void Setup(const ConvFwdDimensions& convFwdDims) {
+ // Primitive reuse context for Conv2D Fwd op
+ struct ConvFwdContext {
+ // expected memory format for this primitive instance
+ memory::format src_fmt;
+ memory::format filter_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> filter_mem;
+ std::shared_ptr<mkldnn::memory> bias_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
+
+ // memory desc
+ std::shared_ptr<mkldnn::memory::desc> src_md;
+ std::shared_ptr<mkldnn::memory::desc> filter_md;
+ std::shared_ptr<mkldnn::memory::desc> bias_md;
+ std::shared_ptr<mkldnn::memory::desc> dst_md;
+
+ // convolution primitive
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
+ std::shared_ptr<mkldnn::primitive> conv_fwd;
+
+ std::shared_ptr<mkldnn::stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ ConvFwdContext()
+ : src_fmt(memory::format::any),
+ filter_fmt(memory::format::any),
+ src_mem(nullptr),
+ filter_mem(nullptr),
+ bias_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ src_md(nullptr),
+ filter_md(nullptr),
+ bias_md(nullptr),
+ fwd_pd(nullptr),
+ conv_fwd(nullptr),
+ fwd_stream(nullptr) {}
+ };
+
+ void Setup(const MklConvFwdParams& convFwdDims) {
// create memory descriptors for convolution data w/ no specified format
- src_md_.reset(new memory::desc({convFwdDims.src_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.src_md.reset(new memory::desc(
+ {convFwdDims.src_dims}, MklDnnType<T>(), memory::format::any));
- filter_md_.reset(new memory::desc({convFwdDims.filter_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.filter_md.reset(new memory::desc(
+ {convFwdDims.filter_dims}, MklDnnType<T>(), memory::format::any));
- dst_md_.reset(new memory::desc({convFwdDims.dst_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.dst_md.reset(new memory::desc(
+ {convFwdDims.dst_dims}, MklDnnType<T>(), memory::format::any));
if (!convFwdDims.bias_dims.empty())
- bias_md_.reset(new memory::desc({convFwdDims.bias_dims},
- MklDnnType<T>(), memory::format::any));
+ context_.bias_md.reset(new memory::desc(
+ {convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::any));
// create a convolution
if (!convFwdDims.bias_dims.empty()) {
- fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *src_md_, *filter_md_, *bias_md_, *dst_md_,
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
} else {
- fwd_desc_.reset(new convolution_forward::desc(prop_kind::forward,
- convolution_direct, *src_md_, *filter_md_, *dst_md_,
- convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
+ context_.fwd_desc.reset(new convolution_forward::desc(
+ prop_kind::forward, convolution_direct, *context_.src_md,
+ *context_.filter_md, *context_.dst_md, convFwdDims.strides,
+ convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
}
- fwd_pd_.reset(new convolution_forward::primitive_desc(
- *fwd_desc_, cpu_engine_));
+ context_.fwd_pd.reset(new convolution_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
// store the expected memory format
- src_fmt_ = static_cast<mkldnn::memory::format>(
- fwd_pd_.get()->src_primitive_desc().desc().data.format);
+ context_.src_fmt = static_cast<mkldnn::memory::format>(
+ context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
- filter_fmt_ = static_cast<mkldnn::memory::format>(
- fwd_pd_.get()->weights_primitive_desc().desc().data.format);
+ context_.filter_fmt = static_cast<mkldnn::memory::format>(
+ context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
// create memory primitive based on dummy data
- src_mem_.reset(new memory(fwd_pd_.get()->src_primitive_desc(), DummyData));
- filter_mem_.reset(new memory(fwd_pd_.get()->weights_primitive_desc(),
- DummyData));
- dst_mem_.reset(new memory(fwd_pd_.get()->dst_primitive_desc(), DummyData));
+ context_.src_mem.reset(
+ new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
+ context_.filter_mem.reset(
+ new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) {
- bias_mem_.reset(new memory({{{convFwdDims.bias_dims}, MklDnnType<T>(),
- memory::format::x}, cpu_engine_}, DummyData));
- conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_,
- *filter_mem_, *bias_mem_, *dst_mem_));
+ context_.bias_mem.reset(new memory(
+ {{{convFwdDims.bias_dims}, MklDnnType<T>(), memory::format::x},
+ cpu_engine_},
+ DummyData));
+ context_.conv_fwd.reset(new convolution_forward(
+ *context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
+ *context_.bias_mem, *context_.dst_mem));
} else {
- conv_fwd_.reset(new convolution_forward(*fwd_pd_, *src_mem_,
- *filter_mem_, *dst_mem_));
+ context_.conv_fwd.reset(
+ new convolution_forward(*context_.fwd_pd, *context_.src_mem,
+ *context_.filter_mem, *context_.dst_mem));
}
- fwd_primitives_.push_back(*conv_fwd_);
+ context_.fwd_primitives.push_back(*context_.conv_fwd);
return;
}
- // MKLDNN memory
- std::shared_ptr<mkldnn::memory> src_mem_;
- std::shared_ptr<mkldnn::memory> filter_mem_;
- std::shared_ptr<mkldnn::memory> bias_mem_;
- std::shared_ptr<mkldnn::memory> dst_mem_;
-
- std::shared_ptr<mkldnn::stream> fwd_stream_;
- std::vector<mkldnn::primitive> fwd_primitives_;
-
- // desc & prmitive desc
- std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc_;
-
- // memory desc
- std::shared_ptr<mkldnn::memory::desc> src_md_;
- std::shared_ptr<mkldnn::memory::desc> filter_md_;
- std::shared_ptr<mkldnn::memory::desc> bias_md_;
- std::shared_ptr<mkldnn::memory::desc> dst_md_;
-
- engine cpu_engine_ = engine(engine::cpu, 0);
+ struct ConvFwdContext context_;
+ engine cpu_engine_;
};
template <typename T>
-class Conv2DFwdFactory : public DnnOpFactory<T> {
+class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static Conv2DFwd<T>* Get(const ConvFwdDimensions& convFwdDims) {
- Conv2DFwd<T>* conv2d_fwd = nullptr;
-
- // try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<Conv2DFwd<T>*> (
- Conv2DFwdFactory<T>::GetInstance().GetConv2DFwd(convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new Conv2DFwd<T>(convFwdDims);
- Conv2DFwdFactory<T>::GetInstance().SetConv2DFwd(
- convFwdDims, conv2d_fwd);
- }
- return conv2d_fwd;
+ static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+
+ // try to find a suitable one in pool
+ conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
+ convFwdDims));
+
+ if (conv2d_fwd == nullptr) {
+ conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
+ MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
+ conv2d_fwd);
+ }
+ return conv2d_fwd;
}
private:
- Conv2DFwdFactory() {}
- ~Conv2DFwdFactory() {}
+ MklConv2DFwdPrimitiveFactory() {}
+ ~MklConv2DFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1;
- static Conv2DFwdFactory& GetInstance() {
- static Conv2DFwdFactory instance_;
+ static MklConv2DFwdPrimitiveFactory& GetInstance() {
+ static MklConv2DFwdPrimitiveFactory instance_;
return instance_;
}
- static std::string CreateKey(const ConvFwdDimensions& convFwdDims) {
+ static std::string CreateKey(const MklConvFwdParams& convFwdDims) {
std::string prefix = "conv2d_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
@@ -266,12 +313,12 @@ class Conv2DFwdFactory : public DnnOpFactory<T> {
return key_creator.GetKey();
}
- DnnOp* GetConv2DFwd(const ConvFwdDimensions& convFwdDims) {
+ MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
std::string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
- void SetConv2DFwd(const ConvFwdDimensions& convFwdDims, DnnOp *op) {
+ void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
std::string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -762,7 +809,6 @@ class MklConv2DOp : public OpKernel {
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> filter(&cpu_engine);
- MklDnnData<T> dst(&cpu_engine); // output
memory::dims src_dims, filter_dims, padding_left, padding_right,
dilations, strides;
@@ -812,7 +858,6 @@ class MklConv2DOp : public OpKernel {
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
- src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO).
@@ -820,29 +865,30 @@ class MklConv2DOp : public OpKernel {
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(),
memory::format::hwio);
- filter.SetUsrMem(filter_md, &filter_tensor);
// MKLDNN dilation starts from 0.
dilations[kDilationH] -= 1;
dilations[kDilationW] -= 1;
// get a conv2d fwd from primitive pool
- Conv2DFwd<T> *conv2d_fwd = nullptr;
+ MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
- ConvFwdDimensions convFwdDims(src_dims, filter_dims, bias_dims,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
- conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims);
+ MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
+ conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
} else {
- ConvFwdDimensions convFwdDims(src_dims, filter_dims, NONE_DIMS,
- dst_dims_mkl_order, strides, dilations, padding_left, padding_right);
- conv2d_fwd = Conv2DFwdFactory<T>::Get(convFwdDims);
+ MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
+ dst_dims_mkl_order, strides, dilations,
+ padding_left, padding_right);
+ conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
}
// allocate output tensors output_tensor and filter_out_tensor
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc>
- conv_fwd_pd = conv2d_fwd->fwd_pd_;
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
+ conv2d_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -853,21 +899,25 @@ class MklConv2DOp : public OpKernel {
T* dst_data = static_cast<T*>(dst_tensor->flat<T>().data());
// check whether src/filter need reorder
- std::vector<primitive> net;
- if (src_md.data.format != conv2d_fwd->src_fmt_)
- src.CheckReorderToOpMem(
- conv_fwd_pd.get()->src_primitive_desc(), &net);
-
- if (filter_md.data.format != conv2d_fwd->filter_fmt_)
- filter.CheckReorderToOpMem(
- conv_fwd_pd.get()->weights_primitive_desc(),
- filter.GetTensorBuffer(filter_out_tensor), &net);
- stream(stream::kind::eager).submit(net).wait();
-
- T* src_data = static_cast<T*>(
- src.GetOpMem().get_data_handle());
- T* filter_data = static_cast<T*>(
- filter.GetOpMem().get_data_handle());
+ T *src_data = nullptr;
+ if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
+ src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
+ } else {
+ src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
+ }
+ T* filter_data = nullptr;
+ if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
+ filter.SetUsrMem(filter_md, &filter_tensor);
+ filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
+ filter.GetTensorBuffer(filter_out_tensor));
+ filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
+ } else {
+ filter_data =
+ static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
+ }
+
// execute convolution
if (biasEnabled) {
@@ -962,16 +1012,15 @@ class MklConv2DOp : public OpKernel {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer.
- std::vector<primitive> net;
- src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
+ src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
// rather than re-order to a temp buffer, reorder directly to the
// filter output tensor
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
- filter->GetTensorBuffer(filter_out_tensor),
- &net);
+ filter->GetTensorBuffer(filter_out_tensor));
// Create convolution primitive and add it to net.
+ std::vector<primitive> net;
if (bias) {
CHECK_EQ(biasEnabled, true);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 8333a09316..5e1a5001dc 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <limits>
#include <string>
#include <vector>
+#include <memory>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -349,6 +350,7 @@ class MklDnnConvUtil {
}
};
+
/////////////////////////////////////////////////////////////////////
/// Common class that implements Conv2DBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
@@ -388,227 +390,17 @@ class MklConv2DBackpropCommonOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
- void Compute(OpKernelContext* context) override {
- try {
- auto cpu_engine = engine(engine::cpu, 0);
-
- // Prepare common tensors for Conv2DBackpropInput and
- // Conv2DBackpropFilter.
- MklDnnData<T> input(&cpu_engine);
- MklDnnData<T> filter(&cpu_engine);
- MklDnnData<T> outbackprop(&cpu_engine);
- MklDnnData<T> output(&cpu_engine);
-
- // Input tensors
- const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
- const Tensor& input_tensor = MklGetInput(context, kInputIdx);
- const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
- const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx);
-
- MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape;
- GetMklShape(context, kInputIdx, &input_mkl_shape);
- GetMklShape(context, kFilterIdx, &filter_mkl_shape);
- GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape);
- // Allow operator-specific sanity checking of shapes.
- ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape);
-
- // Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
- // tensor containing shape of filter. So filter.shape() is not
- // a correct way to get filter shape. These operator-specific calls
- // allow this class to handle this case.
- TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor);
- TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
- TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx);
-
- // Corner cases: output with 0 elements and 0 batch size.
- Tensor* output_tensor = nullptr;
- if (input_tf_shape.num_elements() == 0 ||
- filter_tf_shape.num_elements() == 0 ||
- outbprop_tf_shape.num_elements() == 0) {
- MklDnnShape output_mkl_shape;
- output_mkl_shape.SetMklTensor(false);
- TensorShape output_tf_shape = GetOutputTfShape(
- input_tf_shape, filter_tf_shape, outbprop_tf_shape);
- const int kOutputIdx = 0;
- AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor,
- output_tf_shape, output_mkl_shape);
- CHECK_NOTNULL(output_tensor);
-
- // if output tensor has more than 0 elements, we need to 0 them out.
- for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) {
- output_tensor->flat<T>().data()[i] = 0;
- }
-
- return;
- }
-
- // By default, all dims are in MKL order. Only dims in TF order
- // are those with prefix tf_order.
- memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims;
- memory::dims padding_l, padding_r, dilations, strides, fwd_output_dims;
- memory::dims fwd_output_dims_tf_order;
-
- // Get forward convolution parameters.
- MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
- dilations_);
- conv_utl.GetConvFwdSizesInMklOrder(
- input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims,
- &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
- &padding_l, &padding_r);
- if (!context->status().ok()) return;
-
- // Create Convolution forward descriptor since Convolution backward
- // API needs it. For that, we first need to create input, filter
- // and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
- // If input is in MKL layout, then simply grab input layout; otherwise,
- // construct input TF layout. For TF layout, although input shape
- // required is in MKL-DNN order, the layout is Tensorflow's layout
- // (NHWC or NCHW depending on data format).
- auto fwd_input_md =
- input_mkl_shape.IsMklTensor()
- ? input_mkl_shape.GetMklLayout()
- : memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt);
- // If filter is in MKL layout, then simply grab filter layout; otherwise
- // construct filter in TF layout. For TF layout, filter is in HWIO format.
- auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
- // Tensorflow Output of Conv2D is in data_format order.
- auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt);
-
- const int kDilationH = 0, kDilationW = 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
- auto fwd_desc = (dilations[kDilationH] > 0 || dilations[kDilationW] > 0)?
- convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_input_md,
- fwd_filter_md, fwd_out_md,
- strides, dilations, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_)) :
- convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_input_md,
- fwd_filter_md, fwd_out_md,
- strides, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_));
- auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
-
- // Create memory for user data. Describe how the inputs and outputs of
- // Convolution look like. Also specify buffers containing actual input
- // and output data.
-
- // Since this is a common class for both Conv2DBackpropFilter and
- // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for
- // Conv2DBackpropInput) and for filter tensor (for
- // conv2DBackpropFilter) depending on which tensor is int32 type.
- size_t input_with_sizes = GetInputTensorIndexWithSizes();
- if (input_with_sizes != kInputIdx) {
- // Shape of Conv2DBackpropFilter's input is same as Conv2D input.
- input.SetUsrMem(fwd_input_md, &input_tensor);
- } else if (input_with_sizes != kFilterIdx) {
- // Shape of Conv2DBackpropInput's filter is same as Conv2D filter.
- filter.SetUsrMem(fwd_filter_md, &filter_tensor);
- }
-
- conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims);
- if (!context->status().ok()) return;
- if (outbprop_mkl_shape.IsMklTensor()) {
- // If outbackprop is in Mkl layout, then simply grab it.
- auto outbprop_md = outbprop_mkl_shape.GetMklLayout();
- outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor);
- } else {
- // If outbackprop is in TensorFlow layout, then we need to create memory
- // descriptor for it. Outbackprop shape is data format order.
- outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor);
- }
-
- // Operator specific call to get output shape and data_format.
- auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims);
- auto bwd_output_format = GetOutputFormat(tf_fmt);
- output.SetUsrMem(bwd_output_dims, bwd_output_format);
-
- // Create memory descriptors for convolution data w/ no specified format.
- input.SetOpMemDesc(fwd_input_dims, memory::format::any);
- filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
- outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any);
- output.SetOpMemDesc(bwd_output_dims, memory::format::any);
-
- // Operator-specific call to create and execute primitive.
- CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter,
- &outbackprop, &output, &output_tensor,
- strides, dilations, padding_l, padding_r,
- TFPaddingToMklDnnPadding(padding_),
- bwd_output_dims, bwd_output_format);
- } catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
- }
- }
-
- /// Pure virtual function to allow operator to check for validity of input
- /// shapes. Function asserts that input shapes are valid.
- virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
- const MklDnnShape& filter_mkl_shape,
- const MklDnnShape& outbprop_mkl_shape) = 0;
-
- /// Operator-specific function that returns index of input that is
- /// representing input sizes. For Conv2DBackpropFilter it returns 1 since
- /// filter for this operator is filter shape. For Conv2DBackpropInput it
- /// returns 0 (for input).
- virtual size_t GetInputTensorIndexWithSizes() = 0;
-
- /// Get TensorFlow shape of input tensor.
- virtual TensorShape MakeInputTfShape(OpKernelContext* context,
- const Tensor& input_tensor) = 0;
-
- /// Get TensorFlow shape of filter tensor.
- virtual TensorShape MakeFilterTfShape(OpKernelContext* context,
- const Tensor& filter_tensor) = 0;
-
- /// Get the TensorFlow shape of output tensor.
- virtual TensorShape GetOutputTfShape(const TensorShape& input_shape,
- const TensorShape& filter_shape,
- const TensorShape& outbprop_shape) = 0;
-
- /// Get shape of output in MKL-DNN order. Computes shape of output from
- /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims).
- virtual const memory::dims& GetOutputDims(
- const memory::dims& fwd_input_dims,
- const memory::dims& fwd_filter_dims) = 0;
-
- /// Get data_format of output in MKL-DNN order. If output data format is
- /// same as input data format, then it simply returns value of data_format
- /// parameter as it is.
- virtual memory::format GetOutputFormat(const memory::format data_format) = 0;
-
- /// Create and execute the primitive storing output in the output_tensor.
- virtual void CreatePrimitive(OpKernelContext* context,
- const engine& cpu_engine,
- const convolution_forward::primitive_desc& conv_fwd_pd,
- MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop,
- MklDnnData<T>* output, Tensor** output_tensor, const memory::dims& strides,
- const memory::dims& dilations, const memory::dims& padding_l,
- const memory::dims& padding_r, padding_kind padding,
- const memory::dims& bwd_output_dims,
- memory::format bwd_output_format) = 0;
-
- // Get the data_format {NCHW, NHWC}
- TensorFormat GetTFDataFormat() { return data_format_; }
-
- private:
+ protected:
+ // data members accessible to derived classes.
std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
- TensorFormat data_format_;
+ TensorFormat data_format_; // NCHW or NHWC
};
+
#endif // INTEL_MKL_ML
+
/////////////////////////////////////////////////////////////////////
/// Dummy Mkl op that is just used for operators that are intermediate
/// output of node fusion in the graph
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 62aafa7930..3fe660cf96 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -21,21 +21,21 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
-
using mkldnn::batch_normalization_backward;
using mkldnn::batch_normalization_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
using mkldnn::use_global_stats;
using mkldnn::use_scale_shift;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
// TODO(inteltf) Address comments from PR 8968.
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc
index 6c027f8e72..b02cc5384c 100644
--- a/tensorflow/core/kernels/mkl_identity_op.cc
+++ b/tensorflow/core/kernels/mkl_identity_op.cc
@@ -24,8 +24,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index cda1402b03..dc4da33a06 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -369,8 +369,8 @@ class MklInputConversionOp : public OpKernel {
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
op_data_type, has_avx512f_,
kInputIndex_1);
- SetDummyMklShapeOutput(context, kInputIndex_0);
- SetDummyMklShapeOutput(context, kInputIndex_1);
+ SetDummyMklDnnShapeOutput(context, kInputIndex_0);
+ SetDummyMklDnnShapeOutput(context, kInputIndex_1);
return;
}
@@ -439,11 +439,11 @@ class MklInputConversionOp : public OpKernel {
tensor_out, &net);
if(!reordered) {
// This is the case that the TF tensor has the same shape and format of
- // mkl tensor. However, tf_tensor can not be simply forwarded to the output
- // tensor since mkl data tensor is always one dimensional tensor.
- // Tensor::CopyFrom shares the buffer of the other tensor while set its shape
- // to the other tensor.
- tensor_out->CopyFrom(*tf_tensor, tensor_out->shape());
+ // mkl tensor. However, tf_tensor can not be simply forwarded to the
+ // output tensor since mkl data tensor is always one dimensional tensor.
+ // Tensor::CopyFrom shares the buffer of the other tensor while set its
+ // shape to the other tensor.
+ CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
}
else
stream(stream::kind::eager).submit(net).wait();
@@ -458,7 +458,7 @@ class MklInputConversionOp : public OpKernel {
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
op_data_type, has_avx512f_,
mkl_tensor_index);
- SetDummyMklShapeOutput(context, mkl_tensor_index);
+ SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
// The tensor in TF format passes through
ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index eef254cdad..7966c271d5 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -22,8 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -31,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#if !defined(IS_MOBILE_PLATFORM)
@@ -45,8 +42,13 @@ using mkldnn::lrn_backward;
using mkldnn::lrn_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
+
namespace tensorflow {
namespace {
@@ -845,12 +847,12 @@ class MklLRNOp : public OpKernel {
MklDnnData<T>* src_dnn_data,
MklDnnData<T>* dst_dnn_data,
MklDnnData<uint8>* wksp_dnn_data = nullptr) {
- std::vector<primitive> net;
// Check for input reorder
- src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
// Create pooling primitive and add it to net
+ std::vector<primitive> net;
if (wksp_dnn_data != nullptr) {
net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
wksp_dnn_data->GetOpMem(),
@@ -1158,15 +1160,15 @@ class MklLRNGradOp : public OpKernel {
MklDnnData<T>* output_diff_src,
const memory::primitive_desc& target_diff_dst_pd,
const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
- std::vector<primitive> net;
// Check for input reordering on the diff dst input
input_gradient_diff_dst->CheckReorderToOpMem(
- lrn_bkwd_desc.diff_dst_primitive_desc(), &net);
+ lrn_bkwd_desc.diff_dst_primitive_desc());
// Check for input reordering on the original input
- src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
// Create pooling primitive and add it to net
+ std::vector<primitive> net;
if (nullptr == workspace_dnn_data) {
net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
input_gradient_diff_dst->GetOpMem(),
@@ -1236,7 +1238,7 @@ class MklLRNGradOp : public OpKernel {
auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
Tensor* output_dnn_data;
- MklShape mkl_output_mkl_shape;
+ MklDnnShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
mkl_output_mkl_shape.SetDimensions(4);
AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index dfa6cecc9b..62c0404891 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
// and when it is undefined at build time, this file becomes an empty
// compilation unit
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#include "mkl_cblas.h"
#include "tensorflow/core/framework/op.h"
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 1ed43834dd..78abbdb730 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -23,9 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
@@ -38,7 +35,11 @@ using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 2cfde1f6fd..02ea9fc068 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -24,15 +24,17 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
-#include "mkl_dnn.h"
-#include "mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
using mkldnn::stream;
+#else
+#include "mkl_dnn.h"
+#include "mkl_dnn_types.h"
#endif
+#include "tensorflow/core/util/mkl_util.h"
+
namespace tensorflow {
using CPUDevice = Eigen::ThreadPoolDevice;
template <typename Device, typename T>
@@ -250,7 +252,7 @@ class MklReshapeOp : public OpKernel {
memory::primitive_desc(output_tf_md, cpu_engine);
Tensor* output_tensor = nullptr;
- MklShape mkl_shape_output;
+ MklDnnShape mkl_shape_output;
mkl_shape_output.SetMklTensor(false);
// We allocate output tensor in the shape expected by Reshape.
AllocateOutputSetMklShape(context, kOutputSlotIdx, &output_tensor,
@@ -261,10 +263,7 @@ class MklReshapeOp : public OpKernel {
// shape_from != shape_to), then we just copy input tensor to
// output tensor with target shape (we cannot forward Mkl layout
// in such case because shape has changed.)
- std::vector<primitive> net;
- if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor,
- &net)) {
- stream(stream::kind::eager).submit(net).wait();
+ if (dnn_data_input.CheckReorderToOpMem(output_tf_pd, output_tensor)) {
} else {
OP_REQUIRES(
context, output_tensor->CopyFrom(input_tensor, shape_to),
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index f79e18cff2..638392954e 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -25,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_format.h"
-#include "mkldnn.h"
-#include "mkldnn_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "mkldnn.hpp"
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index 4120f013ac..f4f0035f26 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -32,8 +32,10 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
+#endif
#include "tensorflow/core/util/mkl_util.h"
#ifndef INTEL_MKL_ML
@@ -109,10 +111,8 @@ class MklToTfOp : public OpKernel {
// Do we need to reorder Mkl layout into TensorFlow layout?
if (input.IsReorderNeeded(output_tf_pd)) {
// Insert reorder between Mkl layout and TensorFlow layout.
- std::vector<primitive> net;
- CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, output_tensor, &net),
+ CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, output_tensor),
true);
- stream(stream::kind::eager).submit(net).wait();
} else {
// If not, just forward input tensor to output tensor.
CHECK(output_tensor->CopyFrom(input_tensor, output_shape));
diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc
index 3f07b317c4..b180c2ff20 100644
--- a/tensorflow/core/kernels/mkl_transpose_op.cc
+++ b/tensorflow/core/kernels/mkl_transpose_op.cc
@@ -15,7 +15,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#define EIGEN_USE_THREADS
#include "mkl_trans.h"
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index 23fdfe944a..f59843a07a 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/non_max_suppression_op.h"
+#include <functional>
#include <queue>
#include <vector>
@@ -38,9 +39,32 @@ namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
+static inline void CheckScoreSizes(OpKernelContext* context, int num_boxes,
+ const Tensor& scores) {
+ // The shape of 'scores' is [num_boxes]
+ OP_REQUIRES(context, scores.dims() == 1,
+ errors::InvalidArgument("scores must be 1-D",
+ scores.shape().DebugString()));
+ OP_REQUIRES(context, scores.dim_size(0) == num_boxes,
+ errors::InvalidArgument("scores has incompatible shape"));
+}
+
+static inline void ParseAndCheckOverlapSizes(OpKernelContext* context,
+ const Tensor& overlaps,
+ int* num_boxes) {
+ // the shape of 'overlaps' is [num_boxes, num_boxes]
+ OP_REQUIRES(context, overlaps.dims() == 2,
+ errors::InvalidArgument("overlaps must be 2-D",
+ overlaps.shape().DebugString()));
+
+ *num_boxes = overlaps.dim_size(0);
+ OP_REQUIRES(context, overlaps.dim_size(1) == *num_boxes,
+ errors::InvalidArgument("overlaps must be square",
+ overlaps.shape().DebugString()));
+}
+
static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
- const Tensor& boxes,
- const Tensor& scores, int* num_boxes) {
+ const Tensor& boxes, int* num_boxes) {
// The shape of 'boxes' is [num_boxes, 4]
OP_REQUIRES(context, boxes.dims() == 2,
errors::InvalidArgument("boxes must be 2-D",
@@ -48,18 +72,12 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
*num_boxes = boxes.dim_size(0);
OP_REQUIRES(context, boxes.dim_size(1) == 4,
errors::InvalidArgument("boxes must have 4 columns"));
-
- // The shape of 'scores' is [num_boxes]
- OP_REQUIRES(context, scores.dims() == 1,
- errors::InvalidArgument("scores must be 1-D",
- scores.shape().DebugString()));
- OP_REQUIRES(context, scores.dim_size(0) == *num_boxes,
- errors::InvalidArgument("scores has incompatible shape"));
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
- int j) {
+static inline float IOUGreaterThanThreshold(
+ typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
+ float iou_threshold) {
const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
@@ -78,24 +96,36 @@ static inline float IOU(typename TTypes<float, 2>::ConstTensor boxes, int i,
const float intersection_area =
std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- return intersection_area / (area_i + area_j - intersection_area);
+ const float iou = intersection_area / (area_i + area_j - intersection_area);
+ return iou > iou_threshold;
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
- const Tensor& scores, const Tensor& max_output_size,
- const float iou_threshold,
- const float score_threshold) {
- OP_REQUIRES(context, iou_threshold >= 0 && iou_threshold <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
-
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, scores, &num_boxes);
- if (!context->status().ok()) {
- return;
- }
+static inline bool OverlapsGreaterThanThreshold(
+ typename TTypes<float, 2>::ConstTensor overlaps, int i, int j,
+ float overlap_threshold) {
+ return overlaps(i, j) > overlap_threshold;
+}
+
+static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
+ const Tensor& boxes, float threshold) {
+ typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
+ return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
+ std::placeholders::_2, threshold);
+}
+
+static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
+ const Tensor& overlaps, float threshold) {
+ typename TTypes<float, 2>::ConstTensor overlaps_data =
+ overlaps.tensor<float, 2>();
+ return std::bind(&OverlapsGreaterThanThreshold, overlaps_data,
+ std::placeholders::_1, std::placeholders::_2, threshold);
+}
+void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
+ int num_boxes, const Tensor& max_output_size,
+ const float score_threshold,
+ std::function<bool(int, int)> suppress_check_fn) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
- TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
std::vector<float> scores_data(num_boxes);
std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
@@ -120,11 +150,9 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
std::vector<int> selected;
std::vector<float> selected_scores;
Candidate next_candidate;
- float iou, original_score;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
next_candidate = candidate_priority_queue.top();
- original_score = next_candidate.score;
candidate_priority_queue.pop();
// Overlapping boxes are likely to have similar scores,
@@ -132,9 +160,10 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes,
// in order to see if `next_candidate` should be suppressed.
bool should_select = true;
for (int j = selected.size() - 1; j >= 0; --j) {
- iou = IOU(boxes_data, next_candidate.box_index, selected[j]);
- if (iou == 0.0) continue;
- if (iou > iou_threshold) should_select = false;
+ if (suppress_check_fn(next_candidate.box_index, selected[j])) {
+ should_select = false;
+ break;
+ }
}
if (should_select) {
@@ -174,9 +203,19 @@ class NonMaxSuppressionOp : public OpKernel {
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
max_output_size.shape().DebugString()));
+ OP_REQUIRES(context, iou_threshold_ >= 0 && iou_threshold_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_, score_threshold_val);
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
@@ -207,9 +246,19 @@ class NonMaxSuppressionV2Op : public OpKernel {
iou_threshold.shape().DebugString()));
const float iou_threshold_val = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val, score_threshold_val);
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -245,8 +294,65 @@ class NonMaxSuppressionV3Op : public OpKernel {
score_threshold.shape().DebugString()));
const float score_threshold_val = score_threshold.scalar<float>()();
- DoNonMaxSuppressionOp(context, boxes, scores, max_output_size,
- iou_threshold_val, score_threshold_val);
+ OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
+ int num_boxes = 0;
+ ParseAndCheckBoxSizes(context, boxes, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
+ }
+};
+
+template <typename Device>
+class NonMaxSuppressionWithOverlapsOp : public OpKernel {
+ public:
+ explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ // overlaps: [num_boxes, num_boxes]
+ const Tensor& overlaps = context->input(0);
+ // scores: [num_boxes]
+ const Tensor& scores = context->input(1);
+ // max_output_size: scalar
+ const Tensor& max_output_size = context->input(2);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ errors::InvalidArgument("max_output_size must be 0-D, got shape ",
+ max_output_size.shape().DebugString()));
+ // overlap_threshold: scalar
+ const Tensor& overlap_threshold = context->input(3);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(overlap_threshold.shape()),
+ errors::InvalidArgument("overlap_threshold must be 0-D, got shape ",
+ overlap_threshold.shape().DebugString()));
+ const float overlap_threshold_val = overlap_threshold.scalar<float>()();
+
+ // score_threshold: scalar
+ const Tensor& score_threshold = context->input(4);
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(score_threshold.shape()),
+ errors::InvalidArgument("score_threshold must be 0-D, got shape ",
+ score_threshold.shape().DebugString()));
+ const float score_threshold_val = score_threshold.scalar<float>()();
+
+ int num_boxes = 0;
+ ParseAndCheckOverlapSizes(context, overlaps, &num_boxes);
+ CheckScoreSizes(context, num_boxes, scores);
+ if (!context->status().ok()) {
+ return;
+ }
+ auto suppress_check_fn =
+ CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
+
+ DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -259,4 +365,8 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
+ NonMaxSuppressionWithOverlapsOp<CPUDevice>);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index ed7db313bd..055161a35f 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -569,4 +569,241 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
}
+//
+// NonMaxSuppressionWithOverlapsOp Tests
+//
+
+class NonMaxSuppressionWithOverlapsOpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op",
+ "NonMaxSuppressionWithOverlaps")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+
+ void AddIoUInput(const std::vector<float>& boxes) {
+ ASSERT_EQ((boxes.size() % 4), 0);
+ size_t num_boxes = boxes.size() / 4;
+ std::vector<float> iou_overlaps(num_boxes * num_boxes);
+
+ // compute the pairwise IoU overlaps
+ auto corner_access = [&boxes](size_t box_idx, size_t corner_idx) {
+ return boxes[box_idx * 4 + corner_idx];
+ };
+ for (size_t i = 0; i < num_boxes; ++i) {
+ for (size_t j = 0; j < num_boxes; ++j) {
+ const float ymin_i =
+ std::min<float>(corner_access(i, 0), corner_access(i, 2));
+ const float xmin_i =
+ std::min<float>(corner_access(i, 1), corner_access(i, 3));
+ const float ymax_i =
+ std::max<float>(corner_access(i, 0), corner_access(i, 2));
+ const float xmax_i =
+ std::max<float>(corner_access(i, 1), corner_access(i, 3));
+ const float ymin_j =
+ std::min<float>(corner_access(j, 0), corner_access(j, 2));
+ const float xmin_j =
+ std::min<float>(corner_access(j, 1), corner_access(j, 3));
+ const float ymax_j =
+ std::max<float>(corner_access(j, 0), corner_access(j, 2));
+ const float xmax_j =
+ std::max<float>(corner_access(j, 1), corner_access(j, 3));
+ const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+
+ float iou;
+ if (area_i <= 0 || area_j <= 0) {
+ iou = 0.0;
+ } else {
+ const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
+ const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
+ const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
+ const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
+ const float intersection_area =
+ std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
+ std::max<float>(intersection_xmax - intersection_xmin, 0.0);
+ iou = intersection_area / (area_i + area_j - intersection_area);
+ }
+ iou_overlaps[i * num_boxes + j] = iou;
+ }
+ }
+
+ AddInputFromArray<float>(TensorShape({static_cast<signed>(num_boxes),
+ static_cast<signed>(num_boxes)}),
+ iou_overlaps);
+ }
+};
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectFromThreeClustersFlippedCoordinates) {
+ MakeOp();
+ AddIoUInput({1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f,
+ 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectAtMostTwoBoxesFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {2});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({2}));
+ test::FillValues<int>(&expected, {3, 0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest,
+ TestSelectAtMostThirtyBoxesFromThreeClusters) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({3}));
+ test::FillValues<int>(&expected, {3, 0, 5});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectSingleBox) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestSelectFromTenIdenticalBoxes) {
+ MakeOp();
+
+ int num_boxes = 10;
+ std::vector<float> corners(num_boxes * 4);
+ std::vector<float> scores(num_boxes);
+ for (int i = 0; i < num_boxes; ++i) {
+ corners[i * 4 + 0] = 0;
+ corners[i * 4 + 1] = 0;
+ corners[i * 4 + 2] = 1;
+ corners[i * 4 + 3] = 1;
+ scores[i] = .9;
+ }
+ AddIoUInput(corners);
+ AddInputFromArray<float>(TensorShape({num_boxes}), scores);
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({1}));
+ test::FillValues<int>(&expected, {0});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInconsistentBoxAndScoreShapes) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(
+ str_util::StrContains(s.ToString(), "scores has incompatible shape"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInvalidOverlapsShape) {
+ MakeOp();
+ AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
+ AddInputFromArray<float>(TensorShape({2}), {0.5f, 0.5f});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {0.f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ Status s = RunOpKernel();
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_TRUE(str_util::StrContains(s.ToString(), "overlaps must be square"))
+ << s;
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestThresholdGreaterOne) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {1.2f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestThresholdSmallerZero) {
+ MakeOp();
+ AddIoUInput({0, 0, 1, 1});
+ AddInputFromArray<float>(TensorShape({1}), {.9f});
+ AddInputFromArray<int>(TensorShape({}), {3});
+ AddInputFromArray<float>(TensorShape({}), {-0.2f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) {
+ MakeOp();
+ AddIoUInput({});
+ AddInputFromArray<float>(TensorShape({0}), {});
+ AddInputFromArray<int>(TensorShape({}), {30});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected(allocator(), DT_INT32, TensorShape({0}));
+ test::FillValues<int>(&expected, {});
+ test::ExpectTensorEqual<int>(expected, *GetOutput(0));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc
index 41494f56c5..3b9133ed7e 100644
--- a/tensorflow/core/kernels/pad_op.cc
+++ b/tensorflow/core/kernels/pad_op.cc
@@ -320,7 +320,7 @@ namespace functor {
DECLARE_GPU_SPEC(T, 5); \
DECLARE_GPU_SPEC(T, 6);
-TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
+TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPECS);
TF_CALL_int8(DECLARE_GPU_SPECS);
} // namespace functor
@@ -353,7 +353,7 @@ TF_CALL_int8(DECLARE_GPU_SPECS);
.HostMemory("constant_values"), \
PadOp<GPUDevice, T, int64>)
-TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
+TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNEL);
TF_CALL_int8(REGISTER_GPU_KERNEL);
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/pad_op_gpu.cu.cc b/tensorflow/core/kernels/pad_op_gpu.cu.cc
index 8e13e19e2e..00ec44adc2 100644
--- a/tensorflow/core/kernels/pad_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/pad_op_gpu.cu.cc
@@ -39,7 +39,7 @@ typedef Eigen::GpuDevice GPUDevice;
DEFINE_GPU_PAD_SPECS(T, int32) \
DEFINE_GPU_PAD_SPECS(T, int64)
-TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
+TF_CALL_GPU_ALL_TYPES(DEFINE_GPU_SPECS);
TF_CALL_int8(DEFINE_GPU_SPECS);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index d66b1ba663..b5c6ba1da3 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -19,9 +19,11 @@ limitations under the License.
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
#if GOOGLE_CUDA
@@ -53,6 +55,9 @@ class PartitionedCallOp : public AsyncOpKernel {
errors::Internal("No function library is provided."),
done);
+ OpInputList args;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &args), done);
+
// The function body's graph is placed and partitioned the first time
// `ComputeAsync` is invoked; every subsequent invocation calls each
// of the function shards yielded by partitioning.
@@ -65,16 +70,21 @@ class PartitionedCallOp : public AsyncOpKernel {
// via, e.g., virtual device annotations and a list of device names supplied
// through an attribute.
//
- // TODO(akshayka): Lift the constraint pinning inputs and outputs to the
- // local device.
- //
// TODO(akshayka): Add a fastpath for functions that execute on a single
// device.
{
mutex_lock l(mu_);
- if (!partitioned_) {
- // Instantiate the function to obtain its underlying graph, complete
- // with nodes for arguments and return values.
+ if (function_handles_.find(lib) == function_handles_.end()) {
+ if (local_device_name_.empty()) {
+ // The full local device name isn't known at kernel construction
+ // time, hence the need to set it here.
+ local_device_name_ = lib->device()->name();
+ }
+
+ // TODO(b/37549631): Because this kernel may correspond to a stateful
+ // op, it may be shared by multiple subgraphs, which in turn may have
+ // different `FunctionLibraryRuntime` objects and therefore different
+ // `FHandle` namespaces. As such, we partition on a per-FLR basis.
FunctionLibraryRuntime::InstantiateOptions opts;
FHandle handle;
OP_REQUIRES_OK_ASYNC(
@@ -82,83 +92,38 @@ class PartitionedCallOp : public AsyncOpKernel {
lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
&handle),
done);
- Graph* graph = lib->GetFunctionBody(handle)->graph;
-
- // Pin the inputs and outputs to the local device to simplify the
- // function-dispatching logic.
- local_device_name_ = lib->device()->name();
- for (Node* node : graph->op_nodes()) {
- string node_type = node->type_string();
- if (node_type == FunctionLibraryDefinition::kArgOp ||
- node_type == FunctionLibraryDefinition::kRetOp) {
- node->set_assigned_device_name(local_device_name_);
- }
- }
+ const FunctionBody* fbody = lib->GetFunctionBody(handle);
+ OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
+ errors::Internal("Could not find handle ", handle),
+ done);
+ auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ CopyGraph(*fbody->graph, graph.get());
+ OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
- // Place the graph, i.e,. assign a device to every node in it.
DeviceSet device_set;
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph, &device_set);
+ Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
- // Partition the graph into subgraphs: exactly one subgraph per device.
- //
- // TODO(akshayka): Let devices rewrite their graphs.
- PartitionOptions partition_options;
- partition_options.node_to_loc = [](const Node* node) {
- // TODO(akshayka): To better support the distributed case, first split
- // the graph by worker (e.g,. using the master session's
- // `SplitByWorker` policy), and then recursively partition the
- // per-worker shards at the remote worker(s).
- return node->assigned_device_name();
- };
- int64 edge_name_counter = 0;
- partition_options.new_name =
- [&edge_name_counter](const string& prefix) {
- return strings::StrCat(prefix, "/_", ++edge_name_counter);
- };
- partition_options.get_incarnation =
- [&device_set](const string& name) -> int64 {
- const Device* d = device_set.FindDeviceByName(name);
- if (d == nullptr) {
- return PartitionOptions::kIllegalIncarnation;
- } else {
- return d->attributes().incarnation();
- }
- };
- partition_options.control_flow_added = false;
- std::unordered_map<string, GraphDef> partitions;
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
OP_REQUIRES_OK_ASYNC(
- ctx, Partition(partition_options, graph, &partitions), done);
-
- VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
- << partitions.size() << " shards.";
-
- // `subgraphs` is a map from devices to their corresponding subgraphs.
- gtl::FlatMap<string, std::unique_ptr<Graph>> subgraphs;
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
- for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
- GraphConstructorOptions opts;
- opts.allow_internal_ops = true;
- opts.expect_device_spec = true;
- const string& device = partition.first;
- const GraphDef& graph_def = partition.second;
- OP_REQUIRES_OK_ASYNC(
- ctx, ConvertGraphDefToGraph(opts, graph_def, subgraph.get()),
- done);
- subgraphs.emplace(device, std::move(subgraph));
- }
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
// The FunctionLibraryRuntime's library cannot be mutated from within
- // an OpKernel, so the functions are instantiated in an overlay library.
+ // an OpKernel, so functions are instantiated in an overlay library.
overlay_lib_.reset(new FunctionLibraryDefinition(
*lib->GetFunctionLibraryDefinition()));
+ auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
+ // TODO(akshayka): Fail gracefully if the set of devices corresponds
+ // to more than one address space.
const string& target = pair.first;
- Graph* subgraph = pair.second.get();
+ const auto& subgraph = pair.second;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, UpdateArgAndRetMetadata(target, subgraph.get()), done);
FunctionDef shard;
string unique_name = UniquifyFunctionName(func_.name());
OP_REQUIRES_OK_ASYNC(
@@ -173,11 +138,188 @@ class PartitionedCallOp : public AsyncOpKernel {
lib->Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
&handle),
done);
- device_handle_map_.emplace(target, handle);
+ handles->emplace(target, handle);
+ }
+
+ function_handles_.emplace(lib, std::move(handles));
+ }
+ }
+ ExecuteFunctions(lib, ctx, args, std::move(done));
+ }
+
+ private:
+ typedef std::pair<string, FHandle> DeviceAndFHandle;
+ typedef std::pair<std::vector<int>, std::vector<int>> ArgAndRetIndices;
+ typedef std::pair<std::vector<AllocatorAttributes>,
+ std::vector<AllocatorAttributes>>
+ ArgAndRetAllocAttrs;
+
+ // Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
+ // corresponding resource lives. This ensures that the Placer assigns ops that
+ // access these resources to the appropriate devices.
+ Status PinResourceArgs(Graph* graph, const OpInputList& args) {
+ for (Node* node : graph->op_nodes()) {
+ string node_type = node->type_string();
+ if (node_type == FunctionLibraryDefinition::kArgOp) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
+ DataType dtype = attr_value->type();
+ if (dtype == DT_RESOURCE) {
+ ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ node->set_assigned_device_name(handle.device());
}
- partitioned_ = true;
}
}
+ return Status::OK();
+ }
+
+ // Partitions `graph` and populates `subgraphs` with the partitions.
+ Status PartitionHelper(
+ const DeviceSet& device_set, std::unique_ptr<Graph> graph,
+ std::unordered_map<string, std::unique_ptr<Graph>>* subgraphs) {
+ PartitionOptions partition_options;
+ partition_options.node_to_loc = [](const Node* node) {
+ // TODO(akshayka): To better support the distributed case, first split
+ // the graph by worker (e.g,. using the master session's
+ // `SplitByWorker` policy), and then recursively partition the
+ // per-worker shards at the remote worker(s).
+ return node->assigned_device_name();
+ };
+ int64 edge_name_counter = 0;
+ partition_options.new_name = [&edge_name_counter](const string& prefix) {
+ return strings::StrCat(prefix, "/_", ++edge_name_counter);
+ };
+ partition_options.get_incarnation =
+ [&device_set](const string& name) -> int64 {
+ const Device* d = device_set.FindDeviceByName(name);
+ if (d == nullptr) {
+ return PartitionOptions::kIllegalIncarnation;
+ } else {
+ return d->attributes().incarnation();
+ }
+ };
+ partition_options.control_flow_added = false;
+ std::unordered_map<string, GraphDef> partitions;
+ TF_RETURN_IF_ERROR(Partition(partition_options, graph.get(), &partitions));
+
+ VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
+ << partitions.size() << " shards.";
+
+ const FunctionLibraryDefinition* flib_def = &graph->flib_def();
+ for (const auto& partition : partitions) {
+ std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = true;
+ const string& device = partition.first;
+ const GraphDef& graph_def = partition.second;
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(opts, graph_def, subgraph.get()));
+ subgraphs->emplace(device, std::move(subgraph));
+ }
+
+ return Status::OK();
+ }
+
+ // Each subgraph produced by partitioning the function body contains a subset
+ // of the original `Arg` and `Retval` nodes. This function performs
+ // bookkeeping to track which `Arg` and `Retval` nodes were placed on a
+ // particular device / subgraph.
+ //
+ // More specifically, this function
+ // (1) rewrites the indices of the `Arg` and `Retval` nodes placed on a
+ // particular device,
+ // (2) records the subsets of `Arg` and `Retval` nodes assigned to the
+ // device, and
+ // (3) records which `Arg` and `Retval` nodes live in host memory.
+ Status UpdateArgAndRetMetadata(const string& device, Graph* subgraph) {
+ if (arg_and_ret_indices_.find(device) != arg_and_ret_indices_.end()) {
+ // This function has already been partitioned, albeit for a different
+ // function library.
+ return Status::OK();
+ }
+
+ ArgAndRetIndices indices;
+ std::vector<int>* arg_indices = &indices.first;
+ std::vector<int>* ret_indices = &indices.second;
+ std::vector<std::pair<Node*, int>> arg_nodes;
+ std::vector<std::pair<Node*, int>> ret_nodes;
+ const AttrValue* attr_value;
+
+ for (Node* node : subgraph->op_nodes()) {
+ string node_type = node->type_string();
+ if (node_type == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ arg_indices->push_back(index);
+ arg_nodes.push_back(std::make_pair(node, index));
+ } else if (node_type == FunctionLibraryDefinition::kRetOp) {
+ TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
+ int index = attr_value->i();
+ ret_indices->push_back(index);
+ ret_nodes.push_back(std::make_pair(node, index));
+ }
+ }
+
+ auto sort_by_index = [](std::pair<Node*, int> one,
+ std::pair<Node*, int> two) -> bool {
+ return one.second < two.second;
+ };
+ std::sort(arg_nodes.begin(), arg_nodes.end(), sort_by_index);
+ std::sort(ret_nodes.begin(), ret_nodes.end(), sort_by_index);
+ for (int i = 0; i < arg_nodes.size(); ++i) {
+ Node* arg = arg_nodes[i].first;
+ arg->AddAttr("index", i);
+ TF_RETURN_IF_ERROR(arg->attrs().Find("T", &attr_value));
+ AllocatorAttributes alloc_attr;
+ DataType type = attr_value->type();
+ if (MTypeFromDType(type) == HOST_MEMORY) {
+ alloc_attr.set_on_host(true);
+ }
+ arg_and_ret_alloc_attrs_[device].first.push_back(alloc_attr);
+ }
+ for (int i = 0; i < ret_nodes.size(); ++i) {
+ Node* ret = ret_nodes[i].first;
+ ret->AddAttr("index", i);
+ TF_RETURN_IF_ERROR(ret->attrs().Find("T", &attr_value));
+ AllocatorAttributes alloc_attr;
+ DataType type = attr_value->type();
+ if (MTypeFromDType(type) == HOST_MEMORY) {
+ alloc_attr.set_on_host(true);
+ }
+ arg_and_ret_alloc_attrs_[device].second.push_back(alloc_attr);
+ }
+
+ arg_and_ret_indices_.emplace(device, indices);
+ return Status::OK();
+ }
+
+ std::vector<Tensor> GetArgsForIndices(const std::vector<int>& indices,
+ const OpInputList& arguments) {
+ std::vector<Tensor> args;
+ args.reserve(indices.size());
+ for (int i : indices) {
+ args.push_back(arguments[i]);
+ }
+ return args;
+ }
+
+ void ExecuteFunctions(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
+ const OpInputList& op_args, DoneCallback done)
+ LOCKS_EXCLUDED(mu_) {
+ const gtl::FlatMap<string, FHandle>* handles;
+ {
+ mutex_lock l(mu_);
+ handles = function_handles_[lib].get();
+ }
+ if (handles->empty()) {
+ // Trivial case where the function body is empty.
+ ctx->SetStatus(Status::OK());
+ done();
+ return;
+ }
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
@@ -193,11 +335,6 @@ class PartitionedCallOp : public AsyncOpKernel {
Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
opts.rendezvous = rendez;
- OpInputList arguments;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
- // Dummy args vector for the remote shards, which do not have inputs.
- std::vector<Tensor> dummy_args;
-
StatusCallback callback = std::bind(
[](Rendezvous* rendez, DoneCallback& done, const Status& status) {
rendez->Unref();
@@ -205,50 +342,62 @@ class PartitionedCallOp : public AsyncOpKernel {
},
rendez, std::move(done), std::placeholders::_1);
auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
- for (int i = 1; i < device_handle_map_.size(); ++i) {
+ for (int i = 1; i < handles->size(); ++i) {
refcounted_done->Ref();
}
- for (const auto& pair : device_handle_map_) {
- const string& target_device = pair.first;
+ for (const auto& pair : *handles) {
+ const string& target = pair.first;
FHandle handle = pair.second;
- VLOG(3) << "Running function shard on device " << target_device;
- if (target_device == local_device_name_) {
+ VLOG(3) << "Running function shard on device " << target;
+ ArgAndRetIndices indices = arg_and_ret_indices_[target];
+ ArgAndRetAllocAttrs alloc_attrs = arg_and_ret_alloc_attrs_[target];
+ const std::vector<int>& arg_indices = indices.first;
+ const std::vector<int>& ret_indices = indices.second;
+ opts.args_alloc_attrs = alloc_attrs.first;
+ opts.rets_alloc_attrs = alloc_attrs.second;
+ if (target == local_device_name_) {
opts.remote_execution = false;
- std::vector<Tensor> args;
- args.reserve(arguments.size());
- for (const Tensor& argument : arguments) {
- args.push_back(argument);
- }
- auto* rets = new std::vector<Tensor>;
- lib->Run(opts, handle, args, rets,
- [rets, refcounted_done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else {
- for (int i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- refcounted_done->Unref();
- });
+ std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args);
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(
+ opts, handle, args, rets,
+ [rets, ret_indices, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ VLOG(3) << "Local execution failed: " << status;
+ ctx->SetStatus(status);
+ } else {
+ for (int i = 0; i < rets->size(); ++i) {
+ ctx->set_output(ret_indices[i], (*rets)[i]);
+ }
+ }
+ delete rets;
+ VLOG(3) << "Finished local execution.";
+ refcounted_done->Unref();
+ });
} else {
opts.remote_execution = true;
- std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
- lib->Run(opts, handle, dummy_args, dummy_rets,
- [dummy_rets, refcounted_done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- }
- delete dummy_rets;
- refcounted_done->Unref();
- });
+ std::vector<Tensor> args = GetArgsForIndices(arg_indices, op_args);
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(
+ opts, handle, args, rets,
+ [rets, ret_indices, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ VLOG(3) << "Remote execution failed: " << status;
+ ctx->SetStatus(status);
+ } else {
+ for (int i = 0; i < rets->size(); ++i) {
+ ctx->set_output(ret_indices[i], (*rets)[i]);
+ }
+ }
+ delete rets;
+ VLOG(3) << "Finished remote execution.";
+ refcounted_done->Unref();
+ });
}
}
}
- private:
string UniquifyFunctionName(const string& name) {
for (;; ++suffix_) {
const string candidate = strings::StrCat(name, "_", suffix_);
@@ -258,22 +407,42 @@ class PartitionedCallOp : public AsyncOpKernel {
}
}
- // `func_` encapsulates the original, unsharded function.
NameAttrList func_;
string local_device_name_;
// Function shards are added to `overlay_lib_`.
std::unique_ptr<FunctionLibraryDefinition> overlay_lib_;
- // A map from device names to handles of function shards.
- gtl::FlatMap<string, FHandle> device_handle_map_;
+ // Contains maps from device names to handles of function shards, keyed by
+ // FunctionLibraryRuntime pointers. (Because this kernel may be instantiated
+ // for a stateful op, different invocations of it may use different FLRs.)
+ gtl::FlatMap<FunctionLibraryRuntime*,
+ std::unique_ptr<gtl::FlatMap<string, FHandle>>>
+ function_handles_ GUARDED_BY(mu_);
+ // Map from device name to the indices of the arguments and return values
+ // placed on that device. Read-only after the first invocation.
+ gtl::FlatMap<string, ArgAndRetIndices> arg_and_ret_indices_;
+ // Map from device name to alloc attrs for arguments and return values of the
+ // function placed on that device. Read-only after the first invocation.
+ gtl::FlatMap<string, ArgAndRetAllocAttrs> arg_and_ret_alloc_attrs_;
mutex mu_;
- bool partitioned_ GUARDED_BY(mu_) = false;
// Used to uniquify function names in `overlay_lib_`.
uint32 suffix_ = 0;
};
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_CPU),
+ PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_GPU),
+ PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_GPU),
+ PartitionedCallOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_SYCL),
+ PartitionedCallOp);
+REGISTER_KERNEL_BUILDER(Name("StatefulPartitionedCall").Device(DEVICE_SYCL),
+ PartitionedCallOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h
index 906d507c8a..782263e4e9 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.h
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h
@@ -47,9 +47,13 @@ struct QuantizeAndDequantizeOneScaleImpl {
if (!range_given) {
input_min.device(d) = input.minimum();
input_max.device(d) = input.maximum();
+ d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
+ d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
+ } else {
+ // Copy the range values from their respective tensors on the host.
+ min_range = input_min_tensor->scalar<T>()();
+ max_range = input_max_tensor->scalar<T>()();
}
- d.memcpyDeviceToHost(&min_range, input_min.data(), sizeof(T));
- d.memcpyDeviceToHost(&max_range, input_max.data(), sizeof(T));
// Calculate the range for the simulated integer quantization:
// e.g. [-128,127] for signed = true, num_bits = 8,
diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc
new file mode 100644
index 0000000000..53f431ef3c
--- /dev/null
+++ b/tensorflow/core/kernels/queue_op.cc
@@ -0,0 +1,367 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/queue_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+QueueOp::QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
+ if (capacity_ < 0) {
+ capacity_ = QueueBase::kUnbounded;
+ }
+ OP_REQUIRES_OK(context,
+ context->GetAttr("component_types", &component_types_));
+}
+
+void QueueOp::Compute(OpKernelContext* context) {
+ ResourceOpKernel<QueueInterface>::Compute(context);
+ mutex_lock l(mu_);
+ if (resource_ && context->track_allocations()) {
+ context->record_persistent_memory_allocation(resource_->MemoryUsed());
+ }
+}
+
+Status QueueOp::VerifyResource(QueueInterface* queue) {
+ return queue->MatchesNodeDef(def());
+}
+
+
+QueueOpKernel::QueueOpKernel(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+void QueueOpKernel::ComputeAsync(OpKernelContext* ctx, DoneCallback callback) {
+ QueueInterface* queue;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
+ callback);
+ }
+ ComputeAsync(ctx, queue, [callback, queue]() {
+ queue->Unref();
+ callback();
+ });
+}
+
+QueueAccessOpKernel::QueueAccessOpKernel(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
+ // TODO(keveman): Enable timeout.
+ OP_REQUIRES(context, timeout_ == -1,
+ errors::InvalidArgument("Timeout not supported yet."));
+}
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+EnqueueOp::EnqueueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
+ queue->TryEnqueue(tuple, ctx, callback);
+}
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+EnqueueManyOp::EnqueueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void EnqueueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ DataTypeVector expected_inputs;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ expected_inputs.push_back(DT_RESOURCE);
+ } else {
+ expected_inputs.push_back(DT_STRING_REF);
+ }
+ for (DataType dt : queue->component_dtypes()) {
+ expected_inputs.push_back(dt);
+ }
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}), callback);
+
+ QueueInterface::Tuple tuple;
+ OpInputList components;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
+ callback);
+ for (const Tensor& Tcomponent : components) {
+ tuple.push_back(Tcomponent);
+ }
+
+ OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
+ queue->TryEnqueueMany(tuple, ctx, callback);
+}
+
+EnqueueManyOp::~EnqueueManyOp() = default;
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+DequeueOp::DequeueOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueOp::~DequeueOp() = default;
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+DequeueManyOp::DequeueManyOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueManyOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueManyOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, false /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueManyOp::~DequeueManyOp() = default;
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+DequeueUpToOp::DequeueUpToOp(OpKernelConstruction* context)
+ : QueueAccessOpKernel(context) {}
+
+void DequeueUpToOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ const Tensor& Tnum_elements = ctx->input(1);
+ int32 num_elements = Tnum_elements.flat<int32>()(0);
+
+ OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
+ errors::InvalidArgument("DequeueUpToOp requested ",
+ num_elements, " < 0 elements"),
+ callback);
+
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ ctx->MatchSignature({DT_RESOURCE, DT_INT32}, queue->component_dtypes()),
+ callback);
+ } else {
+ OP_REQUIRES_OK_ASYNC(ctx,
+ ctx->MatchSignature({DT_STRING_REF, DT_INT32},
+ queue->component_dtypes()),
+ callback);
+ }
+
+ queue->TryDequeueMany(
+ num_elements, ctx, true /* allow_small_batch */,
+ [ctx, callback](const QueueInterface::Tuple& tuple) {
+ if (!ctx->status().ok()) {
+ callback();
+ return;
+ }
+ OpOutputList output_components;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->output_list("components", &output_components), callback);
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ output_components.set(i, tuple[i]);
+ }
+ callback();
+ });
+}
+
+DequeueUpToOp::~DequeueUpToOp() = default;
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+QueueCloseOp::QueueCloseOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
+ &cancel_pending_enqueues_));
+}
+
+void QueueCloseOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ queue->Close(ctx, cancel_pending_enqueues_, callback);
+}
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+QueueSizeOp::QueueSizeOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueSizeOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_size = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
+ Tqueue_size->flat<int32>().setConstant(queue->size());
+ callback();
+}
+
+QueueIsClosedOp::QueueIsClosedOp(OpKernelConstruction* context)
+ : QueueOpKernel(context) {}
+
+void QueueIsClosedOp::ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) {
+ Tensor* Tqueue_is_closed = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
+ Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
+ callback();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h
index 6c19f9841c..2efd838a5f 100644
--- a/tensorflow/core/kernels/queue_op.h
+++ b/tensorflow/core/kernels/queue_op.h
@@ -13,12 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_QUEUE_OP_H_
-#define TENSORFLOW_KERNELS_QUEUE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
#include <deque>
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -32,22 +33,9 @@ namespace tensorflow {
// Defines a QueueOp, an abstract class for Queue construction ops.
class QueueOp : public ResourceOpKernel<QueueInterface> {
public:
- QueueOp(OpKernelConstruction* context) : ResourceOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("capacity", &capacity_));
- if (capacity_ < 0) {
- capacity_ = QueueBase::kUnbounded;
- }
- OP_REQUIRES_OK(context,
- context->GetAttr("component_types", &component_types_));
- }
+ QueueOp(OpKernelConstruction* context);
- void Compute(OpKernelContext* context) override {
- ResourceOpKernel<QueueInterface>::Compute(context);
- mutex_lock l(mu_);
- if (resource_ && context->track_allocations()) {
- context->record_persistent_memory_allocation(resource_->MemoryUsed());
- }
- }
+ void Compute(OpKernelContext* context) override;
protected:
// Variables accessible by subclasses
@@ -55,9 +43,7 @@ class QueueOp : public ResourceOpKernel<QueueInterface> {
DataTypeVector component_types_;
private:
- Status VerifyResource(QueueInterface* queue) override {
- return queue->MatchesNodeDef(def());
- }
+ Status VerifyResource(QueueInterface* queue) override;
};
class TypedQueueOp : public QueueOp {
@@ -75,6 +61,211 @@ class TypedQueueOp : public QueueOp {
}
};
+// Queue manipulator kernels
+
+class QueueOpKernel : public AsyncOpKernel {
+ public:
+ explicit QueueOpKernel(OpKernelConstruction* context);
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final;
+
+ protected:
+ virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) = 0;
+};
+
+class QueueAccessOpKernel : public QueueOpKernel {
+ public:
+ explicit QueueAccessOpKernel(OpKernelConstruction* context);
+
+ protected:
+ int64 timeout_;
+};
+
+// Defines an EnqueueOp, the execution of which enqueues a tuple of
+// tensors in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+class EnqueueOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
+};
+
+// Defines an EnqueueManyOp, the execution of which slices each
+// component of a tuple of tensors along the 0th dimension, and
+// enqueues tuples of slices in the given Queue.
+//
+// The op has 1 + k inputs, where k is the number of components in the
+// tuples stored in the given Queue:
+// - Input 0: queue handle.
+// - Input 1: 0th element of the tuple.
+// - ...
+// - Input (1+k): kth element of the tuple.
+//
+// N.B. All tuple components must have the same size in the 0th
+// dimension.
+class EnqueueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit EnqueueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~EnqueueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
+};
+
+// Defines a DequeueOp, the execution of which dequeues a tuple of
+// tensors from the given Queue.
+//
+// The op has one input, which is the handle of the appropriate
+// Queue. The op has k outputs, where k is the number of components in
+// the tuples stored in the given Queue, and output i is the ith
+// component of the dequeued tuple.
+class DequeueOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
+};
+
+// Defines a DequeueManyOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+class DequeueManyOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueManyOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueManyOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
+};
+
+// Defines a DequeueUpToOp, the execution of which concatenates the
+// requested number of elements from the given Queue along the 0th
+// dimension, and emits the result as a single tuple of tensors.
+//
+// The difference between this op and DequeueMany is the handling when
+// the Queue is closed. While the DequeueMany op will return if there
+// an error when there are less than num_elements elements left in the
+// closed queue, this op will return between 1 and
+// min(num_elements, elements_remaining_in_queue), and will not block.
+// If there are no elements left, then the standard DequeueMany error
+// is returned.
+//
+// This op only works if the underlying Queue implementation accepts
+// the allow_small_batch = true parameter to TryDequeueMany.
+// If it does not, an errors::Unimplemented exception is returned.
+//
+// The op has two inputs:
+// - Input 0: the handle to a queue.
+// - Input 1: the number of elements to dequeue.
+//
+// The op has k outputs, where k is the number of components in the
+// tuples stored in the given Queue, and output i is the ith component
+// of the dequeued tuple.
+//
+// The op has one attribute: allow_small_batch. If the Queue supports
+// it, setting this to true causes the queue to return smaller
+// (possibly zero length) batches when it is closed, up to however
+// many elements are available when the op executes. In this case,
+// the Queue does not block when closed.
+class DequeueUpToOp : public QueueAccessOpKernel {
+ public:
+ explicit DequeueUpToOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ ~DequeueUpToOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
+};
+
+// Defines a QueueCloseOp, which closes the given Queue. Closing a
+// Queue signals that no more elements will be enqueued in it.
+//
+// The op has one input, which is the handle of the appropriate Queue.
+class QueueCloseOp : public QueueOpKernel {
+ public:
+ explicit QueueCloseOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ bool cancel_pending_enqueues_;
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
+};
+
+// Defines a QueueSizeOp, which computes the number of elements in the
+// given Queue, and emits it as an output tensor.
+//
+// The op has one input, which is the handle of the appropriate Queue;
+// and one output, which is a single-element tensor containing the current
+// size of that Queue.
+class QueueSizeOp : public QueueOpKernel {
+ public:
+ explicit QueueSizeOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
+};
+
+class QueueIsClosedOp : public QueueOpKernel {
+ public:
+ explicit QueueIsClosedOp(OpKernelConstruction* context);
+
+ protected:
+ void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
+ DoneCallback callback) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_QUEUE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_QUEUE_OP_H_
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index 46a02854d7..c4d404259b 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -13,437 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// See docs in ../ops/data_flow_ops.cc.
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-class QueueOpKernel : public AsyncOpKernel {
- public:
- explicit QueueOpKernel(OpKernelConstruction* context)
- : AsyncOpKernel(context) {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
- QueueInterface* queue;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &queue), callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &queue),
- callback);
- }
- ComputeAsync(ctx, queue, [callback, queue]() {
- queue->Unref();
- callback();
- });
- }
-
- protected:
- virtual void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) = 0;
-};
-
-class QueueAccessOpKernel : public QueueOpKernel {
- public:
- explicit QueueAccessOpKernel(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
- // TODO(keveman): Enable timeout.
- OP_REQUIRES(context, timeout_ == -1,
- errors::InvalidArgument("Timeout not supported yet."));
- }
-
- protected:
- int64 timeout_;
-};
-
-// Defines an EnqueueOp, the execution of which enqueues a tuple of
-// tensors in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-class EnqueueOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateTuple(tuple), callback);
- queue->TryEnqueue(tuple, ctx, callback);
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueue").Device(DEVICE_CPU), EnqueueOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueV2").Device(DEVICE_CPU), EnqueueOp);
-// Defines an EnqueueManyOp, the execution of which slices each
-// component of a tuple of tensors along the 0th dimension, and
-// enqueues tuples of slices in the given Queue.
-//
-// The op has 1 + k inputs, where k is the number of components in the
-// tuples stored in the given Queue:
-// - Input 0: queue handle.
-// - Input 1: 0th element of the tuple.
-// - ...
-// - Input (1+k): kth element of the tuple.
-//
-// N.B. All tuple components must have the same size in the 0th
-// dimension.
-class EnqueueManyOp : public QueueAccessOpKernel {
- public:
- explicit EnqueueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- DataTypeVector expected_inputs;
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- expected_inputs.push_back(DT_RESOURCE);
- } else {
- expected_inputs.push_back(DT_STRING_REF);
- }
- for (DataType dt : queue->component_dtypes()) {
- expected_inputs.push_back(dt);
- }
- OP_REQUIRES_OK_ASYNC(ctx, ctx->MatchSignature(expected_inputs, {}),
- callback);
-
- QueueInterface::Tuple tuple;
- OpInputList components;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("components", &components),
- callback);
- for (const Tensor& Tcomponent : components) {
- tuple.push_back(Tcomponent);
- }
-
- OP_REQUIRES_OK_ASYNC(ctx, queue->ValidateManyTuple(tuple), callback);
- queue->TryEnqueueMany(tuple, ctx, callback);
- }
-
- ~EnqueueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(EnqueueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueMany").Device(DEVICE_CPU),
EnqueueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU),
EnqueueManyOp);
-// Defines a DequeueOp, the execution of which dequeues a tuple of
-// tensors from the given Queue.
-//
-// The op has one input, which is the handle of the appropriate
-// Queue. The op has k outputs, where k is the number of components in
-// the tuples stored in the given Queue, and output i is the ith
-// component of the dequeued tuple.
-class DequeueOp : public QueueAccessOpKernel {
- public:
- explicit DequeueOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_RESOURCE}, queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->MatchSignature({DT_STRING_REF}, queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeue(ctx, [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components), callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeue").Device(DEVICE_CPU), DequeueOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueV2").Device(DEVICE_CPU), DequeueOp);
-// Defines a DequeueManyOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-class DequeueManyOp : public QueueAccessOpKernel {
- public:
- explicit DequeueManyOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueManyOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, false /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueManyOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueManyOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueMany").Device(DEVICE_CPU),
DequeueManyOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU),
DequeueManyOp);
-// Defines a DequeueUpToOp, the execution of which concatenates the
-// requested number of elements from the given Queue along the 0th
-// dimension, and emits the result as a single tuple of tensors.
-//
-// The difference between this op and DequeueMany is the handling when
-// the Queue is closed. While the DequeueMany op will return if there
-// an error when there are less than num_elements elements left in the
-// closed queue, this op will return between 1 and
-// min(num_elements, elements_remaining_in_queue), and will not block.
-// If there are no elements left, then the standard DequeueMany error
-// is returned.
-//
-// This op only works if the underlying Queue implementation accepts
-// the allow_small_batch = true parameter to TryDequeueMany.
-// If it does not, an errors::Unimplemented exception is returned.
-//
-// The op has two inputs:
-// - Input 0: the handle to a queue.
-// - Input 1: the number of elements to dequeue.
-//
-// The op has k outputs, where k is the number of components in the
-// tuples stored in the given Queue, and output i is the ith component
-// of the dequeued tuple.
-//
-// The op has one attribute: allow_small_batch. If the Queue supports
-// it, setting this to true causes the queue to return smaller
-// (possibly zero length) batches when it is closed, up to however
-// many elements are available when the op executes. In this case,
-// the Queue does not block when closed.
-class DequeueUpToOp : public QueueAccessOpKernel {
- public:
- explicit DequeueUpToOp(OpKernelConstruction* context)
- : QueueAccessOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- const Tensor& Tnum_elements = ctx->input(1);
- int32 num_elements = Tnum_elements.flat<int32>()(0);
-
- OP_REQUIRES_ASYNC(ctx, num_elements >= 0,
- errors::InvalidArgument("DequeueUpToOp requested ",
- num_elements, " < 0 elements"),
- callback);
-
- if (ctx->input_dtype(0) == DT_RESOURCE) {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_RESOURCE, DT_INT32},
- queue->component_dtypes()),
- callback);
- } else {
- OP_REQUIRES_OK_ASYNC(ctx,
- ctx->MatchSignature({DT_STRING_REF, DT_INT32},
- queue->component_dtypes()),
- callback);
- }
-
- queue->TryDequeueMany(
- num_elements, ctx, true /* allow_small_batch */,
- [ctx, callback](const QueueInterface::Tuple& tuple) {
- if (!ctx->status().ok()) {
- callback();
- return;
- }
- OpOutputList output_components;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->output_list("components", &output_components),
- callback);
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- output_components.set(i, tuple[i]);
- }
- callback();
- });
- }
-
- ~DequeueUpToOp() override {}
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(DequeueUpToOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpTo").Device(DEVICE_CPU),
DequeueUpToOp);
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU),
DequeueUpToOp);
-// Defines a QueueCloseOp, which closes the given Queue. Closing a
-// Queue signals that no more elements will be enqueued in it.
-//
-// The op has one input, which is the handle of the appropriate Queue.
-class QueueCloseOp : public QueueOpKernel {
- public:
- explicit QueueCloseOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
- &cancel_pending_enqueues_));
- }
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- queue->Close(ctx, cancel_pending_enqueues_, callback);
- }
-
- private:
- bool cancel_pending_enqueues_;
- TF_DISALLOW_COPY_AND_ASSIGN(QueueCloseOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueClose").Device(DEVICE_CPU), QueueCloseOp);
REGISTER_KERNEL_BUILDER(Name("QueueCloseV2").Device(DEVICE_CPU), QueueCloseOp);
-// Defines a QueueSizeOp, which computes the number of elements in the
-// given Queue, and emits it as an output tensor.
-//
-// The op has one input, which is the handle of the appropriate Queue;
-// and one output, which is a single-element tensor containing the current
-// size of that Queue.
-class QueueSizeOp : public QueueOpKernel {
- public:
- explicit QueueSizeOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_size = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &Tqueue_size));
- Tqueue_size->flat<int32>().setConstant(queue->size());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueSizeOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueSize").Device(DEVICE_CPU), QueueSizeOp);
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2").Device(DEVICE_CPU), QueueSizeOp);
-class QueueIsClosedOp : public QueueOpKernel {
- public:
- explicit QueueIsClosedOp(OpKernelConstruction* context)
- : QueueOpKernel(context) {}
-
- protected:
- void ComputeAsync(OpKernelContext* ctx, QueueInterface* queue,
- DoneCallback callback) override {
- Tensor* Tqueue_is_closed = nullptr;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &Tqueue_is_closed));
- Tqueue_is_closed->flat<bool>().setConstant(queue->is_closed());
- callback();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(QueueIsClosedOp);
-};
-
REGISTER_KERNEL_BUILDER(Name("QueueIsClosed").Device(DEVICE_CPU),
QueueIsClosedOp);
REGISTER_KERNEL_BUILDER(Name("QueueIsClosedV2").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 0de2ebb590..9af4cc23b6 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -295,7 +295,11 @@ __global__ void ColumnReduceMax16ColumnsKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ storage_type<value_type> partial_sums[32 * 33];
+ // This is to mimic the following, but without any constructors:
+ // __shared__ storage_type<value_type> partial_sums[32 * 33];
+ __shared__ __align__(
+ alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
+ value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
row += rows_per_warp * gridDim.y * blockDim.y;
for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
@@ -344,7 +348,11 @@ __global__ void ColumnReduceKernel(
// 1D array necessary due to bug in CUDA 9 compiler.
// TODO(nluehr) revert to 2D array when compiler is ready.
- __shared__ storage_type<value_type> partial_sums[32 * 33];
+ // This is to mimic the following, but without constructors:
+ // __shared__ storage_type<value_type> partial_sums[32 * 33];
+ __shared__ __align__(
+ alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
+ value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
row += gridDim.y * blockDim.y;
diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc
index 4188ad233e..ac301f3342 100644
--- a/tensorflow/core/kernels/reshape_util.cc
+++ b/tensorflow/core/kernels/reshape_util.cc
@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/reshape_util.h"
+
#include <algorithm>
#include <numeric>
#include <unordered_map>
@@ -107,15 +108,19 @@ void Reshape(OpKernelContext *context, const Tensor &input_indices_in,
}
gtl::InlinedVector<int64, 8> input_strides(input_rank);
- input_strides[input_rank - 1] = 1;
- for (int d = input_rank - 2; d >= 0; --d) {
- input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+ if (input_rank > 0) {
+ input_strides[input_rank - 1] = 1;
+ for (int d = input_rank - 2; d >= 0; --d) {
+ input_strides[d] = input_strides[d + 1] * input_shape.dim_size(d + 1);
+ }
}
gtl::InlinedVector<int64, 8> output_strides(output_rank);
- output_strides[output_rank - 1] = 1;
- for (int d = output_rank - 2; d >= 0; --d) {
- output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
+ if (output_rank > 0) {
+ output_strides[output_rank - 1] = 1;
+ for (int d = output_rank - 2; d >= 0; --d) {
+ output_strides[d] = output_strides[d + 1] * output_shape.dim_size(d + 1);
+ }
}
Tensor *result_indices = nullptr;
diff --git a/tensorflow/core/kernels/resize_area_op_test.cc b/tensorflow/core/kernels/resize_area_op_test.cc
index a7e06ef15a..84ff090b54 100644
--- a/tensorflow/core/kernels/resize_area_op_test.cc
+++ b/tensorflow/core/kernels/resize_area_op_test.cc
@@ -124,7 +124,8 @@ class ResizeAreaOpTest : public OpsTestBase {
? (j + 1 > in_x1 ? width_scale : j + 1 - in_x)
: (j + 1 > in_x1 ? in_x1 - j : 1.0);
for (int64 c = 0; c < channels; ++c) {
-#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
+#define BOUND(val, limit) \
+ std::min(((limit)-int64{1}), (std::max(int64{0}, (val))))
sum_data(c) +=
static_cast<float>(input_data(b, BOUND(i, in_height),
BOUND(j, in_width), c)) *
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index ff38026ac7..e1fc2ea128 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -143,14 +143,10 @@ class ScatterNdUpdateOp : public OpKernel {
void Compute(OpKernelContext* c) override {
if (dtype_ == DT_RESOURCE) {
- if (use_exclusive_lock_) {
- Var* v;
- OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
- mutex_lock m(*v->mu());
- DoCompute(c);
- } else {
- DoCompute(c);
- }
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ mutex_lock m(*v->mu());
+ DoCompute(c);
} else if (use_exclusive_lock_) {
// If we're here, it means the input type is a ref.
DCHECK(IsRefType(c->input_dtype(0)));
@@ -176,13 +172,7 @@ class ScatterNdUpdateOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
Tensor* t = v->tensor();
- if (!use_exclusive_lock_) {
- // We're not holding the lock in the outer scope so need it here.
- mutex_lock m(*v->mu());
- OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
- } else {
- OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
- }
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
params = *t;
params_shape = params.shape();
} else if (IsRefType(c->input_dtype(0))) {
@@ -260,7 +250,9 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdNonAliasingAdd", \
scatter_nd_op::UpdateOp::ADD); \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdSub", \
- scatter_nd_op::UpdateOp::SUB);
+ scatter_nd_op::UpdateOp::SUB); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
+ type, dev, "ResourceScatterNdAdd", scatter_nd_op::UpdateOp::ADD);
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index d0703d7576..d28e35157b 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
-#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
@@ -138,4 +144,4 @@ struct Highest {
} // namespace functor
} // namespace tensorflow
-#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index 9e041d98f7..852cef29c7 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -36,6 +36,8 @@ limitations under the License.
namespace tensorflow {
+namespace {
+
using sparse::SparseTensor;
template <typename T>
@@ -306,267 +308,6 @@ Status SerializeManySparseOpBase<Variant>::Serialize(const Tensor& input,
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
-template <typename T>
-class DeserializeSparseOp : public OpKernel {
- public:
- explicit DeserializeSparseOp(OpKernelConstruction* context)
- : OpKernel(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
- }
-
- void Compute(OpKernelContext* context) override {
- const Tensor& serialized_sparse = context->input(0);
- const int ndims = serialized_sparse.shape().dims();
-
- OP_REQUIRES(
- context, ndims > 0,
- errors::InvalidArgument("Serialized sparse should have non-zero rank ",
- serialized_sparse.shape().DebugString()));
-
- OP_REQUIRES(context, serialized_sparse.shape().dim_size(ndims - 1) == 3,
- errors::InvalidArgument(
- "Serialized sparse should have 3 as the last dimension ",
- serialized_sparse.shape().DebugString()));
-
- int num_sparse_tensors = 1;
- for (int i = 0; i < ndims - 1; ++i) {
- num_sparse_tensors *= serialized_sparse.shape().dim_size(i);
- }
-
- OP_REQUIRES(
- context, num_sparse_tensors > 0,
- errors::InvalidArgument(
- "Serialized sparse should have at least 1 serialized tensor, "
- "but has a zero dimension ",
- serialized_sparse.shape().DebugString()));
-
- if (num_sparse_tensors == 1 && serialized_sparse.shape().dims() == 0) {
- // Special case with a single sparse tensor. We can avoid data
- // motion in the Concat and Reshape.
- const auto& serialized_sparse_t = serialized_sparse.vec<T>();
-
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(0), serialized_sparse_t(1),
- serialized_sparse_t(2), dtype_, 0 /* index */,
- &output_indices, &output_values, &output_shape));
- context->set_output(0, output_indices);
- context->set_output(1, output_values);
- context->set_output(2, output_shape);
- return;
- }
-
- std::vector<Tensor> indices;
- std::vector<Tensor> values;
- TensorShape shape;
- indices.reserve(num_sparse_tensors);
- values.reserve(num_sparse_tensors);
-
- const auto& serialized_sparse_t = serialized_sparse.flat_inner_dims<T, 2>();
- for (int i = 0; i < num_sparse_tensors; ++i) {
- Tensor output_indices;
- Tensor output_values;
- Tensor output_shape;
- OP_REQUIRES_OK(context,
- this->GetAndValidateSparseTensor(
- serialized_sparse_t(i, 0), serialized_sparse_t(i, 1),
- serialized_sparse_t(i, 2), dtype_, i, &output_indices,
- &output_values, &output_shape));
- int64 num_entries = output_indices.dim_size(0);
- int rank = output_indices.dim_size(1);
-
- // Now we expand each SparseTensors' indices and shape by
- // prefixing a dimension
- Tensor expanded_indices(DT_INT64, TensorShape({num_entries, 1 + rank}));
- const auto& output_indices_t = output_indices.matrix<int64>();
- auto expanded_indices_t = expanded_indices.matrix<int64>();
- expanded_indices_t.chip<1>(0).setZero();
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
- expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
-
- Tensor expanded_shape(DT_INT64, TensorShape({1 + rank}));
- const auto& output_shape_t = output_shape.vec<int64>();
- auto expanded_shape_t = expanded_shape.vec<int64>();
- expanded_shape_t(0) = 1;
- std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
-
- TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
-
- indices.push_back(expanded_indices);
- values.push_back(output_values);
- if (i == 0) {
- shape = expanded_tensor_shape;
- } else {
- OP_REQUIRES(
- context, shape.dims() == expanded_tensor_shape.dims(),
- errors::InvalidArgument(
- "Inconsistent shape across SparseTensors: rank prior to "
- "SparseTensor[",
- i, "] was: ", shape.dims() - 1, " but rank of SparseTensor[", i,
- "] is: ", expanded_tensor_shape.dims() - 1));
- for (int j = 1; j < shape.dims(); ++j) {
- // NOTE(mrry): For compatibility with the implementations of
- // DeserializeManySparse, and many ops that generate
- // SparseTensors to batch that do not have a fixed
- // dense_shape (e.g. `tf.parse_single_example()`), we
- // compute the maximum in each dimension to find the
- // smallest dense_shape that bounds all of the input
- // SparseTensors.
- shape.set_dim(j, std::max(shape.dim_size(j),
- expanded_tensor_shape.dim_size(j)));
- }
- }
- }
-
- // Dimension 0 is the primary dimension.
- int rank = shape.dims();
- gtl::InlinedVector<int64, 8> std_order(rank);
- std::iota(std_order.begin(), std_order.end(), 0);
-
- std::vector<SparseTensor> tensors;
- tensors.reserve(num_sparse_tensors);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- tensors.emplace_back(indices[i], values[i], shape, std_order);
- }
-
- gtl::optional<SparseTensor> maybe_output;
-#define HANDLE_TYPE(T) \
- case DataTypeToEnum<T>::value: { \
- maybe_output = SparseTensor::Concat<T>(tensors); \
- break; \
- }
-
- switch (dtype_) {
- TF_CALL_ALL_TYPES(HANDLE_TYPE);
- TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
-#undef HANDLE_TYPE
- default:
- OP_REQUIRES(context, false,
- errors::Unimplemented(
- "DeserializeSparse Unhandled data type: ", dtype_));
- }
- DCHECK(maybe_output);
- SparseTensor& output = maybe_output.value();
-
- // Compute the input shape for the reshape operation.
- Tensor input_shape(DT_INT64, TensorShape({output.dims()}));
- std::copy_n(output.shape().data(), output.dims(),
- input_shape.vec<int64>().data());
-
- // Compute the target shape for the reshape operation.
- Tensor target_shape(DT_INT64, TensorShape({ndims + output.dims() - 2}));
- for (int i = 0; i < ndims - 1; ++i) {
- target_shape.vec<int64>()(i) = serialized_sparse.shape().dim_size(i);
- }
- for (int i = 0; i < output.dims() - 1; ++i) {
- target_shape.vec<int64>()(i + ndims - 1) = output.shape().data()[i + 1];
- }
-
- Tensor output_indices;
- Tensor output_shape;
- Reshape(context, output.indices(), input_shape, target_shape,
- 0 /* output indices index */, 2 /* output shape index */);
- context->set_output(1, output.values());
- }
-
- protected:
- Status Deserialize(const T& serialized, Tensor* result);
-
- Status GetAndValidateSparseTensor(
- const T& serialized_indices, const T& serialized_values,
- const T& serialized_shape, DataType values_dtype, int index,
- Tensor* output_indices, Tensor* output_values, Tensor* output_shape) {
- // Deserialize and validate the indices.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_indices, output_indices));
- if (!TensorShapeUtils::IsMatrix(output_indices->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 0] to represent an index matrix but received shape ",
- output_indices->shape().DebugString());
- }
- int64 num_entries = output_indices->dim_size(0);
- int rank = output_indices->dim_size(1);
-
- // Deserialize and validate the values.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_values, output_values));
- if (!TensorShapeUtils::IsVector(output_values->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to represent a values vector but received shape ",
- output_values->shape().DebugString());
- }
- if (values_dtype != output_values->dtype()) {
- return errors::InvalidArgument(
- "Requested SparseTensor of type ", DataTypeString(values_dtype),
- " but SparseTensor[", index,
- "].values.dtype() == ", DataTypeString(output_values->dtype()));
- }
- if (num_entries != output_values->dim_size(0)) {
- return errors::InvalidArgument(
- "Expected row counts of SparseTensor[", index,
- "].indices and SparseTensor[", index,
- "].values to match but they do not: ", num_entries, " vs. ",
- output_values->dim_size(0));
- }
-
- // Deserialize and validate the shape.
- TF_RETURN_IF_ERROR(this->Deserialize(serialized_shape, output_shape));
- if (!TensorShapeUtils::IsVector(output_shape->shape())) {
- return errors::InvalidArgument(
- "Expected serialized_sparse[", index,
- ", 1] to be a shape vector but its shape is ",
- output_shape->shape().DebugString());
- }
- if (rank != output_shape->dim_size(0)) {
- return errors::InvalidArgument("Expected column counts of SparseTensor[",
- index,
- "].indices to match size of SparseTensor[",
- index, "].shape but they do not: ", rank,
- " vs. ", output_shape->dim_size(0));
- }
- return Status::OK();
- }
-
- DataType dtype_;
-};
-
-template <>
-Status DeserializeSparseOp<string>::Deserialize(const string& serialized,
- Tensor* result) {
- TensorProto proto;
- if (!ParseProtoUnlimited(&proto, serialized)) {
- return errors::InvalidArgument("Could not parse serialized proto");
- }
- Tensor tensor;
- if (!tensor.FromProto(proto)) {
- return errors::InvalidArgument("Could not construct tensor from proto");
- }
- *result = tensor;
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<string>("Tserialized"),
- DeserializeSparseOp<string>)
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse").Device(DEVICE_CPU),
- DeserializeSparseOp<string>)
-
-template <>
-Status DeserializeSparseOp<Variant>::Deserialize(const Variant& serialized,
- Tensor* result) {
- *result = *serialized.get<Tensor>();
- return Status::OK();
-}
-
-REGISTER_KERNEL_BUILDER(Name("DeserializeSparse")
- .Device(DEVICE_CPU)
- .TypeConstraint<Variant>("Tserialized"),
- DeserializeSparseOp<Variant>)
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sparse_slice_grad_op.cc b/tensorflow/core/kernels/sparse_slice_grad_op.cc
new file mode 100644
index 0000000000..90a39ed818
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_slice_grad_op.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SparseSliceGradOp : public OpKernel {
+ public:
+ explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext *ctx) override {
+ const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start;
+ OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad));
+ OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices));
+ OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start));
+ OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices));
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsMatrix(input_indices->shape()) &&
+ TensorShapeUtils::IsMatrix(output_indices->shape()),
+ errors::InvalidArgument(
+ "Input and output indices should be matrices "
+ "but received shapes: ",
+ input_indices->shape().DebugString(), " and ",
+ output_indices->shape().DebugString()));
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()),
+ errors::InvalidArgument(
+ "Input backprop_val_grad should be a vector but received shape: ",
+ backprop_val_grad->shape().DebugString()));
+ OP_REQUIRES(
+ ctx,
+ input_indices->dim_size(1) == output_indices->dim_size(1),
+ errors::InvalidArgument("The input and output should have the same "
+ "ndims: got: ", input_indices->dim_size(1), " and ",
+ output_indices->dim_size(1)));
+ OP_REQUIRES(
+ ctx, output_indices->dim_size(0) <= input_indices->dim_size(0),
+ errors::InvalidArgument("# rows of output_indices should be not greater "
+ "than of input_indices, got ",
+ output_indices->dim_size(0), " and ",
+ input_indices->dim_size(0)));
+ OP_REQUIRES(
+ ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0),
+ errors::InvalidArgument("# elements of backprop_val_grad and # rows of "
+ "output_indices should match (#nnz of sum): got ",
+ backprop_val_grad->NumElements(), " and ",
+ output_indices->dim_size(0)));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()),
+ errors::InvalidArgument(
+ "The input_start should be a vector but received shape ",
+ input_start->shape().DebugString()));
+
+ const int num_dims = input_indices->dim_size(1);
+ OP_REQUIRES(ctx, num_dims == input_start->NumElements(),
+ errors::InvalidArgument(
+ "Expected input_start to be a vector of length ", num_dims,
+ " but got length ", input_start->NumElements()));
+
+ const int64 input_nnz = input_indices->dim_size(0);
+
+ Tensor *val_grad;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad));
+
+ T *val_grad_flat = val_grad->flat<T>().data();
+ const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data();
+ memset(val_grad_flat, 0, sizeof(T) * input_nnz);
+
+ // Fill gradients for position where indices of input and output are same.
+ const auto input_indices_mat = input_indices->matrix<int64>();
+ const auto output_indices_mat = output_indices->matrix<int64>();
+ const auto input_start_flat = input_start->flat<int64>();
+ int64 j = 0;
+ for (int64 i = 0; i < input_nnz && j < backprop_val_grad->NumElements();
+ ++i) {
+ bool is_same = true;
+ for (int d = 0; d < num_dims; ++d) {
+ const int64 a = input_indices_mat(i, d);
+ const int64 b = output_indices_mat(j, d);
+ const int64 offset = input_start_flat(d);
+ if (a != b + offset) {
+ is_same = false;
+ break;
+ }
+ }
+ if (is_same) {
+ val_grad_flat[i] = backprop_val_grad_flat[j];
+ ++j;
+ }
+ }
+ OP_REQUIRES(
+ ctx, backprop_val_grad->NumElements() == j,
+ errors::Internal("Elements of backprop_val_grad aren't all propagated. "
+ "Num elements:", backprop_val_grad->NumElements(),
+ ", used: ", j));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseSliceGradOp<type>)
+
+TF_CALL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 4c2b312c34..26ab72f12e 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -43,6 +44,63 @@ std::vector<string> Split(const string& str, const string& delimiter,
return char_vector;
}
+std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
+ // This SplitV2 method matches the behavior of python's str.split:
+ // If sep is given, consecutive delimiters are not grouped together
+ // and are deemed to delimit empty strings (for example, '1,,2'.split(',')
+ // returns ['1', '', '2']). The sep argument may consist of multiple
+ // characters (for example, '1<>2<>3'.split('<>') returns ['1', '2', '3']).
+ // Splitting an empty string with a specified separator returns [''].
+ //
+ // If sep is not specified or is None, a different splitting algorithm is
+ // applied: runs of consecutive whitespace are regarded as a single
+ // separator, and the result will contain no empty strings at the start or
+ // end if the string has leading or trailing whitespace. Consequently,
+ // splitting an empty string or a string consisting of just whitespace
+ // with a None separator returns [].
+
+ std::vector<string> result;
+
+ StringPiece text(str);
+ if (maxsplit == 0) {
+ result.emplace_back(std::string(text));
+ return result;
+ }
+
+ if (sep.empty()) {
+ StringPiece token;
+ // Remove leading whitespaces.
+ str_util::RemoveLeadingWhitespace(&text);
+ int split = 0;
+ while (str_util::ConsumeNonWhitespace(&text, &token)) {
+ result.emplace_back(std::string(token));
+ str_util::RemoveLeadingWhitespace(&text);
+ ++split;
+ if (maxsplit > 0 && split == maxsplit) {
+ result.emplace_back(std::string(text));
+ return result;
+ }
+ }
+ return result;
+ }
+ auto p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
+ int split = 0;
+ while (p != text.end()) {
+ StringPiece token = text.substr(0, p - text.begin());
+ result.emplace_back(std::string(token));
+ text.remove_prefix(token.size());
+ text.remove_prefix(sep.size());
+ ++split;
+ if (maxsplit > 0 && split == maxsplit) {
+ result.emplace_back(std::string(text));
+ return result;
+ }
+ p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
+ }
+ result.emplace_back(std::string(text));
+ return result;
+}
+
} // namespace
class StringSplitOp : public OpKernel {
@@ -122,6 +180,78 @@ class StringSplitOp : public OpKernel {
bool skip_empty_;
};
+class StringSplitV2Op : public OpKernel {
+ public:
+ explicit StringSplitV2Op(OpKernelConstruction* context)
+ : OpKernel(context), maxsplit_(-1) {
+ OP_REQUIRES_OK(context, context->GetAttr("maxsplit", &maxsplit_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_tensor->shape()),
+ errors::InvalidArgument("input must be a vector, got shape: ",
+ input_tensor->shape().DebugString()));
+
+ const auto input_vec = input_tensor->vec<string>();
+ const int64 batch_size = input_vec.dimension(0);
+
+ const Tensor* sep_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("sep", &sep_tensor));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(sep_tensor->shape()),
+ errors::InvalidArgument("sep must be a scalar, got shape: ",
+ sep_tensor->shape().DebugString()));
+ const auto sep_vec = sep_tensor->flat<string>();
+ StringPiece sep(sep_vec(0));
+ std::vector<string> tokens;
+ // Guess that we'll be unpacking a handful of tokens per example.
+ static constexpr int kReserveSize = 4;
+ tokens.reserve(batch_size * kReserveSize);
+
+ int64 output_size = 0;
+ int64 max_num_entries = 0;
+ std::vector<int64> num_indices(batch_size);
+ for (int64 i = 0; i < batch_size; ++i) {
+ std::vector<string> parts = SplitV2(input_vec(i), sep, maxsplit_);
+ int64 n_entries = parts.size();
+ num_indices[i] = n_entries;
+ output_size += n_entries;
+ max_num_entries = std::max(max_num_entries, n_entries);
+ tokens.insert(tokens.end(), parts.begin(), parts.end());
+ }
+
+ Tensor* sp_indices_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({output_size, 2}),
+ &sp_indices_t));
+ Tensor* sp_tokens_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(1, TensorShape({output_size}), &sp_tokens_t));
+ Tensor* sp_shape_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({2}), &sp_shape_t));
+
+ auto sp_indices = sp_indices_t->matrix<int64>();
+ auto sp_tokens = sp_tokens_t->vec<string>();
+ auto sp_shape = sp_shape_t->vec<int64>();
+ sp_shape(0) = batch_size;
+ sp_shape(1) = max_num_entries;
+ size_t c = 0;
+ for (size_t i = 0; i < batch_size; ++i) {
+ for (size_t j = 0; j < num_indices[i]; ++j) {
+ sp_indices(c, 0) = i;
+ sp_indices(c, 1) = j;
+ sp_tokens(c) = tokens[c];
+ ++c;
+ }
+ }
+ }
+
+ private:
+ int maxsplit_;
+};
+
REGISTER_KERNEL_BUILDER(Name("StringSplit").Device(DEVICE_CPU), StringSplitOp);
+REGISTER_KERNEL_BUILDER(Name("StringSplitV2").Device(DEVICE_CPU),
+ StringSplitV2Op);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 7b85ff2ea4..765467bc1e 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -81,7 +81,8 @@ TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
std::atomic<int64> TensorArray::tensor_array_counter{0};
-Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
+Status TensorArray::CopyShapesFrom(TensorArray* rhs,
+ const TensorShape* shape_to_prepend) {
mutex_lock l(mu_);
mutex_lock l_rhs(rhs->mu_);
TF_RETURN_IF_ERROR(LockedReturnIfClosed());
@@ -97,7 +98,12 @@ Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
if (!rhs->tensors_[i].written) continue;
// Copy the shape over.
- tensors_[i].shape = rhs->tensors_[i].shape;
+ if (shape_to_prepend) {
+ tensors_[i].shape = *shape_to_prepend;
+ tensors_[i].shape.AppendShape(rhs->tensors_[i].shape);
+ } else {
+ tensors_[i].shape = rhs->tensors_[i].shape;
+ }
// Mark as written. Reads will know that if written is true and
// read is false, and cleared is false, to return zeros of the
// appropriate shape. Future aggregating writes will only use the shape
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 90b71e370c..68fab85770 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -325,13 +325,15 @@ class TensorArray : public ResourceBase {
bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
// Copy the TensorShapes from another TensorArray into this one.
+ // If `shapes_to_prepend` is set, expands the rank of the copied shape by
+ // prepending the passed in shape prefix to the shape values in `rhs`.
// The sizes of the two TensorArrays must match and this one
// may not have any entries filled in. This performs a "soft copy",
// essentially filling the current TensorArray with virtual
// zero-tensors, which will be replaced by future aggregate writes,
// or instantiated by future reads. Requires a non-const pointer
// to the rhs to access its mutex.
- Status CopyShapesFrom(TensorArray* rhs);
+ Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);
// Clear the TensorArray, including any Tensor references, and mark as closed.
void ClearAndMarkClosed() {
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index ef9748b1aa..5aa5d20b1a 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -264,7 +264,10 @@ REGISTER_GPU(bfloat16);
#endif // GOOGLE_CUDA
// GRADIENT *******************************************************************
-
+// Note that this op may have an optional third input. If present, it represents
+// a shape value. It indicates that element shape of this gradient array is that
+// shape value concatenated with the element shape of the original tensor array.
+// See TensorArrayGradWithShape.
class TensorArrayGradOp : public TensorArrayCreationOp {
public:
explicit TensorArrayGradOp(OpKernelConstruction* context)
@@ -325,18 +328,38 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
"previous write? Gradient calculation is impossible when multiple "
"writes are performed to the same index.");
}
+ TensorShape shape_to_prepend;
+ auto element_shape = PartialTensorShape();
+ if (ctx->num_inputs() > 2) {
+ TF_RETURN_IF_ERROR(
+ ctx->op_kernel().MakeShape(ctx->input(2), &shape_to_prepend));
+ auto ta_element_shape = tensor_array->ElemShape();
+ if (!ta_element_shape.unknown_rank()) {
+ std::vector<int64> dims;
+ for (auto dim : shape_to_prepend) {
+ dims.push_back(dim.size);
+ }
+ for (auto dim : ta_element_shape) {
+ dims.push_back(dim.size);
+ }
+ TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape(
+ gtl::ArraySlice<int64>(dims), &element_shape));
+ }
+ } else {
+ element_shape = tensor_array->ElemShape();
+ }
const auto key = strings::StrCat(output_handle(0), output_handle(1));
auto creator = [this, key, tensor_array, array_size, marked_size,
- tensor_array_output_handle,
+ element_shape, shape_to_prepend, tensor_array_output_handle,
output_handle](TensorArray** ret) -> Status {
*ret = new TensorArray(
key, tensor_array->ElemType(), *tensor_array_output_handle,
- array_size, tensor_array->ElemShape(),
- tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */,
- true /* multiple_writes_aggregate */, true /* is_grad */,
- marked_size /* marked_size */, true /* close_after_read */);
- return (*ret)->CopyShapesFrom(tensor_array);
+ array_size, element_shape, tensor_array->HasIdenticalElementShapes(),
+ false /* dynamic_size */, true /* multiple_writes_aggregate */,
+ true /* is_grad */, marked_size /* marked_size */,
+ true /* close_after_read */);
+ return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend);
};
Status s = rm->LookupOrCreate<TensorArray>(
@@ -361,7 +384,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU),
TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU),
TensorArrayGradOp);
-
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape").Device(DEVICE_CPU),
+ TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad")
.Device(DEVICE_GPU)
.HostMemory("handle")
@@ -377,6 +401,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3")
.HostMemory("handle")
.HostMemory("grad_handle"),
TensorArrayGradOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradWithShape")
+ .Device(DEVICE_GPU)
+ .HostMemory("handle")
+ .HostMemory("shape_to_prepend")
+ .HostMemory("grad_handle"),
+ TensorArrayGradOp);
// WRITE **********************************************************************
@@ -705,6 +735,7 @@ class TensorArrayPackOrGatherOp : public OpKernel {
TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>);
TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK);
+TF_CALL_variant(REGISTER_GATHER_AND_PACK);
REGISTER_GATHER_AND_PACK(quint8);
REGISTER_GATHER_AND_PACK(qint8);
REGISTER_GATHER_AND_PACK(qint32);
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 7177ad7888..886b3e7492 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h
index ae67592d04..709b0a92e9 100644
--- a/tensorflow/core/kernels/transpose_op.h
+++ b/tensorflow/core/kernels/transpose_op.h
@@ -42,7 +42,7 @@ class TransposeCpuOp : public TransposeOp {
gtl::ArraySlice<int32> perm, Tensor* out) override;
};
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
class MklTransposeCpuOp : public TransposeOp {
public:
explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
@@ -85,7 +85,7 @@ class ConjugateTransposeCpuOp : public TransposeOp {
bool IsConjugate() const override { return true; }
};
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
class MklConjugateTransposeCpuOp : public TransposeOp {
public:
explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/core/kernels/unary_ops_composition.cc b/tensorflow/core/kernels/unary_ops_composition.cc
new file mode 100644
index 0000000000..0c2cb1b39f
--- /dev/null
+++ b/tensorflow/core/kernels/unary_ops_composition.cc
@@ -0,0 +1,432 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/math_ops.cc.
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+#include "tensorflow/core/kernels/relu_op_functor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class UnaryOpsComposition; // forward declare kernel
+
+template <typename T>
+struct UnaryOpsCompositionSupport;
+
+template <typename T>
+struct UnaryOpsCompositionBase {
+ using InputBuffer = typename TTypes<T>::ConstFlat;
+ using OutputBuffer = typename TTypes<T>::Flat;
+
+ using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*);
+
+ struct ComputeFnRegistration {
+ ComputeFn compute_fn;
+ int cost;
+ };
+
+ bool HasComputeFn(const string& name) {
+ return compute_fns.find(name) != compute_fns.end();
+ }
+
+ protected:
+ void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) {
+ VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost;
+ compute_fns[name] = {compute_fn, cost};
+ }
+
+ private:
+ friend class UnaryOpsComposition<T>;
+
+ Status ExportComputeFns(const std::vector<string>& op_names,
+ std::vector<ComputeFn>* fns, int* cost) {
+ for (const string& op_name : op_names) {
+ auto it = compute_fns.find(op_name);
+ if (it == compute_fns.end())
+ return errors::InvalidArgument(
+ "Do not have a compute function registered for op: ", op_name);
+
+ const ComputeFnRegistration& reg = it->second;
+ fns->push_back(reg.compute_fn);
+ *cost += reg.cost;
+ }
+
+ return Status::OK();
+ }
+
+ std::unordered_map<string, ComputeFnRegistration> compute_fns;
+};
+
+template <typename T>
+class UnaryOpsComposition : public OpKernel {
+ public:
+ using Kernel = UnaryOpsComposition<T>;
+
+ using Scalar = T;
+ using Packet = typename Eigen::internal::packet_traits<T>::type;
+
+ using Support = UnaryOpsCompositionSupport<T>;
+
+ using InputBuffer = typename Support::InputBuffer;
+ using OutputBuffer = typename Support::OutputBuffer;
+ using ComputeFn = typename Support::ComputeFn;
+
+ explicit UnaryOpsComposition(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_));
+
+ OP_REQUIRES(context, !op_names_.empty(),
+ errors::InvalidArgument(
+ "Unary op composition must have at least one op"));
+
+ OP_REQUIRES_OK(context,
+ support_.ExportComputeFns(op_names_, &fns_, &cost_));
+
+ VLOG(2) << "Composed unary op: [" << str_util::Join(op_names_, ", ")
+ << "]; cost=" << cost_;
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& in = ctx->input(0);
+ Tensor* out = nullptr;
+ OP_REQUIRES_OK(
+ ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out));
+
+ InputBuffer in_flat = in.flat<T>();
+ OutputBuffer out_flat = out->flat<T>();
+
+ const std::size_t num_fns = fns_.size();
+ auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64 begin,
+ int64 end) {
+ int64 len = end - begin;
+ const InputBuffer in_slice(in_flat.data() + begin, len);
+ const InputBuffer scratch_slice(out_flat.data() + begin, len);
+ OutputBuffer out_slice(out_flat.data() + begin, len);
+
+ fns_[0](in_slice, &out_slice);
+ for (int i = 1; i < num_fns; ++i) {
+ fns_[i](scratch_slice, &out_slice);
+ }
+ };
+
+ const CPUDevice& device = ctx->eigen_device<CPUDevice>();
+ const int kOverheadCycles = static_cast<int>(num_fns) * 10;
+ Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns,
+ /*bytes_stored=*/sizeof(T) * num_fns,
+ kOverheadCycles + cost_);
+ device.parallelFor(in.NumElements(), cost, AlignBlockSize,
+ std::move(compute_fn));
+ }
+
+ private:
+ static const int kPacketSize = Eigen::internal::unpacket_traits<Packet>::size;
+
+ static inline int64 AlignBlockSize(int64 block_size) {
+ // Align block size to packet size and account for unrolling in run above.
+ if (block_size >= 16 * kPacketSize) {
+ return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1);
+ }
+ // Aligning to 4 * PacketSize would increase block size by more than 25%.
+ return (block_size + kPacketSize - 1) & ~(kPacketSize - 1);
+ }
+
+ Support support_;
+
+ std::vector<string> op_names_;
+ std::vector<ComputeFn> fns_;
+ int cost_ = 0;
+};
+
+// Register compute functions for UnaryOp functors.
+#define REGISTER_COMPUTE_FN_HELPER(name, functor) \
+ static_assert(std::is_same<functor::in_type, functor::out_type>::value, \
+ "Functor must have same input and output types"); \
+ \
+ static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \
+ *out = in.unaryExpr(functor::func()); \
+ } \
+ static inline int Cost##name() { \
+ return Eigen::internal::functor_traits<functor::func>::Cost; \
+ }
+
+// Register compute function for the Relu/Relu6/Elu/Selu.
+#define REGISTER_RELU_HELPER() \
+ template <typename T> \
+ using functor_traits = Eigen::internal::functor_traits<T>; \
+ \
+ static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) { \
+ auto relu = functor::Relu<Eigen::DefaultDevice, T>(); \
+ relu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostRelu() { \
+ return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost; \
+ } \
+ \
+ static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \
+ auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>(); \
+ relu6(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostRelu6() { \
+ return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost + \
+ functor_traits<Eigen::internal::scalar_min_op<T>>::Cost; \
+ } \
+ static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) { \
+ auto elu = functor::Elu<Eigen::DefaultDevice, T>(); \
+ elu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostElu() { \
+ return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
+ Eigen::NumTraits<T>::MulCost; \
+ } \
+ static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) { \
+ auto selu = functor::Selu<Eigen::DefaultDevice, T>(); \
+ selu(Eigen::DefaultDevice(), in, *out); \
+ } \
+ \
+ static inline int CostSelu() { \
+ return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
+ Eigen::NumTraits<T>::MulCost); \
+ }
+
+#define REGISTER_COMPUTE_FN(func) \
+ RegisterComputeFn(#func, Compute##func, Cost##func());
+
+template <>
+struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> {
+ using T = float;
+
+ UnaryOpsCompositionSupport() {
+ // UnaryOp functors.
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Acos);
+ REGISTER_COMPUTE_FN(Acosh);
+ REGISTER_COMPUTE_FN(Asin);
+ REGISTER_COMPUTE_FN(Asinh);
+ REGISTER_COMPUTE_FN(Atan);
+ REGISTER_COMPUTE_FN(Atanh);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Cosh);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Rint);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sinh);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tan);
+ REGISTER_COMPUTE_FN(Tanh);
+
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+template <>
+struct UnaryOpsCompositionSupport<Eigen::half>
+ : UnaryOpsCompositionBase<Eigen::half> {
+ using T = Eigen::half;
+
+ UnaryOpsCompositionSupport() {
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tanh);
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+template <>
+struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> {
+ using T = double;
+
+ UnaryOpsCompositionSupport() {
+ REGISTER_COMPUTE_FN(Abs);
+ REGISTER_COMPUTE_FN(Acos);
+ REGISTER_COMPUTE_FN(Acosh);
+ REGISTER_COMPUTE_FN(Asin);
+ REGISTER_COMPUTE_FN(Asinh);
+ REGISTER_COMPUTE_FN(Atan);
+ REGISTER_COMPUTE_FN(Atanh);
+ REGISTER_COMPUTE_FN(Ceil);
+ REGISTER_COMPUTE_FN(Cos);
+ REGISTER_COMPUTE_FN(Cosh);
+ REGISTER_COMPUTE_FN(Expm1);
+ REGISTER_COMPUTE_FN(Exp);
+ REGISTER_COMPUTE_FN(Floor);
+ REGISTER_COMPUTE_FN(Inv);
+ REGISTER_COMPUTE_FN(Log);
+ REGISTER_COMPUTE_FN(Log1p);
+ REGISTER_COMPUTE_FN(Neg);
+ REGISTER_COMPUTE_FN(Reciprocal);
+ REGISTER_COMPUTE_FN(Rint);
+ REGISTER_COMPUTE_FN(Round);
+ REGISTER_COMPUTE_FN(Rsqrt);
+ REGISTER_COMPUTE_FN(Sigmoid);
+ REGISTER_COMPUTE_FN(Sin);
+ REGISTER_COMPUTE_FN(Sinh);
+ REGISTER_COMPUTE_FN(Sqrt);
+ REGISTER_COMPUTE_FN(Square);
+ REGISTER_COMPUTE_FN(Tan);
+ REGISTER_COMPUTE_FN(Tanh);
+ // Additional compute functions not defined via UnaryOp functors.
+ REGISTER_COMPUTE_FN(Elu);
+ REGISTER_COMPUTE_FN(Relu);
+ REGISTER_COMPUTE_FN(Relu6);
+ REGISTER_COMPUTE_FN(Selu);
+ }
+
+ REGISTER_RELU_HELPER();
+
+ // clang-format off
+ REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
+ REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
+ REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
+ REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
+ REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
+ REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
+ REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
+ REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
+ REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
+ REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
+ REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
+ REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
+ REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
+ // clang-format on
+};
+
+// Register the CPU kernels.
+#define REGISTER_CPU(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ UnaryOpsComposition<T>);
+
+REGISTER_CPU(float);
+REGISTER_CPU(Eigen::half);
+REGISTER_CPU(double);
+
+#undef REGISTER_CPU
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unary_ops_composition_test.cc b/tensorflow/core/kernels/unary_ops_composition_test.cc
new file mode 100644
index 0000000000..4be3555609
--- /dev/null
+++ b/tensorflow/core/kernels/unary_ops_composition_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+class UnaryOpsCompositionTest : public OpsTestBase {
+ protected:
+ template <typename T>
+ void RunComposedOp(const std::vector<string> op_names, T input, T expected) {
+ TF_ASSERT_OK(NodeDefBuilder("unary_op_composition", "_UnaryOpsComposition")
+ .Input(FakeInput(DataTypeToEnum<T>::v()))
+ .Attr("T", DataTypeToEnum<T>::v())
+ .Attr("op_names", op_names)
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+
+ TensorShape shape({});
+ AddInputFromArray<T>(shape, {input});
+
+ TF_ASSERT_OK(RunOpKernel());
+
+ Tensor expected_tensor(allocator(), DataTypeToEnum<T>::value, shape);
+ test::FillValues<T>(&expected_tensor, {expected});
+ test::ExpectClose(expected_tensor, *GetOutput(0));
+ }
+};
+
+TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_F) {
+ RunComposedOp<float>({"Sqrt", "Sqrt"}, 81.0, 3.0);
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sqrt_D) {
+ RunComposedOp<double>({"Sqrt", "Sqrt"}, 81.0, 3.0);
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Sqrt_Sin_F) {
+ RunComposedOp<float>({"Sqrt", "Sin"}, 81.0, std::sin(9.0f));
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Cos_Acos_F) {
+ RunComposedOp<float>({"Cos", "Acos"}, 0.5, std::acos(std::cos(0.5f)));
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Tanh_Relu_F) {
+ RunComposedOp<float>({"Tanh", "Relu"}, 0.5, std::max(0.0f, std::tanh(0.5f)));
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Tanh_Relu_D) {
+ RunComposedOp<double>({"Tanh", "Relu"}, 0.5, std::max(0.0, std::tanh(0.5)));
+}
+
+TEST_F(UnaryOpsCompositionTest, Compose_Tanh_Relu6_F) {
+ RunComposedOp<float>({"Relu6"}, 11.0f, 6.0f);
+}
+
+// Performance benchmarks below.
+
+string Function(int i) {
+ std::vector<string> ops = {"Tanh", "Relu", "Sigmoid", "Sqrt", "Log", "Exp"};
+ return ops[i % ops.size()];
+}
+
+// Unary ops chained together as a separate graph nodes.
+static Graph* UnaryOpsChain(int tensor_size, int repeat_graph,
+ int num_functions) {
+ Graph* g = new Graph(OpRegistry::Global());
+
+ Tensor t(DT_FLOAT, TensorShape({tensor_size}));
+ t.flat<float>() = t.flat<float>().setRandom();
+
+ for (int i = 0; i < repeat_graph; ++i) {
+ Node* node = test::graph::Constant(g, t);
+ for (int j = 0; j < num_functions; ++j) {
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), Function(j))
+ .Input(node)
+ .Attr("T", DT_FLOAT)
+ .Finalize(g, &node));
+ }
+ }
+
+ return g;
+}
+
+#define BM_UnaryOpsChain(N, R, F, type) \
+ static void BM_UnaryOpsChain##_##type##_##N##_##R##_##F(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * N * R * F); \
+ test::Benchmark(#type, UnaryOpsChain(N, R, F)).Run(iters); \
+ } \
+ BENCHMARK(BM_UnaryOpsChain##_##type##_##N##_##R##_##F);
+
+// Unary ops fused together.
+static Graph* UnaryOpsCompo(int tensor_size, int repeat_graph,
+ int num_functions) {
+ Graph* g = new Graph(OpRegistry::Global());
+
+ Tensor t(DT_FLOAT, TensorShape({tensor_size}));
+ t.flat<float>() = t.flat<float>().setRandom();
+
+ std::vector<string> functions;
+ for (int j = 0; j < num_functions; ++j) {
+ functions.push_back(Function(j));
+ }
+
+ for (int i = 0; i < repeat_graph; ++i) {
+ Node* node = test::graph::Constant(g, t);
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_UnaryOpsComposition")
+ .Input(node)
+ .Attr("T", DT_FLOAT)
+ .Attr("op_names", functions)
+ .Finalize(g, &node));
+ }
+
+ return g;
+}
+
+#define BM_UnaryOpsCompo(N, R, F, type) \
+ static void BM_UnaryOpsCompo##_##type##_##N##_##R##_##F(int iters) { \
+ testing::ItemsProcessed(static_cast<int64>(iters) * N * R * F); \
+ test::Benchmark(#type, UnaryOpsCompo(N, R, F)).Run(iters); \
+ } \
+ BENCHMARK(BM_UnaryOpsCompo##_##type##_##N##_##R##_##F);
+
+// BenchmarkName(tensor_size, repeat_graph, num_ops, type)
+
+BM_UnaryOpsChain(1000, 25, 2, cpu);
+BM_UnaryOpsCompo(1000, 25, 2, cpu);
+
+BM_UnaryOpsChain(1000, 25, 5, cpu);
+BM_UnaryOpsCompo(1000, 25, 5, cpu);
+
+BM_UnaryOpsChain(1000, 25, 10, cpu);
+BM_UnaryOpsCompo(1000, 25, 10, cpu);
+
+BM_UnaryOpsChain(100000, 25, 2, cpu);
+BM_UnaryOpsCompo(100000, 25, 2, cpu);
+
+BM_UnaryOpsChain(100000, 25, 5, cpu);
+BM_UnaryOpsCompo(100000, 25, 5, cpu);
+
+BM_UnaryOpsChain(100000, 25, 10, cpu);
+BM_UnaryOpsCompo(100000, 25, 10, cpu);
+
+BM_UnaryOpsChain(1000000, 25, 2, cpu);
+BM_UnaryOpsCompo(1000000, 25, 2, cpu);
+
+BM_UnaryOpsChain(1000000, 25, 5, cpu);
+BM_UnaryOpsCompo(1000000, 25, 5, cpu);
+
+BM_UnaryOpsChain(1000000, 25, 10, cpu);
+BM_UnaryOpsCompo(1000000, 25, 10, cpu);
+
+} // namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc
index 7fd5809ca4..eadea18f76 100644
--- a/tensorflow/core/kernels/variable_ops.cc
+++ b/tensorflow/core/kernels/variable_ops.cc
@@ -73,9 +73,6 @@ void VariableOp::Compute(OpKernelContext* ctx) {
// here is valid because it owns a ref on var.
ctx->set_output_ref(0, var->mu(), var->tensor());
if (ctx->track_allocations() && var->tensor()->IsInitialized()) {
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
ctx->record_persistent_memory_allocation(var->tensor()->AllocatedBytes());
}
var->Unref();
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 2c0576ff10..1c130ba300 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -354,6 +354,18 @@ struct bfloat16 {
return x;
}
+ static bfloat16 highest() {
+ bfloat16 x;
+ x.value = 0x7F7F; // 0x1.FEp127
+ return x;
+ }
+
+ static bfloat16 lowest() {
+ bfloat16 x;
+ x.value = 0xFF7F; // -0x1.FEp127
+ return x;
+ }
+
uint16_t value;
// A value that represents "not a number".
diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc
index 1e88323d01..1590055960 100644
--- a/tensorflow/core/lib/db/sqlite_test.cc
+++ b/tensorflow/core/lib/db/sqlite_test.cc
@@ -73,6 +73,21 @@ TEST_F(SqliteTest, InsertAndSelectDouble) {
EXPECT_EQ(1, stmt.ColumnInt(1));
}
+#ifdef DSQLITE_ENABLE_JSON1
+TEST_F(SqliteTest, Json1Extension) {
+ string s1 = "{\"key\": 42}";
+ string s2 = "{\"key\": \"value\"}";
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+ stmt.BindText(1, s1);
+ stmt.BindText(2, s2);
+ TF_ASSERT_OK(stmt.StepAndReset());
+ stmt = db_->PrepareOrDie("SELECT json_extract(a, '$.key'), json_extract(b, '$.key') FROM T");
+ TF_ASSERT_OK(stmt.Step(&is_done_));
+ EXPECT_EQ(42, stmt.ColumnInt(0));
+ EXPECT_EQ("value", stmt.ColumnString(1));
+}
+#endif //DSQLITE_ENABLE_JSON1
+
TEST_F(SqliteTest, NulCharsInString) {
string s; // XXX: Want to write {2, '\0'} but not sure why not.
s.append(static_cast<size_t>(2), '\0');
diff --git a/tensorflow/core/lib/gtl/manual_constructor_test.cc b/tensorflow/core/lib/gtl/manual_constructor_test.cc
index 4e832ce8d8..35cbc78b66 100644
--- a/tensorflow/core/lib/gtl/manual_constructor_test.cc
+++ b/tensorflow/core/lib/gtl/manual_constructor_test.cc
@@ -95,9 +95,6 @@ TEST(ManualConstructorTest, Alignment) {
#ifdef ARCH_K8
EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 16, 0);
#endif
-#ifdef ARCH_PIII
- EXPECT_EQ(reinterpret_cast<intptr_t>(test2.b.get()) % 4, 0);
-#endif
}
TEST(ManualConstructorTest, DefaultInitialize) {
diff --git a/tensorflow/core/lib/io/random_inputstream.cc b/tensorflow/core/lib/io/random_inputstream.cc
index 09336e79cd..e85367df9c 100644
--- a/tensorflow/core/lib/io/random_inputstream.cc
+++ b/tensorflow/core/lib/io/random_inputstream.cc
@@ -45,16 +45,8 @@ Status RandomAccessInputStream::ReadNBytes(int64 bytes_to_read,
result->resize(data.size());
if (s.ok() || errors::IsOutOfRange(s)) {
pos_ += data.size();
- } else {
- return s;
}
- // If the amount of data we read is less than what we wanted, we return an
- // out of range error. We need to catch this explicitly since file_->Read()
- // would not do so if at least 1 byte is read (b/30839063).
- if (data.size() < bytes_to_read) {
- return errors::OutOfRange("reached end of file");
- }
- return Status::OK();
+ return s;
}
// To limit memory usage, the default implementation of SkipNBytes() only reads
diff --git a/tensorflow/core/lib/strings/numbers.cc b/tensorflow/core/lib/strings/numbers.cc
index 987e4fe733..87aa5915ff 100644
--- a/tensorflow/core/lib/strings/numbers.cc
+++ b/tensorflow/core/lib/strings/numbers.cc
@@ -331,31 +331,29 @@ bool safe_strtou32(StringPiece str, uint32* value) {
return true;
}
-bool safe_strtof(const char* str, float* value) {
+bool safe_strtof(StringPiece str, float* value) {
int processed_characters_count = -1;
- auto len = str_util::Strnlen(str, kFastToBufferSize);
+ auto len = str.size();
- // If there is no zero-termination in str, fail.
- if (len == kFastToBufferSize) return false;
- // If string length exceeds int max, fail.
+ // If string length exceeds buffer size or int max, fail.
+ if (len >= kFastToBufferSize) return false;
if (len > std::numeric_limits<int>::max()) return false;
- *value = StringToFloatConverter().StringToFloat(str, static_cast<int>(len),
- &processed_characters_count);
+ *value = StringToFloatConverter().StringToFloat(
+ str.data(), static_cast<int>(len), &processed_characters_count);
return processed_characters_count > 0;
}
-bool safe_strtod(const char* str, double* value) {
+bool safe_strtod(StringPiece str, double* value) {
int processed_characters_count = -1;
- auto len = str_util::Strnlen(str, kFastToBufferSize);
+ auto len = str.size();
- // If there is no zero-termination in str, fail.
- if (len == kFastToBufferSize) return false;
- // If string length exceeds int max, fail.
+ // If string length exceeds buffer size or int max, fail.
+ if (len >= kFastToBufferSize) return false;
if (len > std::numeric_limits<int>::max()) return false;
- *value = StringToFloatConverter().StringToDouble(str, static_cast<int>(len),
- &processed_characters_count);
+ *value = StringToFloatConverter().StringToDouble(
+ str.data(), static_cast<int>(len), &processed_characters_count);
return processed_characters_count > 0;
}
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 9cb56415cb..1d5bacac93 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -115,13 +115,13 @@ bool safe_strtou64(StringPiece str, uint64* value);
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
-bool safe_strtof(const char* str, float* value);
+bool safe_strtof(StringPiece str, float* value);
// Convert strings to double precision floating point values.
// Leading and trailing spaces are allowed.
// Values may be rounded on over- and underflow.
// Returns false on invalid input or if `strlen(value) >= kFastToBufferSize`.
-bool safe_strtod(const char* str, double* value);
+bool safe_strtod(StringPiece str, double* value);
inline bool ProtoParseNumeric(StringPiece s, int32* value) {
return safe_strto32(s, value);
diff --git a/tensorflow/core/lib/strings/numbers_test.cc b/tensorflow/core/lib/strings/numbers_test.cc
index 0f22dac262..5b595f9847 100644
--- a/tensorflow/core/lib/strings/numbers_test.cc
+++ b/tensorflow/core/lib/strings/numbers_test.cc
@@ -289,12 +289,9 @@ TEST(safe_strtof, Float) {
EXPECT_FALSE(safe_strtof("-infinity is awesome", &result));
- // Make sure we exit cleanly if the string is not terminated
+ // Make sure we exit cleanly if the string is too long
char test_str[2 * kFastToBufferSize];
for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
- EXPECT_FALSE(safe_strtof(test_str, &result));
-
- // Make sure we exit cleanly if the string is too long
test_str[kFastToBufferSize + 1] = '\0';
EXPECT_FALSE(safe_strtof(test_str, &result));
@@ -330,12 +327,9 @@ TEST(safe_strtod, Double) {
EXPECT_EQ(0.1234567890123, result);
EXPECT_FALSE(safe_strtod("0.1234567890123abc", &result));
- // Make sure we exit cleanly if the string is not terminated
+ // Make sure we exit cleanly if the string is too long
char test_str[2 * kFastToBufferSize];
for (int i = 0; i < 2 * kFastToBufferSize; ++i) test_str[i] = 'a';
- EXPECT_FALSE(safe_strtod(test_str, &result));
-
- // Make sure we exit cleanly if the string is too long
test_str[kFastToBufferSize + 1] = '\0';
EXPECT_FALSE(safe_strtod(test_str, &result));
diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc
index 0a62965eed..ba7faeb5e8 100644
--- a/tensorflow/core/ops/batch_ops.cc
+++ b/tensorflow/core/ops/batch_ops.cc
@@ -19,6 +19,26 @@ limitations under the License.
namespace tensorflow {
+REGISTER_OP("BatchFunction")
+ .Input("in_tensors: Tin")
+ .Input("captured_tensors: Tcaptured")
+ .Output("out_tensors: Tout")
+ .Attr("f: func")
+ .Attr("num_batch_threads: int")
+ .Attr("max_batch_size: int")
+ .Attr("batch_timeout_micros: int")
+ .Attr("max_enqueued_batches: int = 10")
+ .Attr("allowed_batch_sizes: list(int) = []")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("batching_queue: string = ''")
+ .Attr("Tin: list(type)")
+ .Attr("Tcaptured: list(type) >= 0")
+ .Attr("Tout: list(type)")
+ // TODO(apassos): Fix this shape inference function. It requires shape
+ // inference of function calls.
+ .SetShapeFn(shape_inference::UnknownShape);
+
REGISTER_OP("Batch")
.Input("in_tensors: T")
.Output("batched_tensors: T")
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 88d6eaf819..01452b3e85 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -203,6 +203,30 @@ REGISTER_OP("BoostedTreesPredict")
return Status::OK();
});
+REGISTER_OP("BoostedTreesExampleDebugOutputs")
+ .Input("tree_ensemble_handle: resource")
+ .Input("bucketized_features: num_bucketized_features * int32")
+ .Attr("num_bucketized_features: int >= 1") // Inferred.
+ .Attr("logits_dimension: int")
+ .Output("examples_debug_outputs_serialized: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle feature_shape;
+ int num_bucketized_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_bucketized_features", &num_bucketized_features));
+ shape_inference::ShapeHandle unused_input;
+ for (int i = 0; i < num_bucketized_features; ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 1), 1, &feature_shape));
+ // Check that the shapes of all bucketized features are the same.
+ TF_RETURN_IF_ERROR(c->Merge(c->input(1), feature_shape, &unused_input));
+ }
+
+ // Multi-class will be supported by modifying the proto.
+ auto batch_size = c->MakeShape({c->Dim(feature_shape, 0)});
+ c->set_output(0, batch_size);
+ return Status::OK();
+ });
+
REGISTER_OP("BoostedTreesSerializeEnsemble")
.Input("tree_ensemble_handle: resource")
.Output("stamp_token: int64")
@@ -307,4 +331,27 @@ REGISTER_OP("BoostedTreesUpdateEnsemble")
return Status::OK();
});
+REGISTER_OP("BoostedTreesCenterBias")
+ .Input("tree_ensemble_handle: resource")
+ .Input("mean_gradients: float")
+ .Input("mean_hessians: float")
+ // Regularization-related.
+ .Input("l1: float")
+ .Input("l2: float")
+ .Output("continue_centering: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle gradients_shape;
+ shape_inference::ShapeHandle hessians_shape;
+ shape_inference::ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(gradients_shape, hessians_shape, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 1920d0a592..6cdd03e6a0 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -6426,6 +6426,68 @@ op {
}
}
op {
+ name: "AsString"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_BOOL
+ }
+ }
+ }
+ attr {
+ name: "precision"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+ attr {
+ name: "scientific"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "shortest"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "width"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+ attr {
+ name: "fill"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+}
+op {
name: "Asin"
input_arg {
name: "x"
@@ -7619,6 +7681,66 @@ op {
}
}
op {
+ name: "AvgPool"
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "AvgPool3D"
input_arg {
name: "input"
@@ -8308,6 +8430,70 @@ op {
}
}
op {
+ name: "AvgPoolGrad"
+ input_arg {
+ name: "orig_input_shape"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Barrier"
output_arg {
name: "handle"
@@ -8721,6 +8907,37 @@ op {
}
}
op {
+ name: "BatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchFFT"
input_arg {
name: "input"
@@ -8763,6 +8980,90 @@ op {
}
}
op {
+ name: "BatchFunction"
+ input_arg {
+ name: "in_tensors"
+ type_list_attr: "Tin"
+ }
+ input_arg {
+ name: "captured_tensors"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "out_tensors"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "num_batch_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_batch_size"
+ type: "int"
+ }
+ attr {
+ name: "batch_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "max_enqueued_batches"
+ type: "int"
+ default_value {
+ i: 10
+ }
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "batching_queue"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchIFFT"
input_arg {
name: "input"
@@ -9971,6 +10272,52 @@ op {
}
}
op {
+ name: "BesselI0e"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "BesselI1e"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Betainc"
input_arg {
name: "a"
@@ -10208,6 +10555,61 @@ op {
}
}
op {
+ name: "BiasAdd"
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "bias"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+}
+op {
name: "BiasAddGrad"
input_arg {
name: "out_backprop"
@@ -10400,6 +10802,57 @@ op {
}
}
op {
+ name: "BiasAddGrad"
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+}
+op {
name: "BiasAddV1"
input_arg {
name: "value"
@@ -11074,6 +11527,34 @@ op {
}
}
op {
+ name: "BoostedTreesCenterBias"
+ input_arg {
+ name: "tree_ensemble_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "mean_gradients"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "mean_hessians"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "continue_centering"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesCreateEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -11128,6 +11609,33 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesExampleDebugOutputs"
+ input_arg {
+ name: "tree_ensemble_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "bucketized_features"
+ type: DT_INT32
+ number_attr: "num_bucketized_features"
+ }
+ output_arg {
+ name: "examples_debug_outputs_serialized"
+ type: DT_STRING
+ }
+ attr {
+ name: "num_bucketized_features"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "logits_dimension"
+ type: "int"
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesGetEnsembleStates"
input_arg {
name: "tree_ensemble_handle"
@@ -12949,6 +13457,81 @@ op {
}
}
op {
+ name: "Conv2D"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "use_cudnn_on_gpu"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+}
+op {
name: "Conv2DBackpropFilter"
input_arg {
name: "input"
@@ -13165,6 +13748,148 @@ op {
}
}
op {
+ name: "Conv2DBackpropFilter"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "filter_sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "use_cudnn_on_gpu"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+}
+op {
+ name: "Conv2DBackpropInput"
+ input_arg {
+ name: "input_sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "use_cudnn_on_gpu"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+}
+op {
name: "Conv2DBackpropInput"
input_arg {
name: "input_sizes"
@@ -13188,6 +13913,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
}
}
@@ -13226,6 +13952,18 @@ op {
}
}
}
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
}
op {
name: "Conv2DBackpropInput"
@@ -13253,6 +13991,7 @@ op {
type: DT_HALF
type: DT_BFLOAT16
type: DT_FLOAT
+ type: DT_DOUBLE
}
}
}
@@ -13364,6 +14103,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -16583,6 +17324,17 @@ op {
}
}
op {
+ name: "DatasetToGraph"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "graph"
+ type: DT_STRING
+ }
+}
+op {
name: "DatasetToSingleElement"
input_arg {
name: "dataset"
@@ -18100,6 +18852,117 @@ op {
}
}
op {
+ name: "DepthwiseConv2dNative"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+}
+op {
+ name: "DepthwiseConv2dNativeBackpropFilter"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "filter_sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "DepthwiseConv2dNativeBackpropFilter"
input_arg {
name: "input"
@@ -18141,6 +19004,19 @@ op {
}
}
}
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
}
op {
name: "DepthwiseConv2dNativeBackpropFilter"
@@ -18165,6 +19041,7 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -18197,6 +19074,18 @@ op {
}
}
}
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
}
op {
name: "DepthwiseConv2dNativeBackpropFilter"
@@ -18221,6 +19110,7 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
@@ -18321,6 +19211,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -18576,6 +19468,78 @@ op {
}
}
op {
+ name: "DepthwiseConv2dNativeBackpropInput"
+ input_arg {
+ name: "input_sizes"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "out_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "dilations"
+ type: "list(int)"
+ default_value {
+ list {
+ i: 1
+ i: 1
+ i: 1
+ i: 1
+ }
+ }
+ }
+}
+op {
name: "Dequantize"
input_arg {
name: "input"
@@ -21402,6 +22366,21 @@ op {
}
}
op {
+ name: "FakeParam"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+}
+op {
name: "FakeQuantWithMinMaxArgs"
input_arg {
name: "inputs"
@@ -22003,6 +22982,33 @@ op {
is_stateful: true
}
op {
+ name: "FeatureStatsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Fill"
input_arg {
name: "dims"
@@ -23772,6 +24778,60 @@ op {
}
}
op {
+ name: "FusedPadConv2D"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "paddings"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "mode"
+ type: "string"
+ allowed_values {
+ list {
+ s: "REFLECT"
+ s: "SYMMETRIC"
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FusedResizeAndPadConv2D"
input_arg {
name: "input"
@@ -23835,6 +24895,71 @@ op {
}
}
op {
+ name: "FusedResizeAndPadConv2D"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "paddings"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "filter"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "resize_align_corners"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "mode"
+ type: "string"
+ allowed_values {
+ list {
+ s: "REFLECT"
+ s: "SYMMETRIC"
+ }
+ }
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "Gather"
input_arg {
name: "params"
@@ -25278,6 +26403,81 @@ op {
}
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+}
+op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -25303,6 +26503,31 @@ op {
}
}
op {
+ name: "IgammaGradA"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Igammac"
input_arg {
name: "a"
@@ -26621,6 +27846,36 @@ op {
is_stateful: true
}
op {
+ name: "IteratorFromStringHandleV2"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNext"
input_arg {
name: "iterator"
@@ -26681,6 +27936,34 @@ op {
is_stateful: true
}
op {
+ name: "IteratorV2"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "L2Loss"
input_arg {
name: "t"
@@ -30506,6 +31789,80 @@ op {
}
}
op {
+ name: "MaxPool3DGradGrad"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NDHWC"
+ }
+ allowed_values {
+ list {
+ s: "NDHWC"
+ s: "NCDHW"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGrad"
input_arg {
name: "orig_input"
@@ -30877,6 +32234,85 @@ op {
}
}
op {
+ name: "MaxPoolGrad"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGradGrad"
input_arg {
name: "orig_input"
@@ -31169,6 +32605,82 @@ op {
}
}
op {
+ name: "MaxPoolGradGrad"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGradGradV2"
input_arg {
name: "orig_input"
@@ -31445,6 +32957,78 @@ op {
}
}
op {
+ name: "MaxPoolGradGradV2"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "ksize"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "strides"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGradGradWithArgmax"
input_arg {
name: "input"
@@ -32013,6 +33597,81 @@ op {
}
}
op {
+ name: "MaxPoolGradV2"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "ksize"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "strides"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGradWithArgmax"
input_arg {
name: "input"
@@ -34697,6 +36356,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionWithOverlaps"
+ input_arg {
+ name: "overlaps"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "overlap_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
@@ -35115,6 +36801,33 @@ op {
}
}
op {
+ name: "OptimizeDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "optimizations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "OrderedMapClear"
attr {
name: "capacity"
@@ -35631,6 +37344,52 @@ op {
}
}
op {
+ name: "PaddedBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PaddingFIFOQueue"
output_arg {
name: "handle"
@@ -40906,6 +42665,31 @@ op {
is_stateful: true
}
op {
+ name: "RandomGammaGrad"
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "sample"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "RandomPoisson"
input_arg {
name: "shape"
@@ -47974,6 +49758,43 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterNdAdd"
+ input_arg {
+ name: "ref"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -57126,6 +58947,17 @@ op {
}
}
op {
+ name: "SinkDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+}
+op {
name: "Size"
input_arg {
name: "input"
@@ -65145,6 +66977,54 @@ op {
}
}
op {
+ name: "SparseSliceGrad"
+ input_arg {
+ name: "backprop_val_grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_indices"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "input_start"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "output_indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "val_grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "SparseSoftmax"
input_arg {
name: "sp_indices"
@@ -67046,6 +68926,32 @@ op {
is_stateful: true
}
op {
+ name: "StatefulPartitionedCall"
+ input_arg {
+ name: "args"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ is_stateful: true
+}
+op {
name: "StatelessMultinomial"
input_arg {
name: "logits"
@@ -67755,6 +69661,36 @@ op {
}
}
op {
+ name: "StringSplitV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "sep"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "values"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "shape"
+ type: DT_INT64
+ }
+ attr {
+ name: "maxsplit"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+}
+op {
name: "StringStrip"
input_arg {
name: "input"
@@ -69298,6 +71234,34 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayGradWithShape"
+ input_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "shape_to_prepend"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "grad_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "source"
+ type: "string"
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayPack"
input_arg {
name: "handle"
@@ -72258,6 +74222,73 @@ op {
}
}
op {
+ name: "UnsortedSegmentProd"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type_attr: "Tnumsegments"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tnumsegments"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "UnsortedSegmentSum"
input_arg {
name: "data"
@@ -72897,6 +74928,33 @@ op {
is_stateful: true
}
op {
+ name: "WindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "window_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "WriteAudioSummary"
input_arg {
name: "writer"
diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc
index 81e9fcfa95..b8028291b4 100644
--- a/tensorflow/core/ops/control_flow_ops.cc
+++ b/tensorflow/core/ops/control_flow_ops.cc
@@ -145,13 +145,12 @@ REGISTER_OP("Enter")
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr) {
c->set_output_handle_shapes_and_types(0, *handle_data);
- } else {
- // Otherwise, propagate shape if output is a constant.
- bool is_constant;
- TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
- if (is_constant) {
- c->set_output(0, c->input(0));
- }
+ }
+ // Propagate shape if output is a constant.
+ bool is_constant;
+ TF_RETURN_IF_ERROR(c->GetAttr("is_constant", &is_constant));
+ if (is_constant) {
+ c->set_output(0, c->input(0));
}
return Status::OK();
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 3112f35da4..eed0bce174 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -608,6 +608,50 @@ REGISTER_OP("TensorArrayGradV3")
return Status::OK();
});
+REGISTER_OP("TensorArrayGradWithShape")
+ .Input("handle: resource")
+ .Input("flow_in: float")
+ .Input("shape_to_prepend: int32")
+ .Output("grad_handle: resource")
+ .Output("flow_out: float")
+ .Attr("source: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
+ c->set_output(0, c->Vector(2));
+ c->set_output(1, c->Scalar());
+ auto* shape_and_type = c->input_handle_shapes_and_types(0);
+ if (shape_and_type) {
+ auto input_shape = (*shape_and_type)[0].shape;
+ auto dtype = (*shape_and_type)[0].dtype;
+ // Note that shape_to_preped is a rank 1 Tensor representing a shape.
+ // The size of dimension 0 is the number of dimensions we need to add to
+ // output shape.
+ int64 prepend_rank = c->Value(c->Dim(c->input(2), 0));
+ if (c->RankKnown(input_shape) &&
+ prepend_rank != InferenceContext::kUnknownDim) {
+ int32 input_rank = c->Rank(input_shape);
+ std::vector<DimensionHandle> dims;
+ dims.reserve(prepend_rank + input_rank);
+ for (int i = 0; i < prepend_rank; ++i) {
+ dims.push_back(c->UnknownDim());
+ }
+ for (int i = 0; i < input_rank; ++i) {
+ dims.push_back(c->Dim(input_shape, i));
+ }
+ c->set_output_handle_shapes_and_types(0,
+ {{c->MakeShape(dims), dtype}});
+ } else {
+ c->set_output_handle_shapes_and_types(0,
+ {{c->UnknownShape(), dtype}});
+ }
+ }
+ return Status::OK();
+ });
+
REGISTER_OP("TensorArrayWriteV3")
.Input("handle: resource")
.Input("index: int32")
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 046049b678..c8bc11155a 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -166,6 +166,18 @@ REGISTER_OP("LatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("FeatureStatsDataset")
+ .Input("input_dataset: variant")
+ .Input("tag: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle tag_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
@@ -350,6 +362,19 @@ REGISTER_OP("FilterDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("WindowDataset")
+ .Input("input_dataset: variant")
+ .Input("window_size: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("BatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
@@ -363,6 +388,22 @@ REGISTER_OP("BatchDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("BatchDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // drop_remainder should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
// TODO(mrry): move SlideDataset to contrib in the future.
REGISTER_OP("SlideDataset")
.Input("input_dataset: variant")
@@ -379,6 +420,10 @@ REGISTER_OP("SlideDataset")
return shape_inference::ScalarShape(c);
});
+// TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
+// `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
+// possible to tell statically) compatible with `padded_shapes`, and that
+// `padding_values` are all scalars.
REGISTER_OP("PaddedBatchDataset")
.Input("input_dataset: variant")
.Input("batch_size: int64")
@@ -388,17 +433,32 @@ REGISTER_OP("PaddedBatchDataset")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that
- // `padded_shapes` are all
- // vectors, the lengths of
- // `output_types` and
- // `output_shapes` are `N`,
- // the `output_shapes` are (as
- // far as possible to tell
- // statically) compatible with
- // `padded_shapes`, and
- // that `padding_values` are
- // all scalars.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("PaddedBatchDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("padded_shapes: N * int64")
+ .Input("padding_values: Toutput_types")
+ .Input("drop_remainder: bool")
+ .Output("handle: variant")
+ .Attr("Toutput_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // batch_size should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // drop_remainder should be a scalar.
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
REGISTER_OP("DenseToSparseBatchDataset")
.Input("input_dataset: variant")
@@ -584,6 +644,14 @@ REGISTER_OP("Iterator")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("IteratorV2")
+ .Output("handle: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("AnonymousIterator")
.Output("handle: resource")
.Attr("output_types: list(type) >= 1")
@@ -661,6 +729,13 @@ REGISTER_OP("IteratorFromStringHandle")
.Attr("output_shapes: list(shape) >= 0 = []")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("IteratorFromStringHandleV2")
+ .Input("string_handle: string")
+ .Output("resource_handle: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("SerializeIterator")
.Input("resource_handle: resource")
.Output("serialized: variant")
@@ -718,4 +793,22 @@ REGISTER_OP("DatasetToTFRecord")
.Input("compression_type: string")
.SetShapeFn(shape_inference::NoOutputs);
+REGISTER_OP("DatasetToGraph")
+ .Input("input_dataset: variant")
+ .Output("graph: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("SinkDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("OptimizeDataset")
+ .Input("input_dataset: variant")
+ .Input("optimizations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 4d4a370478..5f262db2ce 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -31,11 +31,23 @@ REGISTER_OP("SymbolicGradient")
if (c->num_inputs() < c->num_outputs()) {
return errors::InvalidArgument("len(inputs) < len(outputs)");
}
+ std::vector<DataType> types;
+ TF_RETURN_IF_ERROR(c->GetAttr("Tin", &types));
// Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
// (x, y, z, du, dv) -> (dx, dy, dz). Therefore, shapes of its
// outputs (dx, dy, dz) are the same as (x, y, z).
for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ if (types[i] == DT_RESOURCE) {
+ const std::vector<shape_inference::ShapeAndType>* handle_type =
+ c->input_handle_shapes_and_types(i);
+ if (handle_type != nullptr) {
+ c->set_output(i, handle_type->at(0).shape);
+ } else {
+ c->set_output(i, c->UnknownShape());
+ }
+ } else {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});
@@ -82,8 +94,8 @@ REGISTER_OP("If")
.Input("input: Tin")
.Output("output: Tout")
.Attr("Tcond: type")
- .Attr("Tin: list(type)")
- .Attr("Tout: list(type)")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
.SetShapeFn(shape_inference::UnknownShape);
@@ -145,7 +157,6 @@ REGISTER_OP("For")
.Attr("body: func")
.SetShapeFn(shape_inference::UnknownShape);
-// TODO(b/73826847, b/37549631) Mark as stateful.
REGISTER_OP("PartitionedCall")
.Input("args: Tin")
.Output("output: Tout")
@@ -154,4 +165,30 @@ REGISTER_OP("PartitionedCall")
.Attr("f: func")
.SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("StatefulPartitionedCall")
+ .Input("args: Tin")
+ .Output("output: Tout")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
+ .Attr("f: func")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::UnknownShape);
+
+// This op is used as a placeholder in If branch functions. It doesn't provide a
+// valid output when run, so must either be removed (e.g. replaced with a
+// function input) or guaranteed not to be used (e.g. if mirroring an
+// intermediate output needed for the gradient computation of the other branch).
+REGISTER_OP("FakeParam")
+ .Output("output: dtype")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index d949e70c66..50ced1ff73 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -454,7 +454,9 @@ REGISTER_OP("DrawBoundingBoxes")
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
- return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
+ // The rank of the input image (rank = 4) has already been restricted
+ // above, and the output is of the same shape as the input.
+ return shape_inference::UnchangedShape(c);
});
// --------------------------------------------------------------------------
@@ -707,4 +709,36 @@ REGISTER_OP("NonMaxSuppressionV3")
return Status::OK();
});
+REGISTER_OP("NonMaxSuppressionWithOverlaps")
+ .Input("overlaps: float")
+ .Input("scores: float")
+ .Input("max_output_size: int32")
+ .Input("overlap_threshold: float")
+ .Input("score_threshold: float")
+ .Output("selected_indices: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle overlaps;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle overlap_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 8c0b073ce4..c229bd5a41 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -239,6 +239,21 @@ REGISTER_OP("Acos").UNARY();
REGISTER_OP("Atan").UNARY();
+REGISTER_OP("BesselI0e").UNARY_REAL();
+
+REGISTER_OP("BesselI1e").UNARY_REAL();
+
+REGISTER_OP("_UnaryOpsComposition")
+ .Input("x: T")
+ .Output("y: T")
+ .Attr("T: {float, half, double}")
+ .Attr("op_names: list(string)")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to create these operators.
+)doc");
+
#undef UNARY
#undef UNARY_REAL
#undef UNARY_COMPLEX
@@ -485,6 +500,13 @@ REGISTER_OP("Igamma")
.Attr("T: {float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+REGISTER_OP("IgammaGradA")
+ .Input("a: T")
+ .Input("x: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("Zeta")
.Input("x: T")
.Input("q: T")
@@ -592,7 +614,13 @@ REGISTER_OP("ApproximateEqual")
.SetIsCommutative()
.Attr("T: numbertype")
.Attr("tolerance: float = 0.00001")
- .SetShapeFn(shape_inference::UnchangedShape);
+ .SetShapeFn([](InferenceContext* c) {
+ // The inputs 'x' and 'y' must have the same shape.
+ ShapeHandle data_x = c->input(0);
+ ShapeHandle data_y = c->input(1);
+ TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
+ return shape_inference::UnchangedShape(c);
+ });
// --------------------------------------------------------------------------
@@ -1080,7 +1108,7 @@ REGISTER_OP("UnsortedSegmentProd")
.Input("segment_ids: Tindices")
.Input("num_segments: Tnumsegments")
.Output("output: T")
- .Attr("T: realnumbertype")
+ .Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
.Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn);
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 41efa49ce3..f947d4c30d 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -432,7 +432,7 @@ REGISTER_OP("FusedResizeAndPadConv2D")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {float}")
+ .Attr("T: {half, float, double}")
.Attr("resize_align_corners: bool = false")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
@@ -446,7 +446,7 @@ REGISTER_OP("FusedPadConv2D")
.Input("paddings: int32")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {float}")
+ .Attr("T: {half, float, double}")
.Attr(GetMirrorPadModeAttrString())
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
@@ -648,7 +648,7 @@ REGISTER_OP("MaxPool3DGradGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float}")
+ .Attr("T: realnumbertype")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
ShapeHandle unused;
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d929a5fc87..9a9f10f01f 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1977,13 +1977,14 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_INT8
+ type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_FLOAT
type: DT_DOUBLE
type: DT_BOOL
- type: DT_INT8
}
}
}
@@ -2489,6 +2490,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -2671,6 +2674,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -3005,6 +3010,37 @@ op {
}
}
op {
+ name: "BatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchFFT"
input_arg {
name: "input"
@@ -3050,6 +3086,90 @@ op {
}
}
op {
+ name: "BatchFunction"
+ input_arg {
+ name: "in_tensors"
+ type_list_attr: "Tin"
+ }
+ input_arg {
+ name: "captured_tensors"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "out_tensors"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "num_batch_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_batch_size"
+ type: "int"
+ }
+ attr {
+ name: "batch_timeout_micros"
+ type: "int"
+ }
+ attr {
+ name: "max_enqueued_batches"
+ type: "int"
+ default_value {
+ i: 10
+ }
+ }
+ attr {
+ name: "allowed_batch_sizes"
+ type: "list(int)"
+ default_value {
+ list {
+ }
+ }
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "batching_queue"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "BatchIFFT"
input_arg {
name: "input"
@@ -3746,6 +3866,52 @@ op {
}
}
op {
+ name: "BesselI0e"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "BesselI1e"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Betainc"
input_arg {
name: "a"
@@ -3823,6 +3989,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -3872,6 +4040,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -4170,6 +4340,34 @@ op {
}
}
op {
+ name: "BoostedTreesCenterBias"
+ input_arg {
+ name: "tree_ensemble_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "mean_gradients"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "mean_hessians"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l1"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "l2"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "continue_centering"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesCreateEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -4224,6 +4422,33 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesExampleDebugOutputs"
+ input_arg {
+ name: "tree_ensemble_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "bucketized_features"
+ type: DT_INT32
+ number_attr: "num_bucketized_features"
+ }
+ output_arg {
+ name: "examples_debug_outputs_serialized"
+ type: DT_STRING
+ }
+ attr {
+ name: "num_bucketized_features"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "logits_dimension"
+ type: "int"
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesGetEnsembleStates"
input_arg {
name: "tree_ensemble_handle"
@@ -5505,6 +5730,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -5582,6 +5809,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -5659,6 +5888,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -7471,6 +7702,17 @@ op {
}
}
op {
+ name: "DatasetToGraph"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "graph"
+ type: DT_STRING
+ }
+}
+op {
name: "DatasetToSingleElement"
input_arg {
name: "dataset"
@@ -8350,6 +8592,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -8420,6 +8664,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -8490,6 +8736,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -9909,6 +10157,21 @@ op {
}
}
op {
+ name: "FakeParam"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+}
+op {
name: "FakeQuantWithMinMaxArgs"
input_arg {
name: "inputs"
@@ -10160,6 +10423,33 @@ op {
is_stateful: true
}
op {
+ name: "FeatureStatsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Fill"
input_arg {
name: "dims"
@@ -11195,7 +11485,9 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
type: DT_FLOAT
+ type: DT_DOUBLE
}
}
}
@@ -11251,7 +11543,9 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
type: DT_FLOAT
+ type: DT_DOUBLE
}
}
}
@@ -12167,13 +12461,11 @@ op {
name: "Tin"
type: "list(type)"
has_minimum: true
- minimum: 1
}
attr {
name: "Tout"
type: "list(type)"
has_minimum: true
- minimum: 1
}
attr {
name: "then_branch"
@@ -12210,6 +12502,31 @@ op {
}
}
op {
+ name: "IgammaGradA"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Igammac"
input_arg {
name: "a"
@@ -12928,6 +13245,36 @@ op {
is_stateful: true
}
op {
+ name: "IteratorFromStringHandleV2"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
name: "IteratorGetNext"
input_arg {
name: "iterator"
@@ -12988,6 +13335,34 @@ op {
is_stateful: true
}
op {
+ name: "IteratorV2"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "L2Loss"
input_arg {
name: "t"
@@ -15069,6 +15444,17 @@ op {
allowed_values {
list {
type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
}
}
}
@@ -15123,6 +15509,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -15200,6 +15588,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -15270,6 +15660,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -15411,6 +15803,8 @@ op {
list {
s: "NHWC"
s: "NCHW"
+ s: "HWNC"
+ s: "HWCN"
}
}
}
@@ -16630,6 +17024,33 @@ op {
}
}
op {
+ name: "NonMaxSuppressionWithOverlaps"
+ input_arg {
+ name: "overlaps"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "scores"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "overlap_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+}
+op {
name: "NotEqual"
input_arg {
name: "x"
@@ -16830,6 +17251,33 @@ op {
}
}
op {
+ name: "OptimizeDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "optimizations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "OrderedMapClear"
attr {
name: "capacity"
@@ -17303,6 +17751,52 @@ op {
}
}
op {
+ name: "PaddedBatchDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "padded_shapes"
+ type: DT_INT64
+ number_attr: "N"
+ }
+ input_arg {
+ name: "padding_values"
+ type_list_attr: "Toutput_types"
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "Toutput_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "PaddingFIFOQueue"
output_arg {
name: "handle"
@@ -20398,6 +20892,31 @@ op {
is_stateful: true
}
op {
+ name: "RandomGammaGrad"
+ input_arg {
+ name: "alpha"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "sample"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "RandomPoisson"
input_arg {
name: "shape"
@@ -23472,6 +23991,43 @@ op {
is_stateful: true
}
op {
+ name: "ResourceScatterNdAdd"
+ input_arg {
+ name: "ref"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "updates"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceScatterNdUpdate"
input_arg {
name: "ref"
@@ -26876,6 +27432,17 @@ op {
}
}
op {
+ name: "SinkDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+}
+op {
name: "Size"
input_arg {
name: "input"
@@ -29675,6 +30242,54 @@ op {
}
}
op {
+ name: "SparseSliceGrad"
+ input_arg {
+ name: "backprop_val_grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_indices"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "input_start"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "output_indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "val_grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+}
+op {
name: "SparseSoftmax"
input_arg {
name: "sp_indices"
@@ -30707,6 +31322,32 @@ op {
is_stateful: true
}
op {
+ name: "StatefulPartitionedCall"
+ input_arg {
+ name: "args"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ is_stateful: true
+}
+op {
name: "StatelessMultinomial"
input_arg {
name: "logits"
@@ -31267,6 +31908,36 @@ op {
}
}
op {
+ name: "StringSplitV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "sep"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "indices"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "values"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "shape"
+ type: DT_INT64
+ }
+ attr {
+ name: "maxsplit"
+ type: "int"
+ default_value {
+ i: -1
+ }
+ }
+}
+op {
name: "StringStrip"
input_arg {
name: "input"
@@ -32216,6 +32887,34 @@ op {
is_stateful: true
}
op {
+ name: "TensorArrayGradWithShape"
+ input_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "flow_in"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "shape_to_prepend"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "grad_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "flow_out"
+ type: DT_FLOAT
+ }
+ attr {
+ name: "source"
+ type: "string"
+ }
+ is_stateful: true
+}
+op {
name: "TensorArrayPack"
input_arg {
name: "handle"
@@ -34160,9 +34859,14 @@ op {
type: DT_UINT8
type: DT_INT16
type: DT_INT8
+ type: DT_COMPLEX64
type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
type: DT_BFLOAT16
type: DT_UINT16
+ type: DT_COMPLEX128
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
@@ -34544,6 +35248,33 @@ op {
is_stateful: true
}
op {
+ name: "WindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "window_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "WriteAudioSummary"
input_arg {
name: "writer"
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 80ffae5796..a76248e05f 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -138,6 +138,13 @@ REGISTER_OP("RandomGamma")
return Status::OK();
});
+REGISTER_OP("RandomGammaGrad")
+ .Input("alpha: T")
+ .Input("sample: T")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("RandomPoisson")
.SetIsStateful()
.Input("shape: S")
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 3d0a6c2157..26499540f1 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -14,6 +14,7 @@
// ============================================================================
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -84,6 +85,22 @@ REGISTER_OP("ReadVariableOp")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FunctionDefHelper::Define(
+ // Arg defs
+ {"x: resource", "dy: float"},
+ // Ret val defs
+ {"dy: float"},
+ // Attr defs
+ {},
+ // Nodes
+ {});
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("ReadVariableOp", ReadGrad);
+
REGISTER_OP("DestroyResourceOp")
.Input("resource: resource")
.Attr("ignore_lookup_error: bool = true")
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index acc8c782ef..bc0cb2095d 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -302,6 +302,20 @@ REGISTER_OP("SparseSplit")
return Status::OK();
});
+REGISTER_OP("SparseSliceGrad")
+ .Input("backprop_val_grad: T")
+ .Input("input_indices: int64")
+ .Input("input_start: int64")
+ .Input("output_indices: int64")
+ .Output("val_grad: T")
+ .Attr("T: numbertype")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
+ c->set_output(0, c->Vector(c->Dim(indices, 0)));
+ return Status::OK();
+ });
+
REGISTER_OP("SparseSlice")
.Input("indices: int64")
.Input("values: T")
diff --git a/tensorflow/core/ops/sparse_ops_test.cc b/tensorflow/core/ops/sparse_ops_test.cc
index 0df3320484..6a9b5ce4d3 100644
--- a/tensorflow/core/ops/sparse_ops_test.cc
+++ b/tensorflow/core/ops/sparse_ops_test.cc
@@ -52,6 +52,18 @@ TEST(SparseOpsTest, SparseAddGrad_ShapeFn) {
INFER_OK(op, "?;[?,?];[?,?];?", "[d1_0];[d2_0]");
}
+TEST(SparseOpsTest, SparseSliceGrad_ShapeFn) {
+ ShapeInferenceTestOp op("SparseSliceGrad");
+
+ // Rank checks.
+ INFER_ERROR("must be rank 2", op, "?;[1];?;?");
+
+ INFER_OK(op, "?;?;?;?", "[?]");
+
+ // input[1].dim(0) determine output.
+ INFER_OK(op, "?;[?,?];?;?", "[d1_0]");
+}
+
TEST(SparseOpsTest, SparseReorder_ShapeFn) {
ShapeInferenceTestOp op("SparseReorder");
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index 664f52452e..aa975cb77b 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -222,6 +222,15 @@ REGISTER_OP("ResourceScatterNdUpdate")
.Attr("use_locking: bool = true")
.SetShapeFn(shape_inference::ScatterNdUpdateShape);
+REGISTER_OP("ResourceScatterNdAdd")
+ .Input("ref: resource")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape);
+
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 1d5c743a56..4423062362 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -78,7 +78,7 @@ REGISTER_OP("ReduceJoin")
REGISTER_OP("AsString")
.Input("input: T")
.Output("output: string")
- .Attr("T: {int32, int64, complex64, float, double, bool, int8}")
+ .Attr("T: {int8, int16, int32, int64, complex64, float, double, bool}")
.Attr("precision: int = -1")
.Attr("scientific: bool = false")
.Attr("shortest: bool = false")
@@ -134,6 +134,24 @@ REGISTER_OP("StringSplit")
return Status::OK();
});
+REGISTER_OP("StringSplitV2")
+ .Input("input: string")
+ .Input("sep: string")
+ .Output("indices: int64")
+ .Output("values: string")
+ .Output("shape: int64")
+ .Attr("maxsplit: int = -1")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
+ c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
+ c->set_output(2, c->Vector(2));
+ return Status::OK();
+ });
+
REGISTER_OP("StringStrip")
.Input("input: string")
.Output("output: string")
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 803b08f1a3..aa35e8a116 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -631,6 +631,9 @@ GcsFileSystem::GcsFileSystem()
// Setting either to 0 disables the cache; set both for good measure.
block_size = max_bytes = 0;
}
+ VLOG(1) << "GCS cache max size = " << max_bytes << " ; "
+ << "block size = " << block_size << " ; "
+ << "max staleness = " << max_staleness;
file_block_cache_ = MakeFileBlockCache(block_size, max_bytes, max_staleness);
// Apply overrides for the stat cache max age and max entries, if provided.
uint64 stat_cache_max_age = kStatCacheDefaultMaxAge;
@@ -804,7 +807,7 @@ void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
mutex_lock l(block_cache_lock_);
file_block_cache_ =
MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs);
- if (stats_) {
+ if (stats_ != nullptr) {
stats_->Configure(this, &throttle_, file_block_cache_.get());
}
}
@@ -1557,6 +1560,7 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
return Status::OK();
}
-REGISTER_FILE_SYSTEM("gs", RetryingGcsFileSystem);
-
} // namespace tensorflow
+
+// Initialize gcs_file_system
+REGISTER_FILE_SYSTEM("gs", ::tensorflow::RetryingGcsFileSystem);
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index e64653a67a..ee6ba7b041 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -137,8 +137,8 @@ Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
const auto expiration_timestamp_sec =
request_timestamp_sec + kRequestedTokenLifetimeSec;
- root["iat"] = request_timestamp_sec;
- root["exp"] = expiration_timestamp_sec;
+ root["iat"] = Json::Value::UInt64(request_timestamp_sec);
+ root["exp"] = Json::Value::UInt64(expiration_timestamp_sec);
// Step 2: represent the JSON as a string.
string claim = root.toStyledString();
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index f12732b434..28891320c4 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -202,7 +202,10 @@ def cc_proto_library(
)
if use_grpc_plugin:
- cc_libs += ["//external:grpc_lib"]
+ cc_libs += select({
+ "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
+ "//conditions:default": ["//external:grpc_lib"],
+ })
if default_header:
header_only_name = name
@@ -306,6 +309,7 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
cc_grpc_version = None,
j2objc_api_version = 1,
cc_api_version = 2,
+ dart_api_version = 2,
java_api_version = 2, py_api_version = 2,
js_api_version = 2, js_codegen = "jspb",
default_header = False):
@@ -411,7 +415,7 @@ def tf_proto_library(name, srcs = [], has_services = None,
visibility = [], testonly = 0,
cc_libs = [],
cc_api_version = 2, cc_grpc_version = None,
- j2objc_api_version = 1,
+ dart_api_version = 2, j2objc_api_version = 1,
java_api_version = 2, py_api_version = 2,
js_api_version = 2, js_codegen = "jspb",
provide_cc_alias = False,
@@ -616,10 +620,10 @@ def tf_additional_core_deps():
],
"//conditions:default": [],
}) + select({
- "//tensorflow:with_s3_support_windows_override": [],
- "//tensorflow:with_s3_support_android_override": [],
- "//tensorflow:with_s3_support_ios_override": [],
- "//tensorflow:with_s3_support": [
+ "//tensorflow:with_aws_support_windows_override": [],
+ "//tensorflow:with_aws_support_android_override": [],
+ "//tensorflow:with_aws_support_ios_override": [],
+ "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
],
"//conditions:default": [],
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index c17e4810d5..da1f66dc67 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -146,7 +146,6 @@ cc_library(
"@farmhash_archive//:farmhash",
"@fft2d",
"@highwayhash//:sip_hash",
- "@png_archive//:png",
],
)
@@ -161,7 +160,7 @@ cc_library(
"@farmhash_archive//:farmhash",
"@fft2d",
"@highwayhash//:sip_hash",
- "@png_archive//:png",
+ "@zlib_archive//:zlib",
],
)
@@ -187,6 +186,15 @@ cc_library(
)
cc_library(
+ name = "png",
+ copts = tf_copts(),
+ deps = [
+ "@png_archive//:png",
+ "@zlib_archive//:zlib",
+ ],
+)
+
+cc_library(
name = "protos_cc_impl",
copts = tf_copts(),
deps = [
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 9192f7ba10..e17ecc8c52 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -450,6 +450,6 @@ struct Register {
::tensorflow::register_file_system::Register<factory>(env, scheme)
#define REGISTER_FILE_SYSTEM(scheme, factory) \
- REGISTER_FILE_SYSTEM_ENV(Env::Default(), scheme, factory);
+ REGISTER_FILE_SYSTEM_ENV(::tensorflow::Env::Default(), scheme, factory);
#endif // TENSORFLOW_CORE_PLATFORM_ENV_H_
diff --git a/tensorflow/core/platform/fingerprint.h b/tensorflow/core/platform/fingerprint.h
index b47dcdedd7..720dc4c3d6 100644
--- a/tensorflow/core/platform/fingerprint.h
+++ b/tensorflow/core/platform/fingerprint.h
@@ -74,7 +74,7 @@ inline uint64 FingerprintCat64(const uint64 fp1, const uint64 fp2) {
} // namespace tensorflow
-#if defined(PLATFORM_GOOGLE)
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID)
#include "tensorflow/core/platform/google/fingerprint.h"
#else
#include "tensorflow/core/platform/default/fingerprint.h"
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 72c12318ca..ff4b4436bb 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -115,18 +115,17 @@ class LibHDFS {
const char* kLibHdfsDso = "libhdfs.so";
#endif
char* hdfs_home = getenv("HADOOP_HDFS_HOME");
- if (hdfs_home == nullptr) {
- status_ = errors::FailedPrecondition(
- "Environment variable HADOOP_HDFS_HOME not set");
- return;
- }
- string path = io::JoinPath(hdfs_home, "lib", "native", kLibHdfsDso);
- status_ = TryLoadAndBind(path.c_str(), &handle_);
- if (!status_.ok()) {
- // try load libhdfs.so using dynamic loader's search path in case
- // libhdfs.so is installed in non-standard location
- status_ = TryLoadAndBind(kLibHdfsDso, &handle_);
+ if (hdfs_home != nullptr) {
+ string path = io::JoinPath(hdfs_home, "lib", "native", kLibHdfsDso);
+ status_ = TryLoadAndBind(path.c_str(), &handle_);
+ if (status_.ok()) {
+ return;
+ }
}
+
+ // Try to load the library dynamically in case it has been installed
+ // to a in non-standard location.
+ status_ = TryLoadAndBind(kLibHdfsDso, &handle_);
}
Status status_;
diff --git a/tensorflow/core/platform/numa.h b/tensorflow/core/platform/numa.h
new file mode 100644
index 0000000000..b1f08e4c4c
--- /dev/null
+++ b/tensorflow/core/platform/numa.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_NUMA_H_
+#define TENSORFLOW_CORE_PLATFORM_NUMA_H_
+
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace port {
+
+// Returns true iff NUMA functions are supported.
+bool NUMAEnabled();
+
+// Returns the number of NUMA nodes present with respect to CPU operations.
+// Typically this will be the number of sockets where some RAM has greater
+// affinity with one socket than another.
+int NUMANumNodes();
+
+static const int kNUMANoAffinity = -1;
+
+// If possible sets affinity of the current thread to the specified NUMA node.
+// If node == kNUMANoAffinity removes affinity to any particular node.
+void NUMASetThreadNodeAffinity(int node);
+
+// Returns NUMA node affinity of the current thread, kNUMANoAffinity if none.
+int NUMAGetThreadNodeAffinity();
+
+// Like AlignedMalloc, but allocates memory with affinity to the specified NUMA
+// node.
+//
+// Notes:
+// 1. node must be >= 0 and < NUMANumNodes.
+// 1. minimum_alignment must a factor of system page size, the memory
+// returned will be page-aligned.
+// 2. This function is likely significantly slower than AlignedMalloc
+// and should not be used for lots of small allocations. It makes more
+// sense as a backing allocator for BFCAllocator, PoolAllocator, or similar.
+void* NUMAMalloc(int node, size_t size, int minimum_alignment);
+
+// Memory allocated by NUMAMalloc must be freed via NUMAFree.
+void NUMAFree(void* ptr, size_t size);
+
+// Returns NUMA node affinity of memory address, kNUMANoAffinity if none.
+int NUMAGetMemAffinity(const void* ptr);
+
+} // namespace port
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PLATFORM_NUMA_H_
diff --git a/tensorflow/core/platform/numa_test.cc b/tensorflow/core/platform/numa_test.cc
new file mode 100644
index 0000000000..8b39ecd59c
--- /dev/null
+++ b/tensorflow/core/platform/numa_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/numa.h"
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace internal {
+
+TEST(Numa, NumNodes) {
+ if (port::NUMAEnabled()) {
+ EXPECT_GE(port::NUMANumNodes(), 1);
+ }
+}
+
+TEST(Numa, Malloc) {
+ if (port::NUMAEnabled()) {
+ int num_nodes = port::NUMANumNodes();
+ for (int request_node = 0; request_node < num_nodes; ++request_node) {
+ void* ptr = port::NUMAMalloc(request_node, 8, 0);
+ EXPECT_NE(ptr, nullptr);
+ // Affinity cannot be tested until page is touched, so save a value.
+ *(reinterpret_cast<int*>(ptr)) = 0;
+ int affinity_node = port::NUMAGetMemAffinity(ptr);
+ EXPECT_EQ(affinity_node, request_node);
+ port::NUMAFree(ptr, 8);
+ }
+ }
+}
+
+TEST(Numa, SetNodeAffinity) {
+ // NOTE(tucker): This test is not reliable when executed under tap because
+ // the virtual machine may not have access to all of the availble NUMA
+ // nodes. Not sure what to do about that.
+ EXPECT_EQ(-1, port::NUMAGetThreadNodeAffinity());
+ if (port::NUMAEnabled()) {
+ int num_nodes = port::NUMANumNodes();
+ for (int request_node = 0; request_node < num_nodes; ++request_node) {
+ port::NUMASetThreadNodeAffinity(request_node);
+ int affinity_node = port::NUMAGetThreadNodeAffinity();
+ EXPECT_EQ(affinity_node, request_node);
+ }
+ }
+}
+
+} // namespace internal
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 708f32ba80..1939cf72fb 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/platform/types.h"
@@ -79,6 +80,19 @@ int NumHyperthreadsPerCore() {
return (ht_per_core > 0) ? ht_per_core : 1;
}
+bool NUMAEnabled() {
+ // Not yet implemented: coming soon.
+ return false;
+}
+
+int NUMANumNodes() { return 1; }
+
+void NUMASetThreadNodeAffinity(int node) {}
+
+int NUMAGetThreadNodeAffinity() {
+ return kNUMANoAffinity;
+}
+
void* AlignedMalloc(size_t size, int minimum_alignment) {
#if defined(__ANDROID__)
return memalign(minimum_alignment, size);
@@ -128,6 +142,16 @@ void Free(void* ptr) {
#endif
}
+void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
+ return AlignedMalloc(size, minimum_alignment);
+}
+
+void NUMAFree(void* ptr, size_t size) { Free(ptr); }
+
+int NUMAGetMemAffinity(const void* addr) {
+ return kNUMANoAffinity;
+}
+
void MallocExtension_ReleaseToSystem(std::size_t num_bytes) {
// No-op.
}
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.cc b/tensorflow/core/platform/profile_utils/cpu_utils.cc
index 02de7d1362..b0136b52f4 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.cc
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
+#include <fstream>
#include <limits>
#include <mutex>
@@ -67,22 +68,32 @@ static ICpuUtilsHelper* cpu_utils_helper_instance_ = nullptr;
#if defined(__ANDROID__)
return GetCpuUtilsHelperSingletonInstance().CalculateCpuFrequency();
#elif defined(__linux__)
- double bogomips;
- FILE* fp = popen("grep '^bogomips' /proc/cpuinfo | head -1", "r");
- if (fp == nullptr) {
- return INVALID_FREQUENCY;
- }
- const int retval_of_bogomips = fscanf(fp, "bogomips : %lf", &bogomips);
- if (retval_of_bogomips <= 0) {
+ // Read the contents of /proc/cpuinfo.
+ std::ifstream cpuinfo("/proc/cpuinfo");
+ if (!cpuinfo) {
+ LOG(WARNING) << "Failed to open /proc/cpuinfo";
return INVALID_FREQUENCY;
}
- pclose(fp);
- const double freq_ghz = bogomips / 1000.0 / 2.0;
- if (retval_of_bogomips != 1 || freq_ghz < 0.01) {
- LOG(WARNING) << "Failed to get CPU frequency: " << freq_ghz << " Hz";
- return INVALID_FREQUENCY;
+ string line;
+ while (std::getline(cpuinfo, line)) {
+ double bogomips;
+ const int retval_of_bogomips =
+ sscanf(line.c_str(), "bogomips : %lf", &bogomips);
+ if (retval_of_bogomips > 0) {
+ const double freq_ghz = bogomips / 1000.0 / 2.0;
+ if (retval_of_bogomips != 1 || freq_ghz < 0.01) {
+ LOG(WARNING) << "Failed to get CPU frequency: " << freq_ghz << " Hz";
+ return INVALID_FREQUENCY;
+ }
+ const int64 freq_n =
+ static_cast<int64>(freq_ghz * 1000.0 * 1000.0 * 1000.0);
+ LOG(INFO) << "CPU Frequency: " << freq_n << " Hz";
+ return freq_n;
+ }
}
- return static_cast<int64>(freq_ghz * 1000.0 * 1000.0 * 1000.0);
+ LOG(WARNING) << "Failed to find bogomips in /proc/cpuinfo; cannot determine "
+ "CPU frequency";
+ return INVALID_FREQUENCY;
#elif defined(__APPLE__)
int64 freq_hz;
FILE* fp =
diff --git a/tensorflow/core/platform/s3/BUILD b/tensorflow/core/platform/s3/BUILD
index 21038cfeb1..41184b6fd9 100644
--- a/tensorflow/core/platform/s3/BUILD
+++ b/tensorflow/core/platform/s3/BUILD
@@ -16,10 +16,10 @@ load(
tf_cc_binary(
name = "s3_file_system.so",
srcs = [
+ "aws_crypto.cc",
+ "aws_crypto.h",
"aws_logging.cc",
"aws_logging.h",
- "s3_crypto.cc",
- "s3_crypto.h",
"s3_file_system.cc",
"s3_file_system.h",
],
@@ -40,16 +40,14 @@ tf_cc_binary(
)
cc_library(
- name = "s3_crypto",
+ name = "aws_crypto",
srcs = [
- "s3_crypto.cc",
+ "aws_crypto.cc",
],
hdrs = [
- "s3_crypto.h",
+ "aws_crypto.h",
],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
"@aws",
"@boringssl//:crypto",
],
@@ -81,8 +79,8 @@ cc_library(
"s3_file_system.h",
],
deps = [
+ ":aws_crypto",
":aws_logging",
- ":s3_crypto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@aws",
diff --git a/tensorflow/core/platform/s3/aws_crypto.cc b/tensorflow/core/platform/s3/aws_crypto.cc
new file mode 100644
index 0000000000..90e46d6c1d
--- /dev/null
+++ b/tensorflow/core/platform/s3/aws_crypto.cc
@@ -0,0 +1,113 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/platform/s3/aws_crypto.h"
+#include <openssl/hmac.h>
+#include <openssl/sha.h>
+
+#include <aws/core/utils/crypto/HashResult.h>
+#include <aws/s3/S3Client.h>
+
+namespace tensorflow {
+
+class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
+ public:
+ AWSSha256HMACOpenSSLImpl() {}
+
+ virtual ~AWSSha256HMACOpenSSLImpl() = default;
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ const Aws::Utils::ByteBuffer& toSign,
+ const Aws::Utils::ByteBuffer& secret) override {
+ unsigned int length = SHA256_DIGEST_LENGTH;
+ Aws::Utils::ByteBuffer digest(length);
+ memset(digest.GetUnderlyingData(), 0, length);
+
+ HMAC_CTX ctx;
+ HMAC_CTX_init(&ctx);
+
+ HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
+ static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
+ HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
+ HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
+ HMAC_CTX_cleanup(&ctx);
+
+ return Aws::Utils::Crypto::HashResult(std::move(digest));
+ }
+};
+
+class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
+ public:
+ AWSSha256OpenSSLImpl() {}
+
+ virtual ~AWSSha256OpenSSLImpl() = default;
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ const Aws::String& str) override {
+ SHA256_CTX sha256;
+ SHA256_Init(&sha256);
+ SHA256_Update(&sha256, str.data(), str.size());
+
+ Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
+ SHA256_Final(hash.GetUnderlyingData(), &sha256);
+
+ return Aws::Utils::Crypto::HashResult(std::move(hash));
+ }
+
+ virtual Aws::Utils::Crypto::HashResult Calculate(
+ Aws::IStream& stream) override {
+ SHA256_CTX sha256;
+ SHA256_Init(&sha256);
+
+ auto currentPos = stream.tellg();
+ if (currentPos == std::streampos(std::streamoff(-1))) {
+ currentPos = 0;
+ stream.clear();
+ }
+
+ stream.seekg(0, stream.beg);
+
+ char streamBuffer
+ [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
+ while (stream.good()) {
+ stream.read(streamBuffer,
+ Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
+ auto bytesRead = stream.gcount();
+
+ if (bytesRead > 0) {
+ SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
+ }
+ }
+
+ stream.clear();
+ stream.seekg(currentPos, stream.beg);
+
+ Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
+ SHA256_Final(hash.GetUnderlyingData(), &sha256);
+
+ return Aws::Utils::Crypto::HashResult(std::move(hash));
+ }
+};
+
+std::shared_ptr<Aws::Utils::Crypto::Hash>
+AWSSHA256Factory::CreateImplementation() const {
+ return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
+}
+
+std::shared_ptr<Aws::Utils::Crypto::HMAC>
+AWSSHA256HmacFactory::CreateImplementation() const {
+ return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/aws_crypto.h b/tensorflow/core/platform/s3/aws_crypto.h
new file mode 100644
index 0000000000..f05771b904
--- /dev/null
+++ b/tensorflow/core/platform/s3/aws_crypto.h
@@ -0,0 +1,35 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <aws/core/Aws.h>
+#include <aws/core/utils/crypto/Factories.h>
+#include <aws/core/utils/crypto/HMAC.h>
+#include <aws/core/utils/crypto/Hash.h>
+
+namespace tensorflow {
+static const char* AWSCryptoAllocationTag = "AWSCryptoAllocation";
+
+class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
+ public:
+ std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
+ const override;
+};
+
+class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
+ public:
+ std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
+ const override;
+};
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 6da679dc75..bdc8f808df 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -17,8 +17,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/file_system_helper.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/s3/aws_crypto.h"
#include "tensorflow/core/platform/s3/aws_logging.h"
-#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h>
#include <aws/core/config/AWSProfileConfigLoader.h>
@@ -300,10 +300,10 @@ std::shared_ptr<Aws::S3::S3Client> S3FileSystem::GetS3Client() {
Aws::SDKOptions options;
options.cryptoOptions.sha256Factory_create_fn = []() {
- return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
+ return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
};
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
- return Aws::MakeShared<S3SHA256HmacFactory>(S3CryptoAllocationTag);
+ return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
};
Aws::InitAPI(options);
diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc
index a693c4695f..0f9e75bf9c 100644
--- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc
+++ b/tensorflow/core/platform/vmodule_benchmark_test.cc
@@ -13,20 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+namespace tensorflow {
-namespace xla {
-
-string VersionedComputationHandle::ToString() const {
- return tensorflow::strings::StrCat(handle.handle(), ":v", version);
-}
-
-std::ostream& operator<<(std::ostream& out,
- const VersionedComputationHandle& versioned_handle) {
- out << versioned_handle.ToString();
- return out;
+static void BM_DisabledVlog(int iters) {
+ for (int i = 0; i < iters; ++i) {
+ VLOG(1) << "Testing VLOG(1)!";
+ }
}
+BENCHMARK(BM_DisabledVlog);
-} // namespace xla
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/vmodule_test.cc b/tensorflow/core/platform/vmodule_test.cc
new file mode 100644
index 0000000000..47b4b2e0e7
--- /dev/null
+++ b/tensorflow/core/platform/vmodule_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Test that popens a child process with the VLOG-ing environment variable set
+// for the logging framework, and observes VLOG_IS_ON and VLOG macro output.
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/test.h"
+
+#include <string.h>
+
+namespace tensorflow {
+namespace {
+
+int RealMain(const char* argv0, bool do_vlog) {
+ if (do_vlog) {
+#if !defined(PLATFORM_GOOGLE)
+ // Note, we only test this when !defined(PLATFORM_GOOGLE) because
+ // VmoduleActivated doesn't exist in that implementation.
+ //
+ // Also, we call this internal API to simulate what would happen if
+ // differently-named translation units attempted to VLOG, so we don't need
+ // to create dummy translation unit files.
+ bool ok = internal::LogMessage::VmoduleActivated("vmodule_test.cc", 7) &&
+ internal::LogMessage::VmoduleActivated("shoobadooba.h", 3);
+ if (!ok) {
+ fprintf(stderr, "vmodule activated levels not as expected.\n");
+ return EXIT_FAILURE;
+ }
+#endif
+
+ // Print info on which VLOG levels are activated.
+ fprintf(stderr, "VLOG_IS_ON(8)? %d\n", VLOG_IS_ON(8));
+ fprintf(stderr, "VLOG_IS_ON(7)? %d\n", VLOG_IS_ON(7));
+ fprintf(stderr, "VLOG_IS_ON(6)? %d\n", VLOG_IS_ON(6));
+ // Do some VLOG-ing.
+ VLOG(8) << "VLOG(8)";
+ VLOG(7) << "VLOG(7)";
+ VLOG(6) << "VLOG(6)";
+ LOG(INFO) << "INFO";
+ return EXIT_SUCCESS;
+ }
+
+ // Popen the child process.
+ std::string command = std::string(argv0);
+#if defined(PLATFORM_GOOGLE)
+ command = command + " do_vlog --vmodule=vmodule_test=7 --alsologtostderr";
+#else
+ command =
+ "TF_CPP_VMODULE=vmodule_test=7,shoobadooba=3 " + command + " do_vlog";
+#endif
+ command += " 2>&1";
+ fprintf(stderr, "Running: \"%s\"\n", command.c_str());
+ FILE* f = popen(command.c_str(), "r");
+ if (f == nullptr) {
+ fprintf(stderr, "Failed to popen child: %s\n", strerror(errno));
+ return EXIT_FAILURE;
+ }
+
+ // Read data from the child's stdout.
+ constexpr int kBufferSizeBytes = 4096;
+ char buffer[kBufferSizeBytes];
+ size_t result = fread(buffer, sizeof(buffer[0]), kBufferSizeBytes - 1, f);
+ if (result == 0) {
+ fprintf(stderr, "Failed to read from child stdout: %zu %s\n", result,
+ strerror(errno));
+ return EXIT_FAILURE;
+ }
+ buffer[result] = '\0';
+ int status = pclose(f);
+ if (status == -1) {
+ fprintf(stderr, "Failed to close popen child: %s\n", strerror(errno));
+ return EXIT_FAILURE;
+ }
+
+ // Check output is as expected.
+ const char kExpected[] =
+ "VLOG_IS_ON(8)? 0\nVLOG_IS_ON(7)? 1\nVLOG_IS_ON(6)? 1\n";
+ if (strstr(buffer, kExpected) == nullptr) {
+ fprintf(stderr, "error: unexpected output from child: \"%.*s\"\n",
+ kBufferSizeBytes, buffer);
+ return EXIT_FAILURE;
+ }
+ bool ok = strstr(buffer, "VLOG(7)\n") != nullptr &&
+ strstr(buffer, "VLOG(6)\n") != nullptr &&
+ strstr(buffer, "VLOG(8)\n") == nullptr;
+ if (!ok) {
+ fprintf(stderr, "error: VLOG output not as expected: \"%.*s\"\n",
+ kBufferSizeBytes, buffer);
+ return EXIT_FAILURE;
+ }
+
+ // Success!
+ return EXIT_SUCCESS;
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ bool do_vlog = argc >= 2 && strcmp(argv[1], "do_vlog") == 0;
+ return tensorflow::RealMain(argv[0], do_vlog);
+}
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 174f41a993..f2aaf13bec 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -171,5 +171,10 @@ int64 AvailableRam() {
return INT64_MAX;
}
+int NumHyperthreadsPerCore() {
+ static const int ht_per_core = tensorflow::port::CPUIDNumSMT();
+ return (ht_per_core > 0) ? ht_per_core : 1;
+}
+
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.cc b/tensorflow/core/profiler/internal/tfprof_timeline.cc
index b0dd8ce5e0..979b437914 100644
--- a/tensorflow/core/profiler/internal/tfprof_timeline.cc
+++ b/tensorflow/core/profiler/internal/tfprof_timeline.cc
@@ -47,9 +47,9 @@ Json::Value ChromeTraceFormatter::CreateEvent(const string& ph,
event["ph"] = Json::Value(ph);
event["cat"] = Json::Value(category);
event["name"] = Json::Value(name);
- event["pid"] = Json::Value(pid);
- event["tid"] = Json::Value(tid);
- event["ts"] = Json::Value(ts);
+ event["pid"] = Json::Int64(pid);
+ event["tid"] = Json::Int64(tid);
+ event["ts"] = Json::Int64(ts);
return event;
}
@@ -57,7 +57,7 @@ void ChromeTraceFormatter::EmitPID(const string& name, int64 pid) {
Json::Value event(Json::objectValue);
event["name"] = Json::Value("process_name");
event["ph"] = Json::Value("M");
- event["pid"] = Json::Value(pid);
+ event["pid"] = Json::Int64(pid);
Json::Value args(Json::objectValue);
args["name"] = Json::Value(name);
event["args"] = args;
@@ -68,7 +68,7 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
int64 tid, const string& category,
const string& name, Json::Value args) {
Json::Value event = CreateEvent("X", category, name, pid, tid, ts);
- event["dur"] = Json::Value(duration);
+ event["dur"] = Json::Int64(duration);
event["args"] = std::move(args);
metadata_.push_back(event);
}
@@ -76,14 +76,14 @@ void ChromeTraceFormatter::EmitRegion(int64 ts, int64 duration, int64 pid,
void ChromeTraceFormatter::EmitFlowStart(const string& name, int64 ts,
int64 pid, int64 tid, int64 flow_id) {
Json::Value event = CreateEvent("s", "DataFlow", name, pid, tid, ts);
- event["id"] = flow_id;
+ event["id"] = Json::Int64(flow_id);
events_.push_back(event);
}
void ChromeTraceFormatter::EmitFlowEnd(const string& name, int64 ts, int64 pid,
int64 tid, int64 flow_id) {
Json::Value event = CreateEvent("t", "DataFlow", name, pid, tid, ts);
- event["id"] = flow_id;
+ event["id"] = Json::Int64(flow_id);
events_.push_back(event);
}
@@ -93,7 +93,7 @@ void ChromeTraceFormatter::EmitCounter(
const std::map<int64, std::vector<string>>& tensor_mem) {
Json::Value event = CreateEvent("C", category, "Allocated Bytes", pid, 0, ts);
Json::Value args(Json::objectValue);
- args["Allocator Bytes in Use"] = Json::Value(bytes);
+ args["Allocator Bytes in Use"] = Json::Int64(bytes);
event["args"] = args;
events_.push_back(event);
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index d83215d5c2..5b6aa47b93 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -143,6 +143,10 @@ message GPUOptions {
// multiple processes are sharing a single GPU while individually using less
// than 1.0 per process memory fraction.
bool use_unified_memory = 2;
+
+ // If > 1, the number of device-to-device copy streams to create
+ // for each GPUDevice.
+ int32 num_dev_to_dev_copy_streams = 3;
}
// Everything inside experimental is subject to change and is not subject
@@ -490,5 +494,67 @@ message CallableOptions {
// in the callable.
repeated TensorConnection tensor_connection = 5;
- // Next: 6
+ // The Tensor objects fed in the callable and fetched from the callable
+ // are expected to be backed by host (CPU) memory by default.
+ //
+ // The options below allow changing that - feeding tensors backed by
+ // device memory, or returning tensors that are backed by device memory.
+ //
+ // The maps below map the name of a feed/fetch tensor (which appears in
+ // 'feed' or 'fetch' fields above), to the fully qualified name of the device
+ // owning the memory backing the contents of the tensor.
+ //
+ // For example, creating a callable with the following options:
+ //
+ // CallableOptions {
+ // feed: "a:0"
+ // feed: "b:0"
+ //
+ // fetch: "x:0"
+ // fetch: "y:0"
+ //
+ // feed_devices: {
+ // "a:0": "/job:localhost/replica:0/task:0/device:GPU:0"
+ // }
+ //
+ // fetch_devices: {
+ // "y:0": "/job:localhost/replica:0/task:0/device:GPU:0"
+ // }
+ // }
+ //
+ // means that the Callable expects:
+ // - The first argument ("a:0") is a Tensor backed by GPU memory.
+ // - The second argument ("b:0") is a Tensor backed by host memory.
+ // and of its return values:
+ // - The first output ("x:0") will be backed by host memory.
+ // - The second output ("y:0") will be backed by GPU memory.
+ //
+ // FEEDS:
+ // It is the responsibility of the caller to ensure that the memory of the fed
+ // tensors will be correctly initialized and synchronized before it is
+ // accessed by operations executed during the call to Session::RunCallable().
+ //
+ // This is typically ensured by using the TensorFlow memory allocators
+ // (Device::GetAllocator()) to create the Tensor to be fed.
+ //
+ // Alternatively, for CUDA-enabled GPU devices, this typically means that the
+ // operation that produced the contents of the tensor has completed, i.e., the
+ // CUDA stream has been synchronized (e.g., via cuCtxSynchronize() or
+ // cuStreamSynchronize()).
+ map<string, string> feed_devices = 6;
+ map<string, string> fetch_devices = 7;
+
+ // By default, RunCallable() will synchronize the GPU stream before returning
+ // fetched tensors on a GPU device, to ensure that the values in those tensors
+ // have been produced. This simplifies interacting with the tensors, but
+ // potentially incurs a performance hit.
+ //
+ // If this options is set to true, the caller is responsible for ensuring
+ // that the values in the fetched tensors have been produced before they are
+ // used. The caller can do this by invoking `Device::Sync()` on the underlying
+ // device(s), or by feeding the tensors back to the same Session using
+ // `feed_devices` with the same corresponding device name.
+ bool fetch_skip_sync = 8;
+
+ // Next: 9
}
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
index 9a7d0edb35..5b05a1b3ee 100644
--- a/tensorflow/core/protobuf/eager_service.proto
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -7,6 +7,7 @@ import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/function.proto";
import "tensorflow/core/framework/versions.proto";
import "tensorflow/core/protobuf/tensorflow_server.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
message RemoteTensorHandle {
// The ID of the operation that produced this tensor.
@@ -45,6 +46,10 @@ message QueueItem {
}
}
+message QueueResponse {
+ repeated TensorShapeProto shape = 1;
+}
+
message CreateContextRequest {
// Identifies the full cluster, and this particular worker's position within.
ServerDef server_def = 1;
@@ -60,6 +65,11 @@ message CreateContextRequest {
// This is the version for all the ops that will be enqueued by the client.
VersionDef version_def = 4;
+
+ // This ID will be used for all future communications. It is essential that
+ // both ends use this ID for selecting a rendezvous to get everything to
+ // match.
+ int64 rendezvous_id = 5;
}
message CreateContextResponse {
@@ -79,6 +89,8 @@ message EnqueueRequest {
}
message EnqueueResponse {
+ // A single operation response for every item in the request.
+ repeated QueueResponse queue_response = 1;
}
message WaitQueueDoneRequest {
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index bbb25d6f3f..07f984ceea 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -80,6 +80,12 @@ message RewriterConfig {
// is once).
NumIterationsType meta_optimizer_iterations = 12;
+ // The minimum number of nodes in a graph to optimizer. For smaller graphs,
+ // optimization is skipped.
+ // 0 means the system picks an appropriate number.
+ // < 0 means do not skip optimization.
+ int32 min_graph_nodes = 17;
+
enum MemOptType {
// The default setting (SCHEDULING and SWAPPING HEURISTICS only)
DEFAULT_MEM_OPT = 0;
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index 90c3fed2e8..8c24076aa9 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -184,16 +184,65 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
return true;
}
+namespace {
+
+void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
+ DeviceNameUtils::ParsedName* parsed_name) {
+ if (!parsed_name->has_job) {
+ parsed_name->job = parsed_basename.job;
+ parsed_name->has_job = true;
+ }
+ if (!parsed_name->has_replica) {
+ parsed_name->replica = parsed_basename.replica;
+ parsed_name->has_replica = true;
+ }
+ if (!parsed_name->has_task) {
+ parsed_name->task = parsed_basename.task;
+ parsed_name->has_task = true;
+ }
+ if (!parsed_name->has_type) {
+ parsed_name->type = parsed_basename.type;
+ parsed_name->has_type = true;
+ }
+ if (!parsed_name->has_id) {
+ parsed_name->id = parsed_basename.id;
+ parsed_name->has_id = true;
+ }
+}
+
+} // namespace
+
/* static */
-string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) {
+Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
+ StringPiece basename,
+ string* canonical_name) {
+ *canonical_name = "";
+ ParsedName parsed_basename;
+ if (!ParseFullName(basename, &parsed_basename)) {
+ return errors::InvalidArgument("Could not parse basename: ", basename,
+ " into a device specification.");
+ }
+ if (!(parsed_basename.has_job && parsed_basename.has_replica &&
+ parsed_basename.has_task && parsed_basename.has_type &&
+ parsed_basename.has_id)) {
+ return errors::InvalidArgument("Basename: ", basename,
+ " should be fully "
+ "specified.");
+ }
ParsedName parsed_name;
if (ParseLocalName(fullname, &parsed_name)) {
- return ParsedNameToString(parsed_name);
+ CompleteName(parsed_basename, &parsed_name);
+ *canonical_name = ParsedNameToString(parsed_name);
+ return Status::OK();
}
if (ParseFullName(fullname, &parsed_name)) {
- return ParsedNameToString(parsed_name);
+ CompleteName(parsed_basename, &parsed_name);
+ *canonical_name = ParsedNameToString(parsed_name);
+ return Status::OK();
}
- return "";
+ return errors::InvalidArgument("Could not parse ", fullname,
+ " into a device "
+ "specification.");
}
/* static */
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index 0ae28df997..4071a70836 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -88,10 +88,14 @@ class DeviceNameUtils {
// Parses "fullname" into "*parsed". Returns true iff succeeds.
static bool ParseFullName(StringPiece fullname, ParsedName* parsed);
- // Canonicalizes "fullname". Accepts both legacy, newer and local versions of
- // the device spec. Returns the newer version of the device spec. If we were
- // unable to interpret / parse "fullname" returns "".
- static string CanonicalizeDeviceName(StringPiece fullname);
+ // Canonicalizes "fullname" into "*canonical_name". Uses a fully specified
+ // basename to fill in fields that are missing. Accepts both legacy, newer
+ // and local versions of the device spec. Returns the newer version of the
+ // device spec. If we were unable to interpret / parse "fullname" returns
+ // an error and *canonical_name is set to "".
+ static Status CanonicalizeDeviceName(StringPiece fullname,
+ StringPiece basename,
+ string* canonical_name);
// Returns true if "name" specifies any non-trivial constraint on the device.
static bool HasSomeDetails(const ParsedName& name) {
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index ff9c108f10..dafb3b20b9 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -467,18 +467,41 @@ TEST(DeviceNameUtilsTest, GetNamesForDeviceMappings) {
}
TEST(DeviceNameUtilsTest, CanonicalizeDeviceName) {
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/replica:10/task:0/device:CPU:1"));
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica:10/device:CPU:1"));
- EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1",
- DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica:10/cpu:1"));
- EXPECT_EQ("/device:CPU:0", DeviceNameUtils::CanonicalizeDeviceName("CPU:0"));
- EXPECT_EQ("", DeviceNameUtils::CanonicalizeDeviceName(
- "/job:foo/task:0/replica/cpu:1"));
+ string canonical_name;
+ {
+ // Good basename.
+ string basename = "/job:foo/replica:10/task:0/device:CPU:0";
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/replica:10/task:0/device:CPU:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica:10/device:CPU:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica:10/cpu:1", basename, &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:1", canonical_name);
+ TF_EXPECT_OK(DeviceNameUtils::CanonicalizeDeviceName("CPU:0", basename,
+ &canonical_name));
+ EXPECT_EQ("/job:foo/replica:10/task:0/device:CPU:0", canonical_name);
+ Status s = DeviceNameUtils::CanonicalizeDeviceName(
+ "/job:foo/task:0/replica/cpu:1", basename, &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ }
+
+ {
+ // Try out malformed basenames.
+ string fullname = "/device:CPU:0";
+
+ Status s = DeviceNameUtils::CanonicalizeDeviceName(
+ fullname, "/device:CPU:0", &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ s = DeviceNameUtils::CanonicalizeDeviceName(
+ fullname, "/job:foo/task:0/replica/cpu:1", &canonical_name);
+ EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
+ EXPECT_EQ("", canonical_name);
+ }
}
static void BM_ParseFullName(int iters) {
diff --git a/tensorflow/core/util/exec_on_stall.h b/tensorflow/core/util/exec_on_stall.h
new file mode 100644
index 0000000000..5c8f9d2324
--- /dev/null
+++ b/tensorflow/core/util/exec_on_stall.h
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_
+#define TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_
+
+#include <functional>
+
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// An object that executes a particular function only if it
+// is not deleted within the allotted number of seconds.
+//
+// This can be useful in diagnosing deadlocks, stalls and memory leaks
+// without logging too agressively.
+class ExecuteOnStall {
+ public:
+ // delay_secs: If the object still exists after this many seconds,
+ // execute f.
+ // f: The function to be executed, for example a detailed log of the
+ // the state of an object to which this is attached.
+ // poll_microseconds: The spawned thread will wake and test whether
+ // the destructor has been invoked this frequently.
+ ExecuteOnStall(int delay_secs, std::function<void()> f,
+ int32 poll_microseconds = 100)
+ : disabled_(false),
+ joined_(false),
+ env_(Env::Default()),
+ f_(f),
+ poll_microseconds_(poll_microseconds) {
+ deadline_ = env_->NowMicros() + 1000000 * delay_secs;
+ env_->SchedClosure([this]() {
+ while (env_->NowMicros() < deadline_) {
+ {
+ mutex_lock l(mu_);
+ if (disabled_) {
+ break;
+ }
+ }
+ env_->SleepForMicroseconds(poll_microseconds_);
+ }
+ {
+ mutex_lock l(mu_);
+ if (!disabled_) {
+ f_();
+ }
+ joined_ = true;
+ cond_var_.notify_all();
+ }
+ });
+ }
+
+ ~ExecuteOnStall() {
+ // Wait for spawned thread to terminate.
+ mutex_lock l(mu_);
+ disabled_ = true;
+ if (!joined_) {
+ cond_var_.wait(l);
+ }
+ }
+
+ private:
+ mutex mu_;
+ condition_variable cond_var_;
+ bool disabled_ GUARDED_BY(mu_);
+ bool joined_ GUARDED_BY(mu_);
+ Env* env_;
+ std::function<void()> f_;
+ int64 deadline_;
+ int32 poll_microseconds_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_UTIL_EXEC_ON_STALL_H_
diff --git a/tensorflow/core/util/exec_on_stall_test.cc b/tensorflow/core/util/exec_on_stall_test.cc
new file mode 100644
index 0000000000..42e66a7e84
--- /dev/null
+++ b/tensorflow/core/util/exec_on_stall_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/exec_on_stall.h"
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Chunk {
+ std::unique_ptr<ExecuteOnStall> stall_closure;
+};
+
+Chunk* NewChunk(int stall_seconds, std::function<void()> f) {
+ Chunk* c = new Chunk;
+ c->stall_closure.reset(new ExecuteOnStall(stall_seconds, std::move(f)));
+ return c;
+}
+
+TEST(ExecuteOnStallTest, BothWays) {
+ mutex mu;
+ bool a_triggered(false);
+ bool b_triggered(false);
+ Chunk* a = NewChunk(1, [&mu, &a_triggered]() {
+ mutex_lock l(mu);
+ a_triggered = true;
+ });
+ Chunk* b = NewChunk(1, [&mu, &b_triggered]() {
+ mutex_lock l(mu);
+ b_triggered = true;
+ });
+ delete a;
+ Env::Default()->SleepForMicroseconds(2000000);
+ {
+ mutex_lock l(mu);
+ EXPECT_FALSE(a_triggered);
+ EXPECT_TRUE(b_triggered);
+ }
+ delete b;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 230b4278ca..bb447e0393 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -22,10 +22,13 @@ limitations under the License.
#include <unordered_map>
#include <utility>
+#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
#include "mkl_trans.h"
+#endif
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -39,6 +42,7 @@ limitations under the License.
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
+#include "tensorflow/core/lib/core/stringpiece.h"
using mkldnn::engine;
using mkldnn::memory;
@@ -51,11 +55,12 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-// The file contains a number of utility classes and functions used by MKL
-// enabled kernels
namespace tensorflow {
+// The file contains a number of utility classes and functions used by MKL
+// enabled kernels
+
// This class encapsulates all the meta data that is associated with an MKL
// tensor. A tensor is an MKL tensor if it was created as the result of an
// MKL operation, and did not go through a conversion to a standard
@@ -71,6 +76,7 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+#ifdef INTEL_MKL_ML
class MklShape {
public:
MklShape() {}
@@ -331,7 +337,7 @@ class MklShape {
nullptr; // TF dimension corresponding to this MKL dimension
};
-#ifndef INTEL_MKL_ML
+#else
// Forward decl
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
@@ -664,12 +670,14 @@ class MklDnnShape {
// List of MklShape objects. Used in Concat/Split layers.
-typedef std::vector<MklShape> MklShapeList;
#ifndef INTEL_MKL_ML
typedef std::vector<MklDnnShape> MklDnnShapeList;
+#else
+typedef std::vector<MklShape> MklShapeList;
#endif
+#ifdef INTEL_MKL_ML
// Check if all tensors specified by MklShapes are MKL tensors.
inline bool AreAllMklTensors(const MklShapeList& shapes) {
for (auto& s : shapes) {
@@ -680,7 +688,6 @@ inline bool AreAllMklTensors(const MklShapeList& shapes) {
return true;
}
-#ifdef INTEL_MKL_ML
template <typename T>
inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
const MklShape& mkl_shape) {
@@ -753,6 +760,7 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
#endif
// Get the MKL shape from the second string tensor
+#ifdef INTEL_MKL_ML
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
@@ -763,8 +771,7 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
.size() *
sizeof(uint8));
}
-
-#ifndef INTEL_MKL_ML
+#else
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
mklshape->DeSerializeMklDnnShape(
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
@@ -838,6 +845,7 @@ inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
}
#endif
+#ifdef INTEL_MKL_ML
// Allocate the second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -853,7 +861,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
-#ifndef INTEL_MKL_ML
+#else
// Allocate the second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -870,6 +878,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
}
#endif
+#ifdef INTEL_MKL_ML
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -890,7 +899,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
-#ifndef INTEL_MKL_ML
+#else
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -925,8 +934,7 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
tf_shape, tensor_out));
*buf_out = static_cast<void*>(tensor_out->flat<T>().data());
}
-#endif
-
+#else
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
dnnLayout_t lt_buff, void** buf_out) {
TensorShape tf_shape;
@@ -940,6 +948,7 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
*buf_out = static_cast<void*>(tensor_out->flat<float>().data());
}
+#endif
template <typename T>
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
TensorShape tf_shape) {
@@ -963,6 +972,7 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
}
}
+#ifdef INTEL_MKL_ML
inline void MklSizesToTFSizes(OpKernelContext* context,
TensorFormat data_format_,
const MklShape& mkl_shape,
@@ -988,6 +998,7 @@ inline void MklSizesToTFSizes(OpKernelContext* context,
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
}
+#endif
inline int32 GetMklTensorDimIndex(char dimension) {
switch (dimension) {
@@ -1005,12 +1016,14 @@ inline int32 GetMklTensorDimIndex(char dimension) {
}
}
+#ifdef INTEL_MKL_ML
inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
int index = GetMklTensorDimIndex(dimension);
CHECK(index >= 0 && index < mkl_shape.GetDimension())
<< "Invalid index from the dimension: " << index << ", " << dimension;
return mkl_shape.dim_size(index);
}
+#endif
inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
int idx_out) {
@@ -1130,6 +1143,14 @@ inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
}
#ifndef INTEL_MKL_ML
+// Set a dummy MKLDNN shape (called when the output is in TF format)
+inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
+ uint32 idx_data_out) {
+ MklDnnShape mkl_shape_output;
+ mkl_shape_output.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
+}
+
inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
int idx_in, int idx_out,
const MklDnnShape& mkl_shape) {
@@ -1165,6 +1186,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
}
}
+#ifdef INTEL_MKL_ML
// Set a dummy MKL shape (called when the output is in TF format)
inline void SetDummyMklShapeOutput(OpKernelContext* context,
uint32 idx_data_out) {
@@ -1172,8 +1194,6 @@ inline void SetDummyMklShapeOutput(OpKernelContext* context,
mkl_shape_output.SetMklTensor(false);
AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
}
-
-#ifdef INTEL_MKL_ML
// We don't need these functions in MKLDNN. We have defined equality operator
// on MklDnnShape class directly.
@@ -1243,7 +1263,6 @@ inline bool MklCompareShapes(const TensorShape* input_shape_0,
return true;
}
-#endif
// These functions do not compile with MKL-DNN since mkl.h is missing.
// We may need to remove them later.
@@ -1281,6 +1300,7 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
}
}
+#endif
// -------------------------------------------------------------------
#ifndef INTEL_MKL_ML
@@ -1467,6 +1487,8 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
return memory::desc(md);
}
+template <typename T>
+inline primitive FindOrCreateReorder(const memory* from, const memory* to);
/*
* Class to represent all the resources corresponding to a tensor in TensorFlow
* that are required to execute an operation (such as Convolution).
@@ -1713,6 +1735,24 @@ class MklDnnData {
return false;
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
+ CHECK_NOTNULL(user_memory_);
+ if (IsReorderNeeded(op_pd)) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ reorder_memory_ = new memory(op_pd);
+ std::vector<primitive> net;
+ net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ return true;
+ }
+ return false;
+ }
+
/// Overloaded version of above function that accepts memory buffer
/// where output of reorder needs to be stored.
///
@@ -1738,6 +1778,26 @@ class MklDnnData {
return false;
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ void* reorder_data_handle) {
+ CHECK_NOTNULL(reorder_data_handle);
+ CHECK_NOTNULL(user_memory_);
+ if (IsReorderNeeded(op_pd)) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ std::vector<primitive> net;
+ reorder_memory_ = new memory(op_pd, reorder_data_handle);
+ net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ return true;
+ }
+ return false;
+ }
+
/// Another overloaded version of CheckReorderToOpMem that accepts Tensor
/// where output of reorder needs to be stored.
///
@@ -1756,6 +1816,15 @@ class MklDnnData {
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
}
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ Tensor* reorder_tensor) {
+ CHECK_NOTNULL(reorder_tensor);
+ return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
+ }
+
/// Function to handle output reorder
///
/// This function performs very similar functionality as input reordering
@@ -1792,13 +1861,27 @@ class MklDnnData {
CHECK_NOTNULL(reorder_memory_);
net->push_back(CreateReorder(reorder_memory_, user_memory_));
}
+
+ /// TODO: this is a faster path with reorder primitive cache compared with
+ /// InsertReorderToUserMem(std::vector<primitive>* net), will remove
+ /// slow path in the future
+ inline void InsertReorderToUserMem() {
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(reorder_memory_);
+ // primitive reuse don't allow two same reorder prim in
+ // one stream, so submit it immediately
+ std::vector<primitive> net;
+ net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
};
-/// Base class for operations with reuse of DNN primitives
+/// Base class for operations with reuse of primitives
///
-class DnnOp {
+class MklPrimitive {
public:
- virtual ~DnnOp() {}
+ virtual ~MklPrimitive() {}
// Dummy data. Its size, hard-coded as 256 here, does
// not matter since MKL should never operate on this buffer.
@@ -1806,33 +1889,33 @@ class DnnOp {
};
const mkldnn::memory::dims NONE_DIMS = {};
-// This constant is used to declare dummy buffer (size), for MKL primitives
+
template <typename T>
-class DnnOpFactory {
+class MklPrimitiveFactory {
public:
- DnnOpFactory() {}
- ~DnnOpFactory() {}
+ MklPrimitiveFactory() {}
+ ~MklPrimitiveFactory() {}
- DnnOp* GetOp(const std::string& key) {
- auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key);
- if (stream_iter == DnnOpFactory<T>::GetHashMap().end()) {
+ MklPrimitive* GetOp(const std::string& key) {
+ auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
+ if (stream_iter == MklPrimitiveFactory<T>::GetHashMap().end()) {
return nullptr;
} else {
return stream_iter->second;
}
}
- void SetOp(const std::string& key, DnnOp* op) {
- auto stream_iter = DnnOpFactory<T>::GetHashMap().find(key);
+ void SetOp(const std::string& key, MklPrimitive* op) {
+ auto stream_iter = MklPrimitiveFactory<T>::GetHashMap().find(key);
- CHECK(stream_iter == DnnOpFactory<T>::GetHashMap().end());
+ CHECK(stream_iter == MklPrimitiveFactory<T>::GetHashMap().end());
- DnnOpFactory<T>::GetHashMap()[key] = op;
+ MklPrimitiveFactory<T>::GetHashMap()[key] = op;
}
private:
- static inline std::unordered_map<std::string, DnnOp*> &GetHashMap() {
- static thread_local std::unordered_map<std::string, DnnOp*> map_;
+ static inline std::unordered_map<std::string, MklPrimitive*>& GetHashMap() {
+ static thread_local std::unordered_map<std::string, MklPrimitive*> map_;
return map_;
}
};
@@ -1846,10 +1929,7 @@ class FactoryKeyCreator {
~FactoryKeyCreator() {}
- void AddAsKey(const string &str) {
- auto buffer = reinterpret_cast<const char *>(str.c_str());
- Append(buffer, str.length());
- }
+ void AddAsKey(const string& str) { Append(str); }
void AddAsKey(const mkldnn::memory::dims &dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
@@ -1860,7 +1940,7 @@ class FactoryKeyCreator {
template <typename T>
void AddAsKey(const T data) {
auto buffer = reinterpret_cast<const char *>(&data);
- Append(buffer, sizeof(T));
+ Append(StringPiece(buffer, sizeof(T)));
}
std::string GetKey() {
@@ -1871,12 +1951,115 @@ class FactoryKeyCreator {
string key_;
const char delimiter = 'x';
const int kMaxKeyLength = 256;
- void Append(const char* data, int len) {
- key_.append(data, len);
+ void Append(StringPiece s) {
+ key_.append(s.ToString());
key_.append(1, delimiter);
}
};
+class MklReorderPrimitive : public MklPrimitive {
+ public:
+ explicit MklReorderPrimitive(const memory* from, const memory* to) {
+ Setup(from, to);
+ }
+ ~MklReorderPrimitive() {}
+
+ std::shared_ptr<primitive> GetPrimitive() {
+ return context_.reorder_prim;
+ }
+
+ void SetMemory(const memory* from, const memory* to) {
+ context_.src_mem->set_data_handle(from->get_data_handle());
+ context_.dst_mem->set_data_handle(to->get_data_handle());
+ }
+
+ private:
+ struct ReorderContext {
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+ std::shared_ptr<primitive> reorder_prim;
+ ReorderContext():
+ src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {
+ }
+ } context_;
+
+ engine cpu_engine_ = engine(engine::cpu, 0);
+
+ void Setup(const memory* from, const memory* to) {
+ context_.src_mem.reset(new memory(
+ {from->get_primitive_desc().desc(), cpu_engine_}, DummyData));
+ context_.dst_mem.reset(new memory(
+ {to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
+ context_.reorder_prim = std::make_shared<mkldnn::reorder>(
+ reorder(*context_.src_mem, *context_.dst_mem));
+ }
+};
+
+template <typename T>
+class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklReorderPrimitive* Get(const memory* from,
+ const memory* to) {
+ auto reorderPrim = static_cast<MklReorderPrimitive*>(
+ MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
+ if (reorderPrim == nullptr) {
+ reorderPrim = new MklReorderPrimitive(from, to);
+ MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(
+ from, to, reorderPrim);
+ }
+ reorderPrim->SetMemory(from, to);
+ return reorderPrim;
+ }
+
+ static MklReorderPrimitiveFactory & GetInstance() {
+ static MklReorderPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklReorderPrimitiveFactory() {};
+ ~MklReorderPrimitiveFactory() {};
+
+ static std::string CreateKey(const memory* from, const memory* to) {
+ std::string prefix = "reorder";
+ FactoryKeyCreator key_creator;
+ auto const &from_desc = from->get_primitive_desc().desc().data;
+ auto const &to_desc = to->get_primitive_desc().desc().data;
+ memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
+ memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(static_cast<int>(from_desc.format));
+ key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
+ key_creator.AddAsKey(from_dims);
+ key_creator.AddAsKey(static_cast<int>(to_desc.format));
+ key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
+ key_creator.AddAsKey(to_dims);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetReorder(const memory* from, const memory* to) {
+ std::string key = CreateKey(from, to);
+ return this->GetOp(key);
+ }
+
+ void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
+ std::string key = CreateKey(from, to);
+ this->SetOp(key, op);
+ }
+};
+
+ /// Fuction to find(or create) a reorder from memory pointed by from to memory pointed
+ /// by to, it will created primitive or get primitive from pool if it is cached.
+ /// Returns the primitive.
+ template <typename T>
+ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
+ CHECK_NOTNULL(from);
+ CHECK_NOTNULL(to);
+ MklReorderPrimitive *reorder_prim =
+ MklReorderPrimitiveFactory<T>::Get(from, to);
+ return *reorder_prim->GetPrimitive();
+ }
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
index ee43945a39..90672a10a8 100644
--- a/tensorflow/core/util/saved_tensor_slice_util.h
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -123,6 +123,7 @@ TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
#undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
#undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index c0fce207e7..fb70318078 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -78,7 +78,10 @@ class GroupIterable {
typedef gtl::ArraySlice<int64> VarDimArray;
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
- : ix_(ix), vals_(vals), dims_(dims), group_dims_(group_dims) {}
+ : ix_(ix),
+ vals_(vals),
+ dims_(dims),
+ group_dims_(group_dims.begin(), group_dims.end()) {}
class IteratorStep;
@@ -127,7 +130,7 @@ class GroupIterable {
Tensor ix_;
Tensor vals_;
const int dims_;
- const VarDimArray group_dims_;
+ const gtl::InlinedVector<int64, 8> group_dims_;
};
// Implementation of Group::values<T>()
diff --git a/tensorflow/core/util/stat_summarizer.cc b/tensorflow/core/util/stat_summarizer.cc
index 42a4801dcb..a5c1fda102 100644
--- a/tensorflow/core/util/stat_summarizer.cc
+++ b/tensorflow/core/util/stat_summarizer.cc
@@ -78,6 +78,14 @@ void StatSummarizer::Validate(const std::vector<TensorDescription>* outputs,
}
}
+void StatSummarizer::PrintStepStats() const {
+ string output = GetOutputString();
+ std::istringstream iss(output);
+ for (std::string line; std::getline(iss, line);) {
+ LOG(INFO) << line;
+ }
+}
+
namespace {
std::string OpType(const DeviceStepStats& ds, const NodeExecStats& ns) {
// There is no published specification of how DeviceStats and NodeStats
diff --git a/tensorflow/core/util/stat_summarizer.h b/tensorflow/core/util/stat_summarizer.h
index 173ed5cebc..7e6d6f6372 100644
--- a/tensorflow/core/util/stat_summarizer.h
+++ b/tensorflow/core/util/stat_summarizer.h
@@ -68,7 +68,7 @@ class StatSummarizer {
}
// Prints the string returned by GetOutputString().
- void PrintStepStats() const { stats_calculator_->PrintStepStats(); }
+ void PrintStepStats() const;
// Prints the output tensor sizes and types for each node.
void PrintOutputs() const;
diff --git a/tensorflow/core/util/stats_calculator.cc b/tensorflow/core/util/stats_calculator.cc
index 20353ec76e..c4befbdb84 100644
--- a/tensorflow/core/util/stats_calculator.cc
+++ b/tensorflow/core/util/stats_calculator.cc
@@ -21,8 +21,6 @@ limitations under the License.
#include <sstream>
#include <string>
-#include "tensorflow/core/platform/logging.h"
-
namespace tensorflow {
StatsCalculator::StatsCalculator(const StatSummarizerOptions& options)
@@ -93,7 +91,7 @@ std::string StatsCalculator::ColumnString(const Detail& detail,
void StatsCalculator::OrderNodesByMetric(
SortingMetric metric, std::vector<const Detail*>* details) const {
- std::priority_queue<std::pair<string, const Detail*>> sorted_list;
+ std::priority_queue<std::pair<std::string, const Detail*>> sorted_list;
const int num_nodes = details_.size();
for (const auto& det : details_) {
@@ -142,7 +140,7 @@ void StatsCalculator::ComputeStatsByType(
int64_t run_count = run_total_us_.count();
for (const auto& det : details_) {
- const string node_name = det.first;
+ const std::string node_name = det.first;
const Detail& detail = det.second;
int64_t curr_time_val =
@@ -151,7 +149,7 @@ void StatsCalculator::ComputeStatsByType(
int64_t curr_memory_val = detail.mem_used.newest();
- const string& node_type = detail.type;
+ const std::string& node_type = detail.type;
(*node_type_map_count)[node_type] += 1;
(*node_type_map_time)[node_type] += curr_time_val;
@@ -163,12 +161,12 @@ void StatsCalculator::ComputeStatsByType(
std::string StatsCalculator::GetStatsByNodeType() const {
std::stringstream stream;
+ stream << "Number of nodes executed: " << details_.size() << std::endl;
+
stream << "============================== Summary by node type "
"=============================="
<< std::endl;
- LOG(INFO) << "Number of nodes executed: " << details_.size();
-
std::map<std::string, int64_t> node_type_map_count;
std::map<std::string, int64_t> node_type_map_time;
std::map<std::string, int64_t> node_type_map_memory;
@@ -180,11 +178,12 @@ std::string StatsCalculator::GetStatsByNodeType() const {
&accumulated_us);
// Sort them.
- std::priority_queue<std::pair<int64_t, std::pair<string, int64_t>>> timings;
+ std::priority_queue<std::pair<int64_t, std::pair<std::string, int64_t>>>
+ timings;
for (const auto& node_type : node_type_map_time) {
const int64_t mem_used = node_type_map_memory[node_type.first];
timings.emplace(node_type.second,
- std::pair<string, int64_t>(node_type.first, mem_used));
+ std::pair<std::string, int64_t>(node_type.first, mem_used));
}
InitField(stream, 24) << "[Node type]";
@@ -201,7 +200,7 @@ std::string StatsCalculator::GetStatsByNodeType() const {
auto entry = timings.top();
timings.pop();
- const string node_type = entry.second.first;
+ const std::string node_type = entry.second.first;
const float memory = entry.second.second / 1000.0f;
const int64_t node_type_total_us = entry.first;
@@ -273,14 +272,6 @@ std::string StatsCalculator::GetOutputString() const {
return stream.str();
}
-void StatsCalculator::PrintStepStats() const {
- string output = GetOutputString();
- std::istringstream iss(output);
- for (std::string line; std::getline(iss, line);) {
- LOG(INFO) << line;
- }
-}
-
void StatsCalculator::UpdateDetails(
const std::map<std::string, Detail>& details) {
details_.insert(details.begin(), details.end());
diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h
index a1033465fb..39cef816f1 100644
--- a/tensorflow/core/util/stats_calculator.h
+++ b/tensorflow/core/util/stats_calculator.h
@@ -127,9 +127,6 @@ class StatsCalculator {
std::string GetShortSummary() const;
- // Prints the string returned by GetOutputString().
- void PrintStepStats() const;
-
void ComputeStatsByType(
std::map<std::string, int64_t>* node_type_map_count,
std::map<std::string, int64_t>* node_type_map_time,
diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/util/status_util.h
new file mode 100644
index 0000000000..ea92f61dce
--- /dev/null
+++ b/tensorflow/core/util/status_util.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+// Creates a tag to be used in an exception error message. This can be parsed by
+// the Python layer and replaced with information about the node.
+//
+// For example, error_format_tag(node, "${file}") returns
+// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as
+// e.g. "file/where/node/was/created.py".
+inline string error_format_tag(const Node& node, const string& format) {
+ return strings::StrCat("^^node:", node.name(), ":", format, "^^");
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/util/status_util_test.cc
new file mode 100644
index 0000000000..1f06004db2
--- /dev/null
+++ b/tensorflow/core/util/status_util_test.cc
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/status_util.h"
+
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(TestStatusUtil, ErrorFormatTagForNode) {
+ Graph graph(OpRegistry::Global());
+ Node* node;
+ TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node));
+ EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^");
+ EXPECT_EQ(error_format_tag(*node, "${file}:${line}"),
+ "^^node:Foo:${file}:${line}^^");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index d4311d1ab0..33ab87aa78 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -18,7 +18,7 @@ limitations under the License.
namespace tensorflow {
string GetConvnetDataFormatAttrString() {
- return "data_format: { 'NHWC', 'NCHW' } = 'NHWC' ";
+ return "data_format: { 'NHWC', 'NCHW', 'HWNC', 'HWCN' } = 'NHWC' ";
}
string GetConvnet3dDataFormatAttrString() {
@@ -43,6 +43,10 @@ string ToString(TensorFormat format) {
return "NCHW_VECT_C";
case FORMAT_NHWC_VECT_W:
return "NHWC_VECT_W";
+ case FORMAT_HWNC:
+ return "HWNC";
+ case FORMAT_HWCN:
+ return "HWCN";
default:
LOG(FATAL) << "Invalid Format: " << static_cast<int32>(format);
return "INVALID_FORMAT";
@@ -80,6 +84,14 @@ bool FormatFromString(const string& format_str, TensorFormat* format) {
*format = FORMAT_NHWC_VECT_W;
return true;
}
+ if (format_str == "HWNC") {
+ *format = FORMAT_HWNC;
+ return true;
+ }
+ if (format_str == "HWCN") {
+ *format = FORMAT_HWCN;
+ return true;
+ }
return false;
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index d3d5602f92..918835e1fb 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -59,6 +59,12 @@ enum TensorFormat {
// In the future we may change the meaning of these enums to include vectors
// of other types such as int16x2, with op implementations automatically
// determining which format is implied based on the datatype.
+
+ // FORMAT_HWNC is for TPUs.
+ FORMAT_HWNC = 4,
+
+ // FORMAT_HWCN is for TPUs.
+ FORMAT_HWCN = 5,
};
// Tensor format for convolutional filters.
@@ -105,11 +111,11 @@ string ToString(FilterTensorFormat format);
inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
- return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW_VECT_C:
- return num_dims - 3; // Exclude N,C,VectDim.
case FORMAT_NHWC_VECT_W:
// Note: the VECT_W is not counted as an independent spatial dim here,
// since it just a component of the width dimension.
@@ -132,6 +138,8 @@ inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
switch (format) {
case FORMAT_NHWC:
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_spatial_dims + 2; // Include N,C.
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
@@ -158,6 +166,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
return 0;
+ case FORMAT_HWNC:
+ return num_dims - 2;
+ case FORMAT_HWCN:
+ return num_dims - 1;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -170,8 +182,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
+ case FORMAT_HWNC:
return num_dims - 1;
case FORMAT_NHWC_VECT_W:
+ case FORMAT_HWCN:
return num_dims - 2;
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
@@ -210,6 +224,9 @@ inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
return spatial_dim + 2;
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
+ return spatial_dim;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -310,6 +327,32 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
LOG(FATAL) << "Invalid dimension: " << dimension;
return -1; // Avoid compiler warning about missing return value
}
+ } else if (format == FORMAT_HWNC) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'N': return NUM_SPATIAL_DIMS;
+ case 'C': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ } else if (format == FORMAT_HWCN) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'C': return NUM_SPATIAL_DIMS;
+ case 'N': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
} else {
LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
return -1; // Avoid compiler warning about missing return value
diff --git a/tensorflow/core/util/tensor_format_test.cc b/tensorflow/core/util/tensor_format_test.cc
index 93902290eb..07cdce998a 100644
--- a/tensorflow/core/util/tensor_format_test.cc
+++ b/tensorflow/core/util/tensor_format_test.cc
@@ -26,10 +26,9 @@ namespace tensorflow {
{ val, #val }
std::pair<TensorFormat, const char*> test_data_formats[] = {
- EnumStringPair(FORMAT_NHWC),
- EnumStringPair(FORMAT_NCHW),
- EnumStringPair(FORMAT_NCHW_VECT_C),
- EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW),
+ EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN),
};
std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
@@ -85,6 +84,16 @@ struct DimMaps {
{ 0, 2, 3, 1, { 2, 3, -1 } },
{ 0, 3, 4, 1, { 2, 3, 4 } }
};
+ StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
+ { 1, -1, 0, 2, { 0, -1, -1 } },
+ { 2, 0, 1, 3, { 0, 1, -1 } },
+ { 3, 1, 2, 4, { 0, 1, 2 } }
+ };
+ StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
+ { 2, -1, 0, 1, { 0, -1, -1 } },
+ { 3, 0, 1, 2, { 0, 1, -1 } },
+ { 4, 1, 2, 3, { 0, 1, 2 } }
+ };
#undef StaCoExTensorDm
#define StaCoExFilterDm static constexpr FilterDimMap
// 'H', 'W', 'I', 'O' 0 1 2
@@ -108,8 +117,10 @@ GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
(format == FORMAT_NHWC ||
format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
(format == FORMAT_NCHW ||
- format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims]
- : DimMaps::kTdmInvalid;
+ format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
+ (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
+ (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
+ : DimMaps::kTdmInvalid;
}
inline constexpr const FilterDimMap&
@@ -126,6 +137,8 @@ GetFilterDimMap(const int num_spatial_dims,
constexpr TensorDimMap DimMaps::kTdmInvalid;
constexpr TensorDimMap DimMaps::kTdmNHWC[4];
constexpr TensorDimMap DimMaps::kTdmNCHW[4];
+constexpr TensorDimMap DimMaps::kTdmHWNC[4];
+constexpr TensorDimMap DimMaps::kTdmHWCN[4];
constexpr FilterDimMap DimMaps::kFdmInvalid;
constexpr FilterDimMap DimMaps::kFdmHWIO[4];
constexpr FilterDimMap DimMaps::kFdmOIHW[4];
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index 337af07b50..f4bd2950e9 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -20,12 +20,22 @@ limitations under the License.
namespace tensorflow {
+/* ABSL_CONST_INIT */ thread_local int per_thread_max_parallism = 1000000;
+
+void SetPerThreadMaxParallelism(int max_parallelism) {
+ CHECK_LE(0, max_parallelism);
+ per_thread_max_parallism = max_parallelism;
+}
+
+int GetPerThreadMaxParallelism() { return per_thread_max_parallism; }
+
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
int64 cost_per_unit, std::function<void(int64, int64)> work) {
CHECK_GE(total, 0);
if (total == 0) {
return;
}
+ max_parallelism = std::min(max_parallelism, GetPerThreadMaxParallelism());
if (max_parallelism <= 1) {
// Just inline the whole work since we only have 1 thread (core).
work(0, total);
@@ -35,6 +45,13 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
workers->ParallelFor(total, cost_per_unit, work);
return;
}
+ Sharder::Do(total, cost_per_unit, work,
+ [&workers](Sharder::Closure c) { workers->Schedule(c); },
+ max_parallelism);
+}
+
+void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
+ const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
// We shard [0, total) into "num_shards" shards.
// 1 <= num_shards <= num worker threads
@@ -63,7 +80,7 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
BlockingCounter counter(num_shards_used - 1);
for (int64 start = block_size; start < total; start += block_size) {
auto limit = std::min(start + block_size, total);
- workers->Schedule([&work, &counter, start, limit]() {
+ runner([&work, &counter, start, limit]() {
work(start, limit); // Compute the shard.
counter.DecrementCount(); // The shard is done.
});
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index 451da98b6b..72ce493c1b 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -41,6 +41,12 @@ namespace tensorflow {
// work(start, limit) computes the work units from [start,
// limit), i.e., [start, limit) is a shard.
//
+// Too much parallelism can also cause excessive thread switches,
+// therefore, Shard() often limits the maximum parallelism. Each
+// caller can provide the 1st argument max_parallelism. A thread can
+// call SetMaxParallelism() so that all Shard() calls later limits the
+// thread parallelism.
+//
// REQUIRES: max_parallelism >= 0
// REQUIRES: workers != nullptr
// REQUIRES: total >= 0
@@ -48,6 +54,45 @@ namespace tensorflow {
void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
int64 cost_per_unit, std::function<void(int64, int64)> work);
+// Each thread has an associated option to express the desired maximum
+// parallelism. Its default is a very large quantity.
+//
+// Within TF runtime, per-thread max parallelism affects Shard() and
+// intra-op parallelism. E.g., if SetPerThreadMaxParallelism(1) is
+// arranged to be called by a tf_compute thread, Shard() calls and
+// eigen device assignment happens in that thread afterwards becomes
+// single-threaded.
+void SetPerThreadMaxParallelism(int max_parallelism);
+int GetPerThreadMaxParallelism();
+
+// Helper to set and unset per-thread max parallelism.
+class ScopedPerThreadMaxParallelism {
+ public:
+ ScopedPerThreadMaxParallelism(int max_parallelism)
+ : previous_(GetPerThreadMaxParallelism()) {
+ SetPerThreadMaxParallelism(max_parallelism);
+ }
+
+ ~ScopedPerThreadMaxParallelism() { SetPerThreadMaxParallelism(previous_); }
+
+ private:
+ int previous_ = -1;
+};
+
+// Implementation details for Shard().
+class Sharder {
+ public:
+ typedef std::function<void()> Closure;
+ typedef std::function<void(Closure)> Runner;
+ typedef std::function<void(int64, int64)> Work;
+
+ // Refers to Shard()'s comment for the meaning of total,
+ // cost_per_unit, work, max_parallelism. runner is an interface to
+ // schedule a closure. Shard() uses thread::ThreadPool instead.
+ static void Do(int64 total, int64 cost_per_unit, const Work& work,
+ const Runner& runner, int max_parallelism);
+};
+
} // end namespace tensorflow
#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
index 0694566ad9..bc5a1d221f 100644
--- a/tensorflow/core/util/work_sharder_test.cc
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -28,6 +28,7 @@ namespace tensorflow {
namespace {
void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit,
+ int64 per_thread_max_parallelism,
thread::ThreadPool* threads) {
mutex mu;
int64 num_shards = 0;
@@ -46,9 +47,18 @@ void RunSharding(int64 num_workers, int64 total, int64 cost_per_unit,
work[start] = true;
}
});
- EXPECT_EQ(num_done_work, total);
LOG(INFO) << num_workers << " " << total << " " << cost_per_unit << " "
<< num_shards;
+ EXPECT_EQ(num_done_work, total);
+ if (std::min(num_workers, per_thread_max_parallelism) <
+ threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + per_thread_max_parallelism);
+ }
}
TEST(Shard, Basic) {
@@ -56,7 +66,10 @@ TEST(Shard, Basic) {
for (auto workers : {0, 1, 2, 3, 5, 7, 10, 11, 15, 100, 1000}) {
for (auto total : {0, 1, 7, 10, 64, 100, 256, 1000, 9999}) {
for (auto cost_per_unit : {0, 1, 11, 102, 1003, 10005, 1000007}) {
- RunSharding(workers, total, cost_per_unit, &threads);
+ for (auto maxp : {1, 2, 4, 8, 100}) {
+ ScopedPerThreadMaxParallelism s(maxp);
+ RunSharding(workers, total, cost_per_unit, maxp, &threads);
+ }
}
}
}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md b/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md
deleted file mode 100644
index 74fe4a323a..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.bayesflow.monte_carlo.md
+++ /dev/null
@@ -1,50 +0,0 @@
-# BayesFlow Monte Carlo (contrib)
-[TOC]
-
-Monte Carlo integration and helpers.
-
-## Background
-
-Monte Carlo integration refers to the practice of estimating an expectation with
-a sample mean. For example, given random variable Z in \\(R^k\\) with density `p`,
-the expectation of function `f` can be approximated like:
-
-$$E_p[f(Z)] = \int f(z) p(z) dz$$
-$$ ~ S_n
- := n^{-1} \sum_{i=1}^n f(z_i), z_i\ iid\ samples\ from\ p.$$
-
-If \\(E_p[|f(Z)|] < infinity\\), then \\(S_n\\) --> \\(E_p[f(Z)]\\) by the strong law of large
-numbers. If \\(E_p[f(Z)^2] < infinity\\), then \\(S_n\\) is asymptotically normal with
-variance \\(Var[f(Z)] / n\\).
-
-Practitioners of Bayesian statistics often find themselves wanting to estimate
-\\(E_p[f(Z)]\\) when the distribution `p` is known only up to a constant. For
-example, the joint distribution `p(z, x)` may be known, but the evidence
-\\(p(x) = \int p(z, x) dz\\) may be intractable. In that case, a parameterized
-distribution family \\(q_\lambda(z)\\) may be chosen, and the optimal \\(\lambda\\) is the
-one minimizing the KL divergence between \\(q_\lambda(z)\\) and
-\\(p(z | x)\\). We only know `p(z, x)`, but that is sufficient to find \\(\lambda\\).
-
-
-## Log-space evaluation and subtracting the maximum
-
-Care must be taken when the random variable lives in a high dimensional space.
-For example, the naive importance sample estimate \\(E_q[f(Z) p(Z) / q(Z)]\\)
-involves the ratio of two terms \\(p(Z) / q(Z)\\), each of which must have tails
-dropping off faster than \\(O(|z|^{-(k + 1)})\\) in order to have finite integral.
-This ratio would often be zero or infinity up to numerical precision.
-
-For that reason, we write
-
-$$Log E_q[ f(Z) p(Z) / q(Z) ]$$
-$$ = Log E_q[ \exp\{Log[f(Z)] + Log[p(Z)] - Log[q(Z)] - C\} ] + C,$$ where
-$$C := Max[ Log[f(Z)] + Log[p(Z)] - Log[q(Z)] ].$$
-
-The maximum value of the exponentiated term will be 0.0, and the expectation
-can be evaluated in a stable manner.
-
-## Ops
-
-* @{tf.contrib.bayesflow.monte_carlo.expectation}
-* @{tf.contrib.bayesflow.monte_carlo.expectation_importance_sampler}
-* @{tf.contrib.bayesflow.monte_carlo.expectation_importance_sampler_logspace}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
deleted file mode 100644
index e169897f31..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.bijectors.md
+++ /dev/null
@@ -1,32 +0,0 @@
-# Random variable transformations (contrib)
-[TOC]
-
-Bijector Ops.
-
-An API for invertible, differentiable transformations of random variables.
-
-## Background
-
-Differentiable, bijective transformations of continuous random variables alter
-the calculations made in the cumulative/probability distribution functions and
-sample function. This module provides a standard interface for making these
-manipulations.
-
-For more details and examples, see the `Bijector` docstring.
-
-To apply a `Bijector`, use `distributions.TransformedDistribution`.
-
-## Bijectors
-
-* @{tf.contrib.distributions.bijectors.Affine}
-* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
-* @{tf.contrib.distributions.bijectors.Bijector}
-* @{tf.contrib.distributions.bijectors.Chain}
-* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
-* @{tf.contrib.distributions.bijectors.Exp}
-* @{tf.contrib.distributions.bijectors.Identity}
-* @{tf.contrib.distributions.bijectors.Inline}
-* @{tf.contrib.distributions.bijectors.Invert}
-* @{tf.contrib.distributions.bijectors.PowerTransform}
-* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
-* @{tf.contrib.distributions.bijectors.Softplus}
diff --git a/tensorflow/docs_src/api_guides/python/contrib.distributions.md b/tensorflow/docs_src/api_guides/python/contrib.distributions.md
deleted file mode 100644
index 533d7dac13..0000000000
--- a/tensorflow/docs_src/api_guides/python/contrib.distributions.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Statistical Distributions (contrib)
-[TOC]
-
-Classes representing statistical distributions and ops for working with them.
-
-## Classes for statistical distributions
-
-Classes that represent batches of statistical distributions. Each class is
-initialized with parameters that define the distributions.
-
-## Base classes
-
-* @{tf.contrib.distributions.ReparameterizationType}
-* @{tf.contrib.distributions.Distribution}
-
-## Univariate (scalar) distributions
-
-* @{tf.contrib.distributions.Binomial}
-* @{tf.contrib.distributions.Bernoulli}
-* @{tf.contrib.distributions.Beta}
-* @{tf.contrib.distributions.Categorical}
-* @{tf.contrib.distributions.Chi2}
-* @{tf.contrib.distributions.Chi2WithAbsDf}
-* @{tf.contrib.distributions.Exponential}
-* @{tf.contrib.distributions.Gamma}
-* @{tf.contrib.distributions.InverseGamma}
-* @{tf.contrib.distributions.Laplace}
-* @{tf.contrib.distributions.LaplaceWithSoftplusScale}
-* @{tf.contrib.distributions.Normal}
-* @{tf.contrib.distributions.NormalWithSoftplusScale}
-* @{tf.contrib.distributions.Poisson}
-* @{tf.contrib.distributions.StudentT}
-* @{tf.contrib.distributions.StudentTWithAbsDfSoftplusScale}
-* @{tf.contrib.distributions.Uniform}
-
-## Multivariate distributions
-
-### Multivariate normal
-
-* @{tf.contrib.distributions.MultivariateNormalDiag}
-* @{tf.contrib.distributions.MultivariateNormalTriL}
-* @{tf.contrib.distributions.MultivariateNormalDiagPlusLowRank}
-* @{tf.contrib.distributions.MultivariateNormalDiagWithSoftplusScale}
-
-### Other multivariate distributions
-
-* @{tf.contrib.distributions.Dirichlet}
-* @{tf.contrib.distributions.DirichletMultinomial}
-* @{tf.contrib.distributions.Multinomial}
-* @{tf.contrib.distributions.WishartCholesky}
-* @{tf.contrib.distributions.WishartFull}
-
-### Multivariate Utilities
-
-* @{tf.contrib.distributions.matrix_diag_transform}
-
-## Transformed distributions
-
-* @{tf.contrib.distributions.TransformedDistribution}
-* @{tf.contrib.distributions.QuantizedDistribution}
-
-## Mixture Models
-
-* @{tf.contrib.distributions.Mixture}
-
-## Posterior inference with conjugate priors
-
-Functions that transform conjugate prior/likelihood pairs to distributions
-representing the posterior or posterior predictive.
-
-## Normal likelihood with conjugate prior
-
-* @{tf.contrib.distributions.normal_conjugates_known_scale_posterior}
-* @{tf.contrib.distributions.normal_conjugates_known_scale_predictive}
-
-## Kullback-Leibler Divergence
-
-* @{tf.contrib.distributions.kl_divergence}
-* @{tf.contrib.distributions.RegisterKL}
-
-## Utilities
-
-* @{tf.contrib.distributions.softplus_inverse}
diff --git a/tensorflow/docs_src/api_guides/python/spectral_ops.md b/tensorflow/docs_src/api_guides/python/spectral_ops.md
index 022c471ef1..dd13802f00 100644
--- a/tensorflow/docs_src/api_guides/python/spectral_ops.md
+++ b/tensorflow/docs_src/api_guides/python/spectral_ops.md
@@ -23,3 +23,4 @@ that you can use to transform Tensors of real and complex signals.
## Discrete Cosine Transforms
* @{tf.spectral.dct}
+* @{tf.spectral.idct}
diff --git a/tensorflow/docs_src/deploy/s3.md b/tensorflow/docs_src/deploy/s3.md
index 9ef9674338..7028249e94 100644
--- a/tensorflow/docs_src/deploy/s3.md
+++ b/tensorflow/docs_src/deploy/s3.md
@@ -90,4 +90,4 @@ S3 was invented by Amazon, but the S3 API has spread in popularity and has sever
* [Amazon S3](https://aws.amazon.com/s3/)
* [Google Storage](https://cloud.google.com/storage/docs/interoperability)
-* [Minio](https://www.minio.io/kubernetes.html)(Standalone mode only)
+* [Minio](https://www.minio.io/kubernetes.html)
diff --git a/tensorflow/docs_src/extend/index.md b/tensorflow/docs_src/extend/index.md
index 1ab0340ad9..d48340a777 100644
--- a/tensorflow/docs_src/extend/index.md
+++ b/tensorflow/docs_src/extend/index.md
@@ -17,7 +17,8 @@ TensorFlow:
Python is currently the only language supported by TensorFlow's API stability
promises. However, TensorFlow also provides functionality in C++, Go, Java and
-[JavaScript](https://js.tensorflow.org),
+[JavaScript](https://js.tensorflow.org) (incuding
+[Node.js](https://github.com/tensorflow/tfjs-node)),
plus community support for [Haskell](https://github.com/tensorflow/haskell) and
[Rust](https://github.com/tensorflow/rust). If you'd like to create or
develop TensorFlow features in a language other than these languages, read the
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index 003ca265fe..e98206eef9 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -421,7 +421,7 @@ class Model(tf.keras.Model):
super(Model, self).__init__()
self.W = tfe.Variable(5., name='weight')
self.B = tfe.Variable(10., name='bias')
- def predict(self, inputs):
+ def call(self, inputs):
return inputs * self.W + self.B
# A toy dataset of points around 3 * x + 2
@@ -432,7 +432,7 @@ training_outputs = training_inputs * 3 + 2 + noise
# The loss function to be optimized
def loss(model, inputs, targets):
- error = model.predict(inputs) - targets
+ error = model(inputs) - targets
return tf.reduce_mean(tf.square(error))
def grad(model, inputs, targets):
diff --git a/tensorflow/docs_src/guide/feature_columns.md b/tensorflow/docs_src/guide/feature_columns.md
index 1013ec910c..41080e050b 100644
--- a/tensorflow/docs_src/guide/feature_columns.md
+++ b/tensorflow/docs_src/guide/feature_columns.md
@@ -561,9 +561,9 @@ For more examples on feature columns, view the following:
* The @{$low_level_intro#feature_columns$Low Level Introduction} demonstrates how
experiment directly with `feature_columns` using TensorFlow's low level APIs.
-* The @{$wide$wide} and @{$wide_and_deep$Wide & Deep} Tutorials solve a
- binary classification problem using `feature_columns` on a variety of input
- data types.
+* The [Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
+ solves a binary classification problem using `feature_columns` on a variety of
+ input data types.
To learn more about embeddings, see the following:
diff --git a/tensorflow/docs_src/guide/graph_viz.md b/tensorflow/docs_src/guide/graph_viz.md
index f581ae56da..a8876da5a5 100644
--- a/tensorflow/docs_src/guide/graph_viz.md
+++ b/tensorflow/docs_src/guide/graph_viz.md
@@ -248,7 +248,8 @@ The images below show the CIFAR-10 model with tensor shape information:
Often it is useful to collect runtime metadata for a run, such as total memory
usage, total compute time, and tensor shapes for nodes. The code example below
is a snippet from the train and test section of a modification of the
-@{$layers$simple MNIST tutorial}, in which we have recorded summaries and
+[Estimators MNIST tutorial](../tutorials/estimators/cnn.md), in which we have
+recorded summaries and
runtime statistics. See the
@{$summaries_and_tensorboard#serializing-the-data$Summaries Tutorial}
for details on how to record summaries.
diff --git a/tensorflow/docs_src/guide/version_compat.md b/tensorflow/docs_src/guide/version_compat.md
index 72e427c5f8..d2e5e41190 100644
--- a/tensorflow/docs_src/guide/version_compat.md
+++ b/tensorflow/docs_src/guide/version_compat.md
@@ -301,8 +301,10 @@ existing producer scripts will not suddenly use the new functionality.
#### Change an op's functionality
1. Add a new similar op named `SomethingV2` or similar and go through the
- process of adding it and switching existing Python wrappers to use it, which
- may take three weeks if forward compatibility is desired.
+ process of adding it and switching existing Python wrappers to use it.
+ To ensure forward compatibility use the checks suggested in
+ [compat.py](https://www.tensorflow.org/code/tensorflow/python/compat/compat.py)
+ when changing the Python wrappers.
2. Remove the old op (Can only take place with a major version change due to
backward compatibility).
3. Increase `min_consumer` to rule out consumers with the old op, add back the
diff --git a/tensorflow/docs_src/javascript/index.md b/tensorflow/docs_src/javascript/index.md
deleted file mode 100644
index ad63eeb255..0000000000
--- a/tensorflow/docs_src/javascript/index.md
+++ /dev/null
@@ -1,5 +0,0 @@
-# JavaScript
-
-You may develop TensorFlow programs in JavaScript, training and deploying
-models right in your browser. For details, see
-[js.tensorflow.org](https://js.tensorflow.org).
diff --git a/tensorflow/docs_src/javascript/leftnav_files b/tensorflow/docs_src/javascript/leftnav_files
deleted file mode 100644
index fc0ab8a543..0000000000
--- a/tensorflow/docs_src/javascript/leftnav_files
+++ /dev/null
@@ -1 +0,0 @@
-index.md
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 4c4f3f3934..68c427a316 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -2015,30 +2015,37 @@ two-operand version.
<b>`Sort(operand)`</b>
-Arguments | Type | Semantics
---------- | ------- | --------------------
-`operand` | `XlaOp` | The operand to sort.
-
-Sorts the elements in the operand in ascending order. The operand must be rank-1.
-If the operand's elements have floating point type, and the operand contains
-NaN elements, the order of elements in the output is implementation-defined.
+Arguments | Type | Semantics
+----------- | ------- | --------------------
+`operand` | `XlaOp` | The operand to sort.
+`dimension` | `int64` | The dimension along which to sort.
+
+Sorts the elements in the operand in ascending order along the provided
+dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0
+will sort each column independently, and a `dimension` value of 1 will sort each
+row independently. If the operand's elements have floating point type, and the
+operand contains NaN elements, the order of elements in the output is
+implementation-defined.
<b>`Sort(key, value)`</b>
Sorts both the key and the value operands. The keys are sorted as in the
single-operand version. The values are sorted according to the order of their
corresponding keys. For example, if the inputs are `keys = [3, 1]` and
-`values = [42, 50]`, then the output of the sort is the tuple `{[1, 3], [50, 42]}`.
+`values = [42, 50]`, then the output of the sort is the tuple
+`{[1, 3], [50, 42]}`.
+
The sort is not guaranteed to be stable, that is, if the keys array contains
duplicates, the order of their corresponding values may not be preserved.
-Arguments | Type | Semantics
---------- | ------- | -------------------
-`keys` | `XlaOp` | The sort keys.
-`values` | `XlaOp` | The values to sort.
+Arguments | Type | Semantics
+----------- | ------- | -------------------
+`keys` | `XlaOp` | The sort keys.
+`values` | `XlaOp` | The values to sort.
+`dimension` | `int64` | The dimension along which to sort.
-The `keys` and `values` operand must both be rank-1, and must have the same
-dimensions, but may have different element types.
+The `keys` and `values` must have the same dimensions, but may have different
+element types.
## Transpose
diff --git a/tensorflow/docs_src/tutorials/_index.yaml b/tensorflow/docs_src/tutorials/_index.yaml
index 6fc8155669..07d561b8a2 100644
--- a/tensorflow/docs_src/tutorials/_index.yaml
+++ b/tensorflow/docs_src/tutorials/_index.yaml
@@ -170,15 +170,16 @@ landing_page:
<div class="devsite-landing-row-item-description-content">
<p>
Estimators can train large models on multiple machines in a
- production environment. Read the
- <a href="/guide/estimators">Estimators guide</a> for details.
+ production environment. TensorFlow provides a collection of
+ pre-made Estimators to implement common ML algorithms. See the
+ <a href="/guide/estimators">Estimators guide</a>.
</p>
<ol style="padding-left: 20px;">
- <li><a href="/tutorials/images/layers">Build a Convolutional Neural Network using Estimators</a></li>
+ <li><a href="/guide/premade_estimators">Premade Estimators guide</a></li>
+ <li><a href="https://github.com/tensorflow/models/tree/master/official/wide_deep" class="external">Wide and deep learning with Estimators</a></li>
+ <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees" class="external">Boosted trees</a></li>
<li><a href="/hub/tutorials/text_classification_with_tf_hub">How to build a simple text classifier with TF-Hub</a></li>
- <li><a href="https://github.com/tensorflow/models/tree/master/official/boosted_trees">Classifying Higgs boson processes</a></li>
- <li><a href="/tutorials/representation/wide_and_deep">Wide and deep learning using Estimators</a></li>
- <li><a href="/tutorials/representation/linear">Large-scale linear models</a></li>
+ <li><a href="/tutorials/estimators/cnn">Build a Convolutional Neural Network using Estimators</a></li>
</ol>
</div>
<div class="devsite-landing-row-item-buttons">
diff --git a/tensorflow/docs_src/tutorials/_toc.yaml b/tensorflow/docs_src/tutorials/_toc.yaml
index d46d570a93..4db97e35fc 100644
--- a/tensorflow/docs_src/tutorials/_toc.yaml
+++ b/tensorflow/docs_src/tutorials/_toc.yaml
@@ -24,7 +24,7 @@ toc:
- title: Overview
path: /tutorials/eager/
- title: Eager execution
- path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_intro.ipynb
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
status: external
- title: Automatic differentiation
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -37,15 +37,27 @@ toc:
status: external
- title: "Custom training: walkthrough"
path: /tutorials/eager/custom_training_walkthrough
- - title: Neural machine translation
+ - title: Translation with attention
path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
status: external
-- title: Images
+- title: ML at production scale
style: accordion
section:
+ - title: Wide and deep learning
+ path: https://github.com/tensorflow/models/tree/master/official/wide_deep
+ status: external
+ - title: Boosted trees
+ path: https://github.com/tensorflow/models/tree/master/official/boosted_trees
+ status: external
+ - title: Text classifier with TF-Hub
+ path: /hub/tutorials/text_classification_with_tf_hub
- title: Build a CNN using Estimators
- path: /tutorials/images/layers
+ path: /tutorials/estimators/cnn
+
+- title: Images
+ style: accordion
+ section:
- title: Image recognition
path: /tutorials/images/image_recognition
- title: Image retraining
@@ -69,10 +81,6 @@ toc:
- title: Data representation
style: accordion
section:
- - title: Linear models
- path: /tutorials/representation/wide
- - title: Wide and deep learning
- path: /tutorials/representation/wide_and_deep
- title: Vector representations of words
path: /tutorials/representation/word2vec
- title: Kernel methods
diff --git a/tensorflow/docs_src/tutorials/images/layers.md b/tensorflow/docs_src/tutorials/estimators/cnn.md
index 12a215b50c..12a215b50c 100644
--- a/tensorflow/docs_src/tutorials/images/layers.md
+++ b/tensorflow/docs_src/tutorials/estimators/cnn.md
diff --git a/tensorflow/docs_src/tutorials/images/deep_cnn.md b/tensorflow/docs_src/tutorials/images/deep_cnn.md
index 1590f15eb9..27963575f5 100644
--- a/tensorflow/docs_src/tutorials/images/deep_cnn.md
+++ b/tensorflow/docs_src/tutorials/images/deep_cnn.md
@@ -80,21 +80,21 @@ for details. It consists of 1,068,298 learnable parameters and requires about
## Code Organization
The code for this tutorial resides in
-[`models/tutorials/image/cifar10/`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/).
+[`models/tutorials/image/cifar10/`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/).
File | Purpose
--- | ---
-[`cifar10_input.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format.
-[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model.
-[`cifar10_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU.
-[`cifar10_multi_gpu_train.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs.
-[`cifar10_eval.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
+[`cifar10_input.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_input.py) | Reads the native CIFAR-10 binary file format.
+[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py) | Builds the CIFAR-10 model.
+[`cifar10_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_train.py) | Trains a CIFAR-10 model on a CPU or GPU.
+[`cifar10_multi_gpu_train.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py) | Trains a CIFAR-10 model on multiple GPUs.
+[`cifar10_eval.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10_eval.py) | Evaluates the predictive performance of a CIFAR-10 model.
## CIFAR-10 Model
The CIFAR-10 network is largely contained in
-[`cifar10.py`](https://www.tensorflow.org/code/tensorflow_models/tutorials/image/cifar10/cifar10.py).
+[`cifar10.py`](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10/cifar10.py).
The complete training
graph contains roughly 765 operations. We find that we can make the code most
reusable by constructing the graph with the following modules:
diff --git a/tensorflow/docs_src/tutorials/images/image_recognition.md b/tensorflow/docs_src/tutorials/images/image_recognition.md
index 432d470d0c..d545de73df 100644
--- a/tensorflow/docs_src/tutorials/images/image_recognition.md
+++ b/tensorflow/docs_src/tutorials/images/image_recognition.md
@@ -449,7 +449,7 @@ covering them.
To find out more about implementing convolutional neural networks, you can jump
to the TensorFlow @{$deep_cnn$deep convolutional networks tutorial},
-or start a bit more gently with our @{$layers$MNIST starter tutorial}.
+or start a bit more gently with our [Estimator MNIST tutorial](../estimators/cnn.md).
Finally, if you want to get up to speed on research in this area, you can
read the recent work of all the papers referenced in this tutorial.
diff --git a/tensorflow/docs_src/tutorials/representation/linear.md b/tensorflow/docs_src/tutorials/representation/linear.md
index 3f247ade26..1b418cf065 100644
--- a/tensorflow/docs_src/tutorials/representation/linear.md
+++ b/tensorflow/docs_src/tutorials/representation/linear.md
@@ -11,8 +11,9 @@ those tools. It explains:
deep learning to get the advantages of both.
Read this overview to decide whether the Estimator's linear model tools might
-be useful to you. Then do the @{$wide$Linear Models tutorial} to
-give it a try. This overview uses code samples from the tutorial, but the
+be useful to you. Then work through the
+[Estimator wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
+to give it a try. This overview uses code samples from the tutorial, but the
tutorial walks through the code in greater detail.
To understand this overview it will help to have some familiarity
@@ -176,7 +177,7 @@ the name of a `FeatureColumn`. Each key's value is a tensor containing the
values of that feature for all data instances. See
@{$premade_estimators#input_fn} for a
more comprehensive look at input functions, and `input_fn` in the
-[linear models tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py)
+[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep)
for an example implementation of an input function.
The input function is passed to the `train()` and `evaluate()` calls that
@@ -234,4 +235,5 @@ e = tf.estimator.DNNLinearCombinedClassifier(
dnn_feature_columns=deep_columns,
dnn_hidden_units=[100, 50])
```
-For more information, see the @{$wide_and_deep$Wide and Deep Learning tutorial}.
+For more information, see the
+[wide and deep learning tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
diff --git a/tensorflow/docs_src/tutorials/representation/wide.md b/tensorflow/docs_src/tutorials/representation/wide.md
deleted file mode 100644
index 27ce75a30d..0000000000
--- a/tensorflow/docs_src/tutorials/representation/wide.md
+++ /dev/null
@@ -1,461 +0,0 @@
-# TensorFlow Linear Model Tutorial
-
-In this tutorial, we will use the tf.estimator API in TensorFlow to solve a
-binary classification problem: Given census data about a person such as age,
-education, marital status, and occupation (the features), we will try to predict
-whether or not the person earns more than 50,000 dollars a year (the target
-label). We will train a **logistic regression** model, and given an individual's
-information our model will output a number between 0 and 1, which can be
-interpreted as the probability that the individual has an annual income of over
-50,000 dollars.
-
-## Setup
-
-To try the code for this tutorial:
-
-1. @{$install$Install TensorFlow} if you haven't already.
-
-2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/).
-
-3. Execute the data download script we provide to you:
-
- $ python data_download.py
-
-4. Execute the tutorial code with the following command to train the linear
-model described in this tutorial:
-
- $ python wide_deep.py --model_type=wide
-
-Read on to find out how this code builds its linear model.
-
-## Reading The Census Data
-
-The dataset we'll be using is the
-[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income).
-We have provided
-[data_download.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/data_download.py)
-which downloads the code and performs some additional cleanup.
-
-Since the task is a binary classification problem, we'll construct a label
-column named "label" whose value is 1 if the income is over 50K, and 0
-otherwise. For reference, see `input_fn` in
-[wide_deep.py](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-Next, let's take a look at the dataframe and see which columns we can use to
-predict the target label. The columns can be grouped into two types—categorical
-and continuous columns:
-
-* A column is called **categorical** if its value can only be one of the
- categories in a finite set. For example, the relationship status of a person
- (wife, husband, unmarried, etc.) or the education level (high school,
- college, etc.) are categorical columns.
-* A column is called **continuous** if its value can be any numerical value in
- a continuous range. For example, the capital gain of a person (e.g. $14,084)
- is a continuous column.
-
-Here's a list of columns available in the Census Income dataset:
-
-| Column Name | Type | Description |
-| -------------- | ----------- | --------------------------------- |
-| age | Continuous | The age of the individual |
-| workclass | Categorical | The type of employer the |
-: : : individual has (government, :
-: : : military, private, etc.). :
-| fnlwgt | Continuous | The number of people the census |
-: : : takers believe that observation :
-: : : represents (sample weight). Final :
-: : : weight will not be used. :
-| education | Categorical | The highest level of education |
-: : : achieved for that individual. :
-| education_num | Continuous | The highest level of education in |
-: : : numerical form. :
-| marital_status | Categorical | Marital status of the individual. |
-| occupation | Categorical | The occupation of the individual. |
-| relationship | Categorical | Wife, Own-child, Husband, |
-: : : Not-in-family, Other-relative, :
-: : : Unmarried. :
-| race | Categorical | Amer-Indian-Eskimo, Asian-Pac- |
-: : : Islander, Black, White, Other. :
-| gender | Categorical | Female, Male. |
-| capital_gain | Continuous | Capital gains recorded. |
-| capital_loss | Continuous | Capital Losses recorded. |
-| hours_per_week | Continuous | Hours worked per week. |
-| native_country | Categorical | Country of origin of the |
-: : : individual. :
-| income_bracket | Categorical | ">50K" or "<=50K", meaning |
-: : : whether the person makes more :
-: : : than $50,000 annually. :
-
-## Converting Data into Tensors
-
-When building a tf.estimator model, the input data is specified by means of an
-Input Builder function. This builder function will not be called until it is
-later passed to tf.estimator.Estimator methods such as `train` and `evaluate`.
-The purpose of this function is to construct the input data, which is
-represented in the form of @{tf.Tensor}s or @{tf.SparseTensor}s.
-In more detail, the input builder function returns the following as a pair:
-
-1. `features`: A dict from feature column names to `Tensors` or
- `SparseTensors`.
-2. `labels`: A `Tensor` containing the label column.
-
-The keys of the `features` will be used to construct columns in the next
-section. Because we want to call the `train` and `evaluate` methods with
-different data, we define a method that returns an input function based on the
-given data. Note that the returned input function will be called while
-constructing the TensorFlow graph, not while running the graph. What it is
-returning is a representation of the input data as the fundamental unit of
-TensorFlow computations, a `Tensor` (or `SparseTensor`).
-
-Each continuous column in the train or test data will be converted into a
-`Tensor`, which in general is a good format to represent dense data. For
-categorical data, we must represent the data as a `SparseTensor`. This data
-format is good for representing sparse data. Our `input_fn` uses the `tf.data`
-API, which makes it easy to apply transformations to our dataset:
-
-```python
-def input_fn(data_file, num_epochs, shuffle, batch_size):
- """Generate an input function for the Estimator."""
- assert tf.gfile.Exists(data_file), (
- '%s not found. Please make sure you have either run data_download.py or '
- 'set both arguments --train_data and --test_data.' % data_file)
-
- def parse_csv(value):
- print('Parsing', data_file)
- columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
- features = dict(zip(_CSV_COLUMNS, columns))
- labels = features.pop('income_bracket')
- return features, tf.equal(labels, '>50K')
-
- # Extract lines from input files using the Dataset API.
- dataset = tf.data.TextLineDataset(data_file)
-
- if shuffle:
- dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
-
- dataset = dataset.map(parse_csv, num_parallel_calls=5)
-
- # We call repeat after shuffling, rather than before, to prevent separate
- # epochs from blending together.
- dataset = dataset.repeat(num_epochs)
- dataset = dataset.batch(batch_size)
-
- iterator = dataset.make_one_shot_iterator()
- features, labels = iterator.get_next()
- return features, labels
-```
-
-## Selecting and Engineering Features for the Model
-
-Selecting and crafting the right set of feature columns is key to learning an
-effective model. A **feature column** can be either one of the raw columns in
-the original dataframe (let's call them **base feature columns**), or any new
-columns created based on some transformations defined over one or multiple base
-columns (let's call them **derived feature columns**). Basically, "feature
-column" is an abstract concept of any raw or derived variable that can be used
-to predict the target label.
-
-### Base Categorical Feature Columns
-
-To define a feature column for a categorical feature, we can create a
-`CategoricalColumn` using the tf.feature_column API. If you know the set of all
-possible feature values of a column and there are only a few of them, you can
-use `categorical_column_with_vocabulary_list`. Each key in the list will get
-assigned an auto-incremental ID starting from 0. For example, for the
-`relationship` column we can assign the feature string "Husband" to an integer
-ID of 0 and "Not-in-family" to 1, etc., by doing:
-
-```python
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-```
-
-What if we don't know the set of possible values in advance? Not a problem. We
-can use `categorical_column_with_hash_bucket` instead:
-
-```python
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-```
-
-What will happen is that each possible value in the feature column `occupation`
-will be hashed to an integer ID as we encounter them in training. See an example
-illustration below:
-
-ID | Feature
---- | -------------
-... |
-9 | `"Machine-op-inspct"`
-... |
-103 | `"Farming-fishing"`
-... |
-375 | `"Protective-serv"`
-... |
-
-No matter which way we choose to define a `SparseColumn`, each feature string
-will be mapped into an integer ID by looking up a fixed mapping or by hashing.
-Note that hashing collisions are possible, but may not significantly impact the
-model quality. Under the hood, the `LinearModel` class is responsible for
-managing the mapping and creating `tf.Variable` to store the model parameters
-(also known as model weights) for each feature ID. The model parameters will be
-learned through the model training process we'll go through later.
-
-We'll do the similar trick to define the other categorical features:
-
-```python
-education = tf.feature_column.categorical_column_with_vocabulary_list(
- 'education', [
- 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
- 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
- '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
-
-marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
- 'marital_status', [
- 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
- 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
-
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-
-workclass = tf.feature_column.categorical_column_with_vocabulary_list(
- 'workclass', [
- 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
- 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
-
-# To show an example of hashing:
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-```
-
-### Base Continuous Feature Columns
-
-Similarly, we can define a `NumericColumn` for each continuous feature column
-that we want to use in the model:
-
-```python
-age = tf.feature_column.numeric_column('age')
-education_num = tf.feature_column.numeric_column('education_num')
-capital_gain = tf.feature_column.numeric_column('capital_gain')
-capital_loss = tf.feature_column.numeric_column('capital_loss')
-hours_per_week = tf.feature_column.numeric_column('hours_per_week')
-```
-
-### Making Continuous Features Categorical through Bucketization
-
-Sometimes the relationship between a continuous feature and the label is not
-linear. As a hypothetical example, a person's income may grow with age in the
-early stage of one's career, then the growth may slow at some point, and finally
-the income decreases after retirement. In this scenario, using the raw `age` as
-a real-valued feature column might not be a good choice because the model can
-only learn one of the three cases:
-
-1. Income always increases at some rate as age grows (positive correlation),
-1. Income always decreases at some rate as age grows (negative correlation), or
-1. Income stays the same no matter at what age (no correlation)
-
-If we want to learn the fine-grained correlation between income and each age
-group separately, we can leverage **bucketization**. Bucketization is a process
-of dividing the entire range of a continuous feature into a set of consecutive
-bins/buckets, and then converting the original numerical feature into a bucket
-ID (as a categorical feature) depending on which bucket that value falls into.
-So, we can define a `bucketized_column` over `age` as:
-
-```python
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-```
-
-where the `boundaries` is a list of bucket boundaries. In this case, there are
-10 boundaries, resulting in 11 age group buckets (from age 17 and below, 18-24,
-25-29, ..., to 65 and over).
-
-### Intersecting Multiple Columns with CrossedColumn
-
-Using each base feature column separately may not be enough to explain the data.
-For example, the correlation between education and the label (earning > 50,000
-dollars) may be different for different occupations. Therefore, if we only learn
-a single model weight for `education="Bachelors"` and `education="Masters"`, we
-won't be able to capture every single education-occupation combination (e.g.
-distinguishing between `education="Bachelors" AND occupation="Exec-managerial"`
-and `education="Bachelors" AND occupation="Craft-repair"`). To learn the
-differences between different feature combinations, we can add **crossed feature
-columns** to the model.
-
-```python
-education_x_occupation = tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000)
-```
-
-We can also create a `CrossedColumn` over more than two columns. Each
-constituent column can be either a base feature column that is categorical
-(`SparseColumn`), a bucketized real-valued feature column (`BucketizedColumn`),
-or even another `CrossColumn`. Here's an example:
-
-```python
-age_buckets_x_education_x_occupation = tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000)
-```
-
-## Defining The Logistic Regression Model
-
-After processing the input data and defining all the feature columns, we're now
-ready to put them all together and build a Logistic Regression model. In the
-previous section we've seen several types of base and derived feature columns,
-including:
-
-* `CategoricalColumn`
-* `NumericColumn`
-* `BucketizedColumn`
-* `CrossedColumn`
-
-All of these are subclasses of the abstract `FeatureColumn` class, and can be
-added to the `feature_columns` field of a model:
-
-```python
-base_columns = [
- education, marital_status, relationship, workclass, occupation,
- age_buckets,
-]
-crossed_columns = [
- tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
-]
-
-model_dir = tempfile.mkdtemp()
-model = tf.estimator.LinearClassifier(
- model_dir=model_dir, feature_columns=base_columns + crossed_columns)
-```
-
-The model also automatically learns a bias term, which controls the prediction
-one would make without observing any features (see the section "How Logistic
-Regression Works" for more explanations). The learned model files will be stored
-in `model_dir`.
-
-## Training and Evaluating Our Model
-
-After adding all the features to the model, now let's look at how to actually
-train the model. Training a model is just a single command using the
-tf.estimator API:
-
-```python
-model.train(input_fn=lambda: input_fn(train_data, num_epochs, True, batch_size))
-```
-
-After the model is trained, we can evaluate how good our model is at predicting
-the labels of the holdout data:
-
-```python
-results = model.evaluate(input_fn=lambda: input_fn(
- test_data, 1, False, batch_size))
-for key in sorted(results):
- print('%s: %s' % (key, results[key]))
-```
-
-The first line of the final output should be something like
-`accuracy: 0.83557522`, which means the accuracy is 83.6%. Feel free to try more
-features and transformations and see if you can do even better!
-
-After the model is evaluated, we can use the model to predict whether an individual has an annual income of over
-50,000 dollars given an individual's information input.
-```python
- pred_iter = model.predict(input_fn=lambda: input_fn(FLAGS.test_data, 1, False, 1))
- for pred in pred_iter:
- print(pred['classes'])
-```
-
-The model prediction output would be like `[b'1']` or `[b'0']` which means whether corresponding individual has an annual income of over 50,000 dollars or not.
-
-If you'd like to see a working end-to-end example, you can download our
-[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py)
-and set the `model_type` flag to `wide`.
-
-## Adding Regularization to Prevent Overfitting
-
-Regularization is a technique used to avoid **overfitting**. Overfitting happens
-when your model does well on the data it is trained on, but worse on test data
-that the model has not seen before, such as live traffic. Overfitting generally
-occurs when a model is excessively complex, such as having too many parameters
-relative to the number of observed training data. Regularization allows for you
-to control your model's complexity and makes the model more generalizable to
-unseen data.
-
-In the Linear Model library, you can add L1 and L2 regularizations to the model
-as:
-
-```
-model = tf.estimator.LinearClassifier(
- model_dir=model_dir, feature_columns=base_columns + crossed_columns,
- optimizer=tf.train.FtrlOptimizer(
- learning_rate=0.1,
- l1_regularization_strength=1.0,
- l2_regularization_strength=1.0))
-```
-
-One important difference between L1 and L2 regularization is that L1
-regularization tends to make model weights stay at zero, creating sparser
-models, whereas L2 regularization also tries to make the model weights closer to
-zero but not necessarily zero. Therefore, if you increase the strength of L1
-regularization, you will have a smaller model size because many of the model
-weights will be zero. This is often desirable when the feature space is very
-large but sparse, and when there are resource constraints that prevent you from
-serving a model that is too large.
-
-In practice, you should try various combinations of L1, L2 regularization
-strengths and find the best parameters that best control overfitting and give
-you a desirable model size.
-
-## How Logistic Regression Works
-
-Finally, let's take a minute to talk about what the Logistic Regression model
-actually looks like in case you're not already familiar with it. We'll denote
-the label as \\(Y\\), and the set of observed features as a feature vector
-\\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). We define \\(Y=1\\) if an individual
-earned > 50,000 dollars and \\(Y=0\\) otherwise. In Logistic Regression, the
-probability of the label being positive (\\(Y=1\\)) given the features
-\\(\mathbf{x}\\) is given as:
-
-$$ P(Y=1|\mathbf{x}) = \frac{1}{1+\exp(-(\mathbf{w}^T\mathbf{x}+b))}$$
-
-where \\(\mathbf{w}=[w_1, w_2, ..., w_d]\\) are the model weights for the
-features \\(\mathbf{x}=[x_1, x_2, ..., x_d]\\). \\(b\\) is a constant that is
-often called the **bias** of the model. The equation consists of two parts—A
-linear model and a logistic function:
-
-* **Linear Model**: First, we can see that \\(\mathbf{w}^T\mathbf{x}+b = b +
- w_1x_1 + ... +w_dx_d\\) is a linear model where the output is a linear
- function of the input features \\(\mathbf{x}\\). The bias \\(b\\) is the
- prediction one would make without observing any features. The model weight
- \\(w_i\\) reflects how the feature \\(x_i\\) is correlated with the positive
- label. If \\(x_i\\) is positively correlated with the positive label, the
- weight \\(w_i\\) increases, and the probability \\(P(Y=1|\mathbf{x})\\) will
- be closer to 1. On the other hand, if \\(x_i\\) is negatively correlated
- with the positive label, then the weight \\(w_i\\) decreases and the
- probability \\(P(Y=1|\mathbf{x})\\) will be closer to 0.
-
-* **Logistic Function**: Second, we can see that there's a logistic function
- (also known as the sigmoid function) \\(S(t) = 1/(1+\exp(-t))\\) being
- applied to the linear model. The logistic function is used to convert the
- output of the linear model \\(\mathbf{w}^T\mathbf{x}+b\\) from any real
- number into the range of \\([0, 1]\\), which can be interpreted as a
- probability.
-
-Model training is an optimization problem: The goal is to find a set of model
-weights (i.e. model parameters) to minimize a **loss function** defined over the
-training data, such as logistic loss for Logistic Regression models. The loss
-function measures the discrepancy between the ground-truth label and the model's
-prediction. If the prediction is very close to the ground-truth label, the loss
-value will be low; if the prediction is very far from the label, then the loss
-value would be high.
-
-## Learn Deeper
-
-If you're interested in learning more, check out our
-@{$wide_and_deep$Wide & Deep Learning Tutorial} where we'll show you how to
-combine the strengths of linear models and deep neural networks by jointly
-training them using the tf.estimator API.
diff --git a/tensorflow/docs_src/tutorials/representation/wide_and_deep.md b/tensorflow/docs_src/tutorials/representation/wide_and_deep.md
deleted file mode 100644
index 44677a810b..0000000000
--- a/tensorflow/docs_src/tutorials/representation/wide_and_deep.md
+++ /dev/null
@@ -1,243 +0,0 @@
-# TensorFlow Wide & Deep Learning Tutorial
-
-In the previous @{$wide$TensorFlow Linear Model Tutorial}, we trained a logistic
-regression model to predict the probability that the individual has an annual
-income of over 50,000 dollars using the
-[Census Income Dataset](https://archive.ics.uci.edu/ml/datasets/Census+Income).
-TensorFlow is great for training deep neural networks too, and you might be
-thinking which one you should choose—well, why not both? Would it be possible to
-combine the strengths of both in one model?
-
-In this tutorial, we'll introduce how to use the tf.estimator API to jointly
-train a wide linear model and a deep feed-forward neural network. This approach
-combines the strengths of memorization and generalization. It's useful for
-generic large-scale regression and classification problems with sparse input
-features (e.g., categorical features with a large number of possible feature
-values). If you're interested in learning more about how Wide & Deep Learning
-works, please check out our [research paper](https://arxiv.org/abs/1606.07792).
-
-![Wide & Deep Spectrum of Models](https://www.tensorflow.org/images/wide_n_deep.svg "Wide & Deep")
-
-The figure above shows a comparison of a wide model (logistic regression with
-sparse features and transformations), a deep model (feed-forward neural network
-with an embedding layer and several hidden layers), and a Wide & Deep model
-(joint training of both). At a high level, there are only 3 steps to configure a
-wide, deep, or Wide & Deep model using the tf.estimator API:
-
-1. Select features for the wide part: Choose the sparse base columns and
- crossed columns you want to use.
-1. Select features for the deep part: Choose the continuous columns, the
- embedding dimension for each categorical column, and the hidden layer sizes.
-1. Put them all together in a Wide & Deep model
- (`DNNLinearCombinedClassifier`).
-
-And that's it! Let's go through a simple example.
-
-## Setup
-
-To try the code for this tutorial:
-
-1. @{$install$Install TensorFlow} if you haven't already.
-
-2. Download [the tutorial code](https://github.com/tensorflow/models/tree/master/official/wide_deep/).
-
-3. Execute the data download script we provide to you:
-
- $ python data_download.py
-
-4. Execute the tutorial code with the following command to train the wide and
-deep model described in this tutorial:
-
- $ python wide_deep.py
-
-Read on to find out how this code builds its model.
-
-
-## Define Base Feature Columns
-
-First, let's define the base categorical and continuous feature columns that
-we'll use. These base columns will be the building blocks used by both the wide
-part and the deep part of the model.
-
-```python
-import tensorflow as tf
-
-# Continuous columns
-age = tf.feature_column.numeric_column('age')
-education_num = tf.feature_column.numeric_column('education_num')
-capital_gain = tf.feature_column.numeric_column('capital_gain')
-capital_loss = tf.feature_column.numeric_column('capital_loss')
-hours_per_week = tf.feature_column.numeric_column('hours_per_week')
-
-education = tf.feature_column.categorical_column_with_vocabulary_list(
- 'education', [
- 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
- 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
- '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])
-
-marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
- 'marital_status', [
- 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
- 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])
-
-relationship = tf.feature_column.categorical_column_with_vocabulary_list(
- 'relationship', [
- 'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
- 'Other-relative'])
-
-workclass = tf.feature_column.categorical_column_with_vocabulary_list(
- 'workclass', [
- 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
- 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])
-
-# To show an example of hashing:
-occupation = tf.feature_column.categorical_column_with_hash_bucket(
- 'occupation', hash_bucket_size=1000)
-
-# Transformations.
-age_buckets = tf.feature_column.bucketized_column(
- age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
-```
-
-## The Wide Model: Linear Model with Crossed Feature Columns
-
-The wide model is a linear model with a wide set of sparse and crossed feature
-columns:
-
-```python
-base_columns = [
- education, marital_status, relationship, workclass, occupation,
- age_buckets,
-]
-
-crossed_columns = [
- tf.feature_column.crossed_column(
- ['education', 'occupation'], hash_bucket_size=1000),
- tf.feature_column.crossed_column(
- [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
-]
-```
-
-You can also see the @{$wide$TensorFlow Linear Model Tutorial} for more details.
-
-Wide models with crossed feature columns can memorize sparse interactions
-between features effectively. That being said, one limitation of crossed feature
-columns is that they do not generalize to feature combinations that have not
-appeared in the training data. Let's add a deep model with embeddings to fix
-that.
-
-## The Deep Model: Neural Network with Embeddings
-
-The deep model is a feed-forward neural network, as shown in the previous
-figure. Each of the sparse, high-dimensional categorical features are first
-converted into a low-dimensional and dense real-valued vector, often referred to
-as an embedding vector. These low-dimensional dense embedding vectors are
-concatenated with the continuous features, and then fed into the hidden layers
-of a neural network in the forward pass. The embedding values are initialized
-randomly, and are trained along with all other model parameters to minimize the
-training loss. If you're interested in learning more about embeddings, check out
-the TensorFlow tutorial on @{$word2vec$Vector Representations of Words} or
-[Word embedding](https://en.wikipedia.org/wiki/Word_embedding) on Wikipedia.
-
-Another way to represent categorical columns to feed into a neural network is
-via a one-hot or multi-hot representation. This is often appropriate for
-categorical columns with only a few possible values. As an example of a one-hot
-representation, for the relationship column, `"Husband"` can be represented as
-[1, 0, 0, 0, 0, 0], and `"Not-in-family"` as [0, 1, 0, 0, 0, 0], etc. This is a
-fixed representation, whereas embeddings are more flexible and calculated at
-training time.
-
-We'll configure the embeddings for the categorical columns using
-`embedding_column`, and concatenate them with the continuous columns.
-We also use `indicator_column` to create multi-hot representations of some
-categorical columns.
-
-```python
-deep_columns = [
- age,
- education_num,
- capital_gain,
- capital_loss,
- hours_per_week,
- tf.feature_column.indicator_column(workclass),
- tf.feature_column.indicator_column(education),
- tf.feature_column.indicator_column(marital_status),
- tf.feature_column.indicator_column(relationship),
- # To show an example of embedding
- tf.feature_column.embedding_column(occupation, dimension=8),
-]
-```
-
-The higher the `dimension` of the embedding is, the more degrees of freedom the
-model will have to learn the representations of the features. For simplicity, we
-set the dimension to 8 for all feature columns here. Empirically, a more
-informed decision for the number of dimensions is to start with a value on the
-order of \\(\log_2(n)\\) or \\(k\sqrt[4]n\\), where \\(n\\) is the number of
-unique features in a feature column and \\(k\\) is a small constant (usually
-smaller than 10).
-
-Through dense embeddings, deep models can generalize better and make predictions
-on feature pairs that were previously unseen in the training data. However, it
-is difficult to learn effective low-dimensional representations for feature
-columns when the underlying interaction matrix between two feature columns is
-sparse and high-rank. In such cases, the interaction between most feature pairs
-should be zero except a few, but dense embeddings will lead to nonzero
-predictions for all feature pairs, and thus can over-generalize. On the other
-hand, linear models with crossed features can memorize these “exception rules”
-effectively with fewer model parameters.
-
-Now, let's see how to jointly train wide and deep models and allow them to
-complement each other’s strengths and weaknesses.
-
-## Combining Wide and Deep Models into One
-
-The wide models and deep models are combined by summing up their final output
-log odds as the prediction, then feeding the prediction to a logistic loss
-function. All the graph definition and variable allocations have already been
-handled for you under the hood, so you simply need to create a
-`DNNLinearCombinedClassifier`:
-
-```python
-model = tf.estimator.DNNLinearCombinedClassifier(
- model_dir='/tmp/census_model',
- linear_feature_columns=base_columns + crossed_columns,
- dnn_feature_columns=deep_columns,
- dnn_hidden_units=[100, 50])
-```
-
-## Training and Evaluating The Model
-
-Before we train the model, let's read in the Census dataset as we did in the
-@{$wide$TensorFlow Linear Model tutorial}. See `data_download.py` as well as
-`input_fn` within
-[`wide_deep.py`](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-After reading in the data, you can train and evaluate the model:
-
-```python
-# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
-for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
- model.train(input_fn=lambda: input_fn(
- FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size))
-
- results = model.evaluate(input_fn=lambda: input_fn(
- FLAGS.test_data, 1, False, FLAGS.batch_size))
-
- # Display evaluation metrics
- print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
- print('-' * 30)
-
- for key in sorted(results):
- print('%s: %s' % (key, results[key]))
-```
-
-The final output accuracy should be somewhere around 85.5%. If you'd like to
-see a working end-to-end example, you can download our
-[example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/wide_deep.py).
-
-Note that this tutorial is just a quick example on a small dataset to get you
-familiar with the API. Wide & Deep Learning will be even more powerful if you
-try it on a large dataset with many sparse feature columns that have a large
-number of possible feature values. Again, feel free to take a look at our
-[research paper](https://arxiv.org/abs/1606.07792) for more ideas about how to
-apply Wide & Deep Learning in real-world large-scale machine learning problems.
diff --git a/tensorflow/docs_src/tutorials/representation/word2vec.md b/tensorflow/docs_src/tutorials/representation/word2vec.md
index 3fe7352bd2..0a1c41c84a 100644
--- a/tensorflow/docs_src/tutorials/representation/word2vec.md
+++ b/tensorflow/docs_src/tutorials/representation/word2vec.md
@@ -23,7 +23,7 @@ straight in, feel free to look at the minimalistic implementation in
This basic example contains the code needed to download some data, train on it a
bit and visualize the result. Once you get comfortable with reading and running
the basic version, you can graduate to
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py)
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py)
which is a more serious implementation that showcases some more advanced
TensorFlow principles about how to efficiently use threads to move data into a
text model, how to checkpoint during training, etc.
@@ -341,7 +341,7 @@ t-SNE.
Et voila! As expected, words that are similar end up clustering nearby each
other. For a more heavyweight implementation of word2vec that showcases more of
the advanced features of TensorFlow, see the implementation in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
## Evaluating Embeddings: Analogical Reasoning
@@ -357,7 +357,7 @@ Download the dataset for this task from
To see how we do this evaluation, have a look at the `build_eval_graph()` and
`eval()` functions in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
The choice of hyperparameters can strongly influence the accuracy on this task.
To achieve state-of-the-art performance on this task requires training over a
@@ -385,13 +385,13 @@ your model is seriously bottlenecked on input data, you may want to implement a
custom data reader for your problem, as described in
@{$new_data_formats$New Data Formats}. For the case of Skip-Gram
modeling, we've actually already done this for you as an example in
-[models/tutorials/embedding/word2vec.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec.py).
+[models/tutorials/embedding/word2vec.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec.py).
If your model is no longer I/O bound but you want still more performance, you
can take things further by writing your own TensorFlow Ops, as described in
@{$adding_an_op$Adding a New Op}. Again we've provided an
example of this for the Skip-Gram case
-[models/tutorials/embedding/word2vec_optimized.py](https://www.tensorflow.org/code/tensorflow_models/tutorials/embedding/word2vec_optimized.py).
+[models/tutorials/embedding/word2vec_optimized.py](https://github.com/tensorflow/models/tree/master/tutorials/embedding/word2vec_optimized.py).
Feel free to benchmark these against each other to measure performance
improvements at each stage.
diff --git a/tensorflow/examples/android/BUILD b/tensorflow/examples/android/BUILD
index 07f096418f..f327b645f5 100644
--- a/tensorflow/examples/android/BUILD
+++ b/tensorflow/examples/android/BUILD
@@ -1,6 +1,8 @@
# Description:
# TensorFlow camera demo app for Android.
+load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
+
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py
index 03e60972aa..86f5204ec3 100644
--- a/tensorflow/examples/learn/iris.py
+++ b/tensorflow/examples/learn/iris.py
@@ -21,7 +21,8 @@ from __future__ import division
from __future__ import print_function
import os
-import urllib
+
+from six.moves.urllib.request import urlretrieve
import tensorflow as tf
@@ -38,9 +39,7 @@ FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
def maybe_download_iris_data(file_name, download_url):
"""Downloads the file and returns the number of data."""
if not os.path.exists(file_name):
- raw = urllib.urlopen(download_url).read()
- with open(file_name, 'w') as f:
- f.write(raw)
+ urlretrieve(download_url, file_name)
# The first line is a comma-separated string. The first one is the number of
# total data in the file.
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d7bc6a5a7d..d4070fdd1e 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -97,7 +97,7 @@ py_binary(
py_test(
name = "fully_connected_feed_test",
- size = "small",
+ size = "medium",
srcs = [
"fully_connected_feed.py",
],
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go
new file mode 100644
index 0000000000..f86c5737bc
--- /dev/null
+++ b/tensorflow/go/attrs.go
@@ -0,0 +1,245 @@
+/*
+Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package tensorflow
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+import (
+ "fmt"
+ "unsafe"
+)
+
+// makeCShape converts a shape specified in C.int64_t into a Shape.
+func makeCShape(shape []C.int64_t) Shape {
+ s := Shape{dims: make([]int64, len(shape))}
+ for i, n := range shape {
+ s.dims[i] = int64(n)
+ }
+ return s
+}
+
+// Attr returns the value of an attribute on op. It returns an error if the
+// attribute does not exist.
+func (op *Operation) Attr(name string) (interface{}, error) {
+ cname := C.CString(name)
+ defer C.free(unsafe.Pointer(cname))
+
+ status := newStatus()
+ meta := C.TF_OperationGetAttrMetadata(op.c, cname, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ if meta.is_list == 1 {
+ return listAttribute(op, cname, meta)
+ }
+ return scalarAttribute(op, cname, meta)
+}
+
+func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
+ status := newStatus()
+
+ switch meta._type {
+ case C.TF_ATTR_STRING:
+ if meta.list_size == 0 {
+ return []string(nil), nil
+ }
+ values := make([]unsafe.Pointer, meta.list_size)
+ lengths := make([]C.size_t, meta.list_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.char, meta.total_size+1)
+ C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ list := make([]string, meta.list_size)
+ for i, val := range values {
+ length := lengths[i]
+ list[i] = C.GoStringN((*C.char)(val), C.int(length))
+ }
+ return list, nil
+
+ case C.TF_ATTR_INT:
+ if meta.list_size == 0 {
+ return []int64(nil), nil
+ }
+ list := make([]C.int64_t, meta.list_size)
+ C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]int64, meta.list_size)
+ for i, val := range list {
+ vals[i] = int64(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_FLOAT:
+ if meta.list_size == 0 {
+ return []float32(nil), nil
+ }
+ list := make([]C.float, meta.list_size)
+ C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]float32, meta.list_size)
+ for i, val := range list {
+ vals[i] = float32(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_BOOL:
+ if meta.list_size == 0 {
+ return []bool(nil), nil
+ }
+ list := make([]C.uchar, meta.list_size)
+ C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]bool, meta.list_size)
+ for i, val := range list {
+ vals[i] = val == 1
+ }
+ return vals, nil
+
+ case C.TF_ATTR_TYPE:
+ if meta.list_size == 0 {
+ return []DataType(nil), nil
+ }
+ list := make([]C.TF_DataType, meta.list_size)
+ C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]DataType, meta.list_size)
+ for i, val := range list {
+ vals[i] = DataType(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_TENSOR:
+ if meta.list_size == 0 {
+ return []*Tensor(nil), nil
+ }
+ list := make([]*C.TF_Tensor, meta.list_size)
+ C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]*Tensor, meta.list_size)
+ for i, t := range list {
+ vals[i] = newTensorFromC(t)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_SHAPE:
+ if meta.list_size == 0 {
+ return []Shape(nil), nil
+ }
+ dims := make([]*C.int64_t, meta.list_size)
+ numDims := make([]C.int, meta.list_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.int64_t, meta.total_size+1)
+ C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ list := make([]Shape, meta.list_size)
+ for i, dim := range dims {
+ numDim := numDims[i]
+ // If the number of dimensions is unknown, default to empty shape.
+ if numDim < 0 {
+ continue
+ }
+ // A []C.int64_t slice backed by C memory.
+ // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
+ slice := (*[1 << 30]C.int64_t)(unsafe.Pointer(dim))[:numDim:numDim]
+ list[i] = makeCShape(slice)
+ }
+ return list, nil
+
+ default:
+ return nil, fmt.Errorf("list type %v not supported", meta._type)
+ }
+}
+
+func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
+ status := newStatus()
+
+ switch meta._type {
+ case C.TF_ATTR_STRING:
+ if meta.total_size == 0 {
+ return "", nil
+ }
+ v := make([]C.char, meta.total_size)
+ C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return C.GoStringN(&v[0], C.int(meta.total_size)), nil
+
+ case C.TF_ATTR_INT:
+ var v C.int64_t
+ C.TF_OperationGetAttrInt(op.c, cname, &v, status.c)
+ return int64(v), status.Err()
+
+ case C.TF_ATTR_FLOAT:
+ var v C.float
+ C.TF_OperationGetAttrFloat(op.c, cname, &v, status.c)
+ return float32(v), status.Err()
+
+ case C.TF_ATTR_BOOL:
+ var v C.uchar
+ C.TF_OperationGetAttrBool(op.c, cname, &v, status.c)
+ return v == 1, status.Err()
+
+ case C.TF_ATTR_TYPE:
+ var v C.TF_DataType
+ C.TF_OperationGetAttrType(op.c, cname, &v, status.c)
+ return DataType(v), status.Err()
+
+ case C.TF_ATTR_TENSOR:
+ var v *C.TF_Tensor
+ C.TF_OperationGetAttrTensor(op.c, cname, &v, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return newTensorFromC(v), nil
+
+ case C.TF_ATTR_SHAPE:
+ numDims := meta.total_size
+ // If number of dims is unknown return empty shape to indicate that.
+ if numDims < 0 {
+ return Shape{}, nil
+ }
+ if numDims == 0 {
+ return ScalarShape(), nil
+ }
+ dims := make([]C.int64_t, numDims)
+ C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return makeCShape(dims), nil
+
+ default:
+ return nil, fmt.Errorf("type %v not supported", meta._type)
+ }
+}
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go
new file mode 100644
index 0000000000..ea8af221ae
--- /dev/null
+++ b/tensorflow/go/attrs_test.go
@@ -0,0 +1,193 @@
+/*
+Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package tensorflow
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestOperationAttrs(t *testing.T) {
+ g := NewGraph()
+
+ i := 0
+ makeConst := func(v interface{}) Output {
+ op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
+ i++
+ if err != nil {
+ t.Fatal(err)
+ }
+ return op
+ }
+
+ makeTensor := func(v interface{}) *Tensor {
+ tensor, err := NewTensor(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return tensor
+ }
+
+ cases := []OpSpec{
+ {
+ Name: "type",
+ Type: "Placeholder",
+ Attrs: map[string]interface{}{
+ "dtype": Float,
+ },
+ },
+ {
+ Name: "list(float)",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{1, 2, 3, 4}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32{0, 1, 2, 3, 4, 5},
+ },
+ },
+ {
+ Name: "list(float) empty",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32(nil),
+ },
+ },
+ /* TODO(ashankar): debug this issue and add it back later.
+ {
+ Name: "list(type),list(shape)",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst(float32(1)),
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Float, Int32},
+ "shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
+ },
+ },
+ {
+ Name: "list(type),list(shape) empty",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Int32},
+ "shapes": []Shape(nil),
+ },
+ },
+ {
+ Name: "list(type) empty,string empty,int",
+ Type: "_XlaSendFromHost",
+ Input: []Input{
+ OutputList([]Output{}),
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "Tinputs": []DataType(nil),
+ "key": "",
+ "device_ordinal": int64(0),
+ },
+ },
+ */
+ {
+ Name: "list(int),int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": []int64{1, 2},
+ },
+ },
+ {
+ Name: "list(int) empty,int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": ([]int64)(nil),
+ },
+ },
+ {
+ Name: "list(string),type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": []string{"foo", "bar"},
+ },
+ },
+ {
+ Name: "list(string) empty,type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": ([]string)(nil),
+ },
+ },
+ {
+ Name: "tensor",
+ Type: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": String,
+ "value": makeTensor("foo"),
+ },
+ },
+ }
+
+ for i, spec := range cases {
+ op, err := g.AddOperation(spec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for key, want := range spec.Attrs {
+ out, err := op.Attr(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(out, want) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
+ }
+ wantT, ok := want.(*Tensor)
+ if ok {
+ wantVal := wantT.Value()
+ outVal := out.(*Tensor).Value()
+ if !reflect.DeepEqual(outVal, wantVal) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
+ }
+ }
+ }
+ }
+}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 0dd3726948..f49e1cecaf 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2724,64 +2724,497 @@ func MatrixDiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
return op.Output(0)
}
-// Creates a sequence of numbers.
+// Returns a batched diagonal tensor with a given batched diagonal values.
//
-// This operation creates a sequence of numbers that begins at `start` and
-// extends by increments of `delta` up to but not including `limit`.
+// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+// everything else padded with zeros. The diagonal is computed as follows:
+//
+// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a
+// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
+//
+// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`.
//
// For example:
//
// ```
-// # 'start' is 3
-// # 'limit' is 18
-// # 'delta' is 3
-// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]]
+//
+// and diagonal.shape = (2, 4)
+//
+// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0]
+// [0, 2, 0, 0]
+// [0, 0, 3, 0]
+// [0, 0, 0, 4]],
+// [[5, 0, 0, 0]
+// [0, 6, 0, 0]
+// [0, 0, 7, 0]
+// [0, 0, 0, 8]]]
+//
+// which has shape (2, 4, 4)
// ```
//
// Arguments:
-// start: 0-D (scalar). First entry in the sequence.
-// limit: 0-D (scalar). Upper limit of sequence, exclusive.
-// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+// diagonal: Rank `k`, where `k >= 1`.
//
-// Returns 1-D.
-func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
+// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`.
+func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Range",
+ Type: "MatrixDiag",
Input: []tf.Input{
- start, limit, delta,
+ diagonal,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes gradients for SparseSegmentSqrtN.
+// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm.
+type QuantizedInstanceNormAttr func(optionalAttr)
+
+// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value.
//
-// Returns tensor "output" with same shape as grad, except for dimension 0 whose
-// value is output_dim0.
+// value: If True, `given_y_min` and `given_y_min`
+// and `given_y_max` are used as the output range. Otherwise,
+// the implementation computes the output range.
+// If not specified, defaults to false
+func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr {
+ return func(m optionalAttr) {
+ m["output_range_given"] = value
+ }
+}
+
+// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value.
+//
+// value: Output in `y_min` if `output_range_given` is True.
+// If not specified, defaults to 0
+func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr {
+ return func(m optionalAttr) {
+ m["given_y_min"] = value
+ }
+}
+
+// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value.
+//
+// value: Output in `y_max` if `output_range_given` is True.
+// If not specified, defaults to 0
+func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr {
+ return func(m optionalAttr) {
+ m["given_y_max"] = value
+ }
+}
+
+// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value.
+//
+// value: A small float number to avoid dividing by 0.
+// If not specified, defaults to 1e-05
+func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr {
+ return func(m optionalAttr) {
+ m["variance_epsilon"] = value
+ }
+}
+
+// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value.
+//
+// value: Minimum value of `y_max - y_min`
+// If not specified, defaults to 0.001
+func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr {
+ return func(m optionalAttr) {
+ m["min_separation"] = value
+ }
+}
+
+// Quantized Instance normalization.
//
// Arguments:
-// grad: gradient propagated to the SparseSegmentSqrtN op.
-// indices: indices passed to the corresponding SparseSegmentSqrtN op.
-// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op.
-// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op.
-func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
+// x: A 4D input Tensor.
+// x_min: The value represented by the lowest quantized input.
+// x_max: The value represented by the highest quantized input.
+//
+// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output.
+func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentSqrtNGrad",
+ Type: "QuantizedInstanceNorm",
Input: []tf.Input{
- grad, indices, segment_ids, output_dim0,
+ x, x_min, x_max,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Returns the diagonal part of the tensor.
+//
+// This operation returns a tensor with the `diagonal` part
+// of the `input`. The `diagonal` part is computed as follows:
+//
+// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
+// tensor of rank `k` with dimensions `[D1,..., Dk]` where:
+//
+// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
+//
+// For example:
+//
+// ```
+// # 'input' is [[1, 0, 0, 0]
+// [0, 2, 0, 0]
+// [0, 0, 3, 0]
+// [0, 0, 0, 4]]
+//
+// tf.diag_part(input) ==> [1, 2, 3, 4]
+// ```
+//
+// Arguments:
+// input: Rank k tensor where k is even and not zero.
+//
+// Returns The extracted diagonal.
+func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DiagPart",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Gives a guarantee to the TF runtime that the input tensor is a constant.
+//
+// The runtime is then free to make optimizations based on this.
+//
+// Only accepts value typed tensors as inputs and rejects resource variable handles
+// as input.
+//
+// Returns the input tensor without modification.
+func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "GuaranteeConst",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Splits a tensor into `num_split` tensors along one dimension.
+//
+// Arguments:
+// value: The tensor to split.
+// size_splits: list containing the sizes of each output tensor along the split
+// dimension. Must sum to the dimension of value along split_dim.
+// Can contain one -1 indicating that dimension is to be inferred.
+// axis: 0-D. The dimension along which to split. Must be in the range
+// `[-rank(value), rank(value))`.
+//
+//
+// Returns Tensors whose shape matches that of `value`
+// except along `axis`, where their sizes are
+// `size_splits[i]`.
+func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "SplitV",
+ Input: []tf.Input{
+ value, size_splits, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("SplitV", err)
+ return
+ }
+ return output
+}
+
+// Splits a tensor into `num_split` tensors along one dimension.
+//
+// Arguments:
+// axis: 0-D. The dimension along which to split. Must be in the range
+// `[-rank(value), rank(value))`.
+// value: The tensor to split.
+// num_split: The number of ways to split. Must evenly divide
+// `value.shape[split_dim]`.
+//
+// Returns They are identically shaped tensors, whose shape matches that of `value`
+// except along `axis`, where their sizes are
+// `values.shape[split_dim] / num_split`.
+func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "Split",
+ Input: []tf.Input{
+ axis, value,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ scope.UpdateErr("Split", err)
+ return
+ }
+ return output
+}
+
+// Concatenates tensors along one dimension.
+//
+// Arguments:
+// concat_dim: 0-D. The dimension along which to concatenate. Must be in the
+// range [0, rank(values)).
+// values: The `N` Tensors to concatenate. Their ranks and types must match,
+// and their sizes must match in all dimensions except `concat_dim`.
+//
+// Returns A `Tensor` with the concatenation of values stacked along the
+// `concat_dim` dimension. This tensor's shape matches that of `values` except
+// in `concat_dim` where it has the sum of the sizes.
+func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Concat",
+ Input: []tf.Input{
+ concat_dim, tf.OutputList(values),
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
+// Converts a flat index or array of flat indices into a tuple of
+//
+// coordinate arrays.
+//
+// @compatibility(numpy)
+// Equivalent to np.unravel_index
+// @end_compatibility
+//
+// Arguments:
+// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the
+// flattened version of an array of dimensions dims.
+// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling
+// indices.
+//
+// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the
+// same shape as the indices array.
+func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "UnravelIndex",
+ Input: []tf.Input{
+ indices, dims,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Subtracts `v` into specified rows of `x`.
+//
+// Computes y = x; y[i, :] -= v; return y.
+//
+// Arguments:
+// x: A `Tensor` of type T.
+// i: A vector. Indices into the left-most dimension of `x`.
+// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
+//
+// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
+func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InplaceSub",
+ Input: []tf.Input{
+ x, i, v,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Updates specified rows with values in `v`.
+//
+// Computes `x[i, :] = v; return x`.
+//
+// Arguments:
+// x: A tensor of type `T`.
+// i: A vector. Indices into the left-most dimension of `x`.
+// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
+//
+// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
+func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "InplaceUpdate",
+ Input: []tf.Input{
+ x, i, v,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Makes a copy of `x`.
+//
+// Arguments:
+// x: The source tensor of type `T`.
+//
+// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
+// is not an alias of `x`.
+func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DeepCopy",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// PackAttr is an optional argument to Pack.
+type PackAttr func(optionalAttr)
+
+// PackAxis sets the optional axis attribute to value.
+//
+// value: Dimension along which to pack. Negative values wrap around, so the
+// valid range is `[-(R+1), R+1)`.
+// If not specified, defaults to 0
+func PackAxis(value int64) PackAttr {
+ return func(m optionalAttr) {
+ m["axis"] = value
+ }
+}
+
+// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
+//
+// Packs the `N` tensors in `values` into a tensor with rank one higher than each
+// tensor in `values`, by packing them along the `axis` dimension.
+// Given a list of tensors of shape `(A, B, C)`;
+//
+// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
+// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
+// Etc.
+//
+// For example:
+//
+// ```
+// # 'x' is [1, 4]
+// # 'y' is [2, 5]
+// # 'z' is [3, 6]
+// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
+// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
+// ```
+//
+// This is the opposite of `unpack`.
+//
+// Arguments:
+// values: Must be of same shape and type.
+//
+// Returns The packed tensor.
+func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Pack",
+ Input: []tf.Input{
+ tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Concatenates a list of `N` tensors along the first dimension.
+//
+// The input tensors are all required to have size 1 in the first dimension.
+//
+// For example:
+//
+// ```
+// # 'x' is [[1, 4]]
+// # 'y' is [[2, 5]]
+// # 'z' is [[3, 6]]
+// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
+// ```
+//
+// The difference between concat and parallel_concat is that concat requires all
+// of the inputs be computed before the operation will begin but doesn't require
+// that the input shapes be known during graph construction. Parallel concat
+// will copy pieces of the input into the output as they become available, in
+// some situations this can provide a performance benefit.
+//
+// Arguments:
+// values: Tensors to be concatenated. All must have size 1 in the first dimension
+// and same shape.
+// shape: the final shape of the result; should be equal to the shapes of any input
+// but with the number of input values in the first dimension.
+//
+// Returns The concatenated tensor.
+func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "ParallelConcat",
+ Input: []tf.Input{
+ tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the mean along sparse segments of a tensor.
//
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
@@ -3708,24 +4141,6 @@ func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Returns x + y element-wise.
-//
-// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Add",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// NthElementAttr is an optional argument to NthElement.
type NthElementAttr func(optionalAttr)
@@ -3995,69 +4410,6 @@ func Digamma(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Shuffle dimensions of x according to a permutation.
-//
-// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
-// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
-func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Transpose",
- Input: []tf.Input{
- x, perm,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MinAttr is an optional argument to Min.
-type MinAttr func(optionalAttr)
-
-// MinKeepDims sets the optional keep_dims attribute to value.
-//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func MinKeepDims(value bool) MinAttr {
- return func(m optionalAttr) {
- m["keep_dims"] = value
- }
-}
-
-// Computes the minimum of elements across dimensions of a tensor.
-//
-// Reduces `input` along the dimensions given in `axis`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `axis`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
-//
-// Arguments:
-// input: The tensor to reduce.
-// axis: The dimensions to reduce. Must be in the range
-// `[-rank(input), rank(input))`.
-//
-// Returns The reduced tensor.
-func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Min",
- Input: []tf.Input{
- input, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter.
type Conv2DBackpropFilterAttr func(optionalAttr)
@@ -4532,6 +4884,24 @@ func MatrixInverse(scope *Scope, input tf.Output, optional ...MatrixInverseAttr)
return op.Output(0)
}
+// Returns x + y element-wise.
+//
+// *NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Add",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes square of x element-wise.
//
// I.e., \\(y = x * x = x^2\\).
@@ -5198,53 +5568,6 @@ func FloorDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Returns a batched diagonal tensor with a given batched diagonal values.
-//
-// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
-// everything else padded with zeros. The diagonal is computed as follows:
-//
-// Assume `diagonal` has `k` dimensions `[I, J, K, ..., N]`, then the output is a
-// tensor of rank `k+1` with dimensions [I, J, K, ..., N, N]` where:
-//
-// `output[i, j, k, ..., m, n] = 1{m=n} * diagonal[i, j, k, ..., n]`.
-//
-// For example:
-//
-// ```
-// # 'diagonal' is [[1, 2, 3, 4], [5, 6, 7, 8]]
-//
-// and diagonal.shape = (2, 4)
-//
-// tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0]
-// [0, 2, 0, 0]
-// [0, 0, 3, 0]
-// [0, 0, 0, 4]],
-// [[5, 0, 0, 0]
-// [0, 6, 0, 0]
-// [0, 0, 7, 0]
-// [0, 0, 0, 8]]]
-//
-// which has shape (2, 4, 4)
-// ```
-//
-// Arguments:
-// diagonal: Rank `k`, where `k >= 1`.
-//
-// Returns Rank `k+1`, with `output.shape = diagonal.shape + [diagonal.shape[-1]]`.
-func MatrixDiag(scope *Scope, diagonal tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatrixDiag",
- Input: []tf.Input{
- diagonal,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the inverse permutation of a tensor.
//
// This operation computes the inverse of an index permutation. It takes a 1-D
@@ -5945,68 +6268,92 @@ func MutexV2(scope *Scope, optional ...MutexV2Attr) (resource tf.Output) {
return op.Output(0)
}
-// AvgPool3DAttr is an optional argument to AvgPool3D.
-type AvgPool3DAttr func(optionalAttr)
-
-// AvgPool3DDataFormat sets the optional data_format attribute to value.
+// Returns element-wise remainder of division. This emulates C semantics in that
//
-// value: The data format of the input and output data. With the
-// default format "NDHWC", the data is stored in the order of:
-// [batch, in_depth, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCDHW", the data storage order is:
-// [batch, in_channels, in_depth, in_height, in_width].
-// If not specified, defaults to "NDHWC"
-func AvgPool3DDataFormat(value string) AvgPool3DAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
+// the result here is consistent with a truncating divide. E.g.
+// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`.
+//
+// *NOTE*: `Mod` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
}
+ opspec := tf.OpSpec{
+ Type: "Mod",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Performs 3D average pooling on the input.
+// Computes offsets of concat inputs within its output.
+//
+// For example:
+//
+// ```
+// # 'x' is [2, 2, 7]
+// # 'y' is [2, 3, 7]
+// # 'z' is [2, 5, 7]
+// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
+// ```
+//
+// This is typically used by gradient computations for a concat operation.
//
// Arguments:
-// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
-// ksize: 1-D tensor of length 5. The size of the window for each dimension of
-// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
-// strides: 1-D tensor of length 5. The stride of the sliding window for each
-// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
-// padding: The type of padding algorithm to use.
+// concat_dim: The dimension along which to concatenate.
+// shape: The `N` int32 vectors representing shape of tensors being concatenated.
//
-// Returns The average pooled output tensor.
-func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) {
+// Returns The `N` int32 vectors representing the starting offset
+// of input tensors within the concatenated output.
+func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "AvgPool3D",
+ Type: "ConcatOffset",
Input: []tf.Input{
- input,
+ concat_dim, tf.OutputList(shape),
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
+ scope.UpdateErr("ConcatOffset", err)
+ return
+ }
+ return offset
}
-// Returns element-wise remainder of division. This emulates C semantics in that
+// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
//
-// the result here is consistent with a truncating divide. E.g.
-// `tf.truncatediv(x, y) * y + truncate_mod(x, y) = x`.
+// The lower regularized incomplete Gamma function is defined as:
//
-// *NOTE*: `Mod` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+//
+// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
+//
+// where
+//
+// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+//
+// is the lower incomplete Gamma function.
+//
+// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
+// Gamma function.
+func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Mod",
+ Type: "Igamma",
Input: []tf.Input{
- x, y,
+ a, x,
},
}
op := scope.AddOperation(opspec)
@@ -6832,55 +7179,51 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output)
return op.Output(0)
}
-// Transforms a Tensor into a serialized TensorProto proto.
-//
-// Arguments:
-// tensor: A Tensor of type `T`.
+// Shuffle dimensions of x according to a permutation.
//
-// Returns A serialized TensorProto proto of the input tensor.
-func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) {
+// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
+// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
+func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "SerializeTensor",
+ Type: "Transpose",
Input: []tf.Input{
- tensor,
+ x, perm,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// MatrixSolveAttr is an optional argument to MatrixSolve.
-type MatrixSolveAttr func(optionalAttr)
+// MinAttr is an optional argument to Min.
+type MinAttr func(optionalAttr)
-// MatrixSolveAdjoint sets the optional adjoint attribute to value.
+// MinKeepDims sets the optional keep_dims attribute to value.
//
-// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
-// adjoint.
+// value: If true, retain reduced dimensions with length 1.
// If not specified, defaults to false
-func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
+func MinKeepDims(value bool) MinAttr {
return func(m optionalAttr) {
- m["adjoint"] = value
+ m["keep_dims"] = value
}
}
-// Solves systems of linear equations.
+// Computes the minimum of elements across dimensions of a tensor.
//
-// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
-// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
-// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
-// If `adjoint` is `True` then each output matrix satisfies
-// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
+// Reduces `input` along the dimensions given in `axis`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `axis`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
//
// Arguments:
-// matrix: Shape is `[..., M, M]`.
-// rhs: Shape is `[..., M, K]`.
+// input: The tensor to reduce.
+// axis: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
//
-// Returns Shape is `[..., M, K]`.
-func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
+// Returns The reduced tensor.
+func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -6889,9 +7232,9 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr
a(attrs)
}
opspec := tf.OpSpec{
- Type: "MatrixSolve",
+ Type: "Min",
Input: []tf.Input{
- matrix, rhs,
+ input, axis,
},
Attrs: attrs,
}
@@ -6899,6 +7242,26 @@ func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...Matr
return op.Output(0)
}
+// Transforms a Tensor into a serialized TensorProto proto.
+//
+// Arguments:
+// tensor: A Tensor of type `T`.
+//
+// Returns A serialized TensorProto proto of the input tensor.
+func SerializeTensor(scope *Scope, tensor tf.Output) (serialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SerializeTensor",
+ Input: []tf.Input{
+ tensor,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes acos of x element-wise.
func Acos(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -7434,6 +7797,154 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
+// Returns the element-wise sum of a list of tensors.
+//
+// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
+// wait for all of its inputs to be ready before beginning to sum. This can
+// save memory if inputs are ready at different times, since minimum temporary
+// storage is proportional to the output size rather than the inputs size.
+//
+// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+//
+// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+//
+// Arguments:
+// inputs: A list of `Tensor` objects, each with same shape and type.
+// shape: Shape of elements of `inputs`.
+func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "AccumulateNV2",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// RandomShuffleAttr is an optional argument to RandomShuffle.
+type RandomShuffleAttr func(optionalAttr)
+
+// RandomShuffleSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomShuffleSeed(value int64) RandomShuffleAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomShuffleSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomShuffleSeed2(value int64) RandomShuffleAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Randomly shuffles a tensor along its first dimension.
+//
+// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
+// to one and only one `output[i]`. For example, a mapping that might occur for a
+// 3x2 tensor is:
+//
+// ```
+// [[1, 2], [[5, 6],
+// [3, 4], ==> [1, 2],
+// [5, 6]] [3, 4]]
+// ```
+//
+// Arguments:
+// value: The tensor to be shuffled.
+//
+// Returns A tensor of same shape and type as `value`, shuffled along its first
+// dimension.
+func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomShuffle",
+ Input: []tf.Input{
+ value,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize.
+type OrderedMapIncompleteSizeAttr func(optionalAttr)
+
+// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr {
+ return func(m optionalAttr) {
+ m["memory_limit"] = value
+ }
+}
+
+// OrderedMapIncompleteSizeContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Op returns the number of incomplete elements in the underlying container.
+func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "OrderedMapIncompleteSize",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
@@ -7504,103 +8015,51 @@ func DepthwiseConv2dNativeBackpropFilter(scope *Scope, input tf.Output, filter_s
return op.Output(0)
}
-// LRNGradAttr is an optional argument to LRNGrad.
-type LRNGradAttr func(optionalAttr)
-
-// LRNGradDepthRadius sets the optional depth_radius attribute to value.
-//
-// value: A depth radius.
-// If not specified, defaults to 5
-func LRNGradDepthRadius(value int64) LRNGradAttr {
- return func(m optionalAttr) {
- m["depth_radius"] = value
- }
-}
-
-// LRNGradBias sets the optional bias attribute to value.
-//
-// value: An offset (usually > 0 to avoid dividing by 0).
-// If not specified, defaults to 1
-func LRNGradBias(value float32) LRNGradAttr {
- return func(m optionalAttr) {
- m["bias"] = value
- }
-}
-
-// LRNGradAlpha sets the optional alpha attribute to value.
-//
-// value: A scale factor, usually positive.
-// If not specified, defaults to 1
-func LRNGradAlpha(value float32) LRNGradAttr {
- return func(m optionalAttr) {
- m["alpha"] = value
- }
-}
-
-// LRNGradBeta sets the optional beta attribute to value.
+// Returns immutable tensor from memory region.
//
-// value: An exponent.
-// If not specified, defaults to 0.5
-func LRNGradBeta(value float32) LRNGradAttr {
- return func(m optionalAttr) {
- m["beta"] = value
- }
-}
-
-// Gradients for Local Response Normalization.
+// The current implementation memmaps the tensor from a file.
//
// Arguments:
-// input_grads: 4-D with shape `[batch, height, width, channels]`.
-// input_image: 4-D with shape `[batch, height, width, channels]`.
-// output_image: 4-D with shape `[batch, height, width, channels]`.
-//
-// Returns The gradients for LRN.
-func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) {
+// dtype: Type of the returned tensor.
+// shape: Shape of the returned tensor.
+// memory_region_name: Name of readonly memory region used by the tensor, see
+// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
+func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
+ attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
opspec := tf.OpSpec{
- Type: "LRNGrad",
- Input: []tf.Input{
- input_grads, input_image, output_image,
- },
+ Type: "ImmutableConst",
+
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// AnyAttr is an optional argument to Any.
-type AnyAttr func(optionalAttr)
+// StringJoinAttr is an optional argument to StringJoin.
+type StringJoinAttr func(optionalAttr)
-// AnyKeepDims sets the optional keep_dims attribute to value.
+// StringJoinSeparator sets the optional separator attribute to value.
//
-// value: If true, retain reduced dimensions with length 1.
-// If not specified, defaults to false
-func AnyKeepDims(value bool) AnyAttr {
+// value: string, an optional join separator.
+// If not specified, defaults to ""
+func StringJoinSeparator(value string) StringJoinAttr {
return func(m optionalAttr) {
- m["keep_dims"] = value
+ m["separator"] = value
}
}
-// Computes the "logical or" of elements across dimensions of a tensor.
+// Joins the strings in the given list of string tensors into one tensor;
//
-// Reduces `input` along the dimensions given in `axis`. Unless
-// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
-// `axis`. If `keep_dims` is true, the reduced dimensions are
-// retained with length 1.
+// with the given separator (default is an empty separator).
//
// Arguments:
-// input: The tensor to reduce.
-// axis: The dimensions to reduce. Must be in the range
-// `[-rank(input), rank(input))`.
-//
-// Returns The reduced tensor.
-func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) {
+// inputs: A list of string tensors. The tensors must all have the same shape,
+// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
+// of non-scalar inputs.
+func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -7609,9 +8068,9 @@ func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (ou
a(attrs)
}
opspec := tf.OpSpec{
- Type: "Any",
+ Type: "StringJoin",
Input: []tf.Input{
- input, axis,
+ tf.OutputList(inputs),
},
Attrs: attrs,
}
@@ -7805,27 +8264,6 @@ func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_
return op.Output(0)
}
-// Makes a copy of `x`.
-//
-// Arguments:
-// x: The source tensor of type `T`.
-//
-// Returns y: A `Tensor` of type `T`. A copy of `x`. Guaranteed that `y`
-// is not an alias of `x`.
-func DeepCopy(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DeepCopy",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -7999,6 +8437,83 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe
return op.Output(0)
}
+// Converts each string in the input Tensor to its hash mod by a number of buckets.
+//
+// The hash function is deterministic on the content of the string within the
+// process.
+//
+// Note that the hash function may change from time to time.
+// This functionality will be deprecated and it's recommended to use
+// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`.
+//
+// Arguments:
+//
+// num_buckets: The number of buckets.
+//
+// Returns A Tensor of the same shape as the input `string_tensor`.
+func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_buckets": num_buckets}
+ opspec := tf.OpSpec{
+ Type: "StringToHashBucket",
+ Input: []tf.Input{
+ string_tensor,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes gradients for the exponential linear (Elu) operation.
+//
+// Arguments:
+// gradients: The backpropagated gradients to the corresponding Elu operation.
+// outputs: The outputs of the corresponding Elu operation.
+//
+// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0,
+// `gradients` otherwise.
+func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "EluGrad",
+ Input: []tf.Input{
+ gradients, outputs,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that contains `count` elements from the `input_dataset`.
+//
+// Arguments:
+//
+// count: A scalar representing the number of elements from the `input_dataset`
+// that should be taken. A value of `-1` indicates that all of `input_dataset`
+// is taken.
+//
+//
+func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "TakeDataset",
+ Input: []tf.Input{
+ input_dataset, count,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Reads the value of a variable.
//
// The tensor returned by this operation is immutable.
@@ -8082,6 +8597,248 @@ func BoostedTreesUpdateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, fe
return scope.AddOperation(opspec)
}
+// EncodeJpegAttr is an optional argument to EncodeJpeg.
+type EncodeJpegAttr func(optionalAttr)
+
+// EncodeJpegFormat sets the optional format attribute to value.
+//
+// value: Per pixel image format.
+// If not specified, defaults to ""
+func EncodeJpegFormat(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["format"] = value
+ }
+}
+
+// EncodeJpegQuality sets the optional quality attribute to value.
+//
+// value: Quality of the compression from 0 to 100 (higher is better and slower).
+// If not specified, defaults to 95
+func EncodeJpegQuality(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["quality"] = value
+ }
+}
+
+// EncodeJpegProgressive sets the optional progressive attribute to value.
+//
+// value: If True, create a JPEG that loads progressively (coarse to fine).
+// If not specified, defaults to false
+func EncodeJpegProgressive(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["progressive"] = value
+ }
+}
+
+// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value.
+//
+// value: If True, spend CPU/RAM to reduce size with no quality change.
+// If not specified, defaults to false
+func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["optimize_size"] = value
+ }
+}
+
+// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value.
+//
+// value: See http://en.wikipedia.org/wiki/Chroma_subsampling.
+// If not specified, defaults to true
+func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["chroma_downsampling"] = value
+ }
+}
+
+// EncodeJpegDensityUnit sets the optional density_unit attribute to value.
+//
+// value: Unit used to specify `x_density` and `y_density`:
+// pixels per inch (`'in'`) or centimeter (`'cm'`).
+// If not specified, defaults to "in"
+func EncodeJpegDensityUnit(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["density_unit"] = value
+ }
+}
+
+// EncodeJpegXDensity sets the optional x_density attribute to value.
+//
+// value: Horizontal pixels per density unit.
+// If not specified, defaults to 300
+func EncodeJpegXDensity(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["x_density"] = value
+ }
+}
+
+// EncodeJpegYDensity sets the optional y_density attribute to value.
+//
+// value: Vertical pixels per density unit.
+// If not specified, defaults to 300
+func EncodeJpegYDensity(value int64) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["y_density"] = value
+ }
+}
+
+// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value.
+//
+// value: If not empty, embed this XMP metadata in the image header.
+// If not specified, defaults to ""
+func EncodeJpegXmpMetadata(value string) EncodeJpegAttr {
+ return func(m optionalAttr) {
+ m["xmp_metadata"] = value
+ }
+}
+
+// JPEG-encode an image.
+//
+// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
+//
+// The attr `format` can be used to override the color format of the encoded
+// output. Values can be:
+//
+// * `''`: Use a default format based on the number of channels in the image.
+// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension
+// of `image` must be 1.
+// * `rgb`: Output an RGB JPEG image. The `channels` dimension
+// of `image` must be 3.
+//
+// If `format` is not specified or is the empty string, a default format is picked
+// in function of the number of channels in `image`:
+//
+// * 1: Output a grayscale image.
+// * 3: Output an RGB image.
+//
+// Arguments:
+// image: 3-D with shape `[height, width, channels]`.
+//
+// Returns 0-D. JPEG-encoded image.
+func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeJpeg",
+ Input: []tf.Input{
+ image,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MultinomialAttr is an optional argument to Multinomial.
+type MultinomialAttr func(optionalAttr)
+
+// MultinomialSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 is set to be non-zero, the internal random number
+// generator is seeded by the given seed. Otherwise, a random seed is used.
+// If not specified, defaults to 0
+func MultinomialSeed(value int64) MultinomialAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// MultinomialSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func MultinomialSeed2(value int64) MultinomialAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// MultinomialOutputDtype sets the optional output_dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func MultinomialOutputDtype(value tf.DataType) MultinomialAttr {
+ return func(m optionalAttr) {
+ m["output_dtype"] = value
+ }
+}
+
+// Draws samples from a multinomial distribution.
+//
+// Arguments:
+// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]`
+// represents the unnormalized log probabilities for all classes.
+// num_samples: 0-D. Number of independent samples to draw for each row slice.
+//
+// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]`
+// contains the drawn class labels with range `[0, num_classes)`.
+func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Multinomial",
+ Input: []tf.Input{
+ logits, num_samples,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA.
+type ResourceSparseApplyAdagradDAAttr func(optionalAttr)
+
+// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value.
+//
+// value: If True, updating of the var and accum tensors will be protected by
+// a lock; otherwise the behavior is undefined, but may exhibit less contention.
+// If not specified, defaults to false
+func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update entries in '*var' and '*accum' according to the proximal adagrad scheme.
+//
+// Arguments:
+// var_: Should be from a Variable().
+// gradient_accumulator: Should be from a Variable().
+// gradient_squared_accumulator: Should be from a Variable().
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var and accum.
+// lr: Learning rate. Must be a scalar.
+// l1: L1 regularization. Must be a scalar.
+// l2: L2 regularization. Must be a scalar.
+// global_step: Training step number. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyAdagradDA",
+ Input: []tf.Input{
+ var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// ResourceSparseApplyFtrlAttr is an optional argument to ResourceSparseApplyFtrl.
type ResourceSparseApplyFtrlAttr func(optionalAttr)
@@ -8303,95 +9060,6 @@ func RestoreSlice(scope *Scope, file_pattern tf.Output, tensor_name tf.Output, s
return op.Output(0)
}
-// ImagAttr is an optional argument to Imag.
-type ImagAttr func(optionalAttr)
-
-// ImagTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_FLOAT
-func ImagTout(value tf.DataType) ImagAttr {
- return func(m optionalAttr) {
- m["Tout"] = value
- }
-}
-
-// Returns the imaginary part of a complex number.
-//
-// Given a tensor `input` of complex numbers, this operation returns a tensor of
-// type `float` that is the imaginary part of each element in `input`. All
-// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
-// is the real part and *b* is the imaginary part returned by this operation.
-//
-// For example:
-//
-// ```
-// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-// tf.imag(input) ==> [4.75, 5.75]
-// ```
-func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Imag",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ComplexAttr is an optional argument to Complex.
-type ComplexAttr func(optionalAttr)
-
-// ComplexTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_COMPLEX64
-func ComplexTout(value tf.DataType) ComplexAttr {
- return func(m optionalAttr) {
- m["Tout"] = value
- }
-}
-
-// Converts two real numbers to a complex number.
-//
-// Given a tensor `real` representing the real part of a complex number, and a
-// tensor `imag` representing the imaginary part of a complex number, this
-// operation returns complex numbers elementwise of the form \\(a + bj\\), where
-// *a* represents the `real` part and *b* represents the `imag` part.
-//
-// The input tensors `real` and `imag` must have the same shape.
-//
-// For example:
-//
-// ```
-// # tensor 'real' is [2.25, 3.25]
-// # tensor `imag` is [4.75, 5.75]
-// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
-// ```
-func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Complex",
- Input: []tf.Input{
- real, imag,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Divides sparse updates into the variable referenced by `resource`.
//
// This operation computes
@@ -8433,6 +9101,23 @@ func ResourceScatterDiv(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
+ opspec := tf.OpSpec{
+ Type: "CollectiveReduce",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StatelessRandomNormalAttr is an optional argument to StatelessRandomNormal.
type StatelessRandomNormalAttr func(optionalAttr)
@@ -8476,6 +9161,186 @@ func StatelessRandomNormal(scope *Scope, shape tf.Output, seed tf.Output, option
return op.Output(0)
}
+// MaxPoolAttr is an optional argument to MaxPool.
+type MaxPoolAttr func(optionalAttr)
+
+// MaxPoolDataFormat sets the optional data_format attribute to value.
+//
+// value: Specify the data format of the input and output data. With the
+// default format "NHWC", the data is stored in the order of:
+// [batch, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCHW", the data storage order of:
+// [batch, in_channels, in_height, in_width].
+// If not specified, defaults to "NHWC"
+func MaxPoolDataFormat(value string) MaxPoolAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Performs max pooling on the input.
+//
+// Arguments:
+// input: 4-D input to pool over.
+// ksize: The size of the window for each dimension of the input tensor.
+// strides: The stride of the sliding window for each dimension of the
+// input tensor.
+// padding: The type of padding algorithm to use.
+//
+// Returns The max pooled output tensor.
+func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MaxPool",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// SparseMatMulAttr is an optional argument to SparseMatMul.
+type SparseMatMulAttr func(optionalAttr)
+
+// SparseMatMulTransposeA sets the optional transpose_a attribute to value.
+// If not specified, defaults to false
+func SparseMatMulTransposeA(value bool) SparseMatMulAttr {
+ return func(m optionalAttr) {
+ m["transpose_a"] = value
+ }
+}
+
+// SparseMatMulTransposeB sets the optional transpose_b attribute to value.
+// If not specified, defaults to false
+func SparseMatMulTransposeB(value bool) SparseMatMulAttr {
+ return func(m optionalAttr) {
+ m["transpose_b"] = value
+ }
+}
+
+// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value.
+// If not specified, defaults to false
+func SparseMatMulAIsSparse(value bool) SparseMatMulAttr {
+ return func(m optionalAttr) {
+ m["a_is_sparse"] = value
+ }
+}
+
+// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value.
+// If not specified, defaults to false
+func SparseMatMulBIsSparse(value bool) SparseMatMulAttr {
+ return func(m optionalAttr) {
+ m["b_is_sparse"] = value
+ }
+}
+
+// Multiply matrix "a" by matrix "b".
+//
+// The inputs must be two-dimensional matrices and the inner dimension of "a" must
+// match the outer dimension of "b". This op is optimized for the case where at
+// least one of "a" or "b" is sparse. The breakeven for using this versus a dense
+// matrix multiply on one platform was 30% zero values in the sparse matrix.
+//
+// The gradient computation of this operation will only take advantage of sparsity
+// in the input gradient when that gradient comes from a Relu.
+func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseMatMul",
+ Input: []tf.Input{
+ a, b,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Concatenates quantized tensors along one dimension.
+//
+// Arguments:
+// concat_dim: 0-D. The dimension along which to concatenate. Must be in the
+// range [0, rank(values)).
+// values: The `N` Tensors to concatenate. Their ranks and types must match,
+// and their sizes must match in all dimensions except `concat_dim`.
+// input_mins: The minimum scalar values for each of the input tensors.
+// input_maxes: The maximum scalar values for each of the input tensors.
+//
+// Returns A `Tensor` with the concatenation of values stacked along the
+// `concat_dim` dimension. This tensor's shape matches that of `values` except
+// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents.
+func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "QuantizedConcat",
+ Input: []tf.Input{
+ concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Slice a `SparseTensor` based on the `start` and `size`.
+//
+// For example, if the input is
+//
+// input_tensor = shape = [2, 7]
+// [ a d e ]
+// [b c ]
+//
+// Graphically the output tensors are:
+//
+// sparse_slice([0, 0], [2, 4]) = shape = [2, 4]
+// [ a ]
+// [b c ]
+//
+// sparse_slice([0, 4], [2, 3]) = shape = [2, 3]
+// [ d e ]
+// [ ]
+//
+// Arguments:
+// indices: 2-D tensor represents the indices of the sparse tensor.
+// values: 1-D tensor represents the values of the sparse tensor.
+// shape: 1-D. tensor represents the shape of the sparse tensor.
+// start: 1-D. tensor represents the start of the slice.
+// size: 1-D. tensor represents the size of the slice.
+// output indices: A list of 1-D tensors represents the indices of the output
+// sparse tensors.
+//
+// Returns A list of 1-D tensors represents the values of the output sparse
+// tensors.A list of 1-D tensors represents the shape of the output sparse
+// tensors.
+func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSlice",
+ Input: []tf.Input{
+ indices, values, shape, start, size,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Reduces sparse updates into the variable referenced by `resource` using the `min` operation.
//
// This operation computes
@@ -9258,83 +10123,6 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
return op.Output(0)
}
-// Converts each string in the input Tensor to its hash mod by a number of buckets.
-//
-// The hash function is deterministic on the content of the string within the
-// process.
-//
-// Note that the hash function may change from time to time.
-// This functionality will be deprecated and it's recommended to use
-// `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`.
-//
-// Arguments:
-//
-// num_buckets: The number of buckets.
-//
-// Returns A Tensor of the same shape as the input `string_tensor`.
-func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "StringToHashBucket",
- Input: []tf.Input{
- string_tensor,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes gradients for the exponential linear (Elu) operation.
-//
-// Arguments:
-// gradients: The backpropagated gradients to the corresponding Elu operation.
-// outputs: The outputs of the corresponding Elu operation.
-//
-// Returns The gradients: `gradients * (outputs + 1)` if outputs < 0,
-// `gradients` otherwise.
-func EluGrad(scope *Scope, gradients tf.Output, outputs tf.Output) (backprops tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "EluGrad",
- Input: []tf.Input{
- gradients, outputs,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Creates a dataset that contains `count` elements from the `input_dataset`.
-//
-// Arguments:
-//
-// count: A scalar representing the number of elements from the `input_dataset`
-// that should be taken. A value of `-1` indicates that all of `input_dataset`
-// is taken.
-//
-//
-func TakeDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
- opspec := tf.OpSpec{
- Type: "TakeDataset",
- Input: []tf.Input{
- input_dataset, count,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// The gradient operator for the SparseAdd op.
//
// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
@@ -10541,79 +11329,6 @@ func IFFT(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
-// Generates values in an interval.
-//
-// A sequence of `num` evenly-spaced values are generated beginning at `start`.
-// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
-// so that the last one is exactly `stop`.
-//
-// For example:
-//
-// ```
-// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
-// ```
-//
-// Arguments:
-// start: First entry in the range.
-// stop: Last entry in the range.
-// num: Number of values to generate.
-//
-// Returns 1-D. The generated values.
-func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LinSpace",
- Input: []tf.Input{
- start, stop, num,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
-type DestroyResourceOpAttr func(optionalAttr)
-
-// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value.
-//
-// value: whether to ignore the error when the resource
-// doesn't exist.
-// If not specified, defaults to true
-func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr {
- return func(m optionalAttr) {
- m["ignore_lookup_error"] = value
- }
-}
-
-// Deletes the resource specified by the handle.
-//
-// All subsequent operations using the resource will result in a NotFound
-// error status.
-//
-// Arguments:
-// resource: handle to the resource to delete.
-//
-// Returns the created operation.
-func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DestroyResourceOp",
- Input: []tf.Input{
- resource,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
type ResourceSparseApplyRMSPropAttr func(optionalAttr)
@@ -10742,7 +11457,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
//
// value: The cropped area of the image must contain a fraction of the
-// supplied image within this range.
+// supplied image within in this range.
// If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
@@ -11151,63 +11866,6 @@ func DeserializeSparse(scope *Scope, serialized_sparse tf.Output, dtype tf.DataT
return op.Output(0), op.Output(1), op.Output(2)
}
-// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp.
-type ResourceApplyRMSPropAttr func(optionalAttr)
-
-// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the RMSProp algorithm.
-//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
-//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
-//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
-//
-// Arguments:
-// var_: Should be from a Variable().
-// ms: Should be from a Variable().
-// mom: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// rho: Decay rate. Must be a scalar.
-//
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-//
-// Returns the created operation.
-func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyRMSProp",
- Input: []tf.Input{
- var_, ms, mom, lr, rho, momentum, epsilon, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
// ResourceScatterNdUpdateAttr is an optional argument to ResourceScatterNdUpdate.
type ResourceScatterNdUpdateAttr func(optionalAttr)
@@ -11507,60 +12165,6 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// ResizeAreaAttr is an optional argument to ResizeArea.
-type ResizeAreaAttr func(optionalAttr)
-
-// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
-//
-// value: If true, the centers of the 4 corner pixels of the input and output tensors are
-// aligned, preserving the values at the corner pixels. Defaults to false.
-// If not specified, defaults to false
-func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
- return func(m optionalAttr) {
- m["align_corners"] = value
- }
-}
-
-// Resize `images` to `size` using area interpolation.
-//
-// Input images can be of different types but output images are always float.
-//
-// The range of pixel values for the output image might be slightly different
-// from the range for the input image because of limited numerical precision.
-// To guarantee an output range, for example `[0.0, 1.0]`, apply
-// `tf.clip_by_value` to the output.
-//
-// Each output pixel is computed by first transforming the pixel's footprint into
-// the input tensor and then averaging the pixels that intersect the footprint. An
-// input pixel's contribution to the average is weighted by the fraction of its
-// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
-//
-// Arguments:
-// images: 4-D with shape `[batch, height, width, channels]`.
-// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
-//
-// Returns 4-D with shape
-// `[batch, new_height, new_width, channels]`.
-func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResizeArea",
- Input: []tf.Input{
- images, size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// 2D real-valued fast Fourier transform.
//
// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
@@ -11736,23 +12340,6 @@ func TensorArrayGatherV3(scope *Scope, handle tf.Output, indices tf.Output, flow
return op.Output(0)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
- opspec := tf.OpSpec{
- Type: "CollectiveReduce",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// This op consumes a lock created by `MutexLock`.
//
// This op exists to consume a tensor created by `MutexLock` (other than
@@ -11854,81 +12441,6 @@ func RestoreV2(scope *Scope, prefix tf.Output, tensor_names tf.Output, shape_and
return tensors
}
-// Creates a dataset that skips `count` elements from the `input_dataset`.
-//
-// Arguments:
-//
-// count: A scalar representing the number of elements from the `input_dataset`
-// that should be skipped. If count is -1, skips everything.
-//
-//
-func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
- opspec := tf.OpSpec{
- Type: "SkipDataset",
- Input: []tf.Input{
- input_dataset, count,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the maximum along segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
-//
-// Computes a tensor such that
-// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the max is empty for a given segment ID `i`, `output[i] = 0`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SegmentMax",
- Input: []tf.Input{
- data, segment_ids,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes hyperbolic tangent of `x` element-wise.
-func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tanh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Receives a tensor value broadcast from another device.
func CollectiveBcastRecv(scope *Scope, T tf.DataType, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
if scope.Err() != nil {
@@ -12386,248 +12898,6 @@ func MaxPoolWithArgmax(scope *Scope, input tf.Output, ksize []int64, strides []i
return op.Output(0), op.Output(1)
}
-// ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA.
-type ResourceSparseApplyAdagradDAAttr func(optionalAttr)
-
-// ResourceSparseApplyAdagradDAUseLocking sets the optional use_locking attribute to value.
-//
-// value: If True, updating of the var and accum tensors will be protected by
-// a lock; otherwise the behavior is undefined, but may exhibit less contention.
-// If not specified, defaults to false
-func ResourceSparseApplyAdagradDAUseLocking(value bool) ResourceSparseApplyAdagradDAAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update entries in '*var' and '*accum' according to the proximal adagrad scheme.
-//
-// Arguments:
-// var_: Should be from a Variable().
-// gradient_accumulator: Should be from a Variable().
-// gradient_squared_accumulator: Should be from a Variable().
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var and accum.
-// lr: Learning rate. Must be a scalar.
-// l1: L1 regularization. Must be a scalar.
-// l2: L2 regularization. Must be a scalar.
-// global_step: Training step number. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceSparseApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator tf.Output, gradient_squared_accumulator tf.Output, grad tf.Output, indices tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, global_step tf.Output, optional ...ResourceSparseApplyAdagradDAAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceSparseApplyAdagradDA",
- Input: []tf.Input{
- var_, gradient_accumulator, gradient_squared_accumulator, grad, indices, lr, l1, l2, global_step,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// EncodeJpegAttr is an optional argument to EncodeJpeg.
-type EncodeJpegAttr func(optionalAttr)
-
-// EncodeJpegFormat sets the optional format attribute to value.
-//
-// value: Per pixel image format.
-// If not specified, defaults to ""
-func EncodeJpegFormat(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["format"] = value
- }
-}
-
-// EncodeJpegQuality sets the optional quality attribute to value.
-//
-// value: Quality of the compression from 0 to 100 (higher is better and slower).
-// If not specified, defaults to 95
-func EncodeJpegQuality(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["quality"] = value
- }
-}
-
-// EncodeJpegProgressive sets the optional progressive attribute to value.
-//
-// value: If True, create a JPEG that loads progressively (coarse to fine).
-// If not specified, defaults to false
-func EncodeJpegProgressive(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["progressive"] = value
- }
-}
-
-// EncodeJpegOptimizeSize sets the optional optimize_size attribute to value.
-//
-// value: If True, spend CPU/RAM to reduce size with no quality change.
-// If not specified, defaults to false
-func EncodeJpegOptimizeSize(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["optimize_size"] = value
- }
-}
-
-// EncodeJpegChromaDownsampling sets the optional chroma_downsampling attribute to value.
-//
-// value: See http://en.wikipedia.org/wiki/Chroma_subsampling.
-// If not specified, defaults to true
-func EncodeJpegChromaDownsampling(value bool) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["chroma_downsampling"] = value
- }
-}
-
-// EncodeJpegDensityUnit sets the optional density_unit attribute to value.
-//
-// value: Unit used to specify `x_density` and `y_density`:
-// pixels per inch (`'in'`) or centimeter (`'cm'`).
-// If not specified, defaults to "in"
-func EncodeJpegDensityUnit(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["density_unit"] = value
- }
-}
-
-// EncodeJpegXDensity sets the optional x_density attribute to value.
-//
-// value: Horizontal pixels per density unit.
-// If not specified, defaults to 300
-func EncodeJpegXDensity(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["x_density"] = value
- }
-}
-
-// EncodeJpegYDensity sets the optional y_density attribute to value.
-//
-// value: Vertical pixels per density unit.
-// If not specified, defaults to 300
-func EncodeJpegYDensity(value int64) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["y_density"] = value
- }
-}
-
-// EncodeJpegXmpMetadata sets the optional xmp_metadata attribute to value.
-//
-// value: If not empty, embed this XMP metadata in the image header.
-// If not specified, defaults to ""
-func EncodeJpegXmpMetadata(value string) EncodeJpegAttr {
- return func(m optionalAttr) {
- m["xmp_metadata"] = value
- }
-}
-
-// JPEG-encode an image.
-//
-// `image` is a 3-D uint8 Tensor of shape `[height, width, channels]`.
-//
-// The attr `format` can be used to override the color format of the encoded
-// output. Values can be:
-//
-// * `''`: Use a default format based on the number of channels in the image.
-// * `grayscale`: Output a grayscale JPEG image. The `channels` dimension
-// of `image` must be 1.
-// * `rgb`: Output an RGB JPEG image. The `channels` dimension
-// of `image` must be 3.
-//
-// If `format` is not specified or is the empty string, a default format is picked
-// in function of the number of channels in `image`:
-//
-// * 1: Output a grayscale image.
-// * 3: Output an RGB image.
-//
-// Arguments:
-// image: 3-D with shape `[height, width, channels]`.
-//
-// Returns 0-D. JPEG-encoded image.
-func EncodeJpeg(scope *Scope, image tf.Output, optional ...EncodeJpegAttr) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeJpeg",
- Input: []tf.Input{
- image,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// MultinomialAttr is an optional argument to Multinomial.
-type MultinomialAttr func(optionalAttr)
-
-// MultinomialSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 is set to be non-zero, the internal random number
-// generator is seeded by the given seed. Otherwise, a random seed is used.
-// If not specified, defaults to 0
-func MultinomialSeed(value int64) MultinomialAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// MultinomialSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func MultinomialSeed2(value int64) MultinomialAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// MultinomialOutputDtype sets the optional output_dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func MultinomialOutputDtype(value tf.DataType) MultinomialAttr {
- return func(m optionalAttr) {
- m["output_dtype"] = value
- }
-}
-
-// Draws samples from a multinomial distribution.
-//
-// Arguments:
-// logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice `[i, :]`
-// represents the unnormalized log probabilities for all classes.
-// num_samples: 0-D. Number of independent samples to draw for each row slice.
-//
-// Returns 2-D Tensor with shape `[batch_size, num_samples]`. Each slice `[i, :]`
-// contains the drawn class labels with range `[0, num_classes)`.
-func Multinomial(scope *Scope, logits tf.Output, num_samples tf.Output, optional ...MultinomialAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Multinomial",
- Input: []tf.Input{
- logits, num_samples,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the truth value of NOT x element-wise.
func LogicalNot(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -12990,122 +13260,6 @@ func Conv3DBackpropInput(scope *Scope, input tf.Output, filter tf.Output, out_ba
return op.Output(0)
}
-// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad.
-type ResourceApplyProximalAdagradAttr func(optionalAttr)
-
-// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value.
-//
-// value: If True, updating of the var and accum tensors will be protected by
-// a lock; otherwise the behavior is undefined, but may exhibit less contention.
-// If not specified, defaults to false
-func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate.
-//
-// accum += grad * grad
-// prox_v = var - lr * grad * (1 / sqrt(accum))
-// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regularization. Must be a scalar.
-// l2: L2 regularization. Must be a scalar.
-// grad: The gradient.
-//
-// Returns the created operation.
-func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyProximalAdagrad",
- Input: []tf.Input{
- var_, accum, lr, l1, l2, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
-type MutableHashTableOfTensorsV2Attr func(optionalAttr)
-
-// MutableHashTableOfTensorsV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this table is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this table is shared under the given name across
-// multiple sessions.
-// If not specified, defaults to ""
-func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
-// If not specified, defaults to false
-func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["use_node_name_sharing"] = value
- }
-}
-
-// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value.
-// If not specified, defaults to <>
-func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr {
- return func(m optionalAttr) {
- m["value_shape"] = value
- }
-}
-
-// Creates an empty hash table.
-//
-// This op creates a mutable hash table, specifying the type of its keys and
-// values. Each value must be a vector. Data can be inserted into the table using
-// the insert operations. It does not support the initialization operation.
-//
-// Arguments:
-// key_dtype: Type of the table keys.
-// value_dtype: Type of the table values.
-//
-// Returns Handle to a table.
-func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MutableHashTableOfTensorsV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Subtracts sparse updates from the variable referenced by `resource`.
//
// This operation computes
@@ -13147,62 +13301,6 @@ func ResourceScatterSub(scope *Scope, resource tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
-// Inverse 2D fast Fourier transform.
-//
-// Computes the inverse 2-dimensional discrete Fourier transform over the
-// inner-most 2 dimensions of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most 2
-// dimensions of `input` are replaced with their inverse 2D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.ifft2
-// @end_compatibility
-func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IFFT2D",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// 2D fast Fourier transform.
-//
-// Computes the 2-dimensional discrete Fourier transform over the inner-most
-// 2 dimensions of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most 2
-// dimensions of `input` are replaced with their 2D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.fft2
-// @end_compatibility
-func FFT2D(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "FFT2D",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceApplyProximalGradientDescentAttr is an optional argument to ResourceApplyProximalGradientDescent.
type ResourceApplyProximalGradientDescentAttr func(optionalAttr)
@@ -13871,31 +13969,101 @@ func ResourceApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Output, ms
return scope.AddOperation(opspec)
}
-// RealAttr is an optional argument to Real.
-type RealAttr func(optionalAttr)
+// Computes the gradient for the inverse of `x` wrt its input.
+//
+// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
+// is the corresponding input gradient.
+func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReciprocalGrad",
+ Input: []tf.Input{
+ y, dy,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
-// RealTout sets the optional Tout attribute to value.
-// If not specified, defaults to DT_FLOAT
-func RealTout(value tf.DataType) RealAttr {
+// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
+//
+// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Minimum",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// MfccAttr is an optional argument to Mfcc.
+type MfccAttr func(optionalAttr)
+
+// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
+//
+// value: The highest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 4000
+func MfccUpperFrequencyLimit(value float32) MfccAttr {
return func(m optionalAttr) {
- m["Tout"] = value
+ m["upper_frequency_limit"] = value
}
}
-// Returns the real part of a complex number.
+// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
//
-// Given a tensor `input` of complex numbers, this operation returns a tensor of
-// type `float` that is the real part of each element in `input`. All elements in
-// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
-// part returned by this operation and *b* is the imaginary part.
+// value: The lowest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 20
+func MfccLowerFrequencyLimit(value float32) MfccAttr {
+ return func(m optionalAttr) {
+ m["lower_frequency_limit"] = value
+ }
+}
+
+// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
//
-// For example:
+// value: Resolution of the Mel bank used internally.
+// If not specified, defaults to 40
+func MfccFilterbankChannelCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["filterbank_channel_count"] = value
+ }
+}
+
+// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
//
-// ```
-// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-// tf.real(input) ==> [-2.25, 3.25]
-// ```
-func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
+// value: How many output channels to produce per time slice.
+// If not specified, defaults to 13
+func MfccDctCoefficientCount(value int64) MfccAttr {
+ return func(m optionalAttr) {
+ m["dct_coefficient_count"] = value
+ }
+}
+
+// Transforms a spectrogram into a form that's useful for speech recognition.
+//
+// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+// been effective as an input feature for machine learning. They are created by
+// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+// higher frequencies that are less significant to the human ear. They have a long
+// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+// is a good resource to learn more.
+//
+// Arguments:
+// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
+// set to true.
+// sample_rate: How many samples per second the source audio used.
+func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -13904,9 +14072,9 @@ func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output
a(attrs)
}
opspec := tf.OpSpec{
- Type: "Real",
+ Type: "Mfcc",
Input: []tf.Input{
- input,
+ spectrogram, sample_rate,
},
Attrs: attrs,
}
@@ -14332,65 +14500,6 @@ func TensorArraySplitV2(scope *Scope, handle tf.Output, value tf.Output, lengths
return op.Output(0)
}
-// PackAttr is an optional argument to Pack.
-type PackAttr func(optionalAttr)
-
-// PackAxis sets the optional axis attribute to value.
-//
-// value: Dimension along which to pack. Negative values wrap around, so the
-// valid range is `[-(R+1), R+1)`.
-// If not specified, defaults to 0
-func PackAxis(value int64) PackAttr {
- return func(m optionalAttr) {
- m["axis"] = value
- }
-}
-
-// Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor.
-//
-// Packs the `N` tensors in `values` into a tensor with rank one higher than each
-// tensor in `values`, by packing them along the `axis` dimension.
-// Given a list of tensors of shape `(A, B, C)`;
-//
-// if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
-// if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
-// Etc.
-//
-// For example:
-//
-// ```
-// # 'x' is [1, 4]
-// # 'y' is [2, 5]
-// # 'z' is [3, 6]
-// pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
-// pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
-// ```
-//
-// This is the opposite of `unpack`.
-//
-// Arguments:
-// values: Must be of same shape and type.
-//
-// Returns The packed tensor.
-func Pack(scope *Scope, values []tf.Output, optional ...PackAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Pack",
- Input: []tf.Input{
- tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Reorders a SparseTensor into the canonical, row-major ordering.
//
// Note that by convention, all sparse ops preserve the canonical ordering along
@@ -15050,30 +15159,6 @@ func Sigmoid(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Updates specified rows with values in `v`.
-//
-// Computes `x[i, :] = v; return x`.
-//
-// Arguments:
-// x: A tensor of type `T`.
-// i: A vector. Indices into the left-most dimension of `x`.
-// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
-//
-// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
-func InplaceUpdate(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InplaceUpdate",
- Input: []tf.Input{
- x, i, v,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// FusedBatchNormAttr is an optional argument to FusedBatchNorm.
type FusedBatchNormAttr func(optionalAttr)
@@ -15321,31 +15406,6 @@ func BoostedTreesEnsembleResourceHandleOp(scope *Scope, optional ...BoostedTrees
return op.Output(0)
}
-// Concatenates tensors along one dimension.
-//
-// Arguments:
-// concat_dim: 0-D. The dimension along which to concatenate. Must be in the
-// range [0, rank(values)).
-// values: The `N` Tensors to concatenate. Their ranks and types must match,
-// and their sizes must match in all dimensions except `concat_dim`.
-//
-// Returns A `Tensor` with the concatenation of values stacked along the
-// `concat_dim` dimension. This tensor's shape matches that of `values` except
-// in `concat_dim` where it has the sum of the sizes.
-func Concat(scope *Scope, concat_dim tf.Output, values []tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Concat",
- Input: []tf.Input{
- concat_dim, tf.OutputList(values),
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ResourceApplyMomentumAttr is an optional argument to ResourceApplyMomentum.
type ResourceApplyMomentumAttr func(optionalAttr)
@@ -16264,6 +16324,119 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// 2D fast Fourier transform.
+//
+// Computes the 2-dimensional discrete Fourier transform over the inner-most
+// 2 dimensions of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most 2
+// dimensions of `input` are replaced with their 2D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.fft2
+// @end_compatibility
+func FFT2D(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "FFT2D",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Inverse 2D fast Fourier transform.
+//
+// Computes the inverse 2-dimensional discrete Fourier transform over the
+// inner-most 2 dimensions of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most 2
+// dimensions of `input` are replaced with their inverse 2D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.ifft2
+// @end_compatibility
+func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IFFT2D",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyRMSPropAttr is an optional argument to ResourceApplyRMSProp.
+type ResourceApplyRMSPropAttr func(optionalAttr)
+
+// ResourceApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyRMSPropUseLocking(value bool) ResourceApplyRMSPropAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+// var_: Should be from a Variable().
+// ms: Should be from a Variable().
+// mom: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// rho: Decay rate. Must be a scalar.
+//
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyRMSPropAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyRMSProp",
+ Input: []tf.Input{
+ var_, ms, mom, lr, rho, momentum, epsilon, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// Returns element-wise remainder of division. This emulates C semantics in that
//
// the result here is consistent with a truncating divide. E.g. `truncate(x / y) *
@@ -17537,69 +17710,6 @@ func DeserializeManySparse(scope *Scope, serialized_sparse tf.Output, dtype tf.D
return op.Output(0), op.Output(1), op.Output(2)
}
-// StringJoinAttr is an optional argument to StringJoin.
-type StringJoinAttr func(optionalAttr)
-
-// StringJoinSeparator sets the optional separator attribute to value.
-//
-// value: string, an optional join separator.
-// If not specified, defaults to ""
-func StringJoinSeparator(value string) StringJoinAttr {
- return func(m optionalAttr) {
- m["separator"] = value
- }
-}
-
-// Joins the strings in the given list of string tensors into one tensor;
-//
-// with the given separator (default is an empty separator).
-//
-// Arguments:
-// inputs: A list of string tensors. The tensors must all have the same shape,
-// or be scalars. Scalars may be mixed in; these will be broadcast to the shape
-// of non-scalar inputs.
-func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringJoin",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns immutable tensor from memory region.
-//
-// The current implementation memmaps the tensor from a file.
-//
-// Arguments:
-// dtype: Type of the returned tensor.
-// shape: Shape of the returned tensor.
-// memory_region_name: Name of readonly memory region used by the tensor, see
-// NewReadOnlyMemoryRegionFromFile in tensorflow::Env.
-func ImmutableConst(scope *Scope, dtype tf.DataType, shape tf.Shape, memory_region_name string) (tensor tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype, "shape": shape, "memory_region_name": memory_region_name}
- opspec := tf.OpSpec{
- Type: "ImmutableConst",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Inverse real-valued fast Fourier transform.
//
// Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
@@ -17780,75 +17890,185 @@ func SparseCross(scope *Scope, indices []tf.Output, values []tf.Output, shapes [
return op.Output(0), op.Output(1), op.Output(2)
}
-// Concatenates quantized tensors along one dimension.
+// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad.
+type ResourceApplyProximalAdagradAttr func(optionalAttr)
+
+// ResourceApplyProximalAdagradUseLocking sets the optional use_locking attribute to value.
+//
+// value: If True, updating of the var and accum tensors will be protected by
+// a lock; otherwise the behavior is undefined, but may exhibit less contention.
+// If not specified, defaults to false
+func ResourceApplyProximalAdagradUseLocking(value bool) ResourceApplyProximalAdagradAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' and '*accum' according to FOBOS with Adagrad learning rate.
+//
+// accum += grad * grad
+// prox_v = var - lr * grad * (1 / sqrt(accum))
+// var = sign(prox_v)/(1+lr*l2) * max{|prox_v|-lr*l1,0}
//
// Arguments:
-// concat_dim: 0-D. The dimension along which to concatenate. Must be in the
-// range [0, rank(values)).
-// values: The `N` Tensors to concatenate. Their ranks and types must match,
-// and their sizes must match in all dimensions except `concat_dim`.
-// input_mins: The minimum scalar values for each of the input tensors.
-// input_maxes: The maximum scalar values for each of the input tensors.
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regularization. Must be a scalar.
+// l2: L2 regularization. Must be a scalar.
+// grad: The gradient.
//
-// Returns A `Tensor` with the concatenation of values stacked along the
-// `concat_dim` dimension. This tensor's shape matches that of `values` except
-// in `concat_dim` where it has the sum of the sizes.The float value that the minimum quantized output value represents.The float value that the maximum quantized output value represents.
-func QuantizedConcat(scope *Scope, concat_dim tf.Output, values []tf.Output, input_mins []tf.Output, input_maxes []tf.Output) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+// Returns the created operation.
+func ResourceApplyProximalAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, grad tf.Output, optional ...ResourceApplyProximalAdagradAttr) (o *tf.Operation) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "QuantizedConcat",
+ Type: "ResourceApplyProximalAdagrad",
Input: []tf.Input{
- concat_dim, tf.OutputList(values), tf.OutputList(input_mins), tf.OutputList(input_maxes),
+ var_, accum, lr, l1, l2, grad,
},
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
+type MutableHashTableOfTensorsV2Attr func(optionalAttr)
+
+// MutableHashTableOfTensorsV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this table is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this table is shared under the given name across
+// multiple sessions.
+// If not specified, defaults to ""
+func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
+// If not specified, defaults to false
+func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["use_node_name_sharing"] = value
+ }
+}
+
+// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value.
+// If not specified, defaults to <>
+func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr {
+ return func(m optionalAttr) {
+ m["value_shape"] = value
+ }
+}
+
+// Creates an empty hash table.
+//
+// This op creates a mutable hash table, specifying the type of its keys and
+// values. Each value must be a vector. Data can be inserted into the table using
+// the insert operations. It does not support the initialization operation.
+//
+// Arguments:
+// key_dtype: Type of the table keys.
+// value_dtype: Type of the table values.
+//
+// Returns Handle to a table.
+func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MutableHashTableOfTensorsV2",
+
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
-// Slice a `SparseTensor` based on the `start` and `size`.
+// Computes the gradient of the sigmoid of `x` wrt its input.
//
-// For example, if the input is
+// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
+// `dy` is the corresponding input gradient.
+func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SigmoidGrad",
+ Input: []tf.Input{
+ y, dy,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Convert one or more images from HSV to RGB.
//
-// input_tensor = shape = [2, 7]
-// [ a d e ]
-// [b c ]
+// Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+// value of the pixels. The output is only well defined if the value in `images`
+// are in `[0,1]`.
//
-// Graphically the output tensors are:
+// See `rgb_to_hsv` for a description of the HSV encoding.
//
-// sparse_slice([0, 0], [2, 4]) = shape = [2, 4]
-// [ a ]
-// [b c ]
+// Arguments:
+// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3.
//
-// sparse_slice([0, 4], [2, 3]) = shape = [2, 3]
-// [ d e ]
-// [ ]
+// Returns `images` converted to RGB.
+func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "HSVToRGB",
+ Input: []tf.Input{
+ images,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
//
// Arguments:
-// indices: 2-D tensor represents the indices of the sparse tensor.
-// values: 1-D tensor represents the values of the sparse tensor.
-// shape: 1-D. tensor represents the shape of the sparse tensor.
-// start: 1-D. tensor represents the start of the slice.
-// size: 1-D. tensor represents the size of the slice.
-// output indices: A list of 1-D tensors represents the indices of the output
-// sparse tensors.
+// tree_ensemble_handle: Handle to the tree ensemble.
//
-// Returns A list of 1-D tensors represents the values of the output sparse
-// tensors.A list of 1-D tensors represents the shape of the output sparse
-// tensors.
-func SparseSlice(scope *Scope, indices tf.Output, values tf.Output, shape tf.Output, start tf.Output, size tf.Output) (output_indices tf.Output, output_values tf.Output, output_shape tf.Output) {
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "SparseSlice",
+ Type: "BoostedTreesGetEnsembleStates",
Input: []tf.Input{
- indices, values, shape, start, size,
+ tree_ensemble_handle,
},
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
// Returns the element-wise min of two SparseTensors.
@@ -17981,52 +18201,6 @@ func TakeManySparseFromTensorsMap(scope *Scope, sparse_handles tf.Output, dtype
return op.Output(0), op.Output(1), op.Output(2)
}
-// MaxPoolAttr is an optional argument to MaxPool.
-type MaxPoolAttr func(optionalAttr)
-
-// MaxPoolDataFormat sets the optional data_format attribute to value.
-//
-// value: Specify the data format of the input and output data. With the
-// default format "NHWC", the data is stored in the order of:
-// [batch, in_height, in_width, in_channels].
-// Alternatively, the format could be "NCHW", the data storage order of:
-// [batch, in_channels, in_height, in_width].
-// If not specified, defaults to "NHWC"
-func MaxPoolDataFormat(value string) MaxPoolAttr {
- return func(m optionalAttr) {
- m["data_format"] = value
- }
-}
-
-// Performs max pooling on the input.
-//
-// Arguments:
-// input: 4-D input to pool over.
-// ksize: The size of the window for each dimension of the input tensor.
-// strides: The stride of the sliding window for each dimension of the
-// input tensor.
-// padding: The type of padding algorithm to use.
-//
-// Returns The max pooled output tensor.
-func MaxPool(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MaxPool",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Assigns a new value to a variable.
//
// Any ReadVariableOp with a control dependency on this op is guaranteed to return
@@ -18608,69 +18782,6 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
-// SparseMatMulAttr is an optional argument to SparseMatMul.
-type SparseMatMulAttr func(optionalAttr)
-
-// SparseMatMulTransposeA sets the optional transpose_a attribute to value.
-// If not specified, defaults to false
-func SparseMatMulTransposeA(value bool) SparseMatMulAttr {
- return func(m optionalAttr) {
- m["transpose_a"] = value
- }
-}
-
-// SparseMatMulTransposeB sets the optional transpose_b attribute to value.
-// If not specified, defaults to false
-func SparseMatMulTransposeB(value bool) SparseMatMulAttr {
- return func(m optionalAttr) {
- m["transpose_b"] = value
- }
-}
-
-// SparseMatMulAIsSparse sets the optional a_is_sparse attribute to value.
-// If not specified, defaults to false
-func SparseMatMulAIsSparse(value bool) SparseMatMulAttr {
- return func(m optionalAttr) {
- m["a_is_sparse"] = value
- }
-}
-
-// SparseMatMulBIsSparse sets the optional b_is_sparse attribute to value.
-// If not specified, defaults to false
-func SparseMatMulBIsSparse(value bool) SparseMatMulAttr {
- return func(m optionalAttr) {
- m["b_is_sparse"] = value
- }
-}
-
-// Multiply matrix "a" by matrix "b".
-//
-// The inputs must be two-dimensional matrices and the inner dimension of "a" must
-// match the outer dimension of "b". This op is optimized for the case where at
-// least one of "a" or "b" is sparse. The breakeven for using this versus a dense
-// matrix multiply on one platform was 30% zero values in the sparse matrix.
-//
-// The gradient computation of this operation will only take advantage of sparsity
-// in the input gradient when that gradient comes from a Relu.
-func SparseMatMul(scope *Scope, a tf.Output, b tf.Output, optional ...SparseMatMulAttr) (product tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "SparseMatMul",
- Input: []tf.Input{
- a, b,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// ShapeAttr is an optional argument to Shape.
type ShapeAttr func(optionalAttr)
@@ -19171,88 +19282,58 @@ func SquaredDifference(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Forwards the input to the output.
-//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
-//
-// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
-//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
+// RandomGammaAttr is an optional argument to RandomGamma.
+type RandomGammaAttr func(optionalAttr)
-// Computes the gradient for the inverse of `x` wrt its input.
+// RandomGammaSeed sets the optional seed attribute to value.
//
-// Specifically, `grad = -dy * y*y`, where `y = 1/x`, and `dy`
-// is the corresponding input gradient.
-func ReciprocalGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReciprocalGrad",
- Input: []tf.Input{
- y, dy,
- },
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomGammaSeed(value int64) RandomGammaAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// Returns the min of x and y (i.e. x < y ? x : y) element-wise.
+// RandomGammaSeed2 sets the optional seed2 attribute to value.
//
-// *NOTE*: `Minimum` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Minimum",
- Input: []tf.Input{
- x, y,
- },
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomGammaSeed2(value int64) RandomGammaAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
}
-// Returns the element-wise sum of a list of tensors.
-//
-// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
-// wait for all of its inputs to be ready before beginning to sum. This can
-// save memory if inputs are ready at different times, since minimum temporary
-// storage is proportional to the output size rather than the inputs size.
-//
-// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+// Outputs random values from the Gamma distribution(s) described by alpha.
//
-// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+// This op uses the algorithm by Marsaglia et al. to acquire samples via
+// transformation-rejection from pairs of uniform and normal random variables.
+// See http://dl.acm.org/citation.cfm?id=358414
//
// Arguments:
-// inputs: A list of `Tensor` objects, each with same shape and type.
-// shape: Shape of elements of `inputs`.
-func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in alpha.
+// alpha: A tensor in which each scalar is a "shape" parameter describing the
+// associated gamma distribution.
+//
+// Returns A tensor with shape `shape + shape(alpha)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
+func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"shape": shape}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "AccumulateNV2",
+ Type: "RandomGamma",
Input: []tf.Input{
- tf.OutputList(inputs),
+ shape, alpha,
},
Attrs: attrs,
}
@@ -19308,60 +19389,24 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// RandomGammaAttr is an optional argument to RandomGamma.
-type RandomGammaAttr func(optionalAttr)
-
-// RandomGammaSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomGammaSeed(value int64) RandomGammaAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomGammaSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomGammaSeed2(value int64) RandomGammaAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from the Gamma distribution(s) described by alpha.
+// Forwards the input to the output.
//
-// This op uses the algorithm by Marsaglia et al. to acquire samples via
-// transformation-rejection from pairs of uniform and normal random variables.
-// See http://dl.acm.org/citation.cfm?id=358414
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
//
// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in alpha.
-// alpha: A tensor in which each scalar is a "shape" parameter describing the
-// associated gamma distribution.
+// input: A boolean scalar, representing the branch predicate of the Switch op.
//
-// Returns A tensor with shape `shape + shape(alpha)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `alpha[i0, i1, ...iN]`. The dtype of the output matches the dtype of alpha.
-func RandomGamma(scope *Scope, shape tf.Output, alpha tf.Output, optional ...RandomGammaAttr) (output tf.Output) {
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "RandomGamma",
+ Type: "LoopCond",
Input: []tf.Input{
- shape, alpha,
+ input,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -19464,49 +19509,82 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
-// RandomShuffleAttr is an optional argument to RandomShuffle.
-type RandomShuffleAttr func(optionalAttr)
+// Computes gradients for SparseSegmentSqrtN.
+//
+// Returns tensor "output" with same shape as grad, except for dimension 0 whose
+// value is output_dim0.
+//
+// Arguments:
+// grad: gradient propagated to the SparseSegmentSqrtN op.
+// indices: indices passed to the corresponding SparseSegmentSqrtN op.
+// segment_ids: segment_ids passed to the corresponding SparseSegmentSqrtN op.
+// output_dim0: dimension 0 of "data" passed to SparseSegmentSqrtN op.
+func SparseSegmentSqrtNGrad(scope *Scope, grad tf.Output, indices tf.Output, segment_ids tf.Output, output_dim0 tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentSqrtNGrad",
+ Input: []tf.Input{
+ grad, indices, segment_ids, output_dim0,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
-// RandomShuffleSeed sets the optional seed attribute to value.
+// LRNGradAttr is an optional argument to LRNGrad.
+type LRNGradAttr func(optionalAttr)
+
+// LRNGradDepthRadius sets the optional depth_radius attribute to value.
//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomShuffleSeed(value int64) RandomShuffleAttr {
+// value: A depth radius.
+// If not specified, defaults to 5
+func LRNGradDepthRadius(value int64) LRNGradAttr {
return func(m optionalAttr) {
- m["seed"] = value
+ m["depth_radius"] = value
}
}
-// RandomShuffleSeed2 sets the optional seed2 attribute to value.
+// LRNGradBias sets the optional bias attribute to value.
//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomShuffleSeed2(value int64) RandomShuffleAttr {
+// value: An offset (usually > 0 to avoid dividing by 0).
+// If not specified, defaults to 1
+func LRNGradBias(value float32) LRNGradAttr {
return func(m optionalAttr) {
- m["seed2"] = value
+ m["bias"] = value
}
}
-// Randomly shuffles a tensor along its first dimension.
+// LRNGradAlpha sets the optional alpha attribute to value.
//
-// The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
-// to one and only one `output[i]`. For example, a mapping that might occur for a
-// 3x2 tensor is:
+// value: A scale factor, usually positive.
+// If not specified, defaults to 1
+func LRNGradAlpha(value float32) LRNGradAttr {
+ return func(m optionalAttr) {
+ m["alpha"] = value
+ }
+}
+
+// LRNGradBeta sets the optional beta attribute to value.
//
-// ```
-// [[1, 2], [[5, 6],
-// [3, 4], ==> [1, 2],
-// [5, 6]] [3, 4]]
-// ```
+// value: An exponent.
+// If not specified, defaults to 0.5
+func LRNGradBeta(value float32) LRNGradAttr {
+ return func(m optionalAttr) {
+ m["beta"] = value
+ }
+}
+
+// Gradients for Local Response Normalization.
//
// Arguments:
-// value: The tensor to be shuffled.
+// input_grads: 4-D with shape `[batch, height, width, channels]`.
+// input_image: 4-D with shape `[batch, height, width, channels]`.
+// output_image: 4-D with shape `[batch, height, width, channels]`.
//
-// Returns A tensor of same shape and type as `value`, shuffled along its first
-// dimension.
-func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr) (output tf.Output) {
+// Returns The gradients for LRN.
+func LRNGrad(scope *Scope, input_grads tf.Output, input_image tf.Output, output_image tf.Output, optional ...LRNGradAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -19515,9 +19593,9 @@ func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr)
a(attrs)
}
opspec := tf.OpSpec{
- Type: "RandomShuffle",
+ Type: "LRNGrad",
Input: []tf.Input{
- value,
+ input_grads, input_image, output_image,
},
Attrs: attrs,
}
@@ -19525,57 +19603,413 @@ func RandomShuffle(scope *Scope, value tf.Output, optional ...RandomShuffleAttr)
return op.Output(0)
}
-// OrderedMapIncompleteSizeAttr is an optional argument to OrderedMapIncompleteSize.
-type OrderedMapIncompleteSizeAttr func(optionalAttr)
+// AnyAttr is an optional argument to Any.
+type AnyAttr func(optionalAttr)
-// OrderedMapIncompleteSizeCapacity sets the optional capacity attribute to value.
-// If not specified, defaults to 0
+// AnyKeepDims sets the optional keep_dims attribute to value.
//
-// REQUIRES: value >= 0
-func OrderedMapIncompleteSizeCapacity(value int64) OrderedMapIncompleteSizeAttr {
+// value: If true, retain reduced dimensions with length 1.
+// If not specified, defaults to false
+func AnyKeepDims(value bool) AnyAttr {
return func(m optionalAttr) {
- m["capacity"] = value
+ m["keep_dims"] = value
}
}
-// OrderedMapIncompleteSizeMemoryLimit sets the optional memory_limit attribute to value.
-// If not specified, defaults to 0
+// Computes the "logical or" of elements across dimensions of a tensor.
//
-// REQUIRES: value >= 0
-func OrderedMapIncompleteSizeMemoryLimit(value int64) OrderedMapIncompleteSizeAttr {
+// Reduces `input` along the dimensions given in `axis`. Unless
+// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
+// `axis`. If `keep_dims` is true, the reduced dimensions are
+// retained with length 1.
+//
+// Arguments:
+// input: The tensor to reduce.
+// axis: The dimensions to reduce. Must be in the range
+// `[-rank(input), rank(input))`.
+//
+// Returns The reduced tensor.
+func Any(scope *Scope, input tf.Output, axis tf.Output, optional ...AnyAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Any",
+ Input: []tf.Input{
+ input, axis,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a sequence of numbers.
+//
+// This operation creates a sequence of numbers that begins at `start` and
+// extends by increments of `delta` up to but not including `limit`.
+//
+// For example:
+//
+// ```
+// # 'start' is 3
+// # 'limit' is 18
+// # 'delta' is 3
+// tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+// ```
+//
+// Arguments:
+// start: 0-D (scalar). First entry in the sequence.
+// limit: 0-D (scalar). Upper limit of sequence, exclusive.
+// delta: 0-D (scalar). Optional. Default is 1. Number that increments `start`.
+//
+// Returns 1-D.
+func Range(scope *Scope, start tf.Output, limit tf.Output, delta tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Range",
+ Input: []tf.Input{
+ start, limit, delta,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// DestroyResourceOpAttr is an optional argument to DestroyResourceOp.
+type DestroyResourceOpAttr func(optionalAttr)
+
+// DestroyResourceOpIgnoreLookupError sets the optional ignore_lookup_error attribute to value.
+//
+// value: whether to ignore the error when the resource
+// doesn't exist.
+// If not specified, defaults to true
+func DestroyResourceOpIgnoreLookupError(value bool) DestroyResourceOpAttr {
return func(m optionalAttr) {
- m["memory_limit"] = value
+ m["ignore_lookup_error"] = value
}
}
-// OrderedMapIncompleteSizeContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func OrderedMapIncompleteSizeContainer(value string) OrderedMapIncompleteSizeAttr {
+// Deletes the resource specified by the handle.
+//
+// All subsequent operations using the resource will result in a NotFound
+// error status.
+//
+// Arguments:
+// resource: handle to the resource to delete.
+//
+// Returns the created operation.
+func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyResourceOpAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DestroyResourceOp",
+ Input: []tf.Input{
+ resource,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Generates values in an interval.
+//
+// A sequence of `num` evenly-spaced values are generated beginning at `start`.
+// If `num > 1`, the values in the sequence increase by `stop - start / num - 1`,
+// so that the last one is exactly `stop`.
+//
+// For example:
+//
+// ```
+// tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0]
+// ```
+//
+// Arguments:
+// start: First entry in the range.
+// stop: Last entry in the range.
+// num: Number of values to generate.
+//
+// Returns 1-D. The generated values.
+func LinSpace(scope *Scope, start tf.Output, stop tf.Output, num tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LinSpace",
+ Input: []tf.Input{
+ start, stop, num,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ComplexAttr is an optional argument to Complex.
+type ComplexAttr func(optionalAttr)
+
+// ComplexTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_COMPLEX64
+func ComplexTout(value tf.DataType) ComplexAttr {
return func(m optionalAttr) {
- m["container"] = value
+ m["Tout"] = value
}
}
-// OrderedMapIncompleteSizeSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func OrderedMapIncompleteSizeSharedName(value string) OrderedMapIncompleteSizeAttr {
+// Converts two real numbers to a complex number.
+//
+// Given a tensor `real` representing the real part of a complex number, and a
+// tensor `imag` representing the imaginary part of a complex number, this
+// operation returns complex numbers elementwise of the form \\(a + bj\\), where
+// *a* represents the `real` part and *b* represents the `imag` part.
+//
+// The input tensors `real` and `imag` must have the same shape.
+//
+// For example:
+//
+// ```
+// # tensor 'real' is [2.25, 3.25]
+// # tensor `imag` is [4.75, 5.75]
+// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
+// ```
+func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Complex",
+ Input: []tf.Input{
+ real, imag,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ImagAttr is an optional argument to Imag.
+type ImagAttr func(optionalAttr)
+
+// ImagTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func ImagTout(value tf.DataType) ImagAttr {
return func(m optionalAttr) {
- m["shared_name"] = value
+ m["Tout"] = value
}
}
-// Op returns the number of incomplete elements in the underlying container.
-func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...OrderedMapIncompleteSizeAttr) (size tf.Output) {
+// Returns the imaginary part of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the imaginary part of each element in `input`. All
+// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
+// is the real part and *b* is the imaginary part returned by this operation.
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.imag(input) ==> [4.75, 5.75]
+// ```
+func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"dtypes": dtypes}
+ attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
- Type: "OrderedMapIncompleteSize",
+ Type: "Imag",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the maximum along segments of a tensor.
+//
+// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
+// segments.
+//
+// Computes a tensor such that
+// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
+// that `segment_ids[j] == i`.
+//
+// If the max is empty for a given segment ID `i`, `output[i] = 0`.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// first dimension. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SegmentMax",
+ Input: []tf.Input{
+ data, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes hyperbolic tangent of `x` element-wise.
+func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tanh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a dataset that skips `count` elements from the `input_dataset`.
+//
+// Arguments:
+//
+// count: A scalar representing the number of elements from the `input_dataset`
+// that should be skipped. If count is -1, skips everything.
+//
+//
+func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "SkipDataset",
+ Input: []tf.Input{
+ input_dataset, count,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// RealAttr is an optional argument to Real.
+type RealAttr func(optionalAttr)
+
+// RealTout sets the optional Tout attribute to value.
+// If not specified, defaults to DT_FLOAT
+func RealTout(value tf.DataType) RealAttr {
+ return func(m optionalAttr) {
+ m["Tout"] = value
+ }
+}
+// Returns the real part of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// type `float` that is the real part of each element in `input`. All elements in
+// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
+// part returned by this operation and *b* is the imaginary part.
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.real(input) ==> [-2.25, 3.25]
+// ```
+func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Real",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResizeAreaAttr is an optional argument to ResizeArea.
+type ResizeAreaAttr func(optionalAttr)
+
+// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
+//
+// value: If true, the centers of the 4 corner pixels of the input and output tensors are
+// aligned, preserving the values at the corner pixels. Defaults to false.
+// If not specified, defaults to false
+func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
+ return func(m optionalAttr) {
+ m["align_corners"] = value
+ }
+}
+
+// Resize `images` to `size` using area interpolation.
+//
+// Input images can be of different types but output images are always float.
+//
+// The range of pixel values for the output image might be slightly different
+// from the range for the input image because of limited numerical precision.
+// To guarantee an output range, for example `[0.0, 1.0]`, apply
+// `tf.clip_by_value` to the output.
+//
+// Each output pixel is computed by first transforming the pixel's footprint into
+// the input tensor and then averaging the pixels that intersect the footprint. An
+// input pixel's contribution to the average is weighted by the fraction of its
+// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
+//
+// Arguments:
+// images: 4-D with shape `[batch, height, width, channels]`.
+// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
+//
+// Returns 4-D with shape
+// `[batch, new_height, new_width, channels]`.
+func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResizeArea",
+ Input: []tf.Input{
+ images, size,
+ },
Attrs: attrs,
}
op := scope.AddOperation(opspec)
@@ -20134,83 +20568,6 @@ func QuantizedAdd(scope *Scope, x tf.Output, y tf.Output, min_x tf.Output, max_x
return op.Output(0), op.Output(1), op.Output(2)
}
-// MfccAttr is an optional argument to Mfcc.
-type MfccAttr func(optionalAttr)
-
-// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
-//
-// value: The highest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 4000
-func MfccUpperFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["upper_frequency_limit"] = value
- }
-}
-
-// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
-//
-// value: The lowest frequency to use when calculating the
-// ceptstrum.
-// If not specified, defaults to 20
-func MfccLowerFrequencyLimit(value float32) MfccAttr {
- return func(m optionalAttr) {
- m["lower_frequency_limit"] = value
- }
-}
-
-// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
-//
-// value: Resolution of the Mel bank used internally.
-// If not specified, defaults to 40
-func MfccFilterbankChannelCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["filterbank_channel_count"] = value
- }
-}
-
-// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
-//
-// value: How many output channels to produce per time slice.
-// If not specified, defaults to 13
-func MfccDctCoefficientCount(value int64) MfccAttr {
- return func(m optionalAttr) {
- m["dct_coefficient_count"] = value
- }
-}
-
-// Transforms a spectrogram into a form that's useful for speech recognition.
-//
-// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
-// been effective as an input feature for machine learning. They are created by
-// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
-// higher frequencies that are less significant to the human ear. They have a long
-// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
-// is a good resource to learn more.
-//
-// Arguments:
-// spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
-// set to true.
-// sample_rate: How many samples per second the source audio used.
-func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Mfcc",
- Input: []tf.Input{
- spectrogram, sample_rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Given a quantized tensor described by (input, input_min, input_max), outputs a
//
// range that covers the actual values present in that tensor. This op is
@@ -21433,7 +21790,7 @@ func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
//
// The `bad_color` argument is the color to use in the generated images for
-// non-finite input values. It is a `uint8` 1-D tensor of length `channels`.
+// non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
// Each element must be in the range `[0, 255]` (It represents the value of a
// pixel in the output image). Non-finite values in the input tensor are
// replaced by this tensor in the output image. The default value is the color
@@ -21773,54 +22130,130 @@ func TensorListGetItem(scope *Scope, input_handle tf.Output, index tf.Output, el
return op.Output(0)
}
-// Computes the matrix exponential of one or more square matrices:
+// Returns a diagonal tensor with a given diagonal values.
//
-// exp(A) = \sum_{n=0}^\infty A^n/n!
+// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
+// everything else padded with zeros. The diagonal is computed as follows:
//
-// The exponential is computed using a combination of the scaling and squaring
-// method and the Pade approximation. Details can be founds in:
-// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
+// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
+// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the exponential for all input submatrices `[..., :, :]`.
+// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
+//
+// For example:
+//
+// ```
+// # 'diagonal' is [1, 2, 3, 4]
+// tf.diag(diagonal) ==> [[1, 0, 0, 0]
+// [0, 2, 0, 0]
+// [0, 0, 3, 0]
+// [0, 0, 0, 4]]
+// ```
//
// Arguments:
-// input: Shape is `[..., M, M]`.
+// diagonal: Rank k tensor where k is at most 1.
+func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Diag",
+ Input: []tf.Input{
+ diagonal,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal.
+type ParameterizedTruncatedNormalAttr func(optionalAttr)
+
+// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value.
//
-// Returns Shape is `[..., M, M]`.
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value.
//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.expm
-// @end_compatibility
-func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random values from a normal distribution. The parameters may each be a
+//
+// scalar which applies to the entire output, or a vector of length shape[0] which
+// stores the parameters for each batch.
+//
+// Arguments:
+// shape: The shape of the output tensor. Batches are indexed by the 0th dimension.
+// means: The mean parameter of each batch.
+// stdevs: The standard deviation parameter of each batch. Must be greater than 0.
+// minvals: The minimum cutoff. May be -infinity.
+// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval
+// for each batch.
+//
+// Returns A matrix of shape num_batches x samples_per_batch, filled with random
+// truncated normal values using the parameters for each row.
+func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "MatrixExponential",
+ Type: "ParameterizedTruncatedNormal",
Input: []tf.Input{
- input,
+ shape, means, stdevs, minvals, maxvals,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes the matrix logarithm of one or more square matrices:
+// Sets the index-th position of the list to contain the given tensor.
//
+// input_handle: the list
+// index: the position in the list to which the tensor will be assigned
+// item: the element to be assigned to that position
+// output_handle: the new list, with the element in the proper position
//
-// log(exp(A)) = A
+func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorListSetItem",
+ Input: []tf.Input{
+ input_handle, index, item,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the matrix exponential of one or more square matrices:
//
-// This op is only defined for complex matrices. If A is positive-definite and
-// real, then casting to a complex matrix, taking the logarithm and casting back
-// to a real matrix will give the correct result.
+// exp(A) = \sum_{n=0}^\infty A^n/n!
//
-// This function computes the matrix logarithm using the Schur-Parlett algorithm.
-// Details of the algorithm can be found in Section 11.6.2 of:
-// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008.
-// ISBN 978-0-898716-46-7.
+// The exponential is computed using a combination of the scaling and squaring
+// method and the Pade approximation. Details can be founds in:
+// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
+// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
//
// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
// form square matrices. The output is a tensor of the same shape as the input
@@ -21832,14 +22265,14 @@ func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
// Returns Shape is `[..., M, M]`.
//
// @compatibility(scipy)
-// Equivalent to scipy.linalg.logm
+// Equivalent to scipy.linalg.expm
// @end_compatibility
-func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) {
+func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "MatrixLogarithm",
+ Type: "MatrixExponential",
Input: []tf.Input{
input,
},
@@ -22148,6 +22581,53 @@ func AdjustSaturation(scope *Scope, images tf.Output, scale tf.Output) (output t
return op.Output(0)
}
+// MatrixSolveAttr is an optional argument to MatrixSolve.
+type MatrixSolveAttr func(optionalAttr)
+
+// MatrixSolveAdjoint sets the optional adjoint attribute to value.
+//
+// value: Boolean indicating whether to solve with `matrix` or its (block-wise)
+// adjoint.
+// If not specified, defaults to false
+func MatrixSolveAdjoint(value bool) MatrixSolveAttr {
+ return func(m optionalAttr) {
+ m["adjoint"] = value
+ }
+}
+
+// Solves systems of linear equations.
+//
+// `Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
+// a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
+// satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+// If `adjoint` is `True` then each output matrix satisfies
+// `adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
+//
+// Arguments:
+// matrix: Shape is `[..., M, M]`.
+// rhs: Shape is `[..., M, K]`.
+//
+// Returns Shape is `[..., M, K]`.
+func MatrixSolve(scope *Scope, matrix tf.Output, rhs tf.Output, optional ...MatrixSolveAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixSolve",
+ Input: []tf.Input{
+ matrix, rhs,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// SvdAttr is an optional argument to Svd.
type SvdAttr func(optionalAttr)
@@ -23568,71 +24048,6 @@ func DecodeGif(scope *Scope, contents tf.Output) (image tf.Output) {
return op.Output(0)
}
-// Computes the gradient of the sigmoid of `x` wrt its input.
-//
-// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
-// `dy` is the corresponding input gradient.
-func SigmoidGrad(scope *Scope, y tf.Output, dy tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SigmoidGrad",
- Input: []tf.Input{
- y, dy,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Convert one or more images from HSV to RGB.
-//
-// Outputs a tensor of the same shape as the `images` tensor, containing the RGB
-// value of the pixels. The output is only well defined if the value in `images`
-// are in `[0,1]`.
-//
-// See `rgb_to_hsv` for a description of the HSV encoding.
-//
-// Arguments:
-// images: 1-D or higher rank. HSV data to convert. Last dimension must be size 3.
-//
-// Returns `images` converted to RGB.
-func HSVToRGB(scope *Scope, images tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "HSVToRGB",
- Input: []tf.Input{
- images,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Gets the next output from the given iterator.
//
// This operation is a synchronous version IteratorGetNext. It should only be used
@@ -23703,7 +24118,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value.
//
// value: The cropped area of the image must contain a fraction of the
-// supplied image within this range.
+// supplied image within in this range.
// If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
@@ -24212,6 +24627,46 @@ func NonMaxSuppressionV2(scope *Scope, boxes tf.Output, scores tf.Output, max_ou
return op.Output(0)
}
+// Computes the matrix logarithm of one or more square matrices:
+//
+//
+// log(exp(A)) = A
+//
+// This op is only defined for complex matrices. If A is positive-definite and
+// real, then casting to a complex matrix, taking the logarithm and casting back
+// to a real matrix will give the correct result.
+//
+// This function computes the matrix logarithm using the Schur-Parlett algorithm.
+// Details of the algorithm can be found in Section 11.6.2 of:
+// Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008.
+// ISBN 978-0-898716-46-7.
+//
+// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+// form square matrices. The output is a tensor of the same shape as the input
+// containing the exponential for all input submatrices `[..., :, :]`.
+//
+// Arguments:
+// input: Shape is `[..., M, M]`.
+//
+// Returns Shape is `[..., M, M]`.
+//
+// @compatibility(scipy)
+// Equivalent to scipy.linalg.logm
+// @end_compatibility
+func MatrixLogarithm(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MatrixLogarithm",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// EncodeProtoAttr is an optional argument to EncodeProto.
type EncodeProtoAttr func(optionalAttr)
@@ -24845,6 +25300,41 @@ func LatencyStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Output, o
return op.Output(0)
}
+// Runs multiple additive regression ensemble predictors on input instances and
+//
+// computes the update to cached logits. It is designed to be used during training.
+// It traverses the trees starting from cached tree id and cached node id and
+// calculates the updates to be pushed to the cache.
+//
+// Arguments:
+//
+// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
+// tree of prediction.
+// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
+// node of prediction.
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for partial logits
+// shape.
+//
+// Returns Rank 2 Tensor containing logits update (with respect to cached
+// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
+func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesTrainingPredict",
+ Input: []tf.Input{
+ tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// MapSizeAttr is an optional argument to MapSize.
type MapSizeAttr func(optionalAttr)
@@ -25972,6 +26462,53 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
return op.Output(0)
}
+// AvgPool3DAttr is an optional argument to AvgPool3D.
+type AvgPool3DAttr func(optionalAttr)
+
+// AvgPool3DDataFormat sets the optional data_format attribute to value.
+//
+// value: The data format of the input and output data. With the
+// default format "NDHWC", the data is stored in the order of:
+// [batch, in_depth, in_height, in_width, in_channels].
+// Alternatively, the format could be "NCDHW", the data storage order is:
+// [batch, in_channels, in_depth, in_height, in_width].
+// If not specified, defaults to "NDHWC"
+func AvgPool3DDataFormat(value string) AvgPool3DAttr {
+ return func(m optionalAttr) {
+ m["data_format"] = value
+ }
+}
+
+// Performs 3D average pooling on the input.
+//
+// Arguments:
+// input: Shape `[batch, depth, rows, cols, channels]` tensor to pool over.
+// ksize: 1-D tensor of length 5. The size of the window for each dimension of
+// the input tensor. Must have `ksize[0] = ksize[4] = 1`.
+// strides: 1-D tensor of length 5. The stride of the sliding window for each
+// dimension of `input`. Must have `strides[0] = strides[4] = 1`.
+// padding: The type of padding algorithm to use.
+//
+// Returns The average pooled output tensor.
+func AvgPool3D(scope *Scope, input tf.Output, ksize []int64, strides []int64, padding string, optional ...AvgPool3DAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AvgPool3D",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -27245,122 +27782,6 @@ func TensorArrayConcatV3(scope *Scope, handle tf.Output, flow_in tf.Output, dtyp
return op.Output(0), op.Output(1)
}
-// ParameterizedTruncatedNormalAttr is an optional argument to ParameterizedTruncatedNormal.
-type ParameterizedTruncatedNormalAttr func(optionalAttr)
-
-// ParameterizedTruncatedNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func ParameterizedTruncatedNormalSeed(value int64) ParameterizedTruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// ParameterizedTruncatedNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func ParameterizedTruncatedNormalSeed2(value int64) ParameterizedTruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from a normal distribution. The parameters may each be a
-//
-// scalar which applies to the entire output, or a vector of length shape[0] which
-// stores the parameters for each batch.
-//
-// Arguments:
-// shape: The shape of the output tensor. Batches are indexed by the 0th dimension.
-// means: The mean parameter of each batch.
-// stdevs: The standard deviation parameter of each batch. Must be greater than 0.
-// minvals: The minimum cutoff. May be -infinity.
-// maxvals: The maximum cutoff. May be +infinity, and must be more than the minval
-// for each batch.
-//
-// Returns A matrix of shape num_batches x samples_per_batch, filled with random
-// truncated normal values using the parameters for each row.
-func ParameterizedTruncatedNormal(scope *Scope, shape tf.Output, means tf.Output, stdevs tf.Output, minvals tf.Output, maxvals tf.Output, optional ...ParameterizedTruncatedNormalAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ParameterizedTruncatedNormal",
- Input: []tf.Input{
- shape, means, stdevs, minvals, maxvals,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Sets the index-th position of the list to contain the given tensor.
-//
-// input_handle: the list
-// index: the position in the list to which the tensor will be assigned
-// item: the element to be assigned to that position
-// output_handle: the new list, with the element in the proper position
-//
-func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, item tf.Output) (output_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TensorListSetItem",
- Input: []tf.Input{
- input_handle, index, item,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a diagonal tensor with a given diagonal values.
-//
-// Given a `diagonal`, this operation returns a tensor with the `diagonal` and
-// everything else padded with zeros. The diagonal is computed as follows:
-//
-// Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of
-// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
-//
-// `output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else.
-//
-// For example:
-//
-// ```
-// # 'diagonal' is [1, 2, 3, 4]
-// tf.diag(diagonal) ==> [[1, 0, 0, 0]
-// [0, 2, 0, 0]
-// [0, 0, 3, 0]
-// [0, 0, 0, 4]]
-// ```
-//
-// Arguments:
-// diagonal: Rank k tensor where k is at most 1.
-func Diag(scope *Scope, diagonal tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Diag",
- Input: []tf.Input{
- diagonal,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split the data from the input value into TensorArray elements.
//
// Assuming that `lengths` takes on values
@@ -29304,6 +29725,26 @@ func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Returns a tensor of zeros with the same shape and type as x.
+//
+// Arguments:
+// x: a tensor of type T.
+//
+// Returns a tensor of the same shape and type as x but filled with zeros.
+func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ZerosLike",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// AbortAttr is an optional argument to Abort.
type AbortAttr func(optionalAttr)
@@ -29649,41 +30090,6 @@ func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Outpu
return scope.AddOperation(opspec)
}
-// Runs multiple additive regression ensemble predictors on input instances and
-//
-// computes the update to cached logits. It is designed to be used during training.
-// It traverses the trees starting from cached tree id and cached node id and
-// calculates the updates to be pushed to the cache.
-//
-// Arguments:
-//
-// cached_tree_ids: Rank 1 Tensor containing cached tree ids which is the starting
-// tree of prediction.
-// cached_node_ids: Rank 1 Tensor containing cached node id which is the starting
-// node of prediction.
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for partial logits
-// shape.
-//
-// Returns Rank 2 Tensor containing logits update (with respect to cached
-// values stored) for each example.Rank 1 Tensor containing new tree ids for each example.Rank 1 Tensor containing new node ids in the new tree_ids.
-func BoostedTreesTrainingPredict(scope *Scope, tree_ensemble_handle tf.Output, cached_tree_ids tf.Output, cached_node_ids tf.Output, bucketized_features []tf.Output, logits_dimension int64) (partial_logits tf.Output, tree_ids tf.Output, node_ids tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesTrainingPredict",
- Input: []tf.Input{
- tree_ensemble_handle, cached_tree_ids, cached_node_ids, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Elementwise computes the bitwise AND of `x` and `y`.
//
// The result will have those bits set, that are set in both `x` and `y`. The
@@ -30304,409 +30710,3 @@ func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (aud
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Concatenates a list of `N` tensors along the first dimension.
-//
-// The input tensors are all required to have size 1 in the first dimension.
-//
-// For example:
-//
-// ```
-// # 'x' is [[1, 4]]
-// # 'y' is [[2, 5]]
-// # 'z' is [[3, 6]]
-// parallel_concat([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
-// ```
-//
-// The difference between concat and parallel_concat is that concat requires all
-// of the inputs be computed before the operation will begin but doesn't require
-// that the input shapes be known during graph construction. Parallel concat
-// will copy pieces of the input into the output as they become available, in
-// some situations this can provide a performance benefit.
-//
-// Arguments:
-// values: Tensors to be concatenated. All must have size 1 in the first dimension
-// and same shape.
-// shape: the final shape of the result; should be equal to the shapes of any input
-// but with the number of input values in the first dimension.
-//
-// Returns The concatenated tensor.
-func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"shape": shape}
- opspec := tf.OpSpec{
- Type: "ParallelConcat",
- Input: []tf.Input{
- tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Subtracts `v` into specified rows of `x`.
-//
-// Computes y = x; y[i, :] -= v; return y.
-//
-// Arguments:
-// x: A `Tensor` of type T.
-// i: A vector. Indices into the left-most dimension of `x`.
-// v: A `Tensor` of type T. Same dimension sizes as x except the first dimension, which must be the same as i's size.
-//
-// Returns A `Tensor` of type T. An alias of `x`. The content of `y` is undefined if there are duplicates in `i`.
-func InplaceSub(scope *Scope, x tf.Output, i tf.Output, v tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "InplaceSub",
- Input: []tf.Input{
- x, i, v,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Converts a flat index or array of flat indices into a tuple of
-//
-// coordinate arrays.
-//
-// @compatibility(numpy)
-// Equivalent to np.unravel_index
-// @end_compatibility
-//
-// Arguments:
-// indices: An 0-D or 1-D `int` Tensor whose elements are indices into the
-// flattened version of an array of dimensions dims.
-// dims: An 1-D `int` Tensor. The shape of the array to use for unraveling
-// indices.
-//
-// Returns An 2-D (or 1-D if indices is 0-D) tensor where each row has the
-// same shape as the indices array.
-func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "UnravelIndex",
- Input: []tf.Input{
- indices, dims,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
-//
-// The lower regularized incomplete Gamma function is defined as:
-//
-//
-// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
-//
-// where
-//
-// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
-//
-// is the lower incomplete Gamma function.
-//
-// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
-// Gamma function.
-func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Igamma",
- Input: []tf.Input{
- a, x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes offsets of concat inputs within its output.
-//
-// For example:
-//
-// ```
-// # 'x' is [2, 2, 7]
-// # 'y' is [2, 3, 7]
-// # 'z' is [2, 5, 7]
-// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
-// ```
-//
-// This is typically used by gradient computations for a concat operation.
-//
-// Arguments:
-// concat_dim: The dimension along which to concatenate.
-// shape: The `N` int32 vectors representing shape of tensors being concatenated.
-//
-// Returns The `N` int32 vectors representing the starting offset
-// of input tensors within the concatenated output.
-func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ConcatOffset",
- Input: []tf.Input{
- concat_dim, tf.OutputList(shape),
- },
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
- scope.UpdateErr("ConcatOffset", err)
- return
- }
- return offset
-}
-
-// Splits a tensor into `num_split` tensors along one dimension.
-//
-// Arguments:
-// axis: 0-D. The dimension along which to split. Must be in the range
-// `[-rank(value), rank(value))`.
-// value: The tensor to split.
-// num_split: The number of ways to split. Must evenly divide
-// `value.shape[split_dim]`.
-//
-// Returns They are identically shaped tensors, whose shape matches that of `value`
-// except along `axis`, where their sizes are
-// `values.shape[split_dim] / num_split`.
-func Split(scope *Scope, axis tf.Output, value tf.Output, num_split int64) (output []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_split": num_split}
- opspec := tf.OpSpec{
- Type: "Split",
- Input: []tf.Input{
- axis, value,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("Split", err)
- return
- }
- return output
-}
-
-// Splits a tensor into `num_split` tensors along one dimension.
-//
-// Arguments:
-// value: The tensor to split.
-// size_splits: list containing the sizes of each output tensor along the split
-// dimension. Must sum to the dimension of value along split_dim.
-// Can contain one -1 indicating that dimension is to be inferred.
-// axis: 0-D. The dimension along which to split. Must be in the range
-// `[-rank(value), rank(value))`.
-//
-//
-// Returns Tensors whose shape matches that of `value`
-// except along `axis`, where their sizes are
-// `size_splits[i]`.
-func SplitV(scope *Scope, value tf.Output, size_splits tf.Output, axis tf.Output, num_split int64) (output []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_split": num_split}
- opspec := tf.OpSpec{
- Type: "SplitV",
- Input: []tf.Input{
- value, size_splits, axis,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- scope.UpdateErr("SplitV", err)
- return
- }
- return output
-}
-
-// Gives a guarantee to the TF runtime that the input tensor is a constant.
-//
-// The runtime is then free to make optimizations based on this.
-//
-// Only accepts value typed tensors as inputs and rejects resource variable handles
-// as input.
-//
-// Returns the input tensor without modification.
-func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "GuaranteeConst",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a tensor of zeros with the same shape and type as x.
-//
-// Arguments:
-// x: a tensor of type T.
-//
-// Returns a tensor of the same shape and type as x but filled with zeros.
-func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ZerosLike",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// QuantizedInstanceNormAttr is an optional argument to QuantizedInstanceNorm.
-type QuantizedInstanceNormAttr func(optionalAttr)
-
-// QuantizedInstanceNormOutputRangeGiven sets the optional output_range_given attribute to value.
-//
-// value: If True, `given_y_min` and `given_y_min`
-// and `given_y_max` are used as the output range. Otherwise,
-// the implementation computes the output range.
-// If not specified, defaults to false
-func QuantizedInstanceNormOutputRangeGiven(value bool) QuantizedInstanceNormAttr {
- return func(m optionalAttr) {
- m["output_range_given"] = value
- }
-}
-
-// QuantizedInstanceNormGivenYMin sets the optional given_y_min attribute to value.
-//
-// value: Output in `y_min` if `output_range_given` is True.
-// If not specified, defaults to 0
-func QuantizedInstanceNormGivenYMin(value float32) QuantizedInstanceNormAttr {
- return func(m optionalAttr) {
- m["given_y_min"] = value
- }
-}
-
-// QuantizedInstanceNormGivenYMax sets the optional given_y_max attribute to value.
-//
-// value: Output in `y_max` if `output_range_given` is True.
-// If not specified, defaults to 0
-func QuantizedInstanceNormGivenYMax(value float32) QuantizedInstanceNormAttr {
- return func(m optionalAttr) {
- m["given_y_max"] = value
- }
-}
-
-// QuantizedInstanceNormVarianceEpsilon sets the optional variance_epsilon attribute to value.
-//
-// value: A small float number to avoid dividing by 0.
-// If not specified, defaults to 1e-05
-func QuantizedInstanceNormVarianceEpsilon(value float32) QuantizedInstanceNormAttr {
- return func(m optionalAttr) {
- m["variance_epsilon"] = value
- }
-}
-
-// QuantizedInstanceNormMinSeparation sets the optional min_separation attribute to value.
-//
-// value: Minimum value of `y_max - y_min`
-// If not specified, defaults to 0.001
-func QuantizedInstanceNormMinSeparation(value float32) QuantizedInstanceNormAttr {
- return func(m optionalAttr) {
- m["min_separation"] = value
- }
-}
-
-// Quantized Instance normalization.
-//
-// Arguments:
-// x: A 4D input Tensor.
-// x_min: The value represented by the lowest quantized input.
-// x_max: The value represented by the highest quantized input.
-//
-// Returns A 4D Tensor.The value represented by the lowest quantized output.The value represented by the highest quantized output.
-func QuantizedInstanceNorm(scope *Scope, x tf.Output, x_min tf.Output, x_max tf.Output, optional ...QuantizedInstanceNormAttr) (y tf.Output, y_min tf.Output, y_max tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "QuantizedInstanceNorm",
- Input: []tf.Input{
- x, x_min, x_max,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// Returns the diagonal part of the tensor.
-//
-// This operation returns a tensor with the `diagonal` part
-// of the `input`. The `diagonal` part is computed as follows:
-//
-// Assume `input` has dimensions `[D1,..., Dk, D1,..., Dk]`, then the output is a
-// tensor of rank `k` with dimensions `[D1,..., Dk]` where:
-//
-// `diagonal[i1,..., ik] = input[i1, ..., ik, i1,..., ik]`.
-//
-// For example:
-//
-// ```
-// # 'input' is [[1, 0, 0, 0]
-// [0, 2, 0, 0]
-// [0, 0, 3, 0]
-// [0, 0, 0, 4]]
-//
-// tf.diag_part(input) ==> [1, 2, 3, 4]
-// ```
-//
-// Arguments:
-// input: Rank k tensor where k is even and not zero.
-//
-// Returns The extracted diagonal.
-func DiagPart(scope *Scope, input tf.Output) (diagonal tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DiagPart",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 8fcad61f4c..25ec718703 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output {
return Output{op, i}
}
+// NumInputs returns the number of inputs of op.
+func (op *Operation) NumInputs() int {
+ return int(C.TF_OperationNumInputs(op.c))
+}
+
// Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to
@@ -123,6 +128,67 @@ func (p Output) c() C.TF_Output {
func (p Output) canBeAnInput() {}
+// Consumers returns the inputs that consume this output.
+func (p Output) Consumers() []Consumer {
+ max := int(C.TF_OperationOutputNumConsumers(p.c()))
+ if max == 0 {
+ return nil
+ }
+ inputs := make([]C.TF_Input, max)
+ n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
+ inputs = inputs[:int(n)]
+
+ var consumers []Consumer
+ for _, consumer := range inputs {
+ consumers = append(consumers, Consumer{
+ Index: int(consumer.index),
+ Op: &Operation{
+ c: consumer.oper,
+ g: p.Op.g,
+ },
+ })
+ }
+
+ return consumers
+}
+
+// Consumer identifies a specific input of an operation that consumes the output
+// of another operation.
+type Consumer struct {
+ // Op is the Operation that is consuming the output of another operation.
+ Op *Operation
+
+ // Index is the index of the input within Op that the output of another
+ // operation is connected to.
+ Index int
+}
+
+func (p Consumer) c() C.TF_Input {
+ if p.Op == nil {
+ // Attempt to provide a more useful panic message than "nil
+ // pointer dereference".
+ panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
+ }
+ return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
+}
+
+// DataType returns the type of the input.
+func (p Consumer) DataType() DataType {
+ return DataType(C.TF_OperationInputType(p.c()))
+}
+
+// Producer returns the Output that is connected to this Consumer.
+func (p Consumer) Producer() Output {
+ output := C.TF_OperationInput(p.c())
+ return Output{
+ Op: &Operation{
+ c: output.oper,
+ g: p.Op.g,
+ },
+ Index: int(output.index),
+ }
+}
+
// Input is the interface for specifying inputs to an operation being added to
// a Graph.
//
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 40c951ab8c..06b65bdfb7 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -166,6 +166,68 @@ func TestOutputDataTypeAndShape(t *testing.T) {
}
}
+func TestOperationInputs(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ y, err := Placeholder(g, "y", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ add, err := Add(g, "add", x, y)
+ if err != nil {
+ t.Fatal(err)
+ }
+ addOp := add.Op
+
+ if out := addOp.NumInputs(); out != 2 {
+ t.Fatalf("Got %d inputs, wanted 2", out)
+ }
+}
+
+func TestOperationConsumers(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ a, err := Neg(g, "a", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := Neg(g, "b", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ consumers := []*Operation{a.Op, b.Op}
+
+ xConsumers := x.Consumers()
+ if out := len(xConsumers); out != 2 {
+ t.Fatalf("Got %d consumers, wanted 2", out)
+ }
+
+ for i, consumer := range xConsumers {
+ got := consumer.Op.Name()
+ want := consumers[i].Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+
+ got = consumer.Producer().Op.Name()
+ want = x.Op.Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+ }
+
+ if len(b.Consumers()) != 0 {
+ t.Fatalf("expected %+v to have no consumers", b)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 2d25c04dc9..f3338f6595 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -131,13 +131,9 @@ func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error)
}
runtime.SetFinalizer(t, (*Tensor).finalize)
raw := tensorData(t.c)
- n, err := r.Read(raw)
- if err != nil {
+ if _, err := io.ReadFull(r, raw); err != nil {
return nil, err
}
- if uintptr(n) != nbytes {
- return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
- }
return t, nil
}
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 793c36dd4d..dc533cd3e1 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -18,6 +18,7 @@ package tensorflow
import (
"bytes"
+ "io"
"reflect"
"testing"
)
@@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) {
}
}
+func TestReadTensorReadAll(t *testing.T) {
+ // Get the bytes of a tensor.
+ a := []float32{1.1, 1.2, 1.3}
+ ats, err := NewTensor(a)
+ if err != nil {
+ t.Fatal(err)
+ }
+ abuf := new(bytes.Buffer)
+ if _, err := ats.WriteContentsTo(abuf); err != nil {
+ t.Fatal(err)
+ }
+
+ // Get the bytes of another tensor.
+ b := []float32{1.1, 1.2, 1.3}
+ bts, err := NewTensor(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bbuf := new(bytes.Buffer)
+ if _, err := bts.WriteContentsTo(bbuf); err != nil {
+ t.Fatal(err)
+ }
+
+ // Check that ReadTensor reads all bytes of both tensors, when the situation
+ // requires one than reads.
+ abbuf := io.MultiReader(abuf, bbuf)
+ abts, err := ReadTensor(Float, []int64{2, 3}, abbuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ abtsf32 := abts.Value().([][]float32)
+ expected := [][]float32{a, b}
+
+ if len(abtsf32) != 2 {
+ t.Fatalf("first dimension %d is not 2", len(abtsf32))
+ }
+ for i := 0; i < 2; i++ {
+ if len(abtsf32[i]) != 3 {
+ t.Fatalf("second dimension %d is not 3", len(abtsf32[i]))
+ }
+ for j := 0; j < 3; j++ {
+ if abtsf32[i][j] != expected[i][j] {
+ t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j])
+ }
+ }
+ }
+}
+
func benchmarkNewTensor(b *testing.B, v interface{}) {
for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil {
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 19d2133a55..73e210fae0 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -56,6 +56,10 @@ java_library(
srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
javacopts = JAVACOPTS,
resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
+ deps = [
+ "@com_google_guava",
+ "@com_squareup_javapoet",
+ ],
)
filegroup(
@@ -70,6 +74,7 @@ tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
+ "//tensorflow/core/api_def:java_api_def",
],
base_package = "org.tensorflow.op",
gen_tool = ":java_op_gen_tool",
diff --git a/tensorflow/java/maven/.gitignore b/tensorflow/java/maven/.gitignore
index ff080515d5..657e2a60bc 100644
--- a/tensorflow/java/maven/.gitignore
+++ b/tensorflow/java/maven/.gitignore
@@ -11,4 +11,10 @@ tensorflow/src
tensorflow/target
proto/src
proto/target
+hadoop/src
+hadoop/target
+spark-connector/src
+spark-connector/target
+spark-connector/dependency-reduced-pom.xml
+spark-connector/spark-warehouse
pom.xml.versionsBackup
diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md
index c7e8f03806..3e030dcd09 100644
--- a/tensorflow/java/maven/README.md
+++ b/tensorflow/java/maven/README.md
@@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release:
7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings
shared by all of the above.
+8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop.
+ The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop)
+
+9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord
+ using Apache Spark DataFrames. The source code for this package is available
+ in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector)
## Updating the release
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
new file mode 100644
index 0000000000..7391dfb965
--- /dev/null
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -0,0 +1,192 @@
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>hadoop</artifactId>
+ <packaging>jar</packaging>
+ <version>1.9.0</version>
+ <name>tensorflow-hadoop</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <maven.compiler.source>1.6</maven.compiler.source>
+ <maven.compiler.target>1.6</maven.compiler.target>
+ <hadoop.version>2.6.0</hadoop.version>
+ <protobuf.version>3.3.1</protobuf.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <licenses>
+ <license>
+ <name>Apache License Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ </license>
+ </licenses>
+
+ <scm>
+ <url>https://github.com/tensorflow/ecosystem.git</url>
+ <connection>git@github.com:tensorflow/ecosystem.git</connection>
+ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
+ </scm>
+
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9.1</version>
+ <executions>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-core</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-jobclient</artifactId>
+ <version>${hadoop.version}</version>
+ <type>test-jar</type>
+ <optional>true</optional>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ </dependencies>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profiles>
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+</project>
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 08cc860f57..d44bdf8f81 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index fcc7eacc33..e8925c6fb1 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 3d22d86a49..3bf4a2590c 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 0a09a5ea7c..b96dcf2888 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
@@ -32,6 +32,8 @@
<module>libtensorflow_jni_gpu</module>
<module>tensorflow</module>
<module>proto</module>
+ <module>hadoop</module>
+ <module>spark-connector</module>
</modules>
<!-- Two profiles are used:
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 77ec6a0ddb..5581d864d7 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
@@ -16,7 +16,7 @@
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
- <version>3.3.1</version>
+ <version>3.5.1</version>
</dependency>
</dependencies>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 6136ccfdfb..2240d6b7b9 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -19,6 +19,7 @@
RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow"
+TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
# By default we deploy to both ossrh and bintray. These two
# environment variables can be set to skip either repository.
@@ -31,7 +32,7 @@ if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then
# Bintray does not allow snapshots.
DEPLOY_BINTRAY="false"
fi
-PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.3.0/protoc-3.3.0-linux-x86_64.zip"
+PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip"
if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then
echo "Must deploy to at least one of Bintray or OSSRH" >&2
exit 2
@@ -44,7 +45,9 @@ clean() {
# (though if run inside a clean docker container, there won't be any dirty
# artifacts lying around)
mvn -q clean
- rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target libtensorflow/src libtensorflow/target tensorflow-android/target
+ rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target \
+ libtensorflow/src libtensorflow/target tensorflow-android/target proto/src proto/target \
+ hadoop/src hadoop/target spark-connector/src spark-connector/target
}
update_version_in_pom() {
@@ -183,6 +186,46 @@ generate_java_protos() {
rm -rf "${DIR}/proto/tmp"
}
+
+# Download the TensorFlow ecosystem source from git.
+# The pom files from this repo do not inherit from the parent pom so the maven version
+# is updated for each module.
+download_tf_ecosystem() {
+ ECOSYSTEM_DIR="/tmp/tensorflow-ecosystem"
+ HADOOP_DIR="${DIR}/hadoop"
+ SPARK_DIR="${DIR}/spark-connector"
+
+ # Clean any previous attempts
+ rm -rf "${ECOSYSTEM_DIR}"
+
+ # Clone the TensorFlow ecosystem project
+ mkdir -p "${ECOSYSTEM_DIR}"
+ cd "${ECOSYSTEM_DIR}"
+ git clone "${TF_ECOSYSTEM_URL}"
+ cd ecosystem
+ # TF_VERSION is a semver string (<major>.<minor>.<patch>[-suffix])
+ # but the branch is just (r<major>.<minor>).
+ RELEASE_BRANCH=$(echo "${TF_VERSION}" | sed -e 's/\([0-9]\+\.[0-9]\+\)\.[0-9]\+.*/\1/')
+ git checkout r${RELEASE_BRANCH}
+
+ # Copy the TensorFlow Hadoop source
+ cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
+ cp "${ECOSYSTEM_DIR}/ecosystem/hadoop/pom.xml" "${HADOOP_DIR}"
+ cd "${HADOOP_DIR}"
+ update_version_in_pom
+
+ # Copy the TensorFlow Spark connector source
+ cp -r "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/src" "${SPARK_DIR}"
+ cp "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/pom.xml" "${SPARK_DIR}"
+ cd "${SPARK_DIR}"
+ update_version_in_pom
+
+ # Cleanup
+ rm -rf "${ECOSYSTEM_DIR}"
+
+ cd "${DIR}"
+}
+
# Deploy artifacts using a specific profile.
# Arguments:
# profile - name of selected profile.
@@ -240,7 +283,8 @@ cd "${DIR}"
# Comment lines out appropriately if debugging/tinkering with the release
# process.
# gnupg2 is required for signing
-apt-get -qq update && apt-get -qqq install -y gnupg2
+apt-get -qq update && apt-get -qqq install -y gnupg2 git
+
clean
update_version_in_pom
download_libtensorflow
@@ -248,6 +292,8 @@ download_libtensorflow_jni
download_libtensorflow_jni_gpu
update_tensorflow_android
generate_java_protos
+download_tf_ecosystem
+
# Build the release artifacts
mvn verify
# Push artifacts to repository
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
new file mode 100644
index 0000000000..64956be02c
--- /dev/null
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -0,0 +1,349 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>spark-connector_2.11</artifactId>
+ <packaging>jar</packaging>
+ <version>1.9.0</version>
+ <name>spark-tensorflow-connector</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
+
+ <licenses>
+ <license>
+ <name>The Apache Software License, Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ <distribution>repo</distribution>
+ </license>
+ </licenses>
+
+ <scm>
+ <url>https://github.com/tensorflow/ecosystem.git</url>
+ <connection>git@github.com:tensorflow/ecosystem.git</connection>
+ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
+ </scm>
+
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <scala.maven.version>3.2.2</scala.maven.version>
+ <scala.binary.version>2.11</scala.binary.version>
+ <scalatest.maven.version>1.0</scalatest.maven.version>
+ <scala.test.version>2.2.6</scala.test.version>
+ <maven.compiler.version>3.0</maven.compiler.version>
+ <java.version>1.8</java.version>
+ <spark.version>2.3.0</spark.version>
+ <yarn.api.version>2.7.3</yarn.api.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>${scala.maven.version}</version>
+ <executions>
+ <execution>
+ <id>compile</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>compile</goal>
+ </goals>
+ <configuration>
+ <jvmArgs>
+ <jvmArg>-Xms256m</jvmArg>
+ <jvmArg>-Xmx512m</jvmArg>
+ </jvmArgs>
+ <args>
+ <arg>-g:vars</arg>
+ <arg>-deprecation</arg>
+ <arg>-feature</arg>
+ <arg>-unchecked</arg>
+ <arg>-Xfatal-warnings</arg>
+ <arg>-language:implicitConversions</arg>
+ <arg>-language:existentials</arg>
+ </args>
+ </configuration>
+ </execution>
+ <execution>
+ <id>test</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>testCompile</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>doc-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <recompileMode>incremental</recompileMode>
+ <useZincServer>true</useZincServer>
+ <scalaVersion>${scala.binary.version}</scalaVersion>
+ <checkMultipleScalaVersions>false</checkMultipleScalaVersions>
+ </configuration>
+ </plugin>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <version>${scalatest.maven.version}</version>
+ <executions>
+ <execution>
+ <id>scalaTest</id>
+ <phase>test</phase>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- Shade protobuf dependency. -->
+ <plugin>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>3.1.0</version>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <minimizeJar>true</minimizeJar>
+ <artifactSet>
+ <includes>
+ <include>com.google.protobuf:protobuf-java</include>
+ <include>org.tensorflow:hadoop</include>
+ <include>org.tensorflow:proto</include>
+ </includes>
+ </artifactSet>
+ <filters>
+ <filter>
+ <!-- Remove the source to keep the result smaller. -->
+ <artifact>com.google.protobuf:protobuf-java</artifact>
+ <excludes>
+ <exclude>**/*.java</exclude>
+ </excludes>
+ </filter>
+ </filters>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>
+ org.tensorflow.spark.shaded.com.google.protobuf
+ </shadedPattern>
+ </relocation>
+ </relocations>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- GPG signed components: http://central.sonatype.org/pages/apache-maven.html#gpg-signed-components -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>${maven.compiler.version}</version>
+ <configuration>
+ <source>${java.version}</source>
+ <target>${java.version}</target>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-source-plugin</artifactId>
+ <version>2.2.1</version>
+ <executions>
+ <execution>
+ <id>attach-sources</id>
+ <goals>
+ <goal>jar-no-fork</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ <version>2.9.1</version>
+ <executions>
+ <execution>
+ <id>attach-javadocs</id>
+ <goals>
+ <goal>jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+ <profiles>
+ <profile>
+ <id>test</id>
+ <activation>
+ <activeByDefault>true</activeByDefault>
+ <property>
+ <name>!NEVERSETME</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ <dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <version>${scala.test.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </profile>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>hadoop</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <version>${yarn.api.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 0df1f28149..92e15aa2c7 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0</version>
+ <version>1.9.0</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index f5f54bf4d3..d9d6f8adc8 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
-#include <string>
#include <list>
#include <map>
+#include <string>
#include <utility>
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index debd95fc62..d5bd99bdd9 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
namespace java {
namespace {
-const char* kLicense =
+constexpr const char kLicense[] =
"/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
"\n"
"Licensed under the Apache License, Version 2.0 (the \"License\");\n"
@@ -376,9 +376,6 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
}
}
// op annotations
- op_class.add_annotation(
- Annotation::Create("Generated", "javax.annotation")
- .attributes("value = \"TensorFlow Java Op Generator\""));
if (endpoint.deprecated()) {
op_class.add_annotation(Annotation::Create("Deprecated"));
string explanation;
@@ -394,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
}
if (!op.hidden()) {
// expose the op in the Ops Graph API only if it is visible
- op_class.add_annotation(
- Annotation::Create("Operator", "org.tensorflow.op.annotation")
- .attributes("group = \"" + endpoint.package() + "\""));
+ Annotation oper_annot =
+ Annotation::Create("Operator", "org.tensorflow.op.annotation");
+ if (endpoint.package() != kDefaultEndpointPackage) {
+ oper_annot.attributes("group = \"" + endpoint.package() + "\"");
+ }
+ op_class.add_annotation(oper_annot);
}
// create op class file
const string op_dir_name = io::JoinPath(
@@ -415,8 +415,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
SourceFileWriter writer(op_file.get());
std::list<Type> dependencies;
CollectOpDependencies(op, mode, &dependencies);
- writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL,
- &dependencies, &op_javadoc);
+ writer.Write(kLicense)
+ .EndLine()
+ .Write("// This class has been generated, DO NOT EDIT!")
+ .EndLine()
+ .EndLine()
+ .BeginType(op_class, PUBLIC | FINAL, &dependencies, &op_javadoc);
if (!op.optional_attributes().empty()) {
RenderOptionsClass(op, op_class, &writer);
}
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 759d800ecf..05decd6b54 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
index 4bcfc7fe01..941ab2699c 100644
--- a/tensorflow/java/src/gen/cc/op_specs.cc
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include <map>
-#include <vector>
#include <string>
#include <utility>
+#include <vector>
#include "re2/re2.h"
#include "tensorflow/core/framework/op.h"
@@ -50,7 +50,7 @@ class TypeResolver {
// For example, if the argument's datatype is DT_STRING, this method will
// return "java.lang.String", so the argument can become "Operand<String>"
// in the Ops API
- Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
// Returns types of an input attribute
//
@@ -62,7 +62,7 @@ class TypeResolver {
// <java.lang.Float, float>, so the attribute can be used as a "Float" object
// in the Ops API and casted to a "float" when passing through the JNI layer.
std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
- bool *iterable_out);
+ bool* iterable_out);
// Returns true if the type of this attribute has already been resolved
bool IsAttributeVisited(const string& attr_name) {
@@ -89,14 +89,14 @@ class TypeResolver {
}
};
-Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
- bool* iterable_out) {
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = false;
if (!arg_def.number_attr().empty()) {
// when number_attr is set, argument has to be a list of tensors
*iterable_out = true;
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
}
+
Type type = Type::Wildcard();
if (arg_def.type() != DataType::DT_INVALID) {
// resolve type from DataType
@@ -153,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
} else {
LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
return type;
}
std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
- bool* iterable_out) {
+ bool* iterable_out) {
std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
*iterable_out = false;
StringPiece attr_type = attr_def.type();
@@ -184,7 +184,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else if (attr_type == "tensor") {
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
- .add_parameter(Type::Wildcard()));
+ .add_parameter(Type::Wildcard()));
} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
@@ -195,7 +195,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
return types;
@@ -218,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) {
return result;
}
-bool FindAndCut(re2::StringPiece* input, const RE2& expr,
- re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
- re2::StringPiece match;
- if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) {
- return false;
- }
- before_match->set(input->data(), match.begin() - input->begin());
- input->remove_prefix(match.end() - before_match->begin());
- if (ret_match != nullptr) {
- *ret_match = match;
- }
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+ string* ret_match = nullptr) {
+ string match;
+ if (!RE2::PartialMatch(*input, expr, &match)) return false;
+ *before_match = input->substr(0, input->find(match));
+ *input = input->substr(before_match->size() + match.size());
+ if (ret_match != nullptr) *ret_match = match;
return true;
}
-string ParseDocumentation(re2::StringPiece input) {
+string ParseDocumentation(const string& inp) {
std::stringstream javadoc_text;
// TODO(karllessard) This is a very minimalist utility method for converting
// markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
// for alternatives to increase the level of support for markups.
std::vector<string> markups_subexpr;
- markups_subexpr.push_back("\n+\\*\\s+"); // lists
- markups_subexpr.push_back("\n{2,}"); // paragraphs
+ markups_subexpr.push_back("\n+\\*\\s+"); // lists
+ markups_subexpr.push_back("\n{2,}"); // paragraphs
markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
- markups_subexpr.push_back("`+"); // inlined code and code blocks
+ markups_subexpr.push_back("`+"); // inlined code and code blocks
markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
- markups_subexpr.push_back("\\["); // hyperlinks
- const RE2 markup_expr(str_util::Join(markups_subexpr, "|"));
+ markups_subexpr.push_back("\\["); // hyperlinks
+ const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
bool in_list = false;
+ string input = inp;
while (true) {
- re2::StringPiece text;
- re2::StringPiece markup;
+ string text, markup;
if (!FindAndCut(&input, markup_expr, &text, &markup)) {
javadoc_text << input;
break; // end of loop
}
javadoc_text << text;
- if (markup.starts_with("\n")) {
+ if (str_util::StartsWith(markup, "\n")) {
javadoc_text << "\n";
- if (markup.contains("*")) {
+ if (str_util::StrContains(markup, "*")) {
// new list item
javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
in_list = true;
@@ -266,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) {
// end of list
javadoc_text << "</li>\n</ul>\n";
in_list = false;
- } else if (!input.starts_with("```")) {
+ } else if (!str_util::StartsWith(input, "```")) {
// new paragraph (not required if a <pre> block follows)
javadoc_text << "<p>\n";
}
- } else if (markup.starts_with("```")) {
+ } else if (str_util::StartsWith(markup, "```")) {
// code blocks
- if (FindAndCut(&input, "```\\s*\n*", &text)) {
+ if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("`")) {
+ } else if (str_util::StartsWith("(" + markup + ")", "`")) {
// inlined code
if (FindAndCut(&input, markup, &text)) {
javadoc_text << "{@code " << text << "}";
@@ -286,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) {
}
} else if (markup == "**") {
// text emphasis (strong)
- if (FindAndCut(&input, "\\b\\*{2}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
} else {
javadoc_text << markup;
}
} else if (markup == "*") {
// text emphasis (normal)
- if (FindAndCut(&input, "\\b\\*{1}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("[")) {
+ } else if (str_util::StartsWith(markup, "[")) {
// hyperlinks
string label;
string link;
- if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) {
+ if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+ &link) &&
+ str_util::StartsWith(input, label + link)) {
+ input = input.substr(label.size() + link.size());
javadoc_text << "<a href=\"" << link << "\">"
- << ParseDocumentation(label)
- << "</a>";
+ << ParseDocumentation(label) << "</a>";
} else {
javadoc_text << markup;
}
@@ -318,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) {
}
ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
- const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Arg& input_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(input_def, &iterable);
- Type var_type = Type::Interface("Operand", "org.tensorflow")
- .add_parameter(type);
+ Type var_type =
+ Type::Interface("Operand", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::IterableOf(var_type);
}
- return ArgumentSpec(input_api_def.name(),
+ return ArgumentSpec(
+ input_api_def.name(),
Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
- type,
- ParseDocumentation(input_api_def.description()),
- iterable);
+ type, ParseDocumentation(input_api_def.description()), iterable);
}
AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
- const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Attr& attr_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
- Type var_type = types.first.kind() == Type::GENERIC ?
- Type::Class("Class").add_parameter(types.first) : types.first;
+ Type var_type = types.first.kind() == Type::GENERIC
+ ? Type::Class("Class").add_parameter(types.first)
+ : types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return AttributeSpec(attr_api_def.name(),
+ return AttributeSpec(
+ attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
- types.first,
- types.second,
- ParseDocumentation(attr_api_def.description()),
- iterable,
- attr_api_def.has_default_value());
+ types.first, types.second, ParseDocumentation(attr_api_def.description()),
+ iterable, attr_api_def.has_default_value());
}
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
- const ApiDef::Arg& output_api, TypeResolver* type_resolver) {
+ const ApiDef::Arg& output_api,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(output_def, &iterable);
- Type var_type = Type::Class("Output", "org.tensorflow")
- .add_parameter(type);
+ Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return ArgumentSpec(output_api.name(),
+ return ArgumentSpec(
+ output_api.name(),
Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
- type,
- ParseDocumentation(output_api.description()),
- iterable);
+ type, ParseDocumentation(output_api.description()), iterable);
}
EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
- const ApiDef_Endpoint& endpoint_def) {
+ const ApiDef_Endpoint& endpoint_def) {
std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
string package;
string name;
@@ -379,24 +376,22 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
package = "core"; // generate unclassified ops in the 'core' package
name = name_tokens.at(0);
}
- return EndpointSpec(package,
- name,
- Javadoc::Create(ParseDocumentation(api_def.summary()))
- .details(ParseDocumentation(api_def.description())));
+ return EndpointSpec(package, name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())));
}
} // namespace
OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
- OpSpec op(api_def.graph_op_name(),
- api_def.visibility() == ApiDef::HIDDEN,
- op_def.deprecation().explanation());
+ OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
TypeResolver type_resolver(op_def);
for (const string& next_input_name : api_def.arg_order()) {
for (int i = 0; i < op_def.input_arg().size(); ++i) {
if (op_def.input_arg(i).name() == next_input_name) {
op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
- &type_resolver));
+ &type_resolver));
break;
}
}
@@ -405,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
// do not parse attributes already visited, they have probably been inferred
// before as an input argument type
if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
- AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
- &type_resolver);
+ AttributeSpec attr =
+ CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
// attributes with a default value are optional
if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
op.optional_attributes_.push_back(attr);
@@ -416,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
}
}
for (int i = 0; i < op_def.output_arg().size(); ++i) {
- op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i),
- &type_resolver));
+ op.outputs_.push_back(
+ CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
}
for (const auto& endpoint_def : api_def.endpoint()) {
op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index 034cf636ed..30ecb8ce53 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -19,14 +19,16 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/java/src/gen/cc/java_defs.h"
namespace tensorflow {
namespace java {
+constexpr const char kDefaultEndpointPackage[] = "core";
+
class EndpointSpec {
public:
// A specification for an operation endpoint
@@ -36,9 +38,8 @@ class EndpointSpec {
// javadoc: the endpoint class documentation
// TODO(annarev): hardcode depcreated to false until deprecated is possible
EndpointSpec(const string& package, const string& name,
- const Javadoc& javadoc)
- : package_(package), name_(name), javadoc_(javadoc),
- deprecated_(false) {}
+ const Javadoc& javadoc)
+ : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {}
const string& package() const { return package_; }
const string& name() const { return name_; }
@@ -61,10 +62,13 @@ class ArgumentSpec {
// type: the tensor type of this argument
// description: a description of this argument, in javadoc
// iterable: true if this argument is a list
- ArgumentSpec(const string& op_def_name, const Variable& var,
- const Type& type, const string& description, bool iterable)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable) {}
+ ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type,
+ const string& description, bool iterable)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -92,11 +96,16 @@ class AttributeSpec {
// iterable: true if this attribute is a list
// has_default_value: true if this attribute has a default value if not set
AttributeSpec(const string& op_def_name, const Variable& var,
- const Type& type, const Type& jni_type, const string& description,
- bool iterable, bool has_default_value)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable),
- jni_type_(jni_type), has_default_value_(has_default_value) {}
+ const Type& type, const Type& jni_type,
+ const string& description, bool iterable,
+ bool has_default_value)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable),
+ jni_type_(jni_type),
+ has_default_value_(has_default_value) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -145,9 +154,10 @@ class OpSpec {
// hidden: true if this op should not be visible through the Graph Ops API
// deprecation_explanation: message to show if all endpoints are deprecated
explicit OpSpec(const string& graph_op_name, bool hidden,
- const string& deprecation_explanation)
- : graph_op_name_(graph_op_name), hidden_(hidden),
- deprecation_explanation_(deprecation_explanation) {}
+ const string& deprecation_explanation)
+ : graph_op_name_(graph_op_name),
+ hidden_(hidden),
+ deprecation_explanation_(deprecation_explanation) {}
const string graph_op_name_;
const bool hidden_;
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 11fda4fc22..796d6a62dc 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -15,19 +15,44 @@ limitations under the License.
package org.tensorflow.processor;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.FieldSpec;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeSpec;
+import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
-import java.io.PrintWriter;
+import java.util.Collection;
import java.util.Collections;
-import java.util.HashSet;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
import javax.annotation.processing.ProcessingEnvironment;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.SourceVersion;
+import javax.lang.model.element.AnnotationMirror;
+import javax.lang.model.element.AnnotationValue;
import javax.lang.model.element.Element;
+import javax.lang.model.element.ExecutableElement;
+import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
+import javax.lang.model.element.TypeParameterElement;
+import javax.lang.model.element.VariableElement;
+import javax.lang.model.type.TypeMirror;
+import javax.lang.model.type.TypeVariable;
+import javax.lang.model.util.ElementFilter;
+import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;
/**
@@ -55,6 +80,7 @@ public final class OperatorProcessor extends AbstractProcessor {
super.init(processingEnv);
messager = processingEnv.getMessager();
filer = processingEnv.getFiler();
+ elements = processingEnv.getElementUtils();
}
@Override
@@ -98,42 +124,77 @@ public final class OperatorProcessor extends AbstractProcessor {
}
// Collect all classes tagged with our annotation.
- Set<TypeElement> opClasses = new HashSet<TypeElement>();
- if (!collectOpClasses(roundEnv, opClasses, annotation)) {
+ Multimap<String, MethodSpec> groupedMethods = HashMultimap.create();
+ if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
return true;
}
// Nothing to do when there are no tagged classes.
- if (opClasses.isEmpty()) {
+ if (groupedMethods.isEmpty()) {
return true;
}
- // TODO:(kbsriram) validate operator classes and generate Op API.
- writeApi();
+ // Validate operator classes and generate Op API.
+ writeApi(groupedMethods);
+
hasRun = true;
return true;
}
@Override
public Set<String> getSupportedAnnotationTypes() {
- return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE));
+ return Collections.singleton("org.tensorflow.op.annotation.Operator");
+ }
+
+ private static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
+ private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
+ private static final TypeName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
+ private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
+ private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
+ private static final TypeName T_STRING = ClassName.get(String.class);
+
+ private Filer filer;
+ private Messager messager;
+ private Elements elements;
+ private boolean hasRun = false;
+
+ private void error(Element e, String message, Object... args) {
+ if (args != null && args.length > 0) {
+ message = String.format(message, args);
+ }
+ messager.printMessage(Kind.ERROR, message, e);
}
- private void writeApi() {
- // Generate an empty class for now and get the build working correctly. This will be changed to
- // generate the actual API once we've done with build-related changes.
- // TODO:(kbsriram)
- try (PrintWriter writer =
- new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) {
- writer.println(String.format("package %s;", OP_PACKAGE));
- writer.println("public class Ops{}");
+ private void write(TypeSpec spec) {
+ try {
+ JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
} catch (IOException e) {
- error(null, "Unexpected failure generating API: %s", e.getMessage());
+ throw new AssertionError(e);
+ }
+ }
+
+ private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
+ Map<String, ClassName> groups = new HashMap<>();
+
+ // Generate a API class for each group collected other than the default one (= empty string)
+ for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
+ if (!entry.getKey().isEmpty()) {
+ TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
+ write(groupClass);
+ groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name));
+ }
}
+ // Generate the top API class, adding any methods added to the default group
+ TypeSpec topClass = buildTopClass(groups, groupedMethods.get(""));
+ write(topClass);
}
- private boolean collectOpClasses(
- RoundEnvironment roundEnv, Set<TypeElement> opClasses, TypeElement annotation) {
+ private boolean collectOpsMethods(
+ RoundEnvironment roundEnv,
+ Multimap<String, MethodSpec> groupedMethods,
+ TypeElement annotation) {
boolean result = true;
for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
// @Operator can only apply to types, so e must be a TypeElement.
@@ -145,20 +206,251 @@ public final class OperatorProcessor extends AbstractProcessor {
result = false;
continue;
}
- opClasses.add((TypeElement) e);
+ TypeElement opClass = (TypeElement) e;
+ // Skip deprecated operations for now, as we do not guarantee API stability yet
+ if (opClass.getAnnotation(Deprecated.class) == null) {
+ collectOpMethods(groupedMethods, opClass, annotation);
+ }
}
return result;
}
- private void error(Element e, String message, Object... args) {
- if (args != null && args.length > 0) {
- message = String.format(message, args);
+ private void collectOpMethods(
+ Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
+ AnnotationMirror am = getAnnotationMirror(opClass, annotation);
+ String groupName = getAnnotationElementValueAsString("group", am);
+ String methodName = getAnnotationElementValueAsString("name", am);
+ ClassName opClassName = ClassName.get(opClass);
+ if (Strings.isNullOrEmpty(methodName)) {
+ methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
+ }
+ // Build a method for each @Operator found in the class path. There should be one method per
+ // operation factory called
+ // "create", which takes in parameter a scope and, optionally, a list of arguments
+ for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
+ if (opMethod.getModifiers().contains(Modifier.STATIC)
+ && opMethod.getSimpleName().contentEquals("create")) {
+ MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
+ groupedMethods.put(groupName, method);
+ }
}
- messager.printMessage(Kind.ERROR, message, e);
}
- private Filer filer;
- private Messager messager;
- private boolean hasRun = false;
- private static final String OP_PACKAGE = "org.tensorflow.op";
+ private MethodSpec buildOpMethod(
+ String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
+ MethodSpec.Builder builder =
+ MethodSpec.methodBuilder(methodName)
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(factoryMethod.getReturnType()))
+ .varargs(factoryMethod.isVarArgs())
+ .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
+
+ for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
+ TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
+ builder.addTypeVariable(tvn);
+ }
+ for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
+ builder.addException(TypeName.get(thrownType));
+ }
+ StringBuilder call = new StringBuilder("return $T.create(scope");
+ boolean first = true;
+ for (VariableElement param : factoryMethod.getParameters()) {
+ ParameterSpec p = ParameterSpec.get(param);
+ if (first) {
+ first = false;
+ continue;
+ }
+ call.append(", ");
+ call.append(p.name);
+ builder.addParameter(p);
+ }
+ call.append(")");
+ builder.addStatement(call.toString(), opClassName);
+ return builder.build();
+ }
+
+ private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
+ StringBuilder javadoc = new StringBuilder();
+ javadoc
+ .append("Adds an {@link ")
+ .append(opClassName.simpleName())
+ .append("} operation to the graph\n\n");
+
+ // Add all javadoc tags found in the operator factory method but the first one, which should be
+ // in all cases the
+ // 'scope' parameter that is implicitly passed by this API
+ Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
+ boolean firstParam = true;
+
+ while (tagMatcher.find()) {
+ String tag = tagMatcher.group();
+ if (tag.startsWith("@param") && firstParam) {
+ firstParam = false;
+ } else {
+ javadoc.append(tag).append('\n');
+ }
+ }
+ javadoc.append("@see {@link ").append(opClassName).append("}\n");
+
+ return javadoc.toString();
+ }
+
+ private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
+ MethodSpec.Builder ctorBuilder =
+ MethodSpec.constructorBuilder()
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope");
+
+ TypeSpec.Builder builder =
+ TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
+ + "@see {@link $T}\n",
+ group,
+ T_GRAPH,
+ T_OPS)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
+
+ builder.addField(
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+
+ return builder.build();
+ }
+
+ private static TypeSpec buildTopClass(
+ Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
+ MethodSpec.Builder ctorBuilder =
+ MethodSpec.constructorBuilder()
+ .addModifiers(Modifier.PRIVATE)
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope", T_SCOPE);
+
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
+ ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
+ }
+
+ TypeSpec.Builder opsBuilder =
+ TypeSpec.classBuilder("Ops")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for building a {@link $T} with operation wrappers\n<p>\n"
+ + "Any operation wrapper found in the classpath properly annotated as an"
+ + "{@link $T @Operator} is exposed\n"
+ + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
+ + "try (Graph g = new Graph()) {\n"
+ + " Ops ops = new Ops(g);\n"
+ + " // Operations are typed classes with convenience\n"
+ + " // builders in Ops.\n"
+ + " Constant three = ops.constant(3);\n"
+ + " // Single-result operations implement the Operand\n"
+ + " // interface, so this works too.\n"
+ + " Operand four = ops.constant(4);\n"
+ + " // Most builders are found within a group, and accept\n"
+ + " // Operand types as operands\n"
+ + " Operand nine = ops.math().add(four, ops.constant(5));\n"
+ + " // Multi-result operations however offer methods to\n"
+ + " // select a particular result for use.\n"
+ + " Operand result = \n"
+ + " ops.math().add(ops.array().unique(s, a).y(), b);\n"
+ + " // Optional attributes\n"
+ + " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+ + " // Naming operators\n"
+ + " ops.withName(“foo”).constant(5); // name “foo”\n"
+ + " // Names can exist in a hierarchy\n"
+ + " Ops sub = ops.withSubScope(“sub”);\n"
+ + " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ + "}\n"
+ + "}</pre>\n",
+ T_GRAPH,
+ T_OPERATOR)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("withSubScope")
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "childScopeName")
+ .returns(T_OPS)
+ .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
+ .addJavadoc(
+ "Returns an API that adds operations to the graph with the provided name prefix.\n"
+ + "\n@see {@link $T#withSubScope(String)}\n",
+ T_SCOPE)
+ .build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("withName")
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "opName")
+ .returns(T_OPS)
+ .addStatement("return new Ops(scope.withName(opName))")
+ .addJavadoc(
+ "Returns an API that uses the provided name for an op.\n\n"
+ + "@see {@link $T#withName(String)}\n",
+ T_SCOPE)
+ .build());
+
+ opsBuilder.addField(
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("scope")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(T_SCOPE)
+ .addStatement("return scope")
+ .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
+ .build());
+
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
+ opsBuilder.addField(
+ FieldSpec.builder(entry.getValue(), entry.getKey())
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder(entry.getKey())
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(entry.getValue())
+ .addStatement("return $L", entry.getKey())
+ .addJavadoc(
+ "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
+ .build());
+ }
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("create")
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addParameter(T_GRAPH, "graph")
+ .returns(T_OPS)
+ .addStatement("return new Ops(new $T(graph))", T_SCOPE)
+ .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
+ .build());
+
+ return opsBuilder.build();
+ }
+
+ private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) {
+ for (AnnotationMirror am : element.getAnnotationMirrors()) {
+ if (am.getAnnotationType().asElement().equals(annotation)) {
+ return am;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + element.getSimpleName());
+ }
+
+ private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
+ for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
+ if (entry.getKey().getSimpleName().contentEquals(elementName)) {
+ return entry.getValue().getValue().toString();
+ }
+ }
+ return "";
+ }
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index d4fd3db5f7..7d19696749 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable {
}
}
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
+ * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
+ * <p>
+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
+ * shapes in {@code y}.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ Output<?>[] dy = new Output<?>[x.length];
+ final long[] yHandles = new long[y.length];
+ final int[] yIndices = new int[y.length];
+ final long[] xHandles = new long[x.length];
+ final int[] xIndices = new int[x.length];
+ long[] dxHandles = null;
+ int[] dxIndices = null;
+
+ try (Reference ref = ref()) {
+ for (int i = 0; i < y.length; ++i) {
+ yHandles[i] = y[i].op().getUnsafeNativeHandle();
+ yIndices[i] = y[i].index();
+ }
+ for (int i = 0; i < x.length; ++i) {
+ xHandles[i] = x[i].op().getUnsafeNativeHandle();
+ xIndices[i] = x[i].index();
+ }
+ if (dx != null && dx.length > 0) {
+ dxHandles = new long[dx.length];
+ dxIndices = new int[dx.length];
+
+ for (int i = 0; i < dx.length; ++i) {
+ dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
+ dxIndices[i] = dx[i].index();
+ }
+ }
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
+ // of the gradient operations while the second holds the index of their output
+ // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
+ long[] dyHandlesAndIndices =
+ addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ int ndy = dyHandlesAndIndices.length >> 1;
+ if (ndy != dy.length) {
+ throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
+ + " were expected");
+ }
+ for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
+ Operation op = new Operation(this, dyHandlesAndIndices[i]);
+ dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
+ }
+ }
+ return dy;
+ }
+
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code dy/dx_1, dy/dx_2...}
+ * <p>
+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
+ * a single output and {@code dx} is null.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
+ return addGradients(new Output<?>[]{y}, x, null);
+ }
+
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;
@@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
+ private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+
static {
TensorFlow.init();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java
new file mode 100644
index 0000000000..13bc463e7d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * Interface implemented by operands of a TensorFlow operation.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * // The "decodeJpeg" operation can be used as input to the "cast" operation
+ * Input decodeJpeg = ops.image().decodeJpeg(...);
+ * ops.math().cast(decodeJpeg, DataType.FLOAT);
+ *
+ * // The output "y" of the "unique" operation can be used as input to the "cast" operation
+ * Output y = ops.array().unique(...).y();
+ * ops.math().cast(y, DataType.FLOAT);
+ *
+ * // The "split" operation can be used as input list to the "concat" operation
+ * Iterable<? extends Input> split = ops.array().split(...);
+ * ops.array().concat(0, split);
+ * }</pre>
+ */
+public interface Input<T> {
+
+ /**
+ * Returns the symbolic handle of a tensor.
+ *
+ * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is
+ * used to obtain a symbolic handle that represents the computation of the input.
+ *
+ * @see OperationBuilder#addInput(Output)
+ */
+ Output<T> asOutput();
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
new file mode 100644
index 0000000000..f4671c8af9
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Operands;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss
+ * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}.
+ * <p>
+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all
+ * shapes in {@code y}.
+ * <p>
+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}.
+ * <p>
+ * Example of usage:
+ * <pre>{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ *
+ * Constant<Float> alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
+ * }</pre>
+ */
+@Operator
+public class Gradients implements Op, Iterable<Operand<?>> {
+
+ /**
+ * Optional attributes for {@link Gradients}
+ */
+ public static class Options {
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return this option builder
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ this.dx = dx;
+ return this;
+ }
+
+ private Iterable<Operand<?>> dx;
+
+ private Options() {
+ }
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * @param scope current graph scope
+ * @param y outputs of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ Output<?>[] dx = null;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.dx != null) {
+ dx = Operands.asOutputs(opts.dx);
+ }
+ }
+ }
+ Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(gradOutputs));
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
+ * a single output.
+ *
+ * @param scope current graph scope
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ return create(scope, (Iterable) Arrays.asList(y), x, options);
+ }
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return builder to add more options to this operation
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ return new Options().dx(dx);
+ }
+
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public Iterator<Operand<?>> iterator() {
+ return (Iterator) dy.iterator();
+ }
+
+ /**
+ * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x}
+ */
+ public List<Output<?>> dy() {
+ return dy;
+ }
+
+ /**
+ * Returns a symbolic handle to one of the gradient operation output
+ * <p>
+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * gradients.<Integer>dy(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param index The index of the output among the gradients added by this operation
+ */
+ @SuppressWarnings("unchecked")
+ public <T> Output<T> dy(int index) {
+ return (Output<T>) dy.get(index);
+ }
+
+ private List<Output<?>> dy;
+
+ private Gradients(List<Output<?>> dy) {
+ this.dy = dy;
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
new file mode 100644
index 0000000000..ab34f6aa12
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a boolean. */
+public class TFBool implements TFType {
+ private TFBool() {}
+ static {
+ Types.typeCodes.put(TFBool.class, DataType.BOOL);
+ }
+ static {
+ Types.scalars.put(TFBool.class, false);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
new file mode 100644
index 0000000000..49e5d9f2f3
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit double precision floating point number. */
+public class TFDouble implements TFType {
+ private TFDouble() {}
+ static {
+ Types.typeCodes.put(TFDouble.class, DataType.DOUBLE);
+ }
+ static {
+ Types.scalars.put(TFDouble.class, 0.0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
new file mode 100644
index 0000000000..8426ee41f0
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit single precision floating point number. */
+public class TFFloat implements TFType {
+ private TFFloat() {}
+ static {
+ Types.typeCodes.put(TFFloat.class, DataType.FLOAT);
+ }
+ static {
+ Types.scalars.put(TFFloat.class, 0f);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
new file mode 100644
index 0000000000..3947b6ad09
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit signed integer. */
+public class TFInt32 implements TFType {
+ private TFInt32() {}
+ static {
+ Types.typeCodes.put(TFInt32.class, DataType.INT32);
+ }
+ static {
+ Types.scalars.put(TFInt32.class, 0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
new file mode 100644
index 0000000000..ccdded8693
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit signed integer. */
+public class TFInt64 implements TFType {
+ private TFInt64() {}
+ static {
+ Types.typeCodes.put(TFInt64.class, DataType.INT64);
+ }
+ static {
+ Types.scalars.put(TFInt64.class, 0L);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
new file mode 100644
index 0000000000..e7327e8c57
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an arbitrary sequence of bytes. */
+public class TFString implements TFType {
+ private TFString() {}
+ static {
+ Types.typeCodes.put(TFString.class, DataType.STRING);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
new file mode 100644
index 0000000000..562953ac9d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
@@ -0,0 +1,20 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+/**
+ * A marker interface for classes representing TensorFlow types.
+ */
+public interface TFType {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
new file mode 100644
index 0000000000..d7305ca5a8
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an 8-bit unsigned integer. */
+public class TFUInt8 implements TFType {
+ private TFUInt8() {}
+ static {
+ Types.typeCodes.put(TFUInt8.class, DataType.UINT8);
+ }
+ static {
+ Types.scalars.put(TFUInt8.class, (byte)0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
new file mode 100644
index 0000000000..976cd9fd34
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.tensorflow.DataType;
+
+/**
+ * Utility class for managing the representation of TensorFlow types as Java
+ * types. For each TensorFlow type (e.g., int32), there is a corresponding Java
+ * type (e.g., TFInt32) that represents it at compile time and a corresponding
+ * class object (e.g., TFInt32.class) that represents it at run time. There is
+ * also an enumeration value in DataType that can be used to represent the
+ * type, though that should rarely be required.
+ */
+public class Types {
+
+ private Types() {} // not instantiable
+
+ static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
+
+ /** Returns the DataType value corresponding to a TensorFlow type class. */
+ public static DataType dataType(Class<? extends TFType> c) {
+ DataType dtype = typeCodes.get(c);
+ if (dtype == null) {
+ throw new IllegalArgumentException("" + c + " is not a TensorFlow type.");
+ }
+ return dtype;
+ }
+
+ static final Map<Class<?>, Object> scalars = new HashMap<>();
+
+ /** Returns the zero value of type described by {@code c}, or null if
+ * the type (e.g., string) is not numeric and therefore has no zero value.
+ */
+ public static Object zeroValue(Class<? extends TFType> c) {
+ return scalars.get(c);
+ }
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 0fef155275..dac6a345e9 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/java/src/main/native/graph_jni.h"
#include <limits>
+#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
@@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
TF_DeleteBuffer(buf);
return ret;
}
+
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
+ jlongArray y_handles, jintArray y_indices,
+ jlongArray x_handles, jintArray x_indices,
+ jlongArray dx_handles, jintArray dx_indices) {
+
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return nullptr;
+
+ const jint ny = env->GetArrayLength(y_handles);
+ const jint nx = env->GetArrayLength(x_handles);
+
+ std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
+ std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
+ std::unique_ptr<TF_Output[]> dx(nullptr);
+ std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
+
+ resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
+ resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
+ if (dx_handles != nullptr) {
+ if (env->GetArrayLength(dx_handles) != ny) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d dx handles", ny,
+ env->GetArrayLength(dx_handles));
+ }
+ dx.reset(new TF_Output[ny]);
+ resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
+ }
+ if (env->ExceptionCheck()) return nullptr;
+
+ TF_Status* status = TF_NewStatus();
+ TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
+
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ TF_DeleteStatus(status);
+
+ // returned array contains both op handles and output indices, in pair
+ jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
+ jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
+ for (int i = 0, j = nx; i < nx; ++i, ++j) {
+ TF_Output dy_output = dy.get()[i];
+ dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
+ dy_elems[j] = static_cast<jlong>(dy_output.index);
+ }
+ env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
+
+ return dy_handles_and_indices;
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index dd2e038332..4f87e8d5a7 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
jclass,
jlong);
+/*
+ * Class: org_tensorflow_Graph
+ * Method: name
+ * Signature: (J[J[I[J[I[J[I)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
+ jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
+ jintArray);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc
index 2cd542d3c9..8b11525785 100644
--- a/tensorflow/java/src/main/native/session_jni.cc
+++ b/tensorflow/java/src/main/native/session_jni.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/session_jni.h"
@@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
}
-void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
- jintArray src_index, TF_Output* dst, jint n) {
- if (env->ExceptionCheck()) return;
- jint len = env->GetArrayLength(src_op);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operations", n, len, type);
- return;
- }
- len = env->GetArrayLength(src_index);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operation output indices", n, len,
- type);
- return;
- }
- jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
- jint* indices = env->GetIntArrayElements(src_index, nullptr);
- for (int i = 0; i < n; ++i) {
- if (op_handles[i] == 0) {
- throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
- i, n);
- break;
- }
- dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
- static_cast<int>(indices[i])};
- }
- env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
- env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
-}
-
void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
if (buf == nullptr) return;
TF_DeleteBuffer(buf);
@@ -116,20 +86,22 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
- const char* ctarget = nullptr;
jbyte* cconfig = nullptr;
- if (target != nullptr) {
- ctarget = env->GetStringUTFChars(target, nullptr);
- }
if (config != nullptr) {
cconfig = env->GetByteArrayElements(config, nullptr);
TF_SetConfig(opts, cconfig,
static_cast<size_t>(env->GetArrayLength(config)), status);
if (!throwExceptionIfNotOK(env, status)) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
+ TF_DeleteSessionOptions(opts);
+ TF_DeleteStatus(status);
return 0;
}
}
+ const char* ctarget = nullptr;
+ if (target != nullptr) {
+ ctarget = env->GetStringUTFChars(target, nullptr);
+ }
TF_Session* session = TF_NewSession(graph, opts, status);
if (config != nullptr) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc
new file mode 100644
index 0000000000..069ac05a1c
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/java/src/main/native/utils_jni.h"
+
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n) {
+ if (env->ExceptionCheck()) return;
+ jint len = env->GetArrayLength(src_op);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operations", n, len, type);
+ return;
+ }
+ len = env->GetArrayLength(src_index);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operation output indices", n, len,
+ type);
+ return;
+ }
+ jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
+ jint* indices = env->GetIntArrayElements(src_index, nullptr);
+ for (int i = 0; i < n; ++i) {
+ if (op_handles[i] == 0) {
+ throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
+ i, n);
+ break;
+ }
+ dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
+ static_cast<int>(indices[i])};
+ }
+ env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
+ env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
+}
+
+
+
+
diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h
new file mode 100644
index 0000000000..352298e7de
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_
+#define TENSORFLOW_JAVA_UTILS_JNI_H_
+
+#include <jni.h>
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c540299bdc..c2e52c22c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -129,4 +130,106 @@ public class GraphTest {
// expected exception.
}
}
+
+ @Test
+ public void addGradientsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x1);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+ Output<Float> y2 = TestUtil.addN(g, y0, x2);
+
+ Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
+ assertNotNull(grads0);
+ assertEquals(1, grads0.length);
+ assertEquals(DataType.FLOAT, grads0[0].dataType());
+
+ Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
+ assertNotNull(grads1);
+ assertEquals(2, grads1.length);
+ assertEquals(DataType.FLOAT, grads1[0].dataType());
+ assertEquals(DataType.FLOAT, grads1[1].dataType());
+
+ try (Tensor<Float> c1 = Tensors.create(3.0f);
+ Tensor<Float> c2 = Tensors.create(2.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ s.runner()
+ .feed(x1, c1)
+ .feed(x2, c2)
+ .fetch(grads0[0])
+ .fetch(grads1[0])
+ .fetch(grads1[1])
+ .run())) {
+
+ assertEquals(3, outputs.size());
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientSumsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ assertNotNull(grad);
+ assertEquals(1, grad.length);
+ assertEquals(DataType.FLOAT, grad[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(114.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientsWithInitialValuesToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
+ assertNotNull(grad0);
+ assertEquals(1, grad0.length);
+ assertEquals(DataType.FLOAT, grad0[0].dataType());
+
+ Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ assertNotNull(grad1);
+ assertEquals(1, grad1.length);
+ assertEquals(DataType.FLOAT, grad1[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad1[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(108.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ private static Output<?>[] toArray(Output<?>... outputs) {
+ return outputs;
+ }
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index e8cc76c2a6..7d5980bcde 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import java.util.ArrayList;
-import java.util.Collection;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -36,8 +34,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -53,8 +51,8 @@ public class SessionTest {
Output<Integer> feed = g.operation("X").output(0);
Output<Integer> fetch = g.operation("Y").output(0);
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -112,7 +110,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -135,8 +133,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
@@ -164,28 +162,6 @@ public class SessionTest {
Session s = new Session(g, singleThreadConfigProto())) {}
}
- private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
- implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
- super(c);
- }
-
- @Override
- public void close() {
- Exception toThrow = null;
- for (AutoCloseable c : this) {
- try {
- c.close();
- } catch (Exception e) {
- toThrow = e;
- }
- }
- if (toThrow != null) {
- throw new RuntimeException(toThrow);
- }
- }
- }
-
private static byte[] fullTraceRunOptions() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c973b5a3d8..4e84886416 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -16,9 +16,34 @@ limitations under the License.
package org.tensorflow;
import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Collection;
/** Static utility functions. */
public class TestUtil {
+
+ public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
+ implements AutoCloseable {
+ AutoCloseableList(Collection<? extends E> c) {
+ super(c);
+ }
+
+ @Override
+ public void close() {
+ Exception toThrow = null;
+ for (AutoCloseable c : this) {
+ try {
+ c.close();
+ } catch (Exception e) {
+ toThrow = e;
+ }
+ }
+ if (toThrow != null) {
+ throw new RuntimeException(toThrow);
+ }
+ }
+ }
+
public static <T> Output<T> constant(Graph g, String name, Object value) {
try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)
@@ -36,7 +61,7 @@ public class TestUtil {
.<T>output(0);
}
- public static Output<?> addN(Graph g, Output<?>... inputs) {
+ public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
@@ -58,6 +83,13 @@ public class TestUtil {
.setAttr("num_split", numSplit)
.build();
}
+
+ public static <T> Output<T> square(Graph g, String name, Output<T> value) {
+ return g.opBuilder("Square", name)
+ .addInput(value)
+ .build()
+ .<T>output(0);
+ }
public static void transpose_A_times_X(Graph g, int[][] a) {
Output<Integer> aa = constant(g, "A", a);
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 569403fa9a..d60d37df50 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4,14 +4,16 @@
# Public targets:
# ":platform" - Low-level and platform-specific Python code.
-package(default_visibility = [
+visibility = [
"//engedu/ml/tf_from_scratch:__pkg__",
"//tensorflow:internal",
"//tensorflow/contrib/lite/toco/python:__pkg__",
"//tensorflow_models:__subpackages__",
# TODO(aselle): to pass open source test.
"//bazel_pip/tensorflow/contrib/lite/toco/python:__pkg__",
-])
+]
+
+package(default_visibility = visibility)
licenses(["notice"]) # Apache 2.0
@@ -55,12 +57,12 @@ py_library(
"//tensorflow/contrib/lite/toco/python:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/debug:__pkg__", # TODO(b/34059704): remove when fixed
"//tensorflow/python/tools:__pkg__", # TODO(b/34059704): remove when fixed
- "//tensorflow/tools/api/generator:__pkg__",
"//tensorflow/tools/quantization:__pkg__", # TODO(b/34059704): remove when fixed
],
deps = [
":no_contrib",
"//tensorflow/contrib:contrib_py",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -125,13 +127,14 @@ py_library(
":util",
":weights_broadcast_ops",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python/compat",
"//tensorflow/python/data",
- "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/linalg",
"//tensorflow/python/ops/losses",
+ "//tensorflow/python/ops/parallel_for",
"//tensorflow/python/profiler",
"//tensorflow/python/saved_model",
"//third_party/py/numpy",
@@ -278,6 +281,9 @@ cc_library(
name = "ndarray_tensor_bridge",
srcs = ["lib/core/ndarray_tensor_bridge.cc"],
hdrs = ["lib/core/ndarray_tensor_bridge.h"],
+ visibility = visibility + [
+ "//learning/deepmind/courier:__subpackages__",
+ ],
deps = [
":bfloat16_lib",
":numpy_lib",
@@ -358,6 +364,9 @@ cc_library(
name = "ndarray_tensor",
srcs = ["lib/core/ndarray_tensor.cc"],
hdrs = ["lib/core/ndarray_tensor.h"],
+ visibility = visibility + [
+ "//learning/deepmind/courier:__subpackages__",
+ ],
deps = [
":bfloat16_lib",
":ndarray_tensor_bridge",
@@ -691,11 +700,21 @@ py_library(
)
py_library(
+ name = "error_interpolation",
+ srcs = [
+ "framework/error_interpolation.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [],
+)
+
+py_library(
name = "function",
srcs = ["framework/function.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
+ ":cond_v2_impl",
":dtypes",
":framework_ops",
":graph_to_function_def",
@@ -712,6 +731,7 @@ py_library(
srcs = ["framework/graph_to_function_def.py"],
srcs_version = "PY2AND3",
deps = [
+ ":cond_v2_impl",
":op_def_registry",
"//tensorflow/core:protos_all_py",
],
@@ -991,6 +1011,18 @@ py_test(
)
py_test(
+ name = "framework_error_interpolation_test",
+ size = "small",
+ srcs = ["framework/error_interpolation_test.py"],
+ main = "framework/error_interpolation_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":error_interpolation",
+ ],
+)
+
+py_test(
name = "framework_subscribe_test",
size = "small",
srcs = ["framework/subscribe_test.py"],
@@ -1050,7 +1082,9 @@ py_test(
tf_gen_op_wrapper_private_py(
name = "functional_ops_gen",
- visibility = ["//learning/brain/python/ops:__pkg__"],
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ ],
)
py_library(
@@ -1597,6 +1631,9 @@ tf_gen_op_wrapper_private_py(
tf_gen_op_wrapper_private_py(
name = "resource_variable_ops_gen",
+ visibility = [
+ "//tensorflow/compiler/tf2xla:internal",
+ ],
)
tf_gen_op_wrapper_private_py(
@@ -1824,6 +1861,7 @@ py_library(
"tensor_shape",
":array_ops",
":array_ops_gen",
+ ":cond_v2_impl",
":constant_op",
":control_flow_ops_gen",
":control_flow_util",
@@ -1853,6 +1891,37 @@ py_library(
)
py_library(
+ name = "cond_v2",
+ srcs = [
+ "ops/cond_v2.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":cond_v2_impl",
+ ":function",
+ ":function_def_to_graph",
+ ":gradients",
+ ],
+)
+
+py_library(
+ name = "cond_v2_impl",
+ srcs = [
+ "ops/cond_v2_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":c_api_util",
+ ":framework_ops",
+ ":functional_ops_gen",
+ ":pywrap_tensorflow",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_library(
name = "ctc_ops",
srcs = ["ops/ctc_ops.py"],
srcs_version = "PY2AND3",
@@ -1918,6 +1987,8 @@ py_library(
":math_ops",
":platform",
":resource_variable_ops",
+ ":sparse_ops",
+ ":tensor_shape",
":variables",
],
)
@@ -1934,6 +2005,7 @@ py_library(
":array_grad",
":array_ops",
":bitwise_ops",
+ ":cond_v2_impl",
":control_flow_grad",
":control_flow_ops",
":control_flow_util",
@@ -1950,6 +2022,7 @@ py_library(
":math_grad",
":math_ops",
":platform",
+ ":random_grad",
":resource_variable_ops",
":spectral_grad",
":util",
@@ -2329,6 +2402,19 @@ py_library(
)
py_library(
+ name = "random_grad",
+ srcs = ["ops/random_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":dtypes",
+ ":framework_ops",
+ ":math_ops",
+ ":random_ops_gen",
+ ],
+)
+
+py_library(
name = "random_ops",
srcs = ["ops/random_ops.py"],
srcs_version = "PY2AND3",
@@ -2388,6 +2474,7 @@ py_library(
srcs = ["ops/script_ops.py"],
srcs_version = "PY2AND3",
deps = [
+ ":array_ops",
":framework_for_generated_wrappers",
":script_ops_gen",
"//third_party/py/numpy",
@@ -2527,6 +2614,7 @@ py_library(
":check_ops",
":confusion_matrix",
":control_flow_ops",
+ ":distribute",
":framework",
":framework_for_generated_wrappers",
":math_ops",
@@ -3334,6 +3422,19 @@ py_library(
],
)
+py_test(
+ name = "lock_util_test",
+ size = "small",
+ srcs = ["util/lock_util_test.py"],
+ main = "util/lock_util_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
tf_proto_library(
name = "protos_all",
srcs = glob(
@@ -3521,6 +3622,7 @@ tf_py_wrap_cc(
"util/transform_graph.i",
"util/util.i",
],
+ # add win_def_file
win_def_file = select({
"//tensorflow:windows": ":pywrap_tensorflow_filtered_def_file",
"//conditions:default": None,
@@ -3651,6 +3753,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":c_api_util",
+ ":error_interpolation",
":errors",
":framework",
":framework_for_generated_wrappers",
@@ -3851,7 +3954,7 @@ tf_cuda_library(
tf_py_test(
name = "session_test",
- size = "small",
+ size = "medium",
srcs = ["client/session_test.py"],
additional_deps = [
":array_ops",
@@ -3976,6 +4079,7 @@ cuda_py_test(
":math_ops",
"//tensorflow/core:protos_all_py",
],
+ tags = ["no_windows_gpu"],
)
py_test(
@@ -4033,6 +4137,19 @@ py_test(
],
)
+py_test(
+ name = "tf_record_test",
+ size = "small",
+ srcs = ["lib/io/tf_record_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":errors",
+ ":lib",
+ ":util",
+ ],
+)
+
cuda_py_test(
name = "adam_test",
size = "small",
@@ -4340,7 +4457,7 @@ py_test(
py_test(
name = "warm_starting_util_test",
- size = "small",
+ size = "medium",
srcs = ["training/warm_starting_util_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index cf707fb2c7..a2ab63bb48 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -79,7 +79,6 @@ from tensorflow.python.ops import initializers_ns as initializers
# Bring in subpackages.
from tensorflow.python import data
from tensorflow.python import keras
-from tensorflow.python.estimator import estimator_lib as estimator
from tensorflow.python.feature_column import feature_column_lib as feature_column
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 5507d011bb..e037925961 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -361,7 +361,7 @@ class _ListFetchMapper(_FetchMapper):
for m, vi in zip(self._mappers, self._value_indices):
results.append(m.build_results([values[j] for j in vi]))
# Return a value of the original type of the fetches.
- if self._fetch_type == list:
+ if issubclass(self._fetch_type, list):
return results
elif self._fetch_type == tuple:
return tuple(results)
@@ -619,21 +619,12 @@ class BaseSession(SessionInterface):
self._config = None
self._add_shapes = False
- # pylint: disable=protected-access
- # We cache _USE_C_API's value because some test cases will create a session
- # with _USE_C_API = False but set it back to True before calling close().
- self._created_with_new_api = ops._USE_C_API
- # pylint: enable=protected-access
-
self._session = None
opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
try:
- if self._created_with_new_api:
- # pylint: disable=protected-access
- self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
- # pylint: enable=protected-access
- else:
- self._session = tf_session.TF_NewDeprecatedSession(opts)
+ # pylint: disable=protected-access
+ self._session = tf_session.TF_NewSession(self._graph._c_graph, opts)
+ # pylint: enable=protected-access
finally:
tf_session.TF_DeleteSessionOptions(opts)
@@ -660,11 +651,7 @@ class BaseSession(SessionInterface):
Returns:
A list of devices in the session.
"""
- if self._created_with_new_api:
- raw_device_list = tf_session.TF_SessionListDevices(self._session)
- else:
- raw_device_list = tf_session.TF_DeprecatedSessionListDevices(
- self._session)
+ raw_device_list = tf_session.TF_SessionListDevices(self._session)
device_list = []
size = tf_session.TF_DeviceListCount(raw_device_list)
for i in range(size):
@@ -684,16 +671,9 @@ class BaseSession(SessionInterface):
tf.errors.OpError: Or one of its subclasses if an error occurs while
closing the TensorFlow session.
"""
- if self._created_with_new_api:
- if self._session and not self._closed:
- self._closed = True
- tf_session.TF_CloseSession(self._session)
-
- else:
- with self._extend_lock:
- if self._opened and not self._closed:
- self._closed = True
- tf_session.TF_CloseDeprecatedSession(self._session)
+ if self._session and not self._closed:
+ self._closed = True
+ tf_session.TF_CloseSession(self._session)
def __del__(self):
# cleanly ignore all exceptions
@@ -703,10 +683,7 @@ class BaseSession(SessionInterface):
pass
if self._session is not None:
try:
- if self._created_with_new_api:
- tf_session.TF_DeleteSession(self._session)
- else:
- tf_session.TF_DeleteDeprecatedSession(self._session)
+ tf_session.TF_DeleteSession(self._session)
except AttributeError:
# At shutdown, `c_api_util` or `tf_session` may have been garbage
# collected, causing the above method calls to fail. In this case,
@@ -1005,12 +982,9 @@ class BaseSession(SessionInterface):
try:
subfeed_t = self.graph.as_graph_element(
subfeed, allow_tensor=True, allow_operation=False)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feed_list.append(subfeed_t._as_tf_output())
- # pylint: enable=protected-access
- else:
- feed_list.append(compat.as_bytes(subfeed_t.name))
+ # pylint: disable=protected-access
+ feed_list.append(subfeed_t._as_tf_output())
+ # pylint: enable=protected-access
except Exception as e:
e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
e.args = (e.message,)
@@ -1023,22 +997,13 @@ class BaseSession(SessionInterface):
# Set up a graph with feeds and fetches for partial run.
def _setup_fn(session, feed_list, fetch_list, target_list):
self._extend_graph()
- if self._created_with_new_api:
- return tf_session.TF_SessionPRunSetup_wrapper(
- session, feed_list, fetch_list, target_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRunSetup(session, feed_list, fetch_list,
- target_list, status)
+ return tf_session.TF_SessionPRunSetup_wrapper(
+ session, feed_list, fetch_list, target_list)
- if self._created_with_new_api:
- # pylint: disable=protected-access
- final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
- final_targets = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- final_fetches = _name_list(fetch_handler.fetches())
- final_targets = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
+ final_targets = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
final_targets)
@@ -1196,14 +1161,10 @@ class BaseSession(SessionInterface):
# Create a fetch handler to take care of the structure of fetches.
fetch_handler = _FetchHandler(self._graph, fetches, {})
- if self._created_with_new_api:
- # pylint: disable=protected-access
- fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
- target_list = [op._c_op for op in fetch_handler.targets()]
- # pylint: enable=protected-access
- else:
- fetch_list = _name_list(fetch_handler.fetches())
- target_list = _name_list(fetch_handler.targets())
+ # pylint: disable=protected-access
+ fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
+ target_list = [op._c_op for op in fetch_handler.targets()]
+ # pylint: enable=protected-access
def _callable_template_with_options_and_metadata(fetch_list,
target_list,
@@ -1289,16 +1250,11 @@ class BaseSession(SessionInterface):
Raises:
tf.errors.OpError: Or one of its subclasses on error.
"""
- if self._created_with_new_api:
- # pylint: disable=protected-access
- feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
- fetches = [t._as_tf_output() for t in fetch_list]
- targets = [op._c_op for op in target_list]
- # pylint: enable=protected-access
- else:
- feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
- fetches = _name_list(fetch_list)
- targets = _name_list(target_list)
+ # pylint: disable=protected-access
+ feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
+ fetches = [t._as_tf_output() for t in fetch_list]
+ targets = [op._c_op for op in target_list]
+ # pylint: enable=protected-access
def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
# Ensure any changes to the graph are reflected in the runtime.
@@ -1335,22 +1291,8 @@ class BaseSession(SessionInterface):
raise type(e)(node_def, op, message)
def _extend_graph(self):
- if self._created_with_new_api:
- with self._graph._lock: # pylint: disable=protected-access
- tf_session.ExtendSession(self._session)
- else:
- # Ensure any changes to the graph are reflected in the runtime.
- with self._extend_lock:
- if self._graph.version > self._current_version:
- # pylint: disable=protected-access
- graph_def, self._current_version = self._graph._as_graph_def(
- from_version=self._current_version, add_shapes=self._add_shapes)
- # pylint: enable=protected-access
-
- with errors.raise_exception_on_not_ok_status() as status:
- tf_session.TF_ExtendGraph(self._session,
- graph_def.SerializeToString(), status)
- self._opened = True
+ with self._graph._session_run_lock(): # pylint: disable=protected-access
+ tf_session.ExtendSession(self._session)
# The threshold to run garbage collection to delete dead tensors.
_DEAD_HANDLES_THRESHOLD = 10
@@ -1403,24 +1345,13 @@ class BaseSession(SessionInterface):
def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
run_metadata):
- if self._created_with_new_api:
- return tf_session.TF_SessionRun_wrapper(
- self._session, options, feed_dict, fetch_list, target_list,
- run_metadata)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_Run(
- self._session, options, feed_dict, fetch_list, target_list,
- status, run_metadata)
+ return tf_session.TF_SessionRun_wrapper(
+ self._session, options, feed_dict, fetch_list, target_list,
+ run_metadata)
def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
- if self._created_with_new_api:
- return tf_session.TF_SessionPRun_wrapper(
- self._session, handle, feed_dict, fetch_list)
- else:
- with errors.raise_exception_on_not_ok_status() as status:
- return tf_session.TF_PRun(
- self._session, handle, feed_dict, fetch_list, status)
+ return tf_session.TF_SessionPRun_wrapper(
+ self._session, handle, feed_dict, fetch_list)
# pylint: disable=protected-access
class _Callable(object):
@@ -1433,25 +1364,29 @@ class BaseSession(SessionInterface):
compat.as_bytes(callable_options.SerializeToString()))
try:
with errors.raise_exception_on_not_ok_status() as status:
- if session._created_with_new_api:
- self._handle = tf_session.TF_SessionMakeCallable(
- session._session, options_ptr, status)
- else:
- self._handle = tf_session.TF_DeprecatedSessionMakeCallable(
- session._session, options_ptr, status)
+ self._handle = tf_session.TF_SessionMakeCallable(
+ session._session, options_ptr, status)
finally:
tf_session.TF_DeleteBuffer(options_ptr)
- def __call__(self, *args):
+ def __call__(self, *args, **kwargs):
# TODO(b/74355905): Support argument and return value nested structures,
# and tensor-like objects such as SparseTensors.
- with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- return tf_session.TF_SessionRunCallable(
- self._session._session, self._handle, args, status, None)
- else:
- return tf_session.TF_DeprecatedSessionRunCallable(
- self._session._session, self._handle, args, status, None)
+ run_metadata = kwargs.get('run_metadata', None)
+ try:
+ run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
+ # TODO(mrry): Switch to raising an exception from the SWIG wrapper.
+ with errors.raise_exception_on_not_ok_status() as status:
+ ret = tf_session.TF_SessionRunCallable(
+ self._session._session, self._handle, args, status,
+ run_metadata_ptr)
+ if run_metadata:
+ proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
+ run_metadata.ParseFromString(compat.as_bytes(proto_data))
+ finally:
+ if run_metadata_ptr:
+ tf_session.TF_DeleteBuffer(run_metadata_ptr)
+ return ret
def __del__(self):
# NOTE(mrry): It is possible that `self._session.__del__()` could be
@@ -1459,12 +1394,8 @@ class BaseSession(SessionInterface):
# will be `None`.
if self._handle is not None and self._session._session is not None:
with errors.raise_exception_on_not_ok_status() as status:
- if self._session._created_with_new_api:
- tf_session.TF_SessionReleaseCallable(
- self._session._session, self._handle, status)
- else:
- tf_session.TF_DeprecatedSessionReleaseCallable(
- self._session._session, self._handle, status)
+ tf_session.TF_SessionReleaseCallable(
+ self._session._session, self._handle, status)
# pylint: enable=protected-access
# TODO(b/74355905): Reimplement `Session.make_callable()` using this method
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 482497078c..b72e029d1c 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import random
import os
import sys
import threading
@@ -1040,40 +1041,72 @@ class SessionTest(test_util.TensorFlowTestCase):
for t in threads:
t.join()
- def testParallelRunAndBuild(self):
+ @staticmethod
+ def _build_graph():
+ time.sleep(random.random() * 0.1)
+ # Do some graph construction. Try to exercise non-trivial paths.
+ graph = ops.get_default_graph()
+ gdef = None
+ for _ in range(10):
+ x = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.colocate_with(x):
+ y = array_ops.placeholder(dtype=dtypes.float32)
+ with ops.device('/cpu:0'):
+ z = control_flow_ops.while_loop(
+ lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
+ with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
+ gradients_impl.gradients(z, [x, y])
+ if gdef is None:
+ gdef = graph.as_graph_def()
+ else:
+ importer.import_graph_def(gdef, name='import')
+
+ def testParallelRunAndSingleBuild(self):
with session.Session() as sess:
c = constant_op.constant(5.0)
stop = threading.Event()
def run_loop():
while not stop.is_set():
+ time.sleep(random.random() * 0.1)
self.assertEqual(sess.run(c), 5.0)
- threads = [self.checkedThread(target=run_loop) for _ in range(100)]
+ threads = [self.checkedThread(target=run_loop) for _ in range(10)]
for t in threads:
t.start()
- # Do some graph construction. Try to exercise non-trivial paths.
- graph = ops.get_default_graph()
- gdef = None
- for _ in range(10):
- x = array_ops.placeholder(dtype=dtypes.float32)
- with ops.colocate_with(x):
- y = array_ops.placeholder(dtype=dtypes.float32)
- with ops.device('/cpu:0'):
- z = control_flow_ops.while_loop(
- lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
- with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
- gradients_impl.gradients(z, [x, y])
- if gdef is None:
- gdef = graph.as_graph_def()
- else:
- importer.import_graph_def(gdef, name='import')
+ SessionTest._build_graph()
stop.set()
for t in threads:
t.join()
+ def testParallelRunAndParallelBuild(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ stop = threading.Event()
+
+ def run_loop():
+ while not stop.is_set():
+ time.sleep(random.random() * 0.1)
+ self.assertEqual(sess.run(c), 5.0)
+
+ run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
+ for t in run_threads:
+ t.start()
+
+ build_threads = [self.checkedThread(target=SessionTest._build_graph)
+ for _ in range(10)]
+ for t in build_threads:
+ t.start()
+ for t in build_threads:
+ t.join()
+
+ # Let the run_threads run until the build threads are finished.
+ stop.set()
+ for t in run_threads:
+ t.join()
+
def testRunFeedDict(self):
with session.Session() as s:
x = array_ops.zeros([2])
@@ -1364,6 +1397,20 @@ class SessionTest(test_util.TensorFlowTestCase):
for _ in range(5):
self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
+ def testOptimizedMakeCallableWithRunMetadata(self):
+ with session.Session() as sess:
+ ph = array_ops.placeholder(dtypes.float32)
+ a = math_ops.add(ph, 1.0)
+ callable_opts = config_pb2.CallableOptions()
+ callable_opts.feed.append(ph.name)
+ callable_opts.fetch.append(a.name)
+ callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
+ callable_fn = sess._make_callable_from_options(callable_opts)
+ run_metadata = config_pb2.RunMetadata()
+ self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32),
+ run_metadata=run_metadata))
+ self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
+
def testFeedError(self):
with session.Session() as sess:
feed_t = array_ops.placeholder(dtype=dtypes.float32)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 1db1432d65..985cb90436 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -135,7 +135,7 @@ tensorflow::ImportNumpy();
// Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers
%typemap(out) int64_t {
- $result = PyInt_FromLong($1);
+ $result = PyLong_FromLongLong($1);
}
// We use TF_OperationGetControlInputs_wrapper instead of
@@ -610,7 +610,7 @@ def TF_Reset(target, containers=None, config=None):
}
for (size_t i = 0; i < $1.size(); ++i) {
- PyList_SET_ITEM($result, i, PyInt_FromLong($1[i]));
+ PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
@@ -673,7 +673,7 @@ def TF_Reset(target, containers=None, config=None):
}
for (size_t i = 0; i < $1.size(); ++i) {
- PyList_SET_ITEM($result, i, PyInt_FromLong($1[i]));
+ PyList_SET_ITEM($result, i, PyLong_FromLongLong($1[i]));
}
}
diff --git a/tensorflow/python/compat/BUILD b/tensorflow/python/compat/BUILD
new file mode 100644
index 0000000000..58ceafca06
--- /dev/null
+++ b/tensorflow/python/compat/BUILD
@@ -0,0 +1,22 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+py_library(
+ name = "compat",
+ srcs = ["compat.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+)
+
+tf_py_test(
+ name = "compat_test",
+ size = "small",
+ srcs = ["compat_test.py"],
+ additional_deps = [
+ ":compat",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
new file mode 100644
index 0000000000..68a6421c2c
--- /dev/null
+++ b/tensorflow/python/compat/compat.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for API compatibility between TensorFlow release versions.
+
+See
+@{$guide/version_compat#backward_and_partial_forward_compatibility}
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+from tensorflow.python.util import tf_contextlib
+
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 1)
+
+
+def forward_compatible(year, month, day):
+ """Return true if the forward compatibility window has expired.
+
+ Forward-compatibility refers to scenarios where the producer of a TensorFlow
+ model (a GraphDef or SavedModel) is compiled against a version of the
+ TensorFlow library newer than what the consumer was compiled against. The
+ "producer" is typically a Python program that constructs and trains a model
+ while the "consumer" is typically another program that loads and serves the
+ model.
+
+ TensorFlow has been supporting a 3 week forward-compatibility window for
+ programs compiled from source at HEAD.
+
+ For example, consider the case where a new operation `MyNewAwesomeAdd` is
+ created with the intent of replacing the implementation of an existing Python
+ wrapper - `tf.add`. The Python wrapper implementation should change from
+ something like:
+
+ ```python
+ def add(inputs, name=None):
+ return gen_math_ops.add(inputs, name)
+ ```
+
+ to:
+
+ ```python
+ from tensorflow.python.compat import compat
+
+ def add(inputs, name=None):
+ if compat.forward_compatible(year, month, day):
+ # Can use the awesome new implementation.
+ return gen_math_ops.my_new_awesome_add(inputs, name)
+ # To maintain forward compatibiltiy, use the old implementation.
+ return gen_math_ops.add(inputs, name)
+ ```
+
+ Where `year`, `month`, and `day` specify the date beyond which binaries
+ that consume a model are expected to have been updated to include the
+ new operations. This date is typically at least 3 weeks beyond the date
+ the code that adds the new operation is committed.
+
+ Args:
+ year: A year (e.g., 2018).
+ month: A month (1 <= month <= 12) in year.
+ day: A day (1 <= day <= 31, or 30, or 29, or 28) in month.
+
+ Returns:
+ True if the caller can expect that serialized TensorFlow graphs produced
+ can be consumed by programs that are compiled with the TensorFlow library
+ source code after (year, month, day).
+ """
+ return _FORWARD_COMPATIBILITY_HORIZON > datetime.date(year, month, day)
+
+
+@tf_contextlib.contextmanager
+def forward_compatibility_horizon(year, month, day):
+ """Context manager for testing forward compatibility of generated graphs.
+
+ To ensure forward compatibility of generated graphs (see `forward_compatible`)
+ with older binaries, new features can be gated with:
+
+ ```python
+ if compat.forward_compatible(year=2018, month=08, date=01):
+ generate_graph_with_new_features()
+ else:
+ generate_graph_so_older_binaries_can_consume_it()
+ ```
+
+ However, when adding new features, one may want to unittest it before
+ the forward compatibility window expires. This context manager enables
+ such tests. For example:
+
+ ```python
+ from tensorflow.python.compat import compat
+
+ def testMyNewFeature(self):
+ with compat.forward_compatibility_horizon(2018, 08, 02):
+ # Test that generate_graph_with_new_features() has an effect
+ ```
+
+ Args :
+ year: A year (e.g. 2018).
+ month: A month (1 <= month <= 12) in year.
+ day: A day (1 <= day <= 31, or 30, or 29, or 28) in month.
+
+ Yields:
+ Nothing.
+ """
+ global _FORWARD_COMPATIBILITY_HORIZON
+ try:
+ old_compat_date = _FORWARD_COMPATIBILITY_HORIZON
+ _FORWARD_COMPATIBILITY_HORIZON = datetime.date(year, month, day)
+ yield
+ finally:
+ _FORWARD_COMPATIBILITY_HORIZON = old_compat_date
diff --git a/tensorflow/python/compat/compat_test.py b/tensorflow/python/compat/compat_test.py
new file mode 100644
index 0000000000..946abbb300
--- /dev/null
+++ b/tensorflow/python/compat/compat_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for forward and backwards compatibility utilties."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+from tensorflow.python.compat import compat
+from tensorflow.python.platform import test
+
+
+class CompatTest(test.TestCase):
+
+ def _compatibility_date(self):
+ date = compat._FORWARD_COMPATIBILITY_HORIZON # pylint: disable=protected-access
+ return (date.year, date.month, date.day)
+
+ def _n_days_after(self, n):
+ date = compat._FORWARD_COMPATIBILITY_HORIZON + datetime.timedelta(days=n) # pylint: disable=protected-access
+ return (date.year, date.month, date.day)
+
+ def test_basic(self):
+ compatibility_date = self._compatibility_date()
+ one_day_before = self._n_days_after(-1)
+ self.assertTrue(compat.forward_compatible(*one_day_before))
+ self.assertFalse(compat.forward_compatible(*compatibility_date))
+
+ def test_decorator(self):
+ compatibility_date = self._compatibility_date()
+ one_day_after = self._n_days_after(1)
+ with compat.forward_compatibility_horizon(*one_day_after):
+ self.assertTrue(compat.forward_compatible(*compatibility_date))
+ self.assertFalse(compat.forward_compatible(*one_day_after))
+
+ # After exiting context manager, value should be reset.
+ self.assertFalse(compat.forward_compatible(*compatibility_date))
+
+ def test_decorator_with_failure(self):
+ compatibility_date = self._compatibility_date()
+ one_day_after = self._n_days_after(1)
+
+ class DummyError(Exception):
+ pass
+
+ try:
+ with compat.forward_compatibility_horizon(*one_day_after):
+ raise DummyError()
+ except DummyError:
+ pass # silence DummyError
+
+ # After exiting context manager, value should be reset.
+ self.assertFalse(compat.forward_compatible(*compatibility_date))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index ed0c11e6c1..38505c0a01 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -73,6 +74,17 @@ tf_py_test(
)
tf_py_test(
+ name = "dataset_ops_test",
+ size = "small",
+ srcs = ["dataset_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
name = "filter_dataset_op_test",
size = "small",
srcs = ["filter_dataset_op_test.py"],
@@ -167,6 +179,7 @@ tf_py_test(
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
@@ -336,6 +349,7 @@ tf_py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
],
grpc_enabled = True,
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index bd80b9dbf5..89de55dd4f 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -18,10 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
+import time
+from absl.testing import parameterized
import numpy as np
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,73 +37,83 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase):
+class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('even', 28, 14, False),
+ ('uneven_with_remainder', 28, 15, False),
+ ('uneven_without_remainder', 28, 15, True),
+ ('empty', 0, 14, False),
+ )
+ def testBatchDataset(self, count, batch_size, drop_remainder):
+ """Tests the batch dataset logic for various input configurations.
+
+ Args:
+ count: the number of input elements
+ batch_size: the batch size
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
- def testBatchDataset(self):
- """Test an dataset that maps a TF function across its input elements."""
# The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
# RepeatDataset(count) -> BatchDataset(batch_size).
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
np.array(37.0) * np.arange(7))
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
def _map_fn(x, y, z):
return math_ops.square(x), math_ops.square(y), math_ops.square(z)
iterator = (
dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size).make_initializable_iterator())
+ .repeat(count).batch(batch_size,
+ drop_remainder).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ if drop_remainder:
+ dim0 = batch_size
+ else:
+ dim0 = None
+ self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
with self.test_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ batch_size_t: batch_size,
+ drop_remainder_t: drop_remainder
+ })
+ num_full_batches = (count * 7) // batch_size
+ for i in range(num_full_batches):
result = sess.run(get_next)
for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ for j in range(batch_size):
+ self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
+ if not drop_remainder and (count * 7) % batch_size > 0:
result = sess.run(get_next)
for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
+ for j in range((count * 7) % batch_size):
+ self.assertAllEqual(
+ component[(num_full_batches * batch_size + j) % 7]**2,
+ result_component[j])
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ def testBatchDatasetInvalidBatchSize(self):
+ iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
+ get_next = iterator.get_next()
- # Empty batch should be an initialization time error.
+ with self.test_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+ sess.run(get_next)
def assertSparseValuesEqual(self, a, b):
self.assertAllEqual(a.indices, b.indices)
@@ -210,66 +222,108 @@ class BatchDatasetTest(test.TestCase):
r'First element had shape \[3\] and element 2 had shape \[4\].'):
sess.run(next_element)
- def testPaddedBatchDataset(self):
- seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
- padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
+
+def _random_seq_lens(count):
+ return np.random.randint(20, size=(count,)).astype(np.int32)
+
+
+class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('default_padding', _random_seq_lens(32), 4, [-1], False),
+ ('constant_padding', _random_seq_lens(32), 4, [25], False),
+ ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
+ ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
+ )
+ def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
+ drop_remainder):
+ """Tests the padded batch dataset logic for various input configurations.
+
+ Args:
+ seq_lens: the input sequence lengths
+ batch_size: the batch size
+ padded_shapes: the padded shapes to use
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
+
+ seq_lens_t = array_ops.placeholder(dtypes.int32, shape=[None])
+ batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ padded_shapes_t = array_ops.placeholder(dtypes.int64, shape=[1])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
iterator = (
- dataset_ops.Dataset.from_tensor_slices(seq_lens)
+ dataset_ops.Dataset.from_tensor_slices(seq_lens_t)
.map(lambda x: array_ops.fill([x], x)).padded_batch(
- 4, padded_shapes=padded_shape).make_initializable_iterator())
+ batch_size=batch_size_t,
+ drop_remainder=drop_remainder_t,
+ padded_shapes=padded_shapes_t).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
- # Test with random sequence lengths, and max padding.
- random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
sess.run(
- init_op, feed_dict={
- padded_shape: [-1],
- seq_lens: random_seq_lens
+ init_op,
+ feed_dict={
+ seq_lens_t: seq_lens,
+ batch_size_t: batch_size,
+ padded_shapes_t: padded_shapes,
+ drop_remainder_t: drop_remainder,
})
- for i in range(8):
+
+ num_full_batches = len(seq_lens) // batch_size
+
+ for i in range(num_full_batches):
result = sess.run(get_next)
- padded_len = np.max(result)
- self.assertEqual((4, padded_len), result.shape)
- for j in range(4):
- seq_len = random_seq_lens[(i * 4) + j]
+ padded_len = padded_shapes[0]
+ if padded_len is None or padded_len == -1:
+ padded_len = np.max(result) if result.size > 0 else 0
+ self.assertEqual((batch_size, padded_len), result.shape)
+ for j in range(batch_size):
+ seq_len = seq_lens[(i * batch_size) + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.assertAllEqual(result[j, seq_len:],
+ [0] * (padded_len - seq_len))
- # Test with random sequence lengths, and constant padding.
- sess.run(
- init_op, feed_dict={
- padded_shape: [25],
- seq_lens: random_seq_lens
- })
- for i in range(8):
+ if not drop_remainder and len(seq_lens) % batch_size > 0:
result = sess.run(get_next)
- self.assertEqual((4, 25), result.shape)
- for j in range(4):
- seq_len = random_seq_lens[(i * 4) + j]
+ padded_len = np.max(result) if result.size > 0 else 0
+ self.assertEqual((len(seq_lens) % batch_size, padded_len),
+ result.shape)
+ for j in range(len(seq_lens) % batch_size):
+ seq_len = seq_lens[num_full_batches * batch_size + j]
self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
+ self.assertAllEqual(result[j, seq_len:],
+ [0] * (padded_len - seq_len))
+
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Test correct handling of empty tensors.
- sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
+ def testPaddedBatchShortPadding(self):
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices([6, 5, 5, 5, 5])
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaises(errors.DataLossError):
+ sess.run(get_next)
+
+ def testPaddedBatchEmptyTensors(self):
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices([0, 0, 0, 0])
+ .map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
result = sess.run(get_next)
self.assertAllEqual([[], [], [], []], result)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- # Test error handling with constant sequence lengths, and
- # too-short padding.
- sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
- with self.assertRaises(errors.DataLossError):
- result = sess.run(get_next)
-
def testPaddedBatchDatasetNonDefaultPadding(self):
seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
@@ -371,6 +425,94 @@ class BatchDatasetTest(test.TestCase):
with self.assertRaises(TypeError):
_ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
+ def testPaddedBatchShapeError(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(3,\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its shape was \(2, 2\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[[1, 1], [1, 1]])
+
+ with self.assertRaisesRegexp(
+ TypeError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its element type was float32.'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=constant_op.constant([1., 2., 3.]))
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(\?, \?\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
+
+class BatchDatasetBenchmark(test.Benchmark):
+
+ def benchmarkBatchSparse(self):
+ non_zeros_per_row_values = [0, 1, 5, 10, 100]
+ batch_size_values = [1, 32, 64, 128, 1024]
+
+ sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
+ batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
+
+ dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
+ ).batch(batch_size_placeholder)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ for non_zeros_per_row in non_zeros_per_row_values:
+
+ sparse_value = sparse_tensor.SparseTensorValue(
+ indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
+ values=np.arange(non_zeros_per_row, dtype=np.int64),
+ dense_shape=[1000])
+
+ for batch_size in batch_size_values:
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer, feed_dict={
+ sparse_placeholder: sparse_value,
+ batch_size_placeholder: batch_size})
+ # Run five steps to warm up the session caches before taking the
+ # first measurement.
+ for _ in range(5):
+ sess.run(next_element.indices.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.indices.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100.0
+
+ print('Batch sparse dataset non-zeros per row: %d batch_size: %d '
+ 'wall time: %f'
+ % (non_zeros_per_row, batch_size, median_wall_time))
+ self.report_benchmark(
+ iters=10000, wall_time=median_wall_time,
+ name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % (
+ non_zeros_per_row, batch_size))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index 296a76ec88..fb55ae1400 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -259,9 +259,7 @@ class DatasetConstructorTest(test.TestCase):
sess.run(init_op)
self.assertAllEqual([1, 2, 3], sess.run(get_next))
self.assertAllEqual([4, 5, 6], sess.run(get_next))
- # NOTE(mrry): Type name in message differs between Python 2 (`long`) and
- # 3 (`int`).
- with self.assertRaisesOpError(r"invalid literal for"):
+ with self.assertRaisesOpError("The expected type was int64"):
sess.run(get_next)
self.assertAllEqual([7, 8, 9], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
@@ -290,6 +288,34 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testFromGeneratorStructureError(self):
+ def generator():
+ yield 1, 2
+ yield 3, 4
+ yield 5
+ yield 6, 7, 8
+ yield 9, 10
+
+ iterator = (dataset_ops.Dataset.from_generator(
+ generator, output_types=(dtypes.int64, dtypes.int64))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ self.assertEqual((1, 2), sess.run(get_next))
+ self.assertEqual((3, 4), sess.run(get_next))
+ with self.assertRaisesOpError(
+ r"The expected structure was \(tf\.int64, tf\.int64\)"):
+ sess.run(get_next)
+ with self.assertRaisesOpError(
+ r"The expected structure was \(tf\.int64, tf\.int64\)"):
+ sess.run(get_next)
+ self.assertEqual((9, 10), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testFromGeneratorHeterogeneous(self):
def generator():
yield 1
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
new file mode 100644
index 0000000000..2c4c11e132
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -0,0 +1,37 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the input pipeline ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class DatasetOpsTest(test.TestCase):
+
+ def testAsSerializedGraph(self):
+ dataset = dataset_ops.Dataset.range(10)
+ with self.test_session() as sess:
+ graph = graph_pb2.GraphDef().FromString(
+ sess.run(dataset._as_serialized_graph()))
+ self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
index 820c167b6b..b434fa7334 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.compat import compat as forward_compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -415,6 +416,69 @@ class IteratorTest(test.TestCase):
sess.run(
next_element, feed_dict={handle_placeholder: iterator_4_handle})
+ def testIteratorStringHandleFuture(self):
+ with forward_compat.forward_compatibility_horizon(2018, 8, 4):
+ dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
+ dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])
+
+ iterator_3 = dataset_3.make_one_shot_iterator()
+ iterator_4 = dataset_4.make_one_shot_iterator()
+
+ handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
+ feedable_iterator = iterator_ops.Iterator.from_string_handle(
+ handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
+ next_element = feedable_iterator.get_next()
+
+ self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
+ self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
+ self.assertEqual([], feedable_iterator.output_shapes)
+
+ with self.test_session() as sess:
+ iterator_3_handle = sess.run(iterator_3.string_handle())
+ iterator_4_handle = sess.run(iterator_4.string_handle())
+
+ self.assertEqual(
+ 10,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 1,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 20,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 2,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 30,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ self.assertEqual(
+ 3,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_3_handle}))
+ self.assertEqual(
+ 40,
+ sess.run(
+ next_element,
+ feed_dict={handle_placeholder: iterator_4_handle}))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(
+ next_element, feed_dict={handle_placeholder: iterator_3_handle})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(
+ next_element, feed_dict={handle_placeholder: iterator_4_handle})
+
def testIteratorStringHandleReuseTensorObject(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
one_shot_iterator = dataset.make_one_shot_iterator()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 1ad0b9de5e..637bde9ae4 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from collections import namedtuple
import threading
import time
+import warnings
import numpy as np
@@ -638,6 +639,40 @@ class MapDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testWarnOnLookupTable(self):
+ def collecting_function(x):
+ _ = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer([], []), 0.0, name="t1")
+ return x
+
+ warnings.simplefilter("always")
+ with warnings.catch_warnings(record=True) as w:
+ _ = dataset_ops.Dataset.range(10).map(collecting_function)
+ # NOTE(mrry): Python 3 prints other warnings in addition to the one we are
+ # testing, so we search for the expected warning.
+ self.assertGreaterEqual(len(w), 1)
+ found_warning = False
+ for warning in w:
+ if ("Creating lookup tables inside a function passed to Dataset.map() is "
+ "not supported." in str(warning)):
+ found_warning = True
+ break
+ self.assertTrue(found_warning)
+
+ def testNestedDatasetError(self):
+ dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
+ with self.assertRaisesRegexp(
+ NotImplementedError, r"The Dataset.map\(\) transformation does not "
+ "currently support nested datasets as outputs."):
+ _ = dataset.map(dataset_ops.Dataset.from_tensor_slices)
+
+ def testReturnValueError(self):
+ dataset = dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])
+ with self.assertRaisesRegexp(
+ TypeError, r"Unsupported return value from function passed to "
+ r"Dataset.map\(\): None."):
+ _ = dataset.map(lambda x: None)
+
class MapDatasetBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index 646324cb95..63a0830272 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -24,35 +26,33 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase):
+class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
- def testBufferSize(self):
- buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
+ @parameterized.parameters((-1), (0), (5))
+ def testBufferSize(self, buffer_size):
+ buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size_t).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: 5})
+ sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
for m in range(10):
self.assertEqual(m, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testInvalidBufferSize(self):
- buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
+ @parameterized.parameters((-2), (-42))
+ def testInvalidBufferSize(self, buffer_size):
+ buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size_t).make_initializable_iterator()
init_op = iterator.initializer
with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: 0})
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
- with self.test_session() as sess:
- sess.run(init_op, feed_dict={buffer_size: -5})
+ sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
if __name__ == "__main__":
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index fa2e86eab1..f15eb6310f 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -40,6 +40,7 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/compat",
"//tensorflow/python/data/util:convert",
],
)
@@ -54,6 +55,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 0e020d86d0..88de4b588c 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -19,10 +19,12 @@ from __future__ import print_function
import abc
import threading
+import warnings
import numpy as np
import six
+from tensorflow.python.compat import compat
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
@@ -32,6 +34,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@@ -57,6 +60,15 @@ class Dataset(object):
def __init__(self):
pass
+ def _as_serialized_graph(self):
+ """Produces serialized graph representation of the dataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.string` type, representing this dataset as a
+ serialized graph.
+ """
+ return gen_dataset_ops.dataset_to_graph(self._as_variant_tensor())
+
@abc.abstractmethod
def _as_variant_tensor(self):
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
@@ -96,13 +108,12 @@ class Dataset(object):
"execution is enabled.")
if shared_name is None:
shared_name = ""
- iterator_resource = gen_dataset_ops.iterator(
- container="",
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ if compat.forward_compatible(2018, 8, 3):
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="", shared_name=shared_name, **flat_structure(self))
+ else:
+ iterator_resource = gen_dataset_ops.iterator(
+ container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
iterator_resource)
@@ -160,13 +171,8 @@ class Dataset(object):
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
- dataset_factory=_make_dataset,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes,
- self.output_classes))), None,
- self.output_types, self.output_shapes, self.output_classes)
+ dataset_factory=_make_dataset, **flat_structure(self)),
+ None, self.output_types, self.output_shapes, self.output_classes)
@abc.abstractproperty
def output_classes(self):
@@ -412,13 +418,23 @@ class Dataset(object):
# Use the same _convert function from the py_func() implementation to
# convert the returned values to arrays early, so that we can inspect
# their values.
- # pylint: disable=protected-access
- ret_arrays = [
- script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
- for ret, dtype in zip(
- nest.flatten_up_to(output_types, values), flattened_types)
- ]
- # pylint: enable=protected-access
+ try:
+ flattened_values = nest.flatten_up_to(output_types, values)
+ except (TypeError, ValueError):
+ raise TypeError(
+ "`generator` yielded an element that did not match the expected "
+ "structure. The expected structure was %s, but the yielded "
+ "element was %s." % (output_types, values))
+ ret_arrays = []
+ for ret, dtype in zip(flattened_values, flattened_types):
+ try:
+ ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
+ ret, dtype=dtype.as_numpy_dtype))
+ except (TypeError, ValueError):
+ raise TypeError(
+ "`generator` yielded an element that could not be converted to "
+ "the expected type. The expected type was %s, but the yielded "
+ "element was %s." % (dtype.name, ret))
# Additional type and shape checking to ensure that the components
# of the generated element match the `output_types` and `output_shapes`
@@ -795,35 +811,50 @@ class Dataset(object):
return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
- def batch(self, batch_size):
+ def batch(self, batch_size, drop_remainder=False):
"""Combines consecutive elements of this dataset into batches.
- NOTE: If the number of elements (`N`) in this dataset is not an exact
- multiple of `batch_size`, the final batch contain smaller tensors with
- shape `N % batch_size` in the batch dimension. If your program depends on
- the batches having the same shape, consider using the
- @{tf.contrib.data.batch_and_drop_remainder} transformation instead.
+ The tensors in the resulting element will have an additional outer
+ dimension, which will be `batch_size` (or `N % batch_size` for the last
+ element if `batch_size` does not divide the number of input elements `N`
+ evenly and `drop_remainder` is `False`). If your program depends on the
+ batches having the same outer dimension, you should set the `drop_remainder`
+ argument to `True` to prevent the smaller batch from being produced.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in the case its has fewer than
+ `batch_size` elements; the default behavior is not to drop the smaller
+ batch.
Returns:
Dataset: A `Dataset`.
"""
- return BatchDataset(self, batch_size)
+ return BatchDataset(self, batch_size, drop_remainder)
- def padded_batch(self, batch_size, padded_shapes, padding_values=None):
+ def padded_batch(self,
+ batch_size,
+ padded_shapes,
+ padding_values=None,
+ drop_remainder=False):
"""Combines consecutive elements of this dataset into padded batches.
This transformation combines multiple consecutive elements of the input
- dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors
- in the resulting element have an additional outer dimension, which will be
- `batch_size` for all but the last element, and `N % batch_size` for the
- last element (where `N` is the number of elements in this dataset). Unlike
- @{tf.data.Dataset.batch}, the elements may have different shapes for some
- of their components, and this transformation will pad each component to
- the respective shape in `padding_shapes`. The `padding_shapes` argument
+ dataset into a single element.
+
+ Like @{tf.data.Dataset.batch}, the tensors in the resulting element will
+ have an additional outer dimension, which will be `batch_size` (or
+ `N % batch_size` for the last element if `batch_size` does not divide the
+ number of input elements `N` evenly and `drop_remainder` is `False`). If
+ your program depends on the batches having the same outer dimension, you
+ should set the `drop_remainder` argument to `True` to prevent the smaller
+ batch from being produced.
+
+ Unlike @{tf.data.Dataset.batch}, the input elements to be batched may have
+ different shapes, and this transformation will pad each component to the
+ respective shape in `padding_shapes`. The `padding_shapes` argument
determines the resulting shape for each dimension of each component in an
output element:
@@ -833,12 +864,6 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- NOTE: If the number of elements (`N`) in this dataset is not an exact
- multiple of `batch_size`, the final batch contain smaller tensors with
- shape `N % batch_size` in the batch dimension. If your program depends on
- the batches having the same shape, consider using the
- @{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead.
-
See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements
that may have different shapes into a @{tf.SparseTensor}.
@@ -856,14 +881,95 @@ class Dataset(object):
`tf.Tensor`, representing the padding values to use for the
respective components. Defaults are `0` for numeric types and
the empty string for string types.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in the case its has fewer than
+ `batch_size` elements; the default behavior is not to drop the smaller
+ batch.
Returns:
Dataset: A `Dataset`.
"""
- return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
+ return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values,
+ drop_remainder)
def map(self, map_func, num_parallel_calls=None):
- """Maps `map_func` across this dataset.
+ """Maps `map_func` across the elements of this dataset.
+
+ This transformation applies `map_func` to each element of this dataset, and
+ returns a new dataset containing the transformed elements, in the same
+ order as they appeared in the input.
+
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3, 4, 5 }
+
+ a.map(lambda x: x + 1) = { 2, 3, 4, 5, 6 }
+ ```
+
+ The input signature of `map_func` is determined by the structure of each
+ element in this dataset. For example:
+
+ ```python
+ # Each element is a `tf.Tensor` object.
+ a = { 1, 2, 3, 4, 5 }
+ # `map_func` takes a single argument of type `tf.Tensor` with the same
+ # shape and dtype.
+ result = a.map(lambda x: ...)
+
+ # Each element is a tuple containing two `tf.Tensor` objects.
+ b = { (1, "foo"), (2, "bar"), (3, "baz") }
+ # `map_func` takes two arguments of type `tf.Tensor`.
+ result = b.map(lambda x_int, y_str: ...)
+
+ # Each element is a dictionary mapping strings to `tf.Tensor` objects.
+ c = { {"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}, {"a": 3, "b": "baz"} }
+ # `map_func` takes a single argument of type `dict` with the same keys as
+ # the elements.
+ result = c.map(lambda d: ...)
+ ```
+
+ The value or values returned by `map_func` determine the structure of each
+ element in the returned dataset.
+
+ ```python
+ # `map_func` returns a scalar `tf.Tensor` of type `tf.float32`.
+ def f(...):
+ return tf.constant(37.0)
+ result = dataset.map(f)
+ result.output_classes == tf.Tensor
+ result.output_types == tf.float32
+ result.output_shapes == [] # scalar
+
+ # `map_func` returns two `tf.Tensor` objects.
+ def g(...):
+ return tf.constant(37.0), tf.constant(["Foo", "Bar", "Baz"])
+ result = dataset.map(g)
+ result.output_classes == (tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string)
+ result.output_shapes == ([], [3])
+
+ # Python primitives, lists, and NumPy arrays are implicitly converted to
+ # `tf.Tensor`.
+ def h(...):
+ return 37.0, ["Foo", "Bar", "Baz"], np.array([1.0, 2.0] dtype=np.float64)
+ result = dataset.map(h)
+ result.output_classes == (tf.Tensor, tf.Tensor, tf.Tensor)
+ result.output_types == (tf.float32, tf.string, tf.float64)
+ result.output_shapes == ([], [3], [2])
+
+ # `map_func` can return nested structures.
+ def i(...):
+ return {"a": 37.0, "b": [42, 16]}, "foo"
+ result.output_classes == ({"a": tf.Tensor, "b": tf.Tensor}, tf.Tensor)
+ result.output_types == ({"a": tf.float32, "b": tf.int32}, tf.string)
+ result.output_shapes == ({"a": [], "b": [2]}, [])
+ ```
+
+ In addition to `tf.Tensor` objects, `map_func` can accept as arguments and
+ return `tf.SparseTensor` objects.
Args:
map_func: A function mapping a nested structure of tensors (having
@@ -972,7 +1078,8 @@ class Dataset(object):
scalar `tf.bool` tensor.
Returns:
- Dataset: A `Dataset`.
+ Dataset: The `Dataset` containing the elements of this dataset for which
+ `predicate` is `True`.
"""
return FilterDataset(self, predicate)
@@ -1123,6 +1230,313 @@ class SparseTensorSliceDataset(Dataset):
return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
+class _NestedDatasetComponent(object):
+ """The structure of a `Dataset` nested in a component of another `Dataset`.
+
+ A `StructuredFunctionWrapper` around a function that returns a `Dataset` as
+ one of its components will have a `NestedDatasetComponent` in the
+ corresponding position in the `output_classes`, `output_shapes`, and
+ `output_types` properties.
+
+ NOTE(mrry): This class is not currently exposed via the public API. Support
+ for nested datasets can be enabled on a function-by-function basis by setting
+ `experimental_nested_dataset_support=True` in the `StructuredFunctionWrapper`
+ initializer.
+
+ TODO(b/110122868): Add this class, or something equivalent, to the public API.
+ We are considering revising the public API for accessing Dataset structure
+ (`output_classes` etc.) based on experience with nested datasets and other
+ custom component types.
+ """
+
+ def __init__(self,
+ dataset=None,
+ output_shapes=None,
+ output_types=None,
+ output_classes=None):
+ if dataset is None:
+ if (output_classes is None or output_shapes is None or
+ output_types is None):
+ raise ValueError(
+ "Either `dataset`, or all of `output_classes`, "
+ "`output_shapes`, and `output_types` must be specified.")
+ self._output_classes = output_classes
+ self._output_shapes = output_shapes
+ self._output_types = output_types
+ else:
+ if not (output_classes is None and output_shapes is None and
+ output_types is None):
+ raise ValueError(
+ "Either `dataset`, or all of `output_classes`, "
+ "`output_shapes`, and `output_types` must be specified.")
+ self._output_classes = dataset.output_classes
+ self._output_shapes = dataset.output_shapes
+ self._output_types = dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+class _VariantDataset(Dataset):
+ """A Dataset wrapper around a @{tf.variant}-typed function argument."""
+
+ def __init__(self, dataset_variant, structure):
+ super(_VariantDataset, self).__init__()
+ self._dataset_variant = dataset_variant
+ self._structure = structure
+
+ def _as_variant_tensor(self):
+ return self._dataset_variant
+
+ @property
+ def output_classes(self):
+ return self._structure.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._structure.output_shapes
+
+ @property
+ def output_types(self):
+ return self._structure.output_types
+
+
+class StructuredFunctionWrapper(object):
+ """A wrapper for `Defun` that supports structured arguments and return values.
+ """
+
+ def __init__(self, func, transformation_name, dataset=None,
+ input_classes=None, input_shapes=None, input_types=None,
+ add_to_graph=True, experimental_nested_dataset_support=False):
+ """Creates a new `StructuredFunctionWrapper` for the given function.
+
+ Args:
+ func: A function from a nested structure to another nested structure.
+ transformation_name: Human-readable name of the transformation in which
+ this function is being instantiated, for error messages.
+ dataset: (Optional.) A @{tf.data.Dataset}. If given, the structure of this
+ dataset will be assumed as the structure for `func` arguments; otherwise
+ `input_classes`, `input_shapes`, and `input_types` must be defined.
+ input_classes: (Optional.) A nested structure of `type`. If given, this
+ argument defines the Python types for `func` arguments.
+ input_shapes: (Optional.) A nested structure of @{tf.TensorShape}. If
+ given, this argument defines the shapes and structure for `func`
+ arguments.
+ input_types: (Optional.) A nested structure of @{tf.DType}. If given, this
+ argument defines the element types and structure for `func` arguments.
+ add_to_graph: (Optional.) If `True`, the function will be added to the
+ default graph.
+ experimental_nested_dataset_support: (Optional.) If `True`, the function
+ will support @{tf.data.Dataset} objects as arguments and return values.
+
+ Raises:
+ ValueError: If an invalid combination of `dataset`, `input_classes`,
+ `input_shapes`, and `input_types` is passed.
+ """
+ if dataset is None:
+ if input_classes is None or input_shapes is None or input_types is None:
+ raise ValueError("Either `dataset`, or all of `input_classes`, "
+ "`input_shapes`, and `input_types` must be specified.")
+ self._input_shapes = input_shapes
+ self._input_types = input_types
+ self._input_classes = input_classes
+ else:
+ if not (input_classes is None and input_shapes is None and
+ input_types is None):
+ raise ValueError("Either `dataset`, or all of `input_classes`, "
+ "`input_shapes`, and `input_types` must be specified.")
+ self._input_shapes = dataset.output_shapes
+ self._input_types = dataset.output_types
+ self._input_classes = dataset.output_classes
+
+ self._transformation_name = transformation_name
+
+ # TODO(b/110122868): Enable this support for all `tf.data` functions.
+ self._nested_dataset_support = experimental_nested_dataset_support
+
+ @function.Defun(*self._defun_args())
+ def tf_data_structured_function_wrapper(*args):
+ """Wrapper for passing nested structures to and from tf.data functions."""
+ flat_args = []
+ for arg, arg_class, arg_shape, arg_type in zip(
+ args,
+ nest.flatten(self._input_classes),
+ nest.flatten(self._input_shapes),
+ nest.flatten(self._input_types)):
+ # TODO(b/110122868): Add a registration mechanism for new component
+ # types.
+ if arg_class is sparse_tensor_lib.SparseTensor:
+ arg = sparse.deserialize_sparse_tensors(
+ arg, arg_type, arg_shape, arg_class)
+ arg.indices.set_shape([None, arg_shape.ndims])
+ arg.dense_shape.set_shape([arg_shape.ndims])
+ elif isinstance(arg_class, _NestedDatasetComponent):
+ assert self._nested_dataset_support
+ arg = _VariantDataset(arg, arg_class)
+ else:
+ arg.set_shape(arg_shape)
+ flat_args.append(arg)
+ nested_args = nest.pack_sequence_as(self._input_classes, flat_args)
+ if not _should_unpack_args(nested_args):
+ nested_args = (nested_args,)
+
+ ret = func(*nested_args)
+ # If `func` returns a list of tensors, `nest.flatten()` and
+ # `ops.convert_to_tensor()` would conspire to attempt to stack
+ # those tensors into a single tensor, because the customized
+ # version of `nest.flatten()` does not recurse into lists. Since
+ # it is more likely that the list arose from returning the
+ # result of an operation (such as `tf.py_func()`) that returns a
+ # list of not-necessarily-stackable tensors, we treat the
+ # returned value is a `tuple` instead. A user wishing to pack
+ # the return value into a single tensor can use an explicit
+ # `tf.stack()` before returning.
+ if isinstance(ret, list):
+ ret = tuple(ret)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ flat_ret = []
+ flat_classes = []
+ flat_shapes = []
+ flat_types = []
+ for t in nest.flatten(ret):
+ # TODO(b/110122868): Add a registration mechanism for new component
+ # types.
+ if sparse_tensor_lib.is_sparse(t):
+ t = sparse_tensor_lib.SparseTensor.from_value(t)
+ flat_ret.append(sparse.serialize_sparse_tensors(t))
+ flat_classes.append(sparse_tensor_lib.SparseTensor)
+ flat_shapes.append(t.get_shape())
+ flat_types.append(t.dtype)
+ elif isinstance(t, Dataset):
+ if not self._nested_dataset_support:
+ raise NotImplementedError(
+ "The %s transformation does not currently support nested "
+ "datasets as outputs." % self._transformation_name)
+
+ flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access
+ component = _NestedDatasetComponent(t)
+ flat_classes.append(component)
+ flat_shapes.append(component)
+ flat_types.append(component)
+ else:
+ try:
+ t = ops.convert_to_tensor(t)
+ except (ValueError, TypeError):
+ raise TypeError("Unsupported return value from function passed to "
+ "%s: %s." % (transformation_name, t))
+ flat_ret.append(t)
+ flat_classes.append(ops.Tensor)
+ flat_shapes.append(t.get_shape())
+ flat_types.append(t.dtype)
+
+ ret = nest.pack_sequence_as(ret, flat_ret)
+ self._output_classes = nest.pack_sequence_as(ret, flat_classes)
+ self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
+ self._output_types = nest.pack_sequence_as(ret, flat_types)
+
+ _warn_if_collections(transformation_name)
+
+ return flat_ret
+
+ self._function = tf_data_structured_function_wrapper
+ if add_to_graph:
+ self._function.add_to_graph(ops.get_default_graph())
+ else:
+ # Use the private method that will execute
+ # `tf_data_structured_function_wrapper` but delay adding it to the graph
+ # in case (e.g.) we need to rerun the function.
+ self._function._create_definition_if_needed() # pylint: disable=protected-access
+
+ def _defun_args(self):
+ """Returns a flat list of @{tf.DType} for the input element structure."""
+ ret = []
+ for input_type, input_class in zip(nest.flatten(self._input_types),
+ nest.flatten(self._input_classes)):
+ # TODO(b/110122868): Add a registration mechanism for new component types.
+ if input_class is sparse_tensor_lib.SparseTensor:
+ ret.append(dtypes.variant)
+ elif isinstance(input_class, _NestedDatasetComponent):
+ if not self._nested_dataset_support:
+ raise NotImplementedError(
+ "The %s transformation does not currently support nested "
+ "datasets as inputs." % self._transformation_name)
+ ret.append(dtypes.variant)
+ else:
+ assert isinstance(input_type, dtypes.DType)
+ ret.append(input_type)
+ return ret
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def function(self):
+ return self._function
+
+
+def flat_structure(dataset):
+ """Helper for setting `output_shapes` and `output_types` attrs of Dataset ops.
+
+ Most Dataset op constructors expect `output_shapes` and `output_types`
+ arguments that represent the flattened structure of an element. This helper
+ function generates these attrs as a keyword argument dictionary, allowing
+ `Dataset._as_variant_tensor()` implementations to pass
+ `**flat_structure(self)` to the op constructor.
+
+ Args:
+ dataset: A @{tf.data.Dataset}.
+
+ Returns:
+ A dictionary of keyword arguments that can be passed to many Dataset op
+ constructors.
+ """
+ output_classes = []
+ output_shapes = []
+ output_types = []
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_types)):
+ if isinstance(output_class, _NestedDatasetComponent):
+ output_classes.append(output_class.output_classes)
+ output_shapes.append(output_shape.output_shapes)
+ output_types.append(output_type.output_types)
+ else:
+ output_classes.append(output_class)
+ output_shapes.append(output_shape)
+ output_types.append(output_type)
+
+ output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
+ output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
+ output_types = nest.pack_sequence_as(dataset.output_types, output_types)
+
+ return {
+ "output_shapes":
+ nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
+ "output_types":
+ nest.flatten(sparse.as_dense_types(output_types, output_classes)),
+ }
+
+
class _GeneratorDataset(Dataset):
"""A `Dataset` that generates elements by invoking a function."""
@@ -1155,137 +1569,26 @@ class _GeneratorDataset(Dataset):
init_args_types = nest.pack_sequence_as(
init_args, [t.dtype for t in nest.flatten(init_args)])
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(init_args_types, init_args_classes)))
- def tf_init_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- dense_shapes = sparse.as_dense_shapes(init_args_shapes, init_args_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(init_args_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, init_args_types, init_args_shapes, init_args_classes)
- if _should_unpack_args(nested_args):
- ret = init_func(*nested_args)
- else:
- ret = init_func(nested_args)
-
- # If `init_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._state_classes = sparse.get_classes(ret)
- self._state_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._state_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._init_func = tf_init_func
- self._init_func.add_to_graph(ops.get_default_graph())
-
- # These members will be initialized by `tf_next_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes)))
- def tf_next_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(self._state_shapes,
- self._state_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
- if _should_unpack_args(nested_args):
- ret = next_func(*nested_args)
- else:
- ret = next_func(nested_args)
-
- # If `next_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._next_func = tf_next_func
- self._next_func.add_to_graph(ops.get_default_graph())
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(self._state_types, self._state_classes)))
- def tf_finalize_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the state.
- dense_shapes = sparse.as_dense_shapes(self._state_shapes,
- self._state_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(self._state_classes, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, self._state_types, self._state_shapes,
- self._state_classes)
- if _should_unpack_args(nested_args):
- return finalize_func(*nested_args)
- else:
- return finalize_func(nested_args)
-
- self._finalize_func = tf_finalize_func
- self._finalize_func.add_to_graph(ops.get_default_graph())
+ wrapped_init_func = StructuredFunctionWrapper(
+ init_func, "GeneratorDataset", input_classes=init_args_classes,
+ input_shapes=init_args_shapes, input_types=init_args_types)
+ self._state_classes = wrapped_init_func.output_classes
+ self._state_shapes = wrapped_init_func.output_shapes
+ self._state_types = wrapped_init_func.output_types
+ self._init_func = wrapped_init_func.function
+
+ wrapped_next_func = StructuredFunctionWrapper(
+ next_func, "GeneratorDataset", input_classes=self._state_classes,
+ input_shapes=self._state_shapes, input_types=self._state_types)
+ self._output_classes = wrapped_next_func.output_classes
+ self._output_shapes = wrapped_next_func.output_shapes
+ self._output_types = wrapped_next_func.output_types
+ self._next_func = wrapped_next_func.function
+
+ wrapped_finalize_func = StructuredFunctionWrapper(
+ finalize_func, "GeneratorDataset", input_classes=self._state_classes,
+ input_shapes=self._state_shapes, input_types=self._state_types)
+ self._finalize_func = wrapped_finalize_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.generator_dataset(
@@ -1295,10 +1598,7 @@ class _GeneratorDataset(Dataset):
init_func=self._init_func,
next_func=self._next_func,
finalize_func=self._finalize_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1335,16 +1635,7 @@ class ZipDataset(Dataset):
# pylint: disable=protected-access
return gen_dataset_ops.zip_dataset(
[ds._as_variant_tensor() for ds in nest.flatten(self._datasets)],
- output_shapes=[
- s
- for ds in nest.flatten(self._datasets)
- for s in nest.flatten(ds.output_shapes)
- ],
- output_types=[
- t
- for ds in nest.flatten(self._datasets)
- for t in nest.flatten(ds.output_types)
- ])
+ **flat_structure(self))
# pylint: enable=protected-access
@property
@@ -1389,10 +1680,7 @@ class ConcatenateDataset(Dataset):
return gen_dataset_ops.concatenate_dataset(
self._input_dataset._as_variant_tensor(),
self._dataset_to_concatenate._as_variant_tensor(),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
# pylint: enable=protected-access
@property
@@ -1430,10 +1718,7 @@ class RepeatDataset(Dataset):
return gen_dataset_ops.repeat_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
count=self._count,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1457,6 +1742,7 @@ class RangeDataset(Dataset):
self._parse_args(*args)
def _parse_args(self, *args):
+ """Parse arguments according to the same rules as the `range()` builtin."""
if len(args) == 1:
self._start = self._build_tensor(0, "start")
self._stop = self._build_tensor(args[0], "stop")
@@ -1480,10 +1766,7 @@ class RangeDataset(Dataset):
start=self._start,
stop=self._stop,
step=self._step,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1512,10 +1795,7 @@ class CacheDataset(Dataset):
return gen_dataset_ops.cache_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
filename=self._filename,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1575,10 +1855,7 @@ class ShuffleDataset(Dataset):
seed=self._seed,
seed2=self._seed2,
reshuffle_each_iteration=self._reshuffle_each_iteration,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1606,10 +1883,7 @@ class TakeDataset(Dataset):
return gen_dataset_ops.take_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
count=self._count,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1637,10 +1911,7 @@ class SkipDataset(Dataset):
return gen_dataset_ops.skip_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
count=self._count,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1658,21 +1929,28 @@ class SkipDataset(Dataset):
class BatchDataset(Dataset):
"""A `Dataset` that batches contiguous elements from its input."""
- def __init__(self, input_dataset, batch_size):
+ def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
super(BatchDataset, self).__init__()
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _as_variant_tensor(self):
- return gen_dataset_ops.batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- batch_size=self._batch_size,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
+ if smart_cond.smart_constant_value(self._drop_remainder) is False:
+ return gen_dataset_ops.batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ **flat_structure(self))
+ else:
+ return gen_dataset_ops.batch_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ drop_remainder=self._drop_remainder,
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1682,7 +1960,9 @@ class BatchDataset(Dataset):
def output_shapes(self):
input_shapes = self._input_dataset.output_shapes
return nest.pack_sequence_as(input_shapes, [
- tensor_shape.vector(None).concatenate(s)
+ tensor_shape.vector(
+ tensor_util.constant_value(self._batch_size) if smart_cond.
+ smart_constant_value(self._drop_remainder) else None).concatenate(s)
for s in nest.flatten(self._input_dataset.output_shapes)
])
@@ -1691,20 +1971,77 @@ class BatchDataset(Dataset):
return self._input_dataset.output_types
-def _partial_shape_to_tensor(shape_like):
+def _is_padded_shape_compatible_with(padded_shape, input_component_shape):
+ """Returns `True` if `input_component_shape` can be padded to `padded_shape`.
+
+ Args:
+ padded_shape: A `tf.TensorShape`.
+ input_component_shape: A `tf.TensorShape`.
+
+ Returns:
+ `True` if `input_component_shape` can be padded to `padded_shape`, otherwise
+ `False`.
+ """
+
+ if padded_shape.dims is None or input_component_shape.dims is None:
+ return True
+ if len(padded_shape.dims) != len(input_component_shape.dims):
+ return False
+ for padded_dim, input_dim in zip(
+ padded_shape.dims, input_component_shape.dims):
+ if (padded_dim.value is not None and input_dim.value is not None
+ and padded_dim.value < input_dim.value):
+ return False
+ return True
+
+
+def _padded_shape_to_tensor(padded_shape, input_component_shape):
+ """Converts `padded_shape` to a `tf.Tensor` representing that shape.
+
+ Args:
+ padded_shape: A shape-like object, which may be a `tf.TensorShape`, a Python
+ sequence, or a 1-D `tf.Tensor` of `tf.int64` elements.
+ input_component_shape: A `tf.TensorShape`, with which `padded_shape` must
+ be compatible.
+
+ Returns:
+ A 1-D `tf.Tensor` of `tf.int64` elements, representing `padded_shape`.
+
+ Raises:
+ ValueError: If `padded_shape` is not a shape or not compatible with
+ `input_component_shape`.
+ TypeError: If `padded_shape` is not convertible to a `tf.int64` tensor.
+ """
try:
- # First attempt to convert the input to a shape, and return the
- # "canonical" tensor representation, which uses `-1` in place of
- # `None`.
- shape_like = tensor_shape.as_shape(shape_like)
- return ops.convert_to_tensor(
- [dim if dim is not None else -1 for dim in shape_like.as_list()],
- dtype=dtypes.int64)
+ # Try to convert the `padded_shape` to a `tf.TensorShape`
+ padded_shape_as_shape = tensor_shape.as_shape(padded_shape)
+ # We will return the "canonical" tensor representation, which uses
+ # `-1` in place of `None`.
+ ret = ops.convert_to_tensor(
+ [dim if dim is not None else -1
+ for dim in padded_shape_as_shape.as_list()], dtype=dtypes.int64)
except (TypeError, ValueError):
# The argument was not trivially convertible to a
# `tf.TensorShape`, so fall back on the conversion to tensor
# machinery.
- return ops.convert_to_tensor(shape_like, dtype=dtypes.int64)
+ ret = ops.convert_to_tensor(padded_shape, preferred_dtype=dtypes.int64)
+ if ret.shape.dims is not None and len(ret.shape.dims) != 1:
+ raise ValueError(
+ "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
+ "shape was %s." % (padded_shape, ret.shape))
+ if ret.dtype != dtypes.int64:
+ raise TypeError(
+ "Padded shape %s must be a 1-D tensor of tf.int64 values, but its "
+ "element type was %s." % (padded_shape, ret.dtype.name))
+ padded_shape_as_shape = tensor_util.constant_value_as_shape(ret)
+
+ if not _is_padded_shape_compatible_with(padded_shape_as_shape,
+ input_component_shape):
+ raise ValueError("The padded shape %s is not compatible with the "
+ "corresponding input component shape %s."
+ % (padded_shape_as_shape, input_component_shape))
+
+ return ret
def _padding_value_to_tensor(value, output_type):
@@ -1731,7 +2068,7 @@ def _padding_value_to_tensor(value, output_type):
def _default_padding(input_dataset):
-
+ """Returns default padding tensors in a structure matching `input_dataset`."""
def make_zero(t):
if t.base_dtype == dtypes.string:
return ""
@@ -1746,7 +2083,8 @@ def _default_padding(input_dataset):
class PaddedBatchDataset(Dataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
- def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
+ def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
+ drop_remainder):
"""See `Dataset.batch()` for details."""
super(PaddedBatchDataset, self).__init__()
if sparse.any_sparse(input_dataset.output_classes):
@@ -1759,23 +2097,51 @@ class PaddedBatchDataset(Dataset):
padding_values = (
padding_values
if padding_values is not None else _default_padding(input_dataset))
- self._padded_shapes = nest.map_structure_up_to(
- input_dataset.output_shapes, _partial_shape_to_tensor, padded_shapes)
+
+ flat_padded_shapes = nest.flatten_up_to(input_dataset.output_shapes,
+ padded_shapes)
+
+ flat_padded_shapes_as_tensors = []
+
+ for input_component_shape, padded_shape in zip(
+ nest.flatten(input_dataset.output_shapes), flat_padded_shapes):
+ flat_padded_shapes_as_tensors.append(
+ _padded_shape_to_tensor(padded_shape, input_component_shape))
+
+ self._padded_shapes = nest.pack_sequence_as(input_dataset.output_shapes,
+ flat_padded_shapes_as_tensors)
+
self._padding_values = nest.map_structure_up_to(
input_dataset.output_shapes, _padding_value_to_tensor, padding_values,
input_dataset.output_types)
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
def _as_variant_tensor(self):
- return gen_dataset_ops.padded_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- batch_size=self._batch_size,
- padded_shapes=[
- ops.convert_to_tensor(s, dtype=dtypes.int64)
- for s in nest.flatten(self._padded_shapes)
- ],
- padding_values=nest.flatten(self._padding_values),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ # TODO(jsimsa): Switch to using v2 only any time after 6/30/2018.
+ if smart_cond.smart_constant_value(self._drop_remainder) is False:
+ return gen_dataset_ops.padded_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ else:
+ return gen_dataset_ops.padded_batch_dataset_v2(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ drop_remainder=self._drop_remainder,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
@property
def output_classes(self):
@@ -1785,8 +2151,10 @@ class PaddedBatchDataset(Dataset):
def output_shapes(self):
def _padded_shape_to_batch_shape(s):
- return tensor_shape.vector(None).concatenate(
- tensor_util.constant_value_as_shape(s))
+ return tensor_shape.vector(
+ tensor_util.constant_value(self._batch_size) if smart_cond.
+ smart_constant_value(self._drop_remainder) else None).concatenate(
+ tensor_util.constant_value_as_shape(s))
return nest.map_structure(_padded_shape_to_batch_shape, self._padded_shapes)
@@ -1800,6 +2168,24 @@ def _should_unpack_args(args):
return type(args) is tuple # pylint: disable=unidiomatic-typecheck
+def _warn_if_collections(transformation_name):
+ """Prints warning message if the current graph uses common graph collections.
+
+ NOTE(mrry): Currently a warning is only generated for lookup tables. Any
+ variables created will be automatically hoisted out to the outermost scope
+ using `init_scope()`. Some collections (such as for control-flow contexts)
+ are benign and should not generate a warning.
+
+ Args:
+ transformation_name: A human-readable name for the transformation.
+ """
+ if ops.get_default_graph().get_collection(ops.GraphKeys.TABLE_INITIALIZERS):
+ warnings.warn("Creating lookup tables inside a function passed to %s is not"
+ " supported. Create each table outside the function, and "
+ "capture it inside the function to use it."
+ % transformation_name)
+
+
class MapDataset(Dataset):
"""A `Dataset` that maps a function over elements in its input."""
@@ -1808,64 +2194,12 @@ class MapDataset(Dataset):
super(MapDataset, self).__init__()
self._input_dataset = input_dataset
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_map_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- ret = map_func(*nested_args)
- else:
- ret = map_func(nested_args)
-
- # If `map_func` returns a list of tensors, `nest.flatten()` and
- # `ops.convert_to_tensor()` would conspire to attempt to stack
- # those tensors into a single tensor, because the customized
- # version of `nest.flatten()` does not recurse into lists. Since
- # it is more likely that the list arose from returning the
- # result of an operation (such as `tf.py_func()`) that returns a
- # list of not-necessarily-stackable tensors, we treat the
- # returned value is a `tuple` instead. A user wishing to pack
- # the return value into a single tensor can use an explicit
- # `tf.stack()` before returning.
- if isinstance(ret, list):
- ret = tuple(ret)
-
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- ret = nest.pack_sequence_as(ret, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
- for t in nest.flatten(ret)
- ])
-
- self._output_classes = sparse.get_classes(ret)
- self._output_shapes = nest.pack_sequence_as(
- ret, [t.get_shape() for t in nest.flatten(ret)])
- self._output_types = nest.pack_sequence_as(
- ret, [t.dtype for t in nest.flatten(ret)])
-
- # Serialize any sparse tensors.
- ret = nest.pack_sequence_as(
- ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
- return nest.flatten(ret)
-
- self._map_func = tf_map_func
- self._map_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ map_func, "Dataset.map()", input_dataset)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
@@ -1873,10 +2207,7 @@ class MapDataset(Dataset):
input_t,
self._map_func.captured_inputs,
f=self._map_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1909,10 +2240,7 @@ class ParallelMapDataset(MapDataset):
self._map_func.captured_inputs,
f=self._map_func,
num_parallel_calls=self._num_parallel_calls,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
# pylint: enable=protected-access
@@ -1924,47 +2252,22 @@ class FlatMapDataset(Dataset):
super(FlatMapDataset, self).__init__()
self._input_dataset = input_dataset
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_map_func(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- dataset = map_func(*nested_args)
- else:
- dataset = map_func(nested_args)
-
- if not isinstance(dataset, Dataset):
- raise TypeError("`map_func` must return a `Dataset` object.")
-
- self._output_classes = dataset.output_classes
- self._output_types = dataset.output_types
- self._output_shapes = dataset.output_shapes
-
- return dataset._as_variant_tensor() # pylint: disable=protected-access
-
- self._map_func = tf_map_func
- self._map_func.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ map_func, self._transformation_name(), input_dataset,
+ experimental_nested_dataset_support=True)
+ if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent):
+ raise TypeError("`map_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._map_func = wrapped_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.flat_map_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._map_func.captured_inputs,
f=self._map_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -1978,6 +2281,9 @@ class FlatMapDataset(Dataset):
def output_types(self):
return self._output_types
+ def _transformation_name(self):
+ return "Dataset.flat_map()"
+
class InterleaveDataset(FlatMapDataset):
"""A `Dataset` that maps a function over its input and interleaves the result.
@@ -1998,10 +2304,10 @@ class InterleaveDataset(FlatMapDataset):
self._cycle_length,
self._block_length,
f=self._map_func, # pylint: disable=protected-access
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
+
+ def _transformation_name(self):
+ return "Dataset.interleave()"
class FilterDataset(Dataset):
@@ -2011,46 +2317,20 @@ class FilterDataset(Dataset):
"""See `Dataset.filter()` for details."""
super(FilterDataset, self).__init__()
self._input_dataset = input_dataset
-
- @function.Defun(*nest.flatten(
- sparse.as_dense_types(input_dataset.output_types,
- input_dataset.output_classes)))
- def tf_predicate(*args):
- """A wrapper for Defun that facilitates shape inference."""
- # Pass in shape information from the input_dataset.
- dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
- input_dataset.output_classes)
- for arg, shape in zip(args, nest.flatten(dense_shapes)):
- arg.set_shape(shape)
-
- nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
- nested_args = sparse.deserialize_sparse_tensors(
- nested_args, input_dataset.output_types, input_dataset.output_shapes,
- input_dataset.output_classes)
- if _should_unpack_args(nested_args):
- ret = predicate(*nested_args)
- else:
- ret = predicate(nested_args)
-
- ret = ops.convert_to_tensor(ret, dtype=dtypes.bool)
- if not (ret.dtype == dtypes.bool and
- ret.shape.is_compatible_with(tensor_shape.scalar())):
- raise ValueError("`predicate` must return a scalar boolean tensor.")
-
- return ret
-
- self._predicate = tf_predicate
- self._predicate.add_to_graph(ops.get_default_graph())
+ wrapped_func = StructuredFunctionWrapper(
+ predicate, "Dataset.filter()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.bool and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError("`predicate` must return a scalar boolean tensor.")
+ self._predicate = wrapped_func.function
def _as_variant_tensor(self):
return gen_dataset_ops.filter_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
other_arguments=self._predicate.captured_inputs,
predicate=self._predicate,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
@@ -2081,10 +2361,7 @@ class PrefetchDataset(Dataset):
return gen_dataset_ops.prefetch_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
buffer_size=self._buffer_size,
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)))
+ **flat_structure(self))
@property
def output_classes(self):
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index b6dba4e3ca..35de2f2841 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import threading
import warnings
+from tensorflow.python.compat import compat
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
@@ -172,13 +173,32 @@ class Iterator(object):
nest.assert_same_structure(output_types, output_shapes)
if shared_name is None:
shared_name = ""
- iterator_resource = gen_dataset_ops.iterator(
- container="",
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(output_types, output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(output_shapes, output_classes)))
+ if compat.forward_compatible(2018, 8, 3):
+ if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ with ops.device("/cpu:0"):
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_v2(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator(
+ container="",
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
@@ -242,12 +262,29 @@ class Iterator(object):
output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
nest.assert_same_structure(output_types, output_shapes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
- iterator_resource = gen_dataset_ops.iterator_from_string_handle(
- string_handle,
- output_types=nest.flatten(
- sparse.as_dense_types(output_types, output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(output_shapes, output_classes)))
+ if compat.forward_compatible(2018, 8, 3):
+ if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access
+ with ops.device("/cpu:0"):
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
+ else:
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle(
+ string_handle,
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index a73a8b5cdc..066e09969c 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -19,8 +19,6 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import convert
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -150,12 +148,12 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
self._buffer_output_elements,
self._prefetch_input_elements,
f=self._map_func,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ **dataset_ops.flat_structure(self))
# pylint: enable=protected-access
+ def _transformation_name(self):
+ return "tf.contrib.data.parallel_interleave()"
+
@tf_export("data.TFRecordDataset")
class TFRecordDataset(dataset_ops.Dataset):
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 0fc32d51b9..5fcc62b60b 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -70,6 +70,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
],
)
diff --git a/tensorflow/python/data/util/convert.py b/tensorflow/python/data/util/convert.py
index eeb1d700f3..746b3d66de 100644
--- a/tensorflow/python/data/util/convert.py
+++ b/tensorflow/python/data/util/convert.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
def optional_param_to_tensor(argument_name,
@@ -32,3 +33,40 @@ def optional_param_to_tensor(argument_name,
else:
return constant_op.constant(
argument_default, dtype=argument_dtype, name=argument_name)
+
+
+def partial_shape_to_tensor(shape_like):
+ """Returns a @{tf.Tensor} that represents the given shape.
+
+ Args:
+ shape_like: A value that can be converted to a @{tf.TensorShape} or a
+ @{tf.Tensor}.
+
+ Returns:
+ A 1-D `tf.Tensor` of `tf.int64` elements representing the given shape, where
+ `-1` is substituted for any unknown dimensions.
+ """
+ try:
+ # First attempt to convert the input to a shape, and return the
+ # "canonical" tensor representation, which uses `-1` in place of
+ # `None`.
+ shape_like = tensor_shape.as_shape(shape_like)
+ return ops.convert_to_tensor(
+ [dim if dim is not None else -1 for dim in shape_like.as_list()],
+ dtype=dtypes.int64)
+ except (TypeError, ValueError):
+ # The argument was not trivially convertible to a
+ # `tf.TensorShape`, so fall back on the conversion to tensor
+ # machinery.
+ ret = ops.convert_to_tensor(shape_like, preferred_dtype=dtypes.int64)
+ if ret.shape.dims is not None and len(ret.shape.dims) != 1:
+ raise ValueError("The given shape %s must be a 1-D tensor of tf.int64 "
+ "values, but the shape was %s."
+ % (shape_like, ret.shape))
+ if ret.dtype != dtypes.int64:
+ raise TypeError("The given shape %s must be a 1-D tensor of tf.int64 "
+ "values, but the element type was %s."
+ % (shape_like, ret.dtype.name))
+
+ return ret
+
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 2cb6488070..6a67093e48 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -19,7 +19,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.util import convert
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -48,6 +50,77 @@ class ConvertTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(compat.as_bytes("value"), sess.run(resp))
+ def testPartialShapeToTensorKnownDimension(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([1]))))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1])))
+ self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([1], dtype=dtypes.int64))))
+
+ def testPartialShapeToTensorUnknownDimension(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None]))))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ (None,))))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ [None])))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ [-1])))
+ self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([-1], dtype=dtypes.int64))))
+
+ with self.assertRaisesRegexp(
+ ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
+ r"values, but the shape was \(2, 2\)."):
+ convert.partial_shape_to_tensor(constant_op.constant(
+ [[1, 1], [1, 1]], dtype=dtypes.int64))
+
+ with self.assertRaisesRegexp(
+ TypeError, r"The given shape .* must be a 1-D tensor of tf.int64 "
+ r"values, but the element type was float32."):
+ convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
+
+ def testPartialShapeToTensorMultipleDimensions(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, 6]))))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ (3, 6))))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ [3, 6])))
+ self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([3, 6], dtype=dtypes.int64))))
+
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, None]))))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ (3, None))))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ [3, None])))
+ self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([3, -1], dtype=dtypes.int64))))
+
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None, None]))))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ (None, None))))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ [None, None])))
+ self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([-1, -1], dtype=dtypes.int64))))
+
+ def testPartialShapeToTensorScalar(self):
+ with self.test_session() as sess:
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([]))))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([])))
+ self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
+ constant_op.constant([], dtype=dtypes.int64))))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/util/random_seed_test.py b/tensorflow/python/data/util/random_seed_test.py
index 33227e82af..a809151e6e 100644
--- a/tensorflow/python/data/util/random_seed_test.py
+++ b/tensorflow/python/data/util/random_seed_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class RandomSeedTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRandomSeed(self):
zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero')
one_t = constant_op.constant(1, dtype=dtypes.int64, name='one')
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 2d261f9be7..27b8ebd362 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -167,6 +167,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:platform",
+ "//third_party/py/numpy",
"@six_archive//:six",
],
)
@@ -403,6 +404,7 @@ py_library(
deps = [
":debug_errors",
":debug_fibonacci",
+ ":debug_keras",
":debug_mnist",
":debug_tflearn_iris",
],
@@ -453,6 +455,17 @@ py_binary(
],
)
+py_binary(
+ name = "debug_keras",
+ srcs = ["examples/debug_keras.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_py",
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
py_test(
name = "common_test",
size = "small",
@@ -790,6 +803,7 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
+ tags = ["no_windows_gpu"],
)
py_test(
@@ -802,6 +816,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
@@ -1084,6 +1099,7 @@ py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//third_party/py/numpy",
],
)
@@ -1094,6 +1110,7 @@ sh_test(
data = [
":debug_errors",
":debug_fibonacci",
+ ":debug_keras",
":debug_mnist",
":debug_tflearn_iris",
":offline_analyzer",
diff --git a/tensorflow/python/debug/cli/cli_shared.py b/tensorflow/python/debug/cli/cli_shared.py
index dea019fef5..6a368682de 100644
--- a/tensorflow/python/debug/cli/cli_shared.py
+++ b/tensorflow/python/debug/cli/cli_shared.py
@@ -451,42 +451,48 @@ def get_error_intro(tf_error):
sample commands for debugging.
"""
- op_name = tf_error.op.name
+ if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
+ op_name = tf_error.op.name
+ else:
+ op_name = None
intro_lines = [
"--------------------------------------",
RL("!!! An error occurred during the run !!!", "blink"),
"",
- "You may use the following commands to debug:",
]
out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
- out.extend(
- _recommend_command("ni -a -d -t %s" % op_name,
- "Inspect information about the failing op.",
- create_link=True))
- out.extend(
- _recommend_command("li -r %s" % op_name,
- "List inputs to the failing op, recursively.",
- create_link=True))
-
- out.extend(
- _recommend_command(
- "lt",
- "List all tensors dumped during the failing run() call.",
- create_link=True))
+ if op_name is not None:
+ out.extend(debugger_cli_common.RichTextLines(
+ ["You may use the following commands to debug:"]))
+ out.extend(
+ _recommend_command("ni -a -d -t %s" % op_name,
+ "Inspect information about the failing op.",
+ create_link=True))
+ out.extend(
+ _recommend_command("li -r %s" % op_name,
+ "List inputs to the failing op, recursively.",
+ create_link=True))
+
+ out.extend(
+ _recommend_command(
+ "lt",
+ "List all tensors dumped during the failing run() call.",
+ create_link=True))
+ else:
+ out.extend(debugger_cli_common.RichTextLines([
+ "WARNING: Cannot determine the name of the op that caused the error."]))
more_lines = [
"",
- "Op name: " + op_name,
+ "Op name: %s" % op_name,
"Error type: " + str(type(tf_error)),
"",
"Details:",
str(tf_error),
"",
- "WARNING: Using client GraphDef due to the error, instead of "
- "executor GraphDefs.",
"--------------------------------------",
"",
]
diff --git a/tensorflow/python/debug/cli/cli_shared_test.py b/tensorflow/python/debug/cli/cli_shared_test.py
index 3d7939490d..07b364db9f 100644
--- a/tensorflow/python/debug/cli/cli_shared_test.py
+++ b/tensorflow/python/debug/cli/cli_shared_test.py
@@ -372,6 +372,11 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
self.assertEqual("Details:", error_intro.lines[14])
self.assertStartsWith(error_intro.lines[15], "foo description")
+ def testGetErrorIntroForNoOpName(self):
+ tf_error = errors.OpError(None, None, "Fake OpError", -1)
+ error_intro = cli_shared.get_error_intro(tf_error)
+ self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/cli/debugger_cli_common.py b/tensorflow/python/debug/cli/debugger_cli_common.py
index 12e79ab07a..02563fde84 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common.py
@@ -23,9 +23,11 @@ import re
import sre_constants
import traceback
+import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python import pywrap_tensorflow_internal
from tensorflow.python.platform import gfile
HELP_INDENT = " "
@@ -131,6 +133,25 @@ def rich_text_lines_from_rich_line_list(rich_text_list, annotations=None):
return RichTextLines(lines, font_attr_segs, annotations=annotations)
+def get_tensorflow_version_lines(include_dependency_versions=False):
+ """Generate RichTextLines with TensorFlow version info.
+
+ Args:
+ include_dependency_versions: Include the version of TensorFlow's key
+ dependencies, such as numpy.
+
+ Returns:
+ A formatted, multi-line `RichTextLines` object.
+ """
+ lines = ["TensorFlow version: %s" % pywrap_tensorflow_internal.__version__]
+ lines.append("")
+ if include_dependency_versions:
+ lines.append("Dependency version(s):")
+ lines.append(" numpy: %s" % np.__version__)
+ lines.append("")
+ return RichTextLines(lines)
+
+
class RichTextLines(object):
"""Rich multi-line text.
@@ -538,6 +559,8 @@ class CommandHandlerRegistry(object):
HELP_COMMAND = "help"
HELP_COMMAND_ALIASES = ["h"]
+ VERSION_COMMAND = "version"
+ VERSION_COMMAND_ALIASES = ["ver"]
def __init__(self):
# A dictionary from command prefix to handler.
@@ -562,6 +585,13 @@ class CommandHandlerRegistry(object):
"Print this help message.",
prefix_aliases=self.HELP_COMMAND_ALIASES)
+ # Register a default handler for the command "version".
+ self.register_command_handler(
+ self.VERSION_COMMAND,
+ self._version_handler,
+ "Print the versions of TensorFlow and its key dependencies.",
+ prefix_aliases=self.VERSION_COMMAND_ALIASES)
+
def register_command_handler(self,
prefix,
handler,
@@ -763,6 +793,11 @@ class CommandHandlerRegistry(object):
else:
return RichTextLines(["ERROR: help takes only 0 or 1 input argument."])
+ def _version_handler(self, args, screen_info=None):
+ del args # Unused currently.
+ del screen_info # Unused currently.
+ return get_tensorflow_version_lines(include_dependency_versions=True)
+
def _resolve_prefix(self, token):
"""Resolve command prefix from the prefix itself or its alias.
diff --git a/tensorflow/python/debug/cli/debugger_cli_common_test.py b/tensorflow/python/debug/cli/debugger_cli_common_test.py
index 1b7a5962fe..aba95e5820 100644
--- a/tensorflow/python/debug/cli/debugger_cli_common_test.py
+++ b/tensorflow/python/debug/cli/debugger_cli_common_test.py
@@ -21,6 +21,9 @@ import os
import stat
import tempfile
+import numpy as np
+
+from tensorflow.python import pywrap_tensorflow_internal
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile
@@ -547,7 +550,10 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
" Show screen width in number of columns.", "", "",
"help", " Aliases: h", "", " Print this help message.",
"", "", "noop", " Aliases: n, NOOP", "",
- " No operation.", " I.e., do nothing.", "", ""],
+ " No operation.", " I.e., do nothing.", "", "",
+ "version", " Aliases: ver", "",
+ " Print the versions of TensorFlow and its key "
+ "dependencies.", "", ""],
output.lines)
# Get help for one specific command prefix.
@@ -575,7 +581,9 @@ class CommandHandlerRegistryTest(test_util.TensorFlowTestCase):
self.assertEqual(help_intro.lines + [
"help", " Aliases: h", "", " Print this help message.", "", "",
"noop", " Aliases: n, NOOP", "", " No operation.",
- " I.e., do nothing.", "", ""
+ " I.e., do nothing.", "", "",
+ "version", " Aliases: ver", "",
+ " Print the versions of TensorFlow and its key dependencies.", "", ""
], output.lines)
@@ -1147,5 +1155,22 @@ class MenuTest(test_util.TensorFlowTestCase):
self.assertEqual((40, 50, ["bold"]), output.font_attr_segs[0][2])
+class GetTensorFlowVersionLinesTest(test_util.TensorFlowTestCase):
+
+ def testGetVersionWithoutDependencies(self):
+ out = debugger_cli_common.get_tensorflow_version_lines()
+ self.assertEqual(2, len(out.lines))
+ self.assertEqual(
+ "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
+ out.lines[0])
+
+ def testGetVersionWithDependencies(self):
+ out = debugger_cli_common.get_tensorflow_version_lines(True)
+ self.assertIn(
+ "TensorFlow version: %s" % pywrap_tensorflow_internal.__version__,
+ out.lines)
+ self.assertIn(" numpy: %s" % np.__version__, out.lines)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/debug/examples/debug_keras.py b/tensorflow/python/debug/examples/debug_keras.py
new file mode 100644
index 0000000000..3272d85ade
--- /dev/null
+++ b/tensorflow/python/debug/examples/debug_keras.py
@@ -0,0 +1,89 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tfdbg example: debugging tf.keras models training on tf.data.Dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python import debug as tf_debug
+
+
+def main(_):
+ # Create a dummy dataset.
+ num_examples = 8
+ steps_per_epoch = 2
+ input_dims = 3
+ output_dims = 1
+ xs = np.zeros([num_examples, input_dims])
+ ys = np.zeros([num_examples, output_dims])
+ dataset = tf.data.Dataset.from_tensor_slices(
+ (xs, ys)).repeat(num_examples).batch(int(num_examples / steps_per_epoch))
+
+ sess = tf.Session()
+ if FLAGS.debug:
+ # Use the command-line interface (CLI) of tfdbg.
+ sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type=FLAGS.ui_type)
+ elif FLAGS.tensorboard_debug_address:
+ # Use the TensorBoard Debugger Plugin (GUI of tfdbg).
+ sess = tf_debug.TensorBoardDebugWrapperSession(
+ sess, FLAGS.tensorboard_debug_address)
+ tf.keras.backend.set_session(sess)
+
+ # Create a dummy model.
+ model = tf.keras.Sequential([
+ tf.keras.layers.Dense(1, input_shape=[input_dims])])
+ model.compile(loss="mse", optimizer="sgd")
+
+ # Train the model using the dummy dataset created above.
+ model.fit(dataset, epochs=FLAGS.epochs, steps_per_epoch=steps_per_epoch)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.register("type", "bool", lambda v: v.lower() == "true")
+ parser.add_argument(
+ "--debug",
+ type="bool",
+ nargs="?",
+ const=True,
+ default=False,
+ help="Use debugger to track down bad values during training. "
+ "Mutually exclusive with the --tensorboard_debug_address flag.")
+ parser.add_argument(
+ "--ui_type",
+ type=str,
+ default="curses",
+ help="Command-line user interface type (curses | readline).")
+ parser.add_argument(
+ "--tensorboard_debug_address",
+ type=str,
+ default=None,
+ help="Connect to the TensorBoard Debugger Plugin backend specified by "
+ "the gRPC address (e.g., localhost:1234). Mutually exclusive with the "
+ "--debug flag.")
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=2,
+ help="Number of epochs to train the model for.")
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index 2df6c0b6a2..2d35b2d8bb 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -48,12 +48,14 @@ if [[ -z "${PYTHON_BIN_PATH}" ]]; then
DEBUG_ERRORS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_errors"
DEBUG_MNIST_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_tflearn_iris"
+ DEBUG_KERAS_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/debug_keras"
OFFLINE_ANALYZER_BIN="$TEST_SRCDIR/org_tensorflow/tensorflow/python/debug/offline_analyzer"
else
DEBUG_FIBONACCI_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_fibonacci"
DEBUG_ERRORS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_errors"
DEBUG_MNIST_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_mnist"
DEBUG_TFLEARN_IRIS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_tflearn_iris"
+ DEBUG_KERAS_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.examples.debug_keras"
OFFLINE_ANALYZER_BIN="${PYTHON_BIN_PATH} -m tensorflow.python.debug.cli.offline_analyzer"
fi
@@ -69,6 +71,12 @@ run
exit
EOF
+cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline
+run
+ni -a -d -t v/read
+exit
+EOF
+
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
run -t 1
run --node_name_filter hidden --op_type_filter MatMul
@@ -90,6 +98,11 @@ if [[ -d "${CUSTOM_DUMP_ROOT}" ]]; then
exit 1
fi
+# Test debugging of tf.keras.
+cat << EOF | "${DEBUG_KERAS_BIN}" --debug --ui_type=readline
+run -f has_inf_or_nan
+EOF
+
# Test offline_analyzer.
echo
echo "Testing offline_analyzer"
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index 8a65ad087b..7c96c2878c 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -748,7 +748,7 @@ class DebugDumpDir(object):
return sum(len(self._dump_tensor_data[device_name])
for device_name in self._dump_tensor_data)
- def _load_partition_graphs(self, partition_graphs, validate):
+ def _load_partition_graphs(self, client_partition_graphs, validate):
"""Load and process partition graphs.
Load the graphs; parse the input and control input structure; obtain the
@@ -757,8 +757,10 @@ class DebugDumpDir(object):
tensor dumps.
Args:
- partition_graphs: A repeated field of GraphDefs representing the
- partition graphs executed by the TensorFlow runtime.
+ client_partition_graphs: A repeated field of GraphDefs representing the
+ partition graphs executed by the TensorFlow runtime, from the Python
+ client. These partition graphs are used only if partition graphs
+ cannot be loaded from the dump directory on the file system.
validate: (`bool`) Whether the dump files are to be validated against the
partition graphs.
@@ -769,24 +771,23 @@ class DebugDumpDir(object):
self._debug_graphs = {}
self._node_devices = {}
- if partition_graphs:
- partition_graphs_and_device_names = [
- (partition_graph, None) for partition_graph in partition_graphs]
- else:
- partition_graphs_and_device_names = []
- for device_name in self._device_names:
- partition_graph = None
- if device_name in self._dump_graph_file_paths:
- partition_graph = _load_graph_def_from_event_file(
- self._dump_graph_file_paths[device_name])
- else:
- partition_graph = self._find_partition_graph(partition_graphs,
- device_name)
- if partition_graph:
- partition_graphs_and_device_names.append((partition_graph,
- device_name))
- else:
- logging.warn("Failed to load partition graphs from disk.")
+ partition_graphs_and_device_names = []
+ for device_name in self._device_names:
+ partition_graph = None
+ if device_name in self._dump_graph_file_paths:
+ partition_graph = _load_graph_def_from_event_file(
+ self._dump_graph_file_paths[device_name])
+ else:
+ logging.warn(
+ "Failed to load partition graphs for device %s from disk. "
+ "As a fallback, the client graphs will be used. This "
+ "may cause mismatches in device names." % device_name)
+ partition_graph = self._find_partition_graph(client_partition_graphs,
+ device_name)
+
+ if partition_graph:
+ partition_graphs_and_device_names.append((partition_graph,
+ device_name))
for partition_graph, maybe_device_name in partition_graphs_and_device_names:
debug_graph = debug_graphs.DebugGraph(partition_graph,
diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
index bd00f73861..676097fde9 100644
--- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
+++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
@@ -44,7 +44,8 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index c530204bbf..b9524ce649 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
self._default_session_context_manager = None
+ # A cache for callables created from CallableOptions.
+ self._cached_callables_from_options = dict()
+
@property
def graph(self):
return self._sess.graph
@@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
"""Wrapper around Session.run() that inserts tensor watch options.
Args:
@@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
callable_runner: A `callable` returned by `Session.make_callable()`.
If not `None`, `fetches` and `feed_dict` must both be `None`.
- callable_runner_args: An optional list of arguments to `callable_runner`.
+ Mutually exclusive with `callable_options`.
+ callable_runner_args: An optional list of arguments to `callable_runner`
+ or for `callable_options`.
+ callable_options: An instance of `config_pb2.CallableOptions`, to be
+ used with `Session._make_callable_from_options()`. Mutually exclusive
+ with `callable_runner`.
Returns:
Simply forwards the output of the wrapped `Session.run()` call.
@@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface):
ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
is not `None` and either or both of `fetches` and `feed_dict` is `None`.
"""
- if not callable_runner:
+ if callable_runner and callable_options:
+ raise ValueError(
+ "callable_runner and callable_options are mutually exclusive, but "
+ "are both specified in this call to BaseDebugWrapperSession.run().")
+
+ if not (callable_runner or callable_options):
self.increment_run_call_count()
- else:
- if fetches or feed_dict:
- raise ValueError(
- "callable_runner and fetches/feed_dict are mutually exclusive, but "
- "are used simultaneously.")
+ elif callable_runner and (fetches or feed_dict):
+ raise ValueError(
+ "callable_runner and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
@@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
+ elif callable_options:
+ # pylint:disable=protected-access
+ return self._sess._make_callable_from_options(
+ callable_options)(*callable_runner_args)
+ # pylint:enable=protected-access
else:
return self._sess.run(fetches,
feed_dict=feed_dict,
@@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface):
if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
# Decorate RunOption to fill in debugger tensor watch specifications.
- decorated_run_options = options or config_pb2.RunOptions()
+ decorated_run_options = None
+ if callable_options:
+ callable_options_id = id(callable_options)
+ if callable_options_id not in self._cached_callables_from_options:
+ # Make a copy of callable_options to avoid mutating it.
+ new_callable_options = config_pb2.CallableOptions()
+ new_callable_options.CopyFrom(callable_options)
+ decorated_run_options = new_callable_options.run_options
+ else:
+ decorated_run_options = options or config_pb2.RunOptions()
+
run_metadata = run_metadata or config_pb2.RunMetadata()
- self._decorate_run_options_for_debug(
- decorated_run_options,
- run_start_resp.debug_urls,
- debug_ops=run_start_resp.debug_ops,
- node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
- op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
- tensor_dtype_regex_whitelist=(
- run_start_resp.tensor_dtype_regex_whitelist),
- tolerate_debug_op_creation_failures=(
- run_start_resp.tolerate_debug_op_creation_failures))
+ if decorated_run_options:
+ self._decorate_run_options_for_debug(
+ decorated_run_options,
+ run_start_resp.debug_urls,
+ debug_ops=run_start_resp.debug_ops,
+ node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
+ op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=(
+ run_start_resp.tensor_dtype_regex_whitelist),
+ tolerate_debug_op_creation_failures=(
+ run_start_resp.tolerate_debug_op_creation_failures))
# Invoke the run() method of the wrapped Session. Catch any TensorFlow
# runtime errors.
@@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface):
retvals = callable_runner(*callable_runner_args,
options=decorated_run_options,
run_metadata=run_metadata)
+ elif callable_options:
+ # pylint:disable=protected-access
+ if callable_options_id in self._cached_callables_from_options:
+ callable_object = self._cached_callables_from_options[
+ callable_options_id]
+ else:
+ callable_object = self._sess._make_callable_from_options(
+ new_callable_options)
+ self._cached_callables_from_options[
+ callable_options_id] = callable_object
+ # pylint:enable=protected-access
+ retvals = callable_object(
+ *callable_runner_args, run_metadata=run_metadata)
else:
retvals = self._sess.run(fetches,
feed_dict=feed_dict,
@@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata=kwargs.get("run_metadata", None),
callable_runner=runner,
callable_runner_args=runner_args)
+ return wrapped_runner
+ def _make_callable_from_options(self, callable_options):
+ def wrapped_runner(*feed_values, **kwargs):
+ return self.run(None,
+ run_metadata=kwargs.get("run_metadata", None),
+ callable_options=callable_options,
+ callable_runner_args=feed_values)
return wrapped_runner
@property
diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py
index 1f9c8fa5a9..85944fa611 100644
--- a/tensorflow/python/debug/wrappers/grpc_wrapper.py
+++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py
@@ -215,7 +215,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
if self._send_traceback_and_source_code:
self._sent_graph_version = publish_traceback(
self._grpc_debug_server_urls, self.graph, feed_dict, fetches,
@@ -226,4 +227,5 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession):
options=options,
run_metadata=run_metadata,
callable_runner=callable_runner,
- callable_runner_args=callable_runner_args)
+ callable_runner_args=callable_runner_args,
+ callable_options=callable_options)
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
index c8625655e5..668ffb57f1 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py
@@ -290,6 +290,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
if self._run_call_count == 1:
# Show logo at the onset of the first run.
help_intro.extend(cli_shared.get_tfdbg_logo())
+ help_intro.extend(debugger_cli_common.get_tensorflow_version_lines())
help_intro.extend(debugger_cli_common.RichTextLines("Upcoming run:"))
help_intro.extend(self._run_info)
@@ -466,6 +467,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
if self._run_call_count == 1:
output.extend(cli_shared.get_tfdbg_logo())
+ output.extend(debugger_cli_common.get_tensorflow_version_lines())
output.extend(self._run_info)
if (not self._is_run_start and
@@ -594,7 +596,7 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession):
# Register tab completion for the filter names.
curses_cli.register_tab_comp_context(["run", "r"],
list(self._tensor_filters.keys()))
- if self._feed_dict:
+ if self._feed_dict and hasattr(self._feed_dict, "keys"):
# Register tab completion for feed_dict keys.
feed_keys = [common.get_graph_element_name(key)
for key in self._feed_dict.keys()]
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index b06fa26a93..05c9eaa4d2 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -21,7 +21,10 @@ import os
import shutil
import tempfile
+import numpy as np
+
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import debugger_cli_common
@@ -149,7 +152,13 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
- self.sess = session.Session()
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ config_proto = config_pb2.ConfigProto(graph_options=graph_options)
+ self.sess = session.Session(config=config_proto)
# Initialize variable.
self.sess.run(variables.global_variables_initializer())
@@ -393,6 +402,113 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertAllClose(42.0, tensor_runner(41.0, 1.0))
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
+ def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ for _ in range(2):
+ callable_output = sess_callable()
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithOneFeedWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ a = math_ops.add(ph1, ph1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array([10.5, -10.5], dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value)
+ self.assertAllClose(
+ np.array([42.0, -42.0], dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
+ ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
+ ph2 = array_ops.placeholder(dtypes.float32, name="callable_ph2")
+ a = math_ops.add(ph1, ph2, "callable_a")
+ math_ops.add(a, a, "callable_b")
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"]] * 3, self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.feed.append("callable_ph1")
+ callable_options.feed.append("callable_ph2")
+ callable_options.fetch.append("callable_b")
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ ph1_value = np.array(5.0, dtype=np.float32)
+ ph2_value = np.array(16.0, dtype=np.float32)
+
+ for _ in range(2):
+ callable_output = sess_callable(ph1_value, ph2_value)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(2, len(debug_dumps))
+ for debug_dump in debug_dumps:
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(["callable_a", "callable_b"], node_names)
+
+ def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
+ variable_1 = variables.Variable(
+ 10.5, dtype=dtypes.float32, name="variable_1")
+ a = math_ops.add(variable_1, variable_1, "callable_a")
+ math_ops.add(a, a, "callable_b")
+ self.sess.run(variable_1.initializer)
+
+ wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
+ [["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
+ callable_options = config_pb2.CallableOptions()
+ callable_options.fetch.append("callable_b")
+ callable_options.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
+
+ sess_callable = wrapped_sess._make_callable_from_options(callable_options)
+
+ run_metadata = config_pb2.RunMetadata()
+ # Call the callable with a custom run_metadata.
+ callable_output = sess_callable(run_metadata=run_metadata)
+ # Verify that step_stats is populated in the custom run_metadata.
+ self.assertTrue(run_metadata.step_stats)
+ self.assertAllClose(np.array(42.0, dtype=np.float32), callable_output[0])
+
+ debug_dumps = wrapped_sess.observers["debug_dumps"]
+ self.assertEqual(1, len(debug_dumps))
+ debug_dump = debug_dumps[0]
+ node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
+ self.assertItemsEqual(
+ ["callable_a", "callable_b", "variable_1", "variable_1/read"],
+ node_names)
+
def testRuntimeErrorShouldBeCaught(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index dee86966f1..6ede8e4f4d 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -32,6 +32,7 @@ cc_library(
"//tensorflow/python:numpy_lib",
"//tensorflow/python:py_seq_tensor",
"//tensorflow/python:safe_ptr",
+ "//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
)
@@ -391,3 +392,20 @@ py_library(
srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3",
)
+
+cuda_py_test(
+ name = "memory_test",
+ size = "medium",
+ srcs = ["memory_test.py"],
+ additional_deps = [
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/keras",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "optonly", # The test is too slow in non-opt mode
+ ],
+)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index b2e6c60021..9e0bbce4a1 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -196,11 +196,11 @@ def implicit_val_and_grad(f):
# TODO(cais): Remove calls to tf.constant() once the gradients functions
# accept lists and np.ndarrays.
- def grad_fn(*args):
+ def grad_fn(*args, **kwds):
"""Computes the gradient of the wrapped function."""
this_tape = tape.push_new_tape()
try:
- end_node = f(*args)
+ end_node = f(*args, **kwds)
if end_node is None:
raise ValueError("Cannot differentiate a function that returns None; "
"did you forget to return a value from {}?".format(
@@ -605,7 +605,9 @@ def _zeros(shape, dtype):
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
- cache_key = shape, dtype, device
+ # pylint: disable=protected-access
+ cache_key = shape, dtype, device, context.context()._eager_context.mode
+ # pylint: enable=protected-access
cached = _zeros_cache.get(cache_key)
if cached is None:
cached = _fast_fill(0, shape, dtype)
@@ -711,10 +713,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self):
+ def _push_tape(self, existing_tape=False):
if self._recording:
raise ValueError("Tape is already recording.")
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ if existing_tape:
+ if self._tape is None:
+ raise ValueError("There is no existing tape.")
+ tape.push_tape(self._tape)
+ else:
+ self._tape = tape.push_new_tape(persistent=self._persistent)
self._recording = True
def _pop_tape(self):
@@ -762,7 +769,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape()
+ self._push_tape(existing_tape=True)
def reset(self):
"""Clears all information stored in this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 826c6683b9..bdda200ff6 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -46,7 +46,7 @@ from tensorflow.python.training import training
class BackpropTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAggregateGradients(self):
def fn(x):
@@ -223,11 +223,23 @@ class BackpropTest(test.TestCase):
def testTapeStopRecording(self):
with backprop.GradientTape() as t:
- x = constant_op.constant(1.0)
+ x = resource_variable_ops.ResourceVariable(1.0)
with t.stop_recording():
y = x * x
self.assertEqual(t.gradient(y, x), None)
+ def testTapeStopStartRecording(self):
+ with backprop.GradientTape(persistent=True) as t:
+ x = resource_variable_ops.ResourceVariable(1.0)
+ x2 = x * 2 # This should be differentiated through.
+ with t.stop_recording():
+ y = x2 * x2
+ z = x2 * x2
+ self.assertEqual(t.gradient(y, x2), None)
+
+ # If the x*2 was not differentiated through, this would be 2.0, not 4.0
+ self.assertEqual(t.gradient(z, x2).numpy(), 4.0)
+
def testTapeReset(self):
with backprop.GradientTape() as t:
v = resource_variable_ops.ResourceVariable(1.0)
@@ -251,7 +263,7 @@ class BackpropTest(test.TestCase):
g, = backprop.gradients_function(loss, [0])(logits, labels)
self.assertAllEqual(g.numpy(), [[-0.5, 0.5]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientWithinTapeBlock(self):
v1 = resource_variable_ops.ResourceVariable(1.)
self.evaluate(v1.initializer)
@@ -265,7 +277,7 @@ class BackpropTest(test.TestCase):
grad = t.gradient(loss, v1)
self.assertAllEqual(self.evaluate(grad), 2.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestedSelfContexts(self):
v1 = resource_variable_ops.ResourceVariable(1.)
self.evaluate(v1.initializer)
@@ -435,7 +447,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(backprop.implicit_grad(f)()[0][0], None)
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=False) as g:
x = constant_op.constant(3.0)
@@ -445,7 +457,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(self.evaluate(grad), [2.0, 2.0])
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPersistentGradientTapeRepeatedSource(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -459,7 +471,7 @@ class BackpropTest(test.TestCase):
self.assertEqual(self.evaluate(grad), [3.0, 11.0])
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTapeStructure(self):
with backprop.GradientTape(persistent=True) as g:
# Using different constant values because constant tensors are
@@ -482,7 +494,7 @@ class BackpropTest(test.TestCase):
[1.0, {'x2': 2.0, 'x3': 3.0}])
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTape(self):
with backprop.GradientTape() as g:
x = constant_op.constant(3.0)
@@ -497,7 +509,7 @@ class BackpropTest(test.TestCase):
grad = g.gradient(y, [x])[0]
self.assertEqual(self.evaluate(grad), 6.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTapeWithCond(self):
x = constant_op.constant(3.0)
@@ -518,7 +530,7 @@ class BackpropTest(test.TestCase):
dy = g.gradient(y, [x])[0]
self.assertEqual(self.evaluate(dy), 6.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTapeWithWhileLoop(self):
i = constant_op.constant(1)
x = constant_op.constant(2.)
@@ -553,7 +565,7 @@ class BackpropTest(test.TestCase):
g.gradient(y, [x])
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPersistentTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -567,7 +579,7 @@ class BackpropTest(test.TestCase):
del g
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testHigherOrderGradient(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -584,7 +596,7 @@ class BackpropTest(test.TestCase):
del g
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPersistentNestedTape(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant(3.0)
@@ -605,7 +617,7 @@ class BackpropTest(test.TestCase):
del g
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientTapeVariable(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
self.evaluate(v.initializer)
@@ -615,7 +627,7 @@ class BackpropTest(test.TestCase):
self.assertAllEqual(self.evaluate(grad), 2.0)
@test_util.assert_no_new_tensors
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNestedGradients(self):
x = constant_op.constant(3.0)
with backprop.GradientTape() as g:
@@ -900,6 +912,33 @@ class BackpropTest(test.TestCase):
'did you forget to return a value from fn?'):
val_and_grads_fn(x, y)
+ def testZerosCacheDoesntLeakAcrossModes(self):
+ with ops.Graph().as_default():
+ t = random_ops.random_normal(shape=[100, 2])
+ x = random_ops.random_normal(shape=[100, 4])
+ dy = random_ops.random_normal(shape=[100, 4])
+ with backprop.GradientTape() as gradient_tape:
+ gradient_tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1 ** 2.
+ y = array_ops.concat([y1, t], axis=1)
+
+ dx = gradient_tape.gradient(y, x, output_gradients=dy)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(dx)
+
+ t = random_ops.random_normal(shape=[100, 2])
+ x = random_ops.random_normal(shape=[100, 4])
+ dy = random_ops.random_normal(shape=[100, 4])
+ with backprop.GradientTape() as gradient_tape:
+ gradient_tape.watch(x)
+ x1, _ = array_ops.split(x, num_or_size_splits=2, axis=1)
+ y1 = x1 ** 2.
+ y = array_ops.concat([y1, t], axis=1)
+
+ dx = gradient_tape.gradient(y, x, output_gradients=dy)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 9e146f021e..85b9491903 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -143,7 +143,11 @@ class Context(object):
# TODO(agarwal): create and link in some documentation for `execution_mode`.
# pylint: disable=redefined-outer-name
- def __init__(self, config=None, device_policy=None, execution_mode=None):
+ def __init__(self,
+ config=None,
+ device_policy=None,
+ execution_mode=None,
+ server_def=None):
"""Creates a new Context.
Args:
@@ -192,6 +196,7 @@ class Context(object):
if execution_mode is None:
execution_mode = SYNC
self._execution_mode = execution_mode
+ self._server_def = server_def
# pylint: enable=redefined-outer-name
@@ -231,6 +236,9 @@ class Context(object):
opts, self._device_policy)
if self._execution_mode == ASYNC:
pywrap_tensorflow.TFE_ContextOptionsSetAsync(opts, True)
+ if self._server_def is not None:
+ server_def_str = self._server_def.SerializeToString()
+ pywrap_tensorflow.TFE_ContextOptionsSetServerDef(opts, server_def_str)
self._context_handle = pywrap_tensorflow.TFE_NewContext(opts)
finally:
pywrap_tensorflow.TFE_DeleteContextOptions(opts)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 559063d6ae..df83d673ad 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
import numpy as np
@@ -35,6 +36,7 @@ from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import compat
@@ -46,8 +48,11 @@ def capture_value(tensor_map, value, dtype, name):
"""Capture a value from outside the function, to pass in as an extra arg."""
captured_value = tensor_map.get(ops.tensor_id(value), None)
if captured_value is None:
- captured_value = graph_placeholder(
- dtype=dtype or value.dtype, shape=value.shape, name=name)
+ # Note: setting ops.control_dependencies(None) ensures we always put
+ # capturing placeholders outside of any control flow context.
+ with ops.control_dependencies(None):
+ captured_value = graph_placeholder(
+ dtype=dtype or value.dtype, shape=value.shape, name=name)
if captured_value.dtype == dtypes_module.resource:
if ops._USE_C_SHAPES: # pylint: disable=protected-access
if isinstance(value, ops.EagerTensor):
@@ -222,11 +227,25 @@ def _inference_name(n):
return "__inference_%s_%s" % (n, ops.uid())
+def _register(fn):
+ """Registers the function `fn`."""
+ context.context().add_function(fn)
+
+
+_xla_compile_attr = "_XlaCompile"
+
+
# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
# so it doesn't have the definition-generating logic and is just a container for
# an already-defined function.
class _EagerDefinedFunction(object):
- """Function object with the interface of tf _DefinedFunction."""
+ """Callable with the interface of `framework.function._DefinedFunction.`
+
+ `_EagerDefinedFunction` encapsulates a function definition and its properties,
+ and it provides a method for calling the encapsulated function. Some Ops
+ take functions as attributes, which have type `func`; an instance of this
+ class may be provided as the value of these `func` attributes.
+ """
def __init__(self, name, graph, operations, inputs, outputs, attrs):
"""Initializes an eager defined function.
@@ -257,6 +276,7 @@ class _EagerDefinedFunction(object):
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
+ self._xla_compile = _xla_compile_attr in attrs
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
@@ -268,12 +288,92 @@ class _EagerDefinedFunction(object):
if context.executing_eagerly():
_register(fn)
self.definition = function_def
- self.name = function_def.signature.name
+ self.name = compat.as_bytes(function_def.signature.name)
self.signature = function_def.signature
+ self._num_outputs = len(self.signature.output_arg)
+ self._output_types = [o.type for o in self.signature.output_arg]
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
+ self._graph = graph
+ self._stateful_ops = tuple(op for op in operations if op.op_def.is_stateful)
+
+ def add_to_graph(self, g):
+ # pylint: disable=protected-access
+ if self.name not in g._functions:
+ g._add_function(self)
+ for f in self._graph._functions.values():
+ if f.name not in g._functions:
+ g._add_function(f)
+ # pylint: enable=protected-access
+
+ @property
+ def stateful_ops(self):
+ return self._stateful_ops
+
+ def call(self, ctx, args, output_shapes):
+ """Calls this function with `args` as inputs.
+
+ Function execution respects device annotations only if the function won't
+ be compiled with xla.
+
+ Args:
+ ctx: a Context object
+ args: a list of arguments to supply this function with.
+ output_shapes: shapes to which outputs should be set; ignored when
+ executing eagerly.
+
+ Returns:
+ The outputs of the function call.
+ """
+
+ executing_eagerly = ctx.executing_eagerly()
+
+ xla_compile = self._xla_compile or (executing_eagerly and
+ ctx.device_spec.device_type == "TPU")
+
+ if xla_compile:
+ # XLA compilation relies upon a custom kernel creator to run functions.
+ signature = self.signature
+ if executing_eagerly:
+ outputs = execute.execute(
+ str(signature.name),
+ num_outputs=self._num_outputs,
+ inputs=args,
+ attrs=None,
+ ctx=ctx)
+ else:
+ g = ops.get_default_graph()
+ self.add_to_graph(g)
+ op = g.create_op(
+ signature.name,
+ [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
+ op_def=signature,
+ name="FunctionCall",
+ compute_shapes=False)
+ outputs = op.outputs
+ if not outputs:
+ return op
+ outputs = [outputs] if isinstance(
+ outputs, (ops.Tensor, type(None))) else list(outputs)
+ else:
+ # TODO(akshayka): Either remove this if the FunctionLibraryRuntime
+ # creates `PartitionedCallOp` kernels by default, or remove the previous
+ # branch if a TPU kernel is registered for `PartitionedCall`.
+ outputs = functional_ops.partitioned_call(
+ args=args,
+ f=self,
+ tout=self._output_types,
+ executing_eagerly=executing_eagerly)
+
+ if executing_eagerly:
+ return outputs
+ else:
+ for i, shape in enumerate(output_shapes):
+ outputs[i].set_shape(shape)
+ return outputs
def _map_sequence_obj_to_idx(sequence):
@@ -297,8 +397,12 @@ def _flatten(sequence):
return outputs
+# TODO(akshayka): Perhaps rename to something more appropriate.
class GraphModeFunction(object):
- """Callable object representing a graph-mode function.
+ """Callable object encapsulating a function definition and its gradient.
+
+ `GraphModeFunction` is a callable that encapsulates a function definition and
+ is differentiable under `tf.GradientTape` objects.
"""
def __init__(self,
@@ -308,7 +412,7 @@ class GraphModeFunction(object):
graph,
operations,
outputs,
- func_outputs,
+ python_func_outputs,
output_shapes,
variables=None,
attrs=None):
@@ -327,9 +431,10 @@ class GraphModeFunction(object):
definition.
outputs: a flat list of the Tensors in the graph used as outputs to the
function
- func_outputs: a possibly nested python object which will be returned by
- this function. The Tensors in this structure will be replaced by their
- corresponding values in outputs.
+ python_func_outputs: a possibly nested python object which will be
+ returned by this function. The Tensors in this structure will be
+ replaced by their corresponding values in outputs. Note that this
+ structure might contain Python `None`s.
output_shapes: List of shapes of all tensors in outputs
variables: (optional) List of variables to watch during function
execution.
@@ -351,9 +456,10 @@ class GraphModeFunction(object):
self._function_def = defined_function
self._num_outputs = len(defined_function.signature.output_arg)
self._ops = operations
- self._func_outputs = func_outputs
- self._returns = [func_outputs] if isinstance(
- func_outputs, (ops.Tensor, type(None))) else _flatten(func_outputs)
+ self._python_func_outputs = python_func_outputs
+ self._python_returns = [python_func_outputs] if isinstance(
+ python_func_outputs,
+ (ops.Tensor, type(None))) else _flatten(python_func_outputs)
self._output_shapes = output_shapes
self._variables = variables if variables is not None else []
@@ -368,7 +474,7 @@ class GraphModeFunction(object):
c_captured_tensors = set()
existing_op_len = len(self._graph.get_operations())
- filtered_outputs = [x for x in self._returns if x is not None]
+ filtered_outputs = [x for x in self._python_returns if x is not None]
self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
in_gradients = gradients_impl.gradients(
@@ -377,7 +483,7 @@ class GraphModeFunction(object):
grad_ys=self._out_grad_placeholders)
for op in self._graph.get_operations()[existing_op_len:]:
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
- raise ValueError("tfe.defun cannot capture variables created without "
+ raise ValueError("defun cannot capture variables created without "
"using tf.get_variable. Op: %s" % op)
c_known_ops.add(op)
for i in op.inputs:
@@ -409,40 +515,32 @@ class GraphModeFunction(object):
backward_outputs, in_gradients, output_shapes, attrs=self._attrs)
def _backprop_call(self, args):
- """Calls the wrapped function and records the result on a tape."""
+ """Calls the wrapped function and records the result on a tape.
+
+ (Only records results on a tape if the function has outputs)
+
+ Args:
+ args: The tensor inputs to the function.
+ Returns:
+ The call output.
+ """
all_args = args + self._extra_inputs
- signature = self._forward_fdef.signature
ctx = context.context()
- if ctx.executing_eagerly():
- outputs = execute.execute(
- str(signature.name),
- num_outputs=len(signature.output_arg),
- inputs=all_args,
- attrs=None,
- ctx=ctx)
- else:
- g = ops.get_default_graph()
- g._add_function(self._forward_fdef) # pylint: disable=protected-access
- op = g.create_op(
- signature.name,
- [ops.internal_convert_to_tensor(x, ctx=ctx) for x in all_args],
- tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
- op_def=signature,
- name="FunctionCall",
- compute_shapes=False)
- outputs = op.outputs
- outputs = [outputs] if isinstance(
- outputs, (ops.Tensor, type(None))) else list(outputs)
- for i, s in enumerate(self._output_shapes):
- outputs[i].set_shape(s)
- real_outputs = outputs[:len(self._returns)]
- side_outputs = outputs[len(self._returns):]
+ outputs = self._forward_fdef.call(ctx, all_args, self._output_shapes)
+ if isinstance(outputs, ops.Operation) or outputs is None:
+ return outputs
+
+ # `real_outputs` are the actual outputs of the inference graph function;
+ # `side_outputs` are the intermediate Tensors that were added as outputs to
+ # the forward graph function so that we can compute its gradient.
+ real_outputs = outputs[:self._num_outputs]
+ side_outputs = outputs[self._num_outputs:]
def backward_function(*args):
return self._backward_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
tape.record_operation(
- signature.name,
+ self._forward_fdef.signature.name,
real_outputs,
(args + self._extra_inputs),
backward_function)
@@ -453,8 +551,8 @@ class GraphModeFunction(object):
def output_shapes(self):
"""The function's output shapes."""
# TODO(ebrevdo): Should we only keep the output shapes associated
- # with len(self._returns) outputs?
- outputs_list = nest.flatten(self._func_outputs)
+ # with len(self._python_returns) outputs?
+ outputs_list = nest.flatten(self._python_func_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -468,12 +566,12 @@ class GraphModeFunction(object):
else:
outputs_list[i] = self._output_shapes[j]
j += 1
- return nest.pack_sequence_as(self._func_outputs, outputs_list)
+ return nest.pack_sequence_as(self._python_func_outputs, outputs_list)
@property
def output_dtypes(self):
return nest.map_structure(
- lambda x: x.dtype if x is not None else None, self._func_outputs)
+ lambda x: x.dtype if x is not None else None, self._python_func_outputs)
@property
def captured_inputs(self):
@@ -484,13 +582,6 @@ class GraphModeFunction(object):
"""Returns the name of the function in Eager-compatible format."""
return self._function_def.name.encode("utf-8")
- def add_to_graph(self, g):
- if self._function_def.name not in g._functions: # pylint: disable=protected-access
- g._add_function(self._function_def) # pylint: disable=protected-access
- for f in self._graph._functions.values(): # pylint: disable=protected-access
- if f.name not in g._functions: # pylint: disable=protected-access
- g._add_function(f) # pylint: disable=protected-access
-
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
@@ -505,32 +596,9 @@ class GraphModeFunction(object):
return self._backprop_call(tensor_inputs)
ctx = context.context()
- if ctx.executing_eagerly():
- result = execute.execute(
- str(self._func_name),
- num_outputs=self._num_outputs,
- inputs=tensor_inputs + self._extra_inputs,
- attrs=None,
- ctx=ctx)
- else:
- g = ops.get_default_graph()
- self.add_to_graph(g)
- signature = self._function_def.definition.signature
- args = list(tensor_inputs) + self._extra_inputs
- op = g.create_op(
- signature.name,
- [ops.internal_convert_to_tensor(x, ctx=ctx) for x in args],
- tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
- op_def=signature,
- name="FunctionCall",
- compute_shapes=False)
- result = op.outputs
- if not result:
- return op
- for i, s in enumerate(self._output_shapes):
- result[i].set_shape(s)
-
- return self._build_call_outputs(result)
+ args = tensor_inputs + self._extra_inputs
+ outputs = self._function_def.call(ctx, args, self._output_shapes)
+ return self._build_call_outputs(outputs)
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -540,11 +608,12 @@ class GraphModeFunction(object):
Returns:
The actual call output.
"""
- if self._func_outputs is None:
- return None
+ if self._python_func_outputs is None:
+ return result
+
# Use `nest.flatten` instead of `_flatten` in order to preserve any
- # IndexedSlices in `self._func_outputs`.
- outputs_list = nest.flatten(self._func_outputs)
+ # IndexedSlices in `self._python_func_outputs`.
+ outputs_list = nest.flatten(self._python_func_outputs)
j = 0
for i, o in enumerate(outputs_list):
if o is not None:
@@ -564,7 +633,7 @@ class GraphModeFunction(object):
else:
outputs_list[i] = result[j]
j += 1
- ret = nest.pack_sequence_as(self._func_outputs, outputs_list)
+ ret = nest.pack_sequence_as(self._python_func_outputs, outputs_list)
return ret
@@ -580,7 +649,11 @@ def _get_defun_inputs(args):
return nest.pack_sequence_as(args, ret)
-def _defun_internal(name, func, compiled, args, kwds):
+def _deterministic_dict_values(kwds):
+ return tuple(kwds[key] for key in sorted(kwds))
+
+
+def _trace_and_define_function(name, func, compiled, args, kwds):
"""Defines and returns graph-mode version of func."""
graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
with context.graph_mode():
@@ -597,7 +670,8 @@ def _defun_internal(name, func, compiled, args, kwds):
tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
collection)
with tmp_graph.as_default(), AutomaticControlDependencies() as a:
- func_inputs = _get_defun_inputs(args)
+ func_args = _get_defun_inputs(args)
+ func_kwds = _get_defun_inputs(kwds)
def convert(x):
if x is None:
@@ -608,7 +682,7 @@ def _defun_internal(name, func, compiled, args, kwds):
this_tape = tape.push_new_tape()
try:
- func_outputs = func(*func_inputs, **kwds)
+ func_outputs = func(*func_args, **func_kwds)
func_outputs = nest.map_structure(convert, func_outputs)
finally:
tape.pop_tape(this_tape)
@@ -630,10 +704,13 @@ def _defun_internal(name, func, compiled, args, kwds):
extra_placeholders = []
output_shapes = tuple(
x.shape if isinstance(x, ops.Tensor) else None
- for x in outputs_list)
+ for x in func_def_outputs)
- flat_inputs = [x for x in nest.flatten(func_inputs)
- if isinstance(x, ops.Tensor)]
+ func_kwds_values = _deterministic_dict_values(func_kwds)
+ flat_inputs = [
+ x for x in nest.flatten(func_args) + nest.flatten(func_kwds_values)
+ if isinstance(x, ops.Tensor)
+ ]
all_inputs = flat_inputs + list(extra_placeholders)
all_ignored_ops = frozenset(x.op for x in all_inputs)
fname = _inference_name(name)
@@ -648,7 +725,7 @@ def _defun_internal(name, func, compiled, args, kwds):
attrs = {}
if compiled:
- attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=True)
+ attrs[_xla_compile_attr] = attr_value_pb2.AttrValue(b=True)
return GraphModeFunction(
fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
@@ -688,42 +765,89 @@ def _cache_key(x):
return x
-def _register(fn):
- """Registers the function `fn`."""
- context.context().add_function(fn)
+class _PolymorphicFunction(object):
+ """Wrapper class for the graph functions defined for a Python function.
+ See the documentation for `defun` for more information on the semantics of
+ defined functions.
+ """
-# TODO(apassos): better error messages for non-hashable arguments.
-def named_defun(func, name, compiled=False):
- """Defines a function with a given name.
+ def __init__(self, python_function, name, compiled=False):
+ """Initializes a polymorphic function.
- See the documentation for `defun` for more information on the semantics of
- this function.
+ Args:
+ python_function: the function to be wrapped.
+ name: the name given to it.
+ compiled: if True, the framework will attempt to compile func with XLA.
+ """
- Args:
- func: the function to be wrapped.
- name: the name given to it.
- compiled: if true, the framework will attempt to compile func with XLA.
+ self._python_function = python_function
+ self._name = name
+ self._compiled = compiled
+ self._arguments_to_functions = {}
+ self._variables = []
+
+ def __get__(self, instance, owner):
+ """Makes it possible to defun instance methods."""
+ del owner
+ # `instance` here is the instance that this `_PolymorphicFunction` was
+ # accessed through; e.g., for
+ #
+ # class Foo(object):
+ #
+ # @function.defun
+ # def bar(self):
+ # ...
+ #
+ # foo = Foo()
+ # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance
+ #
+ # then `instance` will be `foo` (and `owner` will be `Foo`).
+ return functools.partial(self.__call__, instance)
- Returns:
- the wrapped function.
- """
- arguments_to_functions = {}
+ def _maybe_define_function(self, *args, **kwds):
+ """Gets a function for these inputs, defining it if necessary.
- def decorated(*args, **kwds):
- """Decorated version of func."""
- # Macroexpand on non-Tensor arguments
- cache_key = tuple(_cache_key(x) for x in args)
- if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
- raise ValueError("Tensor keyword arguments are not supported.")
- cache_key = (cache_key, tuple(kwds.items()))
+ Args:
+ *args: args for the Python function; used to compute the signature
+ **kwds: kwds for the Python function; used to compute the signature
- if cache_key not in arguments_to_functions:
- arguments_to_functions[cache_key] = _defun_internal(
- name, func, compiled, args, kwds)
- return arguments_to_functions[cache_key](*args)
+ Returns:
+ A graph function corresponding to the input signature implied by args and
+ kwds, as well as the inputs that the object should be called with.
+ """
- return decorated
+ # TODO(apassos): Better error messages for non-hashable arguments.
+ kwd_values = _deterministic_dict_values(kwds)
+ inputs = args + kwd_values
+ signature = tuple(_cache_key(x) for x in inputs)
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # signature so we don't improperly capture tensors such as variables.
+ signature += tuple([context.executing_eagerly() or ops.get_default_graph()])
+
+ if signature not in self._arguments_to_functions:
+ graph_function = _trace_and_define_function(
+ self._name, self._python_function, self._compiled, args, kwds)
+ self._arguments_to_functions[signature] = graph_function
+ self._variables.extend(
+ [v for v in graph_function.variables if v not in self._variables])
+ return graph_function, inputs
+ else:
+ return self._arguments_to_functions[signature], inputs
+
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized for this input signature."""
+ graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ return graph_function(*inputs)
+
+ def call_python_function(self, *args, **kwargs):
+ """Directly calls the wrapped python function."""
+ return self._python_function(*args, **kwargs)
+
+ @property
+ def variables(self):
+ """Returns a list of variables used in any of the defined functions."""
+ return self._variables
# TODO(akshayka): Remove the `compiled` flag and create a separate
@@ -734,22 +858,33 @@ def defun(func=None, compiled=False):
`defun` (short for "define function") trace-compiles a Python function
composed of TensorFlow operations into a callable that executes a @{tf.Graph}
- containing those operations. When eager execution is enabled, the ability to
- create graphs from Python functions makes it possible to incrementally trade
- off debugability and interactivity for performance. Functions compiled with
- `defun` cannot be inspected with `pdb` and `print` statements; however,
- executing a graph generated by `defun` sometimes takes less time and memory
- than eagerly executing the corresponding Python function, since specifying
- computations as graphs allows for optimizations like automatic buffer reuse
- and parallelization among ops. Note that executing a `defun`-compiled function
+ containing those operations. The callable produced by `defun` contains only
+ the subgraph of TensorFlow operations that were executed when the Python
+ function was called with a particular input signature, defined as a list
+ of the shapes and dtypes of the Python function's Tensor-valued arguments and
+ the values of its non-Tensor Python objects. In particular, `defun` is _not_ a
+ compiler for arbitrary Python code.
+
+ When eager execution is enabled, the ability to create graphs from Python
+ functions makes it possible to incrementally trade off debugability and
+ interactivity for performance. Functions compiled with `defun` cannot be
+ inspected with `pdb` and `print` statements; however, executing a graph
+ generated by `defun` sometimes takes less time and memory than eagerly
+ executing the corresponding Python function, since specifying computations as
+ graphs allows for optimizations like automatic buffer reuse and
+ parallelization among ops. Note that executing a `defun`-compiled function
incurs a small constant overhead, so eagerly executing sufficiently small
Python functions might take less time than executing their corresponding
`defun`-generated graphs.
- For a Python function to be compatible with `defun`, the values of its keyword
- arguments cannot be Tensors and all of its arguments, including its keyword
- arguments, must be hashable Python objects or lists thereof. Additionally, it
- must return zero or more @{tf.Tensor} objects.
+ For a Python function to be compatible with `defun`, all of its arguments must
+ be hashable Python objects or lists thereof. Additionally, it must return zero
+ or more @{tf.Tensor} objects.
+
+ Executing a graph generated by `defun` respects device annotations (i.e.,
+ all `with tf.device` directives present in a Python function will also be
+ present in its corresponding graph), but it is not yet possible to execute the
+ generated graphs across multiple machines.
_Example Usage_
@@ -822,20 +957,23 @@ def defun(func=None, compiled=False):
_Tracing and Input Signatures_.
The signature of inputs supplied to `F` is defined to be a tuple of the shapes
- and dtypes of Tensor-typed arguments and the values of non-Tensor arguments
- and keyword arguments. Every time `F` is invoked, the signature of its inputs
- are inferred. The first time `F(*args, **kwargs)` is invoked with a particular
- signature, `f(*args, **kwargs)` is executed and all the TensorFlow operations
- that `f` executes, along with the Tensors that flow between them, are recorded
- in a TensorFlow graph. `F` caches this graph and binds it to the inputs'
- signature; every subsequent invocation of `F` with inputs conforming to this
- signature will immediately retrieve the cached graph and pass it to the
- TensorFlow runtime for execution.
-
- Be aware that because `F` only logs TensorFlow operations, all non-TensorFlow
- operations that `f` executes will only shape the _construction_ of the graphs
- that `F` executes: They won't be executed when the graphs themselves are
- executed. For example, whereas the Python function
+ and dtypes of Tensor-typed arguments and the values of non-Tensor arguments,
+ where "arguments" includes both args and kwargs. Every time `F` is invoked,
+ the signature of its inputs are inferred. The first time `F(*args, **kwargs)`
+ is invoked with a particular signature, `f(*args, **kwargs)` is executed and
+ all the TensorFlow operations that `f` executes, along with the Tensors that
+ flow between them, are recorded in a TensorFlow graph. `F` caches this graph
+ and binds it to the inputs' signature; every subsequent invocation of `F` with
+ inputs conforming to this signature will immediately retrieve the cached graph
+ and pass it to the TensorFlow runtime for execution.
+
+ Be aware that because `F` only logs TensorFlow operations, all the other
+ Python code that `f` executes will only shape the _construction_ of the graphs
+ that `F` executes: the Python code won't be executed when the graphs
+ themselves are executed, though it will be executed every time the Python
+ function is traced (and a given Python function might be traced multiple
+ times, once for each input signature it is invoked with). For example, whereas
+ the Python function
```python
import tensorflow as tf
@@ -843,17 +981,23 @@ def defun(func=None, compiled=False):
tf.enable_eager_execution()
- matrix = tf.eye(5)
- # `matrix` is assumed to be a Tensor
def add_noise():
- return matrix + np.random.randn(matrix.shape[0], matrix.shape[1])
+ return tf.eye(5) + np.random.randn(5, 5)
```
will return a different output everytime it is invoked, the compiled function
`compiled = tf.contrib.eager.defun(add_noise)` will return the same value
every time it is called, since a particular random offset generated by NumPy
will be inserted into the graph as a TensorFlow constant. The solution is to
- replace the call to `np.random.randn` with `tf.random_normal(matrix.shape)`.
+ replace the call to `np.random.randn` with `tf.random_normal((5, 5))`.
+
+ _Python Side-Effects_
+ A corollary of the previous discussion on tracing is the following: If a
+ Python function `f` has Python side-effects, then executing `f` multiple times
+ will not necessarily be semantically equivalent to executing `F =
+ tf.contrib.eager.defun(f)` multiple times; this difference is due to the fact
+ that `defun` only captures the subgraph of TensorFlow operations that is
+ constructed when `f` is called in a graph-building context.
_Python Control Flow_.
The structure of many machine learning computations depend upon whether one is
@@ -980,7 +1124,7 @@ def defun(func=None, compiled=False):
except AttributeError:
name = "function"
return tf_decorator.make_decorator(
- function, named_defun(function, name, compiled=compiled))
+ function, _PolymorphicFunction(function, name, compiled=compiled))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1037,15 +1181,8 @@ def make_defun_op(func, *args, **kwds):
A wrapper object which can be queried for its output properties,
and which can be called directly the way a `@defun` wrapped function
can.
-
- Raises:
- ValueError: if any of the keyword arguments to `func` are `EagerTensor`
- objects (not yet supported).
"""
- name = func.__name__
- if any(isinstance(x, ops.EagerTensor) for x in kwds.values()):
- raise ValueError("Tensor keyword arguments are not supported.")
- return _defun_internal(name, func, False, args, kwds)
+ return _trace_and_define_function(func.__name__, func, False, args, kwds)
class AutomaticControlDependencies(object):
@@ -1159,7 +1296,7 @@ class AutomaticControlDependencies(object):
# Ensures the merge always runs
ops_which_must_run.add(new_merge[0].op)
if inp in last_op_using_resource_tensor:
- # Ensures the switch exectutes after the previous op using the resource.
+ # Ensures the switch executes after the previous op using the resource.
switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access
# Ensure the next op outside the cond happens after the merge.
last_op_using_resource_tensor[inp] = new_merge[0].op
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index f53d6c2608..a3e63c3153 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -19,13 +19,15 @@ from __future__ import print_function
import collections
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import tape
-from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function as tf_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -34,12 +36,16 @@ from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import gradient_descent
+from tensorflow.python.platform import test
+from tensorflow.python.training import momentum
+from tensorflow.python.training import training_ops
+from tensorflow.python.util import compat
@test_util.with_c_shapes
@@ -90,6 +96,32 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(step(), 2.0)
+ def testGraphGradientVariable(self):
+ with ops.Graph().as_default(), self.test_session():
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def f():
+ return 2.0 * v
+
+ node = f()
+ grads, = gradients_impl.gradients(node, v)
+ v.initializer.run()
+ self.assertAllEqual(grads.eval(), 2.0)
+ self.assertEqual(grads.shape, v.shape)
+
+ def testGraphEagerIsolation(self):
+
+ @function.defun
+ def f():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ return v.read_value()
+
+ self.assertAllEqual(f(), 1.0)
+
+ with ops.Graph().as_default():
+ self.assertEqual(f().shape, ())
+
def testBasicDefunOpGraphMode(self):
matmul = function.defun(math_ops.matmul)
@@ -166,6 +198,15 @@ class FunctionTest(test.TestCase):
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
+ def testDefunCapturedInt32(self):
+ x = constant_op.constant(1, dtype=dtypes.int32)
+
+ @function.defun
+ def add_int32s():
+ return x + x
+
+ self.assertEqual(2, int(add_int32s()))
+
def testDefunReadVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -177,13 +218,14 @@ class FunctionTest(test.TestCase):
def testDefunAssignAddVariable(self):
v = resource_variable_ops.ResourceVariable(1.0)
+ x = constant_op.constant(2.0)
@function.defun
- def f():
- v.assign_add(2.0)
+ def test_assign_add():
+ v.assign_add(x)
return v.read_value()
- self.assertEqual(3.0, float(f()))
+ self.assertEqual(3.0, float(test_assign_add()))
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -196,6 +238,21 @@ class FunctionTest(test.TestCase):
compiled = function.defun(f)
compiled()
+ def testVariableInLoopInFunction(self):
+
+ @function.defun
+ def test_function():
+
+ def loop_test(_):
+ return False
+
+ def loop_body(_):
+ return variable_scope.get_variable('a', shape=())
+
+ return control_flow_ops.while_loop(loop_test, loop_body, [0.0])
+
+ self.assertEqual(test_function().shape, [])
+
def testDefunShapeInferenceWithCapturedResourceVariableInGraphMode(self):
with context.graph_mode():
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -349,6 +406,23 @@ class FunctionTest(test.TestCase):
g(constant_op.constant(1.0))
+ def testNestedDefunWithNoOutputAndTapedInput(self):
+ three = resource_variable_ops.ResourceVariable(3.0, name='v')
+
+ @function.defun
+ def f(x):
+ # This function intentionally takes a taped variable as input,
+ # but does not return any values
+ math_ops.add(x, three)
+
+ @function.defun
+ def g(x):
+ tape.watch_variable(x)
+ y = math_ops.add(x, three)
+ f(y)
+
+ g(three)
+
def testGradientTensorConversionWithDefun(self):
three = resource_variable_ops.ResourceVariable(3.0, name='v')
@@ -381,24 +455,33 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(f(constant_op.constant(1.0)), 2.0)
- def testGradientOfGatherWithDefun(self):
+ def testGatherResourceWithDefun(self):
with ops.device('cpu:0'):
v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
- def sum_gather():
- return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
+ def sum_gather():
+ return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
- grad_fn = backprop.implicit_grad(sum_gather)
- gradient = grad_fn()
- defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather))
- defun_gradient = defun_grad_fn()
- self.assertEqual(len(gradient), len(defun_gradient))
+ defined = function.defun(sum_gather)
+ self.assertAllEqual(sum_gather(), defined())
- gradient = gradient[0][0]
- defun_gradient = defun_gradient[0][0]
- self.assertAllEqual(gradient.values, defun_gradient.values)
- self.assertAllEqual(gradient.indices, defun_gradient.indices)
- self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape)
+ def testGradientOfGatherWithDefun(self):
+ v = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
+
+ def sum_gather():
+ return math_ops.reduce_sum(array_ops.gather(v, [1, 2]))
+
+ grad_fn = backprop.implicit_grad(sum_gather)
+ gradient = grad_fn()
+ defun_grad_fn = backprop.implicit_grad(function.defun(sum_gather))
+ defun_gradient = defun_grad_fn()
+ self.assertEqual(len(gradient), len(defun_gradient))
+
+ gradient = gradient[0][0]
+ defun_gradient = defun_gradient[0][0]
+ self.assertAllEqual(gradient.values, defun_gradient.values)
+ self.assertAllEqual(gradient.indices, defun_gradient.indices)
+ self.assertAllEqual(gradient.dense_shape, defun_gradient.dense_shape)
def testReturningIndexedSlicesWithDefun(self):
@@ -462,6 +545,66 @@ class FunctionTest(test.TestCase):
y = f(x, x).cpu()
self.assertAllEqual(y, [2.])
+ @test_util.run_in_graph_and_eager_modes
+ def testFunctionWithResourcesOnDifferentDevices(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('/cpu:0'):
+ v_cpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
+
+ with ops.device('/gpu:0'):
+ v_gpu = resource_variable_ops.ResourceVariable([0.0, 1.0, 2.0])
+
+ def sum_gather():
+ cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu, [1, 2]))
+ gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
+ return cpu_result, gpu_result
+
+ defined = function.defun(sum_gather)
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ expected = self.evaluate(sum_gather())
+ self.assertAllEqual(expected, self.evaluate(defined()))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testOpInFunctionWithConflictingResourceInputs(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('/cpu:0'):
+ v_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='cpu')
+ v_also_cpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='also_cpu')
+
+ with ops.device('/gpu:0'):
+ v_gpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name='gpu')
+
+ @function.defun
+ def resource_apply_adam():
+ training_ops.resource_apply_adam(
+ v_cpu.handle,
+ v_gpu.handle,
+ v_also_cpu.handle,
+ 1.0, # beta1_power
+ 1.0, # beta2_power
+ 1.0, # learning_rate
+ 1.0, # beta1
+ 1.0, # beta2
+ 1.0, # epsilon,
+ [1.0, 1.0, 1.0], # grad
+ False) # use_locking
+ return None
+
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError, 'Could not colocate node with its '
+ 'resource and reference inputs.*'):
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(resource_apply_adam())
+
def testFunctionHandlesInputsOnDifferentDevices(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
@@ -495,6 +638,60 @@ class FunctionTest(test.TestCase):
g = backprop.gradients_function(wrapper, [0])(constant_op.constant(0.0))
self.assertAllEqual(g[0], 1.)
+ @function.defun
+ def foo(a):
+ return None, a * a
+
+ x = constant_op.constant(5.0)
+ with backprop.GradientTape() as tp:
+ tp.watch(x)
+ none, r = foo(x)
+ g = tp.gradient(r, x)
+
+ self.assertIs(none, None)
+ self.assertAllEqual(r, 25.0)
+ self.assertAllEqual(g, 2 * 5.0)
+
+ def testNestedDifferentiableFunction(self):
+ @function.defun
+ def inner_fn(a, b):
+ return a * math_ops.add(a, b)
+
+ @function.defun
+ def outer_fn(x):
+ return inner_fn(x, 1.0)
+
+ x = constant_op.constant(5.0)
+ with backprop.GradientTape() as tp:
+ tp.watch(x)
+ result = outer_fn(x)
+ grad = tp.gradient(result, x)
+
+ self.assertAllEqual(grad, 2 * 5.0 + 1.0)
+
+ def testNestedDifferentiableFunctionNoneOutputs(self):
+ @function.defun
+ def foo(a, b):
+ return None, a * math_ops.add(a, b), None, 2*a
+
+ @function.defun
+ def bar(x):
+ return foo(x, 1.0)
+
+ x = constant_op.constant(5.0)
+ with backprop.GradientTape(persistent=True) as tp:
+ tp.watch(x)
+ none1, r1, none2, r2 = bar(x)
+ g1 = tp.gradient(r1, x)
+ g2 = tp.gradient(r2, x)
+
+ self.assertAllEqual(r1, 30.0)
+ self.assertAllEqual(r2, 10.0)
+ self.assertIs(none1, None)
+ self.assertIs(none2, None)
+ self.assertAllEqual(g1, 2 * 5.0 + 1.0)
+ self.assertAllEqual(g2, 2.0)
+
def testNoneOutput(self):
@function.defun
@@ -517,15 +714,15 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(3, add_one(constant_op.constant(2)))
def testVariableCaptureInNestedFunctions(self):
- v = resource_variable_ops.ResourceVariable(1)
+ v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int32)
@function.defun
- def read():
+ def inner_read():
return v.read_value()
@function.defun
def outer():
- return read()
+ return inner_read()
self.assertEqual(1, int(outer()))
@@ -616,6 +813,146 @@ class FunctionTest(test.TestCase):
y = model(x)
self.assertAllEqual([[[[4.0]]]], y.numpy())
+ @test_util.run_in_graph_and_eager_modes(
+ config=config_pb2.ConfigProto(device_count={'CPU': 3}))
+ def testDeviceAnnotationsRespected(self):
+ @function.defun
+ def multi_device_fn():
+ with ops.device('/cpu:0'):
+ s1 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ with ops.device('/cpu:1'):
+ s2 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ with ops.device('/cpu:2'):
+ s3 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ return s1, s2, s3
+
+ outputs = multi_device_fn()
+ self.assertTrue(compat.as_bytes('CPU:0') in self.evaluate(outputs[0]))
+ self.assertTrue(compat.as_bytes('CPU:1') in self.evaluate(outputs[1]))
+ self.assertTrue(compat.as_bytes('CPU:2') in self.evaluate(outputs[2]))
+
+ def testVariablesAreTracked(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ def foo(x):
+ return v * x
+
+ defined = function.defun(foo)
+
+ x = constant_op.constant([1.0])
+ self.assertAllEqual(defined.variables, [])
+ _ = defined(x)
+ self.assertAllEqual(defined.variables, [v])
+
+ x = constant_op.constant([1.0, 2.0])
+ _ = defined(x) # ensure the variables list remains the same
+ self.assertAllEqual(defined.variables, [v])
+
+ def testTensorKeywordArguments(self):
+
+ def foo(a, b):
+ del a
+ return b
+
+ defined = function.defun(foo)
+ a = constant_op.constant(2.0)
+ b = constant_op.constant([1.0, 2.0])
+ one = defined(a, b)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ two = defined(a=a, b=b)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ three = defined(b=b, a=a)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ four = defined(a, b=b)
+ self.assertEqual(len(defined._arguments_to_functions), 1)
+
+ # The next call corresponds to a new input signature, hence
+ # we expect another function to be defined.
+ five = defined(b, a)
+ self.assertEqual(len(defined._arguments_to_functions), 2)
+
+ six = defined(a=b, b=a)
+ self.assertEqual(len(defined._arguments_to_functions), 2)
+
+ seven = defined(b=a, a=b)
+ self.assertEqual(len(defined._arguments_to_functions), 2)
+
+ self.assertAllEqual(one, [1.0, 2.0])
+ self.assertAllEqual(two, [1.0, 2.0])
+ self.assertAllEqual(three, [1.0, 2.0])
+ self.assertAllEqual(four, [1.0, 2.0])
+ self.assertAllEqual(five, 2.0)
+ self.assertAllEqual(six, 2.0)
+ self.assertAllEqual(seven, 2.0)
+
+ def testGradientWithKeywordArguments(self):
+ matmul = function.defun(math_ops.matmul)
+
+ def sq(x):
+ return matmul(a=x, b=x, transpose_a=True)
+
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ grad_t, = backprop.gradients_function(sq, [0])(t)
+ self.assertAllEqual(grad_t, [[6, 6], [14, 14]])
+
+ with backprop.GradientTape(persistent=True) as gtape:
+ gtape.watch(t)
+ one = matmul(t, b=t, transpose_a=True)
+ two = matmul(b=t, a=t, transpose_a=True)
+ three = matmul(a=t, b=t, transpose_a=True)
+
+ for output in [one, two, three]:
+ self.assertAllEqual(gtape.gradient(output, t), [[6, 6], [14, 14]])
+
+ def testGradientInFunctionWithKeywordArguments(self):
+
+ @function.defun
+ def f(x):
+ return backprop.gradients_function(lambda y: y * y, [0])(x)[0]
+
+ self.assertAllEqual(f(x=constant_op.constant(1.0)), 2.0)
+
+ def testDecoratingInstanceMethod(self):
+
+ class Foo(object):
+
+ def one(self, tensor):
+ return tensor
+
+ @function.defun
+ def two(self, tensor):
+ return self.one(tensor)
+
+ foo = Foo()
+ t = constant_op.constant(1.0)
+ out = foo.two(t)
+ self.assertEqual(float(out), 1.0)
+
+ def testPythonCallWithSideEffects(self):
+ state = []
+
+ @function.defun
+ def side_effecting_function():
+ state.append(0)
+
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # The second invocation should call the graph function, which shouldn't
+ # trigger the list append.
+ side_effecting_function()
+ self.assertAllEqual(state, [0])
+
+ # Whereas calling the python function directly should create a side-effect.
+ side_effecting_function.call_python_function()
+ self.assertAllEqual(state, [0, 0])
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
@@ -803,7 +1140,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
def loss(v):
return v**2
- optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+ optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0)
@function.defun
def train():
@@ -820,7 +1157,7 @@ class AutomaticControlDependenciesTest(test.TestCase):
def loss():
return v**2
- optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
+ optimizer = momentum.MomentumOptimizer(learning_rate=1.0, momentum=1.0)
@function.defun
def train():
@@ -832,4 +1169,6 @@ class AutomaticControlDependenciesTest(test.TestCase):
if __name__ == '__main__':
+ ops.enable_eager_execution(
+ config=config_pb2.ConfigProto(device_count={'CPU': 3}))
test.main()
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 760a148552..2c6f04d8ad 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -110,13 +110,25 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, partitioner, validate_shape, use_resource, dtype
- del collections, initializer, trainable, reuse, caching_device, shape,
+ del collections, initializer, trainable, reuse, caching_device, shape
+ del aggregation, synchronization
assert name in self.variables
v = self.variables[name]
return v.variable
@@ -136,13 +148,24 @@ class _VariableCapturingScope(object):
"""
# TODO(apassos) ignoring the regularizer and partitioner here; figure out
# how to deal with these.
- def _custom_getter(getter=None, name=None, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None, # pylint: disable=redefined-outer-name
- partitioner=None, validate_shape=True,
- use_resource=None):
+ def _custom_getter( # pylint: disable=missing-docstring
+ getter=None,
+ name=None,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None, # pylint: disable=redefined-outer-name
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ aggregation=variable_scope.VariableAggregation.NONE,
+ synchronization=variable_scope.VariableSynchronization.AUTO):
del getter, regularizer, collections, caching_device, partitioner
- del use_resource, validate_shape
+ del use_resource, validate_shape, aggregation, synchronization
if name in self.tf_variables:
if reuse:
return self.tf_variables[name].initialized_value()
diff --git a/tensorflow/python/eager/memory_test.py b/tensorflow/python/eager/memory_test.py
new file mode 100644
index 0000000000..74c6cbdd31
--- /dev/null
+++ b/tensorflow/python/eager/memory_test.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for memory leaks in eager execution.
+
+It is possible that this test suite will eventually become flaky due to taking
+too long to run (since the tests iterate many times), but for now they are
+helpful for finding memory leaks since not all PyObject leaks are found by
+introspection (test_util decorators). Please be careful adding new tests here.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python import keras
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+
+# memory_profiler might not be available in the OSS version of TensorFlow.
+try:
+ import memory_profiler # pylint:disable=g-import-not-at-top
+except ImportError:
+ memory_profiler = None
+
+
+class SingleLayerNet(keras.Model):
+ """Simple keras model used to ensure that there are no leaks."""
+
+ def __init__(self):
+ super(SingleLayerNet, self).__init__()
+ self.fc1 = keras.layers.Dense(5)
+
+ def call(self, x):
+ return self.fc1(x)
+
+
+class MemoryTest(test.TestCase):
+
+ def assertNotIncreasingMemory(self,
+ f,
+ num_iters=100000,
+ increase_threshold_absolute_mb=10):
+ """Assert memory usage doesn't increase beyond given threshold for f."""
+
+ with context.eager_mode():
+ # Warm up.
+ f()
+
+ initial = memory_profiler.memory_usage(-1)[0]
+
+ for _ in xrange(num_iters):
+ f()
+
+ increase = memory_profiler.memory_usage(-1)[0] - initial
+
+ assert increase < increase_threshold_absolute_mb, (
+ "Increase is too high. Initial memory usage: %f MB. Increase: %f MB. "
+ "Maximum allowed increase: %f") % (initial, increase,
+ increase_threshold_absolute_mb)
+
+ def testMemoryLeakInSimpleModelForwardOnly(self):
+ if memory_profiler is None:
+ self.skipTest("memory_profiler required to run this test")
+
+ inputs = array_ops.zeros([32, 100], dtypes.float32)
+ net = SingleLayerNet()
+
+ def f():
+ with backprop.GradientTape():
+ net(inputs)
+
+ self.assertNotIncreasingMemory(f)
+
+ def testMemoryLeakInSimpleModelForwardAndBackward(self):
+ if memory_profiler is None:
+ self.skipTest("memory_profiler required to run this test")
+
+ inputs = array_ops.zeros([32, 100], dtypes.float32)
+ net = SingleLayerNet()
+
+ def f():
+ with backprop.GradientTape() as tape:
+ result = net(inputs)
+
+ tape.gradient(result, net.variables)
+
+ del tape
+
+ self.assertNotIncreasingMemory(f)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 6c9481c3af..ec7e2371e9 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -205,14 +205,20 @@ bool ParseDimensionValue(const string& key, PyObject* py_value,
}
bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
- const char** value) {
+ tensorflow::StringPiece* value) {
if (PyBytes_Check(py_value)) {
- *value = PyBytes_AsString(py_value);
+ Py_ssize_t size = 0;
+ char* buf = nullptr;
+ if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
+ *value = tensorflow::StringPiece(buf, size);
return true;
}
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(py_value)) {
- *value = PyUnicode_AsUTF8(py_value);
+ Py_ssize_t size = 0;
+ char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
+ if (buf == nullptr) return false;
+ *value = tensorflow::StringPiece(buf, size);
return true;
}
#endif
@@ -275,8 +281,16 @@ bool SetOpAttrList(
}
if (type == TF_ATTR_STRING) {
- PARSE_LIST(const char*, ParseStringValue);
- TFE_OpSetAttrStringList(op, key, values.get(), num_values);
+ std::unique_ptr<const void*[]> values(new const void*[num_values]);
+ std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
+ for (int i = 0; i < num_values; ++i) {
+ tensorflow::StringPiece value;
+ tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
+ if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
+ values[i] = value.data();
+ lengths[i] = value.size();
+ }
+ TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
} else if (type == TF_ATTR_INT) {
PARSE_LIST(int64_t, ParseInt64Value);
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
@@ -379,12 +393,15 @@ void SetOpAttrListDefault(
TF_Status* status) {
if (type == TF_ATTR_STRING) {
int num_values = attr.default_value().list().s_size();
- std::unique_ptr<const char*[]> values(new const char*[num_values]);
+ std::unique_ptr<const void*[]> values(new const void*[num_values]);
+ std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
- values[i] = attr.default_value().list().s(i).data();
+ const string& v = attr.default_value().list().s(i);
+ values[i] = v.data();
+ lengths[i] = v.size();
}
- TFE_OpSetAttrStringList(op, key, values.get(), num_values);
+ TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
} else if (type == TF_ATTR_INT) {
int num_values = attr.default_value().list().i_size();
std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
@@ -470,9 +487,9 @@ bool SetOpAttrScalar(
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
if (type == TF_ATTR_STRING) {
- const char* value;
+ tensorflow::StringPiece value;
if (!ParseStringValue(key, py_value, status, &value)) return false;
- TFE_OpSetAttrString(op, key, value);
+ TFE_OpSetAttrString(op, key, value.data(), value.size());
} else if (type == TF_ATTR_INT) {
int64_t value;
if (!ParseInt64Value(key, py_value, status, &value)) return false;
@@ -533,7 +550,7 @@ bool SetOpAttrScalar(
// (which is what the various "defun" or "Defun" decorators do).
// And in the future also allow an object that can encapsulate
// the function name and its attribute values.
- const char* func_name = nullptr;
+ tensorflow::StringPiece func_name;
if (!ParseStringValue(key, py_value, status, &func_name)) {
PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
if (name_attr == nullptr ||
@@ -549,7 +566,8 @@ bool SetOpAttrScalar(
return false;
}
}
- TFE_Op* func = TFE_NewOp(ctx, func_name, status);
+ TFE_Op* func = TFE_NewOp(
+ ctx, string(func_name.data(), func_name.size()).c_str(), status);
if (TF_GetCode(status) != TF_OK) return false;
TFE_OpSetAttrFunction(op, key, func);
TFE_DeleteOp(func);
@@ -930,7 +948,7 @@ class GradientTape
: id(id), variable(variable) {}
};
struct CompareById {
- bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) {
+ bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
return lhs.id < rhs.id;
}
};
@@ -1880,14 +1898,39 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "trainable"));
+ DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "trainable"));
+ PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
+bool CastTensor(const FastPathOpExecInfo& op_exec_info,
+ const TF_DataType& desired_dtype,
+ tensorflow::Safe_TFE_TensorHandlePtr* handle,
+ TF_Status* status) {
+ TF_DataType input_dtype = TFE_TensorHandleDataType(handle->get());
+ TF_DataType output_dtype = input_dtype;
+
+ if (desired_dtype >= 0 && desired_dtype != input_dtype) {
+ *handle = tensorflow::make_safe(
+ tensorflow::EagerCast(op_exec_info.ctx, handle->get(), input_dtype,
+ static_cast<TF_DataType>(desired_dtype), status));
+ if (!status->status.ok()) return false;
+ output_dtype = desired_dtype;
+ }
+
+ if (output_dtype != TF_INT32) {
+ // Note that this is a shallow copy and will share the underlying buffer
+ // if copying to the same device.
+ *handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
+ handle->get(), op_exec_info.ctx, op_exec_info.device_name, status));
+ if (!status->status.ok()) return false;
+ }
+ return true;
+}
+
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
@@ -1920,9 +1963,31 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
TFE_Execute(op, &output_handle, &num_retvals, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
- // Always create the py object (and correctly DECREF it) from the returned
- // value, else the data will leak.
- output->reset(EagerTensorFromHandle(output_handle));
+ if (!PyObject_HasAttrString(input, "_read_dtype")) {
+ // Always create the py object (and correctly DECREF it) from the returned
+ // value, else the data will leak.
+ output->reset(EagerTensorFromHandle(output_handle));
+ } else {
+ // This is a _MixedPrecisionVariable which potentially does casting when
+ // being read.
+ tensorflow::Safe_PyObjectPtr read_dtype(
+ PyObject_GetAttrString(input, "_read_dtype"));
+ int desired_dtype = -1;
+ if (!ParseTypeValue("_read_dtype", read_dtype.get(), status,
+ &desired_dtype)) {
+ return false;
+ }
+
+ auto safe_output_handle = tensorflow::make_safe(output_handle);
+ // Retires output_handle in the future.
+ output_handle = nullptr;
+ if (!CastTensor(parent_op_exec_info,
+ static_cast<TF_DataType>(desired_dtype),
+ &safe_output_handle, status)) {
+ return false;
+ }
+ output->reset(EagerTensorFromHandle(safe_output_handle.release()));
+ }
// TODO(nareshmodi): Should we run post exec callbacks here?
if (parent_op_exec_info.run_gradient_callback) {
@@ -1992,27 +2057,13 @@ bool ConvertToTensor(
}
}
- TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
- if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
- handle = tensorflow::make_safe(
- tensorflow::EagerCast(op_exec_info.ctx, handle.get(), handle_dtype,
- static_cast<TF_DataType>(desired_dtype), status));
- if (!status->status.ok()) return false;
-
- handle_dtype = TFE_TensorHandleDataType(handle.get());
- }
-
- if (handle_dtype != TF_INT32) {
- // Note that this is a shallow copy and will share the underlying buffer
- // if copying to the same device.
- handle = tensorflow::make_safe(TFE_TensorHandleCopyToDevice(
- handle.get(), op_exec_info.ctx, op_exec_info.device_name, status));
- if (!status->status.ok()) return false;
+ if (!CastTensor(op_exec_info, static_cast<TF_DataType>(desired_dtype),
+ &handle, status)) {
+ return false;
}
-
+ TF_DataType output_dtype = TFE_TensorHandleDataType(handle.get());
output_handle->reset(EagerTensorFromHandle(handle.release()));
-
- dtype_setter(handle_dtype);
+ dtype_setter(output_dtype);
return true;
}
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index faaae40b3f..fd8ab695b8 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -71,6 +72,25 @@ class Tests(test.TestCase):
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
+ def testFastpathExecute_MixedPrecisionVariableMatMulCorrectResponse(self):
+ ctx = context.context()
+ a_2_by_2 = constant_op.constant(1.0, shape=[2, 2])
+ a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
+ m = resource_variable_ops.ResourceVariable(a_2_by_2)
+ m = resource_variable_ops._MixedPrecisionVariable(
+ m, read_dtype=dtypes.float16)
+ x = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, m, m, "transpose_a",
+ False, "transpose_b", False)
+ y = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16,
+ a_2_by_2_fp16, "transpose_a", False, "transpose_b", False)
+
+ self.assertEqual(x.dtype, dtypes.float16)
+ self.assertAllEqual(x, y)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
def testFastpathExecute_TapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
@@ -98,6 +118,29 @@ class Tests(test.TestCase):
self.assertAllEqual(dz_dy.numpy(),
constant_op.constant(4.0, shape=[2, 2]).numpy())
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
+ ctx = context.context()
+ with backprop.GradientTape(persistent=True) as tape:
+ a_2_by_2 = constant_op.constant(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
+ m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
+ m2 = resource_variable_ops._MixedPrecisionVariable(
+ m1, read_dtype=dtypes.float16)
+ tape.watch(m2)
+ z = pywrap_tensorflow.TFE_Py_FastPathExecute(
+ ctx._handle, ctx.device_name, "MatMul", None, None, a_2_by_2_fp16, m2,
+ "transpose_a", False, "transpose_b", False)
+ dz_dy = tape.gradient(z, [m2])[0]
+ self.assertEqual(dz_dy.dtype, dtypes.float16)
+
+ expected_grads = math_ops.matmul(
+ array_ops.transpose(a_2_by_2_fp16),
+ constant_op.constant(1., shape=[2, 2], dtype=dtypes.float16)).numpy()
+ self.assertAllEqual(dz_dy.numpy(), expected_grads)
+
# Tests homogeneous list op
@test_util.assert_no_new_tensors
@test_util.assert_no_garbage_created
diff --git a/tensorflow/python/eager/test.py b/tensorflow/python/eager/test.py
index f6a46e7eb3..33ee797678 100644
--- a/tensorflow/python/eager/test.py
+++ b/tensorflow/python/eager/test.py
@@ -23,6 +23,7 @@ from tensorflow.python.platform import test as _test
from tensorflow.python.platform.test import * # pylint: disable=wildcard-import
+# TODO(akshayka): Do away with this file.
def main(argv=None):
_ops.enable_eager_execution()
_test.main(argv)
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 9c4d58b177..8ee38d35cc 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -1,8 +1,4 @@
-package(
- default_visibility = [
- "//tensorflow:internal",
- ],
-)
+package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
@@ -10,8 +6,15 @@ load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "estimator_py",
- srcs = ["estimator_lib.py"],
+ srcs = [
+ "__init__.py",
+ "estimator_lib.py",
+ ],
srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":baseline",
":boosted_trees",
@@ -27,7 +30,7 @@ py_library(
":parsing_utils",
":run_config",
":training",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -37,10 +40,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gc",
- "//tensorflow/python:errors",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:util",
],
@@ -54,10 +54,7 @@ py_test(
deps = [
":estimator",
":exporter",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -66,8 +63,7 @@ py_library(
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -78,10 +74,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":gc",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -91,12 +84,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":export_output",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -109,12 +97,7 @@ py_test(
deps = [
":export_output",
":model_fn",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -126,11 +109,7 @@ py_library(
":estimator",
":exporter",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -149,13 +128,7 @@ py_test(
":inputs",
":run_config",
":training",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -164,7 +137,7 @@ py_library(
srcs = ["run_config.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/core:protos_all_py",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -176,8 +149,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -190,14 +162,7 @@ py_library(
":head",
":model_fn",
":optimizers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -221,26 +186,7 @@ py_test(
":numpy_io",
":pandas_io",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -253,20 +199,7 @@ py_library(
":estimator",
":head",
":model_fn",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:boosted_trees_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -275,21 +208,13 @@ py_test(
size = "medium",
srcs = ["canned/boosted_trees_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
deps = [
":boosted_trees",
- "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:resources",
- "//tensorflow/python:training",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/feature_column",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -302,14 +227,7 @@ py_library(
":head",
":model_fn",
":optimizers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:summary",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -326,22 +244,7 @@ py_library(
":model_fn",
":numpy_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -364,16 +267,7 @@ py_test(
":numpy_io",
":pandas_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -389,19 +283,7 @@ py_library(
":linear",
":model_fn",
":optimizers",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -424,17 +306,7 @@ py_test(
":numpy_io",
":pandas_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -446,10 +318,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/data",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -460,10 +329,7 @@ py_test(
tags = ["notsan"], # b/67510291
deps = [
":util",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:training",
- "//tensorflow/python/data",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -480,21 +346,7 @@ py_library(
":model_fn",
":run_config",
":util",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python/data",
- "//tensorflow/python/saved_model:builder",
- "//tensorflow/python/saved_model:constants",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -513,29 +365,7 @@ py_test(
":model_fn",
":numpy_io",
":run_config",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:lib",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:saver_test_utils",
- "//tensorflow/python:session",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- "//tensorflow/python/data",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -548,9 +378,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -561,10 +389,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":parsing_utils",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -573,9 +398,7 @@ py_library(
srcs = ["export/export_output.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -587,13 +410,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":export_output",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -606,7 +423,7 @@ py_library(
deps = [
":export_export",
":export_output",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -618,13 +435,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":util",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -637,17 +448,8 @@ py_test(
deps = [
":export_export",
":export_output",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:signature_def_utils",
+ ":util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -660,24 +462,7 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:nn",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:weights_broadcast_ops",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -696,22 +481,7 @@ py_test(
":model_fn",
":numpy_io",
":prediction_keys",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -724,7 +494,7 @@ py_library(
deps = [
":numpy_io",
":pandas_io",
- "//tensorflow/python:util",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -736,11 +506,7 @@ py_library(
":estimator",
":head",
":optimizers",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/ops/losses",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -758,25 +524,7 @@ py_library(
":numpy_io",
":pandas_io",
":run_config",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -794,7 +542,7 @@ py_test(
deps = [
":linear",
":linear_testing_utils",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -823,9 +571,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":numpy_io",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -834,7 +580,7 @@ py_library(
srcs = ["canned/optimizers.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -846,8 +592,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":optimizers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -865,9 +610,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":pandas_io",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -887,15 +630,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"@six_archive//:six",
],
)
@@ -909,7 +644,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":inputs_queues",
- "//tensorflow/python:client_testlib",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -920,10 +655,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":inputs_queues",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -936,32 +668,7 @@ py_library(
":export_export",
":model_fn",
":run_config",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:layers",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:nn",
- "//tensorflow/python:partitioned_variables",
- "//tensorflow/python:platform",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:summary",
- "//tensorflow/python:tensor_util",
- "//tensorflow/python:training",
- "//tensorflow/python:training_util",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/feature_column",
- "//tensorflow/python/keras:backend",
- "//tensorflow/python/keras:engine",
- "//tensorflow/python/keras:layers",
- "//tensorflow/python/ops/losses",
- "//tensorflow/python/saved_model",
- "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow:tensorflow_py_no_contrib",
],
)
@@ -970,21 +677,47 @@ py_test(
size = "large",
srcs = ["keras_test.py"],
srcs_version = "PY2AND3",
- tags = ["notsan"],
+ tags = [
+ "no_windows",
+ "notsan",
+ ],
deps = [
":keras",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:numpy_io",
"//tensorflow/python/estimator:run_config",
- "//tensorflow/python/keras",
- "//tensorflow/python/keras:backend",
- "//tensorflow/python/keras:engine",
"//third_party/py/numpy",
],
)
+
+py_library(
+ name = "expect_numpy_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect numpy to already be installed on the system, e.g. via
+ # `pip install numpy`
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "expect_pandas_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect pandas to already be installed on the system, e.g. via
+ # `pip install pandas`
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "expect_six_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect six to already be installed on the system, e.g. via
+ # `pip install six`
+ visibility = ["//visibility:public"],
+)
+
+py_library(
+ name = "expect_tensorflow_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect tensorflow to already be installed on the system, e.g. via
+ # `pip install tensorflow` or `pip install tensorflow_gpu`
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/python/estimator/__init__.py b/tensorflow/python/estimator/__init__.py
index e69de29bb2..8cf8df567f 100644
--- a/tensorflow/python/estimator/__init__.py
+++ b/tensorflow/python/estimator/__init__.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Import Estimator APIs.
+
+Note: This file is imported by the create_estimator_api genrule. It must
+transitively import all Estimator modules/packages for their @estimator_export
+annotations to generate the public Estimator python API.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow.python.estimator.estimator_lib
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
new file mode 100644
index 0000000000..ceb9baef4d
--- /dev/null
+++ b/tensorflow/python/estimator/api/BUILD
@@ -0,0 +1,19 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+
+gen_api_init_files(
+ name = "estimator_python_api_gen",
+ api_name = "estimator",
+ output_files = ESTIMATOR_API_INIT_FILES,
+ output_package = "tensorflow.python.estimator.api",
+ package = "tensorflow.python.estimator",
+ package_dep = "//tensorflow/python/estimator:estimator_py",
+)
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
index 980c057372..20c7a69b7c 100644
--- a/tensorflow/python/estimator/canned/baseline.py
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -24,10 +24,10 @@ Example:
classifier = BaselineClassifier(n_classes=3)
# Input builders
-def input_fn_train: # returns x, y (where y represents label's class index).
+def input_fn_train(): # returns x, y (where y represents label's class index).
pass
-def input_fn_eval: # returns x, y (where y represents label's class index).
+def input_fn_eval(): # returns x, y (where y represents label's class index).
pass
# Fit model.
@@ -59,7 +59,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.3 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -174,7 +174,7 @@ def _baseline_model_fn(features, labels, mode, head, optimizer,
train_op_fn=train_op_fn)
-@tf_export('estimator.BaselineClassifier')
+@estimator_export('estimator.BaselineClassifier')
class BaselineClassifier(estimator.Estimator):
"""A classifier that can establish a simple baseline.
@@ -215,6 +215,13 @@ class BaselineClassifier(estimator.Estimator):
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
"""
def __init__(self,
@@ -277,7 +284,7 @@ class BaselineClassifier(estimator.Estimator):
config=config)
-@tf_export('estimator.BaselineRegressor')
+@estimator_export('estimator.BaselineRegressor')
class BaselineRegressor(estimator.Estimator):
"""A regressor that can establish a simple baseline.
@@ -313,6 +320,13 @@ class BaselineRegressor(estimator.Estimator):
* if `weight_column` is not `None`, a feature with
`key=weight_column` whose value is a `Tensor`.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
"""
def __init__(self,
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
index 7bf2e62da9..e46a3a156d 100644
--- a/tensorflow/python/estimator/canned/baseline_test.py
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -154,6 +154,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 9.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -176,6 +178,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -204,6 +208,8 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -229,7 +235,9 @@ class BaselineRegressorEvaluationTest(test.TestCase):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is bias which is [46, 58]
self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 4e6010a162..3c832c7569 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import functools
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -39,17 +40,18 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
- 'min_node_weight'
+ 'min_node_weight', 'center_bias'
])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
_HOLD_FOR_MULTI_DIM_SUPPORT = object()
_DUMMY_NUM_BUCKETS = -1
+_DUMMY_NODE_ID = -1
def _get_transformed_features(features, sorted_feature_columns):
@@ -168,9 +170,10 @@ def _group_features_by_num_buckets(sorted_feature_columns):
# pylint:enable=protected-access
# Replace the dummy key with the real max num of buckets for all bucketized
# columns.
- bucket_size_to_feature_ids_dict[
- max_buckets_for_bucketized] = bucket_size_to_feature_ids_dict[
- _DUMMY_NUM_BUCKETS]
+ if max_buckets_for_bucketized not in bucket_size_to_feature_ids_dict:
+ bucket_size_to_feature_ids_dict[max_buckets_for_bucketized] = []
+ bucket_size_to_feature_ids_dict[max_buckets_for_bucketized].extend(
+ bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS])
del bucket_size_to_feature_ids_dict[_DUMMY_NUM_BUCKETS]
feature_ids_list = list(bucket_size_to_feature_ids_dict.values())
@@ -278,7 +281,9 @@ class _CacheTrainingStatesUsingHashTable(object):
"""Returns cached_tree_ids, cached_node_ids, cached_logits."""
cached_tree_ids, cached_node_ids, cached_logits = array_ops.split(
lookup_ops.lookup_table_find_v2(
- self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]),
+ self._table_ref,
+ self._example_ids,
+ default_value=[0.0, _DUMMY_NODE_ID, 0.0]),
[1, 1, self._logits_dimension],
axis=1)
cached_tree_ids = array_ops.squeeze(
@@ -329,7 +334,7 @@ class _CacheTrainingStatesUsingVariables(object):
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -424,8 +429,8 @@ def _bt_model_fn(
ValueError: mode or params are invalid, or features has the wrong type.
"""
is_single_machine = (config.num_worker_replicas <= 1)
-
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
+ center_bias = tree_hparams.center_bias
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
@@ -468,6 +473,9 @@ def _bt_model_fn(
# Create Ensemble resources.
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ # Variable that determines whether bias centering is needed.
+ center_bias_var = variable_scope.variable(
+ initial_value=center_bias, name='center_bias_needed', trainable=False)
# Create logits.
if mode != model_fn.ModeKeys.TRAIN:
logits = boosted_trees_ops.predict(
@@ -488,6 +496,7 @@ def _bt_model_fn(
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
+
if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup())
@@ -496,9 +505,10 @@ def _bt_model_fn(
batch_size = array_ops.shape(labels)[0]
cached_tree_ids, cached_node_ids, cached_logits = (
array_ops.zeros([batch_size], dtype=dtypes.int32),
- array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32))
+
with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states()
@@ -512,13 +522,20 @@ def _bt_model_fn(
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
+
logits = cached_logits + partial_logits
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
if training_state_cache:
- train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
+ # Cache logits only after center_bias is complete, if it's in progress.
+ train_op.append(
+ control_flow_ops.cond(
+ center_bias_var, control_flow_ops.no_op,
+ lambda: training_state_cache.insert(tree_ids, node_ids, logits))
+ )
+
if closed_form_grad_and_hess_fn:
gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
else:
@@ -542,8 +559,7 @@ def _bt_model_fn(
]
stats_summaries_list.append(summaries)
- accumulators = []
-
+ # ========= Helper methods for both in and not in memory. ==============
def grow_tree_from_stats_summaries(stats_summaries_list,
feature_ids_list):
"""Updates ensemble based on the best gains from stats summaries."""
@@ -590,55 +606,126 @@ def _bt_model_fn(
pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
return grow_op
+ def _center_bias_fn(mean_gradients, mean_hessians):
+ """Updates the ensembles and cache (if needed) with logits prior."""
+ continue_centering = boosted_trees_ops.center_bias(
+ tree_ensemble.resource_handle,
+ mean_gradients=mean_gradients,
+ mean_hessians=mean_hessians,
+ l1=tree_hparams.l1,
+ l2=tree_hparams.l2
+ )
+ return center_bias_var.assign(continue_centering)
+
+ # ========= End of helper methods. ==============
+
if train_in_memory and is_single_machine:
train_op.append(distribute_lib.increment_var(global_step))
+
+ mean_gradients = array_ops.expand_dims(
+ math_ops.reduce_mean(gradients, 0), 0)
+ mean_heassians = array_ops.expand_dims(
+ math_ops.reduce_mean(hessians, 0), 0)
+
train_op.append(
- grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list))
+ control_flow_ops.cond(
+ center_bias_var,
+ lambda: _center_bias_fn(mean_gradients, mean_heassians),
+ functools.partial(grow_tree_from_stats_summaries,
+ stats_summaries_list, feature_ids_list)))
else:
- dependencies = []
- for i, feature_ids in enumerate(feature_ids_list):
- stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
+ def center_bias_not_in_mem():
+ """Accumulates the data and updates the logits bias, when ready."""
+ bias_dependencies = []
+
+ bias_accumulator = data_flow_ops.ConditionalAccumulator(
dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
- array_ops.stack(stats_summaries, axis=0), stamp_token)
- dependencies.append(apply_grad)
-
- def grow_tree_from_accumulated_summaries_fn():
- """Updates the tree with the best layer from accumulated summaries."""
- # Take out the accumulated summaries from the accumulator and grow.
- stats_summaries_list = []
-
- stats_summaries_list = [
- array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
- ]
-
- grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list)
- return grow_op
-
- with ops.control_dependencies(dependencies):
- train_op.append(distribute_lib.increment_var(global_step))
- if config.is_chief:
- min_accumulated = math_ops.reduce_min(
- array_ops.stack(
- [acc.num_accumulated() for acc in accumulators]))
-
- train_op.append(
- control_flow_ops.cond(
- math_ops.greater_equal(min_accumulated,
- n_batches_per_layer),
- grow_tree_from_accumulated_summaries_fn,
- control_flow_ops.no_op,
- name='wait_until_n_batches_accumulated'))
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+
+ grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
+ grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
+
+ apply_grad = bias_accumulator.apply_grad(grads_and_hess, stamp_token)
+ bias_dependencies.append(apply_grad)
+
+ def center_bias_from_accumulator():
+ accumulated = array_ops.unstack(
+ bias_accumulator.take_grad(1), axis=0)
+ return _center_bias_fn(
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+
+ with ops.control_dependencies(bias_dependencies):
+ if config.is_chief:
+ center_bias_op = control_flow_ops.cond(
+ math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ n_batches_per_layer),
+ center_bias_from_accumulator,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_for_bias_accumulated')
+
+ return center_bias_op
+ else:
+ return control_flow_ops.no_op()
+
+ def grow_not_in_mem():
+ """Accumulates the data and grows a layer when ready."""
+
+ accumulators = []
+ dependencies = []
+ for i, feature_ids in enumerate(feature_ids_list):
+ stats_summaries = stats_summaries_list[i]
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ accumulators.append(accumulator)
+
+ apply_grad = accumulator.apply_grad(
+ array_ops.stack(stats_summaries, axis=0), stamp_token)
+ dependencies.append(apply_grad)
+
+ def grow_tree_from_accumulated_summaries_fn():
+ """Updates tree with the best layer from accumulated summaries."""
+ # Take out the accumulated summaries from the accumulator and grow.
+ stats_summaries_list = []
+
+ stats_summaries_list = [
+ array_ops.unstack(accumulator.take_grad(1), axis=0)
+ for accumulator in accumulators
+ ]
+
+ grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
+ feature_ids_list)
+ return grow_op
+
+ with ops.control_dependencies(dependencies):
+ if config.is_chief:
+ min_accumulated = math_ops.reduce_min(
+ array_ops.stack(
+ [acc.num_accumulated() for acc in accumulators]))
+
+ grow_model = control_flow_ops.cond(
+ math_ops.greater_equal(min_accumulated, n_batches_per_layer),
+ grow_tree_from_accumulated_summaries_fn,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_accumulated')
+
+ return grow_model
+ else:
+ return control_flow_ops.no_op()
+
+ update_model = control_flow_ops.cond(
+ center_bias_var, center_bias_not_in_mem, grow_not_in_mem)
+ train_op.append(update_model)
+ with ops.control_dependencies([update_model]):
+ increment_global = distribute_lib.increment_var(global_step)
+ train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
@@ -712,9 +799,17 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
-@tf_export('estimator.BoostedTreesClassifier')
+@estimator_export('estimator.BoostedTreesClassifier')
class BoostedTreesClassifier(estimator.Estimator):
- """A Classifier for Tensorflow Boosted Trees models."""
+ """A Classifier for Tensorflow Boosted Trees models.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
+ """
def __init__(self,
feature_columns,
@@ -730,7 +825,8 @@ class BoostedTreesClassifier(estimator.Estimator):
l2_regularization=0.,
tree_complexity=0.,
min_node_weight=0.,
- config=None):
+ config=None,
+ center_bias=False):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -798,6 +894,13 @@ class BoostedTreesClassifier(estimator.Estimator):
split to be considered. The value will be compared with
sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -812,7 +915,7 @@ class BoostedTreesClassifier(estimator.Estimator):
# HParams for the model.
tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -830,9 +933,17 @@ class BoostedTreesClassifier(estimator.Estimator):
model_fn=_model_fn, model_dir=model_dir, config=config)
-@tf_export('estimator.BoostedTreesRegressor')
+@estimator_export('estimator.BoostedTreesRegressor')
class BoostedTreesRegressor(estimator.Estimator):
- """A Regressor for Tensorflow Boosted Trees models."""
+ """A Regressor for Tensorflow Boosted Trees models.
+
+ @compatibility(eager)
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
+ @end_compatibility
+ """
def __init__(self,
feature_columns,
@@ -847,7 +958,8 @@ class BoostedTreesRegressor(estimator.Estimator):
l2_regularization=0.,
tree_complexity=0.,
min_node_weight=0.,
- config=None):
+ config=None,
+ center_bias=False):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -908,6 +1020,12 @@ class BoostedTreesRegressor(estimator.Estimator):
split to be considered. The value will be compared with
sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -921,7 +1039,7 @@ class BoostedTreesRegressor(estimator.Estimator):
# HParams for the model.
tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 9ea4f48474..f807641057 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -500,6 +500,50 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(2, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testTrainEvaluateAndPredictWithOnlyIndicatorColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+
+ labels = np.array([[0.], [5.7], [5.7], [0.], [0.]], dtype=np.float32)
+ # Our categorical feature defines the labels perfectly
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'categorical': np.array(['bad', 'good', 'good', 'ok', 'bad']),
+ },
+ y=labels,
+ batch_size=5,
+ shuffle=False)
+
+ # Train depth 1 tree.
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[feature_indicator],
+ n_batches_per_layer=1,
+ n_trees=1,
+ learning_rate=1.0,
+ max_depth=1)
+
+ num_steps = 1
+ est.train(input_fn, steps=num_steps)
+ ensemble = self._assert_checkpoint_and_return_model(
+ est.model_dir, global_step=1, finalized_trees=1, attempted_layers=1)
+
+ # We learnt perfectly.
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['loss'], 0)
+
+ predictions = list(est.predict(input_fn))
+ self.assertAllClose(
+ labels,
+ [pred['predictions'] for pred in predictions])
+
+ self.assertEqual(3, len(ensemble.trees[0].nodes))
+
+ # Check that the split happened on 'good' value, which will be encoded as
+ # feature with index 1 (0 - 'bad', 2 - 'ok')
+ self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
+ self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
@@ -510,14 +554,6 @@ class ModelFnTests(test_util.TensorFlowTestCase):
feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
}
- self._tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access
- n_trees=2,
- max_depth=2,
- learning_rate=0.1,
- l1=0.,
- l2=0.01,
- tree_complexity=0.,
- min_node_weight=0.)
def _get_expected_ensembles_for_classification(self):
first_round = """
@@ -746,6 +782,245 @@ class ModelFnTests(test_util.TensorFlowTestCase):
"""
return (first_round, second_round, third_round)
+ def _get_expected_ensembles_for_classification_with_bias(self):
+ first_round = """
+ trees {
+ nodes {
+ leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.407711
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.407711
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf {
+ scalar: -0.556054
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.09876
+ original_leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.698072
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.106016
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.27349
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ last_layer_node_end: 1
+ }
+ """
+ forth_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.4077113
+ original_leaf {
+ scalar: -0.405086
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 3
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf {
+ scalar: -0.556054
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.09876
+ original_leaf {
+ scalar: -0.301233
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.698072
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.556054
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.106016
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.27349
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.289927
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.134588
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.083838
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ return (first_round, second_round, third_round, forth_round)
+
def _get_expected_ensembles_for_regression(self):
first_round = """
trees {
@@ -973,17 +1248,275 @@ class ModelFnTests(test_util.TensorFlowTestCase):
"""
return (first_round, second_round, third_round)
- def _get_train_op_and_ensemble(self, head, config, is_classification,
- train_in_memory):
+ def _get_expected_ensembles_for_regression_with_bias(self):
+ first_round = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ }
+ """
+ second_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.862786
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 1
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ third_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.683594
+ original_leaf {
+ scalar: 1.862786
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 0
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.322693
+ original_leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 2.024487
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.710319
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.559208
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.686037
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ last_layer_node_start: 0
+ last_layer_node_end: 1
+ }
+ """
+ forth_round = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ threshold: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.190442
+ original_leaf {
+ scalar: 1.799974
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ threshold: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.683594
+ original_leaf {
+ scalar: 1.8627863
+ }
+ }
+ }
+ nodes {
+ bucketized_split {
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 0.322693
+ original_leaf {
+ scalar: 1.706149
+ }
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 2.024487
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.710319
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.5592078
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.686037
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 0.972589
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.137592
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.034926
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 2
+ is_finalized: true
+ }
+ tree_metadata {
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 2
+ num_layers_attempted: 3
+ last_layer_node_start: 1
+ last_layer_node_end: 3
+ }
+ """
+ return (first_round, second_round, third_round, forth_round)
+
+ def _get_train_op_and_ensemble(self,
+ head,
+ config,
+ is_classification,
+ train_in_memory,
+ center_bias=False):
"""Calls bt_model_fn() and returns the train_op and ensemble_serialzed."""
features, labels = _make_train_input_fn(is_classification)()
+
+ tree_hparams = boosted_trees._TreeHParams( # pylint:disable=protected-access
+ n_trees=2,
+ max_depth=2,
+ learning_rate=0.1,
+ l1=0.,
+ l2=0.01,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ center_bias=center_bias)
+
estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access
features=features,
labels=labels,
mode=model_fn.ModeKeys.TRAIN,
head=head,
feature_columns=self._feature_columns,
- tree_hparams=self._tree_hparams,
+ tree_hparams=tree_hparams,
example_id_column_name=EXAMPLE_ID_COLUMN,
n_batches_per_layer=1,
config=config,
@@ -1032,6 +1565,49 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainClassifierWithCenterBiasInMemory(self):
+ ops.reset_default_graph()
+
+ # When bias centering is on, we expect the very first node to have the
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_classification_with_bias())
+
+ with self.test_session() as sess:
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=True,
+ center_bias=True)
+
+ # 4 iterations to center bias.
+ for _ in range(4):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainClassifierNonInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1062,6 +1638,47 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainClassifierWithCenterBiasNonInMemory(self):
+ ops.reset_default_graph()
+
+ # When bias centering is on, we expect the very first node to have the
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_classification_with_bias())
+
+ with self.test_session() as sess:
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_classification_head(n_classes=2),
+ run_config.RunConfig(),
+ is_classification=True,
+ train_in_memory=False,
+ center_bias=True)
+ # 4 iterations to center bias.
+ for _ in range(4):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainRegressorInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1092,6 +1709,46 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainRegressorInMemoryWithCenterBias(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_regression_with_bias())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=True,
+ center_bias=True)
+ # 3 iterations to center bias.
+ for _ in range(3):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
def testTrainRegressorNonInMemory(self):
ops.reset_default_graph()
expected_first, expected_second, expected_third = (
@@ -1122,6 +1779,46 @@ class ModelFnTests(test_util.TensorFlowTestCase):
ensemble_proto.ParseFromString(serialized)
self.assertProtoEquals(expected_third, ensemble_proto)
+ def testTrainRegressorNotInMemoryWithCenterBias(self):
+ ops.reset_default_graph()
+ expected_first, expected_second, expected_third, expected_forth = (
+ self._get_expected_ensembles_for_regression_with_bias())
+ with self.test_session() as sess:
+ # Train with train_in_memory mode.
+ with sess.graph.as_default():
+ train_op, ensemble_serialized = self._get_train_op_and_ensemble(
+ boosted_trees._create_regression_head(label_dimension=1),
+ run_config.RunConfig(),
+ is_classification=False,
+ train_in_memory=False,
+ center_bias=True)
+ # 3 iterations to center the bias (because we are using regularization).
+ for _ in range(3):
+ _, serialized = sess.run([train_op, ensemble_serialized])
+
+ # Validate the trained ensemble.
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_first, ensemble_proto)
+
+ # Run one more time and validate the trained ensemble.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_second, ensemble_proto)
+
+ # Third round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_third, ensemble_proto)
+
+ # Forth round training and validation.
+ _, serialized = sess.run([train_op, ensemble_serialized])
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertProtoEquals(expected_forth, ensemble_proto)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1feac36f35..c08cf61220 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -26,13 +26,14 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.layers import core as core_layers
+from tensorflow.python.layers import normalization
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.05 is a historical artifact of the initial
# implementation, but seems a reasonable choice.
@@ -45,7 +46,7 @@ def _add_hidden_layer_summary(value, tag):
def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
- dropout, input_layer_partitioner):
+ dropout, input_layer_partitioner, batch_norm):
"""Function builder for a dnn logit_fn.
Args:
@@ -58,6 +59,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
dropout: When not `None`, the probability we will drop out a given
coordinate.
input_layer_partitioner: Partitioner for input layer.
+ batch_norm: Whether to use batch normalization after each hidden layer.
Returns:
A logit_fn (see below).
@@ -83,6 +85,7 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
A `Tensor` representing the logits, or a list of `Tensor`'s representing
multiple logits in the MultiHead case.
"""
+ is_training = mode == model_fn.ModeKeys.TRAIN
with variable_scope.variable_scope(
'input_from_feature_columns',
values=tuple(six.itervalues(features)),
@@ -98,8 +101,20 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
activation=activation_fn,
kernel_initializer=init_ops.glorot_uniform_initializer(),
name=hidden_layer_scope)
- if dropout is not None and mode == model_fn.ModeKeys.TRAIN:
+ if dropout is not None and is_training:
net = core_layers.dropout(net, rate=dropout, training=True)
+ if batch_norm:
+ # TODO(hjm): In future, if this becomes popular, we can enable
+ # customization of the batch normalization params by accepting a
+ # list of `BatchNormalization` instances as `batch_norm`.
+ net = normalization.batch_normalization(
+ net,
+ # The default momentum 0.99 actually crashes on certain
+ # problem, so here we use 0.999, which is the default of
+ # tf.contrib.layers.batch_norm.
+ momentum=0.999,
+ training=is_training,
+ name='batchnorm_%d' % layer_id)
_add_hidden_layer_summary(net, hidden_layer_scope.name)
with variable_scope.variable_scope('logits', values=(net,)) as logits_scope:
@@ -127,7 +142,8 @@ def _dnn_model_fn(features,
dropout=None,
input_layer_partitioner=None,
config=None,
- tpu_estimator_spec=False):
+ tpu_estimator_spec=False,
+ batch_norm=False):
"""Deep Neural Net model_fn.
Args:
@@ -150,6 +166,7 @@ def _dnn_model_fn(features,
config: `RunConfig` object to configure the runtime settings.
tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or
or `model_fn.EstimatorSpec` instance.
+ batch_norm: Whether to use batch normalization after each hidden layer.
Returns:
An `EstimatorSpec` instance.
@@ -182,7 +199,8 @@ def _dnn_model_fn(features,
feature_columns=feature_columns,
activation_fn=activation_fn,
dropout=dropout,
- input_layer_partitioner=input_layer_partitioner)
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=batch_norm)
logits = logit_fn(features=features, mode=mode)
if tpu_estimator_spec:
@@ -201,7 +219,7 @@ def _dnn_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNClassifier')
+@estimator_export('estimator.DNNClassifier')
class DNNClassifier(estimator.Estimator):
"""A classifier for TensorFlow DNN models.
@@ -230,6 +248,17 @@ class DNNClassifier(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNClassifier(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
@@ -266,7 +295,10 @@ class DNNClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -285,6 +317,7 @@ class DNNClassifier(estimator.Estimator):
config=None,
warm_start_from=None,
loss_reduction=losses.Reduction.SUM,
+ batch_norm=False,
):
"""Initializes a `DNNClassifier` instance.
@@ -314,8 +347,9 @@ class DNNClassifier(estimator.Estimator):
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
Also there will be errors if vocabulary is not provided and labels are
string.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
@@ -330,6 +364,7 @@ class DNNClassifier(estimator.Estimator):
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ batch_norm: Whether to use batch normalization after each hidden layer.
"""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction)
@@ -346,14 +381,15 @@ class DNNClassifier(estimator.Estimator):
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm)
super(DNNClassifier, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNRegressor')
+@estimator_export('estimator.DNNRegressor')
class DNNRegressor(estimator.Estimator):
"""A regressor for TensorFlow DNN models.
@@ -382,6 +418,17 @@ class DNNRegressor(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = DNNRegressor(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256],
+ optimizer=lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = DNNRegressor(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
@@ -418,7 +465,10 @@ class DNNRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -436,6 +486,7 @@ class DNNRegressor(estimator.Estimator):
config=None,
warm_start_from=None,
loss_reduction=losses.Reduction.SUM,
+ batch_norm=False,
):
"""Initializes a `DNNRegressor` instance.
@@ -459,8 +510,9 @@ class DNNRegressor(estimator.Estimator):
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to Adagrad optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to Adagrad optimizer.
activation_fn: Activation function applied to each layer. If `None`, will
use `tf.nn.relu`.
dropout: When not `None`, the probability we will drop out a given
@@ -475,6 +527,7 @@ class DNNRegressor(estimator.Estimator):
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ batch_norm: Whether to use batch normalization after each hidden layer.
"""
def _model_fn(features, labels, mode, config):
@@ -492,7 +545,8 @@ class DNNRegressor(estimator.Estimator):
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm)
super(DNNRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 95efc0a028..efa7812452 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -37,7 +37,7 @@ from tensorflow.python.summary import summary
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rates are a historical artifact of the initial
# implementation.
@@ -88,7 +88,9 @@ def _dnn_linear_combined_model_fn(features,
dnn_activation_fn=nn.relu,
dnn_dropout=None,
input_layer_partitioner=None,
- config=None):
+ config=None,
+ batch_norm=False,
+ linear_sparse_combiner='sum'):
"""Deep Neural Net and Linear combined model_fn.
Args:
@@ -115,7 +117,10 @@ def _dnn_linear_combined_model_fn(features,
coordinate.
input_layer_partitioner: Partitioner for input layer.
config: `RunConfig` object to configure the runtime settings.
-
+ batch_norm: Whether to use batch normalization after each hidden layer.
+ linear_sparse_combiner: A string specifying how to reduce the linear model
+ if a categorical column is multivalent. One of "mean", "sqrtn", and
+ "sum".
Returns:
An `EstimatorSpec` instance.
@@ -164,7 +169,8 @@ def _dnn_linear_combined_model_fn(features,
feature_columns=dnn_feature_columns,
activation_fn=dnn_activation_fn,
dropout=dnn_dropout,
- input_layer_partitioner=input_layer_partitioner)
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=batch_norm)
dnn_logits = dnn_logit_fn(features=features, mode=mode)
linear_parent_scope = 'linear'
@@ -182,7 +188,8 @@ def _dnn_linear_combined_model_fn(features,
partitioner=input_layer_partitioner) as scope:
logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
- feature_columns=linear_feature_columns)
+ feature_columns=linear_feature_columns,
+ sparse_combiner=linear_sparse_combiner)
linear_logits = logit_fn(features=features)
_add_layer_summary(linear_logits, scope.name)
@@ -225,7 +232,7 @@ def _dnn_linear_combined_model_fn(features,
logits=logits)
-@tf_export('estimator.DNNLinearCombinedClassifier')
+@estimator_export('estimator.DNNLinearCombinedClassifier')
class DNNLinearCombinedClassifier(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined classification models.
@@ -257,12 +264,19 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
# warm-start settings
warm_start_from="/path/to/checkpoint/dir")
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -292,7 +306,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -311,7 +328,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
input_layer_partitioner=None,
config=None,
warm_start_from=None,
- loss_reduction=losses.Reduction.SUM):
+ loss_reduction=losses.Reduction.SUM,
+ batch_norm=False,
+ linear_sparse_combiner='sum'):
"""Initializes a DNNLinearCombinedClassifier instance.
Args:
@@ -322,12 +341,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
@@ -360,6 +383,12 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ batch_norm: Whether to use batch normalization after each hidden layer.
+ linear_sparse_combiner: A string specifying how to reduce the linear model
+ if a categorical column is multivalent. One of "mean", "sqrtn", and
+ "sum" -- these are effectively different ways to do example-level
+ normalization, which can be useful for bag-of-words features. For more
+ details, see @{tf.feature_column.linear_model$linear_model}.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -399,14 +428,16 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm,
+ linear_sparse_combiner=linear_sparse_combiner)
super(DNNLinearCombinedClassifier, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
warm_start_from=warm_start_from)
-@tf_export('estimator.DNNLinearCombinedRegressor')
+@estimator_export('estimator.DNNLinearCombinedRegressor')
class DNNLinearCombinedRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined models for regression.
@@ -438,12 +469,19 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
# warm-start settings
warm_start_from="/path/to/checkpoint/dir")
- # To apply L1 and L2 regularization, you can set optimizers as follows:
+ # To apply L1 and L2 regularization, you can set dnn_optimizer to:
tf.train.ProximalAdagradOptimizer(
learning_rate=0.1,
l1_regularization_strength=0.001,
l2_regularization_strength=0.001)
- # It is same for FtrlOptimizer.
+ # To apply learning rate decay, you can set dnn_optimizer to a callable:
+ lambda: tf.AdamOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96)
+ # It is the same for linear_optimizer.
# Input builders
def input_fn_train: # returns x, y
@@ -473,7 +511,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -491,7 +532,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
input_layer_partitioner=None,
config=None,
warm_start_from=None,
- loss_reduction=losses.Reduction.SUM):
+ loss_reduction=losses.Reduction.SUM,
+ batch_norm=False,
+ linear_sparse_combiner='sum'):
"""Initializes a DNNLinearCombinedRegressor instance.
Args:
@@ -502,12 +545,16 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
used by linear part of the model. All items in the set must be
instances of classes derived from `FeatureColumn`.
linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the linear part of the model. Defaults to FTRL optimizer.
+ the linear part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to FTRL
+ optimizer.
dnn_feature_columns: An iterable containing all the feature columns used
by deep part of the model. All items in the set must be instances of
classes derived from `FeatureColumn`.
dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to
- the deep part of the model. Defaults to Adagrad optimizer.
+ the deep part of the model. Can also be a string (one of 'Adagrad',
+ 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to Adagrad
+ optimizer.
dnn_hidden_units: List of hidden units per layer. All layers are fully
connected.
dnn_activation_fn: Activation function applied to each layer. If None,
@@ -534,6 +581,12 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ batch_norm: Whether to use batch normalization after each hidden layer.
+ linear_sparse_combiner: A string specifying how to reduce the linear model
+ if a categorical column is multivalent. One of "mean", "sqrtn", and
+ "sum" -- these are effectively different ways to do example-level
+ normalization, which can be useful for bag-of-words features. For more
+ details, see @{tf.feature_column.linear_model$linear_model}.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -564,7 +617,9 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
dnn_activation_fn=dnn_activation_fn,
dnn_dropout=dnn_dropout,
input_layer_partitioner=input_layer_partitioner,
- config=config)
+ config=config,
+ batch_norm=batch_norm,
+ linear_sparse_combiner=linear_sparse_combiner)
super(DNNLinearCombinedRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index d275695eb3..d16318659b 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -100,7 +100,8 @@ def _linear_regressor_fn(feature_columns,
weight_column=None,
optimizer='Ftrl',
config=None,
- partitioner=None):
+ partitioner=None,
+ sparse_combiner='sum'):
return dnn_linear_combined.DNNLinearCombinedRegressor(
model_dir=model_dir,
linear_feature_columns=feature_columns,
@@ -108,7 +109,8 @@ def _linear_regressor_fn(feature_columns,
label_dimension=label_dimension,
weight_column=weight_column,
input_layer_partitioner=partitioner,
- config=config)
+ config=config,
+ linear_sparse_combiner=sparse_combiner)
class LinearOnlyRegressorPartitionerTest(
@@ -163,7 +165,8 @@ def _linear_classifier_fn(feature_columns,
label_vocabulary=None,
optimizer='Ftrl',
config=None,
- partitioner=None):
+ partitioner=None,
+ sparse_combiner='sum'):
return dnn_linear_combined.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=feature_columns,
@@ -172,7 +175,8 @@ def _linear_classifier_fn(feature_columns,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
input_layer_partitioner=partitioner,
- config=config)
+ config=config,
+ linear_sparse_combiner=sparse_combiner)
class LinearOnlyClassifierTrainingTest(
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 06a648777f..de226ed0ef 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -65,6 +65,11 @@ from tensorflow.python.training import training_util
LEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'
HIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'
HIDDEN_BIASES_NAME_PATTERN = 'dnn/hiddenlayer_%d/bias'
+BATCH_NORM_BETA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/beta'
+BATCH_NORM_GAMMA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/gamma'
+BATCH_NORM_MEAN_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/moving_mean'
+BATCH_NORM_VARIANCE_NAME_PATTERN = (
+ 'dnn/hiddenlayer_%d/batchnorm_%d/moving_variance')
LOGITS_WEIGHTS_NAME = 'dnn/logits/kernel'
LOGITS_BIASES_NAME = 'dnn/logits/bias'
OCCUPATION_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/'
@@ -89,7 +94,10 @@ def assert_close(expected, actual, rtol=1e-04, message='', name='assert_close'):
name=scope)
-def create_checkpoint(weights_and_biases, global_step, model_dir):
+def create_checkpoint(weights_and_biases,
+ global_step,
+ model_dir,
+ batch_norm_vars=None):
"""Create checkpoint file with provided model weights.
Args:
@@ -98,12 +106,20 @@ def create_checkpoint(weights_and_biases, global_step, model_dir):
model_dir: Directory into which checkpoint is saved.
"""
weights, biases = zip(*weights_and_biases)
+ if batch_norm_vars:
+ assert len(batch_norm_vars) == len(weights_and_biases) - 1
+ (bn_betas, bn_gammas, bn_means, bn_variances) = zip(*batch_norm_vars)
model_weights = {}
# Hidden layer weights.
for i in range(0, len(weights) - 1):
model_weights[HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i]
model_weights[HIDDEN_BIASES_NAME_PATTERN % i] = biases[i]
+ if batch_norm_vars:
+ model_weights[BATCH_NORM_BETA_NAME_PATTERN % (i, i)] = bn_betas[i]
+ model_weights[BATCH_NORM_GAMMA_NAME_PATTERN % (i, i)] = bn_gammas[i]
+ model_weights[BATCH_NORM_MEAN_NAME_PATTERN % (i, i)] = bn_means[i]
+ model_weights[BATCH_NORM_VARIANCE_NAME_PATTERN % (i, i)] = bn_variances[i]
# Output layer weights.
model_weights[LOGITS_WEIGHTS_NAME] = weights[-1]
@@ -503,8 +519,13 @@ class BaseDNNLogitFnTest(object):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_logits(self, mode, hidden_units, logits_dimension, inputs,
- expected_logits):
+ def _test_logits(self,
+ mode,
+ hidden_units,
+ logits_dimension,
+ inputs,
+ expected_logits,
+ batch_norm=False):
"""Tests that the expected logits are calculated."""
with ops.Graph().as_default():
# Global step needed for MonitoredSession, which is in turn used to
@@ -525,7 +546,8 @@ class BaseDNNLogitFnTest(object):
],
activation_fn=nn.relu,
dropout=None,
- input_layer_partitioner=input_layer_partitioner)
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=batch_norm)
logits = logit_fn(
features={'age': constant_op.constant(inputs)}, mode=mode)
with monitored_session.MonitoredTrainingSession(
@@ -556,6 +578,69 @@ class BaseDNNLogitFnTest(object):
inputs=[[10.]],
expected_logits=[[-2.08]])
+ def test_one_dim_logits_with_batch_norm(self):
+ """Tests one-dimensional logits.
+
+ input_layer = [[10]]
+ hidden_layer_0 = [[relu(0.6*10 +1), relu(0.5*10 -1)]] = [[7, 4]]
+ hidden_layer_0 = [[relu(0.6*20 +1), relu(0.5*20 -1)]] = [[13, 9]]
+
+ batch_norm_0, training (epsilon = 0.001):
+ mean1 = 1/2*(7+13) = 10,
+ variance1 = 1/2*(3^2+3^2) = 9
+ x11 = (7-10)/sqrt(9+0.001) = -0.999944449,
+ x21 = (13-10)/sqrt(9+0.001) = 0.999944449,
+
+ mean2 = 1/2*(4+9) = 6.5,
+ variance2 = 1/2*(2.5^2+.2.5^2) = 6.25
+ x12 = (4-6.5)/sqrt(6.25+0.001) = -0.99992001,
+ x22 = (9-6.5)/sqrt(6.25+0.001) = 0.99992001,
+
+ logits = [[-1*(-0.999944449) + 2*(-0.99992001) + 0.3],
+ [-1*0.999944449 + 2*0.99992001 + 0.3]]
+ = [[-0.699895571],[1.299895571]]
+
+ batch_norm_0, not training (epsilon = 0.001):
+ moving_mean1 = 0, moving_variance1 = 1
+ x11 = (7-0)/sqrt(1+0.001) = 6.996502623,
+ x21 = (13-0)/sqrt(1+0.001) = 12.993504871,
+ moving_mean2 = 0, moving_variance2 = 1
+ x12 = (4-0)/sqrt(1+0.001) = 3.998001499,
+ x22 = (9-0)/sqrt(1+0.001) = 8.995503372,
+
+ logits = [[-1*6.996502623 + 2*3.998001499 + 0.3],
+ [-1*12.993504871 + 2*8.995503372 + 0.3]]
+ = [[1.299500375],[5.297501873]]
+ """
+ base_global_step = 100
+ create_checkpoint(
+ (
+ ([[.6, .5]], [1., -1.]),
+ ([[-1.], [2.]], [.3]),
+ ),
+ base_global_step,
+ self._model_dir,
+ batch_norm_vars=([[0, 0], # beta.
+ [1, 1], # gamma.
+ [0, 0], # moving mean.
+ [1, 1], # moving variance.
+ ],))
+ self._test_logits(
+ model_fn.ModeKeys.TRAIN,
+ hidden_units=[2],
+ logits_dimension=1,
+ inputs=[[10.], [20.]],
+ expected_logits=[[-0.699895571], [1.299895571]],
+ batch_norm=True)
+ for mode in [model_fn.ModeKeys.EVAL, model_fn.ModeKeys.PREDICT]:
+ self._test_logits(
+ mode,
+ hidden_units=[2],
+ logits_dimension=1,
+ inputs=[[10.], [20.]],
+ expected_logits=[[1.299500375], [5.297501873]],
+ batch_norm=True)
+
def test_multi_dim_logits(self):
"""Tests multi-dimensional logits.
@@ -706,7 +791,8 @@ class BaseDNNLogitFnTest(object):
],
activation_fn=nn.relu,
dropout=None,
- input_layer_partitioner=input_layer_partitioner)
+ input_layer_partitioner=input_layer_partitioner,
+ batch_norm=False)
logits = logit_fn(
features={
'age': constant_op.constant(inputs[0]),
@@ -1185,6 +1271,8 @@ class BaseDNNRegressorEvaluateTest(object):
self.assertAllClose({
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
+ metric_keys.MetricKeys.PREDICTION_MEAN: -2.08,
+ metric_keys.MetricKeys.LABEL_MEAN: 1.0,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
@@ -1215,6 +1303,8 @@ class BaseDNNRegressorEvaluateTest(object):
self.assertAllClose({
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 0.39 / 3.0,
+ metric_keys.MetricKeys.LABEL_MEAN: 0.5 / 3.0,
ops.GraphKeys.GLOBAL_STEP: global_step
}, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 04fe4d97e4..da9a64c2bc 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -873,6 +873,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
train_op = train_op_fn(regularized_training_loss)
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
+ train_op = _append_update_ops(train_op)
# Only summarize mean_loss for SUM reduction to preserve backwards
# compatibility. Otherwise skip it to avoid unnecessary computation.
if self._loss_reduction == losses.Reduction.SUM:
@@ -1244,6 +1245,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
train_op = train_op_fn(regularized_training_loss)
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
+ train_op = _append_update_ops(train_op)
# Only summarize mean_loss for SUM reduction to preserve backwards
# compatibility. Otherwise skip it to avoid unnecessary computation.
if self._loss_reduction == losses.Reduction.SUM:
@@ -1396,15 +1398,21 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
weights=weights,
processed_labels=labels)
- def _eval_metric_ops(self, weights, unreduced_loss, regularization_loss):
+ def _eval_metric_ops(self, predicted_value, labels, weights, unreduced_loss,
+ regularization_loss):
"""Returns the Eval metric ops."""
keys = metric_keys.MetricKeys
# Estimator already adds a metric for loss.
eval_metric_ops = {
_summary_key(self._name, keys.LOSS_MEAN):
- metrics_lib.mean(
- values=unreduced_loss,
- weights=weights)
+ metrics_lib.mean(values=unreduced_loss, weights=weights),
+ _summary_key(self._name, keys.PREDICTION_MEAN):
+ _predictions_mean(
+ predictions=predicted_value,
+ weights=weights,
+ name=keys.PREDICTION_MEAN),
+ _summary_key(self._name, keys.LABEL_MEAN):
+ metrics_lib.mean(values=labels, weights=weights)
}
if regularization_loss is not None:
regularization_loss_key = _summary_key(
@@ -1487,13 +1495,13 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
predictions=predictions,
loss=regularized_training_loss,
eval_metrics=_create_eval_metrics_tuple(
- self._eval_metric_ops,
- {
+ self._eval_metric_ops, {
+ 'predicted_value': predicted_value,
+ 'labels': labels,
'weights': weights,
'unreduced_loss': unreduced_loss,
'regularization_loss': regularization_loss,
- }
- ))
+ }))
# Train.
if optimizer is not None:
@@ -1506,6 +1514,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
train_op = train_op_fn(regularized_training_loss)
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
+ train_op = _append_update_ops(train_op)
# Only summarize mean_loss for SUM reduction to preserve backwards
# compatibility. Otherwise skip it to avoid unnecessary computation.
if self._loss_reduction == losses.Reduction.SUM:
@@ -1533,6 +1542,14 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
train_op=train_op)
+def _append_update_ops(train_op):
+ """Returns `train_op` appending `UPDATE_OPS` collection if present."""
+ update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
+ if update_ops:
+ return control_flow_ops.group(train_op, *update_ops)
+ return train_op
+
+
def _assert_range(labels, n_classes, message=None):
with ops.name_scope(None, 'assert_range', (labels,)):
assert_less = check_ops.assert_less_equal(
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index ecca3e8b0d..bd2e0ae943 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -969,6 +970,35 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),
train_result)
+ def test_train_with_update_ops(self):
+ n_classes = 3
+ head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)
+
+ with ops.Graph().as_default():
+ w = variables.Variable(1)
+ update_op = w.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
+
+ t = variables.Variable('')
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return t.assign(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32),
+ labels=np.array(((1,), (1,)), dtype=np.int64),
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ sess.run(spec.train_op)
+ w_value, t_value = sess.run([w, t])
+ self.assertEqual(2, w_value)
+ self.assertEqual(expected_train_result, t_value)
+
def test_train_summaries_with_head_name(self):
n_classes = 3
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
@@ -2102,6 +2132,34 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertAllClose(expected_loss, loss)
self.assertEqual(expected_train_result, train_result)
+ def test_train_with_update_ops(self):
+ head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()
+
+ with ops.Graph().as_default():
+ w = variables.Variable(1)
+ update_op = w.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
+
+ t = variables.Variable('')
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return t.assign(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=np.array(((45,), (-41,),), dtype=np.float32),
+ labels=np.array(((1,), (1,),), dtype=np.float64),
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ sess.run(spec.train_op)
+ w_value, t_value = sess.run([w, t])
+ self.assertEqual(2, w_value)
+ self.assertEqual(expected_train_result, t_value)
+
def test_train_summaries_with_head_name(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
name='some_binary_head')
@@ -3045,8 +3103,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3082,6 +3142,9 @@ class RegressionHead(test.TestCase):
expected_metric_keys = [
'{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN),
+ '{}/some_regression_head'.format(
+ metric_keys.MetricKeys.PREDICTION_MEAN),
+ '{}/some_regression_head'.format(metric_keys.MetricKeys.LABEL_MEAN),
]
self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())
@@ -3112,6 +3175,8 @@ class RegressionHead(test.TestCase):
expected_metrics = {
keys.LOSS_MEAN: expected_unregularized_loss,
keys.LOSS_REGULARIZATION: expected_regularization_loss,
+ keys.PREDICTION_MEAN: (45 + 41) / 2.0,
+ keys.LABEL_MEAN: (43 + 44) / 2.0,
}
# Assert predictions, loss, and metrics.
@@ -3278,6 +3343,34 @@ class RegressionHead(test.TestCase):
self.assertAllClose(expected_loss, loss)
self.assertEqual(expected_train_result, train_result)
+ def test_train_with_update_ops(self):
+ head = head_lib._regression_head()
+
+ with ops.Graph().as_default():
+ w = variables.Variable(1)
+ update_op = w.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
+
+ t = variables.Variable('')
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return t.assign(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=np.array(((45,), (41,),), dtype=np.float32),
+ labels=np.array(((43.,), (44.,),), dtype=np.float64),
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ sess.run(spec.train_op)
+ w_value, t_value = sess.run([w, t])
+ self.assertEqual(2, w_value)
+ self.assertEqual(expected_train_result, t_value)
+
def test_train_summaries_with_head_name(self):
head = head_lib._regression_head(name='some_regression_head')
self.assertEqual(1, head.logits_dimension)
@@ -3385,8 +3478,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3614,8 +3709,10 @@ class RegressionHead(test.TestCase):
self.assertItemsEqual((prediction_key,), spec.predictions.keys())
self.assertEqual(dtypes.float32, spec.predictions[prediction_key].dtype)
self.assertEqual(dtypes.float32, spec.loss.dtype)
- self.assertItemsEqual(
- (metric_keys.MetricKeys.LOSS_MEAN,), spec.eval_metric_ops.keys())
+ self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN),
+ spec.eval_metric_ops.keys())
self.assertIsNone(spec.train_op)
self.assertIsNone(spec.export_outputs)
_assert_no_hooks(self, spec)
@@ -3746,7 +3843,13 @@ class RegressionHead(test.TestCase):
# losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]
# loss = sum(losses) = 100+.1+1.5 = 101.6
# loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923
- expected_metrics = {metric_keys.MetricKeys.LOSS_MEAN: 39.076923}
+ expected_metrics = {
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 39.076923,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ (45 + 41 * 0.1 + 44 * 1.5) / 2.6,
+ metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6,
+ }
# Assert spec contains expected tensors.
self.assertEqual(dtypes.float32, spec.loss.dtype)
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 81657f0c01..58a7160348 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -33,7 +33,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -66,13 +66,15 @@ def _compute_fraction_of_zero(cols_to_vars):
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
-def _linear_logit_fn_builder(units, feature_columns):
+def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
"""Function builder for a linear logit_fn.
Args:
units: An int indicating the dimension of the logit layer.
feature_columns: An iterable containing all the feature columns used by
the model.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. One of "mean", "sqrtn", and "sum".
Returns:
A logit_fn (see below).
@@ -95,6 +97,7 @@ def _linear_logit_fn_builder(units, feature_columns):
features=features,
feature_columns=feature_columns,
units=units,
+ sparse_combiner=sparse_combiner,
cols_to_vars=cols_to_vars)
bias = cols_to_vars.pop('bias')
if units > 1:
@@ -111,7 +114,7 @@ def _linear_logit_fn_builder(units, feature_columns):
def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
- partitioner, config):
+ partitioner, config, sparse_combiner='sum'):
"""A model_fn for linear models that use a gradient-based optimizer.
Args:
@@ -126,6 +129,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
optimizer to use for training. If `None`, will use a FTRL optimizer.
partitioner: Partitioner for variables.
config: `RunConfig` object to configure the runtime settings.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. One of "mean", "sqrtn", and "sum".
Returns:
An `EstimatorSpec` instance.
@@ -153,7 +158,8 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
partitioner=partitioner):
logit_fn = _linear_logit_fn_builder(
- units=head.logits_dimension, feature_columns=feature_columns)
+ units=head.logits_dimension, feature_columns=feature_columns,
+ sparse_combiner=sparse_combiner)
logits = logit_fn(features=features)
return head.create_estimator_spec(
@@ -164,7 +170,7 @@ def _linear_model_fn(features, labels, mode, head, feature_columns, optimizer,
logits=logits)
-@tf_export('estimator.LinearClassifier')
+@estimator_export('estimator.LinearClassifier')
class LinearClassifier(estimator.Estimator):
"""Linear classifier model.
@@ -193,6 +199,17 @@ class LinearClassifier(estimator.Estimator):
l1_regularization_strength=0.001
))
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearClassifier(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearClassifier(
feature_columns=[categorical_column_a,
@@ -227,7 +244,10 @@ class LinearClassifier(estimator.Estimator):
Loss is calculated by using softmax cross entropy.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -241,7 +261,8 @@ class LinearClassifier(estimator.Estimator):
config=None,
partitioner=None,
warm_start_from=None,
- loss_reduction=losses.Reduction.SUM):
+ loss_reduction=losses.Reduction.SUM,
+ sparse_combiner='sum'):
"""Construct a `LinearClassifier` estimator object.
Args:
@@ -269,8 +290,9 @@ class LinearClassifier(estimator.Estimator):
encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 .
Also there will be errors if vocabulary is not provided and labels are
string.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
@@ -280,6 +302,11 @@ class LinearClassifier(estimator.Estimator):
and Tensor names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. One of "mean", "sqrtn", and "sum" -- these are
+ effectively different ways to do example-level normalization, which can
+ be useful for bag-of-words features. for more details, see
+ @{tf.feature_column.linear_model$linear_model}.
Returns:
A `LinearClassifier` estimator.
@@ -308,7 +335,8 @@ class LinearClassifier(estimator.Estimator):
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
- config=config)
+ config=config,
+ sparse_combiner=sparse_combiner)
super(LinearClassifier, self).__init__(
model_fn=_model_fn,
@@ -317,7 +345,7 @@ class LinearClassifier(estimator.Estimator):
warm_start_from=warm_start_from)
-@tf_export('estimator.LinearRegressor')
+@estimator_export('estimator.LinearRegressor')
class LinearRegressor(estimator.Estimator):
"""An estimator for TensorFlow Linear regression problems.
@@ -332,10 +360,31 @@ class LinearRegressor(estimator.Estimator):
categorical_feature_a_x_categorical_feature_b = crossed_column(...)
+ # Estimator using the default optimizer.
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
categorical_feature_a_x_categorical_feature_b])
+ # Or estimator using the FTRL optimizer with regularization.
+ estimator = LinearRegressor(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=tf.train.FtrlOptimizer(
+ learning_rate=0.1,
+ l1_regularization_strength=0.001
+ ))
+
+ # Or estimator using an optimizer with a learning rate decay.
+ estimator = LinearRegressor(
+ feature_columns=[categorical_column_a,
+ categorical_feature_a_x_categorical_feature_b],
+ optimizer=lambda: tf.train.FtrlOptimizer(
+ learning_rate=tf.exponential_decay(
+ learning_rate=0.1,
+ global_step=tf.get_global_step(),
+ decay_steps=10000,
+ decay_rate=0.96))
+
# Or estimator with warm-starting from a previous checkpoint.
estimator = LinearRegressor(
feature_columns=[categorical_column_a,
@@ -370,7 +419,10 @@ class LinearRegressor(estimator.Estimator):
Loss is calculated by using mean squared error.
@compatibility(eager)
- Estimators are not compatible with eager execution.
+ Estimators can be used while eager execution is enabled. Note that `input_fn`
+ and all hooks are executed inside a graph context, so they have to be written
+ to be compatible with graph mode. Note that `input_fn` code using `tf.data`
+ generally works in both graph and eager modes.
@end_compatibility
"""
@@ -383,7 +435,8 @@ class LinearRegressor(estimator.Estimator):
config=None,
partitioner=None,
warm_start_from=None,
- loss_reduction=losses.Reduction.SUM):
+ loss_reduction=losses.Reduction.SUM,
+ sparse_combiner='sum'):
"""Initializes a `LinearRegressor` instance.
Args:
@@ -403,8 +456,9 @@ class LinearRegressor(estimator.Estimator):
used as a key to fetch weight tensor from the `features`. If it is a
`_NumericColumn`, raw tensor is fetched by key `weight_column.key`,
then weight_column.normalizer_fn is applied on it to get weight tensor.
- optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
- to FTRL optimizer.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Can also
+ be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or
+ callable. Defaults to FTRL optimizer.
config: `RunConfig` object to configure the runtime settings.
partitioner: Optional. Partitioner for input layer.
warm_start_from: A string filepath to a checkpoint to warm-start from, or
@@ -414,6 +468,11 @@ class LinearRegressor(estimator.Estimator):
and Tensor names are unchanged.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
to reduce training loss over batch. Defaults to `SUM`.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. One of "mean", "sqrtn", and "sum" -- these are
+ effectively different ways to do example-level normalization, which can
+ be useful for bag-of-words features. for more details, see
+ @{tf.feature_column.linear_model$linear_model}.
"""
head = head_lib._regression_head( # pylint: disable=protected-access
label_dimension=label_dimension, weight_column=weight_column,
@@ -429,7 +488,8 @@ class LinearRegressor(estimator.Estimator):
feature_columns=tuple(feature_columns or []),
optimizer=optimizer,
partitioner=partitioner,
- config=config)
+ config=config,
+ sparse_combiner=sparse_combiner)
super(LinearRegressor, self).__init__(
model_fn=_model_fn,
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 0e6436b421..c3934c7a80 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -29,6 +29,7 @@ import six
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator.canned import linear
@@ -260,6 +261,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 9.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -285,6 +288,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 18.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -315,6 +320,8 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertDictEqual({
metric_keys.MetricKeys.LOSS: 27.,
metric_keys.MetricKeys.LOSS_MEAN: 9.,
+ metric_keys.MetricKeys.PREDICTION_MEAN: 13.,
+ metric_keys.MetricKeys.LABEL_MEAN: 10.,
ops.GraphKeys.GLOBAL_STEP: 100
}, eval_metrics)
@@ -345,7 +352,9 @@ class BaseLinearRegressorEvaluationTest(object):
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is
# [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0]
@@ -382,7 +391,9 @@ class BaseLinearRegressorEvaluationTest(object):
eval_metrics = est.evaluate(input_fn=input_fn, steps=1)
self.assertItemsEqual(
(metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
- ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+ metric_keys.MetricKeys.PREDICTION_MEAN,
+ metric_keys.MetricKeys.LABEL_MEAN, ops.GraphKeys.GLOBAL_STEP),
+ eval_metrics.keys())
# Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =
# [213.0, 421.0], while label is [213., 421.]. Loss = 0.
@@ -484,6 +495,69 @@ class BaseLinearRegressorPredictTest(object):
# x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2
self.assertAllClose([[80.2]], predicted_scores)
+ def testSparseCombiner(self):
+ w_a = 2.0
+ w_b = 3.0
+ w_c = 5.0
+ bias = 5.0
+ with ops.Graph().as_default():
+ variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)
+ variables_lib.Variable([bias], name=BIAS_NAME)
+ variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors({
+ 'language': sparse_tensor.SparseTensor(
+ values=['a', 'c', 'b', 'c'],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ })
+
+ feature_columns = (
+ feature_column_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
+
+ # Check prediction for each sparse_combiner.
+ # With sparse_combiner = 'sum', we have
+ # logits_1 = w_a + w_c + bias
+ # = 2.0 + 5.0 + 5.0 = 12.0
+ # logits_2 = w_b + w_c + bias
+ # = 3.0 + 5.0 + 5.0 = 13.0
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[12.0], [13.0]], predicted_scores)
+
+ # With sparse_combiner = 'mean', we have
+ # logits_1 = 1/2 * (w_a + w_c) + bias
+ # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5
+ # logits_2 = 1/2 * (w_b + w_c) + bias
+ # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='mean')
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[8.5], [9.0]], predicted_scores)
+
+ # With sparse_combiner = 'sqrtn', we have
+ # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias
+ # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974
+ # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias
+ # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685
+ linear_regressor = self._linear_regressor_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='sqrtn')
+ predictions = linear_regressor.predict(input_fn=_input_fn)
+ predicted_scores = list([x['predictions'] for x in predictions])
+ self.assertAllClose([[9.94974], [10.65685]], predicted_scores)
+
class BaseLinearRegressorIntegrationTest(object):
@@ -1636,6 +1710,69 @@ class BaseLinearClassifierPredictTest(object):
for i in range(n_classes)],
label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+ def testSparseCombiner(self):
+ w_a = 2.0
+ w_b = 3.0
+ w_c = 5.0
+ bias = 5.0
+ with ops.Graph().as_default():
+ variables_lib.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)
+ variables_lib.Variable([bias], name=BIAS_NAME)
+ variables_lib.Variable(1, name=ops.GraphKeys.GLOBAL_STEP,
+ dtype=dtypes.int64)
+ save_variables_to_ckpt(self._model_dir)
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors({
+ 'language': sparse_tensor.SparseTensor(
+ values=['a', 'c', 'b', 'c'],
+ indices=[[0, 0], [0, 1], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ })
+
+ feature_columns = (
+ feature_column_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
+
+ # Check prediction for each sparse_combiner.
+ # With sparse_combiner = 'sum', we have
+ # logits_1 = w_a + w_c + bias
+ # = 2.0 + 5.0 + 5.0 = 12.0
+ # logits_2 = w_b + w_c + bias
+ # = 3.0 + 5.0 + 5.0 = 13.0
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[12.0], [13.0]], predicted_scores)
+
+ # With sparse_combiner = 'mean', we have
+ # logits_1 = 1/2 * (w_a + w_c) + bias
+ # = 1/2 * (2.0 + 5.0) + 5.0 = 8.5
+ # logits_2 = 1/2 * (w_b + w_c) + bias
+ # = 1/2 * (3.0 + 5.0) + 5.0 = 9.0
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='mean')
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[8.5], [9.0]], predicted_scores)
+
+ # With sparse_combiner = 'sqrtn', we have
+ # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias
+ # = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974
+ # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias
+ # = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685
+ linear_classifier = self._linear_classifier_fn(
+ feature_columns=feature_columns,
+ model_dir=self._model_dir,
+ sparse_combiner='sqrtn')
+ predictions = linear_classifier.predict(input_fn=_input_fn)
+ predicted_scores = list([x['logits'] for x in predictions])
+ self.assertAllClose([[9.94974], [10.65685]], predicted_scores)
+
class BaseLinearClassifierIntegrationTest(object):
diff --git a/tensorflow/python/estimator/canned/optimizers.py b/tensorflow/python/estimator/canned/optimizers.py
index f72c5ca5cb..8f51cc3a80 100644
--- a/tensorflow/python/estimator/canned/optimizers.py
+++ b/tensorflow/python/estimator/canned/optimizers.py
@@ -72,6 +72,8 @@ def get_optimizer_instance(opt, learning_rate=None):
raise ValueError(
'Unsupported optimizer name: {}. Supported names are: {}'.format(
opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES)))))
+ if callable(opt):
+ opt = opt()
if not isinstance(opt, optimizer_lib.Optimizer):
raise ValueError(
'The given object is not an Optimizer instance. Given: {}'.format(opt))
diff --git a/tensorflow/python/estimator/canned/optimizers_test.py b/tensorflow/python/estimator/canned/optimizers_test.py
index ee28756155..eadabdbc49 100644
--- a/tensorflow/python/estimator/canned/optimizers_test.py
+++ b/tensorflow/python/estimator/canned/optimizers_test.py
@@ -28,6 +28,13 @@ from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import rmsprop
+class _TestOptimizer(optimizer_lib.Optimizer):
+
+ def __init__(self):
+ super(_TestOptimizer, self).__init__(
+ use_locking=False, name='TestOptimizer')
+
+
class GetOptimizerInstance(test.TestCase):
def test_unsupported_name(self):
@@ -66,12 +73,6 @@ class GetOptimizerInstance(test.TestCase):
self.assertAlmostEqual(0.1, opt._learning_rate)
def test_object(self):
- class _TestOptimizer(optimizer_lib.Optimizer):
-
- def __init__(self):
- super(_TestOptimizer, self).__init__(
- use_locking=False, name='TestOptimizer')
-
opt = optimizers.get_optimizer_instance(_TestOptimizer())
self.assertIsInstance(opt, _TestOptimizer)
@@ -80,6 +81,23 @@ class GetOptimizerInstance(test.TestCase):
ValueError, 'The given object is not an Optimizer instance'):
optimizers.get_optimizer_instance((1, 2, 3))
+ def test_callable(self):
+ def _optimizer_fn():
+ return _TestOptimizer()
+ opt = optimizers.get_optimizer_instance(_optimizer_fn)
+ self.assertIsInstance(opt, _TestOptimizer)
+
+ def test_lambda(self):
+ opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer()) # pylint: disable=unnecessary-lambda
+ self.assertIsInstance(opt, _TestOptimizer)
+
+ def test_callable_returns_invalid(self):
+ def _optimizer_fn():
+ return (1, 2, 3)
+ with self.assertRaisesRegexp(
+ ValueError, 'The given object is not an Optimizer instance'):
+ optimizers.get_optimizer_instance(_optimizer_fn)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/canned/parsing_utils.py b/tensorflow/python/estimator/canned/parsing_utils.py
index 74e5e5a1be..1ae0f1e9f7 100644
--- a/tensorflow/python/estimator/canned/parsing_utils.py
+++ b/tensorflow/python/estimator/canned/parsing_utils.py
@@ -23,10 +23,10 @@ import six
from tensorflow.python.feature_column import feature_column as fc
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.classifier_parse_example_spec')
+@estimator_export('estimator.classifier_parse_example_spec')
def classifier_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.int64,
@@ -166,7 +166,7 @@ def classifier_parse_example_spec(feature_columns,
return parsing_spec
-@tf_export('estimator.regressor_parse_example_spec')
+@estimator_export('estimator.regressor_parse_example_spec')
def regressor_parse_example_spec(feature_columns,
label_key,
label_dtype=dtypes.float32,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 4f57a4ef79..350a95eea1 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -38,6 +38,7 @@ from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -66,14 +67,14 @@ from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_VALID_MODEL_FN_ARGS = set(
['features', 'labels', 'mode', 'params', 'self', 'config'])
-@tf_export('estimator.Estimator')
+@estimator_export('estimator.Estimator')
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
@@ -103,6 +104,15 @@ class Estimator(object):
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
+
+ @compatbility(eager)
+ Calling methods of `Estimator` will work while eager execution is enabled.
+ However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
+ will switch to graph model before calling all user-provided functions (incl.
+ hooks), so their code has to be compatible with graph mode execution. Note
+ that `input_fn` code using `tf.data` generally works in both graph and eager
+ modes.
+ @end_compatibility
"""
def __init__(self, model_fn, model_dir=None, config=None, params=None,
@@ -566,7 +576,8 @@ class Estimator(object):
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_validate_features_in_predict_input',
+ '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
@@ -838,7 +849,8 @@ class Estimator(object):
strip_default_attrs,
save_variables=True,
mode=model_fn_lib.ModeKeys.PREDICT,
- export_tags=None):
+ export_tags=None,
+ check_variables=True):
# pylint: disable=line-too-long
"""Loads variables and adds them along with a MetaGraphDef for saving.
@@ -859,6 +871,10 @@ class Estimator(object):
mode: tf.estimator.ModeKeys value indicating which mode will be exported.
export_tags: The set of tags with which to save `MetaGraphDef`. If None,
a default set will be selected to matched the passed mode.
+ check_variables: bool, whether to check the checkpoint has all variables.
+
+ Raises:
+ ValueError: if `save_variables` is `True` and `check_variable` is `False`.
"""
# pylint: enable=line-too-long
if export_tags is None:
@@ -893,19 +909,26 @@ class Estimator(object):
estimator_spec.scaffold.local_init_op or
monitored_session.Scaffold.default_local_init_op())
- saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
- sharded=True)
-
- try:
- saver_for_restore.restore(session, checkpoint_path)
- except errors.NotFoundError as e:
- msg = ('Could not load all requested variables from the checkpoint. '
- 'Please make sure your model_fn does not expect variables '
- 'that were not saved in the checkpoint.\n\n'
- 'Encountered error with mode `{}` while restoring checkpoint '
- 'from: `{}`. Full Traceback:\n\n{}').format(
- mode, checkpoint_path, e)
- raise ValueError(msg)
+ # This saver will be used both for restoring variables now,
+ # and in saving out the metagraph below. This ensures that any
+ # Custom Savers stored with the Scaffold are passed through to the
+ # SavedModel for restore later.
+ graph_saver = estimator_spec.scaffold.saver or saver.Saver(sharded=True)
+
+ if save_variables and not check_variables:
+ raise ValueError('If `save_variables` is `True, `check_variables`'
+ 'must not be `False`.')
+ if check_variables:
+ try:
+ graph_saver.restore(session, checkpoint_path)
+ except errors.NotFoundError as e:
+ msg = ('Could not load all requested variables from checkpoint. '
+ 'Please make sure your model_fn does not expect variables '
+ 'that were not saved in the checkpoint.\n\n'
+ 'Encountered error with mode `{}` while restoring '
+ 'checkpoint from: `{}`. Full Traceback:\n\n{}').format(
+ mode, checkpoint_path, e)
+ raise ValueError(msg)
# We add the train op explicitly for now, so that we don't have to
# change the Builder public interface. Note that this is a no-op
@@ -918,7 +941,8 @@ class Estimator(object):
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
strip_default_attrs=strip_default_attrs,
- legacy_init_op=local_init_op)
+ legacy_init_op=local_init_op,
+ saver=graph_saver)
if save_variables:
builder.add_meta_graph_and_variables(
@@ -1119,6 +1143,18 @@ class Estimator(object):
return self._train_model_default(input_fn, hooks, saving_listeners)
def _train_model_default(self, input_fn, hooks, saving_listeners):
+ """Initiate training with input_fn, without DistributionStrategies.
+
+ Args:
+ input_fn: A function that provides input data for training as minibatches.
+ hooks: List of `SessionRunHook` subclass instances. Used for callbacks
+ inside the training loop.
+ saving_listeners: list of `CheckpointSaverListener` objects. Used for
+ callbacks that run immediately before or after checkpoint savings.
+
+ Returns:
+ Loss from training
+ """
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
@@ -1135,29 +1171,86 @@ class Estimator(object):
saving_listeners)
def _train_model_distributed(self, input_fn, hooks, saving_listeners):
+ """Initiate training with input_fn, using DistributionStrategies.
+
+ Args:
+ input_fn: A function that provides input data for training as minibatches.
+ hooks: List of `SessionRunHook` subclass instances. Used for callbacks
+ inside the training loop.
+ saving_listeners: list of `CheckpointSaverListener` objects. Used for
+ callbacks that run immediately before or after checkpoint savings.
+
+ Returns:
+ Loss from training
+ """
self._distribution.configure(self._session_config)
+
+ # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies
+ # to use the new API
+ is_tpu_strategy = self._distribution.__class__.__name__ == 'TPUStrategy'
+
worker_hooks = []
with ops.Graph().as_default() as g:
with self._distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.TRAIN))
- worker_hooks.extend(input_hooks)
- global_step_tensor = self._create_and_assert_global_step(g)
- # The default destination for the global_step_tensor fetch call is the
- # CPU.
- global_step_read_tensor = self._distribution.fetch(global_step_tensor)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
- global_step_read_tensor)
- grouped_estimator_spec = self._distribution.call_for_each_tower(
- self._call_model_fn,
- features,
- labels, # although this will be None it seems
- model_fn_lib.ModeKeys.TRAIN,
- self.config)
+
+ if is_tpu_strategy:
+ # Create the iterator for run_on_dataset function
+ # TODO(sourabhbajaj): refactor this out to call a function on the
+ # strategy
+ dataset = self._distribution.distribute_dataset(
+ lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda
+ model_fn_lib.ModeKeys.TRAIN))
+ iterator = dataset.make_initializable_iterator()
+ worker_hooks.append(
+ estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access
+
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
+ self._distribution.read_var(global_step_tensor))
+
+ # Create a step_fn from the train_op of grouped_estimator_spec
+ def step_fn(ctx, inputs):
+ """A single step that is passed to run_on_dataset."""
+ features, labels = inputs
+ estimator_spec = self._distribution.call_for_each_tower(
+ self._call_model_fn,
+ features,
+ labels,
+ model_fn_lib.ModeKeys.TRAIN,
+ self.config)
+ ctx.last_step_outputs = estimator_spec.loss
+ ctx.non_tensor_outputs = {'estimator_spec': estimator_spec}
+ with ops.control_dependencies([estimator_spec.train_op]):
+ return array_ops.identity(estimator_spec.loss)
+
+ # Create new train_op post graph rewrites
+ # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations
+ # work correctly. Currently hardcoded at 2
+ initial_training_loss = constant_op.constant(1e7)
+ distributed_train_op, tpu_result, ctx = \
+ self._distribution._run_steps_on_dataset( # pylint: disable=protected-access
+ step_fn, iterator, iterations=2,
+ initial_loop_values=initial_training_loss)
+ grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
+ else:
+ features, labels, input_hooks = (
+ self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN))
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(training_util.GLOBAL_STEP_READ_KEY,
+ self._distribution.read_var(global_step_tensor))
+ grouped_estimator_spec = self._distribution.call_for_each_tower(
+ self._call_model_fn,
+ features,
+ labels, # although this will be None it seems
+ model_fn_lib.ModeKeys.TRAIN,
+ self.config)
# TODO(anjalisridhar): Figure out how to resolve the following scaffold
# parameters: init_feed_dict, init_fn.
@@ -1185,10 +1278,16 @@ class Estimator(object):
else:
init_op = None
+ def _unwrap_and_concat(value):
+ value = nest.flatten(self._distribution.unwrap(value))
+ if len(value) != 1:
+ return array_ops.concat(value)
+ return value[0]
+
ready_op = self._distribution.call_for_each_tower(
create_per_tower_ready_op, grouped_estimator_spec.scaffold)
if ready_op is not None:
- ready_op = self._distribution.group(ready_op)
+ ready_op = _unwrap_and_concat(ready_op)
else:
ready_op = None
@@ -1196,8 +1295,7 @@ class Estimator(object):
create_per_tower_ready_for_local_init_op,
grouped_estimator_spec.scaffold)
if ready_for_local_init_op is not None:
- ready_for_local_init_op = self._distribution.group(
- ready_for_local_init_op)
+ ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op)
else:
ready_for_local_init_op = None
@@ -1238,18 +1336,33 @@ class Estimator(object):
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
+ # TODO(sourabhbajaj): Merge the two code paths once we can
+ # handle per device variables correctly in reduce and can output
+ # the loss scaler.
+ if is_tpu_strategy:
+ loss = self._distribution.unwrap(
+ self._distribution.reduce(distribute_lib.get_loss_reduction(),
+ tpu_result)[0])[0]
+ worker_hooks.append(
+ estimator_util.StrategyInitFinalizeHook(
+ self._distribution.get_initialization_ops,
+ self._distribution.get_finalize_ops))
+ else:
+ loss = self._distribution.unwrap(
+ self._distribution.reduce(distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0]
+ distributed_train_op = grouped_estimator_spec.train_op
+
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
- loss=self._distribution.unwrap(
- self._distribution.reduce(distribute_lib.get_loss_reduction(),
- grouped_estimator_spec.loss,
- destinations='/device:CPU:0'))[0],
- train_op=self._distribution.group(grouped_estimator_spec.train_op),
+ loss=loss,
+ train_op=self._distribution.group(distributed_train_op),
training_hooks=training_hooks,
training_chief_hooks=training_chief_hooks,
scaffold=scaffold)
return self._train_with_estimator_spec(estimator_spec, worker_hooks,
- hooks, global_step_read_tensor,
+ hooks, global_step_tensor,
saving_listeners)
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
@@ -1630,11 +1743,12 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+
VocabInfo = warm_starting_util.VocabInfo # pylint: disable=invalid-name
-tf_export('estimator.VocabInfo', allow_multiple_exports=True)(VocabInfo)
+estimator_export('estimator.VocabInfo')(VocabInfo)
-@tf_export('estimator.WarmStartSettings')
+@estimator_export('estimator.WarmStartSettings')
class WarmStartSettings(
collections.namedtuple('WarmStartSettings', [
'ckpt_to_initialize_from',
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 9c0d0f7390..2a0e4e7617 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
@@ -61,6 +62,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
@@ -100,6 +102,11 @@ def check_eventfile_for_keyword(keyword, dir_):
return any(summaries_with_matching_keyword(keyword, dir_))
+def get_mock_saver():
+ real_saver = saver.Saver()
+ return test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
+
+
class EstimatorInheritanceConstraintTest(test.TestCase):
"""Tests that sub classes cannot override methods of Estimator."""
@@ -1290,14 +1297,37 @@ class EstimatorEvaluateTest(test.TestCase):
dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())
self.assertEqual(5, scores['global_step'])
+ def test_wrong_shape_throws_reasonable_error(self):
+ """Make sure we are helpful when model_fns change. See b/110263146."""
+ def _get_model_fn(val=1):
+ def _model_fn(features, labels, mode):
+ del features, labels # unused
+ variables.Variable(val, name='weight')
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ predictions=constant_op.constant([[1.]]),
+ loss=constant_op.constant(0.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+ return _model_fn
+
+ model_fn_1 = _get_model_fn()
+ model_fn_2 = _get_model_fn(val=[1])
+
+ est1 = estimator.Estimator(model_fn=model_fn_1)
+ est1.train(dummy_input_fn, steps=5)
+ est2 = estimator.Estimator(
+ model_fn=model_fn_2, model_dir=est1.model_dir)
+
+ expected_msg = 'Restoring from checkpoint failed.*a mismatch between'
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, expected_msg):
+ est2.train(dummy_input_fn, steps=1,)
+
def test_scaffold_is_used(self):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -1819,9 +1849,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -2315,8 +2343,8 @@ class EstimatorExportTest(test.TestCase):
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
- # Note that the SavedModel builder replaced the Saver with a new one
- self.assertTrue('save_1/LookupTableImportV2' in graph_ops)
+ # The original saver is used to restore variables
+ self.assertTrue('save/LookupTableImportV2' in graph_ops)
# Clean up.
gfile.DeleteRecursively(tmpdir)
@@ -2481,9 +2509,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+ self.mock_saver = get_mock_saver()
scores = constant_op.constant([3.])
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -2506,19 +2532,24 @@ class EstimatorExportTest(test.TestCase):
est.export_savedmodel(export_dir_base, serving_input_receiver_fn)
self.assertTrue(self.mock_saver.restore.called)
+ self.assertTrue(self.mock_saver.export_meta_graph.called)
+ self.assertTrue(self.mock_saver.save.called)
def test_scaffold_is_used_for_saver_multiple_modes(self):
tmpdir = tempfile.mkdtemp()
+ savers = {'predict_saver': None, 'train_saver': None}
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
variables.Variable(1., name='weight')
- real_saver = saver.Saver()
- self.mock_saver = test.mock.Mock(
- wraps=real_saver, saver_def=real_saver.saver_def)
+
scores = constant_op.constant([3.])
if mode == model_fn_lib.ModeKeys.PREDICT:
- scaffold = training.Scaffold(saver=self.mock_saver)
+ savers['predict_saver'] = get_mock_saver()
+ scaffold = training.Scaffold(saver=savers['predict_saver'])
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ savers['train_saver'] = get_mock_saver()
+ scaffold = training.Scaffold(saver=savers['train_saver'])
else:
scaffold = training.Scaffold()
return model_fn_lib.EstimatorSpec(
@@ -2542,7 +2573,13 @@ class EstimatorExportTest(test.TestCase):
compat.as_bytes(tmpdir), compat.as_bytes('export'))
est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
- self.assertTrue(self.mock_saver.restore.called)
+ self.assertTrue(savers['train_saver'].restore.called)
+ self.assertEqual(savers['train_saver'].export_meta_graph.call_count, 1)
+ self.assertEqual(savers['train_saver'].save.call_count, 1)
+
+ self.assertTrue(savers['predict_saver'].restore.called)
+ self.assertEqual(savers['predict_saver'].export_meta_graph.call_count, 1)
+ self.assertEqual(savers['predict_saver'].save.call_count, 0)
def test_scaffold_is_used_for_local_init(self):
tmpdir = tempfile.mkdtemp()
@@ -2819,6 +2856,45 @@ class EstimatorExportTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_export_savedmodel_no_export_outputs(self):
+ """Ensure that an EstimatorSpec without outputs defined can be exported."""
+
+ def _model_fn(features, labels, mode):
+ _, _ = features, labels
+ variables.Variable(1., name='weight')
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(input_fn=dummy_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('no_export_outputs'))
+ export_dir = est.export_savedmodel(
+ export_dir_base, _get_serving_input_receiver_fn())
+
+ # Check that all the files are in the right places.
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self._validate_exported_files(export_dir)
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('weight' in graph_ops)
+
+ sig_def = meta_graph.signature_def
+ self.assertEqual(len(sig_def), 1)
+ sig_outputs = sig_def[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs
+ self.assertEqual(sig_outputs['output'].name, 'Const:0')
+
class EstimatorHookOrderingTest(test.TestCase):
@@ -2863,7 +2939,7 @@ class EstimatorHookOrderingTest(test.TestCase):
class EstimatorIntegrationTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_complete_flow_with_a_simple_linear_model(self):
def _model_fn(features, labels, mode):
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index ff19a0a7f4..ca26341445 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@@ -93,7 +93,7 @@ def _check_tensor_key(name, error_label='feature'):
raise ValueError('{} keys must be strings: {}.'.format(error_label, name))
-@tf_export('estimator.export.ServingInputReceiver')
+@estimator_export('estimator.export.ServingInputReceiver')
class ServingInputReceiver(
collections.namedtuple(
'ServingInputReceiver',
@@ -161,7 +161,7 @@ class ServingInputReceiver(
receiver_tensors_alternatives=receiver_tensors_alternatives)
-@tf_export('estimator.export.TensorServingInputReceiver')
+@estimator_export('estimator.export.TensorServingInputReceiver')
class TensorServingInputReceiver(
collections.namedtuple(
'TensorServingInputReceiver',
@@ -263,7 +263,7 @@ class SupervisedInputReceiver(
receiver_tensors=receiver_tensors)
-@tf_export('estimator.export.build_parsing_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn')
def build_parsing_serving_input_receiver_fn(feature_spec,
default_batch_size=None):
"""Build a serving_input_receiver_fn expecting fed tf.Examples.
@@ -313,7 +313,7 @@ def _placeholders_from_receiver_tensors_dict(input_vals,
}
-@tf_export('estimator.export.build_raw_serving_input_receiver_fn')
+@estimator_export('estimator.export.build_raw_serving_input_receiver_fn')
def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""Build a serving_input_receiver_fn expecting feature Tensors.
@@ -333,11 +333,7 @@ def build_raw_serving_input_receiver_fn(features, default_batch_size=None):
"""A serving_input_receiver_fn that expects features to be fed directly."""
receiver_tensors = _placeholders_from_receiver_tensors_dict(
features, default_batch_size)
-
- # TODO(b/34885899): remove the unnecessary copy
- # The features provided are simply the placeholders, but we defensively copy
- # the dict because it may be mutated.
- return ServingInputReceiver(receiver_tensors, receiver_tensors.copy())
+ return ServingInputReceiver(receiver_tensors, receiver_tensors)
return serving_input_receiver_fn
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index d387ea2940..6c26d29985 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,10 +26,10 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.export.ExportOutput')
+@estimator_export('estimator.export.ExportOutput')
class ExportOutput(object):
"""Represents an output of a model that can be served.
@@ -100,7 +100,7 @@ class ExportOutput(object):
return output_dict
-@tf_export('estimator.export.ClassificationOutput')
+@estimator_export('estimator.export.ClassificationOutput')
class ClassificationOutput(ExportOutput):
"""Represents the output of a classification head.
@@ -169,7 +169,7 @@ class ClassificationOutput(ExportOutput):
examples, self.classes, self.scores)
-@tf_export('estimator.export.RegressionOutput')
+@estimator_export('estimator.export.RegressionOutput')
class RegressionOutput(ExportOutput):
"""Represents the output of a regression head."""
@@ -202,7 +202,7 @@ class RegressionOutput(ExportOutput):
return signature_def_utils.regression_signature_def(examples, self.value)
-@tf_export('estimator.export.PredictOutput')
+@estimator_export('estimator.export.PredictOutput')
class PredictOutput(ExportOutput):
"""Represents the output of a generic prediction head.
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 766ea23f2a..b18212cfcd 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -28,10 +28,10 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.Exporter')
+@estimator_export('estimator.Exporter')
class Exporter(object):
"""A class representing a type of model export."""
@@ -172,7 +172,7 @@ def _verify_compare_fn_args(compare_fn):
(compare_fn, non_valid_args))
-@tf_export('estimator.BestExporter')
+@estimator_export('estimator.BestExporter')
class BestExporter(Exporter):
"""This class exports the serving graph and checkpoints of the best models.
@@ -367,7 +367,7 @@ class BestExporter(Exporter):
return best_eval_result
-@tf_export('estimator.FinalExporter')
+@estimator_export('estimator.FinalExporter')
class FinalExporter(Exporter):
"""This class exports the serving graph and checkpoints in the end.
@@ -418,7 +418,7 @@ class FinalExporter(Exporter):
is_the_final_export)
-@tf_export('estimator.LatestExporter')
+@estimator_export('estimator.LatestExporter')
class LatestExporter(Exporter):
"""This class regularly exports the serving graph and checkpoints.
diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py
index eefc7c712d..a6cefdece2 100644
--- a/tensorflow/python/estimator/inputs/numpy_io.py
+++ b/tensorflow/python/estimator/inputs/numpy_io.py
@@ -24,7 +24,7 @@ import numpy as np
from six import string_types
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
# Key name to pack the target into dict of `features`. See
# `_get_unique_target_key` for details.
@@ -87,7 +87,7 @@ def _validate_and_convert_features(x):
return ordered_dict_data
-@tf_export('estimator.inputs.numpy_input_fn')
+@estimator_export('estimator.inputs.numpy_input_fn')
def numpy_input_fn(x,
y=None,
batch_size=128,
diff --git a/tensorflow/python/estimator/inputs/pandas_io.py b/tensorflow/python/estimator/inputs/pandas_io.py
index 1ed6ed4d84..616bcb410f 100644
--- a/tensorflow/python/estimator/inputs/pandas_io.py
+++ b/tensorflow/python/estimator/inputs/pandas_io.py
@@ -18,10 +18,12 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+import uuid
import numpy as np
from tensorflow.python.estimator.inputs.queues import feeding_functions
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
try:
# pylint: disable=g-import-not-at-top
@@ -35,7 +37,23 @@ except ImportError:
HAS_PANDAS = False
-@tf_export('estimator.inputs.pandas_input_fn')
+def _get_unique_target_key(features, target_column_name):
+ """Returns a key that does not exist in the input DataFrame `features`.
+
+ Args:
+ features: DataFrame
+ target_column_name: Name of the target column as a `str`
+
+ Returns:
+ A unique key that can be used to insert the target into
+ features.
+ """
+ if target_column_name in features:
+ target_column_name += '_' + str(uuid.uuid4())
+ return target_column_name
+
+
+@estimator_export('estimator.inputs.pandas_input_fn')
def pandas_input_fn(x,
y=None,
batch_size=128,
@@ -50,7 +68,7 @@ def pandas_input_fn(x,
Args:
x: pandas `DataFrame` object.
- y: pandas `Series` object. `None` if absent.
+ y: pandas `Series` object or `DataFrame`. `None` if absent.
batch_size: int, size of batches to return.
num_epochs: int, number of epochs to iterate over data. If not `None`,
read attempts that would exceed this value will raise `OutOfRangeError`.
@@ -60,7 +78,8 @@ def pandas_input_fn(x,
num_threads: Integer, number of threads used for reading and enqueueing. In
order to have predicted and repeatable order of reading and enqueueing,
such as in prediction and evaluation mode, `num_threads` should be 1.
- target_column: str, name to give the target column `y`.
+ target_column: str, name to give the target column `y`. This parameter
+ is not used when `y` is a `DataFrame`.
Returns:
Function, that has signature of ()->(dict of `features`, `target`)
@@ -79,6 +98,9 @@ def pandas_input_fn(x,
'(it is recommended to set it as True for training); '
'got {}'.format(shuffle))
+ if not isinstance(target_column, six.string_types):
+ raise TypeError('target_column must be a string type')
+
x = x.copy()
if y is not None:
if target_column in x:
@@ -88,7 +110,13 @@ def pandas_input_fn(x,
if not np.array_equal(x.index, y.index):
raise ValueError('Index for x and y are mismatched.\nIndex for x: %s\n'
'Index for y: %s\n' % (x.index, y.index))
- x[target_column] = y
+ if isinstance(y, pd.DataFrame):
+ y_columns = [(column, _get_unique_target_key(x, column))
+ for column in list(y)]
+ target_column = [v for _, v in y_columns]
+ x[target_column] = y
+ else:
+ x[target_column] = y
# TODO(mdan): These are memory copies. We probably don't need 4x slack space.
# The sizes below are consistent with what I've seen elsewhere.
@@ -118,7 +146,12 @@ def pandas_input_fn(x,
features = features[1:]
features = dict(zip(list(x.columns), features))
if y is not None:
- target = features.pop(target_column)
+ if isinstance(target_column, list):
+ keys = [k for k, _ in y_columns]
+ values = [features.pop(column) for column in target_column]
+ target = {k: v for k, v in zip(keys, values)}
+ else:
+ target = features.pop(target_column)
return features, target
return features
return input_fn
diff --git a/tensorflow/python/estimator/inputs/pandas_io_test.py b/tensorflow/python/estimator/inputs/pandas_io_test.py
index dcecf6dd61..6f13bc95d2 100644
--- a/tensorflow/python/estimator/inputs/pandas_io_test.py
+++ b/tensorflow/python/estimator/inputs/pandas_io_test.py
@@ -47,6 +47,16 @@ class PandasIoTest(test.TestCase):
y = pd.Series(np.arange(-32, -28), index=index)
return x, y
+ def makeTestDataFrameWithYAsDataFrame(self):
+ index = np.arange(100, 104)
+ a = np.arange(4)
+ b = np.arange(32, 36)
+ a_label = np.arange(10, 14)
+ b_label = np.arange(50, 54)
+ x = pd.DataFrame({'a': a, 'b': b}, index=index)
+ y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)
+ return x, y
+
def callInputFnOnce(self, input_fn, session):
results = input_fn()
coord = coordinator.Coordinator()
@@ -65,6 +75,19 @@ class PandasIoTest(test.TestCase):
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
+ def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):
+ if not HAS_PANDAS:
+ return
+
+ x, y = self.makeTestDataFrame()
+
+ with self.assertRaisesRegexp(TypeError,
+ 'target_column must be a string type'):
+ pandas_io.pandas_input_fn(x, y, batch_size=2,
+ shuffle=False,
+ num_epochs=1,
+ target_column=['one', 'two'])
+
def testPandasInputFn_NonBoolShuffle(self):
if not HAS_PANDAS:
return
@@ -90,6 +113,53 @@ class PandasIoTest(test.TestCase):
self.assertAllEqual(features['b'], [32, 33])
self.assertAllEqual(target, [-32, -31])
+ def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a_target'], [10, 11])
+ self.assertAllEqual(targets['b_target'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['b'], [50, 51])
+
+ def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):
+ if not HAS_PANDAS:
+ return
+ with self.test_session() as session:
+ x, y = self.makeTestDataFrameWithYAsDataFrame()
+ y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})
+ input_fn = pandas_io.pandas_input_fn(
+ x, y, batch_size=2, shuffle=False, num_epochs=1)
+
+ features, targets = self.callInputFnOnce(input_fn, session)
+
+ self.assertAllEqual(features['a'], [0, 1])
+ self.assertAllEqual(features['b'], [32, 33])
+ self.assertAllEqual(targets['a'], [10, 11])
+ self.assertAllEqual(targets['a_n'], [50, 51])
+
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6856b8b5a9..076359b503 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -39,13 +39,13 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
-from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -70,16 +70,22 @@ def _convert_tensor(x):
return x
-def _any_variable_initialized():
- """Check if any variable has been initialized in the Keras model.
+def _any_weight_initialized(keras_model):
+ """Check if any weights has been initialized in the Keras model.
+
+ Args:
+ keras_model: An instance of compiled keras model.
Returns:
- boolean, True if at least one variable has been initialized, else False.
+ boolean, True if at least one weight has been initialized, else False.
+ Currently keras initialize all weights at get_session().
"""
- variables = variables_module.global_variables()
- for v in variables:
- if getattr(v, '_keras_initialized', False):
- return True
+ if keras_model is None:
+ return False
+ for layer in keras_model.layers:
+ for weight in layer.weights:
+ if hasattr(weight, '_keras_initialized'):
+ return True
return False
@@ -123,8 +129,8 @@ def _create_ordered_io(keras_model, estimator_io, is_input=True):
'It needs to match one '
'of the following: %s' % ('input' if is_input else 'output', key,
', '.join(keras_io_names)))
- tensors = [_convert_tensor(estimator_io[io_name])
- for io_name in keras_io_names]
+ tensors = [_convert_tensor(estimator_io[io_name])
+ for io_name in keras_io_names]
return tensors
else:
# Plain array.
@@ -242,8 +248,17 @@ def _in_place_subclassed_model_state_restoration(model):
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
- model._layers = []
+ # Models have sticky attribute assignment, so we want to be careful to add
+ # back the previous attributes and track Layers by their original names
+ # without adding dependencies on "utility" attributes which Models exempt
+ # when they're constructed.
+ model._layers = data_structures.NoDependency([])
for name, value in model._original_attributes_cache.items():
+ if not isinstance(value, checkpointable.CheckpointableBase):
+ # If this value is not already checkpointable, it's probably that way
+ # for a reason; we don't want to start tracking data structures that the
+ # original Model didn't.
+ value = data_structures.NoDependency(value)
setattr(model, name, value)
model._original_attributes_cache = None
else:
@@ -446,7 +461,6 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
-@tf_export('keras.estimator.model_to_estimator')
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@@ -511,7 +525,7 @@ def model_to_estimator(keras_model=None,
keras_model_fn, model_dir=model_dir, config=config)
# Check if we need to call get_weights:
- if _any_variable_initialized():
+ if _any_weight_initialized(keras_model):
keras_weights = keras_model.get_weights()
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 6688a84130..7a4457f5a4 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -31,10 +31,10 @@ from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
-from tensorflow.python.keras import backend as K
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.optimizers import SGD
+from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -146,13 +146,13 @@ def randomize_io_type(array, name):
def multi_inputs_multi_outputs_model():
a = keras.layers.Input(shape=(16,), name='input_a')
b = keras.layers.Input(shape=(16,), name='input_b')
- m = keras.layers.Input(shape=(8,), dtype='bool', name='input_m')
+ m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
dense = keras.layers.Dense(8, name='dense_1')
a_2 = dense(a)
- # Apply a mask
- s_2 = keras.layers.Lambda(lambda k:
- K.switch(k[0], k[1], K.zeros_like(k[1])))([m, a_2])
+ # Read m
+ m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m)
+ s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2])
b_2 = dense(b)
merged = keras.layers.concatenate([s_2, b_2], name='merge')
c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
@@ -204,6 +204,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_tf_optimizer(self):
for model_type in ['sequential', 'functional']:
keras_model, (_, _), (
@@ -231,6 +232,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ @test_util.run_in_graph_and_eager_modes
def test_train_with_subclassed_model(self):
keras_model, (_, _), (
_, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
@@ -372,13 +374,13 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
def train_input_fn():
input_dict = {'input_a': a_train, 'input_b': b_train,
- 'input_m': input_m_train > 0}
+ 'input_m': input_m_train.astype(np.str)}
output_dict = {'dense_2': c_train, 'dense_3': d_train}
return input_dict, output_dict
def eval_input_fn():
input_dict = {'input_a': a_test, 'input_b': b_test,
- 'input_m': input_m_test > 0}
+ 'input_m': input_m_test.astype(np.str)}
output_dict = {'dense_2': c_test, 'dense_3': d_test}
return input_dict, output_dict
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 3edf9fe940..a9fd8f8e1a 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -23,7 +23,7 @@ import collections
import six
-from tensorflow.python.estimator.export.export_output import ExportOutput
+from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -32,10 +32,10 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
-@tf_export('estimator.ModeKeys')
+@estimator_export('estimator.ModeKeys')
class ModeKeys(object):
"""Standard names for model modes.
@@ -62,7 +62,7 @@ EXPORT_TAG_MAP = {
}
-@tf_export('estimator.EstimatorSpec')
+@estimator_export('estimator.EstimatorSpec')
class EstimatorSpec(
collections.namedtuple('EstimatorSpec', [
'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
@@ -99,7 +99,7 @@ class EstimatorSpec(
ignored in eval and infer modes. Example:
```python
- def my_model_fn(mode, features, labels):
+ def my_model_fn(features, labels, mode):
predictions = ...
loss = ...
train_op = ...
@@ -114,7 +114,7 @@ class EstimatorSpec(
given mode. Example:
```python
- def my_model_fn(mode, features, labels):
+ def my_model_fn(features, labels, mode):
if (mode == tf.estimator.ModeKeys.TRAIN or
mode == tf.estimator.ModeKeys.EVAL):
loss = ...
@@ -158,6 +158,8 @@ class EstimatorSpec(
Multi-headed models should specify one entry for each head, one of
which must be named using
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
+ If no entry is provided, a default `PredictOutput` mapping to
+ `predictions` will be created.
training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to
run on the chief worker during training.
training_hooks: Iterable of `tf.train.SessionRunHook` objects to run
@@ -232,29 +234,9 @@ class EstimatorSpec(
_check_is_tensor_or_operation(metric_update,
'eval_metric_ops[{}]'.format(key))
- # Validate export_outputs.
- if export_outputs is not None:
- if not isinstance(export_outputs, dict):
- raise TypeError('export_outputs must be dict, given: {}'.format(
- export_outputs))
- for v in six.itervalues(export_outputs):
- if not isinstance(v, ExportOutput):
- raise TypeError(
- 'Values in export_outputs must be ExportOutput objects. '
- 'Given: {}'.format(export_outputs))
- # Note export_outputs is allowed to be empty.
- if len(export_outputs) == 1:
- (key, value), = export_outputs.items()
- if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- export_outputs[
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
- if len(export_outputs) > 1:
- if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- not in export_outputs):
- raise ValueError(
- 'Multiple export_outputs were provided, but none of them is '
- 'specified as the default. Do this by naming one of them with '
- 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.')
+ # Validate the passed export outputs, or generate defaults.
+ if mode == ModeKeys.PREDICT:
+ export_outputs = _get_export_outputs(export_outputs, predictions)
# Validate that all tensors and ops are from the default graph.
default_graph = ops.get_default_graph()
@@ -286,11 +268,11 @@ class EstimatorSpec(
raise ValueError(error_message_template.format('train_op', train_op.name))
for key, value in list(six.iteritems(eval_metric_ops)):
values = nest.flatten(value)
- for value in values:
- if value.graph is not default_graph:
+ for val in values:
+ if val.graph is not default_graph:
raise ValueError(error_message_template.format(
'eval_metric_ops',
- '{0}: {1}'.format(key, value.name)))
+ '{0}: {1}'.format(key, val.name)))
# Validate hooks.
training_chief_hooks = tuple(training_chief_hooks or [])
@@ -334,6 +316,70 @@ class EstimatorSpec(
return EstimatorSpec(*new_fields)
+def _get_export_outputs(export_outputs, predictions):
+ """Validate export_outputs or create default export_outputs.
+
+ Args:
+ export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict or None.
+ predictions: Predictions `Tensor` or dict of `Tensor`.
+
+ Returns:
+ Valid export_outputs dict
+
+ Raises:
+ TypeError: if export_outputs is not a dict or its values are not
+ ExportOutput instances.
+ """
+ if export_outputs is None:
+ default_output = export_output_lib.PredictOutput(predictions)
+ export_outputs = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: default_output}
+
+ if not isinstance(export_outputs, dict):
+ raise TypeError('export_outputs must be dict, given: {}'.format(
+ export_outputs))
+ for v in six.itervalues(export_outputs):
+ if not isinstance(v, export_output_lib.ExportOutput):
+ raise TypeError(
+ 'Values in export_outputs must be ExportOutput objects. '
+ 'Given: {}'.format(export_outputs))
+
+ _maybe_add_default_serving_output(export_outputs)
+
+ return export_outputs
+
+
+def _maybe_add_default_serving_output(export_outputs):
+ """Add a default serving output to the export_outputs if not present.
+
+ Args:
+ export_outputs: Describes the output signatures to be exported to
+ `SavedModel` and used during serving. Should be a dict.
+
+ Returns:
+ export_outputs dict with default serving signature added if necessary
+
+ Raises:
+ ValueError: if multiple export_outputs were provided without a default
+ serving key.
+ """
+ if len(export_outputs) == 1:
+ (key, value), = export_outputs.items()
+ if key != signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_outputs[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = value
+ if len(export_outputs) > 1:
+ if (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ not in export_outputs):
+ raise ValueError(
+ 'Multiple export_outputs were provided, but none of them is '
+ 'specified as the default. Do this by naming one of them with '
+ 'signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.')
+
+ return export_outputs
+
+
class _TPUEstimatorSpec(collections.namedtuple('TPUEstimatorSpec', [
'mode',
'predictions',
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index b7eeeb437c..08e41fd414 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -592,6 +592,27 @@ class EstimatorSpecInferTest(test.TestCase):
predictions=predictions,
export_outputs=export_outputs)
+ def testDefaultExportOutputCreated(self):
+ """Ensure that a default PredictOutput is created for export."""
+ with ops.Graph().as_default(), self.test_session():
+ predictions = constant_op.constant(1.)
+ self._assertDefaultExportOutputForPredictions(predictions)
+
+ def testDefaultExportOutputCreatedDict(self):
+ """Ensure that a default PredictOutput is created for export for dicts."""
+ with ops.Graph().as_default(), self.test_session():
+ predictions = {'loss': constant_op.constant(1.),
+ 'score': constant_op.constant(10.)}
+ self._assertDefaultExportOutputForPredictions(predictions)
+
+ def _assertDefaultExportOutputForPredictions(self, predictions):
+ spec = model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.PREDICT, predictions=predictions)
+
+ expected = export_output.PredictOutput(predictions).outputs
+ serving_output = spec.export_outputs[
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
+ self.assertEqual(serving_output.outputs, expected)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index c7707be839..aa594af2e4 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -25,11 +25,12 @@ import os
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
from tensorflow.python.util import function_utils
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_USE_DEFAULT = object()
@@ -296,7 +297,7 @@ class TaskType(object):
EVALUATOR = 'evaluator'
-@tf_export('estimator.RunConfig')
+@estimator_export('estimator.RunConfig')
class RunConfig(object):
"""This class specifies the configurations for an `Estimator` run."""
@@ -484,6 +485,52 @@ class RunConfig(object):
self._init_distributed_setting_from_environment_var(tf_config)
+ self._maybe_overwrite_session_config_for_distributed_training()
+
+ def _maybe_overwrite_session_config_for_distributed_training(self):
+ """Overwrites the session_config for distributed training.
+
+ The default overwrite is optimized for between-graph training. Subclass
+ should override this method if necessary.
+ """
+ # Get session_config only for between-graph distributed mode (cluster_spec
+ # is present).
+ if not self._session_config and self._cluster_spec:
+ RunConfig._replace(
+ self,
+ allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
+ session_config=self._get_default_session_config())
+
+ def _get_default_session_config(self):
+ """Returns None or tf.ConfigProto instance with default device_filters set.
+
+ Device filters are set such that chief/master and worker communicates with
+ only ps. session_config=None for evaluators or any other TaskType.
+ """
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+
+ device_filters = None
+ if self._task_type == TaskType.MASTER:
+ device_filters = ['/job:ps', '/job:master']
+ elif self._task_type == TaskType.CHIEF:
+ device_filters = ['/job:ps', '/job:chief']
+ elif self._task_type == TaskType.WORKER:
+ device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id]
+ elif self._task_type == TaskType.PS:
+ device_filters = ['/job:ps', '/job:worker', '/job:master']
+ else:
+ # If the task_type is `EVALUATOR` or something other than the ones in
+ # TaskType then don't set any device filters.
+ return None
+
+ return config_pb2.ConfigProto(
+ allow_soft_placement=True,
+ graph_options=graph_opts,
+ device_filters=device_filters)
+
def _init_distributed_setting_from_environment_var(self, tf_config):
"""Initialize distributed properties based on `tf_config`."""
diff --git a/tensorflow/python/estimator/run_config_test.py b/tensorflow/python/estimator/run_config_test.py
index c8b12605e1..06df7cb9dd 100644
--- a/tensorflow/python/estimator/run_config_test.py
+++ b/tensorflow/python/estimator/run_config_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import json
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.platform import test
@@ -290,6 +291,7 @@ class RunConfigDistributedSettingTest(test.TestCase):
expected_num_worker_replicas=1,
expected_num_ps_replicas=0)
self.assertEqual(0, run_config.global_id_in_cluster)
+ self.assertIsNone(run_config.session_config, None)
def test_session_master_for_local(self):
tf_config = {'session_master': '_my_master'}
@@ -1119,5 +1121,115 @@ class RunConfigModelDirTest(test.TestCase):
_create_run_config_with_cluster_spec(tf_config)
+class RunConfigSessionConfigTest(test.TestCase):
+
+ def _assert_equal_session_config(self, session_config,
+ expected_device_filters):
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+ expected_session_config = config_pb2.ConfigProto(
+ allow_soft_placement=True,
+ graph_options=graph_opts,
+ device_filters=expected_device_filters)
+ self.assertEqual(session_config, expected_session_config)
+
+ def test_master_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.MASTER,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:master'])
+
+ def test_chief_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.CHIEF,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:chief'])
+
+ def test_worker_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.WORKER,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker/task:1'])
+
+ def test_ps_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.PS,
+ 'index': 1
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self._assert_equal_session_config(run_config.session_config,
+ ['/job:ps', '/job:worker', '/job:master'])
+
+ def test_evaluator_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.CHIEF: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': run_config_lib.TaskType.EVALUATOR,
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+ def test_other_type_session_config(self):
+ tf_config = {
+ 'cluster': {
+ run_config_lib.TaskType.MASTER: ['host0:0'],
+ run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
+ 'other_type': ['host3:1', 'host4:2'],
+ run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']
+ },
+ 'task': {
+ 'type': 'other_type',
+ 'index': 0
+ }
+ }
+ run_config = _create_run_config_with_cluster_spec(tf_config)
+ self.assertIsNone(run_config.session_config)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 522662cd32..f5ac79ced2 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -35,7 +35,7 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
-from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.util.tf_export import estimator_export
_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
@@ -115,7 +115,7 @@ def _is_google_env():
return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE
-@tf_export('estimator.TrainSpec')
+@estimator_export('estimator.TrainSpec')
class TrainSpec(
collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])):
"""Configuration for the "train" part for the `train_and_evaluate` call.
@@ -167,7 +167,7 @@ class TrainSpec(
cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
-@tf_export('estimator.EvalSpec')
+@estimator_export('estimator.EvalSpec')
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
@@ -263,7 +263,7 @@ class EvalSpec(
throttle_secs=throttle_secs)
-@tf_export('estimator.train_and_evaluate')
+@estimator_export('estimator.train_and_evaluate')
def train_and_evaluate(estimator, train_spec, eval_spec):
"""Train and evaluate the `estimator`.
@@ -278,10 +278,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
supported distributed training configuration is between-graph replication.
Overfitting: In order to avoid overfitting, it is recommended to set up the
- training `input_fn` to shuffle the training data properly. It is also
- recommended to train the model a little longer, say multiple epochs, before
- performing evaluation, as the input pipeline starts from scratch for each
- training. It is particularly important for local training and evaluation.
+ training `input_fn` to shuffle the training data properly.
Stop condition: In order to support both distributed and non-distributed
configuration reliably, the only supported stop condition for model
@@ -295,6 +292,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
model will be trained with three epochs of training data instead of one epoch.
Example of local (non-distributed) training:
+
```python
# Set up feature columns.
categorial_feature_a = categorial_column_with_hash_bucket(...)
@@ -314,10 +312,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
# hidden_units=[1024, 512, 256])
# Input pipeline for train and evaluate.
- def train_input_fn: # returns x, y
+ def train_input_fn(): # returns x, y
# please shuffle the data.
pass
- def eval_input_fn_eval: # returns x, y
+ def eval_input_fn(): # returns x, y
pass
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
@@ -339,12 +337,14 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
Setting environment variable depends on the platform. For example, on Linux,
it can be done as follows (`$` is the shell prompt):
+
```
$ TF_CONFIG='<replace_with_real_content>' python train_model.py
```
For the content in `TF_CONFIG`, assume that the training cluster spec looks
like:
+
```
cluster = {"chief": ["host0:2222"],
"worker": ["host1:2222", "host2:2222", "host3:2222"],
@@ -352,6 +352,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
```
Example of `TF_CONFIG` for chief training worker (must have one and only one):
+
```
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
@@ -371,6 +372,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
Example of `TF_CONFIG` for non-chief training worker (optional, could be
multiple):
+
```
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
@@ -387,6 +389,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
for non-chief training workers.
Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):
+
```
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
@@ -405,6 +408,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is
not part of the training cluster. There could be only one. It is used for
model evaluation.
+
```
# This should be a JSON string, which is set as environment variable. Usually
# the cluster manager handles that.
@@ -463,6 +467,61 @@ class _StopAtSecsHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+class _NewCheckpointListenerForEvaluate(
+ basic_session_run_hooks.CheckpointSaverListener):
+ """A saver listener to run evaluate with every checkpoint."""
+
+ def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener):
+ self._evaluator = evaluator
+ self._eval_throttle_secs = eval_throttle_secs
+ self._continuous_eval_listener = continuous_eval_listener
+ self.eval_result, self.export_results = None, None
+
+ def begin(self):
+ self._timer = basic_session_run_hooks.SecondOrStepTimer(
+ every_secs=self._eval_throttle_secs)
+ self._is_first_run = True
+
+ def after_save(self, session, global_step_value):
+ del session # unused; required by signature.
+ # skip first run model is not trained yet.
+ if self._is_first_run:
+ self._is_first_run = False
+ return
+
+ if not self._continuous_eval_listener.before_eval():
+ logging.info('Exiting training and evaluation loop, as requested by '
+ '_ContinuousEvalListener.before_eval.')
+ return True
+ if self._timer.should_trigger_for_step(global_step_value):
+ self._evaluate(global_step_value) # updates self.eval_result
+ if not self._continuous_eval_listener.after_eval(self.eval_result):
+ logging.info('Exiting evaluation, as requested by '
+ '_ContinuousEvalListener.after_eval.')
+ return True
+ else:
+ # TODO(ispir): add remaining time in the log.
+ logging.info('Skip the current checkpoint eval due to throttle secs '
+ '({} secs).'.format(self._eval_throttle_secs))
+
+ def end(self, session, global_step_value):
+ # Evaluate if the last step has not been evaluated, yet.
+ if global_step_value != self._timer.last_triggered_step():
+ if self._continuous_eval_listener.before_eval():
+ self._evaluate(global_step_value)
+ self._continuous_eval_listener.after_eval(self.eval_result)
+
+ def _evaluate(self, global_step_value):
+ self._timer.update_last_triggered_step(global_step_value)
+ self.eval_result, self.export_results = (
+ self._evaluator.evaluate_and_export())
+ if self.eval_result.status != _EvalStatus.EVALUATED:
+ # This is unexpected; should never happen.
+ # Training should always end with a new checkpoint.
+ raise RuntimeError('There was no new checkpoint after the training. '
+ 'Eval status: {}'.format(self.eval_result.status))
+
+
class _TrainingExecutor(object):
"""The executor to run `Estimator` training and evaluation.
@@ -569,28 +628,6 @@ class _TrainingExecutor(object):
def run_master(self):
"""Runs task master."""
-
- class NewCheckpointListener(
- basic_session_run_hooks.CheckpointSaverListener):
-
- def __init__(self, evaluator, eval_throttle_secs):
- self._evaluator = evaluator
- self._eval_throttle_secs = eval_throttle_secs
-
- def begin(self):
- self._timer = basic_session_run_hooks.SecondOrStepTimer(
- every_secs=self._eval_throttle_secs)
-
- def after_save(self, session, global_step_value):
- del session # unused; required by signature.
-
- if self._timer.should_trigger_for_step(global_step_value):
- self._timer.update_last_triggered_step(global_step_value)
- self._evaluator.evaluate_and_export()
- else:
- logging.info('Skip the current checkpoint eval due to throttle secs '
- '({} secs).'.format(self._eval_throttle_secs))
-
_assert_eval_spec(self._eval_spec)
# Final export signal: For any eval result with global_step >= train
@@ -610,16 +647,12 @@ class _TrainingExecutor(object):
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
saving_listeners = [
- NewCheckpointListener(evaluator, self._eval_spec.throttle_secs)
+ _NewCheckpointListenerForEvaluate(evaluator,
+ self._eval_spec.throttle_secs,
+ _ContinuousEvalListener())
]
self._start_distributed_training(saving_listeners=saving_listeners)
- if not evaluator.is_final_export_triggered:
- logging.info('Training has already ended. But the last eval is skipped '
- 'due to eval throttle_secs. Now evaluating the final '
- 'checkpoint.')
- evaluator.evaluate_and_export()
-
def run_evaluator(self):
"""Runs task evaluator."""
# TODO(xiejw): To allow execution framework to add continuous eval listener.
@@ -633,68 +666,33 @@ class _TrainingExecutor(object):
def run_local(self):
"""Runs training and evaluation locally (non-distributed)."""
-
- def _should_stop_local_train(global_step):
- if self._train_spec.max_steps is None:
- return False
- if global_step >= self._train_spec.max_steps:
- return True
- return False
-
_assert_eval_spec(self._eval_spec)
- if self._eval_spec.throttle_secs <= 0:
- raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
- 'It is used do determine how long each training '
- 'iteration should go when train and evaluate '
- 'locally.'.format(self._eval_spec.throttle_secs))
-
- stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
- train_hooks = (
- list(self._train_spec.hooks) + [stop_hook] + list(self._train_hooks))
+ train_hooks = list(self._train_spec.hooks) + list(self._train_hooks)
logging.info('Start train and evaluate loop. The evaluate will happen '
- 'after {} secs (eval_spec.throttle_secs) or training is '
- 'finished.'.format(self._eval_spec.throttle_secs))
+ 'after every checkpoint. Checkpoint frequency is determined '
+ 'based on RunConfig arguments: save_checkpoints_steps {} or '
+ 'save_checkpoints_secs {}.'.format(
+ self._estimator.config.save_checkpoints_steps,
+ self._estimator.config.save_checkpoints_secs))
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
self._train_spec.max_steps)
- eval_result = _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT)
- export_results = []
-
- while True:
- self._estimator.train(
- input_fn=self._train_spec.input_fn,
- max_steps=self._train_spec.max_steps,
- hooks=train_hooks)
-
- if not self._continuous_eval_listener.before_eval():
- logging.info('Exiting training and evaluation loop, as requested by '
- '_ContinuousEvalListener.before_eval.')
- break
-
- # Final export signal: For any eval result with global_step >= train
- # max_steps, the evaluator will send the final export signal. The
- # _should_stop_local_train will then end the while True as the stopping
- # condition is satisfied (both checks use the same global_step value,
- # i.e., no race condition)
- eval_result, export_results = evaluator.evaluate_and_export()
-
- if eval_result.status != _EvalStatus.EVALUATED:
- # This is unexpected; should never happen.
- # Training should always end with a new checkpoint.
- raise RuntimeError('There was no new checkpoint after the training. '
- 'Eval status: {}'.format(eval_result.status))
-
- if not self._continuous_eval_listener.after_eval(eval_result):
- logging.info('Exiting evaluation, as requested by '
- '_ContinuousEvalListener.after_eval.')
- break
+ listener_for_eval = _NewCheckpointListenerForEvaluate(
+ evaluator, self._eval_spec.throttle_secs,
+ self._continuous_eval_listener)
+ saving_listeners = [listener_for_eval]
+
+ self._estimator.train(
+ input_fn=self._train_spec.input_fn,
+ max_steps=self._train_spec.max_steps,
+ hooks=train_hooks,
+ saving_listeners=saving_listeners)
- if _should_stop_local_train(
- eval_result.metrics[ops.GraphKeys.GLOBAL_STEP]):
- break
- return eval_result.metrics, export_results
+ eval_result = listener_for_eval.eval_result or _EvalResult(
+ status=_EvalStatus.MISSING_CHECKPOINT)
+ return eval_result.metrics, listener_for_eval.export_results
def _start_std_server(self, config):
"""Creates, starts, and returns a server_lib.Server."""
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 2c838db7a4..6bee7cbe83 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -29,17 +29,21 @@ import time
import numpy as np
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
from tensorflow.python.estimator import training
from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export as export_lib
-from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -49,6 +53,7 @@ from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
from tensorflow.python.util import compat
_DEFAULT_EVAL_STEPS = 100
@@ -885,7 +890,8 @@ class TrainingExecutorRunMasterTest(test.TestCase):
# `after_save`.
del args, kwargs
saving_listeners[0].begin()
- saving_listeners[0].after_save(session=None, global_step_value=None)
+ saving_listeners[0].after_save(session=None, global_step_value=0)
+ saving_listeners[0].after_save(session=None, global_step_value=10)
mock_est = test.mock.Mock(
spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)
@@ -930,7 +936,10 @@ class TrainingExecutorRunMasterTest(test.TestCase):
del args, kwargs
saving_listeners[0].begin()
- # Call three times.
+ # Call four times.
+ mock_timer.should_trigger_for_step.return_value = True
+ saving_listeners[0].after_save(session=None, global_step_value=None)
+
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=None)
@@ -979,14 +988,19 @@ class TrainingExecutorRunMasterTest(test.TestCase):
del args, kwargs
saving_listeners[0].begin()
- # Call two times.
+ # Call tree times (one for first saving).
mock_timer.should_trigger_for_step.return_value = True
- saving_listeners[0].after_save(session=None, global_step_value=None)
+ saving_listeners[0].after_save(session=None, global_step_value=0)
+
+ mock_timer.should_trigger_for_step.return_value = True
+ saving_listeners[0].after_save(session=None, global_step_value=125)
- # The final ckpt is skipped by the timer. It will be picked up the final
- # export check in the code.
mock_timer.should_trigger_for_step.return_value = False
- saving_listeners[0].after_save(session=None, global_step_value=None)
+ saving_listeners[0].after_save(session=None, global_step_value=250)
+
+ # At the end evaluate should be called even if throttle secs prevents it.
+ mock_timer.should_trigger_for_step.return_value = False
+ saving_listeners[0].end(session=None, global_step_value=300)
mock_est.train = estimator_train
mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
@@ -1566,28 +1580,31 @@ class StopAtSecsHookTest(test.TestCase):
class TrainingExecutorRunLocalTest(test.TestCase):
"""Tests run_local of _TrainingExecutor."""
+ def _model_fn(self, features, labels, mode):
+ del labels
+ with ops.control_dependencies([features]):
+ train_op = state_ops.assign_add(training_util.get_global_step(), 1)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(0.),
+ train_op=train_op,
+ predictions=constant_op.constant([[10.]]),
+ eval_metric_ops={'mean_of_features': metrics_lib.mean(features)})
+
+ def _input_fn(self, repeat=True):
+ ds = dataset_ops.Dataset.from_tensors([1])
+ if repeat:
+ return ds.repeat()
+ return ds
+
def unique_checkpoint_every_time_fn(self):
return 'checkpoint_path_%s/' % random.random()
- def test_send_stop_at_secs_to_train(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
- train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
- eval_spec = training.EvalSpec(
- input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
- mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
-
- executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
- executor.run_local()
-
- stop_hook = mock_est.train.call_args[1]['hooks'][-1]
- self.assertIsInstance(stop_hook, training._StopAtSecsHook)
- self.assertEqual(eval_spec.throttle_secs, stop_hook._stop_after_secs)
-
- def test_runs_in_a_loop_until_max_steps(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
+ def test_runs_evaluate_with_every_new_checkpoint(self):
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=10))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
mock_est.times_export_was_called = 0
mock_est.times_final_export_was_called = 0
@@ -1604,42 +1621,30 @@ class TrainingExecutorRunLocalTest(test.TestCase):
exporter.name = 'see_how_many_times_export_is_called'
exporter.export = export
- train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=22)
eval_spec = training.EvalSpec(
- input_fn=lambda: 1,
- hooks=[_FakeHook()],
- throttle_secs=100,
+ input_fn=lambda: self._input_fn(repeat=False),
+ throttle_secs=0,
exporters=exporter)
- # should be called 3 times.
- mock_est.evaluate.side_effect = [{
- _GLOBAL_STEP_KEY: train_spec.max_steps - 100
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps - 50
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps
- }]
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
- self.assertEqual(3, mock_est.train.call_count)
+ self.assertEqual(1, mock_est.train.call_count)
self.assertEqual(3, mock_est.evaluate.call_count)
self.assertEqual(3, mock_est.times_export_was_called)
self.assertEqual(1, mock_est.times_final_export_was_called)
def test_runs_with_eval_listener_before_eval(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=10))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
- train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
- eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100)
- # should be called 2 times without the evallistener
- mock_est.evaluate.side_effect = [{
- _GLOBAL_STEP_KEY: train_spec.max_steps - 50
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps
- }]
+ train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12)
+ eval_spec = training.EvalSpec(input_fn=lambda: self._input_fn(repeat=False))
+ mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: train_spec.max_steps}]
class _Listener(training._ContinuousEvalListener):
@@ -1658,67 +1663,61 @@ class TrainingExecutorRunLocalTest(test.TestCase):
self.assertEqual(1, mock_est.train.call_count)
self.assertEqual(0, mock_est.evaluate.call_count)
- self.assertEqual(1, listener.call_count)
def test_runs_with_eval_listener_after_eval(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=10))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
- train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
- eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=100)
- # should be called 2 times without the evallistener
- mock_est.evaluate.side_effect = [{
- _GLOBAL_STEP_KEY: train_spec.max_steps - 50
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps
- }]
+ train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=3000)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
class _Listener(training._ContinuousEvalListener):
- def __init__(self, test_case):
+ def __init__(self):
self.call_count = 0
- self._test_case = test_case
def after_eval(self, eval_result):
self.call_count += 1
- self._test_case.assertEqual(
- train_spec.max_steps - 50, eval_result.metrics[_GLOBAL_STEP_KEY])
return False # Will stop the run_local after first eval.
- listener = _Listener(test_case=self)
+ listener = _Listener()
executor = training._TrainingExecutor(
mock_est, train_spec, eval_spec, continuous_eval_listener=listener)
- executor.run_local()
+ metrics, _ = executor.run_local() # pylint: disable=assignment-from-no-return
self.assertEqual(1, mock_est.train.call_count)
self.assertEqual(1, mock_est.evaluate.call_count)
self.assertEqual(1, listener.call_count)
+ # Should be less than max_steps since listener did early stopping.
+ self.assertLess(metrics[_GLOBAL_STEP_KEY], train_spec.max_steps)
def test_handles_no_new_checkpoint_found(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint.return_value = (
- 'no_new_checkpoints_after_the_first_train_step')
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ # disable saving checkpoint
+ config=run_config_lib.RunConfig(
+ save_checkpoints_steps=None, save_checkpoints_secs=None))
train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, hooks=[_FakeHook()], throttle_secs=100)
- # It was going to be called 3 times.
- mock_est.evaluate.side_effect = [{
- _GLOBAL_STEP_KEY: train_spec.max_steps - 100
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps - 50
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps
- }]
+ input_fn=lambda: self._input_fn(repeat=False),
+ hooks=[_FakeHook()],
+ throttle_secs=100)
- executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
- with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
+ executor = training._TrainingExecutor(est, train_spec, eval_spec)
+ with self.assertRaisesRegexp(ValueError,
+ 'There should be a CheckpointSaverHook'):
executor.run_local()
def test_final_export_is_true_in_the_end(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=10))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
mock_est.times_export_fn_was_called = 0
mock_est.times_the_final_export_was_true = 0
@@ -1734,37 +1733,29 @@ class TrainingExecutorRunLocalTest(test.TestCase):
exporter.export = export
train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
- input_fn=lambda: 1,
- hooks=[_FakeHook()],
- throttle_secs=100,
+ input_fn=lambda: self._input_fn(repeat=False),
+ throttle_secs=0,
exporters=exporter)
- # should be called 3 times.
- mock_est.evaluate.side_effect = [{
- _GLOBAL_STEP_KEY: train_spec.max_steps - 100
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps - 50
- }, {
- _GLOBAL_STEP_KEY: train_spec.max_steps
- }]
-
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
- self.assertEqual(3, mock_est.train.call_count)
- self.assertEqual(3, mock_est.evaluate.call_count)
- self.assertEqual(3, mock_est.times_export_fn_was_called)
+ self.assertEqual(1, mock_est.train.call_count)
+ self.assertEqual(2, mock_est.evaluate.call_count)
+ self.assertEqual(2, mock_est.times_export_fn_was_called)
self.assertEqual(1, mock_est.times_the_final_export_was_true)
def test_train_and_evaluate_args(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
+ est = estimator_lib.Estimator(model_fn=self._model_fn)
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
- mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ input_fn=lambda: self._input_fn(repeat=False),
+ steps=2,
+ hooks=[_FakeHook()],
+ name='local_eval')
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_local()
@@ -1773,11 +1764,11 @@ class TrainingExecutorRunLocalTest(test.TestCase):
name=eval_spec.name,
input_fn=eval_spec.input_fn,
steps=eval_spec.steps,
- checkpoint_path='checkpoint_path/',
+ checkpoint_path=est.latest_checkpoint(),
hooks=eval_spec.hooks)
train_args = mock_est.train.call_args[1]
- self.assertEqual(list(train_spec.hooks), list(train_args['hooks'][:-1]))
+ self.assertEqual(list(train_spec.hooks), list(train_args['hooks']))
self.assertEqual(train_spec.input_fn, train_args['input_fn'])
self.assertEqual(train_spec.max_steps, train_args['max_steps'])
@@ -1812,25 +1803,11 @@ class TrainingExecutorRunLocalTest(test.TestCase):
if not isinstance(h, training._StopAtSecsHook)
])
- def test_errors_out_if_throttle_secs_is_zero(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- train_spec = training.TrainSpec(input_fn=lambda: 1)
- eval_spec = training.EvalSpec(input_fn=lambda: 1, throttle_secs=0)
-
- executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
- with self.assertRaisesRegexp(ValueError, 'throttle_secs'):
- executor.run_local()
-
def test_that_export_is_called_with_run_local(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
- mock_train_spec.max_steps = 200
- mock_est.evaluate.return_value = {
- _GLOBAL_STEP_KEY: mock_train_spec.max_steps
- }
- # _validate_hooks would have made sure that train_spec.hooks is [], when
- # None were passed.
- mock_train_spec.hooks = []
+ est = estimator_lib.Estimator(model_fn=self._model_fn)
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
+ train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12)
+ mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
def export(estimator, *args, **kwargs):
del args, kwargs
@@ -1842,13 +1819,13 @@ class TrainingExecutorRunLocalTest(test.TestCase):
exporter.export = export
eval_spec = training.EvalSpec(
- input_fn=lambda: 1,
+ input_fn=lambda: self._input_fn(repeat=False),
steps=2,
start_delay_secs=0,
throttle_secs=213,
exporters=exporter)
- executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
# pylint: disable=assignment-from-no-return
_, export_results = executor.run_local()
# pylint: enable=assignment-from-no-return
@@ -1857,9 +1834,13 @@ class TrainingExecutorRunLocalTest(test.TestCase):
self.assertEqual(export_results, ['path_to_export'])
def test_errors_out_if_evaluate_returns_empty_dict(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- train_spec = training.TrainSpec(input_fn=lambda: 1)
- eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=2))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
+ train_spec = training.TrainSpec(input_fn=self._input_fn)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = {}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
@@ -1867,18 +1848,26 @@ class TrainingExecutorRunLocalTest(test.TestCase):
executor.run_local()
def test_errors_out_if_evaluate_returns_non_dict(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- train_spec = training.TrainSpec(input_fn=lambda: 1)
- eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=2))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
+ train_spec = training.TrainSpec(input_fn=self._input_fn)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = 123
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
executor.run_local()
def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
- train_spec = training.TrainSpec(input_fn=lambda: 1)
- eval_spec = training.EvalSpec(input_fn=(lambda: 1), throttle_secs=123)
+ est = estimator_lib.Estimator(
+ model_fn=self._model_fn,
+ config=run_config_lib.RunConfig(save_checkpoints_steps=2))
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
+ train_spec = training.TrainSpec(input_fn=self._input_fn)
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = {'loss': 123}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
@@ -1887,19 +1876,21 @@ class TrainingExecutorRunLocalTest(test.TestCase):
executor.run_local()
def test_train_and_evaluate_return_metrics(self):
- mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
- mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
+ est = estimator_lib.Estimator(model_fn=self._model_fn)
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, wraps=est)
train_spec = training.TrainSpec(
- input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()])
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='local_eval')
- mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}
+ input_fn=lambda: self._input_fn(repeat=False),
+ steps=2,
+ hooks=[_FakeHook()],
+ name='local_eval')
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
# pylint: disable=assignment-from-no-return
metrics, _ = executor.run_local()
# pylint: enable=assignment-from-no-return
- self.assertEqual(metrics['global_step'], 300)
+ self.assertEqual(metrics['global_step'], 12)
class TrainAndEvaluateRunTest(test.TestCase):
@@ -2096,7 +2087,7 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
# max_steps should be larger than save_summary_steps
max_steps = 10
- save_summary_steps = 2
+ save_summary_steps = 9
data = np.linspace(
0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
@@ -2104,24 +2095,20 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
# learn y = x
- train_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data},
- y=y_data,
- batch_size=batch_size,
- num_epochs=None,
- shuffle=True)
-
- eval_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data},
- y=y_data,
- batch_size=batch_size,
- num_epochs=1,
- shuffle=False)
-
- predict_input_fn = numpy_io.numpy_input_fn(
- x={'x': x_data},
- batch_size=batch_size,
- shuffle=False)
+ def train_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices(({
+ 'x': x_data
+ }, y_data)).batch(batch_size).repeat().shuffle(1000)
+
+ def eval_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices(({
+ 'x': x_data
+ }, y_data)).batch(batch_size)
+
+ def predict_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ 'x': x_data
+ }).batch(batch_size)
feature_columns = [
feature_column.numeric_column('x', shape=(input_dimension,))]
@@ -2137,9 +2124,11 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
max_steps=max_steps)
eval_spec = training.EvalSpec(
- name=eval_name, input_fn=eval_input_fn, steps=None,
+ name=eval_name,
+ input_fn=eval_input_fn,
+ steps=None,
exporters=self._get_exporter(exporter_name, feature_columns),
- throttle_secs=2)
+ throttle_secs=0)
training.train_and_evaluate(est, train_spec, eval_spec)
@@ -2148,15 +2137,12 @@ class TrainAndEvaluateIntegrationTest(test.TestCase):
# Examine the training events. Use a range to check global step to avoid
# flakyness due to global step race condition.
- training_loss, training_global_step = self._extract_loss_and_global_step(
- est.model_dir)
+ training_loss, _ = self._extract_loss_and_global_step(est.model_dir)
self.assertIsNotNone(training_loss)
- self.assertTrue(
- max_steps - save_summary_steps < training_global_step <= max_steps)
# Examine the eval events. The global step should be accurate.
eval_loss, eval_global_step = self._extract_loss_and_global_step(
- event_folder=os.path.join(est.model_dir, 'eval_' + eval_name))
+ event_folder=est.eval_dir(eval_name))
self.assertIsNotNone(eval_loss)
self.assertEqual(max_steps, eval_global_step)
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 924ca309ff..d4a75478d5 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import os
import time
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -129,3 +130,24 @@ class _DatasetInitializerHook(training.SessionRunHook):
def after_create_session(self, session, coord):
del coord
session.run(self._initializer)
+
+
+class StrategyInitFinalizeHook(training.SessionRunHook):
+ """Creates a SessionRunHook that initializes and shutsdown devices."""
+
+ def __init__(self, initialization_fn, finalize_fn):
+ self._initialization_fn = initialization_fn
+ self._finalize_fn = finalize_fn
+
+ def begin(self):
+ self._init_ops = self._initialization_fn()
+ self._finalize_ops = self._finalize_fn()
+
+ def after_create_session(self, session, coord):
+ logging.info('Initialize system')
+ session.run(self._init_ops,
+ options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
+
+ def end(self, session):
+ logging.info('Finalize system.')
+ session.run(self._finalize_ops)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 295d4ca094..80707030e6 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -48,6 +48,39 @@ py_library(
],
)
+py_library(
+ name = "feature_column_v2",
+ srcs = ["feature_column_v2.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:template",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
filegroup(
name = "vocabulary_testdata",
srcs = [
@@ -92,3 +125,38 @@ py_test(
"//tensorflow/python/estimator:numpy_io",
],
)
+
+py_test(
+ name = "feature_column_v2_test",
+ srcs = ["feature_column_v2_test.py"],
+ data = [":vocabulary_testdata"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_pip",
+ ],
+ deps = [
+ ":feature_column_py",
+ ":feature_column_v2",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:partitioned_variables",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:numpy_io",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 7aa46af828..d091d2fe0a 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -172,7 +172,7 @@ def _internal_input_layer(features,
scope=None):
"""See input_layer. `scope` is a name or variable scope to use."""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
for column in feature_columns:
if not isinstance(column, _DenseColumn):
raise ValueError(
@@ -350,10 +350,23 @@ def linear_model(features,
prediction itself for linear regression problems.
Note on supported columns: `linear_model` treats categorical columns as
- `indicator_column`s while `input_layer` explicitly requires wrapping each
- of them with an `embedding_column` or an `indicator_column`.
+ `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
+ like:
- Example:
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+ `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
+ just like `indicator_column`, while `input_layer` explicitly requires wrapping
+ each of categorical columns with an `embedding_column` or an
+ `indicator_column`.
+
+ Example of usage:
```python
price = numeric_column('price')
@@ -374,13 +387,44 @@ def linear_model(features,
to your model. All items should be instances of classes derived from
`_FeatureColumn`s.
units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a sparse column is
- multivalent. Currently "mean", "sqrtn" and "sum" are supported, with "sum"
- the default. "sqrtn" often achieves good accuracy, in particular with
- bag-of-words columns. It combines each sparse columns independently.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum" are
+ supported, with "sum" the default for linear model. "sqrtn" often achieves
+ good accuracy, in particular with bag-of-words columns.
* "sum": do not normalize features in the column
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are:
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
weight_collections: A list of collection names to which the Variable will be
added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
@@ -408,13 +452,15 @@ def linear_model(features,
ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
nor `_CategoricalColumn`.
"""
+ with variable_scope.variable_scope(None, 'linear_model') as vs:
+ model_name = _strip_leading_slashes(vs.name)
linear_model_layer = _LinearModel(
feature_columns=feature_columns,
units=units,
sparse_combiner=sparse_combiner,
weight_collections=weight_collections,
trainable=trainable,
- name='linear_model')
+ name=model_name)
retval = linear_model_layer(features) # pylint: disable=not-callable
if cols_to_vars is not None:
cols_to_vars.update(linear_model_layer.cols_to_vars())
@@ -422,13 +468,25 @@ def linear_model(features,
def _add_to_collections(var, weight_collections):
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collections(weight_collections, constituent_var)
- else:
- ops.add_to_collections(weight_collections, var)
+ """Adds a var to the list of weight_collections provided.
+
+ Handles the case for partitioned and non-partitioned variables.
+
+ Args:
+ var: A variable or Partitioned Variable.
+ weight_collections: List of collections to add variable to.
+ """
+ for weight_collection in weight_collections:
+ # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
+ if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
+ continue
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collection(weight_collection, constituent_var)
+ else:
+ ops.add_to_collection(weight_collection, var)
class _FCLinearWrapper(base.Layer):
@@ -536,8 +594,11 @@ class _LinearModel(training.Model):
name=None,
**kwargs):
super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = _clean_feature_columns(feature_columns)
+ self._feature_columns = _normalize_feature_columns(
+ feature_columns)
self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
@@ -643,7 +704,7 @@ def _transform_features(features, feature_columns):
Returns:
A `dict` mapping `_FeatureColumn` to `Tensor` and `SparseTensor` values.
"""
- feature_columns = _clean_feature_columns(feature_columns)
+ feature_columns = _normalize_feature_columns(feature_columns)
outputs = {}
with ops.name_scope(
None, default_name='transform_features', values=features.values()):
@@ -911,7 +972,8 @@ def shared_embedding_columns(
tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
which to restore the column weights. Required if `ckpt_to_load_from` is
not `None`.
- max_norm: If not `None`, embedding values are l2-normalized to this value.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
trainable: Whether or not the embedding is trainable. Default is True.
Returns:
@@ -925,7 +987,12 @@ def shared_embedding_columns(
ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
is specified.
ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: if eager execution is enabled.
"""
+ if context.executing_eagerly():
+ raise RuntimeError('shared_embedding_columns are not supported when eager '
+ 'execution is enabled.')
+
if (dimension is None) or (dimension < 1):
raise ValueError('Invalid dimension {}.'.format(dimension))
if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
@@ -970,16 +1037,6 @@ def shared_embedding_columns(
shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
shared_embedding_collection_name += '_shared_embedding'
- # Create the state (_SharedEmbeddingColumnLayer) here.
- embedding_shape = num_buckets, dimension
-
- shared_embedding_column_layer = _EmbeddingColumnLayer(
- embedding_shape=embedding_shape,
- initializer=initializer,
- weight_collections=[],
- trainable=trainable,
- name=shared_embedding_collection_name)
-
result = []
for column in categorical_columns:
result.append(
@@ -988,16 +1045,12 @@ def shared_embedding_columns(
initializer=initializer,
dimension=dimension,
combiner=combiner,
- var_scope_name=shared_embedding_collection_name,
+ shared_embedding_collection_name=shared_embedding_collection_name,
ckpt_to_load_from=ckpt_to_load_from,
tensor_name_in_ckpt=tensor_name_in_ckpt,
max_norm=max_norm,
trainable=trainable))
- for single_result in result:
- single_result._set_layer(shared_embedding_column_layer) # pylint: disable=protected-access
- single_result._set_all_columns(result) # pylint: disable=protected-access
-
return result
@@ -1182,12 +1235,13 @@ def categorical_column_with_hash_bucket(key,
Use this when your sparse features are in string or integer format, and you
want to distribute your inputs into a finite number of buckets by hashing.
- output_id = Hash(input_feature_string) % bucket_size
+ output_id = Hash(input_feature_string) % bucket_size for string type input.
+ For int type input, the value is converted to its string representation first
+ and then hashed by the same formula.
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example:
@@ -1249,8 +1303,7 @@ def categorical_column_with_vocabulary_file(key,
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
@@ -1366,8 +1419,7 @@ def categorical_column_with_vocabulary_list(
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
Example with `num_oov_buckets`:
In the following example, each input in `vocabulary_list` is assigned an ID
@@ -1480,8 +1532,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
For input dictionary `features`, `features[key]` is either `Tensor` or
`SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
- and `''` for string. Note that these values are independent of the
- `default_value` argument.
+ and `''` for string, which will be dropped by this feature column.
In the following examples, each input in the range `[0, 1000000)` is assigned
the same value. All other inputs are assigned `default_value` 0. Note that a
@@ -1538,8 +1589,14 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
def indicator_column(categorical_column):
"""Represents multi-hot representation of given categorical column.
- Used to wrap any `categorical_column_*` (e.g., to feed to DNN). Use
- `embedding_column` if the inputs are sparse.
+ - For DNN model, `indicator_column` can be used to wrap any
+ `categorical_column_*` (e.g., to feed to DNN). Consider to Use
+ `embedding_column` if the number of buckets/unique(values) are large.
+
+ - For Wide (aka linear) model, `indicator_column` is the internal
+ representation for categorical column when passing categorical column
+ directly (as any element in feature_columns) to `linear_model`. See
+ `linear_model` for details.
```python
name = indicator_column(categorical_column_with_vocabulary_list(
@@ -1782,9 +1839,7 @@ class _EmbeddingColumnLayer(base.Layer):
Args:
embedding_shape: Shape of the embedding variable used for lookup.
initializer: A variable initializer function to be used in embedding
- variable initialization. If not specified, defaults to
- `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
- `1/sqrt(dimension)`.
+ variable initialization.
weight_collections: A list of collection names to which the Variable will
be added. Note that, variables will also be added to collections
`tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
@@ -1799,6 +1854,15 @@ class _EmbeddingColumnLayer(base.Layer):
self._initializer = initializer
self._weight_collections = weight_collections
+ def set_weight_collections(self, weight_collections):
+ """Sets the weight collections for the layer.
+
+ Args:
+ weight_collections: A list of collection names to which the Variable will
+ be added.
+ """
+ self._weight_collections = weight_collections
+
def build(self, _):
self._embedding_weight_var = self.add_variable(
name='embedding_weights',
@@ -1806,11 +1870,8 @@ class _EmbeddingColumnLayer(base.Layer):
dtype=dtypes.float32,
initializer=self._initializer,
trainable=self.trainable)
- # self.add_variable already appends to GLOBAL_VARIABLES collection.
if self._weight_collections and not context.executing_eagerly():
- for weight_collection in self._weight_collections:
- if weight_collection != ops.GraphKeys.GLOBAL_VARIABLES:
- _add_to_collections(self._embedding_weight_var, [weight_collection])
+ _add_to_collections(self._embedding_weight_var, self._weight_collections)
self.built = True
def call(self, _):
@@ -1949,7 +2010,7 @@ def _create_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Creates a weighted sum for a dense or sparse column for linear_model."""
+ """Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, _CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
@@ -2048,7 +2109,34 @@ def _create_categorical_column_weighted_sum(column,
weight_collections,
trainable,
weight_var=None):
- """Create a weighted sum of a categorical column for linear_model."""
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """Create a weighted sum of a categorical column for linear_model.
+
+ Note to maintainer: As implementation details, the weighted sum is
+ implemented via embedding_lookup_sparse toward efficiency. Mathematically,
+ they are the same.
+
+ To be specific, conceptually, categorical column can be treated as multi-hot
+ vector. Say:
+
+ ```python
+ x = [0 0 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `c` in this case, which is same as `w[2]`.
+
+ Another example is
+
+ ```python
+ x = [0 1 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
+
+ For both cases, we can implement weighted sum via embedding_lookup with
+ sparse_combiner = "sum".
+ """
+
sparse_tensors = column._get_sparse_tensors( # pylint: disable=protected-access
builder,
weight_collections=weight_collections,
@@ -2070,7 +2158,7 @@ def _create_categorical_column_weighted_sum(column,
initializer=init_ops.zeros_initializer(),
trainable=trainable,
collections=weight_collections)
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
weight,
id_tensor,
sparse_weights=weight_tensor,
@@ -2242,7 +2330,7 @@ def _shape_offsets(shape):
# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
-def _to_sparse_input(input_tensor, ignore_value=None):
+def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
"""Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
If `input_tensor` is already a `SparseTensor`, just return it.
@@ -2286,8 +2374,22 @@ def _to_sparse_input(input_tensor, ignore_value=None):
input_tensor, out_type=dtypes.int64, name='dense_shape'))
-def _clean_feature_columns(feature_columns):
- """Verifies and normalizes `feature_columns` input."""
+def _normalize_feature_columns(feature_columns):
+ """Normalizes the `feature_columns` input.
+
+ This method converts the `feature_columns` to list type as best as it can. In
+ addition, verifies the type and other parts of feature_columns, required by
+ downstream library.
+
+ Args:
+ feature_columns: The raw feature columns, usually passed by users.
+
+ Returns:
+ The normalized feature column list.
+
+ Raises:
+ ValueError: for any invalid inputs, such as empty, duplicated names, etc.
+ """
if isinstance(feature_columns, _FeatureColumn):
feature_columns = [feature_columns]
@@ -2413,6 +2515,7 @@ class _BucketizedColumn(_DenseColumn, _CategoricalColumn,
def _get_sparse_tensors(self, inputs, weight_collections=None,
trainable=None):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
input_tensor = inputs.get(self)
batch_size = array_ops.shape(input_tensor)[0]
# By construction, source_column is always one-dimensional.
@@ -2491,7 +2594,7 @@ class _EmbeddingColumn(
})
# Return embedding lookup result.
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
@@ -2546,12 +2649,12 @@ def _get_graph_for_variable(var):
class _SharedEmbeddingColumn(
- _DenseColumn,
+ _DenseColumn, _SequenceDenseColumn,
collections.namedtuple(
'_SharedEmbeddingColumn',
('categorical_column', 'dimension', 'combiner', 'initializer',
- 'var_scope_name', 'ckpt_to_load_from', 'tensor_name_in_ckpt',
- 'max_norm', 'trainable'))):
+ 'shared_embedding_collection_name', 'ckpt_to_load_from',
+ 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
"""See `embedding_column`."""
@property
@@ -2562,7 +2665,7 @@ class _SharedEmbeddingColumn(
@property
def _var_scope_name(self):
- return self.var_scope_name
+ return self.shared_embedding_collection_name
@property
def _parse_example_spec(self):
@@ -2571,29 +2674,17 @@ class _SharedEmbeddingColumn(
def _transform_feature(self, inputs):
return inputs.get(self.categorical_column)
- def _set_layer(self, layer):
- self._layer = layer
-
- def _set_all_columns(self, all_columns):
- self._all_columns = all_columns
-
- def _reset_config(self):
- config = self._layer.get_config()
- config['embedding_shape'] = (
- self.categorical_column._num_buckets, # pylint: disable=protected-access
- self.dimension)
- config['initializer'] = self.initializer
- self._layer = self._layer.__class__.from_config(config)
- for column in self._all_columns:
- column._set_layer(self._layer) # pylint: disable=protected-access
-
@property
def _variable_shape(self):
if not hasattr(self, '_shape'):
self._shape = tensor_shape.vector(self.dimension)
return self._shape
- def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ def _get_dense_tensor_internal(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
# which is shared among all shared embeddings. Open a name_scope here, so
# that the ops for different columns have distinct names.
@@ -2604,17 +2695,38 @@ class _SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_weights = self._layer(
- None, scope=variable_scope.get_variable_scope())
- # If we're in graph mode and this is called with a different graph,
- # then we should reset.
- if not context.executing_eagerly() and (
- ops.get_default_graph() !=
- _get_graph_for_variable(embedding_weights)):
- self._reset_config()
- embedding_weights = self._layer(
- None, scope=variable_scope.get_variable_scope())
-
+ embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access
+ shared_embedding_collection = ops.get_collection(
+ self.shared_embedding_collection_name)
+ if shared_embedding_collection:
+ if len(shared_embedding_collection) > 1:
+ raise ValueError(
+ 'Collection {} can only contain one variable. '
+ 'Suggested fix A: Choose a unique name for this collection. '
+ 'Suggested fix B: Do not add any variables to this collection. '
+ 'The feature_column library already adds a variable under the '
+ 'hood.'.format(shared_embedding_collection))
+ embedding_weights = shared_embedding_collection[0]
+ if embedding_weights.get_shape() != embedding_shape:
+ raise ValueError(
+ 'Shared embedding collection {} contains variable {} of '
+ 'unexpected shape {}. Expected shape is {}. '
+ 'Suggested fix A: Choose a unique name for this collection. '
+ 'Suggested fix B: Do not add any variables to this collection. '
+ 'The feature_column library already adds a variable under the '
+ 'hood.'.format(self.shared_embedding_collection_name,
+ embedding_weights.name,
+ embedding_weights.get_shape(), embedding_shape))
+ else:
+ embedding_weights = variable_scope.get_variable(
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer,
+ trainable=self.trainable and trainable,
+ collections=weight_collections)
+ ops.add_to_collection(self.shared_embedding_collection_name,
+ embedding_weights)
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
if isinstance(to_restore, variables.PartitionedVariable):
@@ -2624,7 +2736,7 @@ class _SharedEmbeddingColumn(
})
# Return embedding lookup result.
- return _safe_embedding_lookup_sparse(
+ return embedding_ops.safe_embedding_lookup_sparse(
embedding_weights=embedding_weights,
sparse_ids=sparse_ids,
sparse_weights=sparse_weights,
@@ -2632,6 +2744,44 @@ class _SharedEmbeddingColumn(
name='%s_weights' % self.name,
max_norm=self.max_norm)
+ def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
+ if isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type _SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+
+ def _get_sequence_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ if not isinstance(self.categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type _SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
def _create_tuple(shape, value):
"""Returns a tuple with given shape and filled with value."""
@@ -2753,7 +2903,7 @@ class _HashedCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
raise ValueError('SparseColumn input must be a SparseTensor.')
@@ -2804,7 +2954,7 @@ class _VocabularyFileCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2856,7 +3006,7 @@ class _VocabularyListCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(self.dtype)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if self.dtype.is_integer != input_tensor.dtype.is_integer:
raise ValueError(
@@ -2908,7 +3058,7 @@ class _IdentityCategoricalColumn(
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
def _transform_feature(self, inputs):
- input_tensor = _to_sparse_input(inputs.get(self.key))
+ input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
if not input_tensor.dtype.is_integer:
raise ValueError(
@@ -2990,7 +3140,8 @@ class _WeightedCategoricalColumn(
self.dtype, weight_tensor.dtype))
if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
# The weight tensor can be a regular Tensor. In this case, sparsify it.
- weight_tensor = _to_sparse_input(weight_tensor, ignore_value=0.0)
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
if not weight_tensor.dtype.is_floating:
weight_tensor = math_ops.to_float(weight_tensor)
return (inputs.get(self.categorical_column), weight_tensor)
@@ -3077,161 +3228,6 @@ def _collect_leaf_level_keys(cross):
return leaf_level_keys
-# TODO(zakaria): Move this to embedding_ops and make it public.
-def _safe_embedding_lookup_sparse(embedding_weights,
- sparse_ids,
- sparse_weights=None,
- combiner='mean',
- default_id=None,
- name=None,
- partition_strategy='div',
- max_norm=None):
- """Lookup embedding results, accounting for invalid IDs and empty features.
-
- The partitioned embedding in `embedding_weights` must all be the same shape
- except for the first dimension. The first dimension is allowed to vary as the
- vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
- may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
- partitioner.
-
- Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
- with non-positive weight. For an entry with no features, the embedding vector
- for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
-
- The ids and weights may be multi-dimensional. Embeddings are always aggregated
- along the last dimension.
-
- Args:
- embedding_weights: A list of `P` float `Tensor`s or values representing
- partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
- created by partitioning along dimension 0. The total unpartitioned
- shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
- vocab size and `e_1, ..., e_m` are the embedding dimensions.
- sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
- ids. `d_0` is typically batch size.
- sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
- float weights corresponding to `sparse_ids`, or `None` if all weights
- are be assumed to be 1.0.
- combiner: A string specifying how to combine embedding results for each
- entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
- the default.
- default_id: The id to use for an entry with no features.
- name: A name for this operation (optional).
- partition_strategy: A string specifying the partitioning strategy.
- Currently `"div"` and `"mod"` are supported. Default is `"div"`.
- max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
- combining.
-
-
- Returns:
- Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
-
- Raises:
- ValueError: if `embedding_weights` is empty.
- """
- if embedding_weights is None:
- raise ValueError('Missing embedding_weights %s.' % embedding_weights)
- if isinstance(embedding_weights, variables.PartitionedVariable):
- embedding_weights = list(embedding_weights) # get underlying Variables.
- if not isinstance(embedding_weights, list):
- embedding_weights = [embedding_weights]
- if len(embedding_weights) < 1:
- raise ValueError('Missing embedding_weights %s.' % embedding_weights)
-
- dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
-
- with ops.name_scope(name, 'embedding_lookup',
- embedding_weights + [sparse_ids,
- sparse_weights]) as scope:
- # Reshape higher-rank sparse ids and weights to linear segment ids.
- original_shape = sparse_ids.dense_shape
- original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
- original_rank = (
- array_ops.size(original_shape)
- if original_rank_dim.value is None
- else original_rank_dim.value)
- sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
- math_ops.reduce_prod(
- array_ops.slice(original_shape, [0], [original_rank - 1])),
- array_ops.gather(original_shape, original_rank - 1)])
- if sparse_weights is not None:
- sparse_weights = sparse_tensor_lib.SparseTensor(
- sparse_ids.indices,
- sparse_weights.values, sparse_ids.dense_shape)
-
- # Prune invalid ids and weights.
- sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
- if combiner != 'sum':
- sparse_ids, sparse_weights = _prune_invalid_weights(
- sparse_ids, sparse_weights)
-
- # Fill in dummy values for empty features, if necessary.
- sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
- default_id or
- 0)
- if sparse_weights is not None:
- sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
-
- result = embedding_ops.embedding_lookup_sparse(
- embedding_weights,
- sparse_ids,
- sparse_weights,
- combiner=combiner,
- partition_strategy=partition_strategy,
- name=None if default_id is None else scope,
- max_norm=max_norm)
-
- if default_id is None:
- # Broadcast is_row_empty to the same shape as embedding_lookup_result,
- # for use in Select.
- is_row_empty = array_ops.tile(
- array_ops.reshape(is_row_empty, [-1, 1]),
- array_ops.stack([1, array_ops.shape(result)[1]]))
-
- result = array_ops.where(is_row_empty,
- array_ops.zeros_like(result),
- result,
- name=scope)
-
- # Reshape back from linear ids back into higher-dimensional dense result.
- final_result = array_ops.reshape(
- result,
- array_ops.concat([
- array_ops.slice(
- math_ops.cast(original_shape, dtypes.int32), [0],
- [original_rank - 1]),
- array_ops.slice(array_ops.shape(result), [1], [-1])
- ], 0))
- final_result.set_shape(tensor_shape.unknown_shape(
- (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
- return final_result
-
-
-def _prune_invalid_ids(sparse_ids, sparse_weights):
- """Prune invalid IDs (< 0) from the input ids and weights."""
- is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
- if sparse_weights is not None:
- is_id_valid = math_ops.logical_and(
- is_id_valid,
- array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
- sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
- if sparse_weights is not None:
- sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
- return sparse_ids, sparse_weights
-
-
-def _prune_invalid_weights(sparse_ids, sparse_weights):
- """Prune invalid weights (< 0) from the input ids and weights."""
- if sparse_weights is not None:
- is_weights_valid = math_ops.greater(sparse_weights.values, 0)
- sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
- sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
- return sparse_ids, sparse_weights
-
-
class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_IndicatorColumn',
['categorical_column'])):
@@ -3268,10 +3264,14 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
sp_ids=id_tensor,
sp_values=weight_tensor,
vocab_size=int(self._variable_shape[-1]))
- # Remove (?, -1) index
+ # Remove (?, -1) index.
weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
weighted_column.dense_shape)
- return sparse_ops.sparse_tensor_to_dense(weighted_column)
+ # Use scatter_nd to merge duplicated indices if existed,
+ # instead of sparse_tensor_to_dense.
+ return array_ops.scatter_nd(weighted_column.indices,
+ weighted_column.values,
+ weighted_column.dense_shape)
dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
id_tensor, default_value=-1)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 0af7b9baa9..5bb47bfa47 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1257,14 +1257,14 @@ class CrossedColumnTest(test.TestCase):
}, (crossed,))
-def get_linear_model_bias():
- with variable_scope.variable_scope('linear_model', reuse=True):
+def get_linear_model_bias(name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
return variable_scope.get_variable('bias_weights')
-def get_linear_model_column_var(column):
+def get_linear_model_column_var(column, name='linear_model'):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- 'linear_model/' + column.name)[0]
+ name + '/' + column.name)[0]
def get_keras_linear_model_predictions(features,
@@ -1928,6 +1928,27 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
+ def test_multiple_linear_models(self):
+ price = fc.numeric_column('price')
+ with ops.Graph().as_default():
+ features1 = {'price': [[1.], [5.]]}
+ features2 = {'price': [[2.], [10.]]}
+ predictions1 = fc.linear_model(features1, [price])
+ predictions2 = fc.linear_model(features2, [price])
+ bias1 = get_linear_model_bias(name='linear_model')
+ bias2 = get_linear_model_bias(name='linear_model_1')
+ price_var1 = get_linear_model_column_var(price, name='linear_model')
+ price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias1.eval())
+ sess.run(price_var1.assign([[10.]]))
+ sess.run(bias1.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions1.eval())
+ self.assertAllClose([0.], bias2.eval())
+ sess.run(price_var2.assign([[10.]]))
+ sess.run(bias2.assign([5.]))
+ self.assertAllClose([[25.], [105.]], predictions2.eval())
+
class _LinearModelTest(test.TestCase):
@@ -2586,7 +2607,7 @@ class _LinearModelTest(test.TestCase):
class InputLayerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_retrieving_input(self):
features = {'a': [0.]}
input_layer = InputLayer(fc.numeric_column('a'))
@@ -4559,12 +4580,12 @@ class IndicatorColumnTest(test.TestCase):
weights = fc.weighted_categorical_column(ids, 'weights')
indicator = fc.indicator_column(weights)
features = {
- 'ids': constant_op.constant([['c', 'b', 'a']]),
- 'weights': constant_op.constant([[2., 4., 6.]])
+ 'ids': constant_op.constant([['c', 'b', 'a', 'c']]),
+ 'weights': constant_op.constant([[2., 4., 6., 1.]])
}
indicator_tensor = _transform_features(features, [indicator])[indicator]
with _initialized_session():
- self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+ self.assertAllEqual([[6., 4., 3.]], indicator_tensor.eval())
def test_transform_with_missing_value_in_weighted_column(self):
# Github issue 12583
@@ -5329,9 +5350,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(embedding_column_a.ckpt_to_load_from)
self.assertIsNone(embedding_column_b.ckpt_to_load_from)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('aaa_bbb_shared_embedding',
- embedding_column_b.var_scope_name)
+ embedding_column_b.shared_embedding_collection_name)
self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
self.assertIsNone(embedding_column_a.max_norm)
@@ -5378,9 +5399,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('my_combiner', embedding_column_b.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_b.var_scope_name)
+ embedding_column_b.shared_embedding_collection_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
@@ -5431,7 +5452,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertEqual(embedding_dimension, embedding_column_a.dimension)
self.assertEqual('my_combiner', embedding_column_a.combiner)
self.assertEqual('shared_embedding_collection_name',
- embedding_column_a.var_scope_name)
+ embedding_column_a.shared_embedding_collection_name)
self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
self.assertEqual(42., embedding_column_a.max_norm)
@@ -5615,6 +5636,72 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+ def test_get_dense_tensor_weight_collections(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ fc.input_layer(
+ input_features, [embedding_column_a, embedding_column_b],
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in global_vars))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in my_vars))
+
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
new file mode 100644
index 0000000000..b4dd23f58d
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -0,0 +1,3600 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This API defines FeatureColumn abstraction.
+
+FeatureColumns provide a high level abstraction for ingesting and representing
+features. FeatureColumns are also the primary way of encoding features for
+canned @{tf.estimator.Estimator}s.
+
+When using FeatureColumns with `Estimators`, the type of feature column you
+should choose depends on (1) the feature type and (2) the model type.
+
+1. Feature type:
+
+ * Continuous features can be represented by `numeric_column`.
+ * Categorical features can be represented by any `categorical_column_with_*`
+ column:
+ - `categorical_column_with_vocabulary_list`
+ - `categorical_column_with_vocabulary_file`
+ - `categorical_column_with_hash_bucket`
+ - `categorical_column_with_identity`
+ - `weighted_categorical_column`
+
+2. Model type:
+
+ * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
+
+ Continuous features can be directly fed into deep neural network models.
+
+ age_column = numeric_column("age")
+
+ To feed sparse features into DNN models, wrap the column with
+ `embedding_column` or `indicator_column`. `indicator_column` is recommended
+ for features with only a few possible values. For features with many
+ possible values, to reduce the size of your model, `embedding_column` is
+ recommended.
+
+ embedded_dept_column = embedding_column(
+ categorical_column_with_vocabulary_list(
+ "department", ["math", "philosophy", ...]), dimension=10)
+
+ * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
+
+ Sparse features can be fed directly into linear models. They behave like an
+ indicator column but with an efficient implementation.
+
+ dept_column = categorical_column_with_vocabulary_list("department",
+ ["math", "philosophy", "english"])
+
+ It is recommended that continuous features be bucketized before being
+ fed into linear models.
+
+ bucketized_age_column = bucketized_column(
+ source_column=age_column,
+ boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
+
+ Sparse features can be crossed (also known as conjuncted or combined) in
+ order to form non-linearities, and then fed into linear models.
+
+ cross_dept_age_column = crossed_column(
+ columns=["department", bucketized_age_column],
+ hash_bucket_size=1000)
+
+Example of building canned `Estimator`s using FeatureColumns:
+
+ ```python
+ # Define features and transformations
+ deep_feature_columns = [age_column, embedded_dept_column]
+ wide_feature_columns = [dept_column, bucketized_age_column,
+ cross_dept_age_column]
+
+ # Build deep model
+ estimator = DNNClassifier(
+ feature_columns=deep_feature_columns,
+ hidden_units=[500, 250, 50])
+ estimator.train(...)
+
+ # Or build a wide model
+ estimator = LinearClassifier(
+ feature_columns=wide_feature_columns)
+ estimator.train(...)
+
+ # Or build a wide and deep model!
+ estimator = DNNLinearCombinedClassifier(
+ linear_feature_columns=wide_feature_columns,
+ dnn_feature_columns=deep_feature_columns,
+ dnn_hidden_units=[500, 250, 50])
+ estimator.train(...)
+ ```
+
+
+FeatureColumns can also be transformed into a generic input layer for
+custom models using `input_layer`.
+
+Example of building model using FeatureColumns, this can be used in a
+`model_fn` which is given to the {tf.estimator.Estimator}:
+
+ ```python
+ # Building model via layers
+
+ deep_feature_columns = [age_column, embedded_dept_column]
+ columns_to_tensor = parse_feature_columns_from_examples(
+ serialized=my_data,
+ feature_columns=deep_feature_columns)
+ first_layer = input_layer(
+ features=columns_to_tensor,
+ feature_columns=deep_feature_columns)
+ second_layer = fully_connected(first_layer, ...)
+ ```
+
+NOTE: Functions prefixed with "_" indicate experimental or private parts of
+the API subject to change, and should not be relied upon!
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import math
+
+import numpy as np
+import six
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.feature_column import feature_column as fc_old
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine import training
+from tensorflow.python.layers import base
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.util import nest
+
+
+def _internal_input_layer(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None,
+ scope=None):
+ """See input_layer. `scope` is a name or variable scope to use."""
+
+ feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access
+ for column in feature_columns:
+ if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access
+ raise ValueError(
+ 'Items of feature_columns must be a _DenseColumn. '
+ 'You can wrap a categorical column with an '
+ 'embedding_column or indicator_column. Given: {}'.format(column))
+ weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
+ weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
+ weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ # a non-None `scope` can allow for variable reuse, when, e.g., this function
+ # is wrapped by a `make_template`.
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
+ output_tensors = []
+ ordered_columns = []
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name): # pylint: disable=protected-access
+ tensor = column._get_dense_tensor( # pylint: disable=protected-access
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
+ batch_size = array_ops.shape(tensor)[0]
+ output_tensors.append(
+ array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+ if cols_to_vars is not None:
+ # Retrieve any variables created (some _DenseColumn's don't create
+ # variables, in which case an empty list is returned).
+ cols_to_vars[column] = ops.get_collection(
+ ops.GraphKeys.GLOBAL_VARIABLES,
+ scope=variable_scope.get_variable_scope().name)
+ _verify_static_batch_size_equality(output_tensors, ordered_columns)
+ return array_ops.concat(output_tensors, 1)
+
+
+def input_layer(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+ Generally a single example in training data is described with FeatureColumns.
+ At the first layer of the model, this column oriented data should be converted
+ to a single `Tensor`.
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ keywords_embedded = embedding_column(
+ categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
+ columns = [price, keywords_embedded, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ for units in [128, 64, 32]:
+ dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
+ prediction = tf.layers.dense(dense_tensor, 1)
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values can be a `SparseTensor` or a `Tensor` depends on
+ corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as inputs
+ to your model. All items should be instances of classes derived from
+ `_DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical features,
+ you can wrap them with an `embedding_column` or `indicator_column`.
+ weight_collections: A list of collection names to which the Variable will be
+ added. Note that variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with a
+ mapping from `_FeatureColumn` to list of `Variable`s. For example, after
+ the call, we might have cols_to_vars =
+ {_EmbeddingColumn(
+ categorical_column=_HashedCategoricalColumn(
+ key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
+ dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
+ <tf.Variable 'some_variable:1' shape=(5, 10)]}
+ If a column creates no variables, its value will be an empty list.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
+ """
+ return _internal_input_layer(features, feature_columns, weight_collections,
+ trainable, cols_to_vars)
+
+
+# TODO(akshayka): InputLayer should be a subclass of Layer, and it
+# should implement the logic in input_layer using Layer's build-and-call
+# paradigm; input_layer should create an instance of InputLayer and
+# return the result of invoking its apply method, just as functional layers do.
+class InputLayer(object):
+ """An object-oriented version of `input_layer` that reuses variables."""
+
+ def __init__(self,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """See `input_layer`."""
+
+ self._feature_columns = feature_columns
+ self._weight_collections = weight_collections
+ self._trainable = trainable
+ self._cols_to_vars = cols_to_vars
+ self._input_layer_template = template.make_template(
+ 'feature_column_input_layer',
+ _internal_input_layer,
+ create_scope_now_=True)
+ self._scope = self._input_layer_template.variable_scope
+
+ def __call__(self, features):
+ return self._input_layer_template(
+ features=features,
+ feature_columns=self._feature_columns,
+ weight_collections=self._weight_collections,
+ trainable=self._trainable,
+ cols_to_vars=None,
+ scope=self._scope)
+
+ @property
+ def non_trainable_variables(self):
+ return self._input_layer_template.non_trainable_variables
+
+ @property
+ def non_trainable_weights(self):
+ return self._input_layer_template.non_trainable_weights
+
+ @property
+ def trainable_variables(self):
+ return self._input_layer_template.trainable_variables
+
+ @property
+ def trainable_weights(self):
+ return self._input_layer_template.trainable_weights
+
+ @property
+ def variables(self):
+ return self._input_layer_template.variables
+
+ @property
+ def weights(self):
+ return self._input_layer_template.weights
+
+
+def linear_model(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ """Returns a linear prediction `Tensor` based on given `feature_columns`.
+
+ This function generates a weighted sum based on output dimension `units`.
+ Weighted sum refers to logits in classification problems. It refers to the
+ prediction itself for linear regression problems.
+
+ Note on supported columns: `linear_model` treats categorical columns as
+ `indicator_column`s. To be specific, assume the input as `SparseTensor` looks
+ like:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+ `linear_model` assigns weights for the presence of "a", "b", "c' implicitly,
+ just like `indicator_column`, while `input_layer` explicitly requires wrapping
+ each of categorical columns with an `embedding_column` or an
+ `indicator_column`.
+
+ Example of usage:
+
+ ```python
+ price = numeric_column('price')
+ price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])
+ keywords = categorical_column_with_hash_bucket("keywords", 10K)
+ keywords_price = crossed_column('keywords', price_buckets, ...)
+ columns = [price_buckets, keywords, keywords_price ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as inputs
+ to your model. All items should be instances of classes derived from
+ `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum" are
+ supported, with "sum" the default for linear model. "sqrtn" often achieves
+ good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are:
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ weight_collections: A list of collection names to which the Variable will be
+ added. Note that, variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with a
+ mapping from `_FeatureColumn` to associated list of `Variable`s. For
+ example, after the call, we might have cols_to_vars = {
+ _NumericColumn(
+ key='numeric_feature1', shape=(1,):
+ [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
+ 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
+ _NumericColumn(
+ key='numeric_feature2', shape=(2,)):
+ [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
+ If a column creates no variables, its value will be an empty list. Note
+ that cols_to_vars will also contain a string key 'bias' that maps to a
+ list of Variables.
+
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its shape
+ is (batch_size, units) and its dtype is `float32`.
+
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
+ nor `_CategoricalColumn`.
+ """
+ with variable_scope.variable_scope(None, 'linear_model') as vs:
+ model_name = _strip_leading_slashes(vs.name)
+ linear_model_layer = _LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ name=model_name)
+ retval = linear_model_layer(features) # pylint: disable=not-callable
+ if cols_to_vars is not None:
+ cols_to_vars.update(linear_model_layer.cols_to_vars())
+ return retval
+
+
+def _add_to_collections(var, weight_collections):
+ """Adds a var to the list of weight_collections provided.
+
+ Handles the case for partitioned and non-partitioned variables.
+
+ Args:
+ var: A variable or Partitioned Variable.
+ weight_collections: List of collections to add variable to.
+ """
+ for weight_collection in weight_collections:
+ # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
+ if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
+ continue
+ # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
+ # so that we don't have to do this check.
+ if isinstance(var, variables.PartitionedVariable):
+ for constituent_var in list(var):
+ ops.add_to_collection(weight_collection, constituent_var)
+ else:
+ ops.add_to_collection(weight_collection, var)
+
+
+class _FCLinearWrapper(base.Layer):
+ """Wraps a _FeatureColumn in a layer for use in a linear model.
+
+ See `linear_model` above.
+ """
+
+ def __init__(self,
+ feature_column,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_FCLinearWrapper, self).__init__(
+ trainable=trainable, name=name, **kwargs)
+ self._feature_column = feature_column
+ self._units = units
+ self._sparse_combiner = sparse_combiner
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ else:
+ num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
+ weight = self.add_variable(
+ name='weights',
+ shape=[num_elements, self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ _add_to_collections(weight, self._weight_collections)
+ self._weight_var = weight
+ self.built = True
+
+ def call(self, builder):
+ weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
+ column=self._feature_column,
+ builder=builder,
+ units=self._units,
+ sparse_combiner=self._sparse_combiner,
+ weight_collections=self._weight_collections,
+ trainable=self.trainable,
+ weight_var=self._weight_var)
+ return weighted_sum
+
+
+class _BiasLayer(base.Layer):
+ """A layer for the bias term.
+ """
+
+ def __init__(self,
+ units=1,
+ trainable=True,
+ weight_collections=None,
+ name=None,
+ **kwargs):
+ super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
+ self._units = units
+ self._weight_collections = weight_collections
+
+ def build(self, _):
+ self._bias_variable = self.add_variable(
+ 'bias_weights',
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+ _add_to_collections(self._bias_variable, self._weight_collections)
+ self.built = True
+
+ def call(self, _):
+ return self._bias_variable
+
+
+def _get_expanded_variable_list(variable):
+ if (isinstance(variable, variables.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ return [variable] # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ return list(variable)
+
+
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
+
+class _LinearModel(training.Model):
+ """Creates a linear model using feature columns.
+
+ See `linear_model` for details.
+ """
+
+ def __init__(self,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(_LinearModel, self).__init__(name=name, **kwargs)
+ self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
+ feature_columns)
+ self._weight_collections = list(weight_collections or [])
+ if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
+ if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
+ self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
+
+ column_layers = {}
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
+ # Having the fully expressed variable scope name ends up doubly
+ # expressing the outer scope (scope with which this method was called)
+ # in the name of the variable that would get created.
+ column_name = _strip_leading_slashes(vs.name)
+ column_layer = _FCLinearWrapper(column, units, sparse_combiner,
+ self._weight_collections, trainable,
+ column_name, **kwargs)
+ column_layers[column_name] = column_layer
+ self._column_layers = self._add_layers(column_layers)
+ self._bias_layer = _BiasLayer(
+ units=units,
+ trainable=trainable,
+ weight_collections=self._weight_collections,
+ name='bias_layer',
+ **kwargs)
+ self._cols_to_vars = {}
+
+ def cols_to_vars(self):
+ """Returns a dict mapping _FeatureColumns to variables.
+
+ See `linear_model` for more information.
+ This is not populated till `call` is called i.e. layer is built.
+ """
+ return self._cols_to_vars
+
+ def call(self, features):
+ with variable_scope.variable_scope(self.name):
+ for column in self._feature_columns:
+ if not isinstance(
+ column,
+ (
+ fc_old._DenseColumn, # pylint: disable=protected-access
+ fc_old._CategoricalColumn)): # pylint: disable=protected-access
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ weighted_sums = []
+ ordered_columns = []
+ builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
+ for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
+ column = layer._feature_column # pylint: disable=protected-access
+ ordered_columns.append(column)
+ weighted_sum = layer(builder)
+ weighted_sums.append(weighted_sum)
+ self._cols_to_vars[column] = ops.get_collection(
+ ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
+
+ _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias,
+ self._bias_layer( # pylint: disable=not-callable
+ builder,
+ scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
+ name='weighted_sum')
+ bias = self._bias_layer.variables[0]
+ self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
+ return predictions
+
+ def _add_layers(self, layers):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ for name, layer in layers.items():
+ setattr(self, 'layer-%s' % name, layer)
+ return layers
+
+
+def _transform_features(features, feature_columns, state_manager):
+ """Returns transformed features based on features columns passed in.
+
+ Please note that most probably you would not need to use this function. Please
+ check `input_layer` and `linear_model` to see whether they will
+ satisfy your use case or not.
+
+ Example:
+
+ ```python
+ # Define features and transformations
+ crosses_a_x_b = crossed_column(
+ columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
+ price_buckets = bucketized_column(
+ source_column=numeric_column("price"), boundaries=[...])
+
+ columns = [crosses_a_x_b, price_buckets]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ transformed = transform_features(features=features, feature_columns=columns)
+
+ assertCountEqual(columns, transformed.keys())
+ ```
+
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via these
+ keys. For example `numeric_column('price')` will look at 'price' key in
+ this dict. Values can be a `SparseTensor` or a `Tensor` depends on
+ corresponding `FeatureColumn`.
+ feature_columns: An iterable containing all the `FeatureColumn`s.
+ state_manager: A StateManager object that holds the FeatureColumn state.
+
+ Returns:
+ A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
+ """
+ feature_columns = _normalize_feature_columns(feature_columns)
+ outputs = {}
+ with ops.name_scope(
+ None, default_name='transform_features', values=features.values()):
+ transformation_cache = FeatureTransformationCache(features)
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ with ops.name_scope(None, default_name=column.name):
+ outputs[column] = transformation_cache.get(column, state_manager)
+ return outputs
+
+
+def make_parse_example_spec(feature_columns):
+ """Creates parsing spec dictionary from input feature_columns.
+
+ The returned dictionary can be used as arg 'features' in `tf.parse_example`.
+
+ Typical usage example:
+
+ ```python
+ # Define features and transformations
+ feature_a = categorical_column_with_vocabulary_file(...)
+ feature_b = numeric_column(...)
+ feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...)
+ feature_a_x_feature_c = crossed_column(
+ columns=["feature_a", feature_c_bucketized], ...)
+
+ feature_columns = set(
+ [feature_b, feature_c_bucketized, feature_a_x_feature_c])
+ features = tf.parse_example(
+ serialized=serialized_examples,
+ features=make_parse_example_spec(feature_columns))
+ ```
+
+ For the above example, make_parse_example_spec would return the dict:
+
+ ```python
+ {
+ "feature_a": parsing_ops.VarLenFeature(tf.string),
+ "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
+ "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
+ }
+ ```
+
+ Args:
+ feature_columns: An iterable containing all feature columns. All items
+ should be instances of classes derived from `FeatureColumn`.
+
+ Returns:
+ A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
+ value.
+
+ Raises:
+ ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
+ instance.
+ """
+ result = {}
+ for column in feature_columns:
+ if not isinstance(column, FeatureColumn):
+ raise ValueError('All feature_columns must be FeatureColumn instances. '
+ 'Given: {}'.format(column))
+ config = column.parse_example_spec
+ for key, value in six.iteritems(config):
+ if key in result and value != result[key]:
+ raise ValueError(
+ 'feature_columns contain different parse_spec for key '
+ '{}. Given {} and {}'.format(key, value, result[key]))
+ result.update(config)
+ return result
+
+
+def embedding_column(
+ categorical_column, dimension, combiner='mean', initializer=None,
+ ckpt_to_load_from=None, tensor_name_in_ckpt=None, max_norm=None,
+ trainable=True):
+ """`_DenseColumn` that converts from sparse, categorical input.
+
+ Use this when your inputs are sparse, but you want to convert them to a dense
+ representation (e.g., to feed to a DNN).
+
+ Inputs must be a `_CategoricalColumn` created by any of the
+ `categorical_column_*` function. Here is an example of using
+ `embedding_column` with `DNNClassifier`:
+
+ ```python
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [embedding_column(video_id, 9),...]
+
+ estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
+
+ label_column = ...
+ def input_fn():
+ features = tf.parse_example(
+ ..., features=make_parse_example_spec(columns + [label_column]))
+ labels = features.pop(label_column.name)
+ return features, labels
+
+ estimator.train(input_fn=input_fn, steps=100)
+ ```
+
+ Here is an example using `embedding_column` with model_fn:
+
+ ```python
+ def model_fn(features, ...):
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [embedding_column(video_id, 9),...]
+ dense_tensor = input_layer(features, columns)
+ # Form DNN layers, calculate loss, and return EstimatorSpec.
+ ...
+ ```
+
+ Args:
+ categorical_column: A `_CategoricalColumn` created by a
+ `categorical_column_with_*` function. This column produces the sparse IDs
+ that are inputs to the embedding lookup.
+ dimension: An integer specifying dimension of the embedding, must be > 0.
+ combiner: A string specifying how to reduce if there are multiple entries
+ in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
+ with bag-of-words columns. Each of this can be thought as example level
+ normalizations on the column. For more information, see
+ `tf.embedding_lookup_sparse`.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ ckpt_to_load_from: String representing checkpoint name/pattern from which to
+ restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
+ which to restore the column weights. Required if `ckpt_to_load_from` is
+ not `None`.
+ max_norm: If not `None`, embedding values are l2-normalized to this value.
+ trainable: Whether or not the embedding is trainable. Default is True.
+
+ Returns:
+ `_DenseColumn` that converts from sparse input.
+
+ Raises:
+ ValueError: if `dimension` not > 0.
+ ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
+ is specified.
+ ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: If eager execution is enabled.
+ """
+ if (dimension is None) or (dimension < 1):
+ raise ValueError('Invalid dimension {}.'.format(dimension))
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError('Must specify both `ckpt_to_load_from` and '
+ '`tensor_name_in_ckpt` or none of them.')
+
+ if (initializer is not None) and (not callable(initializer)):
+ raise ValueError('initializer must be callable if specified. '
+ 'Embedding of column_name: {}'.format(
+ categorical_column.name))
+ if initializer is None:
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1 / math.sqrt(dimension))
+
+ return EmbeddingColumn(
+ categorical_column=categorical_column,
+ dimension=dimension,
+ combiner=combiner,
+ initializer=initializer,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable)
+
+
+def shared_embedding_columns(
+ categorical_columns, dimension, combiner='mean', initializer=None,
+ shared_embedding_collection_name=None, ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+ """List of dense columns that convert from sparse, categorical input.
+
+ This is similar to `embedding_column`, except that it produces a list of
+ embedding columns that share the same embedding weights.
+
+ Use this when your inputs are sparse and of the same type (e.g. watched and
+ impression video IDs that share the same vocabulary), and you want to convert
+ them to a dense representation (e.g., to feed to a DNN).
+
+ Inputs must be a list of categorical columns created by any of the
+ `categorical_column_*` function. They must all be of the same type and have
+ the same arguments except `key`. E.g. they can be
+ categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
+ all columns could also be weighted_categorical_column.
+
+ Here is an example embedding of two features for a DNNClassifier model:
+
+ ```python
+ watched_video_id = categorical_column_with_vocabulary_file(
+ 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
+ impression_video_id = categorical_column_with_vocabulary_file(
+ 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
+ columns = shared_embedding_columns(
+ [watched_video_id, impression_video_id], dimension=10)
+
+ estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
+
+ label_column = ...
+ def input_fn():
+ features = tf.parse_example(
+ ..., features=make_parse_example_spec(columns + [label_column]))
+ labels = features.pop(label_column.name)
+ return features, labels
+
+ estimator.train(input_fn=input_fn, steps=100)
+ ```
+
+ Here is an example using `shared_embedding_columns` with model_fn:
+
+ ```python
+ def model_fn(features, ...):
+ watched_video_id = categorical_column_with_vocabulary_file(
+ 'watched_video_id', video_vocabulary_file, video_vocabulary_size)
+ impression_video_id = categorical_column_with_vocabulary_file(
+ 'impression_video_id', video_vocabulary_file, video_vocabulary_size)
+ columns = shared_embedding_columns(
+ [watched_video_id, impression_video_id], dimension=10)
+ dense_tensor = input_layer(features, columns)
+ # Form DNN layers, calculate loss, and return EstimatorSpec.
+ ...
+ ```
+
+ Args:
+ categorical_columns: List of categorical columns created by a
+ `categorical_column_with_*` function. These columns produce the sparse IDs
+ that are inputs to the embedding lookup. All columns must be of the same
+ type and have the same arguments except `key`. E.g. they can be
+ categorical_column_with_vocabulary_file with the same vocabulary_file.
+ Some or all columns could also be weighted_categorical_column.
+ dimension: An integer specifying dimension of the embedding, must be > 0.
+ combiner: A string specifying how to reduce if there are multiple entries
+ in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
+ 'mean' the default. 'sqrtn' often achieves good accuracy, in particular
+ with bag-of-words columns. Each of this can be thought as example level
+ normalizations on the column. For more information, see
+ `tf.embedding_lookup_sparse`.
+ initializer: A variable initializer function to be used in embedding
+ variable initialization. If not specified, defaults to
+ `tf.truncated_normal_initializer` with mean `0.0` and standard deviation
+ `1/sqrt(dimension)`.
+ shared_embedding_collection_name: Optional collective name of these columns.
+ If not given, a reasonable name will be chosen based on the names of
+ `categorical_columns`.
+ ckpt_to_load_from: String representing checkpoint name/pattern from which to
+ restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
+ tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
+ which to restore the column weights. Required if `ckpt_to_load_from` is
+ not `None`.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
+ trainable: Whether or not the embedding is trainable. Default is True.
+
+ Returns:
+ A list of dense columns that converts from sparse input. The order of
+ results follows the ordering of `categorical_columns`.
+
+ Raises:
+ ValueError: if `dimension` not > 0.
+ ValueError: if any of the given `categorical_columns` is of different type
+ or has different arguments than the others.
+ ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
+ is specified.
+ ValueError: if `initializer` is specified and is not callable.
+ RuntimeError: if eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError('shared_embedding_columns are not supported when eager '
+ 'execution is enabled.')
+
+ if (dimension is None) or (dimension < 1):
+ raise ValueError('Invalid dimension {}.'.format(dimension))
+ if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
+ raise ValueError('Must specify both `ckpt_to_load_from` and '
+ '`tensor_name_in_ckpt` or none of them.')
+
+ if (initializer is not None) and (not callable(initializer)):
+ raise ValueError('initializer must be callable if specified.')
+ if initializer is None:
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1. / math.sqrt(dimension))
+
+ # Sort the columns so the default collection name is deterministic even if the
+ # user passes columns from an unsorted collection, such as dict.values().
+ sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
+
+ c0 = sorted_columns[0]
+ num_buckets = c0.num_buckets
+ if not isinstance(c0, CategoricalColumn):
+ raise ValueError(
+ 'All categorical_columns must be subclasses of CategoricalColumn. '
+ 'Given: {}, of type: {}'.format(c0, type(c0)))
+ if isinstance(c0, WeightedCategoricalColumn):
+ c0 = c0.categorical_column
+ for c in sorted_columns[1:]:
+ if isinstance(c, WeightedCategoricalColumn):
+ c = c.categorical_column
+ if not isinstance(c, type(c0)):
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same type, or be weighted_categorical_column of the same type. '
+ 'Given column: {} of type: {} does not match given column: {} of '
+ 'type: {}'.format(c0, type(c0), c, type(c)))
+ if num_buckets != c.num_buckets:
+ raise ValueError(
+ 'To use shared_embedding_column, all categorical_columns must have '
+ 'the same number of buckets. Given column: {} with buckets: {} does '
+ 'not match column: {} with buckets: {}'.format(
+ c0, num_buckets, c, c.num_buckets))
+
+ if not shared_embedding_collection_name:
+ shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
+ shared_embedding_collection_name += '_shared_embedding'
+
+ result = []
+ for column in categorical_columns:
+ result.append(
+ SharedEmbeddingColumn(
+ categorical_column=column,
+ initializer=initializer,
+ dimension=dimension,
+ combiner=combiner,
+ shared_embedding_collection_name=shared_embedding_collection_name,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable))
+
+ return result
+
+
+def numeric_column(key,
+ shape=(1,),
+ default_value=None,
+ dtype=dtypes.float32,
+ normalizer_fn=None):
+ """Represents real valued or numerical features.
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ columns = [price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+
+ # or
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ shape: An iterable of integers specifies the shape of the `Tensor`. An
+ integer can be given which means a single dimension `Tensor` with given
+ width. The `Tensor` representing the column will have the shape of
+ [batch_size] + `shape`.
+ default_value: A single value compatible with `dtype` or an iterable of
+ values compatible with `dtype` which the column takes on during
+ `tf.Example` parsing if data is missing. A default value of `None` will
+ cause `tf.parse_example` to fail if an example does not contain this
+ column. If a single value is provided, the same value will be applied as
+ the default value for every item. If an iterable of values is provided,
+ the shape of the `default_value` should be equal to the given `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ normalizer_fn: If not `None`, a function that can be used to normalize the
+ value of the tensor after `default_value` is applied for parsing.
+ Normalizer function takes the input `Tensor` as its argument, and returns
+ the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
+ even though the most common use case of this function is normalization, it
+ can be used for any kind of Tensorflow transformations.
+
+ Returns:
+ A `NumericColumn`.
+
+ Raises:
+ TypeError: if any dimension in shape is not an int
+ ValueError: if any dimension in shape is not a positive integer
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ shape = _check_shape(shape, key)
+ if not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype must be convertible to float. '
+ 'dtype: {}, key: {}'.format(dtype, key))
+ default_value = _check_default_value(shape, default_value, dtype, key)
+
+ if normalizer_fn is not None and not callable(normalizer_fn):
+ raise TypeError(
+ 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+
+ _assert_key_is_string(key)
+ return NumericColumn(
+ key,
+ shape=shape,
+ default_value=default_value,
+ dtype=dtype,
+ normalizer_fn=normalizer_fn)
+
+
+def bucketized_column(source_column, boundaries):
+ """Represents discretized dense input.
+
+ Buckets include the left boundary, and exclude the right boundary. Namely,
+ `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
+ `[1., 2.)`, and `[2., +inf)`.
+
+ For example, if the inputs are
+
+ ```python
+ boundaries = [0, 10, 100]
+ input tensor = [[-5, 10000]
+ [150, 10]
+ [5, 100]]
+ ```
+
+ then the output will be
+
+ ```python
+ output = [[0, 3]
+ [3, 2]
+ [1, 3]]
+ ```
+
+ Example:
+
+ ```python
+ price = numeric_column('price')
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+
+ # or
+ columns = [bucketized_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ `bucketized_column` can also be crossed with another categorical column using
+ `crossed_column`:
+
+ ```python
+ price = numeric_column('price')
+ # bucketized_column converts numerical feature to a categorical one.
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ # 'keywords' is a string feature.
+ price_x_keywords = crossed_column([bucketized_price, 'keywords'], 50K)
+ columns = [price_x_keywords, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Args:
+ source_column: A one-dimensional dense column which is generated with
+ `numeric_column`.
+ boundaries: A sorted list or tuple of floats specifying the boundaries.
+
+ Returns:
+ A `BucketizedColumn`.
+
+ Raises:
+ ValueError: If `source_column` is not a numeric column, or if it is not
+ one-dimensional.
+ ValueError: If `boundaries` is not a sorted list or tuple.
+ """
+ if not isinstance(source_column, NumericColumn):
+ raise ValueError(
+ 'source_column must be a column generated with numeric_column(). '
+ 'Given: {}'.format(source_column))
+ if len(source_column.shape) > 1:
+ raise ValueError(
+ 'source_column must be one-dimensional column. '
+ 'Given: {}'.format(source_column))
+ if (not boundaries or
+ not (isinstance(boundaries, list) or isinstance(boundaries, tuple))):
+ raise ValueError('boundaries must be a sorted list.')
+ for i in range(len(boundaries) - 1):
+ if boundaries[i] >= boundaries[i + 1]:
+ raise ValueError('boundaries must be a sorted list.')
+ return BucketizedColumn(source_column, tuple(boundaries))
+
+
+def _assert_string_or_int(dtype, prefix):
+ if (dtype != dtypes.string) and (not dtype.is_integer):
+ raise ValueError(
+ '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+
+
+def _assert_key_is_string(key):
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
+ type(key), key))
+
+
+def categorical_column_with_hash_bucket(key,
+ hash_bucket_size,
+ dtype=dtypes.string):
+ """Represents sparse feature where ids are set by hashing.
+
+ Use this when your sparse features are in string or integer format, and you
+ want to distribute your inputs into a finite number of buckets by hashing.
+ output_id = Hash(input_feature_string) % bucket_size for string type input.
+ For int type input, the value is converted to its string representation first
+ and then hashed by the same formula.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example:
+
+ ```python
+ keywords = categorical_column_with_hash_bucket("keywords", 10K)
+ columns = [keywords, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+
+ # or
+ keywords_embedded = embedding_column(keywords, 16)
+ columns = [keywords_embedded, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ hash_bucket_size: An int > 1. The number of buckets.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `HashedCategoricalColumn`.
+
+ Raises:
+ ValueError: `hash_bucket_size` is not greater than 1.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if hash_bucket_size is None:
+ raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
+
+ if hash_bucket_size < 1:
+ raise ValueError('hash_bucket_size must be at least 1. '
+ 'hash_bucket_size: {}, key: {}'.format(
+ hash_bucket_size, key))
+
+ _assert_key_is_string(key)
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+
+ return HashedCategoricalColumn(key, hash_bucket_size, dtype)
+
+
+def categorical_column_with_vocabulary_file(key,
+ vocabulary_file,
+ vocabulary_size=None,
+ num_oov_buckets=0,
+ default_value=None,
+ dtype=dtypes.string):
+ """A `CategoricalColumn` with a vocabulary file.
+
+ Use this when your inputs are in string or integer format, and you have a
+ vocabulary file that maps each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example with `num_oov_buckets`:
+ File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
+ abbreviation. All inputs with values in that file are assigned an ID 0-49,
+ corresponding to its line number. All other values are hashed and assigned an
+ ID 50-54.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
+ num_oov_buckets=5)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Example with `default_value`:
+ File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
+ other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
+ in input, and other values missing from the file, will be assigned ID 0. All
+ others are assigned the corresponding line number 1-50.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
+ default_value=0)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ And to make an embedding with either:
+
+ ```python
+ columns = [embedding_column(states, 3),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_file: The vocabulary file name.
+ vocabulary_size: Number of the elements in the vocabulary. This must be no
+ greater than length of `vocabulary_file`, if less than length, later
+ values are ignored. If None, it is set to the length of `vocabulary_file`.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
+ the input value. A positive `num_oov_buckets` can not be specified with
+ `default_value`.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to `-1`. This can not be specified with a positive
+ `num_oov_buckets`.
+ dtype: The type of features. Only string and integer types are supported.
+
+ Returns:
+ A `CategoricalColumn` with a vocabulary file.
+
+ Raises:
+ ValueError: `vocabulary_file` is missing or cannot be opened.
+ ValueError: `vocabulary_size` is missing or < 1.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: `dtype` is neither string nor integer.
+ """
+ if not vocabulary_file:
+ raise ValueError('Missing vocabulary_file in {}.'.format(key))
+
+ if vocabulary_size is None:
+ if not gfile.Exists(vocabulary_file):
+ raise ValueError('vocabulary_file in {} does not exist.'.format(key))
+
+ with gfile.GFile(vocabulary_file) as f:
+ vocabulary_size = sum(1 for _ in f)
+ logging.info(
+ 'vocabulary_size = %d in %s is inferred from the number of elements '
+ 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
+
+ # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
+ if vocabulary_size < 1:
+ raise ValueError('Invalid vocabulary_size in {}.'.format(key))
+ if num_oov_buckets:
+ if default_value is not None:
+ raise ValueError(
+ 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
+ key))
+ if num_oov_buckets < 0:
+ raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
+ num_oov_buckets, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
+ return VocabularyFileCategoricalColumn(
+ key=key,
+ vocabulary_file=vocabulary_file,
+ vocabulary_size=vocabulary_size,
+ num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
+ default_value=-1 if default_value is None else default_value,
+ dtype=dtype)
+
+
+def categorical_column_with_vocabulary_list(
+ key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
+ """A `_CategoricalColumn` with in-memory vocabulary.
+
+ Use this when your inputs are in string or integer format, and you have an
+ in-memory vocabulary mapping each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example with `num_oov_buckets`:
+ In the following example, each input in `vocabulary_list` is assigned an ID
+ 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
+ inputs are hashed and assigned an ID 4-5.
+
+ ```python
+ colors = categorical_column_with_vocabulary_list(
+ key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
+ num_oov_buckets=2)
+ columns = [colors, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ Example with `default_value`:
+ In the following example, each input in `vocabulary_list` is assigned an ID
+ 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
+ inputs are assigned `default_value` 0.
+
+
+ ```python
+ colors = categorical_column_with_vocabulary_list(
+ key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
+ columns = [colors, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ And to make an embedding with either:
+
+ ```python
+ columns = [embedding_column(colors, 3),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_list: An ordered iterable defining the vocabulary. Each feature
+ is mapped to the index of its value (if present) in `vocabulary_list`.
+ Must be castable to `dtype`.
+ dtype: The type of features. Only string and integer types are supported.
+ If `None`, it will be inferred from `vocabulary_list`.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to `-1`. This can not be specified with a positive
+ `num_oov_buckets`.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
+ hash of the input value. A positive `num_oov_buckets` can not be specified
+ with `default_value`.
+
+ Returns:
+ A `CategoricalColumn` with in-memory vocabulary.
+
+ Raises:
+ ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: if `dtype` is not integer or string.
+ """
+ if (vocabulary_list is None) or (len(vocabulary_list) < 1):
+ raise ValueError(
+ 'vocabulary_list {} must be non-empty, column_name: {}'.format(
+ vocabulary_list, key))
+ if len(set(vocabulary_list)) != len(vocabulary_list):
+ raise ValueError(
+ 'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
+ vocabulary_list, key))
+ vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
+ if num_oov_buckets:
+ if default_value != -1:
+ raise ValueError(
+ 'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
+ key))
+ if num_oov_buckets < 0:
+ raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
+ num_oov_buckets, key))
+ _assert_string_or_int(
+ vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
+ if dtype is None:
+ dtype = vocabulary_dtype
+ elif dtype.is_integer != vocabulary_dtype.is_integer:
+ raise ValueError(
+ 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
+ dtype, vocabulary_dtype, key))
+ _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
+
+ return VocabularyListCategoricalColumn(
+ key=key,
+ vocabulary_list=tuple(vocabulary_list),
+ dtype=dtype,
+ default_value=default_value,
+ num_oov_buckets=num_oov_buckets)
+
+
+def categorical_column_with_identity(key, num_buckets, default_value=None):
+ """A `CategoricalColumn` that returns identity values.
+
+ Use this when your inputs are integers in the range `[0, num_buckets)`, and
+ you want to use the input value itself as the categorical ID. Values outside
+ this range will result in `default_value` if specified, otherwise it will
+ fail.
+
+ Typically, this is used for contiguous ranges of integer indexes, but
+ it doesn't have to be. This might be inefficient, however, if many of IDs
+ are unused. Consider `categorical_column_with_hash_bucket` in that case.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ In the following examples, each input in the range `[0, 1000000)` is assigned
+ the same value. All other inputs are assigned `default_value` 0. Note that a
+ literal 0 in inputs will result in the same default ID.
+
+ Linear model:
+
+ ```python
+ video_id = categorical_column_with_identity(
+ key='video_id', num_buckets=1000000, default_value=0)
+ columns = [video_id, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ Embedding for a DNN model:
+
+ ```python
+ columns = [embedding_column(video_id, 9),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
+ default_value: If `None`, this column's graph operations will fail for
+ out-of-range inputs. Otherwise, this value must be in the range
+ `[0, num_buckets)`, and will replace inputs in that range.
+
+ Returns:
+ A `CategoricalColumn` that returns identity values.
+
+ Raises:
+ ValueError: if `num_buckets` is less than one.
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
+ """
+ if num_buckets < 1:
+ raise ValueError(
+ 'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
+ if (default_value is not None) and (
+ (default_value < 0) or (default_value >= num_buckets)):
+ raise ValueError(
+ 'default_value {} not in range [0, {}), column_name {}'.format(
+ default_value, num_buckets, key))
+ _assert_key_is_string(key)
+ return IdentityCategoricalColumn(
+ key=key, number_buckets=num_buckets, default_value=default_value)
+
+
+def indicator_column(categorical_column):
+ """Represents multi-hot representation of given categorical column.
+
+ - For DNN model, `indicator_column` can be used to wrap any
+ `categorical_column_*` (e.g., to feed to DNN). Consider to Use
+ `embedding_column` if the number of buckets/unique(values) are large.
+
+ - For Wide (aka linear) model, `indicator_column` is the internal
+ representation for categorical column when passing categorical column
+ directly (as any element in feature_columns) to `linear_model`. See
+ `linear_model` for details.
+
+ ```python
+ name = indicator_column(categorical_column_with_vocabulary_list(
+ 'name', ['bob', 'george', 'wanda'])
+ columns = [name, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+
+ dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"]
+ dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"]
+ dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"]
+ ```
+
+ Args:
+ categorical_column: A `CategoricalColumn` which is created by
+ `categorical_column_with_*` or `crossed_column` functions.
+
+ Returns:
+ An `IndicatorColumn`.
+ """
+ return IndicatorColumn(categorical_column)
+
+
+def weighted_categorical_column(
+ categorical_column, weight_feature_key, dtype=dtypes.float32):
+ """Applies weight values to a `_CategoricalColumn`.
+
+ Use this when each of your sparse inputs has both an ID and a value. For
+ example, if you're representing text documents as a collection of word
+ frequencies, you can provide 2 parallel sparse input features ('terms' and
+ 'frequencies' below).
+
+ Example:
+
+ Input `tf.Example` objects:
+
+ ```proto
+ [
+ features {
+ feature {
+ key: "terms"
+ value {bytes_list {value: "very" value: "model"}}
+ }
+ feature {
+ key: "frequencies"
+ value {float_list {value: 0.3 value: 0.1}}
+ }
+ },
+ features {
+ feature {
+ key: "terms"
+ value {bytes_list {value: "when" value: "course" value: "human"}}
+ }
+ feature {
+ key: "frequencies"
+ value {float_list {value: 0.4 value: 0.1 value: 0.2}}
+ }
+ }
+ ]
+ ```
+
+ ```python
+ categorical_column = categorical_column_with_hash_bucket(
+ column_name='terms', hash_bucket_size=1000)
+ weighted_column = weighted_categorical_column(
+ categorical_column=categorical_column, weight_feature_key='frequencies')
+ columns = [weighted_column, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ This assumes the input dictionary contains a `SparseTensor` for key
+ 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
+ the same indices and dense shape.
+
+ Args:
+ categorical_column: A `_CategoricalColumn` created by
+ `categorical_column_with_*` functions.
+ weight_feature_key: String key for weight values.
+ dtype: Type of weights, such as `tf.float32`. Only float and integer weights
+ are supported.
+
+ Returns:
+ A `CategoricalColumn` composed of two sparse features: one represents id,
+ the other represents weight (value) of the id feature in that example.
+
+ Raises:
+ ValueError: if `dtype` is not convertible to float.
+ """
+ if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype {} is not convertible to float.'.format(dtype))
+ return WeightedCategoricalColumn(
+ categorical_column=categorical_column,
+ weight_feature_key=weight_feature_key,
+ dtype=dtype)
+
+
+def crossed_column(keys, hash_bucket_size, hash_key=None):
+ """Returns a column for performing crosses of categorical features.
+
+ Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
+ the transformation can be thought of as:
+ Hash(cartesian product of features) % `hash_bucket_size`
+
+ For example, if the input features are:
+
+ * SparseTensor referred by first key:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [1, 0]: "b"
+ [1, 1]: "c"
+ }
+ ```
+
+ * SparseTensor referred by second key:
+
+ ```python
+ shape = [2, 1]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ }
+ ```
+
+ then crossed feature will look like:
+
+ ```python
+ shape = [2, 2]
+ {
+ [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
+ [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
+ [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
+ }
+ ```
+
+ Here is an example to create a linear model with crosses of string features:
+
+ ```python
+ keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
+ columns = [keywords_x_doc_terms, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ You could also use vocabulary lookup before crossing:
+
+ ```python
+ keywords = categorical_column_with_vocabulary_file(
+ 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
+ keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
+ columns = [keywords_x_doc_terms, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ If an input feature is of numeric type, you can use
+ `categorical_column_with_identity`, or `bucketized_column`, as in the example:
+
+ ```python
+ # vertical_id is an integer categorical feature.
+ vertical_id = categorical_column_with_identity('vertical_id', 10K)
+ price = numeric_column('price')
+ # bucketized_column converts numerical feature to a categorical one.
+ bucketized_price = bucketized_column(price, boundaries=[...])
+ vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
+ columns = [vertical_id_x_price, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ To use crossed column in DNN model, you need to add it in an embedding column
+ as in this example:
+
+ ```python
+ vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
+ vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
+ dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
+ ```
+
+ Args:
+ keys: An iterable identifying the features to be crossed. Each element can
+ be either:
+ * string: Will use the corresponding feature which must be of string type.
+ * `CategoricalColumn`: Will use the transformed tensor produced by this
+ column. Does not support hashed categorical column.
+ hash_bucket_size: An int > 1. The number of buckets.
+ hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
+ function to combine the crosses fingerprints on SparseCrossOp (optional).
+
+ Returns:
+ A `CrossedColumn`.
+
+ Raises:
+ ValueError: If `len(keys) < 2`.
+ ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
+ ValueError: If any of the keys is `HashedCategoricalColumn`.
+ ValueError: If `hash_bucket_size < 1`.
+ """
+ if not hash_bucket_size or hash_bucket_size < 1:
+ raise ValueError('hash_bucket_size must be > 1. '
+ 'hash_bucket_size: {}'.format(hash_bucket_size))
+ if not keys or len(keys) < 2:
+ raise ValueError(
+ 'keys must be a list with length > 1. Given: {}'.format(keys))
+ for key in keys:
+ if (not isinstance(key, six.string_types) and
+ not isinstance(key, CategoricalColumn)):
+ raise ValueError(
+ 'Unsupported key type. All keys must be either string, or '
+ 'categorical column except HashedCategoricalColumn. '
+ 'Given: {}'.format(key))
+ if isinstance(key, HashedCategoricalColumn):
+ raise ValueError(
+ 'categorical_column_with_hash_bucket is not supported for crossing. '
+ 'Hashing before crossing will increase probability of collision. '
+ 'Instead, use the feature name as a string. Given: {}'.format(key))
+ return CrossedColumn(
+ keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
+
+
+class StateManager(object):
+ """Manages the state associated with FeatureColumns.
+
+ Some `FeatureColumn`s create variables or resources to assist their
+ computation. The `StateManager` is responsible for creating and storing these
+ objects since `FeatureColumn`s are supposed to be stateless configuration
+ only.
+ """
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ """Creates a new variable or returns an existing one.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ shape: variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ initializer: initializer instance (callable).
+
+ Returns:
+ The variable.
+ """
+ raise NotImplementedError('StateManager.get_variable')
+
+ def get_resource(self, feature_column, name, resource_creator):
+ """Creates a new resource or returns an existing one.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: Name of the resource.
+ resource_creator: A callable that can create the resource.
+
+ Returns:
+ The resource.
+ """
+ raise NotImplementedError('StateManager.get_resource')
+
+
+class FeatureColumn(object):
+ """Represents a feature column abstraction.
+
+ WARNING: Do not subclass this layer unless you know what you are doing:
+ the API is subject to future changes.
+
+ To distinguish between the concept of a feature family and a specific binary
+ feature within a family, we refer to a feature family like "country" as a
+ feature column. For example, we can have a feature in a `tf.Example` format:
+ {key: "country", value: [ "US" ]}
+ In this example the value of feature is "US" and "country" refers to the
+ column of the feature.
+
+ This class is an abstract class. Users should not create instances of this.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def name(self):
+ """Returns string. Used for naming."""
+ pass
+
+ @abc.abstractmethod
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns intermediate representation (usually a `Tensor`).
+
+ Uses `transformation_cache` to create an intermediate representation
+ (usually a `Tensor`) that other feature columns can use.
+
+ Example usage of `transformation_cache`:
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
+ transformation_cache will be used as follows:
+
+ ```python
+ raw_tensor = transformation_cache.get('raw', state_manager)
+ fc_tensor = transformation_cache.get(input_fc, state_manager)
+ ```
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+ """
+ pass
+
+ @abc.abstractproperty
+ def parse_example_spec(self):
+ """Returns a `tf.Example` parsing spec as dict.
+
+ It is used for get_parsing_spec for `tf.parse_example`. Returned spec is a
+ dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
+ supported objects. Please check documentation of @{tf.parse_example} for all
+ supported spec objects.
+
+ Let's say a Feature column depends on raw feature ('raw') and another
+ `FeatureColumn` (input_fc). One possible implementation of
+ parse_example_spec is as follows:
+
+ ```python
+ spec = {'raw': tf.FixedLenFeature(...)}
+ spec.update(input_fc.parse_example_spec)
+ return spec
+ ```
+ """
+ pass
+
+ def create_state(self, state_manager):
+ """Uses the `state_manager` to create state for the FeatureColumn.
+
+ Args:
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables and variables.
+ """
+ pass
+
+
+class DenseColumn(FeatureColumn):
+ """Represents a column which can be represented as `Tensor`.
+
+ Some examples of this type are: numeric_column, embedding_column,
+ indicator_column.
+ """
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractproperty
+ def variable_shape(self):
+ """`TensorShape` of `get_dense_tensor`, without batch dimension."""
+ pass
+
+ @abc.abstractmethod
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns a `Tensor`.
+
+ The output of this function will be used by model-builder-functions. For
+ example the pseudo code of `input_layer` will be like:
+
+ ```python
+ def input_layer(features, feature_columns, ...):
+ outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
+ return tf.concat(outputs)
+ ```
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ `Tensor` of shape [batch_size] + `variable_shape`.
+ """
+ pass
+
+
+def _create_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ """Creates a weighted sum for a dense/categorical column for linear_model."""
+ if isinstance(column, CategoricalColumn):
+ return _create_categorical_column_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ weight_var=weight_var)
+ else:
+ return _create_dense_column_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ units=units,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ weight_var=weight_var)
+
+
+def _create_dense_column_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ """Create a weighted sum of a dense column for linear_model."""
+ tensor = column.get_dense_tensor(transformation_cache, state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=[num_elements, units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return math_ops.matmul(tensor, weight, name='weighted_sum')
+
+
+class CategoricalColumn(FeatureColumn):
+ """Represents a categorical feature.
+
+ A categorical feature typically handled with a @{tf.SparseTensor} of IDs.
+ """
+ __metaclass__ = abc.ABCMeta
+
+ IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'IdWeightPair', ('id_tensor', 'weight_tensor'))
+
+ @abc.abstractproperty
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ pass
+
+ @abc.abstractmethod
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Returns an IdWeightPair.
+
+ `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
+ weights.
+
+ `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
+ `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
+ `SparseTensor` of `float` or `None` to indicate all weights should be
+ taken to be 1. If specified, `weight_tensor` must have exactly the same
+ shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
+ output of a `VarLenFeature` which is a ragged matrix.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ pass
+
+
+def _create_categorical_column_weighted_sum(column,
+ transformation_cache,
+ state_manager,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ weight_var=None):
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """Create a weighted sum of a categorical column for linear_model.
+
+ Note to maintainer: As implementation details, the weighted sum is
+ implemented via embedding_lookup_sparse toward efficiency. Mathematically,
+ they are the same.
+
+ To be specific, conceptually, categorical column can be treated as multi-hot
+ vector. Say:
+
+ ```python
+ x = [0 0 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `c` in this case, which is same as `w[2]`.
+
+ Another example is
+
+ ```python
+ x = [0 1 1] # categorical column input
+ w = [a b c] # weights
+ ```
+ The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
+
+ For both cases, we can implement weighted sum via embedding_lookup with
+ sparse_combiner = "sum".
+ """
+
+ sparse_tensors = column.get_sparse_tensors(transformation_cache,
+ state_manager)
+ id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
+ array_ops.shape(sparse_tensors.id_tensor)[0], -1
+ ])
+ weight_tensor = sparse_tensors.weight_tensor
+ if weight_tensor is not None:
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
+
+ if weight_var is not None:
+ weight = weight_var
+ else:
+ weight = variable_scope.get_variable(
+ name='weights',
+ shape=(column.num_buckets, units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=trainable,
+ collections=weight_collections)
+ return _safe_embedding_lookup_sparse(
+ weight,
+ id_tensor,
+ sparse_weights=weight_tensor,
+ combiner=sparse_combiner,
+ name='weighted_sum')
+
+
+class SequenceDenseColumn(FeatureColumn):
+ """Represents dense sequence data."""
+
+ __metaclass__ = abc.ABCMeta
+
+ TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
+
+ @abc.abstractmethod
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """Returns a `TensorSequenceLengthPair`.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ pass
+
+
+class FeatureTransformationCache(object):
+ """Handles caching of transformations while building the model.
+
+ `FeatureColumn` specifies how to digest an input column to the network. Some
+ feature columns require data transformations. This class caches those
+ transformations.
+
+ Some features may be used in more than one place. For example, one can use a
+ bucketized feature by itself and a cross with it. In that case we
+ should create only one bucketization op instead of creating ops for each
+ feature column separately. To handle re-use of transformed columns,
+ `FeatureTransformationCache` caches all previously transformed columns.
+
+ Example:
+ We're trying to use the following `FeatureColumn`s:
+
+ ```python
+ bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
+ keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
+ age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
+ ... = linear_model(features,
+ [bucketized_age, keywords, age_X_keywords]
+ ```
+
+ If we transform each column independently, then we'll get duplication of
+ bucketization (one for cross, one for bucketization itself).
+ The `FeatureTransformationCache` eliminates this duplication.
+ """
+
+ def __init__(self, features):
+ """Creates a `FeatureTransformationCache`.
+
+ Args:
+ features: A mapping from feature column to objects that are `Tensor` or
+ `SparseTensor`, or can be converted to same via
+ `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
+ signifies a base feature (not-transformed). A `FeatureColumn` key
+ means that this `Tensor` is the output of an existing `FeatureColumn`
+ which can be reused.
+ """
+ self._features = features.copy()
+ self._feature_tensors = {}
+
+ def get(self, key, state_manager):
+ """Returns a `Tensor` for the given key.
+
+ A `str` key is used to access a base feature (not-transformed). When a
+ `FeatureColumn` is passed, the transformed feature is returned if it
+ already exists, otherwise the given `FeatureColumn` is asked to provide its
+ transformed output, which is then cached.
+
+ Args:
+ key: a `str` or a `FeatureColumn`.
+ state_manager: A StateManager object that holds the FeatureColumn state.
+
+ Returns:
+ The transformed `Tensor` corresponding to the `key`.
+
+ Raises:
+ ValueError: if key is not found or a transformed `Tensor` cannot be
+ computed.
+ """
+ if key in self._feature_tensors:
+ # FeatureColumn is already transformed or converted.
+ return self._feature_tensors[key]
+
+ if key in self._features:
+ feature_tensor = self._get_raw_feature_as_tensor(key)
+ self._feature_tensors[key] = feature_tensor
+ return feature_tensor
+
+ if isinstance(key, six.string_types):
+ raise ValueError('Feature {} is not in features dictionary.'.format(key))
+
+ if not isinstance(key, FeatureColumn):
+ raise TypeError('"key" must be either a "str" or "FeatureColumn". '
+ 'Provided: {}'.format(key))
+
+ column = key
+ logging.debug('Transforming feature_column %s.', column)
+ transformed = column.transform_feature(self, state_manager)
+ if transformed is None:
+ raise ValueError('Column {} is not supported.'.format(column.name))
+ self._feature_tensors[column] = transformed
+ return transformed
+
+ def _get_raw_feature_as_tensor(self, key):
+ """Gets the raw_feature (keyed by `key`) as `tensor`.
+
+ The raw feature is converted to (sparse) tensor and maybe expand dim.
+
+ For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
+ the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
+ error out as it is not supported.
+
+ Args:
+ key: A `str` key to access the raw feature.
+
+ Returns:
+ A `Tensor` or `SparseTensor`.
+
+ Raises:
+ ValueError: if the raw feature has rank 0.
+ """
+ raw_feature = self._features[key]
+ feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ raw_feature)
+
+ def expand_dims(input_tensor):
+ # Input_tensor must have rank 1.
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ return sparse_ops.sparse_reshape(
+ input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ else:
+ return array_ops.expand_dims(input_tensor, -1)
+
+ rank = feature_tensor.get_shape().ndims
+ if rank is not None:
+ if rank == 0:
+ raise ValueError(
+ 'Feature (key: {}) cannot have rank 0. Give: {}'.format(
+ key, feature_tensor))
+ return feature_tensor if rank != 1 else expand_dims(feature_tensor)
+
+ # Handle dynamic rank.
+ with ops.control_dependencies([
+ check_ops.assert_positive(
+ array_ops.rank(feature_tensor),
+ message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
+ key, feature_tensor))]):
+ return control_flow_ops.cond(
+ math_ops.equal(1, array_ops.rank(feature_tensor)),
+ lambda: expand_dims(feature_tensor),
+ lambda: feature_tensor)
+
+
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _shape_offsets(shape):
+ """Returns moving offset for each dimension given shape."""
+ offsets = []
+ for dim in reversed(shape):
+ if offsets:
+ offsets.append(dim * offsets[-1])
+ else:
+ offsets.append(dim)
+ offsets.reverse()
+ return offsets
+
+
+# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
+def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
+ """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
+
+ If `input_tensor` is already a `SparseTensor`, just return it.
+
+ Args:
+ input_tensor: A string or integer `Tensor`.
+ ignore_value: Entries in `dense_tensor` equal to this value will be
+ absent from the resulting `SparseTensor`. If `None`, default value of
+ `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
+
+ Returns:
+ A `SparseTensor` with the same shape as `input_tensor`.
+
+ Raises:
+ ValueError: when `input_tensor`'s rank is `None`.
+ """
+ input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ input_tensor)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ return input_tensor
+ with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
+ if ignore_value is None:
+ if input_tensor.dtype == dtypes.string:
+ # Exception due to TF strings are converted to numpy objects by default.
+ ignore_value = ''
+ elif input_tensor.dtype.is_integer:
+ ignore_value = -1 # -1 has a special meaning of missing feature
+ else:
+ # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
+ # constructing a new numpy object of the given type, which yields the
+ # default value for that type.
+ ignore_value = input_tensor.dtype.as_numpy_dtype()
+ ignore_value = math_ops.cast(
+ ignore_value, input_tensor.dtype, name='ignore_value')
+ indices = array_ops.where(
+ math_ops.not_equal(input_tensor, ignore_value), name='indices')
+ return sparse_tensor_lib.SparseTensor(
+ indices=indices,
+ values=array_ops.gather_nd(input_tensor, indices, name='values'),
+ dense_shape=array_ops.shape(
+ input_tensor, out_type=dtypes.int64, name='dense_shape'))
+
+
+def _normalize_feature_columns(feature_columns):
+ """Normalizes the `feature_columns` input.
+
+ This method converts the `feature_columns` to list type as best as it can. In
+ addition, verifies the type and other parts of feature_columns, required by
+ downstream library.
+
+ Args:
+ feature_columns: The raw feature columns, usually passed by users.
+
+ Returns:
+ The normalized feature column list.
+
+ Raises:
+ ValueError: for any invalid inputs, such as empty, duplicated names, etc.
+ """
+ if isinstance(feature_columns, FeatureColumn):
+ feature_columns = [feature_columns]
+
+ if isinstance(feature_columns, collections.Iterator):
+ feature_columns = list(feature_columns)
+
+ if isinstance(feature_columns, dict):
+ raise ValueError('Expected feature_columns to be iterable, found dict.')
+
+ for column in feature_columns:
+ if not isinstance(column, FeatureColumn):
+ raise ValueError('Items of feature_columns must be a FeatureColumn. '
+ 'Given (type {}): {}.'.format(type(column), column))
+ if not feature_columns:
+ raise ValueError('feature_columns must not be empty.')
+ name_to_column = dict()
+ for column in feature_columns:
+ if column.name in name_to_column:
+ raise ValueError('Duplicate feature column name found for columns: {} '
+ 'and {}. This usually means that these columns refer to '
+ 'same base feature. Either one must be discarded or a '
+ 'duplicated but renamed item must be inserted in '
+ 'features dict.'.format(column,
+ name_to_column[column.name]))
+ name_to_column[column.name] = column
+
+ return feature_columns
+
+
+class NumericColumn(
+ DenseColumn,
+ collections.namedtuple(
+ 'NumericColumn',
+ ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
+ """see `numeric_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {
+ self.key:
+ parsing_ops.FixedLenFeature(self.shape, self.dtype,
+ self.default_value)
+ }
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class.
+
+ In this case, we apply the `normalizer_fn` to the input tensor.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Normalized input tensor.
+ Raises:
+ ValueError: If a SparseTensor is passed in.
+ """
+ input_tensor = transformation_cache.get(self.key, state_manager)
+ if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'The corresponding Tensor of numerical column must be a Tensor. '
+ 'SparseTensor is not supported. key: {}'.format(self.key))
+ if self.normalizer_fn is not None:
+ input_tensor = self.normalizer_fn(input_tensor)
+ return math_ops.to_float(input_tensor)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape(self.shape)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing numeric feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Dense `Tensor` created within `transform_feature`.
+ """
+ # Feature has been already transformed. Return the intermediate
+ # representation created by _transform_feature.
+ return transformation_cache.get(self, state_manager)
+
+
+class BucketizedColumn(DenseColumn, CategoricalColumn,
+ collections.namedtuple('BucketizedColumn',
+ ('source_column', 'boundaries'))):
+ """See `bucketized_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_bucketized'.format(self.source_column.name)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.source_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns bucketized categorical `source_column` tensor."""
+ source_tensor = transformation_cache.get(self.source_column, state_manager)
+ return math_ops._bucketize( # pylint: disable=protected-access
+ source_tensor,
+ boundaries=self.boundaries)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.TensorShape(
+ tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns one hot encoded dense `Tensor`."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ return array_ops.one_hot(
+ indices=math_ops.to_int64(input_tensor),
+ depth=len(self.boundaries) + 1,
+ on_value=1.,
+ off_value=0.)
+
+ @property
+ def num_buckets(self):
+ """See `CategoricalColumn` base class."""
+ # By construction, source_column is always one-dimensional.
+ return (len(self.boundaries) + 1) * self.source_column.shape[0]
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """Converts dense inputs to SparseTensor so downstream code can use it."""
+ input_tensor = transformation_cache.get(self, state_manager)
+ batch_size = array_ops.shape(input_tensor)[0]
+ # By construction, source_column is always one-dimensional.
+ source_dimension = self.source_column.shape[0]
+
+ i1 = array_ops.reshape(
+ array_ops.tile(
+ array_ops.expand_dims(math_ops.range(0, batch_size), 1),
+ [1, source_dimension]),
+ (-1,))
+ i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
+ # Flatten the bucket indices and unique them across dimensions
+ # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
+ bucket_indices = (
+ array_ops.reshape(input_tensor, (-1,)) +
+ (len(self.boundaries) + 1) * i2)
+
+ indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
+ dense_shape = math_ops.to_int64(array_ops.stack(
+ [batch_size, source_dimension]))
+ sparse_tensor = sparse_tensor_lib.SparseTensor(
+ indices=indices,
+ values=bucket_indices,
+ dense_shape=dense_shape)
+ return CategoricalColumn.IdWeightPair(sparse_tensor, None)
+
+
+class EmbeddingColumn(
+ DenseColumn, SequenceDenseColumn,
+ collections.namedtuple(
+ 'EmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
+ """See `embedding_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_embedding'.format(self.categorical_column.name)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Transforms underlying `categorical_column`."""
+ return transformation_cache.get(self.categorical_column, state_manager)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.vector(self.dimension)
+
+ def _get_dense_tensor_internal(self, transformation_cache, state_manager):
+ """Private method that follows the signature of _get_dense_tensor."""
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sparse_ids = sparse_tensors.id_tensor
+ sparse_weights = sparse_tensors.weight_tensor
+
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_weights = state_manager.get_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer)
+
+ if self.ckpt_to_load_from is not None:
+ to_restore = embedding_weights
+ if isinstance(to_restore, variables.PartitionedVariable):
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
+ checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
+ self.tensor_name_in_ckpt: to_restore
+ })
+
+ # Return embedding lookup result.
+ return _safe_embedding_lookup_sparse(
+ embedding_weights=embedding_weights,
+ sparse_ids=sparse_ids,
+ sparse_weights=sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns tensor after doing the embedding lookup.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Embedding lookup tensor.
+
+ Raises:
+ ValueError: `categorical_column` is SequenceCategoricalColumn.
+ """
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(transformation_cache, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self._get_dense_tensor_internal( # pylint: disable=protected-access
+ transformation_cache, state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _get_graph_for_variable(var):
+ if isinstance(var, variables.PartitionedVariable):
+ return list(var)[0].graph
+ else:
+ return var.graph
+
+
+class SharedEmbeddingColumn(
+ DenseColumn, SequenceDenseColumn,
+ collections.namedtuple(
+ 'SharedEmbeddingColumn',
+ ('categorical_column', 'dimension', 'combiner', 'initializer',
+ 'shared_embedding_collection_name', 'ckpt_to_load_from',
+ 'tensor_name_in_ckpt', 'max_norm', 'trainable'))):
+ """See `embedding_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_shared_embedding'.format(self.categorical_column.name)
+
+ @property
+ def shared_collection_name(self):
+ """Returns the shared name of this column.
+
+ A group of columns share an embedding. Each one of those columns would have
+ the same `shared_collection_name` by which they could be collectively
+ referred to.
+ """
+ return self.shared_embedding_collection_name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class."""
+ return transformation_cache.get(self.categorical_column, state_manager)
+
+ @property
+ def variable_shape(self):
+ """See `DenseColumn` base class."""
+ return tensor_shape.vector(self.dimension)
+
+ def _get_dense_tensor_internal(self, transformation_cache, state_manager):
+ """Private method that follows the signature of _get_dense_tensor."""
+ # This method is called from a variable_scope with name _var_scope_name,
+ # which is shared among all shared embeddings. Open a name_scope here, so
+ # that the ops for different columns have distinct names.
+ with ops.name_scope(None, default_name=self.name):
+ # Get sparse IDs and weights.
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sparse_ids = sparse_tensors.id_tensor
+ sparse_weights = sparse_tensors.weight_tensor
+
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ embedding_weights = state_manager.get_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ initializer=self.initializer)
+
+ if self.ckpt_to_load_from is not None:
+ to_restore = embedding_weights
+ if isinstance(to_restore, variables.PartitionedVariable):
+ to_restore = to_restore._get_variable_list() # pylint: disable=protected-access
+ checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
+ self.tensor_name_in_ckpt: to_restore
+ })
+
+ # Return embedding lookup result.
+ return _safe_embedding_lookup_sparse(
+ embedding_weights=embedding_weights,
+ sparse_ids=sparse_ids,
+ sparse_weights=sparse_weights,
+ combiner=self.combiner,
+ name='%s_weights' % self.name,
+ max_norm=self.max_norm)
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns the embedding lookup result."""
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ return self._get_dense_tensor_internal(transformation_cache, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In embedding_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ dense_tensor = self.get_dense_tensor_internal(transformation_cache,
+ state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _create_tuple(shape, value):
+ """Returns a tuple with given shape and filled with value."""
+ if shape:
+ return tuple([_create_tuple(shape[1:], value) for _ in range(shape[0])])
+ return value
+
+
+def _as_tuple(value):
+ if not nest.is_sequence(value):
+ return value
+ return tuple([_as_tuple(v) for v in value])
+
+
+def _check_shape(shape, key):
+ """Returns shape if it's valid, raises error otherwise."""
+ assert shape is not None
+ if not nest.is_sequence(shape):
+ shape = [shape]
+ shape = tuple(shape)
+ for dimension in shape:
+ if not isinstance(dimension, int):
+ raise TypeError('shape dimensions must be integer. '
+ 'shape: {}, key: {}'.format(shape, key))
+ if dimension < 1:
+ raise ValueError('shape dimensions must be greater than 0. '
+ 'shape: {}, key: {}'.format(shape, key))
+ return shape
+
+
+def _is_shape_and_default_value_compatible(default_value, shape):
+ """Verifies compatibility of shape and default_value."""
+ # Invalid condition:
+ # * if default_value is not a scalar and shape is empty
+ # * or if default_value is an iterable and shape is not empty
+ if nest.is_sequence(default_value) != bool(shape):
+ return False
+ if not shape:
+ return True
+ if len(default_value) != shape[0]:
+ return False
+ for i in range(shape[0]):
+ if not _is_shape_and_default_value_compatible(default_value[i], shape[1:]):
+ return False
+ return True
+
+
+def _check_default_value(shape, default_value, dtype, key):
+ """Returns default value as tuple if it's valid, otherwise raises errors.
+
+ This function verifies that `default_value` is compatible with both `shape`
+ and `dtype`. If it is not compatible, it raises an error. If it is compatible,
+ it casts default_value to a tuple and returns it. `key` is used only
+ for error message.
+
+ Args:
+ shape: An iterable of integers specifies the shape of the `Tensor`.
+ default_value: If a single value is provided, the same value will be applied
+ as the default value for every item. If an iterable of values is
+ provided, the shape of the `default_value` should be equal to the given
+ `shape`.
+ dtype: defines the type of values. Default value is `tf.float32`. Must be a
+ non-quantized, real integer or floating point type.
+ key: Column name, used only for error messages.
+
+ Returns:
+ A tuple which will be used as default value.
+
+ Raises:
+ TypeError: if `default_value` is an iterable but not compatible with `shape`
+ TypeError: if `default_value` is not compatible with `dtype`.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
+ """
+ if default_value is None:
+ return None
+
+ if isinstance(default_value, int):
+ return _create_tuple(shape, default_value)
+
+ if isinstance(default_value, float) and dtype.is_floating:
+ return _create_tuple(shape, default_value)
+
+ if callable(getattr(default_value, 'tolist', None)): # Handles numpy arrays
+ default_value = default_value.tolist()
+
+ if nest.is_sequence(default_value):
+ if not _is_shape_and_default_value_compatible(default_value, shape):
+ raise ValueError(
+ 'The shape of default_value must be equal to given shape. '
+ 'default_value: {}, shape: {}, key: {}'.format(
+ default_value, shape, key))
+ # Check if the values in the list are all integers or are convertible to
+ # floats.
+ is_list_all_int = all(
+ isinstance(v, int) for v in nest.flatten(default_value))
+ is_list_has_float = any(
+ isinstance(v, float) for v in nest.flatten(default_value))
+ if is_list_all_int:
+ return _as_tuple(default_value)
+ if is_list_has_float and dtype.is_floating:
+ return _as_tuple(default_value)
+ raise TypeError('default_value must be compatible with dtype. '
+ 'default_value: {}, dtype: {}, key: {}'.format(
+ default_value, dtype, key))
+
+
+class HashedCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('HashedCategoricalColumn',
+ ('key', 'hash_bucket_size', 'dtype'))):
+ """see `categorical_column_with_hash_bucket`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Hashes the values in the feature_column."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+ if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
+ raise ValueError('SparseColumn input must be a SparseTensor.')
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ if self.dtype == dtypes.string:
+ sparse_values = input_tensor.values
+ else:
+ sparse_values = string_ops.as_string(input_tensor.values)
+
+ sparse_id_values = string_ops.string_to_hash_bucket_fast(
+ sparse_values, self.hash_bucket_size, name='lookup')
+ return sparse_tensor_lib.SparseTensor(
+ input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.hash_bucket_size
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class VocabularyFileCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('VocabularyFileCategoricalColumn',
+ ('key', 'vocabulary_file', 'vocabulary_size',
+ 'num_oov_buckets', 'dtype', 'default_value'))):
+ """See `categorical_column_with_vocabulary_file`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_file` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ # TODO(rohanj): Use state manager to manage the index table creation.
+ return lookup_ops.index_table_from_file(
+ vocabulary_file=self.vocabulary_file,
+ num_oov_buckets=self.num_oov_buckets,
+ vocab_size=self.vocabulary_size,
+ default_value=self.default_value,
+ key_dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.vocabulary_size + self.num_oov_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class VocabularyListCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple(
+ 'VocabularyListCategoricalColumn',
+ ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
+):
+ """See `categorical_column_with_vocabulary_list`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Creates a lookup table for the vocabulary list."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if self.dtype.is_integer != input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Column dtype and SparseTensors dtype must be compatible. '
+ 'key: {}, column dtype: {}, tensor dtype: {}'.format(
+ self.key, self.dtype, input_tensor.dtype))
+
+ _assert_string_or_int(
+ input_tensor.dtype,
+ prefix='column_name: {} input_tensor'.format(self.key))
+
+ key_dtype = self.dtype
+ if input_tensor.dtype.is_integer:
+ # `index_table_from_tensor` requires 64-bit integer keys.
+ key_dtype = dtypes.int64
+ input_tensor = math_ops.to_int64(input_tensor)
+
+ # TODO(rohanj): Use state manager to manage the index table creation.
+ return lookup_ops.index_table_from_tensor(
+ vocabulary_list=tuple(self.vocabulary_list),
+ default_value=self.default_value,
+ num_oov_buckets=self.num_oov_buckets,
+ dtype=key_dtype,
+ name='{}_lookup'.format(self.key)).lookup(input_tensor)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return len(self.vocabulary_list) + self.num_oov_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class IdentityCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple('IdentityCategoricalColumn',
+ ('key', 'number_buckets', 'default_value'))):
+
+ """See `categorical_column_with_identity`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.key
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns a SparseTensor with identity values."""
+ input_tensor = _to_sparse_input_and_drop_ignore_values(
+ transformation_cache.get(self.key, state_manager))
+
+ if not input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Invalid input, not integer. key: {} dtype: {}'.format(
+ self.key, input_tensor.dtype))
+
+ values = math_ops.to_int64(input_tensor.values, name='values')
+ num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
+ zero = math_ops.to_int64(0, name='zero')
+ if self.default_value is None:
+ # Fail if values are out-of-range.
+ assert_less = check_ops.assert_less(
+ values, num_buckets, data=(values, num_buckets),
+ name='assert_less_than_num_buckets')
+ assert_greater = check_ops.assert_greater_equal(
+ values, zero, data=(values,),
+ name='assert_greater_or_equal_0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ values = array_ops.identity(values)
+ else:
+ # Assign default for out-of-range values.
+ values = array_ops.where(
+ math_ops.logical_or(
+ values < zero, values >= num_buckets, name='out_of_range'),
+ array_ops.fill(
+ dims=array_ops.shape(values),
+ value=math_ops.to_int64(self.default_value),
+ name='default_values'),
+ values)
+
+ return sparse_tensor_lib.SparseTensor(
+ indices=input_tensor.indices,
+ values=values,
+ dense_shape=input_tensor.dense_shape)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.number_buckets
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+class WeightedCategoricalColumn(
+ CategoricalColumn,
+ collections.namedtuple(
+ 'WeightedCategoricalColumn',
+ ('categorical_column', 'weight_feature_key', 'dtype'))):
+ """See `weighted_categorical_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_weighted_by_{}'.format(
+ self.categorical_column.name, self.weight_feature_key)
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ config = self.categorical_column.parse_example_spec
+ if self.weight_feature_key in config:
+ raise ValueError('Parse config {} already exists for {}.'.format(
+ config[self.weight_feature_key], self.weight_feature_key))
+ config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
+ return config
+
+ @property
+ def num_buckets(self):
+ """See `DenseColumn` base class."""
+ return self.categorical_column.num_buckets
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Applies weights to tensor generated from `categorical_column`'."""
+ weight_tensor = transformation_cache.get(self.weight_feature_key,
+ state_manager)
+ if weight_tensor is None:
+ raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
+ weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
+ weight_tensor)
+ if self.dtype != weight_tensor.dtype.base_dtype:
+ raise ValueError('Bad dtype, expected {}, but got {}.'.format(
+ self.dtype, weight_tensor.dtype))
+ if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
+ # The weight tensor can be a regular Tensor. In this case, sparsify it.
+ weight_tensor = _to_sparse_input_and_drop_ignore_values(
+ weight_tensor, ignore_value=0.0)
+ if not weight_tensor.dtype.is_floating:
+ weight_tensor = math_ops.to_float(weight_tensor)
+ return (transformation_cache.get(self.categorical_column, state_manager),
+ weight_tensor)
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ tensors = transformation_cache.get(self, state_manager)
+ return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
+
+
+class CrossedColumn(
+ CategoricalColumn,
+ collections.namedtuple('CrossedColumn',
+ ('keys', 'hash_bucket_size', 'hash_key'))):
+ """See `crossed_column`."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ feature_names = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, FeatureColumn):
+ feature_names.append(key.name)
+ else: # key must be a string
+ feature_names.append(key)
+ return '_X_'.join(sorted(feature_names))
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ config = {}
+ for key in self.keys:
+ if isinstance(key, FeatureColumn):
+ config.update(key.parse_example_spec)
+ else: # key must be a string
+ config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
+ return config
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Generates a hashed sparse cross from the input tensors."""
+ feature_tensors = []
+ for key in _collect_leaf_level_keys(self):
+ if isinstance(key, six.string_types):
+ feature_tensors.append(transformation_cache.get(key, state_manager))
+ elif isinstance(key, CategoricalColumn):
+ ids_and_weights = key.get_sparse_tensors(transformation_cache,
+ state_manager)
+ if ids_and_weights.weight_tensor is not None:
+ raise ValueError(
+ 'crossed_column does not support weight_tensor, but the given '
+ 'column populates weight_tensor. '
+ 'Given column: {}'.format(key.name))
+ feature_tensors.append(ids_and_weights.id_tensor)
+ else:
+ raise ValueError('Unsupported column type. Given: {}'.format(key))
+ return sparse_ops.sparse_cross_hashed(
+ inputs=feature_tensors,
+ num_buckets=self.hash_bucket_size,
+ hash_key=self.hash_key)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.hash_bucket_size
+
+ def get_sparse_tensors(self, transformation_cache, state_manager):
+ """See `CategoricalColumn` base class."""
+ return CategoricalColumn.IdWeightPair(
+ transformation_cache.get(self, state_manager), None)
+
+
+def _collect_leaf_level_keys(cross):
+ """Collects base keys by expanding all nested crosses.
+
+ Args:
+ cross: A `CrossedColumn`.
+
+ Returns:
+ A list of strings or `CategoricalColumn` instances.
+ """
+ leaf_level_keys = []
+ for k in cross.keys:
+ if isinstance(k, CrossedColumn):
+ leaf_level_keys.extend(_collect_leaf_level_keys(k))
+ else:
+ leaf_level_keys.append(k)
+ return leaf_level_keys
+
+
+# TODO(zakaria): Move this to embedding_ops and make it public.
+def _safe_embedding_lookup_sparse(embedding_weights,
+ sparse_ids,
+ sparse_weights=None,
+ combiner='mean',
+ default_id=None,
+ name=None,
+ partition_strategy='div',
+ max_norm=None):
+ """Lookup embedding results, accounting for invalid IDs and empty features.
+
+ The partitioned embedding in `embedding_weights` must all be the same shape
+ except for the first dimension. The first dimension is allowed to vary as the
+ vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
+ may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
+ partitioner.
+
+ Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+ with non-positive weight. For an entry with no features, the embedding vector
+ for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+
+ The ids and weights may be multi-dimensional. Embeddings are always aggregated
+ along the last dimension.
+
+ Args:
+ embedding_weights: A list of `P` float `Tensor`s or values representing
+ partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
+ created by partitioning along dimension 0. The total unpartitioned
+ shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
+ vocab size and `e_1, ..., e_m` are the embedding dimensions.
+ sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+ ids. `d_0` is typically batch size.
+ sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+ float weights corresponding to `sparse_ids`, or `None` if all weights
+ are be assumed to be 1.0.
+ combiner: A string specifying how to combine embedding results for each
+ entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+ the default.
+ default_id: The id to use for an entry with no features.
+ name: A name for this operation (optional).
+ partition_strategy: A string specifying the partitioning strategy.
+ Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+ max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
+ combining.
+
+
+ Returns:
+ Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
+
+ Raises:
+ ValueError: if `embedding_weights` is empty.
+ """
+ if embedding_weights is None:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+ if isinstance(embedding_weights, variables.PartitionedVariable):
+ embedding_weights = list(embedding_weights) # get underlying Variables.
+ if not isinstance(embedding_weights, list):
+ embedding_weights = [embedding_weights]
+ if len(embedding_weights) < 1:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+
+ dtype = sparse_weights.dtype if sparse_weights is not None else None
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
+
+ with ops.name_scope(name, 'embedding_lookup',
+ embedding_weights + [sparse_ids,
+ sparse_weights]) as scope:
+ # Reshape higher-rank sparse ids and weights to linear segment ids.
+ original_shape = sparse_ids.dense_shape
+ original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
+ original_rank = (
+ array_ops.size(original_shape)
+ if original_rank_dim.value is None
+ else original_rank_dim.value)
+ sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
+ math_ops.reduce_prod(
+ array_ops.slice(original_shape, [0], [original_rank - 1])),
+ array_ops.gather(original_shape, original_rank - 1)])
+ if sparse_weights is not None:
+ sparse_weights = sparse_tensor_lib.SparseTensor(
+ sparse_ids.indices,
+ sparse_weights.values, sparse_ids.dense_shape)
+
+ # Prune invalid ids and weights.
+ sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+ if combiner != 'sum':
+ sparse_ids, sparse_weights = _prune_invalid_weights(
+ sparse_ids, sparse_weights)
+
+ # Fill in dummy values for empty features, if necessary.
+ sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
+ default_id or
+ 0)
+ if sparse_weights is not None:
+ sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
+
+ result = embedding_ops.embedding_lookup_sparse(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=combiner,
+ partition_strategy=partition_strategy,
+ name=None if default_id is None else scope,
+ max_norm=max_norm)
+
+ if default_id is None:
+ # Broadcast is_row_empty to the same shape as embedding_lookup_result,
+ # for use in Select.
+ is_row_empty = array_ops.tile(
+ array_ops.reshape(is_row_empty, [-1, 1]),
+ array_ops.stack([1, array_ops.shape(result)[1]]))
+
+ result = array_ops.where(is_row_empty,
+ array_ops.zeros_like(result),
+ result,
+ name=scope)
+
+ # Reshape back from linear ids back into higher-dimensional dense result.
+ final_result = array_ops.reshape(
+ result,
+ array_ops.concat([
+ array_ops.slice(
+ math_ops.cast(original_shape, dtypes.int32), [0],
+ [original_rank - 1]),
+ array_ops.slice(array_ops.shape(result), [1], [-1])
+ ], 0))
+ final_result.set_shape(tensor_shape.unknown_shape(
+ (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
+ return final_result
+
+
+def _prune_invalid_ids(sparse_ids, sparse_weights):
+ """Prune invalid IDs (< 0) from the input ids and weights."""
+ is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
+ if sparse_weights is not None:
+ is_id_valid = math_ops.logical_and(
+ is_id_valid,
+ array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
+ if sparse_weights is not None:
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
+ return sparse_ids, sparse_weights
+
+
+def _prune_invalid_weights(sparse_ids, sparse_weights):
+ """Prune invalid weights (< 0) from the input ids and weights."""
+ if sparse_weights is not None:
+ is_weights_valid = math_ops.greater(sparse_weights.values, 0)
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
+ return sparse_ids, sparse_weights
+
+
+class IndicatorColumn(DenseColumn, SequenceDenseColumn,
+ collections.namedtuple('IndicatorColumn',
+ ('categorical_column'))):
+ """Represents a one-hot column for use in deep networks.
+
+ Args:
+ categorical_column: A `CategoricalColumn` which is created by
+ `categorical_column_with_*` function.
+ """
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return '{}_indicator'.format(self.categorical_column.name)
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Transformed feature `Tensor`.
+
+ Raises:
+ ValueError: if input rank is not known at graph building time.
+ """
+ id_weight_pair = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ id_tensor = id_weight_pair.id_tensor
+ weight_tensor = id_weight_pair.weight_tensor
+
+ # If the underlying column is weighted, return the input as a dense tensor.
+ if weight_tensor is not None:
+ weighted_column = sparse_ops.sparse_merge(
+ sp_ids=id_tensor,
+ sp_values=weight_tensor,
+ vocab_size=int(self.variable_shape[-1]))
+ # Remove (?, -1) index
+ weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
+ weighted_column.dense_shape)
+ return sparse_ops.sparse_tensor_to_dense(weighted_column)
+
+ dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
+ id_tensor, default_value=-1)
+
+ # One hot must be float for tf.concat reasons since all other inputs to
+ # input_layer are float32.
+ one_hot_id_tensor = array_ops.one_hot(
+ dense_id_tensor,
+ depth=self.variable_shape[-1],
+ on_value=1.0,
+ off_value=0.0)
+
+ # Reduce to get a multi-hot per example.
+ return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ @property
+ def variable_shape(self):
+ """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
+ return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
+
+ def get_dense_tensor(self, transformation_cache, state_manager):
+ """Returns dense `Tensor` representing feature.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+
+ Returns:
+ Dense `Tensor` created within `transform_feature`.
+
+ Raises:
+ ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
+ """
+ if isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must not be of type SequenceCategoricalColumn. '
+ 'Suggested fix A: If you wish to use input_layer, use a '
+ 'non-sequence categorical_column_with_*. '
+ 'Suggested fix B: If you wish to create sequence input, use '
+ 'sequence_input_layer instead of input_layer. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ return transformation_cache.get(self, state_manager)
+
+ def get_sequence_dense_tensor(self, transformation_cache, state_manager):
+ """See `SequenceDenseColumn` base class."""
+ if not isinstance(self.categorical_column, SequenceCategoricalColumn):
+ raise ValueError(
+ 'In indicator_column: {}. '
+ 'categorical_column must be of type SequenceCategoricalColumn '
+ 'to use sequence_input_layer. '
+ 'Suggested fix: Use one of sequence_categorical_column_with_*. '
+ 'Given (type {}): {}'.format(self.name, type(self.categorical_column),
+ self.categorical_column))
+ # Feature has been already transformed. Return the intermediate
+ # representation created by transform_feature.
+ dense_tensor = transformation_cache.get(self, state_manager)
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sparse_tensors.id_tensor)
+ return SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+def _verify_static_batch_size_equality(tensors, columns):
+ # bath_size is a tf.Dimension object.
+ expected_batch_size = None
+ for i in range(0, len(tensors)):
+ if tensors[i].shape[0].value is not None:
+ if expected_batch_size is None:
+ bath_size_column_index = i
+ expected_batch_size = tensors[i].shape[0]
+ elif not expected_batch_size.is_compatible_with(tensors[i].shape[0]):
+ raise ValueError(
+ 'Batch size (first dimension) of each feature must be same. '
+ 'Batch size of columns ({}, {}): ({}, {})'.format(
+ columns[bath_size_column_index].name, columns[i].name,
+ expected_batch_size, tensors[i].shape[0]))
+
+
+def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
+ """Returns a [batch_size] Tensor with per-example sequence length."""
+ with ops.name_scope(None, 'sequence_length') as name_scope:
+ row_ids = sp_tensor.indices[:, 0]
+ column_ids = sp_tensor.indices[:, 1]
+ column_ids += array_ops.ones_like(column_ids)
+ seq_length = math_ops.to_int64(
+ math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # If the last n rows do not have ids, seq_length will have shape
+ # [batch_size - n]. Pad the remaining values with zeros.
+ n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
+ padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
+ return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
+
+
+class SequenceCategoricalColumn(FeatureColumn,
+ collections.namedtuple(
+ 'SequenceCategoricalColumn',
+ ('categorical_column'))):
+ """Represents sequences of categorical data."""
+
+ @property
+ def name(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.name
+
+ @property
+ def parse_example_spec(self):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.parse_example_spec
+
+ def transform_feature(self, transformation_cache, state_manager):
+ """See `FeatureColumn` base class."""
+ return self.categorical_column.transform_feature(transformation_cache,
+ state_manager)
+
+ @property
+ def num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.categorical_column.num_buckets
+
+ def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
+ """Returns an IdWeightPair.
+
+ `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
+ weights.
+
+ `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
+ `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
+ `SparseTensor` of `float` or `None` to indicate all weights should be
+ taken to be 1. If specified, `weight_tensor` must have exactly the same
+ shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
+ output of a `VarLenFeature` which is a ragged matrix.
+
+ Args:
+ transformation_cache: A `FeatureTransformationCache` object to access
+ features.
+ state_manager: A `StateManager` to create / access resources such as
+ lookup tables.
+ """
+ sparse_tensors = self.categorical_column.get_sparse_tensors(
+ transformation_cache, state_manager)
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands final dimension, so that embeddings are not combined during
+ # embedding lookup.
+ check_id_rank = check_ops.assert_equal(
+ array_ops.rank(id_tensor), 2,
+ data=[
+ 'Column {} expected ID tensor of rank 2. '.format(self.name),
+ 'id_tensor shape: ', array_ops.shape(id_tensor)])
+ with ops.control_dependencies([check_id_rank]):
+ id_tensor = sparse_ops.sparse_reshape(
+ id_tensor,
+ shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+ if weight_tensor is not None:
+ check_weight_rank = check_ops.assert_equal(
+ array_ops.rank(weight_tensor), 2,
+ data=[
+ 'Column {} expected weight tensor of rank 2.'.format(self.name),
+ 'weight_tensor shape:', array_ops.shape(weight_tensor)])
+ with ops.control_dependencies([check_weight_rank]):
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor,
+ shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
new file mode 100644
index 0000000000..80a9d5d40e
--- /dev/null
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -0,0 +1,6583 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for feature_column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import copy
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column as fc_old
+from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
+from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
+from tensorflow.python.feature_column.feature_column_v2 import InputLayer
+from tensorflow.python.feature_column.feature_column_v2 import StateManager
+from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
+from tensorflow.python.feature_column.feature_column_v2 import _transform_features
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
+from tensorflow.python.platform import test
+from tensorflow.python.training import coordinator
+from tensorflow.python.training import queue_runner_impl
+
+
+def _initialized_session(config=None):
+ sess = session.Session(config=config)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ return sess
+
+
+class LazyColumnTest(test.TestCase):
+
+ def test_transformations_called_once(self):
+
+ class TransformCounter(FeatureColumn):
+
+ def __init__(self):
+ self.num_transform = 0
+
+ @property
+ def name(self):
+ return 'TransformCounter'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ self.num_transform += 1 # Count transform calls.
+ return transformation_cache.get('a', state_manager)
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ column = TransformCounter()
+ self.assertEqual(0, column.num_transform)
+ transformation_cache.get(column, None)
+ self.assertEqual(1, column.num_transform)
+ transformation_cache.get(column, None)
+ self.assertEqual(1, column.num_transform)
+
+ def test_returns_transform_output(self):
+
+ class Transformer(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ return 'Output'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ column = Transformer()
+ self.assertEqual('Output', transformation_cache.get(column, None))
+ self.assertEqual('Output', transformation_cache.get(column, None))
+
+ def test_does_not_pollute_given_features_dict(self):
+
+ class Transformer(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'Transformer'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ return 'Output'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ features = {'a': [[2], [3.]]}
+ transformation_cache = FeatureTransformationCache(features=features)
+ transformation_cache.get(Transformer(), None)
+ self.assertEqual(['a'], list(features.keys()))
+
+ def test_error_if_feature_is_not_found(self):
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(ValueError,
+ 'bbb is not in features dictionary'):
+ transformation_cache.get('bbb', None)
+ with self.assertRaisesRegexp(ValueError,
+ 'bbb is not in features dictionary'):
+ transformation_cache.get(u'bbb', None)
+
+ def test_not_supported_feature_column(self):
+
+ class NotAProperColumn(FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotAProperColumn'
+
+ def transform_feature(self, transformation_cache, state_manager):
+ # It should return not None.
+ pass
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(ValueError,
+ 'NotAProperColumn is not supported'):
+ transformation_cache.get(NotAProperColumn(), None)
+
+ def test_key_should_be_string_or_feature_colum(self):
+
+ class NotAFeatureColumn(object):
+ pass
+
+ transformation_cache = FeatureTransformationCache(
+ features={'a': [[2], [3.]]})
+ with self.assertRaisesRegexp(
+ TypeError, '"key" must be either a "str" or "FeatureColumn".'):
+ transformation_cache.get(NotAFeatureColumn(), None)
+
+
+class NumericColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.numeric_column('aaa')
+ self.assertEqual('aaa', a.key)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual((1,), a.shape)
+ self.assertIsNone(a.default_value)
+ self.assertEqual(dtypes.float32, a.dtype)
+ self.assertIsNone(a.normalizer_fn)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.numeric_column(key=('aaa',))
+
+ def test_shape_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual((1, 2), a.shape)
+
+ def test_default_value_saved_as_tuple(self):
+ a = fc.numeric_column('aaa', default_value=4.)
+ self.assertEqual((4.,), a.default_value)
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
+ self.assertEqual(((3., 2.),), a.default_value)
+
+ def test_shape_and_default_value_compatibility(self):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.])
+ fc.numeric_column(
+ 'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]])
+ with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
+ fc.numeric_column(
+ 'aaa', shape=[3, 3], default_value=[[2, 3], [1, 2], [2, 3.]])
+
+ def test_default_value_type_check(self):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.float32)
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError, 'must be compatible with dtype'):
+ fc.numeric_column(
+ 'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.int32)
+ with self.assertRaisesRegexp(TypeError,
+ 'default_value must be compatible with dtype'):
+ fc.numeric_column('aaa', default_value=['string'])
+
+ def test_shape_must_be_positive_integer(self):
+ with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 1.0,
+ ])
+
+ with self.assertRaisesRegexp(ValueError,
+ 'shape dimensions must be greater than 0'):
+ fc.numeric_column(
+ 'aaa', shape=[
+ 0,
+ ])
+
+ def test_dtype_is_convertible_to_float(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'dtype must be convertible to float'):
+ fc.numeric_column('aaa', dtype=dtypes.string)
+
+ def test_scalar_default_value_fills_the_shape(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.)
+ self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
+ }, a.parse_example_spec)
+
+ def test_parse_example_no_default_value(self):
+ price = fc.numeric_column('price', shape=[2])
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+
+ def test_parse_example_with_default_value(self):
+ price = fc.numeric_column('price', shape=[2], default_value=11.)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ no_data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'something_else':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString(),
+ no_data.SerializeToString()],
+ features=fc.make_parse_example_spec([price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
+
+ def test_normalizer_fn_must_be_callable(self):
+ with self.assertRaisesRegexp(TypeError, 'must be a callable'):
+ fc.numeric_column('price', normalizer_fn='NotACallable')
+
+ def test_normalizer_fn_transform_feature(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None)
+ with self.test_session():
+ self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
+
+ def test_get_dense_tensor(self):
+
+ def _increment_two(input_tensor):
+ return input_tensor + 2.
+
+ price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[1., 2.], [5., 6.]]
+ })
+ self.assertEqual(
+ transformation_cache.get(price, None),
+ price.get_dense_tensor(transformation_cache, None))
+
+ def test_sparse_tensor_not_supported(self):
+ price = fc.numeric_column('price')
+ transformation_cache = FeatureTransformationCache({
+ 'price':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
+ })
+ with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
+ price.transform_feature(transformation_cache, None)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
+ a_copy = copy.deepcopy(a)
+ self.assertEqual(a_copy.name, 'aaa')
+ self.assertEqual(a_copy.shape, (1, 2))
+ self.assertEqual(a_copy.default_value, ((3., 2.),))
+
+ def test_numpy_default_value(self):
+ a = fc.numeric_column(
+ 'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
+ self.assertEqual(a.default_value, ((3., 2.),))
+
+ def test_linear_model(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[10.], [50.]], predictions.eval())
+
+
+class BucketizedColumnTest(test.TestCase):
+
+ def test_invalid_source_column_type(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'source_column must be a column generated with numeric_column'):
+ fc.bucketized_column(a, boundaries=[0, 1])
+
+ def test_invalid_source_column_shape(self):
+ a = fc.numeric_column('aaa', shape=[2, 3])
+ with self.assertRaisesRegexp(
+ ValueError, 'source_column must be one-dimensional column'):
+ fc.bucketized_column(a, boundaries=[0, 1])
+
+ def test_invalid_boundaries(self):
+ a = fc.numeric_column('aaa')
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=None)
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=1.)
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=[1, 0])
+ with self.assertRaisesRegexp(
+ ValueError, 'boundaries must be a sorted list'):
+ fc.bucketized_column(a, boundaries=[1, 1])
+
+ def test_name(self):
+ a = fc.numeric_column('aaa', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertEqual('aaa_bucketized', b.name)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ self.assertEqual({
+ 'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
+ }, b.parse_example_spec)
+
+ def test_variable_shape(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ # Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
+ self.assertAllEqual((2, 3), b.variable_shape)
+
+ def test_num_buckets(self):
+ a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ # Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
+ self.assertEqual(6, b.num_buckets)
+
+ def test_parse_example(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([bucketized_price]))
+ self.assertIn('price', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+
+ def test_transform_feature(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformed_tensor = _transform_features({
+ 'price': [[-1., 1.], [5., 6.]]
+ }, [bucketized_price], None)
+ with _initialized_session():
+ self.assertAllEqual([[0, 1], [3, 4]],
+ transformed_tensor[bucketized_price].eval())
+
+ def test_get_dense_tensor_one_input_value(self):
+ """Tests _get_dense_tensor() for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1.], [1.], [5.], [6.]]
+ })
+ with _initialized_session():
+ bucketized_price_tensor = bucketized_price.get_dense_tensor(
+ transformation_cache, None)
+ self.assertAllClose(
+ # One-hot tensor.
+ [[[1., 0., 0., 0., 0.]],
+ [[0., 1., 0., 0., 0.]],
+ [[0., 0., 0., 1., 0.]],
+ [[0., 0., 0., 0., 1.]]],
+ bucketized_price_tensor.eval())
+
+ def test_get_dense_tensor_two_input_values(self):
+ """Tests _get_dense_tensor() for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1., 1.], [5., 6.]]
+ })
+ with _initialized_session():
+ bucketized_price_tensor = bucketized_price.get_dense_tensor(
+ transformation_cache, None)
+ self.assertAllClose(
+ # One-hot tensor.
+ [[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
+ [[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
+ bucketized_price_tensor.eval())
+
+ def test_get_sparse_tensors_one_input_value(self):
+ """Tests _get_sparse_tensors() for input with shape=[1]."""
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1.], [1.], [5.], [6.]]
+ })
+ with _initialized_session() as sess:
+ id_weight_pair = bucketized_price.get_sparse_tensors(
+ transformation_cache, None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ id_tensor_value = sess.run(id_weight_pair.id_tensor)
+ self.assertAllEqual(
+ [[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
+ self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
+ self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
+
+ def test_get_sparse_tensors_two_input_values(self):
+ """Tests _get_sparse_tensors() for input with shape=[2]."""
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'price': [[-1., 1.], [5., 6.]]
+ })
+ with _initialized_session() as sess:
+ id_weight_pair = bucketized_price.get_sparse_tensors(
+ transformation_cache, None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ id_tensor_value = sess.run(id_weight_pair.id_tensor)
+ self.assertAllEqual(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
+ # Values 0-4 correspond to the first column of the input price.
+ # Values 5-9 correspond to the second column of the input price.
+ self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
+ self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
+
+ def test_sparse_tensor_input_not_supported(self):
+ price = fc.numeric_column('price')
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
+ transformation_cache = FeatureTransformationCache({
+ 'price':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
+ })
+ with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
+ bucketized_price.transform_feature(transformation_cache, None)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('aaa', shape=[2])
+ a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
+ a_bucketized_copy = copy.deepcopy(a_bucketized)
+ self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
+ self.assertAllEqual(a_bucketized_copy.variable_shape, (2, 3))
+ self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
+
+ def test_linear_model_one_input_value(self):
+ """Tests linear_model() for input with shape=[1]."""
+ price = fc_old.numeric_column('price', shape=[1])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = fc.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(bucketized_price_var.assign(
+ [[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_linear_model_two_input_values(self):
+ """Tests linear_model() for input with shape=[2]."""
+ price = fc_old.numeric_column('price', shape=[2])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = fc.linear_model(features, [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(bucketized_price_var.assign(
+ [[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
+ def test_keras_linear_model_one_input_value(self):
+ """Tests _LinearModel for input with shape=[1]."""
+ price = fc_old.numeric_column('price', shape=[1])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1.], [1.], [5.], [6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight variable per bucket, all initialized to zero.
+ self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 1st bucket, whose weight is 20.
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 4th bucket, whose weight is 50.
+ self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
+
+ def test_keras_linear_model_two_input_values(self):
+ """Tests _LinearModel for input with shape=[2]."""
+ price = fc_old.numeric_column('price', shape=[2])
+ bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ with ops.Graph().as_default():
+ features = {'price': [[-1., 1.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [bucketized_price])
+ bias = get_linear_model_bias()
+ bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ # One weight per bucket per input column, all initialized to zero.
+ self.assertAllClose(
+ [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
+ bucketized_price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(
+ bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
+ [60.], [70.], [80.], [90.], [100.]]))
+ # 1st example:
+ # price -1. is in the 0th bucket, whose weight is 10.
+ # price 1. is in the 6th bucket, whose weight is 70.
+ # 2nd example:
+ # price 5. is in the 3rd bucket, whose weight is 40.
+ # price 6. is in the 9th bucket, whose weight is 100.
+ self.assertAllClose([[80.], [140.]], predictions.eval())
+ sess.run(bias.assign([1.]))
+ self.assertAllClose([[81.], [141.]], predictions.eval())
+
+
+class HashedCategoricalColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual('aaa', a.key)
+ self.assertEqual(10, a.hash_bucket_size)
+ self.assertEqual(dtypes.string, a.dtype)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_hash_bucket(('key',), 10)
+
+ def test_bucket_size_should_be_given(self):
+ with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
+ fc.categorical_column_with_hash_bucket('aaa', None)
+
+ def test_bucket_size_should_be_positive(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'hash_bucket_size must be at least 1'):
+ fc.categorical_column_with_hash_bucket('aaa', 0)
+
+ def test_dtype_should_be_string_or_integer(self):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_hash_bucket('aaa', 10)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(10, column.hash_bucket_size)
+ self.assertEqual(10, column.num_buckets)
+ self.assertEqual(dtypes.string, column.dtype)
+
+ def test_parse_spec_string(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, a.parse_example_spec)
+
+ def test_parse_spec_int(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, a.parse_example_spec)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_hash_bucket('aaa', 10)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_strings_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ outputs = _transform_features({'wire': wire_tensor}, [hashed_sparse], None)
+ output = outputs[hashed_sparse]
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [6, 4, 1]
+ with self.test_session():
+ self.assertEqual(dtypes.int64, output.values.dtype)
+ self.assertAllEqual(expected_values, output.values.eval())
+ self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
+ self.assertAllEqual(wire_tensor.dense_shape.eval(),
+ output.dense_shape.eval())
+
+ def test_tensor_dtype_should_be_string_or_integer(self):
+ string_fc = fc.categorical_column_with_hash_bucket(
+ 'a_string', 10, dtype=dtypes.string)
+ int_fc = fc.categorical_column_with_hash_bucket(
+ 'a_int', 10, dtype=dtypes.int32)
+ float_fc = fc.categorical_column_with_hash_bucket(
+ 'a_float', 10, dtype=dtypes.string)
+ int_tensor = sparse_tensor.SparseTensor(
+ values=[101],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ string_tensor = sparse_tensor.SparseTensor(
+ values=['101'],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ float_tensor = sparse_tensor.SparseTensor(
+ values=[101.],
+ indices=[[0, 0]],
+ dense_shape=[1, 1])
+ transformation_cache = FeatureTransformationCache({
+ 'a_int': int_tensor,
+ 'a_string': string_tensor,
+ 'a_float': float_tensor
+ })
+ transformation_cache.get(string_fc, None)
+ transformation_cache.get(int_fc, None)
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ transformation_cache.get(float_fc, None)
+
+ def test_dtype_should_match_with_tensor(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ transformation_cache.get(hashed_sparse, None)
+
+ def test_ints_should_be_hashed(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=[101, 201, 301],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ output = transformation_cache.get(hashed_sparse, None)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_int32_64_is_compatible(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket(
+ 'wire', 10, dtype=dtypes.int64)
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=constant_op.constant([101, 201, 301], dtype=dtypes.int32),
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ transformation_cache = FeatureTransformationCache({'wire': wire_tensor})
+ output = transformation_cache.get(hashed_sparse, None)
+ # Check exact hashed output. If hashing changes this test will break.
+ expected_values = [3, 7, 5]
+ with self.test_session():
+ self.assertAllEqual(expected_values, output.values.eval())
+
+ def test_get_sparse_tensors(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ transformation_cache = FeatureTransformationCache({
+ 'wire':
+ sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ })
+ id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache,
+ None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(
+ transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_hash_bucket('aaa', 10)
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column._get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ transformation_cache = FeatureTransformationCache({
+ 'wire': (('omar', ''), ('stringer', 'marlo'))
+ })
+ id_weight_pair = hashed_sparse.get_sparse_tensors(transformation_cache,
+ None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ self.assertEqual(
+ transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 3: wire_var[3] = 4
+ # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
+ self.assertAllClose(((4.,), (6.,)), predictions.eval())
+
+
+class CrossedColumnTest(test.TestCase):
+
+ def test_keys_empty(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'keys must be a list with length > 1'):
+ fc.crossed_column([], 10)
+
+ def test_keys_length_one(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'keys must be a list with length > 1'):
+ fc.crossed_column(['a'], 10)
+
+ def test_key_type_unsupported(self):
+ with self.assertRaisesRegexp(ValueError, 'Unsupported key type'):
+ fc.crossed_column(['a', fc.numeric_column('c')], 10)
+
+ with self.assertRaisesRegexp(
+ ValueError, 'categorical_column_with_hash_bucket is not supported'):
+ fc.crossed_column(
+ ['a', fc.categorical_column_with_hash_bucket('c', 10)], 10)
+
+ def test_hash_bucket_size_negative(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], -1)
+
+ def test_hash_bucket_size_zero(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], 0)
+
+ def test_hash_bucket_size_none(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'hash_bucket_size must be > 1'):
+ fc.crossed_column(['a', 'c'], None)
+
+ def test_name(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_name_ordered_alphabetically(self):
+ """Tests that the name does not depend on the order of given columns."""
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+
+ crossed2 = fc.crossed_column([crossed1, 'c', b], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_name_leaf_keys_ordered_alphabetically(self):
+ """Tests that the name does not depend on the order of given columns."""
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d2', 'c'], 10)
+
+ crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
+
+ def test_parse_spec(self):
+ a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed = fc.crossed_column([b, 'c'], 10)
+ self.assertEqual({
+ 'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
+ 'c': parsing_ops.VarLenFeature(dtypes.string),
+ }, crossed.parse_example_spec)
+
+ def test_num_buckets(self):
+ a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed = fc.crossed_column([b, 'c'], 15)
+ self.assertEqual(15, crossed.num_buckets)
+
+ def test_deep_copy(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32)
+ b = fc.bucketized_column(a, boundaries=[0, 1])
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
+ crossed2_copy = copy.deepcopy(crossed2)
+ self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,)
+ self.assertEqual(15, crossed2_copy.hash_bucket_size)
+ self.assertEqual(5, crossed2_copy.hash_key)
+
+ def test_parse_example(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'price':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[20., 110.])),
+ 'wire':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([price_cross_wire]))
+ self.assertIn('price', features)
+ self.assertIn('wire', features)
+ with self.test_session():
+ self.assertAllEqual([[20., 110.]], features['price'].eval())
+ wire_sparse = features['wire']
+ self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
+ # Use byte constants to pass the open-source test.
+ self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
+ self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
+
+ def test_transform_feature(self):
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
+ hash_bucket_size = 10
+ price_cross_wire = fc.crossed_column(
+ [bucketized_price, 'wire'], hash_bucket_size)
+ features = {
+ 'price': constant_op.constant([[1., 2.], [5., 6.]]),
+ 'wire': sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ }
+ outputs = _transform_features(features, [price_cross_wire], None)
+ output = outputs[price_cross_wire]
+ with self.test_session() as sess:
+ output_val = sess.run(output)
+ self.assertAllEqual(
+ [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
+ for val in output_val.values:
+ self.assertIn(val, list(range(hash_bucket_size)))
+ self.assertAllEqual([2, 4], output_val.dense_shape)
+
+ def test_get_sparse_tensors(self):
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed1 = fc.crossed_column(['d1', 'd2'], 10)
+ crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ 'd1':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['d1A', 'd1B', 'd1C'],
+ dense_shape=(2, 2)),
+ 'd2':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['d2A', 'd2B', 'd2C'],
+ dense_shape=(2, 2)),
+ })
+ id_weight_pair = crossed2.get_sparse_tensors(transformation_cache, None)
+ with _initialized_session():
+ id_tensor_eval = id_weight_pair.id_tensor.eval()
+ self.assertAllEqual(
+ ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
+ (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13),
+ (1, 14), (1, 15)),
+ id_tensor_eval.indices)
+ # Check exact hashed output. If hashing changes this test will break.
+ # All values are within [0, hash_bucket_size).
+ expected_values = (
+ 6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11)
+ self.assertAllEqual(expected_values, id_tensor_eval.values)
+ self.assertAllEqual((2, 16), id_tensor_eval.dense_shape)
+
+ def test_get_sparse_tensors_simple(self):
+ """Same as test_get_sparse_tensors, but with simpler values."""
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ transformation_cache = FeatureTransformationCache({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ })
+ id_weight_pair = crossed.get_sparse_tensors(transformation_cache, None)
+ with _initialized_session():
+ id_tensor_eval = id_weight_pair.id_tensor.eval()
+ self.assertAllEqual(
+ ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)),
+ id_tensor_eval.indices)
+ # Check exact hashed output. If hashing changes this test will break.
+ # All values are within [0, hash_bucket_size).
+ expected_values = (1, 0, 1, 3, 4, 2)
+ self.assertAllEqual(expected_values, id_tensor_eval.values)
+ self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
+
+ def test_linear_model(self):
+ """Tests linear_model.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc_old.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'a': constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return fc_old._CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ fc.linear_model({
+ t.name: sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
+ def test_keras_linear_model(self):
+ """Tests _LinearModel.
+
+ Uses data from test_get_sparse_tesnsors_simple.
+ """
+ a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc_old.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'a':
+ constant_op.constant(((-1., .5), (.5, 1.))),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+ bias = get_linear_model_bias()
+ crossed_var = get_linear_model_column_var(crossed)
+ with _initialized_session() as sess:
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
+ crossed_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
+ # Expected ids after cross = (1, 0, 1, 3, 4, 2)
+ self.assertAllClose(((3.,), (14.,)), predictions.eval())
+ sess.run(bias.assign((.1,)))
+ self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
+
+ def test_keras_linear_model_with_weights(self):
+
+ class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ """Produces sparse IDs and sparse weights."""
+
+ @property
+ def name(self):
+ return 'test_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {
+ self.name:
+ parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name):
+ parsing_ops.VarLenFeature(dtypes.float32),
+ }
+
+ @property
+ def _num_buckets(self):
+ return 5
+
+ def _transform_feature(self, inputs):
+ return (inputs.get(self.name),
+ inputs.get('{}_weights'.format(self.name)))
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ """Populates both id_tensor and weight_tensor."""
+ ids_and_weights = inputs.get(self)
+ return fc_old._CategoricalColumn.IdWeightPair(
+ id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
+
+ t = _TestColumnWithWeights()
+ crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
+ get_keras_linear_model_predictions({
+ t.name:
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[0, 1, 2],
+ dense_shape=(2, 2)),
+ '{}_weights'.format(t.name):
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=[1., 10., 2.],
+ dense_shape=(2, 2)),
+ 'c':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=['cA', 'cB', 'cC'],
+ dense_shape=(2, 2)),
+ }, (crossed,))
+
+
+def get_linear_model_bias(name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+
+def get_linear_model_column_var(column, name='linear_model'):
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ name + '/' + column.name)[0]
+
+
+def get_keras_linear_model_predictions(features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ keras_linear_model = _LinearModel(
+ feature_columns,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ name='linear_model')
+ retval = keras_linear_model(features) # pylint: disable=not-callable
+ if cols_to_vars is not None:
+ cols_to_vars.update(keras_linear_model.cols_to_vars())
+ return retval
+
+
+class LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc.linear_model(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc_old._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ fc.linear_model(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_dense_bias(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self, inputs, weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = fc.linear_model(features, [dense_and_sparse_column])
+ bias = get_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(dense_and_sparse_column_var.assign(
+ [[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = fc.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [
+ 1000., 1100., 1200.
+ ], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc.linear_model(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = fc.linear_model(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_sparse_combiner_with_negative_weights(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {
+ 'wire_cast': wire_tensor,
+ 'weights': constant_op.constant([[1., 1., -1.0]])
+ }
+ predictions = fc.linear_model(
+ features, [wire_cast_weights], sparse_combiner='sum')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [-9985.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = fc.linear_model(features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc.linear_model(features, [price])
+
+ def test_dense_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3.], [4.]]
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_fills_cols_to_vars(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ cols_to_vars = {}
+ fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ self.assertAllEqual(cols_to_vars['bias'], [bias])
+ self.assertAllEqual(cols_to_vars[price1], [price1_var])
+ self.assertAllEqual(cols_to_vars[price2], [price2_var])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2', shape=3)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [6., 7.]],
+ 'price2': [[3., 4., 5.], [8., 9., 10.]]
+ }
+ cols_to_vars = {}
+ with variable_scope.variable_scope(
+ 'linear',
+ partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
+ fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
+ with _initialized_session():
+ self.assertEqual([0.], cols_to_vars['bias'][0].eval())
+ # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
+ self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
+ # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
+ # a [1, 1] Variable.
+ self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+
+ def test_dense_collection(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ fc.linear_model(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ fc.linear_model(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.linear_model(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ fc.linear_model(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.linear_model(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.linear_model(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = fc.linear_model(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.linear_model(features, [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': constant_op.constant([-1., 12.,]),
+ 'body-style': sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = fc.linear_model(features, [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = fc.linear_model(features, [price_buckets, body_style, country])
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc.linear_model(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc.linear_model(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+ def test_multiple_linear_models(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features1 = {'price': [[1.], [5.]]}
+ features2 = {'price': [[2.], [10.]]}
+ predictions1 = fc.linear_model(features1, [price])
+ predictions2 = fc.linear_model(features2, [price])
+ bias1 = get_linear_model_bias(name='linear_model')
+ bias2 = get_linear_model_bias(name='linear_model_1')
+ price_var1 = get_linear_model_column_var(price, name='linear_model')
+ price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias1.eval())
+ sess.run(price_var1.assign([[10.]]))
+ sess.run(bias1.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions1.eval())
+ self.assertAllClose([0.], bias2.eval())
+ sess.run(price_var2.assign([[10.]]))
+ sess.run(bias2.assign([5.]))
+ self.assertAllClose([[25.], [105.]], predictions2.eval())
+
+
+class _LinearModelTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ get_keras_linear_model_predictions(features={}, feature_columns=[])
+
+ def test_should_be_feature_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns='NotSupported')
+
+ def test_should_be_dense_or_categorical_column(self):
+
+ class NotSupportedColumn(fc_old._FeatureColumn):
+
+ @property
+ def name(self):
+ return 'NotSupportedColumn'
+
+ def _transform_feature(self, cache):
+ pass
+
+ @property
+ def _parse_example_spec(self):
+ pass
+
+ with self.assertRaisesRegexp(
+ ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.linear_model(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ get_keras_linear_model_predictions(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_dense_bias(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ sess.run(price_var.assign([[10.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[15.], [55.]], predictions.eval())
+
+ def test_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_and_sparse_bias(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [wire_cast, price])
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ sess.run(price_var.assign([[10.]]))
+ self.assertAllClose([[1015.], [10065.]], predictions.eval())
+
+ def test_dense_and_sparse_column(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+
+ @property
+ def name(self):
+ return 'dense_and_sparse_column'
+
+ @property
+ def _parse_example_spec(self):
+ return {self.name: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.name)
+
+ @property
+ def _variable_shape(self):
+ raise ValueError('Should not use this method.')
+
+ def _get_dense_tensor(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ raise ValueError('Should not use this method.')
+
+ @property
+ def _num_buckets(self):
+ return 4
+
+ def _get_sparse_tensors(self,
+ inputs,
+ weight_collections=None,
+ trainable=None):
+ sp_tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[2, 0, 3],
+ dense_shape=[2, 2])
+ return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+
+ dense_and_sparse_column = _DenseAndSparseColumn()
+ with ops.Graph().as_default():
+ sp_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {dense_and_sparse_column.name: sp_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [dense_and_sparse_column])
+ bias = get_linear_model_bias()
+ dense_and_sparse_column_var = get_linear_model_column_var(
+ dense_and_sparse_column)
+ with _initialized_session() as sess:
+ sess.run(
+ dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
+ [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [10015.]], predictions.eval())
+
+ def test_dense_multi_output(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((1, 3)), price_var.eval())
+ sess.run(price_var.assign([[10., 100., 1000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
+ predictions.eval())
+
+ def test_sparse_multi_output(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], units=3)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
+ sess.run(
+ wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
+ [1000., 1100.,
+ 1200.], [10000., 11000., 12000.]]))
+ sess.run(bias.assign([5., 6., 7.]))
+ self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
+ predictions.eval())
+
+ def test_dense_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_sparse_multi_rank(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = array_ops.sparse_placeholder(dtypes.string)
+ wire_value = sparse_tensor.SparseTensorValue(
+ values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
+ indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
+ dense_shape=[2, 2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(features, [wire_cast])
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
+ self.assertAllClose(
+ np.zeros((2, 1)),
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ self.assertAllClose(
+ [[1010.], [11000.]],
+ predictions.eval(feed_dict={wire_tensor: wire_value}))
+
+ def test_sparse_combiner(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {'wire_cast': wire_tensor}
+ predictions = get_keras_linear_model_predictions(
+ features, [wire_cast], sparse_combiner='mean')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [5010.]], predictions.eval())
+
+ def test_dense_multi_dimension_multi_output(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ predictions = get_keras_linear_model_predictions(
+ features, [price], units=3)
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose(np.zeros((3,)), bias.eval())
+ self.assertAllClose(np.zeros((2, 3)), price_var.eval())
+ sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
+ sess.run(bias.assign([2., 3., 4.]))
+ self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
+ predictions.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ get_keras_linear_model_predictions(features, [price])
+
+ def test_dense_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ predictions = get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price_var.assign([[10.], [100.]]))
+ self.assertAllClose([[210.], [650.]], predictions.eval())
+
+ def test_dense_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ with _initialized_session() as sess:
+ self.assertAllClose([0.], bias.eval())
+ self.assertAllClose([[0.], [0.]], price1_var.eval())
+ self.assertAllClose([[0.]], price2_var.eval())
+ self.assertAllClose([[0.], [0.]], predictions.eval())
+ sess.run(price1_var.assign([[10.], [100.]]))
+ sess.run(price2_var.assign([[1000.]]))
+ sess.run(bias.assign([7.]))
+ self.assertAllClose([[3217.], [4657.]], predictions.eval())
+
+ def test_fills_cols_to_vars(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ cols_to_vars = {}
+ get_keras_linear_model_predictions(
+ features, [price1, price2], cols_to_vars=cols_to_vars)
+ bias = get_linear_model_bias()
+ price1_var = get_linear_model_column_var(price1)
+ price2_var = get_linear_model_column_var(price2)
+ self.assertAllEqual(cols_to_vars['bias'], [bias])
+ self.assertAllEqual(cols_to_vars[price1], [price1_var])
+ self.assertAllEqual(cols_to_vars[price2], [price2_var])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2', shape=3)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [6., 7.]],
+ 'price2': [[3., 4., 5.], [8., 9., 10.]]
+ }
+ cols_to_vars = {}
+ with variable_scope.variable_scope(
+ 'linear',
+ partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
+ get_keras_linear_model_predictions(
+ features, [price1, price2], cols_to_vars=cols_to_vars)
+ with _initialized_session():
+ self.assertEqual([0.], cols_to_vars['bias'][0].eval())
+ # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
+ self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
+ # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
+ # a [1, 1] Variable.
+ self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
+ self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+
+ def test_dense_collection(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(
+ features, [price], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ self.assertIn(bias, my_vars)
+ self.assertIn(price_var, my_vars)
+
+ def test_sparse_collection(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(
+ features, [wire_cast], weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, my_vars)
+ self.assertIn(wire_cast_var, my_vars)
+
+ def test_dense_trainable_default(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price])
+ bias = get_linear_model_bias()
+ price_var = get_linear_model_column_var(price)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(price_var, trainable_vars)
+
+ def test_sparse_trainable_default(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast])
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ self.assertIn(bias, trainable_vars)
+ self.assertIn(wire_cast_var, trainable_vars)
+
+ def test_dense_trainable_false(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default() as g:
+ features = {'price': [[1.], [5.]]}
+ get_keras_linear_model_predictions(features, [price], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_sparse_trainable_false(self):
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ features = {'wire_cast': wire_tensor}
+ get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
+ trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertEqual([], trainable_vars)
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [price_a, wire_cast, price_b],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ with ops.Graph().as_default() as g:
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ 'wire_cast':
+ sparse_tensor.SparseTensor(
+ values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
+ }
+ get_keras_linear_model_predictions(
+ features, [wire_cast, price_b, price_a],
+ weight_collections=['my-vars'])
+ my_vars = g.get_collection('my-vars')
+ self.assertIn('price_a', my_vars[0].name)
+ self.assertIn('price_b', my_vars[1].name)
+ self.assertIn('wire_cast', my_vars[2].name)
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ get_keras_linear_model_predictions(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'must have the same size and shape'):
+ sess.run(
+ predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ predictions = get_keras_linear_model_predictions(features,
+ [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ predictions,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_with_numpy_input_fn(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price':
+ constant_op.constant([
+ -1.,
+ 12.,
+ ]),
+ 'body-style':
+ sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+
+ net = get_keras_linear_model_predictions(features,
+ [price_buckets, body_style])
+ with _initialized_session() as sess:
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ price = fc_old.numeric_column('price')
+ price_buckets = fc_old.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+
+ price_data = np.array([-1., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
+ country_data = np.array(['US', 'CA'])
+
+ net = get_keras_linear_model_predictions(
+ features, [price_buckets, body_style, country])
+ bias = get_linear_model_bias()
+ price_buckets_var = get_linear_model_column_var(price_buckets)
+ body_style_var = get_linear_model_column_var(body_style)
+ with _initialized_session() as sess:
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ get_keras_linear_model_predictions(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = get_keras_linear_model_predictions(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+class InputLayerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_retrieving_input(self):
+ features = {'a': [0.]}
+ input_layer = InputLayer(fc_old.numeric_column('a'))
+ inputs = self.evaluate(input_layer(features))
+ self.assertAllClose([[0.]], inputs)
+
+ def test_reuses_variables(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ inputs = input_layer(features)
+ variables = input_layer.variables
+
+ # Sanity check: test that the inputs are correct.
+ self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
+
+ # Check that only one variable was created.
+ self.assertEqual(1, len(variables))
+
+ # Check that invoking input_layer on the same features does not create
+ # additional variables
+ _ = input_layer(features)
+ self.assertEqual(1, len(variables))
+ self.assertEqual(variables[0], input_layer.variables[0])
+
+ def test_feature_column_input_layer_gradient(self):
+ with context.eager_mode():
+ sparse_input = sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (2, 0)),
+ values=(0, 1, 2),
+ dense_shape=(3, 3))
+
+ # Create feature columns (categorical and embedding).
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='a', num_buckets=3)
+ embedding_dimension = 2
+
+ def _embedding_column_initializer(shape, dtype, partition_info):
+ del shape # unused
+ del dtype # unused
+ del partition_info # unused
+ embedding_values = (
+ (1, 0), # id 0
+ (0, 1), # id 1
+ (1, 1)) # id 2
+ return embedding_values
+
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_embedding_column_initializer)
+
+ input_layer = InputLayer([embedding_column])
+ features = {'a': sparse_input}
+
+ def scale_matrix():
+ matrix = input_layer(features)
+ return 2 * matrix
+
+ # Sanity check: Verify that scale_matrix returns the correct output.
+ self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
+
+ # Check that the returned gradient is correct.
+ grad_function = backprop.implicit_grad(scale_matrix)
+ grads_and_vars = grad_function()
+ indexed_slice = grads_and_vars[0][0]
+ gradient = grads_and_vars[0][0].values
+
+ self.assertAllEqual([0, 1, 2], indexed_slice.indices)
+ self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
+
+
+class FunctionalInputLayerTest(test.TestCase):
+
+ def test_raises_if_empty_feature_columns(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'feature_columns must not be empty'):
+ fc.input_layer(features={}, feature_columns=[])
+
+ def test_should_be_dense_column(self):
+ with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])
+
+ def test_does_not_support_dict_columns(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Expected feature_columns to be iterable, found dict.'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns={'a': fc_old.numeric_column('a')})
+
+ def test_bare_column(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.]}
+ net = fc.input_layer(features, fc_old.numeric_column('a'))
+ with _initialized_session():
+ self.assertAllClose([[0.]], net.eval())
+
+ def test_column_generator(self):
+ with ops.Graph().as_default():
+ features = features = {'a': [0.], 'b': [1.]}
+ columns = (fc_old.numeric_column(key) for key in features)
+ net = fc.input_layer(features, columns)
+ with _initialized_session():
+ self.assertAllClose([[0., 1.]], net.eval())
+
+ def test_raises_if_duplicate_name(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Duplicate feature column name found for columns'):
+ fc.input_layer(
+ features={'a': [[0]]},
+ feature_columns=[
+ fc_old.numeric_column('a'),
+ fc_old.numeric_column('a')
+ ])
+
+ def test_one_column(self):
+ price = fc_old.numeric_column('price')
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1.], [5.]], net.eval())
+
+ def test_multi_dimension(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1., 2.], [5., 6.]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_raises_if_shape_mismatch(self):
+ price = fc_old.numeric_column('price', shape=2)
+ with ops.Graph().as_default():
+ features = {'price': [[1.], [5.]]}
+ with self.assertRaisesRegexp(
+ Exception,
+ r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
+ fc.input_layer(features, [price])
+
+ def test_reshaping(self):
+ price = fc_old.numeric_column('price', shape=[1, 2])
+ with ops.Graph().as_default():
+ features = {'price': [[[1., 2.]], [[5., 6.]]]}
+ net = fc.input_layer(features, [price])
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
+
+ def test_multi_column(self):
+ price1 = fc_old.numeric_column('price1', shape=2)
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1., 2.], [5., 6.]],
+ 'price2': [[3.], [4.]]
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session():
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
+
+ def test_fills_cols_to_vars(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ price1 = fc_old.numeric_column('price1')
+ dense_feature = fc_old.numeric_column('dense_feature')
+ dense_feature_bucketized = fc_old.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+
+ def test_fills_cols_to_vars_partitioned_variables(self):
+ price1 = fc_old.numeric_column('price1')
+ dense_feature = fc_old.numeric_column('dense_feature')
+ dense_feature_bucketized = fc_old.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ cols_to_vars = {}
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
+ self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+
+ def test_column_order(self):
+ price_a = fc_old.numeric_column('price_a')
+ price_b = fc_old.numeric_column('price_b')
+ with ops.Graph().as_default():
+ features = {
+ 'price_a': [[1.]],
+ 'price_b': [[3.]],
+ }
+ net1 = fc.input_layer(features, [price_a, price_b])
+ net2 = fc.input_layer(features, [price_b, price_a])
+ with _initialized_session():
+ self.assertAllClose([[1., 3.]], net1.eval())
+ self.assertAllClose([[1., 3.]], net2.eval())
+
+ def test_fails_for_categorical_column(self):
+ animal = fc_old.categorical_column_with_identity('animal', num_buckets=4)
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
+ fc.input_layer(features, [animal])
+
+ def test_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[1.], [5.], [7.]], # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.input_layer(features, [price1, price2])
+
+ def test_subset_of_static_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ price3 = fc_old.numeric_column('price3')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]], # batchsize = 2
+ 'price3': [[3.], [4.], [5.]] # batchsize = 3
+ }
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
+ fc.input_layer(features, [price1, price2, price3])
+
+ def test_runtime_batch_size_mismatch(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
+ 'price2': [[3.], [4.]] # batchsize = 2
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ with self.assertRaisesRegexp(errors.OpError,
+ 'Dimensions of inputs should match'):
+ sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
+
+ def test_runtime_batch_size_matches(self):
+ price1 = fc_old.numeric_column('price1')
+ price2 = fc_old.numeric_column('price2')
+ with ops.Graph().as_default():
+ features = {
+ 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
+ }
+ net = fc.input_layer(features, [price1, price2])
+ with _initialized_session() as sess:
+ sess.run(
+ net,
+ feed_dict={
+ features['price1']: [[1.], [5.]],
+ features['price2']: [[1.], [5.]],
+ })
+
+ def test_multiple_layers_with_same_embedding_column(self):
+ some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc_old.embedding_column(
+ some_sparse_column, dimension=10)
+
+ with ops.Graph().as_default():
+ features = {
+ 'sparse_feature': [['a'], ['x']],
+ }
+ all_cols = [some_embedding_column]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that 2 variables get created in this case.
+ self.assertEqual(2, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ expected_var_names = [
+ 'input_layer/sparse_feature_embedding/embedding_weights:0',
+ 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ ]
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column(self):
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ all_cols = [embedding_column_a, embedding_column_b]
+ fc.input_layer(features, all_cols)
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+ all_cols = [embedding_column_a, embedding_column_b]
+
+ with ops.Graph().as_default():
+ features = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ fc.input_layer(features, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+
+ with ops.Graph().as_default():
+ features1 = {
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+
+ fc.input_layer(features1, all_cols)
+ # Make sure that only 1 variable gets created in this case.
+ self.assertEqual(1, len(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+
+ def test_with_numpy_input_fn(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ # one_hot_body_style has 3 dims in input_layer.
+ one_hot_body_style = fc_old.indicator_column(body_style)
+ # embedded_body_style has 5 dims in input_layer.
+ embedded_body_style = fc_old.embedding_column(
+ body_style, dimension=5, initializer=_initializer)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([11., 12., 13., 14.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_body_style])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[11., 12., 13., 14., 15., 0., 0., 1., 11.],
+ [1., 2., 3., 4., 5., 1., 0., 0., 12]],
+ sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_with_1d_sparse_tensor(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc_old.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc_old.embedding_column(
+ country, dimension=5, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': constant_op.constant([11., 12.,]),
+ 'body-style': sparse_tensor.SparseTensor(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,)),
+ # This is dense tensor for the categorical_column.
+ 'country': constant_op.constant(['CA', 'US']),
+ }
+ self.assertEqual(1, features['price'].shape.ndims)
+ self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
+ self.assertEqual(1, features['country'].shape.ndims)
+
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[0., 0., 1., 11., 12., 13., 14., 15., 11.],
+ [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
+ sess.run(net))
+
+ def test_with_1d_unknown_shape_sparse_tensor(self):
+ embedding_values = (
+ (1., 2.), # id 0
+ (6., 7.), # id 1
+ (11., 12.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+
+ # one_hot_body_style has 3 dims in input_layer.
+ body_style = fc_old.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ one_hot_body_style = fc_old.indicator_column(body_style)
+
+ # embedded_body_style has 5 dims in input_layer.
+ country = fc_old.categorical_column_with_vocabulary_list(
+ 'country', vocabulary_list=['US', 'JP', 'CA'])
+ embedded_country = fc_old.embedding_column(
+ country, dimension=2, initializer=_initializer)
+
+ # Provides 1-dim tensor and dense tensor.
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ 'body-style': array_ops.sparse_placeholder(dtypes.string),
+ # This is dense tensor for the categorical_column.
+ 'country': array_ops.placeholder(dtypes.string),
+ }
+ self.assertIsNone(features['price'].shape.ndims)
+ self.assertIsNone(features['body-style'].get_shape().ndims)
+ self.assertIsNone(features['country'].shape.ndims)
+
+ price_data = np.array([11., 12.])
+ body_style_data = sparse_tensor.SparseTensorValue(
+ indices=((0,), (1,)),
+ values=('sedan', 'hardtop'),
+ dense_shape=(2,))
+ country_data = np.array([['US'], ['CA']])
+
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_country])
+ self.assertEqual(1 + 3 + 2, net.shape[1])
+ with _initialized_session() as sess:
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
+ sess.run(
+ net,
+ feed_dict={
+ features['price']: price_data,
+ features['body-style']: body_style_data,
+ features['country']: country_data
+ }))
+
+ def test_with_rank_0_feature(self):
+ # price has 1 dimension in input_layer
+ price = fc_old.numeric_column('price')
+ features = {
+ 'price': constant_op.constant(0),
+ }
+ self.assertEqual(0, features['price'].shape.ndims)
+
+ # Static rank 0 should fail
+ with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
+ fc.input_layer(features, [price])
+
+ # Dynamic rank 0 should fail
+ features = {
+ 'price': array_ops.placeholder(dtypes.float32),
+ }
+ net = fc.input_layer(features, [price])
+ self.assertEqual(1, net.shape[1])
+ with _initialized_session() as sess:
+ with self.assertRaisesOpError('Feature .* cannot have rank 0'):
+ sess.run(net, feed_dict={features['price']: np.array(1)})
+
+
+class MakeParseExampleSpecTest(test.TestCase):
+
+ class _TestFeatureColumn(FeatureColumn,
+ collections.namedtuple('_TestFeatureColumn',
+ ('parse_spec'))):
+
+ @property
+ def name(self):
+ return "_TestFeatureColumn"
+
+ def transform_feature(self, transformation_cache, state_manager):
+ pass
+
+ @property
+ def parse_example_spec(self):
+ return self.parse_spec
+
+ def test_no_feature_columns(self):
+ actual = fc.make_parse_example_spec([])
+ self.assertDictEqual({}, actual)
+
+ def test_invalid_type(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'All feature_columns must be FeatureColumn instances.*invalid_column'):
+ fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}), 'invalid_column'))
+
+ def test_one_feature_column(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),))
+ self.assertDictEqual({key1: parse_spec1}, actual)
+
+ def test_two_feature_columns(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ key2 = 'key2'
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key2: parse_spec2})))
+ self.assertDictEqual({key1: parse_spec1, key2: parse_spec2}, actual)
+
+ def test_equal_keys_different_parse_spec(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'feature_columns contain different parse_spec for key key1'):
+ fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key1: parse_spec2})))
+
+ def test_equal_keys_equal_parse_spec(self):
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key1: parse_spec1})))
+ self.assertDictEqual({key1: parse_spec1}, actual)
+
+ def test_multiple_features_dict(self):
+ """parse_spc for one column is a dict with length > 1."""
+ key1 = 'key1'
+ parse_spec1 = parsing_ops.FixedLenFeature(
+ shape=(2,), dtype=dtypes.float32, default_value=0.)
+ key2 = 'key2'
+ parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
+ key3 = 'key3'
+ parse_spec3 = parsing_ops.VarLenFeature(dtype=dtypes.int32)
+ actual = fc.make_parse_example_spec(
+ (self._TestFeatureColumn({key1: parse_spec1}),
+ self._TestFeatureColumn({key2: parse_spec2, key3: parse_spec3})))
+ self.assertDictEqual(
+ {key1: parse_spec1, key2: parse_spec2, key3: parse_spec3}, actual)
+
+
+def _assert_sparse_tensor_value(test_case, expected, actual):
+ test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+ test_case.assertAllEqual(expected.indices, actual.indices)
+
+ test_case.assertEqual(
+ np.array(expected.values).dtype, np.array(actual.values).dtype)
+ test_case.assertAllEqual(expected.values, actual.values)
+
+ test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
+ test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
+
+
+class VocabularyFileCategoricalColumnTest(test.TestCase):
+
+ def setUp(self):
+ super(VocabularyFileCategoricalColumnTest, self).setUp()
+
+ # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
+ self._warriors_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/warriors_vocabulary.txt')
+ self._warriors_vocabulary_size = 5
+
+ # Contains strings, character names from 'The Wire': omar, stringer, marlo
+ self._wire_vocabulary_file_name = test.test_src_dir_path(
+ 'python/feature_column/testdata/wire_vocabulary.txt')
+ self._wire_vocabulary_size = 3
+
+ def test_defaults(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_file(
+ key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ self.assertEqual(7, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
+ num_oov_buckets=4, dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(7, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_vocabulary_file_none(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=None, vocabulary_size=3)
+
+ def test_vocabulary_file_empty_string(self):
+ with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='', vocabulary_size=3)
+
+ def test_invalid_vocabulary_file(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_vocabulary_size(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=-1)
+ with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=0)
+
+ def test_too_large_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size + 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
+ with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
+ with self.test_session():
+ lookup_ops.tables_initializer().run()
+
+ def test_invalid_num_oov_buckets(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ num_oov_buckets=-1)
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path', vocabulary_size=3,
+ dtype=dtypes.float64)
+
+ def test_invalid_buckets_and_default_value(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'both num_oov_buckets and default_value'):
+ fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100,
+ default_value=2)
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ dtype=dtypes.string)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_none_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (1, 2)),
+ values=('marlo', 'skywalker', 'omar', 'heisenberg'),
+ dense_shape=(2, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 33, 0, 62), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_small_vocabulary_size(self):
+ # 'marlo' is the last entry in our vocabulary file, so be setting
+ # `vocabulary_size` to 1 less than number of entries in file, we take
+ # 'marlo' out of the vocabulary.
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size - 1)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((-1, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ default_value=default_value)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._warriors_vocabulary_file_name,
+ vocabulary_size=self._warriors_vocabulary_size,
+ dtype=dtypes.int32,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 60, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_file(
+ key='wire',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size,
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+
+class VocabularyListCategoricalColumnTest(test.TestCase):
+
+ def test_defaults_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.string)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_list(
+ key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
+ def test_defaults_int(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36))
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
+ default_value=-99)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int32)
+ }, column.parse_example_spec)
+
+ def test_invalid_dtype(self):
+ with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.float32)
+
+ def test_invalid_mapping_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary dtype must be string or integer'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12., 24., 36.))
+
+ def test_mismatched_int_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
+ dtype=dtypes.int32)
+
+ def test_mismatched_string_dtype(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'dtype.*and vocabulary dtype.*do not match'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string)
+
+ def test_none_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=None)
+
+ def test_empty_mapping(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'vocabulary_list.*must be non-empty'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=tuple([]))
+
+ def test_duplicate_mapping(self):
+ with self.assertRaisesRegexp(ValueError, 'Duplicate keys'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 12))
+
+ def test_invalid_num_oov_buckets(self):
+ with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(12, 24, 36),
+ num_oov_buckets=-1)
+
+ def test_invalid_buckets_and_default_value(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'both num_oov_buckets and default_value'):
+ fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=(12, 24, 36),
+ num_oov_buckets=100,
+ default_value=2)
+
+ def test_invalid_input_dtype_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(12, 24, 36),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_invalid_input_dtype_string(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=(12, 24, 36))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example_string(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_parse_example_int(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=(11, 21, 31))
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=[11, 21],
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': (('marlo', ''), ('skywalker', 'omar'))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_default_value_in_vocabulary(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ default_value=2)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 2, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (1, 2)),
+ values=('marlo', 'skywalker', 'omar', 'heisenberg'),
+ dense_shape=(2, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 33, 0, 62), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((11, 100, 30, 22), dtype=np.int32),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, -1, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_dense_input(self):
+ default_value = -100
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32,
+ default_value=default_value)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa':
+ np.array(
+ ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32)
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=np.array((2, default_value, 0, 4), dtype=np.int64),
+ dense_shape=(3, 3)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_int32_with_oov_buckets(self):
+ column = fc.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
+ dtype=dtypes.int32,
+ num_oov_buckets=100)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1), (2, 2)),
+ values=(11, 100, 30, 22),
+ dense_shape=(3, 3))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((2, 60, 0, 4), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ wire_column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ wire_column = fc_old.categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'),
+ num_oov_buckets=1)
+ self.assertEqual(4, wire_column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ wire_column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }, (wire_column,))
+ bias = get_linear_model_bias()
+ wire_var = get_linear_model_column_var(wire_column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
+ # 'marlo' -> 2: wire_var[2] = 3
+ # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
+ self.assertAllClose(((3.,), (5.,)), predictions.eval())
+
+
+class IdentityCategoricalColumnTest(test.TestCase):
+
+ def test_constructor(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual('aaa', column.name)
+ self.assertEqual('aaa', column.key)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
+ def test_deep_copy(self):
+ original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column.parse_example_spec)
+
+ def test_invalid_num_buckets_zero(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=0)
+
+ def test_invalid_num_buckets_negative(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
+
+ def test_invalid_default_value_too_small(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=-1)
+
+ def test_invalid_default_value_too_big(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=3)
+
+ def test_invalid_input_dtype(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+ value=[11, 21]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([11, 21], dtype=np.int64),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_transform_feature(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ id_tensor = _transform_features({'aaa': inputs}, [column], None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+
+ def DISABLED_test_get_sparse_tensors_weight_collections(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }),
+ weight_collections=('my_weights',))
+
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertItemsEqual([], ops.get_collection('my_weights'))
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': ((0, -1), (1, 0))
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_inputs_too_small(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_greater_or_equal_0'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_inputs_too_big(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 99, 0),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_less_than_num_buckets'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_default_value(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 99),
+ dense_shape=(2, 2))
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int32)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ id_weight_pair = column.get_sparse_tensors(
+ FeatureTransformationCache({
+ 'aaa': inputs
+ }), None)
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=np.array((2, 2), dtype=np.int64)),
+ id_weight_pair.id_tensor.eval(feed_dict={
+ input_indices: ((0, 0), (1, 0), (1, 1)),
+ input_values: (1, -1, 99),
+ input_shape: (2, 2),
+ }))
+
+ def test_linear_model(self):
+ column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column.num_buckets)
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ column.name:
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+
+class TransformFeaturesTest(test.TestCase):
+
+ # All transform tests are distributed in column test.
+ # Here we only test multi column case and naming
+ def transform_multi_column(self):
+ bucketized_price = fc.bucketized_column(
+ fc.numeric_column('price'), boundaries=[0, 2, 4, 6])
+ hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
+ with ops.Graph().as_default():
+ features = {
+ 'price': [[-1.], [5.]],
+ 'wire':
+ sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ }
+ transformed = _transform_features(features,
+ [bucketized_price, hashed_sparse], None)
+ with _initialized_session():
+ self.assertIn(bucketized_price.name, transformed[bucketized_price].name)
+ self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval())
+ self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name)
+ self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval())
+
+ def test_column_order(self):
+ """When the column is both dense and sparse, uses sparse tensors."""
+
+ class _LoggerColumn(FeatureColumn):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+ def transform_feature(self, transformation_cache, state_manager):
+ self.call_order = call_logger['count']
+ call_logger['count'] += 1
+ return 'Anything'
+
+ @property
+ def parse_example_spec(self):
+ pass
+
+ with ops.Graph().as_default():
+ column1 = _LoggerColumn('1')
+ column2 = _LoggerColumn('2')
+ call_logger = {'count': 0}
+ _transform_features({}, [column1, column2], None)
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+ call_logger = {'count': 0}
+ _transform_features({}, [column2, column1], None)
+ self.assertEqual(0, column1.call_order)
+ self.assertEqual(1, column2.call_order)
+
+
+class IndicatorColumnTest(test.TestCase):
+
+ def test_indicator_column(self):
+ a = fc.categorical_column_with_hash_bucket('a', 4)
+ indicator_a = fc.indicator_column(a)
+ self.assertEqual(indicator_a.categorical_column.name, 'a')
+ self.assertEqual(indicator_a.name, 'a_indicator')
+ self.assertEqual(indicator_a.variable_shape, [1, 4])
+
+ b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100)
+ indicator_b = fc.indicator_column(b)
+ self.assertEqual(indicator_b.categorical_column.name, 'b')
+ self.assertEqual(indicator_b.name, 'b_indicator')
+ self.assertEqual(indicator_b.variable_shape, [1, 100])
+
+ def test_1D_shape_succeeds(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_hash_bucket('animal', 4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal': ['fox', 'fox']
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
+
+ def test_2D_shape_succeeds(self):
+ # TODO(ispir/cassandrax): Swith to categorical_column_with_keys when ready.
+ animal = fc.indicator_column(
+ fc.categorical_column_with_hash_bucket('animal', 4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0]],
+ values=['fox', 'fox'],
+ dense_shape=[2, 1])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
+
+ def test_multi_hot(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
+
+ def test_multi_hot2(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
+ transformation_cache = FeatureTransformationCache({
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ })
+ output = transformation_cache.get(animal, None)
+ with self.test_session():
+ self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
+
+ def test_deep_copy(self):
+ a = fc.categorical_column_with_hash_bucket('a', 4)
+ column = fc.indicator_column(a)
+ column_copy = copy.deepcopy(column)
+ self.assertEqual(column_copy.categorical_column.name, 'a')
+ self.assertEqual(column.name, 'a_indicator')
+ self.assertEqual(column.variable_shape, [1, 4])
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_indicator]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_indicator = fc.indicator_column(a)
+ features = {
+ 'aaa': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ }
+ indicator_tensor = _transform_features(features, [a_indicator],
+ None)[a_indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
+
+ def test_transform_with_weighted_column(self):
+ # Github issue 12557
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'a']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+
+ def test_transform_with_missing_value_in_weighted_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ 'weights': constant_op.constant([[2., 4., 6.]])
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
+
+ def test_transform_with_missing_value_in_categorical_column(self):
+ # Github issue 12583
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ indicator = fc.indicator_column(ids)
+ features = {
+ 'ids': constant_op.constant([['c', 'b', 'unknown']]),
+ }
+ indicator_tensor = _transform_features(features, [indicator],
+ None)[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
+
+ def test_linear_model(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = fc.linear_model(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+
+ predictions = get_keras_linear_model_predictions(features, [animal])
+ weight_var = get_linear_model_column_var(animal)
+ with _initialized_session():
+ # All should be zero-initialized.
+ self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
+ self.assertAllClose([[0.]], predictions.eval())
+ weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
+ self.assertAllClose([[2. + 3.]], predictions.eval())
+
+ def test_input_layer(self):
+ animal = fc_old.indicator_column(
+ fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ with ops.Graph().as_default():
+ features = {
+ 'animal':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
+ }
+ net = fc.input_layer(features, [animal])
+ with _initialized_session():
+ self.assertAllClose([[0., 1., 1., 0.]], net.eval())
+
+
+class _TestStateManager(StateManager):
+
+ def __init__(self, trainable=True):
+ # Dict of feature_column to a dict of variables.
+ self._all_variables = {}
+ self._trainable = trainable
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ if feature_column not in self._all_variables:
+ self._all_variables[feature_column] = {}
+ var_dict = self._all_variables[feature_column]
+ if name in var_dict:
+ return var_dict[name]
+ else:
+ var = variable_scope.get_variable(
+ name=name,
+ shape=shape,
+ initializer=initializer,
+ trainable=self._trainable)
+ var_dict[name] = var
+ return var
+
+
+class EmbeddingColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension)
+ self.assertIs(categorical_column, embedding_column.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('mean', embedding_column.combiner)
+ self.assertIsNone(embedding_column.ckpt_to_load_from)
+ self.assertIsNone(embedding_column.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column.max_norm)
+ self.assertTrue(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ combiner='my_combiner', initializer=lambda: 'my_initializer',
+ ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ self.assertIs(categorical_column, embedding_column.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('my_combiner', embedding_column.combiner)
+ self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column.max_norm)
+ self.assertFalse(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_deep_copy(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_dimension = 2
+ original = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ combiner='my_combiner', initializer=lambda: 'my_initializer',
+ ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ for embedding_column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', embedding_column.categorical_column.name)
+ self.assertEqual(3, embedding_column.categorical_column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.categorical_column.parse_example_spec)
+
+ self.assertEqual(embedding_dimension, embedding_column.dimension)
+ self.assertEqual('my_combiner', embedding_column.combiner)
+ self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column.max_norm)
+ self.assertFalse(embedding_column.trainable)
+ self.assertEqual('aaa_embedding', embedding_column.name)
+ self.assertEqual((embedding_dimension,), embedding_column.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column.parse_example_spec)
+
+ def test_invalid_initializer(self):
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
+ fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_embedded = fc.embedding_column(a, dimension=2)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer']))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_embedded]))
+ self.assertIn('aaa', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+
+ def test_transform_feature(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ a_embedded = fc.embedding_column(a, dimension=2)
+ features = {
+ 'aaa': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ }
+ outputs = _transform_features(features, [a, a_embedded], None)
+ output_a = outputs[a]
+ output_embedded = outputs[a_embedded]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self, output_a.eval(), output_embedded.eval())
+
+ def test_get_dense_tensor(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def test_get_dense_tensor_3d(self):
+ # Inputs.
+ vocabulary_size = 4
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)),
+ values=(2, 0, 1, 1, 2),
+ dense_shape=(4, 2, 5))
+
+ # Embedding variable.
+ embedding_dimension = 3
+ embedding_values = (
+ (1., 2., 4.), # id 0
+ (3., 5., 1.), # id 1
+ (7., 11., 2.), # id 2
+ (2., 7., 12.) # id 3
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]]
+ ((7., 11., 2.), (0., 0., 0.)),
+ # example 1, ids [[], [0, 1]], embedding
+ # = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]]
+ ((0., 0., 0.), (2., 3.5, 2.5)),
+ # example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]]
+ ((0., 0., 0.), (0., 0., 0.)),
+ # example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]]
+ ((3., 5., 1.), (7., 11., 2.)),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def DISABLED_test_get_dense_tensor_weight_collections(self):
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ embedding_column = fc.embedding_column(categorical_column, dimension=2)
+
+ # Provide sparse input and get dense result.
+ embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }),
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in my_vars]))
+
+ def test_get_dense_tensor_placeholder_inputs(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int64)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa':
+ sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval(
+ feed_dict={
+ input_indices: sparse_input.indices,
+ input_values: sparse_input.values,
+ input_shape: sparse_input.dense_shape,
+ }))
+
+ def test_get_dense_tensor_restore_from_ckpt(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable. The checkpoint file contains _embedding_values.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ ckpt_path = test.test_src_dir_path(
+ 'python/feature_column/testdata/embedding.ckpt')
+ ckpt_tensor = 'my_embedding'
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc.embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ ckpt_to_load_from=ckpt_path,
+ tensor_name_in_ckpt=ckpt_tensor)
+ state_manager = _TestStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup = embedding_column.get_dense_tensor(
+ FeatureTransformationCache({
+ 'aaa': sparse_input
+ }), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval())
+
+ def test_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v for v in ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars[
+ 'linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 4
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(batch_size, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column.name: sparse_input
+ }, (embedding_column,))
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_embedding/weights:0',
+ 'linear_model/aaa_embedding/embedding_weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_embedding/embedding_weights:0']
+ linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # example 2, ids [], embedding[2] = [0, 0]
+ # example 3, ids [1], embedding[3] = [3, 5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
+ self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
+
+ def test_input_layer(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, trainable_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+ def test_input_layer_not_trainable(self):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 4), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0, ids [2], embedding = [7, 11]
+ (7., 11.),
+ # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ (2., 3.5),
+ # example 2, ids [], embedding = [0, 0]
+ (0., 0.),
+ # example 3, ids [1], embedding = [3, 5]
+ (3., 5.),
+ )
+
+ # Build columns.
+ categorical_column = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = fc_old.embedding_column(
+ categorical_column,
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=False)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(
+ [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, global_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+
+class _TestSharedEmbeddingStateManager(StateManager):
+ """Manages the state for shared embedding columns.
+
+ This can handle multiple groups of shared embedding columns.
+ """
+
+ def __init__(self, trainable=True):
+ # Dict of shared_embedding_collection_name to a dict of variables.
+ self._all_variables = {}
+ self._trainable = trainable
+
+ def get_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ initializer=None):
+ if not isinstance(feature_column, fc.SharedEmbeddingColumn):
+ raise ValueError(
+ 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. '
+ 'Given type: {} '.format(type(feature_column)))
+
+ collection_name = feature_column.shared_collection_name
+ if collection_name not in self._all_variables:
+ self._all_variables[collection_name] = {}
+ var_dict = self._all_variables[collection_name]
+ if name in var_dict:
+ return var_dict[name]
+ else:
+ var = variable_scope.get_variable(
+ name=name,
+ shape=shape,
+ initializer=initializer,
+ trainable=self._trainable)
+ var_dict[name] = var
+ return var
+
+
+class SharedEmbeddingColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ [categorical_column_b, categorical_column_a],
+ dimension=embedding_dimension)
+ self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
+ self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual(embedding_dimension, embedding_column_b.dimension)
+ self.assertEqual('mean', embedding_column_a.combiner)
+ self.assertEqual('mean', embedding_column_b.combiner)
+ self.assertIsNone(embedding_column_a.ckpt_to_load_from)
+ self.assertIsNone(embedding_column_b.ckpt_to_load_from)
+ self.assertEqual('aaa_bbb_shared_embedding',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('aaa_bbb_shared_embedding',
+ embedding_column_b.shared_collection_name)
+ self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
+ self.assertIsNone(embedding_column_a.max_norm)
+ self.assertIsNone(embedding_column_b.max_norm)
+ self.assertTrue(embedding_column_a.trainable)
+ self.assertTrue(embedding_column_b.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
+ self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
+ self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+ self.assertEqual({
+ 'bbb': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_b.parse_example_spec)
+
+ def test_all_constructor_args(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ combiner='my_combiner',
+ initializer=lambda: 'my_initializer',
+ shared_embedding_collection_name='shared_embedding_collection_name',
+ ckpt_to_load_from='my_ckpt',
+ tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42.,
+ trainable=False)
+ self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
+ self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual(embedding_dimension, embedding_column_b.dimension)
+ self.assertEqual('my_combiner', embedding_column_a.combiner)
+ self.assertEqual('my_combiner', embedding_column_b.combiner)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_b.shared_collection_name)
+ self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
+ self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
+ self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column_a.max_norm)
+ self.assertEqual(42., embedding_column_b.max_norm)
+ self.assertFalse(embedding_column_a.trainable)
+ self.assertFalse(embedding_column_b.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
+ self.assertEqual((embedding_dimension,), embedding_column_a.variable_shape)
+ self.assertEqual((embedding_dimension,), embedding_column_b.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+ self.assertEqual({
+ 'bbb': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_b.parse_example_spec)
+
+ def test_deep_copy(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ embedding_dimension = 2
+ original_a, _ = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ combiner='my_combiner',
+ initializer=lambda: 'my_initializer',
+ shared_embedding_collection_name='shared_embedding_collection_name',
+ ckpt_to_load_from='my_ckpt',
+ tensor_name_in_ckpt='my_ckpt_tensor',
+ max_norm=42., trainable=False)
+ for embedding_column_a in (original_a, copy.deepcopy(original_a)):
+ self.assertEqual('aaa', embedding_column_a.categorical_column.name)
+ self.assertEqual(3, embedding_column_a.categorical_column.num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.categorical_column.parse_example_spec)
+
+ self.assertEqual(embedding_dimension, embedding_column_a.dimension)
+ self.assertEqual('my_combiner', embedding_column_a.combiner)
+ self.assertEqual('shared_embedding_collection_name',
+ embedding_column_a.shared_collection_name)
+ self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
+ self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
+ self.assertEqual(42., embedding_column_a.max_norm)
+ self.assertFalse(embedding_column_a.trainable)
+ self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
+ self.assertEqual((embedding_dimension,),
+ embedding_column_a.variable_shape)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, embedding_column_a.parse_example_spec)
+
+ def test_invalid_initializer(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
+ fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2,
+ initializer='not_fn')
+
+ def test_incompatible_column_type(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ categorical_column_c = fc.categorical_column_with_hash_bucket(
+ key='ccc', hash_bucket_size=3)
+ with self.assertRaisesRegexp(
+ ValueError, 'all categorical_columns must have the same type.*'
+ 'IdentityCategoricalColumn.*HashedCategoricalColumn'):
+ fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b, categorical_column_c],
+ dimension=2)
+
+ def test_weighted_categorical_column_ok(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ weighted_categorical_column_a = fc.weighted_categorical_column(
+ categorical_column_a, weight_feature_key='aaa_weights')
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ weighted_categorical_column_b = fc.weighted_categorical_column(
+ categorical_column_b, weight_feature_key='bbb_weights')
+ fc.shared_embedding_columns(
+ [weighted_categorical_column_a, categorical_column_b], dimension=2)
+ fc.shared_embedding_columns(
+ [categorical_column_a, weighted_categorical_column_b], dimension=2)
+ fc.shared_embedding_columns(
+ [weighted_categorical_column_a, weighted_categorical_column_b],
+ dimension=2)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ b = fc.categorical_column_with_vocabulary_list(
+ key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_embedded, b_embedded = fc.shared_embedding_columns(
+ [a, b], dimension=2)
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ 'bbb':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'stringer', b'marlo'])),
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_embedded, b_embedded]))
+ self.assertIn('aaa', features)
+ self.assertIn('bbb', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'stringer', b'marlo'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['bbb'].eval())
+
+ def test_transform_feature(self):
+ a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
+ a_embedded, b_embedded = fc.shared_embedding_columns(
+ [a, b], dimension=2)
+ features = {
+ 'aaa': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb': sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ outputs = _transform_features(features, [a, a_embedded, b, b_embedded],
+ None)
+ output_a = outputs[a]
+ output_a_embedded = outputs[a_embedded]
+ output_b = outputs[b]
+ output_b_embedded = outputs[b_embedded]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self, output_a.eval(), output_a_embedded.eval())
+ _assert_sparse_tensor_value(
+ self, output_b.eval(), output_b_embedded.eval())
+
+ def test_get_dense_tensor(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+ input_features = {
+ 'aaa': input_a,
+ 'bbb': input_b
+ }
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups_a = (
+ # example 0:
+ (7., 11.), # ids [2], embedding = [7, 11]
+ # example 1:
+ (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ )
+ expected_lookups_b = (
+ # example 0:
+ (1., 2.), # ids [0], embedding = [1, 2]
+ # example 1:
+ (0., 0.), # ids [], embedding = [0, 0]
+ )
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension, initializer=_initializer)
+ state_manager = _TestSharedEmbeddingStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+ embedding_lookup_b = embedding_column_b.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(('embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
+ embedding_var = global_vars[0]
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, embedding_var.eval())
+ self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
+ self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
+
+ def DISABLED_test_get_dense_tensor_weight_collections(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+ input_features = {'aaa': input_a, 'bbb': input_b}
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ fc.input_layer(
+ input_features, [embedding_column_a, embedding_column_b],
+ weight_collections=('my_vars',))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in global_vars))
+ my_vars = ops.get_collection('my_vars')
+ self.assertItemsEqual(
+ ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
+ tuple(v.name for v in my_vars))
+
+ def test_get_dense_tensor_placeholder_inputs(self):
+ # Inputs.
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+ # Specify shape, because dense input must have rank specified.
+ input_a_placeholder = array_ops.placeholder(
+ dtype=dtypes.int64, shape=[None, 3])
+ input_b_placeholder = array_ops.placeholder(
+ dtype=dtypes.int64, shape=[None, 3])
+ input_features = {
+ 'aaa': input_a_placeholder,
+ 'bbb': input_b_placeholder,
+ }
+ feed_dict = {
+ input_a_placeholder: input_a,
+ input_b_placeholder: input_b,
+ }
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Build columns.
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension, initializer=_initializer)
+ state_manager = _TestSharedEmbeddingStateManager()
+
+ # Provide sparse input and get dense result.
+ embedding_lookup_a = embedding_column_a.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+ embedding_lookup_b = embedding_column_b.get_dense_tensor(
+ FeatureTransformationCache(input_features), state_manager)
+
+ with _initialized_session() as sess:
+ sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
+
+ def test_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array(
+ [[2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]]) # example 1, ids [0, 1]
+ input_b = np.array(
+ [[0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v for v in ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
+ def test_keras_linear_model(self):
+ # Inputs.
+ batch_size = 2
+ vocabulary_size = 3
+ # -1 values are ignored.
+ input_a = np.array([
+ [2, -1, -1], # example 0, ids [2]
+ [0, 1, -1]
+ ]) # example 1, ids [0, 1]
+ input_b = np.array([
+ [0, -1, -1], # example 0, ids [0]
+ [-1, -1, -1]
+ ]) # example 1, ids []
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_shape = (vocabulary_size, embedding_dimension)
+ zeros_embedding_values = np.zeros(embedding_shape)
+
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual(embedding_shape, shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return zeros_embedding_values
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer)
+
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ categorical_column_a.name: input_a,
+ categorical_column_b.name: input_b,
+ }, (embedding_column_a, embedding_column_b))
+ # Linear weights do not follow the column name. But this is a rare use
+ # case, and fixing it would add too much complexity to the code.
+ expected_var_names = (
+ 'linear_model/bias_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/weights:0',
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ )
+ self.assertItemsEqual(
+ expected_var_names,
+ [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
+ trainable_vars = {
+ v.name: v
+ for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ }
+ self.assertItemsEqual(expected_var_names, trainable_vars.keys())
+ bias = trainable_vars['linear_model/bias_weights:0']
+ embedding_weights = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ linear_weights_a = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ linear_weights_b = trainable_vars[
+ 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ with _initialized_session():
+ # Predictions with all zero weights.
+ self.assertAllClose(np.zeros((1,)), bias.eval())
+ self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
+ self.assertAllClose(
+ np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
+ self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
+
+ # Predictions with all non-zero weights.
+ embedding_weights.assign((
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )).eval()
+ linear_weights_a.assign(((4.,), (6.,))).eval()
+ # example 0, ids [2], embedding[0] = [7, 11]
+ # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # sum(embeddings * linear_weights)
+ # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
+ linear_weights_b.assign(((3.,), (5.,))).eval()
+ # example 0, ids [0], embedding[0] = [1, 2]
+ # example 1, ids [], embedding[1] = 0, 0]
+ # sum(embeddings * linear_weights)
+ # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
+ self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
+
+ def _test_input_layer(self, trainable=True):
+ # Inputs.
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 4)),
+ values=(2, 0, 1),
+ dense_shape=(2, 5))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [0]
+ # example 1, ids []
+ indices=((0, 0),),
+ values=(0,),
+ dense_shape=(2, 5))
+
+ # Embedding variable.
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ # Expected lookup result, using combiner='mean'.
+ expected_lookups = (
+ # example 0:
+ # A ids [2], embedding = [7, 11]
+ # B ids [0], embedding = [1, 2]
+ (7., 11., 1., 2.),
+ # example 1:
+ # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # B ids [], embedding = [0, 0]
+ (2., 3.5, 0., 0.),
+ )
+
+ # Build columns.
+ categorical_column_a = fc_old.categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ categorical_column_b = fc_old.categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b],
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=trainable)
+
+ # Provide sparse input and get dense result.
+ input_layer = fc.input_layer(
+ features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
+ feature_columns=(embedding_column_b, embedding_column_a))
+
+ # Assert expected embedding variable and lookups.
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ tuple([v.name for v in global_vars]))
+ trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+ if trainable:
+ self.assertItemsEqual(
+ ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ tuple([v.name for v in trainable_vars]))
+ else:
+ self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
+ shared_embedding_vars = global_vars
+ with _initialized_session():
+ self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
+ self.assertAllEqual(expected_lookups, input_layer.eval())
+
+ def test_input_layer(self):
+ self._test_input_layer()
+
+ def test_input_layer_no_trainable(self):
+ self._test_input_layer(trainable=False)
+
+
+class WeightedCategoricalColumnTest(test.TestCase):
+
+ def test_defaults(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ self.assertEqual('ids_weighted_by_values', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'ids': parsing_ops.VarLenFeature(dtypes.int64),
+ 'values': parsing_ops.VarLenFeature(dtypes.float32)
+ }, column.parse_example_spec)
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('ids_weighted_by_values', column.name)
+ self.assertEqual(3, column.num_buckets)
+ self.assertEqual({
+ 'ids': parsing_ops.VarLenFeature(dtypes.int64),
+ 'values': parsing_ops.VarLenFeature(dtypes.float32)
+ }, column.parse_example_spec)
+
+ def test_invalid_dtype_none(self):
+ with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values',
+ dtype=None)
+
+ def test_invalid_dtype_string(self):
+ with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values',
+ dtype=dtypes.string)
+
+ def test_invalid_input_dtype(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ strings = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Bad dtype'):
+ _transform_features({'ids': strings, 'values': strings}, (column,), None)
+
+ def test_column_name_collision(self):
+ with self.assertRaisesRegexp(ValueError, r'Parse config.*already exists'):
+ fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3),
+ weight_feature_key='aaa').parse_example_spec()
+
+ def test_missing_weights(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(
+ ValueError, 'values is not in features dictionary'):
+ _transform_features({'ids': inputs}, (column,), None)
+
+ def test_parse_example(self):
+ a = fc.categorical_column_with_vocabulary_list(
+ key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
+ a_weighted = fc.weighted_categorical_column(a, weight_feature_key='weights')
+ data = example_pb2.Example(features=feature_pb2.Features(
+ feature={
+ 'aaa':
+ feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
+ value=[b'omar', b'stringer'])),
+ 'weights':
+ feature_pb2.Feature(float_list=feature_pb2.FloatList(
+ value=[1., 10.]))
+ }))
+ features = parsing_ops.parse_example(
+ serialized=[data.SerializeToString()],
+ features=fc.make_parse_example_spec([a_weighted]))
+ self.assertIn('aaa', features)
+ self.assertIn('weights', features)
+ with self.test_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([b'omar', b'stringer'], dtype=np.object_),
+ dense_shape=[1, 2]),
+ features['aaa'].eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [0, 1]],
+ values=np.array([1., 10.], dtype=np.float32),
+ dense_shape=[1, 2]),
+ features['weights'].eval())
+
+ def test_transform_features(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ weights = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0.5, 1.0, 0.1),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': inputs,
+ 'values': weights,
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(inputs.values, dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=weights.indices,
+ values=np.array(weights.values, dtype=np.float32),
+ dense_shape=weights.dense_shape),
+ weight_tensor.eval())
+
+ def test_transform_features_dense_input(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ weights = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0.5, 1.0, 0.1),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': ((0, -1), (1, 0)),
+ 'values': weights,
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=weights.indices,
+ values=np.array(weights.values, dtype=np.float32),
+ dense_shape=weights.dense_shape),
+ weight_tensor.eval())
+
+ def test_transform_features_dense_weights(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 1, 0),
+ dense_shape=(2, 2))
+ id_tensor, weight_tensor = _transform_features({
+ 'ids': inputs,
+ 'values': ((.5, 0.), (1., .1)),
+ }, (column,), None)[column]
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(inputs.values, dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_tensor.eval())
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((.5, 1., .1), dtype=np.float32),
+ dense_shape=(2, 2)),
+ weight_tensor.eval())
+
+ def test_keras_linear_model(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_keras_linear_model_mismatched_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
+ get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_keras_linear_model_mismatched_dense_values(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_keras_linear_model_mismatched_dense_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = get_keras_linear_model_predictions({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_linear_model(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(.5, 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ def test_linear_model_mismatched_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError, r'Dimensions.*are not compatible'):
+ fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (0, 1), (1, 0), (1, 1)),
+ values=(.5, 11., 1., .1),
+ dense_shape=(2, 2))
+ }, (column,))
+
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
+ # Disabling the constant folding optimizer here since it changes the
+ # error message differently on CPU and GPU.
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.constant_folding = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ with _initialized_session(config):
+ with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
+ predictions.eval()
+
+ def test_linear_model_mismatched_dense_shape(self):
+ column = fc_old.weighted_categorical_column(
+ categorical_column=fc_old.categorical_column_with_identity(
+ key='ids', num_buckets=3),
+ weight_feature_key='values')
+ with ops.Graph().as_default():
+ predictions = fc.linear_model({
+ 'ids': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,), (.1,))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] * weights[0, 0] = 1 * .5 = .5
+ # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
+ # = 3*1 + 2*.1 = 3+.2 = 3.2
+ self.assertAllClose(((.5,), (3.2,)), predictions.eval())
+
+ # TODO(ptucker): Add test with embedding of weighted categorical.
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index 3c5aebbce8..40788e24c4 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -28,6 +28,18 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+def has_fully_defined_shape(tensor):
+ """Returns true if tensor has a fully defined shape."""
+ return isinstance(tensor, ops.EagerTensor) or tensor.shape.is_fully_defined()
+
+
+def rank(tensor):
+ """Return a rank if it is a tensor, else return None."""
+ if isinstance(tensor, ops.Tensor):
+ return tensor._rank() # pylint: disable=protected-access
+ return None
+
+
def scalar_shape(unused_op):
"""Shape function for ops that output a scalar value."""
return [tensor_shape.scalar()]
diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py
new file mode 100644
index 0000000000..9ccae76147
--- /dev/null
+++ b/tensorflow/python/framework/error_interpolation.py
@@ -0,0 +1,92 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Function for interpolating formatted errors from the TensorFlow runtime.
+
+Exposes the function `interpolate` to interpolate messages with tags of the form
+^^type:name:format^^.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+import re
+import string
+
+import six
+
+_NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?"
+_FORMAT_REGEX = r"[A-Za-z0-9_.\-/${}:]+"
+_TAG_REGEX = r"\^\^({name}):({name}):({fmt})\^\^".format(
+ name=_NAME_REGEX, fmt=_FORMAT_REGEX)
+_INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX)
+_INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX)
+
+_ParseTag = collections.namedtuple("_ParseTag", ["type", "name", "format"])
+
+
+def _parse_message(message):
+ """Parses the message.
+
+ Splits the message into separators and tags. Tags are named tuples
+ representing the string ^^type:name:format^^ and they are separated by
+ separators. For example, in
+ "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789", there are two tags and
+ three separators. The separators are the numeric characters.
+
+ Args:
+ message: String to parse
+
+ Returns:
+ (list of separator strings, list of _ParseTags).
+
+ For example, if message is "123^^node:Foo:${file}^^456" then this function
+ returns (["123", "456"], [_ParseTag("node", "Foo", "${file}")])
+ """
+ seps = []
+ tags = []
+ pos = 0
+ while pos < len(message):
+ match = re.match(_INTERPOLATION_PATTERN, message[pos:])
+ if match:
+ seps.append(match.group(1))
+ tags.append(_ParseTag(match.group(3), match.group(4), match.group(5)))
+ pos += match.end()
+ else:
+ break
+ seps.append(message[pos:])
+ return seps, tags
+
+
+# TODO(jtkeeling): Modify to actually interpolate format strings rather than
+# echoing them.
+def interpolate(error_message):
+ """Interpolates an error message.
+
+ The error message can contain tags of the form ^^type:name:format^^ which will
+ be replaced.
+
+ Args:
+ error_message: A string to interpolate.
+
+ Returns:
+ The string with tags of the form ^^type:name:format^^ interpolated.
+ """
+ seps, tags = _parse_message(error_message)
+ subs = [string.Template(tag.format).safe_substitute({}) for tag in tags]
+ return "".join(
+ itertools.chain(*six.moves.zip_longest(seps, subs, fillvalue="")))
diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py
new file mode 100644
index 0000000000..ad448deb62
--- /dev/null
+++ b/tensorflow/python/framework/error_interpolation_test.py
@@ -0,0 +1,49 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.errors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import error_interpolation
+from tensorflow.python.platform import test
+
+
+class InterpolateTest(test.TestCase):
+
+ def testNothingToDo(self):
+ normal_string = "This is just a normal string"
+ interpolated_string = error_interpolation.interpolate(normal_string)
+ self.assertEqual(interpolated_string, normal_string)
+
+ def testOneTag(self):
+ one_tag_string = "^^node:Foo:${file}^^"
+ interpolated_string = error_interpolation.interpolate(one_tag_string)
+ self.assertEqual(interpolated_string, "${file}")
+
+ def testTwoTagsNoSeps(self):
+ two_tags_no_seps = "^^node:Foo:${file}^^^^node:Bar:${line}^^"
+ interpolated_string = error_interpolation.interpolate(two_tags_no_seps)
+ self.assertEqual(interpolated_string, "${file}${line}")
+
+ def testTwoTagsWithSeps(self):
+ two_tags_with_seps = "123^^node:Foo:${file}^^456^^node:Bar:${line}^^789"
+ interpolated_string = error_interpolation.interpolate(two_tags_with_seps)
+ self.assertEqual(interpolated_string, "123${file}456${line}789")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 259cab6699..6525607fae 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -23,6 +23,7 @@ from __future__ import print_function
import collections
import hashlib
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -33,12 +34,17 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+# This is to avoid a circular dependency with cond_v2_impl.
+cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+
class Defun(object):
"""Decorator used to define TensorFlow functions.
@@ -650,6 +656,41 @@ class _FuncGraph(ops.Graph):
# TODO(skyewm): is this needed?
self.extra_vars = []
+ # pylint: disable=g-doc-return-or-yield
+
+ @tf_contextlib.contextmanager
+ def container(self, container_name):
+ """Returns a context manager that specifies the resource container to use.
+
+ Overridden from @{tf.Graph} to update both the init_scope container
+ and the present inner container. This is necessary to make sure setting
+ containers applies correctly both to created variables and to stateful
+ ops.
+
+ Args:
+ container_name: container name string.
+
+ Returns:
+ A context manager for defining resource containers for stateful ops,
+ yields the container name.
+ """
+ original_container = self._container
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ original_init_container = ops.get_default_graph()._container
+ try:
+ self._container = container_name
+ with ops.init_scope():
+ ops.get_default_graph()._container = container_name
+ yield self._container
+ finally:
+ self._container = original_container
+ with ops.init_scope():
+ ops.get_default_graph()._container = original_init_container
+ # pylint: enable=protected-access
+
+ # pylint: enable=g-doc-return-or-yield
+
def getvar(
self,
getter,
@@ -720,6 +761,8 @@ class _FuncGraph(ops.Graph):
if ops._USE_C_SHAPES:
if isinstance(tensor, ops.EagerTensor):
handle_data = tensor._handle_data
+ if handle_data:
+ handle_data = handle_data.SerializeToString()
else:
handle_data = c_api.GetResourceHandleShapeAndType(
tensor.graph._c_graph, tensor._as_tf_output())
@@ -771,7 +814,9 @@ class _FuncGraph(ops.Graph):
def func_graph_from_py_func(func, arg_names, arg_types, name=None,
- capture_by_value=False, device=None):
+ capture_by_value=False, device=None,
+ colocation_stack=None, container=None,
+ collections_ref=None):
"""Returns a _FuncGraph generated from `func`.
Args:
@@ -784,6 +829,10 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
capture_by_value: boolean. If True, captured values will be copied into the
function body.
device: device name or function.
+ colocation_stack: A colocation stack (list) the _FuncGraph should use.
+ container: A container name the _FuncGraph should start with.
+ collections_ref: A reference to a collections dict the _FuncGraph should
+ use internally.
Returns:
A _FuncGraph.
@@ -794,7 +843,17 @@ def func_graph_from_py_func(func, arg_names, arg_types, name=None,
if not name:
name = _get_func_name(func)
func_graph = _FuncGraph(name, capture_by_value)
+
with func_graph.as_default(), ops.device(device):
+ # pylint: disable=protected-access
+ if collections_ref is not None:
+ func_graph._collections = collections_ref
+ if container is not None:
+ func_graph._container = container
+ if colocation_stack is not None:
+ func_graph._colocation_stack = colocation_stack
+ # pylint: enable=protected-access
+
# Create placeholders for the function arguments.
for (argname, argtype) in zip(arg_names, arg_types):
argholder = array_ops.placeholder(argtype, name=argname)
@@ -1170,3 +1229,13 @@ _DTYPE_TO_STR = {
dtypes.qint32: "qi32",
dtypes.bfloat16: "b16"
}
+
+
+def function_def_from_tf_function(c_func):
+ """Converts a SWIG-wrapped TF_Function* to a FunctionDef proto."""
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_FunctionToFunctionDef(c_func, buf)
+ data = c_api.TF_GetBuffer(buf)
+ fdef = function_pb2.FunctionDef()
+ fdef.ParseFromString(compat.as_bytes(data))
+ return fdef
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 4fecc41343..46c9c4c14a 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
@@ -25,6 +27,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import versions
+from tensorflow.python.ops import cond_v2_impl
+
+# This is to avoid a circular dependency with cond_v2_impl.
+cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=protected-access
def function_def_to_graph(fdef, input_shapes=None):
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 72eb7e0eeb..699d2b70d1 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -407,11 +407,11 @@ def import_graph_def(graph_def,
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements)
- # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
- # call cannot occur between creating the TF_Operations in the
+ # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
+ # Session.run call cannot occur between creating the TF_Operations in the
# TF_GraphImportGraphDefWithResults call and mutating the them in
# _ProcessNewOps.
- with graph._lock: # pylint: disable=protected-access
+ with graph._mutation_lock(): # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
try:
results = c_api.TF_GraphImportGraphDefWithResults(
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index a19a72c881..b07c57d265 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -55,14 +55,16 @@ from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import lock_util
from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
# Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True
-_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "0") is not "0"
+_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0"
def tensor_id(tensor):
@@ -288,15 +290,8 @@ class Tensor(_TensorLike):
self._value_index = value_index
self._dtype = dtypes.as_dtype(dtype)
- if _USE_C_API:
- # This will be set by set_shape_and_handle_data_for_outputs.
- self._shape_val = None
- else:
- # The Python code requires all tensors start with a shape to support shape
- # inference on imported while loops. This isn't necessary with the C API
- # enabled because the C API provides the shapes for imported nodes.
- # TODO(skyewm): remove when _USE_C_API is removed.
- self._shape_val = tensor_shape.unknown_shape()
+ # This will be set by self.shape().
+ self._shape_val = None
# List of operations that use this Tensor as input. We maintain this list
# to easily navigate a computation graph.
@@ -384,7 +379,6 @@ class Tensor(_TensorLike):
if _USE_C_SHAPES:
self._shape_val = self._c_api_shape()
else:
- assert _USE_C_API
# Call set_shape_and_handle_data_for_outputs in topological order on all
# ops that are needed to compute self.op's shape. We do this instead of
# having set_shape_and_handle_data_for_outputs recursively call
@@ -508,8 +502,6 @@ class Tensor(_TensorLike):
else:
self._shape_val = self.shape.merge_with(shape)
- if not self._op._graph._c_graph: return
-
# Update C shape even if _USE_C_SHAPES = False, since we still want
# set_shape to be reflected in the C API graph for when we run it.
if not isinstance(shape, tensor_shape.TensorShape):
@@ -545,33 +537,14 @@ class Tensor(_TensorLike):
Returns:
A list of `Operation`s.
"""
- if self._op._c_op: # pylint: disable=protected-access
- consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
- self._as_tf_output())
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(name)
- for name in consumer_names
- ]
- # pylint: enable=protected-access
- else:
- return self._consumers
-
- def _add_consumer(self, consumer):
- """Add a consumer to this tensor.
-
- Args:
- consumer: an Operation.
-
- Raises:
- TypeError: if the consumer is not an Operation.
- """
+ consumer_names = c_api.TF_OperationOutputConsumers_wrapper(
+ self._as_tf_output())
# pylint: disable=protected-access
- assert not self._op._c_op, "Tensor._add_consumer doesn't work with C API"
+ return [
+ self.graph._get_operation_by_name_unsafe(name)
+ for name in consumer_names
+ ]
# pylint: enable=protected-access
- if not isinstance(consumer, Operation):
- raise TypeError("Consumer must be an Operation: %s" % consumer)
- self._consumers.append(consumer)
def _as_node_def_input(self):
"""Return a value to use for the NodeDef "input" attribute.
@@ -594,7 +567,6 @@ class Tensor(_TensorLike):
def _as_tf_output(self):
# pylint: disable=protected-access
- assert self.op._c_op
return c_api_util.tf_output(self.op._c_op, self.value_index)
# pylint: enable=protected-access
@@ -734,7 +706,7 @@ class _EagerTensorBase(Tensor):
"""
if self.dtype == dtypes.resource:
raise ValueError("Resource handles are not convertible to numpy.")
- return self.cpu()._numpy() # pylint: disable=protected-access
+ return self._cpu_nograd()._numpy() # pylint: disable=protected-access
# __int__ and __float__ may copy the tensor to CPU and
# only work for scalars; values are cast as per numpy.
@@ -808,8 +780,8 @@ class _EagerTensorBase(Tensor):
def _override_operator(name, func):
setattr(_EagerTensorBase, name, func)
- def _copy(self, ctx=None, device_name=None):
- """Copies tensor to dest device."""
+ def _copy_nograd(self, ctx=None, device_name=None):
+ """Copies tensor to dest device, but doesn't record the operation."""
# pylint: disable=protected-access
# Creates a new tensor on the dest device.
if ctx is None:
@@ -821,7 +793,11 @@ class _EagerTensorBase(Tensor):
new_tensor = self._copy_to_device(context=ctx._handle, device=device_name)
except core._NotOkStatusException as e:
six.raise_from(core._status_to_exception(e.code, e.message), None)
+ return new_tensor
+ def _copy(self, ctx=None, device_name=None):
+ """Copies tensor to dest device."""
+ new_tensor = self._copy_nograd(ctx, device_name)
# Record the copy on tape and define backprop copy as well.
if context.executing_eagerly():
self_device = self.device
@@ -852,6 +828,16 @@ class _EagerTensorBase(Tensor):
"""Returns the number of Tensor dimensions."""
return self.shape.ndims
+ def _cpu_nograd(self):
+ """A copy of this Tensor with contents backed by host memory.
+
+ The copy cannot be differentiated through.
+
+ Returns:
+ A CPU-memory backed Tensor object with the same contents as this Tensor.
+ """
+ return self._copy_nograd(context.context(), "CPU:0")
+
def cpu(self):
"""A copy of this Tensor with contents backed by host memory."""
return self._copy(context.context(), "CPU:0")
@@ -1722,18 +1708,8 @@ class Operation(object):
"a Tensor, or IndexedSlices: %s" % c)
control_input_ops.append(control_op)
- # Don't set private fields with C API enabled to catch users who need to
- # switch to public API.
- # TODO(skyewm): delete these fields once we remove _USE_C_API
- if not self._graph._c_graph:
- self._inputs_val = list(inputs) # Defensive copy.
- self._input_types_val = input_types
- self._control_inputs_val = control_input_ops
- self._node_def_val = copy.deepcopy(node_def)
- self._op_def_val = op_def
- else:
- # This will be set by self.inputs.
- self._inputs_val = None
+ # This will be set by self.inputs.
+ self._inputs_val = None
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._original_op = original_op
@@ -1742,10 +1718,8 @@ class Operation(object):
# Initialize self._c_op.
if c_op:
- # TODO(skyewm): remove this assert when we remove USE_C_API
- assert self._graph._c_graph # pylint: disable=protected-access
self._c_op = c_op
- elif self._graph._c_graph: # pylint: disable=protected-access
+ else:
if op_def is None:
op_def = self._graph._get_op_def(node_def.op)
# TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs.
@@ -1754,30 +1728,19 @@ class Operation(object):
op_def, inputs, node_def.attr)
self._c_op = _create_c_op(self._graph, node_def, grouped_inputs,
control_input_ops)
- else:
- self._c_op = None
-
- # Mark that we consume the inputs. This is unnecessary and unsupported with
- # the C API enabled, since the C API tracks the tensor consumers instead.
- if not self._c_op:
- for input_tensor in self._inputs_val:
- input_tensor._add_consumer(self) # pylint: disable=protected-access
# Initialize self._outputs.
- if self._c_op:
- num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
- output_types = [
- c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
- for i in range(num_outputs)]
- assert output_types is not None
- elif output_types is None:
- output_types = []
- self._output_types_val = output_types
+ num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
+ output_types = [
+ c_api.TF_OperationOutputType(c_api_util.tf_output(self._c_op, i))
+ for i in range(num_outputs)]
self._outputs = [
Tensor(self, i, output_type)
for i, output_type in enumerate(output_types)
]
+ self._graph._add_op(self) # pylint: disable=protected-access
+
if not c_op:
self._control_flow_post_processing()
@@ -1791,7 +1754,6 @@ class Operation(object):
control_flow_util.CheckInputFromValidContext(self, input_tensor.op)
if self._control_flow_context is not None:
self._control_flow_context.AddOp(self)
- self._recompute_node_def()
def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
"""Regroups a flat list of input tensors into scalar and sequence inputs.
@@ -1872,10 +1834,7 @@ class Operation(object):
@property
def name(self):
"""The full name of this operation."""
- if self._c_op:
- return c_api.TF_OperationName(self._c_op)
- else:
- return self._node_def_val.name
+ return c_api.TF_OperationName(self._c_op)
@property
def _id(self):
@@ -1891,10 +1850,7 @@ class Operation(object):
assigned, or an empty string if it has not been assigned to a
device.
"""
- if self._c_op:
- return c_api.TF_OperationDevice(self._c_op)
- else:
- return self._node_def_val.device
+ return c_api.TF_OperationDevice(self._c_op)
@property
def _output_types(self):
@@ -1907,28 +1863,21 @@ class Operation(object):
The length of this list indicates the number of output endpoints
of the operation.
"""
- if self._c_op:
- num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
- output_types = [
- c_api.TF_OperationOutputType(self._tf_output(i))
- for i in xrange(num_outputs)
- ]
- # TODO(iga): Remove this assert after converting to C API by default.
- # Just being a bit paranoid here.
- assert self._output_types_val == output_types
- # In all the tests we have output_types that are passed into
- # Operation.__init__ are a list of ints (which is illegal according
- # to the docstring), but input_types are instances of DType.
- # This extra assert is to catch if we ever use DType for output_types.
- if output_types:
- assert isinstance(output_types[0], int)
- return output_types
- else:
- return self._output_types_val
+ num_outputs = c_api.TF_OperationNumOutputs(self._c_op)
+ output_types = [
+ c_api.TF_OperationOutputType(self._tf_output(i))
+ for i in xrange(num_outputs)
+ ]
+ # In all the tests we have output_types that are passed into
+ # Operation.__init__ are a list of ints (which is illegal according
+ # to the docstring), but input_types are instances of DType.
+ # This extra assert is to catch if we ever use DType for output_types.
+ if output_types:
+ assert isinstance(output_types[0], int)
+ return output_types
def _tf_output(self, output_idx):
"""Create and return a new TF_Output for output_idx'th output of this op."""
- assert self._c_op
tf_output = c_api.TF_Output()
tf_output.oper = self._c_op
tf_output.index = output_idx
@@ -1936,7 +1885,6 @@ class Operation(object):
def _tf_input(self, input_idx):
"""Create and return a new TF_Input for input_idx'th input of this op."""
- assert self._c_op
tf_input = c_api.TF_Input()
tf_input.oper = self._c_op
tf_input.index = input_idx
@@ -1948,47 +1896,12 @@ class Operation(object):
Args:
device: string or device.. The device to set.
"""
- if self._c_op:
- c_api.SetRequestedDevice(
- self._graph._c_graph, # pylint: disable=protected-access
- self._c_op, # pylint: disable=protected-access
- compat.as_str(_device_string(device)))
- else:
- self._node_def_val.device = _device_string(device)
-
- def _add_input(self, tensor, dtype=None):
- """Add a new input to this operation.
+ c_api.SetRequestedDevice(
+ self._graph._c_graph, # pylint: disable=protected-access
+ self._c_op, # pylint: disable=protected-access
+ compat.as_str(_device_string(device)))
- Args:
- tensor: the Tensor to add as an input.
- dtype: tf.DType: type of the input; defaults to
- the tensor's dtype.
-
- Raises:
- TypeError: if tensor is not a Tensor,
- or if input tensor type is not convertible to dtype.
- ValueError: if the Tensor is from a different graph.
- """
- assert not self._c_op, (
- "Operation._add_input doesn't work with C API")
- if not isinstance(tensor, Tensor):
- raise TypeError("tensor must be a Tensor: %s" % tensor)
- _assert_same_graph(self, tensor)
- if dtype is None:
- dtype = tensor.dtype
- else:
- dtype = dtypes.as_dtype(dtype)
- if not dtype.is_compatible_with(tensor.dtype):
- raise TypeError(
- "Cannot convert a tensor of type %s to an input of type %s" %
- (tensor.dtype.name, dtype.name))
- self._inputs_val.append(tensor)
- self._input_types_val.append(dtype)
- tensor._add_consumer(self) # pylint: disable=protected-access
- self._recompute_node_def()
-
- # TODO(skyewm): Remove `update_dtype` when we enable the C API.
- def _update_input(self, index, tensor, update_dtype=True):
+ def _update_input(self, index, tensor):
"""Update the input to this operation at the given index.
NOTE: This is for TF internal use only. Please don't use it.
@@ -1996,7 +1909,6 @@ class Operation(object):
Args:
index: the index of the input to update.
tensor: the Tensor to be used as the input at the given index.
- update_dtype: If `False`, the type for this input is not updated.
Raises:
TypeError: if tensor is not a Tensor,
@@ -2013,20 +1925,12 @@ class Operation(object):
if not _USE_C_SHAPES:
set_shape_and_handle_data_for_outputs(self)
- if self._c_op:
- # Reset cached inputs.
- self._inputs_val = None
- c_api.UpdateEdge(
- self._graph._c_graph, # pylint: disable=protected-access
- tensor._as_tf_output(), # pylint: disable=protected-access
- self._tf_input(index))
- else:
- self._inputs_val[index].consumers().remove(self)
- self._inputs_val[index] = tensor
- if update_dtype:
- self._input_types_val[index] = tensor.dtype
- tensor._add_consumer(self) # pylint: disable=protected-access
- self._recompute_node_def()
+ # Reset cached inputs.
+ self._inputs_val = None
+ c_api.UpdateEdge(
+ self._graph._c_graph, # pylint: disable=protected-access
+ tensor._as_tf_output(), # pylint: disable=protected-access
+ self._tf_input(index))
def _add_control_inputs(self, ops):
"""Add a list of new control inputs to this operation.
@@ -2038,19 +1942,10 @@ class Operation(object):
TypeError: if ops is not a list of Operations.
ValueError: if any op in ops is from a different graph.
"""
- if self._c_op:
- for op in ops:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
- else:
- if ops:
- for op in ops:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- _assert_same_graph(self, op)
- self._control_inputs_val.append(op)
- self._recompute_node_def()
+ for op in ops:
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
def _add_control_input(self, op):
"""Add a new control input to this operation.
@@ -2062,33 +1957,13 @@ class Operation(object):
TypeError: if op is not an Operation.
ValueError: if op is from a different graph.
"""
- if self._c_op:
- if not isinstance(op, Operation):
- raise TypeError("op must be an Operation: %s" % op)
- c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
- else:
- self._add_control_inputs([op])
+ if not isinstance(op, Operation):
+ raise TypeError("op must be an Operation: %s" % op)
+ c_api.AddControlInput(self._graph._c_graph, self._c_op, op._c_op) # pylint: disable=protected-access
def _remove_all_control_inputs(self):
"""Removes any control inputs to this operation."""
- if self._c_op:
- c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
- else:
- del self.control_inputs[:]
-
- # Methods below are used when building the NodeDef and Graph proto.
- def _recompute_node_def(self):
- # TODO(skyewm): remove this function when we switch to C API
- if self._c_op: return
-
- del self._node_def_val.input[:]
- # pylint: disable=protected-access
- self._node_def_val.input.extend(
- [t._as_node_def_input() for t in self._inputs_val])
- # pylint: enable=protected-access
- if self._control_inputs_val:
- self._node_def_val.input.extend(
- ["^%s" % op.name for op in self._control_inputs_val])
+ c_api.RemoveAllControlInputs(self._graph._c_graph, self._c_op) # pylint: disable=protected-access
def __str__(self):
return str(self.node_def)
@@ -2129,19 +2004,16 @@ class Operation(object):
@property
def inputs(self):
"""The list of `Tensor` objects representing the data inputs of this op."""
- if self._c_op:
- if self._inputs_val is None:
- tf_outputs = c_api.GetOperationInputs(self._c_op)
- # pylint: disable=protected-access
- retval = [
- self.graph._get_tensor_by_tf_output(tf_output)
- for tf_output in tf_outputs
- ]
- # pylint: enable=protected-access
- self._inputs_val = Operation._InputList(retval)
- return self._inputs_val
- else:
- return Operation._InputList(self._inputs_val)
+ if self._inputs_val is None:
+ tf_outputs = c_api.GetOperationInputs(self._c_op)
+ # pylint: disable=protected-access
+ retval = [
+ self.graph._get_tensor_by_tf_output(tf_output)
+ for tf_output in tf_outputs
+ ]
+ # pylint: enable=protected-access
+ self._inputs_val = Operation._InputList(retval)
+ return self._inputs_val
@property
def _inputs(self):
@@ -2155,15 +2027,12 @@ class Operation(object):
@property
def _input_types(self):
- if self._c_op:
- num_inputs = c_api.TF_OperationNumInputs(self._c_op)
- input_types = [
- dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
- for i in xrange(num_inputs)
- ]
- return input_types
- else:
- return self._input_types_val
+ num_inputs = c_api.TF_OperationNumInputs(self._c_op)
+ input_types = [
+ dtypes.as_dtype(c_api.TF_OperationInputType(self._tf_input(i)))
+ for i in xrange(num_inputs)
+ ]
+ return input_types
@_input_types.setter
def _input_types(self, value):
@@ -2183,16 +2052,13 @@ class Operation(object):
A list of `Operation` objects.
"""
- if self._c_op:
- control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(
- c_api.TF_OperationName(c_op)) for c_op in control_c_ops
- ]
- # pylint: enable=protected-access
- else:
- return self._control_inputs_val
+ control_c_ops = c_api.TF_OperationGetControlInputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
@property
def _control_outputs(self):
@@ -2205,18 +2071,13 @@ class Operation(object):
A list of `Operation` objects.
"""
- if self._c_op:
- control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
- # pylint: disable=protected-access
- return [
- self.graph._get_operation_by_name_unsafe(
- c_api.TF_OperationName(c_op)) for c_op in control_c_ops
- ]
- # pylint: enable=protected-access
- else:
- # TODO(apassos) this should be less inefficient.
- return [o for o in self._graph.get_operations()
- if self in o.control_inputs]
+ control_c_ops = c_api.TF_OperationGetControlOutputs_wrapper(self._c_op)
+ # pylint: disable=protected-access
+ return [
+ self.graph._get_operation_by_name_unsafe(
+ c_api.TF_OperationName(c_op)) for c_op in control_c_ops
+ ]
+ # pylint: enable=protected-access
@property
def _control_inputs(self):
@@ -2240,11 +2101,7 @@ class Operation(object):
@property
def type(self):
"""The type of the op (e.g. `"MatMul"`)."""
- if self._c_op:
- op_type = c_api.TF_OperationOpType(self._c_op)
- return op_type
- else:
- return self._node_def_val.op
+ return c_api.TF_OperationOpType(self._c_op)
@property
def graph(self):
@@ -2262,15 +2119,12 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
- if self._c_op:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_OperationToNodeDef(self._c_op, buf)
- data = c_api.TF_GetBuffer(buf)
- node_def = node_def_pb2.NodeDef()
- node_def.ParseFromString(compat.as_bytes(data))
- return node_def
- else:
- return self._node_def_val
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_OperationToNodeDef(self._c_op, buf)
+ data = c_api.TF_GetBuffer(buf)
+ node_def = node_def_pb2.NodeDef()
+ node_def.ParseFromString(compat.as_bytes(data))
+ return node_def
@property
def _node_def(self):
@@ -2289,10 +2143,7 @@ class Operation(object):
protocol buffer.
"""
# pylint: enable=line-too-long
- if self._c_op:
- return self._graph._get_op_def(self.type)
- else:
- return self._op_def_val
+ return self._graph._get_op_def(self.type)
@property
def _op_def(self):
@@ -2318,17 +2169,14 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if self._c_op:
- buf = c_api.TF_NewBufferFromString(
- compat.as_bytes(attr_value.SerializeToString()))
- try:
- # pylint: disable=protected-access
- c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
- # pylint: enable=protected-access
- finally:
- c_api.TF_DeleteBuffer(buf)
- else:
- self._node_def_val.attr[attr_name].CopyFrom(attr_value)
+ buf = c_api.TF_NewBufferFromString(
+ compat.as_bytes(attr_value.SerializeToString()))
+ try:
+ # pylint: disable=protected-access
+ c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf)
+ # pylint: enable=protected-access
+ finally:
+ c_api.TF_DeleteBuffer(buf)
def get_attr(self, name):
"""Returns the value of the attr of this op with the given `name`.
@@ -2343,21 +2191,15 @@ class Operation(object):
ValueError: If this op does not have an attr with the given `name`.
"""
fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
- if self._c_op:
- try:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
- data = c_api.TF_GetBuffer(buf)
- except errors.InvalidArgumentError as e:
- # Convert to ValueError for backwards compatibility.
- raise ValueError(str(e))
- x = attr_value_pb2.AttrValue()
- x.ParseFromString(data)
- else:
- if name not in self._node_def_val.attr:
- raise ValueError(
- "No attr named '" + name + "' in " + str(self._node_def_val))
- x = self._node_def_val.attr[name]
+ try:
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
+ data = c_api.TF_GetBuffer(buf)
+ except errors.InvalidArgumentError as e:
+ # Convert to ValueError for backwards compatibility.
+ raise ValueError(str(e))
+ x = attr_value_pb2.AttrValue()
+ x.ParseFromString(data)
# Treat an empty oneof value as an empty list.
if not x.WhichOneof("value"):
@@ -2577,9 +2419,9 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
def set_shape_and_handle_data_for_outputs(op):
"""Set the shapes and resource handle data for op's outputs.
- When _USE_C_API = True, this is lazily called when a tensor's shape is first
- requested. Usually this should work automatically, but some edge cases may
- require manually calling this first to make sure Tensor._shape_val and
+ When _USE_C_SHAPES = False, this is lazily called when a tensor's shape is
+ first requested. Usually this should work automatically, but some edge cases
+ may require manually calling this first to make sure Tensor._shape_val and
Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a
Tensor).
"""
@@ -2772,6 +2614,10 @@ def _name_from_scope_name(name):
return name[:-1] if (name and name[-1] == "/") else name
+_MUTATION_LOCK_GROUP = 0
+_SESSION_RUN_LOCK_GROUP = 1
+
+
@tf_export("Graph")
class Graph(object):
"""A TensorFlow computation, represented as a dataflow graph.
@@ -2821,20 +2667,21 @@ class Graph(object):
def __init__(self):
"""Creates a new, empty Graph."""
- # Protects core state that can be returned via public accessors, as well as
- # synchronizes Session.run calls with methods that create and mutate ops
- # (e.g. Graph.create_op()). This synchronization is necessary because it's
- # illegal to modify an operation after it's been run. Thread-safety is
- # provided on a best-effort basis to support buggy programs, and is not
- # guaranteed by the public `tf.Graph` API.
- #
- # The lock must be reentrant because create_op can be called recursively due
- # to control flow. Without a reentrant lock, many methods would also need a
- # "locked" version or parameter (including generated code).
+ # Protects core state that can be returned via public accessors.
+ # Thread-safety is provided on a best-effort basis to support buggy
+ # programs, and is not guaranteed by the public `tf.Graph` API.
#
# NOTE(mrry): This does not protect the various stacks. A warning will
# be reported if these are used from multiple threads
self._lock = threading.RLock()
+ # The group lock synchronizes Session.run calls with methods that create
+ # and mutate ops (e.g. Graph.create_op()). This synchronization is
+ # necessary because it's illegal to modify an operation after it's been run.
+ # The group lock allows any number of threads to mutate ops at the same time
+ # but if any modification is going on, all Session.run calls have to wait.
+ # Similarly, if one or more Session.run calls are going on, all mutate ops
+ # have to wait until all Session.run calls have finished.
+ self._group_lock = lock_util.GroupLock(num_groups=2)
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
@@ -3083,15 +2930,12 @@ class Graph(object):
A `VersionDef`.
"""
# pylint: enable=line-too-long
- if self._c_graph:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_GraphVersions(self._c_graph, buf)
- data = c_api.TF_GetBuffer(buf)
- version_def = versions_pb2.VersionDef()
- version_def.ParseFromString(compat.as_bytes(data))
- return version_def
- else:
- return self._graph_def_versions
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_GraphVersions(self._c_graph, buf)
+ data = c_api.TF_GetBuffer(buf)
+ version_def = versions_pb2.VersionDef()
+ version_def.ParseFromString(compat.as_bytes(data))
+ return version_def
@property
def seed(self):
@@ -3185,40 +3029,22 @@ class Graph(object):
"""
# pylint: enable=line-too-long
- if self._c_graph:
- with self._lock:
- with c_api_util.tf_buffer() as buf:
- c_api.TF_GraphToGraphDef(self._c_graph, buf)
- data = c_api.TF_GetBuffer(buf)
- graph = graph_pb2.GraphDef()
- graph.ParseFromString(compat.as_bytes(data))
- # Strip the experimental library field iff it's empty.
- if not graph.library.function:
- graph.ClearField("library")
-
- if add_shapes:
- for node in graph.node:
- op = self._nodes_by_name[node.name]
- if op.outputs:
- node.attr["_output_shapes"].list.shape.extend(
- [output.get_shape().as_proto() for output in op.outputs])
- else:
- with self._lock:
- graph = graph_pb2.GraphDef()
- graph.versions.CopyFrom(self._graph_def_versions)
- bytesize = 0
- for op_id in sorted(self._nodes_by_id):
- op = self._nodes_by_id[op_id]
- if from_version is None or op_id > from_version:
- graph.node.extend([op.node_def])
- if op.outputs and add_shapes:
- assert "_output_shapes" not in graph.node[-1].attr
- graph.node[-1].attr["_output_shapes"].list.shape.extend(
- [output.get_shape().as_proto() for output in op.outputs])
- bytesize += op.node_def.ByteSize()
- if bytesize >= (1 << 31) or bytesize < 0:
- raise ValueError("GraphDef cannot be larger than 2GB.")
- self._copy_functions_to_graph_def(graph, bytesize)
+ with self._lock:
+ with c_api_util.tf_buffer() as buf:
+ c_api.TF_GraphToGraphDef(self._c_graph, buf)
+ data = c_api.TF_GetBuffer(buf)
+ graph = graph_pb2.GraphDef()
+ graph.ParseFromString(compat.as_bytes(data))
+ # Strip the experimental library field iff it's empty.
+ if not graph.library.function:
+ graph.ClearField("library")
+
+ if add_shapes:
+ for node in graph.node:
+ op = self._nodes_by_name[node.name]
+ if op.outputs:
+ node.attr["_output_shapes"].list.shape.extend(
+ [output.get_shape().as_proto() for output in op.outputs])
return graph, self._version
def as_graph_def(self, from_version=None, add_shapes=False):
@@ -3292,34 +3118,16 @@ class Graph(object):
# Add function to graph
# pylint: disable=protected-access
- if self._c_graph:
- # Handle functions created without using the C API. TODO(apassos,skyewm)
- # remove this when all functions are generated using the C API by default
- # as this will be unnecessary.
- if not function._c_func:
- serialized = function.definition.SerializeToString()
- c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- function._c_func = c_api_util.ScopedTFFunction(c_func)
- gradient = (function._grad_func._c_func.func if function._grad_func
- else None)
- c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
- else:
- # If there is already a function with the same name, raise an error
- # if bodies are different. Else, do nothing. The C API version above
- # has the same behavior.
- previous = self._functions.get(name, None)
- if previous:
- # This check is not ideal as we can have a hash collision with only
- # 32 bits in the hash, but the non C API mode is being deprecated.
- # Don't bother changing it now.
- if previous._hash_str == function._hash_str:
- return
- else:
- raise ValueError("Cannot add function (%s, hash %s) to graph (%s). "
- "Another function (%s, hash %s) is already defined "
- "with that name (%s)" % (
- function, function._hash_str, self,
- previous, previous._hash_str, name))
+ # Handle functions created without using the C API. TODO(apassos,skyewm)
+ # remove this when all functions are generated using the C API by default
+ # as this will be unnecessary.
+ if not function._c_func:
+ serialized = function.definition.SerializeToString()
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ function._c_func = c_api_util.ScopedTFFunction(c_func)
+ gradient = (function._grad_func._c_func.func if function._grad_func
+ else None)
+ c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
# pylint: enable=protected-access
self._functions[name] = function
@@ -3334,6 +3142,9 @@ class Graph(object):
return self._building_function
# Helper functions to create operations.
+ @deprecated_args(None,
+ "Shapes are always computed; don't use the compute_shapes "
+ "as it has no effect.", "compute_shapes")
def create_op(
self,
op_type,
@@ -3370,8 +3181,8 @@ class Graph(object):
proto).
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
the operation will have.
- compute_shapes: (Optional.) If True, shape inference will be performed
- to compute the shapes of the outputs.
+ compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
+ computed).
compute_device: (Optional.) If True, device functions will be executed
to compute the device property of the Operation.
@@ -3381,8 +3192,9 @@ class Graph(object):
Returns:
An `Operation` object.
-
"""
+ del compute_shapes
+
self._check_not_finalized()
for idx, a in enumerate(inputs):
if not isinstance(a, Tensor):
@@ -3400,9 +3212,9 @@ class Graph(object):
input_ops = set([t.op for t in inputs])
control_inputs = self._control_dependencies_for_inputs(input_ops)
- # _create_op_helper mutates the new Operation. _lock ensures a Session.run
- # call cannot occur between creating and mutating the op.
- with self._lock:
+ # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a
+ # Session.run call cannot occur between creating and mutating the op.
+ with self._mutation_lock():
ret = Operation(
node_def,
self,
@@ -3412,18 +3224,7 @@ class Graph(object):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
-
- # Note: shapes are lazily computed with the C API enabled.
- #
- # TODO(skyewm): unlike in the original Python implementation, the C API
- # always computes shape information (even for function calls, which the
- # original Python shape inference code doesn't handle). Deprecate the
- # compute_shapes argument.
- if not _USE_C_API and compute_shapes:
- set_shape_and_handle_data_for_outputs(ret)
-
- self._create_op_helper(ret, compute_shapes=compute_shapes,
- compute_device=compute_device)
+ self._create_op_helper(ret, compute_device=compute_device)
return ret
def _create_op_from_tf_operation(self, c_op, compute_device=True):
@@ -3458,11 +3259,8 @@ class Graph(object):
self._create_op_helper(ret, compute_device=compute_device)
return ret
- def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
+ def _create_op_helper(self, op, compute_device=True):
"""Common logic for creating an op in this graph."""
- # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
- self._add_op(op)
-
# Apply any additional attributes requested. Do not overwrite any existing
# attributes.
for key, value in self._attr_scope_map.items():
@@ -3529,8 +3327,7 @@ class Graph(object):
# (2) "is_stateful" is set in OpDef
# (3) "container" attribute is in OpDef
# (4) "container" attribute is None
- # TODO(skyewm): remove op.op_def check when _USE_C_API is removed.
- if self._container and op.op_def and op.op_def.is_stateful:
+ if self._container and op.op_def.is_stateful:
try:
container_attr = op.get_attr("container")
except ValueError:
@@ -3817,17 +3614,14 @@ class Graph(object):
def _get_op_def(self, type): # pylint: disable=redefined-builtin
"""Returns the `OpDef` proto for `type`. `type` is a string."""
- if self._c_graph:
- with c_api_util.tf_buffer() as buf:
- # pylint: disable=protected-access
- c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
- # pylint: enable=protected-access
- data = c_api.TF_GetBuffer(buf)
- op_def = op_def_pb2.OpDef()
- op_def.ParseFromString(compat.as_bytes(data))
- return op_def
- else:
- return self._registered_ops[type]
+ with c_api_util.tf_buffer() as buf:
+ # pylint: disable=protected-access
+ c_api.TF_GraphGetOpDef(self._c_graph, compat.as_bytes(type), buf)
+ # pylint: enable=protected-access
+ data = c_api.TF_GetBuffer(buf)
+ op_def = op_def_pb2.OpDef()
+ op_def.ParseFromString(compat.as_bytes(data))
+ return op_def
def as_default(self):
"""Returns a context manager that makes this `Graph` the default graph.
@@ -3883,7 +3677,6 @@ class Graph(object):
contains many standard names for collections.
value: The value to add to the collection.
""" # pylint: disable=g-doc-exception
- _assert_collection_is_ok(name)
self._check_not_finalized()
with self._lock:
if name not in self._collections:
@@ -3930,7 +3723,6 @@ class Graph(object):
The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection.
""" # pylint: disable=g-doc-exception
- _assert_collection_is_ok(name)
with self._lock:
coll_list = self._collections.get(name, None)
if coll_list is None:
@@ -3960,7 +3752,6 @@ class Graph(object):
list contains the values in the order under which they were
collected.
""" # pylint: disable=g-doc-exception
- _assert_collection_is_ok(name)
with self._lock:
collection = self._collections.get(name, None)
if collection is None:
@@ -4956,6 +4747,20 @@ class Graph(object):
else:
self._graph_control_dependencies_stack = control_dependencies
+ def _mutation_lock(self):
+ """Returns a lock to guard code that creates & mutates ops.
+
+ See the comment for self._group_lock for more info.
+ """
+ return self._group_lock.group(_MUTATION_LOCK_GROUP)
+
+ def _session_run_lock(self):
+ """Returns a lock to guard code for Session.run.
+
+ See the comment for self._group_lock for more info.
+ """
+ return self._group_lock.group(_SESSION_RUN_LOCK_GROUP)
+
# TODO(agarwal): currently device directives in an outer eager scope will not
# apply to inner graph mode code. Fix that.
@@ -5384,7 +5189,8 @@ def init_scope():
@tf_export("enable_eager_execution")
-def enable_eager_execution(config=None, device_policy=None,
+def enable_eager_execution(config=None,
+ device_policy=None,
execution_mode=None):
"""Enables eager execution for the lifetime of this program.
@@ -5444,6 +5250,31 @@ def enable_eager_execution(config=None, device_policy=None,
TensorFlow graph, or if options provided conflict with a previous call
to this function.
"""
+ return enable_eager_execution_internal(
+ config, device_policy, execution_mode, None)
+
+
+def enable_eager_execution_internal(config=None,
+ device_policy=None,
+ execution_mode=None,
+ server_def=None):
+ """Enables eager execution for the lifetime of this program.
+
+ Most of the doc string for enable_eager_execution is relevant here as well.
+ Args:
+ config: See enable_eager_execution doc string
+ device_policy: See enable_eager_execution doc string
+ execution_mode: See enable_eager_execution doc string
+ server_def: (Optional.) A tensorflow::ServerDef proto.
+ Enables execution on remote devices. GrpcServers need to be started by
+ creating an identical server_def to this, and setting the appropriate
+ task_indexes, so that the servers can communicate. It will then be
+ possible to execute operations on remote devices.
+
+ Raises:
+ ValueError
+
+ """
if config is not None and not isinstance(config, config_pb2.ConfigProto):
raise TypeError(
"config must be a tf.ConfigProto, but got %s" % type(config))
@@ -5471,7 +5302,8 @@ def enable_eager_execution(config=None, device_policy=None,
context._context = context.Context(
config=config,
device_policy=device_policy,
- execution_mode=execution_mode)
+ execution_mode=execution_mode,
+ server_def=server_def)
elif ((config is not None and config is not context._context._config) or
(device_policy is not None and
device_policy is not context._context._device_policy) or
@@ -5830,7 +5662,8 @@ def add_to_collection(name, value):
value: The value to add to the collection.
@compatibility(eager)
- Collections are not supported when eager execution is enabled.
+ Collections are only supported in eager when variables are created inside an
+ EagerVariableStore (e.g. as part of a layer or template).
@end_compatibility
"""
get_default_graph().add_to_collection(name, value)
@@ -5848,7 +5681,8 @@ def add_to_collections(names, value):
value: The value to add to the collections.
@compatibility(eager)
- Collections are not supported when eager execution is enabled.
+ Collections are only supported in eager when variables are created inside an
+ EagerVariableStore (e.g. as part of a layer or template).
@end_compatibility
"""
get_default_graph().add_to_collections(names, value)
@@ -6141,14 +5975,6 @@ def get_from_proto_function(collection_name):
return None
-def _assert_collection_is_ok(collection_name):
- if context.executing_eagerly():
- if collection_name in GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access
- raise ValueError(
- "variable collections are not supported when eager execution is enabled."
- )
-
-
def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
"""Produce a nice error if someone converts an Operation to a Tensor."""
raise TypeError(("Can't convert Operation '%s' to Tensor "
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index b3bc800fee..150100d771 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -270,7 +270,6 @@ class OperationTest(test_util.TensorFlowTestCase):
op1 = ops.Operation(
ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
[dtypes.float32_ref, dtypes.float32])
- g._add_op(op1)
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
self.assertEquals([], list(op1.inputs))
ref_t, nonref_t = op1.values()
@@ -279,14 +278,12 @@ class OperationTest(test_util.TensorFlowTestCase):
ops._NodeDef("RefInputFloatInput", "op2"),
g, [ref_t, nonref_t], [],
input_types=[dtypes.float32_ref, dtypes.float32])
- g._add_op(op2)
self.assertProtoEquals(
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
op2.node_def)
self.assertEquals([ref_t, nonref_t], list(op2.inputs))
op3 = ops.Operation(
ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
- g._add_op(op3)
self.assertProtoEquals(
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
op3.node_def)
@@ -1693,7 +1690,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
# e should be dominated by c.
self.assertEqual(e.op.control_inputs, [])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEager(self):
def future():
future.calls += 1
@@ -1878,7 +1875,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
class OpScopeTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNames(self):
with ops.name_scope("foo") as foo:
self.assertEqual("foo/", foo)
@@ -1909,7 +1906,7 @@ class OpScopeTest(test_util.TensorFlowTestCase):
with ops.name_scope("a//b/c") as foo10:
self.assertEqual("a//b/c/", foo10)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerDefaultScopeName(self):
with ops.name_scope(None, "default") as scope:
self.assertEqual(scope, "default/")
diff --git a/tensorflow/python/framework/random_seed_test.py b/tensorflow/python/framework/random_seed_test.py
index 1944922686..6696bffc6c 100644
--- a/tensorflow/python/framework/random_seed_test.py
+++ b/tensorflow/python/framework/random_seed_test.py
@@ -26,7 +26,7 @@ from tensorflow.python.platform import test
class RandomSeedTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRandomSeed(self):
test_cases = [
# Each test case is a tuple with input to get_seed:
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 35fff80c61..395cf43b3f 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -50,13 +50,13 @@ class TensorUtilTest(test.TestCase):
def testFloatN(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -68,13 +68,13 @@ class TensorUtilTest(test.TestCase):
def testFloatTyped(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -86,13 +86,13 @@ class TensorUtilTest(test.TestCase):
def testFloatTypeCoerce(self):
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -105,13 +105,13 @@ class TensorUtilTest(test.TestCase):
arr = np.asarray([10, 20, 30], dtype="int")
t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -123,13 +123,13 @@ class TensorUtilTest(test.TestCase):
def testFloatSizes(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -141,13 +141,13 @@ class TensorUtilTest(test.TestCase):
def testFloatSizes2(self):
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } dim { size: 1 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } dim { size: 1 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -169,13 +169,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(
np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_DOUBLE
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_DOUBLE
tensor_shape { dim { size: 1 } dim { size: 3 } }
tensor_content: "\000\000\000\000\000\000$@\000\000\000\000\000\0004@\000\000\000\000\000\000>@"
@@ -206,13 +206,13 @@ class TensorUtilTest(test.TestCase):
self.assertEquals(np.float32, a.dtype)
self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "A \000\000A\240\000\000A\360\000\000"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_FLOAT
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000 A\000\000\240A\000\000\360A"
@@ -299,16 +299,16 @@ class TensorUtilTest(test.TestCase):
def testIntNDefaultType(self):
t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT32
tensor_shape { dim { size: 2 } dim { size: 2 } }
- tensor_content: "\000\000\000\\n\000\000\000\024\000\000\000\036\000\000\000("
+ tensor_content: "\000\000\000\n\000\000\000\024\000\000\000\036\000\000\000("
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT32
tensor_shape { dim { size: 2 } dim { size: 2 } }
- tensor_content: "\\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
+ tensor_content: "\n\000\000\000\024\000\000\000\036\000\000\000(\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int32, a.dtype)
@@ -380,16 +380,16 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(
[10, 20, 30], shape=[1, 3], dtype=dtypes.int64)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 1 } dim { size: 3 } }
- tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
+ tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 1 } dim { size: 3 } }
- tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int64, a.dtype)
@@ -398,16 +398,16 @@ class TensorUtilTest(test.TestCase):
def testLongNpArray(self):
t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 3 } }
- tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
+ tensor_content: "\000\000\000\000\000\000\000\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_INT64
tensor_shape { dim { size: 3 } }
- tensor_content: "\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
+ tensor_content: "\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036\000\000\000\000\000\000\000"
""", t)
a = tensor_util.MakeNdarray(t)
self.assertEquals(np.int64, a.dtype)
@@ -419,13 +419,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT32
tensor_shape { dim { size: 3 } }
tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT32
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\000\000\026\000\000\000\027\000\000\000"
@@ -435,7 +435,7 @@ class TensorUtilTest(test.TestCase):
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint8)
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT8
tensor_shape { dim { size: 3 } }
tensor_content: "\025\026\027"
@@ -445,7 +445,7 @@ class TensorUtilTest(test.TestCase):
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint8)
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT8
tensor_shape { dim { size: 3 } }
tensor_content: "\025\026\027"
@@ -456,13 +456,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\000\025\000\026\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QUINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\026\000\027\000"
@@ -473,13 +473,13 @@ class TensorUtilTest(test.TestCase):
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16)
if sys.byteorder == "big":
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\000\025\000\026\000\027"
""", t)
else:
- self.assertProtoEquals("""
+ self.assertProtoEquals(r"""
dtype: DT_QINT16
tensor_shape { dim { size: 3 } }
tensor_content: "\025\000\026\000\027\000"
@@ -941,7 +941,7 @@ class ConstantValueTest(test.TestCase):
class ConstantValueAsShapeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConstant(self):
np_val = np.random.rand(3).astype(np.int32)
tf_val = constant_op.constant(np_val)
@@ -954,13 +954,13 @@ class ConstantValueAsShapeTest(test.TestCase):
tensor_shape.TensorShape([]),
tensor_util.constant_value_as_shape(tf_val))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testShape(self):
tf_val = array_ops.shape(constant_op.constant(0.0, shape=[1, 2, 3]))
c_val = tensor_util.constant_value_as_shape(tf_val)
self.assertEqual(tensor_shape.TensorShape([1, 2, 3]), c_val)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMinusOneBecomesNone(self):
tf_val = constant_op.constant([-1, 1, -1], shape=[3])
c_val = tensor_util.constant_value_as_shape(tf_val)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b56483f373..2bc2a189fa 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -27,6 +27,7 @@ import random
import re
import tempfile
import threading
+import unittest
import numpy as np
import six
@@ -61,13 +62,13 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
@@ -321,32 +322,6 @@ def NCHWToNHWC(input_tensor):
return [input_tensor[a] for a in new_axes[ndims]]
-# TODO(skyewm): remove this eventually
-# pylint: disable=protected-access
-def _use_c_api_wrapper(fn, use_c_api, *args, **kwargs):
- prev_value = ops._USE_C_API
- ops._USE_C_API = use_c_api
- try:
- # Reset the default graph so it has the C API enabled. We call
- # reset_default_graph() instead of creating a new default Graph context to
- # make this robust to tests that call reset_default_graph(), which requires
- # that the current default graph isn't nested.
- ops.reset_default_graph()
- fn(*args, **kwargs)
- finally:
- ops._USE_C_API = prev_value
- # Make sure default graph reflects prev_value in case next test doesn't call
- # reset_default_graph().
- ops.reset_default_graph()
-
-
-# pylint: disable=protected-access
-
-
-def c_api_and_cuda_enabled():
- return ops._USE_C_API and IsGoogleCudaEnabled()
-
-
def skip_if(condition):
"""Skips the decorated function if condition is or evaluates to True.
@@ -372,46 +347,6 @@ def skip_if(condition):
return real_skip_if
-# TODO(skyewm): remove this eventually
-def disable_c_api(fn):
- """Decorator for disabling the C API on a test.
-
- Note this disables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, False, *args, **kwargs)
-
- return wrapper
-
-
-# TODO(skyewm): remove this eventually
-def enable_c_api(fn):
- """Decorator for enabling the C API on a test.
-
- Note this enables the C API after running the test class's setup/teardown
- methods.
-
- Args:
- fn: the function to be wrapped
-
- Returns:
- The wrapped function
- """
-
- def wrapper(*args, **kwargs):
- _use_c_api_wrapper(fn, True, *args, **kwargs)
-
- return wrapper
-
-
def enable_c_shapes(fn):
"""Decorator for enabling C shapes on a test.
@@ -425,46 +360,19 @@ def enable_c_shapes(fn):
The wrapped function
"""
+ # pylint: disable=protected-access
def wrapper(*args, **kwargs):
prev_value = ops._USE_C_SHAPES
- # Only use C shapes if the C API is already enabled.
- ops._USE_C_SHAPES = ops._USE_C_API
+ ops._USE_C_SHAPES = True
try:
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
+ # pylint: enable=protected-access
return wrapper
-# This decorator is a hacky way to run all the test methods in a decorated
-# class with and without C API enabled.
-# TODO(iga): Remove this and its uses once we switch to using C API by default.
-def with_c_api(cls):
- """Adds methods that call original methods but with C API enabled.
-
- Note this enables the C API in new methods after running the test class's
- setup method. This can be a problem if some objects are created in it
- before the C API is enabled.
-
- Args:
- cls: class to decorate
-
- Returns:
- cls with new test methods added
- """
- # If the C API is already enabled, don't do anything. Some tests break if the
- # same test is run twice, so this allows us to turn on the C API by default
- # without breaking these tests.
- if ops._USE_C_API:
- return cls
-
- for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCApi", enable_c_api(value))
- return cls
-
-
def with_c_shapes(cls):
"""Adds methods that call original methods but with C API shapes enabled.
@@ -507,8 +415,28 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
+ collection_sizes_before = {
+ collection: len(ops.get_collection(collection))
+ for collection in ops.get_default_graph().collections}
for _ in range(3):
f(self, **kwargs)
+ # Note that gc.get_objects misses anything that isn't subject to garbage
+ # collection (C types). Collections are a common source of leaks, so we
+ # test for collection sizes explicitly.
+ for collection_key in ops.get_default_graph().collections:
+ collection = ops.get_collection(collection_key)
+ size_before = collection_sizes_before.get(collection_key, 0)
+ if len(collection) > size_before:
+ raise AssertionError(
+ ("Collection %s increased in size from "
+ "%d to %d (current items %s).")
+ % (collection_key, size_before, len(collection), collection))
+ # Make sure our collection checks don't show up as leaked memory by
+ # removing references to temporary variables.
+ del collection
+ del collection_key
+ del size_before
+ del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
@@ -644,14 +572,15 @@ def assert_no_garbage_created(f):
def run_all_in_graph_and_eager_modes(cls):
- base_decorator = run_in_graph_and_eager_modes()
+ """Execute all test methods in the given class with and without eager."""
+ base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name, base_decorator(value))
return cls
-def run_in_graph_and_eager_modes(__unused__=None,
+def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,
reset_test=True,
@@ -669,7 +598,7 @@ def run_in_graph_and_eager_modes(__unused__=None,
```python
class MyTests(tf.test.TestCase):
- @run_in_graph_and_eager_modes()
+ @run_in_graph_and_eager_modes
def test_foo(self):
x = tf.constant([1, 2])
y = tf.constant([3, 4])
@@ -686,7 +615,9 @@ def run_in_graph_and_eager_modes(__unused__=None,
Args:
- __unused__: Prevents silently skipping tests.
+ func: function to be annotated. If `func` is None, this method returns a
+ decorator the can be applied to a function. If `func` is not None this
+ returns the decorator applied to `func`.
config: An optional config_pb2.ConfigProto to use to configure the
session when executing graphs.
use_gpu: If True, attempt to run as many operations as possible on GPU.
@@ -708,20 +639,19 @@ def run_in_graph_and_eager_modes(__unused__=None,
eager execution enabled.
"""
- assert not __unused__, "Add () after run_in_graph_and_eager_modes."
-
def decorator(f):
- def decorated(self, **kwargs):
- with context.graph_mode():
- with self.test_session(use_gpu=use_gpu):
- f(self, **kwargs)
+ if tf_inspect.isclass(f):
+ raise ValueError(
+ "`run_test_in_graph_and_eager_modes` only supports test methods. "
+ "Did you mean to use `run_all_tests_in_graph_and_eager_modes`?")
- if reset_test:
- # This decorator runs the wrapped test twice.
- # Reset the test environment between runs.
- self.tearDown()
- self._tempdir = None
- self.setUp()
+ def decorated(self, **kwargs):
+ try:
+ with context.graph_mode():
+ with self.test_session(use_gpu=use_gpu, config=config):
+ f(self, **kwargs)
+ except unittest.case.SkipTest:
+ pass
def run_eagerly(self, **kwargs):
if not use_gpu:
@@ -736,10 +666,20 @@ def run_in_graph_and_eager_modes(__unused__=None,
assert_no_garbage_created(run_eagerly))
with context.eager_mode():
+ if reset_test:
+ # This decorator runs the wrapped test twice.
+ # Reset the test environment between runs.
+ self.tearDown()
+ self._tempdir = None
+ self.setUp()
+
run_eagerly(self, **kwargs)
return decorated
+ if func is not None:
+ return decorator(func)
+
return decorator
@@ -922,14 +862,13 @@ class TensorFlowTestCase(googletest.TestCase):
def _eval_tensor(self, tensor):
if tensor is None:
return None
- elif isinstance(tensor, ops.EagerTensor):
- return tensor.numpy()
- elif isinstance(tensor, resource_variable_ops.ResourceVariable):
- return tensor.read_value().numpy()
elif callable(tensor):
return self._eval_helper(tensor())
else:
- raise ValueError("Unsupported type %s." % type(tensor))
+ try:
+ return tensor.numpy()
+ except AttributeError as e:
+ six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
def _eval_helper(self, tensors):
if tensors is None:
@@ -1334,11 +1273,11 @@ class TensorFlowTestCase(googletest.TestCase):
b,
rtol=rtol,
atol=atol,
- msg="Mismatched value: a%s is different from b%s." % (path_str,
- path_str))
+ msg=("Mismatched value: a%s is different from b%s. %s" %
+ (path_str, path_str, msg)))
except TypeError as e:
- msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
- path_str, type(b))
+ msg = ("Error: a%s has %s, but b%s has %s. %s" %
+ (path_str, type(a), path_str, type(b), msg))
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
raise
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 0178908bcc..122c14c847 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -569,7 +569,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(a_np_rand, b_np_rand)
self.assertEqual(a_rand, b_rand)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_callable_evaluate(self):
def model():
return resource_variable_ops.ResourceVariable(
@@ -578,7 +578,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with context.eager_mode():
self.assertEqual(2, self.evaluate(model))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_nested_tensors_evaluate(self):
expected = {"a": 1, "b": 2, "nested": {"d": 3, "e": 4}}
nested = {"a": constant_op.constant(1),
@@ -588,6 +588,27 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(expected, self.evaluate(nested))
+ def test_run_in_graph_and_eager_modes(self):
+ l = []
+ def inc(self, with_brackets):
+ del self # self argument is required by run_in_graph_and_eager_modes.
+ mode = "eager" if context.executing_eagerly() else "graph"
+ with_brackets = "with_brackets" if with_brackets else "without_brackets"
+ l.append((with_brackets, mode))
+
+ f = test_util.run_in_graph_and_eager_modes(inc)
+ f(self, with_brackets=False)
+ f = test_util.run_in_graph_and_eager_modes()(inc)
+ f(self, with_brackets=True)
+
+ self.assertEqual(len(l), 4)
+ self.assertEqual(set(l), {
+ ("with_brackets", "graph"),
+ ("with_brackets", "eager"),
+ ("without_brackets", "graph"),
+ ("without_brackets", "eager"),
+ })
+
def test_get_node_def_from_graph(self):
graph_def = graph_pb2.GraphDef()
node_foo = graph_def.node.add()
@@ -595,6 +616,55 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
+ def test_run_in_eager_and_graph_modes_test_class(self):
+ msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*"
+ with self.assertRaisesRegexp(ValueError, msg):
+ @test_util.run_in_graph_and_eager_modes()
+ class Foo(object):
+ pass
+ del Foo # Make pylint unused happy.
+
+ def test_run_in_eager_and_graph_modes_skip_graph_runs_eager(self):
+ modes = []
+ def _test(self):
+ if not context.executing_eagerly():
+ self.skipTest("Skipping in graph mode")
+ modes.append("eager" if context.executing_eagerly() else "graph")
+ test_util.run_in_graph_and_eager_modes(_test)(self)
+ self.assertEqual(modes, ["eager"])
+
+ def test_run_in_eager_and_graph_modes_skip_eager_runs_graph(self):
+ modes = []
+ def _test(self):
+ if context.executing_eagerly():
+ self.skipTest("Skipping in eager mode")
+ modes.append("eager" if context.executing_eagerly() else "graph")
+ test_util.run_in_graph_and_eager_modes(_test)(self)
+ self.assertEqual(modes, ["graph"])
+
+ def test_run_in_graph_and_eager_modes_setup_in_same_mode(self):
+ modes = []
+ mode_name = lambda: "eager" if context.executing_eagerly() else "graph"
+
+ class ExampleTest(test_util.TensorFlowTestCase):
+
+ def runTest(self):
+ pass
+
+ def setUp(self):
+ modes.append("setup_" + mode_name())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBody(self):
+ modes.append("run_" + mode_name())
+
+ e = ExampleTest()
+ e.setUp()
+ e.testBody()
+
+ self.assertEqual(modes[0:2], ["setup_graph", "run_graph"])
+ self.assertEqual(modes[2:], ["setup_eager", "run_eager"])
+
class GarbageCollectionTest(test_util.TensorFlowTestCase):
@@ -619,7 +689,7 @@ class GarbageCollectionTest(test_util.TensorFlowTestCase):
ReferenceCycleTest().test_has_no_cycle()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_no_leaked_tensor_decorator(self):
class LeakedTensorTest(object):
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index af5d709f7e..7d07c77c79 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -158,6 +158,7 @@ def _get_config(layout_optimizer=True):
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
# do not remove duplicated nodes
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ rewrite_options.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrite_options, build_cost_model=1)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -1443,7 +1444,8 @@ class LayoutOptimizerTest(test.TestCase):
def testGradient(self):
meta_graph = _simple_metagraph()
rewrite_options = rewriter_config_pb2.RewriterConfig(
- layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON,
+ min_graph_nodes=-1)
optimized_graph = tf_optimizer.OptimizeGraph(
rewrite_options, meta_graph, cluster=_get_cluster())
@@ -1457,7 +1459,8 @@ class LayoutOptimizerTest(test.TestCase):
def testDepthwise(self):
meta_graph = _simple_metagraph(depthwise=True)
rewrite_options = rewriter_config_pb2.RewriterConfig(
- layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON,
+ min_graph_nodes=-1)
optimized_graph = tf_optimizer.OptimizeGraph(
rewrite_options, meta_graph, cluster=_get_cluster())
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index 7ed4b128e4..b658edff2d 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -76,7 +76,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
disable_model_pruning=True,
meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
- memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL,
+ min_graph_nodes=-1)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), graph_size + 2)
@@ -133,6 +134,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS), original_metagraph)
self.assertGreater(
@@ -158,6 +160,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
# Checks that name scope "gradients/" also match sub-scope.
@@ -297,6 +300,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
if 'Recomputed/' in node.name]))
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL),
metagraph)
self.assertEqual(
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 1c0f072dd3..5a9afe7257 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -47,6 +47,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append('constfold')
+ rewriter_config.min_graph_nodes = -1
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
@@ -68,6 +69,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
# Check that the nodes referenced in various collections have been preserved
@@ -109,6 +111,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
mg.graph_def.CopyFrom(optimized_graph)
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index fe40c9fbed..4056818a95 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -39,6 +39,7 @@ py_library(
"datasets/imdb.py",
"datasets/mnist.py",
"datasets/reuters.py",
+ "estimator/__init__.py",
"preprocessing/__init__.py",
"preprocessing/image.py",
"preprocessing/sequence.py",
@@ -135,7 +136,7 @@ py_library(
deps = [
":backend",
"//tensorflow/python/data",
- "//tensorflow/python/training/checkpointable:data_structures_base",
+ "//tensorflow/python/training/checkpointable:data_structures",
"@six_archive//:six",
],
)
@@ -450,6 +451,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
],
shard_count = 2,
+ tags = ["no_windows_gpu"],
)
py_test(
@@ -549,7 +551,7 @@ py_test(
py_test(
name = "gru_test",
- size = "medium",
+ size = "large",
srcs = ["layers/gru_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"], # http://b/62136390
@@ -719,6 +721,7 @@ py_test(
size = "medium",
srcs = ["preprocessing/image_test.py"],
srcs_version = "PY2AND3",
+ tags = ["nomsan"], # TODO(b/110990716) reenable
deps = [
":keras",
"//tensorflow/python:client_testlib",
@@ -858,7 +861,7 @@ py_test(
py_test(
name = "backend_test",
- size = "small",
+ size = "medium",
srcs = ["backend_test.py"],
srcs_version = "PY2AND3",
deps = [
@@ -866,6 +869,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:util",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/keras/__init__.py b/tensorflow/python/keras/__init__.py
index 3493069a5b..198c66d9e1 100644
--- a/tensorflow/python/keras/__init__.py
+++ b/tensorflow/python/keras/__init__.py
@@ -27,6 +27,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks
from tensorflow.python.keras import constraints
from tensorflow.python.keras import datasets
+from tensorflow.python.keras import estimator
from tensorflow.python.keras import initializers
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index a62dadb830..f608dea430 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -32,7 +32,7 @@ def softmax(x, axis=-1):
"""Softmax activation function.
Arguments:
- x : Tensor.
+ x : Input tensor.
axis: Integer, axis along which the softmax normalization is applied.
Returns:
@@ -49,28 +49,52 @@ def softmax(x, axis=-1):
s = math_ops.reduce_sum(e, axis=axis, keepdims=True)
return e / s
else:
- raise ValueError('Cannot apply softmax to a tensor that is 1D')
+ raise ValueError('Cannot apply softmax to a tensor that is 1D. '
+ 'Received input: %s' % (x,))
@tf_export('keras.activations.elu')
def elu(x, alpha=1.0):
+ """Exponential linear unit.
+
+ Arguments:
+ x: Input tensor.
+ alpha: A scalar, slope of negative section.
+
+ Returns:
+ The exponential linear activation: `x` if `x > 0` and
+ `alpha * (exp(x)-1)` if `x < 0`.
+
+ Reference:
+ - [Fast and Accurate Deep Network Learning by Exponential
+ Linear Units (ELUs)](https://arxiv.org/abs/1511.07289)
+ """
return K.elu(x, alpha)
@tf_export('keras.activations.selu')
def selu(x):
- """Scaled Exponential Linear Unit. (Klambauer et al., 2017).
+ """Scaled Exponential Linear Unit (SELU).
+
+ SELU is equal to: `scale * elu(x, alpha)`, where alpha and scale
+ are pre-defined constants. The values of `alpha` and `scale` are
+ chosen so that the mean and variance of the inputs are preserved
+ between two consecutive layers as long as the weights are initialized
+ correctly (see `lecun_normal` initialization) and the number of inputs
+ is "large enough" (see references for more information).
Arguments:
x: A tensor or variable to compute the activation function for.
Returns:
- Tensor with the same shape and dtype as `x`.
+ The scaled exponential unit activation: `scale * elu(x, alpha)`.
# Note
- To be used together with the initialization "lecun_normal".
- To be used together with the dropout variant "AlphaDropout".
+ References:
+ - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
@@ -79,16 +103,44 @@ def selu(x):
@tf_export('keras.activations.softplus')
def softplus(x):
+ """Softplus activation function.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ The softplus activation: `log(exp(x) + 1)`.
+ """
return nn.softplus(x)
@tf_export('keras.activations.softsign')
def softsign(x):
+ """Softsign activation function.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ The softplus activation: `x / (abs(x) + 1)`.
+ """
return nn.softsign(x)
@tf_export('keras.activations.relu')
def relu(x, alpha=0., max_value=None):
+ """Rectified Linear Unit.
+
+ Arguments:
+ x: Input tensor.
+ alpha: Slope of the negative part. Defaults to zero.
+ max_value: Maximum value for the output.
+
+ Returns:
+ The (leaky) rectified linear unit activation: `x` if `x > 0`,
+ `alpha * x` if `x < 0`. If `max_value` is defined, the result
+ is truncated to this value.
+ """
return K.relu(x, alpha=alpha, max_value=max_value)
@@ -104,6 +156,19 @@ def sigmoid(x):
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
+ """Hard sigmoid activation function.
+
+ Faster to compute than sigmoid activation.
+
+ Arguments:
+ x: Input tensor.
+
+ Returns:
+ Hard sigmoid activation:
+ - `0` if `x < -2.5`
+ - `1` if `x > 2.5`
+ - `0.2 * x + 0.5` if `-2.5 <= x <= 2.5`.
+ """
return K.hard_sigmoid(x)
diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index f81f10719a..8df6d08611 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -31,7 +31,6 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import AveragePooling2D
from tensorflow.python.keras.layers import BatchNormalization
@@ -44,6 +43,7 @@ from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.models import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import tf_export
@@ -238,7 +238,7 @@ def DenseNet(blocks,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index fe1d0f2d4f..14e3b6aa60 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -31,7 +31,6 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import AveragePooling2D
from tensorflow.python.keras.layers import BatchNormalization
@@ -44,6 +43,7 @@ from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.models import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -354,7 +354,7 @@ def InceptionResNetV2(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index 857ad49dae..b5e28c781f 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -37,7 +37,6 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import AveragePooling2D
from tensorflow.python.keras.layers import BatchNormalization
@@ -48,6 +47,7 @@ from tensorflow.python.keras.layers import GlobalMaxPooling2D
from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.models import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -375,7 +375,7 @@ def InceptionV3(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index 9d845be0d5..e56c695a28 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -78,8 +78,7 @@ from tensorflow.python.keras import regularizers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine.network import get_source_inputs
+from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.layers import Conv2D
@@ -92,6 +91,7 @@ from tensorflow.python.keras.layers import Reshape
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.utils import conv_utils
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -317,7 +317,7 @@ def MobileNet(input_shape=None,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index b521bc6731..ff79b3a057 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -49,7 +49,6 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras.applications.inception_v3 import preprocess_input
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import add
from tensorflow.python.keras.layers import AveragePooling2D
@@ -65,6 +64,7 @@ from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.layers import SeparableConv2D
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.models import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -290,7 +290,7 @@ def NASNet(input_shape=None,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py
index 508550f445..6afc086812 100644
--- a/tensorflow/python/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/applications/resnet50.py
@@ -34,7 +34,6 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import AveragePooling2D
from tensorflow.python.keras.layers import BatchNormalization
@@ -277,7 +276,7 @@ def ResNet50(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index 659a6533e6..cef0230da9 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -32,7 +32,6 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Conv2D
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.layers import Flatten
@@ -202,7 +201,7 @@ def VGG16(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index 5e27ab8fb1..c4031f5510 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -32,7 +32,6 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Conv2D
from tensorflow.python.keras.layers import Dense
from tensorflow.python.keras.layers import Flatten
@@ -211,7 +210,7 @@ def VGG19(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index e1be8a3c46..01397cfac2 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -44,7 +44,6 @@ from tensorflow.python.keras import layers
from tensorflow.python.keras.applications import imagenet_utils
from tensorflow.python.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras.applications.imagenet_utils import decode_predictions
-from tensorflow.python.keras.engine.network import get_source_inputs
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.layers import Conv2D
@@ -55,6 +54,7 @@ from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import MaxPooling2D
from tensorflow.python.keras.layers import SeparableConv2D
from tensorflow.python.keras.models import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -302,7 +302,7 @@ def Xception(include_top=True,
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
- inputs = get_source_inputs(input_tensor)
+ inputs = layer_utils.get_source_inputs(input_tensor)
else:
inputs = img_input
# Create model.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index af3d1fa33d..cb3423598b 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import itertools
import json
import os
import weakref
@@ -962,13 +963,14 @@ def zeros(shape, dtype=None, name=None):
[ 0., 0., 0., 0.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.ones')
@@ -995,13 +997,14 @@ def ones(shape, dtype=None, name=None):
[ 1., 1., 1., 1.]], dtype=float32)
```
"""
- if dtype is None:
- dtype = floatx()
- tf_dtype = dtypes_module.as_dtype(dtype)
- v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
- if py_all(v.get_shape().as_list()):
- return variable(v, dtype=dtype, name=name)
- return v
+ with ops.init_scope():
+ if dtype is None:
+ dtype = floatx()
+ tf_dtype = dtypes_module.as_dtype(dtype)
+ v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
+ if py_all(v.get_shape().as_list()):
+ return variable(v, dtype=dtype, name=name)
+ return v
@tf_export('keras.backend.eye')
@@ -2794,10 +2797,15 @@ class Function(object):
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
# The main use case of `fetches` being passed to a model is the ability
- # to run custom updates (since the outputs of fetches are never returned).
+ # to run custom updates
# This requires us to wrap fetches in `identity` ops.
self.fetches = [array_ops.identity(x) for x in self.fetches]
self.session_kwargs = session_kwargs
+ # This mapping keeps track of the function that should receive the
+ # output from a fetch in `fetches`: { fetch: function(fetch_output) }
+ # A Callback can use this to register a function with access to the
+ # output values for a fetch it added.
+ self.fetch_callbacks = dict()
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
@@ -2807,6 +2815,7 @@ class Function(object):
self._feed_arrays = None
self._feed_symbols = None
self._symbol_vals = None
+ self._fetches = None
self._session = None
def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
@@ -2852,8 +2861,14 @@ class Function(object):
self._feed_arrays = feed_arrays
self._feed_symbols = feed_symbols
self._symbol_vals = symbol_vals
+ self._fetches = list(self.fetches)
self._session = session
+ def _call_fetch_callbacks(self, fetches_output):
+ for fetch, output in zip(self._fetches, fetches_output):
+ if fetch in self.fetch_callbacks:
+ self.fetch_callbacks[fetch](output)
+
def __call__(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise TypeError('`inputs` should be a list or tuple.')
@@ -2880,21 +2895,24 @@ class Function(object):
feed_arrays.append(tensor)
# We need to do array conversion and type casting at this level, since
# `callable_fn` only supports exact matches.
- array_vals.append(np.asarray(value, dtype=tensor.dtype.base_dtype.name))
+ tensor_type = dtypes_module.as_dtype(tensor.dtype)
+ array_vals.append(np.asarray(value,
+ dtype=tensor_type.as_numpy_dtype))
+
if self.feed_dict:
for key in sorted(self.feed_dict.keys()):
array_vals.append(
np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
# Refresh callable if anything has changed.
- if (self._callable_fn is None or
- feed_arrays != self._feed_arrays or
+ if (self._callable_fn is None or feed_arrays != self._feed_arrays or
symbol_vals != self._symbol_vals or
- feed_symbols != self._feed_symbols or
+ feed_symbols != self._feed_symbols or self.fetches != self._fetches or
session != self._session):
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
fetched = self._callable_fn(*array_vals)
+ self._call_fetch_callbacks(fetched[-len(self._fetches):])
return fetched[:len(self.outputs)]
@@ -2973,30 +2991,29 @@ def rnn(step_function,
Arguments:
step_function: RNN step function.
- Parameters;
- input; tensor with shape `(samples, ...)` (no time dimension),
+ Args;
+ input; Tensor with shape `(samples, ...)` (no time dimension),
representing input for the batch of samples at a certain
time step.
- states; list of tensors.
+ states; List of tensors.
Returns;
- output; tensor with shape `(samples, output_dim)`
+ output; Tensor with shape `(samples, output_dim)`
(no time dimension).
- new_states; list of tensors, same length and shapes
+ new_states; List of tensors, same length and shapes
as 'states'. The first state in the list must be the
output tensor at the previous timestep.
- inputs: tensor of temporal data of shape `(samples, time, ...)`
+ inputs: Tensor of temporal data of shape `(samples, time, ...)`
(at least 3D).
- initial_states: tensor with shape (samples, output_dim)
+ initial_states: Tensor with shape `(samples, output_dim)`
(no time dimension),
containing the initial values for the states used in
the step function.
- go_backwards: boolean. If True, do the iteration over the time
+ go_backwards: Boolean. If True, do the iteration over the time
dimension in reverse order and return the reversed sequence.
- mask: binary tensor with shape `(samples, time, 1)`,
+ mask: Binary tensor with shape `(samples, time, 1)`,
with a zero for every element that is masked.
- constants: a list of constant values passed at each step.
- unroll: whether to unroll the RNN or to use a symbolic loop
- (`while_loop` or `scan` depending on backend).
+ constants: List of constant values passed at each step.
+ unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
Returns:
@@ -3158,10 +3175,16 @@ def rnn(step_function,
array_ops.stack(
[1, array_ops.shape(output)[1]]))
output = array_ops.where(tiled_mask_t, output, states[0])
- new_states = [
- array_ops.where(tiled_mask_t, new_states[i], states[i])
- for i in range(len(states))
- ]
+
+ masked_states = []
+ for i in range(len(states)):
+ states_dim = array_ops.shape(new_states[i])[1]
+ stacked_states_dim = array_ops.stack([1, states_dim])
+ tiled_mask = array_ops.tile(mask_t, stacked_states_dim)
+ masked_state = array_ops.where(tiled_mask, new_states[i], states[i])
+ masked_states.append(masked_state)
+ new_states = masked_states
+
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else:
@@ -3637,12 +3660,12 @@ def _preprocess_conv1d_input(x, data_format):
Returns:
A tensor.
"""
- tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
+ tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = array_ops.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
- tf_data_format = 'NCHW'
+ tf_data_format = 'NCW'
return x, tf_data_format
@@ -3741,10 +3764,8 @@ def conv1d(x,
x = temporal_padding(x, (left_pad, 0))
padding = 'valid'
padding = _preprocess_padding(padding)
- if data_format == 'channels_last':
- tf_data_format = 'NWC'
- else:
- tf_data_format = 'NCW'
+
+ x, tf_data_format = _preprocess_conv1d_input(x, data_format)
x = nn.convolution(
input=x,
filter=kernel,
@@ -3752,6 +3773,8 @@ def conv1d(x,
strides=(strides,),
padding=padding,
data_format=tf_data_format)
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
+ x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -3892,11 +3915,16 @@ def separable_conv1d(x,
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation_rate, int):
+ dilation_rate = (dilation_rate,)
+
x, tf_data_format = _preprocess_conv1d_input(x, data_format)
padding = _preprocess_padding(padding)
if not isinstance(strides, tuple):
strides = tuple(strides)
- if tf_data_format == 'NHWC':
+ if tf_data_format == 'NWC':
spatial_start_dim = 1
strides = (1,) + strides * 2 + (1,)
else:
@@ -3918,7 +3946,7 @@ def separable_conv1d(x,
x = array_ops.squeeze(x, [spatial_start_dim])
- if data_format == 'channels_first' and tf_data_format == 'NHWC':
+ if data_format == 'channels_first' and tf_data_format == 'NWC':
x = array_ops.transpose(x, (0, 2, 1)) # NWC -> NCW
return x
@@ -4238,45 +4266,115 @@ def pool3d(x,
return x
-def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
- """Apply 1D conv with un-shared weights.
-
- Arguments:
- inputs: 3D tensor with shape: (batch_size, steps, input_dim)
- kernel: the unshared weight for convolution,
- with shape (output_length, feature_dim, filters)
- kernel_size: a tuple of a single integer,
- specifying the length of the 1D convolution window
- strides: a tuple of a single integer,
- specifying the stride length of the convolution
- data_format: the data format, channels_first or channels_last
-
- Returns:
- the tensor after 1d conv with un-shared weights, with shape (batch_size,
- output_length, filters)
+def local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format=None):
+ """Apply N-D convolution with un-shared weights.
+
+ Arguments:
+ inputs: (N+2)-D tensor with shape
+ (batch_size, channels_in, d_in1, ..., d_inN)
+ if data_format='channels_first', or
+ (batch_size, d_in1, ..., d_inN, channels_in)
+ if data_format='channels_last'.
+ kernel: the unshared weight for N-D convolution,
+ with shape (output_items, feature_dim, channels_out), where
+ feature_dim = np.prod(kernel_size) * channels_in,
+ output_items = np.prod(output_shape).
+ kernel_size: a tuple of N integers, specifying the
+ spatial dimensions of the N-D convolution window.
+ strides: a tuple of N integers, specifying the strides
+ of the convolution along the spatial dimensions.
+ output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
+ dimensionality of the output.
+ data_format: string, "channels_first" or "channels_last".
+
+ Returns:
+ An (N+2)-D tensor with shape:
+ (batch_size, channels_out) + output_shape
+ if data_format='channels_first', or:
+ (batch_size,) + output_shape + (channels_out,)
+ if data_format='channels_last'.
Raises:
- ValueError: if `data_format` is neither `channels_last` or
- `channels_first`.
+ ValueError: if `data_format` is neither
+ `channels_last` nor `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: ' + str(data_format))
- stride = strides[0]
kernel_shape = int_shape(kernel)
- output_length = kernel_shape[0]
feature_dim = kernel_shape[1]
+ channels_out = kernel_shape[-1]
+ ndims = len(output_shape)
+ spatial_dimensions = list(range(ndims))
xs = []
- for i in range(output_length):
- slice_length = slice(i * stride, i * stride + kernel_size[0])
- xs.append(reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
+ output_axes_ticks = [range(axis_max) for axis_max in output_shape]
+ for position in itertools.product(*output_axes_ticks):
+ slices = [slice(None)]
+
+ if data_format == 'channels_first':
+ slices.append(slice(None))
+
+ slices.extend([slice(position[d] * strides[d],
+ position[d] * strides[d] + kernel_size[d])
+ for d in spatial_dimensions])
+
+ if data_format == 'channels_last':
+ slices.append(slice(None))
+
+ xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
+
x_aggregate = concatenate(xs, axis=0)
- # Shape: `(output_length, batch_size, filters)`.
output = batch_dot(x_aggregate, kernel)
- return permute_dimensions(output, (1, 0, 2))
+ output = reshape(output, output_shape + (-1, channels_out))
+
+ if data_format == 'channels_first':
+ permutation = [ndims, ndims + 1] + spatial_dimensions
+ else:
+ permutation = [ndims] + spatial_dimensions + [ndims + 1]
+
+ return permute_dimensions(output, permutation)
+
+
+def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
+ """Apply 1D conv with un-shared weights.
+
+ Arguments:
+ inputs: 3D tensor with shape:
+ (batch_size, steps, input_dim)
+ if data_format is "channels_last" or
+ (batch_size, input_dim, steps)
+ if data_format is "channels_first".
+ kernel: the unshared weight for convolution,
+ with shape (output_length, feature_dim, filters).
+ kernel_size: a tuple of a single integer,
+ specifying the length of the 1D convolution window.
+ strides: a tuple of a single integer,
+ specifying the stride length of the convolution.
+ data_format: the data format, channels_first or channels_last.
+
+ Returns:
+ A 3d tensor with shape:
+ (batch_size, output_length, filters)
+ if data_format='channels_first'
+ or 3D tensor with shape:
+ (batch_size, filters, output_length)
+ if data_format='channels_last'.
+ """
+ output_shape = (kernel.shape[0],)
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
def local_conv2d(inputs,
@@ -4289,64 +4387,34 @@ def local_conv2d(inputs,
Arguments:
inputs: 4D tensor with shape:
- (batch_size, filters, new_rows, new_cols)
- if data_format='channels_first'
- or 4D tensor with shape:
- (batch_size, new_rows, new_cols, filters)
- if data_format='channels_last'.
+ (batch_size, filters, new_rows, new_cols)
+ if data_format='channels_first'
+ or 4D tensor with shape:
+ (batch_size, new_rows, new_cols, filters)
+ if data_format='channels_last'.
kernel: the unshared weight for convolution,
- with shape (output_items, feature_dim, filters)
+ with shape (output_items, feature_dim, filters).
kernel_size: a tuple of 2 integers, specifying the
- width and height of the 2D convolution window.
+ width and height of the 2D convolution window.
strides: a tuple of 2 integers, specifying the strides
- of the convolution along the width and height.
- output_shape: a tuple with (output_row, output_col)
- data_format: the data format, channels_first or channels_last
+ of the convolution along the width and height.
+ output_shape: a tuple with (output_row, output_col).
+ data_format: the data format, channels_first or channels_last.
Returns:
- A 4d tensor with shape:
+ A 4D tensor with shape:
(batch_size, filters, new_rows, new_cols)
if data_format='channels_first'
or 4D tensor with shape:
(batch_size, new_rows, new_cols, filters)
if data_format='channels_last'.
-
- Raises:
- ValueError: if `data_format` is neither
- `channels_last` or `channels_first`.
"""
- if data_format is None:
- data_format = image_data_format()
- if data_format not in {'channels_first', 'channels_last'}:
- raise ValueError('Unknown data_format: ' + str(data_format))
-
- stride_row, stride_col = strides
- output_row, output_col = output_shape
- kernel_shape = int_shape(kernel)
- feature_dim = kernel_shape[1]
- filters = kernel_shape[2]
-
- xs = []
- for i in range(output_row):
- for j in range(output_col):
- slice_row = slice(i * stride_row, i * stride_row + kernel_size[0])
- slice_col = slice(j * stride_col, j * stride_col + kernel_size[1])
- if data_format == 'channels_first':
- xs.append(
- reshape(inputs[:, :, slice_row, slice_col], (1, -1, feature_dim)))
- else:
- xs.append(
- reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim)))
-
- x_aggregate = concatenate(xs, axis=0)
- output = batch_dot(x_aggregate, kernel)
- output = reshape(output, (output_row, output_col, -1, filters))
-
- if data_format == 'channels_first':
- output = permute_dimensions(output, (2, 3, 0, 1))
- else:
- output = permute_dimensions(output, (2, 0, 1, 3))
- return output
+ return local_conv(inputs,
+ kernel,
+ kernel_size,
+ strides,
+ output_shape,
+ data_format)
@tf_export('keras.backend.bias_add')
@@ -4704,8 +4772,13 @@ def foldr(fn, elems, initializer=None, name=None):
# Load Keras default configuration from config file if present.
-_keras_base_dir = os.path.expanduser('~')
-_keras_dir = os.path.join(_keras_base_dir, '.keras')
+# Set Keras base dir path given KERAS_HOME env variable, if applicable.
+# Otherwise either ~/.keras or /tmp.
+if 'KERAS_HOME' in os.environ:
+ _keras_dir = os.environ.get('KERAS_HOME')
+else:
+ _keras_base_dir = os.path.expanduser('~')
+ _keras_dir = os.path.join(_keras_base_dir, '.keras')
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 58df263a4f..36478ea089 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -17,10 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
import scipy.sparse
from tensorflow.python import keras
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -274,6 +276,36 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
+ def test_function_fetch_callbacks(self):
+
+ class CallbackStub(object):
+
+ def __init__(self):
+ self.times_called = 0
+ self.callback_result = 0
+
+ def _fetch_callback(self, result):
+ self.times_called += 1
+ self.callback_result = result
+
+ with self.test_session():
+ callback = CallbackStub()
+ x_placeholder = keras.backend.placeholder(shape=())
+ y_placeholder = keras.backend.placeholder(shape=())
+
+ callback_op = x_placeholder * y_placeholder
+
+ f = keras.backend.function(
+ inputs=[x_placeholder, y_placeholder],
+ outputs=[x_placeholder + y_placeholder])
+ f.fetches.append(callback_op)
+ f.fetch_callbacks[callback_op] = callback._fetch_callback
+
+ _ = f([10., 20.])
+
+ self.assertEqual(callback.times_called, 1)
+ self.assertEqual(callback.callback_result, 200)
+
class BackendVariableTest(test.TestCase):
@@ -661,7 +693,7 @@ class BackendShapeOpsTest(test.TestCase):
np_kwargs={'data_format': 'channels_first'})
-class BackendNNOpsTest(test.TestCase):
+class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
def test_bias_add(self):
with self.test_session():
@@ -810,6 +842,118 @@ class BackendNNOpsTest(test.TestCase):
padding='same', data_format='channels_last')
self.assertEqual(y.get_shape().as_list(), [10, 5, 5])
+ def test_local_conv_channels_dim(self):
+ filters = 3
+ batch_size = 2
+
+ for input_shape in [(3, 5), (2, 3, 5), (2, 5, 3, 4)]:
+ channels_in = input_shape[0]
+ input_spatial_shape = input_shape[1:]
+ dim = len(input_spatial_shape)
+
+ inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
+ inputs_cf = keras.backend.variable(inputs)
+
+ for kernel_size in [1, 2]:
+ for stride in [1, 2]:
+ kernel_sizes = (kernel_size,) * dim
+ strides = (stride,) * dim
+
+ output_shape = tuple([(i - kernel_size + stride) // stride
+ for i in input_spatial_shape])
+
+ kernel_shape = (np.prod(output_shape),
+ np.prod(kernel_sizes) * channels_in,
+ filters)
+
+ kernel = np.random.normal(
+ 0,
+ 1,
+ output_shape + (channels_in, np.prod(kernel_sizes), filters)
+ )
+
+ kernel_cf = np.reshape(kernel, kernel_shape)
+ kernel_cf = keras.backend.variable(kernel_cf)
+
+ conv_cf = keras.backend.local_conv(inputs_cf,
+ kernel_cf,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_first')
+
+ inputs_cl = np.transpose(inputs, [0, 2] + list(range(3, dim + 2)) +
+ [1])
+ inputs_cl = keras.backend.variable(inputs_cl)
+
+ kernel_cl = np.reshape(
+ np.transpose(kernel, list(range(dim)) + [dim + 1, dim, dim + 2]),
+ kernel_shape
+ )
+ kernel_cl = keras.backend.variable(kernel_cl)
+
+ conv_cl = keras.backend.local_conv(inputs_cl,
+ kernel_cl,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_last')
+ with self.test_session():
+ conv_cf = keras.backend.eval(conv_cf)
+ conv_cl = keras.backend.eval(conv_cl)
+
+ self.assertAllCloseAccordingToType(
+ conv_cf,
+ np.transpose(conv_cl,
+ [0, dim + 1] + list(range(1, dim + 1))),
+ atol=1e-5
+ )
+
+ @parameterized.named_parameters(
+ ('local_conv1d', (5, 6), (3,), (1,), (3,)),
+ ('local_conv2d', (4, 5, 6), (3, 3), (1, 1), (2, 3)))
+ def test_local_conv_1d_and_2d(self,
+ input_shape,
+ kernel_sizes,
+ strides,
+ output_shape):
+ filters = 3
+ batch_size = 2
+
+ inputs = np.random.normal(0, 1, (batch_size,) + input_shape)
+ inputs = keras.backend.variable(inputs)
+
+ kernel = np.random.normal(0, 1, (np.prod(output_shape),
+ np.prod(kernel_sizes) * input_shape[-1],
+ filters))
+ kernel = keras.backend.variable(kernel)
+
+ local_conv = keras.backend.local_conv(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_last')
+ if len(output_shape) == 1:
+ local_conv_dim = keras.backend.local_conv1d(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ 'channels_last')
+ else:
+ local_conv_dim = keras.backend.local_conv2d(inputs,
+ kernel,
+ kernel_sizes,
+ strides,
+ output_shape,
+ 'channels_last')
+
+ with self.test_session():
+ local_conv = keras.backend.eval(local_conv)
+ local_conv_dim = keras.backend.eval(local_conv_dim)
+
+ self.assertAllCloseAccordingToType(local_conv, local_conv_dim)
+
def test_conv2d(self):
val = np.random.random((10, 4, 10, 10))
x = keras.backend.variable(val)
@@ -963,7 +1107,7 @@ class BackendNNOpsTest(test.TestCase):
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
with self.test_session():
- for (i, kwargs) in enumerate(kwargs_list):
+ for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
**kwargs)
@@ -1010,6 +1154,115 @@ class BackendNNOpsTest(test.TestCase):
for b_s, b_u_s in zip(state_list[2], state_list[3]):
self.assertAllClose(b_s, b_u_s, atol=1e-04)
+ def test_rnn_additional_states(self):
+ # implement a simple RNN
+ num_samples = 4
+ input_dim = 5
+ output_dim = 3
+ timesteps = 6
+
+ input_val = np.random.random(
+ (num_samples, timesteps, input_dim)).astype(np.float32)
+ init_state_val = np.random.random(
+ (num_samples, output_dim)).astype(np.float32)
+ w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
+ w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
+ np_mask = np.random.randint(2, size=(num_samples, timesteps))
+
+ def rnn_step_fn():
+ w_i = keras.backend.variable(w_i_val)
+ w_o = keras.backend.variable(w_o_val)
+
+ def step_function(x, states):
+ assert len(states) == 2
+ prev_output = states[0]
+ output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o)
+ return output, [output,
+ keras.backend.concatenate([output, output], axis=-1)]
+
+ return step_function
+
+ # test default setup
+ last_output_list = [[], [], [], [], [], []]
+ outputs_list = [[], [], [], [], [], []]
+ state_list = [[], [], [], [], [], []]
+ additional_state_list = [[], [], [], [], [], []]
+
+ rnn_fn = rnn_step_fn()
+ inputs = keras.backend.variable(input_val)
+ initial_states = [keras.backend.variable(init_state_val),
+ np.concatenate([init_state_val, init_state_val], axis=-1)]
+ mask = keras.backend.variable(np_mask)
+
+ kwargs_list = [
+ {'go_backwards': False, 'mask': None},
+ {'go_backwards': False, 'mask': None, 'unroll': True},
+ {'go_backwards': True, 'mask': None},
+ {'go_backwards': True, 'mask': None, 'unroll': True},
+ {'go_backwards': False, 'mask': mask},
+ {'go_backwards': False, 'mask': mask, 'unroll': True},
+ ]
+ with self.test_session():
+ for i, kwargs in enumerate(kwargs_list):
+ last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
+ initial_states,
+ **kwargs)
+ # check static shape inference
+ self.assertEqual(last_output.get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEqual(outputs.get_shape().as_list(),
+ [num_samples, timesteps, output_dim])
+ # for state in new_states:
+ # self.assertEquals(state.get_shape().as_list(),
+ # [num_samples, output_dim])
+ self.assertEqual(new_states[0].get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEqual(new_states[1].get_shape().as_list(),
+ [num_samples, 2 * output_dim])
+
+ last_output_list[i].append(keras.backend.eval(last_output))
+ outputs_list[i].append(keras.backend.eval(outputs))
+ self.assertEqual(len(new_states), 2)
+ state_list[i].append(keras.backend.eval(new_states[0]))
+ additional_state_list[i].append(keras.backend.eval(new_states[1]))
+
+ def assert_list_pairwise(z_list, atol=1e-05):
+ for (z1, z2) in zip(z_list[1:], z_list[:-1]):
+ self.assertAllClose(z1, z2, atol=atol)
+
+ assert_list_pairwise(last_output_list[0], atol=1e-04)
+ assert_list_pairwise(outputs_list[0], atol=1e-04)
+ assert_list_pairwise(state_list[0], atol=1e-04)
+ assert_list_pairwise(additional_state_list[0], atol=1e-04)
+ assert_list_pairwise(last_output_list[2], atol=1e-04)
+ assert_list_pairwise(outputs_list[2], atol=1e-04)
+ assert_list_pairwise(state_list[2], atol=1e-04)
+ assert_list_pairwise(additional_state_list[2], atol=1e-04)
+
+ for l, u_l in zip(last_output_list[0], last_output_list[1]):
+ self.assertAllClose(l, u_l, atol=1e-04)
+
+ for o, u_o in zip(outputs_list[0], outputs_list[1]):
+ self.assertAllClose(o, u_o, atol=1e-04)
+
+ for s, u_s in zip(state_list[0], state_list[1]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
+ for s, u_s in zip(additional_state_list[0], additional_state_list[1]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
+ for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
+ self.assertAllClose(b_l, b_u_l, atol=1e-04)
+
+ for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
+ self.assertAllClose(b_o, b_u_o, atol=1e-04)
+
+ for b_s, b_u_s in zip(state_list[2], state_list[3]):
+ self.assertAllClose(b_s, b_u_s, atol=1e-04)
+
+ for s, u_s in zip(additional_state_list[2], additional_state_list[3]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
def test_normalize_batch_in_training(self):
val = np.random.random((10, 3, 10, 10))
x = keras.backend.variable(val)
@@ -1165,6 +1418,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.max(y), 2., atol=0.1)
self.assertAllClose(np.min(y), -2., atol=0.1)
+ def test_string_input(self):
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index a6dbe2ba71..5d66db232a 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -24,6 +24,7 @@ from collections import Iterable
from collections import OrderedDict
import csv
import json
+import math
import os
import time
@@ -31,8 +32,10 @@ import numpy as np
import six
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.util.tf_export import tf_export
@@ -424,7 +427,7 @@ class ModelCheckpoint(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('ModelCheckpoint mode %s is unknown, '
- 'fallback to auto mode.', (mode), RuntimeWarning)
+ 'fallback to auto mode.', mode)
mode = 'auto'
if mode == 'min':
@@ -451,7 +454,7 @@ class ModelCheckpoint(Callback):
current = logs.get(self.monitor)
if current is None:
logging.warning('Can save best model only with %s available, '
- 'skipping.', self.monitor, RuntimeWarning)
+ 'skipping.', self.monitor)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
@@ -496,6 +499,9 @@ class EarlyStopping(Callback):
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity.
+ baseline: baseline value for the monitored quantity.
+ Training will stop if the model doesn't show improvement over the
+ baseline.
"""
def __init__(self,
@@ -503,19 +509,21 @@ class EarlyStopping(Callback):
min_delta=0,
patience=0,
verbose=0,
- mode='auto'):
+ mode='auto',
+ baseline=None):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
- self.min_delta = min_delta
+ self.baseline = baseline
+ self.min_delta = abs(min_delta)
self.wait = 0
self.stopped_epoch = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
- 'fallback to auto mode.', mode, RuntimeWarning)
+ 'fallback to auto mode.', mode)
mode = 'auto'
if mode == 'min':
@@ -537,14 +545,17 @@ class EarlyStopping(Callback):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
- self.best = np.Inf if self.monitor_op == np.less else -np.Inf
+ if self.baseline is not None:
+ self.best = self.baseline
+ else:
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if current is None:
logging.warning('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s',
- self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
+ self.monitor, ','.join(list(logs.keys())))
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
@@ -633,13 +644,35 @@ class LearningRateScheduler(Callback):
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
- if not hasattr(self.model.optimizer, 'lr'):
- raise ValueError('Optimizer must have a "lr" attribute.')
- lr = self.schedule(epoch)
+ # TODO(yashkatariya): Change the property checking when the learning
+ # rate attribute is unified across all TF Optimizers.
+ if isinstance(self.model.optimizer, optimizers.TFOptimizer):
+ if not hasattr(self.model.optimizer.optimizer, '_lr') and not hasattr(
+ self.model.optimizer.optimizer, '_learning_rate'):
+ raise ValueError(
+ 'TF Optimizer must have a "_lr" or "_learning_rate" attribute.')
+ else:
+ opt = self.model.optimizer.optimizer
+ if hasattr(opt, '_lr'):
+ opt_lr = Variable(opt._lr) # pylint: disable=protected-access
+ elif hasattr(opt, '_learning_rate'):
+ opt_lr = Variable(opt._learning_rate) # pylint: disable=protected-access
+ else:
+ if not hasattr(self.model.optimizer, 'lr'):
+ raise ValueError('Optimizer must have a "lr" attribute.')
+ else:
+ opt = self.model.optimizer
+ opt_lr = opt.lr
+
+ try: # new API
+ lr = float(K.get_value(opt_lr))
+ lr = self.schedule(epoch, lr)
+ except TypeError: # Support for old API for backward compatibility
+ lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
- K.set_value(self.model.optimizer.lr, lr)
+ K.set_value(opt_lr, lr)
if self.verbose > 0:
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
@@ -711,10 +744,16 @@ class TensorBoard(Callback):
self.write_grads = write_grads
self.write_images = write_images
self.batch_size = batch_size
+ self._current_batch = 0
+ # abstracted writer class to be able to stub for testing
+ self._writer_class = tf_summary.FileWriter
def set_model(self, model):
+ """Sets Keras model and creates summary ops."""
+
self.model = model
self.sess = K.get_session()
+ # only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
@@ -763,54 +802,56 @@ class TensorBoard(Callback):
self.merged = tf_summary.merge_all()
if self.write_graph:
- self.writer = tf_summary.FileWriter(self.log_dir, self.sess.graph)
+ self.writer = self._writer_class(self.log_dir, self.sess.graph)
else:
- self.writer = tf_summary.FileWriter(self.log_dir)
+ self.writer = self._writer_class(self.log_dir)
+
+ def _fetch_callback(self, summary):
+ self.writer.add_summary(
+ summary,
+ self._epoch + self._current_val_batch / self._validation_batches)
+ self._current_val_batch += 1
+
+ def on_train_begin(self, logs=None):
+ """Checks if histogram summaries can be run."""
+
+ if self.histogram_freq:
+ if 'validation_steps' in self.params:
+ self._validation_batches = self.params['validation_steps']
+ elif self.validation_data:
+ self._validation_batches = math.ceil(
+ self.validation_data[0].shape[0] / self.batch_size)
+ else:
+ raise ValueError('If printing histograms, validation data must be '
+ 'provided.')
+ if self._validation_batches == 0:
+ raise ValueError(
+ 'If printing histograms, validation data must have length > 0.')
+
+ def on_epoch_begin(self, epoch, logs=None):
+ """Add histogram op to Model test_function callbacks, reset batch count."""
+
+ # check if histogram summary should be run for this epoch
+ if self.histogram_freq and epoch % self.histogram_freq == 0:
+ self._epoch = epoch
+ self._current_val_batch = 0
+ # add the histogram summary op if it should run this epoch
+ if self.merged not in self.model.test_function.fetches:
+ self.model.test_function.fetches.append(self.merged)
+ self.model.test_function.fetch_callbacks[
+ self.merged] = self._fetch_callback
def on_epoch_end(self, epoch, logs=None):
+ """Checks if summary ops should run next epoch, logs scalar summaries."""
+
logs = logs or {}
- if not self.validation_data and self.histogram_freq:
- raise ValueError('If printing histograms, validation_data must be '
- 'provided, and cannot be a generator.')
- if self.validation_data and self.histogram_freq:
- if epoch % self.histogram_freq == 0:
-
- val_data = self.validation_data
- tensors = (
- self.model.inputs + self.model.targets + self.model.sample_weights)
-
- if self.model.uses_learning_phase:
- tensors += [K.learning_phase()]
-
- assert len(val_data) == len(tensors)
- val_size = val_data[0].shape[0]
- i = 0
- while i < val_size:
- step = min(self.batch_size, val_size - i)
- batch_val = []
- batch_val.append(val_data[0][i:i + step]
- if val_data[0] is not None else None)
- batch_val.append(val_data[1][i:i + step]
- if val_data[1] is not None else None)
- batch_val.append(val_data[2][i:i + step]
- if val_data[2] is not None else None)
- if self.model.uses_learning_phase:
- # do not slice the learning phase
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data[:-1]]
- batch_val.append(val_data[-1])
- else:
- batch_val = [x[i:i + step] if x is not None else None
- for x in val_data]
- feed_dict = {}
- for key, val in zip(tensors, batch_val):
- if val is not None:
- feed_dict[key] = val
- result = self.sess.run([self.merged], feed_dict=feed_dict)
- summary_str = result[0]
- self.writer.add_summary(summary_str, epoch)
- i += self.batch_size
+ # pop the histogram summary op after each epoch
+ if self.histogram_freq:
+ if self.merged in self.model.test_function.fetches:
+ self.model.test_function.fetches.remove(self.merged)
+ if self.merged in self.model.test_function.fetch_callbacks:
+ self.model.test_function.fetch_callbacks.pop(self.merged)
for name, value in logs.items():
if name in ['batch', 'size']:
@@ -901,7 +942,7 @@ class ReduceLROnPlateau(Callback):
"""
if self.mode not in ['auto', 'min', 'max']:
logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
- 'fallback to auto mode.', self.mode, RuntimeWarning)
+ 'fallback to auto mode.', self.mode)
self.mode = 'auto'
if (self.mode == 'min' or
(self.mode == 'auto' and 'acc' not in self.monitor)):
@@ -923,7 +964,7 @@ class ReduceLROnPlateau(Callback):
if current is None:
logging.warning('Reduce LR on plateau conditioned on metric `%s` '
'which is not available. Available metrics are: %s',
- self.monitor, ','.join(list(logs.keys())), RuntimeWarning)
+ self.monitor, ','.join(list(logs.keys())))
else:
if self.in_cooldown():
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index eb40fb4acc..244d48591c 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -27,11 +27,18 @@ import unittest
import numpy as np
+from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training.adam import AdamOptimizer
+from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
+
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -273,16 +280,43 @@ class KerasCallbacksTest(test.TestCase):
1, activation='sigmoid'),))
model.compile(
optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
- stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience)
weights = model.get_weights()
+ stopper = keras.callbacks.EarlyStopping(monitor='acc', patience=patience)
hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
assert len(hist.epoch) >= patience
# This should allow training to go for at least `patience` epochs
model.set_weights(weights)
hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
- assert len(hist.epoch) >= patience
+ assert len(hist.epoch) >= patience
+
+ def test_EarlyStopping_with_baseline(self):
+ with self.test_session():
+ np.random.seed(1337)
+ baseline = 0.5
+ (data, labels), _ = testing_utils.get_test_data(
+ train_samples=100,
+ test_samples=50,
+ input_shape=(1,),
+ num_classes=NUM_CLASSES)
+ model = keras.models.Sequential((keras.layers.Dense(
+ 1, input_dim=1, activation='relu'), keras.layers.Dense(
+ 1, activation='sigmoid'),))
+ model.compile(
+ optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
+
+ stopper = keras.callbacks.EarlyStopping(monitor='acc',
+ baseline=baseline)
+ hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
+ assert len(hist.epoch) == 1
+
+ patience = 3
+ stopper = keras.callbacks.EarlyStopping(monitor='acc',
+ patience=patience,
+ baseline=baseline)
+ hist = model.fit(data, labels, callbacks=[stopper], verbose=0, epochs=20)
+ assert len(hist.epoch) >= patience
def test_RemoteMonitor(self):
if requests is None:
@@ -321,8 +355,96 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=5,
verbose=0)
- assert (float(keras.backend.get_value(model.optimizer.lr)) - 0.2
- ) < keras.backend.epsilon()
+ assert (
+ float(keras.backend.get_value(
+ model.optimizer.lr)) - 0.2) < keras.backend.epsilon()
+
+ cbks = [keras.callbacks.LearningRateScheduler(lambda x, lr: lr / 2)]
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=2,
+ verbose=0)
+ assert (
+ float(keras.backend.get_value(
+ model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_TF_LearningRateScheduler_Adam(self):
+ with self.test_session():
+ with context.eager_mode():
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer=AdamOptimizer(),
+ metrics=['accuracy'])
+ cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=5,
+ verbose=0)
+ opt_lr = model.optimizer.optimizer._lr
+ self.assertLess(
+ float(keras.backend.get_value(
+ Variable(opt_lr))) - 0.2, keras.backend.epsilon())
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_TF_LearningRateScheduler_GradientDescent(self):
+ with self.test_session():
+ with context.eager_mode():
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer=GradientDescentOptimizer(1e-3),
+ metrics=['accuracy'])
+ cbks = [keras.callbacks.LearningRateScheduler(lambda x: 1. / (1. + x))]
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=5,
+ verbose=0)
+ opt_lr = model.optimizer.optimizer._learning_rate
+ self.assertLess(
+ float(keras.backend.get_value(
+ Variable(opt_lr))) - 0.2, keras.backend.epsilon())
def test_ReduceLROnPlateau(self):
with self.test_session():
@@ -767,21 +889,6 @@ class KerasCallbacksTest(test.TestCase):
for cb in cbs:
cb.on_train_end()
- # fit generator with validation data generator should raise ValueError if
- # histogram_freq > 0
- cbs = callbacks_factory(histogram_freq=1)
- with self.assertRaises(ValueError):
- model.fit_generator(
- data_generator(True),
- len(x_train),
- epochs=2,
- validation_data=data_generator(False),
- validation_steps=1,
- callbacks=cbs)
-
- for cb in cbs:
- cb.on_train_end()
-
# Make sure file writer cache is clear to avoid failures during cleanup.
writer_cache.FileWriterCache.clear()
@@ -856,6 +963,130 @@ class KerasCallbacksTest(test.TestCase):
callbacks=callbacks_factory(histogram_freq=1))
assert os.path.isdir(filepath)
+ def test_Tensorboard_histogram_summaries_in_test_function(self):
+
+ class FileWriterStub(object):
+
+ def __init__(self, logdir, graph=None):
+ self.logdir = logdir
+ self.graph = graph
+ self.steps_seen = []
+
+ def add_summary(self, summary, global_step):
+ summary_obj = summary_pb2.Summary()
+
+ # ensure a valid Summary proto is being sent
+ if isinstance(summary, bytes):
+ summary_obj.ParseFromString(summary)
+ else:
+ assert isinstance(summary, summary_pb2.Summary)
+ summary_obj = summary
+
+ # keep track of steps seen for the merged_summary op,
+ # which contains the histogram summaries
+ if len(summary_obj.value) > 1:
+ self.steps_seen.append(global_step)
+
+ def flush(self):
+ pass
+
+ def close(self):
+ pass
+
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=TRAIN_SAMPLES,
+ test_samples=TEST_SAMPLES,
+ input_shape=(INPUT_DIM,),
+ num_classes=NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test)
+ y_train = keras.utils.to_categorical(y_train)
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Dense(
+ NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
+ # non_trainable_weights: moving_variance, moving_mean
+ model.add(keras.layers.BatchNormalization())
+ model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ tsb._writer_class = FileWriterStub
+ cbks = [tsb]
+
+ # fit with validation data
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=3,
+ verbose=0)
+
+ self.assertAllEqual(tsb.writer.steps_seen, [0, 0.5, 1, 1.5, 2, 2.5])
+
+ def test_Tensorboard_histogram_summaries_with_generator(self):
+ np.random.seed(1337)
+ tmpdir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, tmpdir)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_dim=100, activation='relu'))
+ model.add(keras.layers.Dense(10, activation='softmax'))
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ tsb = keras.callbacks.TensorBoard(
+ log_dir=tmpdir,
+ histogram_freq=1,
+ write_images=True,
+ write_grads=True,
+ batch_size=5)
+ cbks = [tsb]
+
+ # fit with validation generator
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=cbks,
+ verbose=0)
+
+ with self.assertRaises(ValueError):
+ # fit with validation generator but no
+ # validation_steps
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=2,
+ validation_data=generator(),
+ callbacks=cbks,
+ verbose=0)
+
+ self.assertTrue(os.path.exists(tmpdir))
+
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
diff --git a/tensorflow/python/keras/datasets/boston_housing.py b/tensorflow/python/keras/datasets/boston_housing.py
index 8c043638c0..eeb7cbc44a 100644
--- a/tensorflow/python/keras/datasets/boston_housing.py
+++ b/tensorflow/python/keras/datasets/boston_housing.py
@@ -39,15 +39,15 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
assert 0 <= test_split < 1
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/keras-datasets/boston_housing.npz',
+ origin=origin_folder + 'boston_housing.npz',
file_hash=
'f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5')
- f = np.load(path)
- x = f['x']
- y = f['y']
- f.close()
+ with np.load(path) as f:
+ x = f['x']
+ y = f['y']
np.random.seed(seed)
indices = np.arange(len(x))
diff --git a/tensorflow/python/keras/datasets/fashion_mnist.py b/tensorflow/python/keras/datasets/fashion_mnist.py
index 45e27aad34..3f4c6c7413 100644
--- a/tensorflow/python/keras/datasets/fashion_mnist.py
+++ b/tensorflow/python/keras/datasets/fashion_mnist.py
@@ -33,9 +33,15 @@ def load_data():
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
+
+ License:
+ The copyright for Fashion-MNIST is held by Zalando SE.
+ Fashion-MNIST is licensed under the [MIT license](
+ https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).
+
"""
dirname = os.path.join('datasets', 'fashion-mnist')
- base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
+ base = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
diff --git a/tensorflow/python/keras/datasets/imdb.py b/tensorflow/python/keras/datasets/imdb.py
index 411b3e8635..b73b024162 100644
--- a/tensorflow/python/keras/datasets/imdb.py
+++ b/tensorflow/python/keras/datasets/imdb.py
@@ -77,9 +77,10 @@ def load_data(path='imdb.npz',
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb.npz',
+ origin=origin_folder + 'imdb.npz',
file_hash='599dadb1135973df5b59232a0e9a887c')
with np.load(path) as f:
x_train, labels_train = f['x_train'], f['y_train']
@@ -140,9 +141,10 @@ def get_word_index(path='imdb_word_index.json'):
Returns:
The word index dictionary.
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json',
+ origin=origin_folder + 'imdb_word_index.json',
file_hash='bfafd718b763782e994055a2d397834f')
with open(path) as f:
return json.load(f)
diff --git a/tensorflow/python/keras/datasets/mnist.py b/tensorflow/python/keras/datasets/mnist.py
index 631189731a..a96b581960 100644
--- a/tensorflow/python/keras/datasets/mnist.py
+++ b/tensorflow/python/keras/datasets/mnist.py
@@ -34,13 +34,21 @@ def load_data(path='mnist.npz'):
Returns:
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
+
+ License:
+ Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,
+ which is a derivative work from original NIST datasets.
+ MNIST dataset is made available under the terms of the
+ [Creative Commons Attribution-Share Alike 3.0 license.](
+ https://creativecommons.org/licenses/by-sa/3.0/)
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
+ origin=origin_folder + 'mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
- f = np.load(path)
- x_train, y_train = f['x_train'], f['y_train']
- x_test, y_test = f['x_test'], f['y_test']
- f.close()
- return (x_train, y_train), (x_test, y_test)
+ with np.load(path) as f:
+ x_train, y_train = f['x_train'], f['y_train']
+ x_test, y_test = f['x_test'], f['y_test']
+
+ return (x_train, y_train), (x_test, y_test)
diff --git a/tensorflow/python/keras/datasets/reuters.py b/tensorflow/python/keras/datasets/reuters.py
index b070ba8d12..cb796bb06c 100644
--- a/tensorflow/python/keras/datasets/reuters.py
+++ b/tensorflow/python/keras/datasets/reuters.py
@@ -75,9 +75,10 @@ def load_data(path='reuters.npz',
if kwargs:
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/reuters.npz',
+ origin=origin_folder + 'reuters.npz',
file_hash='87aedbeb0cb229e378797a632c1997b6')
with np.load(path) as f:
xs, labels = f['x'], f['y']
@@ -124,11 +125,10 @@ def get_word_index(path='reuters_word_index.json'):
Returns:
The word index dictionary.
"""
+ origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
- origin='https://s3.amazonaws.com/text-datasets/reuters_word_index.json',
+ origin=origin_folder + 'reuters_word_index.json',
file_hash='4d44cc38712099c9e383dc6e5f11a921')
- f = open(path)
- data = json.load(f)
- f.close()
- return data
+ with open(path) as f:
+ return json.load(f)
diff --git a/tensorflow/python/keras/engine/__init__.py b/tensorflow/python/keras/engine/__init__.py
index ec7c083199..26aed34766 100644
--- a/tensorflow/python/keras/engine/__init__.py
+++ b/tensorflow/python/keras/engine/__init__.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# TODO(fchollet): Remove hourglass imports once external code is done importing
+# non-public APIs.
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
-from tensorflow.python.keras.engine.network import get_source_inputs
-from tensorflow.python.keras.engine.network import Network
-from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.keras.utils.layer_utils import get_source_inputs
del absolute_import
del division
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 4814275fd5..e02792208b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase):
constraints on inputs that can be accepted by the layer.
"""
+ @checkpointable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase):
@activity_regularizer.setter
def activity_regularizer(self, regularizer):
"""Optional regularizer function for the output of this layer."""
- self._activity_regularizer = regularizer
+ self._activity_regularizer = self._no_dependency(regularizer)
@property
def trainable_weights(self):
@@ -459,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
- def add_weight(self, name, shape,
+ def add_weight(self,
+ name,
+ shape,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
constraint=None,
partitioner=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
getter=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -481,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
getter: Variable getter argument to be passed to the `Checkpointable` API.
Returns:
@@ -495,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
- ValueError: When giving unsupported dtype and no initializer.
+ ValueError: When giving unsupported dtype and no initializer or when
+ trainable has been set to True with synchronization set as `ON_READ`.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
@@ -504,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
# Initialize variable when no initializer provided
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
@@ -531,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase):
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
@@ -654,11 +685,12 @@ class Layer(checkpointable.CheckpointableBase):
# Handle Keras mask propagation from previous layer to current layer.
previous_mask = None
- if (not hasattr(self, '_compute_previous_mask') or
- self._compute_previous_mask):
+ if build_graph and (not hasattr(self, '_compute_previous_mask') or
+ self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = function_utils.fn_args(self.call)
+ self._call_fn_args = self._no_dependency(
+ function_utils.fn_args(self.call))
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
@@ -691,9 +723,10 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
- if all(hasattr(x, 'get_shape') for x in input_list):
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ if all(hasattr(x, 'shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
self.build(input_shapes)
+ self.built = True
# Check input assumptions set after layer building, e.g. input shape.
if build_graph or in_deferred_mode:
@@ -709,7 +742,7 @@ class Layer(checkpointable.CheckpointableBase):
# Deferred mode behavior: use `compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
if input_shapes is None:
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
output_shapes = self.compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
@@ -729,8 +762,6 @@ class Layer(checkpointable.CheckpointableBase):
if in_deferred_mode or build_graph and have_all_keras_metadata(inputs):
inputs, outputs = self._set_connectivity_metadata_(
inputs, outputs, args, kwargs)
-
- self.built = True
if context.executing_eagerly():
return outputs
@@ -1293,7 +1324,7 @@ class Layer(checkpointable.CheckpointableBase):
', but the layer isn\'t built. '
'You can build it manually via: `' + self.name +
'.build(batch_input_shape)`.')
- weight_shapes = [w.get_shape().as_list() for w in self.weights]
+ weight_shapes = [w.shape.as_list() for w in self.weights]
return int(sum([np.prod(w) for w in weight_shapes]))
@property
@@ -1376,7 +1407,7 @@ class Layer(checkpointable.CheckpointableBase):
if (spec.ndim is not None or
spec.min_ndim is not None or
spec.max_ndim is not None):
- if x.get_shape().ndims is None:
+ if x.shape.ndims is None:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'its rank is undefined, but the layer requires a '
@@ -1384,29 +1415,29 @@ class Layer(checkpointable.CheckpointableBase):
# Check ndim.
if spec.ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim != spec.ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
str(ndim) + '. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
if spec.max_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim > spec.max_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected max_ndim=' + str(spec.max_ndim) +
', found ndim=' + str(ndim))
if spec.min_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim < spec.min_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
': expected min_ndim=' + str(spec.min_ndim) +
', found ndim=' + str(ndim) +
'. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
# Check dtype.
if spec.dtype is not None:
if x.dtype != spec.dtype:
@@ -1416,7 +1447,7 @@ class Layer(checkpointable.CheckpointableBase):
', found dtype=' + str(x.dtype))
# Check specific shape axes.
if spec.axes:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for axis, value in spec.axes.items():
if hasattr(value, 'value'):
@@ -1429,7 +1460,7 @@ class Layer(checkpointable.CheckpointableBase):
' but received input with shape ' + str(shape))
# Check shape.
if spec.shape is not None:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for spec_dim, dim in zip(spec.shape, shape):
if spec_dim is not None and dim is not None:
@@ -1704,12 +1735,12 @@ class DeferredTensor(object):
def __str__(self):
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
def __repr__(self):
return "<DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
@@ -1804,11 +1835,13 @@ def make_variable(name,
dtype=dtypes.float32,
initializer=None,
partition_info=None,
- trainable=True,
+ trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1834,11 +1867,21 @@ def make_variable(name,
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
caching_device: Passed to `vs.variable`.
validate_shape: Passed to `vs.variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: Not handled at this time.
Returns:
@@ -1870,5 +1913,7 @@ def make_variable(name,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
return v
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index b04dc3c60b..8a4018a0df 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -119,6 +119,12 @@ class InputLayer(base_layer.Layer):
self.is_placeholder = False
self._batch_input_shape = tuple(input_tensor.get_shape().as_list())
+ if context.executing_eagerly():
+ raise ValueError('You should not pass an input tensor when executing '
+ 'in eager mode. For example, instead of creating an '
+ 'InputLayer, you should instantiate your model and '
+ 'directly call it on your input.')
+
# Create an input node to add to self.outbound_node
# and set output_tensors' _keras_history.
input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access
@@ -209,7 +215,7 @@ def Input( # pylint: disable=invalid-name
if dtype is None:
dtype = K.floatx()
- if not shape and tensor is None:
+ if shape is None and tensor is None:
raise ValueError('Please provide to Input either a `shape`'
' or a `tensor` argument. Note that '
'`shape` does not include the batch '
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 6f27eea1e7..a4d96de74f 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -43,7 +43,8 @@ from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import data_structures_base
+from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import layer_utils as checkpointable_layer_utils
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -80,6 +81,20 @@ class Network(base_layer.Layer):
# Subclassed network
self._init_subclassed_network(**kwargs)
+ # Several Network methods have "no_automatic_dependency_tracking"
+ # annotations. Since Network does automatic dependency tracking on attribute
+ # assignment, including for common data structures such as lists, by default
+ # we'd have quite a few empty dependencies which users don't care about (or
+ # would need some way to ignore dependencies automatically, which is confusing
+ # when applied to user code). Some attributes, such as _layers, would cause
+ # structural issues (_layers being the place where Layers assigned to tracked
+ # attributes are stored).
+ #
+ # Aside from these aesthetic and structural issues, useless dependencies on
+ # empty lists shouldn't cause issues; adding or removing them will not break
+ # checkpoints, but may cause "all Python objects matched" assertions to fail
+ # (in which case less strict assertions may be substituted if necessary).
+ @checkpointable.no_automatic_dependency_tracking
def _base_init(self, name=None):
# The following are implemented as property functions:
# self.trainable_weights
@@ -134,6 +149,7 @@ class Network(base_layer.Layer):
# restore operations when graph building.
self._in_progress_restore_finalizer = None
+ @checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
@@ -292,6 +308,7 @@ class Network(base_layer.Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
+ @checkpointable.no_automatic_dependency_tracking
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
@@ -361,14 +378,35 @@ class Network(base_layer.Layer):
self._track_checkpointable(
layer, name='layer-%d' % layer_index, overwrite=True)
+ def _no_dependency(self, value):
+ """Override to allow `Layer` to disable dependency tracking.
+
+ `CheckpointableBase` defines this method, whose semantics are "if a subclass
+ does dependency tracking, this method exempts `value`." Layer uses
+ `_no_dependency` to exempt some of its attribute assignments (conditional on
+ attribute assignment causing tracking in the subclass).
+
+ Args:
+ value: An object which will be assigned to an object attribute, whose
+ value should not be tracked.
+
+ Returns:
+ A wrapped object which, when assigned to an attribute, will not be
+ tracked (`value` will be stored in the attribute).
+ """
+ return data_structures.NoDependency(value)
+
def __setattr__(self, name, value):
- no_dependency = isinstance(value, checkpointable.NoDependency)
- if no_dependency:
- value = value.value
+ if not getattr(self, '_setattr_tracking', True):
+ super(Network, self).__setattr__(name, value)
+ return
+ no_dependency = isinstance(value, data_structures.NoDependency)
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
if isinstance(value, (
base_layer.Layer,
Network,
- data_structures_base.CheckpointableDataStructureBase)):
+ data_structures.CheckpointableDataStructure)):
try:
is_graph_network = self._is_graph_network
except AttributeError:
@@ -376,7 +414,9 @@ class Network(base_layer.Layer):
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.')
if not is_graph_network:
- if value not in self._layers:
+ # We need to check object identity to avoid de-duplicating empty
+ # container types which compare equal.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_use_resource_variables'):
# In subclassed models, legacy layers (tf.layers) must always use
@@ -384,12 +424,6 @@ class Network(base_layer.Layer):
value._use_resource_variables = True
if (not no_dependency
and isinstance(value, checkpointable.CheckpointableBase)):
- # Layer (and therefore Network/Model) inherit from CheckpointableBase
- # rather than Checkpointable, which means there is no Checkpointable
- # __setattr__ override (it would be a performance issue for functional
- # layers). Therefore Model tracks Checkpointable objects itself.
- self._track_checkpointable(
- checkpointable=value, name=name, overwrite=True)
if ( # For subclassed models only, users may add extra weights/variables
# simply by assigning them to attributes.
not self._is_graph_network
@@ -492,7 +526,8 @@ class Network(base_layer.Layer):
@property
def layers(self):
- return self._layers
+ return checkpointable_layer_utils.filter_empty_layer_containers(
+ self._layers)
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
@@ -528,6 +563,28 @@ class Network(base_layer.Layer):
raise ValueError('No such layer: ' + name)
@property
+ def _unfiltered_updates(self):
+ if context.executing_eagerly():
+ return []
+ updates = []
+ for layer in self.layers:
+ if isinstance(layer, Network):
+ updates += layer._unfiltered_updates
+ else:
+ updates += layer.updates
+ return updates
+
+ @property
+ def _unfiltered_losses(self):
+ losses = []
+ for layer in self.layers:
+ if isinstance(layer, Network):
+ losses += layer._unfiltered_losses
+ else:
+ losses += layer.losses
+ return losses
+
+ @property
def updates(self):
"""Retrieves the network's updates.
@@ -536,6 +593,8 @@ class Network(base_layer.Layer):
(e.g. will not include updates that were created by layers of this model
outside of the model).
+ When the network has no registered inputs, all updates are returned.
+
Effectively, `network.updates` behaves like `layer.updates`.
Concrete example:
@@ -581,22 +640,20 @@ class Network(base_layer.Layer):
if not self.trainable and not self.stateful:
return []
- updates = []
- for layer in self.layers:
- updates += layer.updates
+ updates = self._unfiltered_updates
# `updates` might contain irrelevant updates, so it needs to be filtered
# with respect to inputs the model has been called on.
- if self.inputs:
- relevant_inputs = self.inputs[:]
- else:
- relevant_inputs = []
- for i in range(1, len(self._inbound_nodes)):
+ relevant_inputs = []
+ for i in range(0, len(self._inbound_nodes)):
inputs = self.get_input_at(i)
if isinstance(inputs, list):
relevant_inputs += inputs
else:
relevant_inputs.append(inputs)
+ if not relevant_inputs:
+ return updates
+
reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates)
relevant_conditional_updates = [x for x in updates if x in reachable]
unconditional_updates = [
@@ -615,25 +672,25 @@ class Network(base_layer.Layer):
(e.g. will not include losses that depend on tensors
that aren't inputs to this model).
+ When the network has no registered inputs, all losses are returned.
+
Returns:
A list of loss tensors.
"""
- losses = []
- for layer in self.layers:
- losses += layer.losses
+ losses = self._unfiltered_losses
if context.executing_eagerly():
return losses
- if self.inputs:
- relevant_inputs = self.inputs[:]
- else:
- relevant_inputs = []
- for i in range(1, len(self._inbound_nodes)):
+ relevant_inputs = []
+ for i in range(0, len(self._inbound_nodes)):
inputs = self.get_input_at(i)
if isinstance(inputs, list):
relevant_inputs += inputs
else:
relevant_inputs.append(inputs)
+ if not relevant_inputs:
+ return losses
+
reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses)
relevant_conditional_losses = [x for x in losses if x in reachable]
unconditional_losses = [
@@ -643,14 +700,14 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
- return layer_utils.gather_trainable_weights(
+ return checkpointable_layer_utils.gather_trainable_weights(
trainable=self.trainable,
sub_layers=self.layers,
extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- return layer_utils.gather_non_trainable_weights(
+ return checkpointable_layer_utils.gather_non_trainable_weights(
trainable=self.trainable,
sub_layers=self.layers,
extra_variables=self._extra_variables)
@@ -1463,7 +1520,8 @@ class Network(base_layer.Layer):
ImportError: if yaml module is not found.
"""
if yaml is None:
- raise ImportError('Requires yaml module installed.')
+ raise ImportError(
+ 'Requires yaml module installed (`pip install pyyaml`).')
return yaml.dump(self._updated_config(), **kwargs)
def summary(self, line_length=None, positions=None, print_fn=None):
@@ -1495,47 +1553,6 @@ class Network(base_layer.Layer):
print_fn=print_fn)
-def get_source_inputs(tensor, layer=None, node_index=None):
- """Returns the list of input tensors necessary to compute `tensor`.
-
- Output will always be a list of tensors
- (potentially with 1 element).
-
- Arguments:
- tensor: The tensor to start from.
- layer: Origin layer of the tensor. Will be
- determined via tensor._keras_history if not provided.
- node_index: Origin node index of the tensor.
-
- Returns:
- List of input tensors.
- """
- if not hasattr(tensor, '_keras_history'):
- return tensor
-
- if layer is None or node_index:
- layer, node_index, _ = tensor._keras_history
- if not layer._inbound_nodes:
- return [tensor]
- else:
- node = layer._inbound_nodes[node_index]
- if not node.inbound_layers:
- # Reached an Input layer, stop recursion.
- return node.input_tensors
- else:
- source_tensors = []
- for i in range(len(node.inbound_layers)):
- x = node.input_tensors[i]
- layer = node.inbound_layers[i]
- node_index = node.node_indices[i]
- previous_sources = get_source_inputs(x, layer, node_index)
- # Avoid input redundancy.
- for x in previous_sources:
- if x not in source_tensors:
- source_tensors.append(x)
- return source_tensors
-
-
def _is_hdf5_filepath(filepath):
return filepath.endswith('.h5') or filepath.endswith('.keras')
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index 99ce64a469..d5ccd44604 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -106,7 +106,7 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
model_layers = model.layers
save_weights_to_hdf5_group(model_weights_group, model_layers)
- if include_optimizer and hasattr(model, 'optimizer'):
+ if include_optimizer and model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
logging.warning(
'TensorFlow optimizers do not '
@@ -323,7 +323,7 @@ def model_from_yaml(yaml_string, custom_objects=None):
ImportError: if yaml module is not found.
"""
if yaml is None:
- raise ImportError('Requires yaml module installed.')
+ raise ImportError('Requires yaml module installed (`pip install pyyaml`).')
config = yaml.load(yaml_string)
from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
return deserialize(config, custom_objects=custom_objects)
@@ -351,7 +351,10 @@ def preprocess_weights_for_loading(layer,
weights,
original_keras_version=None,
original_backend=None):
- """Converts layers weights from Keras 1 format to Keras 2.
+ """Preprocess layer weights between different Keras formats.
+
+ Converts layers weights from Keras 1 format to Keras 2 and also weights of
+ CuDNN layers in Keras 2.
Arguments:
layer: Layer instance.
@@ -363,7 +366,18 @@ def preprocess_weights_for_loading(layer,
Returns:
A list of weights values (Numpy arrays).
"""
- if layer.__class__.__name__ == 'Bidirectional':
+ def convert_nested_bidirectional(weights):
+ """Converts layers nested in `Bidirectional` wrapper.
+
+ This function uses `preprocess_weights_for_loading()` for converting
+ layers.
+
+ Arguments:
+ weights: List of weights values (Numpy arrays).
+
+ Returns:
+ A list of weights values (Numpy arrays).
+ """
num_weights_per_layer = len(weights) // 2
forward_weights = preprocess_weights_for_loading(
layer.forward_layer, weights[:num_weights_per_layer],
@@ -371,7 +385,69 @@ def preprocess_weights_for_loading(layer,
backward_weights = preprocess_weights_for_loading(
layer.backward_layer, weights[num_weights_per_layer:],
original_keras_version, original_backend)
- weights = forward_weights + backward_weights
+ return forward_weights + backward_weights
+
+ def convert_nested_time_distributed(weights):
+ """Converts layers nested in `TimeDistributed` wrapper.
+
+ This function uses `preprocess_weights_for_loading()` for converting nested
+ layers.
+
+ Arguments:
+ weights: List of weights values (Numpy arrays).
+
+ Returns:
+ A list of weights values (Numpy arrays).
+ """
+ return preprocess_weights_for_loading(
+ layer.layer, weights, original_keras_version, original_backend)
+
+ def convert_nested_model(weights):
+ """Converts layers nested in `Model` or `Sequential`.
+
+ This function uses `preprocess_weights_for_loading()` for converting nested
+ layers.
+
+ Arguments:
+ weights: List of weights values (Numpy arrays).
+
+ Returns:
+ A list of weights values (Numpy arrays).
+ """
+ new_weights = []
+ # trainable weights
+ for sublayer in layer.layers:
+ num_weights = len(sublayer.trainable_weights)
+ if num_weights > 0:
+ new_weights.extend(preprocess_weights_for_loading(
+ layer=sublayer,
+ weights=weights[:num_weights],
+ original_keras_version=original_keras_version,
+ original_backend=original_backend))
+ weights = weights[num_weights:]
+
+ # non-trainable weights
+ for sublayer in layer.layers:
+ num_weights = len([l for l in sublayer.weights
+ if l not in sublayer.trainable_weights])
+ if num_weights > 0:
+ new_weights.extend(preprocess_weights_for_loading(
+ layer=sublayer,
+ weights=weights[:num_weights],
+ original_keras_version=original_keras_version,
+ original_backend=original_backend))
+ weights = weights[num_weights:]
+ return new_weights
+
+ # Convert layers nested in Bidirectional/Model/Sequential.
+ # Both transformation should be ran for both Keras 1->2 conversion
+ # and for conversion of CuDNN layers.
+ if layer.__class__.__name__ == 'Bidirectional':
+ weights = convert_nested_bidirectional(weights)
+ if layer.__class__.__name__ == 'TimeDistributed':
+ weights = convert_nested_time_distributed(weights)
+ elif layer.__class__.__name__ in ['Model', 'Sequential']:
+ weights = convert_nested_model(weights)
if original_keras_version == '1':
if layer.__class__.__name__ == 'TimeDistributed':
@@ -446,35 +522,6 @@ def preprocess_weights_for_loading(layer,
recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
weights = [kernel, recurrent_kernel, bias]
- if layer.__class__.__name__ in ['Model', 'Sequential']:
- new_weights = []
- # trainable weights
- for sublayer in layer.layers:
- num_weights = len(sublayer.trainable_weights)
- if num_weights > 0:
- new_weights.extend(
- preprocess_weights_for_loading(
- layer=sublayer,
- weights=weights[:num_weights],
- original_keras_version=original_keras_version,
- original_backend=original_backend))
- weights = weights[num_weights:]
-
- # non-trainable weights
- for sublayer in layer.layers:
- num_weights = len([
- l for l in sublayer.weights if l not in sublayer.trainable_weights
- ])
- if num_weights > 0:
- new_weights.extend(
- preprocess_weights_for_loading(
- layer=sublayer,
- weights=weights[:num_weights],
- original_keras_version=original_keras_version,
- original_backend=original_backend))
- weights = weights[num_weights:]
- weights = new_weights
-
conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
if layer.__class__.__name__ in conv_layers:
if original_backend == 'theano':
@@ -486,6 +533,7 @@ def preprocess_weights_for_loading(layer,
if layer.__class__.__name__ == 'ConvLSTM2D':
weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
+ # convert CuDNN layers
return _convert_rnn_weights(layer, weights)
@@ -624,7 +672,7 @@ def _convert_rnn_weights(layer, weights):
kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
n_gates)
recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
- biases = weights[2].reshape((2, -1) if from_cudnn else -1)
+ biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
return [kernels, recurrent_kernels, biases]
if bias_shape == (2 * units * n_gates,):
@@ -806,7 +854,16 @@ def load_weights_from_hdf5_group_by_name(f, layers):
str(len(weight_values)) + ' element(s).')
# Set values.
for i in range(len(weight_values)):
- weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
+ if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
+ raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
+ '"), weight ' + str(symbolic_weights[i]) +
+ ' has shape {}'.format(K.int_shape(
+ symbolic_weights[i])) +
+ ', but the saved weight has shape ' +
+ str(weight_values[i].shape) + '.')
+
+ else:
+ weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
K.batch_set_value(weight_value_tuples)
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index b5448a9be1..030328f2a6 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import os
import shutil
import tempfile
-
from absl.testing import parameterized
import numpy as np
@@ -31,6 +30,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import training
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
@@ -248,6 +248,82 @@ class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
self.assertAllClose(y, ref_y)
+ def test_sequential_weight_loading_group_name_with_incorrect_length(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ num_classes = 2
+ with self.test_session():
+ ref_model = keras.models.Sequential()
+ ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
+ name='d1'))
+ ref_model.add(keras.layers.Dense(num_classes, name='d2'))
+ ref_model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+
+ f_ref_model = h5py.File(h5_path, 'w')
+ saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
+
+ f_model = h5py.File(h5_path, 'r')
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, use_bias=False,
+ input_dim=input_dim, name='d1'))
+ model.add(keras.layers.Dense(num_classes, name='d2'))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ with self.assertRaisesRegexp(ValueError,
+ r'Layer #0 \(named \"d1\"\) expects 1 '
+ r'weight\(s\), but the saved weights have 2 '
+ r'element\(s\)\.'):
+ saving.load_weights_from_hdf5_group_by_name(f_model, model.layers)
+
+ def test_sequential_weight_loading_group_name_with_incorrect_shape(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ num_classes = 2
+ with self.test_session():
+ ref_model = keras.models.Sequential()
+ ref_model.add(keras.layers.Dense(num_hidden, input_dim=input_dim,
+ name='d1'))
+ ref_model.add(keras.layers.Dense(num_classes, name='d2'))
+ ref_model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+
+ f_ref_model = h5py.File(h5_path, 'w')
+ saving.save_weights_to_hdf5_group(f_ref_model, ref_model.layers)
+
+ f_model = h5py.File(h5_path, 'r')
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden + 5, input_dim=input_dim,
+ name='d1'))
+ model.add(keras.layers.Dense(num_classes, name='d2'))
+ model.compile(loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(lr=0.0001),
+ metrics=[keras.metrics.categorical_accuracy])
+ with self.assertRaisesRegexp(ValueError,
+ r'Layer #0 \(named "d1"\), weight '
+ r'<tf\.Variable \'d1_1\/kernel:0\' '
+ r'shape=\(3, 10\) dtype=float32> has '
+ r'shape \(3, 10\), but the saved weight has '
+ r'shape \(3, 5\)\.'):
+ saving.load_weights_from_hdf5_group_by_name(f_model, model.layers)
+
class TestWholeModelSaving(test.TestCase):
@@ -288,6 +364,30 @@ class TestWholeModelSaving(test.TestCase):
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
+ def test_sequential_model_saving_without_compile(self):
+ if h5py is None:
+ self.skipTest('h5py required to run this test')
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+
+ x = np.random.random((1, 3))
+ out = model.predict(x)
+ fd, fname = tempfile.mkstemp('.h5')
+
+ # Save the model without any compilation or training.
+ keras.models.save_model(model, fname)
+
+ new_model = keras.models.load_model(fname)
+ os.close(fd)
+ os.remove(fname)
+
+ out2 = new_model.predict(x)
+ self.assertAllClose(out, out2, atol=1e-05)
+
def test_sequential_model_saving_2(self):
if h5py is None:
self.skipTest('h5py required to run this test')
@@ -563,7 +663,7 @@ class SubclassedModel(training.Model):
class TestWeightSavingAndLoadingTFFormat(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_tensorflow_format_overwrite(self):
with self.test_session() as session:
model = SubclassedModel()
@@ -652,7 +752,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
restore_on_create_y = self.evaluate(restore_on_create_y_tensor)
self.assertAllClose(ref_y, restore_on_create_y)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model(self):
def _make_graph_model():
a = keras.layers.Input(shape=(2,))
@@ -662,7 +762,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self._weight_loading_test_template(_make_graph_model)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_weight_loading_subclassed_model(self):
self._weight_loading_test_template(SubclassedModel)
@@ -696,7 +796,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
y = self.evaluate(model(x))
self.assertAllClose(ref_y, y)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model_added_layer(self):
def _save_graph_model():
a = keras.layers.Input(shape=(2,))
@@ -716,7 +816,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
_save_graph_model, _restore_graph_model,
_restore_init_fn)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model_added_no_weight_layer(self):
def _save_graph_model():
a = keras.layers.Input(shape=(2,))
@@ -737,7 +837,7 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
_save_graph_model, _restore_graph_model,
_restore_init_fn)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_weight_loading_subclassed_model_added_layer(self):
class SubclassedModelRestore(training.Model):
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index 52e29b0ffa..371504a503 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -24,11 +24,12 @@ import copy
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import layers as layer_module
from tensorflow.python.keras.engine import base_layer
-from tensorflow.python.keras.engine import network
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.engine.training import Model
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -108,6 +109,7 @@ class Sequential(Model):
return self._layers[1:]
return self._layers
+ @checkpointable.no_automatic_dependency_tracking
def add(self, layer):
"""Adds a layer instance on top of the layer stack.
@@ -146,8 +148,6 @@ class Sequential(Model):
first_layer = layer.layers[0]
while isinstance(first_layer, (Model, Sequential)):
first_layer = first_layer.layers[0]
- batch_shape = first_layer._batch_input_shape
- dtype = first_layer.dtype
if hasattr(first_layer, '_batch_input_shape'):
batch_shape = first_layer._batch_input_shape
@@ -179,7 +179,7 @@ class Sequential(Model):
'use the functional API.')
self.outputs = [layer._inbound_nodes[-1].output_tensors[0]]
- self.inputs = network.get_source_inputs(self.outputs[0])
+ self.inputs = layer_utils.get_source_inputs(self.outputs[0])
elif self.outputs:
output_tensor = layer(self.outputs[0])
if isinstance(output_tensor, list):
@@ -193,6 +193,7 @@ class Sequential(Model):
else:
self._layers.append(layer)
+ @checkpointable.no_automatic_dependency_tracking
def pop(self):
"""Removes the last layer in the model.
@@ -212,6 +213,7 @@ class Sequential(Model):
self.outputs = [self.layers[-1].output]
self.build()
+ @checkpointable.no_automatic_dependency_tracking
def build(self, input_shape=None):
if input_shape and not self.inputs:
batch_shape = tuple(input_shape)
@@ -222,11 +224,16 @@ class Sequential(Model):
for layer in self._layers:
x = layer(x)
self.outputs = [x]
+ # Make sure that the model's input shape will be preserved during
+ # serialization.
+ if self._layers:
+ self._layers[0]._batch_input_shape = batch_shape
if self.inputs:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self.built = True
- self._track_layers(self._layers)
+ if self._layers:
+ self._track_layers(self._layers)
def predict_proba(self, x, batch_size=32, verbose=0):
"""Generates class probability predictions for the input samples.
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 69a288e69b..0f54e29cee 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -33,7 +33,7 @@ class TestSequential(test.TestCase):
"""Most Sequential model API tests are covered in `training_test.py`.
"""
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_basic_methods(self):
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_dim=2))
@@ -44,7 +44,7 @@ class TestSequential(test.TestCase):
self.assertEqual(len(model.weights), 2 * 2)
self.assertEqual(model.get_layer(name='dp').name, 'dp')
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_sequential_pop(self):
num_hidden = 5
input_dim = 3
@@ -77,7 +77,7 @@ class TestSequential(test.TestCase):
with self.assertRaises(TypeError):
model.pop()
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_sequential_deferred_build_with_np_arrays(self):
num_hidden = 5
input_dim = 3
@@ -102,7 +102,7 @@ class TestSequential(test.TestCase):
[None, num_classes])
self.assertEqual(len(model.weights), 2 * 2)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_sequential_deferred_build_with_dataset_iterators(self):
if not context.executing_eagerly():
# TODO(psv/fchollet): Add support for this use case in graph mode.
@@ -136,7 +136,7 @@ class TestSequential(test.TestCase):
[None, num_classes])
self.assertEqual(len(model.weights), 2 * 2)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_invalid_use_cases(self):
# Added objects must be layer instances
with self.assertRaises(TypeError):
@@ -160,7 +160,7 @@ class TestSequential(test.TestCase):
model.add(keras.layers.Dense(1, input_dim=1))
model.add(MyLayer())
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_nested_sequential_trainability(self):
input_dim = 20
num_units = 10
@@ -209,6 +209,30 @@ class TestSequential(test.TestCase):
x2 = model.predict(val_a)
assert np.abs(np.sum(x1 - x2)) > 1e-5
+ def test_sequential_deferred_build_serialization(self):
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+
+ model = keras.models.Sequential()
+ # We don't specify the input shape.
+ model.add(keras.layers.Dense(num_hidden))
+ model.add(keras.layers.Dense(num_classes))
+ model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
+ self.assertFalse(model.built)
+
+ x = np.random.random((batch_size, input_dim))
+ y = np.random.random((batch_size, num_classes))
+ model.train_on_batch(x, y)
+ self.assertTrue(model.built)
+
+ config = model.get_config()
+ new_model = keras.models.Sequential.from_config(config)
+ self.assertTrue(new_model.built)
+ self.assertEqual(len(model.layers), 2)
+ self.assertEqual(len(model.weights), 4)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 183e26e8bf..3eb69bd7f3 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -26,6 +26,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import input_layer as input_layer_lib
+from tensorflow.python.keras.engine import network as network_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@@ -62,7 +64,7 @@ class TopologyConstructionTest(test.TestCase):
inputs=True)
return inputs + 1
- x1 = keras.Input(shape=(1,))
+ x1 = input_layer_lib.Input(shape=(1,))
layer = MyLayer()
_ = layer.apply(x1)
@@ -70,7 +72,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_updates_for(x1)), 1)
self.assertEqual(len(layer.get_updates_for(None)), 1)
- x2 = keras.Input(shape=(1,))
+ x2 = input_layer_lib.Input(shape=(1,))
y2 = layer.apply(x2)
self.assertEqual(len(layer.updates), 3)
@@ -78,17 +80,17 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_updates_for(x2)), 1)
self.assertEqual(len(layer.get_updates_for(None)), 1)
- network = keras.engine.Network(x2, y2)
+ network = network_lib.Network(x2, y2)
self.assertEqual(len(network.updates), 2)
self.assertEqual(len(network.get_updates_for(x1)), 0)
self.assertEqual(len(network.get_updates_for(x2)), 1)
self.assertEqual(len(network.get_updates_for(None)), 1)
- x3 = keras.Input(shape=(1,))
+ x3 = input_layer_lib.Input(shape=(1,))
_ = layer.apply(x3)
self.assertEqual(len(network.updates), 2)
- x4 = keras.Input(shape=(1,))
+ x4 = input_layer_lib.Input(shape=(1,))
_ = network(x4)
self.assertEqual(len(network.updates), 3)
self.assertEqual(len(network.get_updates_for(x2)), 1)
@@ -104,7 +106,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(network.get_updates_for(x4)), 2)
def test_get_updates_bn(self):
- x1 = keras.Input(shape=(1,))
+ x1 = input_layer_lib.Input(shape=(1,))
layer = keras.layers.BatchNormalization()
_ = layer.apply(x1)
@@ -134,7 +136,7 @@ class TopologyConstructionTest(test.TestCase):
inputs=True)
return inputs + 1
- x1 = keras.Input(shape=(1,))
+ x1 = input_layer_lib.Input(shape=(1,))
layer = MyLayer()
_ = layer.apply(x1)
@@ -142,7 +144,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_losses_for(x1)), 1)
self.assertEqual(len(layer.get_losses_for(None)), 1)
- x2 = keras.Input(shape=(1,))
+ x2 = input_layer_lib.Input(shape=(1,))
y2 = layer.apply(x2)
self.assertEqual(len(layer.losses), 3)
@@ -150,17 +152,17 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(len(layer.get_losses_for(x2)), 1)
self.assertEqual(len(layer.get_losses_for(None)), 1)
- network = keras.engine.Network(x2, y2)
+ network = network_lib.Network(x2, y2)
self.assertEqual(len(network.losses), 2)
self.assertEqual(len(network.get_losses_for(x1)), 0)
self.assertEqual(len(network.get_losses_for(x2)), 1)
self.assertEqual(len(network.get_losses_for(None)), 1)
- x3 = keras.Input(shape=(1,))
+ x3 = input_layer_lib.Input(shape=(1,))
_ = layer.apply(x3)
self.assertEqual(len(network.losses), 2)
- x4 = keras.Input(shape=(1,))
+ x4 = input_layer_lib.Input(shape=(1,))
_ = network(x4)
self.assertEqual(len(network.losses), 3)
self.assertEqual(len(network.get_losses_for(x2)), 1)
@@ -177,8 +179,8 @@ class TopologyConstructionTest(test.TestCase):
def testTopologicalAttributes(self):
# test layer attributes / methods related to cross-layer connectivity.
- a = keras.Input(shape=(32,), name='input_a')
- b = keras.Input(shape=(32,), name='input_b')
+ a = input_layer_lib.Input(shape=(32,), name='input_a')
+ b = input_layer_lib.Input(shape=(32,), name='input_b')
# test input, output, input_shape, output_shape
test_layer = keras.layers.Dense(16, name='test_layer')
@@ -219,15 +221,15 @@ class TopologyConstructionTest(test.TestCase):
_ = new_dense.input_shape
with self.assertRaises(AttributeError):
new_dense = keras.layers.Dense(16)
- a = keras.Input(shape=(3, 32))
- a = keras.Input(shape=(5, 32))
+ a = input_layer_lib.Input(shape=(3, 32))
+ a = input_layer_lib.Input(shape=(5, 32))
a_2 = dense(a)
b_2 = dense(b)
_ = new_dense.input_shape
with self.assertRaises(AttributeError):
new_dense = keras.layers.Dense(16)
- a = keras.Input(shape=(3, 32))
- a = keras.Input(shape=(5, 32))
+ a = input_layer_lib.Input(shape=(3, 32))
+ a = input_layer_lib.Input(shape=(5, 32))
a_2 = dense(a)
b_2 = dense(b)
_ = new_dense.output_shape
@@ -239,7 +241,7 @@ class TopologyConstructionTest(test.TestCase):
def call(self, inputs):
return [inputs**2, inputs**3]
- x = keras.Input(shape=(32,))
+ x = input_layer_lib.Input(shape=(32,))
test_layer = PowersLayer()
p1, p2 = test_layer(x) # pylint: disable=not-callable
@@ -256,8 +258,8 @@ class TopologyConstructionTest(test.TestCase):
assert len(inputs) == 2
return inputs[0] + inputs[1]
- a = keras.Input(shape=(32,))
- b = keras.Input(shape=(32,))
+ a = input_layer_lib.Input(shape=(32,))
+ b = input_layer_lib.Input(shape=(32,))
test_layer = AddLayer()
y = test_layer([a, b]) # pylint: disable=not-callable
@@ -268,10 +270,10 @@ class TopologyConstructionTest(test.TestCase):
def testBasicNetwork(self):
# minimum viable network
- x = keras.Input(shape=(32,))
+ x = input_layer_lib.Input(shape=(32,))
dense = keras.layers.Dense(2)
y = dense(x)
- network = keras.engine.Network(x, y, name='dense_network')
+ network = network_lib.Network(x, y, name='dense_network')
# test basic attributes
self.assertEqual(network.name, 'dense_network')
@@ -282,7 +284,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights)
# test callability on Input
- x_2 = keras.Input(shape=(32,))
+ x_2 = input_layer_lib.Input(shape=(32,))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 2])
@@ -506,7 +508,7 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)])
# test get_source_inputs
- self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b])
+ self.assertListEqual(keras.engine.get_source_inputs(c), [a, b])
# serialization / deserialization
json_config = model.to_json()
@@ -778,12 +780,12 @@ class TopologyConstructionTest(test.TestCase):
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
else:
- x = keras.Input(shape=(32,))
+ x = input_layer_lib.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
- network = keras.engine.Network(x, y)
+ network = network_lib.Network(x, y)
# test callability on Input
- x_2 = keras.Input(shape=(32,))
+ x_2 = input_layer_lib.Input(shape=(32,))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
@@ -797,14 +799,14 @@ class TopologyConstructionTest(test.TestCase):
def reg(x):
return math_ops.reduce_sum(x)
- net_a_input = keras.Input((2,))
+ net_a_input = input_layer_lib.Input((2,))
net_a = net_a_input
net_a = keras.layers.Dense(2, kernel_initializer='ones',
use_bias=False,
activity_regularizer=reg)(net_a)
model_a = keras.Model([net_a_input], [net_a])
- net_b_input = keras.Input((2,))
+ net_b_input = input_layer_lib.Input((2,))
net_b = model_a(net_b_input)
model_b = keras.Model([net_b_input], [net_b])
@@ -817,7 +819,7 @@ class TopologyConstructionTest(test.TestCase):
with self.test_session():
x_val = np.random.random((10, 5))
- x = keras.Input(shape=(5,))
+ x = input_layer_lib.Input(shape=(5,))
a = keras.layers.Dense(5, name='A')
b = keras.layers.Dense(5, name='B')
output = a(b(a(b(x))))
@@ -837,7 +839,7 @@ class TopologyConstructionTest(test.TestCase):
def test_layer_sharing_at_heterogenous_depth_with_concat(self):
with self.test_session():
input_shape = (16, 9, 3)
- input_layer = keras.Input(shape=input_shape)
+ input_layer = input_layer_lib.Input(shape=input_shape)
a = keras.layers.Dense(3, name='dense_A')
b = keras.layers.Dense(3, name='dense_B')
@@ -924,7 +926,7 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testSimpleNetworkBuilding(self):
- inputs = keras.engine.Input(shape=(32,))
+ inputs = input_layer_lib.Input(shape=(32,))
if context.executing_eagerly():
self.assertIsInstance(inputs, base_layer.DeferredTensor)
self.assertEqual(inputs.dtype.name, 'float32')
@@ -937,8 +939,8 @@ class DeferredModeTest(test.TestCase):
self.assertEqual(x.shape.as_list(), [None, 2])
outputs = keras.layers.Dense(4)(x)
- network = keras.engine.Network(inputs, outputs)
- self.assertIsInstance(network, keras.engine.Network)
+ network = network_lib.Network(inputs, outputs)
+ self.assertIsInstance(network, network_lib.Network)
if context.executing_eagerly():
# It should be possible to call such a network on EagerTensors.
@@ -949,8 +951,8 @@ class DeferredModeTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testMultiIONetworkbuilding(self):
- input_a = keras.engine.Input(shape=(32,))
- input_b = keras.engine.Input(shape=(16,))
+ input_a = input_layer_lib.Input(shape=(32,))
+ input_b = input_layer_lib.Input(shape=(16,))
a = keras.layers.Dense(16)(input_a)
class AddLayer(keras.layers.Layer):
@@ -964,7 +966,7 @@ class DeferredModeTest(test.TestCase):
c = AddLayer()([a, input_b]) # pylint: disable=not-callable
c = keras.layers.Dense(2)(c)
- network = keras.engine.Network([input_a, input_b], [a, c])
+ network = network_lib.Network([input_a, input_b], [a, c])
if context.executing_eagerly():
a_val = constant_op.constant(
np.random.random((10, 32)).astype('float32'))
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 04a2aa7664..8e632651fa 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -41,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -114,6 +116,7 @@ class Model(Network):
# Create a cache for dataset - uninitialized iterators
self._dataset_iterator_cache = weakref.WeakKeyDictionary()
+ @checkpointable.no_automatic_dependency_tracking
def compile(self,
optimizer,
loss=None,
@@ -177,6 +180,11 @@ class Model(Network):
raise ValueError('Only TF native optimizers are supported in Eager mode.')
self.optimizer = optimizers.get(optimizer)
+ # We've disabled automatic dependency tracking for this method, but do want
+ # to add a checkpoint dependency on the optimizer if it's checkpointable.
+ if isinstance(self.optimizer, checkpointable.CheckpointableBase):
+ self._track_checkpointable(
+ self.optimizer, name='optimizer', overwrite=True)
self.loss = loss
self.metrics = metrics or []
self.loss_weights = loss_weights
@@ -409,11 +417,13 @@ class Model(Network):
else:
if sample_weight_mode == 'temporal':
sample_weights.append(array_ops.placeholder_with_default(
- [[1.]], shape=[None, None], name=name + '_sample_weights'))
+ constant_op.constant([[1.]], dtype=K.floatx()),
+ shape=[None, None], name=name + '_sample_weights'))
sample_weight_modes.append('temporal')
else:
sample_weights.append(array_ops.placeholder_with_default(
- [1.], shape=[None], name=name + '_sample_weights'))
+ constant_op.constant([1.], dtype=K.floatx()),
+ shape=[None], name=name + '_sample_weights'))
sample_weight_modes.append(None)
self.sample_weight_modes = sample_weight_modes
self._feed_sample_weight_modes = []
@@ -938,6 +948,7 @@ class Model(Network):
str(x[0].shape[0]) + ' samples')
return x, y, sample_weights
+ @checkpointable.no_automatic_dependency_tracking
def _set_inputs(self, inputs, training=None):
"""Set model's input and output specs based on the input data received.
@@ -986,6 +997,7 @@ class Model(Network):
else:
self._symbolic_set_inputs(inputs, training=training)
+ @checkpointable.no_automatic_dependency_tracking
def _eager_set_inputs(self, inputs):
"""Set model's input and output specs based on the input data received.
@@ -1008,14 +1020,16 @@ class Model(Network):
# to keep track of number of inputs and outputs and their ndim.
if isinstance(inputs, (list, tuple)):
if tensor_util.is_tensor(inputs[0]):
- dummy_output_values = self.call(inputs)
+ dummy_output_values = self.call(
+ training_utils.cast_if_floating_dtype(inputs))
else:
dummy_output_values = self.call(
[ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs])
dummy_input_values = list(inputs)
else:
if tensor_util.is_tensor(inputs):
- dummy_output_values = self.call(inputs)
+ dummy_output_values = self.call(
+ training_utils.cast_if_floating_dtype(inputs))
else:
dummy_output_values = self.call(
ops.convert_to_tensor(inputs, dtype=K.floatx()))
@@ -1036,6 +1050,7 @@ class Model(Network):
'output_%d' % (i + 1) for i in range(len(dummy_output_values))]
self.built = True
+ @checkpointable.no_automatic_dependency_tracking
def _symbolic_set_inputs(self, inputs, outputs=None, training=None):
"""Set model's inputs and output specs based.
@@ -1616,7 +1631,10 @@ class Model(Network):
# Validate and standardize user data.
inputs, _, _ = self._standardize_user_data(x)
if context.executing_eagerly():
- if not isinstance(inputs, iterator_ops.EagerIterator):
+ if (isinstance(x, iterator_ops.EagerIterator) or
+ (isinstance(x, dataset_ops.Dataset) and context.executing_eagerly())):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=K.floatx()) for val in inputs
]
diff --git a/tensorflow/python/keras/engine/training_arrays.py b/tensorflow/python/keras/engine/training_arrays.py
index 93f4f1bd1d..adefffab11 100644
--- a/tensorflow/python/keras/engine/training_arrays.py
+++ b/tensorflow/python/keras/engine/training_arrays.py
@@ -124,6 +124,10 @@ def fit_loop(model,
callback_metrics = copy.copy(out_labels) + [
'val_' + n for n in out_labels
]
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
else:
callback_metrics = copy.copy(out_labels)
@@ -156,7 +160,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -164,11 +168,17 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
cbk.validation_data = val_ins
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
# To prevent a slowdown, we find beforehand the arrays that need conversion.
feed = model._feed_inputs + model._feed_targets + model._feed_sample_weights
@@ -185,6 +195,7 @@ def fit_loop(model,
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
if steps_per_epoch is not None:
+ # Step-wise fit loop.
for step_index in range(steps_per_epoch):
batch_logs = {}
batch_logs['batch'] = step_index
@@ -215,7 +226,6 @@ def fit_loop(model,
val_inputs,
val_targets,
sample_weights=val_sample_weights,
- batch_size=batch_size,
steps=validation_steps,
verbose=0)
if not isinstance(val_outs, list):
@@ -224,6 +234,7 @@ def fit_loop(model,
for l, o in zip(out_labels, val_outs):
epoch_logs['val_' + l] = o
else:
+ # Sample-wise fit loop.
if shuffle == 'batch':
index_array = training_utils.batch_shuffle(index_array, batch_size)
elif shuffle:
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index a70b488f25..c78684c9f4 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -255,6 +255,8 @@ def iterator_fit_loop(model,
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(
x, y, class_weight=class_weight)
+ x = training_utils.cast_if_floating_dtype(x)
+ y = training_utils.cast_if_floating_dtype(y)
if sample_weights:
sample_weights = [
ops.convert_to_tensor(val, dtype=backend.floatx())
@@ -471,6 +473,8 @@ def iterator_test_loop(model, inputs, steps, verbose=0):
# Validate and standardize data.
x, y, sample_weights = model._standardize_user_data(x, y)
+ x = training_utils.cast_if_floating_dtype(x)
+ y = training_utils.cast_if_floating_dtype(y)
# Calculate model output, loss values.
loss_outs, loss, loss_metrics = _model_loss(
@@ -639,6 +643,7 @@ def iterator_predict_loop(model, inputs, steps, verbose=0):
# Validate and standardize data.
x, _, _ = model._standardize_user_data(x)
+ x = training_utils.cast_if_floating_dtype(x)
if model._expects_training_arg:
batch_outs = model.call(x[0] if len(x) == 1 else x, training=False)
@@ -814,7 +819,10 @@ def train_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss and the loss associated with each output.
"""
- if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ if len(inputs) and tensor_util.is_tensor(inputs[0]):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ targets = training_utils.cast_if_floating_dtype(targets)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
]
@@ -849,7 +857,10 @@ def test_on_batch(model, inputs, targets, sample_weights=None):
Returns:
total loss, loss and metrics associated with each output.
"""
- if len(inputs) and not tensor_util.is_tensor(inputs[0]):
+ if len(inputs) and tensor_util.is_tensor(inputs[0]):
+ inputs = training_utils.cast_if_floating_dtype(inputs)
+ targets = training_utils.cast_if_floating_dtype(targets)
+ else:
inputs = [
ops.convert_to_tensor(val, dtype=backend.floatx()) for val in inputs
]
@@ -978,7 +989,7 @@ def fit_loop(model,
callbacks.set_model(callback_model)
- callbacks.set_params({
+ callback_params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps_per_epoch,
@@ -986,9 +997,11 @@ def fit_loop(model,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics or [],
- })
- callbacks.on_train_begin()
- callback_model.stop_training = False
+ }
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ callbacks.set_params(callback_params)
+
for cbk in callbacks:
if not val_inputs:
cbk.validation_data = []
@@ -998,6 +1011,10 @@ def fit_loop(model,
cbk.validation_data = val_inputs + val_targets + val_sample_weights
else:
cbk.validation_data = val_inputs + val_targets
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
+ callback_model.stop_training = False
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index 7906d208eb..bdb3035129 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -403,6 +403,24 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_generator_methods(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(3,)))
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, 'mse', metrics=['mae'])
+
+ x = np.random.random((10, 3))
+ y = np.random.random((10, 4))
+
+ def iterator():
+ while True:
+ yield x, y
+
+ model.fit_generator(iterator(), steps_per_epoch=3, epochs=1)
+ model.evaluate_generator(iterator(), steps=3)
+ out = model.predict_generator(iterator(), steps=3)
+ self.assertEqual(out.shape, (30, 4))
+
class LossWeightingTest(test.TestCase):
@@ -629,7 +647,7 @@ class LossWeightingTest(test.TestCase):
class CorrectnessTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_loss_correctness(self):
# Test that training loss is the same in eager and graph
# (by comparing it to a reference value in a deterministic case)
@@ -650,7 +668,7 @@ class CorrectnessTest(test.TestCase):
self.assertEqual(
np.around(history.history['loss'][-1], decimals=4), 0.6173)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
model = keras.Sequential()
model.add(keras.layers.Dense(3,
@@ -671,7 +689,7 @@ class CorrectnessTest(test.TestCase):
outs = model.evaluate(x, y)
self.assertEqual(outs[1], 0.)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_loss_correctness_with_iterator(self):
# Test that training loss is the same in eager and graph
# (by comparing it to a reference value in a deterministic case)
@@ -694,7 +712,7 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness_with_iterator(self):
model = keras.Sequential()
model.add(
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index d81b384f0e..432cf2bddd 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -96,14 +96,25 @@ def fit_generator(model,
else:
callback_model = model
callbacks.set_model(callback_model)
- callbacks.set_params({
+
+ callback_params = {
'epochs': epochs,
'steps': steps_per_epoch,
'verbose': verbose,
'do_validation': do_validation,
'metrics': callback_metrics,
- })
- callbacks.on_train_begin()
+ }
+ if do_validation:
+ # need to create the test_function before start of the first epoch
+ # because TensorBoard callback on_epoch_begin adds summary to the
+ # list of fetches of the test_function
+ model._make_test_function()
+ # determine the number of validation batches given a generator
+ if validation_steps:
+ callback_params.update({'validation_steps': validation_steps})
+ elif isinstance(validation_data, Sequence):
+ callback_params.update({'validation_steps': len(validation_data)})
+ callbacks.set_params(callback_params)
enqueuer = None
val_enqueuer = None
@@ -149,6 +160,9 @@ def fit_generator(model,
output_generator = generator
callback_model.stop_training = False
+ # validation_data must be set before on_train_begin() is called
+ # so that TensorboardCallback can validate its input
+ callbacks.on_train_begin()
# Construct epoch logs.
epoch_logs = {}
while epoch < epochs:
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 5c02d36382..d9e548f01f 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -129,8 +129,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
epochs=1,
batch_size=5,
verbose=0)
@@ -138,8 +140,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
epochs=1,
batch_size=5,
verbose=1)
@@ -147,8 +151,10 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
validation_data=({
'input_a': input_a_np,
'input_b': input_b_np
@@ -162,8 +168,10 @@ class TrainingTest(test.TestCase):
model.train_on_batch({
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np})
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ })
# Test with lists for loss, metrics
loss = ['mae', 'mse']
@@ -285,16 +293,20 @@ class TrainingTest(test.TestCase):
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
batch_size=5,
verbose=0)
model.evaluate(
{
'input_a': input_a_np,
'input_b': input_b_np
- }, {'dense': output_d_np,
- 'dropout': output_e_np},
+ }, {
+ 'dense': output_d_np,
+ 'dropout': output_e_np
+ },
batch_size=5,
verbose=1)
@@ -349,9 +361,11 @@ class TrainingTest(test.TestCase):
with self.test_session():
test_inputs = [
- scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)]
+ scipy_sparse.random(6, 3, density=0.25).tocsr() for _ in range(2)
+ ]
test_outputs = [
- scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)]
+ scipy_sparse.random(6, i, density=0.25).tocsr() for i in range(3, 5)
+ ]
in1 = keras.layers.Input(shape=(3,))
in2 = keras.layers.Input(shape=(3,))
out1 = keras.layers.Dropout(0.5, name='dropout')(in1)
@@ -1682,7 +1696,7 @@ class TestTrainingWithDataTensors(test.TestCase):
model.train_on_batch([input_a_np, input_b_np],
[output_a_np, output_b_np])
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_metric_names_are_identical_in_graph_and_eager(self):
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -1709,7 +1723,7 @@ class TestTrainingWithDataTensors(test.TestCase):
class TestTrainingWithDatasetIterators(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_iterators_single_io(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
@@ -1721,8 +1735,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1786,8 +1800,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1799,7 +1813,7 @@ class TestTrainingWithDatasetIterators(test.TestCase):
ops.get_default_graph().finalize()
model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_iterators_running_out_of_data(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
@@ -1811,8 +1825,8 @@ class TestTrainingWithDatasetIterators(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(2)
dataset = dataset.batch(10)
@@ -1838,8 +1852,8 @@ class TestTrainingWithDataset(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1853,7 +1867,7 @@ class TestTrainingWithDataset(test.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_training_and_eval_methods_on_dataset(self):
with self.test_session():
x = keras.layers.Input(shape=(3,), name='input')
@@ -1865,8 +1879,8 @@ class TestTrainingWithDataset(test.TestCase):
metrics = ['mae']
model.compile(optimizer, loss, metrics=metrics)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
@@ -1928,8 +1942,8 @@ class TestTrainingWithDataset(test.TestCase):
model.compile(optimizer, loss)
# User forgets to batch the dataset
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 3))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
@@ -1938,8 +1952,8 @@ class TestTrainingWithDataset(test.TestCase):
model.train_on_batch(dataset)
# Wrong input shape
- inputs = np.zeros((10, 5), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
+ inputs = np.zeros((10, 5))
+ targets = np.zeros((10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index b93f999444..728a2b493b 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -553,6 +553,10 @@ def standardize_weights(y,
def has_symbolic_tensors(ls):
if context.executing_eagerly():
return False
+ return has_tensors(ls)
+
+
+def has_tensors(ls):
if isinstance(ls, (list, tuple)):
return any(tensor_util.is_tensor(v) for v in ls)
return tensor_util.is_tensor(ls)
@@ -692,3 +696,29 @@ def check_steps_argument(input_data, steps, steps_name):
input_type=input_type_str, steps_name=steps_name))
return True
return False
+
+
+def cast_if_floating_dtype(x):
+ """Casts the given data tensors to the default floating point type.
+
+ Casts only if the input is already a floating point type.
+ Args:
+ x: tensor or list/tuple of tensors.
+
+ Returns:
+ Converted input.
+
+ Raises:
+ RuntimeError: if data isn't tensors.
+ """
+ if not has_tensors(x):
+ raise RuntimeError(
+ 'Please provide tensors for casting, got: {x}'.format(x=x))
+
+ if isinstance(x, (list, tuple)):
+ return [
+ math_ops.cast(val, dtype=K.floatx())
+ if tensor_util.is_tensor(val) and val.dtype.is_floating else val
+ for val in x
+ ]
+ return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
diff --git a/tensorflow/python/keras/estimator/__init__.py b/tensorflow/python/keras/estimator/__init__.py
new file mode 100644
index 0000000000..b244beb5b5
--- /dev/null
+++ b/tensorflow/python/keras/estimator/__init__.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras estimator API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.util.tf_export import tf_export
+
+# Keras has undeclared dependency on tensorflow/estimator:estimator_py.
+# As long as you depend //third_party/py/tensorflow:tensorflow target
+# everything will work as normal.
+
+try:
+ from tensorflow.python.estimator import keras as keras_lib # pylint: disable=g-import-not-at-top
+ model_to_estimator = tf_export('keras.estimator.model_to_estimator')(
+ keras_lib.model_to_estimator)
+except Exception: # pylint: disable=broad-except
+
+ # pylint: disable=unused-argument
+ def stub_model_to_estimator(keras_model=None,
+ keras_model_path=None,
+ custom_objects=None,
+ model_dir=None,
+ config=None):
+ raise NotImplementedError(
+ 'tf.keras.estimator.model_to_estimator function not available in your '
+ 'installation.')
+ # pylint: enable=unused-argument
+
+ model_to_estimator = tf_export('keras.estimator.model_to_estimator')(
+ stub_model_to_estimator)
+
diff --git a/tensorflow/python/keras/initializers.py b/tensorflow/python/keras/initializers.py
index b9b2e9ad59..28beb6760d 100644
--- a/tensorflow/python/keras/initializers.py
+++ b/tensorflow/python/keras/initializers.py
@@ -23,6 +23,9 @@ import six
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops.init_ops import Constant
+from tensorflow.python.ops.init_ops import glorot_normal_initializer
+from tensorflow.python.ops.init_ops import glorot_uniform_initializer
+
from tensorflow.python.ops.init_ops import Identity
from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import
from tensorflow.python.ops.init_ops import Ones
@@ -80,52 +83,6 @@ def lecun_uniform(seed=None):
scale=1., mode='fan_in', distribution='uniform', seed=seed)
-@tf_export('keras.initializers.glorot_normal')
-def glorot_normal(seed=None):
- """Glorot normal initializer, also called Xavier normal initializer.
-
- It draws samples from a truncated normal distribution centered on 0
- with `stddev = sqrt(2 / (fan_in + fan_out))`
- where `fan_in` is the number of input units in the weight tensor
- and `fan_out` is the number of output units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- Glorot & Bengio, AISTATS 2010
- http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_avg', distribution='normal', seed=seed)
-
-
-@tf_export('keras.initializers.glorot_uniform')
-def glorot_uniform(seed=None):
- """Glorot uniform initializer, also called Xavier uniform initializer.
-
- It draws samples from a uniform distribution within [-limit, limit]
- where `limit` is `sqrt(6 / (fan_in + fan_out))`
- where `fan_in` is the number of input units in the weight tensor
- and `fan_out` is the number of output units in the weight tensor.
-
- Arguments:
- seed: A Python integer. Used to seed the random generator.
-
- Returns:
- An initializer.
-
- References:
- Glorot & Bengio, AISTATS 2010
- http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
- """
- return VarianceScaling(
- scale=1., mode='fan_avg', distribution='uniform', seed=seed)
-
-
@tf_export('keras.initializers.he_normal')
def he_normal(seed=None):
"""He normal initializer.
@@ -179,6 +136,8 @@ normal = random_normal = RandomNormal
truncated_normal = TruncatedNormal
identity = Identity
orthogonal = Orthogonal
+glorot_normal = glorot_normal_initializer
+glorot_uniform = glorot_uniform_initializer
# pylint: enable=invalid-name
diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py
index a54d6da839..c519e194bd 100644
--- a/tensorflow/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/initializers_test.py
@@ -71,7 +71,7 @@ class KerasInitializersTest(test.TestCase):
stddev=1,
seed=126),
tensor_shape,
- target_mean=0., target_std=None, target_max=2)
+ target_mean=0., target_max=2, target_min=-2)
def test_constant(self):
tensor_shape = (5, 6, 4)
@@ -83,49 +83,49 @@ class KerasInitializersTest(test.TestCase):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(3. / fan_in)
+ std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ target_mean=0., target_std=std)
def test_glorot_uniform(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(6. / (fan_in + fan_out))
+ std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ target_mean=0., target_std=std)
def test_he_uniform(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(6. / fan_in)
+ std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_uniform(seed=123), tensor_shape,
- target_mean=0., target_max=scale, target_min=-scale)
+ target_mean=0., target_std=std)
def test_lecun_normal(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(1. / fan_in)
+ std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape,
- target_mean=0., target_std=None, target_max=2 * scale)
+ target_mean=0., target_std=std)
def test_glorot_normal(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(2. / (fan_in + fan_out))
+ std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
- target_mean=0., target_std=None, target_max=2 * scale)
+ target_mean=0., target_std=std)
def test_he_normal(self):
tensor_shape = (5, 6, 4, 2)
with self.test_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
- scale = np.sqrt(2. / fan_in)
+ std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_normal(seed=123), tensor_shape,
- target_mean=0., target_std=None, target_max=2 * scale)
+ target_mean=0., target_std=std)
def test_orthogonal(self):
tensor_shape = (20, 20)
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 8fb663a17e..e3a686f45d 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -20,15 +20,16 @@ from __future__ import print_function
# Generic layers.
# pylint: disable=g-bad-import-order
-from tensorflow.python.keras.engine import Input
-from tensorflow.python.keras.engine import InputLayer
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.input_layer import Input
+from tensorflow.python.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
# Advanced activations.
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.advanced_activations import PReLU
from tensorflow.python.keras.layers.advanced_activations import ELU
+from tensorflow.python.keras.layers.advanced_activations import ReLU
from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU
from tensorflow.python.keras.layers.advanced_activations import Softmax
@@ -86,9 +87,11 @@ from tensorflow.python.keras.layers.local import LocallyConnected2D
# Merge layers.
from tensorflow.python.keras.layers.merge import Add
+from tensorflow.python.keras.layers.merge import Subtract
from tensorflow.python.keras.layers.merge import Multiply
from tensorflow.python.keras.layers.merge import Average
from tensorflow.python.keras.layers.merge import Maximum
+from tensorflow.python.keras.layers.merge import Minimum
from tensorflow.python.keras.layers.merge import Concatenate
from tensorflow.python.keras.layers.merge import Dot
from tensorflow.python.keras.layers.merge import add
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 8ade3c3174..eba10da6f3 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -278,3 +278,40 @@ class Softmax(Layer):
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
+
+
+@tf_export('keras.layers.ReLU')
+class ReLU(Layer):
+ """Rectified Linear Unit activation function.
+
+ Input shape:
+ Arbitrary. Use the keyword argument `input_shape`
+ (tuple of integers, does not include the samples axis)
+ when using this layer as the first layer in a model.
+
+ Output shape:
+ Same shape as the input.
+
+ Arguments:
+ max_value: float >= 0. Maximum activation value.
+ """
+
+ def __init__(self, max_value=None, **kwargs):
+ super(ReLU, self).__init__(**kwargs)
+ self.support_masking = True
+ self.max_value = K.cast_to_floatx(max_value)
+ if self.max_value < 0.:
+ raise ValueError('max_value of Relu layer '
+ 'cannot be negative value: ' + str(max_value))
+
+ def call(self, inputs):
+ return activations.relu(inputs, max_value=self.max_value)
+
+ def get_config(self):
+ config = {'max_value': self.max_value}
+ base_config = super(ReLU, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @tf_utils.shape_type_conversion
+ def compute_output_shape(self, input_shape):
+ return input_shape
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 81c76db14c..9e1f15b1bc 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -62,6 +62,20 @@ class AdvancedActivationsTest(test.TestCase):
kwargs={'axis': 1},
input_shape=(2, 3, 4))
+ def test_relu(self):
+ with self.test_session():
+ testing_utils.layer_test(keras.layers.ReLU,
+ kwargs={'max_value': 10},
+ input_shape=(2, 3, 4))
+
+ def test_relu_with_invalid_arg(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'max_value of Relu layer cannot be negative value: -10'):
+ with self.test_session():
+ testing_utils.layer_test(keras.layers.ReLU,
+ kwargs={'max_value': -10},
+ input_shape=(2, 3, 4))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index ce1c84e98d..a57ac121ed 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -26,8 +26,8 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras.layers.pooling import AveragePooling1D
@@ -151,21 +151,23 @@ class Conv(Layer):
input_dim = int(input_shape[channel_axis])
kernel_shape = self.kernel_size + (input_dim, self.filters)
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.kernel = self.add_weight(
+ name='kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ trainable=True,
+ dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.input_spec = InputSpec(ndim=self.rank + 2,
@@ -380,11 +382,11 @@ class Conv2D(Conv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -611,11 +613,11 @@ class Conv2DTranspose(Conv2D):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -720,21 +722,23 @@ class Conv2DTranspose(Conv2D):
self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
kernel_shape = self.kernel_size + (self.filters, input_dim)
- self.kernel = self.add_variable(name='kernel',
- shape=kernel_shape,
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.kernel = self.add_weight(
+ name='kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ trainable=True,
+ dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.built = True
@@ -961,7 +965,7 @@ class Conv3DTranspose(Conv3D):
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.input_spec = InputSpec(ndim=5, axes={channel_axis: input_dim})
- self.kernel = self.add_variable(
+ self.kernel = self.add_weight(
'kernel',
shape=kernel_shape,
initializer=self.kernel_initializer,
@@ -970,7 +974,7 @@ class Conv3DTranspose(Conv3D):
trainable=True,
dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(
+ self.bias = self.add_weight(
'bias',
shape=(self.filters,),
initializer=self.bias_initializer,
@@ -1191,6 +1195,7 @@ class SeparableConv(Conv):
dilation_rate=dilation_rate,
activation=activations.get(activation),
use_bias=use_bias,
+ bias_initializer=initializers.get(bias_initializer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
bias_constraint=bias_constraint,
@@ -1222,7 +1227,7 @@ class SeparableConv(Conv):
pointwise_kernel_shape = (
1,) * self.rank + (self.depth_multiplier * input_dim, self.filters)
- self.depthwise_kernel = self.add_variable(
+ self.depthwise_kernel = self.add_weight(
name='depthwise_kernel',
shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
@@ -1230,7 +1235,7 @@ class SeparableConv(Conv):
constraint=self.depthwise_constraint,
trainable=True,
dtype=self.dtype)
- self.pointwise_kernel = self.add_variable(
+ self.pointwise_kernel = self.add_weight(
name='pointwise_kernel',
shape=pointwise_kernel_shape,
initializer=self.pointwise_initializer,
@@ -1239,13 +1244,14 @@ class SeparableConv(Conv):
trainable=True,
dtype=self.dtype)
if self.use_bias:
- self.bias = self.add_variable(name='bias',
- shape=(self.filters,),
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- trainable=True,
- dtype=self.dtype)
+ self.bias = self.add_weight(
+ name='bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ trainable=True,
+ dtype=self.dtype)
else:
self.bias = None
self.built = True
@@ -1447,11 +1453,11 @@ class SeparableConv2D(SeparableConv):
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -1591,11 +1597,11 @@ class DepthwiseConv2D(Conv2D):
Arguments:
kernel_size: An integer or tuple/list of 2 integers, specifying the
- width and height of the 2D convolution window.
+ height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
- specifying the strides of the convolution along the width and height.
+ specifying the strides of the convolution along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
@@ -1724,7 +1730,7 @@ class DepthwiseConv2D(Conv2D):
dilation_rate=self.dilation_rate,
data_format=self.data_format)
- if self.bias:
+ if self.use_bias:
outputs = backend.bias_add(
outputs,
self.bias,
@@ -2002,7 +2008,7 @@ class ZeroPadding2D(Layer):
Arguments:
padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric padding
- is applied to width and height.
+ is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric padding values for height and width:
@@ -2101,7 +2107,7 @@ class ZeroPadding3D(Layer):
Arguments:
padding: int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.
- If int: the same symmetric padding
- is applied to width and height.
+ is applied to height and width.
- If tuple of 3 ints:
interpreted as two different
symmetric padding values for height and width:
@@ -2261,12 +2267,12 @@ class Cropping1D(Layer):
class Cropping2D(Layer):
"""Cropping layer for 2D input (e.g. picture).
- It crops along spatial dimensions, i.e. width and height.
+ It crops along spatial dimensions, i.e. height and width.
Arguments:
cropping: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.
- If int: the same symmetric cropping
- is applied to width and height.
+ is applied to height and width.
- If tuple of 2 ints:
interpreted as two different
symmetric cropping values for height and width:
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py
index c731508b3c..84d794cada 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent.py
@@ -26,8 +26,8 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask
from tensorflow.python.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras.layers.recurrent import RNN
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index 167cabaeec..f904744422 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -45,7 +45,7 @@ class Convolution1DTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, length, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_conv1d(self):
kwargs = {
'filters': 2,
@@ -117,7 +117,7 @@ class Conv2DTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, num_row, num_col, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_conv2d(self):
kwargs = {
'filters': 2,
@@ -192,7 +192,7 @@ class Conv2DTransposeTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, num_row, num_col, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_conv2dtranspose(self):
kwargs = {
'filters': 2,
@@ -258,7 +258,7 @@ class Conv3DTransposeTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, depth, num_row, num_col, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_conv3dtranspose(self):
kwargs = {
'filters': 2,
@@ -322,7 +322,7 @@ class SeparableConv1DTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, length, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_separable_conv1d(self):
kwargs = {
'filters': 2,
@@ -398,7 +398,7 @@ class SeparableConv2DTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, num_row, num_col, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_separable_conv2d(self):
kwargs = {
'filters': 2,
@@ -477,7 +477,7 @@ class Conv3DTest(test.TestCase):
kwargs=test_kwargs,
input_shape=(num_samples, depth, num_row, num_col, stack_size))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_conv3d(self):
kwargs = {
'filters': 2,
@@ -529,7 +529,7 @@ class Conv3DTest(test.TestCase):
class ZeroPaddingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_zero_padding_1d(self):
num_samples = 2
input_dim = 2
@@ -581,7 +581,7 @@ class ZeroPaddingTest(test.TestCase):
with self.assertRaises(ValueError):
keras.layers.ZeroPadding1D(padding=None)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_zero_padding_2d(self):
num_samples = 2
stack_size = 2
@@ -660,7 +660,7 @@ class ZeroPaddingTest(test.TestCase):
with self.assertRaises(ValueError):
keras.layers.ZeroPadding2D(padding=None)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_zero_padding_3d(self):
num_samples = 2
stack_size = 2
@@ -702,13 +702,13 @@ class ZeroPaddingTest(test.TestCase):
class UpSamplingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_upsampling_1d(self):
with self.test_session(use_gpu=True):
testing_utils.layer_test(
keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_upsampling_2d(self):
num_samples = 2
stack_size = 2
@@ -758,7 +758,7 @@ class UpSamplingTest(test.TestCase):
np.testing.assert_allclose(np_output, expected_out)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_upsampling_3d(self):
num_samples = 2
stack_size = 2
@@ -818,7 +818,7 @@ class UpSamplingTest(test.TestCase):
class CroppingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_cropping_1d(self):
num_samples = 2
time_length = 4
@@ -837,7 +837,7 @@ class CroppingTest(test.TestCase):
with self.assertRaises(ValueError):
keras.layers.Cropping1D(cropping=None)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_cropping_2d(self):
num_samples = 2
stack_size = 2
@@ -905,7 +905,7 @@ class CroppingTest(test.TestCase):
with self.assertRaises(ValueError):
keras.layers.Cropping2D(cropping=None)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_cropping_3d(self):
num_samples = 2
stack_size = 2
@@ -995,6 +995,7 @@ class DepthwiseConv2DTest(test.TestCase):
'bias_regularizer': 'l2',
'activity_regularizer': 'l2',
'depthwise_constraint': 'unit_norm',
+ 'use_bias': True,
'strides': (2, 2),
}
self._run_test(kwargs, 'depth_multiplier', [1])
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index df4c3915a3..f28cade474 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -19,11 +19,14 @@ from __future__ import division
from __future__ import print_function
import copy
+import sys
import types as python_types
+import warnings
import numpy as np
from tensorflow.python.eager import context
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
@@ -31,8 +34,8 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
@@ -714,6 +717,7 @@ class Lambda(Layer):
return self.mask
def get_config(self):
+ module = self.function.__module__
if isinstance(self.function, python_types.LambdaType):
function = generic_utils.func_dump(self.function)
function_type = 'lambda'
@@ -721,21 +725,26 @@ class Lambda(Layer):
function = self.function.__name__
function_type = 'function'
+ output_shape_module = None
if isinstance(self._output_shape, python_types.LambdaType):
output_shape = generic_utils.func_dump(self._output_shape)
output_shape_type = 'lambda'
+ output_shape_module = self._output_shape.__module__
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
output_shape_type = 'function'
+ output_shape_module = self._output_shape.__module__
else:
output_shape = self._output_shape
output_shape_type = 'raw'
config = {
'function': function,
+ 'module': module,
'function_type': function_type,
'output_shape': output_shape,
'output_shape_type': output_shape_type,
+ 'output_shape_module': output_shape_module,
'arguments': self.arguments
}
base_config = super(Lambda, self).get_config()
@@ -745,8 +754,16 @@ class Lambda(Layer):
def from_config(cls, config, custom_objects=None):
config = config.copy()
globs = globals()
+ module = config.pop('module', None)
+ if module in sys.modules:
+ globs.update(sys.modules[module].__dict__)
+ elif module is not None:
+ # Note: we don't know the name of the function if it's a lambda.
+ warnings.warn('{} is not loaded, but a Lambda layer uses it. '
+ 'It may cause errors.'.format(module)
+ , UserWarning)
if custom_objects:
- globs = dict(list(globs.items()) + list(custom_objects.items()))
+ globs.update(custom_objects)
function_type = config.pop('function_type')
if function_type == 'function':
# Simple lookup in custom objects
@@ -760,6 +777,14 @@ class Lambda(Layer):
else:
raise TypeError('Unknown function type:', function_type)
+ output_shape_module = config.pop('output_shape_module', None)
+ if output_shape_module in sys.modules:
+ globs.update(sys.modules[output_shape_module].__dict__)
+ elif output_shape_module is not None:
+ # Note: we don't know the name of the function if it's a lambda.
+ warnings.warn('{} is not loaded, but a Lambda layer uses it. '
+ 'It may cause errors.'.format(output_shape_module)
+ , UserWarning)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
# Simple lookup in custom objects
@@ -882,34 +907,36 @@ class Dense(Layer):
'should be defined. Found `None`.')
self.input_spec = InputSpec(min_ndim=2,
axes={-1: input_shape[-1].value})
- self.kernel = self.add_variable('kernel',
- shape=[input_shape[-1].value, self.units],
- initializer=self.kernel_initializer,
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint,
- dtype=self.dtype,
- trainable=True)
+ self.kernel = self.add_weight(
+ 'kernel',
+ shape=[input_shape[-1].value, self.units],
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ constraint=self.kernel_constraint,
+ dtype=self.dtype,
+ trainable=True)
if self.use_bias:
- self.bias = self.add_variable('bias',
- shape=[self.units,],
- initializer=self.bias_initializer,
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint,
- dtype=self.dtype,
- trainable=True)
+ self.bias = self.add_weight(
+ 'bias',
+ shape=[self.units,],
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ constraint=self.bias_constraint,
+ dtype=self.dtype,
+ trainable=True)
else:
self.bias = None
self.built = True
def call(self, inputs):
inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
- shape = inputs.get_shape().as_list()
- if len(shape) > 2:
+ rank = common_shapes.rank(inputs)
+ if rank > 2:
# Broadcasting is required for the inputs.
- outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
- [0]])
+ outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
+ shape = inputs.get_shape().as_list()
output_shape = shape[:-1] + [self.units]
outputs.set_shape(output_shape)
else:
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index ff8af976b9..226403c592 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -51,7 +51,7 @@ class CoreLayersTest(test.TestCase):
dropout = keras.layers.Dropout(0.5)
self.assertEqual(True, dropout.supports_masking)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_spatial_dropout(self):
testing_utils.layer_test(
keras.layers.SpatialDropout1D,
@@ -78,7 +78,7 @@ class CoreLayersTest(test.TestCase):
kwargs={'rate': 0.5, 'data_format': 'channels_first'},
input_shape=(2, 3, 4, 4, 5))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_activation(self):
# with string argument
testing_utils.layer_test(
@@ -92,7 +92,7 @@ class CoreLayersTest(test.TestCase):
kwargs={'activation': keras.backend.relu},
input_shape=(3, 2))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_reshape(self):
testing_utils.layer_test(
keras.layers.Reshape,
@@ -114,12 +114,12 @@ class CoreLayersTest(test.TestCase):
kwargs={'target_shape': (-1, 1)},
input_shape=(None, None, 2))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_permute(self):
testing_utils.layer_test(
keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_flatten(self):
testing_utils.layer_test(
keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4))
@@ -134,7 +134,7 @@ class CoreLayersTest(test.TestCase):
np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3))
self.assertAllClose(outputs, target_outputs)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_repeat_vector(self):
testing_utils.layer_test(
keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2))
@@ -173,7 +173,7 @@ class CoreLayersTest(test.TestCase):
config = ld.get_config()
ld = keras.layers.Lambda.from_config(config)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dense(self):
testing_utils.layer_test(
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2))
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index ad6594279d..cf2b0c476c 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -25,7 +25,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
+from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 9d186f8c58..8fd970239f 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import tempfile
from absl.testing import parameterized
import numpy as np
@@ -30,7 +32,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class CuDNNTest(test.TestCase, parameterized.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_cudnn_rnn_basics(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
@@ -58,7 +60,7 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
'go_backwards': go_backwards},
input_shape=(num_samples, timesteps, input_size))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_trainability(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
@@ -217,27 +219,14 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
out5 = model.predict(np.ones((num_samples, timesteps)))
self.assertNotEqual(out4.max(), out5.max())
- # TODO(psv): Add generic cross product helper function for parametrized tests.
@parameterized.named_parameters(
- ('cudnnlstm_to_lstm_unidirectional_impl_1', 'LSTM', False, False, 1),
- ('cudnnlstm_to_lstm_bidirectional_impl_1', 'LSTM', False, True, 1),
- ('lstm_to_cudnnlstm_unidirectional_impl_1', 'LSTM', True, False, 1),
- ('lstm_to_cudnnlstm_bidirectional_impl_1', 'LSTM', True, True, 1),
- ('cudnngru_to_gru_unidirectional_impl_1', 'GRU', False, False, 1),
- ('cudnngru_to_gru_bidirectional_impl_1', 'GRU', False, True, 1),
- ('gru_to_cudnngru_unidirectional_impl_1', 'GRU', True, False, 1),
- ('gru_to_cudnngru_bidirectional_impl_1', 'GRU', True, True, 1),
- ('cudnnlstm_to_lstm_unidirectional_impl_2', 'LSTM', False, False, 2),
- ('cudnnlstm_to_lstm_bidirectional_impl_2', 'LSTM', False, True, 2),
- ('lstm_to_cudnnlstm_unidirectional_impl_2', 'LSTM', True, False, 2),
- ('lstm_to_cudnnlstm_bidirectional_impl_2', 'LSTM', True, True, 2),
- ('cudnngru_to_gru_unidirectional_impl_2', 'GRU', False, False, 2),
- ('cudnngru_to_gru_bidirectional_impl_2', 'GRU', False, True, 2),
- ('gru_to_cudnngru_unidirectional_impl_2', 'GRU', True, False, 2),
- ('gru_to_cudnngru_bidirectional_impl_2', 'GRU', True, True, 2),
- )
+ *testing_utils.generate_combinations_with_testcase_name(
+ rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False],
+ bidirectional=[True, False], implementation=[1, 2],
+ model_nest_level=[1, 2], model_type=['seq', 'func']))
def test_load_weights_between_noncudnn_rnn(self, rnn_type, to_cudnn,
- bidirectional, implementation):
+ bidirectional, implementation,
+ model_nest_level, model_type):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
input_size = 10
@@ -261,14 +250,6 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
cudnn_rnn_layer_class = keras.layers.CuDNNGRU
rnn_layer_kwargs['reset_after'] = True
- def convert_weights(source_layer, target_layer):
- weights = source_layer.get_weights()
- weights = keras.engine.saving.preprocess_weights_for_loading(
- target_layer, weights)
- target_layer.set_weights(weights)
-
- input_layer = keras.layers.InputLayer(input_shape)
-
layer = rnn_layer_class(units, **rnn_layer_kwargs)
if bidirectional:
layer = keras.layers.Bidirectional(layer)
@@ -277,18 +258,96 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
if bidirectional:
cudnn_layer = keras.layers.Bidirectional(cudnn_layer)
- model = keras.models.Sequential([input_layer, layer])
- cudnn_model = keras.models.Sequential([input_layer, cudnn_layer])
+ model = self._make_nested_model(input_shape, layer, model_nest_level,
+ model_type)
+ cudnn_model = self._make_nested_model(input_shape, cudnn_layer,
+ model_nest_level, model_type)
+
+ if to_cudnn:
+ self._convert_model_weights(model, cudnn_model)
+ else:
+ self._convert_model_weights(cudnn_model, model)
+
+ self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs),
+ atol=1e-4)
+
+ def _make_nested_model(self, input_shape, layer, level=1, model_type='func'):
+ # example: make_nested_seq_model((1,), Dense(10), level=2).summary()
+ def make_nested_seq_model(input_shape, layer, level=1):
+ model = layer
+ for i in range(1, level + 1):
+ layers = [keras.layers.InputLayer(input_shape),
+ model] if (i == 1) else [model]
+ model = keras.models.Sequential(layers)
+ return model
+
+ # example: make_nested_func_model((1,), Dense(10), level=2).summary()
+ def make_nested_func_model(input_shape, layer, level=1):
+ model_input = keras.layers.Input(input_shape)
+ model = layer
+ for _ in range(level):
+ model = keras.models.Model(model_input, model(model_input))
+ return model
+
+ if model_type == 'func':
+ return make_nested_func_model(input_shape, layer, level)
+ elif model_type == 'seq':
+ return make_nested_seq_model(input_shape, layer, level)
+
+ def _convert_model_weights(self, source_model, target_model):
+ _, fname = tempfile.mkstemp('.h5')
+ source_model.save_weights(fname)
+ target_model.load_weights(fname)
+ os.remove(fname)
+
+ @parameterized.named_parameters(
+ *testing_utils.generate_combinations_with_testcase_name(
+ rnn_type=['LSTM', 'GRU'], to_cudnn=[True, False]))
+ def test_load_weights_between_noncudnn_rnn_time_distributed(self, rnn_type,
+ to_cudnn):
+ # Similar test as test_load_weights_between_noncudnn_rnn() but has different
+ # rank of input due to usage of TimeDistributed. Issue: #10356.
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ input_size = 10
+ steps = 6
+ timesteps = 6
+ input_shape = (timesteps, steps, input_size)
+ units = 2
+ num_samples = 32
+ inputs = np.random.random((num_samples, timesteps, steps, input_size))
+
+ rnn_layer_kwargs = {
+ 'recurrent_activation': 'sigmoid',
+ # ensure biases are non-zero and properly converted
+ 'bias_initializer': 'random_uniform',
+ }
+ if rnn_type == 'LSTM':
+ rnn_layer_class = keras.layers.LSTM
+ cudnn_rnn_layer_class = keras.layers.CuDNNLSTM
+ else:
+ rnn_layer_class = keras.layers.GRU
+ cudnn_rnn_layer_class = keras.layers.CuDNNGRU
+ rnn_layer_kwargs['reset_after'] = True
+
+ layer = rnn_layer_class(units, **rnn_layer_kwargs)
+ layer = keras.layers.TimeDistributed(layer)
+
+ cudnn_layer = cudnn_rnn_layer_class(units)
+ cudnn_layer = keras.layers.TimeDistributed(cudnn_layer)
+
+ model = self._make_nested_model(input_shape, layer)
+ cudnn_model = self._make_nested_model(input_shape, cudnn_layer)
if to_cudnn:
- convert_weights(layer, cudnn_layer)
+ self._convert_model_weights(model, cudnn_model)
else:
- convert_weights(cudnn_layer, layer)
+ self._convert_model_weights(cudnn_model, model)
- self.assertAllClose(
- model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4)
+ self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs),
+ atol=1e-4)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_cudnnrnn_bidirectional(self):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 25eeeee952..629a9ec9a1 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -22,7 +22,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
@@ -112,6 +112,7 @@ class Embedding(Layer):
self.activity_regularizer = regularizers.get(activity_regularizer)
self.embeddings_constraint = constraints.get(embeddings_constraint)
self.mask_zero = mask_zero
+ self.supports_masking = mask_zero
self.input_length = input_length
@tf_utils.shape_type_conversion
@@ -127,8 +128,8 @@ class Embedding(Layer):
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
- else:
- return math_ops.not_equal(inputs, 0)
+
+ return math_ops.not_equal(inputs, 0)
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
diff --git a/tensorflow/python/keras/layers/gru_test.py b/tensorflow/python/keras/layers/gru_test.py
index 234434f7a0..57f660b6d5 100644
--- a/tensorflow/python/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/layers/gru_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class GRULayerTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_return_sequences_GRU(self):
num_samples = 2
timesteps = 3
@@ -41,7 +41,7 @@ class GRULayerTest(test.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dynamic_behavior_GRU(self):
num_samples = 2
timesteps = 3
@@ -55,7 +55,7 @@ class GRULayerTest(test.TestCase):
y = np.random.random((num_samples, units))
model.train_on_batch(x, y)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dropout_GRU(self):
num_samples = 2
timesteps = 3
@@ -68,7 +68,7 @@ class GRULayerTest(test.TestCase):
'recurrent_dropout': 0.1},
input_shape=(num_samples, timesteps, embedding_dim))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_implementation_mode_GRU(self):
num_samples = 2
timesteps = 3
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index 46c18b763e..0ebafe07cc 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -23,8 +23,8 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.util.tf_export import tf_export
@@ -62,6 +62,16 @@ class LocallyConnected1D(Layer):
any `dilation_rate` value != 1.
padding: Currently only supports `"valid"` (case-insensitive).
`"same"` may be supported in the future.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, length, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, length)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
activation: Activation function to use.
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
@@ -122,13 +132,17 @@ class LocallyConnected1D(Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
- input_dim = input_shape[2]
+ if self.data_format == 'channels_first':
+ input_dim, input_length = input_shape[1], input_shape[2]
+ else:
+ input_dim, input_length = input_shape[2], input_shape[1]
+
if input_dim is None:
raise ValueError('Axis 2 of input should be fully-defined. '
'Found shape:', input_shape)
- output_length = conv_utils.conv_output_length(
- input_shape[1], self.kernel_size[0], self.padding, self.strides[0])
- self.kernel_shape = (output_length, self.kernel_size[0] * input_dim,
+ self.output_length = conv_utils.conv_output_length(
+ input_length, self.kernel_size[0], self.padding, self.strides[0])
+ self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
self.filters)
self.kernel = self.add_weight(
shape=self.kernel_shape,
@@ -138,28 +152,43 @@ class LocallyConnected1D(Layer):
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- shape=(output_length, self.filters),
+ shape=(self.output_length, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
- self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
+
+ if self.data_format == 'channels_first':
+ self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
+ else:
+ self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
self.built = True
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
- length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
+ if self.data_format == 'channels_first':
+ input_length = input_shape[2]
+ else:
+ input_length = input_shape[1]
+
+ length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
self.padding, self.strides[0])
- return (input_shape[0], length, self.filters)
+
+ if self.data_format == 'channels_first':
+ return (input_shape[0], self.filters, length)
+ elif self.data_format == 'channels_last':
+ return (input_shape[0], length, self.filters)
def call(self, inputs):
- output = K.local_conv1d(inputs, self.kernel, self.kernel_size, self.strides)
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_length,), self.data_format)
+
if self.use_bias:
- output = K.bias_add(output, self.bias)
- if self.activation is not None:
- output = self.activation(output)
+ output = K.bias_add(output, self.bias, data_format=self.data_format)
+
+ output = self.activation(output)
return output
def get_config(self):
@@ -172,6 +201,8 @@ class LocallyConnected1D(Layer):
self.strides,
'padding':
self.padding,
+ 'data_format':
+ self.data_format,
'activation':
activations.serialize(self.activation),
'use_bias':
@@ -370,9 +401,8 @@ class LocallyConnected2D(Layer):
return (input_shape[0], rows, cols, self.filters)
def call(self, inputs):
- output = K.local_conv2d(inputs, self.kernel, self.kernel_size, self.strides,
- (self.output_row, self.output_col),
- self.data_format)
+ output = K.local_conv(inputs, self.kernel, self.kernel_size, self.strides,
+ (self.output_row, self.output_col), self.data_format)
if self.use_bias:
output = K.bias_add(output, self.bias, data_format=self.data_format)
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index 90ae1719e1..9639e0251f 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class LocallyConnectedLayersTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_locallyconnected_1d(self):
num_samples = 2
num_steps = 8
@@ -40,16 +40,17 @@ class LocallyConnectedLayersTest(test.TestCase):
for strides in [1]:
if padding == 'same' and strides != 1:
continue
-
- testing_utils.layer_test(
- keras.layers.LocallyConnected1D,
- kwargs={
- 'filters': filters,
- 'kernel_size': filter_length,
- 'padding': padding,
- 'strides': strides
- },
- input_shape=(num_samples, num_steps, input_dim))
+ for data_format in ['channels_first', 'channels_last']:
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected1D,
+ kwargs={
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'padding': padding,
+ 'strides': strides,
+ 'data_format': data_format
+ },
+ input_shape=(num_samples, num_steps, input_dim))
def test_locallyconnected_1d_regularization(self):
num_samples = 2
@@ -57,37 +58,41 @@ class LocallyConnectedLayersTest(test.TestCase):
input_dim = 5
filter_length = 3
filters = 4
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_regularizer': 'l2',
- 'bias_regularizer': 'l2',
- 'activity_regularizer': 'l2',
- }
-
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(len(layer.losses), 2)
- layer(
- keras.backend.variable(np.ones((num_samples, num_steps, input_dim))))
- self.assertEqual(len(layer.losses), 3)
-
- k_constraint = keras.constraints.max_norm(0.01)
- b_constraint = keras.constraints.max_norm(0.01)
- kwargs = {
- 'filters': filters,
- 'kernel_size': filter_length,
- 'kernel_constraint': k_constraint,
- 'bias_constraint': b_constraint,
- }
- with self.test_session():
- layer = keras.layers.LocallyConnected1D(**kwargs)
- layer.build((num_samples, num_steps, input_dim))
- self.assertEqual(layer.kernel.constraint, k_constraint)
- self.assertEqual(layer.bias.constraint, b_constraint)
-
- @tf_test_util.run_in_graph_and_eager_modes()
+ for data_format in ['channels_first', 'channels_last']:
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_regularizer': 'l2',
+ 'bias_regularizer': 'l2',
+ 'activity_regularizer': 'l2',
+ 'data_format': data_format
+ }
+
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(len(layer.losses), 2)
+ layer(
+ keras.backend.variable(np.ones((num_samples,
+ num_steps,
+ input_dim))))
+ self.assertEqual(len(layer.losses), 3)
+
+ k_constraint = keras.constraints.max_norm(0.01)
+ b_constraint = keras.constraints.max_norm(0.01)
+ kwargs = {
+ 'filters': filters,
+ 'kernel_size': filter_length,
+ 'kernel_constraint': k_constraint,
+ 'bias_constraint': b_constraint,
+ }
+ with self.test_session():
+ layer = keras.layers.LocallyConnected1D(**kwargs)
+ layer.build((num_samples, num_steps, input_dim))
+ self.assertEqual(layer.kernel.constraint, k_constraint)
+ self.assertEqual(layer.bias.constraint, b_constraint)
+
+ @tf_test_util.run_in_graph_and_eager_modes
def test_locallyconnected_2d(self):
num_samples = 8
filters = 3
@@ -113,6 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase):
},
input_shape=(num_samples, num_row, num_col, stack_size))
+ @tf_test_util.run_in_graph_and_eager_modes
def test_locallyconnected_2d_channels_first(self):
num_samples = 8
filters = 3
@@ -120,15 +126,14 @@ class LocallyConnectedLayersTest(test.TestCase):
num_row = 6
num_col = 10
- with self.test_session():
- testing_utils.layer_test(
- keras.layers.LocallyConnected2D,
- kwargs={
- 'filters': filters,
- 'kernel_size': 3,
- 'data_format': 'channels_first'
- },
- input_shape=(num_samples, num_row, num_col, stack_size))
+ testing_utils.layer_test(
+ keras.layers.LocallyConnected2D,
+ kwargs={
+ 'filters': filters,
+ 'kernel_size': 3,
+ 'data_format': 'channels_first'
+ },
+ input_shape=(num_samples, num_row, num_col, stack_size))
def test_locallyconnected_2d_regularization(self):
num_samples = 8
diff --git a/tensorflow/python/keras/layers/lstm_test.py b/tensorflow/python/keras/layers/lstm_test.py
index 87cb344bf8..ae381f5955 100644
--- a/tensorflow/python/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/layers/lstm_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class LSTMLayerTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_return_sequences_LSTM(self):
num_samples = 2
timesteps = 3
@@ -56,7 +56,7 @@ class LSTMLayerTest(test.TestCase):
outputs = model.layers[-1].output
self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units])
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dynamic_behavior_LSTM(self):
num_samples = 2
timesteps = 3
@@ -70,7 +70,7 @@ class LSTMLayerTest(test.TestCase):
y = np.random.random((num_samples, units))
model.train_on_batch(x, y)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dropout_LSTM(self):
num_samples = 2
timesteps = 3
@@ -83,7 +83,7 @@ class LSTMLayerTest(test.TestCase):
'recurrent_dropout': 0.1},
input_shape=(num_samples, timesteps, embedding_dim))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_implementation_mode_LSTM(self):
num_samples = 2
timesteps = 3
diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py
index 683e3e0ed1..f295af3fe0 100644
--- a/tensorflow/python/keras/layers/merge.py
+++ b/tensorflow/python/keras/layers/merge.py
@@ -250,6 +250,7 @@ class Add(_Merge):
return output
+@tf_export('keras.layers.Subtract')
class Subtract(_Merge):
"""Layer that subtracts two inputs.
@@ -336,6 +337,7 @@ class Maximum(_Merge):
return output
+@tf_export('keras.layers.Minimum')
class Minimum(_Merge):
"""Layer that computes the minimum (element-wise) a list of inputs.
@@ -446,8 +448,8 @@ class Concatenate(_Merge):
class Dot(_Merge):
"""Layer that computes a dot product between samples in two tensors.
- E.g. if applied to two tensors `a` and `b` of shape `(batch_size, n)`,
- the output will be a tensor of shape `(batch_size, 1)`
+ E.g. if applied to a list of two tensors `a` and `b` of shape
+ `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
where each entry `i` will be the dot product between
`a[i]` and `b[i]`.
@@ -586,6 +588,7 @@ def add(inputs, **kwargs):
return Add(**kwargs)(inputs)
+@tf_export('keras.layers.subtract')
def subtract(inputs, **kwargs):
"""Functional interface to the `Subtract` layer.
@@ -656,6 +659,7 @@ def maximum(inputs, **kwargs):
return Maximum(**kwargs)(inputs)
+@tf_export('keras.layers.minimum')
def minimum(inputs, **kwargs):
"""Functional interface to the `Minimum` layer.
diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py
index 8a097cf7f5..39bc98d039 100644
--- a/tensorflow/python/keras/layers/merge_test.py
+++ b/tensorflow/python/keras/layers/merge_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class MergeLayersTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_add(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -76,7 +76,7 @@ class MergeLayersTest(test.TestCase):
with self.assertRaises(ValueError):
keras.layers.add([i1])
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_multiply(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -92,7 +92,7 @@ class MergeLayersTest(test.TestCase):
self.assertEqual(out.shape, (2, 4, 5))
self.assertAllClose(out, x1 * x2 * x3, atol=1e-4)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_average(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -106,7 +106,7 @@ class MergeLayersTest(test.TestCase):
self.assertEqual(out.shape, (2, 4, 5))
self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_maximum(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -120,7 +120,7 @@ class MergeLayersTest(test.TestCase):
self.assertEqual(out.shape, (2, 4, 5))
self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_minimum(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -134,7 +134,7 @@ class MergeLayersTest(test.TestCase):
self.assertEqual(out.shape, (2, 4, 5))
self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_concatenate(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
@@ -169,7 +169,7 @@ class MergeLayersTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'called on a list'):
keras.layers.concatenate([i1], axis=-1)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_dot(self):
i1 = keras.layers.Input(shape=(4,))
i2 = keras.layers.Input(shape=(4,))
@@ -215,7 +215,7 @@ class MergeLayersTest(test.TestCase):
dot = keras.layers.Dot(1)
dot.compute_output_shape(1)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_merge_subtract(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
diff --git a/tensorflow/python/keras/layers/noise.py b/tensorflow/python/keras/layers/noise.py
index a895caa25b..cb7cee3ebc 100644
--- a/tensorflow/python/keras/layers/noise.py
+++ b/tensorflow/python/keras/layers/noise.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py
index bde2185f03..aa2be62390 100644
--- a/tensorflow/python/keras/layers/noise_test.py
+++ b/tensorflow/python/keras/layers/noise_test.py
@@ -40,7 +40,7 @@ class NoiseLayersTest(test.TestCase):
kwargs={'rate': 0.5},
input_shape=(3, 2, 3))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_AlphaDropout(self):
testing_utils.layer_test(
keras.layers.AlphaDropout,
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 7743d00c0f..58c8a8a66d 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -26,14 +26,15 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.tf_export import tf_export
@@ -180,11 +181,6 @@ class BatchNormalization(Layer):
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
- def _add_tower_local_variable(self, *args, **kwargs):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope('mean'):
- return self.add_variable(*args, **kwargs)
-
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
@@ -276,7 +272,7 @@ class BatchNormalization(Layer):
self.axis[idx] = x + 1 # Account for added dimension
if self.scale:
- self.gamma = self.add_variable(
+ self.gamma = self.add_weight(
name='gamma',
shape=param_shape,
dtype=param_dtype,
@@ -291,7 +287,7 @@ class BatchNormalization(Layer):
1.0, dtype=param_dtype, shape=param_shape)
if self.center:
- self.beta = self.add_variable(
+ self.beta = self.add_weight(
name='beta',
shape=param_shape,
dtype=param_dtype,
@@ -312,19 +308,23 @@ class BatchNormalization(Layer):
self._scope.set_partitioner(None)
else:
partitioner = None
- self.moving_mean = self._add_tower_local_variable(
+ self.moving_mean = self.add_weight(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- self.moving_variance = self._add_tower_local_variable(
+ self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
@@ -335,12 +335,14 @@ class BatchNormalization(Layer):
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
def _renorm_variable(name, shape):
- var = self._add_tower_local_variable(
+ var = self.add_weight(
name=name,
shape=shape,
dtype=param_dtype,
initializer=init_ops.zeros_initializer(),
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
return var
with distribute_lib.get_distribution_strategy().colocate_vars_with(
@@ -364,11 +366,12 @@ class BatchNormalization(Layer):
def _assign_moving_average(self, variable, value, momentum):
with ops.name_scope(None, 'AssignMovingAvg',
[variable, value, momentum]) as scope:
- decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ with ops.colocate_with(variable):
+ decay = ops.convert_to_tensor(1.0 - momentum, name='decay')
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ update_delta = (variable - value) * decay
+ return state_ops.assign_sub(variable, update_delta, name=scope)
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 10a82b285e..912e8bd619 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -20,8 +20,8 @@ from __future__ import print_function
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index cbd58a2287..2cd9939e66 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -27,14 +27,14 @@ from tensorflow.python.platform import test
class GlobalPoolingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
testing_utils.layer_test(
keras.layers.pooling.GlobalMaxPooling2D,
@@ -53,7 +53,7 @@ class GlobalPoolingTest(test.TestCase):
kwargs={'data_format': 'channels_last'},
input_shape=(3, 5, 6, 4))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_3d(self):
testing_utils.layer_test(
keras.layers.pooling.GlobalMaxPooling3D,
@@ -75,7 +75,7 @@ class GlobalPoolingTest(test.TestCase):
class Pooling2DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_maxpooling_2d(self):
pool_size = (3, 3)
for strides in [(1, 1), (2, 2)]:
@@ -88,7 +88,7 @@ class Pooling2DTest(test.TestCase):
},
input_shape=(3, 5, 6, 4))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_2d(self):
testing_utils.layer_test(
keras.layers.AveragePooling2D,
@@ -122,7 +122,7 @@ class Pooling2DTest(test.TestCase):
class Pooling3DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_maxpooling_3d(self):
pool_size = (3, 3, 3)
testing_utils.layer_test(
@@ -141,7 +141,7 @@ class Pooling3DTest(test.TestCase):
},
input_shape=(3, 4, 11, 12, 10))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_3d(self):
pool_size = (3, 3, 3)
testing_utils.layer_test(
@@ -163,7 +163,7 @@ class Pooling3DTest(test.TestCase):
class Pooling1DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_maxpooling_1d(self):
for padding in ['valid', 'same']:
for stride in [1, 2]:
@@ -173,7 +173,7 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
for padding in ['valid', 'same']:
for stride in [1, 2]:
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 7e509fb451..61775da47b 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -29,8 +29,8 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@@ -235,7 +235,8 @@ class RNN(Layer):
"""Base class for recurrent layers.
Arguments:
- cell: A RNN cell instance. A RNN cell is a class that has:
+ cell: A RNN cell instance or a list of RNN cell instances.
+ A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
@@ -248,9 +249,9 @@ class RNN(Layer):
(one size per state). In this case, the first entry
(`state_size[0]`) should be the same as
the size of the cell output.
- It is also possible for `cell` to be a list of RNN cell instances,
- in which cases the cells get stacked on after the other in the RNN,
- implementing an efficient stacked RNN.
+ In the case that `cell` is a list of RNN cell instances, the cells
+ will be stacked on after the other in the RNN, implementing an
+ efficient stacked RNN.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py
index be306c0af7..7c45e08b5c 100644
--- a/tensorflow/python/keras/layers/serialization.py
+++ b/tensorflow/python/keras/layers/serialization.py
@@ -20,8 +20,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.keras.engine import Input
-from tensorflow.python.keras.engine import InputLayer
+from tensorflow.python.keras.engine.input_layer import Input
+from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.layers.advanced_activations import *
from tensorflow.python.keras.layers.convolutional import *
from tensorflow.python.keras.layers.convolutional_recurrent import *
diff --git a/tensorflow/python/keras/layers/simplernn_test.py b/tensorflow/python/keras/layers/simplernn_test.py
index 3d24b0d504..18fefbe84f 100644
--- a/tensorflow/python/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/layers/simplernn_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.training.rmsprop import RMSPropOptimizer
class SimpleRNNLayerTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_return_sequences_SimpleRNN(self):
num_samples = 2
timesteps = 3
@@ -41,7 +41,7 @@ class SimpleRNNLayerTest(test.TestCase):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dynamic_behavior_SimpleRNN(self):
num_samples = 2
timesteps = 3
@@ -55,7 +55,7 @@ class SimpleRNNLayerTest(test.TestCase):
y = np.random.random((num_samples, units))
model.train_on_batch(x, y)
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_dropout_SimpleRNN(self):
num_samples = 2
timesteps = 3
@@ -68,7 +68,7 @@ class SimpleRNNLayerTest(test.TestCase):
'recurrent_dropout': 0.1},
input_shape=(num_samples, timesteps, embedding_dim))
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_implementation_mode_SimpleRNN(self):
num_samples = 2
timesteps = 3
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index 7759561ef9..f651e03874 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -23,8 +23,8 @@ import copy
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.engine import InputSpec
-from tensorflow.python.keras.engine import Layer
+from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
@@ -45,7 +45,9 @@ class Wrapper(Layer):
"""
def __init__(self, layer, **kwargs):
+ assert isinstance(layer, Layer)
self.layer = layer
+ self._track_checkpointable(layer, name='layer')
# Tracks mapping of Wrapper inputs to inner layer inputs. Useful when
# the inner layer has update ops that depend on its inputs (as opposed
# to the inputs to the Wrapper layer).
@@ -154,19 +156,61 @@ class TimeDistributed(Wrapper):
Arguments:
layer: a layer instance.
+
+ Raises:
+ ValueError: If not initialized with a `Layer` instance.
"""
def __init__(self, layer, **kwargs):
+ if not isinstance(layer, Layer):
+ raise ValueError(
+ 'Please initialize `TimeDistributed` layer with a '
+ '`Layer` instance. You passed: {input}'.format(input=layer))
super(TimeDistributed, self).__init__(layer, **kwargs)
self.supports_masking = True
+ def _get_shape_tuple(self, init_tuple, tensor, start_idx, int_shape=None):
+ """Finds non-specific dimensions in the static shapes.
+
+ The static shapes are replaced with the corresponding dynamic shapes of the
+ tensor.
+
+ Arguments:
+ init_tuple: a tuple, the first part of the output shape
+ tensor: the tensor from which to get the (static and dynamic) shapes
+ as the last part of the output shape
+ start_idx: int, which indicate the first dimension to take from
+ the static shape of the tensor
+ int_shape: an alternative static shape to take as the last part
+ of the output shape
+ Returns:
+ The new int_shape with the first part from init_tuple
+ and the last part from either `int_shape` (if provided)
+ or `tensor.shape`, where every `None` is replaced by
+ the corresponding dimension from `tf.shape(tensor)`.
+ """
+ # replace all None in int_shape by K.shape
+ if int_shape is None:
+ int_shape = K.int_shape(tensor)[start_idx:]
+ if not any(not s for s in int_shape):
+ return init_tuple + tuple(int_shape)
+ shape = K.shape(tensor)
+ int_shape = list(int_shape)
+ for i, s in enumerate(int_shape):
+ if not s:
+ int_shape[i] = shape[start_idx + i]
+ return init_tuple + tuple(int_shape)
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
assert len(input_shape) >= 3
self.input_spec = InputSpec(shape=input_shape)
child_input_shape = [input_shape[0]] + input_shape[2:]
if not self.layer.built:
- self.layer.build(child_input_shape)
+ # The base layer class calls a conversion function on the input shape to
+ # convert it to a TensorShape. The conversion function requires a
+ # tuple which is why we cast the shape.
+ self.layer.build(tuple(child_input_shape))
self.layer.built = True
super(TimeDistributed, self).build()
self.built = True
@@ -212,18 +256,24 @@ class TimeDistributed(Wrapper):
input_length = input_shape[1]
if not input_length:
input_length = array_ops.shape(inputs)[1]
+ inner_input_shape = self._get_shape_tuple((-1,), inputs, 2)
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
input_uid = generic_utils.object_list_uid(inputs)
- inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
+ inputs = array_ops.reshape(inputs, inner_input_shape)
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
+ if generic_utils.has_arg(self.layer.call, 'mask') and mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ kwargs['mask'] = K.reshape(mask, inner_mask_shape)
y = self.layer.call(inputs, **kwargs)
if hasattr(y, '_uses_learning_phase'):
uses_learning_phase = y._uses_learning_phase
# Shape: (num_samples, timesteps, ...)
output_shape = self.compute_output_shape(input_shape).as_list()
- y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:]))
+ output_shape = self._get_shape_tuple(
+ (-1, input_length), y, 1, output_shape[2:])
+ y = array_ops.reshape(y, output_shape)
# Apply activity regularizer if any:
if (hasattr(self.layer, 'activity_regularizer') and
@@ -235,6 +285,80 @@ class TimeDistributed(Wrapper):
y._uses_learning_phase = True
return y
+ def compute_mask(self, inputs, mask=None):
+ """Computes an output mask tensor for Embedding layer.
+
+ This is based on the inputs, mask, and the inner layer.
+ If batch size is specified:
+ Simply return the input `mask`. (An rnn-based implementation with
+ more than one rnn inputs is required but not supported in tf.keras yet.)
+ Otherwise we call `compute_mask` of the inner layer at each time step.
+ If the output mask at each time step is not `None`:
+ (E.g., inner layer is Masking or RNN)
+ Concatenate all of them and return the concatenation.
+ If the output mask at each time step is `None` and the input mask is not
+ `None`:(E.g., inner layer is Dense)
+ Reduce the input_mask to 2 dimensions and return it.
+ Otherwise (both the output mask and the input mask are `None`):
+ (E.g., `mask` is not used at all)
+ Return `None`.
+
+ Arguments:
+ inputs: Tensor with shape [batch size, timesteps, ...] indicating the
+ input to TimeDistributed. If static shape information is available for
+ "batch size", `mask` is returned unmodified.
+ mask: Either None (indicating no masking) or a Tensor indicating the
+ input mask for TimeDistributed. The shape can be static or dynamic.
+
+ Returns:
+ Either None (no masking), or a [batch size, timesteps, ...] Tensor with
+ an output mask for the TimeDistributed layer with the shape beyond the
+ second dimension being the value of the input mask shape(if the computed
+ output mask is none), an output mask with the shape beyond the first
+ dimension being the value of the mask shape(if mask is not None) or
+ output mask with the shape beyond the first dimension being the
+ value of the computed output shape.
+
+ """
+ # cases need to call the layer.compute_mask when input_mask is None:
+ # Masking layer and Embedding layer with mask_zero
+ input_shape = K.int_shape(inputs)
+ if input_shape[0]:
+ # batch size matters, we currently do not handle mask explicitly
+ return mask
+ inner_mask = mask
+ if inner_mask is not None:
+ inner_mask_shape = self._get_shape_tuple((-1,), mask, 2)
+ inner_mask = K.reshape(inner_mask, inner_mask_shape)
+ input_uid = generic_utils.object_list_uid(inputs)
+ inner_inputs = self._input_map[input_uid]
+ output_mask = self.layer.compute_mask(inner_inputs, inner_mask)
+ if output_mask is None:
+ if mask is None:
+ return None
+ # input_mask is not None, and output_mask is None:
+ # we should return a not-None mask
+ output_mask = mask
+ for _ in range(2, len(K.int_shape(mask))):
+ output_mask = K.any(output_mask, axis=-1)
+ else:
+ # output_mask is not None. We need to reshape it
+ input_length = input_shape[1]
+ if not input_length:
+ input_length = K.shape(inputs)[1]
+ output_mask_int_shape = K.int_shape(output_mask)
+ if output_mask_int_shape is None:
+ # if the output_mask does not have a static shape,
+ # its shape must be the same as mask's
+ if mask is not None:
+ output_mask_int_shape = K.int_shape(mask)
+ else:
+ output_mask_int_shape = K.compute_output_shape(input_shape)[:-1]
+ output_mask_shape = self._get_shape_tuple(
+ (-1, input_length), output_mask, 1, output_mask_int_shape[1:])
+ output_mask = K.reshape(output_mask, output_mask_shape)
+ return output_mask
+
@tf_export('keras.layers.Bidirectional')
class Bidirectional(Wrapper):
@@ -249,7 +373,8 @@ class Bidirectional(Wrapper):
they will be returned as a list.
Raises:
- ValueError: In case of invalid `merge_mode` argument.
+ ValueError: If not initialized with a `Layer` instance or
+ In case of invalid `merge_mode` argument.
Examples:
@@ -265,6 +390,10 @@ class Bidirectional(Wrapper):
"""
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
+ if not isinstance(layer, Layer):
+ raise ValueError(
+ 'Please initialize `Bidirectional` layer with a '
+ '`Layer` instance. You passed: {input}'.format(input=layer))
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
raise ValueError('Invalid merge mode. '
'Merge mode should be one of '
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 5eab6aba8a..3f268acf5c 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -23,8 +23,10 @@ import copy
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -69,7 +71,7 @@ class _RNNCellWithConstants(keras.layers.Layer):
class TimeDistributedTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes
def test_timedistributed_dense(self):
model = keras.models.Sequential()
model.add(
@@ -85,6 +87,10 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
def test_timedistributed_static_batch_size(self):
model = keras.models.Sequential()
model.add(
@@ -97,6 +103,13 @@ class TimeDistributedTest(test.TestCase):
epochs=1,
batch_size=10)
+ def test_timedistributed_invalid_init(self):
+ x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Please initialize `TimeDistributed` layer with a `Layer` instance.'):
+ keras.layers.TimeDistributed(x)
+
def test_timedistributed_conv2d(self):
with self.test_session():
model = keras.models.Sequential()
@@ -177,8 +190,8 @@ class TimeDistributedTest(test.TestCase):
x = keras.layers.Input(shape=(3, 2))
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
_ = layer(x)
- assert len(layer.updates) == 2
- assert len(layer.trainable_weights) == 2
+ self.assertEquals(len(layer.updates), 2)
+ self.assertEquals(len(layer.trainable_weights), 2)
layer.trainable = False
assert not layer.updates
assert not layer.trainable_weights
@@ -186,6 +199,62 @@ class TimeDistributedTest(test.TestCase):
assert len(layer.updates) == 2
assert len(layer.trainable_weights) == 2
+ def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
+ with self.test_session():
+ # test with unspecified shape and Embeddings with mask_zero
+ model = keras.models.Sequential()
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.Embedding(5, 6, mask_zero=True),
+ input_shape=(None, None))) # N by t_1 by t_2 by 6
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.SimpleRNN(7, return_sequences=True)))
+ model.add(keras.layers.TimeDistributed(
+ keras.layers.SimpleRNN(8, return_sequences=False)))
+ model.add(keras.layers.SimpleRNN(1, return_sequences=False))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model_input = np.random.randint(low=1, high=5, size=(10, 3, 4),
+ dtype='int32')
+ for i in range(4):
+ model_input[i, i:, i:] = 0
+ model.fit(model_input,
+ np.random.random((10, 1)), epochs=1, batch_size=10)
+ mask_outputs = [model.layers[0].compute_mask(model.input)]
+ for layer in model.layers[1:]:
+ mask_outputs.append(layer.compute_mask(layer.input, mask_outputs[-1]))
+ func = keras.backend.function([model.input], mask_outputs[:-1])
+ mask_outputs_val = func([model_input])
+ ref_mask_val_0 = model_input > 0 # embedding layer
+ ref_mask_val_1 = ref_mask_val_0 # first RNN layer
+ ref_mask_val_2 = np.any(ref_mask_val_1, axis=-1) # second RNN layer
+ ref_mask_val = [ref_mask_val_0, ref_mask_val_1, ref_mask_val_2]
+ for i in range(3):
+ self.assertAllEqual(mask_outputs_val[i], ref_mask_val[i])
+ self.assertIs(mask_outputs[-1], None) # final layer
+
+ def test_TimeDistributed_with_masking_layer(self):
+ with self.test_session():
+ # test with Masking layer
+ model = keras.models.Sequential()
+ model.add(keras.layers.TimeDistributed(keras.layers.Masking(
+ mask_value=0.,), input_shape=(None, 4)))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(5)))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model_input = np.random.randint(low=1, high=5, size=(10, 3, 4))
+ for i in range(4):
+ model_input[i, i:, :] = 0.
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.fit(model_input,
+ np.random.random((10, 3, 5)), epochs=1, batch_size=6)
+ mask_outputs = [model.layers[0].compute_mask(model.input)]
+ mask_outputs += [model.layers[1].compute_mask(model.layers[1].input,
+ mask_outputs[-1])]
+ func = keras.backend.function([model.input], mask_outputs)
+ mask_outputs_val = func([model_input])
+ self.assertEqual((mask_outputs_val[0]).all(),
+ model_input.all())
+ self.assertEqual((mask_outputs_val[1]).all(),
+ model_input.all())
+
class BidirectionalTest(test.TestCase):
@@ -220,6 +289,13 @@ class BidirectionalTest(test.TestCase):
model = keras.models.model_from_json(model.to_json())
model.summary()
+ def test_bidirectional_invalid_init(self):
+ x = constant_op.constant(np.zeros((1, 1)).astype('float32'))
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'Please initialize `Bidirectional` layer with a `Layer` instance.'):
+ keras.layers.Bidirectional(x)
+
def test_bidirectional_weight_loading(self):
rnn = keras.layers.SimpleRNN
samples = 2
@@ -424,6 +500,42 @@ class BidirectionalTest(test.TestCase):
layer.trainable = True
assert len(layer.trainable_weights) == 6
+ def test_Bidirectional_updates(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3, 2))
+ x_reachable_update = x * x
+ layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
+ _ = layer(x)
+ assert not layer.updates
+ assert not layer.get_updates_for(None)
+ assert not layer.get_updates_for(x)
+ layer.forward_layer.add_update(x_reachable_update, inputs=x)
+ layer.forward_layer.add_update(1, inputs=None)
+ layer.backward_layer.add_update(x_reachable_update, inputs=x)
+ layer.backward_layer.add_update(1, inputs=None)
+ assert len(layer.updates) == 4
+ assert len(layer.get_updates_for(None)) == 2
+ assert len(layer.get_updates_for(x)) == 2
+
+ def test_Bidirectional_losses(self):
+ with self.test_session():
+ x = keras.layers.Input(shape=(3, 2))
+ x_reachable_loss = x * x
+ layer = keras.layers.Bidirectional(
+ keras.layers.SimpleRNN(
+ 3, kernel_regularizer='l1', bias_regularizer='l1'))
+ _ = layer(x)
+ assert len(layer.losses) == 4
+ assert len(layer.get_losses_for(None)) == 4
+ assert not layer.get_losses_for(x)
+ layer.forward_layer.add_loss(x_reachable_loss, inputs=x)
+ layer.forward_layer.add_loss(1, inputs=None)
+ layer.backward_layer.add_loss(x_reachable_loss, inputs=x)
+ layer.backward_layer.add_loss(1, inputs=None)
+ assert len(layer.losses) == 8
+ assert len(layer.get_losses_for(None)) == 6
+ assert len(layer.get_losses_for(x)) == 2
+
def test_Bidirectional_with_constants(self):
with self.test_session():
# Test basic case.
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 8fb957da43..3ac4852eff 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@@ -173,7 +173,7 @@ def get_nested_model_3(input_dim, num_classes):
class ModelSubclassingTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_io_workflow_with_np_arrays(self):
num_classes = 2
num_samples = 100
@@ -192,7 +192,7 @@ class ModelSubclassingTest(test.TestCase):
model.fit(x, y, epochs=2, batch_size=32, verbose=0)
_ = model.evaluate(x, y, verbose=0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_multi_io_workflow_with_np_arrays(self):
num_classes = (2, 3)
num_samples = 1000
@@ -251,7 +251,7 @@ class ModelSubclassingTest(test.TestCase):
model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0)
_ = model.evaluate(steps=10, verbose=0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_io_workflow_with_dataset_iterators(self):
num_classes = 2
num_samples = 10
@@ -325,7 +325,7 @@ class ModelSubclassingTest(test.TestCase):
self.assertEqual(len(model.inputs), 2)
self.assertEqual(len(model.outputs), 2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_updates(self):
# test that updates get run during training
num_samples = 100
@@ -352,7 +352,74 @@ class ModelSubclassingTest(test.TestCase):
y_new = model.predict(x)
self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1)
- @test_util.run_in_graph_and_eager_modes()
+ def test_updates_and_losses_for_nested_models_in_subclassed_model(self):
+
+ # Case 1: deferred-build sequential nested in subclass.
+ class TestModel1(keras.Model):
+
+ def __init__(self):
+ super(TestModel1, self).__init__()
+ self.fc = keras.layers.Dense(10, input_shape=(784,),
+ activity_regularizer='l1')
+ self.bn = keras.Sequential([keras.layers.BatchNormalization(axis=1)])
+
+ def call(self, x):
+ return self.bn(self.fc(x))
+
+ with self.test_session():
+ model = TestModel1()
+
+ x = array_ops.ones(shape=[100, 784], dtype='float32')
+ model(x)
+ self.assertEqual(len(model.get_updates_for(x)), 2)
+ self.assertEqual(len(model.get_losses_for(x)), 1)
+
+ # Case 2: placeholder-sequential nested in subclass.
+ class TestModel2(keras.Model):
+
+ def __init__(self):
+ super(TestModel2, self).__init__()
+ self.fc = keras.layers.Dense(10, input_shape=(784,),
+ activity_regularizer='l1')
+ self.bn = keras.Sequential(
+ [keras.layers.BatchNormalization(axis=1, input_shape=(10,))])
+
+ def call(self, x):
+ return self.bn(self.fc(x))
+
+ with self.test_session():
+ model = TestModel2()
+
+ x = array_ops.ones(shape=[100, 784], dtype='float32')
+ model(x)
+ self.assertEqual(len(model.get_updates_for(x)), 2)
+ self.assertEqual(len(model.get_losses_for(x)), 1)
+
+ # Case 3: functional-API model nested in subclass.
+ inputs = keras.Input((10,))
+ outputs = keras.layers.BatchNormalization(axis=1)(inputs)
+ bn = keras.Model(inputs, outputs)
+
+ class TestModel3(keras.Model):
+
+ def __init__(self):
+ super(TestModel3, self).__init__()
+ self.fc = keras.layers.Dense(10, input_shape=(784,),
+ activity_regularizer='l1')
+ self.bn = bn
+
+ def call(self, x):
+ return self.bn(self.fc(x))
+
+ with self.test_session():
+ model = TestModel3()
+
+ x = array_ops.ones(shape=[100, 784], dtype='float32')
+ model(x)
+ self.assertEqual(len(model.get_updates_for(x)), 2)
+ self.assertEqual(len(model.get_losses_for(x)), 1)
+
+ @test_util.run_in_graph_and_eager_modes
def test_training_and_inference_behavior(self):
# test that dropout is applied in training and not inference
@@ -380,7 +447,7 @@ class ModelSubclassingTest(test.TestCase):
loss = model.train_on_batch(x, y)
self.assertGreater(loss, 0.1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_training_methods(self):
# test fit, train_on_batch
# on different input types: list, dict
@@ -433,14 +500,14 @@ class ModelSubclassingTest(test.TestCase):
model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
model.predict_on_batch([x1, x2])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_trainable_mutation(self):
# test that you can change `trainable` on a model or layer, and that
# it freezes the model state during training
# TODO(fchollet): add test after we unify BN behavior in eager and symbolic.
pass
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_saving(self):
num_classes = (2, 3)
@@ -482,7 +549,7 @@ class ModelSubclassingTest(test.TestCase):
self.assertAllClose(y_ref_1, y1, atol=1e-5)
self.assertAllClose(y_ref_2, y2, atol=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_summary(self):
class ToString(object):
@@ -508,7 +575,7 @@ class ModelSubclassingTest(test.TestCase):
model.summary(print_fn=print_fn)
self.assertTrue('Trainable params: 587' in print_fn.contents)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_subclass_nested_in_subclass(self):
num_classes = 2
num_samples = 100
@@ -531,7 +598,7 @@ class ModelSubclassingTest(test.TestCase):
self.assertEqual(len(model.trainable_weights),
6 + len(model.test_net.trainable_weights))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_graph_nested_in_subclass(self):
num_classes = 2
num_samples = 100
@@ -554,7 +621,7 @@ class ModelSubclassingTest(test.TestCase):
self.assertEqual(len(model.trainable_weights),
6 + len(model.test_net.trainable_weights))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_subclass_nested_in_graph(self):
num_classes = 2
num_samples = 100
@@ -576,7 +643,7 @@ class ModelSubclassingTest(test.TestCase):
len(model.non_trainable_weights), 4)
self.assertEqual(len(model.trainable_weights), 12)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_support_for_manual_training_arg(self):
# In most cases, the `training` argument is left unspecified, in which
# case it defaults to value corresponding to the Model method being used
@@ -612,8 +679,8 @@ class ModelSubclassingTest(test.TestCase):
def __init__(self):
super(Foo, self).__init__()
self.isdep = keras.layers.Dense(1)
- self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
- self.notdep_var = checkpointable.NoDependency(
+ self.notdep = data_structures.NoDependency(keras.layers.Dense(2))
+ self.notdep_var = data_structures.NoDependency(
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
m = Foo()
@@ -685,7 +752,7 @@ class CustomCallModel(keras.Model):
class CustomCallSignatureTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_no_inputs_in_signature(self):
model = CustomCallModel()
first = array_ops.ones([2, 3])
@@ -699,7 +766,7 @@ class CustomCallSignatureTests(test.TestCase):
output = model(first, second=second, training=False)
self.assertAllClose(expected_output, self.evaluate(output))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_inputs_in_signature(self):
class HasInputsAndOtherPositional(keras.Model):
@@ -716,7 +783,7 @@ class CustomCallSignatureTests(test.TestCase):
x1, x2 = keras.Input((1, 1)), keras.Input((1, 1))
model(x1, x2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_kwargs_in_signature(self):
class HasKwargs(keras.Model):
@@ -730,7 +797,7 @@ class CustomCallSignatureTests(test.TestCase):
if not context.executing_eagerly():
six.assertCountEqual(self, [arg], model.inputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_args_in_signature(self):
class HasArgs(keras.Model):
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index c616d8f24f..ad3819e6e7 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -129,7 +129,7 @@ class TestModelCloning(test.TestCase):
class CheckpointingTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_optimizer_dependency(self):
model = keras.models.Sequential()
model.add(keras.layers.Dense(1, input_shape=(4,)))
@@ -144,5 +144,19 @@ class CheckpointingTests(test.TestCase):
model.load_weights(save_prefix)
self.assertEqual(12., self.evaluate(beta1_power))
+class TestModelBackend(test.TestCase):
+
+ def test_model_backend_float64_use_cases(self):
+ # Test case for GitHub issue 19318
+ floatx = keras.backend.floatx()
+ keras.backend.set_floatx('float64')
+
+ x = keras.Input((5,))
+ y = keras.layers.Dense(1)(x)
+ model = keras.models.Model(x, y)
+ model.compile('rmsprop', 'mse')
+
+ keras.backend.set_floatx(floatx)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index f58aeaea1a..0b440185ca 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -19,17 +19,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-
import six
from six.moves import zip # pylint: disable=redefined-builtin
-from tensorflow.python.framework import dtypes as dtypes_module
-from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import distribute as distribute_lib
@@ -39,37 +35,6 @@ from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
-def clip_norm(g, c, n):
- """Clip a tensor by norm.
-
- Arguments:
- g: gradient tensor to clip.
- c: clipping threshold.
- n: norm of gradient tensor.
-
- Returns:
- Clipped gradient tensor.
- """
- if c > 0:
- condition = n >= c
- then_expression = lambda: math_ops.scalar_mul(c / n, g)
- else_expression = lambda: g
-
- # saving the shape to avoid converting sparse tensor to dense
- if isinstance(g, ops.Tensor):
- g_shape = copy.copy(g.get_shape())
- elif isinstance(g, ops.IndexedSlices):
- g_shape = copy.copy(g.dense_shape)
- if condition.dtype != dtypes_module.bool:
- condition = math_ops.cast(condition, 'bool')
- g = control_flow_ops.cond(condition, then_expression, else_expression)
- if isinstance(g, ops.Tensor):
- g.set_shape(g_shape)
- elif isinstance(g, ops.IndexedSlices):
- g._dense_shape = g_shape # pylint: disable=protected-access
- return g
-
-
@tf_export('keras.optimizers.Optimizer')
class Optimizer(object):
"""Abstract optimizer base class.
@@ -91,6 +56,9 @@ class Optimizer(object):
if k not in allowed_kwargs:
raise TypeError('Unexpected keyword argument '
'passed to optimizer: ' + str(k))
+ # checks that clipnorm >= 0 and clipvalue >= 0
+ if kwargs[k] < 0:
+ raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k]))
self.__dict__.update(kwargs)
self.updates = []
self.weights = []
@@ -119,12 +87,13 @@ class Optimizer(object):
'gradient defined (i.e. are differentiable). '
'Common ops without gradient: '
'K.argmax, K.round, K.eval.')
- if hasattr(self, 'clipnorm') and self.clipnorm > 0:
- norm = K.sqrt(
- sum([math_ops.reduce_sum(math_ops.square(g)) for g in grads]))
- grads = [clip_norm(g, self.clipnorm, norm) for g in grads]
- if hasattr(self, 'clipvalue') and self.clipvalue > 0:
- grads = [K.clip(g, -self.clipvalue, self.clipvalue) for g in grads]
+ if hasattr(self, 'clipnorm'):
+ grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads]
+ if hasattr(self, 'clipvalue'):
+ grads = [
+ clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue)
+ for g in grads
+ ]
return grads
def set_weights(self, weights):
@@ -719,12 +688,13 @@ class Nadam(Optimizer):
return dict(list(base_config.items()) + list(config.items()))
-class TFOptimizer(Optimizer, checkpointable.Checkpointable):
+class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
"""Wrapper class for native TensorFlow optimizers.
"""
def __init__(self, optimizer): # pylint: disable=super-init-not-called
self.optimizer = optimizer
+ self._track_checkpointable(optimizer, name='optimizer')
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 92b0cf3261..55fc3fdcf4 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -145,6 +145,12 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_negative_clipvalue_or_clipnorm(self):
+ with self.assertRaises(ValueError):
+ _ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5)
+ with self.assertRaises(ValueError):
+ _ = keras.optimizers.Adam(clipnorm=-2.0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index e7cb45d5e1..17aba7d86c 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from collections import OrderedDict
import numpy as np
from tensorflow.python import keras
@@ -183,3 +184,76 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
# for further checks in the caller function
return actual_output
+
+
+def _combine_named_parameters(**kwargs):
+ """Generate combinations based on its keyword arguments.
+
+ Two sets of returned combinations can be concatenated using +. Their product
+ can be computed using `times()`.
+
+ Args:
+ **kwargs: keyword arguments of form `option=[possibilities, ...]`
+ or `option=the_only_possibility`.
+
+ Returns:
+ a list of dictionaries for each combination. Keys in the dictionaries are
+ the keyword argument names. Each key has one value - one of the
+ corresponding keyword argument values.
+ """
+ if not kwargs:
+ return [OrderedDict()]
+
+ sort_by_key = lambda k: k[0][0]
+ kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
+ first = list(kwargs.items())[0]
+
+ rest = dict(list(kwargs.items())[1:])
+ rest_combined = _combine_named_parameters(**rest)
+
+ key = first[0]
+ values = first[1]
+ if not isinstance(values, list):
+ values = [values]
+
+ combinations = [
+ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
+ for v in values
+ for combined in rest_combined
+ ]
+ return combinations
+
+
+def generate_combinations_with_testcase_name(**kwargs):
+ """Generate combinations based on its keyword arguments using combine().
+
+ This function calls combine() and appends a testcase name to the list of
+ dictionaries returned. The 'testcase_name' key is a required for named
+ parameterized tests.
+
+ Args:
+ **kwargs: keyword arguments of form `option=[possibilities, ...]`
+ or `option=the_only_possibility`.
+
+ Returns:
+ a list of dictionaries for each combination. Keys in the dictionaries are
+ the keyword argument names. Each key has one value - one of the
+ corresponding keyword argument values.
+ """
+ combinations = _combine_named_parameters(**kwargs)
+ named_combinations = []
+ for combination in combinations:
+ assert isinstance(combination, OrderedDict)
+ name = ''.join([
+ '_{}_{}'.format(
+ ''.join(filter(str.isalnum, key)),
+ ''.join(filter(str.isalnum, str(value))))
+ for key, value in combination.items()
+ ])
+ named_combinations.append(
+ OrderedDict(
+ list(combination.items()) + [('testcase_name',
+ '_test{}'.format(name))]))
+
+ return named_combinations
+
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index a1f89d9d43..c1ee34ae46 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -324,12 +324,12 @@ def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535):
class Sequence(object):
"""Base object for fitting to a sequence of data, such as a dataset.
- Every `Sequence` must implements the `__getitem__` and the `__len__` methods.
+ Every `Sequence` must implement the `__getitem__` and the `__len__` methods.
If you want to modify your dataset between epochs you may implement
`on_epoch_end`.
The method `__getitem__` should return a complete batch.
- # Notes
+ Notes:
`Sequence` are a safer way to do multiprocessing. This structure guarantees
that the network will only train once
diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py
index f82e3277de..62674a9c77 100644
--- a/tensorflow/python/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/utils/io_utils.py
@@ -102,13 +102,12 @@ class HDF5Matrix(object):
idx = (self.start + key).tolist()
else:
raise IndexError
- elif isinstance(key, list):
+ else:
+ # Assume list/iterable
if max(key) + self.start < self.end:
idx = [x + self.start for x in key]
else:
raise IndexError
- else:
- raise IndexError
if self.normalizer is not None:
return self.normalizer(self.data[idx])
else:
diff --git a/tensorflow/python/keras/utils/io_utils_test.py b/tensorflow/python/keras/utils/io_utils_test.py
index 3895dca68e..81bb661edd 100644
--- a/tensorflow/python/keras/utils/io_utils_test.py
+++ b/tensorflow/python/keras/utils/io_utils_test.py
@@ -22,6 +22,7 @@ import os
import shutil
import numpy as np
+import six
from tensorflow.python import keras
from tensorflow.python.platform import test
@@ -95,6 +96,29 @@ class TestIOUtils(test.TestCase):
self.assertEqual(out_eval.shape, ())
self.assertGreater(out_eval, 0)
+ # test slicing for shortened array
+ self.assertEqual(len(x_train[0:]), len(x_train))
+
+ # test __getitem__ invalid use cases
+ with self.assertRaises(IndexError):
+ _ = x_train[1000]
+ with self.assertRaises(IndexError):
+ _ = x_train[1000: 1001]
+ with self.assertRaises(IndexError):
+ _ = x_train[[1000, 1001]]
+ with self.assertRaises(IndexError):
+ _ = x_train[six.moves.range(1000, 1001)]
+ with self.assertRaises(IndexError):
+ _ = x_train[np.array([1000])]
+ with self.assertRaises(TypeError):
+ _ = x_train[None]
+
+ # test normalizer
+ normalizer = lambda x: x + 1
+ normalized_x_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=0, end=150, normalizer=normalizer)
+ self.assertAllClose(normalized_x_train[0][0], x_train[0][0] + 1)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 88daff0461..1f28c59ea4 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -26,6 +26,47 @@ from tensorflow.python.keras.utils.conv_utils import convert_kernel
from tensorflow.python.util.tf_export import tf_export
+def get_source_inputs(tensor, layer=None, node_index=None):
+ """Returns the list of input tensors necessary to compute `tensor`.
+
+ Output will always be a list of tensors
+ (potentially with 1 element).
+
+ Arguments:
+ tensor: The tensor to start from.
+ layer: Origin layer of the tensor. Will be
+ determined via tensor._keras_history if not provided.
+ node_index: Origin node index of the tensor.
+
+ Returns:
+ List of input tensors.
+ """
+ if not hasattr(tensor, '_keras_history'):
+ return tensor
+
+ if layer is None or node_index:
+ layer, node_index, _ = tensor._keras_history
+ if not layer._inbound_nodes:
+ return [tensor]
+ else:
+ node = layer._inbound_nodes[node_index]
+ if not node.inbound_layers:
+ # Reached an Input layer, stop recursion.
+ return node.input_tensors
+ else:
+ source_tensors = []
+ for i in range(len(node.inbound_layers)):
+ x = node.input_tensors[i]
+ layer = node.inbound_layers[i]
+ node_index = node.node_indices[i]
+ previous_sources = get_source_inputs(x, layer, node_index)
+ # Avoid input redundancy.
+ for x in previous_sources:
+ if x not in source_tensors:
+ source_tensors.append(x)
+ return source_tensors
+
+
def count_params(weights):
"""Count the total number of scalars composing the weights.
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e5442f04e3..e1c49bc852 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -196,7 +196,7 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
batch_size = shape[:1]
input_shape = shape[1:]
step = batch_size // parts
- if i == num_gpus - 1:
+ if i == parts - 1:
size = batch_size - step * i
else:
size = step
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5d29c2e5f8..838cf836f1 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -893,6 +893,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
+ "//tensorflow/python:sparse_grad",
"//tensorflow/python:sparse_ops",
],
)
@@ -1524,6 +1525,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
+ tags = ["no_windows_gpu"],
)
cuda_py_test(
@@ -2056,6 +2058,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
+ tags = ["no_windows_gpu"],
)
tf_py_test(
@@ -2754,6 +2757,7 @@ cuda_py_test(
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:partitioned_variables",
@@ -2841,6 +2845,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
],
shard_count = 20,
+ tags = ["nomsan"], # TODO(b/110990716) reenable
)
cuda_py_test(
@@ -3087,3 +3092,22 @@ tf_py_test(
data = [":invalid_op.so"],
tags = ["no_pip"],
)
+
+tf_py_test(
+ name = "cond_v2_test",
+ size = "small",
+ srcs = ["cond_v2_test.py"],
+ additional_deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:cond_v2",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:training",
+ ],
+ grpc_enabled = True,
+)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 08bf2d9c64..40567571e6 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1006,7 +1006,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
class ShapeSizeRankTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDenseShape(self):
t_value = [[0, 42], [24, 0]]
self.assertAllEqual((2, 2), self.evaluate(array_ops.shape(t_value)))
@@ -1018,7 +1018,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
self.assertEqual(4, self.evaluate(array_ops.size(t)))
self.assertEqual(2, self.evaluate(array_ops.rank(t)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSparseShape(self):
sp_value = sparse_tensor.SparseTensorValue(
indices=((0, 1), (1, 0)), values=(42, 24), dense_shape=(2, 2))
@@ -1031,7 +1031,7 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
self.assertEqual(4, self.evaluate(array_ops.size(sp)))
self.assertEqual(2, self.evaluate(array_ops.rank(sp)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSizeDtype(self):
tensor = [1]
self.assertEqual(dtypes.int32, self.evaluate(array_ops.size(tensor)).dtype)
@@ -1123,7 +1123,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
class ConcatSliceResourceTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConcatSlice(self):
r1 = test_ops.stub_resource_handle_op(container="a", shared_name="b")
r2 = test_ops.stub_resource_handle_op(container="a", shared_name="c")
diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py
index 9d54add264..94ed8ebd31 100644
--- a/tensorflow/python/kernel_tests/as_string_op_test.py
+++ b/tensorflow/python/kernel_tests/as_string_op_test.py
@@ -130,6 +130,16 @@ class AsStringOpTest(test.TestCase):
result = output.eval(feed_dict={input_: int_inputs_})
self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_])
+ def testHalfInt(self):
+ s = lambda strs: [x.decode("ascii") for x in strs]
+
+ with self.test_session():
+ input_ = array_ops.placeholder(dtypes.int16)
+ int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max]
+ output = string_ops.as_string(input_)
+ result = output.eval(feed_dict={input_: int_inputs_})
+ self.assertAllEqual(s(result), ["%d" % x for x in int_inputs_])
+
def testBool(self):
bool_inputs_ = [False, True]
s = lambda strs: [x.decode("ascii") for x in strs]
diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py
index 0ef08581c9..b98e5fd386 100644
--- a/tensorflow/python/kernel_tests/atrous_convolution_test.py
+++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py
@@ -124,7 +124,7 @@ class AtrousConvolutionTest(test.TestCase):
x, w, "VALID", dilation_rate=[2, 2], data_format="NCHW")
self.assertEqual(y.shape.as_list(), [1, 20, None, None])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAtrousConvolution2D(self):
with self._delay_checks() as add_check:
for padding in ["SAME", "VALID"]:
@@ -139,7 +139,7 @@ class AtrousConvolutionTest(test.TestCase):
dilation_rate=dilation_rate,
)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAtrousConvolution3D(self):
with self._delay_checks() as add_check:
for padding in ["SAME", "VALID"]:
@@ -158,7 +158,7 @@ class AtrousConvolutionTest(test.TestCase):
dilation_rate=dilation_rate,
)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAtrousConvolution1D(self):
with self._delay_checks() as add_check:
for padding in ["SAME", "VALID"]:
@@ -173,7 +173,7 @@ class AtrousConvolutionTest(test.TestCase):
dilation_rate=[rate],
)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAtrousConvolutionNC(self):
if test.is_gpu_available(cuda_only=True):
# "NCW" and "NCHW" formats are currently supported only on CUDA.
@@ -197,7 +197,7 @@ class AtrousConvolutionTest(test.TestCase):
data_format="NCHW",
)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions.
diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py
index 08b03f8518..16fdedac41 100644
--- a/tensorflow/python/kernel_tests/betainc_op_test.py
+++ b/tensorflow/python/kernel_tests/betainc_op_test.py
@@ -172,7 +172,7 @@ class BetaincTest(test.TestCase):
tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s)
err = gradient_checker.compute_gradient_error(
[tf_gx_s], [gx_s.shape], tf_gout_t, gx_s.shape)
- print("betainc gradient err = %g " % err)
+ tf_logging.info("betainc gradient err = %g " % err)
self.assertLess(err, err_tolerance)
# Test broadcast gradient
@@ -181,7 +181,7 @@ class BetaincTest(test.TestCase):
tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s)
err = gradient_checker.compute_gradient_error(
[tf_gx_s], [()], tf_gout_t, ga_s.shape)
- print("betainc gradient err = %g " % err)
+ tf_logging.info("betainc gradient err = %g " % err)
self.assertLess(err, err_tolerance)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 92cd53a031..4e31b1ea2a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -910,7 +910,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
feature_1_values = [11, 27]
# Example 1: tree 0: 1.14, tree 1: 5.0, tree 2: 5.0 = >
- # logit = 0.1*5.0+0.2*5.0+1*5
+ # logit = 0.1*1.14+0.2*5.0+1*5
# Example 2: tree 0: 1.14, tree 1: 7.0, tree 2: -7 = >
# logit= 0.1*1.14+0.2*7.0-1*7.0
expected_logits = [[6.114], [-5.486]]
@@ -925,5 +925,147 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(expected_logits, logits)
+class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
+ """Tests feature contribs ops for model understanding."""
+
+ def testContribsMultipleTree(self):
+ """Tests that the contribs work when we have multiple trees."""
+ with self.test_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 28
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 7.62
+ original_leaf: {scalar: 2.1}
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.14
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 8.79
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf: {scalar: 5.5}
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 34
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_weights: 0.2
+ tree_weights: 1.0
+ tree_metadata: {
+ num_layers_grown: 1}
+ tree_metadata: {
+ num_layers_grown: 2}
+ tree_metadata: {
+ num_layers_grown: 1}
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29] # Unused. Feature is not in above ensemble.
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ bias = 2.1 * 0.1 # Root node of tree_0.
+ expected_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+ # example_0 : (bias, 0.1 * 1.14, 0.2 * 5.5 + .114, 0.2 * 5. + .114,
+ # 1.0 * 5.0 + 0.2 * 5. + .114)
+ # example_1 : (bias, 0.1 * 1.14, 0.2 * 7 + .114,
+ # 1.0 * -7. + 0.2 * 7 + .114)
+ expected_logits_paths = ((bias, 0.114, 1.214, 1.114, 6.114),
+ (bias, 0.114, 1.514, -5.486))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index 13b804875e..d55240297a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -139,6 +139,49 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(new_stamp, 1)
self.assertProtoEquals(expected_result, tree_ensemble)
+ def testBiasCenteringOnEmptyEnsemble(self):
+ """Test growing with bias centering on an empty ensemble."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ gradients = np.array([[5.]], dtype=np.float32)
+ hessians = np.array([[24.]], dtype=np.float32)
+
+ # Grow tree ensemble.
+ grow_op = boosted_trees_ops.center_bias(
+ tree_ensemble_handle,
+ mean_gradients=gradients,
+ mean_hessians=hessians,
+ l1=0.0,
+ l2=1.0
+ )
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(tree_ensemble.serialize())
+
+ tree_ensemble = boosted_trees_pb2.TreeEnsemble()
+ tree_ensemble.ParseFromString(serialized)
+
+ expected_result = """
+ trees {
+ nodes {
+ leaf {
+ scalar: -0.2
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ num_layers_grown: 0
+ is_finalized: false
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble)
+
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
with self.test_session() as session:
@@ -666,7 +709,6 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
num_layers_attempted: 1
last_layer_node_start: 1
last_layer_node_end: 3
-
}
""", tree_ensemble_config)
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 7ef841c96b..bda6ca5ca9 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -34,45 +34,45 @@ from tensorflow.python.platform import test
class AssertProperIterableTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_tensor_raises(self):
tensor = constant_op.constant(1)
with self.assertRaisesRegexp(TypeError, "proper"):
check_ops.assert_proper_iterable(tensor)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_sparse_tensor_raises(self):
ten = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
with self.assertRaisesRegexp(TypeError, "proper"):
check_ops.assert_proper_iterable(ten)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_ndarray_raises(self):
array = np.array([1, 2, 3])
with self.assertRaisesRegexp(TypeError, "proper"):
check_ops.assert_proper_iterable(array)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_single_string_raises(self):
mystr = "hello"
with self.assertRaisesRegexp(TypeError, "proper"):
check_ops.assert_proper_iterable(mystr)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_non_iterable_object_raises(self):
non_iterable = 1234
with self.assertRaisesRegexp(TypeError, "to be iterable"):
check_ops.assert_proper_iterable(non_iterable)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_list_does_not_raise(self):
list_of_stuff = [
constant_op.constant([11, 22]), constant_op.constant([1, 2])
]
check_ops.assert_proper_iterable(list_of_stuff)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_generator_does_not_raise(self):
generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant(
[1, 2]))
@@ -81,14 +81,14 @@ class AssertProperIterableTest(test.TestCase):
class AssertEqualTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies([check_ops.assert_equal(small, small)]):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_scalar_comparison(self):
const_true = constant_op.constant(True, name="true")
const_false = constant_op.constant(False, name="false")
@@ -101,7 +101,7 @@ class AssertEqualTest(test.TestCase):
x = check_ops.assert_equal(small, small)
assert x is None
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_greater(self):
# Static check
static_small = constant_op.constant([1, 2], name="small")
@@ -179,7 +179,7 @@ First 2 elements of y:
check_ops.assert_equal(big, small, message="big does not equal small",
summarize=2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less(self):
# Static check
static_small = constant_op.constant([3, 1], name="small")
@@ -196,7 +196,7 @@ First 2 elements of y:
with self.assertRaisesOpError("small.*big"):
out.eval(feed_dict={small: [3, 1], big: [4, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
small = constant_op.constant([[1, 2], [1, 2]], name="small")
small_2 = constant_op.constant([1, 2], name="small_2")
@@ -204,7 +204,7 @@ First 2 elements of y:
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
small_2 = constant_op.constant([1, 1], name="small_2")
@@ -219,13 +219,13 @@ First 2 elements of y:
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_and_broadcastable_shapes(self):
cond = constant_op.constant([True, False], name="small")
with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
check_ops.assert_equal(cond, False, message="fail")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -236,7 +236,7 @@ First 2 elements of y:
class AssertNoneEqualTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_not_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([10, 20], name="small")
@@ -245,7 +245,7 @@ class AssertNoneEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_equal(self):
small = constant_op.constant([3, 1], name="small")
with self.assertRaisesOpError("x != y did not hold"):
@@ -254,7 +254,7 @@ class AssertNoneEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3], name="big")
@@ -263,7 +263,7 @@ class AssertNoneEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
with self.test_session():
small = constant_op.constant([1, 1, 1], name="small")
@@ -280,7 +280,7 @@ class AssertNoneEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
with self.test_session():
larry = constant_op.constant([])
@@ -300,7 +300,7 @@ class AssertNoneEqualTest(test.TestCase):
class AssertAllCloseTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
x = constant_op.constant(1., name="x")
y = constant_op.constant(1., name="y")
@@ -309,7 +309,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_32_bit_due_to_default_rtol(self):
eps = np.finfo(np.float32).eps
# Default rtol/atol is 10*eps
@@ -320,7 +320,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_32_bit_due_to_default_atol(self):
eps = np.finfo(np.float32).eps
# Default rtol/atol is 10*eps
@@ -331,7 +331,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_64_bit_due_to_default_rtol(self):
eps = np.finfo(np.float64).eps
# Default rtol/atol is 10*eps
@@ -342,7 +342,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_64_bit_due_to_default_atol(self):
eps = np.finfo(np.float64).eps
# Default rtol/atol is 10*eps
@@ -353,7 +353,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_due_to_custom_rtol(self):
x = constant_op.constant(1., name="x")
y = constant_op.constant(1.1, name="y")
@@ -363,7 +363,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_close_enough_due_to_custom_atol(self):
x = constant_op.constant(0., name="x")
y = constant_op.constant(0.1, name="y", dtype=np.float32)
@@ -373,7 +373,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -381,7 +381,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(larry)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_atol_violated(self):
x = constant_op.constant(10., name="x")
y = constant_op.constant(10.2, name="y")
@@ -392,7 +392,7 @@ class AssertAllCloseTest(test.TestCase):
out = array_ops.identity(x)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_default_rtol_violated(self):
x = constant_op.constant(0.1, name="x")
y = constant_op.constant(0.0, name="y")
@@ -412,7 +412,7 @@ class AssertAllCloseTest(test.TestCase):
class AssertLessTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"):
@@ -422,7 +422,7 @@ class AssertLessTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
@@ -431,7 +431,7 @@ class AssertLessTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([4, 2], name="big")
@@ -439,7 +439,7 @@ class AssertLessTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -447,7 +447,7 @@ class AssertLessTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -462,7 +462,7 @@ class AssertLessTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -480,7 +480,7 @@ class AssertLessTest(test.TestCase):
class AssertLessEqualTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies(
@@ -488,7 +488,7 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_greater(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
@@ -499,7 +499,7 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -507,7 +507,7 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 1], name="big")
@@ -515,7 +515,7 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([1, 1, 1], name="big")
@@ -531,7 +531,7 @@ class AssertLessEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -543,7 +543,7 @@ class AssertLessEqualTest(test.TestCase):
class AssertGreaterTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with self.assertRaisesOpError("fail"):
@@ -553,7 +553,7 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
@@ -562,7 +562,7 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(big)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater(self):
small = constant_op.constant([3, 1], name="small")
big = constant_op.constant([4, 2], name="big")
@@ -570,7 +570,7 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -578,7 +578,7 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_greater_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -593,7 +593,7 @@ class AssertGreaterTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -604,7 +604,7 @@ class AssertGreaterTest(test.TestCase):
class AssertGreaterEqualTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_equal(self):
small = constant_op.constant([1, 2], name="small")
with ops.control_dependencies(
@@ -612,7 +612,7 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 4], name="big")
@@ -623,7 +623,7 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_equal(self):
small = constant_op.constant([1, 2], name="small")
big = constant_op.constant([3, 2], name="big")
@@ -632,7 +632,7 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self):
small = constant_op.constant([1], name="small")
big = constant_op.constant([3, 1], name="big")
@@ -641,7 +641,7 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
small = constant_op.constant([1, 1, 1], name="big")
big = constant_op.constant([3, 1], name="small")
@@ -657,7 +657,7 @@ class AssertGreaterEqualTest(test.TestCase):
out = array_ops.identity(small)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_both_empty(self):
larry = constant_op.constant([])
curly = constant_op.constant([])
@@ -669,14 +669,14 @@ class AssertGreaterEqualTest(test.TestCase):
class AssertNegativeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_negative(self):
frank = constant_op.constant([-1, -2], name="frank")
with ops.control_dependencies([check_ops.assert_negative(frank)]):
out = array_ops.identity(frank)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_positive(self):
doug = constant_op.constant([1, 2], name="doug")
with self.assertRaisesOpError("fail"):
@@ -686,7 +686,7 @@ class AssertNegativeTest(test.TestCase):
out = array_ops.identity(doug)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_zero(self):
claire = constant_op.constant([0], name="claire")
with self.assertRaisesOpError("x < 0 did not hold"):
@@ -694,7 +694,7 @@ class AssertNegativeTest(test.TestCase):
out = array_ops.identity(claire)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is negative when it satisfies:
# For every element x_i in x, x_i < 0
@@ -708,7 +708,7 @@ class AssertNegativeTest(test.TestCase):
class AssertPositiveTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_negative(self):
freddie = constant_op.constant([-1, -2], name="freddie")
with self.assertRaisesOpError("fail"):
@@ -718,14 +718,14 @@ class AssertPositiveTest(test.TestCase):
out = array_ops.identity(freddie)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_positive(self):
remmy = constant_op.constant([1, 2], name="remmy")
with ops.control_dependencies([check_ops.assert_positive(remmy)]):
out = array_ops.identity(remmy)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_zero(self):
meechum = constant_op.constant([0], name="meechum")
with self.assertRaisesOpError("x > 0 did not hold"):
@@ -733,7 +733,7 @@ class AssertPositiveTest(test.TestCase):
out = array_ops.identity(meechum)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is positive when it satisfies:
# For every element x_i in x, x_i > 0
@@ -747,7 +747,7 @@ class AssertPositiveTest(test.TestCase):
class AssertRankTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 1
@@ -768,7 +768,7 @@ class AssertRankTest(test.TestCase):
with self.assertRaisesOpError("fail.*my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 0
@@ -784,7 +784,7 @@ class AssertRankTest(test.TestCase):
[check_ops.assert_rank(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 0
@@ -802,7 +802,7 @@ class AssertRankTest(test.TestCase):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 1
@@ -818,7 +818,7 @@ class AssertRankTest(test.TestCase):
[check_ops.assert_rank(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 2
@@ -836,7 +836,7 @@ class AssertRankTest(test.TestCase):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_scalar_static(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
@@ -852,7 +852,7 @@ class AssertRankTest(test.TestCase):
[check_ops.assert_rank(tensor, rank_tensor)]):
array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_integer_static(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
with self.assertRaisesRegexp(TypeError,
@@ -873,7 +873,7 @@ class AssertRankTest(test.TestCase):
class AssertRankInTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
tensor_rank0 = constant_op.constant(42, name="my_tensor")
with self.assertRaisesRegexp(
@@ -890,7 +890,7 @@ class AssertRankInTest(test.TestCase):
with self.assertRaisesOpError("fail.*my_tensor.*rank"):
array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
tensor_rank0 = constant_op.constant(42, name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
@@ -906,7 +906,7 @@ class AssertRankInTest(test.TestCase):
check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
@@ -924,7 +924,7 @@ class AssertRankInTest(test.TestCase):
tensor_rank1: (42.0, 43.0)
})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
with self.assertRaisesRegexp(ValueError, "rank"):
@@ -942,7 +942,7 @@ class AssertRankInTest(test.TestCase):
tensor_rank1: (42.0, 43.0)
})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_scalar_static(self):
tensor = constant_op.constant((42, 43), name="my_tensor")
desired_ranks = (
@@ -966,7 +966,7 @@ class AssertRankInTest(test.TestCase):
desired_ranks[1]: [2, 1],
})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_if_rank_is_not_integer_static(self):
tensor = constant_op.constant((42, 43), name="my_tensor")
with self.assertRaisesRegexp(TypeError,
@@ -987,7 +987,7 @@ class AssertRankInTest(test.TestCase):
class AssertRankAtLeastTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 1
@@ -1005,7 +1005,7 @@ class AssertRankAtLeastTest(test.TestCase):
with self.assertRaisesOpError("my_tensor.*rank"):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant(1, name="my_tensor")
desired_rank = 0
@@ -1021,7 +1021,7 @@ class AssertRankAtLeastTest(test.TestCase):
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: 0})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 0
@@ -1037,7 +1037,7 @@ class AssertRankAtLeastTest(test.TestCase):
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 1
@@ -1053,7 +1053,7 @@ class AssertRankAtLeastTest(test.TestCase):
[check_ops.assert_rank_at_least(tensor, desired_rank)]):
array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
tensor = constant_op.constant([1, 2], name="my_tensor")
desired_rank = 2
@@ -1074,7 +1074,7 @@ class AssertRankAtLeastTest(test.TestCase):
class AssertNonNegativeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_negative(self):
zoe = constant_op.constant([-1, -2], name="zoe")
with self.assertRaisesOpError("x >= 0 did not hold"):
@@ -1082,14 +1082,14 @@ class AssertNonNegativeTest(test.TestCase):
out = array_ops.identity(zoe)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_zero_and_positive(self):
lucas = constant_op.constant([0, 2], name="lucas")
with ops.control_dependencies([check_ops.assert_non_negative(lucas)]):
out = array_ops.identity(lucas)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is non-negative when it satisfies:
# For every element x_i in x, x_i >= 0
@@ -1103,14 +1103,14 @@ class AssertNonNegativeTest(test.TestCase):
class AssertNonPositiveTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_zero_and_negative(self):
tom = constant_op.constant([0, -2], name="tom")
with ops.control_dependencies([check_ops.assert_non_positive(tom)]):
out = array_ops.identity(tom)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_positive(self):
rachel = constant_op.constant([0, 2], name="rachel")
with self.assertRaisesOpError("x <= 0 did not hold"):
@@ -1118,7 +1118,7 @@ class AssertNonPositiveTest(test.TestCase):
out = array_ops.identity(rachel)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_doesnt_raise(self):
# A tensor is non-positive when it satisfies:
# For every element x_i in x, x_i <= 0
@@ -1132,14 +1132,14 @@ class AssertNonPositiveTest(test.TestCase):
class AssertIntegerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_integer(self):
integers = constant_op.constant([1, 2], name="integers")
with ops.control_dependencies([check_ops.assert_integer(integers)]):
out = array_ops.identity(integers)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_float(self):
floats = constant_op.constant([1.0, 2.0], name="floats")
with self.assertRaisesRegexp(TypeError, "Expected.*integer"):
@@ -1148,7 +1148,7 @@ class AssertIntegerTest(test.TestCase):
class AssertTypeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_when_correct_type(self):
integers = constant_op.constant([1, 2], dtype=dtypes.int64)
with ops.control_dependencies([
@@ -1156,7 +1156,7 @@ class AssertTypeTest(test.TestCase):
out = array_ops.identity(integers)
self.evaluate(out)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_raises_when_wrong_type(self):
floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16)
with self.assertRaisesRegexp(TypeError, "must be of type.*float32"):
@@ -1165,74 +1165,74 @@ class AssertTypeTest(test.TestCase):
class IsStrictlyIncreasingTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_constant_tensor_is_not_strictly_increasing(self):
self.assertFalse(self.evaluate(check_ops.is_strictly_increasing([1, 1, 1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_decreasing_tensor_is_not_strictly_increasing(self):
self.assertFalse(self.evaluate(
check_ops.is_strictly_increasing([1, 0, -1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_2d_decreasing_tensor_is_not_strictly_increasing(self):
self.assertFalse(
self.evaluate(check_ops.is_strictly_increasing([[1, 3], [2, 4]])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_increasing_tensor_is_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1, 2, 3])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_increasing_rank_two_tensor(self):
self.assertTrue(
self.evaluate(check_ops.is_strictly_increasing([[-1, 2], [3, 4]])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_tensor_with_one_element_is_strictly_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_is_strictly_increasing(self):
self.assertTrue(self.evaluate(check_ops.is_strictly_increasing([])))
class IsNonDecreasingTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_constant_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 1, 1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_decreasing_tensor_is_not_non_decreasing(self):
self.assertFalse(self.evaluate(check_ops.is_non_decreasing([3, 2, 1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_2d_decreasing_tensor_is_not_non_decreasing(self):
self.assertFalse(self.evaluate(
check_ops.is_non_decreasing([[1, 3], [2, 4]])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_increasing_rank_one_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1, 2, 3])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_increasing_rank_two_tensor(self):
self.assertTrue(self.evaluate(
check_ops.is_non_decreasing([[-1, 2], [3, 3]])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_tensor_with_one_element_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([1])))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_empty_tensor_is_non_decreasing(self):
self.assertTrue(self.evaluate(check_ops.is_non_decreasing([])))
class FloatDTypeTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_assert_same_float_dtype(self):
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(None, None))
@@ -1286,7 +1286,7 @@ class FloatDTypeTest(test.TestCase):
class AssertScalarTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_assert_scalar(self):
check_ops.assert_scalar(constant_op.constant(3))
check_ops.assert_scalar(constant_op.constant("foo"))
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
new file mode 100644
index 0000000000..759db5d5f4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -0,0 +1,536 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for cond_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver
+from tensorflow.python.util import compat
+
+
+class NewCondTest(test.TestCase):
+
+ def _testCond(self, true_fn, false_fn, train_vals):
+ with self.test_session() as sess:
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+
+ expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
+ actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual")
+
+ expected_grad = gradients_impl.gradients(expected, train_vals)
+ actual_grad = gradients_impl.gradients(actual, train_vals)
+
+ expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
+ (expected, actual, expected_grad, actual_grad), {pred: True})
+ self.assertEqual(expected_val, actual_val)
+ self.assertEqual(expected_grad_val, actual_grad_val)
+
+ expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run(
+ (expected, actual, expected_grad, actual_grad), {pred: False})
+ self.assertEqual(expected_val, actual_val)
+ self.assertEqual(expected_grad_val, actual_grad_val)
+
+ def testBasic(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return x * 2.0
+
+ def false_fn():
+ return y * 3.0
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
+ def testBasic2(self):
+ x = constant_op.constant(1.0, name="x")
+ y = constant_op.constant(2.0, name="y")
+
+ def true_fn():
+ return x * y * 2.0
+
+ def false_fn():
+ return 2.0
+
+ self._testCond(true_fn, false_fn, [x])
+ self._testCond(true_fn, false_fn, [x, y])
+ self._testCond(true_fn, false_fn, [y])
+
+ def testNoInputs(self):
+ with self.test_session() as sess:
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+
+ def true_fn():
+ return constant_op.constant(1.0)
+
+ def false_fn():
+ return constant_op.constant(2.0)
+
+ out = cond_v2.cond_v2(pred, true_fn, false_fn)
+
+ self.assertEqual(sess.run(out, {pred: True}), [1.0])
+ self.assertEqual(sess.run(out, {pred: False}), [2.0])
+
+ def _createCond(self, name):
+ pred = constant_op.constant(True, name="pred")
+ x = constant_op.constant(1.0, name="x")
+
+ def true_fn():
+ return x
+
+ def false_fn():
+ return x + 1
+
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
+
+ def testDefaultName(self):
+ with ops.Graph().as_default():
+ cond = self._createCond(None)
+ self.assertEqual(cond.name, "cond")
+ self.assertIn("cond_true", ops.get_default_graph()._functions)
+ self.assertIn("cond_false", ops.get_default_graph()._functions)
+
+ with ops.Graph().as_default():
+ with ops.name_scope("foo"):
+ cond = self._createCond("")
+ self.assertEqual(cond.name, "foo/cond")
+ self.assertIn("foo_cond_true", ops.get_default_graph()._functions)
+ self.assertIn("foo_cond_false", ops.get_default_graph()._functions)
+
+ cond2 = self._createCond(None)
+ self.assertEqual(cond2.name, "foo/cond_1")
+ self.assertIn("foo_cond_1_true", ops.get_default_graph()._functions)
+ self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
+
+ def testSecondDerivative(self):
+ with self.test_session() as sess:
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+ x = constant_op.constant(3.0, name="x")
+
+ def true_fn():
+ return math_ops.pow(x, 3)
+
+ def false_fn():
+ return x
+
+ cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
+ cond_grad = gradients_impl.gradients(cond, [x])
+ cond_grad_grad = gradients_impl.gradients(cond_grad, [x])
+
+ # d[x^3]/dx = 3x^2
+ true_val = sess.run(cond_grad, {pred: True})
+ self.assertEqual(true_val, [27.0])
+ # d[x]/dx = 1
+ false_val = sess.run(cond_grad, {pred: False})
+ self.assertEqual(false_val, [1.0])
+
+ true_val = sess.run(cond_grad_grad, {pred: True})
+ # d2[x^3]/dx2 = 6x
+ self.assertEqual(true_val, [18.0])
+ false_val = sess.run(cond_grad_grad, {pred: False})
+ # d2[x]/dx2 = 0
+ self.assertEqual(false_val, [0.0])
+
+ def testGradientOfDeserializedCond(self):
+ with ops.Graph().as_default():
+ pred = array_ops.placeholder(dtypes.bool, name="pred")
+ x = constant_op.constant(3.0, name="x")
+ ops.add_to_collection("x", x)
+
+ def true_fn():
+ return math_ops.pow(x, 3)
+
+ def false_fn():
+ return x
+
+ ops.add_to_collection("pred", pred)
+ cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
+ for c in cond:
+ ops.add_to_collection("cond", c)
+ meta_graph = saver.export_meta_graph()
+
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ saver.import_meta_graph(meta_graph)
+ x = ops.get_collection("x")[0]
+ pred = ops.get_collection("pred")[0]
+ cond = ops.get_collection("cond")
+ cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
+ cond_grad_grad = gradients_impl.gradients(
+ cond_grad, [x], name="cond_grad_grad")
+ # d[x^3]/dx = 3x^2
+ true_val = sess.run(cond_grad, {pred: True})
+ self.assertEqual(true_val, [27.0])
+ # d[x]/dx = 1
+ false_val = sess.run(cond_grad, {pred: False})
+ self.assertEqual(false_val, [1.0])
+
+ true_val = sess.run(cond_grad_grad, {pred: True})
+ # d2[x^3]/dx2 = 6x
+ self.assertEqual(true_val, [18.0])
+ false_val = sess.run(cond_grad_grad, {pred: False})
+ # d2[x]/dx2 = 0
+ self.assertEqual(false_val, [0.0])
+
+ def testLowering(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g) as sess:
+ out_cond = self._createCond("cond")
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond, options=run_options, run_metadata=run_metadata)
+
+ # If lowering was enabled, there should be a `Switch` node
+ switch_found = any(
+ any(node.op == "Switch" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertTrue(switch_found,
+ "A `Switch` op should exist if the graph was lowered.")
+
+ # If lowering was enabled, there should be no `If` node
+ if_found = any(
+ any(node.op == "If" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertFalse(if_found,
+ "An `If` op was found, but it should be lowered.")
+
+ def testLoweringDisabledInXLA(self):
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Build the cond_v2 in an XLA context
+ xla_context = control_flow_ops.XLAControlFlowContext()
+ xla_context.Enter()
+ out_cond = self._createCond("cond")
+ xla_context.Exit()
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond, options=run_options, run_metadata=run_metadata)
+
+ # Lowering disabled in XLA, there should be no `Switch` node
+ switch_found = any(
+ any(node.op == "Switch" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertFalse(
+ switch_found,
+ "A `Switch` op exists, but the graph should not be lowered.")
+
+ # Lowering disabled in XLA, there should still be an `If` node
+ if_found = any(
+ any(node.op == "If" for node in graph.node)
+ for graph in run_metadata.partition_graphs
+ )
+
+ self.assertTrue(
+ if_found,
+ "An `If` op was not found, but the graph should not be lowered.")
+
+
+class CondV2CollectionTest(test.TestCase):
+
+ def testCollectionIntValueAccessInCond(self):
+ """Read values from graph collections inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = 2
+ y = 5
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+ def fn():
+ x_const = constant_op.constant(ops.get_collection("x")[0])
+ y_const = constant_op.constant(ops.get_collection("y")[0])
+ return math_ops.add(x_const, y_const)
+
+ cnd = cond_v2.cond_v2(True, fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionTensorValueAccessInCond(self):
+ """Read tensors from collections inside of cond_v2 & use them."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ ops.add_to_collection("x", x)
+ ops.add_to_collection("y", y)
+
+ def fn():
+ x_read = ops.get_collection("x")[0]
+ y_read = ops.get_collection("y")[0]
+ return math_ops.add(x_read, y_read)
+
+ cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
+ self.assertEquals(cnd[0].eval(), 7)
+
+ def testCollectionIntValueWriteInCond(self):
+ """Make sure Int writes to collections work inside of cond_v2."""
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ x = constant_op.constant(2)
+ y = constant_op.constant(5)
+ def true_fn():
+ z = math_ops.add(x, y)
+ ops.add_to_collection("z", 7)
+ return math_ops.mul(x, z)
+
+ def false_fn():
+ z = math_ops.add(x, y)
+ return math_ops.mul(x, z)
+
+ cnd = cond_v2.cond_v2(
+ True, true_fn,
+ false_fn)
+ self.assertEquals(cnd[0].eval(), 14)
+
+ read_z_collection = ops.get_collection("z")
+ self.assertEquals(read_z_collection, [7])
+
+
+class CondV2ContainerTest(test.TestCase):
+
+ def testContainer(self):
+ """Set containers outside & inside of cond_v2.
+
+ Make sure the containers are set correctly for both variable creation
+ (tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
+ """
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ v0 = variables.Variable([0])
+ q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ def container(node):
+ return node.op.get_attr("container")
+
+ self.assertEqual(compat.as_bytes(""), container(v0))
+ self.assertEqual(compat.as_bytes(""), container(q0.queue_ref))
+
+ def true_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2t"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2t"), container(v2))
+ self.assertEqual(compat.as_bytes("l2t"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(2.0)
+
+ def false_fn():
+ # When this branch is created in cond below,
+ # the container should begin with 'l1'
+ v1 = variables.Variable([1])
+ q1 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ with ops.container("l2f"):
+ v2 = variables.Variable([2])
+ q2 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ v3 = variables.Variable([1])
+ q3 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v1))
+ self.assertEqual(compat.as_bytes("l1"), container(q1.queue_ref))
+ self.assertEqual(compat.as_bytes("l2f"), container(v2))
+ self.assertEqual(compat.as_bytes("l2f"), container(q2.queue_ref))
+ self.assertEqual(compat.as_bytes("l1"), container(v3))
+ self.assertEqual(compat.as_bytes("l1"), container(q3.queue_ref))
+
+ return constant_op.constant(6.0)
+
+ with ops.container("l1"):
+ cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
+ self.assertEquals(cnd_true[0].eval(), 2)
+
+ cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
+ self.assertEquals(cnd_false[0].eval(), 6)
+
+ v4 = variables.Variable([3])
+ q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+ v5 = variables.Variable([4])
+ q5 = data_flow_ops.FIFOQueue(1, dtypes.float32)
+
+ self.assertEqual(compat.as_bytes("l1"), container(v4))
+ self.assertEqual(compat.as_bytes("l1"), container(q4.queue_ref))
+ self.assertEqual(compat.as_bytes(""), container(v5))
+ self.assertEqual(compat.as_bytes(""), container(q5.queue_ref))
+
+
+class CondV2ColocationGroupAndDeviceTest(test.TestCase):
+
+ def testColocateWithBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testColocateWithInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant([2.0], name="b")
+
+ def fn2():
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a", b"loc:@b"], c.op.colocation_groups())
+ return c
+
+ with ops.colocate_with(a.op):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant([2.0], name="d")
+ self.assertEqual([b"loc:@a"], d.op.colocation_groups())
+
+ def testColocateWithInCondGraphPartitioning(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(
+ graph=g,
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})
+ ) as sess:
+
+ with ops.device("/device:CPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ with ops.device("/device:CPU:1"):
+ b = constant_op.constant([2.0], name="b")
+
+ def fn():
+ with ops.colocate_with(b.op):
+ c = math_ops.add(a, a, name="c")
+ return c
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
+
+ # We expect there to be two partitions because of the
+ # colocate_with. We are only running the cond, which has a data
+ # dependency on `a` but not on `b`. So, without the colocate_with
+ # we would expect execution on just one device.
+ self.assertTrue(len(run_metadata.partition_graphs) >= 2)
+
+ def testDeviceBeforeCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:CPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+
+ def fn2():
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:GPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ def testDeviceInAndOutOfCond(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ def fn2():
+ with ops.device("/device:GPU:0"):
+ c = constant_op.constant(3.0)
+ self.assertEqual("/device:GPU:0", c.op.device)
+ return c
+
+ with ops.device("/device:CPU:0"):
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+
+ d = constant_op.constant(4.0)
+ self.assertEqual("/device:CPU:0", d.op.device)
+
+ def testDeviceInCondGraphPartitioning(self):
+ with ops.Graph().as_default() as g:
+ with self.test_session(
+ graph=g,
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})
+ ) as sess:
+
+ def fn():
+ with ops.device("/device:CPU:1"):
+ c = math_ops.add(a, a, name="c")
+ return c
+
+ with ops.device("/device:CPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(out_cond_2, options=run_options, run_metadata=run_metadata)
+
+ self.assertTrue(len(run_metadata.partition_graphs) >= 2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 79e419867d..ae6875340e 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class ConfusionMatrixTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testExample(self):
"""This is a test of the example provided in pydoc."""
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index 8e9d75667d..a0d5557b92 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -32,6 +32,9 @@ from tensorflow.python.util import compat
# TODO(josh11b): add tests with lists/tuples, Shape.
+# TODO(ashankar): Collapse with tests in constant_op_test.py and use something
+# like the test_util.run_in_graph_and_eager_modes decorator to confirm
+# equivalence between graph and eager execution.
class ConstantTest(test.TestCase):
def _testCpu(self, x):
@@ -280,6 +283,34 @@ class ConstantTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, None):
constant_op.constant([[1, 2], [3], [4, 5]])
+ # TODO(ashankar): This test fails with graph construction since
+ # tensor_util.make_tensor_proto (invoked from constant_op.constant)
+ # does not handle iterables (it relies on numpy conversion).
+ # For consistency, should graph construction handle Python objects
+ # that implement the sequence protocol (but not numpy conversion),
+ # or should eager execution fail on such sequences?
+ def testCustomSequence(self):
+
+ # This is inspired by how many objects in pandas are implemented:
+ # - They implement the Python sequence protocol
+ # - But may raise a KeyError on __getitem__(self, 0)
+ # See https://github.com/tensorflow/tensorflow/issues/20347
+ class MySeq(object):
+
+ def __getitem__(self, key):
+ if key != 1 and key != 3:
+ raise KeyError(key)
+ return key
+
+ def __len__(self):
+ return 2
+
+ def __iter__(self):
+ l = list([1, 3])
+ return l.__iter__()
+
+ self.assertAllEqual([1, 3], self.evaluate(constant_op.constant(MySeq())))
+
class AsTensorTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index a291bef0ad..474d06b8f3 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -312,8 +312,8 @@ class Conv2DTest(test.TestCase):
expected_values = self.evaluate(expected_results)
computed_values = self.evaluate(computed_results)
for e_value, c_value in zip(expected_values, computed_values):
- print("expected = ", e_value)
- print("actual = ", c_value)
+ tf_logging.info("expected = ", e_value)
+ tf_logging.info("actual = ", c_value)
self.assertAllClose(
e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4)
@@ -337,15 +337,15 @@ class Conv2DTest(test.TestCase):
for i in range(len(tensors)):
conv = tensors[i]
value = values[i]
- print("expected = ", expected)
- print("actual = ", value)
+ tf_logging.info("expected = ", expected)
+ tf_logging.info("actual = ", value)
tol = 1e-5
if value.dtype == np.float16:
tol = 1e-3
self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol)
self.assertShapeEqual(value, conv)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D1x1Filter(self):
expected_output = [
30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
@@ -358,7 +358,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Filter2x1Dilation(self):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 4, 4, 1],
@@ -367,7 +367,7 @@ class Conv2DTest(test.TestCase):
dilations=[2, 1],
padding="VALID")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DEmpty(self):
expected_output = []
self._VerifyValues(
@@ -377,7 +377,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DEmptyDilation(self):
self._VerifyDilatedConvValues(
tensor_in_sizes=[0, 2, 3, 3],
@@ -386,7 +386,7 @@ class Conv2DTest(test.TestCase):
dilations=[2, 1],
padding="VALID")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
@@ -397,7 +397,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2FilterDilation(self):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
@@ -406,7 +406,7 @@ class Conv2DTest(test.TestCase):
dilations=[1, 2],
padding="VALID")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D1x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [
@@ -420,7 +420,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D1x2FilterDilation(self):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 2, 3, 3],
@@ -429,7 +429,7 @@ class Conv2DTest(test.TestCase):
dilations=[2, 1],
padding="VALID")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2FilterStride2(self):
expected_output = [2271.0, 2367.0, 2463.0]
self._VerifyValues(
@@ -439,7 +439,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2FilterStride2Same(self):
expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
self._VerifyValues(
@@ -449,7 +449,7 @@ class Conv2DTest(test.TestCase):
padding="SAME",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2FilterStride1x2(self):
expected_output = [58.0, 78.0, 98.0, 118.0, 138.0, 158.0]
self._VerifyValues(
@@ -459,7 +459,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSmallerThanStrideValid(self):
expected_output = [65, 95, 275, 305]
self._VerifyValues(
@@ -469,7 +469,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=expected_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSmallerThanStrideSame(self):
self._VerifyValues(
tensor_in_sizes=[1, 3, 3, 1],
@@ -492,7 +492,7 @@ class Conv2DTest(test.TestCase):
padding="SAME",
expected=[44, 28, 41, 16])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSizeMatchesInputSize(self):
self._VerifyValues(
tensor_in_sizes=[1, 2, 2, 1],
@@ -501,7 +501,7 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=[50, 60])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSizeMatchesInputSizeDilation(self):
self._VerifyDilatedConvValues(
tensor_in_sizes=[1, 3, 3, 1],
@@ -547,8 +547,8 @@ class Conv2DTest(test.TestCase):
# "values" consists of two tensors for two backprops
value = self.evaluate(conv)
self.assertShapeEqual(value, conv)
- print("expected = ", expected)
- print("actual = ", value)
+ tf_logging.info("expected = ", expected)
+ tf_logging.info("actual = ", value)
self.assertArrayNear(expected, value.flatten(), err)
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
@@ -587,9 +587,9 @@ class Conv2DTest(test.TestCase):
values.append(_GetVal(data_format, use_gpu))
for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
+ self.assertAllClose(values[0], values[i], rtol=1e-2, atol=1e-2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth1ValidBackpropInput(self):
expected_output = [1.0, 4.0, 4.0, 3.0, 10.0, 8.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -604,7 +604,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DEmptyBackpropInput(self):
expected_output = []
for (data_format, use_gpu) in GetTestConfigs():
@@ -619,7 +619,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth3ValidBackpropInput(self):
expected_output = [
14.0, 32.0, 50.0, 100.0, 163.0, 226.0, 167.0, 212.0, 257.0, 122.0,
@@ -639,7 +639,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-4)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth3ValidBackpropInputStride1x2(self):
expected_output = [
1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 7.0, 12.0, 11.0, 18.0, 15.0, 24.0, 12.0,
@@ -657,7 +657,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DStrideTwoFilterOneSameBackpropInput(self):
expected_output = [
1.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0,
@@ -675,7 +675,7 @@ class Conv2DTest(test.TestCase):
use_gpu=use_gpu,
err=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSizeMatchesInputSizeBackpropInput(self):
expected_output = [5.0, 11.0, 17.0, 23.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -723,8 +723,8 @@ class Conv2DTest(test.TestCase):
data_format=data_format)
value = self.evaluate(conv)
self.assertShapeEqual(value, conv)
- print("expected = ", expected)
- print("actual = ", value)
+ tf_logging.info("expected = ", expected)
+ tf_logging.info("actual = ", value)
self.assertArrayNear(expected, value.flatten(), 1e-5)
def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes,
@@ -759,7 +759,7 @@ class Conv2DTest(test.TestCase):
for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-4, atol=1e-4)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth1ValidBackpropFilter(self):
expected = [5.0, 8.0, 14.0, 17.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -773,7 +773,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DEmptyBackpropFilter(self):
expected = []
for (data_format, use_gpu) in GetTestConfigs():
@@ -787,7 +787,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DBackpropFilterWithEmptyInput(self):
expected = [0, 0, 0, 0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -801,7 +801,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth3ValidBackpropFilter(self):
expected = [
17.0, 22.0, 27.0, 22.0, 29.0, 36.0, 27.0, 36.0, 45.0, 32.0, 43.0, 54.0,
@@ -820,7 +820,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self):
expected = [161.0, 182.0, 287.0, 308.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -834,7 +834,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DStrideTwoFilterOneSameBackpropFilter(self):
expected_output = [78.]
for (data_format, use_gpu) in GetTestConfigs():
@@ -848,7 +848,7 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConv2DKernelSizeMatchesInputSizeBackpropFilter(self):
expected_output = [1.0, 2.0, 2.0, 4.0, 3.0, 6.0, 4.0, 8.0]
for (data_format, use_gpu) in GetTestConfigs():
@@ -912,8 +912,8 @@ class Conv2DTest(test.TestCase):
value_2 = sess.run(conv_2)
self.assertShapeEqual(value, conv)
self.assertShapeEqual(value_2, conv_2)
- print("expected = ", value_2)
- print("actual = ", value)
+ tf_logging.info("expected = ", value_2)
+ tf_logging.info("actual = ", value)
self.assertArrayNear(value_2.flatten(), value.flatten(), err)
# Testing for backprops
@@ -965,8 +965,8 @@ class Conv2DTest(test.TestCase):
value_2 = sess.run(conv_2)
self.assertShapeEqual(value, conv)
self.assertShapeEqual(value_2, conv_2)
- print("expected = ", value_2)
- print("actual = ", value)
+ tf_logging.info("expected = ", value_2)
+ tf_logging.info("actual = ", value)
self.assertArrayNear(value_2.flatten(), value.flatten(), err)
def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self):
@@ -1178,7 +1178,7 @@ class Conv2DTest(test.TestCase):
# since fp16 numerical gradients are too imprecise.
err = np.fabs(jacob_t - reference_jacob_t).max()
- print("conv_2d gradient error = ", err)
+ tf_logging.info("conv_2d gradient error = ", err)
self.assertLess(err, 0.002)
def testInputGradientValidPaddingStrideOne(self):
@@ -1546,7 +1546,7 @@ class DepthwiseConv2DTest(test.TestCase):
conv = nn_impl.depthwise_conv2d(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
value = sess.run(conv)
- print("value = ", value)
+ tf_logging.info("value = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@@ -1668,7 +1668,7 @@ class SeparableConv2DTest(test.TestCase):
conv = array_ops.transpose(conv, [0, 2, 3, 1])
value = sess.run(conv)
- print("value = ", value)
+ tf_logging.info("value = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@@ -1826,7 +1826,7 @@ class Conv2DBenchmark(test.Benchmark):
wall_time = time.time() - start
self.report_benchmark(
name="conv_stack_iter_%d" % iter_index, wall_time=wall_time)
- print("conv_stack_iter_%d: %.4f" % (iter_index, wall_time))
+ tf_logging.info("conv_stack_iter_%d: %.4f" % (iter_index, wall_time))
def GetInceptionFwdTest(input_size, filter_size, stride, padding,
@@ -1897,19 +1897,19 @@ if __name__ == "__main__":
for index, (input_size_, filter_size_, output_size_, stride_,
padding_) in enumerate(GetShrunkInceptionShapes()):
setattr(Conv2DTest, "testInceptionFwd_" + str(index),
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionFwdTest(input_size_, filter_size_, stride_,
padding_)))
setattr(
Conv2DTest, "testInceptionFwdDilatedConv_" + str(index),
- test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest(
+ test_util.run_in_graph_and_eager_modes(GetInceptionFwdDilatedConvTest(
input_size_, filter_size_, stride_, padding_)))
setattr(Conv2DTest, "testInceptionBackInput_" + str(index),
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionBackInputTest(input_size_, filter_size_,
output_size_, stride_, padding_)))
setattr(Conv2DTest, "testInceptionBackFilter_" + str(index),
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionBackFilterTest(input_size_, filter_size_,
output_size_, [stride_, stride_],
padding_)))
@@ -1924,17 +1924,17 @@ if __name__ == "__main__":
fshape = [1, 1, 1, 256]
oshape = [1, 400, 400, 256]
setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)))
setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused",
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME")))
setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME",
gpu_only=True)))
setattr(Conv2DTest, "testInceptionBackFilter_No_Winograd_Nonfused",
- test_util.run_in_graph_and_eager_modes()(
+ test_util.run_in_graph_and_eager_modes(
GetInceptionBackFilterTest(ishape, fshape, oshape, [1, 1], "SAME",
gpu_only=True)))
test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 1128cd7a63..b61232cded 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -96,7 +96,8 @@ class UnaryOpTest(test.TestCase):
np_ans = np_func(x)
with self.test_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
- if x.dtype in (np.float32, np.float64):
+ if x.dtype in (np.float32, np.float64,
+ dtypes_lib.bfloat16.as_numpy_dtype):
y = 1.1 * tf_func(inx)
np_ans *= 1.1
else:
@@ -105,6 +106,8 @@ class UnaryOpTest(test.TestCase):
self.assertShapeEqual(np_ans, y)
if x.dtype == np.float16:
self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
+ elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
else:
self.assertAllClose(np_ans, tf_cpu)
@@ -241,6 +244,12 @@ class UnaryOpTest(test.TestCase):
math_ops.lgamma)
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -286,6 +295,12 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(x, np.arcsin, math_ops.asin)
self._compareBoth(x, np.arccos, math_ops.acos)
self._compareBoth(x, np.arctan, math_ops.atan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -334,6 +349,12 @@ class UnaryOpTest(test.TestCase):
self._compareBoth(k, np.arcsin, math_ops.asin)
self._compareBoth(k, np.arccos, math_ops.acos)
self._compareBoth(k, np.tan, math_ops.tan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -370,6 +391,12 @@ class UnaryOpTest(test.TestCase):
math_ops.lgamma)
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
self._compareBothSparse(x, np.abs, math_ops.abs)
self._compareBothSparse(x, np.negative, math_ops.negative)
@@ -644,12 +671,11 @@ class BinaryOpTest(test.TestCase):
self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
np.complex128):
- if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.igamma,
- math_ops.igammac, math_ops.zeta, math_ops.polygamma):
+ if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
+ math_ops.polygamma):
self._compareGradientX(x, y, np_func, tf_func)
self._compareGradientY(x, y, np_func, tf_func)
- if tf_func in (math_ops.igamma, math_ops.igammac, math_ops.zeta,
- math_ops.polygamma):
+ if tf_func in (math_ops.zeta, math_ops.polygamma):
# These methods only support gradients in the second parameter
self._compareGradientY(x, y, np_func, tf_func)
self._compareGpu(x, y, np_func, tf_func)
diff --git a/tensorflow/python/kernel_tests/dct_ops_test.py b/tensorflow/python/kernel_tests/dct_ops_test.py
index 93b2ff4561..97d7e2d8f9 100644
--- a/tensorflow/python/kernel_tests/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/dct_ops_test.py
@@ -40,50 +40,92 @@ def try_import(name): # pylint: disable=invalid-name
fftpack = try_import("scipy.fftpack")
+def _np_dct2(signals, norm=None):
+ """Computes the DCT-II manually with NumPy."""
+ # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
+ dct_size = signals.shape[-1]
+ dct = np.zeros_like(signals)
+ for k in range(dct_size):
+ phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
+ dct[..., k] = np.sum(signals * phi, axis=-1)
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ if norm == "ortho":
+ # The orthonormal scaling includes a factor of 0.5 which we combine with
+ # the overall scaling of 2.0 to cancel.
+ dct[..., 0] *= np.sqrt(1.0 / dct_size)
+ dct[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ dct *= 2.0
+ return dct
+
+
+def _np_dct3(signals, norm=None):
+ """Computes the DCT-III manually with NumPy."""
+ # SciPy's `dct` has a scaling factor of 2.0 which we follow.
+ # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
+ dct_size = signals.shape[-1]
+ signals = np.array(signals) # make a copy so we can modify
+ if norm == "ortho":
+ signals[..., 0] *= np.sqrt(4.0 / dct_size)
+ signals[..., 1:] *= np.sqrt(2.0 / dct_size)
+ else:
+ signals *= 2.0
+ dct = np.zeros_like(signals)
+ # X_k = 0.5 * x_0 +
+ # sum_{n=1}^{N-1} x_n * cos(\frac{pi}{N} * n * (k + 0.5)) k=0,...,N-1
+ half_x0 = 0.5 * signals[..., 0]
+ for k in range(dct_size):
+ phi = np.cos(np.pi * np.arange(1, dct_size) * (k + 0.5) / dct_size)
+ dct[..., k] = half_x0 + np.sum(signals[..., 1:] * phi, axis=-1)
+ return dct
+
+
+NP_DCT = {2: _np_dct2, 3: _np_dct3}
+NP_IDCT = {2: _np_dct3, 3: _np_dct2}
+
+
class DCTOpsTest(test.TestCase):
- def _np_dct2(self, signals, norm=None):
- """Computes the DCT-II manually with NumPy."""
- # X_k = sum_{n=0}^{N-1} x_n * cos(\frac{pi}{N} * (n + 0.5) * k) k=0,...,N-1
- dct_size = signals.shape[-1]
- dct = np.zeros_like(signals)
- for k in range(dct_size):
- phi = np.cos(np.pi * (np.arange(dct_size) + 0.5) * k / dct_size)
- dct[..., k] = np.sum(signals * phi, axis=-1)
- # SciPy's `dct` has a scaling factor of 2.0 which we follow.
- # https://github.com/scipy/scipy/blob/v0.15.1/scipy/fftpack/src/dct.c.src
- if norm == "ortho":
- # The orthonormal scaling includes a factor of 0.5 which we combine with
- # the overall scaling of 2.0 to cancel.
- dct[..., 0] *= np.sqrt(1.0 / dct_size)
- dct[..., 1:] *= np.sqrt(2.0 / dct_size)
- else:
- dct *= 2.0
- return dct
-
- def _compare(self, signals, norm, atol=5e-4, rtol=5e-4):
- """Compares the DCT to SciPy (if available) and a NumPy implementation."""
- np_dct = self._np_dct2(signals, norm)
- tf_dct = spectral_ops.dct(signals, type=2, norm=norm).eval()
+ def _compare(self, signals, norm, dct_type, atol=5e-4, rtol=5e-4):
+ """Compares (I)DCT to SciPy (if available) and a NumPy implementation."""
+ np_dct = NP_DCT[dct_type](signals, norm)
+ tf_dct = spectral_ops.dct(signals, type=dct_type, norm=norm).eval()
self.assertAllClose(np_dct, tf_dct, atol=atol, rtol=rtol)
+ np_idct = NP_IDCT[dct_type](signals, norm)
+ tf_idct = spectral_ops.idct(signals, type=dct_type, norm=norm).eval()
+ self.assertAllClose(np_idct, tf_idct, atol=atol, rtol=rtol)
if fftpack:
- scipy_dct = fftpack.dct(signals, type=2, norm=norm)
+ scipy_dct = fftpack.dct(signals, type=dct_type, norm=norm)
self.assertAllClose(scipy_dct, tf_dct, atol=atol, rtol=rtol)
+ scipy_idct = fftpack.idct(signals, type=dct_type, norm=norm)
+ self.assertAllClose(scipy_idct, tf_idct, atol=atol, rtol=rtol)
+ # Verify inverse(forward(s)) == s, up to a normalization factor.
+ tf_idct_dct = spectral_ops.idct(
+ tf_dct, type=dct_type, norm=norm).eval()
+ tf_dct_idct = spectral_ops.dct(
+ tf_idct, type=dct_type, norm=norm).eval()
+ if norm is None:
+ tf_idct_dct *= 0.5 / signals.shape[-1]
+ tf_dct_idct *= 0.5 / signals.shape[-1]
+ self.assertAllClose(signals, tf_idct_dct, atol=atol, rtol=rtol)
+ self.assertAllClose(signals, tf_dct_idct, atol=atol, rtol=rtol)
def test_random(self):
"""Test randomly generated batches of data."""
with spectral_ops_test_util.fft_kernel_label_map():
with self.test_session(use_gpu=True):
- for shape in ([2, 20], [1], [2], [3], [10], [2, 20], [2, 3, 25]):
+ for shape in ([1], [2], [3], [10], [2, 20], [2, 3, 25]):
signals = np.random.rand(*shape).astype(np.float32)
for norm in (None, "ortho"):
- self._compare(signals, norm)
+ self._compare(signals, norm, 2)
+ self._compare(signals, norm, 3)
def test_error(self):
signals = np.random.rand(10)
# Unsupported type.
with self.assertRaises(ValueError):
- spectral_ops.dct(signals, type=3)
+ spectral_ops.dct(signals, type=1)
# Unknown normalization.
with self.assertRaises(ValueError):
spectral_ops.dct(signals, norm="bad")
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 5e223b1828..7134e02c34 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -356,7 +356,7 @@ class DepthwiseConv2DTest(test.TestCase):
with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-0,
- dtypes.float32: 5e-4,
+ dtypes.float32: 8e-4,
dtypes.float64: 1e-12,
}[data_type]
diff --git a/tensorflow/python/kernel_tests/distributions/BUILD b/tensorflow/python/kernel_tests/distributions/BUILD
index cf2e8832fd..14532965d8 100644
--- a/tensorflow/python/kernel_tests/distributions/BUILD
+++ b/tensorflow/python/kernel_tests/distributions/BUILD
@@ -93,6 +93,7 @@ cuda_py_test(
size = "small",
srcs = ["categorical_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -134,6 +135,10 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
+ tags = [
+ "noguitar", # b/110489471
+ "notap", # b/110489471
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 095d1cde15..9ad77a54cb 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -57,14 +58,14 @@ def entropy(p):
class BernoulliTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testP(self):
p = [0.2, 0.4]
dist = bernoulli.Bernoulli(probs=p)
with self.test_session():
self.assertAllClose(p, self.evaluate(dist.probs))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLogits(self):
logits = [-42., 42.]
dist = bernoulli.Bernoulli(logits=logits)
@@ -82,7 +83,7 @@ class BernoulliTest(test.TestCase):
with self.test_session():
self.assertAllClose(special.logit(p), self.evaluate(dist.logits))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInvalidP(self):
invalid_ps = [1.01, 2.]
for p in invalid_ps:
@@ -104,7 +105,7 @@ class BernoulliTest(test.TestCase):
dist = bernoulli.Bernoulli(probs=p)
self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testShapes(self):
with self.test_session():
for batch_shape in ([], [1], [2, 3, 4]):
@@ -115,7 +116,7 @@ class BernoulliTest(test.TestCase):
self.assertAllEqual([], dist.event_shape.as_list())
self.assertAllEqual([], self.evaluate(dist.event_shape_tensor()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDtype(self):
dist = make_bernoulli([])
self.assertEqual(dist.dtype, dtypes.int32)
@@ -133,7 +134,7 @@ class BernoulliTest(test.TestCase):
self.assertEqual(dist64.dtype, dist64.sample(5).dtype)
self.assertEqual(dist64.dtype, dist64.mode().dtype)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def _testPmf(self, **kwargs):
dist = bernoulli.Bernoulli(**kwargs)
with self.test_session():
@@ -174,7 +175,7 @@ class BernoulliTest(test.TestCase):
p: [0.2, 0.3, 0.4]
}), [[0.2, 0.7, 0.4]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPmfInvalid(self):
p = [0.1, 0.2, 0.7]
with self.test_session():
@@ -184,7 +185,7 @@ class BernoulliTest(test.TestCase):
with self.assertRaisesOpError("Elements cannot exceed 1."):
self.evaluate(dist.prob([2, 0, 1]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPmfWithP(self):
p = [[0.2, 0.4], [0.3, 0.6]]
self._testPmf(probs=p)
@@ -226,21 +227,21 @@ class BernoulliTest(test.TestCase):
dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]])
self.assertEqual((2, 1), dist.log_prob(1).get_shape())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBoundaryConditions(self):
with self.test_session():
dist = bernoulli.Bernoulli(probs=1.0)
self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0)))
self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEntropyNoBatch(self):
p = 0.2
dist = bernoulli.Bernoulli(probs=p)
with self.test_session():
self.assertAllClose(self.evaluate(dist.entropy()), entropy(p))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEntropyWithBatch(self):
p = [[0.1, 0.7], [0.2, 0.6]]
dist = bernoulli.Bernoulli(probs=p, validate_args=False)
@@ -250,7 +251,7 @@ class BernoulliTest(test.TestCase):
[[entropy(0.1), entropy(0.7)], [entropy(0.2),
entropy(0.6)]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSampleN(self):
with self.test_session():
p = [0.2, 0.6]
@@ -272,6 +273,16 @@ class BernoulliTest(test.TestCase):
dist = bernoulli.Bernoulli(np.log([.2, .4]))
self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list())
+ @test_util.run_in_graph_and_eager_modes
+ def testNotReparameterized(self):
+ p = constant_op.constant([0.2, 0.6])
+ with backprop.GradientTape() as tape:
+ tape.watch(p)
+ dist = bernoulli.Bernoulli(probs=p)
+ samples = dist.sample(100)
+ grad_p = tape.gradient(samples, p)
+ self.assertIsNone(grad_p)
+
def testSampleActsLikeSampleN(self):
with self.test_session() as sess:
p = [0.2, 0.6]
@@ -282,18 +293,18 @@ class BernoulliTest(test.TestCase):
self.evaluate(dist.sample(n, seed)),
self.evaluate(dist.sample(n, seed)))
n = array_ops.placeholder(dtypes.int32)
- sample, sample = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
- feed_dict={n: 1000})
- self.assertAllEqual(sample, sample)
+ sample1, sample2 = sess.run([dist.sample(n, seed), dist.sample(n, seed)],
+ feed_dict={n: 1000})
+ self.assertAllEqual(sample1, sample2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMean(self):
with self.test_session():
p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32)
dist = bernoulli.Bernoulli(probs=p)
self.assertAllEqual(self.evaluate(dist.mean()), p)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarianceAndStd(self):
var = lambda p: p * (1. - p)
with self.test_session():
@@ -310,7 +321,7 @@ class BernoulliTest(test.TestCase):
[np.sqrt(var(0.5)), np.sqrt(var(0.4))]],
dtype=np.float32))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBernoulliBernoulliKL(self):
batch_size = 6
a_p = np.array([0.5] * batch_size, dtype=np.float32)
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index 4bc8303ebb..36f3ffc333 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -21,6 +21,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
@@ -282,6 +283,18 @@ class BetaTest(test.TestCase):
self.assertAllClose(
np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1)
+ def testBetaFullyReparameterized(self):
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(a)
+ tape.watch(b)
+ beta = beta_lib.Beta(a, b)
+ samples = beta.sample(100)
+ grad_a, grad_b = tape.gradient(samples, [a, b])
+ self.assertIsNotNone(grad_a)
+ self.assertIsNotNone(grad_b)
+
# Test that sampling with the same seed twice gives the same results.
def testBetaSampleMultipleTimes(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index ca2358fe99..d8939433ce 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -18,8 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
@@ -40,7 +42,7 @@ def make_categorical(batch_shape, num_classes, dtype=dtypes.int32):
return categorical.Categorical(logits, dtype=dtype)
-class CategoricalTest(test.TestCase):
+class CategoricalTest(test.TestCase, parameterized.TestCase):
def testP(self):
p = [0.2, 0.8]
@@ -131,7 +133,7 @@ class CategoricalTest(test.TestCase):
with self.test_session():
self.assertAllClose(dist.prob(0).eval(), 0.2)
- def testCDFWithDynamicEventShape(self):
+ def testCDFWithDynamicEventShapeKnownNdims(self):
"""Test that dynamically-sized events with unknown shape work."""
batch_size = 2
histograms = array_ops.placeholder(dtype=dtypes.float32,
@@ -167,6 +169,21 @@ class CategoricalTest(test.TestCase):
self.assertAllClose(actual_cdf_one, expected_cdf_one)
self.assertAllClose(actual_cdf_two, expected_cdf_two)
+ @parameterized.named_parameters(
+ ("test1", [0, 1], [[0.5, 0.3, 0.2], [1.0, 0.0, 0.0]], [0.0, 1.0]),
+ ("test2", [2, 5], [[0.9, 0.0, 0.0, 0.0, 0.0, 0.1],
+ [0.15, 0.2, 0.05, 0.35, 0.13, 0.12]], [0.9, 0.88]))
+ def testCDFWithDynamicEventShapeUnknownNdims(
+ self, events, histograms, expected_cdf):
+ """Test that dynamically-sized events with unknown shape work."""
+ event_ph = array_ops.placeholder_with_default(events, shape=None)
+ histograms_ph = array_ops.placeholder_with_default(histograms, shape=None)
+ dist = categorical.Categorical(probs=histograms_ph)
+ cdf_op = dist.cdf(event_ph)
+
+ actual_cdf = self.evaluate(cdf_op)
+ self.assertAllClose(actual_cdf, expected_cdf)
+
def testCDFWithBatch(self):
histograms = [[0.1, 0.2, 0.3, 0.25, 0.15],
[0.0, 0.75, 0.2, 0.05, 0.0]]
@@ -360,6 +377,15 @@ class CategoricalTest(test.TestCase):
self.assertAllClose(
[0.4**2 + 0.6**2], [prob_val[:, :, :, 1].mean()], atol=1e-2)
+ def testNotReparameterized(self):
+ p = constant_op.constant([0.3, 0.3, 0.4])
+ with backprop.GradientTape() as tape:
+ tape.watch(p)
+ dist = categorical.Categorical(p)
+ samples = dist.sample(100)
+ grad_p = tape.gradient(samples, p)
+ self.assertIsNone(grad_p)
+
def testLogPMFBroadcasting(self):
with self.test_session():
# 1 x 2 x 2
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index daea699514..1b9edcc85a 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -17,6 +17,9 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -250,10 +253,10 @@ class DirichletMultinomialTest(test.TestCase):
dist.variance(),
dist.stddev(),
])
- self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.06)
- self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.07)
- self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.07)
- self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
+ self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
+ self.assertAllClose(sample_cov_, analytic_cov, atol=0.05, rtol=0.)
+ self.assertAllClose(sample_var_, analytic_var, atol=0.05, rtol=0.)
+ self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
def testCovariance(self):
# Shape [2]
@@ -442,7 +445,7 @@ class DirichletMultinomialTest(test.TestCase):
dist.covariance(),
])
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
- self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.15)
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
self.assertAllClose(
actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
@@ -470,10 +473,25 @@ class DirichletMultinomialTest(test.TestCase):
dist.covariance(),
])
self.assertAllEqual([4], sample_mean.get_shape())
- self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.05)
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.20)
self.assertAllEqual([4, 4], sample_covariance.get_shape())
self.assertAllClose(
- actual_covariance_, sample_covariance_, atol=0., rtol=0.15)
+ actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
+
+ def testNotReparameterized(self):
+ total_count = constant_op.constant(5.0)
+ concentration = constant_op.constant([0.1, 0.1, 0.1])
+ with backprop.GradientTape() as tape:
+ tape.watch(total_count)
+ tape.watch(concentration)
+ dist = ds.DirichletMultinomial(
+ total_count=total_count,
+ concentration=concentration)
+ samples = dist.sample(100)
+ grad_total_count, grad_concentration = tape.gradient(
+ samples, [total_count, concentration])
+ self.assertIsNone(grad_total_count)
+ self.assertIsNone(grad_concentration)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index bcec6ef610..67ed0447ed 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -20,6 +20,7 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@@ -190,10 +191,10 @@ class DirichletTest(test.TestCase):
dist.stddev(),
])
- self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.04)
- self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.06)
- self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.03)
- self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.02)
+ self.assertAllClose(sample_mean_, analytic_mean, atol=0.04, rtol=0.)
+ self.assertAllClose(sample_cov_, analytic_cov, atol=0.06, rtol=0.)
+ self.assertAllClose(sample_var_, analytic_var, atol=0.03, rtol=0.)
+ self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.)
def testVariance(self):
with self.test_session():
@@ -264,6 +265,15 @@ class DirichletTest(test.TestCase):
a=1., b=2.).cdf)[0],
0.01)
+ def testDirichletFullyReparameterized(self):
+ alpha = constant_op.constant([1.0, 2.0, 3.0])
+ with backprop.GradientTape() as tape:
+ tape.watch(alpha)
+ dirichlet = dirichlet_lib.Dirichlet(alpha)
+ samples = dirichlet.sample(100)
+ grad_alpha = tape.gradient(samples, alpha)
+ self.assertIsNotNone(grad_alpha)
+
def testDirichletDirichletKL(self):
conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index ebcd41b0e2..850da3e969 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -23,6 +23,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import nn_ops
@@ -163,6 +164,15 @@ class ExponentialTest(test.TestCase):
stats.expon(scale=1.0 / lam_v[i]).cdf)[0],
0.01)
+ def testFullyReparameterized(self):
+ lam = constant_op.constant([0.1, 1.0])
+ with backprop.GradientTape() as tape:
+ tape.watch(lam)
+ exponential = exponential_lib.Exponential(rate=lam)
+ samples = exponential.sample(100)
+ grad_lam = tape.gradient(samples, lam)
+ self.assertIsNotNone(grad_lam)
+
def testExponentialWithSoftplusRate(self):
with self.test_session():
lam = [-2.2, -3.4]
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 5e4813ac07..297e20264c 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -21,9 +21,10 @@ import importlib
import numpy as np
-from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
@@ -45,6 +46,7 @@ special = try_import("scipy.special")
stats = try_import("scipy.stats")
+@test_util.run_all_in_graph_and_eager_modes
class GammaTest(test.TestCase):
def testGammaShape(self):
@@ -53,9 +55,9 @@ class GammaTest(test.TestCase):
beta = constant_op.constant(11.0)
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
- self.assertEqual(gamma.batch_shape_tensor().eval(), (5,))
+ self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,))
self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5]))
- self.assertAllEqual(gamma.event_shape_tensor().eval(), [])
+ self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), [])
self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([]))
def testGammaLogPDF(self):
@@ -74,8 +76,8 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(log_pdf.eval(), expected_log_pdf)
- self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))
+ self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
+ self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
def testGammaLogPDFMultidimensional(self):
with self.test_session():
@@ -87,10 +89,10 @@ class GammaTest(test.TestCase):
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
log_pdf = gamma.log_prob(x)
- log_pdf_values = log_pdf.eval()
+ log_pdf_values = self.evaluate(log_pdf)
self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = gamma.prob(x)
- pdf_values = pdf.eval()
+ pdf_values = self.evaluate(pdf)
self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
return
@@ -108,10 +110,10 @@ class GammaTest(test.TestCase):
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
log_pdf = gamma.log_prob(x)
- log_pdf_values = log_pdf.eval()
+ log_pdf_values = self.evaluate(log_pdf)
self.assertEqual(log_pdf.get_shape(), (6, 2))
pdf = gamma.prob(x)
- pdf_values = pdf.eval()
+ pdf_values = self.evaluate(pdf)
self.assertEqual(pdf.get_shape(), (6, 2))
if not stats:
@@ -135,7 +137,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)
- self.assertAllClose(cdf.eval(), expected_cdf)
+ self.assertAllClose(self.evaluate(cdf), expected_cdf)
def testGammaMean(self):
with self.test_session():
@@ -146,7 +148,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.mean().eval(), expected_means)
+ self.assertAllClose(self.evaluate(gamma.mean()), expected_means)
def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self):
with self.test_session():
@@ -155,7 +157,7 @@ class GammaTest(test.TestCase):
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
expected_modes = (alpha_v - 1) / beta_v
self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(gamma.mode().eval(), expected_modes)
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self):
with self.test_session():
@@ -166,7 +168,7 @@ class GammaTest(test.TestCase):
rate=beta_v,
allow_nan_stats=False)
with self.assertRaisesOpError("x < y"):
- gamma.mode().eval()
+ self.evaluate(gamma.mode())
def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self):
with self.test_session():
@@ -179,7 +181,7 @@ class GammaTest(test.TestCase):
expected_modes = (alpha_v - 1) / beta_v
expected_modes[0] = np.nan
self.assertEqual(gamma.mode().get_shape(), (3,))
- self.assertAllClose(gamma.mode().eval(), expected_modes)
+ self.assertAllClose(self.evaluate(gamma.mode()), expected_modes)
def testGammaVariance(self):
with self.test_session():
@@ -190,7 +192,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.variance().eval(), expected_variances)
+ self.assertAllClose(self.evaluate(gamma.variance()), expected_variances)
def testGammaStd(self):
with self.test_session():
@@ -201,7 +203,7 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v)
- self.assertAllClose(gamma.stddev().eval(), expected_stddev)
+ self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev)
def testGammaEntropy(self):
with self.test_session():
@@ -212,10 +214,10 @@ class GammaTest(test.TestCase):
if not stats:
return
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
- self.assertAllClose(gamma.entropy().eval(), expected_entropy)
+ self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy)
def testGammaSampleSmallAlpha(self):
- with session.Session():
+ with self.test_session():
alpha_v = 0.05
beta_v = 1.0
alpha = constant_op.constant(alpha_v)
@@ -223,7 +225,7 @@ class GammaTest(test.TestCase):
n = 100000
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
@@ -240,7 +242,7 @@ class GammaTest(test.TestCase):
atol=.15)
def testGammaSample(self):
- with session.Session():
+ with self.test_session():
alpha_v = 4.0
beta_v = 3.0
alpha = constant_op.constant(alpha_v)
@@ -248,7 +250,7 @@ class GammaTest(test.TestCase):
n = 100000
gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n,))
self.assertEqual(sample_values.shape, (n,))
self.assertTrue(self._kstest(alpha_v, beta_v, sample_values))
@@ -264,14 +266,26 @@ class GammaTest(test.TestCase):
stats.gamma.var(alpha_v, scale=1 / beta_v),
atol=.15)
+ def testGammaFullyReparameterized(self):
+ alpha = constant_op.constant(4.0)
+ beta = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(alpha)
+ tape.watch(beta)
+ gamma = gamma_lib.Gamma(concentration=alpha, rate=beta)
+ samples = gamma.sample(100)
+ grad_alpha, grad_beta = tape.gradient(samples, [alpha, beta])
+ self.assertIsNotNone(grad_alpha)
+ self.assertIsNotNone(grad_beta)
+
def testGammaSampleMultiDimensional(self):
- with session.Session():
+ with self.test_session():
alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1
gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v)
n = 10000
samples = gamma.sample(n, seed=137)
- sample_values = samples.eval()
+ sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (n, 10, 100))
self.assertEqual(sample_values.shape, (n, 10, 100))
zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100
@@ -283,11 +297,11 @@ class GammaTest(test.TestCase):
sample_values.mean(axis=0),
stats.gamma.mean(
alpha_bc, scale=1 / beta_bc),
- rtol=.035)
+ atol=0., rtol=.05)
self.assertAllClose(
sample_values.var(axis=0),
stats.gamma.var(alpha_bc, scale=1 / beta_bc),
- atol=4.5)
+ atol=10.0, rtol=0.)
fails = 0
trials = 0
for ai, a in enumerate(np.reshape(alpha_v, [-1])):
@@ -306,12 +320,12 @@ class GammaTest(test.TestCase):
return ks < 0.02
def testGammaPdfOfSampleMultiDims(self):
- with session.Session() as sess:
+ with self.test_session():
gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]])
num = 50000
samples = gamma.sample(num, seed=137)
pdfs = gamma.prob(samples)
- sample_vals, pdf_vals = sess.run([samples, pdfs])
+ sample_vals, pdf_vals = self.evaluate([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
@@ -345,18 +359,18 @@ class GammaTest(test.TestCase):
with self.test_session():
alpha_v = constant_op.constant(0.0, name="alpha")
beta_v = constant_op.constant(1.0, name="beta")
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- with self.assertRaisesOpError("alpha"):
- gamma.mean().eval()
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
+ self.evaluate(gamma.mean())
alpha_v = constant_op.constant(1.0, name="alpha")
beta_v = constant_op.constant(0.0, name="beta")
- gamma = gamma_lib.Gamma(concentration=alpha_v,
- rate=beta_v,
- validate_args=True)
- with self.assertRaisesOpError("beta"):
- gamma.mean().eval()
+ with self.assertRaisesOpError("x > 0"):
+ gamma = gamma_lib.Gamma(concentration=alpha_v,
+ rate=beta_v,
+ validate_args=True)
+ self.evaluate(gamma.mean())
def testGammaWithSoftplusConcentrationRate(self):
with self.test_session():
@@ -364,10 +378,10 @@ class GammaTest(test.TestCase):
beta_v = constant_op.constant([1.0, -3.6], name="beta")
gamma = gamma_lib.GammaWithSoftplusConcentrationRate(
concentration=alpha_v, rate=beta_v)
- self.assertAllEqual(nn_ops.softplus(alpha_v).eval(),
- gamma.concentration.eval())
- self.assertAllEqual(nn_ops.softplus(beta_v).eval(),
- gamma.rate.eval())
+ self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)),
+ self.evaluate(gamma.concentration))
+ self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)),
+ self.evaluate(gamma.rate))
def testGammaGammaKL(self):
alpha0 = np.array([3.])
@@ -377,15 +391,15 @@ class GammaTest(test.TestCase):
beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.])
# Build graph.
- with self.test_session() as sess:
+ with self.test_session():
g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0)
g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
x = g0.sample(int(1e4), seed=0)
kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
kl_actual = kullback_leibler.kl_divergence(g0, g1)
- # Execute graph.
- [kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])
+ # Execute graph.
+ [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual])
self.assertEqual(beta0.shape, kl_actual.get_shape())
@@ -399,7 +413,7 @@ class GammaTest(test.TestCase):
+ alpha0 * (beta1 / beta0 - 1.))
self.assertAllClose(kl_expected, kl_actual_, atol=0., rtol=1e-6)
- self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-2)
+ self.assertAllClose(kl_sample_, kl_actual_, atol=0., rtol=1e-1)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 918c7f63f2..24b243f647 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
@@ -255,6 +256,18 @@ class LaplaceTest(test.TestCase):
atol=0.)
self.assertTrue(self._kstest(loc_v, scale_v, sample_values))
+ def testLaplaceFullyReparameterized(self):
+ loc = constant_op.constant(4.0)
+ scale = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(loc)
+ tape.watch(scale)
+ laplace = laplace_lib.Laplace(loc=loc, scale=scale)
+ samples = laplace.sample(100)
+ grad_loc, grad_scale = tape.gradient(samples, [loc, scale])
+ self.assertIsNotNone(grad_loc)
+ self.assertIsNotNone(grad_scale)
+
def testLaplaceSampleMultiDimensional(self):
with session.Session():
loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index e24e8ade73..bfd40ba2b7 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -18,6 +18,8 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import backprop
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
@@ -310,10 +312,10 @@ class MultinomialTest(test.TestCase):
dist.covariance(),
])
self.assertAllEqual([4, 3, 2], sample_mean.get_shape())
- self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
self.assertAllEqual([4, 3, 2, 2], sample_covariance.get_shape())
self.assertAllClose(
- actual_covariance_, sample_covariance_, atol=0., rtol=0.10)
+ actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
def testSampleUnbiasedScalarBatch(self):
with self.test_session() as sess:
@@ -338,10 +340,24 @@ class MultinomialTest(test.TestCase):
dist.covariance(),
])
self.assertAllEqual([4], sample_mean.get_shape())
- self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.07)
+ self.assertAllClose(actual_mean_, sample_mean_, atol=0., rtol=0.10)
self.assertAllEqual([4, 4], sample_covariance.get_shape())
self.assertAllClose(
- actual_covariance_, sample_covariance_, atol=0., rtol=0.10)
+ actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
+
+ def testNotReparameterized(self):
+ total_count = constant_op.constant(5.0)
+ p = constant_op.constant([0.2, 0.6])
+ with backprop.GradientTape() as tape:
+ tape.watch(total_count)
+ tape.watch(p)
+ dist = multinomial.Multinomial(
+ total_count=total_count,
+ probs=p)
+ samples = dist.sample(100)
+ grad_total_count, grad_p = tape.gradient(samples, [total_count, p])
+ self.assertIsNone(grad_total_count)
+ self.assertIsNone(grad_p)
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index d793e03272..7ff48c0c10 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -23,6 +23,7 @@ import math
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -77,20 +78,20 @@ class NormalTest(test.TestCase):
self.assertEqual(expected, mu_shape)
self.assertEqual(expected, sigma_shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testParamShapes(self):
sample_shape = [10, 3, 4]
self._testParamShapes(sample_shape, sample_shape)
self._testParamShapes(constant_op.constant(sample_shape), sample_shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testParamStaticShapes(self):
sample_shape = [10, 3, 4]
self._testParamStaticShapes(sample_shape, sample_shape)
self._testParamStaticShapes(
tensor_shape.TensorShape(sample_shape), sample_shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalWithSoftplusScale(self):
with self.test_session():
mu = array_ops.zeros((10, 3))
@@ -100,7 +101,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual(
self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalLogPDF(self):
with self.test_session():
batch_size = 6
@@ -134,7 +135,7 @@ class NormalTest(test.TestCase):
self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf))
self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6
@@ -172,7 +173,7 @@ class NormalTest(test.TestCase):
self.assertAllClose(expected_log_pdf, log_pdf_values)
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalCDF(self):
with self.test_session():
batch_size = 50
@@ -194,7 +195,7 @@ class NormalTest(test.TestCase):
expected_cdf = stats.norm(mu, sigma).cdf(x)
self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalSurvivalFunction(self):
with self.test_session():
batch_size = 50
@@ -217,7 +218,7 @@ class NormalTest(test.TestCase):
expected_sf = stats.norm(mu, sigma).sf(x)
self.assertAllClose(expected_sf, self.evaluate(sf), atol=0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalLogCDF(self):
with self.test_session():
batch_size = 50
@@ -239,7 +240,7 @@ class NormalTest(test.TestCase):
if not stats:
return
expected_cdf = stats.norm(mu, sigma).logcdf(x)
- self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-5)
+ self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3)
def testFiniteGradientAtDifficultPoints(self):
for dtype in [np.float32, np.float64]:
@@ -261,7 +262,7 @@ class NormalTest(test.TestCase):
self.assertAllFinite(grads[0])
self.assertAllFinite(grads[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalLogSurvivalFunction(self):
with self.test_session():
batch_size = 50
@@ -285,7 +286,7 @@ class NormalTest(test.TestCase):
expected_sf = stats.norm(mu, sigma).logsf(x)
self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test.
with self.test_session():
@@ -307,7 +308,7 @@ class NormalTest(test.TestCase):
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
self.assertAllClose(expected_entropy, self.evaluate(entropy))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalEntropy(self):
with self.test_session():
mu_v = np.array([1.0, 1.0, 1.0])
@@ -328,7 +329,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual(normal.batch_shape, entropy.get_shape())
self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalMeanAndMode(self):
with self.test_session():
# Mu will be broadcast to [7, 7, 7].
@@ -343,7 +344,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual((3,), normal.mode().get_shape())
self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalQuantile(self):
with self.test_session():
batch_size = 52
@@ -395,7 +396,7 @@ class NormalTest(test.TestCase):
def testQuantileFiniteGradientAtDifficultPointsFloat64(self):
self._baseQuantileFiniteGradientAtDifficultPoints(np.float64)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalVariance(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7]
@@ -407,7 +408,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual((3,), normal.variance().get_shape())
self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalStandardDeviation(self):
with self.test_session():
# sigma will be broadcast to [7, 7, 7]
@@ -419,7 +420,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual((3,), normal.stddev().get_shape())
self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalSample(self):
with self.test_session():
mu = constant_op.constant(3.0)
@@ -453,7 +454,19 @@ class NormalTest(test.TestCase):
self.assertAllEqual(expected_samples_shape, samples.get_shape())
self.assertAllEqual(expected_samples_shape, sample_values.shape)
- @test_util.run_in_graph_and_eager_modes()
+ def testNormalFullyReparameterized(self):
+ mu = constant_op.constant(4.0)
+ sigma = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(mu)
+ tape.watch(sigma)
+ normal = normal_lib.Normal(loc=mu, scale=sigma)
+ samples = normal.sample(100)
+ grad_mu, grad_sigma = tape.gradient(samples, [mu, sigma])
+ self.assertIsNotNone(grad_mu)
+ self.assertIsNotNone(grad_sigma)
+
+ @test_util.run_in_graph_and_eager_modes
def testNormalSampleMultiDimensional(self):
with self.test_session():
batch_size = 2
@@ -489,7 +502,7 @@ class NormalTest(test.TestCase):
self.assertAllEqual(expected_samples_shape, samples.get_shape())
self.assertAllEqual(expected_samples_shape, sample_values.shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNegativeSigmaFails(self):
with self.test_session():
with self.assertRaisesOpError("Condition x > 0 did not hold"):
@@ -497,7 +510,7 @@ class NormalTest(test.TestCase):
loc=[1.], scale=[-5.], validate_args=True, name="G")
self.evaluate(normal.mean())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalShape(self):
with self.test_session():
mu = constant_op.constant([-3.0] * 5)
@@ -524,7 +537,7 @@ class NormalTest(test.TestCase):
feed_dict={mu: 5.0,
sigma: [1.0, 2.0]}), [2])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNormalNormalKL(self):
batch_size = 6
mu_a = np.array([3.0] * batch_size)
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index 4565bf5c46..a634194ce5 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -89,7 +89,7 @@ class NdtriTest(test.TestCase):
all_true = np.ones_like(is_finite, dtype=np.bool)
self.assertAllEqual(all_true, is_finite)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNdtri(self):
"""Verifies that ndtri computation is correct."""
with self.test_session():
@@ -138,11 +138,11 @@ class NdtriTest(test.TestCase):
lambda x: special_math.ndtri(x), p) # pylint: disable=unnecessary-lambda
self.assertAllFinite(self.evaluate(grads[0]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNdtriFiniteGradientFloat32(self):
self._baseNdtriFiniteGradientTest(np.float32)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNdtriFiniteGradientFloat64(self):
self._baseNdtriFiniteGradientTest(np.float64)
diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py
index a4fdb658e8..05590542ef 100644
--- a/tensorflow/python/kernel_tests/distributions/student_t_test.py
+++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py
@@ -23,6 +23,7 @@ import math
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
@@ -172,11 +173,11 @@ class StudentTTest(test.TestCase):
sample_values = self.evaluate(samples)
n_val = 200000
self.assertEqual(sample_values.shape, (n_val,))
- self.assertAllClose(sample_values.mean(), mu_v, rtol=1e-2, atol=0)
+ self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0)
self.assertAllClose(
sample_values.var(),
sigma_v**2 * df_v / (df_v - 2),
- rtol=1e-2,
+ rtol=0.1,
atol=0)
self._checkKLApprox(df_v, mu_v, sigma_v, sample_values)
@@ -215,11 +216,11 @@ class StudentTTest(test.TestCase):
def testStudentSampleMultiDimensional(self):
with self.test_session():
batch_size = 7
- df = constant_op.constant([[3., 7.]] * batch_size)
+ df = constant_op.constant([[5., 7.]] * batch_size)
mu = constant_op.constant([[3., -3.]] * batch_size)
sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] *
batch_size)
- df_v = [3., 7.]
+ df_v = [5., 7.]
mu_v = [3., -3.]
sigma_v = [np.sqrt(10.), np.sqrt(15.)]
n = constant_op.constant(200000)
@@ -228,21 +229,21 @@ class StudentTTest(test.TestCase):
sample_values = self.evaluate(samples)
self.assertEqual(samples.get_shape(), (200000, batch_size, 2))
self.assertAllClose(
- sample_values[:, 0, 0].mean(), mu_v[0], rtol=1e-2, atol=0)
+ sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0)
self.assertAllClose(
sample_values[:, 0, 0].var(),
sigma_v[0]**2 * df_v[0] / (df_v[0] - 2),
- rtol=1e-1,
+ rtol=0.2,
atol=0)
self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0])
self.assertAllClose(
- sample_values[:, 0, 1].mean(), mu_v[1], rtol=1e-2, atol=0)
+ sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0)
self.assertAllClose(
sample_values[:, 0, 1].var(),
sigma_v[1]**2 * df_v[1] / (df_v[1] - 2),
- rtol=1e-1,
+ rtol=0.2,
atol=0)
- self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1])
+ self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1])
def _checkKLApprox(self, df, mu, sigma, samples):
n = samples.size
@@ -272,7 +273,7 @@ class StudentTTest(test.TestCase):
self.assertEqual(student.entropy().get_shape(), (3,))
self.assertEqual(student.log_prob(2.).get_shape(), (3,))
self.assertEqual(student.prob(2.).get_shape(), (3,))
- self.assertEqual(student.sample(37, seed=123456).get_shape(), (37, 3,))
+ self.assertEqual(student.sample(37).get_shape(), (37, 3,))
_check(student_t.StudentT(df=[2., 3., 4.,], loc=2., scale=1.))
_check(student_t.StudentT(df=7., loc=[2., 3., 4.,], scale=1.))
@@ -445,15 +446,30 @@ class StudentTTest(test.TestCase):
self.assertEqual(samples.get_shape(), (num,))
self.assertEqual(pdfs.get_shape(), (num,))
self.assertEqual(mean.get_shape(), ())
- self.assertNear(np.pi, np.mean(sample_vals), err=0.02)
+ self.assertNear(np.pi, np.mean(sample_vals), err=0.1)
self.assertNear(np.pi, mean_val, err=1e-6)
# Verify integral over sample*pdf ~= 1.
# Tolerance increased since eager was getting a value of 1.002041.
- self._assertIntegral(sample_vals, pdf_vals, err=3e-3)
+ self._assertIntegral(sample_vals, pdf_vals, err=5e-2)
if not stats:
return
self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6)
+ def testFullyReparameterized(self):
+ df = constant_op.constant(2.0)
+ mu = constant_op.constant(1.0)
+ sigma = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(df)
+ tape.watch(mu)
+ tape.watch(sigma)
+ student = student_t.StudentT(df=df, loc=mu, scale=sigma)
+ samples = student.sample(100)
+ grad_df, grad_mu, grad_sigma = tape.gradient(samples, [df, mu, sigma])
+ self.assertIsNotNone(grad_df)
+ self.assertIsNotNone(grad_mu)
+ self.assertIsNotNone(grad_sigma)
+
def testPdfOfSampleMultiDims(self):
student = student_t.StudentT(df=[7., 11.], loc=[[5.], [6.]], scale=3.)
self.assertAllEqual([], student.event_shape)
@@ -466,22 +482,22 @@ class StudentTTest(test.TestCase):
sample_vals, pdf_vals = self.evaluate([samples, pdfs])
self.assertEqual(samples.get_shape(), (num, 2, 2))
self.assertEqual(pdfs.get_shape(), (num, 2, 2))
- self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03)
- self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03)
- self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02)
- self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02)
+ self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=0.1)
+ self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=0.1)
+ self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.05)
+ self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.05)
+ self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.05)
+ self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.05)
if not stats:
return
self.assertNear(
stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 0]),
- err=.4)
+ err=1.0)
self.assertNear(
stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var
np.var(sample_vals[:, :, 1]),
- err=.4)
+ err=1.0)
def _assertIntegral(self, sample_vals, pdf_vals, err=1.5e-3):
s_p = zip(sample_vals, pdf_vals)
diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py
index e74051c901..bc9c267b9a 100644
--- a/tensorflow/python/kernel_tests/distributions/uniform_test.py
+++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py
@@ -22,6 +22,7 @@ import importlib
import numpy as np
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
@@ -47,7 +48,7 @@ stats = try_import("scipy.stats")
class UniformTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformRange(self):
with self.test_session():
a = 3.0
@@ -57,7 +58,7 @@ class UniformTest(test.TestCase):
self.assertAllClose(b, self.evaluate(uniform.high))
self.assertAllClose(b - a, self.evaluate(uniform.range()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformPDF(self):
with self.test_session():
a = constant_op.constant([-3.0] * 5 + [15.0])
@@ -83,7 +84,7 @@ class UniformTest(test.TestCase):
log_pdf = uniform.log_prob(x)
self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformShape(self):
with self.test_session():
a = constant_op.constant([-3.0] * 5)
@@ -95,7 +96,7 @@ class UniformTest(test.TestCase):
self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), [])
self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformPDFWithScalarEndpoint(self):
with self.test_session():
a = constant_op.constant([0.0, 5.0])
@@ -108,7 +109,7 @@ class UniformTest(test.TestCase):
pdf = uniform.prob(x)
self.assertAllClose(expected_pdf, self.evaluate(pdf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformCDF(self):
with self.test_session():
batch_size = 6
@@ -132,7 +133,7 @@ class UniformTest(test.TestCase):
log_cdf = uniform.log_cdf(x)
self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformEntropy(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0])
@@ -142,7 +143,7 @@ class UniformTest(test.TestCase):
expected_entropy = np.log(b_v - a_v)
self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformAssertMaxGtMin(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
@@ -153,7 +154,7 @@ class UniformTest(test.TestCase):
uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True)
self.evaluate(uniform.low)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformSample(self):
with self.test_session():
a = constant_op.constant([3.0, 4.0])
@@ -168,15 +169,15 @@ class UniformTest(test.TestCase):
sample_values = self.evaluate(samples)
self.assertEqual(sample_values.shape, (100000, 2))
self.assertAllClose(
- sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-2)
+ sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.)
self.assertAllClose(
- sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-2)
+ sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.)
self.assertFalse(
np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v))
self.assertFalse(
np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def _testUniformSampleMultiDimensional(self):
# DISABLED: Please enable this test once b/issues/30149644 is resolved.
with self.test_session():
@@ -207,7 +208,7 @@ class UniformTest(test.TestCase):
self.assertAllClose(
sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformMean(self):
with self.test_session():
a = 10.0
@@ -218,7 +219,7 @@ class UniformTest(test.TestCase):
s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformVariance(self):
with self.test_session():
a = 10.0
@@ -229,7 +230,7 @@ class UniformTest(test.TestCase):
s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformStd(self):
with self.test_session():
a = 10.0
@@ -240,7 +241,7 @@ class UniformTest(test.TestCase):
s_uniform = stats.uniform(loc=a, scale=b - a)
self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformNans(self):
with self.test_session():
a = 10.0
@@ -258,7 +259,7 @@ class UniformTest(test.TestCase):
self.assertFalse(is_nan[0])
self.assertTrue(is_nan[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformSamplePdf(self):
with self.test_session():
a = 10.0
@@ -268,7 +269,7 @@ class UniformTest(test.TestCase):
self.evaluate(
math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformBroadcasting(self):
with self.test_session():
a = 10.0
@@ -279,7 +280,7 @@ class UniformTest(test.TestCase):
expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
self.assertAllClose(expected_pdf, self.evaluate(pdf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUniformSampleWithShape(self):
with self.test_session():
a = 10.0
@@ -299,6 +300,18 @@ class UniformTest(test.TestCase):
expected_pdf = [1.0, 0.1]
self.assertAllClose(expected_pdf, self.evaluate(pdf))
+ def testFullyReparameterized(self):
+ a = constant_op.constant(0.1)
+ b = constant_op.constant(0.8)
+ with backprop.GradientTape() as tape:
+ tape.watch(a)
+ tape.watch(b)
+ uniform = uniform_lib.Uniform(a, b)
+ samples = uniform.sample(100)
+ grad_a, grad_b = tape.gradient(samples, [a, b])
+ self.assertIsNotNone(grad_a)
+ self.assertIsNotNone(grad_b)
+
# Eager doesn't pass due to a type mismatch in one of the ops.
def testUniformFloat64(self):
uniform = uniform_lib.Uniform(
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index 2f256d3e8b..9d38ffcb4a 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -59,65 +59,6 @@ def _logit(x):
class AssertCloseTest(test.TestCase):
- def testAssertCloseIntegerDtype(self):
- x = array_ops.placeholder(dtypes.int32)
- y = x
- z = array_ops.placeholder(dtypes.int32)
- feed_dict = {x: [1, 5, 10, 15, 20], z: [2, 5, 10, 15, 20]}
- with self.test_session():
- with ops.control_dependencies([du.assert_close(x, y)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with ops.control_dependencies([du.assert_close(y, x)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(x, z)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(y, z)]):
- array_ops.identity(y).eval(feed_dict=feed_dict)
-
- def testAssertCloseNonIntegerDtype(self):
- x = array_ops.placeholder(dtypes.float32)
- y = x + 1e-8
- z = array_ops.placeholder(dtypes.float32)
- feed_dict = {x: [1., 5, 10, 15, 20], z: [2., 5, 10, 15, 20]}
- with self.test_session():
- with ops.control_dependencies([du.assert_close(x, y)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with ops.control_dependencies([du.assert_close(y, x)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(x, z)]):
- array_ops.identity(x).eval(feed_dict=feed_dict)
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(y, z)]):
- array_ops.identity(y).eval(feed_dict=feed_dict)
-
- @test_util.run_in_graph_and_eager_modes()
- def testAssertCloseEpsilon(self):
- x = [0., 5, 10, 15, 20]
- # x != y
- y = [0.1, 5, 10, 15, 20]
- # x = z
- z = [1e-8, 5, 10, 15, 20]
- with self.test_session():
- with ops.control_dependencies([du.assert_close(x, z)]):
- self.evaluate(array_ops.identity(x))
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(x, y)]):
- self.evaluate(array_ops.identity(x))
-
- with self.assertRaisesOpError("Condition x ~= y"):
- with ops.control_dependencies([du.assert_close(y, z)]):
- self.evaluate(array_ops.identity(y))
-
def testAssertIntegerForm(self):
# This should only be detected as an integer.
x = array_ops.placeholder(dtypes.float32)
@@ -150,21 +91,21 @@ class AssertCloseTest(test.TestCase):
class MaybeGetStaticTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetStaticInt(self):
x = 2
self.assertEqual(x, du.maybe_get_static_value(x))
self.assertAllClose(
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetStaticNumpyArray(self):
x = np.array(2, dtype=np.int32)
self.assertEqual(x, du.maybe_get_static_value(x))
self.assertAllClose(
np.array(2.), du.maybe_get_static_value(x, dtype=np.float64))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetStaticConstant(self):
x = constant_op.constant(2, dtype=dtypes.int32)
self.assertEqual(np.array(2, dtype=np.int32), du.maybe_get_static_value(x))
@@ -179,7 +120,7 @@ class MaybeGetStaticTest(test.TestCase):
class GetLogitsAndProbsTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testImproperArguments(self):
with self.test_session():
with self.assertRaises(ValueError):
@@ -188,7 +129,7 @@ class GetLogitsAndProbsTest(test.TestCase):
with self.assertRaises(ValueError):
du.get_logits_and_probs(logits=[0.1], probs=[0.1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLogits(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
logits = _logit(p)
@@ -200,7 +141,7 @@ class GetLogitsAndProbsTest(test.TestCase):
self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.)
self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLogitsMultidimensional(self):
p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
logits = np.log(p)
@@ -212,7 +153,7 @@ class GetLogitsAndProbsTest(test.TestCase):
self.assertAllClose(self.evaluate(new_p), p)
self.assertAllClose(self.evaluate(new_logits), logits)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testProbability(self):
p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)
@@ -223,7 +164,7 @@ class GetLogitsAndProbsTest(test.TestCase):
self.assertAllClose(_logit(p), self.evaluate(new_logits))
self.assertAllClose(p, self.evaluate(new_p))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testProbabilityMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
@@ -234,7 +175,7 @@ class GetLogitsAndProbsTest(test.TestCase):
self.assertAllClose(np.log(p), self.evaluate(new_logits))
self.assertAllClose(p, self.evaluate(new_p))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgs(self):
p = [0.01, 0.2, 0.5, 0.7, .99]
# Component less than 0.
@@ -265,7 +206,7 @@ class GetLogitsAndProbsTest(test.TestCase):
probs=p3, validate_args=False)
self.evaluate(prob)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testProbabilityValidateArgsMultidimensional(self):
p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
# Component less than 0. Still sums to 1.
@@ -367,7 +308,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase):
param)
checked_param.eval(feed_dict={param: np.ones([int(2**11+1)])})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUnsupportedDtype(self):
with self.test_session():
with self.assertRaises(TypeError):
@@ -552,7 +493,7 @@ class RotateTransposeTest(test.TestCase):
x = np.array(x)
return np.transpose(x, np.roll(np.arange(len(x.shape)), shift))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRollStatic(self):
with self.test_session():
if context.executing_eagerly():
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 159cba5fa3..c4d4ce780b 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
-from tensorflow.python.framework import dtypes
class DynamicStitchTestBase(object):
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index e53ca1dcaa..55d75cb474 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import itertools
+import math
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -31,6 +32,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
@@ -736,6 +738,222 @@ class EmbeddingLookupSparseTest(test.TestCase):
x, sp_ids, sp_weights, combiner="mean")
+class SafeEmbeddingLookupSparseTest(test.TestCase):
+
+ def _random_weights(self, vocab_size=4, embed_dim=4, num_shards=1):
+ assert vocab_size > 0
+ assert embed_dim > 0
+ assert num_shards > 0
+ assert num_shards <= vocab_size
+
+ embedding_weights = partitioned_variables.create_partitioned_variables(
+ shape=[vocab_size, embed_dim],
+ slicing=[num_shards, 1],
+ initializer=init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
+ for w in embedding_weights:
+ w.initializer.run()
+ embedding_weights = [w.eval() for w in embedding_weights]
+ return embedding_weights
+
+ def _ids_and_weights_2d(self):
+ # Each row demonstrates a test case:
+ # Row 0: multiple valid ids, 1 invalid id, weighted mean
+ # Row 1: all ids are invalid (leaving no valid ids after pruning)
+ # Row 2: no ids to begin with
+ # Row 3: single id
+ # Row 4: all ids have <=0 weight
+ indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [4, 0], [4, 1]]
+ ids = [0, 1, -1, -1, 2, 0, 1]
+ weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
+ shape = [5, 4]
+
+ sparse_ids = sparse_tensor.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(ids, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ sparse_weights = sparse_tensor.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(weights, dtypes.float32),
+ constant_op.constant(shape, dtypes.int64))
+
+ return sparse_ids, sparse_weights
+
+ def _ids_and_weights_3d(self):
+ # Each (2-D) index demonstrates a test case:
+ # Index 0, 0: multiple valid ids, 1 invalid id, weighted mean
+ # Index 0, 1: all ids are invalid (leaving no valid ids after pruning)
+ # Index 0, 2: no ids to begin with
+ # Index 1, 0: single id
+ # Index 1, 1: all ids have <=0 weight
+ # Index 1, 2: no ids to begin with
+ indices = [[0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 1, 0], [1, 0, 0], [1, 1, 0],
+ [1, 1, 1]]
+ ids = [0, 1, -1, -1, 2, 0, 1]
+ weights = [1.0, 2.0, 1.0, 1.0, 3.0, 0.0, -0.5]
+ shape = [2, 3, 4]
+
+ sparse_ids = sparse_tensor.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(ids, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ sparse_weights = sparse_tensor.SparseTensor(
+ constant_op.constant(indices, dtypes.int64),
+ constant_op.constant(weights, dtypes.float32),
+ constant_op.constant(shape, dtypes.int64))
+
+ return sparse_ids, sparse_weights
+
+ def test_safe_embedding_lookup_sparse_return_zero_vector(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, sparse_weights = self._ids_and_weights_2d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, sparse_weights).eval())
+
+ self.assertAllClose(
+ embedding_lookup_result,
+ [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
+ 3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
+
+ def test_safe_embedding_lookup_sparse_return_special_vector(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, sparse_weights = self._ids_and_weights_2d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, sparse_weights, default_id=3).eval())
+
+ self.assertAllClose(
+ embedding_lookup_result,
+ [(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
+ 3.0, embedding_weights[0][3], embedding_weights[0][3],
+ embedding_weights[0][2], embedding_weights[0][3]])
+
+ def test_safe_embedding_lookup_sparse_no_weights(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, _ = self._ids_and_weights_2d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, None).eval())
+
+ self.assertAllClose(
+ embedding_lookup_result,
+ [(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
+ [0] * 4, embedding_weights[0][2], (
+ embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
+
+ def test_safe_embedding_lookup_sparse_partitioned(self):
+ with self.test_session():
+ embedding_weights = self._random_weights(num_shards=3)
+ sparse_ids, _ = self._ids_and_weights_2d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, None).eval())
+
+ embedding_weights = list(itertools.chain(*embedding_weights))
+ self.assertAllClose(embedding_lookup_result,
+ [(embedding_weights[0] + embedding_weights[1]) / 2.0,
+ [0] * 4, [0] * 4, embedding_weights[2],
+ (embedding_weights[0] + embedding_weights[1]) / 2.0])
+
+ def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
+ with self.test_session():
+ embedding_weights = self._random_weights(num_shards=3)
+ sparse_ids, sparse_weights = self._ids_and_weights_2d()
+
+ embedding_weights[1] = embedding_weights[1].astype(np.float64)
+ self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
+ embedding_weights, sparse_ids)
+ embedding_weights = [
+ constant_op.constant(w, dtype=dtypes.float64)
+ for w in embedding_weights
+ ]
+ self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
+ embedding_weights, sparse_ids, sparse_weights)
+
+ def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, sparse_weights = self._ids_and_weights_3d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, sparse_weights).eval())
+
+ self.assertAllClose(embedding_lookup_result, [[
+ (1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
+ [0] * 4, [0] * 4
+ ], [embedding_weights[0][2], [0] * 4, [0] * 4]])
+
+ def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, sparse_weights = self._ids_and_weights_3d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, sparse_weights, default_id=3).eval())
+
+ self.assertAllClose(
+ embedding_lookup_result,
+ [[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
+ 3.0, embedding_weights[0][3], embedding_weights[0][3]], [
+ embedding_weights[0][2], embedding_weights[0][3],
+ embedding_weights[0][3]
+ ]])
+
+ def test_safe_embedding_lookup_sparse_3d_no_weights(self):
+ with self.test_session():
+ embedding_weights = self._random_weights()
+ sparse_ids, _ = self._ids_and_weights_3d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, None).eval())
+
+ self.assertAllClose(embedding_lookup_result, [[(
+ embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
+ 0
+ ] * 4], [
+ embedding_weights[0][2],
+ (embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
+ ]])
+
+ def test_safe_embedding_lookup_sparse_3d_partitioned(self):
+ with self.test_session():
+ embedding_weights = self._random_weights(num_shards=3)
+ sparse_ids, _ = self._ids_and_weights_3d()
+
+ embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
+ embedding_weights, sparse_ids, None).eval())
+
+ embedding_weights = list(itertools.chain(*embedding_weights))
+ self.assertAllClose(embedding_lookup_result, [[
+ (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
+ ], [
+ embedding_weights[2],
+ (embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
+ ]])
+
+ def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
+ self):
+ with self.test_session():
+ embedding_weights = self._random_weights(num_shards=3)
+ sparse_ids, sparse_weights = self._ids_and_weights_3d()
+
+ embedding_weights[1] = embedding_weights[1].astype(np.float64)
+ self.assertRaises(TypeError, embedding_ops.safe_embedding_lookup_sparse,
+ embedding_weights, sparse_ids)
+ embedding_weights = [
+ constant_op.constant(w, dtype=dtypes.float64)
+ for w in embedding_weights
+ ]
+ self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
+ embedding_weights, sparse_ids, sparse_weights)
+
+
class DynamicStitchOpTest(test.TestCase):
def testCint32Cpu(self):
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index ce73e7ad3e..9e7b528338 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -125,12 +126,21 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
+ @test_util.run_in_graph_and_eager_modes
def testMultipleDequeues(self):
- with self.test_session() as session:
- q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
- q.enqueue_many([[1, 2, 3]]).run()
- a, b, c = session.run([q.dequeue(), q.dequeue(), q.dequeue()])
- self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue_many([[1, 2, 3]]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testQueuesDontShare(self):
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index facadc971f..24800d2b7a 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
@@ -56,7 +57,7 @@ def simple_scoped_fn(a, x):
class FunctionalOpsTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldl_Simple(self):
with self.test_session():
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -72,7 +73,7 @@ class FunctionalOpsTest(test.TestCase):
initializer=10)
self.assertAllEqual(880, self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldl_SingleInputMultiOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -83,7 +84,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(22, r_value[0])
self.assertAllEqual(20, r_value[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldl_MultiInputSingleOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -111,7 +112,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertEqual(len(variables.trainable_variables()), 1)
self.assertAllEqual(880, self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldr_Simple(self):
with self.test_session():
elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -127,7 +128,7 @@ class FunctionalOpsTest(test.TestCase):
initializer=10)
self.assertAllEqual(1282, self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldr_SingleInputMultiOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -138,7 +139,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(22, r_value[0])
self.assertAllEqual(20, r_value[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldr_MultiInputSingleOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -182,7 +183,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(720.0, self.evaluate(r))
# pylint: enable=unnecessary-lambda
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_Simple(self):
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
@@ -202,7 +203,7 @@ class FunctionalOpsTest(test.TestCase):
values=constant_op.constant([0, 1, 2]),
dense_shape=[2, 2]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMapOverScalarErrors(self):
with self.assertRaisesRegexp(ValueError, "not scalars"):
functional_ops.map_fn(lambda x: x, [1, 2])
@@ -251,7 +252,7 @@ class FunctionalOpsTest(test.TestCase):
r = gradients_impl.gradients(y, elems)[0]
self.assertAllEqual([4.0, 8.0, 12.0, 16.0, 20.0, 24.0], self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_SimpleNotTensor(self):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
@@ -260,7 +261,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(
np.array([(x + 3) * 2 for x in nums]), self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_SingleInputMultiOutput(self):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
@@ -275,7 +276,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual((nums + 3) * 2, received[0])
self.assertAllEqual(-(nums + 3) * 2, received[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_MultiOutputMismatchedDtype(self):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
@@ -287,7 +288,7 @@ class FunctionalOpsTest(test.TestCase):
nums,
dtype=[dtypes.int64, dtypes.int64])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSingleOutput(self):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
@@ -298,7 +299,7 @@ class FunctionalOpsTest(test.TestCase):
received = self.evaluate(r)
self.assertAllEqual(nums * nums + (-nums), received)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMap_MultiInputSameStructureOutput(self):
with self.test_session():
nums = np.array([1, 2, 3, 4, 5, 6])
@@ -313,7 +314,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(-nums, received[1])
self.assertAllEqual(nums, received[2])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_Simple(self):
with self.test_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
@@ -328,7 +329,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
# pylint: enable=unnecessary-lambda
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_Reverse(self):
with self.test_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
@@ -345,7 +346,7 @@ class FunctionalOpsTest(test.TestCase):
self.evaluate(r))
# pylint: enable=unnecessary-lambda
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_SingleInputMultiOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -357,7 +358,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0])
self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSingleOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -367,7 +368,7 @@ class FunctionalOpsTest(test.TestCase):
(elems + 1, -elems), initializer)
self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_MultiInputSameTypeOutput(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -377,7 +378,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(np.cumsum(elems), r_value[0])
self.assertAllEqual(np.cumsum(-elems), r_value[1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScan_MultiOutputMismatchedInitializer(self):
with self.test_session():
elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
@@ -408,7 +409,7 @@ class FunctionalOpsTest(test.TestCase):
results = np.array([6, 16, 38, 84, 178, 368])
self.assertAllEqual(results, self.evaluate(r))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScanFoldl_Nested(self):
with self.test_session():
elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data")
@@ -467,7 +468,7 @@ class FunctionalOpsTest(test.TestCase):
variables.global_variables_initializer().run()
sess.run(grad)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFoldShape(self):
with self.test_session():
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
@@ -479,7 +480,7 @@ class FunctionalOpsTest(test.TestCase):
y = functional_ops.foldl(fn, x, initializer=initializer)
self.assertAllEqual(y.get_shape(), self.evaluate(y).shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMapShape(self):
with self.test_session():
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
@@ -491,7 +492,7 @@ class FunctionalOpsTest(test.TestCase):
y = functional_ops.map_fn(lambda e: e, x)
self.assertIs(None, y.get_shape().dims)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMapEmptyScalar(self):
with self.test_session():
map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([]))
@@ -507,7 +508,7 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual([0, 3, 2], map_return.get_shape().dims)
self.assertAllEqual([0, 3, 2], self.evaluate(map_return).shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScanShape(self):
with self.test_session():
x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
@@ -604,6 +605,25 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op)
self.assertEqual(mul, [6])
+ def testRemoteFunctionSameDeviceDirectSession(self):
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def _remote_fn(a, b):
+ return math_ops.multiply(a, b)
+
+ with ops.device("/cpu:0"):
+ a = variables.Variable(2, dtype=dtypes.int32)
+ b = variables.Variable(3, dtype=dtypes.int32)
+
+ with ops.device("/cpu:0"):
+ remote_op = functional_ops.remote_call(
+ args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0")
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ mul = sess.run(remote_op)
+ self.assertEqual(mul, [6])
+
def testRemoteFunctionCPUGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@@ -652,6 +672,24 @@ class FunctionalOpsTest(test.TestCase):
mul = sess.run(remote_op)
self.assertEqual(mul, 9.0)
+ def testRemoteFunctionGPUCPUStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(inp):
+ return array_ops.identity(inp)
+
+ a = array_ops.constant("a")
+
+ with ops.device("/gpu:0"):
+ remote_op = functional_ops.remote_call(
+ args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0")
+
+ with self.test_session() as sess:
+ ret = sess.run(remote_op)
+ self.assertAllEqual(ret, [b"a"])
+
def testRemoteFunctionCrossProcess(self):
workers, _ = test_util.create_local_cluster(2, 1)
@@ -1043,6 +1081,56 @@ class PartitionedCallTest(test.TestCase):
self.assertTrue(compat.as_bytes("CPU:1") in outputs[1].eval())
self.assertTrue(compat.as_bytes("CPU:2") in outputs[2].eval())
+ def testAssignAddResourceVariable(self):
+
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.Defun()
+ def AssignAdd():
+ v.assign_add(1.0)
+
+ op = functional_ops.partitioned_call(
+ args=AssignAdd.captured_inputs, f=AssignAdd)
+ _ = self.evaluate(variables.global_variables_initializer())
+ _ = self.evaluate(op)
+ value = self.evaluate(v.read_value())
+ self.assertEqual(value, 2.0)
+
+ def testFunctionWithResourcesOnDifferentDevices(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPUs available.")
+
+ with ops.device("/cpu:0"):
+ v_cpu_zero = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name="v_cpu_zero")
+
+ with ops.device("/cpu:1"):
+ v_cpu_one = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name="v_cpu_one")
+
+ with ops.device("/gpu:0"):
+ v_gpu = resource_variable_ops.ResourceVariable(
+ [0.0, 1.0, 2.0], name="v_gpu")
+
+ def sum_gather():
+ cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_zero, [1, 2]))
+ also_cpu_result = math_ops.reduce_sum(array_ops.gather(v_cpu_one, [1, 2]))
+ gpu_result = math_ops.reduce_sum(array_ops.gather(v_gpu, [1, 2]))
+ return cpu_result, also_cpu_result, gpu_result
+
+ defined = function.Defun()(sum_gather)
+ with self.test_session(
+ config=config_pb2.ConfigProto(
+ allow_soft_placement=False,
+ log_device_placement=True,
+ device_count={"CPU": 2})) as sess:
+ sess.run(variables.global_variables_initializer())
+ expected = sess.run(sum_gather())
+ result = sess.run(
+ functional_ops.partitioned_call(
+ args=defined.captured_inputs, f=defined))
+ self.assertAllEqual(expected, result)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index a9b55854f1..f6097ad489 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -362,6 +362,71 @@ class UniformUnitScalingInitializationTest(test.TestCase):
dtype=dtypes.string)
+class VarianceScalingInitializationTest(test.TestCase):
+
+ def testTruncatedNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(
+ distribution='truncated_normal')
+
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
+ as mock_truncated_normal:
+ x = init(shape).eval()
+ self.assertTrue(mock_truncated_normal.called)
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+ def testNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(distribution='normal')
+
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'truncated_normal', wraps=random_ops.truncated_normal) \
+ as mock_truncated_normal:
+ x = init(shape).eval()
+ self.assertTrue(mock_truncated_normal.called)
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+ def testUntruncatedNormalDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(
+ distribution='untruncated_normal')
+
+ with self.test_session(use_gpu=True), \
+ test.mock.patch.object(
+ random_ops, 'random_normal', wraps=random_ops.random_normal) \
+ as mock_random_normal:
+ x = init(shape).eval()
+ self.assertTrue(mock_random_normal.called)
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+ def testUniformDistribution(self):
+ shape = [100, 100]
+ expect_mean = 0.
+ expect_var = 1. / shape[0]
+ init = init_ops.variance_scaling_initializer(distribution='uniform')
+
+ with self.test_session(use_gpu=True):
+ x = init(shape).eval()
+
+ self.assertNear(np.mean(x), expect_mean, err=1e-2)
+ self.assertNear(np.var(x), expect_var, err=1e-2)
+
+
# TODO(vrv): move to sequence_ops_test?
class RangeTest(test.TestCase):
@@ -765,7 +830,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1], [2], [3], [4], [5], [6]]:
convolution = convolutional.conv1d
inputs = random_ops.random_normal(shape, dtype=dtype)
@@ -860,7 +925,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]:
convolution = convolutional.conv2d
inputs = random_ops.random_normal(shape, dtype=dtype)
@@ -985,7 +1050,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
tol = 1e-3
gain = 3.14
# Check orthogonality/isometry by computing the ratio between
- # the 2-norms of the inputs and ouputs.
+ # the 2-norms of the inputs and outputs.
for kernel_size in [[1, 1, 1], [2, 2, 2], [3, 3, 3]]:
convolution = convolutional.conv3d
inputs = random_ops.random_normal(shape, dtype=dtype)
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 91be80322c..69d3aa4017 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -107,6 +107,10 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -124,6 +128,10 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -140,6 +148,10 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -177,6 +189,10 @@ cuda_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -213,4 +229,8 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
index 2b80f01b73..3ede2aceaa 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
@@ -80,7 +80,7 @@ class SquareLinearOperatorBlockDiagTest(
build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
expected_blocks = (
build_info.__dict__["blocks"] if "blocks" in build_info.__dict__
@@ -91,26 +91,19 @@ class SquareLinearOperatorBlockDiagTest(
for block_shape in expected_blocks
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in expected_blocks
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = self.evaluate(matrices)
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorFullMatrix(
- m_ph, is_square=True) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorFullMatrix(
- m, is_square=True) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = block_diag.LinearOperatorBlockDiag(
+ [linalg.LinearOperatorFullMatrix(
+ l, is_square=True) for l in lin_op_matrices])
+
+ # Should be auto-set.
+ self.assertTrue(operator.is_square)
# Broadcast the shapes.
expected_shape = list(build_info.shape)
@@ -123,7 +116,7 @@ class SquareLinearOperatorBlockDiagTest(
block_diag_dense.set_shape(
expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])
- return operator, block_diag_dense, feed_dict
+ return operator, block_diag_dense
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 5713d16969..7261d4bb3b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -95,7 +95,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
# real, the matrix will not be real.
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating real spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -107,22 +107,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
# zero, so the operator will still be self-adjoint.
spectrum = math_ops.cast(spectrum, dtype)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, is_self_adjoint=True, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -149,7 +145,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -172,22 +168,18 @@ class LinearOperatorCirculantTestHermitianSpectrum(
spectrum = math_ops.fft(h_c)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -213,7 +205,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# Will be well conditioned enough to get accurate solves.
spectrum = linear_operator_test_util.random_sign_uniform(
@@ -222,22 +214,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
minval=1.,
maxval=2.)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -432,7 +420,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -455,22 +443,18 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
spectrum = math_ops.fft2d(h_c)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant2D(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant2D(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant2D(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
class LinearOperatorCirculant2DTestNonHermitianSpectrum(
@@ -486,7 +470,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# Will be well conditioned enough to get accurate solves.
spectrum = linear_operator_test_util.random_sign_uniform(
@@ -495,22 +479,18 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
minval=1.,
maxval=2.)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant2D(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant2D(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant2D(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
index f96b9ccdaa..612a50bcec 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
@@ -44,7 +44,7 @@ class SquareLinearOperatorCompositionTest(
self._rtol[dtypes.float32] = 1e-4
self._rtol[dtypes.complex64] = 1e-4
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
sess = ops.get_default_session()
shape = list(build_info.shape)
@@ -56,33 +56,23 @@ class SquareLinearOperatorCompositionTest(
for _ in range(num_operators)
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in range(num_operators)
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = sess.run(matrices)
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated each matrix to a numpy array.
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = linalg.LinearOperatorComposition(
+ [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices],
+ is_square=True)
+
matmul_order_list = list(reversed(matrices))
- mat = ops.convert_to_tensor(matmul_order_list[0])
+ mat = matmul_order_list[0]
for other_mat in matmul_order_list[1:]:
mat = math_ops.matmul(other_mat, mat)
- return operator, mat, feed_dict
+ return operator, mat
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
@@ -148,7 +138,7 @@ class NonSquareLinearOperatorCompositionTest(
self._rtol[dtypes.float32] = 1e-4
self._rtol[dtypes.complex64] = 1e-4
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
sess = ops.get_default_session()
shape = list(build_info.shape)
@@ -170,30 +160,22 @@ class NonSquareLinearOperatorCompositionTest(
shape_2, dtype=dtype)
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in range(num_operators)
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = sess.run(matrices)
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph])
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m) for m in matrices])
- feed_dict = None
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated each matrix to a numpy array.
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = linalg.LinearOperatorComposition(
+ [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices])
+
matmul_order_list = list(reversed(matrices))
- mat = ops.convert_to_tensor(matmul_order_list[0])
+ mat = matmul_order_list[0]
for other_mat in matmul_order_list[1:]:
mat = math_ops.matmul(other_mat, mat)
- return operator, mat, feed_dict
+ return operator, mat
def test_static_shapes(self):
operators = [
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 0a0e31c716..83cc8c483f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -34,25 +34,21 @@ class LinearOperatorDiagTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
diag = linear_operator_test_util.random_sign_uniform(
shape[:-1], minval=1., maxval=2., dtype=dtype)
+
+ lin_op_diag = diag
+
if use_placeholder:
- diag_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate the diag here because (i) you cannot feed a tensor, and (ii)
- # diag is random and we want the same value used for both mat and
- # feed_dict.
- diag = diag.eval()
- operator = linalg.LinearOperatorDiag(diag_ph)
- feed_dict = {diag_ph: diag}
- else:
- operator = linalg.LinearOperatorDiag(diag)
- feed_dict = None
+ lin_op_diag = array_ops.placeholder_with_default(diag, shape=None)
+
+ operator = linalg.LinearOperatorDiag(lin_op_diag)
- mat = array_ops.matrix_diag(diag)
+ matrix = array_ops.matrix_diag(diag)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
# Matrix with one positive eigenvalue and one zero eigenvalue.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
index b3da623b5e..1a40a29ec6 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -36,30 +35,20 @@ class SquareLinearOperatorFullMatrixTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_positive_definite_matrix(
shape, dtype)
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True)
- feed_dict = {matrix_ph: matrix}
- else:
- # is_square should be auto-detected here.
- operator = linalg.LinearOperatorFullMatrix(matrix)
- feed_dict = None
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
- return operator, mat, feed_dict
+ return operator, matrix
def test_is_x_flags(self):
# Matrix with two positive eigenvalues.
@@ -136,32 +125,20 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.float64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_positive_definite_matrix(
shape, dtype, force_well_conditioned=True)
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- # is_square is auto-set because of self_adjoint/pd.
- operator = linalg.LinearOperatorFullMatrix(
- matrix_ph, is_self_adjoint=True, is_positive_definite=True)
- feed_dict = {matrix_ph: matrix}
- else:
- operator = linalg.LinearOperatorFullMatrix(
- matrix, is_self_adjoint=True, is_positive_definite=True)
- feed_dict = None
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
-
- return operator, mat, feed_dict
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
+
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
+
+ return operator, matrix
def test_is_x_flags(self):
# Matrix with two positive eigenvalues.
@@ -210,26 +187,18 @@ class NonSquareLinearOperatorFullMatrixTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_normal(shape, dtype=dtype)
+
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- operator = linalg.LinearOperatorFullMatrix(matrix_ph)
- feed_dict = {matrix_ph: matrix}
- else:
- operator = linalg.LinearOperatorFullMatrix(matrix)
- feed_dict = None
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
- return operator, mat, feed_dict
+ return operator, matrix
def test_is_x_flags(self):
matrix = [[3., 2., 1.], [1., 1., 1.]]
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index 59f63f949e..35dcf4417c 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -43,7 +43,7 @@ class LinearOperatorIdentityTest(
# 16bit.
return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
assert shape[-1] == shape[-2]
@@ -54,13 +54,7 @@ class LinearOperatorIdentityTest(
num_rows, batch_shape=batch_shape, dtype=dtype)
mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
- # Nothing to feed since LinearOperatorIdentity takes no Tensor args.
- if use_placeholder:
- feed_dict = {}
- else:
- feed_dict = None
-
- return operator, mat, feed_dict
+ return operator, mat
def test_assert_positive_definite(self):
with self.test_session():
@@ -261,7 +255,7 @@ class LinearOperatorScaledIdentityTest(
# 16bit.
return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
assert shape[-1] == shape[-2]
@@ -274,24 +268,23 @@ class LinearOperatorScaledIdentityTest(
multiplier = linear_operator_test_util.random_sign_uniform(
shape=batch_shape, minval=1., maxval=2., dtype=dtype)
- operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier)
# Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args.
+ lin_op_multiplier = multiplier
+
if use_placeholder:
- multiplier_ph = array_ops.placeholder(dtype=dtype)
- multiplier = multiplier.eval()
- operator = linalg_lib.LinearOperatorScaledIdentity(
- num_rows, multiplier_ph)
- feed_dict = {multiplier_ph: multiplier}
- else:
- feed_dict = None
+ lin_op_multiplier = array_ops.placeholder_with_default(
+ multiplier, shape=None)
+
+ operator = linalg_lib.LinearOperatorScaledIdentity(
+ num_rows, lin_op_multiplier)
multiplier_matrix = array_ops.expand_dims(
array_ops.expand_dims(multiplier, -1), -1)
- mat = multiplier_matrix * linalg_ops.eye(
+ matrix = multiplier_matrix * linalg_ops.eye(
num_rows, batch_shape=batch_shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_positive_definite_does_not_raise_when_positive(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index 784c730bbc..e26b946151 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -101,7 +101,7 @@ class SquareLinearOperatorKroneckerTest(
def _tests_to_skip(self):
return ["det", "solve", "solve_with_broadcast"]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
expected_factors = build_info.__dict__["factors"]
matrices = [
@@ -110,26 +110,15 @@ class SquareLinearOperatorKroneckerTest(
for block_shape in expected_factors
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in expected_factors
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = self.evaluate(matrices)
- operator = kronecker.LinearOperatorKronecker(
- [linalg.LinearOperatorFullMatrix(
- m_ph, is_square=True) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = kronecker.LinearOperatorKronecker(
- [linalg.LinearOperatorFullMatrix(
- m, is_square=True) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(m, shape=None) for m in matrices]
+
+ operator = kronecker.LinearOperatorKronecker(
+ [linalg.LinearOperatorFullMatrix(
+ l, is_square=True) for l in lin_op_matrices])
matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
@@ -138,7 +127,7 @@ class SquareLinearOperatorKroneckerTest(
if not use_placeholder:
kronecker_dense.set_shape(shape)
- return operator, kronecker_dense, feed_dict
+ return operator, kronecker_dense
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 8095f6419e..0e38dbd48d 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -49,12 +49,6 @@ class BaseLinearOperatorLowRankUpdatetest(object):
_use_v = None
@property
- def _dtypes_to_test(self):
- # TODO(langmore) Test complex types once cholesky works with them.
- # See comment in LinearOperatorLowRankUpdate.__init__.
- return [dtypes.float32, dtypes.float64]
-
- @property
def _operator_build_infos(self):
build_info = linear_operator_test_util.OperatorBuildInfo
# Previously we had a (2, 10, 10) shape at the end. We did this to test the
@@ -68,7 +62,16 @@ class BaseLinearOperatorLowRankUpdatetest(object):
build_info((3, 4, 4)),
build_info((2, 1, 4, 4))]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _gen_positive_diag(self, dtype, diag_shape):
+ if dtype.is_complex:
+ diag = linear_operator_test_util.random_uniform(
+ diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32)
+ return math_ops.cast(diag, dtype=dtype)
+
+ return linear_operator_test_util.random_uniform(
+ diag_shape, minval=1e-4, maxval=1., dtype=dtype)
+
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
# Recall A = L + UDV^H
shape = list(build_info.shape)
diag_shape = shape[:-1]
@@ -78,63 +81,46 @@ class BaseLinearOperatorLowRankUpdatetest(object):
# base_operator L will be a symmetric positive definite diagonal linear
# operator, with condition number as high as 1e4.
- base_diag = linear_operator_test_util.random_uniform(
- diag_shape, minval=1e-4, maxval=1., dtype=dtype)
- base_diag_ph = array_ops.placeholder(dtype=dtype)
+ base_diag = self._gen_positive_diag(dtype, diag_shape)
+ lin_op_base_diag = base_diag
# U
u = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
- u_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_u = u
# V
v = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
- v_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_v = v
# D
if self._is_diag_update_positive:
- diag_update = linear_operator_test_util.random_uniform(
- diag_update_shape, minval=1e-4, maxval=1., dtype=dtype)
+ diag_update = self._gen_positive_diag(dtype, diag_update_shape)
else:
diag_update = linear_operator_test_util.random_normal(
diag_update_shape, stddev=1e-4, dtype=dtype)
- diag_update_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_diag_update = diag_update
if use_placeholder:
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- base_diag = base_diag.eval()
- u = u.eval()
- v = v.eval()
- diag_update = diag_update.eval()
-
- # In all cases, set base_operator to be positive definite.
- base_operator = linalg.LinearOperatorDiag(
- base_diag_ph, is_positive_definite=True)
-
- operator = linalg.LinearOperatorLowRankUpdate(
- base_operator,
- u=u_ph,
- v=v_ph if self._use_v else None,
- diag_update=diag_update_ph if self._use_diag_update else None,
- is_diag_update_positive=self._is_diag_update_positive)
- feed_dict = {
- base_diag_ph: base_diag,
- u_ph: u,
- v_ph: v,
- diag_update_ph: diag_update}
- else:
- base_operator = linalg.LinearOperatorDiag(
- base_diag, is_positive_definite=True)
- operator = linalg.LinearOperatorLowRankUpdate(
- base_operator,
- u,
- v=v if self._use_v else None,
- diag_update=diag_update if self._use_diag_update else None,
- is_diag_update_positive=self._is_diag_update_positive)
- feed_dict = None
+ lin_op_base_diag = array_ops.placeholder_with_default(
+ base_diag, shape=None)
+ lin_op_u = array_ops.placeholder_with_default(u, shape=None)
+ lin_op_v = array_ops.placeholder_with_default(v, shape=None)
+ lin_op_diag_update = array_ops.placeholder_with_default(
+ diag_update, shape=None)
+
+ base_operator = linalg.LinearOperatorDiag(
+ lin_op_base_diag,
+ is_positive_definite=True,
+ is_self_adjoint=True)
+
+ operator = linalg.LinearOperatorLowRankUpdate(
+ base_operator,
+ lin_op_u,
+ v=lin_op_v if self._use_v else None,
+ diag_update=lin_op_diag_update if self._use_diag_update else None,
+ is_diag_update_positive=self._is_diag_update_positive)
# The matrix representing L
base_diag_mat = array_ops.matrix_diag(base_diag)
@@ -146,28 +132,28 @@ class BaseLinearOperatorLowRankUpdatetest(object):
if self._use_v and self._use_diag_update:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
- mat = base_diag_mat + math_ops.matmul(
+ matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
elif self._use_v:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
- mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
+ matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
elif self._use_diag_update:
# In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
expect_use_cholesky = self._is_diag_update_positive
- mat = base_diag_mat + math_ops.matmul(
+ matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
else:
# In this case, we have L + UU^H, which is PD since L > 0.
expect_use_cholesky = True
- mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
+ matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
if expect_use_cholesky:
self.assertTrue(operator._use_cholesky)
else:
self.assertFalse(operator._use_cholesky)
- return operator, mat, feed_dict
+ return operator, matrix
class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
@@ -186,6 +172,7 @@ class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
@@ -205,6 +192,7 @@ class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
@@ -223,6 +211,7 @@ class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
self._rtol[dtypes.float32] = 1e-5
self._atol[dtypes.float64] = 1e-10
self._rtol[dtypes.float64] = 1e-10
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
@@ -242,6 +231,7 @@ class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
self._rtol[dtypes.float32] = 1e-4
self._atol[dtypes.float64] = 1e-9
self._rtol[dtypes.float64] = 1e-9
+ self._rtol[dtypes.complex64] = 1e-4
class LinearOperatorLowRankUpdatetestWithDiagNotSquare(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index a57d2f085e..b389e0cbdf 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
@@ -32,34 +31,23 @@ class LinearOperatorLowerTriangularTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- @property
- def _dtypes_to_test(self):
- # TODO(langmore) Test complex types once supported by
- # matrix_triangular_solve.
- return [dtypes.float32, dtypes.float64]
-
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
# Upper triangle will be nonzero, but ignored.
# Use a diagonal that ensures this matrix is well conditioned.
tril = linear_operator_test_util.random_tril_matrix(
shape, dtype=dtype, force_well_conditioned=True, remove_upper=False)
+ lin_op_tril = tril
+
if use_placeholder:
- tril_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate the tril here because (i) you cannot feed a tensor, and (ii)
- # tril is random and we want the same value used for both mat and
- # feed_dict.
- tril = tril.eval()
- operator = linalg.LinearOperatorLowerTriangular(tril_ph)
- feed_dict = {tril_ph: tril}
- else:
- operator = linalg.LinearOperatorLowerTriangular(tril)
- feed_dict = None
+ lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None)
+
+ operator = linalg.LinearOperatorLowerTriangular(lin_op_tril)
- mat = array_ops.matrix_band_part(tril, -1, 0)
+ matrix = array_ops.matrix_band_part(tril, -1, 0)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 7d367a9275..6f401358a2 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -177,6 +177,12 @@ if __name__ == '__main__':
MatrixUnaryFunctorGradientTest, 'MatrixDeterminantGradient', name,
_GetMatrixUnaryFunctorGradientTest(linalg_ops.matrix_determinant,
dtype, shape))
+ _AddTest(
+ MatrixUnaryFunctorGradientTest, 'LogMatrixDeterminantGradient',
+ name,
+ _GetMatrixUnaryFunctorGradientTest(
+ lambda x: linalg_ops.log_matrix_determinant(x)[1],
+ dtype, shape))
# Tests for gradients of matrix_solve_ls
for dtype in np.float32, np.float64:
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 49855200c2..bf82e08551 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -46,7 +46,7 @@ def scalar_shape():
@test_util.with_c_shapes
class ListOpsTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPushPop(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
@@ -54,14 +54,14 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e), 1.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPushPopGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testPushPop()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testStack(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
@@ -70,14 +70,14 @@ class ListOpsTest(test_util.TensorFlowTestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testStackGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testStack()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorListFromTensor(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
@@ -87,14 +87,14 @@ class ListOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(self.evaluate(e), 1.0)
self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testFromTensorGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testTensorListFromTensor()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetSetItem(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
@@ -104,14 +104,14 @@ class ListOpsTest(test_util.TensorFlowTestCase):
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetSetGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testGetSetItem()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testUnknownShape(self):
l = list_ops.empty_tensor_list(
element_dtype=dtypes.float32, element_shape=-1)
@@ -122,7 +122,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(e), 1.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCPUGPUCopy(self):
if not context.num_gpus():
return
@@ -140,7 +140,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGraphStack(self):
with context.graph_mode(), self.test_session():
tl = list_ops.empty_tensor_list(
@@ -152,7 +152,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
[[1]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoop(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -170,7 +170,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGraphStackSwitchDtype(self):
with context.graph_mode(), self.test_session():
list_ = list_ops.empty_tensor_list(
@@ -192,7 +192,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
self.assertAllEqual(self.evaluate(s1), np_s1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGraphStackInLoopSwitchDtype(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -216,7 +216,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
self.assertAllEqual(self.evaluate(s1), np_s1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSerialize(self):
# pylint: disable=g-import-not-at-top
try:
@@ -248,7 +248,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
worker_e = array_ops.identity(e)
self.assertAllEqual(self.evaluate(worker_e), [2.0])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPushPopGradients(self):
with backprop.GradientTape() as tape:
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
@@ -260,7 +260,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
e = 2 * e
self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testStackFromTensorGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
@@ -272,7 +272,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
grad = tape.gradient(result, [c])[0]
self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetSetGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
@@ -288,14 +288,14 @@ class ListOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
self.assertAllEqual(self.evaluate(grad_c2), 6.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSetOutOfBounds(self):
c = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testResourceVariableScatterGather(self):
c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
@@ -319,7 +319,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
[[1.0, 2.0]] * 4)
self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConcat(self):
c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
@@ -379,7 +379,7 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
element_dtype=dtypes.float32))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPushBackBatch(self):
c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 28c85fa13a..e635a71c78 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -59,7 +59,7 @@ class LoggingOpsTest(test.TestCase):
class PrintGradientTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPrintShape(self):
inp = constant_op.constant(2.0, shape=[100, 32])
inp_printed = logging_ops.Print(inp, [inp])
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 1123c20a16..87fc715783 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -118,6 +119,14 @@ class AbsoluteDifferenceLossTest(test.TestCase):
with self.test_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerNoMemoryLeaked(self):
+ # This is a somewhat convoluted way of testing that nothing gets added to
+ # a global collection.
+ predictions = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
+ labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ losses.absolute_difference(labels, predictions)
+
class SoftmaxCrossEntropyLossTest(test.TestCase):
@@ -246,6 +255,13 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 0.0, 3)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerNoMemoryLeaked(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
+ losses.sparse_softmax_cross_entropy(labels, logits)
+
def testAllCorrectInt64Labels(self):
with self.test_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index a0c372db7d..e95c729715 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -947,7 +947,7 @@ class PoolingTest(test.TestCase):
output_sizes,
x_init_value=x_init_value,
delta=1e-2)
- print("%s gradient error = " % func_name, err)
+ tf_logging.info("%s gradient error = " % func_name, err)
self.assertLess(err, err_tolerance)
def _ConstructAndTestSecondGradient(self,
@@ -1024,7 +1024,7 @@ class PoolingTest(test.TestCase):
input_sizes,
x_init_value=x_init_value,
delta=1e-2)
- print("%s second-order gradient error = " % func_name, err)
+ tf_logging.info("%s second-order gradient error = " % func_name, err)
self.assertLess(err, err_tolerance)
def _testMaxPoolGradValidPadding1_1(self, data_format, use_gpu):
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index b59e3dd7e7..50154a45a8 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -27,6 +27,7 @@ from six.moves import queue
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
@@ -35,6 +36,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
@@ -458,7 +460,7 @@ class PyFuncTest(test.TestCase):
self.assertEqual(initial_size, script_ops._py_funcs.size())
# ----- Tests for eager_py_func -----
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerSingleOutputInt32(self):
a = array_ops.ones((3, 3), dtype=dtypes.int32)
x = array_ops.ones((3, 1), dtype=dtypes.int32)
@@ -466,7 +468,7 @@ class PyFuncTest(test.TestCase):
ret = self.evaluate(output)
self.assertAllEqual(ret, [[3], [3], [3]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerSingleOutputFloat32(self):
with test_util.device(use_gpu=True):
a = array_ops.ones((3, 3), dtype=dtypes.float32)
@@ -475,7 +477,7 @@ class PyFuncTest(test.TestCase):
ret = self.evaluate(output)
self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerArrayOutput(self):
with test_util.device(use_gpu=True):
a = array_ops.ones((3, 3), dtype=dtypes.float32)
@@ -485,7 +487,7 @@ class PyFuncTest(test.TestCase):
ret = self.evaluate(output)
self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerReturnNone(self):
with test_util.device(use_gpu=True):
def no_return_value():
@@ -498,7 +500,7 @@ class PyFuncTest(test.TestCase):
else:
self.assertIsNone(ret)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerPyFuncInDefun(self):
with test_util.device(use_gpu=True):
def wrapper():
@@ -510,7 +512,7 @@ class PyFuncTest(test.TestCase):
ret = self.evaluate(wrapped())
self.assertAllEqual(ret, [[3.0], [3.0], [3.0]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerExceptionHandling(self):
with test_util.device(use_gpu=True):
self._testExceptionHandling(
@@ -529,11 +531,10 @@ class PyFuncTest(test.TestCase):
self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerReturningVariableRaisesError(self):
def return_variable():
- variable = resource_variable_ops.ResourceVariable(0.0)
- return variable
+ return resource_variable_ops.ResourceVariable(0.0)
with self.assertRaisesRegexp(errors.UnknownError,
"Attempting to return a variable"):
@@ -541,6 +542,99 @@ class PyFuncTest(test.TestCase):
return_variable, inp=[], Tout=dtypes.float32)
self.evaluate(output)
+ @test_util.run_in_graph_and_eager_modes
+ def testEagerGradientTape(self):
+
+ def f(x):
+ return x**2
+
+ x = constant_op.constant(3.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
+ dy_dx = tape.gradient(y, x)
+ self.assertEqual(self.evaluate(dy_dx), 6.0)
+
+ def testEagerGradientGraph(self):
+
+ def f(x):
+ return x**2
+
+ x = constant_op.constant(3.0)
+ y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
+ dy_dx = gradients_impl.gradients(y, x)[0]
+ self.assertEqual(self.evaluate(dy_dx), 6.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEagerGradientTapeMultipleArgs(self):
+
+ def f(x, y):
+ return x**2 + y**2
+
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(4.0)
+ with backprop.GradientTape() as tape:
+ tape.watch(x)
+ tape.watch(y)
+ z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
+
+ dz_dx, dz_dy = tape.gradient(z, [x, y])
+ self.assertEqual(self.evaluate(dz_dx), 6.0)
+ self.assertEqual(self.evaluate(dz_dy), 8.0)
+
+ def testEagerGradientGraphMultipleArgs(self):
+
+ def f(x, y):
+ return x**2 + y**2
+
+ x = constant_op.constant(3.0)
+ y = constant_op.constant(4.0)
+ z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
+
+ dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
+ self.assertEqual(self.evaluate(dz_dx), 6.0)
+ self.assertEqual(self.evaluate(dz_dy), 8.0)
+
+ def testEagerGradientGraphLogHuber(self):
+
+ def log_huber(x, m):
+ if math_ops.abs(x) <= m:
+ return x**2
+ else:
+ return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))
+
+ x = array_ops.placeholder(dtypes.float32)
+ m = array_ops.placeholder(dtypes.float32)
+
+ y = script_ops.eager_py_func(
+ func=log_huber, inp=[x, m], Tout=dtypes.float32)
+ dy_dx = gradients_impl.gradients(y, x)[0]
+
+ with self.test_session() as sess:
+ # Takes the first branch of log_huber.
+ y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
+ self.assertEqual(y, 1.0)
+ self.assertEqual(dy_dx, 2.0)
+
+ def testEagerRespectsDevicePlacmentOfOp(self):
+
+ def f(x):
+ return math_ops.square(x)
+
+ def g(x):
+ return math_ops.add(x, x)
+
+ with ops.device("/CPU:0"):
+ # Explicitly ask for the py_funcs to execute on CPU, even if
+ # a GPU is available.
+ x = array_ops.placeholder(dtypes.float32)
+ y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
+ z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)
+
+ with self.test_session(use_gpu=True) as sess:
+ output = sess.run(z, feed_dict={x: 3.0})
+ self.assertEqual(output, 18.0)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/random/BUILD b/tensorflow/python/kernel_tests/random/BUILD
index acd7566eec..3b3a28fc9a 100644
--- a/tensorflow/python/kernel_tests/random/BUILD
+++ b/tensorflow/python/kernel_tests/random/BUILD
@@ -108,6 +108,23 @@ cuda_py_test(
)
cuda_py_test(
+ name = "random_grad_test",
+ size = "small",
+ srcs = ["random_grad_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:random_grad",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+cuda_py_test(
name = "random_poisson_test",
size = "medium",
srcs = ["random_poisson_test.py"],
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
index 051c7d86bf..bd64d61af8 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
@@ -54,7 +54,7 @@ native_sampler = random_ops.multinomial
class MultinomialTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSmallEntropy(self):
random_seed.set_random_seed(1618)
for output_dtype in [np.int32, np.int64]:
diff --git a/tensorflow/python/kernel_tests/random/random_grad_test.py b/tensorflow/python/kernel_tests/random/random_grad_test.py
new file mode 100644
index 0000000000..c1d455b785
--- /dev/null
+++ b/tensorflow/python/kernel_tests/random/random_grad_test.py
@@ -0,0 +1,240 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.random_grad."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_grad
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class AddLeadingUnitDimensionsTest(test.TestCase):
+
+ def testBasic(self):
+ ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 3)
+ self.assertAllEqual(ret.shape, [1, 1, 1, 3, 2, 1])
+
+ def testZeroExtraDimensions(self):
+ ret = random_grad.add_leading_unit_dimensions(array_ops.ones([3, 2, 1]), 0)
+ self.assertAllEqual(ret.shape, [3, 2, 1])
+
+ def testScalarInput(self):
+ ret = random_grad.add_leading_unit_dimensions(1.0, 2)
+ self.assertAllEqual(ret.shape, [1, 1])
+
+ def testUnknownShape(self):
+ x = array_ops.placeholder(dtypes.float32)
+ num_dimensions = array_ops.placeholder(dtypes.int32)
+ ret = random_grad.add_leading_unit_dimensions(x, num_dimensions)
+ with self.test_session() as sess:
+ ret_val = sess.run(ret, {x: np.ones([2, 2]), num_dimensions: 2})
+ self.assertAllEqual(ret_val.shape, [1, 1, 2, 2])
+
+
+class RandomGammaGradTest(test.TestCase):
+ """Tests for derivative of a sample ~ Gamma(alpha, beta) wrt alpha and beta.
+
+ The sample is an "implicit" function of alpha, beta and the independent random
+ noise u. The derivatives we are looking for are
+ d sample(alpha, beta, u) / dalpha (and dbeta).
+
+ The derivative w.r.t. beta is computed by the standard automatic
+ differentiation, so we trust that it is computed correctly.
+
+ The derivative w.r.t. alpha is computed by Eigen function, so we test it in
+ several ways. Unfortunately, the standard derivative checking by perturbing
+ the parameter is impossible here, because we cannot fix the value of u
+ in the random sampler. Instead, we compare the derivative for the given pair
+ of (sample, alpha) to the values computed in various ways, and also check
+ some statistical properties of the derivative.
+ """
+
+ def testGradientsShape(self):
+ shape = [2, 3]
+ alpha = array_ops.ones([2, 2])
+ beta = array_ops.ones([1, 2])
+ sample = random_ops.random_gamma(shape, alpha, beta)
+ grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta])
+ self.assertAllEqual(grads_alpha.shape, alpha.shape)
+ self.assertAllEqual(grads_beta.shape, beta.shape)
+
+ def testGradientsShapeWithOneSamplePerParameter(self):
+ shape = []
+ alpha = array_ops.ones([2, 2])
+ beta = array_ops.ones([1, 2])
+ sample = random_ops.random_gamma(shape, alpha, beta)
+ grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta])
+ self.assertAllEqual(grads_alpha.shape, alpha.shape)
+ self.assertAllEqual(grads_beta.shape, beta.shape)
+
+ def testGradientsUnknownShape(self):
+ shape = array_ops.placeholder(dtypes.int32)
+ alpha = array_ops.placeholder(dtypes.float32)
+ beta = array_ops.placeholder(dtypes.float32)
+ sample = random_ops.random_gamma(shape, alpha, beta)
+ grads_alpha, grads_beta = gradients_impl.gradients(sample, [alpha, beta])
+
+ alpha_val = np.ones([1, 2])
+ beta_val = np.ones([2, 1])
+ with self.test_session() as sess:
+ grads_alpha_val, grads_beta_val = sess.run(
+ [grads_alpha, grads_beta],
+ {alpha: alpha_val, beta: beta_val, shape: [2, 1]})
+ self.assertAllEqual(grads_alpha_val.shape, alpha_val.shape)
+ self.assertAllEqual(grads_beta_val.shape, beta_val.shape)
+
+ def _testCompareToExplicitDerivative(self, dtype):
+ """Compare to the explicit reparameterization derivative.
+
+ Verifies that the computed derivative satisfies
+ dsample / dalpha = d igammainv(alpha, u) / dalpha,
+ where u = igamma(alpha, sample).
+
+ Args:
+ dtype: TensorFlow dtype to perform the computations in.
+ """
+ delta = 1e-3
+ np_dtype = dtype.as_numpy_dtype
+ try:
+ from scipy import misc # pylint: disable=g-import-not-at-top
+ from scipy import special # pylint: disable=g-import-not-at-top
+
+ alpha_val = np.logspace(-2, 3, dtype=np_dtype)
+ alpha = constant_op.constant(alpha_val)
+ sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype)
+ actual = gradients_impl.gradients(sample, alpha)[0]
+
+ (sample_val, actual_val) = self.evaluate((sample, actual))
+
+ u = special.gammainc(alpha_val, sample_val)
+ expected_val = misc.derivative(
+ lambda alpha_prime: special.gammaincinv(alpha_prime, u),
+ alpha_val, dx=delta * alpha_val)
+
+ self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3)
+ except ImportError as e:
+ tf_logging.warn("Cannot use special functions in a test: %s" % str(e))
+
+ def testCompareToExplicitDerivativeFloat(self):
+ self._testCompareToExplicitDerivative(dtypes.float32)
+
+ def testCompareToExplicitDerivativeDouble(self):
+ self._testCompareToExplicitDerivative(dtypes.float64)
+
+ def _testCompareToImplicitDerivative(self, dtype):
+ """Compare to the implicit reparameterization derivative.
+
+ Let's derive the formula we compare to.
+
+ Start from the fact that CDF maps a random variable to the Uniform
+ random variable:
+ igamma(alpha, sample) = u, where u ~ Uniform(0, 1).
+
+ Apply d / dalpha to both sides:
+ d igamma(alpha, sample) / dalpha
+ + d igamma(alpha, sample) / dsample * dsample/dalpha = 0
+ d igamma(alpha, sample) / dalpha
+ + d igamma(alpha, sample) / dsample * dsample / dalpha = 0
+ dsample/dalpha = - (d igamma(alpha, sample) / dalpha)
+ / d igamma(alpha, sample) / dsample
+
+ This is the equation (8) of https://arxiv.org/abs/1805.08498
+
+ Args:
+ dtype: TensorFlow dtype to perform the computations in.
+ """
+ np_dtype = dtype.as_numpy_dtype
+ alpha = constant_op.constant(np.logspace(-2, 3, dtype=np_dtype))
+ sample = random_ops.random_gamma([], alpha, np_dtype(1.0), dtype=dtype)
+ actual = gradients_impl.gradients(sample, alpha)[0]
+
+ sample_sg = array_ops.stop_gradient(sample)
+ cdf = math_ops.igamma(alpha, sample_sg)
+ dcdf_dalpha, dcdf_dsample = gradients_impl.gradients(
+ cdf, [alpha, sample_sg])
+ # Numerically unstable due to division, do not try at home.
+ expected = -dcdf_dalpha / dcdf_dsample
+
+ (actual_val, expected_val) = self.evaluate((actual, expected))
+
+ self.assertAllClose(actual_val, expected_val, rtol=1e-3, atol=1e-3)
+
+ def testCompareToImplicitDerivativeFloat(self):
+ self._testCompareToImplicitDerivative(dtypes.float32)
+
+ def testCompareToImplicitDerivativeDouble(self):
+ self._testCompareToImplicitDerivative(dtypes.float64)
+
+ def testAverageAlphaGradient(self):
+ """Statistical test for the gradient.
+
+ Using the equation (5) of https://arxiv.org/abs/1805.08498, we have
+ 1 = d/dalpha E_{sample ~ Gamma(alpha, 1)} sample
+ = E_{sample ~ Gamma(alpha, 1)} dsample/dalpha.
+ Here we verify that the rhs is fairly close to one.
+ The convergence speed is not great, so we use many samples and loose bounds.
+ """
+ num_samples = 1000
+ alpha = constant_op.constant([0.8, 1e1, 1e3], dtype=dtypes.float32)
+ sample = random_ops.random_gamma([num_samples], alpha)
+ # We need to average the gradients, which is equivalent to averaging the
+ # samples and then doing backprop.
+ mean_sample = math_ops.reduce_mean(sample, axis=0)
+ dsample_dalpha = gradients_impl.gradients(mean_sample, alpha)[0]
+ dsample_dalpha_val = self.evaluate(dsample_dalpha)
+ self.assertAllClose(dsample_dalpha_val, [1.0] * 3, atol=1e-1, rtol=1e-1)
+
+ def testQuadraticLoss(self):
+ """Statistical test for the gradient.
+
+ The equation (5) of https://arxiv.org/abs/1805.08498 says
+ d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample)
+ = E_{sample ~ Gamma(alpha, 1)} df(sample)/dalpha.
+
+ Choose a quadratic loss function f(sample) = (sample - t)^2.
+ Then, the lhs can be computed analytically:
+ d/dalpha E_{sample ~ Gamma(alpha, 1)} f(sample)
+ = d/dalpha [ (alpha + alpha^2) - 2 * t * alpha + t^2 ]
+ = 1 + 2 * alpha - 2 * t.
+
+ We compare the Monte-Carlo estimate of the expectation with the
+ true gradient.
+ """
+ num_samples = 1000
+ t = 0.3
+ alpha = 0.5
+ expected = 1 + 2 * alpha - 2 * t
+
+ alpha = constant_op.constant(alpha)
+ sample = random_ops.random_gamma([num_samples], alpha, 1.0)
+ loss = math_ops.reduce_mean(math_ops.square(sample - t))
+ dloss_dalpha = gradients_impl.gradients(loss, alpha)[0]
+ dloss_dalpha_val = self.evaluate(dloss_dalpha)
+ self.assertAllClose(expected, dloss_dalpha_val, atol=1e-1, rtol=1e-1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 82a27eebee..8e06e1abfb 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -25,8 +25,6 @@ import shutil
import threading
import zlib
-import six
-
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -77,6 +75,69 @@ _TEXT = b"""Gaily bedight,
"""
+class TFCompressionTestCase(test.TestCase):
+
+ def setUp(self):
+ super(TFCompressionTestCase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ def _Record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _CreateFiles(self, options=None, prefix=""):
+ filenames = []
+ for i in range(self._num_files):
+ name = prefix + "tfrecord.%d.txt" % i
+ records = [self._Record(i, j) for j in range(self._num_records)]
+ fn = self._WriteRecordsToFile(records, name, options)
+ filenames.append(fn)
+ return filenames
+
+ def _WriteRecordsToFile(self, records, name="tfrecord", options=None):
+ fn = os.path.join(self.get_temp_dir(), name)
+ with tf_record.TFRecordWriter(fn, options=options) as writer:
+ for r in records:
+ writer.write(r)
+ return fn
+
+ def _ZlibCompressFile(self, infile, name="tfrecord.z"):
+ # zlib compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = zlib.compress(f.read())
+
+ zfn = os.path.join(self.get_temp_dir(), name)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ return zfn
+
+ def _GzipCompressFile(self, infile, name="tfrecord.gz"):
+ # gzip compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = f.read()
+
+ gzfn = os.path.join(self.get_temp_dir(), name)
+ with gzip.GzipFile(gzfn, "wb") as f:
+ f.write(cdata)
+ return gzfn
+
+ def _ZlibDecompressFile(self, infile, name="tfrecord"):
+ with open(infile, "rb") as f:
+ cdata = zlib.decompress(f.read())
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+ def _GzipDecompressFile(self, infile, name="tfrecord"):
+ with gzip.GzipFile(infile, "rb") as f:
+ cdata = f.read()
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+
class IdentityReaderTest(test.TestCase):
def _ExpectRead(self, sess, key, value, expected):
@@ -348,7 +409,7 @@ class TextLineReaderTest(test.TestCase):
k, v = sess.run([key, value])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -407,40 +468,18 @@ class FixedLengthRecordReaderTest(test.TestCase):
# gap_bytes=hop_bytes-record_bytes
def _CreateGzipFiles(self, num_records, gap_bytes):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with gzip.GzipFile(fn, "wb") as f:
- f.write(b"H" * self._header_bytes)
- if num_records > 0:
- f.write(self._Record(i, 0))
- for j in range(1, num_records):
- if gap_bytes > 0:
- f.write(b"G" * gap_bytes)
- f.write(self._Record(i, j))
- f.write(b"F" * self._footer_bytes)
+ filenames = self._CreateFiles(num_records, gap_bytes)
+ for fn in filenames:
+ # compress inplace.
+ self._GzipCompressFile(fn, fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
def _CreateZlibFiles(self, num_records, gap_bytes):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with open(fn + ".tmp", "wb") as f:
- f.write(b"H" * self._header_bytes)
- if num_records > 0:
- f.write(self._Record(i, 0))
- for j in range(1, num_records):
- if gap_bytes > 0:
- f.write(b"G" * gap_bytes)
- f.write(self._Record(i, j))
- f.write(b"F" * self._footer_bytes)
- with open(fn + ".tmp", "rb") as f:
- cdata = zlib.compress(f.read())
- with open(fn, "wb") as zf:
- zf.write(cdata)
+ filenames = self._CreateFiles(num_records, gap_bytes)
+ for fn in filenames:
+ # compress inplace.
+ self._ZlibCompressFile(fn, fn)
return filenames
def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records):
@@ -477,10 +516,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
])
f.write(compat.as_bytes(all_records_str))
f.write(b"F" * self._footer_bytes)
- with open(fn + ".tmp", "rb") as f:
- cdata = zlib.compress(f.read())
- with open(fn, "wb") as zf:
- zf.write(cdata)
+ self._ZlibCompressFile(fn + ".tmp", fn)
return filenames
# gap_bytes=hop_bytes-record_bytes
@@ -529,7 +565,6 @@ class FixedLengthRecordReaderTest(test.TestCase):
for i in range(self._num_files):
for j in range(num_overlapped_records):
k, v = sess.run([key, value])
- print(v)
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
self.assertAllEqual(self._OverlappedRecord(i, j), v)
@@ -579,25 +614,10 @@ class FixedLengthRecordReaderTest(test.TestCase):
files, num_overlapped_records, encoding="ZLIB")
-class TFRecordReaderTest(test.TestCase):
+class TFRecordReaderTest(TFCompressionTestCase):
def setUp(self):
super(TFRecordReaderTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- def _Record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _CreateFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = tf_record.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._Record(i, j))
- return filenames
def testOneEpoch(self):
files = self._CreateFiles()
@@ -647,107 +667,27 @@ class TFRecordReaderTest(test.TestCase):
self.assertEqual(self._num_files * self._num_records, num_v)
def testReadZlibFiles(self):
- files = self._CreateFiles()
- zlib_files = []
- for i, fn in enumerate(files):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
-
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ files = self._CreateFiles(options)
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
- queue.enqueue_many([zlib_files]).run()
+ queue.enqueue_many([files]).run()
queue.close().run()
for i in range(self._num_files):
for j in range(self._num_records):
k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i]))
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
def testReadGzipFiles(self):
- files = self._CreateFiles()
- gzip_files = []
- for i, fn in enumerate(files):
- with open(fn, "rb") as f:
- cdata = f.read()
-
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(zfn, "wb") as f:
- f.write(cdata)
- gzip_files.append(zfn)
-
- with self.test_session() as sess:
- options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
- reader = io_ops.TFRecordReader(name="test_reader", options=options)
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
-
- queue.enqueue_many([gzip_files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_records):
- k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
- self.assertAllEqual(self._Record(i, j), v)
-
-
-class TFRecordWriterZlibTest(test.TestCase):
-
- def setUp(self):
- super(TFRecordWriterZlibTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- def _Record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _CreateFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
- writer = tf_record.TFRecordWriter(fn, options=options)
- for j in range(self._num_records):
- writer.write(self._Record(i, j))
- writer.close()
- del writer
-
- return filenames
-
- def _WriteRecordsToFile(self, records, name="tf_record"):
- fn = os.path.join(self.get_temp_dir(), name)
- writer = tf_record.TFRecordWriter(fn, options=None)
- for r in records:
- writer.write(r)
- writer.close()
- del writer
- return fn
-
- def _ZlibCompressFile(self, infile, name="tfrecord.z"):
- # zlib compress the file and write compressed contents to file.
- with open(infile, "rb") as f:
- cdata = zlib.compress(f.read())
-
- zfn = os.path.join(self.get_temp_dir(), name)
- with open(zfn, "wb") as f:
- f.write(cdata)
- return zfn
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ files = self._CreateFiles(options)
- def testOneEpoch(self):
- files = self._CreateFiles()
with self.test_session() as sess:
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -760,196 +700,6 @@ class TFRecordWriterZlibTest(test.TestCase):
self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
self.assertAllEqual(self._Record(i, j), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
-
- def testZLibFlushRecord(self):
- fn = self._WriteRecordsToFile([b"small record"], "small_record")
- with open(fn, "rb") as h:
- buff = h.read()
-
- # creating more blocks and trailing blocks shouldn't break reads
- compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)
-
- output = b""
- for c in buff:
- if isinstance(c, int):
- c = six.int2byte(c)
- output += compressor.compress(c)
- output += compressor.flush(zlib.Z_FULL_FLUSH)
-
- output += compressor.flush(zlib.Z_FULL_FLUSH)
- output += compressor.flush(zlib.Z_FULL_FLUSH)
- output += compressor.flush(zlib.Z_FINISH)
-
- # overwrite the original file with the compressed data
- with open(fn, "wb") as h:
- h.write(output)
-
- with self.test_session() as sess:
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
- reader = io_ops.TFRecordReader(name="test_reader", options=options)
- queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
- key, value = reader.read(queue)
- queue.enqueue(fn).run()
- queue.close().run()
- k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
- self.assertAllEqual(b"small record", v)
-
- def testZlibReadWrite(self):
- """Verify that files produced are zlib compatible."""
- original = [b"foo", b"bar"]
- fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord")
- zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z")
-
- # read the compressed contents and verify.
- actual = []
- for r in tf_record.tf_record_iterator(
- zfn,
- options=tf_record.TFRecordOptions(
- tf_record.TFRecordCompressionType.ZLIB)):
- actual.append(r)
- self.assertEqual(actual, original)
-
- def testZlibReadWriteLarge(self):
- """Verify that writing large contents also works."""
-
- # Make it large (about 5MB)
- original = [_TEXT * 10240]
- fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord")
- zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z")
-
- # read the compressed contents and verify.
- actual = []
- for r in tf_record.tf_record_iterator(
- zfn,
- options=tf_record.TFRecordOptions(
- tf_record.TFRecordCompressionType.ZLIB)):
- actual.append(r)
- self.assertEqual(actual, original)
-
- def testGzipReadWrite(self):
- """Verify that files produced are gzip compatible."""
- original = [b"foo", b"bar"]
- fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
-
- # gzip compress the file and write compressed contents to file.
- with open(fn, "rb") as f:
- cdata = f.read()
- gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
- with gzip.GzipFile(gzfn, "wb") as f:
- f.write(cdata)
-
- actual = []
- for r in tf_record.tf_record_iterator(
- gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)):
- actual.append(r)
- self.assertEqual(actual, original)
-
-
-class TFRecordIteratorTest(test.TestCase):
-
- def setUp(self):
- super(TFRecordIteratorTest, self).setUp()
- self._num_records = 7
-
- def _Record(self, r):
- return compat.as_bytes("Record %d" % r)
-
- def _WriteCompressedRecordsToFile(
- self,
- records,
- name="tfrecord.z",
- compression_type=tf_record.TFRecordCompressionType.ZLIB):
- fn = os.path.join(self.get_temp_dir(), name)
- options = tf_record.TFRecordOptions(compression_type=compression_type)
- writer = tf_record.TFRecordWriter(fn, options=options)
- for r in records:
- writer.write(r)
- writer.close()
- del writer
- return fn
-
- def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS):
- with open(infile, "rb") as f:
- cdata = zlib.decompress(f.read(), wbits)
- zfn = os.path.join(self.get_temp_dir(), name)
- with open(zfn, "wb") as f:
- f.write(cdata)
- return zfn
-
- def testIterator(self):
- fn = self._WriteCompressedRecordsToFile(
- [self._Record(i) for i in range(self._num_records)],
- "compressed_records")
- options = tf_record.TFRecordOptions(
- compression_type=TFRecordCompressionType.ZLIB)
- reader = tf_record.tf_record_iterator(fn, options)
- for i in range(self._num_records):
- record = next(reader)
- self.assertAllEqual(self._Record(i), record)
- with self.assertRaises(StopIteration):
- record = next(reader)
-
- def testWriteZlibRead(self):
- """Verify compression with TFRecordWriter is zlib library compatible."""
- original = [b"foo", b"bar"]
- fn = self._WriteCompressedRecordsToFile(original,
- "write_zlib_read.tfrecord.z")
- zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord")
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
- self.assertEqual(actual, original)
-
- def testWriteZlibReadLarge(self):
- """Verify compression for large records is zlib library compatible."""
- # Make it large (about 5MB)
- original = [_TEXT * 10240]
- fn = self._WriteCompressedRecordsToFile(original,
- "write_zlib_read_large.tfrecord.z")
- zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record")
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
- self.assertEqual(actual, original)
-
- def testWriteGzipRead(self):
- original = [b"foo", b"bar"]
- fn = self._WriteCompressedRecordsToFile(
- original,
- "write_gzip_read.tfrecord.gz",
- compression_type=TFRecordCompressionType.GZIP)
-
- with gzip.GzipFile(fn, "rb") as f:
- cdata = f.read()
- zfn = os.path.join(self.get_temp_dir(), "tf_record")
- with open(zfn, "wb") as f:
- f.write(cdata)
-
- actual = []
- for r in tf_record.tf_record_iterator(zfn):
- actual.append(r)
- self.assertEqual(actual, original)
-
- def testBadFile(self):
- """Verify that tf_record_iterator throws an exception on bad TFRecords."""
- fn = os.path.join(self.get_temp_dir(), "bad_file")
- with tf_record.TFRecordWriter(fn) as writer:
- writer.write(b"123")
- fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated")
- with open(fn, "rb") as f:
- with open(fn_truncated, "wb") as f2:
- # DataLossError requires that we've written the header, so this must
- # be at least 12 bytes.
- f2.write(f.read(14))
- with self.assertRaises(errors_impl.DataLossError):
- for _ in tf_record.tf_record_iterator(fn_truncated):
- pass
-
class AsyncReaderTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 00d517e64e..e358293a90 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -145,14 +145,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertIn("<unprintable>", str(handle))
self.assertIn("<unprintable>", repr(handle))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDtypeSurvivesIdentity(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
id_handle = array_ops.identity(handle)
self.evaluate(resource_variable_ops.assign_variable_op(
id_handle, constant_op.constant(0, dtype=dtypes.int32)))
- @test_util.run_in_graph_and_eager_modes()
+ def testUnreadOpName(self):
+ v = resource_variable_ops.ResourceVariable(1.0)
+ self.assertNotEqual(v.name, v.assign_add(1.0).name)
+
+ @test_util.run_in_graph_and_eager_modes
def testCreateRead(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
@@ -161,7 +165,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
self.assertAllEqual(1, value)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testManyAssigns(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
create = resource_variable_ops.assign_variable_op(
@@ -179,7 +183,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(f, 1)
self.assertEqual(s, 2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAssignAdd(self):
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
self.evaluate(resource_variable_ops.assign_variable_op(
@@ -190,7 +194,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
self.assertEqual(read, 2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterAdd(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -203,7 +207,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterSub(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -216,7 +220,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[-1]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMul(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -229,7 +233,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterDiv(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -242,7 +246,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMin(self):
with ops.device("cpu:0"):
handle = resource_variable_ops.var_handle_op(
@@ -279,7 +283,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
meta_graph_two = saver.export_meta_graph(graph=graph)
self.assertEqual(meta_graph_def, meta_graph_two)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMax(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -292,7 +296,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[6]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterAddScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -305,7 +309,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterSubScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -318,7 +322,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[-1]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMulScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -331,7 +335,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterDivScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -344,7 +348,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMinScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -357,7 +361,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[3]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterMaxScalar(self):
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1, 1])
@@ -422,7 +426,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_update(ref, indices, updates)
self.assertAllEqual(ref.read_value(), [True, True, True])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConstraintArg(self):
constraint = lambda x: x
v = resource_variable_ops.ResourceVariable(
@@ -462,32 +466,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.OutOfRangeError):
state_ops.count_up_to(v, 1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitFnDtype(self):
v = resource_variable_ops.ResourceVariable(
initial_value=lambda: 1, dtype=dtypes.float32, name="var0")
self.assertEqual(dtypes.float32, v.value().dtype)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitFnNoDtype(self):
v = resource_variable_ops.ResourceVariable(initial_value=lambda: 1,
name="var2")
self.assertEqual(dtypes.int32, v.value().dtype)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitializeAllVariables(self):
v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.float32,
name="var0")
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(v.value()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOperatorOverload(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
self.evaluate(variables.global_variables_initializer())
self.assertEqual(2.0, self.evaluate(v + v))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAssignMethod(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
self.evaluate(variables.global_variables_initializer())
@@ -505,7 +509,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLoad(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
self.evaluate(variables.global_variables_initializer())
@@ -557,7 +561,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
variable_def=trainable_variable.to_proto())
.trainable)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSparseRead(self):
with self.test_session():
init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
@@ -579,7 +583,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEquals(v._handle, w._handle)
self.assertEquals(v._graph_element, w._graph_element)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAssignAddMethod(self):
v = resource_variable_ops.ResourceVariable(1.0, name="var0")
self.evaluate(variables.global_variables_initializer())
@@ -597,7 +601,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(assign_without_read)
self.assertEqual(4.0, self.evaluate(v.value()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAssignSubMethod(self):
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
self.evaluate(variables.global_variables_initializer())
@@ -615,7 +619,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.evaluate(assign_without_read)
self.assertEqual(0.0, self.evaluate(v.value()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDestroyResource(self):
v = resource_variable_ops.ResourceVariable(3.0, name="var0")
self.evaluate(variables.global_variables_initializer())
@@ -704,7 +708,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
self.assertEqual(300.0, self.evaluate(w_read))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testShape(self):
v = resource_variable_ops.ResourceVariable(
name="var4", initial_value=array_ops.ones(shape=[10, 20, 35]))
@@ -822,13 +826,23 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_add(v, [1], [3])
self.assertAllEqual([1.0, 5.0], v.numpy())
+ def testScatterNdAddStateOps(self):
+ with context.eager_mode():
+ v = resource_variable_ops.ResourceVariable(
+ [1, 1, 1, 1, 1, 1, 1, 1], dtype=dtypes.float32, name="add")
+ indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
+ expected = np.array([1, 12, 1, 11, 10, 1, 1, 13])
+ state_ops.scatter_nd_add(v, indices, updates)
+ self.assertAllClose(expected, v.numpy())
+
def testScatterUpdateCast(self):
with context.eager_mode():
v = resource_variable_ops.ResourceVariable([1.0, 2.0], name="update")
state_ops.scatter_update(v, [1], [3])
self.assertAllEqual([1.0, 3.0], v.numpy())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScatterUpdateInvalidArgs(self):
v = resource_variable_ops.ResourceVariable([0, 1, 2, 3], name="update")
# The exact error and message differ between graph construction (where the
@@ -838,5 +852,62 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_same_as_var_dtype(self):
+ # read_dtype is same as dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float32)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # dtype is not read_dtype, return NotImplemented
+ self.assertEqual(
+ NotImplemented, v._dense_var_to_tensor(dtype=dtypes.float16))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float32)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float32, as_ref=True))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_dense_var_to_tensor_read_dtype_different_from_var_dtype(self):
+ # read_dtype is different from dtype
+ v = resource_variable_ops.ResourceVariable(1.0, dtype=dtypes.float32)
+ v = resource_variable_ops._MixedPrecisionVariable(v, dtypes.float16)
+ if not context.executing_eagerly():
+ v.initializer.run()
+
+ # as_ref is False
+ t = v._dense_var_to_tensor(as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ t = v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=False)
+ self.assertTrue(isinstance(t, ops.Tensor))
+ self.assertEqual(t.dtype, dtypes.float16)
+ self.assertEqual(self.evaluate(t), 1.0)
+
+ # as_ref is True
+ self.assertEqual(NotImplemented, v._dense_var_to_tensor(as_ref=True))
+ self.assertEqual(NotImplemented,
+ v._dense_var_to_tensor(dtype=dtypes.float16, as_ref=True))
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index fe5ad84c10..acee180a6c 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -81,6 +81,25 @@ class ScalarStateRNNCell(rnn_cell_impl.RNNCell):
return (input_, state + 1)
+class UnbalancedOutputRNNCell(rnn_cell_impl.RNNCell):
+ """RNN Cell generating (output, new_state) = (input + 1, state + 1)."""
+
+ @property
+ def output_size(self):
+ return tensor_shape.TensorShape(1), tensor_shape.TensorShape((2))
+
+ @property
+ def state_size(self):
+ return tensor_shape.TensorShape([])
+
+ def zero_state(self, batch_size, dtype):
+ return array_ops.zeros([], dtype=dtypes.int32)
+
+ def call(self, input_, state, scope=None):
+ concatenated = array_ops.concat((input_, input_), axis=-1)
+ return (input_, concatenated), state + 1
+
+
class TensorArrayStateRNNCell(rnn_cell_impl.RNNCell):
"""RNN Cell its state as a TensorArray."""
@@ -108,7 +127,7 @@ class RNNTest(test.TestCase):
self._seed = 23489
np.random.seed(self._seed)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInvalidSequenceLengthShape(self):
cell = Plus1RNNCell()
if context.executing_eagerly():
@@ -122,7 +141,7 @@ class RNNTest(test.TestCase):
dtype=dtypes.float32,
sequence_length=[[4]])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBatchSizeFromInput(self):
cell = Plus1RNNCell()
in_eager_mode = context.executing_eagerly()
@@ -162,7 +181,7 @@ class RNNTest(test.TestCase):
self.assertEqual(None, outputs.shape[0].value)
self.assertEqual(None, state.shape[0].value)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScalarStateIsAccepted(self):
cell = ScalarStateRNNCell()
in_eager_mode = context.executing_eagerly()
@@ -182,7 +201,29 @@ class RNNTest(test.TestCase):
self.assertAllEqual([[[1], [2], [3], [4]]], outputs)
self.assertAllEqual(4, state)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
+ def testUnbalancedOutputIsAccepted(self):
+ cell = UnbalancedOutputRNNCell()
+ in_eager_mode = context.executing_eagerly()
+
+ if in_eager_mode:
+ inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
+ else:
+ inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1))
+
+ with self.test_session() as sess:
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32, sequence_length=[4])
+ if not in_eager_mode:
+ outputs, state = sess.run(
+ [outputs, state], feed_dict={inputs: [[[1], [2], [3], [4]]]})
+
+ self.assertIsInstance(outputs, tuple)
+ self.assertAllEqual([[[1], [2], [3], [4]]], outputs[0])
+ self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
+ self.assertAllEqual(4, state)
+
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell()
in_eager_mode = context.executing_eagerly()
@@ -215,7 +256,7 @@ class RNNTest(test.TestCase):
cell_output, _ = cell(array_ops.zeros(in_shape, dtype), state_output)
self.assertAllEqual([batch_size, out_size], cell_output.shape.as_list())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCellsBuild(self):
f32 = dtypes.float32
f64 = dtypes.float64
@@ -227,6 +268,12 @@ class RNNTest(test.TestCase):
self._assert_cell_builds(rnn_cell_impl.GRUCell, f64, 5, 7, 3)
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f32, 5, 7, 3)
self._assert_cell_builds(rnn_cell_impl.LSTMCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndRNNCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndRNNCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyGRUCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyGRUCell, f64, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f32, 5, 7, 3)
+ self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3)
######### Benchmarking RNN code
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index faa4b49a8d..f9b9c77bbf 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -369,7 +369,7 @@ class ScatterNdTest(test.TestCase):
del input_ # input_ is not used in scatter_nd
return array_ops.scatter_nd(indices, updates, shape)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInvalidShape(self):
# TODO(apassos) figure out how to unify these errors
with self.assertRaises(errors.InvalidArgumentError
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 794be096b7..a82855dfeb 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -264,7 +264,9 @@ class UnsortedSegmentTest(SegmentReductionHelper):
# A subset of ops has been enabled for complex numbers
self.complex_ops_list = [(np.add, None,
- math_ops.unsorted_segment_sum, lambda t: 0)]
+ math_ops.unsorted_segment_sum, lambda t: 0),
+ (np.ndarray.__mul__, None,
+ math_ops.unsorted_segment_prod, lambda t: 1)]
self.differentiable_dtypes = [dtypes_lib.float16, dtypes_lib.float32,
dtypes_lib.float64]
self.all_dtypes = (self.differentiable_dtypes +
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 7368251ab6..34e34d9d1b 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -642,6 +642,29 @@ class TileTest(test.TestCase):
err = gradient_checker.compute_gradient_error(a, [4, 2], tiled, [4, 4])
self.assertLess(err, 1e-3)
+ def testGradientWithSparseGradWithRank1(self):
+ inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0],
+ dtype=dtypes.float32)
+ outputs = array_ops.gather(array_ops.tile(inputs, [3]),
+ [1, 5, 9, 3, 7, 2, 2, 2])
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
+ def testGradientWithSparseGradWithRank3(self):
+ inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0],
+ dtype=dtypes.float32)
+ inputs = array_ops.reshape(inputs, [-1, 1, 1])
+ outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]),
+ [1, 5, 9, 3, 7, 2, 2, 2])
+ with self.test_session():
+ error = gradient_checker.compute_gradient_error(
+ inputs, inputs.get_shape().as_list(),
+ outputs, outputs.get_shape().as_list())
+ self.assertLess(error, 1e-4)
+
def testShapeFunctionEdgeCases(self):
# Unknown multiples shape.
inp = constant_op.constant(0.0, shape=[4, 4, 4, 4])
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 5fc9bef218..402f67619b 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -225,7 +225,7 @@ class SliceTest(test.TestCase):
self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3])
m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1])
- self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None])
+ self.assertAllEqual(m2.get_shape().as_list(), [1, 2, 3])
def _testGradientSlice(self, input_shape, slice_begin, slice_size):
diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
index 27b39a626f..3847cebc7d 100644
--- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
@@ -300,6 +300,51 @@ class SerializeSparseTest(test.TestCase):
sparse_ops.serialize_many_sparse, sparse_ops.deserialize_sparse,
dtypes.variant)
+ def testVariantSerializeDeserializeScalar(self):
+ with self.test_session(use_gpu=False) as sess:
+ indices_value = np.array([[]], dtype=np.int64)
+ values_value = np.array([37], dtype=np.int32)
+ shape_value = np.array([], dtype=np.int64)
+ sparse_tensor = self._SparseTensorPlaceholder()
+ serialized = sparse_ops.serialize_sparse(
+ sparse_tensor, out_type=dtypes.variant)
+ deserialized = sparse_ops.deserialize_sparse(
+ serialized, dtype=dtypes.int32)
+ deserialized_value = sess.run(
+ deserialized,
+ feed_dict={
+ sparse_tensor.indices: indices_value,
+ sparse_tensor.values: values_value,
+ sparse_tensor.dense_shape: shape_value
+ })
+ self.assertAllEqual(deserialized_value.indices, indices_value)
+ self.assertAllEqual(deserialized_value.values, values_value)
+ self.assertAllEqual(deserialized_value.dense_shape, shape_value)
+
+ def testVariantSerializeDeserializeScalarBatch(self):
+ with self.test_session(use_gpu=False) as sess:
+ indices_value = np.array([[]], dtype=np.int64)
+ values_value = np.array([37], dtype=np.int32)
+ shape_value = np.array([], dtype=np.int64)
+ sparse_tensor = self._SparseTensorPlaceholder()
+ serialized = sparse_ops.serialize_sparse(
+ sparse_tensor, out_type=dtypes.variant)
+ stacked = array_ops.stack([serialized, serialized])
+ deserialized = sparse_ops.deserialize_sparse(stacked, dtype=dtypes.int32)
+ deserialized_value = sess.run(
+ deserialized,
+ feed_dict={
+ sparse_tensor.indices: indices_value,
+ sparse_tensor.values: values_value,
+ sparse_tensor.dense_shape: shape_value
+ })
+ self.assertAllEqual(deserialized_value.indices,
+ np.array([[0], [1]], dtype=np.int64))
+ self.assertAllEqual(deserialized_value.values,
+ np.array([37, 37], dtype=np.int32))
+ self.assertAllEqual(deserialized_value.dense_shape,
+ np.array([2], dtype=np.int64))
+
def _testDeserializeFailsWrongTypeHelper(self,
serialize_fn,
deserialize_fn,
diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
index da116601f8..97f30daf4a 100644
--- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
@@ -21,13 +21,15 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import sparse_ops
+import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class SparseSliceOpTest(test.TestCase):
- def _SparseTensor_4x6(self):
+ def _SparseTensor_4x6(self, val_dtype=np.int64):
# [0 | |2 | |4 |5 ]
# [ |11| |13|14| ]
# [20| | |23| |25]
@@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase):
[2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
np.int64)
val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
- np.int64)
+ val_dtype)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)
@@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])
+ def testGradients(self):
+ sp_input = self._SparseTensor_4x6(val_dtype=np.float32)
+ start_and_size = [([0, 0], [4, 2]),
+ ([0, 2], [5, 2]),
+ ([0, 4], [5, 3])]
+
+ with self.test_session(use_gpu=False):
+ for start, size in start_and_size:
+ sp_output = sparse_ops.sparse_slice(sp_input, start, size)
+ nnz_in = len(sp_input.values.eval())
+ nnz_out = len(sp_output.values.eval())
+
+ err = gradient_checker.compute_gradient_error(
+ [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,))
+ self.assertLess(err, 1e-3)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 8cfee3eb93..419cd5ecda 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -95,7 +95,7 @@ class SplitOpTest(test.TestCase):
sess.run(array_ops.split(value, size_splits), {size_splits: [2, 2, 6]})
self.assertTrue("Cannot infer num from shape" in str(context.exception))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testExplicitNum(self):
size_splits = array_ops.constant([2, 2, 6], dtype=dtypes.int32)
value = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
@@ -109,7 +109,7 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(r[1], value[2:4])
self.assertAllEqual(r[2], value[4:])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testListOfScalarTensors(self):
a = math_ops.to_int32(5)
b = math_ops.to_int32(6)
@@ -168,7 +168,7 @@ class SplitOpTest(test.TestCase):
offset += size_splits[i]
self.assertAllEqual(result[i], inp[slices])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSpecialCasesVariable(self):
self._testSpecialCasesVariable()
for dtype in _TEST_DTYPES:
@@ -210,13 +210,13 @@ class SplitOpTest(test.TestCase):
self.assertAllEqual(np_ans[i], out[i])
self.assertShapeEqual(np_ans[i], tf_ans[i])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSplitRows(self):
for dtype in _TEST_DTYPES:
inp = self._makeData((4, 4), dtype)
self._compare(inp, 0, 4)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSplitCols(self):
for dtype in _TEST_DTYPES:
inp = self._makeData((4, 4), dtype)
@@ -232,7 +232,7 @@ class SplitOpTest(test.TestCase):
self.assertEqual(out[i].shape, expected_shape)
self.assertEqual(expected_shape, tf_ans[i].get_shape())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEmpty(self):
# Note: np.split returns a rank-0 empty ndarray
# if the input ndarray is empty.
@@ -244,7 +244,7 @@ class SplitOpTest(test.TestCase):
self._testEmpty(inp, 2, 3, (8, 0, 7))
self._testEmpty(inp, 2, 7, (8, 0, 3))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testIdentity(self):
for dtype in _TEST_DTYPES:
inp = self._makeData((2, 2, 2), dtype)
@@ -252,7 +252,7 @@ class SplitOpTest(test.TestCase):
self._compare(inp, 1, 1)
self._compare(inp, 2, 1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSplitDim0(self):
for dtype in _TEST_DTYPES:
self._compare(self._makeData((6, 10, 18), dtype), 0, 3)
@@ -281,7 +281,7 @@ class SplitOpTest(test.TestCase):
offset += length
self.assertAllEqual(result[i], inp[slices])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRandom(self):
for dtype in _TEST_DTYPES:
for _ in range(5):
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index a5bd1b6ee0..e20daccb28 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -146,5 +146,101 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(shape, [3, 1])
+class StringSplitV2OpTest(test.TestCase):
+
+ def testSplitV2(self):
+ strings = ["pigs on the wing", "animals"]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings)
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
+ self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
+ self.assertAllEqual(shape, [2, 4])
+
+ def testSplitV2MultiCharSeparator(self):
+ # Match Python behavior:
+ # >>> '1<>2<>3'.split('<>')
+ # ['1', '2', '3']
+ # >>> "<><>4<>5<><>6<>".split("<>")
+ # ['', '', '4', '5', '', '6', '']
+ strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings, sep="<>")
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(
+ indices, [[0, 0], [0, 1], [0, 2],
+ [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]])
+ self.assertAllEqual(values, [b"1", b"2", b"3",
+ b"", b"", b"4", b"5", b"", b"6", b""])
+ self.assertAllEqual(shape, [2, 7])
+
+ def testSplitV2SimpleSeparator(self):
+ # Match Python behavior:
+ # >>> '1,2,3'.split(',')
+ # ['1', '2', '3']
+ # >>> '1,2,,3,'.split(',')
+ # ['1', '2', '', '3', '']
+ strings = ["1,2,3", "4,5,,6,"]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings, sep=',')
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
+ [1, 0], [1, 1], [1, 2], [1, 3], [1, 4]])
+ self.assertAllEqual(values, [b"1", b"2", b"3",
+ b"4", b"5", b"", b"6", b""])
+ self.assertAllEqual(shape, [2, 5])
+
+ def testSplitV2EmptySeparator(self):
+ # Match Python behavior:
+ # >>> '1 2 3'.split()
+ # ['1', '2', '3']
+ #>>> ' 1 2 3 '.split()
+ #['1', '2', '3']
+ strings = ["1 2 3", " 4 5 6 "]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings)
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
+ [1, 0], [1, 1], [1, 2]])
+ self.assertAllEqual(values, [b"1", b"2", b"3", b"4", b"5", b"6"])
+ self.assertAllEqual(shape, [2, 3])
+
+ def testSplitV2SimpleSeparatorMaxSplit(self):
+ # Match Python behavior:
+ # >>> '1,2,3'.split(',', maxsplit=1)
+ # ['1', '2,3']
+ # >>> '4,5,,6,'.split(',', maxsplit=1)
+ # ['4', '5,,6,']
+ strings = ["1,2,3", "4,5,,6,"]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1],
+ [1, 0], [1, 1]])
+ self.assertAllEqual(values, [b"1", b"2,3", b"4", b"5,,6,"])
+ self.assertAllEqual(shape, [2, 2])
+
+ def testSplitV2EmptySeparatorMaxSplit(self):
+ # Match Python behavior:
+ # '1 2 3'.split(maxsplit=1)
+ # ['1', '2 3']
+ # >>> " 4 5 6 ".split(maxsplit=1)
+ # ['4', '5 6 ']
+ strings = ["1 2 3", " 4 5 6 "]
+
+ with self.test_session() as sess:
+ tokens = string_ops.string_split_v2(strings, maxsplit=1)
+ indices, values, shape = sess.run(tokens)
+ self.assertAllEqual(indices, [[0, 0], [0, 1],
+ [1, 0], [1, 1]])
+ self.assertAllEqual(values, [b"1", b"2 3", b"4", b"5 6 "])
+ self.assertAllEqual(shape, [2, 2])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 1b935d5286..0b3a396d6b 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -150,7 +150,7 @@ class TemplateTest(test.TestCase):
# Parameters are tied, so the loss should have gone down after training.
self.assertLess(final_test_loss.numpy(), initial_test_loss.numpy())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_skip_stack_frames(self):
first = traceback.format_stack()
second = traceback.format_stack()
@@ -158,7 +158,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(1, len(result))
self.assertNotEqual(len(first), len(result))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_template_with_name(self):
tmpl1 = template.make_template("s1", variable_scoped_function)
tmpl2 = template.make_template("s1", variable_scoped_function)
@@ -204,7 +204,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(v1, v3)
self.assertEqual("s1/dummy:0", v1.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_template_in_scope(self):
tmpl1 = template.make_template("s1", variable_scoped_function)
tmpl2 = template.make_template("s1", variable_scoped_function)
@@ -221,7 +221,7 @@ class TemplateTest(test.TestCase):
self.assertEqual("scope/s1/dummy:0", v1.name)
self.assertEqual("scope/s1_1/dummy:0", v3.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_template_with_internal_reuse(self):
tmpl1 = template.make_template("s1", internally_variable_scoped_function)
tmpl2 = template.make_template("s1", internally_variable_scoped_function)
@@ -237,13 +237,13 @@ class TemplateTest(test.TestCase):
with self.assertRaises(ValueError):
tmpl1("not_test")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_template_without_name(self):
with self.assertRaisesRegexp(
ValueError, "name cannot be None."):
template.make_template(None, variable_scoped_function)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_make_template(self):
# Test both that we can call it with positional and keywords.
tmpl1 = template.make_template(
@@ -266,7 +266,7 @@ class TemplateTest(test.TestCase):
with self.assertRaises(ValueError):
tmpl()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_enforces_no_extra_trainable_variables_eager(self):
tmpl = template.make_template("s",
function_with_side_create,
@@ -287,7 +287,7 @@ class TemplateTest(test.TestCase):
trainable=False)
self.assertEqual(tmpl(name="1"), tmpl(name="2"))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_internal_variable_reuse(self):
def nested():
@@ -310,7 +310,7 @@ class TemplateTest(test.TestCase):
self.assertEqual("s1/nested/x:0", v1.name)
self.assertEqual("s1_1/nested/x:0", v3.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_nested_templates(self):
def nested_template():
@@ -360,7 +360,7 @@ class TemplateTest(test.TestCase):
self.assertEqual("nested", tmpl1._checkpoint_dependencies[0].name)
self.assertEqual("nested_1", tmpl1._checkpoint_dependencies[1].name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_nested_templates_with_defun(self):
def variable_scoped_function_no_return_value(trainable=True):
@@ -429,7 +429,7 @@ class TemplateTest(test.TestCase):
"a", partial, create_graph_function_=True)
self.assertAllEqual(tmpl(ops.convert_to_tensor(1.0)), 2.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_immediate_scope_creation(self):
# Create templates in scope a then call in scope b. make_template should
# capture the scope the first time it is called, and make_immediate_template
@@ -454,7 +454,7 @@ class TemplateTest(test.TestCase):
self.assertEqual("ctor_scope/a/dummy:0", inner_imm_var.name)
self.assertEqual("call_scope/b/dummy:0", inner_defer_var.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_scope_access(self):
# Ensure that we can access the scope inside the template, because the name
# of that scope may be different from the name we pass to make_template, due
@@ -479,7 +479,7 @@ class TemplateTest(test.TestCase):
# Template is called at the top level, so there is no preceding "foo_2".
self.assertEqual(tc.variable_scope.name, "blah")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_custom_getter(self):
# Custom getter that maintains call count and forwards to true getter
custom_getter_count = [0]
@@ -512,7 +512,7 @@ class TemplateTest(test.TestCase):
tmpl2()
self.assertEqual(custom_getter_count[0], 2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_fails_gracefully(self):
for create_scope_now in [True, False]:
def module_function_with_one_arg(inputs):
@@ -535,7 +535,7 @@ class TemplateTest(test.TestCase):
templatized_function(data)
self.assertTrue(templatized_function._variables_created)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_name_scopes_for_variable_scopes(self):
# Test that name scopes are not unnecessarily uniquified (but are
# still uniquified when necessary).
@@ -586,7 +586,7 @@ class TemplateTest(test.TestCase):
"Second application of template should also get "
"a freshly uniquified name scope.")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_global_variables(self):
# Make sure global_variables are created.
with variable_scope.variable_scope("foo"):
@@ -608,7 +608,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(1, len(ta.global_variables))
self.assertEqual(2, len(tb.global_variables))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_trainable_variables(self):
# Make sure trainable_variables are created.
with variable_scope.variable_scope("foo2"):
@@ -632,7 +632,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(1, len(ta.variables))
self.assertEqual(1, len(tb.variables))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_non_trainable_variables(self):
# Make sure non_trainable_variables are created.
with variable_scope.variable_scope("foo2"):
@@ -675,7 +675,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(0, len(ta.local_variables))
self.assertEqual(1, len(tb.local_variables))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_make_template_with_defun(self):
def variable_scoped_function_no_return_value(scope_name):
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index c0b36f143d..6de6fbe767 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -26,11 +26,13 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@@ -73,7 +75,7 @@ class TensorArrayTest(test.TestCase):
super(TensorArrayTest, cls).tearDownClass()
session_lib.Session.reset(cls._workers[0].target)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteRead(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -121,11 +123,11 @@ class TensorArrayTest(test.TestCase):
self._testTensorArrayWritePack(dtypes.complex128)
self._testTensorArrayWritePack(dtypes.string)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayWritePack(self):
self._testTensorArrayWritePackMaybeLegacy()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEmptyTensorArrayPack(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -159,7 +161,7 @@ class TensorArrayTest(test.TestCase):
convert([[4.0, 5.0], [104.0, 105.0], [204.0, 205.0], [6.0, 7.0],
[106.0, 107.0], [8.0, 9.0]]), c0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteConcat(self):
self._testTensorArrayWriteConcat(dtypes.float32)
self._testTensorArrayWriteConcat(dtypes.float64)
@@ -182,7 +184,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros()
@@ -198,7 +200,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros()
@@ -249,7 +251,7 @@ class TensorArrayTest(test.TestCase):
self._testTensorArrayUnpackRead(dtypes.complex128)
self._testTensorArrayUnpackRead(dtypes.string)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayUnpackRead(self):
self._testTensorArrayUnpackReadMaybeLegacy()
@@ -295,7 +297,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual(convert([]).reshape(0, 2), d1)
self.assertAllEqual(convert([[3.0, 301.0]]), d2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArraySplitRead(self):
self._testTensorArraySplitRead(dtypes.float32)
self._testTensorArraySplitRead(dtypes.float64)
@@ -395,7 +397,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual(t_g_ta_0, t_g_ta_1)
self.assertAllEqual([[4.0, 5.0]], d_r1_0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.test_session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32)
@@ -414,7 +416,7 @@ class TensorArrayTest(test.TestCase):
"resizeable and size is: 3"):
self.evaluate(ta.write(3, 3.0).flow)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayReadWrongIndexOrDataTypeFails(self):
with self.test_session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32)
@@ -448,7 +450,7 @@ class TensorArrayTest(test.TestCase):
"it has already been written to."):
self.evaluate(ta.write(2, 3.0).write(2, 3.0).flow)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayConcatIncompatibleShapesFails(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -480,7 +482,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError("shape"):
self.evaluate(w3.concat())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArraySplitIncompatibleShapesFails(self):
with self.test_session(use_gpu=True):
in_eager_mode = context.executing_eagerly()
@@ -549,7 +551,59 @@ class TensorArrayTest(test.TestCase):
dtypes.complex64, dtypes.complex128):
self._testTensorArrayWriteGradientAddMultipleAdds(dtype)
- @test_util.run_in_graph_and_eager_modes()
+ def testTensorArrayGradWithShapeKnownElementShape(self):
+ with self.test_session(use_gpu=True) as sess:
+ ta = tensor_array_ops.TensorArray(
+ size=3,
+ dtype=dtypes.float32,
+ element_shape=tensor_shape.TensorShape([2, 3]))
+ handle, flow = data_flow_ops.tensor_array_grad_with_shape(
+ handle=ta.handle,
+ flow_in=ta.flow,
+ shape_to_prepend=tensor_shape.TensorShape([4, 5]),
+ source="source")
+ ta_grad = tensor_array_ops.TensorArray(
+ dtypes.float32, handle=handle, flow=flow)
+ value = array_ops.placeholder(dtypes.float32)
+ ta_grad = ta_grad.write(0, value)
+ read_value = ta_grad.read(0)
+
+ # Make sure shape inference worked.
+ self.assertAllEqual([None, None, 2, 3], read_value.shape.as_list())
+ # Writing with wrong shape should not work.
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Could not write to TensorArray"):
+ fed_value = np.random.random([2, 3])
+ sess.run(read_value, feed_dict={value: fed_value})
+ # Writing with correct shape should work.
+ fed_value = np.random.random([4, 5, 2, 3])
+ self.assertAllClose(fed_value,
+ sess.run(read_value, feed_dict={value: fed_value}))
+
+ def testTensorArrayGradWithShapeUnknownElementShape(self):
+ with self.test_session(use_gpu=True) as sess:
+ ta = tensor_array_ops.TensorArray(
+ size=3, dtype=dtypes.float32,
+ element_shape=None) # Note that element_shape is unknown
+ handle, flow = data_flow_ops.tensor_array_grad_with_shape(
+ handle=ta.handle,
+ flow_in=ta.flow,
+ shape_to_prepend=tensor_shape.TensorShape([4, 5]),
+ source="source")
+ ta_grad = tensor_array_ops.TensorArray(
+ dtypes.float32, handle=handle, flow=flow)
+ value = array_ops.placeholder(dtypes.float32)
+ ta_grad = ta_grad.write(0, value)
+ read_value = ta_grad.read(0)
+
+ # Make sure shape inference worked.
+ self.assertIsNone(read_value.shape.ndims)
+ # Write with some shape and check read value.
+ fed_value = np.random.random([4, 5, 7])
+ self.assertAllClose(fed_value,
+ sess.run(read_value, feed_dict={value: fed_value}))
+
+ @test_util.run_in_graph_and_eager_modes
def testMultiTensorArray(self):
with self.test_session(use_gpu=True):
h1 = tensor_array_ops.TensorArray(
@@ -652,7 +706,7 @@ class TensorArrayTest(test.TestCase):
def testTensorArrayGradientWritePackConcatAndRead(self):
self._testTensorArrayGradientWritePackConcatAndRead()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayReadTwice(self):
with self.test_session(use_gpu=True):
value = constant_op.constant([[1.0, -1.0], [10.0, -10.0]])
@@ -757,14 +811,14 @@ class TensorArrayTest(test.TestCase):
def testTensorArrayGradientDynamicUnpackRead(self):
self._testTensorArrayGradientDynamicUnpackRead()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCloseTensorArray(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
self.evaluate(ta.close())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSizeTensorArray(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -772,7 +826,7 @@ class TensorArrayTest(test.TestCase):
s = ta.size()
self.assertAllEqual(3, self.evaluate(s))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWriteCloseTensorArray(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -870,7 +924,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllClose(grad_val.sum(axis=0), var_grad_t)
self.assertAllClose(grad_val.sum(axis=0), state0_grad_t)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWhileLoopWritePackGradients(self):
self._testWhileLoopWritePackGradients(
dynamic_size=False, dtype=dtypes.float32)
@@ -882,7 +936,7 @@ class TensorArrayTest(test.TestCase):
self._testWhileLoopWritePackGradients(
dynamic_size=True, dtype=dtypes.float32)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradSerialTwoLoops(self):
with self.test_session(use_gpu=True):
def loop(x):
@@ -1059,7 +1113,7 @@ class TensorArrayTest(test.TestCase):
r5 = w5.read(0)
self.assertAllEqual([5, 4, 2, 3], r5.get_shape().as_list())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def _testUnpackShape(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -1093,7 +1147,7 @@ class TensorArrayTest(test.TestCase):
def testUnpackShape(self):
self._testUnpackShape()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSplitShape(self):
with self.test_session(use_gpu=True):
ta = tensor_array_ops.TensorArray(
@@ -1235,7 +1289,7 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual([10.0, -10.0], read_vals[1])
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayWriteGatherAndGradients(self):
with self.test_session(use_gpu=True) as session:
ta = tensor_array_ops.TensorArray(
@@ -1379,7 +1433,7 @@ class TensorArrayTest(test.TestCase):
self.assertFalse(
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTensorArrayIdentity(self):
with self.test_session(use_gpu=True):
ta0 = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=2,
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 9dc4ec0f96..ae2a0ab29a 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -57,7 +57,7 @@ class VariableScopeTest(test.TestCase):
v1 = vs.get_variable("v", [1])
self.assertEqual(v, v1)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testResource(self):
vs = variable_scope._get_default_variable_store()
v1 = vs.get_variable("v", [1], use_resource=True)
@@ -87,7 +87,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(
set(expected_names), set([v.name for v in vs._vars.values()]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScopeInitializer(self):
init = init_ops.constant_initializer(0.3)
with variable_scope.variable_scope("tower0") as tower:
@@ -100,7 +100,7 @@ class VariableScopeTest(test.TestCase):
self.evaluate(variables_lib.variables_initializer([w]))
self.assertAllClose(self.evaluate(w.value()), 0.3)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScopeConstraint(self):
constraint = lambda x: 0. * x
with variable_scope.variable_scope("tower1") as tower:
@@ -117,7 +117,7 @@ class VariableScopeTest(test.TestCase):
variables_lib.global_variables_initializer().run()
self.assertAllEqual(compat.as_bytes(v.eval()), b"")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScopeDType(self):
with variable_scope.variable_scope("tower2") as tower:
with variable_scope.variable_scope("foo", dtype=dtypes.float16):
@@ -197,7 +197,33 @@ class VariableScopeTest(test.TestCase):
self.assertAllEqual([v1, v2], [v3, v4])
f()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
+ def testEagerVariablesStoreAddsToCollections(self):
+ store = variable_scope.EagerVariableStore()
+ with store.as_default():
+ trainable = variable_scope.get_variable("v1", [], trainable=True)
+ not_trainable = variable_scope.get_variable("v2", [], trainable=False)
+ concat = variable_scope.get_variable(
+ "v3", [], collections=[ops.GraphKeys.CONCATENATED_VARIABLES])
+ self.assertEqual(
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES),
+ [trainable, not_trainable])
+ self.assertEqual(
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
+ [trainable, concat])
+ self.assertEqual(
+ ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES), [concat])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEagerVariablesOutsideStoreNotAddedToCollections(self):
+ if not context.executing_eagerly():
+ return
+ variable_scope.get_variable("v1", [], trainable=True)
+ variable_scope.get_variable("v2", [], trainable=False)
+ self.assertFalse(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ self.assertFalse(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+
+ @test_util.run_in_graph_and_eager_modes
def testInitFromNonTensorValue(self):
v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
self.evaluate(variables_lib.variables_initializer([v]))
@@ -213,7 +239,7 @@ class VariableScopeTest(test.TestCase):
with self.assertRaises(error):
variable_scope.get_variable("x4", initializer={})
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitFromNonInitializer(self):
# Test various dtypes with zeros initializer as following:
types = [
@@ -268,7 +294,7 @@ class VariableScopeTest(test.TestCase):
v_tower = variable_scope.get_variable("v", [])
self.assertFalse(v_tower.value().device.startswith(caching_device))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScopeRegularizer(self):
init = init_ops.constant_initializer(0.3)
@@ -313,7 +339,7 @@ class VariableScopeTest(test.TestCase):
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(3, len(losses)) # No new loss added.
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInitializeFromValue(self):
init = constant_op.constant(0.1)
w = variable_scope.get_variable("v", initializer=init)
@@ -402,7 +428,7 @@ class VariableScopeTest(test.TestCase):
sess.run(v0.initializer)
sess.run(add)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetVariableScope(self):
# Test the get_variable_scope() function and setting properties of result.
init = init_ops.constant_initializer(0.3)
@@ -423,7 +449,7 @@ class VariableScopeTest(test.TestCase):
new_init = variable_scope.get_variable_scope().initializer
self.assertEqual(new_init, None)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScope(self):
with variable_scope.variable_scope("tower4") as tower:
self.assertEqual(tower.name, "tower4")
@@ -442,7 +468,7 @@ class VariableScopeTest(test.TestCase):
with ops.name_scope("scope") as sc:
self.assertEqual(sc, "tower6/tower4/scope/")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testVarScopeNameScope(self):
with ops.name_scope("testVarScopeNameScope1"):
with variable_scope.variable_scope("tower") as tower:
@@ -935,7 +961,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(
constant_op.constant([], name="c").name, "another/inner/c:0")
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGetLocalVar(self):
# Check that local variable respects naming.
with variable_scope.variable_scope("outer") as outer:
@@ -1028,7 +1054,7 @@ class VariableScopeTest(test.TestCase):
"testGetCollection_foo/testGetCollection_a:0"
])
- def testGetTrainableVariables(self):
+ def testGetTrainableVariablesWithGetVariable(self):
with self.test_session():
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
with variable_scope.variable_scope(
@@ -1036,10 +1062,72 @@ class VariableScopeTest(test.TestCase):
_ = variable_scope.get_variable("testGetTrainableVariables_b", [])
_ = variable_scope.get_variable(
"testGetTrainableVariables_c", [], trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_d", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
+ self.assertEqual(
+ [v.name for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
+ def testGetTrainableVariablesWithVariable(self):
+ with self.test_session():
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
+ _ = variable_scope.variable(
+ 1.0, name="testGetTrainableVariables_c", trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_d",
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertEqual(
[v.name for v in scope.trainable_variables()],
- ["testGetTrainableVariables_foo/"
- "testGetTrainableVariables_b:0"])
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
def testGetGlobalVariables(self):
with self.test_session():
@@ -1227,6 +1315,31 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(v3, v4)
self.assertEqual(3, called[0]) # skipped one in the first new_scope
+ def testSynchronizationAndAggregationWithCustomGetter(self):
+ called = [0]
+ synchronization = variable_scope.VariableSynchronization.AUTO
+ aggregation = variable_scope.VariableAggregation.NONE
+
+ def custom_getter(getter, *args, **kwargs):
+ called[0] += 1
+
+ # Verify synchronization and aggregation kwargs are as expected.
+ self.assertEqual(kwargs["synchronization"], synchronization)
+ self.assertEqual(kwargs["aggregation"], aggregation)
+ return getter(*args, **kwargs)
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ variable_scope.get_variable("v", [1])
+ self.assertEqual(1, called[0])
+
+ with variable_scope.variable_scope("scope", custom_getter=custom_getter):
+ synchronization = variable_scope.VariableSynchronization.ON_READ
+ aggregation = variable_scope.VariableAggregation.MEAN
+ variable_scope.get_variable(
+ "v1", [1], synchronization=synchronization, aggregation=aggregation)
+
+ self.assertEqual(2, called[0])
+
def testCustomGetterWithReuse(self):
# Custom getter can choose to behave differently on reused variables.
def custom_getter(getter, *args, **kwargs):
@@ -1329,6 +1442,23 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertAllEqual(variable_names, ["forced_name"])
+ called = [False]
+
+ def creater_c(next_creator, **kwargs):
+ called[0] = True
+ self.assertEqual(kwargs["synchronization"],
+ variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual(kwargs["aggregation"],
+ variable_scope.VariableAggregation.MEAN)
+ return next_creator(**kwargs)
+
+ with variable_scope.variable_creator_scope(creater_c):
+ variable_scope.get_variable(
+ "v", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN)
+ self.assertTrue(called[0])
+
class PartitionInfoTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 62d596da91..2b9c62ad6f 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -642,6 +642,8 @@ class PartitionedVariableTest(test.TestCase):
iterated_partitions = list(partitioned_variable)
self.assertEqual(2, num_partitions)
self.assertEqual([v0, v1], iterated_partitions)
+ self.assertEqual([2], partitioned_variable.get_shape())
+ self.assertEqual([2], partitioned_variable.shape)
self.assertEqual([2], concatenated.get_shape())
self.assertEqual([2], concatenated.shape)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index eda036ece4..cf13b52617 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -152,10 +152,17 @@ class Layer(base_layer.Layer):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_weight(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=None,
+ constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -170,9 +177,19 @@ class Layer(base_layer.Layer):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -190,8 +207,22 @@ class Layer(base_layer.Layer):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When trainable has been set to True with synchronization
+ set as `ON_READ`.
"""
-
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
for var in variable:
@@ -240,6 +271,8 @@ class Layer(base_layer.Layer):
constraint=constraint,
partitioner=partitioner,
use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
getter=vs.get_variable)
if regularizer:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index ab49e37b90..d2443db665 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -39,7 +39,7 @@ from tensorflow.python.platform import test
class BaseLayerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLayerProperties(self):
layer = base_layers.Layer(name='my_layer')
self.assertEqual(layer.variables, [])
@@ -53,13 +53,13 @@ class BaseLayerTest(test.TestCase):
layer = base_layers.Layer(name='my_layer', trainable=False)
self.assertEqual(layer.trainable, False)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInt64Layer(self):
layer = base_layers.Layer(name='my_layer', dtype='int64')
layer.add_variable('my_var', [2, 2])
self.assertEqual(layer.name, 'my_layer')
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAddWeight(self):
layer = base_layers.Layer(name='my_layer')
@@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase):
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ # Test that sync `ON_READ` variables are defaulted to be non-trainable.
+ variable_3 = layer.add_variable(
+ 'sync_on_read_var', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3])
+
+ def testInvalidTrainableSynchronizationCombination(self):
+ layer = base_layers.Layer(name='my_layer')
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.'):
+ _ = layer.add_variable(
+ 'v', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
def testReusePartitionedVaraiblesAndRegularizers(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
partitioner = partitioned_variables.fixed_size_partitioner(3)
@@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase):
partitioner=partitioner,
reuse=reuse):
layer = base_layers.Layer(name='my_layer')
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_part_var', [4, 4],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
@@ -116,7 +138,7 @@ class BaseLayerTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):
core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCall(self):
class MyLayer(base_layers.Layer):
@@ -132,7 +154,7 @@ class BaseLayerTest(test.TestCase):
# op is only supported in GRAPH mode
self.assertEqual(outputs.op.name, 'my_layer/Square')
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDeepCopy(self):
class MyLayer(base_layers.Layer):
@@ -155,7 +177,7 @@ class BaseLayerTest(test.TestCase):
self.assertEqual(layer_copy._graph, layer._graph)
self.assertEqual(layer_copy._private_tensor, layer._private_tensor)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testScopeNaming(self):
class PrivateLayer(base_layers.Layer):
@@ -203,7 +225,7 @@ class BaseLayerTest(test.TestCase):
my_layer_scoped1.apply(inputs)
self.assertEqual(my_layer_scoped1._scope.name, 'var_scope/my_layer_1')
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecNdimCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -230,7 +252,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant([[1], [2]]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecMinNdimCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -258,7 +280,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant([[[1], [2]]]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecMaxNdimCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -286,7 +308,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant([[1], [2]]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecDtypeCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -306,7 +328,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant(1.0, dtype=dtypes.float32))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecAxesCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -328,7 +350,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant([[1, 2], [3, 4], [5, 6]]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testInputSpecShapeCheck(self):
class CustomerLayer(base_layers.Layer):
@@ -348,7 +370,7 @@ class BaseLayerTest(test.TestCase):
layer = CustomerLayer()
layer.apply(constant_op.constant([[1, 2, 3], [4, 5, 6]]))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoInputSpec(self):
class CustomerLayer(base_layers.Layer):
@@ -369,7 +391,7 @@ class BaseLayerTest(test.TestCase):
layer.apply(array_ops.placeholder('int32'))
layer.apply(array_ops.placeholder('int32', shape=(2, 3)))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_count_params(self):
dense = core_layers.Dense(16)
dense.build((None, 4))
@@ -379,7 +401,7 @@ class BaseLayerTest(test.TestCase):
with self.assertRaises(ValueError):
dense.count_params()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDictInputOutput(self):
class DictLayer(base_layers.Layer):
@@ -589,6 +611,5 @@ class BaseLayerTest(test.TestCase):
ValueError, 'Input graph and Layer graph are not the same'):
layer.apply(constant_op.constant([[1.]]))
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 267d78dbcb..36cef3855e 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -217,7 +217,6 @@ def conv1d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -421,7 +420,6 @@ def conv2d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -627,7 +625,6 @@ def conv3d(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -1266,7 +1263,6 @@ def conv2d_transpose(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
@@ -1438,7 +1434,6 @@ def conv3d_transpose(inputs,
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index abbacac442..aadff231da 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -184,7 +184,6 @@ def dense(
bias_constraint=bias_constraint,
trainable=trainable,
name=name,
- dtype=inputs.dtype.base_dtype,
_scope=name,
_reuse=reuse)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index cf45b07637..040c1cddc0 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -41,7 +41,7 @@ from tensorflow.python.platform import test
class DenseTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDenseProperties(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
self.assertEqual(dense.units, 2)
@@ -91,14 +91,14 @@ class DenseTest(test.TestCase):
core_layers.Dense(5)(inputs)
core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')(inputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCallTensorDot(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='my_dense')
inputs = random_ops.random_uniform((5, 4, 3), seed=1)
outputs = dense(inputs)
self.assertListEqual([5, 4, 2], outputs.get_shape().as_list())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoBias(self):
dense = core_layers.Dense(2, use_bias=False, name='my_dense')
inputs = random_ops.random_uniform((5, 2), seed=1)
@@ -112,7 +112,7 @@ class DenseTest(test.TestCase):
self.assertEqual(dense.kernel.name, 'my_dense/kernel:0')
self.assertEqual(dense.bias, None)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNonTrainable(self):
dense = core_layers.Dense(2, trainable=False, name='my_dense')
inputs = random_ops.random_uniform((5, 2), seed=1)
@@ -125,7 +125,7 @@ class DenseTest(test.TestCase):
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOutputShape(self):
dense = core_layers.Dense(7, activation=nn_ops.relu, name='my_dense')
inputs = random_ops.random_uniform((5, 3), seed=1)
@@ -165,7 +165,7 @@ class DenseTest(test.TestCase):
dense = core_layers.Dense(4, name='my_dense')
dense(inputs)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testActivation(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1)
@@ -325,7 +325,7 @@ class DenseTest(test.TestCase):
var_key = 'test2/dense/kernel'
self.assertEqual(var_dict[var_key].name, '%s:0' % var_key)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testComputeOutputShape(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
ts = tensor_shape.TensorShape
@@ -347,7 +347,7 @@ class DenseTest(test.TestCase):
dense.compute_output_shape(ts([None, 4, 3])).as_list())
# pylint: enable=protected-access
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testConstraints(self):
k_constraint = lambda x: x / math_ops.reduce_sum(x)
b_constraint = lambda x: x / math_ops.reduce_max(x)
@@ -369,7 +369,7 @@ def _get_variable_dict_from_varstore():
class DropoutTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDropoutProperties(self):
dp = core_layers.Dropout(0.5, name='dropout')
self.assertEqual(dp.rate, 0.5)
@@ -377,7 +377,7 @@ class DropoutTest(test.TestCase):
dp.apply(array_ops.ones(()))
self.assertEqual(dp.name, 'dropout')
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBooleanLearningPhase(self):
dp = core_layers.Dropout(0.5)
inputs = array_ops.ones((5, 3))
@@ -402,7 +402,7 @@ class DropoutTest(test.TestCase):
np_output = sess.run(dropped, feed_dict={training: False})
self.assertAllClose(np.ones((5, 5)), np_output)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDynamicNoiseShape(self):
inputs = array_ops.ones((5, 3, 2))
noise_shape = [None, 1, None]
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index d082e312e9..f7bc10a6a6 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -44,7 +44,7 @@ class BatchNormalization(keras_layers.BatchNormalization, base.Layer):
normalized, typically the features axis/axes. For instance, after a
`Conv2D` layer with `data_format="channels_first"`, set `axis=1`. If a
list of axes is provided, each axis in `axis` will be normalized
- simultaneously. Default is `-1` which takes uses last axis. Note: when
+ simultaneously. Default is `-1` which uses the last axis. Note: when
using multi-axis batch norm, the `beta`, `gamma`, `moving_mean`, and
`moving_variance` variables are the same rank as the input Tensor, with
dimension size 1 in all reduced (non-axis) dimensions).
@@ -308,7 +308,6 @@ def batch_normalization(inputs,
virtual_batch_size=virtual_batch_size,
adjustment=adjustment,
name=name,
- dtype=inputs.dtype.base_dtype,
_reuse=reuse,
_scope=name)
return layer.apply(inputs, training=training)
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index 77fa2c1f66..fde3a83770 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -446,6 +446,16 @@ npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
return x != static_cast<bfloat16>(0);
}
+int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
+ bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
+ const float start(buffer[0]);
+ const float delta = static_cast<float>(buffer[1]) - start;
+ for (npy_intp i = 2; i < length; ++i) {
+ buffer[i] = static_cast<bfloat16>(start + i * delta);
+ }
+ return 0;
+}
+
// NumPy casts
// Performs a NumPy array cast from type 'From' to 'To'.
@@ -548,6 +558,7 @@ bool Initialize() {
NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
+ NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
npy_bfloat16_ = PyArray_RegisterDataType(&NPyBfloat16_Descr);
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index 09d4b01fa4..bc928cd9e5 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -245,6 +245,20 @@ class Bfloat16NumPyTest(test.TestCase):
np.logaddexp(x.astype(bfloat16), y.astype(bfloat16)),
atol=2e-2)
+ def testArange(self):
+ self.assertAllEqual(
+ np.arange(100, dtype=np.float32).astype(bfloat16),
+ np.arange(100, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
+ np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
+ np.arange(-0., -7., -0.25, dtype=bfloat16))
+ self.assertAllEqual(
+ np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
+ np.arange(-16384., 16384., 64., dtype=bfloat16))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc
index 9df38d464c..ec1ba7b8f7 100644
--- a/tensorflow/python/lib/core/ndarray_tensor.cc
+++ b/tensorflow/python/lib/core/ndarray_tensor.cc
@@ -312,6 +312,40 @@ Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
return Status::OK();
}
+
+inline void FastMemcpy(void* dst, const void* src, size_t size) {
+ // clang-format off
+ switch (size) {
+ // Most compilers will generate inline code for fixed sizes,
+ // which is significantly faster for small copies.
+ case 1: memcpy(dst, src, 1); break;
+ case 2: memcpy(dst, src, 2); break;
+ case 3: memcpy(dst, src, 3); break;
+ case 4: memcpy(dst, src, 4); break;
+ case 5: memcpy(dst, src, 5); break;
+ case 6: memcpy(dst, src, 6); break;
+ case 7: memcpy(dst, src, 7); break;
+ case 8: memcpy(dst, src, 8); break;
+ case 9: memcpy(dst, src, 9); break;
+ case 10: memcpy(dst, src, 10); break;
+ case 11: memcpy(dst, src, 11); break;
+ case 12: memcpy(dst, src, 12); break;
+ case 13: memcpy(dst, src, 13); break;
+ case 14: memcpy(dst, src, 14); break;
+ case 15: memcpy(dst, src, 15); break;
+ case 16: memcpy(dst, src, 16); break;
+#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) && \
+ !defined(IS_MOBILE_PLATFORM)
+ // On Linux, memmove appears to be faster than memcpy for
+ // large sizes, strangely enough.
+ default: memmove(dst, src, size); break;
+#else
+ default: memcpy(dst, src, size); break;
+#endif
+ }
+ // clang-format on
+}
+
} // namespace
// Converts the given TF_Tensor to a numpy ndarray.
@@ -362,8 +396,8 @@ Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
" bytes but TF_Tensor was ",
TF_TensorByteSize(tensor.get()), " bytes");
} else {
- memcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
- PyArray_NBYTES(py_array));
+ FastMemcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
+ PyArray_NBYTES(py_array));
}
// PyArray_Return turns rank 0 arrays into numpy scalars
@@ -377,7 +411,7 @@ Status PyArrayToTF_Tensor(PyObject* ndarray, Safe_TF_TensorPtr* out_tensor) {
// Make sure we dereference this array object in case of error, etc.
Safe_PyObjectPtr array_safe(make_safe(
- PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr)));
+ PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr)));
if (!array_safe) return errors::InvalidArgument("Not a ndarray.");
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
diff --git a/tensorflow/python/lib/core/numpy.h b/tensorflow/python/lib/core/numpy.h
index 25322b458b..d4621d61ee 100644
--- a/tensorflow/python/lib/core/numpy.h
+++ b/tensorflow/python/lib/core/numpy.h
@@ -29,7 +29,9 @@ limitations under the License.
#define NO_IMPORT_ARRAY
#endif
+// Place `<locale>` before <Python.h> to avoid build failure in macOS.
#include <Python.h>
+#include <locale>
#include "numpy/arrayobject.h"
#include "numpy/ufuncobject.h"
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 30c1a9c759..57139986af 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -55,37 +55,35 @@ struct PyCall {
string token;
// The device on which Tensors are stored; only used for EagerPyFunc.
- Device* device;
-
- // True if and only if the op has been placed on a GPU.
- bool gpu;
+ Device* device = nullptr;
// True if the call is associated with an EagerPyFunc.
- bool eager;
+ bool eager = false;
// Inputs and outputs of this function invocation.
std::vector<Tensor> ins;
std::vector<Tensor> out;
};
+bool IsCPUDevice(const Device* d) {
+ return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
+}
+
// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
int64 n = call->ins.size();
PyObject* lst = PyList_New(n);
CHECK(lst);
+ // TFE_TensorHandle assumes that CPU is identified by nullptr.
+ Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
if (call->eager) {
- if (call->gpu) {
- arg = EagerTensorFromHandle(
- new TFE_TensorHandle(t, call->device, call->device));
- } else {
- // TFE_TensorHandle assumes that CPU is identified by `nullptr`.
- arg = EagerTensorFromHandle(new TFE_TensorHandle(t, nullptr, nullptr));
- }
+ arg = EagerTensorFromHandle(new TFE_TensorHandle(t, device, device));
if (arg == nullptr) {
+ Py_DECREF(lst);
return errors::Internal("Unable to procure EagerTensor from Tensor.");
}
} else {
@@ -97,8 +95,9 @@ Status MakeArgTuple(const PyCall* call, PyObject** tuple) {
}
PyList_SetItem(lst, i, arg);
}
- *tuple = Py_BuildValue("(sON)", call->token.c_str(),
- call->gpu ? Py_True : Py_False, lst);
+ const char* device_name =
+ device == nullptr ? nullptr : device->attributes().name().c_str();
+ *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst);
CHECK(*tuple);
return Status::OK();
}
@@ -167,9 +166,40 @@ bool IsSingleNone(PyObject* obj) {
}
// Retrieves a Tensor from `eager_tensor` and stores it in `output_tensor`.
+// Validates that `output_tensor` is backed by memory in `expected_device`
+// (which is assumed to be a local device, one on which the kernel was
+// executed.)
+//
+// It may be nice to copy the tensor to the right device instead of failing if
+// it isn't already there. This is left as a future exercise. The required
+// device-copying logic is implemented in Python at the moment.
tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
+ const Device* expected_device,
const Tensor** output_tensor) {
- return EagerTensor_Handle(eager_tensor)->handle->Tensor(output_tensor);
+ auto handle = EagerTensor_Handle(eager_tensor)->handle;
+ Device* actual_device = nullptr;
+ TF_RETURN_IF_ERROR(handle->Device(&actual_device));
+ TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
+ // actual_device may be nullptr, which implies local CPU.
+ if (expected_device == actual_device) return Status::OK();
+ const string& expected_device_name = expected_device->attributes().name();
+ if (actual_device == nullptr) {
+ if (!IsCPUDevice(expected_device)) {
+ return errors::Internal(
+ "expected the py_func to return a Tensor backed by memory in ",
+ expected_device_name,
+ ", but is actually backed by local host memory. This is a bug.");
+ }
+ return Status::OK();
+ }
+ const string& actual_device_name = actual_device->attributes().name();
+ if (actual_device_name != expected_device_name) {
+ return errors::Internal(
+ "expected the py_func to return a Tensor backed by memory in ",
+ expected_device_name, ", but is actually in ", actual_device_name,
+ ". This is a bug.");
+ }
+ return Status::OK();
}
// Calls the registered py function through the trampoline.
@@ -224,7 +254,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
const PyObject* item = PyList_GetItem(result, i);
if (EagerTensor_CheckExact(item)) {
const Tensor* tensor = nullptr;
- s = ExtractTensorFromEagerTensor(item, &tensor);
+ s = ExtractTensorFromEagerTensor(item, call->device, &tensor);
if (s.ok()) t = *tensor;
} else {
s = errors::FailedPrecondition(
@@ -245,7 +275,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
DCHECK(call->eager);
if (result != Py_None) {
const Tensor* t = nullptr;
- s = ExtractTensorFromEagerTensor(result, &t);
+ s = ExtractTensorFromEagerTensor(result, call->device, &t);
if (s.ok()) call->out.push_back(*t);
}
} else if (PyArray_Check(result)) {
@@ -449,13 +479,11 @@ class PyFuncOp : public OpKernel {
explicit PyFuncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("token", &token_));
eager_ = type_string() == "EagerPyFunc";
- gpu_ = ctx->device_type().type_string() == DEVICE_GPU;
}
void Compute(OpKernelContext* ctx) override {
PyCall call;
call.token = token_;
- call.gpu = gpu_;
call.eager = eager_;
if (call.eager) {
// Eager's C API uses `Device`, whereas `OpKernelContext` stores a
@@ -464,6 +492,7 @@ class PyFuncOp : public OpKernel {
if (call.device == nullptr) {
ctx->CtxFailureWithWarning(
errors::Internal("Unrecognized device class"));
+ return;
}
}
@@ -508,9 +537,6 @@ class PyFuncOp : public OpKernel {
private:
string token_;
- // True if and only if this op has been placed on a GPU.
- bool gpu_;
-
// True if and only if this op should execute the python function eagerly,
// i.e., if and only if the eager attribute is set.
bool eager_;
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 386be35ba2..3b4f12ae31 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -88,6 +88,41 @@ bool IsPyDimension(PyObject* obj) {
return ret;
}
+// Sets *elem to a NEW reference to an element in seq on success.
+// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0.
+Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
+ *elem = PySequence_GetItem(seq, 0);
+ if (*elem != nullptr) return Status::OK();
+ // seq may implement the sequence protocol (i.e., implement __getitem__)
+ // but may legitimately not have a 0-th element (__getitem__(self, 0)
+ // raises a KeyError). For example:
+ // seq = pandas.Series([0, 1, 2], index=[2, 4, 6])
+ //
+ // We don't actually care for the element at key 0, any element will do
+ // for inferring the element types. All elements are expected to
+ // have the same type, and this will be validated when converting
+ // to an EagerTensor.
+ PyErr_Clear();
+ Safe_PyObjectPtr iter(PyObject_GetIter(seq));
+ if (PyErr_Occurred()) {
+ return errors::InvalidArgument("Cannot infer dtype of a ",
+ Py_TYPE(seq)->tp_name,
+ " object: ", PyExceptionFetch());
+ }
+ *elem = PyIter_Next(iter.get());
+ if (PyErr_Occurred()) {
+ return errors::InvalidArgument(
+ "Cannot infer dtype of a ", Py_TYPE(seq)->tp_name,
+ " object, as iter(<object>).next() failed: ", PyExceptionFetch());
+ }
+ if (*elem == nullptr) {
+ return errors::InvalidArgument("Cannot infer dtype of a ",
+ Py_TYPE(seq)->tp_name,
+ " object since it is an empty sequence");
+ }
+ return Status::OK();
+}
+
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
std::vector<Safe_PyObjectPtr> refs_to_clean;
while (true) {
@@ -98,7 +133,9 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
auto length = PySequence_Length(obj);
if (length > 0) {
shape->AddDim(length);
- obj = PySequence_GetItem(obj, 0);
+ PyObject* elem = nullptr;
+ TF_RETURN_IF_ERROR(SampleElementFromSequence(obj, &elem));
+ obj = elem;
refs_to_clean.push_back(make_safe(obj));
continue;
} else if (length == 0) {
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
index dcda1f4a44..6b6c82015f 100644
--- a/tensorflow/python/lib/core/py_util.cc
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/py_util.h"
+// Place `<locale>` before <Python.h> to avoid build failure in macOS.
#include <Python.h>
+#include <locale>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
diff --git a/tensorflow/python/lib/io/tf_record_test.py b/tensorflow/python/lib/io/tf_record_test.py
new file mode 100644
index 0000000000..dcc1a25f42
--- /dev/null
+++ b/tensorflow/python/lib/io/tf_record_test.py
@@ -0,0 +1,322 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf_record.TFRecordWriter and tf_record.tf_record_iterator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+import six
+
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+prefix_path = "third_party/tensorflow/core/lib"
+
+# pylint: disable=invalid-name
+TFRecordCompressionType = tf_record.TFRecordCompressionType
+# pylint: enable=invalid-name
+
+# Edgar Allan Poe's 'Eldorado'
+_TEXT = b"""Gaily bedight,
+ A gallant knight,
+ In sunshine and in shadow,
+ Had journeyed long,
+ Singing a song,
+ In search of Eldorado.
+
+ But he grew old
+ This knight so bold
+ And o'er his heart a shadow
+ Fell as he found
+ No spot of ground
+ That looked like Eldorado.
+
+ And, as his strength
+ Failed him at length,
+ He met a pilgrim shadow
+ 'Shadow,' said he,
+ 'Where can it be
+ This land of Eldorado?'
+
+ 'Over the Mountains
+ Of the Moon'
+ Down the Valley of the Shadow,
+ Ride, boldly ride,'
+ The shade replied,
+ 'If you seek for Eldorado!'
+ """
+
+
+class TFCompressionTestCase(test.TestCase):
+
+ def setUp(self):
+ super(TFCompressionTestCase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ def _Record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _CreateFiles(self, options=None, prefix=""):
+ filenames = []
+ for i in range(self._num_files):
+ name = prefix + "tfrecord.%d.txt" % i
+ records = [self._Record(i, j) for j in range(self._num_records)]
+ fn = self._WriteRecordsToFile(records, name, options)
+ filenames.append(fn)
+ return filenames
+
+ def _WriteRecordsToFile(self, records, name="tfrecord", options=None):
+ fn = os.path.join(self.get_temp_dir(), name)
+ with tf_record.TFRecordWriter(fn, options=options) as writer:
+ for r in records:
+ writer.write(r)
+ return fn
+
+ def _ZlibCompressFile(self, infile, name="tfrecord.z"):
+ # zlib compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = zlib.compress(f.read())
+
+ zfn = os.path.join(self.get_temp_dir(), name)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ return zfn
+
+ def _GzipCompressFile(self, infile, name="tfrecord.gz"):
+ # gzip compress the file and write compressed contents to file.
+ with open(infile, "rb") as f:
+ cdata = f.read()
+
+ gzfn = os.path.join(self.get_temp_dir(), name)
+ with gzip.GzipFile(gzfn, "wb") as f:
+ f.write(cdata)
+ return gzfn
+
+ def _ZlibDecompressFile(self, infile, name="tfrecord"):
+ with open(infile, "rb") as f:
+ cdata = zlib.decompress(f.read())
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+ def _GzipDecompressFile(self, infile, name="tfrecord"):
+ with gzip.GzipFile(infile, "rb") as f:
+ cdata = f.read()
+ fn = os.path.join(self.get_temp_dir(), name)
+ with open(fn, "wb") as f:
+ f.write(cdata)
+ return fn
+
+
+class TFRecordWriterTest(TFCompressionTestCase):
+
+ def setUp(self):
+ super(TFRecordWriterTest, self).setUp()
+
+ def _AssertFilesEqual(self, a, b, equal):
+ for an, bn in zip(a, b):
+ with open(an, "rb") as af, open(bn, "rb") as bf:
+ if equal:
+ self.assertEqual(af.read(), bf.read())
+ else:
+ self.assertNotEqual(af.read(), bf.read())
+
+ def testWriteReadZLibFiles(self):
+ # Write uncompressed then compress manually.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
+ files = self._CreateFiles(options, prefix="uncompressed")
+ zlib_files = [
+ self._ZlibCompressFile(fn, "tfrecord_%s.z" % i)
+ for i, fn in enumerate(files)
+ ]
+ self._AssertFilesEqual(files, zlib_files, False)
+
+ # Now write compressd and verify same.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ compressed_files = self._CreateFiles(options, prefix="compressed")
+ self._AssertFilesEqual(compressed_files, zlib_files, True)
+
+ # Decompress compress and verify same.
+ uncompressed_files = [
+ self._ZlibDecompressFile(fn, "tfrecord_%s.z" % i)
+ for i, fn in enumerate(compressed_files)
+ ]
+ self._AssertFilesEqual(uncompressed_files, files, True)
+
+ def testWriteReadGzipFiles(self):
+ # Write uncompressed then compress manually.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.NONE)
+ files = self._CreateFiles(options, prefix="uncompressed")
+ gzip_files = [
+ self._GzipCompressFile(fn, "tfrecord_%s.gz" % i)
+ for i, fn in enumerate(files)
+ ]
+ self._AssertFilesEqual(files, gzip_files, False)
+
+ # Now write compressd and verify same.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ compressed_files = self._CreateFiles(options, prefix="compressed")
+
+ # Note: Gzips written by TFRecordWriter add 'tfrecord_0' so
+ # compressed_files can't be compared with gzip_files
+
+ # Decompress compress and verify same.
+ uncompressed_files = [
+ self._GzipDecompressFile(fn, "tfrecord_%s.gz" % i)
+ for i, fn in enumerate(compressed_files)
+ ]
+ self._AssertFilesEqual(uncompressed_files, files, True)
+
+
+class TFRecordWriterZlibTest(TFCompressionTestCase):
+
+ def testZLibFlushRecord(self):
+ original = [b"small record"]
+ fn = self._WriteRecordsToFile(original, "small_record")
+ with open(fn, "rb") as h:
+ buff = h.read()
+
+ # creating more blocks and trailing blocks shouldn't break reads
+ compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)
+
+ output = b""
+ for c in buff:
+ if isinstance(c, int):
+ c = six.int2byte(c)
+ output += compressor.compress(c)
+ output += compressor.flush(zlib.Z_FULL_FLUSH)
+
+ output += compressor.flush(zlib.Z_FULL_FLUSH)
+ output += compressor.flush(zlib.Z_FULL_FLUSH)
+ output += compressor.flush(zlib.Z_FINISH)
+
+ # overwrite the original file with the compressed data
+ with open(fn, "wb") as h:
+ h.write(output)
+
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ actual = list(tf_record.tf_record_iterator(fn, options=options))
+ self.assertEqual(actual, original)
+
+ def testZlibReadWrite(self):
+ """Verify that files produced are zlib compatible."""
+ original = [b"foo", b"bar"]
+ fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord")
+ zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z")
+
+ # read the compressed contents and verify.
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ actual = list(tf_record.tf_record_iterator(zfn, options=options))
+ self.assertEqual(actual, original)
+
+ def testZlibReadWriteLarge(self):
+ """Verify that writing large contents also works."""
+
+ # Make it large (about 5MB)
+ original = [_TEXT * 10240]
+ fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord")
+ zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z")
+
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ actual = list(tf_record.tf_record_iterator(zfn, options=options))
+ self.assertEqual(actual, original)
+
+ def testGzipReadWrite(self):
+ """Verify that files produced are gzip compatible."""
+ original = [b"foo", b"bar"]
+ fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
+ gzfn = self._GzipCompressFile(fn, "tfrecord.gz")
+
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ actual = list(tf_record.tf_record_iterator(gzfn, options=options))
+ self.assertEqual(actual, original)
+
+
+class TFRecordIteratorTest(TFCompressionTestCase):
+
+ def setUp(self):
+ super(TFRecordIteratorTest, self).setUp()
+ self._num_records = 7
+
+ def testIterator(self):
+ records = [self._Record(0, i) for i in range(self._num_records)]
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(records, "compressed_records", options)
+
+ reader = tf_record.tf_record_iterator(fn, options)
+ for expected in records:
+ record = next(reader)
+ self.assertAllEqual(expected, record)
+ with self.assertRaises(StopIteration):
+ record = next(reader)
+
+ def testWriteZlibRead(self):
+ """Verify compression with TFRecordWriter is zlib library compatible."""
+ original = [b"foo", b"bar"]
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(original, "write_zlib_read.tfrecord.z",
+ options)
+
+ zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord")
+ actual = list(tf_record.tf_record_iterator(zfn))
+ self.assertEqual(actual, original)
+
+ def testWriteZlibReadLarge(self):
+ """Verify compression for large records is zlib library compatible."""
+ # Make it large (about 5MB)
+ original = [_TEXT * 10240]
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
+ fn = self._WriteRecordsToFile(original, "write_zlib_read_large.tfrecord.z",
+ options)
+ zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tfrecord")
+ actual = list(tf_record.tf_record_iterator(zfn))
+ self.assertEqual(actual, original)
+
+ def testWriteGzipRead(self):
+ original = [b"foo", b"bar"]
+ options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
+ fn = self._WriteRecordsToFile(original, "write_gzip_read.tfrecord.gz",
+ options)
+
+ gzfn = self._GzipDecompressFile(fn, "write_gzip_read.tfrecord")
+ actual = list(tf_record.tf_record_iterator(gzfn))
+ self.assertEqual(actual, original)
+
+ def testBadFile(self):
+ """Verify that tf_record_iterator throws an exception on bad TFRecords."""
+ fn = os.path.join(self.get_temp_dir(), "bad_file")
+ with tf_record.TFRecordWriter(fn) as writer:
+ writer.write(b"123")
+ fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated")
+ with open(fn, "rb") as f:
+ with open(fn_truncated, "wb") as f2:
+ # DataLossError requires that we've written the header, so this must
+ # be at least 12 bytes.
+ f2.write(f.read(14))
+ with self.assertRaises(errors_impl.DataLossError):
+ for _ in tf_record.tf_record_iterator(fn_truncated):
+ pass
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3678bd4c1f..fe459a96b9 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -568,7 +568,6 @@ ops.NotDifferentiable("Size")
@ops.RegisterGradient("Tile")
def _TileGrad(op, grad):
"""Sum reduces grad along the tiled dimensions."""
- assert isinstance(grad, ops.Tensor)
input_shape = array_ops.shape(op.inputs[0])
# We interleave multiples and input_shape to get split_shape,
# reshape grad to split_shape, and reduce along all even
@@ -581,6 +580,13 @@ def _TileGrad(op, grad):
split_shape = array_ops.reshape(
array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
+ # Sum reduces grad along the first dimension for IndexedSlices
+ if isinstance(grad, ops.IndexedSlices):
+ grad = math_ops.unsorted_segment_sum(
+ grad.values,
+ math_ops.mod(grad.indices, input_shape[0]),
+ input_shape[0])
+ split_shape = array_ops.concat([[1], split_shape[1:]], axis=0)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
if not context.executing_eagerly():
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index fb81798602..361667ec49 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
+from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -1623,7 +1624,7 @@ def ones_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
+ `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
`complex64`, `complex128` or `bool`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
@@ -2609,14 +2610,6 @@ def where(condition, x=None, y=None, name=None):
raise ValueError("x and y must both be non-None or both be None.")
-@tf_export("reverse")
-def reverse(tensor, axis, name=None):
- return gen_array_ops.reverse_v2(tensor, axis, name)
-
-
-reverse.__doc__ = gen_array_ops.reverse_v2.__doc__
-
-
# pylint: disable=redefined-builtin
@tf_export("reverse_sequence")
@deprecation.deprecated_args(
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index 2a2bcdd9d6..868a4f6b84 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -25,6 +25,8 @@ from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
diff --git a/tensorflow/python/ops/collective_ops.py b/tensorflow/python/ops/collective_ops.py
index a05fd15eca..98668facd5 100644
--- a/tensorflow/python/ops/collective_ops.py
+++ b/tensorflow/python/ops/collective_ops.py
@@ -22,7 +22,7 @@ from tensorflow.python.ops import gen_collective_ops
def all_reduce(t, group_size, group_key, instance_key, merge_op, final_op,
- subdiv_offsets=(0)):
+ subdiv_offsets=(0,)):
"""Reduces tensors collectively, across devices.
Args:
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 8e16cffdf4..9cc64ef9f6 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -37,11 +37,11 @@ class CollectiveOpTest(test.TestCase):
with ops.device('/CPU:0'):
in0 = constant_op.constant(t0)
colred0 = collective_ops.all_reduce(in0, 2, group_key, instance_key,
- 'Add', 'Div', [0])
+ 'Add', 'Div')
with ops.device('/CPU:1'):
in1 = constant_op.constant(t1)
colred1 = collective_ops.all_reduce(in1, 2, group_key, instance_key,
- 'Add', 'Div', [0])
+ 'Add', 'Div')
run_options = config_pb2.RunOptions()
run_options.experimental.collective_graph_key = 1
results = sess.run([colred0, colred1], options=run_options)
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_case.py b/tensorflow/python/ops/cond_v2.py
index b95202c5df..76173e0f30 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_case.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -1,4 +1,3 @@
-# =============================================================================
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,23 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""Test case base for testing proto operations."""
+"""cond_v2 wrapper module.
+
+This imports the cond_v2 method and all necessary dependencies (this is to avoid
+circular dependencies in the cond_v2 implementation). See cond_v2_impl for more
+information.
+"""
-# Python3 preparedness imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import ctypes as ct
-import os
-
-from tensorflow.python.platform import test
-
-
-class ProtoOpTestCase(test.TestCase):
+# pylint: disable=unused-import
+from tensorflow.python.framework import function
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.ops import gradients_impl
- def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
- super(ProtoOpTestCase, self).__init__(methodName)
- lib = os.path.join(os.path.dirname(__file__), 'libtestexample.so')
- if os.path.isfile(lib):
- ct.cdll.LoadLibrary(lib)
+from tensorflow.python.ops.cond_v2_impl import cond_v2
+# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
new file mode 100644
index 0000000000..d310f83dca
--- /dev/null
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -0,0 +1,479 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""cond_v2 and gradient.
+
+This is a version of cond that emits a single If op, as well as the gradient
+function for If ops produced by cond_v2. This will eventually replace the
+current tf.cond implementation once it reaches feature and performance parity.
+
+NOTE: most users of cond_v2 should import cond_v2, not this module! This module
+does not contain all the necessary imports to prevent circular dependencies,
+while cond_v2 does.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.util import compat
+
+
+# The following modules cannot be imported directly because they cause circular
+# dependencies. These are set in each corresponding module.
+_function = None
+_function_def_to_graph = None
+_gradients_impl = None
+
+# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify
+# that they aren't part of the official public API. These protected members
+# often need to be used by implementation code however. Rather than litter the
+# code with pylint comments, we ignore protected access violations for
+# readability.
+# pylint: disable=protected-access
+
+
+def cond_v2(pred, true_fn, false_fn, name="cond"):
+ """Like tf.cond, except emits a single If op."""
+ if not name:
+ name = "cond"
+
+ with ops.name_scope(name) as scope:
+ # Identify if there is a caller device, & get the innermost if possible.
+ device_stack = ops.get_default_graph()._device_function_stack
+ caller_device = device_stack[-1] if device_stack else None
+
+ caller_colocation_stack = ops.get_default_graph()._colocation_stack
+ caller_container = ops.get_default_graph()._container
+ caller_collection_ref = ops.get_default_graph()._collections
+
+ func_name_prefix = scope.replace("/", "_")
+
+ true_graph = _function.func_graph_from_py_func(
+ true_fn, [], [],
+ name="%strue" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
+ false_graph = _function.func_graph_from_py_func(
+ false_fn, [], [],
+ name="%sfalse" % func_name_prefix,
+ device=caller_device,
+ colocation_stack=caller_colocation_stack,
+ collections_ref=caller_collection_ref,
+ container=caller_container)
+ _check_same_outputs(true_graph, false_graph)
+
+ # Add inputs to true_graph and false_graph to make them match. Note that
+ # this modifies true_graph and false_graph.
+ cond_inputs = _make_inputs_match(true_graph, false_graph,
+ true_graph.extra_inputs,
+ false_graph.extra_inputs)
+
+ # Add all intermediate tensors as function outputs so they're available for
+ # the gradient computation.
+
+ true_intermediates = _get_intermediates(true_graph)
+ false_intermediates = _get_intermediates(false_graph)
+
+ # Save the original number of outputs to return to the caller.
+ num_cond_outputs = len(true_graph.outputs)
+
+ # Make the number/type of new intermediate outputs match.
+ extra_true_outputs, extra_false_outputs = _pad_params(
+ true_graph, false_graph, true_intermediates, false_intermediates)
+
+ true_graph.outputs.extend(extra_true_outputs)
+ false_graph.outputs.extend(extra_false_outputs)
+
+ # Create the If op.
+ tensors = gen_functional_ops._if(
+ pred, cond_inputs, [t.dtype for t in true_graph.outputs],
+ _create_new_tf_function(true_graph),
+ _create_new_tf_function(false_graph),
+ name=scope)
+
+ # Set the flag to enable lowering on the `if` op if necessary
+ # Lowering allows cond_v2 to avoid some of the limitations of Functions,
+ # allowing users to specify devices & colocation inside of cond_v2 branches,
+ # and enabling non-strict evaluation & partial pruning of cond_v2 branches.
+ # This brings cond_v2 closer to feature parity with tf.cond.
+ #
+ # However, we do not lower `If` in the XLA context because it is easier for
+ # XLA to apply its own optimizations when dealing with un-lowered `If`
+ # operators than with lowered switch/merge control flow.
+ #
+ # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
+ if_op = tensors[0].op
+ if not control_flow_util.IsInXLAContext(if_op):
+ if_op._set_attr("_lower_using_switch_merge",
+ attr_value_pb2.AttrValue(b=True))
+
+ return tensors[:num_cond_outputs]
+
+
+@ops.RegisterGradient("If")
+def _IfGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of an If op produced by cond_v2."""
+ true_graph, false_graph = _get_func_graphs(op)
+
+ # Create grad functions that compute the gradient of the true/false forward
+ # graphs. These functions will capture tensors from the forward pass
+ # functions.
+ true_grad_graph = _create_grad_func(
+ true_graph, grads, _get_grad_fn_name(true_graph))
+ false_grad_graph = _create_grad_func(
+ false_graph, grads, _get_grad_fn_name(false_graph))
+
+ assert ([t.dtype for t in true_grad_graph.outputs] ==
+ [t.dtype for t in false_grad_graph.outputs])
+
+ # Match up the captured grad function inputs with outputs of 'op' and other
+ # external tensors.
+ true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph)
+ false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph)
+
+ # Make the inputs to true_grad_graph and false_grad_graph match. Note that
+ # this modifies true_grad_graph and false_grad_graph.
+ grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
+ true_grad_inputs, false_grad_inputs)
+
+ # Add all intermediate tensors as function outputs so they're available for
+ # higher-order gradient computations.
+
+ true_grad_intermediates = _get_intermediates(true_grad_graph)
+ false_grad_intermediates = _get_intermediates(false_grad_graph)
+
+ # Save the original number of gradient outputs to return.
+ num_grad_outputs = len(true_grad_graph.outputs)
+
+ # Make the number/type of new intermediate outputs match.
+ extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
+ true_grad_graph, false_grad_graph,
+ true_grad_intermediates, false_grad_intermediates)
+
+ true_grad_graph.outputs.extend(extra_true_grad_outputs)
+ false_grad_graph.outputs.extend(extra_false_grad_outputs)
+
+ # Create the gradient If op.
+ tensors = gen_functional_ops._if(
+ op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
+ _create_new_tf_function(true_grad_graph),
+ _create_new_tf_function(false_grad_graph))
+
+ # The predicate has no gradient.
+ return [None] + tensors[:num_grad_outputs]
+
+
+def _get_func_graphs(if_op):
+ """Returns `_FuncGraph`s for the input op branches.
+
+ Args:
+ if_op: The _If Operation.
+
+ Returns:
+ A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch.
+ """
+ def _get_func_graph_for_branch(branch_name):
+ """Generates and returns a _FuncGraph for the given branch."""
+ extra_inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = if_op.get_attr(branch_name).name
+ fdef = if_op.graph._get_function(func_name).definition
+ func_graph = _function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes)
+ func_graph.extra_inputs = extra_inputs
+ func_graph.extra_args = func_graph.inputs
+ func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ return func_graph
+
+ return (_get_func_graph_for_branch("then_branch"),
+ _get_func_graph_for_branch("else_branch"))
+
+
+def _grad_fn(func_graph, grads):
+ """The gradient function for each conditional branch.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ conditional branch in `func_graph`. This is done by differentiating
+ func_graph's outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function._FuncGraph. The corresponding forward-pass function.
+ grads: The list of input gradient Tensors.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ # Filter out untrainable function outputs.
+ # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes
+ # cause _GradientsHelper to raise an exception (e.g. the implementation
+ # doesn't expect 'ys' to contain boolean tensors).
+ assert len(func_graph.outputs) == len(grads)
+ ys = []
+ grad_ys = []
+ for y, grad_y in zip(func_graph.outputs, grads):
+ if not _gradients_impl._IsTrainable(y):
+ continue
+ ys.append(y)
+ grad_ys.append(grad_y)
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _get_grad_inputs.
+ result = _gradients_impl._GradientsHelper(
+ ys, func_graph.inputs, grad_ys=grad_ys,
+ src_graph=func_graph)
+
+ # Functions can't return None; replace Nones with zero tensors.
+ # TODO(b/80444525): don't return anything here and make _IfGrad return None if
+ # both branches have zero gradient.
+ for i in range(len(result)):
+ if result[i] is None:
+ result[i] = array_ops.zeros_like(func_graph.inputs[i])
+
+ return result
+
+
+def _create_grad_func(func_graph, grads, name):
+ """Returns the _FuncGraph representation of _grad_fn."""
+ return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads),
+ [], [], name)
+
+
+def _get_grad_inputs(if_op, cond_graph, grad_graph):
+ """Returns the tensors we should pass to grad_graph.
+
+ This method handles tensors captured from cond_graph in grad_graph. It
+ converts these to suitable input tensors from the outer graph.
+
+ Args:
+ if_op: Operation. The forward-pass If op that uses cond_graph.
+ cond_graph: function._FuncGraph. The forward-pass function.
+ grad_graph: function._FuncGraph. The gradients function.
+
+ Returns:
+ A list of inputs tensors to be passed to grad_graph.
+ """
+ inputs = []
+
+ # Maps placeholders in cond_graph -> input tensor in outer graph.
+ forward_input_map = {v: k for k, v in cond_graph._captured.items()}
+
+ for t in grad_graph.extra_inputs:
+ if t.graph == ops.get_default_graph():
+ # t is in the outer graph (e.g. one of the input gradients).
+ inputs.append(t)
+ elif t in forward_input_map:
+ # t is an input placeholder in cond_graph. Get the corresponding input
+ # tensor in the outer graph.
+ assert t.graph == cond_graph
+ assert forward_input_map[t].graph == ops.get_default_graph()
+ inputs.append(forward_input_map[t])
+ else:
+ # t is an intermediate value in cond_graph. Get the corresponding output
+ # of 'if_op' (note that all intermediate values are outputs).
+ assert t.graph == cond_graph
+ output_idx = cond_graph.outputs.index(t)
+ inputs.append(if_op.outputs[output_idx])
+
+ return inputs
+
+
+def _create_new_tf_function(func_graph):
+ """Converts func_graph to a TF_Function and adds it to the current graph.
+
+ Args:
+ func_graph: function._FuncGraph
+
+ Returns:
+ The name of the new TF_Function.
+ """
+ c_func = c_api.TF_GraphToFunction_wrapper(
+ func_graph._c_graph,
+ compat.as_str(func_graph.name),
+ False, # append_hash_to_fn_name
+ None, # opers
+ [t._as_tf_output() for t in func_graph.inputs],
+ [t._as_tf_output() for t in func_graph.outputs],
+ [],
+ None, # opts
+ None) # description
+ _ = c_api_util.ScopedTFFunction(c_func)
+
+ # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
+ # deserializing it into a Python FunctionDef, then reserializing it to create
+ # a new TF_Function that we add to the graph.
+ fdef = _function.function_def_from_tf_function(c_func)
+ defined_func = _function._from_definition(fdef)
+ defined_func.add_to_graph(ops.get_default_graph())
+
+ return func_graph.name
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that aren't inputs or outputs."""
+ intermediates = []
+ for op in func_graph.get_operations():
+ for t in op.outputs:
+ if t in func_graph.inputs: continue
+ if t in func_graph.outputs: continue
+ intermediates.append(t)
+ return intermediates
+
+
+def _separate_unique_inputs(true_inputs, false_inputs):
+ """Separates tensors appearing only in true_inputs or false_inputs, or both.
+
+ Args:
+ true_inputs: list of Tensors
+ false_inputs: list of Tensors
+
+ Returns:
+ Three lists of Tensors:
+ 1. The tensors that appear in both true_inputs and false_inputs
+ 2. The tensors that only appear in true_inputs
+ 3. The tensors that only appear in false_inputs
+ """
+ true_inputs = set(true_inputs)
+ false_inputs = set(false_inputs)
+
+ shared_inputs = true_inputs.intersection(false_inputs)
+ true_only_inputs = true_inputs - false_inputs
+ false_only_inputs = false_inputs - true_inputs
+
+ return list(shared_inputs), list(true_only_inputs), list(false_only_inputs)
+
+
+def _pad_params(true_graph, false_graph, true_params, false_params):
+ """Returns new param lists that have matching signatures.
+
+ This is done by mirroring each param list in the other using dummy params.
+ There is no merging of params.
+
+ Args:
+ true_graph: function._FuncGraph
+ false_graph: function._FuncGraph
+ true_params: a list of Tensors from true_graph
+ false_params: a list of Tensors from false_graph
+
+ Returns:
+ A new list of Tensors in true_graph and a new list of Tensors in
+ false_graph. The two lists have the same number of Tensors, with matching
+ types and shapes across the lists.
+ """
+ new_true_params = (true_params +
+ _create_dummy_params(true_graph, false_params))
+ new_false_inputs = (_create_dummy_params(false_graph, true_params)
+ + false_params)
+ return new_true_params, new_false_inputs
+
+
+def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
+ """Modifies true_graph and false_graph so they have the same input signature.
+
+ This method reorders and/or adds parameters to true_graph and false_graph so
+ they have the same input signature, and updates the 'inputs', 'extra_inputs',
+ and '_captured' fields of both graphs accordingly. It uses the input tensors
+ from the outer graph to avoid duplicating shared arguments.
+
+ Args:
+ true_graph: function._FuncGraph
+ false_graph: function._FuncGraph
+ true_inputs: a list of Tensors in the outer graph. The inputs for
+ true_graph.
+ false_inputs: a list of Tensors in the outer graph. The inputs for
+ false_graph.
+
+ Returns:
+ A new list of Tensors from the outer graph that are the new inputs for both
+ true_graph and false_graph. This is a deduped version of true_inputs +
+ false_inputs.
+ """
+ shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
+ true_inputs, false_inputs)
+
+ new_inputs = shared_inputs + true_only_inputs + false_only_inputs
+
+ true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
+
+ true_graph.inputs = (
+ [true_input_to_param[t] for t in shared_inputs] +
+ [true_input_to_param[t] for t in true_only_inputs] +
+ _create_dummy_params(true_graph, false_only_inputs))
+
+ false_graph.inputs = (
+ [false_input_to_param[t] for t in shared_inputs] +
+ _create_dummy_params(false_graph, true_only_inputs) +
+ [false_input_to_param[t] for t in false_only_inputs])
+
+ # Rewrite the _FuncGraphs' state to reflect the new inputs.
+ true_graph.extra_inputs = new_inputs
+ false_graph.extra_inputs = new_inputs
+
+ true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
+ false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
+
+ return new_inputs
+
+
+def _create_dummy_params(func_graph, template_tensors):
+ """Creates tensors in func_graph to represent template_tensors.
+
+ Args:
+ func_graph: function._FuncGraph.
+ template_tensors: a list of tensors in the outer graph.
+
+ Returns:
+ A list of tensors in func_graph.
+ """
+ with func_graph.as_default():
+ return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape)
+ for t in template_tensors]
+
+
+def _get_grad_fn_name(func_graph):
+ """Returns a unique name to use for the grad function of `func_graph`."""
+ name = "%s_grad" % func_graph.name
+
+ base_name = name
+ counter = 1
+ if ops.get_default_graph()._is_function(name):
+ name = "%s_%s" % (base_name, counter)
+ counter += 1
+
+ return name
+
+
+def _check_same_outputs(true_graph, false_graph):
+ """Raises an error if true_graph and false_graph have different outputs."""
+ true_output_types = [t.dtype for t in true_graph.outputs]
+ false_output_types = [t.dtype for t in false_graph.outputs]
+ if (len(true_graph.outputs) != len(false_graph.outputs) or
+ true_output_types != false_output_types):
+ raise ValueError(
+ "true_fn() and false_fn() must return the same number and type of "
+ "arguments, got:\n"
+ " true_fn: %s\n"
+ " false_fn: %s" % (true_output_types, false_output_types))
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index ee024ce64a..04545cceb7 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -24,6 +24,7 @@ from __future__ import print_function
import abc
import collections
import functools
+import os
import six
@@ -38,6 +39,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_util as util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_control_flow_ops
@@ -57,6 +59,10 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+
+_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+
+
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
@@ -596,7 +602,6 @@ def _EnforceShapeInvariant(merge_var, next_var):
enter = merge_var.op.inputs[0].op
assert util.IsLoopEnter(enter)
input_t = enter.inputs[0]
- assert input_t.shape == m_shape
raise ValueError(
"Input tensor '%s' enters the loop with shape %s, but has shape %s "
"after one iteration. To allow the shape to vary across iterations, "
@@ -1994,6 +1999,9 @@ def cond(pred,
```
"""
+ if _ENABLE_COND_V2:
+ return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
+
# We needed to make true_fn/false_fn keyword arguments for
# backwards-compatibility. This check exists so that we can convert back to
# having them be positional arguments.
@@ -2729,7 +2737,8 @@ class WhileContext(ControlFlowContext):
self.outer_context.Exit()
else:
shape_acc = array_ops.zeros_like(
- array_ops.shape_internal(op.inputs[0], optimize=False),
+ array_ops.shape_internal(op.inputs[0], optimize=False,
+ out_type=dense_shape.dtype),
optimize=False)
if self.outer_context:
@@ -2923,7 +2932,8 @@ class WhileContext(ControlFlowContext):
return original_body_result, exit_vars
- def BuildLoop(self, pred, body, loop_vars, shape_invariants):
+ def BuildLoop(self, pred, body, loop_vars, shape_invariants,
+ return_same_structure):
"""Add the loop termination condition and body to the graph."""
# Keep original_loop_vars to identify which are TensorArrays
@@ -2934,9 +2944,10 @@ class WhileContext(ControlFlowContext):
loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
try:
self.Enter()
- # _BuildLoop calls _update_input in several places. _lock ensures a
- # Session.run call cannot occur between creating and mutating new ops.
- with ops.get_default_graph()._lock: # pylint: disable=protected-access
+ # _BuildLoop calls _update_input in several places. _mutation_lock()
+ # ensures a Session.run call cannot occur between creating and mutating
+ # new ops.
+ with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
original_body_result, exit_vars = self._BuildLoop(
pred, body, original_loop_vars, loop_vars, shape_invariants)
finally:
@@ -2950,7 +2961,11 @@ class WhileContext(ControlFlowContext):
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensor_arrays)
- return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
+
+ if return_same_structure:
+ return packed_exit_vars
+ else:
+ return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
@@ -2990,7 +3005,8 @@ def while_loop(cond,
back_prop=True,
swap_memory=False,
name=None,
- maximum_iterations=None):
+ maximum_iterations=None,
+ return_same_structure=False):
"""Repeat `body` while the condition `cond` is true.
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
@@ -3066,11 +3082,16 @@ def while_loop(cond,
to run. If provided, the `cond` output is AND-ed with an additional
condition ensuring the number of iterations executed is no greater than
`maximum_iterations`.
+ return_same_structure: If True, output has same structure as `loop_vars`. If
+ eager execution is enabled, this is ignored (and always treated as True).
Returns:
- The output tensors for the loop variables after the loop. When the length
- of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when
- the length of `loop_vars` is greater than 1 it returns a list.
+ The output tensors for the loop variables after the loop.
+ If `return_same_structure` is True, the return value has the same
+ structure as `loop_vars`.
+ If `return_same_structure` is False, the return value is a Tensor,
+ TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
+ otherwise.
Raises:
TypeError: if `cond` or `body` is not callable.
@@ -3125,6 +3146,7 @@ def while_loop(cond,
happen is that the thread updating `x` can never get ahead of the
counter thread because the thread incrementing `x` depends on the value
of the counter.
+
```python
import tensorflow as tf
@@ -3206,7 +3228,8 @@ def while_loop(cond,
# be encapsulated in the root context.
if loop_context.outer_context is None:
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
- result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
+ result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
+ return_same_structure)
if maximum_iterations is not None:
return result[1]
else:
@@ -3339,12 +3362,6 @@ def group(*inputs, **kwargs):
if not hasattr(inp, "device"):
raise TypeError("Expected tf.group() expected Tensor arguments not "
"'%s' with type '%s'" % (inp, type(inp)))
- if not hasattr(inp, "device"):
- if isinstance(inp, list):
- raise TypeError("To call tf.group() with a list, use "
- "tf.group(*[...]) not tf.group([...]).")
- raise TypeError("Expected tf.group() expected Tensor arguments not "
- "'%s' with type '%s'" % (inp, type(inp)))
dev = inp.device
if dev in ops_on_device:
ops_on_device[dev].append(inp)
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 59bb925df0..153548ae92 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -939,7 +939,7 @@ class CaseTest(test_util.TensorFlowTestCase):
class WhileLoopTestCase(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWhileLoopWithSingleVariable(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
@@ -948,7 +948,7 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
self.assertEqual(self.evaluate(r), 10)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, 10)
@@ -958,6 +958,28 @@ class WhileLoopTestCase(test_util.TensorFlowTestCase):
# Expect a tuple since that is what the body returns.
self.assertEqual(self.evaluate(r), (10,))
+ def testWhileLoopSameReturnShape_False(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the tensor.
+ r = control_flow_ops.while_loop(c, b, [i, []])
+ self.assertEqual(self.evaluate(r), 10)
+
+ def testWhileLoopSameReturnShape_True(self):
+ i = constant_op.constant(0)
+ c = lambda i, _: math_ops.less(i, 10)
+
+ # Body returns a [tensor, []]
+ b = lambda i, _: [math_ops.add(i, 1), []]
+
+ # Should only return the original structure.
+ r = control_flow_ops.while_loop(c, b, [i, []], return_same_structure=True)
+ self.assertEqual(self.evaluate(r), [10, []])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py
index 907df85cd9..aacdaa7ad0 100644
--- a/tensorflow/python/ops/conv2d_benchmark.py
+++ b/tensorflow/python/ops/conv2d_benchmark.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import itertools
import time
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -28,22 +30,32 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import flags
from tensorflow.python.platform import test
+FLAGS = flags.FLAGS
-def build_graph(device, input_shape, filter_shape, strides, padding, dtype,
- num_iters, warmup_iters):
+flags.DEFINE_boolean(
+ "enable_layout_optimizer", False,
+ "If true, enables layout optimizer to update input data format for faster "
+ "execution of convolution ops.")
+
+
+def build_graph(device, dtype, data_format, input_shape, filter_shape, strides,
+ padding, num_iters, warmup_iters):
"""builds a graph containing a sequence of conv2d operations.
Args:
device: String, the device to run on.
+ dtype: Data type for the convolution.
+ data_format: A string from: "NHWC" or "NCHW". Data format for input and
+ output data.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
window for each dimension of input.
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use.
- dtype: Data type for the convolution.
num_iters: number of iterations to run conv2d.
warmup_iters: number of iterations for warmup runs.
@@ -57,22 +69,23 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype,
random_ops.truncated_normal(filter_shape, dtype=dtype))
outputs = []
- conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC")
+ conv2d_op = nn_ops.conv2d(
+ inp, filt, strides, padding, data_format=data_format)
outputs.append(conv2d_op)
for _ in range(1, num_iters):
with ops.control_dependencies([conv2d_op]):
conv2d_op = nn_ops.conv2d(
- inp, filt, strides, padding, data_format="NHWC")
+ inp, filt, strides, padding, data_format=data_format)
outputs.append(conv2d_op)
warmup_groups = []
warmup_conv2d_op = nn_ops.conv2d(
- inp, filt, strides, padding, data_format="NHWC")
+ inp, filt, strides, padding, data_format=data_format)
warmup_groups.append(warmup_conv2d_op)
for _ in range(1, warmup_iters):
with ops.control_dependencies([warmup_conv2d_op]):
warmup_conv2d_op = nn_ops.conv2d(
- inp, filt, strides, padding, data_format="NHWC")
+ inp, filt, strides, padding, data_format=data_format)
warmup_groups.append(warmup_conv2d_op)
return control_flow_ops.group(*warmup_groups), control_flow_ops.group(
*outputs)
@@ -81,12 +94,15 @@ def build_graph(device, input_shape, filter_shape, strides, padding, dtype,
class Conv2DBenchmark(test.Benchmark):
"""Benchmark conv2d!"""
- def _run_graph(self, device, input_shape, filter_shape, strides, padding,
- dtype, num_iters, warmup_iters):
+ def _run_graph(self, device, dtype, data_format, input_shape, filter_shape,
+ strides, padding, num_iters, warmup_iters):
"""runs the graph and print its execution time.
Args:
device: String, the device to run on.
+ dtype: Data type for the convolution.
+ data_format: A string from: "NHWC" or "NCHW". Data format for input and
+ output data.
input_shape: Shape of the input tensor.
filter_shape: Shape of the filter tensor.
strides: A list of ints. 1-D of length 4. The stride of sliding
@@ -94,7 +110,6 @@ class Conv2DBenchmark(test.Benchmark):
padding: A string from: "SAME", "VALID". The type of padding
algorithm to use. num_iters: Number of iterations to run the
benchmark.
- dtype: Data type for the convolution.
num_iters: number of iterations to run conv2d.
warmup_iters: number of iterations for warmup runs.
@@ -103,10 +118,27 @@ class Conv2DBenchmark(test.Benchmark):
"""
graph = ops.Graph()
with graph.as_default():
- warmup_outputs, outputs = build_graph(device, input_shape, filter_shape,
- strides, padding, dtype, num_iters,
- warmup_iters)
- with session_lib.Session(graph=graph) as session:
+ warmup_outputs, outputs = build_graph(device, dtype, data_format,
+ input_shape, filter_shape, strides,
+ padding, num_iters, warmup_iters)
+
+ config = config_pb2.ConfigProto()
+ config.graph_options.optimizer_options.opt_level = -1
+ rewrite_options = config.graph_options.rewrite_options
+
+ # Disable layout optimizer to not change input data_format.
+ rewrite_options.layout_optimizer = (
+ rewriter_config_pb2.RewriterConfig.ON if FLAGS.enable_layout_optimizer
+ else rewriter_config_pb2.RewriterConfig.OFF)
+ # Convolution ops are effectively noop in the test graph as we are not
+ # fetching the convolution outputs. Disable dependency optimizer to not
+ # remove the conv ops.
+ rewrite_options.dependency_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+
+ with session_lib.Session(graph=graph, config=config) as session:
+ # TODO(hinsu): Use run_op_benchmark method from test.Benchmark to run
+ # benchmark along with warmup.
variables.global_variables_initializer().run()
# warmup runs
session.run(warmup_outputs)
@@ -114,20 +146,21 @@ class Conv2DBenchmark(test.Benchmark):
start_time = time.time()
session.run(outputs)
duration = (time.time() - start_time) / num_iters
- print("%s %s inputshape:%s filtershape:%s strides:%s padding:%s "
+ print("%s %s %s inputshape:%s filtershape:%s strides:%s padding:%s "
"%d iters: %.8f sec" %
- (device, str(dtype), str(input_shape).replace(" ", ""),
- str(filter_shape).replace(" ", ""),
+ (device, str(dtype), data_format, str(input_shape).replace(
+ " ", ""), str(filter_shape).replace(" ", ""),
str(strides).replace(" ", ""), padding, num_iters, duration))
name_template = (
- "conv2d_{device}_{datatype}_input_shape_{inputshape}_"
+ "conv2d_{device}_{datatype}_{data_format}_input_shape_{inputshape}_"
"filter_shape_{filtershape}_strides_{strides}_padding_{padding}")
self.report_benchmark(
name=name_template.format(
device=device,
datatype=str(dtype),
+ data_format=str(data_format),
inputshape=str(input_shape).replace(" ", ""),
filtershape=str(filter_shape).replace(" ", ""),
strides=str(strides).replace(" ", ""),
@@ -140,24 +173,37 @@ class Conv2DBenchmark(test.Benchmark):
def benchmark_conv2d(self):
print("conv2d benchmark:")
- h = 500
- w = 500
- fh = 3
- fw = 3
- input_shapes = []
- filter_shapes = []
data_types = [dtypes.float32, dtypes.float16]
- for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]):
- input_shapes += [[b, h, w, c]]
- filter_shapes += [[fh, fw, c, b]]
- strides = [[1, 2, 2, 1]]
+ data_formats = ["NHWC", "NCHW"]
+ in_channels = list(range(3, 16))
+ out_channels = [4, 16, 32]
+ hw_strides = [[2, 2]]
paddings = ["VALID", "SAME"]
- for ishape, fshape in zip(input_shapes, filter_shapes):
- for dtype in data_types:
- for stride in strides:
- for padding in paddings:
- self._run_graph("gpu", ishape, fshape, stride, padding, dtype, 80,
- 2)
+
+ args_lists = [
+ data_types, data_formats, in_channels, out_channels, hw_strides,
+ paddings
+ ]
+ for args in itertools.product(*args_lists):
+ dtype, data_format, in_channel, out_channel, hw_stride, padding = args
+
+ # Keep batch size same as out channels just to reduce the number of
+ # different configurations to benchmark.
+ batch_size = out_channel
+ h, w, fh, fw = 500, 500, 3, 3
+ if data_format == "NHWC":
+ ishape = [batch_size, h, w, in_channel]
+ stride = [1] + hw_stride + [1]
+ elif data_format == "NCHW":
+ ishape = [batch_size, in_channel, h, w]
+ stride = [1, 1] + hw_stride
+ else:
+ raise ValueError("Unknown data_format: " + str(data_format))
+ fshape = [fh, fw, in_channel, out_channel]
+ num_iters = 80
+ warmup_iters = 2
+ self._run_graph("gpu", dtype, data_format, ishape, fshape, stride,
+ padding, num_iters, warmup_iters)
if __name__ == "__main__":
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d934f27cb9..ca24f11054 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -89,7 +89,7 @@ def custom_gradient(f):
operations in `f` to `x`.
- `grad_fn` is a function with the signature `g(*grad_ys)` which returns
a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
- to the `Tensor`s in `x. `grad_ys` is a `Tensor` or sequence of
+ to the `Tensor`s in `x`. `grad_ys` is a `Tensor` or sequence of
`Tensor`s the same size as `y` holding the initial value gradients for
each `Tensor` in `y`. If `f` uses `Variable`s (that are not part of the
inputs), i.e. through `get_variable`, then `grad_fn` should have
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 62c5adc385..abf597ca55 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
@@ -129,11 +130,6 @@ class QueueBase(object):
@{tf.RandomShuffleQueue} for concrete
implementations of this class, and instructions on how to create
them.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self, dtypes, shapes, names, queue_ref):
@@ -157,12 +153,7 @@ class QueueBase(object):
Raises:
ValueError: If one of the arguments is invalid.
- RuntimeError: If eager execution is enabled.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- "Queues are not supported when eager execution is enabled. "
- "Instead, please use tf.data to get data into your model.")
self._dtypes = dtypes
if shapes is not None:
if len(shapes) != len(dtypes):
@@ -179,6 +170,8 @@ class QueueBase(object):
self._queue_ref = queue_ref
if context.executing_eagerly():
self._name = context.context().scope_name
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ queue_ref, None)
else:
self._name = self._queue_ref.op.name.split("/")[-1]
@@ -605,6 +598,11 @@ class QueueBase(object):
else:
return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
+def _shared_name(shared_name):
+ if context.executing_eagerly():
+ return str(ops.uid())
+ return shared_name
+
@tf_export("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
@@ -612,11 +610,6 @@ class RandomShuffleQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -690,7 +683,7 @@ class RandomShuffleQueue(QueueBase):
min_after_dequeue=min_after_dequeue,
seed=seed1,
seed2=seed2,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -702,11 +695,6 @@ class FIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -752,7 +740,7 @@ class FIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -767,11 +755,6 @@ class PaddingFIFOQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -831,7 +814,7 @@ class PaddingFIFOQueue(QueueBase):
component_types=dtypes,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
@@ -843,11 +826,6 @@ class PriorityQueue(QueueBase):
See @{tf.QueueBase} for a description of the methods on
this class.
-
- @compatibility(eager)
- Queues are not compatible with eager execution. Instead, please
- use `tf.data` to get data into your model.
- @end_compatibility
"""
def __init__(self,
@@ -899,7 +877,7 @@ class PriorityQueue(QueueBase):
component_types=types,
shapes=shapes,
capacity=capacity,
- shared_name=shared_name,
+ shared_name=_shared_name(shared_name),
name=name)
priority_dtypes = [_dtypes.int64] + types
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index f28f76b6c4..99d30b0bd1 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -84,13 +84,24 @@ class Beta(distribution.Distribution):
Distribution parameters are automatically broadcast in all functions; see
examples for details.
+ Warning: The samples can be zero due to finite precision.
+ This happens more often when some of the concentrations are very small.
+ Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
+ density.
+
+ Samples of this distribution are reparameterized (pathwise differentiable).
+ The derivatives are computed using the approach described in the paper
+
+ [Michael Figurnov, Shakir Mohamed, Andriy Mnih.
+ Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
+
#### Examples
```python
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
- dist = Beta(alpha, beta)
+ dist = tf.distributions.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@@ -106,7 +117,7 @@ class Beta(distribution.Distribution):
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
- dist = Beta(alpha, beta)
+ dist = tf.distributions.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@@ -122,6 +133,18 @@ class Beta(distribution.Distribution):
dist.prob(x) # Shape [2, 3]
```
+ Compute the gradients of samples w.r.t. the parameters:
+
+ ```python
+ alpha = tf.constant(1.0)
+ beta = tf.constant(2.0)
+ dist = tf.distributions.Beta(alpha, beta)
+ samples = dist.sample(5) # Shape [5]
+ loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
+ # Unbiased stochastic gradients of the loss function
+ grads = tf.gradients(loss, [alpha, beta])
+ ```
+
"""
def __init__(self,
@@ -165,7 +188,7 @@ class Beta(distribution.Distribution):
dtype=self._total_concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
- reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration1,
self._concentration0,
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index b88a0518b6..dd25fce2ec 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -32,12 +32,8 @@ from tensorflow.python.ops.distributions import util as distribution_util
from tensorflow.python.util.tf_export import tf_export
-def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
+def _broadcast_cat_event_and_params(event, params, base_dtype):
"""Broadcasts the event or distribution parameters."""
- if event.shape.ndims is None:
- raise NotImplementedError(
- "Cannot broadcast with an event tensor of unknown rank.")
-
if event.dtype.is_integer:
pass
elif event.dtype.is_floating:
@@ -47,15 +43,18 @@ def _broadcast_cat_event_and_params(event, params, base_dtype=dtypes.int32):
else:
raise TypeError("`value` should have integer `dtype` or "
"`self.dtype` ({})".format(base_dtype))
-
- if params.get_shape()[:-1] == event.get_shape():
- params = params
- else:
- params *= array_ops.ones_like(
- array_ops.expand_dims(event, -1), dtype=params.dtype)
+ shape_known_statically = (
+ params.shape.ndims is not None and
+ params.shape[:-1].is_fully_defined() and
+ event.shape.is_fully_defined())
+ if not shape_known_statically or params.shape[:-1] != event.shape:
+ params *= array_ops.ones_like(event[..., array_ops.newaxis],
+ dtype=params.dtype)
params_shape = array_ops.shape(params)[:-1]
event *= array_ops.ones(params_shape, dtype=event.dtype)
- event.set_shape(tensor_shape.TensorShape(params.get_shape()[:-1]))
+ if params.shape.ndims is not None:
+ event.set_shape(tensor_shape.TensorShape(params.shape[:-1]))
+
return event, params
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 72567e62f7..9104a1d071 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -90,13 +90,24 @@ class Dirichlet(distribution.Distribution):
Distribution parameters are automatically broadcast in all functions; see
examples for details.
+ Warning: Some components of the samples can be zero due to finite precision.
+ This happens more often when some of the concentrations are very small.
+ Make sure to round the samples to `np.finfo(dtype).tiny` before computing the
+ density.
+
+ Samples of this distribution are reparameterized (pathwise differentiable).
+ The derivatives are computed using the approach described in the paper
+
+ [Michael Figurnov, Shakir Mohamed, Andriy Mnih.
+ Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
+
#### Examples
```python
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
- dist = Dirichlet(alpha)
+ dist = tf.distributions.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@@ -118,7 +129,7 @@ class Dirichlet(distribution.Distribution):
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
- dist = Dirichlet(alpha)
+ dist = tf.distributions.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@@ -129,6 +140,17 @@ class Dirichlet(distribution.Distribution):
dist.prob(x) # shape: [2]
```
+ Compute the gradients of samples w.r.t. the parameters:
+
+ ```python
+ alpha = tf.constant([1.0, 2.0, 3.0])
+ dist = tf.distributions.Dirichlet(alpha)
+ samples = dist.sample(5) # Shape [5, 3]
+ loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
+ # Unbiased stochastic gradients of the loss function
+ grads = tf.gradients(loss, alpha)
+ ```
+
"""
def __init__(self,
@@ -165,7 +187,7 @@ class Dirichlet(distribution.Distribution):
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
- reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration,
self._total_concentration],
@@ -290,10 +312,8 @@ class Dirichlet(distribution.Distribution):
if not self.validate_args:
return x
return control_flow_ops.with_dependencies([
- check_ops.assert_positive(
- x,
- message="samples must be positive"),
- distribution_util.assert_close(
+ check_ops.assert_positive(x, message="samples must be positive"),
+ check_ops.assert_near(
array_ops.ones([], dtype=self.dtype),
math_ops.reduce_sum(x, -1),
message="sample last-dimension must sum to `1`"),
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 0db4749507..c03ef967e6 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -212,7 +212,7 @@ class ReparameterizationType(object):
reparameterized, and straight-through gradients are either partially
unsupported or are not supported at all. In this case, for purposes of
e.g. RL or variational inference, it is generally safest to wrap the
- sample results in a `stop_gradients` call and instead use policy
+ sample results in a `stop_gradients` call and use policy
gradients / surrogate loss instead.
"""
@@ -722,11 +722,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_prob(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return math_ops.log(self._prob(value, **kwargs))
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.log(self._prob(value, **kwargs))
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@@ -749,11 +746,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._prob(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return math_ops.exp(self._log_prob(value, **kwargs))
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.exp(self._log_prob(value, **kwargs))
def prob(self, value, name="prob"):
"""Probability density/mass function.
@@ -776,11 +770,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_cdf(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return math_ops.log(self._cdf(value, **kwargs))
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.log(self._cdf(value, **kwargs))
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@@ -813,11 +804,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._cdf(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return math_ops.exp(self._log_cdf(value, **kwargs))
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.exp(self._log_cdf(value, **kwargs))
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@@ -846,11 +834,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_survival_function(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return math_ops.log1p(-self.cdf(value, **kwargs))
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.log1p(-self.cdf(value, **kwargs))
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@@ -884,11 +869,8 @@ class Distribution(_BaseDistribution):
value = ops.convert_to_tensor(value, name="value")
try:
return self._survival_function(value, **kwargs)
- except NotImplementedError as original_exception:
- try:
- return 1. - self.cdf(value, **kwargs)
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return 1. - self.cdf(value, **kwargs)
def survival_function(self, value, name="survival_function"):
"""Survival function.
@@ -933,10 +915,7 @@ class Distribution(_BaseDistribution):
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
- try:
- return self._quantile(value, **kwargs)
- except NotImplementedError as original_exception:
- raise original_exception
+ return self._quantile(value, **kwargs)
def quantile(self, value, name="quantile"):
"""Quantile function. Aka "inverse cdf" or "percent point function".
@@ -982,11 +961,8 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._variance()
- except NotImplementedError as original_exception:
- try:
- return math_ops.square(self._stddev())
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.square(self._stddev())
def _stddev(self):
raise NotImplementedError("stddev is not implemented")
@@ -1014,11 +990,8 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._stddev()
- except NotImplementedError as original_exception:
- try:
- return math_ops.sqrt(self._variance())
- except NotImplementedError:
- raise original_exception
+ except NotImplementedError:
+ return math_ops.sqrt(self._variance())
def _covariance(self):
raise NotImplementedError("covariance is not implemented")
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index 24bc3f3d3e..4325a14449 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -103,9 +103,6 @@ class Exponential(gamma.Gamma):
allow_nan_stats=allow_nan_stats,
validate_args=validate_args,
name=name)
- # While the Gamma distribution is not reparameterizable, the exponential
- # distribution is.
- self._reparameterization_type = True
self._parameters = parameters
self._graph_parents += [self._rate]
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 163a27f758..b631f0247c 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -55,7 +55,7 @@ class Gamma(distribution.Distribution):
```none
pdf(x; alpha, beta, x > 0) = x**(alpha - 1) exp(-x beta) / Z
- Z = Gamma(alpha) beta**alpha
+ Z = Gamma(alpha) beta**(-alpha)
```
where:
@@ -85,14 +85,35 @@ class Gamma(distribution.Distribution):
Distribution parameters are automatically broadcast in all functions; see
examples for details.
- WARNING: This distribution may draw 0-valued samples for small `concentration`
- values. See note in `tf.random_gamma` docstring.
+ Warning: The samples of this distribution are always non-negative. However,
+ the samples that are smaller than `np.finfo(dtype).tiny` are rounded
+ to this value, so it appears more often than it should.
+ This should only be noticeable when the `concentration` is very small, or the
+ `rate` is very large. See note in `tf.random_gamma` docstring.
+
+ Samples of this distribution are reparameterized (pathwise differentiable).
+ The derivatives are computed using the approach described in the paper
+
+ [Michael Figurnov, Shakir Mohamed, Andriy Mnih.
+ Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
#### Examples
```python
- dist = Gamma(concentration=3.0, rate=2.0)
- dist2 = Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
+ dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ ```
+
+ Compute the gradients of samples w.r.t. the parameters:
+
+ ```python
+ concentration = tf.constant(3.0)
+ rate = tf.constant(2.0)
+ dist = tf.distributions.Gamma(concentration, rate)
+ samples = dist.sample(5) # Shape [5]
+ loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
+ # Unbiased stochastic gradients of the loss function
+ grads = tf.gradients(loss, [concentration, rate])
```
"""
@@ -141,7 +162,7 @@ class Gamma(distribution.Distribution):
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
- reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration,
self._rate],
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index 20a2d16181..e0cf6f86f1 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -80,6 +80,12 @@ class StudentT(distribution.Distribution):
variance. However it is not actually the std. deviation; the Student's
t-distribution std. dev. is `scale sqrt(df / (df - 2))` when `df > 2`.
+ Samples of this distribution are reparameterized (pathwise differentiable).
+ The derivatives are computed using the approach described in the paper
+
+ [Michael Figurnov, Shakir Mohamed, Andriy Mnih.
+ Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
+
#### Examples
Examples of initialization of one or a batch of distributions.
@@ -118,6 +124,19 @@ class StudentT(distribution.Distribution):
dist.prob(3.0)
```
+ Compute the gradients of samples w.r.t. the parameters:
+
+ ```python
+ df = tf.constant(2.0)
+ loc = tf.constant(2.0)
+ scale = tf.constant(11.0)
+ dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+ samples = dist.sample(5) # Shape [5]
+ loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
+ # Unbiased stochastic gradients of the loss function
+ grads = tf.gradients(loss, [df, loc, scale])
+ ```
+
"""
# pylint: enable=line-too-long
@@ -168,7 +187,7 @@ class StudentT(distribution.Distribution):
(self._df, self._loc, self._scale))
super(StudentT, self).__init__(
dtype=self._scale.dtype,
- reparameterization_type=distribution.NOT_REPARAMETERIZED,
+ reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 401676bf84..3e480a79f5 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -36,43 +36,6 @@ from tensorflow.python.ops import nn
from tensorflow.python.util import tf_inspect
-def assert_close(
- x, y, data=None, summarize=None, message=None, name="assert_close"):
- """Assert that x and y are within machine epsilon of each other.
-
- Args:
- x: Floating-point `Tensor`
- y: Floating-point `Tensor`
- data: The tensors to print out if the condition is `False`. Defaults to
- error message and first few entries of `x` and `y`.
- summarize: Print this many entries of each tensor.
- message: A string to prefix to the default message.
- name: A name for this operation (optional).
-
- Returns:
- Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
- """
- message = message or ""
- x = ops.convert_to_tensor(x, name="x")
- y = ops.convert_to_tensor(y, name="y")
-
- if data is None:
- data = [
- message,
- "Condition x ~= y did not hold element-wise: x = ", x, "y = ", y
- ]
-
- if x.dtype.is_integer:
- return check_ops.assert_equal(
- x, y, data=data, summarize=summarize, message=message, name=name)
-
- with ops.name_scope(name, "assert_close", [x, y, data]):
- tol = np.finfo(x.dtype.as_numpy_dtype).eps
- condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
- return control_flow_ops.Assert(
- condition, data, summarize=summarize)
-
-
def assert_integer_form(
x, data=None, summarize=None, message=None,
int_dtype=None, name="assert_integer_form"):
@@ -241,8 +204,12 @@ def get_logits_and_probs(logits=None,
dependencies = [check_ops.assert_non_negative(probs)]
if multidimensional:
probs = embed_check_categorical_event_shape(probs)
- dependencies += [assert_close(math_ops.reduce_sum(probs, -1), one,
- message="probs does not sum to 1.")]
+ dependencies += [
+ check_ops.assert_near(
+ math_ops.reduce_sum(probs, -1),
+ one,
+ message="probs does not sum to 1.")
+ ]
else:
dependencies += [check_ops.assert_less_equal(
probs, one, message="probs has components greater than 1.")]
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index bcc717b043..27c2fa7017 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
# Imports gradient definitions.
@@ -30,6 +31,7 @@ from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-impor
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@@ -43,8 +45,8 @@ def _clip(params, ids, max_norm):
Args:
params: A `Tensor` of embeddings retrieved by `gather`.
ids: The `ids` argument that was passed to `gather`.
- max_norm: If provided, the embeddings are l2-normalized to the value of
- max_norm.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value.
Returns:
A `Tensor` with the same type as `params`.
@@ -290,8 +292,8 @@ def embedding_lookup(
in `indices` are always validated to be within range. If assigned to GPU,
out-of-bound indices result in safe but unspecified behavior, which may
include raising an error.
- max_norm: If provided, embedding values are l2-normalized to the value of
- max_norm.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value.
Returns:
A `Tensor` with the same type as the tensors in `params`.
@@ -346,8 +348,8 @@ def embedding_lookup_sparse(params,
"mean" is the weighted sum divided by the total weight.
"sqrtn" is the weighted sum divided by the square root of the sum of the
squares of the weights.
- max_norm: If provided, each embedding is normalized to have l2 norm equal
- to max_norm before combining.
+ max_norm: If not `None`, each embedding is clipped if its l2-norm is
+ larger than this value, before combining.
Returns:
A dense tensor representing the combined embeddings for the
@@ -479,3 +481,158 @@ def embedding_lookup_sparse(params,
assert False, "Unrecognized combiner"
return embeddings
+
+
+@tf_export("nn.safe_embedding_lookup_sparse")
+def safe_embedding_lookup_sparse(embedding_weights,
+ sparse_ids,
+ sparse_weights=None,
+ combiner='mean',
+ default_id=None,
+ name=None,
+ partition_strategy='div',
+ max_norm=None):
+ """Lookup embedding results, accounting for invalid IDs and empty features.
+
+ The partitioned embedding in `embedding_weights` must all be the same shape
+ except for the first dimension. The first dimension is allowed to vary as the
+ vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
+ may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
+ partitioner.
+
+ Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
+ with non-positive weight. For an entry with no features, the embedding vector
+ for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
+
+ The ids and weights may be multi-dimensional. Embeddings are always aggregated
+ along the last dimension.
+
+ Args:
+ embedding_weights: A list of `P` float `Tensor`s or values representing
+ partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
+ created by partitioning along dimension 0. The total unpartitioned
+ shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
+ vocab size and `e_1, ..., e_m` are the embedding dimensions.
+ sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
+ ids. `d_0` is typically batch size.
+ sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
+ float weights corresponding to `sparse_ids`, or `None` if all weights
+ are be assumed to be 1.0.
+ combiner: A string specifying how to combine embedding results for each
+ entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
+ the default.
+ default_id: The id to use for an entry with no features.
+ name: A name for this operation (optional).
+ partition_strategy: A string specifying the partitioning strategy.
+ Currently `"div"` and `"mod"` are supported. Default is `"div"`.
+ max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
+ combining.
+
+
+ Returns:
+ Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
+
+ Raises:
+ ValueError: if `embedding_weights` is empty.
+ """
+ if embedding_weights is None:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+ if isinstance(embedding_weights, variables.PartitionedVariable):
+ embedding_weights = list(embedding_weights) # get underlying Variables.
+ if not isinstance(embedding_weights, list):
+ embedding_weights = [embedding_weights]
+ if len(embedding_weights) < 1:
+ raise ValueError('Missing embedding_weights %s.' % embedding_weights)
+
+ dtype = sparse_weights.dtype if sparse_weights is not None else None
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
+
+ with ops.name_scope(name, 'embedding_lookup',
+ embedding_weights + [sparse_ids,
+ sparse_weights]) as scope:
+ # Reshape higher-rank sparse ids and weights to linear segment ids.
+ original_shape = sparse_ids.dense_shape
+ original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
+ original_rank = (
+ array_ops.size(original_shape)
+ if original_rank_dim.value is None
+ else original_rank_dim.value)
+ sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
+ math_ops.reduce_prod(
+ array_ops.slice(original_shape, [0], [original_rank - 1])),
+ array_ops.gather(original_shape, original_rank - 1)])
+ if sparse_weights is not None:
+ sparse_weights = sparse_tensor.SparseTensor(
+ sparse_ids.indices,
+ sparse_weights.values, sparse_ids.dense_shape)
+
+ # Prune invalid ids and weights.
+ sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+ if combiner != 'sum':
+ sparse_ids, sparse_weights = _prune_invalid_weights(
+ sparse_ids, sparse_weights)
+
+ # Fill in dummy values for empty features, if necessary.
+ sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
+ default_id or
+ 0)
+ if sparse_weights is not None:
+ sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
+
+ result = embedding_lookup_sparse(
+ embedding_weights,
+ sparse_ids,
+ sparse_weights,
+ combiner=combiner,
+ partition_strategy=partition_strategy,
+ name=None if default_id is None else scope,
+ max_norm=max_norm)
+
+ if default_id is None:
+ # Broadcast is_row_empty to the same shape as embedding_lookup_result,
+ # for use in Select.
+ is_row_empty = array_ops.tile(
+ array_ops.reshape(is_row_empty, [-1, 1]),
+ array_ops.stack([1, array_ops.shape(result)[1]]))
+
+ result = array_ops.where(is_row_empty,
+ array_ops.zeros_like(result),
+ result,
+ name=scope)
+
+ # Reshape back from linear ids back into higher-dimensional dense result.
+ final_result = array_ops.reshape(
+ result,
+ array_ops.concat([
+ array_ops.slice(
+ math_ops.cast(original_shape, dtypes.int32), [0],
+ [original_rank - 1]),
+ array_ops.slice(array_ops.shape(result), [1], [-1])
+ ], 0))
+ final_result.set_shape(tensor_shape.unknown_shape(
+ (original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
+ return final_result
+
+
+def _prune_invalid_ids(sparse_ids, sparse_weights):
+ """Prune invalid IDs (< 0) from the input ids and weights."""
+ is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
+ if sparse_weights is not None:
+ is_id_valid = math_ops.logical_and(
+ is_id_valid,
+ array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
+ if sparse_weights is not None:
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
+ return sparse_ids, sparse_weights
+
+
+def _prune_invalid_weights(sparse_ids, sparse_weights):
+ """Prune invalid weights (< 0) from the input ids and weights."""
+ if sparse_weights is not None:
+ is_weights_valid = math_ops.greater(sparse_weights.values, 0)
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
+ return sparse_ids, sparse_weights
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 30413f289a..4ecc74675a 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -775,7 +775,7 @@ def While(input_, cond, body, name=None, hostmem=None):
a string, non-empty means True and empty means False. If the
tensor is not a scalar, non-emptiness means True and False
otherwise.
- body: . A funcion takes a list of tensors and returns another
+ body: . A function takes a list of tensors and returns another
list tensors. Both lists have the same types as specified
by T.
name: A name for the operation (optional).
@@ -945,6 +945,61 @@ def For(start,
# pylint: enable=invalid-name,protected-access
-def partitioned_call(args, f):
- return gen_functional_ops.partitioned_call(
- args=args, Tout=[o.type for o in f.definition.signature.output_arg], f=f)
+def partitioned_call(args, f, tout=None, executing_eagerly=None):
+ """Executes a function while respecting device annotations.
+
+ Currently, only those functions that execute within the same address space
+ can be executed.
+
+ Args:
+ args: The arguments of the function, including captured inputs.
+ f: The function to execute; an instance of `_DefinedFunction` or
+ `_EagerDefinedFunction`.
+ tout: a list containing the output dtypes enums; if `None`, inferred from
+ the signature of `f`.
+ executing_eagerly: (Optional) A boolean indicating whether the context is
+ executing eagerly. If `None`, fetched from the global context.
+
+ Returns:
+ The list of `Tensor`s returned by invoking `f(args)`. If the function does
+ not return anything, then returns `None` if eager execution is enabled, or
+ the `Operation` if not.
+ """
+
+ if tout is None:
+ tout = tuple(x.type for x in f.definition.signature.output_arg)
+
+ if executing_eagerly is None:
+ executing_eagerly = context.executing_eagerly()
+
+ if executing_eagerly or len(tout):
+ if f.stateful_ops:
+ outputs = gen_functional_ops.stateful_partitioned_call(
+ args=args, Tout=tout, f=f)
+ else:
+ outputs = gen_functional_ops.partitioned_call(args=args, Tout=tout, f=f)
+ return outputs if outputs else None
+
+ # The generated binding returns an empty list for functions that don't
+ # return any Tensors, hence the need to use `create_op` directly.
+ args = [ops.internal_convert_to_tensor(x) for x in args]
+ tin_attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ type=[x.dtype.as_datatype_enum for x in args]))
+ tout_attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(type=tout))
+ func_attr = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(name=f.name))
+
+ graph = ops.get_default_graph()
+ f.add_to_graph(graph)
+ op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
+ op = graph.create_op(
+ op_name,
+ args,
+ tout,
+ compute_shapes=False,
+ name="PartitionedFunctionCall",
+ attrs={"Tin": tin_attr, "Tout": tout_attr, "f": func_attr})
+ outputs = op.outputs
+ return outputs if outputs else op
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 7385cb7585..b64a66be03 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import contextlib
+import sys
import warnings
import numpy as np
@@ -30,12 +31,14 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops # pylint: disable=unused-import
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
@@ -47,12 +50,17 @@ from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_grad # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
+# This is to avoid a circular dependency with cond_v2_impl.
+cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access
+
# Warn the user if we convert a sparse representation to dense with at
# least this number of elements.
_LARGE_SPARSE_NUM_ELEMENTS = 100000000
@@ -107,12 +115,14 @@ ops.register_tensor_conversion_function(ops.IndexedSlices,
_IndexedSlicesToTensor)
-def _MarkReachedOps(from_ops, reached_ops):
+def _MarkReachedOps(from_ops, reached_ops, func_graphs):
"""Mark all ops reached from "from_ops".
Args:
from_ops: list of Operations.
reached_ops: set of Operations.
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops.
"""
queue = collections.deque()
queue.extend(from_ops)
@@ -122,36 +132,11 @@ def _MarkReachedOps(from_ops, reached_ops):
reached_ops.add(op)
for output in op.outputs:
if _IsBackpropagatable(output):
- queue.extend(output.consumers())
-
-
-def _GatherInputs(to_ops, reached_ops):
- """List all inputs of to_ops that are in reached_ops.
+ queue.extend(_Consumers(output, func_graphs))
- Args:
- to_ops: list of Operations.
- reached_ops: set of Operations.
- Returns:
- The list of all inputs of to_ops that are in reached_ops.
- That list includes all elements of to_ops.
- """
- inputs = []
- queue = collections.deque()
- queue.extend(to_ops)
- while queue:
- op = queue.popleft()
- # We are interested in this op.
- if op in reached_ops:
- inputs.append(op)
- # Clear the boolean so we won't add the inputs again.
- reached_ops.remove(op)
- for inp in op.inputs:
- queue.append(inp.op)
- return inputs
-
-
-def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
+ xs):
"""Initialize the pending count for ops between two lists of Operations.
'pending_count[op]' indicates the number of backprop inputs
@@ -161,6 +146,11 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
+ func_graphs: list of function._FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops. This is
+ useful if to_ops occur in a function and from_ops are in an outer function
+ or graph.
+ xs: list of Tensors.
Returns:
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
@@ -171,7 +161,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
"""
# Mark reachable ops from from_ops.
reached_ops = set()
- _MarkReachedOps(from_ops, reached_ops)
+ _MarkReachedOps(from_ops, reached_ops, func_graphs)
# X in reached_ops iff X is reachable from from_ops by a path of zero or more
# backpropagatable tensors.
@@ -190,7 +180,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
@@ -202,7 +192,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
- for x in op.inputs:
+ for x in _Inputs(op, xs):
if x.op in between_ops:
pending_count[x.op] += 1
@@ -323,7 +313,7 @@ def _VerifyGeneratedGradients(grads, op):
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
-def _StopOps(from_ops, stop_gradient_ops, pending_count):
+def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
"""The set of ops that terminate the gradient computation.
This computes the frontier of the forward graph *before* which backprop
@@ -339,6 +329,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
pending_count: mapping from operation to number of backprop inputs.
+ xs: list of Tensors.
Returns:
The set of operations.
@@ -346,7 +337,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
stop_ops = set()
for op in from_ops:
is_stop_op = True
- for inp in op.inputs:
+ for inp in _Inputs(op, xs):
if pending_count[inp.op] > 0:
is_stop_op = False
break
@@ -366,15 +357,26 @@ def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pyli
yield
-def _SymGrad(op, out_grads):
+def _IsPartitionedCall(op):
+ return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
+
+
+def _SymGrad(op, out_grads, xs):
"""Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in op.inputs] + out_grads
- f_types = [x.dtype for x in op.inputs]
+ f_in = [x for x in _Inputs(op, xs)] + out_grads
+ f_types = [x.dtype for x in _Inputs(op, xs)]
f = attr_value_pb2.NameAttrList()
- f.name = op.type
+ if _IsPartitionedCall(op):
+ f.name = op.get_attr("f").name
+ else:
+ f.name = op.type
for k in op.node_def.attr:
f.attr[k].CopyFrom(op.node_def.attr[k])
- in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
+ # TODO(apassos) use a better dtype here
+ in_grads = functional_ops.symbolic_gradient(
+ input=f_in,
+ Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types],
+ f=f)
return in_grads
@@ -415,7 +417,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
-def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
"""Raises an error if we backprop through a loop var."""
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
# message.
@@ -429,7 +431,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
if curr_op in from_ops:
target_op = curr_op
break
- queue.extend(t.op for t in curr_op.inputs)
+ queue.extend(t.op for t in _Inputs(curr_op, xs))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -439,6 +441,68 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
% target_op.name)
+def _MaybeCaptured(t):
+ """If t is a captured value placeholder, returns the original captured value.
+
+ Args:
+ t: Tensor
+
+ Returns:
+ A tensor, potentially from a different Graph/function._FuncGraph.
+ """
+ # pylint: disable=protected-access
+ if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder":
+ for input_t, placeholder_t in t.op.graph._captured.items():
+ if t == placeholder_t:
+ return _MaybeCaptured(input_t)
+ # pylint: enable=protected-access
+ return t
+
+
+# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
+# _GradientsHelper a class with xs as a member variable.
+def _Inputs(op, xs):
+ """Returns the inputs of op, crossing closure boundaries where necessary.
+
+ Args:
+ op: Operation
+ xs: list of Tensors we are differentiating w.r.t.
+
+ Returns:
+ A list of tensors. The tensors may be from multiple
+ Graph/function._FuncGraphs if op is in a function._FuncGraph and has
+ captured inputs.
+ """
+ if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
+ # to a captured value. The algorithm needs to "see" `t` in this case, even
+ # if it's a function input for a captured value, whereas usually we'd like
+ # to traverse through these closures as if the captured value was the direct
+ # input to op.
+ return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+ else:
+ return op.inputs
+
+
+def _Consumers(t, func_graphs):
+ """Returns the consumers of t, crossing closure boundaries where necessary.
+
+ Args:
+ t: Tensor
+ func_graphs: a list of function._FuncGraphs that may have captured t.
+
+ Returns:
+ A list of tensors. The tensors will be from the current graph and/or
+ func_graphs.
+ """
+ consumers = t.consumers()
+ for func in func_graphs:
+ for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access
+ if input_t == t:
+ consumers.extend(_Consumers(placeholder, func_graphs))
+ return consumers
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -524,10 +588,10 @@ def gradients(ys,
RuntimeError: if called in Eager mode.
"""
- # Creating the gradient graph for control flow mutates Operations. _lock
- # ensures a Session.run call cannot occur between creating and mutating new
- # ops.
- with ops.get_default_graph()._lock: # pylint: disable=protected-access
+ # Creating the gradient graph for control flow mutates Operations.
+ # _mutation_lock ensures a Session.run call cannot occur between creating and
+ # mutating new ops.
+ with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
gate_gradients, aggregation_method, stop_gradients)
@@ -543,12 +607,19 @@ def _GradientsHelper(ys,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
- raise RuntimeError("tf.gradients not supported when eager execution "
- "is enabled. Use tf.contrib.eager.GradientTape "
- "instead.")
+ raise RuntimeError("tf.gradients is not supported when eager execution "
+ "is enabled. Use tf.GradientTape instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
+ # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
+ # ancestor graphs. This is necessary for correctly handling captured values.
+ func_graphs = []
+ curr_graph = src_graph
+ while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access
+ func_graphs.append(curr_graph)
+ curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
+
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
@@ -583,12 +654,13 @@ def _GradientsHelper(ys,
# Initialize the pending count for ops in the connected subgraph from ys
# to the xs.
if len(ys) > 1:
- ys = [array_ops.identity(y) if y.consumers() else y for y in ys]
+ ys = [array_ops.identity(y) if _Consumers(y, func_graphs) else y
+ for y in ys]
to_ops = [t.op for t in ys]
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
- to_ops, from_ops, colocate_gradients_with_ops)
+ to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
# Iterate over the collected ops.
#
@@ -622,7 +694,7 @@ def _GradientsHelper(ys,
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
- stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count)
+ stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
@@ -636,13 +708,19 @@ def _GradientsHelper(ys,
grad_fn = None
func_call = None
+ is_partitioned_call = _IsPartitionedCall(op)
# pylint: disable=protected-access
- is_func_call = src_graph._is_function(op.type)
+ is_func_call = (
+ src_graph._is_function(op.type) or is_partitioned_call)
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
if is_func_call:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
# Note that __defun is not set if the graph is
# imported. If it's set, we prefer to access the original
# defun.
@@ -671,7 +749,7 @@ def _GradientsHelper(ys,
op._control_flow_context.IsWhileContext() and
op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()):
- _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
# pylint: enable=protected-access
if (grad_fn or is_func_call) and has_out_grads:
@@ -703,7 +781,7 @@ def _GradientsHelper(ys,
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads))
+ lambda: _SymGrad(op, out_grads, xs))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len([x for x in in_grads
@@ -718,8 +796,8 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
- in_grads = [None] * len(op.inputs)
- for i, (t_in, in_grad) in enumerate(zip(op.inputs, in_grads)):
+ in_grads = [None] * len(_Inputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@@ -737,7 +815,8 @@ def _GradientsHelper(ys,
loop_state.ExitGradWhileContext(op, before=False)
# Update pending count for the inputs of op and enqueue ready ops.
- _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state)
+ _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs)
if loop_state:
loop_state.PostProcessing()
@@ -756,9 +835,10 @@ def _HasAnyNotNoneGrads(grads, op):
return False
-def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
+def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs):
"""Update pending count for the inputs of op and enqueue ready ops."""
- for x in op.inputs:
+ for x in _Inputs(op, xs):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 6891501ae1..d02fcf4ee2 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -57,91 +57,8 @@ from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest
-def _OpsBetween(to_ops, from_ops):
- """Build the list of operations between two lists of Operations.
-
- Args:
- to_ops: list of Operations.
- from_ops: list of Operations.
-
- Returns:
- The list of operations between "from_ops" and "to_ops", sorted by
- decreasing operation id. This list contains all elements of to_ops.
-
- TODO(touts): Think about returning an empty list if from_ops are not
- reachable from to_ops. Presently it returns to_ops in that case.
- """
- # Ops that are reachable from the output of "input_ops".
- reached_ops = set()
- # We only care to reach up to "output_ops" so we mark the
- # output ops as reached to avoid recursing past them.
- for op in to_ops:
- reached_ops.add(op)
- gradients_impl._MarkReachedOps(from_ops, reached_ops)
- between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
- between_ops.sort(key=lambda x: -x._id)
- return between_ops
-
-
-@test_util.with_c_api
class GradientsTest(test_util.TensorFlowTestCase):
- def _OpNames(self, op_list):
- return ["%s/%d" % (str(op.name), op._id) for op in op_list]
-
- def _assertOpListEqual(self, ops1, ops2):
- self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
-
- def testOpsBetweenSimple(self):
- with ops.Graph().as_default():
- t1 = constant(1.0)
- t2 = constant(2.0)
- t3 = array_ops.stack([t1, t2])
- # Full graph
- self._assertOpListEqual([t3.op, t2.op, t1.op],
- _OpsBetween([t3.op], [t1.op, t2.op]))
- # Only t1, t3.
- self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op]))
-
- def testOpsBetweenUnreachable(self):
- with ops.Graph().as_default():
- t1 = constant(1.0)
- t2 = constant(2.0)
- _ = array_ops.stack([t1, t2])
- t4 = constant(1.0)
- t5 = constant(2.0)
- t6 = array_ops.stack([t4, t5])
- # Elements of to_ops are always listed.
- self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op]))
-
- def testOpsBetweenCut(self):
- with ops.Graph().as_default():
- t1 = constant(1.0)
- t2 = constant(2.0)
- t3 = array_ops.stack([t1, t2])
- t4 = constant([1.0])
- t5 = array_ops.concat([t4, t3], 0)
- t6 = constant([2.0])
- t7 = array_ops.concat([t5, t6], 0)
- self._assertOpListEqual([t7.op, t5.op, t4.op],
- _OpsBetween([t7.op], [t4.op]))
-
- def testOpsBetweenCycle(self):
- with ops.Graph().as_default():
- t1 = constant(1.0)
- t2 = constant(2.0)
- t3 = array_ops.stack([t1, t2])
- t4 = array_ops.concat([t3, t3, t3], 0)
- t5 = constant([1.0])
- t6 = array_ops.concat([t4, t5], 0)
- t7 = array_ops.concat([t6, t3], 0)
- self._assertOpListEqual([t6.op, t4.op, t3.op],
- _OpsBetween([t6.op], [t3.op]))
- self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
- _OpsBetween([t7.op], [t1.op, t5.op]))
- self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
- _OpsBetween([t6.op], [t2.op, t5.op]))
-
def testGradients(self):
with ops.Graph().as_default():
inp = constant(1.0, shape=[32, 100], name="in")
@@ -520,6 +437,96 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
grad_func=grad_func, python_grad_func=self._PythonGradient)
f.add_to_graph(ops.Graph())
+ def testGradientWrtCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(x, 2.0, name="y")
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testGradientOfCaptured(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Foo():
+ g = gradients_impl.gradients(y, x)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedResourceVariable(self):
+ with ops.Graph().as_default():
+ var = resource_variable_ops.ResourceVariable(1.0, name="var")
+
+ @function.Defun()
+ def Foo():
+ y = math_ops.multiply(var, 2.0, name="y")
+ g = gradients_impl.gradients(y, var)
+ return g[0]
+
+ f = Foo()
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(sess.run(f), 2.0)
+
+ def testCapturedNested(self):
+ with ops.Graph().as_default():
+ x1 = constant_op.constant(1.0, name="x1")
+ x2 = constant_op.constant(2.0, name="x2")
+ x3 = math_ops.multiply(x1, x2, name="x3")
+
+ @function.Defun()
+ def Outer():
+ outer1 = array_ops.identity(x1, name="outer1")
+
+ @function.Defun()
+ def Inner():
+ inner1 = array_ops.identity(outer1, name="inner1")
+ inner2 = array_ops.identity(x2, name="inner2")
+ inner3 = array_ops.identity(x3, name="inner3")
+ return gradients_impl.gradients([inner1, inner2, inner3, x1],
+ [x1, x2])
+
+ return Inner()
+
+ x1_grad, x2_grad = Outer()
+ with self.test_session() as sess:
+ # 1.0 + None + 2.0 + 1.0 = 4.0
+ self.assertEqual(sess.run(x1_grad), 4.0)
+ # None + 1.0 + 1.0 + None = 2.0
+ self.assertEqual(sess.run(x2_grad), 2.0)
+
+ def testCapturedFromFunction(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1.0, name="x")
+
+ @function.Defun()
+ def Outer():
+ y = math_ops.multiply(x, 2.0, name="y")
+
+ @function.Defun()
+ def Inner():
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, y)
+ return g[0]
+
+ return Inner()
+
+ z_grad = Outer()
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(z_grad), 3.0)
+
class StopGradientTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index e907fc470b..5b384fd596 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -54,8 +55,10 @@ ops.NotDifferentiable('SampleDistortedBoundingBoxV2')
ops.NotDifferentiable('ExtractGlimpse')
ops.NotDifferentiable('NonMaxSuppression')
ops.NotDifferentiable('NonMaxSuppressionV2')
+ops.NotDifferentiable('NonMaxSuppressionWithOverlaps')
+# pylint: disable=invalid-name
def _assert(cond, ex_type, msg):
"""A polymorphic assert, works with tensors and boolean expressions.
@@ -258,14 +261,14 @@ def random_flip_up_down(image, seed=None):
dimension, which is `height`. Otherwise output the image as-is.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
- A 3-D tensor of the same type and shape as `image`.
-
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
@@ -280,13 +283,14 @@ def random_flip_left_right(image, seed=None):
second dimension, which is `width`. Otherwise output the image as-is.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
@@ -297,7 +301,8 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
Args:
- image: A 3-D tensor of shape `[height, width, channels].`
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
flip_index: The dimension along which to flip the image.
Vertical: 0, Horizontal: 1
seed: A Python integer. Used to create a random seed. See
@@ -306,22 +311,37 @@ def _random_flip(image, flip_index, seed, scope_name):
scope_name: Name of the scope in which the ops are added.
Returns:
- A 3-D tensor of the same type and shape as `image`.
+ A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
- mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(
- mirror_cond,
- lambda: array_ops.reverse(image, [flip_index]),
- lambda: image,
- name=scope)
- return fix_image_flip_shape(image, result)
+ image = _AssertAtLeast3DImage(image)
+ shape = image.get_shape()
+ if shape.ndims == 3 or shape.ndims is None:
+ uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
+ mirror_cond = math_ops.less(uniform_random, .5)
+ result = control_flow_ops.cond(
+ mirror_cond,
+ lambda: array_ops.reverse(image, [flip_index]),
+ lambda: image,
+ name=scope
+ )
+ return fix_image_flip_shape(image, result)
+ elif shape.ndims == 4:
+ uniform_random = random_ops.random_uniform(
+ [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ )
+ mirror_cond = math_ops.less(uniform_random, .5)
+ return array_ops.where(
+ mirror_cond,
+ image,
+ functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ )
+ else:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@tf_export('image.flip_left_right')
@@ -921,12 +941,13 @@ class ResizeMethod(object):
def resize_images(images,
size,
method=ResizeMethod.BILINEAR,
- align_corners=False):
+ align_corners=False,
+ preserve_aspect_ratio=False):
"""Resize `images` to `size` using the specified `method`.
Resized images will be distorted if their original aspect ratio is not
the same as `size`. To avoid distortions see
- @{tf.image.resize_image_with_crop_or_pad}.
+ @{tf.image.resize_image_with_pad}.
`method` can be one of:
@@ -953,6 +974,10 @@ def resize_images(images,
align_corners: bool. If True, the centers of the 4 corner pixels of the
input and output tensors are aligned, preserving the values at the
corner pixels. Defaults to `False`.
+ preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set,
+ then `images` will be resized to a size that fits in `size` while
+ preserving the aspect ratio of the original image. Scales up the image if
+ `size` is bigger than the current size of the `image`. Defaults to False.
Raises:
ValueError: if the shape of `images` is incompatible with the
@@ -991,6 +1016,28 @@ def resize_images(images,
new_height_const = size_const_as_shape[0].value
new_width_const = size_const_as_shape[1].value
+ if preserve_aspect_ratio:
+ # Get the current shapes of the image, even if dynamic.
+ _, current_height, current_width, _ = _ImageDimensions(images, rank=4)
+
+ # do the computation to find the right scale and height/width.
+ scale_factor_height = (math_ops.to_float(new_height_const) /
+ math_ops.to_float(current_height))
+ scale_factor_width = (math_ops.to_float(new_width_const) /
+ math_ops.to_float(current_width))
+ scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
+ scaled_height_const = math_ops.to_int32(scale_factor *
+ math_ops.to_float(current_height))
+ scaled_width_const = math_ops.to_int32(scale_factor *
+ math_ops.to_float(current_width))
+
+ # NOTE: Reset the size and other constants used later.
+ size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
+ dtypes.int32, name='size')
+ size_const_as_shape = tensor_util.constant_value_as_shape(size)
+ new_height_const = size_const_as_shape[0].value
+ new_width_const = size_const_as_shape[1].value
+
# If we can determine that the height and width will be unmodified by this
# transformation, we avoid performing the resize.
if all(x is not None
@@ -1024,6 +1071,106 @@ def resize_images(images,
return images
+@tf_export('image.resize_image_with_pad')
+def resize_image_with_pad(image,
+ target_height,
+ target_width,
+ method=ResizeMethod.BILINEAR):
+ """Resizes and pads an image to a target width and height.
+
+ Resizes an image to a target width and height by keeping
+ the aspect ratio the same without distortion. If the target
+ dimensions don't match the image dimensions, the image
+ is resized and then padded with zeroes to match requested
+ dimensions.
+
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ target_height: Target height.
+ target_width: Target width.
+ method: Method to use for resizing image. See `resize_images()`
+
+ Raises:
+ ValueError: if `target_height` or `target_width` are zero or negative.
+
+ Returns:
+ Resized and padded image.
+ If `images` was 4-D, a 4-D float Tensor of shape
+ `[batch, new_height, new_width, channels]`.
+ If `images` was 3-D, a 3-D float Tensor of shape
+ `[new_height, new_width, channels]`.
+ """
+ with ops.name_scope(None, 'resize_image_with_pad', [image]):
+ image = ops.convert_to_tensor(image, name='image')
+ image_shape = image.get_shape()
+ is_batch = True
+ if image_shape.ndims == 3:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ elif image_shape.ndims is None:
+ is_batch = False
+ image = array_ops.expand_dims(image, 0)
+ image.set_shape([None] * 4)
+ elif image_shape.ndims != 4:
+ raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+
+ assert_ops = _CheckAtLeast3DImage(image, require_static=False)
+ assert_ops += _assert(target_width > 0, ValueError,
+ 'target_width must be > 0.')
+ assert_ops += _assert(target_height > 0, ValueError,
+ 'target_height must be > 0.')
+
+ image = control_flow_ops.with_dependencies(assert_ops, image)
+
+ def max_(x, y):
+ if _is_tensor(x) or _is_tensor(y):
+ return math_ops.maximum(x, y)
+ else:
+ return max(x, y)
+
+ _, height, width, _ = _ImageDimensions(image, rank=4)
+
+ # convert values to float, to ease divisions
+ f_height = math_ops.cast(height, dtype=dtypes.float64)
+ f_width = math_ops.cast(width, dtype=dtypes.float64)
+ f_target_height = math_ops.cast(target_height, dtype=dtypes.float64)
+ f_target_width = math_ops.cast(target_width, dtype=dtypes.float64)
+
+ # Find the ratio by which the image must be adjusted
+ # to fit within the target
+ ratio = max_(f_width / f_target_width, f_height / f_target_height)
+ resized_height_float = f_height / ratio
+ resized_width_float = f_width / ratio
+ resized_height = math_ops.cast(
+ math_ops.floor(resized_height_float), dtype=dtypes.int32)
+ resized_width = math_ops.cast(
+ math_ops.floor(resized_width_float), dtype=dtypes.int32)
+
+ padding_height = (f_target_height - resized_height_float) / 2
+ padding_width = (f_target_width - resized_width_float) / 2
+ f_padding_height = math_ops.floor(padding_height)
+ f_padding_width = math_ops.floor(padding_width)
+ p_height = max_(0, math_ops.cast(f_padding_height, dtype=dtypes.int32))
+ p_width = max_(0, math_ops.cast(f_padding_width, dtype=dtypes.int32))
+
+ # Resize first, then pad to meet requested dimensions
+ resized = resize_images(image, [resized_height, resized_width], method)
+
+ padded = pad_to_bounding_box(resized, p_height, p_width, target_height,
+ target_width)
+
+ if padded.get_shape().ndims is None:
+ raise ValueError('padded contains no shape.')
+
+ _ImageDimensions(padded, rank=4)
+
+ if not is_batch:
+ padded = array_ops.squeeze(padded, squeeze_dims=[0])
+
+ return padded
+
+
@tf_export('image.per_image_standardization')
def per_image_standardization(image):
"""Linearly scales `image` to have zero mean and unit norm.
@@ -1451,6 +1598,75 @@ def adjust_hue(image, delta, name=None):
return convert_image_dtype(rgb_altered, orig_dtype)
+# pylint: disable=invalid-name
+@tf_export('image.random_jpeg_quality')
+def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
+ """Randomly changes jpeg encoding quality for inducing jpeg noise.
+
+ `min_jpeg_quality` must be in the interval `[0, 100]` and less than
+ `max_jpeg_quality`.
+ `max_jpeg_quality` must be in the interval `[0, 100]`.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ min_jpeg_quality: Minimum jpeg encoding quality to use.
+ max_jpeg_quality: Maximum jpeg encoding quality to use.
+ seed: An operation-specific seed. It will be used in conjunction
+ with the graph-level seed to determine the real seeds that will be
+ used in this operation. Please see the documentation of
+ set_random_seed for its interaction with the graph-level random seed.
+
+ Returns:
+ Adjusted image(s), same shape and DType as `image`.
+
+ Raises:
+ ValueError: if `min_jpeg_quality` or `max_jpeg_quality` is invalid.
+ """
+ if (min_jpeg_quality < 0 or max_jpeg_quality < 0 or
+ min_jpeg_quality > 100 or max_jpeg_quality > 100):
+ raise ValueError('jpeg encoding range must be between 0 and 100.')
+
+ if min_jpeg_quality >= max_jpeg_quality:
+ raise ValueError('`min_jpeg_quality` must be less than `max_jpeg_quality`.')
+
+ np.random.seed(seed)
+ jpeg_quality = np.random.randint(min_jpeg_quality, max_jpeg_quality)
+ return adjust_jpeg_quality(image, jpeg_quality)
+
+
+@tf_export('image.adjust_jpeg_quality')
+def adjust_jpeg_quality(image, jpeg_quality, name=None):
+ """Adjust jpeg encoding quality of an RGB image.
+
+ This is a convenience method that adjusts jpeg encoding quality of an
+ RGB image.
+
+ `image` is an RGB image. The image's encoding quality is adjusted
+ to `jpeg_quality`.
+ `jpeg_quality` must be in the interval `[0, 100]`.
+
+ Args:
+ image: RGB image or images. Size of the last dimension must be 3.
+ jpeg_quality: int. jpeg encoding quality.
+ name: A name for this operation (optional).
+
+ Returns:
+ Adjusted image(s), same shape and DType as `image`.
+ """
+ with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name:
+ image = ops.convert_to_tensor(image, name='image')
+ # Remember original dtype to so we can convert back if needed
+ orig_dtype = image.dtype
+ # Convert to uint8
+ image = convert_image_dtype(image, dtypes.uint8)
+ # Encode image to jpeg with given jpeg quality
+ image = gen_image_ops.encode_jpeg(image, quality=jpeg_quality)
+ # Decode jpeg image
+ image = gen_image_ops.decode_jpeg(image)
+ # Convert back to original dtype and return
+ return convert_image_dtype(image, orig_dtype)
+
+
@tf_export('image.random_saturation')
def random_saturation(image, lower, upper, seed=None):
"""Adjust the saturation of an RGB image by a random factor.
@@ -1538,13 +1754,13 @@ def is_jpeg(contents, name=None):
@tf_export('image.decode_image')
-def decode_image(contents, channels=None, name=None):
+def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.
Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
- appropriate operation to convert the input bytes `string` into a `Tensor` of
- type `uint8`.
+ appropriate operation to convert the input bytes `string` into a `Tensor`
+ of type `dtype`.
Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as
opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D
@@ -1556,10 +1772,11 @@ def decode_image(contents, channels=None, name=None):
contents: 0-D `string`. The encoded image bytes.
channels: An optional `int`. Defaults to `0`. Number of color channels for
the decoded image.
+ dtype: The desired DType of the returned `Tensor`.
name: A name for the operation (optional)
Returns:
- `Tensor` with type `uint8` with shape `[height, width, num_channels]` for
+ `Tensor` with type `dtype` and shape `[height, width, num_channels]` for
BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for
GIF images.
@@ -1583,7 +1800,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
- return gen_image_ops.decode_bmp(contents)
+ return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype)
def _gif():
# Create assert to make sure that channels is not set to 1
@@ -1596,7 +1813,7 @@ def decode_image(contents, channels=None, name=None):
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_gif(contents)
+ return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype)
def check_gif():
# Create assert op to check that bytes are GIF decodable
@@ -1605,7 +1822,11 @@ def decode_image(contents, channels=None, name=None):
def _png():
"""Decodes a PNG image."""
- return gen_image_ops.decode_png(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_png(contents, channels,
+ dtype=dtypes.uint8
+ if dtype == dtypes.uint8
+ else dtypes.uint16), dtype)
def check_png():
"""Checks if an image is PNG."""
@@ -1621,7 +1842,8 @@ def decode_image(contents, channels=None, name=None):
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
- return gen_image_ops.decode_jpeg(contents, channels)
+ return convert_image_dtype(
+ gen_image_ops.decode_jpeg(contents, channels), dtype)
# Decode normal JPEG images (start with \xff\xd8\xff\xe0)
# as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1).
@@ -1872,6 +2094,50 @@ def non_max_suppression(boxes,
iou_threshold, score_threshold)
+@tf_export('image.non_max_suppression_overlaps')
+def non_max_suppression_with_overlaps(overlaps,
+ scores,
+ max_output_size,
+ overlap_threshold=0.5,
+ score_threshold=float('-inf'),
+ name=None):
+ """Greedily selects a subset of bounding boxes in descending order of score.
+
+ Prunes away boxes that have high overlap with previously selected boxes.
+ N-by-n overlap values are supplied as square matrix.
+ The output of this operation is a set of integers indexing into the input
+ collection of bounding boxes representing the selected boxes. The bounding
+ box coordinates corresponding to the selected indices can then be obtained
+ using the `tf.gather operation`. For example:
+ selected_indices = tf.image.non_max_suppression_overlaps(
+ overlaps, scores, max_output_size, iou_threshold)
+ selected_boxes = tf.gather(boxes, selected_indices)
+
+ Args:
+ overlaps: A 2-D float `Tensor` of shape `[num_boxes, num_boxes]`.
+ scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
+ score corresponding to each box (each row of boxes).
+ max_output_size: A scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ overlap_threshold: A float representing the threshold for deciding whether
+ boxes overlap too much with respect to the provided overlap values.
+ score_threshold: A float representing the threshold for deciding when to
+ remove boxes based on score.
+ name: A name for the operation (optional).
+
+ Returns:
+ selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
+ selected indices from the overlaps tensor, where `M <= max_output_size`.
+ """
+ with ops.name_scope(name, 'non_max_suppression_overlaps'):
+ overlap_threshold = ops.convert_to_tensor(
+ overlap_threshold, name='overlap_threshold')
+ # pylint: disable=protected-access
+ return gen_image_ops._non_max_suppression_v3(
+ overlaps, scores, max_output_size, overlap_threshold, score_threshold)
+ # pylint: enable=protected-access
+
+
_rgb_to_yiq_kernel = [[0.299, 0.59590059,
0.2115], [0.587, -0.27455667, -0.52273617],
[0.114, -0.32134392, 0.31119955]]
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 72c889a2e6..cf9761803b 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -533,6 +533,37 @@ class FlipImageBenchmark(test.Benchmark):
iters=benchmark_rounds,
wall_time=step_time)
+ def _benchmarkBatchedRandomFlipLeftRight(self, device, cpu_count):
+ image_shape = [16, 299, 299, 3]
+ warmup_rounds = 100
+ benchmark_rounds = 1000
+ config = config_pb2.ConfigProto()
+ if cpu_count is not None:
+ config.inter_op_parallelism_threads = 1
+ config.intra_op_parallelism_threads = cpu_count
+ with session.Session("", graph=ops.Graph(), config=config) as sess:
+ with ops.device(device):
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ run_op = image_ops.random_flip_left_right(inputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
+ end = time.time()
+ step_time = (end - start) / benchmark_rounds
+ tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
+ print("benchmarkBatchedRandomFlipLeftRight_16_299_299_3_%s step_time: "
+ "%.2f us" %
+ (tag, step_time * 1e6))
+ self.report_benchmark(
+ name="benchmarkBatchedRandomFlipLeftRight_16_299_299_3_%s" % (tag),
+ iters=benchmark_rounds,
+ wall_time=step_time)
+
def benchmarkFlipLeftRightCpu1(self):
self._benchmarkFlipLeftRight("/cpu:0", 1)
@@ -551,6 +582,15 @@ class FlipImageBenchmark(test.Benchmark):
def benchmarkRandomFlipLeftRightGpu(self):
self._benchmarkRandomFlipLeftRight(test.gpu_device_name(), None)
+ def benchmarkBatchedRandomFlipLeftRightCpu1(self):
+ self._benchmarkBatchedRandomFlipLeftRight("/cpu:0", 1)
+
+ def benchmarkBatchedRandomFlipLeftRightCpuAll(self):
+ self._benchmarkBatchedRandomFlipLeftRight("/cpu:0", None)
+
+ def benchmarkBatchedRandomFlipLeftRightGpu(self):
+ self._benchmarkBatchedRandomFlipLeftRight(test.gpu_device_name(), None)
+
class AdjustHueBenchmark(test.Benchmark):
@@ -987,7 +1027,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_left_right(x_tf)
+ y = image_ops.random_flip_left_right(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_left_right"))
count_flipped = 0
@@ -1008,6 +1048,50 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
self.assertGreaterEqual(count_flipped, 20)
self.assertGreaterEqual(count_unflipped, 20)
+ def testRandomFlipLeftRightWithBatch(self):
+ batch_size = 16
+ seed = 42
+
+ # create single item of test data
+ x_np_raw = np.array(
+ [[1, 2, 3], [1, 2, 3]], dtype=np.uint8
+ ).reshape([1, 2, 3, 1])
+ y_np_raw = np.array(
+ [[3, 2, 1], [3, 2, 1]], dtype=np.uint8
+ ).reshape([1, 2, 3, 1])
+
+ # create batched test data
+ x_np = np.vstack([x_np_raw for _ in range(batch_size)])
+ y_np = np.vstack([y_np_raw for _ in range(batch_size)])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_left_right(x_tf, seed=seed)
+ self.assertTrue(y.op.name.startswith("random_flip_left_right"))
+
+ count_flipped = 0
+ count_unflipped = 0
+ for _ in range(100):
+ y_tf = y.eval()
+
+ # check every element of the batch
+ for i in range(batch_size):
+ if y_tf[i][0][0] == 1:
+ self.assertAllEqual(y_tf[i], x_np[i])
+ count_unflipped += 1
+ else:
+ self.assertAllEqual(y_tf[i], y_np[i])
+ count_flipped += 1
+
+ # 100 trials, each containing batch_size elements
+ # Mean: 50 * batch_size
+ # Std Dev: ~5 * sqrt(batch_size)
+ # Six Sigma: 50 * batch_size - (5 * 6 * sqrt(batch_size))
+ # = 50 * batch_size - 30 * sqrt(batch_size) = 800 - 30 * 4 = 680
+ six_sigma = 50 * batch_size - 30 * np.sqrt(batch_size)
+ self.assertGreaterEqual(count_flipped, six_sigma)
+ self.assertGreaterEqual(count_unflipped, six_sigma)
+
def testInvolutionUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
@@ -1057,9 +1141,11 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
+ seed = 42
+
with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_up_down(x_tf, seed=42)
+ y = image_ops.random_flip_up_down(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
count_flipped = 0
count_unflipped = 0
@@ -1079,6 +1165,50 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
self.assertGreaterEqual(count_flipped, 20)
self.assertGreaterEqual(count_unflipped, 20)
+ def testRandomFlipUpDownWithBatch(self):
+ batch_size = 16
+ seed = 42
+
+ # create single item of test data
+ x_np_raw = np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=np.uint8
+ ).reshape([1, 2, 3, 1])
+ y_np_raw = np.array(
+ [[4, 5, 6], [1, 2, 3]], dtype=np.uint8
+ ).reshape([1, 2, 3, 1])
+
+ # create batched test data
+ x_np = np.vstack([x_np_raw for _ in range(batch_size)])
+ y_np = np.vstack([y_np_raw for _ in range(batch_size)])
+
+ with self.test_session(use_gpu=True):
+ x_tf = constant_op.constant(x_np, shape=x_np.shape)
+ y = image_ops.random_flip_up_down(x_tf, seed=seed)
+ self.assertTrue(y.op.name.startswith("random_flip_up_down"))
+
+ count_flipped = 0
+ count_unflipped = 0
+ for _ in range(100):
+ y_tf = y.eval()
+
+ # check every element of the batch
+ for i in range(batch_size):
+ if y_tf[i][0][0] == 1:
+ self.assertAllEqual(y_tf[i], x_np[i])
+ count_unflipped += 1
+ else:
+ self.assertAllEqual(y_tf[i], y_np[i])
+ count_flipped += 1
+
+ # 100 trials, each containing batch_size elements
+ # Mean: 50 * batch_size
+ # Std Dev: ~5 * sqrt(batch_size)
+ # Six Sigma: 50 * batch_size - (5 * 6 * sqrt(batch_size))
+ # = 50 * batch_size - 30 * sqrt(batch_size) = 800 - 30 * 4 = 680
+ six_sigma = 50 * batch_size - 30 * np.sqrt(batch_size)
+ self.assertGreaterEqual(count_flipped, six_sigma)
+ self.assertGreaterEqual(count_unflipped, six_sigma)
+
def testInvolutionTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
@@ -1156,6 +1286,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
#Ops that support 4D input
for op in [
image_ops.flip_left_right, image_ops.flip_up_down,
+ image_ops.random_flip_left_right, image_ops.random_flip_up_down,
image_ops.transpose_image, image_ops.rot90
]:
transformed_unknown_dims_4 = op(p_unknown_dims_4)
@@ -1166,14 +1297,6 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
"must be at least three-dimensional"):
op(p_wrong_rank)
- for op in [
- image_ops.random_flip_left_right,
- image_ops.random_flip_up_down,
- ]:
- with self.assertRaisesRegexp(ValueError, "must be three-dimensional"):
- op(p_wrong_rank)
-
-
def testRot90GroupOrder(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
with self.test_session(use_gpu=True):
@@ -1208,41 +1331,6 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase):
y_np = np.rot90(image, k=k, axes=(1, 2))
self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k}))
-class RandomFlipTest(test_util.TensorFlowTestCase):
-
- def testRandomLeftRight(self):
- x_np = np.array([0, 1], dtype=np.uint8).reshape([1, 2, 1])
- num_iterations = 500
-
- hist = [0, 0]
- with self.test_session(use_gpu=True):
- x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_left_right(x_tf)
- for _ in xrange(num_iterations):
- y_np = y.eval().flatten()[0]
- hist[y_np] += 1
-
- # Ensure that each entry is observed within 4 standard deviations.
- four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
- self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
-
- def testRandomUpDown(self):
- x_np = np.array([0, 1], dtype=np.uint8).reshape([2, 1, 1])
- num_iterations = 500
-
- hist = [0, 0]
- with self.test_session(use_gpu=True):
- x_tf = constant_op.constant(x_np, shape=x_np.shape)
- y = image_ops.random_flip_up_down(x_tf)
- for _ in xrange(num_iterations):
- y_np = y.eval().flatten()[0]
- hist[y_np] += 1
-
- # Ensure that each entry is observed within 4 standard deviations.
- four_stddev = 4.0 * np.sqrt(num_iterations / 2.0)
- self.assertAllClose(hist, [num_iterations / 2.0] * 2, atol=four_stddev)
-
-
class AdjustContrastTest(test_util.TensorFlowTestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
@@ -2511,6 +2599,182 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
y = image_ops.resize_images(single_image, [55, 66])
self.assertTrue(y.op.name.startswith("resize_images"))
+ def _ResizeImageCall(self, x, max_h, max_w, preserve_aspect_ratio,
+ use_tensor_inputs):
+ if use_tensor_inputs:
+ target_max = ops.convert_to_tensor([max_h, max_w])
+ x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim)
+ feed_dict = {x_tensor: x}
+ else:
+ target_max = [max_h, max_w]
+ x_tensor = x
+ feed_dict = {}
+
+ y = image_ops.resize_images(x_tensor, target_max,
+ preserve_aspect_ratio=preserve_aspect_ratio)
+
+ with self.test_session(use_gpu=True):
+ return y.eval(feed_dict=feed_dict)
+
+ def _assertResizeEqual(self, x, x_shape, y, y_shape,
+ preserve_aspect_ratio=True,
+ use_tensor_inputs_options=None):
+ use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
+ target_height, target_width, _ = y_shape
+ x = np.array(x).reshape(x_shape)
+ y = np.array(y).reshape(y_shape)
+
+ for use_tensor_inputs in use_tensor_inputs_options:
+ y_tf = self._ResizeImageCall(x, target_height, target_width,
+ preserve_aspect_ratio, use_tensor_inputs)
+ self.assertAllClose(y, y_tf)
+
+ def _assertResizeCheckShape(self, x, x_shape, target_shape,
+ y_shape, preserve_aspect_ratio=True,
+ use_tensor_inputs_options=None):
+ use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
+ target_height, target_width = target_shape
+ x = np.array(x).reshape(x_shape)
+ y = np.zeros(y_shape)
+
+ for use_tensor_inputs in use_tensor_inputs_options:
+ y_tf = self._ResizeImageCall(x, target_height, target_width,
+ preserve_aspect_ratio, use_tensor_inputs)
+ self.assertShapeEqual(y, ops.convert_to_tensor(y_tf))
+
+ def testPreserveAspectRatioMultipleImages(self):
+ x_shape = [10, 100, 100, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [250, 250], [10, 250, 250, 10],
+ preserve_aspect_ratio=False)
+
+ def testPreserveAspectRatioNoOp(self):
+ x_shape = [10, 10, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeEqual(x, x_shape, x, x_shape)
+
+ def testPreserveAspectRatioSmaller(self):
+ x_shape = [100, 100, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [75, 50], [50, 50, 10])
+
+ def testPreserveAspectRatioSmallerMultipleImages(self):
+ x_shape = [10, 100, 100, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [75, 50], [10, 50, 50, 10])
+
+ def testPreserveAspectRatioLarger(self):
+ x_shape = [100, 100, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [150, 200], [150, 150, 10])
+
+ def testPreserveAspectRatioSameRatio(self):
+ x_shape = [1920, 1080, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+
+
+class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
+
+ def _ResizeImageWithPad(self, x, target_height, target_width,
+ use_tensor_inputs):
+ if use_tensor_inputs:
+ target_height = ops.convert_to_tensor(target_height)
+ target_width = ops.convert_to_tensor(target_width)
+ x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim)
+ feed_dict = {x_tensor: x}
+ else:
+ x_tensor = x
+ feed_dict = {}
+
+ y = image_ops.resize_image_with_pad(x_tensor, target_height,
+ target_width)
+ if not use_tensor_inputs:
+ self.assertTrue(y.get_shape().is_fully_defined())
+
+ with self.test_session(use_gpu=True):
+ return y.eval(feed_dict=feed_dict)
+
+ def _assertReturns(self,
+ x,
+ x_shape,
+ y,
+ y_shape,
+ use_tensor_inputs_options=None):
+ use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
+ target_height, target_width, _ = y_shape
+ x = np.array(x).reshape(x_shape)
+ y = np.array(y).reshape(y_shape)
+
+ for use_tensor_inputs in use_tensor_inputs_options:
+ y_tf = self._ResizeImageWithPad(x, target_height, target_width,
+ use_tensor_inputs)
+ self.assertAllClose(y, y_tf)
+
+ def _assertRaises(self,
+ x,
+ x_shape,
+ target_height,
+ target_width,
+ err_msg,
+ use_tensor_inputs_options=None):
+ use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
+ x = np.array(x).reshape(x_shape)
+
+ for use_tensor_inputs in use_tensor_inputs_options:
+ try:
+ self._ResizeImageWithPad(x, target_height, target_width,
+ use_tensor_inputs)
+ except Exception as e: # pylint: disable=broad-except
+ if err_msg not in str(e):
+ raise
+ else:
+ raise AssertionError("Exception not raised: %s" % err_msg)
+
+ def _assertShapeInference(self, pre_shape, height, width, post_shape):
+ image = array_ops.placeholder(dtypes.float32, shape=pre_shape)
+ y = image_ops.resize_image_with_pad(image, height, width)
+ self.assertEqual(y.get_shape().as_list(), post_shape)
+
+ def testNoOp(self):
+ x_shape = [10, 10, 10]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertReturns(x, x_shape, x, x_shape)
+
+ def testPad(self):
+ # Reduce vertical dimension
+ x = [1, 2, 3, 4, 5, 6, 7, 8]
+ x_shape = [2, 4, 1]
+
+ y = [0, 1, 3, 0]
+ y_shape = [1, 4, 1]
+
+ self._assertReturns(x, x_shape, y, y_shape)
+
+ # Reduce horizontal dimension
+ x = [1, 2, 3, 4, 5, 6, 7, 8]
+ x_shape = [2, 4, 1]
+
+ y = [1, 3, 0, 0]
+ y_shape = [2, 2, 1]
+
+ self._assertReturns(x, x_shape, y, y_shape)
+
+ x = [1, 2, 3, 4, 5, 6, 7, 8]
+ x_shape = [2, 4, 1]
+
+ y = [1, 3]
+ y_shape = [1, 2, 1]
+
+ self._assertReturns(x, x_shape, y, y_shape)
+
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
@@ -3800,5 +4064,88 @@ class SobelEdgesTest(test_util.TensorFlowTestCase):
self.assertAllClose(expected_batch, actual_sobel)
+class DecodeImageTest(test_util.TensorFlowTestCase):
+
+ def testJpegUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/jpeg/testdata"
+ jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+ image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testPngUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/png/testdata"
+ png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+ image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(
+ image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testGifUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/gif/testdata"
+ gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+ image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testBmpUint16(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/bmp/testdata"
+ bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+ image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+ dtypes.uint16)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testJpegFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/jpeg/testdata"
+ jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
+ image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testPngFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/png/testdata"
+ png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
+ image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(
+ image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testGifFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/gif/testdata"
+ gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
+ image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+ def testBmpFloat32(self):
+ with self.test_session(use_gpu=True) as sess:
+ base = "tensorflow/core/lib/bmp/testdata"
+ bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
+ image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
+ image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
+ dtypes.float32)
+ image0, image1 = sess.run([image0, image1])
+ self.assertAllEqual(image0, image1)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 2df230d470..3132f7467f 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -43,7 +43,8 @@ from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.util.deprecation import deprecated
+from tensorflow.python.util.deprecation import (
+ deprecated, deprecated_arg_values)
from tensorflow.python.util.tf_export import tf_export
@@ -409,8 +410,10 @@ class UniformUnitScaling(Initializer):
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
- With `distribution="normal"`, samples are drawn from a truncated normal
- distribution centered on zero, with `stddev = sqrt(scale / n)`
+ With `distribution="truncated_normal" or "untruncated_normal"`,
+ samples are drawn from a truncated/untruncated normal
+ distribution with a mean of zero and a standard deviation (after truncation,
+ if used) `stddev = sqrt(scale / n)`
where n is:
- number of input units in the weight tensor, if mode = "fan_in"
- number of output units, if mode = "fan_out"
@@ -433,10 +436,14 @@ class VarianceScaling(Initializer):
"distribution" arguments.
"""
+ @deprecated_arg_values(
+ None,
+ "`normal` is a deprecated alias for `truncated_normal`",
+ distribution="normal")
def __init__(self,
scale=1.0,
mode="fan_in",
- distribution="normal",
+ distribution="truncated_normal",
seed=None,
dtype=dtypes.float32):
if scale <= 0.:
@@ -444,7 +451,8 @@ class VarianceScaling(Initializer):
if mode not in {"fan_in", "fan_out", "fan_avg"}:
raise ValueError("Invalid `mode` argument:", mode)
distribution = distribution.lower()
- if distribution not in {"normal", "uniform"}:
+ if distribution not in {"normal", "uniform",
+ "truncated_normal", "untruncated_normal"}:
raise ValueError("Invalid `distribution` argument:", distribution)
self.scale = scale
self.mode = mode
@@ -466,10 +474,15 @@ class VarianceScaling(Initializer):
scale /= max(1., fan_out)
else:
scale /= max(1., (fan_in + fan_out) / 2.)
- if self.distribution == "normal":
- stddev = math.sqrt(scale)
+ if self.distribution == "normal" or self.distribution == "truncated_normal":
+ # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
+ stddev = math.sqrt(scale) / .87962566103423978
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype, seed=self.seed)
+ elif self.distribution == "untruncated_normal":
+ stddev = math.sqrt(scale)
+ return random_ops.random_normal(
+ shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
return random_ops.random_uniform(
@@ -550,7 +563,9 @@ class ConvolutionDeltaOrthogonal(Initializer):
The shape of the tensor must have length 3, 4 or 5. The number of input
filters must not exceed the number of output filters. The center pixels of the
- tensor form an orthogonal matrix. Other pixels are set to be zero.
+ tensor form an orthogonal matrix. Other pixels are set to be zero. See
+ algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
+
Args:
gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
@@ -671,6 +686,7 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
+ See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
Args:
gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
@@ -806,6 +822,7 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
+ See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
Args:
gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
@@ -922,6 +939,7 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
+ See algorithm 1 [Xiao et al., 2018] in: https://arxiv.org/abs/1806.05393
Args:
gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
@@ -1118,7 +1136,7 @@ convolutional_orthogonal_3d = ConvolutionOrthogonal3D
# pylint: enable=invalid-name
-@tf_export("glorot_uniform_initializer")
+@tf_export("glorot_uniform_initializer", "keras.initializers.glorot_uniform")
def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot uniform initializer, also called Xavier uniform initializer.
@@ -1142,7 +1160,7 @@ def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
scale=1.0, mode="fan_avg", distribution="uniform", seed=seed, dtype=dtype)
-@tf_export("glorot_normal_initializer")
+@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal")
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
"""The Glorot normal initializer, also called Xavier normal initializer.
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 8cfe964b1c..20c46fbb82 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -42,7 +42,7 @@ __all__ = ["LinearOperator"]
class LinearOperator(object):
"""Base class defining a [batch of] linear operator[s].
- Subclasses of `LinearOperator` provide a access to common methods on a
+ Subclasses of `LinearOperator` provide access to common methods on a
(batch) matrix, without the need to materialize the matrix. This allows:
* Matrix free computations
@@ -69,11 +69,11 @@ class LinearOperator(object):
#### Shape compatibility
- `LinearOperator` sub classes should operate on a [batch] matrix with
+ `LinearOperator` subclasses should operate on a [batch] matrix with
compatible shape. Class docstrings should define what is meant by compatible
- shape. Some sub-classes may not support batching.
+ shape. Some subclasses may not support batching.
- An example is:
+ Examples:
`x` is a batch matrix with compatible shape for `matmul` if
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index 5beaea65a5..ed53decc00 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -231,8 +231,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
def _log_abs_determinant(self):
- return math_ops.reduce_sum(
+ log_det = math_ops.reduce_sum(
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+ if self.dtype.is_complex:
+ log_det = math_ops.cast(log_det, dtype=self.dtype)
+ return log_det
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
diag_term = math_ops.conj(self._diag) if adjoint else self._diag
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 08e5896e10..2b2bf80f27 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -18,16 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -153,8 +152,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
`is_X` matrix property hints, which will trigger the appropriate code path.
Args:
- base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or
- `float64` `LinearOperator`. This is `L` above.
+ base_operator: Shape `[B1,...,Bb, M, N]`.
u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
This is `U` above.
diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
@@ -183,23 +181,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
Raises:
ValueError: If `is_X` flags are set in an inconsistent way.
"""
- # TODO(langmore) support complex types.
- # Complex types are not allowed due to tf.cholesky() requiring float.
- # If complex dtypes are allowed, we update the following
- # 1. is_diag_update_positive should still imply that `diag > 0`, but we need
- # to remind the user that this implies diag is real. This is needed
- # because if diag has non-zero imaginary part, it will not be
- # self-adjoint positive definite.
dtype = base_operator.dtype
- allowed_dtypes = [
- dtypes.float16,
- dtypes.float32,
- dtypes.float64,
- ]
- if dtype not in allowed_dtypes:
- raise TypeError(
- "Argument matrix must have dtype in %s. Found: %s"
- % (allowed_dtypes, dtype))
+
+ if diag_update is not None:
+ if is_diag_update_positive and dtype.is_complex:
+ logging.warn("Note: setting is_diag_update_positive with a complex "
+ "dtype means that diagonal is real and positive.")
if diag_update is None:
if is_diag_update_positive is False:
@@ -271,8 +258,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
self._set_diag_operators(diag_update, is_diag_update_positive)
self._is_diag_update_positive = is_diag_update_positive
- check_ops.assert_same_float_dtype((base_operator, self.u, self.v,
- self._diag_update))
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
@@ -407,6 +392,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
else:
det_c = linalg_ops.matrix_determinant(self._capacitance)
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
+ if self.dtype.is_complex:
+ log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
return log_abs_det_c + log_abs_det_d + log_abs_det_l
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index fb1eb2fedb..ca6d3f5405 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -119,8 +119,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
Args:
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
The lower triangular part of `tril` defines this operator. The strictly
- upper triangle is ignored. Allowed dtypes: `float16`, `float32`,
- `float64`.
+ upper triangle is ignored.
is_non_singular: Expect that this operator is non-singular.
This operator is non-singular if and only if its diagonal elements are
all non-zero.
@@ -137,7 +136,6 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
name: A name for this `LinearOperator`.
Raises:
- TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `is_square` is `False`.
"""
@@ -163,12 +161,12 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _check_tril(self, tril):
"""Static check of the `tril` argument."""
- # TODO(langmore) Add complex types once matrix_triangular_solve works for
- # them.
allowed_dtypes = [
dtypes.float16,
dtypes.float32,
dtypes.float64,
+ dtypes.complex64,
+ dtypes.complex128,
]
dtype = tril.dtype
if dtype not in allowed_dtypes:
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 1b5bb9470c..78c85db557 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -102,7 +102,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
raise NotImplementedError("operator_build_infos has not been implemented.")
@abc.abstractmethod
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
"""Build a batch matrix and an Operator that should have similar behavior.
Every operator acts like a (batch) matrix. This method returns both
@@ -118,9 +118,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
Returns:
operator: `LinearOperator` subclass instance.
mat: `Tensor` representing operator.
- feed_dict: Dictionary.
- If placholder is True, this must contains everything needed to be fed
- to sess.run calls at runtime to make the operator work.
"""
# Create a matrix as a numpy array with desired shape/dtype.
# Create a LinearOperator that should have the same behavior as the matrix.
@@ -189,12 +186,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_dense = operator.to_dense()
if not use_placeholder:
self.assertAllEqual(build_info.shape, op_dense.get_shape())
- op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict)
+ op_dense_v, mat_v = sess.run([op_dense, mat])
self.assertAC(op_dense_v, mat_v)
def test_det(self):
@@ -204,14 +201,13 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_det = operator.determinant()
if not use_placeholder:
self.assertAllEqual(build_info.shape[:-2], op_det.get_shape())
op_det_v, mat_det_v = sess.run(
- [op_det, linalg_ops.matrix_determinant(mat)],
- feed_dict=feed_dict)
+ [op_det, linalg_ops.matrix_determinant(mat)])
self.assertAC(op_det_v, mat_det_v)
def test_log_abs_det(self):
@@ -221,7 +217,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_log_abs_det = operator.log_abs_determinant()
_, mat_log_abs_det = linalg.slogdet(mat)
@@ -229,7 +225,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
self.assertAllEqual(
build_info.shape[:-2], op_log_abs_det.get_shape())
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
- [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict)
+ [op_log_abs_det, mat_log_abs_det])
self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
def _test_matmul(self, with_batch):
@@ -246,7 +242,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for adjoint_arg in self._adjoint_arg_options:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
x = self._make_x(
operator, adjoint=adjoint, with_batch=with_batch)
@@ -264,7 +260,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
self.assertAllEqual(op_matmul.get_shape(),
mat_matmul.get_shape())
op_matmul_v, mat_matmul_v = sess.run(
- [op_matmul, mat_matmul], feed_dict=feed_dict)
+ [op_matmul, mat_matmul])
self.assertAC(op_matmul_v, mat_matmul_v)
def test_matmul(self):
@@ -289,7 +285,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for adjoint_arg in self._adjoint_arg_options:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
rhs = self._make_rhs(
operator, adjoint=adjoint, with_batch=with_batch)
@@ -307,8 +303,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
if not use_placeholder:
self.assertAllEqual(op_solve.get_shape(),
mat_solve.get_shape())
- op_solve_v, mat_solve_v = sess.run(
- [op_solve, mat_solve], feed_dict=feed_dict)
+ op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
self.assertAC(op_solve_v, mat_solve_v)
def test_solve(self):
@@ -326,14 +321,13 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_trace = operator.trace()
mat_trace = math_ops.trace(mat)
if not use_placeholder:
self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
- op_trace_v, mat_trace_v = sess.run(
- [op_trace, mat_trace], feed_dict=feed_dict)
+ op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
self.assertAC(op_trace_v, mat_trace_v)
def test_add_to_tensor(self):
@@ -343,15 +337,14 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_plus_2mat = operator.add_to_tensor(2 * mat)
if not use_placeholder:
self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape())
- op_plus_2mat_v, mat_v = sess.run(
- [op_plus_2mat, mat], feed_dict=feed_dict)
+ op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat])
self.assertAC(op_plus_2mat_v, 3 * mat_v)
@@ -362,7 +355,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_diag_part = operator.diag_part()
mat_diag_part = array_ops.matrix_diag_part(mat)
@@ -372,7 +365,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
op_diag_part.get_shape())
op_diag_part_, mat_diag_part_ = sess.run(
- [op_diag_part, mat_diag_part], feed_dict=feed_dict)
+ [op_diag_part, mat_diag_part])
self.assertAC(op_diag_part_, mat_diag_part_)
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 3cbbf3412a..b6b98d5c86 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -55,6 +55,17 @@ def _MatrixDeterminantGrad(op, grad):
return multipliers * a_adj_inv
+@ops.RegisterGradient("LogMatrixDeterminant")
+def _LogMatrixDeterminantGrad(op, _, grad_b):
+ """Gradient for LogMatrixDeterminant."""
+ a = op.inputs[0]
+ c = op.outputs[1]
+ a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True)
+ multipliers = array_ops.reshape(
+ grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0))
+ return multipliers * a_adj_inv
+
+
@ops.RegisterGradient("Cholesky")
def _CholeskyGrad(op, grad):
"""Gradient for Cholesky."""
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index 8276047cb6..df41933f8a 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -35,9 +35,12 @@ from tensorflow.python.util.tf_export import tf_export
# Assert and Print are special symbols in python, so we must
-# have an upper-case version of them. For users with Python 3 or Python 2.7
-# with `from __future__ import print_function`, we also allow lowercase.
-@tf_export("Print", "print")
+# have an upper-case version of them.
+#
+# For users with Python 3 or Python 2.7
+# with `from __future__ import print_function`, we could also allow lowercase.
+# See https://github.com/tensorflow/tensorflow/issues/18053
+@tf_export("Print")
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 0e547689cc..fb51fbc626 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -366,6 +366,10 @@ class KeyValueTensorInitializer(TableInitializerBase):
with ops.name_scope(
self._name, values=(table.table_ref, self._keys,
self._values)) as scope:
+ if context.executing_eagerly():
+ # Ensure a unique name when eager execution is enabled to avoid spurious
+ # sharing issues.
+ scope += str(ops.uid())
init_op = gen_lookup_ops.initialize_table_v2(
table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
@@ -1108,6 +1112,10 @@ def index_table_from_tensor(vocabulary_list,
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
+ if context.executing_eagerly():
+ # Ensure a unique name when eager execution is enabled to avoid spurious
+ # sharing issues.
+ shared_name += str(ops.uid())
table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
table_keys,
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index de9b3c6909..66633c8b12 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -192,6 +192,11 @@ def compute_weighted_loss(
on some model parameters but you do not want this to affect the loss
gradient, you need to apply @{tf.stop_gradient} to `weights` before
passing them to `compute_weighted_loss`.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
Reduction.validate(reduction)
with ops.name_scope(scope, "weighted_loss", (losses, weights)):
@@ -260,6 +265,11 @@ def absolute_difference(
ValueError: If the shape of `predictions` doesn't match that of
`labels` or if the shape of `weights` is invalid or if `labels`
or `predictions` is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -306,6 +316,11 @@ def cosine_distance(
Raises:
ValueError: If `predictions` shape doesn't match `labels` shape, or
`axis`, `labels`, `predictions` or `weights` is `None`.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
axis = deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
@@ -353,6 +368,11 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
Raises:
ValueError: If the shapes of `logits` and `labels` don't match or
if `labels` or `logits` is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -416,6 +436,11 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
ValueError: If the shape of `predictions` doesn't match that of `labels` or
if the shape of `weights` is invalid. Also if `labels` or
`predictions` is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -477,6 +502,11 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
ValueError: If the shape of `predictions` doesn't match that of `labels` or
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -540,6 +570,11 @@ def mean_pairwise_squared_error(
ValueError: If the shape of `predictions` doesn't match that of `labels` or
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -618,6 +653,11 @@ def mean_squared_error(
ValueError: If the shape of `predictions` doesn't match that of `labels` or
if the shape of `weights` is invalid. Also if `labels` or `predictions`
is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
@@ -670,6 +710,11 @@ def sigmoid_cross_entropy(
ValueError: If the shape of `logits` doesn't match that of
`multi_class_labels` or if the shape of `weights` is invalid, or if
`weights` is None. Also if `multi_class_labels` or `logits` is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if multi_class_labels is None:
raise ValueError("multi_class_labels must not be None.")
@@ -731,6 +776,11 @@ def softmax_cross_entropy(
ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
or if the shape of `weights` is invalid or if `weights` is None. Also if
`onehot_labels` or `logits` is None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if onehot_labels is None:
raise ValueError("onehot_labels must not be None.")
@@ -828,7 +878,8 @@ def sparse_softmax_cross_entropy(
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
- `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
+ `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32` or
+ `float64`.
weights: Coefficients for the loss. This must be scalar or broadcastable to
`labels` (i.e. same rank and each dimension is either 1 or the same).
scope: the scope for the operations performed in computing the loss.
@@ -842,6 +893,11 @@ def sparse_softmax_cross_entropy(
Raises:
ValueError: If the shapes of `logits`, `labels`, and `weights` are
incompatible, or if any of them are None.
+
+ @compatbility(eager)
+ The `loss_collection` argument is ignored when executing eagerly. Consider
+ holding on to the return value or collecting losses via a `tf.keras.Model`.
+ @end_compatibility
"""
if labels is None:
raise ValueError("labels must not be None.")
diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py
index 10646af8a9..97bba46661 100644
--- a/tensorflow/python/ops/losses/util.py
+++ b/tensorflow/python/ops/losses/util.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@@ -32,7 +33,10 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
loss: A loss `Tensor`.
loss_collection: Optional collection to add the loss to.
"""
- if loss_collection:
+ # Since we have no way of figuring out when a training iteration starts or
+ # ends, holding on to a loss when executing eagerly is indistingishable from
+ # leaking memory. We instead leave the collection empty.
+ if loss_collection and not context.executing_eagerly():
ops.add_to_collection(loss_collection, loss)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 563c0b3ab3..f0c6bd532f 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -620,29 +620,59 @@ def _DigammaGrad(op, grad):
return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
+@ops.RegisterGradient("BesselI0e")
+def _BesselI0eGrad(op, grad):
+ """Compute gradient of bessel_i0e(x) with respect to its argument."""
+ x = op.inputs[0]
+ y = op.outputs[0]
+ with ops.control_dependencies([grad]):
+ return grad * (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
+
+
+@ops.RegisterGradient("BesselI1e")
+def _BesselI1eGrad(op, grad):
+ """Compute gradient of bessel_i1e(x) with respect to its argument."""
+ x = op.inputs[0]
+ y = op.outputs[0]
+ with ops.control_dependencies([grad]):
+ # For x = 0, the correct gradient is 0.5.
+ # However, the main branch gives NaN because of the division by x, so
+ # we impute the gradient manually.
+ # An alternative solution is to express the gradient via bessel_i0e and
+ # bessel_i2e, but the latter is not yet implemented in Eigen.
+ eps = np.finfo(x.dtype.as_numpy_dtype).eps
+ zeros = array_ops.zeros_like(x)
+ x_is_not_tiny = math_ops.abs(x) > eps
+ safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
+ dy_dx = math_ops.bessel_i0e(safe_x) - y * (
+ math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
+ return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
+
+
@ops.RegisterGradient("Igamma")
def _IgammaGrad(op, grad):
- """Returns gradient of igamma(a, x) with respect to x."""
- # TODO(ebrevdo): Perhaps add the derivative w.r.t. a
+ """Returns gradient of igamma(a, x) with respect to a and x."""
a = op.inputs[0]
x = op.inputs[1]
sa = array_ops.shape(a)
sx = array_ops.shape(x)
- unused_ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
+ ra, rx = gen_array_ops.broadcast_gradient_args(sa, sx)
- # Perform operations in log space before summing, because Gamma(a)
- # and Gamma'(a) can grow large.
- partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x) - math_ops.lgamma(a))
- # TODO(b/36815900): Mark None return values as NotImplemented
- return (None, array_ops.reshape(
- math_ops.reduce_sum(partial_x * grad, rx), sx))
+ with ops.control_dependencies([grad]):
+ partial_a = gen_math_ops.igamma_grad_a(a, x)
+ # Perform operations in log space before summing, because Gamma(a)
+ # and Gamma'(a) can grow large.
+ partial_x = math_ops.exp(-x + (a - 1) * math_ops.log(x)
+ - math_ops.lgamma(a))
+ return (array_ops.reshape(math_ops.reduce_sum(partial_a * grad, ra), sa),
+ array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))
@ops.RegisterGradient("Igammac")
def _IgammacGrad(op, grad):
- """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. x."""
- _, igamma_grad_x = _IgammaGrad(op, grad)
- return None, -igamma_grad_x
+ """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
+ igamma_grad_a, igamma_grad_x = _IgammaGrad(op, grad)
+ return (-igamma_grad_a, -igamma_grad_x)
@ops.RegisterGradient("Betainc")
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 118b02c6c7..c28dca5137 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -37,11 +37,11 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gen_spectral_ops
-from tensorflow.python.platform import tf_logging as logging
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_math_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -370,7 +370,7 @@ def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
Args:
- x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
@@ -651,6 +651,9 @@ def cast(x, dtype, name=None):
TypeError: If `x` cannot be cast to the `dtype`.
"""
base_type = dtypes.as_dtype(dtype).base_dtype
+ if isinstance(x,
+ (ops.Tensor, _resource_variable_type)) and base_type == x.dtype:
+ return x
with ops.name_scope(name, "Cast", [x]) as name:
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
@@ -1222,8 +1225,9 @@ def _ReductionDims(x, axis, reduction_indices):
return axis
else:
# Fast path: avoid creating Rank and Range ops if ndims is known.
- if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access
- return constant_op.constant(np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access
+ rank = common_shapes.rank(x)
+ if rank is not None:
+ return constant_op.constant(np.arange(rank), dtype=dtypes.int32)
if (isinstance(x, sparse_tensor.SparseTensor) and
x.dense_shape.get_shape().is_fully_defined()):
rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D.
@@ -1234,8 +1238,8 @@ def _ReductionDims(x, axis, reduction_indices):
def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
- """Set a reduction's output's shape to be a scalar if we are certain."""
- if (not output.shape.is_fully_defined()) and (not keepdims) and (
+ """Set a reduction's output shape to be a scalar if we are certain."""
+ if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and (
axis is None) and (reduction_indices is None):
output.set_shape(())
return output
@@ -1617,7 +1621,7 @@ def reduce_all(input_tensor,
entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
- If `axis` has no entries, all dimensions are reduced, and a
+ If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned.
For example:
@@ -1675,7 +1679,7 @@ def reduce_any(input_tensor,
entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
- If `axis` has no entries, all dimensions are reduced, and a
+ If `axis` is None, all dimensions are reduced, and a
tensor with a single element is returned.
For example:
@@ -1990,7 +1994,7 @@ def matmul(a,
sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
use_sparse_matmul = (
a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
- if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and
+ if ((a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16) and
a.dtype != b.dtype):
# matmul currently doesn't handle mixed-precision inputs.
use_sparse_matmul = True
@@ -2225,8 +2229,8 @@ def sigmoid(x, name=None):
Returns:
A Tensor with the same type as `x`.
- @compatibility(numpy)
- Equivalent to np.scipy.special.expit
+ @compatibility(scipy)
+ Equivalent to scipy.special.expit
@end_compatibility
"""
with ops.name_scope(name, "Sigmoid", [x]) as name:
@@ -2954,6 +2958,67 @@ def polyval(coeffs, x, name=None):
p = c + p * x
return p
+
+@tf_export("math.bessel_i0e")
+def bessel_i0e(x, name=None):
+ """Computes the Bessel i0e function of `x` element-wise.
+
+ Exponentially scaled modified Bessel function of order 0 defined as
+ `bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
+
+ This function is faster and numerically stabler than `bessel_i0(x)`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.special.i0e
+ @end_compatibility
+ """
+ with ops.name_scope(name, "bessel_i0e", [x]) as name:
+ if isinstance(x, sparse_tensor.SparseTensor):
+ x_i0e = gen_math_ops.bessel_i0e(x.values, name=name)
+ return sparse_tensor.SparseTensor(
+ indices=x.indices, values=x_i0e, dense_shape=x.dense_shape)
+ else:
+ return gen_math_ops.bessel_i0e(x, name=name)
+
+
+@tf_export("math.bessel_i1e")
+def bessel_i1e(x, name=None):
+ """Computes the Bessel i1e function of `x` element-wise.
+
+ Exponentially scaled modified Bessel function of order 1 defined as
+ `bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
+
+ This function is faster and numerically stabler than `bessel_i1(x)`.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.special.i1e
+ @end_compatibility
+ """
+ with ops.name_scope(name, "bessel_i1e", [x]) as name:
+ if isinstance(x, sparse_tensor.SparseTensor):
+ x_i1e = gen_math_ops.bessel_i1e(x.values, name=name)
+ return sparse_tensor.SparseTensor(
+ indices=x.indices, values=x_i1e, dense_shape=x.dense_shape)
+ else:
+ return gen_math_ops.bessel_i1e(x, name=name)
+
+
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
# 1.0 API so we leave these here for backwards compatibility.
fft = gen_spectral_ops.fft
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 980c92b0d5..6b709e5e7f 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -37,14 +37,14 @@ log = np.log
class ReduceTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testReduceAllDims(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
with test_util.device(use_gpu=True):
y_tf = self.evaluate(math_ops.reduce_sum(x))
self.assertEqual(y_tf, 21)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testReduceExplicitAxes(self):
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
with test_util.device(use_gpu=True):
@@ -57,7 +57,7 @@ class ReduceTest(test_util.TensorFlowTestCase):
for axis in (None, (0, 1), (-1, -2), (-2, -1, 0, 1)):
self.assertEqual(self.evaluate(math_ops.reduce_sum(x, axis=axis)), 21)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testReduceInvalidAxis(self):
if context.executing_eagerly():
# The shape check is in run a graph construction time. In eager mode,
@@ -150,7 +150,7 @@ class LogSumExpTest(test_util.TensorFlowTestCase):
class RoundTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRounding(self):
x = np.arange(-5.0, 5.0, .25)
for dtype in [np.float32, np.double, np.int32]:
@@ -194,7 +194,7 @@ class ModTest(test_util.TensorFlowTestCase):
class SquaredDifferenceTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSquaredDifference(self):
for dtype in [np.int32, np.float16]:
x = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
@@ -207,7 +207,7 @@ class SquaredDifferenceTest(test_util.TensorFlowTestCase):
class ApproximateEqualTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testApproximateEqual(self):
for dtype in [np.float32, np.double]:
x = dtype(1)
@@ -235,10 +235,19 @@ class ApproximateEqualTest(test_util.TensorFlowTestCase):
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
self.assertAllEqual(z, z_tf)
+ def testApproximateEqualShape(self):
+ for dtype in [np.float32, np.double]:
+ x = np.array([1, 2], dtype=dtype)
+ y = np.array([[1, 2]], dtype=dtype)
+ # The inputs 'x' and 'y' must have the same shape.
+ with self.assertRaisesRegexp(
+ ValueError, "Shapes must be equal rank, but are 1 and 2"):
+ math_ops.approximate_equal(x, y)
+
class ScalarMulTest(test_util.TensorFlowTestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAcceptsRefs(self):
if context.executing_eagerly():
var = resource_variable_ops.ResourceVariable(10, name="var")
@@ -250,14 +259,14 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
self.evaluate(init)
self.assertEqual(30, self.evaluate(result))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAcceptsConstant(self):
const = constant_op.constant(10)
result = math_ops.scalar_mul(3, const)
with test_util.device(use_gpu=True):
self.assertEqual(30, self.evaluate(result))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAcceptsTensor(self):
tensor = array_ops.ones([10, 10])
result = math_ops.scalar_mul(3, tensor)
@@ -266,7 +275,7 @@ class ScalarMulTest(test_util.TensorFlowTestCase):
with test_util.device(use_gpu=True):
self.assertAllEqual(self.evaluate(expected), self.evaluate(result))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAcceptsIndexedSlices(self):
values = constant_op.constant([2, 3, 5, 7, 0, -1], shape=[3, 2])
indices = constant_op.constant([0, 2, 5])
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 47eea6ef6b..3aedeb6acd 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -34,20 +34,54 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
def metric_variable(shape, dtype, validate_shape=True, name=None):
- """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
+ """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
+ If running in a `DistributionStrategy` context, the variable will be
+ "tower local". This means:
+
+ * The returned object will be a container with separate variables
+ per replica/tower of the model.
+
+ * When writing to the variable, e.g. using `assign_add` in a metric
+ update, the update will be applied to the variable local to the
+ replica/tower.
+
+ * To get a metric's result value, we need to sum the variable values
+ across the replicas/towers before computing the final answer.
+ Furthermore, the final answer should be computed once instead of
+ in every replica/tower. Both of these are accomplished by
+ running the computation of the final result value inside
+ `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+ Inside the `merge_call()`, ops are only added to the graph once
+ and access to a tower-local variable in a computation returns
+ the sum across all replicas/towers.
+
+ Args:
+ shape: Shape of the created variable.
+ dtype: Type of the created variable.
+ validate_shape: (Optional) Whether shape validation is enabled for
+ the created variable.
+ name: (Optional) String name of the created variable.
+
+ Returns:
+ A (non-trainable) variable initialized to zero, or if inside a
+ `DistributionStrategy` scope a tower-local variable container.
+ """
+ # Note that synchronization "ON_READ" implies trainable=False.
return variable_scope.variable(
lambda: array_ops.zeros(shape, dtype),
- trainable=False,
collections=[
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
],
validate_shape=validate_shape,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM,
name=name)
@@ -333,11 +367,15 @@ def mean(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- mean_t = _safe_div(total, count, 'value')
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ def aggregate_across_towers(_, t, c):
+ mean_t = _safe_div(t, c, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
+ return mean_t
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
+ mean_t = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total, count)
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -572,6 +610,17 @@ def _confusion_matrix_at_thresholds(labels,
return values, update_ops
+def _aggregate_variable(v, collections):
+
+ def f(distribution, value):
+ value = distribution.read_var(value)
+ if collections:
+ ops.add_to_collections(collections, value)
+ return value
+
+ return distribute_lib.get_tower_context().merge_call(f, v)
+
+
@tf_export('metrics.auc')
def auc(labels,
predictions,
@@ -757,14 +806,18 @@ def auc(labels,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
+ def aggregate_auc(_, values):
+ auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+ values['fp'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, auc_value)
+ return auc_value
+
+ auc_value = distribute_lib.get_tower_context().merge_call(
+ aggregate_auc, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -992,15 +1045,18 @@ def mean_per_class_accuracy(labels,
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- per_class_accuracy = _safe_div(count, total, None)
-
- mean_accuracy_v = math_ops.reduce_mean(
- per_class_accuracy, name='mean_accuracy')
- update_op = _safe_div(update_count_op, update_total_op, name='update_op')
+ def aggregate_mean_accuracy(_, count, total):
+ per_class_accuracy = _safe_div(count, total, None)
+ mean_accuracy_v = math_ops.reduce_mean(
+ per_class_accuracy, name='mean_accuracy')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_accuracy_v)
+ return mean_accuracy_v
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
+ mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
+ aggregate_mean_accuracy, count, total)
+ update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1071,7 +1127,7 @@ def mean_iou(labels,
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(name):
+ def compute_mean_iou(total_cm, name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1098,10 +1154,14 @@ def mean_iou(labels,
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
return result
- mean_iou_v = compute_mean_iou('mean_iou')
+ def mean_iou_across_towers(_, v):
+ mean_iou_v = compute_mean_iou(v, 'mean_iou')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_iou_v)
+ return mean_iou_v
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
+ mean_iou_v = distribute_lib.get_tower_context().merge_call(
+ mean_iou_across_towers, total_cm)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1310,12 +1370,16 @@ def mean_tensor(values,
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- mean_t = _safe_div(total, count, 'value')
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ def aggregate_across_towers(_, t, c):
+ mean_t = _safe_div(t, c, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
+ return mean_t
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
+ mean_t = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total, count)
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1413,12 +1477,9 @@ def _count_condition(values,
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
- value_tensor = array_ops.identity(count)
- update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, value_tensor)
+ value_tensor = _aggregate_variable(count, metrics_collections)
+ update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -1525,13 +1586,12 @@ def false_negatives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fn',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['fn'])
+ fn_value = _aggregate_variable(values['fn'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fn'])
- return values['fn'], update_ops['fn']
+ return fn_value, update_ops['fn']
@tf_export('metrics.false_positives')
@@ -1635,13 +1695,12 @@ def false_positives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fp',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['fp'])
+ fp_value = _aggregate_variable(values['fp'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fp'])
- return values['fp'], update_ops['fp']
+ return fp_value, update_ops['fp']
@tf_export('metrics.true_negatives')
@@ -1745,13 +1804,12 @@ def true_negatives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tn',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['tn'])
+ tn_value = _aggregate_variable(values['tn'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tn'])
- return values['tn'], update_ops['tn']
+ return tn_value, update_ops['tn']
@tf_export('metrics.true_positives')
@@ -1855,13 +1913,12 @@ def true_positives_at_thresholds(labels,
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tp',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['tp'])
+ tp_value = _aggregate_variable(values['tp'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tp'])
- return values['tp'], update_ops['tp']
+ return tp_value, update_ops['tp']
@tf_export('metrics.precision')
@@ -1945,13 +2002,17 @@ def precision(labels,
return array_ops.where(
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
- p = compute_precision(true_p, false_p, 'value')
- update_op = compute_precision(true_positives_update_op,
- false_positives_update_op, 'update_op')
+ def once_across_towers(_, true_p, false_p):
+ p = compute_precision(true_p, false_p, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, p)
+ return p
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
+ p = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, true_p, false_p)
+ update_op = compute_precision(true_positives_update_op,
+ false_positives_update_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2025,13 +2086,17 @@ def precision_at_thresholds(labels,
def compute_precision(tp, fp, name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
- prec = compute_precision(values['tp'], values['fp'], 'value')
- update_op = compute_precision(update_ops['tp'], update_ops['fp'],
- 'update_op')
+ def precision_across_towers(_, values):
+ prec = compute_precision(values['tp'], values['fp'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, prec)
+ return prec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
+ prec = distribute_lib.get_tower_context().merge_call(
+ precision_across_towers, values)
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+ 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2050,7 +2115,7 @@ def recall(labels,
The `recall` function creates two local variables, `true_positives`
and `false_negatives`, that are used to compute the recall. This value is
ultimately returned as `recall`, an idempotent operation that simply divides
- `true_positives` by the sum of `true_positives` and `false_negatives`.
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
For estimation of the metric over a stream of data, the function creates an
`update_op` that updates these variables and returns the `recall`. `update_op`
@@ -2117,13 +2182,17 @@ def recall(labels,
math_ops.greater(true_p + false_n, 0),
math_ops.div(true_p, true_p + false_n), 0, name)
- rec = compute_recall(true_p, false_n, 'value')
- update_op = compute_recall(true_positives_update_op,
- false_negatives_update_op, 'update_op')
+ def once_across_towers(_, true_p, false_n):
+ rec = compute_recall(true_p, false_n, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
+ return rec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
+ rec = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, true_p, false_n)
+ update_op = compute_recall(true_positives_update_op,
+ false_negatives_update_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2552,11 +2621,17 @@ def recall_at_top_k(labels,
class_id=class_id,
weights=weights)
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+ def aggregate_across_towers(_, tp, fn):
+ metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ return metric
+
+ metric = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, tp, fn)
+
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
@@ -2627,12 +2702,16 @@ def recall_at_thresholds(labels,
def compute_recall(tp, fn, name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
- rec = compute_recall(values['tp'], values['fn'], 'value')
- update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
+ def recall_across_towers(_, values):
+ rec = compute_recall(values['tp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
+ return rec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
+ rec = distribute_lib.get_tower_context().merge_call(
+ recall_across_towers, values)
+ update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -2698,13 +2777,16 @@ def root_mean_squared_error(labels,
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
+ def once_across_towers(_, mse):
+ rmse = math_ops.sqrt(mse)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rmse)
+ return rmse
- rmse = math_ops.sqrt(mse)
- update_rmse_op = math_ops.sqrt(update_mse_op)
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
+ rmse = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, mse)
+ update_rmse_op = math_ops.sqrt(update_mse_op)
if updates_collections:
ops.add_to_collections(updates_collections, update_rmse_op)
@@ -2797,15 +2879,19 @@ def sensitivity_at_specificity(labels,
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- sensitivity = compute_sensitivity_at_specificity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ def aggregate_across_towers(_, values):
+ sensitivity = compute_sensitivity_at_specificity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, sensitivity)
+ return sensitivity
+
+ sensitivity = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, values)
+
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
@@ -3070,11 +3156,16 @@ def _streaming_sparse_average_precision_at_top_k(labels,
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- update = _safe_scalar_div(total_update, max_update, name=scope)
+ def aggregate_across_towers(_, total_var, max_var):
+ mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_average_precision)
+ return mean_average_precision
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
+ mean_average_precision = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total_var, max_var)
+
+ update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
ops.add_to_collections(updates_collections, update)
@@ -3351,11 +3442,17 @@ def precision_at_top_k(labels,
class_id=class_id,
weights=weights)
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ def aggregate_across_towers(_, tp, fp):
+ metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ return metric
+
+ metric = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, tp, fp)
+
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
@@ -3583,15 +3680,19 @@ def specificity_at_sensitivity(labels,
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- specificity = compute_specificity_at_sensitivity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ def aggregate_across_towers(_, values):
+ specificity = compute_specificity_at_sensitivity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, specificity)
+ return specificity
+
+ specificity = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, values)
+
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 783d485892..f47f38e29e 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -621,7 +621,7 @@ def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
Args:
- counts: A `Tensor` containing a the total count of the data (one value).
+ counts: A `Tensor` containing the total count of the data (one value).
mean_ss: A `Tensor` containing the mean sufficient statistics: the (possibly
shifted) sum of the elements to average over.
variance_ss: A `Tensor` containing the variance sufficient statistics: the
@@ -689,6 +689,9 @@ def moments(
# Compute true mean while keeping the dims for proper broadcasting.
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
# sample variance, not unbiased variance
+ # Note: stop_gradient does not change the gradient that gets
+ # backpropagated to the mean from the variance calculation,
+ # because that gradient is zero
variance = math_ops.reduce_mean(
math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
axes,
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 0c2f5b06c4..41d54a6c2f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2009,7 +2009,8 @@ def sparse_softmax_cross_entropy_with_logits(
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Unscaled log probabilities of shape
- `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float32` or `float64`.
+ `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
+ `float64`.
name: A name for the operation (optional).
Returns:
@@ -2166,7 +2167,7 @@ def _calc_conv_flops(graph, node):
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats(
"flops",
(output_count * filter_in_depth * filter_height * filter_width * 2))
@@ -2184,7 +2185,7 @@ def _calc_depthwise_conv_flops(graph, node):
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
@@ -2594,7 +2595,7 @@ def _calc_dilation2d_flops(graph, node):
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 035b4735af..ae24ca0552 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -76,7 +76,7 @@ class SoftmaxTest(test_lib.TestCase):
z = u.sum(1)[:, np.newaxis]
return u / z
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSoftmax(self):
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float32)
@@ -123,7 +123,7 @@ class LogPoissonLossTest(test_lib.TestCase):
lpl += np.ma.masked_array(stirling_approx, mask=(z <= 1)).filled(0.)
return lpl
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLogPoissonLoss(self):
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float32)
@@ -164,7 +164,7 @@ class LogSoftmaxTest(test_lib.TestCase):
u = x - m
return u - np.log(np.sum(np.exp(u), 1, keepdims=True))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLogSoftmax(self):
x_shape = [5, 10]
x_np = np.random.randn(*x_shape).astype(np.float32)
@@ -201,7 +201,7 @@ class LogSoftmaxTest(test_lib.TestCase):
class L2LossTest(test_lib.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testL2Loss(self):
for dtype in [dtypes.float32, dtypes.float64]:
x = constant_op.constant(
@@ -235,7 +235,7 @@ class L2NormalizeTest(test_lib.TestCase):
norm = np.apply_along_axis(np.linalg.norm, dim, x)
return x / np.expand_dims(norm, dim)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testL2Normalize(self):
x_shape = [20, 7, 3]
np.random.seed(1)
@@ -246,7 +246,7 @@ class L2NormalizeTest(test_lib.TestCase):
y_tf = nn_impl.l2_normalize(x_tf, dim)
self.assertAllClose(y_np, self.evaluate(y_tf))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testL2NormalizeDimArray(self):
x_shape = [20, 7, 3]
np.random.seed(1)
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
new file mode 100644
index 0000000000..065c2caedc
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -0,0 +1,129 @@
+package(
+ default_visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "parallel_for",
+ srcs = [
+ "__init__.py",
+ "control_flow_ops.py",
+ "gradients.py",
+ "pfor.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":gradients",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:util",
+ "@absl_py//absl/flags",
+ ],
+)
+
+py_library(
+ name = "pfor_lib",
+ srcs = ["pfor.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "@absl_py//absl/flags",
+ ],
+)
+
+py_library(
+ name = "control_flow_ops",
+ srcs = ["control_flow_ops.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":pfor_lib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+cuda_py_test(
+ name = "control_flow_ops_test",
+ srcs = ["control_flow_ops_test.py"],
+ additional_deps = [
+ ":control_flow_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:logging_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tensor_array_grad",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "gradients",
+ srcs = ["gradients.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:util",
+ ],
+)
+
+cuda_py_test(
+ name = "gradients_test",
+ size = "large",
+ srcs = ["gradients_test.py"],
+ additional_deps = [
+ ":control_flow_ops",
+ ":gradients",
+ "//third_party/py/numpy",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python/ops/losses",
+ ],
+ tags = ["no_gpu"], # TODO(b/80127739): test is flaky
+)
diff --git a/tensorflow/python/ops/parallel_for/__init__.py b/tensorflow/python/ops/parallel_for/__init__.py
new file mode 100644
index 0000000000..b49d865968
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/__init__.py
@@ -0,0 +1,35 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops for pfor, for_loop, jacobian."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops.parallel_for import * # pylint: disable=wildcard-import
+from tensorflow.python.ops.parallel_for.control_flow_ops import for_loop
+from tensorflow.python.ops.parallel_for.control_flow_ops import pfor
+from tensorflow.python.ops.parallel_for.gradients import batch_jacobian
+from tensorflow.python.ops.parallel_for.gradients import jacobian
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'pfor',
+ 'for_loop',
+ 'jacobian',
+ 'batch_jacobian',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
new file mode 100644
index 0000000000..ccf2eb8214
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""for_loop and pfor ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops.parallel_for.pfor import PFor
+from tensorflow.python.util import nest
+
+
+def for_loop(loop_fn, loop_fn_dtypes, iters):
+ """Runs `loop_fn` `iters` times and stacks the outputs.
+
+
+ Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and
+ stacks corresponding outputs of the different runs.
+
+ Args:
+ loop_fn: A function that takes an int32 scalar tf.Tensor object representing
+ the iteration number, and returns a possibly nested structure of tensor
+ objects. The shape of these outputs should not depend on the input.
+ loop_fn_dtypes: dtypes for the outputs of loop_fn.
+ iters: Number of iterations for which to run loop_fn.
+
+ Returns:
+ Returns a nested structure of stacked output tensor objects with the same
+ nested structure as the output of `loop_fn`.
+ """
+
+ flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes)
+
+ def while_body(i, *ta_list):
+ """Body of while loop."""
+ fn_output = nest.flatten(loop_fn(i))
+ if len(fn_output) != len(flat_loop_fn_dtypes):
+ raise ValueError(
+ "Number of expected outputs, %d, does not match the number of "
+ "actual outputs, %d, from loop_fn" % (len(flat_loop_fn_dtypes),
+ len(fn_output)))
+ outputs = []
+ for out, ta in zip(fn_output, ta_list):
+ # TODO(agarwal): support returning Operation objects from loop_fn.
+ assert isinstance(out, ops.Tensor)
+ outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
+ return tuple([i + 1] + outputs)
+
+ ta_list = control_flow_ops.while_loop(
+ lambda i, *ta: i < iters, while_body, [0] + [
+ tensor_array_ops.TensorArray(dtype, iters)
+ for dtype in flat_loop_fn_dtypes
+ ])[1:]
+
+ # TODO(rachelim): enable this for sparse tensors
+ return nest.pack_sequence_as(loop_fn_dtypes, [ta.concat() for ta in ta_list])
+
+
+def pfor(loop_fn, iters):
+ """Equivalent to running `loop_fn` `iters` times and stacking the outputs.
+
+ `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
+ times, with input from 0 to `iters - 1`, and stacking corresponding output of
+ each iteration. However the implementation does not use a tf.while_loop.
+ Instead it adds new operations to the graph that collectively compute the same
+ value as what running `loop_fn` in a loop would compute.
+
+
+ This is an experimental feature and currently has a lot of limitations:
+ - There should be no data depenendency between the different iterations. For
+ example, a future iteration should not depend on a value or side-effect of
+ a previous iteration.
+ - Stateful kernels may mostly not be supported since these often imply a
+ data dependency or ordering of the iterations. We do support a limited set
+ of such stateful kernels though (like RandomFoo, Variable operations like
+ reads, etc).
+ - Conversion works only on a limited set of kernels for which a converter
+ has been registered.
+ - loop_fn cannot currently contain control flow operations like
+ tf.while_loop or tf.cond.
+ - `loop_fn` should return nested structure of Tensors or Operations. However
+ if an Operation is returned, it should have zero outputs.
+ - The shape and dtype of `loop_fn` outputs should not depend on the input
+ to loop_fn.
+
+ Args:
+ loop_fn: A function that takes an int32 scalar tf.Tensor object representing
+ the iteration number, and returns a possibly nested structure of Tensor or
+ Operation objects.
+ iters: Number of iterations for which to run loop_fn.
+
+ Returns:
+ Returns a nested structure of stacked tensor objects with the same nested
+ structure as the output of `loop_fn`.
+ """
+ existing_ops = set(ops.get_default_graph().get_operations())
+ with ops.name_scope("loop_body"):
+ loop_var = array_ops.placeholder(dtypes.int32, shape=[])
+ loop_fn_outputs = loop_fn(loop_var)
+ new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
+ iters = ops.convert_to_tensor(iters)
+ with ops.name_scope("pfor"):
+ converter = PFor(loop_var, iters, new_ops)
+ outputs = []
+ for loop_fn_output in nest.flatten(loop_fn_outputs):
+ outputs.append(converter.convert(loop_fn_output))
+ return nest.pack_sequence_as(loop_fn_outputs, outputs)
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
new file mode 100644
index 0000000000..c0e66cb0b8
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -0,0 +1,1404 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for pfor and for_loop."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl import flags
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gradients as gradient_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import nest
+
+
+class PForTest(test.TestCase):
+
+ def _run_targets(self, targets1, targets2=None, run_init=True):
+ targets1 = nest.flatten(targets1)
+ targets2 = ([] if targets2 is None else nest.flatten(targets2))
+ assert len(targets1) == len(targets2) or not targets2
+ if run_init:
+ init = variables.global_variables_initializer()
+ self.evaluate(init)
+ return self.evaluate(targets1 + targets2)
+
+ def run_and_assert_equal(self, targets1, targets2):
+ outputs = self._run_targets(targets1, targets2)
+ outputs = nest.flatten(outputs) # flatten SparseTensorValues
+ n = len(outputs) // 2
+ for i in range(n):
+ if outputs[i + n].dtype != np.object:
+ self.assertAllClose(outputs[i + n], outputs[i], rtol=1e-4, atol=1e-5)
+ else:
+ self.assertAllEqual(outputs[i + n], outputs[i])
+
+ def _test_loop_fn(self, loop_fn, iters, loop_fn_dtypes=dtypes.float32):
+ t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters)
+ t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters)
+ self.run_and_assert_equal(t1, t2)
+
+ def test_op_conversion_fallback_to_while_loop(self):
+ # Note that we used top_k op for this test. If a converter gets defined for
+ # it, we will need to find another op for which a converter has not been
+ # defined.
+ x = random_ops.random_uniform([3, 2, 4])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ return nn.top_k(x_i)
+
+ with self.assertRaisesRegexp(ValueError, "No converter defined"):
+ self._test_loop_fn(
+ loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32])
+ flags.FLAGS.op_conversion_fallback_to_while_loop = True
+ self._test_loop_fn(
+ loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32])
+ flags.FLAGS.op_conversion_fallback_to_while_loop = False
+
+
+class ArrayTest(PForTest):
+
+ def test_gather(self):
+ x = random_ops.random_uniform([3, 3, 3])
+
+ def loop_fn(i):
+ outputs = []
+ x_i = array_ops.gather(x, i)
+ for y in [x, x_i]:
+ axes = [0, 2, -1] if y == x else [0]
+ for axis in axes:
+ outputs.append(array_ops.gather(y, 2, axis=axis))
+ outputs.append(array_ops.gather(y, i, axis=axis))
+ outputs.append(array_ops.gather(y, [i], axis=axis))
+ outputs.append(array_ops.gather(y, [i, 2], axis=axis))
+ outputs.append(array_ops.gather(y, [[2, i], [i, 1]], axis=axis))
+ return outputs
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 20)
+
+ def test_shape(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ return array_ops.shape(x_i), array_ops.shape(x_i, out_type=dtypes.int64)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64])
+
+ def test_size(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ return array_ops.size(x_i), array_ops.size(x_i, out_type=dtypes.int64)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int64])
+
+ def test_rank(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ return array_ops.rank(x_i)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_shape_n(self):
+ x = random_ops.random_uniform([3, 2, 3])
+ y = random_ops.random_uniform([3])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ y_i = array_ops.gather(y, i)
+ return array_ops.shape_n([x_i, x, y, y_i]), array_ops.shape_n(
+ [x_i, x, y, y_i], out_type=dtypes.int64)
+
+ self._test_loop_fn(
+ loop_fn, 3, loop_fn_dtypes=[dtypes.int32] * 4 + [dtypes.int64] * 4)
+
+ def test_reshape(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.reshape(x1, [-1]), array_ops.reshape(x1, [1, 3, 1, -1])
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_expand_dims(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.expand_dims(
+ x1, axis=-1), array_ops.expand_dims(
+ x1, axis=1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_slice(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.slice(x1, begin=(0, 1), size=(2, 1))
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_tile(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.tile(x1, [2, 1])
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_tile_loop_dependent(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.tile(x1, [i, 1])
+
+ with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"):
+ pfor_control_flow_ops.pfor(loop_fn, 2)
+
+ def test_pack(self):
+ x = random_ops.random_uniform([3, 2, 3])
+ y = random_ops.random_uniform([2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.stack([x1, y], axis=-1)
+
+ self._test_loop_fn(loop_fn, 1)
+
+ def test_unpack(self):
+ x = random_ops.random_uniform([3, 2, 3, 4])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ return array_ops.unstack(
+ x_i, 4, axis=-1), array_ops.unstack(
+ x_i, 3, axis=1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 7)
+
+ def test_pad(self):
+ x = random_ops.random_uniform([3, 2, 3])
+ padding = constant_op.constant([[1, 2], [3, 4]])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.pad(x1, padding, mode="CONSTANT")
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_split(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.split(x1, 2, axis=0), array_ops.split(x1, 3, axis=-1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 5)
+
+ def test_transpose(self):
+ x = random_ops.random_uniform([3, 2, 3, 4])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.transpose(x1, [2, 1, 0])
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_zeros_like(self):
+ x = random_ops.random_uniform([3, 2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ z = array_ops.zeros_like(x1),
+ return z, z + x1
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_concat_v2(self):
+ x = random_ops.random_uniform([3, 2, 3])
+ y = random_ops.random_uniform([2, 3])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return array_ops.concat(
+ [x1, x1, y], axis=0), array_ops.concat(
+ [x1, x1, y], axis=-1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_unary_cwise_ops(self):
+ for op in [array_ops.identity, array_ops.stop_gradient]:
+ x = random_ops.random_uniform([3, 5])
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y = op(x1) + x1
+ loss = nn.l2_loss(y)
+ return op(x), y, gradient_ops.gradients(loss, x1)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+
+ def test_strided_slice(self):
+ x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ y = x_i[:2, ::2, 1::3, ..., array_ops.newaxis, 1]
+ loss = nn.l2_loss(y)
+ return y, gradient_ops.gradients(loss, x_i)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+
+class MathTest(PForTest):
+
+ def test_unary_cwise_ops(self):
+ for op in [
+ math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative,
+ math_ops.square
+ ]:
+ x = random_ops.random_uniform([3, 5])
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y = op(x1)
+ loss = math_ops.reduce_sum(y * y)
+ return op(x), y, gradient_ops.gradients(loss, x1)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+
+ def test_unary_cwise_no_grad(self):
+ for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]:
+ x = random_ops.random_uniform([3, 5])
+ if op == math_ops.logical_not:
+ x = x > 0
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ return op(array_ops.gather(x, i))
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=x.dtype)
+
+ def test_binary_cwise_ops(self):
+ logical_ops = [
+ math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor
+ ]
+ bool_ops = [
+ math_ops.less, math_ops.less_equal, math_ops.greater,
+ math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+ ]
+ float_ops = [
+ math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide,
+ math_ops.maximum, math_ops.minimum
+ ]
+ for op in logical_ops + bool_ops + float_ops:
+ x = random_ops.random_uniform([7, 3, 5])
+ y = random_ops.random_uniform([3, 5])
+ if op in logical_ops:
+ x = x > 0
+ y = y > 0
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y1 = array_ops.gather(y, i)
+ return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)
+
+ # pylint: enable=cell-var-from-loop
+
+ dtype = dtypes.float32 if op in float_ops else dtypes.bool
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5)
+
+ def test_addn(self):
+ x = random_ops.random_uniform([2, 3, 5])
+ y = random_ops.random_uniform([3, 5])
+ z = random_ops.random_uniform([3, 5])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return math_ops.add_n([x1, y, z])
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_matmul(self):
+ for tr_a in (True, False):
+ for tr_b in (True, False):
+ for stack_a in (True, False):
+ for stack_b in (True, False):
+ shape_a = (5, 3) if tr_a else (3, 5)
+ if stack_a:
+ shape_a = (2,) + shape_a
+ shape_b = (7, 5) if tr_b else (5, 7)
+ if stack_b:
+ shape_b = (2,) + shape_b
+
+ x = random_ops.random_uniform(shape_a)
+ y = random_ops.random_uniform(shape_b)
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i) if stack_a else x
+ b = array_ops.gather(y, i) if stack_b else y
+ return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_batch_matmul(self):
+ for tr_a in (True, False):
+ for tr_b in (True, False):
+ for stack_a in (True, False):
+ for stack_b in (True, False):
+ shape_a = (4, 5, 3) if tr_a else (4, 3, 5)
+ if stack_a:
+ shape_a = (2,) + shape_a
+ shape_b = (4, 7, 5) if tr_b else (4, 5, 7)
+ if stack_b:
+ shape_b = (2,) + shape_b
+
+ x = random_ops.random_uniform(shape_a)
+ y = random_ops.random_uniform(shape_b)
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i) if stack_a else x
+ b = array_ops.gather(y, i) if stack_b else y
+ return math_ops.matmul(a, b, transpose_a=tr_a, transpose_b=tr_b)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_reduction(self):
+ x = random_ops.random_uniform([2, 3, 4, 5])
+ for op in [
+ math_ops.reduce_sum, math_ops.reduce_prod, math_ops.reduce_max,
+ math_ops.reduce_min
+ ]:
+ for axis in ([1], None, [0, 2]):
+ for keepdims in (True, False):
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i)
+ return op(a, axis=axis, keepdims=keepdims)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_cum_sum(self):
+ x = random_ops.random_uniform([2, 3, 4, 5])
+ for axis in (1, -2):
+ for exclusive in (True, False):
+ for reverse in (True, False):
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i)
+ return math_ops.cumsum(
+ a, axis=axis, exclusive=exclusive, reverse=reverse)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_cum_prod(self):
+ x = random_ops.random_uniform([2, 3, 4, 5])
+ for axis in (1, -2):
+ for exclusive in (True, False):
+ for reverse in (True, False):
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i)
+ return math_ops.cumprod(
+ a, axis=axis, exclusive=exclusive, reverse=reverse)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+ def test_bias_add(self):
+ x_shape = [2, 3, 4, 5, 6]
+ x = random_ops.random_uniform(x_shape)
+ for data_format in ("NCHW", "NHWC"):
+ bias_dim = 2 if data_format == "NCHW" else -1
+ bias_shape = x_shape[bias_dim]
+ bias = random_ops.random_uniform([bias_shape])
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a = array_ops.gather(x, i)
+ y = nn.bias_add(a, bias, data_format=data_format)
+ loss = math_ops.reduce_sum(y * y)
+ return y, gradient_ops.gradients(loss, bias)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(
+ loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.float32])
+
+ def test_unsorted_segment_sum(self):
+ t = random_ops.random_uniform([3, 3, 2])
+ segment_ids = constant_op.constant([[0, 0, 2], [0, 1, 2], [2, 2, 2]])
+ num_segments = 3
+
+ def loop_fn(i):
+ data = array_ops.gather(t, i)
+ data_0 = array_ops.gather(t, 0)
+ seg_ids = array_ops.gather(segment_ids, i)
+ return (math_ops.unsorted_segment_sum(data, seg_ids, num_segments),
+ math_ops.unsorted_segment_sum(data_0, seg_ids, num_segments))
+
+ self._test_loop_fn(loop_fn, 3, [dtypes.float32] * 2)
+
+ def test_cast(self):
+ x = constant_op.constant([[1], [2]])
+ y = constant_op.constant([[1.0], [2.0]])
+
+ def loop_fn(i):
+ return (math_ops.cast(array_ops.gather(x, i), dtypes.float32),
+ math_ops.cast(array_ops.gather(y, i), dtypes.int32))
+
+ self._test_loop_fn(
+ loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.int32])
+
+ def test_tanh_axpy(self):
+ a = constant_op.constant(3.)
+ x = random_ops.random_uniform([4, 5])
+ y = random_ops.random_uniform([6, 5])
+ n = x.shape[0]
+
+ def loop_fn(i):
+ return math_ops.tanh(a * array_ops.gather(x, i) + array_ops.gather(y, i))
+
+ self._test_loop_fn(loop_fn, n)
+
+ def test_select(self):
+ cond = constant_op.constant([True, False])
+ a = random_ops.random_uniform([2, 3, 5])
+ b = random_ops.random_uniform([2, 3, 5])
+ for cond_shape in [2], [2, 3], [2, 3, 5]:
+ cond = random_ops.random_uniform(cond_shape) > 0.5
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ a_i = array_ops.gather(a, i)
+ b_i = array_ops.gather(b, i)
+ cond_i = array_ops.gather(cond, i)
+ return array_ops.where(cond_i, a_i, b_i)
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 2)
+
+
+class NNTest(PForTest):
+
+ def test_conv2d(self):
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ filt = random_ops.random_uniform([3, 3, 3, 7])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return nn.conv2d(
+ x1, filt, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_conv2d_backprop_input(self):
+ x_shape = [2, 12, 12, 3]
+ filt = random_ops.random_uniform([3, 3, 3, 7])
+ grad = random_ops.random_uniform([3, 2, 5, 5, 7])
+
+ def loop_fn(i):
+ grad1 = array_ops.gather(grad, i)
+ return nn.conv2d_backprop_input(
+ x_shape,
+ filt,
+ grad1,
+ strides=[1, 2, 2, 1],
+ padding="VALID",
+ data_format="NHWC")
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_conv2d_backprop_filter(self):
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ x_0 = array_ops.gather(x, 0)
+ filter_sizes = [3, 3, 3, 7]
+ grad = random_ops.random_uniform([3, 2, 5, 5, 7])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ grad_i = array_ops.gather(grad, i)
+ return [
+ nn.conv2d_backprop_filter(
+ inp,
+ filter_sizes,
+ grad_i,
+ strides=[1, 2, 2, 1],
+ padding="VALID",
+ data_format="NHWC") for inp in [x_i, x_0]
+ ]
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_avg_pool(self):
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ ksize = [1, 3, 3, 1]
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ output = nn.avg_pool(
+ x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
+ loss = nn.l2_loss(output)
+ return output, gradient_ops.gradients(loss, x1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_max_pool(self):
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ ksize = [1, 3, 3, 1]
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ output = nn.max_pool(
+ x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
+ loss = nn.l2_loss(output)
+ return output, gradient_ops.gradients(loss, x1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+ def test_fused_batch_norm(self):
+ data_formats = ["NHWC"]
+ if test.is_gpu_available():
+ data_formats.append("NCHW")
+ for is_training in (True, False):
+ for data_format in data_formats:
+ if data_format == "NCHW":
+ x = random_ops.random_uniform([3, 1, 2, 5, 5])
+ else:
+ x = random_ops.random_uniform([3, 1, 5, 5, 2])
+ scale = random_ops.random_uniform([2])
+ offset = random_ops.random_uniform([2])
+ mean = None if is_training else random_ops.random_uniform([2])
+ variance = None if is_training else random_ops.random_uniform([2])
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ outputs = nn.fused_batch_norm(
+ x1,
+ scale,
+ offset,
+ mean=mean,
+ variance=variance,
+ epsilon=0.01,
+ data_format=data_format,
+ is_training=is_training)
+ outputs = list(outputs)
+ # We only test the first value of outputs when is_training is False.
+ # It looks like CPU and GPU have different outputs for batch_mean and
+ # batch_variance for this case.
+ if not is_training:
+ outputs[1] = constant_op.constant(0.)
+ outputs[2] = constant_op.constant(0.)
+ loss = nn.l2_loss(outputs[0])
+ gradients = gradient_ops.gradients(loss, [x1, scale, offset])
+ return outputs + gradients
+
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 6)
+
+ def test_softmax_cross_entropy_with_logits(self):
+ logits = random_ops.random_uniform([3, 2, 4])
+ labels = random_ops.random_uniform([3, 2, 4])
+ labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
+
+ def loop_fn(i):
+ logits_i = array_ops.gather(logits, i)
+ labels_i = array_ops.gather(labels, i)
+ loss = nn.softmax_cross_entropy_with_logits(
+ labels=labels_i, logits=logits_i)
+ return loss, gradient_ops.gradients(math_ops.reduce_sum(loss), logits_i)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+
+
+class RandomTest(PForTest):
+
+ # The random values generated in the two implementations are not guaranteed to
+ # match. So we only check the returned shapes.
+ def run_and_assert_equal(self, targets1, targets2):
+ outputs = self._run_targets(targets1, targets2)
+ n = len(outputs) // 2
+ for i in range(n):
+ self.assertAllEqual(outputs[i].shape, outputs[i + n].shape)
+
+ def test_random_uniform(self):
+
+ def loop_fn(_):
+ return random_ops.random_uniform([3])
+
+ self._test_loop_fn(loop_fn, 5)
+
+ def test_random_uniform_int(self):
+
+ def loop_fn(_):
+ return random_ops.random_uniform([3], maxval=1, dtype=dtypes.int32)
+
+ self._test_loop_fn(loop_fn, 5, loop_fn_dtypes=dtypes.int32)
+
+ def test_random_standard_normal(self):
+
+ def loop_fn(_):
+ return random_ops.random_normal([3])
+
+ self._test_loop_fn(loop_fn, 5)
+
+ def test_truncated_normal(self):
+
+ def loop_fn(_):
+ return random_ops.truncated_normal([3])
+
+ self._test_loop_fn(loop_fn, 5)
+
+ def test_random_gamma(self):
+
+ def loop_fn(_):
+ return random_ops.random_gamma([3], alpha=[0.5])
+
+ self._test_loop_fn(loop_fn, 5)
+
+ def test_random_poisson_v2(self):
+
+ def loop_fn(_):
+ return random_ops.random_poisson(lam=[1.3], shape=[3])
+
+ self._test_loop_fn(loop_fn, 5)
+
+
+class LoggingTest(PForTest):
+
+ def test_print(self):
+ x = random_ops.random_uniform([3, 5])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return logging_ops.Print(
+ x1, [x1, "x1", array_ops.shape(x1)], summarize=10)
+
+ self._test_loop_fn(loop_fn, 3)
+
+ def test_assert(self):
+
+ def loop_fn(i):
+ return control_flow_ops.Assert(i < 10, [i, [10], [i + 1]])
+
+ # TODO(agarwal): make this work with for_loop.
+ with session.Session() as sess:
+ sess.run(pfor_control_flow_ops.pfor(loop_fn, 3))
+
+
+class TensorArrayTest(PForTest):
+
+ def test_create_outside_and_read(self):
+
+ ta = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
+
+ def loop_fn(i):
+ return ta.read(i), ta.read(0)
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2)
+
+ def test_create_outside_and_gather(self):
+
+ ta = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, 0).write(1, 1)
+
+ def loop_fn(i):
+ return ta.gather([i]), ta.gather([0, 1])
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2)
+
+ def test_create_outside_and_write_and_scatter(self):
+
+ t = tensor_array_ops.TensorArray(dtypes.int32, 10, clear_after_read=False)
+ handle = t.handle
+
+ def loop_fn(i):
+ ta = t.write(i + 2, 2 * i).write(i, 5)
+ ta = ta.scatter([4 + i], [4]).scatter([6 + i, 8 + i], [6 + i, 8 + i])
+ return ta.flow
+
+ t1 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
+ out1 = tensor_array_ops.TensorArray(
+ dtypes.int32, handle=handle, flow=t1[-1]).stack()
+ output1 = self._run_targets(out1)
+
+ t2 = pfor_control_flow_ops.for_loop(loop_fn, dtypes.float32, iters=2)
+ out2 = tensor_array_ops.TensorArray(
+ dtypes.int32, handle=handle, flow=t2[-1]).stack()
+ output2 = self._run_targets(out2)
+ self.assertAllClose(output2, output1)
+
+ def test_create_inside_and_write(self):
+
+ def loop_fn(i):
+ # TODO(agarwal): switching the order of writes to ta1 does not work.
+ ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).write(0, i).write(
+ 1, 1)
+ ta2 = tensor_array_ops.TensorArray(dtypes.int32, 1).write(0, 1)
+ return ta1.stack(), ta2.stack()
+
+ self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2)
+
+ def test_create_inside_and_scatter(self):
+
+ def loop_fn(i):
+ # TODO(agarwal): switching the order of scatter to ta1 does not work.
+ ta1 = tensor_array_ops.TensorArray(dtypes.int32, 2).scatter(
+ [0], [[i, 2]]).scatter([1], [[1, 2]])
+ ta2 = tensor_array_ops.TensorArray(dtypes.int32,
+ 2).scatter([0], [3]).scatter([1], [4])
+ return ta1.stack(), ta2.stack()
+
+ self._test_loop_fn(loop_fn, 3, [dtypes.int32] * 2)
+
+ def test_create_inside_and_read(self):
+
+ def loop_fn(i):
+ ta1 = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
+ ta2 = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
+ # TODO(agarwal): ta1.read(i) currently is not supported.
+ return ta1.read(0), ta2.read(0), ta2.read(i)
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3)
+
+ def test_create_inside_and_gather(self):
+
+ def loop_fn(i):
+ ta1 = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, i).write(1, 1)
+ ta2 = tensor_array_ops.TensorArray(
+ dtypes.int32, 2, clear_after_read=False).write(0, 1).write(1, 2)
+ # TODO(agarwal): ta1.read(i) currently is not supported.
+ return ta1.gather([0, 1]), ta2.gather([0, 1]), ta2.gather([i])
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 3)
+
+ def test_grad(self):
+ x = random_ops.random_uniform([3, 2])
+ ta = tensor_array_ops.TensorArray(
+ dtypes.float32, 3, clear_after_read=False).unstack(x)
+ y = math_ops.square(ta.stack())
+
+ def loop_fn(i):
+ y_i = array_ops.gather(y, i)
+ grad = gradient_ops.gradients(y_i, x)[0]
+ return array_ops.gather(grad, i)
+
+ t1 = pfor_control_flow_ops.pfor(loop_fn, iters=3)
+ # y = x * x. Hence dy/dx = 2 * x.
+ actual_grad = 2.0 * x
+ with session.Session() as sess:
+ actual_grad, computed_grad = sess.run([t1, actual_grad])
+ self.assertAllClose(actual_grad, computed_grad)
+
+
+class StackTest(PForTest):
+
+ def test_stack_inside_loop_invariant(self):
+
+ def loop_fn(_):
+ s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
+ op1 = data_flow_ops.stack_push_v2(s, 1)
+ with ops.control_dependencies([op1]):
+ op2 = data_flow_ops.stack_push_v2(s, 2)
+ with ops.control_dependencies([op2]):
+ e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ with ops.control_dependencies([e2]):
+ e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ return e1, e2
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2)
+
+ def test_stack_inside_push_loop_dependent(self):
+
+ def loop_fn(i):
+ s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
+ op1 = data_flow_ops.stack_push_v2(s, i)
+ with ops.control_dependencies([op1]):
+ op2 = data_flow_ops.stack_push_v2(s, 2)
+ with ops.control_dependencies([op2]):
+ e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ with ops.control_dependencies([e2]):
+ e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ return e1, e2
+
+ self._test_loop_fn(loop_fn, 2, [dtypes.int32] * 2)
+
+ def test_stack_outside_pop(self):
+ s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
+ op = data_flow_ops.stack_push_v2(s, 5)
+ with ops.control_dependencies([op]):
+ op = data_flow_ops.stack_push_v2(s, 6)
+ with ops.control_dependencies([op]):
+ op = data_flow_ops.stack_push_v2(s, 7)
+
+ def loop_fn(_):
+ e1 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ with ops.control_dependencies([e1]):
+ e2 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ return e1, e2
+
+ with ops.control_dependencies([op]):
+ e1, e2 = pfor_control_flow_ops.pfor(loop_fn, iters=2)
+ with ops.control_dependencies([e1, e2]):
+ e3 = data_flow_ops.stack_pop_v2(s, elem_type=dtypes.int32)
+ v1, v2, v3 = self._run_targets([e1, e2, e3], run_init=False)
+ self.assertAllEqual([7, 7], v1)
+ self.assertAllEqual([6, 6], v2)
+ self.assertAllEqual(5, v3)
+
+ def test_stack_outside_push(self):
+ s = data_flow_ops.stack_v2(max_size=4, elem_type=dtypes.int32)
+
+ def loop_fn(_):
+ return data_flow_ops.stack_push_v2(s, 7)
+
+ with self.assertRaisesRegexp(ValueError, "StackPushV2 not allowed.*"):
+ pfor_control_flow_ops.pfor(loop_fn, iters=2)
+
+
+# TODO(agarwal): test nested while_loops. This currently requires converting a
+# tf.cond.
+class ControlFlowTest(PForTest):
+
+ def test_while_outside_loop(self):
+
+ x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
+
+ def loop_fn(i):
+ return x + i
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_invariant_while(self):
+
+ def loop_fn(_):
+ return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_invariant_while_with_control_dependency(self):
+
+ def loop_fn(i):
+ with ops.control_dependencies([i]):
+ return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
+ [0])
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_while_with_stateful_ops(self):
+
+ def loop_fn(_):
+ return control_flow_ops.while_loop(
+ lambda j, x: j < 4,
+ lambda j, x: (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_while_unstacked_condition(self):
+
+ def loop_fn(i):
+ return control_flow_ops.while_loop(lambda j, x: j < 4,
+ lambda j, x: (j + 1, x + i), [0, 0])
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32, dtypes.int32])
+
+ def test_while(self):
+ x = random_ops.random_uniform([3, 5])
+ lengths = constant_op.constant([4, 0, 2])
+
+ def loop_fn(i):
+ x_i = array_ops.gather(x, i)
+ lengths_i = array_ops.gather(lengths, i)
+
+ _, total = control_flow_ops.while_loop(
+ lambda j, _: j < lengths_i,
+ lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
+ return total
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32])
+
+ def test_while_jacobian(self):
+ x = random_ops.random_uniform([1, 3])
+ y = random_ops.random_uniform([3, 3])
+
+ # out = x @ y @ y @ y @ y, where @ is matmul operator.
+ _, out = control_flow_ops.while_loop(
+ lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
+ [0, x])
+
+ def loop_fn(i):
+ out_i = array_ops.gather(out, i, axis=1)
+ return array_ops.reshape(gradient_ops.gradients(out_i, x)[0], [-1])
+
+ out = pfor_control_flow_ops.pfor(loop_fn, iters=3)
+
+ # The above code does not work with tf.while_loop instead of pfor. So we
+ # manually compute the expected output here.
+ # Note that gradient of output w.r.t is (y @ y @ y @ y)^T.
+ expected_output = y
+ for _ in range(3):
+ expected_output = math_ops.matmul(expected_output, y)
+ expected_output = array_ops.transpose(expected_output, [1, 0])
+
+ with session.Session() as sess:
+ out, expected = sess.run([out, expected_output])
+ self.assertAllClose(expected, out)
+
+ def test_tensor_array_as_loop_variable(self):
+
+ def loop_fn(i):
+
+ def body(j, ta):
+ ta = ta.write(j, i + j * j)
+ return j + 1, ta
+
+ _, ta = control_flow_ops.while_loop(
+ lambda j, _: j < 4, body,
+ (0, tensor_array_ops.TensorArray(dtypes.int32, size=4)))
+ return ta.stack()
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_read_tensor_array_partitioned_indices(self):
+ # Note that tensor array values are pfor loop dependent, and the while loop
+ # termination condition is also dependent on pfor iteration.
+ def loop_fn(i):
+ ta = tensor_array_ops.TensorArray(dtypes.int32, size=6)
+ ta = ta.unstack(i + list(range(5)))
+
+ def body(j, s):
+ return j + 1, s + ta.read(j)
+
+ _, s = control_flow_ops.while_loop(lambda j, _: j < i,
+ body,
+ (0, 0))
+ return s
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_external_while_loop_grad(self):
+ # Here we test that external while_loops that are extended from inside pfor
+ # (due to gradient calls) are not actually converted. If the below was
+ # converted all pfor iterations would write to the same tensor array
+ # indices.
+ x = constant_op.constant(1.)
+
+ def body(j, ta):
+ ta = ta.write(j, x)
+ return j + 1, ta
+
+ _, ta = control_flow_ops.while_loop(
+ lambda j, _: j < 4, body,
+ (0, tensor_array_ops.TensorArray(dtypes.float32, size=4)))
+ out = ta.stack()
+
+ def loop_fn(i):
+ out_i = array_ops.gather(out, i)
+ return gradient_ops.gradients(out_i, x)[0]
+
+ with session.Session() as sess:
+ # out is [x, x, x]. Hence the gradients should be [1, 1, 1].
+ self.assertAllEqual([1, 1, 1],
+ sess.run(pfor_control_flow_ops.pfor(loop_fn, 3)))
+
+ def test_tensor_array_grad(self):
+ inp = constant_op.constant(np.random.rand(3, 4, 2), dtype=dtypes.float32)
+ ta = tensor_array_ops.TensorArray(dtypes.float32, size=3)
+ ta = ta.unstack(inp)
+
+ def loop_fn(i):
+
+ def body(j, x):
+ value = ta.gather([j])
+ value = array_ops.gather(array_ops.reshape(value, [4, 2]), i)
+ return j + 1, x + value
+
+ _, out = control_flow_ops.while_loop(lambda j, _: j < 3, body,
+ (0, array_ops.zeros([2])))
+ out = math_ops.reduce_prod(out)
+ return out, gradient_ops.gradients(out, inp)[0]
+
+ pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4)
+ # Note that tf.while_loop does not work in the setup above. So we manually
+ # construct the equivalent computation of the above loops here.
+ real_out = math_ops.reduce_sum(inp, reduction_indices=[0])
+ real_out = math_ops.reduce_prod(real_out, reduction_indices=[1])
+ # Note that gradients of real_out will accumulate the gradients across the
+ # output value. Hence we do the same aggregation on pfor_out_grad.
+ real_out_grad = gradient_ops.gradients(real_out, inp)[0]
+ sum_pfor_out_grad = math_ops.reduce_sum(
+ pfor_out_grad, reduction_indices=[0])
+
+ with session.Session() as sess:
+ v1, v2, v1_grad, v2_grad = sess.run(
+ [pfor_out, real_out, sum_pfor_out_grad, real_out_grad])
+ self.assertAllClose(v1, v2)
+ self.assertAllClose(v1_grad, v2_grad)
+
+
+def dynamic_lstm_input_fn(batch_size, state_size, max_steps):
+ # We make inputs and sequence_length constant so that multiple session.run
+ # calls produce the same result.
+ inputs = constant_op.constant(
+ np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32)
+ sequence_length = np.random.randint(0, size=[batch_size], high=max_steps + 1)
+ sequence_length = constant_op.constant(sequence_length, dtype=dtypes.int32)
+ return inputs, sequence_length
+
+
+def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
+ cell = cell_fn(state_size)
+ inputs, sequence_length = dynamic_lstm_input_fn(batch_size,
+ state_size,
+ max_steps)
+ inputs_ta = tensor_array_ops.TensorArray(
+ dtypes.float32, size=max_steps, element_shape=[batch_size, state_size])
+ inputs_time_major = array_ops.transpose(inputs, [1, 0, 2])
+ inputs_ta = inputs_ta.unstack(inputs_time_major)
+ zeros = array_ops.zeros([state_size])
+
+ def loop_fn(i):
+ sequence_length_i = array_ops.gather(sequence_length, i)
+
+ def body_fn(t, state, ta):
+ inputs_t = array_ops.expand_dims(
+ array_ops.gather(inputs_ta.read(t), i), 0)
+ output, new_state = cell(inputs_t, state)
+ output = array_ops.reshape(output, [-1])
+ # TODO(agarwal): one optimization that dynamic_rnn uses is to avoid the
+ # array_ops.where when t < min(sequence_length). Doing that requires
+ # supporting tf.cond pfor conversion.
+ done = t >= sequence_length_i
+ output = array_ops.where(done, zeros, output)
+ ta = ta.write(t, output)
+ new_state = [array_ops.where(done, s, ns) for s, ns in
+ zip(nest.flatten(state), nest.flatten(new_state))]
+ new_state = nest.pack_sequence_as(state, new_state)
+ return t + 1, new_state, ta
+
+ def condition_fn(t, _, unused):
+ del unused
+ return t < max_steps
+
+ initial_state = cell.zero_state(1, dtypes.float32)
+ _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
+ 0, initial_state,
+ tensor_array_ops.TensorArray(dtypes.float32, max_steps)
+ ])
+
+ new_state = [array_ops.reshape(x, [-1]) for x in nest.flatten(state)]
+ new_state = nest.pack_sequence_as(initial_state, new_state)
+ return ta.stack(), new_state
+
+ pfor_output = pfor_control_flow_ops.pfor(loop_fn, batch_size)
+ tf_output = rnn.dynamic_rnn(
+ cell,
+ inputs,
+ sequence_length=sequence_length,
+ initial_state=cell.zero_state(batch_size, dtypes.float32))
+ return pfor_output, tf_output
+
+
+class RNNTest(PForTest):
+
+ def test_dynamic_rnn(self):
+ pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell,
+ 3, 5, 7)
+ self.run_and_assert_equal(pfor_outputs, tf_outputs)
+
+ def test_dynamic_lstm(self):
+ pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell,
+ 3, 5, 7)
+ self.run_and_assert_equal(pfor_outputs, tf_outputs)
+
+
+# TODO(agarwal): benchmark numbers on GPU for graphs based on while_loop
+# conversion don't look good. Some of it seems like lot of copies between host
+# and device. Optimize that.
+class Benchmarks(test.Benchmark):
+
+ def _run(self, targets, iters, name=None):
+
+ def _done(t):
+ # Note that we don't use tf.control_dependencies since that will not make
+ # sure that the computation on GPU has actually finished. So we fetch the
+ # first element of the output, and assume that this will not be called on
+ # empty tensors.
+ return array_ops.gather(array_ops.reshape(t, [-1]), 0)
+
+ targets = [_done(x) for x in nest.flatten(targets)]
+ sess = session.Session()
+ with sess:
+ init = variables.global_variables_initializer()
+ sess.run(init)
+ sess.run(targets)
+ begin = time.time()
+ for _ in range(iters):
+ sess.run(targets)
+ end = time.time()
+ avg_time_ms = 1000 * (end - begin) / iters
+ self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
+ return avg_time_ms
+
+ def benchmark_basic_while(self):
+ with ops.Graph().as_default():
+
+ def loop_fn(i):
+ _, s = control_flow_ops.while_loop(
+ lambda t, x: t < i,
+ lambda t, x: (t + 1, x + i),
+ [0, 0])
+ return s
+
+ iters = 50
+ pfor_output = pfor_control_flow_ops.pfor(loop_fn, iters)
+ for_loop_output = pfor_control_flow_ops.for_loop(loop_fn, dtypes.int32,
+ iters)
+ self._run(pfor_output, 100, name="pfor_basic")
+ self._run(for_loop_output, 100, name="for_loop_basic")
+
+ def benchmark_dynamic_rnn(self):
+ with ops.Graph().as_default():
+ pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicRNNCell,
+ 128, 512, 16)
+ self._run(pfor_outputs, 100, name="pfor_rnn")
+ self._run(tf_outputs, 100, name="tf_rnn")
+
+ def benchmark_dynamic_lstm(self):
+ with ops.Graph().as_default():
+ pfor_outputs, tf_outputs = create_dynamic_lstm(rnn_cell.BasicLSTMCell,
+ 128, 512, 16)
+ self._run(pfor_outputs, 100, name="pfor_lstm")
+ self._run(tf_outputs, 100, name="tf_lstm")
+
+
+class SparseTest(PForTest):
+
+ def test_var_loop_len(self):
+ num_iters = array_ops.placeholder(dtypes.int32)
+
+ def loop_fn(_):
+ return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
+ [3]) # [0, 2, 0]
+
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ with self.test_session() as sess:
+ sess.run(pfor, feed_dict={num_iters: 3})
+
+ def test_sparse_result_none_stacked(self):
+ num_iters = 10
+
+ def loop_fn(_):
+ return sparse_tensor.SparseTensor([[0], [1], [2]], [4, 5, 6],
+ [3]) # [0, 2, 0]
+
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+
+ indices = [[i, j] for i in range(num_iters) for j in range(3)]
+ values = [4, 5, 6] * num_iters
+ dense_shapes = [num_iters, 3]
+ # Expected result: [[4, 5, 6], [4, 5, 6], [4, 5, 6], ...]
+ manual = sparse_tensor.SparseTensor(indices, values, dense_shapes)
+ self.run_and_assert_equal(pfor, manual)
+
+ def test_sparse_result_all_stacked(self):
+ num_iters = 10
+
+ def loop_fn(i):
+ i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
+ indices = array_ops.expand_dims(i, 0)
+ return sparse_tensor.SparseTensor(indices, i, i + 1) # [0, ..., 0, i]
+
+ # Expected result: [[0], [0, 1], [0, 0, 2], [0, 0, 0, 3], ...]
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
+ list(range(num_iters)),
+ (num_iters, num_iters))
+ self.run_and_assert_equal(pfor, manual)
+
+ def test_sparse_result_indices_stacked(self):
+ num_iters = 10
+
+ def loop_fn(i):
+ i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
+ indices = array_ops.expand_dims(i, 0)
+ return sparse_tensor.SparseTensor(indices, [1], [num_iters])
+
+ # Expected result: identity matrix size num_iters * num_iters
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ manual = sparse_tensor.SparseTensor([[i, i] for i in range(num_iters)],
+ [1] * num_iters, (num_iters, num_iters))
+ self.run_and_assert_equal(pfor, manual)
+
+ def test_sparse_result_values_stacked(self):
+ num_iters = 10
+
+ def loop_fn(i):
+ i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
+ return sparse_tensor.SparseTensor([[0]], i, [num_iters]) # [i, 0, ..., 0]
+
+ # Expected result: [[1, 0, ...], [2, 0, ...], [3, 0, ...], ...]
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
+ list(range(num_iters)),
+ (num_iters, num_iters))
+ self.run_and_assert_equal(pfor, manual)
+
+ def test_sparse_result_shapes_stacked(self):
+ num_iters = 10
+
+ def loop_fn(i):
+ i = array_ops.expand_dims(math_ops.cast(i, dtypes.int64), 0)
+ return sparse_tensor.SparseTensor([[0]], [1], i + 1) # [1, 0, ..., 0]
+
+ # Expected result: [[1, 0, 0, ...], [1, 0, 0, ...], ...]
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ manual = sparse_tensor.SparseTensor([[i, 0] for i in range(num_iters)],
+ [1] * num_iters, (num_iters, num_iters))
+ self.run_and_assert_equal(pfor, manual)
+
+ def test_sparse_result_shapes_stacked_2D(self):
+ num_iters = 10
+
+ def loop_fn(i):
+ i = array_ops.expand_dims(math_ops.cast(i + 1, dtypes.int64), 0)
+ shape = array_ops.concat([i, i], 0)
+ return sparse_tensor.SparseTensor([[0, 0]], [1], shape) # [1, 0, ..., 0]
+
+ # Expected result: [[[1, 0, ...], [0, ..., 0], [0, ..., 0], ...], ...]
+ pfor = pfor_control_flow_ops.pfor(loop_fn, num_iters)
+ manual = sparse_tensor.SparseTensor([[i, 0, 0] for i in range(num_iters)],
+ [1] * num_iters,
+ (num_iters, num_iters, num_iters))
+ self.run_and_assert_equal(pfor, manual)
+
+
+class ParsingTest(PForTest):
+
+ def test_decode_csv(self):
+ csv_tensor = constant_op.constant([["1:2:3"], ["::"], ["7:8:9"]])
+ kwargs = {"record_defaults": [[10], [20], [30]], "field_delim": ":"}
+
+ def loop_fn(i):
+ line = array_ops.gather(csv_tensor, i)
+ return parsing_ops.decode_csv(line, **kwargs)
+
+ self._test_loop_fn(loop_fn, iters=3, loop_fn_dtypes=[dtypes.int32] * 3)
+
+ def test_parse_single_example(self):
+
+ def _int64_feature(*values):
+ return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
+
+ def _bytes_feature(*values):
+ return feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[v.encode("utf-8") for v in values]))
+
+ examples = constant_op.constant([
+ example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "dense_int": _int64_feature(i),
+ "dense_str": _bytes_feature(str(i)),
+ "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8),
+ "sparse_str": _bytes_feature(*["abc"] * i)
+ })).SerializeToString() for i in range(10)
+ ])
+
+ features = {
+ "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
+ "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
+ "sparse_int": parsing_ops.VarLenFeature(dtypes.int64),
+ "sparse_str": parsing_ops.VarLenFeature(dtypes.string),
+ }
+
+ def loop_fn(i):
+ example_proto = array_ops.gather(examples, i)
+ f = parsing_ops.parse_single_example(example_proto, features)
+ return f
+
+ pfor = pfor_control_flow_ops.pfor(loop_fn, iters=10)
+ manual = parsing_ops.parse_example(examples, features)
+ self.run_and_assert_equal(pfor, manual)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
new file mode 100644
index 0000000000..ee3d5c9b86
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -0,0 +1,126 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Jacobian ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gradients as gradient_ops
+from tensorflow.python.ops.parallel_for import control_flow_ops
+from tensorflow.python.util import nest
+
+
+def jacobian(output, inputs, use_pfor=True):
+ """Computes jacobian of `output` w.r.t. `inputs`.
+
+ Args:
+ output: A tensor.
+ inputs: A tensor or a nested structure of tensor objects.
+ use_pfor: If true, uses pfor for computing the jacobian. Else uses
+ tf.while_loop.
+
+ Returns:
+ A tensor or a nested strucutre of tensors with the same structure as
+ `inputs`. Each entry is the jacobian of `output` w.rt. to the corresponding
+ value in `inputs`. If output has shape [y_1, ..., y_n] and inputs_i has
+ shape [x_1, ..., x_m], the corresponding jacobian has shape
+ [y_1, ..., y_n, x_1, ..., x_m].
+ """
+ flat_inputs = nest.flatten(inputs)
+ output_shape = array_ops.shape(output)
+ output = array_ops.reshape(output, [-1])
+
+ def loop_fn(i):
+ y = array_ops.gather(output, i)
+ return gradient_ops.gradients(y, flat_inputs)
+
+ try:
+ output_size = int(output.shape[0])
+ except TypeError:
+ output_size = array_ops.shape(output)[0]
+
+ if use_pfor:
+ pfor_outputs = control_flow_ops.pfor(loop_fn, output_size)
+ else:
+ pfor_outputs = control_flow_ops.for_loop(
+ loop_fn, [output.dtype] * len(flat_inputs), output_size)
+
+ for i, out in enumerate(pfor_outputs):
+ new_shape = array_ops.concat(
+ [output_shape, array_ops.shape(out)[1:]], axis=0)
+ out = array_ops.reshape(out, new_shape)
+ pfor_outputs[i] = out
+
+ return nest.pack_sequence_as(inputs, pfor_outputs)
+
+
+def batch_jacobian(output, inp, use_pfor=True):
+ """Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`.
+
+ e.g.
+ x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
+ y = x * x
+ jacobian = batch_jacobian(y, x)
+ # => [[[2, 0], [0, 4]], [[6, 0], [0, 8]]]
+
+ Args:
+ output: A tensor with shape [b, y1, ..., y_n]. `output[i,...]` should
+ only depend on `inp[i,...]`.
+ inp: A tensor with shape [b, x1, ..., x_m]
+ use_pfor: If true, uses pfor for computing the Jacobian. Else uses a
+ tf.while_loop.
+
+ Returns:
+ A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
+ is the jacobian of `output[i, ...]` w.r.t. `inp[i, ...]`, i.e. stacked
+ per-example jacobians.
+
+ Raises:
+ ValueError: if first dimension of `output` and `inp` do not match.
+ """
+ output_shape = output.shape
+ if not output_shape[0].is_compatible_with(inp.shape[0]):
+ raise ValueError("Need first dimension of output shape (%s) and inp shape "
+ "(%s) to match." % (output.shape, inp.shape))
+ if output_shape.is_fully_defined():
+ batch_size = int(output_shape[0])
+ output_row_size = output_shape.num_elements() // batch_size
+ else:
+ output_shape = array_ops.shape(output)
+ batch_size = output_shape[0]
+ output_row_size = array_ops.size(output) // batch_size
+ inp_shape = array_ops.shape(inp)
+ # Flatten output to 2-D.
+ with ops.control_dependencies(
+ [check_ops.assert_equal(batch_size, inp_shape[0])]):
+ output = array_ops.reshape(output, [batch_size, output_row_size])
+
+ def loop_fn(i):
+ y = array_ops.gather(output, i, axis=1)
+ return gradient_ops.gradients(y, inp)[0]
+
+ if use_pfor:
+ pfor_output = control_flow_ops.pfor(loop_fn, output_row_size)
+ else:
+ pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
+ output_row_size)
+ pfor_output = array_ops.reshape(pfor_output,
+ [output_row_size, batch_size, -1])
+ output = array_ops.transpose(pfor_output, [1, 0, 2])
+ new_shape = array_ops.concat([output_shape, inp_shape[1:]], axis=0)
+ return array_ops.reshape(output, new_shape)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
new file mode 100644
index 0000000000..310a2154f7
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -0,0 +1,568 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for jacobian and batch_jacobian ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import training as keras_training
+from tensorflow.python.layers import layers as tf_layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients as gradient_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.ops.parallel_for import control_flow_ops
+from tensorflow.python.ops.parallel_for import gradients
+from tensorflow.python.platform import test
+from tensorflow.python.util import nest
+
+
+class FullyConnectedModel(object):
+
+ def __init__(self, activation_size, num_layers):
+ self._layers = [
+ tf_layers.Dense(activation_size, activation=nn.relu)
+ for _ in range(num_layers)
+ ]
+
+ def __call__(self, inp):
+ activation = inp
+ for layer in self._layers:
+ activation = layer(activation)
+ return activation
+
+
+def fully_connected_model_fn(batch_size, activation_size, num_layers):
+ model = FullyConnectedModel(activation_size, num_layers)
+ inp = random_ops.random_normal([batch_size, activation_size])
+ return inp, model(inp)
+
+
+def lstm_model_fn(batch_size, state_size, steps):
+ inputs = [
+ random_ops.random_normal([batch_size, state_size]) for _ in range(steps)
+ ]
+ cell = rnn_cell.BasicLSTMCell(state_size)
+ init_state = cell.zero_state(batch_size, dtypes.float32)
+ state = init_state
+ for inp in inputs:
+ _, state = cell(inp, state)
+ return init_state.c, state.c
+
+
+def dynamic_lstm_model_fn(batch_size, state_size, max_steps):
+ # We make inputs and sequence_length constant so that multiple session.run
+ # calls produce the same result.
+ inputs = constant_op.constant(
+ np.random.rand(batch_size, max_steps, state_size), dtype=dtypes.float32)
+ sequence_length = constant_op.constant(
+ np.random.randint(0, size=[batch_size], high=max_steps + 1),
+ dtype=dtypes.int32)
+
+ cell = rnn_cell.BasicLSTMCell(state_size)
+ initial_state = cell.zero_state(batch_size, dtypes.float32)
+ return inputs, rnn.dynamic_rnn(
+ cell,
+ inputs,
+ sequence_length=sequence_length,
+ initial_state=initial_state)
+
+
+def create_fc_batch_jacobian(batch_size, activation_size, num_layers):
+ inp, output = fully_connected_model_fn(batch_size, activation_size,
+ num_layers)
+ pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True)
+ while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False)
+ return pfor_jacobian, while_jacobian
+
+
+def create_lstm_batch_jacobian(batch_size, state_size, steps):
+ inp, output = lstm_model_fn(batch_size, state_size, steps)
+ pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True)
+ while_jacobian = gradients.batch_jacobian(output, inp, use_pfor=False)
+ return pfor_jacobian, while_jacobian
+
+
+def create_dynamic_lstm_batch_jacobian(batch_size, state_size, max_steps):
+ inp, (_, final_state) = dynamic_lstm_model_fn(batch_size, state_size,
+ max_steps)
+ pfor_jacobian = gradients.batch_jacobian(final_state.c, inp, use_pfor=True)
+ # Note that use_pfor=False does not work above given the current limitations
+ # on implementation of while_loop. So we statically unroll the looping in the
+ # jacobian computation.
+ while_gradients = [
+ gradient_ops.gradients(array_ops.gather(final_state.c, i, axis=1), inp)[0]
+ for i in range(state_size)
+ ]
+ return pfor_jacobian, while_gradients
+
+
+def create_lstm_batch_hessian(batch_size, state_size, steps):
+ inp, output = lstm_model_fn(batch_size, state_size, steps)
+ pfor_jacobian = gradients.batch_jacobian(output, inp, use_pfor=True)
+ pfor_jacobian = array_ops.reshape(pfor_jacobian, [batch_size, -1])
+ pfor_hessian = gradients.batch_jacobian(pfor_jacobian, inp, use_pfor=True)
+ # TODO(agarwal): using two nested while_loop doesn't seem to work here.
+ # Hence we use pfor_jacobian for computing while_hessian.
+ while_jacobian = pfor_jacobian
+ while_hessian = gradients.batch_jacobian(while_jacobian, inp, use_pfor=False)
+ return pfor_hessian, while_hessian
+
+
+def create_lstm_hessian(batch_size, state_size, steps):
+ _, output = lstm_model_fn(batch_size, state_size, steps)
+ weights = variables.trainable_variables()
+ pfor_jacobians = gradients.jacobian(output, weights, use_pfor=True)
+ pfor_hessians = [
+ gradients.jacobian(x, weights, use_pfor=True) for x in pfor_jacobians
+ ]
+ # TODO(agarwal): using two nested while_loop doesn't seem to work here.
+ # Hence we use pfor_jacobians for computing while_hessians.
+ while_jacobians = pfor_jacobians
+ while_hessians = [
+ gradients.jacobian(x, weights, use_pfor=False) for x in while_jacobians
+ ]
+ return pfor_hessians, while_hessians
+
+
+def create_fc_per_eg_grad(batch_size, activation_size, num_layers):
+ inp = random_ops.random_normal([batch_size, activation_size])
+ layers = [
+ tf_layers.Dense(activation_size, activation=nn.relu)
+ for _ in range(num_layers)
+ ]
+ projection = tf_layers.Dense(1)
+
+ def model_fn(activation):
+ for layer in layers:
+ activation = layer(activation)
+ activation = projection(activation)
+ activation = nn.l2_loss(activation)
+ return gradient_ops.gradients(activation, variables.trainable_variables())
+
+ def loop_fn(i):
+ return model_fn(array_ops.expand_dims(array_ops.gather(inp, i), 0))
+
+ pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size)
+ loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()]
+ while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size)
+ return pfor_outputs, while_outputs
+
+
+def create_lstm_per_eg_grad(batch_size, state_size, steps):
+ inputs = [
+ random_ops.random_normal([batch_size, state_size]) for _ in range(steps)
+ ]
+ cell = rnn_cell.BasicLSTMCell(state_size)
+ init_state = cell.zero_state(batch_size, dtypes.float32)
+
+ def model_fn(inps, init_state):
+ state = init_state
+ for inp in inps:
+ _, state = cell(inp, state)
+ output = nn.l2_loss(state.c)
+ return gradient_ops.gradients(output, variables.trainable_variables())
+
+ def loop_fn(i):
+ loop_inputs = [
+ array_ops.expand_dims(array_ops.gather(x, i), 0) for x in inputs
+ ]
+ loop_init_state = rnn_cell.LSTMStateTuple(
+ *[array_ops.expand_dims(array_ops.gather(x, i), 0) for x in init_state])
+ return model_fn(loop_inputs, loop_init_state)
+
+ pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size)
+ loop_fn_dtypes = [x.dtype for x in variables.trainable_variables()]
+ while_outputs = control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, batch_size)
+ return pfor_outputs, while_outputs
+
+
+# Importing the code from tensorflow_models seems to cause errors. Hence we
+# duplicate the model definition here.
+# TODO(agarwal): Use the version in tensorflow_models/official instead.
+class Mnist(keras_training.Model):
+
+ def __init__(self, data_format):
+ """Creates a model for classifying a hand-written digit.
+
+ Args:
+ data_format: Either 'channels_first' or 'channels_last'.
+ """
+ super(Mnist, self).__init__()
+ if data_format == "channels_first":
+ self._input_shape = [-1, 1, 28, 28]
+ else:
+ assert data_format == "channels_last"
+ self._input_shape = [-1, 28, 28, 1]
+
+ self.conv1 = tf_layers.Conv2D(
+ 32, 5, padding="same", data_format=data_format, activation=nn.relu)
+ self.conv2 = tf_layers.Conv2D(
+ 64, 5, padding="same", data_format=data_format, activation=nn.relu)
+ self.fc1 = tf_layers.Dense(1024, activation=nn.relu)
+ self.fc2 = tf_layers.Dense(10)
+ self.dropout = tf_layers.Dropout(0.4)
+ self.max_pool2d = tf_layers.MaxPooling2D(
+ (2, 2), (2, 2), padding="same", data_format=data_format)
+
+ def __call__(self, inputs, training):
+ """Add operations to classify a batch of input images.
+
+ Args:
+ inputs: A Tensor representing a batch of input images.
+ training: A boolean. Set to True to add operations required only when
+ training the classifier.
+
+ Returns:
+ A logits Tensor with shape [<batch_size>, 10].
+ """
+ y = array_ops.reshape(inputs, self._input_shape)
+ y = self.conv1(y)
+ y = self.max_pool2d(y)
+ y = self.conv2(y)
+ y = self.max_pool2d(y)
+ y = tf_layers.flatten(y)
+ y = self.fc1(y)
+ y = self.dropout(y, training=training)
+ return self.fc2(y)
+
+
+def create_mnist_per_eg_grad(batch_size, data_format, training):
+ images = random_ops.random_uniform([batch_size, 28, 28])
+ sparse_labels = np.random.randint(
+ low=0, high=10, size=[batch_size]).astype(np.int32)
+ labels = np.zeros((batch_size, 10)).astype(np.float32)
+ labels[np.arange(batch_size), sparse_labels] = 1.
+ model = Mnist(data_format)
+
+ def loop_fn(i):
+ image = array_ops.gather(images, i)
+ label = array_ops.gather(labels, i)
+ logits = array_ops.reshape(model(image, training=training), [-1])
+ loss = losses.softmax_cross_entropy(
+ logits=logits, onehot_labels=label, reduction=losses.Reduction.NONE)
+ return gradient_ops.gradients(loss, variables.trainable_variables())
+
+ pfor_outputs = control_flow_ops.pfor(loop_fn, batch_size)
+ while_outputs = control_flow_ops.for_loop(
+ loop_fn, [dtypes.float32] * len(variables.trainable_variables()),
+ batch_size)
+ return pfor_outputs, while_outputs
+
+
+def create_mnist_per_eg_jacobian(batch_size, data_format, training):
+ images = random_ops.random_uniform([batch_size, 28, 28])
+ model = Mnist(data_format)
+
+ def loop_fn(i, use_pfor):
+ image = array_ops.gather(images, i)
+ logits = array_ops.reshape(model(image, training=training), [-1])
+ return gradients.jacobian(
+ logits, variables.trainable_variables(), use_pfor=use_pfor)
+
+ pfor_outputs = control_flow_ops.pfor(
+ functools.partial(loop_fn, use_pfor=True),
+ batch_size)
+ while_outputs = control_flow_ops.for_loop(
+ functools.partial(loop_fn, use_pfor=False),
+ [dtypes.float32] * len(variables.trainable_variables()), batch_size)
+ return pfor_outputs, while_outputs
+
+
+def create_fc_per_eg_jacobians(batch_size, activation_size, num_layers):
+ model = FullyConnectedModel(activation_size=activation_size,
+ num_layers=num_layers)
+ inp = random_ops.random_normal([batch_size, activation_size])
+ output = model(inp)
+ jacobians = gradients.jacobian(output, variables.trainable_variables())
+
+ def loop_fn(i, use_pfor):
+ inp_i = array_ops.expand_dims(array_ops.gather(inp, i), 0)
+ output = array_ops.reshape(model(inp_i), [-1])
+ return gradients.jacobian(
+ output, variables.trainable_variables(), use_pfor=use_pfor)
+
+ per_eg_jacobians_pfor = control_flow_ops.pfor(
+ functools.partial(loop_fn, use_pfor=True),
+ batch_size)
+ per_eg_jacobians_while = control_flow_ops.for_loop(
+ functools.partial(loop_fn, use_pfor=False),
+ [dtypes.float32] * len(variables.trainable_variables()), batch_size)
+ return jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while
+
+
+class GradientsTest(test.TestCase):
+
+ def run_and_assert_equal(self, targets1, targets2, atol=1e-4, rtol=1e-4):
+ targets1 = nest.flatten(targets1)
+ targets2 = nest.flatten(targets2)
+ assert len(targets1) == len(targets2)
+ init = variables.global_variables_initializer()
+ self.evaluate(init)
+ outputs = self.evaluate(targets1 + targets2)
+ n = len(outputs) // 2
+ for i in range(n):
+ self.assertAllClose(outputs[i], outputs[i + n], rtol=rtol, atol=atol)
+
+ def test_jacobian_fixed_shape(self):
+ x = random_ops.random_uniform([2, 2])
+ y = math_ops.matmul(x, x, transpose_a=True)
+ jacobian_pfor = gradients.jacobian(y, x, use_pfor=True)
+ jacobian_while = gradients.jacobian(y, x, use_pfor=False)
+ answer = ops.convert_to_tensor([[
+ gradient_ops.gradients(y[0][0], x)[0],
+ gradient_ops.gradients(y[0][1], x)[0]
+ ], [
+ gradient_ops.gradients(y[1][0], x)[0],
+ gradient_ops.gradients(y[1][1], x)[0]
+ ]])
+ self.run_and_assert_equal(answer, jacobian_pfor)
+ self.run_and_assert_equal(answer, jacobian_while)
+
+ def test_jacobian_unknown_shape(self):
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ y = math_ops.matmul(x, x, transpose_a=True)
+ jacobian_pfor = gradients.jacobian(y, x, use_pfor=True)
+ jacobian_while = gradients.jacobian(y, x, use_pfor=False)
+ answer = ops.convert_to_tensor([[
+ gradient_ops.gradients(y[0][0], x)[0],
+ gradient_ops.gradients(y[0][1], x)[0]
+ ], [
+ gradient_ops.gradients(y[1][0], x)[0],
+ gradient_ops.gradients(y[1][1], x)[0]
+ ]])
+ ans, pfor_value, while_value = sess.run(
+ [answer, jacobian_pfor, jacobian_while],
+ feed_dict={x: [[1, 2], [3, 4]]})
+ self.assertAllClose(ans, pfor_value)
+ self.assertAllClose(ans, while_value)
+
+ def test_batch_jacobian_bad_shapes(self):
+ x = random_ops.random_uniform([2, 2])
+ y = random_ops.random_uniform([3, 2])
+ with self.assertRaisesRegexp(ValueError, "Need first dimension of output"):
+ gradients.batch_jacobian(y, x, use_pfor=True)
+
+ def test_batch_jacobian_bad_unknown_shapes(self):
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.concat([x, x], axis=0)
+ jacobian = gradients.batch_jacobian(y, x)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "assertion failed"):
+ sess.run(jacobian, feed_dict={x: [[1, 2], [3, 4]]})
+
+ def test_batch_jacobian_fixed_shape(self):
+ x = random_ops.random_uniform([2, 3, 5])
+ y = x * x
+ batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True)
+ batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False)
+ two_x = 2 * x
+ answer = array_ops.stack(
+ [array_ops.diag(two_x[0]),
+ array_ops.diag(two_x[1])])
+ self.run_and_assert_equal(answer, batch_jacobian_pfor)
+ self.run_and_assert_equal(answer, batch_jacobian_while)
+
+ def test_batch_jacobian_unknown_shape(self):
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32)
+ y = x * x
+ batch_jacobian_pfor = gradients.batch_jacobian(y, x, use_pfor=True)
+ batch_jacobian_while = gradients.batch_jacobian(y, x, use_pfor=False)
+ two_x = 2 * x
+ answer = array_ops.stack(
+ [array_ops.diag(two_x[0]),
+ array_ops.diag(two_x[1])])
+ ans, pfor_value, while_value = sess.run(
+ [answer, batch_jacobian_pfor, batch_jacobian_while],
+ feed_dict={x: [[1, 2], [3, 4]]})
+ self.assertAllClose(ans, pfor_value)
+ self.assertAllClose(ans, while_value)
+
+ def test_fc_batch_jacobian(self):
+ pfor_jacobian, while_jacobian = create_fc_batch_jacobian(8, 4, 2)
+ self.run_and_assert_equal(pfor_jacobian, while_jacobian)
+
+ def test_lstm_batch_jacobian(self):
+ pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(8, 4, 2)
+ self.run_and_assert_equal(pfor_jacobian, while_jacobian)
+
+ def test_dynamic_lstm_batch_jacobian(self):
+ pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3)
+ with session.Session() as sess:
+ init = variables.global_variables_initializer()
+ sess.run(init)
+ pfor = sess.run(pfor_jacobian)
+ for i in range(4):
+ while_i = sess.run(while_gradients[i])
+ self.assertAllClose(while_i, pfor[:, i, ...])
+
+ def test_lstm_hessian(self):
+ pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 2)
+ self.run_and_assert_equal(pfor_hessian, while_hessian)
+
+ def test_lstm_batch_hessian(self):
+ pfor_hessian, while_hessian = create_lstm_batch_hessian(2, 2, 2)
+ self.run_and_assert_equal(pfor_hessian, while_hessian)
+
+ def test_fc_per_eg_grad(self):
+ pfor_outputs, while_outputs = create_fc_per_eg_grad(8, 4, 2)
+ self.run_and_assert_equal(pfor_outputs, while_outputs)
+
+ def test_lstm_per_eg_grad(self):
+ pfor_outputs, while_outputs = create_lstm_per_eg_grad(8, 4, 2)
+ self.run_and_assert_equal(pfor_outputs, while_outputs)
+
+ def test_mnist_per_eg_grad(self):
+ data_format = ("channels_first"
+ if test.is_gpu_available() else "channels_last")
+ # Note that we we are setting training=False here so that dropout produces
+ # the same result with pfor and with while_loop.
+ pfor_outputs, while_outputs = create_mnist_per_eg_grad(
+ 4, data_format, training=False)
+ self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3)
+
+ def test_mnist_per_eg_jacobian(self):
+ data_format = ("channels_first"
+ if test.is_gpu_available() else "channels_last")
+ # Note that we we are setting training=False here so that dropout produces
+ # the same result with pfor and with while_loop.
+ pfor_outputs, while_outputs = create_mnist_per_eg_jacobian(
+ 2, data_format, training=False)
+ self.run_and_assert_equal(pfor_outputs, while_outputs, rtol=1e-3)
+
+ def test_fc_jacobian(self):
+ jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = (
+ create_fc_per_eg_jacobians(batch_size=8,
+ activation_size=4,
+ num_layers=2))
+ self.run_and_assert_equal(jacobians, per_eg_jacobians_pfor,
+ rtol=2e-3, atol=1e-3)
+ self.run_and_assert_equal(jacobians, per_eg_jacobians_while,
+ rtol=2e-3, atol=1e-3)
+
+
+class GradientsBenchmarks(test.Benchmark):
+
+ def _run(self, targets, iters, name=None):
+
+ def _done(t):
+ # Note that we don't use tf.control_dependencies since that will not make
+ # sure that the computation on GPU has actually finished. So we fetch the
+ # first element of the output, and assume that this will not be called on
+ # empty tensors.
+ return array_ops.gather(array_ops.reshape(t, [-1]), 0)
+
+ targets = [_done(x) for x in nest.flatten(targets)]
+ sess = session.Session()
+ with sess:
+ init = variables.global_variables_initializer()
+ sess.run(init)
+ sess.run(targets)
+ begin = time.time()
+ for _ in range(iters):
+ sess.run(targets)
+ end = time.time()
+ avg_time_ms = 1000 * (end - begin) / iters
+ self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
+ return avg_time_ms
+
+ def benchmark_fc_batch_jacobian(self):
+ with ops.Graph().as_default():
+ pfor_jacobian, while_jacobian = create_fc_batch_jacobian(100, 32, 20)
+ self._run(pfor_jacobian, 100, name="fc_batch_jacobian_pfor")
+ self._run(while_jacobian, 20, name="fc_batch_jacobian_while")
+
+ def benchmark_lstm_batch_jacobian(self):
+ with ops.Graph().as_default():
+ pfor_jacobian, while_jacobian = create_lstm_batch_jacobian(100, 32, 8)
+ self._run(pfor_jacobian, 100, name="lstm_batch_jacobian_pfor")
+ self._run(while_jacobian, 20, name="lstm_batch_jacobian_while")
+
+ def benchmark_lstm_hessian(self):
+ with ops.Graph().as_default():
+ pfor_hessian, while_hessian = create_lstm_hessian(2, 2, 10)
+ self._run(pfor_hessian, 20, name="lstm_hessian_pfor")
+ self._run(while_hessian, 3, name="lstm_hessian_while_pfor")
+
+ def benchmark_lstm_batch_hessian(self):
+ with ops.Graph().as_default():
+ pfor_hessian, while_hessian = create_lstm_batch_hessian(4, 4, 10)
+ self._run(pfor_hessian, 100, name="lstm_batch_hessian_pfor")
+ self._run(while_hessian, 20, name="lstm_batch_hessian_while_pfor")
+
+ def benchmark_fc_per_eg_grad(self):
+ with ops.Graph().as_default():
+ pfor_outputs, while_outputs = create_fc_per_eg_grad(100, 32, 3)
+ self._run(pfor_outputs, 100, name="fc_per_eg_grad_pfor")
+ self._run(while_outputs, 20, name="fc_per_eg_grad_while")
+
+ def benchmark_lstm_per_eg_grad(self):
+ with ops.Graph().as_default():
+ pfor_outputs, while_outputs = create_lstm_per_eg_grad(100, 32, 8)
+ self._run(pfor_outputs, 100, name="lstm_per_eg_grad_pfor")
+ self._run(while_outputs, 20, name="lstm_per_eg_grad_while")
+
+ def benchmark_mnist_per_eg_grad(self):
+ with ops.Graph().as_default():
+ data_format = ("channels_first"
+ if test.is_gpu_available() else "channels_last")
+ pfor_outputs, while_outputs = create_mnist_per_eg_grad(
+ 128, data_format, training=True)
+ self._run(pfor_outputs, 20, name="mnist_per_eg_grad_pfor")
+ self._run(while_outputs, 20, name="mnist_per_eg_grad_while")
+
+ def benchmark_mnist_per_eg_jacobian(self):
+ with ops.Graph().as_default():
+ data_format = ("channels_first"
+ if test.is_gpu_available() else "channels_last")
+ pfor_outputs, while_outputs = create_mnist_per_eg_jacobian(
+ 16, data_format, training=True)
+ self._run(pfor_outputs, 20, name="mnist_per_eg_jacobian_pfor")
+ self._run(while_outputs, 20, name="mnist_per_eg_jacobian_while")
+
+ def benchmark_fc_per_eg_jacobian(self):
+ with ops.Graph().as_default():
+ jacobians, per_eg_jacobians_pfor, per_eg_jacobians_while = (
+ create_fc_per_eg_jacobians(batch_size=128,
+ activation_size=32,
+ num_layers=3))
+ self._run(jacobians, 30, name="fc_jacobians_pfor")
+ self._run(per_eg_jacobians_pfor, 100,
+ name="fc_per_eg_jacobians_pfor")
+ self._run(per_eg_jacobians_while, 10,
+ name="fc_per_eg_jacobians_while")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
new file mode 100644
index 0000000000..77ec3bc0d4
--- /dev/null
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -0,0 +1,2552 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compiled parallel-for loop."""
+# pylint: disable=missing-docstring
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from absl import flags
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_sparse_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
+
+flags.DEFINE_bool(
+ "op_conversion_fallback_to_while_loop", False,
+ "If true, falls back to using a while loop for ops for "
+ "which a converter is not defined.")
+
+
+def _stack(t, length):
+ """stacks `t` `length` times."""
+ ones = array_ops.ones_like(array_ops.shape(t))
+ multiples = array_ops.concat([length, ones], 0)
+ t = array_ops.tile(array_ops.expand_dims(t, 0), multiples)
+ return wrap(t, True)
+
+
+# The following stateful ops can be safely called once, and with the same
+# signature as the unconverted version, if their inputs are loop invariant.
+# TODO(agarwal): implement a strategy for converting Variable reads/writes. The
+# plan is to map each read/write in the loop_fn to a corresponding merged
+# read/write in the converted graph. Writes need to be mergeable (e.g.
+# AssignAdd) to be used in `pfor`. Given a certain read/write order in the
+# loop_fn, doing a one-to-one conversion will simulate executing such
+# instructions in lock-step across all iterations.
+passthrough_stateful_ops = set([
+ "VariableV2",
+ "VarHandleOp",
+ "ReadVariableOp",
+ "StackV2",
+ "TensorArrayWriteV3",
+ "TensorArrayReadV3",
+ "TensorArraySizeV3",
+])
+
+
+def _is_stateful_pfor_op(op):
+ if isinstance(op, WhileOp):
+ return op.is_stateful
+ if op.type == "Const":
+ # Const didn't have an op_def.
+ return False
+ if op.type in passthrough_stateful_ops:
+ return False
+ assert hasattr(op, "op_def") and op.op_def is not None, op
+ return op.op_def.is_stateful
+
+
+# pylint: disable=protected-access
+class WhileOp(object):
+ """Object for storing state for converting the outputs of a while_loop."""
+
+ def __init__(self, exit_node, pfor_ops):
+ """Initializer.
+
+ Args:
+ exit_node: A tensor output from the while_loop.
+ pfor_ops: list of ops inside the current pfor loop.
+ """
+ self._pfor_ops = set(pfor_ops)
+ self._pfor_op_ids = set([x._id for x in pfor_ops])
+ assert isinstance(exit_node, ops.Tensor)
+ self._while_context = exit_node.op._get_control_flow_context()
+ assert isinstance(self._while_context, control_flow_ops.WhileContext)
+ self._context_name = self._while_context.name
+ self._condition = self._while_context.pivot.op.inputs[0]
+ # Parts of an external while_loop could be created inside a pfor loop.
+ # However for the purpose here, we declare such loops to be external. Also
+ # note that we check if the condition was created inside or outside to
+ # determine if the while_loop was first created inside or outside.
+ # TODO(agarwal): check that the Enter and Exit of this loop are unstacked.
+ self._is_inside_loop = self.op_is_inside_loop(self._condition.op)
+ if self._is_inside_loop:
+ for e in self._while_context.loop_exits:
+ assert self.op_is_inside_loop(e.op)
+
+ # Note the code below tries to reverse engineer an existing while_loop graph
+ # by assuming the following pattern of nodes.
+ #
+ # NextIteration <---- Body <--- Enter
+ # | ^
+ # V ___| Y
+ # Enter -> Merge -> Switch___
+ # ^ | N
+ # | V
+ # LoopCond Exit
+
+ # Node that elements in the list below correspond one-to-one with each
+ # other. i.e. these lists are the same size, and the i_th entry corresponds
+ # to different Operations/Tensors of a single cycle as illustrated above.
+ # List of Switch ops (ops.Operation) that feed into an Exit Node.
+ self._exit_switches = []
+ # List of inputs (ops.Tensor) to NextIteration.
+ self._body_outputs = []
+ # List of list of control inputs of the NextIteration nodes.
+ self._next_iter_control_inputs = []
+ # List of Merge ops (ops.Operation).
+ self._enter_merges = []
+ # List of output (ops.Tensor) of Exit nodes.
+ self._outputs = []
+
+ # List of Enter Tensors.
+ # There are two types of Enter nodes:
+ # - The Enter nodes that are used in the `loop_vars` argument to
+ # `while_loop` (see
+ # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect
+ # these Enter nodes immediately below by tracing backwards from the Exit
+ # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the
+ # diagram above. This allows us to have a 1:1 correspondence between the
+ # self._outputs and the first elements in self._enters.
+ # - The Enter nodes that are used only by the body. They don't appear in the
+ # `loop_vars` and are not returned from the `while_loop`. In Python code,
+ # they are usually captured by the body lambda. We collect them below by
+ # iterating over all the ops in the graph. They are appended to the end of
+ # self._enters or self._direct_enters, and don't correspond to any outputs
+ # in self._outputs. Note that we keep the resource/variant Enter nodes in
+ # self._direct_enters and the constructed while_loop's body uses them
+ # directly as opposed to passing them as loop variables. This is done
+ # because the while_body cannot partition the resource/variant Tensors, so
+ # it has to leave them unchanged.
+ self._enters = []
+ self._direct_enters = []
+
+ for e in self._while_context.loop_exits:
+ self._outputs.append(e.op.outputs[0])
+ switch = e.op.inputs[0].op
+ assert switch.type == "Switch", switch
+ self._exit_switches.append(switch)
+ merge = switch.inputs[0].op
+ assert merge.type == "Merge", merge
+ self._enter_merges.append(merge)
+ enter = merge.inputs[0].op
+ assert enter.type == "Enter", enter
+ self._enters.append(enter.outputs[0])
+ next_iter = merge.inputs[1].op
+ assert next_iter.type == "NextIteration", next_iter
+ self._body_outputs.append(next_iter.inputs[0])
+ self._next_iter_control_inputs.append(next_iter.control_inputs)
+
+ # Collect all the Enter nodes that are not part of `loop_vars`, the second
+ # category described above.
+ # Also track whether the loop body has any stateful ops.
+ self._is_stateful = False
+ for op in ops.get_default_graph().get_operations():
+ # TODO(agarwal): make sure this works with nested case.
+ control_flow_context = op._get_control_flow_context()
+ if control_flow_context is None:
+ continue
+ if control_flow_context.name == self._context_name:
+ self._is_stateful |= _is_stateful_pfor_op(op)
+ if op.type == "Enter":
+ output = op.outputs[0]
+ if output not in self._enters:
+ if output.dtype in (dtypes.resource, dtypes.variant):
+ if output not in self._direct_enters:
+ self._direct_enters.append(output)
+ else:
+ self._enters.append(output)
+
+ def __str__(self):
+ """String representation."""
+ return "while_loop(%s)" % self.name
+
+ @property
+ def inputs(self):
+ """Input to all the Enter nodes."""
+ return [x.op.inputs[0] for x in self._enters + self._direct_enters]
+
+ @property
+ def control_inputs(self):
+ """Control input to all the Enter nodes."""
+ control_inputs = []
+ for x in self._enters + self._direct_enters:
+ control_inputs.extend(x.op.control_inputs)
+ return control_inputs
+
+ @property
+ def outputs(self):
+ """Outputs of all the Exit nodes."""
+ return self._outputs
+
+ @property
+ def name(self):
+ """Context name for the while loop."""
+ return self._context_name
+
+ @property
+ def is_inside_loop(self):
+ """Returns true if the while_loop was created inside the pfor."""
+ return self._is_inside_loop
+
+ def op_is_inside_loop(self, op):
+ """True if op was created inside the pfor loop body."""
+ assert isinstance(op, ops.Operation)
+ # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
+ # since it appears there tensorflow API could return different python
+ # objects representing the same Operation node.
+ return op._id in self._pfor_op_ids
+
+ @property
+ def is_stateful(self):
+ return self._is_stateful
+
+ @property
+ def pfor_converter(self):
+ """Return a converter for the while loop."""
+ return self
+
+ def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs,
+ inputs_stacked):
+ """Create a PFor object for converting parts of the while_loop.
+
+ Args:
+ parent_pfor: PFor object being used for converting the while_loop.
+ indices: int32 Tensor of ids for the iterations that are still active
+ (i.e. did not exit the while_loop).
+ cond_stacked: True if the while_loop condition is stacked.
+ inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note
+ that these Tensors are a subset of the loop variables for the generated
+ while_loop.
+ inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`,
+ indicating if the value is stacked or not.
+
+ Returns:
+ A PFor instance. The instance is initialized by adding conversion mappings
+ of nodes that will be external to the conversion that the returned
+ instance will be used for. e.g. Enter nodes as well as Merge and Switch
+ outputs are mapped to converted values.
+ """
+ num_outputs = len(self._outputs)
+ assert len(inputs) == len(self._enters)
+ assert len(inputs_stacked) == len(self._enters)
+ loop_var = parent_pfor.loop_var
+ loop_len = array_ops.size(indices)
+ pfor = PFor(
+ loop_var,
+ loop_len,
+ pfor_ops=self._pfor_ops,
+ all_indices=indices,
+ all_indices_partitioned=cond_stacked)
+ # Map all inputs of Enter nodes in self._direct_enters to their converted
+ # values.
+ for enter in self._direct_enters:
+ enter_input = enter.op.inputs[0]
+ converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper(
+ enter_input)
+ # Since these are resources / variants, they should be unstacked.
+ assert not stacked and not is_sparse_stacked, (enter, converted_enter)
+ pfor._add_conversion(enter, wrap(converted_enter, False))
+
+ # Map all Enter nodes to the inputs.
+ for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked):
+ pfor._add_conversion(enter, wrap(inp, stacked))
+ # Map outputs of Switch and Merge.
+ for i in range(num_outputs):
+ wrapped_inp = wrap(inputs[i], inputs_stacked[i])
+ merge = self._enter_merges[i]
+ pfor._add_conversion(merge.outputs[0], wrapped_inp)
+ # Note that second output of Merge is typically not used, except possibly
+ # as a control dependency. To avoid trying to output the correct value, we
+ # employ a hack here. We output a dummy invalid value with an incorrect
+ # dtype. This will allow control dependency to work but if using it as an
+ # input, it should typically lead to errors during graph construction due
+ # to dtype mismatch.
+ # TODO(agarwal): Check in the original graph to see if there are any
+ # consumers of this Tensor that use it as an input.
+ pfor._add_conversion(merge.outputs[1],
+ wrap(constant_op.constant(-1.0), False))
+ switch = self._exit_switches[i]
+ # Don't need to worry about switch.output[0] which will feed to Exit node.
+ pfor._add_conversion(switch.outputs[1], wrapped_inp)
+ return pfor
+
+ def _convert_enter(self, parent_pfor, enter):
+ """Converts an Enter node."""
+ inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0])
+ control_inputs = [
+ parent_pfor._convert_helper(x).t for x in enter.op.control_inputs
+ ]
+ if control_inputs:
+ with ops.control_dependencies(control_inputs):
+ inp = array_ops.identity(inp)
+ return inp, stacked
+
+ def _maybe_stacked(self, cache, inp):
+ """Heuristic to figue out if the coverting inp leads to a stacked value.
+
+
+ Args:
+ cache: map from Tensor to boolean indicating stacked/unstacked.
+ inp: input Tensor.
+
+ Returns:
+ True if `inp` could get stacked. If the function returns False, the
+ converted value should be guaranteed to be unstacked. If returning True,
+ it may or may not be stacked.
+ """
+ if inp in cache:
+ return cache[inp]
+ if not self.op_is_inside_loop(inp.op):
+ return False
+ op = inp.op
+ output = False
+ if op.type in [
+ "Shape",
+ "Rank"
+ "ShapeN",
+ "ZerosLike",
+ "TensorArrayV3",
+ "TensorArraySizeV3",
+ ]:
+ output = False
+ elif _is_stateful_pfor_op(op):
+ # This may be fairly aggressive.
+ output = True
+ elif op.type == "Exit":
+ # This may be fairly aggressive.
+ output = True
+ else:
+ for t in op.inputs:
+ if self._maybe_stacked(cache, t):
+ output = True
+ break
+ cache[inp] = output
+ return output
+
+ def _create_init_values(self, pfor_input):
+ """Create arguments passed to converted while_loop."""
+ with ops.name_scope("while_init"):
+ loop_len_vector = pfor_input.pfor.loop_len_vector
+ loop_len = loop_len_vector[0]
+ num_outputs = len(self._outputs)
+
+ inputs = []
+ maybe_stacked_cache = {}
+ # Convert all the Enters. Need to do this before checking for stacking
+ # below.
+ for i, enter in enumerate(self._enters):
+ inp, stacked = self._convert_enter(pfor_input.pfor, enter)
+ inputs.append(inp)
+ maybe_stacked_cache[enter] = stacked
+ # Since this enter node is part of the `loop_vars`, it corresponds to an
+ # output and its preceding switch. We mark this switch's output the same
+ # stackness, to act at the base case for the logic below. Below, we will
+ # be going through the body figuring out which inputs might need to be
+ # stacked and which inputs can safely remain unstacked.
+ if i < num_outputs:
+ maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked
+
+ # Shape invariants for init_values corresponding to self._enters.
+ input_shape_invariants = []
+ # TensorArrays for outputs of converted while loop
+ output_tas = []
+ # Shape invariants for output TensorArrays.
+ ta_shape_invariants = []
+ # List of booleans indicating stackness of inputs, i.e. tensors
+ # corresponding to self._enters.
+ inputs_stacked = []
+ for i, inp in enumerate(inputs):
+ enter = self._enters[i]
+ inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter)
+ # Note that even when an input is unstacked, the body could make it
+ # stacked. we use a heuristic below to figure out if body may be making
+ # it stacked.
+ if i < num_outputs:
+ body_output = self._body_outputs[i]
+ if enter.op in self._pfor_ops:
+ body_output_stacked = self._maybe_stacked(maybe_stacked_cache,
+ body_output)
+ else:
+ # If constructed outside of pfor loop, then the output would not be
+ # stacked.
+ body_output_stacked = False
+ if body_output_stacked and not inp_stacked:
+ inp = _stack(inp, loop_len_vector).t
+ inputs[i] = inp
+ inp_stacked = True
+ # TODO(agarwal): other attributes for the TensorArray ?
+ output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len))
+ ta_shape_invariants.append(tensor_shape.TensorShape(None))
+
+ inputs_stacked.append(inp_stacked)
+ input_shape_invariants.append(tensor_shape.TensorShape(None))
+
+ # See documentation for __call__ for the structure of init_values.
+ init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas
+ # TODO(agarwal): try stricter shape invariants
+ shape_invariants = (
+ [tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None)
+ ] + input_shape_invariants + ta_shape_invariants)
+
+ return init_values, inputs_stacked, shape_invariants
+
+ def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
+ """Handles case when condition is unstacked.
+
+ Note that all iterations end together. So we don't need to partition the
+ inputs. When all iterations are done, we write the inputs to the
+ TensorArrays. Note that we only write to index 0 of output_tas. Since all
+ iterations end together, they can all be output together.
+ """
+ not_all_done = array_ops.reshape(conditions, [])
+ new_output_tas = []
+ # pylint: disable=cell-var-from-loop
+ for i, out_ta in enumerate(output_tas):
+ inp = inputs[i]
+ new_output_tas.append(
+ control_flow_ops.cond(not_all_done,
+ lambda: out_ta,
+ lambda: out_ta.write(0, inp)))
+ # pylint: enable=cell-var-from-loop
+ return not_all_done, indices, inputs, new_output_tas
+
+ def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
+ output_tas):
+ num_outputs = len(self._outputs)
+ # Compute if all iterations are done.
+ not_all_done = math_ops.reduce_any(conditions)
+ conditions_int = math_ops.cast(conditions, dtypes.int32)
+ # Partition the indices.
+ done_indices, new_indices = data_flow_ops.dynamic_partition(
+ indices, conditions_int, 2)
+
+ new_inputs = []
+ new_output_tas = []
+ for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
+ # Partition the inputs.
+ if stacked:
+ done_inp, new_inp = data_flow_ops.dynamic_partition(
+ inp, conditions_int, 2)
+ else:
+ # TODO(agarwal): avoid this stacking. See TODO earlier in
+ # _process_cond_unstacked.
+ done_inp = _stack(inp, [array_ops.size(done_indices)]).t
+ new_inp = inp
+ new_inputs.append(new_inp)
+ # For iterations that are done, write them to TensorArrays.
+ if i < num_outputs:
+ out_ta = output_tas[i]
+ # Note that done_indices can be empty. done_inp should also be empty in
+ # that case.
+ new_output_tas.append(out_ta.scatter(done_indices, done_inp))
+ return not_all_done, new_indices, new_inputs, new_output_tas
+
+ def _process_body(self, pfor_input, inputs_stacked,
+ new_indices, cond_stacked, new_inputs,
+ not_all_done):
+ """Convert the body function."""
+
+ def true_fn(control_inputs, body_pfor, body_output, stacked):
+ """Converts the body function for all but last iteration.
+
+ This essentially converts body_output. Additionally, it needs to handle
+ any control dependencies on the NextIteration node. So it creates another
+ Identity node with the converted dependencies.
+ """
+ converted_control_inp = []
+ for x in control_inputs:
+ for t in x.outputs:
+ converted_control_inp.append(body_pfor._convert_helper(t).t)
+ if stacked:
+ # Note convert always does the stacking.
+ output = body_pfor.convert(body_output)
+ else:
+ output, convert_stacked, _ = body_pfor._convert_helper(body_output)
+ assert convert_stacked == stacked, body_output
+ with ops.control_dependencies(converted_control_inp):
+ return array_ops.identity(output)
+
+ body_pfor = self._init_pfor(pfor_input.pfor, new_indices,
+ cond_stacked, new_inputs,
+ inputs_stacked)
+ new_outputs = []
+
+ for i, (body_output, stacked) in enumerate(
+ zip(self._body_outputs, inputs_stacked)):
+ control_inp = self._next_iter_control_inputs[i]
+ out_dtype = body_output.dtype
+ # Note that we want to run the body only if not all pfor iterations are
+ # done. If all are done, we return empty tensors since these values will
+ # not be used. Notice that the value returned by the loop is based on
+ # TensorArrays and not directly on these returned values.
+ # pylint: disable=cell-var-from-loop
+ new_output = control_flow_ops.cond(
+ not_all_done,
+ lambda: true_fn(control_inp, body_pfor, body_output, stacked),
+ lambda: constant_op.constant([], dtype=out_dtype))
+ # pylint: enable=cell-var-from-loop
+ new_outputs.append(new_output)
+ return new_outputs
+
+ def __call__(self, pfor_input):
+ """Converter for the while_loop.
+
+ The conversion of a while_loop is another while_loop.
+
+ The arguments to this converted while_loop are as follows:
+ not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
+ are done.
+ indices: int32 1-D Tensor storing the id of the iterations that are not
+ done.
+ args: Remaining arguments. These can be divided into 3 categories:
+ - First set of arguments are the tensors that correspond to the initial
+ elements of self._enters. The elements that appear in original while
+ loop's `loop_vars`.
+ - The second set of arguments are the tensors that correspond to the
+ remaining elements of self._enters. These are the tensors that directly
+ enter the original while loop body.
+ - Finally, the last set of arguments are TensorArrays. These TensorArrays
+ correspond to the outputs of the original while_loop, i.e. to the
+ elements in self._outputs. Each TensorArray has `PFor.loop_len`
+ elements, i.e. the number of pfor iterations. At the end, the i'th
+ element of each TensorArray will contain the output computed by the
+ i'th iteration of pfor. Note that elements can be written into these
+ tensors arrays in any order, depending on when the corresponding pfor
+ iteration is done.
+ If the original while_loop had `k` tensors in its `loop_vars` and its body
+ directly captured `m` tensors, the `args` will contain `2 * k + m` values.
+
+ In each iteration, the while_loop body recomputes the condition for all
+ active pfor iterations to see which of them are now done. It then partitions
+ all the inputs and passes them along to the converted body. Values for all
+ the iterations that are done are written to TensorArrays indexed by the pfor
+ iteration number. When all iterations are done, the TensorArrays are stacked
+ to get the final value.
+
+ Args:
+ pfor_input: A PForInput object corresponding to the output of any Exit
+ node from this while loop.
+
+ Returns:
+ List of converted outputs.
+ """
+ # Create init_values that will be passed to the while_loop.
+ init_values, inputs_stacked, shape_invariants = self._create_init_values(
+ pfor_input)
+ # Note that we use a list as a hack since we need the nested function body
+ # to set the value of cond_is_stacked. python2.x doesn't support nonlocal
+ # variables.
+ cond_is_stacked = [None]
+
+ def cond(not_all_done, *_):
+ return not_all_done
+
+ def body(not_all_done, indices, *args):
+ # See documentatin for __call__ for the structure of *args.
+ num_enters = len(self._enters)
+ inputs = args[:num_enters]
+ output_tas = args[num_enters:]
+ # TODO(agarwal): see which outputs have consumers and only populate the
+ # TensorArrays corresponding to those. Or do those paths get trimmed out
+ # from inside the while_loop body?
+ assert len(inputs) >= len(output_tas)
+ assert len(inputs) == len(inputs_stacked)
+
+ # Convert condition
+ with ops.name_scope("while_cond"):
+ # Note that we set cond_stacked to True here. At this point we don't
+ # know if it could be loop invariant, hence the conservative value is
+ # to assume stacked.
+ cond_pfor = self._init_pfor(pfor_input.pfor, indices,
+ cond_stacked=True,
+ inputs=inputs,
+ inputs_stacked=inputs_stacked)
+ conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition)
+ cond_is_stacked[0] = cond_stacked
+
+ # Recompute the new condition, write outputs of done iterations, and
+ # partition the inputs if needed.
+ if not cond_stacked:
+ (not_all_done, new_indices,
+ new_inputs, new_output_tas) = self._process_cond_unstacked(
+ conditions, indices, inputs, output_tas)
+ else:
+ (not_all_done, new_indices,
+ new_inputs, new_output_tas) = self._process_cond_stacked(
+ conditions, indices, inputs, inputs_stacked, output_tas)
+
+ # Convert body
+ with ops.name_scope("while_body"):
+ # Compute the outputs from the body.
+ new_outputs = self._process_body(pfor_input, inputs_stacked,
+ new_indices, cond_stacked, new_inputs,
+ not_all_done)
+
+ # Note that the first num_outputs new values of inputs are computed using
+ # the body. Rest of them were direct Enters into the condition/body and
+ # the partitioning done earlier is sufficient to give the new value.
+ num_outputs = len(self._outputs)
+ new_args = ([not_all_done, new_indices] + new_outputs + list(
+ new_inputs[num_outputs:]) + new_output_tas)
+ return tuple(new_args)
+
+ while_outputs = control_flow_ops.while_loop(
+ cond, body, init_values, shape_invariants=shape_invariants)
+ output_tas = while_outputs[-len(self._outputs):]
+ outputs = []
+ assert cond_is_stacked[0] is not None
+ for inp_stacked, ta in zip(inputs_stacked, output_tas):
+ if cond_is_stacked[0]:
+ outputs.append(wrap(ta.stack(), True))
+ else:
+ # Note that if while_loop condition is unstacked, all iterations exit at
+ # the same time and we wrote those outputs in index 0 of the tensor
+ # array.
+ outputs.append(wrap(ta.read(0), inp_stacked))
+ return outputs
+
+
+class _PforInput(object):
+ """Input object passed to registered pfor converters."""
+
+ def __init__(self, pfor, op, inputs):
+ """Creates a _PforInput object.
+
+ Args:
+ pfor: PFor converter object.
+ op: the Operation object that is being converted.
+ inputs: list of WrappedTensor objects representing converted values of the
+ inputs of `op`.
+ """
+ self.pfor = pfor
+ self._op = op
+ self._inputs = inputs
+
+ def stack_inputs(self, stack_indices=None):
+ """Stacks unstacked inputs at `stack_indices`.
+
+ Args:
+ stack_indices: indices of inputs at which stacking is done. If None,
+ stacking is done at all indices.
+ """
+ if stack_indices is None:
+ stack_indices = range(len(self._inputs))
+ length = self.pfor.loop_len_vector
+ for i in stack_indices:
+ inp = self._inputs[i]
+ if not inp.is_stacked:
+ self._inputs[i] = _stack(inp.t, length)
+
+ def expanddim_inputs_for_broadcast(self):
+ """Reshapes stacked inputs to prepare them for broadcast.
+
+ Since stacked inputs have an extra leading dimension, automatic broadcasting
+ rules could incorrectly try to expand dimensions before that leading
+ dimension. To avoid that, we reshape these stacked inputs to the maximum
+ rank they will need to be broadcasted to.
+ """
+ if not self._inputs:
+ return
+
+ # Find max rank
+ def _get_rank(x):
+ rank = array_ops.rank(x.t)
+ if not x.is_stacked:
+ rank += 1
+ return rank
+
+ ranks = [_get_rank(x) for x in self._inputs]
+ max_rank = ranks[0]
+ for rank in ranks[1:]:
+ max_rank = math_ops.maximum(rank, max_rank)
+
+ for i, inp in enumerate(self._inputs):
+ if inp.is_stacked:
+ shape = array_ops.shape(inp.t)
+ rank_diff = array_ops.reshape(max_rank - ranks[i], [1])
+ ones = array_ops.tile([1], rank_diff)
+ new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0)
+ self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True)
+
+ @property
+ def inputs(self):
+ return self._inputs
+
+ @property
+ def num_inputs(self):
+ return len(self._inputs)
+
+ def input(self, index):
+ assert len(self._inputs) > index, (index, self._inputs)
+ return self._inputs[index]
+
+ def stacked_input(self, index):
+ t, is_stacked, _ = self.input(index)
+ if not is_stacked:
+ op_type = self.op_type
+ op_def = getattr(self._op, "op_def", None)
+ if op_def is None:
+ input_name = "at index %d" % index
+ else:
+ input_name = "\"%s\"" % op_def.input_arg[index].name
+ raise ValueError("Input %s of op \"%s\" expected to be not loop invariant"
+ ".\nError while converting op %s"
+ "with converted inputs\n%s" % (input_name, op_type,
+ self._op, self.inputs))
+ return t
+
+ def unstacked_input(self, index):
+ t, is_stacked, _ = self.input(index)
+ if is_stacked:
+ op_type = self.op_type
+ op_def = getattr(self._op, "op_def", None)
+ if op_def is None:
+ input_name = "at index %d" % index
+ else:
+ input_name = "\"%s\"" % op_def.input_arg[index].name
+ raise ValueError("Input %s of op \"%s\" expected to be loop invariant"
+ ".\nError while converting op %s"
+ "with converted inputs\n%s" % (input_name, op_type,
+ self._op, self.inputs))
+ return t
+
+ @property
+ def op(self):
+ return self._op
+
+ @property
+ def op_type(self):
+ return self._op.type
+
+ def get_attr(self, attr):
+ return self._op.get_attr(attr)
+
+ @property
+ def outputs(self):
+ return self._op.outputs
+
+ def output(self, index):
+ assert index < len(self._op.outputs)
+ return self._op.outputs[index]
+
+
+_pfor_converter_registry = {}
+
+
+class RegisterPFor(object):
+ """Utility to register converters for pfor.
+
+ Usage:
+ @RegisterPFor(foo_op_type)
+ def _foo_converter(pfor_input):
+ ...
+
+ The above will register conversion function `_foo_converter` for handling
+ conversion of `foo_op_type`. During conversion, the registered functin will be
+ called with a single argument of type `PForInput` which will contain state
+ needed for the conversion. This registered function should output a list of
+ WrappedTensor object with the same length as the number of outputs of op being
+ converted. If the op had zero outputs, then it should return a ops.Operation
+ object.
+ """
+
+ def __init__(self, op_type):
+ """Creates an object to register a converter for op with type `op_type`."""
+ self.op_type = op_type
+
+ def __call__(self, converter):
+ name = self.op_type
+ assert name not in _pfor_converter_registry, "Re-registering %s " % name
+ _pfor_converter_registry[name] = converter
+ return converter
+
+
+class RegisterPForWithArgs(RegisterPFor):
+ """Utility to register converters for pfor.
+
+ Usage:
+ @RegisteRPFor(foo_op_type, foo=value, ....)
+ def _foo_converter(pfor_input, foo=None, ....):
+ ...
+
+ See RegisterPFor for details on the conversion function.
+ `RegisterPForWithArgs` allows binding extra arguments to the
+ conversion function at registration time.
+ """
+
+ def __init__(self, op_type, *args, **kw_args):
+ super(RegisterPForWithArgs, self).__init__(op_type)
+ self._args = args
+ self._kw_args = kw_args
+
+ def __call__(self, converter):
+
+ def _f(pfor_input):
+ return converter(pfor_input, self.op_type, *self._args, **self._kw_args)
+
+ super(RegisterPForWithArgs, self).__call__(_f)
+ return converter
+
+
+def _create_op(op_type, inputs, op_dtypes, attrs=None):
+ """Utility to create an op."""
+ return ops.get_default_graph().create_op(
+ op_type, inputs, op_dtypes, attrs=attrs, compute_device=True)
+
+
+WrappedTensor = collections.namedtuple("WrappedTensor",
+ ["t", "is_stacked", "is_sparse_stacked"])
+"""Wrapper around the result of a Tensor conversion.
+
+The additional fields are useful for keeping track of the conversion state as
+data flows through the ops in the loop body. For every op whose output is a
+Tensor, its converter should return either a WrappedTensor or a list of
+WrappedTensors.
+
+Args:
+ t: The converted tensor
+ is_stacked: True if the tensor is stacked, i.e. represents the results of all
+ the iterations of the loop, where each row i of the tensor corresponds to
+ that op's output on iteration i of the loop. False if the tensor is not
+ stacked, i.e. represents the result of the op on of a single iteration of
+ the loop, where the result does not vary between iterations.
+ is_sparse_stacked: True if the tensor corresponds to a component tensor
+ (indices, values, or dense_shape) of a sparse tensor, and has been logically
+ stacked via a sparse conversion.
+"""
+
+
+def wrap(tensor, is_stacked=True, is_sparse_stacked=False):
+ """Helper to create a WrappedTensor object."""
+ assert isinstance(is_stacked, bool)
+ assert isinstance(is_sparse_stacked, bool)
+ assert isinstance(tensor, ops.Tensor)
+ assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is "
+ "stacked via a sparse "
+ "conversion, it must also be "
+ "stacked.")
+ return WrappedTensor(tensor, is_stacked, is_sparse_stacked)
+
+
+def _fallback_converter(pfor_input):
+ logging.warn("Using a while_loop for converting %s", pfor_input.op_type)
+ output_dtypes = [x.dtype for x in pfor_input.outputs]
+ iters = pfor_input.pfor.loop_len_vector[0]
+
+ def while_body(i, *ta_list):
+ """Body of while loop."""
+ inputs = [
+ x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs
+ ]
+ op_outputs = _create_op(
+ pfor_input.op_type,
+ inputs,
+ output_dtypes,
+ attrs=pfor_input.op.node_def.attr).outputs
+
+ outputs = []
+ for out, ta in zip(op_outputs, ta_list):
+ assert isinstance(out, ops.Tensor)
+ outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
+ return tuple([i + 1] + outputs)
+
+ ta_list = control_flow_ops.while_loop(
+ lambda i, *ta: i < iters, while_body, [0] + [
+ tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes
+ ])[1:]
+ return tuple([wrap(ta.concat(), True) for ta in ta_list])
+
+
+class PFor(object):
+ """Implementation of rewrite of parallel-for loops.
+
+ This class takes a DAG or a set of DAGs representing the body of a
+ parallel-for loop, and adds new operations to the graph that implements
+ functionality equivalent to running that loop body for a specified number of
+ iterations. This new set of nodes may or may not use a tensorflow loop
+ construct.
+
+ The process of conversion does not delete or change any existing operations.
+ It only adds operations that efficiently implement the equivalent
+ functionality. We refer to the added ops as "converted ops".
+
+ The conversion process uses a simple greedy heuristic. It walks the loop body
+ and tries to express the functionality of running each node in a loop with a
+ new set of nodes. When converting an op several cases are possible:
+ - The op is not inside the loop body. Hence it can be used as is.
+ - The op does not depend on the iteration number and is stateless. In this
+ case, it can be used as is.
+ - The op is not stateful, and depends on iteration number only through control
+ dependencies. In this case, we can create a single op with same inputs and
+ attributes, but with "converted" control dependencies.
+ - The op is not stateful, and all its inputs are loop invariant. In this
+ case, similar to above, we can create a single op with same inputs and
+ attributes, but with "converted" control dependencies.
+ - The op is stateful or at least one of the inputs is not loop invariant. In
+ this case, we run the registered converter for that op to create a set of
+ converted ops. All nodes in the set will have converted control dependencies
+ corresponding to control dependencies of the original op. If the op returned
+ multiple outputs, "converted outputs" could be produced by different ops in
+ this set.
+ """
+
+ def __init__(self,
+ loop_var,
+ loop_len,
+ pfor_ops,
+ all_indices=None,
+ all_indices_partitioned=False):
+ """Creates an object to rewrite a parallel-for loop.
+
+ Args:
+ loop_var: ops.Tensor output of a Placeholder operation. The value should
+ be an int32 scalar representing the loop iteration number.
+ loop_len: A scalar or scalar Tensor representing the number of iterations
+ the loop is run for.
+ pfor_ops: List of all ops inside the loop body.
+ all_indices: If not None, an int32 vector with size `loop_len`
+ representing the iteration ids that are still active. These values
+ should be unique and sorted. However they may not be contiguous. This is
+ typically the case when inside a control flow construct which has
+ partitioned the indices of the iterations that are being converted.
+ all_indices_partitioned: If True, this object is being constructed from a
+ control flow construct where not all the pfor iterations are guaranteed
+ to be active.
+ """
+ assert isinstance(loop_var, ops.Tensor)
+ assert loop_var.op.type == "Placeholder"
+ self._loop_var = loop_var
+ loop_len_value = tensor_util.constant_value(loop_len)
+ if loop_len_value is not None:
+ loop_len = loop_len_value
+ self._loop_len_vector = array_ops.reshape(loop_len, [1])
+ self._all_indices_partitioned = all_indices_partitioned
+ if all_indices_partitioned:
+ assert all_indices is not None
+ self.all_indices = (
+ math_ops.range(loop_len) if all_indices is None else all_indices)
+
+ self._conversion_map = {}
+ self._conversion_map[loop_var] = wrap(self.all_indices, True)
+ self._pfor_ops = set(pfor_ops)
+ self._pfor_op_ids = set([x._id for x in pfor_ops])
+
+ def op_is_inside_loop(self, op):
+ """True if op was created inside the pfor loop body."""
+ assert isinstance(op, ops.Operation)
+ # Note that we use self._pfor_op_ids for the check and not self._pfor_ops
+ # since it appears there tensorflow API could return different python
+ # objects representing the same Operation node.
+ return op._id in self._pfor_op_ids
+
+ def _convert_sparse(self, y):
+ """Returns the converted value corresponding to SparseTensor y.
+
+ For SparseTensors, instead of stacking the component tensors separately,
+ resulting in component tensors with shapes (N, m, rank), (N, m), and (N,
+ rank) respectively for indices, values, and dense_shape (where N is the loop
+ length and m is the number of sparse tensor values per loop iter), we want
+ to logically stack the SparseTensors, to create a SparseTensor whose
+ components are size (N * m, rank + 1), (N * m, ), and (rank + 1,)
+ respectively.
+
+ Here, we try to get the conversion of each component tensor.
+ If the tensors are stacked via a sparse conversion, return the resulting
+ SparseTensor composed of the converted components. Otherwise, the component
+ tensors are either unstacked or stacked naively. In the latter case, we
+ unstack the component tensors to reform loop_len SparseTensor elements,
+ then correctly batch them.
+
+ The unstacked tensors must have the same rank. Each dimension of each
+ SparseTensor will expand to be the largest among all SparseTensor elements
+ for that dimension. For example, if there are N SparseTensors of rank 3
+ being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i),
+ the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)).
+
+ Args:
+ y: A tf.SparseTensor.
+
+ Returns:
+ A tf.SparseTensor that is the converted value corresponding to y.
+ """
+ outputs = [
+ self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape)
+ ]
+ assert all(isinstance(o, WrappedTensor) for o in outputs)
+
+ if all(w.is_sparse_stacked for w in outputs):
+ return sparse_tensor.SparseTensor(*[w.t for w in outputs])
+
+ assert not any(w.is_sparse_stacked for w in outputs), (
+ "Error converting SparseTensor. All components should be logically "
+ "stacked, or none.")
+
+ # If component tensors were not sparsely stacked, they are either unstacked
+ # or stacked without knowledge that they are components of sparse tensors.
+ # In this case, we have to restack them.
+ return self._restack_sparse_tensor_logically(
+ *[self._unwrap_or_tile(w) for w in outputs])
+
+ def _restack_sparse_tensor_logically(self, indices, values, shape):
+ sparse_tensor_rank = indices.get_shape()[-1].value
+ if sparse_tensor_rank is not None:
+ sparse_tensor_rank += 1
+
+ def map_fn(args):
+ res = gen_sparse_ops.serialize_sparse(
+ args[0], args[1], args[2], out_type=dtypes.variant)
+ return res
+
+ # Applies a map function to the component tensors to serialize each
+ # sparse tensor element and batch them all, then deserializes the batch.
+ # TODO(rachelim): Try to do this without map_fn -- add the right offsets
+ # to shape and indices tensors instead.
+ result = functional_ops.map_fn(
+ map_fn, [indices, values, shape], dtype=dtypes.variant)
+ return sparse_ops.deserialize_sparse(
+ result, dtype=values.dtype, rank=sparse_tensor_rank)
+
+ def _unwrap_or_tile(self, wrapped_tensor):
+ """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it."""
+ output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked
+ if is_stacked:
+ return output
+ else:
+ return _stack(output, self._loop_len_vector).t
+
+ def convert(self, y):
+ """Returns the converted value corresponding to y.
+
+ Args:
+ y: A ops.Tensor or a ops.Operation object. If latter, y should not have
+ any outputs.
+
+ Returns:
+ If y does not need to be converted, it returns y as is. Else it returns
+ the "converted value" corresponding to y.
+ """
+ if isinstance(y, sparse_tensor.SparseTensor):
+ return self._convert_sparse(y)
+ output = self._convert_helper(y)
+ if isinstance(output, WrappedTensor):
+ assert isinstance(y, ops.Tensor)
+ return self._unwrap_or_tile(output)
+ else:
+ assert isinstance(y, ops.Operation)
+ assert not y.outputs
+ assert isinstance(output, ops.Operation)
+ return output
+
+ def _was_converted(self, t):
+ """True if t is not a conversion of itself."""
+ converted_t = self._conversion_map[t]
+ return converted_t.t is not t
+
+ def _add_conversion(self, old_output, new_output):
+ self._conversion_map[old_output] = new_output
+
+ def _convert_helper(self, op_or_tensor):
+ stack = [op_or_tensor]
+ while stack:
+ y = stack[0]
+ if y in self._conversion_map:
+ assert isinstance(self._conversion_map[y],
+ (WrappedTensor, ops.Operation))
+ stack.pop(0)
+ continue
+ if isinstance(y, ops.Operation):
+ assert not y.outputs, (
+ "We only support converting Operation objects with no outputs. "
+ "Got %s", y)
+ y_op = y
+ else:
+ assert isinstance(y, ops.Tensor), y
+ y_op = y.op
+
+ is_while_loop = y_op.type == "Exit"
+ if is_while_loop:
+ while_op = WhileOp(y, pfor_ops=self._pfor_ops)
+ is_inside_loop = while_op.is_inside_loop
+ # If all nodes in the while_loop graph were created inside the pfor, we
+ # treat the whole loop subgraph as a single op (y_op) and try to convert
+ # it. For while_loops that are created completely or partially outside,
+ # we treat them as external and should be able to simply return the Exit
+ # node output as is without needing any conversion. Note that for
+ # while_loops that are partially constructed inside, we assume they will
+ # be loop invariant. If that is not the case, it will create runtime
+ # errors since the converted graph would depend on the self._loop_var
+ # placeholder.
+ if is_inside_loop:
+ y_op = while_op
+ else:
+ is_inside_loop = self.op_is_inside_loop(y_op)
+
+ # If this op was not created inside the loop body, we will return as is.
+ # 1. Convert inputs and control inputs.
+
+ def _add_to_stack(x):
+ if x not in self._conversion_map:
+ stack.insert(0, x)
+ return True
+ else:
+ return False
+
+ if is_inside_loop:
+ added_to_stack = False
+ for inp in y_op.inputs:
+ added_to_stack |= _add_to_stack(inp)
+ for cinp in y_op.control_inputs:
+ if cinp.outputs:
+ for t in cinp.outputs:
+ added_to_stack |= _add_to_stack(t)
+ else:
+ added_to_stack |= _add_to_stack(cinp)
+ if added_to_stack:
+ continue
+
+ converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
+ some_input_converted = any(
+ [self._was_converted(x) for x in y_op.inputs])
+ some_input_stacked = any([x.is_stacked for x in converted_inputs])
+
+ converted_control_ops = set()
+ some_control_input_converted = False
+ for cinp in y_op.control_inputs:
+ if cinp.outputs:
+ for t in cinp.outputs:
+ converted_t = self._conversion_map[t]
+ if self._was_converted(t):
+ some_control_input_converted = True
+ converted_control_ops.add(converted_t.t.op)
+ else:
+ converted_cinp = self._conversion_map[cinp]
+ assert isinstance(converted_cinp, ops.Operation)
+ if converted_cinp != cinp:
+ some_control_input_converted = True
+ converted_control_ops.add(converted_cinp)
+ converted_control_ops = list(converted_control_ops)
+ is_stateful = _is_stateful_pfor_op(y_op)
+ else:
+ converted_inputs = []
+ converted_control_ops = []
+ logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op,
+ converted_inputs, converted_control_ops)
+
+ # 2. Convert y_op
+ # If converting a while_loop, we let the while_loop convertor deal with
+ # putting the control dependencies appropriately.
+ control_dependencies = [] if is_while_loop else converted_control_ops
+ with ops.control_dependencies(control_dependencies), ops.name_scope(
+ y_op.name + "/pfor/"):
+ # None of the inputs and control inputs were converted.
+ if (not is_inside_loop or
+ (not is_stateful and not some_input_converted and
+ not some_control_input_converted)):
+ if y == y_op:
+ assert not isinstance(y_op, WhileOp)
+ new_outputs = y_op
+ else:
+ new_outputs = [wrap(x, False) for x in y_op.outputs]
+ elif not (is_stateful or is_while_loop or some_input_stacked):
+ # All inputs are unstacked or uncoverted but some control inputs are
+ # converted.
+ # TODO(rachelim): Handle the case where some inputs are sparsely
+ # stacked (i.e. any([x.is_sparse_stacked for x in converted_inputs]))
+ new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
+ [x.dtype for x in y_op.outputs],
+ y_op.node_def.attr)
+ if y == y_op:
+ new_outputs = new_op
+ else:
+ new_outputs = [wrap(x, False) for x in new_op.outputs]
+ else:
+ # Either some inputs are not loop invariant or op is stateful.
+ if hasattr(y_op, "pfor_converter"):
+ converter = y_op.pfor_converter
+ else:
+ converter = _pfor_converter_registry.get(y_op.type, None)
+ if converter is None:
+ if flags.FLAGS.op_conversion_fallback_to_while_loop:
+ converter = _fallback_converter
+ else:
+ raise ValueError(
+ "No converter defined for %s\n%s\ninputs: %s. "
+ "\nEither add a converter or set "
+ "--op_conversion_fallback_to_while_loop=True, "
+ "which may run slower" % (y_op.type, y_op, converted_inputs))
+ # TODO(rachelim): Handle the case where some inputs are sparsely
+ # stacked. We should only call the converter if it supports handling
+ # those inputs.
+ new_outputs = converter(_PforInput(self, y_op, converted_inputs))
+ if isinstance(new_outputs, WrappedTensor):
+ new_outputs = [new_outputs]
+ assert isinstance(new_outputs,
+ (list, tuple, ops.Operation)), new_outputs
+ logging.vlog(2, "converted %s %s", y_op, new_outputs)
+
+ # Insert into self._conversion_map
+ if y == y_op:
+ assert isinstance(new_outputs, ops.Operation)
+ self._add_conversion(y_op, new_outputs)
+ else:
+ for old_output, new_output in zip(y_op.outputs, new_outputs):
+ assert isinstance(new_output, WrappedTensor), (new_output, y, y_op)
+ self._add_conversion(old_output, new_output)
+ stack.pop(0)
+
+ return self._conversion_map[op_or_tensor]
+
+ @property
+ def loop_len_vector(self):
+ """Returns a single element vector whose value is number of iterations."""
+ return self._loop_len_vector
+
+ @property
+ def loop_var(self):
+ """Returns placeholder loop variable."""
+ return self._loop_var
+
+ @property
+ def pfor_ops(self):
+ return self._pfor_ops
+
+ @property
+ def all_indices_partitioned(self):
+ """all_indices_partitioned property.
+
+ Returns:
+ True if we are inside a control flow construct and not all pfor iterations
+ may be active.
+ """
+ return self._all_indices_partitioned
+
+# nn_ops
+
+
+def _flatten_first_two_dims(x):
+ """Merges first two dimensions."""
+ old_shape = array_ops.shape(x)
+ new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+def _unflatten_first_dim(x, first_dim):
+ """Splits first dimension into [first_dim, -1]."""
+ old_shape = array_ops.shape(x)
+ new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+def _inputs_with_flattening(pfor_input, input_indices):
+ """Stacks and flattens first dim of inputs at indices `input_indices`."""
+ if input_indices is None:
+ input_indices = []
+ pfor_input.stack_inputs(stack_indices=input_indices)
+ inputs = []
+ for i in range(pfor_input.num_inputs):
+ if i in input_indices:
+ inp = pfor_input.stacked_input(i)
+ inp = _flatten_first_two_dims(inp)
+ else:
+ inp = pfor_input.unstacked_input(i)
+ inputs.append(inp)
+ return inputs
+
+
+@RegisterPForWithArgs("Conv2D", dims=[0])
+@RegisterPForWithArgs("AvgPool", dims=[0])
+@RegisterPForWithArgs("MaxPool", dims=[0])
+@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
+@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
+def _convert_flatten_batch(pfor_input, op_type, dims):
+ del op_type
+ inputs = _inputs_with_flattening(pfor_input, dims)
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs]
+ return [wrap(x, True) for x in outputs]
+
+
+_channel_flatten_input_cache = {}
+
+
+def _channel_flatten_input(x, data_format):
+ """Merge the stack dimension with the channel dimension.
+
+ If S is pfor's stacking dimension, then,
+ - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose
+ should be cheap.
+ - for SNHWC, we transpose to NHWCS.
+ We then merge the S and C dimension.
+
+ Args:
+ x: ops.Tensor to transform.
+ data_format: "NCHW" or "NHWC".
+
+ Returns:
+ A 3-element tuple with the transformed value, along with the shape for
+ reshape and order for transpose required to transform back.
+ """
+
+ graph = ops.get_default_graph()
+ cache_key = (graph, x, data_format)
+ if cache_key not in _channel_flatten_input_cache:
+ x_shape = array_ops.shape(x)
+ if data_format == b"NCHW":
+ order = [1, 0, 2, 3, 4]
+ shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0)
+ reverse_order = order
+ else:
+ order = [1, 2, 3, 0, 4]
+ shape = array_ops.concat([x_shape[1:4], [-1]], axis=0)
+ reverse_order = [3, 0, 1, 2, 4]
+ # Move S dimension next to C dimension.
+ x = array_ops.transpose(x, order)
+ reverse_shape = array_ops.shape(x)
+ # Reshape to merge the S and C dimension.
+ x = array_ops.reshape(x, shape)
+ outputs = x, reverse_order, reverse_shape
+ _channel_flatten_input_cache[cache_key] = outputs
+ else:
+ outputs = _channel_flatten_input_cache[cache_key]
+ return outputs
+
+
+# Note that with training=True, running FusedBatchNorm on individual examples
+# is very different from running FusedBatchNorm on a batch of those examples.
+# This is because, for the latter case, the operation can be considered as first
+# computing the mean and variance over all the examples and then using these
+# to scale all those examples. This creates a data dependency between these
+# different "iterations" since the inputs to the scaling step depends on the
+# statistics coming from all these inputs.
+# As with other kernels, the conversion here effectively runs the kernel
+# independently for each iteration, and returns outputs by stacking outputs from
+# each of those iterations.
+@RegisterPFor("FusedBatchNorm")
+def _convert_fused_batch_norm(pfor_input):
+ is_training = pfor_input.get_attr("is_training")
+ # When BatchNorm is used with training=False, mean and variance are provided
+ # externally and used as is by the op. Thus, we can merge the S and N
+ # dimensions as we do for regular operations.
+ # When BatchNorm is used with training=True, mean and variance are computed
+ # for each channel across the batch dimension (first one). If we merge S and N
+ # dimensions, mean and variances will be computed over a larger set. So, we
+ # merge the S and C dimensions instead.
+ if not is_training:
+ # We return zeros for batch_mean and batch_variance output. Note that CPU
+ # and GPU seem to have different behavior for those two outputs. CPU outputs
+ # zero because these values are not used during inference. GPU outputs
+ # something, probably real means and variances.
+ inputs = _inputs_with_flattening(pfor_input, [0])
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ y = outputs[0]
+ n = pfor_input.pfor.loop_len_vector
+ y = _unflatten_first_dim(y, n)
+ mean = pfor_input.unstacked_input(3)
+ zeros = array_ops.zeros_like(mean)
+ return [wrap(y, True), wrap(zeros, False), wrap(zeros, False)]
+
+ pfor_input.stack_inputs()
+ data_format = pfor_input.get_attr("data_format")
+ # We merge the first dimension with the "C" dimension, run FusedBatchNorm, and
+ # then transpose back.
+ x = pfor_input.stacked_input(0)
+ x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format)
+ # Note that we stack all the other inputs as well so that they are the same
+ # size as the new size of the channel dimension.
+ inputs = [x] + [
+ array_ops.reshape(pfor_input.stacked_input(i), [-1])
+ for i in range(1, pfor_input.num_inputs)
+ ]
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ y = outputs[0]
+ y = array_ops.reshape(y, reverse_shape)
+ y = array_ops.transpose(y, reverse_order)
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
+ outputs = [y] + outputs
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("FusedBatchNormGrad")
+def _convert_fused_batch_norm_grad(pfor_input):
+ pfor_input.stack_inputs()
+ data_format = pfor_input.get_attr("data_format")
+ y_backprop = pfor_input.stacked_input(0)
+ y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format)
+ x = pfor_input.stacked_input(1)
+ x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format)
+ inputs = [y_backprop, x] + [
+ array_ops.reshape(pfor_input.stacked_input(i), [-1])
+ for i in range(2, pfor_input.num_inputs)
+ ]
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ x_backprop = outputs[0]
+ x_backprop = array_ops.reshape(x_backprop, x_reverse_shape)
+ x_backprop = array_ops.transpose(x_backprop, x_reverse_order)
+ n = pfor_input.pfor.loop_len_vector
+ outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]]
+ outputs = [x_backprop] + outputs
+ return [wrap(output, True) for output in outputs]
+
+
+@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0)
+@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0)
+def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims,
+ shape_dim):
+ del op_type
+ inputs = _inputs_with_flattening(pfor_input, flatten_dims)
+ n = pfor_input.pfor.loop_len_vector
+ # Adjust the `input_sizes` input.
+ ones = array_ops.ones(
+ [array_ops.shape(inputs[shape_dim])[0] - 1], dtype=n.dtype)
+ inputs[shape_dim] *= array_ops.concat([n, ones], axis=0)
+ outputs = _create_op(
+ pfor_input.op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ outputs = [_unflatten_first_dim(x, n) for x in outputs]
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("Conv2DBackpropFilter")
+def _convert_conv2d_backprop_filter(pfor_input):
+ pfor_input.stack_inputs(stack_indices=[2])
+ inputs, inputs_stacked, _ = pfor_input.input(0)
+ filter_sizes = pfor_input.unstacked_input(1)
+ grads = pfor_input.stacked_input(2)
+ strides = pfor_input.get_attr("strides")
+ padding = pfor_input.get_attr("padding")
+ use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu")
+ data_format = pfor_input.get_attr("data_format")
+ dilations = pfor_input.get_attr("dilations")
+ if inputs_stacked:
+ # TODO(agarwal): Implement this efficiently.
+ logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!")
+
+ def while_body(i, ta):
+ inp_i = inputs[i, ...]
+ grad_i = grads[i, ...]
+ output = nn_ops.conv2d_backprop_filter(
+ inp_i,
+ filter_sizes,
+ grad_i,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format,
+ dilations=dilations)
+ return i + 1, ta.write(i, array_ops.expand_dims(output, 0))
+
+ n = array_ops.reshape(pfor_input.pfor.loop_len_vector, [])
+ _, ta = control_flow_ops.while_loop(
+ lambda i, ta: i < n, while_body,
+ (0, tensor_array_ops.TensorArray(inputs.dtype, n)))
+ output = ta.concat()
+ return wrap(output, True)
+ else:
+ # We merge the stack dimension with the channel dimension of the gradients
+ # and pretend we had a larger filter (see change to filter_sizes below).
+ # Once the filter backprop is computed, we reshape and transpose back
+ # appropriately.
+ grads, _, _ = _channel_flatten_input(grads, data_format)
+ n = pfor_input.pfor.loop_len_vector
+ old_filter_sizes = filter_sizes
+ filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0)
+ output = nn_ops.conv2d_backprop_filter(
+ inputs,
+ filter_sizes,
+ grads,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format,
+ dilations=dilations)
+ new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0)
+ output = array_ops.reshape(output, new_filter_shape)
+ output = array_ops.transpose(output, [3, 0, 1, 2, 4])
+ return wrap(output, True)
+
+
+# array_ops
+
+
+@RegisterPForWithArgs("Identity", array_ops.identity)
+@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
+def _convert_identity(pfor_input, op_type, op_func):
+ del op_type
+ return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("Reshape")
+def _convert_reshape(pfor_input):
+ t = pfor_input.stacked_input(0)
+ shape = pfor_input.unstacked_input(1)
+ new_dim = array_ops.shape(t)[:1]
+ new_shape = array_ops.concat([new_dim, shape], axis=0)
+ return wrap(array_ops.reshape(t, new_shape), True)
+
+
+@RegisterPFor("ExpandDims")
+def _convert_expanddims(pfor_input):
+ t = pfor_input.stacked_input(0)
+ dim = pfor_input.unstacked_input(1)
+ dim += math_ops.cast(dim >= 0, dtypes.int32)
+ return wrap(array_ops.expand_dims(t, axis=dim), True)
+
+
+@RegisterPFor("Slice")
+def _convert_slice(pfor_input):
+ t = pfor_input.stacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ size = pfor_input.unstacked_input(2)
+ begin = array_ops.concat([[0], begin], axis=0)
+ size = array_ops.concat([[-1], size], axis=0)
+ return wrap(array_ops.slice(t, begin, size), True)
+
+
+@RegisterPFor("Tile")
+def _convert_tile(pfor_input):
+ t = pfor_input.stacked_input(0)
+ multiples = pfor_input.unstacked_input(1)
+ multiples = array_ops.concat([[1], multiples], 0)
+ return wrap(array_ops.tile(t, multiples), True)
+
+
+@RegisterPFor("Pack")
+def _convert_pack(pfor_input):
+ pfor_input.stack_inputs()
+ axis = pfor_input.get_attr("axis")
+ if axis >= 0:
+ axis += 1
+ return wrap(
+ array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True)
+
+
+@RegisterPFor("Unpack")
+def _convert_unpack(pfor_input):
+ value = pfor_input.stacked_input(0)
+ axis = pfor_input.get_attr("axis")
+ if axis >= 0:
+ axis += 1
+ num = pfor_input.get_attr("num")
+ return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)]
+
+
+@RegisterPFor("Pad")
+def _convert_pad(pfor_input):
+ t = pfor_input.stacked_input(0)
+ paddings = pfor_input.unstacked_input(1)
+ paddings = array_ops.concat([[[0, 0]], paddings], 0)
+ return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True)
+
+
+@RegisterPFor("Split")
+def _convert_split(pfor_input):
+ split_dim = pfor_input.unstacked_input(0)
+ t = pfor_input.stacked_input(1)
+ num_split = pfor_input.get_attr("num_split")
+ split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
+ return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
+
+
+@RegisterPFor("Transpose")
+def _convert_transpose(pfor_input):
+ t = pfor_input.stacked_input(0)
+ perm = pfor_input.unstacked_input(1)
+ new_perm = array_ops.concat([[0], perm + 1], axis=0)
+ return wrap(array_ops.transpose(t, new_perm), True)
+
+
+@RegisterPFor("ZerosLike")
+def _convert_zeroslike(pfor_input):
+ t = pfor_input.stacked_input(0)
+ shape = array_ops.shape(t)[1:]
+ return wrap(array_ops.zeros(shape, dtype=t.dtype), False)
+
+
+@RegisterPFor("Gather")
+@RegisterPFor("GatherV2")
+def _convert_gather(pfor_input):
+ param, param_stacked, _ = pfor_input.input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ op_type = pfor_input.op_type
+ if op_type == "Gather":
+ validate_indices = pfor_input.get_attr("validate_indices")
+ axis = 0
+ else:
+ validate_indices = None
+ axis = pfor_input.unstacked_input(2)
+ axis_value = tensor_util.constant_value(axis)
+ if axis_value is not None:
+ axis = axis_value
+ if indices_stacked and not param_stacked:
+ if indices == pfor_input.pfor.all_indices and axis == 0:
+ param_shape0 = param.shape[0].value
+ indices_shape0 = indices.shape[0].value
+ if param_shape0 is not None and indices_shape0 == param_shape0:
+ # Note that with loops and conditionals, indices may not be contiguous.
+ # However they will be sorted and unique. So if the shape matches, then
+ # it must be picking up all the rows of param.
+ return wrap(param, True)
+ # TODO(agarwal): use array_ops.slice here.
+ output = array_ops.gather(
+ param, indices, validate_indices=validate_indices, axis=axis)
+ if axis != 0:
+ axis = control_flow_ops.cond(
+ axis < 0, lambda: axis + array_ops.rank(param), lambda: axis)
+ order = array_ops.concat(
+ [[axis],
+ math_ops.range(axis),
+ math_ops.range(axis + 1, array_ops.rank(output))],
+ axis=0)
+ output = control_flow_ops.cond(
+ math_ops.equal(axis, 0), lambda: output,
+ lambda: array_ops.transpose(output, order))
+ return wrap(output, True)
+ if param_stacked:
+ loop_len_vector = pfor_input.pfor.loop_len_vector
+ pfor_input.stack_inputs(stack_indices=[1])
+ indices = pfor_input.stacked_input(1)
+ param_flat = _flatten_first_two_dims(param)
+
+ # Recompute indices to handle stacked param.
+ indices_offset = math_ops.range(
+ loop_len_vector[0]) * array_ops.shape(param)[1]
+ # Reshape indices_offset to allow broadcast addition
+ ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32)
+ new_shape = array_ops.concat([loop_len_vector, ones], axis=0)
+ indices_offset = array_ops.reshape(indices_offset, new_shape)
+ indices += indices_offset
+
+ # TODO(agarwal): handle axis != 0. May need to transpose param or
+ # array_ops.gather_nd.
+ if isinstance(axis, ops.Tensor):
+ axis_value = tensor_util.constant_value(axis)
+ else:
+ try:
+ axis_value = int(axis)
+ except TypeError:
+ axis_value = None
+ msg = ("Gather, where indices and param are both loop dependent, currently "
+ "requires axis=0")
+ if axis_value is not None and axis_value != 0:
+ raise ValueError("Error while converting %s. %s. Got axis=%d" %
+ (pfor_input.op, msg, axis))
+ with ops.control_dependencies(
+ [check_ops.assert_equal(axis, 0, message=msg)]):
+ output = array_ops.gather(param_flat, indices)
+ return wrap(output, True)
+
+
+@RegisterPFor("ConcatV2")
+def _convert_concatv2(pfor_input):
+ n = pfor_input.num_inputs
+ pfor_input.stack_inputs(stack_indices=range(n - 1))
+ axis = pfor_input.unstacked_input(n - 1)
+ axis += math_ops.cast(axis >= 0, axis.dtype)
+ return wrap(
+ array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis),
+ True)
+
+
+@RegisterPFor("StridedSlice")
+def _convert_strided_slice(pfor_input):
+ inp = pfor_input.stacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ end = pfor_input.unstacked_input(2)
+ strides = pfor_input.unstacked_input(3)
+ begin_mask = pfor_input.get_attr("begin_mask")
+ end_mask = pfor_input.get_attr("end_mask")
+ ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
+ new_axis_mask = pfor_input.get_attr("new_axis_mask")
+ shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
+
+ begin = array_ops.concat([[0], begin], axis=0)
+ end = array_ops.concat([[0], end], axis=0)
+ strides = array_ops.concat([[1], strides], axis=0)
+ begin_mask = begin_mask << 1 | 1
+ end_mask = end_mask << 1 | 1
+ ellipsis_mask <<= 1
+ new_axis_mask <<= 1
+ shrink_axis_mask <<= 1
+ return wrap(
+ array_ops.strided_slice(
+ inp,
+ begin,
+ end,
+ strides,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask), True)
+
+
+@RegisterPFor("StridedSliceGrad")
+def _convert_strided_slice_grad(pfor_input):
+ shape = pfor_input.unstacked_input(0)
+ begin = pfor_input.unstacked_input(1)
+ end = pfor_input.unstacked_input(2)
+ strides = pfor_input.unstacked_input(3)
+ dy = pfor_input.stacked_input(4)
+ begin_mask = pfor_input.get_attr("begin_mask")
+ end_mask = pfor_input.get_attr("end_mask")
+ ellipsis_mask = pfor_input.get_attr("ellipsis_mask")
+ new_axis_mask = pfor_input.get_attr("new_axis_mask")
+ shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask")
+
+ shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0)
+ begin = array_ops.concat([[0], begin], axis=0)
+ end = array_ops.concat([[0], end], axis=0)
+ strides = array_ops.concat([[1], strides], axis=0)
+ begin_mask = begin_mask << 1 | 1
+ end_mask = end_mask << 1 | 1
+ ellipsis_mask <<= 1
+ new_axis_mask <<= 1
+ shrink_axis_mask <<= 1
+ return wrap(
+ array_ops.strided_slice_grad(
+ shape,
+ begin,
+ end,
+ strides,
+ dy,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask), True)
+
+
+# math_ops
+
+
+@RegisterPFor("MatMul")
+def _convert_matmul(pfor_input):
+ # TODO(agarwal): Check if tiling is faster than two transposes.
+ a, a_stacked, _ = pfor_input.input(0)
+ b, b_stacked, _ = pfor_input.input(1)
+ tr_a = pfor_input.get_attr("transpose_a")
+ tr_b = pfor_input.get_attr("transpose_b")
+ if a_stacked and b_stacked:
+ output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True)
+ return output
+ elif a_stacked:
+ if tr_a:
+ a = array_ops.transpose(a, [0, 2, 1])
+ if a.shape.is_fully_defined():
+ x, y, z = a.shape
+ else:
+ x, y, z = [
+ array_ops.reshape(i, [])
+ for i in array_ops.split(array_ops.shape(a), 3)
+ ]
+ a = array_ops.reshape(a, [x * y, z])
+ prod = math_ops.matmul(a, b, transpose_b=tr_b)
+ return wrap(array_ops.reshape(prod, [x, y, -1]), True)
+ else:
+ assert b_stacked
+ if tr_b:
+ perm = [2, 0, 1]
+ b = array_ops.transpose(b, perm)
+ else:
+ # As an optimization, if one of the first two dimensions is 1, then we can
+ # reshape instead of transpose.
+ # TODO(agarwal): This check can be done inside Transpose kernel.
+ b_shape = array_ops.shape(b)
+ min_dim = math_ops.minimum(b_shape[0], b_shape[1])
+ perm = control_flow_ops.cond(
+ math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2])
+ new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]])
+ b = array_ops.transpose(b, perm)
+ b = array_ops.reshape(b, new_shape)
+
+ if b.shape.is_fully_defined():
+ x, y, z = b.shape
+ else:
+ x, y, z = [
+ array_ops.reshape(i, [])
+ for i in array_ops.split(array_ops.shape(b), 3)
+ ]
+ b = array_ops.reshape(b, [x, y * z])
+ prod = math_ops.matmul(a, b, transpose_a=tr_a)
+ prod = array_ops.reshape(prod, [-1, y, z])
+ prod = array_ops.transpose(prod, [1, 0, 2])
+ return wrap(prod, True)
+
+
+@RegisterPFor("BatchMatMul")
+def _convert_batch_mat_mul(pfor_input):
+ # TODO(agarwal): There may be a more efficient way to do this instead of
+ # stacking the inputs.
+ pfor_input.stack_inputs()
+ x = pfor_input.stacked_input(0)
+ y = pfor_input.stacked_input(1)
+ adj_x = pfor_input.get_attr("adj_x")
+ adj_y = pfor_input.get_attr("adj_y")
+
+ x = _flatten_first_two_dims(x)
+ y = _flatten_first_two_dims(y)
+ output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
+ output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector)
+ return wrap(output, True)
+
+
+@RegisterPForWithArgs("Sum", math_ops.reduce_sum)
+@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
+@RegisterPForWithArgs("Max", math_ops.reduce_max)
+@RegisterPForWithArgs("Min", math_ops.reduce_min)
+def _convert_reduction(pfor_input, _, op_func):
+ t = pfor_input.stacked_input(0)
+ indices = pfor_input.unstacked_input(1)
+ # Shift positive indices by one to account for the extra dimension.
+ indices += math_ops.cast(indices >= 0, dtypes.int32)
+ keep_dims = pfor_input.get_attr("keep_dims")
+ return wrap(op_func(t, indices, keepdims=keep_dims), True)
+
+
+@RegisterPForWithArgs("Cumsum", math_ops.cumsum)
+@RegisterPForWithArgs("Cumprod", math_ops.cumprod)
+def _convert_cumfoo(pfor_input, _, op_func):
+ t = pfor_input.stacked_input(0)
+ axis = pfor_input.unstacked_input(1)
+ # Shift positive indices by one to account for the extra dimension.
+ axis += math_ops.cast(axis >= 0, dtypes.int32)
+ exclusive = pfor_input.get_attr("exclusive")
+ reverse = pfor_input.get_attr("reverse")
+ return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True)
+
+
+@RegisterPFor("BiasAdd")
+def _convert_biasadd(pfor_input):
+ t = pfor_input.stacked_input(0)
+ bias = pfor_input.unstacked_input(1)
+ data_format = pfor_input.get_attr("data_format")
+ if data_format != b"NCHW":
+ return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
+ shape = array_ops.shape(t)
+ flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
+ t = array_ops.reshape(t, flattened_shape)
+ t = nn_ops.bias_add(t, bias, data_format=b"NCHW")
+ t = array_ops.reshape(t, shape)
+ return wrap(t, True)
+
+
+@RegisterPFor("UnsortedSegmentSum")
+def _convert_unsortedsegmentsum(pfor_input):
+ data, data_stacked, _ = pfor_input.input(0)
+ # TODO(agarwal): handle unstacked?
+ segment_ids = pfor_input.stacked_input(1)
+ # TODO(agarwal): handle stacked?
+ num_segments = pfor_input.unstacked_input(2)
+ if not data_stacked:
+ data = _stack(data, pfor_input.pfor.loop_len_vector).t
+ segment_shape = array_ops.shape(segment_ids)
+ n = segment_shape[0]
+ ones = array_ops.ones_like(segment_shape)[1:]
+ segment_offset = num_segments * math_ops.range(n)
+ segment_offset = array_ops.reshape(segment_offset,
+ array_ops.concat([[n], ones], axis=0))
+ segment_ids += segment_offset
+ num_segments *= n
+ output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments)
+ new_output_shape = array_ops.concat(
+ [[n, -1], array_ops.shape(output)[1:]], axis=0)
+ output = array_ops.reshape(output, new_output_shape)
+ return wrap(output, True)
+
+
+@RegisterPFor("Cast")
+def _convert_cast(pfor_input):
+ inp = pfor_input.stacked_input(0)
+ dtype = pfor_input.get_attr("DstT")
+ return wrap(math_ops.cast(inp, dtype), True)
+
+
+# Note that ops handled here do not have attributes except "T", and hence don't
+# need extra arguments passed to the cwise_op call below.
+@RegisterPForWithArgs("Add", math_ops.add)
+@RegisterPForWithArgs("Ceil", math_ops.ceil)
+@RegisterPForWithArgs("Equal", math_ops.equal)
+@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Floor", math_ops.floor)
+@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
+@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
+@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
+@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
+@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
+@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
+@RegisterPForWithArgs("Maximum", math_ops.maximum)
+@RegisterPForWithArgs("Minimum", math_ops.minimum)
+@RegisterPForWithArgs("Mul", math_ops.multiply)
+@RegisterPForWithArgs("Neg", math_ops.negative)
+@RegisterPForWithArgs("RealDiv", math_ops.divide)
+@RegisterPForWithArgs("Relu", nn_ops.relu)
+@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
+@RegisterPForWithArgs("Square", math_ops.square)
+@RegisterPForWithArgs("Sub", math_ops.subtract)
+@RegisterPForWithArgs("Tanh", math_ops.tanh)
+def _convert_cwise(pfor_input, op_type, op_func):
+ del op_type
+ pfor_input.expanddim_inputs_for_broadcast()
+ return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("Shape")
+def _convert_shape(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ return wrap(
+ array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:],
+ False)
+
+
+@RegisterPFor("ShapeN")
+def _convert_shape_n(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ shapes = [
+ array_ops.shape(x, out_type=out_type)[1:]
+ if stacked else array_ops.shape(x) for x, stacked, _ in pfor_input.inputs
+ ]
+ return [wrap(x, False) for x in shapes]
+
+
+@RegisterPFor("Size")
+def _convert_size(pfor_input):
+ out_type = pfor_input.get_attr("out_type")
+ n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type)
+ return wrap(
+ array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n,
+ False)
+
+
+@RegisterPFor("Rank")
+def _convert_rank(pfor_input):
+ return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False)
+
+
+@RegisterPFor("AddN")
+def _convert_addn(pfor_input):
+ # AddN does not support broadcasting.
+ pfor_input.stack_inputs()
+ return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True)
+
+
+@RegisterPFor("BiasAddGrad")
+def _convert_biasaddgrad(pfor_input):
+ grad = pfor_input.stacked_input(0)
+ fmt = pfor_input.get_attr("data_format")
+ if fmt == b"NCHW":
+ output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False)
+ else:
+ grad_shape = array_ops.shape(grad)
+ last_dim_shape = grad_shape[-1]
+ first_dim_shape = grad_shape[0]
+ output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape])
+ output = math_ops.reduce_sum(output, axis=[1], keepdims=False)
+ return wrap(output, True)
+
+
+# Some required ops are not exposed under the tf namespace. Hence relying on
+# _create_op to create them.
+@RegisterPForWithArgs("ReluGrad")
+@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SigmoidGrad")
+def _convert_grads(pfor_input, op_type, *args, **kw_args):
+ del args
+ del kw_args
+ # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we
+ # have to use tiling here.
+ pfor_input.stack_inputs()
+ outputs = _create_op(
+ op_type, [x.t for x in pfor_input.inputs],
+ [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+@RegisterPFor("Select")
+def _convert_select(pfor_input):
+ pfor_input.stack_inputs()
+ cond = pfor_input.stacked_input(0)
+ t = pfor_input.stacked_input(1)
+ e = pfor_input.stacked_input(2)
+ cond_rank = array_ops.rank(cond)
+ cond, t, e = control_flow_ops.cond(
+ cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]),
+ lambda: [cond, t, e])
+ outputs = _create_op(
+ pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ n = pfor_input.pfor.loop_len_vector
+ out = control_flow_ops.cond(cond_rank > 1,
+ lambda: _unflatten_first_dim(outputs[0], n),
+ lambda: outputs[0])
+ return [wrap(out, True) for x in outputs]
+
+
+# random_ops
+
+
+@RegisterPForWithArgs("RandomUniform")
+@RegisterPForWithArgs("RandomUniformInt")
+@RegisterPForWithArgs("RandomStandardNormal")
+@RegisterPForWithArgs("TruncatedNormal")
+@RegisterPForWithArgs("RandomGamma")
+@RegisterPForWithArgs("RandomPoissonV2")
+def _convert_random(pfor_input, op_type, *args, **kw_args):
+ del args
+ del kw_args
+ inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)]
+ # inputs[0] is "shape"
+ inputs[0] = array_ops.concat(
+ [pfor_input.pfor.loop_len_vector, inputs[0]], axis=0)
+ logging.warning(
+ "Note that %s inside pfor op may not give same output as "
+ "inside a sequential loop.", op_type)
+ outputs = _create_op(
+ op_type,
+ inputs, [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+# logging_ops
+
+
+@RegisterPFor("Assert")
+def _convert_assert(pfor_input):
+ cond, cond_stacked, _ = pfor_input.input(0)
+ if cond_stacked:
+ cond = math_ops.reduce_all(cond)
+
+ data_list = [x.t for x in pfor_input.inputs][1:]
+ return _create_op("Assert", [cond] + data_list, [],
+ attrs=pfor_input.op.node_def.attr)
+
+
+@RegisterPFor("Print")
+def _convert_print(pfor_input):
+ # Note that we don't stack all the inputs. Hence unstacked values are printed
+ # once here vs multiple times in a while_loop.
+ pfor_input.stack_inputs([0])
+ outputs = _create_op(
+ "Print", [x.t for x in pfor_input.inputs],
+ [x.dtype for x in pfor_input.outputs],
+ attrs=pfor_input.op.node_def.attr).outputs
+ return [wrap(x, True) for x in outputs]
+
+
+# data_flow_ops
+
+# TensorArray conversion is tricky since we don't support arrays of
+# TensorArrays. For converting them, we consider two distinct cases:
+#
+# 1. The array is constructed outside the pfor call, and read/written inside the
+# loop.
+# This is an easier case since we don't need to make an array of TensorArrays.
+# A correctness requirement is that these parallel iterations shouldn't attempt
+# to write to the same location. Hence at conversion time we disallow indices to
+# be loop-invariant as that would guarantee a collision. Even if the indices are
+# not loop-invariant, they could conflict and that shall trigger runtime errors.
+#
+# 2. The array is constructed and used entirely inside each pfor iteration.
+# For simplicity, here we require that the indices used for write/scatter are
+# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in
+# different pfor iterations. We consider two sub_cases:
+#
+# 2a Elements written to the array are "stacked"
+# To simulate multiple TensorArrays, we may increase the dimension of each
+# element of the array. i.e. the i_th row of the j_th entry of the converted
+# TensorArray corresponds to to the j_th entry of the TensorArray in the i_th
+# pfor iteration.
+#
+# 2b Elements written to the array are "unstacked"
+# In this case we don't increase the dimensions to avoid redundant tiling. Each
+# iteration is trying to write the same value. So we convert that to a single
+# write.
+#
+# Here are some tricks used to implement the above:
+# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of
+# trying to trace whether future writes are stacked or unstacked in order to set
+# this attr, we set it to correspond to unknown shape.
+# - We use the "flow" output of the different ops to track whether the array
+# elements are stacked or unstacked. If a stacked write/scatter is done, we make
+# the flow stacked as well.
+# - We use some heuristic traversal of the graph to track whether the
+# TensorArray handle was created inside or outside the pfor loop.
+
+
+@RegisterPFor("TensorArrayV3")
+def _convert_tensor_array_v3(pfor_input):
+ size = pfor_input.unstacked_input(0)
+ dtype = pfor_input.get_attr("dtype")
+ dynamic_size = pfor_input.get_attr("dynamic_size")
+ clear_after_read = pfor_input.get_attr("clear_after_read")
+ identical_element_shapes = pfor_input.get_attr("identical_element_shapes")
+ tensor_array_name = pfor_input.get_attr("tensor_array_name")
+ handle, flow = data_flow_ops.tensor_array_v3(
+ size,
+ dtype=dtype,
+ # We don't set element shape since we don't know if writes are stacked or
+ # not yet.
+ element_shape=None,
+ dynamic_size=dynamic_size,
+ clear_after_read=clear_after_read,
+ identical_element_shapes=identical_element_shapes,
+ tensor_array_name=tensor_array_name)
+ # Note we keep flow unstacked for now since we don't know if writes will be
+ # stacked or not.
+ return wrap(handle, False), wrap(flow, False)
+
+
+@RegisterPFor("TensorArraySizeV3")
+def _convert_tensor_array_size_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ flow, flow_stacked, _ = pfor_input.input(1)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ size = data_flow_ops.tensor_array_size_v3(handle, flow)
+ return wrap(size, False)
+
+
+def _handle_inside_pfor(pfor_input, handle):
+ """Returns True if handle was created inside the pfor loop."""
+ # We use some heuristic to find the original TensorArray creation op.
+ # The logic should handle the common cases (except cond based subgraphs).
+ # In theory the user could perform different operations on the handle (like
+ # Reshape, stack multiple handles, etc) which could break this logic.
+ # TODO(agarwal): handle Switch/Merge.
+ while handle.op.type in ("Enter", "Identity"):
+ handle = handle.op.inputs[0]
+ if handle.op.type not in [
+ "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape"]:
+ raise ValueError("Unable to find source for handle %s" % handle)
+ else:
+ return pfor_input.pfor.op_is_inside_loop(handle.op)
+
+
+def _unstack_flow(value):
+ # TODO(agarwal): consider looking if this is a Tile op then get its input.
+ # This may avoid running the Tile operations.
+ return array_ops.gather(value, 0)
+
+
+@RegisterPFor("TensorArrayReadV3")
+def _convert_tensor_array_read_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ index, index_stacked, _ = pfor_input.input(1)
+ dtype = pfor_input.get_attr("dtype")
+ flow, flow_stacked, _ = pfor_input.input(2)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+
+ is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside_pfor:
+ # Note that if we are inside a control flow construct inside the pfor, and
+ # only some of the iterations are doing the read (i.e.
+ # `all_indices_partitioned` is True), then the read operation should only
+ # return values for the currently active pfor iterations (`all_indices`
+ # below). Hence, whenever the returned value is stacked (i.e. `flow` is
+ # stacked), we may need to do an extra gather after reading the values. Also
+ # note that if `is_inside` is false, then values in the tensor array are
+ # unstacked. So the check is only needed in this branch.
+ all_indices = pfor_input.pfor.all_indices
+ all_indices_partitioned = pfor_input.pfor.all_indices_partitioned
+ # Note: flow_stacked indicates if values in the TensorArray are stacked or
+ # not.
+ if index_stacked:
+ if flow_stacked:
+ raise ValueError(
+ "It looks like TensorArrayReadV3 was called on a TensorArray whose"
+ " values are not loop-invariant, and the read indices were also"
+ " not loop invariant. This is currently unsupported.")
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, index, flow, dtype=dtype)
+ return wrap(value, True)
+ value = data_flow_ops.tensor_array_read_v3(
+ handle, index, flow, dtype=dtype)
+ if flow_stacked and all_indices_partitioned:
+ value = array_ops.gather(value, all_indices)
+ return wrap(value, flow_stacked)
+ # Values in the TensorArray should be unstacked (since different iterations
+ # couldn't write to the same location). So whether output is stacked or not
+ # depends on index_stacked.
+ if index_stacked:
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, index, flow, dtype=dtype)
+ else:
+ value = data_flow_ops.tensor_array_read_v3(
+ handle, index, flow, dtype=dtype)
+ return wrap(value, index_stacked)
+
+
+@RegisterPFor("TensorArrayWriteV3")
+def _convert_tensor_array_write_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ index, index_stacked, _ = pfor_input.input(1)
+ value, value_stacked, _ = pfor_input.input(2)
+ flow, flow_stacked, _ = pfor_input.input(3)
+ if value_stacked and pfor_input.pfor.all_indices_partitioned:
+ # Looks like we are in a control flow in a pfor where not all iterations are
+ # active now. We don't allow that since that could lead to different indices
+ # having different shapes which will be hard to merge later.
+ raise ValueError("Writing non loop invariant values to TensorArray from "
+ "inside a while_loop/cond not supported.")
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ if index_stacked:
+ raise ValueError("Need indices for %s to be loop invariant" % handle)
+ if not flow_stacked and not value_stacked:
+ flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
+ return wrap(flow_out, False)
+ else:
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ # TODO(agarwal): Note that if flow is unstacked and value is stacked, then
+ # this may or may not be a safe situation. flow is unstacked both for a
+ # freshly created TensorArray, as well as after unstacked values are
+ # written to it. If it is the latter, then we cannot write a stacked value
+ # now since that may cause runtime errors due to different shapes in the
+ # array. At the moment we are not able to handle this gracefully and
+ # distinguish between the two cases. That would require some heuristic
+ # traversal of the graph to figure out whether all the writes are
+ # unstacked or not.
+ flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+ else:
+ if not index_stacked:
+ raise ValueError("Need indices for %s to be not loop invariant" % handle)
+ # Note that even when index_stacked is true, actual values in index may
+ # still not be unique. However that will cause runtime error when executing
+ # the scatter operation below.
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+
+
+def _transpose_first_two_dims(value):
+ # TODO(agarwal): optimize if one of the dims == 1.
+ value_shape = array_ops.shape(value)
+ v0 = value_shape[0]
+ v1 = value_shape[1]
+ value = array_ops.reshape(value, [v0, v1, -1])
+ value = array_ops.transpose(value, [1, 0, 2])
+ new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0)
+ return array_ops.reshape(value, new_shape)
+
+
+@RegisterPFor("TensorArrayGatherV3")
+def _convert_tensor_array_gather_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ indices = array_ops.reshape(indices, [-1])
+ flow, flow_stacked, _ = pfor_input.input(2)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ dtype = pfor_input.get_attr("dtype")
+ # TODO(agarwal): support element_shape attr?
+
+ n = pfor_input.pfor.loop_len_vector
+ value = data_flow_ops.tensor_array_gather_v3(
+ handle, indices, flow, dtype=dtype)
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ # flow_stacked indicates if values in the TensorArray are stacked or not.
+ if indices_stacked:
+ if flow_stacked:
+ raise ValueError(
+ "It looks like TensorArrayGatherV3 was called on a TensorArray "
+ "whose values are not loop-invariant, and the indices were also "
+ "not loop invariant. This is currently unsupported.")
+ else:
+ value = _unflatten_first_dim(value, n)
+ return wrap(value, True)
+ else:
+ if flow_stacked:
+ # Since elements in this array are stacked and `value` was produced by
+ # gather, its first two dims are "gathered elements" and "stack
+ # dimension". Our semantics require these two to be flipped.
+ value = _transpose_first_two_dims(value)
+ return wrap(value, flow_stacked)
+ else:
+ # Values in the TensorArray should be unstacked (since different iterations
+ # couldn't write to the same location). So whether output is stacked or not
+ # depends on indices_stacked.
+ if indices_stacked:
+ value = _unflatten_first_dim(value, n)
+ return wrap(value, indices_stacked)
+
+
+@RegisterPFor("TensorArrayScatterV3")
+def _convert_tensor_array_scatter_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ indices, indices_stacked, _ = pfor_input.input(1)
+ indices = array_ops.reshape(indices, [-1])
+ value, value_stacked, _ = pfor_input.input(2)
+ flow, flow_stacked, _ = pfor_input.input(3)
+
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+
+ is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0])
+ if is_inside:
+ if indices_stacked:
+ raise ValueError("Need indices for %s to be loop invariant" % handle)
+ # Note that flow_stacked indicates if existing values in the array are
+ # stacked or not.
+ if not flow_stacked and not value_stacked:
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return wrap(flow_out, False)
+ if not value_stacked:
+ # TODO(agarwal): tile in the second dimension directly instead of
+ # transposing below.
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+
+ value = _transpose_first_two_dims(value)
+ # TODO(agarwal): Note that if a previous write was unstacked, flow will be
+ # unstacked, and a stacked value may be written here which may cause
+ # runtime error due to different elements having different shape. We do
+ # not try to prevent that.
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+ if not indices_stacked:
+ raise ValueError("Need indices for %s to be not loop invariant" % handle)
+ if not value_stacked:
+ value = _stack(value, pfor_input.pfor.loop_len_vector).t
+ value = _flatten_first_two_dims(value)
+ flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value,
+ flow)
+ return _stack(flow_out, pfor_input.pfor.loop_len_vector)
+
+
+@RegisterPFor("TensorArrayGradV3")
+def _convert_tensor_array_grad_v3(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ flow, flow_stacked, _ = pfor_input.input(1)
+ if flow_stacked:
+ flow = _unstack_flow(flow)
+ source = pfor_input.get_attr("source")
+ # TODO(agarwal): For now, we assume that gradients are stacked if the
+ # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong
+ # will give runtime error due to incorrect shape being written to the
+ # accumulator. It is difficult to know in advance if gradients written will be
+ # stacked or not. Note that flow being stacked is not indicative of the
+ # gradient being stacked or not. Revisit this later.
+ shape_to_prepend = pfor_input.pfor.loop_len_vector
+ grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape(
+ handle=handle,
+ flow_in=flow,
+ shape_to_prepend=shape_to_prepend,
+ source=source)
+ flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t
+ return [wrap(grad_handle, False), wrap(flow_out, True)]
+
+
+# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar
+# to TensorArrays, we convert them by changing the dimension of the elements
+# inside the stack.
+#
+# We consider two cases:
+#
+# 1. StackV2 is constructed and used entirely inside the pfor loop.
+# We keep a single Stack and perform the push/pop operations of all the
+# iterations in lock-step. We also assume that all the iterations perform these
+# operations. In case of dynamic control flow, if only some of the iterations
+# try to perform a push/pop, then the conversion may not work correctly and may
+# cause undefined behavior.
+# TODO(agarwal): test StackV2 with dynamic control flow.
+#
+# 2. StackV2 is constructed outside the pfor loop.
+# Performing stack push/pop in a parallel fashion is ill-defined. However given
+# that reading stacks created externally is a common operation when computing
+# jacobians, we provide some special semantics here as follows.
+# - disallow push operations to the stack
+# - pop operations are performed in lock step by all iterations, similar to the
+# case when the stack is created inside. A single value is popped during the
+# lock-step operation and broadcast to all the iterations. Values in the stack
+# are assumed to be loop-invariant.
+#
+# Some other implementation details:
+# We use an ugly logic to find whether values in Stack data structure are
+# loop invariant or not. When converting push/pop operations, we keep track of
+# whether the last conversion used a stacked value or not (see _stack_cache
+# below). As a result if an unstacked value is written first, subsequent stacked
+# writes are disallowed when they could have been allowed in theory.
+
+# Map from cache key based on StackV2 handle to a bool indicating whether values
+# are stacked or not.
+# TODO(agarwal): move _stack_cache inside pfor?
+_stack_cache = {}
+
+
+def _stack_cache_key(pfor_input):
+ """Create cache key corresponding to a stack handle."""
+ op_type = pfor_input.op_type
+ assert op_type in ["StackPushV2", "StackPopV2"], op_type
+ orig_handle = pfor_input.op.inputs[0]
+ while orig_handle.op.type in ["Identity", "Enter"]:
+ orig_handle = orig_handle.op.inputs[0]
+ assert orig_handle.op.type == "StackV2", orig_handle.op
+ return ops.get_default_graph(), pfor_input.pfor, orig_handle
+
+
+def _stack_handle_inside_pfor(handle, pfor_input):
+ while handle.op.type in ["Identity", "Enter"]:
+ handle = handle.op.inputs[0]
+ assert handle.op.type == "StackV2", (
+ "Unable to find StackV2 op. Got %s" % handle.op)
+ return pfor_input.pfor.op_is_inside_loop(handle.op)
+
+
+@RegisterPFor("StackPushV2")
+def _convert_stack_push_v2(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ elem, elem_stacked, _ = pfor_input.input(1)
+ swap_memory = pfor_input.get_attr("swap_memory")
+
+ if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input):
+ raise ValueError("StackPushV2 not allowed on stacks created outside pfor")
+ stack_cache_key = _stack_cache_key(pfor_input)
+ stacked = _stack_cache.get(stack_cache_key, None)
+ if stacked is None:
+ stacked = elem_stacked
+ _stack_cache[stack_cache_key] = stacked
+ else:
+ # If we previously made it unstacked then we can't revert to being stacked.
+ if not stacked and elem_stacked:
+ raise ValueError(
+ "It looks like the stack was previously determined to be loop"
+ " invariant, but we are now trying to push a loop dependent value"
+ " to it. This is currently unsupported.")
+ if stacked and not elem_stacked:
+ elem = _stack(elem, pfor_input.pfor.loop_len_vector).t
+ out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory)
+ return wrap(out, stacked)
+
+
+# Note that inputs to this convertor will be unstacked. However it should get
+# called since it is a stateful op.
+@RegisterPFor("StackPopV2")
+def _convert_stack_pop_v2(pfor_input):
+ handle = pfor_input.unstacked_input(0)
+ stack_cache_key = _stack_cache_key(pfor_input)
+ stacked = _stack_cache.get(stack_cache_key, None)
+ # If a StackPushV2 has not been converted yet, we default to unstacked since
+ # the push could be outside of pfor, or the covertor may not be called if the
+ # inputs are unconverted.
+ if stacked is None:
+ stacked = False
+ _stack_cache[stack_cache_key] = False
+ elem_type = pfor_input.get_attr("elem_type")
+ out = data_flow_ops.stack_pop_v2(handle, elem_type)
+ return wrap(out, stacked)
+
+
+# parsing_ops
+
+
+@RegisterPFor("DecodeCSV")
+def _convert_decode_csv(pfor_input):
+ lines = pfor_input.stacked_input(0)
+ record_defaults = [
+ pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
+ ]
+ field_delim = pfor_input.get_attr("field_delim")
+ use_quote_delim = pfor_input.get_attr("use_quote_delim")
+ select_cols = pfor_input.get_attr("select_cols")
+ if not select_cols:
+ select_cols = None
+ return [
+ wrap(t, True) for t in parsing_ops.decode_csv(
+ lines,
+ record_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ select_cols=select_cols)
+ ]
+
+
+@RegisterPFor("ParseSingleExample")
+def _convert_parse_single_example(pfor_input):
+ serialized = pfor_input.stacked_input(0)
+ dense_defaults = [
+ pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs)
+ ]
+ sparse_keys = pfor_input.get_attr("sparse_keys")
+ dense_keys = pfor_input.get_attr("dense_keys")
+ sparse_types = pfor_input.get_attr("sparse_types")
+ dense_shapes = pfor_input.get_attr("dense_shapes")
+ output = gen_parsing_ops.parse_example(
+ serialized=serialized,
+ names=[],
+ dense_defaults=dense_defaults,
+ sparse_keys=sparse_keys,
+ dense_keys=dense_keys,
+ sparse_types=sparse_types,
+ dense_shapes=dense_shapes)
+ return [wrap(t, True, True) for t in nest.flatten(output)]
diff --git a/tensorflow/python/ops/random_grad.py b/tensorflow/python/ops/random_grad.py
new file mode 100644
index 0000000000..baa8e2e2cd
--- /dev/null
+++ b/tensorflow/python/ops/random_grad.py
@@ -0,0 +1,65 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for operators defined in random_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_random_ops
+from tensorflow.python.ops import math_ops
+
+
+def add_leading_unit_dimensions(x, num_dimensions):
+ new_shape = array_ops.concat(
+ [array_ops.ones([num_dimensions], dtype=dtypes.int32),
+ array_ops.shape(x)], axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+@ops.RegisterGradient("RandomGamma")
+def _RandomGammaGrad(op, grad): # pylint: disable=invalid-name
+ """Returns the gradient of a Gamma sample w.r.t. alpha.
+
+ The gradient is computed using implicit differentiation, see
+ "Implicit Reparameterization Gradients" (https://arxiv.org/abs/1805.08498).
+
+ Args:
+ op: A `RandomGamma` operation. We assume that the inputs to the operation
+ are `shape` and `alpha` tensors, and the output is the `sample` tensor.
+ grad: The incoming gradient `dloss / dsample` of the same shape as
+ `op.outputs[0]`.
+
+ Returns:
+ A `Tensor` with derivatives `dloss / dalpha`
+ """
+ shape = op.inputs[0]
+ alpha = op.inputs[1]
+ sample = op.outputs[0]
+
+ with ops.control_dependencies([grad]):
+ # Make the parameters alpha broadcastable with samples by appending
+ # unit dimensions.
+ num_sample_dimensions = array_ops.shape(shape)[0]
+ alpha_broadcastable = add_leading_unit_dimensions(
+ alpha, num_sample_dimensions)
+ partial_a = gen_random_ops.random_gamma_grad(alpha_broadcastable, sample)
+
+ # The first input is shape; the second input is alpha.
+ return (None, math_ops.reduce_sum(
+ grad * partial_a, axis=math_ops.range(num_sample_dimensions)))
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 6a2dd3f1cd..b8738adf66 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -368,25 +368,41 @@ def random_gamma(shape,
`alpha` is the shape parameter describing the distribution(s), and `beta` is
the inverse scale parameter(s).
- Example:
+ Note: Because internal calculations are done using `float64` and casting has
+ `floor` semantics, we must manually map zero outcomes to the smallest
+ possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This
+ means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
+ should. This bias can only happen for small values of `alpha`, i.e.,
+ `alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
- samples = tf.random_gamma([10], [0.5, 1.5])
- # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
- # the samples drawn from each distribution
+ The samples are differentiable w.r.t. alpha and beta.
+ The derivatives are computed using the approach described in the paper
- samples = tf.random_gamma([7, 5], [0.5, 1.5])
- # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
- # represents the 7x5 samples drawn from each of the two distributions
+ [Michael Figurnov, Shakir Mohamed, Andriy Mnih.
+ Implicit Reparameterization Gradients, 2018](https://arxiv.org/abs/1805.08498)
- samples = tf.random_gamma([30], [[1.],[3.],[5.]], beta=[[3., 4.]])
- # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
+ Example:
- Note: Because internal calculations are done using `float64` and casting has
- `floor` semantics, we must manually map zero outcomes to the smallest
- possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This
- means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
- should. This bias can only happen for small values of `alpha`, i.e.,
- `alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
+ ```python
+ samples = tf.random_gamma([10], [0.5, 1.5])
+ # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
+ # the samples drawn from each distribution
+
+ samples = tf.random_gamma([7, 5], [0.5, 1.5])
+ # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
+ # represents the 7x5 samples drawn from each of the two distributions
+
+ alpha = tf.constant([[1.],[3.],[5.]])
+ beta = tf.constant([[3., 4.]])
+ samples = tf.random_gamma([30], alpha=alpha, beta=beta)
+ # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
+
+ loss = tf.reduce_mean(tf.square(samples))
+ dloss_dalpha, dloss_dbeta = tf.gradients(loss, [alpha, beta])
+ # unbiased stochastic derivatives of the loss function
+ alpha.shape == dloss_dalpha.shape # True
+ beta.shape == dloss_dbeta.shape # True
+ ```
Args:
shape: A 1-D integer Tensor or Python array. The shape of the output samples
@@ -406,8 +422,9 @@ def random_gamma(shape,
name: Optional name for the operation.
Returns:
- samples: a `Tensor` of shape `tf.concat(shape, tf.shape(alpha + beta))`
- with values of type `dtype`.
+ samples: a `Tensor` of shape
+ `tf.concat([shape, tf.shape(alpha + beta)], axis=0)` with values of type
+ `dtype`.
"""
with ops.name_scope(name, "random_gamma", [shape, alpha, beta]):
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
@@ -421,8 +438,6 @@ def random_gamma(shape,
gen_random_ops.random_gamma(
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
-ops.NotDifferentiable("RandomGamma")
-
@tf_export("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
@@ -432,13 +447,15 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
Example:
- samples = tf.random_poisson([0.5, 1.5], [10])
- # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
- # the samples drawn from each distribution
+ ```python
+ samples = tf.random_poisson([0.5, 1.5], [10])
+ # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
+ # the samples drawn from each distribution
- samples = tf.random_poisson([12.2, 3.3], [7, 5])
- # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
- # represents the 7x5 samples drawn from each of the two distributions
+ samples = tf.random_poisson([12.2, 3.3], [7, 5])
+ # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
+ # represents the 7x5 samples drawn from each of the two distributions
+ ```
Args:
lam: A Tensor or Python value or N-D array of type `dtype`.
@@ -455,8 +472,8 @@ def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
name: Optional name for the operation.
Returns:
- samples: a `Tensor` of shape `tf.concat(shape, tf.shape(lam))` with
- values of type `dtype`.
+ samples: a `Tensor` of shape `tf.concat([shape, tf.shape(lam)], axis=0)`
+ with values of type `dtype`.
"""
with ops.name_scope(name, "random_poisson", [lam, shape]):
shape = ops.convert_to_tensor(shape, name="shape", dtype=dtypes.int32)
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 7061b32808..70a89e5ebb 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -507,6 +507,9 @@ class ResourceVariable(variables.Variable):
else:
self._cached_value = None
if not context.executing_eagerly():
+ # Eager variables are only added to collections if they are part of an
+ # eager variable store (otherwise in an interactive session they would
+ # hog memory and cause OOM). This is done in ops/variable_scope.py.
ops.add_to_collections(collections, self)
elif ops.GraphKeys.GLOBAL_STEP in collections:
ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)
@@ -848,14 +851,15 @@ class ResourceVariable(variables.Variable):
operator: string. The operator name.
"""
+ tensor_oper = getattr(ops.Tensor, operator)
def _run_op(a, *args):
# pylint: disable=protected-access
value = a._AsTensor()
- return getattr(ops.Tensor, operator)(value, *args)
+ return tensor_oper(value, *args)
# Propagate __doc__ to wrapper
try:
- _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
+ _run_op.__doc__ = tensor_oper.__doc__
except AttributeError:
pass
@@ -863,6 +867,19 @@ class ResourceVariable(variables.Variable):
__array_priority__ = 100
+ def is_initialized(self, name=None):
+ """Checks whether a resource variable has been initialized.
+
+ Outputs boolean scalar indicating whether the tensor has been initialized.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of type `bool`.
+ """
+ return gen_resource_variable_ops.var_is_initialized_op(self.handle, name)
+
def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
"""Subtracts a value from this variable.
@@ -995,32 +1012,28 @@ class ResourceVariable(variables.Variable):
def __imul__(self, unused_other):
raise RuntimeError("Variable *= value not supported. Use "
- "variable.assign_mul(value) to modify the variable "
- "value and variable = variable * value to get a new "
- "Tensor object.")
+ "`var.assign(var * value)` to modify the variable or "
+ "`var = var * value` to get a new Tensor object.")
def __idiv__(self, unused_other):
raise RuntimeError("Variable /= value not supported. Use "
- "variable.assign_div(value) to modify the variable "
- "value and variable = variable / value to get a new "
- "Tensor object.")
+ "`var.assign(var / value)` to modify the variable or "
+ "`var = var / value` to get a new Tensor object.")
def __itruediv__(self, unused_other):
raise RuntimeError("Variable /= value not supported. Use "
- "variable.assign_div(value) to modify the variable "
- "value and variable = variable / value to get a new "
- "Tensor object.")
+ "`var.assign(var / value)` to modify the variable or "
+ "`var = var / value` to get a new Tensor object.")
def __irealdiv__(self, unused_other):
raise RuntimeError("Variable /= value not supported. Use "
- "variable.assign_div(value) to modify the variable "
- "value and variable = variable / value to get a new "
- "Tensor object.")
+ "`var.assign(var / value)` to modify the variable or "
+ "`var = var / value` to get a new Tensor object.")
def __ipow__(self, unused_other):
raise RuntimeError("Variable **= value not supported. Use "
- "value and variable = variable ** value to get a new "
- "Tensor object.")
+ "`var.assign(var ** value)` to modify the variable or "
+ "`var = var ** value` to get a new Tensor object.")
pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable)
@@ -1064,6 +1077,10 @@ class _UnreadVariable(ResourceVariable):
self._graph_element = self.read_value()
self._handle_deleter = deleter
+ @property
+ def name(self):
+ return self._parent_op.name
+
def value(self):
return self._read_variable_op()
@@ -1087,6 +1104,113 @@ class _UnreadVariable(ResourceVariable):
ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)
+
+class _MixedPrecisionVariable(ResourceVariable):
+ """Represents a variable that can return in desired dtype when read.
+
+ In mixed precision training, it is usually desirable to use different dtypes
+ for variables and computation. This class will be used to wrap created
+ ResourceVariable when mixed precision training is enabled. It allows layers to
+ perform computation in a different dtype than their variable dtypes, in order
+ to achieve higher performance without causing quality loss.
+ """
+
+ def __init__(self, var, read_dtype):
+ """Creates a MixedPrecisionVariable.
+
+ Args:
+ var: A ResourceVariable instance.
+ read_dtype: A tf.DType, the returned dtype when read, default to None.
+ Casting is performed if read_dtype is not None and differs from
+ var.dtype.
+ Returns:
+ An MixedPrecisionVariable instance.
+ Raises:
+ ValueError: if var is not a ResourceVariable instance, or read_dtype is
+ not a tf.DType instance.
+ """
+ # pylint: disable=super-init-not-called
+ # We do not call super init on purpose.
+ if not isinstance(var, ResourceVariable):
+ raise ValueError("InvalidArgument: var must be a ResourceVariable type.")
+ if not isinstance(read_dtype, dtypes.DType):
+ raise ValueError("InvalidArgument: read_dtype must be a tf.DType type.")
+
+ self._var = var
+ self._trainable = var.trainable
+ self._save_slice_info = None
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ self._in_graph_mode = var._in_graph_mode # pylint: disable=protected-access
+ self._handle = var.handle
+ self._shape = var.shape
+ self._initial_value = None
+ if isinstance(self.handle, ops.EagerTensor):
+ self._handle_name = ""
+ else:
+ self._handle_name = self.handle.name
+ self._unique_id = var._unique_id # pylint: disable=protected-access
+ self._dtype = var.dtype
+ self._constraint = None
+ self._cached_value = None
+ self._is_initialized_op = var._is_initialized_op # pylint: disable=protected-access
+ self._initializer_op = var._initializer_op # pylint: disable=protected-access
+ # This needs to be set before read_value() is called.
+ self._read_dtype = read_dtype
+ if context.executing_eagerly():
+ self._graph_element = None
+ else:
+ self._graph_element = self.read_value()
+ self._handle_deleter = (
+ var._handle_deleter if not self._in_graph_mode # pylint: disable=protected-access
+ else None)
+ # pylint: enable=super-init-not-called
+
+ @property
+ def name(self):
+ return self._var.name
+
+ def value(self):
+ return self._read_variable_op()
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def _read_variable_op(self):
+ with ops.colocate_with(self._handle):
+ res = gen_resource_variable_ops.read_variable_op(self._handle,
+ self._dtype)
+ if self._read_dtype != self._dtype:
+ return math_ops.cast(res, self._read_dtype)
+ else:
+ return res
+
+ def set_shape(self, shape):
+ self._shape = shape
+ self._cached_shape_as_list = None
+
+ @property
+ def op(self):
+ """The op for this variable."""
+ return self._var.op
+
+ @property
+ def read_dtype(self):
+ """The dtype of the returned tensor when reading the var."""
+ return self._read_dtype
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ del name
+ dtype = dtype or self.read_dtype
+ if dtype != self.read_dtype or as_ref:
+ return NotImplemented
+ else:
+ res = self.value()
+ return res
+
+ def _should_act_as_resource_variable(self):
+ """To pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 10d576c95b..deba133fb9 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import tensor_array_ops
@@ -131,6 +132,18 @@ def _maybe_tensor_shape_from_tensor(shape):
return shape
+def _should_cache():
+ """Returns True if a default caching device should be set, otherwise False."""
+ if context.executing_eagerly():
+ return False
+ # Don't set a caching device when running in a loop, since it is possible that
+ # train steps could be wrapped in a tf.while_loop. In that scenario caching
+ # prevents forward computations in loop iterations from re-reading the
+ # updated weights.
+ ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
+ return control_flow_util.GetContainingWhileContext(ctxt) is None
+
+
# pylint: disable=unused-argument
def _rnn_step(
time, sequence_length, min_sequence_length, max_sequence_length,
@@ -558,7 +571,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -828,7 +841,8 @@ def _dynamic_rnn_loop(cell,
final_outputs = nest.pack_sequence_as(
structure=cell.output_size, flat_sequence=final_outputs)
if not in_graph_mode:
- final_outputs = array_ops.stack(final_outputs, axis=0)
+ final_outputs = nest.map_structure_up_to(
+ cell.output_size, lambda x: array_ops.stack(x, axis=0), final_outputs)
return (final_outputs, final_state)
@@ -1014,7 +1028,7 @@ def raw_rnn(cell, loop_fn,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -1227,7 +1241,7 @@ def static_rnn(cell,
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
with vs.variable_scope(scope or "rnn") as varscope:
- if not context.executing_eagerly():
+ if _should_cache():
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 05723c6960..70805fd572 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -54,16 +54,6 @@ from tensorflow.python.util.tf_export import tf_export
_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"
-
-# TODO(jblespiau): Remove this function when we are sure there are no longer
-# any usage (even if protected, it is being used). Prefer assert_like_rnncell.
-def _like_rnncell(cell):
- """Checks that a given object is an RNNCell by using duck typing."""
- conditions = [hasattr(cell, "output_size"), hasattr(cell, "state_size"),
- hasattr(cell, "zero_state"), callable(cell)]
- return all(conditions)
-
-
# This can be used with self.assertRaisesRegexp for assert_like_rnncell.
ASSERT_LIKE_RNNCELL_ERROR_REGEXP = "is not an RNNCell"
@@ -1329,48 +1319,3 @@ class MultiRNNCell(RNNCell):
array_ops.concat(new_states, 1))
return cur_inp, new_states
-
-
-class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable):
- """A simple wrapper for slim.rnn_cells."""
-
- def __init__(self, cell_fn):
- """Create a SlimRNNCell from a cell_fn.
-
- Args:
- cell_fn: a function which takes (inputs, state, scope) and produces the
- outputs and the new_state. Additionally when called with inputs=None and
- state=None it should return (initial_outputs, initial_state).
-
- Raises:
- TypeError: if cell_fn is not callable
- ValueError: if cell_fn cannot produce a valid initial state.
- """
- if not callable(cell_fn):
- raise TypeError("cell_fn %s needs to be callable", cell_fn)
- self._cell_fn = cell_fn
- self._cell_name = cell_fn.func.__name__
- init_output, init_state = self._cell_fn(None, None)
- output_shape = init_output.get_shape()
- state_shape = init_state.get_shape()
- self._output_size = output_shape.with_rank(2)[1].value
- self._state_size = state_shape.with_rank(2)[1].value
- if self._output_size is None:
- raise ValueError("Initial output created by %s has invalid shape %s" %
- (self._cell_name, output_shape))
- if self._state_size is None:
- raise ValueError("Initial state created by %s has invalid shape %s" %
- (self._cell_name, state_shape))
-
- @property
- def state_size(self):
- return self._state_size
-
- @property
- def output_size(self):
- return self._output_size
-
- def __call__(self, inputs, state, scope=None):
- scope = scope or self._cell_name
- output, state = self._cell_fn(inputs, state, scope=scope)
- return output, state
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index f8df9b2c78..af103d3cc7 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Script Language Operators. See the @{$python/script_ops} guide."""
# pylint: disable=g-bad-name
@@ -30,30 +29,55 @@ import numpy as np
import six
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
+# Map from EagerPyFunc token to tuple (tape, eager args, eager outputs);
+# used for differentiation.
+tape_cache = {}
+
class EagerFunc(object):
"""A wrapper for a function owned by an EagerPyFunc."""
- def __init__(self, func, Tout):
+ def __init__(self, func, Tout, is_grad_func):
"""Constructs an EagerFunc.
Args:
func: The function to wrap.
Tout: A list of datatypes for the output; an empty list if the output is
None.
+ is_grad_func: Whether this EagerFunc is the gradient of another
+ EagerPyFunc.
"""
self._func = func
self._out_dtypes = Tout
+ self._is_grad_func = is_grad_func
def _convert(self, value, dtype):
+ """Converts `value` to a tensor of type `dtype`, with error checking.
+
+ Args:
+ value: The tensor to convert.
+ dtype: The desired dtype.
+
+ Returns:
+ A tensor of type `dtype`, or a zeros tensor if value is None and
+ this function is in fact a grdient function.
+
+ Raises:
+ RuntimeError: if `value` is a variable.
+ """
+
if isinstance(value, resource_variable_ops.ResourceVariable):
raise RuntimeError(
"Attempting to return a variable from an eagerly executed py_func. "
@@ -61,22 +85,39 @@ class EagerFunc(object):
"be returned; to return the value of a variable, make sure to obtain "
"the Tensor backing it by calling `.read_value()` on the variable in "
"question: %s" % value)
+ if value is None and self._is_grad_func:
+ # Gradient functions may legitimately return a list that contains
+ # both Tensors and Python Nones. Unfortuantely this breaks the
+ # OpKernel, so for now we replace None objects with zeros, which is
+ # mathematically correct but will prevent short-circuiting gradient
+ # computations.
+ #
+ # TODO(akshayka): Make it possible to return a list of both Tensors and
+ # Nones from an EagerPyFunc.
+ return constant_op.constant(0.0, dtype=dtype)
return ops.convert_to_tensor(value, dtype=dtype)
- def __call__(self, on_gpu, args):
+ def __call__(self, device, token, args):
"""Passes `args` to `self._func`, which is executed eagerly."""
- with context.eager_mode():
+
+ with context.eager_mode(), backprop.GradientTape() as tape:
+ for tensor in args:
+ tape.watch(tensor)
ret = self._func(*args)
- maybe_copy_to_gpu = lambda x: x if not on_gpu else x.gpu()
- if isinstance(ret, (tuple, list)):
- return [
- maybe_copy_to_gpu(self._convert(x, dtype=dtype))
- for (x, dtype) in zip(ret, self._out_dtypes)
- ]
- elif ret is None:
- return ret
- else:
- return maybe_copy_to_gpu(self._convert(ret, dtype=self._out_dtypes[0]))
+ # Use tf.identity to copy the returned tensors to device if neccesary.
+ with ops.device(device):
+ if isinstance(ret, (tuple, list)):
+ outputs = [
+ array_ops.identity(self._convert(x, dtype=dtype))
+ for (x, dtype) in zip(ret, self._out_dtypes)
+ ]
+ elif ret is None:
+ outputs = None
+ else:
+ outputs = array_ops.identity(
+ self._convert(ret, dtype=self._out_dtypes[0]))
+ tape_cache[compat.as_bytes(token)] = (tape, args, outputs)
+ return outputs
class FuncRegistry(object):
@@ -89,7 +130,7 @@ class FuncRegistry(object):
def __init__(self):
self._lock = threading.Lock()
self._unique_id = 0 # GUARDED_BY(self._lock)
- # Only store weakrefs to the funtions. The strong reference is stored in
+ # Only store weakrefs to the functions. The strong reference is stored in
# the graph.
self._funcs = weakref.WeakValueDictionary()
@@ -133,14 +174,14 @@ class FuncRegistry(object):
else:
return result
- def __call__(self, token, on_gpu, args):
+ def __call__(self, token, device, args):
"""Calls the registered function for `token` with args.
Args:
token: A key into this `FuncRegistry` identifying which function to call.
- on_gpu: A boolean indicating whether or not `token`'s corresponding
- operation was placed on GPU; only used if the function registered for
- `token` is an `EagerPyFunc`.
+ device: Name of the device on which outputs of `token`'s corresponding
+ operation should be placed. Used iff the function registered for `token`
+ is an EagerPyFunc.
args: The arguments to pass to the function registered for `token`.
Returns:
@@ -153,7 +194,14 @@ class FuncRegistry(object):
if func is None:
raise ValueError("callback %s is not found" % token)
if isinstance(func, EagerFunc):
- return func(on_gpu, args)
+ # NB: Different invocations of the same py_func will share the same
+ # token, and the entries they stash in the tape_cache will collide.
+ # In practice, when executing a graph, this should only happen if
+ # the py_func is in a while_loop whose iterations are run in parallel
+ # or if the graph is being driven by concurrent session.run() calls.
+ #
+ # TODO(akshayka): Key the tape cache in a thread-safe way.
+ return func(device, token, args)
else:
ret = func(*args)
# Strings seem to lead to a memory leak here if they're not wrapped in a
@@ -184,7 +232,13 @@ _py_funcs = FuncRegistry()
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
-def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
+def _internal_py_func(func,
+ inp,
+ Tout,
+ stateful=None,
+ eager=False,
+ is_grad_func=False,
+ name=None):
"""See documentation for py_func and eager_py_func."""
is_list_or_tuple = False
@@ -194,7 +248,7 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
Tout = [Tout]
if eager:
- func = EagerFunc(func, Tout)
+ func = EagerFunc(func, Tout, is_grad_func)
token = _py_funcs.insert(func)
# We tie the registered function's lifetime with the current default graph,
@@ -231,34 +285,56 @@ def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None):
return result if is_list_or_tuple else result[0]
+# TODO(akshayka): Implement higher-order derivatives.
+@ops.RegisterGradient("EagerPyFunc")
+def _EagerPyFuncGrad(op, dy):
+ """Computes the gradient of an EagerPyFunc."""
+
+ token = op.get_attr("token")
+
+ def eagerly_executed_grad(dy):
+ tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token))
+ return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy)
+
+ with ops.control_dependencies(op.outputs):
+ return _internal_py_func(
+ func=eagerly_executed_grad,
+ inp=[dy] if isinstance(dy, ops.Tensor) else dy,
+ Tout=[tensor.dtype for tensor in op.inputs],
+ eager=True,
+ is_grad_func=True)
+
+
def eager_py_func(func, inp, Tout, name=None):
"""Wraps a python function into a TensorFlow op that executes it eagerly.
This function allows expressing computations in a TensorFlow graph as
Python functions. In particular, it wraps a Python function `func`
- in a TensorFlow operation that executes it with eager exeuction enabled. As a
- consequence, `tf.contrib.eager.py_func` makes it possible to express control
- flow using Python constructs (`if`, `while`, `for`, etc.), instead of
- TensorFlow control flow constructs (@{tf.cond}, @{tf.while_loop}). For
- example, you might use `tf.contrib.eager.py_func` to implement the log huber
- function:
+ in a once-differentiable TensorFlow operation that executes it with eager
+ exeuction enabled. As a consequence, `tf.contrib.eager.py_func` makes it
+ possible to express control flow using Python constructs (`if`, `while`,
+ `for`, etc.), instead of TensorFlow control flow constructs (@{tf.cond},
+ @{tf.while_loop}). For example, you might use `tf.contrib.eager.py_func` to
+ implement the log huber function:
```python
def log_huber(x, m):
if tf.abs(x) <= m:
- return x ** 2
+ return x**2
else:
- return m ** 2 * (1 - 2 * tf.log(m) + tf.log(x ** 2))
+ return m**2 * (1 - 2 * tf.log(m) + tf.log(x**2))
x = tf.placeholder(tf.float32)
m = tf.placeholder(tf.float32)
y = tf.contrib.eager.py_func(func=log_huber, inp=[x, m], Tout=tf.float32)
+ dy_dx = tf.gradients(y, x)[0]
with tf.Session() as sess:
# The session executes `log_huber` eagerly. Given the feed values below,
- # it will take the second branch, so `output` evaluates to 7.24372.
- output = sess.run(y, feed_dict={x: 3.0, m: 2.0})
+ # it will take the first branch, so `y` evaluates to 1.0 and
+ # `dy_dx` evaluates to 2.0.
+ y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
```
You can also use `tf.contrib.eager.py_func` to debug your models at runtime
@@ -277,10 +353,6 @@ def eager_py_func(func, inp, Tout, name=None):
that take Tensors as inputs, execute TensorFlow operations in their bodies,
and return Tensors as outputs.
- `tf.contrib.eager.py_func` is not differentiable, though a gradient may be
- implemented in the future; if you would like to differentiate through it,
- please file an issue on Github.
-
Like @{tf.py_func}, `tf.contrib.eager.py_func` has the following limitations
with respect to serialization and distribution:
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 97353d6c74..1223b290ff 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -116,6 +116,35 @@ def _SparseReduceSumGrad(op, out_grad):
None, None)
+@ops.RegisterGradient("SparseSlice")
+def _SparseSliceGrad(op, *grads):
+ """The backward operator for the SparseSlice op.
+
+ This op takes in the upstream gradient w.r.t. non-empty values of
+ the sliced `SparseTensor`, and outputs the gradients w.r.t.
+ the non-empty values of input `SparseTensor`.
+
+ Args:
+ op: the SparseSlice op
+ *grads: the incoming gradients, one element per output of `op`
+
+ Returns:
+ Gradient for each of the 5 input tensors of SparseSlice:
+ (indices, values, shape, start, size)
+ The gradients for the indices, shape, start and the size are None.
+ """
+ backprop_val_grad = grads[1]
+ input_indices = op.inputs[0]
+ input_start = op.inputs[3]
+ output_indices = op.outputs[0]
+
+ val_grad = gen_sparse_ops.sparse_slice_grad(
+ backprop_val_grad, input_indices, input_start, output_indices)
+ val_grad.set_shape(op.inputs[1].get_shape())
+ # (indices, values, shape, start, size)
+ return (None, val_grad, None, None, None)
+
+
@ops.RegisterGradient("SparseTensorDenseMatMul")
def _SparseTensorDenseMatMulGrad(op, grad):
"""Gradients for the dense tensor in the SparseTensorDenseMatMul op.
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 6204adef3b..9a10abfcf7 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -34,7 +34,7 @@ from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
@tf_export('lbeta')
-def lbeta(x, name='lbeta'):
+def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define
@@ -64,7 +64,7 @@ def lbeta(x, name='lbeta'):
# This is consistent with a convention that the sum over the empty set 0, and
# the product is 1.
# This is standard. See https://en.wikipedia.org/wiki/Empty_set.
- with ops.name_scope(name, values=[x]):
+ with ops.name_scope(name, 'lbeta', [x]):
x = ops.convert_to_tensor(x, name='x')
# Note reduce_sum([]) = 0.
@@ -82,6 +82,54 @@ def lbeta(x, name='lbeta'):
return result
+@tf_export('math.bessel_i0')
+def bessel_i0(x, name=None):
+ """Computes the Bessel i0 function of `x` element-wise.
+
+ Modified Bessel function of order 0.
+
+ It is preferable to use the numerically stabler function `i0e(x)` instead.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.special.i0
+ @end_compatibility
+ """
+ with ops.name_scope(name, 'bessel_i0', [x]):
+ return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
+
+
+@tf_export('math.bessel_i1')
+def bessel_i1(x, name=None):
+ """Computes the Bessel i1 function of `x` element-wise.
+
+ Modified Bessel function of order 1.
+
+ It is preferable to use the numerically stabler function `i1e(x)` instead.
+
+ Args:
+ x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.special.i1
+ @end_compatibility
+ """
+ with ops.name_scope(name, 'bessel_i1', [x]):
+ return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
+
+
@tf_export('einsum', 'linalg.einsum')
def einsum(equation, *inputs, **kwargs):
"""A generalized contraction between tensors of arbitrary dimension.
@@ -153,6 +201,8 @@ def einsum(equation, *inputs, **kwargs):
indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
"""
+ equation = equation.replace(' ', '')
+
name = kwargs.pop('name', None)
if kwargs:
raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index d7c3a7e8dc..9bc4098d5b 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -25,23 +25,25 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.platform import test
-
+from tensorflow.python.platform import tf_logging
class LBetaTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes
def test_one_dimensional_arg(self):
# Should evaluate to 1 and 1/2.
x_one = [1, 1.]
x_one_half = [2, 1.]
with self.test_session(use_gpu=True):
- self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_one)).eval())
- self.assertAllClose(0.5,
- math_ops.exp(
- special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose(
+ 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one))))
+ self.assertAllClose(
+ 0.5, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual([], special_math_ops.lbeta(x_one).get_shape())
def test_one_dimensional_arg_dynamic(self):
@@ -52,7 +54,8 @@ class LBetaTest(test.TestCase):
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
self.assertAllClose(1, beta_ph.eval(feed_dict={ph: x_one}))
- self.assertAllClose(0.5, beta_ph.eval(feed_dict={ph: x_one_half}))
+ self.assertAllClose(0.5,
+ beta_ph.eval(feed_dict={ph: x_one_half}))
def test_four_dimensional_arg_with_partial_shape_dynamic(self):
x_ = np.ones((3, 2, 3, 4))
@@ -65,15 +68,17 @@ class LBetaTest(test.TestCase):
with self.test_session(use_gpu=True):
x_ph = array_ops.placeholder(dtypes.float32, [3, 2, 3, None])
beta_ph = math_ops.exp(special_math_ops.lbeta(x_ph))
- self.assertAllClose(expected_beta_x, beta_ph.eval(feed_dict={x_ph: x_}))
+ self.assertAllClose(expected_beta_x,
+ beta_ph.eval(feed_dict={x_ph: x_}))
+ @test_util.run_in_graph_and_eager_modes
def test_two_dimensional_arg(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose([0.5, 0.5],
- math_ops.exp(
- special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose(
+ [0.5, 0.5],
+ self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual((2,), special_math_ops.lbeta(x_one_half).get_shape())
def test_two_dimensional_arg_dynamic(self):
@@ -82,50 +87,59 @@ class LBetaTest(test.TestCase):
with self.test_session(use_gpu=True):
ph = array_ops.placeholder(dtypes.float32)
beta_ph = math_ops.exp(special_math_ops.lbeta(ph))
- self.assertAllClose([0.5, 0.5], beta_ph.eval(feed_dict={ph: x_one_half}))
+ self.assertAllClose([0.5, 0.5],
+ beta_ph.eval(feed_dict={ph: x_one_half}))
+ @test_util.run_in_graph_and_eager_modes
def test_two_dimensional_proper_shape(self):
# Should evaluate to 1/2.
x_one_half = [[2, 1.], [2, 1.]]
with self.test_session(use_gpu=True):
- self.assertAllClose([0.5, 0.5],
- math_ops.exp(
- special_math_ops.lbeta(x_one_half)).eval())
+ self.assertAllClose(
+ [0.5, 0.5],
+ self.evaluate(math_ops.exp(special_math_ops.lbeta(x_one_half))))
self.assertEqual(
(2,),
- array_ops.shape(special_math_ops.lbeta(x_one_half)).eval())
+ self.evaluate(array_ops.shape(special_math_ops.lbeta(x_one_half))))
self.assertEqual(
tensor_shape.TensorShape([2]),
special_math_ops.lbeta(x_one_half).get_shape())
+ @test_util.run_in_graph_and_eager_modes
def test_complicated_shape(self):
with self.test_session(use_gpu=True):
x = ops.convert_to_tensor(np.random.rand(3, 2, 2))
- self.assertAllEqual((3, 2),
- array_ops.shape(special_math_ops.lbeta(x)).eval())
+ self.assertAllEqual(
+ (3, 2), self.evaluate(array_ops.shape(special_math_ops.lbeta(x))))
self.assertEqual(
tensor_shape.TensorShape([3, 2]),
special_math_ops.lbeta(x).get_shape())
+ @test_util.run_in_graph_and_eager_modes
def test_length_1_last_dimension_results_in_one(self):
# If there is only one coefficient, the formula still works, and we get one
# as the answer, always.
x_a = [5.5]
x_b = [0.1]
with self.test_session(use_gpu=True):
- self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_a)).eval())
- self.assertAllClose(1, math_ops.exp(special_math_ops.lbeta(x_b)).eval())
+ self.assertAllClose(
+ 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_a))))
+ self.assertAllClose(
+ 1, self.evaluate(math_ops.exp(special_math_ops.lbeta(x_b))))
self.assertEqual((), special_math_ops.lbeta(x_a).get_shape())
+ @test_util.run_in_graph_and_eager_modes
def test_empty_rank1_returns_negative_infinity(self):
with self.test_session(use_gpu=True):
x = constant_op.constant([], shape=[0])
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant(-np.inf, shape=())
- self.assertAllEqual(expected_result.eval(), lbeta_x.eval())
+ self.assertAllEqual(self.evaluate(expected_result),
+ self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
+ @test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_last_dim_returns_negative_infinity(self):
with self.test_session(use_gpu=True):
event_size = 0
@@ -134,9 +148,11 @@ class LBetaTest(test.TestCase):
lbeta_x = special_math_ops.lbeta(x)
expected_result = constant_op.constant(-np.inf, shape=[batch_size])
- self.assertAllEqual(expected_result.eval(), lbeta_x.eval())
+ self.assertAllEqual(self.evaluate(expected_result),
+ self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
+ @test_util.run_in_graph_and_eager_modes
def test_empty_rank2_with_zero_batch_dim_returns_empty(self):
with self.test_session(use_gpu=True):
batch_size = 0
@@ -146,10 +162,40 @@ class LBetaTest(test.TestCase):
expected_result = constant_op.constant([], shape=[batch_size])
- self.assertAllEqual(expected_result.eval(), lbeta_x.eval())
+ self.assertAllEqual(self.evaluate(expected_result),
+ self.evaluate(lbeta_x))
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
+class BesselTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_bessel_i0(self):
+ x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+ x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self.assertAllClose(special.i0(x_single),
+ self.evaluate(special_math_ops.bessel_i0(x_single)))
+ self.assertAllClose(special.i0(x_double),
+ self.evaluate(special_math_ops.bessel_i0(x_double)))
+ except ImportError as e:
+ tf_logging.warn('Cannot test special functions: %s' % str(e))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_bessel_i1(self):
+ x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+ x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self.assertAllClose(special.i1(x_single),
+ self.evaluate(special_math_ops.bessel_i1(x_single)))
+ self.assertAllClose(special.i1(x_double),
+ self.evaluate(special_math_ops.bessel_i1(x_double)))
+ except ImportError as e:
+ tf_logging.warn('Cannot test special functions: %s' % str(e))
+
+
class EinsumTest(test.TestCase):
simple_cases = [
@@ -195,6 +241,12 @@ class EinsumTest(test.TestCase):
'iJ,Jk->ik',
'iJ,Ki->JK',
'iJk,Jklm->Jk'
+ 'ij, jk, kl -> il',
+ 'a, ab, abc -> abc',
+ 'ab, ab, cd, cd, ef, ef -> ',
+ 'abc, bac',
+ 'iJ, Ki -> JK',
+ 'iJk, Jklm -> Jk'
]
long_cases = [
@@ -203,6 +255,8 @@ class EinsumTest(test.TestCase):
'ea,fb,gc,hd,abcd->efgh',
'ea,fb,abcd,gc,hd->efgh',
'abhe,hidj,jgba,hiab,gab',
+ 'efc, dbc, acf, fd -> abe',
+ 'abhe, hidj, jgba, hiab, gab',
]
invalid_cases = [
@@ -273,20 +327,20 @@ class EinsumTest(test.TestCase):
input_axes, _, _ = axes.partition('->')
for idx in input_axes.split(','):
- shape = [all_axes[ax] for ax in idx]
+ shape = [all_axes[ax] for ax in idx if ax.isalpha()]
input_vals.append(np.random.random(shape))
input_tensors = [constant_op.constant(val) for val in input_vals]
output_tensor = special_math_ops.einsum(axes, *input_tensors)
with self.test_session(use_gpu=True):
- output_value = output_tensor.eval()
+ output_value = self.evaluate(output_tensor)
correct_value = np.einsum(axes, *input_vals)
err = np.abs(correct_value - output_value).max()
- print(axes, err)
- assert err < 1e-8
+ # print(axes, err)
+ self.assertLess(err, 1e-8)
def test_input_is_placeholder(self):
with ops.Graph().as_default():
@@ -298,8 +352,7 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [[2], [1], [1]],
}
- np.testing.assert_almost_equal([[7]], sess.run(
- out, feed_dict=feed_dict))
+ self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 3))
@@ -310,7 +363,7 @@ class EinsumTest(test.TestCase):
m0: [[1, 2, 3]],
m1: [2, 1, 1],
}
- np.testing.assert_almost_equal([7], sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([7], sess.run(out, feed_dict=feed_dict))
# Tests for placeholders which have two or more None values
with ops.Graph().as_default():
@@ -322,8 +375,7 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [[3], [2]],
}
- np.testing.assert_almost_equal([[[7]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(2, 1))
@@ -334,8 +386,7 @@ class EinsumTest(test.TestCase):
m0: [[3], [2]],
m1: [[[1, 2]]],
}
- np.testing.assert_almost_equal([[[7]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7]]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, None, 2))
@@ -346,8 +397,7 @@ class EinsumTest(test.TestCase):
m0: [[[1, 2]]],
m1: [3, 2],
}
- np.testing.assert_almost_equal([[7]], sess.run(
- out, feed_dict=feed_dict))
+ self.assertAllClose([[7]], sess.run(out, feed_dict=feed_dict))
with ops.Graph().as_default():
m0 = array_ops.placeholder(dtypes.int32, shape=(None, 2, None, 2))
@@ -358,8 +408,7 @@ class EinsumTest(test.TestCase):
m0: [[[[1, 2]], [[2, 1]]]],
m1: [[3, 2]],
}
- np.testing.assert_almost_equal([[[7, 8]]],
- sess.run(out, feed_dict=feed_dict))
+ self.assertAllClose([[[7, 8]]], sess.run(out, feed_dict=feed_dict))
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index 28054f50ef..293aace728 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -167,8 +167,8 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
raise NotImplementedError("The DCT length argument is not implemented.")
if axis != -1:
raise NotImplementedError("axis must be -1. Got: %s" % axis)
- if dct_type != 2:
- raise ValueError("Only the Type II DCT is supported.")
+ if dct_type not in (2, 3):
+ raise ValueError("Only Types II and III (I)DCT are supported.")
if norm not in (None, "ortho"):
raise ValueError(
"Unknown normalization. Expected None or 'ortho', got: %s" % norm)
@@ -179,18 +179,20 @@ def _validate_dct_arguments(dct_type, n, axis, norm):
def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
"""Computes the 1D [Discrete Cosine Transform (DCT)][dct] of `input`.
- Currently only Type II is supported. Implemented using a length `2N` padded
- @{tf.spectral.rfft}, as described here: https://dsp.stackexchange.com/a/10606
+ Currently only Types II and III are supported. Type II is implemented using a
+ length `2N` padded @{tf.spectral.rfft}, as described here:
+ https://dsp.stackexchange.com/a/10606. Type III is a fairly straightforward
+ inverse of Type II (i.e. using a length `2N` padded @{tf.spectral.irfft}).
@compatibility(scipy)
- Equivalent to scipy.fftpack.dct for the Type-II DCT.
+ Equivalent to scipy.fftpack.dct for Type-II and Type-III DCT.
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
@end_compatibility
Args:
input: A `[..., samples]` `float32` `Tensor` containing the signals to
take the DCT of.
- type: The DCT type to perform. Must be 2.
+ type: The DCT type to perform. Must be 2 or 3.
n: For future expansion. The length of the transform. Must be `None`.
axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
norm: The normalization to apply. `None` for no normalization or `'ortho'`
@@ -201,8 +203,8 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
A `[..., samples]` `float32` `Tensor` containing the DCT of `input`.
Raises:
- ValueError: If `type` is not `2`, `n` is not `None, `axis` is not `-1`, or
- `norm` is not `None` or `'ortho'`.
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
[dct]: https://en.wikipedia.org/wiki/Discrete_cosine_transform
"""
@@ -214,22 +216,91 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
axis_dim = input.shape[-1].value or _array_ops.shape(input)[-1]
axis_dim_float = _math_ops.to_float(axis_dim)
- scale = 2.0 * _math_ops.exp(_math_ops.complex(
- 0.0, -_math.pi * _math_ops.range(axis_dim_float) /
- (2.0 * axis_dim_float)))
-
- # TODO(rjryan): Benchmark performance and memory usage of the various
- # approaches to computing a DCT via the RFFT.
- dct2 = _math_ops.real(
- rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
-
- if norm == "ortho":
- n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
- n2 = n1 * _math_ops.sqrt(2.0)
- # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
- weights = _array_ops.pad(
- _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
- constant_values=n2)
- dct2 *= weights
-
- return dct2
+ if type == 2:
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0, -_math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+
+ # TODO(rjryan): Benchmark performance and memory usage of the various
+ # approaches to computing a DCT via the RFFT.
+ dct2 = _math_ops.real(
+ rfft(input, fft_length=[2 * axis_dim])[..., :axis_dim] * scale)
+
+ if norm == "ortho":
+ n1 = 0.5 * _math_ops.rsqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(2.0)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ dct2 *= weights
+
+ return dct2
+
+ elif type == 3:
+ if norm == "ortho":
+ n1 = _math_ops.sqrt(axis_dim_float)
+ n2 = n1 * _math_ops.sqrt(0.5)
+ # Use tf.pad to make a vector of [n1, n2, n2, n2, ...].
+ weights = _array_ops.pad(
+ _array_ops.expand_dims(n1, 0), [[0, axis_dim - 1]],
+ constant_values=n2)
+ input *= weights
+ else:
+ input *= axis_dim_float
+ scale = 2.0 * _math_ops.exp(
+ _math_ops.complex(
+ 0.0,
+ _math_ops.range(axis_dim_float) * _math.pi * 0.5 /
+ axis_dim_float))
+ dct3 = _math_ops.real(
+ irfft(
+ scale * _math_ops.complex(input, 0.0),
+ fft_length=[2 * axis_dim]))[..., :axis_dim]
+
+ return dct3
+
+
+# TODO(rjryan): Implement `type`, `n` and `axis` parameters.
+@tf_export("spectral.idct")
+def idct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disable=redefined-builtin
+ """Computes the 1D [Inverse Discrete Cosine Transform (DCT)][idct] of `input`.
+
+ Currently only Types II and III are supported. Type III is the inverse of
+ Type II, and vice versa.
+
+ Note that you must re-normalize by 1/(2n) to obtain an inverse if `norm` is
+ not `'ortho'`. That is:
+ `signal == idct(dct(signal)) * 0.5 / signal.shape[-1]`.
+ When `norm='ortho'`, we have:
+ `signal == idct(dct(signal, norm='ortho'), norm='ortho')`.
+
+ @compatibility(scipy)
+ Equivalent to scipy.fftpack.idct for Type-II and Type-III DCT.
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.idct.html
+ @end_compatibility
+
+ Args:
+ input: A `[..., samples]` `float32` `Tensor` containing the signals to take
+ the DCT of.
+ type: The IDCT type to perform. Must be 2 or 3.
+ n: For future expansion. The length of the transform. Must be `None`.
+ axis: For future expansion. The axis to compute the DCT along. Must be `-1`.
+ norm: The normalization to apply. `None` for no normalization or `'ortho'`
+ for orthonormal normalization.
+ name: An optional name for the operation.
+
+ Returns:
+ A `[..., samples]` `float32` `Tensor` containing the IDCT of `input`.
+
+ Raises:
+ ValueError: If `type` is not `2` or `3`, `n` is not `None, `axis` is not
+ `-1`, or `norm` is not `None` or `'ortho'`.
+
+ [idct]:
+ https://en.wikipedia.org/wiki/Discrete_cosine_transform#Inverse_transforms
+ """
+ _validate_dct_arguments(type, n, axis, norm)
+ inverse_type = {2: 3, 3: 2}[type]
+ return dct(input, type=inverse_type, n=n, axis=axis, norm=norm, name=name)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index a2d24711e2..d0e5f70025 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import cudnn_rnn_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import random_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 94d7458ec8..2c93cf72c7 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_resource_variable_ops
@@ -124,9 +123,7 @@ def is_variable_initialized(ref, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.is_variable_initialized(ref=ref, name=name)
# Handle resource variables.
- if context.executing_eagerly() or ref.op.type == "VarHandleOp":
- return gen_resource_variable_ops.var_is_initialized_op(ref.handle,
- name=name)
+ return ref.is_initialized(name=name)
@tf_export("assign_sub")
@@ -338,7 +335,6 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
Args:
ref: A Variable.
indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
- A Tensor. Must be one of the following types: int32, int64.
A tensor of indices into ref.
updates: A `Tensor`. Must have the same type as `ref`.
A Tensor. Must have the same type as ref. A tensor of updated
@@ -355,10 +351,9 @@ def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_nd_update(
ref, indices, updates, use_locking, name)
- with ops.control_dependencies([gen_state_ops.resource_scatter_nd_update(
- ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
- use_locking, name)]):
- return ref.read_value()
+ return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
@tf_export("scatter_add")
@@ -396,7 +391,7 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None):
A tensor of indices into the first dimension of `ref`.
updates: A `Tensor`. Must have the same type as `ref`.
A tensor of updated values to store in `ref`.
- use_locking: An optional `bool`. Defaults to `True`.
+ use_locking: An optional `bool`. Defaults to `False`.
If True, the assignment will be protected by a lock;
otherwise the behavior is undefined, but may exhibit less contention.
name: A name for the operation (optional).
@@ -411,3 +406,67 @@ def scatter_add(ref, indices, updates, use_locking=False, name=None):
return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
name=name))
+
+
+@tf_export("scatter_nd_add")
+def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
+ r"""Applies sparse addition to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to add 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ add = tf.scatter_nd_add(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(add)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 13, 3, 14, 14, 6, 7, 20]
+
+ See @{tf.scatter_nd} for more details about how to make updates to
+ slices.
+
+ Args:
+ ref: A mutable `Tensor`. Must be one of the following types: `float32`,
+ `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
+ `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
+ `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A tensor of indices into ref.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A tensor of updated values to add to ref.
+ use_locking: An optional `bool`. Defaults to `False`.
+ An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ A mutable `Tensor`. Has the same type as `ref`.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_nd_add(
+ ref, indices, updates, use_locking, name)
+ return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name))
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index ae79c01949..0280c89c10 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -91,6 +91,59 @@ def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=inv
shape.set_shape([2])
return sparse_tensor.SparseTensor(indices, values, shape)
+@tf_export("strings.split")
+def string_split_v2(source, sep=None, maxsplit=-1):
+ """Split elements of `source` based on `sep` into a `SparseTensor`.
+
+ Let N be the size of source (typically N will be the batch size). Split each
+ element of `source` based on `sep` and return a `SparseTensor`
+ containing the split tokens. Empty tokens are ignored.
+
+ For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+ then the output will be
+
+ st.indices = [0, 0;
+ 0, 1;
+ 1, 0;
+ 1, 1;
+ 1, 2]
+ st.shape = [2, 3]
+ st.values = ['hello', 'world', 'a', 'b', 'c']
+
+ If `sep` is given, consecutive delimiters are not grouped together and are
+ deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+ sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+ string, consecutive whitespace are regarded as a single separator, and the
+ result will contain no empty strings at the startor end if the string has
+ leading or trailing whitespace.
+
+ Note that the above mentioned behavior matches python's str.split.
+
+ Args:
+ source: `1-D` string `Tensor`, the strings to split.
+ sep: `0-D` string `Tensor`, the delimiter character.
+ maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
+
+ Raises:
+ ValueError: If sep is not a string.
+
+ Returns:
+ A `SparseTensor` of rank `2`, the strings split according to the delimiter.
+ The first column of the indices corresponds to the row in `source` and the
+ second column corresponds to the index of the split component in this row.
+ """
+ if sep is None:
+ sep = ''
+ sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
+ source = ops.convert_to_tensor(source, dtype=dtypes.string)
+
+ indices, values, shape = gen_string_ops.string_split_v2(
+ source, sep=sep, maxsplit=maxsplit)
+ indices.set_shape([None, 2])
+ values.set_shape([None])
+ shape.set_shape([2])
+ return sparse_tensor.SparseTensor(indices, values, shape)
+
def _reduce_join_reduction_dims(x, axis, reduction_indices):
"""Returns range(rank(x) - 1, 0, -1) if reduction_indices is None."""
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index b80f84eb7c..00150fe688 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -306,10 +306,11 @@ def create_db_writer(db_uri,
def _make_summary_writer(name, factory, **kwargs):
resource = gen_summary_ops.summary_writer(shared_name=name)
init_op_fn = lambda: factory(resource, **kwargs)
- # TODO(apassos): Consider doing this instead.
- # if not context.executing_eagerly():
- # ops.get_default_session().run(init_op)
- ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op_fn())
+ init_op = init_op_fn()
+ if not context.executing_eagerly():
+ # TODO(apassos): Consider doing this instead.
+ # ops.get_default_session().run(init_op)
+ ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME, init_op)
return SummaryWriter(resource, init_op_fn)
@@ -380,7 +381,8 @@ def summary_writer_function(name, tensor, function, family=None):
with ops.device("cpu:0"):
op = smart_cond.smart_cond(
should_record_summaries(), record, _nothing, name="")
- ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
+ if not context.executing_eagerly():
+ ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
return op
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 355b0d961e..161d9687d6 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.deprecation import deprecated
@@ -295,66 +296,6 @@ class Template(checkpointable.CheckpointableBase):
# which is not the same as whether the scope has been created.
self._variables_created = False
- def _checkpointable_custom_creator(self, next_creator, name, initial_value,
- checkpointable_parent=None, **kwargs):
- """A variable creation hook which adds Checkpointable dependencies.
-
- Set during the `Template`'s first wrapped function execution. Ensures that
- (a) `Template` objects depend on `Template`s created inside them which
- create variables, and (b) that any variables not in a more deeply nested
- `Template` are added as dependencies directly.
-
- The `checkpointable_parent` argument is passed between `Template` custom
- creators but ignored when the variable object itself is created. This
- argument indicates (if not `None`) that a more deeply nested `Template` has
- already added the variable as a dependency, and that parent `Template`s
- should add a dependency on that `Template` rather than on the variable
- directly.
-
- Args:
- next_creator: See `variable_scope.variable_creator_scope`; the next
- creator in the chain.
- name: The (full, scope-influenced) name of the variable. The scope name
- for the Template itself is stripped for the purposes of object-based
- dependency tracking, but scopes within Templates are respected.
- initial_value: See `variable_scope.variable_creator_scope`. Taken
- explicitly so the argument can be re-named and used with
- `Checkpointable._add_variable_with_custom_getter`.
- checkpointable_parent: If not None, a more deeply nested Template object
- to add a dependency on (rather than depending on the variable directly).
- **kwargs: Passed through to the next creator.
- Returns:
- The output of `next_creator`: the fetched/created variable object.
- """
- def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
- inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
- # we don't want to propagate.
- return next_creator(
- initial_value=initializer,
- name=name,
- **inner_kwargs)
- if name.startswith(self._variable_scope.name):
- scope_stripped_name = name[len(self._variable_scope.name) + 1:]
- if not checkpointable_parent:
- return self._add_variable_with_custom_getter(
- initializer=initial_value,
- name=scope_stripped_name,
- getter=_call_next_creator_renaming_initializer,
- # Disable error checking for Checkpointable. Exceptions are instead
- # raised if necessary when the object-based saver tries to
- # save/restore the object.
- overwrite=True,
- checkpointable_parent=self,
- **kwargs)
- else:
- self._track_checkpointable(
- checkpointable_parent,
- name=checkpointable_parent._variable_scope.name[ # pylint: disable=protected-access
- len(self._variable_scope.name) + 1:],
- overwrite=True)
- return next_creator(name=name, initial_value=initial_value,
- checkpointable_parent=self, **kwargs)
-
def _call_func(self, args, kwargs):
try:
vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
@@ -365,8 +306,7 @@ class Template(checkpointable.CheckpointableBase):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
- with variable_scope.variable_creator_scope(
- self._checkpointable_custom_creator):
+ with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
@@ -634,8 +574,7 @@ class EagerTemplate(Template):
else:
# The first time we run, restore variables if necessary (via
# Checkpointable).
- with variable_scope.variable_creator_scope(
- self._checkpointable_custom_creator):
+ with checkpointable_util.capture_dependencies(template=self):
result = self._func(*args, **kwargs)
if self._variables_created:
diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py
index 1f70d69548..d341349804 100644
--- a/tensorflow/python/ops/tensor_array_grad.py
+++ b/tensorflow/python/ops/tensor_array_grad.py
@@ -34,6 +34,7 @@ ops.NotDifferentiable("TensorArrayCloseV2")
ops.NotDifferentiable("TensorArrayV3")
ops.NotDifferentiable("TensorArrayGradV3")
+ops.NotDifferentiable("TensorArrayGradWithShape")
ops.NotDifferentiable("TensorArraySizeV3")
ops.NotDifferentiable("TensorArrayCloseV3")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index fa34774622..77f67c18ee 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
- "get_variable", "get_local_variable", "variable_scope",
- "variable_op_scope", "no_regularizer"]
+__all__ = [
+ "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
+ "get_local_variable", "variable_scope", "variable_op_scope",
+ "no_regularizer", "VariableSynchronization", "VariableAggregation"
+]
class _PartitionInfo(object):
@@ -188,6 +190,38 @@ class _ReuseMode(enum.Enum):
# REUSE_FALSE = 2
# REUSE_TRUE = 3
+
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
@@ -214,11 +248,23 @@ class _VariableStore(object):
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._store_eager_variables = False
- def get_variable(self, name, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- custom_getter=None, constraint=None):
+ def get_variable(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ custom_getter=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the stored
@@ -254,6 +300,8 @@ class _VariableStore(object):
forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys to add the `Variable` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where the
@@ -291,6 +339,15 @@ class _VariableStore(object):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -343,11 +400,22 @@ class _VariableStore(object):
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
- def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- constraint=None):
+ def _true_getter( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
is_scalar = (shape is not None
and isinstance(shape, collections_lib.Sequence)
and not shape)
@@ -397,11 +465,24 @@ class _VariableStore(object):
"name was already created with partitioning?" % name)
return self._get_single_variable(
- name=name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, reuse=reuse,
- trainable=trainable, collections=collections,
- caching_device=caching_device, validate_shape=validate_shape,
- use_resource=use_resource, constraint=constraint)
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ # Set trainable value based on synchronization value.
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
@@ -420,6 +501,8 @@ class _VariableStore(object):
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
+ "synchronization": synchronization,
+ "aggregation": aggregation,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
if "constraint" in function_utils.fn_args(custom_getter):
@@ -427,18 +510,36 @@ class _VariableStore(object):
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
- name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- reuse=reuse, trainable=trainable, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- constraint=constraint)
-
- def _get_partitioned_variable(
- self, name, partitioner, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- validate_shape=True, use_resource=None, constraint=None):
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ def _get_partitioned_variable(self,
+ name,
+ partitioner,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None):
"""Gets or creates a sharded variable list with these parameters.
The `partitioner` must be a callable that accepts a fully defined
@@ -688,12 +789,14 @@ class _VariableStore(object):
regularizer=None,
partition_info=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
validate_shape=True,
use_resource=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
@@ -713,6 +816,8 @@ class _VariableStore(object):
validate_shape: see get_variable.
use_resource: see get_variable.
constraint: see get_variable.
+ synchronization: see get_variable.
+ aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
@@ -793,7 +898,17 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
+ if context.executing_eagerly() and self._store_eager_variables:
+ if collections:
+ ops.add_to_collections(collections, v)
+ else:
+ ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, v)
+ if trainable:
+ ops.add_to_collection(ops.GraphKeys.TRAINABLE_VARIABLES, v)
+
if not context.executing_eagerly() or self._store_eager_variables:
# In eager mode we do not want to keep default references to Variable
# objects as this will prevent their memory from being released.
@@ -1037,14 +1152,16 @@ class VariableScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with this name or create a new one."""
if regularizer is None:
regularizer = self._regularizer
@@ -1082,12 +1199,22 @@ class VariableScope(object):
if dtype is None:
dtype = self._dtype
return var_store.get_variable(
- full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=reuse, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ full_name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(self,
var_store,
@@ -1096,7 +1223,7 @@ class VariableScope(object):
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1311,21 +1438,35 @@ def get_variable(name,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
return get_variable_scope().get_variable(
- _get_default_variable_store(), name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ _get_default_variable_store(),
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+
get_variable_or_local_docstring = (
"""%s
@@ -1422,29 +1563,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
@tf_export("get_local_variable")
-def get_local_variable(name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=False, # pylint: disable=unused-argument
- collections=None,
- caching_device=None,
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- custom_getter=None,
- constraint=None):
+def get_local_variable( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=False, # pylint: disable=unused-argument
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE,
+ custom_getter=None,
+ constraint=None):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
return get_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, trainable=False, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- custom_getter=custom_getter, constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=False,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ custom_getter=custom_getter,
+ constraint=constraint)
+
+
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
@@ -1778,6 +1934,23 @@ class variable_scope(object):
assert v.name == "foo/bar/v:0"
```
+ Simple example of how to reenter a premade variable scope safely:
+
+ ```python
+ with tf.variable_scope("foo") as vs:
+ pass
+
+ # Re-enter the variable scope.
+ with tf.variable_scope(vs,
+ auxiliary_name_scope=False) as vs1:
+ # Restore the original name_scope.
+ with tf.name_scope(vs1.original_name_scope):
+ v = tf.get_variable("v", [1])
+ assert v.name == "foo/v:0"
+ c = tf.constant([1], name="c")
+ assert c.name == "foo/c:0"
+ ```
+
Basic example of sharing a variable AUTO_REUSE:
```python
@@ -1900,7 +2073,8 @@ class variable_scope(object):
for this scope as well as all sub-scopes; if tf.AUTO_REUSE, we create
variables if they do not exist, and return them otherwise; if None, we
inherit the parent scope's reuse flag. When eager execution is enabled,
- this argument is always forced to be tf.AUTO_REUSE.
+ new variables are always created unless an EagerVariableStore or
+ template is currently active.
dtype: type of variables created in this scope (defaults to the type
in the passed scope, or inherited from parent scope).
use_resource: If False, all variables will be regular Variables. If True,
@@ -1915,7 +2089,9 @@ class variable_scope(object):
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
auxiliary_name_scope: If `True`, we create an auxiliary name scope with
- the scope. If `False`, we don't touch name scope.
+ the scope. If `False`, we don't create it. Note that the argument is
+ not inherited, and it only takes effect for once when creating. You
+ should only use it for re-entering a premade variable scope.
Returns:
A scope that can be captured and reused.
@@ -2174,11 +2350,28 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
return slice_dim, slice_shape
+def _get_trainable_value(synchronization, trainable):
+ """Computes the trainable value based on the given arguments."""
+ if synchronization == VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ.")
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+ return trainable
+
+
def default_variable_creator(next_creator=None, **kwargs):
"""Default variable creator."""
assert next_creator is None
initial_value = kwargs.get("initial_value", None)
- trainable = kwargs.get("trainable", True)
+ trainable = kwargs.get("trainable", None)
collections = kwargs.get("collections", None)
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
@@ -2186,6 +2379,12 @@ def default_variable_creator(next_creator=None, **kwargs):
dtype = kwargs.get("dtype", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.executing_eagerly()):
@@ -2213,25 +2412,35 @@ def _make_getter(captured_getter, captured_previous):
def variable(initial_value=None,
- trainable=True,
+ trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
dtype=None,
constraint=None,
- use_resource=None):
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
- return previous_getter(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name, dtype=dtype,
- constraint=constraint,
- use_resource=use_resource)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum None value.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ dtype=dtype,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
@tf_contextlib.contextmanager
@@ -2265,6 +2474,8 @@ def variable_creator_scope(variable_creator):
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
@@ -2283,6 +2494,15 @@ def variable_creator_scope(variable_creator):
constraint: A constraint function to be applied to the variable after
updates by some algorithms.
use_resource: if True, a ResourceVariable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
This set may grow over time, so it's important the signature of creators is as
mentioned above.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 4be9f5eb68..d3b8da6d2a 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1093,39 +1093,40 @@ class Variable(checkpointable.CheckpointableBase):
def __imul__(self, other):
logging.log_first_n(
logging.WARN,
- "Variable *= will be deprecated. Use variable.assign_mul"
- " if you want assignment to the variable value or 'x = x * y'"
+ "Variable *= will be deprecated. Use `var.assign(var * other)`"
+ " if you want assignment to the variable value or `x = x * y`"
" if you want a new python Tensor object.", 1)
return self * other
def __idiv__(self, other):
logging.log_first_n(
logging.WARN,
- "Variable /= will be deprecated. Use variable.assign_div"
- " if you want assignment to the variable value or 'x = x / y'"
+ "Variable /= will be deprecated. Use `var.assign(var / other)`"
+ " if you want assignment to the variable value or `x = x / y`"
" if you want a new python Tensor object.", 1)
return self / other
def __itruediv__(self, other):
logging.log_first_n(
logging.WARN,
- "Variable /= will be deprecated. Use variable.assign_div"
- " if you want assignment to the variable value or 'x = x / y'"
+ "Variable /= will be deprecated. Use `var.assign(var / other)`"
+ " if you want assignment to the variable value or `x = x / y`"
" if you want a new python Tensor object.", 1)
return self / other
def __irealdiv__(self, other):
logging.log_first_n(
logging.WARN,
- "Variable /= will be deprecated. Use variable.assign_div"
- " if you want assignment to the variable value or 'x = x / y'"
+ "Variable /= will be deprecated. Use `var.assign(var / other)`"
+ " if you want assignment to the variable value or `x = x / y`"
" if you want a new python Tensor object.", 1)
return self / other
def __ipow__(self, other):
logging.log_first_n(
logging.WARN,
- "Variable **= will be deprecated. Use 'x = x ** y'"
+ "Variable **= will be deprecated. Use `var.assign(var ** other)`"
+ " if you want assignment to the variable value or `x = x ** y`"
" if you want a new python Tensor object.", 1)
return self ** other
@@ -1403,6 +1404,10 @@ class PartitionedVariable(object):
def dtype(self):
return self._dtype
+ @property
+ def shape(self):
+ return self.get_shape()
+
def get_shape(self):
return self._shape
@@ -1722,6 +1727,8 @@ def report_uninitialized_variables(var_list=None,
var_list.append(op.outputs[0])
with ops.name_scope(name):
# Run all operations on CPU
+ if var_list:
+ init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
with ops.device("/cpu:0"):
if not var_list:
# Return an empty tensor so we only need to check for returned tensor
@@ -1729,9 +1736,7 @@ def report_uninitialized_variables(var_list=None,
return array_ops.constant([], dtype=dtypes.string)
else:
# Get a 1-D boolean tensor listing whether each variable is initialized.
- variables_mask = math_ops.logical_not(
- array_ops.stack(
- [state_ops.is_variable_initialized(v) for v in var_list]))
+ variables_mask = math_ops.logical_not(array_ops.stack(init_vars))
# Get a 1-D string tensor containing all the variable names.
variable_names_tensor = array_ops.constant(
[s.op.name for s in var_list])
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index eba2baaf6f..fa17b17d10 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -66,11 +66,11 @@ def _global_report_benchmark(
if not isinstance(extras, dict):
raise TypeError("extras must be a dict")
- logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
- "throughput: %g %s", name, iters if iters is not None else -1,
- wall_time if wall_time is not None else -1, cpu_time if
- cpu_time is not None else -1, throughput if
- throughput is not None else -1, str(extras) if extras else "")
+ logging.info("Benchmark [%s] iters: %d, wall_time: %g, cpu_time: %g,"
+ "throughput: %g %s", name, iters if iters is not None else -1,
+ wall_time if wall_time is not None else -1, cpu_time if
+ cpu_time is not None else -1, throughput if
+ throughput is not None else -1, str(extras) if extras else "")
entries = test_log_pb2.BenchmarkEntries()
entry = entries.entry.add()
diff --git a/tensorflow/python/platform/self_check.py b/tensorflow/python/platform/self_check.py
index 966a094e55..844ae99918 100644
--- a/tensorflow/python/platform/self_check.py
+++ b/tensorflow/python/platform/self_check.py
@@ -78,7 +78,7 @@ def preload_check():
"Could not find %r. TensorFlow requires that this DLL be "
"installed in a directory that is named in your %%PATH%% "
"environment variable. Download and install CUDA %s from "
- "this URL: https://developer.nvidia.com/cuda-toolkit"
+ "this URL: https://developer.nvidia.com/cuda-90-download-archive"
% (build_info.cudart_dll_name, build_info.cuda_version_number))
if hasattr(build_info, "cudnn_dll_name") and hasattr(
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 9e49188c1e..f9891f3b1e 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -707,8 +707,10 @@ class PrintModelAnalysisTest(test.TestCase):
a = array_ops.constant(np.ones((100, 100)))
b = array_ops.constant(np.ones((100, 100)))
c = a * b
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.min_graph_nodes = -1
- with session.Session() as sess:
+ with session.Session(config=config) as sess:
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 500dc30cc3..5d7535cf34 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -59,6 +59,7 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetConfig;
%rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
%rename("%s") TFE_ContextOptionsSetAsync;
+%rename("%s") TFE_ContextOptionsSetServerDef;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 2609a5d222..076f2d8760 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -87,6 +87,30 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+)
+
+py_test(
+ name = "loader_test",
+ size = "small",
+ srcs = ["loader_test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":builder",
+ ":loader",
+ ":signature_def_utils",
+ ":utils",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
],
)
@@ -149,6 +173,7 @@ py_test(
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 24a13c0f33..e58be804c2 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -270,6 +270,18 @@ class SavedModelBuilder(object):
self._add_train_op(train_op)
+ def _maybe_create_saver(self, saver=None):
+ """Creates a sharded saver if one does not already exist."""
+ if not saver:
+ # Initialize a saver to generate a sharded output for all saveables in the
+ # current scope.
+ saver = tf_saver.Saver(
+ variables._all_saveable_objects(), # pylint: disable=protected-access
+ sharded=True,
+ write_version=saver_pb2.SaverDef.V2,
+ allow_empty=True)
+ return saver
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -277,7 +289,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel.
@@ -302,6 +315,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph. If None, a sharded Saver that restores all variables will
+ be used.
Raises:
AssertionError: If the variables for the SavedModel have not been saved
@@ -320,18 +336,11 @@ class SavedModelBuilder(object):
# Add assets and ops
self._add_collections(assets_collection, legacy_init_op, main_op, None)
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
@@ -350,7 +359,8 @@ class SavedModelBuilder(object):
legacy_init_op=None,
clear_devices=False,
main_op=None,
- strip_default_attrs=False):
+ strip_default_attrs=False,
+ saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel and saves variables.
@@ -377,6 +387,9 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ saver: An instance of tf.train.Saver that will be used to export the
+ metagraph and save variables. If None, a sharded Saver that restores
+ all variables will be used.
"""
# pylint: enable=line-too-long
@@ -403,13 +416,7 @@ class SavedModelBuilder(object):
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
- # Initialize a saver to generate a sharded output for all saveables in the
- # current scope.
- saver = tf_saver.Saver(
- variables._all_saveable_objects(), # pylint: disable=protected-access
- sharded=True,
- write_version=saver_pb2.SaverDef.V2,
- allow_empty=True)
+ saver = self._maybe_create_saver(saver)
# Save the variables. Also, disable writing the checkpoint state proto. The
# file is not used during SavedModel loading. In addition, since a
@@ -421,8 +428,7 @@ class SavedModelBuilder(object):
# The graph almost certainly previously contained at least one Saver, and
# possibly several (e.g. one for loading a pretrained embedding, and another
- # for the model weights). However, a *new* Saver was just created that
- # includes all of the variables. Removing the preexisting ones was the
+ # for the model weights). Removing the preexisting ones was the
# motivation for the clear_extraneous_savers option, but it turns out that
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index bebf1d5e0d..e5f649fdab 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -28,6 +28,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
@@ -79,12 +80,14 @@ def _parse_saved_model(export_dir):
constants.SAVED_MODEL_FILENAME_PB))
-def _get_asset_tensors(export_dir, meta_graph_def_to_load):
+def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
+ import_scope: Optional `string` -- if specified, prepend this followed by
+ '/' to all returned asset tensor names.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
@@ -104,7 +107,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
- asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
+ tensor_name = asset_proto.tensor_info.name
+ if import_scope:
+ tensor_name = "%s/%s" % (import_scope, tensor_name)
+ asset_tensor_dict[tensor_name] = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict
@@ -179,7 +185,7 @@ def maybe_saved_model_directory(export_dir):
@tf_export("saved_model.loader.load")
-def load(sess, tags, export_dir, **saver_kwargs):
+def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
Args:
@@ -189,6 +195,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: Optional keyword arguments passed through to Saver.
Returns:
@@ -198,11 +208,56 @@ def load(sess, tags, export_dir, **saver_kwargs):
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
"""
- with sess.graph.as_default():
- # Build the SavedModel protocol buffer and find requested meta graph def.
- saved_model = _parse_saved_model(export_dir)
+ loader = SavedModelLoader(export_dir)
+ return loader.load(sess, tags, import_scope, **saver_kwargs)
+
+
+class SavedModelLoader(object):
+ """Load graphs and restore variable values from a `SavedModel`."""
+
+ def __init__(self, export_dir):
+ """Creates a `SavedModelLoader`.
+
+ Args:
+ export_dir: Directory in which the SavedModel protocol buffer and
+ variables to be loaded are located.
+ """
+ self._export_dir = export_dir
+ self._variables_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY),
+ compat.as_bytes(constants.VARIABLES_FILENAME))
+ self._saved_model = _parse_saved_model(export_dir)
+
+ @property
+ def export_dir(self):
+ """Directory containing the SavedModel."""
+ return self._export_dir
+
+ @property
+ def variables_path(self):
+ """Path to variable checkpoint files."""
+ return self._variables_path
+
+ @property
+ def saved_model(self):
+ """SavedModel object parsed from the export directory."""
+ return self._saved_model
+
+ def get_meta_graph_def_from_tags(self, tags):
+ """Return MetaGraphDef with the exact specified tags.
+
+ Args:
+ tags: A list or set of string tags that identify the MetaGraphDef.
+
+ Returns:
+ MetaGraphDef with the same tags.
+
+ Raises:
+ RuntimeError: if no metagraphs were found with the associated tags.
+ """
found_match = False
- for meta_graph_def in saved_model.meta_graphs:
+ for meta_graph_def in self._saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(tags):
meta_graph_def_to_load = meta_graph_def
found_match = True
@@ -214,31 +269,100 @@ def load(sess, tags, export_dir, **saver_kwargs):
" could not be found in SavedModel. To inspect available tag-sets in"
" the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
)
-
- # Build a saver by importing the meta graph def to load.
- saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
-
- if saver:
- # Build the checkpoint path where the variables are located.
- variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
-
- # Restore the variables using the built saver in the provided session.
- saver.restore(sess, variables_path)
- else:
- tf_logging.info("The specified SavedModel has no variables; no "
- "checkpoints were restored.")
-
- # Get asset tensors, if any.
- asset_tensors_dictionary = _get_asset_tensors(export_dir,
- meta_graph_def_to_load)
-
- main_op_tensor = (
- _get_main_op_tensor(meta_graph_def_to_load) or
- (_get_legacy_init_op_tensor(meta_graph_def_to_load)))
- if main_op_tensor is not None:
- sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
-
return meta_graph_def_to_load
+
+ def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
+ """Load ops and nodes from SavedModel MetaGraph into graph.
+
+ Args:
+ graph: tf.Graph object.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
+
+ Returns:
+ Saver defined by the MetaGraph, which can be used to restore the variable
+ values.
+ """
+ meta_graph_def = self.get_meta_graph_def_from_tags(tags)
+ with graph.as_default():
+ return tf_saver.import_meta_graph(
+ meta_graph_def, import_scope=import_scope, **saver_kwargs)
+
+ def restore_variables(self, sess, saver, import_scope=None):
+ """Restore SavedModel variable values into the session.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ saver: a tf.train.Saver object. Can be None if there are no variables in
+ graph. This may be the saver returned by the load_graph() function, or a
+ default `tf.train.Saver()`.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+
+ Raises:
+ ValueError: if no saver was passed to the saver argument, and there are
+ variables in the graph.
+ """
+ with sess.graph.as_default():
+ if (saver is None and
+ not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access
+ tf_logging.info("The specified SavedModel has no variables; no "
+ "checkpoints were restored.")
+ elif isinstance(saver, tf_saver.Saver):
+ saver.restore(sess, self._variables_path)
+ else:
+ raise ValueError(
+ "No tf.train.Saver object was passed to the function "
+ "SavedModelLoader.restore_variables. Since there are variables in "
+ "the graph, a saver is required.")
+
+ def run_init_ops(self, sess, tags, import_scope=None):
+ """Run initialization ops defined in the `MetaGraphDef`.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ """
+ meta_graph_def = self.get_meta_graph_def_from_tags(tags)
+ with sess.graph.as_default():
+ # Get asset tensors, if any.
+ asset_tensors_dictionary = _get_asset_tensors(
+ self._export_dir, meta_graph_def, import_scope=import_scope)
+
+ main_op_tensor = (
+ _get_main_op_tensor(meta_graph_def) or
+ (_get_legacy_init_op_tensor(meta_graph_def)))
+ if main_op_tensor is not None:
+ sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
+
+ def load(self, sess, tags, import_scope=None, **saver_kwargs):
+ """Load the MetaGraphDef graph and restore variable values into the session.
+
+ Args:
+ sess: tf.Session to restore variable values.
+ tags: a set of string tags identifying a MetaGraphDef.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
+ **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.
+
+ Returns:
+ `MetagraphDef` proto of the graph that was loaded.
+ """
+ with sess.graph.as_default():
+ saver = self.load_graph(sess.graph, tags, import_scope,
+ **saver_kwargs)
+ self.restore_variables(sess, saver, import_scope)
+ self.run_init_ops(sess, tags, import_scope)
+ return self.get_meta_graph_def_from_tags(tags)
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
new file mode 100644
index 0000000000..ce18859f6b
--- /dev/null
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -0,0 +1,217 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModelLoader class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder as saved_model_builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import utils
+from tensorflow.python.training import saver as tf_saver
+
+
+def _get_export_dir(label):
+ return os.path.join(test.get_temp_dir(), label)
+
+SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
+SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
+
+
+class SavedModelLoaderTest(test.TestCase):
+
+ def setUp(self):
+ """Write test SavedModels to a temp directory."""
+ with session.Session(graph=ops.Graph()) as sess:
+ x = variables.Variable(5, name="x")
+ y = variables.Variable(11, name="y")
+ z = x + y
+ sess.run(variables.global_variables_initializer())
+
+ foo_sig_def = signature_def_utils.build_signature_def(
+ {"foo_input": utils.build_tensor_info(x)},
+ {"foo_output": utils.build_tensor_info(z)})
+ bar_sig_def = signature_def_utils.build_signature_def(
+ {"bar_x": utils.build_tensor_info(x),
+ "bar_y": utils.build_tensor_info(y)},
+ {"bar_z": utils.build_tensor_info(z)})
+
+ builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
+ builder.save()
+
+ # Write SavedModel with a main_op
+ assign_op = control_flow_ops.group(state_ops.assign(y, 7))
+
+ builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
+ main_op=assign_op)
+ builder.save()
+
+ def tearDown(self):
+ file_io.delete_recursively(test.get_temp_dir())
+
+ def test_load_function(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+ loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader2.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
+
+ def test_load_graph(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ graph = ops.Graph()
+ loader.load_graph(graph, ["foo_graph"])
+
+ x = graph.get_tensor_by_name("x:0")
+ y = graph.get_tensor_by_name("y:0")
+
+ with self.assertRaises(KeyError):
+ graph.get_tensor_by_name("z:0")
+
+ with self.test_session(graph=graph) as sess:
+ # Check that x and y are not initialized
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(x)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(y)
+
+ def test_load_with_import_scope(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ saver = loader.load_graph(sess.graph, ["foo_graph"], import_scope="baz")
+
+ # The default saver should not work when the import scope is set.
+ with self.assertRaises(errors.NotFoundError):
+ loader.restore_variables(sess, tf_saver.Saver())
+
+ loader.restore_variables(sess, saver)
+ loader.run_init_ops(sess, ["foo_graph"])
+
+ self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("baz/y:0").eval())
+
+ # Test combined load function.
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"], import_scope="baa")
+ self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
+
+ def test_restore_variables(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ with self.test_session(graph=ops.Graph()) as sess:
+ x = variables.Variable(0, name="x")
+ y = variables.Variable(0, name="y")
+ z = x * y
+
+ sess.run(variables.global_variables_initializer())
+
+ # There are variables to restore, so a saver must be created.
+ with self.assertRaises(ValueError):
+ loader.restore_variables(sess, None)
+
+ loader.restore_variables(sess, tf_saver.Saver())
+ self.assertEqual(55, z.eval())
+
+ def test_run_init_op(self):
+ loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
+ graph = ops.Graph()
+ saver = loader.load_graph(graph, ["foo_graph"])
+ with self.test_session(graph=graph) as sess:
+ loader.restore_variables(sess, saver)
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+ loader.run_init_ops(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
+
+ def test_parse_saved_model(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
+ self.assertIsNotNone(meta_graph)
+ self.assertIn("foo", meta_graph.signature_def)
+ self.assertIn("bar", meta_graph.signature_def)
+
+ def test_load_invalid_meta_graph(self):
+ loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags([])
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags([""])
+ with self.assertRaises(RuntimeError):
+ loader.get_meta_graph_def_from_tags(["not_a_graph"])
+
+ def test_load_saved_model_with_no_variables(self):
+ """Test that SavedModel runs saver when there appear to be no variables.
+
+ When no variables are detected, this may mean that the variables were saved
+ to different collections, or the collections weren't saved to the
+ SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
+ run in either of these cases.
+ """
+ path = _get_export_dir("no_variable_saved_model")
+ with session.Session(graph=ops.Graph()) as sess:
+ x = variables.Variable(5, name="x", collections=["not_global_variable"])
+ y = variables.Variable(11, name="y", collections=["not_global_variable"])
+ self.assertFalse(variables._all_saveable_objects())
+ z = x + y
+ sess.run(variables.variables_initializer([x, y]))
+
+ foo_sig_def = signature_def_utils.build_signature_def(
+ {"foo_input": utils.build_tensor_info(x)},
+ {"foo_output": utils.build_tensor_info(z)})
+
+ builder = saved_model_builder.SavedModelBuilder(path)
+ builder.add_meta_graph_and_variables(
+ sess, ["foo_graph"], {"foo": foo_sig_def},
+ saver=tf_saver.Saver([x, y]))
+ builder.save()
+
+ loader = loader_impl.SavedModelLoader(path)
+ with self.test_session(graph=ops.Graph()) as sess:
+ saver = loader.load_graph(sess.graph, ["foo_graph"])
+ self.assertFalse(variables._all_saveable_objects())
+ self.assertIsNotNone(saver)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo_graph"])
+ self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
+ self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 7302c77ad5..fb4732aca2 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -44,6 +44,7 @@ from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training
from tensorflow.python.util import compat
SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
@@ -1122,6 +1123,133 @@ class SavedModelTest(test.TestCase):
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
+ def testCustomSaver(self):
+ export_dir = self._get_export_dir("test_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ custom_saver = training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertFalse("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "my_saver/restore_all")
+
+ def testNoCustomSaver(self):
+ export_dir = self._get_export_dir("test_no_custom_saver")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ training.Saver(name="my_saver")
+ builder.add_meta_graph_and_variables(sess, ["tag"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, ["tag"], export_dir)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue("my_saver/restore_all" in graph_ops)
+ self.assertTrue("save/restore_all" in graph_ops)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name, "save/restore_all")
+
+ def testMultipleCustomSavers(self):
+ export_dir = self._get_export_dir("test_multiple_custom_savers")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ variables.Variable(1, name="v1")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(sess, ["tag_0"])
+
+ saver_1 = training.Saver()
+ builder.add_meta_graph(["tag_1"], saver=saver_1)
+
+ saver_2 = training.Saver()
+ builder.add_meta_graph(["tag_2"], saver=saver_2)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ def _validate_custom_saver(tag_name, saver_name):
+ with ops.Graph().as_default() as graph:
+ with self.test_session(graph=graph) as sess:
+ saved_graph = loader.load(sess, [tag_name], export_dir)
+ self.assertEqual(
+ saved_graph.saver_def.restore_op_name,
+ saver_name)
+
+ _validate_custom_saver("tag_0", "save/restore_all")
+ _validate_custom_saver("tag_1", "save_1/restore_all")
+ _validate_custom_saver("tag_2", "save_2/restore_all")
+
+ def testImportScope(self):
+ export_dir = self._get_export_dir("test_scoped_assets")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Build a SavedModel with a variable, an asset, and a constant tensor.
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+ asset_collection = self._build_asset_collection("foo.txt", "content_foo",
+ "asset_file_tensor")
+ constant_op.constant("constant value", name="constant_tensor_name")
+ builder.add_meta_graph_and_variables(
+ sess, ["tag_name"], assets_collection=asset_collection)
+
+ # Save the asset file path for later comparison.
+ asset_file_path = asset_collection[0].eval()
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Restore the SavedModel under an import_scope in a new graph/session.
+ graph_proto = loader.load(
+ sess, ["tag_name"], export_dir, import_scope="scope_name")
+
+ # The loaded variable tensor should be scoped, but its contents should be
+ # unchanged.
+ self.assertEqual(
+ "scope_name/v:0",
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
+ self.assertEqual(
+ 42,
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+
+ # The loaded asset tensor should be scoped, but the asset file path and
+ # contents should be unchanged.
+ asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
+ self.assertEqual(1, len(asset_collection))
+ self.assertEqual(asset_file_path, asset_collection[0].eval())
+ self.assertEqual("scope_name/asset_file_tensor:0",
+ asset_collection[0].name)
+ # The static asset data inside graph_proto.collection_def should not be
+ # scoped.
+ self._validate_asset_collection(export_dir, graph_proto.collection_def,
+ "foo.txt", "content_foo",
+ "asset_file_tensor:0")
+
+ # The constant tensor should be scoped, but its contents should be
+ # unchanged.
+ self.assertEqual(
+ compat.as_bytes("constant value"),
+ ops.get_default_graph().get_tensor_by_name(
+ "scope_name/constant_tensor_name:0").eval())
+
def testClearDevices(self):
export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py
index c08e3cca00..95eca76496 100644
--- a/tensorflow/python/training/adadelta.py
+++ b/tensorflow/python/training/adadelta.py
@@ -46,6 +46,13 @@ class AdadeltaOptimizer(optimizer.Optimizer):
use_locking: If `True` use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
+ each be a callable that takes no arguments and returns the actual value to
+ use. This can be useful for changing these values across different
+ invocations of optimizer functions.
+ @end_compatibility
"""
super(AdadeltaOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
@@ -63,9 +70,13 @@ class AdadeltaOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "accum_update", self._name)
def _prepare(self):
- self._lr_t = ops.convert_to_tensor(self._lr, name="lr")
- self._rho_t = ops.convert_to_tensor(self._rho, name="rho")
- self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._lr)
+ rho = self._call_if_callable(self._rho)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name="lr")
+ self._rho_t = ops.convert_to_tensor(rho, name="rho")
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
accum = self.get_slot(var, "accum")
diff --git a/tensorflow/python/training/adadelta_test.py b/tensorflow/python/training/adadelta_test.py
index 50f435236b..2678016d24 100644
--- a/tensorflow/python/training/adadelta_test.py
+++ b/tensorflow/python/training/adadelta_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -32,44 +34,52 @@ from tensorflow.python.training import adadelta
class AdadeltaOptimizerTest(test.TestCase):
- def doTestBasic(self, use_resource=False):
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
num_updates = 4 # number of ADADELTA steps to perform
for dtype in [dtypes.half, dtypes.float32]:
for grad in [0.2, 0.1, 0.01]:
for lr in [1.0, 0.5, 0.1]:
- with self.test_session():
- var0_init = [1.0, 2.0]
- var1_init = [3.0, 4.0]
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable(
- var0_init, dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable(
- var1_init, dtype=dtype)
- else:
- var0 = variables.Variable(var0_init, dtype=dtype)
- var1 = variables.Variable(var1_init, dtype=dtype)
-
- grads = constant_op.constant([grad, grad], dtype=dtype)
-
- accum = 0.0
- accum_update = 0.0
-
- # ADADELTA gradient optimizer
- rho = 0.95
- epsilon = 1e-8
- adadelta_opt = adadelta.AdadeltaOptimizer(lr, rho, epsilon)
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ if use_callable_params:
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lambda: lr, # pylint: disable=cell-var-from-loop
+ rho=lambda: rho, # pylint: disable=cell-var-from-loop
+ epsilon=lambda: epsilon) # pylint: disable=cell-var-from-loop
+ else:
+ adadelta_opt = adadelta.AdadeltaOptimizer(
+ learning_rate=lr, rho=rho, epsilon=epsilon)
+ if not context.executing_eagerly():
adadelta_update = adadelta_opt.apply_gradients(
zip([grads, grads], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ # TODO(lxuechen): This is hard to test in eager mode,
+ # since the optimizer is not fully initialized until the first
+ # call to `apply_gradients`
opt_vars = adadelta_opt.variables()
self.assertStartsWith(opt_vars[0].name, var0._shared_name)
self.assertStartsWith(opt_vars[1].name, var0._shared_name)
self.assertStartsWith(opt_vars[2].name, var1._shared_name)
self.assertStartsWith(opt_vars[3].name, var1._shared_name)
self.assertEqual(4, len(opt_vars))
-
- variables.global_variables_initializer().run()
-
# Assign slots
slot = [None] * 2
slot_update = [None] * 2
@@ -91,36 +101,42 @@ class AdadeltaOptimizerTest(test.TestCase):
self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
self.assertFalse(slot_update[1] in variables.trainable_variables())
- # Fetch params to validate initial values
- self.assertAllClose(var0_init, var0.eval())
- self.assertAllClose(var1_init, var1.eval())
-
- update = [None] * num_updates
- tot_update = 0
- for step in range(num_updates):
- # Run adadelta update for comparison
- adadelta_update.run()
-
- # Perform initial update without previous accum values
- accum = accum * rho + (grad**2) * (1 - rho)
- update[step] = (np.sqrt(accum_update + epsilon) *
- (1. / np.sqrt(accum + epsilon)) * grad)
- accum_update = (accum_update * rho + (update[step]**2) *
- (1.0 - rho))
- tot_update += update[step] * lr
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, self.evaluate(var0))
+ self.assertAllClose(var1_init, self.evaluate(var1))
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ if not context.executing_eagerly():
+ self.evaluate(adadelta_update)
+ else:
+ adadelta_opt.apply_gradients(zip([grads, grads], [var0, var1]))
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (
+ np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (
+ accum_update * rho + (update[step]**2) * (1.0 - rho))
+ tot_update += update[step] * lr
+
+ if not context.executing_eagerly():
# Check that the accumulators have been updated
+ # TODO(lxuechen): This is hard to test in eager mode
for slot_idx in range(2):
self.assertAllCloseAccordingToType(
np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
- slot[slot_idx].eval(),
+ self.evaluate(slot[slot_idx]),
rtol=1e-5)
self.assertAllCloseAccordingToType(
np.array(
[accum_update, accum_update],
dtype=dtype.as_numpy_dtype()),
- slot_update[slot_idx].eval(),
+ self.evaluate(slot_update[slot_idx]),
rtol=1e-5)
# Check that the parameters have been updated
@@ -128,22 +144,28 @@ class AdadeltaOptimizerTest(test.TestCase):
np.array(
[var0_init[0] - tot_update, var0_init[1] - tot_update],
dtype=dtype.as_numpy_dtype()),
- var0.eval(),
+ self.evaluate(var0),
rtol=1e-5)
self.assertAllCloseAccordingToType(
np.array(
[var1_init[0] - tot_update, var1_init[1] - tot_update],
dtype=dtype.as_numpy_dtype()),
- var1.eval(),
+ self.evaluate(var1),
rtol=1e-5)
def testBasic(self):
- self.doTestBasic(use_resource=False)
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index deb4e6f546..6778f3c735 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -51,6 +51,13 @@ class AdagradOptimizer(optimizer.Optimizer):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate` can be a callable that
+ takes no arguments and returns the actual value to use. This can be useful
+ for changing these values across different invocations of optimizer
+ functions.
+ @end_compatibility
"""
if initial_accumulator_value <= 0.0:
raise ValueError("initial_accumulator_value must be positive: %s" %
@@ -78,8 +85,9 @@ class AdagradOptimizer(optimizer.Optimizer):
"accumulator", self._name)
def _prepare(self):
- self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
- name="learning_rate")
+ learning_rate = self._call_if_callable(self._learning_rate)
+ self._learning_rate_tensor = ops.convert_to_tensor(
+ learning_rate, name="learning_rate")
def _apply_dense(self, grad, var):
acc = self.get_slot(var, "accumulator")
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
index 15b007b46d..c9aec33d09 100644
--- a/tensorflow/python/training/adagrad_test.py
+++ b/tensorflow/python/training/adagrad_test.py
@@ -20,9 +20,11 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -34,40 +36,63 @@ from tensorflow.python.training import adagrad
class AdagradOptimizerTest(test.TestCase):
- def doTestBasic(self, use_locking=False, use_resource=False):
+ def doTestBasic(self,
+ use_locking=False,
+ use_resource=False,
+ use_callable_params=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- else:
- var0 = variables.Variable([1.0, 2.0], dtype=dtype)
- var1 = variables.Variable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- ada_opt = adagrad.AdagradOptimizer(
- 3.0, initial_accumulator_value=0.1, use_locking=use_locking)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ learning_rate = lambda: 3.0
+ if not use_callable_params:
+ learning_rate = learning_rate()
+
+ ada_opt = adagrad.AdagradOptimizer(
+ learning_rate, initial_accumulator_value=0.1, use_locking=use_locking)
+
+ if not context.executing_eagerly():
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
- # Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of adagrad
- for _ in range(3):
- ada_update.run()
- # Validate updated params
- self.assertAllCloseAccordingToType(
- np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+ self.evaluate(variables.global_variables_initializer())
+
+ # Fetch params to validate initial values
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([3.0, 4.0], v1_val)
+
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ if not context.executing_eagerly():
+ self.evaluate(ada_update)
+ else:
+ ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ # Validate updated params
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), v1_val)
def testBasic(self):
self.doTestBasic(use_locking=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(
+ use_locking=False, use_resource=True, use_callable_params=True)
+
def testBasicLocked(self):
self.doTestBasic(use_locking=True)
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 6fa3ff6658..b65c88e972 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -85,6 +85,13 @@ class AdamOptimizer(optimizer.Optimizer):
use_locking: If True use locks for update operations.
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
"""
super(AdamOptimizer, self).__init__(use_locking, name)
self._lr = learning_rate
@@ -128,10 +135,15 @@ class AdamOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "v", self._name)
def _prepare(self):
- self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
- self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
- self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
- self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._lr)
+ beta1 = self._call_if_callable(self._beta1)
+ beta2 = self._call_if_callable(self._beta2)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
+ self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
+ self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
+ self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
m = self.get_slot(var, "m")
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index bc68f24c6f..ccdc7e384d 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -150,7 +150,7 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
- def doTestBasic(self, use_resource=False):
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
with self.test_session(graph=ops.Graph()):
# Initialize variables for numpy implementation.
@@ -171,7 +171,17 @@ class AdamOptimizerTest(test.TestCase):
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
- opt = adam.AdamOptimizer()
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = adam.AdamOptimizer(learning_rate=learning_rate)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
opt_variables = opt.variables()
beta1_power, beta2_power = opt._get_beta_accumulators()
@@ -221,6 +231,10 @@ class AdamOptimizerTest(test.TestCase):
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py
index e7f88de1d2..5b372e82b3 100644
--- a/tensorflow/python/training/checkpoint_utils.py
+++ b/tensorflow/python/training/checkpoint_utils.py
@@ -147,7 +147,7 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
partitioner=lambda shape, dtype: [5, 1])
# Initialize all variables in `new_scope_1` from `old_scope_1`.
- init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/', 'new_scope_1'})
+ init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})
# Use names to specify which variables to initialize from checkpoint.
init_from_checkpoint('/tmp/model.ckpt',
@@ -219,8 +219,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
else:
var_name = ",".join([v.name for v in var])
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
- logging.info("Initialize variable %s from checkpoint %s with %s",
- var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
+ logging.debug("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
else:
scopes = ""
# TODO(vihanjain): Support list of 'current_var_or_name' here.
@@ -261,8 +261,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
if var is None:
var = _collect_partitioned_variable(var_name, store_vars)
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
- logging.info("Initialize variable %s from checkpoint %s with %s",
- var_name, ckpt_dir_or_file, full_tensor_name)
+ logging.debug("Initialize variable %s from checkpoint %s with %s",
+ var_name, ckpt_dir_or_file, full_tensor_name)
def _get_checkpoint_filename(ckpt_dir_or_file):
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index 87ba4dc91c..35007653a0 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -42,21 +42,39 @@ py_test(
)
py_library(
- name = "data_structures_base",
- srcs = ["data_structures_base.py"],
+ name = "tracking",
+ srcs = ["tracking.py"],
srcs_version = "PY2AND3",
deps = [
":base",
+ ":data_structures",
],
)
+py_test(
+ name = "tracking_test",
+ srcs = ["tracking_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ":tracking",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "layer_utils",
+ srcs = ["layer_utils.py"],
+ srcs_version = "PY2AND3",
+)
+
py_library(
name = "data_structures",
srcs = ["data_structures.py"],
srcs_version = "PY2AND3",
deps = [
":base",
- ":data_structures_base",
+ ":layer_utils",
],
)
@@ -83,6 +101,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":base",
+ ":tracking",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index cfe7259e1b..ee35b01328 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -33,6 +33,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
+from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
@@ -340,6 +341,34 @@ _SlotVariableRestoration = collections.namedtuple(
])
+def no_automatic_dependency_tracking(method):
+ """Disables automatic dependency tracking on attribute assignment.
+
+ Use to decorate any method of a Checkpointable object. Attribute assignment in
+ that method will not add dependencies (also respected in Model). Harmless if
+ used in a class which does not do automatic dependency tracking (which means
+ it's safe to use in base classes which may have subclasses which also inherit
+ from Checkpointable).
+
+ Args:
+ method: The method to decorate.
+ Returns:
+ A decorated method which sets and un-sets automatic dependency tracking for
+ the object the method is called on (not thread safe).
+ """
+
+ def _method_wrapper(self, *args, **kwargs):
+ previous_value = getattr(self, "_setattr_tracking", True)
+ self._setattr_tracking = False # pylint: disable=protected-access
+ try:
+ method(self, *args, **kwargs)
+ finally:
+ self._setattr_tracking = previous_value # pylint: disable=protected-access
+
+ return tf_decorator.make_decorator(
+ target=method, decorator_func=_method_wrapper)
+
+
class CheckpointableBase(object):
"""Base class for `Checkpointable` objects without automatic dependencies.
@@ -349,6 +378,11 @@ class CheckpointableBase(object):
checks.
"""
+ # CheckpointableBase does not do automatic dependency tracking, but uses the
+ # no_automatic_dependency_tracking decorator so it can avoid adding
+ # dependencies if a subclass is Checkpointable / inherits from Model (both of
+ # which have __setattr__ overrides).
+ @no_automatic_dependency_tracking
def _maybe_initialize_checkpointable(self):
"""Initialize dependency management.
@@ -386,6 +420,10 @@ class CheckpointableBase(object):
# building.
self._name_based_restores = set()
+ def _no_dependency(self, value):
+ """If automatic dependency tracking is enabled, ignores `value`."""
+ return value
+
def _name_based_attribute_restore(self, checkpoint):
"""Restore the object's attributes from a name-based checkpoint."""
self._name_based_restores.add(checkpoint)
@@ -463,7 +501,7 @@ class CheckpointableBase(object):
ValueError: If the variable name is not unique.
"""
self._maybe_initialize_checkpointable()
- if not overwrite and self._lookup_dependency(name) is not None:
+ if overwrite and self._lookup_dependency(name) is not None:
raise ValueError(
("A variable named '%s' already exists in this Checkpointable, but "
"Checkpointable._add_variable called to create another with "
@@ -593,9 +631,9 @@ class CheckpointableBase(object):
self._unconditional_checkpoint_dependencies[index] = new_reference
elif current_object is None:
self._unconditional_checkpoint_dependencies.append(new_reference)
- self._unconditional_dependency_names[name] = checkpointable
self._handle_deferred_dependencies(
name=name, checkpointable=checkpointable)
+ self._unconditional_dependency_names[name] = checkpointable
return checkpointable
def _handle_deferred_dependencies(self, name, checkpointable):
@@ -733,86 +771,3 @@ class CheckpointableBase(object):
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
-
-
-class NoDependency(object):
- """Allows attribute assignment to `Checkpointable` objects with no dependency.
-
- Example usage:
- ```python
- obj = Checkpointable()
- obj.has_dependency = tf.Variable(0., name="dep")
- obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
- assert obj.no_dependency.name == "nodep:0"
- ```
-
- `obj` in this example has a dependency on the variable "dep", and both
- attributes contain un-wrapped `Variable` objects.
-
- `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
- dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
- `Layer` to the attribute without a checkpoint dependency, but the `Model` will
- still track the `Layer` (so it will appear in `Model.layers`, and its
- variables will appear in `Model.variables`).
- """
-
- def __init__(self, value):
- self.value = value
-
-
-class NotCheckpointable(object):
- """Marks instances of child classes as unsaveable using an object-based API.
-
- Useful for marking objects which would otherwise look checkpointable because
- of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from
- `NotCheckpointable` does not prevent an object from being assigned to any
- attributes, but will throw an error on save/restore.
- """
- pass
-
-
-class Checkpointable(CheckpointableBase):
- """Manages dependencies on other objects.
-
- `Checkpointable` objects may have dependencies: other `Checkpointable` objects
- which should be saved if the object declaring the dependency is saved. A
- correctly saveable program has a dependency graph such that if changing a
- global variable affects an object (e.g. changes the behavior of any of its
- methods) then there is a chain of dependencies from the influenced object to
- the variable.
-
- Dependency edges have names, and are created implicitly when a
- `Checkpointable` object is assigned to an attribute of another
- `Checkpointable` object. For example:
-
- ```
- obj = Checkpointable()
- obj.v = ResourceVariable(0.)
- ```
-
- The `Checkpointable` object `obj` now has a dependency named "v" on a
- variable.
-
- `Checkpointable` objects may specify `Tensor`s to be saved and restored
- directly (e.g. a `Variable` indicating how to save itself) rather than through
- dependencies on other objects. See
- `Checkpointable._gather_saveables_for_checkpoint` for details.
- """
-
- def __setattr__(self, name, value):
- """Support self.foo = checkpointable syntax."""
- # Perform the attribute assignment, and potentially call other __setattr__
- # overrides such as that for tf.keras.Model.
- no_dependency = isinstance(value, NoDependency)
- if no_dependency:
- value = value.value
- super(Checkpointable, self).__setattr__(name, value)
- if not no_dependency and isinstance(value, CheckpointableBase):
- self._track_checkpointable(
- value, name=name,
- # Allow the user to switch the Checkpointable which is tracked by this
- # name, since assigning a new variable to an attribute has
- # historically been fine (e.g. Adam did this).
- # TODO(allenl): Should this be a warning once Checkpointable save/load
- # is usable?
- overwrite=True)
diff --git a/tensorflow/python/training/checkpointable/base_test.py b/tensorflow/python/training/checkpointable/base_test.py
index 0a274cdfed..950e9c5b53 100644
--- a/tensorflow/python/training/checkpointable/base_test.py
+++ b/tensorflow/python/training/checkpointable/base_test.py
@@ -17,33 +17,25 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import base
class InterfaceTests(test.TestCase):
- def testMultipleAssignment(self):
- root = checkpointable.Checkpointable()
- root.leaf = checkpointable.Checkpointable()
- root.leaf = root.leaf
- duplicate_name_dep = checkpointable.Checkpointable()
+ def testOverwrite(self):
+ root = base.CheckpointableBase()
+ leaf = base.CheckpointableBase()
+ root._track_checkpointable(leaf, name="leaf")
+ (current_name, current_dependency), = root._checkpoint_dependencies
+ self.assertIs(leaf, current_dependency)
+ self.assertEqual("leaf", current_name)
+ duplicate_name_dep = base.CheckpointableBase()
with self.assertRaises(ValueError):
root._track_checkpointable(duplicate_name_dep, name="leaf")
- # No error; we're overriding __setattr__, so we can't really stop people
- # from doing this while maintaining backward compatibility.
- root.leaf = duplicate_name_dep
root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
-
- def testNoDependency(self):
- root = checkpointable.Checkpointable()
- hasdep = checkpointable.Checkpointable()
- root.hasdep = hasdep
- nodep = checkpointable.Checkpointable()
- root.nodep = checkpointable.NoDependency(nodep)
- self.assertEqual(1, len(root._checkpoint_dependencies))
- self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
- self.assertIs(root.hasdep, hasdep)
- self.assertIs(root.nodep, nodep)
+ (current_name, current_dependency), = root._checkpoint_dependencies
+ self.assertIs(duplicate_name_dep, current_dependency)
+ self.assertEqual("leaf", current_name)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 69ed253fb2..019d43f09c 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -21,54 +21,127 @@ import collections
import six
-from tensorflow.python.keras.engine import base_layer
-from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import variables
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
-from tensorflow.python.training.checkpointable import data_structures_base
-
-
-# TODO(allenl): We could track regular Python data structures which get assigned
-# to Checkpointable objects. Making this work with restore-on-create would be
-# tricky; we'd need to re-create nested structures with our own wrapped objects
-# on assignment to an attribute, and track the user's original structure to make
-# sure they don't modify it except through the wrappers (since we could save the
-# user's updated structure, but would have no way to support restore-on-create
-# for those modifications).
-# TODO(allenl): A dictionary data structure would be good too.
-class CheckpointableDataStructure(
- data_structures_base.CheckpointableDataStructureBase):
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import layer_utils
+
+
+class NoDependency(object):
+ """Allows attribute assignment to `Checkpointable` objects with no dependency.
+
+ Example usage:
+ ```python
+ obj = Checkpointable()
+ obj.has_dependency = tf.Variable(0., name="dep")
+ obj.no_dependency = NoDependency(tf.Variable(1., name="nodep"))
+ assert obj.no_dependency.name == "nodep:0"
+ ```
+
+ `obj` in this example has a dependency on the variable "dep", and both
+ attributes contain un-wrapped `Variable` objects.
+
+ `NoDependency` also works with `tf.keras.Model`, but only for checkpoint
+ dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped)
+ `Layer` to the attribute without a checkpoint dependency, but the `Model` will
+ still track the `Layer` (so it will appear in `Model.layers`, and its
+ variables will appear in `Model.variables`).
+ """
+
+ def __init__(self, value):
+ self.value = value
+
+
+def _wrap_or_unwrap(value):
+ """Wraps basic data structures, unwraps NoDependency objects."""
+ if isinstance(value, NoDependency):
+ return value.value
+ if isinstance(value, base.CheckpointableBase):
+ return value # Skip conversion for already checkpointable objects.
+ elif isinstance(value, list):
+ return _ListWrapper(value)
+ else:
+ return value
+ # TODO(allenl): Handle other common data structures. Tuples will require
+ # special casing (tuple subclasses are not weak referenceable, so replacement
+ # with a wrapper that subclasses tuple on attribute assignment works poorly,
+ # and replacement with a wrapper that isn't a tuple is also problematic),
+ # probably a tree traversal where the leaves are non-tuples(/namedtuples) to
+ # come up with names. Dictionaries should look like lists.
+
+
+def sticky_attribute_assignment(checkpointable, name, value):
+ """Adds dependencies, generally called from __setattr__.
+
+ This behavior is shared between Checkpointable and Model.
+
+ Respects NoDependency indicators, but otherwise makes checkpointable objects
+ out of common data structures and tracks objects by their attribute names.
+
+ Args:
+ checkpointable: The object to add dependencies to (generally the one having
+ an attribute assigned).
+ name: The attribute name being assigned.
+ value: The value being assigned. Not necessarily a checkpointable object.
+
+ Returns:
+ The value which should be stored in the attribute (unwrapped from a
+ NoDependency object if necessary).
+ """
+ if isinstance(value, NoDependency):
+ add_dependency = False
+ else:
+ add_dependency = True
+ value = _wrap_or_unwrap(value)
+ if not add_dependency:
+ return value
+ if isinstance(value, base.CheckpointableBase):
+ checkpointable._track_checkpointable( # pylint: disable=protected-access
+ value, name=name,
+ # Allow the user to switch the Checkpointable which is tracked by this
+ # name, since assigning a new variable to an attribute has
+ # historically been fine (e.g. Adam did this).
+ overwrite=True)
+ return value
+
+
+class CheckpointableDataStructure(base.CheckpointableBase):
"""Base class for data structures which contain checkpointable objects."""
def __init__(self):
+ # An append-only ordered set
self._layers = []
+
self.trainable = True
self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
- if isinstance(value, checkpointable_lib.CheckpointableBase):
- self._track_checkpointable(value, name=name)
- if isinstance(value, variables.Variable):
- self._extra_variables.append(value)
- else:
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
+ if not isinstance(value, base.CheckpointableBase):
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
- if isinstance(value, (
- base_layer.Layer,
- data_structures_base.CheckpointableDataStructureBase)):
- if value not in self._layers:
+ if (isinstance(value, CheckpointableDataStructure)
+ or layer_utils.is_layer(value)):
+ # Check for object-identity rather than with __eq__ to avoid
+ # de-duplicating empty container types. Automatically generated list
+ # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
+ # also true. This becomes not true once one of the lists is mutated.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, "_use_resource_variables"):
# In subclassed models, legacy layers (tf.layers) must always use
# resource variables.
value._use_resource_variables = True # pylint: disable=protected-access
+ return value
@property
def layers(self):
- return self._layers
+ return layer_utils.filter_empty_layer_containers(self._layers)
@property
def trainable_weights(self):
@@ -168,24 +241,28 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __init__(self, *args, **kwargs):
"""Construct a new sequence. Arguments are passed to `list()`."""
super(List, self).__init__()
- self._storage = list(*args, **kwargs)
+ self._storage = self._make_storage(*args, **kwargs)
for index, element in enumerate(self._storage):
- self._track_value(element, name=self._name_element(index))
+ self._storage[index] = self._track_value(
+ element, name=self._name_element(index))
+
+ def _make_storage(self, *args, **kwargs):
+ """Determines the backing storage (overridden in subclasses)."""
+ return list(*args, **kwargs)
def _name_element(self, index):
return "%d" % (index,)
def append(self, value):
"""Add a new checkpointable value."""
- self._track_value(value, self._name_element(len(self._storage)))
+ value = self._track_value(value, self._name_element(len(self._storage)))
self._storage.append(value)
def extend(self, values):
"""Add a sequence of checkpointable values."""
- for index_offset, value in enumerate(values):
- self._track_value(
- value, name=self._name_element(len(self._storage) + index_offset))
- self._storage.extend(values)
+ for value in values:
+ self._storage.append(self._track_value(
+ value, name=self._name_element(len(self._storage))))
def __iadd__(self, values):
self.extend(values)
@@ -193,9 +270,12 @@ class List(CheckpointableDataStructure, collections.Sequence):
def __add__(self, other):
if isinstance(other, List):
- return List(self._storage + other._storage) # pylint: disable=protected-access
+ return self.__class__(self._storage + other._storage) # pylint: disable=protected-access
else:
- return List(self._storage + other)
+ return self.__class__(self._storage + other)
+
+ def __radd__(self, other):
+ return self + other
def __getitem__(self, key):
return self._storage[key]
@@ -207,6 +287,144 @@ class List(CheckpointableDataStructure, collections.Sequence):
return "List(%s)" % (repr(self._storage),)
+class _ListWrapper(List, collections.MutableSequence,
+ # Shadowed, but there for isinstance checks.
+ list):
+ """Wraps the built-in `list` to support restore-on-create for variables.
+
+ Unlike `List`, this sequence type is mutable in the same ways built-in lists
+ are. Instead of throwing an error immediately like `List`, it records
+ problematic mutations (e.g. assigning a new element to a position already
+ occupied, meaning both elements get the same names at different times) and
+ refuses to save.
+
+ On assignment to an attribute of a Model or Checkpointable object, Python
+ lists are replaced with _ListWrapper. Wrapping a list in a
+ `tf.contrib.checkpoint.NoDependency` object prevents this.
+ """
+
+ def __init__(self, wrapped_list):
+ """Construct a new list wrapper.
+
+ Args:
+ wrapped_list: The initial value of the data structure. A shallow copy may
+ be maintained for error checking. `wrapped_list` itself should not be
+ modified directly after constructing the `_ListWrapper`, and if changes
+ are detected the `_ListWrapper` will throw an exception on save.
+ """
+ # Monotonic flags which indicate this object would not be restored properly,
+ # and therefore should throw an error on save to avoid giving the impression
+ # that restoring it will work.
+ self._non_append_mutation = False
+ self._external_modification = False
+ super(_ListWrapper, self).__init__(wrapped_list)
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ def _make_storage(self, wrapped_list):
+ """Use the user's original list for storage."""
+ return wrapped_list
+
+ def _check_external_modification(self):
+ """Checks for any changes to the wrapped list not through the wrapper."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ if self._storage != self._last_wrapped_list_snapshot:
+ self._external_modification = True
+ self._last_wrapped_list_snapshot = None
+
+ def _update_snapshot(self):
+ """Acknowledges tracked changes to the wrapped list."""
+ if self._external_modification or self._non_append_mutation:
+ return
+ self._last_wrapped_list_snapshot = list(self._storage)
+
+ @property
+ def _checkpoint_dependencies(self):
+ self._check_external_modification()
+ if self._non_append_mutation:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). A list element was replaced "
+ "(__setitem__), deleted, or inserted. In order to support "
+ "restoration on object creation, tracking is exclusively for "
+ "append-only data structures.\n\nIf you don't need this list "
+ "checkpointed, wrap it in a tf.contrib.checkpoint.NoDependency "
+ "object; it will be automatically un-wrapped and subsequently "
+ "ignored." % (self,)))
+ if self._external_modification:
+ raise ValueError(
+ ("Unable to save the object %s (a list wrapper constructed to track "
+ "checkpointable TensorFlow objects). The wrapped list was modified "
+ "outside the wrapper (its final value was %s, its value when a "
+ "checkpoint dependency was added was %s), which breaks restoration "
+ "on object creation.\n\nIf you don't need this list checkpointed, "
+ "wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
+ "automatically un-wrapped and subsequently ignored." % (
+ self, self._storage, self._last_wrapped_list_snapshot)))
+ return super(_ListWrapper, self)._checkpoint_dependencies
+
+ def __delitem__(self, key):
+ self._non_append_mutation = True
+ del self._storage[key]
+
+ def __setitem__(self, key, value):
+ self._non_append_mutation = True
+ self._storage[key] = value
+
+ def append(self, value):
+ """Add a new checkpointable value."""
+ self._check_external_modification()
+ super(_ListWrapper, self).append(value)
+ self._update_snapshot()
+
+ def extend(self, values):
+ """Add a sequence of checkpointable values."""
+ self._check_external_modification()
+ super(_ListWrapper, self).extend(values)
+ self._update_snapshot()
+
+ def __eq__(self, other):
+ return self._storage == getattr(other, "_storage", other)
+
+ def __ne__(self, other):
+ return self._storage != getattr(other, "_storage", other)
+
+ def __lt__(self, other):
+ return self._storage < getattr(other, "_storage", other)
+
+ def __le__(self, other):
+ return self._storage <= getattr(other, "_storage", other)
+
+ def __gt__(self, other):
+ return self._storage > getattr(other, "_storage", other)
+
+ def __ge__(self, other):
+ return self._storage >= getattr(other, "_storage", other)
+
+ def __hash__(self):
+ # List wrappers need to compare like regular lists, and so like regular
+ # lists they don't belong in hash tables.
+ raise TypeError("unhashable type: 'ListWrapper'")
+
+ def insert(self, index, obj):
+ self._non_append_mutation = True
+ self._storage.insert(index, obj)
+
+ def _track_value(self, value, name):
+ """Allows storage of non-checkpointable objects."""
+ try:
+ value = super(_ListWrapper, self)._track_value(value=value, name=name)
+ except ValueError:
+ # Even if this value isn't checkpointable, we need to make sure
+ # NoDependency objects get unwrapped.
+ value = sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ return value
+
+ def __repr__(self):
+ return "ListWrapper(%s)" % (repr(self._storage),)
+
+
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
@@ -221,8 +439,10 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
"""Construct a new sequence. Arguments are passed to `dict()`."""
super(Mapping, self).__init__()
self._storage = dict(*args, **kwargs)
- for key, value in self._storage.items():
- self._track_value(value, name=self._name_element(key))
+ self._storage.update(
+ {key: self._track_value(
+ value, name=self._name_element(key))
+ for key, value in self._storage.items()})
def _name_element(self, key):
if not isinstance(key, six.string_types):
@@ -232,13 +452,14 @@ class Mapping(CheckpointableDataStructure, collections.Mapping):
return str(key)
def __setitem__(self, key, value):
+ name = self._name_element(key)
+ value = self._track_value(value, name=name)
current_value = self._storage.setdefault(key, value)
if current_value is not value:
raise ValueError(
("Mappings are an append-only data structure. Tried to overwrite the "
"key '%s' with value %s, but it already contains %s")
% (key, value, current_value))
- self._track_value(value, name=self._name_element(key))
def update(self, *args, **kwargs):
for key, value in dict(*args, **kwargs).items():
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index b05b3a8800..ec8c9da809 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
class HasList(training.Model):
@@ -66,7 +67,7 @@ class HasList(training.Model):
class ListTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTracking(self):
model = HasList()
output = model(array_ops.ones([32, 2]))
@@ -106,13 +107,26 @@ class ListTests(test.TestCase):
model(model_input)
self.assertEqual(0, len(model.updates))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLossesForwarded(self):
model = HasList()
model_input = array_ops.ones([32, 2])
model(model_input)
self.assertEqual(2, len(model.losses))
+ def testModelContainersCompareEqual(self):
+ class HasEqualContainers(training.Model):
+
+ def __init__(self):
+ super(HasEqualContainers, self).__init__()
+ self.l1 = []
+ self.l2 = []
+
+ model = HasEqualContainers()
+ model.l1.append(HasEqualContainers())
+ model.l2.append(HasEqualContainers())
+ self.assertEqual([model.l1, model.l2], model.layers)
+
def testNotCheckpointable(self):
class NotCheckpointable(object):
pass
@@ -158,11 +172,62 @@ class ListTests(test.TestCase):
self.assertEqual([v], l.trainable_weights)
self.assertEqual([v2], l.non_trainable_weights)
+ def testListWrapperBasic(self):
+ # _ListWrapper, unlike List, compares like the built-in list type (since it
+ # is used to automatically replace lists).
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ self.assertEqual([a, a],
+ [a, a])
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual([a, a],
+ data_structures._ListWrapper([a, a]))
+ self.assertEqual(data_structures._ListWrapper([a, a]),
+ [a, a])
+ self.assertNotEqual([a, a],
+ [b, a])
+ self.assertNotEqual(data_structures._ListWrapper([a, a]),
+ data_structures._ListWrapper([b, a]))
+ self.assertNotEqual([a, a],
+ data_structures._ListWrapper([b, a]))
+ self.assertLess([a], [a, b])
+ self.assertLess(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertLessEqual([a], [a, b])
+ self.assertLessEqual(data_structures._ListWrapper([a]),
+ data_structures._ListWrapper([a, b]))
+ self.assertGreater([a, b], [a])
+ self.assertGreater(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertGreaterEqual([a, b], [a])
+ self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
+ data_structures._ListWrapper([a]))
+ self.assertEqual([a], data_structures._ListWrapper([a]))
+ self.assertEqual([a], list(data_structures.List([a])))
+ self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
+ self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
+ self.assertIsInstance(data_structures._ListWrapper([a]), list)
+
+ def testWrapperChangesList(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l_wrapper.append(1)
+ self.assertEqual([1], l)
+
+ def testListChangesWrapper(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ l.append(1)
+ self.assertEqual([1], l_wrapper)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
self.assertEqual(2, len(has_sequences))
self.assertNotIn(data_structures.List(), has_sequences)
+ with self.assertRaises(TypeError):
+ has_sequences.add(data_structures._ListWrapper([]))
class HasMapping(training.Model):
@@ -190,7 +255,7 @@ class HasMapping(training.Model):
class MappingTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testTracking(self):
model = HasMapping()
output = model(array_ops.ones([32, 2]))
diff --git a/tensorflow/python/training/checkpointable/layer_utils.py b/tensorflow/python/training/checkpointable/layer_utils.py
new file mode 100644
index 0000000000..978fcb2252
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/layer_utils.py
@@ -0,0 +1,93 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities related to layer/model functionality."""
+
+# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
+# once __init__ files no longer require all of tf.keras to be imported together.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def is_layer(obj):
+ """Implicit check for Layer-like objects."""
+ # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
+ return (hasattr(obj, "call")
+ and hasattr(obj, "build")
+ and hasattr(obj, "variables"))
+
+
+def filter_empty_layer_containers(layer_list):
+ """Filter out empty Layer-like containers."""
+ return [layer for layer in layer_list
+ # Filter out only empty Checkpointable data structures. Empty Networks
+ # will still show up in Model.layers.
+ if is_layer(layer) or getattr(layer, "layers", True)]
+
+
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected trainable weights/variables.
+ """
+ if not trainable:
+ return []
+ weights = []
+ for layer in sub_layers:
+ weights += layer.trainable_weights
+ trainable_extra_variables = [
+ v for v in extra_variables if v.trainable]
+ return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the non-trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected non-trainable weights/variables.
+ """
+ trainable_extra_variables = []
+ non_trainable_extra_variables = []
+ for v in extra_variables:
+ if v.trainable:
+ trainable_extra_variables.append(v)
+ else:
+ non_trainable_extra_variables.append(v)
+ weights = []
+ for layer in sub_layers:
+ weights += layer.non_trainable_weights
+ if not trainable:
+ trainable_weights = []
+ for layer in sub_layers:
+ trainable_weights += layer.trainable_weights
+ return (trainable_weights + trainable_extra_variables
+ + weights + non_trainable_extra_variables)
+ return weights + non_trainable_extra_variables
diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py
new file mode 100644
index 0000000000..bd0bed9d46
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/tracking.py
@@ -0,0 +1,72 @@
+"""Dependency tracking for checkpointable objects."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
+
+
+class NotCheckpointable(object):
+ """Marks instances of child classes as unsaveable using an object-based API.
+
+ Useful for marking objects which would otherwise look checkpointable because
+ of inheritance (e.g. through `Layer`) as not checkpointable. Inheriting from
+ `NotCheckpointable` does not prevent an object from being assigned to any
+ attributes, but will throw an error on save/restore.
+ """
+ pass
+
+
+class Checkpointable(base.CheckpointableBase):
+ """Manages dependencies on other objects.
+
+ `Checkpointable` objects may have dependencies: other `Checkpointable` objects
+ which should be saved if the object declaring the dependency is saved. A
+ correctly saveable program has a dependency graph such that if changing a
+ global variable affects an object (e.g. changes the behavior of any of its
+ methods) then there is a chain of dependencies from the influenced object to
+ the variable.
+
+ Dependency edges have names, and are created implicitly when a
+ `Checkpointable` object is assigned to an attribute of another
+ `Checkpointable` object. For example:
+
+ ```
+ obj = Checkpointable()
+ obj.v = ResourceVariable(0.)
+ ```
+
+ The `Checkpointable` object `obj` now has a dependency named "v" on a
+ variable.
+
+ `Checkpointable` objects may specify `Tensor`s to be saved and restored
+ directly (e.g. a `Variable` indicating how to save itself) rather than through
+ dependencies on other objects. See
+ `Checkpointable._gather_saveables_for_checkpoint` for details.
+ """
+
+ def __setattr__(self, name, value):
+ """Support self.foo = checkpointable syntax."""
+ if getattr(self, "_setattr_tracking", True):
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
+ super(Checkpointable, self).__setattr__(name, value)
+
+ def _no_dependency(self, value):
+ """Override to allow CheckpointableBase to disable dependency tracking."""
+ return data_structures.NoDependency(value)
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
new file mode 100644
index 0000000000..96da0d6e47
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -0,0 +1,171 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+from tensorflow.python.util import nest
+
+
+class InterfaceTests(test.TestCase):
+
+ def testMultipleAssignment(self):
+ root = tracking.Checkpointable()
+ root.leaf = tracking.Checkpointable()
+ root.leaf = root.leaf
+ duplicate_name_dep = tracking.Checkpointable()
+ with self.assertRaisesRegexp(ValueError, "already declared"):
+ root._track_checkpointable(duplicate_name_dep, name="leaf")
+ # No error; we're overriding __setattr__, so we can't really stop people
+ # from doing this while maintaining backward compatibility.
+ root.leaf = duplicate_name_dep
+ root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
+ self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf"))
+ (_, dep_object), = root._checkpoint_dependencies
+ self.assertIs(duplicate_name_dep, dep_object)
+
+ def testNoDependency(self):
+ root = tracking.Checkpointable()
+ hasdep = tracking.Checkpointable()
+ root.hasdep = hasdep
+ nodep = tracking.Checkpointable()
+ root.nodep = data_structures.NoDependency(nodep)
+ self.assertEqual(1, len(root._checkpoint_dependencies))
+ self.assertIs(root._checkpoint_dependencies[0].ref, root.hasdep)
+ self.assertIs(root.hasdep, hasdep)
+ self.assertIs(root.nodep, nodep)
+
+ class NoDependencyModel(training.Model):
+
+ @base.no_automatic_dependency_tracking
+ def __init__(self):
+ super(NoDependencyModel, self).__init__()
+ self.a = []
+ self.b = tracking.Checkpointable()
+
+ nodeps = NoDependencyModel()
+ self.assertEqual([nodeps], util.list_objects(nodeps))
+
+ def testListBasic(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ direct_a_dep, = a._checkpoint_dependencies
+ self.assertEqual("l", direct_a_dep.name)
+ self.assertIn(b, direct_a_dep.ref)
+ self.assertIn(c, direct_a_dep.ref)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testMutationDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ a.l = [b]
+ c = tracking.Checkpointable()
+ a.l.insert(0, c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testOutOfBandEditDirtiesList(self):
+ a = tracking.Checkpointable()
+ b = tracking.Checkpointable()
+ held_reference = [b]
+ a.l = held_reference
+ c = tracking.Checkpointable()
+ held_reference.append(c)
+ checkpoint = util.Checkpoint(a=a)
+ with self.assertRaisesRegexp(ValueError, "The wrapped list was modified"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNestedLists(self):
+ a = tracking.Checkpointable()
+ a.l = []
+ b = tracking.Checkpointable()
+ a.l.append([b])
+ c = tracking.Checkpointable()
+ a.l[0].append(c)
+ a_deps = util.list_objects(a)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ a.l[0].append(1)
+ d = tracking.Checkpointable()
+ a.l[0].append(d)
+ a_deps = util.list_objects(a)
+ self.assertIn(d, a_deps)
+ self.assertIn(b, a_deps)
+ self.assertIn(c, a_deps)
+ self.assertNotIn(1, a_deps)
+ e = tracking.Checkpointable()
+ f = tracking.Checkpointable()
+ a.l1 = [[], [e]]
+ a.l1[0].append(f)
+ a_deps = util.list_objects(a)
+ self.assertIn(e, a_deps)
+ self.assertIn(f, a_deps)
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l[0].append(data_structures.NoDependency([]))
+ a.l[0][-1].append(5)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ # Dirtying the inner list means the root object is unsaveable.
+ a.l[0][1] = 2
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoDepList(self):
+ a = training.Model()
+ a.l1 = data_structures.NoDependency([])
+ a.l1.insert(1, 0)
+ self.assertTrue(isinstance(a.l1, list))
+ checkpoint = util.Checkpoint(a=a)
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+ a.l2 = []
+ a.l2.insert(1, 0)
+ with self.assertRaisesRegexp(ValueError, "A list element was replaced"):
+ checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAssertions(self):
+ a = tracking.Checkpointable()
+ a.l = [numpy.zeros([2, 2])]
+ self.assertAllEqual([numpy.zeros([2, 2])], a.l)
+ self.assertAllClose([numpy.zeros([2, 2])], a.l)
+ nest.map_structure(self.assertAllClose, a.l, [numpy.zeros([2, 2])])
+ a.tensors = [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]
+ self.assertAllClose([numpy.ones([2, 2]), numpy.zeros([3, 3])],
+ self.evaluate(a.tensors))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 96e6d10791..6ae5765b13 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -39,8 +39,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import saveable_object as saveable_object_lib
from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training.checkpointable import base as checkpointable_lib
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@@ -91,7 +94,7 @@ class _CheckpointRestoreCoordinator(object):
# use them (for example because of inconsistent references when
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
- self.all_python_objects = weakref.WeakSet()
+ self.all_python_objects = _ObjectIdentityWeakSet()
self.save_path = save_path
self.dtype_map = dtype_map
# When graph building, contains a list of ops to run to restore objects from
@@ -113,7 +116,7 @@ class _CheckpointRestoreCoordinator(object):
# `node` refers to an `Optimizer`, since only these have slot variables.
self.slot_restorations.setdefault(
slot_reference.original_variable_node_id, []).append(
- checkpointable_lib._SlotVariableRestoration( # pylint: disable=protected-access
+ base._SlotVariableRestoration( # pylint: disable=protected-access
optimizer_id=node_index,
slot_variable_id=slot_reference.slot_variable_node_id,
slot_name=slot_reference.slot_name))
@@ -257,27 +260,145 @@ def object_metadata(save_path):
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
try:
object_graph_string = reader.get_tensor(
- checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
+ base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
raise ValueError(
('The specified checkpoint "%s" does not appear to be object-based (it '
'is missing the key "%s"). Likely it was created with a name-based '
'saver and does not contain an object dependency graph.') % (
- save_path, checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
+ save_path, base.OBJECT_GRAPH_PROTO_KEY))
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
return object_graph_proto
+class _ObjectIdentityWrapper(object):
+ """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
+
+ Since __eq__ is based on object identity, it's safe to also define __hash__
+ based on object ids. This lets us add unhashable types like checkpointable
+ _ListWrapper objects to object-identity collections.
+ """
+
+ def __init__(self, wrapped):
+ self._wrapped = wrapped
+
+ @property
+ def unwrapped(self):
+ return self._wrapped
+
+ def __eq__(self, other):
+ if isinstance(other, _ObjectIdentityWrapper):
+ return self._wrapped is other._wrapped # pylint: disable=protected-access
+ return self._wrapped is other
+
+ def __hash__(self):
+ # Wrapper id() is also fine for weakrefs. In fact, we rely on
+ # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
+ # weakref.ref(a) in _WeakObjectIdentityWrapper.
+ return id(self._wrapped)
+
+
+class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
+
+ def __init__(self, wrapped):
+ super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
+
+ @property
+ def unwrapped(self):
+ return self._wrapped()
+
+
+class _ObjectIdentityDictionary(collections.MutableMapping):
+ """A mutable mapping data structure which compares using "is".
+
+ This is necessary because we have checkpointable objects (_ListWrapper) which
+ have behavior identical to built-in Python lists (including being unhashable
+ and comparing based on the equality of their contents by default).
+ """
+
+ def __init__(self):
+ self._storage = {}
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
+
+ def __getitem__(self, key):
+ return self._storage[self._wrap_key(key)]
+
+ def __setitem__(self, key, value):
+ self._storage[self._wrap_key(key)] = value
+
+ def __delitem__(self, key):
+ del self._storage[self._wrap_key(key)]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ for key in self._storage:
+ yield key.unwrapped
+
+
+class _ObjectIdentityWeakKeyDictionary(_ObjectIdentityDictionary):
+ """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self._storage))
+
+ def __iter__(self):
+ keys = self._storage.keys()
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ del self[key]
+ else:
+ yield unwrapped
+
+
+class _ObjectIdentityWeakSet(collections.MutableSet):
+ """Like weakref.WeakSet, but compares objects with "is"."""
+
+ def __init__(self):
+ self._storage = set()
+
+ def __contains__(self, key):
+ return _WeakObjectIdentityWrapper(key) in self._storage
+
+ def discard(self, key):
+ self._storage.discard(_WeakObjectIdentityWrapper(key))
+
+ def add(self, key):
+ self._storage.add(_WeakObjectIdentityWrapper(key))
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self))
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ self.discard(key)
+ else:
+ yield unwrapped
+
+
def _breadth_first_checkpointable_traversal(root_checkpointable):
"""Find shortest paths to all variables owned by dependencies of root."""
bfs_sorted = []
to_visit = collections.deque([root_checkpointable])
- path_to_root = {root_checkpointable: ()}
+ path_to_root = _ObjectIdentityDictionary()
+ path_to_root[root_checkpointable] = ()
while to_visit:
current_checkpointable = to_visit.popleft()
- if isinstance(current_checkpointable, checkpointable_lib.NotCheckpointable):
+ if isinstance(current_checkpointable, tracking.NotCheckpointable):
raise NotImplementedError(
("The object %s does not support object-based saving. File a feature "
"request if this limitation bothers you. In the meantime, you can "
@@ -335,7 +456,7 @@ def _slot_variable_naming_for_optimizer(optimizer_path):
def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
"""Gather and name slot variables."""
non_slot_objects = list(checkpointable_objects)
- slot_variables = {}
+ slot_variables = _ObjectIdentityDictionary()
for checkpointable in non_slot_objects:
if isinstance(checkpointable, optimizer_lib.Optimizer):
naming_scheme = _slot_variable_naming_for_optimizer(
@@ -498,11 +619,12 @@ def _serialize_object_graph(root_checkpointable, saveables_cache):
"""
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -533,11 +655,12 @@ def list_objects(root_checkpointable):
# to run.
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = {
- obj: _object_prefix_from_path(path)
- for obj, path in path_to_root.items()}
- node_ids = {node: node_id for node_id, node
- in enumerate(checkpointable_objects)}
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
_serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
@@ -564,6 +687,93 @@ def gather_initializers(root_checkpointable):
if hasattr(c, "initializer") and c.initializer is not None]
+@tf_contextlib.contextmanager
+def capture_dependencies(template):
+ """Capture variables created within this scope as `Template` dependencies.
+
+ Requires that `template.variable_scope` is active.
+
+ This scope is intended as a compatibility measure, allowing a checkpointable
+ object to add dependencies on variables created in a block of code which is
+ not aware of object-based saving (and instead uses variable names
+ heavily). This is how `Template` objects add dependencies on variables and
+ sub-`Template`s. Where possible, use `tf.make_template` directly.
+
+ Args:
+ template: The `Template` object to register dependencies with.
+
+ Yields:
+ None (when used as a context manager).
+ """
+ name_prefix = template.variable_scope.name
+
+ def _checkpointable_custom_creator(next_creator, name, initial_value,
+ checkpointable_parent=None, **kwargs):
+ """A variable creation hook which adds Checkpointable dependencies.
+
+ Set for example during a `Template`'s first wrapped function
+ execution. Ensures that (a) `template` depends on any checkpointable
+ objects using their own `capture_dependencies` scope inside this scope which
+ create variables, and (b) that any variables not in a more deeply nested
+ scope are added as dependencies directly.
+
+ The `checkpointable_parent` argument is passed between custom creators but
+ ignored when the variable object itself is created. This argument indicates
+ (if not `None`) that a more deeply nested scope has already added the
+ variable as a dependency, and that parent scopes should add a dependency on
+ that object rather than on the variable directly.
+
+ Args:
+ next_creator: See `variable_scope.variable_creator_scope`; the next
+ creator in the chain.
+ name: The (full, scope-influenced) name of the variable. The `name_prefix`
+ itself is stripped for the purposes of object-based dependency tracking,
+ but scopes opened within this scope are respected.
+ initial_value: See `variable_scope.variable_creator_scope`. Taken
+ explicitly so the argument can be re-named and used with
+ `Checkpointable._add_variable_with_custom_getter`.
+ checkpointable_parent: If not None, a more deeply nested checkpointable
+ object and its name prefix which were passed to `capture_dependencies`
+ to add a dependency on (rather than depending on the variable directly).
+ **kwargs: Passed through to the next creator.
+
+ Returns:
+ The output of `next_creator`: the fetched/created variable object.
+ """
+ def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
+ inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
+ # we don't want to propagate.
+ return next_creator(
+ initial_value=initializer,
+ name=name,
+ **inner_kwargs)
+ if name.startswith(name_prefix):
+ scope_stripped_name = name[len(name_prefix) + 1:]
+ if not checkpointable_parent:
+ return template._add_variable_with_custom_getter( # pylint: disable=protected-access
+ initializer=initial_value,
+ name=scope_stripped_name,
+ getter=_call_next_creator_renaming_initializer,
+ # Disable error checking for Checkpointable. Exceptions are instead
+ # raised if necessary when the object-based saver tries to
+ # save/restore the object.
+ overwrite=True,
+ checkpointable_parent=(template, name_prefix),
+ **kwargs)
+ else:
+ parent_object, parent_name_prefix = checkpointable_parent
+ template._track_checkpointable( # pylint: disable=protected-access
+ parent_object,
+ name=parent_name_prefix[len(name_prefix) + 1:],
+ overwrite=True)
+ return next_creator(
+ name=name, initial_value=initial_value,
+ checkpointable_parent=(template, name_prefix), **kwargs)
+
+ with variable_scope.variable_creator_scope(_checkpointable_custom_creator):
+ yield
+
+
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
def __init__(self, tensor, name):
@@ -899,7 +1109,7 @@ class CheckpointableSaver(object):
else:
# Maps Checkpointable objects -> attribute names -> SaveableObjects, to
# avoid re-creating SaveableObjects when graph building.
- self._saveable_object_cache = weakref.WeakKeyDictionary()
+ self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
@property
def _root_checkpointable(self):
@@ -950,11 +1160,11 @@ class CheckpointableSaver(object):
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
- assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
+ assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
named_variables.append(
_NoRestoreSaveable(
tensor=object_graph_tensor,
- name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
+ name=base.OBJECT_GRAPH_PROTO_KEY))
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
# save() is called so they pick up new Tensors passed to their
@@ -1044,7 +1254,7 @@ class CheckpointableSaver(object):
dtype_map = reader.get_variable_to_dtype_map()
try:
object_graph_string = reader.get_tensor(
- checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
+ base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
# The object graph proto does not exist in this checkpoint. Try the
# name-based compatibility mode.
@@ -1090,7 +1300,7 @@ class CheckpointableSaver(object):
"file a feature request if this limitation bothers you.")
self._last_restore_checkpoint = checkpoint
self._last_restore_object_graph = object_graph_proto
- checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access
+ base._CheckpointPosition( # pylint: disable=protected-access
checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
load_status = CheckpointLoadStatus(
checkpoint,
@@ -1100,7 +1310,7 @@ class CheckpointableSaver(object):
@tf_export("train.Checkpoint")
-class Checkpoint(checkpointable_lib.Checkpointable):
+class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.
`Checkpoint`'s constructor accepts keyword arguments whose values are types
@@ -1202,7 +1412,7 @@ class Checkpoint(checkpointable_lib.Checkpointable):
"""
super(Checkpoint, self).__init__()
for k, v in sorted(kwargs.items(), key=lambda item: item[0]):
- if not isinstance(v, checkpointable_lib.CheckpointableBase):
+ if not isinstance(v, base.CheckpointableBase):
raise ValueError(
("`Checkpoint` was expecting a checkpointable object (an object "
"derived from `CheckpointableBase`), got %s. If you believe this "
@@ -1221,7 +1431,7 @@ class Checkpoint(checkpointable_lib.Checkpointable):
with ops.device("/cpu:0"):
# add_variable creates a dependency named "save_counter"; NoDependency
# prevents creating a second dependency named "_save_counter".
- self._save_counter = checkpointable_lib.NoDependency(
+ self._save_counter = data_structures.NoDependency(
add_variable(self, name="save_counter", initializer=0,
dtype=dtypes.int64))
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 8cdf5d7855..3c1a4a6f83 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -44,11 +44,12 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
-class NonLayerCheckpointable(checkpointable.Checkpointable):
+class NonLayerCheckpointable(tracking.Checkpointable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
@@ -101,7 +102,7 @@ class InterfaceTests(test.TestCase):
name="duplicate", initial_value=1.)
duplicate = checkpointable_utils.add_variable(
obj, name="duplicate", shape=[])
- with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
+ with self.assertRaisesRegexp(ValueError, "'duplicate'.*already declared"):
checkpointable_utils.add_variable(obj, name="duplicate", shape=[])
self.evaluate(checkpointable_utils.gather_initializers(obj))
@@ -136,7 +137,7 @@ class InterfaceTests(test.TestCase):
def testInitNotCalled(self):
- class NoInit(checkpointable.Checkpointable):
+ class NoInit(tracking.Checkpointable):
def __init__(self):
pass
@@ -145,7 +146,7 @@ class InterfaceTests(test.TestCase):
checkpointable_utils.add_variable(NoInit(), "var", shape=[])
def testShapeDtype(self):
- root = checkpointable.Checkpointable()
+ root = tracking.Checkpointable()
v1 = checkpointable_utils.add_variable(
root, name="v1", initializer=3., dtype=dtypes.float64)
self.assertEqual(dtypes.float64, v1.dtype)
@@ -177,7 +178,7 @@ class InterfaceTests(test.TestCase):
def testNotCheckpointable(self):
class CallsFunctionalStuff(
- checkpointable.NotCheckpointable, checkpointable.Checkpointable):
+ tracking.NotCheckpointable, tracking.Checkpointable):
pass
test_dir = self.get_temp_dir()
@@ -187,7 +188,7 @@ class InterfaceTests(test.TestCase):
checkpoint.save(prefix)
class CallsFunctionalStuffOtherMRO(
- checkpointable.Checkpointable, checkpointable.NotCheckpointable):
+ tracking.Checkpointable, tracking.NotCheckpointable):
pass
checkpoint_reversed = checkpointable_utils.Checkpoint(
@@ -217,7 +218,7 @@ class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
self._mirrored_variable.assign(tensor))
-class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
+class _OwnsMirroredVariables(base.CheckpointableBase):
"""A Checkpointable object which returns a more complex SaveableObject."""
def __init__(self):
@@ -232,7 +233,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
primary_variable=self.non_dep_variable,
mirrored_variable=self.mirrored,
name=name)
- return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+ return {base.VARIABLE_VALUE_KEY: _saveable_factory}
# The Saver sorts by name before parsing, so we need a name property.
@property
@@ -355,7 +356,7 @@ class CheckpointingTests(test.TestCase):
optimizer_node.slot_variables[0]
.slot_variable_node_id].attributes[0].checkpoint_key)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
checkpoint = checkpointable_utils.Checkpoint(v=v)
@@ -375,7 +376,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(44., self.evaluate(v.non_dep_variable))
self.assertEqual(44., self.evaluate(v.mirrored))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturnedWithGlobalName(self):
# The same object can also be saved using the name-based saver.
v = _OwnsMirroredVariables()
@@ -391,7 +392,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(42., self.evaluate(v.non_dep_variable))
self.assertEqual(42., self.evaluate(v.mirrored))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveRestore(self):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
@@ -512,7 +513,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAgnosticUsage(self):
"""Graph/eager agnostic usage."""
# Does create garbage when executing eagerly due to ops.Graph() creation.
@@ -546,7 +547,7 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
# pylint: disable=cell-var-from-loop
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testWithDefun(self):
num_training_steps = 2
checkpoint_directory = self.get_temp_dir()
@@ -590,7 +591,7 @@ class CheckpointingTests(test.TestCase):
# pylint: enable=cell-var-from-loop
def _get_checkpoint_name(self, name):
- root = checkpointable.Checkpointable()
+ root = tracking.Checkpointable()
checkpointable_utils.add_variable(
root, name=name, shape=[1, 2], dtype=dtypes.float64)
(named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
@@ -611,18 +612,18 @@ class CheckpointingTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNumberedPath(self):
- root = checkpointable.Checkpointable()
- leaf = checkpointable.Checkpointable()
+ root = tracking.Checkpointable()
+ leaf = tracking.Checkpointable()
root.leaf = leaf
checkpointable_utils.add_variable(leaf, name="v", shape=[])
(named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
root, saveables_cache=None)
self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLocalNameValidation(self):
- root = checkpointable.Checkpointable()
- leaf = checkpointable.Checkpointable()
+ root = tracking.Checkpointable()
+ leaf = tracking.Checkpointable()
# Dots are escaped, which avoids conflicts with reserved names.
root._track_checkpointable(leaf, name=".ATTRIBUTES")
checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
@@ -660,16 +661,16 @@ class CheckpointingTests(test.TestCase):
optimizer.apply_gradients(
[(g, v) for g, v in zip(grad, model.vars)])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLateDependencyTracking(self):
- class Dependency(checkpointable.Checkpointable):
+ class Dependency(tracking.Checkpointable):
def build(self):
self.var = checkpointable_utils.add_variable(
self, "var", initializer=0.)
- class LateDependencies(checkpointable.Checkpointable):
+ class LateDependencies(tracking.Checkpointable):
def add_dep(self):
self.dep = Dependency()
@@ -692,16 +693,16 @@ class CheckpointingTests(test.TestCase):
status.run_restore_ops()
self.assertEqual(123., self.evaluate(load_into.dep.var))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDepAfterVar(self):
- class Dependency(checkpointable.Checkpointable):
+ class Dependency(tracking.Checkpointable):
def build(self):
self.var = checkpointable_utils.add_variable(
self, "var", initializer=0.)
- class DepAfterVar(checkpointable.Checkpointable):
+ class DepAfterVar(tracking.Checkpointable):
def add_dep(self):
dep = Dependency()
@@ -724,11 +725,11 @@ class CheckpointingTests(test.TestCase):
status.run_restore_ops()
self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = checkpointable.Checkpointable()
+ root = tracking.Checkpointable()
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
@@ -751,7 +752,7 @@ class CheckpointingTests(test.TestCase):
14.))
slots_path = checkpointable_utils.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "with_slots"))
- new_root = checkpointable.Checkpointable()
+ new_root = tracking.Checkpointable()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
slot_status = checkpointable_utils.CheckpointableSaver(
@@ -789,11 +790,11 @@ class CheckpointingTests(test.TestCase):
self.evaluate(train_op)
slot_status.assert_consumed()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testOverlappingRestores(self):
checkpoint_directory = self.get_temp_dir()
- save_root = checkpointable.Checkpointable()
- save_root.dep = checkpointable.Checkpointable()
+ save_root = tracking.Checkpointable()
+ save_root.dep = tracking.Checkpointable()
save_root.dep.var = checkpointable_utils.add_variable(
save_root.dep, name="var", initializer=0.)
self.evaluate(state_ops.assign(save_root.dep.var, 12.))
@@ -802,13 +803,13 @@ class CheckpointingTests(test.TestCase):
self.evaluate(state_ops.assign(save_root.dep.var, 13.))
second_path = saver.save(os.path.join(checkpoint_directory, "second"))
- first_root = checkpointable.Checkpointable()
- second_root = checkpointable.Checkpointable()
+ first_root = tracking.Checkpointable()
+ second_root = tracking.Checkpointable()
first_status = checkpointable_utils.CheckpointableSaver(
first_root).restore(first_path)
second_status = checkpointable_utils.CheckpointableSaver(
second_root).restore(second_path)
- load_dep = checkpointable.Checkpointable()
+ load_dep = tracking.Checkpointable()
load_dep.var = checkpointable_utils.add_variable(
load_dep, name="var", shape=[])
first_root.dep = load_dep
@@ -822,13 +823,13 @@ class CheckpointingTests(test.TestCase):
# Try again with the order of the restore() reversed. The last restore
# determines the final value.
- first_root = checkpointable.Checkpointable()
- second_root = checkpointable.Checkpointable()
+ first_root = tracking.Checkpointable()
+ second_root = tracking.Checkpointable()
second_status = checkpointable_utils.CheckpointableSaver(
second_root).restore(second_path)
first_status = checkpointable_utils.CheckpointableSaver(
first_root).restore(first_path)
- load_dep = checkpointable.Checkpointable()
+ load_dep = tracking.Checkpointable()
load_dep.var = checkpointable_utils.add_variable(
load_dep, name="var", shape=[])
first_root.dep = load_dep
@@ -840,39 +841,39 @@ class CheckpointingTests(test.TestCase):
second_status.run_restore_ops()
self.assertEqual(12., self.evaluate(load_dep.var))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testAmbiguousLoad(self):
# Not OK to split one checkpoint object into two
checkpoint_directory = self.get_temp_dir()
- save_root = checkpointable.Checkpointable()
- save_root.dep_one = checkpointable.Checkpointable()
- save_root.dep_two = checkpointable.Checkpointable()
- dep_three = checkpointable.Checkpointable()
+ save_root = tracking.Checkpointable()
+ save_root.dep_one = tracking.Checkpointable()
+ save_root.dep_two = tracking.Checkpointable()
+ dep_three = tracking.Checkpointable()
save_root.dep_one.dep_three = dep_three
save_root.dep_two.dep_three = dep_three
checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
self.evaluate(checkpointable_utils.gather_initializers(save_root))
save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
os.path.join(checkpoint_directory, "ckpt"))
- load_root = checkpointable.Checkpointable()
+ load_root = tracking.Checkpointable()
status = checkpointable_utils.CheckpointableSaver(load_root).restore(
save_path)
- load_root.dep_one = checkpointable.Checkpointable()
- load_root.dep_two = checkpointable.Checkpointable()
- load_root.dep_one.dep_three = checkpointable.Checkpointable()
- load_root.dep_two.dep_three = checkpointable.Checkpointable()
+ load_root.dep_one = tracking.Checkpointable()
+ load_root.dep_two = tracking.Checkpointable()
+ load_root.dep_one.dep_three = tracking.Checkpointable()
+ load_root.dep_two.dep_three = tracking.Checkpointable()
checkpointable_utils.add_variable(
load_root.dep_one.dep_three, name="var", initializer=0.)
with self.assertRaises(AssertionError):
status.assert_consumed()
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testObjectsCombined(self):
# Currently fine to load two checkpoint objects into one Python object
checkpoint_directory = self.get_temp_dir()
- save_root = checkpointable.Checkpointable()
- save_root.dep_one = checkpointable.Checkpointable()
- save_root.dep_two = checkpointable.Checkpointable()
+ save_root = tracking.Checkpointable()
+ save_root.dep_one = tracking.Checkpointable()
+ save_root.dep_two = tracking.Checkpointable()
checkpointable_utils.add_variable(
save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64)
checkpointable_utils.add_variable(
@@ -880,8 +881,8 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers(save_root))
save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
os.path.join(checkpoint_directory, "ckpt"))
- load_root = checkpointable.Checkpointable()
- load_root.dep_one = checkpointable.Checkpointable()
+ load_root = tracking.Checkpointable()
+ load_root.dep_one = tracking.Checkpointable()
load_root.dep_two = load_root.dep_one
v1 = checkpointable_utils.add_variable(
load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
@@ -893,12 +894,12 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(32., self.evaluate(v1))
self.assertEqual(64., self.evaluate(v2))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testDependencyLoop(self):
# Note: this test creates garbage during eager execution because it
# purposefully creates a reference cycle.
- first = checkpointable.Checkpointable()
- second = checkpointable.Checkpointable()
+ first = tracking.Checkpointable()
+ second = tracking.Checkpointable()
first.second = second
second.first = first
first.v = checkpointable_utils.add_variable(
@@ -911,10 +912,10 @@ class CheckpointingTests(test.TestCase):
os.path.join(checkpoint_directory, "ckpt"))
# Test deferred loading
- first_load = checkpointable.Checkpointable()
+ first_load = tracking.Checkpointable()
status = checkpointable_utils.CheckpointableSaver(
first_load).restore(save_path)
- second_load = checkpointable.Checkpointable()
+ second_load = tracking.Checkpointable()
first_load.second = second_load
second_load.first = first_load
with self.assertRaises(AssertionError):
@@ -939,13 +940,13 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testRestoreOnAssign(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_graph = ops.Graph()
with save_graph.as_default(), self.test_session(save_graph):
- first = checkpointable.Checkpointable()
+ first = tracking.Checkpointable()
first.var1 = variable_scope.get_variable(
name="outside_var", initializer=0.)
first.var2 = variable_scope.get_variable(
@@ -956,7 +957,7 @@ class CheckpointingTests(test.TestCase):
checkpoint_prefix)
restore_graph = ops.Graph()
with restore_graph.as_default(), self.test_session(restore_graph):
- second = checkpointable.Checkpointable()
+ second = tracking.Checkpointable()
second.var2 = variable_scope.get_variable(
name="blah", initializer=0.)
status = checkpointable_utils.CheckpointableSaver(
@@ -978,7 +979,7 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
@@ -989,11 +990,11 @@ class CheckpointingTests(test.TestCase):
saver.save(checkpoint_prefix)
self.assertEqual(before_ops, graph.get_operations())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCheckpointCleanup(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
self.evaluate(checkpointable_utils.gather_initializers(obj))
saver = checkpointable_utils.Checkpoint(obj=obj)
@@ -1009,11 +1010,11 @@ class CheckpointingTests(test.TestCase):
expected_filenames,
os.listdir(checkpoint_directory))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testCheckpointCleanupChangingVarList(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
self.evaluate(checkpointable_utils.gather_initializers(obj))
checkpoint = checkpointable_utils.Checkpoint(obj=obj)
@@ -1062,7 +1063,7 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
@@ -1132,7 +1133,7 @@ class CheckpointingTests(test.TestCase):
beta1_power, _ = optimizer._get_beta_accumulators()
self.assertAllEqual(3., self.evaluate(beta1_power))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_sequential(self):
model = sequential.Sequential()
checkpoint = checkpointable_utils.Checkpoint(model=model)
@@ -1164,7 +1165,7 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual([1., 2., 3., 4., 5.],
self.evaluate(deferred_second_dense.bias))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_initialize_if_not_restoring(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
@@ -1243,9 +1244,21 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
+class _ManualScope(tracking.Checkpointable):
+
+ def __call__(self):
+ with variable_scope.variable_scope("ManualScope") as vs:
+ self.variable_scope = vs
+ with checkpointable_utils.capture_dependencies(template=self):
+ return self._build()
+
+ def _build(self):
+ return variable_scope.get_variable(name="in_manual_scope", shape=[])
+
+
class TemplateTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_checkpointable_save_restore(self):
def _templated():
@@ -1255,14 +1268,23 @@ class TemplateTests(test.TestCase):
v2 = variable_scope.get_variable(
"v2", shape=[1], initializer=init_ops.zeros_initializer(),
use_resource=True)
- return v, v + 1., v2
+ manual = _ManualScope()
+ return v, v + 1., v2, manual, manual()
save_template = template.make_template("s1", _templated)
- v1_save, _, v2_save = save_template()
+ v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
+ six.assertCountEqual(
+ self,
+ [v1_save, v2_save, manual_scope, manual_scope_v, save_template],
+ checkpointable_utils.list_objects(save_template))
+ manual_dep, = manual_scope._checkpoint_dependencies
+ self.assertEqual("in_manual_scope", manual_dep.name)
+ self.assertIs(manual_scope_v, manual_dep.ref)
optimizer = adam.AdamOptimizer(0.0)
save_root = checkpointable_utils.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in save_template.variables])
self.evaluate([v.initializer for v in optimizer.variables()])
self.evaluate(v1_save.assign([12.]))
self.evaluate(v2_save.assign([14.]))
@@ -1275,17 +1297,19 @@ class TemplateTests(test.TestCase):
load_root = checkpointable_utils.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
- var, var_plus_one, var2 = load_template()
+ var, var_plus_one, var2, _, _ = load_template()
load_optimizer.minimize(var.read_value)
- self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual(3, len(load_template._checkpoint_dependencies))
self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ self.assertEqual("ManualScope",
+ load_template._checkpoint_dependencies[2].name)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([12.], self.evaluate(var))
self.assertAllEqual([13.], self.evaluate(var_plus_one))
self.assertAllEqual([14.], self.evaluate(var2))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_checkpointable_save_restore_nested(self):
def _inner_template():
@@ -1386,7 +1410,7 @@ class CheckpointCompatibilityTests(test.TestCase):
sess=session, save_path=checkpoint_prefix,
global_step=root.optimizer_step)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testLoadFromNameBasedSaver(self):
"""Save a name-based checkpoint, load it using the object-based API."""
with test_util.device(use_gpu=True):
@@ -1448,7 +1472,7 @@ class CheckpointCompatibilityTests(test.TestCase):
class PythonMetadataTests(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveLoad(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py
index e31fa02d60..70e1ca4b5d 100644
--- a/tensorflow/python/training/device_util.py
+++ b/tensorflow/python/training/device_util.py
@@ -27,13 +27,15 @@ def canonicalize(d, default=None):
"""Canonicalize device string.
If d has missing components, the rest would be deduced from the `default`
- argument or from '/job:localhost/replica:0/task:0/device:CPU:0'. For example:
+ argument or from '/replica:0/task:0/device:CPU:0'. For example:
If d = '/cpu:0', default='/job:worker/task:1', it returns
'/job:worker/replica:0/task:1/device:CPU:0'.
If d = '/cpu:0', default='/job:worker', it returns
'/job:worker/replica:0/task:0/device:CPU:0'.
If d = '/gpu:0', default=None, it returns
- '/job:localhost/replica:0/task:0/device:GPU:0'.
+ '/replica:0/task:0/device:GPU:0'.
+
+ Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string.
@@ -47,7 +49,9 @@ def canonicalize(d, default=None):
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.
result = tf_device.DeviceSpec(
- job="localhost", replica=0, task=0, device_type="CPU", device_index=0)
+ replica=0, task=0, device_type="CPU", device_index=0)
+ if context.executing_eagerly():
+ result.job = "localhost"
if default:
result.merge_from(tf_device.DeviceSpec.from_string(default))
result.merge_from(d)
diff --git a/tensorflow/python/training/device_util_test.py b/tensorflow/python/training/device_util_test.py
index 61525e21f5..cdbb08229d 100644
--- a/tensorflow/python/training/device_util_test.py
+++ b/tensorflow/python/training/device_util_test.py
@@ -52,7 +52,7 @@ class DeviceUtilTest(test.TestCase):
def testCanonicalizeWithoutDefaultDevice(self):
self.assertEqual(
device_util.canonicalize("/cpu:0"),
- "/job:localhost/replica:0/task:0/device:CPU:0")
+ "/replica:0/task:0/device:CPU:0")
self.assertEqual(
device_util.canonicalize("/job:worker/cpu:0"),
"/job:worker/replica:0/task:0/device:CPU:0")
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index ab8b37bb65..c719045c7f 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import threading
-import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
@@ -222,11 +221,11 @@ def has_distribution_strategy():
def get_loss_reduction():
- """Reduce `method_string` corresponding to the last loss reduction."""
+ """Reduce `aggregation` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if loss_reduction == losses_impl.Reduction.SUM:
- return "sum"
- return "mean"
+ return variable_scope.VariableAggregation.SUM
+ return variable_scope.VariableAggregation.MEAN
# ------------------------------------------------------------------------------
@@ -527,15 +526,21 @@ class DistributionStrategy(object):
V(`v`), output will have locality V(`v`) as well.
* `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower
context, like `d.update()` except with locality N.
- * `d.fetch(t)`: Copy `t` with any locality to the client's CPU device.
+ * `d.read_var(v)`: Gets the (read-only) value of the variable `v` (on
+ the device determined by the current device scope), aggregating
+ across towers for tower-local variables. Frequently, this will be
+ done automatically when using `v` in an expression or fetching it in
+ a cross-tower context, but this function can be used to force that
+ conversion happens at a particular point in time (for example, to
+ add the result of the conversion to a graph collection).
The standard pattern for updating variables is to:
1. Wrap your input dataset in `d.distribute_dataset()` and create an iterator.
2. Define each tower `d.call_for_each_tower()` up to the point of
getting a list of gradient, variable pairs.
- 3. Call `d.reduce("sum", t, v)` or `d.batch_reduce()` to sum the
- gradients (with locality T) into values with locality V(`v`).
+ 3. Call `d.reduce(VariableAggregation.SUM, t, v)` or `d.batch_reduce()` to sum
+ the gradients (with locality T) into values with locality V(`v`).
4. Call `d.update(v)` for each variable to update its value.
Steps 3 and 4 are done automatically by class `Optimizer` if you call
@@ -609,42 +614,20 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, reduce_method):
- """Inside this scope, new variables will not be mirrored.
+ def read_var(self, v):
+ """Reads the value of a variable.
- There will still be one component variable per tower, but there is
- no requirement that they stay in sync. Instead, when saving them
- or calling `fetch()`, we use the value that results when calling
- `reduce()` on all the towers' variables.
-
- Note: tower-local implies not trainable. Instead, it is expected
- that each tower will directly update (using `assign_add()` or
- whatever) its local variable instance but only the aggregated
- value (accessible using `fetch()`) will be exported from the
- model. When it is acceptable to only aggregate on export, we
- greatly reduce communication overhead by using tower-local
- variables.
-
- Note: All component variables will be initialized to the same
- value, using the initialization expression from the first tower.
- The values will match even if the initialization expression uses
- random numbers.
+ Returns the aggregate value of a tower-local variable, or the
+ (read-only) value of any other variable.
Args:
- reduce_method: String used as a `method_string` to `reduce()`
- to get the value to save when checkpointing.
+ v: A variable allocated within the scope of this `DistributionStrategy`.
Returns:
- A context manager.
+ A tensor representing the value of `v`, aggregated across towers if
+ necessary.
"""
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["use_resource"] = True
- kwargs["tower_local_reduce_method"] = reduce_method
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
+ raise NotImplementedError("must be implemented in descendants")
def colocate_vars_with(self, colocate_with_variable):
"""Scope that controls which devices variables will be created on.
@@ -796,12 +779,12 @@ class DistributionStrategy(object):
def _call_for_each_tower(self, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def reduce(self, method_string, value, destinations=None):
+ def reduce(self, aggregation, value, destinations=None):
"""Combine (via e.g. sum or mean) values across towers.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -816,18 +799,21 @@ class DistributionStrategy(object):
# TODO(josh11b): Return an unwrapped value if colocate_with is a
# single device.
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._reduce(method_string, value, destinations)
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._reduce(aggregation, value, destinations)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
raise NotImplementedError("must be implemented in descendants")
- def batch_reduce(self, method_string, value_destination_pairs):
+ def batch_reduce(self, aggregation, value_destination_pairs):
"""Combine multiple `reduce` calls into one for faster execution.
Args:
- method_string: A string indicating how to combine values, either
- "sum" or "mean".
+ aggregation: Indicates how a variable will be aggregated. Accepted values
+ are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -836,12 +822,17 @@ class DistributionStrategy(object):
"""
# TODO(josh11b): More docstring
_require_cross_tower_context(self)
- assert method_string in ("sum", "mean")
- return self._batch_reduce(method_string, value_destination_pairs)
-
- def _batch_reduce(self, method_string, value_destination_pairs):
- return [self.reduce(method_string, t, destinations=v)
- for t, v in value_destination_pairs]
+ assert aggregation in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN
+ ]
+ return self._batch_reduce(aggregation, value_destination_pairs)
+
+ def _batch_reduce(self, aggregation, value_destination_pairs):
+ return [
+ self.reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def update(self, var, fn, *args, **kwargs):
"""Run `fn` to update `var` using inputs mirrored to the same devices.
@@ -897,30 +888,6 @@ class DistributionStrategy(object):
def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
- def fetch(self, val, destination="/device:CPU:0", fn=lambda x: x):
- """Return a copy of `val` or `fn(val)` on `destination`.
-
- This is useful for getting a mirrored value onto a device. It
- will attempt to avoid a copy by checking if the value is already
- on the destination device.
-
- Args:
- val: Value (which may be mirrored) to copy.
- destination: A device string to copy the value to.
- fn: An optional function to apply to the value on the source
- device, before copying.
-
- Returns:
- A `Tensor` on `destination`.
- """
- _require_cross_tower_context(self)
- assert isinstance(destination, six.string_types)
- destination = device_util.resolve(destination)
- return self._fetch(val, destination, fn)
-
- def _fetch(self, val, destination, fn):
- raise NotImplementedError("must be implemented in descendants")
-
def unwrap(self, value):
"""Returns the list of all per-device values contained in `value`.
@@ -946,7 +913,7 @@ class DistributionStrategy(object):
return control_flow_ops.group(value, name=name)
# Special handling for the common case of one op.
v, = value
- if isinstance(v, ops.Tensor):
+ if hasattr(v, "op"):
v = v.op
return v
@@ -1094,10 +1061,6 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, reduce_method):
- """Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(reduce_method)
-
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
@@ -1144,22 +1107,11 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, reduce_method):
- """Does not set to resource variables."""
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["trainable"] = False
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_distribution_strategy_scope(self)
@@ -1180,9 +1132,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
with TowerContext(self, tower_id=0):
return fn(*args, **kwargs)
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
# TODO(josh11b): Use destinations?
- del method_string, destinations
+ del aggregation, destinations
return value
def _update(self, var, fn, *args, **kwargs):
@@ -1197,11 +1149,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
return fn(*args, **kwargs)
- def _fetch(self, var, destination, fn):
- with ops.colocate_with(var):
- var = fn(var)
- with ops.device(destination):
- return array_ops.identity(var)
+ def read_var(self, tower_local_var):
+ return array_ops.identity(tower_local_var)
def _unwrap(self, distributed_value):
return [distributed_value]
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/training/distribute_test.py
index 0a4f19c31f..694145ede7 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/training/distribute_test.py
@@ -29,6 +29,14 @@ class _TestTowerContext(distribute.TowerContext):
return kwargs["test_arg"]
+def _get_test_variable(name, synchronization, aggregation):
+ return {
+ "name": name,
+ "synchronization": synchronization,
+ "aggregation": aggregation
+ }
+
+
class _TestStrategy(distribute.DistributionStrategy):
def _call_for_each_tower(self, fn, *args, **kwargs):
@@ -36,7 +44,8 @@ class _TestStrategy(distribute.DistributionStrategy):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):
- return kwargs["name"]
+ return _get_test_variable(kwargs["name"], kwargs["synchronization"],
+ kwargs["aggregation"])
def _assert_in_default_state(t):
@@ -61,7 +70,11 @@ class TestStrategyTest(test.TestCase):
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
- self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
+ expected_value = _get_test_variable(
+ "bar", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="bar"))
with self.assertRaises(RuntimeError):
dist.call_for_each_tower(run_fn)
@@ -77,7 +90,27 @@ class TestStrategyTest(test.TestCase):
self.assertIs(dist, distribute.get_cross_tower_context())
self.assertTrue(distribute.has_distribution_strategy())
self.assertIs(dist, distribute.get_distribution_strategy())
- self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.AUTO,
+ variable_scope.VariableAggregation.NONE)
+ self.assertDictEqual(expected_value,
+ variable_scope.variable(1.0, name="baz"))
+ _assert_in_default_state(self)
+
+ def testSettingSynchronizationAndAggregation(self):
+ _assert_in_default_state(self)
+ dist = _TestStrategy()
+ with dist.scope():
+ expected_value = _get_test_variable(
+ "baz", variable_scope.VariableSynchronization.ON_WRITE,
+ variable_scope.VariableAggregation.MEAN)
+ self.assertDictEqual(
+ expected_value,
+ variable_scope.variable(
+ 1.0,
+ name="baz",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=variable_scope.VariableAggregation.MEAN))
_assert_in_default_state(self)
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index a07ad19a6e..ef50f6315d 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -41,6 +40,13 @@ class GradientDescentOptimizer(optimizer.Optimizer):
use_locking: If True use locks for update operations.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate` can be a callable that
+ takes no arguments and returns the actual value to use. This can be useful
+ for changing these values across different invocations of optimizer
+ functions.
+ @end_compatibility
"""
super(GradientDescentOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
@@ -71,7 +77,6 @@ class GradientDescentOptimizer(optimizer.Optimizer):
return var.scatter_sub(delta, use_locking=self._use_locking)
def _prepare(self):
- if not context.executing_eagerly() or not isinstance(
- self._learning_rate_tensor, ops.EagerTensor):
- self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
- name="learning_rate")
+ learning_rate = self._call_if_callable(self._learning_rate)
+ self._learning_rate_tensor = ops.convert_to_tensor(
+ learning_rate, name="learning_rate")
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index f89a9c5838..b304e92421 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -83,6 +83,32 @@ class GradientDescentOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
var1.eval())
+ def testBasicCallableParams(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lr = lambda: 3.0
+ sgd_op = gradient_descent.GradientDescentOptimizer(lr).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
def testMinimizeResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.test_session():
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 10ab4c1137..51190264e8 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import math
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -87,6 +88,12 @@ def exponential_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for exponential_decay.")
@@ -95,14 +102,22 @@ def exponential_decay(learning_rate,
[learning_rate, global_step, decay_steps, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- return math_ops.multiply(
- learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.multiply(
+ learning_rate, math_ops.pow(decay_rate, p), name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.piecewise_constant")
@@ -141,48 +156,62 @@ def piecewise_constant(x, boundaries, values, name=None):
ValueError: if types of `x` and `boundaries` do not match, or types of all
`values` do not match or
the number of elements in the lists does not match.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if len(boundaries) != len(values) - 1:
raise ValueError(
"The length of boundaries should be 1 less than the length of values")
with ops.name_scope(name, "PiecewiseConstant",
[x, boundaries, values, name]) as name:
- x = ops.convert_to_tensor(x)
- # Avoid explicit conversion to x's dtype. This could result in faulty
- # comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
- for i, b in enumerate(boundaries):
- if b.dtype.base_dtype != x.dtype.base_dtype:
- # We can promote int32 boundaries to int64 without loss of precision.
- # This covers the most common case where the user passes in boundaries
- # as an array of Python integers.
- if (b.dtype.base_dtype == dtypes.int32 and
- x.dtype.base_dtype == dtypes.int64):
- b = math_ops.cast(b, x.dtype.base_dtype)
- boundaries[i] = b
- else:
- raise ValueError(
- "Boundaries (%s) must have the same dtype as x (%s)." %
- (b.dtype.base_dtype, x.dtype.base_dtype))
- # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
- for v in values[1:]:
- if v.dtype.base_dtype != values[0].dtype.base_dtype:
- raise ValueError(
- "Values must have elements all with the same dtype (%s vs %s)." %
- (values[0].dtype.base_dtype, v.dtype.base_dtype))
- pred_fn_pairs = []
- pred_fn_pairs.append((x <= boundaries[0], lambda: values[0]))
- pred_fn_pairs.append((x > boundaries[-1], lambda: values[-1]))
- for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
- # Need to bind v here; can do this with lambda v=v: ...
- pred = (x > low) & (x <= high)
- pred_fn_pairs.append((pred, lambda v=v: v))
-
- # The default isn't needed here because our conditions are mutually
- # exclusive and exhaustive, but tf.case requires it.
- default = lambda: values[0]
- return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ x_recomp = ops.convert_to_tensor(x)
+ # Avoid explicit conversion to x's dtype. This could result in faulty
+ # comparisons, for example if floats are converted to integers.
+ for i, b in enumerate(boundaries):
+ if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
+ # We can promote int32 boundaries to int64 without loss of precision.
+ # This covers the most common case where the user passes in boundaries
+ # as an array of Python integers.
+ if (b.dtype.base_dtype == dtypes.int32 and
+ x_recomp.dtype.base_dtype == dtypes.int64):
+ b = math_ops.cast(b, x_recomp.dtype.base_dtype)
+ boundaries[i] = b
+ else:
+ raise ValueError(
+ "Boundaries (%s) must have the same dtype as x (%s)." %
+ (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
+ # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
+ for v in values[1:]:
+ if v.dtype.base_dtype != values[0].dtype.base_dtype:
+ raise ValueError(
+ "Values must have elements all with the same dtype (%s vs %s)." %
+ (values[0].dtype.base_dtype, v.dtype.base_dtype))
+ pred_fn_pairs = []
+ pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
+ pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
+ for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
+ # Need to bind v here; can do this with lambda v=v: ...
+ pred = (x_recomp > low) & (x_recomp <= high)
+ pred_fn_pairs.append((pred, lambda v=v: v))
+
+ # The default isn't needed here because our conditions are mutually
+ # exclusive and exhaustive, but tf.case requires it.
+ default = lambda: values[0]
+ return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.polynomial_decay")
@@ -263,6 +292,12 @@ def polynomial_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for polynomial_decay.")
@@ -272,27 +307,35 @@ def polynomial_decay(learning_rate,
]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
- decay_steps = math_ops.cast(decay_steps, dtype)
end_learning_rate = math_ops.cast(end_learning_rate, dtype)
power = math_ops.cast(power, dtype)
- if cycle:
- # Find the first multiple of decay_steps that is bigger than global_step.
- # If global_step is zero set the multiplier to 1
- multiplier = control_flow_ops.cond(
- math_ops.equal(global_step, 0), lambda: 1.0,
- lambda: math_ops.ceil(global_step / decay_steps))
- decay_steps = math_ops.multiply(decay_steps, multiplier)
- else:
- # Make sure that the global_step used is not bigger than decay_steps.
- global_step = math_ops.minimum(global_step, decay_steps)
-
- p = math_ops.div(global_step, decay_steps)
- return math_ops.add(
- math_ops.multiply(learning_rate - end_learning_rate,
- math_ops.pow(1 - p, power)),
- end_learning_rate,
- name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ decay_steps_recomp = math_ops.cast(decay_steps, dtype)
+ if cycle:
+ # Find the first multiple of decay_steps that is bigger than
+ # global_step. If global_step is zero set the multiplier to 1
+ multiplier = control_flow_ops.cond(
+ math_ops.equal(global_step_recomp, 0), lambda: 1.0,
+ lambda: math_ops.ceil(global_step_recomp / decay_steps))
+ decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
+ else:
+ # Make sure that the global_step used is not bigger than decay_steps.
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+
+ p = math_ops.div(global_step_recomp, decay_steps_recomp)
+ return math_ops.add(
+ math_ops.multiply(learning_rate - end_learning_rate,
+ math_ops.pow(1 - p, power)),
+ end_learning_rate,
+ name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.natural_exp_decay")
@@ -350,6 +393,12 @@ def natural_exp_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for natural_exp_decay.")
@@ -357,14 +406,23 @@ def natural_exp_decay(learning_rate,
[learning_rate, global_step, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- exponent = math_ops.exp(math_ops.multiply(math_ops.negative(decay_rate), p))
- return math_ops.multiply(learning_rate, exponent, name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ exponent = math_ops.exp(
+ math_ops.multiply(math_ops.negative(decay_rate), p))
+ return math_ops.multiply(learning_rate, exponent, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.inverse_time_decay")
@@ -432,6 +490,12 @@ def inverse_time_decay(learning_rate,
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("global_step is required for inverse_time_decay.")
@@ -439,15 +503,23 @@ def inverse_time_decay(learning_rate,
[learning_rate, global_step, decay_rate]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
decay_rate = math_ops.cast(decay_rate, dtype)
- p = global_step / decay_steps
- if staircase:
- p = math_ops.floor(p)
- const = math_ops.cast(constant_op.constant(1), learning_rate.dtype)
- denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
- return math_ops.div(learning_rate, denom, name=name)
+
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ p = global_step_recomp / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ const = math_ops.cast(constant_op.constant(1), dtype)
+ denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
+ return math_ops.div(learning_rate, denom, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.cosine_decay")
@@ -492,6 +564,12 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("cosine decay requires global_step")
@@ -499,15 +577,23 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
- completed_fraction = global_step / decay_steps
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
- return math_ops.multiply(learning_rate, decayed)
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ completed_fraction = global_step_recomp / decay_steps
+ cosine_decayed = 0.5 * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+
+ decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.cosine_decay_restarts")
@@ -561,6 +647,12 @@ def cosine_decay_restarts(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("cosine decay restarts requires global_step")
@@ -568,40 +660,48 @@ def cosine_decay_restarts(learning_rate,
learning_rate = ops.convert_to_tensor(
learning_rate, name="initial_learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
first_decay_steps = math_ops.cast(first_decay_steps, dtype)
alpha = math_ops.cast(alpha, dtype)
t_mul = math_ops.cast(t_mul, dtype)
m_mul = math_ops.cast(m_mul, dtype)
- completed_fraction = global_step / first_decay_steps
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ completed_fraction = global_step_recomp / first_decay_steps
- def compute_step(completed_fraction, geometric=False):
- if geometric:
- i_restart = math_ops.floor(
- math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
- math_ops.log(t_mul))
+ def compute_step(completed_fraction, geometric=False):
+ """Helper for `cond` operation."""
+ if geometric:
+ i_restart = math_ops.floor(
+ math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
+ math_ops.log(t_mul))
- sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
- completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
+ sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
+ completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
- else:
- i_restart = math_ops.floor(completed_fraction)
- completed_fraction = completed_fraction - i_restart
+ else:
+ i_restart = math_ops.floor(completed_fraction)
+ completed_fraction -= i_restart
+
+ return i_restart, completed_fraction
- return i_restart, completed_fraction
+ i_restart, completed_fraction = control_flow_ops.cond(
+ math_ops.equal(t_mul, 1.0),
+ lambda: compute_step(completed_fraction, geometric=False),
+ lambda: compute_step(completed_fraction, geometric=True))
- i_restart, completed_fraction = control_flow_ops.cond(
- math_ops.equal(t_mul, 1.0),
- lambda: compute_step(completed_fraction, geometric=False),
- lambda: compute_step(completed_fraction, geometric=True))
+ m_fac = m_mul**i_restart
+ cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
+ constant_op.constant(math.pi) * completed_fraction))
+ decayed = (1 - alpha) * cosine_decayed + alpha
- m_fac = m_mul**i_restart
- cosine_decayed = 0.5 * m_fac * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * completed_fraction))
- decayed = (1 - alpha) * cosine_decayed + alpha
+ return math_ops.multiply(learning_rate, decayed, name=name)
- return math_ops.multiply(learning_rate, decayed, name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.linear_cosine_decay")
@@ -664,6 +764,12 @@ def linear_cosine_decay(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("linear cosine decay requires global_step")
@@ -671,21 +777,28 @@ def linear_cosine_decay(learning_rate,
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
num_periods = math_ops.cast(num_periods, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
alpha = math_ops.cast(alpha, dtype)
beta = math_ops.cast(beta, dtype)
- linear_decayed = (decay_steps - global_step) / decay_steps
- completed_fraction = global_step / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+
+ linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
+ return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
- linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
- return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
@tf_export("train.noisy_linear_cosine_decay")
@@ -756,6 +869,12 @@ def noisy_linear_cosine_decay(learning_rate,
learning rate.
Raises:
ValueError: if `global_step` is not supplied.
+
+ @compatibility(eager)
+ When eager execution is enabled, this function returns a function which in
+ turn returns the decayed learning rate Tensor. This can be useful for changing
+ the learning rate value across different invocations of optimizer functions.
+ @end_compatibility
"""
if global_step is None:
raise ValueError("noisy linear cosine decay requires global_step")
@@ -763,29 +882,36 @@ def noisy_linear_cosine_decay(learning_rate,
[learning_rate, global_step]) as name:
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
dtype = learning_rate.dtype
- global_step = math_ops.cast(global_step, dtype)
decay_steps = math_ops.cast(decay_steps, dtype)
- global_step = math_ops.minimum(global_step, decay_steps)
initial_variance = math_ops.cast(initial_variance, dtype)
variance_decay = math_ops.cast(variance_decay, dtype)
num_periods = math_ops.cast(num_periods, dtype)
alpha = math_ops.cast(alpha, dtype)
beta = math_ops.cast(beta, dtype)
- linear_decayed = (decay_steps - global_step) / decay_steps
- variance = initial_variance / (
- math_ops.pow(1.0 + global_step, variance_decay))
- std = math_ops.sqrt(variance)
- noisy_linear_decayed = (
- linear_decayed +
- random_ops.random_normal(linear_decayed.shape, stddev=std))
-
- completed_fraction = global_step / decay_steps
- fraction = 2.0 * num_periods * completed_fraction
- cosine_decayed = 0.5 * (
- 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
- noisy_linear_cosine_decayed = (
- (alpha + noisy_linear_decayed) * cosine_decayed + beta)
-
- return math_ops.multiply(
- learning_rate, noisy_linear_cosine_decayed, name=name)
+ def decayed_lr():
+ """Helper to recompute learning rate; most helpful in eager-mode."""
+ global_step_recomp = math_ops.cast(global_step, dtype)
+ global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
+ linear_decayed = (decay_steps - global_step_recomp) / decay_steps
+ variance = initial_variance / (
+ math_ops.pow(1.0 + global_step_recomp, variance_decay))
+ std = math_ops.sqrt(variance)
+ noisy_linear_decayed = (
+ linear_decayed + random_ops.random_normal(
+ linear_decayed.shape, stddev=std))
+
+ completed_fraction = global_step_recomp / decay_steps
+ fraction = 2.0 * num_periods * completed_fraction
+ cosine_decayed = 0.5 * (
+ 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
+ noisy_linear_cosine_decayed = (
+ (alpha + noisy_linear_decayed) * cosine_decayed + beta)
+
+ return math_ops.multiply(
+ learning_rate, noisy_linear_cosine_decayed, name=name)
+
+ if not context.executing_eagerly():
+ decayed_lr = decayed_lr()
+
+ return decayed_lr
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 60306e4f12..4f3cf01822 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -21,12 +21,9 @@ from __future__ import print_function
import math
from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import gen_state_ops
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import learning_rate_decay
@@ -34,31 +31,35 @@ from tensorflow.python.training import learning_rate_decay
class LRDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testContinuous(self):
- with self.test_session():
- step = 5
- decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
- expected = .05 * 0.96 ** (5.0 / 10.0)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ self.evaluate(variables.global_variables_initializer())
+ step = 5
+ decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
+ expected = .05 * 0.96**(5.0 / 10.0)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testStaircase(self):
- with self.test_session():
- step = gen_state_ops.variable(shape=[], dtype=dtypes.int32,
- name="step", container="", shared_name="")
- assign_100 = state_ops.assign(step, 100)
- assign_1 = state_ops.assign(step, 1)
- assign_2 = state_ops.assign(step, 2)
- decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
- staircase=True)
- # No change to learning rate
- assign_1.op.run()
- self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
- assign_2.op.run()
- self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ if context.executing_eagerly():
+ step = resource_variable_ops.ResourceVariable(0)
+ self.evaluate(variables.global_variables_initializer())
+ decayed_lr = learning_rate_decay.exponential_decay(
+ .1, step, 3, 0.96, staircase=True)
+
+ # No change to learning rate due to staircase
+ expected = .1
+ self.evaluate(step.assign(1))
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ expected = .1
+ self.evaluate(step.assign(2))
+ self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
+
# Decayed learning rate
- assign_100.op.run()
expected = .1 * 0.96 ** (100 // 3)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ self.evaluate(step.assign(100))
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testVariables(self):
with self.test_session():
@@ -79,38 +80,44 @@ class LRDecayTest(test_util.TensorFlowTestCase):
expected = .1 * 0.96 ** (100 // 3)
self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPiecewiseConstant(self):
x = resource_variable_ops.ResourceVariable(-999)
- def pc():
- return learning_rate_decay.piecewise_constant(x, [100, 110, 120],
- [1.0, 0.1, 0.01, 0.001])
+ decayed_lr = learning_rate_decay.piecewise_constant(
+ x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])
self.evaluate(variables.global_variables_initializer())
- self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6)
self.evaluate(x.assign(100))
- self.assertAllClose(self.evaluate(pc()), 1.0, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 1.0, 1e-6)
self.evaluate(x.assign(105))
- self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6)
self.evaluate(x.assign(110))
- self.assertAllClose(self.evaluate(pc()), 0.1, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.1, 1e-6)
self.evaluate(x.assign(120))
- self.assertAllClose(self.evaluate(pc()), 0.01, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.01, 1e-6)
self.evaluate(x.assign(999))
- self.assertAllClose(self.evaluate(pc()), 0.001, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.001, 1e-6)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testPiecewiseConstantEdgeCases(self):
x_int = resource_variable_ops.ResourceVariable(
0, dtype=variables.dtypes.int32)
boundaries, values = [-1.0, 1.0], [1, 2, 3]
with self.assertRaises(ValueError):
- learning_rate_decay.piecewise_constant(x_int, boundaries, values)
+ decayed_lr = learning_rate_decay.piecewise_constant(
+ x_int, boundaries, values)
+ if context.executing_eagerly():
+ decayed_lr()
+
x = resource_variable_ops.ResourceVariable(0.0)
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
with self.assertRaises(ValueError):
- learning_rate_decay.piecewise_constant(x, boundaries, values)
+ decayed_lr = learning_rate_decay.piecewise_constant(
+ x, boundaries, values)
+ if context.executing_eagerly():
+ decayed_lr()
# Test that ref types are valid.
if not context.executing_eagerly():
@@ -123,221 +130,205 @@ class LRDecayTest(test_util.TensorFlowTestCase):
x_int64 = resource_variable_ops.ResourceVariable(
0, dtype=variables.dtypes.int64)
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
- def pc():
- return learning_rate_decay.piecewise_constant(x_int64, boundaries, values)
+ decayed_lr = learning_rate_decay.piecewise_constant(
+ x_int64, boundaries, values)
self.evaluate(variables.global_variables_initializer())
- self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6)
self.evaluate(x_int64.assign(1))
- self.assertAllClose(self.evaluate(pc()), 0.4, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.4, 1e-6)
self.evaluate(x_int64.assign(2))
- self.assertAllClose(self.evaluate(pc()), 0.5, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.5, 1e-6)
self.evaluate(x_int64.assign(3))
- self.assertAllClose(self.evaluate(pc()), 0.6, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.6, 1e-6)
self.evaluate(x_int64.assign(4))
- self.assertAllClose(self.evaluate(pc()), 0.7, 1e-6)
+ self.assertAllClose(self.evaluate(decayed_lr), 0.7, 1e-6)
class LinearDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testHalfWay(self):
- with self.test_session():
- step = 5
- lr = 0.05
- end_lr = 0.0
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
- expected = lr * 0.5
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
+ expected = lr * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testEnd(self):
- with self.test_session():
- step = 10
- lr = 0.05
- end_lr = 0.001
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
- expected = end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testHalfWayWithEnd(self):
- with self.test_session():
- step = 5
- lr = 0.05
- end_lr = 0.001
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
- expected = (lr + end_lr) * 0.5
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
+ expected = (lr + end_lr) * 0.5
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testBeyondEnd(self):
- with self.test_session():
- step = 15
- lr = 0.05
- end_lr = 0.001
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
- expected = end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testBeyondEndWithCycle(self):
- with self.test_session():
- step = 15
- lr = 0.05
- end_lr = 0.001
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- cycle=True)
- expected = (lr - end_lr) * 0.25 + end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, cycle=True)
+ expected = (lr - end_lr) * 0.25 + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class SqrtDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testHalfWay(self):
- with self.test_session():
- step = 5
- lr = 0.05
- end_lr = 0.0
- power = 0.5
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- power=power)
- expected = lr * 0.5 ** power
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 5
+ lr = 0.05
+ end_lr = 0.0
+ power = 0.5
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = lr * 0.5**power
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testEnd(self):
- with self.test_session():
- step = 10
- lr = 0.05
- end_lr = 0.001
- power = 0.5
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- power=power)
- expected = end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 10
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testHalfWayWithEnd(self):
- with self.test_session():
- step = 5
- lr = 0.05
- end_lr = 0.001
- power = 0.5
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- power=power)
- expected = (lr - end_lr) * 0.5 ** power + end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 5
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = (lr - end_lr) * 0.5**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testBeyondEnd(self):
- with self.test_session():
- step = 15
- lr = 0.05
- end_lr = 0.001
- power = 0.5
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- power=power)
- expected = end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
-
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, power=power)
+ expected = end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+
+ @test_util.run_in_graph_and_eager_modes
def testBeyondEndWithCycle(self):
- with self.test_session():
- step = 15
- lr = 0.05
- end_lr = 0.001
- power = 0.5
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step, 10, end_lr,
- power=power, cycle=True)
- expected = (lr - end_lr) * 0.25 ** power + end_lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ step = 15
+ lr = 0.05
+ end_lr = 0.001
+ power = 0.5
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, 10, end_lr, power=power, cycle=True)
+ expected = (lr - end_lr) * 0.25**power + end_lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class PolynomialDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testBeginWithCycle(self):
- with self.test_session():
- lr = 0.001
- decay_steps = 10
- step = 0
- decayed_lr = learning_rate_decay.polynomial_decay(lr, step,
- decay_steps, cycle=True)
- expected = lr
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ lr = 0.001
+ decay_steps = 10
+ step = 0
+ decayed_lr = learning_rate_decay.polynomial_decay(
+ lr, step, decay_steps, cycle=True)
+ expected = lr
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class ExponentialDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testDecay(self):
initial_lr = 0.1
k = 10
decay_rate = 0.96
- step = gen_state_ops.variable(
- shape=[], dtype=dtypes.int32, name="step", container="", shared_name="")
- assign_step = state_ops.assign(step, 0)
- increment_step = state_ops.assign_add(step, 1)
- decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step,
- k, decay_rate)
- with self.test_session():
- assign_step.op.run()
- for i in range(k+1):
- expected = initial_lr * math.exp(-i / k * decay_rate)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
- increment_step.op.run()
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr, step, k,
+ decay_rate)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+ @test_util.run_in_graph_and_eager_modes
def testStaircase(self):
initial_lr = 0.1
k = 10
decay_rate = 0.96
- step = gen_state_ops.variable(
- shape=[], dtype=dtypes.int32, name="step", container="", shared_name="")
- assign_step = state_ops.assign(step, 0)
- increment_step = state_ops.assign_add(step, 1)
- decayed_lr = learning_rate_decay.natural_exp_decay(initial_lr,
- step,
- k,
- decay_rate,
- staircase=True)
- with self.test_session():
- assign_step.op.run()
- for i in range(k+1):
- expected = initial_lr * math.exp(-decay_rate * (i // k))
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
- increment_step.op.run()
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay.natural_exp_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr * math.exp(-decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
class InverseDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testDecay(self):
initial_lr = 0.1
k = 10
decay_rate = 0.96
- step = gen_state_ops.variable(
- shape=[], dtype=dtypes.int32, name="step", container="", shared_name="")
- assign_step = state_ops.assign(step, 0)
- increment_step = state_ops.assign_add(step, 1)
- decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,
- step,
- k,
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr, step, k,
decay_rate)
- with self.test_session():
- assign_step.op.run()
- for i in range(k+1):
- expected = initial_lr / (1 + i / k * decay_rate)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
- increment_step.op.run()
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + i / k * decay_rate)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
+
+ @test_util.run_in_graph_and_eager_modes
def testStaircase(self):
initial_lr = 0.1
k = 10
decay_rate = 0.96
- step = gen_state_ops.variable(
- shape=[], dtype=dtypes.int32, name="step", container="", shared_name="")
- assign_step = state_ops.assign(step, 0)
- increment_step = state_ops.assign_add(step, 1)
- decayed_lr = learning_rate_decay.inverse_time_decay(initial_lr,
- step,
- k,
- decay_rate,
- staircase=True)
- with self.test_session():
- assign_step.op.run()
- for i in range(k+1):
- expected = initial_lr / (1 + decay_rate * (i // k))
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
- increment_step.op.run()
+ step = resource_variable_ops.ResourceVariable(0)
+ decayed_lr = learning_rate_decay.inverse_time_decay(
+ initial_lr, step, k, decay_rate, staircase=True)
+
+ self.evaluate(variables.global_variables_initializer())
+ for i in range(k + 1):
+ expected = initial_lr / (1 + decay_rate * (i // k))
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ self.evaluate(step.assign_add(1))
class CosineDecayTest(test_util.TensorFlowTestCase):
@@ -348,34 +339,35 @@ class CosineDecayTest(test_util.TensorFlowTestCase):
decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
return (1.0 - alpha) * decay + alpha
+ @test_util.run_in_graph_and_eager_modes
def testDecay(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay(
- initial_lr, step, num_training_steps)
- expected = self.np_cosine_decay(step, num_training_steps)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step,
+ num_training_steps)
+ expected = self.np_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testAlpha(self):
num_training_steps = 1000
initial_lr = 1.0
alpha = 0.1
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay(
- initial_lr, step, num_training_steps, alpha)
- expected = self.np_cosine_decay(step, num_training_steps, alpha)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay(initial_lr, step,
+ num_training_steps, alpha)
+ expected = self.np_cosine_decay(step, num_training_steps, alpha)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class CosineDecayRestartsTest(test_util.TensorFlowTestCase):
+
def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
alpha=0.0):
fac = 1.0
while step >= decay_steps:
- step = step - decay_steps
+ step -= decay_steps
decay_steps *= t_mul
fac *= m_mul
@@ -383,51 +375,51 @@ class CosineDecayRestartsTest(test_util.TensorFlowTestCase):
decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
return (1.0 - alpha) * decay + alpha
+ @test_util.run_in_graph_and_eager_modes
def testDecay(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay_restarts(
- initial_lr, step, num_training_steps)
- expected = self.np_cosine_decay_restarts(step, num_training_steps)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay_restarts(
+ initial_lr, step, num_training_steps)
+ expected = self.np_cosine_decay_restarts(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testAlpha(self):
num_training_steps = 1000
initial_lr = 1.0
alpha = 0.1
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay_restarts(
- initial_lr, step, num_training_steps, alpha=alpha)
- expected = self.np_cosine_decay_restarts(step, num_training_steps,
- alpha=alpha)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, alpha=alpha)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, alpha=alpha)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testMMul(self):
num_training_steps = 1000
initial_lr = 1.0
m_mul = 0.9
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay_restarts(
- initial_lr, step, num_training_steps, m_mul=m_mul)
- expected = self.np_cosine_decay_restarts(step, num_training_steps,
- m_mul=m_mul)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, m_mul=m_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, m_mul=m_mul)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testTMul(self):
num_training_steps = 1000
initial_lr = 1.0
t_mul = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.cosine_decay_restarts(
- initial_lr, step, num_training_steps, t_mul=t_mul)
- expected = self.np_cosine_decay_restarts(step, num_training_steps,
- t_mul=t_mul)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.cosine_decay_restarts(
+ initial_lr, step, num_training_steps, t_mul=t_mul)
+ expected = self.np_cosine_decay_restarts(
+ step, num_training_steps, t_mul=t_mul)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class LinearCosineDecayTest(test_util.TensorFlowTestCase):
@@ -444,65 +436,63 @@ class LinearCosineDecayTest(test_util.TensorFlowTestCase):
cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
return (alpha + linear_decayed) * cosine_decayed + beta
+ @test_util.run_in_graph_and_eager_modes
def testDefaultDecay(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.linear_cosine_decay(
- initial_lr, step, num_training_steps)
- expected = self.np_linear_cosine_decay(step, num_training_steps)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ expected = self.np_linear_cosine_decay(step, num_training_steps)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
+ @test_util.run_in_graph_and_eager_modes
def testNonDefaultDecay(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- decayed_lr = learning_rate_decay.linear_cosine_decay(
- initial_lr,
- step,
- num_training_steps,
- alpha=0.1,
- beta=1e-4,
- num_periods=5)
- expected = self.np_linear_cosine_decay(
- step,
- num_training_steps,
- alpha=0.1,
- beta=1e-4,
- num_periods=5)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ decayed_lr = learning_rate_decay.linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ expected = self.np_linear_cosine_decay(
+ step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
class NoisyLinearCosineDecayTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes
def testDefaultNoisyLinearCosine(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- # No numerical check because of noise
- decayed_lr = learning_rate_decay.noisy_linear_cosine_decay(
- initial_lr, step, num_training_steps)
- decayed_lr.eval()
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay.noisy_linear_cosine_decay(
+ initial_lr, step, num_training_steps)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr)
+ @test_util.run_in_graph_and_eager_modes
def testNonDefaultNoisyLinearCosine(self):
num_training_steps = 1000
initial_lr = 1.0
for step in range(0, 1500, 250):
- with self.test_session():
- # No numerical check because of noise
- decayed_lr = learning_rate_decay.noisy_linear_cosine_decay(
- initial_lr,
- step,
- num_training_steps,
- initial_variance=0.5,
- variance_decay=0.1,
- alpha=0.1,
- beta=1e-4,
- num_periods=5)
- decayed_lr.eval()
+ # No numerical check because of noise
+ decayed_lr = learning_rate_decay.noisy_linear_cosine_decay(
+ initial_lr,
+ step,
+ num_training_steps,
+ initial_variance=0.5,
+ variance_decay=0.1,
+ alpha=0.1,
+ beta=1e-4,
+ num_periods=5)
+ # Cannot be deterministically tested
+ self.evaluate(decayed_lr)
if __name__ == "__main__":
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index bd9fa79d8f..cb3ec6f053 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -61,8 +61,8 @@ class MomentumOptimizer(optimizer.Optimizer):
variable(s) track the values called `theta_t + mu*v_t` in the paper.
@compatibility(eager)
- When eager execution is enabled, learning_rate and momentum can each be a
- callable that takes no arguments and returns the actual value to use. This
+ When eager execution is enabled, `learning_rate` and `momentum` can each be
+ a callable that takes no arguments and returns the actual value to use. This
can be useful for changing these values across different invocations of
optimizer functions.
@end_compatibility
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index fece3370f3..7b06bffa4b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -298,7 +298,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
stop_grace_period_secs=120,
log_step_count_steps=100,
max_wait_secs=7200,
- save_checkpoint_steps=USE_DEFAULT):
+ save_checkpoint_steps=USE_DEFAULT,
+ summary_dir=None):
"""Creates a `MonitoredSession` for training.
For a chief, this utility sets proper session initializer/restorer. It also
@@ -348,6 +349,8 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
`save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
the default checkpoint saver isn't used. If both are provided, then only
`save_checkpoint_secs` is used. Default not enabled.
+ summary_dir: A string. Optional path to a directory where to
+ save summaries. If None, checkpoint_dir is used instead.
Returns:
A `MonitoredSession` object.
@@ -388,11 +391,12 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
master=master,
config=config)
- if checkpoint_dir:
+ summary_dir = summary_dir or checkpoint_dir
+ if summary_dir:
if log_step_count_steps and log_step_count_steps > 0:
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
- output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))
+ output_dir=summary_dir, every_n_steps=log_step_count_steps))
if (save_summaries_steps and save_summaries_steps > 0) or (
save_summaries_secs and save_summaries_secs > 0):
@@ -400,7 +404,9 @@ def MonitoredTrainingSession(master='', # pylint: disable=invalid-name
scaffold=scaffold,
save_steps=save_summaries_steps,
save_secs=save_summaries_secs,
- output_dir=checkpoint_dir))
+ output_dir=summary_dir))
+
+ if checkpoint_dir:
if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
save_checkpoint_steps and save_checkpoint_steps > 0):
all_hooks.append(basic_session_run_hooks.CheckpointSaverHook(
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index a9287a0f0d..f75db08059 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -77,9 +77,10 @@ def _deduplicate_indexed_slices(values, indices):
def _var_key(var):
- if context.executing_eagerly():
- return var._unique_id # pylint: disable=protected-access
- return (var.op.graph, var.op.name)
+ # TODO(ashankar): Consolidate handling for eager and graph
+ if hasattr(var, "op"):
+ return (var.op.graph, var.op.name)
+ return var._unique_id # pylint: disable=protected-access
class _OptimizableVariable(object):
@@ -461,7 +462,8 @@ class Optimizer(
# Have to be careful to call distribute_lib.get_loss_reduction()
# *after* loss() is evaluated, so we know what loss reduction it uses.
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if distribute_lib.get_loss_reduction() == "mean":
+ if (distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN):
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss_value *= (1. / num_towers)
@@ -478,7 +480,8 @@ class Optimizer(
"be a function when eager execution is enabled.")
# Scale loss if using a "mean" loss reduction and multiple towers.
- if distribute_lib.get_loss_reduction() == "mean":
+ if (distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN):
num_towers = distribute_lib.get_distribution_strategy().num_towers
if num_towers > 1:
loss *= (1. / num_towers)
@@ -649,7 +652,8 @@ class Optimizer(
towers. If `global_step` was not None, that operation also
increments `global_step`.
"""
- reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
# Note that this is called in a cross-tower context.
@@ -730,15 +734,15 @@ class Optimizer(
if not named_slots:
return None
- if hasattr(var, "_mirrored_container"):
+ if hasattr(var, "_distributed_container"):
# NOTE: If this isn't patched, then there is no `handle` in
# `_resource_apply_dense`.
- mirrored_container = var._mirrored_container()
- assert mirrored_container is not None
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
if context.executing_eagerly():
- key = mirrored_container._unique_id
+ key = distributed_container._unique_id
else:
- key = (mirrored_container.graph, mirrored_container._shared_name)
+ key = (distributed_container.graph, distributed_container._shared_name)
# pylint: enable=protected-access
mirrored_slot = named_slots.get(key, None)
if mirrored_slot is None: return None
@@ -839,7 +843,7 @@ class Optimizer(
def _get_non_slot_variable(self, name, graph=None):
non_slot = self._non_slot_dict.get((name, graph), None)
- if hasattr(non_slot, "_mirrored_container"):
+ if hasattr(non_slot, "_distributed_container"):
# This is a mirrored non-slot. In order to enable code like `_finish`
# to assign to a non-slot, return the current context replica.
return non_slot.get()
@@ -1211,3 +1215,7 @@ class Optimizer(
self._deferred_slot_restorations.setdefault(
slot_name, {}).setdefault(variable_key, []).append(
slot_variable_position)
+
+ def _call_if_callable(self, param):
+ """Call the function if param is callable."""
+ return param() if callable(param) else param
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index 0cab6410e8..dfe9176bea 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.training import gradient_descent
class OptimizerTest(test.TestCase):
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testBasic(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -112,7 +112,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
var1.eval())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoVariables(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
# pylint: disable=cell-var-from-loop
@@ -127,7 +127,7 @@ class OptimizerTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'No.*variables'):
sgd_op.minimize(loss)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -145,7 +145,7 @@ class OptimizerTest(test.TestCase):
# var1 has no gradient
sgd_op.minimize(loss, var_list=[var1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_Minimize(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -161,7 +161,7 @@ class OptimizerTest(test.TestCase):
'No gradients provided for any variable'):
sgd_op.minimize(loss, var_list=[var0, var1])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNoGradientsForAnyVariables_ApplyGradients(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -175,7 +175,7 @@ class OptimizerTest(test.TestCase):
'No gradients provided for any variable'):
sgd_op.apply_gradients([(None, var0), (None, var1)])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testGradientsAsVariables(self):
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
# Note that we name the variables uniquely here since the variables don't
@@ -215,7 +215,7 @@ class OptimizerTest(test.TestCase):
self.assertAllClose([-14., -13.], self.evaluate(var0))
self.assertAllClose([-6., -5.], self.evaluate(var1))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testComputeGradientsWithTensors(self):
x = ops.convert_to_tensor(1.0)
def f():
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index 341b970c92..f38c9861d6 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -92,6 +92,13 @@ class RMSPropOptimizer(optimizer.Optimizer):
computation and memory. Defaults to False.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
+
+ @compatibility(eager)
+ When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions.
+ @end_compatibility
"""
super(RMSPropOptimizer, self).__init__(use_locking, name)
self._learning_rate = learning_rate
@@ -120,12 +127,15 @@ class RMSPropOptimizer(optimizer.Optimizer):
self._zeros_slot(v, "momentum", self._name)
def _prepare(self):
- self._learning_rate_tensor = ops.convert_to_tensor(
- self._learning_rate, name="learning_rate")
- self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay")
- self._momentum_tensor = ops.convert_to_tensor(
- self._momentum, name="momentum")
- self._epsilon_tensor = ops.convert_to_tensor(self._epsilon, name="epsilon")
+ lr = self._call_if_callable(self._learning_rate)
+ decay = self._call_if_callable(self._decay)
+ momentum = self._call_if_callable(self._momentum)
+ epsilon = self._call_if_callable(self._epsilon)
+
+ self._learning_rate_tensor = ops.convert_to_tensor(lr, name="learning_rate")
+ self._decay_tensor = ops.convert_to_tensor(decay, name="decay")
+ self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
+ self._epsilon_tensor = ops.convert_to_tensor(epsilon, name="epsilon")
def _apply_dense(self, grad, var):
rms = self.get_slot(var, "rms")
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index ee5385596c..6043327384 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -24,6 +24,7 @@ import math
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -141,7 +142,7 @@ class RMSPropOptimizerTest(test.TestCase):
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
- for t in range(1, 5):
+ for _ in range(1, 5):
update.run()
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
@@ -261,7 +262,7 @@ class RMSPropOptimizerTest(test.TestCase):
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 4 steps of RMSProp
- for t in range(1, 5):
+ for _ in range(1, 5):
update.run()
var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
@@ -444,6 +445,55 @@ class RMSPropOptimizerTest(test.TestCase):
(0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 2e-5)))
]), var1.eval())
+ def testCallableParams(self):
+ with context.eager_mode():
+ for dtype in [dtypes.half, dtypes.float32]:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ learning_rate = lambda: 2.0
+ decay = lambda: 0.9
+ momentum = lambda: 0.0
+ epsilon = lambda: 1.0
+ opt = rmsprop.RMSPropOptimizer(learning_rate, decay, momentum, epsilon)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0))
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
+ ]), self.evaluate(var1))
+ # Step 2: the root mean square accumulators contain the previous update.
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901 + 1.0)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001 + 1.0))
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5 + 1.0))
+ ]), self.evaluate(var1))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 4d464135fd..1ee975fbe4 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import collections
import os.path
import re
-import sys
import time
import uuid
@@ -206,21 +205,19 @@ class BaseSaverBuilder(object):
filename_tensor: String Tensor.
saveables: List of BaseSaverBuilder.SaveableObject objects.
preferred_shard: Int. Shard to open first when loading a sharded file.
- restore_sequentially: Bool. If true, each restore is sequential.
+ restore_sequentially: Unused. Bool. If true, each restore is sequential.
Returns:
A list of Tensors resulting from reading 'saveable' from
'filename'.
"""
+ del restore_sequentially
all_tensors = []
- assign_ops = []
for saveable in saveables:
- restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
- with ops.control_dependencies(restore_control_inputs):
- all_tensors.extend(
- self.restore_op(filename_tensor, saveable, preferred_shard))
+ all_tensors.extend(
+ self.restore_op(filename_tensor, saveable, preferred_shard))
return all_tensors
# pylint: disable=unused-argument
@@ -1045,8 +1042,8 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
ckpt = CheckpointState()
text_format.Merge(file_content, ckpt)
if not ckpt.model_checkpoint_path:
- raise ValueError("Invalid checkpoint state loaded from %s",
- checkpoint_dir)
+ raise ValueError("Invalid checkpoint state loaded from "
+ + checkpoint_dir)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(ckpt.model_checkpoint_path):
@@ -1373,23 +1370,6 @@ class Saver(object):
name, _ = p
return name
- def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
- """Returns the meta graph filename.
-
- Args:
- checkpoint_filename: Name of the checkpoint file.
- meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
-
- Returns:
- MetaGraph file name.
- """
- # If the checkpoint_filename is sharded, the checkpoint_filename could
- # be of format model.ckpt-step#-?????-of-shard#. For example,
- # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
- basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
- meta_graph_filename = ".".join([basename, meta_graph_suffix])
- return meta_graph_filename
-
def _RecordLastCheckpoint(self, latest_save_path):
"""Manages the list of the latest checkpoints."""
if not self.saver_def.max_to_keep:
@@ -1430,24 +1410,12 @@ class Saver(object):
# Otherwise delete the files.
try:
- checkpoint_prefix = self._CheckpointFilename(p)
- self._delete_file_if_exists(
- self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix))
- if self.saver_def.version == saver_pb2.SaverDef.V2:
- # V2 has a metadata file and some data files.
- self._delete_file_if_exists(checkpoint_prefix + ".index")
- self._delete_file_if_exists(checkpoint_prefix +
- ".data-?????-of-?????")
- else:
- # V1, Legacy. Exact match on the data file.
- self._delete_file_if_exists(checkpoint_prefix)
+ remove_checkpoint(
+ self._CheckpointFilename(p), self.saver_def.version,
+ meta_graph_suffix)
except Exception as e: # pylint: disable=broad-except
logging.warning("Ignoring: %s", str(e))
- def _delete_file_if_exists(self, filespec):
- for pathname in file_io.get_matching_files(filespec):
- file_io.delete_file(pathname)
-
def as_saver_def(self):
"""Generates a `SaverDef` representation of this saver.
@@ -1669,7 +1637,7 @@ class Saver(object):
raise exc
if write_meta_graph:
- meta_graph_filename = self._MetaGraphFilename(
+ meta_graph_filename = _meta_graph_filename(
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if not context.executing_eagerly():
with sess.graph.as_default():
@@ -1737,12 +1705,17 @@ class Saver(object):
save_path: Path where parameters were previously saved.
Raises:
- ValueError: If save_path is None.
+ ValueError: If save_path is None or not a valid checkpoint.
"""
if self._is_empty:
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+
+ if not checkpoint_exists(compat.as_text(save_path)):
+ raise ValueError("The passed save_path is not a valid checkpoint: "
+ + compat.as_text(save_path))
+
logging.info("Restoring parameters from %s", compat.as_text(save_path))
try:
if context.executing_eagerly():
@@ -1750,23 +1723,24 @@ class Saver(object):
else:
sess.run(self.saver_def.restore_op_name,
{self.saver_def.filename_tensor_name: save_path})
- except errors.NotFoundError:
- exception_type, exception_value, exception_traceback = sys.exc_info()
- # The checkpoint would not be loaded successfully as is. Try to parse it
- # as an object-based checkpoint.
- should_reraise = False
+ except errors.NotFoundError as err:
+ # There are three common conditions that might cause this error:
+ # 0. The file is missing. We ignore here, as this is checked above.
+ # 1. This is an object-based checkpoint trying name-based loading.
+ # 2. The graph has been altered and a variable or other name is missing.
+
+ # 1. The checkpoint would not be loaded successfully as is. Try to parse
+ # it as an object-based checkpoint.
try:
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
object_graph_string = reader.get_tensor(
checkpointable.OBJECT_GRAPH_PROTO_KEY)
except errors.NotFoundError:
- # This is not an object-based checkpoint, or the checkpoint doesn't
- # exist. Re-raise the original exception, but do it outside the except
- # block so the object graph lookup isn't included in the stack trace.
- should_reraise = True
- if should_reraise:
- six.reraise(exception_type, exception_value, exception_traceback)
- del exception_traceback # avoid reference cycles
+ # 2. This is not an object-based checkpoint, which likely means there
+ # is a graph mismatch. Re-raise the original error with
+ # a helpful message (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a Variable name or other graph key that is missing")
# This is an object-based checkpoint. We'll print a warning and then do
# the restore.
@@ -1778,6 +1752,11 @@ class Saver(object):
self._restore_from_object_based_checkpoint(
sess=sess, save_path=save_path,
object_graph_string=object_graph_string)
+ except errors.InvalidArgumentError as err:
+ # There is a mismatch between the graph and the checkpoint being loaded.
+ # We add a more reasonable error message here to help users (b/110263146)
+ raise _wrap_restore_error_with_msg(
+ err, "a mismatch between the current graph and the graph")
def _restore_from_object_based_checkpoint(self, sess, save_path,
object_graph_string):
@@ -1970,7 +1949,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
else:
- if variables._all_saveable_objects(): # pylint: disable=protected-access
+ if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
# Return the default saver instance for all graph variables.
return Saver()
else:
@@ -2121,6 +2100,63 @@ def get_checkpoint_mtimes(checkpoint_prefixes):
return mtimes
+@tf_export("train.remove_checkpoint")
+def remove_checkpoint(checkpoint_prefix,
+ checkpoint_format_version=saver_pb2.SaverDef.V2,
+ meta_graph_suffix="meta"):
+ """Removes a checkpoint given by `checkpoint_prefix`.
+
+ Args:
+ checkpoint_prefix: The prefix of a V1 or V2 checkpoint. Typically the result
+ of `Saver.save()` or that of `tf.train.latest_checkpoint()`, regardless of
+ sharded/non-sharded or V1/V2.
+ checkpoint_format_version: `SaverDef.CheckpointFormatVersion`, defaults to
+ `SaverDef.V2`.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+ """
+ _delete_file_if_exists(
+ _meta_graph_filename(checkpoint_prefix, meta_graph_suffix))
+ if checkpoint_format_version == saver_pb2.SaverDef.V2:
+ # V2 has a metadata file and some data files.
+ _delete_file_if_exists(checkpoint_prefix + ".index")
+ _delete_file_if_exists(checkpoint_prefix + ".data-?????-of-?????")
+ else:
+ # V1, Legacy. Exact match on the data file.
+ _delete_file_if_exists(checkpoint_prefix)
+
+
+def _delete_file_if_exists(filespec):
+ """Deletes files matching `filespec`."""
+ for pathname in file_io.get_matching_files(filespec):
+ file_io.delete_file(pathname)
+
+
+def _meta_graph_filename(checkpoint_filename, meta_graph_suffix="meta"):
+ """Returns the meta graph filename.
+
+ Args:
+ checkpoint_filename: Name of the checkpoint file.
+ meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
+
+ Returns:
+ MetaGraph file name.
+ """
+ # If the checkpoint_filename is sharded, the checkpoint_filename could
+ # be of format model.ckpt-step#-?????-of-shard#. For example,
+ # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
+ basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
+ meta_graph_filename = ".".join([basename, meta_graph_suffix])
+ return meta_graph_filename
+
+
+def _wrap_restore_error_with_msg(err, extra_verbiage):
+ err_msg = ("Restoring from checkpoint failed. This is most likely "
+ "due to {} from the checkpoint. Please ensure that you "
+ "have not altered the graph expected based on the checkpoint. "
+ "Original error:\n\n{}").format(extra_verbiage, err.message)
+ return err.__class__(err.node_def, err.op, err_msg)
+
+
ops.register_proto_function(
ops.GraphKeys.SAVERS,
proto_type=saver_pb2.SaverDef,
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index f1991093e0..ae9c244aaf 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -24,10 +24,8 @@ import math
import os
import random
import shutil
-import sys
import tempfile
import time
-import traceback
import numpy as np
import six
@@ -79,7 +77,8 @@ from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import base as checkpointable_base
+from tensorflow.python.training.checkpointable import tracking as checkpointable_tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@@ -171,7 +170,7 @@ class SaverTest(test.TestCase):
def testBasic(self):
self.basicSaveRestore(variables.Variable)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
@@ -252,7 +251,7 @@ class SaverTest(test.TestCase):
self.assertAllEqual(w3.eval(), 3.0)
self.assertAllEqual(w4.eval(), 4.0)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testResourceSaveRestoreCachingDevice(self):
save_path = os.path.join(self.get_temp_dir(), "resource_cache")
with self.test_session(graph=ops_lib.Graph()) as sess:
@@ -368,8 +367,8 @@ class SaverTest(test.TestCase):
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.test_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
- with self.assertRaisesRegexp(errors.NotFoundError,
- "Failed to find any matching files for"):
+ with self.assertRaisesRegexp(
+ ValueError, "The passed save_path is not a valid checkpoint:"):
save.restore(sess, "invalid path")
def testInt64(self):
@@ -671,7 +670,7 @@ class SaverTest(test.TestCase):
save.restore(sess, save_path)
self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testSaveWithGlobalStep(self, pad_step_number=False):
save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
global_step_int = 5
@@ -809,7 +808,7 @@ class SaveRestoreShardedTest(test.TestCase):
self.assertEqual(save_path + "-?????-of-00002", val)
else:
self.assertEqual(save_path, val)
- meta_graph_filename = save._MetaGraphFilename(val)
+ meta_graph_filename = saver_module._meta_graph_filename(val)
self.assertEqual(save_path + ".meta", meta_graph_filename)
if save._write_version is saver_pb2.SaverDef.V1:
@@ -1185,13 +1184,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s3, s2], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s1))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertTrue(saver_module.checkpoint_exists(s3))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1202,13 +1201,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1222,14 +1221,14 @@ class MaxToKeepTest(test.TestCase):
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertCheckpointState(
model_checkpoint_path=s2,
all_model_checkpoint_paths=[s3, s2],
@@ -1240,13 +1239,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save2.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1260,14 +1259,14 @@ class MaxToKeepTest(test.TestCase):
# Created by the first helper.
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
# Deleted by the first helper.
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
# Even though the file for s1 exists, this saver isn't aware of it, which
# is why it doesn't end up in the checkpoint state.
self.assertCheckpointState(
@@ -1280,13 +1279,13 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual([s2, s1], save3.last_checkpoints)
self.assertFalse(saver_module.checkpoint_exists(s3))
self.assertFalse(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s3)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s3)))
self.assertTrue(saver_module.checkpoint_exists(s2))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s2)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s2)))
self.assertTrue(saver_module.checkpoint_exists(s1))
self.assertTrue(
- saver_module.checkpoint_exists(save._MetaGraphFilename(s1)))
+ saver_module.checkpoint_exists(saver_module._meta_graph_filename(s1)))
self.assertCheckpointState(
model_checkpoint_path=s1,
all_model_checkpoint_paths=[s2, s1],
@@ -1317,7 +1316,7 @@ class MaxToKeepTest(test.TestCase):
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
+ self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
s2 = save.save(sess, os.path.join(save_dir, "s2"))
self.assertEqual([s1, s2], save.last_checkpoints)
@@ -1325,27 +1324,27 @@ class MaxToKeepTest(test.TestCase):
self.assertEqual(2, len(gfile.Glob(s1)))
else:
self.assertEqual(4, len(gfile.Glob(s1 + "*")))
- self.assertTrue(gfile.Exists(save._MetaGraphFilename(s1)))
+ self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
+ self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
s3 = save.save(sess, os.path.join(save_dir, "s3"))
self.assertEqual([s2, s3], save.last_checkpoints)
self.assertEqual(0, len(gfile.Glob(s1 + "*")))
- self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
+ self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s2)))
else:
self.assertEqual(4, len(gfile.Glob(s2 + "*")))
- self.assertTrue(gfile.Exists(save._MetaGraphFilename(s2)))
+ self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s2)))
if save._write_version is saver_pb2.SaverDef.V1:
self.assertEqual(2, len(gfile.Glob(s3)))
else:
self.assertEqual(4, len(gfile.Glob(s3 + "*")))
- self.assertTrue(gfile.Exists(save._MetaGraphFilename(s3)))
+ self.assertTrue(gfile.Exists(saver_module._meta_graph_filename(s3)))
def testNoMaxToKeep(self):
save_dir = self._get_test_dir("no_max_to_keep")
@@ -1385,7 +1384,7 @@ class MaxToKeepTest(test.TestCase):
s1 = save.save(sess, os.path.join(save_dir, "s1"), write_meta_graph=False)
self.assertTrue(saver_module.checkpoint_exists(s1))
- self.assertFalse(gfile.Exists(save._MetaGraphFilename(s1)))
+ self.assertFalse(gfile.Exists(saver_module._meta_graph_filename(s1)))
class KeepCheckpointEveryNHoursTest(test.TestCase):
@@ -1395,7 +1394,7 @@ class KeepCheckpointEveryNHoursTest(test.TestCase):
gfile.MakeDirs(test_dir)
return test_dir
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(saver_module, "time")
def testNonSharded(self, mock_time):
save_dir = self._get_test_dir("keep_checkpoint_every_n_hours")
@@ -1515,7 +1514,7 @@ class SaveRestoreWithVariableNameMap(test.TestCase):
self.assertEqual(10.0, self.evaluate(v0))
self.assertEqual(20.0, self.evaluate(v1))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNonReshapeResourceVariable(self):
self._testNonReshape(resource_variable_ops.ResourceVariable)
@@ -2339,6 +2338,46 @@ class MetaGraphTest(test.TestCase):
10, size=[1, 10])
})
+ def testImportIntoNamescopeWithoutVariables(self):
+ # Save a simple graph that contains no variables into a checkpoint.
+ test_dir = self._get_test_dir("no_vars_graph")
+ filename = os.path.join(test_dir, "ckpt")
+ graph_1 = ops_lib.Graph()
+ with session.Session(graph=graph_1) as sess:
+ constant_op.constant([1, 2, 3], name="x")
+ constant_op.constant([1, 2, 3], name="y")
+ saver = saver_module.Saver(allow_empty=True)
+ saver.save(sess, filename)
+
+ # Create a fresh graph.
+ graph_2 = ops_lib.Graph()
+ with session.Session(graph=graph_2) as sess:
+ # Restore the above checkpoint under scope "subgraph_1".
+ new_saver_1 = saver_module.import_meta_graph(
+ filename + ".meta", graph=graph_2, import_scope="subgraph_1")
+ # There are no variables to restore, so import_meta_graph should not
+ # return a Saver.
+ self.assertIsNone(new_saver_1)
+
+ # Create a variable in graph_2 under scope "my_scope".
+ variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
+ sess.run(variables.global_variables_initializer())
+ # Restore the checkpoint into a different scope "subgraph_2".
+ new_saver_2 = saver_module.import_meta_graph(
+ filename + ".meta", graph=graph_2, import_scope="subgraph_2")
+ # Because the variable does not live in scope "subgraph_2",
+ # import_meta_graph should not attempt to restore the variable. So,
+ # import_meta_graph still won't return a Saver instance.
+ self.assertIsNone(new_saver_2)
+
+ # However, if we restore the checkpoint under scope "my_scope",
+ # import_meta_graph will detect the variable and return a Saver for
+ # restoring it. This should happen even when the variable does not
+ # originate from graph_1.
+ new_saver_3 = saver_module.import_meta_graph(
+ filename + ".meta", graph=graph_2, import_scope="my_scope")
+ self.assertIsInstance(new_saver_3, saver_module.Saver)
+
def testImportIntoImplicitNamescope(self):
# Test that we can import a meta graph into an implicit namescope.
test_dir = self._get_test_dir("import_into_namescope")
@@ -2581,6 +2620,20 @@ class SaverUtilsTest(test.TestCase):
self.assertEqual(2, len(mtimes))
self.assertTrue(mtimes[1] >= mtimes[0])
+ def testRemoveCheckpoint(self):
+ for sharded in (False, True):
+ for version in (saver_pb2.SaverDef.V2, saver_pb2.SaverDef.V1):
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ unused_v = variables.Variable(1.0, name="v")
+ variables.global_variables_initializer().run()
+ saver = saver_module.Saver(sharded=sharded, write_version=version)
+
+ path = os.path.join(self._base_dir, "%s-%s" % (sharded, version))
+ ckpt_prefix = saver.save(sess, path)
+ self.assertTrue(saver_module.checkpoint_exists(ckpt_prefix))
+ saver_module.remove_checkpoint(ckpt_prefix, version)
+ self.assertFalse(saver_module.checkpoint_exists(ckpt_prefix))
+
class ScopedGraphTest(test.TestCase):
@@ -2885,7 +2938,7 @@ class ScopedGraphTest(test.TestCase):
self.assertEqual(2.0, var_dict2["variable2:0"].eval())
-class _OwnsAVariableSimple(checkpointable.CheckpointableBase):
+class _OwnsAVariableSimple(checkpointable_base.CheckpointableBase):
"""A Checkpointable object which can be saved using a tf.train.Saver."""
def __init__(self):
@@ -2893,7 +2946,7 @@ class _OwnsAVariableSimple(checkpointable.CheckpointableBase):
name="non_dep_variable", initializer=6., use_resource=True)
def _gather_saveables_for_checkpoint(self):
- return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable}
+ return {checkpointable_base.VARIABLE_VALUE_KEY: self.non_dep_variable}
# The Saver sorts by name before parsing, so we need a name property.
@property
@@ -2918,7 +2971,7 @@ class _MirroringSaveable(
self._mirrored_variable.assign(tensor))
-class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
+class _OwnsMirroredVariables(checkpointable_base.CheckpointableBase):
"""A Checkpointable object which returns a more complex SaveableObject."""
def __init__(self):
@@ -2933,7 +2986,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
primary_variable=self.non_dep_variable,
mirrored_variable=self.mirrored,
name=name)
- return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+ return {checkpointable_base.VARIABLE_VALUE_KEY: _saveable_factory}
# The Saver sorts by name before parsing, so we need a name property.
@property
@@ -2941,7 +2994,7 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
return self.non_dep_variable.name
-class NonLayerCheckpointable(checkpointable.Checkpointable):
+class NonLayerCheckpointable(checkpointable_tracking.Checkpointable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
@@ -2967,7 +3020,7 @@ class MyModel(training.Model):
class CheckpointableCompatibilityTests(test.TestCase):
# TODO(allenl): Track down python3 reference cycles in these tests.
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsCheckpointable(self):
v = _OwnsAVariableSimple()
saver = saver_module.Saver(var_list=[v])
@@ -2980,7 +3033,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver.restore(sess, save_path)
self.assertEqual(42., self.evaluate(v.non_dep_variable))
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
saver = saver_module.Saver(var_list=[v])
@@ -3084,27 +3137,33 @@ class CheckpointableCompatibilityTests(test.TestCase):
errors.NotFoundError, "Key b not found in checkpoint"):
b_saver.restore(sess=sess, save_path=save_path)
- def testCheckpointNotFoundErrorRaised(self):
- # Restore does some tricky exception handling to figure out if it should
- # load an object-based checkpoint. Tests that the exception handling isn't
- # too broad.
- a = resource_variable_ops.ResourceVariable(1., name="a")
- saver = saver_module.Saver([a])
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.NotFoundError,
- "Failed to find any matching files for path_which_does_not_exist"):
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- try:
- saver.restore(sess=sess, save_path="path_which_does_not_exist")
- except errors.NotFoundError:
- # Make sure we don't have a confusing "During handling of the above
- # exception" block in Python 3.
- # pylint: disable=no-value-for-parameter
- exception_string = "\n".join(
- traceback.format_exception(*sys.exc_info()))
- # pylint: enable=no-value-for-parameter
- self.assertNotIn("NewCheckpointReader", exception_string)
+ with self.assertRaises(errors.NotFoundError) as cs:
+ b_saver.restore(sess=sess, save_path=save_path)
+
+ # Make sure we don't have a confusing "During handling of the above
+ # exception" block in Python 3.
+ self.assertNotIn("NewCheckpointReader", cs.exception.message)
+
+ def testGraphChangedForRestoreErrorRaised(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable(1., name="a")
+ a_saver = saver_module.Saver([a])
+
+ with self.test_session(graph=g) as sess:
+ sess.run(a.initializer)
+ save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+
+ with ops_lib.Graph().as_default() as g:
+ a = variables.Variable([1.], name="a")
+ a_saver = saver_module.Saver([a])
+ with self.test_session(graph=g) as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "a mismatch between the current graph and the graph"):
+ a_saver.restore(sess=sess, save_path=save_path)
def testLoadFromObjectBasedGraph(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 376be39978..c8ed2b715d 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -87,6 +87,27 @@ def _call_location(outer=False):
return '%s:%d' % (entry[1], entry[2])
+def _wrap_decorator(wrapped_function):
+ """Indicate that one function wraps another.
+
+ This decorator wraps a function using `tf_decorator.make_decorator`
+ so that doc generation scripts can pick up original function
+ signature.
+ It would be better to use @functools.wrap decorator, but it would
+ not update function signature to match wrapped function in Python 2.
+
+ Args:
+ wrapped_function: The function that decorated function wraps.
+
+ Returns:
+ Function that accepts wrapper function as an argument and returns
+ `TFDecorator` instance.
+ """
+ def wrapper(wrapper_func):
+ return tf_decorator.make_decorator(wrapped_function, wrapper_func)
+ return wrapper
+
+
def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
"""Deprecate a symbol in favor of a new name with identical semantics.
@@ -144,7 +165,7 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
if tf_inspect.isclass(func_or_class):
# Make a new class with __init__ wrapped in a warning.
- class NewClass(func_or_class): # pylint: disable=missing-docstring
+ class _NewClass(func_or_class): # pylint: disable=missing-docstring
__doc__ = decorator_utils.add_notice_to_docstring(
func_or_class.__doc__, 'Please use %s instead.' % name,
'DEPRECATED CLASS',
@@ -153,27 +174,28 @@ def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
__name__ = func_or_class.__name__
__module__ = _call_location(outer=True)
+ @_wrap_decorator(func_or_class.__init__)
def __init__(self, *args, **kwargs):
- if hasattr(NewClass.__init__, '__func__'):
+ if hasattr(_NewClass.__init__, '__func__'):
# Python 2
- NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
else:
# Python 3
- NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
+ _NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
if _PRINT_DEPRECATION_WARNINGS:
# We're making the alias as we speak. The original may have other
# aliases, so we cannot use it to check for whether it's already been
# warned about.
- if NewClass.__init__ not in _PRINTED_WARNING:
+ if _NewClass.__init__ not in _PRINTED_WARNING:
if warn_once:
- _PRINTED_WARNING[NewClass.__init__] = True
+ _PRINTED_WARNING[_NewClass.__init__] = True
logging.warning(
'From %s: The name %s is deprecated. Please use %s instead.\n',
_call_location(), deprecated_name, name)
- super(NewClass, self).__init__(*args, **kwargs)
+ super(_NewClass, self).__init__(*args, **kwargs)
- return NewClass
+ return _NewClass
else:
decorator_utils.validate_callable(func_or_class, 'deprecated')
diff --git a/tensorflow/python/util/deprecation_test.py b/tensorflow/python/util/deprecation_test.py
index bdd0bc48d2..1ea695e4d6 100644
--- a/tensorflow/python/util/deprecation_test.py
+++ b/tensorflow/python/util/deprecation_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_inspect
class DeprecatedAliasTest(test.TestCase):
@@ -73,6 +74,11 @@ class DeprecatedAliasTest(test.TestCase):
self.assertEqual(["test", "deprecated", "deprecated again"],
MyClass.init_args)
+ # Check __init__ signature matches for doc generation.
+ self.assertEqual(
+ tf_inspect.getfullargspec(MyClass.__init__),
+ tf_inspect.getfullargspec(deprecated_cls.__init__))
+
class DeprecationTest(test.TestCase):
diff --git a/tensorflow/python/util/lock_util.py b/tensorflow/python/util/lock_util.py
new file mode 100644
index 0000000000..0424960666
--- /dev/null
+++ b/tensorflow/python/util/lock_util.py
@@ -0,0 +1,128 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Locking related utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+
+class GroupLock(object):
+ """A lock to allow many members of a group to access a resource exclusively.
+
+ This lock provides a way to allow access to a resource by multiple threads
+ belonging to a logical group at the same time, while restricting access to
+ threads from all other groups. You can think of this as an extension of a
+ reader-writer lock, where you allow multiple writers at the same time. We
+ made it generic to support multiple groups instead of just two - readers and
+ writers.
+
+ Simple usage example with two groups accessing the same resource:
+
+ ```python
+ lock = GroupLock(num_groups=2)
+
+ # In a member of group 0:
+ with lock.group(0):
+ # do stuff, access the resource
+ # ...
+
+ # In a member of group 1:
+ with lock.group(1):
+ # do stuff, access the resource
+ # ...
+ ```
+
+ Using as a context manager with `.group(group_id)` is the easiest way. You
+ can also use the `acquire` and `release` method directly.
+ """
+
+ def __init__(self, num_groups=2):
+ """Initialize a group lock.
+
+ Args:
+ num_groups: The number of groups that will be accessing the resource under
+ consideration. Should be a positive number.
+
+ Returns:
+ A group lock that can then be used to synchronize code.
+
+ Raises:
+ ValueError: If num_groups is less than 1.
+ """
+ if num_groups < 1:
+ raise ValueError("num_groups must be a positive integer, got {}".format(
+ num_groups))
+ self._ready = threading.Condition(threading.Lock())
+ self._num_groups = num_groups
+ self._group_member_counts = [0] * self._num_groups
+
+ def group(self, group_id):
+ """Enter a context where the lock is with group `group_id`.
+
+ Args:
+ group_id: The group for which to acquire and release the lock.
+
+ Returns:
+ A context manager which will acquire the lock for `group_id`.
+ """
+ self._validate_group_id(group_id)
+ return self._Context(self, group_id)
+
+ def acquire(self, group_id):
+ """Acquire the group lock for a specific group `group_id`."""
+ self._validate_group_id(group_id)
+
+ self._ready.acquire()
+ while self._another_group_active(group_id):
+ self._ready.wait()
+ self._group_member_counts[group_id] += 1
+ self._ready.release()
+
+ def release(self, group_id):
+ """Release the group lock for a specific group `group_id`."""
+ self._validate_group_id(group_id)
+
+ self._ready.acquire()
+ self._group_member_counts[group_id] -= 1
+ if self._group_member_counts[group_id] == 0:
+ self._ready.notifyAll()
+ self._ready.release()
+
+ def _another_group_active(self, group_id):
+ return any(
+ c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id)
+
+ def _validate_group_id(self, group_id):
+ if group_id < 0 or group_id >= self._num_groups:
+ raise ValueError(
+ "group_id={} should be between 0 and num_groups={}".format(
+ group_id, self._num_groups))
+
+ class _Context(object):
+ """Context manager helper for `GroupLock`."""
+
+ def __init__(self, lock, group_id):
+ self._lock = lock
+ self._group_id = group_id
+
+ def __enter__(self):
+ self._lock.acquire(self._group_id)
+
+ def __exit__(self, type_arg, value_arg, traceback_arg):
+ del type_arg, value_arg, traceback_arg
+ self._lock.release(self._group_id)
diff --git a/tensorflow/python/util/lock_util_test.py b/tensorflow/python/util/lock_util_test.py
new file mode 100644
index 0000000000..cda8f95225
--- /dev/null
+++ b/tensorflow/python/util/lock_util_test.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lock_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import time
+
+from absl.testing import parameterized
+
+from tensorflow.python.platform import test
+from tensorflow.python.util import lock_util
+
+
+class GroupLockTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(1, 2, 3, 5, 10)
+ def testGroups(self, num_groups):
+ lock = lock_util.GroupLock(num_groups)
+ num_threads = 10
+ finished = set()
+
+ def thread_fn(thread_id):
+ time.sleep(random.random() * 0.1)
+ group_id = thread_id % num_groups
+ with lock.group(group_id):
+ time.sleep(random.random() * 0.1)
+ self.assertGreater(lock._group_member_counts[group_id], 0)
+ for g, c in enumerate(lock._group_member_counts):
+ if g != group_id:
+ self.assertEqual(0, c)
+ finished.add(thread_id)
+
+ threads = [
+ self.checkedThread(target=thread_fn, args=(i,))
+ for i in range(num_threads)
+ ]
+
+ for i in range(num_threads):
+ threads[i].start()
+ for i in range(num_threads):
+ threads[i].join()
+
+ self.assertEqual(set(range(num_threads)), finished)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 1104768ae8..d63f59a8c8 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True):
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
+ check_types: if `True` (default) types of sequences are checked as well,
+ including the keys of dictionaries. If set to `False`, for example a
+ list and a tuple of objects will look the same if they have the same
size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
+ considered to have the same shallow structure. Two types will also be
+ considered the same if they are both list subtypes (which allows "list"
+ and "_ListWrapper" from checkpointable dependency tracking to compare
+ equal).
Raises:
ValueError: If the two structures do not have the same number of elements or
diff --git a/tensorflow/python/util/serialization_test.py b/tensorflow/python/util/serialization_test.py
index 5000bcfad0..9d9cac2725 100644
--- a/tensorflow/python/util/serialization_test.py
+++ b/tensorflow/python/util/serialization_test.py
@@ -47,7 +47,7 @@ class SerializationTests(test.TestCase):
self.assertIs(round_trip[0], None)
self.assertEqual(round_trip[1], 2)
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_serialize_sequential(self):
model = sequential.Sequential()
model.add(core.Dense(4))
@@ -61,7 +61,7 @@ class SerializationTests(test.TestCase):
self.assertAllEqual([1, 1],
input_round_trip[0]["config"]["batch_input_shape"])
- @test_util.run_in_graph_and_eager_modes()
+ @test_util.run_in_graph_and_eager_modes
def test_serialize_model(self):
x = input_layer.Input(shape=[3])
y = core.Dense(10)(x)
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index bf3961c692..e154ffb68a 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -41,17 +41,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import functools
import sys
from tensorflow.python.util import tf_decorator
+ESTIMATOR_API_NAME = 'estimator'
+TENSORFLOW_API_NAME = 'tensorflow'
+
+_Attributes = collections.namedtuple(
+ 'ExportedApiAttributes', ['names', 'constants'])
+
+# Attribute values must be unique to each API.
+API_ATTRS = {
+ TENSORFLOW_API_NAME: _Attributes(
+ '_tf_api_names',
+ '_tf_api_constants'),
+ ESTIMATOR_API_NAME: _Attributes(
+ '_estimator_api_names',
+ '_estimator_api_constants')
+}
+
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
-class tf_export(object): # pylint: disable=invalid-name
+class api_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
@@ -63,15 +81,12 @@ class tf_export(object): # pylint: disable=invalid-name
overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
- allow_multiple_exports: Allows exporting the same symbol multiple
- times with multiple `tf_export` usages. Prefer however, to list all
- of the exported names in a single `tf_export` usage when possible.
-
+ api_name: Name of the API you want to generate (e.g. `tensorflow` or
+ `estimator`). Default is `tensorflow`.
"""
self._names = args
+ self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
- self._allow_multiple_exports = kwargs.get(
- 'allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -86,25 +101,24 @@ class tf_export(object): # pylint: disable=invalid-name
SymbolAlreadyExposedError: Raised when a symbol already has API names
and kwarg `allow_multiple_exports` not set.
"""
+ api_names_attr = API_ATTRS[self._api_name].names
+
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
- del undecorated_f._tf_api_names # pylint: disable=protected-access
+ delattr(undecorated_f, api_names_attr)
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
- if '_tf_api_names' in undecorated_func.__dict__:
- if self._allow_multiple_exports:
- undecorated_func._tf_api_names += self._names # pylint: disable=protected-access
- else:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (undecorated_func.__name__, undecorated_func._tf_api_names)) # pylint: disable=protected-access
- else:
- undecorated_func._tf_api_names = self._names # pylint: disable=protected-access
+ if api_names_attr in undecorated_func.__dict__:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (undecorated_func.__name__, getattr(
+ undecorated_func, api_names_attr))) # pylint: disable=protected-access
+ setattr(undecorated_func, api_names_attr, self._names)
return func
def export_constant(self, module_name, name):
@@ -126,8 +140,12 @@ class tf_export(object): # pylint: disable=invalid-name
name: (string) Current constant name.
"""
module = sys.modules[module_name]
- if not hasattr(module, '_tf_api_constants'):
- module._tf_api_constants = [] # pylint: disable=protected-access
+ if not hasattr(module, API_ATTRS[self._api_name].constants):
+ setattr(module, API_ATTRS[self._api_name].constants, [])
# pylint: disable=protected-access
- module._tf_api_constants.append((self._names, name))
+ getattr(module, API_ATTRS[self._api_name].constants).append(
+ (self._names, name))
+
+tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
+estimator_export = functools.partial(tf_export, api_name=ESTIMATOR_API_NAME)
diff --git a/tensorflow/python/util/tf_export_test.py b/tensorflow/python/util/tf_export_test.py
index ace3f054ba..b9e26ecb33 100644
--- a/tensorflow/python/util/tf_export_test.py
+++ b/tensorflow/python/util/tf_export_test.py
@@ -128,13 +128,6 @@ class ValidateExportTest(test.TestCase):
with self.assertRaises(tf_export.SymbolAlreadyExposedError):
export_decorator(_test_function)
- def testEAllowMultipleExports(self):
- _test_function._tf_api_names = ['name1', 'name2']
- tf_export.tf_export('nameRed', 'nameBlue', allow_multiple_exports=True)(
- _test_function)
- self.assertEquals(['name1', 'name2', 'nameRed', 'nameBlue'],
- _test_function._tf_api_names)
-
def testOverridesFunction(self):
_test_function2._tf_api_names = ['abc']
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index fbd6561767..ec20998bdd 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -300,6 +300,16 @@ def getsource(object): # pylint: disable=redefined-builtin
return _inspect.getsource(tf_decorator.unwrap(object)[1])
+def getsourcefile(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.getsourcefile."""
+ return _inspect.getsourcefile(tf_decorator.unwrap(object)[1])
+
+
+def getsourcelines(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.getsourcelines."""
+ return _inspect.getsourcelines(tf_decorator.unwrap(object)[1])
+
+
def isbuiltin(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.isbuiltin."""
return _inspect.isbuiltin(tf_decorator.unwrap(object)[1])
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index beaf350de1..2f6021c7d8 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -326,6 +326,18 @@ def test_decorated_function_with_defaults(a, b=2, c='Hello'):
self.assertEqual(
expected, tf_inspect.getsource(test_decorated_function_with_defaults))
+ def testGetSourceFile(self):
+ self.assertEqual(
+ __file__,
+ tf_inspect.getsourcefile(test_decorated_function_with_defaults))
+
+ def testGetSourceLines(self):
+ expected = inspect.getsourcelines(
+ test_decorated_function_with_defaults.decorated_target)
+ self.assertEqual(
+ expected,
+ tf_inspect.getsourcelines(test_decorated_function_with_defaults))
+
def testIsBuiltin(self):
self.assertEqual(
tf_inspect.isbuiltin(TestDecoratedClass),
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 8e839b523e..366f8a0deb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -33,6 +33,8 @@ namespace {
PyObject* CollectionsSequenceType = nullptr;
PyTypeObject* SparseTensorValueType = nullptr;
+const int kMaxItemsInCache = 1024;
+
bool WarnedThatSetIsNotSequence = false;
bool IsString(PyObject* o) {
@@ -196,11 +198,14 @@ int IsSequenceHelper(PyObject* o) {
// NOTE: This is never decref'd, but we don't want the type to get deleted
// as long as it is in the map. This should not be too much of a
// leak, as there should only be a relatively small number of types in the
- // map, and an even smaller number that are eligible for decref.
- Py_INCREF(type);
+ // map, and an even smaller number that are eligible for decref. As a
+ // precaution, we limit the size of the map to 1024.
{
mutex_lock l(g_type_to_sequence_map);
- type_to_sequence_map->insert({type, is_sequence});
+ if (type_to_sequence_map->size() < kMaxItemsInCache) {
+ Py_INCREF(type);
+ type_to_sequence_map->insert({type, is_sequence});
+ }
}
return is_sequence;
@@ -243,6 +248,9 @@ bool GetNextValuesForIterable(PyObject* nested,
std::vector<Safe_PyObjectPtr>* next_values) {
PyObject* item;
PyObject* iterator = PyObject_GetIter(nested);
+ if (iterator == nullptr || PyErr_Occurred()) {
+ return false;
+ }
while ((item = PyIter_Next(iterator)) != nullptr) {
next_values->emplace_back(item);
}
@@ -386,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
type2->tp_name);
return true;
}
- } else if (type1 != type2) {
+ } else if (type1 != type2
+ /* If both sequences are list types, don't complain. This allows
+ one to be a list subclass (e.g. _ListWrapper used for automatic
+ dependency tracking.) */
+ && !(PyList_Check(o1) && PyList_Check(o2))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "
diff --git a/tensorflow/security/index.md b/tensorflow/security/index.md
index 44f51ad07b..ea39e17ab2 100644
--- a/tensorflow/security/index.md
+++ b/tensorflow/security/index.md
@@ -4,7 +4,7 @@ We regularly publish security advisories about using TensorFlow.
*Note*: In conjunction with these security advisories, we strongly encourage
TensorFlow users to read and understand TensorFlow's security model as outlined
-in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.md).
+in (https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md)[SECURITY.md].
| Advisory Number | Type | Versions affected | Reported by | Additional Information |
|-----------------|--------------------|:-----------------:|-----------------------|-----------------------------|
@@ -14,5 +14,5 @@ in [https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md](SECURITY.m
| [TFSA-2018-003](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-003.md) | TensorFlow Lite TOCO FlatBuffer Parsing Vulnerability | <= 1.7 | Blade Team of Tencent | |
| [TFSA-2018-002](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-002.md) | GIF File Parsing Null Pointer Dereference Error | <= 1.5 | Blade Team of Tencent | |
| [TFSA-2018-001](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/advisory/tfsa-2018-001.md) | BMP File Parser Out-of-bounds Read | <= 1.6 | Blade Team of Tencent | |
-| - | Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
+| - | Out Of Bounds Read | <= 1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 21295abed1..e742f8e8d5 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -2,6 +2,7 @@ licenses(["restricted"])
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
STREAM_EXECUTOR_HEADERS = glob([
"*.h",
@@ -51,6 +52,14 @@ cc_library(
] + if_static([":stream_executor_impl"]),
)
+cc_header_only_library(
+ name = "stream_executor_headers_lib",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":stream_executor",
+ ],
+)
+
cc_library(
name = "cuda_platform",
srcs = if_cuda_is_configured(
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 08fe153b59..874bf0e8cb 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2155,10 +2155,7 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
-// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
-#if CUDA_VERSION < 8000
- return false;
-#else
+ // GPUs < sm_50 don't support cublasGemmEx.
int cc_major, cc_minor;
if (stream->parent()->GetDeviceDescription().cuda_compute_capability(
&cc_major, &cc_minor) &&
@@ -2184,6 +2181,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
}
}
+ // Return false if we might be hitting a cuBLAS bug that produces the wrong
+ // result. See nvbugs/2156201, b/79126339.
+#if CUDA_VERSION >= 9000 && CUDA_VERSION < 9020
+ if ((algorithm == CUBLAS_GEMM_DEFAULT || algorithm >= CUBLAS_GEMM_ALGO13) &&
+ std::max({m, n, k}) >= 2097153 && cc_major < 7) {
+ return false;
+ }
+#endif
+
cudaDataType_t cuda_in_type = CUDADataType<InT>::type;
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
// we do the following compile-time check on the default value:
@@ -2213,7 +2219,6 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
timer->GetElapsedMilliseconds());
}
return result;
-#endif
}
bool CUDABlas::GetBlasGemmAlgorithms(
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
index 46e5deed84..124d5905b9 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -124,15 +124,20 @@ void Diagnostician::LogDiagnosticInformation() {
#ifdef __APPLE__
CFStringRef kext_ids[1];
kext_ids[0] = kDriverKextIdentifier;
- CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
- CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
+ CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1,
+ &kCFTypeArrayCallBacks);
+ CFDictionaryRef kext_infos =
+ KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
CFRelease(kext_id_query);
CFDictionaryRef cuda_driver_info = nullptr;
- if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
- bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue(cuda_driver_info, CFSTR("OSBundleStarted")));
+ if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier,
+ (const void **)&cuda_driver_info)) {
+ bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue(
+ cuda_driver_info, CFSTR("OSBundleStarted")));
if (!started) {
- LOG(INFO) << "kernel driver is installed, but does not appear to be running on this host "
+ LOG(INFO) << "kernel driver is installed, but does not appear to be "
+ "running on this host "
<< "(" << port::Hostname() << ")";
}
} else {
@@ -210,27 +215,27 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
"was unable to find libcuda.so DSO loaded into this program"));
#if defined(__APPLE__)
- // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib
- const string prefix("libcuda_");
- const string suffix("_mercury.dylib");
- for (uint32_t image_index = 0; image_index < _dyld_image_count(); ++image_index) {
- const string path(_dyld_get_image_name(image_index));
- const size_t suffix_pos = path.rfind(suffix);
- const size_t prefix_pos = path.rfind(prefix, suffix_pos);
- if (prefix_pos == string::npos ||
- suffix_pos == string::npos) {
- // no match
- continue;
- }
- const size_t start = prefix_pos + prefix.size();
- if (start >= suffix_pos) {
- // version not included
- continue;
- }
- const size_t length = suffix_pos - start;
- const string version = path.substr(start, length);
- result = StringToDriverVersion(version);
+ // OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib
+ const string prefix("libcuda_");
+ const string suffix("_mercury.dylib");
+ for (uint32_t image_index = 0; image_index < _dyld_image_count();
+ ++image_index) {
+ const string path(_dyld_get_image_name(image_index));
+ const size_t suffix_pos = path.rfind(suffix);
+ const size_t prefix_pos = path.rfind(prefix, suffix_pos);
+ if (prefix_pos == string::npos || suffix_pos == string::npos) {
+ // no match
+ continue;
+ }
+ const size_t start = prefix_pos + prefix.size();
+ if (start >= suffix_pos) {
+ // version not included
+ continue;
}
+ const size_t length = suffix_pos - start;
+ const string version = path.substr(start, length);
+ result = StringToDriverVersion(version);
+ }
#else
#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA)
// Callback used when iterating through DSOs. Looks for the driver-interfacing
@@ -313,12 +318,15 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
#if defined(__APPLE__)
CFStringRef kext_ids[1];
kext_ids[0] = kDriverKextIdentifier;
- CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
- CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
+ CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void **)kext_ids, 1,
+ &kCFTypeArrayCallBacks);
+ CFDictionaryRef kext_infos =
+ KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
CFRelease(kext_id_query);
CFDictionaryRef cuda_driver_info = nullptr;
- if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
+ if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier,
+ (const void **)&cuda_driver_info)) {
// NOTE: OSX CUDA driver does not currently store the same driver version
// in kCFBundleVersionKey as is returned by cuDriverGetVersion
CFRelease(kext_infos);
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index c2c0c283b3..84916385a8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <functional>
#include <memory>
+#include <utility>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
@@ -55,6 +56,33 @@ namespace {
static_assert(CUDNN_VERSION >= 6000, "cuDNN needs to be version 6.0 or higher");
+// Exits the program if 'expr' doesn't return CUDNN_STATUS_SUCCESS.
+#define CHECK_CUDNN_OK(expr) CHECK_EQ(expr, CUDNN_STATUS_SUCCESS)
+
+// If 'expr' doesn't return CUDNN_STATUS_SUCCESS, returns from the current
+// function with a non-successful port::Status.
+#define RETURN_IF_CUDNN_ERROR(expr) \
+ do { \
+ cudnnStatus_t _status = expr; \
+ if (!SE_PREDICT_TRUE(_status == CUDNN_STATUS_SUCCESS)) { \
+ std::ostringstream oss; \
+ oss << ToString(_status) << "\nin " << __FILE__ << "(" << __LINE__ \
+ << "): '" << #expr << "'"; \
+ return port::Status(port::error::UNKNOWN, oss.str().c_str()); \
+ } \
+ } while (false)
+
+// Returns whether status is 'ok', and potentially logs the error.
+bool IsStatusOk(const port::Status& status, bool report_error) {
+ if (status.ok()) {
+ return true;
+ }
+ if (report_error) {
+ LOG(ERROR) << status.error_message();
+ }
+ return false;
+}
+
// Converts (via narrowing) a type T value to a type U, and checks that the
// value has no value change due to the conversion.
template <typename WideT, typename NarrowT>
@@ -89,26 +117,20 @@ string ToString(cudnnStatus_t status) {
return "CUDNN_STATUS_NOT_SUPPORTED";
case CUDNN_STATUS_LICENSE_ERROR:
return "CUDNN_STATUS_LICENSE_ERROR";
+ case CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING:
+ return "CUDNN_STATUS_RUNTIME_PREREQUISITE_MISSING";
+#if CUDNN_VERSION >= 7000
+ case CUDNN_STATUS_RUNTIME_IN_PROGRESS:
+ return "CUDNN_STATUS_RUNTIME_IN_PROGRESS";
+ case CUDNN_STATUS_RUNTIME_FP_OVERFLOW:
+ return "CUDNN_STATUS_RUNTIME_FP_OVERFLOW";
+#endif
default:
return port::StrCat("<unknown cudnn status: ", static_cast<int>(status),
">");
}
}
-string ToString(libraryPropertyType type) {
- switch (type) {
- case MAJOR_VERSION:
- return "MAJOR_VERSION";
- case MINOR_VERSION:
- return "MINOR_VERSION";
- case PATCH_LEVEL:
- return "PATCH_LEVEL";
- default:
- return port::StrCat(
- "<unknown libraryPropertyType: ", static_cast<int>(type), ">");
- }
-}
-
template <typename T>
cudnnDataType_t GetCudnnDataType();
@@ -150,9 +172,9 @@ class CudnnHandle {
} // namespace
-// Wraps a cuDNN handle and provides access to it through CudnnHandle instances,
-// which also locks a mutex, acquires the CUDA context, and sets the stream
-// that cuDNN should use to enqueue any work.
+// Wraps a cuDNN handle and provides access to it through CudnnHandle
+// instances, which also locks a mutex, acquires the CUDA context, and sets
+// the stream that cuDNN should use to enqueue any work.
//
// Note: CudnnSupport::cudnn_ should be the only instantiation of this class.
class CudnnAccess {
@@ -167,13 +189,13 @@ class CudnnAccess {
// Creates a CudnnHandle instance for stream.
//
- // cuDNN API calls using the same handle instance need to be serialized across
- // threads. This is guaranteed by CudnnHandle instances locking the mutex
- // owned by this class.
+ // cuDNN API calls using the same handle instance need to be serialized
+ // across threads. This is guaranteed by CudnnHandle instances locking the
+ // mutex owned by this class.
//
// Most cuDNN APIs taking a handle perform work on a CUDA stream. The
- // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN to
- // use the provided stream.
+ // CudnnHandle instance acquires the executor's CUDA context and sets cuDNN
+ // to use the provided stream.
//
// The stream argument may be null, which translates to the legacy default
// stream. See
@@ -187,7 +209,6 @@ class CudnnAccess {
CUstream cu_stream = stream ? AsCUDAStreamValue(stream) : cudaStreamLegacy;
auto status = cudnnSetStream(handle_, cu_stream);
CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Failed to set cuDNN stream.";
- using my_mutex_lock = mutex_lock;
return CudnnHandle(std::move(context), std::move(lock), handle_);
}
@@ -201,6 +222,8 @@ class CudnnAccess {
namespace {
+// A helper function to return the internal compute type for
+// RNNs in cudnn.
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type);
cudnnConvolutionFwdAlgo_t ToConvForwardAlgo(dnn::AlgorithmDesc algorithm) {
@@ -264,16 +287,10 @@ cudnnConvolutionBwdFilterAlgo_t ToConvBackwardFilterAlgo(
}
}
-port::Status GetCudnnProperty(libraryPropertyType type, int* value) {
- cudnnStatus_t status = cudnnGetProperty(type, value);
- if (status != CUDNN_STATUS_SUCCESS) {
- const string error =
- port::StrCat("cudnnGetProperty failed for type: ", ToString(type),
- " with status: ", ToString(status));
- LOG(ERROR) << error;
- return port::Status(port::error::INTERNAL, error);
- }
- return port::Status::OK();
+port::StatusOr<int> GetCudnnProperty(libraryPropertyType type) {
+ int value;
+ RETURN_IF_CUDNN_ERROR(cudnnGetProperty(type, &value));
+ return value;
}
cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
@@ -294,9 +311,9 @@ cudnnRNNAlgo_t ToCudnnRNNAlgo(const dnn::AlgorithmDesc& algorithm) {
}
port::Status GetLoadedCudnnVersion(CudnnVersion* version) {
- TF_RETURN_IF_ERROR(GetCudnnProperty(MAJOR_VERSION, &version->major_version));
- TF_RETURN_IF_ERROR(GetCudnnProperty(MINOR_VERSION, &version->minor_version));
- TF_RETURN_IF_ERROR(GetCudnnProperty(PATCH_LEVEL, &version->patch_level));
+ SE_ASSIGN_OR_RETURN(version->major_version, GetCudnnProperty(MAJOR_VERSION));
+ SE_ASSIGN_OR_RETURN(version->minor_version, GetCudnnProperty(MINOR_VERSION));
+ SE_ASSIGN_OR_RETURN(version->patch_level, GetCudnnProperty(PATCH_LEVEL));
return port::Status::OK();
}
@@ -319,9 +336,11 @@ port::Status CudnnSupport::Init() {
". CuDNN library major and minor version needs to match or have "
"higher minor version in case of CuDNN 7.0 or later version. If "
"using a binary install, upgrade your CuDNN library. If building "
- "from sources, make sure the library loaded at runtime is compatible "
+ "from sources, make sure the library loaded at runtime is "
+ "compatible "
"with the version specified during compile configuration.");
LOG(ERROR) << error;
+ cudnnDestroy(cudnn_handle);
return port::Status(port::error::INTERNAL, error);
}
@@ -329,23 +348,17 @@ port::Status CudnnSupport::Init() {
return port::Status::OK();
}
- LOG(ERROR) << "could not create cudnn handle: " << ToString(status);
+ CHECK_EQ(cudnn_handle, nullptr);
+ LOG(ERROR) << "Could not create cudnn handle: " << ToString(status);
if (status == CUDNN_STATUS_NOT_INITIALIZED) {
auto result = cuda::Diagnostician::FindKernelDriverVersion();
if (!result.ok()) {
- LOG(ERROR) << "error retrieving driver version: "
+ LOG(ERROR) << "Error retrieving driver version: "
<< DriverVersionStatusToString(result);
} else {
const auto& version = result.ValueOrDie();
- LOG(ERROR) << "possibly insufficient driver version: "
+ LOG(ERROR) << "Possibly insufficient driver version: "
<< DriverVersionToString(version);
- // OS X kernel driver does not report version accurately
-#if !defined(__APPLE__)
- if (std::get<0>(version) < 340) {
- LOG(ERROR)
- << "cudnn library is only supported on 340.XX+ driver versions";
- }
-#endif
}
}
@@ -364,18 +377,129 @@ CudnnSupport::GetVersion() {
namespace {
-// Turns a BatchDescriptor structure into a cudnn tensor handle within a scope.
-class ScopedTensorDescriptor {
- public:
- ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
- cudnnDataType_t elem_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn tensor descriptor: "
- << ToString(status);
- }
+// Deleter functors for cuDNN types that need to be deleted.
+struct TensorDescriptorDeleter {
+ void operator()(cudnnTensorDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyTensorDescriptor(descriptor));
+ }
+};
+struct FilterDescriptorDeleter {
+ void operator()(cudnnFilterDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyFilterDescriptor(descriptor));
+ }
+};
+struct ConvolutionDescriptorDeleter {
+ void operator()(cudnnConvolutionDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyConvolutionDescriptor(descriptor));
+ }
+};
+struct PoolingDescriptorDeleter {
+ void operator()(cudnnPoolingDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyPoolingDescriptor(descriptor));
+ }
+};
+struct LrnDescriptorDeleter {
+ void operator()(cudnnLRNDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyLRNDescriptor(descriptor));
+ }
+};
+
+struct ActivationDescriptorDeleter {
+ void operator()(cudnnActivationDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyActivationDescriptor(descriptor));
+ }
+};
+struct DropoutDescriptorDeleter {
+ void operator()(cudnnDropoutDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyDropoutDescriptor(descriptor));
+ }
+};
+struct RnnDescriptorDeleter {
+ void operator()(cudnnRNNDescriptor_t descriptor) const {
+ CHECK_CUDNN_OK(cudnnDestroyRNNDescriptor(descriptor));
+ }
+};
+struct PersistentRnnPlanDeleter {
+ void operator()(cudnnPersistentRNNPlan_t plan) const {
+ CHECK_CUDNN_OK(cudnnDestroyPersistentRNNPlan(plan));
+ }
+};
+// RAII wrappers for cuDNN types.
+using TensorDescriptor =
+ std::unique_ptr<cudnnTensorStruct, TensorDescriptorDeleter>;
+using FilterDescriptor =
+ std::unique_ptr<cudnnFilterStruct, FilterDescriptorDeleter>;
+using ConvolutionDescriptor =
+ std::unique_ptr<cudnnConvolutionStruct, ConvolutionDescriptorDeleter>;
+using PoolingDescriptor =
+ std::unique_ptr<cudnnPoolingStruct, PoolingDescriptorDeleter>;
+using LrnDescriptor = std::unique_ptr<cudnnLRNStruct, LrnDescriptorDeleter>;
+using ActivationDescriptor =
+ std::unique_ptr<cudnnActivationStruct, ActivationDescriptorDeleter>;
+using DropoutDescriptor =
+ std::unique_ptr<cudnnDropoutStruct, DropoutDescriptorDeleter>;
+using RnnDescriptor = std::unique_ptr<cudnnRNNStruct, RnnDescriptorDeleter>;
+using PersistentRnnPlan =
+ std::unique_ptr<cudnnPersistentRNNPlan, PersistentRnnPlanDeleter>;
+
+// Factory methods for cuDNN types.
+TensorDescriptor CreateTensorDescriptor() {
+ cudnnTensorDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateTensorDescriptor(&result));
+ return TensorDescriptor(result);
+}
+FilterDescriptor CreateFilterDescriptor() {
+ cudnnFilterDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateFilterDescriptor(&result));
+ return FilterDescriptor(result);
+}
+ConvolutionDescriptor CreateConvolutionDescriptor() {
+ cudnnConvolutionDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateConvolutionDescriptor(&result));
+ return ConvolutionDescriptor(result);
+}
+PoolingDescriptor CreatePoolingDescriptor() {
+ cudnnPoolingDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreatePoolingDescriptor(&result));
+ return PoolingDescriptor(result);
+}
+LrnDescriptor CreateLrnDescriptor() {
+ cudnnLRNDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateLRNDescriptor(&result));
+ return LrnDescriptor(result);
+}
+ActivationDescriptor CreateActivationDescriptor() {
+ cudnnActivationDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateActivationDescriptor(&result));
+ return ActivationDescriptor(result);
+}
+DropoutDescriptor CreateDropoutDescriptor() {
+ cudnnDropoutDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateDropoutDescriptor(&result));
+ return DropoutDescriptor(result);
+}
+RnnDescriptor CreateRnnDescriptor() {
+ cudnnRNNDescriptor_t result;
+ CHECK_CUDNN_OK(cudnnCreateRNNDescriptor(&result));
+ return RnnDescriptor(result);
+}
+PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,
+ int batch_size,
+ cudnnDataType_t data_type) {
+ cudnnPersistentRNNPlan_t result;
+ CHECK_CUDNN_OK(
+ cudnnCreatePersistentRNNPlan(rnn_desc, batch_size, data_type, &result));
+ return PersistentRnnPlan(result);
+}
+
+// Turns a BatchDescriptor structure into a cudnn tensor handle within a
+// scope.
+class CudnnTensorDescriptor {
+ public:
+ CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
+ cudnnDataType_t elem_type)
+ : handle_(CreateTensorDescriptor()) {
switch (batch_descriptor.layout()) {
case dnn::DataLayout::kBatchYXDepth:
case dnn::DataLayout::kBatchDepthYX: {
@@ -393,25 +517,16 @@ class ScopedTensorDescriptor {
&CheckedNarrowing<int64, int>);
std::transform(dims64.cbegin(), dims64.cend(), dims.begin(),
&CheckedNarrowing<int64, int>);
- status = cudnnSetTensorNdDescriptor(handle_, elem_type, nd, dims.data(),
- strides.data());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not convert BatchDescriptor "
- << batch_descriptor.ToString()
- << " to cudnn tensor descriptor: " << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(handle_.get(), elem_type, nd,
+ dims.data(), strides.data()))
+ << "batch_descriptor: " << batch_descriptor.ToString();
} break;
case dnn::DataLayout::kBatchDepthYX4: {
- status = cudnnSetTensor4dDescriptor(
- handle_, CUDNN_TENSOR_NCHW_VECT_C, elem_type,
+ CHECK_CUDNN_OK(cudnnSetTensor4dDescriptor(
+ handle_.get(), CUDNN_TENSOR_NCHW_VECT_C, elem_type,
batch_descriptor.count(), batch_descriptor.feature_map_count(),
- batch_descriptor.height(), batch_descriptor.width());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not convert BatchDescriptor "
- << batch_descriptor.ToString()
- << " to cudnn tensor descriptor: " << ToString(status);
- }
+ batch_descriptor.height(), batch_descriptor.width()))
+ << "batch_descriptor: " << batch_descriptor.ToString();
} break;
default:
LOG(FATAL) << "Unsupported tensor format "
@@ -420,37 +535,24 @@ class ScopedTensorDescriptor {
}
}
- ~ScopedTensorDescriptor() {
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn tensor descriptor: "
- << ToString(status);
- }
- }
-
- cudnnTensorDescriptor_t handle() const { return handle_; }
+ cudnnTensorDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnTensorDescriptor_t handle_; // Owned.
+ TensorDescriptor handle_;
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
};
-// Turns a FilterDescriptor structure into a cudnn filter handle within a scope.
-class ScopedFilterDescriptor {
+// Turns a FilterDescriptor structure into a cudnn filter handle within a
+// scope.
+class CudnnFilterDescriptor {
public:
- ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
- cudnnDataType_t elem_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateFilterDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn filter descriptor: "
- << ToString(status);
- }
-
+ CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
+ cudnnDataType_t elem_type)
+ : handle_(CreateFilterDescriptor()) {
// TODO(b/23032134): Even if the filter layout is not supported,
- // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because it
- // does not take layout as an input. Maybe force cuDNN by giving wrong
+ // cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
+ // it does not take layout as an input. Maybe force cuDNN by giving wrong
// inputs intentionally?
cudnnTensorFormat_t format;
switch (filter_descriptor.layout()) {
@@ -475,32 +577,20 @@ class ScopedFilterDescriptor {
const auto& spatial_dims = filter_descriptor.input_filter_dims();
std::copy(spatial_dims.begin(), spatial_dims.end(), dims.begin() + 2);
- status = cudnnSetFilterNdDescriptor(handle_, elem_type, format, dims.size(),
- dims.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn filter descriptor: "
- << ToString(status);
- }
- }
-
- ~ScopedFilterDescriptor() {
- cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn filter descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetFilterNdDescriptor(handle_.get(), elem_type, format,
+ dims.size(), dims.data()));
}
- cudnnFilterDescriptor_t handle() const { return handle_; }
+ cudnnFilterDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnFilterDescriptor_t handle_; // Owned.
+ FilterDescriptor handle_; // Owned.
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
};
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
-static bool TensorOpMathEnabled() {
+bool TensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
@@ -513,7 +603,7 @@ static bool TensorOpMathEnabled() {
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
// for RNNs.
-static bool RnnTensorOpMathEnabled() {
+bool RnnTensorOpMathEnabled() {
static bool is_enabled = [] {
bool is_disabled = false;
TF_CHECK_OK(
@@ -524,15 +614,16 @@ static bool RnnTensorOpMathEnabled() {
return is_enabled;
}
-// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT
-// in batchnorm. This mode can be faster in some tasks because an optimized path
-// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute
-// capability 6.0 or higher. The reason we set it to false by default is that
-// this mode may use scaled atomic integer reduction that may cause a numerical
-// overflow for certain input data range.
+// A helper function to decide whether to use
+// CUDNN_BATCHNORM_SPATIAL_PERSISTENT in batchnorm. This mode can be faster in
+// some tasks because an optimized path may be selected for CUDNN_DATA_FLOAT
+// and CUDNN_DATA_HALF data types, compute capability 6.0 or higher. The
+// reason we set it to false by default is that this mode may use scaled
+// atomic integer reduction that may cause a numerical overflow for certain
+// input data range.
// TODO(yangzihao): Use autotune to choose between this mode and
// CUDNN_BATCHNORM_SPATIAL mode.
-static bool BatchnormSpatialPersistentEnabled() {
+bool BatchnormSpatialPersistentEnabled() {
static bool is_enabled = [] {
bool is_enabled = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
@@ -545,24 +636,18 @@ static bool BatchnormSpatialPersistentEnabled() {
// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
// within a scope.
-class ScopedConvolutionDescriptor {
+class CudnnConvolutionDescriptor {
public:
- ScopedConvolutionDescriptor(
+ CudnnConvolutionDescriptor(
const dnn::ConvolutionDescriptor& convolution_descriptor,
cudnnDataType_t data_type)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateConvolutionDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn convolution descriptor: "
- << ToString(status);
- }
+ : handle_(CreateConvolutionDescriptor()) {
const auto& strides64 = convolution_descriptor.strides();
const auto& padding64 = convolution_descriptor.padding();
const auto& dilations64 = convolution_descriptor.dilations();
- if (convolution_descriptor.pad_alignment() ==
- dnn::PadAlignment::kTensorFlowPadding) {
- LOG(ERROR) << "TensorFlow padding alignment is not supported.";
- }
+ CHECK_NE(convolution_descriptor.pad_alignment(),
+ dnn::PadAlignment::kTensorFlowPadding)
+ << "TensorFlow padding alignment is not supported.";
// cuDNN requires arrays of ints.
std::vector<int> strides(convolution_descriptor.ndims());
@@ -577,18 +662,14 @@ class ScopedConvolutionDescriptor {
std::transform(dilations64.cbegin(), dilations64.cend(), dilations.begin(),
&CheckedNarrowing<int64, int>);
- status = cudnnSetConvolutionNdDescriptor(
- handle_, convolution_descriptor.ndims(), padding.data(), strides.data(),
- dilations.data(),
+ CHECK_CUDNN_OK(cudnnSetConvolutionNdDescriptor(
+ handle_.get(), convolution_descriptor.ndims(), padding.data(),
+ strides.data(), dilations.data(),
// NOTE(keveman): cuDNN supports convolution and cross correlation.
// However, almost all the use cases do cross correlation, so just
// hard coding it here.
- CUDNN_CROSS_CORRELATION, data_type);
+ CUDNN_CROSS_CORRELATION, data_type));
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution descriptor: "
- << ToString(status);
- }
// NOTE(benbarsdell): This only applies if tensor op math is enabled
// and algo selection is set to Default.
this->set_use_tensor_op_math(true);
@@ -596,60 +677,39 @@ class ScopedConvolutionDescriptor {
#if CUDNN_MAJOR >= 7
VLOG(2) << "Requesting grouped convolution: "
<< convolution_descriptor.group_count();
- status = cudnnSetConvolutionGroupCount(
- handle_, convolution_descriptor.group_count());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution group count: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetConvolutionGroupCount(
+ handle_.get(), convolution_descriptor.group_count()));
#else
CHECK_EQ(convolution_descriptor.group_count(), 1)
<< "Requested grouped convolution for cuDNN version < 7";
#endif
}
- void set_use_tensor_op_math(bool use_tensor_op_math) {
+ void set_use_tensor_op_math(bool use_tensor_op_math) const {
#if CUDNN_VERSION >= 7000
cudnnMathType_t math_type =
(use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
if (TensorOpMathEnabled()) {
- cudnnStatus_t status = cudnnSetConvolutionMathType(handle_, math_type);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn convolution math type: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetConvolutionMathType(handle_.get(), math_type));
}
#endif
}
- ~ScopedConvolutionDescriptor() {
- cudnnStatus_t status = cudnnDestroyConvolutionDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn convolution descriptor: "
- << ToString(status);
- }
- }
-
- cudnnConvolutionDescriptor_t handle() const { return handle_; }
+ cudnnConvolutionDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnConvolutionDescriptor_t handle_; // Owned.
+ ConvolutionDescriptor handle_; // Owned.
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
};
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
// within a scope.
-class ScopedPoolingDescriptor {
+class CudnnPoolingDescriptor {
public:
- explicit ScopedPoolingDescriptor(
+ explicit CudnnPoolingDescriptor(
const dnn::PoolingDescriptor& pooling_descriptor)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreatePoolingDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn pooling descriptor: "
- << ToString(status);
- }
+ : handle_(CreatePoolingDescriptor()) {
const std::vector<int64> strides64 = pooling_descriptor.strides();
const std::vector<int64> padding64 = pooling_descriptor.padding();
const std::vector<int64> shape64 = pooling_descriptor.window();
@@ -665,46 +725,29 @@ class ScopedPoolingDescriptor {
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
bool propagate_nans = pooling_descriptor.propagate_nans();
- status = cudnnSetPoolingNdDescriptor(
- handle_,
+ CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
+ handle_.get(),
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
? CUDNN_POOLING_MAX
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
- shape.data(), padding.data(), strides.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn pooling descriptor: "
- << ToString(status);
- }
- }
- ~ScopedPoolingDescriptor() {
- cudnnStatus_t status = cudnnDestroyPoolingDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn pooling descriptor: "
- << ToString(status);
- }
+ shape.data(), padding.data(), strides.data()));
}
- cudnnPoolingDescriptor_t handle() const { return handle_; }
+ cudnnPoolingDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnPoolingDescriptor_t handle_; // Owned.
+ PoolingDescriptor handle_; // Owned.
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
};
// Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
-class ScopedNormalizeDescriptor {
+class CudnnNormalizeDescriptor {
public:
- explicit ScopedNormalizeDescriptor(
+ explicit CudnnNormalizeDescriptor(
const dnn::NormalizeDescriptor& normalize_descriptor)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateLRNDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn LRN descriptor: "
- << ToString(status);
- }
-
+ : handle_(CreateLrnDescriptor()) {
// The range specifies that the indices in the closed range
// [i - range, i + range] should be included in the normalization for index
// i. The lrnN value is the total number of elements in the range, so
@@ -725,42 +768,26 @@ class ScopedNormalizeDescriptor {
double lrnBeta = normalize_descriptor.beta();
double lrnK = normalize_descriptor.bias();
- status = cudnnSetLRNDescriptor(handle_, lrnN, lrnAlpha, lrnBeta, lrnK);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn LRN descriptor: " << ToString(status);
- }
- }
-
- ~ScopedNormalizeDescriptor() {
- cudnnStatus_t status = cudnnDestroyLRNDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn LRN descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(
+ cudnnSetLRNDescriptor(handle_.get(), lrnN, lrnAlpha, lrnBeta, lrnK));
}
- cudnnLRNDescriptor_t handle() const { return handle_; }
+ cudnnLRNDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnLRNDescriptor_t handle_; // Owned.
+ LrnDescriptor handle_; // Owned.
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
};
// Turns a ActivationDescriptor structure into a cudnn activation
// descriptor handle within a scope.
-class ScopedActivationDescriptor {
+class CudnnActivationDescriptor {
public:
- ScopedActivationDescriptor(dnn::ActivationMode activation_mode,
- cudnnNanPropagation_t nan_propagation,
- double value_max)
- : handle_(nullptr) {
- cudnnStatus_t status = cudnnCreateActivationDescriptor(&handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not create cudnn activation descriptor: "
- << ToString(status);
- }
-
+ CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
+ cudnnNanPropagation_t nan_propagation,
+ double value_max)
+ : handle_(CreateActivationDescriptor()) {
double relu_ceiling = 0.0;
cudnnActivationMode_t mode;
switch (activation_mode) {
@@ -786,28 +813,16 @@ class ScopedActivationDescriptor {
<< static_cast<int>(activation_mode);
}
- status = cudnnSetActivationDescriptor(handle_, mode, nan_propagation,
- relu_ceiling);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn activation descriptor: "
- << ToString(status);
- }
+ CHECK_CUDNN_OK(cudnnSetActivationDescriptor(handle_.get(), mode,
+ nan_propagation, relu_ceiling));
}
- ~ScopedActivationDescriptor() {
- cudnnStatus_t status = cudnnDestroyActivationDescriptor(handle_);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not destroy cudnn activation descriptor: "
- << ToString(status);
- }
- }
-
- cudnnActivationDescriptor_t handle() const { return handle_; }
+ cudnnActivationDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnActivationDescriptor_t handle_; // Owned.
+ ActivationDescriptor handle_; // Owned.
- SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
+ SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
};
cudnnDataType_t ToCudnnDataType(
@@ -873,117 +888,74 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
}
}
-template <typename Base>
-class MixinBase : public Base {};
-template <>
-class MixinBase<void> {};
-
-#define CUDNN_RETURN_IF_FAIL(STATUS, ...) \
- if (!SE_PREDICT_TRUE((STATUS) == CUDNN_STATUS_SUCCESS)) { \
- string error_msg = port::StrCat(ToString(STATUS), " ", __VA_ARGS__); \
- SetFailure(port::Status(port::error::UNKNOWN, error_msg)); \
- LOG(ERROR) << error_msg; \
- return; \
- }
+class CudnnDropoutDescriptor {
+ explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
+ : handle_(std::move(handle)) {}
-// TODO(csigg): Remove inheritance for code reuse.
-template <typename Base>
-class CudnnDescriptorCommon : public MixinBase<Base> {
public:
- bool ok() const { return status_.ok(); }
- port::Status Status() const { return status_; }
+ CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
- protected:
- void SetFailure(const port::Status& status) { status_.Update(status); }
- port::Status status_;
-};
+ static port::StatusOr<CudnnDropoutDescriptor> Create(
+ const CudnnHandle& cudnn, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ DropoutDescriptor handle = CreateDropoutDescriptor();
-class CudnnDropoutDescriptor : public CudnnDescriptorCommon<void> {
- public:
- CudnnDropoutDescriptor(const CudnnHandle& cudnn, float dropout, uint64 seed,
- ScratchAllocator* state_allocator)
- : handle_(nullptr) {
- cudnnStatus_t status;
- status = cudnnCreateDropoutDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create dropout descriptor");
-
- if (dropout == 0.f) {
- return;
+ if (dropout == 0.0f) {
+ // Return 'empty' dropout descriptor.
+ return CudnnDropoutDescriptor(std::move(handle));
}
DeviceMemory<uint8> state_memory;
if (state_allocator) {
size_t state_sizes_in_bytes = 0;
- status = cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes);
- CUDNN_RETURN_IF_FAIL(status, "Failed to query dropout state sizes");
-
- auto allocated =
- state_allocator->AllocateBytes(nullptr, state_sizes_in_bytes);
- if (!allocated.ok() ||
- (state_memory = allocated.ValueOrDie()) == nullptr) {
- string error_msg =
- port::StrCat("Failed to allocate Cudnn dropout state memory of ",
- state_sizes_in_bytes, " bytes.");
- status_ = port::Status(port::error::UNKNOWN, error_msg);
- LOG(ERROR) << error_msg;
- return;
- }
+ RETURN_IF_CUDNN_ERROR(
+ cudnnDropoutGetStatesSize(cudnn.handle(), &state_sizes_in_bytes));
+ SE_ASSIGN_OR_RETURN(state_memory, state_allocator->AllocateBytes(
+ nullptr, state_sizes_in_bytes));
}
- status = cudnnSetDropoutDescriptor(handle_, cudnn.handle(), dropout,
- state_memory.opaque(),
- state_memory.size(), seed);
- CUDNN_RETURN_IF_FAIL(
- status, port::StrCat(
- "Failed to set dropout descriptor with state memory size: ",
- state_memory.size(), " bytes."));
- }
+ RETURN_IF_CUDNN_ERROR(cudnnSetDropoutDescriptor(
+ handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
+ state_memory.size(), seed));
- ~CudnnDropoutDescriptor() {
- cudnnStatus_t status = cudnnDestroyDropoutDescriptor(handle_);
- // TODO(csigg): This is a no-op (error is not reported). Same below.
- CUDNN_RETURN_IF_FAIL(status, "Failed to destroy Cudnn dropout handle: ");
+ return CudnnDropoutDescriptor(std::move(handle));
}
- cudnnDropoutDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
+ cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
private:
- cudnnDropoutDescriptor_t handle_; // Owned.
- float dropout_;
- uint64 seed_;
+ DropoutDescriptor handle_; // Owned.
SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
};
-class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
- public:
- typedef dnn::RnnDescriptor::ParamsRegion ParamsRegion;
+class CudnnRnnParamsDescriptor {
typedef dnn::RnnDescriptor::ParamsRegions ParamsRegions;
- CudnnRnnParamsDescriptor(const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc);
- ~CudnnRnnParamsDescriptor() {
- cudnnStatus_t status = cudnnDestroyFilterDescriptor(handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to destroy RNN filter descriptor");
- }
- cudnnFilterDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
+
+ CudnnRnnParamsDescriptor(FilterDescriptor handle, int64 params_size_in_bytes,
+ ParamsRegions weights, ParamsRegions biases)
+ : handle_(std::move(handle)),
+ params_size_in_bytes_(params_size_in_bytes),
+ weights_(std::move(weights)),
+ biases_(std::move(biases)) {}
+
+ public:
+ CudnnRnnParamsDescriptor(CudnnRnnParamsDescriptor&&) = default;
+
+ static port::StatusOr<CudnnRnnParamsDescriptor> Create(
+ const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
+ cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
+ cudnnDirectionMode_t direction_mode, int num_layers);
+
+ cudnnFilterDescriptor_t handle() const { return handle_.get(); }
int64 params_size_in_bytes() const { return params_size_in_bytes_; }
ParamsRegions params_weights() const {
- if (!ok()) return ParamsRegions();
return weights_;
}
ParamsRegions params_biases() const {
- if (!ok()) return ParamsRegions();
return biases_;
}
private:
- int GetRegionCountPerLayer() const;
- cudnnFilterDescriptor_t handle_;
- const CudnnRnnDescriptor* rnn_desc_;
+ FilterDescriptor handle_;
int64 params_size_in_bytes_;
ParamsRegions weights_;
ParamsRegions biases_;
@@ -992,97 +964,98 @@ class CudnnRnnParamsDescriptor : public CudnnDescriptorCommon<void> {
} // namespace
-class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
- public:
- CudnnRnnDescriptor(const CudnnHandle& cudnn, int num_layers, int hidden_size,
- int input_size, int batch_size,
+class CudnnRnnDescriptor : public dnn::RnnDescriptor {
+ CudnnRnnDescriptor(const CudnnHandle& cudnn, cuda::RnnDescriptor rnn_desc,
+ PersistentRnnPlan rnn_plan, int num_layers,
+ int hidden_size, int input_size, int batch_size,
cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
cudnnDataType_t compute_type,
const dnn::AlgorithmConfig& algorithm_config,
- float dropout, uint64 seed,
- ScratchAllocator* state_allocator)
- : rnn_desc_(nullptr),
+ CudnnDropoutDescriptor dropout_desc,
+ CudnnRnnParamsDescriptor params_desc)
+ : rnn_desc_(std::move(rnn_desc)),
+ rnn_plan_(std::move(rnn_plan)),
num_layers_(num_layers),
hidden_size_(hidden_size),
input_size_(input_size),
batch_size_(batch_size),
- rnn_plan_(nullptr),
+ rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
data_type_(data_type),
compute_type_(compute_type),
- algorithm_config_(algorithm_config) {
- // Create the dropout handle.
- cudnn_dropout_desc_.reset(
- new CudnnDropoutDescriptor(cudnn, dropout, seed, state_allocator));
- if (!cudnn_dropout_desc_->ok()) {
- SetFailure(cudnn_dropout_desc_->Status());
- return;
- }
+ algorithm_config_(algorithm_config),
+ dropout_desc_(std::move(dropout_desc)),
+ params_desc_(std::move(params_desc)) {}
+
+ public:
+ CudnnRnnDescriptor(CudnnRnnDescriptor&& other) = default;
+
+ static port::StatusOr<CudnnRnnDescriptor> Create(
+ const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
+ int batch_size, cudnnRNNInputMode_t input_mode,
+ cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
+ cudnnDataType_t data_type, cudnnDataType_t compute_type,
+ const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
+ SE_ASSIGN_OR_RETURN(
+ CudnnDropoutDescriptor dropout_desc,
+ CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
+
+ cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor();
+ cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
- // Create the RNN handle
- cudnnStatus_t status = cudnnCreateRNNDescriptor(&rnn_desc_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
// TODO: allow the user to choose an algorithm.
- rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
- status = cudnnSetRNNDescriptor_v6(
- cudnn.handle(), /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
- /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
+ cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
/*inputMode=*/input_mode, /*direction=*/direction_mode,
- /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type);
- CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf(
- "Unable to update RNN descriptor with "
- "algo_id: %d and compute_type: %d",
- static_cast<int>(rnn_algo_),
- static_cast<int>(compute_type)));
-
- if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
- CHECK_GE(batch_size_, 0);
- status = cudnnCreatePersistentRNNPlan(rnn_desc_, batch_size_, data_type_,
- &rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan.");
- status = cudnnSetPersistentRNNPlan(rnn_desc_, rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
+ /*mode=*/rnn_mode, /*algo=*/rnn_algo,
+ /*dataType=*/compute_type));
+
+ PersistentRnnPlan rnn_plan;
+ if (rnn_algo == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
+ CHECK_GE(batch_size, 0);
+ rnn_plan = CreatePersistentRnnPlan(rnn_desc.get(), batch_size, data_type);
+ RETURN_IF_CUDNN_ERROR(
+ cudnnSetPersistentRNNPlan(rnn_desc.get(), rnn_plan.get()));
}
// Create the params handle.
- cudnn_params_desc_.reset(new CudnnRnnParamsDescriptor(cudnn, *this));
- if (!cudnn_params_desc_->ok()) {
- SetFailure(cudnn_params_desc_->Status());
- return;
- }
- set_use_tensor_op_math(algorithm_config_.algorithm().tensor_ops_enabled());
- }
- ~CudnnRnnDescriptor() override {
- if (rnn_desc_) {
- cudnnStatus_t status;
- if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
- status = cudnnDestroyPersistentRNNPlan(rnn_plan_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
- }
- status = cudnnDestroyRNNDescriptor(rnn_desc_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
- }
- }
- void set_use_tensor_op_math(bool use_tensor_op_math) {
+ SE_ASSIGN_OR_RETURN(auto params_desc,
+ CudnnRnnParamsDescriptor::Create(
+ cudnn, input_size, data_type, rnn_desc.get(),
+ rnn_mode, direction_mode, num_layers));
+
#if CUDNN_VERSION >= 7000
- cudnnMathType_t math_type =
- (use_tensor_op_math ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH);
- if (RnnTensorOpMathEnabled()) {
- cudnnStatus_t status = cudnnSetRNNMatrixMathType(rnn_desc_, math_type);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(FATAL) << "could not set cudnn RNN math type: " << ToString(status);
- }
+ // Require explicit algorithm config to enable tensor cores. Some configs
+ // return CUDNN_NOT_SUPPORTED when tensor ops are enabled (which is against
+ // the idiom that enabling tensor ops is only a hint: see nvbugs/2172799).
+ // We can only reasonably expect the user to handle the subsequent failure
+ // in profile mode, which is run with algorithms returned from
+ // GetRnnAlgorithms() (which are non-default and explicitly set whether to
+ // use tensor ops).
+ if (RnnTensorOpMathEnabled() &&
+ !algorithm_config.algorithm().is_default()) {
+ cudnnMathType_t math_type =
+ algorithm_config.algorithm().tensor_ops_enabled()
+ ? CUDNN_TENSOR_OP_MATH
+ : CUDNN_DEFAULT_MATH;
+ CHECK_CUDNN_OK(cudnnSetRNNMatrixMathType(rnn_desc.get(), math_type));
}
#endif
+
+ return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
+ num_layers, hidden_size, input_size, batch_size,
+ input_mode, direction_mode, rnn_mode, data_type,
+ compute_type, algorithm_config,
+ std::move(dropout_desc), std::move(params_desc));
}
- cudnnRNNDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return rnn_desc_;
- }
+
+ cudnnRNNDescriptor_t handle() const { return rnn_desc_.get(); }
int num_layers() const { return num_layers_; }
int hidden_size() const { return hidden_size_; }
int input_size() const { return input_size_; }
@@ -1096,27 +1069,21 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
return algorithm_config_;
}
int64 ParamsSizeInBytes() const override {
- return cudnn_params_desc_->params_size_in_bytes();
- }
- cudnnDropoutDescriptor_t dropout_handle() const {
- if (!cudnn_dropout_desc_) return nullptr;
- return cudnn_dropout_desc_->handle();
+ return params_desc_.params_size_in_bytes();
}
cudnnFilterDescriptor_t params_handle() const {
- if (!cudnn_params_desc_) return nullptr;
- return cudnn_params_desc_->handle();
+ return params_desc_.handle();
}
ParamsRegions ParamsWeightRegions() const override {
- if (!ok()) return ParamsRegions();
- return cudnn_params_desc_->params_weights();
+ return params_desc_.params_weights();
}
ParamsRegions ParamsBiasRegions() const override {
- if (!ok()) return ParamsRegions();
- return cudnn_params_desc_->params_biases();
+ return params_desc_.params_biases();
}
private:
- cudnnRNNDescriptor_t rnn_desc_;
+ cuda::RnnDescriptor rnn_desc_;
+ PersistentRnnPlan rnn_plan_;
int num_layers_;
int hidden_size_;
int input_size_;
@@ -1124,180 +1091,142 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
// algorithm.
int batch_size_;
cudnnRNNAlgo_t rnn_algo_;
- cudnnPersistentRNNPlan_t rnn_plan_;
cudnnRNNInputMode_t input_mode_;
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
cudnnDataType_t data_type_;
cudnnDataType_t compute_type_;
dnn::AlgorithmConfig algorithm_config_;
- std::unique_ptr<CudnnDropoutDescriptor> cudnn_dropout_desc_;
- std::unique_ptr<CudnnRnnParamsDescriptor> cudnn_params_desc_;
+ CudnnDropoutDescriptor dropout_desc_;
+ CudnnRnnParamsDescriptor params_desc_;
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
};
namespace {
-CudnnRnnParamsDescriptor::CudnnRnnParamsDescriptor(
- const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc)
- : handle_(nullptr), rnn_desc_(&rnn_desc), params_size_in_bytes_(0) {
- cudnnTensorDescriptor_t input_desc = nullptr;
- {
- // Query the params size.
- auto status = cudnnCreateTensorDescriptor(&input_desc);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create tensor descriptor");
- int dims[] = {1, rnn_desc.input_size(), 1};
- int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/input_desc, /*dataType=*/rnn_desc.data_type(),
- /*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to set tensor descriptor");
-
- size_t params_size = 0;
- status = cudnnGetRNNParamsSize(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*xDesc=*/input_desc, /*sizeInBytes=*/&params_size,
- /*dataType=*/rnn_desc.data_type());
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get RNN parameter size");
- params_size_in_bytes_ = static_cast<int64>(params_size);
- }
-
- {
- // Create the params descriptor.
- auto status = cudnnCreateFilterDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create RNN filter descriptor");
- int dims[] = {static_cast<int>(params_size_in_bytes_), 1, 1};
- status = cudnnSetFilterNdDescriptor(
- /*filterDesc=*/handle_, /*dataType=*/rnn_desc.data_type(),
- /*format=*/CUDNN_TENSOR_NCHW, /*nbDims=*/sizeof(dims) / sizeof(dims[0]),
- /*filterDimA=*/dims);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to update RNN filter descriptor");
- }
+port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
+ const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
+ cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
+ cudnnDirectionMode_t direction_mode, int num_layers) {
+ // Query the params size.
+ TensorDescriptor input_desc = CreateTensorDescriptor();
+ int tensor_dims[] = {1, input_size, 1};
+ int strides[] = {tensor_dims[1] * tensor_dims[2], tensor_dims[2], 1};
+ RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/input_desc.get(), /*dataType=*/data_type,
+ /*nbDims=*/sizeof(tensor_dims) / sizeof(tensor_dims[0]),
+ /*dimA=*/tensor_dims,
+ /*strideA=*/strides));
+
+ size_t params_size = 0;
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+ /*xDesc=*/input_desc.get(), /*sizeInBytes=*/&params_size,
+ /*dataType=*/data_type));
+ int64 params_size_in_bytes = static_cast<int64>(params_size);
+
+ FilterDescriptor filter_desc = CreateFilterDescriptor();
+ int filter_dims[] = {static_cast<int>(params_size_in_bytes), 1, 1};
+ RETURN_IF_CUDNN_ERROR(cudnnSetFilterNdDescriptor(
+ /*filterDesc=*/filter_desc.get(), /*dataType=*/data_type,
+ /*format=*/CUDNN_TENSOR_NCHW,
+ /*nbDims=*/sizeof(filter_dims) / sizeof(filter_dims[0]),
+ /*filterDimA=*/filter_dims));
+
+ // Create the weights and biases into the params buffer
+ int region_count_per_layer = [&] {
+ switch (rnn_mode) {
+ case CUDNN_RNN_RELU:
+ case CUDNN_RNN_TANH:
+ return 2;
+ case CUDNN_LSTM:
+ return 8;
+ case CUDNN_GRU:
+ return 6;
+ default:
+ LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
+ return 0;
+ }
+ }();
- {
- // Create the weights and biases into the params buffer
- int region_count_per_layer = GetRegionCountPerLayer();
- cudnnFilterDescriptor_t region_desc_handle = nullptr;
- auto status = cudnnCreateFilterDescriptor(&region_desc_handle);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to create filter descriptor");
- const int layer_count = rnn_desc.direction_mode() == CUDNN_UNIDIRECTIONAL
- ? rnn_desc.num_layers()
- : 2 * rnn_desc.num_layers();
- for (int layer = 0; layer < layer_count; layer++) {
- for (int region = 0; region < region_count_per_layer; region++) {
- for (int type = 0; type < 2; type++) {
- void* offset = nullptr;
- if (type == 0) {
- status = cudnnGetRNNLinLayerMatrixParams(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
- /*w=*/nullptr, /*linLayerID=*/region,
- /*linLayerMatDesc=*/region_desc_handle,
- /*linLayerMat=*/&offset);
- CUDNN_RETURN_IF_FAIL(
- status, "Cudnn fails to call cudnnGetRNNLinLayerMatrixParams");
- } else {
- status = cudnnGetRNNLinLayerBiasParams(
- /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
- /*layer=*/layer, /*xDesc=*/input_desc, /*wDesc=*/handle_,
- /*w=*/nullptr, /*linLayerID=*/region,
- /*linLayerBiasDesc=*/region_desc_handle,
- /*linLayerBias=*/&offset);
- CUDNN_RETURN_IF_FAIL(
- status, "Cudnn fails to call cudnnGetRNNLinLayerBiasParams");
- }
- int dims[] = {1, 1, 1};
- cudnnDataType_t data_type;
- cudnnTensorFormat_t tensor_format;
- int n_dims;
- status = cudnnGetFilterNdDescriptor(
- /*filterDesc=*/region_desc_handle,
- /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
- /*dataType=*/&data_type, /*format=*/&tensor_format,
- /*nbDims=*/&n_dims, /*filterDimA=*/dims);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to get filter description");
- int64 size = dims[0] * dims[1] * dims[2] *
- CudnnDataTypeToByteSize(rnn_desc.data_type());
- ParamsRegion region = {reinterpret_cast<int64>(offset), size};
- if (type == 0) {
- weights_.push_back(region);
- } else {
- biases_.push_back(region);
- }
- }
+ FilterDescriptor region_desc_handle = CreateFilterDescriptor();
+ const int layer_count =
+ direction_mode == CUDNN_UNIDIRECTIONAL ? num_layers : 2 * num_layers;
+
+ ParamsRegions weights;
+ ParamsRegions biases;
+
+ for (int layer = 0; layer < layer_count; layer++) {
+ for (int region = 0; region < region_count_per_layer; region++) {
+ for (int type = 0; type < 2; type++) {
+ void* offset = nullptr;
+ RETURN_IF_CUDNN_ERROR((type == 0 ? cudnnGetRNNLinLayerMatrixParams
+ : cudnnGetRNNLinLayerBiasParams)(
+ /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+ /*layer=*/layer, /*xDesc=*/input_desc.get(),
+ /*wDesc=*/filter_desc.get(),
+ /*w=*/nullptr, /*linLayerID=*/region,
+ /*linLayerMatDesc=*/region_desc_handle.get(),
+ /*linLayerMat or linLayerBias=*/&offset));
+ int dims[] = {1, 1, 1};
+ cudnnDataType_t data_type;
+ cudnnTensorFormat_t tensor_format;
+ int n_dims;
+ RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
+ /*filterDesc=*/region_desc_handle.get(),
+ /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
+ /*dataType=*/&data_type, /*format=*/&tensor_format,
+ /*nbDims=*/&n_dims, /*filterDimA=*/dims));
+ int64 size =
+ dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
+ dnn::RnnDescriptor::ParamsRegion region = {
+ reinterpret_cast<int64>(offset), size};
+ (type == 0 ? weights : biases).push_back(region);
}
}
- status = cudnnDestroyFilterDescriptor(region_desc_handle);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy filter descriptor");
}
- {
- // Release the dummy input tensor descriptor.
- auto status = cudnnDestroyTensorDescriptor(input_desc);
- CUDNN_RETURN_IF_FAIL(status, "Cudnn fails to destroy tensor descriptor");
- }
-}
-
-int CudnnRnnParamsDescriptor::GetRegionCountPerLayer() const {
- auto rnn_mode = rnn_desc_->rnn_mode();
- switch (rnn_mode) {
- case CUDNN_RNN_RELU:
- case CUDNN_RNN_TANH:
- return 2;
- case CUDNN_LSTM:
- return 8;
- case CUDNN_GRU:
- return 6;
- default:
- LOG(FATAL) << "Invalid RNN Mode: " << static_cast<int>(rnn_mode);
- }
+ return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
+ weights, biases);
}
} // namespace
class CudnnRnnSequenceTensorDescriptor
- : public CudnnDescriptorCommon<dnn::RnnSequenceTensorDescriptor> {
- public:
+ : public dnn::RnnSequenceTensorDescriptor {
CudnnRnnSequenceTensorDescriptor(CUDAExecutor* parent, int seq_length,
int batch_size, int data_size,
- cudnnDataType_t data_type)
+ cudnnDataType_t data_type,
+ TensorDescriptor handle)
: parent_(parent),
seq_length_(seq_length),
batch_size_(batch_size),
data_size_(data_size),
- data_type_(data_type) {
- cudnnTensorDescriptor_t handle = nullptr;
- if (seq_length <= 0) {
- string error_msg =
- port::StrCat("sequence length must be positive: ", seq_length);
- LOG(ERROR) << error_msg;
- SetFailure(port::Status(port::error::UNKNOWN, error_msg));
- return;
- }
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
+ data_type_(data_type),
+ handle_(std::move(handle)),
+ handles_(seq_length, handle_.get()) {}
+
+ public:
+ CudnnRnnSequenceTensorDescriptor(CudnnRnnSequenceTensorDescriptor&&) =
+ default;
+
+ static port::StatusOr<CudnnRnnSequenceTensorDescriptor> Create(
+ CUDAExecutor* parent, int seq_length, int batch_size, int data_size,
+ cudnnDataType_t data_type) {
+ CHECK_GT(seq_length, 0);
int dims[] = {batch_size, data_size, 1};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/handle, /*dataType=*/data_type,
+ TensorDescriptor tensor_desc = CreateTensorDescriptor();
+ RETURN_IF_CUDNN_ERROR(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/tensor_desc.get(), /*dataType=*/data_type,
/*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
- // Replicate handle across the number of steps.
- handles_.assign(seq_length, handle);
- }
-
- ~CudnnRnnSequenceTensorDescriptor() override {
- // Only the first one needs to be destroyed. All others are the same.
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handles_[0]);
- CUDNN_RETURN_IF_FAIL(status,
- "Failed to destroy sequence tensor descriptor");
+ /*strideA=*/strides));
+ return CudnnRnnSequenceTensorDescriptor(parent, seq_length, batch_size,
+ data_size, data_type,
+ std::move(tensor_desc));
}
const cudnnTensorDescriptor_t* handles() const {
- if (!ok()) return nullptr;
- CHECK(!handles_.empty()) << "handles cannot be empty";
return handles_.data();
}
@@ -1311,51 +1240,39 @@ class CudnnRnnSequenceTensorDescriptor
int batch_size_;
int data_size_;
cudnnDataType_t data_type_;
- std::vector<cudnnTensorDescriptor_t> handles_;
+ TensorDescriptor handle_;
+ std::vector<cudnnTensorDescriptor_t> handles_; // Copies of handle_.
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnSequenceTensorDescriptor);
};
-class CudnnRnnStateTensorDescriptor
- : public CudnnDescriptorCommon<dnn::RnnStateTensorDescriptor> {
+class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor {
public:
CudnnRnnStateTensorDescriptor(CUDAExecutor* parent, int num_layers,
int batch_size, int data_size,
cudnnDataType_t data_type)
: parent_(parent),
- handle_(nullptr),
+ handle_(CreateTensorDescriptor()),
num_layers_(num_layers),
batch_size_(batch_size),
data_size_(data_size),
data_type_(data_type) {
- cudnnStatus_t status = cudnnCreateTensorDescriptor(&handle_);
- CUDNN_RETURN_IF_FAIL(status, "Failed to create tensor descriptor");
int dims[] = {num_layers, batch_size, data_size};
int strides[] = {dims[1] * dims[2], dims[2], 1};
- status = cudnnSetTensorNdDescriptor(
- /*tensorDesc=*/handle_, /*dataType=*/data_type,
+ CHECK_CUDNN_OK(cudnnSetTensorNdDescriptor(
+ /*tensorDesc=*/handle_.get(), /*dataType=*/data_type,
/*nbDims=*/sizeof(dims) / sizeof(dims[0]), /*dimA=*/dims,
- /*strideA=*/strides);
- CUDNN_RETURN_IF_FAIL(status, "Failed to update tensor descriptor");
+ /*strideA=*/strides));
}
- ~CudnnRnnStateTensorDescriptor() override {
- if (!handle_) {
- cudnnStatus_t status = cudnnDestroyTensorDescriptor(handle_);
- CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN state tensor");
- }
- }
+ cudnnTensorDescriptor_t handle() const { return handle_.get(); }
- cudnnTensorDescriptor_t handle() const {
- if (!ok()) return nullptr;
- return handle_;
- }
int num_layers() const { return num_layers_; }
int batch_size() const { return batch_size_; }
int data_size() const { return data_size_; }
private:
CUDAExecutor* parent_;
- cudnnTensorDescriptor_t handle_;
+ TensorDescriptor handle_;
int num_layers_;
int batch_size_;
int data_size_;
@@ -1375,7 +1292,7 @@ struct RnnModelDims {
};
template <class T>
-bool ExtractAndCheckRnnForward(
+port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1388,103 +1305,89 @@ bool ExtractAndCheckRnnForward(
const CudnnRnnStateTensorDescriptor& output_h_desc,
const DeviceMemory<T>& output_h_data,
const CudnnRnnStateTensorDescriptor& output_c_desc,
- const DeviceMemory<T>& output_c_data, RnnModelDims* model_dims) {
+ const DeviceMemory<T>& output_c_data) {
// extract model parameters
- model_dims->num_layers = rnn_desc.num_layers();
- model_dims->batch_size = input_desc.batch_size();
- model_dims->seq_length = input_desc.seq_length();
- model_dims->hidden_size = rnn_desc.hidden_size();
- model_dims->input_size = input_desc.data_size();
- model_dims->dir_count =
+ RnnModelDims model_dims;
+ model_dims.num_layers = rnn_desc.num_layers();
+ model_dims.batch_size = input_desc.batch_size();
+ model_dims.seq_length = input_desc.seq_length();
+ model_dims.hidden_size = rnn_desc.hidden_size();
+ model_dims.input_size = input_desc.data_size();
+ model_dims.dir_count =
(rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
// check parameters
if (!(input_h_desc.num_layers() ==
- model_dims->num_layers * model_dims->dir_count &&
- input_h_desc.batch_size() == model_dims->batch_size &&
- input_h_desc.data_size() == model_dims->hidden_size)) {
- LOG(ERROR) << "Invalid input_h shape";
- return false;
+ model_dims.num_layers * model_dims.dir_count &&
+ input_h_desc.batch_size() == model_dims.batch_size &&
+ input_h_desc.data_size() == model_dims.hidden_size)) {
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
}
if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
input_h_desc.batch_size() == input_c_desc.batch_size() &&
input_h_desc.data_size() == input_c_desc.data_size())) {
- LOG(ERROR) << "Invalid input_c shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
}
- if (!(output_desc.seq_length() == model_dims->seq_length &&
- output_desc.batch_size() == model_dims->batch_size &&
+ if (!(output_desc.seq_length() == model_dims.seq_length &&
+ output_desc.batch_size() == model_dims.batch_size &&
output_desc.data_size() ==
- model_dims->hidden_size * model_dims->dir_count)) {
- LOG(ERROR) << "Invalid output shape";
- return false;
+ model_dims.hidden_size * model_dims.dir_count)) {
+ return port::Status(port::error::INVALID_ARGUMENT, "Invalid output shape");
}
if (!(input_h_desc.num_layers() == output_h_desc.num_layers() &&
input_h_desc.batch_size() == output_h_desc.batch_size() &&
input_h_desc.data_size() == output_h_desc.data_size())) {
- LOG(ERROR) << "Invalid output_h shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Invalid output_h shape");
}
if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
input_h_desc.batch_size() == output_c_desc.batch_size() &&
input_h_desc.data_size() == output_c_desc.data_size())) {
- LOG(ERROR) << "Invalid output_h shape";
- return false;
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Invalid output_c shape");
}
- return true;
+ return model_dims;
}
-bool CheckRNNParameterSize(const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc) {
+port::Status CheckRNNParameterSize(
+ const CudnnHandle& cudnn, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc) {
size_t params_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNParamsSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNParamsSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*xDesc=*/input_desc.handles()[0], /*sizeInBytes=*/&params_size_in_bytes,
- /*dataType=*/rnn_desc.data_type());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to check RNN param size: " << ToString(status);
- return false;
+ /*dataType=*/rnn_desc.data_type()));
+ if (static_cast<int64>(params_size_in_bytes) !=
+ rnn_desc.ParamsSizeInBytes()) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "Mismatching RNN parameter size");
}
- return static_cast<int64>(params_size_in_bytes) ==
- rnn_desc.ParamsSizeInBytes();
+ return port::Status::OK();
}
-bool CreateRnnWorkspace(Stream* stream, const CudnnHandle& cudnn,
- const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- ScratchAllocator* workspace_allocator,
- DeviceMemory<uint8>* workspace) {
+port::StatusOr<DeviceMemory<uint8>> CreateRnnWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ ScratchAllocator* workspace_allocator) {
// Query the workspace size.
size_t workspace_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNWorkspaceSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNWorkspaceSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/input_desc.seq_length(), /*xDesc=*/input_desc.handles(),
- /*sizeInBytes=*/&workspace_size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to query workspace size: " << ToString(status);
- return false;
- }
+ /*sizeInBytes=*/&workspace_size_in_bytes));
// Allocate the workspace.
- if (workspace_size_in_bytes > 0) {
- auto allocated =
- workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
- if (!allocated.ok() || (*workspace = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << port::StrCat("Failed to allocate RNN workspace of ",
- workspace_size_in_bytes, " bytes.");
- return false;
- }
- } else {
- *workspace = DeviceMemory<uint8>();
+ if (workspace_size_in_bytes == 0) {
+ return DeviceMemory<uint8>();
}
- return true;
+ return workspace_allocator->AllocateBytes(stream, workspace_size_in_bytes);
}
} // namespace
template <class T>
-bool CudnnSupport::DoRnnForwardImpl(
+port::Status CudnnSupport::DoRnnForwardImpl(
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1501,57 +1404,34 @@ bool CudnnSupport::DoRnnForwardImpl(
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
- // extract model parameters
- RnnModelDims model_dims;
- bool res = ExtractAndCheckRnnForward(
- rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
- input_c_desc, input_c_data, params, output_desc, *output_data,
- output_h_desc, *output_h_data, output_c_desc, *output_c_data,
- &model_dims);
- if (!res) {
- LOG(ERROR) << "Invalid parameters for RNN Model";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(
+ RnnModelDims model_dims,
+ ExtractAndCheckRnnForward(
+ rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
+ input_c_desc, input_c_data, params, output_desc, *output_data,
+ output_h_desc, *output_h_data, output_c_desc, *output_c_data));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // check params size
- if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
- LOG(ERROR) << "Invalid parameters";
- return false;
- }
-
- // create the workspace
- DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
- workspace_allocator, &workspace)) {
- LOG(ERROR) << "Unable to create rnn workspace";
- return false;
- }
+ SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
+ CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator))
// query the reserve space size
// allocate the reserve space
DeviceMemory<uint8> reserve_space;
if (is_training) {
size_t reserve_space_size_in_bytes = 0;
- cudnnStatus_t status = cudnnGetRNNTrainingReserveSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetRNNTrainingReserveSize(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
- /*sizeInBytes=*/&reserve_space_size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Unable to query reserve space size: " << ToString(status);
- return false;
- }
+ /*sizeInBytes=*/&reserve_space_size_in_bytes));
if (reserve_space_size_in_bytes > 0) {
- auto allocated = reserve_space_allocator->AllocateBytes(
- stream, reserve_space_size_in_bytes);
- if (!allocated.ok() ||
- (reserve_space = allocated.ValueOrDie()) == nullptr) {
- LOG(ERROR) << "Failed to allocate RNN reserve space of "
- << reserve_space_size_in_bytes << " bytes.";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(reserve_space,
+ reserve_space_allocator->AllocateBytes(
+ stream, reserve_space_size_in_bytes));
}
}
@@ -1559,20 +1439,16 @@ bool CudnnSupport::DoRnnForwardImpl(
const bool is_profiling = output_profile_result != nullptr;
if (is_profiling) {
timer.reset(new CUDATimer(parent_));
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- // make the forward call
- cudnnStatus_t status;
+
if (!is_training) {
- status = cudnnRNNForwardInference(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNForwardInference(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1582,9 +1458,9 @@ bool CudnnSupport::DoRnnForwardImpl(
/*y=*/output_data->opaque(), /*hyDesc=*/output_h_desc.handle(),
/*hy=*/output_h_data->opaque(), /*cyDesc=*/output_c_desc.handle(),
/*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
- /*workSpaceSizeInBytes=*/workspace.size());
+ /*workSpaceSizeInBytes=*/workspace.size()));
} else {
- status = cudnnRNNForwardTraining(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNForwardTraining(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1596,35 +1472,24 @@ bool CudnnSupport::DoRnnForwardImpl(
/*cy=*/output_c_data->opaque(), /*workspace=*/workspace.opaque(),
/*workSpaceSizeInBytes=*/workspace.size(),
/*reserveSpace=*/reserve_space.opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space.size());
+ /*reserveSpaceSizeInBytes=*/reserve_space.size()));
}
+
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- return false;
- }
- if (status == CUDNN_STATUS_SUCCESS) {
- auto algo_desc = rnn_desc.algorithm_config().algorithm();
- output_profile_result->set_algorithm(algo_desc);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "Failed to call "
- << (is_training ? "cudnnRNNForwardTraining "
- : "cudnnRNNForwardInference ")
- << ToString(status);
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
+ auto algo_desc = rnn_desc.algorithm_config().algorithm();
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
template <class T>
-bool CudnnSupport::DoRnnBackwardImpl(
+port::Status CudnnSupport::DoRnnBackwardImpl(
Stream* stream, const CudnnRnnDescriptor& rnn_desc,
const CudnnRnnSequenceTensorDescriptor& input_desc,
const DeviceMemory<T>& input_data,
@@ -1648,53 +1513,38 @@ bool CudnnSupport::DoRnnBackwardImpl(
DeviceMemory<uint8>* reserve_space_data,
ScratchAllocator* workspace_allocator,
dnn::ProfileResult* output_profile_result) {
- // extract model parameters
- RnnModelDims model_dims;
- bool res = ExtractAndCheckRnnForward(
- rnn_desc, input_desc, input_data, input_h_desc, input_h_data,
- input_c_desc, input_c_data, params, output_desc, output_data,
- output_h_desc, output_h_data, output_c_desc, output_c_data, &model_dims);
- if (!res) {
- LOG(ERROR) << "Invalid parameters for RNN Model";
- return false;
- }
+ SE_ASSIGN_OR_RETURN(
+ RnnModelDims model_dims,
+ ExtractAndCheckRnnForward(rnn_desc, input_desc, input_data, input_h_desc,
+ input_h_data, input_c_desc, input_c_data,
+ params, output_desc, output_data, output_h_desc,
+ output_h_data, output_c_desc, output_c_data));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- // check params size
- if (!CheckRNNParameterSize(cudnn, rnn_desc, input_desc)) {
- LOG(ERROR) << "Invalid parameters";
- return false;
- }
-
- // create the workspace
- DeviceMemory<uint8> workspace;
- if (!CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
- workspace_allocator, &workspace)) {
- LOG(ERROR) << "Unable to create rnn workspace";
- return false;
- }
+ SE_RETURN_IF_ERROR(CheckRNNParameterSize(cudnn, rnn_desc, input_desc));
+ SE_ASSIGN_OR_RETURN(DeviceMemory<uint8> workspace,
+ CreateRnnWorkspace(stream, cudnn, rnn_desc, input_desc,
+ workspace_allocator));
std::unique_ptr<CUDATimer, TimerDeleter> timer;
const bool is_profiling = output_profile_result != nullptr;
if (is_profiling) {
timer.reset(new CUDATimer(parent_));
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- // make the backward data call
- cudnnStatus_t status = cudnnRNNBackwardData(
+
+ RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardData(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*yDesc=*/output_desc.handles(),
/*y=*/output_data.opaque(), /*dyDesc=*/output_desc.handles(),
- /*dy=*/output_backprop_data.opaque(), /*dhyDesc=*/output_h_desc.handle(),
+ /*dy=*/output_backprop_data.opaque(),
+ /*dhyDesc=*/output_h_desc.handle(),
/*dhy=*/output_h_backprop_data.opaque(),
/*dcyDesc=*/output_c_desc.handle(),
/*dcy=*/output_c_backprop_data.opaque(),
@@ -1705,24 +1555,17 @@ bool CudnnSupport::DoRnnBackwardImpl(
/*dhxDesc=*/input_h_desc.handle(),
/*dhx=*/input_h_backprop_data->opaque(),
/*dcxDesc=*/input_c_desc.handle(),
- /*dcx=*/input_c_backprop_data->opaque(), /*workspace=*/workspace.opaque(),
+ /*dcx=*/input_c_backprop_data->opaque(),
+ /*workspace=*/workspace.opaque(),
/*workSpaceSizeInBytes=*/workspace.size(),
/*reserveSpace=*/reserve_space_data->opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- }
- LOG(ERROR) << "Failed to call cudnnRNNBackwardData: " << ToString(status);
- return false;
- }
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
if (params_backprop_data != nullptr) {
// Clear the dw to zeros.
stream->ThenMemZero(params_backprop_data, params_backprop_data->size());
// make the backward weight call
- status = cudnnRNNBackwardWeights(
+ RETURN_IF_CUDNN_ERROR(cudnnRNNBackwardWeights(
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc.handle(),
/*seqLength=*/model_dims.seq_length, /*xDesc=*/input_desc.handles(),
/*x=*/input_data.opaque(), /*hxDesc=*/input_h_desc.handle(),
@@ -1732,19 +1575,12 @@ bool CudnnSupport::DoRnnBackwardImpl(
/*dwDesc=*/rnn_desc.params_handle(),
/*dw=*/params_backprop_data->opaque(),
/*reserveSpace=*/reserve_space_data->opaque(),
- /*reserveSpaceSizeInBytes=*/reserve_space_data->size());
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- }
- LOG(ERROR) << "Failed to call cudnnRNNBackwardWeights: "
- << ToString(status);
- return false;
- }
+ /*reserveSpaceSizeInBytes=*/reserve_space_data->size()));
}
+
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
auto algo_desc = rnn_desc.algorithm_config().algorithm();
output_profile_result->set_algorithm(algo_desc);
@@ -1752,7 +1588,7 @@ bool CudnnSupport::DoRnnBackwardImpl(
timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
@@ -1765,46 +1601,37 @@ CudnnSupport::createRnnDescriptor(
// Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
// not enqueueing anything into a stream, we pass in the null stream.
auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
- std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
- cudnn, num_layers, hidden_size, input_size, batch_size,
- ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
- ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
- GetRnnComputeType(data_type), algorithm_config, dropout, seed,
- state_allocator));
- if (!rnn_desc->ok()) {
- return rnn_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>(
- std::move(rnn_desc));
+ SE_ASSIGN_OR_RETURN(
+ CudnnRnnDescriptor rnn_desc,
+ CudnnRnnDescriptor::Create(
+ cudnn, num_layers, hidden_size, input_size, batch_size,
+ ToCudnnRnnInputMode(input_mode),
+ ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
+ ToCudnnDataType(data_type), GetRnnComputeType(data_type),
+ algorithm_config, dropout, seed, state_allocator));
+ return std::unique_ptr<dnn::RnnDescriptor>(
+ new CudnnRnnDescriptor(std::move(rnn_desc)));
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
CudnnSupport::createRnnSequenceTensorDescriptor(int seq_length, int batch_size,
int data_size,
dnn::DataType data_type) {
- std::unique_ptr<CudnnRnnSequenceTensorDescriptor> seq_desc(
- new CudnnRnnSequenceTensorDescriptor(parent_, seq_length, batch_size,
- data_size,
- ToCudnnDataType(data_type)));
- if (!seq_desc->ok()) {
- return seq_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>(
- std::move(seq_desc));
+ SE_ASSIGN_OR_RETURN(CudnnRnnSequenceTensorDescriptor descriptor,
+ CudnnRnnSequenceTensorDescriptor::Create(
+ parent_, seq_length, batch_size, data_size,
+ ToCudnnDataType(data_type)));
+ return std::unique_ptr<dnn::RnnSequenceTensorDescriptor>(
+ new CudnnRnnSequenceTensorDescriptor(std::move(descriptor)));
}
port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>
CudnnSupport::createRnnStateTensorDescriptor(int num_layer, int batch_size,
int data_size,
dnn::DataType data_type) {
- std::unique_ptr<CudnnRnnStateTensorDescriptor> state_desc(
+ return std::unique_ptr<dnn::RnnStateTensorDescriptor>(
new CudnnRnnStateTensorDescriptor(parent_, num_layer, batch_size,
data_size, ToCudnnDataType(data_type)));
- if (!state_desc->ok()) {
- return state_desc->Status();
- }
- return port::StatusOr<std::unique_ptr<dnn::RnnStateTensorDescriptor>>(
- std::move(state_desc));
}
bool CudnnSupport::DoRnnForward(
@@ -1840,12 +1667,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<Eigen::half>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<Eigen::half>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnForward(
@@ -1880,12 +1709,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<float>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnForward(
@@ -1921,12 +1752,14 @@ bool CudnnSupport::DoRnnForward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnForwardImpl<double>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, is_training, reserve_space_allocator, workspace_allocator,
- output_profile_result);
+ return IsStatusOk(
+ DoRnnForwardImpl<double>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data, is_training,
+ reserve_space_allocator, workspace_allocator, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -1969,14 +1802,17 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<Eigen::half>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<Eigen::half>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -2018,14 +1854,17 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<float>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<float>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoRnnBackward(
@@ -2068,121 +1907,351 @@ bool CudnnSupport::DoRnnBackward(
const CudnnRnnStateTensorDescriptor& cudnn_output_c_desc =
static_cast<const CudnnRnnStateTensorDescriptor&>(output_c_desc);
- return DoRnnBackwardImpl<double>(
- stream, cudnn_rnn_desc, cudnn_input_desc, input_data, cudnn_input_h_desc,
- input_h_data, cudnn_input_c_desc, input_c_data, params, cudnn_output_desc,
- output_data, cudnn_output_h_desc, output_h_data, cudnn_output_c_desc,
- output_c_data, output_backprop_data, output_h_backprop_data,
- output_c_backprop_data, input_backprop_data, input_h_backprop_data,
- input_c_backprop_data, params_backprop_data, reserve_space_data,
- workspace_allocator, output_profile_result);
+ return IsStatusOk(
+ DoRnnBackwardImpl<double>(
+ stream, cudnn_rnn_desc, cudnn_input_desc, input_data,
+ cudnn_input_h_desc, input_h_data, cudnn_input_c_desc, input_c_data,
+ params, cudnn_output_desc, output_data, cudnn_output_h_desc,
+ output_h_data, cudnn_output_c_desc, output_c_data,
+ output_backprop_data, output_h_backprop_data, output_c_backprop_data,
+ input_backprop_data, input_h_backprop_data, input_c_backprop_data,
+ params_backprop_data, reserve_space_data, workspace_allocator,
+ output_profile_result),
+ /*report_error=*/!output_profile_result);
}
namespace {
-inline cudnnConvolutionFwdAlgo_t GetCudnnConvolutionForwardAlgo(
- const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd,
- const ScopedFilterDescriptor& filter,
- const ScopedConvolutionDescriptor& conv,
- const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
+// TODO(csigg): Merge a lot of duplicate code below for forward, backward data,
+// and backward filter.
+
+port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
+ const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
+ const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
size_t memory_limit_bytes) {
cudnnConvolutionFwdPreference_t preference =
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
: CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
cudnnConvolutionFwdAlgo_t algo_to_use;
- auto status = cudnnGetConvolutionForwardAlgorithm(
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardAlgorithm(
cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
- output_nd.handle(), preference, memory_limit_bytes, &algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
- << "Unable to find a suitable algorithm for doing forward convolution";
+ output_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
return algo_to_use;
}
-dnn::AlgorithmDesc GetCudnnConvolutionForwardAlgorithm(
- Stream* stream, const CudnnHandle& cudnn,
- const dnn::AlgorithmConfig& algorithm_config, bool is_profiling,
- const ScopedTensorDescriptor& input_nd,
- const ScopedFilterDescriptor& filter,
- const ScopedConvolutionDescriptor& conv,
- const ScopedTensorDescriptor& output_nd,
- ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
- cudnnConvolutionFwdAlgo_t algo;
- bool use_tensor_ops;
- if (algorithm_config.algorithm().is_default()) {
- use_tensor_ops = true;
+port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
+GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
+ const CudnnTensorDescriptor& input_nd,
+ const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd,
+ bool specify_workspace_limit,
+ size_t memory_limit_bytes) {
+ cudnnConvolutionBwdDataPreference_t preference =
+ specify_workspace_limit
+ ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
+ cudnnConvolutionBwdDataAlgo_t algo_to_use;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataAlgorithm(
+ cudnn.handle(), filter.handle(), output_nd.handle(), conv.handle(),
+ input_nd.handle(), preference, memory_limit_bytes, &algo_to_use));
+ return algo_to_use;
+}
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
+port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
+GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
+ const CudnnTensorDescriptor& input_nd,
+ const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd,
+ bool specify_workspace_limit,
+ size_t memory_limit_bytes) {
+ cudnnConvolutionBwdFilterPreference_t preference =
+ specify_workspace_limit
+ ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+ cudnnConvolutionBwdFilterAlgo_t algo_to_use;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm(
+ cudnn.handle(), input_nd.handle(), output_nd.handle(), conv.handle(),
+ filter.handle(), preference, memory_limit_bytes, &algo_to_use));
+ return algo_to_use;
+}
- algo = GetCudnnConvolutionForwardAlgo(
- cudnn, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/scratch_allocator != nullptr,
- memory_limit_bytes);
- } else {
- use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- algo = ToConvForwardAlgo(algorithm_config.algorithm());
- }
+port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmDesc& algorithm_desc,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+
+ // Query the size of the workspace and allocate it.
size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionForwardWorkspaceSize(
cudnn.handle(),
/*xDesc=*/input_nd.handle(),
/*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
+ /*yDesc=*/output_nd.handle(), /*algo=*/ToConvForwardAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
int64 size_in_bytes_int64 = size_in_bytes;
- if (TF_PREDICT_FALSE(status != CUDNN_STATUS_SUCCESS)) {
- CHECK(is_profiling) << "Cannot query the size of workspace needed "
- "for the specified algorithm: "
- << algorithm_config.algorithm().algo_id() << " "
- << ToString(status);
- // Silently return when we are profiling.
- return dnn::AlgorithmDesc();
+
+ if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionForwardWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
+ }
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
}
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
+ }
+
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<DeviceMemory<uint8>>
+AllocateCudnnConvolutionBackwardDataWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmDesc& algorithm_desc,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+
+ // Query the size of the workspace and allocate it.
+ size_t size_in_bytes;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardDataWorkspaceSize(
+ cudnn.handle(),
+ /*wDesc=*/filter.handle(),
+ /*dyDesc=*/output_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*dxDesc=*/input_nd.handle(),
+ /*algo=*/ToConvBackwardDataAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
+ int64 size_in_bytes_int64 = size_in_bytes;
+
if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
- LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- if (TF_PREDICT_TRUE(is_profiling)) {
- return dnn::AlgorithmDesc();
- }
- } else if (size_in_bytes_int64 > 0) {
- port::StatusOr<DeviceMemory<uint8>> allocated;
- if (TF_PREDICT_TRUE(scratch_allocator)) {
- allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (TF_PREDICT_TRUE(allocated.ok())) {
- *scratch = allocated.ValueOrDie();
- } else {
- if (TF_PREDICT_TRUE(is_profiling)) {
- // Silently return when we are profiling.
- return dnn::AlgorithmDesc();
- }
- LOG(WARNING) << allocated.status().error_message();
- // For the int8 case, we fail at this point since the no_scratch
- // algorithm should be set to dnn::kDefaultAlgorithm.
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- }
- }
- if (TF_PREDICT_FALSE(!allocated.ok())) {
- if (algorithm_config.algorithm_no_scratch().is_default()) {
- use_tensor_ops = true;
- algo = GetCudnnConvolutionForwardAlgo(
- cudnn, input_nd, filter, conv, output_nd,
- /*specify_workspace_limit=*/false, 0);
- } else {
- use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- algo = ToConvForwardAlgo(algorithm_config.algorithm_no_scratch());
- }
- }
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
+ }
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
+ }
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
}
- return dnn::AlgorithmDesc(algo, use_tensor_ops);
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<DeviceMemory<uint8>>
+AllocateCudnnConvolutionBackwardFilterWorkspace(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmDesc& algorithm_desc,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd,
+ ScratchAllocator* scratch_allocator) {
+ // TODO(csigg): This has side effects on the convolution descriptor. It is
+ // functionally correct because the convolution is run with the algorithm of
+ // the last call to this function, but should be fixed anyway.
+ conv.set_use_tensor_op_math(algorithm_desc.tensor_ops_enabled());
+
+ // Query the size of the workspace and allocate it.
+ size_t size_in_bytes;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ cudnn.handle(),
+ /*xDesc=*/input_nd.handle(),
+ /*dyDesc=*/output_nd.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(),
+ /*algo=*/ToConvBackwardFilterAlgo(algorithm_desc),
+ /*sizeInBytes=*/&size_in_bytes));
+ int64 size_in_bytes_int64 = size_in_bytes;
+
+ if (TF_PREDICT_FALSE(size_in_bytes_int64 < 0)) {
+ return port::Status(
+ port::error::INTERNAL,
+ "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
+ "negative sizeInBytes value. This could be a cudnn bug.");
+ }
+
+ if (size_in_bytes_int64 == 0) {
+ return DeviceMemory<uint8>();
+ }
+
+ if (TF_PREDICT_FALSE(!scratch_allocator)) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "No scratch allocator provided");
+ }
+
+ return scratch_allocator->AllocateBytes(stream, size_in_bytes);
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionFwdAlgo_t algo,
+ GetCudnnConvolutionForwardAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
+ }
+
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(
+ *scratch, AllocateCudnnConvolutionForwardWorkspace(
+ stream, cudnn, algorithm_config.algorithm_no_scratch(),
+ input_nd, filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdDataAlgo_t algo,
+ GetCudnnConvolutionBackwardDataAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
+ }
+
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(
+ *scratch, AllocateCudnnConvolutionBackwardDataWorkspace(
+ stream, cudnn, algorithm_config.algorithm_no_scratch(),
+ input_nd, filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
+}
+
+port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
+ Stream* stream, const CudnnHandle& cudnn,
+ const dnn::AlgorithmConfig& algorithm_config,
+ const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
+ const CudnnConvolutionDescriptor& conv,
+ const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
+ DeviceMemory<uint8>* scratch) {
+ dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
+ if (algorithm_config.algorithm().is_default()) {
+ // Pick fastest algorithm within memory limit according to cuDNN's
+ // heuristics.
+ bool specify_workspace_limit = scratch_allocator != nullptr;
+ auto memory_limit_bytes =
+ specify_workspace_limit
+ ? std::max(scratch_allocator->GetMemoryLimitInBytes(stream), 0ll)
+ : 0ll;
+ SE_ASSIGN_OR_RETURN(cudnnConvolutionBwdFilterAlgo_t algo,
+ GetCudnnConvolutionBackwardFilterAlgo(
+ cudnn, input_nd, filter, conv, output_nd,
+ specify_workspace_limit, memory_limit_bytes));
+ algo_desc = dnn::AlgorithmDesc(
+ algo, algorithm_config.algorithm().tensor_ops_enabled());
+ }
+
+ auto scratch_or = AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, algo_desc, input_nd, filter, conv, output_nd,
+ scratch_allocator);
+
+ if (scratch_or.ok()) {
+ *scratch = scratch_or.ValueOrDie();
+ return algo_desc;
+ }
+
+ // Failed to allocate workspace for the first algorithm, fall back to the
+ // no_scratch algorithm.
+ if (algorithm_config.algorithm_no_scratch().is_default()) {
+ return port::Status(
+ port::error::INVALID_ARGUMENT,
+ "The primary convolution algorithm failed memory allocation, "
+ "while a secondary algorithm is not provided.");
+ }
+
+ SE_ASSIGN_OR_RETURN(*scratch,
+ AllocateCudnnConvolutionBackwardFilterWorkspace(
+ stream, cudnn, algorithm_config.algorithm(), input_nd,
+ filter, conv, output_nd, scratch_allocator));
+ return algorithm_config.algorithm_no_scratch();
}
// A helper class to set env-vars and choose options for cudnn-related
@@ -2215,9 +2284,7 @@ class CudnnEnvVar {
// algorithm through an env-var "TF_ENABLE_FFT_TILING_FORWARD=1".
struct FftTilingForward {
static constexpr const char* kName = "TF_ENABLE_FFT_TILING_FORWARD";
- // TODO(csigg): Enabling this algo causes XLA test failures, for example in
- // platforms/xla/tests/internal:convolution_test_gpu. See b/80018418.
- static constexpr bool kDefaultFlag = false; // CUDNN_VERSION >= 7000;
+ static constexpr bool kDefaultFlag = CUDNN_VERSION >= 7000;
};
// A helper struct to decide whether to enable the WINOGRAD_NONFUSED algorithms.
@@ -2282,8 +2349,6 @@ struct RnnDoFP32ComputationFP16Input {
static constexpr bool kDefaultFlag = false;
};
-// A helper function to return the internal compute type for
-// RNNs in cudnn.
cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
switch (data_type) {
case dnn::DataType::kFloat:
@@ -2304,7 +2369,7 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
} // namespace
template <class T>
-bool CudnnSupport::DoConvolveImpl(
+port::Status CudnnSupport::DoConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::FilterDescriptor& filter_descriptor,
@@ -2315,11 +2380,11 @@ bool CudnnSupport::DoConvolveImpl(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
- ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type);
- ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
- ScopedConvolutionDescriptor conv(convolution_descriptor,
- GetConvComputeType<T>());
+ CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
+ CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
auto cudnn = cudnn_->GetHandle(parent_, stream);
// Alpha is the scaling factor for input.
@@ -2334,177 +2399,75 @@ bool CudnnSupport::DoConvolveImpl(
: static_cast<void*>(&fbeta);
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionFwdAlgo_t algo;
- bool use_tensor_ops;
- DeviceMemory<uint8> scratch;
-
- // TODO(pauldonnelly): Replace the following code with a call to
- // GetCudnnConvolutionForwardAlgorithm().
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm = [&](bool specify_limit) {
- cudnnConvolutionFwdPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
- cudnnConvolutionFwdAlgo_t algo_to_use;
- auto status = cudnnGetConvolutionForwardAlgorithm(
- cudnn.handle(), input_nd.handle(), filter.handle(), conv.handle(),
- output_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing forward "
- "convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
- use_tensor_ops = true;
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionForwardAlgorithm(
+ stream, cudnn, algorithm_config, input_nd, filter,
+ conv, output_nd, scratch_allocator, &scratch));
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvForwardAlgo(algotype);
- use_tensor_ops = algotype.tensor_ops_enabled();
- conv.set_use_tensor_op_math(use_tensor_ops);
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionForwardWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*wDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
- /*yDesc=*/output_nd.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvForwardAlgo(algotype);
- use_tensor_ops = algotype.tensor_ops_enabled();
- conv.set_use_tensor_op_math(use_tensor_ops);
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING) << "cudnnGetConvolutionForwardWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
- auto status = cudnnConvolutionForward(
+
+ // Report an error if we might be hitting a cuDNN bug that accesses illegal
+ // memory. See nvbugs/2138754, b/80018418.
+ SE_RETURN_IF_ERROR([&] {
+ if (algo_desc.algo_id() != CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) {
+ return port::Status::OK();
+ }
+ if (input_descriptor.ndims() < 3) {
+ return port::Status::OK();
+ }
+ // Checks that a*b is within the valid range (as provided by NVIDIA).
+ auto check_sizes = [](size_t a, size_t b) {
+ if ((a * b * 4608 - 1) >> 31 == 0) {
+ return port::Status::OK();
+ }
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ "This configuration potentially accesses illegal memory.");
+ };
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.feature_map_count(),
+ output_descriptor.feature_map_count()));
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
+ input_descriptor.feature_map_count()));
+ SE_RETURN_IF_ERROR(check_sizes(input_descriptor.count(),
+ output_descriptor.feature_map_count()));
+ return port::Status::OK();
+ }());
+
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
- /*algo=*/algo, /*workSpace=*/scratch.opaque(),
+ /*algo=*/ToConvForwardAlgo(algo_desc), /*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/beta,
- /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+ /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
- }
- if (status == CUDNN_STATUS_SUCCESS) {
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- timer->Destroy();
- }
-
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
- }
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
template <typename Type, typename BiasType, typename ScaleType,
int cudnn_data_type, int cudnn_compute_type>
-bool CudnnSupport::DoFusedConvolveImpl(
+port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
@@ -2517,56 +2480,48 @@ bool CudnnSupport::DoFusedConvolveImpl(
DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- ScopedTensorDescriptor conv_input_nd(
+ if (activation_mode != dnn::ActivationMode::kRelu) {
+ return port::Status(port::error::INVALID_ARGUMENT,
+ "cudnnConvolutionBiasActivationForward() only supports "
+ "Relu activation.");
+ }
+
+ CudnnTensorDescriptor conv_input_nd(
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- ScopedTensorDescriptor output_nd(
+ CudnnTensorDescriptor output_nd(
output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- ScopedFilterDescriptor filter(filter_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type));
- ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
- ScopedConvolutionDescriptor conv(
+ CudnnFilterDescriptor filter(filter_descriptor,
+ static_cast<cudnnDataType_t>(cudnn_data_type));
+ CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
+ CudnnConvolutionDescriptor conv(
convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
auto cudnn = cudnn_->GetHandle(parent_, stream);
+
const bool is_profiling = output_profile_result != nullptr;
- DeviceMemory<uint8> scratch;
- dnn::AlgorithmDesc algotype = GetCudnnConvolutionForwardAlgorithm(
- stream, cudnn, algorithm_config, is_profiling, conv_input_nd, filter,
- conv, output_nd, scratch_allocator, &scratch);
- if (algotype.is_default()) {
- if (!is_profiling) {
- LOG(ERROR) << "No suitable algorithm found";
- }
- return false;
- }
- auto algo = static_cast<cudnnConvolutionFwdAlgo_t>(algotype.algo_id());
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- if (activation_mode != dnn::ActivationMode::kRelu) {
- LOG(ERROR) << "cudnnConvolutionBiasActivationForward() only supports Relu "
- "activation.";
- return false;
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(
+ dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionForwardAlgorithm(
+ stream, cudnn, algorithm_config, conv_input_nd, filter, conv,
+ output_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- if (!timer->Init()) {
- return false;
- }
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- if (!timer->Start(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
}
}
// CUDNN v6 only supports CUDNN_NOT_PROPAGATE_NAN as the reluNanOpt for
// activation descriptor. Note that this will change the nan propagation
// behavior from separate conv, bias, and relu (which by default is
// CUDNN_PROPAGATE_NAN.
- ScopedActivationDescriptor activation_desc(
+ CudnnActivationDescriptor activation_desc(
activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
: side_input_data.opaque();
@@ -2576,7 +2531,8 @@ bool CudnnSupport::DoFusedConvolveImpl(
<< "\nconv_input_data.opaque() = " << conv_input_data.opaque()
<< "\nfilter.handle() = " << filter.handle()
<< "\nfilter_data.opaque() = " << filter_data.opaque()
- << "\nconv.handle() = " << conv.handle() << "\nalgo = " << algo
+ << "\nconv.handle() = " << conv.handle()
+ << "\nalgo = " << algo_desc.algo_id()
<< "\nscratch.opaque() = " << scratch.opaque()
<< "\nscratch.size() = " << scratch.size()
<< "\nside_input_scale = " << side_input_scale
@@ -2588,41 +2544,29 @@ bool CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
- auto status = cudnnConvolutionBiasActivationForward(
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
/*srcDesc=*/conv_input_nd.handle(), /*srcData=*/conv_input_data.opaque(),
/*filterDesc=*/filter.handle(), /*filterData=*/filter_data.opaque(),
- /*convDesc=*/conv.handle(), algo, /*workSpace=*/scratch.opaque(),
+ /*convDesc=*/conv.handle(), ToConvForwardAlgo(algo_desc),
+ /*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(), /*alpha2=*/&side_input_scale,
/*zDesc=*/output_nd.handle(), /*z=*/side_input_data_ptr,
/*biasDesc=*/bias_nd.handle(), /*bias=*/biases.opaque(),
/*activationDesc=*/activation_desc.handle(),
- /*destDesc=*/output_nd.handle(), /*destData=*/output_data->opaque());
+ /*yDesc=*/output_nd.handle(), /*y=*/output_data->opaque()));
if (is_profiling) {
if (!timer->Stop(AsCUDAStream(stream))) {
- timer->Destroy();
- return false;
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- if (status == CUDNN_STATUS_SUCCESS) {
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
-
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
- }
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+ return port::Status::OK();
}
bool CudnnSupport::GetConvolveAlgorithms(
@@ -2746,11 +2690,13 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* saved_inv_var, bool is_training,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- return DoBatchNormalizationForwardImpl<float, float>(
- stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
- estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
- batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
- std::move(var_to_inv_var), std::move(inv_var_to_var));
+ return IsStatusOk(
+ DoBatchNormalizationForwardImpl<float, float>(
+ stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
+ offset, estimated_mean, estimated_variance, x_desc, scale_offset_desc,
+ epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
+ is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
+ /*report_error=*/true);
}
bool CudnnSupport::DoBatchNormalizationForward(
@@ -2765,15 +2711,17 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* saved_inv_var, bool is_training,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- return DoBatchNormalizationForwardImpl<Eigen::half, float>(
- stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
- estimated_mean, estimated_variance, x_desc, scale_offset_desc, epsilon, y,
- batch_mean, batch_var, saved_mean, saved_inv_var, is_training,
- std::move(var_to_inv_var), std::move(inv_var_to_var));
+ return IsStatusOk(
+ DoBatchNormalizationForwardImpl<Eigen::half, float>(
+ stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
+ estimated_mean, estimated_variance, x_desc, scale_offset_desc,
+ epsilon, y, batch_mean, batch_var, saved_mean, saved_inv_var,
+ is_training, std::move(var_to_inv_var), std::move(inv_var_to_var)),
+ /*report_error=*/true);
}
template <class T, class U>
-bool CudnnSupport::DoBatchNormalizationForwardImpl(
+port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type,
dnn::DataType scale_data_type, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
@@ -2785,8 +2733,8 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
- ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
- ScopedTensorDescriptor scale_offset_descriptor(
+ CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
+ CudnnTensorDescriptor scale_offset_descriptor(
scale_offset_desc, ToCudnnDataType(scale_data_type));
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
@@ -2798,7 +2746,6 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
float zero = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = CUDNN_STATUS_SUCCESS;
if (is_training) {
CHECK_EQ(batch_mean->is_null(), batch_var->is_null())
<< "batch_mean and batch_var must both be null or both be non-null";
@@ -2815,26 +2762,21 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl(
batch_var_opaque = nullptr;
}
- status = cudnnBatchNormalizationForwardTraining(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardTraining(
cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
scale.opaque(), offset.opaque(), 1.0, batch_mean_opaque,
batch_var_opaque, epsilon, saved_mean->opaque(),
- saved_inv_var->opaque());
+ saved_inv_var->opaque()));
} else {
const void* maybe_inv_var = estimated_variance.opaque();
- status = cudnnBatchNormalizationForwardInference(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationForwardInference(
cudnn.handle(), mode, &one, &zero, x_descriptor.handle(), x.opaque(),
x_descriptor.handle(), y->opaque(), scale_offset_descriptor.handle(),
scale.opaque(), offset.opaque(), estimated_mean.opaque(), maybe_inv_var,
- epsilon);
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward batch normalization on stream: "
- << ToString(status);
- return false;
+ epsilon));
}
- return true;
+ return port::Status::OK();
}
bool CudnnSupport::DoBatchNormalizationBackward(
@@ -2845,10 +2787,11 @@ bool CudnnSupport::DoBatchNormalizationBackward(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<float>* x_backprop, DeviceMemory<float>* scale_backprop,
DeviceMemory<float>* offset_backprop) {
- return DoBatchNormalizationBackwardImpl(
- stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
- inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
- offset_backprop);
+ return IsStatusOk(DoBatchNormalizationBackwardImpl(
+ stream, CUDNN_DATA_FLOAT, CUDNN_DATA_FLOAT, y_backprop,
+ x, scale, mean, inv_var, x_desc, scale_offset_desc,
+ epsilon, x_backprop, scale_backprop, offset_backprop),
+ /*report_error=*/true);
}
bool CudnnSupport::DoBatchNormalizationBackward(
@@ -2859,14 +2802,15 @@ bool CudnnSupport::DoBatchNormalizationBackward(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<Eigen::half>* x_backprop, DeviceMemory<float>* scale_backprop,
DeviceMemory<float>* offset_backprop) {
- return DoBatchNormalizationBackwardImpl(
- stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop, x, scale, mean,
- inv_var, x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop,
- offset_backprop);
+ return IsStatusOk(DoBatchNormalizationBackwardImpl(
+ stream, CUDNN_DATA_HALF, CUDNN_DATA_FLOAT, y_backprop,
+ x, scale, mean, inv_var, x_desc, scale_offset_desc,
+ epsilon, x_backprop, scale_backprop, offset_backprop),
+ /*report_error=*/true);
}
template <class T, class U>
-bool CudnnSupport::DoBatchNormalizationBackwardImpl(
+port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
Stream* stream, int cudnn_input_type, int cudnn_scale_type,
const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
@@ -2874,9 +2818,9 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
DeviceMemory<U>* offset_backprop) {
- ScopedTensorDescriptor x_descriptor(
+ CudnnTensorDescriptor x_descriptor(
x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
- ScopedTensorDescriptor scale_offset_descriptor(
+ CudnnTensorDescriptor scale_offset_descriptor(
scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
@@ -2889,19 +2833,14 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnBatchNormalizationBackward(
+ RETURN_IF_CUDNN_ERROR(cudnnBatchNormalizationBackward(
cudnn.handle(), mode, &one, &zero, &one, &zero, x_descriptor.handle(),
x.opaque(), x_descriptor.handle(), y_backprop.opaque(),
x_descriptor.handle(), x_backprop->opaque(),
scale_offset_descriptor.handle(), scale.opaque(),
scale_backprop->opaque(), offset_backprop->opaque(), epsilon,
- mean.opaque(), inv_var.opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward batch normalization on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ mean.opaque(), inv_var.opaque()));
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolve(
@@ -2914,10 +2853,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<float>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<float>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolve(
@@ -2930,10 +2871,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<double>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<double>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolve(
@@ -2946,10 +2889,12 @@ bool CudnnSupport::DoConvolve(
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveImpl<Eigen::half>(
- stream, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveImpl<Eigen::half>(
+ stream, batch_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output_data,
+ scratch_allocator, algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -2965,13 +2910,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<double>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
+ CUDNN_DATA_DOUBLE>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -2987,13 +2934,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<float>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
+ CUDNN_DATA_FLOAT>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -3010,13 +2959,15 @@ bool CudnnSupport::DoFusedConvolve(
DeviceMemory<Eigen::half>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
+ CUDNN_DATA_FLOAT>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoFusedConvolve(
@@ -3040,13 +2991,15 @@ bool CudnnSupport::DoFusedConvolve(
"supported on GPUs with compute capability 6.1 or later.";
return false;
}
- return DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
- stream, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output_data, scratch_allocator, algorithm_config,
- output_profile_result);
+ return IsStatusOk(
+ DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
+ CUDNN_DATA_INT32>(
+ stream, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoTransformTensor(Stream* stream,
@@ -3057,27 +3010,22 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
dnn::DataType output_type, float scale,
DeviceMemoryBase* output_data) {
float beta = 0.0f;
- ScopedTensorDescriptor input_tensor_desc(
+ CudnnTensorDescriptor input_tensor_desc(
input_desc, ToCudnnDataType(input_type, input_desc.layout()));
- ScopedTensorDescriptor output_tensor_desc(
+ CudnnTensorDescriptor output_tensor_desc(
output_desc, ToCudnnDataType(output_type, output_desc.layout()));
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnTransformTensor(
- cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
- &beta, output_tensor_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "Could not transform a tensor with layout "
- << input_desc.ToString() << " and data type "
- << static_cast<int>(input_type) << " to another with layout "
- << output_desc.ToString() << " and data type "
- << static_cast<int>(output_type) << ": " << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnTransformTensor(
+ cudnn.handle(), &scale, input_tensor_desc.handle(), input_data.opaque(),
+ &beta, output_tensor_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardDataImpl(
+port::Status CudnnSupport::DoConvolveBackwardDataImpl(
Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
const dnn::BatchDescriptor& output_descriptor,
@@ -3101,146 +3049,48 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
- ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
- ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
- ScopedConvolutionDescriptor conv(convolution_descriptor,
- GetConvComputeType<T>());
+ CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
+ CudnnTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
+ CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionBwdDataAlgo_t algo;
- DeviceMemory<uint8> scratch;
-
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
- auto get_algorithm =
- [&](bool specify_limit) -> cudnnConvolutionBwdDataAlgo_t {
- cudnnConvolutionBwdDataPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
- cudnnConvolutionBwdDataAlgo_t algo_to_use;
- cudnnStatus_t status = cudnnGetConvolutionBackwardDataAlgorithm(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "data convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
-
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvBackwardDataAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardDataWorkspaceSize(
- cudnn.handle(),
- /*filterDesc=*/filter.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/in_back_nd.handle(),
- /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvBackwardDataAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING) << "cudnnGetConvolutionBackwardDataWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionBackwardDataAlgorithm(
+ stream, cudnn, algorithm_config, in_back_nd, filter,
+ conv, out_back_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- timer->Init();
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- timer->Start(AsCUDAStream(stream));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
+ }
+ }
+
+ // Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
+ // zero-initialized.
+ // TODO(timshen): Add an nvbugs/ link.
+ if (CUDNN_VERSION >= 7000 &&
+ algorithm_config.algorithm().algo_id() ==
+ CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 &&
+ cudnn_type == CUDNN_DATA_HALF &&
+ algorithm_config.algorithm().tensor_ops_enabled() &&
+ input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ filter_descriptor.layout() == dnn::FilterLayout::kOutputInputYX &&
+ output_descriptor.layout() == dnn::DataLayout::kBatchDepthYX &&
+ (convolution_descriptor.vertical_filter_stride() > 1 ||
+ convolution_descriptor.horizontal_filter_stride() > 1)) {
+ stream->ThenMemZero(&scratch, scratch.size());
}
- auto status =
+ RETURN_IF_CUDNN_ERROR(
cudnnConvolutionBackwardData(cudnn.handle(),
/*alpha=*/alpha,
/*wDesc=*/filter.handle(),
@@ -3248,32 +3098,22 @@ bool CudnnSupport::DoConvolveBackwardDataImpl(
/*dyDesc=*/out_back_nd.handle(),
/*dy=*/backward_output_data.opaque(),
/*convDesc=*/conv.handle(),
- /*algo=*/algo,
+ /*algo=*/ToConvBackwardDataAlgo(algo_desc),
/*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(),
/*beta=*/beta,
/*dxDesc=*/in_back_nd.handle(),
- /*dx=*/backward_input_data->opaque());
+ /*dx=*/backward_input_data->opaque()));
if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- if (status == CUDNN_STATUS_SUCCESS) {
- bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3287,11 +3127,13 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3305,11 +3147,13 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardData(
@@ -3323,15 +3167,17 @@ bool CudnnSupport::DoConvolveBackwardData(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor,
- backward_input_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardDataImpl(stream, filter_descriptor, filter_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor,
+ backward_input_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardFilterImpl(
+port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& output_descriptor,
@@ -3355,148 +3201,60 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
auto cudnn = cudnn_->GetHandle(parent_, stream);
- ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
- ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
- ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
- ScopedConvolutionDescriptor conv(convolution_descriptor,
- GetConvComputeType<T>());
+ CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
+ CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetConvComputeType<T>());
const bool is_profiling = output_profile_result != nullptr;
- cudnnConvolutionBwdFilterAlgo_t algo;
- DeviceMemory<uint8> scratch;
-
- if (algorithm_config.algorithm().is_default()) {
- // With the default algorithm, use Cudnn's heuristics.
-
- // Lambda that retrieves the algorithm.
- // specify_limit will occur when we have a scratch allocator and it succeeds
- // in allocating; otherwise, we'll fall back to the "no workspace" version.
- auto get_algorithm = [&](bool specify_limit) {
- cudnnConvolutionBwdFilterPreference_t preference =
- specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
- : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
-
- auto memory_limit_bytes =
- scratch_allocator == nullptr
- ? 0
- : scratch_allocator->GetMemoryLimitInBytes(stream);
- if (memory_limit_bytes < 0) {
- memory_limit_bytes = 0;
- }
-
- cudnnConvolutionBwdFilterAlgo_t algo_to_use;
- cudnnStatus_t status = cudnnGetConvolutionBackwardFilterAlgorithm(
- cudnn.handle(),
- /*srcDesc=*/input_nd.handle(),
- /*diffDesc=*/out_back_nd.handle(),
- /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(),
- /*preference=*/preference,
- /*memoryLimitInBytes=*/memory_limit_bytes,
- /*algo=*/&algo_to_use);
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
- "algorithm for doing backward "
- "filter convolution";
- return algo_to_use;
- };
-
- algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
-
- if (scratch_allocator != nullptr) {
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- int64 size_in_bytes_int64 = size_in_bytes;
- if (status == CUDNN_STATUS_SUCCESS && size_in_bytes_int64 != 0) {
- if (size_in_bytes_int64 > 0) {
- auto allocated =
- scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- } else {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
- }
- // If we didn't allocate any scratch space (perhaps because of failed
- // allocation), we force a switch back to the "no workspace" algorithm.
- if (scratch == nullptr) {
- algo = get_algorithm(/*specify_limit=*/false);
- }
- } else {
- // An algorithm has been specified.
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm();
- algo = ToConvBackwardFilterAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
-
- size_t size_in_bytes;
- auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize(
- cudnn.handle(),
- /*xDesc=*/input_nd.handle(),
- /*dyDesc=*/out_back_nd.handle(), /*convDesc=*/conv.handle(),
- /*gradDesc=*/filter.handle(), /*algo=*/algo,
- /*sizeInBytes=*/&size_in_bytes);
- if (status != CUDNN_STATUS_SUCCESS) {
- if (is_profiling) {
- // Silently return when we are profiling.
- return false;
- }
- LOG(FATAL) << "Cannot query the size of workspace needed for the given "
- "algorithm: "
- << algorithm_config.algorithm().algo_id();
- }
- int64 size_in_bytes_int64 = size_in_bytes;
- if (size_in_bytes_int64 > 0) {
- if (scratch_allocator == nullptr) {
- LOG(FATAL) << "An allocator must be specified when scratch memory is "
- "needed";
- }
- auto allocated = scratch_allocator->AllocateBytes(stream, size_in_bytes);
- if (is_profiling && !allocated.ok()) {
- // Silently return when we are profiling.
- return false;
- }
- if (allocated.ok()) {
- scratch = allocated.ValueOrDie();
- } else {
- LOG(WARNING) << allocated.status().error_message();
- }
- if (scratch == nullptr) {
- CHECK(!algorithm_config.algorithm_no_scratch().is_default())
- << "The primary convolution algorithm failed memory allocation, "
- "while a secondary algorithm is not provided.";
- dnn::AlgorithmDesc algotype = algorithm_config.algorithm_no_scratch();
- algo = ToConvBackwardFilterAlgo(algotype);
- conv.set_use_tensor_op_math(algotype.tensor_ops_enabled());
- }
- } else if (size_in_bytes_int64 < 0) {
- LOG(WARNING)
- << "cudnnGetConvolutionBackwardFilterWorkspaceSize() returned "
- "negative sizeInBytes value. This could be a cudnn bug.";
- }
- }
+ DeviceMemory<uint8> scratch;
+ SE_ASSIGN_OR_RETURN(dnn::AlgorithmDesc algo_desc,
+ GetCudnnConvolutionBackwardFilterAlgorithm(
+ stream, cudnn, algorithm_config, input_nd, filter,
+ conv, out_back_nd, scratch_allocator, &scratch));
- std::unique_ptr<CUDATimer> timer;
+ std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (is_profiling) {
timer.reset(new CUDATimer(parent_)); // NOLINT
- timer->Init();
// The start and stop of the timer should be as close to the Cudnn call as
// possible. It is still possible for other threads to issue workload on
// to this stream. So it could take multiple profiling measurements.
- timer->Start(AsCUDAStream(stream));
+ if (!timer->Init() || !timer->Start(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to start timer");
+ }
}
- auto status = cudnnConvolutionBackwardFilter(
+ // Report an error if we might be hitting a cuDNN bug that produces incorrect
+ // results. See nvbugs/2072856
+ SE_RETURN_IF_ERROR([&] {
+ if (algo_desc.algo_id() != CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) {
+ return port::Status::OK();
+ }
+ if (output_descriptor.height() > 1 && output_descriptor.width() > 1) {
+ return port::Status::OK();
+ }
+ int convolution_size = output_descriptor.height() > 1
+ ? filter_descriptor.input_filter_height()
+ : filter_descriptor.input_filter_width();
+ if (convolution_size <= 32) {
+ return port::Status::OK();
+ }
+ cudnnConvolutionMode_t convolution_mode;
+ cudnnDataType_t compute_type;
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdDescriptor(
+ conv.handle(), 0, nullptr, nullptr, nullptr, nullptr, &convolution_mode,
+ &compute_type));
+ if (convolution_mode != CUDNN_CONVOLUTION) {
+ return port::Status::OK();
+ }
+ return port::Status(
+ port::error::FAILED_PRECONDITION,
+ "This configuration potentially produces incorrect results.");
+ }());
+
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
/*srcDesc=*/input_nd.handle(),
@@ -3504,33 +3262,22 @@ bool CudnnSupport::DoConvolveBackwardFilterImpl(
/*diffDesc=*/out_back_nd.handle(),
/*diffData=*/backward_output_data.opaque(),
/*convDesc=*/conv.handle(),
- /*algo=*/algo,
+ /*algo=*/ToConvBackwardFilterAlgo(algo_desc),
/*workSpace=*/scratch.opaque(),
/*workSpaceSizeInBytes=*/scratch.size(),
/*beta=*/beta,
/*gradDesc=*/filter.handle(),
- /*gradData=*/backward_filter_data->opaque());
-
+ /*dw=*/backward_filter_data->opaque()));
if (is_profiling) {
- timer->Stop(AsCUDAStream(stream));
- if (status == CUDNN_STATUS_SUCCESS) {
- bool use_tensor_ops = algorithm_config.algorithm().tensor_ops_enabled();
- dnn::AlgorithmDesc algotype(algo, use_tensor_ops);
- output_profile_result->set_algorithm(algotype);
- output_profile_result->set_elapsed_time_in_ms(
- timer->GetElapsedMilliseconds());
- }
- timer->Destroy();
- }
- if (status != CUDNN_STATUS_SUCCESS) {
- // Silently return when we are profiling.
- if (!is_profiling) {
- LOG(ERROR) << "failed to enqueue convolution on stream: "
- << ToString(status);
+ if (!timer->Stop(AsCUDAStream(stream))) {
+ return port::Status(port::error::INTERNAL, "Failed to stop timer");
}
- return false;
+ output_profile_result->set_algorithm(algo_desc);
+ output_profile_result->set_elapsed_time_in_ms(
+ timer->GetElapsedMilliseconds());
}
- return true;
+
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3544,11 +3291,13 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3562,11 +3311,13 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -3580,22 +3331,24 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
- return DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
- output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor,
- backward_filter_data, scratch_allocator,
- algorithm_config, output_profile_result);
+ return IsStatusOk(
+ DoConvolveBackwardFilterImpl(stream, input_descriptor, input_data,
+ output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor,
+ backward_filter_data, scratch_allocator,
+ algorithm_config, output_profile_result),
+ /*report_error=*/!output_profile_result);
}
template <class T>
-bool CudnnSupport::DoConvolveBackwardBiasImpl(
+port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data) {
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
- ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
- ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
+ CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
+ CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
// Alpha is the scaling factor for input.
float alpha = 1.0;
@@ -3603,15 +3356,10 @@ bool CudnnSupport::DoConvolveBackwardBiasImpl(
float beta = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnConvolutionBackwardBias(
+ RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardBias(
cudnn.handle(), &alpha, input_nd.handle(), input_data.opaque(), &beta,
- bias_nd.handle(), backward_bias_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward convolution on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ bias_nd.handle(), backward_bias_data->opaque()));
+ return port::Status::OK();
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3619,8 +3367,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3628,8 +3378,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoConvolveBackwardBias(
@@ -3637,8 +3389,10 @@ bool CudnnSupport::DoConvolveBackwardBias(
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) {
- return DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
- bias_descriptor, backward_bias_data);
+ return IsStatusOk(
+ DoConvolveBackwardBiasImpl(stream, input_descriptor, input_data,
+ bias_descriptor, backward_bias_data),
+ /*report_error=*/true);
}
bool CudnnSupport::DoMatMul(Stream* stream,
@@ -3781,7 +3535,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
const DeviceMemory<float>& biases,
const dnn::BatchDescriptor& dimensions,
DeviceMemory<float>* output_data) {
- ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
dnn::BatchDescriptor bias_dimensions;
bias_dimensions.set_count(1)
@@ -3789,7 +3543,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
.set_height(1)
.set_width(1)
.set_layout(dnn::DataLayout::kBatchYXDepth);
- ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
// cudnnAddTensor after R3 is in-place, so we need to copy input_data to
// output_data before doing the addition, unless the input and
@@ -3810,16 +3564,13 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnAddTensor(
- cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(), &beta,
- input_descriptor.handle(), output_data->opaque());
-
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
- return false;
- }
-
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnAddTensor(
+ cudnn.handle(), &alpha, bias_descriptor.handle(), biases.opaque(),
+ &beta, input_descriptor.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoActivate(Stream* stream,
@@ -3828,26 +3579,23 @@ bool CudnnSupport::DoActivate(Stream* stream,
const DeviceMemory<float>& input_data,
DeviceMemory<float>* output_data,
uint64 options) {
- ScopedActivationDescriptor activation_desc(
+ CudnnActivationDescriptor activation_desc(
activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
- ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
// Alpha is the input scaling factor.
float alpha = 1.0;
// Beta is the output scaling factor.
float beta = 0.0;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnActivationForward(
- cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
- input_data.opaque(), &beta, input_nd.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "stream " << stream
- << " could not enqueue activation: " << ToString(status);
- return false;
- }
-
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnActivationForward(
+ cudnn.handle(), activation_desc.handle(), &alpha, input_nd.handle(),
+ input_data.opaque(), &beta, input_nd.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3861,20 +3609,18 @@ bool CudnnSupport::DoPoolForward(
// Beta is the scaling factor for output.
double beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3888,20 +3634,18 @@ bool CudnnSupport::DoPoolForward(
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolForward(
@@ -3915,19 +3659,17 @@ bool CudnnSupport::DoPoolForward(
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingForward(
- cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
- input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue forward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, src_desc.handle(),
+ input_data.opaque(), &beta, dest_desc.handle(), output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -3943,22 +3685,20 @@ bool CudnnSupport::DoPoolBackward(
// Beta is the scaling factor for output.
double beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -3974,22 +3714,20 @@ bool CudnnSupport::DoPoolBackward(
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoPoolBackward(
@@ -4005,22 +3743,20 @@ bool CudnnSupport::DoPoolBackward(
// Beta is the scaling factor for output.
float beta = 0.0;
- ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
- ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
- ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
+ CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
+ CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
+ CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnPoolingBackward(
- cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
- output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
- src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
- output_diff_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to enqueue backward pooling on stream: "
- << ToString(status);
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnPoolingBackward(
+ cudnn.handle(), pooling_desc.handle(), &alpha, dest_desc.handle(),
+ output_data.opaque(), dest_desc.handle(), input_diff_data.opaque(),
+ src_desc.handle(), input_data.opaque(), &beta, src_desc.handle(),
+ output_diff_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoNormalize(
@@ -4044,8 +3780,8 @@ bool CudnnSupport::DoNormalizeWithDimensions(
return false;
}
- ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
- ScopedNormalizeDescriptor normalize(normalize_descriptor);
+ CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
+ CudnnNormalizeDescriptor normalize(normalize_descriptor);
// Alpha is the scaling factor for input.
float alpha = 1.0f;
@@ -4055,15 +3791,14 @@ bool CudnnSupport::DoNormalizeWithDimensions(
auto cudnn = cudnn_->GetHandle(parent_, stream);
// Launch the normalization.
- auto status = cudnnLRNCrossChannelForward(
- cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
- dims.handle(), input_data.opaque(), &beta, dims.handle(),
- output_data->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to run cudnnLRNCrossChannelForward";
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelForward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ &alpha, dims.handle(), input_data.opaque(), &beta, dims.handle(),
+ output_data->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoNormalizeBackwardWithDimensions(
@@ -4082,23 +3817,22 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
return false;
}
- ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
- ScopedNormalizeDescriptor normalize(normalize_descriptor);
+ CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
+ CudnnNormalizeDescriptor normalize(normalize_descriptor);
float alpha = 1.0f;
float beta = 0.0f;
auto cudnn = cudnn_->GetHandle(parent_, stream);
- auto status = cudnnLRNCrossChannelBackward(
- cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha,
- dims.handle(), normalized_data.opaque(), dims.handle(),
- normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
- &beta, dims.handle(), raw_variable_gradient->opaque());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "failed to run cudnnLRNCrossChannelBackward";
- return false;
- }
- return true;
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnLRNCrossChannelBackward(
+ cudnn.handle(), normalize.handle(), CUDNN_LRN_CROSS_CHANNEL_DIM1,
+ &alpha, dims.handle(), normalized_data.opaque(), dims.handle(),
+ normalized_variable_gradient.opaque(), dims.handle(), raw_data.opaque(),
+ &beta, dims.handle(), raw_variable_gradient->opaque()));
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
bool CudnnSupport::DoDepthConcatenate(
@@ -4207,30 +3941,26 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
const dnn::FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
dnn::BatchDescriptor* output_batch_descriptor) {
- ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
- ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
- ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
+ CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
+ CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
+ CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
int dn = batch_descriptor.ndims() + 2;
std::vector<int> dims(dn); // in BDYX
- auto status = cudnnGetConvolutionNdForwardOutputDim(
- conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data());
- if (status != CUDNN_STATUS_SUCCESS) {
- LOG(ERROR) << "could not get output tensor for convolution: "
- << ToString(status);
- return false;
- }
-
- output_batch_descriptor->set_count(dims[0])
- .set_feature_map_count(dims[1])
- .set_layout(batch_descriptor.layout());
-
- for (int i = 0; i < batch_descriptor.ndims(); i++) {
- output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
- dims.rbegin()[i]);
- }
+ auto status = [&] {
+ RETURN_IF_CUDNN_ERROR(cudnnGetConvolutionNdForwardOutputDim(
+ conv.handle(), input_nd.handle(), filter.handle(), dn, dims.data()));
+ output_batch_descriptor->set_count(dims[0])
+ .set_feature_map_count(dims[1])
+ .set_layout(batch_descriptor.layout());
- return true;
+ for (int i = 0; i < batch_descriptor.ndims(); i++) {
+ output_batch_descriptor->set_spatial_dim(static_cast<dnn::DimIndex>(i),
+ dims.rbegin()[i]);
+ }
+ return port::Status::OK();
+ }();
+ return IsStatusOk(status, /*report_error=*/true);
}
} // namespace cuda
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index e2de3c62d8..c924d41cb5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -631,7 +631,7 @@ class CudnnSupport : public dnn::DnnSupport {
std::unique_ptr<class CudnnAccess> cudnn_;
template <class T, class U>
- bool DoBatchNormalizationForwardImpl(
+ port::Status DoBatchNormalizationForwardImpl(
Stream* stream, dnn::DataType input_data_type,
dnn::DataType scale_data_type, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& offset,
@@ -646,7 +646,7 @@ class CudnnSupport : public dnn::DnnSupport {
std::function<void()> inv_var_to_var);
template <class T, class U>
- bool DoBatchNormalizationBackwardImpl(
+ port::Status DoBatchNormalizationBackwardImpl(
Stream* stream, int cudnn_input_type, int cudnn_scale_type,
const DeviceMemory<T>& y_backprop, const DeviceMemory<T>& x,
const DeviceMemory<U>& scale, const DeviceMemory<U>& mean,
@@ -656,21 +656,20 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<U>* offset_backprop);
template <class T>
- bool DoConvolveImpl(Stream* stream,
- const dnn::BatchDescriptor& input_descriptor,
- const DeviceMemory<T>& input_data,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<T>& filter_data,
- const dnn::ConvolutionDescriptor& convolution_descriptor,
- const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<T>* output_data,
- ScratchAllocator* scratch_allocator,
- const dnn::AlgorithmConfig& algorithm_config,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoConvolveImpl(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<T>& input_data,
+ const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<T>& filter_data,
+ const dnn::ConvolutionDescriptor& convolution_descriptor,
+ const dnn::BatchDescriptor& output_descriptor,
+ DeviceMemory<T>* output_data, ScratchAllocator* scratch_allocator,
+ const dnn::AlgorithmConfig& algorithm_config,
+ dnn::ProfileResult* output_profile_result);
template <typename Type, typename BiasType, typename ScaleType,
int cudnn_data_type, int cudnn_compute_type>
- bool DoFusedConvolveImpl(
+ port::Status DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
@@ -685,9 +684,8 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardDataImpl(
- Stream* stream,
- const dnn::FilterDescriptor& filter_descriptor,
+ port::Status DoConvolveBackwardDataImpl(
+ Stream* stream, const dnn::FilterDescriptor& filter_descriptor,
const DeviceMemory<T>& filter_data,
const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
@@ -698,10 +696,10 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardFilterImpl(
+ port::Status DoConvolveBackwardFilterImpl(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& output_descriptor_in,
+ const dnn::BatchDescriptor& output_descriptor,
DeviceMemory<T> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
@@ -711,56 +709,56 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoConvolveBackwardBiasImpl(Stream* stream,
- const dnn::BatchDescriptor& input_descriptor,
- const DeviceMemory<T>& input_data,
- const dnn::BatchDescriptor& bias_descriptor,
- DeviceMemory<T>* backward_bias_data);
+ port::Status DoConvolveBackwardBiasImpl(
+ Stream* stream, const dnn::BatchDescriptor& input_descriptor,
+ const DeviceMemory<T>& input_data,
+ const dnn::BatchDescriptor& bias_descriptor,
+ DeviceMemory<T>* backward_bias_data);
template <class T>
- bool DoRnnForwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- const DeviceMemory<T>& input_data,
- const CudnnRnnStateTensorDescriptor& input_h_desc,
- const DeviceMemory<T>& input_h_data,
- const CudnnRnnStateTensorDescriptor& input_c_desc,
- const DeviceMemory<T>& input_c_data,
- const DeviceMemory<T>& params,
- const CudnnRnnSequenceTensorDescriptor& output_desc,
- DeviceMemory<T>* output_data,
- const CudnnRnnStateTensorDescriptor& output_h_desc,
- DeviceMemory<T>* output_h_data,
- const CudnnRnnStateTensorDescriptor& output_c_desc,
- DeviceMemory<T>* output_c_data, bool is_training,
- ScratchAllocator* reserve_space_allocator,
- ScratchAllocator* workspace_allocator,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoRnnForwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ DeviceMemory<T>* output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ DeviceMemory<T>* output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ DeviceMemory<T>* output_c_data, bool is_training,
+ ScratchAllocator* reserve_space_allocator,
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
template <class T>
- bool DoRnnBackwardImpl(Stream* stream, const CudnnRnnDescriptor& rnn_desc,
- const CudnnRnnSequenceTensorDescriptor& input_desc,
- const DeviceMemory<T>& input_data,
- const CudnnRnnStateTensorDescriptor& input_h_desc,
- const DeviceMemory<T>& input_h_data,
- const CudnnRnnStateTensorDescriptor& input_c_desc,
- const DeviceMemory<T>& input_c_data,
- const DeviceMemory<T>& params,
- const CudnnRnnSequenceTensorDescriptor& output_desc,
- const DeviceMemory<T>& output_data,
- const CudnnRnnStateTensorDescriptor& output_h_desc,
- const DeviceMemory<T>& output_h_data,
- const CudnnRnnStateTensorDescriptor& output_c_desc,
- const DeviceMemory<T>& output_c_data,
- const DeviceMemory<T>& output_backprop_data,
- const DeviceMemory<T>& output_h_backprop_data,
- const DeviceMemory<T>& output_c_backprop_data,
- DeviceMemory<T>* input_backprop_data,
- DeviceMemory<T>* input_h_backprop_data,
- DeviceMemory<T>* input_c_backprop_data,
- DeviceMemory<T>* params_backprop_data,
- DeviceMemory<uint8>* reserve_space_data,
- ScratchAllocator* workspace_allocator,
- dnn::ProfileResult* output_profile_result);
+ port::Status DoRnnBackwardImpl(
+ Stream* stream, const CudnnRnnDescriptor& rnn_desc,
+ const CudnnRnnSequenceTensorDescriptor& input_desc,
+ const DeviceMemory<T>& input_data,
+ const CudnnRnnStateTensorDescriptor& input_h_desc,
+ const DeviceMemory<T>& input_h_data,
+ const CudnnRnnStateTensorDescriptor& input_c_desc,
+ const DeviceMemory<T>& input_c_data, const DeviceMemory<T>& params,
+ const CudnnRnnSequenceTensorDescriptor& output_desc,
+ const DeviceMemory<T>& output_data,
+ const CudnnRnnStateTensorDescriptor& output_h_desc,
+ const DeviceMemory<T>& output_h_data,
+ const CudnnRnnStateTensorDescriptor& output_c_desc,
+ const DeviceMemory<T>& output_c_data,
+ const DeviceMemory<T>& output_backprop_data,
+ const DeviceMemory<T>& output_h_backprop_data,
+ const DeviceMemory<T>& output_c_backprop_data,
+ DeviceMemory<T>* input_backprop_data,
+ DeviceMemory<T>* input_h_backprop_data,
+ DeviceMemory<T>* input_c_backprop_data,
+ DeviceMemory<T>* params_backprop_data,
+ DeviceMemory<uint8>* reserve_space_data,
+ ScratchAllocator* workspace_allocator,
+ dnn::ProfileResult* output_profile_result);
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
};
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index f2be68bc42..f11022ef1d 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -180,11 +180,11 @@ bool CUDAExecutor::FindOnDiskForComputeCapability(
static string GetBinaryDir(bool strip_exe) {
char exe_path[PATH_MAX] = {0};
#if defined(__APPLE__)
- uint32_t buffer_size = 0U;
- _NSGetExecutablePath(nullptr, &buffer_size);
- char unresolved_path[buffer_size];
- _NSGetExecutablePath(unresolved_path, &buffer_size);
- CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1);
+ uint32_t buffer_size = 0U;
+ _NSGetExecutablePath(nullptr, &buffer_size);
+ char unresolved_path[buffer_size];
+ _NSGetExecutablePath(unresolved_path, &buffer_size);
+ CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1);
#else
#if defined(PLATFORM_WINDOWS)
HMODULE hModule = GetModuleHandle(NULL);
diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h
index 70554ec931..e040cf86fa 100644
--- a/tensorflow/stream_executor/cuda/cuda_timer.h
+++ b/tensorflow/stream_executor/cuda/cuda_timer.h
@@ -37,8 +37,9 @@ class CUDATimer : public internal::TimerInterface {
explicit CUDATimer(CUDAExecutor *parent)
: parent_(parent), start_event_(nullptr), stop_event_(nullptr) {}
- // Note: teardown is explicitly handled in this API by a call to
+ // Note: teardown needs to be explicitly handled in this API by a call to
// StreamExecutor::DeallocateTimer(), which invokes Destroy().
+ // TODO(csigg): Change to RAII.
~CUDATimer() override {}
// Allocates the platform-specific pieces of the timer, called as part of
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 5315d1f3da..82aa8ceb32 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -141,6 +141,10 @@ string PadAlignmentString(PadAlignment alignment) {
return "unknown pad alignment";
}
+std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) {
+ return str << PadAlignmentString(alignment);
+}
+
string ShortPoolingModeString(PoolingMode mode) {
switch (mode) {
case PoolingMode::kMaximum:
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 3df5365c23..9eca5abe1a 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -469,6 +469,9 @@ enum class PadAlignment : int64 {
// Returns a string representation of the given padding alignment.
string PadAlignmentString(PadAlignment alignment);
+// Print alignment to str. Needed to use CHECK_EQ between two PadAlignments.
+std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment);
+
// Describes a convolution.
//
// Uses the named argument construction form:
@@ -710,7 +713,7 @@ class PoolingDescriptor {
class AlgorithmDesc {
public:
typedef int64 Index;
- AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(false) {}
+ AlgorithmDesc() : algo_(kDefaultAlgorithm), tensor_ops_enabled_(true) {}
AlgorithmDesc(Index a, bool use_tensor_ops)
: algo_(a), tensor_ops_enabled_(use_tensor_ops) {}
bool is_default() const { return algo_ == kDefaultAlgorithm; }
diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc
index 50a6edd80b..52efe771bc 100644
--- a/tensorflow/stream_executor/event.cc
+++ b/tensorflow/stream_executor/event.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include "tensorflow/stream_executor/event.h"
+#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
-#include "tensorflow/stream_executor/stream.h"
namespace stream_executor {
@@ -27,9 +27,12 @@ Event::Event(StreamExecutor* stream_exec)
stream_exec_->implementation()->CreateEventImplementation()) {}
Event::~Event() {
- auto status = stream_exec_->DeallocateEvent(this);
- if (!status.ok()) {
- LOG(ERROR) << status.error_message();
+ // Deal with nullptr implementation_, as this event may have been std::moved.
+ if (stream_exec_ && implementation_) {
+ auto status = stream_exec_->DeallocateEvent(this);
+ if (!status.ok()) {
+ LOG(ERROR) << status.error_message();
+ }
}
}
diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h
index 1f37262c78..9cc87a7c12 100644
--- a/tensorflow/stream_executor/event.h
+++ b/tensorflow/stream_executor/event.h
@@ -61,6 +61,9 @@ class Event {
// Returns a pointer to the underlying platform-specific implementation.
internal::EventInterface* implementation() { return implementation_.get(); }
+ Event(Event&&) = default;
+ Event& operator=(Event&&) = default;
+
private:
friend class Stream;
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index 2c4819651a..3cd97b3cf1 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -26,8 +26,6 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/plugin_registry.h"
-bool FLAGS_stream_executor_cpu_real_clock_rate = false;
-
namespace stream_executor {
namespace host {
@@ -190,11 +188,8 @@ DeviceDescription *HostExecutor::PopulateDeviceDescription() const {
// doesn't result in thrashing or other badness? 4GiB chosen arbitrarily.
builder.set_device_memory_size(static_cast<uint64>(4) * 1024 * 1024 * 1024);
- float cycle_counter_frequency = 1e9;
- if (FLAGS_stream_executor_cpu_real_clock_rate) {
- cycle_counter_frequency = static_cast<float>(
- tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
- }
+ float cycle_counter_frequency = static_cast<float>(
+ tensorflow::profile_utils::CpuUtils::GetCycleCounterFrequency());
builder.set_clock_rate_ghz(cycle_counter_frequency / 1e9);
auto built = builder.Build();
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 4a98cfe164..9369183133 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -192,6 +192,7 @@ string ToVlogString(dnn::DataType data_type) {
case dnn::DataType::kInt8:
return "dnn::DataType::kInt8";
}
+ return "unknown DataType";
}
// Used together with PARAM to VLOG calls made to the stream. Intended
@@ -5227,24 +5228,11 @@ port::Status Stream::BlockHostUntilDone() {
return status;
}
- port::Status first_error;
- {
- // Wait until all active sub-streams have done their tasks.
- mutex_lock lock(mu_);
- for (auto &stream : sub_streams_) {
- if (!stream.second) {
- first_error.Update(stream.first->BlockHostUntilDone());
- // Set this sub-stream as available.
- stream.second = true;
- }
- }
- }
-
temporary_memory_manager_.DeallocateFinalizedTemporaries();
- first_error.Update(parent_->BlockHostUntilDone(this));
- CheckError(first_error.ok());
- return first_error;
+ port::Status error = parent_->BlockHostUntilDone(this);
+ CheckError(error.ok());
+ return error;
}
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 3da1b856d6..e8885e1eb6 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -25,6 +25,7 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dnn.h"
@@ -156,14 +157,13 @@ class Stream {
const TypedKernel<Params...> &kernel, Args... args);
// Record a "start" event for the interval timer at this point in the
- // stream's
- // execution (relative to the previously and subsequently enqueued items in
- // the stream's execution). Streams may be started/stopped multiple times.
+ // stream's execution (relative to the previously and subsequently enqueued
+ // items in the stream's execution). Streams may be started/stopped multiple
+ // times.
Stream &ThenStartTimer(Timer *t);
// Record a "stop" event for the interval timer at this point in the
- // stream's
- // execution. See also Stream::ThenStartTimer.
+ // stream's execution. See also Stream::ThenStartTimer.
Stream &ThenStopTimer(Timer *t);
// TODO(leary) If work is added to the stream that is being depended upon,
@@ -179,8 +179,7 @@ class Stream {
//
// Checks that a stream does not wait for itself, and it is up to the
// user to guarantee that a stream does not come to wait on itself in a
- // cyclic
- // manner; in that case, behavior is undefined.
+ // cyclic manner; in that case, behavior is undefined.
//
// N.B. Base recursion case for the variadic ThenWaitFor.
Stream &ThenWaitFor(Stream *other);
@@ -1351,33 +1350,39 @@ class Stream {
DeviceMemory<std::complex<double>> *x, int incx);
// See BlasSupport::DoBlasGemm.
- Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha,
- const DeviceMemory<Eigen::half> &a, int lda,
- const DeviceMemory<Eigen::half> &b, int ldb, float beta,
- DeviceMemory<Eigen::half> *c, int ldc);
- Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha,
- const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta,
- DeviceMemory<float> *c, int ldc);
- Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, double alpha,
- const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc);
- Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<float> alpha,
- const DeviceMemory<std::complex<float>> &a, int lda,
- const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta,
- DeviceMemory<std::complex<float>> *c, int ldc);
- Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<double> alpha,
- const DeviceMemory<std::complex<double>> &a, int lda,
- const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta,
- DeviceMemory<std::complex<double>> *c, int ldc);
+ TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, float alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ float beta, DeviceMemory<Eigen::half> *c,
+ int ldc);
+ TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, float alpha,
+ const DeviceMemory<float> &a, int lda,
+ const DeviceMemory<float> &b, int ldb,
+ float beta, DeviceMemory<float> *c, int ldc);
+ TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k, double alpha,
+ const DeviceMemory<double> &a, int lda,
+ const DeviceMemory<double> &b, int ldb,
+ double beta, DeviceMemory<double> *c, int ldc);
+ TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k,
+ std::complex<float> alpha,
+ const DeviceMemory<std::complex<float>> &a,
+ int lda,
+ const DeviceMemory<std::complex<float>> &b,
+ int ldb, std::complex<float> beta,
+ DeviceMemory<std::complex<float>> *c, int ldc);
+ TF_EXPORT Stream &ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
+ uint64 m, uint64 n, uint64 k,
+ std::complex<double> alpha,
+ const DeviceMemory<std::complex<double>> &a,
+ int lda,
+ const DeviceMemory<std::complex<double>> &b,
+ int ldb, std::complex<double> beta,
+ DeviceMemory<std::complex<double>> *c,
+ int ldc);
Stream &ThenBlasGemmWithProfiling(blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index b222a4d82a..000795ff00 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -610,7 +610,7 @@ port::Status StreamExecutor::SynchronousMemcpyD2H(
port::Status StreamExecutor::SynchronousMemcpyH2D(
const void *host_src, int64 size, DeviceMemoryBase *device_dst) {
VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
- << ", size=" << size << ", device_dst" << device_dst->opaque() << ")"
+ << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")"
<< StackTraceIfVLOG10();
port::Status result;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index b59f8e1f98..e4241667ad 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -148,6 +148,12 @@ def if_windows(a):
"//conditions:default": [],
})
+def if_not_windows_cuda(a):
+ return select({
+ clean_dep("//tensorflow:with_cuda_support_windows_override"): [],
+ "//conditions:default": a,
+ })
+
def if_linux_x86_64(a):
return select({
clean_dep("//tensorflow:linux_x86_64"): a,
@@ -241,6 +247,9 @@ def tf_opts_nortti_if_android():
# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
+def tf_features_nomodules_if_android():
+ return if_android(["-use_header_modules"])
+
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names, deps=None, is_external=True):
@@ -816,6 +825,9 @@ def tf_cc_test_mkl(srcs,
tags=[],
size="medium",
args=None):
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
for src in srcs:
native.cc_test(
name=src_to_test_name(src),
@@ -841,6 +853,7 @@ def tf_cc_test_mkl(srcs,
tags=tags,
size=size,
args=args,
+ features=disable_header_modules,
nocopts="-fno-exceptions")
@@ -919,6 +932,7 @@ def tf_gpu_kernel_library(srcs,
hdrs=[],
**kwargs):
copts = copts + _cuda_copts() + if_cuda(cuda_copts) + tf_copts()
+ kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
native.cc_library(
srcs=srcs,
@@ -959,6 +973,7 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
if not cuda_deps:
cuda_deps = []
+ kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
native.cc_library(
deps=deps + if_cuda(cuda_deps + [
clean_dep("//tensorflow/core:cuda"),
@@ -973,16 +988,17 @@ register_extension_info(
label_regex_for_dep = "{extension_name}",
)
-def tf_kernel_library(name,
- prefix=None,
- srcs=None,
- gpu_srcs=None,
- hdrs=None,
- deps=None,
- alwayslink=1,
- copts=None,
- is_external=False,
- **kwargs):
+def tf_kernel_library(
+ name,
+ prefix = None,
+ srcs = None,
+ gpu_srcs = None,
+ hdrs = None,
+ deps = None,
+ alwayslink = 1,
+ copts = None,
+ is_external = False,
+ **kwargs):
"""A rule to build a TensorFlow OpKernel.
May either specify srcs/hdrs or prefix. Similar to tf_cuda_library,
@@ -1012,6 +1028,7 @@ def tf_kernel_library(name,
deps = []
if not copts:
copts = []
+ textual_hdrs = []
copts = copts + tf_copts(is_external=is_external)
if prefix:
if native.glob([prefix + "*.cu.cc"], exclude=["*test*"]):
@@ -1022,8 +1039,13 @@ def tf_kernel_library(name,
srcs = srcs + native.glob(
[prefix + "*.cc"], exclude=[prefix + "*test*", prefix + "*.cu.cc"])
hdrs = hdrs + native.glob(
- [prefix + "*.h"], exclude=[prefix + "*test*", prefix + "*.cu.h"])
-
+ [prefix + "*.h"],
+ exclude = [prefix + "*test*", prefix + "*.cu.h", prefix + "*impl.h"],
+ )
+ textual_hdrs = native.glob(
+ [prefix + "*impl.h"],
+ exclude = [prefix + "*test*", prefix + "*.cu.h"],
+ )
cuda_deps = [clean_dep("//tensorflow/core:gpu_lib")]
if gpu_srcs:
for gpu_src in gpu_srcs:
@@ -1037,6 +1059,7 @@ def tf_kernel_library(name,
name=name,
srcs=srcs,
hdrs=hdrs,
+ textual_hdrs = textual_hdrs,
copts=copts,
cuda_deps=cuda_deps,
linkstatic=1, # Needed since alwayslink is broken in bazel b/27630669
@@ -1070,6 +1093,9 @@ def tf_mkl_kernel_library(name,
hdrs = hdrs + native.glob(
[prefix + "*.h"])
+ # -fno-exceptions in nocopts breaks compilation if header modules are enabled.
+ disable_header_modules = ["-use_header_modules"]
+
native.cc_library(
name=name,
srcs=if_mkl(srcs),
@@ -1077,7 +1103,8 @@ def tf_mkl_kernel_library(name,
deps=deps,
alwayslink=alwayslink,
copts=copts,
- nocopts=nocopts
+ nocopts=nocopts,
+ features = disable_header_modules
)
register_extension_info(
@@ -1301,6 +1328,7 @@ def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[], linkopts=[]):
name=basename + "_gpu",
srcs=gpu_srcs,
copts=_cuda_copts() + if_tensorrt(["-DGOOGLE_TENSORRT=1"]),
+ features = if_cuda(["-use_header_modules"]),
deps=deps + if_cuda(cuda_deps))
cuda_deps.extend([":" + basename + "_gpu"])
diff --git a/tensorflow/tf_framework_version_script.lds b/tensorflow/tf_framework_version_script.lds
new file mode 100644
index 0000000000..d4977f88c0
--- /dev/null
+++ b/tensorflow/tf_framework_version_script.lds
@@ -0,0 +1,11 @@
+VERS_1.0 {
+ # Hide libjpeg symbols to avoid symbol conflict with OpenCV
+ local:
+ jpeg_*;
+ jinit_*;
+ jdiv_round_up;
+ jround_up;
+ jzero_far;
+ jcopy_*;
+ jsimd_*;
+};
diff --git a/tensorflow/tools/api/generator/BUILD b/tensorflow/tools/api/generator/BUILD
index 3a28153e52..8c760e6f52 100644
--- a/tensorflow/tools/api/generator/BUILD
+++ b/tensorflow/tools/api/generator/BUILD
@@ -3,34 +3,37 @@
licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
-
+load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
+exports_files(
+ [
+ "LICENSE",
+ "create_python_api.py",
+ ],
+)
+
py_library(
name = "doc_srcs",
srcs = ["doc_srcs.py"],
srcs_version = "PY2AND3",
-)
-
-py_binary(
- name = "create_python_api",
- srcs = ["create_python_api.py"],
- srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":doc_srcs",
- "//tensorflow/python:no_contrib",
+ "//tensorflow/python:util",
],
)
py_test(
name = "create_python_api_test",
- srcs = ["create_python_api_test.py"],
+ srcs = [
+ "create_python_api.py",
+ "create_python_api_test.py",
+ ],
srcs_version = "PY2AND3",
deps = [
- ":create_python_api",
+ ":doc_srcs",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
],
)
@@ -39,6 +42,7 @@ py_test(
srcs = ["doc_srcs_test.py"],
args = [
"--package=tensorflow.python",
+ "--api_name=tensorflow",
] + TENSORFLOW_API_INIT_FILES,
main = "doc_srcs_test.py",
srcs_version = "PY2AND3",
@@ -48,3 +52,20 @@ py_test(
"//tensorflow/python:no_contrib",
],
)
+
+py_test(
+ name = "estimator_doc_srcs_test",
+ srcs = ["doc_srcs_test.py"],
+ args = [
+ "--package=tensorflow.python.estimator",
+ "--api_name=estimator",
+ ] + ESTIMATOR_API_INIT_FILES,
+ main = "doc_srcs_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":doc_srcs",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:no_contrib",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
diff --git a/tensorflow/tools/api/generator/api_gen.bzl b/tensorflow/tools/api/generator/api_gen.bzl
index fe3e4d1434..ed164bf9e4 100644
--- a/tensorflow/tools/api/generator/api_gen.bzl
+++ b/tensorflow/tools/api/generator/api_gen.bzl
@@ -8,16 +8,16 @@ TENSORFLOW_API_INIT_FILES = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "debugging/__init__.py",
"distributions/__init__.py",
"distributions/bijectors/__init__.py",
+ "dtypes/__init__.py",
"errors/__init__.py",
- "estimator/__init__.py",
- "estimator/export/__init__.py",
- "estimator/inputs/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
"graph_util/__init__.py",
"image/__init__.py",
+ "io/__init__.py",
"initializers/__init__.py",
"keras/__init__.py",
"keras/activations/__init__.py",
@@ -68,6 +68,7 @@ TENSORFLOW_API_INIT_FILES = [
"nn/rnn_cell/__init__.py",
"profiler/__init__.py",
"python_io/__init__.py",
+ "quantization/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
@@ -91,6 +92,16 @@ TENSORFLOW_API_INIT_FILES = [
# END GENERATED FILES
]
+# keep sorted
+ESTIMATOR_API_INIT_FILES = [
+ # BEGIN GENERATED ESTIMATOR FILES
+ "__init__.py",
+ "estimator/__init__.py",
+ "estimator/export/__init__.py",
+ "estimator/inputs/__init__.py",
+ # END GENERATED ESTIMATOR FILES
+]
+
# Creates a genrule that generates a directory structure with __init__.py
# files that import all exported modules (i.e. modules with tf_export
# decorators).
@@ -107,19 +118,47 @@ TENSORFLOW_API_INIT_FILES = [
# template will be replaced with root imports collected by this genrule.
# srcs: genrule sources. If passing root_init_template, the template file
# must be included in sources.
-def gen_api_init_files(name,
- output_files=TENSORFLOW_API_INIT_FILES,
- root_init_template=None,
- srcs=[]):
- root_init_template_flag = ""
- if root_init_template:
- root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
- native.genrule(
- name = name,
- outs = output_files,
- cmd = (
- "$(location //tensorflow/tools/api/generator:create_python_api) " +
- root_init_template_flag + " --apidir=$(@D) $(OUTS)"),
- srcs = srcs,
- tools = ["//tensorflow/tools/api/generator:create_python_api"],
- )
+# api_name: Name of the project that you want to generate API files for
+# (e.g. "tensorflow" or "estimator").
+# package: Python package containing the @tf_export decorators you want to
+# process
+# package_dep: Python library target containing your package.
+
+def gen_api_init_files(
+ name,
+ output_files = TENSORFLOW_API_INIT_FILES,
+ root_init_template = None,
+ srcs = [],
+ api_name = "tensorflow",
+ package = "tensorflow.python",
+ package_dep = "//tensorflow/python:no_contrib",
+ output_package = "tensorflow"):
+ root_init_template_flag = ""
+ if root_init_template:
+ root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
+
+ api_gen_binary_target = "create_" + package + "_api"
+ native.py_binary(
+ name = "create_" + package + "_api",
+ srcs = ["//tensorflow/tools/api/generator:create_python_api.py"],
+ main = "//tensorflow/tools/api/generator:create_python_api.py",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ package_dep,
+ "//tensorflow/tools/api/generator:doc_srcs",
+ ],
+ )
+
+ native.genrule(
+ name = name,
+ outs = output_files,
+ cmd = (
+ "$(location :" + api_gen_binary_target + ") " +
+ root_init_template_flag + " --apidir=$(@D) --apiname=" +
+ api_name + " --package=" + package + " --output_package=" +
+ output_package + " $(OUTS)"),
+ srcs = srcs,
+ tools = [":" + api_gen_binary_target ],
+ visibility = ["//tensorflow:__pkg__"],
+ )
diff --git a/tensorflow/tools/api/generator/create_python_api.py b/tensorflow/tools/api/generator/create_python_api.py
index e3ab056efc..7f17360c91 100644
--- a/tensorflow/tools/api/generator/create_python_api.py
+++ b/tensorflow/tools/api/generator/create_python_api.py
@@ -25,11 +25,11 @@ import os
import sys
from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_export
from tensorflow.tools.api.generator import doc_srcs
+API_ATTRS = tf_export.API_ATTRS
-_API_CONSTANTS_ATTR = '_tf_api_constants'
-_API_NAMES_ATTR = '_tf_api_names'
_DEFAULT_PACKAGE = 'tensorflow.python'
_GENFILES_DIR_SUFFIX = 'genfiles/'
_SYMBOLS_TO_SKIP_EXPLICITLY = {
@@ -45,7 +45,7 @@ _GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
from __future__ import print_function
"""
-_GENERATED_FILE_FOOTER = "\n\ndel print_function\n"
+_GENERATED_FILE_FOOTER = '\n\ndel print_function\n'
class SymbolExposedTwiceError(Exception):
@@ -159,12 +159,13 @@ __all__.remove('print_function')
return module_text_map
-def get_api_init_text(package):
+def get_api_init_text(package, output_package, api_name):
"""Get a map from destination module to __init__.py code for that module.
Args:
package: Base python package containing python with target tf_export
decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Returns:
A dictionary where
@@ -179,7 +180,7 @@ def get_api_init_text(package):
for module in list(sys.modules.values()):
# Only look at tensorflow modules.
if (not module or not hasattr(module, '__name__') or
- package not in module.__name__):
+ module.__name__ is None or package not in module.__name__):
continue
# Do not generate __init__.py files for contrib modules for now.
if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
@@ -192,7 +193,7 @@ def get_api_init_text(package):
attr = getattr(module, module_contents_name)
# If attr is _tf_api_constants attribute, then add the constants.
- if module_contents_name == _API_CONSTANTS_ATTR:
+ if module_contents_name == API_ATTRS[api_name].constants:
for exports, value in attr:
for export in exports:
names = export.split('.')
@@ -201,15 +202,12 @@ def get_api_init_text(package):
-1, dest_module, module.__name__, value, names[-1])
continue
- try:
- _, attr = tf_decorator.unwrap(attr)
- except Exception as e:
- print('5555: %s %s' % (module, module_contents_name), file=sys.stderr)
- raise e
+ _, attr = tf_decorator.unwrap(attr)
# If attr is a symbol with _tf_api_names attribute, then
# add import for it.
- if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
- for export in attr._tf_api_names: # pylint: disable=protected-access
+ if (hasattr(attr, '__dict__') and
+ API_ATTRS[api_name].names in attr.__dict__):
+ for export in getattr(attr, API_ATTRS[api_name].names): # pylint: disable=protected-access
names = export.split('.')
dest_module = '.'.join(names[:-1])
module_code_builder.add_import(
@@ -220,7 +218,6 @@ def get_api_init_text(package):
# For e.g. if we import 'foo.bar.Value'. Then, we also
# import 'bar' in 'foo'.
imported_modules = set(module_code_builder.module_imports.keys())
- import_from = '.'
for module in imported_modules:
if not module:
continue
@@ -231,6 +228,9 @@ def get_api_init_text(package):
if submodule_index > 0:
parent_module += ('.' + module_split[submodule_index-1] if parent_module
else module_split[submodule_index-1])
+ import_from = output_package
+ if submodule_index > 0:
+ import_from += '.' + '.'.join(module_split[:submodule_index])
module_code_builder.add_import(
-1, parent_module, import_from,
module_split[submodule_index], module_split[submodule_index])
@@ -246,7 +246,7 @@ def get_module(dir_path, relative_to_dir):
relative_to_dir: Get module relative to this directory.
Returns:
- module that corresponds to the given directory.
+ Name of module that corresponds to the given directory.
"""
dir_path = dir_path[len(relative_to_dir):]
# Convert path separators to '/' for easier parsing below.
@@ -254,7 +254,7 @@ def get_module(dir_path, relative_to_dir):
return dir_path.replace('/', '.').strip('.')
-def get_module_docstring(module_name, package):
+def get_module_docstring(module_name, package, api_name):
"""Get docstring for the given module.
This method looks for docstring in the following order:
@@ -270,6 +270,7 @@ def get_module_docstring(module_name, package):
(excluding 'tensorflow.' prefix) to get a docstring for.
package: Base python package containing python with target tf_export
decorators.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Returns:
One-line docstring to describe the module.
@@ -277,8 +278,10 @@ def get_module_docstring(module_name, package):
# Module under base package to get a docstring from.
docstring_module_name = module_name
- if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
- docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
+ doc_sources = doc_srcs.get_doc_sources(api_name)
+
+ if module_name in doc_sources:
+ docsrc = doc_sources[module_name]
if docsrc.docstring:
return docsrc.docstring
if docsrc.docstring_module_name:
@@ -293,7 +296,8 @@ def get_module_docstring(module_name, package):
def create_api_files(
- output_files, package, root_init_template, output_dir):
+ output_files, package, root_init_template, output_dir, output_package,
+ api_name):
"""Creates __init__.py files for the Python API.
Args:
@@ -305,6 +309,7 @@ def create_api_files(
"#API IMPORTS PLACEHOLDER" comment in the template file will be replaced
with imports.
output_dir: output API root directory.
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
Raises:
ValueError: if an output file is not under api/ directory,
@@ -321,7 +326,7 @@ def create_api_files(
os.makedirs(os.path.dirname(file_path))
open(file_path, 'a').close()
- module_text_map = get_api_init_text(package)
+ module_text_map = get_api_init_text(package, output_package, api_name)
# Add imports to output files.
missing_output_files = []
@@ -336,8 +341,8 @@ def create_api_files(
if module or not root_init_template:
contents = (
_GENERATED_FILE_HEADER %
- get_module_docstring(module, package) + text +
- _GENERATED_FILE_FOOTER)
+ get_module_docstring(module, package, api_name) +
+ text + _GENERATED_FILE_FOOTER)
else:
# Read base init file
with open(root_init_template, 'r') as root_init_template_file:
@@ -375,6 +380,13 @@ def main():
help='Directory where generated output files are placed. '
'gendir should be a prefix of apidir. Also, apidir '
'should be a prefix of every directory in outputs.')
+ parser.add_argument(
+ '--apiname', required=True, type=str,
+ choices=API_ATTRS.keys(),
+ help='The API you want to generate.')
+ parser.add_argument(
+ '--output_package', default='tensorflow', type=str,
+ help='Root output package.')
args = parser.parse_args()
@@ -388,8 +400,8 @@ def main():
# Populate `sys.modules` with modules containing tf_export().
importlib.import_module(args.package)
- create_api_files(
- outputs, args.package, args.root_init_template, args.apidir)
+ create_api_files(outputs, args.package, args.root_init_template,
+ args.apidir, args.output_package, args.apiname)
if __name__ == '__main__':
diff --git a/tensorflow/tools/api/generator/create_python_api_test.py b/tensorflow/tools/api/generator/create_python_api_test.py
index 986340cf6d..1a7187463a 100644
--- a/tensorflow/tools/api/generator/create_python_api_test.py
+++ b/tensorflow/tools/api/generator/create_python_api_test.py
@@ -57,7 +57,9 @@ class CreatePythonApiTest(test.TestCase):
def testFunctionImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
expected_import = (
'from tensorflow.python.test_module '
'import test_op as test_op1')
@@ -73,7 +75,9 @@ class CreatePythonApiTest(test.TestCase):
def testClassImportIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
expected_import = ('from tensorflow.python.test_module '
'import TestClass')
self.assertTrue(
@@ -82,7 +86,9 @@ class CreatePythonApiTest(test.TestCase):
def testConstantIsAdded(self):
imports = create_python_api.get_api_init_text(
- package=create_python_api._DEFAULT_PACKAGE)
+ package=create_python_api._DEFAULT_PACKAGE,
+ output_package='tensorflow',
+ api_name='tensorflow')
expected = ('from tensorflow.python.test_module '
'import _TEST_CONSTANT')
self.assertTrue(expected in str(imports),
diff --git a/tensorflow/tools/api/generator/doc_srcs.py b/tensorflow/tools/api/generator/doc_srcs.py
index 74f6db98fd..ad1988494d 100644
--- a/tensorflow/tools/api/generator/doc_srcs.py
+++ b/tensorflow/tools/api/generator/doc_srcs.py
@@ -19,6 +19,8 @@ from __future__ import print_function
import collections
+from tensorflow.python.util import tf_export
+
# Specifies docstring source for a module.
# Only one of docstring or docstring_module_name should be set.
@@ -31,7 +33,7 @@ DocSource = collections.namedtuple(
# Each attribute of DocSource is optional.
DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
-TENSORFLOW_DOC_SOURCES = {
+_TENSORFLOW_DOC_SOURCES = {
'app': DocSource(docstring_module_name='platform.app'),
'compat': DocSource(docstring_module_name='util.compat'),
'distributions': DocSource(
@@ -41,7 +43,7 @@ TENSORFLOW_DOC_SOURCES = {
'gfile': DocSource(docstring_module_name='platform.gfile'),
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
'image': DocSource(docstring_module_name='ops.image_ops'),
- 'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
+ 'keras.estimator': DocSource(docstring_module_name='keras.estimator'),
'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
'logging': DocSource(docstring_module_name='ops.logging_ops'),
'losses': DocSource(docstring_module_name='ops.losses.losses'),
@@ -63,3 +65,28 @@ TENSORFLOW_DOC_SOURCES = {
'train.queue_runner': DocSource(
docstring_module_name='training.queue_runner'),
}
+
+_ESTIMATOR_DOC_SOURCES = {
+ 'estimator': DocSource(
+ docstring_module_name='estimator_lib'),
+ 'estimator.export': DocSource(
+ docstring_module_name='export.export_lib'),
+ 'estimator.inputs': DocSource(
+ docstring_module_name='inputs.inputs'),
+}
+
+
+def get_doc_sources(api_name):
+ """Get a map from module to a DocSource object.
+
+ Args:
+ api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
+
+ Returns:
+ Map from module name to DocSource object.
+ """
+ if api_name == tf_export.TENSORFLOW_API_NAME:
+ return _TENSORFLOW_DOC_SOURCES
+ if api_name == tf_export.ESTIMATOR_API_NAME:
+ return _ESTIMATOR_DOC_SOURCES
+ return {}
diff --git a/tensorflow/tools/api/generator/doc_srcs_test.py b/tensorflow/tools/api/generator/doc_srcs_test.py
index 9ba95a3439..dbff904abe 100644
--- a/tensorflow/tools/api/generator/doc_srcs_test.py
+++ b/tensorflow/tools/api/generator/doc_srcs_test.py
@@ -32,34 +32,34 @@ FLAGS = None
class DocSrcsTest(test.TestCase):
def testModulesAreValidAPIModules(self):
- for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
+ for module_name in doc_srcs.get_doc_sources(FLAGS.api_name):
# Convert module_name to corresponding __init__.py file path.
file_path = module_name.replace('.', '/')
if file_path:
file_path += '/'
file_path += '__init__.py'
- if file_path not in FLAGS.outputs:
- self.assertFalse('%s is not a valid API module' % module_name)
+ self.assertIn(
+ file_path, FLAGS.outputs,
+ msg='%s is not a valid API module' % module_name)
def testHaveDocstringOrDocstringModule(self):
- for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
- if docsrc.docstring and docsrc.docstring_module_name:
- self.assertFalse(
- '%s contains DocSource has both a docstring and a '
- 'docstring_module_name. '
- 'Only one of "docstring" or "docstring_module_name" should be set.'
- % (module_name))
+ for module_name, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
+ self.assertFalse(
+ docsrc.docstring and docsrc.docstring_module_name,
+ msg=('%s contains DocSource has both a docstring and a '
+ 'docstring_module_name. Only one of "docstring" or '
+ '"docstring_module_name" should be set.') % (module_name))
def testDocstringModulesAreValidModules(self):
- for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
+ for _, docsrc in doc_srcs.get_doc_sources(FLAGS.api_name).items():
if docsrc.docstring_module_name:
doc_module_name = '.'.join([
FLAGS.package, docsrc.docstring_module_name])
- if doc_module_name not in sys.modules:
- sys.assertFalse(
- 'docsources_module %s is not a valid module under %s.' %
- (docsrc.docstring_module_name, FLAGS.package))
+ self.assertIn(
+ doc_module_name, sys.modules,
+ msg=('docsources_module %s is not a valid module under %s.' %
+ (docsrc.docstring_module_name, FLAGS.package)))
if __name__ == '__main__':
@@ -71,6 +71,9 @@ if __name__ == '__main__':
'--package', type=str,
help='Base package that imports modules containing the target tf_export '
'decorators.')
+ parser.add_argument(
+ '--api_name', type=str,
+ help='API name: tensorflow or estimator')
FLAGS, unparsed = parser.parse_known_args()
importlib.import_module(FLAGS.package)
diff --git a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
index f819b174c0..353e63127d 100644
--- a/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-g-p-u-options.pbtxt
@@ -72,6 +72,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "num_dev_to_dev_copy_streams"
+ number: 3
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
nested_type {
name: "VirtualDevices"
field {
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
new file mode 100644
index 0000000000..36b534af36
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-aggregation.pbtxt
@@ -0,0 +1,16 @@
+path: "tensorflow.VariableAggregation"
+tf_class {
+ is_instance: "<enum \'VariableAggregation\'>"
+ member {
+ name: "MEAN"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+ member {
+ name: "SUM"
+ mtype: "<enum \'VariableAggregation\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
index 8e539069da..c13eb7b8bb 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-scope.pbtxt
@@ -56,7 +56,7 @@ tf_class {
}
member_method {
name: "get_variable"
- argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'var_store\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'reuse\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "global_variables"
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
new file mode 100644
index 0000000000..7589bb2888
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.-variable-synchronization.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.VariableSynchronization"
+tf_class {
+ is_instance: "<enum \'VariableSynchronization\'>"
+ member {
+ name: "AUTO"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_READ"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+ member {
+ name: "ON_WRITE"
+ mtype: "<enum \'VariableSynchronization\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
index 8e7e945ed1..834f0954d5 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-dataset.pbtxt
@@ -24,7 +24,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -80,7 +80,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 5cfb2fd2f0..4d854a4cee 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
index 3327e5b274..601f095a60 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
index 9d59375282..587829a4c0 100644
--- a/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.data.-text-line-dataset.pbtxt
@@ -25,7 +25,7 @@ tf_class {
}
member_method {
name: "batch"
- argspec: "args=[\'self\', \'batch_size\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "cache"
@@ -81,7 +81,7 @@ tf_class {
}
member_method {
name: "padded_batch"
- argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
member_method {
name: "prefetch"
diff --git a/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt
new file mode 100644
index 0000000000..d9efe97821
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.debugging.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.debugging"
+tf_module {
+ member_method {
+ name: "check_numerics"
+ argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_finite"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_inf"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_nan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt
new file mode 100644
index 0000000000..98e1feed00
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.dtypes.pbtxt
@@ -0,0 +1,7 @@
+path: "tensorflow.dtypes"
+tf_module {
+ member_method {
+ name: "as_string"
+ argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 099838fa65..9dbb5d16a4 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 87bd19a23a..34a30c2874 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 111914f643..0c6b7e4a82 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 67e4ee02d0..9c1c072124 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'2\', \'None\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index e1289b975e..7391d4b07a 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'model_dir\', \'linear_feature_columns\', \'linear_optimizer\', \'dnn_feature_columns\', \'dnn_optimizer\', \'dnn_hidden_units\', \'dnn_activation_fn\', \'dnn_dropout\', \'label_dimension\', \'weight_column\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\', \'linear_sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Ftrl\', \'None\', \'Adagrad\', \'None\', \'<function relu instance>\', \'None\', \'1\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\', \'sum\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
index d030b2f51f..f50e375f7c 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'hidden_units\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'activation_fn\', \'dropout\', \'input_layer_partitioner\', \'config\', \'warm_start_from\', \'loss_reduction\', \'batch_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Adagrad\', \'<function relu instance>\', \'None\', \'None\', \'None\', \'None\', \'weighted_sum\', \'False\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
index cb578759ee..154f171e89 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
index fcd01bb663..4d46d1e6b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\', \'partitioner\', \'warm_start_from\', \'loss_reduction\', \'sparse_combiner\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\', \'None\', \'None\', \'weighted_sum\', \'sum\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
index 87543e374b..6ec3aba775 100644
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
@@ -21,6 +21,10 @@ tf_module {
argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "adjust_jpeg_quality"
+ argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "adjust_saturation"
argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -54,7 +58,7 @@ tf_module {
}
member_method {
name: "decode_image"
- argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
}
member_method {
name: "decode_jpeg"
@@ -81,6 +85,10 @@ tf_module {
argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
}
member_method {
+ name: "extract_image_patches"
+ argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "extract_jpeg_shape"
argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
}
@@ -113,6 +121,10 @@ tf_module {
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
}
member_method {
+ name: "non_max_suppression_overlaps"
+ argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
+ }
+ member_method {
name: "pad_to_bounding_box"
argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
@@ -145,6 +157,10 @@ tf_module {
argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "random_jpeg_quality"
+ argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "random_saturation"
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -165,8 +181,12 @@ tf_module {
argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "resize_image_with_pad"
+ argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
name: "resize_images"
- argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\'], varargs=None, keywords=None, defaults=[\'0\', \'False\'], "
+ argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
}
member_method {
name: "resize_nearest_neighbor"
diff --git a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt
index a6b6e5eceb..86340913e2 100644
--- a/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.initializers.variance_scaling.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/tensorflow.io.pbtxt
new file mode 100644
index 0000000000..3a36c168aa
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.io.pbtxt
@@ -0,0 +1,39 @@
+path: "tensorflow.io"
+tf_module {
+ member_method {
+ name: "decode_base64"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_compressed"
+ argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
+ }
+ member_method {
+ name: "decode_json_example"
+ argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_raw"
+ argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
+ name: "encode_base64"
+ argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "matching_files"
+ argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "parse_tensor"
+ argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "read_file"
+ argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write_file"
+ argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index 11cdd6f0b5..40e82b18b6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index 4afad3e4df..8295905975 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt
index 7b0ad85eaa..f71292856c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-early-stopping.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\'], "
+ argspec: "args=[\'self\', \'monitor\', \'min_delta\', \'patience\', \'verbose\', \'mode\', \'baseline\'], varargs=None, keywords=None, defaults=[\'val_loss\', \'0\', \'0\', \'0\', \'auto\', \'None\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt
index 32a6f6ee88..03f4064b9e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.-variance-scaling.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
index 14a667870d..8645e54302 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.initializers.pbtxt
@@ -90,11 +90,11 @@ tf_module {
}
member_method {
name: "glorot_normal"
- argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "glorot_uniform"
- argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "he_normal"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
index 2bf973debb..86e328888e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
index 03f20e72c2..b0ed545781 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
index 4b46b8d15a..42f98ed03d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
index d8a1c76fd0..000898a4be 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index 622926bc4b..380b49f99c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 82100d8e09..82db5e6137 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index 408061077c..b6ff688ec3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
index a3c8031104..b41290f8b0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index e2dfaca29f..88a033e61f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index 4f068d2066..c1b9b96044 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index b8c261a743..f59f7727a3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
index 4ccd6cace6..7d3744ed92 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
index 2790e5fd85..3fd4ccdab2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -107,7 +107,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
index b1326bd0e6..ba21b50be4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index e3ac3dbf28..46f9fa2bbb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -188,7 +188,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
index 1117a695a3..c3ad326589 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index b9de142142..fd9eb43066 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
index deb535e06e..40d61688f2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index 9a9a223fba..b8c227d725 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
index 1c59b0bdf6..095d35e574 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
index 30cf5489f4..8f99961198 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 0ec69508d5..96d522a016 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
index 4cd8928403..de2824dab4 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 4b4912496d..1d563241d8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
index d0ad9cf567..c87e52c537 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
index 98cff95a7f..dccf5523e3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
index 2357498b46..7ac4116d92 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
index 3324cbff30..024f72705d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
index 6c81823654..4e0233331b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
index 487e04fd07..32d46ce8f3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
index 137e7cced4..858486c725 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index 7161665d25..f65d750926 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
index 24affa2481..2e71ef503d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
index 7ba19a4269..42533bcd21 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
index 503aa9162c..b5df169417 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
index 1737e590a2..0ea17919a9 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
index 021d024dc2..a33248bc00 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index 65387008bf..4ba21a25cd 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 4f791acf05..a7a570418e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index abc30e54e0..763bc23113 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 20791bb448..3c50a3d7f2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 449a91d873..ac78bdafad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index bb361e1297..275282d9d2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index e564bf3216..0e31e6058b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 4cb9cc3ec8..aacd0b1791 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index 5ed52b88ae..c236548663 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index f4559d29d7..6b9c0290aa 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 64e2d061e2..0d7b2211e6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index 3372ad6453..d080ad6aed 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index 08a6860bcd..fcb0a109da 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 22c9eab64f..1d0e22abd0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index 74c405ba9b..653c9f547b 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index 39f6f98193..cdbaf82cf6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
index 7b25e80b6b..230c5e9034 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 3619b8bfc4..511456e740 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 8ef3d71dd8..4a3492ebd6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -171,7 +171,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
index ecbaa9ce2c..5d05cf689f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
index 9b90db1e5e..7efa29be77 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt
@@ -97,7 +97,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index 3c60eaab7f..0ca8e0b52c 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 3dac1ff342..f754fa1da8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index 7f1b5db4d3..c9516b8f07 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
index b3e31000f3..850ecff974 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
index bbd9d1b0dc..7c69e31f9a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
index fe72beea80..fba42642d7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
index e9bf57b2b0..9c277411ea 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 0eecc58a2b..7c2f6ccc8a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index 96785a7d85..802178dba6 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index 42c46cccb3..e870dfe9ad 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
index ac816f68d4..c1337ce0cb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt
new file mode 100644
index 0000000000..ed27a62765
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-minimum.pbtxt
@@ -0,0 +1,176 @@
+path: "tensorflow.keras.layers.Minimum"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.layers.merge.Minimum\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.merge._Merge\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
index 9ae99563e9..b9f05cb3e5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 815f3bc2d1..336d9f76fb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
index e704992b4a..46282217e0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
index b3a58fa11e..42cd7e87ee 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt
new file mode 100644
index 0000000000..c00fa79adf
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -0,0 +1,175 @@
+path: "tensorflow.keras.layers.ReLU"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.layers.advanced_activations.ReLU\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'max_value\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
index 78f464583b..9f094a877a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
index 222344fd04..2f519a2438 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 55fddf576c..6b93116ba0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index 96314ce498..fd17115e27 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index 88bdf99566..4b37a94478 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 6eeea7a8d1..5bdadca74a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -100,7 +100,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 3050d46249..9dfda96fc8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index dda4c9358b..7b7684ccd2 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
index cc6275158b..3b15407fca 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index 5eb7e75047..6d04415267 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 500cb8c14e..04950654d5 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index 1113a7634f..c424e6dcc8 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -99,7 +99,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index c4b9f93561..1160d2840f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt
new file mode 100644
index 0000000000..740a03367b
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-subtract.pbtxt
@@ -0,0 +1,176 @@
+path: "tensorflow.keras.layers.Subtract"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.layers.merge.Subtract\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.merge._Merge\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "activity_regularizer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "inbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "input_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "losses"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "non_trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "outbound_nodes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_mask"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "trainable_weights"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "updates"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "variables"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "weights"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_loss"
+ argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_update"
+ argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add_variable"
+ argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "add_weight"
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+ }
+ member_method {
+ name: "build"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute_mask"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "compute_output_shape"
+ argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "count_params"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_input_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_losses_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_mask_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_output_shape_at"
+ argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_updates_for"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_weights"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_weights"
+ argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index 282c98d79a..a08c583adb 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
index acab93706b..c1294fed0f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -103,7 +103,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index a5ec228a07..dc401d3ed0 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index d8d8e0bfe9..4b5165ae97 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 97d6dc06fb..789af15fea 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
index ea9bb41b99..0536a7cee7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt
@@ -102,7 +102,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index e6d1d2e089..8915353ec3 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index f62017305f..6efb5ef15a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 07a1fde5bd..4c33c5d0bf 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -98,7 +98,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index 709eb5be55..9d7e5bb8c7 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -281,6 +281,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Minimum"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Multiply"
mtype: "<type \'type\'>"
}
@@ -297,6 +301,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "ReLU"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "RepeatVector"
mtype: "<type \'type\'>"
}
@@ -349,6 +357,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Subtract"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "ThresholdedReLU"
mtype: "<type \'type\'>"
}
@@ -409,7 +421,15 @@ tf_module {
argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
+ name: "minimum"
+ argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "multiply"
argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
+ member_method {
+ name: "subtract"
+ argspec: "args=[\'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index 62aa929d32..85f7c2bfed 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -119,7 +119,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 93ecbbce9b..5211657414 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -124,7 +124,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
index 11067058d5..c82e67526b 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
index 3259e706d7..1d031cb5f8 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
index e561f2f415..a8dda6655d 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-average-pooling3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
index 3124a35c78..97f65ed894 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-batch-normalization.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
index b5ec61255a..ccd9578f0d 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
index b2c89ae66f..9cbb58d721 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d-transpose.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
index 9e4f4969dc..c75ea3911e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
index 9850e6d765..5dc834e514 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d-transpose.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
index be113826cc..96ab209874 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-conv3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
index 0d951bf633..7e9656b352 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-dense.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
index f1beeed9ef..e9a2269a6e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-dropout.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
index b75a012811..7d2eaaab2a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-flatten.pbtxt
@@ -108,7 +108,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
index 80e0fb228b..8bc3eb26e9 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-layer.pbtxt
@@ -106,7 +106,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
index 50ff484d73..6a0dcce56a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling1-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
index cea809744c..b6c84edf2a 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling2-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
index ab9e89554c..062a02fa59 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-max-pooling3-d.pbtxt
@@ -109,7 +109,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
index 4362568445..eaad0fb23e 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv1-d.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
index 3cad824cd3..ece28a8ce9 100644
--- a/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.layers.-separable-conv2-d.pbtxt
@@ -110,7 +110,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 00b9238543..3b5845f99a 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -69,6 +69,10 @@ tf_module {
argspec: "args=[\'chol\', \'rhs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "cross"
+ argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "det"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -141,6 +145,14 @@ tf_module {
argspec: "args=[\'tensor\', \'full_matrices\', \'compute_uv\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
}
member_method {
+ name: "tensor_diag"
+ argspec: "args=[\'diagonal\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tensor_diag_part"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "tensordot"
argspec: "args=[\'a\', \'b\', \'axes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
index 0b84165285..9add462396 100644
--- a/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.manip.pbtxt
@@ -1,7 +1,35 @@
path: "tensorflow.manip"
tf_module {
member_method {
+ name: "batch_to_space_nd"
+ argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "gather_nd"
+ argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reshape"
+ argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reverse"
+ argspec: "args=[\'tensor\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "roll"
argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
}
+ member_method {
+ name: "scatter_nd"
+ argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "space_to_batch_nd"
+ argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tile"
+ argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
index 897718c05e..a308c76ebc 100644
--- a/tensorflow/tools/api/golden/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.math.pbtxt
@@ -1,7 +1,239 @@
path: "tensorflow.math"
tf_module {
member_method {
+ name: "acos"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "acosh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "add"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "asin"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "asinh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atan2"
+ argspec: "args=[\'y\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "atanh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "bessel_i0"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "bessel_i0e"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "bessel_i1"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "bessel_i1e"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "betainc"
+ argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "ceil"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "cos"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "cosh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "digamma"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "erfc"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "exp"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "expm1"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "floor"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "greater"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "greater_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "igamma"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "igammac"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "invert_permutation"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "less"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "less_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "lgamma"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log1p"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_and"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_not"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "logical_or"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "maximum"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "minimum"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "not_equal"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "polygamma"
+ argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "polyval"
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "reciprocal"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rint"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rsqrt"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_max"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_mean"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_min"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_prod"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "segment_sum"
+ argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sin"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sinh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softplus"
+ argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softsign"
+ argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "squared_difference"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "tan"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_max"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_min"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_prod"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "unsorted_segment_sum"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "zeta"
+ argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 455590d866..d9e5b0d0fc 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -261,6 +261,10 @@ tf_module {
argspec: "args=[\'x\', \'weights\', \'biases\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "safe_embedding_lookup_sparse"
+ argspec: "args=[\'embedding_weights\', \'sparse_ids\', \'sparse_weights\', \'combiner\', \'default_id\', \'name\', \'partition_strategy\', \'max_norm\'], varargs=None, keywords=None, defaults=[\'None\', \'mean\', \'None\', \'None\', \'div\', \'None\'], "
+ }
+ member_method {
name: "sampled_softmax_loss"
argspec: "args=[\'weights\', \'biases\', \'labels\', \'inputs\', \'num_sampled\', \'num_classes\', \'num_true\', \'sampled_values\', \'remove_accidental_hits\', \'partition_strategy\', \'name\', \'seed\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'True\', \'mod\', \'sampled_softmax_loss\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
index a8d9e120cb..c74773000a 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-l-s-t-m-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
index c039890e1f..d251f54806 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
index 62c393de34..8a63b49180 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index f121ba7939..db1aae2757 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -120,7 +120,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
index 4583dc32b2..d76eab7eb8 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-g-r-u-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
index 5016b6ac30..944db6ac93 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-l-s-t-m-cell.pbtxt
@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
index 59623fc983..72b40cc9f7 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
index e2ab5aaee9..a5c2b4aefd 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt
@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index bd2a6d61f8..61d5f04b22 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -116,7 +116,7 @@ tf_class {
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 3051c4437e..4f90743fec 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -261,10 +261,18 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "VariableAggregation"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "VariableScope"
mtype: "<type \'type\'>"
}
member {
+ name: "VariableSynchronization"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "WholeFileReader"
mtype: "<type \'type\'>"
}
@@ -309,6 +317,10 @@ tf_module {
mtype: "<type \'module\'>"
}
member {
+ name: "debugging"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "distributions"
mtype: "<type \'module\'>"
}
@@ -317,6 +329,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "dtypes"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "errors"
mtype: "<type \'module\'>"
}
@@ -381,6 +397,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "io"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "keras"
mtype: "<type \'module\'>"
}
@@ -457,6 +477,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "quantization"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "quint16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
@@ -793,6 +817,10 @@ tf_module {
argspec: "args=[\'shape_x\', \'shape_y\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "broadcast_to"
+ argspec: "args=[\'input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "case"
argspec: "args=[\'pred_fn_pairs\', \'default\', \'exclusive\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'False\', \'case\'], "
}
@@ -1130,7 +1158,7 @@ tf_module {
}
member_method {
name: "get_local_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'synchronization\', \'aggregation\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\', \'None\'], "
}
member_method {
name: "get_seed"
@@ -1146,7 +1174,7 @@ tf_module {
}
member_method {
name: "get_variable"
- argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "get_variable_scope"
@@ -1290,7 +1318,7 @@ tf_module {
}
member_method {
name: "lbeta"
- argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'lbeta\'], "
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "less"
@@ -1533,10 +1561,6 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -2170,7 +2194,7 @@ tf_module {
}
member_method {
name: "while_loop"
- argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\'], "
+ argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "
}
member_method {
name: "write_file"
diff --git a/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt
new file mode 100644
index 0000000000..6d865efed0
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.quantization.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.quantization"
+tf_module {
+ member_method {
+ name: "dequantize"
+ argspec: "args=[\'input\', \'min_range\', \'max_range\', \'mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_args"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_args_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'-6\', \'6\', \'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_per_channel"
+ argspec: "args=[\'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "fake_quant_with_min_max_vars_per_channel_gradient"
+ argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "quantized_concat"
+ argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
index ca8e5884b1..83bd703540 100644
--- a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
@@ -8,11 +8,11 @@ tf_class {
}
member_method {
name: "add_meta_graph"
- argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "add_meta_graph_and_variables"
- argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "save"
diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt
index 896e2160c6..511e6b4712 100644
--- a/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt
@@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
tf_module {
member_method {
name: "load"
- argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
+ argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
}
member_method {
name: "maybe_saved_model_directory"
diff --git a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
index 4f306540cc..6a421ef12d 100644
--- a/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.spectral.pbtxt
@@ -17,6 +17,10 @@ tf_module {
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "idct"
+ argspec: "args=[\'input\', \'type\', \'n\', \'axis\', \'norm\', \'name\'], varargs=None, keywords=None, defaults=[\'2\', \'None\', \'-1\', \'None\', \'None\'], "
+ }
+ member_method {
name: "ifft"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
index a3fbe95bba..9a831fed26 100644
--- a/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.strings.pbtxt
@@ -1,7 +1,43 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "join"
+ argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "regex_replace"
+ argspec: "args=[\'input\', \'pattern\', \'rewrite\', \'replace_global\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ }
+ member_method {
+ name: "split"
+ argspec: "args=[\'source\', \'sep\', \'maxsplit\'], varargs=None, keywords=None, defaults=[\'None\', \'-1\'], "
+ }
+ member_method {
+ name: "strip"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "substr"
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket"
+ argspec: "args=[\'string_tensor\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket_fast"
+ argspec: "args=[\'input\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_hash_bucket_strong"
+ argspec: "args=[\'input\', \'num_buckets\', \'key\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "to_number"
+ argspec: "args=[\'string_tensor\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt
index ddc553d7c9..2d067e4eff 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-checkpoint.pbtxt
@@ -1,7 +1,7 @@
path: "tensorflow.train.Checkpoint"
tf_class {
is_instance: "<class \'tensorflow.python.training.checkpointable.util.Checkpoint\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.Checkpointable\'>"
+ is_instance: "<class \'tensorflow.python.training.checkpointable.tracking.Checkpointable\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
member {
diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
index 9fb18e77af..b0fb04d7d4 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
@@ -242,7 +242,7 @@ tf_module {
}
member_method {
name: "MonitoredTrainingSession"
- argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\'], "
+ argspec: "args=[\'master\', \'is_chief\', \'checkpoint_dir\', \'scaffold\', \'hooks\', \'chief_only_hooks\', \'save_checkpoint_secs\', \'save_summaries_steps\', \'save_summaries_secs\', \'config\', \'stop_grace_period_secs\', \'log_step_count_steps\', \'max_wait_secs\', \'save_checkpoint_steps\', \'summary_dir\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'None\', \'None\', \'None\', \'None\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'None\', \'120\', \'100\', \'7200\', \'<object object instance>\', \'None\'], "
}
member_method {
name: "NewCheckpointReader"
@@ -401,6 +401,10 @@ tf_module {
argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], "
}
member_method {
+ name: "remove_checkpoint"
+ argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], "
+ }
+ member_method {
name: "replica_device_setter"
argspec: "args=[\'ps_tasks\', \'ps_device\', \'worker_device\', \'merge_devices\', \'cluster\', \'ps_ops\', \'ps_strategy\'], varargs=None, keywords=None, defaults=[\'0\', \'/job:ps\', \'/job:worker\', \'True\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt
index a58398d645..09d7bc03b4 100644
--- a/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.variance_scaling_initializer.pbtxt
@@ -5,7 +5,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'normal\', \'None\', \"<dtype: \'float32\'>\"], "
+ argspec: "args=[\'self\', \'scale\', \'mode\', \'distribution\', \'seed\', \'dtype\'], varargs=None, keywords=None, defaults=[\'1.0\', \'fan_in\', \'truncated_normal\', \'None\', \"<dtype: \'float32\'>\"], "
}
member_method {
name: "from_config"
diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
index 1cf330e702..3a48cf683c 100644
--- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
+++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
@@ -88,6 +88,9 @@ def _SanitizedMRO(obj):
"""
return_list = []
for cls in tf_inspect.getmro(obj):
+ if cls.__name__ == '_NewClass':
+ # Ignore class created by @deprecated_alias decorator.
+ continue
str_repr = str(cls)
return_list.append(str_repr)
if 'tensorflow' not in str_repr:
diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake
index d5dea4f3e4..e8c3199828 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cmake
+++ b/tensorflow/tools/ci_build/Dockerfile.cmake
@@ -28,6 +28,8 @@ RUN pip install --upgrade astor
RUN pip install --upgrade gast
RUN pip install --upgrade numpy
RUN pip install --upgrade termcolor
+RUN pip install keras_applications==1.0.2
+RUN pip install keras_preprocessing==1.0.1
# Install golang
RUN apt-get install -t xenial-backports -y golang-1.9
diff --git a/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le
new file mode 100644
index 0000000000..e879c34bbd
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.cpu.ppc64le
@@ -0,0 +1,20 @@
+FROM ubuntu:16.04
+
+LABEL maintainer="William Irons <wdirons@us.ibm.com>"
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa
+RUN /install/install_deb_packages.sh
+RUN apt-get update && apt-get install -y libopenblas-dev
+RUN /install/install_hdf5_ppc64le.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_bazel_from_source.sh
+RUN /install/install_proto3.sh
+RUN /install/install_buildifier_from_source.sh
+RUN /install/install_auditwheel.sh
+RUN /install/install_golang_ppc64le.sh
+
+# Set up the master bazelrc configuration file.
+COPY install/.bazelrc /etc/bazel.bazelrc
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
new file mode 100644
index 0000000000..8967138747
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
@@ -0,0 +1,28 @@
+FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04
+
+LABEL maintainer="William Irons <wdirons@us.ibm.com>"
+
+# In the Ubuntu 16.04 images, cudnn is placed in system paths. Move them to
+# /usr/local/cuda
+RUN cp -P /usr/include/cudnn.h /usr/local/cuda/include
+RUN cp -P /usr/lib/powerpc64le-linux-gnu/libcudnn* /usr/local/cuda/lib64
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa
+RUN /install/install_deb_packages.sh
+RUN apt-get update && apt-get install -y libopenblas-dev
+RUN /install/install_hdf5_ppc64le.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_bazel_from_source.sh
+RUN /install/install_golang_ppc64le.sh
+
+# Set up the master bazelrc configuration file.
+COPY install/.bazelrc /etc/bazel.bazelrc
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+
+# Configure the build for our CUDA configuration.
+ENV TF_NEED_CUDA 1
+ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu
index 3bc52b9ed6..7e5860aeec 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cpu
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cpu
@@ -1,4 +1,4 @@
-FROM launcher.gcr.io/google/rbe-debian8:r327695
+FROM launcher.gcr.io/google/rbe-ubuntu16-04:r327695
LABEL maintainer="Yu Yi <yiyu@google.com>"
# Copy install scripts
@@ -9,6 +9,6 @@ ENV CC /usr/local/bin/clang
ENV CXX /usr/local/bin/clang++
ENV AR /usr/bin/ar
-# Run pip install script for RBE Debian8 container.
+# Run pip install script for RBE Ubuntu 16-04 container.
RUN /install/install_pip_packages_remote.sh
RUN /install/install_pip_packages.sh
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index 5fa75e1d61..883bb93647 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -322,6 +322,10 @@ create_activate_virtualenv_and_install_tensorflow() {
pip install -v ${PIP_FLAGS} ${WHL_PATH} || \
die "pip install (forcing to reinstall tensorflow) FAILED"
echo "Successfully installed pip package ${TF_WHEEL_PATH}"
+
+ # Force downgrade setuptools.
+ pip install --upgrade setuptools==39.1.0
+
}
################################################################################
diff --git a/tensorflow/tools/ci_build/ci_build.sh b/tensorflow/tools/ci_build/ci_build.sh
index 1f0fd0387a..f6a50d3d4c 100755
--- a/tensorflow/tools/ci_build/ci_build.sh
+++ b/tensorflow/tools/ci_build/ci_build.sh
@@ -79,7 +79,7 @@ if [[ "${CONTAINER_TYPE}" == "cmake" ]]; then
fi
# Use nvidia-docker if the container is GPU.
-if [[ "${CONTAINER_TYPE}" == "gpu" ]]; then
+if [[ "${CONTAINER_TYPE}" == gpu* ]]; then
DOCKER_BINARY="nvidia-docker"
else
DOCKER_BINARY="docker"
@@ -99,7 +99,7 @@ BUILD_TAG="${BUILD_TAG:-tf_ci}"
# Add extra params for cuda devices and libraries for GPU container.
# And clear them if we are not building for GPU.
-if [[ "${CONTAINER_TYPE}" != "gpu" ]]; then
+if [[ "${CONTAINER_TYPE}" != gpu* ]]; then
GPU_EXTRA_PARAMS=""
fi
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 6aaeb14aee..08e2c3edd2 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -59,6 +59,9 @@
# TF_BUILD_BAZEL_CLEAN:
# Will perform "bazel clean", if and only if this variable
# is set to any non-empty and non-0 value
+# TF_BAZEL_BUILD_ONLY:
+# If it is set to any non-empty value that is not "0", Bazel
+# will only build specified targets
# TF_GPU_COUNT:
# Run this many parallel tests for serial builds.
# For now, only can be edited for PIP builds.
@@ -94,10 +97,6 @@
#
# This script can be used by Jenkins parameterized / matrix builds.
-# TODO(jhseu): Temporary for the gRPC pull request due to the
-# protobuf -> protobuf_archive rename. Remove later.
-TF_BUILD_BAZEL_CLEAN=1
-
# Helper function: Convert to lower case
to_lower () {
echo "$1" | tr '[:upper:]' '[:lower:]'
@@ -262,9 +261,9 @@ function set_script_variable() {
# Process container type
-if [[ ${CTYPE} == "cpu" ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then
+if [[ ${CTYPE} == cpu* ]] || [[ ${CTYPE} == "debian.jessie.cpu" ]]; then
:
-elif [[ ${CTYPE} == "gpu" ]]; then
+elif [[ ${CTYPE} == gpu* ]]; then
set_script_variable TF_NEED_CUDA 1
if [[ $TF_CUDA_CLANG == "1" ]]; then
@@ -414,6 +413,11 @@ fi
# this flag, and it only affects a few tests.
EXTRA_ARGS="${EXTRA_ARGS} --distinct_host_configuration=false"
+if [[ ! -z "${TF_BAZEL_BUILD_ONLY}" ]] &&
+ [[ "${TF_BAZEL_BUILD_ONLY}" != "0" ]];then
+ BAZEL_CMD=${BAZEL_BUILD_ONLY_CMD}
+fi
+
# Process PIP install-test option
if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] ||
[[ ${TF_BUILD_IS_PIP} == "both" ]]; then
@@ -422,12 +426,12 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] ||
BAZEL_TARGET=${TF_BUILD_BAZEL_TARGET}
fi
- if [[ ${CTYPE} == "cpu" ]] || \
+ if [[ ${CTYPE} == cpu* ]] || \
[[ ${CTYPE} == "debian.jessie.cpu" ]]; then
# CPU only command, fully parallel.
NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} ${EXTRA_ARGS} -- "\
"${BAZEL_TARGET}"
- elif [[ ${CTYPE} == "gpu" ]]; then
+ elif [[ ${CTYPE} == gpu* ]]; then
# GPU only command, run as many jobs as the GPU count only.
NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\
"--local_test_jobs=${TF_GPU_COUNT} "\
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 0dd32ad1a8..db37edf809 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -349,12 +349,12 @@ do_external_licenses_check(){
# Blacklist
echo ${MISSING_LICENSES_FILE}
- grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -v ${MISSING_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//third_party/" -e "@com_google_absl//absl" -e "@org_tensorflow//" -e "@com_github_googlecloudplatform_google_cloud_cpp//google" -v ${MISSING_LICENSES_FILE} > temp.txt
mv temp.txt ${MISSING_LICENSES_FILE}
# Whitelist
echo ${EXTRA_LICENSE_FILE}
- grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -v ${EXTRA_LICENSES_FILE} > temp.txt
+ grep -e "@bazel_tools//src" -e "@bazel_tools//tools/" -e "@com_google_absl//" -e "//external" -e "@local" -e "@com_github_googlecloudplatform_google_cloud_cpp//" -v ${EXTRA_LICENSES_FILE} > temp.txt
mv temp.txt ${EXTRA_LICENSES_FILE}
diff --git a/tensorflow/tools/ci_build/copy_binary.py b/tensorflow/tools/ci_build/copy_binary.py
index 420d390d2b..148526492d 100755
--- a/tensorflow/tools/ci_build/copy_binary.py
+++ b/tensorflow/tools/ci_build/copy_binary.py
@@ -32,7 +32,8 @@ import shutil
import tempfile
import zipfile
-TF_NIGHTLY_REGEX = r"(.+)tf_nightly(|_gpu)-(\d\.\d\.\d.dev[\d]{0,8})-(.+)\.whl"
+TF_NIGHTLY_REGEX = (r"(.+)tf_nightly(|_gpu)-(\d\.[\d]{1,2}"
+ "\.\d.dev[\d]{0,8})-(.+)\.whl")
BINARY_STRING_TEMPLATE = "%s-%s-%s.whl"
diff --git a/tensorflow/tools/ci_build/install/install_bazel_from_source.sh b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
new file mode 100755
index 0000000000..ddad00c5f0
--- /dev/null
+++ b/tensorflow/tools/ci_build/install/install_bazel_from_source.sh
@@ -0,0 +1,40 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This script is to be used to install bzel on non x86_64 systems
+# It will compile bazel from source and install it in /usr/local/bin
+
+# Select bazel version.
+BAZEL_VERSION="0.11.0"
+
+set +e
+local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}')
+
+if [[ "$local_bazel_ver" == "$BAZEL_VERSION" ]]; then
+ exit 0
+fi
+
+set -e
+
+# Compile bazel from source
+mkdir -p /bazel
+cd /bazel
+
+curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip
+unzip bazel-$BAZEL_VERSION-dist.zip
+bash ./compile.sh
+cp output/bazel /usr/local/bin/
+rm -rf /bazel
diff --git a/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh
new file mode 100755
index 0000000000..a93c258fad
--- /dev/null
+++ b/tensorflow/tools/ci_build/install/install_buildifier_from_source.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+BUILDTOOLS_VERSION="0.11.1"
+
+# Clone buildtools
+git clone -b $BUILDTOOLS_VERSION https://github.com/bazelbuild/buildtools
+cd buildtools
+
+# Build buildifier
+bazel build //buildifier
+sudo mv bazel-bin/buildifier/linux*stripped/buildifier /usr/local/bin
+
+# Build buildozer
+bazel build //buildozer
+sudo mv bazel-bin/buildozer/linux*stripped/buildozer /usr/local/bin
diff --git a/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh
new file mode 100755
index 0000000000..47d23a59b3
--- /dev/null
+++ b/tensorflow/tools/ci_build/install/install_golang_ppc64le.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -ex
+
+GOLANG_URL="https://storage.googleapis.com/golang/go1.10.linux-ppc64le.tar.gz"
+
+sudo mkdir -p /usr/local
+wget -q -O - "${GOLANG_URL}" | sudo tar -C /usr/local -xz
diff --git a/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh
new file mode 100755
index 0000000000..4989d986b8
--- /dev/null
+++ b/tensorflow/tools/ci_build/install/install_hdf5_ppc64le.sh
@@ -0,0 +1,30 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+#This is required because pypi doesn't have a pre-built h5py binary for ppc64le
+#It has to be compiled from source during the install
+apt-get update
+apt-get install -y libhdf5-dev
+
+#h5py is not expecting the shared libraries to have _serial in the name.
+ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial.so /usr/lib/powerpc64le-linux-gnu/libhdf5.so
+ln -s /usr/lib/powerpc64le-linux-gnu/libhdf5_serial_hl.so /usr/lib/powerpc64le-linux-gnu/libhdf5_hl.so
+
+#pip is not installed yet, so use easy_install
+#CPATH is the location of hdf5.h
+CPATH=/usr/include/hdf5/serial/ easy_install -U h5py
+CPATH=/usr/include/hdf5/serial/ easy_install3 -U h5py
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index b3d3f23ec8..221b5b80fb 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -51,8 +51,8 @@ pip2 install --upgrade markdown==2.6.8
pip3 install --upgrade markdown==2.6.8
# Install protobuf.
-pip2 install --upgrade protobuf==3.3.0
-pip3 install --upgrade protobuf==3.3.0
+pip2 install --upgrade protobuf==3.6.0
+pip3 install --upgrade protobuf==3.6.0
# Remove obsolete version of six, which can sometimes confuse virtualenv.
rm -rf /usr/lib/python3/dist-packages/six*
@@ -113,3 +113,13 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
+
+# Keras
+pip2 install keras_applications==1.0.2
+pip3 install keras_applications==1.0.2
+pip2 install keras_preprocessing==1.0.1
+pip3 install keras_preprocessing==1.0.1
+
+# Install last working version of setuptools.
+pip2 install --upgrade setuptools==39.1.0
+pip3 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/ci_build/install/install_proto3.sh b/tensorflow/tools/ci_build/install/install_proto3.sh
index 7934002b2c..821d50baff 100755
--- a/tensorflow/tools/ci_build/install/install_proto3.sh
+++ b/tensorflow/tools/ci_build/install/install_proto3.sh
@@ -17,7 +17,7 @@
# Install protobuf3.
# Select protobuf version.
-PROTOBUF_VERSION="3.3.0"
+PROTOBUF_VERSION="3.6.0"
protobuf_ver_flat=$(echo $PROTOBUF_VERSION | sed 's/\.//g' | sed 's/^0*//g')
local_protobuf_ver=$(protoc --version)
local_protobuf_ver_flat=$(echo $local_protobuf_ver | sed 's/\.//g' | sed 's/^0*//g')
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 61d34c7304..45a30c6e82 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -48,7 +48,7 @@ pip3.5 install --upgrade absl-py
pip3.5 install --upgrade six==1.10.0
# Install protobuf.
-pip3.5 install --upgrade protobuf==3.3.0
+pip3.5 install --upgrade protobuf==3.6.0
# Remove obsolete version of six, which can sometimes confuse virtualenv.
rm -rf /usr/lib/python3/dist-packages/six*
@@ -84,4 +84,11 @@ pip3.5 install --upgrade termcolor
# Install last working version of setuptools.
pip3.5 install --upgrade setuptools==39.1.0
+# Keras
+pip3.5 install keras_applications==1.0.2
+pip3.5 install keras_preprocessing==1.0.1
+
+# Install last working version of setuptools.
+pip3.5 install --upgrade setuptools==39.1.0
+
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh)
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index fe2d2cf11c..d66b2aa18a 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -60,7 +60,7 @@ pip3 install --upgrade absl-py
pip3 install --upgrade six==1.10.0
# Install protobuf.
-pip3 install --upgrade protobuf==3.3.0
+pip3 install --upgrade protobuf==3.6.0
# Remove obsolete version of six, which can sometimes confuse virtualenv.
rm -rf /usr/lib/python3/dist-packages/six*
@@ -100,4 +100,8 @@ pip3 install --upgrade termcolor
# Install last working version of setuptools.
pip3 install --upgrade setuptools==39.1.0
+# Keras
+pip3.5 install keras_applications==1.0.2
+pip3.5 install keras_preprocessing==1.0.1
+
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
new file mode 100755
index 0000000000..50ee07e727
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
@@ -0,0 +1,47 @@
+#!/usr/bin/env bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ==============================================================================
+
+set -e
+set -x
+
+N_JOBS=$(grep -c ^processor /proc/cpuinfo)
+
+echo ""
+echo "Bazel will use ${N_JOBS} concurrent job(s)."
+echo ""
+
+# Run configure.
+export PYTHON_BIN_PATH=`which python2`
+
+export TF_NEED_CUDA=1
+export TF_CUDA_VERSION=9.0
+export TF_CUDNN_VERSION=7
+export TF_CUDA_COMPUTE_CAPABILITIES=3.7
+
+yes "" | $PYTHON_BIN_PATH configure.py
+
+# Run bazel test command. Double test timeouts to avoid flakes.
+# Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution
+# in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads
+# caused by executing multiple tests concurrently.
+bazel test --config=cuda --test_tag_filters=-no_oss,-oss_serial,-no_gpu,-benchmark-test \
+ --test_lang_filters=cc,py -k --jobs="${N_JOBS}" \
+ --test_timeout 300,450,1200,3600 --build_tests_only --test_env=KMP_BLOCKTIME=0\
+ --config=mkl --config=opt --test_output=errors --local_test_jobs=8 \
+ --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute -- \
+ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
+
diff --git a/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh
new file mode 100755
index 0000000000..68354bf7c1
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-gpu-test.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Usage: basic_mkl_test.sh
+
+# Helper function to traverse directories up until given file is found.
+function upsearch () {
+ test / == "$PWD" && return || \
+ test -e "$1" && echo "$PWD" && return || \
+ cd .. && upsearch "$1"
+}
+
+# Set up WORKSPACE.
+WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
+
+BUILD_TAG=mkl-gpu-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh gpu tensorflow/tools/ci_build/linux/gpu/run_mkl.sh
diff --git a/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh
new file mode 100755
index 0000000000..10a09a415a
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/mkl/basic-mkl-test.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Usage: basic_mkl_test.sh
+
+# Helper function to traverse directories up until given file is found.
+function upsearch () {
+ test / == "$PWD" && return || \
+ test -e "$1" && echo "$PWD" && return || \
+ cd .. && upsearch "$1"
+}
+
+# Set up WORKSPACE.
+WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
+
+BUILD_TAG=mkl-ci-test CI_BUILD_USER_FORCE_BADNAME=yes ${WORKSPACE}/tensorflow/tools/ci_build/ci_build.sh cpu tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
diff --git a/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh
new file mode 100755
index 0000000000..ad22ebe4eb
--- /dev/null
+++ b/tensorflow/tools/ci_build/linux/mkl/build-dev-container.sh
@@ -0,0 +1,53 @@
+#!/usr/bin/env bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Build a whl and container with Intel(R) MKL support
+# Usage: build-dev-container.sh
+
+# Helper function to traverse directories up until given file is found.
+function upsearch () {
+ test / == "$PWD" && return || \
+ test -e "$1" && echo "$PWD" && return || \
+ cd .. && upsearch "$1"
+}
+
+# Set up WORKSPACE.
+WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}"
+
+TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH:-master}
+TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME:-intel-mkl/tensorflow}
+TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION:-nightly}
+
+echo "TF_DOCKER_BUILD_DEVEL_BRANCH=${TF_DOCKER_BUILD_DEVEL_BRANCH}"
+echo "TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}"
+echo "TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}"
+
+# build the python 2 container and whl
+TF_DOCKER_BUILD_TYPE="MKL" \
+ TF_DOCKER_BUILD_IS_DEVEL="YES" \
+ TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \
+ TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \
+ TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \
+ ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+
+# build the python 3 container and whl
+TF_DOCKER_BUILD_TYPE="MKL" \
+ TF_DOCKER_BUILD_IS_DEVEL="YES" \
+ TF_DOCKER_BUILD_DEVEL_BRANCH="${TF_DOCKER_BUILD_DEVEL_BRANCH}" \
+ TF_DOCKER_BUILD_IMAGE_NAME="${TF_DOCKER_BUILD_IMAGE_NAME}" \
+ TF_DOCKER_BUILD_VERSION="${TF_DOCKER_BUILD_VERSION}" \
+ TF_DOCKER_BUILD_PYTHON_VERSION="PYTHON3" \
+ ${WORKSPACE}/tensorflow/tools/docker/parameterized_docker_build.sh
+
diff --git a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
index 5eff3e415d..3d27e84b81 100755
--- a/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
+++ b/tensorflow/tools/ci_build/pi/build_raspberry_pi.sh
@@ -106,6 +106,8 @@ bazel build -c opt ${PI_COPTS} \
--copt=-fomit-frame-pointer --cpu=armeabi \
--crosstool_top=@local_config_arm_compiler//:toolchain \
--verbose_failures \
+ //tensorflow:libtensorflow.so \
+ //tensorflow:libtensorflow_framework.so \
//tensorflow/tools/benchmark:benchmark_model \
//tensorflow/tools/pip_package:build_pip_package
diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py
index 00bfcfd49b..30c318a58f 100755
--- a/tensorflow/tools/ci_build/update_version.py
+++ b/tensorflow/tools/ci_build/update_version.py
@@ -37,7 +37,7 @@ SETUP_PY = "%s/tools/pip_package/setup.py" % TF_SRC_DIR
README_MD = "./README.md"
DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel" % TF_SRC_DIR
GPU_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-gpu" % TF_SRC_DIR
-CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-cpu-mkl" % TF_SRC_DIR
+CPU_MKL_DEVEL_DOCKERFILE = "%s/tools/docker/Dockerfile.devel-mkl" % TF_SRC_DIR
RELEVANT_FILES = [TF_SRC_DIR,
VERSION_H,
SETUP_PY,
@@ -248,16 +248,6 @@ def update_md_files(old_version, new_version):
replace_string_in_line(r"<version>%s<\/version>" % old_version,
"<version>%s</version>" % new_version, filepath)
- # Update any links to colab notebooks.
- def colab_url(version):
- version_string = "%s.%s.%s" % (version.major, version.minor, version.patch)
- prefix = "https://colab.research.google.com/github/tensorflow/models/blob/r"
- return prefix + version_string + "/"
-
- replace_string_in_line(
- colab_url(old_version), colab_url(new_version),
- "%s/docs_src/get_started/eager.md" % TF_SRC_DIR)
-
def major_minor_change(old_version, new_version):
"""Check if a major or minor change occurred."""
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index 582188fc00..c03cbd9c66 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -14,136 +14,29 @@
# limitations under the License.
# ==============================================================================
#
-# C++ tests
-failing_cpu_cc_tests="\
- //tensorflow/core/kernels:control_flow_ops_test + \
- //tensorflow/core:example_example_parser_configuration_test + \
- //tensorflow/core:lib_core_status_test + \
- //tensorflow/core:lib_monitoring_collection_registry_test + \
- //tensorflow/core:lib_strings_numbers_test + \
- //tensorflow/core/platform/hadoop:hadoop_file_system_test + \
- //tensorflow/core:platform_file_system_test + \
- //tensorflow/core:platform_logging_test + \
- //tensorflow/core:util_sparse_sparse_tensor_test + \
- //tensorflow/cc:framework_gradient_checker_test + \
- //tensorflow/cc:framework_gradients_test + \
- //tensorflow/cc:gradients_array_grad_test + \
- //tensorflow/cc:gradients_math_grad_test + \
- //tensorflow/cc:gradients_nn_grad_test + \
- //tensorflow/cc/saved_model:loader_test \
-"
-
-broken_cpu_cc_tests="\
- //tensorflow/cc:framework_cc_ops_test + \
- //tensorflow/core/platform/cloud:time_util_test + \
- //tensorflow/core/platform/cloud:oauth_client_test + \
- //tensorflow/core/platform/cloud:http_request_test + \
- //tensorflow/core/platform/cloud:google_auth_provider_test + \
- //tensorflow/core/platform/cloud:gcs_file_system_test + \
- //tensorflow/core/kernels/cloud:bigquery_table_accessor_test + \
- //tensorflow/core/kernels/hexagon:graph_transferer_test + \
- //tensorflow/core/kernels:remote_fused_graph_execute_utils_test + \
- //tensorflow/core/kernels:requantize_op_test + \
- //tensorflow/core/kernels:requantization_range_op_test + \
- //tensorflow/core/kernels:quantized_reshape_op_test + \
- //tensorflow/core/kernels:quantized_pooling_ops_test + \
- //tensorflow/core/kernels:quantized_matmul_op_test + \
- //tensorflow/core/kernels:quantized_conv_ops_test + \
- //tensorflow/core/kernels:quantized_concat_op_test + \
- //tensorflow/core/kernels:quantized_bias_add_op_test + \
- //tensorflow/core/kernels:quantized_batch_norm_op_test + \
- //tensorflow/core/kernels:quantized_activation_ops_test + \
- //tensorflow/core/kernels:quantize_op_test + \
- //tensorflow/core/kernels:quantize_down_and_shrink_range_op_test + \
- //tensorflow/core/kernels:quantize_and_dequantize_op_test_gpu + \
- //tensorflow/core/kernels:quantize_and_dequantize_op_test + \
- //tensorflow/core/kernels:quantization_utils_test + \
- //tensorflow/core/kernels:debug_ops_test + \
- //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test_gpu + \
- //tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr_test + \
- //tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding_test + \
- //tensorflow/core/distributed_runtime/rpc:grpc_session_test_gpu + \
- //tensorflow/core/distributed_runtime/rpc:grpc_session_test + \
- //tensorflow/core/distributed_runtime/rpc:grpc_channel_test_gpu + \
- //tensorflow/core/distributed_runtime/rpc:grpc_channel_test + \
- //tensorflow/core/distributed_runtime:remote_device_test_gpu + \
- //tensorflow/core/distributed_runtime:remote_device_test + \
- //tensorflow/core/distributed_runtime:executor_test_gpu + \
- //tensorflow/core/distributed_runtime:executor_test + \
- //tensorflow/core/debug:debug_gateway_test + \
- //tensorflow/core/debug:debug_grpc_io_utils_test + \
- //tensorflow/core:util_reporter_test + \
- //tensorflow/core:util_memmapped_file_system_test + \
- //tensorflow/core:platform_subprocess_test + \
- //tensorflow/core:platform_profile_utils_cpu_utils_test + \
- //tensorflow/core:lib_jpeg_jpeg_mem_unittest + \
- //tensorflow/core/debug:debug_io_utils_test \
-"
-
-# lib_core_threadpool_test is timeout, but it passes when running alone
-extra_failing_gpu_cc_tests="\
- //tensorflow/core:lib_core_threadpool_test + \
- //tensorflow/core:cuda_libdevice_path_test + \
- //tensorflow/core:common_runtime_direct_session_test + \
- //tensorflow/core:common_runtime_direct_session_with_tracking_alloc_test + \
- //tensorflow/core:device_tracer_test + \
- //tensorflow/core:ops_math_grad_test \
-"
-
-exclude_cpu_cc_tests="${failing_cpu_cc_tests} + ${broken_cpu_cc_tests}"
-
-exclude_gpu_cc_tests="${extra_failing_gpu_cc_tests} + ${exclude_cpu_cc_tests}"
function run_configure_for_cpu_build {
- # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182
- # yes "" | ./configure doesn't work on Windows, so we set all the
- # environment variables in advance to avoid interact with the script.
- export TF_NEED_CUDA=0
- if [ -z "$TF_ENABLE_XLA" ]; then
- export TF_ENABLE_XLA=0
- fi
- if [ -z "$TF_NEED_MKL" ]; then
- export TF_NEED_MKL=0
- fi
- export TF_NEED_VERBS=0
- export TF_NEED_GCP=1
- export TF_NEED_HDFS=0
- export TF_NEED_OPENCL_SYCL=0
- echo "" | ./configure
+ yes "" | ./configure
}
function run_configure_for_gpu_build {
- # Due to a bug in Bazel: https://github.com/bazelbuild/bazel/issues/2182
- # yes "" | ./configure doesn't work on Windows, so we set all the
- # environment variables in advance to avoid interact with the script.
+ # Enable CUDA support
export TF_NEED_CUDA=1
- export TF_CUDA_VERSION=9.0
- export CUDA_TOOLKIT_PATH="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0"
- export TF_CUDNN_VERSION=7.0
- if [ -z "$CUDNN_INSTALL_PATH" ]; then
- export CUDNN_INSTALL_PATH="C:/tools/cuda"
- fi
- export TF_CUDA_COMPUTE_CAPABILITIES="3.7"
- if [ -z "$TF_ENABLE_XLA" ]; then
- export TF_ENABLE_XLA=0
- fi
- export TF_NEED_VERBS=0
- export TF_NEED_MKL=0
- export TF_NEED_GCP=0
- export TF_NEED_HDFS=0
- export TF_NEED_OPENCL_SYCL=0
-
- # TODO(pcloudy): Remove this after TensorFlow uses its own CRSOOTOOL
- # for GPU build on Windows
- export USE_MSVC_WRAPPER=1
- echo "" | ./configure
+ yes "" | ./configure
}
-function set_gcs_remote_cache_options {
- echo "build --experimental_remote_spawn_cache" >> "${TMP_BAZELRC}"
+function set_remote_cache_options {
+ echo "build --remote_instance_name=projects/tensorflow-testing-cpu" >> "${TMP_BAZELRC}"
echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}"
- echo "build --remote_http_cache=https://storage.googleapis.com/$GCS_BUCKET_NAME" >> "${TMP_BAZELRC}"
+ echo "build --remote_cache=remotebuildexecution.googleapis.com" >> "${TMP_BAZELRC}"
+ echo "build --tls_enabled=true" >> "${TMP_BAZELRC}"
+ echo "build --remote_timeout=3600" >> "${TMP_BAZELRC}"
+ echo "build --auth_enabled=true" >> "${TMP_BAZELRC}"
+ echo "build --spawn_strategy=remote" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Javac=remote" >> "${TMP_BAZELRC}"
+ echo "build --strategy=Closure=remote" >> "${TMP_BAZELRC}"
+ echo "build --genrule_strategy=remote" >> "${TMP_BAZELRC}"
echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}"
}
diff --git a/tensorflow/tools/ci_build/windows/bazel/common_env.sh b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
index 0e6c0227b7..3af132217e 100644
--- a/tensorflow/tools/ci_build/windows/bazel/common_env.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/common_env.sh
@@ -50,7 +50,14 @@ export PATH="/c/Program Files/Git/cmd:$PATH"
# Make sure we have pip in PATH
export PATH="/c/${PYTHON_BASE_PATH}/Scripts:$PATH"
+# Setting default values to CUDA related environment variables
+export TF_CUDA_VERSION=${TF_CUDA_VERSION:-9.0}
+export TF_CUDNN_VERSION=${TF_CUDNN_VERSION:-7.0}
+export TF_CUDA_COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES:-3.7}
+export CUDA_TOOLKIT_PATH=${CUDA_TOOLKIT_PATH:-"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v${TF_CUDA_VERSION}"}
+export CUDNN_INSTALL_PATH=${CUDNN_INSTALL_PATH:-"C:/tools/cuda"}
+
# Add Cuda and Cudnn dll directories into PATH
-export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/bin:$PATH"
-export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v9.0/extras/CUPTI/libx64:$PATH"
-export PATH="/c/tools/cuda/bin:$PATH"
+export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/bin:$PATH"
+export PATH="$(cygpath -u "${CUDA_TOOLKIT_PATH}")/extras/CUPTI/libx64:$PATH"
+export PATH="$(cygpath -u "${CUDNN_INSTALL_PATH}")/bin:$PATH"
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index f4a0b232ec..ed73401467 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -59,8 +59,8 @@ release_build=0
for ARG in "$@"; do
if [[ "$ARG" == --skip_test ]]; then
skip_test=1
- elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then
- set_gcs_remote_cache_options
+ elif [[ "$ARG" == --enable_remote_cache ]]; then
+ set_remote_cache_options
elif [[ "$ARG" == --release_build ]]; then
release_build=1
fi
@@ -77,11 +77,16 @@ fi
# to distinct them. This helps avoid building the same targets twice.
echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}"
-echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc
+# Enable short object file path to avoid long path issue on Windows.
+echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
+
+if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then
+ echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc
+fi
run_configure_for_cpu_build
-bazel build --announce_rc -c opt tensorflow/tools/pip_package:build_pip_package || exit $?
+bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $?
if [[ "$skip_test" == 1 ]]; then
exit 0
@@ -102,7 +107,7 @@ N_JOBS="${NUMBER_OF_PROCESSORS}"
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
-bazel test -c opt -k --test_output=errors \
+bazel test --announce_rc --config=opt -k --test_output=errors \
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
--test_tag_filters=-no_pip,-no_windows,-no_oss \
--build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \
diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
index 4656afe025..cec5b717f8 100644
--- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
+++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat
@@ -30,7 +30,6 @@ IF DEFINED SWIG_EXE (ECHO SWIG_EXE is set to %SWIG_EXE%) ELSE (SET SWIG_EXE="C:\
IF DEFINED PY_EXE (ECHO PY_EXE is set to %PY_EXE%) ELSE (SET PY_EXE="C:\Program Files\Anaconda3\python.exe")
IF DEFINED PY_LIB (ECHO PY_LIB is set to %PY_LIB%) ELSE (SET PY_LIB="C:\Program Files\Anaconda3\libs\python35.lib")
IF DEFINED CUDNN_HOME (ECHO CUDNN_HOME is set to %CUDNN_HOME%) ELSE (SET CUDNN_HOME="c:\tools\cuda")
-verbosity:quiet
IF DEFINED DISABLE_FORCEINLINE (ECHO DISABLE_FORCEINLINE is set to %DISABLE_FORCEINLINE%) ELSE (SET DISABLE_FORCEINLINE="OFF")
SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 922bb67bbf..fe3bce428f 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -42,9 +42,58 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \
source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \
|| { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; }
+# Recreate an empty bazelrc file under source root
+export TMP_BAZELRC=.tmp.bazelrc
+rm -f "${TMP_BAZELRC}"
+touch "${TMP_BAZELRC}"
+
+function cleanup {
+ # Remove all options in .tmp.bazelrc
+ echo "" > "${TMP_BAZELRC}"
+}
+trap cleanup EXIT
+
+skip_test=0
+release_build=0
+
+for ARG in "$@"; do
+ if [[ "$ARG" == --skip_test ]]; then
+ skip_test=1
+ elif [[ "$ARG" == --enable_remote_cache ]]; then
+ set_remote_cache_options
+ elif [[ "$ARG" == --release_build ]]; then
+ release_build=1
+ fi
+done
+
+if [[ "$release_build" != 1 ]]; then
+ # --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
+ # by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
+ # Because this hurts the performance of TF, we don't enable it in release build.
+ echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+fi
+
+# The host and target platforms are the same in Windows build. So we don't have
+# to distinct them. This helps avoid building the same targets twice.
+echo "build --distinct_host_configuration=false" >> "${TMP_BAZELRC}"
+
+# Enable short object file path to avoid long path issue on Windows.
+echo "startup --output_user_root=${TMPDIR}" >> "${TMP_BAZELRC}"
+
+# Disable nvcc warnings to reduce log file size.
+echo "build --copt=-nvcc_options=disable-warnings" >> "${TMP_BAZELRC}"
+
+if ! grep -q "import %workspace%/${TMP_BAZELRC}" .bazelrc; then
+ echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc
+fi
+
run_configure_for_gpu_build
-bazel build -c opt tensorflow/tools/pip_package:build_pip_package || exit $?
+bazel build --announce_rc --config=opt tensorflow/tools/pip_package:build_pip_package || exit $?
+
+if [[ "$skip_test" == 1 ]]; then
+ exit 0
+fi
# Create a python test directory to avoid package name conflict
PY_TEST_DIR="py_test_dir"
@@ -59,8 +108,11 @@ reinstall_tensorflow_pip ${PIP_NAME}
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
# GPU tests are very flaky when running concurrently, so set local_test_jobs=1
-bazel test -c opt -k --test_output=errors \
+bazel test --announce_rc --config=opt -k --test_output=errors \
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
- --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \
- --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,no_oss \
- --local_test_jobs=1 --build_tests_only //${PY_TEST_DIR}/tensorflow/python/...
+ --test_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss \
+ --build_tag_filters=-no_pip,-no_windows,-no_windows_gpu,-no_gpu,-no_pip_gpu,-no_oss --build_tests_only \
+ --local_test_jobs=1 --test_timeout="300,450,1200,3600" \
+ --flaky_test_attempts=3 \
+ //${PY_TEST_DIR}/tensorflow/python/... \
+ //${PY_TEST_DIR}/tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
index 583d1d5f09..fdbd1120b2 100755
--- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
@@ -41,7 +41,7 @@ run_configure_for_cpu_build
# build_libtensorflow_tarball in ../builds/libtensorflow.sh
# cannot be used on Windows since it relies on pkg_tar rules.
# So we do something special here
-bazel build -c opt --copt=/arch:AVX \
+bazel --output_user_root=${TMPDIR} build -c opt --copt=/arch:AVX \
tensorflow:libtensorflow.so \
tensorflow/tools/lib_package:clicenses_generate \
tensorflow/java:libtensorflow_jni.so \
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
new file mode 100644
index 0000000000..23cc4a21a9
--- /dev/null
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -0,0 +1,502 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Upgrader for Python scripts according to an API change specification."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import collections
+import os
+import shutil
+import sys
+import tempfile
+import traceback
+
+
+class APIChangeSpec(object):
+ """This class defines the transformations that need to happen.
+
+ This class must provide the following fields:
+
+ * `function_keyword_renames`: maps function names to a map of old -> new
+ argument names
+ * `function_renames`: maps function names to new function names
+ * `change_to_function`: a set of function names that have changed (for
+ notifications)
+ * `function_reorders`: maps functions whose argument order has changed to the
+ list of arguments in the new order
+ * `function_handle`: maps function names to custom handlers for the function
+
+ For an example, see `TFAPIChangeSpec`.
+ """
+
+
+class _FileEditTuple(
+ collections.namedtuple("_FileEditTuple",
+ ["comment", "line", "start", "old", "new"])):
+ """Each edit that is recorded by a _FileEditRecorder.
+
+ Fields:
+ comment: A description of the edit and why it was made.
+ line: The line number in the file where the edit occurs (1-indexed).
+ start: The line number in the file where the edit occurs (0-indexed).
+ old: text string to remove (this must match what was in file).
+ new: text string to add in place of `old`.
+ """
+
+ __slots__ = ()
+
+
+class _FileEditRecorder(object):
+ """Record changes that need to be done to the file."""
+
+ def __init__(self, filename):
+ # all edits are lists of chars
+ self._filename = filename
+
+ self._line_to_edit = collections.defaultdict(list)
+ self._errors = []
+
+ def process(self, text):
+ """Process a list of strings, each corresponding to the recorded changes.
+
+ Args:
+ text: A list of lines of text (assumed to contain newlines)
+ Returns:
+ A tuple of the modified text and a textual description of what is done.
+ Raises:
+ ValueError: if substitution source location does not have expected text.
+ """
+
+ change_report = ""
+
+ # Iterate of each line
+ for line, edits in self._line_to_edit.items():
+ offset = 0
+ # sort by column so that edits are processed in order in order to make
+ # indexing adjustments cumulative for changes that change the string
+ # length
+ edits.sort(key=lambda x: x.start)
+
+ # Extract each line to a list of characters, because mutable lists
+ # are editable, unlike immutable strings.
+ char_array = list(text[line - 1])
+
+ # Record a description of the change
+ change_report += "%r Line %d\n" % (self._filename, line)
+ change_report += "-" * 80 + "\n\n"
+ for e in edits:
+ change_report += "%s\n" % e.comment
+ change_report += "\n Old: %s" % (text[line - 1])
+
+ # Make underscore buffers for underlining where in the line the edit was
+ change_list = [" "] * len(text[line - 1])
+ change_list_new = [" "] * len(text[line - 1])
+
+ # Iterate for each edit
+ for e in edits:
+ # Create effective start, end by accounting for change in length due
+ # to previous edits
+ start_eff = e.start + offset
+ end_eff = start_eff + len(e.old)
+
+ # Make sure the edit is changing what it should be changing
+ old_actual = "".join(char_array[start_eff:end_eff])
+ if old_actual != e.old:
+ raise ValueError("Expected text %r but got %r" %
+ ("".join(e.old), "".join(old_actual)))
+ # Make the edit
+ char_array[start_eff:end_eff] = list(e.new)
+
+ # Create the underline highlighting of the before and after
+ change_list[e.start:e.start + len(e.old)] = "~" * len(e.old)
+ change_list_new[start_eff:end_eff] = "~" * len(e.new)
+
+ # Keep track of how to generate effective ranges
+ offset += len(e.new) - len(e.old)
+
+ # Finish the report comment
+ change_report += " %s\n" % "".join(change_list)
+ text[line - 1] = "".join(char_array)
+ change_report += " New: %s" % (text[line - 1])
+ change_report += " %s\n\n" % "".join(change_list_new)
+ return "".join(text), change_report, self._errors
+
+ def add(self, comment, line, start, old, new, error=None):
+ """Add a new change that is needed.
+
+ Args:
+ comment: A description of what was changed
+ line: Line number (1 indexed)
+ start: Column offset (0 indexed)
+ old: old text
+ new: new text
+ error: this "edit" is something that cannot be fixed automatically
+ Returns:
+ None
+ """
+
+ self._line_to_edit[line].append(
+ _FileEditTuple(comment, line, start, old, new))
+ if error:
+ self._errors.append("%s:%d: %s" % (self._filename, line, error))
+
+
+class _ASTCallVisitor(ast.NodeVisitor):
+ """AST Visitor that processes function calls.
+
+ Updates function calls from old API version to new API version using a given
+ change spec.
+ """
+
+ def __init__(self, filename, lines, api_change_spec):
+ self._filename = filename
+ self._file_edit = _FileEditRecorder(filename)
+ self._lines = lines
+ self._api_change_spec = api_change_spec
+
+ def process(self, lines):
+ return self._file_edit.process(lines)
+
+ def generic_visit(self, node):
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def _rename_functions(self, node, full_name):
+ function_renames = self._api_change_spec.function_renames
+ try:
+ new_name = function_renames[full_name]
+ self._file_edit.add("Renamed function %r to %r" % (full_name, new_name),
+ node.lineno, node.col_offset, full_name, new_name)
+ except KeyError:
+ pass
+
+ def _get_attribute_full_path(self, node):
+ """Traverse an attribute to generate a full name e.g. tf.foo.bar.
+
+ Args:
+ node: A Node of type Attribute.
+
+ Returns:
+ a '.'-delimited full-name or None if the tree was not a simple form.
+ i.e. `foo()+b).bar` returns None, while `a.b.c` would return "a.b.c".
+ """
+ curr = node
+ items = []
+ while not isinstance(curr, ast.Name):
+ if not isinstance(curr, ast.Attribute):
+ return None
+ items.append(curr.attr)
+ curr = curr.value
+ items.append(curr.id)
+ return ".".join(reversed(items))
+
+ def _find_true_position(self, node):
+ """Return correct line number and column offset for a given node.
+
+ This is necessary mainly because ListComp's location reporting reports
+ the next token after the list comprehension list opening.
+
+ Args:
+ node: Node for which we wish to know the lineno and col_offset
+ """
+ import re
+ find_open = re.compile("^\s*(\\[).*$")
+ find_string_chars = re.compile("['\"]")
+
+ if isinstance(node, ast.ListComp):
+ # Strangely, ast.ListComp returns the col_offset of the first token
+ # after the '[' token which appears to be a bug. Workaround by
+ # explicitly finding the real start of the list comprehension.
+ line = node.lineno
+ col = node.col_offset
+ # loop over lines
+ while 1:
+ # Reverse the text to and regular expression search for whitespace
+ text = self._lines[line - 1]
+ reversed_preceding_text = text[:col][::-1]
+ # First find if a [ can be found with only whitespace between it and
+ # col.
+ m = find_open.match(reversed_preceding_text)
+ if m:
+ new_col_offset = col - m.start(1) - 1
+ return line, new_col_offset
+ else:
+ if (reversed_preceding_text == "" or
+ reversed_preceding_text.isspace()):
+ line = line - 1
+ prev_line = self._lines[line - 1]
+ # TODO(aselle):
+ # this is poor comment detection, but it is good enough for
+ # cases where the comment does not contain string literal starting/
+ # ending characters. If ast gave us start and end locations of the
+ # ast nodes rather than just start, we could use string literal
+ # node ranges to filter out spurious #'s that appear in string
+ # literals.
+ comment_start = prev_line.find("#")
+ if comment_start == -1:
+ col = len(prev_line) - 1
+ elif find_string_chars.search(prev_line[comment_start:]) is None:
+ col = comment_start
+ else:
+ return None, None
+ else:
+ return None, None
+ # Most other nodes return proper locations (with notably does not), but
+ # it is not possible to use that in an argument.
+ return node.lineno, node.col_offset
+
+ def visit_Call(self, node): # pylint: disable=invalid-name
+ """Handle visiting a call node in the AST.
+
+ Args:
+ node: Current Node
+ """
+
+ # Find a simple attribute name path e.g. "tf.foo.bar"
+ full_name = self._get_attribute_full_path(node.func)
+
+ # Make sure the func is marked as being part of a call
+ node.func.is_function_for_call = True
+
+ if full_name:
+ # Call special handlers
+ function_handles = self._api_change_spec.function_handle
+ if full_name in function_handles:
+ function_handles[full_name](self._file_edit, node)
+
+ # Examine any non-keyword argument and make it into a keyword argument
+ # if reordering required.
+ function_reorders = self._api_change_spec.function_reorders
+ function_keyword_renames = (
+ self._api_change_spec.function_keyword_renames)
+
+ if full_name in function_reorders:
+ reordered = function_reorders[full_name]
+ for idx, arg in enumerate(node.args):
+ lineno, col_offset = self._find_true_position(arg)
+ if lineno is None or col_offset is None:
+ self._file_edit.add(
+ "Failed to add keyword %r to reordered function %r" %
+ (reordered[idx], full_name),
+ arg.lineno,
+ arg.col_offset,
+ "",
+ "",
+ error="A necessary keyword argument failed to be inserted.")
+ else:
+ keyword_arg = reordered[idx]
+ if (full_name in function_keyword_renames and
+ keyword_arg in function_keyword_renames[full_name]):
+ keyword_arg = function_keyword_renames[full_name][keyword_arg]
+ self._file_edit.add("Added keyword %r to reordered function %r" %
+ (reordered[idx], full_name), lineno, col_offset,
+ "", keyword_arg + "=")
+
+ # Examine each keyword argument and convert it to the final renamed form
+ renamed_keywords = ({} if full_name not in function_keyword_renames else
+ function_keyword_renames[full_name])
+ for keyword in node.keywords:
+ argkey = keyword.arg
+ argval = keyword.value
+
+ if argkey in renamed_keywords:
+ argval_lineno, argval_col_offset = self._find_true_position(argval)
+ if argval_lineno is not None and argval_col_offset is not None:
+ # TODO(aselle): We should scan backward to find the start of the
+ # keyword key. Unfortunately ast does not give you the location of
+ # keyword keys, so we are forced to infer it from the keyword arg
+ # value.
+ key_start = argval_col_offset - len(argkey) - 1
+ key_end = key_start + len(argkey) + 1
+ if (self._lines[argval_lineno - 1][key_start:key_end] == argkey +
+ "="):
+ self._file_edit.add("Renamed keyword argument from %r to %r" %
+ (argkey,
+ renamed_keywords[argkey]), argval_lineno,
+ argval_col_offset - len(argkey) - 1,
+ argkey + "=", renamed_keywords[argkey] + "=")
+ continue
+ self._file_edit.add(
+ "Failed to rename keyword argument from %r to %r" %
+ (argkey, renamed_keywords[argkey]),
+ argval.lineno,
+ argval.col_offset - len(argkey) - 1,
+ "",
+ "",
+ error="Failed to find keyword lexographically. Fix manually.")
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+ def visit_Attribute(self, node): # pylint: disable=invalid-name
+ """Handle bare Attributes i.e. [tf.foo, tf.bar].
+
+ Args:
+ node: Node that is of type ast.Attribute
+ """
+ full_name = self._get_attribute_full_path(node)
+ if full_name:
+ self._rename_functions(node, full_name)
+ if full_name in self._api_change_spec.change_to_function:
+ if not hasattr(node, "is_function_for_call"):
+ new_text = full_name + "()"
+ self._file_edit.add("Changed %r to %r" % (full_name, new_text),
+ node.lineno, node.col_offset, full_name, new_text)
+
+ ast.NodeVisitor.generic_visit(self, node)
+
+
+class ASTCodeUpgrader(object):
+ """Handles upgrading a set of Python files using a given API change spec."""
+
+ def __init__(self, api_change_spec):
+ if not isinstance(api_change_spec, APIChangeSpec):
+ raise TypeError("Must pass APIChangeSpec to ASTCodeUpgrader, got %s" %
+ type(api_change_spec))
+ self._api_change_spec = api_change_spec
+
+ def process_file(self, in_filename, out_filename):
+ """Process the given python file for incompatible changes.
+
+ Args:
+ in_filename: filename to parse
+ out_filename: output file to write to
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+
+ # Write to a temporary file, just in case we are doing an implace modify.
+ with open(in_filename, "r") as in_file, \
+ tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
+ ret = self.process_opened_file(in_filename, in_file, out_filename,
+ temp_file)
+
+ shutil.move(temp_file.name, out_filename)
+ return ret
+
+ # Broad exceptions are required here because ast throws whatever it wants.
+ # pylint: disable=broad-except
+ def process_opened_file(self, in_filename, in_file, out_filename, out_file):
+ """Process the given python file for incompatible changes.
+
+ This function is split out to facilitate StringIO testing from
+ tf_upgrade_test.py.
+
+ Args:
+ in_filename: filename to parse
+ in_file: opened file (or StringIO)
+ out_filename: output file to write to
+ out_file: opened file (or StringIO)
+ Returns:
+ A tuple representing number of files processed, log of actions, errors
+ """
+ process_errors = []
+ text = "-" * 80 + "\n"
+ text += "Processing file %r\n outputting to %r\n" % (in_filename,
+ out_filename)
+ text += "-" * 80 + "\n\n"
+
+ parsed_ast = None
+ lines = in_file.readlines()
+ try:
+ parsed_ast = ast.parse("".join(lines))
+ except Exception:
+ text += "Failed to parse %r\n\n" % in_filename
+ text += traceback.format_exc()
+ if parsed_ast:
+ visitor = _ASTCallVisitor(in_filename, lines, self._api_change_spec)
+ visitor.visit(parsed_ast)
+ out_text, new_text, process_errors = visitor.process(lines)
+ text += new_text
+ if out_file:
+ out_file.write(out_text)
+ text += "\n"
+ return 1, text, process_errors
+
+ # pylint: enable=broad-except
+
+ def process_tree(self, root_directory, output_root_directory,
+ copy_other_files):
+ """Processes upgrades on an entire tree of python files in place.
+
+ Note that only Python files. If you have custom code in other languages,
+ you will need to manually upgrade those.
+
+ Args:
+ root_directory: Directory to walk and process.
+ output_root_directory: Directory to use as base.
+ copy_other_files: Copy files that are not touched by this converter.
+
+ Returns:
+ A tuple of files processed, the report string ofr all files, and errors
+ """
+
+ # make sure output directory doesn't exist
+ if output_root_directory and os.path.exists(output_root_directory):
+ print("Output directory %r must not already exist." %
+ (output_root_directory))
+ sys.exit(1)
+
+ # make sure output directory does not overlap with root_directory
+ norm_root = os.path.split(os.path.normpath(root_directory))
+ norm_output = os.path.split(os.path.normpath(output_root_directory))
+ if norm_root == norm_output:
+ print("Output directory %r same as input directory %r" %
+ (root_directory, output_root_directory))
+ sys.exit(1)
+
+ # Collect list of files to process (we do this to correctly handle if the
+ # user puts the output directory in some sub directory of the input dir)
+ files_to_process = []
+ files_to_copy = []
+ for dir_name, _, file_list in os.walk(root_directory):
+ py_files = [f for f in file_list if f.endswith(".py")]
+ copy_files = [f for f in file_list if not f.endswith(".py")]
+ for filename in py_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(output_root_directory,
+ os.path.relpath(fullpath,
+ root_directory))
+ files_to_process.append((fullpath, fullpath_output))
+ if copy_other_files:
+ for filename in copy_files:
+ fullpath = os.path.join(dir_name, filename)
+ fullpath_output = os.path.join(output_root_directory,
+ os.path.relpath(
+ fullpath, root_directory))
+ files_to_copy.append((fullpath, fullpath_output))
+
+ file_count = 0
+ tree_errors = []
+ report = ""
+ report += ("=" * 80) + "\n"
+ report += "Input tree: %r\n" % root_directory
+ report += ("=" * 80) + "\n"
+
+ for input_path, output_path in files_to_process:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ file_count += 1
+ _, l_report, l_errors = self.process_file(input_path, output_path)
+ tree_errors += l_errors
+ report += l_report
+ for input_path, output_path in files_to_copy:
+ output_directory = os.path.dirname(output_path)
+ if not os.path.isdir(output_directory):
+ os.makedirs(output_directory)
+ shutil.copy(input_path, output_path)
+ return file_count, report, tree_errors
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 57a491255e..fd94d64268 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -63,7 +63,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.11.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 204b5b4dba..5ec43b8cb8 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -72,7 +72,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc
RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
>>/etc/bazel.bazelrc
# Install the most recent bazel release.
-ENV BAZEL_VERSION 0.11.0
+ENV BAZEL_VERSION 0.14.1
WORKDIR /
RUN mkdir /bazel && \
cd /bazel && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7 b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
new file mode 100644
index 0000000000..3bedc8cf34
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu-cuda9-cudnn7
@@ -0,0 +1,115 @@
+FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04
+
+LABEL maintainer="Gunhan Gulsoy <gunan@google.com>"
+
+# It is possible to override these for releases.
+ARG TF_BRANCH=master
+ARG BAZEL_VERSION=0.5.4
+ARG TF_AVAILABLE_CPUS=32
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ golang \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ python-pip \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ wget \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN pip --no-cache-dir install --upgrade \
+ pip setuptools
+
+RUN pip --no-cache-dir install \
+ ipykernel \
+ jupyter \
+ matplotlib \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ wheel \
+ && \
+ python -m ipykernel.kernelspec
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+# Set up Bazel.
+
+# Running bazel inside a `docker build` command causes trouble, cf:
+# https://github.com/bazelbuild/bazel/issues/134
+# The easiest solution is to set up a bazelrc file forcing --batch.
+RUN echo "startup --batch" >>/etc/bazel.bazelrc
+# Similarly, we need to workaround sandboxing issues:
+# https://github.com/bazelbuild/bazel/issues/418
+RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
+ >>/etc/bazel.bazelrc
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ wget --quiet https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ wget --quiet https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
+ chmod +x bazel-*.sh && \
+ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+
+# Download and build TensorFlow.
+WORKDIR /
+RUN git clone https://github.com/tensorflow/tensorflow.git && \
+ cd tensorflow && \
+ git checkout ${TF_BRANCH}
+WORKDIR /tensorflow
+
+# Configure the build for our CUDA configuration.
+ENV CI_BUILD_PYTHON=python \
+ LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH} \
+ CUDNN_INSTALL_PATH=/usr/lib/x86_64-linux-gnu \
+ PYTHON_BIN_PATH=/usr/bin/python \
+ PYTHON_LIB_PATH=/usr/local/lib/python2.7/dist-packages \
+ TF_NEED_CUDA=1 \
+ TF_CUDA_VERSION=9.0 \
+ TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0,6.1,7.0 \
+ TF_CUDNN_VERSION=7
+RUN ./configure
+
+# Build and Install TensorFlow.
+RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs:${LD_LIBRARY_PATH} \
+ bazel build -c opt \
+ --config=cuda \
+ --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
+ --jobs=${TF_AVAILABLE_CPUS} \
+ tensorflow/tools/pip_package:build_pip_package && \
+ mkdir /pip_pkg && \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package /pip_pkg && \
+ pip --no-cache-dir install --upgrade /pip_pkg/tensorflow-*.whl && \
+ rm -rf /pip_pkg && \
+ rm -rf /root/.cache
+# Clean up pip wheel and Bazel cache when done.
+
+WORKDIR /root
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
new file mode 100755
index 0000000000..c85641b383
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -0,0 +1,128 @@
+FROM ubuntu:16.04
+
+LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
+
+# These parameters can be overridden by parameterized_docker_build.sh
+ARG TF_BUILD_VERSION=r1.9
+ARG PYTHON="python"
+ARG PYTHON3_DEV=""
+ARG WHL_DIR="/tmp/pip"
+ARG PIP="pip"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python-dev \
+ ${PYTHON3_DEV} \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
+ ${PYTHON} get-pip.py && \
+ rm get-pip.py
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ ipykernel \
+ jupyter \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && \
+ ${PYTHON} -m ipykernel.kernelspec
+
+RUN if [ "${PYTHON}" = "python3" ]; then \
+ ln -s -f /usr/bin/python3 /usr/bin/python; \
+ fi
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+# Set up Bazel.
+
+# Running bazel inside a `docker build` command causes trouble, cf:
+# https://github.com/bazelbuild/bazel/issues/134
+# The easiest solution is to set up a bazelrc file forcing --batch.
+RUN echo "startup --batch" >>/etc/bazel.bazelrc
+# Similarly, we need to workaround sandboxing issues:
+# https://github.com/bazelbuild/bazel/issues/418
+RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \
+ >>/etc/bazel.bazelrc
+# Install the most recent bazel release.
+ENV BAZEL_VERSION 0.14.1
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
+ chmod +x bazel-*.sh && \
+ ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+ cd / && \
+ rm -f /bazel/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh
+
+# Download and build TensorFlow.
+WORKDIR /tensorflow
+
+# Download and build TensorFlow.
+# Enable checking out both tags and branches
+RUN export TAG_PREFIX="v" && \
+ echo ${TF_BUILD_VERSION} | grep -q ^${TAG_PREFIX}; \
+ if [ $? -eq 0 ]; then \
+ git clone --depth=1 https://github.com/tensorflow/tensorflow.git . && \
+ git fetch --tags && \
+ git checkout ${TF_BUILD_VERSION}; \
+ else \
+ git clone --depth=1 --branch=${TF_BUILD_VERSION} https://github.com/tensorflow/tensorflow.git . ; \
+ fi
+
+RUN yes "" | ${PYTHON} configure.py
+
+ENV CI_BUILD_PYTHON ${PYTHON}
+
+# Set bazel build parameters in .bazelrc in parameterized_docker_build.sh
+# Use --copt=-march values to get optimized builds appropriate for the hardware
+# platform of your choice.
+# For ivy-bridge or sandy-bridge
+# --copt=-march="avx" \
+# For haswell, broadwell, or skylake
+# --copt=-march="avx2" \
+COPY .bazelrc /root/.bazelrc
+
+RUN tensorflow/tools/ci_build/builds/configured CPU \
+ bazel --bazelrc=/root/.bazelrc build -c opt \
+ tensorflow/tools/pip_package:build_pip_package && \
+ bazel-bin/tensorflow/tools/pip_package/build_pip_package "${WHL_DIR}" && \
+ ${PIP} --no-cache-dir install --upgrade "${WHL_DIR}"/tensorflow-*.whl && \
+ rm -rf /root/.cache
+# Clean up Bazel cache when done.
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
+
+WORKDIR /root
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
new file mode 100755
index 0000000000..139395d491
--- /dev/null
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -0,0 +1,75 @@
+FROM ubuntu:16.04
+
+LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
+
+# This parameter MUST be set by parameterized_docker_build.sh
+ARG TF_WHL_URL
+
+# Optional parameters
+ARG TF_BUILD_VERSION=r1.9
+ARG PYTHON="python"
+ARG PYTHON_DEV="python-dev"
+ARG PIP="pip"
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libpng12-dev \
+ libzmq3-dev \
+ pkg-config \
+ python \
+ ${PYTHON_DEV} \
+ rsync \
+ software-properties-common \
+ unzip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ ipykernel \
+ jupyter \
+ matplotlib \
+ numpy \
+ pandas \
+ scipy \
+ sklearn \
+ && \
+ python -m ipykernel.kernelspec
+
+COPY ${TF_WHL_URL} /
+RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \
+ rm -rf /${TF_WHL_URL}
+
+RUN if [ "${PYTHON}" = "python3" ]; then \
+ ln -s -f /usr/bin/python3 /usr/bin/python; \
+ fi
+
+# Set up our notebook config.
+COPY jupyter_notebook_config.py /root/.jupyter/
+
+# Copy sample notebooks.
+COPY notebooks /notebooks
+
+# Jupyter has issues with being run directly:
+# https://github.com/ipython/ipython/issues/7062
+# We just add a little wrapper script.
+COPY run_jupyter.sh /
+
+# TensorBoard
+EXPOSE 6006
+# IPython
+EXPOSE 8888
+
+WORKDIR "/notebooks"
+
+CMD ["/run_jupyter.sh", "--allow-root"]
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index 05de25f2cb..4681c5fd61 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -19,8 +19,8 @@
# parameterized_docker_build.sh
#
# The script obeys the following environment variables:
-# TF_DOCKER_BUILD_TYPE: (CPU | GPU)
-# CPU or GPU image
+# TF_DOCKER_BUILD_TYPE: (CPU | GPU | MKL)
+# CPU, GPU, or MKL image
#
# TF_DOCKER_BUILD_IS_DEVEL: (NO | YES)
# Is this developer image
@@ -87,6 +87,15 @@
# TF_DOCKER_BUILD_OPTIONS
# (Optional)
# Specifies the desired build options. Defaults to OPT.
+#
+# TF_DOCKER_BUILD_ARGS
+# (Optional)
+# A list (array) of docker build args. Will be passed to docker build
+# command as list of --build-arg parameters.
+#
+# TF_BAZEL_BUILD_OPTIONS
+# (Optional)
+# Bazel compiler flags to be passed to the bazelrc file
# Script directory
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
@@ -116,6 +125,8 @@ echo " TF_DOCKER_BUILD_IMAGE_NAME=${TF_DOCKER_BUILD_IMAGE_NAME}"
echo " TF_DOCKER_BUILD_VERSION=${TF_DOCKER_BUILD_VERSION}"
echo " TF_DOCKER_BUILD_PORT=${TF_DOCKER_BUILD_PORT}"
echo " TF_DOCKER_BUILD_PUSH_CMD=${TF_DOCKER_BUILD_PUSH_CMD}"
+echo " TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]:-()}"
+echo " TF_BAZEL_BUILD_OPTIONS=${TF_BAZEL_BUILD_OPTIONS}"
CONTAINER_PORT=${TF_DOCKER_BUILD_PORT:-8888}
@@ -149,6 +160,15 @@ fi
if [[ ${TF_DOCKER_BUILD_TYPE} == "cpu" ]]; then
DOCKER_BINARY="docker"
+elif [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ DOCKER_BINARY="docker"
+ FINAL_TAG="${FINAL_TAG}-mkl"
+ if [[ ${ORIG_DOCKERFILE} == *"."* ]]; then
+ # There is already a dot in the tag, use "-"
+ ORIG_DOCKERFILE="${ORIG_DOCKERFILE}-mkl"
+ else
+ ORIG_DOCKERFILE="${ORIG_DOCKERFILE}.mkl"
+ fi
elif [[ ${TF_DOCKER_BUILD_TYPE} == "gpu" ]]; then
DOCKER_BINARY="nvidia-docker"
@@ -203,6 +223,10 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
export TF_BUILD_OPTIONS=${TF_DOCKER_BUILD_OPTIONS}
export TF_BUILD_IS_PIP="PIP"
+ if [[ "${TF_DOCKER_BUILD_TYPE}" == "mkl" ]]; then
+ die "FAIL: Non-development MKL builds require a pre-built pip whl."
+ fi
+
if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then
export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\
"${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2"
@@ -255,25 +279,39 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
# Use string replacement to put the correct file name into the Dockerfile
PIP_WHL=$(basename "${PIP_WHL}")
- # Modify the non-devel Dockerfile to point to the correct pip whl file
- # location
- sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}" )
+ cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
+ else
+ # Modify the non-devel Dockerfile to point to the correct pip whl file
+ # location
+ sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\
"/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\
"COPY ${PIP_WHL} /\n"\
"RUN pip --no-cache-dir install /${PIP_WHL}" "${ORIG_DOCKERFILE}" \
- > "${DOCKERFILE}"
+ > "${DOCKERFILE}"
+ fi
echo "Using local pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}"
echo
-
else
echo "Downloading pip wheel from: ${TF_DOCKER_BUILD_CENTRAL_PIP}"
- echo
-
- # Modify the non-devel Dockerfile to point to the correct pip whl URL.
- sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ pushd "${TMP_DIR}/"
+ curl -O ${TF_DOCKER_BUILD_CENTRAL_PIP}
+ popd
+ PIP_WHL_PATH=`find ${TMP_DIR} -name "*.whl"`
+ PIP_WHL=$(basename "${PIP_WHL_PATH}")
+ echo "PIP_WHL= ${PIP_WHL}"
+ echo
+ TF_DOCKER_BUILD_ARGS+=("--build-arg TF_WHL_URL=${PIP_WHL}")
+ cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
+ else
+ # Modify the non-devel Dockerfile to point to the correct pip whl URL.
+ sed -e "/# --- DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/,"\
"/# --- ~ DO NOT EDIT OR DELETE BETWEEN THE LINES --- #/c"\
"RUN pip --no-cache-dir install ${TF_DOCKER_BUILD_CENTRAL_PIP}" "${ORIG_DOCKERFILE}" \
- > "${DOCKERFILE}"
+ > "${DOCKERFILE}"
+ fi
fi
echo "Modified Dockerfile at: ${DOCKERFILE}"
@@ -281,36 +319,66 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
# Modify python/pip version if necessary.
if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then
- if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
- sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \
- sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
- sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
- then
- echo "Modified Dockerfile for python version "\
-"${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}")
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON_DEV=python3-dev")
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3")
+ cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
else
- die "FAILED to modify ${DOCKERFILE} for python3"
+ if sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
+ sed -i -e 's/python-dev/python3-dev/g' "${DOCKERFILE}" && \
+ sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
+ sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
+ then
+ echo "Modified Dockerfile for python version "\
+ "${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
+ else
+ die "FAILED to modify ${DOCKERFILE} for python3"
+ fi
fi
fi
-else
+else # TF_DOCKER_BUILD_IS_DEVEL == 'yes'
DOCKERFILE="${TMP_DIR}/Dockerfile"
- # Modify the devel Dockerfile to specify the git branch
- sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \
- "${ORIG_DOCKERFILE}" > "${DOCKERFILE}"
+ # Set up Dockerfile ARGS for mkl build
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ if [[ -z "${TF_BAZEL_BUILD_OPTIONS// }" ]]; then
+ TF_BAZEL_BUILD_OPTIONS=("--config=mkl --copt=-mavx --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0")
+ else
+ TF_BAZEL_BUILD_OPTIONS="${TF_BAZEL_BUILD_OPTIONS}"
+ fi
+ TF_DOCKER_BUILD_ARGS+=("--build-arg TF_BUILD_VERSION=${TF_DOCKER_BUILD_DEVEL_BRANCH}")
+ echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}"
+
+ # Pass the build options to bazel using the user-specific .bazelrc file
+ echo "build ${TF_BAZEL_BUILD_OPTIONS}" >> ${TMP_DIR}/.bazelrc
+ cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
+ else
+ # Modify the devel Dockerfile to specify the git branch
+ sed "s/^RUN git clone --branch=.* --depth=1/RUN git clone --branch=${TF_DOCKER_BUILD_DEVEL_BRANCH} --depth=1/" \
+ "${ORIG_DOCKERFILE}" > "${DOCKERFILE}"
+ fi
# Modify python/pip version if necessary.
if [[ "${TF_DOCKER_BUILD_PYTHON_VERSION}" == "python3" ]]; then
- if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \
- sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
- sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \
- sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
- sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \
- sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
- then
- echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
+ if [[ ${TF_DOCKER_BUILD_TYPE} == "mkl" ]]; then
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON=${TF_DOCKER_BUILD_PYTHON_VERSION}")
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PYTHON3_DEV=python3-dev")
+ TF_DOCKER_BUILD_ARGS+=("--build-arg WHL_DIR=/tmp/pip3")
+ TF_DOCKER_BUILD_ARGS+=("--build-arg PIP=pip3")
+ cp "${ORIG_DOCKERFILE}" "${DOCKERFILE}"
else
- die "FAILED to modify ${DOCKERFILE} for python3"
+ if sed -i -e 's/python-dev/python-dev python3-dev/g' "${DOCKERFILE}" && \
+ sed -i -e 's/python /python3 /g' "${DOCKERFILE}" && \
+ sed -i -e 's^/tmp/pip^/tmp/pip3^g' "${DOCKERFILE}" && \
+ sed -i -e 's/pip /pip3 /g' "${DOCKERFILE}" && \
+ sed -i -e 's/ENV CI_BUILD_PYTHON python/ENV CI_BUILD_PYTHON python3/g' "${DOCKERFILE}" && \
+ sed -i -e 's^# RUN ln -s -f /usr/bin/python3 /usr/bin/python#^RUN ln -s -f /usr/bin/python3 /usr/bin/python^' "${DOCKERFILE}"
+ then
+ echo "Modified Dockerfile further for python version ${TF_DOCKER_BUILD_PYTHON_VERSION} at: ${DOCKERFILE}"
+ else
+ die "FAILED to modify ${DOCKERFILE} for python3"
+ fi
fi
fi
fi
@@ -319,8 +387,11 @@ fi
# Intermediate image name with tag
IMG="${USER}/tensorflow:${FINAL_TAG}"
echo "Building docker image with image name and tag: ${IMG}"
+echo "TF_DOCKER_BUILD_ARGS=${TF_DOCKER_BUILD_ARGS[@]}"
+CMD="${DOCKER_BINARY} build ${TF_DOCKER_BUILD_ARGS[@]} --no-cache --pull -t ${IMG} -f ${DOCKERFILE} ${TMP_DIR}"
+echo "CMD=${CMD}"
+${CMD}
-"${DOCKER_BINARY}" build --no-cache --pull -t "${IMG}" -f "${DOCKERFILE}" "${TMP_DIR}"
if [[ $? == "0" ]]; then
echo "${DOCKER_BINARY} build of ${IMG} succeeded"
else
@@ -340,7 +411,7 @@ fi
DOCKER_RUN_LOG="${TMP_DIR}/docker_run.log"
echo ""
echo "Running docker container from image ${IMG}..."
-echo " (Log file is at: ${DOCKER_RUN_LOG}"
+echo " Log file is at: ${DOCKER_RUN_LOG}"
echo ""
if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
@@ -386,7 +457,6 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
# Stop the running docker container
sleep 1
"${DOCKER_BINARY}" stop --time=0 ${CONTAINER_ID}
-
fi
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 58b5ef8345..2403e2d966 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -37,7 +37,11 @@ py_library(
srcs = ["parser.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = ["@astor_archive//:astor"],
+ deps = [
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "@astor_archive//:astor",
+ ],
)
py_test(
@@ -92,6 +96,7 @@ py_binary(
deps = [
":generate_lib",
"//tensorflow:tensorflow_py",
+ "//tensorflow/python:util",
"//tensorflow/python/debug:debug_py",
],
)
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 853ec6194f..e7634cd5dc 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import argparse
import fnmatch
import os
+import shutil
import six
@@ -81,12 +82,8 @@ def write_docs(output_dir,
raise ValueError("'output_dir' must be an absolute path.\n"
" output_dir='%s'" % output_dir)
- try:
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- except OSError as e:
- print('Creating output dir "%s" failed: %s' % (output_dir, e))
- raise
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
# These dictionaries are used for table-of-contents generation below
# They will contain, after the for-loop below::
@@ -129,8 +126,6 @@ def write_docs(output_dir,
module_children.setdefault(subname, []).append(full_name)
break
- print('Writing docs for %s (%r).' % (full_name, py_object))
-
# Generate docs for `py_object`, resolving references.
page_info = parser.docs_for_object(full_name, py_object, parser_config)
@@ -151,10 +146,9 @@ def write_docs(output_dir,
text = text.encode('utf-8')
with open(path, 'wb') as f:
f.write(text)
- except OSError as e:
- print('Cannot write documentation for %s to %s: %s' % (full_name,
- directory, e))
- raise
+ except OSError:
+ raise OSError(
+ 'Cannot write documentation for %s to %s' % (full_name, directory))
if yaml_toc:
# Generate table of contents
@@ -394,16 +388,40 @@ def _build_guide_index(guide_src_dir):
class _UpdateTags(py_guide_parser.PyGuideParser):
- """Rewrites a Python guide so that each section has an explicit tag."""
+ """Rewrites a Python guide so that each section has an explicit id tag.
+
+ "section" here refers to blocks delimited by second level headings.
+ """
def process_section(self, line_number, section_title, tag):
self.replace_line(line_number, '<h2 id="%s">%s</h2>' % (tag, section_title))
+def update_id_tags_inplace(src_dir):
+ """Set explicit ids on all second-level headings to ensure back-links work.
+
+ Args:
+ src_dir: The directory of md-files to convert (inplace).
+ """
+ tag_updater = _UpdateTags()
+
+ for dirpath, _, filenames in os.walk(src_dir):
+ for base_name in filenames:
+ if not base_name.endswith('.md'):
+ continue
+ full_path = os.path.join(src_dir, dirpath, base_name)
+
+ # Tag updater loads the file, makes the replacements, and returns the
+ # modified file contents
+ content = tag_updater.process(full_path)
+ with open(full_path, 'w') as f:
+ f.write(content)
+
+
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
-def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
+def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
"""Fix @{} references in all files under `src_dir` matching `file_pattern`.
A matching directory structure, with the modified files is
@@ -424,7 +442,6 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
using fnmatch. Non-matching files are copied unchanged.
"""
# Iterate through all the source files and process them.
- tag_updater = _UpdateTags()
for dirpath, _, filenames in os.walk(src_dir):
# How to get from `dirpath` to api_docs/python/
relative_path_to_root = os.path.relpath(
@@ -433,41 +450,32 @@ def _other_docs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
# Make the directory under output_dir.
new_dir = os.path.join(output_dir,
os.path.relpath(path=dirpath, start=src_dir))
- try:
- if not os.path.exists(new_dir):
- os.makedirs(new_dir)
- except OSError as e:
- print('Creating output dir "%s" failed: %s' % (new_dir, e))
- raise
+ if not os.path.exists(new_dir):
+ os.makedirs(new_dir)
for base_name in filenames:
if base_name in EXCLUDED:
- print('Skipping excluded file %s...' % base_name)
continue
full_in_path = os.path.join(dirpath, base_name)
+ # Set the `current_doc_full_name` so bad files can be reported on errors.
reference_resolver.current_doc_full_name = full_in_path
suffix = os.path.relpath(path=full_in_path, start=src_dir)
full_out_path = os.path.join(output_dir, suffix)
+ # Copy files that do not match the file_pattern, unmodified.
if not fnmatch.fnmatch(base_name, file_pattern):
- print('Copying un-matched file %s...' % suffix)
- open(full_out_path, 'wb').write(open(full_in_path, 'rb').read())
+ shutil.copyfile(full_in_path, full_out_path)
continue
- if dirpath.endswith('/api_guides/python'):
- print('Processing Python guide %s...' % base_name)
- content = tag_updater.process(full_in_path)
- else:
- print('Processing doc %s...' % suffix)
- content = open(full_in_path, 'rb').read().decode('utf-8')
+
+ with open(full_in_path, 'rb') as f:
+ content = f.read().decode('utf-8')
content = reference_resolver.replace_references(content,
relative_path_to_root)
with open(full_out_path, 'wb') as f:
f.write(content.encode('utf-8'))
- print('Done.')
-
class DocGenerator(object):
"""Main entry point for generating docs."""
@@ -554,15 +562,43 @@ class DocGenerator(object):
self._do_not_descend_map)
def build(self, flags):
- """Actually build the docs."""
+ """Build all the docs.
+
+ This produces two outputs
+
+ python api docs:
+
+ * generated from modules set with `set_py_modules`.
+ * written to '{FLAGS.output_dir}/api_docs/python/'
+
+ non-api docs:
+
+ * Everything in '{FLAGS.src_dir}' is copied to '{FLAGS.output_dir}'.
+ * '@{}' references in '.md' files are replaced with links.
+ * '.md' files under 'api_guides/python' have explicit ids set for their
+ second level headings.
+
+ Args:
+ flags:
+ * src_dir: Where to fetch the non-api-docs.
+ * base_dir: Base of the docs directory (Used to build correct
+ relative links).
+ * output_dir: Where to write the resulting docs.
+
+ Returns:
+ The number of errors encountered while processing.
+ """
+ # Extract the python api from the _py_modules
doc_index = build_doc_index(flags.src_dir)
visitor = self.run_extraction()
reference_resolver = self.make_reference_resolver(visitor, doc_index)
+ # Build the guide_index for the api_docs back links.
root_title = getattr(flags, 'root_title', 'TensorFlow')
guide_index = _build_guide_index(
os.path.join(flags.src_dir, 'api_guides/python'))
+ # Write the api docs.
parser_config = self.make_parser_config(visitor, reference_resolver,
guide_index, flags.base_dir)
output_dir = os.path.join(flags.output_dir, 'api_docs/python')
@@ -573,8 +609,16 @@ class DocGenerator(object):
yaml_toc=self.yaml_toc,
root_title=root_title,
search_hints=getattr(flags, 'search_hints', True))
- _other_docs(flags.src_dir, flags.output_dir, reference_resolver)
+ # Replace all the @{} references in files under `FLAGS.src_dir`
+ replace_refs(flags.src_dir, flags.output_dir, reference_resolver, '*.md')
+ # Fix the tags in the guide dir.
+ guide_dir = os.path.join(flags.output_dir, 'api_guides/python')
+ if os.path.exists(guide_dir):
+ update_id_tags_inplace(guide_dir)
+
+ # Report all errors found by the reference resolver, and return the error
+ # code.
parser_config.reference_resolver.log_errors()
return parser_config.reference_resolver.num_errors()
diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py
index ea6d28a02b..7a6f9fd9f7 100644
--- a/tensorflow/tools/docs/generate_lib_test.py
+++ b/tensorflow/tools/docs/generate_lib_test.py
@@ -51,7 +51,9 @@ class DummyVisitor(object):
class GenerateTest(googletest.TestCase):
- def test_write(self):
+ def get_test_objects(self):
+ # These are all mutable objects, so rebuild them for each test.
+ # Don't cache the objects.
module = sys.modules[__name__]
index = {
@@ -98,6 +100,11 @@ class GenerateTest(googletest.TestCase):
guide_index={},
base_dir=base_dir)
+ return reference_resolver, parser_config
+
+ def test_write(self):
+ _, parser_config = self.get_test_objects()
+
output_dir = googletest.GetTempDir()
generate_lib.write_docs(output_dir, parser_config, yaml_toc=True)
@@ -127,6 +134,107 @@ class GenerateTest(googletest.TestCase):
os.path.exists(
os.path.join(output_dir, 'tf/TestModule/test_function.md')))
+ def test_update_id_tags_inplace(self):
+ test_dir = googletest.GetTempDir()
+ test_sub_dir = os.path.join(test_dir, 'a/b')
+ os.makedirs(test_sub_dir)
+
+ test_path1 = os.path.join(test_dir, 'file1.md')
+ test_path2 = os.path.join(test_sub_dir, 'file2.md')
+ test_path3 = os.path.join(test_sub_dir, 'file3.notmd')
+
+ with open(test_path1, 'w') as f:
+ f.write('## abc&123')
+
+ with open(test_path2, 'w') as f:
+ f.write('# A Level 1 Heading\n')
+ f.write('## A Level 2 Heading')
+
+ with open(test_path3, 'w') as f:
+ f.write("## don\'t change this")
+
+ generate_lib.update_id_tags_inplace(test_dir)
+
+ with open(test_path1) as f:
+ content = f.read()
+
+ self.assertEqual(content, '<h2 id="abc_123">abc&123</h2>')
+
+ with open(test_path2) as f:
+ content = f.read()
+
+ self.assertEqual(
+ content, '# A Level 1 Heading\n'
+ '<h2 id="A_Level_2_Heading">A Level 2 Heading</h2>')
+
+ with open(test_path3) as f:
+ content = f.read()
+
+ self.assertEqual(content, "## don\'t change this")
+
+ def test_replace_refes(self):
+ test_dir = googletest.GetTempDir()
+ test_in_dir = os.path.join(test_dir, 'in')
+ test_in_dir_a = os.path.join(test_dir, 'in/a')
+ test_in_dir_b = os.path.join(test_dir, 'in/b')
+ os.makedirs(test_in_dir)
+ os.makedirs(test_in_dir_a)
+ os.makedirs(test_in_dir_b)
+
+ test_out_dir = os.path.join(test_dir, 'out')
+ os.makedirs(test_out_dir)
+
+ test_path1 = os.path.join(test_in_dir_a, 'file1.md')
+ test_path2 = os.path.join(test_in_dir_b, 'file2.md')
+ test_path3 = os.path.join(test_in_dir_b, 'file3.notmd')
+ test_path4 = os.path.join(test_in_dir_b, 'OWNERS')
+
+ with open(test_path1, 'w') as f:
+ f.write('Use `tf.test_function` to test things.')
+
+ with open(test_path2, 'w') as f:
+ f.write('Use @{tf.TestModule.TestClass.ChildClass} to test things.\n'
+ "`tf.whatever` doesn't exist")
+
+ with open(test_path3, 'w') as f:
+ file3_content = (
+ 'Not a .md file. Should be copied unchanged:'
+ '@{tf.TestModule.TestClass.ChildClass}, `tf.test_function`')
+ f.write(file3_content)
+
+ with open(test_path4, 'w') as f:
+ f.write('')
+
+ reference_resolver, _ = self.get_test_objects()
+ generate_lib.replace_refs(test_in_dir, test_out_dir, reference_resolver,
+ '*.md')
+
+ with open(os.path.join(test_out_dir, 'a/file1.md')) as f:
+ content = f.read()
+ self.assertEqual(
+ content,
+ 'Use <a href="../api_docs/python/tf/TestModule/test_function.md">'
+ '<code>tf.test_function</code></a> to test things.')
+
+ with open(os.path.join(test_out_dir, 'b/file2.md')) as f:
+ content = f.read()
+ self.assertEqual(
+ content,
+ 'Use '
+ '<a href="../api_docs/python/tf/TestModule/TestClass/ChildClass.md">'
+ '<code>tf.TestModule.TestClass.ChildClass</code></a> '
+ 'to test things.\n'
+ '`tf.whatever` doesn\'t exist')
+
+ with open(os.path.join(test_out_dir, 'b/file3.notmd')) as f:
+ content = f.read()
+ self.assertEqual(content, file3_content)
+
+ with self.assertRaises(IOError):
+ # This should fail. The OWNERS file should not be copied
+ with open(os.path.join(test_out_dir, 'b/OWNERS')) as f:
+ content = f.read()
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 50c9052741..ffb93027ed 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -25,12 +25,12 @@ import itertools
import json
import os
import re
-import sys
import astor
import six
from google.protobuf.message import Message as ProtoMessage
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
@@ -53,7 +53,7 @@ class _Errors(object):
template = 'ERROR:\n output file name: %s\n %s\n\n'
for full_name, message in self._errors:
- print(template % (full_name, message), file=sys.stderr)
+ logging.warn(template, full_name, message)
def append(self, full_name, message):
"""Add an error to the collection.
@@ -761,8 +761,9 @@ def _generate_signature(func, reverse_index):
lookup_text = public_name + default_text[len(internal_name):]
break
if default_text is lookup_text:
- print('WARNING: Using default arg, failed lookup: %s, repr: %r' %
- (default_text, default))
+ logging.warn(
+ 'WARNING: Using default arg, failed lookup: %s, repr: %r',
+ default_text, default)
else:
default_text = lookup_text
else:
@@ -1165,7 +1166,7 @@ class _ClassPageInfo(object):
if short_name in [
'__class__', '__base__', '__weakref__', '__doc__', '__module__',
'__dict__', '__abstractmethods__', '__slots__', '__getnewargs__',
- '__str__', '__repr__', '__hash__'
+ '__str__', '__repr__', '__hash__', '__reduce__'
]:
continue
@@ -1213,8 +1214,6 @@ class _ClassPageInfo(object):
if not child_doc.brief.strip() and short_name in [
'__del__', '__copy__'
]:
- print('Skipping %s, defined in %s, no docstring.' % (child_name,
- defining_class))
continue
try:
@@ -1371,7 +1370,8 @@ class _ModulePageInfo(object):
for name in member_names:
if name in ['__builtins__', '__doc__', '__file__',
- '__name__', '__path__', '__package__']:
+ '__name__', '__path__', '__package__',
+ '__cached__', '__loader__', '__spec__']:
continue
member_full_name = self.full_name + '.' + name if self.full_name else name
diff --git a/tensorflow/tools/docs/py_guide_parser.py b/tensorflow/tools/docs/py_guide_parser.py
index 328f42d18f..b00694dc40 100644
--- a/tensorflow/tools/docs/py_guide_parser.py
+++ b/tensorflow/tools/docs/py_guide_parser.py
@@ -44,7 +44,8 @@ class PyGuideParser(object):
def process(self, full_path):
"""Read and process the file at `full_path`."""
- md_string = open(full_path, 'rb').read().decode('utf-8')
+ with open(full_path, 'rb') as f:
+ md_string = f.read().decode('utf-8')
self._lines = md_string.split('\n')
seen = set()
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index 73dee98bae..cc2288a7fa 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -164,14 +164,17 @@ def get_git_version(git_base_path, git_tag_override):
"git", str("--git-dir=%s/.git" % git_base_path),
str("--work-tree=" + git_base_path), "describe", "--long", "--tags"
]).strip())
- if git_tag_override:
+ if git_tag_override and val:
split_val = val.split("-")
- if len(split_val) != 3:
+ if len(split_val) < 3:
raise Exception(
("Expected git version in format 'TAG-COMMITS AFTER TAG-HASH' "
"but got '%s'") % val)
- split_val[0] = git_tag_override
- val = bytes("-".join(split_val))
+ # There might be "-" in the tag name. But we can be sure that the final
+ # two "-" are those inserted by the git describe command.
+ abbrev_commit = split_val[-1]
+ val = bytes(
+ "-".join([git_tag_override, "0", abbrev_commit]))
return val if val else unknown_label
except (subprocess.CalledProcessError, OSError):
return unknown_label
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 85660f94a8..f858411876 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
return Status::OK();
}
+Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
+ GraphDef* graph_def) {
+ std::unordered_set<string> input_names;
+ for (const string& input_name : context.input_names) {
+ input_names.insert(ParseTensorName(input_name).first.ToString());
+ }
+
+ for (NodeDef& node : *graph_def->mutable_node()) {
+ if (input_names.find(node.name()) == input_names.end()) {
+ continue;
+ }
+ if (node.op() == "PlaceholderWithDefault") {
+ node.set_op("Placeholder");
+ node.clear_input();
+ } else if (node.op() != "Placeholder") {
+ return errors::InvalidArgument(
+ "Input '", node.name(),
+ "' was expected to be a Placeholder or PlaceholderWithDefault op, "
+ "but was ",
+ node.op());
+ }
+ }
+ return Status::OK();
+}
+
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
@@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
input_graph_def,
[&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
output_graph_def);
+ TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));
return Status::OK();
}
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index a082399a87..dcdc3c2906 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("unused"));
}
- void TestRemoveUnusedNodesMultipleOutputs() {
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- auto root = tensorflow::Scope::NewRootScope();
-
- // a b
- // \ /
- // shape_n
- // \ /
- // c
- auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
- auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
- auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
- auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);
-
- GraphDef graph_def;
- TF_ASSERT_OK(root.ToGraphDef(&graph_def));
- GraphDef result_graph_def;
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));
-
- // Only one output of shape_n node is fed input. Hence the graph search
- // should propagate to inputs of shape_n. Nothing to remove here.
- std::map<string, const NodeDef*> node_map;
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(1, node_map.count("a"));
- EXPECT_EQ(1, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
-
- result_graph_def.Clear();
- TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
- graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
- &result_graph_def));
-
- // Both outputs of shape_n node are fed inputs. shape_n does not function
- // and inputs to shape_n should be removed.
- node_map.clear();
- graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
- EXPECT_EQ(0, node_map.count("a"));
- EXPECT_EQ(0, node_map.count("b"));
- EXPECT_EQ(1, node_map.count("c"));
- }
-
void TestMaxConstantSizeInBytes() {
auto root = tensorflow::Scope::NewRootScope();
@@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
-TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
- TestRemoveUnusedNodesMultipleOutputs();
-}
-
TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
TestMaxConstantSizeInBytes();
}
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 77f83b77a0..173f418dc8 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -115,6 +115,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
@@ -130,7 +131,7 @@ genrule(
"@highwayhash//:LICENSE",
"@jemalloc//:COPYING",
"@jpeg//:LICENSE.md",
- "@libxsmm_archive//:LICENSE",
+ "@libxsmm_archive//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
@@ -156,6 +157,7 @@ genrule(
"//third_party/fft2d:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
"@curl//:COPYING",
@@ -168,7 +170,7 @@ genrule(
"@highwayhash//:LICENSE",
"@jemalloc//:COPYING",
"@jpeg//:LICENSE.md",
- "@libxsmm_archive//:LICENSE",
+ "@libxsmm_archive//:LICENSE.md",
"@llvm//:LICENSE.TXT",
"@lmdb//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 9d4148c07f..c9d53f46c3 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -57,15 +57,18 @@ COMMON_PIP_DEPS = [
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/autograph:autograph",
"//tensorflow/contrib/autograph/converters:converters",
- "//tensorflow/contrib/autograph/converters:test_lib",
+ "//tensorflow/contrib/autograph/core:core",
+ "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/contrib/autograph/impl:impl",
+ "//tensorflow/contrib/autograph/lang:lang",
"//tensorflow/contrib/autograph/operators:operators",
"//tensorflow/contrib/autograph/pyct:pyct",
"//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
+ "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
- "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
+ "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
@@ -91,6 +94,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/timeseries:timeseries_pip",
"//tensorflow/contrib/tpu",
"//tensorflow/examples/tutorials/mnist:package",
+ "//tensorflow/python:cond_v2",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
@@ -126,6 +130,8 @@ filegroup(
"@astor_archive//:LICENSE",
"@aws//:LICENSE",
"@boringssl//:LICENSE",
+ "@com_github_googleapis_googleapis//:LICENSE",
+ "@com_github_googlecloudplatform_google_cloud_cpp//:LICENSE",
"@com_google_absl//:LICENSE",
"@com_googlesource_code_re2//:LICENSE",
"@cub_archive//:LICENSE.TXT",
@@ -143,7 +149,7 @@ filegroup(
"@jemalloc//:COPYING",
"@jpeg//:LICENSE.md",
"@kafka//:LICENSE",
- "@libxsmm_archive//:LICENSE",
+ "@libxsmm_archive//:LICENSE.md",
"@lmdb//:LICENSE",
"@local_config_nccl//:LICENSE",
"@local_config_sycl//sycl:LICENSE.text",
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index dc9d059bab..c630ca04b8 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,7 +45,7 @@ DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.9.0'
+_VERSION = '1.9.0-rc0'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
@@ -53,9 +53,9 @@ REQUIRED_PACKAGES = [
'gast >= 0.2.0',
'numpy >= 1.13.3',
'six >= 1.10.0',
- 'protobuf >= 3.4.0',
+ 'protobuf >= 3.6.0',
'setuptools <= 39.1.0',
- 'tensorboard >= 1.9.0, < 1.10.0',
+ 'tensorboard >= 1.8.0, < 1.9.0',
'termcolor >= 1.1.0',
]
@@ -84,7 +84,7 @@ else:
if 'tf_nightly' in project_name:
for i, pkg in enumerate(REQUIRED_PACKAGES):
if 'tensorboard' in pkg:
- REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.9.0a0, < 1.10.0a0'
+ REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.10.0a0, < 1.11.0a0'
break
# weakref.finalize and enum were introduced in Python 3.4
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
index aa56cc676d..15d7c70281 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
@@ -279,8 +279,13 @@ void Generator::AppendFieldValueAppend(const FieldDescriptor& field,
if (omit_default) {
Print("if (", field_expr, " != 0) {").Nest();
}
- Print("o->AppendEnumName(\"", field.name(), "\", ",
- GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));");
+ Print("const char* enum_name = ",
+ GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, ");");
+ Print("if (enum_name[0]) {").Nest();
+ Print("o->AppendEnumName(\"", field.name(), "\", enum_name);");
+ Unnest().Print("} else {").Nest();
+ Print("o->AppendNumeric(\"", field.name(), "\", ", field_expr, ");");
+ Unnest().Print("}");
if (omit_default) {
Unnest().Print("}");
}
@@ -540,18 +545,24 @@ void Generator::AppendParseMessageFunction(const Descriptor& md) {
for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) {
const auto* value_d = enum_d->value(enum_i);
const string& value_name = value_d->name();
- string condition = StrCat("value == \"", value_name,
- "\" || value == \"", value_d->number(), "\"");
- if (value_d->number() == 0) {
- StrAppend(&condition, " || value == \"-0\"");
- }
+ string condition = StrCat("value == \"", value_name, "\"");
Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {");
Nest();
Print(set_value_prefix, "(", value_prefix, value_name, ");");
Unnest();
}
+ Print("} else {");
+ Nest();
+ // Proto3 allows all numeric values.
+ Print("int32 int_value;");
+ Print("if (strings::SafeStringToNumeric(value, &int_value)) {");
+ Nest();
+ Print(set_value_prefix, "(static_cast<", GetQualifiedName(*enum_d),
+ ">(int_value));");
+ Unnest();
Print("} else {").Nest().Print("return false;").Unnest().Print("}");
+ Unnest().Print("}");
} else {
Print(field->cpp_type_name(), " value;");
switch (field->cpp_type()) {
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
index 6f0b4f47de..e67add72de 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
@@ -455,7 +455,10 @@ TEST(CreateProtoDebugStringLibTest, Enums) {
"repeated_nested_enum: 1"));
EXPECT_PARSE_SUCCESS("", "optional_nested_enum: -0");
- EXPECT_PARSE_FAILURE("optional_nested_enum: 6");
+ // TODO(amauryfa): restore the line below when protobuf::TextFormat also
+ // supports unknonwn enum values.
+ // EXPECT_PARSE_SUCCESS("optional_nested_enum: 6", "optional_nested_enum: 6");
+ EXPECT_PARSE_FAILURE("optional_nested_enum: 2147483648"); // > INT32_MAX
EXPECT_PARSE_FAILURE("optional_nested_enum: BARNONE");
EXPECT_PARSE_FAILURE("optional_nested_enum: 'BAR'");
EXPECT_PARSE_FAILURE("optional_nested_enum: \"BAR\" ");
diff --git a/tensorflow/tools/quantization/quantize_graph_test.py b/tensorflow/tools/quantization/quantize_graph_test.py
index df71840b64..92bb5127da 100644
--- a/tensorflow/tools/quantization/quantize_graph_test.py
+++ b/tensorflow/tools/quantization/quantize_graph_test.py
@@ -119,8 +119,8 @@ def are_tensors_near(a, b, tolerance):
flat_a = a.flatten()
flat_b = b.flatten()
if len(flat_a) != len(flat_b):
- print("Tensors are different sizes: " + str(len(flat_a)) + " vs " + str(
- len(flat_b)))
+ tf_logging.info("Tensors are different sizes: " + str(len(flat_a)) + " vs "
+ + str(len(flat_b)))
return False
value_count = len(flat_a)
how_many_different = 0
@@ -140,10 +140,10 @@ def are_tensors_near(a, b, tolerance):
if how_many_different == 0:
return True
else:
- print("Tensors have {0} different values ({1}%), with mean difference"
- " {2} and mean absolute difference {3}".format(
- how_many_different, proportion_different * 100, mean_difference,
- mean_abs_difference))
+ tf_logging.info("Tensors have {0} different values ({1}%), with mean"
+ " difference {2} and mean absolute difference {3}".format(
+ how_many_different, proportion_different * 100,
+ mean_difference, mean_abs_difference))
return False
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 79274d66ad..b712954d6d 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -131,11 +131,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "libxsmm_archive",
urls = [
- "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.8.1.tar.gz",
- "https://github.com/hfp/libxsmm/archive/1.8.1.tar.gz",
+ "https://mirror.bazel.build/github.com/hfp/libxsmm/archive/1.9.tar.gz",
+ "https://github.com/hfp/libxsmm/archive/1.9.tar.gz",
],
- sha256 = "2ade869c3f42f23b5263c7d594aa3c7e5e61ac6a3afcaf5d6e42899d2a7986ce",
- strip_prefix = "libxsmm-1.8.1",
+ sha256 = "cd8532021352b4a0290d209f7f9bfd7c2411e08286a893af3577a43457287bfa",
+ strip_prefix = "libxsmm-1.9",
build_file = clean_dep("//third_party:libxsmm.BUILD"),
)
@@ -155,12 +155,33 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_googlesource_code_re2",
urls = [
- "https://mirror.bazel.build/github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
- "https://github.com/google/re2/archive/26cd968b735e227361c9703683266f01e5df7857.tar.gz",
+ "https://mirror.bazel.build/github.com/google/re2/archive/2018-04-01.tar.gz",
+ "https://github.com/google/re2/archive/2018-04-01.tar.gz",
],
- sha256 = "e57eeb837ac40b5be37b2c6197438766e73343ffb32368efea793dfd8b28653b",
- strip_prefix = "re2-26cd968b735e227361c9703683266f01e5df7857",
+ sha256 = "2f945446b71336e7f5a2bcace1abcf0b23fbba368266c6a1be33de3de3b3c912",
+ strip_prefix = "re2-2018-04-01",
+ )
+
+ tf_http_archive(
+ name = "com_github_googlecloudplatform_google_cloud_cpp",
+ urls = [
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/f875700a023bdd706333cde45aee8758b272c357.tar.gz",
+ ],
+ sha256 = "a34f3c50b237686dc870b13baaa6a5836ce3473f2f2a02717299f0ff318372db",
+ strip_prefix = "google-cloud-cpp-f875700a023bdd706333cde45aee8758b272c357",
+ )
+
+ tf_http_archive(
+ name = "com_github_googleapis_googleapis",
+ urls = [
+ "https://mirror.bazel.build/github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ "https://github.com/googleapis/googleapis/archive/f81082ea1e2f85c43649bee26e0d9871d4b41cdb.zip",
+ ],
+ sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
+ strip_prefix="googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
+ build_file = clean_dep("//third_party:googleapis.BUILD"),
)
tf_http_archive(
@@ -198,11 +219,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "nasm",
urls = [
- "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2",
- "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2",
+ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
+ "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
],
- sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324",
- strip_prefix = "nasm-2.12.02",
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
build_file = clean_dep("//third_party:nasm.BUILD"),
)
@@ -232,11 +254,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "org_sqlite",
urls = [
- "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
- "https://www.sqlite.org/2018/sqlite-amalgamation-3230100.zip",
+ "https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
+ "https://www.sqlite.org/2018/sqlite-amalgamation-3240000.zip",
],
- sha256 = "4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc",
- strip_prefix = "sqlite-amalgamation-3230100",
+ sha256 = "ad68c1216c3a474cf360c7581a4001e952515b3649342100f2d7ca7c8e313da6",
+ strip_prefix = "sqlite-amalgamation-3240000",
build_file = clean_dep("//third_party:sqlite.BUILD"),
)
@@ -298,11 +320,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "absl_py",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz",
- "https://github.com/abseil/abseil-py/archive/ea8c4d2ddbf3fba610c4d613260561699b776db8.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
+ "https://github.com/abseil/abseil-py/archive/pypi-v0.2.2.tar.gz",
],
- sha256 = "c30b48e0d2580ef1412e55c5c0e1dab8db2ee4ab56e2075eccff29c90c7c7059",
- strip_prefix = "abseil-py-ea8c4d2ddbf3fba610c4d613260561699b776db8",
+ sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
+ strip_prefix = "abseil-py-pypi-v0.2.2",
)
tf_http_archive(
@@ -330,11 +352,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "protobuf_archive",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
- "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
],
- sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
- strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
)
# We need to import the protobuf library under the names com_google_protobuf
@@ -343,31 +365,31 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_google_protobuf",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
- "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
],
- sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
- strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
)
tf_http_archive(
name = "com_google_protobuf_cc",
urls = [
- "https://mirror.bazel.build/github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
- "https://github.com/google/protobuf/archive/396336eb961b75f03b25824fe86cf6490fb75e3a.tar.gz",
+ "https://mirror.bazel.build/github.com/google/protobuf/archive/v3.6.0.tar.gz",
+ "https://github.com/google/protobuf/archive/v3.6.0.tar.gz",
],
- sha256 = "846d907acf472ae233ec0882ef3a2d24edbbe834b80c305e867ac65a1f2c59e3",
- strip_prefix = "protobuf-396336eb961b75f03b25824fe86cf6490fb75e3a",
+ sha256 = "50a5753995b3142627ac55cfd496cebc418a2e575ca0236e29033c67bd5665f4",
+ strip_prefix = "protobuf-3.6.0",
)
tf_http_archive(
name = "nsync",
urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz",
- "https://github.com/google/nsync/archive/0559ce013feac8db639ee1bf776aca0325d28777.tar.gz",
+ "https://mirror.bazel.build/github.com/google/nsync/archive/1.20.0.tar.gz",
+ "https://github.com/google/nsync/archive/1.20.0.tar.gz",
],
- sha256 = "6284454c5cd8b1dae2eeb8cf5eb63004de930b5427ed5f6b1aa793513df6b361",
- strip_prefix = "nsync-0559ce013feac8db639ee1bf776aca0325d28777",
+ sha256 = "0c1b03962b2f8450f21e74a5a46116bf2d6009a807c57eb4207e974a8c4bb7dd",
+ strip_prefix = "nsync-1.20.0",
)
tf_http_archive(
@@ -383,21 +405,21 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "com_github_gflags_gflags",
urls = [
- "https://mirror.bazel.build/github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz",
- "https://github.com/gflags/gflags/archive/f8a0efe03aa69b3336d8e228b37d4ccb17324b88.tar.gz",
+ "https://mirror.bazel.build/github.com/gflags/gflags/archive/v2.2.1.tar.gz",
+ "https://github.com/gflags/gflags/archive/v2.2.1.tar.gz",
],
- sha256 = "4d222fab8f1ede4709cdff417d15a1336f862d7334a81abf76d09c15ecf9acd1",
- strip_prefix = "gflags-f8a0efe03aa69b3336d8e228b37d4ccb17324b88",
+ sha256 = "ae27cdbcd6a2f935baa78e4f21f675649271634c092b1be01469440495609d0e",
+ strip_prefix = "gflags-2.2.1",
)
tf_http_archive(
name = "pcre",
- sha256 = "ccdf7e788769838f8285b3ee672ed573358202305ee361cfec7a4a4fb005bbc7",
+ sha256 = "69acbc2fbdefb955d42a4c606dfde800c2885711d2979e356c0636efde9ec3b5",
urls = [
- "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.39.tar.gz",
- "http://ftp.exim.org/pub/pcre/pcre-8.39.tar.gz",
+ "https://mirror.bazel.build/ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
+ "http://ftp.exim.org/pub/pcre/pcre-8.42.tar.gz",
],
- strip_prefix = "pcre-8.39",
+ strip_prefix = "pcre-8.42",
build_file = clean_dep("//third_party:pcre.BUILD"),
)
@@ -415,26 +437,25 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "curl",
- sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6",
+ sha256 = "e9c37986337743f37fd14fe8737f246e97aec94b39d1b71e8a5973f72a9fc4f5",
urls = [
- "https://mirror.bazel.build/curl.haxx.se/download/curl-7.49.1.tar.gz",
- "https://curl.haxx.se/download/curl-7.49.1.tar.gz",
+ "https://mirror.bazel.build/curl.haxx.se/download/curl-7.60.0.tar.gz",
+ "https://curl.haxx.se/download/curl-7.60.0.tar.gz",
],
- strip_prefix = "curl-7.49.1",
+ strip_prefix = "curl-7.60.0",
build_file = clean_dep("//third_party:curl.BUILD"),
)
tf_http_archive(
name = "grpc",
urls = [
- "https://mirror.bazel.build/github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
- "https://github.com/grpc/grpc/archive/d184fa229d75d336aedea0041bd59cb93e7e267f.tar.gz",
+ "https://mirror.bazel.build/github.com/grpc/grpc/archive/v1.13.0.tar.gz",
+ "https://github.com/grpc/grpc/archive/v1.13.0.tar.gz",
],
- sha256 = "895b31310e718a61f7335759a778c068a6edde1c089883598a0830cbb7075673",
- strip_prefix = "grpc-d184fa229d75d336aedea0041bd59cb93e7e267f",
+ sha256 = "50db9cf2221354485eb7c3bd55a4c27190caef7048a2a1a15fbe60a498f98b44",
+ strip_prefix = "grpc-1.13.0",
)
-
tf_http_archive(
name = "linenoise",
sha256 = "7f51f45887a3d31b4ce4fa5965210a5e64637ceac12720cfce7954d6a2e812f7",
@@ -451,33 +472,33 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/bf13d093f13a295d71080614c3036ada591201d5.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93.tar.gz",
],
- sha256 = "3c5b4538a4df95090693bf6b758e861afc5b8c599592368f9dc57901f7560bd0",
- strip_prefix = "llvm-bf13d093f13a295d71080614c3036ada591201d5",
- build_file = clean_dep("//third_party/llvm:llvm.BUILD"),
+ sha256 = "280fdc888e2eb88a3a8cc4e7d3034fffc87f98e3e686be31f8c719c6e5b67d2d",
+ strip_prefix = "llvm-d5d94ca3a7f8526c2e4e5f663f9dc79ae5d39d93",
+ build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
tf_http_archive(
name = "lmdb",
urls = [
- "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz",
- "https://github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz",
+ "https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
+ "https://github.com/LMDB/lmdb/archive/LMDB_0.9.22.tar.gz",
],
- sha256 = "108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326",
- strip_prefix = "lmdb-LMDB_0.9.19/libraries/liblmdb",
+ sha256 = "f3927859882eb608868c8c31586bb7eb84562a40a6bf5cc3e13b6b564641ea28",
+ strip_prefix = "lmdb-LMDB_0.9.22/libraries/liblmdb",
build_file = clean_dep("//third_party:lmdb.BUILD"),
)
tf_http_archive(
name = "jsoncpp_git",
urls = [
- "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz",
- "https://github.com/open-source-parsers/jsoncpp/archive/11086dd6a7eba04289944367ca82cea71299ed70.tar.gz",
+ "https://mirror.bazel.build/github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
+ "https://github.com/open-source-parsers/jsoncpp/archive/1.8.4.tar.gz",
],
- sha256 = "07d34db40593d257324ec5fb9debc4dc33f29f8fb44e33a2eeb35503e61d0fe2",
- strip_prefix = "jsoncpp-11086dd6a7eba04289944367ca82cea71299ed70",
+ sha256 = "c49deac9e0933bcb7044f08516861a2d560988540b23de2ac1ad443b219afdb6",
+ strip_prefix = "jsoncpp-1.8.4",
build_file = clean_dep("//third_party:jsoncpp.BUILD"),
)
@@ -537,11 +558,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "kafka",
urls = [
- "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz",
- "https://github.com/edenhill/librdkafka/archive/v0.11.1.tar.gz",
+ "https://mirror.bazel.build/github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
+ "https://github.com/edenhill/librdkafka/archive/v0.11.4.tar.gz",
],
- sha256 = "dd035d57c8f19b0b612dd6eefe6e5eebad76f506e302cccb7c2066f25a83585e",
- strip_prefix = "librdkafka-0.11.1",
+ sha256 = "9d8f1eb7b0e29e9ab1168347c939cb7ae5dff00a39cef99e7ef033fd8f92737c",
+ strip_prefix = "librdkafka-0.11.4",
build_file = clean_dep("//third_party:kafka/BUILD"),
patch_file = clean_dep("//third_party/kafka:config.patch"),
)
@@ -627,6 +648,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
licenses = ["notice"], # Apache 2.0
)
+ java_import_external(
+ name = "com_squareup_javapoet",
+ jar_sha256 = "5bb5abdfe4366c15c0da3332c57d484e238bd48260d6f9d6acf2b08fdde1efea",
+ jar_urls = [
+ "http://mirror.bazel.build/repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
+ "http://repo1.maven.org/maven2/com/squareup/javapoet/1.9.0/javapoet-1.9.0.jar",
+ ],
+ licenses = ["notice"], # Apache 2.0
+ )
+
tf_http_archive(
name = "com_google_pprof",
urls = [
@@ -651,12 +682,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "cython",
- sha256 = "6dcd30b5ceb887b2b965ee7ceb82ea3acb5f0642fe2206c7636b45acea4798e5",
+ sha256 = "bccc9aa050ea02595b2440188813b936eaf345e85fb9692790cecfe095cf91aa",
urls = [
- "https://mirror.bazel.build/github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
- "https://github.com/cython/cython/archive/3732784c45cfb040a5b0936951d196f83a12ea17.tar.gz",
+ "https://mirror.bazel.build/github.com/cython/cython/archive/0.28.4.tar.gz",
+ "https://github.com/cython/cython/archive/0.28.4.tar.gz",
],
- strip_prefix = "cython-3732784c45cfb040a5b0936951d196f83a12ea17",
+ strip_prefix = "cython-0.28.4",
build_file = clean_dep("//third_party:cython.BUILD"),
delete = ["BUILD.bazel"],
)
@@ -664,11 +695,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "bazel_toolchains",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/37acf1841ab1475c98a152cb9e446460c8ae29e1.tar.gz",
],
- strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43",
- sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556",
+ strip_prefix = "bazel-toolchains-37acf1841ab1475c98a152cb9e446460c8ae29e1",
+ sha256 = "3b604699685c5c65dd3f6f17425570a4b2f00ddba2f750db15acc72e55bb098b",
)
tf_http_archive(
@@ -684,11 +715,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
tf_http_archive(
name = "flatbuffers",
- strip_prefix = "flatbuffers-971a68110e4fc1bace10fcb6deeb189e7e1a34ce",
- sha256 = "874088d2ee0d9f8524191f77209556415f03dd44e156276edf19e5b90ceb5f55",
+ strip_prefix = "flatbuffers-1.9.0",
+ sha256 = "5ca5491e4260cacae30f1a5786d109230db3f3a6e5a0eb45d0d0608293d247e3",
urls = [
- "https://mirror.bazel.build/github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz",
- "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz",
+ "https://mirror.bazel.build/github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
+ "https://github.com/google/flatbuffers/archive/v1.9.0.tar.gz",
],
build_file = clean_dep("//third_party/flatbuffers:flatbuffers.BUILD"),
)
@@ -722,6 +753,14 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
],
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
+ tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant",
+ sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
+ urls = ["https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_0.75_quant_2018_06_29.zip",
+ ],
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
tf_http_archive(
name = "tflite_conv_actions_frozen",
@@ -754,6 +793,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
strip_prefix = "ovic",
)
+ tf_http_archive(
+ name = "build_bazel_rules_android",
+ sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
+ "https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip",
+ ],
+ strip_prefix = "rules_android-0.1.1",
+ )
+
##############################################################################
# BIND DEFINITIONS
#
@@ -778,10 +827,13 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
actual = "@grpc//:grpc_python_plugin",
)
- # gRPC has three empty C++ functions which it wants the user to define
- # at build time. https://github.com/grpc/grpc/issues/13590
native.bind(
name = "grpc_lib",
+ actual = "@grpc//:grpc++",
+ )
+
+ native.bind(
+ name = "grpc_lib_unsecure",
actual = "@grpc//:grpc++_unsecure",
)
diff --git a/third_party/android/BUILD b/third_party/android/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/android/BUILD
diff --git a/third_party/android/android.bzl.tpl b/third_party/android/android.bzl.tpl
new file mode 100644
index 0000000000..e6ed4994f3
--- /dev/null
+++ b/third_party/android/android.bzl.tpl
@@ -0,0 +1,9 @@
+"""Set up configurable Android SDK and NDK dependencies."""
+
+def android_workspace():
+ # String for replacement in Bazel template.
+ # These will either be replaced by android_sdk_repository if various ENV
+ # variables are set when `local_config_android` repo_rule is run, or they
+ # will be replaced by noops otherwise.
+ MAYBE_ANDROID_SDK_REPOSITORY
+ MAYBE_ANDROID_NDK_REPOSITORY
diff --git a/third_party/android/android_configure.BUILD.tpl b/third_party/android/android_configure.BUILD.tpl
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/android/android_configure.BUILD.tpl
diff --git a/third_party/android/android_configure.bzl b/third_party/android/android_configure.bzl
new file mode 100644
index 0000000000..da09bdf39e
--- /dev/null
+++ b/third_party/android/android_configure.bzl
@@ -0,0 +1,87 @@
+"""Repository rule for Android SDK and NDK autoconfiguration.
+
+`android_configure` depends on the following environment variables:
+
+ * `ANDROID_NDK_HOME`: Location of Android NDK root.
+ * `ANDROID_SDK_HOME`: Location of Android SDK root.
+ * `ANDROID_SDK_API_LEVEL`: Desired Android SDK API version.
+ * `ANDROID_NDK_API_LEVEL`: Desired Android NDK API version.
+ * `ANDROID_BUILD_TOOLS_VERSION`: Desired Android build tools version.
+"""
+
+# TODO(mikecase): Move logic for getting default values for the env variables
+# from configure.py script into this rule.
+
+_ANDROID_NDK_HOME = "ANDROID_NDK_HOME"
+_ANDROID_SDK_HOME = "ANDROID_SDK_HOME"
+_ANDROID_NDK_API_VERSION = "ANDROID_NDK_API_LEVEL"
+_ANDROID_SDK_API_VERSION = "ANDROID_SDK_API_LEVEL"
+_ANDROID_BUILD_TOOLS_VERSION = "ANDROID_BUILD_TOOLS_VERSION"
+
+_ANDROID_SDK_REPO_TEMPLATE = """
+ native.android_sdk_repository(
+ name="androidsdk",
+ path="%s",
+ api_level=%s,
+ build_tools_version="%s",
+ )
+"""
+
+_ANDROID_NDK_REPO_TEMPLATE = """
+ native.android_ndk_repository(
+ name="androidndk",
+ path="%s",
+ api_level=%s,
+ )
+"""
+
+def _android_autoconf_impl(repository_ctx):
+ """Implementation of the android_autoconf repository rule."""
+ sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
+ sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
+ build_tools_version = repository_ctx.os.environ.get(
+ _ANDROID_BUILD_TOOLS_VERSION)
+ ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
+ ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
+
+ sdk_rule = "pass"
+ if all([sdk_home, sdk_api_level, build_tools_version]):
+ sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
+ sdk_home, sdk_api_level, build_tools_version)
+
+ ndk_rule = "pass"
+ if all([ndk_home, ndk_api_level]):
+ ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
+
+ repository_ctx.template(
+ "BUILD",
+ Label("//third_party/android:android_configure.BUILD.tpl"))
+ repository_ctx.template(
+ "android.bzl",
+ Label("//third_party/android:android.bzl.tpl"),
+ substitutions={
+ "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
+ "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
+ })
+
+android_configure = repository_rule(
+ implementation = _android_autoconf_impl,
+ environ = [
+ _ANDROID_SDK_API_VERSION,
+ _ANDROID_NDK_API_VERSION,
+ _ANDROID_BUILD_TOOLS_VERSION,
+ _ANDROID_NDK_HOME,
+ _ANDROID_SDK_HOME,
+ ],
+)
+"""Writes Android SDK and NDK rules.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+android_configure(name = "local_config_android")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD
index 2dc921933c..5426f79e46 100644
--- a/third_party/aws.BUILD
+++ b/third_party/aws.BUILD
@@ -46,6 +46,8 @@ cc_library(
"aws-cpp-sdk-core/source/utils/xml/**/*.cpp",
"aws-cpp-sdk-core/source/utils/crypto/*.cpp",
"aws-cpp-sdk-core/source/utils/crypto/factory/**/*.cpp",
+ "aws-cpp-sdk-kinesis/include/**/*.h",
+ "aws-cpp-sdk-kinesis/source/**/*.cpp",
"aws-cpp-sdk-s3/include/**/*.h",
"aws-cpp-sdk-s3/source/**/*.cpp",
]),
@@ -72,6 +74,7 @@ cc_library(
}),
includes = [
"aws-cpp-sdk-core/include/",
+ "aws-cpp-sdk-kinesis/include/",
"aws-cpp-sdk-s3/include/",
],
deps = [
diff --git a/third_party/clang_toolchain/download_clang.bzl b/third_party/clang_toolchain/download_clang.bzl
index 02d2b78067..ab57b9dfa0 100644
--- a/third_party/clang_toolchain/download_clang.bzl
+++ b/third_party/clang_toolchain/download_clang.bzl
@@ -35,18 +35,18 @@ def download_clang(repo_ctx, out_folder):
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
- CLANG_REVISION = '332335'
+ CLANG_REVISION = '336424'
CLANG_SUB_REVISION = 1
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
'Linux_x64':
- '5c234e0bc43b2386984ac34ac9c200c35686f2f7fa5ded0db031055bbc7f3e52',
+ '2ea97e047470da648f5d078af008bce6891287592382cee3d53a1187d996da94',
'Mac':
- '69b94f16d261c0922c3853cdad768776f454dece2948363f1c4e20bc2ddbf95d',
+ 'c6e28909cce63ee35e0d51284d9f0f6e8838f7fb8b7a0dc9536c2ea900552df0',
'Win':
- '76c8897abf032f3e23598275517da60090f53cf35b673481f41fa98752d1ad37',
+ '1299fda7c4378bfb81337f7e5f351c8a1f953f51e0744e2170454b8d722f3db7',
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
diff --git a/third_party/codegen.BUILD b/third_party/codegen.BUILD
new file mode 100644
index 0000000000..df436c8163
--- /dev/null
+++ b/third_party/codegen.BUILD
@@ -0,0 +1,16 @@
+# -*- mode: python; -*-
+#
+# Description:
+# Extension to ast that allow ast -> python code generation.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # New BSD
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "com_github_andreif_codegen",
+ srcs = glob(["codegen.py"]),
+ srcs_version = "PY2AND3",
+)
diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD
index 4def6f9489..1638b72161 100644
--- a/third_party/curl.BUILD
+++ b/third_party/curl.BUILD
@@ -7,6 +7,7 @@ exports_files(["COPYING"])
CURL_WIN_COPTS = [
"/Iexternal/curl/lib",
+ "/DBUILDING_LIBCURL",
"/DHAVE_CONFIG_H",
"/DCURL_DISABLE_FTP",
"/DCURL_DISABLE_NTLM",
@@ -49,6 +50,8 @@ cc_library(
"lib/curl_addrinfo.c",
"lib/curl_addrinfo.h",
"lib/curl_base64.h",
+ "lib/curl_ctype.c",
+ "lib/curl_ctype.h",
"lib/curl_des.h",
"lib/curl_endian.h",
"lib/curl_fnmatch.c",
@@ -75,6 +78,7 @@ cc_library(
"lib/curl_sec.h",
"lib/curl_setup.h",
"lib/curl_setup_once.h",
+ "lib/curl_sha256.h",
"lib/curl_sspi.c",
"lib/curl_sspi.h",
"lib/curl_threads.c",
@@ -134,6 +138,8 @@ cc_library(
"lib/md5.c",
"lib/memdebug.c",
"lib/memdebug.h",
+ "lib/mime.c",
+ "lib/mime.h",
"lib/mprintf.c",
"lib/multi.c",
"lib/multihandle.h",
@@ -153,8 +159,8 @@ cc_library(
"lib/pop3.h",
"lib/progress.c",
"lib/progress.h",
- "lib/rawstr.c",
- "lib/rawstr.h",
+ "lib/rand.c",
+ "lib/rand.h",
"lib/rtsp.c",
"lib/rtsp.h",
"lib/security.c",
@@ -162,8 +168,11 @@ cc_library(
"lib/select.h",
"lib/sendf.c",
"lib/sendf.h",
+ "lib/setopt.c",
+ "lib/setopt.h",
"lib/setup-os400.h",
"lib/setup-vms.h",
+ "lib/sha256.c",
"lib/share.c",
"lib/share.h",
"lib/sigpipe.h",
@@ -179,10 +188,10 @@ cc_library(
"lib/splay.c",
"lib/splay.h",
"lib/ssh.h",
+ "lib/strcase.c",
+ "lib/strcase.h",
"lib/strdup.c",
"lib/strdup.h",
- "lib/strequal.c",
- "lib/strequal.h",
"lib/strerror.c",
"lib/strerror.h",
"lib/strtok.c",
@@ -241,13 +250,12 @@ cc_library(
}),
hdrs = [
"include/curl/curl.h",
- "include/curl/curlbuild.h",
- "include/curl/curlrules.h",
"include/curl/curlver.h",
"include/curl/easy.h",
"include/curl/mprintf.h",
"include/curl/multi.h",
"include/curl/stdcheaders.h",
+ "include/curl/system.h",
"include/curl/typecheck-gcc.h",
],
copts = select({
@@ -256,6 +264,7 @@ cc_library(
"//conditions:default": [
"-Iexternal/curl/lib",
"-D_GNU_SOURCE",
+ "-DBUILDING_LIBCURL",
"-DHAVE_CONFIG_H",
"-DCURL_DISABLE_FTP",
"-DCURL_DISABLE_NTLM", # turning it off in configure is not enough
@@ -676,6 +685,7 @@ genrule(
"# define SIZEOF_INT 4",
"# define SIZEOF_LONG 8",
"# define SIZEOF_OFF_T 8",
+ "# define SIZEOF_CURL_OFF_T 8",
"# define SIZEOF_SHORT 2",
"# define SIZEOF_SIZE_T 8",
"# define SIZEOF_TIME_T 8",
diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD
index e54c1a4501..759f8a9be9 100644
--- a/third_party/eigen.BUILD
+++ b/third_party/eigen.BUILD
@@ -69,3 +69,9 @@ cc_library(
includes = ["."],
visibility = ["//visibility:public"],
)
+
+filegroup(
+ name = "eigen_header_files",
+ srcs = EIGEN_MPL2_HEADER_FILES,
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index f661093bc9..203991b50f 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -17,21 +17,23 @@ load("//tensorflow:tensorflow.bzl", "if_mkl")
# INTEL_MKL end
load("//tensorflow:tensorflow.bzl", "if_mkl")
+EIGEN3_THIRD_PARTY_HEADERS = [
+ "Eigen/Core",
+ "Eigen/LU",
+ "Eigen/Cholesky",
+ "Eigen/Eigenvalues",
+ "Eigen/QR",
+ "Eigen/SVD",
+ "unsupported/Eigen/MatrixFunctions",
+ "unsupported/Eigen/SpecialFunctions",
+ "unsupported/Eigen/CXX11/ThreadPool",
+ "unsupported/Eigen/CXX11/Tensor",
+ "unsupported/Eigen/CXX11/FixedPoint",
+] + glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"])
+
cc_library(
name = "eigen3",
- hdrs = glob(["unsupported/Eigen/CXX11/src/FixedPoint/*.h"]) + [
- "Eigen/Core",
- "Eigen/LU",
- "Eigen/Cholesky",
- "Eigen/Eigenvalues",
- "Eigen/QR",
- "Eigen/SVD",
- "unsupported/Eigen/MatrixFunctions",
- "unsupported/Eigen/SpecialFunctions",
- "unsupported/Eigen/CXX11/ThreadPool",
- "unsupported/Eigen/CXX11/Tensor",
- "unsupported/Eigen/CXX11/FixedPoint",
- ],
+ hdrs = EIGEN3_THIRD_PARTY_HEADERS,
includes = if_mkl(["./mkl_include"]),
visibility = ["//visibility:public"],
deps = [
@@ -48,3 +50,35 @@ filegroup(
),
visibility = ["//tensorflow:__subpackages__"],
)
+
+filegroup(
+ name = "eigen_third_party_header_files",
+ srcs = EIGEN3_THIRD_PARTY_HEADERS,
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "install_eigen_headers",
+ srcs = [
+ "@eigen_archive//:eigen_header_files",
+ ":eigen_third_party_header_files",
+ ],
+ outs = ["include"],
+ cmd = """
+ mkdir $@
+ for f in $(locations @eigen_archive//:eigen_header_files) ; do
+ d="$${f%/*}"
+ d="$${d#*external/eigen_archive/}"
+
+ mkdir -p "$@/$${d}"
+ cp "$${f}" "$@/$${d}/"
+ done
+
+ for f in $(locations :eigen_third_party_header_files) ; do
+ d="$${f%/*}"
+
+ mkdir -p "$@/$${d}"
+ cp "$${f}" "$@/$${d}/"
+ done
+ """,
+)
diff --git a/third_party/eigen_fix_cuda_compilation.patch b/third_party/eigen_fix_cuda_compilation.patch
deleted file mode 100644
index b921a7c31d..0000000000
--- a/third_party/eigen_fix_cuda_compilation.patch
+++ /dev/null
@@ -1,38 +0,0 @@
-diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h
---- a/Eigen/src/Core/ProductEvaluators.h
-+++ b/Eigen/src/Core/ProductEvaluators.h
-@@ -137,7 +137,7 @@ struct Assignment<DstXprType, Product<Lh
- typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct)>::type>
- {
- typedef Product<Lhs,Rhs,Options> SrcXprType;
-- static EIGEN_STRONG_INLINE
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
- void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar,Scalar> &)
- {
- Index dstRows = src.rows();
-@@ -390,7 +390,7 @@ struct generic_product_impl<Lhs,Rhs,Dens
- typedef typename Product<Lhs,Rhs>::Scalar Scalar;
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // Same as: dst.noalias() = lhs.lazyProduct(rhs);
- // but easier on the compiler side
-@@ -398,14 +398,14 @@ struct generic_product_impl<Lhs,Rhs,Dens
- }
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // dst.noalias() += lhs.lazyProduct(rhs);
- call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op<typename Dst::Scalar,Scalar>());
- }
-
- template<typename Dst>
-- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
-+ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs)
- {
- // dst.noalias() -= lhs.lazyProduct(rhs);
- call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op<typename Dst::Scalar,Scalar>());
diff --git a/third_party/flatbuffers/flatbuffers.BUILD b/third_party/flatbuffers/flatbuffers.BUILD
index 824c97be60..639dff2cd0 100644
--- a/third_party/flatbuffers/flatbuffers.BUILD
+++ b/third_party/flatbuffers/flatbuffers.BUILD
@@ -98,6 +98,8 @@ cc_binary(
"grpc/src/compiler/cpp_generator.h",
"grpc/src/compiler/go_generator.cc",
"grpc/src/compiler/go_generator.h",
+ "grpc/src/compiler/java_generator.cc",
+ "grpc/src/compiler/java_generator.h",
"grpc/src/compiler/schema_interface.h",
"src/flatc_main.cpp",
"src/idl_gen_cpp.cpp",
diff --git a/third_party/googleapis.BUILD b/third_party/googleapis.BUILD
new file mode 100644
index 0000000000..95e999af18
--- /dev/null
+++ b/third_party/googleapis.BUILD
@@ -0,0 +1,45 @@
+# Copyright 2018 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+licenses(["notice"]) # Apache 2.0
+exports_files(["LICENSE"])
+
+load("@protobuf_archive//:protobuf.bzl", "cc_proto_library")
+
+cc_proto_library(
+ name = "bigtable_protos",
+ srcs = [
+ "google/bigtable/admin/v2/bigtable_instance_admin.proto",
+ "google/bigtable/admin/v2/bigtable_table_admin.proto",
+ "google/bigtable/admin/v2/common.proto",
+ "google/bigtable/admin/v2/instance.proto",
+ "google/bigtable/admin/v2/table.proto",
+ "google/bigtable/v2/bigtable.proto",
+ "google/bigtable/v2/data.proto",
+ "google/iam/v1/iam_policy.proto",
+ "google/iam/v1/policy.proto",
+ "google/longrunning/operations.proto",
+ "google/rpc/status.proto",
+ "google/rpc/error_details.proto",
+ "google/api/annotations.proto",
+ "google/api/auth.proto",
+ "google/api/http.proto",
+ ],
+ include = ".",
+ protoc = "@protobuf_archive//:protoc",
+ default_runtime = "@protobuf_archive//:protobuf",
+ deps = ["@protobuf_archive//:cc_wkt_protos"],
+ use_grpc_plugin = True,
+)
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index 98cb326572..f638756d23 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -7,6 +7,7 @@ cc_toolchain_suite(
toolchains = {
"local|compiler": ":cc-compiler-local",
"darwin|compiler": ":cc-compiler-darwin",
+ "x64_windows|msvc-cl": ":cc-compiler-windows",
},
)
@@ -42,6 +43,20 @@ cc_toolchain(
supports_param_files = 0,
)
+cc_toolchain(
+ name = "cc-compiler-windows",
+ all_files = "%{win_linker_files}",
+ compiler_files = ":empty",
+ cpu = "x64_windows",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = "%{win_linker_files}",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 1,
+)
+
filegroup(
name = "empty",
srcs = [],
@@ -51,3 +66,8 @@ filegroup(
name = "crosstool_wrapper_driver_is_not_gcc",
srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
)
+
+filegroup(
+ name = "windows_msvc_wrapper_files",
+ srcs = glob(["windows/msvc_*"]),
+)
diff --git a/third_party/gpus/crosstool/CROSSTOOL.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl
index 60b19daf1d..3972c96a2f 100644
--- a/third_party/gpus/crosstool/CROSSTOOL.tpl
+++ b/third_party/gpus/crosstool/CROSSTOOL.tpl
@@ -22,6 +22,10 @@ default_toolchain {
cpu: "ppc"
toolchain_identifier: "local_linux"
}
+default_toolchain {
+ cpu: "x64_windows"
+ toolchain_identifier: "local_windows"
+}
toolchain {
abi_version: "local"
@@ -295,3 +299,1110 @@ toolchain {
%{host_compiler_includes}
}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+ %{host_compiler_warnings}
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin/"
+ }
+ }
+ }
+
+ feature {
+ name: "undefined-dynamic"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-undefined"
+ flag: "dynamic_lookup"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ implies: "undefined-dynamic"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "%{host_compiler_path}" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+%{host_compiler_includes}
+}
+
+toolchain {
+ toolchain_identifier: "local_windows"
+ host_system_name: "local"
+ target_system_name: "local"
+
+ abi_version: "local"
+ abi_libc_version: "local"
+ target_cpu: "x64_windows"
+ compiler: "msvc-cl"
+ target_libc: "msvcrt"
+
+%{cxx_builtin_include_directory}
+
+ tool_path {
+ name: "ar"
+ path: "%{msvc_lib_path}"
+ }
+ tool_path {
+ name: "ml"
+ path: "%{msvc_ml_path}"
+ }
+ tool_path {
+ name: "cpp"
+ path: "%{msvc_cl_path}"
+ }
+ tool_path {
+ name: "gcc"
+ path: "%{msvc_cl_path}"
+ }
+ tool_path {
+ name: "gcov"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "ld"
+ path: "%{msvc_link_path}"
+ }
+ tool_path {
+ name: "nm"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objcopy"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objdump"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "strip"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ supports_interface_shared_objects: true
+
+ # TODO(pcloudy): Review those flags below, they should be defined by cl.exe
+ compiler_flag: "/DCOMPILER_MSVC"
+
+ # Don't define min/max macros in windows.h.
+ compiler_flag: "/DNOMINMAX"
+
+ # Platform defines.
+ compiler_flag: "/D_WIN32_WINNT=0x0600"
+ # Turn off warning messages.
+ compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE"
+ compiler_flag: "/D_CRT_SECURE_NO_WARNINGS"
+ compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS"
+
+ # Useful options to have on for compilation.
+ # Increase the capacity of object files to 2^32 sections.
+ compiler_flag: "/bigobj"
+ # Allocate 500MB for precomputed headers.
+ compiler_flag: "/Zm500"
+ # Use unsigned char by default.
+ compiler_flag: "/J"
+ # Use function level linking.
+ compiler_flag: "/Gy"
+ # Use string pooling.
+ compiler_flag: "/GF"
+ # Catch C++ exceptions only and tell the compiler to assume that functions declared
+ # as extern "C" never throw a C++ exception.
+ compiler_flag: "/EHsc"
+
+ # Globally disabled warnings.
+ # Don't warn about elements of array being be default initialized.
+ compiler_flag: "/wd4351"
+ # Don't warn about no matching delete found.
+ compiler_flag: "/wd4291"
+ # Don't warn about diamond inheritance patterns.
+ compiler_flag: "/wd4250"
+ # Don't warn about insecure functions (e.g. non _s functions).
+ compiler_flag: "/wd4996"
+
+ linker_flag: "/MACHINE:X64"
+
+ feature {
+ name: "no_legacy_features"
+ }
+
+ # Suppress startup banner.
+ feature {
+ name: "nologo"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ flag_group {
+ flag: "/nologo"
+ }
+ }
+ }
+
+ feature {
+ name: 'has_configured_linker_path'
+ }
+
+ # This feature indicates strip is not supported, building stripped binary will just result a copy of orignial binary
+ feature {
+ name: 'no_stripping'
+ }
+
+ # This feature indicates this is a toolchain targeting Windows.
+ feature {
+ name: 'targets_windows'
+ implies: 'copy_dynamic_libraries_to_binary'
+ enabled: true
+ }
+
+ feature {
+ name: 'copy_dynamic_libraries_to_binary'
+ }
+
+ action_config {
+ config_name: 'assemble'
+ action_name: 'assemble'
+ tool {
+ tool_path: '%{msvc_ml_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'preprocess-assemble'
+ action_name: 'preprocess-assemble'
+ tool {
+ tool_path: '%{msvc_ml_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'c-compile'
+ action_name: 'c-compile'
+ tool {
+ tool_path: '%{msvc_cl_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-compile'
+ action_name: 'c++-compile'
+ tool {
+ tool_path: '%{msvc_cl_path}'
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-link-executable'
+ action_name: 'c++-link-executable'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ }
+
+ action_config {
+ config_name: 'c++-link-dynamic-library'
+ action_name: 'c++-link-dynamic-library'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-nodeps-dynamic-library'
+ action_name: 'c++-link-nodeps-dynamic-library'
+ tool {
+ tool_path: '%{msvc_link_path}'
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-static-library'
+ action_name: 'c++-link-static-library'
+ tool {
+ tool_path: '%{msvc_lib_path}'
+ }
+ implies: 'nologo'
+ implies: 'archiver_flags'
+ implies: 'input_param_flags'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ }
+
+ # TODO(b/65151735): Remove legacy_compile_flags feature when legacy fields are
+ # not used in this crosstool
+ feature {
+ name: 'legacy_compile_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'legacy_compile_flags'
+ flag: '%{legacy_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: "msvc_env"
+ env_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ env_entry {
+ key: "PATH"
+ value: "%{msvc_env_path}"
+ }
+ env_entry {
+ key: "INCLUDE"
+ value: "%{msvc_env_include}"
+ }
+ env_entry {
+ key: "LIB"
+ value: "%{msvc_env_lib}"
+ }
+ env_entry {
+ key: "TMP"
+ value: "%{msvc_env_tmp}"
+ }
+ env_entry {
+ key: "TEMP"
+ value: "%{msvc_env_tmp}"
+ }
+ }
+ }
+
+ feature {
+ name: 'include_paths'
+ flag_set {
+ action: "assemble"
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ flag_group {
+ iterate_over: 'quote_include_paths'
+ flag: '/I%{quote_include_paths}'
+ }
+ flag_group {
+ iterate_over: 'include_paths'
+ flag: '/I%{include_paths}'
+ }
+ flag_group {
+ iterate_over: 'system_include_paths'
+ flag: '/I%{system_include_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: "preprocessor_defines"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-module-compile"
+ flag_group {
+ flag: "/D%{preprocessor_defines}"
+ iterate_over: "preprocessor_defines"
+ }
+ }
+ }
+
+ # Tell Bazel to parse the output of /showIncludes
+ feature {
+ name: 'parse_showincludes'
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-module-compile'
+ action: 'c++-header-parsing'
+ flag_group {
+ flag: "/showIncludes"
+ }
+ }
+ }
+
+
+ feature {
+ name: 'generate_pdb_file'
+ requires: {
+ feature: 'dbg'
+ }
+ requires: {
+ feature: 'fastbuild'
+ }
+ }
+
+ feature {
+ name: 'shared_flag'
+ flag_set {
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/DLL'
+ }
+ }
+ }
+
+ feature {
+ name: 'linkstamps'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ expand_if_all_available: 'linkstamp_paths'
+ flag_group {
+ iterate_over: 'linkstamp_paths'
+ flag: '%{linkstamp_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: 'output_execpath_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'archiver_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'input_param_flags'
+ flag_set {
+ expand_if_all_available: 'interface_library_output_path'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/IMPLIB:%{interface_library_output_path}"
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libopts'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'libopts'
+ flag: '%{libopts}'
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libraries_to_link'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ iterate_over: 'libraries_to_link'
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file_group'
+ }
+ iterate_over: 'libraries_to_link.object_files'
+ flag_group {
+ flag: '%{libraries_to_link.object_files}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'interface_library'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'static_library'
+ }
+ flag_group {
+ expand_if_false: 'libraries_to_link.is_whole_archive'
+ flag: '%{libraries_to_link.name}'
+ }
+ flag_group {
+ expand_if_true: 'libraries_to_link.is_whole_archive'
+ flag: '/WHOLEARCHIVE:%{libraries_to_link.name}'
+ }
+ }
+ }
+ }
+ }
+
+ # Since this feature is declared earlier in the CROSSTOOL than
+ # "user_link_flags", this feature will be applied prior to it anwyhere they
+ # are both implied. And since "user_link_flags" contains the linkopts from
+ # the build rule, this allows the user to override the /SUBSYSTEM in the BUILD
+ # file.
+ feature {
+ name: 'linker_subsystem_flag'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/SUBSYSTEM:CONSOLE'
+ }
+ }
+ }
+
+ # The "user_link_flags" contains user-defined linkopts (from build rules)
+ # so it should be defined after features that declare user-overridable flags.
+ # For example the "linker_subsystem_flag" defines a default "/SUBSYSTEM" flag
+ # but we want to let the user override it, therefore "link_flag_subsystem" is
+ # defined earlier in the CROSSTOOL file than "user_link_flags".
+ feature {
+ name: 'user_link_flags'
+ flag_set {
+ expand_if_all_available: 'user_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'user_link_flags'
+ flag: '%{user_link_flags}'
+ }
+ }
+ }
+ feature {
+ name: 'legacy_link_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'legacy_link_flags'
+ flag: '%{legacy_link_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'linker_param_file'
+ flag_set {
+ expand_if_all_available: 'linker_param_file'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '@%{linker_param_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'static_link_msvcrt'
+ }
+
+ feature {
+ name: 'static_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MT"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MD"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'static_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MTd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MDd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dbg'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FULL"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'fastbuild'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FASTLINK"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'opt'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/O2"
+ flag: "/DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: 'user_compile_flags'
+ flag_set {
+ expand_if_all_available: 'user_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'user_compile_flags'
+ flag: '%{user_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'sysroot'
+ flag_set {
+ expand_if_all_available: 'sysroot'
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'sysroot'
+ flag: '--sysroot=%{sysroot}'
+ }
+ }
+ }
+
+ feature {
+ name: 'unfiltered_compile_flags'
+ flag_set {
+ expand_if_all_available: 'unfiltered_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'unfiltered_compile_flags'
+ flag: '%{unfiltered_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_output_flags'
+ flag_set {
+ action: 'assemble'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ flag: '/Zi'
+ }
+ }
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_assembly_file'
+ flag: '/Fa%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_preprocess_file'
+ flag: '/P'
+ flag: '/Fi%{output_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_input_flags'
+ flag_set {
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'source_file'
+ flag: '/c'
+ flag: '%{source_file}'
+ }
+ }
+ }
+
+ feature {
+ name : 'def_file',
+ flag_set {
+ expand_if_all_available: 'def_file_path'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEF:%{def_file_path}"
+ # We can specify a different DLL name in DEF file, /ignore:4070 suppresses
+ # the warning message about DLL name doesn't match the default one.
+ # See https://msdn.microsoft.com/en-us/library/sfkk2fz7.aspx
+ flag: "/ignore:4070"
+ }
+ }
+ }
+
+ feature {
+ name: 'windows_export_all_symbols'
+ }
+
+ feature {
+ name: 'no_windows_export_all_symbols'
+ }
+
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
index 2558f46fd5..f4f4d0ee96 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -175,6 +175,11 @@ def InvokeNvcc(argv, log=False):
# any other reliable way to just get the list of source files to be compiled.
src_files = GetOptionValue(argv, 'c')
+ # Pass -w through from host to nvcc, but don't do anything fancier with
+ # warnings-related flags, since they're not necessarily the same across
+ # compilers.
+ warning_options = ' -w' if '-w' in argv else ''
+
if len(src_files) == 0:
return 1
if len(out_file) != 1:
@@ -205,6 +210,7 @@ def InvokeNvcc(argv, log=False):
nvccopts += defines
nvccopts += std_options
nvccopts += m_options
+ nvccopts += warning_options
if depfiles:
# Generate the dependency file
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl
new file mode 100644
index 0000000000..8f8fb3e423
--- /dev/null
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.bat.tpl
@@ -0,0 +1,20 @@
+:: Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: http://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+:: =============================================================================
+
+:: Invoke msvc_wrapper_for_nvcc.py, which is located in the same directory.
+@echo OFF
+set arg0=%~0
+for %%F in ("%arg0%") do set DRIVER_BIN=%%~dpF
+"%{python_binary}" -B "%DRIVER_BIN%\msvc_wrapper_for_nvcc.py" %*
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
new file mode 100644
index 0000000000..1a09756813
--- /dev/null
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows.
+
+DESCRIPTION:
+ This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('%{cpu_compiler}')
+GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
+
+NVCC_PATH = '%{nvcc_path}'
+NVCC_VERSION = '%{cuda_version}'
+NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
+supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from options.
+
+ Args:
+ option: The option whose value to extract, without the leading '/'.
+
+ Returns:
+ 1. A list of values, either directly following the option,
+ (eg., /opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., /opt val1 /opt val2).
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser(prefix_chars='/')
+ parser.add_argument('/' + option, nargs='*', action='append')
+ args, leftover = parser.parse_known_args(argv)
+ if args and vars(args)[option]:
+ return (sum(vars(args)[option], []), leftover)
+ return ([], leftover)
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ 1. The string that can be passed directly to nvcc.
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, leftover = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return (['--' + a for a in options], leftover)
+ return ([], leftover)
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ src_files = [f for f in argv if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ if len(src_files) == 0:
+ raise Error('No source files found for cuda compilation.')
+
+ out_file = [ f for f in argv if f.startswith('/Fo') ]
+ if len(out_file) != 1:
+ raise Error('Please sepecify exactly one output file for cuda compilation.')
+ out = ['-o', out_file[0][len('/Fo'):]]
+
+ nvcc_compiler_options, argv = GetNvccOptions(argv)
+
+ opt_option, argv = GetOptionValue(argv, 'O')
+ opt = ['-g', '-G']
+ if (len(opt_option) > 0 and opt_option[0] != 'd'):
+ opt = ['-O2']
+
+ include_options, argv = GetOptionValue(argv, 'I')
+ includes = ["-I " + include for include in include_options]
+
+ defines, argv = GetOptionValue(argv, 'D')
+ defines = ['-D' + define for define in defines]
+
+ undefines, argv = GetOptionValue(argv, 'U')
+ undefines = ['-U' + define for define in undefines]
+
+ # The rest of the unrecongized options should be passed to host compiler
+ host_compiler_options = [option for option in argv if option not in (src_files + out_file)]
+
+ m_options = ["-m64"]
+
+ nvccopts = ['-D_FORCE_INLINES']
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
+ capability, capability, capability)]
+ nvccopts += nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += m_options
+ nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"']
+ nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files
+ # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP
+ # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check
+ # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver
+ # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists.
+ if os.path.isfile(NVCC_TEMP_DIR):
+ os.remove(NVCC_TEMP_DIR)
+ if not os.path.exists(NVCC_TEMP_DIR):
+ os.makedirs(NVCC_TEMP_DIR)
+ nvccopts += ['--keep', '--keep-dir', NVCC_TEMP_DIR]
+ cmd = [NVCC_PATH] + nvccopts
+ if log:
+ Log(cmd)
+ proc = subprocess.Popen(cmd,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ env=os.environ.copy(),
+ shell=True)
+ proc.wait()
+ return proc.returncode
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))
+ and not flag.startswith(('-nvcc_options'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl
index 2a37c65bc7..f6b497f813 100644
--- a/third_party/gpus/cuda/BUILD.tpl
+++ b/third_party/gpus/cuda/BUILD.tpl
@@ -128,6 +128,15 @@ cc_library(
)
cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
name = "cufft",
srcs = ["cuda/lib/%{cufft_lib}"],
data = ["cuda/lib/%{cufft_lib}"],
diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl
new file mode 100644
index 0000000000..ff6b3cc351
--- /dev/null
+++ b/third_party/gpus/cuda/BUILD.windows.tpl
@@ -0,0 +1,163 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_nvcc",
+ values = {
+ "define": "using_cuda_nvcc=true",
+ },
+)
+
+config_setting(
+ name = "using_clang",
+ values = {
+ "define": "using_cuda_clang=true",
+ },
+)
+
+# Equivalent to using_clang && -c opt.
+config_setting(
+ name = "using_clang_opt",
+ values = {
+ "define": "using_cuda_clang=true",
+ "compilation_mode": "opt",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {"cpu": "darwin"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "freebsd",
+ values = {"cpu": "freebsd"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ %{cuda_headers}
+ ],
+ includes = [
+ ".",
+ "cuda/include",
+ "cuda/include/crt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudart_static",
+ # /WHOLEARCHIVE:cudart_static.lib will cause a
+ # "Internal error during CImplib::EmitThunk" error.
+ # Treat this library as interface library to avoid being whole archived when
+ # linking a DLL that depends on this.
+ # TODO(pcloudy): Remove this rule after b/111278841 is resolved.
+ interface_library = "cuda/lib/%{cudart_static_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cuda_driver",
+ interface_library = "cuda/lib/%{cuda_driver_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudart",
+ interface_library = "cuda/lib/%{cudart_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cublas",
+ interface_library = "cuda/lib/%{cublas_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cusolver",
+ interface_library = "cuda/lib/%{cusolver_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cudnn",
+ interface_library = "cuda/lib/%{cudnn_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cufft",
+ interface_library = "cuda/lib/%{cufft_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "curand",
+ interface_library = "cuda/lib/%{curand_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cublas",
+ ":cuda_headers",
+ ":cudart",
+ ":cudnn",
+ ":cufft",
+ ":curand",
+ ],
+)
+
+cc_library(
+ name = "cupti_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-extras",
+ ],
+ includes = [
+ ".",
+ "cuda/extras/CUPTI/include/",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_import(
+ name = "cupti_dsos",
+ interface_library = "cuda/lib/%{cupti_lib}",
+ system_provided = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "libdevice_root",
+ data = [":cuda-nvvm"],
+ visibility = ["//visibility:public"],
+)
+
+%{cuda_include_genrules}
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index c90c66912d..e848fa175c 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -20,6 +20,7 @@
`/usr/local/cuda`.
* `TF_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
`3.5,5.2`.
+ * `PYTHON_BIN_PATH`: The python binary path
"""
_GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH"
@@ -31,6 +32,7 @@ _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH"
_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO"
_TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
+_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
_DEFAULT_CUDA_VERSION = ""
_DEFAULT_CUDNN_VERSION = ""
@@ -44,12 +46,12 @@ _DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
# will be used. For example, when looking for the cudart libraries, the first
# attempt will be lib64/cudart inside the CUDA toolkit.
CUDA_LIB_PATHS = [
- "lib64/",
- "lib64/stubs/",
- "lib/x86_64-linux-gnu/",
- "lib/x64/",
- "lib/",
- "",
+ "lib64/",
+ "lib64/stubs/",
+ "lib/x86_64-linux-gnu/",
+ "lib/x64/",
+ "lib/",
+ "",
]
# Lookup paths for cupti.h, relative to the CUDA toolkit directory.
@@ -57,8 +59,8 @@ CUDA_LIB_PATHS = [
# On most systems, the cupti library is not installed in the same directory as
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_HEADER_PATHS = [
- "extras/CUPTI/include/",
- "include/cuda/CUPTI/",
+ "extras/CUPTI/include/",
+ "include/cuda/CUPTI/",
]
# Lookup paths for the cupti library, relative to the
@@ -66,25 +68,25 @@ CUPTI_HEADER_PATHS = [
# On most systems, the cupti library is not installed in the same directory as
# the other CUDA libraries but rather in a special extras/CUPTI directory.
CUPTI_LIB_PATHS = [
- "extras/CUPTI/lib64/",
- "lib/x86_64-linux-gnu",
- "lib64/",
- "extras/CUPTI/libx64/",
- "extras/CUPTI/lib/",
- "lib/",
+ "extras/CUPTI/lib64/",
+ "lib/x86_64-linux-gnu",
+ "lib64/",
+ "extras/CUPTI/libx64/",
+ "extras/CUPTI/lib/",
+ "lib/",
]
# Lookup paths for CUDA headers (cuda.h) relative to the CUDA toolkit directory.
CUDA_INCLUDE_PATHS = [
- "include/",
- "include/cuda/"
+ "include/",
+ "include/cuda/",
]
# Lookup paths for cudnn.h relative to the CUDNN install directory.
CUDNN_INCLUDE_PATHS = [
- "",
- "include/",
- "include/cuda/",
+ "",
+ "include/",
+ "include/cuda/",
]
# Lookup paths for NVVM libdevice relative to the CUDA directory toolkit.
@@ -92,686 +94,841 @@ CUDNN_INCLUDE_PATHS = [
# libdevice implements mathematical functions for GPU kernels, and is provided
# in NVVM bitcode (a subset of LLVM bitcode).
NVVM_LIBDEVICE_PATHS = [
- "nvvm/libdevice/",
- "share/cuda/",
+ "nvvm/libdevice/",
+ "share/cuda/",
+]
+
+# Files used to detect the NVVM libdevice path.
+NVVM_LIBDEVICE_FILES = [
+ # CUDA 9.0 has a single file.
+ "libdevice.10.bc",
+
+ # CUDA 8.0 has separate files for compute versions 2.0, 3.0, 3.5 and 5.0.
+ # Probing for one of them is sufficient.
+ "libdevice.compute_20.10.bc",
]
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
+load(
+ "@bazel_tools//tools/cpp:lib_cc_configure.bzl",
+ "escape_string",
+ "get_env_var",
+)
+load(
+ "@bazel_tools//tools/cpp:windows_cc_configure.bzl",
+ "find_msvc_tool",
+ "find_vc_path",
+ "setup_vc_env_vars",
+)
+
+def _get_python_bin(repository_ctx):
+ """Gets the python bin path."""
+ python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
+ if python_bin != None:
+ return python_bin
+ python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
+ python_bin_path = repository_ctx.which(python_bin_name)
+ if python_bin_path != None:
+ return str(python_bin_path)
+ auto_configure_fail("Cannot find python in PATH, please make sure " +
+ "python is installed and add its directory in PATH, or --define " +
+ "%s='/something/else'.\nPATH=%s" % (
+ _PYTHON_BIN_PATH,
+ repository_ctx.os.environ.get("PATH", ""),
+ ))
+
+def _get_nvcc_tmp_dir_for_windows(repository_ctx):
+ """Return the tmp directory for nvcc to generate intermediate source files."""
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
+ )
+ return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
+
+def _get_msvc_compiler(repository_ctx):
+ vc_path = find_vc_path(repository_ctx)
+ return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
+
+def _get_win_cuda_defines(repository_ctx):
+ """Return CROSSTOOL defines for Windows"""
+
+ # If we are not on Windows, return empty vaules for Windows specific fields.
+ # This ensures the CROSSTOOL file parser is happy.
+ if not _is_windows(repository_ctx):
+ return {
+ "%{msvc_env_tmp}": "",
+ "%{msvc_env_path}": "",
+ "%{msvc_env_include}": "",
+ "%{msvc_env_lib}": "",
+ "%{msvc_cl_path}": "",
+ "%{msvc_ml_path}": "",
+ "%{msvc_link_path}": "",
+ "%{msvc_lib_path}": "",
+ "%{cxx_builtin_include_directory}": "",
+ }
+
+ vc_path = find_vc_path(repository_ctx)
+ if not vc_path:
+ auto_configure_fail("Visual C++ build tools not found on your machine." +
+ "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using")
+ return {}
+
+ env = setup_vc_env_vars(repository_ctx, vc_path)
+ escaped_paths = escape_string(env["PATH"])
+ escaped_include_paths = escape_string(env["INCLUDE"])
+ escaped_lib_paths = escape_string(env["LIB"])
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
+ )
+
+ msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat"
+ msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace("\\", "/")
+ msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace("\\", "/")
+ msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace("\\", "/")
+
+ # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
+ # The generated files are guranteed to have unique name, so they can share the same tmp directory
+ escaped_cxx_include_directories = ["cxx_builtin_include_directory: \"%s\"" % _get_nvcc_tmp_dir_for_windows(repository_ctx)]
+ for path in escaped_include_paths.split(";"):
+ if path:
+ escaped_cxx_include_directories.append("cxx_builtin_include_directory: \"%s\"" % path)
+
+ return {
+ "%{msvc_env_tmp}": escaped_tmp_dir,
+ "%{msvc_env_path}": escaped_paths,
+ "%{msvc_env_include}": escaped_include_paths,
+ "%{msvc_env_lib}": escaped_lib_paths,
+ "%{msvc_cl_path}": msvc_cl_path,
+ "%{msvc_ml_path}": msvc_ml_path,
+ "%{msvc_link_path}": msvc_link_path,
+ "%{msvc_lib_path}": msvc_lib_path,
+ "%{cxx_builtin_include_directory}": "\n".join(escaped_cxx_include_directories),
+ }
# TODO(dzc): Once these functions have been factored out of Bazel's
# cc_configure.bzl, load them from @bazel_tools instead.
# BEGIN cc_configure common functions.
def find_cc(repository_ctx):
- """Find the C++ compiler."""
- # On Windows, we use Bazel's MSVC CROSSTOOL for GPU build
- # Return a dummy value for GCC detection here to avoid error
- if _is_windows(repository_ctx):
- return "/use/--config=win-cuda --cpu=x64_windows_msvc/instead"
-
- if _use_cuda_clang(repository_ctx):
- target_cc_name = "clang"
- cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
- if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
- return "extra_tools/bin/clang"
- else:
- target_cc_name = "gcc"
- cc_path_envvar = _GCC_HOST_COMPILER_PATH
- cc_name = target_cc_name
-
- if cc_path_envvar in repository_ctx.os.environ:
- cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
- if cc_name_from_env:
- cc_name = cc_name_from_env
- if cc_name.startswith("/"):
- # Absolute path, maybe we should make this supported by our which function.
- return cc_name
- cc = repository_ctx.which(cc_name)
- if cc == None:
- fail(("Cannot find {}, either correct your path or set the {}" +
- " environment variable").format(target_cc_name, cc_path_envvar))
- return cc
-
+ """Find the C++ compiler."""
+ if _is_windows(repository_ctx):
+ return _get_msvc_compiler(repository_ctx)
+
+ if _use_cuda_clang(repository_ctx):
+ target_cc_name = "clang"
+ cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
+ if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
+ return "extra_tools/bin/clang"
+ else:
+ target_cc_name = "gcc"
+ cc_path_envvar = _GCC_HOST_COMPILER_PATH
+ cc_name = target_cc_name
+
+ if cc_path_envvar in repository_ctx.os.environ:
+ cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
+ if cc_name_from_env:
+ cc_name = cc_name_from_env
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this supported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(("Cannot find {}, either correct your path or set the {}" +
+ " environment variable").format(target_cc_name, cc_path_envvar))
+ return cc
_INC_DIR_MARKER_BEGIN = "#include <...>"
-
# OSX add " (framework directory)" at the end of line, strip it.
_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
-_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
-def _cxx_inc_convert(path):
- """Convert path returned by cc -E xc++ in a complete path."""
- path = path.strip()
- if path.endswith(_OSX_FRAMEWORK_SUFFIX):
- path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
- return path
+_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
+def _cxx_inc_convert(path):
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ if path.endswith(_OSX_FRAMEWORK_SUFFIX):
+ path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
+ return path
def _normalize_include_path(repository_ctx, path):
- """Normalizes include paths before writing them to the crosstool.
+ """Normalizes include paths before writing them to the crosstool.
- If path points inside the 'crosstool' folder of the repository, a relative
- path is returned.
- If path points outside the 'crosstool' folder, an absolute path is returned.
- """
- path = str(repository_ctx.path(path))
- crosstool_folder = str(repository_ctx.path(".").get_child('crosstool'))
-
- if path.startswith(crosstool_folder):
- # We drop the path to "$REPO/crosstool" and a trailing path separator.
- return path[len(crosstool_folder)+1:]
- return path
+ If path points inside the 'crosstool' folder of the repository, a relative
+ path is returned.
+ If path points outside the 'crosstool' folder, an absolute path is returned.
+ """
+ path = str(repository_ctx.path(path))
+ crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
+ if path.startswith(crosstool_folder):
+ # We drop the path to "$REPO/crosstool" and a trailing path separator.
+ return path[len(crosstool_folder) + 1:]
+ return path
def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
- """Compute the list of default C or C++ include directories."""
- if lang_is_cpp:
- lang = "c++"
- else:
- lang = "c"
- result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
- index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
- if index1 == -1:
- return []
- index1 = result.stderr.find("\n", index1)
- if index1 == -1:
- return []
- index2 = result.stderr.rfind("\n ")
- if index2 == -1 or index2 < index1:
- return []
- index2 = result.stderr.find("\n", index2 + 1)
- if index2 == -1:
- inc_dirs = result.stderr[index1 + 1:]
- else:
- inc_dirs = result.stderr[index1 + 1:index2].strip()
-
- return [
- _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
- for p in inc_dirs.split("\n")
- ]
+ """Compute the list of default C or C++ include directories."""
+ if lang_is_cpp:
+ lang = "c++"
+ else:
+ lang = "c"
+ result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+ return [
+ _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
+ for p in inc_dirs.split("\n")
+ ]
def get_cxx_inc_directories(repository_ctx, cc):
- """Compute the list of default C and C++ include directories."""
- # For some reason `clang -xc` sometimes returns include paths that are
- # different from the ones from `clang -xc++`. (Symlink and a dir)
- # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
- includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
- includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+ """Compute the list of default C and C++ include directories."""
- includes_cpp_set = depset(includes_cpp)
- return includes_cpp + [inc for inc in includes_c
- if inc not in includes_cpp_set]
+ # For some reason `clang -xc` sometimes returns include paths that are
+ # different from the ones from `clang -xc++`. (Symlink and a dir)
+ # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
+ includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
+ includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+ includes_cpp_set = depset(includes_cpp)
+ return includes_cpp + [
+ inc
+ for inc in includes_c
+ if inc not in includes_cpp_set
+ ]
def auto_configure_fail(msg):
- """Output failure message when cuda configuration fails."""
- red = "\033[0;31m"
- no_color = "\033[0m"
- fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
-# END cc_configure common functions (see TODO above).
+ """Output failure message when cuda configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
+# END cc_configure common functions (see TODO above).
def _host_compiler_includes(repository_ctx, cc):
- """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
-
- Args:
- repository_ctx: The repository context.
- cc: The path to the gcc host compiler.
-
- Returns:
- A string containing the cxx_builtin_include_directory for each of the gcc
- host compiler include directories, which can be added to the CROSSTOOL
- file.
- """
- inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
- inc_entries = []
- for inc_dir in inc_dirs:
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
- return "\n".join(inc_entries)
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
+
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
def _cuda_include_path(repository_ctx, cuda_config):
- """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
-
- Args:
- repository_ctx: The repository context.
- cc: The path to the gcc host compiler.
-
- Returns:
- A string containing the cxx_builtin_include_directory for each of the gcc
- host compiler include directories, which can be added to the CROSSTOOL
- file.
- """
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_config.cuda_toolkit_path,
- ".exe" if cuda_config.cpu_value == "Windows" else ""))
- result = repository_ctx.execute([nvcc_path, '-v',
- '/dev/null', '-o', '/dev/null'])
- target_dir = ""
- for one_line in result.stderr.splitlines():
- if one_line.startswith('#$ _TARGET_DIR_='):
- target_dir = (cuda_config.cuda_toolkit_path + '/' +
- one_line.replace('#$ _TARGET_DIR_=', '') + "/include")
- inc_entries = []
- if target_dir != "":
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
- default_include = cuda_config.cuda_toolkit_path + '/include'
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" %
- default_include)
- return "\n".join(inc_entries)
+ """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
+ Args:
+ repository_ctx: The repository context.
+ cc: The path to the gcc host compiler.
-def _enable_cuda(repository_ctx):
- if "TF_NEED_CUDA" in repository_ctx.os.environ:
- enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
- return enable_cuda == "1"
- return False
+ Returns:
+ A string containing the cxx_builtin_include_directory for each of the gcc
+ host compiler include directories, which can be added to the CROSSTOOL
+ file.
+ """
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if cuda_config.cpu_value == "Windows" else "",
+ ))
+ result = repository_ctx.execute([
+ nvcc_path,
+ "-v",
+ "/dev/null",
+ "-o",
+ "/dev/null",
+ ])
+ target_dir = ""
+ for one_line in result.stderr.splitlines():
+ if one_line.startswith("#$ _TARGET_DIR_="):
+ target_dir = (cuda_config.cuda_toolkit_path + "/" +
+ one_line.replace("#$ _TARGET_DIR_=", "") + "/include")
+ inc_entries = []
+ if target_dir != "":
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
+ default_include = cuda_config.cuda_toolkit_path + "/include"
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" %
+ default_include)
+ return "\n".join(inc_entries)
+def _enable_cuda(repository_ctx):
+ if "TF_NEED_CUDA" in repository_ctx.os.environ:
+ enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
+ return enable_cuda == "1"
+ return False
def _cuda_toolkit_path(repository_ctx):
- """Finds the cuda toolkit directory.
-
- Args:
- repository_ctx: The repository context.
+ """Finds the cuda toolkit directory.
- Returns:
- A speculative real path of the cuda toolkit install directory.
- """
- cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
- if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
- cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
- if not repository_ctx.path(cuda_toolkit_path).exists:
- auto_configure_fail("Cannot find cuda toolkit path.")
- return str(repository_ctx.path(cuda_toolkit_path).realpath)
+ Args:
+ repository_ctx: The repository context.
+ Returns:
+ A speculative real path of the cuda toolkit install directory.
+ """
+ cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
+ if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
+ cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
+ if not repository_ctx.path(cuda_toolkit_path).exists:
+ auto_configure_fail("Cannot find cuda toolkit path.")
+ return str(repository_ctx.path(cuda_toolkit_path).realpath)
def _cudnn_install_basedir(repository_ctx):
- """Finds the cudnn install directory."""
- cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
- if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
- cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
- if not repository_ctx.path(cudnn_install_path).exists:
- auto_configure_fail("Cannot find cudnn install path.")
- return cudnn_install_path
-
+ """Finds the cudnn install directory."""
+ cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
+ if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
+ cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
+ if not repository_ctx.path(cudnn_install_path).exists:
+ auto_configure_fail("Cannot find cudnn install path.")
+ return cudnn_install_path
def matches_version(environ_version, detected_version):
- """Checks whether the user-specified version matches the detected version.
-
- This function performs a weak matching so that if the user specifies only the
- major or major and minor versions, the versions are still considered matching
- if the version parts match. To illustrate:
-
- environ_version detected_version result
- -----------------------------------------
- 5.1.3 5.1.3 True
- 5.1 5.1.3 True
- 5 5.1 True
- 5.1.3 5.1 False
- 5.2.3 5.1.3 False
-
- Args:
- environ_version: The version specified by the user via environment
- variables.
- detected_version: The version autodetected from the CUDA installation on
- the system.
-
- Returns: True if user-specified version matches detected version and False
- otherwise.
- """
- environ_version_parts = environ_version.split(".")
- detected_version_parts = detected_version.split(".")
- if len(detected_version_parts) < len(environ_version_parts):
- return False
- for i, part in enumerate(detected_version_parts):
- if i >= len(environ_version_parts):
- break
- if part != environ_version_parts[i]:
- return False
- return True
-
+ """Checks whether the user-specified version matches the detected version.
+
+ This function performs a weak matching so that if the user specifies only the
+ major or major and minor versions, the versions are still considered matching
+ if the version parts match. To illustrate:
+
+ environ_version detected_version result
+ -----------------------------------------
+ 5.1.3 5.1.3 True
+ 5.1 5.1.3 True
+ 5 5.1 True
+ 5.1.3 5.1 False
+ 5.2.3 5.1.3 False
+
+ Args:
+ environ_version: The version specified by the user via environment
+ variables.
+ detected_version: The version autodetected from the CUDA installation on
+ the system.
+
+ Returns: True if user-specified version matches detected version and False
+ otherwise.
+ """
+ environ_version_parts = environ_version.split(".")
+ detected_version_parts = detected_version.split(".")
+ if len(detected_version_parts) < len(environ_version_parts):
+ return False
+ for i, part in enumerate(detected_version_parts):
+ if i >= len(environ_version_parts):
+ break
+ if part != environ_version_parts[i]:
+ return False
+ return True
_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
-
def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
- """Detects the version of CUDA installed on the system.
-
- Args:
- repository_ctx: The repository context.
- cuda_toolkit_path: The CUDA install directory.
-
- Returns:
- String containing the version of CUDA.
- """
- # Run nvcc --version and find the line containing the CUDA version.
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_toolkit_path,
- ".exe" if cpu_value == "Windows" else ""))
- if not nvcc_path.exists:
- auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
- result = repository_ctx.execute([str(nvcc_path), '--version'])
- if result.stderr:
- auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
- lines = result.stdout.splitlines()
- version_line = lines[len(lines) - 1]
- if version_line.find(_NVCC_VERSION_PREFIX) == -1:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout)
-
- # Parse the CUDA version from the line containing the CUDA version.
- prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, '')
- parts = prefix_removed.split(",")
- if len(parts) != 2 or len(parts[0]) < 2:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout)
- full_version = parts[1].strip()
- if full_version.startswith('V'):
- full_version = full_version[1:]
-
- # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDA_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- auto_configure_fail(
- ("CUDA version detected from nvcc (%s) does not match " +
- "TF_CUDA_VERSION (%s)") % (full_version, environ_version))
-
- # We only use the version consisting of the major and minor version numbers.
- version_parts = full_version.split('.')
- if len(version_parts) < 2:
- auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
- if cpu_value == "Windows":
- version = "64_%s%s" % (version_parts[0], version_parts[1])
- else:
- version = "%s.%s" % (version_parts[0], version_parts[1])
- return version
+ """Detects the version of CUDA installed on the system.
+
+ Args:
+ repository_ctx: The repository context.
+ cuda_toolkit_path: The CUDA install directory.
+
+ Returns:
+ String containing the version of CUDA.
+ """
+
+ # Run nvcc --version and find the line containing the CUDA version.
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_toolkit_path,
+ ".exe" if cpu_value == "Windows" else "",
+ ))
+ if not nvcc_path.exists:
+ auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
+ result = repository_ctx.execute([str(nvcc_path), "--version"])
+ if result.stderr:
+ auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
+ lines = result.stdout.splitlines()
+ version_line = lines[len(lines) - 1]
+ if version_line.find(_NVCC_VERSION_PREFIX) == -1:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,
+ )
+ # Parse the CUDA version from the line containing the CUDA version.
+ prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "")
+ parts = prefix_removed.split(",")
+ if len(parts) != 2 or len(parts[0]) < 2:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,
+ )
+ full_version = parts[1].strip()
+ if full_version.startswith("V"):
+ full_version = full_version[1:]
+
+ # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDA_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ auto_configure_fail(
+ ("CUDA version detected from nvcc (%s) does not match " +
+ "TF_CUDA_VERSION (%s)") % (full_version, environ_version),
+ )
+
+ # We only use the version consisting of the major and minor version numbers.
+ version_parts = full_version.split(".")
+ if len(version_parts) < 2:
+ auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
+ if cpu_value == "Windows":
+ version = "64_%s%s" % (version_parts[0], version_parts[1])
+ else:
+ version = "%s.%s" % (version_parts[0], version_parts[1])
+ return version
_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
_DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
-
def find_cuda_define(repository_ctx, header_dir, header_file, define):
- """Returns the value of a #define in a header file.
-
- Greps through a header file and returns the value of the specified #define.
- If the #define is not found, then raise an error.
-
- Args:
- repository_ctx: The repository context.
- header_dir: The directory containing the header file.
- header_file: The header file name.
- define: The #define to search for.
-
- Returns:
- The value of the #define found in the header.
- """
- # Confirm location of the header and grep for the line defining the macro.
- h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
- if not h_path.exists:
- auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
- result = repository_ctx.execute(
- # Grep one more lines as some #defines are splitted into two lines.
- ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
- if result.stderr:
- auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
-
- # Parse the version from the line defining the macro.
- if result.stdout.find(define) == -1:
- auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, h_path))
- # Split results to lines
- lines = result.stdout.split('\n')
- num_lines = len(lines)
- for l in range(num_lines):
- line = lines[l]
- if define in line: # Find the line with define
- version = line
- if l != num_lines-1 and line[-1] == '\\': # Add next line, if multiline
- version = version[:-1] + lines[l+1]
- break
- # Remove any comments
- version = version.split("//")[0]
- # Remove define name
- version = version.replace(define, "").strip()
- # Remove the code after the version number.
- version_end = version.find(" ")
- if version_end != -1:
- if version_end == 0:
- auto_configure_fail(
- "Cannot extract the version from line containing '%s' in %s" %
- (define, str(h_path)))
- version = version[:version_end].strip()
- return version
+ """Returns the value of a #define in a header file.
+
+ Greps through a header file and returns the value of the specified #define.
+ If the #define is not found, then raise an error.
+ Args:
+ repository_ctx: The repository context.
+ header_dir: The directory containing the header file.
+ header_file: The header file name.
+ define: The #define to search for.
+
+ Returns:
+ The value of the #define found in the header.
+ """
+
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)],
+ )
+ if result.stderr:
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
+
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
+ auto_configure_fail("Cannot find line containing '%s' in %s" %
+ (define, h_path))
+
+ # Split results to lines
+ lines = result.stdout.split("\n")
+ num_lines = len(lines)
+ for l in range(num_lines):
+ line = lines[l]
+ if define in line: # Find the line with define
+ version = line
+ if l != num_lines - 1 and line[-1] == "\\": # Add next line, if multiline
+ version = version[:-1] + lines[l + 1]
+ break
+
+ # Remove any comments
+ version = version.split("//")[0]
+
+ # Remove define name
+ version = version.replace(define, "").strip()
+
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)),
+ )
+ version = version[:version_end].strip()
+ return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
- """Detects the version of cuDNN installed on the system.
-
- Args:
- repository_ctx: The repository context.
- cpu_value: The name of the host operating system.
- cudnn_install_basedir: The cuDNN install directory.
-
- Returns:
- A string containing the version of cuDNN.
- """
- cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
- cudnn_install_basedir)
- major_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MAJOR)
- minor_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_MINOR)
- patch_version = find_cuda_define(
- repository_ctx, cudnn_header_dir, "cudnn.h", _DEFINE_CUDNN_PATCHLEVEL)
- full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
-
- # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDNN_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
- cudnn_install_basedir)
- auto_configure_fail(
- ("cuDNN version detected from %s (%s) does not match " +
- "TF_CUDNN_VERSION (%s)") %
- (str(cudnn_h_path), full_version, environ_version))
-
- # We only use the major version since we use the libcudnn libraries that are
- # only versioned with the major version (e.g. libcudnn.so.5).
- version = major_version
- if cpu_value == "Windows":
- version = "64_" + version
- return version
+ """Detects the version of cuDNN installed on the system.
+ Args:
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ cudnn_install_basedir: The cuDNN install directory.
-def _compute_capabilities(repository_ctx):
- """Returns a list of strings representing cuda compute capabilities."""
- if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
- return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
- capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
- capabilities = capabilities_str.split(",")
- for capability in capabilities:
- # Workaround for Skylark's lack of support for regex. This check should
- # be equivalent to checking:
- # if re.match("[0-9]+.[0-9]+", capability) == None:
- parts = capability.split(".")
- if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
- auto_configure_fail("Invalid compute capability: %s" % capability)
- return capabilities
+ Returns:
+ A string containing the version of cuDNN.
+ """
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cudnn_install_basedir,
+ )
+ major_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MAJOR,
+ )
+ minor_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MINOR,
+ )
+ patch_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_PATCHLEVEL,
+ )
+ full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+
+ # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDNN_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
+ cudnn_install_basedir)
+ auto_configure_fail(
+ ("cuDNN version detected from %s (%s) does not match " +
+ "TF_CUDNN_VERSION (%s)") %
+ (str(cudnn_h_path), full_version, environ_version),
+ )
+ # We only use the major version since we use the libcudnn libraries that are
+ # only versioned with the major version (e.g. libcudnn.so.5).
+ version = major_version
+ if cpu_value == "Windows":
+ version = "64_" + version
+ return version
-def get_cpu_value(repository_ctx):
- """Returns the name of the host operating system.
+def _compute_capabilities(repository_ctx):
+ """Returns a list of strings representing cuda compute capabilities."""
+ if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
+ return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
+ capabilities = capabilities_str.split(",")
+ for capability in capabilities:
+ # Workaround for Skylark's lack of support for regex. This check should
+ # be equivalent to checking:
+ # if re.match("[0-9]+.[0-9]+", capability) == None:
+ parts = capability.split(".")
+ if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
+ auto_configure_fail("Invalid compute capability: %s" % capability)
+ return capabilities
- Args:
- repository_ctx: The repository context.
+def get_cpu_value(repository_ctx):
+ """Returns the name of the host operating system.
- Returns:
- A string containing the name of the host operating system.
- """
- os_name = repository_ctx.os.name.lower()
- if os_name.startswith("mac os"):
- return "Darwin"
- if os_name.find("windows") != -1:
- return "Windows"
- result = repository_ctx.execute(["uname", "-s"])
- return result.stdout.strip()
+ Args:
+ repository_ctx: The repository context.
+ Returns:
+ A string containing the name of the host operating system.
+ """
+ os_name = repository_ctx.os.name.lower()
+ if os_name.startswith("mac os"):
+ return "Darwin"
+ if os_name.find("windows") != -1:
+ return "Windows"
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
def _is_windows(repository_ctx):
- """Returns true if the host operating system is windows."""
- return get_cpu_value(repository_ctx) == "Windows"
-
-def _lib_name(lib, cpu_value, version="", static=False):
- """Constructs the platform-specific name of a library.
-
- Args:
- lib: The name of the library, such as "cudart"
- cpu_value: The name of the host operating system.
- version: The version of the library.
- static: True the library is static or False if it is a shared object.
-
- Returns:
- The platform-specific name of the library.
- """
- if cpu_value in ("Linux", "FreeBSD"):
- if static:
- return "lib%s.a" % lib
- else:
- if version:
- version = ".%s" % version
- return "lib%s.so%s" % (lib, version)
- elif cpu_value == "Windows":
- return "%s.lib" % lib
- elif cpu_value == "Darwin":
- if static:
- return "lib%s.a" % lib
- else:
- if version:
- version = ".%s" % version
- return "lib%s%s.dylib" % (lib, version)
- else:
- auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
-
-
-def _find_cuda_lib(lib, repository_ctx, cpu_value, basedir, version="",
- static=False):
- """Finds the given CUDA or cuDNN library on the system.
-
- Args:
- lib: The name of the library, such as "cudart"
- repository_ctx: The repository context.
- cpu_value: The name of the host operating system.
- basedir: The install directory of CUDA or cuDNN.
- version: The version of the library.
- static: True if static library, False if shared object.
-
- Returns:
- Returns a struct with the following fields:
- file_name: The basename of the library found on the system.
- path: The full path to the library.
- """
- file_name = _lib_name(lib, cpu_value, version, static)
- for relative_path in CUDA_LIB_PATHS:
- path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
- if path.exists:
- return struct(file_name=file_name, path=str(path.realpath))
- auto_configure_fail("Cannot find cuda library %s" % file_name)
+ """Returns true if the host operating system is windows."""
+ return get_cpu_value(repository_ctx) == "Windows"
+def _lib_name(lib, cpu_value, version = "", static = False):
+ """Constructs the platform-specific name of a library.
-def _find_cupti_header_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing cupti.h
+ Args:
+ lib: The name of the library, such as "cudart"
+ cpu_value: The name of the host operating system.
+ version: The version of the library.
+ static: True the library is static or False if it is a shared object.
+
+ Returns:
+ The platform-specific name of the library.
+ """
+ if cpu_value in ("Linux", "FreeBSD"):
+ if static:
+ return "lib%s.a" % lib
+ else:
+ if version:
+ version = ".%s" % version
+ return "lib%s.so%s" % (lib, version)
+ elif cpu_value == "Windows":
+ return "%s.lib" % lib
+ elif cpu_value == "Darwin":
+ if static:
+ return "lib%s.a" % lib
+ elif version:
+ version = ".%s" % version
+ return "lib%s%s.dylib" % (lib, version)
+ else:
+ auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+
+def _find_cuda_lib(
+ lib,
+ repository_ctx,
+ cpu_value,
+ basedir,
+ version = "",
+ static = False):
+ """Finds the given CUDA or cuDNN library on the system.
+
+ Args:
+ lib: The name of the library, such as "cudart"
+ repository_ctx: The repository context.
+ cpu_value: The name of the host operating system.
+ basedir: The install directory of CUDA or cuDNN.
+ version: The version of the library.
+ static: True if static library, False if shared object.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(lib, cpu_value, version, static)
+ for relative_path in CUDA_LIB_PATHS:
+ path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+ auto_configure_fail("Cannot find cuda library %s" % file_name)
- On most systems, the cupti library is not installed in the same directory as
- the other CUDA libraries but rather in a special extras/CUPTI directory.
+def _find_cupti_header_dir(repository_ctx, cuda_config):
+ """Returns the path to the directory containing cupti.h
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
+ On most systems, the cupti library is not installed in the same directory as
+ the other CUDA libraries but rather in a special extras/CUPTI directory.
- Returns:
- The path of the directory containing the cupti header.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_HEADER_PATHS:
- if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the cupti header.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_HEADER_PATHS:
+ if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
def _find_cupti_lib(repository_ctx, cuda_config):
- """Finds the cupti library on the system.
-
- On most systems, the cupti library is not installed in the same directory as
- the other CUDA libraries but rather in a special extras/CUPTI directory.
-
- Args:
- repository_ctx: The repository context.
- cuda_config: The cuda configuration as returned by _get_cuda_config.
-
- Returns:
- Returns a struct with the following fields:
- file_name: The basename of the library found on the system.
- path: The full path to the library.
- """
- file_name = _lib_name("cupti", cuda_config.cpu_value,
- cuda_config.cuda_version)
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_LIB_PATHS:
- path = repository_ctx.path(
- "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name))
- if path.exists:
- return struct(file_name=file_name, path=str(path.realpath))
-
- auto_configure_fail("Cannot find cupti library %s" % file_name)
+ """Finds the cupti library on the system.
+
+ On most systems, the cupti library is not installed in the same directory as
+ the other CUDA libraries but rather in a special extras/CUPTI directory.
+
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The cuda configuration as returned by _get_cuda_config.
+
+ Returns:
+ Returns a struct with the following fields:
+ file_name: The basename of the library found on the system.
+ path: The full path to the library.
+ """
+ file_name = _lib_name(
+ "cupti",
+ cuda_config.cpu_value,
+ cuda_config.cuda_version,
+ )
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_LIB_PATHS:
+ path = repository_ctx.path(
+ "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name),
+ )
+ if path.exists:
+ return struct(file_name = file_name, path = str(path.realpath))
+
+ auto_configure_fail("Cannot find cupti library %s" % file_name)
def _find_libs(repository_ctx, cuda_config):
- """Returns the CUDA and cuDNN libraries on the system.
-
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
-
- Returns:
- Map of library names to structs of filename and path.
- """
- cpu_value = cuda_config.cpu_value
- return {
- "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
- "cudart": _find_cuda_lib(
- "cudart", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cudart_static": _find_cuda_lib(
- "cudart_static", repository_ctx, cpu_value,
- cuda_config.cuda_toolkit_path, cuda_config.cuda_version, static=True),
- "cublas": _find_cuda_lib(
- "cublas", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cusolver": _find_cuda_lib(
- "cusolver", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "curand": _find_cuda_lib(
- "curand", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cufft": _find_cuda_lib(
- "cufft", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version),
- "cudnn": _find_cuda_lib(
- "cudnn", repository_ctx, cpu_value, cuda_config.cudnn_install_basedir,
- cuda_config.cudnn_version),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config)
- }
+ """Returns the CUDA and cuDNN libraries on the system.
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
-def _find_cuda_include_path(repository_ctx, cuda_config):
- """Returns the path to the directory containing cuda.h
+ Returns:
+ Map of library names to structs of filename and path.
+ """
+ cpu_value = cuda_config.cpu_value
+ return {
+ "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
+ "cudart": _find_cuda_lib(
+ "cudart",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudart_static": _find_cuda_lib(
+ "cudart_static",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ static = True,
+ ),
+ "cublas": _find_cuda_lib(
+ "cublas",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cusolver": _find_cuda_lib(
+ "cusolver",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "curand": _find_cuda_lib(
+ "curand",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cufft": _find_cuda_lib(
+ "cufft",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudnn": _find_cuda_lib(
+ "cudnn",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cudnn_install_basedir,
+ cuda_config.cudnn_version,
+ ),
+ "cupti": _find_cupti_lib(repository_ctx, cuda_config),
+ }
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
+def _find_cuda_include_path(repository_ctx, cuda_config):
+ """Returns the path to the directory containing cuda.h
- Returns:
- The path of the directory containing the CUDA headers.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the CUDA headers.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
- """Returns the path to the directory containing cudnn.h
-
- Args:
- repository_ctx: The repository context.
- cudnn_install_basedir: The cudnn install directory as returned by
- _cudnn_install_basedir.
+ """Returns the path to the directory containing cudnn.h
- Returns:
- The path of the directory containing the cudnn header.
- """
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
- return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
- if repository_ctx.path("/usr/include/cudnn.h").exists:
- return "/usr/include"
- auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+ Args:
+ repository_ctx: The repository context.
+ cudnn_install_basedir: The cudnn install directory as returned by
+ _cudnn_install_basedir.
+ Returns:
+ The path of the directory containing the cudnn header.
+ """
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
+ return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
+ if repository_ctx.path("/usr/include/cudnn.h").exists:
+ return "/usr/include"
+ auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing libdevice in bitcode format.
+ """Returns the path to the directory containing libdevice in bitcode format.
- Args:
- repository_ctx: The repository context.
- cuda_config: The CUDA config as returned by _get_cuda_config
-
- Returns:
- The path of the directory containing the CUDA headers.
- """
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in NVVM_LIBDEVICE_PATHS:
- if repository_ctx.path("%s/%slibdevice.10.bc" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find libdevice.10.bc under %s" % cuda_toolkit_path)
+ Args:
+ repository_ctx: The repository context.
+ cuda_config: The CUDA config as returned by _get_cuda_config
+ Returns:
+ The path of the directory containing the CUDA headers.
+ """
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for libdevice_file in NVVM_LIBDEVICE_FILES:
+ for relative_path in NVVM_LIBDEVICE_PATHS:
+ if repository_ctx.path("%s/%s%s" % (cuda_toolkit_path, relative_path, libdevice_file)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find libdevice*.bc files under %s" % cuda_toolkit_path)
def _cudart_static_linkopt(cpu_value):
- """Returns additional platform-specific linkopts for cudart."""
- return "" if cpu_value == "Darwin" else "\"-lrt\","
+ """Returns additional platform-specific linkopts for cudart."""
+ return "" if cpu_value == "Darwin" else "\"-lrt\","
def _get_cuda_config(repository_ctx):
- """Detects and returns information about the CUDA installation on the system.
-
- Args:
- repository_ctx: The repository context.
-
- Returns:
- A struct containing the following fields:
- cuda_toolkit_path: The CUDA toolkit installation directory.
- cudnn_install_basedir: The cuDNN installation directory.
- cuda_version: The version of CUDA on the system.
- cudnn_version: The version of cuDNN on the system.
- compute_capabilities: A list of the system's CUDA compute capabilities.
- cpu_value: The name of the host operating system.
- """
- cpu_value = get_cpu_value(repository_ctx)
- cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
- cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
- cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
- cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
- return struct(
- cuda_toolkit_path = cuda_toolkit_path,
- cudnn_install_basedir = cudnn_install_basedir,
- cuda_version = cuda_version,
- cudnn_version = cudnn_version,
- compute_capabilities = _compute_capabilities(repository_ctx),
- cpu_value = cpu_value)
-
-
-def _tpl(repository_ctx, tpl, substitutions={}, out=None):
- if not out:
- out = tpl.replace(":", "/")
- repository_ctx.template(
- out,
- Label("//third_party/gpus/%s.tpl" % tpl),
- substitutions)
-
+ """Detects and returns information about the CUDA installation on the system.
+
+ Args:
+ repository_ctx: The repository context.
+
+ Returns:
+ A struct containing the following fields:
+ cuda_toolkit_path: The CUDA toolkit installation directory.
+ cudnn_install_basedir: The cuDNN installation directory.
+ cuda_version: The version of CUDA on the system.
+ cudnn_version: The version of cuDNN on the system.
+ compute_capabilities: A list of the system's CUDA compute capabilities.
+ cpu_value: The name of the host operating system.
+ """
+ cpu_value = get_cpu_value(repository_ctx)
+ cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
+ cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
+ cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
+ cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
+ return struct(
+ cuda_toolkit_path = cuda_toolkit_path,
+ cudnn_install_basedir = cudnn_install_basedir,
+ cuda_version = cuda_version,
+ cudnn_version = cudnn_version,
+ compute_capabilities = _compute_capabilities(repository_ctx),
+ cpu_value = cpu_value,
+ )
+
+def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions,
+ )
def _file(repository_ctx, label):
- repository_ctx.template(
- label.replace(":", "/"),
- Label("//third_party/gpus/%s.tpl" % label),
- {})
-
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {},
+ )
_DUMMY_CROSSTOOL_BZL_FILE = """
def error_gpu_disabled():
@@ -792,379 +949,498 @@ def error_gpu_disabled():
)
"""
-
_DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled")
error_gpu_disabled()
"""
-
def _create_dummy_repository(repository_ctx):
- cpu_value = get_cpu_value(repository_ctx)
-
- # Set up BUILD file for cuda/.
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]",
- })
- _tpl(repository_ctx, "cuda:BUILD",
- {
- "%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
- "%{cudart_static_lib}": _lib_name("cudart_static", cpu_value,
- static=True),
- "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
- "%{cudart_lib}": _lib_name("cudart", cpu_value),
- "%{cublas_lib}": _lib_name("cublas", cpu_value),
- "%{cusolver_lib}": _lib_name("cusolver", cpu_value),
- "%{cudnn_lib}": _lib_name("cudnn", cpu_value),
- "%{cufft_lib}": _lib_name("cufft", cpu_value),
- "%{curand_lib}": _lib_name("curand", cpu_value),
- "%{cupti_lib}": _lib_name("cupti", cpu_value),
- "%{cuda_include_genrules}": '',
- "%{cuda_headers}": '',
- })
-
- # Create dummy files for the CUDA toolkit since they are still required by
- # tensorflow/core/platform/default/build_config:cuda.
- repository_ctx.file("cuda/cuda/include/cuda.h", "")
- repository_ctx.file("cuda/cuda/include/cublas.h", "")
- repository_ctx.file("cuda/cuda/include/cudnn.h", "")
- repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
- _tpl(repository_ctx, "cuda:cuda_config.h",
- {
- "%{cuda_version}": _DEFAULT_CUDA_VERSION,
- "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
- "%{cuda_compute_capabilities}": ",".join([
- "CudaVersion(\"%s\")" % c
- for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES]),
- "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
- }, "cuda/cuda/cuda_config.h")
-
- # If cuda_configure is not configured to build with GPU support, and the user
- # attempts to build with --config=cuda, add a dummy build rule to intercept
- # this and fail with an actionable error message.
- repository_ctx.file("crosstool/error_gpu_disabled.bzl",
- _DUMMY_CROSSTOOL_BZL_FILE)
- repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
-
-
-def _execute(repository_ctx, cmdline, error_msg=None, error_details=None,
- empty_stdout_fine=False):
- """Executes an arbitrary shell command.
-
- Args:
- repository_ctx: the repository_ctx object
- cmdline: list of strings, the command to execute
- error_msg: string, a summary of the error if the command fails
- error_details: string, details about the error or steps to fix it
- empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
- it's an error
- Return:
- the result of repository_ctx.execute(cmdline)
- """
- result = repository_ctx.execute(cmdline)
- if result.stderr or not (empty_stdout_fine or result.stdout):
- auto_configure_fail(
- "\n".join([
- error_msg.strip() if error_msg else "Repository command failed",
- result.stderr.strip(),
- error_details if error_details else ""]))
- return result
-
+ cpu_value = get_cpu_value(repository_ctx)
+
+ # Set up BUILD file for cuda/.
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "False",
+ "%{cuda_extra_copts}": "[]",
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
+ "%{cudart_static_lib}": _lib_name(
+ "cudart_static",
+ cpu_value,
+ static = True,
+ ),
+ "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
+ "%{cudart_lib}": _lib_name("cudart", cpu_value),
+ "%{cublas_lib}": _lib_name("cublas", cpu_value),
+ "%{cusolver_lib}": _lib_name("cusolver", cpu_value),
+ "%{cudnn_lib}": _lib_name("cudnn", cpu_value),
+ "%{cufft_lib}": _lib_name("cufft", cpu_value),
+ "%{curand_lib}": _lib_name("curand", cpu_value),
+ "%{cupti_lib}": _lib_name("cupti", cpu_value),
+ "%{cuda_include_genrules}": "",
+ "%{cuda_headers}": "",
+ },
+ )
+
+ # Create dummy files for the CUDA toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:cuda.
+ repository_ctx.file("cuda/cuda/include/cuda.h", "")
+ repository_ctx.file("cuda/cuda/include/cublas.h", "")
+ repository_ctx.file("cuda/cuda/include/cudnn.h", "")
+ repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": _DEFAULT_CUDA_VERSION,
+ "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
+ "%{cuda_compute_capabilities}": ",".join([
+ "CudaVersion(\"%s\")" % c
+ for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ ]),
+ "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
+
+ # If cuda_configure is not configured to build with GPU support, and the user
+ # attempts to build with --config=cuda, add a dummy build rule to intercept
+ # this and fail with an actionable error message.
+ repository_ctx.file(
+ "crosstool/error_gpu_disabled.bzl",
+ _DUMMY_CROSSTOOL_BZL_FILE,
+ )
+ repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
+
+def _execute(
+ repository_ctx,
+ cmdline,
+ error_msg = None,
+ error_details = None,
+ empty_stdout_fine = False):
+ """Executes an arbitrary shell command.
+
+ Args:
+ repository_ctx: the repository_ctx object
+ cmdline: list of strings, the command to execute
+ error_msg: string, a summary of the error if the command fails
+ error_details: string, details about the error or steps to fix it
+ empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
+ it's an error
+ Return:
+ the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]),
+ )
+ return result
def _norm_path(path):
- """Returns a path with '/' and remove the trailing slash."""
- path = path.replace("\\", "/")
- if path[-1] == "/":
- path = path[:-1]
- return path
-
-
-def symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
- src_files = [], dest_files = []):
- """Returns a genrule to symlink(or copy if on Windows) a set of files.
-
- If src_dir is passed, files will be read from the given directory; otherwise
- we assume files are in src_files and dest_files
- """
- if src_dir != None:
- src_dir = _norm_path(src_dir)
- dest_dir = _norm_path(dest_dir)
- files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
- # Create a list with the src_dir stripped to use for outputs.
- dest_files = files.replace(src_dir, '').splitlines()
- src_files = files.splitlines()
- command = []
- if not _is_windows(repository_ctx):
- # We clear folders that might have been generated previously to avoid
- # undesired inclusions
- command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
- command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
- command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
- command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
- outs = []
- for i in range(len(dest_files)):
- if dest_files[i] != "":
- # If we have only one file to link we do not want to use the dest_dir, as
- # $(@D) will include the full path to the file.
- dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s'
- command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
- outs.append(' "' + dest_dir + dest_files[i] + '",')
- genrule = _genrule(src_dir, genrule_name, " && ".join(command),
- "\n".join(outs))
- return genrule
-
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
+def symlink_genrule_for_dir(
+ repository_ctx,
+ src_dir,
+ dest_dir,
+ genrule_name,
+ src_files = [],
+ dest_files = []):
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
+
+ If src_dir is passed, files will be read from the given directory; otherwise
+ we assume files are in src_files and dest_files
+ """
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
+
+ # Create a list with the src_dir stripped to use for outputs.
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+ if not _is_windows(repository_ctx):
+ # We clear folders that might have been generated previously to avoid
+ # undesired inclusions
+ command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
+ command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
+ command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
+ command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # On Windows, symlink is not supported, so we just copy all the files.
+ cmd = "cp -f" if _is_windows(repository_ctx) else "ln -s"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ genrule = _genrule(
+ src_dir,
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
def _genrule(src_dir, genrule_name, command, outs):
- """Returns a string with a genrule.
-
- Genrule executes the given command and produces the given outputs.
- """
- return (
- 'genrule(\n' +
- ' name = "' +
- genrule_name + '",\n' +
- ' outs = [\n' +
- outs +
- '\n ],\n' +
- ' cmd = """\n' +
- command +
- '\n """,\n' +
- ')\n'
- )
+ """Returns a string with a genrule.
+ Genrule executes the given command and produces the given outputs.
+ """
+ return (
+ "genrule(\n" +
+ ' name = "' +
+ genrule_name + '",\n' +
+ " outs = [\n" +
+ outs +
+ "\n ],\n" +
+ ' cmd = """\n' +
+ command +
+ '\n """,\n' +
+ ")\n"
+ )
def _read_dir(repository_ctx, src_dir):
- """Returns a string with all files in a directory.
-
- Finds all files inside a directory, traversing subfolders and following
- symlinks. The returned string contains the full path of all files
- separated by line breaks.
- """
- if _is_windows(repository_ctx):
- src_dir = src_dir.replace("/", "\\")
- find_result = _execute(
- repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
- empty_stdout_fine=True)
- # src_files will be used in genrule.outs where the paths must
- # use forward slashes.
- result = find_result.stdout.replace("\\", "/")
- else:
- find_result = _execute(
- repository_ctx, ["find", src_dir, "-follow", "-type", "f"],
- empty_stdout_fine=True)
- result = find_result.stdout
- return result
+ """Returns a string with all files in a directory.
+
+ Finds all files inside a directory, traversing subfolders and following
+ symlinks. The returned string contains the full path of all files
+ separated by line breaks.
+ """
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx,
+ ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine = True,
+ )
+
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine = True,
+ )
+ result = find_result.stdout
+ return result
def _flag_enabled(repository_ctx, flag_name):
- if flag_name in repository_ctx.os.environ:
- value = repository_ctx.os.environ[flag_name].strip()
- return value == "1"
- return False
+ if flag_name in repository_ctx.os.environ:
+ value = repository_ctx.os.environ[flag_name].strip()
+ return value == "1"
+ return False
def _use_cuda_clang(repository_ctx):
- return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
+ return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
- if _use_cuda_clang(repository_ctx):
- capability_flags = ["--cuda-gpu-arch=sm_" +
- cap.replace(".", "") for cap in compute_capabilities]
- else:
- # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
- capability_flags = []
- return str(capability_flags)
+ if _use_cuda_clang(repository_ctx):
+ capability_flags = ["--cuda-gpu-arch=sm_" +
+ cap.replace(".", "") for cap in compute_capabilities]
+ else:
+ # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
+ capability_flags = []
+ return str(capability_flags)
def _create_local_cuda_repository(repository_ctx):
- """Creates the repository containing files set up to build with CUDA."""
- cuda_config = _get_cuda_config(repository_ctx)
-
- cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
- cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
- cuda_config.cudnn_install_basedir)
- cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
- nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
-
- # Set up symbolic links for the cuda toolkit by creating genrules to do
- # symlinking. We create one genrule for each directory we want to track under
- # cuda_toolkit_path
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- genrules = [symlink_genrule_for_dir(repository_ctx,
- cuda_include_path, "cuda/include", "cuda-include")]
- genrules.append(symlink_genrule_for_dir(repository_ctx,
- nvvm_libdevice_dir, "cuda/nvvm/libdevice", "cuda-nvvm"))
- genrules.append(symlink_genrule_for_dir(repository_ctx,
- cupti_header_dir, "cuda/extras/CUPTI/include", "cuda-extras"))
-
- cuda_libs = _find_libs(repository_ctx, cuda_config)
- cuda_lib_src = []
- cuda_lib_dest = []
- for lib in cuda_libs.values():
- cuda_lib_src.append(lib.path)
- cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(symlink_genrule_for_dir(repository_ctx, None, "", "cuda-lib",
- cuda_lib_src, cuda_lib_dest))
-
- # Set up the symbolic links for cudnn if cndnn was not installed to
- # CUDA_TOOLKIT_PATH.
- included_files = _read_dir(repository_ctx, cuda_include_path).replace(
- cuda_include_path, '').splitlines()
- if '/cudnn.h' not in included_files:
- genrules.append(symlink_genrule_for_dir(repository_ctx, None,
- "cuda/include/", "cudnn-include", [cudnn_header_dir + "/cudnn.h"],
- ["cudnn.h"]))
- else:
- genrules.append(
- 'filegroup(\n' +
+ """Creates the repository containing files set up to build with CUDA."""
+ cuda_config = _get_cuda_config(repository_ctx)
+
+ cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cuda_config.cudnn_install_basedir,
+ )
+ cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
+ nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
+
+ # Set up symbolic links for the cuda toolkit by creating genrules to do
+ # symlinking. We create one genrule for each directory we want to track under
+ # cuda_toolkit_path
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ genrules = [symlink_genrule_for_dir(
+ repository_ctx,
+ cuda_include_path,
+ "cuda/include",
+ "cuda-include",
+ )]
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ nvvm_libdevice_dir,
+ "cuda/nvvm/libdevice",
+ "cuda-nvvm",
+ ))
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ cupti_header_dir,
+ "cuda/extras/CUPTI/include",
+ "cuda-extras",
+ ))
+
+ cuda_libs = _find_libs(repository_ctx, cuda_config)
+ cuda_lib_src = []
+ cuda_lib_dest = []
+ for lib in cuda_libs.values():
+ cuda_lib_src.append(lib.path)
+ cuda_lib_dest.append("cuda/lib/" + lib.file_name)
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "cuda-lib",
+ cuda_lib_src,
+ cuda_lib_dest,
+ ))
+
+ # Set up the symbolic links for cudnn if cndnn was not installed to
+ # CUDA_TOOLKIT_PATH.
+ included_files = _read_dir(repository_ctx, cuda_include_path).replace(
+ cuda_include_path,
+ "",
+ ).splitlines()
+ if "/cudnn.h" not in included_files:
+ genrules.append(symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "cuda/include/",
+ "cudnn-include",
+ [cudnn_header_dir + "/cudnn.h"],
+ ["cudnn.h"],
+ ))
+ else:
+ genrules.append(
+ "filegroup(\n" +
' name = "cudnn-include",\n' +
- ' srcs = [],\n' +
- ')\n'
+ " srcs = [],\n" +
+ ")\n",
)
- # Set up BUILD file for cuda/
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx, cuda_config.compute_capabilities),
- })
- _tpl(repository_ctx, "cuda:BUILD",
- {
- "%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
- "%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
- "%{cudart_static_linkopt}": _cudart_static_linkopt(
- cuda_config.cpu_value),
- "%{cudart_lib}": cuda_libs["cudart"].file_name,
- "%{cublas_lib}": cuda_libs["cublas"].file_name,
- "%{cusolver_lib}": cuda_libs["cusolver"].file_name,
- "%{cudnn_lib}": cuda_libs["cudnn"].file_name,
- "%{cufft_lib}": cuda_libs["cufft"].file_name,
- "%{curand_lib}": cuda_libs["curand"].file_name,
- "%{cupti_lib}": cuda_libs["cupti"].file_name,
- "%{cuda_include_genrules}": "\n".join(genrules),
- "%{cuda_headers}": ('":cuda-include",\n' +
- ' ":cudnn-include",')
- })
-
- is_cuda_clang = _use_cuda_clang(repository_ctx)
-
- should_download_clang = is_cuda_clang and _flag_enabled(
- repository_ctx, _TF_DOWNLOAD_CLANG)
- if should_download_clang:
- download_clang(repository_ctx, "crosstool/extra_tools")
-
- # Set up crosstool/
- cc = find_cc(repository_ctx)
- cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
-
- host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
- cuda_defines = {}
- if is_cuda_clang:
- cuda_defines["%{host_compiler_path}"] = str(cc)
- cuda_defines["%{host_compiler_warnings}"] = """
+ # Set up BUILD file for cuda/
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "True",
+ "%{cuda_extra_copts}": _compute_cuda_extra_copts(
+ repository_ctx,
+ cuda_config.compute_capabilities,
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
+ "%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
+ "%{cudart_static_linkopt}": _cudart_static_linkopt(
+ cuda_config.cpu_value,
+ ),
+ "%{cudart_lib}": cuda_libs["cudart"].file_name,
+ "%{cublas_lib}": cuda_libs["cublas"].file_name,
+ "%{cusolver_lib}": cuda_libs["cusolver"].file_name,
+ "%{cudnn_lib}": cuda_libs["cudnn"].file_name,
+ "%{cufft_lib}": cuda_libs["cufft"].file_name,
+ "%{curand_lib}": cuda_libs["curand"].file_name,
+ "%{cupti_lib}": cuda_libs["cupti"].file_name,
+ "%{cuda_include_genrules}": "\n".join(genrules),
+ "%{cuda_headers}": ('":cuda-include",\n' +
+ ' ":cudnn-include",'),
+ },
+ "cuda/BUILD",
+ )
+
+ is_cuda_clang = _use_cuda_clang(repository_ctx)
+
+ should_download_clang = is_cuda_clang and _flag_enabled(
+ repository_ctx,
+ _TF_DOWNLOAD_CLANG,
+ )
+ if should_download_clang:
+ download_clang(repository_ctx, "crosstool/extra_tools")
+
+ # Set up crosstool/
+ cc = find_cc(repository_ctx)
+ cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
+
+ host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
+ cuda_defines = {}
+ if is_cuda_clang:
+ cuda_defines["%{host_compiler_path}"] = str(cc)
+ cuda_defines["%{host_compiler_warnings}"] = """
# Some parts of the codebase set -Werror and hit this warning, so
# switch it off for now.
flag: "-Wno-invalid-partial-specialization"
"""
- cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
- _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty"})
- repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
- else:
- cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
- cuda_defines["%{host_compiler_warnings}"] = ""
- # TODO(klimek): We currently need to inject "/" as builtin directory path
- # to disable bazel's dependency checks.
- # The problem is that:
- # - the python rules symlink the python headers into the bazel root
- # - the rules use 'includes' in the BUILD file to redirect includes of the
- # python headers through those paths
- # - bazel currently uses -isystem for include paths specified via 'includes'
- # - gcc follows symlinks when resolving files via -isystem paths, and puts
- # the resolved paths into the .d file, which makes the dependency check
- # fail for bazel
- # There are multiple possible ways to solve this:
- # 1. make bazel not use -isystem for paths specified via 'includes'
- # 2. cp the headers instead of symlinking them
- #
- # Once this is fixed, the right builtin directory path is:
- # (host_compiler_includes +
- # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
- # The cuda directory needs to be passed, as there is currently no rule
- # providing the cuda headers in the same way the python headers are
- # provided.
- cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
- nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
- (cuda_config.cuda_toolkit_path,
- ".exe" if cuda_config.cpu_value == "Windows" else "")))
- _tpl(repository_ctx, "crosstool:BUILD",
- {"%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc"})
- _tpl(repository_ctx,
- "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
- {
- "%{cpu_compiler}": str(cc),
- "%{cuda_version}": cuda_config.cuda_version,
- "%{nvcc_path}": nvcc_path,
- "%{gcc_host_compiler_path}": str(cc),
- "%{cuda_compute_capabilities}": ", ".join(
- ["\"%s\"" % c for c in cuda_config.compute_capabilities]),
- })
- _tpl(repository_ctx, "crosstool:CROSSTOOL", cuda_defines, out="crosstool/CROSSTOOL")
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
- _tpl(repository_ctx, "cuda:cuda_config.h",
- {
- "%{cuda_version}": cuda_config.cuda_version,
- "%{cudnn_version}": cuda_config.cudnn_version,
- "%{cuda_compute_capabilities}": ",".join(
- ["CudaVersion(\"%s\")" % c
- for c in cuda_config.compute_capabilities]),
- "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
- }, "cuda/cuda/cuda_config.h")
+ cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
+ _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"})
+ repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.bat", "")
+ else:
+ cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
+ cuda_defines["%{host_compiler_warnings}"] = ""
+
+ # TODO(klimek): We currently need to inject "/" as builtin directory path
+ # to disable bazel's dependency checks.
+ # The problem is that:
+ # - the python rules symlink the python headers into the bazel root
+ # - the rules use 'includes' in the BUILD file to redirect includes of the
+ # python headers through those paths
+ # - bazel currently uses -isystem for include paths specified via 'includes'
+ # - gcc follows symlinks when resolving files via -isystem paths, and puts
+ # the resolved paths into the .d file, which makes the dependency check
+ # fail for bazel
+ # There are multiple possible ways to solve this:
+ # 1. make bazel not use -isystem for paths specified via 'includes'
+ # 2. cp the headers instead of symlinking them
+ #
+ # Once this is fixed, the right builtin directory path is:
+ # (host_compiler_includes +
+ # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
+ # The cuda directory needs to be passed, as there is currently no rule
+ # providing the cuda headers in the same way the python headers are
+ # provided.
+ cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
+ nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
+ (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if _is_windows(repository_ctx) else "",
+ )))
+ _tpl(
+ repository_ctx,
+ "crosstool:BUILD",
+ {
+ "%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc",
+ "%{win_linker_files}": ":windows_msvc_wrapper_files",
+ },
+ )
+ wrapper_defines = {
+ "%{cpu_compiler}": str(cc),
+ "%{cuda_version}": cuda_config.cuda_version,
+ "%{nvcc_path}": nvcc_path,
+ "%{gcc_host_compiler_path}": str(cc),
+ "%{cuda_compute_capabilities}": ", ".join(
+ ["\"%s\"" % c for c in cuda_config.compute_capabilities],
+ ),
+ "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
+ }
+ _tpl(
+ repository_ctx,
+ "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
+ wrapper_defines,
+ )
+ _tpl(
+ repository_ctx,
+ "crosstool:windows/msvc_wrapper_for_nvcc.py",
+ wrapper_defines,
+ )
+ _tpl(
+ repository_ctx,
+ "crosstool:windows/msvc_wrapper_for_nvcc.bat",
+ {
+ "%{python_binary}": _get_python_bin(repository_ctx),
+ },
+ )
+
+ _tpl(
+ repository_ctx,
+ "crosstool:CROSSTOOL",
+ cuda_defines + _get_win_cuda_defines(repository_ctx),
+ out = "crosstool/CROSSTOOL",
+ )
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": cuda_config.cuda_version,
+ "%{cudnn_version}": cuda_config.cudnn_version,
+ "%{cuda_compute_capabilities}": ",".join(
+ [
+ "CudaVersion(\"%s\")" % c
+ for c in cuda_config.compute_capabilities
+ ],
+ ),
+ "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
- """Creates pointers to a remotely configured repo set up to build with CUDA."""
- _tpl(repository_ctx, "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx, _compute_capabilities(repository_ctx)),
-
- })
- _tpl(repository_ctx, "cuda:remote.BUILD",
- {
- "%{remote_cuda_repo}": remote_config_repo,
- }, "cuda/BUILD")
- _tpl(repository_ctx, "crosstool:remote.BUILD", {
- "%{remote_cuda_repo}": remote_config_repo,
- }, "crosstool/BUILD")
+ """Creates pointers to a remotely configured repo set up to build with CUDA."""
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "True",
+ "%{cuda_extra_copts}": _compute_cuda_extra_copts(
+ repository_ctx,
+ _compute_capabilities(repository_ctx),
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:remote.BUILD",
+ {
+ "%{remote_cuda_repo}": remote_config_repo,
+ },
+ "cuda/BUILD",
+ )
+ _tpl(repository_ctx, "crosstool:remote.BUILD", {
+ "%{remote_cuda_repo}": remote_config_repo,
+ }, "crosstool/BUILD")
def _cuda_autoconf_impl(repository_ctx):
- """Implementation of the cuda_autoconf repository rule."""
- if not _enable_cuda(repository_ctx):
- _create_dummy_repository(repository_ctx)
- else:
- if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
- _create_remote_cuda_repository(repository_ctx,
- repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO])
+ """Implementation of the cuda_autoconf repository rule."""
+ if not _enable_cuda(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
+ _create_remote_cuda_repository(
+ repository_ctx,
+ repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
+ )
else:
- _create_local_cuda_repository(repository_ctx)
-
+ _create_local_cuda_repository(repository_ctx)
cuda_configure = repository_rule(
implementation = _cuda_autoconf_impl,
@@ -1181,6 +1457,7 @@ cuda_configure = repository_rule(
_TF_CUDA_COMPUTE_CAPABILITIES,
_TF_CUDA_CONFIG_REPO,
"NVVMIR_LIBRARY_DIR",
+ _PYTHON_BIN_PATH,
],
)
diff --git a/third_party/jsoncpp.BUILD b/third_party/jsoncpp.BUILD
index 65f98410b2..cf3cba0555 100644
--- a/third_party/jsoncpp.BUILD
+++ b/third_party/jsoncpp.BUILD
@@ -6,7 +6,6 @@ cc_library(
name = "jsoncpp",
srcs = [
"include/json/assertions.h",
- "src/lib_json/json_batchallocator.h",
"src/lib_json/json_reader.cpp",
"src/lib_json/json_tool.h",
"src/lib_json/json_value.cpp",
@@ -20,9 +19,13 @@ cc_library(
"include/json/json.h",
"include/json/reader.h",
"include/json/value.h",
+ "include/json/version.h",
"include/json/writer.h",
],
- copts = ["-DJSON_USE_EXCEPTION=0"],
+ copts = [
+ "-DJSON_USE_EXCEPTION=0",
+ "-DJSON_HAS_INT64",
+ ],
includes = ["include"],
visibility = ["//visibility:public"],
deps = [":private"],
diff --git a/third_party/kafka/BUILD b/third_party/kafka/BUILD
index a839ca717e..75792b0d87 100644
--- a/third_party/kafka/BUILD
+++ b/third_party/kafka/BUILD
@@ -60,6 +60,8 @@ cc_library(
"src/rdkafka_event.h",
"src/rdkafka_feature.c",
"src/rdkafka_feature.h",
+ "src/rdkafka_header.c",
+ "src/rdkafka_header.h",
"src/rdkafka_int.h",
"src/rdkafka_interceptor.c",
"src/rdkafka_interceptor.h",
@@ -93,7 +95,6 @@ cc_library(
"src/rdkafka_sasl_int.h",
"src/rdkafka_sasl_plain.c",
"src/rdkafka_subscription.c",
- "src/rdkafka_subscription.h",
"src/rdkafka_timer.c",
"src/rdkafka_timer.h",
"src/rdkafka_topic.c",
@@ -105,6 +106,8 @@ cc_library(
"src/rdlist.h",
"src/rdlog.c",
"src/rdlog.h",
+ "src/rdmurmur2.c",
+ "src/rdmurmur2.h",
"src/rdports.c",
"src/rdports.h",
"src/rdposix.h",
diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD
index 78ed1f4e16..ee49d281ab 100644
--- a/third_party/libxsmm.BUILD
+++ b/third_party/libxsmm.BUILD
@@ -3,7 +3,7 @@
licenses(["notice"]) # BSD 3-clause
-exports_files(["LICENSE"])
+exports_files(["LICENSE.md"])
# Arguments to ./scripts/libxsmm_interface.py, see that file for detailed description.
# precision: SP & DP
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.autogenerated.BUILD
index e1c22c8151..bf9f9ca9cf 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.autogenerated.BUILD
@@ -11,7 +11,11 @@ load(
"cmake_var_string",
"expand_cmake_vars",
"gentbl",
- "llvm_target_cmake_vars",
+ "llvm_all_cmake_vars",
+ "llvm_copts",
+ "llvm_defines",
+ "llvm_linkopts",
+ "llvm_support_platform_specific_srcs_glob",
)
load(
"@org_tensorflow//third_party:common.bzl",
@@ -39,147 +43,25 @@ llvm_target_asm_printers = llvm_targets
llvm_target_disassemblers = llvm_targets
-# TODO(phawkins): the set of CMake variables was hardcoded for expediency.
-# However, we should really detect many of these via configure-time tests.
-
-# The set of CMake variables common to all targets.
-cmake_vars = {
- # Headers
- "HAVE_DIRENT_H": 1,
- "HAVE_DLFCN_H": 1,
- "HAVE_ERRNO_H": 1,
- "HAVE_EXECINFO_H": 1,
- "HAVE_FCNTL_H": 1,
- "HAVE_INTTYPES_H": 1,
- "HAVE_PTHREAD_H": 1,
- "HAVE_SIGNAL_H": 1,
- "HAVE_STDINT_H": 1,
- "HAVE_SYS_IOCTL_H": 1,
- "HAVE_SYS_MMAN_H": 1,
- "HAVE_SYS_PARAM_H": 1,
- "HAVE_SYS_RESOURCE_H": 1,
- "HAVE_SYS_STAT_H": 1,
- "HAVE_SYS_TIME_H": 1,
- "HAVE_SYS_TYPES_H": 1,
- "HAVE_TERMIOS_H": 1,
- "HAVE_UNISTD_H": 1,
- "HAVE_ZLIB_H": 1,
-
- # Features
- "HAVE_BACKTRACE": 1,
- "BACKTRACE_HEADER": "execinfo.h",
- "HAVE_DLOPEN": 1,
- "HAVE_FUTIMES": 1,
- "HAVE_GETCWD": 1,
- "HAVE_GETPAGESIZE": 1,
- "HAVE_GETRLIMIT": 1,
- "HAVE_GETRUSAGE": 1,
- "HAVE_GETTIMEOFDAY": 1,
- "HAVE_INT64_T": 1,
- "HAVE_ISATTY": 1,
- "HAVE_LIBEDIT": 1,
- "HAVE_LIBPTHREAD": 1,
- "HAVE_LIBZ": 1,
- "HAVE_MKDTEMP": 1,
- "HAVE_MKSTEMP": 1,
- "HAVE_MKTEMP": 1,
- "HAVE_PREAD": 1,
- "HAVE_PTHREAD_GETSPECIFIC": 1,
- "HAVE_PTHREAD_MUTEX_LOCK": 1,
- "HAVE_PTHREAD_RWLOCK_INIT": 1,
- "HAVE_REALPATH": 1,
- "HAVE_SBRK": 1,
- "HAVE_SETENV": 1,
- "HAVE_SETRLIMIT": 1,
- "HAVE_SIGALTSTACK": 1,
- "HAVE_STRERROR": 1,
- "HAVE_STRERROR_R": 1,
- "HAVE_STRTOLL": 1,
- "HAVE_SYSCONF": 1,
- "HAVE_UINT64_T": 1,
- "HAVE__UNWIND_BACKTRACE": 1,
-
- # LLVM features
- "ENABLE_BACKTRACES": 1,
- "LLVM_BINDIR": "/dev/null",
- "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0,
- "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0,
- "LLVM_ENABLE_THREADS": 1,
- "LLVM_ENABLE_ZLIB": 1,
- "LLVM_HAS_ATOMICS": 1,
- "LLVM_INCLUDEDIR": "/dev/null",
- "LLVM_INFODIR": "/dev/null",
- "LLVM_MANDIR": "/dev/null",
- "LLVM_NATIVE_TARGET": 1,
- "LLVM_NATIVE_TARGETINFO": 1,
- "LLVM_NATIVE_TARGETMC": 1,
- "LLVM_NATIVE_ASMPRINTER": 1,
- "LLVM_NATIVE_ASMPARSER": 1,
- "LLVM_NATIVE_DISASSEMBLER": 1,
- "LLVM_ON_UNIX": 1,
- "LLVM_PREFIX": "/dev/null",
- "LLVM_VERSION_MAJOR": 0,
- "LLVM_VERSION_MINOR": 0,
- "LLVM_VERSION_PATCH": 0,
- "LTDL_SHLIB_EXT": ".so",
- "PACKAGE_NAME": "llvm",
- "PACKAGE_STRING": "llvm tensorflow-trunk",
- "PACKAGE_VERSION": "tensorflow-trunk",
- "RETSIGTYPE": "void",
-}
-
-# CMake variables specific to the Linux platform
-linux_cmake_vars = {
- "HAVE_MALLOC_H": 1,
- "HAVE_LINK_H": 1,
- "HAVE_MALLINFO": 1,
- "HAVE_FUTIMENS": 1,
-}
-
-# CMake variables specific to the Darwin (Mac OS X) platform.
-darwin_cmake_vars = {
- "HAVE_MALLOC_MALLOC_H": 1,
-}
-
-# Select a set of CMake variables based on the platform.
-# TODO(phawkins): use a better method to select the right host triple, rather
-# than hardcoding x86_64.
-all_cmake_vars = select({
- "@org_tensorflow//tensorflow:darwin": cmake_var_string(
- cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
- darwin_cmake_vars,
- ),
- "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string(
- cmake_vars +
- llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
- linux_cmake_vars,
- ),
- "//conditions:default": cmake_var_string(
- cmake_vars +
- llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
- linux_cmake_vars,
- ),
-})
-
# Performs CMake variable substitutions on configuration header files.
expand_cmake_vars(
name = "config_gen",
src = "include/llvm/Config/config.h.cmake",
- cmake_vars = all_cmake_vars,
+ cmake_vars = llvm_all_cmake_vars,
dst = "include/llvm/Config/config.h",
)
expand_cmake_vars(
name = "llvm_config_gen",
src = "include/llvm/Config/llvm-config.h.cmake",
- cmake_vars = all_cmake_vars,
+ cmake_vars = llvm_all_cmake_vars,
dst = "include/llvm/Config/llvm-config.h",
)
expand_cmake_vars(
name = "abi_breaking_gen",
src = "include/llvm/Config/abi-breaking.h.cmake",
- cmake_vars = all_cmake_vars,
+ cmake_vars = llvm_all_cmake_vars,
dst = "include/llvm/Config/abi-breaking.h",
)
@@ -240,14 +122,7 @@ cc_library(
"include/llvm/Config/config.h",
"include/llvm/Config/llvm-config.h",
],
- defines = [
- "LLVM_ENABLE_STATS",
- "__STDC_LIMIT_MACROS",
- "__STDC_CONSTANT_MACROS",
- "__STDC_FORMAT_MACROS",
- "_DEBUG",
- "LLVM_BUILD_GLOBAL_ISEL",
- ],
+ defines = llvm_defines,
includes = ["include"],
)
@@ -263,17 +138,6 @@ genrule(
# Rules that apply the LLVM tblgen tool.
gentbl(
- name = "intrinsics_gen",
- tbl_outs = [("-gen-intrinsic", "include/llvm/IR/Intrinsics.inc")],
- tblgen = ":llvm-tblgen",
- td_file = "include/llvm/IR/Intrinsics.td",
- td_srcs = glob([
- "include/llvm/CodeGen/*.td",
- "include/llvm/IR/Intrinsics*.td",
- ]),
-)
-
-gentbl(
name = "attributes_gen",
tbl_outs = [("-gen-attrs", "include/llvm/IR/Attributes.inc")],
tblgen = ":llvm-tblgen",
@@ -292,6 +156,42 @@ gentbl(
],
)
+gentbl(
+ name = "instcombine_transforms_gen",
+ tbl_outs = [(
+ "-gen-searchable-tables",
+ "lib/Transforms/InstCombine/InstCombineTables.inc",
+ )],
+ tblgen = ":llvm-tblgen",
+ td_file = "lib/Transforms/InstCombine/InstCombineTables.td",
+ td_srcs = glob([
+ "include/llvm/CodeGen/*.td",
+ "include/llvm/IR/Intrinsics*.td",
+ ]) + ["include/llvm/TableGen/SearchableTable.td"],
+)
+
+gentbl(
+ name = "intrinsic_enums_gen",
+ tbl_outs = [("-gen-intrinsic-enums", "include/llvm/IR/IntrinsicEnums.inc")],
+ tblgen = ":llvm-tblgen",
+ td_file = "include/llvm/IR/Intrinsics.td",
+ td_srcs = glob([
+ "include/llvm/CodeGen/*.td",
+ "include/llvm/IR/Intrinsics*.td",
+ ]),
+)
+
+gentbl(
+ name = "intrinsics_impl_gen",
+ tbl_outs = [("-gen-intrinsic-impl", "include/llvm/IR/IntrinsicImpl.inc")],
+ tblgen = ":llvm-tblgen",
+ td_file = "include/llvm/IR/Intrinsics.td",
+ td_srcs = glob([
+ "include/llvm/CodeGen/*.td",
+ "include/llvm/IR/Intrinsics*.td",
+ ]),
+)
+
# Binary targets used by Tensorflow.
cc_binary(
name = "llvm-tblgen",
@@ -299,11 +199,8 @@ cc_binary(
"utils/TableGen/*.cpp",
"utils/TableGen/*.h",
]),
- linkopts = [
- "-lm",
- "-ldl",
- "-lpthread",
- ],
+ copts = llvm_copts,
+ linkopts = llvm_linkopts,
stamp = 0,
deps = [
":config",
@@ -319,11 +216,8 @@ cc_binary(
"utils/FileCheck/*.cpp",
"utils/FileCheck/*.h",
]),
- linkopts = [
- "-ldl",
- "-lm",
- "-lpthread",
- ],
+ copts = llvm_copts,
+ linkopts = llvm_linkopts,
stamp = 0,
deps = [":support"],
)
@@ -494,7 +388,7 @@ cc_library(
"include/llvm/Target/AArch64/AsmParser/*.inc",
"lib/Target/AArch64/AsmParser/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_desc",
":aarch64_info",
@@ -519,7 +413,7 @@ cc_library(
"include/llvm/Target/AArch64/InstPrinter/*.inc",
"lib/Target/AArch64/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_target_gen",
":aarch64_utils",
@@ -542,7 +436,7 @@ cc_library(
"include/llvm/Target/AArch64/*.inc",
"lib/Target/AArch64/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_asm_printer",
":aarch64_desc",
@@ -575,14 +469,15 @@ cc_library(
"include/llvm/Target/AArch64/MCTargetDesc/*.inc",
"lib/Target/AArch64/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_asm_printer",
":aarch64_info",
":aarch64_target_gen",
":attributes_gen",
":config",
- ":intrinsics_gen",
+ ":intrinsic_enums_gen",
+ ":intrinsics_impl_gen",
":mc",
":support",
],
@@ -601,7 +496,7 @@ cc_library(
"include/llvm/Target/AArch64/Disassembler/*.inc",
"lib/Target/AArch64/Disassembler/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_desc",
":aarch64_info",
@@ -629,7 +524,7 @@ cc_library(
"lib/Target/AArch64/AArch64*.h",
"lib/Target/AArch64/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":code_gen",
":config",
@@ -652,7 +547,7 @@ cc_library(
"include/llvm/Target/AArch64/Utils/*.inc",
"lib/Target/AArch64/Utils/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AArch64"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AArch64"],
deps = [
":aarch64_target_gen",
":config",
@@ -674,6 +569,7 @@ cc_library(
"include/llvm/Transforms/AggressiveInstCombine/*.def",
"include/llvm/Transforms/AggressiveInstCombine/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -698,6 +594,7 @@ cc_library(
"include/llvm/Analysis/*.def",
"include/llvm/Analysis/*.inc",
]),
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -721,7 +618,7 @@ cc_library(
"include/llvm/Target/AMDGPU/MCTargetDesc/*.inc",
"lib/Target/AMDGPU/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_asm_printer",
":amdgpu_info",
@@ -746,7 +643,7 @@ cc_library(
"include/llvm/Target/AMDGPU/Disassembler/*.inc",
"lib/Target/AMDGPU/Disassembler/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_desc",
":amdgpu_info",
@@ -771,7 +668,7 @@ cc_library(
"include/llvm/Target/AMDGPU/TargetInfo/*.inc",
"lib/Target/AMDGPU/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_target_gen",
":config",
@@ -793,7 +690,7 @@ cc_library(
"include/llvm/Target/AMDGPU/Utils/*.inc",
"lib/Target/AMDGPU/Utils/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_target_gen",
":config",
@@ -816,7 +713,7 @@ cc_library(
"include/llvm/Target/AMDGPU/AsmParser/*.inc",
"lib/Target/AMDGPU/AsmParser/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_desc",
":amdgpu_info",
@@ -841,7 +738,7 @@ cc_library(
"include/llvm/Target/AMDGPU/InstPrinter/*.inc",
"lib/Target/AMDGPU/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_utils",
":config",
@@ -863,7 +760,7 @@ cc_library(
"include/llvm/Target/AMDGPU/*.inc",
"lib/Target/AMDGPU/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/AMDGPU"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/AMDGPU"],
deps = [
":amdgpu_asm_printer",
":amdgpu_desc",
@@ -899,7 +796,7 @@ cc_library(
"include/llvm/Target/ARM/AsmParser/*.inc",
"lib/Target/ARM/AsmParser/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_desc",
":arm_info",
@@ -925,7 +822,7 @@ cc_library(
"lib/Target/ARM/*.h",
"lib/Target/ARM/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_info",
":arm_target_gen",
@@ -949,7 +846,7 @@ cc_library(
"include/llvm/Target/ARM/*.inc",
"lib/Target/ARM/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":analysis",
":arm_asm_printer",
@@ -966,6 +863,7 @@ cc_library(
":selection_dag",
":support",
":target",
+ ":transform_utils",
],
)
@@ -984,14 +882,15 @@ cc_library(
"include/llvm/Target/ARM/MCTargetDesc/*.inc",
"lib/Target/ARM/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_asm_printer",
":arm_info",
":arm_target_gen",
":attributes_gen",
":config",
- ":intrinsics_gen",
+ ":intrinsic_enums_gen",
+ ":intrinsics_impl_gen",
":mc",
":mc_disassembler",
":support",
@@ -1011,7 +910,7 @@ cc_library(
"include/llvm/Target/ARM/Disassembler/*.inc",
"lib/Target/ARM/Disassembler/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_desc",
":arm_info",
@@ -1036,7 +935,7 @@ cc_library(
"include/llvm/Target/ARM/TargetInfo/*.inc",
"lib/Target/ARM/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_target_gen",
":config",
@@ -1059,7 +958,7 @@ cc_library(
"include/llvm/Target/ARM/Utils/*.inc",
"lib/Target/ARM/Utils/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/ARM"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/ARM"],
deps = [
":arm_target_gen",
":config",
@@ -1081,6 +980,7 @@ cc_library(
"include/llvm/AsmParser/*.def",
"include/llvm/AsmParser/*.inc",
]),
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1103,6 +1003,7 @@ cc_library(
"include/llvm/CodeGen/AsmPrinter/*.inc",
"lib/CodeGen/AsmPrinter/*.def",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":binary_format",
@@ -1133,6 +1034,7 @@ cc_library(
"include/llvm/BinaryFormat/ELFRelocs/*.def",
"include/llvm/BinaryFormat/WasmRelocs/*.def",
]),
+ copts = llvm_copts,
deps = [
":config",
":support",
@@ -1153,6 +1055,7 @@ cc_library(
"include/llvm/Bitcode/Reader/*.inc",
"include/llvm/Bitcode/BitstreamReader.h",
]),
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1176,6 +1079,7 @@ cc_library(
"include/llvm/Bitcode/BitcodeWriterPass.h",
"include/llvm/Bitcode/BitstreamWriter.h",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1200,6 +1104,7 @@ cc_library(
"include/llvm/CodeGen/*.inc",
"include/llvm/CodeGen/**/*.h",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":bit_reader",
@@ -1237,12 +1142,14 @@ cc_library(
"include/llvm/*.h",
"include/llvm/Analysis/*.def",
]),
+ copts = llvm_copts,
deps = [
":attributes_compat_gen",
":attributes_gen",
":binary_format",
":config",
- ":intrinsics_gen",
+ ":intrinsic_enums_gen",
+ ":intrinsics_impl_gen",
":support",
],
)
@@ -1260,6 +1167,7 @@ cc_library(
"include/llvm/DebugInfo/CodeView/*.def",
"include/llvm/DebugInfo/CodeView/*.inc",
]),
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1281,6 +1189,7 @@ cc_library(
"include/llvm/DebugInfo/MSF/*.def",
"include/llvm/DebugInfo/MSF/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":support",
@@ -1300,6 +1209,7 @@ cc_library(
"include/llvm/Demangle/*.def",
"include/llvm/Demangle/*.inc",
]),
+ copts = llvm_copts,
deps = [":config"],
)
@@ -1316,6 +1226,7 @@ cc_library(
"include/llvm/ExecutionEngine/*.def",
"include/llvm/ExecutionEngine/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1340,6 +1251,7 @@ cc_library(
"include/llvm/CodeGen/GlobalISel/*.def",
"include/llvm/CodeGen/GlobalISel/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":code_gen",
@@ -1369,6 +1281,7 @@ cc_library(
"include/llvm/Transforms/InstrProfiling.h",
"include/llvm/Transforms/PGOInstrumentation.h",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1393,10 +1306,12 @@ cc_library(
"include/llvm/Transforms/InstCombine/*.def",
"include/llvm/Transforms/InstCombine/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
":core",
+ ":instcombine_transforms_gen",
":support",
":transform_utils",
],
@@ -1418,6 +1333,7 @@ cc_library(
"include/llvm/Transforms/IPO/*.def",
"include/llvm/Transforms/IPO/*.inc",
]),
+ copts = llvm_copts,
deps = [
":aggressive_inst_combine",
":analysis",
@@ -1451,6 +1367,7 @@ cc_library(
"include/llvm/IRReader/*.def",
"include/llvm/IRReader/*.inc",
]),
+ copts = llvm_copts,
deps = [
":asm_parser",
":bit_reader",
@@ -1473,6 +1390,7 @@ cc_library(
"include/llvm/Linker/*.def",
"include/llvm/Linker/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1494,6 +1412,7 @@ cc_library(
"include/llvm/MC/*.def",
"include/llvm/MC/*.inc",
]),
+ copts = llvm_copts,
deps = [
":binary_format",
":config",
@@ -1515,6 +1434,7 @@ cc_library(
"include/llvm/MC/MCDisassembler/*.def",
"include/llvm/MC/MCDisassembler/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1535,6 +1455,7 @@ cc_library(
"include/llvm/MC/MCParser/*.def",
"include/llvm/MC/MCParser/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1555,7 +1476,7 @@ cc_library(
"include/llvm/Target/NVPTX/InstPrinter/*.inc",
"lib/Target/NVPTX/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/NVPTX"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":attributes_gen",
@@ -1579,7 +1500,7 @@ cc_library(
"include/llvm/Target/NVPTX/*.inc",
"lib/Target/NVPTX/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/NVPTX"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
":analysis",
":asm_printer",
@@ -1613,7 +1534,7 @@ cc_library(
"include/llvm/Target/NVPTX/MCTargetDesc/*.inc",
"lib/Target/NVPTX/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/NVPTX"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":config",
@@ -1639,7 +1560,7 @@ cc_library(
"lib/Target/NVPTX/NVPTX.h",
"lib/Target/NVPTX/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/NVPTX"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/NVPTX"],
deps = [
"nvptx_target_gen",
":attributes_gen",
@@ -1663,6 +1584,7 @@ cc_library(
"include/llvm/Object/*.def",
"include/llvm/Object/*.inc",
]),
+ copts = llvm_copts,
deps = [
":binary_format",
":bit_reader",
@@ -1688,6 +1610,7 @@ cc_library(
"include/llvm/Transforms/ObjCARC/*.def",
"include/llvm/Transforms/ObjCARC/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -1710,13 +1633,16 @@ cc_library(
"include/llvm/ExecutionEngine/Orc/*.def",
"include/llvm/ExecutionEngine/Orc/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":core",
":execution_engine",
+ ":mc",
":object",
":runtime_dyld",
":support",
+ ":target",
":transform_utils",
],
)
@@ -1734,7 +1660,7 @@ cc_library(
"include/llvm/Target/PowerPC/AsmParser/*.inc",
"lib/Target/PowerPC/AsmParser/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":config",
":mc",
@@ -1758,11 +1684,12 @@ cc_library(
"include/llvm/Target/PowerPC/InstPrinter/*.inc",
"lib/Target/PowerPC/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
- ":intrinsics_gen",
+ ":intrinsic_enums_gen",
+ ":intrinsics_impl_gen",
":mc",
":powerpc_info",
":powerpc_target_gen",
@@ -1783,7 +1710,7 @@ cc_library(
"include/llvm/Target/PowerPC/*.inc",
"lib/Target/PowerPC/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":analysis",
":asm_printer",
@@ -1815,11 +1742,12 @@ cc_library(
"include/llvm/Target/PowerPC/MCTargetDesc/*.inc",
"lib/Target/PowerPC/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
- ":intrinsics_gen",
+ ":intrinsic_enums_gen",
+ ":intrinsics_impl_gen",
":mc",
":powerpc_asm_printer",
":powerpc_info",
@@ -1841,7 +1769,7 @@ cc_library(
"include/llvm/Target/PowerPC/Disassembler/*.inc",
"lib/Target/PowerPC/Disassembler/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":config",
":mc_disassembler",
@@ -1865,12 +1793,11 @@ cc_library(
"lib/Target/PowerPC/PPC*.h",
"lib/Target/PowerPC/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/PowerPC"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/PowerPC"],
deps = [
":attributes_gen",
":config",
":core",
- ":intrinsics_gen",
":powerpc_target_gen",
":support",
":target",
@@ -1890,6 +1817,7 @@ cc_library(
"include/llvm/ProfileData/*.def",
"include/llvm/ProfileData/*.inc",
]),
+ copts = llvm_copts,
deps = [
":config",
":core",
@@ -1918,6 +1846,7 @@ cc_library(
"include/llvm/ExecutionEngine/RTDyldMemoryManager.h",
"include/llvm/ExecutionEngine/RuntimeDyld*.h",
]),
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -1945,6 +1874,7 @@ cc_library(
"include/llvm/Transforms/IPO.h",
"include/llvm/Transforms/IPO/SCCP.h",
]),
+ copts = llvm_copts,
deps = [
":aggressive_inst_combine",
":analysis",
@@ -1970,6 +1900,7 @@ cc_library(
"include/llvm/CodeGen/SelectionDAG/*.def",
"include/llvm/CodeGen/SelectionDAG/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":code_gen",
@@ -1988,14 +1919,12 @@ cc_library(
"lib/Support/*.c",
"lib/Support/*.cpp",
"lib/Support/*.inc",
- "lib/Support/Unix/*.inc",
- "lib/Support/Unix/*.h",
"include/llvm-c/*.h",
"include/llvm/CodeGen/MachineValueType.h",
"include/llvm/BinaryFormat/COFF.h",
"include/llvm/BinaryFormat/MachO.h",
"lib/Support/*.h",
- ]),
+ ] + llvm_support_platform_specific_srcs_glob),
hdrs = glob([
"include/llvm/Support/*.h",
"include/llvm/Support/*.def",
@@ -2007,6 +1936,7 @@ cc_library(
"include/llvm/BinaryFormat/MachO.def",
"include/llvm/Support/VCSRevision.h",
],
+ copts = llvm_copts,
deps = [
":config",
":demangle",
@@ -2029,6 +1959,7 @@ cc_library(
"include/llvm/TableGen/*.inc",
"include/llvm/Target/*.def",
]),
+ copts = llvm_copts,
deps = [
":config",
":mc",
@@ -2054,6 +1985,7 @@ cc_library(
"include/llvm/CodeGen/*.def",
"include/llvm/CodeGen/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2078,6 +2010,7 @@ cc_library(
"include/llvm/Transforms/Utils/*.def",
"include/llvm/Transforms/Utils/*.inc",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2101,6 +2034,7 @@ cc_library(
"include/llvm/Transforms/Vectorize/*.inc",
"include/llvm/Transforms/Vectorize.h",
]),
+ copts = llvm_copts,
deps = [
":analysis",
":config",
@@ -2124,7 +2058,7 @@ cc_library(
"include/llvm/Target/X86/AsmParser/*.inc",
"lib/Target/X86/AsmParser/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2149,7 +2083,7 @@ cc_library(
"include/llvm/Target/X86/InstPrinter/*.inc",
"lib/Target/X86/InstPrinter/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2173,7 +2107,7 @@ cc_library(
"include/llvm/Target/X86/*.inc",
"lib/Target/X86/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":analysis",
":asm_printer",
@@ -2206,7 +2140,7 @@ cc_library(
"include/llvm/Target/X86/MCTargetDesc/*.inc",
"lib/Target/X86/MCTargetDesc/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2231,7 +2165,7 @@ cc_library(
"include/llvm/Target/X86/Disassembler/*.inc",
"lib/Target/X86/Disassembler/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc_disassembler",
@@ -2254,7 +2188,7 @@ cc_library(
"include/llvm/Target/X86/TargetInfo/*.inc",
"lib/Target/X86/TargetInfo/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":config",
":mc",
@@ -2276,7 +2210,7 @@ cc_library(
"include/llvm/Target/X86/Utils/*.inc",
"lib/Target/X86/Utils/*.h",
]),
- copts = ["-Iexternal/llvm/lib/Target/X86"],
+ copts = llvm_copts + ["-Iexternal/llvm/lib/Target/X86"],
deps = [
":code_gen",
":config",
diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl
index 0efcf319bd..dfdacafceb 100644
--- a/third_party/llvm/llvm.bzl
+++ b/third_party/llvm/llvm.bzl
@@ -105,3 +105,143 @@ def expand_cmake_vars(name, src, dst, cmake_vars):
"< $< > $@")
)
+# TODO(phawkins): the set of CMake variables was hardcoded for expediency.
+# However, we should really detect many of these via configure-time tests.
+
+# The set of CMake variables common to all targets.
+cmake_vars = {
+ # Headers
+ "HAVE_DIRENT_H": 1,
+ "HAVE_DLFCN_H": 1,
+ "HAVE_ERRNO_H": 1,
+ "HAVE_EXECINFO_H": 1,
+ "HAVE_FCNTL_H": 1,
+ "HAVE_INTTYPES_H": 1,
+ "HAVE_PTHREAD_H": 1,
+ "HAVE_SIGNAL_H": 1,
+ "HAVE_STDINT_H": 1,
+ "HAVE_SYS_IOCTL_H": 1,
+ "HAVE_SYS_MMAN_H": 1,
+ "HAVE_SYS_PARAM_H": 1,
+ "HAVE_SYS_RESOURCE_H": 1,
+ "HAVE_SYS_STAT_H": 1,
+ "HAVE_SYS_TIME_H": 1,
+ "HAVE_SYS_TYPES_H": 1,
+ "HAVE_TERMIOS_H": 1,
+ "HAVE_UNISTD_H": 1,
+ "HAVE_ZLIB_H": 1,
+
+ # Features
+ "HAVE_BACKTRACE": 1,
+ "BACKTRACE_HEADER": "execinfo.h",
+ "HAVE_DLOPEN": 1,
+ "HAVE_FUTIMES": 1,
+ "HAVE_GETCWD": 1,
+ "HAVE_GETPAGESIZE": 1,
+ "HAVE_GETRLIMIT": 1,
+ "HAVE_GETRUSAGE": 1,
+ "HAVE_GETTIMEOFDAY": 1,
+ "HAVE_INT64_T": 1,
+ "HAVE_ISATTY": 1,
+ "HAVE_LIBEDIT": 1,
+ "HAVE_LIBPTHREAD": 1,
+ "HAVE_LIBZ": 1,
+ "HAVE_MKDTEMP": 1,
+ "HAVE_MKSTEMP": 1,
+ "HAVE_MKTEMP": 1,
+ "HAVE_PREAD": 1,
+ "HAVE_PTHREAD_GETSPECIFIC": 1,
+ "HAVE_PTHREAD_MUTEX_LOCK": 1,
+ "HAVE_PTHREAD_RWLOCK_INIT": 1,
+ "HAVE_REALPATH": 1,
+ "HAVE_SBRK": 1,
+ "HAVE_SETENV": 1,
+ "HAVE_SETRLIMIT": 1,
+ "HAVE_SIGALTSTACK": 1,
+ "HAVE_STRERROR": 1,
+ "HAVE_STRERROR_R": 1,
+ "HAVE_STRTOLL": 1,
+ "HAVE_SYSCONF": 1,
+ "HAVE_UINT64_T": 1,
+ "HAVE__UNWIND_BACKTRACE": 1,
+
+ # LLVM features
+ "ENABLE_BACKTRACES": 1,
+ "LLVM_BINDIR": "/dev/null",
+ "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0,
+ "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0,
+ "LLVM_ENABLE_THREADS": 1,
+ "LLVM_ENABLE_ZLIB": 1,
+ "LLVM_HAS_ATOMICS": 1,
+ "LLVM_INCLUDEDIR": "/dev/null",
+ "LLVM_INFODIR": "/dev/null",
+ "LLVM_MANDIR": "/dev/null",
+ "LLVM_NATIVE_TARGET": 1,
+ "LLVM_NATIVE_TARGETINFO": 1,
+ "LLVM_NATIVE_TARGETMC": 1,
+ "LLVM_NATIVE_ASMPRINTER": 1,
+ "LLVM_NATIVE_ASMPARSER": 1,
+ "LLVM_NATIVE_DISASSEMBLER": 1,
+ "LLVM_ON_UNIX": 1,
+ "LLVM_PREFIX": "/dev/null",
+ "LLVM_VERSION_MAJOR": 0,
+ "LLVM_VERSION_MINOR": 0,
+ "LLVM_VERSION_PATCH": 0,
+ "LTDL_SHLIB_EXT": ".so",
+ "PACKAGE_NAME": "llvm",
+ "PACKAGE_STRING": "llvm tensorflow-trunk",
+ "PACKAGE_VERSION": "tensorflow-trunk",
+ "RETSIGTYPE": "void",
+}
+
+# CMake variables specific to the Linux platform
+linux_cmake_vars = {
+ "HAVE_MALLOC_H": 1,
+ "HAVE_LINK_H": 1,
+ "HAVE_MALLINFO": 1,
+ "HAVE_FUTIMENS": 1,
+}
+
+# CMake variables specific to the Darwin (Mac OS X) platform.
+darwin_cmake_vars = {
+ "HAVE_MALLOC_MALLOC_H": 1,
+}
+
+# Select a set of CMake variables based on the platform.
+# TODO(phawkins): use a better method to select the right host triple, rather
+# than hardcoding x86_64.
+llvm_all_cmake_vars = select({
+ "@org_tensorflow//tensorflow:darwin": cmake_var_string(
+ cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
+ darwin_cmake_vars),
+ "@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string(
+ cmake_vars +
+ llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
+ linux_cmake_vars,
+ ),
+ "//conditions:default": cmake_var_string(
+ cmake_vars +
+ llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
+ linux_cmake_vars),
+
+})
+
+llvm_linkopts = ["-ldl", "-lm", "-lpthread"]
+
+llvm_defines = [
+ "LLVM_ENABLE_STATS",
+ "__STDC_LIMIT_MACROS",
+ "__STDC_CONSTANT_MACROS",
+ "__STDC_FORMAT_MACROS",
+ "_DEBUG",
+ "LLVM_BUILD_GLOBAL_ISEL",
+]
+
+llvm_copts = []
+
+# Platform specific sources for libSupport.
+
+llvm_support_platform_specific_srcs_glob = [
+ "lib/Support/Unix/*.inc",
+ "lib/Support/Unix/*.h",
+]
diff --git a/third_party/mkl/LICENSE b/third_party/mkl/LICENSE
new file mode 100644
index 0000000000..9c8f3ea087
--- /dev/null
+++ b/third_party/mkl/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. \ No newline at end of file
diff --git a/third_party/nanopb.BUILD b/third_party/nanopb.BUILD
new file mode 100644
index 0000000000..d21866911b
--- /dev/null
+++ b/third_party/nanopb.BUILD
@@ -0,0 +1,23 @@
+# Description:
+# Nanopb, a tiny ANSI C protobuf implementation for use on embedded devices.
+
+licenses(["notice"]) # zlib license
+
+exports_files(["LICENSE.txt"])
+
+cc_library(
+ name = "nanopb",
+ srcs = [
+ "pb_common.c",
+ "pb_decode.c",
+ "pb_encode.c",
+ ],
+ hdrs = [
+ "pb.h",
+ "pb_common.h",
+ "pb_decode.h",
+ "pb_encode.h",
+ ],
+ includes = ["."],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD
index 341d58068b..89330eac54 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm.BUILD
@@ -8,45 +8,93 @@ exports_files(["LICENSE"])
cc_binary(
name = "nasm",
srcs = [
- "assemble.c",
- "assemble.h",
- "compiler.h",
- "crc64.c",
- "directiv.c",
- "directiv.h",
- "disp8.c",
- "disp8.h",
- "eval.c",
- "eval.h",
- "exprlib.c",
- "float.c",
- "float.h",
- "hashtbl.c",
- "hashtbl.h",
- "iflag.c",
- "iflag.h",
- "iflaggen.h",
- "ilog2.c",
- "insns.h",
- "insnsa.c",
- "insnsb.c",
- "insnsi.h",
- "labels.c",
- "labels.h",
- "lib/strlcpy.c",
- "listing.c",
- "listing.h",
- "macros.c",
- "md5.h",
- "md5c.c",
- "nasm.c",
- "nasm.h",
- "nasmlib.c",
- "nasmlib.h",
- "opflags.h",
+ "asm/assemble.c",
+ "asm/assemble.h",
+ "asm/directbl.c",
+ "asm/directiv.c",
+ "asm/directiv.h",
+ "asm/error.c",
+ "asm/eval.c",
+ "asm/eval.h",
+ "asm/exprdump.c",
+ "asm/exprlib.c",
+ "asm/float.c",
+ "asm/float.h",
+ "asm/labels.c",
+ "asm/listing.c",
+ "asm/listing.h",
+ "asm/nasm.c",
+ "asm/parser.c",
+ "asm/parser.h",
+ "asm/pptok.c",
+ "asm/pptok.h",
+ "asm/pragma.c",
+ "asm/preproc.c",
+ "asm/preproc.h",
+ "asm/preproc-nop.c",
+ "asm/quote.c",
+ "asm/quote.h",
+ "asm/rdstrnum.c",
+ "asm/segalloc.c",
+ "asm/stdscan.c",
+ "asm/stdscan.h",
+ "asm/strfunc.c",
+ "asm/tokens.h",
+ "asm/tokhash.c",
+ "common/common.c",
+ "config/unknown.h",
+ "disasm/disasm.c",
+ "disasm/disasm.h",
+ "disasm/sync.c",
+ "disasm/sync.h",
+ "include/compiler.h",
+ "include/disp8.h",
+ "include/error.h",
+ "include/hashtbl.h",
+ "include/iflag.h",
+ "include/insns.h",
+ "include/labels.h",
+ "include/md5.h",
+ "include/nasm.h",
+ "include/nasmint.h",
+ "include/nasmlib.h",
+ "include/opflags.h",
+ "include/perfhash.h",
+ "include/raa.h",
+ "include/rbtree.h",
+ "include/rdoff.h",
+ "include/saa.h",
+ "include/strlist.h",
+ "include/tables.h",
+ "include/ver.h",
+ "macros/macros.c",
+ "nasmlib/badenum.c",
+ "nasmlib/bsi.c",
+ "nasmlib/crc64.c",
+ "nasmlib/file.c",
+ "nasmlib/file.h",
+ "nasmlib/filename.c",
+ "nasmlib/hashtbl.c",
+ "nasmlib/ilog2.c",
+ "nasmlib/malloc.c",
+ "nasmlib/md5c.c",
+ "nasmlib/mmap.c",
+ "nasmlib/path.c",
+ "nasmlib/perfhash.c",
+ "nasmlib/raa.c",
+ "nasmlib/rbtree.c",
+ "nasmlib/readnum.c",
+ "nasmlib/realpath.c",
+ "nasmlib/saa.c",
+ "nasmlib/srcfile.c",
+ "nasmlib/string.c",
+ "nasmlib/strlist.c",
+ "nasmlib/ver.c",
+ "nasmlib/zerobuf.c",
"output/codeview.c",
"output/dwarf.h",
"output/elf.h",
+ "output/legacy.c",
"output/nulldbg.c",
"output/nullout.c",
"output/outaout.c",
@@ -56,9 +104,6 @@ cc_binary(
"output/outdbg.c",
"output/outelf.c",
"output/outelf.h",
- "output/outelf32.c",
- "output/outelf64.c",
- "output/outelfx32.c",
"output/outform.c",
"output/outform.h",
"output/outieee.c",
@@ -69,35 +114,31 @@ cc_binary(
"output/outrdf2.c",
"output/pecoff.h",
"output/stabs.h",
- "parser.c",
- "parser.h",
- "pptok.c",
- "pptok.h",
- "preproc.c",
- "preproc.h",
- "preproc-nop.c",
- "quote.c",
- "quote.h",
- "raa.c",
- "raa.h",
- "rbtree.c",
- "rbtree.h",
- "rdoff/rdoff.h",
- "realpath.c",
- "regflags.c",
- "regs.h",
- "regvals.c",
- "saa.c",
- "saa.h",
- "srcfile.c",
- "stdscan.c",
- "stdscan.h",
- "strfunc.c",
- "tables.h",
- "tokens.h",
- "tokhash.c",
- "ver.c",
+ "stdlib/snprintf.c",
+ "stdlib/strlcpy.c",
+ "stdlib/strnlen.c",
+ "stdlib/vsnprintf.c",
"version.h",
+ "x86/disp8.c",
+ "x86/iflag.c",
+ "x86/iflaggen.h",
+ "x86/insnsa.c",
+ "x86/insnsb.c",
+ "x86/insnsd.c",
+ "x86/insnsi.h",
+ "x86/insnsn.c",
+ "x86/regdis.c",
+ "x86/regdis.h",
+ "x86/regflags.c",
+ "x86/regs.c",
+ "x86/regs.h",
+ "x86/regvals.c",
+ ],
+ includes = [
+ "asm",
+ "include",
+ "output",
+ "x86",
],
copts = select({
":windows": [],
@@ -110,7 +151,10 @@ cc_binary(
defines = select({
":windows": [],
":windows_msvc": [],
- "//conditions:default": ["HAVE_SNPRINTF"],
+ "//conditions:default": [
+ "HAVE_SNPRINTF",
+ "HAVE_SYS_TYPES_H",
+ ],
}),
visibility = ["@jpeg//:__pkg__"],
)
diff --git a/third_party/repo.bzl b/third_party/repo.bzl
index cb67d3e961..9cee1fcc4b 100644
--- a/third_party/repo.bzl
+++ b/third_party/repo.bzl
@@ -16,7 +16,6 @@
_SINGLE_URL_WHITELIST = depset([
"arm_compiler",
- "ortools_archive",
])
def _is_windows(ctx):
diff --git a/third_party/sqlite.BUILD b/third_party/sqlite.BUILD
index 6da7953589..2876f305f1 100644
--- a/third_party/sqlite.BUILD
+++ b/third_party/sqlite.BUILD
@@ -5,6 +5,7 @@ licenses(["unencumbered"]) # Public Domain
SQLITE_COPTS = [
"-Os",
+ "-DSQLITE_ENABLE_JSON1",
"-DHAVE_DECL_STRERROR_R=1",
"-DHAVE_STDINT_H=1",
"-DHAVE_INTTYPES_H=1",
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
new file mode 100644
index 0000000000..fc3183a754
--- /dev/null
+++ b/third_party/toolchains/BUILD
@@ -0,0 +1,22 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+# Platform for use with remote execution with
+# custom container based off RBE Ubuntu16_04
+# http://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04
+# Built with //tensorflow/tools/ci_build/Dockerfile.rbe.cpu
+platform(
+ name = "rbe_ubuntu16_04-tf",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains//constraints:xenial",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/asci-toolchain/nosla-ubuntu16_04-tf@sha256:800a7b68cabef15419695c188ed33ed70adf678c2371b97b236f3ae26c38274d"
+ }""",
+)
diff --git a/third_party/toolchains/clang6/CROSSTOOL.tpl b/third_party/toolchains/clang6/CROSSTOOL.tpl
index 6b7e5a8808..ffba9850bb 100644
--- a/third_party/toolchains/clang6/CROSSTOOL.tpl
+++ b/third_party/toolchains/clang6/CROSSTOOL.tpl
@@ -76,9 +76,6 @@ toolchain {
# This adds a little bit more durability to our Clang build.
#
- # At the moment, this only only be needed for:
- # - add_boringssl_s390x.patch: --Wa,--noexecstack
- #
# Folks who do maintenance work on TF Bazel Clang should consider
# commenting out these lines, while doing that work, to gain a better
# understanding of what the intersection of support looks like between GCC
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 1c1e6afb65..3559375d5c 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -36,8 +36,6 @@ build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
build:cuda_clang --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda_clang --define=using_cuda=true --define=using_cuda_clang=true --define=using_clang=true
-build:win-cuda --define=using_cuda=true --define=using_cuda_nvcc=true
-
build:mkl --define=using_mkl=true
build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain